Source code for gflownet.losses.vargrad

"""
VarGrad loss or objective for training GFlowNets.

The VarGrad (VG) loss or objective was introduced by Richter and Boustati (2020) and
Nüsken and Richter (2023). Then it was rediscovered for GFNs by Zhang et al. (2023).

    .. _a link: https://arxiv.org/abs/2010.10436
    .. _a link: https://arxiv.org/abs/2005.05409
    .. _a link: https://arxiv.org/abs/2302.05446
"""

from torchtyping import TensorType

from gflownet.losses.trajectorybalance import TrajectoryBalance
from gflownet.utils.batch import Batch, compute_logprobs_trajectories


[docs] class VarGrad(TrajectoryBalance): def __init__(self, **kwargs): """ Initialization method for the VarGrad loss class. Attributes ---------- name : str The name of the loss or objective function: VarGrad acronym : str The acronym of the loss or objective function: vg id : str The identifier of the loss or objective function: vargrad """ super().__init__(**kwargs) assert self.forward_policy is not None assert self.backward_policy is not None # Attribute to indicate that logZ is *not* required in the computation of the # loss, unlike in the parent class TrajectoryBalance self._requires_log_z = False
[docs] self.name = "VarGrad"
[docs] self.acronym = "VG"
[docs] self.id = "vargrad"
[docs] def compute_losses_of_batch(self, batch: Batch) -> TensorType["batch_size"]: """ Computes the VarGrad loss for each trajectory of the input batch. The VarGrad loss or objective is computed in this method as is defined in equation 8 of David W. Zhang's paper. .. _a link: https://arxiv.org/pdf/2302.05446 Parameters ---------- batch : Batch A batch of trajectories. Returns ------- tensor The loss of each trajectory in the batch. """ # Get logprobs of forward and backward transitions logprobs_f = compute_logprobs_trajectories( batch, forward_policy=self.forward_policy, backward=False ) logprobs_b = compute_logprobs_trajectories( batch, backward_policy=self.backward_policy, backward=True ) # Get rewards from batch logrewards = batch.get_terminating_rewards(log=True, sort_by="trajectory") # Estimate the expected logZ as the average over the batch if logprobs_f.requires_grad or logprobs_b.requires_grad: logZ = (logrewards + logprobs_b - logprobs_f).detach().mean(dim=0) else: logZ = (logrewards + logprobs_b - logprobs_f).mean(dim=0) # VarGrad loss return (logZ + logprobs_f - logprobs_b - logrewards).pow(2)