"""Represent sequence-like environments.
Sequences are constructed by adding tokens from a dictionary, from left to
right.
"""
import itertools
from typing import Iterable, List, Optional, Tuple, Union
import torch
import torch.nn.functional as F
from torchtyping import TensorType
from gflownet.envs.base import GFlowNetEnv
from gflownet.utils.common import copy, set_device, tlong
[docs]
class SequenceBase(GFlowNetEnv):
"""
Represent sequence environments built one token at a time.
By default, for illustration purposes, this class is functional and represents
binary sequences of 0s and 1s that can be padded with the special token [PAD]
and terminated by a special EOS action without appending any token to the
state.
"""
def __init__(
self,
tokens: Iterable = [0, 1],
min_length: int = 1,
max_length: int = 5,
pad_token: Union[int, float, str] = -1,
**kwargs,
):
"""
Initialize a sequence environment.
Parameters
----------
tokens : Iterable
Vocabulary of tokens used to build sequences.
min_length : int
Minimum valid sequence length before the EOS action is allowed.
max_length : int
Maximum sequence length.
pad_token : int, float, str
Token used to pad incomplete sequences.
**kwargs
Additional keyword arguments forwarded to :class:`GFlowNetEnv`.
"""
assert max_length > 0
assert min_length > 0
assert min_length <= max_length
assert len(set(tokens)) == len(tokens)
# Make sure that padding token is not one of the regular tokens
if pad_token in tokens:
raise ValueError(
f"The padding token ({pad_token}) cannot be one of the regular tokens."
)
# Make sure that all tokens are the same type
token_types = {type(token) for token in set(tokens).union({pad_token})}
if len(token_types) != 1:
raise ValueError(
"All tokens must be the same type, but more than one type was found."
)
# Set device because it is needed in the init
[docs]
self.device = set_device(kwargs["device"])
# Main attributes
[docs]
self.tokens = tuple(tokens)
[docs]
self.pad_token = pad_token
[docs]
self.n_tokens = len(self.tokens)
[docs]
self.min_length = min_length
[docs]
self.max_length = max_length
[docs]
self.dtype = type(pad_token)
# Dictionaries
[docs]
self.idx2token = {idx + 1: token for idx, token in enumerate(self.tokens)}
self.idx2token[self.pad_idx] = pad_token
[docs]
self.token2idx = {token: idx for idx, token in self.idx2token.items()}
# Source state: vector of length max_length filled with pad token
[docs]
self.source = tlong(
torch.full((self.max_length,), self.pad_idx), device=self.device
)
# End-of-sequence action
[docs]
self.eos = (self.eos_idx,)
# Base class init
super().__init__(**kwargs)
[docs]
def get_action_space(self) -> List[Tuple]:
"""
Construct the list of all possible actions, including EOS.
An action is represented by a single-element tuple indicating the index of the
token to be added to the current sequence (state).
The action space of this parent class is:
action_space: [(1,), (2,), (-1,)]
"""
return [(self.token2idx[token],) for token in self.tokens] + [(self.eos_idx,)]
def _get_max_trajectory_length(self) -> int:
"""Return the maximum trajectory length, including the EOS action."""
return self.max_length + 1
[docs]
def get_mask_invalid_actions_forward(
self,
state: Optional[TensorType["max_length"]] = None, # noqa: F821
done: Optional[bool] = None,
) -> List[bool]:
"""
Return the mask of invalid forward actions.
The returned list has one entry per action:
- True if the forward action is invalid from the current state.
- False otherwise.
Parameters
----------
state : tensor
Input state. If None, self.state is used.
done : bool
Whether the trajectory is done. If None, self.done is used.
Returns
-------
A list of boolean values.
"""
state = self._get_state(state)
done = self._get_done(done)
if done:
return [True for _ in range(self.action_space_dim)]
# EOS action is invalid before attaining minimum sequence length
if self._get_seq_length(state) < self.min_length:
# All actions are valid except EOS
mask = [False for _ in range(self.action_space_dim)]
mask[self.action_space.index(self.eos)] = True
return mask
# If sequence is not at maximum length, all actions are valid
if state[-1] == self.pad_idx:
return [False for _ in range(self.action_space_dim)]
# Otherwise, only EOS is valid
mask = [True for _ in range(self.action_space_dim)]
mask[self.action_space.index(self.eos)] = False
return mask
[docs]
def get_parents(
self,
state: Optional[TensorType["max_length"]] = None, # noqa: F821
done: Optional[bool] = None,
action: Optional[Tuple] = None,
) -> Tuple[List, List]:
"""
Determine all parents and actions that lead to a state.
The GFlowNet graph is a tree and there is only one parent per state.
Parameters
----------
state : tensor
Input state. If None, self.state is used.
done : bool
Whether the trajectory is done. If None, self.done is used.
action : None
Ignored
Returns
-------
parents : list
List of parents in state format. This environment has a single parent per
state.
actions : list
List of actions that lead to state for each parent in parents. This
environment has a single parent per state.
"""
state = self._get_state(state)
done = self._get_done(done)
if done:
return [state], [self.eos]
if self.equal(state, self.source):
return [], []
pos_last_token = self._get_seq_length(state) - 1
parent = copy(state)
parent[pos_last_token] = self.pad_idx
p_action = (int(state[pos_last_token]),)
return [parent], [p_action]
[docs]
def step(
self, action: Tuple[int], skip_mask_check: bool = False
) -> Tuple[TensorType["max_length"], Tuple[int], bool]: # noqa: F821
"""
Execute a step for the given action.
Parameters
----------
action : tuple
Action to be executed. An action is represented by a single-element tuple
indicating the index of the token to be added to the current sequence
(state).
skip_mask_check : bool
If True, skip computing forward mask of invalid actions to check if the
action is valid.
Returns
-------
self.state : list
The sequence after executing the action
action : tuple
Action executed
valid : bool
False, if the action is not allowed for the current state.
"""
# Generic pre-step checks
do_step, self.state, action = self._pre_step(
action, skip_mask_check or self.skip_mask_check
)
if not do_step:
return self.state, action, False
valid = True
self.n_actions += 1
# If action is EOS, set done to True and return state as is
if action == self.eos:
self.done = True
return self.state, action, valid
# Update state
self.state[self._get_seq_length()] = action[0]
return self.state, action, valid
[docs]
def states2proxy(
self,
states: Union[
List[TensorType["max_length"]], # noqa: F821
TensorType["batch", "max_length"], # noqa: F821
],
) -> List[List]:
"""
Prepare a batch of states for a proxy.
States are represented by the tokens instead of the indices, with
padding up to the max_length.
Important: by default, the output of states2proxy() is a list of lists, instead
of a tensor as in most environments. This is to allow for string tokens.
Example, with max_length = 5:
- Sequence (tokens): 0100
- state: [1, 2, 1, 1, 0]
- proxy format: [0, 1, 0, 0, -1]
Parameters
----------
states : list or tensor
A batch of states in environment format, either as a list of states or as a
single tensor.
Returns
-------
A list containing all the states in the batch, represented themselves as lists.
"""
states = tlong(states, device=self.device).tolist()
states_proxy = []
for state in states:
states_proxy.append([self.idx2token[idx] for idx in state])
return states_proxy
[docs]
def states2policy(
self,
states: Union[
List[TensorType["max_length"]], # noqa: F821
TensorType["batch", "max_length"], # noqa: F821
],
) -> TensorType["batch", "policy_input_dim"]: # noqa: F821
"""
Prepare a batch of states for the policy model.
States are one-hot encoded.
Example, with max_length = 5:
- Sequence (tokens): 0100
- state: [1, 2, 1, 1, 0]
- policy format: [0, 1, 0, 0, 0, 1, 0, 1, 0, 0, 1, 0, 1, 0, 0]
| 0 | 1 | 0 | 0 | PAD |
Parameters
----------
states : list or tensor
A batch of states in environment format, either as a list of states or as a
single tensor.
Returns
-------
A tensor containing all the states in the batch.
"""
states = tlong(states, device=self.device)
return (
F.one_hot(states, self.n_tokens + 1)
.reshape(states.shape[0], -1)
.to(self.float)
)
[docs]
def state2readable(
self, state: Optional[TensorType["max_length"]] = None # noqa: F821
) -> str:
"""
Convert a state into a human-readable string.
Example, with max_length = 5:
- state: [1, 2, 1, 1, 0]
- readable: "0 1 0 0"
The output string contains the token corresponding to each index in the state,
separated by spaces.
Parameters
----------
states : tensor
A state in environment format. If None, self.state is used.
Returns
-------
A string of space-separated tokens.
"""
state = self._get_state(state)
state = self._unpad(state.tolist())
return "".join([str(self.idx2token[idx]) + " " for idx in state])[:-1]
[docs]
def readable2state(self, readable: str) -> TensorType["max_length"]: # noqa: F821
"""
Convert a readable state into environment format.
Example, with max_length = 5:
- readable: "0 1 0 0"
- state: [1, 2, 1, 1, 0]
Parameters
----------
readable : str
A state in readable format - space-separated tokens.
Returns
-------
A tensor containing the indices of the tokens.
"""
if readable == "":
return copy(self.source)
return tlong(
self._pad(
[self.token2idx[self.dtype(token)] for token in readable.split(" ")]
),
device=self.device,
)
[docs]
def get_all_terminating_states(
self,
) -> List[TensorType["max_length"]]: # noqa: F821
"""Construct a batch with all terminating states in the sample space."""
samples = []
tokens_indices = set(self.idx2token.keys())
tokens_indices.remove(self.pad_idx)
for length in range(self.min_length, self.max_length + 1):
samples_aux = tlong(
list(itertools.product(*[tokens_indices] * length)),
device=self.device,
)
samples.append(
torch.cat(
(
samples_aux,
torch.full(
(samples_aux.shape[0], self.max_length - length),
self.pad_idx,
).to(samples_aux),
),
dim=1,
)
)
# TODO: this is very inefficient but currently this method has to return a list
# of states in the GFlowNet format.
samples = torch.cat(samples).tolist()
return [tlong(sample, device=self.device) for sample in samples]
def _get_generator(self, seed: Optional[int] = None) -> Optional[torch.Generator]:
"""Return a local random number generator for the given seed."""
if seed is None:
return None
generator = torch.Generator()
generator.manual_seed(seed)
return generator
def _pad(self, seq_list: list):
"""
Pad a sequence represented as a list of indices.
Parameters
----------
seq_list : list
The input sequence. A list containing a list of indices.
Returns
-------
The input list padded by the end with self.pad_idx.
"""
return seq_list + [self.pad_idx] * (self.max_length - len(seq_list))
def _unpad(self, seq_list: list):
"""
Remove trailing padding from a sequence represented as a list of indices.
Parameters
----------
seq_list : list
The input sequence. A list containing a list of indices, including possibly
padding indices.
Returns
-------
The input list padded by the end with self.pad_idx.
"""
if self.pad_idx not in seq_list:
return seq_list
return seq_list[: seq_list.index(self.pad_idx)]
def _get_seq_length(
self, state: Optional[TensorType["max_length"]] = None # noqa: F821
):
"""
Return the effective length of a state, ignoring padding.
Parameters
----------
state : tensor
The input sequence. If None, self.state is used.
Returns
-------
Single element int tensor.
"""
state = self._get_state(state)
if state[-1] == self.pad_idx:
return torch.where(state == self.pad_idx)[0][0]
else:
return len(state)