"""
GFlowNet
TODO:
- Seeds
"""
import copy
import gc
import pickle
import time
from collections import defaultdict
from functools import partial
from pathlib import Path
from typing import Callable, Dict, List, Optional, Tuple, Union
import numpy as np
import torch
import torch.nn as nn
from torchtyping import TensorType
from tqdm import tqdm, trange
from gflownet.envs.base import GFlowNetEnv
from gflownet.evaluator.base import BaseEvaluator
from gflownet.utils.batch import Batch, compute_logprobs_trajectories
from gflownet.utils.common import (
bootstrap_samples,
set_device,
set_float_precision,
tbool,
tfloat,
tlong,
)
[docs]
class GFlowNetAgent:
def __init__(
self,
env_maker: partial,
proxy,
seed,
device,
float_precision,
loss,
optimizer,
buffer,
forward_policy,
backward_policy,
mask_invalid_actions,
temperature_logits,
random_action_prob,
logger,
evaluator,
state_flow=None,
use_context=False,
replay_sampling="permutation",
train_sampling="permutation",
garbage_collection_period: int = 0,
collect_reversed_logprobs: bool = False,
**kwargs,
):
"""
Main class of this repository. Handles the training logic for a GFlowNet model.
Parameters
----------
env : GFlowNetEnv
The environment to be used for training, i.e. the DAG, action space and
reward function.
seed : int
Random seed to be used for reproducibility.
device : str
Device to be used for training and inference, e.g. "cuda" or "cpu".
float_precision : int
Precision of the floating point numbers, e.g. 32 or 64.
loss : Loss
An instance of a loss class, corresponding to one of the GFlowNet
objectives, for example Flow Matching or Trajectory Balance.
optimizer : dict
Optimizer config dictionary. See gflownet.yaml:optimizer for details.
buffer : dict
Buffer config dictionary. See gflownet.yaml:buffer for details.
forward_policy : gflownet.policy.base.Policy
The forward policy to be used for training. Parameterized from
`gflownet.yaml:forward_policy` and parsed with
`gflownet/utils/policy.py:set_policy`.
backward_policy : gflownet.policy.base.Policy
Same as forward_policy, but for the backward policy.
mask_invalid_actions : bool
Whether to mask invalid actions in the policy outputs.
temperature_logits : float
Temperature to adjust the logits by logits /= temperature. If None,
self.temperature_logits is used.
random_action_prob : float
Probability of sampling random actions. If None (default),
self.random_action_prob is used, unless its value is forced to either 0.0 or
1.0 by other arguments.
logger : gflownet.utils.logger.Logger
Logger object to be used for logging and saving checkpoints
(`gflownet/utils/logger.py:Logger`).
evaluator : gflownet.evaluator.base.BaseEvaluator
:py:mod:`~gflownet.evaluator` ``Evaluator`` instance.
state_flow : dict, optional
State flow config dictionary. See `gflownet.yaml:state_flow` for details. By
default None.
use_context : bool, optional
Whether the logger will use its context in metrics names. Formerly the
`active_learning: bool` flag. By default False.
replay_sampling : str, optional
Type of sampling for the replay buffer. See
:meth:`~gflownet.utils.buffer.select`. By default "permutation".
train_sampling : str, optional
Type of sampling for the train buffer (offline backward trajectories). See
:meth:`~gflownet.utils.buffer.select`. By default "permutation".
garbage_collection_period : int
The periodicity to perform garbage collection and empty the cache of the
GPU. By default it is 0, so no garbage collection is performed. This is
because it can incur a large time overhead unnecessarily.
collect_reversed_logprobs: bool
If True, reversed logprobs will be computed and collected during sampling batches
for training
Raises
------
Exception
If the environment is continuous and the loss is not well defined for
continuous GFlowNets.
"""
# Seed
[docs]
self.rng = np.random.default_rng(seed)
# Device
[docs]
self.device = set_device(device)
# Float precision
[docs]
self.float = set_float_precision(float_precision)
# Environment
[docs]
self.env_maker = env_maker
[docs]
self.env = self.env_maker()
# Proxy
self.proxy.setup(self.env)
# Loss
if self.loss.requires_log_z:
self.logZ = nn.Parameter(torch.ones(optimizer.z_dim) * 150.0 / 64)
self.loss.set_log_z(self.logZ)
else:
self.logZ = None
[docs]
self.collect_backwards_masks = self.loss.requires_backward_policy()
[docs]
self.collect_reversed_logprobs = collect_reversed_logprobs
# Continuous environments
[docs]
self.continuous = hasattr(self.env, "continuous") and self.env.continuous
if self.continuous and not loss.is_defined_for_continuous():
raise Exception(
f"The environment is continuous but the {loss.name} loss is not well "
"defined for continuous environments. Consider using a different loss."
)
# Logging
# Buffers
[docs]
self.replay_sampling = replay_sampling
[docs]
self.train_sampling = train_sampling
# Train set statistics and reward normalization constant
if self.buffer.train is not None:
scores_stats_tr = [
self.buffer.min_tr,
self.buffer.max_tr,
self.buffer.mean_tr,
self.buffer.std_tr,
self.buffer.max_norm_tr,
]
print("\nTrain data")
print(f"\tMean score: {scores_stats_tr[2]}")
print(f"\tStd score: {scores_stats_tr[3]}")
print(f"\tMin score: {scores_stats_tr[0]}")
print(f"\tMax score: {scores_stats_tr[1]}")
else:
scores_stats_tr = None
# Test set statistics
if self.buffer.test is not None:
print("\nTest data")
print(f"\tMean score: {self.buffer.test['scores'].mean()}")
print(f"\tStd score: {self.buffer.test['scores'].std()}")
print(f"\tMin score: {self.buffer.test['scores'].min()}")
print(f"\tMax score: {self.buffer.test['scores'].max()}")
# Models
[docs]
self.forward_policy = forward_policy
[docs]
self.backward_policy = backward_policy
[docs]
self.state_flow = state_flow
# Optimizer
if self.forward_policy.is_model:
self.target = copy.deepcopy(self.forward_policy.model)
self.opt, self.lr_scheduler = make_opt(
self.parameters(), self.logZ, optimizer
)
else:
self.opt, self.lr_scheduler, self.target = None, None, None
# Evaluator
[docs]
self.evaluator = evaluator
self.evaluator.set_agent(self)
[docs]
self.n_train_steps = optimizer.n_train_steps
[docs]
self.batch_size = optimizer.batch_size
[docs]
self.batch_size_total = sum(self.batch_size.values())
[docs]
self.ttsr = max(int(optimizer.train_to_sample_ratio), 1)
[docs]
self.sttr = max(int(1 / optimizer.train_to_sample_ratio), 1)
[docs]
self.clip_grad_norm = optimizer.clip_grad_norm
[docs]
self.tau = optimizer.bootstrap_tau
[docs]
self.use_context = use_context
[docs]
self.logsoftmax = torch.nn.LogSoftmax(dim=1)
# Training
[docs]
self.mask_invalid_actions = mask_invalid_actions
[docs]
self.temperature_logits = temperature_logits
[docs]
self.random_action_prob = random_action_prob
[docs]
self.garbage_collection_period = garbage_collection_period
# Metrics
[docs]
self.corr_probs_rewards = 0.0
[docs]
self.corr_logprobs_logrewards = 0.0
[docs]
self.var_logrewards_logp = -1.0
[docs]
self.mean_logprobs_std = -1.0
[docs]
self.mean_probs_std = -1.0
[docs]
self.logprobs_std_nll_ratio = -1.0
[docs]
def parameters(self):
parameters = list(self.forward_policy.model.parameters())
if self.backward_policy.is_model:
if not self.loss.requires_backward_policy():
raise ValueError(
"Backward policy initialized but not required by "
f"loss {self.loss.name}."
)
parameters += list(self.backward_policy.model.parameters())
if self.state_flow is not None:
if not self.loss.requires_state_flow_model():
raise ValueError(
"State flow model initialized but not required by "
f"loss {self.loss.name}."
)
parameters += list(self.state_flow.model.parameters())
return parameters
[docs]
def sample_actions(
self,
envs: List[GFlowNetEnv],
batch: Optional[Batch] = None,
env_cond: Optional[GFlowNetEnv] = None,
sampling_method: Optional[str] = "policy",
backward: Optional[bool] = False,
temperature: Optional[float] = 1.0,
random_action_prob: Optional[float] = None,
no_random: Optional[bool] = True,
times: Optional[dict] = None,
compute_reversed_logprobs: Optional[bool] = False,
) -> List[Tuple]:
"""
Samples one action on each environment of the list envs, according to the
sampling method specified by sampling_method.
With probability 1 - random_action_prob, actions will be sampled from the
self.forward_policy or self.backward_policy, depending on backward. The rest
are sampled according to the random policy of the environment
(model.random_distribution).
If a batch is provided (and self.mask_invalid_actions) is True, the masks are
retrieved from the batch. Otherwise they are computed from the environments.
Args
----
envs : list of GFlowNetEnv or derived
A list of instances of the environment
batch_forward : Batch
A batch from which to obtain required variables (e.g. masks) to avoid
recomputing them.
env_cond : GFlowNetEnv or derived
An environment to do conditional sampling, that is restrict the action
space via the masks of the main environments.
sampling_method : string
- policy: uses current forward to obtain the sampling probabilities.
- random: samples purely from a random policy, that is
- random_action_prob = 1.0
regardless of the value passed as arguments.
backward : bool
True if sampling is backward. False (forward) by default.
temperature : float
Temperature to adjust the logits by logits /= temperature. If None,
self.temperature_logits is used.
random_action_prob : float
Probability of sampling random actions. If None (default),
self.random_action_prob is used, unless its value is forced to either 0.0
or 1.0 by other arguments (sampling_method or no_random).
no_random : bool
If True, the samples will strictly be on-policy, that is
- temperature = 1.0
- random_action_prob = 0.0
regardless of the values passed as arguments.
times : dict
Dictionary to store times. Currently not implemented.
compute_reversed_logprobs: bool
If True, reversed logprobs will be computed. Default is False. Reversed
logprobs correspond to the reversed direction to sampling, i.e. if
sampling is forwards, reversed logprobs are backward logprobs and
vise versa. Reversed logprobs are computed on the current states of
the envs and on the actions sampled (and added to the batch) before
the current step.
Returns
-------
actions : list of tuples
The sampled actions, one for each environment in envs.
logprobs : tensor or None
Log probabilities corresponding to each sampled action. It may be None if
the environment's sampled_action_batch() method does not calculate the log
probs while sampling the actions.
"""
# Preliminaries
if sampling_method == "random":
assert (
no_random is False
), "sampling_method random and no_random True is ambiguous"
random_action_prob = 1.0
temperature = 1.0
elif no_random is True:
temperature = 1.0
random_action_prob = 0.0
else:
if temperature is None:
temperature = self.temperature_logits
if random_action_prob is None:
random_action_prob = self.random_action_prob
if backward:
model = self.backward_policy
model_rev = self.forward_policy
else:
model = self.forward_policy
model_rev = self.backward_policy
if not isinstance(envs, list):
envs = [envs]
# Build states and masks
states = [env.state for env in envs]
# Obtain masks of invalid actions
mask_invalid_actions = self._get_masks(
envs, batch, env_cond, backward, backward
)
# Get policy inputs from the states and obtain the policy outputs from the
# model
# TODO: get policy states from batch
states_policy = tfloat(
self.env.states2policy(states),
device=self.device,
float_type=self.float,
)
policy_outputs = model(states_policy)
# Sample actions from policy outputs
actions = self.env.sample_actions_batch(
policy_outputs=policy_outputs,
mask=mask_invalid_actions,
states_from=states,
is_backward=backward,
random_action_prob=random_action_prob,
temperature_logits=temperature,
)
# Compute logprobs from policy outputs
logprobs = self.env.get_logprobs(
policy_outputs=policy_outputs,
actions=actions,
mask=mask_invalid_actions,
states_from=states,
is_backward=backward,
)
if compute_reversed_logprobs:
logprobs_rev = torch.zeros_like(logprobs)
indices_rev = batch.get_indices_of_previous_transitions(envs, backward)
if any(indices_rev):
actions_all = batch.get_actions()
actions_rev, states_from = zip(
*[
(actions_all[idx], state)
for idx, state in zip(indices_rev, states)
if idx is not None
]
)
actions_rev = tfloat(
list(actions_rev), device=self.device, float_type=self.float
)
is_rev = torch.tensor(
np.array(indices_rev) != None, dtype=torch.bool, device=self.device
)
mask_invalid_actions_rev = self._get_masks(
envs, batch, env_cond, not backward, backward
)
policy_outputs_rev = model_rev(states_policy[is_rev])
logprobs_rev[is_rev] = self.env.get_logprobs(
policy_outputs=policy_outputs_rev[is_rev],
actions=actions_rev,
mask=mask_invalid_actions_rev[is_rev],
states_from=list(states_from),
is_backward=not backward,
)
else:
logprobs_rev = [None] * len(actions)
return actions, logprobs, logprobs_rev
def _get_masks(
self,
envs: List[GFlowNetEnv],
batch: Optional[Batch] = None,
env_cond: Optional[GFlowNetEnv] = None,
is_backward_mask: Optional[bool] = False,
is_backward_traj: Optional[bool] = False,
) -> List[List[bool]]:
"""
Given a batch and/or a list of environments, obtains the mask of invalid
actions of each environment's current state.
Note that batch.get_item("mask_*") computes the mask if it is not available and
stores it in the batch.
If env_cond is not None, then the masks will be adjusted according to the
restrictions imposed by the conditioning environment, env_cond (see
GFlowNetEnv.mask_conditioning()).
Parameters
----------
envs : list of GFlowNetEnv or derived
A list of instances of the environment
batch_forward : Batch
A batch from which to obtain the masks to avoid recomputing them.
env_cond : GFlowNetEnv or derived
An environment to do conditional sampling, that is restrict the action
space via the masks of the main environments. Ignored if None.
is_backward_mask : bool
Whether the masks are of backward transitions (True) or forward transitions
(False). False (forward) by default.
is_backward_traj : bool
Whether the trajectories in the batch are sampled backwards (True) or
forward (False). False (forward) by default.
Returns
-------
A list of boolean lists containing the masks of invalid actions of each
environment.
"""
if not self.mask_invalid_actions:
return None
if batch is not None:
if is_backward_mask:
mask_invalid_actions = tbool(
[
batch.get_item("mask_backward", env, backward=is_backward_traj)
for env in envs
],
device=self.device,
)
else:
mask_invalid_actions = tbool(
[
batch.get_item("mask_forward", env, backward=is_backward_traj)
for env in envs
],
device=self.device,
)
# Compute masks since a batch was not provided
else:
if is_backward_mask:
mask_invalid_actions = tbool(
[env.get_mask_invalid_actions_backward() for env in envs],
device=self.device,
)
else:
mask_invalid_actions = tbool(
[env.get_mask_invalid_actions_forward() for env in envs],
device=self.device,
)
# Mask conditioning
if env_cond is not None:
mask_invalid_actions = tbool(
[
env.mask_conditioning(mask, env_cond, is_backward_mask)
for env, mask in zip(envs, mask_invalid_actions)
],
device=self.device,
)
return mask_invalid_actions
[docs]
def step(
self,
envs: List[GFlowNetEnv],
actions: List[Tuple],
backward: bool = False,
):
"""
Executes the actions on the environments envs, one by one. This method simply
calls env.step(action) or env.step_backwards(action) for each (env, action)
pair, depending on the value of backward.
Args
----
envs : list of GFlowNetEnv or derived
A list of instances of the environment
actions : list
A list of actions to be executed on each env of envs.
backward : bool
True if sampling is backward. False (forward) by default.
"""
assert len(envs) == len(actions)
if not isinstance(envs, list):
envs = [envs]
if backward:
_, actions, valids = zip(
*[env.step_backwards(action) for env, action in zip(envs, actions)]
)
else:
_, actions, valids = zip(
*[env.step(action) for env, action in zip(envs, actions)]
)
return envs, actions, valids
[docs]
def get_env_instances(self, nb_env_instances):
"""
Returns the requested number of instances of the environment
Args
----
nb_env_instances : int
Number of instance to return
Returns
-------
A list of environment instances
"""
# Create new env instances if not enough exist in the cache.
if len(self.env_cache) < nb_env_instances:
nb_new_instances_needed = nb_env_instances - len(self.env_cache)
new_instances = [self.env_maker() for _ in range(nb_new_instances_needed)]
self.env_cache.extend(new_instances)
# Return the requested instances
return self.env_cache[:nb_env_instances]
# TODO: avoid computing gradients when not needed
# TODO: extract code from while loop to avoid replication
[docs]
def sample_batch(
self,
n_forward: int = 0,
n_train: int = 0,
n_replay: int = 0,
env_cond: Optional[GFlowNetEnv] = None,
train=True,
progress=False,
collect_forwards_masks=False,
collect_backwards_masks=False,
):
"""
TODO: extend docstring.
Builds a batch of data by sampling online and/or offline trajectories.
"""
# Obtain the necessary env instances (one per forward/train/replay trajectory)
# WARNING : These instances must be reset before use.
nb_env_instances_needed = n_forward + n_train + n_replay
env_instances = self.get_env_instances(nb_env_instances_needed)
# PRELIMINARIES: Prepare Batch and environments
times = {
"all": 0.0,
"forward_actions": 0.0,
"train_actions": 0.0,
"replay_actions": 0.0,
"actions_envs": 0.0,
}
t0_all = time.time()
batch = Batch(env=self.env, device=self.device, float_type=self.float)
# ON-POLICY FORWARD trajectories
t0_forward = time.time()
envs = [env_instances.pop().reset(idx) for idx in range(n_forward)]
batch_forward = Batch(
env=self.env,
proxy=self.proxy,
device=self.device,
float_type=self.float,
collect_forwards_masks=collect_forwards_masks,
collect_backwards_masks=collect_backwards_masks,
)
while envs:
# Sample actions
t0_a_envs = time.time()
actions, logprobs, logprobs_rev = self.sample_actions(
envs,
batch_forward,
env_cond,
no_random=not train,
times=times,
compute_reversed_logprobs=self.collect_reversed_logprobs,
)
times["actions_envs"] += time.time() - t0_a_envs
# Update environments with sampled actions
envs, actions, valids = self.step(envs, actions)
# Add to batch
actions_torch = torch.tensor(actions)
batch_forward.add_to_batch(
envs, actions, logprobs, logprobs_rev, valids, train=train
)
# Filter out finished trajectories
envs = [env for env in envs if not env.done]
times["forward_actions"] = time.time() - t0_forward
# TRAIN BACKWARD trajectories
t0_train = time.time()
batch_train = Batch(
env=self.env,
proxy=self.proxy,
device=self.device,
float_type=self.float,
collect_forwards_masks=collect_forwards_masks,
collect_backwards_masks=collect_backwards_masks,
)
if n_train > 0 and self.buffer.train is not None:
envs = [env_instances.pop().reset(idx) for idx in range(n_train)]
x_train = self.buffer.select(
self.buffer.train, n_train, self.train_sampling, self.rng
)["samples"].values.tolist()
for env, x in zip(envs, x_train):
env.set_state(x, done=True)
else:
envs = []
while envs:
# Sample backward actions
t0_a_envs = time.time()
actions, logprobs, logprobs_rev = self.sample_actions(
envs,
batch_train,
env_cond,
backward=True,
no_random=not train,
times=times,
compute_reversed_logprobs=self.collect_reversed_logprobs,
)
times["actions_envs"] += time.time() - t0_a_envs
# Update environments with sampled actions
envs, actions, valids = self.step(envs, actions, backward=True)
# Add to batch
batch_train.add_to_batch(
envs,
actions,
logprobs,
logprobs_rev,
valids,
backward=True,
train=train,
)
# Filter out finished trajectories
envs = [env for env in envs if not env.equal(env.state, env.source)]
times["train_actions"] = time.time() - t0_train
# REPLAY BACKWARD trajectories
t0_replay = time.time()
batch_replay = Batch(
env=self.env,
proxy=self.proxy,
device=self.device,
float_type=self.float,
collect_forwards_masks=collect_forwards_masks,
collect_backwards_masks=collect_backwards_masks,
)
if (
n_replay > 0
and self.buffer.replay is not None
and len(self.buffer.replay) > 0
):
envs = [env_instances.pop().reset(idx) for idx in range(n_replay)]
n_replay = min(n_replay, len(self.buffer.replay))
x_replay = self.buffer.select(
self.buffer.replay,
n_replay,
self.replay_sampling,
self.rng,
)["samples"].values.tolist()
for env, x in zip(envs, x_replay):
env.set_state(x, done=True)
else:
envs = []
while envs:
# Sample backward actions
t0_a_envs = time.time()
actions, logprobs, logprobs_rev = self.sample_actions(
envs,
batch_replay,
env_cond,
backward=True,
no_random=not train,
times=times,
compute_reversed_logprobs=self.collect_reversed_logprobs,
)
times["actions_envs"] += time.time() - t0_a_envs
# Update environments with sampled actions
envs, actions, valids = self.step(envs, actions, backward=True)
# Add to batch
batch_replay.add_to_batch(
envs,
actions,
logprobs,
logprobs_rev,
valids,
backward=True,
train=train,
)
# Filter out finished trajectories
envs = [env for env in envs if not env.equal(env.state, env.source)]
times["replay_actions"] = time.time() - t0_replay
# Merge forward and backward batches
batch = batch.merge([batch_forward, batch_train, batch_replay])
times["all"] = time.time() - t0_all
return batch, times
@torch.no_grad()
[docs]
def estimate_logprobs_data(
self,
data: Union[List, str],
n_trajectories: int = 1,
max_iters_per_traj: int = 10,
max_data_size: int = 1e5,
batch_size: int = 100,
bs_num_samples=10000,
):
r"""
Estimates the probability of sampling with current GFlowNet policy
(self.forward_policy) the objects in a data set given by the argument data. The
(log) probabilities are estimated by sampling a number of backward trajectories
(n_trajectories) through importance sampling and calculating the forward
probabilities of the trajectories.
$$
\log p_T(x) = \int_{x \in \tau} P_F(\tau)d\tau \\
= \log \mathbb{E}_{P_B(\tau|x)} \frac{P_F(x)}{P_B(\tau|x)}\\
\approx \log \frac{1}{N} \sum_{i=1}^{N} \frac{P_F(x_i)}{P_B(\tau|x_i)}\\
= \log \sum_{i=1}^{N} \frac{P_F(x_i)}{P_B(\tau|x_i)} - \log N
$$
Note: torch.logsumexp is used to compute the log of the sum, in order to have
numerical stability, since we have the log PF and log PB, instead of directly
PF and PB.
Note: the correct indexing of data points and trajectories is ensured by the
fact that the indices of the environments are set in a consistent way with the
indexing when storing the log probabilities.
Args
----
data : list or string
A data set of terminating states. The data set may be passed directly as a
list of states, or it may be a string defining the path to a pickled data
set where the terminating states are stored in key "samples".
n_trajectories : int
The number of trajectories per object to sample for estimating the log
probabilities.
max_iters_per_traj : int
The maximum number of attempts to sample a distinct trajectory, to avoid
getting trapped in an infinite loop.
max_data_size : int
Maximum number of data points in the data set to avoid an accidental
situation of having to sample too many backward trajectories. If necessary,
the user should change this argument manually.
bs_num_samples: int
Number of bootstrap resampling times for std estimation of logprobs_estimates.
Doesn't require recomputing of log probabilities, so can be arbitrary large
Returns
-------
logprobs_estimates: torch.tensor
The logarithm of the average ratio PF/PB over n trajectories sampled for
each data point.
logprobs_std: torch.tensor
Bootstrap std of the logprobs_estimates
probs_std: torch.tensor
Bootstrap std of the torch.exp(logprobs_estimates)
"""
times = {}
# Determine terminating states
if isinstance(data, list):
states_term = data
elif isinstance(data, str) and Path(data).suffix == ".pkl":
with open(data, "rb") as f:
data_dict = pickle.load(f)
states_term = data_dict["samples"]
else:
raise NotImplementedError(
"data must be either a list of states or a path to a .pkl file."
)
n_states = len(states_term)
assert (
n_states < max_data_size
), "The size of the test data is larger than max_data_size ({max_data_size})."
# Compute log probabilities in batches
logprobs_f = torch.full(
(n_states, n_trajectories),
-torch.inf,
dtype=self.float,
device=self.device,
)
logprobs_b = torch.full(
(n_states, n_trajectories),
-torch.inf,
dtype=self.float,
device=self.device,
)
mult_indices = max(n_states, n_trajectories)
init_batch = 0
end_batch = min(batch_size, n_states)
pbar = tqdm(
total=n_states,
disable=self.logger.progressbar["skip"],
leave=False,
desc="Sampling backward actions from test data to estimate logprobs",
)
pbar2 = trange(
end_batch * n_trajectories,
disable=self.logger.progressbar["skip"],
leave=False,
desc="Setting env terminal states",
)
while init_batch < n_states:
batch = Batch(
env=self.env,
proxy=self.proxy,
device=self.device,
float_type=self.float,
)
# Obtain the necessary env instances: one per trajectory in the batch
# WARNING : These instances must be reset before use.
n_trajectories_batch = (end_batch - init_batch) * n_trajectories
env_instances = self.get_env_instances(n_trajectories_batch)
# For each data point and trajectory, set the state on an environment
envs = []
pbar2.reset(n_trajectories_batch)
for state_idx in range(init_batch, end_batch):
for traj_idx in range(n_trajectories):
idx = int(mult_indices * state_idx + traj_idx)
env = env_instances.pop().reset(idx)
env.set_state(states_term[state_idx], done=True)
envs.append(env)
pbar2.update(1)
# Sample trajectories
while envs:
# Sample backward actions
actions, logprobs, logprobs_rev = self.sample_actions(
envs,
batch,
backward=True,
no_random=True,
times=times,
compute_reversed_logprobs=True,
)
# Update environments with sampled actions
envs, actions, valids = self.step(envs, actions, backward=True)
# Add to batch
batch.add_to_batch(
envs,
actions,
logprobs,
logprobs_rev,
valids,
backward=True,
train=True,
)
# Filter out finished trajectories
envs = [env for env in envs if not env.equal(env.state, env.source)]
# Prepare data structures to compute log probabilities
traj_indices_batch = tlong(
batch.get_unique_trajectory_indices(), device=self.device
)
data_indices = traj_indices_batch // mult_indices
traj_indices = traj_indices_batch % mult_indices
# Compute log probabilities of the trajectories
logprobs_f[data_indices, traj_indices] = compute_logprobs_trajectories(
batch, self.env, forward_policy=self.forward_policy, backward=False
)
logprobs_b[data_indices, traj_indices] = compute_logprobs_trajectories(
batch, self.env, backward_policy=self.backward_policy, backward=True
)
# Increment batch indices
init_batch += batch_size
end_batch = min(end_batch + batch_size, n_states)
if n_states > batch_size:
pbar.update(end_batch - init_batch)
# Compute log of the average probabilities of the ratio PF / PB
logprobs_estimates = torch.logsumexp(
logprobs_f - logprobs_b, dim=1
) - torch.log(torch.tensor(n_trajectories, device=self.device))
logprobs_f_b_bs = bootstrap_samples(
logprobs_f - logprobs_b, num_samples=bs_num_samples
)
logprobs_estimates_bs = torch.logsumexp(logprobs_f_b_bs, dim=1) - torch.log(
torch.tensor(n_trajectories, device=self.device)
)
logprobs_std = torch.std(logprobs_estimates_bs, dim=-1)
probs_std = torch.std(torch.exp(logprobs_estimates_bs), dim=-1)
pbar.close()
pbar2.close()
return logprobs_estimates, logprobs_std, probs_std
[docs]
def train(self):
# Train loop
pbar = tqdm(
initial=self.it - 1,
total=self.n_train_steps,
disable=self.logger.progressbar["skip"],
)
for self.it in range(self.it, self.n_train_steps + 1):
# Test and log
if self.evaluator.should_eval(self.it):
self.evaluator.eval_and_log(self.it)
if self.evaluator.should_eval_top_k(self.it):
self.evaluator.eval_and_log_top_k(self.it)
t0_iter = time.time()
batch = Batch(
env=self.env,
proxy=self.proxy,
device=self.device,
float_type=self.float,
)
for j in range(self.sttr):
sub_batch, times = self.sample_batch(
n_forward=self.batch_size.forward,
n_train=self.batch_size.backward_dataset,
n_replay=self.batch_size.backward_replay,
collect_forwards_masks=True,
collect_backwards_masks=self.collect_backwards_masks,
)
batch.merge(sub_batch)
for j in range(self.ttsr):
losses = self.loss.compute(batch, get_sublosses=True)
# TODO: deal with this in a better way
if not all([torch.isfinite(loss) for loss in losses.values()]):
if self.logger.debug:
print("Loss is not finite - skipping iteration")
else:
losses["all"].backward()
if self.clip_grad_norm > 0:
torch.nn.utils.clip_grad_norm_(
self.parameters(), self.clip_grad_norm
)
self.opt.step()
self.lr_scheduler.step()
self.opt.zero_grad()
# Log training iteration: progress bar, buffer, metrics, intermediate
# models
times = self.log_train_iteration(pbar, losses, batch, times)
# Log times
t1_iter = time.time()
times.update({"iter": t1_iter - t0_iter})
self.logger.log_time(times, use_context=self.use_context)
# Garbage collection and cleanup GPU memory
if (
self.garbage_collection_period > 0
and self.garbage_collection_period % self.it == 0
):
del batch
gc.collect()
torch.cuda.empty_cache()
# Check early stopping
if self.loss.do_early_stopping(losses["all"]):
print(
"Ending training after meeting early stopping criteria: "
f"{self.loss.loss_ema} < {self.loss.early_stopping_th}"
)
break
# Save final model
self.logger.save_checkpoint(
forward_policy=self.forward_policy,
backward_policy=self.backward_policy,
state_flow=self.state_flow,
logZ=self.logZ,
optimizer=self.opt,
buffer=self.buffer,
step=self.it,
final=True,
)
# Close logger
if self.use_context is False:
self.logger.end()
@torch.no_grad()
[docs]
def log_train_iteration(self, pbar: tqdm, losses: List, batch: Batch, times: dict):
"""
Carries out the logging operations after the training iteration.
The operations done by this method include:
- Updating the main buffer
- Updating the replay buffer
- Logging the rewards and scores of the train batch
- Logging the losses, logZ, learning rate and other metrics of the training
process
- Updating the progress bar
- Save checkpoints
Parameters
----------
pbar : tqdm
Progress bar object
losses : dict
Dictionary of losses after the training iteration
batch : Batch
Training batch
times : dict
Dictionary of times
"""
t0_buffer = time.time()
states_term = batch.get_terminating_states(sort_by="trajectory")
proxy_vals = batch.get_terminating_proxy_values(sort_by="trajectory")
# The batch will typically have the log-rewards available, since they are
# used to compute the losses. In order to avoid recalculating the proxy
# values, the natural rewards are computed by taking the exponential of the
# log-rewards. In case the rewards are available in the batch but not the
# log-rewards, the latter are computed by taking the log of the rewards.
# Numerical issues are not critical in this case, since the derived values
# are only used for reporting purposes.
if batch.rewards_available(log=False):
rewards = batch.get_terminating_rewards(sort_by="trajectory")
if batch.rewards_available(log=True):
logrewards = batch.get_terminating_rewards(sort_by="trajectory", log=True)
if not batch.rewards_available(log=False):
assert batch.rewards_available(log=True)
rewards = torch.exp(logrewards)
if not batch.rewards_available(log=True):
assert batch.rewards_available(log=False)
logrewards = torch.log(rewards)
# Update main buffer
actions_trajectories = batch.get_actions_trajectories()
if self.buffer.use_main_buffer:
self.buffer.add(
states_term,
actions_trajectories,
rewards,
self.it,
buffer="main",
)
# Update replay buffer
self.buffer.add(
states_term,
actions_trajectories,
rewards,
self.it,
buffer="replay",
)
t1_buffer = time.time()
times.update({"buffer": t1_buffer - t0_buffer})
### Train logs
t0_log = time.time()
# TODO: consider moving this into separate method
if self.evaluator.should_log_train(self.it):
# logZ
if self.logZ is not None:
logz = self.logZ.sum()
else:
logz = None
# Trajectory length
_, trajectory_lengths = torch.unique(
batch.get_trajectory_indices(), return_counts=True
)
traj_length_mean = torch.mean(trajectory_lengths.to(self.float))
traj_length_min = torch.min(trajectory_lengths)
traj_length_max = torch.max(trajectory_lengths)
# Learning rates
learning_rates = self.lr_scheduler.get_last_lr()
if len(learning_rates) == 1:
learning_rates += [None]
# Log train rewards and scores
self.logger.log_rewards_and_scores(
rewards,
logrewards,
proxy_vals,
step=self.it,
prefix="Train batch -",
use_context=self.use_context,
)
# Log trajectory lengths, batch size, logZ and learning rates
self.logger.log_metrics(
metrics={
"step": self.it,
"Trajectory lengths mean": traj_length_mean,
"Trajectory lengths min": traj_length_min,
"Trajectory lengths max": traj_length_max,
"Batch size": len(batch),
"logZ": logz,
"Learning rate": learning_rates[0],
"Learning rate logZ": learning_rates[1],
},
step=self.it,
use_context=self.use_context,
)
# Log losses
losses["Loss"] = losses["all"]
self.logger.log_metrics(
metrics=losses,
step=self.it,
use_context=self.use_context,
)
# Log replay buffer rewards
if self.buffer.replay_updated:
rewards_replay = self.buffer.replay.rewards
self.logger.log_rewards_and_scores(
rewards_replay,
np.log(rewards_replay),
scores=None,
step=self.it,
prefix="Replay buffer -",
use_context=self.use_context,
)
t1_log = time.time()
times.update({"log": t1_log - t0_log})
# Progress bar
self.logger.progressbar_update(
pbar, losses["all"].item(), rewards.tolist(), self.jsd, self.use_context
)
# Save intermediate models
t0_model = time.time()
if self.evaluator.should_checkpoint(self.it):
self.logger.save_checkpoint(
forward_policy=self.forward_policy,
backward_policy=self.backward_policy,
state_flow=self.state_flow,
logZ=self.logZ,
optimizer=self.opt,
buffer=self.buffer,
step=self.it,
)
t1_model = time.time()
times.update({"save_interim_model": t1_model - t0_model})
return times
[docs]
def get_sample_space_and_reward(self):
"""
Returns samples representative of the env state space with their rewards
Returns
-------
sample_space_batch : tensor
Repressentative terminating states for the environment
rewards_sample_space : tensor
Rewards associated with the tates in sample_space_batch
"""
if not hasattr(self, "sample_space_batch"):
if hasattr(self.env, "get_all_terminating_states"):
self.sample_space_batch = self.env.get_all_terminating_states()
elif hasattr(self.env, "get_grid_terminating_states"):
self.sample_space_batch = self.env.get_grid_terminating_states(
self.evaluator.config.n_grid
)
else:
raise NotImplementedError(
"In order to obtain representative terminating states, the "
"environment must implement either get_all_terminating_states() "
"or get_grid_terminating_states()"
)
self.sample_space_batch = self.env.states2proxy(self.sample_space_batch)
if not hasattr(self, "rewards_sample_space"):
self.rewards_sample_space = self.proxy.rewards(self.sample_space_batch)
return self.sample_space_batch, self.rewards_sample_space
# TODO: implement other proposal distributions
# TODO: rethink whether it is needed to convert to reward
[docs]
def sample_from_reward(
self,
n_samples: int,
proposal_distribution: str = "uniform",
epsilon=1e-4,
) -> Union[List, Dict, TensorType["n_samples", "state_dim"]]:
"""
Rejection sampling with proposal the uniform distribution defined over the
sample space.
Returns a tensor in GFloNet (state) format.
Parameters
----------
n_samples : int
The number of samples to draw from the reward distribution.
proposal_distribution : str
Identifier of the proposal distribution. Currently only `uniform` is
implemented.
epsilon : float
Small epsilon parameter for rejection sampling.
Returns
-------
samples_final : list
The list of samples drawn from the reward distribution in environment
format.
"""
samples_final = []
max_reward = self.proxy.get_max_reward()
while len(samples_final) < n_samples:
if proposal_distribution == "uniform":
# TODO: sample only the remaining number of samples
samples_uniform = self.env.get_uniform_terminating_states(n_samples)
else:
raise NotImplementedError("The proposal distribution must be uniform")
rewards = self.proxy.proxy2reward(
self.proxy(self.env.states2proxy(samples_uniform))
)
indices_accept = (
(
torch.rand(n_samples, dtype=self.float, device=self.device)
* (max_reward + epsilon)
< rewards
)
.flatten()
.tolist()
)
samples_accepted = [samples_uniform[idx] for idx in indices_accept]
samples_final.extend(samples_accepted[-(n_samples - len(samples_final)) :])
return samples_final
[docs]
def load_checkpoint(self, checkpoint: dict):
"""
Loads the content of a checkpoint dictionary into the corresponding variables
of the GFlowNet agent.
Parameters
----------
checkpoint : dict
A dictionary containing the following keys:
- "step": The iteration number of the checkpoint,
- "forward": The state dict of the forward policy model,
- "backward": The state dict of the backward policy model,
- "state_flow": The state dict of the state flow model,
- "logZ": The tensor containing the parameters of logZ,
- "optimizer": The state dict of the optimizer,
- "buffer": A dictionary with keys 'train', 'test' and 'replay', with
the relative paths of the corresponding data sets.
"""
# Iteration: increment by one
self.it = checkpoint["step"] + 1
# Forward model
if checkpoint["forward"] is not None:
assert self.forward_policy.is_model
self.forward_policy.model.load_state_dict(checkpoint["forward"])
# Backward model
if checkpoint["backward"] is not None:
assert self.backward_policy.is_model
self.backward_policy.model.load_state_dict(checkpoint["backward"])
# State flow model
if checkpoint["state_flow"] is not None:
assert self.state_flow
self.state_flow.model.load_state_dict(checkpoint["state_flow"])
# LogZ
if checkpoint["logZ"] is not None:
assert isinstance(self.logZ, torch.nn.Parameter) and self.logZ.requires_grad
self.logZ.data = checkpoint["logZ"].to(self.device)
# Optimizer
self.opt.load_state_dict(checkpoint["optimizer"])
if self.logger.debug:
print("\nCheckpoint loaded into GFlowNet agent\n")
[docs]
def make_opt(params, logZ, config):
"""
Set up the optimizer
"""
params = params
if not len(params):
return None
if config.method == "adam":
opt = torch.optim.Adam(
params,
config.lr,
betas=(config.adam_beta1, config.adam_beta2),
)
if logZ is not None:
opt.add_param_group(
{
"params": logZ,
"lr": config.lr * config.lr_z_mult,
}
)
elif config.method == "msgd":
opt = torch.optim.SGD(params, config.lr, momentum=config.momentum)
# Learning rate scheduling
lr_scheduler = torch.optim.lr_scheduler.StepLR(
opt,
step_size=config.lr_decay_period,
gamma=config.lr_decay_gamma,
)
return opt, lr_scheduler