Coverage for ase / parallel.py: 52.73%

220 statements  

« prev     ^ index     » next       coverage.py v7.13.3, created at 2026-02-04 10:20 +0000

1# fmt: off 

2 

3import atexit 

4import functools 

5import os 

6import pickle 

7import sys 

8import time 

9import warnings 

10 

11import numpy as np 

12 

13 

14def paropen(name, mode='r', buffering=-1, encoding=None, comm=None): 

15 """MPI-safe version of open function. 

16 

17 In read mode, the file is opened on all nodes. In write and 

18 append mode, the file is opened on the master only, and /dev/null 

19 is opened on all other nodes. 

20 """ 

21 if comm is None: 

22 comm = world 

23 if comm.rank > 0 and mode[0] != 'r': 

24 name = os.devnull 

25 return open(name, mode, buffering, encoding) 

26 

27 

28def parprint(*args, comm=None, **kwargs): 

29 """MPI-safe print - prints only from master. """ 

30 if comm is None: 

31 comm = world 

32 if comm.rank == 0: 

33 print(*args, **kwargs) 

34 

35 

36class DummyMPI: 

37 rank = 0 

38 size = 1 

39 

40 def _returnval(self, a, root=-1): 

41 # MPI interface works either on numbers, in which case a number is 

42 # returned, or on arrays, in-place. 

43 if np.isscalar(a): 

44 return a 

45 if hasattr(a, '__array__'): 

46 a = a.__array__() 

47 assert isinstance(a, np.ndarray) 

48 return None 

49 

50 def sum(self, a, root=-1): 

51 if np.isscalar(a): 

52 warnings.warn('Please use sum_scalar(...) for scalar arguments', 

53 FutureWarning) 

54 return self._returnval(a) 

55 

56 def sum_scalar(self, a, root=-1): 

57 return a 

58 

59 def product(self, a, root=-1): 

60 return self._returnval(a) 

61 

62 def broadcast(self, a, root): 

63 assert root == 0 

64 return self._returnval(a) 

65 

66 def barrier(self): 

67 pass 

68 

69 

70class MPI: 

71 """Wrapper for MPI world object. 

72 

73 Decides at runtime (after all imports) which one to use: 

74 

75 * MPI4Py 

76 * GPAW 

77 * Asap 

78 * a dummy implementation for serial runs 

79 

80 """ 

81 

82 def __init__(self): 

83 self.comm = None 

84 

85 def __getattr__(self, name): 

86 # Pickling of objects that carry instances of MPI class 

87 # (e.g. NEB) raises RecursionError since it tries to access 

88 # the optional __setstate__ method (which we do not implement) 

89 # when unpickling. The two lines below prevent the 

90 # RecursionError. This also affects modules that use pickling 

91 # e.g. multiprocessing. For more details see: 

92 # https://gitlab.com/ase/ase/-/merge_requests/2695 

93 if name == '__setstate__': 

94 raise AttributeError(name) 

95 

96 if self.comm is None: 

97 self.comm = _get_comm() 

98 return getattr(self.comm, name) 

99 

100 

101def _get_comm(): 

102 """Get the correct MPI world object.""" 

103 if 'mpi4py' in sys.modules: 

104 return MPI4PY() 

105 if '_gpaw' in sys.modules: 

106 world = _gpaw_world() 

107 if world is not None: 

108 return world 

109 if '_asap' in sys.modules: 

110 world = _asap_world() 

111 if world is not None: 

112 return world 

113 return DummyMPI() 

114 

115 

116class MPI4PY: 

117 def __init__(self, mpi4py_comm=None): 

118 if mpi4py_comm is None: 

119 from mpi4py import MPI 

120 mpi4py_comm = MPI.COMM_WORLD 

121 self.comm = mpi4py_comm 

122 

123 @property 

124 def rank(self): 

125 return self.comm.rank 

126 

127 @property 

128 def size(self): 

129 return self.comm.size 

130 

131 def _returnval(self, a, b): 

132 """Behave correctly when working on scalars/arrays. 

133 

134 Either input is an array and we in-place write b (output from 

135 mpi4py) back into a, or input is a scalar and we return the 

136 corresponding output scalar.""" 

137 if np.isscalar(a): 

138 assert np.isscalar(b) 

139 return b 

140 else: 

141 assert not np.isscalar(b) 

142 a[:] = b 

143 return None 

144 

145 def sum(self, a, root=-1): 

146 if root == -1: 

147 b = self.comm.allreduce(a) 

148 else: 

149 b = self.comm.reduce(a, root) 

150 if np.isscalar(a): 

151 warnings.warn('Please use sum_scalar(...) for scalar arguments', 

152 FutureWarning) 

153 return self._returnval(a, b) 

154 

155 def sum_scalar(self, a, root=-1): 

156 if root == -1: 

157 b = self.comm.allreduce(a) 

158 else: 

159 b = self.comm.reduce(a, root) 

160 return b 

161 

162 def split(self, split_size=None): 

163 """Divide the communicator.""" 

164 # color - subgroup id 

165 # key - new subgroup rank 

166 if not split_size: 

167 split_size = self.size 

168 color = int(self.rank // (self.size / split_size)) 

169 key = int(self.rank % (self.size / split_size)) 

170 comm = self.comm.Split(color, key) 

171 return MPI4PY(comm) 

172 

173 def barrier(self): 

174 self.comm.barrier() 

175 

176 def abort(self, code): 

177 self.comm.Abort(code) 

178 

179 def broadcast(self, a, root): 

180 b = self.comm.bcast(a, root=root) 

181 if self.rank == root: 

182 if np.isscalar(a): 

183 return a 

184 return None 

185 return self._returnval(a, b) 

186 

187 

188class AsapCommWrapper: 

189 """Compatibility hack to save people from trouble with older asap. 

190 

191 We can definitely remove this in 2027.""" 

192 def __init__(self, world): 

193 self.world = world 

194 

195 def __getattr__(self, attr): 

196 return getattr(self.world, attr) 

197 

198 def sum_scalar(self, a, root=-1): 

199 buf = np.array([a]) 

200 self.world.sum(buf, root=root) 

201 return buf[0] 

202 

203 

204def _asap_world(): 

205 import _asap 

206 try: 

207 world = _asap.Communicator() 

208 except AttributeError: 

209 return None 

210 

211 if not hasattr(world, 'sum_scalar'): 

212 world = AsapCommWrapper(world) 

213 

214 return world 

215 

216 

217def _gpaw_world(): 

218 import _gpaw 

219 

220 if not hasattr(_gpaw, 'Communicator'): 

221 return None 

222 

223 # Return the actual public (possibly wrapper) object 

224 from gpaw.mpi import world 

225 return world 

226 

227 

228world = MPI() 

229 

230 

231def barrier(): 

232 world.barrier() 

233 

234 

235def broadcast(obj, root=0, comm=world): 

236 """Broadcast a Python object across an MPI communicator and return it.""" 

237 if comm.rank == root: 

238 string = pickle.dumps(obj, pickle.HIGHEST_PROTOCOL) 

239 n = np.array([len(string)], int) 

240 else: 

241 string = None 

242 n = np.empty(1, int) 

243 comm.broadcast(n, root) 

244 if comm.rank == root: 

245 string = np.frombuffer(string, np.int8) 

246 else: 

247 string = np.zeros(n, np.int8) 

248 comm.broadcast(string, root) 

249 if comm.rank == root: 

250 return obj 

251 else: 

252 return pickle.loads(string.tobytes()) 

253 

254 

255def parallel_function(func): 

256 """Decorator for broadcasting from master to slaves using MPI. 

257 

258 Disable by passing parallel=False to the function. For a method, 

259 you can also disable the parallel behavior by giving the instance 

260 a self.serial = True. 

261 """ 

262 

263 @functools.wraps(func) 

264 def new_func(*args, **kwargs): 

265 parallel = kwargs.pop('parallel', True) 

266 if (world.size == 1 or 

267 (args and getattr(args[0], 'serial', False)) or not parallel): 

268 # Disable: 

269 return func(*args, **kwargs) 

270 

271 ex = None 

272 result = None 

273 if world.rank == 0: 

274 try: 

275 result = func(*args, **kwargs) 

276 except Exception as x: 

277 ex = x 

278 ex, result = broadcast((ex, result)) 

279 if ex is not None: 

280 raise ex 

281 return result 

282 

283 return new_func 

284 

285 

286def parallel_generator(generator): 

287 """Decorator for broadcasting yields from master to slaves using MPI. 

288 

289 Disable by passing parallel=False to the function. For a method, 

290 you can also disable the parallel behavior by giving the instance 

291 a self.serial = True. 

292 """ 

293 

294 @functools.wraps(generator) 

295 def new_generator(*args, **kwargs): 

296 parallel = kwargs.pop('parallel', True) 

297 if (world.size == 1 or 

298 (args and getattr(args[0], 'serial', False)) or 

299 not parallel): 

300 # Disable: 

301 for result in generator(*args, **kwargs): 

302 yield result 

303 return 

304 

305 if world.rank == 0: 

306 try: 

307 for result in generator(*args, **kwargs): 

308 broadcast((None, result)) 

309 yield result 

310 except Exception as ex: 

311 broadcast((ex, None)) 

312 raise ex 

313 broadcast((None, None)) 

314 else: 

315 ex2, result = broadcast((None, None)) 

316 if ex2 is not None: 

317 raise ex2 

318 while result is not None: 

319 yield result 

320 ex2, result = broadcast((None, None)) 

321 if ex2 is not None: 

322 raise ex2 

323 

324 return new_generator 

325 

326 

327def register_parallel_cleanup_function(): 

328 """Call MPI_Abort if python crashes. 

329 

330 This will terminate the processes on the other nodes.""" 

331 

332 if world.size == 1: 

333 return 

334 

335 def cleanup(sys=sys, time=time, world=world): 

336 error = getattr(sys, 'last_type', None) 

337 if error: 

338 sys.stdout.flush() 

339 sys.stderr.write(('ASE CLEANUP (node %d): %s occurred. ' + 

340 'Calling MPI_Abort!\n') % (world.rank, error)) 

341 sys.stderr.flush() 

342 # Give other nodes a moment to crash by themselves (perhaps 

343 # producing helpful error messages): 

344 time.sleep(3) 

345 world.abort(42) 

346 

347 atexit.register(cleanup) 

348 

349 

350def distribute_cpus(size, comm): 

351 """Distribute cpus to tasks and calculators. 

352 

353 Input: 

354 size: number of nodes per calculator 

355 comm: total communicator object 

356 

357 Output: 

358 communicator for this rank, number of calculators, index for this rank 

359 """ 

360 

361 assert size <= comm.size 

362 assert comm.size % size == 0 

363 

364 tasks_rank = comm.rank // size 

365 

366 r0 = tasks_rank * size 

367 ranks = np.arange(r0, r0 + size) 

368 mycomm = comm.new_communicator(ranks) 

369 

370 return mycomm, comm.size // size, tasks_rank 

371 

372 

373def myslice(ntotal, comm): 

374 """Return the slice of your tasks for ntotal jobs""" 

375 n = -(-ntotal // comm.size) # ceil divide 

376 return slice(n * comm.rank, n * (comm.rank + 1))