Source code for gflownet.losses.trajectorybalance

"""
Trajectory Balance loss or objective for training GFlowNets.

The Trajectory Balance (TB) loss or objective was defined by Malkin et al. (2022):

    .. _a link: https://arxiv.org/abs/2201.13259
"""

from torchtyping import TensorType

from gflownet.losses.base import BaseLoss
from gflownet.utils.batch import Batch, compute_logprobs_trajectories


[docs] class TrajectoryBalance(BaseLoss): def __init__(self, **kwargs): """ Initialization method for the Trajectory Balance loss class. Attributes ---------- name : str The name of the loss or objective function: Trajectory Balance acronym : str The acronym of the loss or objective function: TB id : str The identifier of the loss or objective function: trajectorybalance """ super().__init__(**kwargs) assert self.forward_policy is not None assert self.backward_policy is not None # Attribute to indicate that logZ is required in the computation of the loss self._requires_log_z = True
[docs] self.name = "Trajectory Balance"
[docs] self.acronym = "TB"
[docs] self.id = "trajectorybalance"
[docs] def requires_backward_policy(self) -> bool: """ Returns True if the loss function requires a backward policy. The Trajectory Balance loss does require a backward policy model, hence True is returned. Returns ------- True """ return True
[docs] def requires_state_flow_model(self) -> bool: """ Returns True if the loss function requires a state flow model. The Trajectory Balance loss does not require a state flow model, hence False is returned. Returns ------- False """ return False
[docs] def is_defined_for_continuous(self) -> bool: """ Returns True if the loss function is well defined for continuous GFlowNets, that is continuous environments, or False otherwise. The Trajectory Balance loss is well defined for continuous GFlowNets, therefore this method returns True. Returns ------- True """ return True
[docs] def compute_losses_of_batch(self, batch: Batch) -> TensorType["batch_size"]: """ Computes the Trajectory Balance loss for each trajectory of the input batch. The Trajectory Balance (TB) loss or objective is computed in this method as is defined in Equation 14 of Malkin et al. (2022). .. _a link: https://arxiv.org/abs/2201.13259 Parameters ---------- batch : Batch A batch of trajectories. Returns ------- tensor The loss of each trajectory in the batch. """ # Get logprobs of forward and backward transitions logprobs_f = compute_logprobs_trajectories( batch, forward_policy=self.forward_policy, backward=False ) logprobs_b = compute_logprobs_trajectories( batch, backward_policy=self.backward_policy, backward=True ) # Get rewards from batch logrewards = batch.get_terminating_rewards(log=True, sort_by="trajectory") # Trajectory balance loss return (self.logZ.sum() + logprobs_f - logprobs_b - logrewards).pow(2)
# TODO: extend with loss over the different types of trajectories (forward, replay # buffer, training set...)
[docs] def aggregate_losses_of_batch( self, losses: TensorType["batch_size"], batch: Batch ) -> dict[str, float]: """ Aggregates the losses computed from a batch to obtain the overall average loss. The result is returned as a dictionary with the following items: - 'all': Overall average loss Parameters ---------- losses : tensor The loss of each trajectory in the batch. batch : Batch A batch of trajectories. Returns ------- loss_dict : dict A dictionary of loss aggregations. """ return { "all": losses.mean(), }