gflownet.envs.base
Base class of GFlowNet environments
Attributes
Plotting colour map (cividis). |
Classes
Base class of GFlowNet environments |
Module Contents
- class gflownet.envs.base.GFlowNetEnv(device='cpu', float_precision=32, env_id='env', fixed_distr_params=None, random_distr_params=None, skip_mask_check=False, conditional=False, continuous=False, **kwargs)[source]
Base class of GFlowNet environments
- Parameters:
device (str)
float_precision (int)
env_id (Union[int, str])
fixed_distr_params (Optional[dict])
random_distr_params (Optional[dict])
skip_mask_check (bool)
conditional (bool)
continuous (bool)
- abstract get_action_space()[source]
Constructs list with all possible actions (excluding end of sequence)
- property action_space_dim: int[source]
Returns the dimensionality of the action space (number of actions).
- Returns:
The number of actions in the action space.
- Return type:
int
- property mask_dim[source]
Returns the dimensionality of the masks.
- Returns:
The dimensionality of the masks.
- property max_traj_length: int[source]
Returns the maximum trajectory length of the environment, including the EOS action.
- Returns:
The maximum number of steps in a trajectory of the environment.
- Return type:
int
- action2representative(action)[source]
For continuous or hybrid environments, converts a continuous action into its representative in the action space. Discrete actions remain identical, thus fully discrete environments do not need to re-implement this method. Continuous environments should re-implement this method in order to replace continuous actions by their representatives in the action space.
- Parameters:
action (Tuple)
- Return type:
int
- action_produces_permutation(action, is_backward=False)[source]
Determines whether an action produces permutations in the resulting state.
Permutations can be introduced, for example, in environments that need to incorporate permutation invariance, as in sets of elements. In these cases, some actions may result in states with elements that are randomly permuted.
This method allows to identify these actions, which is useful, for instance, in unit tests.
By default, actions do not produce permutations and the returned value of this method is False.
Environments with actions that produce permutations should override this method and properly identify such actions.
- Parameters:
action (tuple) – An action of the environment.
is_backward (bool) – Whether the transition to consider is backward (True) or forward (False).
- Returns:
bool – Whether the input actions produces permutations in the resulting state, in the direction indicated by
is_backward.- Return type:
bool
- action2index(action)[source]
Returns the index in the action space of the action passed as an argument, or its representative if it is a continuous action.
The method uses the dictionary lookup
self._action2index.See: self.action2representative()
- Parameters:
action (tuple) – An action from the action space.
- Returns:
int – The index of the action in the action space.
- Return type:
int
- actions2indices(actions)[source]
Returns the corresponding indices in the action space of the actions in a batch.
- Parameters:
actions (torchtyping.TensorType[batch_size, action_dim])
- Return type:
torchtyping.TensorType[batch_size]
- is_source(state=None)[source]
Returns True if the environment’s state or the state passed as parameter (if not None) is the source state of the environment.
- Parameters:
state (list or tensor or None) – None, or a state in environment format.
- Returns:
bool – Whether the state is the source state of the environment
- Return type:
bool
- get_mask_invalid_actions_forward(state=None, done=None)[source]
- Returns a list of length the action space with values:
True if the forward action is invalid from the current state.
False otherwise.
For continuous or hybrid environments, this mask corresponds to the discrete part of the action space.
- Parameters:
state (Optional[List])
done (Optional[bool])
- Return type:
List
- get_mask_invalid_actions_backward(state=None, done=None, parents_a=None)[source]
- Returns a list of length the action space with values:
True if the backward action is invalid from the current state.
False otherwise.
For continuous or hybrid environments, this mask corresponds to the discrete part of the action space.
The base implementation below should be common to all discrete spaces as it relies on get_parents, which is environment-specific and must be implemented. Continuous environments will probably need to implement its specific version of this method.
- Parameters:
state (Optional[List])
done (Optional[bool])
parents_a (Optional[List])
- Return type:
List
- get_mask(state=None, done=None, backward=False)[source]
Returns a mask of invalid actions given a state and a done value. Depending on backward, either the forward or the backward mask is returned, by calling the corresponding method.
- Parameters:
state (Optional[List])
done (Optional[bool])
backward (Optional[bool])
- Return type:
List
- get_valid_actions(mask=None, state=None, done=None, backward=False)[source]
Returns the list of non-invalid (valid, for short) according to the mask of invalid actions.
More documentation about the meaning and use of invalid actions can be found in gflownet/envs/README.md.
- Parameters:
mask (Optional[bool])
state (Optional[List])
done (Optional[bool])
backward (Optional[bool])
- Return type:
List[Tuple]
- get_parents(state=None, done=None, action=None)[source]
Determines all parents and actions that lead to state.
In continuous environments, get_parents() should return only the parent from which action leads to state.
- Parameters:
state (list) – Representation of a state
done (bool) – Whether the trajectory is done. If None, done is taken from instance.
action (tuple) – Last action performed
- Returns:
parents (list) – List of parents in state format
actions (list) – List of actions that lead to state for each parent in parents
- Return type:
Tuple[List, List]
- abstract step(action, skip_mask_check=False)[source]
Executes step given an action.
- Parameters:
action (tuple) – Action from the action space.
skip_mask_check (bool) – If True, skip computing forward mask of invalid actions to check if the action is valid.
- Returns:
self.state (list) – The sequence after executing the action
action (int) – Action index
valid (bool) – False, if the action is not allowed for the current state, e.g. stop at the root state
- Return type:
Tuple[List[int], Tuple[int], bool]
- step_backwards(action, skip_mask_check=False)[source]
Executes a backward step given an action. This generic implementation should work for all discrete environments, as it relies on get_parents(). Continuous environments should re-implement a custom step_backwards(). Despite being valid for any discrete environment, the call to get_parents() may be expensive. Thus, it may be advantageous to re-implement step_backwards() in a more efficient way as well for discrete environments. Especially, because this generic implementation will make two calls to get_parents - once here and one in _pre_step() through the call to get_mask_invalid_actions_backward() if skip_mask_check is True.
- Parameters:
action (tuple) – Action from the action space.
skip_mask_check (bool) – If True, skip computing forward mask of invalid actions to check if the action is valid.
- Returns:
self.state (list) – The sequence after executing the action
action (int) – Action index
valid (bool) – False, if the action is not allowed for the current state.
- Return type:
Tuple[List[int], Tuple[int], bool]
- randomize_and_temper_sampling_distribution(policy_outputs, probability_random_action=0.0, temperature=1.0)[source]
Replaces the rows of
policy_outputsby a vector corresponding to a random sampling policy with the probability indicated byprobability_random_action.Note that the tensor of policy outputs is not cloned if neither tempering nor random actions are incorporated. This implies that the original tensor of policy outputs may be modified by subsequent methods (namely sample_actions_batch()), for example to mask the invalid actions.
- Parameters:
policy_outputs (tensor) – The original outputs of the sampling policy. For example, they may correspond to the output (logits) of the GFlowNet policy model.
probability_random_action (float, optional) – The probability of sampling a random action. If larger than one, the logits will be replaced by a random policy vector with this probability, according to Bernoulli distribution. By default, the probability is 0.0 (no random actions).
temperature (float, optional) – A scalar by which the logits are divided to adjust the sampling distribution. A temperature larger than one will result in a flatter distribution, favouring exploration. A temperature smaller than one will sharpen the distribution, favouring concentration around high probability actions. By default, the temperature is 1.0 (no tempering).
- Returns:
policy_outputs (tensor) – The modified policy outputs.
- Return type:
torchtyping.TensorType[n_states, policy_output_dim]
- sample_actions_batch(policy_outputs, mask=None, states_from=None, is_backward=False, random_action_prob=0.0, temperature_logits=1.0)[source]
Samples a batch of actions from a batch of policy outputs.
This implementation is generally valid for all discrete environments but continuous or mixed environments need to reimplement this method.
The method is valid for both forward and backward actions in the case of discrete environments. Some continuous environments may also be agnostic to the difference between forward and backward actions since the necessary information can be contained in the mask. However, some continuous environments do need to know whether the actions are forward of backward, which is why this can be specified by the argument is_backward.
Most environments do not need to know the states from which the actions are to be sampled since the necessary information is in both the policy outputs and the mask. However, some continuous environments do need to know the originating states in order to construct the actions, which is why one of the arguments is states_from.
Note that methods overriding this method should randomize and temper the logits.
- Parameters:
policy_outputs (tensor) – The output of the GFlowNet policy model.
mask (tensor) – The mask of invalid actions. For continuous or mixed environments, the mask may be tensor with an arbitrary length contaning information about special states, as defined elsewhere in the environment.
states_from (tensor) – The states originating the actions, in GFlowNet format. Ignored in discrete environments and only required in certain continuous environments.
is_backward (bool) – True if the actions are backward, False if the actions are forward (default). Ignored in discrete environments and only required in certain continuous environments.
random_action_prob (float, optional) – The probability of sampling a random action. If larger than one, the model outputs will be replaced by a random policy vector with probability
random_action_prob, according to Bernoulli distribution.temperature_logits (float, optional) – A scalar by which the model outputs are divided to temper the sampling distribution.
- Returns:
actions (list) – The list of sampled actions.
- Return type:
Tuple[List[Tuple], torchtyping.TensorType[n_states]]
- get_logprobs(policy_outputs, actions, mask=None, states_from=None, is_backward=False)[source]
Computes log probabilities of actions given policy outputs and actions. This implementation is generally valid for all discrete environments but continuous environments will likely have to implement its own.
- Parameters:
policy_outputs (tensor) – The output of the GFlowNet policy model.
mask (tensor) – The mask of invalid actions. For continuous or mixed environments, the mask may be tensor with an arbitrary length contaning information about special states, as defined elsewhere in the environment.
actions (list or tensor) – The actions from each state in the batch for which to compute the log probability. The actions may be a list or a tensor. Most environments handle the actions as a list, but in some cases it is practical to use a tensor for easier indexing, such as in meta-environments.
states_from (tensor) – The states originating the actions, in GFlowNet format. Ignored in discrete environments and only required in certain continuous environments.
is_backward (bool) – True if the actions are backward, False if the actions are forward (default). Ignored in discrete environments and only required in certain continuous environments.
- Return type:
torchtyping.TensorType[batch_size]
- step_random(backward=False)[source]
Samples a random action and executes the step.
- Parameters:
backward (bool) – If True, the step is performed backwards. False by default.
- Returns:
state (list) – The state after executing the action.
action (int) – Action, randomly sampled.
valid (bool) – False, if the action is not allowed for the current state.
- trajectory_random(backward=False)[source]
Samples and applies a random trajectory on the environment, by sampling random actions until an EOS action is sampled.
- Parameters:
backward (bool) – If True, the trajectory is sampled backwards. False by default.
- Returns:
state (list) – The final state.
action (list) – The list of actions (tuples) in the trajectory.
- get_random_terminating_states(n_states, unique=True, max_attempts=100000)[source]
Samples n terminating states by using the random policy of the environment (calling self.trajectory_random()).
Note that this method is general for all environments but it may be suboptimal in terms of efficiency. In particular, 1) it samples full trajectories in order to get terminating states, 2) if unique is True, it needs to compare each newly sampled state with all the previously sampled states. If get_uniform_terminating_states is available, it may be preferred, or for some environments, a custom get_random_terminating_states may be straightforward to implement in a much more efficient way.
- Parameters:
n_states (int) – The number of terminating states to sample.
unique (bool) – Whether samples should be unique. True by default.
max_attempts (int) – The maximum number of attempts, to prevent the method from getting stuck trying to obtain n_states different samples if unique is True. 100000 by default, therefore if more than 100000 are requested, max_attempts should be increased accordingly.
- Returns:
states (list) – A list of randomly sampled terminating states.
- Return type:
List
- get_random_states(n_states, unique=True, exclude_source=False, max_attempts=1000)[source]
Samples n states (not necessarily terminating) by using the random policy of the environment (calling self.step_random()).
It relies on self.max_traj_length in order to uniformly sample the number of steps, in order to obtain states with varying trajectory lengths.
The method iteratively samples first a trajectory length and attempts to perform as many steps. If the trajectory ends before the requested number of steps is reached, then it is discarded and a new one is attempted.
This may introduced a bias towards states that can be reached with a few steps.
Note that this method is general for all environments but it may be suboptimal in terms of efficiency. In particular, 1) it samples trajectories step by step in order to get random states, 2) if unique is True, it needs to compare each newly sampled state with all the previously sampled states, 3) states are copied before adding them to the list, 4) only the last state of a trajectory is added to the list in order to have diversity of trajectories.
- Parameters:
n_states (int) – The number of terminating states to sample.
unique (bool) – Whether samples should be unique. True by default.
max_attempts (int) – The maximum number of attempts, to prevent the method from getting stuck trying to obtain n_states different samples if unique is True. 100000 by default, therefore if more than 100000 are requested, max_attempts should be increased accordingly.
exclude_source (bool) – If True, exclude the source state from the list of states.
- Returns:
states (list) – A list of randomly sampled states.
- Raises:
ValueError – If max_attempts is smaller than n_states
RuntimeError – If the maximum number of attempts is reached before obtaining the requested number of unique states.
- Return type:
List
- get_policy_output(params=None)[source]
Defines the structure of the output of the policy model, from which an action is to be determined or sampled, by returning a vector with a fixed random policy. As a baseline, the policy is uniform over the dimensionality of the action space.
Continuous environments will generally have to overwrite this method.
- Parameters:
params (Optional[dict])
- Return type:
torchtyping.TensorType[policy_output_dim]
- states2proxy(states)[source]
Prepares a batch of states in “environment format” for the proxy. By default, the batch of states is converted into a tensor with float dtype and returned as is.
- Parameters:
states (list or tensor) – A batch of states in environment format, either as a list of states or as a single tensor.
- Returns:
A tensor containing all the states in the batch.
- Return type:
torchtyping.TensorType[batch, state_proxy_dim]
- state2proxy(state=None)[source]
Prepares a single state in “GFlowNet format” for the proxy. By default, simply states2proxy is called and the output will be a “batch” with a single state in the proxy format.
- Parameters:
state (list) – A state
- Return type:
torchtyping.TensorType[state_proxy_dim]
- states2policy(states)[source]
Prepares a batch of states in “environment format” for the policy model: By default, the batch of states is converted into a tensor with float dtype and returned as is.
- Parameters:
states (list or tensor) – A batch of states in environment format, either as a list of states or as a single tensor.
- Returns:
A tensor containing all the states in the batch.
- Return type:
torchtyping.TensorType[batch, policy_input_dim]
- state2policy(state=None)[source]
Prepares a state in “GFlowNet format” for the policy model. By default, states2policy is called, which by default will return the state as is.
- Parameters:
state (list) – A state
- Return type:
torchtyping.TensorType[policy_input_dim]
- readable2state(readable)[source]
Converts a human-readable representation of a state into the standard format.
- reset(env_id=None)[source]
Resets the environment.
- Parameters:
env_id (int or str) – Unique (ideally) identifier of the environment instance, used to identify the trajectory generated with this environment. If None, uuid.uuid4() is used.
- Returns:
self
- set_id(env_id)[source]
Sets the id given as argument and returns the environment.
- Parameters:
env_id (int or str) – Unique (ideally) identifier of the environment instance, used to identify the trajectory generated with this environment.
- Returns:
self
- set_state(state, done=False)[source]
Sets the state and done of an environment. Environments that cannot be “done” at all states (intermediate states are not fully constructed objects) should overwrite this method and check for validity.
- Parameters:
state (List)
done (Optional[bool])
- static equal(state_x, state_y)[source]
Checks whether the two input states are equal.
This method handles recursively multiple structure types: numbers, strings, tensors, dictionaries, lists and tuples.
The result is only True if the content of the two input states is identical.
The core functionality is implemented in
gflownet.envs.base.GFlowNetEnv.isclose()and this method simply calls it withdo_equal=True.- Parameters:
state_x (number, str, tensor, dict, list, tuple) – One of the states to be compared.
state_y (number, str, tensor, dict, list, tuple) – The other state to be compared.
- Returns:
bool – True if the two input states are equal; False otherwise.
- Raises:
NotImplementedError – If the input types are not part of the explicitly handles types.
- Return type:
bool
- static isclose(state_x, state_y, rtol=1e-05, atol=1e-08, do_equal=False)[source]
Checks whether the two input states are close, according to a tolerance.
This method relies on numpy’s and torch’s
isclose()methods, which both use the following formula:abs(x - y) <= rtol * abs(y) + atolThis method is used as well by
gflownet.envs.base.GFlowNetEnv.equal()in order to avoid code repetition. In this case,do_equalis True and numpy’s and torch’sequal()methods are used. This is preferred over usingrtolandatolequal to 0.0 for efficiency reasons.This method handles recursively multiple structure types: numbers, strings, tensors, dictionaries, lists and tuples.
The result is only True if the content of the two input states is identical or close enough, as defined by the tolerance values
rtolandatol. In the case of strings, True is only returned if the states are identical.- Parameters:
state_x (number, str, tensor, dict, list, tuple) – One of the states to be compared.
state_y (number, str, tensor, dict, list, tuple) – The other state to be compared.
rtol (float) – Relative tolerance for numeric values.
atol (float) – Maximum absolute tolerance threshold for numeric values.
do_equal (bool) – If True, comparisons are by equality instead of closeness and
rtolandatolare ignored.
- Returns:
bool – True if the two input states are equal or closer than the maximum tolerance; False otherwise.
- Raises:
NotImplementedError – If the input types are not part of the explicitly handles types.
- Return type:
bool
- __eq__(other, ignored_keys=[])[source]
Checks whether the current environment instance is equal to the input environment instance.
The attribute
self.idis ignored to determine whether the environments are equal.- Parameters:
other (GFlowNetEnv) – The environment instance to be compared.
ignored_keys (list) – A list of keys (strings) to be ignored in the comparison. This parameter may be used by subclasses that may need to ignore certain keys.
- Returns:
bool – True if the environments’s attributes are considered equal; False otherwise.
- Return type:
bool
- get_trajectories(traj_list, traj_actions_list, current_traj, current_actions)[source]
Determines all trajectories leading to each state in traj_list, recursively.
- Parameters:
traj_list (list) – List of trajectories (lists)
traj_actions_list (list) – List of actions within each trajectory
current_traj (list) – Current trajectory
current_actions (list) – Actions of current trajectory
- Returns:
traj_list (list) – List of trajectories (lists)
traj_actions_list (list) – List of actions within each trajectory
- compute_train_energy_proxy_and_rewards()[source]
Gather batched proxy data:
The ground-truth energy of the train set
The predicted proxy energy over the train set
The reward version of those energies (with env.proxy2reward)
- Returns:
gt_energy (torch.Tensor) – The ground-truth energies in the proxy’s train set
proxy_energy (torch.Tensor) – The proxy’s predicted energies over its train set
gt_reward (torch.Tensor) – The reward version of the ground-truth energies
proxy_reward (torch.Tensor) – The reward version of the proxy’s predicted energies
- mask_conditioning(mask, env_cond, backward)[source]
Conditions the input mask based on the restrictions imposed by a conditioning environment, env_cond.
It is assumed that the state space of the conditioning environment is a subset of the state space of the original environment (self). The conditioning mechanism goes as follows: given a state, its corresponding mask and a conditioning environment, the mask of invalid actions is updated such that all actions that would be invalid in the conditioning environment are made invalid, even though they may not be invalid in the original environment.
- Parameters:
mask (Union[List[bool], torchtyping.TensorType[mask_dim]])
backward (bool)
- top_k_metrics_and_plots(states, top_k, name, energy=None, reward=None, step=None, **kwargs)[source]
Compute top_k metrics and plots for the given states.
In particular, if no states, energy, or reward are passed, then the name must be “train”, and the energy and reward will be computed from the proxy using
env.compute_train_energy_proxy_and_rewards(). In this case,top_k_metrics_and_plotswill be called a second time to compute the metrics and plots of the proxy distribution in addition to the ground-truth distribution. Train mode should only be called once at the begining of training as distributions do not change over time.If
statesare passed, then the energy and reward will be computed from the proxy for those states. They are typically sampled from the current GFN.Otherwise, energy and reward should be passed directly.
Plots and metrics: - mean+std of energy and reward - mean+std of top_k energy and reward - histogram of energy and reward - histogram of top_k energy and reward
- Parameters:
states (list) – List of states to compute metrics and plots for.
top_k (int) – Number of top k states to compute metrics and plots for. “top” means lowest energy/highest reward.
name (str) – Name of the distribution to compute metrics and plots for. Typically “gflownet”, “random” or “train”. Will be used in metrics names like
f"Mean {name} energy".energy (torch.Tensor, optional) – Batch of pre-computed energies
reward (torch.Tensor, optional) – Batch of pre-computed rewards
step (int, optional) – Step number to use for the plot title.
- Returns:
metrics (dict) – Dictionary of metrics: str->float
figs (list) – List of matplotlib figures
figs_names (list) – List of figure names for
figs
- plot_reward_distribution(states=None, scores=None, ax=None, title=None, proxy=None, **kwargs)[source]
- test(samples)[source]
Placeholder for a custom test function that can be defined for a specific environment. Can be overwritten if special evaluation procedure is needed for a given environment.
- Parameters:
samples (Union[torchtyping.TensorType[n_trajectories, ...], numpy.typing.NDArray[numpy.float32], List]) – A collection of sampled terminating states.
- Returns:
metrics – A dictionary with metrics and their calculated values.
- Return type:
dict