gflownet.envs.composite.setbase
Classes implementing the family of Set meta-environments, which allow to combine multiple sub-environments without any specific order.
Classes
Initializes the BaseSet. |
Module Contents
- class gflownet.envs.composite.setbase.BaseSet(can_alternate_subenvs=True, **kwargs)[source]
Bases:
gflownet.envs.composite.base.CompositeBaseInitializes the BaseSet.
- Parameters:
can_alternate_subenvs (bool) – If True, actions of different sub-environments can alternate and each sub-environment action is preceded and followed by a meta-action to toggle the sub-environment. If False, once a sub-environment is activated, only actions of that sub-environment can be performed until it gets done (its EOS action is performed).
- property n_toggle_actions: int[source]
Returns the number of actions to toggle sub-environments or unique environments.
If the Set allows alternating actions between sub-environments, the number of toggle actions is the number of sub-environments. Otherwise, toggle actions activate unique environments and the number of unique environments is returned.
- Return type:
int
- get_action_space()[source]
Constructs list with all possible actions, including eos.
- The action space of a Set environment consists of:
The actions to activate specific sub-environments or unique environments.
The EOS action.
The concatenation of the actions of all unique environments
In order to make all actions the same length (required to construct batches of actions as a tensor), the actions are zero-padded from the back.
In order to make all actions unique, the unique environment index is added as the first element of the action.
Note that the actions of unique environments are only added once to the action space, regardless of how many elements of the unique environment (sub-environments) there are in the set. In other words, identical environments that are part of the Set share the actions and a given action will have an effect on the sub-environment that is active.
The actions to activate a specific sub-environment are represented as: (-1, subenv index, ZERO-PADDING)
See: -
_pad_action()-_depad_action()- Return type:
List[Tuple]
- action_produces_permutation(action, is_backward=False)[source]
Determines whether an action produces permutations in the resulting state.
The Set introduces actions that produce permutations, in particular in the key
_keysof the state. These actions are introduced ifself.can_alternate_subenvsis False.In particular, the actions that produce permutations are backward actions that toggle a sub-environment.
Note that this method does not check whether all relevant substates are identical, in which case, there is effectively not more than one permutation. Instead, True is returned if the action _could_ produce permutations in the resulting state.
- Parameters:
action (tuple) – An action of the environment.
is_backward (bool) – Whether the transition to consider is backward (True) or forward (False).
- Returns:
bool – Whether the input actions produces permutations in the resulting state, in the direction indicated by
is_backward.- Return type:
bool
- get_mask_invalid_actions_forward(state=None, done=None)[source]
Computes the forward actions mask of the state.
The mask of the Set environment is the concatenation of the following: - A one-hot encoding of the index of the sub-environment or unique environment
(True at the index of the active environment). All False if the only valid actions are meta-actions.
- Actual (main) mask of invalid actions:
The mask of the actions to activate a sub-environment or unique environment, OR
The mask of the active sub-environment.
The mask is False-padded from the back up to mask_dim.
- Parameters:
state (Optional[Dict])
done (Optional[bool])
- Return type:
List[bool]
- get_mask_invalid_actions_backward(state=None, done=None)[source]
Computes the backward actions mask of the state.
The mask of the Set environment is the concatenation of the following: - A one-hot encoding of the index of the subenv (True at the index of the
active environment). All False if no sub-environment is active.
- Actual (main) mask of invalid actions:
The mask of the actions to activate a sub-environment, OR
The mask of the active sub-environment.
The mask is False-padded from the back up to mask_dim.
- Parameters:
state (Optional[Dict])
done (Optional[bool])
- Return type:
List[bool]
- mask_conditioning(mask, env_cond, backward)[source]
Conditions the input mask based on the restrictions imposed by a conditioning environment, env_cond.
This method is overriden because the base mask_conditioning would change the mask unaware of the special Stack format. Therefore, this method calls the mask_conditioning() method of the currently relevant sub-environment and returns the mask with the correct Stack format.
- Parameters:
mask (Union[List[bool], torchtyping.TensorType[mask_dim]])
backward (bool)
- step(action, skip_mask_check=False)[source]
Executes forward step given an action.
Actions may be either sub-environent actions, or set actions. If the former, the action is performed by the corresponding sub-environment and then the set state is updated accordingly. If the latter, no sub-environment is involved and the changes are in the meta-data of the state (active subenv and toggle flag)
Because the same action may correspond to multiple sub-environments, the action will always be performed on the active sub-environment.
- Toggle actions:
Activate the corresponding sub-environment if no sub-environment is currently active.
If can_alternate_subenvs is True, the toggle flag is set to 1.
Reset the active sub-environment flag to -1 if a sub-environment is currently active.
The toggle flag is expected to be 0 and it remains 0.
- Environment actions:
Updates the corresponding sub-environment as well as the set state.
If can_alternate_subenvs is True, the toggle flag is set to 0.
- Parameters:
action (tuple) – Action to be executed. The input action is global, that is padded.
skip_mask_check (bool)
- Returns:
self.state (dict) – The state after executing the action.
action (int) – Action executed.
valid (bool) – False, if the action is not allowed for the current state. True otherwise.
- Return type:
Tuple[Dict, Tuple, bool]
- step_backwards(action, skip_mask_check=False)[source]
Executes backward step given an action.
Actions may be either sub-environent actions, or set actions. If the former, the action is performed by the corresponding sub-environment and then the set state is updated accordingly. If the latter, no sub-environment is involved and the changes are in the meta-data of the state (active subenv and toggle flag)
Because the same action may correspond to multiple sub-environments, the action will always be performed on the active sub-environment.
- Toggle actions:
Activate the corresponding sub-environment if no sub-environment is currently active.
Reset the active sub-environment flag to -1 if a sub-environment is currently active.
Set the toggle flag to 0.
- Environment actions:
Updates the corresponding sub-environment as well as the set state.
If can_alternate_subenvs is True, set the toggle flag is set to 1.
- Parameters:
action (tuple) – Action to be executed. The input action is global, that is padded.
skip_mask_check (bool)
- Returns:
self.state (dict) – The state after executing the action.
action (int) – Action executed.
valid (bool) – False, if the action is not allowed for the current state. True otherwise.
- Return type:
Tuple[Dict, Tuple, bool]
- get_parents(state=None, done=None, action=None)[source]
Determines all parents and actions that lead to state.
- Parameters:
state (dict) – State in environment format. If not, self.state is used.
done (bool) – Whether the trajectory is done. If None, self.done is used.
action (tuple) – Ignored.
- Returns:
parents (list) – List of parents in state format
actions (list) – List of actions that lead to state for each parent in parents
- Return type:
Tuple[List, List]
- sample_actions_batch(policy_outputs, mask=None, states_from=None, is_backward=False, random_action_prob=0.0, temperature_logits=1.0)[source]
Samples a batch of actions from a batch of policy outputs.
This method calls the sample_actions_batch() method of the sub-environment corresponding to each state in the batch, or samples the actions to activate a sub-environment for the environments with no active environment.
Note that in order to call sample_actions_batch() of the sub-environments, we need to first extract the part of the policy outputs, the masks and the states that correspond to the sub-environment.
- Parameters:
policy_outputs (torchtyping.TensorType[n_states, policy_output_dim])
mask (Optional[torchtyping.TensorType[n_states, policy_output_dim]])
states_from (List)
is_backward (Optional[bool])
random_action_prob (Optional[float])
temperature_logits (Optional[float])
- Return type:
Tuple[List[Tuple], torchtyping.TensorType[n_states]]
- get_logprobs(policy_outputs, actions, mask, states_from, is_backward)[source]
Computes log probabilities of actions given policy outputs and actions.
- Parameters:
policy_outputs (tensor) – The output of the GFlowNet policy model.
mask (tensor) – The mask containing information about invalid actions and special cases.
actions (list or tensor) – The actions (global) from each state in the batch for which to compute the log probability.
states_from (tensor) – The states originating the actions, in environment format.
is_backward (bool) – True if the actions are backward, False if the actions are forward (default).
- Return type:
torchtyping.TensorType[batch_size]
- action2representative(action)[source]
Replaces the part of the action associated with a sub-environment by its representative. The part of the action that identifies the sub-environment concerned by the action remains unaffected.
- Parameters:
action (tuple) – An action of the Set environment (padded)
- Returns:
tuple – A representative of the action, re-padded as a Set action that should be in the action space.
- Return type:
Tuple
- get_valid_actions(mask=None, state=None, done=None, backward=False)[source]
Returns the list of non-invalid (valid, for short) according to the mask of invalid actions.
This method is overridden because the mask of a Set of environments does not cover the entire action space, but only the relevant sub-environment or the toggle actions, depending on the state. Therefore, this method calls the get_valid_actions() method of the active sub-environment or retrieves the valid toggle actions and returns the padded actions.
- Parameters:
mask (Optional[bool])
state (Optional[Dict])
done (Optional[bool])
backward (Optional[bool])
- Return type:
List[Tuple]
- get_policy_output(params)[source]
Defines the structure of the output of the policy model.
This method is overriden to add the policy outputs corresponding to the Set actions. These are concatenated to the policy outputs of the unique environments, obtained from the parent’s method. The policy output is the concatenation of the policy outputs corresponding to the Set actions (actions to activate a sub-environment and EOS) and the policy outputs of the unique environments.
- Parameters:
params (list) – A list of distribution parameters. This list has as many elements as there are unique environments, since all sub-environments of the same environment type are expected to be identical.
- Return type:
torchtyping.TensorType[policy_output_dim]
- is_source(state=None)[source]
Returns True if the environment’s state or the state passed as parameter (if not None) is the source state of the environment.
This method is overriden for efficiency (for example, it would return False immediately if the meta-data part of the state is not the source’s) and to cover special uses of the Set.
- Parameters:
state (dict) – None, or a state in environment format.
- Returns:
bool – Whether the state is the source state of the environment
- Return type:
bool
- equal(state_x, state_y)[source]
Checks whether the two input states are equal.
This method is overriden in order to account for the fact that states with permuted substates must be considered equal if the permutations are indeed equivalent. The permutatation of substates is not done by permuting the substates directly bu by permuting the list of keys in
state["_keys"].Thus, this method returns True if all keys of the state dictionary are equal (except
_keyswhich is ignored) and the substates are equal, after accounting for the permutation.This method uses the parent method in order to compare the substates. If a substate is a dictionary containing the key
_keys, then it is assumed it is a Set state and the current method is used. If Set states appear deeper in the substates, the comparison is not expected to behave as expected.- Parameters:
state_x (dict) – One of the Set states to be compared.
state_y (dict) – The other Set state to be compared.
- Returns:
bool – True if the two input states are equal; False otherwise.
- Return type:
bool
- __eq__(other, ignored_keys=[])[source]
Checks whether the current environment instance is equal to the input environment instance.
- This method is overriden to ignore the keys:
envs_unique_cache
- Parameters:
other (GFlowNetEnv) – The environment instance to be compared.
ignored_keys (list) – A list of keys (strings) to be ignored in the comparison. This parameter may be used by subclasses that may need to ignore certain keys. True if the environments’s attributes are considered equal; False otherwise.
- Return type:
bool