"""
Classes to represent hyper-cube environments
"""
import itertools
import warnings
from abc import ABC, abstractmethod
from typing import List, Optional, Tuple, Union
import matplotlib.pyplot as plt
import numpy as np
import torch
from sklearn.neighbors import KernelDensity
from torch.distributions import Bernoulli, Beta, Categorical, MixtureSameFamily
from torchtyping import TensorType
from gflownet.envs.base import GFlowNetEnv
from gflownet.envs.composite.setfix import SetFix
from gflownet.utils.common import copy, tbool, tfloat, torch2np
[docs]
class CubeBase(GFlowNetEnv, ABC):
"""
Base class for hyper-cube environments, continuous or hybrid versions of the
hyper-grid in which the continuous increments are modelled by a (mixture of) Beta
distribution(s).
The states space is the value of each dimension, defined in the closed set [0, 1].
If the value of a dimension gets larger than 1 - min_incr, then the trajectory is
ended (the only possible action is EOS).
Attributes
----------
n_dim : int
Dimensionality of the hyper-cube.
min_incr : float
Minimum increment in the actions, in (0, 1). This is necessary to ensure
that all trajectories have finite length.
n_comp : int
Number of components in the mixture of Beta distributions.
epsilon : float
Small constant to control the clamping interval of the inputs to the
calculation of log probabilities. Clamping interval will be [epsilon, 1 -
epsilon]. The smaller the value, the lower the probability to incur an
unbounded result due to numerical precision, but the lower the precision too.
Default: 1e-6.
kappa : float
Small constant to control the intervals of the generated sets of states (in a
grid or uniformly). States will be in the interval [kappa, 1 - kappa]. Default:
1e-3.
ignored_dims : list
Boolean mask of ignored dimensions. This can be used for trajectories that may
have multiple dimensions coupled or fixed. For each dimension, True if ignored,
False, otherwise. If None, no dimension is ignored.
"""
def __init__(
self,
n_dim: int = 2,
min_incr: float = 0.1,
n_comp: int = 1,
beta_params_min: float = 1.0,
beta_params_max: float = 100.0,
epsilon: float = 1e-6,
kappa: float = 1e-3,
ignored_dims: Optional[List[bool]] = None,
fixed_distr_params: dict = {
"beta_weights": 1.0,
"beta_alpha": 10.0,
"beta_beta": 10.0,
"bernoulli_bts_prob": 0.1,
"bernoulli_eos_prob": 0.1,
},
random_distr_params: dict = {
"beta_weights": 1.0,
"beta_alpha": 10.0,
"beta_beta": 10.0,
"bernoulli_bts_prob": 0.1,
"bernoulli_eos_prob": 0.1,
},
**kwargs,
):
assert n_dim > 0
assert min_incr > 0.0
assert min_incr < 1.0
assert n_comp > 0
# Main properties
if ignored_dims:
self.ignored_dims = ignored_dims
else:
self.ignored_dims = [False] * self.n_dim
# Parameters of the policy distribution
[docs]
self.beta_params_min = beta_params_min
[docs]
self.beta_params_max = beta_params_max
# Source state is abstract - not included in the cube: -1 for all dimensions.
[docs]
self.source = [-1 for _ in range(self.n_dim)]
# Small constant to clamp the inputs to the beta distribution
# Small constant to restrict the interval of (test) sets
# Base class init
super().__init__(
fixed_distr_params=fixed_distr_params,
random_distr_params=random_distr_params,
**kwargs,
)
@abstractmethod
[docs]
def get_action_space(self):
pass
@abstractmethod
[docs]
def get_policy_output(self, params: dict) -> TensorType["policy_output_dim"]:
pass
@abstractmethod
[docs]
def get_mask_invalid_actions_forward(
self,
state: Optional[List] = None,
done: Optional[bool] = None,
) -> List:
pass
@abstractmethod
[docs]
def get_mask_invalid_actions_backward(self, state=None, done=None, parents_a=None):
pass
[docs]
def states2proxy(
self, states: Union[List, TensorType["batch", "state_dim"]]
) -> TensorType["batch", "state_dim"]:
"""
Prepares a batch of states in "environment format" for a proxy: clips the
states into [0, 1] and maps them to [CELL_MIN, CELL_MAX]
Args
----
states : list or tensor
A batch of states in environment format, either as a list of states or as a
single tensor.
Returns
-------
A tensor containing all the states in the batch.
"""
states = tfloat(states, device=self.device, float_type=self.float)
return 2.0 * torch.clip(states, min=0.0, max=CELL_MAX) - CELL_MAX
# TODO: Change value of ignored dimensions / source values
[docs]
def states2policy(
self, states: Union[List, TensorType["batch", "state_dim"]]
) -> TensorType["batch", "state_dim"]:
"""
Prepares a batch of states in "environment format" for the policy model: clips
the states into [0, 1] and maps them to [-1.0, 1.0]
Args
----
states : list or tensor
A batch of states in environment format, either as a list of states or as a
single tensor.
Returns
-------
A tensor containing all the states in the batch.
"""
states = tfloat(states, device=self.device, float_type=self.float)
return 2.0 * torch.clip(states, min=0.0, max=1.0) - 1.0
[docs]
def state2readable(self, state: List) -> str:
"""
Converts a state (a list of positions) into a human-readable string
representing a state.
"""
return str(state).replace("(", "[").replace(")", "]").replace(",", "")
[docs]
def readable2state(self, readable: str) -> List:
"""
Converts a human-readable string representing a state into a state as a list of
positions.
"""
return [el for el in readable.strip("[]").split(" ")]
@abstractmethod
[docs]
def get_parents(
self, state: List = None, done: bool = None, action: Tuple[int, float] = None
) -> Tuple[List[List], List[Tuple[int, float]]]:
"""
Determines all parents and actions that lead to state.
Args
----
state : list
Representation of a state
done : bool
Whether the trajectory is done. If None, done is taken from instance.
action : int
Last action performed
Returns
-------
parents : list
List of parents in state format
actions : list
List of actions that lead to state for each parent in parents
"""
pass
@abstractmethod
[docs]
def sample_actions_batch(
self,
policy_outputs: TensorType["n_states", "policy_output_dim"],
mask: Optional[TensorType["n_states", "policy_output_dim"]] = None,
states_from: Optional[List] = None,
is_backward: Optional[bool] = False,
random_action_prob: Optional[float] = 0.0,
temperature_logits: Optional[float] = 1.0,
) -> Tuple[List[Tuple], TensorType["n_states"]]:
"""
Samples a batch of actions from a batch of policy outputs.
"""
pass
[docs]
def get_logprobs(
self,
policy_outputs: TensorType["n_states", "policy_output_dim"],
is_forward: bool,
actions: Union[List, TensorType["n_states", "action_dim"]],
mask_invalid_actions: TensorType["batch_size", "policy_output_dim"] = None,
loginf: float = 1000,
) -> TensorType["batch_size"]:
"""
Computes log probabilities of actions given policy outputs and actions.
"""
pass
[docs]
def step(
self, action: Tuple[int, float]
) -> Tuple[List[float], Tuple[int, float], bool]:
"""
Executes step given an action.
Args
----
action : tuple
Action to be executed. An action is a tuple with two values:
(dimension, increment).
Returns
-------
self.state : list
The sequence after executing the action
action : int
Action executed
valid : bool
False, if the action is not allowed for the current state, e.g. stop at the
root state
"""
pass
def _beta_params_to_policy_outputs(self, param_name: str, params_dict: dict):
"""
Maps the values of alpha and beta given in the configuration to new values such
that when passed to _make_increments_distribution, the actual alpha and beta
passed to the Beta distribution(s) are the ones from the configuration.
Args
----
param_name : str
Name of the parameter to transform: alpha or beta
params_dict : dict
Dictionary with the complete set of parameters of the distribution.
See
---
_make_increments_distribution()
"""
param_value = tfloat(
params_dict[f"beta_{param_name}"], float_type=self.float, device=self.device
)
return torch.logit((param_value - self.beta_params_min) / self.beta_params_max)
def _get_effective_dims(self, state: Optional[List] = None) -> List:
state = self._get_state(state)
return [s for s, ign_dim in zip(state, self.ignored_dims) if not ign_dim]
[docs]
def fit_kde(
self,
samples: TensorType["batch_size", "state_proxy_dim"],
kernel: str = "gaussian",
bandwidth: float = 0.1,
):
r"""
Fits a Kernel Density Estimator on a batch of samples.
Parameters
----------
samples : tensor
A batch of samples in proxy format.
kernel : str
An identifier of the kernel to use for the density estimation. It must be a
valid kernel for the scikit-learn method
:py:meth:`sklearn.neighbors.KernelDensity`.
bandwidth : float
The bandwidth of the kernel.
"""
samples = torch2np(samples)
return KernelDensity(kernel=kernel, bandwidth=bandwidth).fit(samples)
[docs]
def plot_reward_samples(
self,
samples: TensorType["batch_size", "state_proxy_dim"],
samples_reward: TensorType["batch_size", "state_proxy_dim"],
rewards: TensorType["batch_size"],
alpha: float = 0.5,
dpi: int = 150,
max_samples: int = 500,
**kwargs,
):
"""
Plots the reward contour alongside a batch of samples.
Parameters
----------
samples : tensor
A batch of samples from the GFlowNet policy in proxy format. These samples
will be plotted on top of the reward density.
samples_reward : tensor
A batch of samples containing a grid over the sample space, from which the
reward has been obtained. These samples are used to plot the contour of
reward density.
rewards : tensor
The rewards of samples_reward. It should be a vector of dimensionality
n_per_dim ** 2 and be sorted such that the each block at rewards[i *
n_per_dim:i * n_per_dim + n_per_dim] correspond to the rewards at the i-th
row of the grid of samples, from top to bottom. The same is assumed for
samples_reward.
alpha : float
Transparency of the reward contour.
dpi : int
Dots per inch, indicating the resolution of the plot.
max_samples : int
Maximum of number of samples to include in the plot.
"""
if self.n_dim != 2:
return None
samples = torch2np(samples)
samples_reward = torch2np(samples_reward)
rewards = torch2np(rewards)
# Create mesh grid from samples_reward
n_per_dim = int(np.sqrt(samples_reward.shape[0]))
assert n_per_dim**2 == samples_reward.shape[0]
x_coords = samples_reward[:, 0].reshape((n_per_dim, n_per_dim))
y_coords = samples_reward[:, 1].reshape((n_per_dim, n_per_dim))
rewards = rewards.reshape((n_per_dim, n_per_dim))
# Init figure
fig, ax = plt.subplots()
fig.set_dpi(dpi)
# Plot reward contour
h = ax.contourf(x_coords, y_coords, rewards, alpha=alpha)
ax.axis("scaled")
fig.colorbar(h, ax=ax)
# Plot samples
random_indices = np.random.permutation(samples.shape[0])[:max_samples]
ax.scatter(samples[random_indices, 0], samples[random_indices, 1], alpha=alpha)
# Figure settings
ax.grid()
padding = 0.05 * (CELL_MAX - CELL_MIN)
ax.set_xlim([CELL_MIN - padding, CELL_MAX + padding])
ax.set_ylim([CELL_MIN - padding, CELL_MAX + padding])
plt.tight_layout()
return fig
[docs]
def plot_kde(
self,
samples: TensorType["batch_size", "state_proxy_dim"],
kde,
alpha: float = 0.5,
dpi=150,
colorbar: bool = True,
**kwargs,
):
"""
Plots the density previously estimated from a batch of samples via KDE over the
entire sample space.
Parameters
----------
samples : tensor
A batch of samples containing a grid over the sample space. These samples
are used to plot the contour of the estimated density.
kde : KDE
A scikit-learn KDE object fit with a batch of samples.
alpha : float
Transparency of the density contour.
dpi : int
Dots per inch, indicating the resolution of the plot.
"""
if self.n_dim != 2:
return None
samples = torch2np(samples)
# Create mesh grid from samples
n_per_dim = int(np.sqrt(samples.shape[0]))
assert n_per_dim**2 == samples.shape[0]
x_coords = samples[:, 0].reshape((n_per_dim, n_per_dim))
y_coords = samples[:, 1].reshape((n_per_dim, n_per_dim))
# Score samples with KDE
Z = np.exp(kde.score_samples(samples)).reshape((n_per_dim, n_per_dim))
# Init figure
fig, ax = plt.subplots()
fig.set_dpi(dpi)
# Plot KDE
h = ax.contourf(x_coords, y_coords, Z, alpha=alpha)
ax.axis("scaled")
if colorbar:
fig.colorbar(h, ax=ax)
ax.set_xticks([])
ax.set_yticks([])
for spine in ax.spines.values():
spine.set_visible(False)
# Set tight layout
plt.tight_layout()
return fig
[docs]
class ContinuousCube(CubeBase):
"""
Continuous hyper-cube environment (continuous version of a hyper-grid) in which the
action space consists of the increment of each dimension d, modelled by a mixture
of Beta distributions. The state space is the value of each dimension. In order to
ensure that all trajectories are of finite length, actions have a minimum increment
for all dimensions determined by min_incr. If the value of any dimension is larger
than 1 - min_incr, then that dimension can't be further incremented. In order to
ensure the coverage of the state space, the first action (from the source state) is
not constrained by the minimum increment.
Actions do not represent absolute increments but rather the relative increment with
respect to the distance to the edges of the hyper-cube, from the minimum increment.
That is, if dimension d of a state has value 0.3, the minimum increment (min_incr)
is 0.1 and the maximum value is 1.0, an action of 0.5 will increment the
value of the dimension in 0.5 * (1.0 - 0.3 - 0.1) = 0.5 * 0.6 = 0.3. Therefore, the
value of d in the next state will be 0.3 + 0.3 = 0.6.
Attributes
----------
n_dim : int
Dimensionality of the hyper-cube.
min_incr : float
Minimum increment in the actions, in (0, 1). This is necessary to ensure
that all trajectories have finite length.
n_comp : int
Number of components in the mixture of Beta distributions.
"""
def __init__(self, **kwargs):
# Number of fixed entries in the mask
super().__init__(**kwargs)
def _compute_mask_dim(self):
"""
Calculates the mask dimensionality..
The mask consists of three fixed flags, plus a flag for each dimension.
Returns
-------
int
The number of elements in the Stack masks.
"""
return self.mask_dim_base + self.n_dim
[docs]
def get_action_space(self):
"""
The action space is continuous, thus not defined as such here.
The actions contained in the action space are "representatives"
The actions are tuples of length n_dim + 1, where the value at position d
indicates the increment of dimension d, and the value at position -1 indicates
whether the action is from or to source (1), or 0 otherwise.
EOS is indicated by np.inf for all dimensions.
The action space consists of the EOS actions and two representatives:
- Generic increment action, not from or to source: (0, 0, ..., 0, 0)
- Generic increment action, from or to source: (0, 0, ..., 0, 1)
- EOS: (inf, inf, ..., inf, inf)
"""
actions_dim = self.n_dim + 1
self.eos = tuple([np.inf] * actions_dim)
self.representative_no_source = tuple([0.0] * self.n_dim) + (0,)
self.representative_source = tuple([0.0] * self.n_dim) + (1,)
return [self.representative_no_source, self.representative_source, self.eos]
[docs]
def action2representative(self, action: Tuple) -> Tuple:
"""
Replaces the continuous values of an action by 0s (the "generic" or
"representative" action in the first position of the action space), so that
they can be compared against the action space or a mask.
If the action is EOS, it is returned as is.
Parameters
----------
action : tuple
An actual action of the Cube environment (with continuous values)
Returns
-------
tuple
A representative of the action, where continuous values are replaced by
zeros.
"""
if action == self.eos:
return action
elif action[-1] == 0:
return self.action_space[0]
elif action[-1] == 1:
return self.action_space[1]
else:
raise ValueError(f"Action {action} does not seem like a valid action")
def _get_max_trajectory_length(self) -> int:
"""
Returns the maximum trajectory length of the environment, including the EOS
action.
"""
return int(np.ceil(1.0 / self.min_incr) + 2) + 1
[docs]
def get_policy_output(self, params: dict) -> TensorType["policy_output_dim"]:
"""
Defines the structure of the output of the policy model.
The policy output will be used to initialize a distribution, from which an
action is to be determined or sampled. This method returns a vector with a
fixed policy defined by params.
The environment consists of both continuous and discrete actions.
Continuous actions
For each dimension d of the hyper-cube and component c of the mixture, the
output of the policy should return:
1) the weight of the component in the mixture,
2) the pre-alpha parameter of the Beta distribution to sample the increment,
3) the pre-beta parameter of the Beta distribution to sample the increment.
These parameters are the first n_dim * n_comp * 3 of the policy output such
that the first 3 x C elements correspond to the first dimension, and so on.
Discrete actions
Additionally, the policy output contains one logit (pos -1) of a Bernoulli
distribution to model the (discrete) forward probability of selecting the EOS
action and another logit (pos -2) for the (discrete) backward probability of
returning to the source node.
Therefore, the output of the policy model has dimensionality D x C x 3 + 2,
where D is the number of dimensions (self.n_dim) and C is the number of
components (self.n_comp).
See
---
_beta_params_to_policy_outputs()
"""
# Parameters for continuous actions
self._len_policy_output_cont = self.n_dim * self.n_comp * 3
policy_output_cont = torch.empty(
self._len_policy_output_cont,
dtype=self.float,
device=self.device,
)
policy_output_cont[0::3] = params["beta_weights"]
policy_output_cont[1::3] = self._beta_params_to_policy_outputs("alpha", params)
policy_output_cont[2::3] = self._beta_params_to_policy_outputs("beta", params)
# Logit for Bernoulli distribution to model EOS action
policy_output_eos_logit = torch.logit(
tfloat(
[params["bernoulli_eos_prob"]],
float_type=self.float,
device=self.device,
)
)
# Logit for Bernoulli distribution to model back-to-source action
policy_output_bts_logit = torch.logit(
tfloat(
[params["bernoulli_bts_prob"]],
float_type=self.float,
device=self.device,
)
)
# Concatenate all outputs
policy_output = torch.cat(
(
policy_output_cont,
policy_output_bts_logit,
policy_output_eos_logit,
)
)
return policy_output
def _get_policy_betas_weights(
self, policy_output: TensorType["n_states", "policy_output_dim"]
) -> TensorType["n_states", "n_dim * n_comp"]:
"""
Reduces a given policy output to the part corresponding to the weights of the
mixture of Beta distributions.
See: get_policy_output()
"""
return policy_output[:, 0 : self._len_policy_output_cont : 3]
def _get_policy_betas_alpha(
self, policy_output: TensorType["n_states", "policy_output_dim"]
) -> TensorType["n_states", "n_dim * n_comp"]:
"""
Reduces a given policy output to the part corresponding to the alphas of the
mixture of Beta distributions.
See: get_policy_output()
"""
return policy_output[:, 1 : self._len_policy_output_cont : 3]
def _get_policy_betas_beta(
self, policy_output: TensorType["n_states", "policy_output_dim"]
) -> TensorType["n_states", "n_dim * n_comp"]:
"""
Reduces a given policy output to the part corresponding to the betas of the
mixture of Beta distributions.
See: get_policy_output()
"""
return policy_output[:, 2 : self._len_policy_output_cont : 3]
def _get_policy_eos_logit(
self, policy_output: TensorType["n_states", "policy_output_dim"]
) -> TensorType["n_states", "1"]:
"""
Reduces a given policy output to the part corresponding to the logit of the
Bernoulli distribution to model the EOS action.
See: get_policy_output()
"""
return policy_output[:, -1]
def _get_policy_source_logit(
self, policy_output: TensorType["n_states", "policy_output_dim"]
) -> TensorType["n_states", "1"]:
"""
Reduces a given policy output to the part corresponding to the logit of the
Bernoulli distribution to model the back-to-source action.
See: get_policy_output()
"""
return policy_output[:, -2]
[docs]
def get_mask_invalid_actions_forward(
self,
state: Optional[List] = None,
done: Optional[bool] = None,
) -> List:
"""
The action space is continuous, thus the mask is not only of invalid actions as
in discrete environments, but also an indicator of "special cases", for example
states from which only certain actions are possible.
The values of True/False intend to approximately stick to the semantics in
discrete environments, where the mask is of "invalid" actions, but it is
important to note that a direct interpretation in this sense does not always
apply.
For example, the mask values of special cases are True if the special cases they
refer to are "invalid". In other words, the values are False if the state has
the special case.
The forward mask has the following structure:
- 0 : whether a continuous action is invalid. True if the value at any
dimension is larger than 1 - min_incr, or if done is True. False otherwise.
- 1 : special case when the state is the source state. False when the state is
the source state, True otherwise.
- 2 : whether EOS action is invalid. EOS is valid from any state, except the
source state or if done is True.
- -n_dim: : dimensions that should be ignored when sampling actions or
computing logprobs. This can be used for trajectories that may have
multiple dimensions coupled or fixed. For each dimension, True if ignored,
False, otherwise.
"""
state = self._get_state(state)
done = self._get_done(done)
# If done, the entire mask is True (all actions are "invalid" and no special
# cases)
if done:
return [True] * self.mask_dim
mask = [False] * self.mask_dim_base + self.ignored_dims
# If the state is the source state, EOS is invalid
if self._get_effective_dims(state) == self._get_effective_dims(self.source):
mask[2] = True
# If the state is not the source, indicate not special case (True)
else:
mask[1] = True
# If the value of any dimension is greater than 1 - min_incr, then continuous
# actions are invalid (True).
if any([s > 1 - self.min_incr for s in self._get_effective_dims(state)]):
mask[0] = True
return mask
[docs]
def get_mask_invalid_actions_backward(self, state=None, done=None, parents_a=None):
"""
The action space is continuous, thus the mask is not only of invalid actions as
in discrete environments, but also an indicator of "special cases", for example
states from which only certain actions are possible.
In order to approximately stick to the semantics in discrete environments,
where the mask is of "invalid" actions, that is the value is True if an action
is invalid, the mask values of special cases are True if the special cases they
refer to are "invalid". In other words, the values are False if the state has
the special case.
The backward mask has the following structure:
- 0 : whether a continuous action is invalid. True if the value at any
dimension is smaller than min_incr, or if done is True. False otherwise.
- 1 : special case when back-to-source action is the only possible action.
False if any dimension is smaller than min_incr, True otherwise.
- 2 : whether EOS action is invalid. False only if done is True, True
(invalid) otherwise.
- -n_dim: : dimensions that should be ignored when sampling actions or
computing logprobs. this can be used for trajectories that may have
multiple dimensions coupled or fixed. for each dimension, true if ignored,
false, otherwise. By default, no dimension is ignored.
"""
state = self._get_state(state)
done = self._get_done(done)
mask = [True] * self.mask_dim_base + self.ignored_dims
# If the state is the source state, entire mask is True
if self._get_effective_dims(state) == self._get_effective_dims(self.source):
return mask
# If done, only valid action is EOS.
if done:
mask[2] = False
return mask
# If any dimension is smaller than m, then back-to-source action is the only
# possible actiona.
if any([s < self.min_incr for s in self._get_effective_dims(state)]):
mask[1] = False
return mask
# Otherwise, continuous actions are valid
mask[0] = False
return mask
[docs]
def get_valid_actions(
self,
mask: Optional[bool] = None,
state: Optional[List] = None,
done: Optional[bool] = None,
backward: Optional[bool] = False,
) -> List[Tuple]:
"""
Returns the list of non-invalid (valid, for short) according to the mask of
invalid actions.
As a continuous environment, the returned actions are "representatives", that
is the actions represented in the action space.
Parameters
----------
mask : list (optional)
The mask of a state. If None, it is computed in place.
state : list (optional)
A state in GFlowNet format. If None, self.state is used.
done : bool (optional)
Whether the trajectory is done. If None, self.done is used.
backward : bool
True if the transtion is backwards; False if forward.
Returns
-------
list
The list of representatives of the valid actions.
"""
state = self._get_state(state)
done = self._get_done(done)
if mask is None:
mask = self.get_mask(state, done, backward)
actions_valid = []
if backward:
if mask[0] is False or mask[1] is False:
actions_valid.append(self.action_space[1])
if mask[0] is False:
actions_valid.append(self.action_space[0])
else:
if mask[1] is False:
actions_valid.append(self.action_space[1])
elif mask[0] is False:
actions_valid.append(self.action_space[0])
if mask[2] is False:
actions_valid.append(self.eos)
return actions_valid
[docs]
def get_parents(
self, state: List = None, done: bool = None, action: Tuple[int, float] = None
) -> Tuple[List[List], List[Tuple[int, float]]]:
"""
Defined only because it is required. A ContinuousEnv should be created to avoid
this issue.
"""
pass
[docs]
def relative_to_absolute_increments(
self,
states: TensorType["n_states", "n_dim"],
increments_rel: TensorType["n_states", "n_dim"],
is_backward: bool,
):
"""
Returns a batch of absolute increments (actions) given a batch of states,
relative increments and minimum_increments.
Given a dimension value x, a relative increment r, and a minimum increment m,
then the absolute increment a is given by:
Forward:
a = m + r * (1 - x - m)
Backward:
a = m + r * (x - m)
"""
min_increments = torch.full_like(
increments_rel, self.min_incr, dtype=self.float, device=self.device
)
if is_backward:
return min_increments + increments_rel * (states - min_increments)
else:
return min_increments + increments_rel * (1.0 - states - min_increments)
[docs]
def absolute_to_relative_increments(
self,
states: TensorType["n_states", "n_dim"],
increments_abs: TensorType["n_states", "n_dim"],
is_backward: bool,
):
"""
Returns a batch of relative increments (as sampled by the Beta distributions)
given a batch of states, absolute increments (actions) and minimum_increments.
Given a dimension value x, an absolute increment a, and a minimum increment m,
then the relative increment r is given by:
Forward:
r = (a - m) / (1 - x - m)
Backward:
r = (a - m) / (x - m)
"""
min_increments = torch.full_like(
increments_abs, self.min_incr, dtype=self.float, device=self.device
)
if is_backward:
denominator = (states - min_increments).clamp(min=self.epsilon)
increments_rel = (increments_abs - min_increments) / denominator
return increments_rel
else:
denominator = (1.0 - states - min_increments).clamp(min=self.epsilon)
return (increments_abs - min_increments) / denominator
@staticmethod
def _get_beta_params_from_mean_variance(
mean: TensorType["n_states", "n_dim_x_n_comp"],
variance: TensorType["n_states", "n_dim_x_n_comp"],
) -> Tuple[
TensorType["n_states", "n_dim_x_n_comp"],
TensorType["n_states", "n_dim_x_n_comp"],
]:
"""
Calculates the alpha and beta parameters of a Beta distribution from the mean
and variance.
The method operates on tensors containing a batch of means and variances.
Args
----
mean : tensor
A batch of means.
variance : tensor
A batch of variances.
Returns
-------
alpha : tensor
The alpha parameters for the Beta distributions as a function of the mean
and variances.
beta : tensor
The beta parameters for the Beta distributions as a function of the mean
and variances.
"""
one_minus_mean = 1.0 - mean
beta = one_minus_mean * (mean * one_minus_mean - variance) / variance
alpha = (mean * beta) / one_minus_mean
return alpha, beta
def _make_increments_distribution(
self,
policy_outputs: TensorType["n_states", "policy_output_dim"],
) -> MixtureSameFamily:
mix_logits = self._get_policy_betas_weights(policy_outputs).reshape(
-1, self.n_dim, self.n_comp
)
mix = Categorical(logits=mix_logits)
alphas = self._get_policy_betas_alpha(policy_outputs).reshape(
-1, self.n_dim, self.n_comp
)
alphas = self.beta_params_max * torch.sigmoid(alphas) + self.beta_params_min
betas = self._get_policy_betas_beta(policy_outputs).reshape(
-1, self.n_dim, self.n_comp
)
betas = self.beta_params_max * torch.sigmoid(betas) + self.beta_params_min
beta_distr = Beta(alphas, betas)
return MixtureSameFamily(mix, beta_distr)
def _mask_ignored_dimensions(
self,
mask: TensorType["n_states", "policy_outputs_dim"],
tensor_to_mask: TensorType["n_states", "n_dim"],
) -> MixtureSameFamily:
"""
Makes the actions, logprobs or log jacobian entries of ignored dimensions zero.
Since the shape of all the tensor of actions, the logprobs of increments and
the log of the diagonal of the Jacobian must be the same, this method makes no
distiction between for simplicity.
Args
----
mask : tensor
Boolean mask indicating (True) which dimensions should be set to zero.
tensor_to_mask : tensor
Tensor to be modified. It may be a tensor of actions, of logprobs of
increments or the log of the diagonal of the Jacobian.
"""
is_ignored_dim = mask[:, -self.n_dim :]
if torch.any(is_ignored_dim):
shape_orig = tensor_to_mask.shape
tensor_to_mask[is_ignored_dim] = 0.0
tensor_to_mask = tensor_to_mask.reshape(shape_orig)
return tensor_to_mask
[docs]
def sample_actions_batch(
self,
policy_outputs: TensorType["n_states", "policy_output_dim"],
mask: Optional[TensorType["n_states", "mask_dim"]] = None,
states_from: List = None,
is_backward: Optional[bool] = False,
random_action_prob: Optional[float] = 0.0,
temperature_logits: Optional[float] = 1.0,
) -> Tuple[List[Tuple], TensorType["n_states"]]:
"""
Samples a batch of actions from a batch of policy outputs.
This method overwrites the methof of the GFlowNetEnv because it
is a continious enviroment.
Parameters
----------
policy_outputs : tensor
The output of the GFlowNet policy model.
mask : tensor
The mask of invalid actions. For continuous or mixed environments, the mask
may be tensor with an arbitrary length contaning information about special
states, as defined elsewhere in the environment.
states_from : tensor
The states originating the actions, in GFlowNet format. Ignored in discrete
environments and only required in certain continuous environments.
is_backward : bool
True if the actions are backward, False if the actions are forward
(default).
random_action_prob : float, optional
The probability of sampling a random action. If larger than one, the model
outputs will be replaced by a random policy vector with probability
`random_action_prob`, according to Bernoulli distribution.
temperature_logits : float, optional
A scalar by which the model outputs are divided to temper the sampling
distribution.
Returns
-------
actions : list
The list of sampled actions.
"""
if not is_backward:
return self._sample_actions_batch_forward(
policy_outputs=policy_outputs,
mask=mask,
states_from=states_from,
random_action_prob=random_action_prob,
temperature_logits=temperature_logits,
)
else:
return self._sample_actions_batch_backward(
policy_outputs=policy_outputs,
mask=mask,
states_from=states_from,
random_action_prob=random_action_prob,
temperature_logits=temperature_logits,
)
def _sample_actions_batch_forward(
self,
policy_outputs: TensorType["n_states", "policy_output_dim"],
mask: Optional[TensorType["n_states", "mask_dim"]] = None,
states_from: List = None,
random_action_prob: Optional[float] = 0.0,
temperature_logits: Optional[float] = 1.0,
) -> Tuple[List[Tuple], TensorType["n_states"]]:
"""
Samples a a batch of forward actions from a batch of policy outputs.
An action indicates, for each dimension, the absolute increment of the
dimension value. However, in order to ensure that trajectories have finite
length, increments must have a minumum increment (self.min_incr) except if the
originating state is the source state (special case, see
get_mask_invalid_actions_forward()). Furthermore, absolute increments must also
be smaller than the distance from the dimension value to the edge of the cube
(1.0). In order to accomodate these constraints, first relative
increments (in [0, 1]) are sampled from a (mixture of) Beta distribution(s),
where 0.0 indicates an absolute increment of min_incr and 1.0 indicates an
absolute increment of 1 - x + min_incr (going to the edge).
Therefore, given a dimension value x, a relative increment r, a minimum
increment m and a maximum value 1, the absolute increment a is given by:
a = m + r * (1 - x - m)
The continuous distribution to sample the continuous action described above
must be mixed with the discrete distribution to model the sampling of the EOS
action. The EOS action can be sampled from any state except from the source
state or whether the trajectory is done. That the EOS action is invalid is
indicated by mask[-1] being False.
Finally, regarding the constraints on the increments, the following special
cases are taken into account:
- The originating state is the source state: in this case, the minimum
increment is 0.0 instead of self.min_incr. This is to ensure that the entire
state space can be reached. This is indicated by mask[-2] being False.
- The value at any dimension is at a distance from the cube edge smaller than the
minimum increment (x > 1 - m). In this case, only EOS is valid.
This is indicated by mask[0] being True (continuous actions are invalid).
"""
# Initialize variables
n_states = policy_outputs.shape[0]
states_from_tensor = tfloat(
states_from, float_type=self.float, device=self.device
)
is_eos = torch.zeros(n_states, dtype=torch.bool, device=self.device)
# Determine source states
is_source = ~mask[:, 1]
# EOS is the only possible action if continuous actions are invalid (mask[0] is
# True)
is_eos_forced = mask[:, 0]
is_eos[is_eos_forced] = True
# Ensure that is_eos_forced does not include any source state
assert not torch.any(torch.logical_and(is_source, is_eos_forced))
# Randomize actions and temper the logits
logits_sampling = self.randomize_and_temper_sampling_distribution(
policy_outputs, random_action_prob, temperature_logits
)
# Sample EOS from Bernoulli distribution
do_eos = torch.logical_and(~is_source, ~is_eos_forced)
if torch.any(do_eos):
is_eos_sampled = torch.zeros_like(do_eos)
logits_eos = self._get_policy_eos_logit(logits_sampling)[do_eos]
distr_eos = Bernoulli(logits=logits_eos)
is_eos_sampled[do_eos] = tbool(distr_eos.sample(), device=self.device)
is_eos[is_eos_sampled] = True
# Sample (relative) increments if EOS is not the (sampled or forced) action
do_increments = ~is_eos
if torch.any(do_increments):
distr_increments = self._make_increments_distribution(
logits_sampling[do_increments]
)
# Shape of increments: [n_do_increments, n_dim]
increments = distr_increments.sample()
# Compute absolute increments from sampled relative increments if state is
# not source
is_relative = ~is_source[do_increments]
states_from_rel = tfloat(
states_from_tensor[do_increments],
float_type=self.float,
device=self.device,
)[is_relative]
increments[is_relative] = self.relative_to_absolute_increments(
states_from_rel,
increments[is_relative],
is_backward=False,
)
# Build actions
actions_tensor = torch.full(
(n_states, self.n_dim + 1), torch.inf, dtype=self.float, device=self.device
)
if torch.any(do_increments):
# Make increments of ignored dimensions zero
increments = self._mask_ignored_dimensions(mask[do_increments], increments)
# Add dimension is_source and add to actions tensor
actions_tensor[do_increments] = torch.cat(
(increments, torch.zeros((increments.shape[0], 1))), dim=1
)
actions_tensor[is_source, -1] = 1
return [tuple(a) for a in actions_tensor.tolist()]
def _sample_actions_batch_backward(
self,
policy_outputs: TensorType["n_states", "policy_output_dim"],
mask: Optional[TensorType["n_states", "mask_dim"]] = None,
states_from: List = None,
random_action_prob: Optional[float] = 0.0,
temperature_logits: Optional[float] = 1.0,
) -> Tuple[List[Tuple], TensorType["n_states"]]:
"""
Samples a a batch of backward actions from a batch of policy outputs.
An action indicates, for each dimension, the absolute increment of the
dimension value. However, in order to ensure that trajectories have finite
length, increments must have a minumum increment (self.min_incr). Furthermore,
absolute increments must also be smaller than the distance from the dimension
value to the edge of the cube. In order to accomodate these constraints, first
relative increments (in [0, 1]) are sampled from a (mixture of) Beta
distribution(s), where 0.0 indicates an absolute increment of min_incr and 1.0
indicates an absolute increment of x (going back to the source).
Therefore, given a dimension value x, a relative increment r, a minimum
increment m and a maximum value 1, the absolute increment a is given by:
a = m + r * (x - m)
The continuous distribution to sample the continuous action described above
must be mixed with the discrete distribution to model the sampling of the back
to source (BTS) action. While the BTS action is also a continuous action, it
needs to be modelled with a (discrete) Bernoulli distribution in order to
ensure that this action has positive likelihood.
Finally, regarding the constraints on the increments, the special case where
the trajectory is done and the only possible action is EOS, is also taken into
account.
"""
# Initialize variables
n_states = policy_outputs.shape[0]
is_bts = torch.zeros(n_states, dtype=torch.bool, device=self.device)
# EOS is the only possible action only if done is True (mask[2] is False)
is_eos = ~mask[:, 2]
# Back-to-source (BTS) is the only possible action if mask[1] is False
is_bts_forced = ~mask[:, 1]
is_bts[is_bts_forced] = True
# Randomize actions and temper the logits
logits_sampling = self.randomize_and_temper_sampling_distribution(
policy_outputs, random_action_prob, temperature_logits
)
# Sample BTS from Bernoulli distribution
do_bts = torch.logical_and(~is_bts_forced, ~is_eos)
if torch.any(do_bts):
is_bts_sampled = torch.zeros_like(do_bts)
logits_bts = self._get_policy_source_logit(logits_sampling)[do_bts]
distr_bts = Bernoulli(logits=logits_bts)
is_bts_sampled[do_bts] = tbool(distr_bts.sample(), device=self.device)
is_bts[is_bts_sampled] = True
# Sample relative increments if actions are neither BTS nor EOS
do_increments = torch.logical_and(~is_bts, ~is_eos)
if torch.any(do_increments):
distr_increments = self._make_increments_distribution(
logits_sampling[do_increments]
)
# Shape of increments_rel: [n_do_increments, n_dim]
increments = distr_increments.sample()
# Compute absolute increments from all sampled relative increments
states_from_rel = tfloat(
states_from, float_type=self.float, device=self.device
)[do_increments]
increments = self.relative_to_absolute_increments(
states_from_rel,
increments,
is_backward=True,
)
# Build actions
actions_tensor = torch.zeros(
(n_states, self.n_dim + 1), dtype=self.float, device=self.device
)
actions_tensor[is_eos] = tfloat(
self.eos, float_type=self.float, device=self.device
)
if torch.any(do_increments):
# Make increments of ignored dimensions zero
increments = self._mask_ignored_dimensions(mask[do_increments], increments)
# Add dimension is_source and add to actions tensor
actions_tensor[do_increments] = torch.cat(
(increments, torch.zeros((increments.shape[0], 1))), dim=1
)
if torch.any(is_bts):
# BTS actions are equal to the originating states
actions_bts = tfloat(
states_from, float_type=self.float, device=self.device
)[is_bts]
actions_bts = torch.cat(
(actions_bts, torch.ones((actions_bts.shape[0], 1))), dim=1
)
actions_tensor[is_bts] = actions_bts
# Make ignored dimensions zero
actions_tensor[is_bts, :-1] = self._mask_ignored_dimensions(
mask[is_bts], actions_tensor[is_bts, :-1]
)
return [tuple(a) for a in actions_tensor.tolist()]
[docs]
def get_logprobs(
self,
policy_outputs: TensorType["n_states", "policy_output_dim"],
actions: Union[List, TensorType["n_states", "action_dim"]],
mask: TensorType["n_states", "mask_dim"],
states_from: List,
is_backward: bool,
) -> TensorType["batch_size"]:
"""
Computes log probabilities of actions given policy outputs and actions.
Parameters
----------
policy_outputs : tensor
The output of the GFlowNet policy model.
mask : tensor
The mask containing information about invalid actions and special cases.
actions : list or tensor
The actions (absolute increments) from each state in the batch for which to
compute the log probability.
states_from : tensor
The states originating the actions, in GFlowNet format. They are required
so as to compute the relative increments and the Jacobian.
is_backward : bool
True if the actions are backward, False if the actions are forward
(default). Required, since the computation for forward and backward actions
is different.
"""
if is_backward:
return self._get_logprobs_backward(
policy_outputs, actions, mask, states_from
)
else:
return self._get_logprobs_forward(
policy_outputs, actions, mask, states_from
)
def _get_logprobs_forward(
self,
policy_outputs: TensorType["n_states", "policy_output_dim"],
actions: Union[List, TensorType["n_states", "action_dim"]],
mask: TensorType["n_states", "3"],
states_from: List,
) -> TensorType["batch_size"]:
"""
Computes log probabilities of forward actions.
"""
# Initialize variables
actions = tfloat(actions, float_type=self.float, device=self.device)
n_states = policy_outputs.shape[0]
states_from_tensor = tfloat(
states_from, float_type=self.float, device=self.device
)
is_eos = torch.zeros(n_states, dtype=torch.bool, device=self.device)
logprobs_eos = torch.zeros(n_states, dtype=self.float, device=self.device)
logprobs_increments_rel = torch.zeros(
(n_states, self.n_dim), dtype=self.float, device=self.device
)
log_jacobian_diag = torch.zeros(
(n_states, self.n_dim), device=self.device, dtype=self.float
)
eos_tensor = tfloat(self.eos, float_type=self.float, device=self.device)
# Determine source states
is_source = ~mask[:, 1]
# EOS is the only possible action if continuous actions are invalid (mask[0] is
# True)
is_eos_forced = mask[:, 0]
is_eos[is_eos_forced] = True
# Ensure that is_eos_forced does not include any source state
assert not torch.any(torch.logical_and(is_source, is_eos_forced))
# Get sampled EOS actions and get log probs from Bernoulli distribution
do_eos = torch.logical_and(~is_source, ~is_eos_forced)
if torch.any(do_eos):
is_eos_sampled = torch.zeros_like(do_eos)
is_eos_sampled[do_eos] = torch.all(actions[do_eos] == eos_tensor, dim=1)
is_eos[is_eos_sampled] = True
logits_eos = self._get_policy_eos_logit(policy_outputs)[do_eos]
distr_eos = Bernoulli(logits=logits_eos)
logprobs_eos[do_eos] = distr_eos.log_prob(
is_eos_sampled[do_eos].to(self.float)
)
# Get log probs of relative increments if EOS was not the sampled or forced
# action
do_increments = ~is_eos
if torch.any(do_increments):
# Get absolute increments
increments = actions[do_increments, :-1]
# Make sure increments are finite
assert torch.any(torch.isfinite(increments))
# Compute relative increments from absolute increments if state is not
# source
is_relative = ~is_source[do_increments]
if torch.any(is_relative):
states_from_rel = tfloat(
states_from_tensor[do_increments],
float_type=self.float,
device=self.device,
)[is_relative]
increments[is_relative] = self.absolute_to_relative_increments(
states_from_rel,
increments[is_relative],
is_backward=False,
)
# Compute diagonal of the Jacobian (see _get_jacobian_diag()) if state is
# not source
is_relative = torch.logical_and(do_increments, ~is_source)
if torch.any(is_relative):
log_jacobian_diag[is_relative] = torch.log(
self._get_jacobian_diag(
states_from_rel,
is_backward=False,
)
)
# Make ignored dimensions zero
log_jacobian_diag = self._mask_ignored_dimensions(mask, log_jacobian_diag)
# Get logprobs
distr_increments = self._make_increments_distribution(
policy_outputs[do_increments]
)
# Clamp because increments of 0.0 or 1.0 would yield nan
logprobs_increments_rel[do_increments] = distr_increments.log_prob(
torch.clamp(increments, min=self.epsilon, max=(1 - self.epsilon))
)
# Make ignored dimensions zero
logprobs_increments_rel = self._mask_ignored_dimensions(
mask, logprobs_increments_rel
)
# Sum log Jacobian across dimensions
log_det_jacobian = torch.sum(log_jacobian_diag, dim=1)
# Compute combined probabilities
sumlogprobs_increments = logprobs_increments_rel.sum(axis=1)
logprobs = logprobs_eos + sumlogprobs_increments + log_det_jacobian
return logprobs
def _get_logprobs_backward(
self,
policy_outputs: TensorType["n_states", "policy_output_dim"],
actions: Union[List, TensorType["n_states", "action_dim"]],
mask: TensorType["n_states", "3"],
states_from: List,
) -> TensorType["batch_size"]:
"""
Computes log probabilities of backward actions.
"""
# Initialize variables
actions = tfloat(actions, float_type=self.float, device=self.device)
n_states = policy_outputs.shape[0]
states_from_tensor = tfloat(
states_from, float_type=self.float, device=self.device
)
is_bts = torch.zeros(n_states, dtype=torch.bool, device=self.device)
logprobs_bts = torch.zeros(n_states, dtype=self.float, device=self.device)
logprobs_increments_rel = torch.zeros(
(n_states, self.n_dim), dtype=self.float, device=self.device
)
log_jacobian_diag = torch.zeros(
(n_states, self.n_dim), device=self.device, dtype=self.float
)
# EOS is the only possible action only if done is True (mask[2] is False)
is_eos = ~mask[:, 2]
# Back-to-source (BTS) is the only possible action if mask[1] is False
is_bts_forced = ~mask[:, 1]
is_bts[is_bts_forced] = True
# Get sampled BTS actions and get log probs from Bernoulli distribution
do_bts = torch.logical_and(~is_bts_forced, ~is_eos)
if torch.any(do_bts):
# BTS actions are equal to the originating states
is_bts_sampled = torch.zeros_like(do_bts)
is_bts_sampled[do_bts] = torch.all(
actions[do_bts, :-1] == states_from_tensor[do_bts], dim=1
)
is_bts[is_bts_sampled] = True
logits_bts = self._get_policy_source_logit(policy_outputs)[do_bts]
distr_bts = Bernoulli(logits=logits_bts)
logprobs_bts[do_bts] = distr_bts.log_prob(
is_bts_sampled[do_bts].to(self.float)
)
# Get log probs of relative increments if actions were neither BTS nor EOS
do_increments = torch.logical_and(~is_bts, ~is_eos)
if torch.any(do_increments):
# Get absolute increments
increments = actions[do_increments, :-1]
# Compute absolute increments from all sampled relative increments
increments = self.absolute_to_relative_increments(
states_from_tensor[do_increments],
increments,
is_backward=True,
)
# Make sure increments are finite
assert torch.all(torch.isfinite(increments))
# Compute diagonal of the Jacobian (see _get_jacobian_diag())
log_jacobian_diag[do_increments] = torch.log(
self._get_jacobian_diag(
states_from_tensor[do_increments],
is_backward=True,
)
)
# Make ignored dimensions zero
log_jacobian_diag = self._mask_ignored_dimensions(mask, log_jacobian_diag)
# Get logprobs
distr_increments = self._make_increments_distribution(
policy_outputs[do_increments]
)
# Clamp because increments of 0.0 or 1.0 would yield nan
logprobs_increments_rel[do_increments] = distr_increments.log_prob(
torch.clamp(increments, min=self.epsilon, max=(1 - self.epsilon))
)
# Make ignored dimensions zero
logprobs_increments_rel = self._mask_ignored_dimensions(
mask, logprobs_increments_rel
)
# Sum log Jacobian across dimensions
log_det_jacobian = torch.sum(log_jacobian_diag, dim=1)
# Compute combined probabilities
sumlogprobs_increments = logprobs_increments_rel.sum(axis=1)
logprobs = logprobs_bts + sumlogprobs_increments + log_det_jacobian
# Ensure that logprobs of forced EOS are 0
logprobs[is_eos] = 0.0
return logprobs
def _get_jacobian_diag(
self,
states_from: TensorType["n_states", "n_dim"],
is_backward: bool,
):
"""
Computes the diagonal of the Jacobian of the sampled actions with respect to
the target states.
Forward: the sampled variables are the relative increments r_f and the state
updates (s -> s') are:
s' = s + m + r_f(1 - s - m)
r_f = (s' - s - m) / (1 - s - m)
Therefore, the derivative of r_f wrt to s' is
dr_f/ds' = 1 / (1 - s - m)
Backward: the sampled variables are the relative decrements r_b and the state
updates (s' -> s) are:
s = s' - m - r_b(s' - m)
r_b = (s' - s - m) / (s' - m)
Therefore, the derivative of r_b wrt to s is
dr_b/ds = -1 / (s' - m)
We take the absolute value of the derivative (Jacobian).
The derivatives of the components of r with respect to dimensions of s or s'
other than itself are zero. Therefore, the Jacobian is diagonal and the
determinant is the product of the diagonal.
"""
min_increments = torch.full_like(
states_from, self.min_incr, dtype=self.float, device=self.device
)
if is_backward:
return 1.0 / ((states_from - min_increments))
else:
return 1.0 / ((1.0 - states_from - min_increments))
def _step(
self,
action: Tuple[float],
backward: bool,
) -> Tuple[List[float], Tuple[float], bool]:
"""
Updates self.state given a non-EOS action. This method is called by both step()
and step_backwards(), with the corresponding value of argument backward.
Args
----
action : tuple
Action to be executed. An action is a tuple of length n_dim, with the
absolute increment for each dimension.
backward : bool
If True, perform backward step. Otherwise (default), perform forward step.
Returns
-------
self.state : list
The sequence after executing the action
action : int
Action executed
valid : bool
False, if the action is not allowed for the current state, e.g. stop at the
root state
"""
# If forward action is from source, initialize state to all zeros.
if not backward and action[-1] == 1 and self.state == self.source:
state = [0.0 for _ in range(self.n_dim)]
else:
assert action[-1] == 0
state = copy(self.state)
# Increment dimensions
for dim, incr in enumerate(action[:-1]):
if backward:
state[dim] -= incr
else:
state[dim] += incr
# If state is out of bounds, return invalid
effective_dims = self._get_effective_dims(state)
if any([s > 1.0 for s in effective_dims]) or any(
[s < 0.0 for s in effective_dims]
):
warnings.warn(f"""
State is out of cube bounds.
\nCurrent state:\n{self.state}\nAction:\n{action}\nNext state: {state}
""")
return self.state, action, False
# Otherwise, set self.state as the udpated state and return valid.
self.n_actions += 1
self.state = state
return self.state, action, True
# TODO: make generic for continuous environments?
[docs]
def step(self, action: Tuple[float]) -> Tuple[List[float], Tuple[int, float], bool]:
"""
Executes step given an action. An action is the absolute increment of each
dimension.
Args
----
action : tuple
Action to be executed. An action is a tuple of length n_dim, with the
absolute increment for each dimension.
Returns
-------
self.state : list
The sequence after executing the action
action : int
Action executed
valid : bool
False, if the action is not allowed for the current state, e.g. stop at the
root state
"""
if self.done:
return self.state, action, False
if action == self.eos:
assert self._get_effective_dims(self.state) != self._get_effective_dims(
self.source
)
self.done = True
self.n_actions += 1
return self.state, self.eos, True
# Otherwise perform action
else:
return self._step(action, backward=False)
[docs]
def step_backwards(
self, action: Tuple[int, float]
) -> Tuple[List[float], Tuple[int, float], bool]:
"""
Executes backward step given an action. An action is the absolute decrement of
each dimension.
Args
----
action : tuple
Action to be executed. An action is a tuple of length n_dim, with the
absolute decrement for each dimension.
Returns
-------
self.state : list
The sequence after executing the action
action : int
Action executed
valid : bool
False, if the action is not allowed for the current state, e.g. stop at the
root state
"""
# If done is True, set done to False, increment n_actions and return same state
if self.done:
assert action == self.eos
self.done = False
self.n_actions += 1
return self.state, action, True
assert action != self.eos
# If action is BTS, set source state
if action[-1] == 1 and self.state != self.source:
self.n_actions += 1
self.state = self.source
return self.state, action, True
# Otherwise perform action
return self._step(action, backward=True)
[docs]
def get_grid_terminating_states(
self, n_states: int, kappa: Optional[float] = None
) -> List[List]:
"""
Constructs a grid of terminating states within the range of the hyper-cube.
Args
----
n_states : int
Requested number of states. The actual number of states will be rounded up
such that all dimensions have the same number of states.
kappa : float
Small constant indicating the distance to the theoretical limits of the
cube [0, 1], in order to avoid innacuracies in the computation of the log
probabilities due to clamping. The grid will thus be in [kappa, 1 -
kappa]. If None, self.kappa will be used.
"""
if kappa is None:
kappa = self.kappa
n_per_dim = int(np.ceil(n_states ** (1 / self.n_dim)))
linspaces = [
np.linspace(kappa, 1.0 - kappa, n_per_dim) for _ in range(self.n_dim)
]
states = list(itertools.product(*linspaces))
states = [list(el) for el in states]
return states
[docs]
class HybridCube(SetFix):
"""
Hybrid hyper-cube environment (continuous version of a hyper-grid) in which the
action space consists of the increment of one dimension dimension at a time, as
opposed to the ContinuousCube, in which all dimensions are incremented at once.
In practice, it is implemented as a SetFix of M 1D ContinuousCube environments, where
M is the dimensionality of the hyper-cube.
Attributes
----------
n_dim : int
Dimensionality of the hyper-cube.
min_incr : float
Minimum increment in the actions, in (0, 1). This is necessary to ensure
that all trajectories have finite length.
n_comp : int
Number of components in the mixture of Beta distributions.
"""
def __init__(self, **kwargs):
# Set number of dimensions as attribute
[docs]
self.n_dim = kwargs["n_dim"]
# Change number of dimensions in kwargs to 1, to be passed to the
# ContinuousCube environments
kwargs["n_dim"] = 1
# The sub-environments are self.n_dim 1D ContinuousCube environments
subenvs = [ContinuousCube(**kwargs) for _ in range(self.n_dim)]
super().__init__(subenvs=subenvs, **kwargs)
[docs]
def states2proxy(
self, states: List[List]
) -> TensorType["batch", "state_proxy_dim"]:
r"""
Prepares a batch of states in environment format for a proxy.
The input states are in the environment format of the Set. The outputs contain
only the Cube part and the format is as in
:py:meth:`gflownet.envs.cube.ContinuousCube.states2proxy`.
Parameters
----------
states : list
A batch of states in Set environment format.
Returns
-------
A tensor containing all the states in the batch.
"""
# Construct list of states as if they came from a ContinuousCube with self.n_dim
states = [[s[0] for s in self._get_substates(state)] for state in states]
# Convert to proxy via states2proxy() of the ContinuousCube
return self.subenvs[0].states2proxy(states)
[docs]
def get_grid_terminating_states(
self, n_states: int, kappa: Optional[float] = None
) -> List[List]:
"""
Constructs a grid of terminating states within the range of the hyper-cube.
Args
----
n_states : int
Requested number of states. The actual number of states will be rounded up
such that all dimensions have the same number of states.
kappa : float
Small constant indicating the distance to the theoretical limits of the
cube [0, 1], in order to avoid innacuracies in the computation of the log
probabilities due to clamping. The grid will thus be in [kappa, 1 -
kappa]. If None, self.kappa will be used.
"""
if kappa is None:
kappa = self.subenvs[0].kappa
n_per_dim = int(np.ceil(n_states ** (1 / self.n_dim)))
linspaces = [
np.linspace(kappa, 1.0 - kappa, n_per_dim) for _ in range(self.n_dim)
]
states = []
for dims in itertools.product(*linspaces):
state = copy(self.source)
for idx in range(self.n_subenvs):
state = self._set_subdone(idx, True, state)
state = self._set_substate(idx, [dims[idx]], state)
states.append(state)
return states
[docs]
def fit_kde(
self,
samples: TensorType["batch_size", "state_proxy_dim"],
kernel: str = "gaussian",
bandwidth: float = 0.1,
):
"""
Fits a Kernel Density Estimator on a batch of samples.
Simply calls fit_kde() of CubeBase.
"""
return CubeBase.fit_kde(self, samples, kernel, bandwidth)
[docs]
def plot_reward_samples(
self,
samples: TensorType["batch_size", "state_proxy_dim"],
samples_reward: TensorType["batch_size", "state_proxy_dim"],
rewards: TensorType["batch_size"],
alpha: float = 0.5,
dpi: int = 150,
max_samples: int = 500,
**kwargs,
):
"""
Plots the reward contour alongside a batch of samples.
Simply calls plot_reward_samples() of CubeBase.
"""
return CubeBase.plot_reward_samples(
self, samples, samples_reward, rewards, alpha, dpi, max_samples
)
[docs]
def plot_kde(
self,
samples: TensorType["batch_size", "state_proxy_dim"],
kde,
alpha: float = 0.5,
dpi=150,
colorbar: bool = True,
**kwargs,
):
"""
Plots the density previously estimated from a batch of samples via KDE over the
entire sample space.
Simply calls plot_reward_samples() of CubeBase.
"""
return CubeBase.plot_kde(self, samples, kde, alpha, dpi, colorbar)