base
Represent sequence-like environments.
Sequences are constructed by adding tokens from a dictionary, from left to right.
Classes
Initialize a sequence environment. |
Module Contents
- class base.SequenceBase(tokens=[0, 1], min_length=1, max_length=5, pad_token=-1, **kwargs)[source]
Bases:
gflownet.envs.base.GFlowNetEnvInitialize a sequence environment.
- Parameters:
tokens (Iterable) – Vocabulary of tokens used to build sequences.
min_length (int) – Minimum valid sequence length before the EOS action is allowed.
max_length (int) – Maximum sequence length.
pad_token (int, float, str) – Token used to pad incomplete sequences.
**kwargs – Additional keyword arguments forwarded to
GFlowNetEnv.
- get_action_space()[source]
Construct the list of all possible actions, including EOS.
An action is represented by a single-element tuple indicating the index of the token to be added to the current sequence (state).
- The action space of this parent class is:
action_space: [(1,), (2,), (-1,)]
- Return type:
List[Tuple]
- get_mask_invalid_actions_forward(state=None, done=None)[source]
Return the mask of invalid forward actions.
- The returned list has one entry per action:
True if the forward action is invalid from the current state.
False otherwise.
- Parameters:
state (tensor) – Input state. If None, self.state is used.
done (bool) – Whether the trajectory is done. If None, self.done is used.
- Returns:
A list of boolean values.
- Return type:
List[bool]
- get_parents(state=None, done=None, action=None)[source]
Determine all parents and actions that lead to a state.
The GFlowNet graph is a tree and there is only one parent per state.
- Parameters:
state (tensor) – Input state. If None, self.state is used.
done (bool) – Whether the trajectory is done. If None, self.done is used.
action (None) – Ignored
- Returns:
parents (list) – List of parents in state format. This environment has a single parent per state.
actions (list) – List of actions that lead to state for each parent in parents. This environment has a single parent per state.
- Return type:
Tuple[List, List]
- step(action, skip_mask_check=False)[source]
Execute a step for the given action.
- Parameters:
action (tuple) – Action to be executed. An action is represented by a single-element tuple indicating the index of the token to be added to the current sequence (state).
skip_mask_check (bool) – If True, skip computing forward mask of invalid actions to check if the action is valid.
- Returns:
self.state (list) – The sequence after executing the action
action (tuple) – Action executed
valid (bool) – False, if the action is not allowed for the current state.
- Return type:
Tuple[torchtyping.TensorType[max_length], Tuple[int], bool]
- states2proxy(states)[source]
Prepare a batch of states for a proxy.
States are represented by the tokens instead of the indices, with padding up to the max_length.
Important: by default, the output of states2proxy() is a list of lists, instead of a tensor as in most environments. This is to allow for string tokens.
- Example, with max_length = 5:
Sequence (tokens): 0100
state: [1, 2, 1, 1, 0]
proxy format: [0, 1, 0, 0, -1]
- Parameters:
states (list or tensor) – A batch of states in environment format, either as a list of states or as a single tensor.
- Returns:
A list containing all the states in the batch, represented themselves as lists.
- Return type:
List[List]
- states2policy(states)[source]
Prepare a batch of states for the policy model.
States are one-hot encoded.
- Example, with max_length = 5:
Sequence (tokens): 0100
state: [1, 2, 1, 1, 0]
- policy format: [0, 1, 0, 0, 0, 1, 0, 1, 0, 0, 1, 0, 1, 0, 0]
- 0 | 1 | 0 | 0 | PAD |
- Parameters:
states (list or tensor) – A batch of states in environment format, either as a list of states or as a single tensor.
- Returns:
A tensor containing all the states in the batch.
- Return type:
torchtyping.TensorType[batch, policy_input_dim]
- state2readable(state=None)[source]
Convert a state into a human-readable string.
- Example, with max_length = 5:
state: [1, 2, 1, 1, 0]
readable: “0 1 0 0”
The output string contains the token corresponding to each index in the state, separated by spaces.
- Parameters:
states (tensor) – A state in environment format. If None, self.state is used.
state (Optional[torchtyping.TensorType[max_length]])
- Returns:
A string of space-separated tokens.
- Return type:
str
- readable2state(readable)[source]
Convert a readable state into environment format.
- Example, with max_length = 5:
readable: “0 1 0 0”
state: [1, 2, 1, 1, 0]
- Parameters:
readable (str) – A state in readable format - space-separated tokens.
- Returns:
A tensor containing the indices of the tokens.
- Return type:
torchtyping.TensorType[max_length]