gflownet.losses.vargrad
VarGrad loss or objective for training GFlowNets.
The VarGrad (VG) loss or objective was introduced by Richter and Boustati (2020) and Nüsken and Richter (2023). Then it was rediscovered for GFNs by Zhang et al. (2023).
Classes
Initialization method for the VarGrad loss class. |
Module Contents
- class gflownet.losses.vargrad.VarGrad(**kwargs)[source]
Bases:
gflownet.losses.trajectorybalance.TrajectoryBalanceInitialization method for the VarGrad loss class.
- compute_losses_of_batch(batch)[source]
Computes the VarGrad loss for each trajectory of the input batch.
The VarGrad loss or objective is computed in this method as is defined in equation 8 of David W. Zhang’s paper.
- Parameters:
batch (Batch) – A batch of trajectories.
- Returns:
tensor – The loss of each trajectory in the batch.
- Return type:
torchtyping.TensorType[batch_size]