gflownet.proxy.crystals.corners
Debugging proxy for Crystal-GFN
Classes
Initializes the CrystalCorners proxy according the configuration passed as an |
Module Contents
- class gflownet.proxy.crystals.corners.CrystalCorners(mu=0.75, sigma=0.05, config=[], **kwargs)[source]
Bases:
gflownet.proxy.base.ProxyInitializes the CrystalCorners proxy according the configuration passed as an argument.
- Parameters:
mu (float) – Mean of the multivariate Gaussian distribution used to construct the default corners proxy, that is the proxy used for states not included in the conditions specified in the config. Note that mu is a single float because the mean is homogeneous.
sigma (float) – Standard deviation of the multivariate Gaussian distribution used to default corners proxy, that is the proxy used for states not included in the conditions specified in the config. Note that sigma is a single float because the covariance matrix is diagonal.
config (list) –
A list of dictionaries specifying the parameters of the Corners sub-proxies. Each element in the list is a dictionary that must contain one (and only one) of the following keys:
spacegroup: int
element: int
- Additionally, it must contain the following keys:
mu: float
sigma:
mu and sigma are the mean and standard deviation for a Corners sub-proxy to be applied for states that have the spacegroup and element indicated in the configuration.
- setup(env=None)[source]
Sets the minimum length and maximum length of the LatticeParameters sub-environment.
This is needed to be able to rescale the lattice lengths before passing them to the Corners proxy.
- Parameters:
env (Crystal) – A Crystal environment.
- lattice_lengths_to_corners_proxy(lp_lengths)[source]
Converts a batch of lattice lengtsh in LatticeParameters proxy format into the format expected by the Corners proxy.
The lattice lengths in LatticeParameters proxy format are in angstroms.
The corners proxy expects the states in the range [-1, 1].
- Parameters:
lp_lengths (tensor) – Batch of lattice lengths in LatticeParameters proxy format (angstroms).
- Returns:
tensor – Batch of re-scaled lattice lengths in the range [-1, 1].
- Return type:
torchtyping.TensorType[batch, 3]
- __call__(states)[source]
Builds the proxy values of the CornersProxy for a batch Crystal states.
Different Corners proxies are applied depending on the states and according to the configuration passed at initialization, which sets conditions on the values of the space group and the presence of elements.
If a state meets more than one condition, the resulting score is a weighted sum of the multiple proxies, with coefficients equal to 1/N, where N is the number of proxies.
If a state does not meet any specific condition, the default proxy is used.
- Parameters:
states (torch.Tensor) – States to infer on. Shape:
(batch, [6 + 1 + n_elements]).- Returns:
torch.Tensor – Proxy scores. Shape:
(batch,).- Return type:
torchtyping.TensorType[batch]