from typing import List, Union
import torch
from torchtyping import TensorType
from tqdm import tqdm
from gflownet.proxy.base import Proxy
from gflownet.utils.common import tfloat, tint, tlong
from gflownet.utils.scrabble.utils import read_alphabet, read_vocabulary
[docs]
class ScrabbleScorer(Proxy):
"""
Oracle to compute the Scrabble scores from words, that is the sum of the score of
each letter in a sequence of letters (word).
"""
def __init__(self, vocabulary_check: bool = False, **kwargs):
[docs]
self.vocabulary_check = vocabulary_check
[docs]
self.alphabet_dict = read_alphabet()
[docs]
self.vocabulary_orig = read_vocabulary()
super().__init__(**kwargs)
[docs]
def setup(self, env=None):
# Add pad_token to alphabet dict
if env and not hasattr(self, "pad_token"):
# Make a copy before modifying because the dictionary is global
self.alphabet_dict = self.alphabet_dict.copy()
self.pad_token = env.pad_token
self.alphabet_dict[self.pad_token] = 0
# Build scores tensor
if env and not hasattr(self, "scores"):
scores = [
self.alphabet_dict[env.idx2token[idx]]
for idx in range(len(env.idx2token))
]
self.scores = tlong(scores, device=self.device)
# Build index-based version of the vocabulary as a tensor
self.vocabulary = torch.zeros(
(len(self.vocabulary_orig), env.max_length),
dtype=torch.int16,
device=self.device,
)
for idx, word in enumerate(self.vocabulary_orig):
word = "".join([letter + " " for letter in word.upper()])[:-1]
self.vocabulary[idx] = tint(
env.readable2state(word), device=self.device, int_type=torch.int16
)
[docs]
def __call__(
self, states: Union[List[str], TensorType["batch", "state_dim"]]
) -> TensorType["batch"]:
"""
Computes and returns the Scrabble score of sequence in a batch.
In principle and in general, the input states is a tensor, where each state
(row) is represented by the index of each token.
However, for debugging purposes, this proxy also works if the input states is a
list of:
- Strings
- List of string tokens
See: tests/gflownet/proxy/test_scrabble_proxy.py
Args
----
states : tensor or list
If a tensor: A batch of states, where each row is a state and each state
represents a sequence by the indices of the token, including the padding.
If a list: A batch of state, where each entry is either a string containing
the word or a list of letters.
Returns
-------
A vector with the score of each sequence in the batch.
"""
if torch.is_tensor(states):
output = torch.zeros(states.shape[0], device=self.device, dtype=self.float)
if self.vocabulary_check:
is_in_vocabulary = self._is_in_vocabulary(states)
else:
is_in_vocabulary = torch.ones_like(output, dtype=torch.bool)
output[is_in_vocabulary] = tfloat(
self.scores[states[is_in_vocabulary]].sum(dim=1),
float_type=self.float,
device=self.device,
)
return output
elif isinstance(states, list):
scores = []
for sample in states:
if (
self.vocabulary_check
and self._unpad_and_string(sample) not in self.vocabulary_orig
):
scores.append(0.0)
else:
scores.append(self._sum_scores(sample))
return tfloat(scores, device=self.device, float_type=self.float)
else:
raise NotImplementedError(
"The Scrabble proxy currently only supports input states as a tensor "
"of indices or as list of strings containing a token each"
)
def _sum_scores(self, sample: list) -> int:
return sum(map(lambda x: self.alphabet_dict[x], sample))
def _is_in_vocabulary(
self, states: TensorType["batch", "state_dim"]
) -> TensorType["batch"]:
"""
Returns the indices of the states that match any of the words in the
vocabulary.
See: https://stackoverflow.com/a/77419829/6194082
"""
return (self.vocabulary == states.unsqueeze(1)).all(-1).any(-1)
def _unpad_and_string(self, sample: list) -> str:
if self.pad_token in sample:
sample = sample[: sample.index(self.pad_token)]
return "".join(sample).lower()