Coverage for /builds/ase/ase/ase/calculators/subprocesscalculator.py: 91.37%

197 statements  

« prev     ^ index     » next       coverage.py v7.5.3, created at 2025-08-02 00:12 +0000

1# fmt: off 

2 

3import os 

4import pickle 

5import sys 

6from abc import ABC, abstractmethod 

7from subprocess import PIPE, Popen 

8 

9from ase.calculators.calculator import Calculator, all_properties 

10 

11 

12class PackedCalculator(ABC): 

13 """Portable calculator for use via PythonSubProcessCalculator. 

14 

15 This class allows creating and talking to a calculator which 

16 exists inside a different process, possibly with MPI or srun. 

17 

18 Use this when you want to use ASE mostly in serial, but run some 

19 calculations in a parallel Python environment. 

20 

21 Most existing calculators can be used this way through the 

22 NamedPackedCalculator implementation. To customize the behaviour 

23 for other calculators, write a custom class inheriting this one. 

24 

25 Example:: 

26 

27 from ase.build import bulk 

28 

29 atoms = bulk('Au') 

30 pack = NamedPackedCalculator('emt') 

31 

32 with pack.calculator() as atoms.calc: 

33 energy = atoms.get_potential_energy() 

34 

35 The computation takes place inside a subprocess which lives as long 

36 as the with statement. 

37 """ 

38 

39 @abstractmethod 

40 def unpack_calculator(self) -> Calculator: 

41 """Return the calculator packed inside. 

42 

43 This method will be called inside the subprocess doing 

44 computations.""" 

45 

46 def calculator(self, mpi_command=None) -> 'PythonSubProcessCalculator': 

47 """Return a PythonSubProcessCalculator for this calculator. 

48 

49 The subprocess calculator wraps a subprocess containing 

50 the actual calculator, and computations are done inside that 

51 subprocess.""" 

52 return PythonSubProcessCalculator(self, mpi_command=mpi_command) 

53 

54 

55class NamedPackedCalculator(PackedCalculator): 

56 """PackedCalculator implementation which works with standard calculators. 

57 

58 This works with calculators known by ase.calculators.calculator.""" 

59 

60 def __init__(self, name, kwargs=None): 

61 self._name = name 

62 if kwargs is None: 

63 kwargs = {} 

64 self._kwargs = kwargs 

65 

66 def unpack_calculator(self): 

67 from ase.calculators.calculator import get_calculator_class 

68 cls = get_calculator_class(self._name) 

69 return cls(**self._kwargs) 

70 

71 def __repr__(self): 

72 return f'{self.__class__.__name__}({self._name}, {self._kwargs})' 

73 

74 

75class MPICommand: 

76 def __init__(self, argv): 

77 self.argv = argv 

78 

79 @classmethod 

80 def python_argv(cls): 

81 return [sys.executable, '-m', 'ase.calculators.subprocesscalculator'] 

82 

83 @classmethod 

84 def parallel(cls, nprocs, mpi_argv=()): 

85 return cls(['mpiexec', '-n', str(nprocs)] 

86 + list(mpi_argv) 

87 + cls.python_argv() 

88 + ['mpi4py']) 

89 

90 @classmethod 

91 def serial(cls): 

92 return MPICommand(cls.python_argv() + ['standard']) 

93 

94 def execute(self): 

95 # On this computer (Ubuntu 20.04 + OpenMPI) the subprocess crashes 

96 # without output during startup if os.environ is not passed along. 

97 # Hence we pass os.environ. Not sure if this is a machine thing 

98 # or in general. --askhl 

99 return Popen(self.argv, stdout=PIPE, 

100 stdin=PIPE, env=os.environ) 

101 

102 

103def gpaw_process(ncores=1, **kwargs): 

104 packed = NamedPackedCalculator('gpaw', kwargs) 

105 mpicommand = MPICommand([ 

106 sys.executable, '-m', 'gpaw', '-P', str(ncores), 'python', '-m', 

107 'ase.calculators.subprocesscalculator', 'standard', 

108 ]) 

109 return PythonSubProcessCalculator(packed, mpicommand) 

110 

111 

112class PythonSubProcessCalculator(Calculator): 

113 """Calculator for running calculations in external processes. 

114 

115 TODO: This should work with arbitrary commands including MPI stuff. 

116 

117 This calculator runs a subprocess wherein it sets up an 

118 actual calculator. Calculations are forwarded through pickle 

119 to that calculator, which returns results through pickle.""" 

120 implemented_properties = list(all_properties) 

121 

122 def __init__(self, calc_input, mpi_command=None): 

123 super().__init__() 

124 

125 # self.proc = None 

126 self.calc_input = calc_input 

127 if mpi_command is None: 

128 mpi_command = MPICommand.serial() 

129 self.mpi_command = mpi_command 

130 

131 self.protocol = None 

132 

133 def set(self, **kwargs): 

134 if hasattr(self, 'client'): 

135 raise RuntimeError('No setting things for now, thanks') 

136 

137 def __repr__(self): 

138 return '{}({})'.format(type(self).__name__, 

139 self.calc_input) 

140 

141 def __enter__(self): 

142 assert self.protocol is None 

143 proc = self.mpi_command.execute() 

144 self.protocol = Protocol(proc) 

145 self.protocol.send(self.calc_input) 

146 return self 

147 

148 def __exit__(self, *args): 

149 self.protocol.send('stop') 

150 self.protocol.proc.communicate() 

151 self.protocol = None 

152 

153 def _run_calculation(self, atoms, properties, system_changes): 

154 self.protocol.send('calculate') 

155 self.protocol.send((atoms, properties, system_changes)) 

156 

157 def calculate(self, atoms, properties, system_changes): 

158 Calculator.calculate(self, atoms, properties, system_changes) 

159 # We send a pickle of self.atoms because this is a fresh copy 

160 # of the input, but without an unpicklable calculator: 

161 self._run_calculation(self.atoms.copy(), properties, system_changes) 

162 results = self.protocol.recv() 

163 self.results.update(results) 

164 

165 def backend(self): 

166 return ParallelBackendInterface(self) 

167 

168 

169class Protocol: 

170 def __init__(self, proc): 

171 self.proc = proc 

172 

173 def send(self, obj): 

174 pickle.dump(obj, self.proc.stdin) 

175 self.proc.stdin.flush() 

176 

177 def recv(self): 

178 response_type, value = pickle.load(self.proc.stdout) 

179 

180 if response_type == 'raise': 

181 raise value 

182 

183 assert response_type == 'return' 

184 return value 

185 

186 

187class MockMethod: 

188 def __init__(self, name, calc): 

189 self.name = name 

190 self.calc = calc 

191 

192 def __call__(self, *args, **kwargs): 

193 protocol = self.calc.protocol 

194 protocol.send('callmethod') 

195 protocol.send([self.name, args, kwargs]) 

196 return protocol.recv() 

197 

198 

199class ParallelBackendInterface: 

200 def __init__(self, calc): 

201 self.calc = calc 

202 

203 def __getattr__(self, name): 

204 return MockMethod(name, self.calc) 

205 

206 

207run_modes = {'standard', 'mpi4py'} 

208 

209 

210def callmethod(calc, attrname, args, kwargs): 

211 method = getattr(calc, attrname) 

212 value = method(*args, **kwargs) 

213 return value 

214 

215 

216def callfunction(func, args, kwargs): 

217 return func(*args, **kwargs) 

218 

219 

220def calculate(calc, atoms, properties, system_changes): 

221 # Again we need formalization of the results/outputs, and 

222 # a way to programmatically access all available properties. 

223 # We do a wild hack for now: 

224 calc.results.clear() 

225 # If we don't clear(), the caching is broken! For stress. 

226 # But not for forces. What dark magic from the depths of the 

227 # underworld is at play here? 

228 calc.calculate(atoms=atoms, properties=properties, 

229 system_changes=system_changes) 

230 results = calc.results 

231 return results 

232 

233 

234def bad_mode(): 

235 return SystemExit(f'sys.argv[1] must be one of {run_modes}') 

236 

237 

238def parallel_startup(): 

239 try: 

240 run_mode = sys.argv[1] 

241 except IndexError: 

242 raise bad_mode() 

243 

244 if run_mode not in run_modes: 

245 raise bad_mode() 

246 

247 if run_mode == 'mpi4py': 

248 # We must import mpi4py before the rest of ASE, or world will not 

249 # be correctly initialized. 

250 import mpi4py # noqa 

251 

252 # We switch stdout so stray print statements won't interfere with outputs: 

253 binary_stdout = sys.stdout.buffer 

254 sys.stdout = sys.stderr 

255 

256 return Client(input_fd=sys.stdin.buffer, 

257 output_fd=binary_stdout) 

258 

259 

260class Client: 

261 def __init__(self, input_fd, output_fd): 

262 from ase.parallel import world 

263 self._world = world 

264 self.input_fd = input_fd 

265 self.output_fd = output_fd 

266 

267 def recv(self): 

268 from ase.parallel import broadcast 

269 if self._world.rank == 0: 

270 obj = pickle.load(self.input_fd) 

271 else: 

272 obj = None 

273 

274 obj = broadcast(obj, 0, self._world) 

275 return obj 

276 

277 def send(self, obj): 

278 if self._world.rank == 0: 

279 pickle.dump(obj, self.output_fd) 

280 self.output_fd.flush() 

281 

282 def mainloop(self, calc): 

283 while True: 

284 instruction = self.recv() 

285 if instruction == 'stop': 

286 return 

287 

288 instruction_data = self.recv() 

289 

290 response_type, value = self.process_instruction( 

291 calc, instruction, instruction_data) 

292 self.send((response_type, value)) 

293 

294 def process_instruction(self, calc, instruction, instruction_data): 

295 if instruction == 'callmethod': 

296 function = callmethod 

297 args = (calc, *instruction_data) 

298 elif instruction == 'calculate': 

299 function = calculate 

300 args = (calc, *instruction_data) 

301 elif instruction == 'callfunction': 

302 function = callfunction 

303 args = instruction_data 

304 else: 

305 raise RuntimeError(f'Bad instruction: {instruction}') 

306 

307 try: 

308 value = function(*args) 

309 except Exception as ex: 

310 import traceback 

311 traceback.print_exc() 

312 response_type = 'raise' 

313 value = ex 

314 else: 

315 response_type = 'return' 

316 return response_type, value 

317 

318 

319class ParallelDispatch: 

320 """Utility class to run functions in parallel. 

321 

322 with ParallelDispatch(...) as parallel: 

323 parallel.call(function, args, kwargs) 

324 

325 """ 

326 

327 def __init__(self, mpicommand): 

328 self._mpicommand = mpicommand 

329 self._protocol = None 

330 

331 def call(self, func, *args, **kwargs): 

332 self._protocol.send('callfunction') 

333 self._protocol.send((func, args, kwargs)) 

334 return self._protocol.recv() 

335 

336 def __enter__(self): 

337 assert self._protocol is None 

338 self._protocol = Protocol(self._mpicommand.execute()) 

339 

340 # Even if we are not using a calculator, we have to send one: 

341 pack = NamedPackedCalculator('emt', {}) 

342 self._protocol.send(pack) 

343 # (We should get rid of that requirement.) 

344 

345 return self 

346 

347 def __exit__(self, *args): 

348 self._protocol.send('stop') 

349 self._protocol.proc.communicate() 

350 self._protocol = None 

351 

352 

353def main(): 

354 client = parallel_startup() 

355 pack = client.recv() 

356 calc = pack.unpack_calculator() 

357 client.mainloop(calc) 

358 

359 

360if __name__ == '__main__': 

361 main()