Coverage for ase / io / castep / castep_input_file.py: 78.47%
288 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-03-30 08:22 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-03-30 08:22 +0000
1# fmt: off
3import difflib
4import re
5import warnings
7import numpy as np
9from ase import Atoms
11# A convenient table to avoid the previously used "eval"
12_tf_table = {
13 '': True, # Just the keyword is equivalent to True
14 'True': True,
15 'False': False}
18def _parse_tss_block(value, scaled=False):
19 # Parse the assigned value for a Transition State Search structure block
20 is_atoms = isinstance(value, Atoms)
21 try:
22 is_strlist = all(map(lambda x: isinstance(x, str), value))
23 except TypeError:
24 is_strlist = False
26 if not is_atoms:
27 if not is_strlist:
28 # Invalid!
29 raise TypeError('castep.cell.positions_abs/frac_intermediate/'
30 'product expects Atoms object or list of strings')
32 # First line must be Angstroms, or nothing
33 has_units = len(value[0].strip().split()) == 1
34 if (not scaled) and has_units and value[0].strip() != 'ang':
35 raise RuntimeError('Only ang units currently supported in castep.'
36 'cell.positions_abs_intermediate/product')
37 return '\n'.join(map(str.strip, value))
38 else:
39 text_block = '' if scaled else 'ang\n'
40 positions = (value.get_scaled_positions() if scaled else
41 value.get_positions())
42 symbols = value.get_chemical_symbols()
43 for s, p in zip(symbols, positions):
44 text_block += ' {} {:.3f} {:.3f} {:.3f}\n'.format(s, *p)
46 return text_block
49class CastepOption:
50 """"A CASTEP option. It handles basic conversions from string to its value
51 type."""
53 default_convert_types = {
54 'boolean (logical)': 'bool',
55 'defined': 'bool',
56 'string': 'str',
57 'integer': 'int',
58 'real': 'float',
59 'integer vector': 'int_vector',
60 'real vector': 'float_vector',
61 'physical': 'float_physical',
62 'block': 'block'
63 }
65 def __init__(self, keyword, level, option_type, value=None,
66 docstring='No information available'):
67 self.keyword = keyword
68 self.level = level
69 self.type = option_type
70 self._value = value
71 self.__doc__ = docstring
73 @property
74 def value(self):
76 if self._value is not None:
77 if self.type.lower() in ('integer vector', 'real vector',
78 'physical'):
79 return ' '.join(map(str, self._value))
80 elif self.type.lower() in ('boolean (logical)', 'defined'):
81 return str(self._value).upper()
82 else:
83 return str(self._value)
85 @property
86 def raw_value(self):
87 # The value, not converted to a string
88 return self._value
90 @value.setter # type: ignore[attr-defined, no-redef]
91 def value(self, val):
93 if val is None:
94 self.clear()
95 return
97 ctype = self.default_convert_types.get(self.type.lower(), 'str')
98 typeparse = f'_parse_{ctype}'
99 try:
100 self._value = getattr(self, typeparse)(val)
101 except ValueError:
102 raise ConversionError(ctype, self.keyword, val)
104 def clear(self):
105 """Reset the value of the option to None again"""
106 self._value = None
108 @staticmethod
109 def _parse_bool(value):
110 try:
111 value = _tf_table[str(value).strip().title()]
112 except (KeyError, ValueError):
113 raise ValueError()
114 return value
116 @staticmethod
117 def _parse_str(value):
118 value = str(value)
119 return value
121 @staticmethod
122 def _parse_int(value):
123 value = int(value)
124 return value
126 @staticmethod
127 def _parse_float(value):
128 value = float(value)
129 return value
131 @staticmethod
132 def _parse_int_vector(value):
133 # Accepts either a string or an actual list/numpy array of ints
134 if isinstance(value, str):
135 if ',' in value:
136 value = value.replace(',', ' ')
137 value = list(map(int, value.split()))
139 value = np.array(value)
141 if value.shape != (3,) or value.dtype != int:
142 raise ValueError()
144 return list(value)
146 @staticmethod
147 def _parse_float_vector(value):
148 # Accepts either a string or an actual list/numpy array of floats
149 if isinstance(value, str):
150 if ',' in value:
151 value = value.replace(',', ' ')
152 value = list(map(float, value.split()))
154 value = np.array(value) * 1.0
156 if value.shape != (3,) or value.dtype != float:
157 raise ValueError()
159 return list(value)
161 @staticmethod
162 def _parse_float_physical(value):
163 # If this is a string containing units, saves them
164 if isinstance(value, str):
165 value = value.split()
167 try:
168 l = len(value)
169 except TypeError:
170 l = 1
171 value = [value]
173 if l == 1:
174 try:
175 value = (float(value[0]), '')
176 except (TypeError, ValueError):
177 raise ValueError()
178 elif l == 2:
179 try:
180 value = (float(value[0]), value[1])
181 except (TypeError, ValueError, IndexError):
182 raise ValueError()
183 else:
184 raise ValueError()
186 return value
188 @staticmethod
189 def _parse_block(value):
191 if isinstance(value, str):
192 return value
193 elif hasattr(value, '__getitem__'):
194 return '\n'.join(value) # Arrays of lines
195 else:
196 raise ValueError()
198 def __repr__(self):
199 if self._value:
200 expr = ('Option: {keyword}({type}, {level}):\n{_value}\n'
201 ).format(**self.__dict__)
202 else:
203 expr = ('Option: {keyword}[unset]({type}, {level})'
204 ).format(**self.__dict__)
205 return expr
207 def __eq__(self, other):
208 if not isinstance(other, CastepOption):
209 return False
210 else:
211 return self.__dict__ == other.__dict__
214class CastepOptionDict:
215 """A dictionary-like object to hold a set of options for .cell or .param
216 files loaded from a dictionary, for the sake of validation.
218 Replaces the old CastepCellDict and CastepParamDict that were defined in
219 the castep_keywords.py file.
220 """
222 def __init__(self, options=None):
223 object.__init__(self)
224 self._options = {} # ComparableDict is not needed any more as
225 # CastepOptions can be compared directly now
226 for kw in options:
227 opt = CastepOption(**options[kw])
228 self._options[opt.keyword] = opt
229 self.__dict__[opt.keyword] = opt
232class CastepInputFile:
234 """Master class for CastepParam and CastepCell to inherit from"""
236 _keyword_conflicts: list[set[str]] = []
238 def __init__(self, options_dict=None, keyword_tolerance=1):
239 object.__init__(self)
241 if options_dict is None:
242 options_dict = CastepOptionDict({})
244 self._options = options_dict._options
245 self.__dict__.update(self._options)
246 # keyword_tolerance means how strict the checks on new attributes are
247 # 0 = no new attributes allowed
248 # 1 = new attributes allowed, warning given
249 # 2 = new attributes allowed, silent
250 self._perm = np.clip(keyword_tolerance, 0, 2)
252 # Compile a dictionary for quick check of conflict sets
253 self._conflict_dict = {
254 kw: set(cset).difference({kw})
255 for cset in self._keyword_conflicts for kw in cset}
257 def __repr__(self):
258 expr = ''
259 is_default = True
260 for key, option in sorted(self._options.items()):
261 if option.value is not None:
262 is_default = False
263 expr += ('%20s : %s\n' % (key, option.value))
264 if is_default:
265 expr = 'Default\n'
267 expr += f'Keyword tolerance: {self._perm}'
268 return expr
270 def __setattr__(self, attr, value):
272 # Hidden attributes are treated normally
273 if attr.startswith('_'):
274 self.__dict__[attr] = value
275 return
277 if attr not in self._options.keys():
279 if self._perm > 0:
280 # Do we consider it a string or a block?
281 is_str = isinstance(value, str)
282 is_block = False
283 if ((hasattr(value, '__getitem__') and not is_str)
284 or (is_str and len(value.split('\n')) > 1)):
285 is_block = True
287 if self._perm == 0:
288 similars = difflib.get_close_matches(attr,
289 self._options.keys())
290 if similars:
291 raise RuntimeError(
292 f'Option "{attr}" not known! You mean "{similars[0]}"?')
293 else:
294 raise RuntimeError(f'Option "{attr}" is not known!')
295 elif self._perm == 1:
296 warnings.warn(('Option "%s" is not known and will '
297 'be added as a %s') % (attr,
298 ('block' if is_block else
299 'string')))
300 attr = attr.lower()
301 opt = CastepOption(keyword=attr, level='Unknown',
302 option_type='block' if is_block else 'string')
303 self._options[attr] = opt
304 self.__dict__[attr] = opt
305 else:
306 attr = attr.lower()
307 opt = self._options[attr]
309 if not opt.type.lower() == 'block' and isinstance(value, str):
310 value = value.replace(':', ' ')
312 # If it is, use the appropriate parser, unless a custom one is defined
313 attrparse = f'_parse_{attr.lower()}'
315 # Check for any conflicts if the value is not None
316 if value is not None:
317 cset = self._conflict_dict.get(attr.lower(), {})
318 for c in cset:
319 if (c in self._options and self._options[c].value):
320 warnings.warn(
321 'option "{attr}" conflicts with "{conflict}" in '
322 'calculator. Setting "{conflict}" to '
323 'None.'.format(attr=attr, conflict=c))
324 self._options[c].value = None
326 if hasattr(self, attrparse):
327 self._options[attr].value = self.__getattribute__(attrparse)(value)
328 else:
329 self._options[attr].value = value
331 def __getattr__(self, name):
332 if name[0] == '_' or self._perm == 0:
333 raise AttributeError()
335 if self._perm == 1:
336 warnings.warn(f'Option {(name)} is not known, returning None')
338 return CastepOption(keyword='none', level='Unknown',
339 option_type='string', value=None)
341 def get_attr_dict(self, raw=False, types=False):
342 """Settings that go into .param file in a traditional dict"""
344 attrdict = {k: o.raw_value if raw else o.value
345 for k, o in self._options.items() if o.value is not None}
347 if types:
348 for key, val in attrdict.items():
349 attrdict[key] = (val, self._options[key].type)
351 return attrdict
354class CastepParam(CastepInputFile):
355 """CastepParam abstracts the settings that go into the .param file"""
357 _keyword_conflicts = [{'cut_off_energy', 'basis_precision'}, ]
359 def __init__(self, castep_keywords, keyword_tolerance=1):
360 self._castep_version = castep_keywords.castep_version
361 CastepInputFile.__init__(self, castep_keywords.CastepParamDict(),
362 keyword_tolerance)
364 @property
365 def castep_version(self):
366 return self._castep_version
368 # .param specific parsers
369 def _parse_reuse(self, value):
370 if value is None:
371 return None # Reset the value
372 try:
373 if self._options['continuation'].value:
374 warnings.warn('Cannot set reuse if continuation is set, and '
375 'vice versa. Set the other to None, if you want '
376 'this setting.')
377 return None
378 except KeyError:
379 pass
380 return 'default' if (value is True) else str(value)
382 def _parse_continuation(self, value):
383 if value is None:
384 return None # Reset the value
385 try:
386 if self._options['reuse'].value:
387 warnings.warn('Cannot set reuse if continuation is set, and '
388 'vice versa. Set the other to None, if you want '
389 'this setting.')
390 return None
391 except KeyError:
392 pass
393 return 'default' if (value is True) else str(value)
396class CastepCell(CastepInputFile):
398 """CastepCell abstracts all setting that go into the .cell file"""
400 _keyword_conflicts = [
401 {'kpoint_mp_grid', 'kpoint_mp_spacing', 'kpoint_list',
402 'kpoints_mp_grid', 'kpoints_mp_spacing', 'kpoints_list'},
403 {'bs_kpoint_mp_grid',
404 'bs_kpoint_mp_spacing',
405 'bs_kpoint_list',
406 'bs_kpoint_path',
407 'bs_kpoints_mp_grid',
408 'bs_kpoints_mp_spacing',
409 'bs_kpoints_list',
410 'bs_kpoints_path'},
411 {'spectral_kpoint_mp_grid',
412 'spectral_kpoint_mp_spacing',
413 'spectral_kpoint_list',
414 'spectral_kpoint_path',
415 'spectral_kpoints_mp_grid',
416 'spectral_kpoints_mp_spacing',
417 'spectral_kpoints_list',
418 'spectral_kpoints_path'},
419 {'phonon_kpoint_mp_grid',
420 'phonon_kpoint_mp_spacing',
421 'phonon_kpoint_list',
422 'phonon_kpoint_path',
423 'phonon_kpoints_mp_grid',
424 'phonon_kpoints_mp_spacing',
425 'phonon_kpoints_list',
426 'phonon_kpoints_path'},
427 {'fine_phonon_kpoint_mp_grid',
428 'fine_phonon_kpoint_mp_spacing',
429 'fine_phonon_kpoint_list',
430 'fine_phonon_kpoint_path'},
431 {'magres_kpoint_mp_grid',
432 'magres_kpoint_mp_spacing',
433 'magres_kpoint_list',
434 'magres_kpoint_path'},
435 {'elnes_kpoint_mp_grid',
436 'elnes_kpoint_mp_spacing',
437 'elnes_kpoint_list',
438 'elnes_kpoint_path'},
439 {'optics_kpoint_mp_grid',
440 'optics_kpoint_mp_spacing',
441 'optics_kpoint_list',
442 'optics_kpoint_path'},
443 {'supercell_kpoint_mp_grid',
444 'supercell_kpoint_mp_spacing',
445 'supercell_kpoint_list',
446 'supercell_kpoint_path'}, ]
448 def __init__(self, castep_keywords, keyword_tolerance=1):
449 self._castep_version = castep_keywords.castep_version
450 CastepInputFile.__init__(self, castep_keywords.CastepCellDict(),
451 keyword_tolerance)
453 @property
454 def castep_version(self):
455 return self._castep_version
457 # .cell specific parsers
458 def _parse_species_pot(self, value):
460 # Single tuple
461 if isinstance(value, tuple) and len(value) == 2:
462 value = [value]
463 # List of tuples
464 if hasattr(value, '__getitem__'):
465 pspots = [tuple(map(str.strip, x)) for x in value]
466 if not all(map(lambda x: len(x) == 2, value)):
467 warnings.warn(
468 'Please specify pseudopotentials in python as '
469 'a tuple or a list of tuples formatted like: '
470 '(species, file), e.g. ("O", "path-to/O_OTFG.usp") '
471 'Anything else will be ignored')
472 return None
474 text_block = self._options['species_pot'].value
476 text_block = text_block if text_block else ''
477 # Remove any duplicates
478 for pp in pspots:
479 text_block = re.sub(fr'\n?\s*{pp[0]}\s+.*', '', text_block)
480 if pp[1]:
481 text_block += '\n%s %s' % pp
483 return text_block
485 def _parse_symmetry_ops(self, value):
486 if not isinstance(value, tuple) \
487 or not len(value) == 2 \
488 or not value[0].shape[1:] == (3, 3) \
489 or not value[1].shape[1:] == (3,) \
490 or not value[0].shape[0] == value[1].shape[0]:
491 warnings.warn('Invalid symmetry_ops block, skipping')
492 return
493 # Now on to print...
494 text_block = ''
495 for op_i, (op_rot, op_tranls) in enumerate(zip(*value)):
496 text_block += '\n'.join([' '.join([str(x) for x in row])
497 for row in op_rot])
498 text_block += '\n'
499 text_block += ' '.join([str(x) for x in op_tranls])
500 text_block += '\n\n'
502 return text_block
504 def _parse_positions_abs_intermediate(self, value):
505 return _parse_tss_block(value)
507 def _parse_positions_abs_product(self, value):
508 return _parse_tss_block(value)
510 def _parse_positions_frac_intermediate(self, value):
511 return _parse_tss_block(value, True)
513 def _parse_positions_frac_product(self, value):
514 return _parse_tss_block(value, True)
517class ConversionError(Exception):
519 """Print customized error for options that are not converted correctly
520 and point out that they are maybe not implemented, yet"""
522 def __init__(self, key_type, attr, value):
523 Exception.__init__(self)
524 self.key_type = key_type
525 self.value = value
526 self.attr = attr
528 def __str__(self):
529 contact_email = 'simon.rittmeyer@tum.de'
530 return f'Could not convert {self.attr} = {self.value} '\
531 + 'to {self.key_type}\n' \
532 + 'This means you either tried to set a value of the wrong\n'\
533 + 'type or this keyword needs some special care. Please feel\n'\
534 + 'to add it to the corresponding __setattr__ method and send\n'\
535 + f'the patch to {(contact_email)}, so we can all benefit.'