Source code for gflownet.proxy.crystals.spacegroup

from pathlib import Path

import pandas as pd
import torch
from torchtyping import TensorType

from gflownet.proxy.base import Proxy

[docs] SPACE_GROUP_COUNTS = None
def _read_space_group_counts(): global SPACE_GROUP_COUNTS if SPACE_GROUP_COUNTS is None: return pd.read_csv(Path(__file__).parent / "spacegroups_limat_counts.csv") return SPACE_GROUP_COUNTS
[docs] class SpaceGroup(Proxy): def __init__(self, normalize: bool = True, **kwargs): super().__init__(**kwargs) df = _read_space_group_counts()
[docs] self.counts = torch.zeros(231, device=self.device, dtype=torch.int16)
self.counts[df["Space Group"]] = torch.tensor( df.Counts, device=self.device, dtype=torch.int16 )
[docs] self.normalize = normalize
if self.normalize: self.norm = -1.0 * torch.sum(self.counts) else: self.norm = -1
[docs] def __call__(self, states: TensorType["batch", "1"]) -> TensorType["batch"]: return self.counts[torch.squeeze(states)] / self.norm