"""
Flow Matching loss or objective for training GFlowNets.
The Flow Matching (FM) loss or objective was defined by Bengio et al. (2021):
.. _a link: https://arxiv.org/abs/2106.04399
"""
import torch
from torchtyping import TensorType
from gflownet.losses.base import BaseLoss
from gflownet.utils.batch import Batch
[docs]
class FlowMatching(BaseLoss):
def __init__(self, **kwargs):
"""
Initialization method for the Flow Matching loss class.
Attributes
----------
name : str
The name of the loss or objective function: Flow Matching
acronym : str
The acronym of the loss or objective function: FM
id : str
The identifier of the loss or objective function: flowmatching
"""
super().__init__(**kwargs)
assert self.forward_policy is not None
[docs]
self.name = "Flow Matching"
[docs]
self.id = "flowmatching"
[docs]
def requires_backward_policy(self) -> bool:
"""
Returns True if the loss function requires a backward policy.
The Flow Matching loss is well defined with a forward policy model only, hence
False is returned.
Returns
-------
False
"""
return False
[docs]
def requires_state_flow_model(self) -> bool:
"""
Returns True if the loss function requires a state flow model.
The Flow Matching loss is well defined with a forward policy model only, hence
False is returned.
Returns
-------
False
"""
return False
[docs]
def is_defined_for_continuous(self) -> bool:
"""
Returns True if the loss function is well defined for continuous GFlowNets,
that is continuous environments, or False otherwise.
The Flow Matching loss is currently not well defined for continuous GFlowNets,
therefore this method returns False.
Returns
-------
False
"""
return False
# TODO: consider using epsilon
[docs]
def compute_losses_of_batch(self, batch: Batch) -> TensorType["batch_size"]:
"""
Computes the Flow Matching loss for each state of the input batch.
The Flow Matching (FM) loss or objective is computed in this method as is
defined in Equation 12 of Bengio et al. (2021), except that the outer sum is
ommited here:
.. _a link: https://arxiv.org/abs/2106.04399
Parameters
----------
batch : Batch
A batch of states.
Returns
-------
losses : tensor
The loss of each state in the batch.
"""
assert batch.is_valid()
# Get necessary tensors from batch
states = batch.get_states(policy=True)
parents, parents_actions, parents_state_idx = batch.get_parents_all(policy=True)
done = batch.get_done()
masks_sf = batch.get_masks_forward()
parents_a_idx = batch.readonly_env.actions2indices(parents_actions)
# Log-rewards are stored in variable named outflows so that outflows of
# intermediate states can be stored in the same variable
outflows = batch.get_rewards(log=True)
# Compute in-flows
inflow_logits = torch.full(
(states.shape[0], batch.readonly_env.policy_output_dim),
-torch.inf,
dtype=self.float,
device=self.device,
)
inflow_logits[parents_state_idx, parents_a_idx] = self.forward_policy(parents)[
torch.arange(parents.shape[0]), parents_a_idx
]
inflows = torch.logsumexp(inflow_logits, dim=1)
# Compute out-flows
outflow_logits = self.forward_policy(states)
outflow_logits[masks_sf] = -torch.inf
outflows[~done] = torch.logsumexp(outflow_logits[~done], dim=1)
# Compute and return the flow matching loss for each state
return (inflows - outflows).pow(2)
[docs]
def aggregate_losses_of_batch(
self, losses: TensorType["batch_size"], batch: Batch
) -> dict[str, float]:
"""
Aggregates the losses computed from a batch to obtain the overall average loss
and the average loss over terminating states and intermediate states.
The result is returned as a dictionary with the following items:
- 'all': Overall average loss
- 'Loss (terminating)': Average loss over terminating states
- 'Loss (non-term.)': Average loss over non-terminating (intermediate) states
Parameters
----------
losses : tensor
The loss of each state in the batch.
batch : Batch
A batch of states.
Returns
-------
loss_dict : dict
A dictionary of loss aggregations.
"""
done = batch.get_done()
# Loss of terminating states
loss_term = losses[done].mean()
contrib_term = done.eq(1).to(self.float).mean()
# Loss of non-terminating states
loss_interm = losses[~done].mean()
contrib_interm = done.eq(0).to(self.float).mean()
# Overall loss
loss_overall = contrib_term * loss_term + contrib_interm * loss_interm
return {
"all": loss_overall,
"Loss (terminating)": loss_term,
"Loss (non-term.)": loss_interm,
}