gflownet.losses.detailedbalance =============================== .. py:module:: gflownet.losses.detailedbalance .. autoapi-nested-parse:: Detailed Balance loss or objective for training GFlowNets. The Detailed Balance (DB) loss or objective was defined by Malkin et al. (2022): .. _a link: https://arxiv.org/abs/2201.13259 Classes ------- .. autoapisummary:: gflownet.losses.detailedbalance.DetailedBalance Module Contents --------------- .. py:class:: DetailedBalance(**kwargs) Bases: :py:obj:`gflownet.losses.base.BaseLoss` Initialization method for the Detailed Balance loss class. .. attribute:: name The name of the loss or objective function: Detailed Balance :type: str .. attribute:: acronym The acronym of the loss or objective function: DB :type: str .. attribute:: id The identifier of the loss or objective function: detailedbalance :type: str .. py:attribute:: name :value: 'Detailed Balance' .. py:attribute:: acronym :value: 'DB' .. py:attribute:: id :value: 'detailedbalance' .. py:method:: requires_backward_policy() Returns True if the loss function requires a backward policy. The Detailed Balance loss does require a backward policy model, hence True is returned. :returns: *True* .. py:method:: requires_state_flow_model() Returns True if the loss function requires a state flow model. The Detailed Balance loss does require a state flow model, hence True is returned. :returns: *True* .. py:method:: is_defined_for_continuous() Returns True if the loss function is well defined for continuous GFlowNets, that is continuous environments, or False otherwise. The Detailed Balance loss is well defined for continuous GFlowNets, therefore this method returns True. :returns: *True* .. py:method:: compute_losses_of_batch(batch) Computes the Detailed Balance loss for each state of the input batch. The Detailed Balance (DB) loss or objective is computed in this method as is defined in Equation 11 of Malkin et al. (2022). .. _a link: https://arxiv.org/abs/2106.04399 :param batch: A batch of states. :type batch: Batch :returns: **losses** (*tensor*) -- The loss of each state in the batch. .. py:method:: aggregate_losses_of_batch(losses, batch) Aggregates the losses computed from a batch to obtain the overall average loss and the average loss over terminating states and intermediate states. The result is returned as a dictionary with the following items: - 'all': Overall average loss - 'Loss (terminating)': Average loss over terminating states - 'Loss (non-term.)': Average loss over non-terminating (intermediate) states :param losses: The loss of each state in the batch. :type losses: tensor :param batch: A batch of states. :type batch: Batch :returns: **loss_dict** (*dict*) -- A dictionary of loss aggregations.