"""
This environment is a conditional set of cubes and grids. The set is flexible in that
in can consist of a variable number of cubes and grids. The number of grids and cubes
for each trajectory can be regarded as the conditions of the set and in this environment
these conditions are sampled by an auxiliary 2D grid environment, where each dimension
will indicate the number of cubes and grids, respectively. Therefore, the overall
environment is a Stack of a Grid and a Set of Cubes and Grids.
This environment was originally designed for debugging the conditional mode of the Set
environment.
"""
from typing import Dict, List, Optional, Tuple, Union
import pandas as pd
import torch
from torchtyping import TensorType
from tqdm import tqdm
from gflownet.envs.composite.setflex import SetFlex
from gflownet.envs.composite.stack import Stack
from gflownet.envs.cube import ContinuousCube
from gflownet.envs.grid import Grid
from gflownet.utils.common import copy, tfloat
# Constants to identify the indices of the Cube and Grid unique environments
[docs]
class SetBox(Stack):
"""
A Stack of a Grid and a Set of Cubes and Grids. The first grid determines the
conditions (constraints) of the Set.
"""
def __init__(
self,
max_elements_per_subenv: int = 3,
n_dim: int = 2,
cube_kwargs: Optional[Dict] = None,
grid_kwargs: Optional[Dict] = None,
**kwargs,
):
"""
Initializes the SetBox environment.
Parameters
----------
max_elements_per_subenv : int
The maximum number of elements of each kind in the Set. The total maximum
number of elements in the set will thus be 2 * max_elements_per_subenv.
n_dim : int
The dimensionality of the Cubes and Grids in the Set.
"""
[docs]
self.max_elements_per_subenv = max_elements_per_subenv
[docs]
self.cube_kwargs = cube_kwargs or {}
[docs]
self.grid_kwargs = grid_kwargs or {}
# Define sub-environments of the Stack
[docs]
self.idx_conditioning_grid = 0
subenvs = [
Grid(n_dim=2, length=self.max_elements_per_subenv + 1),
SetFlex(
max_elements=self.max_elements_per_subenv * 2,
envs_unique=(
ContinuousCube(n_dim=n_dim, **self.cube_kwargs),
Grid(n_dim=n_dim, **self.grid_kwargs),
),
),
]
# Initialize base Stack environment
super().__init__(subenvs=tuple(subenvs), **kwargs)
@property
[docs]
def conditioning_grid(self) -> Grid:
"""
Returns the sub-environment corresponding to the Grid that is the first
sub-environment in the Stack, which is used to sample the conditions of the Set.
Returns
-------
Grid
"""
return self.subenvs[self.idx_conditioning_grid]
@property
[docs]
def set(self) -> SetFlex:
"""
Returns the sub-environment corresponding to the set of cubes and grids.
Returns
-------
SetFlex
"""
return self.subenvs[self.idx_set]
@property
[docs]
def cube(self) -> ContinuousCube:
"""
Returns the ContinuousCube environment that is used as unique environment to
define the Cubes in the Set.
The Cube is the unique environment in the first (0) dimension of subenvs_unique
in the Set.
Returns
-------
ContinuousCube
"""
return self.set.envs_unique[IDX_CUBE]
@property
[docs]
def grid(self) -> Grid:
"""
Returns the Grid environment that is used as unique environment to define the
Grids in the Set.
The Grid is the unique environment in the second (1) dimension of
subenvs_unique in the Set.
Returns
-------
Grid
"""
return self.set.envs_unique[IDX_GRID]
def _check_has_constraints(self) -> bool:
"""
Checks whether the Stack has constraints across sub-environments.
It returns True, because the SetBox always has constraints from the
conditioning grid to the Set.
Returns
-------
bool
True
"""
return True
def _apply_constraints_forward(
self,
action: Tuple = None,
state: Dict = None,
) -> bool:
"""
Applies constraints across sub-environments, when applicable, in the forward
direction.
Parameters
----------
action : tuple
An action from the SetBox environment.
state : dict
A state from the SetBox environment.
Returns
-------
bool
True if any constraint was applied; False otherwise.
"""
if self._do_constraints_for_subenv(
state, self.idx_conditioning_grid, action, is_backward=False
):
n_cubes = self.conditioning_grid.state[IDX_CUBE]
n_grids = self.conditioning_grid.state[IDX_GRID]
if n_cubes == 0 and n_grids == 0:
subenvs = self.set._sample_random_subenvs()
else:
cubes = [IDX_CUBE] * n_cubes
grids = [IDX_GRID] * n_grids
subenvs = self.set.get_env_instances_by_unique_indices(cubes + grids)
self.set.set_subenvs(subenvs=subenvs)
# If a state is passed as argument, set the state and done of set
# sub-environment
if state is not None:
self.set.set_state(
self._get_substate(state, self.idx_set),
done=self.set.done,
)
# Update global Stack state with state of Set
self._set_substate(self.idx_set, self.set.state)
return True
else:
return False
def _apply_constraints_backward(
self, action: Tuple = None, state: Optional[Dict] = None
) -> bool:
"""
Applies constraints across sub-environments, when applicable, in the backward
direction.
Parameters
----------
action : tuple
An action from the SetBox environment.
state : dict
A state from the SetBox environment.
Returns
-------
bool
True if any constraint was applied; False otherwise.
"""
if self._do_constraints_for_subenv(
state, self.idx_conditioning_grid, action, is_backward=True
):
# Reset source of Set, set subenvs of Set to None and update global Stack
# state
self.set.state = copy(self.set.source)
self.set.subenvs = None
self._set_substate(self.idx_set, self.set.state)
return True
else:
return False
[docs]
def states2proxy(
self, states: List[Dict]
) -> TensorType["batch", "state_oracle_dim"]:
"""
Prepares a batch of states in "environment format" for a proxy.
The proxy representation is the average of the proxy representation across all
the cubes and grids in the set.
Parameters
----------
states : list
A batch of states in environment format.
Returns
-------
A tensor containing all the states in the batch in the proxy representation.
"""
# Keep only the part of the state corresponding to the Set
states_proxy_set = self.set.states2proxy(
[self._get_substate(state, self.idx_set) for state in states]
)
states_proxy = []
for state in states_proxy_set:
states_box = tfloat(
self.set._get_substates(state),
float_type=self.float,
device=self.device,
)
states_proxy.append(torch.mean(states_box, dim=0))
return tfloat(states_proxy, float_type=self.float, device=self.device)