gflownet.losses.vargrad

VarGrad loss or objective for training GFlowNets.

The VarGrad (VG) loss or objective was introduced by Richter and Boustati (2020) and Nüsken and Richter (2023). Then it was rediscovered for GFNs by Zhang et al. (2023).

Classes

VarGrad

Initialization method for the VarGrad loss class.

Module Contents

class gflownet.losses.vargrad.VarGrad(**kwargs)[source]

Bases: gflownet.losses.trajectorybalance.TrajectoryBalance

Initialization method for the VarGrad loss class.

name[source]

The name of the loss or objective function: VarGrad

Type:

str

acronym[source]

The acronym of the loss or objective function: vg

Type:

str

id[source]

The identifier of the loss or objective function: vargrad

Type:

str

name = 'VarGrad'[source]
acronym = 'VG'[source]
id = 'vargrad'[source]
compute_losses_of_batch(batch)[source]

Computes the VarGrad loss for each trajectory of the input batch.

The VarGrad loss or objective is computed in this method as is defined in equation 8 of David W. Zhang’s paper.

Parameters:

batch (Batch) – A batch of trajectories.

Returns:

tensor – The loss of each trajectory in the batch.

Return type:

torchtyping.TensorType[batch_size]