Source code for gflownet.proxy.torus

import torch
from torchtyping import TensorType

from gflownet.proxy.base import Proxy


[docs] class Torus(Proxy): def __init__(self, normalize, alpha=1.0, beta=1.0, **kwargs): super().__init__(**kwargs)
[docs] self.normalize = normalize
[docs] self.alpha = alpha
[docs] self.beta = beta
[docs] def setup(self, env=None): if env: self.n_dim = env.n_dim
@property
[docs] def optimum(self): if not hasattr(self, "_optimum"): if self.normalize: self._optimum = torch.tensor(1.0, device=self.device, dtype=self.float) else: self._optimum = torch.tensor( ((self.n_dim * 2) ** 3), device=self.device, dtype=self.float ) return self._optimum
@property
[docs] def norm(self): if not hasattr(self, "_norm"): if self.normalize: self._norm = torch.tensor( ((self.n_dim * 2) ** 3), device=self.device, dtype=self.float ) else: self._norm = torch.tensor(1.0, device=self.device, dtype=self.float) return self._norm
[docs] def __call__(self, states: TensorType["batch", "state_dim"]) -> TensorType["batch"]: """ args: states: tensor returns: list of scores technically an oracle, hence used variable name energies """ def _func_sin_cos_cube(x): return (1.0 / self.norm) * ( torch.sum(torch.sin(self.alpha * x[:, 0::2]), axis=1) + torch.sum(torch.cos(self.beta * x[:, 1::2]), axis=1) + x.shape[1] ) ** 3 return _func_sin_cos_cube(states)