gflownet.envs.torus

IMPORTANT: this environment is currently broken!

Classes to represent hyper-torus environments

Classes

Torus

Hyper-torus environment in which the action space consists of:

Module Contents

class gflownet.envs.torus.Torus(n_dim=2, n_angles=3, length_traj=1, max_increment=1, max_dim_per_action=1, **kwargs)[source]

Bases: gflownet.envs.base.GFlowNetEnv

Hyper-torus environment in which the action space consists of:
  • Increasing the angle index of dimension d

  • Decreasing the angle index of dimension d

  • Keeping all dimensions as are

and the trajectory is of fixed length length_traj.

The states space is the concatenation of the angle index at each dimension and the number of actions.

Parameters:
  • n_dim (int)

  • n_angles (int)

  • length_traj (int)

  • max_increment (int)

  • max_dim_per_action (int)

ndim

Dimensionality of the torus

Type:

int

n_angles[source]

Number of angles into which each dimension is divided

Type:

int

length_traj[source]

Fixed length of the trajectory.

Type:

int

n_dim = 2[source]
n_angles = 3[source]
length_traj = 1[source]
max_increment = 1[source]
max_dim_per_action = 1[source]
source_angles[source]
source[source]
eos[source]
angle_rad[source]
get_action_space()[source]

Constructs list with all possible actions, including eos. An action is represented by a vector of length n_dim where each index d indicates the increment/decrement to apply to dimension d of the hyper-torus. A negative value indicates a decrement. The action “keep” (no increment/decrement of any dimensions) is valid and is indicated by all zeros.

get_mask_invalid_actions_forward(state=None, done=None)[source]
Returns a list of length the action space with values:
  • True if the forward action is invalid from the current state.

  • False otherwise.

All actions except EOS are valid if the maximum number of actions has not been reached, and vice versa.

Parameters:
  • state (Optional[List])

  • done (Optional[bool])

Return type:

List

states2proxy(states)[source]

Prepares a batch of states in “environment format” for the proxy: each state is a vector of length n_dim where each value is an angle in radians. The n_actions item is removed.

Parameters:

states (list or tensor) – A batch of states in environment format, either as a list of states or as a single tensor.

Returns:

A tensor containing all the states in the batch.

Return type:

torchtyping.TensorType[batch, state_proxy_dim]

states2policy(states)[source]

Prepares a batch of states in “environment format” for the policy model: the policy format is a one-hot encoding of the states.

Each row is a vector of length n_angles * n_dim + 1, where each n-th successive block of length elements is a one-hot encoding of the position in the n-th dimension.

Example, n_dim = 2, n_angles = 4:
  • state: [1, 3, 4]
    a | n | (a = angles, n = n_actions)
  • policy format: [0, 1, 0, 0, 0, 0, 0, 1, 4]
    1 | 3 | 4 |
Parameters:

states (list or tensor) – A batch of states in environment format, either as a list of states or as a single tensor.

Returns:

A tensor containing all the states in the batch.

Return type:

torchtyping.TensorType[batch, policy_input_dim]

state2readable(state=None)[source]

Converts a state (a list of positions) into a human-readable string representing a state.

Parameters:

state (Optional[List])

Return type:

str

readable2state(readable)[source]

Converts a human-readable string representing a state into a state as a list of positions.

Parameters:

readable (str)

Return type:

List

get_parents(state=None, done=None, action=None)[source]

Determines all parents and actions that lead to state.

Parameters:
  • state (list) – Representation of a state, as a list of length n_angles where each element is the position at each dimension.

  • done (bool) – Whether the trajectory is done. If None, done is taken from instance.

  • action (None) – Ignored

Returns:

  • parents (list) – List of parents in state format

  • actions (list) – List of actions that lead to state for each parent in parents

Return type:

Tuple[List, List]

step(action, skip_mask_check=False)[source]

Executes step given an action.

Parameters:
  • action (tuple) – Action to be executed. See: get_action_space()

  • skip_mask_check (bool) – If True, skip computing forward mask of invalid actions to check if the action is valid.

Returns:

  • self.state (list) – The sequence after executing the action

  • action (tuple) – Action executed

  • valid (bool) – False, if the action is not allowed for the current state.

Return type:

Tuple[List[int], Tuple[int], bool]

get_all_terminating_states()[source]
fit_kde(kernel='exponential', bandwidth=0.1)[source]