gflownet.losses.base
Base class for GFlowNet losses or objective functions.
Warning
Should not be used directly, but subclassed to implement specific losses or objective functions for training a GFlowNet.
Classes
Base class for GFlowNet losses. |
Module Contents
- class gflownet.losses.base.BaseLoss(forward_policy, backward_policy=None, state_flow=None, logZ=None, early_stopping_th=0.0, ema_alpha=0.0, device='cpu', float_precision=32)[source]
Base class for GFlowNet losses.
- Parameters:
forward_policy (
gflownet.policy.base.Policy) – The forward policy to be used for training. Parameterized fromgflownet.yaml:forward_policyand parsed withgflownet/utils/policy.py:set_policy.backward_policy (
gflownet.policy.base.Policy, optional) – Same as forward_policy, but for the backward policy.state_flow (
gflownet.policy.state_flow.StateFlow, optional) – Same as forward_policy and backward_policy, but for the state flow model.logZ (Parameter, optional) – The learnable parameters for the log-partition function logZ. By default None. It may be extended to consider modelling logZ with a neural network.
early_stopping_th (float) – Threshold value for early stopping. If larger than 0.0, if the moving average of the loss falls under this threshold, training is stopped. If the value is 0.0 (default), no early stopping is applied.
device (str or torch.device) – The device to be passed to torch tensors.
float_precision (int or torch.dtype) – The floating point precision to be passed to torch tensors.
ema_alpha (float)
- early_stopping_th[source]
Threshold value for early stopping. If larger than 0.0, if the moving average of the loss falls under this threshold, training is stopped. If the value is 0.0, no early stopping is applied.
- Type:
float
- forward_policy[source]
The forward policy to be used for training. Parameterized from
gflownet.yaml:forward_policyand parsed withgflownet/utils/policy.py:set_policy.
- name[source]
The name of the loss or objective function. This is meant to be nicely formatted for printing purposes, for example using capital letters and spaces.
- Type:
str
- id[source]
The identifier of the loss or objective function. This is for processing purposes.
- Type:
str
- property requires_log_z[source]
Returns True if the loss function requires logZ in its computation.
- abstract requires_backward_policy()[source]
Returns True if the loss function requires a backward policy.
- Returns:
bool – Whether the loss function requires a backward policy.
- Return type:
bool
- abstract requires_state_flow_model()[source]
Returns True if the loss function requires a state flow model.
- Returns:
bool – Whether the loss function requires a state flow model.
- Return type:
bool
- abstract is_defined_for_continuous()[source]
Returns True if the loss function is well defined for continuous GFlowNets, that is continuous environments, or False otherwise.
- Returns:
bool – Whether the loss function is well defined for continuous GFlowNets.
- Return type:
bool
- abstract compute_losses_of_batch(batch)[source]
Computes the loss for each state or trajectory of the input batch.
- Parameters:
batch (Batch) – A batch of states or trajectories.
- Returns:
losses (tensor) – The loss of each unit in the batch. Depending on the loss function, the unit may be states (for example, for the flow matching loss) or trajectories (for example, for the trajectory balance loss). That is, the notion of batch size is different for each loss function, as it may refer to the number of states or the number of trajectories.
- Return type:
torchtyping.TensorType[batch_size]
- abstract aggregate_losses_of_batch(losses, batch)[source]
Aggregates the losses computed from a batch to obtain the average loss and/or multiple averages over different relevant parts of the batch.
The result is returned as a dictionary whose keys are the identifiers of each type of aggregation and the values are the aggregated losses.
It is expected that one of the keys in the dictionary is ‘all’ and its value corresponds to the overall loss, which may be used to compute the gradient with respect to graph leaves with
backward().- Parameters:
losses (tensor) – The loss of each unit (state or trajectory) in the batch.
batch (Batch) – A batch of states or trajectories.
- Returns:
loss_dict (dict) – A dictionary of loss aggregations. The keys are the identifiers of each type of aggregation and the values are the aggregated losses.
- Return type:
dict[str, float]
- compute(batch, get_sublosses=False)[source]
Computes and aggregates the losses of a batch of states or trajectories.
- Parameters:
batch (Batch) – A batch of states or trajectories.
get_sublosses (bool) – Whether specific, relevant sub-aggregations of the loss should be computed and returned as a dictionary. Example of sub-losses are the average loss over the terminating states, over the intermediate states, over the on-policy trajectories, over the replay buffer trajectories, etc. If True, the returned variable is a dictionary. If False, simply the mean over all losses in the batch is returned.
- Returns:
float or dict – A float containing the average loss or dictionary of loss aggregations, depending on the value of
get_sublosses.- Return type:
Union[float, dict[str, float]]
- set_log_z(logZ)[source]
Sets the input logZ as an attribute of the class instance.
- Parameters:
logZ (Parameters) – The learnable parameters for the log-partition function logZ.
- do_early_stopping(loss)[source]
Returns True if early stopping critera are met, according to an exponential moving average of the loss and a loss threshold.
Early stopping is applied only if
self.early_stopping_this larger than 0.The exponential moving average (EMA) is applied as follows:
\[\ell_{EMA}(t=0) = \ell(t=0) \ell_{EMA}(t) = lpha \cdot \ell(t) + (1 - lpha) \cdot \ell_{EMA}(t-1),\]where
\[\ell_{EMA}(t)\]is the exponential moving average of the loss at iteration\[t\]and\[\ell(t)\]is the global average loss at iteration\[t\].See:
- Parameters:
loss (float) – The current value of the loss.
- Returns:
bool – Whether early stopping criteria are met.
- Return type:
bool