gflownet.losses.base ==================== .. py:module:: gflownet.losses.base .. autoapi-nested-parse:: Base class for GFlowNet losses or objective functions. .. warning:: Should not be used directly, but subclassed to implement specific losses or objective functions for training a GFlowNet. Classes ------- .. autoapisummary:: gflownet.losses.base.BaseLoss Module Contents --------------- .. py:class:: BaseLoss(forward_policy, backward_policy = None, state_flow = None, logZ = None, early_stopping_th = 0.0, ema_alpha = 0.0, device = 'cpu', float_precision = 32) Base class for GFlowNet losses. :param forward_policy: The forward policy to be used for training. Parameterized from `gflownet.yaml:forward_policy` and parsed with `gflownet/utils/policy.py:set_policy`. :type forward_policy: :py:class:`gflownet.policy.base.Policy` :param backward_policy: Same as forward_policy, but for the backward policy. :type backward_policy: :py:class:`gflownet.policy.base.Policy`, optional :param state_flow: Same as forward_policy and backward_policy, but for the state flow model. :type state_flow: :py:class:`gflownet.policy.state_flow.StateFlow`, optional :param logZ: The learnable parameters for the log-partition function logZ. By default None. It may be extended to consider modelling logZ with a neural network. :type logZ: Parameter, optional :param early_stopping_th: Threshold value for early stopping. If larger than 0.0, if the moving average of the loss falls under this threshold, training is stopped. If the value is 0.0 (default), no early stopping is applied. :type early_stopping_th: float :param device: The device to be passed to torch tensors. :type device: str or torch.device :param float_precision: The floating point precision to be passed to torch tensors. :type float_precision: int or torch.dtype .. attribute:: early_stopping_th Threshold value for early stopping. If larger than 0.0, if the moving average of the loss falls under this threshold, training is stopped. If the value is 0.0, no early stopping is applied. :type: float .. attribute:: ema_alpha Coefficient for the exponential moving average (EMA) of the loss. :type: float .. attribute:: loss_ema The exponential moving average of the loss. :type: float .. attribute:: forward_policy The forward policy to be used for training. Parameterized from `gflownet.yaml:forward_policy` and parsed with `gflownet/utils/policy.py:set_policy`. :type: gflownet.policy.base.Policy .. attribute:: backward_policy Same as forward_policy, but for the backward policy. :type: gflownet.policy.base.Policy .. attribute:: state_flow State flow config dictionary. default None. :type: dict .. attribute:: logZ The learnable parameters for the log-partition function logZ. :type: Parameter .. attribute:: device The device to be passed to torch tensors. :type: torch.device .. attribute:: float The floating point precision to be passed to torch tensors. :type: torch.dtype .. attribute:: name The name of the loss or objective function. This is meant to be nicely formatted for printing purposes, for example using capital letters and spaces. :type: str .. attribute:: acronym The acronym of the loss or objective function. :type: str .. attribute:: id The identifier of the loss or objective function. This is for processing purposes. :type: str .. py:attribute:: early_stopping_th :value: 0.0 .. py:attribute:: ema_alpha :value: 0.0 .. py:attribute:: loss_ema :value: None .. py:attribute:: forward_policy .. py:attribute:: backward_policy :value: None .. py:attribute:: state_flow :value: None .. py:attribute:: logZ :value: None .. py:attribute:: device :value: 'cpu' .. py:attribute:: float :value: 32 .. py:attribute:: name :value: 'Base Loss' .. py:attribute:: acronym :value: '' .. py:attribute:: id :value: 'base' .. py:property:: requires_log_z Returns True if the loss function requires logZ in its computation. .. py:method:: requires_backward_policy() :abstractmethod: Returns True if the loss function requires a backward policy. :returns: *bool* -- Whether the loss function requires a backward policy. .. py:method:: requires_state_flow_model() :abstractmethod: Returns True if the loss function requires a state flow model. :returns: *bool* -- Whether the loss function requires a state flow model. .. py:method:: is_defined_for_continuous() :abstractmethod: Returns True if the loss function is well defined for continuous GFlowNets, that is continuous environments, or False otherwise. :returns: *bool* -- Whether the loss function is well defined for continuous GFlowNets. .. py:method:: compute_losses_of_batch(batch) :abstractmethod: Computes the loss for each state or trajectory of the input batch. :param batch: A batch of states or trajectories. :type batch: Batch :returns: **losses** (*tensor*) -- The loss of each unit in the batch. Depending on the loss function, the unit may be states (for example, for the flow matching loss) or trajectories (for example, for the trajectory balance loss). That is, the notion of batch size is different for each loss function, as it may refer to the number of states or the number of trajectories. .. py:method:: aggregate_losses_of_batch(losses, batch) :abstractmethod: Aggregates the losses computed from a batch to obtain the average loss and/or multiple averages over different relevant parts of the batch. The result is returned as a dictionary whose keys are the identifiers of each type of aggregation and the values are the aggregated losses. It is expected that one of the keys in the dictionary is 'all' and its value corresponds to the overall loss, which may be used to compute the gradient with respect to graph leaves with `backward()`. :param losses: The loss of each unit (state or trajectory) in the batch. :type losses: tensor :param batch: A batch of states or trajectories. :type batch: Batch :returns: **loss_dict** (*dict*) -- A dictionary of loss aggregations. The keys are the identifiers of each type of aggregation and the values are the aggregated losses. .. py:method:: compute(batch, get_sublosses = False) Computes and aggregates the losses of a batch of states or trajectories. :param batch: A batch of states or trajectories. :type batch: Batch :param get_sublosses: Whether specific, relevant sub-aggregations of the loss should be computed and returned as a dictionary. Example of sub-losses are the average loss over the terminating states, over the intermediate states, over the on-policy trajectories, over the replay buffer trajectories, etc. If True, the returned variable is a dictionary. If False, simply the mean over all losses in the batch is returned. :type get_sublosses: bool :returns: *float or dict* -- A float containing the average loss or dictionary of loss aggregations, depending on the value of `get_sublosses`. .. py:method:: set_log_z(logZ) Sets the input logZ as an attribute of the class instance. :param logZ: The learnable parameters for the log-partition function logZ. :type logZ: Parameters .. py:method:: do_early_stopping(loss) Returns True if early stopping critera are met, according to an exponential moving average of the loss and a loss threshold. Early stopping is applied only if ``self.early_stopping_th`` is larger than 0. The exponential moving average (EMA) is applied as follows: .. math:: \ell_{EMA}(t=0) = \ell(t=0) \ell_{EMA}(t) = lpha \cdot \ell(t) + (1 - lpha) \cdot \ell_{EMA}(t-1), where $$\ell_{EMA}(t)$$ is the exponential moving average of the loss at iteration $$t$$ and $$\ell(t)$$ is the global average loss at iteration $$t$$. See: .. _a link: https://en.wikipedia.org/wiki/Exponential_smoothing :param loss: The current value of the loss. :type loss: float :returns: *bool* -- Whether early stopping criteria are met.