gflownet.envs.grid
Classes to represent a hyper-grid environments
Classes
Hyper-grid environment: A grid with n_dim dimensions and length cells per |
Module Contents
- class gflownet.envs.grid.Grid(n_dim=2, length=3, max_increment=1, max_dim_per_action=1, cell_min=-1, cell_max=1, **kwargs)[source]
Bases:
gflownet.envs.base.GFlowNetEnvHyper-grid environment: A grid with n_dim dimensions and length cells per dimensions.
The state space is the entire grid and each state is represented by the vector of coordinates of each dimensions. For example, in 3D, the origin will be at [0, 0, 0] and after incrementing dimension 0 by 2, dimension 1 by 3 and dimension 3 by 1, the state would be [2, 3, 1].
The action space is the increment to be applied to each dimension. For instance, (0, 0, 1) will increment dimension 2 by 1 and the action that goes from [1, 1, 1] to [2, 3, 1] is (1, 2, 0).
- Parameters:
n_dim (int)
length (int)
max_increment (int)
max_dim_per_action (int)
cell_min (float)
cell_max (float)
- max_dim_per_action[source]
Maximum number of dimensions to increment per action. If -1, then max_dim_per_action is set to n_dim.
- Type:
int
- cell_min
Lower bound of the cells range
- Type:
float
- cell_max
Upper bound of the cells range
- Type:
float
- get_action_space()[source]
Constructs list with all possible actions, including eos. An action is represented by a vector of length n_dim where each index d indicates the increment to apply to dimension d of the hyper-grid.
- get_mask_invalid_actions_forward(state=None, done=None)[source]
- Returns a list of length the action space with values:
True if the forward action is invalid from the current state.
False otherwise.
- Parameters:
state (Optional[List])
done (Optional[bool])
- Return type:
List
- states2proxy(states)[source]
Prepares a batch of states in “environment format” for the proxy: each state is a vector of length n_dim with values in the range [cell_min, cell_max].
See: states2policy()
- Parameters:
states (list or tensor) – A batch of states in environment format, either as a list of states or as a single tensor.
- Returns:
A tensor containing all the states in the batch.
- Return type:
torchtyping.TensorType[batch, state_proxy_dim]
- states2policy(states)[source]
Prepares a batch of states in “environment format” for the policy model: states are one-hot encoded.
The output is a 2D tensor, with the second dimension of size length * n_dim, where each n-th successive block of length elements is a one-hot encoding of the position in the n-th dimension.
- Example (n_dim = 3, length = 4):
state: [0, 3, 1]
- policy format: [1, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0]
- 0 | 3 | 1 |
- Parameters:
states (list or tensor) – A batch of states in environment format, either as a list of states or as a single tensor.
- Returns:
A tensor containing all the states in the batch.
- Return type:
torchtyping.TensorType[batch, policy_input_dim]
- readable2state(readable, alphabet={})[source]
Converts a human-readable string representing a state into a state as a list of positions.
- state2readable(state=None, alphabet={})[source]
Converts a state (a list of positions) into a human-readable string representing a state.
- Parameters:
state (Optional[List])
- get_parents(state=None, done=None, action=None)[source]
Determines all parents and actions that lead to state.
- Parameters:
state (list) – Representation of a state, as a list of length length where each element is the position at each dimension.
done (bool) – Whether the trajectory is done. If None, done is taken from instance.
action (None) – Ignored
- Returns:
parents (list) – List of parents in state format
actions (list) – List of actions that lead to state for each parent in parents
- Return type:
Tuple[List, List]
- step(action, skip_mask_check=False)[source]
Executes step given an action.
- Parameters:
action (tuple) – Action to be executed. An action is a tuple int values indicating the dimensions to increment by 1.
skip_mask_check (bool) – If True, skip computing forward mask of invalid actions to check if the action is valid.
- Returns:
self.state (list) – The sequence after executing the action
action (tuple) – Action executed
valid (bool) – False, if the action is not allowed for the current state.
- Return type:
Tuple[List[int], Tuple[int], bool]
- get_uniform_terminating_states(n_states, seed=None)[source]
- Parameters:
n_states (int)
seed (int)
- Return type:
List[List]
- plot_reward_samples(samples, samples_reward, rewards, dpi=150, n_ticks_max=50, reward_norm=True, **kwargs)[source]
Plots the reward density as a 2D histogram on the grid, alongside a histogram representing the samples density.
It is assumed that the rewards correspond to entire domain of the grid and are sorted from left to right (first) and top to bottom of the grid of samples.
- Parameters:
samples (tensor) – A batch of samples from the GFlowNet policy in proxy format. These samples will be plotted on top of the reward density.
samples_reward (tensor) – A batch of samples containing a grid over the sample space, from which the reward has been obtained. Ignored by this method.
rewards (tensor) – The rewards of samples_reward. It should be a vector of dimensionality length ** 2 and be sorted such that the each block at rewards[i * length:i * length + length] correspond to the rewards at the i-th row of the grid of samples, from top to bottom.
dpi (int) – Dots per inch, indicating the resolution of the plot.
n_ticks_max (int) – Maximum of number of ticks to include in the axes.
reward_norm (bool) – Whether to normalize the histogram. True by default.