gflownet.envs.toy
Toy environment: A playground environment for small-scale, controllable experiments.
With the default values, the sample, state and action spaces of the environment are defined as in Figure 2 of the GFlowNet Foundations paper, Bengio et al (JMLR, 2023):
Classes
Toy environment: with the default values, the environment has a DAG as in Figure 2 |
Module Contents
- class gflownet.envs.toy.Toy(connections={0: (1, 2), 1: (3,), 2: (3, 4), 3: (5, -1), 4: (6, -1), 5: (7, 8), 6: (8, 10, -1), 7: (9,), 8: (9, -1), 9: (-1,), 10: (-1,)}, action_space_only_valid=True, **kwargs)[source]
Bases:
gflownet.envs.base.GFlowNetEnvToy environment: with the default values, the environment has a DAG as in Figure 2 of the GFlowNet Foundations paper.
The DAG can be described as follows:
The source state, s0, is connected to s1 and s2.
s1 is only connected to s3.
s2 is connected to s3 and s4.
s3 is a terminating state and is also connected to s5.
s4 is a terminating state and is also connected to s6.
s5 is connected to s7 and s8.
s6 is a terminating state and is also connected to s8 and s10.
s7 is only connected to s9.
s8 is a terminating state and is also connected to s9.
s9 is a terminating state and is not connected to other states.
s10 is a terminating state and is not connected to other states.
Therefore, the terminating states are s3, s4, s6, s8, s9 and s10.
States are represented as a single-element list with the identifying integer.
Actions are represented as tuples with two integers, where the first element is the source state and the second element is the target state, interpreted in the forward direction. The end-of-sequence is (-1, -1).
- Parameters:
connections (dict)
action_space_only_valid (bool)
- connections[source]
A dictionary of state connections. Each key is a state index, and the values are an iterable (tuple) of state indices to which the state is connected. If the state is a terminating state, then -1 must be included in the iterable.
- Type:
dict
- action_space_only_valid[source]
Whether the action space should be restricted to only the valid actions (True), or instead it should contain or theoretically available actions, that is between any two pairs of states (False).
- Type:
bool
- get_action_space()[source]
Constructs list with all possible actions, including eos.
Actions are represented as tuples with two integers, where the first element is the source state and the second element is the target state, interpreted in the forward direction. The end-of-sequence is (-1, -1).
- Return type:
List[Tuple]
- get_mask_invalid_actions_forward(state=None, done=None)[source]
- Returns a list of length the action space with values:
True if the forward action is invalid from the current state.
False otherwise.
- Parameters:
state (list) – Input state. If None, self.state is used.
done (bool) – Whether the trajectory is done. If None, self.done is used.
- Returns:
A list of boolean values.
- Return type:
List[bool]
- get_parents(state=None, done=None, action=None)[source]
Determines all parents and actions that lead to state.
- Parameters:
state (list) – Input state. If None, self.state is used.
done (bool) – Whether the trajectory is done. If None, self.done is used.
action (None) – Ignored
- Returns:
parents (list) – List of valid parents in environment format.
actions (list) – List of actions that lead to state for each parent in parents.
- Return type:
Tuple[List, List]
- step(action, skip_mask_check=False)[source]
Performs step given an action.
- Parameters:
action (tuple) – Action to be performed. An action is a tuple indicating the source and target states, in the forward sense.
skip_mask_check (bool) – If True, skip computing forward mask of invalid actions to check if the action is valid.
- Returns:
self.state (list) – The state after performing the action.
action (tuple) – The performed action or attempted to be performed.
valid (bool) – False, if the action is not allowed for the current state.
- Return type:
[List[int], Tuple[int], bool]
- states2proxy(states)[source]
Prepares a batch of states in environment format for a proxy: the batch is simply converted into a tensor of state indices.
- Parameters:
states (list) – A batch of states in environment format, as a list of states.
- Returns:
A 2D tensor containing all the states in the batch.
- Return type:
torchtyping.TensorType[batch, 1]
- states2policy(states)[source]
Prepares a batch of states in environment format for the policy model: states indices are one-hot encoded.
- Parameters:
states (list) – A batch of states in environment format, as a list of states.
- Returns:
A 2D tensor containing all the states in the batch.
- Return type:
torchtyping.TensorType[batch, n_states]
- state2readable(state=None)[source]
Converts a state into a human-readable string.
The output string is simply “s” followed by the index of the state.
- Parameters:
state (list) – A state in environment format. If None, self.state is used.
- Returns:
A string representing the state
- Return type:
str
- readable2state(readable)[source]
Converts a state in readable format into the environment format.
- Parameters:
readable (str) – A state in readable format.
- Returns:
A state in environment format.
- Return type:
List[int]