gflownet.losses.flowmatching
Flow Matching loss or objective for training GFlowNets.
The Flow Matching (FM) loss or objective was defined by Bengio et al. (2021):
Classes
Initialization method for the Flow Matching loss class. |
Module Contents
- class gflownet.losses.flowmatching.FlowMatching(**kwargs)[source]
Bases:
gflownet.losses.base.BaseLossInitialization method for the Flow Matching loss class.
- requires_backward_policy()[source]
Returns True if the loss function requires a backward policy.
The Flow Matching loss is well defined with a forward policy model only, hence False is returned.
- Returns:
False
- Return type:
bool
- requires_state_flow_model()[source]
Returns True if the loss function requires a state flow model.
The Flow Matching loss is well defined with a forward policy model only, hence False is returned.
- Returns:
False
- Return type:
bool
- is_defined_for_continuous()[source]
Returns True if the loss function is well defined for continuous GFlowNets, that is continuous environments, or False otherwise.
The Flow Matching loss is currently not well defined for continuous GFlowNets, therefore this method returns False.
- Returns:
False
- Return type:
bool
- compute_losses_of_batch(batch)[source]
Computes the Flow Matching loss for each state of the input batch.
The Flow Matching (FM) loss or objective is computed in this method as is defined in Equation 12 of Bengio et al. (2021), except that the outer sum is ommited here:
- Parameters:
batch (Batch) – A batch of states.
- Returns:
losses (tensor) – The loss of each state in the batch.
- Return type:
torchtyping.TensorType[batch_size]
- aggregate_losses_of_batch(losses, batch)[source]
Aggregates the losses computed from a batch to obtain the overall average loss and the average loss over terminating states and intermediate states.
The result is returned as a dictionary with the following items: - ‘all’: Overall average loss - ‘Loss (terminating)’: Average loss over terminating states - ‘Loss (non-term.)’: Average loss over non-terminating (intermediate) states
- Parameters:
losses (tensor) – The loss of each state in the batch.
batch (Batch) – A batch of states.
- Returns:
loss_dict (dict) – A dictionary of loss aggregations.
- Return type:
dict[str, float]