gflownet.losses.flowmatching ============================ .. py:module:: gflownet.losses.flowmatching .. autoapi-nested-parse:: Flow Matching loss or objective for training GFlowNets. The Flow Matching (FM) loss or objective was defined by Bengio et al. (2021): .. _a link: https://arxiv.org/abs/2106.04399 Classes ------- .. autoapisummary:: gflownet.losses.flowmatching.FlowMatching Module Contents --------------- .. py:class:: FlowMatching(**kwargs) Bases: :py:obj:`gflownet.losses.base.BaseLoss` Initialization method for the Flow Matching loss class. .. attribute:: name The name of the loss or objective function: Flow Matching :type: str .. attribute:: acronym The acronym of the loss or objective function: FM :type: str .. attribute:: id The identifier of the loss or objective function: flowmatching :type: str .. py:attribute:: name :value: 'Flow Matching' .. py:attribute:: acronym :value: 'FM' .. py:attribute:: id :value: 'flowmatching' .. py:method:: requires_backward_policy() Returns True if the loss function requires a backward policy. The Flow Matching loss is well defined with a forward policy model only, hence False is returned. :returns: *False* .. py:method:: requires_state_flow_model() Returns True if the loss function requires a state flow model. The Flow Matching loss is well defined with a forward policy model only, hence False is returned. :returns: *False* .. py:method:: is_defined_for_continuous() Returns True if the loss function is well defined for continuous GFlowNets, that is continuous environments, or False otherwise. The Flow Matching loss is currently not well defined for continuous GFlowNets, therefore this method returns False. :returns: *False* .. py:method:: compute_losses_of_batch(batch) Computes the Flow Matching loss for each state of the input batch. The Flow Matching (FM) loss or objective is computed in this method as is defined in Equation 12 of Bengio et al. (2021), except that the outer sum is ommited here: .. _a link: https://arxiv.org/abs/2106.04399 :param batch: A batch of states. :type batch: Batch :returns: **losses** (*tensor*) -- The loss of each state in the batch. .. py:method:: aggregate_losses_of_batch(losses, batch) Aggregates the losses computed from a batch to obtain the overall average loss and the average loss over terminating states and intermediate states. The result is returned as a dictionary with the following items: - 'all': Overall average loss - 'Loss (terminating)': Average loss over terminating states - 'Loss (non-term.)': Average loss over non-terminating (intermediate) states :param losses: The loss of each state in the batch. :type losses: tensor :param batch: A batch of states. :type batch: Batch :returns: **loss_dict** (*dict*) -- A dictionary of loss aggregations.