"""
Base class for Stack environments.
Stack environments are environments which consist of sequence of multiple
sub-environments in a fixed order.
"""
from collections import OrderedDict
from typing import Dict, List, Optional, Sequence, Tuple, Union
import torch
from torchtyping import TensorType
from gflownet.envs.base import GFlowNetEnv
from gflownet.envs.composite.base import CompositeBase
from gflownet.utils.common import copy, tfloat
[docs]
class Stack(CompositeBase):
"""
Base class to create new environments by stacking multiple sub-environments.
This class imposes the order specified in the creation, such that the trajectory of
the first sub-environment needs to be completed (reach done) before the actions of
the second sub-environment become valid, and so on and so forth.
The Stack is a sub-class of the CompositeBase, and it thus enables the
incorporation of constraints across sub-environments via the
:py:meth:`~gflownet.envs.composite.base.CompositeBase._apply_constraints` method.
In order to implement the application of constraints, Stack sub-classes must
override:
- :py:meth:`~gflownet.envs.composite.base.CompositeBase._apply_constraints_forward`
- :py:meth:`~gflownet.envs.composite.base.CompositeBase._apply_constraints_backward`
"""
def __init__(
self,
subenvs: Sequence[GFlowNetEnv],
**kwargs,
):
"""
Parameters
----------
subenvs : Sequence[GFlowNetEnv]
A sequence containing the ordered list of the sub-environments to be
stacked.
"""
[docs]
self.n_subenvs = len(self.subenvs)
[docs]
self.max_elements = self.n_subenvs
# Determine the unique environments
(
self.envs_unique,
_,
self.unique_indices,
) = self._get_unique_environments(self.subenvs)
# States are represented as a dictionary with the following keys and values:
# - Meta-data about the Stack
# - "_active": The index of the currently active sub-environment, starting
# from 0 at the source state, up to the total number of sub-environments.
# - States of the sub-environments, with keys the indices of the subenvs.
# Note: "_dones" is not included because it can be inferred from the active
# sub-environment.
[docs]
self.source = {"_active": 0}
self.source.update(
{idx: subenv.source for idx, subenv in enumerate(self.subenvs)}
)
# TODO: review if needed
# Get action dimensionality by computing the maximum action length among all
# sub-environments, and adding 1 to indicate the sub-environment.
[docs]
self.action_dim = max([len(subenv.eos) for subenv in self.subenvs]) + 1
# The EOS action of the Stack is EOS action of the last sub-environment
[docs]
self.eos = self._pad_action(
self.subenvs[-1].eos, idx_unique=self.unique_indices[-1]
)
# Policy distributions parameters
kwargs["fixed_distr_params"] = [
env.fixed_distr_params for env in self.envs_unique
]
kwargs["random_distr_params"] = [
env.random_distr_params for env in self.envs_unique
]
# Base class init
super().__init__(**kwargs)
# TODO: this should be turned into a property
# The stack is continuous if any subenv is continuous
[docs]
self.continuous = any([subenv.continuous for subenv in self.subenvs])
def _get_dones(self, state: Optional[Dict] = None) -> List[int]:
"""
Returns a list indicating which sub-environments are done (1) or not done (0).
This method is overriden because Stack states do not contain the key
``"_dones"``, since this information can be inferred from the active subenv.
Parameters
----------
state : Dict
A state of the Stack environment.
Returns
-------
The list of dones as integer flags (0 or 1).
"""
state = self._get_state(state)
active_subenv = self._get_active_subenv(state)
return [1] * active_subenv + [0] * (self.n_subenvs - active_subenv)
def _get_subdone(self, idx_subenv: int, state: Optional[Dict] = None) -> bool:
"""
Returns whether if the sub-environment at ``idx_subenv`` is done.
This method is overriden for efficiency.
Parameters
----------
idx_subenv : int
Index of the sub-environment to query.
state : dict
A state of the composite environment.
Returns
-------
True if the sub-environment at ``idx_subenv`` is done; False otherwise.
"""
assert idx_subenv < self.n_subenvs
return self._get_active_subenv(state) > idx_subenv
# TODO: remove after refactoring of CompositeBase
def _get_unique_indices(self, state: Optional[Dict] = None) -> List[int]:
"""
Returns the list of unique indices.
The unique indices are the indices of the unique environments corresponding to
each sub-environment and sub-state.
This method is overriden because the Stack states do not contain the key
``"_envs_unique"``. A future refactoring should change the CompositeBase so
that this key is not included by default in the states.
Parameters
----------
state : dict
Ignored.
Returns
-------
list
The list of unique environment indices.
"""
return self.unique_indices
def _compute_mask_dim(self):
"""
Calculates the mask dimensionality of the Stack environment.
The mask consists of:
- A one-hot encoding of the index of the active sub-environment.
- The mask of the active sub-environment.
Therefore, the dimensionality is the number of sub-environments, plus the
maximum dimensionality of the mask of all sub-environments.
Returns
-------
int
The number of elements in the Stack masks.
"""
mask_dim_envs_unique = [env.mask_dim for env in self.envs_unique]
return max(mask_dim_envs_unique) + self.n_subenvs
def _get_max_trajectory_length(self) -> int:
"""
Returns the maximum trajectory length of the environment, including the EOS
action.
Returns
-------
int
The maximum trajectory length.
"""
return sum([subenv.max_traj_length for subenv in self.subenvs])
[docs]
def get_mask_invalid_actions_forward(
self, state: Optional[Dict] = None, done: Optional[bool] = None
) -> List[bool]:
"""
Computes the forward actions mask of the state.
The mask of the Stack environment is the mask of the active sub-environment,
preceded by a one-hot encoding of the index of the subenv and padded with False
up to mask_dim. Including only the relevant mask saves memory and computation.
If a state is passed as an argument (not ``None``) and the Stack has
constraints, the constraints are applied before computing the mask and reset
thereafter. This is necessary because otherwise the sub-environments may not
have the correct attributes necessary to calculate the mask.
"""
do_constraints = state is not None and id(state) != id(self.state)
state = self._get_state(state)
done = self._get_done(done)
# Apply constraints based on the input state
if do_constraints:
do_constraints = self._apply_constraints(state=state)
# Get active sub-environment, substate and unique environment
active_subenv = self._get_active_subenv(state)
state_subenv = self._get_substate(state, active_subenv)
subenv = self._get_unique_env_of_subenv(active_subenv)
# Obtain mask of substate
mask = subenv.get_mask_invalid_actions_forward(state_subenv, done)
# Reset constraints for self.state
if do_constraints:
self._apply_constraints(state=self.state)
return self._format_mask(mask, active_subenv)
[docs]
def get_mask_invalid_actions_backward(
self, state: Optional[Dict] = None, done: Optional[bool] = None
) -> List[bool]:
"""
Computes the backward actions mask of the state.
The mask of the Stack environment is the mask of the relevant sub-environment,
preceded by a one-hot encoding of the index of the subenv and padded with False
up to ``mask_dim``. Including only the relevant mask saves memory and
computation.
The relevant sub-environment regarding the backward mask is the current
sub-environment except if the state of the sub-environment is the subenv's
source, in which case the mask must be the one of the preceding
sub-environment, so as to sample its EOS action.
There are two exceptions to the above case:
- If ``done`` is True, in which case the current sub-environment is the last
one and the EOS action must come from itself, not the preceding subenv.
- If the current stage is the first sub-environment, in which case there is
no preceding subenv.
If a state is passed as an argument (not ``None``) and the Stack has
constraints, the constraints are applied before computing the mask and reset
thereafter. This is necessary because otherwise the sub-environments may not
have the correct attributes necessary to calculate the mask.
"""
do_constraints = state is not None and id(state) != id(self.state)
state = self._get_state(state)
done = self._get_done(done)
# Apply constraints based on the input state
if do_constraints:
do_constraints = self._apply_constraints(state=state)
# Get active sub-environment, substate and unique environment
active_subenv = self._get_active_subenv(state)
state_subenv = self._get_substate(state, active_subenv)
subenv = self._get_unique_env_of_subenv(active_subenv)
# Change the relevant sub-environment and set done to True if the substate is
# the source of an intermediate sub-environment
if active_subenv > 0 and not done and subenv.is_source(state_subenv):
relevant_subenv = active_subenv - 1
state_subenv = self._get_substate(state, relevant_subenv)
subenv = self._get_unique_env_of_subenv(relevant_subenv)
done = True
else:
relevant_subenv = active_subenv
# Obtain mask of substate
mask = subenv.get_mask_invalid_actions_backward(state_subenv, done)
# Reset constraints for self.state
if do_constraints:
self._apply_constraints(state=self.state)
return self._format_mask(mask, relevant_subenv)
# TODO: rethink whether padding should be True (invalid) instead.
def _format_mask(self, mask: List[bool], idx_subenv: int):
"""
Formats the mask of a sub-environment into a Stack mask.
The output format is the input mask, which corresponds to a sub-environment,
preceded by a one-hot encoding of the index of active sub-environment and with
False up to ``self.mask_dim``.
Parameters
----------
mask : List[bool]
The mask of a sub-environment
idx_subenv : int
The index of the sub-environment to be one-hot encoded.
"""
idx_onehot = [False] * self.n_subenvs
idx_onehot[idx_subenv] = True
padding = [False] * (self.mask_dim - (len(mask) + self.n_subenvs))
return idx_onehot + mask + padding
def _unformat_mask(
self,
mask: Union[List[bool], TensorType["batch_size", "mask_dim"]],
idx_subenv: int = None,
mask_dim: int = None,
):
"""
Extracts the mask of the sub-environment from a Stack-formated mask or batch of
masks.
This method removes the one-hot encoding of the index of the active
sub-environment that precedes the subenv mask, as well as the padding.
Parameters
----------
mask : List[bool] or tensor
A Stack mask or batch of masks
idx_subenv : int
The index of the sub-environment whose mask is to be extracted. If None,
``mask_dim`` must be passed.
mask_dim : int
The dimensionality of the mask to be extracted. If None, ``idx_subenv``
must be passed. Ignored if ``idx_subenv`` is not None.
Raises
------
ValueError
If ``idx_subenv`` and ``mask_dim`` are both None.
ValueError
If the mask is neither a list nor a tensor.
"""
if idx_subenv is not None:
mask_dim = self.subenvs[idx_subenv].mask_dim
elif mask_dim is None:
raise ValueError("idx_subenv and mask_dim cannot be both None")
if isinstance(mask, list):
return mask[self.n_subenvs : self.n_subenvs + mask_dim]
elif torch.is_tensor(mask):
return mask[:, self.n_subenvs : self.n_subenvs + mask_dim]
else:
raise ValueError("The input mask can only be a list or a tensor")
[docs]
def get_valid_actions(
self,
mask: Optional[bool] = None,
state: Optional[Dict] = None,
done: Optional[bool] = None,
backward: Optional[bool] = False,
) -> List[Tuple]:
"""
Returns the list of non-invalid (valid, for short) actions.
This method is overridden because the mask of a Stack of environments does not
cover the entire action space, but only the relevant sub-environment. Therefore,
this method calls the ``get_valid_actions()`` method of the currently relevant
sub-environment and returns the padded actions.
If a state is passed as an argument (not ``None``) and the Stack has
constraints, the constraints are applied before computing the mask and reset
thereafter. This is necessary because otherwise the sub-environments may not
have the correct attributes necessary to calculate the mask.
"""
do_constraints = state is not None and id(state) != id(self.state)
state = self._get_state(state)
done = self._get_done(done)
# Apply constraints based on the input state
if do_constraints:
do_constraints = self._apply_constraints(state=state)
# Get active sub-environment, substate and unique environment
active_subenv = self._get_active_subenv(state)
state_subenv = self._get_substate(state, active_subenv)
subenv = self._get_unique_env_of_subenv(active_subenv)
# Change the relevant sub-environment and set done to True if the substate is
# the source of an intermediate sub-environment
if (
backward
and active_subenv > 0
and not done
and subenv.is_source(state_subenv)
):
relevant_subenv = active_subenv - 1
state_subenv = self._get_substate(state, relevant_subenv)
subenv = self._get_unique_env_of_subenv(relevant_subenv)
done = True
else:
relevant_subenv = active_subenv
if mask is not None:
# Extract the part of the mask corresponding to the sub-environment
mask = self._unformat_mask(mask, relevant_subenv)
# Obtain valid actions
idx_unique = self._get_unique_idx_of_subenv(relevant_subenv)
valid_actions = [
self._pad_action(action, idx_unique)
for action in subenv.get_valid_actions(mask, state_subenv, done, backward)
]
# Reset constraints for self.state
if do_constraints:
self._apply_constraints(state=self.state)
return valid_actions
# TODO: review
[docs]
def mask_conditioning(
self, mask: Union[List[bool], TensorType["mask_dim"]], env_cond, backward: bool
):
"""
Conditions the input mask based on the restrictions imposed by a conditioning
environment, env_cond.
This method is overriden because the base mask_conditioning would change the
mask unaware of the special Stack format. Therefore, this method calls the
mask_conditioning() method of the currently relevant sub-environment and
returns the mask with the correct Stack format.
"""
stage = self._get_stage()
subenv = self.subenvs[stage]
# Extract the part of the mask corresponding to the sub-environment
# TODO: consider writing a method to do this
mask = self._unformat_mask(mask, stage)
env_cond = env_cond.subenvs[stage]
mask = subenv.mask_conditioning(mask, env_cond, backward)
return self._format_mask(mask, stage, subenv.mask_dim)
[docs]
def get_parents(
self,
state: Optional[Dict] = None,
done: Optional[bool] = None,
action: Optional[Tuple] = None,
) -> Tuple[List, List]:
"""
Determines all parents and actions that lead to the input state.
If a state is passed as an argument (not ``None``) and the Stack has
constraints, the constraints are applied before computing the mask and reset
thereafter. This is necessary because otherwise the sub-environments may not
have the correct attributes necessary to calculate the mask.
Parameters
----------
state : list
State in environment format. If not, self.state is used.
done : bool
Whether the trajectory is done. If None, self.done is used.
action : tuple
Ignored.
Returns
-------
parents : list
List of parents in state format
actions : list
List of actions that lead to state for each parent in parents
"""
do_constraints = state is not None and id(state) != id(self.state)
state = self._get_state(state)
done = self._get_done(done)
# If done is True, the only parent is the state itself with action EOS.
if done:
return [state], [self.eos]
# Apply constraints based on the input state
if do_constraints:
do_constraints = self._apply_constraints(state=state)
# Get active sub-environment, substate and unique environment
active_subenv = self._get_active_subenv(state)
state_subenv = self._get_substate(state, active_subenv)
subenv = self._get_unique_env_of_subenv(active_subenv)
# Change the relevant sub-environment and set done to True if the substate is
# the source of an intermediate sub-environment
if active_subenv > 0 and not done and subenv.is_source(state_subenv):
relevant_subenv = active_subenv - 1
state_subenv = self._get_substate(state, relevant_subenv)
subenv = self._get_unique_env_of_subenv(relevant_subenv)
done = True
else:
relevant_subenv = active_subenv
# Get parents of the relevant sub-environment
parents_subenv, parent_actions = subenv.get_parents(state_subenv, done)
# Convert subenv parents to Stack states
parents = []
for parent_subenv in parents_subenv:
parent = copy(state)
parent = self._set_active_subenv(relevant_subenv, parent)
parent = self._set_substate(relevant_subenv, parent_subenv, parent)
parents.append(parent)
# Pad actions
idx_unique = self._get_unique_idx_of_subenv(relevant_subenv)
parent_actions = [
self._pad_action(action, idx_unique) for action in parent_actions
]
# Reset constraints for self.state
if do_constraints:
self._apply_constraints(state=self.state)
return parents, parent_actions
[docs]
def step(
self, action: Tuple, skip_mask_check: bool = False
) -> Tuple[List, Tuple, bool]:
"""
Executes forward step given an action.
The action is performed by the corresponding sub-environment and then the
global state is updated accordingly. If the action is the EOS of the
sub-environment, the stage is advanced and constraints are set on the
subsequent sub-environment.
Parameters
----------
action : tuple
Action to be executed. The input action is global, that is padded.
Returns
-------
self.state : Dict
The state after executing the action.
action : int
Action executed.
valid : bool
False, if the action is not allowed for the current state. True otherwise.
"""
# If done, exit immediately
if self.done:
return self.state, action, False
# Get active sub-environment, unique index, subenv and action of subenv
active_subenv = self._get_active_subenv(self.state)
idx_unique = self._get_unique_idx_of_subenv(active_subenv)
subenv = self.subenvs[active_subenv]
action_subenv = self._depad_action(action, idx_unique)
# Perform pre-step from subenv - if it was done from the Stack env there could
# be a mismatch between mask and action space due to continuous subenvs.
action_to_check = subenv.action2representative(action_subenv)
# Skip mask check if stage is continuous
if subenv.continuous:
skip_mask_check = True
do_step, _, _ = subenv._pre_step(
action_to_check,
skip_mask_check=(skip_mask_check or self.skip_mask_check),
)
if not do_step:
return self.state, action, False
# Call step of active sub-environment
_, action_subenv, valid = subenv.step(action_subenv)
# If action is invalid, exit immediately.
if not valid:
return self.state, action, False
# Otherwise, increment number of actions and go on
self.n_actions += 1
# Check if action is EOS of subenv
if action_subenv == subenv.eos:
# If it is global EOS set done to True
if active_subenv == (self.n_subenvs - 1) and action == self.eos:
self.done = True
else:
# Increment active subenv and apply constraints
self._set_active_subenv(active_subenv + 1)
self._apply_constraints(action=action, is_backward=False)
else:
# Update substate
self._set_substate(active_subenv, subenv.state)
return self.state, action, valid
[docs]
def step_backwards(
self, action: Tuple, skip_mask_check: bool = False
) -> Tuple[List, Tuple, bool]:
"""
Executes backward step given an action.
The action is performed by the corresponding sub-environment and then the
global state is updated accordingly. If the updated state of the
sub-environment becomes its source, the stage is decreased.
Parameters
----------
action : tuple
Action to be executed. The input action is global, that is padded.
Returns
-------
self.state : list
The state after executing the action.
action : int
Action executed.
valid : bool
False, if the action is not allowed for the current state. True otherwise.
"""
# Get active sub-environment, unique index, and action of subenv
# The unique index is taken from the action, not the state, since they may not
# coincide in inter-subenv transition actions
active_subenv = self._get_active_subenv(self.state)
idx_unique = action[0]
action_subenv = self._depad_action(action, idx_unique)
# Determine the relevant sub-environment: if the action is EOS, then it is a
# transition between sub-environments and the relevant sub-environment is the
# previous to the active subenv in the state. Exception: if the env is done
env = self._get_env_unique(idx_unique)
if action_subenv == env.eos and not self.done:
relevant_subenv = active_subenv - 1
else:
relevant_subenv = active_subenv
subenv = self.subenvs[relevant_subenv]
# Perform pre-step from subenv - if it was done from the "superenv" there could
# be a mismatch between mask and action space due to continuous subenvs.
action_to_check = subenv.action2representative(action_subenv)
# Skip mask check if stage is continuous
if subenv.continuous:
skip_mask_check = True
do_step, _, _ = subenv._pre_step(
action_to_check,
backward=True,
skip_mask_check=(skip_mask_check or self.skip_mask_check),
)
if not do_step:
return self.state, action, False
# Call step of current subenvironment
_, _, valid = subenv.step_backwards(action_subenv)
# If action is invalid, exit immediately
if not valid:
return self.state, action, False
# Otherwise, increment number of actions and go on
self.n_actions += 1
# If action was from done, set done False
if self.done:
assert action == self.eos
self.done = False
# Update substate
self._set_substate(relevant_subenv, subenv.state)
self._set_active_subenv(relevant_subenv)
# If action is EOS of subenv, apply backward constraints
if action_subenv == subenv.eos:
self._apply_constraints(action=action, is_backward=True)
return self.state, action, valid
# TODO: review if random action probability works
[docs]
def sample_actions_batch(
self,
policy_outputs: TensorType["n_states", "policy_output_dim"],
mask: Optional[TensorType["n_states", "policy_output_dim"]] = None,
states_from: List[Dict] = None,
is_backward: Optional[bool] = False,
random_action_prob: Optional[float] = 0.0,
temperature_logits: Optional[float] = 1.0,
) -> Tuple[List[Tuple]]:
"""
Samples a batch of actions from a batch of policy outputs.
This method calls the ``sample_actions_batch()`` method of the sub-environment
corresponding to each state in the batch.
Note that in order to call sample_actions_batch() of the sub-environments, we
need to first extract the part of the policy outputs, the masks and the states
that correspond to the sub-environment.
Parameters
----------
policy_outputs : tensor
The output of the GFlowNet policy model.
mask : tensor
The mask of invalid actions, formatted as in
:py:meth:`gflownet.envs.composite.stack.Stack._format_mask`
states_from : list
The states originating the actions, in environment format.
is_backward : bool
True if the actions are backward, False if the actions are forward
random_action_prob : float, optional
The probability of sampling a random action. If larger than one, the model
outputs will be replaced by a random policy vector with probability
`random_action_prob`, according to Bernoulli distribution.
temperature_logits : float, optional
A scalar by which the model outputs are divided to temper the sampling
distribution.
Returns
-------
actions : list
The list of sampled actions.
"""
# Get the indices of the relevant sub-environments from the one-hot prefix of
# the mask
indices_relevant = torch.where(mask[:, : self.n_subenvs])[1]
indices_relevant_int = indices_relevant.tolist()
# Create the tensor indices_unique, which contains the index of the unique
# environment corresponding to the relevant subenv and the list of unique
# indices
indices_unique = torch.empty_like(indices_relevant)
for idx_subenv in set(indices_relevant_int):
idx_unique = self._get_unique_idx_of_subenv(idx_subenv)
indices_unique[indices_relevant == idx_subenv] = idx_unique
indices_unique_int = indices_unique.tolist()
# Create a dictionary with keys equal to the unique indices and the values are
# corresponding to the sub-environment.
states_dict = {idx: [] for idx in self.unique_indices}
for state, idx_subenv, idx_unique in zip(
states_from, indices_relevant_int, indices_unique_int
):
states_dict[idx_unique].append(self._get_substate(state, idx_subenv))
# Sample actions from each unique environment
actions_dict = {}
for idx in set(indices_unique_int):
env = self._get_env_unique(idx)
env_mask = indices_unique == idx
actions_dict[idx] = env.sample_actions_batch(
self._get_policy_outputs_of_env_unique(policy_outputs[env_mask], idx),
self._unformat_mask(mask[env_mask, :], mask_dim=env.mask_dim),
states_dict[idx],
is_backward,
random_action_prob,
temperature_logits,
)
# Stitch all actions in the right order, with the right padding
return [
self._pad_action(actions_dict[idx].pop(0), idx)
for idx in indices_unique_int
]
[docs]
def get_logprobs(
self,
policy_outputs: TensorType["n_states", "policy_output_dim"],
actions: Union[List, TensorType["n_states", "action_dim"]],
mask: TensorType["n_states", "mask_dim"],
states_from: List[Dict],
is_backward: bool,
) -> TensorType["batch_size"]:
"""
Computes log probabilities of actions given policy outputs and actions.
Parameters
----------
policy_outputs : tensor
The output of the GFlowNet policy model.
mask : tensor
The mask containing information about invalid actions and special cases.
actions : list or tensor
The actions (global) from each state in the batch for which to compute the
log probability.
states_from : tensor
The states originating the actions, in environment format.
is_backward : bool
True if the actions are backward, False if the actions are forward
(default).
Returns
-------
tensor
The log probabilities of the transitions.
"""
actions = tfloat(actions, float_type=self.float, device=self.device)
n_states = policy_outputs.shape[0]
# Get the indices of the relevant sub-environments from the one-hot prefix of
# the mask
indices_relevant = torch.where(mask[:, : self.n_subenvs])[1]
indices_relevant_int = indices_relevant.tolist()
# Create the tensor indices_unique, which contains the index of the unique
# environment corresponding to the relevant subenv and the list of unique
# indices
indices_unique = torch.empty_like(indices_relevant)
for idx_subenv in set(indices_relevant_int):
idx_unique = self._get_unique_idx_of_subenv(idx_subenv)
indices_unique[indices_relevant == idx_subenv] = idx_unique
indices_unique_int = indices_unique.tolist()
# Create a dictionary with keys equal to the unique indices and the values are
# corresponding to the sub-environment.
states_dict = {idx: [] for idx in self.unique_indices}
for state, idx_subenv, idx_unique in zip(
states_from, indices_relevant_int, indices_unique_int
):
states_dict[idx_unique].append(self._get_substate(state, idx_subenv))
# Compute logprobs from each sub-environment
logprobs = torch.empty(n_states, dtype=self.float, device=self.device)
for idx in set(indices_unique_int):
env = self._get_env_unique(idx)
env_mask = indices_unique == idx
logprobs[env_mask] = env.get_logprobs(
self._get_policy_outputs_of_env_unique(policy_outputs[env_mask], idx),
self._depad_action_batch(actions[env_mask, :], idx),
self._unformat_mask(mask[env_mask, :], mask_dim=env.mask_dim),
states_dict[idx],
is_backward,
)
return logprobs
[docs]
def states2policy(
self, states: List[Dict]
) -> TensorType["batch", "state_policy_dim"]:
"""
Prepares a batch of states in environment format for the policy model.
The default policy representation is a concatenation of the policy-format
states of the sub-environments.
There is only one call of ``subenv.states2policy()`` for each sub-environment,
on all the corresponding substates in the batch.
Parameters
----------
states : list
A batch of states in environment format.
Returns
-------
A tensor containing all the states in the batch.
"""
return torch.cat(
[
subenv.states2policy([state[idx] for state in states])
for idx, subenv in enumerate(self.subenvs)
],
dim=1,
)
[docs]
def states2proxy(self, states: List[Dict]) -> List[List]:
"""
Prepares a batch of states in environment format for a proxy: simply a
concatenation of the proxy-format states of the sub-environments.
Parameters
----------
states : list
A batch of states in environment format.
Returns
-------
A list of lists, each containing the proxy representation of all the states in
the Stack, for all the Stacks in the batch..
"""
states_proxy = []
for state in states:
states_proxy.append(
[
subenv.state2proxy(self._get_substate(state, idx))[0]
for idx, subenv in enumerate(self.subenvs)
]
)
return states_proxy
[docs]
def state2readable(self, state: Optional[Dict] = None) -> str:
"""
Converts a state into human-readable representation. It concatenates the
readable representations of each sub-environment, separated by "; " and
preceded by "Stage {stage}; ".
"""
state = self._get_state(state)
readable = f"Active: {self._get_active_subenv(state)}; " + "".join(
[
subenv.state2readable(self._get_substate(state, idx)) + "; "
for idx, subenv in enumerate(self.subenvs)
]
)
readable = readable[:-2]
return readable
[docs]
def readable2state(self, readable: str) -> List[int]:
"""
Converts a human-readable representation of a state into the standard format.
"""
readables = readable.split("; ")
active_subenv = int(readables[0][-1])
readables = readables[1:]
state = {"_active": active_subenv}
state.update(
{
idx: subenv.readable2state(readables[idx])
for idx, subenv in enumerate(self.subenvs)
}
)
return state
# TODO: review if this could be moved to CompositeBase
[docs]
def action2representative(self, action: Tuple) -> int:
"""
Replaces the part of the action associated with a sub-environment by its
representative. The part of the action that identifies the sub-environment
concerned by the action remains unaffected.
"""
idx_unique = action[0]
action_subenv = self._depad_action(action, idx_unique)
env = self._get_env_unique(idx_unique)
action_subenv_representative = env.action2representative(action_subenv)
return self._pad_action(action_subenv_representative, idx_unique)
[docs]
def is_source(self, state: Optional[Dict] = None) -> bool:
"""
Returns True if the environment's state or the state passed as parameter (if
not None) is the source state of the environment.
This method is overriden for efficiency (for example, it would return False
immediately if the stage is not the first stage) and to cover special uses of
the Stack.
Parameters
----------
state : Dict
None, or a state in environment format.
Returns
-------
bool
Whether the state is the source state of the environment
"""
state = self._get_state(state)
active_subenv = self._get_active_subenv(state)
if active_subenv != 0:
return False
for idx, subenv in enumerate(self.subenvs):
if not subenv.is_source(self._get_substate(state, idx)):
return False
else:
return True