"""
IMPORTANT: this environment is currently broken!
Classes to represent hyper-torus environments
"""
import itertools
from typing import List, Optional, Tuple, Union
import numpy as np
import numpy.typing as npt
import pandas as pd
import torch
from torchtyping import TensorType
from gflownet.envs.base import GFlowNetEnv
from gflownet.utils.common import tfloat, tlong
[docs]
class Torus(GFlowNetEnv):
"""
Hyper-torus environment in which the action space consists of:
- Increasing the angle index of dimension d
- Decreasing the angle index of dimension d
- Keeping all dimensions as are
and the trajectory is of fixed length length_traj.
The states space is the concatenation of the angle index at each dimension and the
number of actions.
Attributes
----------
ndim : int
Dimensionality of the torus
n_angles : int
Number of angles into which each dimension is divided
length_traj : int
Fixed length of the trajectory.
"""
def __init__(
self,
n_dim: int = 2,
n_angles: int = 3,
length_traj: int = 1,
max_increment: int = 1,
max_dim_per_action: int = 1,
**kwargs,
):
assert n_dim > 0
assert n_angles > 1
assert length_traj > 0
assert max_increment > 0
assert max_dim_per_action == -1 or max_dim_per_action > 0
[docs]
self.n_angles = n_angles
[docs]
self.length_traj = length_traj
[docs]
self.max_increment = max_increment
if max_dim_per_action == -1:
max_dim_per_action = self.n_dim
[docs]
self.max_dim_per_action = max_dim_per_action
# Source state: position 0 at all dimensions and number of actions 0
[docs]
self.source_angles = [0 for _ in range(self.n_dim)]
[docs]
self.source = self.source_angles + [0]
# End-of-sequence action: (self.max_incremement + 1) in all dimensions
[docs]
self.eos = tuple([self.max_increment + 1 for _ in range(self.n_dim)])
# Angle increments in radians
[docs]
self.angle_rad = 2 * np.pi / self.n_angles
# Base class init
super().__init__(**kwargs)
[docs]
def get_action_space(self):
"""
Constructs list with all possible actions, including eos. An action is
represented by a vector of length n_dim where each index d indicates the
increment/decrement to apply to dimension d of the hyper-torus. A negative
value indicates a decrement. The action "keep" (no increment/decrement of any
dimensions) is valid and is indicated by all zeros.
"""
increments = [el for el in range(-self.max_increment, self.max_increment + 1)]
actions = []
for action in itertools.product(increments, repeat=self.n_dim):
if len([el for el in action if el != 0]) <= self.max_dim_per_action:
actions.append(tuple(action))
actions.append(self.eos)
return actions
[docs]
def get_mask_invalid_actions_forward(
self,
state: Optional[List] = None,
done: Optional[bool] = None,
) -> List:
"""
Returns a list of length the action space with values:
- True if the forward action is invalid from the current state.
- False otherwise.
All actions except EOS are valid if the maximum number of actions has not been
reached, and vice versa.
"""
state = self._get_state(state)
done = self._get_done(done)
if done:
return [True for _ in range(self.action_space_dim)]
if state[-1] >= self.length_traj:
mask = [True for _ in range(self.action_space_dim)]
mask[-1] = False
else:
mask = [False for _ in range(self.action_space_dim)]
mask[-1] = True
return mask
[docs]
def states2proxy(
self, states: Union[List[List], TensorType["batch", "state_dim"]]
) -> TensorType["batch", "state_proxy_dim"]:
"""
Prepares a batch of states in "environment format" for the proxy: each state is
a vector of length n_dim where each value is an angle in radians. The n_actions
item is removed.
Args
----
states : list or tensor
A batch of states in environment format, either as a list of states or as a
single tensor.
Returns
-------
A tensor containing all the states in the batch.
"""
return (
tfloat(states, device=self.device, float_type=self.float)[:, :-1]
* self.angle_rad
)
# TODO: circular encoding as in htorus
[docs]
def states2policy(
self, states: Union[List, TensorType["batch", "state_dim"]]
) -> TensorType["batch", "policy_input_dim"]:
"""
Prepares a batch of states in "environment format" for the policy model: the
policy format is a one-hot encoding of the states.
Each row is a vector of length n_angles * n_dim + 1, where each n-th successive
block of length elements is a one-hot encoding of the position in the n-th
dimension.
Example, n_dim = 2, n_angles = 4:
- state: [1, 3, 4]
| a | n | (a = angles, n = n_actions)
- policy format: [0, 1, 0, 0, 0, 0, 0, 1, 4]
| 1 | 3 | 4 |
Args
----
states : list or tensor
A batch of states in environment format, either as a list of states or as a
single tensor.
Returns
-------
A tensor containing all the states in the batch.
"""
states = tlong(states, device=self.device)
cols = states[:, :-1] + torch.arange(self.n_dim).to(self.device) * self.n_angles
rows = torch.repeat_interleave(
torch.arange(states.shape[0]).to(self.device), self.n_dim
)
states_policy = torch.zeros(
(states.shape[0], self.n_angles * self.n_dim + 1)
).to(states)
states_policy[rows, cols.flatten()] = 1.0
states_policy[:, -1] = states[:, -1]
return states_policy.to(self.float)
[docs]
def state2readable(self, state: Optional[List] = None) -> str:
"""
Converts a state (a list of positions) into a human-readable string
representing a state.
"""
state = self._get_state(state)
angles = (
str(state[: self.n_dim])
.replace("(", "[")
.replace(")", "]")
.replace(",", "")
)
n_actions = str(state[-1])
return angles + " | " + n_actions
[docs]
def readable2state(self, readable: str) -> List:
"""
Converts a human-readable string representing a state into a state as a list of
positions.
"""
pair = readable.split(" | ")
angles = [int(el) for el in pair[0].strip("[]").split(" ")]
n_actions = [int(pair[1])]
return angles + n_actions
[docs]
def get_parents(
self,
state: Optional[List] = None,
done: Optional[bool] = None,
action: Optional[Tuple] = None,
) -> Tuple[List, List]:
"""
Determines all parents and actions that lead to state.
Args
----
state : list
Representation of a state, as a list of length n_angles where each element
is the position at each dimension.
done : bool
Whether the trajectory is done. If None, done is taken from instance.
action : None
Ignored
Returns
-------
parents : list
List of parents in state format
actions : list
List of actions that lead to state for each parent in parents
"""
state = self._get_state(state)
done = self._get_done(done)
if done:
return [state], [self.eos]
# If source state
elif state[-1] == 0:
return [], []
else:
parents = []
actions = []
for idx, action in enumerate(self.action_space[:-1]):
state_p = state.copy()
angles_p = state_p[: self.n_dim]
n_actions_p = state_p[-1]
# Get parent
n_actions_p -= 1
for dim, incr in enumerate(action):
angles_p[dim] -= incr
# If negative angle index, restart from the back
if angles_p[dim] < 0:
angles_p[dim] = self.n_angles + angles_p[dim]
# If angle index larger than n_angles, restart from 0
if angles_p[dim] >= self.n_angles:
angles_p[dim] = angles_p[dim] - self.n_angles
if self._get_min_actions_to_source(angles_p) < state[-1]:
state_p = angles_p + [n_actions_p]
parents.append(state_p)
actions.append(action)
return parents, actions
[docs]
def step(
self, action: Tuple[int], skip_mask_check: bool = False
) -> Tuple[List[int], Tuple[int], bool]:
"""
Executes step given an action.
Args
----
action : tuple
Action to be executed. See: get_action_space()
skip_mask_check : bool
If True, skip computing forward mask of invalid actions to check if the
action is valid.
Returns
-------
self.state : list
The sequence after executing the action
action : tuple
Action executed
valid : bool
False, if the action is not allowed for the current state.
"""
# Generic pre-step checks
do_step, self.state, action = self._pre_step(
action, skip_mask_check or self.skip_mask_check
)
if not do_step:
return self.state, action, False
# If only possible action is eos, then force eos
# If the number of actions is equal to trajectory length
if self.n_actions == self.length_traj:
self.n_actions += 1
self.done = True
return self.state, self.eos, True
# Perform non-EOS action
else:
angles_next = self.state.copy()[: self.n_dim]
n_actions_next = self.state[-1] + 1
for dim, incr in enumerate(action):
angles_next[dim] += incr
# If negative angle index, restart from the back
if angles_next[dim] < 0:
angles_next[dim] = self.n_angles + angles_next[dim]
# If angle index larger than n_angles, restart from 0
if angles_next[dim] >= self.n_angles:
angles_next[dim] = angles_next[dim] - self.n_angles
self.state = angles_next + [n_actions_next]
self.n_actions += 1
return self.state, action, True
[docs]
def get_all_terminating_states(self):
all_x = itertools.product(*[list(range(self.n_angles))] * self.n_dim)
all_x_valid = []
for x in all_x:
if self._get_min_actions_to_source(x) <= self.length_traj:
all_x_valid.append(x)
all_x = np.int32(all_x_valid)
n_actions = self.length_traj * np.ones([all_x.shape[0], 1], dtype=np.int32)
all_x = np.concatenate([all_x, n_actions], axis=1)
return all_x.tolist()
[docs]
def fit_kde(x, kernel="exponential", bandwidth=0.1):
kde = KernelDensity(kernel=kernel, bandwidth=bandwidth).fit(last_states.numpy())
def _get_min_actions_to_source(self, angles):
def _get_min_actions_dim(u, v):
return np.min([np.abs(u - v), np.abs(u - (v - self.n_angles))])
return np.sum(
[_get_min_actions_dim(u, v) for u, v in zip(self.source_angles, angles)]
)