"""
Classes to represent a hyper-grid environments
"""
import itertools
from typing import List, Optional, Tuple, Union
import matplotlib.pyplot as plt
import numpy as np
import numpy.typing as npt
import torch
from matplotlib.axes import Axes
from mpl_toolkits.axes_grid1 import make_axes_locatable
from torchtyping import TensorType
from gflownet.envs.base import GFlowNetEnv
from gflownet.utils.common import tfloat, tlong, torch2np
[docs]
class Grid(GFlowNetEnv):
"""
Hyper-grid environment: A grid with n_dim dimensions and length cells per
dimensions.
The state space is the entire grid and each state is represented by the vector of
coordinates of each dimensions. For example, in 3D, the origin will be at [0, 0, 0]
and after incrementing dimension 0 by 2, dimension 1 by 3 and dimension 3 by 1, the
state would be [2, 3, 1].
The action space is the increment to be applied to each dimension. For instance,
(0, 0, 1) will increment dimension 2 by 1 and the action that goes from [1, 1, 1]
to [2, 3, 1] is (1, 2, 0).
Attributes
----------
n_dim : int
Dimensionality of the grid
length : int
Size of the grid (cells per dimension)
max_increment : int
Maximum increment of each dimension by the actions.
max_dim_per_action : int
Maximum number of dimensions to increment per action. If -1, then
max_dim_per_action is set to n_dim.
cell_min : float
Lower bound of the cells range
cell_max : float
Upper bound of the cells range
"""
def __init__(
self,
n_dim: int = 2,
length: int = 3,
max_increment: int = 1,
max_dim_per_action: int = 1,
cell_min: float = -1,
cell_max: float = 1,
**kwargs,
):
assert n_dim > 0
assert length > 1
assert max_increment > 0
assert max_dim_per_action == -1 or max_dim_per_action > 0
[docs]
self.max_increment = max_increment
if max_dim_per_action == -1:
max_dim_per_action = self.n_dim
[docs]
self.max_dim_per_action = max_dim_per_action
[docs]
self.cells = np.linspace(cell_min, cell_max, length)
# Source state: position 0 at all dimensions
[docs]
self.source = [0 for _ in range(self.n_dim)]
# End-of-sequence action
[docs]
self.eos = tuple([0 for _ in range(self.n_dim)])
# Base class init
super().__init__(**kwargs)
[docs]
def get_action_space(self):
"""
Constructs list with all possible actions, including eos. An action is
represented by a vector of length n_dim where each index d indicates the
increment to apply to dimension d of the hyper-grid.
"""
increments = [el for el in range(self.max_increment + 1)]
actions = []
for action in itertools.product(increments, repeat=self.n_dim):
if (
sum(action) != 0
and len([el for el in action if el > 0]) <= self.max_dim_per_action
):
actions.append(tuple(action))
actions.append(self.eos)
return actions
[docs]
def get_mask_invalid_actions_forward(
self,
state: Optional[List] = None,
done: Optional[bool] = None,
) -> List:
"""
Returns a list of length the action space with values:
- True if the forward action is invalid from the current state.
- False otherwise.
"""
state = self._get_state(state)
done = self._get_done(done)
if done:
return [True for _ in range(self.policy_output_dim)]
mask = [False for _ in range(self.policy_output_dim)]
for idx, action in enumerate(self.action_space[:-1]):
child = state.copy()
for dim, incr in enumerate(action):
child[dim] += incr
if any(el >= self.length for el in child):
mask[idx] = True
return mask
[docs]
def states2proxy(
self, states: Union[List[List], TensorType["batch", "state_dim"]]
) -> TensorType["batch", "state_proxy_dim"]:
"""
Prepares a batch of states in "environment format" for the proxy: each state is
a vector of length n_dim with values in the range [cell_min, cell_max].
See: states2policy()
Args
----
states : list or tensor
A batch of states in environment format, either as a list of states or as a
single tensor.
Returns
-------
A tensor containing all the states in the batch.
"""
states = tfloat(states, device=self.device, float_type=self.float)
return (
self.states2policy(states).reshape(
(states.shape[0], self.n_dim, self.length)
)
* torch.tensor(self.cells[None, :]).to(states.device, self.float)
).sum(axis=2)
[docs]
def states2policy(
self, states: Union[List, TensorType["batch", "state_dim"]]
) -> TensorType["batch", "policy_input_dim"]:
"""
Prepares a batch of states in "environment format" for the policy model: states
are one-hot encoded.
The output is a 2D tensor, with the second dimension of size length * n_dim,
where each n-th successive block of length elements is a one-hot encoding of
the position in the n-th dimension.
Example (n_dim = 3, length = 4):
- state: [0, 3, 1]
- policy format: [1, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0]
| 0 | 3 | 1 |
Args
----
states : list or tensor
A batch of states in environment format, either as a list of states or as a
single tensor.
Returns
-------
A tensor containing all the states in the batch.
"""
states = tlong(states, device=self.device)
n_states = states.shape[0]
device = states.device
index_dtype = states.dtype
cols = (
states
+ torch.arange(self.n_dim, device=device, dtype=index_dtype) * self.length
)
rows = torch.repeat_interleave(
torch.arange(n_states, device=device, dtype=index_dtype), self.n_dim
)
states_policy = torch.zeros(
(n_states, self.length * self.n_dim), dtype=self.float, device=device
)
states_policy[rows, cols.flatten()] = 1.0
return states_policy
[docs]
def readable2state(self, readable, alphabet={}):
"""
Converts a human-readable string representing a state into a state as a list of
positions.
"""
return [int(el) for el in readable.strip("[]").split(" ") if el != ""]
[docs]
def state2readable(self, state: Optional[List] = None, alphabet={}):
"""
Converts a state (a list of positions) into a human-readable string
representing a state.
"""
state = self._get_state(state)
return str(state).replace("(", "[").replace(")", "]").replace(",", "")
[docs]
def get_parents(
self,
state: Optional[List] = None,
done: Optional[bool] = None,
action: Optional[Tuple] = None,
) -> Tuple[List, List]:
"""
Determines all parents and actions that lead to state.
Args
----
state : list
Representation of a state, as a list of length length where each element is
the position at each dimension.
done : bool
Whether the trajectory is done. If None, done is taken from instance.
action : None
Ignored
Returns
-------
parents : list
List of parents in state format
actions : list
List of actions that lead to state for each parent in parents
"""
state = self._get_state(state)
done = self._get_done(done)
if done:
return [state], [self.eos]
else:
parents = []
actions = []
for idx, action in enumerate(self.action_space[:-1]):
parent = state.copy()
for dim, incr in enumerate(action):
if parent[dim] - incr >= 0:
parent[dim] -= incr
else:
break
else:
parents.append(parent)
actions.append(action)
return parents, actions
[docs]
def step(
self, action: Tuple[int], skip_mask_check: bool = False
) -> Tuple[List[int], Tuple[int], bool]:
"""
Executes step given an action.
Args
----
action : tuple
Action to be executed. An action is a tuple int values indicating the
dimensions to increment by 1.
skip_mask_check : bool
If True, skip computing forward mask of invalid actions to check if the
action is valid.
Returns
-------
self.state : list
The sequence after executing the action
action : tuple
Action executed
valid : bool
False, if the action is not allowed for the current state.
"""
# Generic pre-step checks
do_step, self.state, action = self._pre_step(
action, skip_mask_check or self.skip_mask_check
)
if not do_step:
return self.state, action, False
# If only possible action is eos, then force eos
# All dimensions are at the maximum length
if all([s == self.length - 1 for s in self.state]):
self.done = True
self.n_actions += 1
return self.state, self.eos, True
# If action is not eos, then perform action
elif action != self.eos:
state_next = self.state.copy()
for dim, incr in enumerate(action):
state_next[dim] += incr
if any([s >= self.length for s in state_next]):
valid = False
else:
self.state = state_next
valid = True
self.n_actions += 1
return self.state, action, valid
# If action is eos, then perform eos
else:
self.done = True
self.n_actions += 1
return self.state, self.eos, True
def _get_max_trajectory_length(self) -> int:
"""
Returns the maximum trajectory length of the environment, including the EOS
action.
"""
return self.n_dim * self.length + 1
[docs]
def get_all_terminating_states(self) -> List[List]:
grid = np.meshgrid(*[range(self.length)] * self.n_dim)
all_x = np.stack(grid).reshape((self.n_dim, -1)).T
return all_x.tolist()
[docs]
def plot_reward_samples(
self,
samples: TensorType["batch_size", "state_proxy_dim"],
samples_reward: TensorType["batch_size", "state_proxy_dim"],
rewards: TensorType["batch_size"],
dpi: int = 150,
n_ticks_max: int = 50,
reward_norm: bool = True,
**kwargs,
):
"""
Plots the reward density as a 2D histogram on the grid, alongside a histogram
representing the samples density.
It is assumed that the rewards correspond to entire domain of the grid and are
sorted from left to right (first) and top to bottom of the grid of samples.
Parameters
----------
samples : tensor
A batch of samples from the GFlowNet policy in proxy format. These samples
will be plotted on top of the reward density.
samples_reward : tensor
A batch of samples containing a grid over the sample space, from which the
reward has been obtained. Ignored by this method.
rewards : tensor
The rewards of samples_reward. It should be a vector of dimensionality
length ** 2 and be sorted such that the each block at rewards[i *
length:i * length + length] correspond to the rewards at the i-th
row of the grid of samples, from top to bottom.
dpi : int
Dots per inch, indicating the resolution of the plot.
n_ticks_max : int
Maximum of number of ticks to include in the axes.
reward_norm : bool
Whether to normalize the histogram. True by default.
"""
# Only available for 2D grids
if self.n_dim != 2:
return None
samples = torch2np(samples)
rewards = torch2np(rewards)
assert rewards.shape[0] == self.length**2
# Init figure
fig, axes = plt.subplots(ncols=2, dpi=dpi)
step_ticks = np.ceil(self.length / n_ticks_max).astype(int)
# 2D histogram of samples
samples_hist, xedges, yedges = np.histogram2d(
samples[:, 0], samples[:, 1], bins=(self.length, self.length), density=True
)
# Transpose and reverse rows so that [0, 0] is at bottom left
samples_hist = samples_hist.T[::-1, :]
# Normalize and reshape reward into a grid with [0, 0] at the bottom left
if reward_norm:
rewards = rewards / rewards.sum()
rewards_2d = rewards.reshape(self.length, self.length).T[::-1, :]
# Plot reward
self._plot_grid_2d(rewards_2d, axes[0], step_ticks, title="True reward")
# Plot samples histogram
self._plot_grid_2d(samples_hist, axes[1], step_ticks, title="Samples density")
fig.tight_layout()
return fig
@staticmethod
def _plot_grid_2d(img: np.array, ax: Axes, step_ticks: int, title: str):
"""
Plots a 2D histogram of a grid environment as an image.
Parameters
----------
img : np.array
An array containing a 2D histogram over a grid.
ax : Axes
A matplotlib Axes object on which the image will be plotted.
step_ticks : int
The step value to add ticks to the axes. For example, if it is 2, the ticks
will be at 0, 2, 4, ...
title : str
Title for the axes.
"""
ax_img = ax.imshow(img)
divider = make_axes_locatable(ax)
cax = divider.append_axes("top", size="5%", pad=0.05)
ax.set_xticks(np.arange(start=0, stop=img.shape[0], step=step_ticks))
ax.set_yticks(np.arange(start=0, stop=img.shape[1], step=step_ticks)[::-1])
cax.set_title(title)
plt.colorbar(ax_img, cax=cax, orientation="horizontal")
cax.xaxis.set_ticks_position("top")