"""
An environment to sample a selection of elements from a given set of options.
The configuration of the environment must determine not only the set of options, but
the maximum number of elements to be sampled, whether fewer than the maximum can be
sampled and whether the selection must proceed with or without replacement.
If the selection is with replacement, then the environment operates as a SetFlex
(without constraints). If the selection is without replacement, then the environment
operates as a SetFix with constraints, such that options that have been already
selected are made unavailable in the remaining environments.
"""
from typing import Dict, Iterable, List, Optional, Set, Tuple
import numpy as np
import torch
import torch.nn.functional as F
from torchtyping import TensorType
from gflownet.envs.base import GFlowNetEnv
from gflownet.envs.choice import Choice
from gflownet.envs.composite import make_set
from gflownet.envs.composite.setfix import SetFix
from gflownet.envs.composite.setflex import SetFlex
from gflownet.utils.common import tfloat, tlong
[docs]
def Choices(
options: Iterable = None,
n_options: int = 3,
max_selection: int = 2,
can_select_fewer_than_max: bool = False,
with_replacement: bool = True,
source_readable: str = "<source>",
**kwargs,
):
"""
Factory method to instantiate a Choices environment.
Parameters
----------
options : iterable (optional)
The descrption of the options. If None, the options are simply described by
their indices. In this case, ``n_options`` must be not None.
n_options : int
The number of options, if ``options`` is None. Ignored otherwise.
max_selection : int
The maximum number of options that may be selected.
can_select_fewer_than_max : bool
Whether fewer options than the maximum can be selected.
with_replacement : bool
Whether the selection proceeds with replacement (True, the same option can be
selected more than once) or without replacement (False, each option can be
selected only once).
source_readable : str
The string to be used to represent the source state as a human-readable
string. By default: <source>
"""
if can_select_fewer_than_max:
raise NotImplementedError(
"Selection of fewer than the maximum is currently not implemented"
)
return ChoicesSetFix(
options,
n_options,
max_selection,
can_select_fewer_than_max,
with_replacement,
source_readable,
)
[docs]
class ChoicesBase:
"""
ChoicesBase class.
This class is the base of Choices environments, which inherit from either SetFix or
SetFlex as well. If the configuration allows for selecting fewer elements than the
maximum, then the environment becomes a SetFlex. If the number of elements is fixed
(the maximum), then the environment becomes a SetFix.
This class determines the inputs that are passed to initialize the Set environment
(SetFix or SetFlex) allows for overriding the methods that implement functionality
that is common to both versions, in particular the constraints across environments,
which depend on whether the selection is with or without replacement.
If sampling is without replacement, the options that have been already selected are
made unavailable in the remaining environments.
``can_alternate_subenvs`` is always passed as False, since the Choice
sub-environments only have one meaningful action - selecting the option and then
EOS.
Attributes
----------
options : iterable
The description of the options as an iterable of strings. These strings are
used as readable representation. By default, the string <source> is reserved
for the source state.
n_options : int
The total number of different options.
max_selection : int
The maximum number of options that may be selected.
can_select_fewer_than_max : bool
Whether fewer options than the maximum can be selected.
with_replacement : bool
Whether the selection proceeds with replacement (True, the same option can be
selected more than once) or without replacement (False, each option can be
selected only once).
source_readable : str
The string to be used to represent the source state in the Choice environments
as a human-readable string.
"""
def __init__(
self,
options: Iterable = None,
n_options: int = 3,
max_selection: int = 2,
can_select_fewer_than_max: bool = False,
with_replacement: bool = True,
source_readable: str = "<source>",
**kwargs,
):
"""
Initializes a Choices environment.
Parameters
----------
options : iterable (optional)
The descrption of the options. If None, the options are simply described by
their indices. In this case, ``n_options`` must be not None.
n_options : int
The number of options, if ``options`` is None. Ignored otherwise.
max_selection : int
The maximum number of options that may be selected.
can_select_fewer_than_max : bool
Whether fewer options than the maximum can be selected.
with_replacement : bool
Whether the selection proceeds with replacement (True, the same option can be
selected more than once) or without replacement (False, each option can be
selected only once).
source_readable : str
The string to be used to represent the source state as a human-readable
string. By default: <source>
"""
[docs]
self.max_selection = max_selection
[docs]
self.can_select_fewer_than_max = can_select_fewer_than_max
[docs]
self.with_replacement = with_replacement
# Initialize parent class:
# - SetFlex if can_select_fewer_than_max is True (not implemented yet)
# - SetFix if can_select_fewer_than_max is False
if self.can_select_fewer_than_max:
raise NotImplementedError(
"Selection of fewer than the maximum is currently not implemented"
)
env_unique = Choice(
options=options, n_options=n_options, source_readable=source_readable
)
super().__init__(
envs_unique=(env_unique),
max_elements=self.max_selection,
can_alternate_subenvs=False,
**kwargs,
)
else:
subenvs = (
Choice(
options=options,
n_options=n_options,
source_readable=source_readable,
)
for _ in range(self.max_selection)
)
super().__init__(subenvs=subenvs, can_alternate_subenvs=False, **kwargs)
# Get attributes from sub-environment
[docs]
self.options = self.subenvs[0].options
[docs]
self.n_options = len(self.options)
[docs]
self.source_readable = self.subenvs[0].source_readable
@property
[docs]
def choice_env(self) -> Choice:
"""
Returns the unique Choice environment.
Returns
-------
Choice
The Choice environment that serves as unique environment of the Set.
"""
return self.envs_unique[0]
def _check_has_constraints(self) -> bool:
"""
Checks whether the composite environment has constraints across
sub-environments.
Constraints need to be applied if the selection is without replacement, since
the available options of the remaining sub-environments need to be restricted.
Returns
-------
bool
True if the selection is without replacement, False otherwise
"""
return not self.with_replacement
[docs]
def get_options(self, state: Dict = None) -> Tuple[int]:
"""
Returns all the options that have already been chosen from the state.
Parameters
----------
state : dict
A state of the global set environment.
Returns
-------
The set of options, as a tuple of integers.
"""
state = self._get_state(state)
states = self._get_substates(state)
return tuple([state[0] for state in states if state[0] != 0])
def _apply_constraints_forward(
self,
action: Tuple = None,
state: Optional[Dict] = None,
) -> bool:
"""
Applies constraints across sub-environments in the forward direction.
The available options of the sub-environments that are still to be set are
restricted to the options that have not been selected yet. This is done by
restricting the options of the unique environment, which is common to all
sub-environments.
Parameters
----------
action : tuple (optional)
An action from the global composite environment. If the call of this method
is not initiated by a transition, then ``action`` is None.
state : dict (optional)
A state of the global set environment.
Returns
-------
bool
True if any constraint was applied; False otherwise.
"""
idx_subenv = self._get_active_subenv(state)
if self._do_constraints_for_subenv(state, idx_subenv, action, False):
options = set(self.get_options(state))
options_available = set(self.choice_env.options_indices).difference(options)
self.choice_env.set_available_options(options_available)
return True
else:
return False
def _apply_constraints_backward(
self,
action: Tuple = None,
state: Optional[Dict] = None,
) -> bool:
"""
Applies constraints across sub-environments in the backward direction.
In the backward direction, in this case, means that the constraints between two
sub-environments are undone and reset as in the source state.
The available options of the sub-environments that are restricted to the
options that are not part of the state. Additionally, the option of the
currently active sub-environment is also addedto the available options, since,
in the backward sense, it will be unselected and then it will be available.
This is done by restricting the options of the unique environment, which is
common to all sub-environments.
Parameters
----------
action : tuple
An action from the global composite environment. If the call of this method
is not initiated by a transition, then ``action`` is None.
state : dict (optional)
A state of the global composite environment.
Returns
-------
bool
True if any constraint was applied; False otherwise.
"""
idx_subenv = self._get_active_subenv(state)
applied_constraints = False
if self._do_constraints_for_subenv(state, idx_subenv, action, True):
applied_constraints = True
options = set(self.get_options(state))
options_available = set(self.choice_env.options_indices).difference(options)
# Add option of currently active sub-environment since its option is
# currently part of the state and thus not available in the forward sense
# but it should be available assuming it will be unselected in the forward
# sense.
option_of_active_subenv = self._get_substate(state, idx_subenv)[0]
if option_of_active_subenv != 0:
options_available.add(option_of_active_subenv)
self.choice_env.set_available_options(options_available)
# If the state is source, reset the available options to the full set of
# options
# TODO: Design better solution for resetting the constraints when reset() is
# called
elif self.is_source(state):
applied_constraints = True
self.choice_env.set_available_options(set(self.choice_env.options_indices))
return applied_constraints
[docs]
def states2policy(
self, states: List[Dict]
) -> TensorType["batch", "state_policy_dim"]:
"""
Prepares a batch of states in environment format for the policy model.
The policy representation is the concatenation of the following elements:
- A flag indicating whether no environment is active (-1), an environment is
active and not done (0), or an environment is active but done (1).
- A vector of length ``self.n_options`` with the count of each selected option.
Parameters
----------
states : list
A batch of states in environment format.
Returns
-------
A tensor containing all the states in the batch.
"""
n_states = len(states)
# Extract relevant data
active, dones, substates = zip(
*[
[state["_active"], state["_dones"], self._get_substates(state)]
for state in states
]
)
active = np.array(active)
dones = np.array(dones)
substates = tlong(substates, device=self.device).reshape(
n_states, self.max_selection
)
# Build flags vector
flags = active
flags_active = flags != -1
flags[flags_active] = dones[flags_active][
np.arange(sum(flags_active)), active[flags_active]
]
# Build counts vector
n_options = self.choice_env.n_options
counts_all = torch.zeros(
(n_states, n_options + 1), device=self.device, dtype=torch.long
)
row_indices = torch.arange(n_states).unsqueeze(1).expand_as(substates)
counts_all.scatter_add_(1, substates, torch.ones_like(substates))
# Build output tensor
return torch.cat(
[
tfloat(flags, device=self.device, float_type=self.float).unsqueeze(1),
tfloat(counts_all[:, 1:], device=self.device, float_type=self.float),
],
dim=1,
)
[docs]
class ChoicesSetFix(ChoicesBase, SetFix):
"""
ChoicesSetFix environment.
This environment is the version of the Choices environments for the configuration
where the number of elements is fixed (``max_selection``), since fewer elements
than the maximum are not allowed. This version inherits the SetFix composite
environment.
"""
def __init__(
self,
*args,
**kwargs,
):
"""
Initializes a ChoicesSetFix environment inheriting from the SetFix.
"""
super().__init__(*args, **kwargs)
[docs]
class ChoicesSetFlex(ChoicesBase, SetFlex):
"""
ChoicesSetFlex environment.
This environment is the version of the Choices environments for the configuration
where the number of elements is variable, since fewer elements than the maximum are
allowed. This version inherits the SetFlex composite environment.
"""
def __init__(
self,
*args,
**kwargs,
):
"""
Initializes a ChoicesSetFlex environment inheriting from the SetFlex.
"""
super().__init__(*args, **kwargs)