"""
Utility functions to improve the efficiency of common
functions/workflows/calculations in ``doped``.
"""
import contextlib
import itertools
import operator
import re
from collections import defaultdict
from collections.abc import Callable, Generator, Sequence
from functools import cached_property, lru_cache
from typing import TYPE_CHECKING
import numpy as np
from numpy.typing import NDArray
from pymatgen.analysis.defects.generators import VacancyGenerator
from pymatgen.analysis.defects.utils import VoronoiPolyhedron, remove_collisions
from pymatgen.analysis.structure_matcher import (
AbstractComparator,
ElementComparator,
FrameworkComparator,
StructureMatcher,
)
from pymatgen.core.composition import Composition, DummySpecies
from pymatgen.core.periodic_table import Element, Species
from pymatgen.core.sites import PeriodicSite, Site
from pymatgen.core.structure import IStructure, Molecule, Structure
from pymatgen.io.vasp.sets import get_valid_magmom_struct
from pymatgen.symmetry.analyzer import SpacegroupAnalyzer, SymmOp
from scipy.spatial import Voronoi
if TYPE_CHECKING:
from doped.core import Vacancy
# Note that any overrides of ``__eq__`` should also override ``__hash__``, and vice versa
# Composition overrides:
def _composition__hash__(self):
"""
Custom ``__hash__`` method for ``Composition`` instances, to make
composition comparisons faster (used in structure matching etc.).
``pymatgen`` composition has just hashes the chemical system (without
stoichiometry), which cannot then be used to distinguish different
compositions.
"""
return hash(frozenset(self._data.items()))
[docs]
@lru_cache(maxsize=int(1e8))
def doped_Composition_eq_func(self_hash, other_hash):
r"""
Update equality function for ``Composition`` instances, which breaks early
for mismatches and also uses caching, making it orders of magnitude faster
than ``pymatgen``\s equality function.
"""
self_comp = Composition.__instances__[self_hash]
other_comp = Composition.__instances__[other_hash]
return fast_Composition_eq(self_comp, other_comp)
[docs]
def fast_Composition_eq(self, other):
"""
Fast equality function for ``Composition`` instances, breaking early for
mismatches.
"""
# skip matching object type check here, as already checked upstream in ``_Composition__eq__``
if len(self) != len(other):
return False
for el, amt in self.items(): # noqa: SIM110
if abs(amt - other[el]) > type(self).amount_tolerance:
return False
return True
def _Composition__eq__(self, other):
"""
Custom ``__eq__`` method for ``Composition`` instances, using a cached
equality function to speed up comparisons.
"""
if not isinstance(other, type(self) | dict):
return NotImplemented
# use object hash with instances to avoid recursion issues (for class method)
self_hash = _composition__hash__(self)
Composition.__instances__[self_hash] = self # Ensure instances are stored for caching
other_hash = _composition__hash__(other)
Composition.__instances__[other_hash] = other
return doped_Composition_eq_func(self_hash, other_hash)
[docs]
class Hashabledict(dict):
def __hash__(self):
"""
Make the dictionary hashable by recursively "freezing" into only
hashable built-ins, then hash that.
Handles nested dicts, lists, sets, tuples, etc.
"""
def _freeze(obj):
if isinstance(obj, dict): # convert to frozenset of tuples
return frozenset((k, _freeze(v)) for k, v in obj.items())
if isinstance(obj, list): # lists → tuples
return tuple(_freeze(v) for v in obj)
if isinstance(obj, set): # sets → frozensets
return frozenset(_freeze(v) for v in obj)
if isinstance(obj, tuple): # tuples → tuples of frozen values
return tuple(_freeze(v) for v in obj)
# else assume it's already hashable (int, str, custom …)
return obj
return hash(_freeze(self))
def _get_hashable_dict(d: dict) -> Hashabledict:
if isinstance(d, Hashabledict):
return d
if isinstance(d, dict):
return Hashabledict(d) # convert to hashable dict for caching purposes
return d
def _fast_dict_deepcopy_max_two_levels(d: dict) -> dict:
"""
Fast deepcopy of a dict with at most two levels of nested dicts (i.e. d →
dict → dict → values).
Implemented to allow fast deep-copying of nested chemical potential dicts,
avoiding the overhead of `deepcopy` when looping over many chemical
potential dicts.
"""
return {
k: (
{
k2: (v2.copy() if isinstance(v2, dict) else v2) # final level, shallow copy sufficient
for k2, v2 in v1.items()
}
if isinstance(v1, dict)
else v1
)
for k, v1 in d.items()
}
@lru_cache(maxsize=int(1e5))
def _cached_Composition_init(comp_input):
return Composition(comp_input)
def _cache_ready_Composition_init(comp_input):
return _cached_Composition_init(_get_hashable_dict(comp_input))
def _fast_get_composition_from_sites(sites, assume_full_occupancy=False):
"""
Helper function to quickly get the composition of a collection of sites,
faster than initializing a ``Structure`` object.
Used in initial drafts of defect stenciling code, but replaced by faster
methods.
"""
elem_map: dict[Species, float] = defaultdict(float)
for site in sites:
if assume_full_occupancy:
elem_map[next(iter(site._species))] += 1
else:
for species, occu in site.species.items():
elem_map[species] += occu
return Composition(elem_map)
Composition.__instances__ = {}
Composition.__eq__ = _Composition__eq__
Composition.__hash__ = _composition__hash__
@lru_cache(maxsize=int(1e5))
def _parse_site_species_str(site: Site, wout_charge: bool = False):
if isinstance(site._species, Element):
return site._species.symbol
if isinstance(site._species, str):
species_string = site._species
elif isinstance(site._species, Composition | dict):
species_string = str(next(iter(site._species)))
else:
raise ValueError(f"Unexpected species type: {type(site._species)}")
if wout_charge: # remove all digits, + or - from species string
return re.sub(r"\d+|[\+\-]", "", species_string)
return species_string
# PeriodicSite overrides:
def _periodic_site__hash__(self):
"""
Custom ``__hash__`` method for ``PeriodicSite`` instances.
"""
property_dict = ( # Convert properties to a hashable form
{k: tuple(v) if isinstance(v, list | np.ndarray) else v for k, v in self.properties.items()}
if self.properties
else {}
)
species_info = tuple(str(el) for el in self.species) # string representation is used for species hash
try:
return hash(
(
species_info,
tuple(self.coords),
frozenset(property_dict.items()),
)
)
except Exception: # hash without the property dict
return hash(
(
species_info,
tuple(self.coords),
)
)
[docs]
def cache_ready_PeriodicSite__eq__(self, other):
"""
Custom ``__eq__`` method for ``PeriodicSite`` instances, using a cached
equality function to speed up comparisons.
"""
needed_attrs = ("_species", "coords", "properties")
if not all(hasattr(other, attr) for attr in needed_attrs):
return NotImplemented
return (
self._species == other._species # should always work fine (and is faster) if Site initialised
# without ``skip_checks`` (default)
and cached_allclose(tuple(self.coords), tuple(other.coords), atol=type(self).position_atol)
and self.properties == other.properties
)
[docs]
@lru_cache(maxsize=int(1e8))
def cached_allclose(a: tuple, b: tuple, rtol: float = 1e-05, atol: float = 1e-08):
"""
Cached version of ``np.allclose``, taking tuples as inputs (so that they
are hashable and thus cacheable).
"""
return np.allclose(np.array(a), np.array(b), rtol=rtol, atol=atol)
PeriodicSite.__eq__ = cache_ready_PeriodicSite__eq__
PeriodicSite.__hash__ = _periodic_site__hash__
# Structure overrides:
def _structure__hash__(self):
"""
Custom ``__hash__`` method for ``Structure`` instances.
"""
return hash((self.lattice, frozenset(self.sites)))
[docs]
@contextlib.contextmanager
def cache_species(structure_cls):
"""
Context manager that makes ``Structure.species`` a cached property, which
significantly speeds up ``pydefect`` eigenvalue parsing in large structures
(due to repeated use of ``Structure.indices_from_symbol``.
"""
Composition.__eq__ = _Composition__eq__
Composition.__hash__ = _composition__hash__ # use efficient hash for composition
original_species = structure_cls.species
try:
cached = cached_property(original_species.fget)
cached.__set_name__(structure_cls, "species") # Explicit initialization
structure_cls.species = cached
yield
finally:
structure_cls.species = original_species
[docs]
def doped_Structure__eq__(self, other: IStructure) -> bool:
"""
Copied from ``pymatgen``, but updated to break early once a mis-matching
site is found, to speed up structure matching by ~2x.
"""
# skip matching object type check here, as already checked upstream in ``_Structure__eq__``
if other is self:
return True
if len(self) != len(other):
return False
if self.lattice != other.lattice:
return False
if self.properties != other.properties:
return False
for site in self: # noqa: SIM110
if site not in other:
return False # break early!
return True
[docs]
@lru_cache(maxsize=int(1e4))
def cached_Structure_eq_func(self_hash, other_hash):
"""
Cached equality function for ``Structure`` instances.
"""
return doped_Structure__eq__(IStructure.__instances__[self_hash], IStructure.__instances__[other_hash])
def _Structure__eq__(self, other):
"""
Custom ``__eq__`` method for ``Structure``/``IStructure`` instances, using
both caching and an updated, faster equality function to speed up
comparisons.
"""
needed_attrs = ("lattice", "sites", "properties")
if not all(hasattr(other, attr) for attr in needed_attrs):
return NotImplemented
self_hash = _structure__hash__(self)
other_hash = _structure__hash__(other)
IStructure.__instances__[self_hash] = self # Ensure instances are stored for caching
IStructure.__instances__[other_hash] = other
return cached_Structure_eq_func(self_hash, other_hash)
IStructure.__eq__ = _Structure__eq__
IStructure.__hash__ = _structure__hash__
IStructure.__instances__ = {}
Structure.__eq__ = _Structure__eq__
Structure.__hash__ = _structure__hash__
Structure.__deepcopy__ = lambda x, y: x.copy() # make deepcopying faster, shallow copy fine for structures
# Molecule overrides:
def _DopedMolecule__hash__(self):
"""
Hash ``pymatgen`` ``Molecule`` objects using the z-matrix (which reflects
the lengths, angles, and atom types of the molecule) and the site
coordinates.
Implemented to allow caching for efficient determination of symmetry
equivalent ``Molecule`` objects. The z-matrix functions as a unique
identifier for a molecule (with translation/rotation invariance -- which is
then removed by including the actual coordinates), while the ``__hash__``
method for the parent ``Molecule`` class is based solely on the composition
of the molecule and thus not unique.
"""
z_list = self.get_zmatrix().split("\n")
rounded_z_list = tuple([round(float(i.split("=")[-1]), 2) if "=" in i else i for i in z_list])
return hash((rounded_z_list, tuple(tuple(np.round(site.coords, 3)) for site in self)))
def _DopedMolecule__eq__(self, other):
"""
Custom ``__eq__`` method for ``Molecule`` instances, using a cached
equality function to speed up comparisons.
"""
if not isinstance(other, type(self)):
return NotImplemented
return _DopedMolecule__hash__(self) == _DopedMolecule__hash__(other)
Molecule.__eq__ = _DopedMolecule__eq__
Molecule.__hash__ = _DopedMolecule__hash__
# SpacegroupAnalyzer overrides:
def _sga__hash__(self):
"""
Custom ``__hash__`` method for ``SpacegroupAnalyzer`` instances, to make
them hashable for efficient caching of e.g. symmetry operation generation.
"""
return hash((self._cell, self._symprec, self._angle_tol))
_original_get_symmetry = SpacegroupAnalyzer._get_symmetry
@lru_cache(maxsize=int(1e3))
def _get_symmetry(self) -> tuple[NDArray, NDArray]:
"""
Get the symmetry operations associated with the structure.
Refactored from ``pymatgen`` to allow caching, to boost efficiency when
working with large defect supercells.
"""
return _original_get_symmetry(self) # call the original method, with the now cacheable class
_original_get_symmetry_operations = SpacegroupAnalyzer.get_symmetry_operations
@lru_cache(maxsize=int(1e3))
def _get_symmetry_operations(self, cartesian: bool = False) -> list[SymmOp]:
"""
Get the symmetry operations associated with the structure.
Refactored from ``pymatgen`` to allow caching, to boost efficiency.
"""
return _original_get_symmetry_operations(self) # call the original method, now a cacheable class
SpacegroupAnalyzer.__hash__ = _sga__hash__
SpacegroupAnalyzer._get_symmetry = _get_symmetry
SpacegroupAnalyzer.get_symmetry_operations = _get_symmetry_operations
def _get_symbol(element: Element | Species, comparator: AbstractComparator | None = None) -> str:
"""
Convenience function to get the symbol of an ``Element`` or ``Species`` as
a string, with charge information included or excluded depending on the
choice of ``comparator``.
By default, the returned symbol does not include any charge / oxidation
state information. If ``comparator`` is provided and is not
``ElementComparator`` / ``FrameworkComparator``, then the ``str(element)``
representation is returned (which will include charge information if
``element`` is a ``Species``).
Args:
element (Element | Species):
``Element`` or ``Species`` to get the symbol of.
comparator (AbstractComparator | None):
Comparator to check if we should return the ``str(element)``
representation (which includes charge information if ``element`` is
a ``Species``), or just the element symbol (i.e.
``element.symbol``, or ``element.element.symbol`` if ``element`` is
a ``Species`` object) -- which is the case when ``comparator`` is
``None`` (default) or ``ElementComparator`` /
``FrameworkComparator``.
Returns:
str: Symbol of the element as a string.
"""
if (
comparator is not None
and not isinstance(comparator, ElementComparator | FrameworkComparator)
and isinstance(element, Species)
):
return str(element)
return element.symbol if isinstance(element, Element | DummySpecies) else element.element.symbol
[docs]
def get_element_indices(
structure: Structure,
elements: list[Element | Species | str] | None = None,
comparator: AbstractComparator | None = None,
) -> dict[str, list[int]]:
"""
Convenience function to generate a dictionary of ``{element: [indices]}``
for a given ``Structure``, where ``indices`` are the indices of the sites
in the structure corresponding to the given ``elements`` (default is all
elements in the structure).
Args:
structure (Structure):
``Structure`` to get the indices from.
elements (list[Element | Species | str] | None):
List of elements to get the indices of. If ``None`` (default), all
elements in the structure are used.
comparator (AbstractComparator | None):
Comparator to check if we should return the ``str(element)``
representation (which includes charge information if ``element`` is
a ``Species``), or just the element symbol (i.e.
``element.element.symbol``) -- which is the case when
``comparator`` is ``None`` (default) or ``ElementComparator`` /
``FrameworkComparator``.
Returns:
dict[str, list[int]]:
Dictionary of ``{element: [indices]}`` for the given ``elements``
in the structure.
"""
if elements is None:
elements = _fast_get_composition_from_sites(structure).elements
if not all(isinstance(element, str) for element in elements):
elements = [_get_symbol(element, comparator) for element in elements]
species = np.array([_get_symbol(site.specie, comparator) for site in structure])
return {element: np.where(species == element)[0].tolist() for element in elements}
[docs]
def get_element_min_max_bond_length_dict(structure: Structure, **sm_kwargs) -> dict:
r"""
Get a dictionary of ``{element: (min_bond_length, max_bond_length)}`` for a
given ``Structure``, where ``min_bond_length`` and ``max_bond_length`` are
the minimum and maximum bond lengths for each element in the structure.
Args:
structure (Structure):
Structure to calculate bond lengths for.
**sm_kwargs:
Additional keyword arguments to pass to ``StructureMatcher()``.
Just used to check if ``comparator`` has been set here (if
``ElementComparator``/``FrameworkComparator`` used, then we use
``Element``\s rather than ``Species`` as the keys), or if
``ignored_species`` is set (in which case these species are
ignored when calculating bond lengths).
Returns:
dict: Dictionary of ``{element: (min_bond_length, max_bond_length)}``.
"""
comparator = sm_kwargs.get("comparator")
if len(structure) == 1:
structure = structure * 2 # need at least two sites to calculate bond lengths
# get the distance matrix broken down by species:
element_idx_dict = get_element_indices(structure, comparator=comparator)
ignored_indices = [
idx for elt in sm_kwargs.get("ignored_species", []) for idx in element_idx_dict.get(elt, [])
]
distance_matrix = structure.distance_matrix
np.fill_diagonal(distance_matrix, np.inf) # set diagonal to np.inf to ignore self-distances of 0
distance_matrix[:, ignored_indices] = np.inf # set ignored indices to np.inf to ignore these distances
distance_matrix[ignored_indices, :] = np.inf # set ignored indices to np.inf to ignore these distances
element_min_max_bond_length_dict = {elt: np.array([0, 0]) for elt in element_idx_dict}
for elt, site_indices in element_idx_dict.items():
element_dist_matrix = distance_matrix[:, site_indices] # (N_of_that_element, N_sites) matrix
if element_dist_matrix.size != 0:
min_interatomic_distances_per_atom = np.min(element_dist_matrix, axis=0) # min along columns
element_min_max_bond_length_dict[elt] = np.array(
[np.min(min_interatomic_distances_per_atom), np.max(min_interatomic_distances_per_atom)]
)
return element_min_max_bond_length_dict
[docs]
def get_dist_equiv_stol(dist: float, structure: Structure) -> float:
"""
Get the equivalent ``stol`` value for a given Cartesian distance (``dist``)
in a given ``Structure``.
``stol`` is a site tolerance parameter used in ``pymatgen``
``StructureMatcher`` functions, defined as the fraction of the average free
length per atom := ( V / Nsites ) ** (1/3).
Args:
dist (float): Cartesian distance in Å.
structure (Structure): Structure to calculate ``stol`` for.
Returns:
float: Equivalent ``stol`` value for the given distance.
"""
return dist / (structure.volume / len(structure)) ** (1 / 3)
[docs]
def get_min_stol_for_s1_s2(struct1: Structure, struct2: Structure, **sm_kwargs) -> float:
"""
Get the minimum possible ``stol`` value which will give a match between
``struct1`` and ``struct2`` using ``StructureMatcher``, based on the ranges
of per-element minimum interatomic distances in the two structures.
Args:
struct1 (Structure): Initial structure.
struct2 (Structure): Final structure.
**sm_kwargs:
Additional keyword arguments to pass to ``StructureMatcher()``.
Just used to check if ``ignored_species`` or ``comparator`` has
been set here.
Returns:
float:
Minimum ``stol`` value for a match between ``struct1`` and
``struct2``. If a direct match is detected (corresponding to min
``stol`` = 0, then ``1e-4`` is returned).
"""
s1_min_max_bond_length_dict = get_element_min_max_bond_length_dict(struct1, **sm_kwargs)
s2_min_max_bond_length_dict = get_element_min_max_bond_length_dict(struct2, **sm_kwargs)
common_elts = set(s1_min_max_bond_length_dict.keys()) & set(s2_min_max_bond_length_dict.keys())
if not common_elts: # try without oxidation states
struct1_wout_oxi = struct1.copy()
struct2_wout_oxi = struct2.copy()
struct1_wout_oxi.remove_oxidation_states()
struct2_wout_oxi.remove_oxidation_states()
s1_min_max_bond_length_dict = get_element_min_max_bond_length_dict(struct1_wout_oxi, **sm_kwargs)
s2_min_max_bond_length_dict = get_element_min_max_bond_length_dict(struct2_wout_oxi, **sm_kwargs)
common_elts = set(s1_min_max_bond_length_dict.keys()) & set(s2_min_max_bond_length_dict.keys())
min_min_dist_change = 1e-4
with contextlib.suppress(Exception):
min_min_dist_change = max(
{
elt: max(np.abs(s1_min_max_bond_length_dict[elt] - s2_min_max_bond_length_dict[elt]))
for elt in common_elts
if elt not in sm_kwargs.get("ignored_species", [])
}.values()
)
return max(get_dist_equiv_stol(min_min_dist_change, struct1), 1e-4)
def _sm_get_atomic_disps(sm: StructureMatcher, struct1: Structure, struct2: Structure):
"""
Get the root-mean-square displacement `and atomic displacements` between
two structures, normalized by the mean free length per atom:
``(Vol/Nsites)^(1/3)``.
These values are not directly returned by ``StructureMatcher`` methods.
This function replicates ``StructureMatcher.get_rms_dist()``, but changes
the return value from ``match[0], max(match[1])`` to ``match[0], match[1]``
to allow further analysis of displacements. Mainly intended for use by
``ShakeNBreak``.
Args:
sm (StructureMatcher): ``pymatgen`` ``StructureMatcher`` object.
struct1 (Structure): Initial structure.
struct2 (Structure): Final structure.
Returns:
tuple:
- float: Normalised RMS displacement between the two structures.
- np.ndarray: Normalised displacements between the two structures.
or ``None`` if no match is found.
"""
struct1, struct2 = sm._process_species([struct1, struct2])
struct1, struct2, fu, s1_supercell = sm._preprocess(struct1, struct2)
match = sm._match(struct1, struct2, fu, s1_supercell, use_rms=True, break_on_match=False)
return None if match is None else (match[0], match[1])
[docs]
def StructureMatcher_scan_stol(
struct1: Structure,
struct2: Structure,
func_name: str = "get_s2_like_s1",
min_stol: float | None = None,
max_stol: float = 5.0,
stol_factor: float = 0.5,
**sm_kwargs,
):
r"""
Utility function to scan through a range of ``stol`` values for
``StructureMatcher`` until a match is found between ``struct1`` and
``struct2`` (i.e. ``StructureMatcher.{func_name}`` returns a result).
The ``StructureMatcher.match()`` function (used in most
``StructureMatcher`` methods) speed is heavily dependent on ``stol``, with
smaller values being faster, so we can speed up evaluation by starting with
small values and increasing until a match is found (especially with the
``doped`` efficiency tools which implement caching (and other improvements)
to ensure no redundant work here).
Note that ``ElementComparator()`` is used by default here! (So sites with
different species but the same element (e.g. "S2-" & "S0+") will be
considered match-able). This can be controlled with
``sm_kwargs['comparator']``.
Args:
struct1 (Structure): ``struct1`` for ``StructureMatcher.match()``.
struct2 (Structure): ``struct2`` for ``StructureMatcher.match()``.
func_name (str):
The name of the ``StructureMatcher`` method to return the result
of ``StructureMatcher.{func_name}(struct1, struct2)`` for, such
as:
- "get_s2_like_s1" (default)
- "get_rms_dist"
- "fit"
- "fit_anonymous"
min_stol (float):
Minimum ``stol`` value to try. Default is to use ``doped``\s
``get_min_stol_for_s1_s2()`` function to estimate the minimum
``stol`` necessary, and start with 2x this value to achieve fast
structure-matching in most cases.
max_stol (float):
Maximum ``stol`` value to try. Default: 5.0.
stol_factor (float):
Fractional increment to increase ``stol`` by each time (when a
match is not found). Default value of 0.5 increases ``stol`` by 50%
each time.
**sm_kwargs:
Additional keyword arguments to pass to ``StructureMatcher()``.
Returns:
Result of ``StructureMatcher.{func_name}(struct1, struct2)`` or
``None`` if no match is found.
"""
# use doped efficiency tools to make structure-matching as fast as possible:
StructureMatcher._get_atomic_disps = _sm_get_atomic_disps # monkey-patch ``StructureMatcher`` for SnB
if "comparator" not in sm_kwargs:
sm_kwargs["comparator"] = ElementComparator()
if min_stol is None:
min_stol = get_min_stol_for_s1_s2(struct1, struct2, **sm_kwargs) * 2
# here we cycle through a range of stols, because we just need to find the closest match so we could
# use a high ``stol`` from the start and it would give correct result, but higher ``stol``\s take
# much longer to run as it cycles through multiple possible matches. So we start with a low ``stol``
# and break once a match is found:
stol = min_stol
while stol < max_stol:
if user_stol := sm_kwargs.pop("stol", False): # first run, try using user-provided stol first:
sm_full_user_custom = StructureMatcher(stol=user_stol, **sm_kwargs)
result = getattr(sm_full_user_custom, func_name)(struct1, struct2)
if result is not None:
return result
sm = StructureMatcher(stol=stol, **sm_kwargs)
result = getattr(sm, func_name)(struct1, struct2)
if result is not None:
return result
stol *= 1 + stol_factor
# Note: this function could possibly be sped up if ``StructureMatcher._match()`` was updated to
# return the guessed ``best_match`` value (even if larger than ``stol``), which will always be
# >= the best possible match it seems, and then using this to determine the next ``stol`` value
# to trial. Seems like it could give a ~50% speedup in some cases? Not clear though,
# as once you're getting a reasonable guessed value out, the trial ``stol`` should be pretty
# close to the necessary value anyway.
return None
StructureMatcher._get_atomic_disps = _sm_get_atomic_disps # monkey-patch ``StructureMatcher`` for SnB
[docs]
class DopedTopographyAnalyzer:
"""
This is a modified version of
``pymatgen.analysis.defects.utils.TopographyAnalyzer`` to lean down the
input options and make initialisation far more efficient (~2 orders of
magnitude faster).
The original code was written by Danny Broberg and colleagues
(10.1016/j.cpc.2018.01.004), which was then added to ``pymatgen`` before
being cut.
"""
def __init__(
self,
structure: Structure,
image_tol: float = 0.0001,
max_cell_range: int = 1,
constrained_c_frac: float = 0.5,
thickness: float = 0.5,
) -> None:
"""
Args:
structure (Structure):
Structure to analyse.
image_tol (float):
A tolerance distance for the analysis, used to determine if
sites are periodic images of each other. Default (of 1e-4) is
usually fine.
max_cell_range (int):
This is the range of periodic images to construct the Voronoi
tessellation. A value of 1 means that we include all points
from ``(x +- 1, y +- 1, z+- 1)`` in the Voronoi construction.
This is because the Voronoi polyhedra extend beyond the
standard unit cell because of PBC. Typically, the default value
of 1 works fine for most structures and is fast. But for very
small unit cells with high symmetry, this may need to be
increased to 2 or higher. If there are < 5 atoms in the input
structure and ``max_cell_range`` is 1, this will automatically
be increased to 2.
constrained_c_frac (float):
Constrain the region where topology analysis is performed.
Only sites with ``z`` fractional coordinates between
``constrained_c_frac +/- thickness`` are considered. Default of
0.5 (with ``thickness`` of 0.5) includes all sites in the unit
cell.
thickness (float):
Constrain the region where topology analysis is performed.
Only sites with ``z`` fractional coordinates between
``constrained_c_frac +/- thickness`` are considered. Default of
0.5 (with ``thickness`` of 0.5) includes all sites in the unit
cell.
"""
# if input cell is very small (< 5 atoms) and max cell range is 1 (default), bump to 2 for
# accurate Voronoi tessellation:
if len(structure) < 5 and max_cell_range == 1:
max_cell_range = 2
self.structure = structure.copy()
self.structure.remove_oxidation_states()
constrained_sites = []
for _i, site in enumerate(self.structure):
if (
site.frac_coords[2] >= constrained_c_frac - thickness
and site.frac_coords[2] <= constrained_c_frac + thickness
):
constrained_sites.append(site)
constrained_struct = Structure.from_sites(sites=constrained_sites)
lattice = constrained_struct.lattice
coords = []
cell_range = list(range(-max_cell_range, max_cell_range + 1))
for shift in itertools.product(cell_range, cell_range, cell_range):
for site in constrained_struct.sites:
shifted = site.frac_coords + shift
coords.append(lattice.get_cartesian_coords(shifted))
# Perform the voronoi tessellation.
voro = Voronoi(coords)
node_points_map = defaultdict(set)
for pts, vs in voro.ridge_dict.items():
for v in vs:
node_points_map[v].update(pts)
vnodes: list[VoronoiPolyhedron] = []
def get_mapping(vnodes, poly: VoronoiPolyhedron):
"""
Check if a Voronoi Polyhedron is a periodic image of one of the
existing polyhedra.
Modified to avoid expensive ``np.allclose()`` calls.
"""
if not vnodes:
return None
distance_matrix = lattice.get_all_distances([v.frac_coords for v in vnodes], poly.frac_coords)
if np.any(distance_matrix < image_tol):
for v in vnodes:
if v.is_image(poly, image_tol):
return v
return None
# Filter all the voronoi polyhedra so that we only consider those
# which are within the unit cell:
for i, vertex in enumerate(voro.vertices):
if i == 0:
continue
fcoord = lattice.get_fractional_coords(vertex)
if np.all([-image_tol <= c < 1 + image_tol for c in fcoord]):
poly = VoronoiPolyhedron(lattice, fcoord, node_points_map[i], coords, i)
if get_mapping(vnodes, poly) is None:
vnodes.append(poly)
self.coords = coords
self.vnodes = vnodes
[docs]
def get_voronoi_nodes(structure: Structure) -> list[PeriodicSite]:
"""
Get the Voronoi nodes of a ``pymatgen`` ``Structure``.
Maximises efficiency by mapping down to the primitive cell, doing Voronoi
analysis (with the efficient ``DopedTopographyAnalyzer`` class), and then
mapping back to the original structure (typically a supercell).
Args:
structure (Structure):
``pymatgen`` ``Structure`` object.
Returns:
list[PeriodicSite]:
List of ``PeriodicSite`` objects representing the Voronoi nodes.
"""
try:
return _hashable_get_voronoi_nodes(structure)
except TypeError:
structure.__hash__ = _structure__hash__ # make sure Structure is hashable
return _hashable_get_voronoi_nodes(structure)
@lru_cache(maxsize=int(1e2))
def _hashable_get_voronoi_nodes(structure: Structure) -> list[PeriodicSite]:
from doped.utils.symmetry import _doped_cluster_frac_coords, get_primitive_structure
# map all sites to the unit cell; 0 ≤ xyz < 1.
structure = Structure.from_sites(structure, to_unit_cell=True)
# get Voronoi nodes in primitive structure and then map back to the supercell:
prim_structure = get_primitive_structure(structure)
top_analyzer = DopedTopographyAnalyzer(prim_structure)
voronoi_coords = [v.frac_coords for v in top_analyzer.vnodes]
# remove nodes less than 0.5 Å from sites in the structure
voronoi_coords = remove_collisions(voronoi_coords, structure=prim_structure, min_dist=0.5)
# cluster nodes within 0.2 Å of each other:
prim_vnodes: np.ndarray = _doped_cluster_frac_coords(voronoi_coords, prim_structure, tol=0.2)
# map back to the supercell
sm = StructureMatcher(primitive_cell=False, attempt_supercell=True)
mapping = sm.get_supercell_matrix(structure, prim_structure)
voronoi_struct = Structure.from_sites(
[PeriodicSite("X", fpos, structure.lattice) for fpos in prim_vnodes]
) # Structure with Voronoi nodes as sites
voronoi_struct.make_supercell(mapping) # Map back to the supercell
# check if there was an origin shift between primitive and supercell
regenerated_supercell = prim_structure.copy()
regenerated_supercell.make_supercell(mapping)
fractional_shift = sm.get_transformation(structure, regenerated_supercell)[1]
if not np.allclose(fractional_shift, 0):
voronoi_struct.translate_sites(range(len(voronoi_struct)), fractional_shift, frac_coords=True)
return voronoi_struct.sites
def _generic_group_labels(list_in: Sequence, comp: Callable = operator.eq) -> list[int]:
"""
Group a list of unsortable objects, using a given comparator function.
Templated off the ``pymatgen-analysis-defects`` function, but fixed to
avoid broken reassignment logic and overwriting of labels (resulting in
sites being incorrectly dropped).
Previously in ``doped`` interstitial generation, but then removed after
updates in commit ``4699f38`` (for v3.0.0) to use faster site-matching
functions from ``doped``.
Args:
list_in (Sequence): A sequence of objects to group using ``comp``.
comp (Callable): A comparator function.
Returns:
list[int]: list of labels for the input list
"""
list_out = [-1] * len(list_in) # Initialize with -1 instead of None for clarity
label_num = 0
for i1 in range(len(list_in)):
if list_out[i1] != -1: # Already labeled
continue
list_out[i1] = label_num
for i2 in range(i1 + 1, len(list_in)):
if list_out[i2] == -1 and comp(list_in[i1], list_in[i2]):
list_out[i2] = label_num
label_num += 1
return list_out
[docs]
class DopedVacancyGenerator(VacancyGenerator):
"""
Vacancy defects generator, subclassed from ``pymatgen-analysis-defects`` to
improve efficiency (particularly when handling defect complexes).
"""
[docs]
def generate(
self,
structure: Structure,
rm_species: set[str | Species] | list[str | Species] | None = None,
**kwargs,
) -> Generator["Vacancy", None, None]:
"""
Generate vacancy defects.
Args:
structure (Structure):
The structure to generate vacancy defects in.
rm_species (set[str | Species] | list[str | Species] | None):
List/set of species to be removed (i.e. to consider for vacancy
generation). If ``None``, considers all species.
**kwargs:
Additional keyword arguments for the ``Vacancy`` constructor.
Returns:
Generator[Vacancy, None, None]:
Generator that yields a list of ``Vacancy`` objects.
"""
from doped.core import Vacancy
from doped.utils.symmetry import get_sga
# core difference is the removal of unnecessary `remove_oxidation_states` calls
structure = get_valid_magmom_struct(structure)
all_species = {elt.symbol for elt in structure.composition.elements}
rm_species = all_species if rm_species is None else {*map(str, rm_species)}
if not set(rm_species).issubset(all_species):
raise ValueError(
f"rm_species ({rm_species}) must be a subset of the structure's species ({all_species})."
)
sga = get_sga(structure)
sym_struct = sga.get_symmetrized_structure()
for site_group in sym_struct.equivalent_sites:
site = site_group[0]
if site.specie.symbol in rm_species:
yield Vacancy(
structure=structure, # note that we no longer remove oxi states here! or in get_sga
site=site,
equivalent_sites=site_group,
**kwargs,
)