Coverage for ase / io / cif.py: 90.85%

492 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-03-30 08:22 +0000

1# fmt: off 

2 

3"""Module to read and write atoms in cif file format. 

4 

5See http://www.iucr.org/resources/cif/spec/version1.1/cifsyntax for a 

6description of the file format. STAR extensions as save frames, 

7global blocks, nested loops and multi-data values are not supported. 

8The "latin-1" encoding is required by the IUCR specification. 

9""" 

10 

11import collections.abc 

12import io 

13import re 

14import shlex 

15import warnings 

16from collections.abc import Iterator, Sequence 

17from typing import Any 

18 

19import numpy as np 

20 

21from ase import Atoms 

22from ase.cell import Cell 

23from ase.io.cif_unicode import format_unicode, handle_subscripts 

24from ase.spacegroup import crystal 

25from ase.spacegroup.spacegroup import Spacegroup, spacegroup_from_data 

26from ase.utils import iofunction 

27 

28rhombohedral_spacegroups = {146, 148, 155, 160, 161, 166, 167} 

29 

30 

31old_spacegroup_names = {'Abm2': 'Aem2', 

32 'Aba2': 'Aea2', 

33 'Cmca': 'Cmce', 

34 'Cmma': 'Cmme', 

35 'Ccca': 'Ccc1'} 

36 

37# CIF maps names to either single values or to multiple values via loops. 

38CIFDataValue = str | int | float 

39CIFData = CIFDataValue | list[CIFDataValue] 

40 

41 

42def convert_value(value: str) -> CIFDataValue: 

43 """Convert CIF value string to corresponding python type.""" 

44 value = value.strip() 

45 if re.match('(".*")|(\'.*\')$', value): 

46 return handle_subscripts(value[1:-1]) 

47 elif re.match(r'[+-]?\d+$', value): 

48 return int(value) 

49 elif re.match(r'[+-]?(?:\d+(?:\.\d*)?|\.\d+)(?:[eE][+-]?\d+)?$', value): 

50 return float(value) 

51 elif re.match(r'[+-]?(?:\d+(?:\.\d*)?|\.\d+)(?:[eE][+-]?\d+)?\(\d+\)$', 

52 value): 

53 return float(value[:value.index('(')]) # strip off uncertainties 

54 elif re.match(r'[+-]?(?:\d+(?:\.\d*)?|\.\d+)(?:[eE][+-]?\d+)?\(\d+$', 

55 value): 

56 warnings.warn(f'Badly formed number: "{value}"') 

57 return float(value[:value.index('(')]) # strip off uncertainties 

58 else: 

59 return handle_subscripts(value) 

60 

61 

62def parse_multiline_string(lines: list[str], line: str) -> str: 

63 """Parse semicolon-enclosed multiline string and return it.""" 

64 assert line[0] == ';' 

65 strings = [line[1:].lstrip()] 

66 while True: 

67 line = lines.pop().strip() 

68 if line[:1] == ';': 

69 break 

70 strings.append(line) 

71 return '\n'.join(strings).strip() 

72 

73 

74def parse_singletag(lines: list[str], line: str) -> tuple[str, CIFDataValue]: 

75 """Parse a CIF tag (entries starting with underscore). Returns 

76 a key-value pair.""" 

77 kv = line.split(None, 1) 

78 if len(kv) == 1: 

79 key = line 

80 line = lines.pop().strip() 

81 while not line or line[0] == '#': 

82 line = lines.pop().strip() 

83 if line[0] == ';': 

84 value = parse_multiline_string(lines, line) 

85 else: 

86 value = line 

87 else: 

88 key, value = kv 

89 return key, convert_value(value) 

90 

91 

92def parse_cif_loop_headers(lines: list[str]) -> Iterator[str]: 

93 while lines: 

94 line = lines.pop() 

95 tokens = line.split() 

96 

97 if len(tokens) == 1 and tokens[0].startswith('_'): 

98 header = tokens[0].lower() 

99 yield header 

100 else: 

101 lines.append(line) # 'undo' pop 

102 return 

103 

104 

105def parse_cif_loop_data(lines: list[str], 

106 ncolumns: int) -> list[list[CIFDataValue]]: 

107 columns: list[list[CIFDataValue]] = [[] for _ in range(ncolumns)] 

108 

109 tokens = [] 

110 while lines: 

111 line = lines.pop().strip() 

112 lowerline = line.lower() 

113 if (not line or 

114 line.startswith('_') or 

115 lowerline.startswith('data_') or 

116 lowerline.startswith('loop_')): 

117 lines.append(line) 

118 break 

119 

120 if line.startswith('#'): 

121 continue 

122 

123 line = line.split(' #')[0] 

124 

125 if line.startswith(';'): 

126 moretokens = [parse_multiline_string(lines, line)] 

127 else: 

128 if ncolumns == 1: 

129 moretokens = [line] 

130 else: 

131 moretokens = shlex.split(line, posix=False) 

132 

133 tokens += moretokens 

134 if len(tokens) < ncolumns: 

135 continue 

136 if len(tokens) == ncolumns: 

137 for i, token in enumerate(tokens): 

138 columns[i].append(convert_value(token)) 

139 else: 

140 warnings.warn(f'Wrong number {len(tokens)} of tokens, ' 

141 f'expected {ncolumns}: {tokens}') 

142 

143 # (Due to continue statements we cannot move this to start of loop) 

144 tokens = [] 

145 

146 if tokens: 

147 assert len(tokens) < ncolumns 

148 raise RuntimeError('CIF loop ended unexpectedly with incomplete row: ' 

149 f'{tokens}, expected {ncolumns} tokens') 

150 

151 return columns 

152 

153 

154def parse_loop(lines: list[str]) -> dict[str, list[CIFDataValue]]: 

155 """Parse a CIF loop. Returns a dict with column tag names as keys 

156 and a lists of the column content as values.""" 

157 

158 headers = list(parse_cif_loop_headers(lines)) 

159 # Dict would be better. But there can be repeated headers. 

160 

161 columns = parse_cif_loop_data(lines, len(headers)) 

162 

163 columns_dict = {} 

164 for i, header in enumerate(headers): 

165 if header in columns_dict: 

166 warnings.warn(f'Duplicated loop tags: {header}') 

167 else: 

168 columns_dict[header] = columns[i] 

169 return columns_dict 

170 

171 

172def parse_items(lines: list[str], line: str) -> dict[str, CIFData]: 

173 """Parse a CIF data items and return a dict with all tags.""" 

174 tags: dict[str, CIFData] = {} 

175 

176 while True: 

177 if not lines: 

178 break 

179 line = lines.pop().strip() 

180 if not line: 

181 continue 

182 lowerline = line.lower() 

183 if not line or line.startswith('#'): 

184 continue 

185 elif line.startswith('_'): 

186 key, value = parse_singletag(lines, line) 

187 tags[key.lower()] = value 

188 elif lowerline.startswith('loop_'): 

189 tags.update(parse_loop(lines)) 

190 elif lowerline.startswith('data_'): 

191 if line: 

192 lines.append(line) 

193 break 

194 elif line.startswith(';'): 

195 parse_multiline_string(lines, line) 

196 else: 

197 raise ValueError(f'Unexpected CIF file entry: "{line}"') 

198 return tags 

199 

200 

201class NoStructureData(RuntimeError): 

202 pass 

203 

204 

205class CIFBlock(collections.abc.Mapping): 

206 """A block (i.e., a single system) in a crystallographic information file. 

207 

208 Use this object to query CIF tags or import information as ASE objects.""" 

209 

210 cell_tags = ['_cell_length_a', '_cell_length_b', '_cell_length_c', 

211 '_cell_angle_alpha', '_cell_angle_beta', '_cell_angle_gamma'] 

212 

213 def __init__(self, name: str, tags: dict[str, CIFData]): 

214 self.name = name 

215 self._tags = tags 

216 

217 def __repr__(self) -> str: 

218 tags = set(self._tags) 

219 return f'CIFBlock({self.name}, tags={tags})' 

220 

221 def __getitem__(self, key: str) -> CIFData: 

222 return self._tags[key] 

223 

224 def __iter__(self) -> Iterator[str]: 

225 return iter(self._tags) 

226 

227 def __len__(self) -> int: 

228 return len(self._tags) 

229 

230 def get(self, key, default=None): 

231 return self._tags.get(key, default) 

232 

233 def get_cellpar(self) -> list | None: 

234 try: 

235 return [self[tag] for tag in self.cell_tags] 

236 except KeyError: 

237 return None 

238 

239 def get_cell(self) -> Cell: 

240 cellpar = self.get_cellpar() 

241 if cellpar is None: 

242 return Cell.new([0, 0, 0]) 

243 return Cell.new(cellpar) 

244 

245 def _raw_scaled_positions(self) -> np.ndarray | None: 

246 coords = [self.get(name) for name in ['_atom_site_fract_x', 

247 '_atom_site_fract_y', 

248 '_atom_site_fract_z']] 

249 # XXX Shall we try to handle mixed coordinates? 

250 # (Some scaled vs others fractional) 

251 if None in coords: 

252 return None 

253 return np.array(coords).T 

254 

255 def _raw_positions(self) -> np.ndarray | None: 

256 coords = [self.get('_atom_site_cartn_x'), 

257 self.get('_atom_site_cartn_y'), 

258 self.get('_atom_site_cartn_z')] 

259 if None in coords: 

260 return None 

261 return np.array(coords).T 

262 

263 def _get_site_coordinates(self): 

264 scaled = self._raw_scaled_positions() 

265 

266 if scaled is not None: 

267 return 'scaled', scaled 

268 

269 cartesian = self._raw_positions() 

270 

271 if cartesian is None: 

272 raise NoStructureData('No positions found in structure') 

273 

274 return 'cartesian', cartesian 

275 

276 def _get_symbols_with_deuterium(self): 

277 labels = self._get_any(['_atom_site_type_symbol', 

278 '_atom_site_label']) 

279 if labels is None: 

280 raise NoStructureData('No symbols') 

281 

282 symbols = [] 

283 for label in labels: 

284 if label == '.' or label == '?': 

285 raise NoStructureData('Symbols are undetermined') 

286 # Strip off additional labeling on chemical symbols 

287 match = re.search(r'([A-Z][a-z]?)', label) 

288 symbol = match.group(0) 

289 symbols.append(symbol) 

290 return symbols 

291 

292 def get_symbols(self) -> list[str]: 

293 symbols = self._get_symbols_with_deuterium() 

294 return [symbol if symbol != 'D' else 'H' for symbol in symbols] 

295 

296 def _where_deuterium(self): 

297 return np.array([symbol == 'D' for symbol 

298 in self._get_symbols_with_deuterium()], bool) 

299 

300 def _get_masses(self) -> np.ndarray | None: 

301 mask = self._where_deuterium() 

302 if not any(mask): 

303 return None 

304 

305 symbols = self.get_symbols() 

306 masses = Atoms(symbols).get_masses() 

307 masses[mask] = 2.01355 

308 return masses 

309 

310 def _get_any(self, names): 

311 for name in names: 

312 if name in self: 

313 return self[name] 

314 return None 

315 

316 def _get_spacegroup_number(self): 

317 # Symmetry specification, see 

318 # http://www.iucr.org/resources/cif/dictionaries/cif_sym for a 

319 # complete list of official keys. In addition we also try to 

320 # support some commonly used depricated notations 

321 return self._get_any(['_space_group.it_number', 

322 '_space_group_it_number', 

323 '_symmetry_int_tables_number']) 

324 

325 def _get_spacegroup_name(self): 

326 hm_symbol = self._get_any(['_space_group_name_h-m_alt', 

327 '_symmetry_space_group_name_h-m', 

328 '_space_group.Patterson_name_h-m', 

329 '_space_group.patterson_name_h-m']) 

330 

331 hm_symbol = old_spacegroup_names.get(hm_symbol, hm_symbol) 

332 return hm_symbol 

333 

334 def _get_sitesym(self): 

335 sitesym = self._get_any(['_space_group_symop_operation_xyz', 

336 '_space_group_symop.operation_xyz', 

337 '_symmetry_equiv_pos_as_xyz']) 

338 if isinstance(sitesym, str): 

339 sitesym = [sitesym] 

340 return sitesym 

341 

342 def _get_fractional_occupancies(self): 

343 return self.get('_atom_site_occupancy') 

344 

345 def _get_setting(self) -> int | None: 

346 setting_str = self.get('_symmetry_space_group_setting') 

347 if setting_str is None: 

348 return None 

349 

350 setting = int(setting_str) 

351 if setting not in [1, 2]: 

352 raise ValueError( 

353 f'Spacegroup setting must be 1 or 2, not {setting}') 

354 return setting 

355 

356 def get_spacegroup(self, subtrans_included) -> Spacegroup: 

357 # XXX The logic in this method needs serious cleaning up! 

358 no = self._get_spacegroup_number() 

359 if isinstance(no, str): 

360 # If the value was specified as "key 'value'" with ticks, 

361 # then "integer values" become strings and we'll have to 

362 # manually convert it: 

363 no = int(no) 

364 

365 hm_symbol = self._get_spacegroup_name() 

366 sitesym = self._get_sitesym() 

367 

368 if sitesym: 

369 # Special cases: sitesym can be None or an empty list. 

370 # The empty list could be replaced with just the identity 

371 # function, but it seems more correct to try to get the 

372 # spacegroup number and derive the symmetries for that. 

373 subtrans = [(0.0, 0.0, 0.0)] if subtrans_included else None 

374 

375 spacegroup = spacegroup_from_data( 

376 no=no, symbol=hm_symbol, sitesym=sitesym, 

377 subtrans=subtrans, 

378 setting=1) # should the setting be passed from somewhere? 

379 elif no is not None: 

380 spacegroup = no 

381 elif hm_symbol is not None: 

382 spacegroup = hm_symbol 

383 else: 

384 spacegroup = 1 

385 

386 setting_std = self._get_setting() 

387 

388 setting = 1 

389 setting_name = None 

390 if '_symmetry_space_group_setting' in self: 

391 assert setting_std is not None 

392 setting = setting_std 

393 elif '_space_group_crystal_system' in self: 

394 setting_name = self['_space_group_crystal_system'] 

395 elif '_symmetry_cell_setting' in self: 

396 setting_name = self['_symmetry_cell_setting'] 

397 

398 if setting_name: 

399 no = Spacegroup(spacegroup).no 

400 if no in rhombohedral_spacegroups: 

401 if setting_name == 'hexagonal': 

402 setting = 1 

403 elif setting_name in ('trigonal', 'rhombohedral'): 

404 setting = 2 

405 else: 

406 warnings.warn( 

407 f'unexpected crystal system {setting_name!r} ' 

408 f'for space group {spacegroup!r}') 

409 # FIXME - check for more crystal systems... 

410 else: 

411 warnings.warn( 

412 f'crystal system {setting_name!r} is not ' 

413 f'interpreted for space group {spacegroup!r}. ' 

414 'This may result in wrong setting!') 

415 

416 spg = Spacegroup(spacegroup, setting) 

417 if no is not None: 

418 assert int(spg) == no, (int(spg), no) 

419 return spg 

420 

421 def get_unsymmetrized_structure(self) -> Atoms: 

422 """Return Atoms without symmetrizing coordinates. 

423 

424 This returns a (normally) unphysical Atoms object 

425 corresponding only to those coordinates included 

426 in the CIF file, useful for e.g. debugging. 

427 

428 This method may change behaviour in the future.""" 

429 symbols = self.get_symbols() 

430 coordtype, coords = self._get_site_coordinates() 

431 

432 atoms = Atoms(symbols=symbols, 

433 cell=self.get_cell(), 

434 masses=self._get_masses()) 

435 

436 if coordtype == 'scaled': 

437 atoms.set_scaled_positions(coords) 

438 else: 

439 assert coordtype == 'cartesian' 

440 atoms.positions[:] = coords 

441 

442 return atoms 

443 

444 def has_structure(self): 

445 """Whether this CIF block has an atomic configuration.""" 

446 try: 

447 self.get_symbols() 

448 self._get_site_coordinates() 

449 except NoStructureData: 

450 return False 

451 else: 

452 return True 

453 

454 def get_atoms(self, store_tags=False, primitive_cell=False, 

455 subtrans_included=True, fractional_occupancies=True) -> Atoms: 

456 """Returns an Atoms object from a cif tags dictionary. See read_cif() 

457 for a description of the arguments.""" 

458 if primitive_cell and subtrans_included: 

459 raise RuntimeError( 

460 'Primitive cell cannot be determined when sublattice ' 

461 'translations are included in the symmetry operations listed ' 

462 'in the CIF file, i.e. when `subtrans_included` is True.') 

463 

464 cell = self.get_cell() 

465 assert cell.rank in [0, 3] 

466 

467 kwargs: dict[str, Any] = {} 

468 if store_tags: 

469 kwargs['info'] = self._tags.copy() 

470 

471 if fractional_occupancies: 

472 occupancies = self._get_fractional_occupancies() 

473 else: 

474 occupancies = None 

475 

476 if occupancies is not None: 

477 # no warnings in this case 

478 kwargs['onduplicates'] = 'keep' 

479 

480 # The unsymmetrized_structure is not the asymmetric unit 

481 # because the asymmetric unit should have (in general) a smaller cell, 

482 # whereas we have the full cell. 

483 unsymmetrized_structure = self.get_unsymmetrized_structure() 

484 

485 if cell.rank == 3: 

486 spacegroup = self.get_spacegroup(subtrans_included) 

487 atoms = crystal(unsymmetrized_structure, 

488 spacegroup=spacegroup, 

489 setting=spacegroup.setting, 

490 occupancies=occupancies, 

491 primitive_cell=primitive_cell, 

492 **kwargs) 

493 else: 

494 atoms = unsymmetrized_structure 

495 if kwargs.get('info') is not None: 

496 atoms.info.update(kwargs['info']) 

497 if occupancies is not None: 

498 occ_dict = { 

499 str(i): {sym: occupancies[i]} 

500 for i, sym in enumerate(atoms.symbols) 

501 } 

502 atoms.info['occupancy'] = occ_dict 

503 

504 return atoms 

505 

506 

507def parse_block(lines: list[str], line: str) -> CIFBlock: 

508 assert line.lower().startswith('data_') 

509 blockname = line.split('_', 1)[1].rstrip() 

510 tags = parse_items(lines, line) 

511 return CIFBlock(blockname, tags) 

512 

513 

514def parse_cif(fileobj, reader='ase') -> Iterator[CIFBlock]: 

515 if reader == 'ase': 

516 return parse_cif_ase(fileobj) 

517 elif reader == 'pycodcif': 

518 return parse_cif_pycodcif(fileobj) 

519 else: 

520 raise ValueError(f'No such reader: {reader}') 

521 

522 

523def parse_cif_ase(fileobj) -> Iterator[CIFBlock]: 

524 """Parse a CIF file using ase CIF parser.""" 

525 

526 if isinstance(fileobj, str): 

527 with open(fileobj, 'rb') as fileobj: 

528 data = fileobj.read() 

529 else: 

530 data = fileobj.read() 

531 

532 if isinstance(data, bytes): 

533 data = data.decode('latin1') 

534 data = format_unicode(data) 

535 lines = [e for e in data.split('\n') if len(e) > 0] 

536 if len(lines) > 0 and lines[0].rstrip() == '#\\#CIF_2.0': 

537 warnings.warn('CIF v2.0 file format detected; `ase` CIF reader might ' 

538 'incorrectly interpret some syntax constructions, use ' 

539 '`pycodcif` reader instead') 

540 lines = [''] + lines[::-1] # all lines (reversed) 

541 

542 while lines: 

543 line = lines.pop().strip() 

544 if not line or line.startswith('#'): 

545 continue 

546 

547 yield parse_block(lines, line) 

548 

549 

550def parse_cif_pycodcif(fileobj) -> Iterator[CIFBlock]: 

551 """Parse a CIF file using pycodcif CIF parser.""" 

552 if not isinstance(fileobj, str): 

553 fileobj = fileobj.name 

554 

555 try: 

556 from pycodcif import parse 

557 except ImportError: 

558 raise ImportError( 

559 'parse_cif_pycodcif requires pycodcif ' + 

560 '(http://wiki.crystallography.net/cod-tools/pycodcif/)') 

561 

562 data, _, _ = parse(fileobj) 

563 

564 for datablock in data: 

565 tags = datablock['values'] 

566 for tag in tags.keys(): 

567 values = [convert_value(x) for x in tags[tag]] 

568 if len(values) == 1: 

569 tags[tag] = values[0] 

570 else: 

571 tags[tag] = values 

572 yield CIFBlock(datablock['name'], tags) 

573 

574 

575def iread_cif( 

576 fileobj, 

577 index=-1, 

578 store_tags: bool = False, 

579 primitive_cell: bool = False, 

580 subtrans_included: bool = True, 

581 fractional_occupancies: bool = True, 

582 reader: str = 'ase', 

583) -> Iterator[Atoms]: 

584 # Find all CIF blocks with valid crystal data 

585 # TODO: return Atoms of the block name ``index`` if it is a string. 

586 images = [] 

587 for block in parse_cif(fileobj, reader): 

588 if not block.has_structure(): 

589 continue 

590 

591 atoms = block.get_atoms( 

592 store_tags, primitive_cell, 

593 subtrans_included, 

594 fractional_occupancies=fractional_occupancies) 

595 images.append(atoms) 

596 

597 if index is None or index == ':': 

598 index = slice(None, None, None) 

599 

600 if not isinstance(index, (slice, str)): 

601 index = slice(index, (index + 1) or None) 

602 

603 for atoms in images[index]: 

604 yield atoms 

605 

606 

607def read_cif( 

608 fileobj, 

609 index=-1, 

610 *, 

611 store_tags: bool = False, 

612 primitive_cell: bool = False, 

613 subtrans_included: bool = True, 

614 fractional_occupancies: bool = True, 

615 reader: str = 'ase', 

616) -> Atoms | list[Atoms]: 

617 """Read Atoms object from CIF file. 

618 

619 Parameters 

620 ---------- 

621 store_tags : bool 

622 If true, the *info* attribute of the returned Atoms object will be 

623 populated with all tags in the corresponding cif data block. 

624 primitive_cell : bool 

625 If true, the primitive cell is built instead of the conventional cell. 

626 subtrans_included : bool 

627 If true, sublattice translations are assumed to be included among the 

628 symmetry operations listed in the CIF file (seems to be the common 

629 behaviour of CIF files). 

630 Otherwise the sublattice translations are determined from setting 1 of 

631 the extracted space group. A result of setting this flag to true, is 

632 that it will not be possible to determine the primitive cell. 

633 fractional_occupancies : bool 

634 If true, the resulting atoms object will be tagged equipped with a 

635 dictionary `occupancy`. The keys of this dictionary will be integers 

636 converted to strings. The conversion to string is done in order to 

637 avoid troubles with JSON encoding/decoding of the dictionaries with 

638 non-string keys. 

639 Also, in case of mixed occupancies, the atom's chemical symbol will be 

640 that of the most dominant species. 

641 reader : str 

642 Select CIF reader. 

643 

644 * ``ase`` : built-in CIF reader (default) 

645 * ``pycodcif`` : CIF reader based on ``pycodcif`` package 

646 

647 Notes 

648 ----- 

649 Only blocks with valid crystal data will be included. 

650 """ 

651 g = iread_cif( 

652 fileobj, 

653 index, 

654 store_tags, 

655 primitive_cell, 

656 subtrans_included, 

657 fractional_occupancies, 

658 reader, 

659 ) 

660 if isinstance(index, (slice, str)): 

661 # Return list of atoms 

662 return list(g) 

663 else: 

664 # Return single atoms object 

665 return next(g) 

666 

667 

668def format_cell(cell: Cell) -> str: 

669 assert cell.rank == 3 

670 lines = [] 

671 for name, value in zip(CIFBlock.cell_tags, cell.cellpar()): 

672 line = f'{name:20} {value}\n' 

673 lines.append(line) 

674 assert len(lines) == 6 

675 return ''.join(lines) 

676 

677 

678def format_generic_spacegroup_info() -> str: 

679 # We assume no symmetry whatsoever 

680 return '\n'.join([ 

681 '_space_group_name_H-M_alt "P 1"', 

682 '_space_group_IT_number 1', 

683 '', 

684 'loop_', 

685 ' _space_group_symop_operation_xyz', 

686 " 'x, y, z'", 

687 '', 

688 ]) 

689 

690 

691class CIFLoop: 

692 def __init__(self): 

693 self.names = [] 

694 self.formats = [] 

695 self.arrays = [] 

696 

697 def add(self, name, array, fmt): 

698 assert name.startswith('_') 

699 self.names.append(name) 

700 self.formats.append(fmt) 

701 self.arrays.append(array) 

702 if len(self.arrays[0]) != len(self.arrays[-1]): 

703 raise ValueError(f'Loop data "{name}" has {len(array)} ' 

704 'elements, expected {len(self.arrays[0])}') 

705 

706 def tostring(self): 

707 lines = [] 

708 append = lines.append 

709 append('loop_') 

710 for name in self.names: 

711 append(f' {name}') 

712 

713 template = ' ' + ' '.join(self.formats) 

714 

715 ncolumns = len(self.arrays) 

716 nrows = len(self.arrays[0]) if ncolumns > 0 else 0 

717 for row in range(nrows): 

718 arraydata = [array[row] for array in self.arrays] 

719 line = template.format(*arraydata) 

720 append(line) 

721 append('') 

722 return '\n'.join(lines) 

723 

724 

725@iofunction('wb') 

726def write_cif(fd, images, cif_format=None, 

727 wrap=True, labels=None, loop_keys=None) -> None: 

728 r"""Write *images* to CIF file. 

729 

730 wrap: bool 

731 Wrap atoms into unit cell. 

732 

733 labels: list 

734 Use this list (shaped list[i_frame][i_atom] = string) for the 

735 '_atom_site_label' section instead of automatically generating 

736 it from the element symbol. 

737 

738 loop_keys: dict 

739 Add the information from this dictionary to the `loop\_` 

740 section. Keys are printed to the `loop\_` section preceeded by 

741 ' _'. dict[key] should contain the data printed for each atom, 

742 so it needs to have the setup `dict[key][i_frame][i_atom] = 

743 string`. The strings are printed as they are, so take care of 

744 formating. Information can be re-read using the `store_tags` 

745 option of the cif reader. 

746 

747 """ 

748 

749 if cif_format is not None: 

750 warnings.warn('The cif_format argument is deprecated and may be ' 

751 'removed in the future. Use loop_keys to customize ' 

752 'data written in loop.', FutureWarning) 

753 

754 if loop_keys is None: 

755 loop_keys = {} 

756 

757 if hasattr(images, 'get_positions'): 

758 images = [images] 

759 

760 fd = io.TextIOWrapper(fd, encoding='latin-1') 

761 try: 

762 for i, atoms in enumerate(images): 

763 blockname = f'data_image{i}\n' 

764 image_loop_keys = {key: loop_keys[key][i] for key in loop_keys} 

765 

766 write_cif_image(blockname, atoms, fd, 

767 wrap=wrap, 

768 labels=None if labels is None else labels[i], 

769 loop_keys=image_loop_keys) 

770 

771 finally: 

772 # Using the TextIOWrapper somehow causes the file to close 

773 # when this function returns. 

774 # Detach in order to circumvent this highly illogical problem: 

775 fd.detach() 

776 

777 

778def autolabel(symbols: Sequence[str]) -> list[str]: 

779 no: dict[str, int] = {} 

780 labels = [] 

781 for symbol in symbols: 

782 if symbol in no: 

783 no[symbol] += 1 

784 else: 

785 no[symbol] = 1 

786 labels.append('%s%d' % (symbol, no[symbol])) 

787 return labels 

788 

789 

790def chemical_formula_header(atoms): 

791 counts = atoms.symbols.formula.count() 

792 formula_sum = ' '.join(f'{sym}{count}' for sym, count 

793 in counts.items()) 

794 return (f'_chemical_formula_structural {atoms.symbols}\n' 

795 f'_chemical_formula_sum "{formula_sum}"\n') 

796 

797 

798class BadOccupancies(ValueError): 

799 pass 

800 

801 

802def expand_kinds(atoms, coords): 

803 # try to fetch occupancies // spacegroup_kinds - occupancy mapping 

804 symbols = list(atoms.symbols) 

805 coords = list(coords) 

806 occupancies = [1] * len(symbols) 

807 occ_info = atoms.info.get('occupancy') 

808 kinds = atoms.arrays.get('spacegroup_kinds') 

809 if occ_info is not None and kinds is not None: 

810 for i, kind in enumerate(kinds): 

811 occ_info_kind = occ_info[str(kind)] 

812 symbol = symbols[i] 

813 if symbol not in occ_info_kind: 

814 raise BadOccupancies('Occupancies present but no occupancy ' 

815 'info for "{symbol}"') 

816 occupancies[i] = occ_info_kind[symbol] 

817 # extend the positions array in case of mixed occupancy 

818 for sym, occ in occ_info[str(kind)].items(): 

819 if sym != symbols[i]: 

820 symbols.append(sym) 

821 coords.append(coords[i]) 

822 occupancies.append(occ) 

823 return symbols, coords, occupancies 

824 

825 

826def atoms_to_loop_data(atoms, wrap, labels, loop_keys): 

827 if atoms.cell.rank == 3: 

828 coord_type = 'fract' 

829 coords = atoms.get_scaled_positions(wrap).tolist() 

830 else: 

831 coord_type = 'Cartn' 

832 coords = atoms.get_positions(wrap).tolist() 

833 

834 try: 

835 symbols, coords, occupancies = expand_kinds(atoms, coords) 

836 except BadOccupancies as err: 

837 warnings.warn(str(err)) 

838 occupancies = [1] * len(atoms) 

839 symbols = list(atoms.symbols) 

840 

841 if labels is None: 

842 labels = autolabel(symbols) 

843 

844 coord_headers = [f'_atom_site_{coord_type}_{axisname}' 

845 for axisname in 'xyz'] 

846 

847 loopdata = {} 

848 loopdata['_atom_site_label'] = (labels, '{:<8s}') 

849 loopdata['_atom_site_occupancy'] = (occupancies, '{:6.4f}') 

850 

851 _coords = np.array(coords) 

852 for i, key in enumerate(coord_headers): 

853 loopdata[key] = (_coords[:, i], '{}') 

854 

855 loopdata['_atom_site_type_symbol'] = (symbols, '{:<2s}') 

856 loopdata['_atom_site_symmetry_multiplicity'] = ( 

857 [1.0] * len(symbols), '{}') 

858 

859 for key in loop_keys: 

860 # Should expand the loop_keys like we expand the occupancy stuff. 

861 # Otherwise user will never figure out how to do this. 

862 values = [loop_keys[key][i] for i in range(len(symbols))] 

863 loopdata['_' + key] = (values, '{}') 

864 

865 return loopdata, coord_headers 

866 

867 

868def write_cif_image(blockname, atoms, fd, *, wrap, 

869 labels, loop_keys): 

870 fd.write(blockname) 

871 fd.write(chemical_formula_header(atoms)) 

872 

873 rank = atoms.cell.rank 

874 if rank == 3: 

875 fd.write(format_cell(atoms.cell)) 

876 fd.write('\n') 

877 fd.write(format_generic_spacegroup_info()) 

878 fd.write('\n') 

879 elif rank != 0: 

880 raise ValueError('CIF format can only represent systems with ' 

881 f'0 or 3 lattice vectors. Got {rank}.') 

882 

883 loopdata, coord_headers = atoms_to_loop_data(atoms, wrap, labels, 

884 loop_keys) 

885 

886 headers = [ 

887 '_atom_site_type_symbol', 

888 '_atom_site_label', 

889 '_atom_site_symmetry_multiplicity', 

890 *coord_headers, 

891 '_atom_site_occupancy', 

892 ] 

893 

894 headers += ['_' + key for key in loop_keys] 

895 

896 loop = CIFLoop() 

897 for header in headers: 

898 array, fmt = loopdata[header] 

899 loop.add(header, array, fmt) 

900 

901 fd.write(loop.tostring())