"""
VarGrad loss or objective for training GFlowNets.
The VarGrad (VG) loss or objective was introduced by Richter and Boustati (2020) and
Nüsken and Richter (2023). Then it was rediscovered for GFNs by Zhang et al. (2023).
.. _a link: https://arxiv.org/abs/2010.10436
.. _a link: https://arxiv.org/abs/2005.05409
.. _a link: https://arxiv.org/abs/2302.05446
"""
from torchtyping import TensorType
from gflownet.losses.trajectorybalance import TrajectoryBalance
from gflownet.utils.batch import Batch, compute_logprobs_trajectories
[docs]
class VarGrad(TrajectoryBalance):
def __init__(self, **kwargs):
"""
Initialization method for the VarGrad loss class.
Attributes
----------
name : str
The name of the loss or objective function: VarGrad
acronym : str
The acronym of the loss or objective function: vg
id : str
The identifier of the loss or objective function: vargrad
"""
super().__init__(**kwargs)
assert self.forward_policy is not None
assert self.backward_policy is not None
# Attribute to indicate that logZ is *not* required in the computation of the
# loss, unlike in the parent class TrajectoryBalance
self._requires_log_z = False
[docs]
def compute_losses_of_batch(self, batch: Batch) -> TensorType["batch_size"]:
"""
Computes the VarGrad loss for each trajectory of the input batch.
The VarGrad loss or objective is computed in this method as is
defined in equation 8 of David W. Zhang's paper.
.. _a link: https://arxiv.org/pdf/2302.05446
Parameters
----------
batch : Batch
A batch of trajectories.
Returns
-------
tensor
The loss of each trajectory in the batch.
"""
# Get logprobs of forward and backward transitions
logprobs_f = compute_logprobs_trajectories(
batch, forward_policy=self.forward_policy, backward=False
)
logprobs_b = compute_logprobs_trajectories(
batch, backward_policy=self.backward_policy, backward=True
)
# Get rewards from batch
logrewards = batch.get_terminating_rewards(log=True, sort_by="trajectory")
# Estimate the expected logZ as the average over the batch
if logprobs_f.requires_grad or logprobs_b.requires_grad:
logZ = (logrewards + logprobs_b - logprobs_f).detach().mean(dim=0)
else:
logZ = (logrewards + logprobs_b - logprobs_f).mean(dim=0)
# VarGrad loss
return (logZ + logprobs_f - logprobs_b - logrewards).pow(2)