Coverage for /builds/ase/ase/ase/parallel.py: 51.39%

216 statements  

« prev     ^ index     » next       coverage.py v7.5.3, created at 2025-08-02 00:12 +0000

1# fmt: off 

2 

3import atexit 

4import functools 

5import os 

6import pickle 

7import sys 

8import time 

9import warnings 

10 

11import numpy as np 

12 

13 

14def paropen(name, mode='r', buffering=-1, encoding=None, comm=None): 

15 """MPI-safe version of open function. 

16 

17 In read mode, the file is opened on all nodes. In write and 

18 append mode, the file is opened on the master only, and /dev/null 

19 is opened on all other nodes. 

20 """ 

21 if comm is None: 

22 comm = world 

23 if comm.rank > 0 and mode[0] != 'r': 

24 name = os.devnull 

25 return open(name, mode, buffering, encoding) 

26 

27 

28def parprint(*args, comm=None, **kwargs): 

29 """MPI-safe print - prints only from master. """ 

30 if comm is None: 

31 comm = world 

32 if comm.rank == 0: 

33 print(*args, **kwargs) 

34 

35 

36class DummyMPI: 

37 rank = 0 

38 size = 1 

39 

40 def _returnval(self, a, root=-1): 

41 # MPI interface works either on numbers, in which case a number is 

42 # returned, or on arrays, in-place. 

43 if np.isscalar(a): 

44 return a 

45 if hasattr(a, '__array__'): 

46 a = a.__array__() 

47 assert isinstance(a, np.ndarray) 

48 return None 

49 

50 def sum(self, a, root=-1): 

51 if np.isscalar(a): 

52 warnings.warn('Please use sum_scalar(...) for scalar arguments', 

53 FutureWarning) 

54 return self._returnval(a) 

55 

56 def sum_scalar(self, a, root=-1): 

57 return a 

58 

59 def product(self, a, root=-1): 

60 return self._returnval(a) 

61 

62 def broadcast(self, a, root): 

63 assert root == 0 

64 return self._returnval(a) 

65 

66 def barrier(self): 

67 pass 

68 

69 

70class MPI: 

71 """Wrapper for MPI world object. 

72 

73 Decides at runtime (after all imports) which one to use: 

74 

75 * MPI4Py 

76 * GPAW 

77 * Asap 

78 * a dummy implementation for serial runs 

79 

80 """ 

81 

82 def __init__(self): 

83 self.comm = None 

84 

85 def __getattr__(self, name): 

86 # Pickling of objects that carry instances of MPI class 

87 # (e.g. NEB) raises RecursionError since it tries to access 

88 # the optional __setstate__ method (which we do not implement) 

89 # when unpickling. The two lines below prevent the 

90 # RecursionError. This also affects modules that use pickling 

91 # e.g. multiprocessing. For more details see: 

92 # https://gitlab.com/ase/ase/-/merge_requests/2695 

93 if name == '__setstate__': 

94 raise AttributeError(name) 

95 

96 if self.comm is None: 

97 self.comm = _get_comm() 

98 return getattr(self.comm, name) 

99 

100 

101def _get_comm(): 

102 """Get the correct MPI world object.""" 

103 if 'mpi4py' in sys.modules: 

104 return MPI4PY() 

105 if '_gpaw' in sys.modules: 

106 import _gpaw 

107 if hasattr(_gpaw, 'Communicator'): 

108 return _gpaw.Communicator() 

109 if '_asap' in sys.modules: 

110 import _asap 

111 if hasattr(_asap, 'Communicator'): 

112 return _asap.Communicator() 

113 return DummyMPI() 

114 

115 

116class MPI4PY: 

117 def __init__(self, mpi4py_comm=None): 

118 if mpi4py_comm is None: 

119 from mpi4py import MPI 

120 mpi4py_comm = MPI.COMM_WORLD 

121 self.comm = mpi4py_comm 

122 

123 @property 

124 def rank(self): 

125 return self.comm.rank 

126 

127 @property 

128 def size(self): 

129 return self.comm.size 

130 

131 def _returnval(self, a, b): 

132 """Behave correctly when working on scalars/arrays. 

133 

134 Either input is an array and we in-place write b (output from 

135 mpi4py) back into a, or input is a scalar and we return the 

136 corresponding output scalar.""" 

137 if np.isscalar(a): 

138 assert np.isscalar(b) 

139 return b 

140 else: 

141 assert not np.isscalar(b) 

142 a[:] = b 

143 return None 

144 

145 def sum(self, a, root=-1): 

146 if root == -1: 

147 b = self.comm.allreduce(a) 

148 else: 

149 b = self.comm.reduce(a, root) 

150 if np.isscalar(a): 

151 warnings.warn('Please use sum_scalar(...) for scalar arguments', 

152 FutureWarning) 

153 return self._returnval(a, b) 

154 

155 def sum_scalar(self, a, root=-1): 

156 if root == -1: 

157 b = self.comm.allreduce(a) 

158 else: 

159 b = self.comm.reduce(a, root) 

160 return b 

161 

162 def split(self, split_size=None): 

163 """Divide the communicator.""" 

164 # color - subgroup id 

165 # key - new subgroup rank 

166 if not split_size: 

167 split_size = self.size 

168 color = int(self.rank // (self.size / split_size)) 

169 key = int(self.rank % (self.size / split_size)) 

170 comm = self.comm.Split(color, key) 

171 return MPI4PY(comm) 

172 

173 def barrier(self): 

174 self.comm.barrier() 

175 

176 def abort(self, code): 

177 self.comm.Abort(code) 

178 

179 def broadcast(self, a, root): 

180 b = self.comm.bcast(a, root=root) 

181 if self.rank == root: 

182 if np.isscalar(a): 

183 return a 

184 return None 

185 return self._returnval(a, b) 

186 

187 

188world = None 

189 

190# Check for special MPI-enabled Python interpreters: 

191if '_gpaw' in sys.builtin_module_names: 

192 # http://gpaw.readthedocs.io 

193 import _gpaw 

194 world = _gpaw.Communicator() 

195elif '_asap' in sys.builtin_module_names: 

196 # Modern version of Asap 

197 # http://wiki.fysik.dtu.dk/asap 

198 # We cannot import asap3.mpi here, as that creates an import deadlock 

199 import _asap 

200 world = _asap.Communicator() 

201 

202# Check if MPI implementation has been imported already: 

203elif '_gpaw' in sys.modules: 

204 # Same thing as above but for the module version 

205 import _gpaw 

206 try: 

207 world = _gpaw.Communicator() 

208 except AttributeError: 

209 pass 

210elif '_asap' in sys.modules: 

211 import _asap 

212 try: 

213 world = _asap.Communicator() 

214 except AttributeError: 

215 pass 

216elif 'mpi4py' in sys.modules: 

217 world = MPI4PY() 

218 

219if world is None: 

220 world = MPI() 

221 

222 

223def barrier(): 

224 world.barrier() 

225 

226 

227def broadcast(obj, root=0, comm=world): 

228 """Broadcast a Python object across an MPI communicator and return it.""" 

229 if comm.rank == root: 

230 string = pickle.dumps(obj, pickle.HIGHEST_PROTOCOL) 

231 n = np.array([len(string)], int) 

232 else: 

233 string = None 

234 n = np.empty(1, int) 

235 comm.broadcast(n, root) 

236 if comm.rank == root: 

237 string = np.frombuffer(string, np.int8) 

238 else: 

239 string = np.zeros(n, np.int8) 

240 comm.broadcast(string, root) 

241 if comm.rank == root: 

242 return obj 

243 else: 

244 return pickle.loads(string.tobytes()) 

245 

246 

247def parallel_function(func): 

248 """Decorator for broadcasting from master to slaves using MPI. 

249 

250 Disable by passing parallel=False to the function. For a method, 

251 you can also disable the parallel behavior by giving the instance 

252 a self.serial = True. 

253 """ 

254 

255 @functools.wraps(func) 

256 def new_func(*args, **kwargs): 

257 if (world.size == 1 or 

258 args and getattr(args[0], 'serial', False) or 

259 not kwargs.pop('parallel', True)): 

260 # Disable: 

261 return func(*args, **kwargs) 

262 

263 ex = None 

264 result = None 

265 if world.rank == 0: 

266 try: 

267 result = func(*args, **kwargs) 

268 except Exception as x: 

269 ex = x 

270 ex, result = broadcast((ex, result)) 

271 if ex is not None: 

272 raise ex 

273 return result 

274 

275 return new_func 

276 

277 

278def parallel_generator(generator): 

279 """Decorator for broadcasting yields from master to slaves using MPI. 

280 

281 Disable by passing parallel=False to the function. For a method, 

282 you can also disable the parallel behavior by giving the instance 

283 a self.serial = True. 

284 """ 

285 

286 @functools.wraps(generator) 

287 def new_generator(*args, **kwargs): 

288 if (world.size == 1 or 

289 args and getattr(args[0], 'serial', False) or 

290 not kwargs.pop('parallel', True)): 

291 # Disable: 

292 for result in generator(*args, **kwargs): 

293 yield result 

294 return 

295 

296 if world.rank == 0: 

297 try: 

298 for result in generator(*args, **kwargs): 

299 broadcast((None, result)) 

300 yield result 

301 except Exception as ex: 

302 broadcast((ex, None)) 

303 raise ex 

304 broadcast((None, None)) 

305 else: 

306 ex2, result = broadcast((None, None)) 

307 if ex2 is not None: 

308 raise ex2 

309 while result is not None: 

310 yield result 

311 ex2, result = broadcast((None, None)) 

312 if ex2 is not None: 

313 raise ex2 

314 

315 return new_generator 

316 

317 

318def register_parallel_cleanup_function(): 

319 """Call MPI_Abort if python crashes. 

320 

321 This will terminate the processes on the other nodes.""" 

322 

323 if world.size == 1: 

324 return 

325 

326 def cleanup(sys=sys, time=time, world=world): 

327 error = getattr(sys, 'last_type', None) 

328 if error: 

329 sys.stdout.flush() 

330 sys.stderr.write(('ASE CLEANUP (node %d): %s occurred. ' + 

331 'Calling MPI_Abort!\n') % (world.rank, error)) 

332 sys.stderr.flush() 

333 # Give other nodes a moment to crash by themselves (perhaps 

334 # producing helpful error messages): 

335 time.sleep(3) 

336 world.abort(42) 

337 

338 atexit.register(cleanup) 

339 

340 

341def distribute_cpus(size, comm): 

342 """Distribute cpus to tasks and calculators. 

343 

344 Input: 

345 size: number of nodes per calculator 

346 comm: total communicator object 

347 

348 Output: 

349 communicator for this rank, number of calculators, index for this rank 

350 """ 

351 

352 assert size <= comm.size 

353 assert comm.size % size == 0 

354 

355 tasks_rank = comm.rank // size 

356 

357 r0 = tasks_rank * size 

358 ranks = np.arange(r0, r0 + size) 

359 mycomm = comm.new_communicator(ranks) 

360 

361 return mycomm, comm.size // size, tasks_rank 

362 

363 

364def myslice(ntotal, comm): 

365 """Return the slice of your tasks for ntotal jobs""" 

366 n = -(-ntotal // comm.size) # ceil divide 

367 return slice(n * comm.rank, n * (comm.rank + 1))