Coverage for ase / parallel.py: 52.73%
220 statements
« prev ^ index » next coverage.py v7.13.3, created at 2026-02-04 10:20 +0000
« prev ^ index » next coverage.py v7.13.3, created at 2026-02-04 10:20 +0000
1# fmt: off
3import atexit
4import functools
5import os
6import pickle
7import sys
8import time
9import warnings
11import numpy as np
14def paropen(name, mode='r', buffering=-1, encoding=None, comm=None):
15 """MPI-safe version of open function.
17 In read mode, the file is opened on all nodes. In write and
18 append mode, the file is opened on the master only, and /dev/null
19 is opened on all other nodes.
20 """
21 if comm is None:
22 comm = world
23 if comm.rank > 0 and mode[0] != 'r':
24 name = os.devnull
25 return open(name, mode, buffering, encoding)
28def parprint(*args, comm=None, **kwargs):
29 """MPI-safe print - prints only from master. """
30 if comm is None:
31 comm = world
32 if comm.rank == 0:
33 print(*args, **kwargs)
36class DummyMPI:
37 rank = 0
38 size = 1
40 def _returnval(self, a, root=-1):
41 # MPI interface works either on numbers, in which case a number is
42 # returned, or on arrays, in-place.
43 if np.isscalar(a):
44 return a
45 if hasattr(a, '__array__'):
46 a = a.__array__()
47 assert isinstance(a, np.ndarray)
48 return None
50 def sum(self, a, root=-1):
51 if np.isscalar(a):
52 warnings.warn('Please use sum_scalar(...) for scalar arguments',
53 FutureWarning)
54 return self._returnval(a)
56 def sum_scalar(self, a, root=-1):
57 return a
59 def product(self, a, root=-1):
60 return self._returnval(a)
62 def broadcast(self, a, root):
63 assert root == 0
64 return self._returnval(a)
66 def barrier(self):
67 pass
70class MPI:
71 """Wrapper for MPI world object.
73 Decides at runtime (after all imports) which one to use:
75 * MPI4Py
76 * GPAW
77 * Asap
78 * a dummy implementation for serial runs
80 """
82 def __init__(self):
83 self.comm = None
85 def __getattr__(self, name):
86 # Pickling of objects that carry instances of MPI class
87 # (e.g. NEB) raises RecursionError since it tries to access
88 # the optional __setstate__ method (which we do not implement)
89 # when unpickling. The two lines below prevent the
90 # RecursionError. This also affects modules that use pickling
91 # e.g. multiprocessing. For more details see:
92 # https://gitlab.com/ase/ase/-/merge_requests/2695
93 if name == '__setstate__':
94 raise AttributeError(name)
96 if self.comm is None:
97 self.comm = _get_comm()
98 return getattr(self.comm, name)
101def _get_comm():
102 """Get the correct MPI world object."""
103 if 'mpi4py' in sys.modules:
104 return MPI4PY()
105 if '_gpaw' in sys.modules:
106 world = _gpaw_world()
107 if world is not None:
108 return world
109 if '_asap' in sys.modules:
110 world = _asap_world()
111 if world is not None:
112 return world
113 return DummyMPI()
116class MPI4PY:
117 def __init__(self, mpi4py_comm=None):
118 if mpi4py_comm is None:
119 from mpi4py import MPI
120 mpi4py_comm = MPI.COMM_WORLD
121 self.comm = mpi4py_comm
123 @property
124 def rank(self):
125 return self.comm.rank
127 @property
128 def size(self):
129 return self.comm.size
131 def _returnval(self, a, b):
132 """Behave correctly when working on scalars/arrays.
134 Either input is an array and we in-place write b (output from
135 mpi4py) back into a, or input is a scalar and we return the
136 corresponding output scalar."""
137 if np.isscalar(a):
138 assert np.isscalar(b)
139 return b
140 else:
141 assert not np.isscalar(b)
142 a[:] = b
143 return None
145 def sum(self, a, root=-1):
146 if root == -1:
147 b = self.comm.allreduce(a)
148 else:
149 b = self.comm.reduce(a, root)
150 if np.isscalar(a):
151 warnings.warn('Please use sum_scalar(...) for scalar arguments',
152 FutureWarning)
153 return self._returnval(a, b)
155 def sum_scalar(self, a, root=-1):
156 if root == -1:
157 b = self.comm.allreduce(a)
158 else:
159 b = self.comm.reduce(a, root)
160 return b
162 def split(self, split_size=None):
163 """Divide the communicator."""
164 # color - subgroup id
165 # key - new subgroup rank
166 if not split_size:
167 split_size = self.size
168 color = int(self.rank // (self.size / split_size))
169 key = int(self.rank % (self.size / split_size))
170 comm = self.comm.Split(color, key)
171 return MPI4PY(comm)
173 def barrier(self):
174 self.comm.barrier()
176 def abort(self, code):
177 self.comm.Abort(code)
179 def broadcast(self, a, root):
180 b = self.comm.bcast(a, root=root)
181 if self.rank == root:
182 if np.isscalar(a):
183 return a
184 return None
185 return self._returnval(a, b)
188class AsapCommWrapper:
189 """Compatibility hack to save people from trouble with older asap.
191 We can definitely remove this in 2027."""
192 def __init__(self, world):
193 self.world = world
195 def __getattr__(self, attr):
196 return getattr(self.world, attr)
198 def sum_scalar(self, a, root=-1):
199 buf = np.array([a])
200 self.world.sum(buf, root=root)
201 return buf[0]
204def _asap_world():
205 import _asap
206 try:
207 world = _asap.Communicator()
208 except AttributeError:
209 return None
211 if not hasattr(world, 'sum_scalar'):
212 world = AsapCommWrapper(world)
214 return world
217def _gpaw_world():
218 import _gpaw
220 if not hasattr(_gpaw, 'Communicator'):
221 return None
223 # Return the actual public (possibly wrapper) object
224 from gpaw.mpi import world
225 return world
228world = MPI()
231def barrier():
232 world.barrier()
235def broadcast(obj, root=0, comm=world):
236 """Broadcast a Python object across an MPI communicator and return it."""
237 if comm.rank == root:
238 string = pickle.dumps(obj, pickle.HIGHEST_PROTOCOL)
239 n = np.array([len(string)], int)
240 else:
241 string = None
242 n = np.empty(1, int)
243 comm.broadcast(n, root)
244 if comm.rank == root:
245 string = np.frombuffer(string, np.int8)
246 else:
247 string = np.zeros(n, np.int8)
248 comm.broadcast(string, root)
249 if comm.rank == root:
250 return obj
251 else:
252 return pickle.loads(string.tobytes())
255def parallel_function(func):
256 """Decorator for broadcasting from master to slaves using MPI.
258 Disable by passing parallel=False to the function. For a method,
259 you can also disable the parallel behavior by giving the instance
260 a self.serial = True.
261 """
263 @functools.wraps(func)
264 def new_func(*args, **kwargs):
265 parallel = kwargs.pop('parallel', True)
266 if (world.size == 1 or
267 (args and getattr(args[0], 'serial', False)) or not parallel):
268 # Disable:
269 return func(*args, **kwargs)
271 ex = None
272 result = None
273 if world.rank == 0:
274 try:
275 result = func(*args, **kwargs)
276 except Exception as x:
277 ex = x
278 ex, result = broadcast((ex, result))
279 if ex is not None:
280 raise ex
281 return result
283 return new_func
286def parallel_generator(generator):
287 """Decorator for broadcasting yields from master to slaves using MPI.
289 Disable by passing parallel=False to the function. For a method,
290 you can also disable the parallel behavior by giving the instance
291 a self.serial = True.
292 """
294 @functools.wraps(generator)
295 def new_generator(*args, **kwargs):
296 parallel = kwargs.pop('parallel', True)
297 if (world.size == 1 or
298 (args and getattr(args[0], 'serial', False)) or
299 not parallel):
300 # Disable:
301 for result in generator(*args, **kwargs):
302 yield result
303 return
305 if world.rank == 0:
306 try:
307 for result in generator(*args, **kwargs):
308 broadcast((None, result))
309 yield result
310 except Exception as ex:
311 broadcast((ex, None))
312 raise ex
313 broadcast((None, None))
314 else:
315 ex2, result = broadcast((None, None))
316 if ex2 is not None:
317 raise ex2
318 while result is not None:
319 yield result
320 ex2, result = broadcast((None, None))
321 if ex2 is not None:
322 raise ex2
324 return new_generator
327def register_parallel_cleanup_function():
328 """Call MPI_Abort if python crashes.
330 This will terminate the processes on the other nodes."""
332 if world.size == 1:
333 return
335 def cleanup(sys=sys, time=time, world=world):
336 error = getattr(sys, 'last_type', None)
337 if error:
338 sys.stdout.flush()
339 sys.stderr.write(('ASE CLEANUP (node %d): %s occurred. ' +
340 'Calling MPI_Abort!\n') % (world.rank, error))
341 sys.stderr.flush()
342 # Give other nodes a moment to crash by themselves (perhaps
343 # producing helpful error messages):
344 time.sleep(3)
345 world.abort(42)
347 atexit.register(cleanup)
350def distribute_cpus(size, comm):
351 """Distribute cpus to tasks and calculators.
353 Input:
354 size: number of nodes per calculator
355 comm: total communicator object
357 Output:
358 communicator for this rank, number of calculators, index for this rank
359 """
361 assert size <= comm.size
362 assert comm.size % size == 0
364 tasks_rank = comm.rank // size
366 r0 = tasks_rank * size
367 ranks = np.arange(r0, r0 + size)
368 mycomm = comm.new_communicator(ranks)
370 return mycomm, comm.size // size, tasks_rank
373def myslice(ntotal, comm):
374 """Return the slice of your tasks for ntotal jobs"""
375 n = -(-ntotal // comm.size) # ceil divide
376 return slice(n * comm.rank, n * (comm.rank + 1))