gflownet.utils.batch

Classes

Batch

Functions

compute_logprobs_trajectories(batch[, env, ...])

Computes the forward or backward log probabilities of the trajectories in a

Module Contents

class gflownet.utils.batch.Batch(env=None, proxy=None, device='cpu', float_type=32, collect_forwards_masks=False, collect_backwards_masks=False)[source]
Parameters:
  • env (GFlowNetEnv) – An instance of the environment that will be used to form the batch.

  • proxy (Proxy) – An instance of a GFlowNet proxy that will be used to compute proxy values and rewards.

  • device (str or torch.device) – torch.device or string indicating the device to use (“cpu” or “cuda”)

  • float_type (torch.dtype or int) – One of float torch.dtype or an int indicating the float precision (16, 32 or 64).

device = 'cpu'[source]
float = 32[source]
proxy = None[source]
size = 0[source]
envs[source]
trajectories[source]
is_backward[source]
traj_indices = [][source]
state_indices = [][source]
states = [][source]
actions = [][source]
logprobs_forward = [][source]
logprobs_backward = [][source]
logprobs_forward_avail = [][source]
logprobs_backward_avail = [][source]
done = [][source]
masks_invalid_actions_forward = [][source]
masks_invalid_actions_backward = [][source]
parents = [][source]
parents_all = [][source]
parents_actions_all = [][source]
n_actions = [][source]
states_policy = None[source]
parents_policy = None[source]
collect_forwards_masks = False[source]
collect_backwards_masks = False[source]
__len__()[source]
batch_idx_to_traj_state_idx(batch_idx)[source]
Parameters:

batch_idx (int)

traj_idx_to_batch_indices(traj_idx)[source]
Parameters:

traj_idx (int)

traj_state_idx_to_batch_idx(traj_idx, state_idx)[source]
Parameters:
  • traj_idx (int)

  • state_idx (int)

traj_idx_action_idx_to_batch_idx(traj_idx, action_idx, backward)[source]
Parameters:
  • traj_idx (int)

  • action_idx (int)

  • backward (bool)

idx2state_idx(idx)[source]
Parameters:

idx (int)

rewards_available(log=False)[source]

Returns True if the (log)rewards are available.

Parameters:

log (bool) – If True, check self._logrewards_available. Otherwise (default), check self._rewards_available.

Returns:

bool – True if the (log)rewards are available, False otherwise.

Return type:

bool

rewards_parents_available(log=False)[source]

Returns True if the (log)rewards of the parents are available.

Parameters:

log (bool) – If True, check self._logrewards_parents_available. Otherwise (default), check self._rewards_parents_available.

Returns:

bool – True if the (log)rewards of the parents are available, False otherwise.

Return type:

bool

rewards_source_available(log=False)[source]

Returns True if the (log)rewards of the source are available.

Parameters:

log (bool) – If True, check self._logrewards_source_available. Otherwise (default), check self._rewards_source_available.

Returns:

bool – True if the (log)rewards of the source are available, False otherwise.

Return type:

bool

set_env(env)[source]

Sets the generic environment passed as an argument and initializes the environment-dependent properties.

Parameters:

env (gflownet.envs.base.GFlowNetEnv)

set_proxy(proxy)[source]

Sets the proxy, used to compute rewards from a batch of states.

Parameters:

proxy (gflownet.proxy.base.Proxy)

add_to_batch(envs, actions, logprobs, logprobs_rev, valids, backward=False, train=True)[source]

Adds information from a list of environments and actions to the batch after performing steps in the envs. If train is False, only the variables of terminating states are stored.

Parameters:
  • envs (list) – A list of environments (GFlowNetEnv).

  • actions (list) – A list of actions attempted or performed on the envs.

  • logprobs (torch.tensor or list of None) – Log probabilities corresponding to the actions or None.

  • logprobs_rev (torch.tensor or list of None) – Log probabilities corresponding to the transitions in the opposite direction to the sampling direction, from the current states but with the actions added in the previous step. It may be a list of None.

  • valids (list) – A list of boolean values indicated whether the actions were valid.

  • backward (bool) –

    A boolean value indicating whether the action was sampled backward (False by default). If True, the behavior is slightly different so as to match what is stored in forward sampling:

    • If it is the first state in the trajectory (action from a done state/env), then done is stored as True, instead of taking env.done which will be False after having performed the step.

    • If it is not the first state in the trajectory, the stored state will be the previous one in the trajectory, to match the state-action stored in forward sampling and the convention that the source state is not stored, but the terminating state is repeated with action eos.

  • train (bool) – A boolean value indicating whether the data to add to the batch will be used for training. Optional, default is True.

get_n_trajectories()[source]

Returns the number of trajectories in the batch.

Returns:

The number of trajectories in the batch (int).

Return type:

int

get_unique_trajectory_indices()[source]

Returns the unique trajectory indices as the keys of self.trajectories, which is an OrderedDict, as a list.

Return type:

List

get_trajectory_indices(consecutive=False, return_mapping_dict=False)[source]

Returns the trajectory index of all elements in the batch as a long int torch tensor.

Parameters:
  • consecutive (bool) – If True, the trajectory indices are mapped to consecutive indices starting from 0, in the order of the OrderedDict self.trajectory.keys(). If False (default), the trajectory indices are returned as they are.

  • return_mapping_dict (bool) – If True, the dictionary mapping actual_index: consecutive_index is returned as a second argument. Ignored if consecutive is False.

Returns:

  • traj_indices (torch.tensor) – self.traj_indices as a long int torch tensor.

  • traj_index_to_consecutive_dict (dict) – A dictionary mapping the actual trajectory indices in the Batch to the consecutive indices. Ommited if return_mapping_dict is False (default).

Return type:

torchtyping.TensorType[n_states, int]

get_state_indices()[source]

Returns the state index of all elements in the batch as a long int torch tensor.

Returns:

state_indices (torch.tensor) – self.state_indices as a long int torch tensor.

Return type:

torchtyping.TensorType[n_states, int]

get_states(policy=False, proxy=False, force_recompute=False, indices=None)[source]

Returns all the states in the batch.

The states are returned in “policy format” if policy is True, in “proxy format” if proxy is True and otherwise they are returned in “GFlowNet” format by default. An error is raised if both policy and proxy are True.

Parameters:
  • policy (bool) – If True, the policy format of the states is returned and self.states_policy is updated if not available yet or if force_recompute is True.

  • proxy (bool) – If True, the proxy format of the states is returned. States in proxy format are not stored.

  • force_recompute (bool) – If True, the policy states are recomputed even if they are available. Ignored if policy is False.

  • indices (list, tuple, tensor or np.ndarray) – 1-dimensional sequence of batch indices for selecting states, optional. If None (default), all the states will be returned.

Returns:

  • self.states or self.states_policy or self.states2proxy(self.states) (list or)

  • torch.tensor or ndarray – The set of states in the selected format with the selected indices. If indices is None, all states of the batch are returned.

Return type:

Union[torchtyping.TensorType[n_states, …], numpy.typing.NDArray[numpy.float32], List]

states2policy(states=None, traj_indices=None)[source]

Converts states from a list of states in GFlowNet format to a tensor of states in policy format.

Parameters:
  • states (list) – List of states in GFlowNet format.

  • traj_indices (list or torch.tensor) – Ids indicating which env corresponds to each state in states. It is only used if the environments are conditional to call state2policy from the right environment. Ignored if self.conditional is False.

Returns:

states (torch.tensor) – States in policy format.

Return type:

torchtyping.TensorType[n_states, state_policy_dims]

states2proxy(states=None, traj_indices=None)[source]

Converts states from a list of states in GFlowNet format to a tensor of states in proxy format. Note that the implementatiuon of this method differs from Batch.states2policy() because the latter always returns torch.tensors. The output of the present method can also be numpy arrays or Python lists, depending on the proxy.

Parameters:
  • states (list) – List of states in GFlowNet format.

  • traj_indices (list or torch.tensor) – Ids indicating which env corresponds to each state in states. It is only used if the environments are conditional to call state2proxy from the right environment. Ignored if self.conditional is False.

Returns:

states (torch.tensor or ndarray or list) – States in proxy format.

Return type:

Union[torchtyping.TensorType[n_states, state_proxy_dims], numpy.typing.NDArray[numpy.float32], List]

get_actions(indices=None)[source]

Returns the actions in the batch.

Parameters:

indices (list, tuple, tensor or np.ndarray) – 1-dimentional sequence of bacth indecies for selecting actions, optional. If None (default), all the actions will be returned.

Returns:

list – The list of actions in the batch with selected indices. If indices is None, all the actions will be returned.

Return type:

List

get_logprobs(backward=False)[source]

Returns the logprobs in the batch as a float tensor.

If there is any None values in self.logprobs, the list cannot be converted to a tensor. This exception is caught and None is returned, as a signal that the logprobs are not available.

Parameters:

backward (bool) – Whether the requested logprobs are of backward transitions.

Returns:

  • logprobs (TensorType[“n_states”]) – Tensor of logprobs

  • valids (TensorType[“n_states”]) – Boolean tensor with flags indicating whether the correspondig logprobs are valid. The flag is False if the corresponding logprob value is a zero / None placefolder.

Return type:

Tuple[torchtyping.TensorType[n_states], torchtyping.TensorType[n_states]]

get_done()[source]

Returns the list of done flags as a boolean tensor.

Return type:

torchtyping.TensorType[n_states]

get_parents(policy=False, force_recompute=False, indices=None)[source]

Returns the parent (single parent for each state) of all states in the batch. The parents are computed, obtaining all necessary components, if they are not readily available. Missing components and newly computed components are added to the batch (self.component is set).

The parents are returned in “policy format” if policy is True, otherwise they are returned in “GFlowNet” format (default).

Parameters:
  • policy (bool) – If True, the policy format of parents is returned. Otherwise, the GFlowNet format is returned.

  • force_recompute (bool) – If True, the parents are recomputed even if they are available.

  • indices (list, tuple, tensor or np.ndarray) – 1-dimentional sequence of bacth indecies for selecting parents, optional. If None (default), the parents of all states in the batch will be returned.

Returns:

self.parents or self.parents_policy (torch.tensor) – The parent of states selected by indices. If indices is None, the parents of all states in the batch are returned.

Return type:

torchtyping.TensorType[n_states, …]

get_parents_indices(indices=None)[source]

Returns the indices of the parents of the states in the batch.

Each i-th item in the returned list contains the indices in self.states that contains the parent of self.states[i], if it is present there. If a parent is not present in self.states (because it is the source), the indices is -1.

Returns:

  • self.parents_indices – The indices in self.states of the parents of self.states.

  • indices (list, tuple, tensor or np.ndarray) – 1-dimentional sequence of bacth indecies for selecting parents indices, optional. If None (default), the indices of parents of all states in the batch will be returned.

get_parents_all(policy=False, force_recompute=False)[source]

Returns the whole set of parents, their corresponding actions and indices of all states in the batch. If the parents are not available (self._parents_all_available is False) or if force_recompute is True, then self._compute_parents_all() is called to compute the required components.

The parents are returned in “policy format” if policy is True, otherwise they are returned in “GFlowNet” format (default).

Parameters:
  • policy (bool) – If True, the policy format of parents is returned. Otherwise, the GFlowNet format is returned.

  • force_recompute (bool) – If True, the parents are recomputed even if they are available.

Returns:

  • self.parents_all or self.parents_all_policy (list or torch.tensor) – The whole set of parents of all states in the batch.

  • self.parents_actions_all (torch.tensor) – The actions corresponding to each parent in self.parents_all or self.parents_all_policy, linking them to the corresponding state in the trajectory.

  • self.parents_all_indices (torch.tensor) – The state index corresponding to each parent in self.parents_all or self.parents_all_policy, linking them to the corresponding state in the batch.

Return type:

Tuple[Union[List, torchtyping.TensorType[n_parents, …]], torchtyping.TensorType[n_parents, …], torchtyping.TensorType[n_parents]]

get_masks_forward(of_parents=False, force_recompute=False, indices=None)[source]

Returns the forward mask of invalid actions of all states in the batch or of their parent in the trajectory if of_parents is True. The masks are computed via self._compute_masks_forward if they are not available or if force_recompute is True.

Parameters:
  • of_parents (bool) – If True, the returned masks will correspond to the parents of the states, instead of to the states (default).

  • force_recompute (bool) – If True, the masks are recomputed even if they are available.

  • indices (list, tuple, tensor or np.ndarray) – 1-dimentional sequence of bacth indecies for selecting masks, optional. If None (default), the masks of all states in the batch will be returned.

Returns:

self.masks_invalid_actions_forward (torch.tensor) – The forward mask of the selected states (by indices). If indices is None, the masks of all states in the batch will be returned.

Return type:

torchtyping.TensorType[n_states, action_space_dim]

get_masks_backward(force_recompute=False, indices=None)[source]

Returns the backward mask of invalid actions of all states in the batch. The masks are computed via self._compute_masks_backward if they are not available or if force_recompute is True.

Parameters:
  • force_recompute (bool) – If True, the masks are recomputed even if they are available.

  • indices (list, tuple, tensor or np.ndarray) – 1-dimentional sequence of bacth indecies for selecting masks, optional. If None (default), the masks of all states in the batch will be returned.

Returns:

self.masks_invalid_actions_backward (torch.tensor) – The backward mask of the selected states (by indices). If indices is None, the masks of all states in the batch will be returned.

Return type:

torchtyping.TensorType[n_states, action_space_dim]

get_rewards(log=False, force_recompute=False, do_non_terminating=False)[source]

Returns the rewards of all states in the batch (including not done).

Parameters:
  • log (bool) – If True, return the logarithm of the rewards.

  • force_recompute (bool) – If True, the rewards are recomputed even if they are available.

  • do_non_terminating (bool) – If True, return the actual rewards of the non-terminating states. If False, non-terminating states will be assigned reward 0.

Return type:

torchtyping.TensorType[n_states]

get_proxy_values(force_recompute=False, do_non_terminating=False)[source]

Returns the proxy values of all states in the batch (including not done).

Parameters:
  • force_recompute (bool) – If True, the proxy values are recomputed even if they are available.

  • do_non_terminating (bool) – If True, return the actual proxy values of the non-terminating states. If False, non-terminating states will be assigned value inf.

Return type:

torchtyping.TensorType[n_states]

get_rewards_parents(log=False)[source]

Returns the rewards of all parents in the batch.

Parameters:

log (bool) – If True, return the logarithm of the rewards.

Returns:

self.rewards_parents or self.logrewards_parents – A tensor containing the rewards of the parents of self.states.

Return type:

torchtyping.TensorType[n_states]

get_rewards_source(log=False)[source]

Returns rewards of the corresponding source states for each state in the batch.

Parameters:

log (bool) – If True, return the logarithm of the rewards.

Returns:

self.rewards_source or self.logrewards_source – A tensor containing the rewards the source states.

Return type:

torchtyping.TensorType[n_states]

get_terminating_states(sort_by='insertion', policy=False, proxy=False)[source]

Returns the terminating states in the batch, that is all states with done = True. The states will be returned in either GFlowNet format (default), policy (policy = True) or proxy (proxy = True) format. If both policy and proxy are True, it raises an error due to the ambiguity. The returned states may be sorted by order of insertion (sort_by = “insert[ion]”, default) or by trajectory index (sort_by = “traj[ectory]”.

Parameters:
  • sort_by (str) –

    Indicates how to sort the output:
    • insert[ion]: sort by order of insertion (states of trajectories that reached the terminating state first come first)

    • traj[ectory]: sort by trajectory index (the order in the ordered dict self.trajectories)

  • policy (bool) – If True, the policy format of the states is returned.

  • proxy (bool) – If True, the proxy format of the states is returned.

Return type:

Union[torchtyping.TensorType[n_trajectories, …], numpy.typing.NDArray[numpy.float32], List]

get_terminating_rewards(sort_by='insertion', log=False, force_recompute=False)[source]

Returns the reward of the terminating states in the batch, that is all states with done = True. The returned rewards may be sorted by order of insertion (sort_by = “insert[ion]”, default) or by trajectory index (sort_by = “traj[ectory]”.

Parameters:
  • sort_by (str) –

    Indicates how to sort the output:
    • insert[ion]: sort by order of insertion (rewards of trajectories that reached the terminating state first come first)

    • traj[ectory]: sort by trajectory index (the order in the ordered dict self.trajectories)

  • log (bool) – If True, return the logarithm of the rewards.

  • force_recompute (bool) – If True, the rewards are recomputed even if they are available.

Return type:

torchtyping.TensorType[n_trajectories]

get_terminating_proxy_values(sort_by='insertion', force_recompute=False)[source]

Returns the proxy values of the terminating states in the batch, that is all states with done = True. The returned proxy values may be sorted by order of insertion (sort_by = “insert[ion]”, default) or by trajectory index (sort_by = “traj[ectory]”.

Parameters:
  • sort_by (str) –

    Indicates how to sort the output:
    • insert[ion]: sort by order of insertion (proxy values of trajectories that reached the terminating state first come first)

    • traj[ectory]: sort by trajectory index (the order in the ordered dict self.trajectories)

  • force_recompute (bool) – If True, the proxy_values are recomputed even if they are available.

Return type:

torchtyping.TensorType[n_trajectories]

get_actions_trajectories()[source]

Returns the actions corresponding to all trajectories in the batch, sorted by trajectory index (the order in the ordered dict self.trajectories).

Return type:

List[List[Tuple]]

get_states_of_trajectory(traj_idx, states=None, traj_indices=None)[source]

Returns the states of the trajectory indicated by traj_idx. If states and traj_indices are not None, then these will be the only states and trajectory indices considered.

See: states2policy() See: states2proxy()

Parameters:
  • traj_idx (int) – Index of the trajectory from which to return the states.

  • states (tensor, array or list) – States from the trajectory to consider.

  • traj_indices (tensor, array or list) – Trajectory indices of the trajectory to consider.

Returns:

Tensor, array or list of states of the requested trajectory.

Return type:

Union[torchtyping.TensorType[n_states, state_proxy_dims], numpy.typing.NDArray[numpy.float32], List]

get_logprobs_of_trajectory(traj_idx, backward=False)[source]

Returns the logprobs of the trajectory indicated by traj_idx.

Parameters:
  • traj_idx (int) – Index of the trajectory from which to return the states.

  • backward (bool) – If True, returns backward logprobs, othervise returns forward logprobs

Returns:

A list of logprobs of the actions the requested trajectory.

Return type:

Union[torchtyping.TensorType[n_states, state_proxy_dims], numpy.typing.NDArray[numpy.float32], List]

merge(batches)[source]

Merges the current Batch (self) with the Batch or list of Batches passed as argument.

Returns:

self

Parameters:

batches (List)

is_valid()[source]

Performs basic checks on the current state of the batch.

Returns:

True if all the checks are valid, False otherwise.

Return type:

bool

traj_indices_are_consecutive()[source]

Returns True if the trajectory indices start from 0 and are consecutive; False otherwise.

Return type:

bool

make_indices_consecutive()[source]

Updates the trajectory indices as well as the env ids such that they start from 0 and are consecutive. Note that only the trajectory indices are changed, but importantly the order of the main data in the batch is preserved.

Examples:

  • Original indices: 0, 10, 20

  • New indices: 0, 1, 2

  • Original indices: 1, 5, 3

  • New indices: 0, 1, 2

Note: this method is unsued as of September 1st 2023, but is left here for potential future use.

get_item(item, env=None, traj_idx=None, action_idx=None, backward=False)[source]
Returns the item specified by item of either:
  • environment env, OR

  • trajectory traj_idx AND action number action_idx (in the order of sampling)

If all arguments are given, then they must be consistent, otherwise an exception (assert) is raised due to ambiguity.

If a mask is requested but is missing, it is computed and stored.

Parameters:
  • item (str) –

    String identifier of the item to retrieve from the batch. Options
    • state

    • parent

    • action

    • done

    • mask_f[orward]

    • mask_b[ackward]

  • traj_idx (int) – Trajectory index

  • action_idx (int) – Action index. Regardless of forward of backward, n-th item sampled when forming the batch.

  • backward (bool) – Whether the trajectory is sampling backward. False (forward) by default.

  • env (gflownet.envs.base.GFlowNetEnv)

Returns:

  • The requested item if it is available or None if it is not. It raises an error

  • if the request can be identified as incorrect.

get_indices_of_previous_transitions(envs, backward)[source]

Get batch indices of the latest elements (states, actions, etc) added to the batch for the provided environments.

Parameters:
  • envs (list) – List of envs for which the batch indices are requested.

  • backward (bool) – Whether the trajectories are sampled backward (True) or forward (False).

Returns:

list – Batch indices of the previous transitions for the requested environments. Environments without any previous transitions in the batch are assigned None.

Return type:

List[int]

get_actions_of_previous_transitions(envs, backward)[source]

Retrieves the latest actions added to the batch for the provided envs.

Parameters:
  • envs (list) – List of envs for which the actions are requested.

  • backward (bool) – Indicates whether the trajectories are sampled backward (True) or forward (False).

Returns:

actions (list) – Actions of previous transitions for the requested envs

Return type:

List

gflownet.utils.batch.compute_logprobs_trajectories(batch, env=None, forward_policy=None, backward_policy=None, backward=False)[source]

Computes the forward or backward log probabilities of the trajectories in a batch.

The log probabilities are computed according to the forward or backward policy models passed as parameters.

Parameters:
  • batch (Batch) – A batch of data, containing all the states in the trajectories.

  • env (gflownet.envs.base.GFlowNetEnv, optional) – An instance of the environment used to compute log probabilities of state transitions. If None, batch.readonly_env is used.

  • forward_policy (gflownet.policy.base.Policy, optional) – The model used to compute the forward policy outputs from input states. It is ignored if backward is True.

  • bacward_policy (gflownet.policy.base.Policy, optional) – Same as forward_policy, but for the backward policy. It is ignored if backward is False.

  • backward (bool) – False: log probabilities of forward trajectories. True: log probabilities of backward trajectories.

  • backward_policy (gflownet.policy.base.Policy)

Returns:

logprobs (torch.tensor) – The log probabilities of the trajectories.