gflownet.envs.crystals.spacegroup

Classes to represent crystal environments

Attributes

CRYSTAL_LATTICE_SYSTEMS

POINT_SYMMETRIES

SPACE_GROUPS

Classes

Prop

Enumeration of the 3 properties of the SpaceGroup Environment:

SpaceGroup

Module Contents

gflownet.envs.crystals.spacegroup.CRYSTAL_LATTICE_SYSTEMS = None[source]
gflownet.envs.crystals.spacegroup.POINT_SYMMETRIES = None[source]
gflownet.envs.crystals.spacegroup.SPACE_GROUPS = None[source]
class gflownet.envs.crystals.spacegroup.Prop[source]

Bases: enum.Enum

Enumeration of the 3 properties of the SpaceGroup Environment:
  • Crystal lattice system

  • Point symmetry

  • Space group

CLS = 0[source]
PS = 1[source]
SG = 2[source]
class gflownet.envs.crystals.spacegroup.SpaceGroup(space_groups_subset=None, n_atoms=None, policy_fmt='onehot', **kwargs)[source]

Bases: gflownet.envs.base.GFlowNetEnv

Parameters:
  • space_groups_subset (iterable) – A subset of space group (international) numbers to which to restrict the state space. If None (default), the entire set of 230 space groups is considered.

  • n_atoms (list of int (optional)) –

    A list with the number of atoms per element, used to compute constraints on the space group. 0’s are removed from the list. If None, composition/space group constraints are ignored.

    policy_fmt : str

    Specifies the policy encoding. Options:
    • onehot: One-hot encoding of each property (crystal-lattice system, point symmetry, space group), all concatenated to make the overall input.

    • indices: A three-dimensional vector with the indices of each property

  • policy_fmt (str)

policy_fmt = 'onehot'[source]
crystal_lattice_systems = None[source]
point_symmetries = None[source]
space_groups = None[source]
cls_valid[source]
ps_valid[source]
sg_valid[source]
properties[source]
state_type_indices = [0, 1, 2, 3][source]
eos[source]
source[source]
get_action_space()[source]

Constructs list with all possible actions. An action is described by a tuple (property, index, state_from_type), where property is (0: crystal-lattice system, 1: point symmetry, 2: space group), index is the index of the property set by the action and state_from_type is the state type of the originating state (see self.state_type_indices).

get_mask_invalid_actions_forward(state=None, done=None)[source]
Returns a list of length the action space with values:
  • True if the forward action is invalid given the current state.

  • False otherwise.

Parameters:
  • state (Optional[List])

  • done (Optional[bool])

Return type:

List

states2proxy(states)[source]

Prepares a batch of states in “environment format” for the proxy: the proxy format is simply the space group.

Parameters:

states (list or tensor) – A batch of states in environment format, either as a list of states or as a single tensor.

Returns:

A tensor containing all the states in the batch.

Return type:

torchtyping.TensorType[batch, state_proxy_dim]

states2policy(states)[source]

Prepares a batch of states in “environment format” for the policy model, by calling the appropriate conversion method depending on the settings.

Parameters:

states (list) – A batch of states in environment format, that is a list of lists.

Returns:

A tensor containing the policy representation of all the states in the batch.

Return type:

torchtyping.TensorType[batch, policy_input_dim]

states2policy_onehot(states)[source]

Prepares a batch of states in “environment format” for the policy model: states are one-hot encoded.

In particular, the policy input for a state is a vector containing the following encodings, in this order:

  • One-hot encoding of the crystal-lattice system (max length 8).

  • One-hot encoding of the point symmetry (max length 5).

  • One-hot encoding of the space group (max length 230).

Besides, the states in which each property has not been set yet are included as an additional class in the encoding. Thus, each property is one-hot encoded with a vector of length the number of classes in the property plus one.

Notes

In order to not waste memory and for backward compatibility, the one-hot encodings have a maximum length equal to the maximum number of options in the configuration.

To obtain the one-hot encoding of a given property index, while accounting for the fact that not all possible indices might be valid given the current configuration, we use torch.searchsorted, which receives as first input the valid set of indices and as second input the value to be encoded, and outputs the corresponding index. This index in then one-hot encoded.

See: torch.searchsorted

Example

Consider a configuration with valid space groups [1, 17, 39], and then valid crystal-lattice systems [1, 3] and valid point symmetries [1, 3, 4]. Additionally, each property can take the value 0 for the case where it is not set yet.

states = [[0, 0, 0], [1, 1, 1], [3, 4, 17], [3, 3, 39]] self.states2policy(states) tensor(

[

[1, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0], [0, 1, 0, 0, 1, 0, 0, 0, 1, 0, 0], [0, 0, 1, 0, 0, 0, 1, 0, 0, 1, 0], [0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 1],

]

)

Parameters:

states (list) – A batch of states in environment format, that is a list of lists.

Returns:

A tensor containing the policy representation of all the states in the batch.

Return type:

torchtyping.TensorType[batch, policy_input_dim]

state2readable(state=None)[source]

Transforms the state, represented as a list of property indices, into a human-readable string with the format:

<space group idx> | <space group symbol> | <crystal-lattice system> (<crystal-lattice system idx>) | <point symmetry> (<point symmetry idx>) <crystal class> | <point group>

Example

space group: 69 space group symbol: Fmmm crystal-lattice system: orthorhombic (3) point symmetry: centrosymmetric (2) crystal class: rhombic-dipyramidal point group: mmm output:

69 | Fmmm | orthorhombic (3) | centrosymmetric (2) | rhombic-dipyramidal | mmm |

readable2state(readable)[source]

Converts a human-readable representation of a state into the standard format. See: state2readable

get_parents(state=None, done=None, action=None)[source]

Determines all parents and actions that lead to a state.

Parameters:
  • state (list)

  • done (bool) – Whether the trajectory is done. If None, done is taken from instance.

  • action (None) – Ignored

Returns:

  • parents (list) – List of parents in state format

  • actions (list) – List of actions that lead to state for each parent in parents

step(action)[source]

Executes step given an action.

Parameters:

action (tuple) – Action to be executed. See: get_action_space()

Returns:

  • self.state (list) – The new state after executing the action

  • action (tuple) – Action executed

  • valid (bool) – False, if the action is not allowed for the current state.

Return type:

Tuple[List[int], Tuple[int, int], bool]

get_crystal_system(state=None)[source]

Returns the name of the crystal system given a state.

Parameters:

state (List[int])

Return type:

str

property crystal_system: str[source]
Return type:

str

get_lattice_system(state=None)[source]

Returns the name of the lattice system given a state.

Parameters:

state (List[int])

Return type:

str

property lattice_system: str[source]
Parameters:

state (List[int])

Return type:

str

get_crystal_lattice_system(state=None)[source]

Returns the name of the crystal-lattice system given a state.

Parameters:

state (List[int])

Return type:

str

property crystal_lattice_system: str[source]
Return type:

str

get_point_symmetry(state=None)[source]

Returns the name of the point symmetry given a state.

Parameters:

state (List[int])

Return type:

str

property point_symmetry: str[source]
Return type:

str

get_space_group_symbol(state=None)[source]

Returns the name of the space group symbol given a state.

Parameters:

state (List[int])

Return type:

str

property space_group_symbol: str[source]
Return type:

str

get_space_group(state=None)[source]

Returns the index of the space group symbol given a state.

Parameters:

state (List[int])

Return type:

int

property space_group: int[source]
Return type:

int

get_crystal_class(state=None)[source]

Returns the name of the crystal_class given a state.

Parameters:

state (List[int])

Return type:

str

property crystal_class: str[source]
Return type:

str

get_point_group(state=None)[source]

Returns the name of the point group given a state.

Parameters:

state (List[int])

Return type:

str

property point_group: str[source]
Return type:

str

get_state_type(state=None)[source]

Returns the index of the type of the state passed as an argument. The state type is one of the following (self.state_type_indices):

0: both crystal-lattice system and point symmetry are unset (== 0) 1: crystal-lattice system is set (!= 0); point symmetry is unset 2: crystal-lattice system is unset; point symmetry is set 3: both crystal-lattice system and point symmetry are set

Parameters:

state (List[int])

Return type:

int

set_n_atoms_compatibility_dict(n_atoms)[source]

Sets self.n_atoms_compatibility_dict by calling SpaceGroup.build_n_atoms_compatibility_dict(), which contains a dictionary of {space_group: is_compatible} indicating whether each space_group in space_groups is compatible with the stoichiometry defined by n_atoms.

See: build_n_atoms_compatibility_dict()

Parameters:

n_atoms (list of int) – A list of number of atoms for each element in a composition. 0s will be removed from the list since they do not count towards the compatibility with a space group.

static build_n_atoms_compatibility_dict(n_atoms, space_groups)[source]

Obtains which space groups are compatible with the stoichiometry given as argument (n_atoms).

It relies on a function which, internally, calls pyxtal’s pyxtal.symmetry.Group.check_compatible(). Note that sometimes that pyxtal is known to return invalid results.

Parameters:
  • n_atoms (list of int) – A list of number of atoms for each element in a stoichiometry. 0s will be removed from the list since they do not count towards the compatibility with a space group. If None, all space groups will be marked as compatible.

  • space_groups (list of int) – A list of space group international numbers, in [1, 230]

Returns:

  • A dictionary of {space_group (is_compatible} indicating whether each)

  • space_group in space_groups is compatible with the stoichiometry defined by

  • n_atoms.

get_all_terminating_states(apply_stoichiometry_constraints=True)[source]
Parameters:

apply_stoichiometry_constraints (Optional[bool])

Return type:

List[List]

is_valid(x)[source]

Determines whether a state is valid, according to the attributes of the environment.

Parameters:

x (List)

Return type:

bool