gflownet.proxy.crystals.dave

Attributes

REPO_URL

RELEASE

URL to the proxy's code repository. It is used to provide a link to the appropriate

Classes

DAVE

Wrapper class around the Dave (Divya-Alexandre-Victor) proxy.

Module Contents

gflownet.proxy.crystals.dave.REPO_URL = 'https://github.com/sh-divya/crystalproxies.git'[source]
gflownet.proxy.crystals.dave.RELEASE = '2.0.6'[source]

URL to the proxy’s code repository. It is used to provide a link to the appropriate release link in case of version mismatch between requested and installed dave release.

class gflownet.proxy.crystals.dave.DAVE(ckpt_path=str, **kwargs)[source]

Bases: gflownet.proxy.base.Proxy

Wrapper class around the Dave (Divya-Alexandre-Victor) proxy.

Parameters:

ckpt_path (str) – Path to a directory containing the checkpoint of a pre-trained model.

model[source]
__call__(states)[source]

Forward pass of the proxy.

The proxy will decompose the state as: * composition: states[:, :-7] -> length 95 (dummy 0 then 94 elements) * space group: states[:, -7] - 1 * lattice parameters: states[:, -6:]

>>> composition MUST be a list of ATOMIC NUMBERS, prepended with a 0.
>>> dummy padding value at comp[0] MUST be 0.
ie -> comp[i] -> element Z=i
ie -> LiO2 -> [0, 0, 0, 1, 0, 0, 2, 0, ...] up until Z=94 for the MatBench proxy
ie -> len(comp) = 95 (0 then 94 elements)
>>> sg MUST be a list of ACTUAL space group numbers (1-230)
>>> lat_params MUST be a list of lattice parameters in the following order:
[a, b, c, alpha, beta, gamma] as floats.
>>> the states tensor MUST already be on the device.
Parameters:

states (torch.Tensor) – States to infer on. Shape: (batch, [6 + 1 + n_elements]).

Returns:

torch.Tensor – Proxy energies. Shape: (batch,).

Return type:

torchtyping.TensorType[batch]

infer_on_train_set()[source]

Infer on the training set and return the ground-truth and proxy values.

Returns:

tuple[torch.Tensor, torch.Tensor]

(energy, proxy) representing 1/ ground-truth energies and 2/

proxy inference on the proxy’s training set as 1D tensors.