gflownet.envs.choice

A very simple environment to sample one element from a given set of options.

Given a set of options, the environment proceeds to select one of the options from the source state and then only the end-of-sequence action is valid.

Classes

Choice

Initializes a Choice environment.

Module Contents

class gflownet.envs.choice.Choice(options=None, n_options=3, source_readable='<source>', options_available=None, **kwargs)[source]

Bases: gflownet.envs.base.GFlowNetEnv

Initializes a Choice environment.

Parameters:
  • options (iterable (optional)) – The descrption of the options. If None, the options are simply described by their indices. In this case, n_options must be not None.

  • n_options (int) – The number of options, if options is None. Ignored otherwise.

  • source_readable (str) – The string to be used to represent the source state as a human-readable string. By default: <source>

  • options_available (iterable (optional)) – A subset of the options to restrict the available options for the environment instance. The elements of the iterable are integers referring to the option indices.

source_readable = '<source>'[source]
options = None[source]
n_options[source]
options_indices[source]
source = [0][source]
eos[source]
get_action_space()[source]

Constructs list with all possible actions, including EOS.

Actions are represented by one element, namely the index of the option to be selected, starting from 1. The end of sequence action is (-1,).

Returns:

list – A list of tuples representing the actions.

Return type:

List[Tuple]

get_mask_invalid_actions_forward(state=None, done=None)[source]

Returns which actions are invalid (True) and which are not invalid (False).

Parameters:
  • state (list) – Input state. If None, self.state is used.

  • done (bool) – Whether the trajectory is done. If None, self.done is used.

Returns:

A list of boolean values.

Return type:

List[bool]

get_parents(state=None, done=None, action=None)[source]

Determines all parents and actions that lead to state.

There are only three types of states:
  • Done trajectories: the only parent is the state itself with action EOS.

  • Source state: no parents

  • Option selected: the only parent is the source state.

Parameters:
  • state (list) – Input state. If None, self.state is used.

  • done (bool) – Whether the trajectory is done. If None, self.done is used.

  • action (None) – Ignored

Returns:

  • parents (list) – List of parents in state format.

  • actions (list) – List of actions that lead to state for each parent in parents.

Return type:

Tuple[List, List]

step(action, skip_mask_check=False)[source]

Executes step given an action.

Parameters:
  • action (tuple) – Action to be executed. An action is a tuple with a single element indicating the the index of the option to be set.

  • skip_mask_check (bool) – If True, skip computing forward mask of invalid actions to check if the action is valid.

Returns:

  • self.state (list) – The state after executing the action.

  • action (tuple) – Action executed.

  • valid (bool) – False, if the action is not allowed for the current state.

Return type:

[List[int], Tuple[int], bool]

set_available_options(options)[source]

Updates the attribute options_available().

Parameters:

options (Iterable)

states2policy(states)[source]

Prepares a batch of states in “environment format” for the policy model: states are one-hot encoded.

Parameters:

states (list) – A batch of states in environment format

Returns:

A tensor containing all the states in the batch.

Return type:

torchtyping.TensorType[batch, policy_input_dim]

readable2state(readable, alphabet={})[source]

Converts a human-readable string representing a state into a state as a list of positions.

state2readable(state=None, alphabet={})[source]

Converts a state into a human-readable string representing a state.

The readable representation is taken from self.options, except if the state is the source state in which case self.source_readable is returned.

Parameters:

state (Optional[List])

get_all_terminating_states()[source]

Returns a list with all the terminating states in environment format.

Returns:

list – The list of all terminating states.

Return type:

List[List[int]]

get_uniform_terminating_states(n_states, seed=None)[source]

Constructs a batch of n states uniformly sampled in the sample space of the environment.

Parameters:
  • n_states (int) – The number of states to sample.

  • seed (int) – Random seed.

Return type:

List[List[int]]