Source code for gflownet.envs.composite.setbase

# TODO: The docstrings need a major rewrite to clarify the new functioning with _keys
# and non-alternating subenvs.
"""
Classes implementing the family of Set meta-environments, which allow to combine
multiple sub-environments without any specific order.
"""

import random
from collections import Counter
from typing import Dict, List, Optional, Tuple, Union

import torch
from torchtyping import TensorType

from gflownet.envs.base import GFlowNetEnv
from gflownet.envs.composite.base import CompositeBase
from gflownet.utils.common import copy, tfloat, tlong


[docs] class BaseSet(CompositeBase): """ Base class for the SetFlex and the SetFix classes. Set environments allow to combine multiple sub-environments of same of different type. For example, a new environment could be created by arranging a set of two (continuous) Cubes and a Grid. The SetFlex implements a Set environment with a variable number of elements (sub-environments), up to a pre-defined maximum. That is, trajectories may consist of a actions in a variable number of sub-environments from a pre-defined set of unique environments. The SetFix is a special case of the SetFlex which implements a Set with a fixed number of elements (sub-environments). That is, all trajectories consist of actions in the same set of pre-defined sub-environments. Set environments do not impose any order in the sub-environments, unlike in the Stack environment. For example, a Set may consist of the following 3 sub-environments: - 0: 2D Cube A - 1: 2D Cube B - 2: 10x10 Grid A Two variants are implemented that control how much the actions of sub-environments can alternate: 1. Once a sub-environment is selected, the subsequent actions must be of the same sub-environment until its EOS action is performed. This variant is selected by setting ``can_alternate_subenvs`` to False. 2. The actions of the sub-environments can be sampled in any order. In order to perform an action of a sub-environment, the sub-environment must be activated first with a special action. This variant is selected by setting ``can_alternate_subenvs`` to True. Therefore, if ``can_alternate_subenvs`` is True, the Set environment alternates actions that activate a sub-environment and actions from the active sub-environment. Additionally, in order to remove the ambiguity of the backward transitions, active sub-environments also need to be deactivated or toggled to go back to a state with no active sub-environment. This action are needed in the backward transitions in order to determine which sub-environment should perform the action. Finally, in order to make sure that methods work in their state-less fashion (without relying on self.subenvs), the state needs to contain information about whether sub-environments are done or not. All this implies that the state of a Set environment consists of: - The index of the active sub-environment or -1 to indicate that no sub-environment is active - A flag (toggle) to indicate whether a sub-environment action is expected, or whether an action to toggle a sub-environment is expected. This flag is only necessary if ``can_alternate_subenvs`` is True. Therefore, the toggle flag is ignored (always 0) if ``can_alternate_subenvs`` is False. - A list of flags indicating whether the sub-environments are done (1) or not (0). - A dictionary with the states of all the sub-environments The flow of actions for each of the two variants is as follows: 1. Actions of different sub-environments cannot alternate (``can_alternate_subenvs`` is False) - s0: (active: -1, toggle: 0, dones: [0, 0]) | action: toggle subenv 1 - s1: (active: 1, toggle: 0, dones: [0, 0] | action: an action of subenv 1 - s2: (active: 1, toggle: 0, dones: [0, 0] | action: an action of subenv 1 - s3: (active: 1, toggle: 0, dones: [0, 0] | action: EOS action of subenv 1 - s4: (active: 1, toggle: 0, dones: [0, 1] | action: toggle subenv 1 - s5: (active: -1, toggle: 0, dones: [0, 1]) | action: toggle subenv 0 - s6: (active: 0, toggle: 0, dones: [0, 1]) | action: an action of subenv 0 - s7: (active: 0, toggle: 0, dones: [0, 1]) | action: an action of subenv 0 - s8: (active: 0, toggle: 0, dones: [0, 1] | action: EOS action of subenv 0 - s9: (active: 0, toggle: 0, dones: [1, 1]) | action: toggle subenv 0 - s10: (active: -1, toggle: 0, dones: [1, 1]) | action: global EOS 2. Actions of different sub-environments can alternate (``can_alternate_subenvs`` is True) - s0: (active: -1, toggle: 0, dones: [0, 0]) | action: toggle subenv 1 - s1: (active: 1, toggle: 1, dones: [0, 0] | action: an action of subenv 1 - s2: (active: 1, toggle: 0, dones: [0, 0]) | action: toggle subenv 1 - s3: (active: -1, toggle: 0, dones: [0, 0]) | action: toggle subenv 1 - s4: (active: 1, toggle: 1, dones: [0, 0] | action: an action of subenv 1 - s5: (active: 1, toggle: 0, dones: [0, 0]) | action: toggle subenv 1 - s6: (active: -1, toggle: 0, dones: [0, 0]) | action: toggle subenv 0 - s7: (active: 0, toggle: 1, dones: [0, 0] | action: an action of subenv 0 - s8: (active: 0, toggle: 0, dones: [0, 0]) | action: toggle subenv 0 - s9: (active: -1, toggle: 0, dones: [0, 0]) | action: toggle subenv 1 - s10: (active: 1, toggle: 1, dones: [0, 0] | action: EOS action of subenv 1 - s11: (active: 1, toggle: 0, dones: [0, 1] | action: toggle subenv 1 - s12: (active: -1, toggle: 0, dones: [0, 1]) | action: toggle subenv 0 - s13: (active: 0, toggle: 1, dones: [0, 1] | action: an action of subenv 0 - s14: (active: 0, toggle: 0, dones: [0, 1]) | action: toggle subenv 0 - s15: (active: -1, toggle: 0, dones: [0, 1]) | action: toggle subenv 0 - s16: (active: 0, toggle: 1, dones: [0, 1] | action: EOS action of subenv 0 - s17: (active: 0, toggle: 0, dones: [1, 1] | action: toggle subenv 0 - s18: (active: -1, toggle: 0, dones: [1, 1]) | action: global EOS A potential alternative implementation would be to keep the active sub-environment active until a different sub-environment is selected. However, this would require special handling of continuous environment, since in order to calculate the probability of an action, we would have to mix the continuous distribution with the discrete distribution over the actions to activate a different sub-environment. """ def __init__( self, can_alternate_subenvs=True, **kwargs, ): """ Initializes the BaseSet. Parameters ---------- can_alternate_subenvs : bool If True, actions of different sub-environments can alternate and each sub-environment action is preceded and followed by a meta-action to toggle the sub-environment. If False, once a sub-environment is activated, only actions of that sub-environment can be performed until it gets done (its EOS action is performed). """
[docs] self.can_alternate_subenvs = can_alternate_subenvs
# Base class init super().__init__(**kwargs) @property
[docs] def n_toggle_actions(self) -> int: """ Returns the number of actions to toggle sub-environments or unique environments. If the Set allows alternating actions between sub-environments, the number of toggle actions is the number of sub-environments. Otherwise, toggle actions activate unique environments and the number of unique environments is returned. """ if not hasattr(self, "_n_toggle_actions"): if self.can_alternate_subenvs: self._n_toggle_actions = self.max_elements else: self._n_toggle_actions = self.n_unique_envs return self._n_toggle_actions
def _get_state_key(self, idx_subenv: int, state: Optional[Dict] = None) -> int: """ Returns the the dictionary key of the state corresponding to the subenv with index ``idx_subenv``. In order to handle the permutations of states that would leave a Set state invariant, Set states have a key called ``_keys`` which contains a list of the dictionary keys that contain each sub-state. This list is initially sorted in ascending order, for instance ``_keys: [0, 1, 2]``, but when the states are permuted, instead of permuting the actual states in the dictionary, the keys are permuted. This happens, for instance, in backward transitions. For example, the keys could become ``_keys: [0, 2, 1]``. This would mean that the state with index 1 (at ``_keys[1]``) is stored under the key ``2`` of the dictionary, and the state with index 2 is stored under the key ``1``. Parameters ---------- idx_subenv : int Index of a sub-environment (from 0 to ``self.max_elements``). Note that this is the index of a subenv, not of the unique environments. state : dict A state of the Set environment. Returns ------- int The key in the state dictionary containing the state with index ``idx_subenv``. Raises ------ ValueError If ``index_subenv`` is not a valid sub-environment index because it is not one of the keys of the state. """ state = self._get_state(state) if idx_subenv not in state["_keys"]: raise ValueError( f"Index {idx_subenv} is not a valid sub-environment index." ) return state["_keys"][idx_subenv] def _get_substate(self, state: Dict, idx_subenv: Optional[int] = None): """ Returns the part of the state corresponding to the sub-environment indicated at ``idx_subenv``. This method is overriden to account for the potential permutation of the keys of the substates. Parameters ---------- state : dict A state of the composite environment. idx_subenv : int Index of the sub-environment of which the corresponding part of the state is to be extracted. If None, the state of the active subenv is used. Note that this is the index of a subenv, not of the unique environments, and that the index may not correspond to the key of the substate as is stored in the state. The actual key is obtained via :py:meth:`~gflownet.envs.composite.setbase.BaseSet._get_state_key` Returns ------- The state of a sub-environment. Raises ------ ValueError If ``index_subenv`` is not a valid sub-environment index because it is not one of the keys of the state. """ if idx_subenv is None: idx_subenv = self._get_active_subenv(state) key_substate = self._get_state_key(idx_subenv, state) return super()._get_substate(state, key_substate) def _set_substate( self, idx_subenv: int, state_subenv: Union[List, TensorType, dict], state: Optional[Dict] = None, ) -> Dict: """ Updates the global composite state by setting as substate of subenv ``idx_subenv`` the current state of the sub-environment. This method modifies ``self.state`` if ``state`` is None. This method is overriden to account for the potential permutation of the keys of the substates. Parameters ---------- idx_subenv : int Index of the sub-environment of which to set the state. Note that this is the index of a subenv and that the index may not correspond to the key of the substate as is stored in the state. The actual key is obtained via :py:meth:`~gflownet.envs.composite.setbase.BaseSet._get_state_key` state_subenv : list or tensor or dict The state of a sub-environment. state : dict A state of the global composite environment. Returns ------- The updated composite state. Raises ------ ValueError If ``index_subenv`` is not a valid sub-environment index because it is not one of the keys of the state. """ key_substate = self._get_state_key(idx_subenv, state) return super()._set_substate(key_substate, state_subenv, state) # TODO: update by using super().get_action_space(), which will require changing # other methods to use the correct indexing of actions
[docs] def get_action_space(self) -> List[Tuple]: r""" Constructs list with all possible actions, including eos. The action space of a Set environment consists of: - The actions to activate specific sub-environments or unique environments. - The EOS action. - The concatenation of the actions of all unique environments In order to make all actions the same length (required to construct batches of actions as a tensor), the actions are zero-padded from the back. In order to make all actions unique, the unique environment index is added as the first element of the action. Note that the actions of unique environments are only added once to the action space, regardless of how many elements of the unique environment (sub-environments) there are in the set. In other words, identical environments that are part of the Set share the actions and a given action will have an effect on the sub-environment that is active. The actions to activate a specific sub-environment are represented as: (-1, subenv index, ZERO-PADDING) See: - :py:meth:`~gflownet.envs.composite.setbase.BaseSet._pad_action` - :py:meth:`~gflownet.envs.composite.setbase.BaseSet._depad_action` """ action_space = [] # Actions to activate a sub-environment or unique environment action_space.extend( [self._pad_action((idx,), -1) for idx in range(self.n_toggle_actions)] ) # EOS action action_space += [self.eos] # Action space of each unique environment for idx in range(self.n_unique_envs): action_space.extend( [ self._pad_action(action, idx) for action in self._get_env_unique(idx).action_space ] ) return action_space
[docs] def action_produces_permutation( self, action: Tuple, is_backward: bool = False ) -> bool: """ Determines whether an action produces permutations in the resulting state. The Set introduces actions that produce permutations, in particular in the key ``_keys`` of the state. These actions are introduced if ``self.can_alternate_subenvs`` is False. In particular, the actions that produce permutations are backward actions that toggle a sub-environment. Note that this method does not check whether all relevant substates are identical, in which case, there is effectively not more than one permutation. Instead, True is returned if the action _could_ produce permutations in the resulting state. Parameters ---------- action : tuple An action of the environment. is_backward : bool Whether the transition to consider is backward (True) or forward (False). Returns ------- bool Whether the input actions produces permutations in the resulting state, in the direction indicated by ``is_backward``. """ if ( not self.can_alternate_subenvs and not is_backward and action[0] == -1 and action != self.eos ): return True return False
# TODO: make mask prefix indicate the unique environment rather than active subenv
[docs] def get_mask_invalid_actions_forward( self, state: Optional[Dict] = None, done: Optional[bool] = None ) -> List[bool]: """ Computes the forward actions mask of the state. The mask of the Set environment is the concatenation of the following: - A one-hot encoding of the index of the sub-environment or unique environment (True at the index of the active environment). All False if the only valid actions are meta-actions. - Actual (main) mask of invalid actions: - The mask of the actions to activate a sub-environment or unique environment, OR - The mask of the active sub-environment. The mask is False-padded from the back up to mask_dim. """ do_constraints = state is not None and id(state) != id(self.state) state = self._get_state(state) done = self._get_done(done) # Apply constraints based on the input state if do_constraints: do_constraints = self._apply_constraints(state=state) # Get active sub-environment and flag active_subenv = self._get_active_subenv(state) toggle_flag = self._get_toggle_flag(state) dones = self._get_dones(state) # Establish the case based on the active sub-environment, the toggle flag and # the done flags case_a = case_b = case_c = case_d = False if active_subenv == -1: # - Case A: no sub-environment is active: the only valid actions are to # toggle sub-environments or the global EOS. assert toggle_flag == 0 case_a = True elif not self.can_alternate_subenvs and dones[active_subenv] == 1: # Case B: in the variant where sub-environments cannot alternate, the # active sub-environment is done: this indicates that since the # sub-environment is done, the only valid action is to toggle (deactivate) # the active sub-environment. case_b = True elif self.can_alternate_subenvs and toggle_flag == 0: # Case C: in the variant where sub-environments can alternate, the toggle # flag is zero: this indicates a sub-environment action has been performed # and the only valid action is to toggle (deactivate) the active # sub-environment. case_c = True else: # Case D: a sub-environment is active and, if sub-environments cannot # alternate, the toggle flag is 1: this indicates that a sub-environment # action is to be performed. assert not self.can_alternate_subenvs or toggle_flag == 1 assert not dones[active_subenv] case_d = True # Build the mask based on the case if case_a: # The main mask is the mask of the meta-actions to toggle a sub-environment # (or unique environment). if self.can_alternate_subenvs: # The action to activate a sub-environment is invalid (True) if the # sub-environment is done. mask = [bool(done) for done in dones] else: # The action to activate a unique environment is invalid (True) if all # its sub-environments are done. indices_unique = self._get_unique_indices(state) mask = [True] * self.n_unique_envs for done, idx_unique in zip(dones, indices_unique): if mask[idx_unique] and done == 0: mask[idx_unique] = False # The global EOS is invalid (True) unless all other actions are invalid. mask += [not all(mask)] elif case_b or case_c: # The main mask is the mask of the meta-actions to toggle a sub-environment # or unique environment, but the only valid action is to toggle the active # sub-environment. The global EOS is invalid (True). active_subenv is set # to -1, in order to make the mask formatting reflect that the valid # actions are set meta-actions. mask = [True] * self.n_toggle_actions if self.can_alternate_subenvs: mask[active_subenv] = False else: mask[self._get_unique_idx_of_subenv(active_subenv)] = False mask += [True] active_subenv = -1 elif case_d: # The main mask is the mask of the active sub-environment # Get subenv from unique environments. This way computing the mask does not # depend on self.subenvs and can be computed without setting the subenvs if # the state is passed. subenv = self._get_unique_env_of_subenv(active_subenv, state) state_subenv = self._get_substate(state, active_subenv) mask = subenv.get_mask_invalid_actions_forward(state_subenv, False) else: raise RuntimeError("None of the possible forward cases is True") # Reset constraints for self.state if do_constraints: self._apply_constraints(state=self.state) # Format mask and return return self._format_mask(mask, active_subenv)
[docs] def get_mask_invalid_actions_backward( self, state: Optional[Dict] = None, done: Optional[bool] = None ) -> List[bool]: """ Computes the backward actions mask of the state. The mask of the Set environment is the concatenation of the following: - A one-hot encoding of the index of the subenv (True at the index of the active environment). All False if no sub-environment is active. - Actual (main) mask of invalid actions: - The mask of the actions to activate a sub-environment, OR - The mask of the active sub-environment. The mask is False-padded from the back up to mask_dim. """ do_constraints = state is not None and id(state) != id(self.state) state = self._get_state(state) done = self._get_done(done) # Apply constraints based on the input state if do_constraints: do_constraints = self._apply_constraints(state=state) # Get active sub-environment and flag active_subenv = self._get_active_subenv(state) toggle_flag = self._get_toggle_flag(state) dones = self._get_dones(state) subenv = None state_subenv = None # Establish the case based on the active sub-environment, the toggle flag and # the done flags case_a = case_b = case_c = case_d = case_e = case_f = False if active_subenv == -1: # - Case A: no sub-environment is active: the only valid actions are to # toggle sub-environments or the global EOS. assert toggle_flag == 0 case_a = True elif self.can_alternate_subenvs: if toggle_flag == 0: # Case B: in the variant where sub-environments can alternate, the # toggle flag is zero: this indicates a sub-environment action (in the # backward sense) must be performed. case_b = True else: # Case C: in the variant where sub-environments can alternate, a # sub-environment is active and the toggle flag is 1: this indicates a # sub-environment action (in the backward sense) has been performed and # the only valid action is to toggle (deactivate) the sub-environment. case_c = True else: if dones[active_subenv] == 1: # Case D: in the variant where sub-environments cannot alternate the # sub-environment is done: this indicates the sub-environment has just # been activated (in the backward sense) and the only valid is the EOS # of the active sub-environment. case_d = True else: subenv = self._get_unique_env_of_subenv(active_subenv, state) state_subenv = self._get_substate(state, active_subenv) if subenv.is_source(state_subenv): # Case E: in the variant where sub-environments cannot alternate, # the sub-environment is in the source state: the only valid action # is to toggle it. case_e = True else: # Case F: in the variant where sub-environments cannot alternate, # the sub-environment is not in the source state and is not done: # this indicates a sub-environment action (in the backward sense) # must be performed. case_f = True # Build the mask based on the case if case_a: # The main mask is the mask of the meta-actions to activate a # sub-environment. The action to activate a sub-environment is invalid # (True) if it is the source state. The global EOS is invalid (True) unless # the parent Set environment's done is True. If so, all toggle actions are # invalid. assert toggle_flag == 0 mask = [True] * self.n_toggle_actions if done: mask += [False] else: # Toggling a sub-environment is invalid if the substate is source but # the sub-environment is not done. indices_unique = self._get_unique_indices(state) for idx, (idx_unique, done) in enumerate(zip(indices_unique, dones)): # Skip non-present sub-environments if idx_unique == -1: continue # Skip sub-envs whose unique env has already been marked as valid if not self.can_alternate_subenvs and not mask[idx_unique]: continue if not done and self._get_env_unique(idx_unique).is_source( self._get_substate(state, idx) ): continue else: if self.can_alternate_subenvs: mask[idx] = False else: mask[idx_unique] = False mask += [True] elif case_c or case_e: # The main mask is the mask of the meta-actions to toggle a # sub-environment, but the only valid action is to toggle the active # sub-environment. The global EOS is invalid. # active_subenv is set to -1, in order to force the prefix reflect that the # state is effectively inactive. EOS is invalid from this state. mask = [True] * self.n_toggle_actions if self.can_alternate_subenvs: mask[active_subenv] = False else: mask[self._get_unique_idx_of_subenv(active_subenv)] = False mask += [True] active_subenv = -1 elif case_b or case_d or case_f: # The main mask is the mask of the active sub-environment # Get subenv from unique environments. This way computing the mask does not # depend on self.subenvs and can be computed without setting the subenvs if # the state is passed. if subenv is None or state_subenv is None: subenv = self._get_unique_env_of_subenv(active_subenv, state) state_subenv = self._get_substate(state, active_subenv) done_subenv = dones[active_subenv] mask = subenv.get_mask_invalid_actions_backward(state_subenv, done_subenv) else: raise RuntimeError("None of the possible backward cases is True") # Reset constraints for self.state if do_constraints: self._apply_constraints(state=self.state) # Format mask and return return self._format_mask(mask, active_subenv)
# TODO
[docs] def mask_conditioning( self, mask: Union[List[bool], TensorType["mask_dim"]], env_cond, backward: bool ): """ Conditions the input mask based on the restrictions imposed by a conditioning environment, env_cond. This method is overriden because the base mask_conditioning would change the mask unaware of the special Stack format. Therefore, this method calls the mask_conditioning() method of the currently relevant sub-environment and returns the mask with the correct Stack format. """ stage = self._get_stage() subenv = self.subenvs[stage] # Extract the part of the mask corresponding to the sub-environment # TODO: consider writing a method to do this mask = mask[self.n_toggle_actions : self.n_toggle_actions + subenv.mask_dim] env_cond = env_cond.subenvs[stage] mask = subenv.mask_conditioning(mask, env_cond, backward) return self._format_mask(mask, stage, subenv.mask_dim)
[docs] def step( self, action: Tuple, skip_mask_check: bool = False ) -> Tuple[Dict, Tuple, bool]: """ Executes forward step given an action. Actions may be either sub-environent actions, or set actions. If the former, the action is performed by the corresponding sub-environment and then the set state is updated accordingly. If the latter, no sub-environment is involved and the changes are in the meta-data of the state (active subenv and toggle flag) Because the same action may correspond to multiple sub-environments, the action will always be performed on the active sub-environment. - Toggle actions: - Activate the corresponding sub-environment if no sub-environment is currently active. - If can_alternate_subenvs is True, the toggle flag is set to 1. - Reset the active sub-environment flag to -1 if a sub-environment is currently active. - The toggle flag is expected to be 0 and it remains 0. - Environment actions: - Updates the corresponding sub-environment as well as the set state. - If can_alternate_subenvs is True, the toggle flag is set to 0. Parameters ---------- action : tuple Action to be executed. The input action is global, that is padded. Returns ------- self.state : dict The state after executing the action. action : int Action executed. valid : bool False, if the action is not allowed for the current state. True otherwise. """ # If self.subenvs is None, raise an exception if self.subenvs is None: raise ValueError( "self.subenvs of the SetFlex is None. The subenvs must be set before " "developing a trajectory." ) # If done, exit immediately if self.done: return self.state, action, False # Case A: the action is EOS or is an action to toggle a sub-environment if action[0] == -1: assert self._get_toggle_flag(self.state) == 0 # Skip mask check in pre-step from base environment because the mask would # not match the action space do_step, _, _ = self._pre_step( action, skip_mask_check=True, ) # Do mask check of Set actions # Note that this relies on the Set actions being placed first in the action # space if not skip_mask_check and not self.skip_mask_check: action_idx = self.action_space.index(action) if self._extract_core_mask( self.get_mask_invalid_actions_forward(), idx_unique=-1 )[action_idx]: do_step = False if not do_step: return self.state, action, False self.n_actions += 1 # If action is EOS, set done to True and return if action == self.eos: assert all([env.done for env in self.subenvs]) self.done = True return self.state, action, True # Otherwise, it is an action to toggle a sub-environment: # - Update the active sub-environment of the parent Set state # - Toggle the flag # - Return toggled_idx = self._depad_action(action)[0] if self._get_active_subenv(self.state) == -1: if self.can_alternate_subenvs: self._set_active_subenv(toggled_idx) self._set_toggle_flag(1) else: # Activate first non-done subenv of the toggled type indices_unique = self._get_unique_indices(self.state) dones = self._get_dones(self.state) for idx, (idx_unique, done) in enumerate( zip(indices_unique, dones) ): if idx_unique == toggled_idx and not done: self._set_active_subenv(idx) break else: # Deactivate the current subenv active_subenv = self._get_active_subenv(self.state) if self.can_alternate_subenvs: assert active_subenv == toggled_idx else: assert ( self._get_unique_idx_of_subenv(active_subenv, self.state) == toggled_idx ) assert self._get_toggle_flag(self.state) == 0 self._set_active_subenv(-1) return self.state, action, True # Case B: the action is an action from a sub-environment # Get the sub-environment corresponding to the action and its sub-action if self.can_alternate_subenvs: assert self._get_toggle_flag(self.state) == 1 # Get active sub-environment and depad action active_subenv = self._get_active_subenv(self.state) assert active_subenv != -1 idx_unique = action[0] assert self._get_unique_indices(self.state)[active_subenv] == idx_unique subenv = self.subenvs[active_subenv] action_subenv = self._depad_action(action, idx_unique) # Perform pre-step from subenv - if it was done from the Set env there could # be a mismatch between mask and action space due to continuous subenvs. action_to_check = subenv.action2representative(action_subenv) # Skip mask check if active sub-environment is continuous if subenv.continuous: skip_mask_check = True do_step, _, _ = subenv._pre_step( action_to_check, skip_mask_check=(skip_mask_check or self.skip_mask_check), ) if not do_step: return self.state, action, False # Call step of current sub-environment _, _, valid = subenv.step(action_subenv) # If action is invalid, exit immediately. Otherwise increment actions and go on if not valid: return self.state, action, False self.n_actions += 1 # Update (global) Set state, apply constraints and return # Note that the unique indices are not change by performing an action self._set_substate(active_subenv, subenv.state) self._set_subdone(active_subenv, subenv.done) self._set_active_subenv(active_subenv) if self.can_alternate_subenvs or subenv.done: self._set_toggle_flag(0) self._apply_constraints(state=self.state, action=action, is_backward=False) return self.state, action, valid
[docs] def step_backwards( self, action: Tuple, skip_mask_check: bool = False ) -> Tuple[Dict, Tuple, bool]: """ Executes backward step given an action. Actions may be either sub-environent actions, or set actions. If the former, the action is performed by the corresponding sub-environment and then the set state is updated accordingly. If the latter, no sub-environment is involved and the changes are in the meta-data of the state (active subenv and toggle flag) Because the same action may correspond to multiple sub-environments, the action will always be performed on the active sub-environment. - Toggle actions: - Activate the corresponding sub-environment if no sub-environment is currently active. - Reset the active sub-environment flag to -1 if a sub-environment is currently active. - Set the toggle flag to 0. - Environment actions: - Updates the corresponding sub-environment as well as the set state. - If can_alternate_subenvs is True, set the toggle flag is set to 1. Parameters ---------- action : tuple Action to be executed. The input action is global, that is padded. Returns ------- self.state : dict The state after executing the action. action : int Action executed. valid : bool False, if the action is not allowed for the current state. True otherwise. """ # If self.subenvs is None, raise an exception if self.subenvs is None: raise ValueError( "self.subenvs of the SetFlex is None. The subenvs must be set before " "developing a trajectory." ) # Case A: the action is EOS or is an action to toggle a sub-environment if action[0] == -1: # If can_alternate_subenvs is True and there is an active sub-environment # but the toggle flag is 0, the action cannot be a Set action, thus it is # invalid if ( self.can_alternate_subenvs and self._get_active_subenv() != -1 and self._get_toggle_flag() == 0 ): return self.state, action, False # Skip mask check in pre-step from base environment because the mask would # not match the action space do_step, _, _ = self._pre_step( action, backward=True, skip_mask_check=True, ) # Do mask check of Set actions # Note that this relies on the Set actions being placed first in the action # space if not skip_mask_check and not self.skip_mask_check: action_idx = self.action_space.index(action) if self._extract_core_mask( self.get_mask_invalid_actions_backward(), idx_unique=-1 )[action_idx]: do_step = False if not do_step: return self.state, action, False self.n_actions += 1 # If action is EOS, set done to False and return if action == self.eos: assert self.done assert all([env.done for env in self.subenvs]) self.done = False return self.state, action, True # Otherwise, it is an action to toggle a sub-environment: # - Update the active sub-environment of the parent Set state # - Toggle the flag # - Return toggled_idx = self._depad_action(action)[0] if self._get_active_subenv(self.state) == -1: if self.can_alternate_subenvs: self._set_active_subenv(toggled_idx) self._set_toggle_flag(0) else: # Permute the done subenvironments of the selected index, and # activate the one in the last position self.state, indices_relevant = self._permute_substates( toggled_idx, self.state, done_only=True ) self._set_active_subenv(indices_relevant[-1]) else: # Toggle the current subenv active_subenv = self._get_active_subenv(self.state) if self.can_alternate_subenvs: assert active_subenv == toggled_idx assert self._get_toggle_flag(self.state) == 1 self._set_toggle_flag(0) else: assert ( self._get_unique_idx_of_subenv(active_subenv, self.state) == toggled_idx ) self._set_active_subenv(-1) return self.state, action, True # Case B: the action is an action from a sub-environment # If can_alternate_subenvs is True and the toggle flag is not 0, then it is an # invalid action if self.can_alternate_subenvs and not self._get_toggle_flag(self.state) == 0: return self.state, action, False # Get active sub-environment and depad action active_subenv = self._get_active_subenv(self.state) assert active_subenv != -1 idx_unique = action[0] assert self._get_unique_indices(self.state)[active_subenv] == idx_unique subenv = self.subenvs[active_subenv] action_subenv = self._depad_action(action, idx_unique) # Perform pre-step from subenv - if it was done from the Set env there could # be a mismatch between mask and action space due to continuous subenvs. action_to_check = subenv.action2representative(action_subenv) # Skip mask check if active sub-environment is continuous if subenv.continuous: skip_mask_check = True do_step, _, _ = subenv._pre_step( action_to_check, backward=True, skip_mask_check=(skip_mask_check or self.skip_mask_check), ) if not do_step: return self.state, action, False # Call step of current sub-environment _, _, valid = subenv.step_backwards(action_subenv) # If action is invalid, exit immediately. Otherwise increment actions and go on if not valid: return self.state, action, False self.n_actions += 1 # Update (global) Set state, apply constraints and return self._set_substate(active_subenv, subenv.state) self._set_subdone(active_subenv, subenv.done) self._set_active_subenv(active_subenv) if self.can_alternate_subenvs: self._set_toggle_flag(1) self._apply_constraints(state=self.state, action=action, is_backward=True) return self.state, action, valid
# TODO: review if adding constraints is necessary if state is not None and # different to self.state
[docs] def get_parents( self, state: Optional[Dict] = None, done: Optional[bool] = None, action: Optional[Tuple] = None, ) -> Tuple[List, List]: """ Determines all parents and actions that lead to state. Parameters ---------- state : dict State in environment format. If not, self.state is used. done : bool Whether the trajectory is done. If None, self.done is used. action : tuple Ignored. Returns ------- parents : list List of parents in state format actions : list List of actions that lead to state for each parent in parents """ state = self._get_state(state) done = self._get_done(done) # If done is True, the only parent is the state itself with action EOS. if done: return [state], [self.eos] parents = [] actions = [] # Get active sub-environment and flag active_subenv = self._get_active_subenv(state) toggle_flag = self._get_toggle_flag(state) dones = self._get_dones(state) subenv = None state_subenv = None # Establish the case based on the active sub-environment, the toggle flag and # the done flags case_a = case_b = case_c = case_d = case_e = case_f = False if active_subenv == -1: # - Case A: no sub-environment is active assert toggle_flag == 0 case_a = True elif self.can_alternate_subenvs: if toggle_flag == 0: # Case B: in the variant where sub-environments can alternate, the # toggle flag is zero: this indicates sub-environment actions (in the # backward sense) are valid. case_b = True else: # Case C: in the variant where sub-environments can alternate, a # sub-environment is active and the toggle flag is 1: this indicates # the corresponding toggle action (deactivate) is valid. case_c = True else: if dones[active_subenv] == 1: # Case D: in the variant where sub-environments cannot alternate the # sub-environment is done: this indicates the sub-environment has just # been activated (in the backward sense) and the only valid is the EOS # of the active sub-environment. case_d = True else: subenv = self._get_unique_env_of_subenv(active_subenv, state) state_subenv = self._get_substate(state, active_subenv) if subenv.is_source(state_subenv): # Case E: in the variant where sub-environments cannot alternate, # the sub-environment is in the source state: the only valid action # is to toggle it. case_e = True else: # Case F: in the variant where sub-environments cannot alternate, # the sub-environment is not in the source state and is not done: # this indicates sub-environment actions (in the backward sense) # are valid. case_f = True if case_a: # Case A: no sub-environment is active: the parents of the state correspond # to states with the same sub-environment states but with one active # sub-environment. # If sub-environments can alternate, states with any active sub-environment # are parents, unless the sub-environment is at the source state and is not # done. # If sub-environments cannot alternate, only the last done sub-environment # of each unique environment can be active in the parents. assert toggle_flag == 0 indices_unique = self._get_unique_indices(state) indices_unique_seen = set() for idx, (idx_unique, done) in reversed( list(enumerate(zip(indices_unique, dones))) ): if self.can_alternate_subenvs: # Skip if the subenv is at the source and is not done if not done and self._get_env_unique(idx_unique).is_source( self._get_substate(state, idx) ): continue else: # Skip if not done if not done: continue # Skip non-present sub-environments if idx_unique == -1: continue # Skip if the unique index has been already added if idx_unique in indices_unique_seen: continue # Add parent and action parent = copy(state) if self.can_alternate_subenvs: actions.append(self._pad_action((idx,), -1)) idx_active_subenv = idx else: # Permute the done subenvironments of the selected index, and # activate the one in the last position parent, indices_relevant = self._permute_substates( idx_unique, parent, done_only=True ) idx_active_subenv = indices_relevant[-1] actions.append(self._pad_action((idx_unique,), -1)) indices_unique_seen.add(idx_unique) parents.append(self._set_active_subenv(idx_active_subenv, parent)) elif case_c or case_e: # Case B: a sub-environment is active but only the corresponding toggle # action is valid: the only parent is the same state with inactive # sub-environments and toggle flag 0. parent = copy(state) parent = self._set_active_subenv(-1, parent) if self.can_alternate_subenvs: parent = self._set_toggle_flag(0, parent) idx_action = active_subenv else: idx_action = self._get_unique_idx_of_subenv(active_subenv) parents.append(parent) actions.append(self._pad_action((idx_action,), -1)) elif case_b or case_d or case_f: # Case C: a sub-environment is active and sub-environment actions are # valid: the parents are determined by the parents of the active # sub-environment. assert toggle_flag == 0 if subenv is None or state_subenv is None: subenv = self.subenvs[active_subenv] state_subenv = self._get_substate(state, active_subenv) done_subenv = bool(dones[active_subenv]) parents_subenv, parent_actions_subenv = subenv.get_parents( state_subenv, done_subenv ) for p, p_a in zip(parents_subenv, parent_actions_subenv): parent = copy(state) if self.can_alternate_subenvs: parent = self._set_toggle_flag(1, parent) parent = self._set_substate(active_subenv, p, parent) if p_a == subenv.eos: parent = self._set_subdone(active_subenv, False, parent) parents.append(parent) actions.append( self._pad_action(p_a, self._get_unique_idx_of_subenv(active_subenv)) ) return parents, actions
def _permute_substates( self, idx_unique: int, state: Optional[Dict] = None, done_only: bool = True ) -> Tuple[Dict, int]: """ Permutes the sub-states of a given unique environment. The permutation is reflected only in the list stored in the key ``_keys`` of the dictionary, which contains the actual keys where the states are stored. Permuting the values of this list is more efficient than permuting the actual sub-states. If ``done_only`` is True (default), then only the done sub-environments are permuted. Otherwise, all sub-environments of the specified unique environment are permuted. Parameters ---------- idx_unique : int The index of the unique environment type whose instances should be permuted. state : dict A state of the Set environment. If None, self.state is used. done_only : bool Whether to permute only the done sub-environments. Returns ------- state : dict The updated state indices_relevant : int The indices of the relevant elements whose keys are permuted, which correspond to the indices of the specified type and (if ``done_only`` is True) are also done. Note that the returned indices are the actual indices and not the substate keys. These indices are not permuted. """ state = self._get_state(state) unique_indices = self._get_unique_indices(state) keys = self._get_keys(state) if done_only: dones = self._get_dones(state) else: dones = [True] * len(keys) # Find all done sub-environments of the relevant unique environment indices_relevant, keys_to_permute = zip( *[ (idx, key) for idx, (idx_u, key, done) in enumerate( zip(unique_indices, keys, dones) ) if idx_u == idx_unique and done ] ) # If there are not any indices to permute, return immediately if len(indices_relevant) <= 1: return state, indices_relevant # Permute relevant keys and update keys of state keys_to_permute = list(keys_to_permute) random.shuffle(keys_to_permute) for idx, key in zip(indices_relevant, keys_to_permute): keys[idx] = key state = self._set_keys(keys, state) return state, indices_relevant
[docs] def sample_actions_batch( self, policy_outputs: TensorType["n_states", "policy_output_dim"], mask: Optional[TensorType["n_states", "policy_output_dim"]] = None, states_from: List = None, is_backward: Optional[bool] = False, random_action_prob: Optional[float] = 0.0, temperature_logits: Optional[float] = 1.0, ) -> Tuple[List[Tuple], TensorType["n_states"]]: """ Samples a batch of actions from a batch of policy outputs. This method calls the sample_actions_batch() method of the sub-environment corresponding to each state in the batch, or samples the actions to activate a sub-environment for the environments with no active environment. Note that in order to call sample_actions_batch() of the sub-environments, we need to first extract the part of the policy outputs, the masks and the states that correspond to the sub-environment. """ # Get the states in the batch with and without an active sub-environment is_active = torch.any(mask[:, : self.n_toggle_actions], axis=1) is_set = torch.logical_not(is_active) # Sample Set actions (to toggle a sub-environment or EOS). # Note that this relies on the Set actions being placed first in the action # space, since the super() method will select the actions by indexing the # action space, starting from 0. if any(is_set): actions_set = super().sample_actions_batch( self._get_policy_outputs_of_set_actions(policy_outputs[is_set]), self._extract_core_mask(mask[is_set], idx_unique=-1), None, is_backward, random_action_prob, temperature_logits, ) # Get the active sub-environment of each mask from the one-hot prefix indices_active = torch.where(mask[is_active, : self.n_toggle_actions])[1] # If there are no states with active sub-environments, return here if len(indices_active) == 0: assert len(actions_set) == policy_outputs.shape[0] return actions_set indices_active_int = indices_active.tolist() indices_unique_int = [] states_dict = {idx: [] for idx in range(self.n_unique_envs)} """ A dictionary with keys equal to the unique environments indices and the values are the list of states in the subenv of the key. The states are only the part corresponding to the sub-environment. """ idx = 0 for state, active in zip(states_from, is_active): if active: if self.can_alternate_subenvs: active_subenv = indices_active_int[idx] idx_unique = self._get_unique_indices(state)[active_subenv] else: idx_unique = indices_active_int[idx] states_dict[idx_unique].append(self._get_substate(state)) indices_unique_int.append(idx_unique) idx += 1 indices_unique = tlong(indices_unique_int, device=self.device) # Sample actions from each unique environment actions_subenvs_dict = {} for idx, subenv in enumerate(self.envs_unique): indices_unique_mask = indices_unique == idx if not torch.any(indices_unique_mask): continue actions_subenvs_dict[idx] = subenv.sample_actions_batch( self._get_policy_outputs_of_env_unique( policy_outputs[is_active][indices_unique_mask], idx ), self._extract_core_mask( mask[is_active][indices_unique_mask], idx_unique=idx ), states_dict[idx], is_backward, random_action_prob, temperature_logits, ) # Stitch all environment actions in the right order, with the right padding actions_subenvs = [] for idx in indices_unique_int: actions_subenvs.append( self._pad_action(actions_subenvs_dict[idx].pop(0), idx) ) # Stitch all actions, both Set actions and sub-environment actions actions = [] for action_is_from_subenv in is_active: if action_is_from_subenv: actions.append(actions_subenvs.pop(0)) else: actions.append(actions_set.pop(0)) return actions
[docs] def get_logprobs( self, policy_outputs: TensorType["n_states", "policy_output_dim"], actions: Union[List, TensorType["n_states", "action_dim"]], mask: TensorType["n_states", "mask_dim"], states_from: List, is_backward: bool, ) -> TensorType["batch_size"]: """ Computes log probabilities of actions given policy outputs and actions. Parameters ---------- policy_outputs : tensor The output of the GFlowNet policy model. mask : tensor The mask containing information about invalid actions and special cases. actions : list or tensor The actions (global) from each state in the batch for which to compute the log probability. states_from : tensor The states originating the actions, in environment format. is_backward : bool True if the actions are backward, False if the actions are forward (default). """ actions = tfloat(actions, float_type=self.float, device=self.device) n_states = policy_outputs.shape[0] # Get the states in the batch with and without an active sub-environment is_active = torch.any(mask[:, : self.n_toggle_actions], axis=1) is_set = torch.logical_not(is_active) # Get logprobs of Set actions (to toggle a sub-environment or EOS). # Note that this relies on the Set actions being placed first in the action # space, since the super() method will select the actions by indexing the # action space, starting from 0. states_from is ignored so can be None. if any(is_set): logprobs_set = super().get_logprobs( self._get_policy_outputs_of_set_actions(policy_outputs[is_set]), actions[is_set], self._extract_core_mask(mask[is_set], idx_unique=-1), None, is_backward, ) # Apply permutation correction for backward toggle actions with a # stochastic component (not EOS, not when can_alternate_subenvs is True) if is_backward and not self.can_alternate_subenvs: eos_tensor = tfloat(self.eos, float_type=self.float, device=self.device) is_stochastic = torch.zeros_like(is_set) is_stochastic[is_set] = torch.any(actions[is_set] != eos_tensor, dim=1) if any(is_stochastic): logprobs_set[ is_stochastic[is_set] ] -= self._get_logprobs_of_permutations( actions[is_stochastic], [ state for state, do_keep in zip(states_from, is_stochastic) if do_keep ], ) # Get the active sub-environment of each mask from the one-hot prefix indices_active = torch.where(mask[is_active, : self.n_toggle_actions])[1] # If there are no states with active sub-environments, return here if len(indices_active) == 0: assert logprobs_set.shape[0] == n_states return logprobs_set indices_active_int = indices_active.tolist() indices_unique_int = [] states_dict = {idx: [] for idx in range(self.n_unique_envs)} """ A dictionary with keys equal to the unique environment indices and the values are the list of states in the subenv of the key. The states are only the part corresponding to the sub-environment. """ idx = 0 for state, active in zip(states_from, is_active): if active: if self.can_alternate_subenvs: active_subenv = indices_active_int[idx] idx_unique = self._get_unique_indices(state)[active_subenv] else: idx_unique = indices_active_int[idx] states_dict[idx_unique].append(self._get_substate(state)) indices_unique_int.append(idx_unique) idx += 1 indices_unique = tlong(indices_unique_int, device=self.device) # Compute logprobs from each unique environment logprobs_subenvs = torch.empty( len(indices_active), dtype=self.float, device=self.device ) for idx, subenv in enumerate(self.envs_unique): indices_unique_mask = indices_unique == idx if not torch.any(indices_unique_mask): continue logprobs_subenvs[indices_unique_mask] = subenv.get_logprobs( self._get_policy_outputs_of_env_unique( policy_outputs[is_active][indices_unique_mask], idx ), self._depad_action_batch( actions[is_active][indices_unique_mask, :], idx ), self._extract_core_mask( mask[is_active][indices_unique_mask], idx_unique=idx ), states_dict[idx], is_backward, ) # Stitch logprobs of Set actions and environment actions logprobs = torch.empty(n_states, dtype=self.float, device=self.device) if any(is_set): logprobs[is_set] = logprobs_set logprobs[is_active] = logprobs_subenvs return logprobs
def _get_logprobs_of_permutations( self, actions: TensorType["n_set_states", "action_dim"], states: List, ) -> TensorType["n_set_states"]: r""" Calculates the log-probabilities of transitions that involve permutations. Certain transitions in the Set have a stochastic component. In particular, in order to account for the fact that permutations of the substates correspond to the same state, some transitions obtain one of these permutations with uniform probability. In order to correctly assign the transition probabilities, we need to calculate the log-probability of these transitions. If all substates involved in the permutation are identical, the probability of the transition is one. Otherwise, the log-probability is $-\log(P)$, where $P$ is the number of permutations. If all substates involved in the permutation are different, the number of permutations is equal to $n!$, where $n$ is the number of substates. However, if there exist subsets of the states that are identical, the number of permutations is equal to $n! / (m_1! * m_2! * \ldots)$, where $m_i$ are the multiplicities of identical subsets of substates. Parameters ---------- actions : tensor The actions corresponding to the set states. Toggle actions have the format (-1, idx_unique, 0, ...). Shape: (n_set_states, action_dim) states : list A subset of states from the batch, corresponding to the states from which the actions are initiated. Returns ------- tensor The log probabilities of one uniform permutation for each of the input states. Shape: (len(states),) """ logprobs = torch.zeros(len(states), dtype=self.float, device=self.device) for idx, (state, action) in enumerate(zip(states, actions)): # Get the unique environment index from the toggle action idx_unique = action[1] # Get unique indices and done flags from state unique_indices = self._get_unique_indices(state) dones = self._get_dones(state) # Find all done instances of this unique environment type indices_relevant = [ idx for idx, (idx_u, done) in enumerate(zip(unique_indices, dones)) if idx_u == idx_unique and done ] n_done = len(indices_relevant) # No correction needed if fewer than 2 done instances if n_done <= 1: continue # Collect substates of done instances substates = [self._get_substate(state, idx) for idx in indices_relevant] # Count multiplicities of each substate: # - All substates are initialized to present once # - We compare each substate with the other substates except with itself # - If a substate has been already matched to another one, we skip the # comparison # - If two states are found to be equal, we increase by one the # multiplicity of the first one and set to zero the multiplicity of the # second one multiplicities = [1] * n_done matched_substates = [False] * n_done for idx_x, substate_x in enumerate(substates): for idx_y, substate_y in enumerate(substates): # Skip if the indices of both substates are the same if idx_x == idx_y: continue # Skip if the substates have already been matched to another state if matched_substates[idx_x] or matched_substates[idx_y]: continue # If a match is found, increase the multiplicity of the first # substate and set to zero the multiplicity of the second one # If substates are dictionaries and have the key "_keys", compare # using the Set's equal(). Otherwise, use the parent's equal(). if ( type(substate_x) == dict and "_keys" in substate_x and type(substate_y) == dict and "_keys" in substate_y ): substate_match = self.equal(substate_x, substate_y) else: substate_match = GFlowNetEnv.equal(substate_x, substate_y) if substate_match: multiplicities[idx_x] += 1 multiplicities[idx_y] = 0 matched_substates[idx_y] = True matched_substates[idx_x] = True # Compute number of unique permutations: n! / (m1! * m2! * ...) # In log space: log(n!) - sum(log(mi!)) log_n_factorial = torch.lgamma( torch.tensor(n_done + 1, dtype=self.float, device=self.device) ) multiplicities = torch.tensor( multiplicities, dtype=self.float, device=self.device ) multiplicities = multiplicities[multiplicities > 0] multiplicities += 1 log_denominator = torch.sum(torch.lgamma(multiplicities)) logprobs[idx] = log_n_factorial - log_denominator return logprobs def _compute_mask_dim(self) -> int: """ Calculates the mask dimensionality of the global Set environment. The mask consists of: - A one-hot encoding of the index of the active sub-environment or unique environment. - Actual (main) mask of invalid actions: - The mask of the Set actions (activate a sub-environment and EOS), OR - The mask of the active sub-environment. Therefore, the dimensionality is the maximum number of sub-environments, plus the maximum dimensionality of the mask of all sub-environments or the number of sub-environments plus one (Set actions), whichever is larger. Returns ------- int The number of elements in the Set masks. """ mask_dim_subenvs = [subenv.mask_dim for subenv in self.envs_unique] mask_dim_set_actions = self.n_toggle_actions + 1 return max(mask_dim_subenvs + [mask_dim_set_actions]) + self.n_toggle_actions def _get_toggle_flag(self, state: Optional[Dict] = None) -> int: """ Returns the value of the toggle flag from the state. If no state is passed, ``self.state`` is used. The toggle flag is indicated in ``state["_toggle"]``. Parameters ---------- state : dict A state of the parent Set environment. """ state = self._get_state(state) return state["_toggle"] def _set_toggle_flag(self, toggle_flag: int, state: Optional[Dict] = None) -> Dict: """ Sets the toggle flag. If no state is passed, ``self.state`` is used. The toggle flag is set in ``state["_toggle"]``. Parameters ---------- toggle_flag : int Value of the toggle flag to set in the state. Must be 0 or 1. state : dict A state of the parent Set environment. Returns ------- The updated Set state. """ assert toggle_flag in [0, 1] state = self._get_state(state) state["_toggle"] = toggle_flag return state def _get_keys(self, state: Optional[Dict] = None) -> list: """ Returns the list of state keys from the state. If no state is passed, ``self.state`` is used. The state keys are indicated in ``state["_keys"]``. Parameters ---------- state : dict A state of the parent Set environment. Returns ------- list The list of state keys in the state. """ state = self._get_state(state) return state["_keys"] def _set_keys(self, keys: list, state: Optional[Dict] = None) -> Dict: """ Sets the state keys. If no state is passed, ``self.state`` is used. The state keys are set in ``state["_keys"]``. Parameters ---------- keys : list The list of state keys to set in the state. Must be a permutation of the integers from 0 to ``self.max_elements``. state : dict A state of the Set environment. Returns ------- The updated Set state. """ state = self._get_state(state) state["_keys"] = keys return state
[docs] def action2representative(self, action: Tuple) -> Tuple: """ Replaces the part of the action associated with a sub-environment by its representative. The part of the action that identifies the sub-environment concerned by the action remains unaffected. Parameters ---------- action : tuple An action of the Set environment (padded) Returns ------- tuple A representative of the action, re-padded as a Set action that should be in the action space. """ # Get index of unique environmennt from action idx_unique = action[0] # If the index is -1, it is a Set action, so return if idx_unique == -1: return action # Otherwise, get the unique environment and depad the action subenv = self._get_env_unique(idx_unique) action_subenv = self._depad_action(action, idx_unique) # Obtain the representative from the unique environment representative_subenv = subenv.action2representative(action_subenv) representative = self._pad_action(representative_subenv, idx_unique) return representative
# TODO: consider moving to Composite def _format_mask(self, mask: List[bool], active_subenv: int): r""" Applies formatting to the mask of a sub-environment. The output format is the mask of the input sub-environment, preceded by a one-hot encoding of the index of the active sub-environment (or unique environment) and padded with False up to :py:const:`self.mask_dim`. If no sub-environment is active (``active_subenv`` is -1), the prefix is all False. Parameters ---------- mask : list The mask of a sub-environment active_subenv : int The index of the active sub-environment, or -1 if no subenv is active. """ active_subenv_onehot = [False] * self.n_toggle_actions if active_subenv != -1: if self.can_alternate_subenvs: active_subenv_onehot[active_subenv] = True else: active_subenv_onehot[self._get_unique_idx_of_subenv(active_subenv)] = ( True ) mask = active_subenv_onehot + mask padding = [False] * (self.mask_dim - len(mask)) return mask + padding def _extract_core_mask( self, mask: Tuple[List, TensorType["batch_size", "mask_dim"]], idx_unique: int, ) -> Tuple[List, TensorType["batch_size", "mask_dim"]]: """ Extracts the core part of the mask, that is without prefix and padding. Parameters ---------- mask : list or tensor The mask of a state (list) or a batch of masks (tensor). In the latter case, it is assumed that all states in the batch of masks correspond to the same unique environment or are all in a state with only set actions valid, that is idx_unique is -1. idx_unique : int The index of the unique environment or -1 to indicate that the mask corresponds to set actions (toggle and EOS). """ if idx_unique == -1: mask_dim = self.n_toggle_actions + 1 else: mask_dim = self._get_env_unique(idx_unique).mask_dim if isinstance(mask, list): return mask[self.n_toggle_actions : self.n_toggle_actions + mask_dim] else: assert torch.is_tensor(mask) return mask[:, self.n_toggle_actions : self.n_toggle_actions + mask_dim] # TODO: review if adding constraints is necessary if state is not None and # different to self.state
[docs] def get_valid_actions( self, mask: Optional[bool] = None, state: Optional[Dict] = None, done: Optional[bool] = None, backward: Optional[bool] = False, ) -> List[Tuple]: """ Returns the list of non-invalid (valid, for short) according to the mask of invalid actions. This method is overridden because the mask of a Set of environments does not cover the entire action space, but only the relevant sub-environment or the toggle actions, depending on the state. Therefore, this method calls the get_valid_actions() method of the active sub-environment or retrieves the valid toggle actions and returns the padded actions. """ state = self._get_state(state) done = self._get_done(done) active_subenv = self._get_active_subenv(state) if mask is None: mask = self.get_mask(state, done, backward) # Set active environment and idx_unique to -1 if the mask contains no active # environment if not any(mask[: self.n_toggle_actions]): active_subenv = -1 idx_unique = -1 else: # Otherwise, get index of unique environment from state idx_unique = self._get_unique_indices(state)[active_subenv] # Extract core mask mask = self._extract_core_mask(mask, idx_unique) if active_subenv == -1: # Case A: the only valid actions are Set actions. # Note that this relies on the Set actions being placed first in the action # space, since the super() method will select the actions by indexing the # action space, starting from 0. return super().get_valid_actions(mask, state, done, backward) # Case B: the only valid actions are sub-environment actions, which are # retrieved from the active sub-environment and padded before returning them. assert active_subenv != -1 subenv = self._get_unique_env_of_subenv(active_subenv, state) state_subenv = self._get_substate(state, active_subenv) done = bool(self._get_dones(state)[active_subenv]) return [ self._pad_action(action, idx_unique) for action in subenv.get_valid_actions(mask, state_subenv, done, backward) ]
[docs] def get_policy_output(self, params: list[dict]) -> TensorType["policy_output_dim"]: """ Defines the structure of the output of the policy model. This method is overriden to add the policy outputs corresponding to the Set actions. These are concatenated to the policy outputs of the unique environments, obtained from the parent's method. The policy output is the concatenation of the policy outputs corresponding to the Set actions (actions to activate a sub-environment and EOS) and the policy outputs of the unique environments. Parameters ---------- params : list A list of distribution parameters. This list has as many elements as there are unique environments, since all sub-environments of the same environment type are expected to be identical. """ policy_outputs_subenvs = super().get_policy_output(params) policy_outputs_set_actions = torch.ones( self.n_toggle_actions + 1, dtype=self.float, device=self.device ) return torch.cat((policy_outputs_subenvs, policy_outputs_set_actions))
def _get_policy_outputs_of_set_actions( self, policy_outputs: TensorType["n_states", "policy_output_dim"] ): """ Returns the columns of the policy outputs that correspond to the Set actions. Set actions are toggle actions and EOS. It is assumed that the policy outputs of Set actions are stored on the last ``self.n_toggle_actions + 1`` columns of the input tensor. Parameters ---------- policy_outputs : tensor A tensor containing a batch of policy outputs. It is assumed that all the rows in the this tensor correspond to actions to activate a sub-environemnt. """ return policy_outputs[:, -(self.n_toggle_actions + 1) :]
[docs] def is_source(self, state: Optional[Dict] = None) -> bool: """ Returns True if the environment's state or the state passed as parameter (if not None) is the source state of the environment. This method is overriden for efficiency (for example, it would return False immediately if the meta-data part of the state is not the source's) and to cover special uses of the Set. Parameters ---------- state : dict None, or a state in environment format. Returns ------- bool Whether the state is the source state of the environment """ state = self._get_state(state) substates = self._get_substates(state) n_subenvs = len(substates) if self._get_active_subenv(state) != -1: return False if self._get_toggle_flag(state) != 0: return False if self._get_dones(state)[:n_subenvs] != [0] * n_subenvs: return False for idx, substate in enumerate(substates): if not self._get_unique_env_of_subenv(idx, state).is_source(substate): return False return True
[docs] def equal(self, state_x: Dict, state_y: Dict) -> bool: """ Checks whether the two input states are equal. This method is overriden in order to account for the fact that states with permuted substates must be considered equal if the permutations are indeed equivalent. The permutatation of substates is not done by permuting the substates directly bu by permuting the list of keys in ``state["_keys"]``. Thus, this method returns True if all keys of the state dictionary are equal (except ``_keys`` which is ignored) and the substates are equal, after accounting for the permutation. This method uses the parent method in order to compare the substates. If a substate is a dictionary containing the key ``_keys``, then it is assumed it is a Set state and the current method is used. If Set states appear deeper in the substates, the comparison is not expected to behave as expected. Parameters ---------- state_x: dict One of the Set states to be compared. state_y: dict The other Set state to be compared. Returns ------- bool True if the two input states are equal; False otherwise. """ # Check if keys of meta data are present and that they contain the same # elements if "_active" not in state_x and "_active" not in state_y: return False if state_x["_active"] == -1 and state_y["_active"] != -1: return False if "_toggle" not in state_x and "_toggle" not in state_y: return False if state_x["_toggle"] != state_y["_toggle"]: return False if "_dones" not in state_x and "_dones" not in state_y: return False if Counter(state_x["_dones"]) != Counter(state_y["_dones"]): return False if "_envs_unique" not in state_x and "_envs_unique" not in state_y: return False if Counter(state_x["_envs_unique"]) != Counter(state_y["_envs_unique"]): return False if "_keys" not in state_x and "_keys" not in state_y: return False if set(state_x["_keys"]) != set(state_y["_keys"]): return False # Compare substates: the state is considered equal if the set of substates is # the same, regardless of the order. # For each substate in state_x, an equal substate must be found in state_y; # otherwise the states are not equal. Each time a match is found in state_y, it # gets excluded from future comparisons. # Furthermore, the corresponding values of done and envs_unique must coincide. # Finally, if the active flag is not -1, it must refer to the same substate in # order to return True. keys_y = state_y["_keys"].copy() for idx_x, key_x in enumerate(state_x["_keys"]): if key_x == -1: continue substate_x = state_x[key_x] substate_match = False for key_y in keys_y: if key_y == -1: continue substate_y = state_y[key_y] # If substates are dictionaries and have the key "_keys", compare using # the Set's equal(). Otherwise, use the GFlowNetEnv's equal(). if ( type(substate_x) == dict and "_keys" in substate_x and type(substate_y) == dict and "_keys" in substate_y ): if self.equal(substate_x, substate_y): substate_match = True else: if GFlowNetEnv.equal(substate_x, substate_y): substate_match = True if substate_match: keys_y.remove(key_y) idx_y = state_y["_keys"].index(key_y) if idx_x == state_x["_active"] and idx_y != state_y["_active"]: return False if state_x["_dones"][idx_x] != state_y["_dones"][idx_y]: return False if state_x["_envs_unique"][idx_x] != state_y["_envs_unique"][idx_y]: return False break else: # Return False if not match of substate is found in state_y return False return True
[docs] def __eq__(self, other, ignored_keys: List[str] = []) -> bool: """ Checks whether the current environment instance is equal to the input environment instance. This method is overriden to ignore the keys: - ``envs_unique_cache`` Parameters ---------- other : GFlowNetEnv The environment instance to be compared. ignored_keys : list A list of keys (strings) to be ignored in the comparison. This parameter may be used by subclasses that may need to ignore certain keys. True if the environments's attributes are considered equal; False otherwise. """ ignored_keys = ignored_keys + ["envs_unique_cache"] return super().__eq__(other, ignored_keys=ignored_keys)