import os
import random
from copy import deepcopy
from functools import partial
from os.path import expandvars
from pathlib import Path
from typing import List, Optional, Tuple, Union
import numpy as np
import numpy.typing as npt
import torch
from hydra import compose, initialize_config_dir
from hydra.utils import get_original_cwd, instantiate
from omegaconf import DictConfig, OmegaConf
from torchtyping import TensorType
from gflownet.utils.policy import parse_policy_config
[docs]
def set_device(device: Union[str, torch.device]):
"""
Get `torch` device from device.
Examples
--------
>>> set_device("cuda")
device(type='cuda')
>>> set_device("cpu")
device(type='cpu')
>>> set_device(torch.device("cuda"))
device(type='cuda')
Parameters
----------
device : Union[str, torch.device]
Device.
Returns
-------
torch.device
`torch` device.
"""
if isinstance(device, torch.device):
return device
if device.lower() == "cuda" and torch.cuda.is_available():
return torch.device("cuda")
else:
return torch.device("cpu")
[docs]
def set_float_precision(precision: Union[int, torch.dtype]):
"""
Get `torch` float type from precision.
Examples
--------
>>> set_float_precision(32)
torch.float32
>>> set_float_precision(torch.float32)
torch.float32
Parameters
----------
precision : Union[int, torch.dtype]
Precision.
Returns
-------
torch.dtype
`torch` float type.
Raises
------
ValueError
If precision is not one of [16, 32, 64].
"""
if isinstance(precision, torch.dtype):
return precision
if precision == 16:
return torch.float16
elif precision == 32:
return torch.float32
elif precision == 64:
return torch.float64
else:
raise ValueError("Precision must be one of [16, 32, 64]")
[docs]
def set_int_precision(precision: Union[int, torch.dtype]):
"""
Get `torch` integer type from `int` precision.
Examples
--------
>>> set_int_precision(32)
torch.int32
>>> set_int_precision(torch.int32)
torch.int32
Parameters
----------
precision : Union[int, torch.dtype]
Integer precision.
Returns
-------
torch.dtype
`torch` integer type.
Raises
------
ValueError
If precision is not one of [16, 32, 64].
"""
if isinstance(precision, torch.dtype):
return precision
if precision == 16:
return torch.int16
elif precision == 32:
return torch.int32
elif precision == 64:
return torch.int64
else:
raise ValueError("Precision must be one of [16, 32, 64]")
[docs]
def torch2np(x):
"""
Convert a torch tensor to a numpy array.
Parameters
----------
x : Union[torch.Tensor, np.ndarray, list]
Data to be converted.
Returns
-------
np.ndarray
Converted data.
"""
if hasattr(x, "is_cuda") and x.is_cuda:
x = x.detach().cpu()
return np.array(x)
[docs]
def download_file_if_not_exists(path: str, url: str):
"""
Download a file from google drive if path doestn't exist.
url should be in the format: https://drive.google.com/uc?id=FILE_ID
"""
import gdown
path = Path(path)
if not path.is_absolute():
# to avoid storing downloaded files with the logs, prefix is set to the original working dir
prefix = get_original_cwd()
path = Path(prefix) / path
if not path.exists():
path.absolute().parent.mkdir(parents=True, exist_ok=True)
gdown.download(url, str(path.absolute()), quiet=False)
return path
[docs]
def resolve_path(path: str) -> Path:
"""
Resolve a path by expanding environment variables, user home directory, and making
it absolute.
Examples
--------
>>> resolve_path("~/scratch/$SLURM_JOB_ID/data")
Path("/home/user/scratch/12345/data")
Parameters
----------
path : Union[str, Path]
Path to be resolved.
Returns
-------
Path
Resolved path.
"""
return Path(expandvars(str(path))).expanduser().resolve()
[docs]
def find_latest_checkpoint(ckpt_dir):
"""
Find the latest checkpoint in the input directory.
If the directory contains a checkpoint file with the name "final", that checkpoint
is returned. Otherwise, the latest checkpoint is returned based on the iteration
number set in the file names.
Parameters
----------
ckpt_dir : Union[str, Path]
Directory in which to search for the checkpoints.
Returns
-------
Path
Path to the latest checkpoint.
Raises
------
ValueError
If no checkpoint files are found in the input directory.
"""
ckpt_dir = Path(ckpt_dir)
final = [f for f in ckpt_dir.glob(f"*final*")]
if len(final) > 0:
return final[0]
ckpts = [f for f in ckpt_dir.glob(f"iter_*")]
if not ckpts:
raise ValueError(
f"No checkpoints found in {ckpt_dir} with pattern iter_* or *final*"
)
return sorted(ckpts, key=lambda f: float(f.stem.split("iter_")[1]))[-1]
[docs]
def read_hydra_config(rundir=None, config_name="config"):
if rundir is None:
rundir = Path(config_name)
hydra_dir = rundir.parent
config_name = rundir.name
else:
hydra_dir = rundir / ".hydra"
with initialize_config_dir(
version_base=None, config_dir=str(hydra_dir), job_name="xxx"
):
return compose(config_name=config_name)
[docs]
def gflownet_from_config(config, env=None):
"""
Create GFlowNet from a Hydra OmegaConf config.
Parameters
----------
config : DictConfig
Config.
env : GFlowNetEnv
Optional environment instance to be used in the initialization.
Returns
-------
GFN
GFlowNet.
"""
# Logger
logger = instantiate(config.logger, config, _recursive_=False)
# The proxy is required by the GFlowNetAgent for computing rewards
proxy = instantiate(
config.proxy,
device=config.device,
float_precision=config.float_precision,
)
# Using Hydra's partial instantiation, see:
# https://hydra.cc/docs/advanced/instantiate_objects/overview/#partial-instantiation
# If env is passed as an argument, we create an env maker with a partial
# instantiation from the copy method of the environment (this is used in unit
# tests, for example). Otherwise, we create the env maker with partial
# instantiation from the config.
if env is not None:
env_maker = partial(env.copy)
else:
env_maker = instantiate(
config.env,
device=config.device,
float_precision=config.float_precision,
_partial_=True,
)
env = env_maker()
# TOREVISE: set up proxy so when buffer calls it (when it creates train / test
# dataset) it has the correct infro from env
# proxy.setup(env)
buffer = instantiate(
config.buffer,
env=env,
proxy=proxy,
datadir=logger.datadir,
)
# The evaluator is used to compute metrics and plots
evaluator = instantiate(config.evaluator)
# The policy is used to model the probability of a forward/backward action
forward_config = parse_policy_config(config, kind="forward")
backward_config = parse_policy_config(config, kind="backward")
forward_policy = instantiate(
forward_config,
env=env,
device=config.device,
float_precision=config.float_precision,
)
backward_policy = instantiate(
backward_config,
env=env,
device=config.device,
float_precision=config.float_precision,
base=forward_policy,
)
# State flow
if config.gflownet.state_flow is not None:
state_flow = instantiate(
config.gflownet.state_flow,
env=env,
device=config.device,
float_precision=config.float_precision,
base=forward_policy,
)
else:
state_flow = None
# Loss
loss = instantiate(
config.loss,
forward_policy=forward_policy,
backward_policy=backward_policy,
state_flow=state_flow,
device=config.device,
float_precision=config.float_precision,
)
# GFlowNet Agent
gflownet = instantiate(
config.gflownet,
device=config.device,
float_precision=config.float_precision,
env_maker=env_maker,
proxy=proxy,
loss=loss,
forward_policy=forward_policy,
backward_policy=backward_policy,
state_flow=state_flow,
buffer=buffer,
logger=logger,
evaluator=evaluator,
)
return gflownet
[docs]
def load_gflownet_from_rundir(
rundir,
no_wandb=True,
print_config=False,
device=None,
load_last_checkpoint=True,
is_resumed: bool = False,
):
"""
Load GFlowNet from a run path (directory with a `.hydra` directory inside).
Parameters
----------
rundir : Union[str, Path]
Path to the run directory. Must contain a `.hydra` directory.
no_wandb : bool, optional
Whether to disable wandb in the GFN init, by default True.
print_config : bool, optional
Whether to print the loaded config, by default False.
device : str, optional
Device to which the models should be moved. If None (default), take the device
from the loaded config.
load_last_checkpoint : bool, optional
Whether to load the final models, by default True.
is_resumed : bool, optional
Whether the GFlowNet is loaded to resume training.
Returns
-------
Tuple[GFN, DictConfig]
Loaded GFlowNet and the loaded config.
Raises
------
ValueError
If no checkpoints are found in the directory.
"""
rundir = resolve_path(rundir)
# Read experiment config
config = OmegaConf.load(Path(rundir) / ".hydra" / "config.yaml")
# Resolve variables
config = OmegaConf.to_container(config, resolve=True)
# Re-create OmegaCong DictConfig
config = OmegaConf.create(config)
if print_config:
print(OmegaConf.to_yaml(config))
# Device
if device is None:
device = config.device
if no_wandb:
# Disable wandb
config.logger.do.online = False
# -----------------------------------------
# ----- Load last model checkpoints -----
# -----------------------------------------
if load_last_checkpoint:
checkpoint_latest = find_latest_checkpoint(rundir / config.logger.logdir.ckpts)
checkpoint = torch.load(checkpoint_latest, map_location=set_device(device))
# Set run id in logger to enable WandB resume
config.logger.run_id = checkpoint["run_id"]
# Set up Buffer configuration to load data sets and buffers from run
if checkpoint["buffer"]["train"]:
config.buffer.train = {
"type": "pkl",
"path": checkpoint["buffer"]["train"],
}
if checkpoint["buffer"]["test"]:
config.buffer.test = {
"type": "pkl",
"path": checkpoint["buffer"]["test"],
}
if checkpoint["buffer"]["replay"]:
config.buffer.replay_buffer = checkpoint["buffer"]["replay"]
# load them here
if is_resumed:
config.logger.logdir.root = rundir
config.logger.is_resumed = True
# Initialize a GFlowNet agent from the configuration
gflownet = gflownet_from_config(config)
# Load checkpoint into the GFlowNet agent
if load_last_checkpoint:
gflownet.load_checkpoint(checkpoint)
return gflownet, config
[docs]
def batch_with_rest(start, stop, step, tensor=False):
"""
Yields batches of indices from start to stop with step size. The last batch may be
smaller than step.
Parameters
----------
start : int
Start index
stop : int
End index (exclusive)
step : int
Step size
tensor : bool, optional
Whether to return a `torch` tensor of indices instead of a `numpy` array, by
default False.
Yields
------
Union[np.ndarray, torch.Tensor]
Batch of indices
"""
for i in range(start, stop, step):
if tensor:
yield torch.arange(i, min(i + step, stop))
else:
yield np.arange(i, min(i + step, stop))
[docs]
def tfloat(x, device, float_type):
"""
Convert input to a float tensor. If the input is a list of tensors, the tensors
are stacked along the first dimension.
The resulting tensor is moved to the specified device.
Parameters
----------
x : Union[List[torch.Tensor], torch.Tensor, List[Union[int, float]], Union[int,
float]]
Input to be converted to a float tensor.
device : torch.device
Device to which the tensor should be moved.
float_type : torch.dtype
Float type to which the tensor should be converted.
Returns
-------
Union[torch.Tensor, List[torch.Tensor]]
Float tensor.
"""
if isinstance(x, list) and torch.is_tensor(x[0]):
return torch.stack(x).to(device=device, dtype=float_type)
if torch.is_tensor(x):
return x.to(device=device, dtype=float_type)
else:
return torch.tensor(x, dtype=float_type, device=device)
[docs]
def tlong(x, device):
"""
Convert input to a long tensor. If the input is a list of tensors, the tensors
are stacked along the first dimension.
The resulting tensor is moved to the specified device.
Parameters
----------
x : Union[List[torch.Tensor], torch.Tensor, List[Union[int, float]], Union[int,
float]]
Input to be converted to a long tensor.
device : torch.device
Device to which the tensor should be moved.
Returns
-------
Union[torch.Tensor, List[torch.Tensor]]
Long tensor.
"""
if isinstance(x, list) and torch.is_tensor(x[0]):
return torch.stack(x).to(device=device, dtype=torch.long)
if torch.is_tensor(x):
return x.to(device=device, dtype=torch.long)
else:
return torch.tensor(x, dtype=torch.long, device=device)
[docs]
def tint(x, device, int_type):
"""
Convert input to an integer tensor. If the input is a list of tensors, the tensors
are stacked along the first dimension.
The resulting tensor is moved to the specified device.
Parameters
----------
x : Union[List[torch.Tensor], torch.Tensor, List[Union[int, float]], Union[int,
float]]
Input to be converted to an integer tensor.
device : torch.device
Device to which the tensor should be moved.
int_type : torch.dtype
Integer type to which the tensor should be converted.
Returns
-------
Union[torch.Tensor, List[torch.Tensor]]
Integer tensor.
"""
if isinstance(x, list) and torch.is_tensor(x[0]):
return torch.stack(x).to(device=device, dtype=int_type)
if torch.is_tensor(x):
return x.to(device=device, dtype=int_type)
else:
return torch.tensor(x, dtype=int_type, device=device)
[docs]
def tbool(x, device):
"""
Convert input to a boolean tensor. If the input is a list of tensors, the tensors
are stacked along the first dimension.
The resulting tensor is moved to the specified device.
Parameters
----------
x : Union[List[torch.Tensor], torch.Tensor, List[Union[int, float]], Union[int,
float]]
Input to be converted to a boolean tensor.
device : torch.device
Device to which the tensor should be moved.
Returns
-------
Union[torch.Tensor, List[torch.Tensor]]
Boolean tensor.
"""
if isinstance(x, list) and torch.is_tensor(x[0]):
return torch.stack(x).to(device=device, dtype=torch.bool)
if torch.is_tensor(x):
return x.to(device=device, dtype=torch.bool)
else:
return torch.tensor(x, dtype=torch.bool, device=device)
[docs]
def concat_items(list_of_items, indices=None):
"""
Concatenates a list of items into a single tensor or array.
Parameters
----------
list_of_items :
List of items to be concatenated, i.e. list of arrays or list of tensors.
indices : Union[List[np.ndarray], List[torch.Tensor]], optional
Indices to select in the resulting concatenated tensor or array, by default
None.
Returns
-------
Union[np.ndarray, torch.Tensor]
Concatenated tensor or array, with optional selection of indices.
Raises
------
NotImplementedError
If the input type is not supported, i.e., not a list of arrays or a list of
tensors.
"""
if isinstance(list_of_items[0], np.ndarray):
result = np.concatenate(list_of_items)
if indices is not None:
if torch.is_tensor(indices[0]):
indices = indices.cpu().numpy()
result = result[indices]
elif torch.is_tensor(list_of_items[0]):
result = torch.cat(list_of_items)
if indices is not None:
result = result[indices]
else:
raise NotImplementedError(
"cannot concatenate {}".format(type(list_of_items[0]))
)
return result
[docs]
def extend(
orig: Union[List, TensorType["..."]], new: Union[List, TensorType["..."]]
) -> Union[List, TensorType["..."]]:
"""
Extends the original list or tensor with the new list or tensor.
Returns
-------
Union[List, TensorType["..."]]
Extended list or tensor.
Raises
------
NotImplementedError
If the input type is not supported, i.e., not a list or a tensor.
"""
assert isinstance(orig, type(new))
if isinstance(orig, list):
orig.extend(new)
elif torch.tensor(orig):
orig = torch.cat([orig, new])
else:
raise NotImplementedError(
"Extension only supported for lists and torch tensors"
)
return orig
[docs]
def copy(x: Union[List, TensorType["..."]]):
"""
Makes copy of the input tensor or list.
A tensor is cloned and detached from the computational graph.
Parameters
----------
x : Union[List, TensorType["..."]]
Input tensor or list to be copied.
Returns
-------
Union[List, TensorType["..."]]
Copy of the input tensor or list.
"""
if torch.is_tensor(x):
return x.clone().detach()
else:
return deepcopy(x)
[docs]
def bootstrap_samples(tensor, num_samples):
"""
Bootstraps tensor along the last dimention
returns tensor of the shape [initial_shape, num_samples]
"""
dim_size = tensor.size(-1)
bs_indices = torch.randint(
0, dim_size, size=(num_samples * dim_size,), device=tensor.device
)
bs_samples = torch.index_select(tensor, -1, index=bs_indices)
bs_samples = bs_samples.view(
tensor.size()[:-1] + (num_samples, dim_size)
).transpose(-1, -2)
return bs_samples
[docs]
def example_documented_function(arg1, arg2):
r"""Summary line: this function is not used anywhere, it's just an example.
Extended description of function from the docstrings tutorial :ref:`write
docstrings-extended`.
Refer to
* functions with :py:func:`gflownet.utils.common.set_device`
* classes with :py:class:`gflownet.gflownet.GFlowNetAgent`
* methods with :py:meth:`gflownet.envs.base.GFlowNetEnv.get_action_space`
* constants with :py:const:`gflownet.envs.base.CMAP`
Prepenend with ``~`` to refer to the name of the object only instead of the full
path -> :py:func:`~gflownet.utils.common.set_device` will display as ``set_device``
instead of the full path.
Great maths:
.. math::
\int_0^1 x^2 dx = \frac{1}{3}
.. important::
A docstring with **math** MUST be a raw Python string (a string prepended with
an ``r``: ``r"raw"``) to avoid backslashes being treated as escape characters.
Alternatively, you can use double backslashes.
.. warning::
Display a warning. See :ref:`learn by example`. (<-- this is a cross reference,
learn about it `here
<https://www.sphinx-doc.org/en/master/usage/referencing.html#ref-rolel>`_)
Examples
--------
>>> function(1, 'a')
True
>>> function(1, 2)
True
>>> function(1, 1)
Traceback (most recent call last):
...
Notes
-----
This block uses ``$ ... $`` for inline maths -> $e^{\frac{x}{2}}$.
Or ``$$ ... $$`` for block math instead of the ``.. math:`` directive above.
$$\int_0^1 x^2 dx = \frac{1}{3}$$
Parameters
----------
arg1 : int
Description of arg1
arg2 : str
Description of arg2
Returns
-------
bool
Description of return value
"""
if arg1 == arg2:
raise ValueError("arg1 must not be equal to arg2")
return True
[docs]
def select_indices(
iterable: Union[List, Tuple, TensorType, npt.NDArray],
indices: Optional[Union[List, Tuple, TensorType, npt.NDArray]] = None,
):
"""
Select elements form iterable
Parameters
----------
iterable: list, tuple, tensor or np.ndarray
An iterable to select elements from. It can have multiple dimensions
and selection is always preformed over the first dimension
indices: list, tuple, tensor or np.ndarray
1-dimentional sequence of indecies for selecting elements, optional. If None,
the iterable will be returned as is. Default None
Returns
-------
list, tuple, tensor or np.ndarray
A sequence of selected elements. The type of the returned sequence is
the same as the type of the input iterable
"""
if indices is None:
return iterable
if isinstance(iterable, (list, tuple)):
result = [iterable[idx] for idx in indices]
if isinstance(iterable, tuple):
result = tuple(result)
return result
elif torch.is_tensor(iterable) or isinstance(iterable, np.ndarray):
if isinstance(indices, tuple):
indices = list(indices)
return iterable[indices]
else:
raise Exception(f"Cannot select elements from {type(iterable)}")