Source code for gflownet.envs.composite.stack

"""
Base class for Stack environments.

Stack environments are environments which consist of sequence of multiple
sub-environments in a fixed order.
"""

from collections import OrderedDict
from typing import Dict, List, Optional, Sequence, Tuple, Union

import torch
from torchtyping import TensorType

from gflownet.envs.base import GFlowNetEnv
from gflownet.envs.composite.base import CompositeBase
from gflownet.utils.common import copy, tfloat


[docs] class Stack(CompositeBase): """ Base class to create new environments by stacking multiple sub-environments. This class imposes the order specified in the creation, such that the trajectory of the first sub-environment needs to be completed (reach done) before the actions of the second sub-environment become valid, and so on and so forth. The Stack is a sub-class of the CompositeBase, and it thus enables the incorporation of constraints across sub-environments via the :py:meth:`~gflownet.envs.composite.base.CompositeBase._apply_constraints` method. In order to implement the application of constraints, Stack sub-classes must override: - :py:meth:`~gflownet.envs.composite.base.CompositeBase._apply_constraints_forward` - :py:meth:`~gflownet.envs.composite.base.CompositeBase._apply_constraints_backward` """ def __init__( self, subenvs: Sequence[GFlowNetEnv], **kwargs, ): """ Parameters ---------- subenvs : Sequence[GFlowNetEnv] A sequence containing the ordered list of the sub-environments to be stacked. """
[docs] self.subenvs = subenvs
[docs] self.n_subenvs = len(self.subenvs)
[docs] self.max_elements = self.n_subenvs
# Determine the unique environments ( self.envs_unique, _, self.unique_indices, ) = self._get_unique_environments(self.subenvs) # States are represented as a dictionary with the following keys and values: # - Meta-data about the Stack # - "_active": The index of the currently active sub-environment, starting # from 0 at the source state, up to the total number of sub-environments. # - States of the sub-environments, with keys the indices of the subenvs. # Note: "_dones" is not included because it can be inferred from the active # sub-environment.
[docs] self.source = {"_active": 0}
self.source.update( {idx: subenv.source for idx, subenv in enumerate(self.subenvs)} ) # TODO: review if needed # Get action dimensionality by computing the maximum action length among all # sub-environments, and adding 1 to indicate the sub-environment.
[docs] self.action_dim = max([len(subenv.eos) for subenv in self.subenvs]) + 1
# The EOS action of the Stack is EOS action of the last sub-environment
[docs] self.eos = self._pad_action( self.subenvs[-1].eos, idx_unique=self.unique_indices[-1] )
# Policy distributions parameters kwargs["fixed_distr_params"] = [ env.fixed_distr_params for env in self.envs_unique ] kwargs["random_distr_params"] = [ env.random_distr_params for env in self.envs_unique ] # Base class init super().__init__(**kwargs) # TODO: this should be turned into a property # The stack is continuous if any subenv is continuous
[docs] self.continuous = any([subenv.continuous for subenv in self.subenvs])
def _get_dones(self, state: Optional[Dict] = None) -> List[int]: """ Returns a list indicating which sub-environments are done (1) or not done (0). This method is overriden because Stack states do not contain the key ``"_dones"``, since this information can be inferred from the active subenv. Parameters ---------- state : Dict A state of the Stack environment. Returns ------- The list of dones as integer flags (0 or 1). """ state = self._get_state(state) active_subenv = self._get_active_subenv(state) return [1] * active_subenv + [0] * (self.n_subenvs - active_subenv) def _get_subdone(self, idx_subenv: int, state: Optional[Dict] = None) -> bool: """ Returns whether if the sub-environment at ``idx_subenv`` is done. This method is overriden for efficiency. Parameters ---------- idx_subenv : int Index of the sub-environment to query. state : dict A state of the composite environment. Returns ------- True if the sub-environment at ``idx_subenv`` is done; False otherwise. """ assert idx_subenv < self.n_subenvs return self._get_active_subenv(state) > idx_subenv # TODO: remove after refactoring of CompositeBase def _get_unique_indices(self, state: Optional[Dict] = None) -> List[int]: """ Returns the list of unique indices. The unique indices are the indices of the unique environments corresponding to each sub-environment and sub-state. This method is overriden because the Stack states do not contain the key ``"_envs_unique"``. A future refactoring should change the CompositeBase so that this key is not included by default in the states. Parameters ---------- state : dict Ignored. Returns ------- list The list of unique environment indices. """ return self.unique_indices def _compute_mask_dim(self): """ Calculates the mask dimensionality of the Stack environment. The mask consists of: - A one-hot encoding of the index of the active sub-environment. - The mask of the active sub-environment. Therefore, the dimensionality is the number of sub-environments, plus the maximum dimensionality of the mask of all sub-environments. Returns ------- int The number of elements in the Stack masks. """ mask_dim_envs_unique = [env.mask_dim for env in self.envs_unique] return max(mask_dim_envs_unique) + self.n_subenvs def _get_max_trajectory_length(self) -> int: """ Returns the maximum trajectory length of the environment, including the EOS action. Returns ------- int The maximum trajectory length. """ return sum([subenv.max_traj_length for subenv in self.subenvs])
[docs] def get_mask_invalid_actions_forward( self, state: Optional[Dict] = None, done: Optional[bool] = None ) -> List[bool]: """ Computes the forward actions mask of the state. The mask of the Stack environment is the mask of the active sub-environment, preceded by a one-hot encoding of the index of the subenv and padded with False up to mask_dim. Including only the relevant mask saves memory and computation. If a state is passed as an argument (not ``None``) and the Stack has constraints, the constraints are applied before computing the mask and reset thereafter. This is necessary because otherwise the sub-environments may not have the correct attributes necessary to calculate the mask. """ do_constraints = state is not None and id(state) != id(self.state) state = self._get_state(state) done = self._get_done(done) # Apply constraints based on the input state if do_constraints: do_constraints = self._apply_constraints(state=state) # Get active sub-environment, substate and unique environment active_subenv = self._get_active_subenv(state) state_subenv = self._get_substate(state, active_subenv) subenv = self._get_unique_env_of_subenv(active_subenv) # Obtain mask of substate mask = subenv.get_mask_invalid_actions_forward(state_subenv, done) # Reset constraints for self.state if do_constraints: self._apply_constraints(state=self.state) return self._format_mask(mask, active_subenv)
[docs] def get_mask_invalid_actions_backward( self, state: Optional[Dict] = None, done: Optional[bool] = None ) -> List[bool]: """ Computes the backward actions mask of the state. The mask of the Stack environment is the mask of the relevant sub-environment, preceded by a one-hot encoding of the index of the subenv and padded with False up to ``mask_dim``. Including only the relevant mask saves memory and computation. The relevant sub-environment regarding the backward mask is the current sub-environment except if the state of the sub-environment is the subenv's source, in which case the mask must be the one of the preceding sub-environment, so as to sample its EOS action. There are two exceptions to the above case: - If ``done`` is True, in which case the current sub-environment is the last one and the EOS action must come from itself, not the preceding subenv. - If the current stage is the first sub-environment, in which case there is no preceding subenv. If a state is passed as an argument (not ``None``) and the Stack has constraints, the constraints are applied before computing the mask and reset thereafter. This is necessary because otherwise the sub-environments may not have the correct attributes necessary to calculate the mask. """ do_constraints = state is not None and id(state) != id(self.state) state = self._get_state(state) done = self._get_done(done) # Apply constraints based on the input state if do_constraints: do_constraints = self._apply_constraints(state=state) # Get active sub-environment, substate and unique environment active_subenv = self._get_active_subenv(state) state_subenv = self._get_substate(state, active_subenv) subenv = self._get_unique_env_of_subenv(active_subenv) # Change the relevant sub-environment and set done to True if the substate is # the source of an intermediate sub-environment if active_subenv > 0 and not done and subenv.is_source(state_subenv): relevant_subenv = active_subenv - 1 state_subenv = self._get_substate(state, relevant_subenv) subenv = self._get_unique_env_of_subenv(relevant_subenv) done = True else: relevant_subenv = active_subenv # Obtain mask of substate mask = subenv.get_mask_invalid_actions_backward(state_subenv, done) # Reset constraints for self.state if do_constraints: self._apply_constraints(state=self.state) return self._format_mask(mask, relevant_subenv)
# TODO: rethink whether padding should be True (invalid) instead. def _format_mask(self, mask: List[bool], idx_subenv: int): """ Formats the mask of a sub-environment into a Stack mask. The output format is the input mask, which corresponds to a sub-environment, preceded by a one-hot encoding of the index of active sub-environment and with False up to ``self.mask_dim``. Parameters ---------- mask : List[bool] The mask of a sub-environment idx_subenv : int The index of the sub-environment to be one-hot encoded. """ idx_onehot = [False] * self.n_subenvs idx_onehot[idx_subenv] = True padding = [False] * (self.mask_dim - (len(mask) + self.n_subenvs)) return idx_onehot + mask + padding def _unformat_mask( self, mask: Union[List[bool], TensorType["batch_size", "mask_dim"]], idx_subenv: int = None, mask_dim: int = None, ): """ Extracts the mask of the sub-environment from a Stack-formated mask or batch of masks. This method removes the one-hot encoding of the index of the active sub-environment that precedes the subenv mask, as well as the padding. Parameters ---------- mask : List[bool] or tensor A Stack mask or batch of masks idx_subenv : int The index of the sub-environment whose mask is to be extracted. If None, ``mask_dim`` must be passed. mask_dim : int The dimensionality of the mask to be extracted. If None, ``idx_subenv`` must be passed. Ignored if ``idx_subenv`` is not None. Raises ------ ValueError If ``idx_subenv`` and ``mask_dim`` are both None. ValueError If the mask is neither a list nor a tensor. """ if idx_subenv is not None: mask_dim = self.subenvs[idx_subenv].mask_dim elif mask_dim is None: raise ValueError("idx_subenv and mask_dim cannot be both None") if isinstance(mask, list): return mask[self.n_subenvs : self.n_subenvs + mask_dim] elif torch.is_tensor(mask): return mask[:, self.n_subenvs : self.n_subenvs + mask_dim] else: raise ValueError("The input mask can only be a list or a tensor")
[docs] def get_valid_actions( self, mask: Optional[bool] = None, state: Optional[Dict] = None, done: Optional[bool] = None, backward: Optional[bool] = False, ) -> List[Tuple]: """ Returns the list of non-invalid (valid, for short) actions. This method is overridden because the mask of a Stack of environments does not cover the entire action space, but only the relevant sub-environment. Therefore, this method calls the ``get_valid_actions()`` method of the currently relevant sub-environment and returns the padded actions. If a state is passed as an argument (not ``None``) and the Stack has constraints, the constraints are applied before computing the mask and reset thereafter. This is necessary because otherwise the sub-environments may not have the correct attributes necessary to calculate the mask. """ do_constraints = state is not None and id(state) != id(self.state) state = self._get_state(state) done = self._get_done(done) # Apply constraints based on the input state if do_constraints: do_constraints = self._apply_constraints(state=state) # Get active sub-environment, substate and unique environment active_subenv = self._get_active_subenv(state) state_subenv = self._get_substate(state, active_subenv) subenv = self._get_unique_env_of_subenv(active_subenv) # Change the relevant sub-environment and set done to True if the substate is # the source of an intermediate sub-environment if ( backward and active_subenv > 0 and not done and subenv.is_source(state_subenv) ): relevant_subenv = active_subenv - 1 state_subenv = self._get_substate(state, relevant_subenv) subenv = self._get_unique_env_of_subenv(relevant_subenv) done = True else: relevant_subenv = active_subenv if mask is not None: # Extract the part of the mask corresponding to the sub-environment mask = self._unformat_mask(mask, relevant_subenv) # Obtain valid actions idx_unique = self._get_unique_idx_of_subenv(relevant_subenv) valid_actions = [ self._pad_action(action, idx_unique) for action in subenv.get_valid_actions(mask, state_subenv, done, backward) ] # Reset constraints for self.state if do_constraints: self._apply_constraints(state=self.state) return valid_actions
# TODO: review
[docs] def mask_conditioning( self, mask: Union[List[bool], TensorType["mask_dim"]], env_cond, backward: bool ): """ Conditions the input mask based on the restrictions imposed by a conditioning environment, env_cond. This method is overriden because the base mask_conditioning would change the mask unaware of the special Stack format. Therefore, this method calls the mask_conditioning() method of the currently relevant sub-environment and returns the mask with the correct Stack format. """ stage = self._get_stage() subenv = self.subenvs[stage] # Extract the part of the mask corresponding to the sub-environment # TODO: consider writing a method to do this mask = self._unformat_mask(mask, stage) env_cond = env_cond.subenvs[stage] mask = subenv.mask_conditioning(mask, env_cond, backward) return self._format_mask(mask, stage, subenv.mask_dim)
[docs] def get_parents( self, state: Optional[Dict] = None, done: Optional[bool] = None, action: Optional[Tuple] = None, ) -> Tuple[List, List]: """ Determines all parents and actions that lead to the input state. If a state is passed as an argument (not ``None``) and the Stack has constraints, the constraints are applied before computing the mask and reset thereafter. This is necessary because otherwise the sub-environments may not have the correct attributes necessary to calculate the mask. Parameters ---------- state : list State in environment format. If not, self.state is used. done : bool Whether the trajectory is done. If None, self.done is used. action : tuple Ignored. Returns ------- parents : list List of parents in state format actions : list List of actions that lead to state for each parent in parents """ do_constraints = state is not None and id(state) != id(self.state) state = self._get_state(state) done = self._get_done(done) # If done is True, the only parent is the state itself with action EOS. if done: return [state], [self.eos] # Apply constraints based on the input state if do_constraints: do_constraints = self._apply_constraints(state=state) # Get active sub-environment, substate and unique environment active_subenv = self._get_active_subenv(state) state_subenv = self._get_substate(state, active_subenv) subenv = self._get_unique_env_of_subenv(active_subenv) # Change the relevant sub-environment and set done to True if the substate is # the source of an intermediate sub-environment if active_subenv > 0 and not done and subenv.is_source(state_subenv): relevant_subenv = active_subenv - 1 state_subenv = self._get_substate(state, relevant_subenv) subenv = self._get_unique_env_of_subenv(relevant_subenv) done = True else: relevant_subenv = active_subenv # Get parents of the relevant sub-environment parents_subenv, parent_actions = subenv.get_parents(state_subenv, done) # Convert subenv parents to Stack states parents = [] for parent_subenv in parents_subenv: parent = copy(state) parent = self._set_active_subenv(relevant_subenv, parent) parent = self._set_substate(relevant_subenv, parent_subenv, parent) parents.append(parent) # Pad actions idx_unique = self._get_unique_idx_of_subenv(relevant_subenv) parent_actions = [ self._pad_action(action, idx_unique) for action in parent_actions ] # Reset constraints for self.state if do_constraints: self._apply_constraints(state=self.state) return parents, parent_actions
[docs] def step( self, action: Tuple, skip_mask_check: bool = False ) -> Tuple[List, Tuple, bool]: """ Executes forward step given an action. The action is performed by the corresponding sub-environment and then the global state is updated accordingly. If the action is the EOS of the sub-environment, the stage is advanced and constraints are set on the subsequent sub-environment. Parameters ---------- action : tuple Action to be executed. The input action is global, that is padded. Returns ------- self.state : Dict The state after executing the action. action : int Action executed. valid : bool False, if the action is not allowed for the current state. True otherwise. """ # If done, exit immediately if self.done: return self.state, action, False # Get active sub-environment, unique index, subenv and action of subenv active_subenv = self._get_active_subenv(self.state) idx_unique = self._get_unique_idx_of_subenv(active_subenv) subenv = self.subenvs[active_subenv] action_subenv = self._depad_action(action, idx_unique) # Perform pre-step from subenv - if it was done from the Stack env there could # be a mismatch between mask and action space due to continuous subenvs. action_to_check = subenv.action2representative(action_subenv) # Skip mask check if stage is continuous if subenv.continuous: skip_mask_check = True do_step, _, _ = subenv._pre_step( action_to_check, skip_mask_check=(skip_mask_check or self.skip_mask_check), ) if not do_step: return self.state, action, False # Call step of active sub-environment _, action_subenv, valid = subenv.step(action_subenv) # If action is invalid, exit immediately. if not valid: return self.state, action, False # Otherwise, increment number of actions and go on self.n_actions += 1 # Check if action is EOS of subenv if action_subenv == subenv.eos: # If it is global EOS set done to True if active_subenv == (self.n_subenvs - 1) and action == self.eos: self.done = True else: # Increment active subenv and apply constraints self._set_active_subenv(active_subenv + 1) self._apply_constraints(action=action, is_backward=False) else: # Update substate self._set_substate(active_subenv, subenv.state) return self.state, action, valid
[docs] def step_backwards( self, action: Tuple, skip_mask_check: bool = False ) -> Tuple[List, Tuple, bool]: """ Executes backward step given an action. The action is performed by the corresponding sub-environment and then the global state is updated accordingly. If the updated state of the sub-environment becomes its source, the stage is decreased. Parameters ---------- action : tuple Action to be executed. The input action is global, that is padded. Returns ------- self.state : list The state after executing the action. action : int Action executed. valid : bool False, if the action is not allowed for the current state. True otherwise. """ # Get active sub-environment, unique index, and action of subenv # The unique index is taken from the action, not the state, since they may not # coincide in inter-subenv transition actions active_subenv = self._get_active_subenv(self.state) idx_unique = action[0] action_subenv = self._depad_action(action, idx_unique) # Determine the relevant sub-environment: if the action is EOS, then it is a # transition between sub-environments and the relevant sub-environment is the # previous to the active subenv in the state. Exception: if the env is done env = self._get_env_unique(idx_unique) if action_subenv == env.eos and not self.done: relevant_subenv = active_subenv - 1 else: relevant_subenv = active_subenv subenv = self.subenvs[relevant_subenv] # Perform pre-step from subenv - if it was done from the "superenv" there could # be a mismatch between mask and action space due to continuous subenvs. action_to_check = subenv.action2representative(action_subenv) # Skip mask check if stage is continuous if subenv.continuous: skip_mask_check = True do_step, _, _ = subenv._pre_step( action_to_check, backward=True, skip_mask_check=(skip_mask_check or self.skip_mask_check), ) if not do_step: return self.state, action, False # Call step of current subenvironment _, _, valid = subenv.step_backwards(action_subenv) # If action is invalid, exit immediately if not valid: return self.state, action, False # Otherwise, increment number of actions and go on self.n_actions += 1 # If action was from done, set done False if self.done: assert action == self.eos self.done = False # Update substate self._set_substate(relevant_subenv, subenv.state) self._set_active_subenv(relevant_subenv) # If action is EOS of subenv, apply backward constraints if action_subenv == subenv.eos: self._apply_constraints(action=action, is_backward=True) return self.state, action, valid
# TODO: review if random action probability works
[docs] def sample_actions_batch( self, policy_outputs: TensorType["n_states", "policy_output_dim"], mask: Optional[TensorType["n_states", "policy_output_dim"]] = None, states_from: List[Dict] = None, is_backward: Optional[bool] = False, random_action_prob: Optional[float] = 0.0, temperature_logits: Optional[float] = 1.0, ) -> Tuple[List[Tuple]]: """ Samples a batch of actions from a batch of policy outputs. This method calls the ``sample_actions_batch()`` method of the sub-environment corresponding to each state in the batch. Note that in order to call sample_actions_batch() of the sub-environments, we need to first extract the part of the policy outputs, the masks and the states that correspond to the sub-environment. Parameters ---------- policy_outputs : tensor The output of the GFlowNet policy model. mask : tensor The mask of invalid actions, formatted as in :py:meth:`gflownet.envs.composite.stack.Stack._format_mask` states_from : list The states originating the actions, in environment format. is_backward : bool True if the actions are backward, False if the actions are forward random_action_prob : float, optional The probability of sampling a random action. If larger than one, the model outputs will be replaced by a random policy vector with probability `random_action_prob`, according to Bernoulli distribution. temperature_logits : float, optional A scalar by which the model outputs are divided to temper the sampling distribution. Returns ------- actions : list The list of sampled actions. """ # Get the indices of the relevant sub-environments from the one-hot prefix of # the mask indices_relevant = torch.where(mask[:, : self.n_subenvs])[1] indices_relevant_int = indices_relevant.tolist() # Create the tensor indices_unique, which contains the index of the unique # environment corresponding to the relevant subenv and the list of unique # indices indices_unique = torch.empty_like(indices_relevant) for idx_subenv in set(indices_relevant_int): idx_unique = self._get_unique_idx_of_subenv(idx_subenv) indices_unique[indices_relevant == idx_subenv] = idx_unique indices_unique_int = indices_unique.tolist() # Create a dictionary with keys equal to the unique indices and the values are # corresponding to the sub-environment. states_dict = {idx: [] for idx in self.unique_indices} for state, idx_subenv, idx_unique in zip( states_from, indices_relevant_int, indices_unique_int ): states_dict[idx_unique].append(self._get_substate(state, idx_subenv)) # Sample actions from each unique environment actions_dict = {} for idx in set(indices_unique_int): env = self._get_env_unique(idx) env_mask = indices_unique == idx actions_dict[idx] = env.sample_actions_batch( self._get_policy_outputs_of_env_unique(policy_outputs[env_mask], idx), self._unformat_mask(mask[env_mask, :], mask_dim=env.mask_dim), states_dict[idx], is_backward, random_action_prob, temperature_logits, ) # Stitch all actions in the right order, with the right padding return [ self._pad_action(actions_dict[idx].pop(0), idx) for idx in indices_unique_int ]
[docs] def get_logprobs( self, policy_outputs: TensorType["n_states", "policy_output_dim"], actions: Union[List, TensorType["n_states", "action_dim"]], mask: TensorType["n_states", "mask_dim"], states_from: List[Dict], is_backward: bool, ) -> TensorType["batch_size"]: """ Computes log probabilities of actions given policy outputs and actions. Parameters ---------- policy_outputs : tensor The output of the GFlowNet policy model. mask : tensor The mask containing information about invalid actions and special cases. actions : list or tensor The actions (global) from each state in the batch for which to compute the log probability. states_from : tensor The states originating the actions, in environment format. is_backward : bool True if the actions are backward, False if the actions are forward (default). Returns ------- tensor The log probabilities of the transitions. """ actions = tfloat(actions, float_type=self.float, device=self.device) n_states = policy_outputs.shape[0] # Get the indices of the relevant sub-environments from the one-hot prefix of # the mask indices_relevant = torch.where(mask[:, : self.n_subenvs])[1] indices_relevant_int = indices_relevant.tolist() # Create the tensor indices_unique, which contains the index of the unique # environment corresponding to the relevant subenv and the list of unique # indices indices_unique = torch.empty_like(indices_relevant) for idx_subenv in set(indices_relevant_int): idx_unique = self._get_unique_idx_of_subenv(idx_subenv) indices_unique[indices_relevant == idx_subenv] = idx_unique indices_unique_int = indices_unique.tolist() # Create a dictionary with keys equal to the unique indices and the values are # corresponding to the sub-environment. states_dict = {idx: [] for idx in self.unique_indices} for state, idx_subenv, idx_unique in zip( states_from, indices_relevant_int, indices_unique_int ): states_dict[idx_unique].append(self._get_substate(state, idx_subenv)) # Compute logprobs from each sub-environment logprobs = torch.empty(n_states, dtype=self.float, device=self.device) for idx in set(indices_unique_int): env = self._get_env_unique(idx) env_mask = indices_unique == idx logprobs[env_mask] = env.get_logprobs( self._get_policy_outputs_of_env_unique(policy_outputs[env_mask], idx), self._depad_action_batch(actions[env_mask, :], idx), self._unformat_mask(mask[env_mask, :], mask_dim=env.mask_dim), states_dict[idx], is_backward, ) return logprobs
[docs] def states2policy( self, states: List[Dict] ) -> TensorType["batch", "state_policy_dim"]: """ Prepares a batch of states in environment format for the policy model. The default policy representation is a concatenation of the policy-format states of the sub-environments. There is only one call of ``subenv.states2policy()`` for each sub-environment, on all the corresponding substates in the batch. Parameters ---------- states : list A batch of states in environment format. Returns ------- A tensor containing all the states in the batch. """ return torch.cat( [ subenv.states2policy([state[idx] for state in states]) for idx, subenv in enumerate(self.subenvs) ], dim=1, )
[docs] def states2proxy(self, states: List[Dict]) -> List[List]: """ Prepares a batch of states in environment format for a proxy: simply a concatenation of the proxy-format states of the sub-environments. Parameters ---------- states : list A batch of states in environment format. Returns ------- A list of lists, each containing the proxy representation of all the states in the Stack, for all the Stacks in the batch.. """ states_proxy = [] for state in states: states_proxy.append( [ subenv.state2proxy(self._get_substate(state, idx))[0] for idx, subenv in enumerate(self.subenvs) ] ) return states_proxy
[docs] def state2readable(self, state: Optional[Dict] = None) -> str: """ Converts a state into human-readable representation. It concatenates the readable representations of each sub-environment, separated by "; " and preceded by "Stage {stage}; ". """ state = self._get_state(state) readable = f"Active: {self._get_active_subenv(state)}; " + "".join( [ subenv.state2readable(self._get_substate(state, idx)) + "; " for idx, subenv in enumerate(self.subenvs) ] ) readable = readable[:-2] return readable
[docs] def readable2state(self, readable: str) -> List[int]: """ Converts a human-readable representation of a state into the standard format. """ readables = readable.split("; ") active_subenv = int(readables[0][-1]) readables = readables[1:] state = {"_active": active_subenv} state.update( { idx: subenv.readable2state(readables[idx]) for idx, subenv in enumerate(self.subenvs) } ) return state
# TODO: review if this could be moved to CompositeBase
[docs] def action2representative(self, action: Tuple) -> int: """ Replaces the part of the action associated with a sub-environment by its representative. The part of the action that identifies the sub-environment concerned by the action remains unaffected. """ idx_unique = action[0] action_subenv = self._depad_action(action, idx_unique) env = self._get_env_unique(idx_unique) action_subenv_representative = env.action2representative(action_subenv) return self._pad_action(action_subenv_representative, idx_unique)
[docs] def is_source(self, state: Optional[Dict] = None) -> bool: """ Returns True if the environment's state or the state passed as parameter (if not None) is the source state of the environment. This method is overriden for efficiency (for example, it would return False immediately if the stage is not the first stage) and to cover special uses of the Stack. Parameters ---------- state : Dict None, or a state in environment format. Returns ------- bool Whether the state is the source state of the environment """ state = self._get_state(state) active_subenv = self._get_active_subenv(state) if active_subenv != 0: return False for idx, subenv in enumerate(self.subenvs): if not subenv.is_source(self._get_substate(state, idx)): return False else: return True