Source code for doped.utils

"""
Submodule for utility functions in doped.

``doped.utils.__init__`` contains utility functions and context managers for
handling of warnings and multi-processing.
"""

import contextlib
import importlib
import inspect
import logging
import multiprocessing
import os
import warnings
from functools import cache
from typing import Any

from pymatgen.io.vasp.inputs import UnknownPotcarWarning
from pymatgen.io.vasp.sets import BadInputSetWarning


[docs] @contextlib.contextmanager def suppress_logging(level=logging.CRITICAL): """ Context manager to catch and suppress logging messages. """ previous_level = logging.root.manager.disable # store the current logging level logging.disable(level) # disable logging at the specified level try: yield finally: logging.disable(previous_level) # restore the original logging level
[docs] @contextlib.contextmanager def patch_vise_for_windows(): """ Context manager to patch ``vise.defaults.UserSettings._make_yaml_file_list``, so that it returns an empty list. Fixes an issue where this function gives an infinite recursive search on Windows, causing hanging. """ try: vd = importlib.import_module("vise.defaults") orig = vd.UserSettings._make_yaml_file_list vd.UserSettings._make_yaml_file_list = lambda *args, **kwargs: [] yield finally: # restore original vd.UserSettings._make_yaml_file_list = orig
[docs] @contextlib.contextmanager def vise_handling(level=logging.CRITICAL): r""" Tame ``vise``/``pydefect`` side effects, by combining :func:`suppress_logging`, ``warnings.catch_warnings()`` and :func:`patch_vise_for_windows`. The steps are ordered to handle two things that must happen _before_ ``vise.defaults`` is first imported anywhere: 1. ``vise.util.logger.get_logger`` is replaced with ``logging.getLogger``, to avoid repeated ``vise`` INFO messages (and duplicate handlers) under parallelism. Only ``vise.util.logger`` is imported for this, which does _not_ pull in ``vise.defaults``, so the patch is in effect before ``vise.defaults`` builds its module-level logger. 2. Import ``vise.defaults`` now, within ``catch_warnings()`` so its ``warnings.simplefilter("ignore", UserWarning)`` is reverted on exit. Otherwise this fires the first time ``vise.defaults`` is imported -- which, with multiprocessing parsing, happens lazily during _result unpickling_ with ``Pool`` (importing a ``pydefect`` ``BandEdgeStates`` object -> ``pydefect.defaults`` -> ``vise.defaults``), i.e. _outside_ any ``vise_handling()`` block. Calling this once at ``doped`` import (below) leaves ``vise.defaults`` in ``sys.modules`` with both fixes applied, so later imports are no-ops and never re-trigger warning suppression. """ with suppress_logging(level), warnings.catch_warnings(): import vise.util.logger vise.util.logger.get_logger = logging.getLogger with patch_vise_for_windows(): # imports ``vise.defaults`` / builds the ``Defaults`` singleton yield
with vise_handling(): # tame ``vise``/``pydefect`` side effects at import time (see :func:`vise_handling`) pass def _ignore_pmg_warnings(): # globally ignore these POTCAR warnings; `message` only needs to match start of message warnings.filterwarnings("ignore", category=UnknownPotcarWarning) warnings.filterwarnings("ignore", category=BadInputSetWarning) warnings.filterwarnings("ignore", message="No POTCAR file with matching TITEL fields") warnings.filterwarnings("ignore", message="Ignoring unknown variable type") # Ignore because comment after 'ALGO = Normal' causes this unnecessary warning: warnings.filterwarnings("ignore", message="Hybrid functionals only support") warnings.filterwarnings("ignore", message="Use get_magnetic_symmetry()") warnings.filterwarnings("ignore", message="Use of properties is now deprecated") # avoid warning about selective_dynamics properties (can happen if user explicitly set "T T T" (or # otherwise) for the bulk): warnings.filterwarnings("ignore", message="Not all sites have property") # ignore warning about structure charge that appears when getting Vasprun.as_dict(): warnings.filterwarnings("ignore", message="Structure charge") # ignore UFloat warning about std_dev==0 (from MP energy corrections), can potentially be removed in # future if/when this issue resolved upstream warnings.filterwarnings("ignore", message="Using UFloat objects with std_dev==0") _ignore_pmg_warnings()
[docs] @cache def warn_once(message: str, category: type[Warning] = UserWarning, key: Any = None) -> None: """ Emit ``message`` as a warning at most once per unique ``(message, category, key)``, unless ``warn_once.cache_clear()`` is called. Unlike Python's default "show once per location" behaviour, this is immune to the ``__warningregistry__`` being reset whenever the warning filters are mutated (which dependencies such as ``pandas`` do internally, defeating the default dedup when warning repeatedly in a loop -- e.g. over temperatures/conditions). ``key`` is an optional `cheap` hashable used to check if the warning has already been called for a given object/situation; e.g. a |DefectEntry| ``name`` to warn once `per defect entry`. Used for the periodicity-breaking supercell and missing-degeneracy-factor warnings, which can otherwise be emitted many times in thermodynamic analysis loops. """ warnings.warn(message, category, stacklevel=3)
[docs] class ParameterOrderWarning(FutureWarning): """ Warning about the ``(bulk, defect)`` -> ``(defect, bulk)`` parameter ordering change for some functions in ``doped`` v4.0. TODO: Remove all parameter-order warning handling in v4.1. """
def _check_parameter_order_warning(): """ Check if the parameter order warning should be shown, based on the ``DOPED_WARN_PARAMETER_ORDER`` environment variable. Defaults to ``True`` if the environment variable is not set. """ env = os.environ.get("DOPED_WARN_PARAMETER_ORDER") if env is None: return True return env.lower() not in ("0", "false", "no") def _warn_parameter_order(func_name: str, stacklevel: int = 3): """ Emit a ``ParameterOrderWarning`` for the given function name. """ if not _check_parameter_order_warning(): return warnings.warn( f"In doped v4.0, the parameter ordering for `{func_name}` was changed from " f"`(bulk_..., defect_..., ...)` to `(defect_..., bulk_..., ...)`. Please ensure your code uses " f"the correct ordering (and/or uses keyword arguments rather than positional arguments). This " f"warning can be disabled by setting the environment variable DOPED_WARN_PARAMETER_ORDER=false, " f"and will be removed in doped v4.1.", ParameterOrderWarning, stacklevel=stacklevel, ) def _doped_obj_properties_methods(obj): """ Return a tuple of the attributes & properties and methods of a given object. Used in the ``__repr__()`` methods of ``doped`` objects. """ attrs = {k for k in vars(obj) if not k.startswith("_")} methods = set() for k in dir(obj): with contextlib.suppress(Exception): if callable(getattr(obj, k)) and not k.startswith("_"): methods.add(k) properties = {name for name, value in inspect.getmembers(type(obj)) if isinstance(value, property)} return attrs | properties, methods
[docs] def get_mp_context(): """ Get a multiprocessing context that is compatible with the current OS. """ try: return multiprocessing.get_context("forkserver") except ValueError: # forkserver not available on Windows OS return multiprocessing.get_context("spawn")
[docs] def get_mp_processes(processes: int | None = None): """ Get the number of processes to use with ``Pool``. """ mp = get_mp_context() # https://github.com/python/cpython/pull/100229 return processes or max(1, mp.cpu_count() - 1)
[docs] @contextlib.contextmanager def pool_manager(processes: int | None = None): r""" Context manager for ``multiprocessing`` ``Pool``, to throw a clearer error message when ``RuntimeError``\s are raised ``multiprocessing`` within ``doped`` is used in a python script. See the :ref:`Errors with Python Scripts <errors_with_python_scripts>` section. Args: processes (int | None): Number of processes to use with ``Pool``. If ``None``, will use ``mp.cpu_count() - 1`` (i.e. one less than the number of available CPUs). Yields: Pool: A ``Pool`` object with the specified number of processes. """ pool = None try: mp = get_mp_context() # https://github.com/python/cpython/pull/100229 pool = mp.Pool(get_mp_processes(processes)) yield pool except RuntimeError as orig_exc: if "freeze_support()" in str(orig_exc): raise RuntimeError( "When using doped in python scripts with multiprocessing (recommended), you must use the " "`if __name__ == '__main__':` syntax, see " "https://doped.readthedocs.io/en/latest/Troubleshooting.html#errors-with-python-scripts " "-- alternatively you can set processes=1 (but this will be slower)" ) from orig_exc raise orig_exc finally: if pool is not None: pool.close() pool.join()