Source code for gflownet.proxy.crystals.density

import warnings

import numpy as np
import pandas as pd
import torch
from torchtyping import TensorType

from gflownet.envs.crystals.composition import N_ELEMENTS_ORACLE
from gflownet.envs.crystals.crystal import Crystal
from gflownet.proxy.base import Proxy
from gflownet.utils.common import tfloat, tlong
from gflownet.utils.crystals.constants import ATOMIC_MASS

[docs] DENSITY_CONVERSION = 10 / 6.022 # constant to convert g/molA3 to g/cm3
[docs] class Density(Proxy): def __init__(self, **kwargs): """ Proxy to compute the density of a crystal, in g/cm3. It requires the same inputs as the Dave proxy. """ super().__init__(**kwargs)
[docs] def setup(self, env=None): if isinstance(env, Crystal): self.atomic_mass = torch.zeros(N_ELEMENTS_ORACLE + 1, dtype=self.float) elements = env.subenvs[env.idx_composition].elements atomic_mass_elements = tfloat( [ATOMIC_MASS[n] for n in elements], float_type=self.float, device=self.device, ) self.atomic_mass[tlong(elements, device=self.device)] = atomic_mass_elements else: warnings.warn( "Attempted to setup Density proxy without passing the right " "Crystal env type (continuous crystal stack)" )
@torch.no_grad()
[docs] def __call__( self, states: TensorType["batch", "policy_input_dim"] ) -> TensorType["batch"]: """ Args: states (torch.Tensor): same as DAVE proxy, i.e. * composition: ``states[:, :-7]`` -> length 95 (dummy 0 then 94 elements) * space group: ``states[:, -7] - 1`` * lattice parameters: ``states[:, -6:]`` Returns: nd.array: -1 * density in g/cm3. Shape: ``(batch,)``. """ total_mass = torch.matmul(states[:, :-7], self.atomic_mass) a, b, c, cos_alpha, cos_beta, cos_gamma = ( states[:, -6], states[:, -5], states[:, -4], torch.cos(torch.deg2rad(states[:, -3])), torch.cos(torch.deg2rad(states[:, -2])), torch.cos(torch.deg2rad(states[:, -1])), ) volume = (a * b * c) * torch.sqrt( 1 - (cos_alpha.pow(2) + cos_beta.pow(2) + cos_gamma.pow(2)) + (2 * cos_alpha * cos_beta * cos_gamma) ) return (total_mass / volume) * DENSITY_CONVERSION