Coverage for /builds/ase/ase/ase/utils/__init__.py: 81.23%
389 statements
« prev ^ index » next coverage.py v7.5.3, created at 2025-08-02 00:12 +0000
« prev ^ index » next coverage.py v7.5.3, created at 2025-08-02 00:12 +0000
1import errno
2import functools
3import io
4import os
5import pickle
6import re
7import string
8import sys
9import time
10import warnings
11from contextlib import ExitStack, contextmanager
12from importlib import import_module
13from math import atan2, cos, degrees, gcd, radians, sin
14from pathlib import Path, PurePath
15from typing import Callable, Dict, List, Type, Union
17import numpy as np
19from ase.formula import formula_hill, formula_metal
21__all__ = [
22 'basestring',
23 'import_module',
24 'seterr',
25 'plural',
26 'devnull',
27 'gcd',
28 'convert_string_to_fd',
29 'Lock',
30 'opencew',
31 'OpenLock',
32 'rotate',
33 'irotate',
34 'pbc2pbc',
35 'givens',
36 'hsv2rgb',
37 'hsv',
38 'pickleload',
39 'reader',
40 'formula_hill',
41 'formula_metal',
42 'PurePath',
43 'xwopen',
44 'tokenize_version',
45 'get_python_package_path_description',
46]
49def tokenize_version(version_string: str):
50 """Parse version string into a tuple for version comparisons.
52 Usage: tokenize_version('3.8') < tokenize_version('3.8.1').
53 """
54 tokens = []
55 for component in version_string.split('.'):
56 match = re.match(r'(\d*)(.*)', component)
57 assert match is not None, f'Cannot parse component {component}'
58 number_str, tail = match.group(1, 2)
59 try:
60 number = int(number_str)
61 except ValueError:
62 number = -1
63 tokens += [number, tail]
64 return tuple(tokens)
67# Python 2+3 compatibility stuff (let's try to remove these things):
68basestring = str
69pickleload = functools.partial(pickle.load, encoding='bytes')
72def deprecated(
73 message: Union[str, Warning],
74 category: Type[Warning] = FutureWarning,
75 callback: Callable[[List, Dict], bool] = lambda args, kwargs: True,
76):
77 """Return a decorator deprecating a function.
79 Parameters
80 ----------
81 message : str or Warning
82 The message to be emitted. If ``message`` is a Warning, then
83 ``category`` is ignored and ``message.__class__`` will be used.
84 category : Type[Warning], default=FutureWarning
85 The type of warning to be emitted. If ``message`` is a ``Warning``
86 instance, then ``category`` will be ignored and ``message.__class__``
87 will be used.
88 callback : Callable[[List, Dict], bool], default=lambda args, kwargs: True
89 A callable that determines if the warning should be emitted and handles
90 any processing prior to calling the deprecated function. The callable
91 will receive two arguments, a list and a dictionary. The list will
92 contain the positional arguments that the deprecated function was
93 called with at runtime while the dictionary will contain the keyword
94 arguments. The callable *must* return ``True`` if the warning is to be
95 emitted and ``False`` otherwise. The list and dictionary will be
96 unpacked into the positional and keyword arguments, respectively, used
97 to call the deprecated function.
99 Returns
100 -------
101 deprecated_decorator : Callable
102 A decorator for deprecated functions that can be used to conditionally
103 emit deprecation warnings and/or pre-process the arguments of a
104 deprecated function.
106 Example
107 -------
108 >>> # Inspect & replace a keyword parameter passed to a deprecated function
109 >>> from typing import Any, Callable, Dict, List
110 >>> import warnings
111 >>> from ase.utils import deprecated
113 >>> def alias_callback_factory(kwarg: str, alias: str) -> Callable:
114 ... def _replace_arg(_: List, kwargs: Dict[str, Any]) -> bool:
115 ... kwargs[kwarg] = kwargs[alias]
116 ... del kwargs[alias]
117 ... return True
118 ... return _replace_arg
120 >>> MESSAGE = ("Calling this function with `atoms` is deprecated. "
121 ... "Use `optimizable` instead.")
122 >>> @deprecated(
123 ... MESSAGE,
124 ... category=DeprecationWarning,
125 ... callback=alias_callback_factory("optimizable", "atoms")
126 ... )
127 ... def function(*, atoms=None, optimizable=None):
128 ... '''
129 ... .. deprecated:: 3.23.0
130 ... Calling this function with ``atoms`` is deprecated.
131 ... Use ``optimizable`` instead.
132 ... '''
133 ... print(f"atoms: {atoms}")
134 ... print(f"optimizable: {optimizable}")
136 >>> with warnings.catch_warnings(record=True) as w:
137 ... warnings.simplefilter("always")
138 ... function(atoms="atoms")
139 atoms: None
140 optimizable: atoms
142 >>> w[-1].category == DeprecationWarning
143 True
144 """
146 def deprecated_decorator(func):
147 @functools.wraps(func)
148 def deprecated_function(*args, **kwargs):
149 _args = list(args)
150 if callback(_args, kwargs):
151 warnings.warn(message, category=category, stacklevel=2)
153 return func(*_args, **kwargs)
155 return deprecated_function
157 return deprecated_decorator
160@contextmanager
161def seterr(**kwargs):
162 """Set how floating-point errors are handled.
164 See np.seterr() for more details.
165 """
166 old = np.seterr(**kwargs)
167 try:
168 yield
169 finally:
170 np.seterr(**old)
173def plural(n, word):
174 """Use plural for n!=1.
176 >>> from ase.utils import plural
178 >>> plural(0, 'egg'), plural(1, 'egg'), plural(2, 'egg')
179 ('0 eggs', '1 egg', '2 eggs')
180 """
181 if n == 1:
182 return '1 ' + word
183 return '%d %ss' % (n, word)
186class DevNull:
187 encoding = 'UTF-8'
188 closed = False
190 _use_os_devnull = deprecated(
191 'use open(os.devnull) instead', DeprecationWarning
192 )
193 # Deprecated for ase-3.21.0. Change to futurewarning later on.
195 @_use_os_devnull
196 def write(self, string):
197 pass
199 @_use_os_devnull
200 def flush(self):
201 pass
203 @_use_os_devnull
204 def seek(self, offset, whence=0):
205 return 0
207 @_use_os_devnull
208 def tell(self):
209 return 0
211 @_use_os_devnull
212 def close(self):
213 pass
215 @_use_os_devnull
216 def isatty(self):
217 return False
219 @_use_os_devnull
220 def read(self, n=-1):
221 return ''
224devnull = DevNull()
227@deprecated(
228 'convert_string_to_fd does not facilitate proper resource '
229 'management. '
230 'Please use e.g. ase.utils.IOContext class instead.'
231)
232def convert_string_to_fd(name, world=None):
233 """Create a file-descriptor for text output.
235 Will open a file for writing with given name. Use None for no output and
236 '-' for sys.stdout.
238 .. deprecated:: 3.22.1
239 Please use e.g. :class:`ase.utils.IOContext` class instead.
240 """
241 if world is None:
242 from ase.parallel import world
243 if name is None or world.rank != 0:
244 return open(os.devnull, 'w')
245 if name == '-':
246 return sys.stdout
247 if isinstance(name, (str, PurePath)):
248 return open(str(name), 'w') # str for py3.5 pathlib
249 return name # we assume name is already a file-descriptor
252# Only Windows has O_BINARY:
253CEW_FLAGS = os.O_CREAT | os.O_EXCL | os.O_WRONLY | getattr(os, 'O_BINARY', 0)
256@contextmanager
257def xwopen(filename, world=None):
258 """Create and open filename exclusively for writing.
260 If master cpu gets exclusive write access to filename, a file
261 descriptor is returned (a dummy file descriptor is returned on the
262 slaves). If the master cpu does not get write access, None is
263 returned on all processors."""
265 fd = opencew(filename, world)
266 try:
267 yield fd
268 finally:
269 if fd is not None:
270 fd.close()
273# @deprecated('use "with xwopen(...) as fd: ..." to prevent resource leak')
274def opencew(filename, world=None):
275 return _opencew(filename, world)
278def _opencew(filename, world=None):
279 import ase.parallel as parallel
281 if world is None:
282 world = parallel.world
284 closelater = []
286 def opener(file, flags):
287 return os.open(file, flags | CEW_FLAGS)
289 try:
290 error = 0
291 if world.rank == 0:
292 try:
293 fd = open(filename, 'wb', opener=opener)
294 except OSError as ex:
295 error = ex.errno
296 else:
297 closelater.append(fd)
298 else:
299 fd = open(os.devnull, 'wb')
300 closelater.append(fd)
302 # Synchronize:
303 error = world.sum_scalar(error)
304 if error == errno.EEXIST:
305 return None
306 if error:
307 raise OSError(error, 'Error', filename)
309 return fd
310 except BaseException:
311 for fd in closelater:
312 fd.close()
313 raise
316def opencew_text(*args, **kwargs):
317 fd = opencew(*args, **kwargs)
318 if fd is None:
319 return None
320 return io.TextIOWrapper(fd)
323class Lock:
324 def __init__(self, name='lock', world=None, timeout=float('inf')):
325 self.name = str(name)
326 self.timeout = timeout
327 if world is None:
328 from ase.parallel import world
329 self.world = world
331 def acquire(self):
332 dt = 0.2
333 t1 = time.time()
334 while True:
335 fd = opencew(self.name, self.world)
336 if fd is not None:
337 self.fd = fd
338 break
339 time_left = self.timeout - (time.time() - t1)
340 if time_left <= 0:
341 raise TimeoutError
342 time.sleep(min(dt, time_left))
343 dt *= 2
345 def release(self):
346 self.world.barrier()
347 # Important to close fd before deleting file on windows
348 # as a WinError would otherwise be raised.
349 self.fd.close()
350 if self.world.rank == 0:
351 os.remove(self.name)
352 self.world.barrier()
354 def __enter__(self):
355 self.acquire()
357 def __exit__(self, type, value, tb):
358 self.release()
361class OpenLock:
362 def acquire(self):
363 pass
365 def release(self):
366 pass
368 def __enter__(self):
369 pass
371 def __exit__(self, type, value, tb):
372 pass
375def search_current_git_hash(arg, world=None):
376 """Search for .git directory and current git commit hash.
378 Parameters:
380 arg: str (directory path) or python module
381 .git directory is searched from the parent directory of
382 the given directory or module.
383 """
384 if world is None:
385 from ase.parallel import world
386 if world.rank != 0:
387 return None
389 # Check argument
390 if isinstance(arg, str):
391 # Directory path
392 dpath = arg
393 else:
394 # Assume arg is module
395 dpath = os.path.dirname(arg.__file__)
396 # dpath = os.path.abspath(dpath)
397 # in case this is just symlinked into $PYTHONPATH
398 dpath = os.path.realpath(dpath)
399 dpath = os.path.dirname(dpath) # Go to the parent directory
400 git_dpath = os.path.join(dpath, '.git')
401 if not os.path.isdir(git_dpath):
402 # Replace this 'if' with a loop if you want to check
403 # further parent directories
404 return None
405 HEAD_file = os.path.join(git_dpath, 'HEAD')
406 if not os.path.isfile(HEAD_file):
407 return None
408 with open(HEAD_file) as fd:
409 line = fd.readline().strip()
410 if line.startswith('ref: '):
411 ref = line[5:]
412 ref_file = os.path.join(git_dpath, ref)
413 else:
414 # Assuming detached HEAD state
415 ref_file = HEAD_file
416 if not os.path.isfile(ref_file):
417 return None
418 with open(ref_file) as fd:
419 line = fd.readline().strip()
420 if all(c in string.hexdigits for c in line):
421 return line
422 return None
425def rotate(rotations, rotation=np.identity(3)):
426 """Convert string of format '50x,-10y,120z' to a rotation matrix.
428 Note that the order of rotation matters, i.e. '50x,40z' is different
429 from '40z,50x'.
430 """
432 if rotations == '':
433 return rotation.copy()
435 for i, a in [
436 ('xyz'.index(s[-1]), radians(float(s[:-1])))
437 for s in rotations.split(',')
438 ]:
439 s = sin(a)
440 c = cos(a)
441 if i == 0:
442 rotation = np.dot(rotation, [(1, 0, 0), (0, c, s), (0, -s, c)])
443 elif i == 1:
444 rotation = np.dot(rotation, [(c, 0, -s), (0, 1, 0), (s, 0, c)])
445 else:
446 rotation = np.dot(rotation, [(c, s, 0), (-s, c, 0), (0, 0, 1)])
447 return rotation
450def givens(a, b):
451 """Solve the equation system::
453 [ c s] [a] [r]
454 [ ] . [ ] = [ ]
455 [-s c] [b] [0]
456 """
457 sgn = np.sign
458 if b == 0:
459 c = sgn(a)
460 s = 0
461 r = abs(a)
462 elif abs(b) >= abs(a):
463 cot = a / b
464 u = sgn(b) * (1 + cot**2) ** 0.5
465 s = 1.0 / u
466 c = s * cot
467 r = b * u
468 else:
469 tan = b / a
470 u = sgn(a) * (1 + tan**2) ** 0.5
471 c = 1.0 / u
472 s = c * tan
473 r = a * u
474 return c, s, r
477def irotate(rotation, initial=np.identity(3)):
478 """Determine x, y, z rotation angles from rotation matrix."""
479 a = np.dot(initial, rotation)
480 cx, sx, rx = givens(a[2, 2], a[1, 2])
481 cy, sy, _ry = givens(rx, a[0, 2])
482 cz, sz, _rz = givens(
483 cx * a[1, 1] - sx * a[2, 1],
484 cy * a[0, 1] - sy * (sx * a[1, 1] + cx * a[2, 1]),
485 )
486 x = degrees(atan2(sx, cx))
487 y = degrees(atan2(-sy, cy))
488 z = degrees(atan2(sz, cz))
489 return x, y, z
492def pbc2pbc(pbc):
493 newpbc = np.empty(3, bool)
494 newpbc[:] = pbc
495 return newpbc
498def string2index(stridx: str) -> Union[int, slice, str]:
499 """Convert index string to either int or slice"""
500 if ':' not in stridx:
501 # may contain database accessor
502 try:
503 return int(stridx)
504 except ValueError:
505 return stridx
506 i = [None if s == '' else int(s) for s in stridx.split(':')]
507 return slice(*i)
510def hsv2rgb(h, s, v):
511 """http://en.wikipedia.org/wiki/HSL_and_HSV
513 h (hue) in [0, 360[
514 s (saturation) in [0, 1]
515 v (value) in [0, 1]
517 return rgb in range [0, 1]
518 """
519 if v == 0:
520 return 0, 0, 0
521 if s == 0:
522 return v, v, v
524 i, f = divmod(h / 60.0, 1)
525 p = v * (1 - s)
526 q = v * (1 - s * f)
527 t = v * (1 - s * (1 - f))
529 if i == 0:
530 return v, t, p
531 elif i == 1:
532 return q, v, p
533 elif i == 2:
534 return p, v, t
535 elif i == 3:
536 return p, q, v
537 elif i == 4:
538 return t, p, v
539 elif i == 5:
540 return v, p, q
541 else:
542 raise RuntimeError('h must be in [0, 360]')
545def hsv(array, s=0.9, v=0.9):
546 array = (array + array.min()) * 359.0 / (array.max() - array.min())
547 result = np.empty((len(array.flat), 3))
548 for rgb, h in zip(result, array.flat):
549 rgb[:] = hsv2rgb(h, s, v)
550 return np.reshape(result, array.shape + (3,))
553# This code does the same, but requires pylab
554# def cmap(array, name='hsv'):
555# import pylab
556# a = (array + array.min()) / array.ptp()
557# rgba = getattr(pylab.cm, name)(a)
558# return rgba[:-1] # return rgb only (not alpha)
561def longsum(x):
562 """128-bit floating point sum."""
563 return float(np.asarray(x, dtype=np.longdouble).sum())
566@contextmanager
567def workdir(path, mkdir=False):
568 """Temporarily change, and optionally create, working directory."""
569 path = Path(path)
570 if mkdir:
571 path.mkdir(parents=True, exist_ok=True)
573 olddir = os.getcwd()
574 os.chdir(path)
575 try:
576 yield # Yield the Path or dirname maybe?
577 finally:
578 os.chdir(olddir)
581class iofunction:
582 """Decorate func so it accepts either str or file.
584 (Won't work on functions that return a generator.)"""
586 def __init__(self, mode):
587 self.mode = mode
589 def __call__(self, func):
590 @functools.wraps(func)
591 def iofunc(file, *args, **kwargs):
592 openandclose = isinstance(file, (str, PurePath))
593 fd = None
594 try:
595 if openandclose:
596 fd = open(str(file), self.mode)
597 else:
598 fd = file
599 obj = func(fd, *args, **kwargs)
600 return obj
601 finally:
602 if openandclose and fd is not None:
603 # fd may be None if open() failed
604 fd.close()
606 return iofunc
609def writer(func):
610 return iofunction('w')(func)
613def reader(func):
614 return iofunction('r')(func)
617# The next two functions are for hotplugging into a JSONable class
618# using the jsonable decorator. We are supposed to have this kind of stuff
619# in ase.io.jsonio, but we'd rather import them from a 'basic' module
620# like ase/utils than one which triggers a lot of extra (cyclic) imports.
623def write_json(self, fd):
624 """Write to JSON file."""
625 from ase.io.jsonio import write_json as _write_json
627 _write_json(fd, self)
630@classmethod # type: ignore[misc]
631def read_json(cls, fd):
632 """Read new instance from JSON file."""
633 from ase.io.jsonio import read_json as _read_json
635 obj = _read_json(fd)
636 assert isinstance(obj, cls)
637 return obj
640def jsonable(name):
641 """Decorator for facilitating JSON I/O with a class.
643 Pokes JSON-based read and write functions into the class.
645 In order to write an object to JSON, it needs to be a known simple type
646 (such as ndarray, float, ...) or implement todict(). If the class
647 defines a string called ase_objtype, the decoder will want to convert
648 the object back into its original type when reading."""
650 def jsonableclass(cls):
651 cls.ase_objtype = name
652 if not hasattr(cls, 'todict'):
653 raise TypeError('Class must implement todict()')
655 # We may want the write and read to be optional.
656 # E.g. a calculator might want to be JSONable, but not
657 # that .write() produces a JSON file.
658 #
659 # This is mostly for 'lightweight' object IO.
660 cls.write = write_json
661 cls.read = read_json
662 return cls
664 return jsonableclass
667class ExperimentalFeatureWarning(Warning):
668 pass
671def experimental(func):
672 """Decorator for functions not ready for production use."""
674 @functools.wraps(func)
675 def expfunc(*args, **kwargs):
676 warnings.warn(
677 'This function may change or misbehave: {}()'.format(
678 func.__qualname__
679 ),
680 ExperimentalFeatureWarning,
681 )
682 return func(*args, **kwargs)
684 return expfunc
687@deprecated('use functools.cached_property instead')
688def lazymethod(meth):
689 """Decorator for lazy evaluation and caching of data.
691 Example::
693 class MyClass:
695 @lazymethod
696 def thing(self):
697 return expensive_calculation()
699 The method body is only executed first time thing() is called, and
700 its return value is stored. Subsequent calls return the cached
701 value.
703 .. deprecated:: 3.25.0
704 """
705 name = meth.__name__
707 @functools.wraps(meth)
708 def getter(self):
709 try:
710 cache = self._lazy_cache
711 except AttributeError:
712 cache = self._lazy_cache = {}
714 if name not in cache:
715 cache[name] = meth(self)
716 return cache[name]
718 return getter
721def atoms_to_spglib_cell(atoms):
722 """Convert atoms into data suitable for calling spglib."""
723 return (
724 atoms.get_cell(),
725 atoms.get_scaled_positions(),
726 atoms.get_atomic_numbers(),
727 )
730def warn_legacy(feature_name):
731 warnings.warn(
732 f'The {feature_name} feature is untested and ASE developers do not '
733 'know whether it works or how to use it. Please rehabilitate it '
734 '(by writing unittests) or it may be removed.',
735 FutureWarning,
736 )
739@deprecated('use functools.cached_property instead')
740def lazyproperty(meth):
741 """Decorator like lazymethod, but making item available as a property.
743 .. deprecated:: 3.25.0
744 """
745 return property(lazymethod(meth))
748class _DelExitStack(ExitStack):
749 # We don't want IOContext itself to implement __del__, since IOContext
750 # might be subclassed, and we don't want __del__ on objects that we
751 # don't fully control. Therefore we make a little custom class
752 # that nobody else refers to, and that has the __del__.
753 def __del__(self):
754 self.close()
757class IOContext:
758 @functools.cached_property
759 def _exitstack(self):
760 return _DelExitStack()
762 def __enter__(self):
763 return self
765 def __exit__(self, *args):
766 self.close()
768 def closelater(self, fd):
769 return self._exitstack.enter_context(fd)
771 def close(self):
772 self._exitstack.close()
774 def openfile(self, file, comm, mode='w'):
775 if hasattr(file, 'close'):
776 return file # File already opened, not for us to close.
778 encoding = None if mode.endswith('b') else 'utf-8'
780 if file is None or comm.rank != 0:
781 return self.closelater(
782 open(os.devnull, mode=mode, encoding=encoding)
783 )
785 if file == '-':
786 return sys.stdout
788 return self.closelater(open(file, mode=mode, encoding=encoding))
791def get_python_package_path_description(
792 package, default='module has no path'
793) -> str:
794 """Helper to get path description of a python package/module
796 If path has multiple elements, the first one is returned.
797 If it is empty, the default is returned.
798 Exceptions are returned as strings default+(exception).
799 Always returns a string.
800 """
801 try:
802 p = list(package.__path__)
803 if p:
804 return str(p[0])
805 else:
806 return default
807 except Exception as ex:
808 return f'{default} ({ex})'