Coverage for /builds/ase/ase/ase/parallel.py: 51.39%
216 statements
« prev ^ index » next coverage.py v7.5.3, created at 2025-08-02 00:12 +0000
« prev ^ index » next coverage.py v7.5.3, created at 2025-08-02 00:12 +0000
1# fmt: off
3import atexit
4import functools
5import os
6import pickle
7import sys
8import time
9import warnings
11import numpy as np
14def paropen(name, mode='r', buffering=-1, encoding=None, comm=None):
15 """MPI-safe version of open function.
17 In read mode, the file is opened on all nodes. In write and
18 append mode, the file is opened on the master only, and /dev/null
19 is opened on all other nodes.
20 """
21 if comm is None:
22 comm = world
23 if comm.rank > 0 and mode[0] != 'r':
24 name = os.devnull
25 return open(name, mode, buffering, encoding)
28def parprint(*args, comm=None, **kwargs):
29 """MPI-safe print - prints only from master. """
30 if comm is None:
31 comm = world
32 if comm.rank == 0:
33 print(*args, **kwargs)
36class DummyMPI:
37 rank = 0
38 size = 1
40 def _returnval(self, a, root=-1):
41 # MPI interface works either on numbers, in which case a number is
42 # returned, or on arrays, in-place.
43 if np.isscalar(a):
44 return a
45 if hasattr(a, '__array__'):
46 a = a.__array__()
47 assert isinstance(a, np.ndarray)
48 return None
50 def sum(self, a, root=-1):
51 if np.isscalar(a):
52 warnings.warn('Please use sum_scalar(...) for scalar arguments',
53 FutureWarning)
54 return self._returnval(a)
56 def sum_scalar(self, a, root=-1):
57 return a
59 def product(self, a, root=-1):
60 return self._returnval(a)
62 def broadcast(self, a, root):
63 assert root == 0
64 return self._returnval(a)
66 def barrier(self):
67 pass
70class MPI:
71 """Wrapper for MPI world object.
73 Decides at runtime (after all imports) which one to use:
75 * MPI4Py
76 * GPAW
77 * Asap
78 * a dummy implementation for serial runs
80 """
82 def __init__(self):
83 self.comm = None
85 def __getattr__(self, name):
86 # Pickling of objects that carry instances of MPI class
87 # (e.g. NEB) raises RecursionError since it tries to access
88 # the optional __setstate__ method (which we do not implement)
89 # when unpickling. The two lines below prevent the
90 # RecursionError. This also affects modules that use pickling
91 # e.g. multiprocessing. For more details see:
92 # https://gitlab.com/ase/ase/-/merge_requests/2695
93 if name == '__setstate__':
94 raise AttributeError(name)
96 if self.comm is None:
97 self.comm = _get_comm()
98 return getattr(self.comm, name)
101def _get_comm():
102 """Get the correct MPI world object."""
103 if 'mpi4py' in sys.modules:
104 return MPI4PY()
105 if '_gpaw' in sys.modules:
106 import _gpaw
107 if hasattr(_gpaw, 'Communicator'):
108 return _gpaw.Communicator()
109 if '_asap' in sys.modules:
110 import _asap
111 if hasattr(_asap, 'Communicator'):
112 return _asap.Communicator()
113 return DummyMPI()
116class MPI4PY:
117 def __init__(self, mpi4py_comm=None):
118 if mpi4py_comm is None:
119 from mpi4py import MPI
120 mpi4py_comm = MPI.COMM_WORLD
121 self.comm = mpi4py_comm
123 @property
124 def rank(self):
125 return self.comm.rank
127 @property
128 def size(self):
129 return self.comm.size
131 def _returnval(self, a, b):
132 """Behave correctly when working on scalars/arrays.
134 Either input is an array and we in-place write b (output from
135 mpi4py) back into a, or input is a scalar and we return the
136 corresponding output scalar."""
137 if np.isscalar(a):
138 assert np.isscalar(b)
139 return b
140 else:
141 assert not np.isscalar(b)
142 a[:] = b
143 return None
145 def sum(self, a, root=-1):
146 if root == -1:
147 b = self.comm.allreduce(a)
148 else:
149 b = self.comm.reduce(a, root)
150 if np.isscalar(a):
151 warnings.warn('Please use sum_scalar(...) for scalar arguments',
152 FutureWarning)
153 return self._returnval(a, b)
155 def sum_scalar(self, a, root=-1):
156 if root == -1:
157 b = self.comm.allreduce(a)
158 else:
159 b = self.comm.reduce(a, root)
160 return b
162 def split(self, split_size=None):
163 """Divide the communicator."""
164 # color - subgroup id
165 # key - new subgroup rank
166 if not split_size:
167 split_size = self.size
168 color = int(self.rank // (self.size / split_size))
169 key = int(self.rank % (self.size / split_size))
170 comm = self.comm.Split(color, key)
171 return MPI4PY(comm)
173 def barrier(self):
174 self.comm.barrier()
176 def abort(self, code):
177 self.comm.Abort(code)
179 def broadcast(self, a, root):
180 b = self.comm.bcast(a, root=root)
181 if self.rank == root:
182 if np.isscalar(a):
183 return a
184 return None
185 return self._returnval(a, b)
188world = None
190# Check for special MPI-enabled Python interpreters:
191if '_gpaw' in sys.builtin_module_names:
192 # http://gpaw.readthedocs.io
193 import _gpaw
194 world = _gpaw.Communicator()
195elif '_asap' in sys.builtin_module_names:
196 # Modern version of Asap
197 # http://wiki.fysik.dtu.dk/asap
198 # We cannot import asap3.mpi here, as that creates an import deadlock
199 import _asap
200 world = _asap.Communicator()
202# Check if MPI implementation has been imported already:
203elif '_gpaw' in sys.modules:
204 # Same thing as above but for the module version
205 import _gpaw
206 try:
207 world = _gpaw.Communicator()
208 except AttributeError:
209 pass
210elif '_asap' in sys.modules:
211 import _asap
212 try:
213 world = _asap.Communicator()
214 except AttributeError:
215 pass
216elif 'mpi4py' in sys.modules:
217 world = MPI4PY()
219if world is None:
220 world = MPI()
223def barrier():
224 world.barrier()
227def broadcast(obj, root=0, comm=world):
228 """Broadcast a Python object across an MPI communicator and return it."""
229 if comm.rank == root:
230 string = pickle.dumps(obj, pickle.HIGHEST_PROTOCOL)
231 n = np.array([len(string)], int)
232 else:
233 string = None
234 n = np.empty(1, int)
235 comm.broadcast(n, root)
236 if comm.rank == root:
237 string = np.frombuffer(string, np.int8)
238 else:
239 string = np.zeros(n, np.int8)
240 comm.broadcast(string, root)
241 if comm.rank == root:
242 return obj
243 else:
244 return pickle.loads(string.tobytes())
247def parallel_function(func):
248 """Decorator for broadcasting from master to slaves using MPI.
250 Disable by passing parallel=False to the function. For a method,
251 you can also disable the parallel behavior by giving the instance
252 a self.serial = True.
253 """
255 @functools.wraps(func)
256 def new_func(*args, **kwargs):
257 if (world.size == 1 or
258 args and getattr(args[0], 'serial', False) or
259 not kwargs.pop('parallel', True)):
260 # Disable:
261 return func(*args, **kwargs)
263 ex = None
264 result = None
265 if world.rank == 0:
266 try:
267 result = func(*args, **kwargs)
268 except Exception as x:
269 ex = x
270 ex, result = broadcast((ex, result))
271 if ex is not None:
272 raise ex
273 return result
275 return new_func
278def parallel_generator(generator):
279 """Decorator for broadcasting yields from master to slaves using MPI.
281 Disable by passing parallel=False to the function. For a method,
282 you can also disable the parallel behavior by giving the instance
283 a self.serial = True.
284 """
286 @functools.wraps(generator)
287 def new_generator(*args, **kwargs):
288 if (world.size == 1 or
289 args and getattr(args[0], 'serial', False) or
290 not kwargs.pop('parallel', True)):
291 # Disable:
292 for result in generator(*args, **kwargs):
293 yield result
294 return
296 if world.rank == 0:
297 try:
298 for result in generator(*args, **kwargs):
299 broadcast((None, result))
300 yield result
301 except Exception as ex:
302 broadcast((ex, None))
303 raise ex
304 broadcast((None, None))
305 else:
306 ex2, result = broadcast((None, None))
307 if ex2 is not None:
308 raise ex2
309 while result is not None:
310 yield result
311 ex2, result = broadcast((None, None))
312 if ex2 is not None:
313 raise ex2
315 return new_generator
318def register_parallel_cleanup_function():
319 """Call MPI_Abort if python crashes.
321 This will terminate the processes on the other nodes."""
323 if world.size == 1:
324 return
326 def cleanup(sys=sys, time=time, world=world):
327 error = getattr(sys, 'last_type', None)
328 if error:
329 sys.stdout.flush()
330 sys.stderr.write(('ASE CLEANUP (node %d): %s occurred. ' +
331 'Calling MPI_Abort!\n') % (world.rank, error))
332 sys.stderr.flush()
333 # Give other nodes a moment to crash by themselves (perhaps
334 # producing helpful error messages):
335 time.sleep(3)
336 world.abort(42)
338 atexit.register(cleanup)
341def distribute_cpus(size, comm):
342 """Distribute cpus to tasks and calculators.
344 Input:
345 size: number of nodes per calculator
346 comm: total communicator object
348 Output:
349 communicator for this rank, number of calculators, index for this rank
350 """
352 assert size <= comm.size
353 assert comm.size % size == 0
355 tasks_rank = comm.rank // size
357 r0 = tasks_rank * size
358 ranks = np.arange(r0, r0 + size)
359 mycomm = comm.new_communicator(ranks)
361 return mycomm, comm.size // size, tasks_rank
364def myslice(ntotal, comm):
365 """Return the slice of your tasks for ntotal jobs"""
366 n = -(-ntotal // comm.size) # ceil divide
367 return slice(n * comm.rank, n * (comm.rank + 1))