gflownet.proxy.base
Base class of GFlowNet proxies
Attributes
Classes
Base Proxy class for GFlowNet proxies. |
Module Contents
- class gflownet.proxy.base.Proxy(device='cpu', float_precision=32, reward_function='identity', logreward_function=None, reward_function_kwargs={}, reward_min=0.0, do_clip_rewards=False, **kwargs)[source]
Bases:
abc.ABCBase Proxy class for GFlowNet proxies.
A proxy is the input to a reward function. Depending on the
reward_function, the reward may be directly the output of the proxy or a function of it.- Parameters:
device (str or torch.device) – The device to be passed to torch tensors.
float_precision (int or torch.dtype) – The floating point precision to be passed to torch tensors.
reward_function (str or Callable) – The transformation applied to the proxy outputs to obtain a GFlowNet reward. See
Proxy._get_reward_functions().logreward_function (Callable) – The transformation applied to the proxy outputs to obtain a GFlowNet log reward. See
Proxy._get_reward_functions(). If None (default), the log of the reward function is used. The Callable may be used to improve the numerical stability of the transformation.reward_function_kwargs (dict) – A dictionary of arguments to be passed to the reward function.
reward_min (float) – The minimum value allowed for rewards, 0.0 by default, which results in a minimum log reward of
LOGZERO. Note that certain loss functions, for example the Forward Looking loss may not work as desired if the minimum reward is 0.0. It may be set to a small (positive) value close to zero in order to prevent numerical stability issues.do_clip_rewards (bool) – Whether to clip the rewards according to the minimum value.
- abstract __call__(states)[source]
Implement this function to call the get_reward method of the appropriate Proxy Class (EI, UCB, Proxy, Oracle etc).
- Parameters:
states (ndarray)
- Return type:
torchtyping.TensorType
- rewards(states, log=False, return_proxy=False)[source]
Computes the rewards of a batch of states.
The rewards are computed by first calling the proxy function, then transforming the proxy values according to the reward function.
- Parameters:
states (tensor or list or array) – A batch of states in proxy format.
log (bool) – If True, returns the logarithm of the rewards. If False (default), returns the natural rewards.
return_proxy (bool) – If True, returns the proxy values, alongside the rewards, as the second element in the returned tuple.
- Returns:
rewards (tensor) – The reward or log-reward of all elements in the batch.
proxy_values (tensor (optional)) – The proxy value of all elements in the batch. Included only if return_proxy is True.
- Return type:
Union[torchtyping.TensorType, Tuple[torchtyping.TensorType, torchtyping.TensorType]]
- proxy2reward(proxy_values)[source]
Transform a tensor of proxy values into rewards.
If do_clip_rewards is True, rewards are clipped to self.reward_min.
- Parameters:
proxy_values (tensor) – The proxy values corresponding to a batch of states.
- Returns:
tensor – The reward of all elements in the batch.
- Return type:
torchtyping.TensorType
- proxy2logreward(proxy_values)[source]
Transform a tensor of proxy values into log-rewards.
NaN values are set to self.logreward_min.
- Parameters:
proxy_values (tensor) – The proxy values corresponding to a batch of states.
- Returns:
tensor – The log-reward of all elements in the batch.
- Return type:
torchtyping.TensorType
- get_min_reward(log=False)[source]
Returns the minimum value of the (log) reward, retrieved from self.reward_min and self.logreward_min.
- Parameters:
log (bool) – If True, returns the logarithm of the minimum reward. If False (default), returns the natural minimum reward.
- Returns:
float – The minimum (log) reward.
- Return type:
float
- get_max_reward(log=False)[source]
Returns the maximum value of the (log) reward, retrieved from self.optimum, in case it is defined.
- Parameters:
log (bool) – If True, returns the logarithm of the maximum reward. If False (default), returns the natural maximum reward.
- Returns:
float – The maximum (log) reward.
- Return type:
float