gflownet.losses.flowmatching

Flow Matching loss or objective for training GFlowNets.

The Flow Matching (FM) loss or objective was defined by Bengio et al. (2021):

Classes

FlowMatching

Initialization method for the Flow Matching loss class.

Module Contents

class gflownet.losses.flowmatching.FlowMatching(**kwargs)[source]

Bases: gflownet.losses.base.BaseLoss

Initialization method for the Flow Matching loss class.

name[source]

The name of the loss or objective function: Flow Matching

Type:

str

acronym[source]

The acronym of the loss or objective function: FM

Type:

str

id[source]

The identifier of the loss or objective function: flowmatching

Type:

str

name = 'Flow Matching'[source]
acronym = 'FM'[source]
id = 'flowmatching'[source]
requires_backward_policy()[source]

Returns True if the loss function requires a backward policy.

The Flow Matching loss is well defined with a forward policy model only, hence False is returned.

Returns:

False

Return type:

bool

requires_state_flow_model()[source]

Returns True if the loss function requires a state flow model.

The Flow Matching loss is well defined with a forward policy model only, hence False is returned.

Returns:

False

Return type:

bool

is_defined_for_continuous()[source]

Returns True if the loss function is well defined for continuous GFlowNets, that is continuous environments, or False otherwise.

The Flow Matching loss is currently not well defined for continuous GFlowNets, therefore this method returns False.

Returns:

False

Return type:

bool

compute_losses_of_batch(batch)[source]

Computes the Flow Matching loss for each state of the input batch.

The Flow Matching (FM) loss or objective is computed in this method as is defined in Equation 12 of Bengio et al. (2021), except that the outer sum is ommited here:

Parameters:

batch (Batch) – A batch of states.

Returns:

losses (tensor) – The loss of each state in the batch.

Return type:

torchtyping.TensorType[batch_size]

aggregate_losses_of_batch(losses, batch)[source]

Aggregates the losses computed from a batch to obtain the overall average loss and the average loss over terminating states and intermediate states.

The result is returned as a dictionary with the following items: - ‘all’: Overall average loss - ‘Loss (terminating)’: Average loss over terminating states - ‘Loss (non-term.)’: Average loss over non-terminating (intermediate) states

Parameters:
  • losses (tensor) – The loss of each state in the batch.

  • batch (Batch) – A batch of states.

Returns:

loss_dict (dict) – A dictionary of loss aggregations.

Return type:

dict[str, float]