gflownet.proxy.crystals.dave
Attributes
URL to the proxy's code repository. It is used to provide a link to the appropriate |
Classes
Wrapper class around the Dave (Divya-Alexandre-Victor) proxy. |
Module Contents
- gflownet.proxy.crystals.dave.RELEASE = '2.0.6'[source]
URL to the proxy’s code repository. It is used to provide a link to the appropriate release link in case of version mismatch between requested and installed
daverelease.
- class gflownet.proxy.crystals.dave.DAVE(ckpt_path=str, **kwargs)[source]
Bases:
gflownet.proxy.base.ProxyWrapper class around the Dave (Divya-Alexandre-Victor) proxy.
- Parameters:
ckpt_path (str) – Path to a directory containing the checkpoint of a pre-trained model.
- __call__(states)[source]
Forward pass of the proxy.
The proxy will decompose the state as: * composition:
states[:, :-7]-> length 95 (dummy 0 then 94 elements) * space group:states[:, -7] - 1* lattice parameters:states[:, -6:]>>> composition MUST be a list of ATOMIC NUMBERS, prepended with a 0. >>> dummy padding value at comp[0] MUST be 0. ie -> comp[i] -> element Z=i ie -> LiO2 -> [0, 0, 0, 1, 0, 0, 2, 0, ...] up until Z=94 for the MatBench proxy ie -> len(comp) = 95 (0 then 94 elements)
>>> sg MUST be a list of ACTUAL space group numbers (1-230)
>>> lat_params MUST be a list of lattice parameters in the following order: [a, b, c, alpha, beta, gamma] as floats.
>>> the states tensor MUST already be on the device.
- Parameters:
states (torch.Tensor) – States to infer on. Shape:
(batch, [6 + 1 + n_elements]).- Returns:
torch.Tensor – Proxy energies. Shape:
(batch,).- Return type:
torchtyping.TensorType[batch]