gflownet.envs.crystals.spacegroup
Classes to represent crystal environments
Attributes
Classes
Enumeration of the 3 properties of the SpaceGroup Environment: |
|
Module Contents
- class gflownet.envs.crystals.spacegroup.Prop[source]
Bases:
enum.Enum- Enumeration of the 3 properties of the SpaceGroup Environment:
Crystal lattice system
Point symmetry
Space group
- class gflownet.envs.crystals.spacegroup.SpaceGroup(space_groups_subset=None, n_atoms=None, policy_fmt='onehot', **kwargs)[source]
Bases:
gflownet.envs.base.GFlowNetEnv- Parameters:
space_groups_subset (iterable) – A subset of space group (international) numbers to which to restrict the state space. If None (default), the entire set of 230 space groups is considered.
n_atoms (list of int (optional)) –
A list with the number of atoms per element, used to compute constraints on the space group. 0’s are removed from the list. If None, composition/space group constraints are ignored.
policy_fmt : str
- Specifies the policy encoding. Options:
onehot: One-hot encoding of each property (crystal-lattice system, point symmetry, space group), all concatenated to make the overall input.
indices: A three-dimensional vector with the indices of each property
policy_fmt (str)
- get_action_space()[source]
Constructs list with all possible actions. An action is described by a tuple (property, index, state_from_type), where property is (0: crystal-lattice system, 1: point symmetry, 2: space group), index is the index of the property set by the action and state_from_type is the state type of the originating state (see self.state_type_indices).
- get_mask_invalid_actions_forward(state=None, done=None)[source]
- Returns a list of length the action space with values:
True if the forward action is invalid given the current state.
False otherwise.
- Parameters:
state (Optional[List])
done (Optional[bool])
- Return type:
List
- states2proxy(states)[source]
Prepares a batch of states in “environment format” for the proxy: the proxy format is simply the space group.
- Parameters:
states (list or tensor) – A batch of states in environment format, either as a list of states or as a single tensor.
- Returns:
A tensor containing all the states in the batch.
- Return type:
torchtyping.TensorType[batch, state_proxy_dim]
- states2policy(states)[source]
Prepares a batch of states in “environment format” for the policy model, by calling the appropriate conversion method depending on the settings.
- Parameters:
states (list) – A batch of states in environment format, that is a list of lists.
- Returns:
A tensor containing the policy representation of all the states in the batch.
- Return type:
torchtyping.TensorType[batch, policy_input_dim]
- states2policy_onehot(states)[source]
Prepares a batch of states in “environment format” for the policy model: states are one-hot encoded.
In particular, the policy input for a state is a vector containing the following encodings, in this order:
One-hot encoding of the crystal-lattice system (max length 8).
One-hot encoding of the point symmetry (max length 5).
One-hot encoding of the space group (max length 230).
Besides, the states in which each property has not been set yet are included as an additional class in the encoding. Thus, each property is one-hot encoded with a vector of length the number of classes in the property plus one.
Notes
In order to not waste memory and for backward compatibility, the one-hot encodings have a maximum length equal to the maximum number of options in the configuration.
To obtain the one-hot encoding of a given property index, while accounting for the fact that not all possible indices might be valid given the current configuration, we use torch.searchsorted, which receives as first input the valid set of indices and as second input the value to be encoded, and outputs the corresponding index. This index in then one-hot encoded.
See: torch.searchsorted
Example
Consider a configuration with valid space groups [1, 17, 39], and then valid crystal-lattice systems [1, 3] and valid point symmetries [1, 3, 4]. Additionally, each property can take the value 0 for the case where it is not set yet.
states = [[0, 0, 0], [1, 1, 1], [3, 4, 17], [3, 3, 39]] self.states2policy(states) tensor(
- [
[1, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0], [0, 1, 0, 0, 1, 0, 0, 0, 1, 0, 0], [0, 0, 1, 0, 0, 0, 1, 0, 0, 1, 0], [0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 1],
]
)
- Parameters:
states (list) – A batch of states in environment format, that is a list of lists.
- Returns:
A tensor containing the policy representation of all the states in the batch.
- Return type:
torchtyping.TensorType[batch, policy_input_dim]
- state2readable(state=None)[source]
Transforms the state, represented as a list of property indices, into a human-readable string with the format:
<space group idx> | <space group symbol> | <crystal-lattice system> (<crystal-lattice system idx>) | <point symmetry> (<point symmetry idx>) <crystal class> | <point group>
Example
space group: 69 space group symbol: Fmmm crystal-lattice system: orthorhombic (3) point symmetry: centrosymmetric (2) crystal class: rhombic-dipyramidal point group: mmm output:
69 | Fmmm | orthorhombic (3) | centrosymmetric (2) | rhombic-dipyramidal | mmm |
- readable2state(readable)[source]
Converts a human-readable representation of a state into the standard format. See: state2readable
- get_parents(state=None, done=None, action=None)[source]
Determines all parents and actions that lead to a state.
- Parameters:
state (list)
done (bool) – Whether the trajectory is done. If None, done is taken from instance.
action (None) – Ignored
- Returns:
parents (list) – List of parents in state format
actions (list) – List of actions that lead to state for each parent in parents
- step(action)[source]
Executes step given an action.
- Parameters:
action (tuple) – Action to be executed. See: get_action_space()
- Returns:
self.state (list) – The new state after executing the action
action (tuple) – Action executed
valid (bool) – False, if the action is not allowed for the current state.
- Return type:
Tuple[List[int], Tuple[int, int], bool]
- get_crystal_system(state=None)[source]
Returns the name of the crystal system given a state.
- Parameters:
state (List[int])
- Return type:
str
- get_lattice_system(state=None)[source]
Returns the name of the lattice system given a state.
- Parameters:
state (List[int])
- Return type:
str
- get_crystal_lattice_system(state=None)[source]
Returns the name of the crystal-lattice system given a state.
- Parameters:
state (List[int])
- Return type:
str
- get_point_symmetry(state=None)[source]
Returns the name of the point symmetry given a state.
- Parameters:
state (List[int])
- Return type:
str
- get_space_group_symbol(state=None)[source]
Returns the name of the space group symbol given a state.
- Parameters:
state (List[int])
- Return type:
str
- get_space_group(state=None)[source]
Returns the index of the space group symbol given a state.
- Parameters:
state (List[int])
- Return type:
int
- get_crystal_class(state=None)[source]
Returns the name of the crystal_class given a state.
- Parameters:
state (List[int])
- Return type:
str
- get_point_group(state=None)[source]
Returns the name of the point group given a state.
- Parameters:
state (List[int])
- Return type:
str
- get_state_type(state=None)[source]
Returns the index of the type of the state passed as an argument. The state type is one of the following (self.state_type_indices):
0: both crystal-lattice system and point symmetry are unset (== 0) 1: crystal-lattice system is set (!= 0); point symmetry is unset 2: crystal-lattice system is unset; point symmetry is set 3: both crystal-lattice system and point symmetry are set
- Parameters:
state (List[int])
- Return type:
int
- set_n_atoms_compatibility_dict(n_atoms)[source]
Sets self.n_atoms_compatibility_dict by calling SpaceGroup.build_n_atoms_compatibility_dict(), which contains a dictionary of {space_group: is_compatible} indicating whether each space_group in space_groups is compatible with the stoichiometry defined by n_atoms.
See: build_n_atoms_compatibility_dict()
- Parameters:
n_atoms (list of int) – A list of number of atoms for each element in a composition. 0s will be removed from the list since they do not count towards the compatibility with a space group.
- static build_n_atoms_compatibility_dict(n_atoms, space_groups)[source]
Obtains which space groups are compatible with the stoichiometry given as argument (n_atoms).
It relies on a function which, internally, calls pyxtal’s pyxtal.symmetry.Group.check_compatible(). Note that sometimes that pyxtal is known to return invalid results.
- Parameters:
n_atoms (list of int) – A list of number of atoms for each element in a stoichiometry. 0s will be removed from the list since they do not count towards the compatibility with a space group. If None, all space groups will be marked as compatible.
space_groups (list of int) – A list of space group international numbers, in [1, 230]
- Returns:
A dictionary of {space_group (is_compatible} indicating whether each)
space_group in space_groups is compatible with the stoichiometry defined by
n_atoms.