Coverage for /builds/ase/ase/ase/formula.py: 91.42%

268 statements  

« prev     ^ index     » next       coverage.py v7.5.3, created at 2025-08-02 00:12 +0000

1# fmt: off 

2 

3import re 

4from functools import lru_cache 

5from math import gcd 

6from typing import Dict, List, Sequence, Tuple, Union 

7 

8from ase.data import atomic_numbers, chemical_symbols 

9 

10# For type hints (A, A2, A+B): 

11Tree = Union[str, Tuple['Tree', int], List['Tree']] 

12 

13 

14class Formula: 

15 def __init__(self, 

16 formula: Union[str, 'Formula'] = '', 

17 *, 

18 strict: bool = False, 

19 format: str = '', 

20 _tree: Tree = None, 

21 _count: Dict[str, int] = None): 

22 """Chemical formula object. 

23 

24 Parameters 

25 ---------- 

26 formula: str 

27 Text string representation of formula. Examples: ``'6CO2'``, 

28 ``'30Cu+2CO'``, ``'Pt(CO)6'``. 

29 strict: bool 

30 Only allow real chemical symbols. 

31 format: str 

32 Reorder according to *format*. Must be one of hill, metal, 

33 ab2, a2b, periodic or reduce. 

34 

35 Examples 

36 -------- 

37 >>> from ase.formula import Formula 

38 >>> w = Formula('H2O') 

39 >>> w.count() 

40 {'H': 2, 'O': 1} 

41 >>> 'H' in w 

42 True 

43 >>> w == 'HOH' 

44 True 

45 >>> f'{w:latex}' 

46 'H$_{2}$O' 

47 >>> w.format('latex') 

48 'H$_{2}$O' 

49 >>> divmod(6 * w + 'Cu', w) 

50 (6, Formula('Cu')) 

51 

52 Raises 

53 ------ 

54 ValueError 

55 on malformed formula 

56 """ 

57 

58 # Be sure that Formula(x) works the same whether x is string or Formula 

59 assert isinstance(formula, (str, Formula)) 

60 formula = str(formula) 

61 

62 if format: 

63 assert _tree is None and _count is None 

64 if format not in {'hill', 'metal', 'abc', 'reduce', 'ab2', 'a2b', 

65 'periodic'}: 

66 raise ValueError(f'Illegal format: {format}') 

67 formula = Formula(formula).format(format) 

68 

69 self._formula = formula 

70 

71 self._tree = _tree or parse(formula) 

72 self._count = _count or count_tree(self._tree) 

73 if strict: 

74 for symbol in self._count: 

75 if symbol not in atomic_numbers: 

76 raise ValueError('Unknown chemical symbol: ' + symbol) 

77 

78 def convert(self, fmt: str) -> 'Formula': 

79 """Reformat this formula as a new Formula. 

80 

81 Same formatting rules as Formula(format=...) keyword. 

82 """ 

83 return Formula(self._formula, format=fmt) 

84 

85 def count(self) -> Dict[str, int]: 

86 """Return dictionary mapping chemical symbol to number of atoms. 

87 

88 Example 

89 ------- 

90 >>> Formula('H2O').count() 

91 {'H': 2, 'O': 1} 

92 """ 

93 return self._count.copy() 

94 

95 def reduce(self) -> Tuple['Formula', int]: 

96 """Reduce formula. 

97 

98 Returns 

99 ------- 

100 formula: Formula 

101 Reduced formula. 

102 n: int 

103 Number of reduced formula units. 

104 

105 Example 

106 ------- 

107 >>> Formula('2H2O').reduce() 

108 (Formula('H2O'), 2) 

109 """ 

110 dct, N = self._reduce() 

111 return self.from_dict(dct), N 

112 

113 def stoichiometry(self) -> Tuple['Formula', 'Formula', int]: 

114 """Reduce to unique stoichiometry using "chemical symbols" A, B, C, ... 

115 

116 Examples 

117 -------- 

118 >>> Formula('CO2').stoichiometry() 

119 (Formula('AB2'), Formula('CO2'), 1) 

120 >>> Formula('(H2O)4').stoichiometry() 

121 (Formula('AB2'), Formula('OH2'), 4) 

122 """ 

123 count1, N = self._reduce() 

124 c = ord('A') 

125 count2 = {} 

126 count3 = {} 

127 for n, symb in sorted((n, symb) 

128 for symb, n in count1.items()): 

129 count2[chr(c)] = n 

130 count3[symb] = n 

131 c += 1 

132 return self.from_dict(count2), self.from_dict(count3), N 

133 

134 def format(self, fmt: str = '') -> str: 

135 """Format formula as string. 

136 

137 Formats: 

138 

139 * ``'hill'``: alphabetically ordered with C and H first 

140 * ``'metal'``: alphabetically ordered with metals first 

141 * ``'ab2'``: count-ordered first then alphabetically ordered 

142 * ``'abc'``: old name for ``'ab2'`` 

143 * ``'a2b'``: reverse count-ordered first then alphabetically ordered 

144 * ``'periodic'``: periodic-table ordered: period first then group 

145 * ``'reduce'``: Reduce and keep order (ABBBC -> AB3C) 

146 * ``'latex'``: LaTeX representation 

147 * ``'html'``: HTML representation 

148 * ``'rest'``: reStructuredText representation 

149 

150 Example 

151 ------- 

152 >>> Formula('H2O').format('html') 

153 'H<sub>2</sub>O' 

154 """ 

155 return format(self, fmt) 

156 

157 def __format__(self, fmt: str) -> str: 

158 """Format Formula as str. 

159 

160 Possible formats: ``'hill'``, ``'metal'``, ``'abc'``, ``'reduce'``, 

161 ``'latex'``, ``'html'``, ``'rest'``. 

162 

163 Example 

164 ------- 

165 >>> f = Formula('OH2') 

166 >>> '{f}, {f:hill}, {f:latex}'.format(f=f) 

167 'OH2, H2O, OH$_{2}$' 

168 """ 

169 

170 if fmt == 'hill': 

171 count = self.count() 

172 count2 = {symb: count.pop(symb) for symb in 'CH' if symb in count} 

173 for symb, n in sorted(count.items()): 

174 count2[symb] = n 

175 return dict2str(count2) 

176 

177 if fmt == 'metal': 

178 count = self.count() 

179 result2 = [(s, count.pop(s)) for s in non_metals if s in count] 

180 result = [(s, count[s]) for s in sorted(count)] 

181 result += sorted(result2) 

182 return dict2str(dict(result)) 

183 

184 if fmt == 'abc' or fmt == 'ab2': 

185 _, f, N = self.stoichiometry() 

186 return dict2str({symb: n * N for symb, n in f._count.items()}) 

187 

188 if fmt == 'a2b': 

189 _, f, N = self.stoichiometry() 

190 return dict2str({symb: -n * N 

191 for n, symb 

192 in sorted([(-n, symb) for symb, n 

193 in f._count.items()])}) 

194 

195 if fmt == 'periodic': 

196 count = self.count() 

197 order = periodic_table_order() 

198 items = sorted(count.items(), 

199 key=lambda item: order.get(item[0], 0)) 

200 return ''.join(symb + (str(n) if n > 1 else '') 

201 for symb, n in items) 

202 

203 if fmt == 'reduce': 

204 symbols = list(self) 

205 nsymb = len(symbols) 

206 parts = [] 

207 i1 = 0 

208 for i2, symbol in enumerate(symbols): 

209 if i2 == nsymb - 1 or symbol != symbols[i2 + 1]: 

210 parts.append(symbol) 

211 m = i2 + 1 - i1 

212 if m > 1: 

213 parts.append(str(m)) 

214 i1 = i2 + 1 

215 return ''.join(parts) 

216 

217 if fmt == 'latex': 

218 return self._tostr('$_{', '}$') 

219 

220 if fmt == 'html': 

221 return self._tostr('<sub>', '</sub>') 

222 

223 if fmt == 'rest': 

224 return self._tostr(r'\ :sub:`', r'`\ ') 

225 

226 if fmt == '': 

227 return self._formula 

228 

229 raise ValueError('Invalid format specifier') 

230 

231 @staticmethod 

232 def from_dict(dct: Dict[str, int]) -> 'Formula': 

233 """Convert dict to Formula. 

234 

235 >>> Formula.from_dict({'H': 2}) 

236 Formula('H2') 

237 """ 

238 dct2 = {} 

239 for symb, n in dct.items(): 

240 if not (isinstance(symb, str) and isinstance(n, int) and n >= 0): 

241 raise ValueError(f'Bad dictionary: {dct}') 

242 if n > 0: # filter out n=0 symbols 

243 dct2[symb] = n 

244 return Formula(dict2str(dct2), 

245 _tree=[([(symb, n) for symb, n in dct2.items()], 1)], 

246 _count=dct2) 

247 

248 @staticmethod 

249 def from_list(symbols: Sequence[str]) -> 'Formula': 

250 """Convert list of chemical symbols to Formula.""" 

251 return Formula(''.join(symbols), 

252 _tree=[(symbols[:], 1)]) # type: ignore[list-item] 

253 

254 def __len__(self) -> int: 

255 """Number of atoms.""" 

256 return sum(self._count.values()) 

257 

258 def __getitem__(self, symb: str) -> int: 

259 """Number of atoms with chemical symbol *symb*.""" 

260 return self._count.get(symb, 0) 

261 

262 def __contains__(self, f: Union[str, 'Formula']) -> bool: 

263 """Check if formula contains chemical symbols in *f*. 

264 

265 Type of *f* must be str or Formula. 

266 

267 Examples 

268 -------- 

269 >>> 'OH' in Formula('H2O') 

270 True 

271 >>> 'O2' in Formula('H2O') 

272 False 

273 """ 

274 if isinstance(f, str): 

275 f = Formula(f) 

276 for symb, n in f._count.items(): 

277 if self[symb] < n: 

278 return False 

279 return True 

280 

281 def __eq__(self, other) -> bool: 

282 """Equality check. 

283 

284 Note that order is not important. 

285 

286 Example 

287 ------- 

288 >>> Formula('CO') == Formula('OC') 

289 True 

290 """ 

291 if isinstance(other, str): 

292 other = Formula(other) 

293 elif not isinstance(other, Formula): 

294 return False 

295 return self._count == other._count 

296 

297 def __add__(self, other: Union[str, 'Formula']) -> 'Formula': 

298 """Add two formulas.""" 

299 if not isinstance(other, str): 

300 other = other._formula 

301 return Formula(self._formula + '+' + other) 

302 

303 def __radd__(self, other: str): # -> Formula 

304 return Formula(other) + self 

305 

306 def __mul__(self, N: int) -> 'Formula': 

307 """Repeat formula `N` times.""" 

308 if N == 0: 

309 return Formula('') 

310 return self.from_dict({symb: n * N 

311 for symb, n in self._count.items()}) 

312 

313 def __rmul__(self, N: int): # -> Formula 

314 return self * N 

315 

316 def __divmod__(self, 

317 other: Union['Formula', str]) -> Tuple[int, 'Formula']: 

318 """Return the tuple (self // other, self % other). 

319 

320 Invariant:: 

321 

322 div, mod = divmod(self, other) 

323 div * other + mod == self 

324 

325 Example 

326 ------- 

327 >>> divmod(Formula('H2O'), 'H') 

328 (2, Formula('O')) 

329 """ 

330 if isinstance(other, str): 

331 other = Formula(other) 

332 N = min(self[symb] // n for symb, n in other._count.items()) 

333 dct = self.count() 

334 if N: 

335 for symb, n in other._count.items(): 

336 dct[symb] -= n * N 

337 if dct[symb] == 0: 

338 del dct[symb] 

339 return N, self.from_dict(dct) 

340 

341 def __rdivmod__(self, other): 

342 return divmod(Formula(other), self) 

343 

344 def __mod__(self, other): 

345 return divmod(self, other)[1] 

346 

347 def __rmod__(self, other): 

348 return Formula(other) % self 

349 

350 def __floordiv__(self, other): 

351 return divmod(self, other)[0] 

352 

353 def __rfloordiv__(self, other): 

354 return Formula(other) // self 

355 

356 def __iter__(self): 

357 return self._tree_iter() 

358 

359 def _tree_iter(self, tree=None): 

360 if tree is None: 

361 tree = self._tree 

362 if isinstance(tree, str): 

363 yield tree 

364 elif isinstance(tree, tuple): 

365 tree, N = tree 

366 for _ in range(N): 

367 yield from self._tree_iter(tree) 

368 else: 

369 for tree in tree: 

370 yield from self._tree_iter(tree) 

371 

372 def __str__(self): 

373 return self._formula 

374 

375 def __repr__(self): 

376 return f'Formula({self._formula!r})' 

377 

378 def _reduce(self): 

379 N = 0 

380 for n in self._count.values(): 

381 if N == 0: 

382 N = n 

383 else: 

384 N = gcd(n, N) 

385 dct = {symb: n // N for symb, n in self._count.items()} 

386 return dct, N 

387 

388 def _tostr(self, sub1, sub2): 

389 parts = [] 

390 for tree, n in self._tree: 

391 s = tree2str(tree, sub1, sub2) 

392 if s[0] == '(' and s[-1] == ')': 

393 s = s[1:-1] 

394 if n > 1: 

395 s = str(n) + s 

396 parts.append(s) 

397 return '+'.join(parts) 

398 

399 

400def dict2str(dct: Dict[str, int]) -> str: 

401 """Convert symbol-to-number dict to str. 

402 

403 >>> dict2str({'A': 1, 'B': 2}) 

404 'AB2' 

405 """ 

406 return ''.join(symb + (str(n) if n > 1 else '') 

407 for symb, n in dct.items()) 

408 

409 

410def parse(f: str) -> Tree: 

411 """Convert formula string to tree structure. 

412 

413 >>> parse('2A+BC2') 

414 [('A', 2), (['B', ('C', 2)], 1)] 

415 """ 

416 if not f: 

417 return [] 

418 parts = f.split('+') 

419 result = [] 

420 for part in parts: 

421 n, f = strip_number(part) 

422 result.append((parse2(f), n)) 

423 return result # type: ignore[return-value] 

424 

425 

426def parse2(f: str) -> Tree: 

427 """Convert formula string to tree structure (no "+" symbols). 

428 

429 >>> parse('10(H2O)') 

430 [(([('H', 2), 'O'], 1), 10)] 

431 """ 

432 units = [] 

433 while f: 

434 unit: Union[str, Tuple[str, int], Tree] 

435 if f[0] == '(': 

436 level = 0 

437 for i, c in enumerate(f[1:], 1): 

438 if c == '(': 

439 level += 1 

440 elif c == ')': 

441 if level == 0: 

442 break 

443 level -= 1 

444 else: 

445 raise ValueError 

446 f2 = f[1:i] 

447 n, f = strip_number(f[i + 1:]) 

448 unit = (parse2(f2), n) 

449 else: 

450 m = re.match('([A-Z][a-z]?)([0-9]*)', f) 

451 if m is None: 

452 raise ValueError 

453 symb = m.group(1) 

454 number = m.group(2) 

455 if number: 

456 unit = (symb, int(number)) 

457 else: 

458 unit = symb 

459 f = f[m.end():] 

460 units.append(unit) 

461 if len(units) == 1: 

462 return unit 

463 return units 

464 

465 

466def strip_number(s: str) -> Tuple[int, str]: 

467 """Strip leading nuimber. 

468 

469 >>> strip_number('10AB2') 

470 (10, 'AB2') 

471 >>> strip_number('AB2') 

472 (1, 'AB2') 

473 """ 

474 m = re.match('[0-9]*', s) 

475 assert m is not None 

476 return int(m.group() or 1), s[m.end():] 

477 

478 

479def tree2str(tree: Tree, 

480 sub1: str, sub2: str) -> str: 

481 """Helper function for html, latex and rest formats.""" 

482 if isinstance(tree, str): 

483 return tree 

484 if isinstance(tree, tuple): 

485 tree, N = tree 

486 s = tree2str(tree, sub1, sub2) 

487 if N == 1: 

488 if s[0] == '(' and s[-1] == ')': 

489 return s[1:-1] 

490 return s 

491 return s + sub1 + str(N) + sub2 

492 return '(' + ''.join(tree2str(tree, sub1, sub2) for tree in tree) + ')' 

493 

494 

495def count_tree(tree: Tree) -> Dict[str, int]: 

496 if isinstance(tree, str): 

497 return {tree: 1} 

498 if isinstance(tree, tuple): 

499 tree, N = tree 

500 return {symb: n * N for symb, n in count_tree(tree).items()} 

501 dct = {} # type: Dict[str, int] 

502 for tree in tree: 

503 for symb, n in count_tree(tree).items(): 

504 m = dct.get(symb, 0) 

505 dct[symb] = m + n 

506 return dct 

507 

508 

509# non metals, half-metals/metalloid, halogen, noble gas: 

510non_metals = ['H', 'He', 'B', 'C', 'N', 'O', 'F', 'Ne', 

511 'Si', 'P', 'S', 'Cl', 'Ar', 

512 'Ge', 'As', 'Se', 'Br', 'Kr', 

513 'Sb', 'Te', 'I', 'Xe', 

514 'Po', 'At', 'Rn'] 

515 

516 

517@lru_cache 

518def periodic_table_order() -> Dict[str, int]: 

519 """Create dict for sorting after period first then row.""" 

520 return {symbol: n for n, symbol in enumerate(chemical_symbols[87:] + 

521 chemical_symbols[55:87] + 

522 chemical_symbols[37:55] + 

523 chemical_symbols[19:37] + 

524 chemical_symbols[11:19] + 

525 chemical_symbols[3:11] + 

526 chemical_symbols[1:3])} 

527 

528 

529# Backwards compatibility: 

530def formula_hill(numbers, empirical=False): 

531 """Convert list of atomic numbers to a chemical formula as a string. 

532 

533 Elements are alphabetically ordered with C and H first. 

534 

535 If argument `empirical`, element counts will be divided by greatest common 

536 divisor to yield an empirical formula""" 

537 symbols = [chemical_symbols[Z] for Z in numbers] 

538 f = Formula('', _tree=[(symbols, 1)]) 

539 if empirical: 

540 f, _ = f.reduce() 

541 return f.format('hill') 

542 

543 

544# Backwards compatibility: 

545def formula_metal(numbers, empirical=False): 

546 """Convert list of atomic numbers to a chemical formula as a string. 

547 

548 Elements are alphabetically ordered with metals first. 

549 

550 If argument `empirical`, element counts will be divided by greatest common 

551 divisor to yield an empirical formula""" 

552 symbols = [chemical_symbols[Z] for Z in numbers] 

553 f = Formula('', _tree=[(symbols, 1)]) 

554 if empirical: 

555 f, _ = f.reduce() 

556 return f.format('metal')