Coverage for /builds/ase/ase/ase/io/castep/castep_input_file.py: 78.55%
289 statements
« prev ^ index » next coverage.py v7.5.3, created at 2025-08-02 00:12 +0000
« prev ^ index » next coverage.py v7.5.3, created at 2025-08-02 00:12 +0000
1# fmt: off
3import difflib
4import re
5import warnings
6from typing import List, Set
8import numpy as np
10from ase import Atoms
12# A convenient table to avoid the previously used "eval"
13_tf_table = {
14 '': True, # Just the keyword is equivalent to True
15 'True': True,
16 'False': False}
19def _parse_tss_block(value, scaled=False):
20 # Parse the assigned value for a Transition State Search structure block
21 is_atoms = isinstance(value, Atoms)
22 try:
23 is_strlist = all(map(lambda x: isinstance(x, str), value))
24 except TypeError:
25 is_strlist = False
27 if not is_atoms:
28 if not is_strlist:
29 # Invalid!
30 raise TypeError('castep.cell.positions_abs/frac_intermediate/'
31 'product expects Atoms object or list of strings')
33 # First line must be Angstroms, or nothing
34 has_units = len(value[0].strip().split()) == 1
35 if (not scaled) and has_units and value[0].strip() != 'ang':
36 raise RuntimeError('Only ang units currently supported in castep.'
37 'cell.positions_abs_intermediate/product')
38 return '\n'.join(map(str.strip, value))
39 else:
40 text_block = '' if scaled else 'ang\n'
41 positions = (value.get_scaled_positions() if scaled else
42 value.get_positions())
43 symbols = value.get_chemical_symbols()
44 for s, p in zip(symbols, positions):
45 text_block += ' {} {:.3f} {:.3f} {:.3f}\n'.format(s, *p)
47 return text_block
50class CastepOption:
51 """"A CASTEP option. It handles basic conversions from string to its value
52 type."""
54 default_convert_types = {
55 'boolean (logical)': 'bool',
56 'defined': 'bool',
57 'string': 'str',
58 'integer': 'int',
59 'real': 'float',
60 'integer vector': 'int_vector',
61 'real vector': 'float_vector',
62 'physical': 'float_physical',
63 'block': 'block'
64 }
66 def __init__(self, keyword, level, option_type, value=None,
67 docstring='No information available'):
68 self.keyword = keyword
69 self.level = level
70 self.type = option_type
71 self._value = value
72 self.__doc__ = docstring
74 @property
75 def value(self):
77 if self._value is not None:
78 if self.type.lower() in ('integer vector', 'real vector',
79 'physical'):
80 return ' '.join(map(str, self._value))
81 elif self.type.lower() in ('boolean (logical)', 'defined'):
82 return str(self._value).upper()
83 else:
84 return str(self._value)
86 @property
87 def raw_value(self):
88 # The value, not converted to a string
89 return self._value
91 @value.setter # type: ignore[attr-defined, no-redef]
92 def value(self, val):
94 if val is None:
95 self.clear()
96 return
98 ctype = self.default_convert_types.get(self.type.lower(), 'str')
99 typeparse = f'_parse_{ctype}'
100 try:
101 self._value = getattr(self, typeparse)(val)
102 except ValueError:
103 raise ConversionError(ctype, self.keyword, val)
105 def clear(self):
106 """Reset the value of the option to None again"""
107 self._value = None
109 @staticmethod
110 def _parse_bool(value):
111 try:
112 value = _tf_table[str(value).strip().title()]
113 except (KeyError, ValueError):
114 raise ValueError()
115 return value
117 @staticmethod
118 def _parse_str(value):
119 value = str(value)
120 return value
122 @staticmethod
123 def _parse_int(value):
124 value = int(value)
125 return value
127 @staticmethod
128 def _parse_float(value):
129 value = float(value)
130 return value
132 @staticmethod
133 def _parse_int_vector(value):
134 # Accepts either a string or an actual list/numpy array of ints
135 if isinstance(value, str):
136 if ',' in value:
137 value = value.replace(',', ' ')
138 value = list(map(int, value.split()))
140 value = np.array(value)
142 if value.shape != (3,) or value.dtype != int:
143 raise ValueError()
145 return list(value)
147 @staticmethod
148 def _parse_float_vector(value):
149 # Accepts either a string or an actual list/numpy array of floats
150 if isinstance(value, str):
151 if ',' in value:
152 value = value.replace(',', ' ')
153 value = list(map(float, value.split()))
155 value = np.array(value) * 1.0
157 if value.shape != (3,) or value.dtype != float:
158 raise ValueError()
160 return list(value)
162 @staticmethod
163 def _parse_float_physical(value):
164 # If this is a string containing units, saves them
165 if isinstance(value, str):
166 value = value.split()
168 try:
169 l = len(value)
170 except TypeError:
171 l = 1
172 value = [value]
174 if l == 1:
175 try:
176 value = (float(value[0]), '')
177 except (TypeError, ValueError):
178 raise ValueError()
179 elif l == 2:
180 try:
181 value = (float(value[0]), value[1])
182 except (TypeError, ValueError, IndexError):
183 raise ValueError()
184 else:
185 raise ValueError()
187 return value
189 @staticmethod
190 def _parse_block(value):
192 if isinstance(value, str):
193 return value
194 elif hasattr(value, '__getitem__'):
195 return '\n'.join(value) # Arrays of lines
196 else:
197 raise ValueError()
199 def __repr__(self):
200 if self._value:
201 expr = ('Option: {keyword}({type}, {level}):\n{_value}\n'
202 ).format(**self.__dict__)
203 else:
204 expr = ('Option: {keyword}[unset]({type}, {level})'
205 ).format(**self.__dict__)
206 return expr
208 def __eq__(self, other):
209 if not isinstance(other, CastepOption):
210 return False
211 else:
212 return self.__dict__ == other.__dict__
215class CastepOptionDict:
216 """A dictionary-like object to hold a set of options for .cell or .param
217 files loaded from a dictionary, for the sake of validation.
219 Replaces the old CastepCellDict and CastepParamDict that were defined in
220 the castep_keywords.py file.
221 """
223 def __init__(self, options=None):
224 object.__init__(self)
225 self._options = {} # ComparableDict is not needed any more as
226 # CastepOptions can be compared directly now
227 for kw in options:
228 opt = CastepOption(**options[kw])
229 self._options[opt.keyword] = opt
230 self.__dict__[opt.keyword] = opt
233class CastepInputFile:
235 """Master class for CastepParam and CastepCell to inherit from"""
237 _keyword_conflicts: List[Set[str]] = []
239 def __init__(self, options_dict=None, keyword_tolerance=1):
240 object.__init__(self)
242 if options_dict is None:
243 options_dict = CastepOptionDict({})
245 self._options = options_dict._options
246 self.__dict__.update(self._options)
247 # keyword_tolerance means how strict the checks on new attributes are
248 # 0 = no new attributes allowed
249 # 1 = new attributes allowed, warning given
250 # 2 = new attributes allowed, silent
251 self._perm = np.clip(keyword_tolerance, 0, 2)
253 # Compile a dictionary for quick check of conflict sets
254 self._conflict_dict = {
255 kw: set(cset).difference({kw})
256 for cset in self._keyword_conflicts for kw in cset}
258 def __repr__(self):
259 expr = ''
260 is_default = True
261 for key, option in sorted(self._options.items()):
262 if option.value is not None:
263 is_default = False
264 expr += ('%20s : %s\n' % (key, option.value))
265 if is_default:
266 expr = 'Default\n'
268 expr += f'Keyword tolerance: {self._perm}'
269 return expr
271 def __setattr__(self, attr, value):
273 # Hidden attributes are treated normally
274 if attr.startswith('_'):
275 self.__dict__[attr] = value
276 return
278 if attr not in self._options.keys():
280 if self._perm > 0:
281 # Do we consider it a string or a block?
282 is_str = isinstance(value, str)
283 is_block = False
284 if ((hasattr(value, '__getitem__') and not is_str)
285 or (is_str and len(value.split('\n')) > 1)):
286 is_block = True
288 if self._perm == 0:
289 similars = difflib.get_close_matches(attr,
290 self._options.keys())
291 if similars:
292 raise RuntimeError(
293 f'Option "{attr}" not known! You mean "{similars[0]}"?')
294 else:
295 raise RuntimeError(f'Option "{attr}" is not known!')
296 elif self._perm == 1:
297 warnings.warn(('Option "%s" is not known and will '
298 'be added as a %s') % (attr,
299 ('block' if is_block else
300 'string')))
301 attr = attr.lower()
302 opt = CastepOption(keyword=attr, level='Unknown',
303 option_type='block' if is_block else 'string')
304 self._options[attr] = opt
305 self.__dict__[attr] = opt
306 else:
307 attr = attr.lower()
308 opt = self._options[attr]
310 if not opt.type.lower() == 'block' and isinstance(value, str):
311 value = value.replace(':', ' ')
313 # If it is, use the appropriate parser, unless a custom one is defined
314 attrparse = f'_parse_{attr.lower()}'
316 # Check for any conflicts if the value is not None
317 if value is not None:
318 cset = self._conflict_dict.get(attr.lower(), {})
319 for c in cset:
320 if (c in self._options and self._options[c].value):
321 warnings.warn(
322 'option "{attr}" conflicts with "{conflict}" in '
323 'calculator. Setting "{conflict}" to '
324 'None.'.format(attr=attr, conflict=c))
325 self._options[c].value = None
327 if hasattr(self, attrparse):
328 self._options[attr].value = self.__getattribute__(attrparse)(value)
329 else:
330 self._options[attr].value = value
332 def __getattr__(self, name):
333 if name[0] == '_' or self._perm == 0:
334 raise AttributeError()
336 if self._perm == 1:
337 warnings.warn(f'Option {(name)} is not known, returning None')
339 return CastepOption(keyword='none', level='Unknown',
340 option_type='string', value=None)
342 def get_attr_dict(self, raw=False, types=False):
343 """Settings that go into .param file in a traditional dict"""
345 attrdict = {k: o.raw_value if raw else o.value
346 for k, o in self._options.items() if o.value is not None}
348 if types:
349 for key, val in attrdict.items():
350 attrdict[key] = (val, self._options[key].type)
352 return attrdict
355class CastepParam(CastepInputFile):
356 """CastepParam abstracts the settings that go into the .param file"""
358 _keyword_conflicts = [{'cut_off_energy', 'basis_precision'}, ]
360 def __init__(self, castep_keywords, keyword_tolerance=1):
361 self._castep_version = castep_keywords.castep_version
362 CastepInputFile.__init__(self, castep_keywords.CastepParamDict(),
363 keyword_tolerance)
365 @property
366 def castep_version(self):
367 return self._castep_version
369 # .param specific parsers
370 def _parse_reuse(self, value):
371 if value is None:
372 return None # Reset the value
373 try:
374 if self._options['continuation'].value:
375 warnings.warn('Cannot set reuse if continuation is set, and '
376 'vice versa. Set the other to None, if you want '
377 'this setting.')
378 return None
379 except KeyError:
380 pass
381 return 'default' if (value is True) else str(value)
383 def _parse_continuation(self, value):
384 if value is None:
385 return None # Reset the value
386 try:
387 if self._options['reuse'].value:
388 warnings.warn('Cannot set reuse if continuation is set, and '
389 'vice versa. Set the other to None, if you want '
390 'this setting.')
391 return None
392 except KeyError:
393 pass
394 return 'default' if (value is True) else str(value)
397class CastepCell(CastepInputFile):
399 """CastepCell abstracts all setting that go into the .cell file"""
401 _keyword_conflicts = [
402 {'kpoint_mp_grid', 'kpoint_mp_spacing', 'kpoint_list',
403 'kpoints_mp_grid', 'kpoints_mp_spacing', 'kpoints_list'},
404 {'bs_kpoint_mp_grid',
405 'bs_kpoint_mp_spacing',
406 'bs_kpoint_list',
407 'bs_kpoint_path',
408 'bs_kpoints_mp_grid',
409 'bs_kpoints_mp_spacing',
410 'bs_kpoints_list',
411 'bs_kpoints_path'},
412 {'spectral_kpoint_mp_grid',
413 'spectral_kpoint_mp_spacing',
414 'spectral_kpoint_list',
415 'spectral_kpoint_path',
416 'spectral_kpoints_mp_grid',
417 'spectral_kpoints_mp_spacing',
418 'spectral_kpoints_list',
419 'spectral_kpoints_path'},
420 {'phonon_kpoint_mp_grid',
421 'phonon_kpoint_mp_spacing',
422 'phonon_kpoint_list',
423 'phonon_kpoint_path',
424 'phonon_kpoints_mp_grid',
425 'phonon_kpoints_mp_spacing',
426 'phonon_kpoints_list',
427 'phonon_kpoints_path'},
428 {'fine_phonon_kpoint_mp_grid',
429 'fine_phonon_kpoint_mp_spacing',
430 'fine_phonon_kpoint_list',
431 'fine_phonon_kpoint_path'},
432 {'magres_kpoint_mp_grid',
433 'magres_kpoint_mp_spacing',
434 'magres_kpoint_list',
435 'magres_kpoint_path'},
436 {'elnes_kpoint_mp_grid',
437 'elnes_kpoint_mp_spacing',
438 'elnes_kpoint_list',
439 'elnes_kpoint_path'},
440 {'optics_kpoint_mp_grid',
441 'optics_kpoint_mp_spacing',
442 'optics_kpoint_list',
443 'optics_kpoint_path'},
444 {'supercell_kpoint_mp_grid',
445 'supercell_kpoint_mp_spacing',
446 'supercell_kpoint_list',
447 'supercell_kpoint_path'}, ]
449 def __init__(self, castep_keywords, keyword_tolerance=1):
450 self._castep_version = castep_keywords.castep_version
451 CastepInputFile.__init__(self, castep_keywords.CastepCellDict(),
452 keyword_tolerance)
454 @property
455 def castep_version(self):
456 return self._castep_version
458 # .cell specific parsers
459 def _parse_species_pot(self, value):
461 # Single tuple
462 if isinstance(value, tuple) and len(value) == 2:
463 value = [value]
464 # List of tuples
465 if hasattr(value, '__getitem__'):
466 pspots = [tuple(map(str.strip, x)) for x in value]
467 if not all(map(lambda x: len(x) == 2, value)):
468 warnings.warn(
469 'Please specify pseudopotentials in python as '
470 'a tuple or a list of tuples formatted like: '
471 '(species, file), e.g. ("O", "path-to/O_OTFG.usp") '
472 'Anything else will be ignored')
473 return None
475 text_block = self._options['species_pot'].value
477 text_block = text_block if text_block else ''
478 # Remove any duplicates
479 for pp in pspots:
480 text_block = re.sub(fr'\n?\s*{pp[0]}\s+.*', '', text_block)
481 if pp[1]:
482 text_block += '\n%s %s' % pp
484 return text_block
486 def _parse_symmetry_ops(self, value):
487 if not isinstance(value, tuple) \
488 or not len(value) == 2 \
489 or not value[0].shape[1:] == (3, 3) \
490 or not value[1].shape[1:] == (3,) \
491 or not value[0].shape[0] == value[1].shape[0]:
492 warnings.warn('Invalid symmetry_ops block, skipping')
493 return
494 # Now on to print...
495 text_block = ''
496 for op_i, (op_rot, op_tranls) in enumerate(zip(*value)):
497 text_block += '\n'.join([' '.join([str(x) for x in row])
498 for row in op_rot])
499 text_block += '\n'
500 text_block += ' '.join([str(x) for x in op_tranls])
501 text_block += '\n\n'
503 return text_block
505 def _parse_positions_abs_intermediate(self, value):
506 return _parse_tss_block(value)
508 def _parse_positions_abs_product(self, value):
509 return _parse_tss_block(value)
511 def _parse_positions_frac_intermediate(self, value):
512 return _parse_tss_block(value, True)
514 def _parse_positions_frac_product(self, value):
515 return _parse_tss_block(value, True)
518class ConversionError(Exception):
520 """Print customized error for options that are not converted correctly
521 and point out that they are maybe not implemented, yet"""
523 def __init__(self, key_type, attr, value):
524 Exception.__init__(self)
525 self.key_type = key_type
526 self.value = value
527 self.attr = attr
529 def __str__(self):
530 contact_email = 'simon.rittmeyer@tum.de'
531 return f'Could not convert {self.attr} = {self.value} '\
532 + 'to {self.key_type}\n' \
533 + 'This means you either tried to set a value of the wrong\n'\
534 + 'type or this keyword needs some special care. Please feel\n'\
535 + 'to add it to the corresponding __setattr__ method and send\n'\
536 + f'the patch to {(contact_email)}, so we can all benefit.'