gflownet.envs.composite.stack ============================= .. py:module:: gflownet.envs.composite.stack .. autoapi-nested-parse:: Base class for Stack environments. Stack environments are environments which consist of sequence of multiple sub-environments in a fixed order. Classes ------- .. autoapisummary:: gflownet.envs.composite.stack.Stack Module Contents --------------- .. py:class:: Stack(subenvs, **kwargs) Bases: :py:obj:`gflownet.envs.composite.base.CompositeBase` :param subenvs: A sequence containing the ordered list of the sub-environments to be stacked. :type subenvs: Sequence[GFlowNetEnv] .. py:attribute:: subenvs .. py:attribute:: n_subenvs .. py:attribute:: max_elements .. py:attribute:: source .. py:attribute:: action_dim .. py:attribute:: eos .. py:attribute:: continuous .. py:method:: get_mask_invalid_actions_forward(state = None, done = None) Computes the forward actions mask of the state. The mask of the Stack environment is the mask of the active sub-environment, preceded by a one-hot encoding of the index of the subenv and padded with False up to mask_dim. Including only the relevant mask saves memory and computation. If a state is passed as an argument (not ``None``) and the Stack has constraints, the constraints are applied before computing the mask and reset thereafter. This is necessary because otherwise the sub-environments may not have the correct attributes necessary to calculate the mask. .. py:method:: get_mask_invalid_actions_backward(state = None, done = None) Computes the backward actions mask of the state. The mask of the Stack environment is the mask of the relevant sub-environment, preceded by a one-hot encoding of the index of the subenv and padded with False up to ``mask_dim``. Including only the relevant mask saves memory and computation. The relevant sub-environment regarding the backward mask is the current sub-environment except if the state of the sub-environment is the subenv's source, in which case the mask must be the one of the preceding sub-environment, so as to sample its EOS action. There are two exceptions to the above case: - If ``done`` is True, in which case the current sub-environment is the last one and the EOS action must come from itself, not the preceding subenv. - If the current stage is the first sub-environment, in which case there is no preceding subenv. If a state is passed as an argument (not ``None``) and the Stack has constraints, the constraints are applied before computing the mask and reset thereafter. This is necessary because otherwise the sub-environments may not have the correct attributes necessary to calculate the mask. .. py:method:: get_valid_actions(mask = None, state = None, done = None, backward = False) Returns the list of non-invalid (valid, for short) actions. This method is overridden because the mask of a Stack of environments does not cover the entire action space, but only the relevant sub-environment. Therefore, this method calls the ``get_valid_actions()`` method of the currently relevant sub-environment and returns the padded actions. If a state is passed as an argument (not ``None``) and the Stack has constraints, the constraints are applied before computing the mask and reset thereafter. This is necessary because otherwise the sub-environments may not have the correct attributes necessary to calculate the mask. .. py:method:: mask_conditioning(mask, env_cond, backward) Conditions the input mask based on the restrictions imposed by a conditioning environment, env_cond. This method is overriden because the base mask_conditioning would change the mask unaware of the special Stack format. Therefore, this method calls the mask_conditioning() method of the currently relevant sub-environment and returns the mask with the correct Stack format. .. py:method:: get_parents(state = None, done = None, action = None) Determines all parents and actions that lead to the input state. If a state is passed as an argument (not ``None``) and the Stack has constraints, the constraints are applied before computing the mask and reset thereafter. This is necessary because otherwise the sub-environments may not have the correct attributes necessary to calculate the mask. :param state: State in environment format. If not, self.state is used. :type state: list :param done: Whether the trajectory is done. If None, self.done is used. :type done: bool :param action: Ignored. :type action: tuple :returns: * **parents** (*list*) -- List of parents in state format * **actions** (*list*) -- List of actions that lead to state for each parent in parents .. py:method:: step(action, skip_mask_check = False) Executes forward step given an action. The action is performed by the corresponding sub-environment and then the global state is updated accordingly. If the action is the EOS of the sub-environment, the stage is advanced and constraints are set on the subsequent sub-environment. :param action: Action to be executed. The input action is global, that is padded. :type action: tuple :returns: * **self.state** (*Dict*) -- The state after executing the action. * **action** (*int*) -- Action executed. * **valid** (*bool*) -- False, if the action is not allowed for the current state. True otherwise. .. py:method:: step_backwards(action, skip_mask_check = False) Executes backward step given an action. The action is performed by the corresponding sub-environment and then the global state is updated accordingly. If the updated state of the sub-environment becomes its source, the stage is decreased. :param action: Action to be executed. The input action is global, that is padded. :type action: tuple :returns: * **self.state** (*list*) -- The state after executing the action. * **action** (*int*) -- Action executed. * **valid** (*bool*) -- False, if the action is not allowed for the current state. True otherwise. .. py:method:: sample_actions_batch(policy_outputs, mask = None, states_from = None, is_backward = False, random_action_prob = 0.0, temperature_logits = 1.0) Samples a batch of actions from a batch of policy outputs. This method calls the ``sample_actions_batch()`` method of the sub-environment corresponding to each state in the batch. Note that in order to call sample_actions_batch() of the sub-environments, we need to first extract the part of the policy outputs, the masks and the states that correspond to the sub-environment. :param policy_outputs: The output of the GFlowNet policy model. :type policy_outputs: tensor :param mask: The mask of invalid actions, formatted as in :py:meth:`gflownet.envs.composite.stack.Stack._format_mask` :type mask: tensor :param states_from: The states originating the actions, in environment format. :type states_from: list :param is_backward: True if the actions are backward, False if the actions are forward :type is_backward: bool :param random_action_prob: The probability of sampling a random action. If larger than one, the model outputs will be replaced by a random policy vector with probability `random_action_prob`, according to Bernoulli distribution. :type random_action_prob: float, optional :param temperature_logits: A scalar by which the model outputs are divided to temper the sampling distribution. :type temperature_logits: float, optional :returns: **actions** (*list*) -- The list of sampled actions. .. py:method:: get_logprobs(policy_outputs, actions, mask, states_from, is_backward) Computes log probabilities of actions given policy outputs and actions. :param policy_outputs: The output of the GFlowNet policy model. :type policy_outputs: tensor :param mask: The mask containing information about invalid actions and special cases. :type mask: tensor :param actions: The actions (global) from each state in the batch for which to compute the log probability. :type actions: list or tensor :param states_from: The states originating the actions, in environment format. :type states_from: tensor :param is_backward: True if the actions are backward, False if the actions are forward (default). :type is_backward: bool :returns: *tensor* -- The log probabilities of the transitions. .. py:method:: states2policy(states) Prepares a batch of states in environment format for the policy model. The default policy representation is a concatenation of the policy-format states of the sub-environments. There is only one call of ``subenv.states2policy()`` for each sub-environment, on all the corresponding substates in the batch. :param states: A batch of states in environment format. :type states: list :returns: *A tensor containing all the states in the batch.* .. py:method:: states2proxy(states) Prepares a batch of states in environment format for a proxy: simply a concatenation of the proxy-format states of the sub-environments. :param states: A batch of states in environment format. :type states: list :returns: * *A list of lists, each containing the proxy representation of all the states in* * *the Stack, for all the Stacks in the batch..* .. py:method:: state2readable(state = None) Converts a state into human-readable representation. It concatenates the readable representations of each sub-environment, separated by "; " and preceded by "Stage {stage}; ". .. py:method:: readable2state(readable) Converts a human-readable representation of a state into the standard format. .. py:method:: action2representative(action) Replaces the part of the action associated with a sub-environment by its representative. The part of the action that identifies the sub-environment concerned by the action remains unaffected. .. py:method:: is_source(state = None) Returns True if the environment's state or the state passed as parameter (if not None) is the source state of the environment. This method is overriden for efficiency (for example, it would return False immediately if the stage is not the first stage) and to cover special uses of the Stack. :param state: None, or a state in environment format. :type state: Dict :returns: *bool* -- Whether the state is the source state of the environment