Coverage for ase / io / castep / castep_input_file.py: 78.47%

288 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-03-30 08:22 +0000

1# fmt: off 

2 

3import difflib 

4import re 

5import warnings 

6 

7import numpy as np 

8 

9from ase import Atoms 

10 

11# A convenient table to avoid the previously used "eval" 

12_tf_table = { 

13 '': True, # Just the keyword is equivalent to True 

14 'True': True, 

15 'False': False} 

16 

17 

18def _parse_tss_block(value, scaled=False): 

19 # Parse the assigned value for a Transition State Search structure block 

20 is_atoms = isinstance(value, Atoms) 

21 try: 

22 is_strlist = all(map(lambda x: isinstance(x, str), value)) 

23 except TypeError: 

24 is_strlist = False 

25 

26 if not is_atoms: 

27 if not is_strlist: 

28 # Invalid! 

29 raise TypeError('castep.cell.positions_abs/frac_intermediate/' 

30 'product expects Atoms object or list of strings') 

31 

32 # First line must be Angstroms, or nothing 

33 has_units = len(value[0].strip().split()) == 1 

34 if (not scaled) and has_units and value[0].strip() != 'ang': 

35 raise RuntimeError('Only ang units currently supported in castep.' 

36 'cell.positions_abs_intermediate/product') 

37 return '\n'.join(map(str.strip, value)) 

38 else: 

39 text_block = '' if scaled else 'ang\n' 

40 positions = (value.get_scaled_positions() if scaled else 

41 value.get_positions()) 

42 symbols = value.get_chemical_symbols() 

43 for s, p in zip(symbols, positions): 

44 text_block += ' {} {:.3f} {:.3f} {:.3f}\n'.format(s, *p) 

45 

46 return text_block 

47 

48 

49class CastepOption: 

50 """"A CASTEP option. It handles basic conversions from string to its value 

51 type.""" 

52 

53 default_convert_types = { 

54 'boolean (logical)': 'bool', 

55 'defined': 'bool', 

56 'string': 'str', 

57 'integer': 'int', 

58 'real': 'float', 

59 'integer vector': 'int_vector', 

60 'real vector': 'float_vector', 

61 'physical': 'float_physical', 

62 'block': 'block' 

63 } 

64 

65 def __init__(self, keyword, level, option_type, value=None, 

66 docstring='No information available'): 

67 self.keyword = keyword 

68 self.level = level 

69 self.type = option_type 

70 self._value = value 

71 self.__doc__ = docstring 

72 

73 @property 

74 def value(self): 

75 

76 if self._value is not None: 

77 if self.type.lower() in ('integer vector', 'real vector', 

78 'physical'): 

79 return ' '.join(map(str, self._value)) 

80 elif self.type.lower() in ('boolean (logical)', 'defined'): 

81 return str(self._value).upper() 

82 else: 

83 return str(self._value) 

84 

85 @property 

86 def raw_value(self): 

87 # The value, not converted to a string 

88 return self._value 

89 

90 @value.setter # type: ignore[attr-defined, no-redef] 

91 def value(self, val): 

92 

93 if val is None: 

94 self.clear() 

95 return 

96 

97 ctype = self.default_convert_types.get(self.type.lower(), 'str') 

98 typeparse = f'_parse_{ctype}' 

99 try: 

100 self._value = getattr(self, typeparse)(val) 

101 except ValueError: 

102 raise ConversionError(ctype, self.keyword, val) 

103 

104 def clear(self): 

105 """Reset the value of the option to None again""" 

106 self._value = None 

107 

108 @staticmethod 

109 def _parse_bool(value): 

110 try: 

111 value = _tf_table[str(value).strip().title()] 

112 except (KeyError, ValueError): 

113 raise ValueError() 

114 return value 

115 

116 @staticmethod 

117 def _parse_str(value): 

118 value = str(value) 

119 return value 

120 

121 @staticmethod 

122 def _parse_int(value): 

123 value = int(value) 

124 return value 

125 

126 @staticmethod 

127 def _parse_float(value): 

128 value = float(value) 

129 return value 

130 

131 @staticmethod 

132 def _parse_int_vector(value): 

133 # Accepts either a string or an actual list/numpy array of ints 

134 if isinstance(value, str): 

135 if ',' in value: 

136 value = value.replace(',', ' ') 

137 value = list(map(int, value.split())) 

138 

139 value = np.array(value) 

140 

141 if value.shape != (3,) or value.dtype != int: 

142 raise ValueError() 

143 

144 return list(value) 

145 

146 @staticmethod 

147 def _parse_float_vector(value): 

148 # Accepts either a string or an actual list/numpy array of floats 

149 if isinstance(value, str): 

150 if ',' in value: 

151 value = value.replace(',', ' ') 

152 value = list(map(float, value.split())) 

153 

154 value = np.array(value) * 1.0 

155 

156 if value.shape != (3,) or value.dtype != float: 

157 raise ValueError() 

158 

159 return list(value) 

160 

161 @staticmethod 

162 def _parse_float_physical(value): 

163 # If this is a string containing units, saves them 

164 if isinstance(value, str): 

165 value = value.split() 

166 

167 try: 

168 l = len(value) 

169 except TypeError: 

170 l = 1 

171 value = [value] 

172 

173 if l == 1: 

174 try: 

175 value = (float(value[0]), '') 

176 except (TypeError, ValueError): 

177 raise ValueError() 

178 elif l == 2: 

179 try: 

180 value = (float(value[0]), value[1]) 

181 except (TypeError, ValueError, IndexError): 

182 raise ValueError() 

183 else: 

184 raise ValueError() 

185 

186 return value 

187 

188 @staticmethod 

189 def _parse_block(value): 

190 

191 if isinstance(value, str): 

192 return value 

193 elif hasattr(value, '__getitem__'): 

194 return '\n'.join(value) # Arrays of lines 

195 else: 

196 raise ValueError() 

197 

198 def __repr__(self): 

199 if self._value: 

200 expr = ('Option: {keyword}({type}, {level}):\n{_value}\n' 

201 ).format(**self.__dict__) 

202 else: 

203 expr = ('Option: {keyword}[unset]({type}, {level})' 

204 ).format(**self.__dict__) 

205 return expr 

206 

207 def __eq__(self, other): 

208 if not isinstance(other, CastepOption): 

209 return False 

210 else: 

211 return self.__dict__ == other.__dict__ 

212 

213 

214class CastepOptionDict: 

215 """A dictionary-like object to hold a set of options for .cell or .param 

216 files loaded from a dictionary, for the sake of validation. 

217 

218 Replaces the old CastepCellDict and CastepParamDict that were defined in 

219 the castep_keywords.py file. 

220 """ 

221 

222 def __init__(self, options=None): 

223 object.__init__(self) 

224 self._options = {} # ComparableDict is not needed any more as 

225 # CastepOptions can be compared directly now 

226 for kw in options: 

227 opt = CastepOption(**options[kw]) 

228 self._options[opt.keyword] = opt 

229 self.__dict__[opt.keyword] = opt 

230 

231 

232class CastepInputFile: 

233 

234 """Master class for CastepParam and CastepCell to inherit from""" 

235 

236 _keyword_conflicts: list[set[str]] = [] 

237 

238 def __init__(self, options_dict=None, keyword_tolerance=1): 

239 object.__init__(self) 

240 

241 if options_dict is None: 

242 options_dict = CastepOptionDict({}) 

243 

244 self._options = options_dict._options 

245 self.__dict__.update(self._options) 

246 # keyword_tolerance means how strict the checks on new attributes are 

247 # 0 = no new attributes allowed 

248 # 1 = new attributes allowed, warning given 

249 # 2 = new attributes allowed, silent 

250 self._perm = np.clip(keyword_tolerance, 0, 2) 

251 

252 # Compile a dictionary for quick check of conflict sets 

253 self._conflict_dict = { 

254 kw: set(cset).difference({kw}) 

255 for cset in self._keyword_conflicts for kw in cset} 

256 

257 def __repr__(self): 

258 expr = '' 

259 is_default = True 

260 for key, option in sorted(self._options.items()): 

261 if option.value is not None: 

262 is_default = False 

263 expr += ('%20s : %s\n' % (key, option.value)) 

264 if is_default: 

265 expr = 'Default\n' 

266 

267 expr += f'Keyword tolerance: {self._perm}' 

268 return expr 

269 

270 def __setattr__(self, attr, value): 

271 

272 # Hidden attributes are treated normally 

273 if attr.startswith('_'): 

274 self.__dict__[attr] = value 

275 return 

276 

277 if attr not in self._options.keys(): 

278 

279 if self._perm > 0: 

280 # Do we consider it a string or a block? 

281 is_str = isinstance(value, str) 

282 is_block = False 

283 if ((hasattr(value, '__getitem__') and not is_str) 

284 or (is_str and len(value.split('\n')) > 1)): 

285 is_block = True 

286 

287 if self._perm == 0: 

288 similars = difflib.get_close_matches(attr, 

289 self._options.keys()) 

290 if similars: 

291 raise RuntimeError( 

292 f'Option "{attr}" not known! You mean "{similars[0]}"?') 

293 else: 

294 raise RuntimeError(f'Option "{attr}" is not known!') 

295 elif self._perm == 1: 

296 warnings.warn(('Option "%s" is not known and will ' 

297 'be added as a %s') % (attr, 

298 ('block' if is_block else 

299 'string'))) 

300 attr = attr.lower() 

301 opt = CastepOption(keyword=attr, level='Unknown', 

302 option_type='block' if is_block else 'string') 

303 self._options[attr] = opt 

304 self.__dict__[attr] = opt 

305 else: 

306 attr = attr.lower() 

307 opt = self._options[attr] 

308 

309 if not opt.type.lower() == 'block' and isinstance(value, str): 

310 value = value.replace(':', ' ') 

311 

312 # If it is, use the appropriate parser, unless a custom one is defined 

313 attrparse = f'_parse_{attr.lower()}' 

314 

315 # Check for any conflicts if the value is not None 

316 if value is not None: 

317 cset = self._conflict_dict.get(attr.lower(), {}) 

318 for c in cset: 

319 if (c in self._options and self._options[c].value): 

320 warnings.warn( 

321 'option "{attr}" conflicts with "{conflict}" in ' 

322 'calculator. Setting "{conflict}" to ' 

323 'None.'.format(attr=attr, conflict=c)) 

324 self._options[c].value = None 

325 

326 if hasattr(self, attrparse): 

327 self._options[attr].value = self.__getattribute__(attrparse)(value) 

328 else: 

329 self._options[attr].value = value 

330 

331 def __getattr__(self, name): 

332 if name[0] == '_' or self._perm == 0: 

333 raise AttributeError() 

334 

335 if self._perm == 1: 

336 warnings.warn(f'Option {(name)} is not known, returning None') 

337 

338 return CastepOption(keyword='none', level='Unknown', 

339 option_type='string', value=None) 

340 

341 def get_attr_dict(self, raw=False, types=False): 

342 """Settings that go into .param file in a traditional dict""" 

343 

344 attrdict = {k: o.raw_value if raw else o.value 

345 for k, o in self._options.items() if o.value is not None} 

346 

347 if types: 

348 for key, val in attrdict.items(): 

349 attrdict[key] = (val, self._options[key].type) 

350 

351 return attrdict 

352 

353 

354class CastepParam(CastepInputFile): 

355 """CastepParam abstracts the settings that go into the .param file""" 

356 

357 _keyword_conflicts = [{'cut_off_energy', 'basis_precision'}, ] 

358 

359 def __init__(self, castep_keywords, keyword_tolerance=1): 

360 self._castep_version = castep_keywords.castep_version 

361 CastepInputFile.__init__(self, castep_keywords.CastepParamDict(), 

362 keyword_tolerance) 

363 

364 @property 

365 def castep_version(self): 

366 return self._castep_version 

367 

368 # .param specific parsers 

369 def _parse_reuse(self, value): 

370 if value is None: 

371 return None # Reset the value 

372 try: 

373 if self._options['continuation'].value: 

374 warnings.warn('Cannot set reuse if continuation is set, and ' 

375 'vice versa. Set the other to None, if you want ' 

376 'this setting.') 

377 return None 

378 except KeyError: 

379 pass 

380 return 'default' if (value is True) else str(value) 

381 

382 def _parse_continuation(self, value): 

383 if value is None: 

384 return None # Reset the value 

385 try: 

386 if self._options['reuse'].value: 

387 warnings.warn('Cannot set reuse if continuation is set, and ' 

388 'vice versa. Set the other to None, if you want ' 

389 'this setting.') 

390 return None 

391 except KeyError: 

392 pass 

393 return 'default' if (value is True) else str(value) 

394 

395 

396class CastepCell(CastepInputFile): 

397 

398 """CastepCell abstracts all setting that go into the .cell file""" 

399 

400 _keyword_conflicts = [ 

401 {'kpoint_mp_grid', 'kpoint_mp_spacing', 'kpoint_list', 

402 'kpoints_mp_grid', 'kpoints_mp_spacing', 'kpoints_list'}, 

403 {'bs_kpoint_mp_grid', 

404 'bs_kpoint_mp_spacing', 

405 'bs_kpoint_list', 

406 'bs_kpoint_path', 

407 'bs_kpoints_mp_grid', 

408 'bs_kpoints_mp_spacing', 

409 'bs_kpoints_list', 

410 'bs_kpoints_path'}, 

411 {'spectral_kpoint_mp_grid', 

412 'spectral_kpoint_mp_spacing', 

413 'spectral_kpoint_list', 

414 'spectral_kpoint_path', 

415 'spectral_kpoints_mp_grid', 

416 'spectral_kpoints_mp_spacing', 

417 'spectral_kpoints_list', 

418 'spectral_kpoints_path'}, 

419 {'phonon_kpoint_mp_grid', 

420 'phonon_kpoint_mp_spacing', 

421 'phonon_kpoint_list', 

422 'phonon_kpoint_path', 

423 'phonon_kpoints_mp_grid', 

424 'phonon_kpoints_mp_spacing', 

425 'phonon_kpoints_list', 

426 'phonon_kpoints_path'}, 

427 {'fine_phonon_kpoint_mp_grid', 

428 'fine_phonon_kpoint_mp_spacing', 

429 'fine_phonon_kpoint_list', 

430 'fine_phonon_kpoint_path'}, 

431 {'magres_kpoint_mp_grid', 

432 'magres_kpoint_mp_spacing', 

433 'magres_kpoint_list', 

434 'magres_kpoint_path'}, 

435 {'elnes_kpoint_mp_grid', 

436 'elnes_kpoint_mp_spacing', 

437 'elnes_kpoint_list', 

438 'elnes_kpoint_path'}, 

439 {'optics_kpoint_mp_grid', 

440 'optics_kpoint_mp_spacing', 

441 'optics_kpoint_list', 

442 'optics_kpoint_path'}, 

443 {'supercell_kpoint_mp_grid', 

444 'supercell_kpoint_mp_spacing', 

445 'supercell_kpoint_list', 

446 'supercell_kpoint_path'}, ] 

447 

448 def __init__(self, castep_keywords, keyword_tolerance=1): 

449 self._castep_version = castep_keywords.castep_version 

450 CastepInputFile.__init__(self, castep_keywords.CastepCellDict(), 

451 keyword_tolerance) 

452 

453 @property 

454 def castep_version(self): 

455 return self._castep_version 

456 

457 # .cell specific parsers 

458 def _parse_species_pot(self, value): 

459 

460 # Single tuple 

461 if isinstance(value, tuple) and len(value) == 2: 

462 value = [value] 

463 # List of tuples 

464 if hasattr(value, '__getitem__'): 

465 pspots = [tuple(map(str.strip, x)) for x in value] 

466 if not all(map(lambda x: len(x) == 2, value)): 

467 warnings.warn( 

468 'Please specify pseudopotentials in python as ' 

469 'a tuple or a list of tuples formatted like: ' 

470 '(species, file), e.g. ("O", "path-to/O_OTFG.usp") ' 

471 'Anything else will be ignored') 

472 return None 

473 

474 text_block = self._options['species_pot'].value 

475 

476 text_block = text_block if text_block else '' 

477 # Remove any duplicates 

478 for pp in pspots: 

479 text_block = re.sub(fr'\n?\s*{pp[0]}\s+.*', '', text_block) 

480 if pp[1]: 

481 text_block += '\n%s %s' % pp 

482 

483 return text_block 

484 

485 def _parse_symmetry_ops(self, value): 

486 if not isinstance(value, tuple) \ 

487 or not len(value) == 2 \ 

488 or not value[0].shape[1:] == (3, 3) \ 

489 or not value[1].shape[1:] == (3,) \ 

490 or not value[0].shape[0] == value[1].shape[0]: 

491 warnings.warn('Invalid symmetry_ops block, skipping') 

492 return 

493 # Now on to print... 

494 text_block = '' 

495 for op_i, (op_rot, op_tranls) in enumerate(zip(*value)): 

496 text_block += '\n'.join([' '.join([str(x) for x in row]) 

497 for row in op_rot]) 

498 text_block += '\n' 

499 text_block += ' '.join([str(x) for x in op_tranls]) 

500 text_block += '\n\n' 

501 

502 return text_block 

503 

504 def _parse_positions_abs_intermediate(self, value): 

505 return _parse_tss_block(value) 

506 

507 def _parse_positions_abs_product(self, value): 

508 return _parse_tss_block(value) 

509 

510 def _parse_positions_frac_intermediate(self, value): 

511 return _parse_tss_block(value, True) 

512 

513 def _parse_positions_frac_product(self, value): 

514 return _parse_tss_block(value, True) 

515 

516 

517class ConversionError(Exception): 

518 

519 """Print customized error for options that are not converted correctly 

520 and point out that they are maybe not implemented, yet""" 

521 

522 def __init__(self, key_type, attr, value): 

523 Exception.__init__(self) 

524 self.key_type = key_type 

525 self.value = value 

526 self.attr = attr 

527 

528 def __str__(self): 

529 contact_email = 'simon.rittmeyer@tum.de' 

530 return f'Could not convert {self.attr} = {self.value} '\ 

531 + 'to {self.key_type}\n' \ 

532 + 'This means you either tried to set a value of the wrong\n'\ 

533 + 'type or this keyword needs some special care. Please feel\n'\ 

534 + 'to add it to the corresponding __setattr__ method and send\n'\ 

535 + f'the patch to {(contact_email)}, so we can all benefit.'