Coverage for /builds/ase/ase/ase/io/cif.py: 90.84%

491 statements  

« prev     ^ index     » next       coverage.py v7.5.3, created at 2025-08-02 00:12 +0000

1# fmt: off 

2 

3"""Module to read and write atoms in cif file format. 

4 

5See http://www.iucr.org/resources/cif/spec/version1.1/cifsyntax for a 

6description of the file format. STAR extensions as save frames, 

7global blocks, nested loops and multi-data values are not supported. 

8The "latin-1" encoding is required by the IUCR specification. 

9""" 

10 

11import collections.abc 

12import io 

13import re 

14import shlex 

15import warnings 

16from typing import Any, Dict, Iterator, List, Optional, Sequence, Tuple, Union 

17 

18import numpy as np 

19 

20from ase import Atoms 

21from ase.cell import Cell 

22from ase.io.cif_unicode import format_unicode, handle_subscripts 

23from ase.spacegroup import crystal 

24from ase.spacegroup.spacegroup import Spacegroup, spacegroup_from_data 

25from ase.utils import iofunction 

26 

27rhombohedral_spacegroups = {146, 148, 155, 160, 161, 166, 167} 

28 

29 

30old_spacegroup_names = {'Abm2': 'Aem2', 

31 'Aba2': 'Aea2', 

32 'Cmca': 'Cmce', 

33 'Cmma': 'Cmme', 

34 'Ccca': 'Ccc1'} 

35 

36# CIF maps names to either single values or to multiple values via loops. 

37CIFDataValue = Union[str, int, float] 

38CIFData = Union[CIFDataValue, List[CIFDataValue]] 

39 

40 

41def convert_value(value: str) -> CIFDataValue: 

42 """Convert CIF value string to corresponding python type.""" 

43 value = value.strip() 

44 if re.match('(".*")|(\'.*\')$', value): 

45 return handle_subscripts(value[1:-1]) 

46 elif re.match(r'[+-]?\d+$', value): 

47 return int(value) 

48 elif re.match(r'[+-]?(?:\d+(?:\.\d*)?|\.\d+)(?:[eE][+-]?\d+)?$', value): 

49 return float(value) 

50 elif re.match(r'[+-]?(?:\d+(?:\.\d*)?|\.\d+)(?:[eE][+-]?\d+)?\(\d+\)$', 

51 value): 

52 return float(value[:value.index('(')]) # strip off uncertainties 

53 elif re.match(r'[+-]?(?:\d+(?:\.\d*)?|\.\d+)(?:[eE][+-]?\d+)?\(\d+$', 

54 value): 

55 warnings.warn(f'Badly formed number: "{value}"') 

56 return float(value[:value.index('(')]) # strip off uncertainties 

57 else: 

58 return handle_subscripts(value) 

59 

60 

61def parse_multiline_string(lines: List[str], line: str) -> str: 

62 """Parse semicolon-enclosed multiline string and return it.""" 

63 assert line[0] == ';' 

64 strings = [line[1:].lstrip()] 

65 while True: 

66 line = lines.pop().strip() 

67 if line[:1] == ';': 

68 break 

69 strings.append(line) 

70 return '\n'.join(strings).strip() 

71 

72 

73def parse_singletag(lines: List[str], line: str) -> Tuple[str, CIFDataValue]: 

74 """Parse a CIF tag (entries starting with underscore). Returns 

75 a key-value pair.""" 

76 kv = line.split(None, 1) 

77 if len(kv) == 1: 

78 key = line 

79 line = lines.pop().strip() 

80 while not line or line[0] == '#': 

81 line = lines.pop().strip() 

82 if line[0] == ';': 

83 value = parse_multiline_string(lines, line) 

84 else: 

85 value = line 

86 else: 

87 key, value = kv 

88 return key, convert_value(value) 

89 

90 

91def parse_cif_loop_headers(lines: List[str]) -> Iterator[str]: 

92 while lines: 

93 line = lines.pop() 

94 tokens = line.split() 

95 

96 if len(tokens) == 1 and tokens[0].startswith('_'): 

97 header = tokens[0].lower() 

98 yield header 

99 else: 

100 lines.append(line) # 'undo' pop 

101 return 

102 

103 

104def parse_cif_loop_data(lines: List[str], 

105 ncolumns: int) -> List[List[CIFDataValue]]: 

106 columns: List[List[CIFDataValue]] = [[] for _ in range(ncolumns)] 

107 

108 tokens = [] 

109 while lines: 

110 line = lines.pop().strip() 

111 lowerline = line.lower() 

112 if (not line or 

113 line.startswith('_') or 

114 lowerline.startswith('data_') or 

115 lowerline.startswith('loop_')): 

116 lines.append(line) 

117 break 

118 

119 if line.startswith('#'): 

120 continue 

121 

122 line = line.split(' #')[0] 

123 

124 if line.startswith(';'): 

125 moretokens = [parse_multiline_string(lines, line)] 

126 else: 

127 if ncolumns == 1: 

128 moretokens = [line] 

129 else: 

130 moretokens = shlex.split(line, posix=False) 

131 

132 tokens += moretokens 

133 if len(tokens) < ncolumns: 

134 continue 

135 if len(tokens) == ncolumns: 

136 for i, token in enumerate(tokens): 

137 columns[i].append(convert_value(token)) 

138 else: 

139 warnings.warn(f'Wrong number {len(tokens)} of tokens, ' 

140 f'expected {ncolumns}: {tokens}') 

141 

142 # (Due to continue statements we cannot move this to start of loop) 

143 tokens = [] 

144 

145 if tokens: 

146 assert len(tokens) < ncolumns 

147 raise RuntimeError('CIF loop ended unexpectedly with incomplete row: ' 

148 f'{tokens}, expected {ncolumns} tokens') 

149 

150 return columns 

151 

152 

153def parse_loop(lines: List[str]) -> Dict[str, List[CIFDataValue]]: 

154 """Parse a CIF loop. Returns a dict with column tag names as keys 

155 and a lists of the column content as values.""" 

156 

157 headers = list(parse_cif_loop_headers(lines)) 

158 # Dict would be better. But there can be repeated headers. 

159 

160 columns = parse_cif_loop_data(lines, len(headers)) 

161 

162 columns_dict = {} 

163 for i, header in enumerate(headers): 

164 if header in columns_dict: 

165 warnings.warn(f'Duplicated loop tags: {header}') 

166 else: 

167 columns_dict[header] = columns[i] 

168 return columns_dict 

169 

170 

171def parse_items(lines: List[str], line: str) -> Dict[str, CIFData]: 

172 """Parse a CIF data items and return a dict with all tags.""" 

173 tags: Dict[str, CIFData] = {} 

174 

175 while True: 

176 if not lines: 

177 break 

178 line = lines.pop().strip() 

179 if not line: 

180 continue 

181 lowerline = line.lower() 

182 if not line or line.startswith('#'): 

183 continue 

184 elif line.startswith('_'): 

185 key, value = parse_singletag(lines, line) 

186 tags[key.lower()] = value 

187 elif lowerline.startswith('loop_'): 

188 tags.update(parse_loop(lines)) 

189 elif lowerline.startswith('data_'): 

190 if line: 

191 lines.append(line) 

192 break 

193 elif line.startswith(';'): 

194 parse_multiline_string(lines, line) 

195 else: 

196 raise ValueError(f'Unexpected CIF file entry: "{line}"') 

197 return tags 

198 

199 

200class NoStructureData(RuntimeError): 

201 pass 

202 

203 

204class CIFBlock(collections.abc.Mapping): 

205 """A block (i.e., a single system) in a crystallographic information file. 

206 

207 Use this object to query CIF tags or import information as ASE objects.""" 

208 

209 cell_tags = ['_cell_length_a', '_cell_length_b', '_cell_length_c', 

210 '_cell_angle_alpha', '_cell_angle_beta', '_cell_angle_gamma'] 

211 

212 def __init__(self, name: str, tags: Dict[str, CIFData]): 

213 self.name = name 

214 self._tags = tags 

215 

216 def __repr__(self) -> str: 

217 tags = set(self._tags) 

218 return f'CIFBlock({self.name}, tags={tags})' 

219 

220 def __getitem__(self, key: str) -> CIFData: 

221 return self._tags[key] 

222 

223 def __iter__(self) -> Iterator[str]: 

224 return iter(self._tags) 

225 

226 def __len__(self) -> int: 

227 return len(self._tags) 

228 

229 def get(self, key, default=None): 

230 return self._tags.get(key, default) 

231 

232 def get_cellpar(self) -> Optional[List]: 

233 try: 

234 return [self[tag] for tag in self.cell_tags] 

235 except KeyError: 

236 return None 

237 

238 def get_cell(self) -> Cell: 

239 cellpar = self.get_cellpar() 

240 if cellpar is None: 

241 return Cell.new([0, 0, 0]) 

242 return Cell.new(cellpar) 

243 

244 def _raw_scaled_positions(self) -> Optional[np.ndarray]: 

245 coords = [self.get(name) for name in ['_atom_site_fract_x', 

246 '_atom_site_fract_y', 

247 '_atom_site_fract_z']] 

248 # XXX Shall we try to handle mixed coordinates? 

249 # (Some scaled vs others fractional) 

250 if None in coords: 

251 return None 

252 return np.array(coords).T 

253 

254 def _raw_positions(self) -> Optional[np.ndarray]: 

255 coords = [self.get('_atom_site_cartn_x'), 

256 self.get('_atom_site_cartn_y'), 

257 self.get('_atom_site_cartn_z')] 

258 if None in coords: 

259 return None 

260 return np.array(coords).T 

261 

262 def _get_site_coordinates(self): 

263 scaled = self._raw_scaled_positions() 

264 

265 if scaled is not None: 

266 return 'scaled', scaled 

267 

268 cartesian = self._raw_positions() 

269 

270 if cartesian is None: 

271 raise NoStructureData('No positions found in structure') 

272 

273 return 'cartesian', cartesian 

274 

275 def _get_symbols_with_deuterium(self): 

276 labels = self._get_any(['_atom_site_type_symbol', 

277 '_atom_site_label']) 

278 if labels is None: 

279 raise NoStructureData('No symbols') 

280 

281 symbols = [] 

282 for label in labels: 

283 if label == '.' or label == '?': 

284 raise NoStructureData('Symbols are undetermined') 

285 # Strip off additional labeling on chemical symbols 

286 match = re.search(r'([A-Z][a-z]?)', label) 

287 symbol = match.group(0) 

288 symbols.append(symbol) 

289 return symbols 

290 

291 def get_symbols(self) -> List[str]: 

292 symbols = self._get_symbols_with_deuterium() 

293 return [symbol if symbol != 'D' else 'H' for symbol in symbols] 

294 

295 def _where_deuterium(self): 

296 return np.array([symbol == 'D' for symbol 

297 in self._get_symbols_with_deuterium()], bool) 

298 

299 def _get_masses(self) -> Optional[np.ndarray]: 

300 mask = self._where_deuterium() 

301 if not any(mask): 

302 return None 

303 

304 symbols = self.get_symbols() 

305 masses = Atoms(symbols).get_masses() 

306 masses[mask] = 2.01355 

307 return masses 

308 

309 def _get_any(self, names): 

310 for name in names: 

311 if name in self: 

312 return self[name] 

313 return None 

314 

315 def _get_spacegroup_number(self): 

316 # Symmetry specification, see 

317 # http://www.iucr.org/resources/cif/dictionaries/cif_sym for a 

318 # complete list of official keys. In addition we also try to 

319 # support some commonly used depricated notations 

320 return self._get_any(['_space_group.it_number', 

321 '_space_group_it_number', 

322 '_symmetry_int_tables_number']) 

323 

324 def _get_spacegroup_name(self): 

325 hm_symbol = self._get_any(['_space_group_name_h-m_alt', 

326 '_symmetry_space_group_name_h-m', 

327 '_space_group.Patterson_name_h-m', 

328 '_space_group.patterson_name_h-m']) 

329 

330 hm_symbol = old_spacegroup_names.get(hm_symbol, hm_symbol) 

331 return hm_symbol 

332 

333 def _get_sitesym(self): 

334 sitesym = self._get_any(['_space_group_symop_operation_xyz', 

335 '_space_group_symop.operation_xyz', 

336 '_symmetry_equiv_pos_as_xyz']) 

337 if isinstance(sitesym, str): 

338 sitesym = [sitesym] 

339 return sitesym 

340 

341 def _get_fractional_occupancies(self): 

342 return self.get('_atom_site_occupancy') 

343 

344 def _get_setting(self) -> Optional[int]: 

345 setting_str = self.get('_symmetry_space_group_setting') 

346 if setting_str is None: 

347 return None 

348 

349 setting = int(setting_str) 

350 if setting not in [1, 2]: 

351 raise ValueError( 

352 f'Spacegroup setting must be 1 or 2, not {setting}') 

353 return setting 

354 

355 def get_spacegroup(self, subtrans_included) -> Spacegroup: 

356 # XXX The logic in this method needs serious cleaning up! 

357 no = self._get_spacegroup_number() 

358 if isinstance(no, str): 

359 # If the value was specified as "key 'value'" with ticks, 

360 # then "integer values" become strings and we'll have to 

361 # manually convert it: 

362 no = int(no) 

363 

364 hm_symbol = self._get_spacegroup_name() 

365 sitesym = self._get_sitesym() 

366 

367 if sitesym: 

368 # Special cases: sitesym can be None or an empty list. 

369 # The empty list could be replaced with just the identity 

370 # function, but it seems more correct to try to get the 

371 # spacegroup number and derive the symmetries for that. 

372 subtrans = [(0.0, 0.0, 0.0)] if subtrans_included else None 

373 

374 spacegroup = spacegroup_from_data( 

375 no=no, symbol=hm_symbol, sitesym=sitesym, 

376 subtrans=subtrans, 

377 setting=1) # should the setting be passed from somewhere? 

378 elif no is not None: 

379 spacegroup = no 

380 elif hm_symbol is not None: 

381 spacegroup = hm_symbol 

382 else: 

383 spacegroup = 1 

384 

385 setting_std = self._get_setting() 

386 

387 setting = 1 

388 setting_name = None 

389 if '_symmetry_space_group_setting' in self: 

390 assert setting_std is not None 

391 setting = setting_std 

392 elif '_space_group_crystal_system' in self: 

393 setting_name = self['_space_group_crystal_system'] 

394 elif '_symmetry_cell_setting' in self: 

395 setting_name = self['_symmetry_cell_setting'] 

396 

397 if setting_name: 

398 no = Spacegroup(spacegroup).no 

399 if no in rhombohedral_spacegroups: 

400 if setting_name == 'hexagonal': 

401 setting = 1 

402 elif setting_name in ('trigonal', 'rhombohedral'): 

403 setting = 2 

404 else: 

405 warnings.warn( 

406 f'unexpected crystal system {setting_name!r} ' 

407 f'for space group {spacegroup!r}') 

408 # FIXME - check for more crystal systems... 

409 else: 

410 warnings.warn( 

411 f'crystal system {setting_name!r} is not ' 

412 f'interpreted for space group {spacegroup!r}. ' 

413 'This may result in wrong setting!') 

414 

415 spg = Spacegroup(spacegroup, setting) 

416 if no is not None: 

417 assert int(spg) == no, (int(spg), no) 

418 return spg 

419 

420 def get_unsymmetrized_structure(self) -> Atoms: 

421 """Return Atoms without symmetrizing coordinates. 

422 

423 This returns a (normally) unphysical Atoms object 

424 corresponding only to those coordinates included 

425 in the CIF file, useful for e.g. debugging. 

426 

427 This method may change behaviour in the future.""" 

428 symbols = self.get_symbols() 

429 coordtype, coords = self._get_site_coordinates() 

430 

431 atoms = Atoms(symbols=symbols, 

432 cell=self.get_cell(), 

433 masses=self._get_masses()) 

434 

435 if coordtype == 'scaled': 

436 atoms.set_scaled_positions(coords) 

437 else: 

438 assert coordtype == 'cartesian' 

439 atoms.positions[:] = coords 

440 

441 return atoms 

442 

443 def has_structure(self): 

444 """Whether this CIF block has an atomic configuration.""" 

445 try: 

446 self.get_symbols() 

447 self._get_site_coordinates() 

448 except NoStructureData: 

449 return False 

450 else: 

451 return True 

452 

453 def get_atoms(self, store_tags=False, primitive_cell=False, 

454 subtrans_included=True, fractional_occupancies=True) -> Atoms: 

455 """Returns an Atoms object from a cif tags dictionary. See read_cif() 

456 for a description of the arguments.""" 

457 if primitive_cell and subtrans_included: 

458 raise RuntimeError( 

459 'Primitive cell cannot be determined when sublattice ' 

460 'translations are included in the symmetry operations listed ' 

461 'in the CIF file, i.e. when `subtrans_included` is True.') 

462 

463 cell = self.get_cell() 

464 assert cell.rank in [0, 3] 

465 

466 kwargs: Dict[str, Any] = {} 

467 if store_tags: 

468 kwargs['info'] = self._tags.copy() 

469 

470 if fractional_occupancies: 

471 occupancies = self._get_fractional_occupancies() 

472 else: 

473 occupancies = None 

474 

475 if occupancies is not None: 

476 # no warnings in this case 

477 kwargs['onduplicates'] = 'keep' 

478 

479 # The unsymmetrized_structure is not the asymmetric unit 

480 # because the asymmetric unit should have (in general) a smaller cell, 

481 # whereas we have the full cell. 

482 unsymmetrized_structure = self.get_unsymmetrized_structure() 

483 

484 if cell.rank == 3: 

485 spacegroup = self.get_spacegroup(subtrans_included) 

486 atoms = crystal(unsymmetrized_structure, 

487 spacegroup=spacegroup, 

488 setting=spacegroup.setting, 

489 occupancies=occupancies, 

490 primitive_cell=primitive_cell, 

491 **kwargs) 

492 else: 

493 atoms = unsymmetrized_structure 

494 if kwargs.get('info') is not None: 

495 atoms.info.update(kwargs['info']) 

496 if occupancies is not None: 

497 occ_dict = { 

498 str(i): {sym: occupancies[i]} 

499 for i, sym in enumerate(atoms.symbols) 

500 } 

501 atoms.info['occupancy'] = occ_dict 

502 

503 return atoms 

504 

505 

506def parse_block(lines: List[str], line: str) -> CIFBlock: 

507 assert line.lower().startswith('data_') 

508 blockname = line.split('_', 1)[1].rstrip() 

509 tags = parse_items(lines, line) 

510 return CIFBlock(blockname, tags) 

511 

512 

513def parse_cif(fileobj, reader='ase') -> Iterator[CIFBlock]: 

514 if reader == 'ase': 

515 return parse_cif_ase(fileobj) 

516 elif reader == 'pycodcif': 

517 return parse_cif_pycodcif(fileobj) 

518 else: 

519 raise ValueError(f'No such reader: {reader}') 

520 

521 

522def parse_cif_ase(fileobj) -> Iterator[CIFBlock]: 

523 """Parse a CIF file using ase CIF parser.""" 

524 

525 if isinstance(fileobj, str): 

526 with open(fileobj, 'rb') as fileobj: 

527 data = fileobj.read() 

528 else: 

529 data = fileobj.read() 

530 

531 if isinstance(data, bytes): 

532 data = data.decode('latin1') 

533 data = format_unicode(data) 

534 lines = [e for e in data.split('\n') if len(e) > 0] 

535 if len(lines) > 0 and lines[0].rstrip() == '#\\#CIF_2.0': 

536 warnings.warn('CIF v2.0 file format detected; `ase` CIF reader might ' 

537 'incorrectly interpret some syntax constructions, use ' 

538 '`pycodcif` reader instead') 

539 lines = [''] + lines[::-1] # all lines (reversed) 

540 

541 while lines: 

542 line = lines.pop().strip() 

543 if not line or line.startswith('#'): 

544 continue 

545 

546 yield parse_block(lines, line) 

547 

548 

549def parse_cif_pycodcif(fileobj) -> Iterator[CIFBlock]: 

550 """Parse a CIF file using pycodcif CIF parser.""" 

551 if not isinstance(fileobj, str): 

552 fileobj = fileobj.name 

553 

554 try: 

555 from pycodcif import parse 

556 except ImportError: 

557 raise ImportError( 

558 'parse_cif_pycodcif requires pycodcif ' + 

559 '(http://wiki.crystallography.net/cod-tools/pycodcif/)') 

560 

561 data, _, _ = parse(fileobj) 

562 

563 for datablock in data: 

564 tags = datablock['values'] 

565 for tag in tags.keys(): 

566 values = [convert_value(x) for x in tags[tag]] 

567 if len(values) == 1: 

568 tags[tag] = values[0] 

569 else: 

570 tags[tag] = values 

571 yield CIFBlock(datablock['name'], tags) 

572 

573 

574def iread_cif( 

575 fileobj, 

576 index=-1, 

577 store_tags: bool = False, 

578 primitive_cell: bool = False, 

579 subtrans_included: bool = True, 

580 fractional_occupancies: bool = True, 

581 reader: str = 'ase', 

582) -> Iterator[Atoms]: 

583 # Find all CIF blocks with valid crystal data 

584 # TODO: return Atoms of the block name ``index`` if it is a string. 

585 images = [] 

586 for block in parse_cif(fileobj, reader): 

587 if not block.has_structure(): 

588 continue 

589 

590 atoms = block.get_atoms( 

591 store_tags, primitive_cell, 

592 subtrans_included, 

593 fractional_occupancies=fractional_occupancies) 

594 images.append(atoms) 

595 

596 if index is None or index == ':': 

597 index = slice(None, None, None) 

598 

599 if not isinstance(index, (slice, str)): 

600 index = slice(index, (index + 1) or None) 

601 

602 for atoms in images[index]: 

603 yield atoms 

604 

605 

606def read_cif( 

607 fileobj, 

608 index=-1, 

609 *, 

610 store_tags: bool = False, 

611 primitive_cell: bool = False, 

612 subtrans_included: bool = True, 

613 fractional_occupancies: bool = True, 

614 reader: str = 'ase', 

615) -> Union[Atoms, List[Atoms]]: 

616 """Read Atoms object from CIF file. 

617 

618 Parameters 

619 ---------- 

620 store_tags : bool 

621 If true, the *info* attribute of the returned Atoms object will be 

622 populated with all tags in the corresponding cif data block. 

623 primitive_cell : bool 

624 If true, the primitive cell is built instead of the conventional cell. 

625 subtrans_included : bool 

626 If true, sublattice translations are assumed to be included among the 

627 symmetry operations listed in the CIF file (seems to be the common 

628 behaviour of CIF files). 

629 Otherwise the sublattice translations are determined from setting 1 of 

630 the extracted space group. A result of setting this flag to true, is 

631 that it will not be possible to determine the primitive cell. 

632 fractional_occupancies : bool 

633 If true, the resulting atoms object will be tagged equipped with a 

634 dictionary `occupancy`. The keys of this dictionary will be integers 

635 converted to strings. The conversion to string is done in order to 

636 avoid troubles with JSON encoding/decoding of the dictionaries with 

637 non-string keys. 

638 Also, in case of mixed occupancies, the atom's chemical symbol will be 

639 that of the most dominant species. 

640 reader : str 

641 Select CIF reader. 

642 

643 * ``ase`` : built-in CIF reader (default) 

644 * ``pycodcif`` : CIF reader based on ``pycodcif`` package 

645 

646 Notes 

647 ----- 

648 Only blocks with valid crystal data will be included. 

649 """ 

650 g = iread_cif( 

651 fileobj, 

652 index, 

653 store_tags, 

654 primitive_cell, 

655 subtrans_included, 

656 fractional_occupancies, 

657 reader, 

658 ) 

659 if isinstance(index, (slice, str)): 

660 # Return list of atoms 

661 return list(g) 

662 else: 

663 # Return single atoms object 

664 return next(g) 

665 

666 

667def format_cell(cell: Cell) -> str: 

668 assert cell.rank == 3 

669 lines = [] 

670 for name, value in zip(CIFBlock.cell_tags, cell.cellpar()): 

671 line = f'{name:20} {value}\n' 

672 lines.append(line) 

673 assert len(lines) == 6 

674 return ''.join(lines) 

675 

676 

677def format_generic_spacegroup_info() -> str: 

678 # We assume no symmetry whatsoever 

679 return '\n'.join([ 

680 '_space_group_name_H-M_alt "P 1"', 

681 '_space_group_IT_number 1', 

682 '', 

683 'loop_', 

684 ' _space_group_symop_operation_xyz', 

685 " 'x, y, z'", 

686 '', 

687 ]) 

688 

689 

690class CIFLoop: 

691 def __init__(self): 

692 self.names = [] 

693 self.formats = [] 

694 self.arrays = [] 

695 

696 def add(self, name, array, fmt): 

697 assert name.startswith('_') 

698 self.names.append(name) 

699 self.formats.append(fmt) 

700 self.arrays.append(array) 

701 if len(self.arrays[0]) != len(self.arrays[-1]): 

702 raise ValueError(f'Loop data "{name}" has {len(array)} ' 

703 'elements, expected {len(self.arrays[0])}') 

704 

705 def tostring(self): 

706 lines = [] 

707 append = lines.append 

708 append('loop_') 

709 for name in self.names: 

710 append(f' {name}') 

711 

712 template = ' ' + ' '.join(self.formats) 

713 

714 ncolumns = len(self.arrays) 

715 nrows = len(self.arrays[0]) if ncolumns > 0 else 0 

716 for row in range(nrows): 

717 arraydata = [array[row] for array in self.arrays] 

718 line = template.format(*arraydata) 

719 append(line) 

720 append('') 

721 return '\n'.join(lines) 

722 

723 

724@iofunction('wb') 

725def write_cif(fd, images, cif_format=None, 

726 wrap=True, labels=None, loop_keys=None) -> None: 

727 r"""Write *images* to CIF file. 

728 

729 wrap: bool 

730 Wrap atoms into unit cell. 

731 

732 labels: list 

733 Use this list (shaped list[i_frame][i_atom] = string) for the 

734 '_atom_site_label' section instead of automatically generating 

735 it from the element symbol. 

736 

737 loop_keys: dict 

738 Add the information from this dictionary to the `loop\_` 

739 section. Keys are printed to the `loop\_` section preceeded by 

740 ' _'. dict[key] should contain the data printed for each atom, 

741 so it needs to have the setup `dict[key][i_frame][i_atom] = 

742 string`. The strings are printed as they are, so take care of 

743 formating. Information can be re-read using the `store_tags` 

744 option of the cif reader. 

745 

746 """ 

747 

748 if cif_format is not None: 

749 warnings.warn('The cif_format argument is deprecated and may be ' 

750 'removed in the future. Use loop_keys to customize ' 

751 'data written in loop.', FutureWarning) 

752 

753 if loop_keys is None: 

754 loop_keys = {} 

755 

756 if hasattr(images, 'get_positions'): 

757 images = [images] 

758 

759 fd = io.TextIOWrapper(fd, encoding='latin-1') 

760 try: 

761 for i, atoms in enumerate(images): 

762 blockname = f'data_image{i}\n' 

763 image_loop_keys = {key: loop_keys[key][i] for key in loop_keys} 

764 

765 write_cif_image(blockname, atoms, fd, 

766 wrap=wrap, 

767 labels=None if labels is None else labels[i], 

768 loop_keys=image_loop_keys) 

769 

770 finally: 

771 # Using the TextIOWrapper somehow causes the file to close 

772 # when this function returns. 

773 # Detach in order to circumvent this highly illogical problem: 

774 fd.detach() 

775 

776 

777def autolabel(symbols: Sequence[str]) -> List[str]: 

778 no: Dict[str, int] = {} 

779 labels = [] 

780 for symbol in symbols: 

781 if symbol in no: 

782 no[symbol] += 1 

783 else: 

784 no[symbol] = 1 

785 labels.append('%s%d' % (symbol, no[symbol])) 

786 return labels 

787 

788 

789def chemical_formula_header(atoms): 

790 counts = atoms.symbols.formula.count() 

791 formula_sum = ' '.join(f'{sym}{count}' for sym, count 

792 in counts.items()) 

793 return (f'_chemical_formula_structural {atoms.symbols}\n' 

794 f'_chemical_formula_sum "{formula_sum}"\n') 

795 

796 

797class BadOccupancies(ValueError): 

798 pass 

799 

800 

801def expand_kinds(atoms, coords): 

802 # try to fetch occupancies // spacegroup_kinds - occupancy mapping 

803 symbols = list(atoms.symbols) 

804 coords = list(coords) 

805 occupancies = [1] * len(symbols) 

806 occ_info = atoms.info.get('occupancy') 

807 kinds = atoms.arrays.get('spacegroup_kinds') 

808 if occ_info is not None and kinds is not None: 

809 for i, kind in enumerate(kinds): 

810 occ_info_kind = occ_info[str(kind)] 

811 symbol = symbols[i] 

812 if symbol not in occ_info_kind: 

813 raise BadOccupancies('Occupancies present but no occupancy ' 

814 'info for "{symbol}"') 

815 occupancies[i] = occ_info_kind[symbol] 

816 # extend the positions array in case of mixed occupancy 

817 for sym, occ in occ_info[str(kind)].items(): 

818 if sym != symbols[i]: 

819 symbols.append(sym) 

820 coords.append(coords[i]) 

821 occupancies.append(occ) 

822 return symbols, coords, occupancies 

823 

824 

825def atoms_to_loop_data(atoms, wrap, labels, loop_keys): 

826 if atoms.cell.rank == 3: 

827 coord_type = 'fract' 

828 coords = atoms.get_scaled_positions(wrap).tolist() 

829 else: 

830 coord_type = 'Cartn' 

831 coords = atoms.get_positions(wrap).tolist() 

832 

833 try: 

834 symbols, coords, occupancies = expand_kinds(atoms, coords) 

835 except BadOccupancies as err: 

836 warnings.warn(str(err)) 

837 occupancies = [1] * len(atoms) 

838 symbols = list(atoms.symbols) 

839 

840 if labels is None: 

841 labels = autolabel(symbols) 

842 

843 coord_headers = [f'_atom_site_{coord_type}_{axisname}' 

844 for axisname in 'xyz'] 

845 

846 loopdata = {} 

847 loopdata['_atom_site_label'] = (labels, '{:<8s}') 

848 loopdata['_atom_site_occupancy'] = (occupancies, '{:6.4f}') 

849 

850 _coords = np.array(coords) 

851 for i, key in enumerate(coord_headers): 

852 loopdata[key] = (_coords[:, i], '{}') 

853 

854 loopdata['_atom_site_type_symbol'] = (symbols, '{:<2s}') 

855 loopdata['_atom_site_symmetry_multiplicity'] = ( 

856 [1.0] * len(symbols), '{}') 

857 

858 for key in loop_keys: 

859 # Should expand the loop_keys like we expand the occupancy stuff. 

860 # Otherwise user will never figure out how to do this. 

861 values = [loop_keys[key][i] for i in range(len(symbols))] 

862 loopdata['_' + key] = (values, '{}') 

863 

864 return loopdata, coord_headers 

865 

866 

867def write_cif_image(blockname, atoms, fd, *, wrap, 

868 labels, loop_keys): 

869 fd.write(blockname) 

870 fd.write(chemical_formula_header(atoms)) 

871 

872 rank = atoms.cell.rank 

873 if rank == 3: 

874 fd.write(format_cell(atoms.cell)) 

875 fd.write('\n') 

876 fd.write(format_generic_spacegroup_info()) 

877 fd.write('\n') 

878 elif rank != 0: 

879 raise ValueError('CIF format can only represent systems with ' 

880 f'0 or 3 lattice vectors. Got {rank}.') 

881 

882 loopdata, coord_headers = atoms_to_loop_data(atoms, wrap, labels, 

883 loop_keys) 

884 

885 headers = [ 

886 '_atom_site_type_symbol', 

887 '_atom_site_label', 

888 '_atom_site_symmetry_multiplicity', 

889 *coord_headers, 

890 '_atom_site_occupancy', 

891 ] 

892 

893 headers += ['_' + key for key in loop_keys] 

894 

895 loop = CIFLoop() 

896 for header in headers: 

897 array, fmt = loopdata[header] 

898 loop.add(header, array, fmt) 

899 

900 fd.write(loop.tostring())