gflownet.envs.torus
IMPORTANT: this environment is currently broken!
Classes to represent hyper-torus environments
Classes
Hyper-torus environment in which the action space consists of: |
Module Contents
- class gflownet.envs.torus.Torus(n_dim=2, n_angles=3, length_traj=1, max_increment=1, max_dim_per_action=1, **kwargs)[source]
Bases:
gflownet.envs.base.GFlowNetEnv- Hyper-torus environment in which the action space consists of:
Increasing the angle index of dimension d
Decreasing the angle index of dimension d
Keeping all dimensions as are
and the trajectory is of fixed length length_traj.
The states space is the concatenation of the angle index at each dimension and the number of actions.
- Parameters:
n_dim (int)
n_angles (int)
length_traj (int)
max_increment (int)
max_dim_per_action (int)
- ndim
Dimensionality of the torus
- Type:
int
- get_action_space()[source]
Constructs list with all possible actions, including eos. An action is represented by a vector of length n_dim where each index d indicates the increment/decrement to apply to dimension d of the hyper-torus. A negative value indicates a decrement. The action “keep” (no increment/decrement of any dimensions) is valid and is indicated by all zeros.
- get_mask_invalid_actions_forward(state=None, done=None)[source]
- Returns a list of length the action space with values:
True if the forward action is invalid from the current state.
False otherwise.
All actions except EOS are valid if the maximum number of actions has not been reached, and vice versa.
- Parameters:
state (Optional[List])
done (Optional[bool])
- Return type:
List
- states2proxy(states)[source]
Prepares a batch of states in “environment format” for the proxy: each state is a vector of length n_dim where each value is an angle in radians. The n_actions item is removed.
- Parameters:
states (list or tensor) – A batch of states in environment format, either as a list of states or as a single tensor.
- Returns:
A tensor containing all the states in the batch.
- Return type:
torchtyping.TensorType[batch, state_proxy_dim]
- states2policy(states)[source]
Prepares a batch of states in “environment format” for the policy model: the policy format is a one-hot encoding of the states.
Each row is a vector of length n_angles * n_dim + 1, where each n-th successive block of length elements is a one-hot encoding of the position in the n-th dimension.
- Example, n_dim = 2, n_angles = 4:
- state: [1, 3, 4]
- a | n | (a = angles, n = n_actions)
- policy format: [0, 1, 0, 0, 0, 0, 0, 1, 4]
- 1 | 3 | 4 |
- Parameters:
states (list or tensor) – A batch of states in environment format, either as a list of states or as a single tensor.
- Returns:
A tensor containing all the states in the batch.
- Return type:
torchtyping.TensorType[batch, policy_input_dim]
- state2readable(state=None)[source]
Converts a state (a list of positions) into a human-readable string representing a state.
- Parameters:
state (Optional[List])
- Return type:
str
- readable2state(readable)[source]
Converts a human-readable string representing a state into a state as a list of positions.
- Parameters:
readable (str)
- Return type:
List
- get_parents(state=None, done=None, action=None)[source]
Determines all parents and actions that lead to state.
- Parameters:
state (list) – Representation of a state, as a list of length n_angles where each element is the position at each dimension.
done (bool) – Whether the trajectory is done. If None, done is taken from instance.
action (None) – Ignored
- Returns:
parents (list) – List of parents in state format
actions (list) – List of actions that lead to state for each parent in parents
- Return type:
Tuple[List, List]
- step(action, skip_mask_check=False)[source]
Executes step given an action.
- Parameters:
action (tuple) – Action to be executed. See: get_action_space()
skip_mask_check (bool) – If True, skip computing forward mask of invalid actions to check if the action is valid.
- Returns:
self.state (list) – The sequence after executing the action
action (tuple) – Action executed
valid (bool) – False, if the action is not allowed for the current state.
- Return type:
Tuple[List[int], Tuple[int], bool]