gflownet.losses.base

Base class for GFlowNet losses or objective functions.

Warning

Should not be used directly, but subclassed to implement specific losses or objective functions for training a GFlowNet.

Classes

BaseLoss

Base class for GFlowNet losses.

Module Contents

class gflownet.losses.base.BaseLoss(forward_policy, backward_policy=None, state_flow=None, logZ=None, early_stopping_th=0.0, ema_alpha=0.0, device='cpu', float_precision=32)[source]

Base class for GFlowNet losses.

Parameters:
  • forward_policy (gflownet.policy.base.Policy) – The forward policy to be used for training. Parameterized from gflownet.yaml:forward_policy and parsed with gflownet/utils/policy.py:set_policy.

  • backward_policy (gflownet.policy.base.Policy, optional) – Same as forward_policy, but for the backward policy.

  • state_flow (gflownet.policy.state_flow.StateFlow, optional) – Same as forward_policy and backward_policy, but for the state flow model.

  • logZ (Parameter, optional) – The learnable parameters for the log-partition function logZ. By default None. It may be extended to consider modelling logZ with a neural network.

  • early_stopping_th (float) – Threshold value for early stopping. If larger than 0.0, if the moving average of the loss falls under this threshold, training is stopped. If the value is 0.0 (default), no early stopping is applied.

  • device (str or torch.device) – The device to be passed to torch tensors.

  • float_precision (int or torch.dtype) – The floating point precision to be passed to torch tensors.

  • ema_alpha (float)

early_stopping_th[source]

Threshold value for early stopping. If larger than 0.0, if the moving average of the loss falls under this threshold, training is stopped. If the value is 0.0, no early stopping is applied.

Type:

float

ema_alpha[source]

Coefficient for the exponential moving average (EMA) of the loss.

Type:

float

loss_ema[source]

The exponential moving average of the loss.

Type:

float

forward_policy[source]

The forward policy to be used for training. Parameterized from gflownet.yaml:forward_policy and parsed with gflownet/utils/policy.py:set_policy.

Type:

gflownet.policy.base.Policy

backward_policy[source]

Same as forward_policy, but for the backward policy.

Type:

gflownet.policy.base.Policy

state_flow[source]

State flow config dictionary. default None.

Type:

dict

logZ[source]

The learnable parameters for the log-partition function logZ.

Type:

Parameter

device[source]

The device to be passed to torch tensors.

Type:

torch.device

float[source]

The floating point precision to be passed to torch tensors.

Type:

torch.dtype

name[source]

The name of the loss or objective function. This is meant to be nicely formatted for printing purposes, for example using capital letters and spaces.

Type:

str

acronym[source]

The acronym of the loss or objective function.

Type:

str

id[source]

The identifier of the loss or objective function. This is for processing purposes.

Type:

str

early_stopping_th = 0.0[source]
ema_alpha = 0.0[source]
loss_ema = None[source]
forward_policy[source]
backward_policy = None[source]
state_flow = None[source]
logZ = None[source]
device = 'cpu'[source]
float = 32[source]
name = 'Base Loss'[source]
acronym = ''[source]
id = 'base'[source]
property requires_log_z[source]

Returns True if the loss function requires logZ in its computation.

abstract requires_backward_policy()[source]

Returns True if the loss function requires a backward policy.

Returns:

bool – Whether the loss function requires a backward policy.

Return type:

bool

abstract requires_state_flow_model()[source]

Returns True if the loss function requires a state flow model.

Returns:

bool – Whether the loss function requires a state flow model.

Return type:

bool

abstract is_defined_for_continuous()[source]

Returns True if the loss function is well defined for continuous GFlowNets, that is continuous environments, or False otherwise.

Returns:

bool – Whether the loss function is well defined for continuous GFlowNets.

Return type:

bool

abstract compute_losses_of_batch(batch)[source]

Computes the loss for each state or trajectory of the input batch.

Parameters:

batch (Batch) – A batch of states or trajectories.

Returns:

losses (tensor) – The loss of each unit in the batch. Depending on the loss function, the unit may be states (for example, for the flow matching loss) or trajectories (for example, for the trajectory balance loss). That is, the notion of batch size is different for each loss function, as it may refer to the number of states or the number of trajectories.

Return type:

torchtyping.TensorType[batch_size]

abstract aggregate_losses_of_batch(losses, batch)[source]

Aggregates the losses computed from a batch to obtain the average loss and/or multiple averages over different relevant parts of the batch.

The result is returned as a dictionary whose keys are the identifiers of each type of aggregation and the values are the aggregated losses.

It is expected that one of the keys in the dictionary is ‘all’ and its value corresponds to the overall loss, which may be used to compute the gradient with respect to graph leaves with backward().

Parameters:
  • losses (tensor) – The loss of each unit (state or trajectory) in the batch.

  • batch (Batch) – A batch of states or trajectories.

Returns:

loss_dict (dict) – A dictionary of loss aggregations. The keys are the identifiers of each type of aggregation and the values are the aggregated losses.

Return type:

dict[str, float]

compute(batch, get_sublosses=False)[source]

Computes and aggregates the losses of a batch of states or trajectories.

Parameters:
  • batch (Batch) – A batch of states or trajectories.

  • get_sublosses (bool) – Whether specific, relevant sub-aggregations of the loss should be computed and returned as a dictionary. Example of sub-losses are the average loss over the terminating states, over the intermediate states, over the on-policy trajectories, over the replay buffer trajectories, etc. If True, the returned variable is a dictionary. If False, simply the mean over all losses in the batch is returned.

Returns:

float or dict – A float containing the average loss or dictionary of loss aggregations, depending on the value of get_sublosses.

Return type:

Union[float, dict[str, float]]

set_log_z(logZ)[source]

Sets the input logZ as an attribute of the class instance.

Parameters:

logZ (Parameters) – The learnable parameters for the log-partition function logZ.

do_early_stopping(loss)[source]

Returns True if early stopping critera are met, according to an exponential moving average of the loss and a loss threshold.

Early stopping is applied only if self.early_stopping_th is larger than 0.

The exponential moving average (EMA) is applied as follows:

\[\ell_{EMA}(t=0) = \ell(t=0) \ell_{EMA}(t) = lpha \cdot \ell(t) + (1 - lpha) \cdot \ell_{EMA}(t-1),\]

where

\[\ell_{EMA}(t)\]
is the exponential moving average of the loss at iteration
\[t\]
and
\[\ell(t)\]
is the global average loss at iteration
\[t\]
.

See:

Parameters:

loss (float) – The current value of the loss.

Returns:

bool – Whether early stopping criteria are met.

Return type:

bool