"""
Detailed Balance loss or objective for training GFlowNets.
The Detailed Balance (DB) loss or objective was defined by Malkin et al. (2022):
.. _a link: https://arxiv.org/abs/2201.13259
"""
from torchtyping import TensorType
from gflownet.losses.base import BaseLoss
from gflownet.utils.batch import Batch
[docs]
class DetailedBalance(BaseLoss):
def __init__(self, **kwargs):
"""
Initialization method for the Detailed Balance loss class.
Attributes
----------
name : str
The name of the loss or objective function: Detailed Balance
acronym : str
The acronym of the loss or objective function: DB
id : str
The identifier of the loss or objective function: detailedbalance
"""
super().__init__(**kwargs)
assert self.forward_policy is not None
assert self.backward_policy is not None
assert self.state_flow is not None
[docs]
self.name = "Detailed Balance"
[docs]
self.id = "detailedbalance"
[docs]
def requires_backward_policy(self) -> bool:
"""
Returns True if the loss function requires a backward policy.
The Detailed Balance loss does require a backward policy model, hence True is
returned.
Returns
-------
True
"""
return True
[docs]
def requires_state_flow_model(self) -> bool:
"""
Returns True if the loss function requires a state flow model.
The Detailed Balance loss does require a state flow model, hence True is
returned.
Returns
-------
True
"""
return True
[docs]
def is_defined_for_continuous(self) -> bool:
"""
Returns True if the loss function is well defined for continuous GFlowNets,
that is continuous environments, or False otherwise.
The Detailed Balance loss is well defined for continuous GFlowNets, therefore
this method returns True.
Returns
-------
True
"""
return True
[docs]
def compute_losses_of_batch(self, batch: Batch) -> TensorType["batch_size"]:
"""
Computes the Detailed Balance loss for each state of the input batch.
The Detailed Balance (DB) loss or objective is computed in this method as is
defined in Equation 11 of Malkin et al. (2022).
.. _a link: https://arxiv.org/abs/2106.04399
Parameters
----------
batch : Batch
A batch of states.
Returns
-------
losses : tensor
The loss of each state in the batch.
"""
assert batch.is_valid()
# Get necessary tensors from batch
states = batch.get_states(policy=False)
states_policy = batch.get_states(policy=True)
actions = batch.get_actions()
parents = batch.get_parents(policy=False)
parents_policy = batch.get_parents(policy=True)
done = batch.get_done()
logrewards = batch.get_terminating_rewards(log=True, sort_by="insertion")
# Get logprobs
masks_f = batch.get_masks_forward(of_parents=True)
policy_output_f = self.forward_policy(parents_policy)
logprobs_f = batch.readonly_env.get_logprobs(
policy_output_f, actions, masks_f, parents, is_backward=False
)
masks_b = batch.get_masks_backward()
policy_output_b = self.backward_policy(states_policy)
logprobs_b = batch.readonly_env.get_logprobs(
policy_output_b, actions, masks_b, states, is_backward=True
)
# Get logflows
logflows_states = self.state_flow(states_policy)
logflows_states[done.eq(1)] = logrewards
# TODO: Optimise by reusing logflows_states and batch.get_parent_indices
logflows_parents = self.state_flow(parents_policy)
# Detailed balance loss
return (logflows_parents + logprobs_f - logflows_states - logprobs_b).pow(2)
[docs]
def aggregate_losses_of_batch(
self, losses: TensorType["batch_size"], batch: Batch
) -> dict[str, float]:
"""
Aggregates the losses computed from a batch to obtain the overall average loss
and the average loss over terminating states and intermediate states.
The result is returned as a dictionary with the following items:
- 'all': Overall average loss
- 'Loss (terminating)': Average loss over terminating states
- 'Loss (non-term.)': Average loss over non-terminating (intermediate) states
Parameters
----------
losses : tensor
The loss of each state in the batch.
batch : Batch
A batch of states.
Returns
-------
loss_dict : dict
A dictionary of loss aggregations.
"""
done = batch.get_done()
# Loss of terminating states
loss_term = losses[done].mean()
contrib_term = done.eq(1).to(self.float).mean()
# Loss of non-terminating states
loss_interm = losses[~done].mean()
contrib_interm = done.eq(0).to(self.float).mean()
# Overall loss
loss_overall = contrib_term * loss_term + contrib_interm * loss_interm
return {
"all": loss_overall,
"Loss (terminating)": loss_term,
"Loss (non-term.)": loss_interm,
}