from collections import OrderedDict
from typing import List, Optional, Tuple, Union
import numpy as np
import numpy.typing as npt
import torch
from torchtyping import TensorType
from gflownet.envs.base import GFlowNetEnv
from gflownet.policy.base import Policy
from gflownet.proxy.base import Proxy
from gflownet.utils.common import (
concat_items,
copy,
extend,
select_indices,
set_device,
set_float_precision,
tbool,
tfloat,
tlong,
)
[docs]
class Batch:
"""
Class to handle GFlowNet batches.
Important note: one env should correspond to only one trajectory, all env_id should
be unique.
Note: self.state_indices start from index 1 to indicate that index 0 would correspond
to the source state, but the latter is not stored in the batch for each trajectory.
This implies that one has to be careful when indexing the list of batch_indices in
self.trajectories by using self.state_indices. For example, the batch index of
state state_idx of trajectory traj_idx is self.trajectories[traj_idx][state_idx-1]
(not self.trajectories[traj_idx][state_idx]).
"""
def __init__(
self,
env: Optional[GFlowNetEnv] = None,
proxy: Optional[Proxy] = None,
device: Union[str, torch.device] = "cpu",
float_type: Union[int, torch.dtype] = 32,
collect_forwards_masks=False,
collect_backwards_masks=False,
):
"""
Arguments
---------
env : GFlowNetEnv
An instance of the environment that will be used to form the batch.
proxy : Proxy
An instance of a GFlowNet proxy that will be used to compute proxy values
and rewards.
device : str or torch.device
torch.device or string indicating the device to use ("cpu" or "cuda")
float_type : torch.dtype or int
One of float torch.dtype or an int indicating the float precision (16, 32
or 64).
"""
# Device
[docs]
self.device = set_device(device)
# Float precision
[docs]
self.float = set_float_precision(float_type)
# Keep reference to the provided env. However, this env instance might be
# used by other objects so, to avoid causing issues, this object should be
# treated as read-only. No method that changes the state of this object may be
# called on this object. Ex : reset(), step(), any setter method, ...
if env is not None:
self.set_env(env)
else:
self.readonly_env = None
self.source = None
self.conditional = None
self.continuous = None
# Proxy
# Initialize batch size 0
# Initialize empty batch variables
# TODO: make single ordered dictionary of dictionaries
[docs]
self.envs = OrderedDict()
[docs]
self.trajectories = OrderedDict()
[docs]
self.is_backward = OrderedDict()
# TODO: state_indices is currently unused, it is redundant and inconsistent
# between forward and backward trajectories. We may want to remove it.
[docs]
self.state_indices = []
[docs]
self.logprobs_forward = []
[docs]
self.logprobs_backward = []
[docs]
self.logprobs_forward_avail = []
[docs]
self.logprobs_backward_avail = []
[docs]
self.masks_invalid_actions_forward = []
[docs]
self.masks_invalid_actions_backward = []
[docs]
self.parents_actions_all = []
[docs]
self.states_policy = None
[docs]
self.parents_policy = None
# Flag to indicate if masks should be collected in add_to_batch
[docs]
self.collect_forwards_masks = collect_forwards_masks
[docs]
self.collect_backwards_masks = collect_backwards_masks
# Flags for available items
self._parents_available = False
self._parents_policy_available = False
self._parents_all_available = False
self._masks_forward_available = False
self._masks_backward_available = False
self._rewards_available = False
self._rewards_parents_available = False
self._rewards_source_available = False
self._logrewards_available = False
self._logrewards_parents_available = False
self._logrewards_source_available = False
self._proxy_values_available = False
[docs]
def __len__(self):
return self.size
[docs]
def batch_idx_to_traj_state_idx(self, batch_idx: int):
traj_idx = self.traj_indices[batch_idx]
state_idx = self.state_indices[batch_idx]
return traj_idx, state_id
[docs]
def traj_idx_to_batch_indices(self, traj_idx: int):
batch_indices = self.trajectories[traj_idx]
return batch_indices
[docs]
def traj_state_idx_to_batch_idx(self, traj_idx: int, state_idx: int):
batch_idx = self.trajectories[traj_idx][state_idx]
return batch_idx
[docs]
def traj_idx_action_idx_to_batch_idx(
self, traj_idx: int, action_idx: int, backward: bool
):
if traj_idx not in self.trajectories:
return None
if backward:
if action_idx >= len(self.trajectories[traj_idx]):
return None
return self.trajectories[traj_idx][::-1][action_idx]
if action_idx > len(self.trajectories[traj_idx]):
return None
return self.trajectories[traj_idx][action_idx - 1]
[docs]
def idx2state_idx(self, idx: int):
return self.trajectories[self.traj_indices[idx]].index(idx)
[docs]
def rewards_available(self, log: bool = False) -> bool:
"""
Returns True if the (log)rewards are available.
Parameters
----------
log : bool
If True, check self._logrewards_available. Otherwise (default), check
self._rewards_available.
Returns
-------
bool
True if the (log)rewards are available, False otherwise.
"""
if log:
return self._logrewards_available
else:
return self._rewards_available
[docs]
def rewards_parents_available(self, log: bool = False) -> bool:
"""
Returns True if the (log)rewards of the parents are available.
Parameters
----------
log : bool
If True, check self._logrewards_parents_available. Otherwise (default),
check self._rewards_parents_available.
Returns
-------
bool
True if the (log)rewards of the parents are available, False otherwise.
"""
if log:
return self._logrewards_parents_available
else:
return self._rewards_parents_available
[docs]
def rewards_source_available(self, log: bool = False) -> bool:
"""
Returns True if the (log)rewards of the source are available.
Parameters
----------
log : bool
If True, check self._logrewards_source_available. Otherwise (default),
check self._rewards_source_available.
Returns
-------
bool
True if the (log)rewards of the source are available, False otherwise.
"""
if log:
return self._logrewards_source_available
else:
return self._rewards_source_available
[docs]
def set_env(self, env: GFlowNetEnv):
"""
Sets the generic environment passed as an argument and initializes the
environment-dependent properties.
"""
self.readonly_env = env
source_state = self.readonly_env.source
self.source = {
"state": source_state,
"mask_forward": self.readonly_env.get_mask_invalid_actions_forward(
state=source_state, done=False
),
}
self.conditional = self.readonly_env.conditional
self.continuous = self.readonly_env.continuous
[docs]
def set_proxy(self, proxy: Proxy):
"""
Sets the proxy, used to compute rewards from a batch of states.
"""
self.proxy = proxy
[docs]
def add_to_batch(
self,
envs: List[GFlowNetEnv],
actions: List[Tuple],
logprobs: TensorType,
logprobs_rev: TensorType,
valids: List[bool],
backward: Optional[bool] = False,
train: Optional[bool] = True,
):
"""
Adds information from a list of environments and actions to the batch after
performing steps in the envs. If train is False, only the variables of
terminating states are stored.
Parameters
----------
envs : list
A list of environments (GFlowNetEnv).
actions : list
A list of actions attempted or performed on the envs.
logprobs : torch.tensor or list of None
Log probabilities corresponding to the actions or None.
logprobs_rev : torch.tensor or list of None
Log probabilities corresponding to the transitions in the opposite
direction to the sampling direction, from the current states but with the
actions added in the previous step. It may be a list of None.
valids : list
A list of boolean values indicated whether the actions were valid.
backward : bool
A boolean value indicating whether the action was sampled backward (False
by default). If True, the behavior is slightly different so as to match
what is stored in forward sampling:
- If it is the first state in the trajectory (action from a done
state/env), then done is stored as True, instead of taking env.done
which will be False after having performed the step.
- If it is not the first state in the trajectory, the stored state will
be the previous one in the trajectory, to match the state-action
stored in forward sampling and the convention that the source state
is not stored, but the terminating state is repeated with action eos.
train : bool
A boolean value indicating whether the data to add to the batch will be used
for training. Optional, default is True.
"""
# TODO: do we need this?
if self.continuous is None:
self.continuous = envs[0].continuous
indices_prev_trans = self.get_indices_of_previous_transitions(envs, backward)
assert (
len(envs)
== len(actions)
== len(logprobs)
== len(logprobs_rev)
== len(valids)
== len(indices_prev_trans)
)
# Add data samples to the batch
for env, action, logp, logp_rev, valid, idx_prev in zip(
envs, actions, logprobs, logprobs_rev, valids, indices_prev_trans
):
if train is False and env.done is False:
continue
if not valid:
continue
# Add env to dictionary
if env.id not in self.envs:
self.envs.update({env.id: env})
# Add batch index to trajectory
if env.id not in self.trajectories:
self.trajectories.update({env.id: [len(self)]})
else:
if backward:
self.trajectories[env.id].insert(0, len(self))
else:
self.trajectories[env.id].append(len(self))
# Set whether trajectory is backward
if env.id not in self.is_backward:
self.is_backward.update({env.id: backward})
# Add trajectory index and state index
self.traj_indices.append(env.id)
self.state_indices.append(env.n_actions)
# Add action
self.actions.append(action)
# Handle backward transition
if backward:
# Add state, parent and done
self.parents.append(copy(env.state))
if len(self.trajectories[env.id]) == 1:
self.states.append(copy(env.state))
self.done.append(True)
else:
self.states.append(copy(self.parents[self.trajectories[env.id][1]]))
self.done.append(env.done)
# Add backward logp for current action
self.logprobs_backward.append(logp)
self.logprobs_backward_avail.append(True if logp is not None else False)
# Add a placeholder (2.0) forward logp for transition into current state
self.logprobs_forward.append(
tfloat(2.0, device=self.device, float_type=self.float)
)
self.logprobs_forward_avail.append(False)
# If the index and logp for a previous transition are available, set
# the reverse (forward) logp at the index
if idx_prev is not None and logp_rev is not None:
self.logprobs_forward[idx_prev] = logp_rev
self.logprobs_forward_avail[idx_prev] = True
# Handle forward transition
else:
# Add state, parent and done
self.states.append(copy(env.state))
self.done.append(env.done)
if len(self.trajectories[env.id]) == 1:
self.parents.append(copy(self.source["state"]))
else:
self.parents.append(
copy(self.states[self.trajectories[env.id][-2]])
)
# Add forward logp for current action
self.logprobs_forward.append(logp)
self.logprobs_forward_avail.append(True if logp is not None else False)
# Add backward logp for transition into current action
if env.done:
# If the trajectory is done, the backward logp is always 0.0
self.logprobs_backward.append(
tfloat(0.0, device=self.device, float_type=self.float)
)
self.logprobs_backward_avail.append(True)
else:
# Otherwise, add a placeholder (2.0) backward logp state
self.logprobs_backward.append(
tfloat(2.0, device=self.device, float_type=self.float)
)
self.logprobs_backward_avail.append(False)
# If the index and logp for a previous transition are available, set
# the reverse (backward) logp at the index
if idx_prev is not None and logp_rev is not None:
self.logprobs_backward[idx_prev] = logp_rev
self.logprobs_backward_avail[idx_prev] = True
# Collect masks if needed
if self.collect_forwards_masks:
self.masks_invalid_actions_forward.append(
env.get_mask_invalid_actions_forward(self.states[-1], self.done[-1])
)
else:
self.masks_invalid_actions_forward.append(None)
if self.collect_backwards_masks:
self.masks_invalid_actions_backward.append(
env.get_mask_invalid_actions_backward(
self.states[-1], self.done[-1]
)
)
else:
self.masks_invalid_actions_backward.append(None)
# Increment size of batch
self.size += 1
# Other variables are not available after new items were added to the batch
self._masks_forward_available = False
self._masks_backward_available = False
self._parents_policy_available = False
self._parents_all_available = False
self._rewards_available = False
self._logrewards_available = False
[docs]
def get_n_trajectories(self) -> int:
"""
Returns the number of trajectories in the batch.
Returns
-------
The number of trajectories in the batch (int).
"""
return len(self.trajectories)
[docs]
def get_unique_trajectory_indices(self) -> List:
"""
Returns the unique trajectory indices as the keys of self.trajectories, which
is an OrderedDict, as a list.
"""
return list(self.trajectories.keys())
[docs]
def get_trajectory_indices(
self, consecutive: bool = False, return_mapping_dict: bool = False
) -> TensorType["n_states", int]:
"""
Returns the trajectory index of all elements in the batch as a long int torch
tensor.
Args
----
consecutive : bool
If True, the trajectory indices are mapped to consecutive indices starting
from 0, in the order of the OrderedDict self.trajectory.keys(). If False
(default), the trajectory indices are returned as they are.
return_mapping_dict : bool
If True, the dictionary mapping actual_index: consecutive_index is returned
as a second argument. Ignored if consecutive is False.
Returns
-------
traj_indices : torch.tensor
self.traj_indices as a long int torch tensor.
traj_index_to_consecutive_dict : dict
A dictionary mapping the actual trajectory indices in the Batch to the
consecutive indices. Ommited if return_mapping_dict is False (default).
"""
if consecutive:
traj_index_to_consecutive_dict = {
traj_idx: consecutive
for consecutive, traj_idx in enumerate(self.trajectories)
}
traj_indices = list(
map(lambda x: traj_index_to_consecutive_dict[x], self.traj_indices)
)
else:
traj_indices = self.traj_indices
if return_mapping_dict and consecutive:
return (
tlong(traj_indices, device=self.device),
traj_index_to_consecutive_dict,
)
else:
return tlong(traj_indices, device=self.device)
[docs]
def get_state_indices(self) -> TensorType["n_states", int]:
"""
Returns the state index of all elements in the batch as a long int torch
tensor.
Returns
-------
state_indices : torch.tensor
self.state_indices as a long int torch tensor.
"""
return tlong(self.state_indices, device=self.device)
[docs]
def get_states(
self,
policy: Optional[bool] = False,
proxy: Optional[bool] = False,
force_recompute: Optional[bool] = False,
indices: Optional[Union[List, Tuple, TensorType, npt.NDArray]] = None,
) -> Union[TensorType["n_states", "..."], npt.NDArray[np.float32], List]:
"""
Returns all the states in the batch.
The states are returned in "policy format" if policy is True, in "proxy format"
if proxy is True and otherwise they are returned in "GFlowNet" format by
default. An error is raised if both policy and proxy are True.
Args
----
policy : bool
If True, the policy format of the states is returned and self.states_policy
is updated if not available yet or if force_recompute is True.
proxy : bool
If True, the proxy format of the states is returned. States in proxy format
are not stored.
force_recompute : bool
If True, the policy states are recomputed even if they are available.
Ignored if policy is False.
indices: list, tuple, tensor or np.ndarray
1-dimensional sequence of batch indices for selecting states, optional.
If None (default), all the states will be returned.
Returns
-------
self.states or self.states_policy or self.states2proxy(self.states) : list or
torch.tensor or ndarray
The set of states in the selected format with the selected indices. If indices is
None, all states of the batch are returned.
"""
if policy is True and proxy is True:
raise ValueError(
"Ambiguous request! Only one of policy or proxy can be True."
)
if policy is True:
if self.states_policy is None or force_recompute is True:
self.states_policy = self.states2policy()
return select_indices(self.states_policy, indices)
if proxy is True:
return select_indices(self.states2proxy(), indices)
return select_indices(self.states, indices)
[docs]
def states2policy(
self,
states: Optional[Union[List[List], List[TensorType["n_states", "..."]]]] = None,
traj_indices: Optional[Union[List, TensorType["n_states"]]] = None,
) -> TensorType["n_states", "state_policy_dims"]:
"""
Converts states from a list of states in GFlowNet format to a tensor of states
in policy format.
Args
----
states: list
List of states in GFlowNet format.
traj_indices: list or torch.tensor
Ids indicating which env corresponds to each state in states. It is only
used if the environments are conditional to call state2policy from the
right environment. Ignored if self.conditional is False.
Returns
-------
states: torch.tensor
States in policy format.
"""
# If traj_indices is not None and self.conditional is True, then both states
# and traj_indices must be the same type and have the same length.
if traj_indices is not None and self.conditional is True:
assert type(states) == type(traj_indices)
assert len(states) == len(traj_indices)
if states is None:
states = self.states
traj_indices = self.traj_indices
# TODO: will env.policy_input_dim be the same for all envs if conditional?
if self.conditional:
states_policy = torch.zeros(
(len(states), self.readonly_env.policy_input_dim),
device=self.device,
dtype=self.float,
)
traj_indices_torch = tlong(traj_indices, device=self.device)
for traj_idx in self.trajectories:
if traj_idx not in traj_indices:
continue
states_policy[traj_indices_torch == traj_idx] = self.envs[
traj_idx
].statebatchpolicy(
self.get_states_of_trajectory(traj_idx, states, traj_indices)
)
return states_policy
return self.readonly_env.states2policy(states)
[docs]
def states2proxy(
self,
states: Optional[Union[List[List], List[TensorType["n_states", "..."]]]] = None,
traj_indices: Optional[Union[List, TensorType["n_states"]]] = None,
) -> Union[
TensorType["n_states", "state_proxy_dims"], npt.NDArray[np.float32], List
]:
"""
Converts states from a list of states in GFlowNet format to a tensor of states
in proxy format. Note that the implementatiuon of this method differs from
Batch.states2policy() because the latter always returns torch.tensors. The
output of the present method can also be numpy arrays or Python lists,
depending on the proxy.
Args
----
states: list
List of states in GFlowNet format.
traj_indices: list or torch.tensor
Ids indicating which env corresponds to each state in states. It is only
used if the environments are conditional to call state2proxy from the right
environment. Ignored if self.conditional is False.
Returns
-------
states: torch.tensor or ndarray or list
States in proxy format.
"""
# If traj_indices is not None and self.conditional is True, then both states
# and traj_indices must be the same type and have the same length.
if traj_indices is not None and self.conditional is True:
assert type(states) == type(traj_indices)
assert len(states) == len(traj_indices)
if states is None:
states = self.states
traj_indices = self.traj_indices
if self.conditional:
states_proxy = []
index = torch.arange(len(states), device=self.device)
perm_index = []
# TODO: rethink this
for traj_idx in self.trajectories:
if traj_idx not in traj_indices:
continue
states_proxy.append(
self.envs[traj_idx].states2proxy(
self.get_states_of_trajectory(traj_idx, states, traj_indices)
)
)
perm_index.append(index[env_ids == env_id])
perm_index = torch.cat(perm_index)
# Reverse permutation to make it index the states_proxy array
index[perm_index] = index.clone()
states_proxy = concat_items(states_proxy, index)
return states_proxy
return self.readonly_env.states2proxy(states)
[docs]
def get_actions(
self,
indices: Optional[Union[List, Tuple, TensorType, npt.NDArray]] = None,
) -> List:
"""
Returns the actions in the batch.
Parameters
----------
indices: list, tuple, tensor or np.ndarray
1-dimentional sequence of bacth indecies for selecting actions, optional.
If None (default), all the actions will be returned.
Returns
-------
list
The list of actions in the batch with selected indices. If indices is None,
all the actions will be returned.
"""
return select_indices(self.actions, indices)
[docs]
def get_logprobs(
self, backward: bool = False
) -> Tuple[TensorType["n_states"], TensorType["n_states"]]:
"""
Returns the logprobs in the batch as a float tensor.
If there is any None values in self.logprobs, the list cannot be converted to a
tensor. This exception is caught and None is returned, as a signal that the
logprobs are not available.
Parameters
----------
backward : bool
Whether the requested logprobs are of backward transitions.
Returns
-------
logprobs : TensorType["n_states"]
Tensor of logprobs
valids: TensorType["n_states"]
Boolean tensor with flags indicating whether the correspondig logprobs are
valid. The flag is False if the corresponding logprob value is a zero / None
placefolder.
"""
if backward:
return tfloat(
self.logprobs_backward, device=self.device, float_type=self.float
), tbool(self.logprobs_backward_avail, device=self.device)
else:
return tfloat(
self.logprobs_forward, device=self.device, float_type=self.float
), tbool(self.logprobs_forward_avail, device=self.device)
[docs]
def get_done(self) -> TensorType["n_states"]:
"""
Returns the list of done flags as a boolean tensor.
"""
return tbool(self.done, device=self.device)
# TODO: check availability one by one as in get_masks
[docs]
def get_parents(
self,
policy: Optional[bool] = False,
force_recompute: Optional[bool] = False,
indices: Optional[Union[List, Tuple, TensorType, npt.NDArray]] = None,
) -> TensorType["n_states", "..."]:
"""
Returns the parent (single parent for each state) of all states in the batch.
The parents are computed, obtaining all necessary components, if they are not
readily available. Missing components and newly computed components are added
to the batch (self.component is set).
The parents are returned in "policy format" if policy is True, otherwise they
are returned in "GFlowNet" format (default).
Parameters
----------
policy : bool
If True, the policy format of parents is returned. Otherwise, the GFlowNet
format is returned.
force_recompute : bool
If True, the parents are recomputed even if they are available.
indices: list, tuple, tensor or np.ndarray
1-dimentional sequence of bacth indecies for selecting parents, optional.
If None (default), the parents of all states in the batch will be returned.
Returns
-------
self.parents or self.parents_policy : torch.tensor
The parent of states selected by indices. If indices is None, the parents of
all states in the batch are returned.
"""
if self._parents_available is False or force_recompute is True:
self._compute_parents()
if policy:
if self._parents_policy_available is False or force_recompute is True:
self._compute_parents_policy()
return select_indices(self.parents_policy, indices)
else:
return select_indices(self.parents, indices)
[docs]
def get_parents_indices(self, indices=None):
"""
Returns the indices of the parents of the states in the batch.
Each i-th item in the returned list contains the indices in self.states that
contains the parent of self.states[i], if it is present there. If a parent
is not present in self.states (because it is the source), the indices is -1.
Returns
-------
self.parents_indices
The indices in self.states of the parents of self.states.
indices: list, tuple, tensor or np.ndarray
1-dimentional sequence of bacth indecies for selecting parents indices,
optional. If None (default), the indices of parents of all states in the
batch will be returned.
"""
if self._parents_available is False:
self._compute_parents()
return select_indices(self.parents_indices, indices)
def _compute_parents(self):
"""
Obtains the parent (single parent for each state) of all states in the batch
and its index.
The parents are computed, obtaining all necessary components, if they are not
readily available. Missing components and newly computed components are added
to the batch (self.component is set). The following variables are stored:
- self.parents: the parent of each state in the batch. It will be the same type
as self.states (list of lists or tensor)
Length: n_states
Shape: [n_states, state_dims]
- self.parents_indices: the position of each parent in self.states tensor. If a
parent is not present in self.states (i.e. it is source), the corresponding
index is -1.
self._parents_available is set to True.
"""
self.parents = []
self.parents_indices = []
indices_dict = {}
indices_next = 0
# Iterate over the trajectories to obtain the parents from the states
for traj_idx, batch_indices in self.trajectories.items():
# parent is source
self.parents.append(self.envs[traj_idx].source)
# there's no source state in the batch
self.parents_indices.append(-1)
# parent is not source
# TODO: check if tensor and sort without iter
self.parents.extend([self.states[idx] for idx in batch_indices[:-1]])
self.parents_indices.extend(batch_indices[:-1])
# Store the indices required to reorder the parents lists in the same
# order as the states
for b_idx in batch_indices:
indices_dict[b_idx] = indices_next
indices_next += 1
# Sort parents list in the same order as states
# TODO: check if tensor and sort without iter
self.parents = [self.parents[indices_dict[idx]] for idx in range(len(self))]
self.parents_indices = tlong(
[self.parents_indices[indices_dict[idx]] for idx in range(len(self))],
device=self.device,
)
self._parents_available = True
# TODO: consider converting directly from self.parents
def _compute_parents_policy(self):
"""
Obtains the parent (single parent for each state) of all states in the batch,
in policy format. The parents are computed, obtaining all necessary
components, if they are not readily available. Missing components and newly
computed components are added to the batch (self.component is set). The
following variable is stored:
- self.parents_policy: the parent of each state in the batch in policy format.
Shape: [n_states, state_policy_dims]
self.parents_policy is stored as a torch tensor and
self._parents_policy_available is set to True.
"""
self.states_policy = self.get_states(policy=True)
self.parents_policy = torch.zeros_like(self.states_policy)
# Iterate over the trajectories to obtain the parents from the states
for traj_idx, batch_indices in self.trajectories.items():
# parent is source
self.parents_policy[batch_indices[0]] = tfloat(
self.envs[traj_idx].state2policy(self.envs[traj_idx].source),
device=self.device,
float_type=self.float,
)
# parent is not source
self.parents_policy[batch_indices[1:]] = self.states_policy[
batch_indices[:-1]
]
self._parents_policy_available = True
[docs]
def get_parents_all(
self, policy: bool = False, force_recompute: bool = False
) -> Tuple[
Union[List, TensorType["n_parents", "..."]],
TensorType["n_parents", "..."],
TensorType["n_parents"],
]:
"""
Returns the whole set of parents, their corresponding actions and indices of
all states in the batch. If the parents are not available
(self._parents_all_available is False) or if force_recompute is True, then
self._compute_parents_all() is called to compute the required components.
The parents are returned in "policy format" if policy is True, otherwise they
are returned in "GFlowNet" format (default).
Args
----
policy : bool
If True, the policy format of parents is returned. Otherwise, the GFlowNet
format is returned.
force_recompute : bool
If True, the parents are recomputed even if they are available.
Returns
-------
self.parents_all or self.parents_all_policy : list or torch.tensor
The whole set of parents of all states in the batch.
self.parents_actions_all : torch.tensor
The actions corresponding to each parent in self.parents_all or
self.parents_all_policy, linking them to the corresponding state in the
trajectory.
self.parents_all_indices : torch.tensor
The state index corresponding to each parent in self.parents_all or
self.parents_all_policy, linking them to the corresponding state in the
batch.
"""
if self.continuous:
raise Exception("get_parents() is ill-defined for continuous environments!")
if self._parents_all_available is False or force_recompute is True:
self._compute_parents_all()
if policy:
return (
self.parents_all_policy,
self.parents_actions_all,
self.parents_all_indices,
)
else:
return self.parents_all, self.parents_actions_all, self.parents_all_indices
def _compute_parents_all(self):
"""
Obtains the whole set of parents all states in the batch. The parents are
computed via env.get_parents(). The following components are obtained:
- self.parents_all: all the parents of all states in the batch. It will be the
same type as self.states (list of lists or tensor)
Length: n_parents
Shape: [n_parents, state_dims]
- self.parents_actions_all: the actions corresponding to the transition from
each parent in self.parents_all to its corresponding state in the batch.
Shape: [n_parents, action_dim]
- self.parents_all_indices: the indices corresponding to the state in the batch
of which each parent in self.parents_all is a parent.
Shape: [n_parents]
- self.parents_all_policy: self.parents_all in policy format.
Shape: [n_parents, state_policy_dims]
All the above components are stored as torch tensors and
self._parents_all_available is set to True.
"""
# Iterate over the trajectories to obtain all parents
self.parents_all = []
self.parents_actions_all = []
self.parents_all_indices = []
self.parents_all_policy = []
for idx, traj_idx in enumerate(self.traj_indices):
state = self.states[idx]
done = self.done[idx]
action = self.actions[idx]
parents, parents_a = self.envs[traj_idx].get_parents(
state=state,
done=done,
action=action,
)
assert self.readonly_env.action2representative(action) in parents_a, f"""
Sampled action is not in the list of valid actions from parents.
\nState:\n{state}\nAction:\n{action}
"""
self.parents_all.extend(parents)
self.parents_actions_all.extend(parents_a)
self.parents_all_indices.extend([idx] * len(parents))
self.parents_all_policy.append(self.envs[traj_idx].states2policy(parents))
# Convert to tensors
self.parents_actions_all = tfloat(
self.parents_actions_all,
device=self.device,
float_type=self.float,
)
self.parents_all_indices = tlong(
self.parents_all_indices,
device=self.device,
)
self.parents_all_policy = torch.cat(self.parents_all_policy)
self._parents_all_available = True
# TODO: opportunity to improve efficiency by caching.
[docs]
def get_masks_forward(
self,
of_parents: bool = False,
force_recompute: bool = False,
indices: Optional[Union[List, Tuple, TensorType, npt.NDArray]] = None,
) -> TensorType["n_states", "action_space_dim"]:
"""
Returns the forward mask of invalid actions of all states in the batch or of
their parent in the trajectory if of_parents is True. The masks are computed
via self._compute_masks_forward if they are not available or if force_recompute
is True.
Args
----
of_parents : bool
If True, the returned masks will correspond to the parents of the states,
instead of to the states (default).
force_recompute : bool
If True, the masks are recomputed even if they are available.
indices: list, tuple, tensor or np.ndarray
1-dimentional sequence of bacth indecies for selecting masks, optional.
If None (default), the masks of all states in the batch will be returned.
Returns
-------
self.masks_invalid_actions_forward : torch.tensor
The forward mask of the selected states (by indices). If indices is None, the
masks of all states in the batch will be returned.
"""
if self._masks_forward_available is False or force_recompute is True:
self._compute_masks_forward()
# Make tensor
masks_invalid_actions_forward = tbool(
self.masks_invalid_actions_forward, device=self.device
)
if of_parents:
trajectories_parents = {
traj_idx: [-1] + batch_indices[:-1]
for traj_idx, batch_indices in self.trajectories.items()
}
parents_indices = tlong(
[
trajectories_parents[traj_idx][
self.trajectories[traj_idx].index(idx)
]
for idx, traj_idx in enumerate(self.traj_indices)
],
device=self.device,
)
masks_invalid_actions_forward_parents = torch.zeros_like(
masks_invalid_actions_forward
)
masks_invalid_actions_forward_parents[parents_indices == -1] = tbool(
self.source["mask_forward"], device=self.device
)
masks_invalid_actions_forward_parents[parents_indices != -1] = (
masks_invalid_actions_forward[parents_indices[parents_indices != -1]]
)
return select_indices(masks_invalid_actions_forward_parents, indices)
return select_indices(masks_invalid_actions_forward, indices)
def _compute_masks_forward(self):
"""
Computes the forward mask of invalid actions of all states in the batch, by
calling env.get_mask_invalid_actions_forward(). self._masks_forward_available
is set to True.
"""
# Iterate over the trajectories to compute all forward masks
for idx, mask in enumerate(self.masks_invalid_actions_forward):
if mask is not None:
continue
state = self.states[idx]
done = self.done[idx]
traj_idx = self.traj_indices[idx]
self.masks_invalid_actions_forward[idx] = self.envs[
traj_idx
].get_mask_invalid_actions_forward(state, done)
self._masks_forward_available = True
# TODO: opportunity to improve efficiency by caching. Note that
# env.get_masks_invalid_actions_backward() may be expensive because it calls
# env.get_parents().
[docs]
def get_masks_backward(
self,
force_recompute: bool = False,
indices: Optional[Union[List, Tuple, TensorType, npt.NDArray]] = None,
) -> TensorType["n_states", "action_space_dim"]:
"""
Returns the backward mask of invalid actions of all states in the batch. The
masks are computed via self._compute_masks_backward if they are not available
or if force_recompute is True.
Args
----
force_recompute : bool
If True, the masks are recomputed even if they are available.
indices: list, tuple, tensor or np.ndarray
1-dimentional sequence of bacth indecies for selecting masks, optional.
If None (default), the masks of all states in the batch will be returned.
Returns
-------
self.masks_invalid_actions_backward : torch.tensor
The backward mask of the selected states (by indices). If indices is None, the
masks of all states in the batch will be returned.
"""
if self._masks_backward_available is False or force_recompute is True:
self._compute_masks_backward()
masks = tbool(self.masks_invalid_actions_backward, device=self.device)
return select_indices(masks, indices)
def _compute_masks_backward(self):
"""
Computes the backward mask of invalid actions of all states in the batch, by
calling env.get_mask_invalid_actions_backward(). self._masks_backward_available
is set to True.
"""
# Iterate over the trajectories to compute all backward masks
for idx, mask in enumerate(self.masks_invalid_actions_backward):
if mask is not None:
continue
state = self.states[idx]
done = self.done[idx]
traj_idx = self.traj_indices[idx]
self.masks_invalid_actions_backward[idx] = self.envs[
traj_idx
].get_mask_invalid_actions_backward(state, done)
self._masks_backward_available = True
# TODO: better handling of availability of rewards, logrewards, proxy_values.
[docs]
def get_rewards(
self,
log: bool = False,
force_recompute: Optional[bool] = False,
do_non_terminating: Optional[bool] = False,
) -> TensorType["n_states"]:
"""
Returns the rewards of all states in the batch (including not done).
Parameters
----------
log : bool
If True, return the logarithm of the rewards.
force_recompute : bool
If True, the rewards are recomputed even if they are available.
do_non_terminating : bool
If True, return the actual rewards of the non-terminating states. If
False, non-terminating states will be assigned reward 0.
"""
if self.rewards_available(log) is False or force_recompute is True:
self._compute_rewards(log, do_non_terminating)
if log:
return self.logrewards
else:
return self.rewards
[docs]
def get_proxy_values(
self,
force_recompute: Optional[bool] = False,
do_non_terminating: Optional[bool] = False,
) -> TensorType["n_states"]:
"""
Returns the proxy values of all states in the batch (including not done).
Parameters
----------
force_recompute : bool
If True, the proxy values are recomputed even if they are available.
do_non_terminating : bool
If True, return the actual proxy values of the non-terminating states. If
False, non-terminating states will be assigned value inf.
"""
if self._proxy_values_available is False or force_recompute is True:
self._compute_rewards(do_non_terminating=do_non_terminating)
return self.proxy_values
def _compute_rewards(
self, log: bool = False, do_non_terminating: Optional[bool] = False
):
"""
Computes rewards for all self.states by first converting the states into proxy
format. The result is stored in self.rewards as a torch.tensor
Parameters
----------
log : bool
If True, compute the logarithm of the rewards.
do_non_terminating : bool
If True, compute the rewards of the non-terminating states instead of
assigning reward 0 and proxy value inf.
"""
if do_non_terminating:
rewards, proxy_values = self.proxy.rewards(
self.states2proxy(), log, return_proxy=True
)
else:
rewards = self.proxy.get_min_reward(log) * torch.ones(
len(self), dtype=self.float, device=self.device
)
proxy_values = torch.full_like(rewards, torch.inf)
done = self.get_done()
if len(done) > 0:
states_proxy_done = self.get_terminating_states(proxy=True)
rewards[done], proxy_values[done] = self.proxy.rewards(
states_proxy_done, log, return_proxy=True
)
self.proxy_values = proxy_values
self._proxy_values_available = True
if log:
self.logrewards = rewards
self._logrewards_available = True
else:
self.rewards = rewards
self._rewards_available = True
[docs]
def get_rewards_parents(self, log: bool = False) -> TensorType["n_states"]:
"""
Returns the rewards of all parents in the batch.
Parameters
----------
log : bool
If True, return the logarithm of the rewards.
Returns
-------
self.rewards_parents or self.logrewards_parents
A tensor containing the rewards of the parents of self.states.
"""
if not self.rewards_parents_available(log):
self._compute_rewards_parents(log)
if log:
return self.logrewards_parents
else:
return self.rewards_parents
def _compute_rewards_parents(self, log: bool = False):
"""
Computes the rewards of self.parents by reusing the rewards of the states
(self.rewards).
Stores the result in self.rewards_parents or self.logrewards_parents.
Parameters
----------
log : bool
If True, compute the logarithm of the rewards.
"""
# TODO: this may return zero rewards for all parents if before
# rewards for states were computed with do_non_terminating=False
state_rewards = self.get_rewards(log=log, do_non_terminating=True)
rewards_parents = torch.zeros_like(state_rewards)
parent_indices = self.get_parents_indices()
parent_is_source = parent_indices == -1
rewards_parents[~parent_is_source] = state_rewards[
parent_indices[~parent_is_source]
]
rewards_source = self.get_rewards_source(log)
rewards_parents[parent_is_source] = rewards_source[parent_is_source]
if log:
self.logrewards_parents = rewards_parents
self._logrewards_parents_available = True
else:
self.rewards_parents = rewards_parents
self._rewards_parents_available = True
[docs]
def get_rewards_source(self, log: bool = False) -> TensorType["n_states"]:
"""
Returns rewards of the corresponding source states for each state in the batch.
Parameters
----------
log : bool
If True, return the logarithm of the rewards.
Returns
-------
self.rewards_source or self.logrewards_source
A tensor containing the rewards the source states.
"""
if not self.rewards_source_available(log):
self._compute_rewards_source(log)
if log:
return self.logrewards_source
else:
return self.rewards_source
def _compute_rewards_source(self, log: bool = False):
"""
Computes a tensor of length len(self.states) with the rewards of the
corresponding source states.
Stores the result in self.rewards_source or self.logrewards_source.
Parameters
----------
log : bool
If True, compute the logarithm of the rewards.
"""
# This will not work if source is randomised
if not self.conditional:
source_proxy = self.readonly_env.state2proxy(self.readonly_env.source)
reward_source = self.proxy.rewards(source_proxy, log)
rewards_source = reward_source.expand(len(self))
else:
raise NotImplementedError
if log:
self.logrewards_source = rewards_source
self._logrewards_source_available = True
else:
self.rewards_source = rewards_source
self._rewards_source_available = True
[docs]
def get_terminating_states(
self,
sort_by: str = "insertion",
policy: Optional[bool] = False,
proxy: Optional[bool] = False,
) -> Union[TensorType["n_trajectories", "..."], npt.NDArray[np.float32], List]:
"""
Returns the terminating states in the batch, that is all states with done =
True. The states will be returned in either GFlowNet format (default), policy
(policy = True) or proxy (proxy = True) format. If both policy and proxy are
True, it raises an error due to the ambiguity. The returned states may be
sorted by order of insertion (sort_by = "insert[ion]", default) or by
trajectory index (sort_by = "traj[ectory]".
Args
----
sort_by : str
Indicates how to sort the output:
- insert[ion]: sort by order of insertion (states of trajectories that
reached the terminating state first come first)
- traj[ectory]: sort by trajectory index (the order in the ordered
dict self.trajectories)
policy : bool
If True, the policy format of the states is returned.
proxy : bool
If True, the proxy format of the states is returned.
"""
if sort_by == "insert" or sort_by == "insertion":
indices = np.arange(len(self))
elif sort_by == "traj" or sort_by == "trajectory":
indices = np.argsort(self.traj_indices)
else:
raise ValueError("sort_by must be either insert[ion] or traj[ectory]")
if policy is True and proxy is True:
raise ValueError(
"Ambiguous request! Only one of policy or proxy can be True."
)
traj_indices = None
if torch.is_tensor(self.states):
indices = tlong(indices, device=self.device)
done = self.get_done()[indices]
states_term = self.states[indices][done, :]
if self.conditional and (policy is True or proxy is True):
traj_indices = tlong(self.traj_indices, device=self.device)[indices][
done
]
assert len(traj_indices) == len(torch.unique(traj_indices))
elif isinstance(self.states, list):
states_term = [self.states[idx] for idx in indices if self.done[idx]]
if self.conditional and (policy is True or proxy is True):
done = np.array(self.done, dtype=bool)[indices]
traj_indices = np.array(self.traj_indices)[indices][done]
assert len(traj_indices) == len(np.unique(traj_indices))
else:
raise NotImplementedError("self.states can only be list or torch.tensor")
if policy is True:
return self.states2policy(states_term, traj_indices)
elif proxy is True:
return self.states2proxy(states_term, traj_indices)
else:
return states_term
[docs]
def get_terminating_rewards(
self,
sort_by: str = "insertion",
log: bool = False,
force_recompute: Optional[bool] = False,
) -> TensorType["n_trajectories"]:
"""
Returns the reward of the terminating states in the batch, that is all states
with done = True. The returned rewards may be sorted by order of insertion
(sort_by = "insert[ion]", default) or by trajectory index (sort_by =
"traj[ectory]".
Parameters
----------
sort_by : str
Indicates how to sort the output:
- insert[ion]: sort by order of insertion (rewards of trajectories that
reached the terminating state first come first)
- traj[ectory]: sort by trajectory index (the order in the ordered
dict self.trajectories)
log : bool
If True, return the logarithm of the rewards.
force_recompute : bool
If True, the rewards are recomputed even if they are available.
"""
if sort_by == "insert" or sort_by == "insertion":
indices = np.arange(len(self))
elif sort_by == "traj" or sort_by == "trajectory":
indices = np.argsort(self.traj_indices)
else:
raise ValueError("sort_by must be either insert[ion] or traj[ectory]")
if self.rewards_available(log) is False or force_recompute is True:
self._compute_rewards(log, do_non_terminating=False)
done = self.get_done()[indices]
if log:
return self.logrewards[indices][done]
else:
return self.rewards[indices][done]
[docs]
def get_terminating_proxy_values(
self,
sort_by: str = "insertion",
force_recompute: Optional[bool] = False,
) -> TensorType["n_trajectories"]:
"""
Returns the proxy values of the terminating states in the batch, that is all
states with done = True. The returned proxy values may be sorted by order of
insertion (sort_by = "insert[ion]", default) or by trajectory index (sort_by =
"traj[ectory]".
Parameters
----------
sort_by : str
Indicates how to sort the output:
- insert[ion]: sort by order of insertion (proxy values of trajectories
that reached the terminating state first come first)
- traj[ectory]: sort by trajectory index (the order in the ordered
dict self.trajectories)
force_recompute : bool
If True, the proxy_values are recomputed even if they are available.
"""
if sort_by == "insert" or sort_by == "insertion":
indices = np.arange(len(self))
elif sort_by == "traj" or sort_by == "trajectory":
indices = np.argsort(self.traj_indices)
else:
raise ValueError("sort_by must be either insert[ion] or traj[ectory]")
if self._proxy_values_available is False or force_recompute is True:
self._compute_rewards(log, do_non_terminating=False)
done = self.get_done()[indices]
return self.proxy_values[indices][done]
[docs]
def get_actions_trajectories(self) -> List[List[Tuple]]:
"""
Returns the actions corresponding to all trajectories in the batch, sorted by
trajectory index (the order in the ordered dict self.trajectories).
"""
actions_trajectories = []
for batch_indices in self.trajectories.values():
actions_trajectories.append([self.actions[idx] for idx in batch_indices])
return actions_trajectories
[docs]
def get_states_of_trajectory(
self,
traj_idx: int,
states: Optional[
Union[TensorType["n_states", "..."], npt.NDArray[np.float32], List]
] = None,
traj_indices: Optional[Union[List, TensorType["n_states"]]] = None,
) -> Union[
TensorType["n_states", "state_proxy_dims"], npt.NDArray[np.float32], List
]:
"""
Returns the states of the trajectory indicated by traj_idx. If states and
traj_indices are not None, then these will be the only states and trajectory
indices considered.
See: states2policy()
See: states2proxy()
Args
----
traj_idx : int
Index of the trajectory from which to return the states.
states : tensor, array or list
States from the trajectory to consider.
traj_indices : tensor, array or list
Trajectory indices of the trajectory to consider.
Returns
-------
Tensor, array or list of states of the requested trajectory.
"""
# TODO: re-implement using the batch indices in self.trajectories[traj_idx]
# If either states or traj_indices are not None, both must be the same type and
# have the same length.
# TODO: or add sort_by
if states is not None or traj_indices is not None:
assert type(states) == type(traj_indices)
assert len(states) == len(traj_indices)
else:
states = self.states
traj_indices = self.traj_indices
if torch.is_tensor(states):
return states[tlong(traj_indices, device=self.device) == traj_idx]
elif isinstance(states, list):
return [
state for state, idx in zip(states, traj_indices) if idx == traj_idx
]
elif isinstance(states, np.ndarray):
return states[np.array(traj_indices) == traj_idx]
else:
raise ValueError("states can only be list, torch.tensor or ndarray")
[docs]
def get_logprobs_of_trajectory(
self,
traj_idx: int,
backward: bool = False,
) -> Union[
TensorType["n_states", "state_proxy_dims"], npt.NDArray[np.float32], List
]:
"""
Returns the logprobs of the trajectory indicated by traj_idx.
Parameters
----------
traj_idx : int
Index of the trajectory from which to return the states.
backward: bool
If True, returns backward logprobs, othervise returns forward logprobs
Returns
-------
A list of logprobs of the actions the requested trajectory.
"""
# TODO: re-implement using the batch indices in self.trajectories[traj_idx]
# If either states or traj_indices are not None, both must be the same type and
# have the same length.
# TODO: or add sort_by
if backward:
logprobs = self.logprobs_backward
else:
logprobs = self.logprobs_forward
traj_indices = self.traj_indices
return [x for x, idx in zip(logprobs, traj_indices) if idx == traj_idx]
[docs]
def merge(self, batches: List):
"""
Merges the current Batch (self) with the Batch or list of Batches passed as
argument.
Returns
-------
self
"""
if not isinstance(batches, list):
batches = [batches]
for batch in batches:
if len(batch) == 0:
continue
# Shift trajectory indices of batch to merge
if len(self) == 0:
traj_idx_shift = 0
else:
traj_idx_shift = np.max(list(self.trajectories.keys())) + 1
batch._shift_indices(traj_shift=traj_idx_shift, batch_shift=len(self))
# Merge main data
self.size += batch.size
self.envs.update(batch.envs)
self.trajectories.update(batch.trajectories)
self.traj_indices.extend(batch.traj_indices)
self.state_indices.extend(batch.state_indices)
self.states.extend(batch.states)
self.actions.extend(batch.actions)
self.logprobs_forward.extend(batch.logprobs_forward)
self.logprobs_backward.extend(batch.logprobs_backward)
self.logprobs_forward_avail.extend(batch.logprobs_forward_avail)
self.logprobs_backward_avail.extend(batch.logprobs_backward_avail)
self.done.extend(batch.done)
self.masks_invalid_actions_forward = extend(
self.masks_invalid_actions_forward,
batch.masks_invalid_actions_forward,
)
self.masks_invalid_actions_backward = extend(
self.masks_invalid_actions_backward,
batch.masks_invalid_actions_backward,
)
# Merge "optional" data
if self.states_policy is not None and batch.states_policy is not None:
self.states_policy = extend(self.states_policy, batch.states_policy)
else:
self.states_policy = None
if self._parents_available and batch._parents_available:
self.parents = extend(self.parents, batch.parents)
else:
self.parents = None
if self._parents_policy_available and batch._parents_policy_available:
self.parents_policy = extend(self.parents_policy, batch.parents_policy)
else:
self.parents_policy = None
if self._parents_all_available and batch._parents_all_available:
self.parents_all = extend(self.parents_all, batch.parents_all)
else:
self.parents_all = None
if self._rewards_available and batch._rewards_available:
self.rewards = extend(self.rewards, batch.rewards)
else:
self.rewards = None
if self._logrewards_available and batch._logrewards_available:
self.logrewards = extend(self.logrewards, batch.logrewards)
else:
self.logrewards = None
assert self.is_valid()
return self
[docs]
def is_valid(self) -> bool:
"""
Performs basic checks on the current state of the batch.
Returns
-------
True if all the checks are valid, False otherwise.
"""
if len(self.states) != len(self):
return False
if len(self.actions) != len(self):
return False
if len(self.logprobs_forward) != len(self):
return False
if len(self.logprobs_backward) != len(self):
return False
if len(self.done) != len(self):
return False
if len(self.traj_indices) != len(self):
return False
if len(self.state_indices) != len(self):
return False
if set(np.unique(self.traj_indices)) != set(self.envs.keys()):
return False
if set(self.trajectories.keys()) != set(self.envs.keys()):
return False
batch_indices = [
idx for indices in self.trajectories.values() for idx in indices
]
if len(batch_indices) != len(self):
return False
if len(np.unique(batch_indices)) != len(batch_indices):
return False
return True
[docs]
def traj_indices_are_consecutive(self) -> bool:
"""
Returns True if the trajectory indices start from 0 and are consecutive; False
otherwise.
"""
trajectories_consecutive = list(self.trajectories) == list(
np.arange(self.get_n_trajectories())
)
envs_consecutive = list(self.envs) == list(np.arange(self.get_n_trajectories()))
return trajectories_consecutive and envs_consecutive
[docs]
def make_indices_consecutive(self):
"""
Updates the trajectory indices as well as the env ids such that they start from
0 and are consecutive. Note that only the trajectory indices are changed, but
importantly the order of the main data in the batch is preserved.
Examples:
- Original indices: 0, 10, 20
- New indices: 0, 1, 2
- Original indices: 1, 5, 3
- New indices: 0, 1, 2
Note: this method is unsued as of September 1st 2023, but is left here for
potential future use.
"""
if self.traj_indices_are_consecutive():
return
self.traj_indices = self.get_trajectory_indices(consecutive=True).tolist()
self.trajectories = OrderedDict(
zip(range(self.get_n_trajectories()), self.trajectories.values())
)
self.envs = OrderedDict(
{idx: env.set_id(idx) for idx, env in enumerate(self.envs.values())}
)
assert self.traj_indices_are_consecutive()
assert self.is_valid()
def _shift_indices(self, traj_shift: int, batch_shift: int):
"""
Shifts all the trajectory indices and environment ids by traj_shift and the batch
indices by batch_shift.
Returns
-------
self
"""
if not self.is_valid():
raise Exception("Batch is not valid before attempting indices shift")
self.traj_indices = [idx + traj_shift for idx in self.traj_indices]
self.trajectories = {
traj_idx + traj_shift: list(map(lambda x: x + batch_shift, batch_indices))
for traj_idx, batch_indices in self.trajectories.items()
}
self.envs = {
k + traj_shift: env.set_id(k + traj_shift) for k, env in self.envs.items()
}
if not self.is_valid():
raise Exception("Batch is not valid after performing indices shift")
return self
# TODO: rewrite once cache is implemnted
[docs]
def get_item(
self,
item: str,
env: GFlowNetEnv = None,
traj_idx: int = None,
action_idx: int = None,
backward: bool = False,
):
"""
Returns the item specified by item of either:
- environment env, OR
- trajectory traj_idx AND action number action_idx (in the order of
sampling)
If all arguments are given, then they must be consistent, otherwise an
exception (assert) is raised due to ambiguity.
If a mask is requested but is missing, it is computed and stored.
Args
----
item : str
String identifier of the item to retrieve from the batch. Options
- state
- parent
- action
- done
- mask_f[orward]
- mask_b[ackward]
traj_idx : int
Trajectory index
action_idx : int
Action index. Regardless of forward of backward, n-th item sampled when
forming the batch.
backward : bool
Whether the trajectory is sampling backward. False (forward) by default.
Returns
-------
The requested item if it is available or None if it is not. It raises an error
if the request can be identified as incorrect.
"""
# Preliminary checks
if env is not None:
if traj_idx is not None:
assert (
env.id == traj_idx
), "env.id {env.id} different to traj_idx {traj_idx}."
else:
traj_idx = env.id
if action_idx is not None:
assert (
env.n_actions == action_idx
), "env.n_actions {env.n_actions} different to action_idx {action_idx}."
else:
action_idx = env.n_actions
else:
assert (
traj_idx is not None and action_idx is not None
), "Either env or traj_idx AND action_idx must be provided"
# Handle action_idx = 0 (source state)
if action_idx == 0:
if backward is False:
if item == "state":
return self.source["state"]
elif item == "mask_f" or item == "mask_forward":
return self.source["mask_forward"]
else:
raise ValueError(
"Only state or mask_forward are available for a fresh env "
"(action_idx = 0)"
)
# else:
# # TODO: handle backward masks with cache
# raise NotImplementedError(
# "get_item at action_idx = 0 for backward trajectories is currently "
# "not supported"
# )
batch_idx = self.traj_idx_action_idx_to_batch_idx(
traj_idx, action_idx, backward
)
if batch_idx is None:
# TODO: handle this
if env is None:
raise ValueError(
"{item} not available for action {action_idx} of trajectory "
"{traj_idx} and no env was provided."
)
else:
if item == "state":
return env.state
elif item == "done":
return env.done
elif item == "mask_f" or item == "mask_forward":
return env.get_mask_invalid_actions_forward()
elif item == "mask_b" or item == "mask_backward":
return env.get_mask_invalid_actions_backward()
else:
raise ValueError(
"Not available in the batch. item must be one of: state, done, "
"mask_f[orward] or mask_b[ackward]."
)
if item == "state":
return self.states[batch_idx]
elif item == "parent":
return self.parents[batch_idx]
elif item == "action":
return self.actions[batch_idx]
elif item == "logprob_f" or item == "logprob_forward":
return self.logprobs_forward[batch_idx]
elif item == "logprob_b" or item == "logprob_backward":
return self.logprobs_backward[batch_idx]
elif item == "done":
return self.done[batch_idx]
elif item == "mask_f" or item == "mask_forward":
if self.masks_invalid_actions_forward[batch_idx] is None:
state = self.states[batch_idx]
done = self.done[batch_idx]
self.masks_invalid_actions_forward[batch_idx] = self.envs[
traj_idx
].get_mask_invalid_actions_forward(state, done)
return self.masks_invalid_actions_forward[batch_idx]
elif item == "mask_b" or item == "mask_backward":
if self.masks_invalid_actions_backward[batch_idx] is None:
state = self.states[batch_idx]
done = self.done[batch_idx]
self.masks_invalid_actions_backward[batch_idx] = self.envs[
traj_idx
].get_mask_invalid_actions_backward(state, done)
return self.masks_invalid_actions_backward[batch_idx]
else:
raise ValueError(
"item must be one of: state, parent, action, done, mask_f[orward] or "
"mask_b[ackward]"
)
[docs]
def get_indices_of_previous_transitions(
self, envs: List[GFlowNetEnv], backward: bool
) -> List[int]:
"""
Get batch indices of the latest elements (states, actions, etc) added to the
batch for the provided environments.
Parameters
----------
envs: list
List of envs for which the batch indices are requested.
backward: bool
Whether the trajectories are sampled backward (True) or forward (False).
Returns
-------
list
Batch indices of the previous transitions for the requested environments.
Environments without any previous transitions in the batch are assigned
None.
"""
indices = []
for env in envs:
if env.id in self.trajectories:
indices.append(self.trajectories[env.id][0 if backward else -1])
else:
indices.append(None)
return indices
[docs]
def get_actions_of_previous_transitions(
self, envs: List[GFlowNetEnv], backward: bool
) -> List:
"""
Retrieves the latest actions added to the batch for the provided envs.
Parameters
----------
envs: list
List of envs for which the actions are requested.
backward: bool
Indicates whether the trajectories are sampled backward (True) or forward
(False).
Returns
-------
actions: list
Actions of previous transitions for the requested envs
"""
indices = self.get_indices_of_previous_transitions(envs, backward)
return [self.actions[idx] if idx is None else None for idx in indices]
[docs]
def compute_logprobs_trajectories(
batch: Batch,
env: GFlowNetEnv = None,
forward_policy: Policy = None,
backward_policy: Policy = None,
backward: bool = False,
):
"""
Computes the forward or backward log probabilities of the trajectories in a
batch.
The log probabilities are computed according to the forward or backward policy
models passed as parameters.
Parameters
----------
batch : Batch
A batch of data, containing all the states in the trajectories.
env : :py:class:`gflownet.envs.base.GFlowNetEnv`, optional
An instance of the environment used to compute log probabilities of state
transitions. If None, batch.readonly_env is used.
forward_policy : :py:class:`gflownet.policy.base.Policy`, optional
The model used to compute the forward policy outputs from input states. It is
ignored if `backward` is True.
bacward_policy : :py:class:`gflownet.policy.base.Policy`, optional
Same as `forward_policy`, but for the backward policy. It is ignored if
`backward` is False.
backward : bool
False: log probabilities of forward trajectories.
True: log probabilities of backward trajectories.
Returns
-------
logprobs : torch.tensor
The log probabilities of the trajectories.
"""
if env is None:
env = batch.readonly_env
if backward:
assert backward_policy is not None
else:
assert forward_policy is not None
# Make indices of batch consecutive since they are used for indexing here
traj_indices = batch.get_trajectory_indices(consecutive=True)
# Take logprobs from the batch if they are available.
logprobs_states, logprobs_avail = batch.get_logprobs(backward)
# Compute the unavailable log probs from the states and actions in the batch
indices_select = torch.where(~logprobs_avail)[0]
if len(indices_select) > 0:
if len(indices_select) == len(logprobs_avail):
# Set select indices to None to select everything
indices_select = None
# Get necessary tensors from batch
states = batch.get_states(policy=False, indices=indices_select)
actions = batch.get_actions(indices=indices_select)
parents = batch.get_parents(policy=False, indices=indices_select)
states_policy = batch.get_states(policy=True, indices=indices_select)
parents_policy = batch.get_parents(policy=True, indices=indices_select)
if backward:
# Backward trajectories
masks_b = batch.get_masks_backward(indices=indices_select)
policy_output_b = backward_policy(states_policy)
logprobs_states_val = env.get_logprobs(
policy_output_b, actions, masks_b, states, backward
)
if indices_select is None:
logprobs_states = logprobs_states_val
else:
logprobs_states[indices_select] = logprobs_states_val
else:
# Forward trajectories
masks_f = batch.get_masks_forward(of_parents=True, indices=indices_select)
policy_output_f = forward_policy(parents_policy)
logprobs_states_val = env.get_logprobs(
policy_output_f, actions, masks_f, parents, backward
)
if indices_select is None:
logprobs_states = logprobs_states_val
else:
logprobs_states[indices_select] = logprobs_states_val
# Sum log probabilities of all transitions in each trajectory
logprobs = torch.zeros(
batch.get_n_trajectories(),
dtype=logprobs_states.dtype,
device=logprobs_states.device,
).index_add_(0, traj_indices, logprobs_states)
return logprobs