"""
This implementation uses the Stack meta-environment and the continuous Lattice
Parameters environment. Alternative implementations preceded this one but have been
removed for simplicity. Check commit 9f3477d8e46c4624f9162d755663993b83196546 to see
these changes or the history previous to that commit to consult previous
implementations.
"""
from typing import Dict, List, Optional, Tuple, Union
import pandas as pd
import torch
from torchtyping import TensorType
from tqdm import tqdm
from gflownet.envs.composite.stack import Stack
from gflownet.envs.crystals.composition import Composition
from gflownet.envs.crystals.lattice_parameters import (
PARAMETER_NAMES,
LatticeParameters,
LatticeParametersSGCCG,
)
from gflownet.envs.crystals.spacegroup import SpaceGroup
from gflownet.utils.common import copy
from gflownet.utils.crystals.constants import LATTICE_SYSTEMS, TRICLINIC
[docs]
class Crystal(Stack):
"""
A combination of Composition, SpaceGroup and LatticeParameters into a single
environment. Works sequentially, by first filling in the Composition, then
SpaceGroup, and finally LatticeParameters.
Attributes
----------
do_spacegroup : bool
Whether to include the SpaceGroup as a sub-environment and thus sample the
space group of the crystal.
do_lattice_parameters : bool
Whether to include the LatticeParameters as a sub-environment and thus sample
the lattice parameters (a, b, c, α, β, γ) of the crystal.
do_projected_lattice_parameters : bool
If True, the LatticeParametersSGCCG environment is used instead of
LatticeParameters. The latter operates in the natural space of the lattice
parameters, while the former operates in a projection which ensures the
validity of the angles. By default, LatticeParameters is used because
LatticeParametersSGCCG does not currently allow to set constraints of min and
max lengths and angles and it slows down the run time. Note that the default
natural LatticeParameters can generate angles with potentially invalid volumes.
do_sg_before_composition : bool
Whether the SpaceGroup sub-environment should precede the composition.
do_composition_to_sg_constraints : bool
Whether to apply constraints on the space group sub-environment, based on the
composition, in case the composition goes first.
do_sg_to_composition_constraints : bool
Whether to apply constraints on the composition sub-environment, based on the
space group, in case the space group goes first.
do_sg_to_lp_constraints : bool
Whether to apply constraints on the lattice parameters sub-environment, based
on the space group.
composition_kwargs : dict
An optional dictionary with configuration to be passed to the Composition
sub-environment.
space_group_kwargs : dict
An optional dictionary with configuration to be passed to the SpaceGroup
sub-environment.
lattice_parameters_kwargs : dict
An optional dictionary with configuration to be passed to the LatticeParameters
sub-environment.
"""
def __init__(
self,
do_spacegroup: bool = True,
do_lattice_parameters: bool = True,
do_projected_lattice_parameters: bool = False,
do_sg_before_composition: bool = True,
do_composition_to_sg_constraints: bool = True,
do_sg_to_composition_constraints: bool = True,
do_sg_to_lp_constraints: bool = True,
composition_kwargs: Optional[Dict] = None,
space_group_kwargs: Optional[Dict] = None,
lattice_parameters_kwargs: Optional[Dict] = None,
**kwargs,
):
[docs]
self.do_spacegroup = do_spacegroup
[docs]
self.do_lattice_parameters = do_lattice_parameters
[docs]
self.do_projected_lattice_parameters = do_projected_lattice_parameters
[docs]
self.do_sg_to_composition_constraints = (
do_sg_to_composition_constraints and do_sg_before_composition
)
[docs]
self.do_composition_to_sg_constraints = (
do_composition_to_sg_constraints and not do_sg_before_composition
)
[docs]
self.do_sg_to_lp_constraints = do_sg_to_lp_constraints
[docs]
self.do_sg_before_composition = do_sg_before_composition
[docs]
self.composition_kwargs = dict(
composition_kwargs or {},
do_spacegroup_check=self.do_sg_to_composition_constraints,
)
[docs]
self.space_group_kwargs = space_group_kwargs or {}
[docs]
self.lattice_parameters_kwargs = lattice_parameters_kwargs or {}
# Initialize list of subenvs:
subenvs = []
if self.do_spacegroup:
space_group = SpaceGroup(**self.space_group_kwargs)
if self.do_sg_before_composition:
subenvs.append(space_group)
self.idx_spacegroup = 0
self.idx_composition = 1
else:
space_group = None
composition = Composition(**self.composition_kwargs)
subenvs.append(composition)
if not self.do_sg_before_composition and space_group is not None:
subenvs.append(space_group)
self.idx_composition = 0
self.idx_spacegroup = 1
if self.do_lattice_parameters:
# We initialize lattice parameters with triclinic lattice system as it is
# the most general one, but it will have to be reinitialized using proper
# lattice system from space group once that is determined.
if self.do_projected_lattice_parameters:
lattice_parameters = LatticeParametersSGCCG(
lattice_system=TRICLINIC, **self.lattice_parameters_kwargs
)
else:
lattice_parameters = LatticeParameters(
lattice_system=TRICLINIC, **self.lattice_parameters_kwargs
)
subenvs.append(lattice_parameters)
self.idx_latticeparameters = 2
# Initialize base Stack environment
super().__init__(subenvs=tuple(subenvs), **kwargs)
@property
[docs]
def composition(self) -> Union[Composition]:
"""
Returns the sub-environment corresponding to the composition.
Returns
-------
Composition or None
"""
if hasattr(self, "idx_composition"):
return self.subenvs[self.idx_composition]
return None
@property
[docs]
def space_group(self) -> SpaceGroup:
"""
Returns the sub-environment corresponding to the space group.
Returns
-------
SpaceGroup or None
"""
if hasattr(self, "idx_spacegroup"):
return self.subenvs[self.idx_spacegroup]
return None
@property
[docs]
def lattice_parameters(self) -> Union[LatticeParameters, LatticeParametersSGCCG]:
"""
Returns the sub-environment corresponding to the lattice parameters.
Returns
-------
LatticeParameters or None
"""
if hasattr(self, "idx_latticeparameters"):
return self.subenvs[self.idx_latticeparameters]
return None
def _check_has_constraints(self) -> bool:
"""
Checks whether Crystal implements any constraints across sub-environments.
It returns True if any of the possible constraints is implemented.
Returns
-------
bool
True if the Crystal has constraints, False otherwise
"""
return any(
[
self.do_composition_to_sg_constraints,
self.do_sg_to_composition_constraints,
self.do_sg_to_lp_constraints,
]
)
def _apply_constraints_forward(
self,
action: Optional[Tuple] = None,
state: Optional[Dict] = None,
) -> bool:
"""
Applies constraints across sub-environments, when applicable, in the forward
direction.
- composition -> space group (if composition is first)
- space group -> composition (if space group is first)
- space group -> lattice parameters
Parameters
----------
action : tuple (optional)
An action from the Crystal environment or None.
state : dict (optional)
A state from the Crystal environment or None.
Returns
-------
bool
True if any constraint was applied; False otherwise.
"""
applied_constraints = False
# Apply constraints composition -> space group
# Apply constraint only if action is None or if it is the composition EOS
if (
self.do_composition_to_sg_constraints
and not self.do_sg_before_composition
and self._do_constraints_for_subenv(
state, self.idx_composition, action, is_backward=False
)
):
applied_constraints = True
state = self._get_state(state)
composition_substate = self._get_substate(state, self.idx_composition)
n_atoms_per_element = self.composition.get_n_atoms_per_element(
composition_substate
)
self.space_group.set_n_atoms_compatibility_dict(n_atoms_per_element)
# Apply constraints:
# - space group -> composition
# - space group -> lattice parameters
# Apply constraint only if action is None or if it is the space group EOS
if self._do_constraints_for_subenv(
state, self.idx_spacegroup, action, is_backward=False
):
applied_constraints = True
state = self._get_state(state)
spacegroup_substate = self._get_substate(state, self.idx_spacegroup)
if self.do_sg_before_composition and self.do_sg_to_composition_constraints:
space_group = self.space_group.get_space_group(spacegroup_substate)
self.composition.set_space_group(space_group)
if self.do_sg_to_lp_constraints:
lattice_system = self.space_group.get_lattice_system(
spacegroup_substate
)
self.lattice_parameters.set_lattice_system(lattice_system)
self._set_substate(
self.idx_latticeparameters, self.lattice_parameters.state, state
)
return applied_constraints
def _apply_constraints_backward(
self, action: Optional[Tuple] = None, state: Optional[Dict] = None
) -> bool:
"""
Applies constraints across sub-environments, when applicable, in the backward
direction.
Parameters
----------
action : tuple (optional)
An action from the Crystal environment or None.
state : dict (optional)
A state from the Crystal environment or None.
Returns
-------
bool
True if any constraint was applied; False otherwise.
"""
applied_constraints = False
# Revert constraints:
# - space group -> lattice parameters: lattice system of LatticeParameters is
# set back to TRICLINIC
# - space group -> composition: space group of Composition is set back to None
# Apply constraint only if action is None or if it is the space group EOS
if (
self.do_spacegroup
and self.do_sg_to_lp_constraints
and self._do_constraints_for_subenv(
state, self.idx_spacegroup, action, is_backward=True
)
):
applied_constraints = True
self.lattice_parameters.set_lattice_system(TRICLINIC)
self._set_substate(
self.idx_latticeparameters, self.lattice_parameters.state, state
)
self.composition.set_space_group(None)
# Revert constraints composition -> space group: The number of atoms is set
# back to None
# Apply constraint only if action is None or if it is the composition EOS
if (
self.do_composition_to_sg_constraints
and not self.do_sg_before_composition
and self._do_constraints_for_subenv(
state, self.idx_composition, action, is_backward=True
)
):
applied_constraints = True
self.space_group.set_n_atoms_compatibility_dict(None)
return applied_constraints
[docs]
def states2proxy(
self, states: List[Dict]
) -> TensorType["batch", "state_oracle_dim"]:
"""
Prepares a batch of states in environment format for the proxies.
The output is the concatenation of the proxy-format states of the
sub-environments.
This method is overriden to improve the efficiency, to create a tensor as an
output and to account for the space group before composition case, since the
proxy expects composition first regardless.
Parameters
----------
states : list
A batch of states in environment format.
Returns
-------
A tensor containing all the states in the batch.
"""
indices_subenvs_proxy = [
self.idx_composition,
self.idx_spacegroup,
self.idx_latticeparameters,
]
return torch.cat(
[
self.subenvs[idx].states2proxy([state[idx] for state in states])
for idx in indices_subenvs_proxy
if idx is not None
],
dim=1,
)
[docs]
def process_data_set(
self, data: Union[pd.DataFrame, List], progress=False
) -> List[List]:
"""
Processes a data set passed as a pandas DataFrame or as a list of states by
filtering out the states that are not valid according to the environment
configuration.
If the input is a DataFrame, the rows are converted into environment states.
Parameters
----------
data : DataFrame or list
One of the following:
- A pandas DataFrame containing the necessary columns to represent a
crystal as described above.
- A list of states in environment format.
progress : bool
Whether to display a progress bar.
Returns
-------
list
A list of states in environment format.
"""
if isinstance(data, pd.DataFrame):
return self._process_dataframe(data, progress)
elif isinstance(data, list) and isinstance(data[0], list):
return self._process_states_list(data, progress)
else:
raise ValueError("Unknown data type")
def _process_states_list(self, data: List, progress=False) -> List[List]:
"""
Processes a data set passed a list of states in environment format by filtering
out the states that are not valid according to the environment configuration.
Parameters
----------
data : list
A list of states in environment format.
progress : bool
Whether to display a progress bar.
Returns
-------
list
A list of states in environment format.
"""
data_valid = []
for state in tqdm(data, total=len(data), disable=not progress):
# Index 0 is the row index; index 1 is the remaining columns
is_valid_subenvs = [
subenv.is_valid(self._get_substate(state, idx))
for idx, subenv in enumerate(self.subenvs)
]
if all(is_valid_subenvs):
data_valid.append(state)
return data_valid
def _process_dataframe(self, df: pd.DataFrame, progress=False) -> List[List]:
"""
Converts a data set passed as a pandas DataFrame into a list of states in
environment format.
The DataFrame is expected to have the following columns:
- Formulae: non-reduced formulae of the composition
- Space Group: international number of the space group
- a, b, c, alpha, beta, gamma: lattice parameters
Parameters
----------
df : DataFrame
A pandas DataFrame containing the necessary columns to represent a crystal
as described above.
progress : bool
Whether to display a progress bar.
Returns
-------
list
A list of states in environment format.
"""
data_valid = []
for row in tqdm(df.iterrows(), total=len(df), disable=not progress):
# Index 0 is the row index; index 1 is the remaining columns
row = row[1]
state = {}
# Composition
state[self.idx_composition] = self.subenvs[
self.idx_composition
].readable2state(row["Formulae"])
# Space group
state[self.idx_spacegroup] = self.subenvs[
self.idx_spacegroup
]._set_constrained_properties([0, 0, row["Space Group"]])
# Lattice parameters
lattice_system = self.space_group.get_lattice_system(
state[self.idx_spacegroup]
)
if lattice_system not in LATTICE_SYSTEMS:
lattice_system = TRICLINIC
state_lp = copy(self.lattice_parameters.source)
state_lp = self.lattice_parameters._set_active_subenv(
self.lattice_parameters.idx_cube, state_lp
)
state_lp = self.lattice_parameters.set_lattice_system(
lattice_system, state_lp
)
state_cube = self.lattice_parameters.revert_lattice_constraints(
tuple(row[list(PARAMETER_NAMES)]), lattice_system
)
state[self.idx_latticeparameters] = self.lattice_parameters._set_substate(
self.lattice_parameters.idx_cube, state_cube, state_lp
)
# Check validity
is_valid_subenvs = [
subenv.is_valid(self._get_substate(state, idx))
for idx, subenv in enumerate(self.subenvs)
]
if all(is_valid_subenvs):
# Add meta-data to state
state.update({"_active": self.n_subenvs - 1})
data_valid.append(state)
return data_valid
def _print_state(self, state: Optional[List] = None):
"""
Prints a state in more human-readable format, for debugging purposes.
"""
state = self._get_state(state)
for idx, subenv in enumerate(self.subenvs):
print(f"Stage {idx}")
print(self._get_substate(state, idx))