gflownet.envs.grid

Classes to represent a hyper-grid environments

Classes

Grid

Hyper-grid environment: A grid with n_dim dimensions and length cells per

Module Contents

class gflownet.envs.grid.Grid(n_dim=2, length=3, max_increment=1, max_dim_per_action=1, cell_min=-1, cell_max=1, **kwargs)[source]

Bases: gflownet.envs.base.GFlowNetEnv

Hyper-grid environment: A grid with n_dim dimensions and length cells per dimensions.

The state space is the entire grid and each state is represented by the vector of coordinates of each dimensions. For example, in 3D, the origin will be at [0, 0, 0] and after incrementing dimension 0 by 2, dimension 1 by 3 and dimension 3 by 1, the state would be [2, 3, 1].

The action space is the increment to be applied to each dimension. For instance, (0, 0, 1) will increment dimension 2 by 1 and the action that goes from [1, 1, 1] to [2, 3, 1] is (1, 2, 0).

Parameters:
  • n_dim (int)

  • length (int)

  • max_increment (int)

  • max_dim_per_action (int)

  • cell_min (float)

  • cell_max (float)

n_dim[source]

Dimensionality of the grid

Type:

int

length[source]

Size of the grid (cells per dimension)

Type:

int

max_increment[source]

Maximum increment of each dimension by the actions.

Type:

int

max_dim_per_action[source]

Maximum number of dimensions to increment per action. If -1, then max_dim_per_action is set to n_dim.

Type:

int

cell_min

Lower bound of the cells range

Type:

float

cell_max

Upper bound of the cells range

Type:

float

n_dim = 2[source]
length = 3[source]
max_increment = 1[source]
max_dim_per_action = 1[source]
cells[source]
source[source]
eos[source]
get_action_space()[source]

Constructs list with all possible actions, including eos. An action is represented by a vector of length n_dim where each index d indicates the increment to apply to dimension d of the hyper-grid.

get_mask_invalid_actions_forward(state=None, done=None)[source]
Returns a list of length the action space with values:
  • True if the forward action is invalid from the current state.

  • False otherwise.

Parameters:
  • state (Optional[List])

  • done (Optional[bool])

Return type:

List

states2proxy(states)[source]

Prepares a batch of states in “environment format” for the proxy: each state is a vector of length n_dim with values in the range [cell_min, cell_max].

See: states2policy()

Parameters:

states (list or tensor) – A batch of states in environment format, either as a list of states or as a single tensor.

Returns:

A tensor containing all the states in the batch.

Return type:

torchtyping.TensorType[batch, state_proxy_dim]

states2policy(states)[source]

Prepares a batch of states in “environment format” for the policy model: states are one-hot encoded.

The output is a 2D tensor, with the second dimension of size length * n_dim, where each n-th successive block of length elements is a one-hot encoding of the position in the n-th dimension.

Example (n_dim = 3, length = 4):
  • state: [0, 3, 1]

  • policy format: [1, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0]
    0 | 3 | 1 |
Parameters:

states (list or tensor) – A batch of states in environment format, either as a list of states or as a single tensor.

Returns:

A tensor containing all the states in the batch.

Return type:

torchtyping.TensorType[batch, policy_input_dim]

readable2state(readable, alphabet={})[source]

Converts a human-readable string representing a state into a state as a list of positions.

state2readable(state=None, alphabet={})[source]

Converts a state (a list of positions) into a human-readable string representing a state.

Parameters:

state (Optional[List])

get_parents(state=None, done=None, action=None)[source]

Determines all parents and actions that lead to state.

Parameters:
  • state (list) – Representation of a state, as a list of length length where each element is the position at each dimension.

  • done (bool) – Whether the trajectory is done. If None, done is taken from instance.

  • action (None) – Ignored

Returns:

  • parents (list) – List of parents in state format

  • actions (list) – List of actions that lead to state for each parent in parents

Return type:

Tuple[List, List]

step(action, skip_mask_check=False)[source]

Executes step given an action.

Parameters:
  • action (tuple) – Action to be executed. An action is a tuple int values indicating the dimensions to increment by 1.

  • skip_mask_check (bool) – If True, skip computing forward mask of invalid actions to check if the action is valid.

Returns:

  • self.state (list) – The sequence after executing the action

  • action (tuple) – Action executed

  • valid (bool) – False, if the action is not allowed for the current state.

Return type:

Tuple[List[int], Tuple[int], bool]

get_all_terminating_states()[source]
Return type:

List[List]

get_uniform_terminating_states(n_states, seed=None)[source]
Parameters:
  • n_states (int)

  • seed (int)

Return type:

List[List]

plot_reward_samples(samples, samples_reward, rewards, dpi=150, n_ticks_max=50, reward_norm=True, **kwargs)[source]

Plots the reward density as a 2D histogram on the grid, alongside a histogram representing the samples density.

It is assumed that the rewards correspond to entire domain of the grid and are sorted from left to right (first) and top to bottom of the grid of samples.

Parameters:
  • samples (tensor) – A batch of samples from the GFlowNet policy in proxy format. These samples will be plotted on top of the reward density.

  • samples_reward (tensor) – A batch of samples containing a grid over the sample space, from which the reward has been obtained. Ignored by this method.

  • rewards (tensor) – The rewards of samples_reward. It should be a vector of dimensionality length ** 2 and be sorted such that the each block at rewards[i * length:i * length + length] correspond to the rewards at the i-th row of the grid of samples, from top to bottom.

  • dpi (int) – Dots per inch, indicating the resolution of the plot.

  • n_ticks_max (int) – Maximum of number of ticks to include in the axes.

  • reward_norm (bool) – Whether to normalize the histogram. True by default.