Source code for gflownet.envs.composite.base

"""
Base class for composite environments.

Composite environments are environments which consist of multiple environments.
"""

from typing import Dict, Iterable, List, Optional, Tuple, Union

import torch
from torchtyping import TensorType

from gflownet.envs.base import GFlowNetEnv


[docs] class CompositeBase(GFlowNetEnv): """ Base class for composite environments. The states of composite environments are dictionaries. Keys with integers, starting from 0, are reserved to contain the states at the position indicated by the key. Additionally, the dictionary may includes keys with meta-data about the state of the composite environment. The following keys and values are supported by default by the base composite environment: - ``_active``: The index of the currently active sub-environment, or -1 if none is active. - ``_dones``: A list of flags indicating whether the sub-environments are done (1) or not done (0). For example, ``[0, 1, 0]`` indicates that the sub-environments at indices 0 and 2 are not done, and the sub-environment at index 1 is done. - ``_envs_unique``: A list of indices identifying the unique environment corresponding to each sub-environment. For example, ``[1, 1, 0]`` indicates that the sub-environments at indices 0 and 1 are of the same type, in particular the type of environment at index 1 in ``self.envs_unique``; the sub-environment at index 2 is of the type at index 0 in ``self.envs_unique``. If the environment type at a position is unknown, it contains -1. Not all supported keys need to be included in the states of sub-classes and new keys may be included as needed. Attributes ---------- max_elements : int The maximum number of elements that can be included in the composite environment. Note that this number does not refer to the number of unique environments, but the number of elements (instances of a sub-environment) that can form a composite environment. subenvs : iterable The collection of sub-environments included in the composite environment. envs_unique : iterable The collection of unique environments that can make part of the composite environment. Uniqueness is defined in terms of the environment type and action space. """ def __init__( self, **kwargs, ): """ Initializes the CompositeBase environment. """ # Constraints self._has_constraints = self._check_has_constraints() # Base class init super().__init__(**kwargs) def _get_substate(self, state: Dict, idx_subenv: Optional[int] = None): """ Returns the part of the state corresponding to the sub-environment indicated at ``idx_subenv``. Parameters ---------- state : dict A state of the composite environment. idx_subenv : int Index of the sub-environment of which the corresponding part of the state is to be extracted. If None, the state of the active subenv is used. Returns ------- The state of a sub-environment. """ if idx_subenv is None: idx_subenv = self._get_active_subenv(state) if idx_subenv >= self.max_elements: raise ValueError( f"Index {idx_subenv} is not a valid sub-environment index." ) return state[idx_subenv] def _get_substates(self, state: Dict) -> List: """ Returns a list with all the states of the sub-environments. Parameters ---------- state : dict A state of the global composite environment. Returns ------- list All sub-states in the state. """ substates = [] for idx in range(self.max_elements): if idx in state: substates.append(self._get_substate(state, idx)) return substates def _set_substate( self, idx_subenv: int, state_subenv: Union[List, TensorType, dict], state: Optional[Dict] = None, ) -> Dict: """ Updates the global composite state by setting as substate of subenv ``idx_subenv`` the current state of the sub-environment. This method modifies ``self.state`` if ``state`` is None. Parameters ---------- idx_subenv : int Index of the sub-environment of which to set the state. state_subenv : list or tensor or dict The state of a sub-environment. state : dict A state of the global composite environment. Returns ------- The updated composite state. """ assert idx_subenv < self.max_elements state = self._get_state(state) state[idx_subenv] = state_subenv return state def _get_active_subenv(self, state: Optional[Dict] = None) -> int: """ Returns the index of the currently active sub-environment. If no state is passed, ``self.state`` is used. The active sub-environment is indicated in ``state["_active"]``. Parameters ---------- state : dict A state of the composite environment. """ state = self._get_state(state) return state["_active"] def _set_active_subenv(self, idx_subenv: int, state: Optional[Dict] = None) -> Dict: """ Sets the index of the active sub-environment. If no state is passed, ``self.state`` is used. The active sub-environment is set in ``state["_active"]``. Parameters ---------- idx_subenv : int Index of the sub-environment to set as active, or -1. state : dict A state of the composite environment. Returns ------- The updated composite state. """ assert idx_subenv < self.max_elements or idx_subenv == -1 state = self._get_state(state) state["_active"] = idx_subenv return state def _get_dones(self, state: Optional[Dict] = None) -> List[int]: """ Returns the part of the state containing the list of done flags. The list of done flags indicate which sub-environments are done or not. The list of done flags is indicated in ``state["_dones"]``. Parameters ---------- state : dict A state of the composite environment. Returns ------- The list of dones as integer flags (0 or 1). """ state = self._get_state(state) return state["_dones"] def _set_subdone( self, idx_subenv: int, done: bool, state: Optional[Dict] = None ) -> Dict: """ Updates the done flag corresponding to the sub-environment at ``idx_subenv``. Parameters ---------- idx_subenv : int Index of the sub-environment of which to set the done flag. done : bool The value of done to be set in the state. state : dict A state of the composite environment. Returns ------- The updated composite state. """ assert idx_subenv < self.max_elements state = self._get_state(state) state["_dones"][idx_subenv] = int(done) return state def _get_subdone(self, idx_subenv: int, state: Optional[Dict] = None) -> bool: """ Returns whether if the sub-environment at ``idx_subenv`` is done. Parameters ---------- idx_subenv : int Index of the sub-environment to query. state : dict A state of the composite environment. Returns ------- True if the sub-environment at ``idx_subenv`` is done; False otherwise. """ assert idx_subenv < self.max_elements return bool(self._get_dones(state)[idx_subenv]) def _set_unique_index( self, idx_subenv: int, idx_unique: int, state: Optional[Dict] = None ) -> Dict: """ Updates the index of the sub-environment indicated by idx_subenv in the list of unique indices. Parameters ---------- idx_subenv : int Index of the sub-environment of which to set the unique index. idx_unique : int The unique index to be set. state : dict A state of the composite environment. Returns ------- The updated composite state. """ assert idx_subenv < self.max_elements assert idx_unique < self.n_unique_envs state = self._get_state(state) state["_envs_unique"][idx_subenv] = idx_unique return state def _get_unique_environments( self, subenvs: Iterable[GFlowNetEnv] ) -> Tuple[List[GFlowNetEnv], Tuple, List]: """ Determines the set of unique environments in the iterable subenvs passed as an argument. Uniqueness is determined by both the type of environment and the action space: two environments are only considered equal if both the type and the action space are the same. Parameters ---------- subenvs : iterable Iterable of sub-environments. Returns ------- envs_unique : list The list of unique environments. envs_unique_keys : tuple A tuple containing the keys that identify each unique environment, namely tuples with the type of environment and the action space. unique_indices : list A list containing the index of the unique environment corresponding to each sub-environment in the input iterable. """ envs_unique = [] envs_unique_keys = [] unique_indices = [] for idx, env in enumerate(subenvs): env_key = (type(env), tuple(env.action_space)) if env_key not in envs_unique_keys: envs_unique_keys.append(env_key) envs_unique.append(env) unique_indices.append(envs_unique_keys.index(env_key)) return envs_unique, tuple(envs_unique_keys), unique_indices @property
[docs] def n_unique_envs(self) -> int: """ Returns the number of unique environments. Returns ------- int The number of unique environments. Raises ------ RuntimeError If ``self.subenvs`` is not an attribute of the environment. """ if hasattr(self, "_n_unique_envs"): return self._n_unique_envs if not hasattr(self, "envs_unique"): if not hasattr(self, "subenvs"): raise RuntimeError("The environment does not contain self.subenvs") envs_unique, _, _ = self._get_unique_environments(self.subenvs) self._n_unique_envs = len(envs_unique) else: self._n_unique_envs = len(self.envs_unique) return self._n_unique_envs
def _get_env_unique(self, idx_unique: int) -> GFlowNetEnv: """ Returns the unique environment with index idx_unique``. This method requires the definition of the attribute ``envs_unique`` containing the set of unique enviroments. Parameters ---------- idx_unique : int The index of the unique environment to be retrieved. Returns ------- GFlowNetEnv The unique environment with index idx_unique. """ return self.envs_unique[idx_unique] def _get_unique_idx_of_subenv( self, idx_subenv: int, state: Optional[Dict] = None ) -> int: """ Returns the index of the unique environment corresponding to the subenv at index ``idx_subenv``. The index refers to the state passed as an input. If the state is None, ``self.state`` is used. This method requires the definition of the attributes ``envs_unique``, containing the set of unique enviroments, and ``max_elements``, containing the maximum number of sub-environments allowed in the composite environment. Parameters ---------- idx_subenv : int Index of a sub-environment (from 0 to ``self.max_elements``). Note that this is the index of a subenv, not of the unique environments. state : dict A state of the global composite environment. """ assert idx_subenv < self.max_elements return self._get_unique_indices(state)[idx_subenv] def _get_unique_env_of_subenv( self, idx_subenv: int, state: Optional[Dict] = None ) -> GFlowNetEnv: """ Returns the unique environment corresponding to the sub-environment at index ``idx_subenv``. Parameters ---------- idx_subenv : int Index of a sub-environment (from 0 to ``self.max_elements``). Note that this is the index of a subenv, not of the unique environments. state : dict A state of the global composite environment. """ return self._get_env_unique(self._get_unique_idx_of_subenv(idx_subenv, state)) def _get_unique_indices(self, state: Optional[Dict] = None) -> int: """ Returns the part of the state containing the unique indices. The unique indices identify the type of environment of each element in state of the composite environment. Parameters ---------- state : dict A state of the global composite environment. """ state = self._get_state(state) return state["_envs_unique"]
[docs] def get_action_space(self) -> List[Tuple]: """ Constructs a list with all possible actions, including EOS. By default, the action space of a Composite environment consists of the concatenation of the actions of all unique environments. Certain composite environments may make use of additional actions, for example to toggle specific sub-environments. Sub-classes with additional actions should override this method. In order to make all actions the same length (required to construct batches of actions as a tensor), the actions are zero-padded from the back. In order to make all actions unique, the unique environment index is added as the first element of the action. Note that the actions of unique environments are only added once to the action space, regardless of how many elements of the unique environment (sub-environments) there are in the composite environment. In other words, identical environments that are part of the composite environment share the actions and a given action will have an effect on the sub-environment that is next or active. See: - :py:meth:`~gflownet.envs.composite.base.CompositeBase._pad_action` - :py:meth:`~gflownet.envs.composite.base.CompositeBase._depad_action` """ action_space = [] for idx in range(self.n_unique_envs): action_space.extend( [ self._pad_action(action, idx) for action in self._get_env_unique(idx).action_space ] ) return action_space
def _pad_action(self, action: Tuple, idx_unique: int) -> Tuple: """ Pads an action by adding the unique index (or -1) as the first element and zeros as padding. See: - :py:meth:`~gflownet.envs.composite.base.CompositeBase.get_action_space` Parameters ---------- action : tuple The action to be padded. idx_unique : int The index of the unique environment or -1 for meta-actions (EOS and other actions of the composite environment) Returns ------- tuple The padded and pre-fixed action. """ return (idx_unique,) + action + (0,) * (self.action_dim - len(action) - 1) def _depad_action(self, action: Tuple, idx_unique: int = None) -> Tuple: """ Reverses the padding operation, such that the resulting action can be passed to the underlying environment. See: - :py:meth:`~gflownet.envs.composite.base.CompositeBase._pad_action` If idx_unique does not match the prefix of the action, the action is returned as is. Parameters ---------- action : tuple The action to be depadded. idx_unique : int The index of the unique environment or -1 for meta-actions (EOS and other actions of the composite environment) Returns ------- tuple The depadded action, as it appears in the action space of the sub-environment it belongs to. If idx_unique is -1 (meta-action), then the returned action is a single-element tuple with the sub-environment index. """ if idx_unique is None: idx_unique = action[0] else: # If idx_unique does not match the action prefix, raise an error if idx_unique != action[0]: raise RuntimeError( f"There is a mismatch between the input idx_unique ({idx_unique}) " f"for de-padding and the unique index in the action ({action[0]})" ) if idx_unique != -1: return action[1 : 1 + len(self._get_env_unique(idx_unique).eos)] return (action[1],) def _depad_action_batch( self, actions: TensorType["batch_size", "action_dim"], idx_unique: int, ) -> TensorType["batch_size", "action_dim_subenv"]: """ Reverses the padding operation for a batch of actions. It is assumed that all actions correspond to the same unique environment or that all of them are meta-actions. See: - :py:meth:`~gflownet.envs.composite.base.CompositeBase._depad_action` Parameters ---------- actions : tensor The batch of actions to be depadded. idx_unique : int The index of the unique environment or -1 for meta-actions (EOS and other actions of the composite environment) Returns ------- tensor The depadded batch of actions. If idx_unique is -1 (meta-action), then the returned batch is a single-column tensor with the sub-environment index. """ if idx_unique != -1: return actions[:, 1 : 1 + len(self._get_env_unique(idx_unique).eos)] return actions[:, 1]
[docs] def set_state(self, state: Dict, done: Optional[bool] = False): """ Sets a state and done. The correct state and done of each sub-environment are set too. Parameters ---------- state : dict A state of the global composite environment. done : bool Whether the trajectory of the environment is done or not. """ # If done is True, then all sub-environments must be done if done: dones = [True] * self.max_elements else: dones = self._get_dones(state) # Set state and done super().set_state(state, done) # Set state and done of each sub-environment for idx, (subenv, done_subenv) in enumerate(zip(self.subenvs, dones)): subenv.set_state(self._get_substate(self.state, idx), bool(done_subenv)) # Apply constraints across sub-environments, in case they apply. self._apply_constraints(state=self.state, is_backward=None) return self
[docs] def reset(self, env_id: Union[int, str] = None): """ Resets the environment by resetting the sub-environments. """ if self.subenvs is not None: for subenv in self.subenvs: subenv.reset() super().reset(env_id=env_id) # Apply constraints across sub-environments, in case they apply. is_backward is # set to True to bypass forward constraints. self._apply_constraints(state=self.state, is_backward=True) return self
[docs] def get_policy_output(self, params: list[dict]) -> TensorType["policy_output_dim"]: """ Defines the structure of the output of the policy model. By default, the policy output is the concatenation of the policy outputs of the unique environments. Sub-classes should override this method if the structure of the policy outputs changes, for example, if meta-actions are added. Parameters ---------- params : list A list of distribution parameters. This list has as many elements as there are unique environments, since all sub-environments of the same environment type are expected to be identical. """ return torch.cat( [ self._get_env_unique(idx).get_policy_output(params_env_unique) for idx, params_env_unique in enumerate(params) ] )
def _get_policy_outputs_of_env_unique( self, policy_outputs: TensorType["n_states", "policy_output_dim"], idx_unique: int, ): """ Returns the columns of the policy outputs that correspond to the unique environment indicated by idx_unique. Since the policy outputs corresponding to each unique environment are concatenated across the columns of the input tensor, the outputs of a particular environment can be retrieved by iterating over the unique environments and calculating their output dimensions. In order to avoid this iteration at every request, the first call creates a dictionary of offsets as an attribute of the environment. Parameters ---------- policy_outputs : tensor A tensor containing a batch of policy outputs. It is assumed that all the rows in the this tensor correspond to the same unique environment. idx_unique : int Index of the unique environment of which the corresponding columns of the policy outputs are to be extracted. """ if hasattr(self, "_policy_outputs_offset_of_unique_env"): if idx_unique in self._policy_outputs_offset_of_unique_env: init_col = self._policy_outputs_offset_of_unique_env[idx_unique] end_col = init_col + self._get_env_unique(idx_unique).policy_output_dim return policy_outputs[:, init_col:end_col] else: self._policy_outputs_offset_of_unique_env = {} init_col = 0 for idx in range(self.n_unique_envs): end_col = init_col + self._get_env_unique(idx).policy_output_dim if idx == idx_unique: self._policy_outputs_offset_of_unique_env[idx] = init_col return policy_outputs[:, init_col:end_col] init_col = end_col @property
[docs] def has_constraints(self): """ Whether the composite environment has constraints across sub-environments. Returns ------- True if the composite environment has constraints across sub-environments. """ return self._has_constraints
def _check_has_constraints(self) -> bool: """ Checks whether the composite environment has constraints across sub-environments. By default, composite environments do not have constraints (False). This method should be overriden in environments that incorporate constraints across sub-environmnents via ``_apply_constraints()``. Returns ------- bool True if the composite environment has constraints, False otherwise """ return False def _apply_constraints( self, action: Tuple = None, state: Optional[Dict] = None, is_backward: bool = None, ) -> bool: """ Applies constraints across sub-environments. This method is called from the methods that can modify the state, namely: - :py:meth:`~gflownet.envs.composite.base.CompositeBase.step()` - :py:meth:`~gflownet.envs.composite.base.CompositeBase.step_backwards()` - :py:meth:`~gflownet.envs.composite.base.CompositeBase.set_state()` - :py:meth:`~gflownet.envs.composite.base.CompositeBase.reset()` Furthermore, it is also called from methods that receive an input state different to ``self.state``, in order to make sure that the sub-environments have the properties and constraints corresponding to the input state, rather than those of ``self.state``. For example: - ``get_mask_invalid_actions_forward()`` - ``get_mask_invalid_actions_backward()`` - ``get_valid_actions()`` - ``get_parents()`` In general, the application of constraints can be initiated by and depend on an action (for example, from ``step()` or ``step_backwards()``) or by a state (most other methods). In the former case, the input action is not ``None`` and the state may be None. Furthermore, ``is_backward`` should be ``True`` or ``False`` to indicate the direction of the action. In the latter case, the action is ``None`` and the input state may be not ``None``. Furthermore, ``is_backward`` is ``None``, indicating that no transition is involved in the application of constraints. This method simply calls :py:meth:`~gflownet.envs.composite.base.CompositeBase._apply_constraints_forward` and/or :py:meth:`~gflownet.envs.composite.base.CompositeBase._apply_constraints_backward`. In general, this method should _not_ be overriden. Instead, environments inheriting composite classes may override: - `_apply_constraints_forward` - `_apply_constraints_backward` Parameters ---------- action : tuple (optional) An action, used to determine whether and which constraints should be applied and which should not, since the computations may be intensive. If the call of the method is not initiated by an actioa, then it is expected to be ``None``. state : dict (optional) A state in environment format used to indicate the state of the trajectory which should inform whether and which constraints should be applied. If ``None``, ``self.state`` may be used. is_backward : bool or None Boolean flag to indicate whether the constraint should be applied in the backward direction (True), meaning 'undoing' the constraint (this is the value when the call method is initiated by ``step_backwards()``; or in the forward direction, meaning 'applying' the constraint (if initiated by ``step()``). If the call of the method is not initiated by an action, then the value may be ``None``, indicating that the constraints depend on the input state. ``is_backward`` can be not None even if the action is None to indicate that either the forward or the backward constraints can be ignored. For example, ``reset()`` can pass ``is_backward=True`` to bypass the forward constraints. Returns ------- bool True if any constraint was applied; False otherwise. """ if not self.has_constraints: return False applied_constraints = False # Both forward and backward constraints are attempted if the call method is not # initiated by a transition (action is None), unless is_backward is True or # False, in which case one of the two directions may be ignored. # then is_backward must be None too, and vice versa # Forward constraints are applied as well if the call method is initiated # by a forward transition (action is not None and is_backward is False) # Backward constraints are applied as well if the call method is initiated # by a backward transition (action is not None and is_backward is True) if action is not None: assert isinstance(is_backward, bool) do_forward = is_backward is not True do_backward = is_backward is not False if do_forward: applied_constraints = self._apply_constraints_forward(action, state) if do_backward: applied_constraints = self._apply_constraints_backward(action, state) return applied_constraints def _apply_constraints_forward( self, action: Tuple = None, state: Optional[Dict] = None, ) -> bool: """ Applies constraints across sub-environments in the forward direction. Environments inheriting composite classes may override this method if constraints across sub-environments must be applied. The method :py:meth:`~gflownet.envs.composite.base.CompositeBase._do_constraints_for_subenv` may be used as a helper to determine whether the constraints imposed by a sub-environment should be applied depending on the action. Parameters ---------- action : tuple (optional) An action from the global composite environment. If the call of this method is not initiated by a transition, then ``action`` is None. state : dict (optional) A state of the global composite environment. Returns ------- bool True if any constraint was applied; False otherwise. """ return False def _apply_constraints_backward( self, action: Tuple = None, state: Optional[Dict] = None, ) -> bool: """ Applies constraints across sub-environments in the backward direction. In the backward direction, in this case, means that the constraints between two sub-environments are undone and reset as in the source state. Environments inheriting composite classes may override this method if constraints across sub-environments must be applied. The method :py:meth:`~gflownet.envs.composite.base.CompositeBase._do_constraints_for_subenv` may be used as a helper to determine whether the constraints imposed by a sub-environment should be applied depending on the action. Parameters ---------- action : tuple (optional) An action from the global composite environment. If the call of this method is not initiated by a transition, then ``action`` is None. state : dict (optional) A state of the global composite environment. Returns ------- bool True if any constraint was applied; False otherwise. """ return False def _do_constraints_for_subenv( self, state: Union[Dict], idx_subenv: int, action: Tuple = None, is_backward: bool = False, ) -> bool: """ Returns True if constraints chould be applied given the state, relevant sub-environment, action and direction. This method is meant to be used by environments inheriting composite classes to determine whether the constraints imposed by a particular sub-environment should be applied. This depends on whether the environment is done or not, whether the constraints are to be done or undone, and whether they would be triggered by a transition or by a state. This method is meant to be called from: - :py:meth:`~gflownet.envs.composite.base.CompositeBase._apply_constraints_forward` - :py:meth:`~gflownet.envs.composite.base.CompositeBase._apply_constraints_backward` Additionally, composite environments may include other speciic checks before setting inter-environment constraints, besides the output of this method. Forward constraints could be applied if: - The condition environment is done, and - The action is either None or EOS Backward constraints could be applied if: - The condition environment is not done, and - The action is either None or EOS Parameters ---------- state : dict A state of the global composite environment. idx_subenv : int Index of the sub-environment that would trigger constraints. action : tuple (optional) The action involved in the transition, or None if there is no transition. is_backward : bool Boolean flag to indicate whether the potential constraint is in the backward direction (True) or in the forward direction (False). """ # If the index of the sub-environment is -1, then no sub-environment is # currently relevant and constraints should not be applied. if idx_subenv == -1: return False # If the action is not None, get the unique environment and depad the action if action is not None: idx_unique = self._get_unique_idx_of_subenv(idx_subenv, state) try: action = self._depad_action(action, idx_unique) except RuntimeError as e: # If there is a mismatch between idx_unique and the action index, # return False if str(e).startswith( "There is a mismatch between the input idx_unique" ): return False # If the action is not None, indicating that the check was initiated by a # transition, constraints are not applied if the action is not EOS. env_unique = self._get_env_unique(idx_unique) if action != env_unique.eos: return False subenv_is_done = self._get_subdone(idx_subenv, state) # Backward constraints should only be applied if the sub-environment is not done if is_backward: return not subenv_is_done # Forward constraints hould only be applied if the sub-environment is done else: return subenv_is_done