"""
Classes to represent crystal environments
"""
import itertools
from copy import deepcopy
from enum import Enum
from pathlib import Path
from typing import Dict, Iterable, List, Optional, Tuple, Union
import numpy as np
import torch
import torch.nn.functional as F
import yaml
from torch import Tensor
from torchtyping import TensorType
from gflownet.envs.base import GFlowNetEnv
from gflownet.utils.common import tfloat, tlong
from gflownet.utils.crystals.pyxtal_cache import space_group_check_compatible
[docs]
CRYSTAL_LATTICE_SYSTEMS = None
[docs]
POINT_SYMMETRIES = None
def _get_crystal_lattice_systems():
global CRYSTAL_LATTICE_SYSTEMS
if CRYSTAL_LATTICE_SYSTEMS is None:
with open(Path(__file__).parent / "crystal_lattice_systems.yaml", "r") as f:
CRYSTAL_LATTICE_SYSTEMS = yaml.safe_load(f)
return CRYSTAL_LATTICE_SYSTEMS
def _get_point_symmetries():
global POINT_SYMMETRIES
if POINT_SYMMETRIES is None:
with open(Path(__file__).parent / "point_symmetries.yaml", "r") as f:
POINT_SYMMETRIES = yaml.safe_load(f)
return POINT_SYMMETRIES
def _get_space_groups():
global SPACE_GROUPS
if SPACE_GROUPS is None:
with open(Path(__file__).parent / "space_groups.yaml", "r") as f:
SPACE_GROUPS = yaml.safe_load(f)
return SPACE_GROUPS
[docs]
class Prop(Enum):
"""
Enumeration of the 3 properties of the SpaceGroup Environment:
- Crystal lattice system
- Point symmetry
- Space group
"""
[docs]
class SpaceGroup(GFlowNetEnv):
"""
SpaceGroup environment for ionic conductivity.
The state space is the combination of three properties:
1. The crystal-lattice system: combination of crystal system and lattice system
See: https://en.wikipedia.org/wiki/Crystal_system#Crystal_system
See: https://en.wikipedia.org/wiki/Crystal_system#Lattice_system
See: https://en.wikipedia.org/wiki/Hexagonal_crystal_family
(8 options + none)
2. The point symmetry
See: https://en.wikipedia.org/wiki/Crystal_system#Crystal_classes
(5 options + none)
3. The space group
See: https://en.wikipedia.org/wiki/Space_group#Table_of_space_groups_in_3_dimensions
(230 options + none)
The action space is the choice of property to update, the index within the property
and the combination of properties (state type) already set in the originating state
type (e.g. crystal-lattice system 2 from source, point symmetry 4 from
crystal-lattice system, space group 69 from point symmetry, etc.). The state type
is included in the action to differentiate actions that lead to same state from
different states, as in GFlowNet the distribution is over states not over actions.
The selection of crystal-lattice system restricts the possible point symmetries and
space groups; the selection of point symmetry restricts the possible
crystal-lattice systems and space groups. The selection of space groups determines
a specific crystal-lattice system and point symmetry. There is no restriction in
the order of selection of properties.
"""
def __init__(
self,
space_groups_subset: Optional[Iterable] = None,
n_atoms: Optional[List[int]] = None,
policy_fmt: str = "onehot",
**kwargs,
):
"""
Parameters
----------
space_groups_subset : iterable
A subset of space group (international) numbers to which to restrict the
state space. If None (default), the entire set of 230 space groups is
considered.
n_atoms : list of int (optional)
A list with the number of atoms per element, used to compute constraints on
the space group. 0's are removed from the list. If None, composition/space
group constraints are ignored.
policy_fmt : str
Specifies the policy encoding. Options:
- onehot: One-hot encoding of each property (crystal-lattice system,
point symmetry, space group), all concatenated to make the overall
input.
- indices: A three-dimensional vector with the indices of each property
"""
# Policy format
if policy_fmt not in ["onehot", "indices"]:
raise NotImplementedError(
"Unknown policy format. policy_fmt must be either 'onehot' or "
f"'indices'. Found {policy_fmt}."
)
[docs]
self.policy_fmt = policy_fmt
# Get dictionaries
[docs]
self.crystal_lattice_systems = _get_crystal_lattice_systems()
[docs]
self.point_symmetries = _get_point_symmetries()
[docs]
self.space_groups = _get_space_groups()
self._restrict_space_groups(space_groups_subset)
# Create tensors with possible values of each property
[docs]
self.cls_valid = torch.tensor([0] + list(self.crystal_lattice_systems.keys()))
[docs]
self.ps_valid = torch.tensor([0] + list(self.point_symmetries.keys()))
[docs]
self.sg_valid = torch.tensor([0] + list(self.space_groups.keys()))
# Set dictionary of compatibility with number of atoms
self.set_n_atoms_compatibility_dict(n_atoms)
# Indices in the state representation: crystal-lattice system (cls), point
# symmetry (ps) and space group (sg)
self.cls_idx, self.ps_idx, self.sg_idx = 0, 1, 2
# Dictionary of all properties
[docs]
self.properties = {
Prop.CLS: self.crystal_lattice_systems,
Prop.PS: self.point_symmetries,
Prop.SG: self.space_groups,
}
# Indices of state types (see self.get_state_type)
[docs]
self.state_type_indices = [0, 1, 2, 3]
# End-of-sequence action
[docs]
self.eos = (-1, -1, -1)
# Source state: index 0 (empty) for all three properties (crystal-lattice
# system index, point symmetry index, space group)
[docs]
self.source = [0 for _ in range(3)]
# Base class init
super().__init__(**kwargs)
[docs]
def get_action_space(self):
"""
Constructs list with all possible actions. An action is described by a tuple
(property, index, state_from_type), where property is (0: crystal-lattice
system, 1: point symmetry, 2: space group), index is the index of the property
set by the action and state_from_type is the state type of the originating
state (see self.state_type_indices).
"""
actions = []
for prop, indices in self.properties.items():
for s_from_type in self.state_type_indices:
if prop == Prop.CLS and s_from_type in [1, 3]:
continue
if prop == Prop.PS and s_from_type in [2, 3]:
continue
actions_prop = [(prop.value, idx, s_from_type) for idx in indices]
actions += actions_prop
actions += [self.eos]
return actions
[docs]
def get_mask_invalid_actions_forward(
self,
state: Optional[List] = None,
done: Optional[bool] = None,
) -> List:
"""
Returns a list of length the action space with values:
- True if the forward action is invalid given the current state.
- False otherwise.
"""
state = self._get_state(state)
done = self._get_done(done)
if done:
return [True for _ in self.action_space]
cls_idx, ps_idx, sg_idx = state
# If space group has been selected, only valid action is EOS
if sg_idx != 0:
mask = [True for _ in self.action_space]
mask[-1] = False
return mask
state_type = self.get_state_type(state)
# If neither crystal-lattice system nor point symmetry selected, apply only
# composition-compatibility constraints
if cls_idx == 0 and ps_idx == 0:
crystal_lattice_systems = [
(self.cls_idx, idx, state_type)
for idx in self.crystal_lattice_systems
if self._is_compatible(cls_idx=idx)
]
point_symmetries = [
(self.ps_idx, idx, state_type)
for idx in self.point_symmetries
if self._is_compatible(ps_idx=idx)
]
# Constraints after having selected crystal-lattice system
if cls_idx != 0:
crystal_lattice_systems = []
space_groups_cls = [
(self.sg_idx, sg, state_type)
for sg in self.crystal_lattice_systems[cls_idx]["space_groups"]
if self.n_atoms_compatibility_dict[sg]
]
# If no point symmetry selected yet
if ps_idx == 0:
point_symmetries = [
(self.ps_idx, idx, state_type)
for idx in self.crystal_lattice_systems[cls_idx]["point_symmetries"]
if self._is_compatible(cls_idx=cls_idx, ps_idx=idx)
]
else:
space_groups_cls = [
(self.sg_idx, idx, state_type)
for idx in self.space_groups
if self.n_atoms_compatibility_dict[idx]
]
# Constraints after having selected point symmetry
if ps_idx != 0:
point_symmetries = []
space_groups_ps = [
(self.sg_idx, sg, state_type)
for sg in self.point_symmetries[ps_idx]["space_groups"]
if self.n_atoms_compatibility_dict[sg]
]
# If no crystal-lattice system selected yet
if cls_idx == 0:
crystal_lattice_systems = [
(self.cls_idx, idx, state_type)
for idx in self.point_symmetries[ps_idx]["crystal_lattice_systems"]
if self._is_compatible(cls_idx=idx, ps_idx=ps_idx)
]
else:
space_groups_ps = [
(self.sg_idx, idx, state_type)
for idx in self.space_groups
if self.n_atoms_compatibility_dict[idx]
]
# Merge space_groups constraints and determine valid space group actions
space_groups = list(set(space_groups_cls).intersection(set(space_groups_ps)))
# Construct mask
actions_valid = set.union(
set(crystal_lattice_systems), set(point_symmetries), set(space_groups)
)
assert len(actions_valid) > 0
mask = [
False if action in actions_valid else True for action in self.action_space
]
return mask
[docs]
def states2proxy(
self, states: Union[List[List], TensorType["batch", "state_dim"]]
) -> TensorType["batch", "state_proxy_dim"]:
"""
Prepares a batch of states in "environment format" for the proxy: the proxy
format is simply the space group.
Args
----
states : list or tensor
A batch of states in environment format, either as a list of states or as a
single tensor.
Returns
-------
A tensor containing all the states in the batch.
"""
states = tlong(states, device=self.device)
return torch.unsqueeze(states[:, self.sg_idx], dim=1)
[docs]
def states2policy(
self, states: List[List]
) -> TensorType["batch", "policy_input_dim"]:
"""
Prepares a batch of states in "environment format" for the policy model, by
calling the appropriate conversion method depending on the settings.
Parameters
----------
states : list
A batch of states in environment format, that is a list of lists.
Returns
-------
A tensor containing the policy representation of all the states in the batch.
"""
if self.policy_fmt == "onehot":
return self.states2policy_onehot(states)
elif self.policy_fmt == "indices":
return super().states2policy(states)
else:
raise NotImplementedError(
"Unknown policy format. policy_fmt must be either 'onehot' or "
f"'indices'. Found {self.policy_fmt}."
)
[docs]
def states2policy_onehot(
self, states: List[List]
) -> TensorType["batch", "policy_input_dim"]:
"""
Prepares a batch of states in "environment format" for the policy model: states
are one-hot encoded.
In particular, the policy input for a state is a vector containing the
following encodings, in this order:
- One-hot encoding of the crystal-lattice system (max length 8).
- One-hot encoding of the point symmetry (max length 5).
- One-hot encoding of the space group (max length 230).
Besides, the states in which each property has not been set yet are included as
an additional class in the encoding. Thus, each property is one-hot encoded
with a vector of length the number of classes in the property plus one.
Notes
-----
In order to not waste memory and for backward compatibility, the one-hot
encodings have a maximum length equal to the maximum number of options in the
configuration.
To obtain the one-hot encoding of a given property index, while accounting for
the fact that not all possible indices might be valid given the current
configuration, we use torch.searchsorted, which receives as first input the
valid set of indices and as second input the value to be encoded, and outputs
the corresponding index. This index in then one-hot encoded.
See: `torch.searchsorted
<https://pytorch.org/docs/stable/generated/torch.searchsorted.html>`_
Example
-------
Consider a configuration with valid space groups [1, 17, 39], and then valid
crystal-lattice systems [1, 3] and valid point symmetries [1, 3, 4].
Additionally, each property can take the value 0 for the case where it is not
set yet.
states = [[0, 0, 0], [1, 1, 1], [3, 4, 17], [3, 3, 39]]
self.states2policy(states)
tensor(
[
[1, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0],
[0, 1, 0, 0, 1, 0, 0, 0, 1, 0, 0],
[0, 0, 1, 0, 0, 0, 1, 0, 0, 1, 0],
[0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 1],
]
)
Parameters
----------
states : list
A batch of states in environment format, that is a list of lists.
Returns
-------
A tensor containing the policy representation of all the states in the batch.
"""
states = tlong(states, device=self.device)
cls_onehot = F.one_hot(
torch.searchsorted(self.cls_valid, states[:, 0]),
self.cls_valid.shape[0],
)
ps_onehot = F.one_hot(
torch.searchsorted(self.ps_valid, states[:, 1]),
self.ps_valid.shape[0],
)
sg_onehot = F.one_hot(
torch.searchsorted(self.sg_valid, states[:, 2]),
self.sg_valid.shape[0],
)
return tfloat(
torch.cat([cls_onehot, ps_onehot, sg_onehot], dim=1),
device=self.device,
float_type=self.float,
)
[docs]
def state2readable(self, state=None):
"""
Transforms the state, represented as a list of property indices, into a
human-readable string with the format:
<space group idx> | <space group symbol> |
<crystal-lattice system> (<crystal-lattice system idx>) |
<point symmetry> (<point symmetry idx>)
<crystal class> | <point group>
Example:
space group: 69
space group symbol: Fmmm
crystal-lattice system: orthorhombic (3)
point symmetry: centrosymmetric (2)
crystal class: rhombic-dipyramidal
point group: mmm
output:
69 | Fmmm | orthorhombic (3) | centrosymmetric (2) |
rhombic-dipyramidal | mmm |
"""
state = self._get_state(state)
cls_idx, ps_idx, sg_idx = state
crystal_lattice_system = self.get_crystal_lattice_system(state)
point_symmetry = self.get_point_symmetry(state)
sg_symbol = self.get_space_group_symbol(state)
crystal_class = self.get_crystal_class(state)
point_group = self.get_point_group(state)
readable = (
f"{sg_idx} | {sg_symbol} | {crystal_lattice_system} ({cls_idx}) | "
+ f"{point_symmetry} ({ps_idx}) | {crystal_class} | {point_group}"
)
return readable
[docs]
def readable2state(self, readable):
"""
Converts a human-readable representation of a state into the standard format.
See: state2readable
"""
properties = readable.split(" | ")
space_group = int(properties[0])
crystal_lattice_system = int(properties[2].split(" ")[-1].strip("(").strip(")"))
point_symmetry = int(properties[3].split(" ")[-1].strip("(").strip(")"))
state = [crystal_lattice_system, point_symmetry, space_group]
return state
[docs]
def get_parents(self, state=None, done=None, action=None):
"""
Determines all parents and actions that lead to a state.
Args
----
state : list
done : bool
Whether the trajectory is done. If None, done is taken from instance.
action : None
Ignored
Returns
-------
parents : list
List of parents in state format
actions : list
List of actions that lead to state for each parent in parents
"""
state = self._get_state(state)
done = self._get_done(done)
if done:
return [state], [self.eos]
else:
parents = []
actions = []
# Catch cases where space group has been selected
if state[self.sg_idx] != 0:
sg = state[self.sg_idx]
# Add parent: source
parents.append(self.source)
action = (self.sg_idx, sg, 0)
actions.append(action)
# Add parents: states before setting space group
state_pre_sg = state.copy()
state_pre_sg[self.sg_idx] = 0
for prop in range(len(state_pre_sg)):
parent = state_pre_sg.copy()
parent[prop] = 0
parents.append(parent)
parent_type = self.get_state_type(parent)
action = (self.sg_idx, sg, parent_type)
actions.append(action)
else:
# Catch other parents
for prop, idx in enumerate(state[: self.sg_idx]):
if idx != 0:
parent = state.copy()
parent[prop] = 0
parents.append(parent)
parent_type = self.get_state_type(parent)
action = (prop, idx, parent_type)
actions.append(action)
return parents, actions
[docs]
def step(self, action: Tuple[int, int]) -> Tuple[List[int], Tuple[int, int], bool]:
"""
Executes step given an action.
Args
----
action : tuple
Action to be executed. See: get_action_space()
Returns
-------
self.state : list
The new state after executing the action
action : tuple
Action executed
valid : bool
False, if the action is not allowed for the current state.
"""
# If action not found in action space raise an error
if action not in self.action_space:
raise ValueError(
f"Tried to execute action {action} not present in action space."
)
else:
action_idx = self.action_space.index(action)
# If action is in invalid mask, exit immediately
if self.get_mask_invalid_actions_forward()[action_idx]:
return self.state, action, False
valid = True
self.n_actions += 1
prop, idx, _ = action
# Action is not eos
if action != self.eos:
state_next = self.state[:]
state_next[prop] = idx
# Set crystal-lattice system and point symmetry if space group is set
self.state = self._set_constrained_properties(state_next)
return self.state, action, valid
# Action is eos
else:
self.done = True
return self.state, action, valid
def _get_max_trajectory_length(self) -> int:
"""
Returns the maximum trajectory length of the environment, including the EOS
action.
"""
return len(self.source) + 1
def _set_constrained_properties(self, state: List[int]) -> List[int]:
"""
Sets the missing properties in a state that can be determined from the existing
properties in the input state.
Parameters
----------
state : list
A state in environment format.
Returns
-------
list
The updated state.
"""
cls_idx, ps_idx, sg_idx = state
if sg_idx != 0:
if sg_idx not in self.space_groups:
return state
if cls_idx == 0:
state[self.cls_idx] = self.space_groups[state[self.sg_idx]][
"crystal_lattice_system_idx"
]
if ps_idx == 0:
state[self.ps_idx] = self.space_groups[state[self.sg_idx]][
"point_symmetry_idx"
]
return state
[docs]
def get_crystal_system(self, state: List[int] = None) -> str:
"""
Returns the name of the crystal system given a state.
"""
state = self._get_state(state)
state = self._set_constrained_properties(state)
if state[self.cls_idx] != 0:
return self.crystal_lattice_systems[state[self.cls_idx]]["crystal_system"]
else:
return "None"
@property
[docs]
def crystal_system(self) -> str:
return self.get_crystal_system(self.state)
[docs]
def get_lattice_system(self, state: List[int] = None) -> str:
"""
Returns the name of the lattice system given a state.
"""
state = self._get_state(state)
state = self._set_constrained_properties(state)
if state[self.cls_idx] != 0:
return self.crystal_lattice_systems[state[self.cls_idx]]["lattice_system"]
else:
return "None"
@property
[docs]
def lattice_system(self, state: List[int] = None) -> str:
return self.get_lattice_system(self.state)
[docs]
def get_crystal_lattice_system(self, state: List[int] = None) -> str:
"""
Returns the name of the crystal-lattice system given a state.
"""
state = self._get_state(state)
crystal_system = self.get_crystal_system(state)
lattice_system = self.get_lattice_system(state)
if crystal_system != lattice_system:
return f"{crystal_system}-{lattice_system}"
else:
return crystal_system
@property
[docs]
def crystal_lattice_system(self) -> str:
return self.get_crystal_lattice_system(self.state)
[docs]
def get_point_symmetry(self, state: List[int] = None) -> str:
"""
Returns the name of the point symmetry given a state.
"""
state = self._get_state(state)
state = self._set_constrained_properties(state)
if state[self.ps_idx] != 0:
return self.point_symmetries[state[self.ps_idx]]["point_symmetry"]
else:
return "None"
@property
[docs]
def point_symmetry(self) -> str:
return self.get_point_symmetry(self.state)
[docs]
def get_space_group_symbol(self, state: List[int] = None) -> str:
"""
Returns the name of the space group symbol given a state.
"""
state = self._get_state(state)
if state[self.sg_idx] != 0:
return self.space_groups[state[self.sg_idx]]["full_symbol"]
else:
return "None"
@property
[docs]
def space_group_symbol(self) -> str:
return self.get_space_group_symbol(self.state)
[docs]
def get_space_group(self, state: List[int] = None) -> int:
"""
Returns the index of the space group symbol given a state.
"""
state = self._get_state(state)
if state[self.sg_idx] != 0:
return state[self.sg_idx]
else:
return None
@property
[docs]
def space_group(self) -> int:
return self.get_space_group(self.state)
# TODO: Technically the crystal class could be determined from crystal-lattice
# system + point symmetry
[docs]
def get_crystal_class(self, state: List[int] = None) -> str:
"""
Returns the name of the crystal_class given a state.
"""
state = self._get_state(state)
if state[self.sg_idx] != 0:
return self.space_groups[state[self.sg_idx]]["crystal_class"]
else:
return "None"
@property
[docs]
def crystal_class(self) -> str:
return self.get_crystal_class(self.state)
# TODO: Technically the point group could be determined from crystal-lattice system
# + point symmetry
[docs]
def get_point_group(self, state: List[int] = None) -> str:
"""
Returns the name of the point group given a state.
"""
state = self._get_state(state)
if state[self.sg_idx] != 0:
return self.space_groups[state[self.sg_idx]]["point_group"]
else:
return "None"
@property
[docs]
def point_group(self) -> str:
return self.get_point_group(self.state)
[docs]
def get_state_type(self, state: List[int] = None) -> int:
"""
Returns the index of the type of the state passed as an argument. The state
type is one of the following (self.state_type_indices):
0: both crystal-lattice system and point symmetry are unset (== 0)
1: crystal-lattice system is set (!= 0); point symmetry is unset
2: crystal-lattice system is unset; point symmetry is set
3: both crystal-lattice system and point symmetry are set
"""
state = self._get_state(state)
return sum([int(s > 0) * f for s, f in zip(state, (1, 2))])
[docs]
def set_n_atoms_compatibility_dict(self, n_atoms: List):
"""
Sets self.n_atoms_compatibility_dict by calling
SpaceGroup.build_n_atoms_compatibility_dict(), which contains a dictionary of
{space_group: is_compatible} indicating whether each space_group in
space_groups is compatible with the stoichiometry defined by n_atoms.
See: build_n_atoms_compatibility_dict()
Args
----
n_atoms : list of int
A list of number of atoms for each element in a composition. 0s will be
removed from the list since they do not count towards the compatibility
with a space group.
"""
# Get compatibility with stoichiometry
self.n_atoms_compatibility_dict = SpaceGroup.build_n_atoms_compatibility_dict(
n_atoms, self.space_groups.keys()
)
def _is_compatible(
self, cls_idx: Optional[int] = None, ps_idx: Optional[int] = None
):
"""
Returns True if there is exists at least one space group compatible with the
atom composition (according to self.n_atoms_compatibility_dict), with the
crystal-lattice system (if provided), and with the point symmetry (if provided).
False otherwise.
"""
# Get list of space groups compatible with the composition
space_groups = [
sg for sg in self.space_groups if self.n_atoms_compatibility_dict[sg]
]
# Prune the list of space groups to those compatible with the provided crystal-
# lattice system
if cls_idx is not None:
space_groups_cls = self.crystal_lattice_systems[cls_idx]["space_groups"]
space_groups = list(set(space_groups).intersection(set(space_groups_cls)))
# Prune the list of space groups to those compatible with the provided point
# symmetry
if ps_idx is not None:
space_groups_ps = self.point_symmetries[ps_idx]["space_groups"]
space_groups = list(set(space_groups).intersection(set(space_groups_ps)))
return len(space_groups) > 0
@staticmethod
[docs]
def build_n_atoms_compatibility_dict(
n_atoms: List[int], space_groups: Iterable[int]
):
"""
Obtains which space groups are compatible with the stoichiometry given as
argument (n_atoms).
It relies on a function which, internally, calls pyxtal's
pyxtal.symmetry.Group.check_compatible(). Note that sometimes that pyxtal
is known to return invalid results.
Args
----
n_atoms : list of int
A list of number of atoms for each element in a stoichiometry. 0s will be
removed from the list since they do not count towards the compatibility
with a space group. If None, all space groups will be marked as compatible.
space_groups : list of int
A list of space group international numbers, in [1, 230]
Returns
-------
A dictionary of {space_group: is_compatible} indicating whether each
space_group in space_groups is compatible with the stoichiometry defined by
n_atoms.
"""
if n_atoms is None:
return {sg: True for sg in space_groups}
n_atoms = [n for n in n_atoms if n > 0]
assert all([n > 0 for n in n_atoms])
assert all([sg > 0 and sg <= 230 for sg in space_groups])
return {sg: space_group_check_compatible(sg, n_atoms) for sg in space_groups}
def _restrict_space_groups(self, sg_subset: Optional[Iterable] = None):
"""
Updates the dictionaries:
- self.space_groups
- self.crystal_lattice_systems
- self.point_symmetries
by eliminating the space groups that are not in the subset sg_subset passed as
an argument.
"""
if sg_subset is None:
return
sg_subset = set(sg_subset)
# Update self.space_groups
self.space_groups = {
k: v for (k, v) in self.space_groups.items() if k in sg_subset
}
# Update self.crystal_lattice_systems based on space groups
self.crystal_lattice_systems = deepcopy(self.crystal_lattice_systems)
cls_to_remove = []
for cls in self.crystal_lattice_systems:
cls_space_groups = sg_subset.intersection(
set(self.crystal_lattice_systems[cls]["space_groups"])
)
if len(cls_space_groups) == 0:
cls_to_remove.append(cls)
else:
self.crystal_lattice_systems[cls]["space_groups"] = list(
cls_space_groups
)
for cls in cls_to_remove:
del self.crystal_lattice_systems[cls]
# Update self.point_symmetries based on space groups
self.point_symmetries = deepcopy(self.point_symmetries)
ps_to_remove = []
for ps in self.point_symmetries:
ps_space_groups = sg_subset.intersection(
set(self.point_symmetries[ps]["space_groups"])
)
if len(ps_space_groups) == 0:
ps_to_remove.append(ps)
else:
self.point_symmetries[ps]["space_groups"] = list(ps_space_groups)
for ps in ps_to_remove:
del self.point_symmetries[ps]
# Update point symmetries of remaining crystal lattice systems
point_symmetries = set(self.point_symmetries)
for cls in self.crystal_lattice_systems:
cls_point_symmetries = point_symmetries.intersection(
set(self.crystal_lattice_systems[cls]["point_symmetries"])
)
self.crystal_lattice_systems[cls]["point_symmetries"] = list(
cls_point_symmetries
)
# Update crystal lattice systems of remaining point symmetries
crystal_lattice_systems = set(self.crystal_lattice_systems)
for ps in self.point_symmetries:
ps_crystal_lattice_systems = crystal_lattice_systems.intersection(
set(self.point_symmetries[ps]["crystal_lattice_systems"])
)
self.point_symmetries[ps]["crystal_lattice_systems"] = list(
ps_crystal_lattice_systems
)
[docs]
def get_all_terminating_states(
self, apply_stoichiometry_constraints: Optional[bool] = True
) -> List[List]:
all_x = []
for sg in self.space_groups:
if (
apply_stoichiometry_constraints
and self.n_atoms_compatibility_dict[sg] is False
):
continue
all_x.append(self._set_constrained_properties([0, 0, sg]))
return all_x
[docs]
def is_valid(self, x: List) -> bool:
"""
Determines whether a state is valid, according to the attributes of the
environment.
"""
if x[self.sg_idx] in self.space_groups:
return True
return False