Source code for gflownet.losses.flowmatching

"""
Flow Matching loss or objective for training GFlowNets.

The Flow Matching (FM) loss or objective was defined by Bengio et al. (2021):

    .. _a link: https://arxiv.org/abs/2106.04399
"""

import torch
from torchtyping import TensorType

from gflownet.losses.base import BaseLoss
from gflownet.utils.batch import Batch


[docs] class FlowMatching(BaseLoss): def __init__(self, **kwargs): """ Initialization method for the Flow Matching loss class. Attributes ---------- name : str The name of the loss or objective function: Flow Matching acronym : str The acronym of the loss or objective function: FM id : str The identifier of the loss or objective function: flowmatching """ super().__init__(**kwargs) assert self.forward_policy is not None
[docs] self.name = "Flow Matching"
[docs] self.acronym = "FM"
[docs] self.id = "flowmatching"
[docs] def requires_backward_policy(self) -> bool: """ Returns True if the loss function requires a backward policy. The Flow Matching loss is well defined with a forward policy model only, hence False is returned. Returns ------- False """ return False
[docs] def requires_state_flow_model(self) -> bool: """ Returns True if the loss function requires a state flow model. The Flow Matching loss is well defined with a forward policy model only, hence False is returned. Returns ------- False """ return False
[docs] def is_defined_for_continuous(self) -> bool: """ Returns True if the loss function is well defined for continuous GFlowNets, that is continuous environments, or False otherwise. The Flow Matching loss is currently not well defined for continuous GFlowNets, therefore this method returns False. Returns ------- False """ return False
# TODO: consider using epsilon
[docs] def compute_losses_of_batch(self, batch: Batch) -> TensorType["batch_size"]: """ Computes the Flow Matching loss for each state of the input batch. The Flow Matching (FM) loss or objective is computed in this method as is defined in Equation 12 of Bengio et al. (2021), except that the outer sum is ommited here: .. _a link: https://arxiv.org/abs/2106.04399 Parameters ---------- batch : Batch A batch of states. Returns ------- losses : tensor The loss of each state in the batch. """ assert batch.is_valid() # Get necessary tensors from batch states = batch.get_states(policy=True) parents, parents_actions, parents_state_idx = batch.get_parents_all(policy=True) done = batch.get_done() masks_sf = batch.get_masks_forward() parents_a_idx = batch.readonly_env.actions2indices(parents_actions) # Log-rewards are stored in variable named outflows so that outflows of # intermediate states can be stored in the same variable outflows = batch.get_rewards(log=True) # Compute in-flows inflow_logits = torch.full( (states.shape[0], batch.readonly_env.policy_output_dim), -torch.inf, dtype=self.float, device=self.device, ) inflow_logits[parents_state_idx, parents_a_idx] = self.forward_policy(parents)[ torch.arange(parents.shape[0]), parents_a_idx ] inflows = torch.logsumexp(inflow_logits, dim=1) # Compute out-flows outflow_logits = self.forward_policy(states) outflow_logits[masks_sf] = -torch.inf outflows[~done] = torch.logsumexp(outflow_logits[~done], dim=1) # Compute and return the flow matching loss for each state return (inflows - outflows).pow(2)
[docs] def aggregate_losses_of_batch( self, losses: TensorType["batch_size"], batch: Batch ) -> dict[str, float]: """ Aggregates the losses computed from a batch to obtain the overall average loss and the average loss over terminating states and intermediate states. The result is returned as a dictionary with the following items: - 'all': Overall average loss - 'Loss (terminating)': Average loss over terminating states - 'Loss (non-term.)': Average loss over non-terminating (intermediate) states Parameters ---------- losses : tensor The loss of each state in the batch. batch : Batch A batch of states. Returns ------- loss_dict : dict A dictionary of loss aggregations. """ done = batch.get_done() # Loss of terminating states loss_term = losses[done].mean() contrib_term = done.eq(1).to(self.float).mean() # Loss of non-terminating states loss_interm = losses[~done].mean() contrib_interm = done.eq(0).to(self.float).mean() # Overall loss loss_overall = contrib_term * loss_term + contrib_interm * loss_interm return { "all": loss_overall, "Loss (terminating)": loss_term, "Loss (non-term.)": loss_interm, }