Coverage for ase / formula.py: 91.45%

269 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-03-30 08:22 +0000

1# fmt: off 

2 

3import re 

4from collections.abc import Sequence 

5from functools import lru_cache 

6from math import gcd 

7from typing import Union 

8 

9from ase.data import atomic_numbers, chemical_symbols 

10 

11# For type hints (A, A2, A+B): 

12Tree = str | tuple['Tree', int] | list['Tree'] 

13 

14 

15class Formula: 

16 def __init__(self, 

17 formula: Union[str, 'Formula'] = '', 

18 *, 

19 strict: bool = False, 

20 format: str = '', 

21 _tree: Tree | None = None, 

22 _count: dict[str, int] | None = None): 

23 """Chemical formula object. 

24 

25 Parameters 

26 ---------- 

27 formula: str 

28 Text string representation of formula. Examples: ``'6CO2'``, 

29 ``'30Cu+2CO'``, ``'Pt(CO)6'``. 

30 strict: bool 

31 Only allow real chemical symbols. 

32 format: str 

33 Reorder according to *format*. Must be one of hill, metal, 

34 ab2, a2b, periodic or reduce. 

35 

36 Examples 

37 -------- 

38 >>> from ase.formula import Formula 

39 >>> w = Formula('H2O') 

40 >>> w.count() 

41 {'H': 2, 'O': 1} 

42 >>> 'H' in w 

43 True 

44 >>> w == 'HOH' 

45 True 

46 >>> f'{w:latex}' 

47 'H$_{2}$O' 

48 >>> w.format('latex') 

49 'H$_{2}$O' 

50 >>> divmod(6 * w + 'Cu', w) 

51 (6, Formula('Cu')) 

52 

53 Raises 

54 ------ 

55 ValueError 

56 on malformed formula 

57 """ 

58 

59 # Be sure that Formula(x) works the same whether x is string or Formula 

60 assert isinstance(formula, (str, Formula)) 

61 formula = str(formula) 

62 

63 if format: 

64 assert _tree is None and _count is None 

65 if format not in {'hill', 'metal', 'abc', 'reduce', 'ab2', 'a2b', 

66 'periodic'}: 

67 raise ValueError(f'Illegal format: {format}') 

68 formula = Formula(formula).format(format) 

69 

70 self._formula = formula 

71 

72 self._tree = _tree or parse(formula) 

73 self._count = _count or count_tree(self._tree) 

74 if strict: 

75 for symbol in self._count: 

76 if symbol not in atomic_numbers: 

77 raise ValueError('Unknown chemical symbol: ' + symbol) 

78 

79 def convert(self, fmt: str) -> 'Formula': 

80 """Reformat this formula as a new Formula. 

81 

82 Same formatting rules as Formula(format=...) keyword. 

83 """ 

84 return Formula(self._formula, format=fmt) 

85 

86 def count(self) -> dict[str, int]: 

87 """Return dictionary mapping chemical symbol to number of atoms. 

88 

89 Example 

90 ------- 

91 >>> Formula('H2O').count() 

92 {'H': 2, 'O': 1} 

93 """ 

94 return self._count.copy() 

95 

96 def reduce(self) -> tuple['Formula', int]: 

97 """Reduce formula. 

98 

99 Returns 

100 ------- 

101 formula: Formula 

102 Reduced formula. 

103 n: int 

104 Number of reduced formula units. 

105 

106 Example 

107 ------- 

108 >>> Formula('2H2O').reduce() 

109 (Formula('H2O'), 2) 

110 """ 

111 dct, N = self._reduce() 

112 return self.from_dict(dct), N 

113 

114 def stoichiometry(self) -> tuple['Formula', 'Formula', int]: 

115 """Reduce to unique stoichiometry using "chemical symbols" A, B, C, ... 

116 

117 Examples 

118 -------- 

119 >>> Formula('CO2').stoichiometry() 

120 (Formula('AB2'), Formula('CO2'), 1) 

121 >>> Formula('(H2O)4').stoichiometry() 

122 (Formula('AB2'), Formula('OH2'), 4) 

123 """ 

124 count1, N = self._reduce() 

125 c = ord('A') 

126 count2 = {} 

127 count3 = {} 

128 for n, symb in sorted((n, symb) 

129 for symb, n in count1.items()): 

130 count2[chr(c)] = n 

131 count3[symb] = n 

132 c += 1 

133 return self.from_dict(count2), self.from_dict(count3), N 

134 

135 def format(self, fmt: str = '') -> str: 

136 """Format formula as string. 

137 

138 Formats: 

139 

140 * ``'hill'``: alphabetically ordered with C and H first 

141 * ``'metal'``: alphabetically ordered with metals first 

142 * ``'ab2'``: count-ordered first then alphabetically ordered 

143 * ``'abc'``: old name for ``'ab2'`` 

144 * ``'a2b'``: reverse count-ordered first then alphabetically ordered 

145 * ``'periodic'``: periodic-table ordered: period first then group 

146 * ``'reduce'``: Reduce and keep order (ABBBC -> AB3C) 

147 * ``'latex'``: LaTeX representation 

148 * ``'html'``: HTML representation 

149 * ``'rest'``: reStructuredText representation 

150 

151 Example 

152 ------- 

153 >>> Formula('H2O').format('html') 

154 'H<sub>2</sub>O' 

155 """ 

156 return format(self, fmt) 

157 

158 def __format__(self, fmt: str) -> str: 

159 """Format Formula as str. 

160 

161 Possible formats: ``'hill'``, ``'metal'``, ``'abc'``, ``'reduce'``, 

162 ``'latex'``, ``'html'``, ``'rest'``. 

163 

164 Example 

165 ------- 

166 >>> f = Formula('OH2') 

167 >>> '{f}, {f:hill}, {f:latex}'.format(f=f) 

168 'OH2, H2O, OH$_{2}$' 

169 """ 

170 

171 if fmt == 'hill': 

172 count = self.count() 

173 count2 = {symb: count.pop(symb) for symb in 'CH' if symb in count} 

174 for symb, n in sorted(count.items()): 

175 count2[symb] = n 

176 return dict2str(count2) 

177 

178 if fmt == 'metal': 

179 count = self.count() 

180 result2 = [(s, count.pop(s)) for s in non_metals if s in count] 

181 result = [(s, count[s]) for s in sorted(count)] 

182 result += sorted(result2) 

183 return dict2str(dict(result)) 

184 

185 if fmt == 'abc' or fmt == 'ab2': 

186 _, f, N = self.stoichiometry() 

187 return dict2str({symb: n * N for symb, n in f._count.items()}) 

188 

189 if fmt == 'a2b': 

190 _, f, N = self.stoichiometry() 

191 return dict2str({symb: -n * N 

192 for n, symb 

193 in sorted([(-n, symb) for symb, n 

194 in f._count.items()])}) 

195 

196 if fmt == 'periodic': 

197 count = self.count() 

198 order = periodic_table_order() 

199 items = sorted(count.items(), 

200 key=lambda item: order.get(item[0], 0)) 

201 return ''.join(symb + (str(n) if n > 1 else '') 

202 for symb, n in items) 

203 

204 if fmt == 'reduce': 

205 symbols = list(self) 

206 nsymb = len(symbols) 

207 parts = [] 

208 i1 = 0 

209 for i2, symbol in enumerate(symbols): 

210 if i2 == nsymb - 1 or symbol != symbols[i2 + 1]: 

211 parts.append(symbol) 

212 m = i2 + 1 - i1 

213 if m > 1: 

214 parts.append(str(m)) 

215 i1 = i2 + 1 

216 return ''.join(parts) 

217 

218 if fmt == 'latex': 

219 return self._tostr('$_{', '}$') 

220 

221 if fmt == 'html': 

222 return self._tostr('<sub>', '</sub>') 

223 

224 if fmt == 'rest': 

225 return self._tostr(r'\ :sub:`', r'`\ ') 

226 

227 if fmt == '': 

228 return self._formula 

229 

230 raise ValueError('Invalid format specifier') 

231 

232 @staticmethod 

233 def from_dict(dct: dict[str, int]) -> 'Formula': 

234 """Convert dict to Formula. 

235 

236 >>> Formula.from_dict({'H': 2}) 

237 Formula('H2') 

238 """ 

239 dct2 = {} 

240 for symb, n in dct.items(): 

241 if not (isinstance(symb, str) and isinstance(n, int) and n >= 0): 

242 raise ValueError(f'Bad dictionary: {dct}') 

243 if n > 0: # filter out n=0 symbols 

244 dct2[symb] = n 

245 return Formula(dict2str(dct2), 

246 _tree=[([(symb, n) for symb, n in dct2.items()], 1)], 

247 _count=dct2) 

248 

249 @staticmethod 

250 def from_list(symbols: Sequence[str]) -> 'Formula': 

251 """Convert list of chemical symbols to Formula.""" 

252 return Formula(''.join(symbols), 

253 _tree=[(symbols[:], 1)]) # type: ignore[list-item] 

254 

255 def __len__(self) -> int: 

256 """Number of atoms.""" 

257 return sum(self._count.values()) 

258 

259 def __getitem__(self, symb: str) -> int: 

260 """Number of atoms with chemical symbol *symb*.""" 

261 return self._count.get(symb, 0) 

262 

263 def __contains__(self, f: Union[str, 'Formula']) -> bool: 

264 """Check if formula contains chemical symbols in *f*. 

265 

266 Type of *f* must be str or Formula. 

267 

268 Examples 

269 -------- 

270 >>> 'OH' in Formula('H2O') 

271 True 

272 >>> 'O2' in Formula('H2O') 

273 False 

274 """ 

275 if isinstance(f, str): 

276 f = Formula(f) 

277 for symb, n in f._count.items(): 

278 if self[symb] < n: 

279 return False 

280 return True 

281 

282 def __eq__(self, other) -> bool: 

283 """Equality check. 

284 

285 Note that order is not important. 

286 

287 Example 

288 ------- 

289 >>> Formula('CO') == Formula('OC') 

290 True 

291 """ 

292 if isinstance(other, str): 

293 other = Formula(other) 

294 elif not isinstance(other, Formula): 

295 return False 

296 return self._count == other._count 

297 

298 def __add__(self, other: Union[str, 'Formula']) -> 'Formula': 

299 """Add two formulas.""" 

300 if not isinstance(other, str): 

301 other = other._formula 

302 return Formula(self._formula + '+' + other) 

303 

304 def __radd__(self, other: str): # -> Formula 

305 return Formula(other) + self 

306 

307 def __mul__(self, N: int) -> 'Formula': 

308 """Repeat formula `N` times.""" 

309 if N == 0: 

310 return Formula('') 

311 return self.from_dict({symb: n * N 

312 for symb, n in self._count.items()}) 

313 

314 def __rmul__(self, N: int): # -> Formula 

315 return self * N 

316 

317 def __divmod__(self, 

318 other: Union['Formula', str]) -> tuple[int, 'Formula']: 

319 """Return the tuple (self // other, self % other). 

320 

321 Invariant:: 

322 

323 div, mod = divmod(self, other) 

324 div * other + mod == self 

325 

326 Example 

327 ------- 

328 >>> divmod(Formula('H2O'), 'H') 

329 (2, Formula('O')) 

330 """ 

331 if isinstance(other, str): 

332 other = Formula(other) 

333 N = min(self[symb] // n for symb, n in other._count.items()) 

334 dct = self.count() 

335 if N: 

336 for symb, n in other._count.items(): 

337 dct[symb] -= n * N 

338 if dct[symb] == 0: 

339 del dct[symb] 

340 return N, self.from_dict(dct) 

341 

342 def __rdivmod__(self, other): 

343 return divmod(Formula(other), self) 

344 

345 def __mod__(self, other): 

346 return divmod(self, other)[1] 

347 

348 def __rmod__(self, other): 

349 return Formula(other) % self 

350 

351 def __floordiv__(self, other): 

352 return divmod(self, other)[0] 

353 

354 def __rfloordiv__(self, other): 

355 return Formula(other) // self 

356 

357 def __iter__(self): 

358 return self._tree_iter() 

359 

360 def _tree_iter(self, tree=None): 

361 if tree is None: 

362 tree = self._tree 

363 if isinstance(tree, str): 

364 yield tree 

365 elif isinstance(tree, tuple): 

366 tree, N = tree 

367 for _ in range(N): 

368 yield from self._tree_iter(tree) 

369 else: 

370 for tree in tree: 

371 yield from self._tree_iter(tree) 

372 

373 def __str__(self): 

374 return self._formula 

375 

376 def __repr__(self): 

377 return f'Formula({self._formula!r})' 

378 

379 def _reduce(self): 

380 N = 0 

381 for n in self._count.values(): 

382 if N == 0: 

383 N = n 

384 else: 

385 N = gcd(n, N) 

386 dct = {symb: n // N for symb, n in self._count.items()} 

387 return dct, N 

388 

389 def _tostr(self, sub1, sub2): 

390 parts = [] 

391 for tree, n in self._tree: 

392 s = tree2str(tree, sub1, sub2) 

393 if s[0] == '(' and s[-1] == ')': 

394 s = s[1:-1] 

395 if n > 1: 

396 s = str(n) + s 

397 parts.append(s) 

398 return '+'.join(parts) 

399 

400 

401def dict2str(dct: dict[str, int]) -> str: 

402 """Convert symbol-to-number dict to str. 

403 

404 >>> dict2str({'A': 1, 'B': 2}) 

405 'AB2' 

406 """ 

407 return ''.join(symb + (str(n) if n > 1 else '') 

408 for symb, n in dct.items()) 

409 

410 

411def parse(f: str) -> Tree: 

412 """Convert formula string to tree structure. 

413 

414 >>> parse('2A+BC2') 

415 [('A', 2), (['B', ('C', 2)], 1)] 

416 """ 

417 if not f: 

418 return [] 

419 parts = f.split('+') 

420 result = [] 

421 for part in parts: 

422 n, f = strip_number(part) 

423 result.append((parse2(f), n)) 

424 return result # type: ignore[return-value] 

425 

426 

427def parse2(f: str) -> Tree: 

428 """Convert formula string to tree structure (no "+" symbols). 

429 

430 >>> parse('10(H2O)') 

431 [(([('H', 2), 'O'], 1), 10)] 

432 """ 

433 units = [] 

434 while f: 

435 unit: str | tuple[str, int] | Tree 

436 if f[0] == '(': 

437 level = 0 

438 for i, c in enumerate(f[1:], 1): 

439 if c == '(': 

440 level += 1 

441 elif c == ')': 

442 if level == 0: 

443 break 

444 level -= 1 

445 else: 

446 raise ValueError 

447 f2 = f[1:i] 

448 n, f = strip_number(f[i + 1:]) 

449 unit = (parse2(f2), n) 

450 else: 

451 m = re.match('([A-Z][a-z]?)([0-9]*)', f) 

452 if m is None: 

453 raise ValueError 

454 symb = m.group(1) 

455 number = m.group(2) 

456 if number: 

457 unit = (symb, int(number)) 

458 else: 

459 unit = symb 

460 f = f[m.end():] 

461 units.append(unit) 

462 if len(units) == 1: 

463 return unit 

464 return units 

465 

466 

467def strip_number(s: str) -> tuple[int, str]: 

468 """Strip leading nuimber. 

469 

470 >>> strip_number('10AB2') 

471 (10, 'AB2') 

472 >>> strip_number('AB2') 

473 (1, 'AB2') 

474 """ 

475 m = re.match('[0-9]*', s) 

476 assert m is not None 

477 return int(m.group() or 1), s[m.end():] 

478 

479 

480def tree2str(tree: Tree, 

481 sub1: str, sub2: str) -> str: 

482 """Helper function for html, latex and rest formats.""" 

483 if isinstance(tree, str): 

484 return tree 

485 if isinstance(tree, tuple): 

486 tree, N = tree 

487 s = tree2str(tree, sub1, sub2) 

488 if N == 1: 

489 if s[0] == '(' and s[-1] == ')': 

490 return s[1:-1] 

491 return s 

492 return s + sub1 + str(N) + sub2 

493 return '(' + ''.join(tree2str(tree, sub1, sub2) for tree in tree) + ')' 

494 

495 

496def count_tree(tree: Tree) -> dict[str, int]: 

497 if isinstance(tree, str): 

498 return {tree: 1} 

499 if isinstance(tree, tuple): 

500 tree, N = tree 

501 return {symb: n * N for symb, n in count_tree(tree).items()} 

502 dct = {} # type: dict[str, int] 

503 for tree in tree: 

504 for symb, n in count_tree(tree).items(): 

505 m = dct.get(symb, 0) 

506 dct[symb] = m + n 

507 return dct 

508 

509 

510# non metals, half-metals/metalloid, halogen, noble gas: 

511non_metals = ['H', 'He', 'B', 'C', 'N', 'O', 'F', 'Ne', 

512 'Si', 'P', 'S', 'Cl', 'Ar', 

513 'Ge', 'As', 'Se', 'Br', 'Kr', 

514 'Sb', 'Te', 'I', 'Xe', 

515 'Po', 'At', 'Rn'] 

516 

517 

518@lru_cache 

519def periodic_table_order() -> dict[str, int]: 

520 """Create dict for sorting after period first then row.""" 

521 return {symbol: n for n, symbol in enumerate(chemical_symbols[87:] + 

522 chemical_symbols[55:87] + 

523 chemical_symbols[37:55] + 

524 chemical_symbols[19:37] + 

525 chemical_symbols[11:19] + 

526 chemical_symbols[3:11] + 

527 chemical_symbols[1:3])} 

528 

529 

530# Backwards compatibility: 

531def formula_hill(numbers, empirical=False): 

532 """Convert list of atomic numbers to a chemical formula as a string. 

533 

534 Elements are alphabetically ordered with C and H first. 

535 

536 If argument `empirical`, element counts will be divided by greatest common 

537 divisor to yield an empirical formula""" 

538 symbols = [chemical_symbols[Z] for Z in numbers] 

539 f = Formula('', _tree=[(symbols, 1)]) 

540 if empirical: 

541 f, _ = f.reduce() 

542 return f.format('hill') 

543 

544 

545# Backwards compatibility: 

546def formula_metal(numbers, empirical=False): 

547 """Convert list of atomic numbers to a chemical formula as a string. 

548 

549 Elements are alphabetically ordered with metals first. 

550 

551 If argument `empirical`, element counts will be divided by greatest common 

552 divisor to yield an empirical formula""" 

553 symbols = [chemical_symbols[Z] for Z in numbers] 

554 f = Formula('', _tree=[(symbols, 1)]) 

555 if empirical: 

556 f, _ = f.reduce() 

557 return f.format('metal')