"""Debugging proxy for Crystal-GFN"""
from typing import Dict, List
import numpy as np
import torch
from torchtyping import TensorType
from gflownet.proxy.base import Proxy
from gflownet.proxy.box.corners import Corners
[docs]
class CrystalCorners(Proxy):
"""
A synthetic proxy which resembles the Corners proxy in the lattice parameters
domain. It places different corners (with varying mean and standard deviations for
the Gaussians) depending on the space group and composition.
Specifically, a different proxy than the default one can be used for states that
contain a particular space group or a particular element.
For states that simultaneously meet more than one condition, the resulting score is
the weighted sum of the multiple proxies. This will result in a mixture of corners
with homogenoeus coefficients.
Attributes
----------
proxy_default : Corners
Default proxy to be used for the states not included in the conditions
specified in the configuration.
"""
def __init__(
self, mu: float = 0.75, sigma: float = 0.05, config: List[Dict] = [], **kwargs
):
"""
Initializes the CrystalCorners proxy according the configuration passed as an
argument.
Parameters
----------
mu : float
Mean of the multivariate Gaussian distribution used to construct the
default corners proxy, that is the proxy used for states not included in
the conditions specified in the config. Note that mu is a single float
because the mean is homogeneous.
sigma : float
Standard deviation of the multivariate Gaussian distribution used to
default corners proxy, that is the proxy used for states not included in
the conditions specified in the config. Note that sigma is a single float
because the covariance matrix is diagonal.
config : list
A list of dictionaries specifying the parameters of the Corners
sub-proxies. Each element in the list is a dictionary that must contain one
(and only one) of the following keys:
- spacegroup: int
- element: int
Additionally, it must contain the following keys:
- mu: float
- sigma:
mu and sigma are the mean and standard deviation for a Corners sub-proxy to
be applied for states that have the spacegroup and element indicated in the
configuration.
"""
super().__init__(**kwargs)
# Check whether the config list is invalid
if not all([self._dict_is_valid(el) for el in config]):
raise ValueError("Configuration is not valid")
# Initialize special case proxies
for conf in self.proxies:
conf.update(
{
"proxy": Corners(
n_dim=3, mu=conf["mu"], sigma=conf["sigma"], **kwargs
)
}
)
conf["proxy"].setup()
# Initialize default proxy
[docs]
self.proxy_default = Corners(n_dim=3, mu=mu, sigma=sigma, **kwargs)
self.proxy_default.setup()
[docs]
def setup(self, env=None):
"""
Sets the minimum length and maximum length of the LatticeParameters
sub-environment.
This is needed to be able to rescale the lattice lengths before passing them to
the Corners proxy.
Parameters
----------
env : Crystal
A Crystal environment.
"""
if env:
self.min_length = env.lattice_parameters.min_length
self.max_length = env.lattice_parameters.max_length
[docs]
def lattice_lengths_to_corners_proxy(
self, lp_lengths: TensorType["batch", "3"]
) -> TensorType["batch", "3"]:
"""
Converts a batch of lattice lengtsh in LatticeParameters proxy format into the
format expected by the Corners proxy.
The lattice lengths in LatticeParameters proxy format are in angstroms.
The corners proxy expects the states in the range [-1, 1].
Parameters
----------
lp_lengths : tensor
Batch of lattice lengths in LatticeParameters proxy format (angstroms).
Returns
-------
tensor
Batch of re-scaled lattice lengths in the range [-1, 1].
"""
return -1.0 + ((lp_lengths - self.min_length) * 2.0) / (
self.max_length - self.min_length
)
@torch.no_grad()
[docs]
def __call__(self, states: TensorType["batch", "102"]) -> TensorType["batch"]:
"""
Builds the proxy values of the CornersProxy for a batch Crystal states.
Different Corners proxies are applied depending on the states and according to
the configuration passed at initialization, which sets conditions on the values
of the space group and the presence of elements.
If a state meets more than one condition, the resulting score is a weighted sum
of the multiple proxies, with coefficients equal to 1/N, where N is the number
of proxies.
If a state does not meet any specific condition, the default proxy is used.
Parameters
----------
states : torch.Tensor
States to infer on. Shape: ``(batch, [6 + 1 + n_elements])``.
Returns
-------
torch.Tensor
Proxy scores. Shape: ``(batch,)``.
"""
comp = states[:, :-7]
sg = states[:, -7]
lat_params = states[:, -6:]
lp_lengths = self.lattice_lengths_to_corners_proxy(lat_params[:, :3])
# Apply the corresponding proxy for each state in the batch
scores = torch.zeros(states.shape[0], dtype=self.float)
coefficients = torch.zeros(states.shape[0], dtype=self.float)
default = torch.ones(states.shape[0], dtype=torch.bool)
for proxy in self.proxies:
if "spacegroup" in proxy:
indices = sg == proxy["spacegroup"]
elif "element" in proxy:
indices = comp[:, proxy["element"]] > 0
else:
raise ValueError("Configuration is not valid")
# The proxy values are sum to the current scores (which are initialized at
# zero)
scores[indices] = scores[indices] + proxy["proxy"](lp_lengths[indices])
coefficients[indices] += 1
default[indices] = False
# Divide scores by the coefficients
scores[~default] = scores[~default] / coefficients[~default]
# Apply default proxy
scores[default] = self.proxy_default(lp_lengths[default])
return scores
@staticmethod
def _dict_is_valid(config: Dict):
"""
Checks whether a dictionary of configuration is valid.
To be valid, the following conditions must be satisfied:
- There is a key 'mu' containing a float number.
- There is a key 'sigma' containing a float number.
- There is a key 'spacegroup' or 'element', but only one of the two,
containing an int number.
"""
if "mu" not in config.keys() or not isinstance(config["mu"], float):
return False
if "sigma" not in config.keys() or not isinstance(config["sigma"], float):
return False
if "spacegroup" not in config.keys() and "element" not in config.keys():
return False
if "spacegroup" in config.keys() and "element" in config.keys():
return False
if "spacegroup" in config.keys() and not isinstance(config["spacegroup"], int):
return False
if "element" in config.keys() and not isinstance(config["element"], int):
return False
return True