Source code for gflownet.envs.base

"""
Base class of GFlowNet environments
"""

import math
import numbers
import random
import uuid
from abc import abstractmethod
from copy import deepcopy
from textwrap import dedent
from typing import Dict, List, Optional, Tuple, Union

import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
import numpy.typing as npt
import torch
from torch.distributions import Bernoulli, Categorical
from torchtyping import TensorType

from gflownet.utils.common import (
    copy,
    set_device,
    set_float_precision,
    tbool,
    tfloat,
    tlong,
)

[docs] CMAP = mpl.colormaps["cividis"]
""" Plotting colour map (cividis). """
[docs] class GFlowNetEnv: """ Base class of GFlowNet environments """ def __init__( self, device: str = "cpu", float_precision: int = 32, env_id: Union[int, str] = "env", fixed_distr_params: Optional[dict] = None, random_distr_params: Optional[dict] = None, skip_mask_check: bool = False, conditional: bool = False, continuous: bool = False, **kwargs, ): # Flag whether env is conditional
[docs] self.conditional = conditional
# Flag whether env is continuous
[docs] self.continuous = continuous
# Call reset() to set initial state, done, n_actions self.reset() # Device
[docs] self.device = set_device(device)
# Float precision
[docs] self.float = set_float_precision(float_precision)
# Flag to skip checking if action is valid (computing mask) before step
[docs] self.skip_mask_check = skip_mask_check
# Log SoftMax function
[docs] self.logsoftmax = torch.nn.LogSoftmax(dim=1)
# Action space
[docs] self.action_space = self.get_action_space()
self._action2index = {a: idx for idx, a in enumerate(self.action_space)}
[docs] self.action_space_torch = torch.tensor( self.action_space, device=self.device, dtype=self.float )
# Mask dimensionality self._mask_dim = self._compute_mask_dim() # Max trajectory length self._max_traj_length = self._get_max_trajectory_length() # Policy outputs
[docs] self.fixed_distr_params = fixed_distr_params
[docs] self.random_distr_params = random_distr_params
[docs] self.fixed_policy_output = self.get_policy_output(self.fixed_distr_params)
[docs] self.random_policy_output = self.get_policy_output(self.random_distr_params)
[docs] self.policy_output_dim = len(self.fixed_policy_output)
[docs] self.policy_input_dim = len(self.state2policy())
@abstractmethod
[docs] def get_action_space(self): """ Constructs list with all possible actions (excluding end of sequence) """ pass
@property
[docs] def action_space_dim(self) -> int: """ Returns the dimensionality of the action space (number of actions). Returns ------- The number of actions in the action space. """ return len(self.action_space)
@property
[docs] def mask_dim(self): """ Returns the dimensionality of the masks. Returns ------- The dimensionality of the masks. """ return self._mask_dim
def _compute_mask_dim(self) -> int: """ Calculates the mask dimensionality. By default, the mask dimensionality is equal to the dimensionality of the action space. This method should be overriden in environments where this may not be the case, for example continuous environments (ContinuousCube) and meta-environments such as Stack and Set. Returns ------- int The number of elements in the masks. """ return self.action_space_dim def _get_max_trajectory_length(self) -> int: """ Returns the maximum trajectory length of the environment, including the EOS action. While it is not required to override this method because it does return a default value of 100, it is recommended to override it to return the correct value or an upper bound as tight as possible to the maximum. The maximum trajectory length does not play a critical role but it is used for testing purposes. For example, it is used by get_random_states(), and poor estimation of the trajectory length could result in stark inefficiency. """ return 100 @property
[docs] def max_traj_length(self) -> int: """ Returns the maximum trajectory length of the environment, including the EOS action. Returns ------- The maximum number of steps in a trajectory of the environment. """ return self._max_traj_length
[docs] def action2representative(self, action: Tuple) -> int: """ For continuous or hybrid environments, converts a continuous action into its representative in the action space. Discrete actions remain identical, thus fully discrete environments do not need to re-implement this method. Continuous environments should re-implement this method in order to replace continuous actions by their representatives in the action space. """ return action
[docs] def action_produces_permutation( self, action: Tuple, is_backward: bool = False ) -> bool: """ Determines whether an action produces permutations in the resulting state. Permutations can be introduced, for example, in environments that need to incorporate permutation invariance, as in sets of elements. In these cases, some actions may result in states with elements that are randomly permuted. This method allows to identify these actions, which is useful, for instance, in unit tests. By default, actions do not produce permutations and the returned value of this method is False. Environments with actions that produce permutations should override this method and properly identify such actions. Parameters ---------- action : tuple An action of the environment. is_backward : bool Whether the transition to consider is backward (True) or forward (False). Returns ------- bool Whether the input actions produces permutations in the resulting state, in the direction indicated by ``is_backward``. """ return False
[docs] def action2index(self, action: Tuple) -> int: """ Returns the index in the action space of the action passed as an argument, or its representative if it is a continuous action. The method uses the dictionary lookup ``self._action2index``. See: self.action2representative() Parameters ---------- action : tuple An action from the action space. Returns ------- int The index of the action in the action space. """ return self._action2index[self.action2representative(action)]
[docs] def actions2indices( self, actions: TensorType["batch_size", "action_dim"] ) -> TensorType["batch_size"]: """ Returns the corresponding indices in the action space of the actions in a batch. """ # Expand the action_space tensor: [batch_size, d_actions_space, action_dim] action_space = torch.unsqueeze(self.action_space_torch, 0).expand( actions.shape[0], -1, -1 ) # Expand the actions tensor: [batch_size, d_actions_space, action_dim] actions = torch.unsqueeze(actions, 1).expand(-1, self.action_space_dim, -1) # Take the indices at the d_actions_space dimension where all the elements in # the action_dim dimension are True return torch.where(torch.all(actions == action_space, dim=2))[1]
def _get_state( self, state: Union[List, TensorType["state_dims"]], do_copy: bool = False ): """ Returns the input state or ``self.state`` if it is None. This is meant to be used as a helper method for other methods to determine whether the state should be taken from the arguments or from the environment instance (``self.state``): if is None, it is taken from the environment. If ``do_copy`` is True (False by default), the state is copied before returning it. Parameters ---------- state : list or tensor or dict or None A state in environment format, or None. do_copy : bool Whether to copy the state before returning it. Returns ------- state : list or tensor or dict or None The argument state, or self.state if state is None. """ if state is None: state = self.state if do_copy: return copy(state) else: return state def _get_done(self, done: bool): """ A helper method for other methods to determine whether done should be taken from the arguments or from the instance (self.done): if it is None, it is taken from the instance. Args ---- done : bool or None None, or whether the environment is done. Returns ------- done: bool The argument done, or self.done if done is None. """ if done is None: done = self.done return done
[docs] def is_source( self, state: Optional[Union[List, TensorType["state_dims"]]] = None ) -> bool: """ Returns True if the environment's state or the state passed as parameter (if not None) is the source state of the environment. Parameters ---------- state : list or tensor or None None, or a state in environment format. Returns ------- bool Whether the state is the source state of the environment """ state = self._get_state(state) return self.equal(state, self.source)
[docs] def get_mask_invalid_actions_forward( self, state: Optional[List] = None, done: Optional[bool] = None, ) -> List: """ Returns a list of length the action space with values: - True if the forward action is invalid from the current state. - False otherwise. For continuous or hybrid environments, this mask corresponds to the discrete part of the action space. """ return [False for _ in range(self.action_space_dim)]
[docs] def get_mask_invalid_actions_backward( self, state: Optional[List] = None, done: Optional[bool] = None, parents_a: Optional[List] = None, ) -> List: """ Returns a list of length the action space with values: - True if the backward action is invalid from the current state. - False otherwise. For continuous or hybrid environments, this mask corresponds to the discrete part of the action space. The base implementation below should be common to all discrete spaces as it relies on get_parents, which is environment-specific and must be implemented. Continuous environments will probably need to implement its specific version of this method. """ state = self._get_state(state) done = self._get_done(done) if parents_a is None: _, parents_a = self.get_parents(state, done) mask = [True for _ in range(self.action_space_dim)] for pa in parents_a: mask[self.action_space.index(pa)] = False return mask
[docs] def get_mask( self, state: Optional[List] = None, done: Optional[bool] = None, backward: Optional[bool] = False, ) -> List: """ Returns a mask of invalid actions given a state and a done value. Depending on backward, either the forward or the backward mask is returned, by calling the corresponding method. """ if backward: return self.get_mask_invalid_actions_backward(state, done) else: return self.get_mask_invalid_actions_forward(state, done)
[docs] def get_valid_actions( self, mask: Optional[bool] = None, state: Optional[List] = None, done: Optional[bool] = None, backward: Optional[bool] = False, ) -> List[Tuple]: """ Returns the list of non-invalid (valid, for short) according to the mask of invalid actions. More documentation about the meaning and use of invalid actions can be found in gflownet/envs/README.md. """ if mask is None: mask = self.get_mask(state, done, backward) return [action for action, m in zip(self.action_space, mask) if not m]
[docs] def get_parents( self, state: Optional[List] = None, done: Optional[bool] = None, action: Optional[Tuple] = None, ) -> Tuple[List, List]: """ Determines all parents and actions that lead to state. In continuous environments, get_parents() should return only the parent from which action leads to state. Args ---- state : list Representation of a state done : bool Whether the trajectory is done. If None, done is taken from instance. action : tuple Last action performed Returns ------- parents : list List of parents in state format actions : list List of actions that lead to state for each parent in parents """ state = self._get_state(state) done = self._get_done(done) if done: return [state], [(self.eos,)] parents = [] actions = [] return parents, actions
# TODO: consider returning only do_step def _pre_step( self, action: Tuple[int], backward: bool = False, skip_mask_check: bool = False ) -> Tuple[bool, List[int], Tuple[int]]: """ Performs generic checks shared by the step() and step_backward() (backward must be True) methods of all environments. Args ---- action : tuple Action from the action space. skip_mask_check : bool If True, skip computing forward mask of invalid actions to check if the action is valid. Returns ------- do_step : bool If True, step() should continue further, False otherwise. self.state : list The sequence after executing the action action : int Action index """ # If action not found in action space raise an error if action not in self.action_space: raise ValueError( f"Tried to execute action {action} not present in action space." ) # If backward and state is source, step should not proceed. if backward is True: if self.equal(self.state, self.source) and action != self.eos: return False, self.state, action # If forward and env is done, step should not proceed. else: if self.done: return False, self.state, action # If action is in invalid mask (not in valid actions), step should not proceed. if not (self.skip_mask_check or skip_mask_check): if action not in self.get_valid_actions(backward=backward): return False, self.state, action return True, self.state, action @abstractmethod
[docs] def step( self, action: Tuple[int], skip_mask_check: bool = False ) -> Tuple[List[int], Tuple[int], bool]: """ Executes step given an action. Args ---- action : tuple Action from the action space. skip_mask_check : bool If True, skip computing forward mask of invalid actions to check if the action is valid. Returns ------- self.state : list The sequence after executing the action action : int Action index valid : bool False, if the action is not allowed for the current state, e.g. stop at the root state """ _, self.state, action = self._pre_step(action, skip_mask_check) return None, None, None
[docs] def step_backwards( self, action: Tuple[int], skip_mask_check: bool = False ) -> Tuple[List[int], Tuple[int], bool]: """ Executes a backward step given an action. This generic implementation should work for all discrete environments, as it relies on get_parents(). Continuous environments should re-implement a custom step_backwards(). Despite being valid for any discrete environment, the call to get_parents() may be expensive. Thus, it may be advantageous to re-implement step_backwards() in a more efficient way as well for discrete environments. Especially, because this generic implementation will make two calls to get_parents - once here and one in _pre_step() through the call to get_mask_invalid_actions_backward() if skip_mask_check is True. Args ---- action : tuple Action from the action space. skip_mask_check : bool If True, skip computing forward mask of invalid actions to check if the action is valid. Returns ------- self.state : list The sequence after executing the action action : int Action index valid : bool False, if the action is not allowed for the current state. """ do_step, self.state, action = self._pre_step(action, True, skip_mask_check) if not do_step: return self.state, action, False parents, parents_a = self.get_parents() state_next = parents[parents_a.index(action)] self.set_state(state_next, done=False) self.n_actions += 1 return self.state, action, True
[docs] def randomize_and_temper_sampling_distribution( self, policy_outputs: TensorType["n_states", "policy_output_dim"], probability_random_action: Optional[float] = 0.0, temperature: Optional[float] = 1.0, ) -> TensorType["n_states", "policy_output_dim"]: """ Replaces the rows of `policy_outputs` by a vector corresponding to a random sampling policy with the probability indicated by `probability_random_action`. Note that the tensor of policy outputs is not cloned if neither tempering nor random actions are incorporated. This implies that the original tensor of policy outputs may be modified by subsequent methods (namely sample_actions_batch()), for example to mask the invalid actions. Parameters ---------- policy_outputs : tensor The original outputs of the sampling policy. For example, they may correspond to the output (logits) of the GFlowNet policy model. probability_random_action : float, optional The probability of sampling a random action. If larger than one, the logits will be replaced by a random policy vector with this probability, according to Bernoulli distribution. By default, the probability is 0.0 (no random actions). temperature : float, optional A scalar by which the logits are divided to adjust the sampling distribution. A temperature larger than one will result in a flatter distribution, favouring exploration. A temperature smaller than one will sharpen the distribution, favouring concentration around high probability actions. By default, the temperature is 1.0 (no tempering). Returns ------- policy_outputs : tensor The modified policy outputs. """ if not math.isclose(temperature, 1.0, abs_tol=1e-08): do_temper = True else: do_temper = False if not math.isclose(probability_random_action, 0.0, abs_tol=1e-08): do_random = True else: do_random = False if not do_temper and not do_random: return policy_outputs # Clone the sampling logits in order not to change the original tensor logits_sampling = policy_outputs.clone().detach() if do_temper: logits_sampling /= temperature if do_random: idx_random = tbool( Bernoulli( probability_random_action * torch.ones(policy_outputs.shape[0]) ).sample(), device=self.device, ) logits_sampling[idx_random, :] = self.random_policy_output return logits_sampling
[docs] def sample_actions_batch( self, policy_outputs: TensorType["n_states", "policy_output_dim"], mask: Optional[TensorType["n_states", "policy_output_dim"]] = None, states_from: Optional[List] = None, is_backward: Optional[bool] = False, random_action_prob: Optional[float] = 0.0, temperature_logits: Optional[float] = 1.0, ) -> Tuple[List[Tuple], TensorType["n_states"]]: """ Samples a batch of actions from a batch of policy outputs. This implementation is generally valid for all discrete environments but continuous or mixed environments need to reimplement this method. The method is valid for both forward and backward actions in the case of discrete environments. Some continuous environments may also be agnostic to the difference between forward and backward actions since the necessary information can be contained in the mask. However, some continuous environments do need to know whether the actions are forward of backward, which is why this can be specified by the argument is_backward. Most environments do not need to know the states from which the actions are to be sampled since the necessary information is in both the policy outputs and the mask. However, some continuous environments do need to know the originating states in order to construct the actions, which is why one of the arguments is states_from. Note that methods overriding this method should randomize and temper the logits. Parameters ---------- policy_outputs : tensor The output of the GFlowNet policy model. mask : tensor The mask of invalid actions. For continuous or mixed environments, the mask may be tensor with an arbitrary length contaning information about special states, as defined elsewhere in the environment. states_from : tensor The states originating the actions, in GFlowNet format. Ignored in discrete environments and only required in certain continuous environments. is_backward : bool True if the actions are backward, False if the actions are forward (default). Ignored in discrete environments and only required in certain continuous environments. random_action_prob : float, optional The probability of sampling a random action. If larger than one, the model outputs will be replaced by a random policy vector with probability `random_action_prob`, according to Bernoulli distribution. temperature_logits : float, optional A scalar by which the model outputs are divided to temper the sampling distribution. Returns ------- actions : list The list of sampled actions. """ # Randomize actions and temper the logits logits_sampling = self.randomize_and_temper_sampling_distribution( policy_outputs, random_action_prob, temperature_logits ) # Make the logits of invalid actions equal to -inf. if mask is not None: if torch.all(mask, dim=1).any(): raise RuntimeError( "All actions in the mask are invalid for some states in the batch." ) logits_sampling[mask] = -torch.inf # Sample actions from the Categorical distributions defined by the logits action_indices = Categorical(logits=logits_sampling).sample() # Build actions actions = [self.action_space[idx] for idx in action_indices] return actions
[docs] def get_logprobs( self, policy_outputs: TensorType["n_states", "policy_output_dim"], actions: Union[List, TensorType["n_states", "action_dim"]], mask: TensorType["batch_size", "policy_output_dim"] = None, states_from: Optional[List] = None, is_backward: bool = False, ) -> TensorType["batch_size"]: """ Computes log probabilities of actions given policy outputs and actions. This implementation is generally valid for all discrete environments but continuous environments will likely have to implement its own. Parameters ---------- policy_outputs : tensor The output of the GFlowNet policy model. mask : tensor The mask of invalid actions. For continuous or mixed environments, the mask may be tensor with an arbitrary length contaning information about special states, as defined elsewhere in the environment. actions : list or tensor The actions from each state in the batch for which to compute the log probability. The actions may be a list or a tensor. Most environments handle the actions as a list, but in some cases it is practical to use a tensor for easier indexing, such as in meta-environments. states_from : tensor The states originating the actions, in GFlowNet format. Ignored in discrete environments and only required in certain continuous environments. is_backward : bool True if the actions are backward, False if the actions are forward (default). Ignored in discrete environments and only required in certain continuous environments. """ device = policy_outputs.device ns_range = torch.arange(policy_outputs.shape[0]).to(device) logits = policy_outputs.clone() if mask is not None: logits[mask] = -torch.inf if torch.is_tensor(actions): action_indices = tlong(self.actions2indices(actions), device=self.device) else: action_indices = tlong( [self.action2index(action) for action in actions], device=self.device, ) logprobs = self.logsoftmax(logits)[ns_range, action_indices] return logprobs
[docs] def step_random(self, backward: bool = False): """ Samples a random action and executes the step. Parameters ---------- backward : bool If True, the step is performed backwards. False by default. Returns ------- state : list The state after executing the action. action : int Action, randomly sampled. valid : bool False, if the action is not allowed for the current state. """ mask_invalid = tbool( self.get_mask(backward=backward), device=self.device ).unsqueeze(0) action = self.sample_actions_batch( self.random_policy_output.clone().unsqueeze(0), mask_invalid, [self.state], backward, )[0] if backward: return self.step_backwards(action) return self.step(action)
[docs] def trajectory_random(self, backward: bool = False): """ Samples and applies a random trajectory on the environment, by sampling random actions until an EOS action is sampled. Parameters ---------- backward : bool If True, the trajectory is sampled backwards. False by default. Returns ------- state : list The final state. action: list The list of actions (tuples) in the trajectory. """ actions = [] while True: _, action, valid = self.step_random(backward) if valid: actions.append(action) if backward and self.is_source(): break elif self.done: break else: continue return self.state, actions
[docs] def get_random_terminating_states( self, n_states: int, unique: bool = True, max_attempts: int = 100000 ) -> List: """ Samples n terminating states by using the random policy of the environment (calling self.trajectory_random()). Note that this method is general for all environments but it may be suboptimal in terms of efficiency. In particular, 1) it samples full trajectories in order to get terminating states, 2) if unique is True, it needs to compare each newly sampled state with all the previously sampled states. If get_uniform_terminating_states is available, it may be preferred, or for some environments, a custom get_random_terminating_states may be straightforward to implement in a much more efficient way. Args ---- n_states : int The number of terminating states to sample. unique : bool Whether samples should be unique. True by default. max_attempts : int The maximum number of attempts, to prevent the method from getting stuck trying to obtain n_states different samples if unique is True. 100000 by default, therefore if more than 100000 are requested, max_attempts should be increased accordingly. Returns ------- states : list A list of randomly sampled terminating states. """ if unique is False: max_attempts = n_states + 1 states = [] count = 0 while len(states) < n_states and count < max_attempts: add = True self.reset() state, _ = self.trajectory_random() if unique is True: if any([self.equal(state, s) for s in states]): add = False if add is True: states.append(state) count += 1 return states
[docs] def get_random_states( self, n_states: int, unique: bool = True, exclude_source: bool = False, max_attempts: int = 1000, ) -> List: """ Samples n states (not necessarily terminating) by using the random policy of the environment (calling self.step_random()). It relies on self.max_traj_length in order to uniformly sample the number of steps, in order to obtain states with varying trajectory lengths. The method iteratively samples first a trajectory length and attempts to perform as many steps. If the trajectory ends before the requested number of steps is reached, then it is discarded and a new one is attempted. This may introduced a bias towards states that can be reached with a few steps. Note that this method is general for all environments but it may be suboptimal in terms of efficiency. In particular, 1) it samples trajectories step by step in order to get random states, 2) if unique is True, it needs to compare each newly sampled state with all the previously sampled states, 3) states are copied before adding them to the list, 4) only the last state of a trajectory is added to the list in order to have diversity of trajectories. Parameters ---------- n_states : int The number of terminating states to sample. unique : bool Whether samples should be unique. True by default. max_attempts : int The maximum number of attempts, to prevent the method from getting stuck trying to obtain n_states different samples if unique is True. 100000 by default, therefore if more than 100000 are requested, max_attempts should be increased accordingly. exclude_source : bool If True, exclude the source state from the list of states. Returns ------- states : list A list of randomly sampled states. Raises ------ ValueError If max_attempts is smaller than n_states RuntimeError If the maximum number of attempts is reached before obtaining the requested number of unique states. """ max_traj_length = self.max_traj_length if max_attempts < n_states: raise ValueError( f"max_attempts (received {max_attempts}) must larger than or " f"equal to n_states (received {n_states})." ) states = [] n_attempts = 0 # Iterate until the requested number of states is obtained while len(states) < n_states: n_attempts += 1 # Sample a trajectory length for this state traj_length = random.randint(1, max_traj_length) self.reset() is_valid = True for _ in range(traj_length): # If the trajectory has reached done before the number of requested # steps, discard it and start a new one. if self.done: is_valid = False break # Perform a random step self.step_random() # If exclude_source is True and the state is the source, mark the # trajectory as invalid. if is_valid and exclude_source and self.is_source(self.state): is_valid = False # If unique is True and the state is in the list, mark the trajetory as # invalid if is_valid and unique and any([self.equal(self.state, s) for s in states]): is_valid = False # If the trajectory is valid, add the state to the list if is_valid: states.append(copy(self.state)) # Check if the number of attempts has reached the maximum if n_attempts >= max_attempts: raise RuntimeError( f"Reached the maximum number of attempts ({max_attempts}) to " f"sample {n_states} states but only {len(states)} could " "be obtained. It is possible that the state space is too small " f"to contain {n_states} states. Otherwise, consider " "increasing the number of attempts" ) return states
[docs] def get_policy_output( self, params: Optional[dict] = None ) -> TensorType["policy_output_dim"]: """ Defines the structure of the output of the policy model, from which an action is to be determined or sampled, by returning a vector with a fixed random policy. As a baseline, the policy is uniform over the dimensionality of the action space. Continuous environments will generally have to overwrite this method. """ return torch.ones(self.action_space_dim, dtype=self.float, device=self.device)
[docs] def states2proxy( self, states: Union[List[List], TensorType["batch", "state_dim"]] ) -> TensorType["batch", "state_proxy_dim"]: """ Prepares a batch of states in "environment format" for the proxy. By default, the batch of states is converted into a tensor with float dtype and returned as is. Args ---- states : list or tensor A batch of states in environment format, either as a list of states or as a single tensor. Returns ------- A tensor containing all the states in the batch. """ return tfloat(states, device=self.device, float_type=self.float)
[docs] def state2proxy( self, state: Union[List, TensorType["state_dim"]] = None ) -> TensorType["state_proxy_dim"]: """ Prepares a single state in "GFlowNet format" for the proxy. By default, simply states2proxy is called and the output will be a "batch" with a single state in the proxy format. Args ---- state : list A state """ state = self._get_state(state) return self.states2proxy([state])
[docs] def states2policy( self, states: Union[List, TensorType["batch", "state_dim"]] ) -> TensorType["batch", "policy_input_dim"]: """ Prepares a batch of states in "environment format" for the policy model: By default, the batch of states is converted into a tensor with float dtype and returned as is. Args ---- states : list or tensor A batch of states in environment format, either as a list of states or as a single tensor. Returns ------- A tensor containing all the states in the batch. """ return tfloat(states, device=self.device, float_type=self.float)
[docs] def state2policy( self, state: Union[List, TensorType["state_dim"]] = None ) -> TensorType["policy_input_dim"]: """ Prepares a state in "GFlowNet format" for the policy model. By default, states2policy is called, which by default will return the state as is. Args ---- state : list A state """ state = self._get_state(state) return torch.squeeze(self.states2policy([state]), dim=0)
[docs] def state2readable(self, state=None): """ Converts a state into human-readable representation. """ state = self._get_state(state) return str(state)
[docs] def readable2state(self, readable): """ Converts a human-readable representation of a state into the standard format. """ return readable
[docs] def traj2readable(self, traj=None): """ Converts a trajectory into a human-readable string. """ return str(traj).replace("(", "[").replace(")", "]").replace(",", "")
[docs] def reset(self, env_id: Union[int, str] = None): """ Resets the environment. Args ---- env_id: int or str Unique (ideally) identifier of the environment instance, used to identify the trajectory generated with this environment. If None, uuid.uuid4() is used. Returns ------- self """ self.state = copy(self.source) self.n_actions = 0 self.done = False if env_id is None: self.id = str(uuid.uuid4()) else: self.id = env_id return self
[docs] def set_id(self, env_id: Union[int, str]): """ Sets the id given as argument and returns the environment. Args ---- env_id: int or str Unique (ideally) identifier of the environment instance, used to identify the trajectory generated with this environment. Returns ------- self """ self.id = env_id return self
[docs] def set_state(self, state: List, done: Optional[bool] = False): """ Sets the state and done of an environment. Environments that cannot be "done" at all states (intermediate states are not fully constructed objects) should overwrite this method and check for validity. """ self.state = copy(state) self.done = done return self
[docs] def copy(self): # return self.__class__(**self.__dict__) return deepcopy(self)
@staticmethod
[docs] def equal( state_x: Union[numbers.Number, str, torch.Tensor, Dict, List, Tuple], state_y: Union[numbers.Number, str, torch.Tensor, Dict, List, Tuple], ) -> bool: """ Checks whether the two input states are equal. This method handles recursively multiple structure types: numbers, strings, tensors, dictionaries, lists and tuples. The result is only True if the content of the two input states is identical. The core functionality is implemented in :py:meth:`gflownet.envs.base.GFlowNetEnv.isclose` and this method simply calls it with ``do_equal=True``. Parameters ---------- state_x: number, str, tensor, dict, list, tuple One of the states to be compared. state_y: number, str, tensor, dict, list, tuple The other state to be compared. Returns ------- bool True if the two input states are equal; False otherwise. Raises ------ NotImplementedError If the input types are not part of the explicitly handles types. """ return GFlowNetEnv.isclose(state_x, state_y, do_equal=True)
@staticmethod
[docs] def isclose( state_x: Union[numbers.Number, str, torch.Tensor, Dict, List, Tuple], state_y: Union[numbers.Number, str, torch.Tensor, Dict, List, Tuple], rtol: float = 1e-5, atol: float = 1e-8, do_equal: bool = False, ) -> bool: """ Checks whether the two input states are close, according to a tolerance. This method relies on numpy's and torch's ``isclose()`` methods, which both use the following formula: ``abs(x - y) <= rtol * abs(y) + atol`` This method is used as well by :py:meth:`gflownet.envs.base.GFlowNetEnv.equal` in order to avoid code repetition. In this case, ``do_equal`` is True and numpy's and torch's ``equal()`` methods are used. This is preferred over using ``rtol`` and ``atol`` equal to 0.0 for efficiency reasons. This method handles recursively multiple structure types: numbers, strings, tensors, dictionaries, lists and tuples. The result is only True if the content of the two input states is identical or close enough, as defined by the tolerance values ``rtol`` and ``atol``. In the case of strings, True is only returned if the states are identical. Parameters ---------- state_x: number, str, tensor, dict, list, tuple One of the states to be compared. state_y: number, str, tensor, dict, list, tuple The other state to be compared. rtol : float Relative tolerance for numeric values. atol : float Maximum absolute tolerance threshold for numeric values. do_equal : bool If True, comparisons are by equality instead of closeness and ``rtol`` and ``atol`` are ignored. Returns ------- bool True if the two input states are equal or closer than the maximum tolerance; False otherwise. Raises ------ NotImplementedError If the input types are not part of the explicitly handles types. """ # Strings if isinstance(state_x, str): return state_x == state_y # Numbers elif isinstance(state_x, numbers.Number): if do_equal: return state_x == state_y return np.isclose(state_x, state_y, rtol=rtol, atol=atol) # Types elif type(state_x) != type(state_y): return False # Tensors elif torch.is_tensor(state_x) and torch.is_tensor(state_y): # Check for nans because (torch.nan == torch.nan) == False x_nan = torch.isnan(state_x) if torch.any(x_nan): y_nan = torch.isnan(state_y) if not torch.equal(x_nan, y_nan): return False if do_equal: return torch.equal(state_x[~x_nan], state_y[~y_nan]) return torch.all( torch.isclose( state_x[~x_nan], state_y[~y_nan], rtol=rtol, atol=atol ) ) if do_equal: return torch.equal(state_x, state_y) return torch.all(torch.isclose(state_x, state_y, rtol=rtol, atol=atol)) # Numpy elif isinstance(state_x, np.ndarray) and isinstance(state_y, np.ndarray): if do_equal: return np.array_equal(state_x, state_y, equal_nan=True) return np.allclose(state_x, state_y, rtol=rtol, atol=atol, equal_nan=True) # Dictionaries elif isinstance(state_x, dict) and isinstance(state_y, dict): if len(state_x) != len(state_y): return False for key_x, value_x in state_x.items(): if key_x not in state_y: return False # Recursive comparison of the values if not GFlowNetEnv.isclose( value_x, state_y[key_x], rtol=rtol, atol=atol, do_equal=do_equal ): return False else: return True # Lists and tuples elif (isinstance(state_x, list) and isinstance(state_y, list)) or ( isinstance(state_x, tuple) and isinstance(state_y, tuple) ): if len(state_x) != len(state_y): return False if len(state_x) == 0: return True # If all the elements of the list or tuple are numbers compare the list or # tuple via np.all(np.isclose(state_x == state_y)) if isinstance(state_x[0], numbers.Number): value_type = type(state_x[0]) for sx, sy in zip(state_x, state_y): if not isinstance(sx, value_type): break else: if do_equal: return state_x == state_y return np.all(np.isclose(state_x, state_y, rtol=rtol, atol=atol)) # Otherwise, iterate over the lists or tuples and compare them recursively for sx, sy in zip(state_x, state_y): if not GFlowNetEnv.isclose( sx, sy, rtol=rtol, atol=atol, do_equal=do_equal ): return False else: raise NotImplementedError(f"Unknown type: {type(state_x)}") return True
[docs] def __eq__(self, other, ignored_keys: List[str] = []) -> bool: """ Checks whether the current environment instance is equal to the input environment instance. The attribute ``self.id`` is ignored to determine whether the environments are equal. Parameters ---------- other : GFlowNetEnv The environment instance to be compared. ignored_keys : list A list of keys (strings) to be ignored in the comparison. This parameter may be used by subclasses that may need to ignore certain keys. Returns ------- bool True if the environments's attributes are considered equal; False otherwise. """ # Check if other is not a GFlowNet environment if not isinstance(other, GFlowNetEnv): return False # Obtain dictionary of attributes of the other instance and iterate over the # dictionary of self to compare all attributes other_dict = other.__dict__ for k, v in self.__dict__.items(): # Ignore id if k == "id": continue # Ignore keys in ignored_keys if k in ignored_keys: continue # Check if the attribute is not in the other dict if k not in other_dict: return False v_other = other_dict[k] # Check if value types are different if type(v_other) != type(v): return False # If the attribute is an environment, enter recursion to check the # attributes of the sub-environment if isinstance(v, GFlowNetEnv): if not v.__eq__(v_other): return False # If the attribute is a list / tuple / dict of environments, enter # recursion to check the attributes of the sub-environment. This method # does not catch differences in sub-environments that are not at the first # level of a list, tuple or dict elif isinstance(v, list) or isinstance(v, tuple): if len(v) != len(v_other): return False if len(v) == 0: return True for v_el, v_other_el in zip(v, v_other): if isinstance(v_el, GFlowNetEnv): if not v_el.__eq__(v_other_el): return False else: # Compare the values with GFlowNet.equal() try: if not GFlowNetEnv.equal(v_el, v_other_el): return False except NotImplementedError: # If the types are not handled by self.equal, then ignore # this attribute for lack of means to determine whether the # values are equal continue elif isinstance(v, dict): for (v_k, v_v), (v_other_k, v_other_v) in zip( v.items(), v_other.items() ): if v_k != v_other_k: return False if isinstance(v_v, GFlowNetEnv): if not v_v.__eq__(v_other_v): return False else: # Compare the values with GFlowNet.equal() try: if not GFlowNetEnv.equal(v_v, v_other_v): return False except NotImplementedError: # If the types are not handled by self.equal, then ignore # this attribute for lack of means to determine whether the # values are equal continue else: # Compare the values with GFlowNet.equal() try: if not GFlowNetEnv.equal(v, v_other): return False except NotImplementedError: # If the types are not handled by self.equal, then ignore this # attribute for lack of means to determine whether the values are equal continue return True
[docs] def get_trajectories( self, traj_list, traj_actions_list, current_traj, current_actions ): """ Determines all trajectories leading to each state in traj_list, recursively. Args ---- traj_list : list List of trajectories (lists) traj_actions_list : list List of actions within each trajectory current_traj : list Current trajectory current_actions : list Actions of current trajectory Returns ------- traj_list : list List of trajectories (lists) traj_actions_list : list List of actions within each trajectory """ parents, parents_actions = self.get_parents(current_traj[-1], False) if parents == []: traj_list.append(current_traj) traj_actions_list.append(current_actions) return traj_list, traj_actions_list for idx, (p, a) in enumerate(zip(parents, parents_actions)): traj_list, traj_actions_list = self.get_trajectories( traj_list, traj_actions_list, current_traj + [p], current_actions + [a] ) return traj_list, traj_actions_list
@torch.no_grad()
[docs] def compute_train_energy_proxy_and_rewards(self): """ Gather batched proxy data: * The ground-truth energy of the train set * The predicted proxy energy over the train set * The reward version of those energies (with env.proxy2reward) Returns ------- gt_energy : torch.Tensor The ground-truth energies in the proxy's train set proxy_energy : torch.Tensor The proxy's predicted energies over its train set gt_reward : torch.Tensor The reward version of the ground-truth energies proxy_reward : torch.Tensor The reward version of the proxy's predicted energies """ gt_energy, proxy_energy = self.proxy.infer_on_train_set() gt_reward = self.proxy2reward(gt_energy) proxy_reward = self.proxy2reward(proxy_energy) return gt_energy, proxy_energy, gt_reward, proxy_reward
[docs] def mask_conditioning( self, mask: Union[List[bool], TensorType["mask_dim"]], env_cond, backward: bool ): """ Conditions the input mask based on the restrictions imposed by a conditioning environment, env_cond. It is assumed that the state space of the conditioning environment is a subset of the state space of the original environment (self). The conditioning mechanism goes as follows: given a state, its corresponding mask and a conditioning environment, the mask of invalid actions is updated such that all actions that would be invalid in the conditioning environment are made invalid, even though they may not be invalid in the original environment. """ # Set state in conditional environment env_cond.reset() env_cond.set_state(self.state, self.done) # If the environment is continuous, then we simply return the mask of the # conditioning environment. It is thus assumed that the dimensionality and # interpretation is the same. if self.continuous: return env_cond.get_mask(backward=backward) # Get valid actions common to both the original and the conditioning env actions_valid_orig = self.get_valid_actions(mask) actions_valid_cond = env_cond.get_valid_actions(backward=backward) actions_valid = set(actions_valid_orig).intersection(set(actions_valid_cond)) # Construct new mask by setting to False (valid or not invalid) the actions # that are valid to both environments mask = [True] * self.mask_dim for action in actions_valid: mask[self.action_space.index(action)] = False return mask
@torch.no_grad()
[docs] def top_k_metrics_and_plots( self, states, top_k, name, energy=None, reward=None, step=None, **kwargs, ): """ Compute top_k metrics and plots for the given states. In particular, if no states, energy, or reward are passed, then the name *must* be "train", and the energy and reward will be computed from the proxy using ``env.compute_train_energy_proxy_and_rewards()``. In this case, ``top_k_metrics_and_plots`` will be called a second time to compute the metrics and plots of the proxy distribution in addition to the ground-truth distribution. Train mode should only be called once at the begining of training as distributions do not change over time. If ``states`` are passed, then the energy and reward will be computed from the proxy for those states. They are typically sampled from the current GFN. Otherwise, energy and reward should be passed directly. *Plots and metrics*: - mean+std of energy and reward - mean+std of top_k energy and reward - histogram of energy and reward - histogram of top_k energy and reward Args ---- states: list List of states to compute metrics and plots for. top_k: int Number of top k states to compute metrics and plots for. "top" means lowest energy/highest reward. name: str Name of the distribution to compute metrics and plots for. Typically "gflownet", "random" or "train". Will be used in metrics names like ``f"Mean {name} energy"``. energy: torch.Tensor, optional Batch of pre-computed energies reward: torch.Tensor, optional Batch of pre-computed rewards step: int, optional Step number to use for the plot title. Returns ------- metrics: dict Dictionary of metrics: str->float figs: list List of matplotlib figures figs_names: list List of figure names for ``figs`` """ if states is None and energy is None and reward is None: assert name == "train" ( energy, proxy, energy_reward, proxy_reward, ) = self.compute_train_energy_proxy_and_rewards() name = "train ground truth" reward = energy_reward elif energy is None and reward is None: # TODO: fix this x = torch.stack([self.state2proxy(s) for s in states]) energy = self.proxy(x.to(self.device)).cpu() reward = self.proxy2reward(energy) assert energy is not None and reward is not None # select top k best energies and rewards top_k_e = torch.topk(energy, top_k, largest=False, dim=0).values.numpy() top_k_r = torch.topk(reward, top_k, largest=True, dim=0).values.numpy() # find best energy and reward best_e = torch.min(energy).item() best_r = torch.max(reward).item() # to numpy to plot energy = energy.numpy() reward = reward.numpy() # compute stats mean_e = np.mean(energy) mean_r = np.mean(reward) std_e = np.std(energy) std_r = np.std(reward) mean_top_k_e = np.mean(top_k_e) mean_top_k_r = np.mean(top_k_r) std_top_k_e = np.std(top_k_e) std_top_k_r = np.std(top_k_r) # automatic color scale # currently: cividis colour map colors = ["full", "top_k"] normalizer = mpl.colors.Normalize(vmin=0, vmax=len(colors) - 0.5) colors = {k: CMAP(normalizer(i)) for i, k in enumerate(colors[::-1])} # two sublopts: left is energy, right is reward fig, ax = plt.subplots(1, 2, figsize=(10, 5)) # energy full distribution and stats lines ax[0].hist( energy, bins=100, alpha=0.35, label=f"All = {len(energy)}", color=colors["full"], density=True, ) ax[0].axvline( mean_e, color=colors["full"], linestyle=(0, (5, 10)), label=f"Mean = {mean_e:.3f}", ) ax[0].axvline( mean_e + std_e, color=colors["full"], linestyle=(0, (1, 10)), label=f"Std = {std_e:.3f}", ) ax[0].axvline( mean_e - std_e, color=colors["full"], linestyle=(0, (1, 10)), ) # energy top k distribution and stats lines ax[0].hist( top_k_e, bins=100, alpha=0.7, label=f"Top k = {top_k}", color=colors["top_k"], density=True, ) ax[0].axvline( mean_top_k_e, color=colors["top_k"], linestyle=(0, (5, 10)), label=f"Mean = {mean_top_k_e:.3f}", ) ax[0].axvline( mean_top_k_e + std_top_k_e, color=colors["top_k"], linestyle=(0, (1, 10)), label=f"Std = {std_top_k_e:.3f}", ) ax[0].axvline( mean_top_k_e - std_top_k_e, color=colors["top_k"], linestyle=(0, (1, 10)), ) # energy title & legend ax[0].set_title( f"Energy distribution for {top_k} vs {len(energy)}" + f" samples\nBest: {best_e:.3f}", y=0, pad=-20, verticalalignment="top", size=12, ) ax[0].legend() # reward full distribution and stats lines ax[1].hist( reward, bins=100, alpha=0.35, label=f"All = {len(reward)}", color=colors["full"], density=True, ) ax[1].axvline( mean_r, color=colors["full"], linestyle=(0, (5, 10)), label=f"Mean = {mean_r:.3f}", ) ax[1].axvline( mean_r + std_r, color=colors["full"], linestyle=(0, (1, 10)), label=f"Std = {std_r:.3f}", ) ax[1].axvline( mean_r - std_r, color=colors["full"], linestyle=(0, (1, 10)), ) # reward top k distribution and stats lines ax[1].hist( top_k_r, bins=100, alpha=0.7, label=f"Top k = {top_k}", color=colors["top_k"], density=True, ) ax[1].axvline( mean_top_k_r, color=colors["top_k"], linestyle=(0, (5, 10)), label=f"Mean = {mean_top_k_r:.3f}", ) ax[1].axvline( mean_top_k_r + std_top_k_r, color=colors["top_k"], linestyle=(0, (1, 10)), label=f"Std = {std_top_k_r:.3f}", ) ax[1].axvline( mean_top_k_r - std_top_k_r, color=colors["top_k"], linestyle=(0, (1, 10)), ) # reward title & legend ax[1].set_title( f"Reward distribution for {top_k} vs {len(reward)}" + f" samples\nBest: {best_r:.3f}", y=0, pad=-20, verticalalignment="top", size=12, ) ax[1].legend() # Finalize figure title = f"{name.capitalize()} energy and reward distributions" if step is not None: title += f" (step {step})" fig.suptitle(title, y=0.95) plt.tight_layout(rect=[0, 0.02, 1, 0.98]) # store metrics metrics = { f"Mean {name} energy": mean_e, f"Std {name} energy": std_e, f"Mean {name} reward": mean_r, f"Std {name} reward": std_r, f"Mean {name} top k energy": mean_top_k_e, f"Std {name} top k energy": std_top_k_e, f"Mean {name} top k reward": mean_top_k_r, f"Std {name} top k reward": std_top_k_r, f"Best (min) {name} energy": best_e, f"Best (max) {name} reward": best_r, } figs = [fig] fig_names = [title] if name.lower() == "train ground truth": # train stats mode: the ground truth data has meen plotted # and computed, let's do it again for the proxy data. # This can be used to visualize potential distribution mismatch # between the proxy and the ground truth data. proxy_metrics, proxy_figs, proxy_fig_names = self.top_k_metrics_and_plots( None, top_k, "train proxy", energy=proxy, reward=proxy_reward, step=None, **kwargs, ) # aggregate metrics and figures metrics.update(proxy_metrics) figs += proxy_figs fig_names += proxy_fig_names return metrics, figs, fig_names
[docs] def plot_reward_distribution( self, states=None, scores=None, ax=None, title=None, proxy=None, **kwargs ): if ax is None: fig, ax = plt.subplots() standalone = True else: standalone = False if title == None: title = "Scores of Sampled States" if proxy is None: proxy = self.proxy if scores is None: if isinstance(states[0], torch.Tensor): states = torch.vstack(states).to(self.device, self.float) if isinstance(states, torch.Tensor) == False: states = torch.tensor(states, device=self.device, dtype=self.float) states_proxy = self.states2proxy(states) scores = self.proxy(states_proxy) if isinstance(scores, TensorType): scores = scores.cpu().detach().numpy() ax.hist(scores) ax.set_title(title) ax.set_ylabel("Number of Samples") ax.set_xlabel("Energy") plt.show() if standalone == True: plt.tight_layout() plt.close() return ax
[docs] def test( self, samples: Union[ TensorType["n_trajectories", "..."], npt.NDArray[np.float32], List ], ) -> dict: """ Placeholder for a custom test function that can be defined for a specific environment. Can be overwritten if special evaluation procedure is needed for a given environment. Args ---- samples A collection of sampled terminating states. Returns ------- metrics A dictionary with metrics and their calculated values. """ return {}