Coverage for /builds/ase/ase/ase/calculators/kim/kimpy_wrappers.py: 75.81%
339 statements
« prev ^ index » next coverage.py v7.5.3, created at 2025-08-02 00:12 +0000
« prev ^ index » next coverage.py v7.5.3, created at 2025-08-02 00:12 +0000
1# fmt: off
3"""
4Wrappers that provide a minimal interface to kimpy methods and objects
6Daniel S. Karls
7University of Minnesota
8"""
10import functools
11from abc import ABC
13import numpy as np
15from .exceptions import (
16 KIMModelInitializationError,
17 KIMModelNotFound,
18 KIMModelParameterError,
19 KimpyError,
20)
23class LazyKimpyImport:
24 """This class avoids module level import of the optional kimpy module."""
26 def __getattr__(self, attr):
27 return getattr(self._kimpy, attr)
29 @functools.cached_property
30 def _kimpy(self):
31 import kimpy
32 return kimpy
35class Wrappers:
36 """Shortcuts written in a way that avoids module-level kimpy import."""
38 @property
39 def collections_create(self):
40 return functools.partial(check_call, kimpy.collections.create)
42 @property
43 def model_create(self):
44 return functools.partial(check_call, kimpy.model.create)
46 @property
47 def simulator_model_create(self):
48 return functools.partial(check_call, kimpy.simulator_model.create)
50 @property
51 def get_species_name(self):
52 return functools.partial(
53 check_call, kimpy.species_name.get_species_name)
55 @property
56 def get_number_of_species_names(self):
57 return functools.partial(
58 check_call, kimpy.species_name.get_number_of_species_names)
60 @property
61 def collection_item_type_portableModel(self):
62 return kimpy.collection_item_type.portableModel
65kimpy = LazyKimpyImport()
66wrappers = Wrappers()
68# Function used for casting parameter/extent indices to C-compatible ints
69c_int = np.intc
71# Function used for casting floating point parameter values to C-compatible
72# doubles
73c_double = np.double
76def c_int_args(func):
77 """
78 Decorator for instance methods that will cast all of the args passed,
79 excluding the first (which corresponds to 'self'), to C-compatible
80 integers.
81 """
83 @functools.wraps(func)
84 def myfunc(*args, **kwargs):
85 args_cast = [args[0]]
86 args_cast += map(c_int, args[1:])
87 return func(*args, **kwargs)
89 return myfunc
92def check_call(f, *args, **kwargs):
93 """Call a kimpy function using its arguments and, if a RuntimeError is
94 raised, catch it and raise a KimpyError with the exception's
95 message.
97 (Starting with kimpy 2.0.0, a RuntimeError is the only exception
98 type raised when something goes wrong.)"""
100 try:
101 return f(*args, **kwargs)
102 except RuntimeError as e:
103 raise KimpyError(
104 f'Calling kimpy function "{f.__name__}" failed:\n {e!s}')
107def check_call_wrapper(func):
108 @functools.wraps(func)
109 def myfunc(*args, **kwargs):
110 return check_call(func, *args, **kwargs)
112 return myfunc
115class ModelCollections:
116 """
117 KIM Portable Models and Simulator Models are installed/managed into
118 different "collections". In order to search through the different
119 KIM API model collections on the system, a corresponding object must
120 be instantiated. For more on model collections, see the KIM API's
121 install file:
122 https://github.com/openkim/kim-api/blob/master/INSTALL
123 """
125 def __init__(self):
126 self.collection = wrappers.collections_create()
128 def __enter__(self):
129 return self
131 def __exit__(self, exc_type, value, traceback):
132 pass
134 def get_item_type(self, model_name):
135 try:
136 model_type = check_call(self.collection.get_item_type, model_name)
137 except KimpyError:
138 msg = (
139 "Could not find model {} installed in any of the KIM API "
140 "model collections on this system. See "
141 "https://openkim.org/doc/usage/obtaining-models/ for "
142 "instructions on installing models.".format(model_name)
143 )
144 raise KIMModelNotFound(msg)
146 return model_type
148 @property
149 def initialized(self):
150 return hasattr(self, "collection")
153class PortableModel:
154 """Creates a KIM API Portable Model object and provides a minimal
155 interface to it"""
157 def __init__(self, model_name, debug):
158 self.model_name = model_name
159 self.debug = debug
161 # Create KIM API Model object
162 units_accepted, self.kim_model = wrappers.model_create(
163 kimpy.numbering.zeroBased,
164 kimpy.length_unit.A,
165 kimpy.energy_unit.eV,
166 kimpy.charge_unit.e,
167 kimpy.temperature_unit.K,
168 kimpy.time_unit.ps,
169 self.model_name,
170 )
172 if not units_accepted:
173 raise KIMModelInitializationError(
174 "Requested units not accepted in kimpy.model.create"
175 )
177 if self.debug:
178 l_unit, e_unit, c_unit, te_unit, ti_unit = check_call(
179 self.kim_model.get_units
180 )
181 print(f"Length unit is: {l_unit}")
182 print(f"Energy unit is: {e_unit}")
183 print(f"Charge unit is: {c_unit}")
184 print(f"Temperature unit is: {te_unit}")
185 print(f"Time unit is: {ti_unit}")
186 print()
188 self._create_parameters()
190 def __enter__(self):
191 return self
193 def __exit__(self, exc_type, value, traceback):
194 pass
196 @check_call_wrapper
197 def _get_number_of_parameters(self):
198 return self.kim_model.get_number_of_parameters()
200 def _create_parameters(self):
201 def _kim_model_parameter(**kwargs):
202 dtype = kwargs["dtype"]
204 if dtype == "Integer":
205 return KIMModelParameterInteger(**kwargs)
206 elif dtype == "Double":
207 return KIMModelParameterDouble(**kwargs)
208 else:
209 raise KIMModelParameterError(
210 f"Invalid model parameter type {dtype}. Supported types "
211 "'Integer' and 'Double'."
212 )
214 self._parameters = {}
215 num_params = self._get_number_of_parameters()
216 for index_param in range(num_params):
217 parameter_metadata = self._get_one_parameter_metadata(index_param)
218 name = parameter_metadata["name"]
220 self._parameters[name] = _kim_model_parameter(
221 kim_model=self.kim_model,
222 dtype=parameter_metadata["dtype"],
223 extent=parameter_metadata["extent"],
224 name=name,
225 description=parameter_metadata["description"],
226 parameter_index=index_param,
227 )
229 def get_model_supported_species_and_codes(self):
230 """Get all of the supported species for this model and their
231 corresponding integer codes that are defined in the KIM API
233 Returns
234 -------
235 species : list of str
236 Abbreviated chemical symbols of all species the mmodel
237 supports (e.g. ["Mo", "S"])
239 codes : list of int
240 Integer codes used by the model for each species (order
241 corresponds to the order of ``species``)
242 """
243 species = []
244 codes = []
245 num_kim_species = wrappers.get_number_of_species_names()
247 for i in range(num_kim_species):
248 species_name = wrappers.get_species_name(i)
250 species_is_supported, code = self.get_species_support_and_code(
251 species_name)
253 if species_is_supported:
254 species.append(str(species_name))
255 codes.append(code)
257 return species, codes
259 @check_call_wrapper
260 def clear_then_refresh(self):
261 self.kim_model.clear_then_refresh()
263 @c_int_args
264 def _get_parameter_metadata(self, index_parameter):
265 try:
266 dtype, extent, name, description = check_call(
267 self.kim_model.get_parameter_metadata, index_parameter
268 )
269 except KimpyError as e:
270 raise KIMModelParameterError(
271 "Failed to retrieve metadata for "
272 f"parameter at index {index_parameter}"
273 ) from e
275 return dtype, extent, name, description
277 def parameters_metadata(self):
278 """Metadata associated with all model parameters.
280 Returns
281 -------
282 dict
283 Metadata associated with all model parameters.
284 """
285 return {
286 param_name: param.metadata
287 for param_name, param in self._parameters.items()
288 }
290 def parameter_names(self):
291 """Names of model parameters registered in the KIM API.
293 Returns
294 -------
295 tuple
296 Names of model parameters registered in the KIM API
297 """
298 return tuple(self._parameters.keys())
300 def get_parameters(self, **kwargs):
301 """
302 Get the values of one or more model parameter arrays.
304 Given the names of one or more model parameters and a set of indices
305 for each of them, retrieve the corresponding elements of the relevant
306 model parameter arrays.
308 Parameters
309 ----------
310 **kwargs
311 Names of the model parameters and the indices whose values should
312 be retrieved.
314 Returns
315 -------
316 dict
317 The requested indices and the values of the model's parameters.
319 Note
320 ----
321 The output of this method can be used as input of
322 ``set_parameters``.
324 Example
325 -------
326 To get `epsilons` and `sigmas` in the LJ universal model for Mo-Mo
327 (index 4879), Mo-S (index 2006) and S-S (index 1980) interactions::
329 >>> LJ = 'LJ_ElliottAkerson_2015_Universal__MO_959249795837_003'
330 >>> calc = KIM(LJ)
331 >>> calc.get_parameters(epsilons=[4879, 2006, 1980],
332 ... sigmas=[4879, 2006, 1980])
333 {'epsilons': [[4879, 2006, 1980],
334 [4.47499, 4.421814057295943, 4.36927]],
335 'sigmas': [[4879, 2006, 1980],
336 [2.74397, 2.30743, 1.87089]]}
337 """
338 parameters = {}
339 for parameter_name, index_range in kwargs.items():
340 parameters.update(
341 self._get_one_parameter(
342 parameter_name,
343 index_range))
344 return parameters
346 def set_parameters(self, **kwargs):
347 """
348 Set the values of one or more model parameter arrays.
350 Given the names of one or more model parameters and a set of indices
351 and corresponding values for each of them, mutate the corresponding
352 elements of the relevant model parameter arrays.
354 Parameters
355 ----------
356 **kwargs
357 Names of the model parameters to mutate and the corresponding
358 indices and values to set.
360 Returns
361 -------
362 dict
363 The requested indices and the values of the model's parameters
364 that were set.
366 Example
367 -------
368 To set `epsilons` in the LJ universal model for Mo-Mo (index 4879),
369 Mo-S (index 2006) and S-S (index 1980) interactions to 5.0, 4.5, and
370 4.0, respectively::
372 >>> LJ = 'LJ_ElliottAkerson_2015_Universal__MO_959249795837_003'
373 >>> calc = KIM(LJ)
374 >>> calc.set_parameters(epsilons=[[4879, 2006, 1980],
375 ... [5.0, 4.5, 4.0]])
376 {'epsilons': [[4879, 2006, 1980],
377 [5.0, 4.5, 4.0]]}
378 """
379 parameters = {}
380 for parameter_name, parameter_data in kwargs.items():
381 index_range, values = parameter_data
382 self._set_one_parameter(parameter_name, index_range, values)
383 parameters[parameter_name] = parameter_data
385 return parameters
387 def _get_one_parameter(self, parameter_name, index_range):
388 """
389 Retrieve value of one or more components of a model parameter array.
391 Parameters
392 ----------
393 parameter_name : str
394 Name of model parameter registered in the KIM API.
395 index_range : int or list
396 Zero-based index (int) or indices (list of int) specifying the
397 component(s) of the corresponding model parameter array that are
398 to be retrieved.
400 Returns
401 -------
402 dict
403 The requested indices and the corresponding values of the model
404 parameter array.
405 """
406 if parameter_name not in self._parameters:
407 raise KIMModelParameterError(
408 f"Parameter '{parameter_name}' is not "
409 "supported by this model. "
410 "Please check that the parameter name is spelled correctly."
411 )
413 return self._parameters[parameter_name].get_values(index_range)
415 def _set_one_parameter(self, parameter_name, index_range, values):
416 """
417 Set the value of one or more components of a model parameter array.
419 Parameters
420 ----------
421 parameter_name : str
422 Name of model parameter registered in the KIM API.
423 index_range : int or list
424 Zero-based index (int) or indices (list of int) specifying the
425 component(s) of the corresponding model parameter array that are
426 to be mutated.
427 values : int/float or list
428 Value(s) to assign to the component(s) of the model parameter
429 array specified by ``index_range``.
430 """
431 if parameter_name not in self._parameters:
432 raise KIMModelParameterError(
433 f"Parameter '{parameter_name}' is not "
434 "supported by this model. "
435 "Please check that the parameter name is spelled correctly."
436 )
438 self._parameters[parameter_name].set_values(index_range, values)
440 def _get_one_parameter_metadata(self, index_parameter):
441 """
442 Get metadata associated with a single model parameter.
444 Parameters
445 ----------
446 index_parameter : int
447 Zero-based index used by the KIM API to refer to this model
448 parameter.
450 Returns
451 -------
452 dict
453 Metadata associated with the requested model parameter.
454 """
455 dtype, extent, name, description = self._get_parameter_metadata(
456 index_parameter)
457 parameter_metadata = {
458 "name": name,
459 "dtype": repr(dtype),
460 "extent": extent,
461 "description": description,
462 }
463 return parameter_metadata
465 @check_call_wrapper
466 def compute(self, compute_args_wrapped, release_GIL):
467 return self.kim_model.compute(
468 compute_args_wrapped.compute_args, release_GIL)
470 @check_call_wrapper
471 def get_species_support_and_code(self, species_name):
472 return self.kim_model.get_species_support_and_code(species_name)
474 @check_call_wrapper
475 def get_influence_distance(self):
476 return self.kim_model.get_influence_distance()
478 @check_call_wrapper
479 def get_neighbor_list_cutoffs_and_hints(self):
480 return self.kim_model.get_neighbor_list_cutoffs_and_hints()
482 def compute_arguments_create(self):
483 return ComputeArguments(self, self.debug)
485 @property
486 def initialized(self):
487 return hasattr(self, "kim_model")
490class KIMModelParameter(ABC):
491 def __init__(self, kim_model, dtype, extent,
492 name, description, parameter_index):
493 self._kim_model = kim_model
494 self._dtype = dtype
495 self._extent = extent
496 self._name = name
497 self._description = description
499 # Ensure that parameter_index is cast to a C-compatible integer. This
500 # is necessary because this is passed to kimpy.
501 self._parameter_index = c_int(parameter_index)
503 @property
504 def metadata(self):
505 return {
506 "dtype": self._dtype,
507 "extent": self._extent,
508 "name": self._name,
509 "description": self._description,
510 }
512 @c_int_args
513 def _get_one_value(self, index_extent):
514 get_parameter = getattr(self._kim_model, self._dtype_accessor)
515 try:
516 return check_call(
517 get_parameter, self._parameter_index, index_extent)
518 except KimpyError as exception:
519 raise KIMModelParameterError(
520 f"Failed to access component {index_extent} of model "
521 f"parameter of type '{self._dtype}' at parameter index "
522 f"{self._parameter_index}"
523 ) from exception
525 def _set_one_value(self, index_extent, value):
526 value_typecast = self._dtype_c(value)
528 try:
529 check_call(
530 self._kim_model.set_parameter,
531 self._parameter_index,
532 c_int(index_extent),
533 value_typecast,
534 )
535 except KimpyError:
536 raise KIMModelParameterError(
537 f"Failed to set component {index_extent} at parameter index "
538 f"{self._parameter_index} to {self._dtype} value "
539 f"{value_typecast}"
540 )
542 def get_values(self, index_range):
543 index_range_dim = np.ndim(index_range)
544 if index_range_dim == 0:
545 values = self._get_one_value(index_range)
546 elif index_range_dim == 1:
547 values = []
548 for idx in index_range:
549 values.append(self._get_one_value(idx))
550 else:
551 raise KIMModelParameterError(
552 "Index range must be an integer or a list of integers"
553 )
554 return {self._name: [index_range, values]}
556 def set_values(self, index_range, values):
557 index_range_dim = np.ndim(index_range)
558 values_dim = np.ndim(values)
560 # Check the shape of index_range and values
561 msg = "index_range and values must have the same shape"
562 assert index_range_dim == values_dim, msg
564 if index_range_dim == 0:
565 self._set_one_value(index_range, values)
566 elif index_range_dim == 1:
567 assert len(index_range) == len(values), msg
568 for idx, value in zip(index_range, values):
569 self._set_one_value(idx, value)
570 else:
571 raise KIMModelParameterError(
572 "Index range must be an integer or a list containing a "
573 "single integer"
574 )
577class KIMModelParameterInteger(KIMModelParameter):
578 _dtype_c = c_int
579 _dtype_accessor = "get_parameter_int"
582class KIMModelParameterDouble(KIMModelParameter):
583 _dtype_c = c_double
584 _dtype_accessor = "get_parameter_double"
587class ComputeArguments:
588 """Creates a KIM API ComputeArguments object from a KIM Portable
589 Model object and configures it for ASE. A ComputeArguments object
590 is associated with a KIM Portable Model and is used to inform the
591 KIM API of what the model can compute. It is also used to
592 register the data arrays that allow the KIM API to pass the atomic
593 coordinates to the model and retrieve the corresponding energy and
594 forces, etc."""
596 def __init__(self, kim_model_wrapped, debug):
597 self.kim_model_wrapped = kim_model_wrapped
598 self.debug = debug
600 # Create KIM API ComputeArguments object
601 self.compute_args = check_call(
602 self.kim_model_wrapped.kim_model.compute_arguments_create
603 )
605 # Check compute arguments
606 kimpy_arg_name = kimpy.compute_argument_name
607 num_arguments = kimpy_arg_name.get_number_of_compute_argument_names()
608 if self.debug:
609 print(f"Number of compute_args: {num_arguments}")
611 for i in range(num_arguments):
612 name = check_call(kimpy_arg_name.get_compute_argument_name, i)
613 dtype = check_call(
614 kimpy_arg_name.get_compute_argument_data_type, name)
616 arg_support = self.get_argument_support_status(name)
618 if self.debug:
619 print(
620 "Compute Argument name {:21} is of type {:7} "
621 "and has support "
622 "status {}".format(*[str(x)
623 for x in [name, dtype, arg_support]])
624 )
626 # See if the model demands that we ask it for anything
627 # other than energy and forces. If so, raise an
628 # exception.
629 if arg_support == kimpy.support_status.required:
630 if (
631 name != kimpy.compute_argument_name.partialEnergy
632 and name != kimpy.compute_argument_name.partialForces
633 ):
634 raise KIMModelInitializationError(
635 f"Unsupported required ComputeArgument {name}"
636 )
638 # Check compute callbacks
639 callback_name = kimpy.compute_callback_name
640 num_callbacks = callback_name.get_number_of_compute_callback_names()
641 if self.debug:
642 print()
643 print(f"Number of callbacks: {num_callbacks}")
645 for i in range(num_callbacks):
646 name = check_call(callback_name.get_compute_callback_name, i)
648 support_status = self.get_callback_support_status(name)
650 if self.debug:
651 print(
652 "Compute callback {:17} has support status {}".format(
653 str(name), support_status
654 )
655 )
657 # Cannot handle any "required" callbacks
658 if support_status == kimpy.support_status.required:
659 raise KIMModelInitializationError(
660 f"Unsupported required ComputeCallback: {name}"
661 )
663 @check_call_wrapper
664 def set_argument_pointer(self, compute_arg_name, data_object):
665 return self.compute_args.set_argument_pointer(
666 compute_arg_name, data_object)
668 @check_call_wrapper
669 def get_argument_support_status(self, name):
670 return self.compute_args.get_argument_support_status(name)
672 @check_call_wrapper
673 def get_callback_support_status(self, name):
674 return self.compute_args.get_callback_support_status(name)
676 @check_call_wrapper
677 def set_callback(self, compute_callback_name,
678 callback_function, data_object):
679 return self.compute_args.set_callback(
680 compute_callback_name, callback_function, data_object
681 )
683 @check_call_wrapper
684 def set_callback_pointer(
685 self, compute_callback_name, callback, data_object):
686 return self.compute_args.set_callback_pointer(
687 compute_callback_name, callback, data_object
688 )
690 def update(
691 self, num_particles, species_code, particle_contributing,
692 coords, energy, forces
693 ):
694 """Register model input and output in the kim_model object."""
695 compute_arg_name = kimpy.compute_argument_name
696 set_argument_pointer = self.set_argument_pointer
698 set_argument_pointer(compute_arg_name.numberOfParticles, num_particles)
699 set_argument_pointer(
700 compute_arg_name.particleSpeciesCodes,
701 species_code)
702 set_argument_pointer(
703 compute_arg_name.particleContributing, particle_contributing
704 )
705 set_argument_pointer(compute_arg_name.coordinates, coords)
706 set_argument_pointer(compute_arg_name.partialEnergy, energy)
707 set_argument_pointer(compute_arg_name.partialForces, forces)
709 if self.debug:
710 print("Debug: called update_kim")
711 print()
714class SimulatorModel:
715 """Creates a KIM API Simulator Model object and provides a minimal
716 interface to it. This is only necessary in this package in order to
717 extract any information about a given simulator model because it is
718 generally embedded in a shared object.
719 """
721 def __init__(self, model_name):
722 # Create a KIM API Simulator Model object for this model
723 self.model_name = model_name
724 self.simulator_model = wrappers.simulator_model_create(self.model_name)
726 # Need to close template map in order to access simulator
727 # model metadata
728 self.simulator_model.close_template_map()
730 def __enter__(self):
731 return self
733 def __exit__(self, exc_type, value, traceback):
734 pass
736 @property
737 def simulator_name(self):
738 simulator_name, _ = self.simulator_model.\
739 get_simulator_name_and_version()
740 return simulator_name
742 @property
743 def num_supported_species(self):
744 num_supported_species = self.simulator_model.\
745 get_number_of_supported_species()
746 if num_supported_species == 0:
747 raise KIMModelInitializationError(
748 "Unable to determine supported species of "
749 "simulator model {}.".format(self.model_name)
750 )
751 return num_supported_species
753 @property
754 def supported_species(self):
755 supported_species = []
756 for spec_code in range(self.num_supported_species):
757 species = check_call(
758 self.simulator_model.get_supported_species, spec_code)
759 supported_species.append(species)
761 return tuple(supported_species)
763 @property
764 def num_metadata_fields(self):
765 return self.simulator_model.get_number_of_simulator_fields()
767 @property
768 def metadata(self):
769 sm_metadata_fields = {}
770 for field in range(self.num_metadata_fields):
771 extent, field_name = check_call(
772 self.simulator_model.get_simulator_field_metadata, field
773 )
774 sm_metadata_fields[field_name] = []
775 for ln in range(extent):
776 field_line = check_call(
777 self.simulator_model.get_simulator_field_line, field, ln
778 )
779 sm_metadata_fields[field_name].append(field_line)
781 return sm_metadata_fields
783 @property
784 def supported_units(self):
785 try:
786 supported_units = self.metadata["units"][0]
787 except (KeyError, IndexError):
788 raise KIMModelInitializationError(
789 "Unable to determine supported units of "
790 "simulator model {}.".format(self.model_name)
791 )
793 return supported_units
795 @property
796 def atom_style(self):
797 """
798 See if a 'model-init' field exists in the SM metadata and, if
799 so, whether it contains any entries including an "atom_style"
800 command. This is specific to LAMMPS SMs and is only required
801 for using the LAMMPSrun calculator because it uses
802 lammps.inputwriter to create a data file. All other content in
803 'model-init', if it exists, is ignored.
804 """
805 atom_style = None
806 for ln in self.metadata.get("model-init", []):
807 if ln.find("atom_style") != -1:
808 atom_style = ln.split()[1]
810 return atom_style
812 @property
813 def model_defn(self):
814 return self.metadata["model-defn"]
816 @property
817 def initialized(self):
818 return hasattr(self, "simulator_model")