Source code for base

"""Represent sequence-like environments.

Sequences are constructed by adding tokens from a dictionary, from left to
right.
"""

import itertools
from typing import Iterable, List, Optional, Tuple, Union

import torch
import torch.nn.functional as F
from torchtyping import TensorType

from gflownet.envs.base import GFlowNetEnv
from gflownet.utils.common import copy, set_device, tlong


[docs] class SequenceBase(GFlowNetEnv): """ Represent sequence environments built one token at a time. By default, for illustration purposes, this class is functional and represents binary sequences of 0s and 1s that can be padded with the special token [PAD] and terminated by a special EOS action without appending any token to the state. """ def __init__( self, tokens: Iterable = [0, 1], min_length: int = 1, max_length: int = 5, pad_token: Union[int, float, str] = -1, **kwargs, ): """ Initialize a sequence environment. Parameters ---------- tokens : Iterable Vocabulary of tokens used to build sequences. min_length : int Minimum valid sequence length before the EOS action is allowed. max_length : int Maximum sequence length. pad_token : int, float, str Token used to pad incomplete sequences. **kwargs Additional keyword arguments forwarded to :class:`GFlowNetEnv`. """ assert max_length > 0 assert min_length > 0 assert min_length <= max_length assert len(set(tokens)) == len(tokens) # Make sure that padding token is not one of the regular tokens if pad_token in tokens: raise ValueError( f"The padding token ({pad_token}) cannot be one of the regular tokens." ) # Make sure that all tokens are the same type token_types = {type(token) for token in set(tokens).union({pad_token})} if len(token_types) != 1: raise ValueError( "All tokens must be the same type, but more than one type was found." ) # Set device because it is needed in the init
[docs] self.device = set_device(kwargs["device"])
# Main attributes
[docs] self.tokens = tuple(tokens)
[docs] self.pad_token = pad_token
[docs] self.n_tokens = len(self.tokens)
[docs] self.min_length = min_length
[docs] self.max_length = max_length
[docs] self.eos_idx = -1
[docs] self.pad_idx = 0
[docs] self.dtype = type(pad_token)
# Dictionaries
[docs] self.idx2token = {idx + 1: token for idx, token in enumerate(self.tokens)}
self.idx2token[self.pad_idx] = pad_token
[docs] self.token2idx = {token: idx for idx, token in self.idx2token.items()}
# Source state: vector of length max_length filled with pad token
[docs] self.source = tlong( torch.full((self.max_length,), self.pad_idx), device=self.device )
# End-of-sequence action
[docs] self.eos = (self.eos_idx,)
# Base class init super().__init__(**kwargs)
[docs] def get_action_space(self) -> List[Tuple]: """ Construct the list of all possible actions, including EOS. An action is represented by a single-element tuple indicating the index of the token to be added to the current sequence (state). The action space of this parent class is: action_space: [(1,), (2,), (-1,)] """ return [(self.token2idx[token],) for token in self.tokens] + [(self.eos_idx,)]
def _get_max_trajectory_length(self) -> int: """Return the maximum trajectory length, including the EOS action.""" return self.max_length + 1
[docs] def get_mask_invalid_actions_forward( self, state: Optional[TensorType["max_length"]] = None, # noqa: F821 done: Optional[bool] = None, ) -> List[bool]: """ Return the mask of invalid forward actions. The returned list has one entry per action: - True if the forward action is invalid from the current state. - False otherwise. Parameters ---------- state : tensor Input state. If None, self.state is used. done : bool Whether the trajectory is done. If None, self.done is used. Returns ------- A list of boolean values. """ state = self._get_state(state) done = self._get_done(done) if done: return [True for _ in range(self.action_space_dim)] # EOS action is invalid before attaining minimum sequence length if self._get_seq_length(state) < self.min_length: # All actions are valid except EOS mask = [False for _ in range(self.action_space_dim)] mask[self.action_space.index(self.eos)] = True return mask # If sequence is not at maximum length, all actions are valid if state[-1] == self.pad_idx: return [False for _ in range(self.action_space_dim)] # Otherwise, only EOS is valid mask = [True for _ in range(self.action_space_dim)] mask[self.action_space.index(self.eos)] = False return mask
[docs] def get_parents( self, state: Optional[TensorType["max_length"]] = None, # noqa: F821 done: Optional[bool] = None, action: Optional[Tuple] = None, ) -> Tuple[List, List]: """ Determine all parents and actions that lead to a state. The GFlowNet graph is a tree and there is only one parent per state. Parameters ---------- state : tensor Input state. If None, self.state is used. done : bool Whether the trajectory is done. If None, self.done is used. action : None Ignored Returns ------- parents : list List of parents in state format. This environment has a single parent per state. actions : list List of actions that lead to state for each parent in parents. This environment has a single parent per state. """ state = self._get_state(state) done = self._get_done(done) if done: return [state], [self.eos] if self.equal(state, self.source): return [], [] pos_last_token = self._get_seq_length(state) - 1 parent = copy(state) parent[pos_last_token] = self.pad_idx p_action = (int(state[pos_last_token]),) return [parent], [p_action]
[docs] def step( self, action: Tuple[int], skip_mask_check: bool = False ) -> Tuple[TensorType["max_length"], Tuple[int], bool]: # noqa: F821 """ Execute a step for the given action. Parameters ---------- action : tuple Action to be executed. An action is represented by a single-element tuple indicating the index of the token to be added to the current sequence (state). skip_mask_check : bool If True, skip computing forward mask of invalid actions to check if the action is valid. Returns ------- self.state : list The sequence after executing the action action : tuple Action executed valid : bool False, if the action is not allowed for the current state. """ # Generic pre-step checks do_step, self.state, action = self._pre_step( action, skip_mask_check or self.skip_mask_check ) if not do_step: return self.state, action, False valid = True self.n_actions += 1 # If action is EOS, set done to True and return state as is if action == self.eos: self.done = True return self.state, action, valid # Update state self.state[self._get_seq_length()] = action[0] return self.state, action, valid
[docs] def states2proxy( self, states: Union[ List[TensorType["max_length"]], # noqa: F821 TensorType["batch", "max_length"], # noqa: F821 ], ) -> List[List]: """ Prepare a batch of states for a proxy. States are represented by the tokens instead of the indices, with padding up to the max_length. Important: by default, the output of states2proxy() is a list of lists, instead of a tensor as in most environments. This is to allow for string tokens. Example, with max_length = 5: - Sequence (tokens): 0100 - state: [1, 2, 1, 1, 0] - proxy format: [0, 1, 0, 0, -1] Parameters ---------- states : list or tensor A batch of states in environment format, either as a list of states or as a single tensor. Returns ------- A list containing all the states in the batch, represented themselves as lists. """ states = tlong(states, device=self.device).tolist() states_proxy = [] for state in states: states_proxy.append([self.idx2token[idx] for idx in state]) return states_proxy
[docs] def states2policy( self, states: Union[ List[TensorType["max_length"]], # noqa: F821 TensorType["batch", "max_length"], # noqa: F821 ], ) -> TensorType["batch", "policy_input_dim"]: # noqa: F821 """ Prepare a batch of states for the policy model. States are one-hot encoded. Example, with max_length = 5: - Sequence (tokens): 0100 - state: [1, 2, 1, 1, 0] - policy format: [0, 1, 0, 0, 0, 1, 0, 1, 0, 0, 1, 0, 1, 0, 0] | 0 | 1 | 0 | 0 | PAD | Parameters ---------- states : list or tensor A batch of states in environment format, either as a list of states or as a single tensor. Returns ------- A tensor containing all the states in the batch. """ states = tlong(states, device=self.device) return ( F.one_hot(states, self.n_tokens + 1) .reshape(states.shape[0], -1) .to(self.float) )
[docs] def state2readable( self, state: Optional[TensorType["max_length"]] = None # noqa: F821 ) -> str: """ Convert a state into a human-readable string. Example, with max_length = 5: - state: [1, 2, 1, 1, 0] - readable: "0 1 0 0" The output string contains the token corresponding to each index in the state, separated by spaces. Parameters ---------- states : tensor A state in environment format. If None, self.state is used. Returns ------- A string of space-separated tokens. """ state = self._get_state(state) state = self._unpad(state.tolist()) return "".join([str(self.idx2token[idx]) + " " for idx in state])[:-1]
[docs] def readable2state(self, readable: str) -> TensorType["max_length"]: # noqa: F821 """ Convert a readable state into environment format. Example, with max_length = 5: - readable: "0 1 0 0" - state: [1, 2, 1, 1, 0] Parameters ---------- readable : str A state in readable format - space-separated tokens. Returns ------- A tensor containing the indices of the tokens. """ if readable == "": return copy(self.source) return tlong( self._pad( [self.token2idx[self.dtype(token)] for token in readable.split(" ")] ), device=self.device, )
[docs] def get_all_terminating_states( self, ) -> List[TensorType["max_length"]]: # noqa: F821 """Construct a batch with all terminating states in the sample space.""" samples = [] tokens_indices = set(self.idx2token.keys()) tokens_indices.remove(self.pad_idx) for length in range(self.min_length, self.max_length + 1): samples_aux = tlong( list(itertools.product(*[tokens_indices] * length)), device=self.device, ) samples.append( torch.cat( ( samples_aux, torch.full( (samples_aux.shape[0], self.max_length - length), self.pad_idx, ).to(samples_aux), ), dim=1, ) ) # TODO: this is very inefficient but currently this method has to return a list # of states in the GFlowNet format. samples = torch.cat(samples).tolist() return [tlong(sample, device=self.device) for sample in samples]
[docs] def get_uniform_terminating_states( self, n_states: int, seed: int = None ) -> List[TensorType["max_length"]]: # noqa: F821 """ Construct a batch of states sampled uniformly from the sample space. Parameters ---------- n_states : int The number of states to sample. seed : int Random seed. """ generator = self._get_generator(seed) n_per_length = torch.tensor( [ self.n_tokens**length for length in range(self.min_length, self.max_length + 1) ], dtype=torch.float64, ) lengths = torch.multinomial( n_per_length, num_samples=n_states, replacement=True, generator=generator, ) lengths += self.min_length samples = torch.randint( low=1, high=self.n_tokens + 1, size=(n_states, self.max_length), generator=generator, ) for idx, length in enumerate(lengths.tolist()): samples[idx, length:] = self.pad_idx # TODO: this is very inefficient but currently this method has to return a list # of states in the GFlowNet format. samples = samples.tolist() return [tlong(sample, device=self.device) for sample in samples]
def _get_generator(self, seed: Optional[int] = None) -> Optional[torch.Generator]: """Return a local random number generator for the given seed.""" if seed is None: return None generator = torch.Generator() generator.manual_seed(seed) return generator def _pad(self, seq_list: list): """ Pad a sequence represented as a list of indices. Parameters ---------- seq_list : list The input sequence. A list containing a list of indices. Returns ------- The input list padded by the end with self.pad_idx. """ return seq_list + [self.pad_idx] * (self.max_length - len(seq_list)) def _unpad(self, seq_list: list): """ Remove trailing padding from a sequence represented as a list of indices. Parameters ---------- seq_list : list The input sequence. A list containing a list of indices, including possibly padding indices. Returns ------- The input list padded by the end with self.pad_idx. """ if self.pad_idx not in seq_list: return seq_list return seq_list[: seq_list.index(self.pad_idx)] def _get_seq_length( self, state: Optional[TensorType["max_length"]] = None # noqa: F821 ): """ Return the effective length of a state, ignoring padding. Parameters ---------- state : tensor The input sequence. If None, self.state is used. Returns ------- Single element int tensor. """ state = self._get_state(state) if state[-1] == self.pad_idx: return torch.where(state == self.pad_idx)[0][0] else: return len(state)