Coverage for /builds/ase/ase/ase/calculators/subprocesscalculator.py: 91.37%
197 statements
« prev ^ index » next coverage.py v7.5.3, created at 2025-08-02 00:12 +0000
« prev ^ index » next coverage.py v7.5.3, created at 2025-08-02 00:12 +0000
1# fmt: off
3import os
4import pickle
5import sys
6from abc import ABC, abstractmethod
7from subprocess import PIPE, Popen
9from ase.calculators.calculator import Calculator, all_properties
12class PackedCalculator(ABC):
13 """Portable calculator for use via PythonSubProcessCalculator.
15 This class allows creating and talking to a calculator which
16 exists inside a different process, possibly with MPI or srun.
18 Use this when you want to use ASE mostly in serial, but run some
19 calculations in a parallel Python environment.
21 Most existing calculators can be used this way through the
22 NamedPackedCalculator implementation. To customize the behaviour
23 for other calculators, write a custom class inheriting this one.
25 Example::
27 from ase.build import bulk
29 atoms = bulk('Au')
30 pack = NamedPackedCalculator('emt')
32 with pack.calculator() as atoms.calc:
33 energy = atoms.get_potential_energy()
35 The computation takes place inside a subprocess which lives as long
36 as the with statement.
37 """
39 @abstractmethod
40 def unpack_calculator(self) -> Calculator:
41 """Return the calculator packed inside.
43 This method will be called inside the subprocess doing
44 computations."""
46 def calculator(self, mpi_command=None) -> 'PythonSubProcessCalculator':
47 """Return a PythonSubProcessCalculator for this calculator.
49 The subprocess calculator wraps a subprocess containing
50 the actual calculator, and computations are done inside that
51 subprocess."""
52 return PythonSubProcessCalculator(self, mpi_command=mpi_command)
55class NamedPackedCalculator(PackedCalculator):
56 """PackedCalculator implementation which works with standard calculators.
58 This works with calculators known by ase.calculators.calculator."""
60 def __init__(self, name, kwargs=None):
61 self._name = name
62 if kwargs is None:
63 kwargs = {}
64 self._kwargs = kwargs
66 def unpack_calculator(self):
67 from ase.calculators.calculator import get_calculator_class
68 cls = get_calculator_class(self._name)
69 return cls(**self._kwargs)
71 def __repr__(self):
72 return f'{self.__class__.__name__}({self._name}, {self._kwargs})'
75class MPICommand:
76 def __init__(self, argv):
77 self.argv = argv
79 @classmethod
80 def python_argv(cls):
81 return [sys.executable, '-m', 'ase.calculators.subprocesscalculator']
83 @classmethod
84 def parallel(cls, nprocs, mpi_argv=()):
85 return cls(['mpiexec', '-n', str(nprocs)]
86 + list(mpi_argv)
87 + cls.python_argv()
88 + ['mpi4py'])
90 @classmethod
91 def serial(cls):
92 return MPICommand(cls.python_argv() + ['standard'])
94 def execute(self):
95 # On this computer (Ubuntu 20.04 + OpenMPI) the subprocess crashes
96 # without output during startup if os.environ is not passed along.
97 # Hence we pass os.environ. Not sure if this is a machine thing
98 # or in general. --askhl
99 return Popen(self.argv, stdout=PIPE,
100 stdin=PIPE, env=os.environ)
103def gpaw_process(ncores=1, **kwargs):
104 packed = NamedPackedCalculator('gpaw', kwargs)
105 mpicommand = MPICommand([
106 sys.executable, '-m', 'gpaw', '-P', str(ncores), 'python', '-m',
107 'ase.calculators.subprocesscalculator', 'standard',
108 ])
109 return PythonSubProcessCalculator(packed, mpicommand)
112class PythonSubProcessCalculator(Calculator):
113 """Calculator for running calculations in external processes.
115 TODO: This should work with arbitrary commands including MPI stuff.
117 This calculator runs a subprocess wherein it sets up an
118 actual calculator. Calculations are forwarded through pickle
119 to that calculator, which returns results through pickle."""
120 implemented_properties = list(all_properties)
122 def __init__(self, calc_input, mpi_command=None):
123 super().__init__()
125 # self.proc = None
126 self.calc_input = calc_input
127 if mpi_command is None:
128 mpi_command = MPICommand.serial()
129 self.mpi_command = mpi_command
131 self.protocol = None
133 def set(self, **kwargs):
134 if hasattr(self, 'client'):
135 raise RuntimeError('No setting things for now, thanks')
137 def __repr__(self):
138 return '{}({})'.format(type(self).__name__,
139 self.calc_input)
141 def __enter__(self):
142 assert self.protocol is None
143 proc = self.mpi_command.execute()
144 self.protocol = Protocol(proc)
145 self.protocol.send(self.calc_input)
146 return self
148 def __exit__(self, *args):
149 self.protocol.send('stop')
150 self.protocol.proc.communicate()
151 self.protocol = None
153 def _run_calculation(self, atoms, properties, system_changes):
154 self.protocol.send('calculate')
155 self.protocol.send((atoms, properties, system_changes))
157 def calculate(self, atoms, properties, system_changes):
158 Calculator.calculate(self, atoms, properties, system_changes)
159 # We send a pickle of self.atoms because this is a fresh copy
160 # of the input, but without an unpicklable calculator:
161 self._run_calculation(self.atoms.copy(), properties, system_changes)
162 results = self.protocol.recv()
163 self.results.update(results)
165 def backend(self):
166 return ParallelBackendInterface(self)
169class Protocol:
170 def __init__(self, proc):
171 self.proc = proc
173 def send(self, obj):
174 pickle.dump(obj, self.proc.stdin)
175 self.proc.stdin.flush()
177 def recv(self):
178 response_type, value = pickle.load(self.proc.stdout)
180 if response_type == 'raise':
181 raise value
183 assert response_type == 'return'
184 return value
187class MockMethod:
188 def __init__(self, name, calc):
189 self.name = name
190 self.calc = calc
192 def __call__(self, *args, **kwargs):
193 protocol = self.calc.protocol
194 protocol.send('callmethod')
195 protocol.send([self.name, args, kwargs])
196 return protocol.recv()
199class ParallelBackendInterface:
200 def __init__(self, calc):
201 self.calc = calc
203 def __getattr__(self, name):
204 return MockMethod(name, self.calc)
207run_modes = {'standard', 'mpi4py'}
210def callmethod(calc, attrname, args, kwargs):
211 method = getattr(calc, attrname)
212 value = method(*args, **kwargs)
213 return value
216def callfunction(func, args, kwargs):
217 return func(*args, **kwargs)
220def calculate(calc, atoms, properties, system_changes):
221 # Again we need formalization of the results/outputs, and
222 # a way to programmatically access all available properties.
223 # We do a wild hack for now:
224 calc.results.clear()
225 # If we don't clear(), the caching is broken! For stress.
226 # But not for forces. What dark magic from the depths of the
227 # underworld is at play here?
228 calc.calculate(atoms=atoms, properties=properties,
229 system_changes=system_changes)
230 results = calc.results
231 return results
234def bad_mode():
235 return SystemExit(f'sys.argv[1] must be one of {run_modes}')
238def parallel_startup():
239 try:
240 run_mode = sys.argv[1]
241 except IndexError:
242 raise bad_mode()
244 if run_mode not in run_modes:
245 raise bad_mode()
247 if run_mode == 'mpi4py':
248 # We must import mpi4py before the rest of ASE, or world will not
249 # be correctly initialized.
250 import mpi4py # noqa
252 # We switch stdout so stray print statements won't interfere with outputs:
253 binary_stdout = sys.stdout.buffer
254 sys.stdout = sys.stderr
256 return Client(input_fd=sys.stdin.buffer,
257 output_fd=binary_stdout)
260class Client:
261 def __init__(self, input_fd, output_fd):
262 from ase.parallel import world
263 self._world = world
264 self.input_fd = input_fd
265 self.output_fd = output_fd
267 def recv(self):
268 from ase.parallel import broadcast
269 if self._world.rank == 0:
270 obj = pickle.load(self.input_fd)
271 else:
272 obj = None
274 obj = broadcast(obj, 0, self._world)
275 return obj
277 def send(self, obj):
278 if self._world.rank == 0:
279 pickle.dump(obj, self.output_fd)
280 self.output_fd.flush()
282 def mainloop(self, calc):
283 while True:
284 instruction = self.recv()
285 if instruction == 'stop':
286 return
288 instruction_data = self.recv()
290 response_type, value = self.process_instruction(
291 calc, instruction, instruction_data)
292 self.send((response_type, value))
294 def process_instruction(self, calc, instruction, instruction_data):
295 if instruction == 'callmethod':
296 function = callmethod
297 args = (calc, *instruction_data)
298 elif instruction == 'calculate':
299 function = calculate
300 args = (calc, *instruction_data)
301 elif instruction == 'callfunction':
302 function = callfunction
303 args = instruction_data
304 else:
305 raise RuntimeError(f'Bad instruction: {instruction}')
307 try:
308 value = function(*args)
309 except Exception as ex:
310 import traceback
311 traceback.print_exc()
312 response_type = 'raise'
313 value = ex
314 else:
315 response_type = 'return'
316 return response_type, value
319class ParallelDispatch:
320 """Utility class to run functions in parallel.
322 with ParallelDispatch(...) as parallel:
323 parallel.call(function, args, kwargs)
325 """
327 def __init__(self, mpicommand):
328 self._mpicommand = mpicommand
329 self._protocol = None
331 def call(self, func, *args, **kwargs):
332 self._protocol.send('callfunction')
333 self._protocol.send((func, args, kwargs))
334 return self._protocol.recv()
336 def __enter__(self):
337 assert self._protocol is None
338 self._protocol = Protocol(self._mpicommand.execute())
340 # Even if we are not using a calculator, we have to send one:
341 pack = NamedPackedCalculator('emt', {})
342 self._protocol.send(pack)
343 # (We should get rid of that requirement.)
345 return self
347 def __exit__(self, *args):
348 self._protocol.send('stop')
349 self._protocol.proc.communicate()
350 self._protocol = None
353def main():
354 client = parallel_startup()
355 pack = client.recv()
356 calc = pack.unpack_calculator()
357 client.mainloop(calc)
360if __name__ == '__main__':
361 main()