Coverage for /builds/ase/ase/ase/utils/__init__.py: 81.23%

389 statements  

« prev     ^ index     » next       coverage.py v7.5.3, created at 2025-08-02 00:12 +0000

1import errno 

2import functools 

3import io 

4import os 

5import pickle 

6import re 

7import string 

8import sys 

9import time 

10import warnings 

11from contextlib import ExitStack, contextmanager 

12from importlib import import_module 

13from math import atan2, cos, degrees, gcd, radians, sin 

14from pathlib import Path, PurePath 

15from typing import Callable, Dict, List, Type, Union 

16 

17import numpy as np 

18 

19from ase.formula import formula_hill, formula_metal 

20 

21__all__ = [ 

22 'basestring', 

23 'import_module', 

24 'seterr', 

25 'plural', 

26 'devnull', 

27 'gcd', 

28 'convert_string_to_fd', 

29 'Lock', 

30 'opencew', 

31 'OpenLock', 

32 'rotate', 

33 'irotate', 

34 'pbc2pbc', 

35 'givens', 

36 'hsv2rgb', 

37 'hsv', 

38 'pickleload', 

39 'reader', 

40 'formula_hill', 

41 'formula_metal', 

42 'PurePath', 

43 'xwopen', 

44 'tokenize_version', 

45 'get_python_package_path_description', 

46] 

47 

48 

49def tokenize_version(version_string: str): 

50 """Parse version string into a tuple for version comparisons. 

51 

52 Usage: tokenize_version('3.8') < tokenize_version('3.8.1'). 

53 """ 

54 tokens = [] 

55 for component in version_string.split('.'): 

56 match = re.match(r'(\d*)(.*)', component) 

57 assert match is not None, f'Cannot parse component {component}' 

58 number_str, tail = match.group(1, 2) 

59 try: 

60 number = int(number_str) 

61 except ValueError: 

62 number = -1 

63 tokens += [number, tail] 

64 return tuple(tokens) 

65 

66 

67# Python 2+3 compatibility stuff (let's try to remove these things): 

68basestring = str 

69pickleload = functools.partial(pickle.load, encoding='bytes') 

70 

71 

72def deprecated( 

73 message: Union[str, Warning], 

74 category: Type[Warning] = FutureWarning, 

75 callback: Callable[[List, Dict], bool] = lambda args, kwargs: True, 

76): 

77 """Return a decorator deprecating a function. 

78 

79 Parameters 

80 ---------- 

81 message : str or Warning 

82 The message to be emitted. If ``message`` is a Warning, then 

83 ``category`` is ignored and ``message.__class__`` will be used. 

84 category : Type[Warning], default=FutureWarning 

85 The type of warning to be emitted. If ``message`` is a ``Warning`` 

86 instance, then ``category`` will be ignored and ``message.__class__`` 

87 will be used. 

88 callback : Callable[[List, Dict], bool], default=lambda args, kwargs: True 

89 A callable that determines if the warning should be emitted and handles 

90 any processing prior to calling the deprecated function. The callable 

91 will receive two arguments, a list and a dictionary. The list will 

92 contain the positional arguments that the deprecated function was 

93 called with at runtime while the dictionary will contain the keyword 

94 arguments. The callable *must* return ``True`` if the warning is to be 

95 emitted and ``False`` otherwise. The list and dictionary will be 

96 unpacked into the positional and keyword arguments, respectively, used 

97 to call the deprecated function. 

98 

99 Returns 

100 ------- 

101 deprecated_decorator : Callable 

102 A decorator for deprecated functions that can be used to conditionally 

103 emit deprecation warnings and/or pre-process the arguments of a 

104 deprecated function. 

105 

106 Example 

107 ------- 

108 >>> # Inspect & replace a keyword parameter passed to a deprecated function 

109 >>> from typing import Any, Callable, Dict, List 

110 >>> import warnings 

111 >>> from ase.utils import deprecated 

112 

113 >>> def alias_callback_factory(kwarg: str, alias: str) -> Callable: 

114 ... def _replace_arg(_: List, kwargs: Dict[str, Any]) -> bool: 

115 ... kwargs[kwarg] = kwargs[alias] 

116 ... del kwargs[alias] 

117 ... return True 

118 ... return _replace_arg 

119 

120 >>> MESSAGE = ("Calling this function with `atoms` is deprecated. " 

121 ... "Use `optimizable` instead.") 

122 >>> @deprecated( 

123 ... MESSAGE, 

124 ... category=DeprecationWarning, 

125 ... callback=alias_callback_factory("optimizable", "atoms") 

126 ... ) 

127 ... def function(*, atoms=None, optimizable=None): 

128 ... ''' 

129 ... .. deprecated:: 3.23.0 

130 ... Calling this function with ``atoms`` is deprecated. 

131 ... Use ``optimizable`` instead. 

132 ... ''' 

133 ... print(f"atoms: {atoms}") 

134 ... print(f"optimizable: {optimizable}") 

135 

136 >>> with warnings.catch_warnings(record=True) as w: 

137 ... warnings.simplefilter("always") 

138 ... function(atoms="atoms") 

139 atoms: None 

140 optimizable: atoms 

141 

142 >>> w[-1].category == DeprecationWarning 

143 True 

144 """ 

145 

146 def deprecated_decorator(func): 

147 @functools.wraps(func) 

148 def deprecated_function(*args, **kwargs): 

149 _args = list(args) 

150 if callback(_args, kwargs): 

151 warnings.warn(message, category=category, stacklevel=2) 

152 

153 return func(*_args, **kwargs) 

154 

155 return deprecated_function 

156 

157 return deprecated_decorator 

158 

159 

160@contextmanager 

161def seterr(**kwargs): 

162 """Set how floating-point errors are handled. 

163 

164 See np.seterr() for more details. 

165 """ 

166 old = np.seterr(**kwargs) 

167 try: 

168 yield 

169 finally: 

170 np.seterr(**old) 

171 

172 

173def plural(n, word): 

174 """Use plural for n!=1. 

175 

176 >>> from ase.utils import plural 

177 

178 >>> plural(0, 'egg'), plural(1, 'egg'), plural(2, 'egg') 

179 ('0 eggs', '1 egg', '2 eggs') 

180 """ 

181 if n == 1: 

182 return '1 ' + word 

183 return '%d %ss' % (n, word) 

184 

185 

186class DevNull: 

187 encoding = 'UTF-8' 

188 closed = False 

189 

190 _use_os_devnull = deprecated( 

191 'use open(os.devnull) instead', DeprecationWarning 

192 ) 

193 # Deprecated for ase-3.21.0. Change to futurewarning later on. 

194 

195 @_use_os_devnull 

196 def write(self, string): 

197 pass 

198 

199 @_use_os_devnull 

200 def flush(self): 

201 pass 

202 

203 @_use_os_devnull 

204 def seek(self, offset, whence=0): 

205 return 0 

206 

207 @_use_os_devnull 

208 def tell(self): 

209 return 0 

210 

211 @_use_os_devnull 

212 def close(self): 

213 pass 

214 

215 @_use_os_devnull 

216 def isatty(self): 

217 return False 

218 

219 @_use_os_devnull 

220 def read(self, n=-1): 

221 return '' 

222 

223 

224devnull = DevNull() 

225 

226 

227@deprecated( 

228 'convert_string_to_fd does not facilitate proper resource ' 

229 'management. ' 

230 'Please use e.g. ase.utils.IOContext class instead.' 

231) 

232def convert_string_to_fd(name, world=None): 

233 """Create a file-descriptor for text output. 

234 

235 Will open a file for writing with given name. Use None for no output and 

236 '-' for sys.stdout. 

237 

238 .. deprecated:: 3.22.1 

239 Please use e.g. :class:`ase.utils.IOContext` class instead. 

240 """ 

241 if world is None: 

242 from ase.parallel import world 

243 if name is None or world.rank != 0: 

244 return open(os.devnull, 'w') 

245 if name == '-': 

246 return sys.stdout 

247 if isinstance(name, (str, PurePath)): 

248 return open(str(name), 'w') # str for py3.5 pathlib 

249 return name # we assume name is already a file-descriptor 

250 

251 

252# Only Windows has O_BINARY: 

253CEW_FLAGS = os.O_CREAT | os.O_EXCL | os.O_WRONLY | getattr(os, 'O_BINARY', 0) 

254 

255 

256@contextmanager 

257def xwopen(filename, world=None): 

258 """Create and open filename exclusively for writing. 

259 

260 If master cpu gets exclusive write access to filename, a file 

261 descriptor is returned (a dummy file descriptor is returned on the 

262 slaves). If the master cpu does not get write access, None is 

263 returned on all processors.""" 

264 

265 fd = opencew(filename, world) 

266 try: 

267 yield fd 

268 finally: 

269 if fd is not None: 

270 fd.close() 

271 

272 

273# @deprecated('use "with xwopen(...) as fd: ..." to prevent resource leak') 

274def opencew(filename, world=None): 

275 return _opencew(filename, world) 

276 

277 

278def _opencew(filename, world=None): 

279 import ase.parallel as parallel 

280 

281 if world is None: 

282 world = parallel.world 

283 

284 closelater = [] 

285 

286 def opener(file, flags): 

287 return os.open(file, flags | CEW_FLAGS) 

288 

289 try: 

290 error = 0 

291 if world.rank == 0: 

292 try: 

293 fd = open(filename, 'wb', opener=opener) 

294 except OSError as ex: 

295 error = ex.errno 

296 else: 

297 closelater.append(fd) 

298 else: 

299 fd = open(os.devnull, 'wb') 

300 closelater.append(fd) 

301 

302 # Synchronize: 

303 error = world.sum_scalar(error) 

304 if error == errno.EEXIST: 

305 return None 

306 if error: 

307 raise OSError(error, 'Error', filename) 

308 

309 return fd 

310 except BaseException: 

311 for fd in closelater: 

312 fd.close() 

313 raise 

314 

315 

316def opencew_text(*args, **kwargs): 

317 fd = opencew(*args, **kwargs) 

318 if fd is None: 

319 return None 

320 return io.TextIOWrapper(fd) 

321 

322 

323class Lock: 

324 def __init__(self, name='lock', world=None, timeout=float('inf')): 

325 self.name = str(name) 

326 self.timeout = timeout 

327 if world is None: 

328 from ase.parallel import world 

329 self.world = world 

330 

331 def acquire(self): 

332 dt = 0.2 

333 t1 = time.time() 

334 while True: 

335 fd = opencew(self.name, self.world) 

336 if fd is not None: 

337 self.fd = fd 

338 break 

339 time_left = self.timeout - (time.time() - t1) 

340 if time_left <= 0: 

341 raise TimeoutError 

342 time.sleep(min(dt, time_left)) 

343 dt *= 2 

344 

345 def release(self): 

346 self.world.barrier() 

347 # Important to close fd before deleting file on windows 

348 # as a WinError would otherwise be raised. 

349 self.fd.close() 

350 if self.world.rank == 0: 

351 os.remove(self.name) 

352 self.world.barrier() 

353 

354 def __enter__(self): 

355 self.acquire() 

356 

357 def __exit__(self, type, value, tb): 

358 self.release() 

359 

360 

361class OpenLock: 

362 def acquire(self): 

363 pass 

364 

365 def release(self): 

366 pass 

367 

368 def __enter__(self): 

369 pass 

370 

371 def __exit__(self, type, value, tb): 

372 pass 

373 

374 

375def search_current_git_hash(arg, world=None): 

376 """Search for .git directory and current git commit hash. 

377 

378 Parameters: 

379 

380 arg: str (directory path) or python module 

381 .git directory is searched from the parent directory of 

382 the given directory or module. 

383 """ 

384 if world is None: 

385 from ase.parallel import world 

386 if world.rank != 0: 

387 return None 

388 

389 # Check argument 

390 if isinstance(arg, str): 

391 # Directory path 

392 dpath = arg 

393 else: 

394 # Assume arg is module 

395 dpath = os.path.dirname(arg.__file__) 

396 # dpath = os.path.abspath(dpath) 

397 # in case this is just symlinked into $PYTHONPATH 

398 dpath = os.path.realpath(dpath) 

399 dpath = os.path.dirname(dpath) # Go to the parent directory 

400 git_dpath = os.path.join(dpath, '.git') 

401 if not os.path.isdir(git_dpath): 

402 # Replace this 'if' with a loop if you want to check 

403 # further parent directories 

404 return None 

405 HEAD_file = os.path.join(git_dpath, 'HEAD') 

406 if not os.path.isfile(HEAD_file): 

407 return None 

408 with open(HEAD_file) as fd: 

409 line = fd.readline().strip() 

410 if line.startswith('ref: '): 

411 ref = line[5:] 

412 ref_file = os.path.join(git_dpath, ref) 

413 else: 

414 # Assuming detached HEAD state 

415 ref_file = HEAD_file 

416 if not os.path.isfile(ref_file): 

417 return None 

418 with open(ref_file) as fd: 

419 line = fd.readline().strip() 

420 if all(c in string.hexdigits for c in line): 

421 return line 

422 return None 

423 

424 

425def rotate(rotations, rotation=np.identity(3)): 

426 """Convert string of format '50x,-10y,120z' to a rotation matrix. 

427 

428 Note that the order of rotation matters, i.e. '50x,40z' is different 

429 from '40z,50x'. 

430 """ 

431 

432 if rotations == '': 

433 return rotation.copy() 

434 

435 for i, a in [ 

436 ('xyz'.index(s[-1]), radians(float(s[:-1]))) 

437 for s in rotations.split(',') 

438 ]: 

439 s = sin(a) 

440 c = cos(a) 

441 if i == 0: 

442 rotation = np.dot(rotation, [(1, 0, 0), (0, c, s), (0, -s, c)]) 

443 elif i == 1: 

444 rotation = np.dot(rotation, [(c, 0, -s), (0, 1, 0), (s, 0, c)]) 

445 else: 

446 rotation = np.dot(rotation, [(c, s, 0), (-s, c, 0), (0, 0, 1)]) 

447 return rotation 

448 

449 

450def givens(a, b): 

451 """Solve the equation system:: 

452 

453 [ c s] [a] [r] 

454 [ ] . [ ] = [ ] 

455 [-s c] [b] [0] 

456 """ 

457 sgn = np.sign 

458 if b == 0: 

459 c = sgn(a) 

460 s = 0 

461 r = abs(a) 

462 elif abs(b) >= abs(a): 

463 cot = a / b 

464 u = sgn(b) * (1 + cot**2) ** 0.5 

465 s = 1.0 / u 

466 c = s * cot 

467 r = b * u 

468 else: 

469 tan = b / a 

470 u = sgn(a) * (1 + tan**2) ** 0.5 

471 c = 1.0 / u 

472 s = c * tan 

473 r = a * u 

474 return c, s, r 

475 

476 

477def irotate(rotation, initial=np.identity(3)): 

478 """Determine x, y, z rotation angles from rotation matrix.""" 

479 a = np.dot(initial, rotation) 

480 cx, sx, rx = givens(a[2, 2], a[1, 2]) 

481 cy, sy, _ry = givens(rx, a[0, 2]) 

482 cz, sz, _rz = givens( 

483 cx * a[1, 1] - sx * a[2, 1], 

484 cy * a[0, 1] - sy * (sx * a[1, 1] + cx * a[2, 1]), 

485 ) 

486 x = degrees(atan2(sx, cx)) 

487 y = degrees(atan2(-sy, cy)) 

488 z = degrees(atan2(sz, cz)) 

489 return x, y, z 

490 

491 

492def pbc2pbc(pbc): 

493 newpbc = np.empty(3, bool) 

494 newpbc[:] = pbc 

495 return newpbc 

496 

497 

498def string2index(stridx: str) -> Union[int, slice, str]: 

499 """Convert index string to either int or slice""" 

500 if ':' not in stridx: 

501 # may contain database accessor 

502 try: 

503 return int(stridx) 

504 except ValueError: 

505 return stridx 

506 i = [None if s == '' else int(s) for s in stridx.split(':')] 

507 return slice(*i) 

508 

509 

510def hsv2rgb(h, s, v): 

511 """http://en.wikipedia.org/wiki/HSL_and_HSV 

512 

513 h (hue) in [0, 360[ 

514 s (saturation) in [0, 1] 

515 v (value) in [0, 1] 

516 

517 return rgb in range [0, 1] 

518 """ 

519 if v == 0: 

520 return 0, 0, 0 

521 if s == 0: 

522 return v, v, v 

523 

524 i, f = divmod(h / 60.0, 1) 

525 p = v * (1 - s) 

526 q = v * (1 - s * f) 

527 t = v * (1 - s * (1 - f)) 

528 

529 if i == 0: 

530 return v, t, p 

531 elif i == 1: 

532 return q, v, p 

533 elif i == 2: 

534 return p, v, t 

535 elif i == 3: 

536 return p, q, v 

537 elif i == 4: 

538 return t, p, v 

539 elif i == 5: 

540 return v, p, q 

541 else: 

542 raise RuntimeError('h must be in [0, 360]') 

543 

544 

545def hsv(array, s=0.9, v=0.9): 

546 array = (array + array.min()) * 359.0 / (array.max() - array.min()) 

547 result = np.empty((len(array.flat), 3)) 

548 for rgb, h in zip(result, array.flat): 

549 rgb[:] = hsv2rgb(h, s, v) 

550 return np.reshape(result, array.shape + (3,)) 

551 

552 

553# This code does the same, but requires pylab 

554# def cmap(array, name='hsv'): 

555# import pylab 

556# a = (array + array.min()) / array.ptp() 

557# rgba = getattr(pylab.cm, name)(a) 

558# return rgba[:-1] # return rgb only (not alpha) 

559 

560 

561def longsum(x): 

562 """128-bit floating point sum.""" 

563 return float(np.asarray(x, dtype=np.longdouble).sum()) 

564 

565 

566@contextmanager 

567def workdir(path, mkdir=False): 

568 """Temporarily change, and optionally create, working directory.""" 

569 path = Path(path) 

570 if mkdir: 

571 path.mkdir(parents=True, exist_ok=True) 

572 

573 olddir = os.getcwd() 

574 os.chdir(path) 

575 try: 

576 yield # Yield the Path or dirname maybe? 

577 finally: 

578 os.chdir(olddir) 

579 

580 

581class iofunction: 

582 """Decorate func so it accepts either str or file. 

583 

584 (Won't work on functions that return a generator.)""" 

585 

586 def __init__(self, mode): 

587 self.mode = mode 

588 

589 def __call__(self, func): 

590 @functools.wraps(func) 

591 def iofunc(file, *args, **kwargs): 

592 openandclose = isinstance(file, (str, PurePath)) 

593 fd = None 

594 try: 

595 if openandclose: 

596 fd = open(str(file), self.mode) 

597 else: 

598 fd = file 

599 obj = func(fd, *args, **kwargs) 

600 return obj 

601 finally: 

602 if openandclose and fd is not None: 

603 # fd may be None if open() failed 

604 fd.close() 

605 

606 return iofunc 

607 

608 

609def writer(func): 

610 return iofunction('w')(func) 

611 

612 

613def reader(func): 

614 return iofunction('r')(func) 

615 

616 

617# The next two functions are for hotplugging into a JSONable class 

618# using the jsonable decorator. We are supposed to have this kind of stuff 

619# in ase.io.jsonio, but we'd rather import them from a 'basic' module 

620# like ase/utils than one which triggers a lot of extra (cyclic) imports. 

621 

622 

623def write_json(self, fd): 

624 """Write to JSON file.""" 

625 from ase.io.jsonio import write_json as _write_json 

626 

627 _write_json(fd, self) 

628 

629 

630@classmethod # type: ignore[misc] 

631def read_json(cls, fd): 

632 """Read new instance from JSON file.""" 

633 from ase.io.jsonio import read_json as _read_json 

634 

635 obj = _read_json(fd) 

636 assert isinstance(obj, cls) 

637 return obj 

638 

639 

640def jsonable(name): 

641 """Decorator for facilitating JSON I/O with a class. 

642 

643 Pokes JSON-based read and write functions into the class. 

644 

645 In order to write an object to JSON, it needs to be a known simple type 

646 (such as ndarray, float, ...) or implement todict(). If the class 

647 defines a string called ase_objtype, the decoder will want to convert 

648 the object back into its original type when reading.""" 

649 

650 def jsonableclass(cls): 

651 cls.ase_objtype = name 

652 if not hasattr(cls, 'todict'): 

653 raise TypeError('Class must implement todict()') 

654 

655 # We may want the write and read to be optional. 

656 # E.g. a calculator might want to be JSONable, but not 

657 # that .write() produces a JSON file. 

658 # 

659 # This is mostly for 'lightweight' object IO. 

660 cls.write = write_json 

661 cls.read = read_json 

662 return cls 

663 

664 return jsonableclass 

665 

666 

667class ExperimentalFeatureWarning(Warning): 

668 pass 

669 

670 

671def experimental(func): 

672 """Decorator for functions not ready for production use.""" 

673 

674 @functools.wraps(func) 

675 def expfunc(*args, **kwargs): 

676 warnings.warn( 

677 'This function may change or misbehave: {}()'.format( 

678 func.__qualname__ 

679 ), 

680 ExperimentalFeatureWarning, 

681 ) 

682 return func(*args, **kwargs) 

683 

684 return expfunc 

685 

686 

687@deprecated('use functools.cached_property instead') 

688def lazymethod(meth): 

689 """Decorator for lazy evaluation and caching of data. 

690 

691 Example:: 

692 

693 class MyClass: 

694 

695 @lazymethod 

696 def thing(self): 

697 return expensive_calculation() 

698 

699 The method body is only executed first time thing() is called, and 

700 its return value is stored. Subsequent calls return the cached 

701 value. 

702 

703 .. deprecated:: 3.25.0 

704 """ 

705 name = meth.__name__ 

706 

707 @functools.wraps(meth) 

708 def getter(self): 

709 try: 

710 cache = self._lazy_cache 

711 except AttributeError: 

712 cache = self._lazy_cache = {} 

713 

714 if name not in cache: 

715 cache[name] = meth(self) 

716 return cache[name] 

717 

718 return getter 

719 

720 

721def atoms_to_spglib_cell(atoms): 

722 """Convert atoms into data suitable for calling spglib.""" 

723 return ( 

724 atoms.get_cell(), 

725 atoms.get_scaled_positions(), 

726 atoms.get_atomic_numbers(), 

727 ) 

728 

729 

730def warn_legacy(feature_name): 

731 warnings.warn( 

732 f'The {feature_name} feature is untested and ASE developers do not ' 

733 'know whether it works or how to use it. Please rehabilitate it ' 

734 '(by writing unittests) or it may be removed.', 

735 FutureWarning, 

736 ) 

737 

738 

739@deprecated('use functools.cached_property instead') 

740def lazyproperty(meth): 

741 """Decorator like lazymethod, but making item available as a property. 

742 

743 .. deprecated:: 3.25.0 

744 """ 

745 return property(lazymethod(meth)) 

746 

747 

748class _DelExitStack(ExitStack): 

749 # We don't want IOContext itself to implement __del__, since IOContext 

750 # might be subclassed, and we don't want __del__ on objects that we 

751 # don't fully control. Therefore we make a little custom class 

752 # that nobody else refers to, and that has the __del__. 

753 def __del__(self): 

754 self.close() 

755 

756 

757class IOContext: 

758 @functools.cached_property 

759 def _exitstack(self): 

760 return _DelExitStack() 

761 

762 def __enter__(self): 

763 return self 

764 

765 def __exit__(self, *args): 

766 self.close() 

767 

768 def closelater(self, fd): 

769 return self._exitstack.enter_context(fd) 

770 

771 def close(self): 

772 self._exitstack.close() 

773 

774 def openfile(self, file, comm, mode='w'): 

775 if hasattr(file, 'close'): 

776 return file # File already opened, not for us to close. 

777 

778 encoding = None if mode.endswith('b') else 'utf-8' 

779 

780 if file is None or comm.rank != 0: 

781 return self.closelater( 

782 open(os.devnull, mode=mode, encoding=encoding) 

783 ) 

784 

785 if file == '-': 

786 return sys.stdout 

787 

788 return self.closelater(open(file, mode=mode, encoding=encoding)) 

789 

790 

791def get_python_package_path_description( 

792 package, default='module has no path' 

793) -> str: 

794 """Helper to get path description of a python package/module 

795 

796 If path has multiple elements, the first one is returned. 

797 If it is empty, the default is returned. 

798 Exceptions are returned as strings default+(exception). 

799 Always returns a string. 

800 """ 

801 try: 

802 p = list(package.__path__) 

803 if p: 

804 return str(p[0]) 

805 else: 

806 return default 

807 except Exception as ex: 

808 return f'{default} ({ex})'