base

Represent sequence-like environments.

Sequences are constructed by adding tokens from a dictionary, from left to right.

Classes

SequenceBase

Initialize a sequence environment.

Module Contents

class base.SequenceBase(tokens=[0, 1], min_length=1, max_length=5, pad_token=-1, **kwargs)[source]

Bases: gflownet.envs.base.GFlowNetEnv

Initialize a sequence environment.

Parameters:
  • tokens (Iterable) – Vocabulary of tokens used to build sequences.

  • min_length (int) – Minimum valid sequence length before the EOS action is allowed.

  • max_length (int) – Maximum sequence length.

  • pad_token (int, float, str) – Token used to pad incomplete sequences.

  • **kwargs – Additional keyword arguments forwarded to GFlowNetEnv.

device[source]
tokens = (0, 1)[source]
pad_token = -1[source]
n_tokens = 2[source]
min_length = 1[source]
max_length = 5[source]
eos_idx = -1[source]
pad_idx = 0[source]
dtype[source]
idx2token[source]
token2idx[source]
source[source]
eos[source]
get_action_space()[source]

Construct the list of all possible actions, including EOS.

An action is represented by a single-element tuple indicating the index of the token to be added to the current sequence (state).

The action space of this parent class is:

action_space: [(1,), (2,), (-1,)]

Return type:

List[Tuple]

get_mask_invalid_actions_forward(state=None, done=None)[source]

Return the mask of invalid forward actions.

The returned list has one entry per action:
  • True if the forward action is invalid from the current state.

  • False otherwise.

Parameters:
  • state (tensor) – Input state. If None, self.state is used.

  • done (bool) – Whether the trajectory is done. If None, self.done is used.

Returns:

A list of boolean values.

Return type:

List[bool]

get_parents(state=None, done=None, action=None)[source]

Determine all parents and actions that lead to a state.

The GFlowNet graph is a tree and there is only one parent per state.

Parameters:
  • state (tensor) – Input state. If None, self.state is used.

  • done (bool) – Whether the trajectory is done. If None, self.done is used.

  • action (None) – Ignored

Returns:

  • parents (list) – List of parents in state format. This environment has a single parent per state.

  • actions (list) – List of actions that lead to state for each parent in parents. This environment has a single parent per state.

Return type:

Tuple[List, List]

step(action, skip_mask_check=False)[source]

Execute a step for the given action.

Parameters:
  • action (tuple) – Action to be executed. An action is represented by a single-element tuple indicating the index of the token to be added to the current sequence (state).

  • skip_mask_check (bool) – If True, skip computing forward mask of invalid actions to check if the action is valid.

Returns:

  • self.state (list) – The sequence after executing the action

  • action (tuple) – Action executed

  • valid (bool) – False, if the action is not allowed for the current state.

Return type:

Tuple[torchtyping.TensorType[max_length], Tuple[int], bool]

states2proxy(states)[source]

Prepare a batch of states for a proxy.

States are represented by the tokens instead of the indices, with padding up to the max_length.

Important: by default, the output of states2proxy() is a list of lists, instead of a tensor as in most environments. This is to allow for string tokens.

Example, with max_length = 5:
  • Sequence (tokens): 0100

  • state: [1, 2, 1, 1, 0]

  • proxy format: [0, 1, 0, 0, -1]

Parameters:

states (list or tensor) – A batch of states in environment format, either as a list of states or as a single tensor.

Returns:

A list containing all the states in the batch, represented themselves as lists.

Return type:

List[List]

states2policy(states)[source]

Prepare a batch of states for the policy model.

States are one-hot encoded.

Example, with max_length = 5:
  • Sequence (tokens): 0100

  • state: [1, 2, 1, 1, 0]

  • policy format: [0, 1, 0, 0, 0, 1, 0, 1, 0, 0, 1, 0, 1, 0, 0]
    0 | 1 | 0 | 0 | PAD |
Parameters:

states (list or tensor) – A batch of states in environment format, either as a list of states or as a single tensor.

Returns:

A tensor containing all the states in the batch.

Return type:

torchtyping.TensorType[batch, policy_input_dim]

state2readable(state=None)[source]

Convert a state into a human-readable string.

Example, with max_length = 5:
  • state: [1, 2, 1, 1, 0]

  • readable: “0 1 0 0”

The output string contains the token corresponding to each index in the state, separated by spaces.

Parameters:
  • states (tensor) – A state in environment format. If None, self.state is used.

  • state (Optional[torchtyping.TensorType[max_length]])

Returns:

A string of space-separated tokens.

Return type:

str

readable2state(readable)[source]

Convert a readable state into environment format.

Example, with max_length = 5:
  • readable: “0 1 0 0”

  • state: [1, 2, 1, 1, 0]

Parameters:

readable (str) – A state in readable format - space-separated tokens.

Returns:

A tensor containing the indices of the tokens.

Return type:

torchtyping.TensorType[max_length]

get_all_terminating_states()[source]

Construct a batch with all terminating states in the sample space.

Return type:

List[torchtyping.TensorType[max_length]]

get_uniform_terminating_states(n_states, seed=None)[source]

Construct a batch of states sampled uniformly from the sample space.

Parameters:
  • n_states (int) – The number of states to sample.

  • seed (int) – Random seed.

Return type:

List[torchtyping.TensorType[max_length]]