gflownet.proxy.crystals.dave ============================ .. py:module:: gflownet.proxy.crystals.dave Attributes ---------- .. autoapisummary:: gflownet.proxy.crystals.dave.REPO_URL gflownet.proxy.crystals.dave.RELEASE Classes ------- .. autoapisummary:: gflownet.proxy.crystals.dave.DAVE Module Contents --------------- .. py:data:: REPO_URL :value: 'https://github.com/sh-divya/crystalproxies.git' .. py:data:: RELEASE :value: '2.0.6' URL to the proxy's code repository. It is used to provide a link to the appropriate release link in case of version mismatch between requested and installed ``dave`` release. .. py:class:: DAVE(ckpt_path=str, **kwargs) Bases: :py:obj:`gflownet.proxy.base.Proxy` Wrapper class around the Dave (Divya-Alexandre-Victor) proxy. :param ckpt_path: Path to a directory containing the checkpoint of a pre-trained model. :type ckpt_path: str .. py:attribute:: model .. py:method:: __call__(states) Forward pass of the proxy. The proxy will decompose the state as: * composition: ``states[:, :-7]`` -> length 95 (dummy 0 then 94 elements) * space group: ``states[:, -7] - 1`` * lattice parameters: ``states[:, -6:]`` >>> composition MUST be a list of ATOMIC NUMBERS, prepended with a 0. >>> dummy padding value at comp[0] MUST be 0. ie -> comp[i] -> element Z=i ie -> LiO2 -> [0, 0, 0, 1, 0, 0, 2, 0, ...] up until Z=94 for the MatBench proxy ie -> len(comp) = 95 (0 then 94 elements) >>> sg MUST be a list of ACTUAL space group numbers (1-230) >>> lat_params MUST be a list of lattice parameters in the following order: [a, b, c, alpha, beta, gamma] as floats. >>> the states tensor MUST already be on the device. :param states: States to infer on. Shape: ``(batch, [6 + 1 + n_elements])``. :type states: torch.Tensor :returns: *torch.Tensor* -- Proxy energies. Shape: ``(batch,)``. .. py:method:: infer_on_train_set() Infer on the training set and return the ground-truth and proxy values. :returns: *tuple[torch.Tensor, torch.Tensor]* -- ``(energy, proxy)`` representing 1/ ground-truth energies and 2/ proxy inference on the proxy's training set as 1D tensors.