Source code for gflownet.envs.crystals.composition

"""
Classes to represent material compositions (stoichiometry)
"""

import re
from typing import Dict, List, Optional, Tuple, Union

import numpy as np
from torchtyping import TensorType

from gflownet.envs.base import GFlowNetEnv
from gflownet.utils.common import tfloat, tlong
from gflownet.utils.crystals.constants import ELEMENT_NAMES, OXIDATION_STATES
from gflownet.utils.crystals.pyxtal_cache import (
    get_space_group,
    space_group_check_compatible,
    space_group_lowest_free_wp_multiplicity,
    space_group_wyckoff_gcd,
)

[docs] N_ELEMENTS_ORACLE = 94
[docs] class Composition(GFlowNetEnv): """ Composition environment for crystal materials. States are represented as a dictionary where the keys are the atomic numbers of the elements and the values the number of atoms per element. """ def __init__( self, elements: Union[List, int] = 94, max_diff_elem: int = 5, min_diff_elem: int = 2, min_atoms: int = 2, max_atoms: int = 20, min_atom_i: int = 1, max_atom_i: int = 16, oxidation_states: Optional[Dict] = None, alphabet: Optional[Dict] = None, required_elements: Optional[Union[Tuple, List]] = (), space_group: Optional[int] = None, do_charge_check: bool = False, do_spacegroup_check: bool = True, **kwargs, ): """ Args ---------- elements : list or int Elements that will be used for construction of crystal. Either list, in which case every value should indicate the atomic number of an element, or int, in which case n consecutive atomic numbers will be used. Note that we assume this will correspond to real atomic numbers, i.e. start from 1, not 0. max_diff_elem : int Maximum number of unique elements in the crystal min_diff_elem : int Minimum number of unique elements in the crystal min_atoms : int Minimum number of atoms that needs to be used to construct a crystal max_atoms : int Maximum number of atoms that can be used to construct a crystal min_atom_i : int Minimum number of elements of each kind that needs to be used to construct a crystal max_atom_i : int Maximum number of elements of each kind that can be used to construct a crystal oxidation_states : (optional) dict Mapping from ints (representing elements) to lists of different oxidation states alphabet : (optional) dict Mapping from ints (representing elements) to strings containing human-readable elements' names required_elements : (optional) list List of elements that must be present in a crystal for it to represent a valid end state space_group : (optional) int International number of a space group to be used for compatibility check, using pyxtal.symmetry.Group.check_compatible(). do_charge_check : bool Whether to do neutral charge check and forbid compositions for which neutral charge is not possible. do_spacegroup_check : bool Whether to do a space group compatibility check and forbid compositions with incompatible Wyckoff positions with the given space group. """ if isinstance(elements, int): elements = [i + 1 for i in range(elements)] if len(elements) != len(set(elements)): raise ValueError( f"Provided elements must be unique, detected {len(elements) - len(set(elements))} duplicates." ) if any(e <= 0 for e in elements): raise ValueError( "Provided elements should be non-negative (assumed indexing from 1 for H)." )
[docs] self.elements = sorted(elements)
[docs] self.max_diff_elem = max_diff_elem
[docs] self.min_diff_elem = min_diff_elem
[docs] self.min_atoms = min_atoms
[docs] self.max_atoms = max_atoms
[docs] self.min_atom_i = min_atom_i
[docs] self.max_atom_i = max_atom_i
[docs] self.oxidation_states = ( oxidation_states if oxidation_states is not None else OXIDATION_STATES.copy() )
[docs] self.alphabet = alphabet if alphabet is not None else ELEMENT_NAMES.copy()
[docs] self.alphabet_rev = {v: k for k, v in self.alphabet.items()}
[docs] self.required_elements = ( required_elements if required_elements is not None else [] )
[docs] self.space_group = space_group
[docs] self.do_charge_check = do_charge_check
[docs] self.do_spacegroup_check = do_spacegroup_check
[docs] self.elem2idx = {el: idx for idx, el in enumerate(self.elements)}
# Source state: empty dict
[docs] self.source = {}
# End-of-sequence action
[docs] self.eos = (-1, -1)
super().__init__(**kwargs)
[docs] def set_space_group(self, space_group: int): """ Sets the space group. Parameters ---------- space_group : int Space group number. """ self.space_group = space_group
[docs] def get_action_space(self): """ Constructs list with all possible actions. An action is described by a tuple (element, n), indicating that the count of element will be set to n. """ assert self.max_diff_elem >= self.min_diff_elem assert self.max_atom_i >= self.min_atom_i valid_word_len = np.arange(self.min_atom_i, self.max_atom_i + 1) actions = [(element, n) for element in self.elements for n in valid_word_len] actions.append(self.eos) return actions
def _get_max_trajectory_length(self) -> int: """ Returns the maximum trajectory length of the environment, including the EOS action. """ return min(self.max_diff_elem, self.max_atoms // self.min_atom_i) + 1 def _refine_compatibility_check( self, state, mask_required_element, mask_unrequired_element ): """ Refines the masks (in-place) of required and unrequired elements by doing compatibility checks between the space group and the number of atoms. Args ---- state : list The state on which the masks are to be applied. mask_required_element: list Element-wise mask indicating invalid actions for required elements. This masks indicates whether each individual actions is invalid or not for elements that are required to be in the composition by the end of the trajectory. mask_unrequired_element: list Element-wise mask indicating invalid actions for unrequired elements. This masks indicates whether each individual actions is invalid or not for elements that are not required to be in the composition by the end of the trajectory. """ space_group = get_space_group(self.space_group) n_atoms_per_element = self.get_n_atoms_per_element(state) # Get the greated common divisor of the group's wyckoff position. # It cannot be valid to add a number of atoms that is not a # multiple of this value wyckoff_gcd = space_group_wyckoff_gcd(self.space_group) # Get the multiplicity of the group's most specific wyckoff position with # at least one degree of freedom free_multiplicity = space_group_lowest_free_wp_multiplicity(self.space_group) # Go through each action in the masks, validating them # individually for action_idx, nb_atoms_action in enumerate( range(self.min_atom_i, self.max_atom_i + 1) ): if ( not mask_required_element[action_idx] or not mask_unrequired_element[action_idx] ): # If the number of atoms added by this action is not a # multiple of the greatest common divisor of the wyckoff # positions' multiplicities, mark action as invalid if nb_atoms_action % wyckoff_gcd != 0: mask_required_element[action_idx] = True mask_unrequired_element[action_idx] = True continue # If the number of atoms added by this action is a # multiple of a non-specific wyckoff position, nothing # prevents it from being valid if nb_atoms_action % free_multiplicity == 0: continue # Checking validity by induction. If a composition is # valid, adding a number of atoms is equal to the # multiplicity of a non-specific position, then this # action must also be valid. if nb_atoms_action > free_multiplicity and ( not mask_required_element[action_idx - free_multiplicity] or not mask_unrequired_element[action_idx - free_multiplicity] ): continue # If the composition resulting from this action is # incompatible with the space group, mark action as # invalid n_atoms_post_action = n_atoms_per_element + [nb_atoms_action] sg_compatible = space_group_check_compatible( self.space_group, n_atoms_post_action ) if not sg_compatible: mask_required_element[action_idx] = True mask_unrequired_element[action_idx] = True
[docs] def get_mask_invalid_actions_forward(self, state=None, done=None): """ Returns a vector of length the action space + 1: True if forward action is invalid given the current state, False otherwise. """ state = self._get_state(state) done = self._get_done(done) if done: return [True for _ in range(self.action_space_dim)] mask = [False] * self.action_space_dim used_elements = list(state.keys()) unused_required_elements = [ e for e in self.required_elements if e not in used_elements ] n_used_elements = len(used_elements) n_unused_required_elements = len(unused_required_elements) n_used_atoms = sum(state.values()) if self.do_spacegroup_check and isinstance(self.space_group, int): space_group = get_space_group(self.space_group) # Determine, based on the space group's Wyckoff's positions, what # is the min/max number of atoms of a given element that could be # added. most_specific_wp = space_group.get_wyckoff_position(-1) min_atom_i = most_specific_wp.multiplicity wyckoff_gcd = space_group_wyckoff_gcd(self.space_group) max_atom_i = (self.max_atom_i // wyckoff_gcd) * wyckoff_gcd # Determine if the current composition is compatible with the # space group n_atoms_per_element = self.get_n_atoms_per_element(state) sg_compatible = space_group_check_compatible( self.space_group, n_atoms_per_element ) else: # Don't impose additional constraints on the min/max number of # atoms per element min_atom_i = self.min_atom_i max_atom_i = self.max_atom_i # Assume the current composition is compatible with the space group sg_compatible = True # Compute the min and max number of atoms to add to satisfy constraints nb_atoms_still_needed = max(0, self.min_atoms - n_used_atoms) nb_atoms_still_allowed = self.max_atoms - n_used_atoms # Compute the min and max number of elements to add to satisfy constraints nb_elems_still_needed = max( n_unused_required_elements, self.min_diff_elem - n_used_elements ) nb_elems_still_allowed = self.max_diff_elem - n_used_elements # How many elements, other than the required elements, can still be added n_max_unrequired_elements_left = self.max_diff_elem - ( n_used_elements + n_unused_required_elements ) # What is the minimum number of atoms needed for a new required element in # order to reach the number of required atoms before we can't add new elements # anymore min_atoms_per_required_element = max( nb_atoms_still_needed - (nb_elems_still_allowed - 1) * max_atom_i, min_atom_i, ) # What is the maximum number of atoms allowed for a new required element in # order to be able to reach the number of required elements before we can't add # new atoms anymore max_atoms_per_required_element = min( nb_atoms_still_allowed - (nb_elems_still_needed - 1) * min_atom_i, max_atom_i, ) # Determine if there is a need to add unrequired elements to either reach the # number of required distinct elements or the number of required atoms unrequired_element_needed = ( nb_elems_still_needed > n_unused_required_elements or max_atoms_per_required_element * n_unused_required_elements < nb_atoms_still_needed ) # Determine if it is possible to add unrequired elements without going over the # maximum number of elements or atoms unrequired_element_allowed = ( n_max_unrequired_elements_left > 0 and min_atoms_per_required_element * n_unused_required_elements + min_atom_i <= nb_atoms_still_allowed ) # Compute the minimum and maximum number of atoms available for an unrequired # element if unrequired_element_needed: # Some unrequired elements are needed so they are treated the same as the # required elements min_atoms_per_unrequired_element = min_atoms_per_required_element max_atoms_per_unrequired_element = max_atoms_per_required_element elif unrequired_element_allowed: # Unrequired elements are optional so there is no minium amount to add for # them and the maximum is only as high as possible without preventing the # addition of the required elements later min_atoms_per_unrequired_element = min_atom_i max_atoms_per_unrequired_element = min( nb_atoms_still_allowed - min_atoms_per_required_element * n_unused_required_elements, max_atom_i, ) else: # No unrequired elements can be added min_atoms_per_unrequired_element = 0 max_atoms_per_unrequired_element = 0 if n_used_atoms < self.min_atoms: mask[-1] = True if n_used_elements < self.min_diff_elem: mask[-1] = True if any(r not in used_elements for r in self.required_elements): mask[-1] = True if not sg_compatible: # The current composition is incompatible with the space group, # we must allow EOS to end the trajectory. mask[-1] = False # Obtain action mask for each category of element def get_element_mask(min_atoms, max_atoms): return [ bool(i < min_atoms or i > max_atoms) for i in range(self.min_atom_i, self.max_atom_i + 1) ] mask_required_element = get_element_mask( min_atoms_per_required_element, max_atoms_per_required_element ) mask_unrequired_element = get_element_mask( min_atoms_per_unrequired_element, max_atoms_per_unrequired_element ) # If required, refine the masks by doing compatibility checks between # the space group and the number of atoms if self.do_spacegroup_check and isinstance(self.space_group, int): self._refine_compatibility_check( state, mask_required_element, mask_unrequired_element ) # Set action mask for each element nb_actions_per_element = self.max_atom_i - self.min_atom_i + 1 for element_idx, element in enumerate(self.elements): # Compute the start and end indices of the actions associated with this # element action_start_idx = element_idx * nb_actions_per_element action_end_idx = action_start_idx + nb_actions_per_element # Set the mask for the actions associated with this element if element in state: # This element has already been added, we cannot add more mask[action_start_idx:action_end_idx] = [True] * nb_actions_per_element elif element in unused_required_elements: mask[action_start_idx:action_end_idx] = mask_required_element else: mask[action_start_idx:action_end_idx] = mask_unrequired_element # If no other action is valid, ensure that the EOS action is available if all(mask): mask[-1] = False return mask
[docs] def states2proxy( self, states: List[dict] ) -> TensorType["batch", "state_proxy_dim"]: """ Prepares a batch of states in "environment format" for the proxy: The output is a tensor of dtype long with N_ELEMENTS_ORACLE + 1 columns, where the positions of self.elements are filled with the number of atoms of each element in the state. Args ---- states : list or tensor A batch of states in environment format, either as a list of states or as a single tensor. Returns ------- A tensor containing all the states in the batch. """ states_proxy = np.zeros((len(states), N_ELEMENTS_ORACLE + 1)) for idx, state in enumerate(states): if len(state) == 0: continue states_proxy[idx, list(state.keys())] = list(state.values()) return tlong(states_proxy, device=self.device)
[docs] def states2policy( self, states: List[dict] ) -> TensorType["batch", "policy_input_dim"]: """ Prepares a batch of states in "environment format" for the policy model: in order to not waste memory and for backward compatibility, the policy state only contains the number of atoms of the allowed elements. For example, if self.elements is [1, 2, 3, 4]: states: [{2: 1, 4: 2}, {1: 3}] states2policy(states): tensor([[0, 1, 0, 2], [3, 0, 0, 0]]) Args ---- states : list A batch of states in environment format, that is a list of dictionaries. Returns ------- A tensor containing all the states in the batch. """ states_policy = np.zeros((len(states), len(self.elements))) for idx, state in enumerate(states): if len(state) == 0: continue indices_elements = [self.elem2idx[el] for el in state] states_policy[idx, indices_elements] = list(state.values()) return tfloat(states_policy, device=self.device, float_type=self.float)
[docs] def state2readable(self, state=None): """ Transforms the state, represented as a dictionary of element: n_atoms key-value pairs, into a human-readable version: a non-reduced formula, following the Hill system. See: https://en.wikipedia.org/wiki/Chemical_formula#Hill_system Example: state: {1: 2, 3: 1} 1: atomic number of H 3: atomic number of Li output: H2Li1 """ state = self._get_state(state) state_elements = {self.alphabet[el]: n for el, n in state.items()} formula = "" if "C" in state_elements: formula += "C" + str(state_elements["C"]) if "H" in state_elements: formula += "H" + str(state_elements["H"]) formula += "".join( [ el + str(n) for el, n in sorted(state_elements.items()) if el not in ["C", "H"] ] ) return formula
[docs] def readable2state(self, readable): """ Converts the readable representation of a state (a chemical formula) into the environment format. Example: readable: H2Li1 self.alphabet: {1: "H", 2: "He", 3: "Li", 4: "Be"} output: {1: 2, 3: 1} """ state = {} offset = 0 for match in re.finditer(r"\d+", readable): span = match.span(0) element = readable[offset : span[0]] n_atoms = int(readable[span[0] : span[1]]) state[self.alphabet_rev[element]] = n_atoms offset = span[1] return state
[docs] def reset(self, env_id=None): """ Resets the environment. """ self.state = self.source.copy() self.n_actions = 0 self.done = False self.id = env_id return self
[docs] def get_parents(self, state=None, done=None, action=None): """ Determines all parents and actions that lead to a state. Args ---- state : dict Representation of a state as a dictionary of element: n_atoms key-value pairs. Elements whose atomic number is not a key of the dictionary have zero atoms. done : bool Whether the trajectory is done. If None, done is taken from instance. action : None Ignored Returns ------- parents : list List of parents in state format actions : list List of actions that lead to state for each parent in parents """ state = self._get_state(state) done = self._get_done(done) if done: return [state], [self.eos] else: parents = [] actions = [] for idx, action in enumerate(self.action_space[:-1]): element, n = action if element in state and state[element] == n: parent = state.copy() del parent[element] parents.append(parent) actions.append(action) return parents, actions
[docs] def step(self, action: Tuple[int, int]) -> Tuple[List[int], Tuple[int, int], bool]: """ Executes step given an action. Args ---- action : tuple Action to be executed. See: get_action_space() Returns ------- self.state : list The state after executing the action action : tuple Action executed valid : bool False, if the action is not allowed for the current state. """ # If done, return invalid if self.done: return self.state, action, False # If action not found in action space raise an error if action not in self.action_space: raise ValueError( f"Tried to execute action {action} not present in action space." ) else: action_idx = self.action_space.index(action) # If action is in invalid mask, exit immediately if self.get_mask_invalid_actions_forward()[action_idx]: return self.state, action, False # If action is not eos, then perform action if action != self.eos: element, num = action self.state[element] = num self.n_actions += 1 return self.state, action, True # If action is eos, then attempt eos action else: if self.get_mask_invalid_actions_forward()[-1]: valid = False else: if self.do_charge_check: # Currently enabling it causes errors when training combined # Crystal env, and very significantly increases training time. if self._can_produce_neutral_charge(): self.done = True valid = True self.n_actions += 1 else: valid = False else: self.done = True valid = True self.n_actions += 1 return self.state, self.eos, valid
[docs] def get_n_atoms_per_element(self, state: Optional[List[int]] = None) -> List[int]: """ Returns the number of atoms per element. The result is returned as a list of integers, with no particular order. That is, the output does not allow to identify which element has how many atoms, since the intended use is mainly to check the compatibility with a space group. Parameters ---------- state : list A state in environment format. If None, self.state is used. Returns ------- list A list of integers containing the number of atoms per element in the state. """ state = self._get_state(state) return list(state.values())
def _can_produce_neutral_charge(self, state: Optional[List[int]] = None) -> bool: """ Helper that checks whether there is a configuration of oxidation states that can produce a neutral charge for the given state. """ state = self._get_state(state) nums_charges = [ (num, self.oxidation_states[element]) for element, num in state.items() ] # Process all atoms one by one, gradually accumulating a set of all possible # charge totals so far poss_charge_sum = set([0]) while len(nums_charges) > 0: num, charges = nums_charges[0] # Compute all possible charge totals that can be obtained by combining # all the previous charge totals with all the possible charges for the # current atom new_poss_charge_sum = set() for old_charge_sum in poss_charge_sum: for element_charge in charges: new_poss_charge_sum.add(old_charge_sum + element_charge) poss_charge_sum = new_poss_charge_sum # Remove the atom that was processed from nums_charges if num == 1: # Remove element from nums_charges del nums_charges[0] else: # Remove one atom from this element nums_charges[0] = (num - 1, charges) return 0 in poss_charge_sum
[docs] def is_valid(self, state: dict) -> bool: """ Determines whether a state is valid, according to the attributes of the environment. Parameters ---------- state : dict A state in environment format. Returns ------- bool True if the state is valid according to the attributes of the environment; False otherwise. """ n_atoms_per_element = self.get_n_atoms_per_element(state) # Check total number of atoms n_atoms = sum(n_atoms_per_element) if n_atoms < self.min_atoms: return False if n_atoms > self.max_atoms: return False # Check number element if any([n < self.min_atom_i for n in n_atoms_per_element]): return False if any([n > self.max_atom_i for n in n_atoms_per_element]): return False # Check required elements used_elements = set(state.keys()) if len(used_elements - set(self.elements)) > 0: return False if len(used_elements) < self.min_diff_elem: return False if len(used_elements) > self.max_diff_elem: return False if len(set(self.required_elements) - used_elements) > 0: return False # If all checks are passed, return True return True