gflownet.envs.composite.stack
Base class for Stack environments.
Stack environments are environments which consist of sequence of multiple sub-environments in a fixed order.
Classes
Module Contents
- class gflownet.envs.composite.stack.Stack(subenvs, **kwargs)[source]
Bases:
gflownet.envs.composite.base.CompositeBase- Parameters:
subenvs (Sequence[GFlowNetEnv]) – A sequence containing the ordered list of the sub-environments to be stacked.
- get_mask_invalid_actions_forward(state=None, done=None)[source]
Computes the forward actions mask of the state.
The mask of the Stack environment is the mask of the active sub-environment, preceded by a one-hot encoding of the index of the subenv and padded with False up to mask_dim. Including only the relevant mask saves memory and computation.
If a state is passed as an argument (not
None) and the Stack has constraints, the constraints are applied before computing the mask and reset thereafter. This is necessary because otherwise the sub-environments may not have the correct attributes necessary to calculate the mask.- Parameters:
state (Optional[Dict])
done (Optional[bool])
- Return type:
List[bool]
- get_mask_invalid_actions_backward(state=None, done=None)[source]
Computes the backward actions mask of the state.
The mask of the Stack environment is the mask of the relevant sub-environment, preceded by a one-hot encoding of the index of the subenv and padded with False up to
mask_dim. Including only the relevant mask saves memory and computation.The relevant sub-environment regarding the backward mask is the current sub-environment except if the state of the sub-environment is the subenv’s source, in which case the mask must be the one of the preceding sub-environment, so as to sample its EOS action.
- There are two exceptions to the above case:
If
doneis True, in which case the current sub-environment is the last one and the EOS action must come from itself, not the preceding subenv.If the current stage is the first sub-environment, in which case there is no preceding subenv.
If a state is passed as an argument (not
None) and the Stack has constraints, the constraints are applied before computing the mask and reset thereafter. This is necessary because otherwise the sub-environments may not have the correct attributes necessary to calculate the mask.- Parameters:
state (Optional[Dict])
done (Optional[bool])
- Return type:
List[bool]
- get_valid_actions(mask=None, state=None, done=None, backward=False)[source]
Returns the list of non-invalid (valid, for short) actions.
This method is overridden because the mask of a Stack of environments does not cover the entire action space, but only the relevant sub-environment. Therefore, this method calls the
get_valid_actions()method of the currently relevant sub-environment and returns the padded actions.If a state is passed as an argument (not
None) and the Stack has constraints, the constraints are applied before computing the mask and reset thereafter. This is necessary because otherwise the sub-environments may not have the correct attributes necessary to calculate the mask.- Parameters:
mask (Optional[bool])
state (Optional[Dict])
done (Optional[bool])
backward (Optional[bool])
- Return type:
List[Tuple]
- mask_conditioning(mask, env_cond, backward)[source]
Conditions the input mask based on the restrictions imposed by a conditioning environment, env_cond.
This method is overriden because the base mask_conditioning would change the mask unaware of the special Stack format. Therefore, this method calls the mask_conditioning() method of the currently relevant sub-environment and returns the mask with the correct Stack format.
- Parameters:
mask (Union[List[bool], torchtyping.TensorType[mask_dim]])
backward (bool)
- get_parents(state=None, done=None, action=None)[source]
Determines all parents and actions that lead to the input state.
If a state is passed as an argument (not
None) and the Stack has constraints, the constraints are applied before computing the mask and reset thereafter. This is necessary because otherwise the sub-environments may not have the correct attributes necessary to calculate the mask.- Parameters:
state (list) – State in environment format. If not, self.state is used.
done (bool) – Whether the trajectory is done. If None, self.done is used.
action (tuple) – Ignored.
- Returns:
parents (list) – List of parents in state format
actions (list) – List of actions that lead to state for each parent in parents
- Return type:
Tuple[List, List]
- step(action, skip_mask_check=False)[source]
Executes forward step given an action.
The action is performed by the corresponding sub-environment and then the global state is updated accordingly. If the action is the EOS of the sub-environment, the stage is advanced and constraints are set on the subsequent sub-environment.
- Parameters:
action (tuple) – Action to be executed. The input action is global, that is padded.
skip_mask_check (bool)
- Returns:
self.state (Dict) – The state after executing the action.
action (int) – Action executed.
valid (bool) – False, if the action is not allowed for the current state. True otherwise.
- Return type:
Tuple[List, Tuple, bool]
- step_backwards(action, skip_mask_check=False)[source]
Executes backward step given an action.
The action is performed by the corresponding sub-environment and then the global state is updated accordingly. If the updated state of the sub-environment becomes its source, the stage is decreased.
- Parameters:
action (tuple) – Action to be executed. The input action is global, that is padded.
skip_mask_check (bool)
- Returns:
self.state (list) – The state after executing the action.
action (int) – Action executed.
valid (bool) – False, if the action is not allowed for the current state. True otherwise.
- Return type:
Tuple[List, Tuple, bool]
- sample_actions_batch(policy_outputs, mask=None, states_from=None, is_backward=False, random_action_prob=0.0, temperature_logits=1.0)[source]
Samples a batch of actions from a batch of policy outputs.
This method calls the
sample_actions_batch()method of the sub-environment corresponding to each state in the batch.Note that in order to call sample_actions_batch() of the sub-environments, we need to first extract the part of the policy outputs, the masks and the states that correspond to the sub-environment.
- Parameters:
policy_outputs (tensor) – The output of the GFlowNet policy model.
mask (tensor) – The mask of invalid actions, formatted as in
gflownet.envs.composite.stack.Stack._format_mask()states_from (list) – The states originating the actions, in environment format.
is_backward (bool) – True if the actions are backward, False if the actions are forward
random_action_prob (float, optional) – The probability of sampling a random action. If larger than one, the model outputs will be replaced by a random policy vector with probability
random_action_prob, according to Bernoulli distribution.temperature_logits (float, optional) – A scalar by which the model outputs are divided to temper the sampling distribution.
- Returns:
actions (list) – The list of sampled actions.
- Return type:
Tuple[List[Tuple]]
- get_logprobs(policy_outputs, actions, mask, states_from, is_backward)[source]
Computes log probabilities of actions given policy outputs and actions.
- Parameters:
policy_outputs (tensor) – The output of the GFlowNet policy model.
mask (tensor) – The mask containing information about invalid actions and special cases.
actions (list or tensor) – The actions (global) from each state in the batch for which to compute the log probability.
states_from (tensor) – The states originating the actions, in environment format.
is_backward (bool) – True if the actions are backward, False if the actions are forward (default).
- Returns:
tensor – The log probabilities of the transitions.
- Return type:
torchtyping.TensorType[batch_size]
- states2policy(states)[source]
Prepares a batch of states in environment format for the policy model.
The default policy representation is a concatenation of the policy-format states of the sub-environments.
There is only one call of
subenv.states2policy()for each sub-environment, on all the corresponding substates in the batch.- Parameters:
states (list) – A batch of states in environment format.
- Returns:
A tensor containing all the states in the batch.
- Return type:
torchtyping.TensorType[batch, state_policy_dim]
- states2proxy(states)[source]
Prepares a batch of states in environment format for a proxy: simply a concatenation of the proxy-format states of the sub-environments.
- Parameters:
states (list) – A batch of states in environment format.
- Returns:
A list of lists, each containing the proxy representation of all the states in
the Stack, for all the Stacks in the batch..
- Return type:
List[List]
- state2readable(state=None)[source]
Converts a state into human-readable representation. It concatenates the readable representations of each sub-environment, separated by “; “ and preceded by “Stage {stage}; “.
- Parameters:
state (Optional[Dict])
- Return type:
str
- readable2state(readable)[source]
Converts a human-readable representation of a state into the standard format.
- Parameters:
readable (str)
- Return type:
List[int]
- action2representative(action)[source]
Replaces the part of the action associated with a sub-environment by its representative. The part of the action that identifies the sub-environment concerned by the action remains unaffected.
- Parameters:
action (Tuple)
- Return type:
int
- is_source(state=None)[source]
Returns True if the environment’s state or the state passed as parameter (if not None) is the source state of the environment.
This method is overriden for efficiency (for example, it would return False immediately if the stage is not the first stage) and to cover special uses of the Stack.
- Parameters:
state (Dict) – None, or a state in environment format.
- Returns:
bool – Whether the state is the source state of the environment
- Return type:
bool