gflownet.losses.trajectorybalance

Trajectory Balance loss or objective for training GFlowNets.

The Trajectory Balance (TB) loss or objective was defined by Malkin et al. (2022):

Classes

TrajectoryBalance

Initialization method for the Trajectory Balance loss class.

Module Contents

class gflownet.losses.trajectorybalance.TrajectoryBalance(**kwargs)[source]

Bases: gflownet.losses.base.BaseLoss

Initialization method for the Trajectory Balance loss class.

name[source]

The name of the loss or objective function: Trajectory Balance

Type:

str

acronym[source]

The acronym of the loss or objective function: TB

Type:

str

id[source]

The identifier of the loss or objective function: trajectorybalance

Type:

str

name = 'Trajectory Balance'[source]
acronym = 'TB'[source]
id = 'trajectorybalance'[source]
requires_backward_policy()[source]

Returns True if the loss function requires a backward policy.

The Trajectory Balance loss does require a backward policy model, hence True is returned.

Returns:

True

Return type:

bool

requires_state_flow_model()[source]

Returns True if the loss function requires a state flow model.

The Trajectory Balance loss does not require a state flow model, hence False is returned.

Returns:

False

Return type:

bool

is_defined_for_continuous()[source]

Returns True if the loss function is well defined for continuous GFlowNets, that is continuous environments, or False otherwise.

The Trajectory Balance loss is well defined for continuous GFlowNets, therefore this method returns True.

Returns:

True

Return type:

bool

compute_losses_of_batch(batch)[source]

Computes the Trajectory Balance loss for each trajectory of the input batch.

The Trajectory Balance (TB) loss or objective is computed in this method as is defined in Equation 14 of Malkin et al. (2022).

Parameters:

batch (Batch) – A batch of trajectories.

Returns:

tensor – The loss of each trajectory in the batch.

Return type:

torchtyping.TensorType[batch_size]

aggregate_losses_of_batch(losses, batch)[source]

Aggregates the losses computed from a batch to obtain the overall average loss.

The result is returned as a dictionary with the following items: - ‘all’: Overall average loss

Parameters:
  • losses (tensor) – The loss of each trajectory in the batch.

  • batch (Batch) – A batch of trajectories.

Returns:

loss_dict (dict) – A dictionary of loss aggregations.

Return type:

dict[str, float]