Coverage for /builds/ase/ase/ase/io/castep/castep_input_file.py: 78.55%

289 statements  

« prev     ^ index     » next       coverage.py v7.5.3, created at 2025-08-02 00:12 +0000

1# fmt: off 

2 

3import difflib 

4import re 

5import warnings 

6from typing import List, Set 

7 

8import numpy as np 

9 

10from ase import Atoms 

11 

12# A convenient table to avoid the previously used "eval" 

13_tf_table = { 

14 '': True, # Just the keyword is equivalent to True 

15 'True': True, 

16 'False': False} 

17 

18 

19def _parse_tss_block(value, scaled=False): 

20 # Parse the assigned value for a Transition State Search structure block 

21 is_atoms = isinstance(value, Atoms) 

22 try: 

23 is_strlist = all(map(lambda x: isinstance(x, str), value)) 

24 except TypeError: 

25 is_strlist = False 

26 

27 if not is_atoms: 

28 if not is_strlist: 

29 # Invalid! 

30 raise TypeError('castep.cell.positions_abs/frac_intermediate/' 

31 'product expects Atoms object or list of strings') 

32 

33 # First line must be Angstroms, or nothing 

34 has_units = len(value[0].strip().split()) == 1 

35 if (not scaled) and has_units and value[0].strip() != 'ang': 

36 raise RuntimeError('Only ang units currently supported in castep.' 

37 'cell.positions_abs_intermediate/product') 

38 return '\n'.join(map(str.strip, value)) 

39 else: 

40 text_block = '' if scaled else 'ang\n' 

41 positions = (value.get_scaled_positions() if scaled else 

42 value.get_positions()) 

43 symbols = value.get_chemical_symbols() 

44 for s, p in zip(symbols, positions): 

45 text_block += ' {} {:.3f} {:.3f} {:.3f}\n'.format(s, *p) 

46 

47 return text_block 

48 

49 

50class CastepOption: 

51 """"A CASTEP option. It handles basic conversions from string to its value 

52 type.""" 

53 

54 default_convert_types = { 

55 'boolean (logical)': 'bool', 

56 'defined': 'bool', 

57 'string': 'str', 

58 'integer': 'int', 

59 'real': 'float', 

60 'integer vector': 'int_vector', 

61 'real vector': 'float_vector', 

62 'physical': 'float_physical', 

63 'block': 'block' 

64 } 

65 

66 def __init__(self, keyword, level, option_type, value=None, 

67 docstring='No information available'): 

68 self.keyword = keyword 

69 self.level = level 

70 self.type = option_type 

71 self._value = value 

72 self.__doc__ = docstring 

73 

74 @property 

75 def value(self): 

76 

77 if self._value is not None: 

78 if self.type.lower() in ('integer vector', 'real vector', 

79 'physical'): 

80 return ' '.join(map(str, self._value)) 

81 elif self.type.lower() in ('boolean (logical)', 'defined'): 

82 return str(self._value).upper() 

83 else: 

84 return str(self._value) 

85 

86 @property 

87 def raw_value(self): 

88 # The value, not converted to a string 

89 return self._value 

90 

91 @value.setter # type: ignore[attr-defined, no-redef] 

92 def value(self, val): 

93 

94 if val is None: 

95 self.clear() 

96 return 

97 

98 ctype = self.default_convert_types.get(self.type.lower(), 'str') 

99 typeparse = f'_parse_{ctype}' 

100 try: 

101 self._value = getattr(self, typeparse)(val) 

102 except ValueError: 

103 raise ConversionError(ctype, self.keyword, val) 

104 

105 def clear(self): 

106 """Reset the value of the option to None again""" 

107 self._value = None 

108 

109 @staticmethod 

110 def _parse_bool(value): 

111 try: 

112 value = _tf_table[str(value).strip().title()] 

113 except (KeyError, ValueError): 

114 raise ValueError() 

115 return value 

116 

117 @staticmethod 

118 def _parse_str(value): 

119 value = str(value) 

120 return value 

121 

122 @staticmethod 

123 def _parse_int(value): 

124 value = int(value) 

125 return value 

126 

127 @staticmethod 

128 def _parse_float(value): 

129 value = float(value) 

130 return value 

131 

132 @staticmethod 

133 def _parse_int_vector(value): 

134 # Accepts either a string or an actual list/numpy array of ints 

135 if isinstance(value, str): 

136 if ',' in value: 

137 value = value.replace(',', ' ') 

138 value = list(map(int, value.split())) 

139 

140 value = np.array(value) 

141 

142 if value.shape != (3,) or value.dtype != int: 

143 raise ValueError() 

144 

145 return list(value) 

146 

147 @staticmethod 

148 def _parse_float_vector(value): 

149 # Accepts either a string or an actual list/numpy array of floats 

150 if isinstance(value, str): 

151 if ',' in value: 

152 value = value.replace(',', ' ') 

153 value = list(map(float, value.split())) 

154 

155 value = np.array(value) * 1.0 

156 

157 if value.shape != (3,) or value.dtype != float: 

158 raise ValueError() 

159 

160 return list(value) 

161 

162 @staticmethod 

163 def _parse_float_physical(value): 

164 # If this is a string containing units, saves them 

165 if isinstance(value, str): 

166 value = value.split() 

167 

168 try: 

169 l = len(value) 

170 except TypeError: 

171 l = 1 

172 value = [value] 

173 

174 if l == 1: 

175 try: 

176 value = (float(value[0]), '') 

177 except (TypeError, ValueError): 

178 raise ValueError() 

179 elif l == 2: 

180 try: 

181 value = (float(value[0]), value[1]) 

182 except (TypeError, ValueError, IndexError): 

183 raise ValueError() 

184 else: 

185 raise ValueError() 

186 

187 return value 

188 

189 @staticmethod 

190 def _parse_block(value): 

191 

192 if isinstance(value, str): 

193 return value 

194 elif hasattr(value, '__getitem__'): 

195 return '\n'.join(value) # Arrays of lines 

196 else: 

197 raise ValueError() 

198 

199 def __repr__(self): 

200 if self._value: 

201 expr = ('Option: {keyword}({type}, {level}):\n{_value}\n' 

202 ).format(**self.__dict__) 

203 else: 

204 expr = ('Option: {keyword}[unset]({type}, {level})' 

205 ).format(**self.__dict__) 

206 return expr 

207 

208 def __eq__(self, other): 

209 if not isinstance(other, CastepOption): 

210 return False 

211 else: 

212 return self.__dict__ == other.__dict__ 

213 

214 

215class CastepOptionDict: 

216 """A dictionary-like object to hold a set of options for .cell or .param 

217 files loaded from a dictionary, for the sake of validation. 

218 

219 Replaces the old CastepCellDict and CastepParamDict that were defined in 

220 the castep_keywords.py file. 

221 """ 

222 

223 def __init__(self, options=None): 

224 object.__init__(self) 

225 self._options = {} # ComparableDict is not needed any more as 

226 # CastepOptions can be compared directly now 

227 for kw in options: 

228 opt = CastepOption(**options[kw]) 

229 self._options[opt.keyword] = opt 

230 self.__dict__[opt.keyword] = opt 

231 

232 

233class CastepInputFile: 

234 

235 """Master class for CastepParam and CastepCell to inherit from""" 

236 

237 _keyword_conflicts: List[Set[str]] = [] 

238 

239 def __init__(self, options_dict=None, keyword_tolerance=1): 

240 object.__init__(self) 

241 

242 if options_dict is None: 

243 options_dict = CastepOptionDict({}) 

244 

245 self._options = options_dict._options 

246 self.__dict__.update(self._options) 

247 # keyword_tolerance means how strict the checks on new attributes are 

248 # 0 = no new attributes allowed 

249 # 1 = new attributes allowed, warning given 

250 # 2 = new attributes allowed, silent 

251 self._perm = np.clip(keyword_tolerance, 0, 2) 

252 

253 # Compile a dictionary for quick check of conflict sets 

254 self._conflict_dict = { 

255 kw: set(cset).difference({kw}) 

256 for cset in self._keyword_conflicts for kw in cset} 

257 

258 def __repr__(self): 

259 expr = '' 

260 is_default = True 

261 for key, option in sorted(self._options.items()): 

262 if option.value is not None: 

263 is_default = False 

264 expr += ('%20s : %s\n' % (key, option.value)) 

265 if is_default: 

266 expr = 'Default\n' 

267 

268 expr += f'Keyword tolerance: {self._perm}' 

269 return expr 

270 

271 def __setattr__(self, attr, value): 

272 

273 # Hidden attributes are treated normally 

274 if attr.startswith('_'): 

275 self.__dict__[attr] = value 

276 return 

277 

278 if attr not in self._options.keys(): 

279 

280 if self._perm > 0: 

281 # Do we consider it a string or a block? 

282 is_str = isinstance(value, str) 

283 is_block = False 

284 if ((hasattr(value, '__getitem__') and not is_str) 

285 or (is_str and len(value.split('\n')) > 1)): 

286 is_block = True 

287 

288 if self._perm == 0: 

289 similars = difflib.get_close_matches(attr, 

290 self._options.keys()) 

291 if similars: 

292 raise RuntimeError( 

293 f'Option "{attr}" not known! You mean "{similars[0]}"?') 

294 else: 

295 raise RuntimeError(f'Option "{attr}" is not known!') 

296 elif self._perm == 1: 

297 warnings.warn(('Option "%s" is not known and will ' 

298 'be added as a %s') % (attr, 

299 ('block' if is_block else 

300 'string'))) 

301 attr = attr.lower() 

302 opt = CastepOption(keyword=attr, level='Unknown', 

303 option_type='block' if is_block else 'string') 

304 self._options[attr] = opt 

305 self.__dict__[attr] = opt 

306 else: 

307 attr = attr.lower() 

308 opt = self._options[attr] 

309 

310 if not opt.type.lower() == 'block' and isinstance(value, str): 

311 value = value.replace(':', ' ') 

312 

313 # If it is, use the appropriate parser, unless a custom one is defined 

314 attrparse = f'_parse_{attr.lower()}' 

315 

316 # Check for any conflicts if the value is not None 

317 if value is not None: 

318 cset = self._conflict_dict.get(attr.lower(), {}) 

319 for c in cset: 

320 if (c in self._options and self._options[c].value): 

321 warnings.warn( 

322 'option "{attr}" conflicts with "{conflict}" in ' 

323 'calculator. Setting "{conflict}" to ' 

324 'None.'.format(attr=attr, conflict=c)) 

325 self._options[c].value = None 

326 

327 if hasattr(self, attrparse): 

328 self._options[attr].value = self.__getattribute__(attrparse)(value) 

329 else: 

330 self._options[attr].value = value 

331 

332 def __getattr__(self, name): 

333 if name[0] == '_' or self._perm == 0: 

334 raise AttributeError() 

335 

336 if self._perm == 1: 

337 warnings.warn(f'Option {(name)} is not known, returning None') 

338 

339 return CastepOption(keyword='none', level='Unknown', 

340 option_type='string', value=None) 

341 

342 def get_attr_dict(self, raw=False, types=False): 

343 """Settings that go into .param file in a traditional dict""" 

344 

345 attrdict = {k: o.raw_value if raw else o.value 

346 for k, o in self._options.items() if o.value is not None} 

347 

348 if types: 

349 for key, val in attrdict.items(): 

350 attrdict[key] = (val, self._options[key].type) 

351 

352 return attrdict 

353 

354 

355class CastepParam(CastepInputFile): 

356 """CastepParam abstracts the settings that go into the .param file""" 

357 

358 _keyword_conflicts = [{'cut_off_energy', 'basis_precision'}, ] 

359 

360 def __init__(self, castep_keywords, keyword_tolerance=1): 

361 self._castep_version = castep_keywords.castep_version 

362 CastepInputFile.__init__(self, castep_keywords.CastepParamDict(), 

363 keyword_tolerance) 

364 

365 @property 

366 def castep_version(self): 

367 return self._castep_version 

368 

369 # .param specific parsers 

370 def _parse_reuse(self, value): 

371 if value is None: 

372 return None # Reset the value 

373 try: 

374 if self._options['continuation'].value: 

375 warnings.warn('Cannot set reuse if continuation is set, and ' 

376 'vice versa. Set the other to None, if you want ' 

377 'this setting.') 

378 return None 

379 except KeyError: 

380 pass 

381 return 'default' if (value is True) else str(value) 

382 

383 def _parse_continuation(self, value): 

384 if value is None: 

385 return None # Reset the value 

386 try: 

387 if self._options['reuse'].value: 

388 warnings.warn('Cannot set reuse if continuation is set, and ' 

389 'vice versa. Set the other to None, if you want ' 

390 'this setting.') 

391 return None 

392 except KeyError: 

393 pass 

394 return 'default' if (value is True) else str(value) 

395 

396 

397class CastepCell(CastepInputFile): 

398 

399 """CastepCell abstracts all setting that go into the .cell file""" 

400 

401 _keyword_conflicts = [ 

402 {'kpoint_mp_grid', 'kpoint_mp_spacing', 'kpoint_list', 

403 'kpoints_mp_grid', 'kpoints_mp_spacing', 'kpoints_list'}, 

404 {'bs_kpoint_mp_grid', 

405 'bs_kpoint_mp_spacing', 

406 'bs_kpoint_list', 

407 'bs_kpoint_path', 

408 'bs_kpoints_mp_grid', 

409 'bs_kpoints_mp_spacing', 

410 'bs_kpoints_list', 

411 'bs_kpoints_path'}, 

412 {'spectral_kpoint_mp_grid', 

413 'spectral_kpoint_mp_spacing', 

414 'spectral_kpoint_list', 

415 'spectral_kpoint_path', 

416 'spectral_kpoints_mp_grid', 

417 'spectral_kpoints_mp_spacing', 

418 'spectral_kpoints_list', 

419 'spectral_kpoints_path'}, 

420 {'phonon_kpoint_mp_grid', 

421 'phonon_kpoint_mp_spacing', 

422 'phonon_kpoint_list', 

423 'phonon_kpoint_path', 

424 'phonon_kpoints_mp_grid', 

425 'phonon_kpoints_mp_spacing', 

426 'phonon_kpoints_list', 

427 'phonon_kpoints_path'}, 

428 {'fine_phonon_kpoint_mp_grid', 

429 'fine_phonon_kpoint_mp_spacing', 

430 'fine_phonon_kpoint_list', 

431 'fine_phonon_kpoint_path'}, 

432 {'magres_kpoint_mp_grid', 

433 'magres_kpoint_mp_spacing', 

434 'magres_kpoint_list', 

435 'magres_kpoint_path'}, 

436 {'elnes_kpoint_mp_grid', 

437 'elnes_kpoint_mp_spacing', 

438 'elnes_kpoint_list', 

439 'elnes_kpoint_path'}, 

440 {'optics_kpoint_mp_grid', 

441 'optics_kpoint_mp_spacing', 

442 'optics_kpoint_list', 

443 'optics_kpoint_path'}, 

444 {'supercell_kpoint_mp_grid', 

445 'supercell_kpoint_mp_spacing', 

446 'supercell_kpoint_list', 

447 'supercell_kpoint_path'}, ] 

448 

449 def __init__(self, castep_keywords, keyword_tolerance=1): 

450 self._castep_version = castep_keywords.castep_version 

451 CastepInputFile.__init__(self, castep_keywords.CastepCellDict(), 

452 keyword_tolerance) 

453 

454 @property 

455 def castep_version(self): 

456 return self._castep_version 

457 

458 # .cell specific parsers 

459 def _parse_species_pot(self, value): 

460 

461 # Single tuple 

462 if isinstance(value, tuple) and len(value) == 2: 

463 value = [value] 

464 # List of tuples 

465 if hasattr(value, '__getitem__'): 

466 pspots = [tuple(map(str.strip, x)) for x in value] 

467 if not all(map(lambda x: len(x) == 2, value)): 

468 warnings.warn( 

469 'Please specify pseudopotentials in python as ' 

470 'a tuple or a list of tuples formatted like: ' 

471 '(species, file), e.g. ("O", "path-to/O_OTFG.usp") ' 

472 'Anything else will be ignored') 

473 return None 

474 

475 text_block = self._options['species_pot'].value 

476 

477 text_block = text_block if text_block else '' 

478 # Remove any duplicates 

479 for pp in pspots: 

480 text_block = re.sub(fr'\n?\s*{pp[0]}\s+.*', '', text_block) 

481 if pp[1]: 

482 text_block += '\n%s %s' % pp 

483 

484 return text_block 

485 

486 def _parse_symmetry_ops(self, value): 

487 if not isinstance(value, tuple) \ 

488 or not len(value) == 2 \ 

489 or not value[0].shape[1:] == (3, 3) \ 

490 or not value[1].shape[1:] == (3,) \ 

491 or not value[0].shape[0] == value[1].shape[0]: 

492 warnings.warn('Invalid symmetry_ops block, skipping') 

493 return 

494 # Now on to print... 

495 text_block = '' 

496 for op_i, (op_rot, op_tranls) in enumerate(zip(*value)): 

497 text_block += '\n'.join([' '.join([str(x) for x in row]) 

498 for row in op_rot]) 

499 text_block += '\n' 

500 text_block += ' '.join([str(x) for x in op_tranls]) 

501 text_block += '\n\n' 

502 

503 return text_block 

504 

505 def _parse_positions_abs_intermediate(self, value): 

506 return _parse_tss_block(value) 

507 

508 def _parse_positions_abs_product(self, value): 

509 return _parse_tss_block(value) 

510 

511 def _parse_positions_frac_intermediate(self, value): 

512 return _parse_tss_block(value, True) 

513 

514 def _parse_positions_frac_product(self, value): 

515 return _parse_tss_block(value, True) 

516 

517 

518class ConversionError(Exception): 

519 

520 """Print customized error for options that are not converted correctly 

521 and point out that they are maybe not implemented, yet""" 

522 

523 def __init__(self, key_type, attr, value): 

524 Exception.__init__(self) 

525 self.key_type = key_type 

526 self.value = value 

527 self.attr = attr 

528 

529 def __str__(self): 

530 contact_email = 'simon.rittmeyer@tum.de' 

531 return f'Could not convert {self.attr} = {self.value} '\ 

532 + 'to {self.key_type}\n' \ 

533 + 'This means you either tried to set a value of the wrong\n'\ 

534 + 'type or this keyword needs some special care. Please feel\n'\ 

535 + 'to add it to the corresponding __setattr__ method and send\n'\ 

536 + f'the patch to {(contact_email)}, so we can all benefit.'