Coverage for ase / utils / __init__.py: 81.93%
404 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-03-30 08:22 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-03-30 08:22 +0000
1import errno
2import functools
3import io
4import os
5import pickle
6import re
7import string
8import sys
9import time
10import warnings
11from collections.abc import Callable
12from contextlib import ExitStack, contextmanager
13from importlib import import_module
14from math import atan2, cos, degrees, gcd, radians, sin
15from pathlib import Path, PurePath
17import numpy as np
19from ase.formula import formula_hill, formula_metal
21__all__ = [
22 'basestring',
23 'import_module',
24 'seterr',
25 'plural',
26 'devnull',
27 'gcd',
28 'convert_string_to_fd',
29 'Lock',
30 'opencew',
31 'OpenLock',
32 'rotate',
33 'irotate',
34 'pbc2pbc',
35 'givens',
36 'hsv2rgb',
37 'hsv',
38 'pickleload',
39 'reader',
40 'formula_hill',
41 'formula_metal',
42 'PurePath',
43 'xwopen',
44 'tokenize_version',
45 'get_python_package_path_description',
46]
49def tokenize_version(version_string: str):
50 """Parse version string into a tuple for version comparisons.
52 Usage: tokenize_version('3.8') < tokenize_version('3.8.1').
53 """
54 tokens = []
55 for component in version_string.split('.'):
56 match = re.match(r'(\d*)(.*)', component)
57 assert match is not None, f'Cannot parse component {component}'
58 number_str, tail = match.group(1, 2)
59 try:
60 number = int(number_str)
61 except ValueError:
62 number = -1
63 tokens += [number, tail]
64 return tuple(tokens)
67# Python 2+3 compatibility stuff (let's try to remove these things):
68basestring = str
69pickleload = functools.partial(pickle.load, encoding='bytes')
72def deprecated(
73 message: str | Warning,
74 category: type[Warning] = FutureWarning,
75 callback: Callable[[list, dict], bool] = lambda args, kwargs: True,
76):
77 """Return a decorator deprecating a function.
79 Parameters
80 ----------
81 message : str or Warning
82 The message to be emitted. If ``message`` is a Warning, then
83 ``category`` is ignored and ``message.__class__`` will be used.
84 category : type[Warning], default=FutureWarning
85 The type of warning to be emitted. If ``message`` is a ``Warning``
86 instance, then ``category`` will be ignored and ``message.__class__``
87 will be used.
88 callback : Callable[[list, dict], bool], default=lambda args, kwargs: True
89 A callable that determines if the warning should be emitted and handles
90 any processing prior to calling the deprecated function. The callable
91 will receive two arguments, a list and a dictionary. The list will
92 contain the positional arguments that the deprecated function was
93 called with at runtime while the dictionary will contain the keyword
94 arguments. The callable *must* return ``True`` if the warning is to be
95 emitted and ``False`` otherwise. The list and dictionary will be
96 unpacked into the positional and keyword arguments, respectively, used
97 to call the deprecated function.
99 Returns
100 -------
101 deprecated_decorator : Callable
102 A decorator for deprecated functions that can be used to conditionally
103 emit deprecation warnings and/or pre-process the arguments of a
104 deprecated function.
106 Example
107 -------
108 >>> # Inspect & replace a keyword parameter passed to a deprecated function
109 >>> from typing import Any, Callable
110 >>> import warnings
111 >>> from ase.utils import deprecated
113 >>> def alias_callback_factory(kwarg: str, alias: str) -> Callable:
114 ... def _replace_arg(_: list, kwargs: dict[str, Any]) -> bool:
115 ... kwargs[kwarg] = kwargs[alias]
116 ... del kwargs[alias]
117 ... return True
118 ... return _replace_arg
120 >>> MESSAGE = ("Calling this function with `atoms` is deprecated. "
121 ... "Use `optimizable` instead.")
122 >>> @deprecated(
123 ... MESSAGE,
124 ... category=DeprecationWarning,
125 ... callback=alias_callback_factory("optimizable", "atoms")
126 ... )
127 ... def function(*, atoms=None, optimizable=None):
128 ... '''
129 ... .. deprecated:: 3.23.0
130 ... Calling this function with ``atoms`` is deprecated.
131 ... Use ``optimizable`` instead.
132 ... '''
133 ... print(f"atoms: {atoms}")
134 ... print(f"optimizable: {optimizable}")
136 >>> with warnings.catch_warnings(record=True) as w:
137 ... warnings.simplefilter("always")
138 ... function(atoms="atoms")
139 atoms: None
140 optimizable: atoms
142 >>> w[-1].category == DeprecationWarning
143 True
144 """
146 def deprecated_decorator(func):
147 @functools.wraps(func)
148 def deprecated_function(*args, **kwargs):
149 _args = list(args)
150 if callback(_args, kwargs):
151 warnings.warn(message, category=category, stacklevel=2)
153 return func(*_args, **kwargs)
155 return deprecated_function
157 return deprecated_decorator
160@contextmanager
161def seterr(**kwargs):
162 """Set how floating-point errors are handled.
164 See np.seterr() for more details.
165 """
166 old = np.seterr(**kwargs)
167 try:
168 yield
169 finally:
170 np.seterr(**old)
173def plural(n, word):
174 """Use plural for n!=1.
176 >>> from ase.utils import plural
178 >>> plural(0, 'egg'), plural(1, 'egg'), plural(2, 'egg')
179 ('0 eggs', '1 egg', '2 eggs')
180 """
181 if n == 1:
182 return '1 ' + word
183 return '%d %ss' % (n, word)
186class DevNull:
187 encoding = 'UTF-8'
188 closed = False
190 _use_os_devnull = deprecated(
191 'use open(os.devnull) instead', DeprecationWarning
192 )
193 # Deprecated for ase-3.21.0. Change to futurewarning later on.
195 @_use_os_devnull
196 def write(self, string):
197 pass
199 @_use_os_devnull
200 def flush(self):
201 pass
203 @_use_os_devnull
204 def seek(self, offset, whence=0):
205 return 0
207 @_use_os_devnull
208 def tell(self):
209 return 0
211 @_use_os_devnull
212 def close(self):
213 pass
215 @_use_os_devnull
216 def isatty(self):
217 return False
219 @_use_os_devnull
220 def read(self, n=-1):
221 return ''
224devnull = DevNull()
227@deprecated(
228 'convert_string_to_fd does not facilitate proper resource '
229 'management. '
230 'Please use e.g. ase.utils.IOContext class instead.'
231)
232def convert_string_to_fd(name, world=None):
233 """Create a file-descriptor for text output.
235 Will open a file for writing with given name. Use None for no output and
236 '-' for sys.stdout.
238 .. deprecated:: 3.22.1
239 Please use e.g. :class:`ase.utils.IOContext` class instead.
240 """
241 if world is None:
242 from ase.parallel import world
243 if name is None or world.rank != 0:
244 return open(os.devnull, 'w')
245 if name == '-':
246 return sys.stdout
247 if isinstance(name, (str, PurePath)):
248 return open(str(name), 'w') # str for py3.5 pathlib
249 return name # we assume name is already a file-descriptor
252# Only Windows has O_BINARY:
253CEW_FLAGS = os.O_CREAT | os.O_EXCL | os.O_WRONLY | getattr(os, 'O_BINARY', 0)
256@contextmanager
257def xwopen(filename, world=None):
258 """Create and open filename exclusively for writing.
260 If master cpu gets exclusive write access to filename, a file
261 descriptor is returned (a dummy file descriptor is returned on the
262 slaves). If the master cpu does not get write access, None is
263 returned on all processors."""
265 fd = opencew(filename, world)
266 try:
267 yield fd
268 finally:
269 if fd is not None:
270 fd.close()
273# @deprecated('use "with xwopen(...) as fd: ..." to prevent resource leak')
274def opencew(filename, world=None):
275 return _opencew(filename, world)
278def _opencew(filename, world=None):
279 import ase.parallel as parallel
281 if world is None:
282 world = parallel.world
284 closelater = []
286 def opener(file, flags):
287 return os.open(file, flags | CEW_FLAGS)
289 try:
290 error = 0
291 if world.rank == 0:
292 try:
293 fd = open(filename, 'wb', opener=opener)
294 except OSError as ex:
295 error = ex.errno
296 else:
297 closelater.append(fd)
298 else:
299 fd = open(os.devnull, 'wb')
300 closelater.append(fd)
302 # Synchronize:
303 error = world.sum_scalar(error)
304 if error == errno.EEXIST:
305 return None
306 if error:
307 raise OSError(error, 'Error', filename)
309 return fd
310 except BaseException:
311 for fd in closelater:
312 fd.close()
313 raise
316def opencew_text(*args, **kwargs):
317 fd = opencew(*args, **kwargs)
318 if fd is None:
319 return None
320 return io.TextIOWrapper(fd)
323class Lock:
324 def __init__(self, name='lock', world=None, timeout=float('inf')):
325 self.name = str(name)
326 self.timeout = timeout
327 if world is None:
328 from ase.parallel import world
329 self.world = world
331 def acquire(self):
332 dt = 0.2
333 t1 = time.time()
334 while True:
335 fd = opencew(self.name, self.world)
336 if fd is not None:
337 self.fd = fd
338 break
339 time_left = self.timeout - (time.time() - t1)
340 if time_left <= 0:
341 raise TimeoutError
342 time.sleep(min(dt, time_left))
343 dt *= 2
345 def release(self):
346 self.world.barrier()
347 # Important to close fd before deleting file on windows
348 # as a WinError would otherwise be raised.
349 self.fd.close()
350 if self.world.rank == 0:
351 os.remove(self.name)
352 self.world.barrier()
354 def __enter__(self):
355 self.acquire()
357 def __exit__(self, type, value, tb):
358 self.release()
361class OpenLock:
362 def acquire(self):
363 pass
365 def release(self):
366 pass
368 def __enter__(self):
369 pass
371 def __exit__(self, type, value, tb):
372 pass
375def search_current_git_hash(arg, world=None):
376 """Search for .git directory and current git commit hash.
378 Parameters
379 ----------
381 arg: str (directory path) or python module
382 .git directory is searched from the parent directory of
383 the given directory or module.
384 """
385 if world is None:
386 from ase.parallel import world
387 if world.rank != 0:
388 return None
390 # Check argument
391 if isinstance(arg, str):
392 # Directory path
393 dpath = arg
394 else:
395 # Assume arg is module
396 dpath = os.path.dirname(arg.__file__)
397 # dpath = os.path.abspath(dpath)
398 # in case this is just symlinked into $PYTHONPATH
399 dpath = os.path.realpath(dpath)
400 dpath = os.path.dirname(dpath) # Go to the parent directory
401 git_dpath = os.path.join(dpath, '.git')
402 if not os.path.isdir(git_dpath):
403 # Replace this 'if' with a loop if you want to check
404 # further parent directories
405 return None
406 HEAD_file = os.path.join(git_dpath, 'HEAD')
407 if not os.path.isfile(HEAD_file):
408 return None
409 with open(HEAD_file) as fd:
410 line = fd.readline().strip()
411 if line.startswith('ref: '):
412 ref = line[5:]
413 ref_file = os.path.join(git_dpath, ref)
414 else:
415 # Assuming detached HEAD state
416 ref_file = HEAD_file
417 if not os.path.isfile(ref_file):
418 return None
419 with open(ref_file) as fd:
420 line = fd.readline().strip()
421 if all(c in string.hexdigits for c in line):
422 return line
423 return None
426def rotate(rotations, rotation=np.identity(3)):
427 """Convert string of format '50x,-10y,120z' to a rotation matrix.
429 Note that the order of rotation matters, i.e. '50x,40z' is different
430 from '40z,50x'.
431 """
433 if rotations == '':
434 return rotation.copy()
436 for i, a in [
437 ('xyz'.index(s[-1]), radians(float(s[:-1])))
438 for s in rotations.split(',')
439 ]:
440 s = sin(a)
441 c = cos(a)
442 if i == 0:
443 rotation = np.dot(rotation, [(1, 0, 0), (0, c, s), (0, -s, c)])
444 elif i == 1:
445 rotation = np.dot(rotation, [(c, 0, -s), (0, 1, 0), (s, 0, c)])
446 else:
447 rotation = np.dot(rotation, [(c, s, 0), (-s, c, 0), (0, 0, 1)])
448 return rotation
451def givens(a, b):
452 """Solve the equation system::
454 [ c s] [a] [r]
455 [ ] . [ ] = [ ]
456 [-s c] [b] [0]
457 """
458 sgn = np.sign
459 if b == 0:
460 c = sgn(a)
461 s = 0
462 r = abs(a)
463 elif abs(b) >= abs(a):
464 cot = a / b
465 u = sgn(b) * (1 + cot**2) ** 0.5
466 s = 1.0 / u
467 c = s * cot
468 r = b * u
469 else:
470 tan = b / a
471 u = sgn(a) * (1 + tan**2) ** 0.5
472 c = 1.0 / u
473 s = c * tan
474 r = a * u
475 return c, s, r
478def irotate(rotation, initial=np.identity(3)):
479 """Determine x, y, z rotation angles from rotation matrix."""
480 a = np.dot(initial, rotation)
481 cx, sx, rx = givens(a[2, 2], a[1, 2])
482 cy, sy, _ry = givens(rx, a[0, 2])
483 cz, sz, _rz = givens(
484 cx * a[1, 1] - sx * a[2, 1],
485 cy * a[0, 1] - sy * (sx * a[1, 1] + cx * a[2, 1]),
486 )
487 x = degrees(atan2(sx, cx))
488 y = degrees(atan2(-sy, cy))
489 z = degrees(atan2(sz, cz))
490 return x, y, z
493def pbc2pbc(pbc):
494 newpbc = np.empty(3, bool)
495 newpbc[:] = pbc
496 return newpbc
499def string2index(stridx: str) -> int | slice | str:
500 """Convert index string to either int or slice"""
501 if ':' not in stridx:
502 # may contain database accessor
503 try:
504 return int(stridx)
505 except ValueError:
506 return stridx
507 i = [None if s == '' else int(s) for s in stridx.split(':')]
508 return slice(*i)
511def hsv2rgb(h, s, v):
512 """http://en.wikipedia.org/wiki/HSL_and_HSV
514 h (hue) in [0, 360[
515 s (saturation) in [0, 1]
516 v (value) in [0, 1]
518 return rgb in range [0, 1]
519 """
520 if v == 0:
521 return 0, 0, 0
522 if s == 0:
523 return v, v, v
525 i, f = divmod(h / 60.0, 1)
526 p = v * (1 - s)
527 q = v * (1 - s * f)
528 t = v * (1 - s * (1 - f))
530 if i == 0:
531 return v, t, p
532 elif i == 1:
533 return q, v, p
534 elif i == 2:
535 return p, v, t
536 elif i == 3:
537 return p, q, v
538 elif i == 4:
539 return t, p, v
540 elif i == 5:
541 return v, p, q
542 else:
543 raise RuntimeError('h must be in [0, 360]')
546def hsv(array, s=0.9, v=0.9):
547 array = (array + array.min()) * 359.0 / (array.max() - array.min())
548 result = np.empty((len(array.flat), 3))
549 for rgb, h in zip(result, array.flat):
550 rgb[:] = hsv2rgb(h, s, v)
551 return np.reshape(result, array.shape + (3,))
554# This code does the same, but requires pylab
555# def cmap(array, name='hsv'):
556# import pylab
557# a = (array + array.min()) / array.ptp()
558# rgba = getattr(pylab.cm, name)(a)
559# return rgba[:-1] # return rgb only (not alpha)
562def longsum(x):
563 """128-bit floating point sum."""
564 return float(np.asarray(x, dtype=np.longdouble).sum())
567@contextmanager
568def workdir(path, mkdir=False):
569 """Temporarily change, and optionally create, working directory."""
570 path = Path(path)
571 if mkdir:
572 path.mkdir(parents=True, exist_ok=True)
574 olddir = os.getcwd()
575 os.chdir(path)
576 try:
577 yield # Yield the Path or dirname maybe?
578 finally:
579 os.chdir(olddir)
582class iofunction:
583 """Decorate func so it accepts either str or file.
585 (Won't work on functions that return a generator.)"""
587 def __init__(self, mode):
588 self.mode = mode
590 def __call__(self, func):
591 @functools.wraps(func)
592 def iofunc(file, *args, **kwargs):
593 openandclose = isinstance(file, (str, PurePath))
594 fd = None
595 try:
596 if openandclose:
597 fd = open(str(file), self.mode)
598 else:
599 fd = file
600 obj = func(fd, *args, **kwargs)
601 return obj
602 finally:
603 if openandclose and fd is not None:
604 # fd may be None if open() failed
605 fd.close()
607 return iofunc
610def writer(func):
611 return iofunction('w')(func)
614def reader(func):
615 return iofunction('r')(func)
618# The next two functions are for hotplugging into a JSONable class
619# using the jsonable decorator. We are supposed to have this kind of stuff
620# in ase.io.jsonio, but we'd rather import them from a 'basic' module
621# like ase/utils than one which triggers a lot of extra (cyclic) imports.
624def write_json(self, fd):
625 """Write to JSON file."""
626 from ase.io.jsonio import write_json as _write_json
628 _write_json(fd, self)
631@classmethod # type: ignore[misc]
632def read_json(cls, fd):
633 """Read new instance from JSON file."""
634 from ase.io.jsonio import read_json as _read_json
636 obj = _read_json(fd)
637 assert isinstance(obj, cls)
638 return obj
641def jsonable(name):
642 """Decorator for facilitating JSON I/O with a class.
644 Pokes JSON-based read and write functions into the class.
646 In order to write an object to JSON, it needs to be a known simple type
647 (such as ndarray, float, ...) or implement todict(). If the class
648 defines a string called ase_objtype, the decoder will want to convert
649 the object back into its original type when reading."""
651 def jsonableclass(cls):
652 cls.ase_objtype = name
653 if not hasattr(cls, 'todict'):
654 raise TypeError('Class must implement todict()')
656 # We may want the write and read to be optional.
657 # E.g. a calculator might want to be JSONable, but not
658 # that .write() produces a JSON file.
659 #
660 # This is mostly for 'lightweight' object IO.
661 cls.write = write_json
662 cls.read = read_json
663 return cls
665 return jsonableclass
668class ExperimentalFeatureWarning(Warning):
669 pass
672def experimental(func):
673 """Decorator for functions not ready for production use."""
675 @functools.wraps(func)
676 def expfunc(*args, **kwargs):
677 warnings.warn(
678 'This function may change or misbehave: {}()'.format(
679 func.__qualname__
680 ),
681 ExperimentalFeatureWarning,
682 )
683 return func(*args, **kwargs)
685 return expfunc
688@deprecated('use functools.cached_property instead')
689def lazymethod(meth):
690 """Decorator for lazy evaluation and caching of data.
692 Example::
694 class MyClass:
696 @lazymethod
697 def thing(self):
698 return expensive_calculation()
700 The method body is only executed first time thing() is called, and
701 its return value is stored. Subsequent calls return the cached
702 value.
704 .. deprecated:: 3.25.0
705 """
706 name = meth.__name__
708 @functools.wraps(meth)
709 def getter(self):
710 try:
711 cache = self._lazy_cache
712 except AttributeError:
713 cache = self._lazy_cache = {}
715 if name not in cache:
716 cache[name] = meth(self)
717 return cache[name]
719 return getter
722def atoms_to_spglib_cell(atoms):
723 """Convert atoms into data suitable for calling spglib."""
724 return (
725 atoms.get_cell(),
726 atoms.get_scaled_positions(),
727 atoms.get_atomic_numbers(),
728 )
731def warn_legacy(feature_name):
732 warnings.warn(
733 f'The {feature_name} feature is untested and ASE developers do not '
734 'know whether it works or how to use it. Please rehabilitate it '
735 '(by writing unittests) or it may be removed.',
736 FutureWarning,
737 )
740@deprecated('use functools.cached_property instead')
741def lazyproperty(meth):
742 """Decorator like lazymethod, but making item available as a property.
744 .. deprecated:: 3.25.0
745 """
746 return property(lazymethod(meth))
749class _DelExitStack(ExitStack):
750 # We don't want IOContext itself to implement __del__, since IOContext
751 # might be subclassed, and we don't want __del__ on objects that we
752 # don't fully control. Therefore we make a little custom class
753 # that nobody else refers to, and that has the __del__.
754 def __del__(self):
755 self.close()
758class IOContext:
759 @functools.cached_property
760 def _exitstack(self):
761 return _DelExitStack()
763 def __enter__(self):
764 return self
766 def __exit__(self, *args):
767 self.close()
769 def closelater(self, fd):
770 return self._exitstack.enter_context(fd)
772 def close(self):
773 self._exitstack.close()
775 def openfile(self, file, comm, mode='w'):
776 if hasattr(file, 'close'):
777 return file # File already opened, not for us to close.
779 encoding = None if mode.endswith('b') else 'utf-8'
781 if file is None or comm.rank != 0:
782 return self.closelater(
783 open(os.devnull, mode=mode, encoding=encoding)
784 )
786 if file == '-':
787 return sys.stdout
789 return self.closelater(open(file, mode=mode, encoding=encoding))
792def get_python_package_path_description(
793 package, default='module has no path'
794) -> str:
795 """Helper to get path description of a python package/module
797 If path has multiple elements, the first one is returned.
798 If it is empty, the default is returned.
799 Exceptions are returned as strings default+(exception).
800 Always returns a string.
801 """
802 try:
803 p = list(package.__path__)
804 if p:
805 return str(p[0])
806 else:
807 return default
808 except Exception as ex:
809 return f'{default} ({ex})'
812class OldSpglibError(Exception):
813 pass
816def spglib_new_errorhandling(func):
817 def spglib_wrapper(*args, **kwargs):
818 # spglib<2.7.0 returns None when there is an error.
819 # spglib 2.7.0 warns that this will become exceptions in the future.
820 # We hack an environment setting to silence this warning and get the
821 # behaviour we want.
822 key = 'SPGLIB_OLD_ERROR_HANDLING'
823 orig_value = os.environ.get(key)
824 try:
825 os.environ[key] = 'false'
826 value = func(*args, **kwargs)
827 if value is None:
828 raise OldSpglibError()
829 return value
830 finally:
831 if orig_value is not None:
832 os.environ[key] = orig_value
834 return spglib_wrapper