gflownet.envs.base

Base class of GFlowNet environments

Attributes

CMAP

Plotting colour map (cividis).

Classes

GFlowNetEnv

Base class of GFlowNet environments

Module Contents

gflownet.envs.base.CMAP[source]

Plotting colour map (cividis).

class gflownet.envs.base.GFlowNetEnv(device='cpu', float_precision=32, env_id='env', fixed_distr_params=None, random_distr_params=None, skip_mask_check=False, conditional=False, continuous=False, **kwargs)[source]

Base class of GFlowNet environments

Parameters:
  • device (str)

  • float_precision (int)

  • env_id (Union[int, str])

  • fixed_distr_params (Optional[dict])

  • random_distr_params (Optional[dict])

  • skip_mask_check (bool)

  • conditional (bool)

  • continuous (bool)

conditional = False[source]
continuous = False[source]
device = 'cpu'[source]
float = 32[source]
skip_mask_check = False[source]
logsoftmax[source]
action_space[source]
action_space_torch[source]
fixed_distr_params = None[source]
random_distr_params = None[source]
fixed_policy_output[source]
random_policy_output[source]
policy_output_dim[source]
policy_input_dim[source]
abstract get_action_space()[source]

Constructs list with all possible actions (excluding end of sequence)

property action_space_dim: int[source]

Returns the dimensionality of the action space (number of actions).

Returns:

The number of actions in the action space.

Return type:

int

property mask_dim[source]

Returns the dimensionality of the masks.

Returns:

The dimensionality of the masks.

property max_traj_length: int[source]

Returns the maximum trajectory length of the environment, including the EOS action.

Returns:

The maximum number of steps in a trajectory of the environment.

Return type:

int

action2representative(action)[source]

For continuous or hybrid environments, converts a continuous action into its representative in the action space. Discrete actions remain identical, thus fully discrete environments do not need to re-implement this method. Continuous environments should re-implement this method in order to replace continuous actions by their representatives in the action space.

Parameters:

action (Tuple)

Return type:

int

action_produces_permutation(action, is_backward=False)[source]

Determines whether an action produces permutations in the resulting state.

Permutations can be introduced, for example, in environments that need to incorporate permutation invariance, as in sets of elements. In these cases, some actions may result in states with elements that are randomly permuted.

This method allows to identify these actions, which is useful, for instance, in unit tests.

By default, actions do not produce permutations and the returned value of this method is False.

Environments with actions that produce permutations should override this method and properly identify such actions.

Parameters:
  • action (tuple) – An action of the environment.

  • is_backward (bool) – Whether the transition to consider is backward (True) or forward (False).

Returns:

bool – Whether the input actions produces permutations in the resulting state, in the direction indicated by is_backward.

Return type:

bool

action2index(action)[source]

Returns the index in the action space of the action passed as an argument, or its representative if it is a continuous action.

The method uses the dictionary lookup self._action2index.

See: self.action2representative()

Parameters:

action (tuple) – An action from the action space.

Returns:

int – The index of the action in the action space.

Return type:

int

actions2indices(actions)[source]

Returns the corresponding indices in the action space of the actions in a batch.

Parameters:

actions (torchtyping.TensorType[batch_size, action_dim])

Return type:

torchtyping.TensorType[batch_size]

is_source(state=None)[source]

Returns True if the environment’s state or the state passed as parameter (if not None) is the source state of the environment.

Parameters:

state (list or tensor or None) – None, or a state in environment format.

Returns:

bool – Whether the state is the source state of the environment

Return type:

bool

get_mask_invalid_actions_forward(state=None, done=None)[source]
Returns a list of length the action space with values:
  • True if the forward action is invalid from the current state.

  • False otherwise.

For continuous or hybrid environments, this mask corresponds to the discrete part of the action space.

Parameters:
  • state (Optional[List])

  • done (Optional[bool])

Return type:

List

get_mask_invalid_actions_backward(state=None, done=None, parents_a=None)[source]
Returns a list of length the action space with values:
  • True if the backward action is invalid from the current state.

  • False otherwise.

For continuous or hybrid environments, this mask corresponds to the discrete part of the action space.

The base implementation below should be common to all discrete spaces as it relies on get_parents, which is environment-specific and must be implemented. Continuous environments will probably need to implement its specific version of this method.

Parameters:
  • state (Optional[List])

  • done (Optional[bool])

  • parents_a (Optional[List])

Return type:

List

get_mask(state=None, done=None, backward=False)[source]

Returns a mask of invalid actions given a state and a done value. Depending on backward, either the forward or the backward mask is returned, by calling the corresponding method.

Parameters:
  • state (Optional[List])

  • done (Optional[bool])

  • backward (Optional[bool])

Return type:

List

get_valid_actions(mask=None, state=None, done=None, backward=False)[source]

Returns the list of non-invalid (valid, for short) according to the mask of invalid actions.

More documentation about the meaning and use of invalid actions can be found in gflownet/envs/README.md.

Parameters:
  • mask (Optional[bool])

  • state (Optional[List])

  • done (Optional[bool])

  • backward (Optional[bool])

Return type:

List[Tuple]

get_parents(state=None, done=None, action=None)[source]

Determines all parents and actions that lead to state.

In continuous environments, get_parents() should return only the parent from which action leads to state.

Parameters:
  • state (list) – Representation of a state

  • done (bool) – Whether the trajectory is done. If None, done is taken from instance.

  • action (tuple) – Last action performed

Returns:

  • parents (list) – List of parents in state format

  • actions (list) – List of actions that lead to state for each parent in parents

Return type:

Tuple[List, List]

abstract step(action, skip_mask_check=False)[source]

Executes step given an action.

Parameters:
  • action (tuple) – Action from the action space.

  • skip_mask_check (bool) – If True, skip computing forward mask of invalid actions to check if the action is valid.

Returns:

  • self.state (list) – The sequence after executing the action

  • action (int) – Action index

  • valid (bool) – False, if the action is not allowed for the current state, e.g. stop at the root state

Return type:

Tuple[List[int], Tuple[int], bool]

step_backwards(action, skip_mask_check=False)[source]

Executes a backward step given an action. This generic implementation should work for all discrete environments, as it relies on get_parents(). Continuous environments should re-implement a custom step_backwards(). Despite being valid for any discrete environment, the call to get_parents() may be expensive. Thus, it may be advantageous to re-implement step_backwards() in a more efficient way as well for discrete environments. Especially, because this generic implementation will make two calls to get_parents - once here and one in _pre_step() through the call to get_mask_invalid_actions_backward() if skip_mask_check is True.

Parameters:
  • action (tuple) – Action from the action space.

  • skip_mask_check (bool) – If True, skip computing forward mask of invalid actions to check if the action is valid.

Returns:

  • self.state (list) – The sequence after executing the action

  • action (int) – Action index

  • valid (bool) – False, if the action is not allowed for the current state.

Return type:

Tuple[List[int], Tuple[int], bool]

randomize_and_temper_sampling_distribution(policy_outputs, probability_random_action=0.0, temperature=1.0)[source]

Replaces the rows of policy_outputs by a vector corresponding to a random sampling policy with the probability indicated by probability_random_action.

Note that the tensor of policy outputs is not cloned if neither tempering nor random actions are incorporated. This implies that the original tensor of policy outputs may be modified by subsequent methods (namely sample_actions_batch()), for example to mask the invalid actions.

Parameters:
  • policy_outputs (tensor) – The original outputs of the sampling policy. For example, they may correspond to the output (logits) of the GFlowNet policy model.

  • probability_random_action (float, optional) – The probability of sampling a random action. If larger than one, the logits will be replaced by a random policy vector with this probability, according to Bernoulli distribution. By default, the probability is 0.0 (no random actions).

  • temperature (float, optional) – A scalar by which the logits are divided to adjust the sampling distribution. A temperature larger than one will result in a flatter distribution, favouring exploration. A temperature smaller than one will sharpen the distribution, favouring concentration around high probability actions. By default, the temperature is 1.0 (no tempering).

Returns:

policy_outputs (tensor) – The modified policy outputs.

Return type:

torchtyping.TensorType[n_states, policy_output_dim]

sample_actions_batch(policy_outputs, mask=None, states_from=None, is_backward=False, random_action_prob=0.0, temperature_logits=1.0)[source]

Samples a batch of actions from a batch of policy outputs.

This implementation is generally valid for all discrete environments but continuous or mixed environments need to reimplement this method.

The method is valid for both forward and backward actions in the case of discrete environments. Some continuous environments may also be agnostic to the difference between forward and backward actions since the necessary information can be contained in the mask. However, some continuous environments do need to know whether the actions are forward of backward, which is why this can be specified by the argument is_backward.

Most environments do not need to know the states from which the actions are to be sampled since the necessary information is in both the policy outputs and the mask. However, some continuous environments do need to know the originating states in order to construct the actions, which is why one of the arguments is states_from.

Note that methods overriding this method should randomize and temper the logits.

Parameters:
  • policy_outputs (tensor) – The output of the GFlowNet policy model.

  • mask (tensor) – The mask of invalid actions. For continuous or mixed environments, the mask may be tensor with an arbitrary length contaning information about special states, as defined elsewhere in the environment.

  • states_from (tensor) – The states originating the actions, in GFlowNet format. Ignored in discrete environments and only required in certain continuous environments.

  • is_backward (bool) – True if the actions are backward, False if the actions are forward (default). Ignored in discrete environments and only required in certain continuous environments.

  • random_action_prob (float, optional) – The probability of sampling a random action. If larger than one, the model outputs will be replaced by a random policy vector with probability random_action_prob, according to Bernoulli distribution.

  • temperature_logits (float, optional) – A scalar by which the model outputs are divided to temper the sampling distribution.

Returns:

actions (list) – The list of sampled actions.

Return type:

Tuple[List[Tuple], torchtyping.TensorType[n_states]]

get_logprobs(policy_outputs, actions, mask=None, states_from=None, is_backward=False)[source]

Computes log probabilities of actions given policy outputs and actions. This implementation is generally valid for all discrete environments but continuous environments will likely have to implement its own.

Parameters:
  • policy_outputs (tensor) – The output of the GFlowNet policy model.

  • mask (tensor) – The mask of invalid actions. For continuous or mixed environments, the mask may be tensor with an arbitrary length contaning information about special states, as defined elsewhere in the environment.

  • actions (list or tensor) – The actions from each state in the batch for which to compute the log probability. The actions may be a list or a tensor. Most environments handle the actions as a list, but in some cases it is practical to use a tensor for easier indexing, such as in meta-environments.

  • states_from (tensor) – The states originating the actions, in GFlowNet format. Ignored in discrete environments and only required in certain continuous environments.

  • is_backward (bool) – True if the actions are backward, False if the actions are forward (default). Ignored in discrete environments and only required in certain continuous environments.

Return type:

torchtyping.TensorType[batch_size]

step_random(backward=False)[source]

Samples a random action and executes the step.

Parameters:

backward (bool) – If True, the step is performed backwards. False by default.

Returns:

  • state (list) – The state after executing the action.

  • action (int) – Action, randomly sampled.

  • valid (bool) – False, if the action is not allowed for the current state.

trajectory_random(backward=False)[source]

Samples and applies a random trajectory on the environment, by sampling random actions until an EOS action is sampled.

Parameters:

backward (bool) – If True, the trajectory is sampled backwards. False by default.

Returns:

  • state (list) – The final state.

  • action (list) – The list of actions (tuples) in the trajectory.

get_random_terminating_states(n_states, unique=True, max_attempts=100000)[source]

Samples n terminating states by using the random policy of the environment (calling self.trajectory_random()).

Note that this method is general for all environments but it may be suboptimal in terms of efficiency. In particular, 1) it samples full trajectories in order to get terminating states, 2) if unique is True, it needs to compare each newly sampled state with all the previously sampled states. If get_uniform_terminating_states is available, it may be preferred, or for some environments, a custom get_random_terminating_states may be straightforward to implement in a much more efficient way.

Parameters:
  • n_states (int) – The number of terminating states to sample.

  • unique (bool) – Whether samples should be unique. True by default.

  • max_attempts (int) – The maximum number of attempts, to prevent the method from getting stuck trying to obtain n_states different samples if unique is True. 100000 by default, therefore if more than 100000 are requested, max_attempts should be increased accordingly.

Returns:

states (list) – A list of randomly sampled terminating states.

Return type:

List

get_random_states(n_states, unique=True, exclude_source=False, max_attempts=1000)[source]

Samples n states (not necessarily terminating) by using the random policy of the environment (calling self.step_random()).

It relies on self.max_traj_length in order to uniformly sample the number of steps, in order to obtain states with varying trajectory lengths.

The method iteratively samples first a trajectory length and attempts to perform as many steps. If the trajectory ends before the requested number of steps is reached, then it is discarded and a new one is attempted.

This may introduced a bias towards states that can be reached with a few steps.

Note that this method is general for all environments but it may be suboptimal in terms of efficiency. In particular, 1) it samples trajectories step by step in order to get random states, 2) if unique is True, it needs to compare each newly sampled state with all the previously sampled states, 3) states are copied before adding them to the list, 4) only the last state of a trajectory is added to the list in order to have diversity of trajectories.

Parameters:
  • n_states (int) – The number of terminating states to sample.

  • unique (bool) – Whether samples should be unique. True by default.

  • max_attempts (int) – The maximum number of attempts, to prevent the method from getting stuck trying to obtain n_states different samples if unique is True. 100000 by default, therefore if more than 100000 are requested, max_attempts should be increased accordingly.

  • exclude_source (bool) – If True, exclude the source state from the list of states.

Returns:

states (list) – A list of randomly sampled states.

Raises:
  • ValueError – If max_attempts is smaller than n_states

  • RuntimeError – If the maximum number of attempts is reached before obtaining the requested number of unique states.

Return type:

List

get_policy_output(params=None)[source]

Defines the structure of the output of the policy model, from which an action is to be determined or sampled, by returning a vector with a fixed random policy. As a baseline, the policy is uniform over the dimensionality of the action space.

Continuous environments will generally have to overwrite this method.

Parameters:

params (Optional[dict])

Return type:

torchtyping.TensorType[policy_output_dim]

states2proxy(states)[source]

Prepares a batch of states in “environment format” for the proxy. By default, the batch of states is converted into a tensor with float dtype and returned as is.

Parameters:

states (list or tensor) – A batch of states in environment format, either as a list of states or as a single tensor.

Returns:

A tensor containing all the states in the batch.

Return type:

torchtyping.TensorType[batch, state_proxy_dim]

state2proxy(state=None)[source]

Prepares a single state in “GFlowNet format” for the proxy. By default, simply states2proxy is called and the output will be a “batch” with a single state in the proxy format.

Parameters:

state (list) – A state

Return type:

torchtyping.TensorType[state_proxy_dim]

states2policy(states)[source]

Prepares a batch of states in “environment format” for the policy model: By default, the batch of states is converted into a tensor with float dtype and returned as is.

Parameters:

states (list or tensor) – A batch of states in environment format, either as a list of states or as a single tensor.

Returns:

A tensor containing all the states in the batch.

Return type:

torchtyping.TensorType[batch, policy_input_dim]

state2policy(state=None)[source]

Prepares a state in “GFlowNet format” for the policy model. By default, states2policy is called, which by default will return the state as is.

Parameters:

state (list) – A state

Return type:

torchtyping.TensorType[policy_input_dim]

state2readable(state=None)[source]

Converts a state into human-readable representation.

readable2state(readable)[source]

Converts a human-readable representation of a state into the standard format.

traj2readable(traj=None)[source]

Converts a trajectory into a human-readable string.

reset(env_id=None)[source]

Resets the environment.

Parameters:

env_id (int or str) – Unique (ideally) identifier of the environment instance, used to identify the trajectory generated with this environment. If None, uuid.uuid4() is used.

Returns:

self

set_id(env_id)[source]

Sets the id given as argument and returns the environment.

Parameters:

env_id (int or str) – Unique (ideally) identifier of the environment instance, used to identify the trajectory generated with this environment.

Returns:

self

set_state(state, done=False)[source]

Sets the state and done of an environment. Environments that cannot be “done” at all states (intermediate states are not fully constructed objects) should overwrite this method and check for validity.

Parameters:
  • state (List)

  • done (Optional[bool])

copy()[source]
static equal(state_x, state_y)[source]

Checks whether the two input states are equal.

This method handles recursively multiple structure types: numbers, strings, tensors, dictionaries, lists and tuples.

The result is only True if the content of the two input states is identical.

The core functionality is implemented in gflownet.envs.base.GFlowNetEnv.isclose() and this method simply calls it with do_equal=True.

Parameters:
  • state_x (number, str, tensor, dict, list, tuple) – One of the states to be compared.

  • state_y (number, str, tensor, dict, list, tuple) – The other state to be compared.

Returns:

bool – True if the two input states are equal; False otherwise.

Raises:

NotImplementedError – If the input types are not part of the explicitly handles types.

Return type:

bool

static isclose(state_x, state_y, rtol=1e-05, atol=1e-08, do_equal=False)[source]

Checks whether the two input states are close, according to a tolerance.

This method relies on numpy’s and torch’s isclose() methods, which both use the following formula:

abs(x - y) <= rtol * abs(y) + atol

This method is used as well by gflownet.envs.base.GFlowNetEnv.equal() in order to avoid code repetition. In this case, do_equal is True and numpy’s and torch’s equal() methods are used. This is preferred over using rtol and atol equal to 0.0 for efficiency reasons.

This method handles recursively multiple structure types: numbers, strings, tensors, dictionaries, lists and tuples.

The result is only True if the content of the two input states is identical or close enough, as defined by the tolerance values rtol and atol. In the case of strings, True is only returned if the states are identical.

Parameters:
  • state_x (number, str, tensor, dict, list, tuple) – One of the states to be compared.

  • state_y (number, str, tensor, dict, list, tuple) – The other state to be compared.

  • rtol (float) – Relative tolerance for numeric values.

  • atol (float) – Maximum absolute tolerance threshold for numeric values.

  • do_equal (bool) – If True, comparisons are by equality instead of closeness and rtol and atol are ignored.

Returns:

bool – True if the two input states are equal or closer than the maximum tolerance; False otherwise.

Raises:

NotImplementedError – If the input types are not part of the explicitly handles types.

Return type:

bool

__eq__(other, ignored_keys=[])[source]

Checks whether the current environment instance is equal to the input environment instance.

The attribute self.id is ignored to determine whether the environments are equal.

Parameters:
  • other (GFlowNetEnv) – The environment instance to be compared.

  • ignored_keys (list) – A list of keys (strings) to be ignored in the comparison. This parameter may be used by subclasses that may need to ignore certain keys.

Returns:

bool – True if the environments’s attributes are considered equal; False otherwise.

Return type:

bool

get_trajectories(traj_list, traj_actions_list, current_traj, current_actions)[source]

Determines all trajectories leading to each state in traj_list, recursively.

Parameters:
  • traj_list (list) – List of trajectories (lists)

  • traj_actions_list (list) – List of actions within each trajectory

  • current_traj (list) – Current trajectory

  • current_actions (list) – Actions of current trajectory

Returns:

  • traj_list (list) – List of trajectories (lists)

  • traj_actions_list (list) – List of actions within each trajectory

compute_train_energy_proxy_and_rewards()[source]

Gather batched proxy data:

  • The ground-truth energy of the train set

  • The predicted proxy energy over the train set

  • The reward version of those energies (with env.proxy2reward)

Returns:

  • gt_energy (torch.Tensor) – The ground-truth energies in the proxy’s train set

  • proxy_energy (torch.Tensor) – The proxy’s predicted energies over its train set

  • gt_reward (torch.Tensor) – The reward version of the ground-truth energies

  • proxy_reward (torch.Tensor) – The reward version of the proxy’s predicted energies

mask_conditioning(mask, env_cond, backward)[source]

Conditions the input mask based on the restrictions imposed by a conditioning environment, env_cond.

It is assumed that the state space of the conditioning environment is a subset of the state space of the original environment (self). The conditioning mechanism goes as follows: given a state, its corresponding mask and a conditioning environment, the mask of invalid actions is updated such that all actions that would be invalid in the conditioning environment are made invalid, even though they may not be invalid in the original environment.

Parameters:
  • mask (Union[List[bool], torchtyping.TensorType[mask_dim]])

  • backward (bool)

top_k_metrics_and_plots(states, top_k, name, energy=None, reward=None, step=None, **kwargs)[source]

Compute top_k metrics and plots for the given states.

In particular, if no states, energy, or reward are passed, then the name must be “train”, and the energy and reward will be computed from the proxy using env.compute_train_energy_proxy_and_rewards(). In this case, top_k_metrics_and_plots will be called a second time to compute the metrics and plots of the proxy distribution in addition to the ground-truth distribution. Train mode should only be called once at the begining of training as distributions do not change over time.

If states are passed, then the energy and reward will be computed from the proxy for those states. They are typically sampled from the current GFN.

Otherwise, energy and reward should be passed directly.

Plots and metrics: - mean+std of energy and reward - mean+std of top_k energy and reward - histogram of energy and reward - histogram of top_k energy and reward

Parameters:
  • states (list) – List of states to compute metrics and plots for.

  • top_k (int) – Number of top k states to compute metrics and plots for. “top” means lowest energy/highest reward.

  • name (str) – Name of the distribution to compute metrics and plots for. Typically “gflownet”, “random” or “train”. Will be used in metrics names like f"Mean {name} energy".

  • energy (torch.Tensor, optional) – Batch of pre-computed energies

  • reward (torch.Tensor, optional) – Batch of pre-computed rewards

  • step (int, optional) – Step number to use for the plot title.

Returns:

  • metrics (dict) – Dictionary of metrics: str->float

  • figs (list) – List of matplotlib figures

  • figs_names (list) – List of figure names for figs

plot_reward_distribution(states=None, scores=None, ax=None, title=None, proxy=None, **kwargs)[source]
test(samples)[source]

Placeholder for a custom test function that can be defined for a specific environment. Can be overwritten if special evaluation procedure is needed for a given environment.

Parameters:

samples (Union[torchtyping.TensorType[n_trajectories, ...], numpy.typing.NDArray[numpy.float32], List]) – A collection of sampled terminating states.

Returns:

metrics – A dictionary with metrics and their calculated values.

Return type:

dict