Source code for gflownet.proxy.uniform

from typing import List, Union

import torch
from torchtyping import TensorType

from gflownet.proxy.base import Proxy


[docs] class Uniform(Proxy): def __init__(self, **kwargs): super().__init__(**kwargs) self._optimum = torch.tensor(1.0, device=self.device, dtype=self.float)
[docs] def __call__( self, states: Union[List, TensorType["batch", "state_dim"]] ) -> TensorType["batch"]: return torch.ones(len(states), device=self.device, dtype=self.float)