gflownet.proxy.scrabble

Classes

ScrabbleScorer

Base Proxy class for GFlowNet proxies.

Module Contents

class gflownet.proxy.scrabble.ScrabbleScorer(vocabulary_check=False, **kwargs)[source]

Bases: gflownet.proxy.base.Proxy

Base Proxy class for GFlowNet proxies.

A proxy is the input to a reward function. Depending on the reward_function, the reward may be directly the output of the proxy or a function of it.

Parameters:
  • device (str or torch.device) – The device to be passed to torch tensors.

  • float_precision (int or torch.dtype) – The floating point precision to be passed to torch tensors.

  • reward_function (str or Callable) – The transformation applied to the proxy outputs to obtain a GFlowNet reward. See Proxy._get_reward_functions().

  • logreward_function (Callable) – The transformation applied to the proxy outputs to obtain a GFlowNet log reward. See Proxy._get_reward_functions(). If None (default), the log of the reward function is used. The Callable may be used to improve the numerical stability of the transformation.

  • reward_function_kwargs (dict) – A dictionary of arguments to be passed to the reward function.

  • reward_min (float) – The minimum value allowed for rewards, 0.0 by default, which results in a minimum log reward of LOGZERO. Note that certain loss functions, for example the Forward Looking loss may not work as desired if the minimum reward is 0.0. It may be set to a small (positive) value close to zero in order to prevent numerical stability issues.

  • do_clip_rewards (bool) – Whether to clip the rewards according to the minimum value.

  • vocabulary_check (bool)

vocabulary_check = False[source]
alphabet_dict = None[source]
vocabulary_orig = None[source]
setup(env=None)[source]
__call__(states)[source]

Computes and returns the Scrabble score of sequence in a batch.

In principle and in general, the input states is a tensor, where each state (row) is represented by the index of each token.

However, for debugging purposes, this proxy also works if the input states is a list of:

  • Strings

  • List of string tokens

See: tests/gflownet/proxy/test_scrabble_proxy.py

Parameters:

states (tensor or list) – If a tensor: A batch of states, where each row is a state and each state represents a sequence by the indices of the token, including the padding. If a list: A batch of state, where each entry is either a string containing the word or a list of letters.

Returns:

A vector with the score of each sequence in the batch.

Return type:

torchtyping.TensorType[batch]