Coverage for /builds/ase/ase/ase/calculators/genericfileio.py: 87.40%

127 statements  

« prev     ^ index     » next       coverage.py v7.5.3, created at 2025-08-02 00:12 +0000

1# fmt: off 

2 

3import shlex 

4from abc import ABC, abstractmethod 

5from contextlib import ExitStack 

6from os import PathLike 

7from pathlib import Path 

8from typing import Any, Iterable, List, Mapping, Optional, Set 

9 

10from ase.calculators.abc import GetOutputsMixin 

11from ase.calculators.calculator import ( 

12 BadConfiguration, 

13 BaseCalculator, 

14 _validate_command, 

15) 

16from ase.config import cfg as _cfg 

17 

18link_calculator_docs = ( 

19 "https://wiki.fysik.dtu.dk/ase/ase/calculators/" 

20 "calculators.html#calculator-configuration" 

21) 

22 

23 

24class BaseProfile(ABC): 

25 configvars: Set[str] = set() 

26 

27 def __init__(self, command): 

28 self.command = _validate_command(command) 

29 

30 @property 

31 def _split_command(self): 

32 return shlex.split(self.command) 

33 

34 def get_command(self, inputfile, calc_command=None) -> List[str]: 

35 """ 

36 Get the command to run. This should be a list of strings. 

37 

38 Parameters 

39 ---------- 

40 inputfile : str 

41 calc_command: list[str]: calculator command (used for sockets) 

42 

43 Returns 

44 ------- 

45 list of str 

46 The command to run. 

47 """ 

48 if calc_command is None: 

49 calc_command = self.get_calculator_command(inputfile) 

50 return [*self._split_command, *calc_command] 

51 

52 @abstractmethod 

53 def get_calculator_command(self, inputfile): 

54 """ 

55 The calculator specific command as a list of strings. 

56 

57 Parameters 

58 ---------- 

59 inputfile : str 

60 

61 Returns 

62 ------- 

63 list of str 

64 The command to run. 

65 """ 

66 

67 def run( 

68 self, directory: Path, inputfile: Optional[str], 

69 outputfile: str, errorfile: Optional[str] = None, 

70 append: bool = False 

71 ) -> None: 

72 """ 

73 Run the command in the given directory. 

74 

75 Parameters 

76 ---------- 

77 directory : pathlib.Path 

78 The directory to run the command in. 

79 inputfile : Optional[str] 

80 The name of the input file. 

81 outputfile : str 

82 The name of the output file. 

83 errorfile: Optional[str] 

84 the stderror file 

85 append: bool 

86 if True then use append mode 

87 """ 

88 

89 import os 

90 from subprocess import check_call 

91 

92 argv_command = self.get_command(inputfile) 

93 mode = 'wb' if not append else 'ab' 

94 

95 with ExitStack() as stack: 

96 output_path = directory / outputfile 

97 fd_out = stack.enter_context(open(output_path, mode)) 

98 if errorfile is not None: 

99 error_path = directory / errorfile 

100 fd_err = stack.enter_context(open(error_path, mode)) 

101 else: 

102 fd_err = None 

103 check_call( 

104 argv_command, 

105 cwd=directory, 

106 stdout=fd_out, 

107 stderr=fd_err, 

108 env=os.environ, 

109 ) 

110 

111 @abstractmethod 

112 def version(self): 

113 """Get the version of the code. 

114 

115 Returns 

116 ------- 

117 str 

118 The version of the code. 

119 """ 

120 

121 @classmethod 

122 def from_config(cls, cfg, section_name): 

123 """Create a profile from a configuration file. 

124 

125 Parameters 

126 ---------- 

127 cfg : ase.config.Config 

128 The configuration object. 

129 section_name : str 

130 The name of the section in the configuration file. E.g. the name 

131 of the template that this profile is for. 

132 

133 Returns 

134 ------- 

135 BaseProfile 

136 The profile object. 

137 """ 

138 section = cfg.parser[section_name] 

139 command = section['command'] 

140 

141 kwargs = { 

142 varname: section[varname] 

143 for varname in cls.configvars if varname in section 

144 } 

145 

146 try: 

147 return cls(command=command, **kwargs) 

148 except TypeError as err: 

149 raise BadConfiguration(*err.args) 

150 

151 

152def read_stdout(args, createfile=None): 

153 """Run command in tempdir and return standard output. 

154 

155 Helper function for getting version numbers of DFT codes. 

156 Most DFT codes don't implement a --version flag, so in order to 

157 determine the code version, we just run the code until it prints 

158 a version number.""" 

159 import tempfile 

160 from subprocess import PIPE, Popen 

161 

162 with tempfile.TemporaryDirectory() as directory: 

163 if createfile is not None: 

164 path = Path(directory) / createfile 

165 path.touch() 

166 proc = Popen( 

167 args, 

168 stdout=PIPE, 

169 stderr=PIPE, 

170 stdin=PIPE, 

171 cwd=directory, 

172 encoding='utf-8', # Make this a parameter if any non-utf8/ascii 

173 ) 

174 stdout, _ = proc.communicate() 

175 # Exit code will be != 0 because there isn't an input file 

176 return stdout 

177 

178 

179class CalculatorTemplate(ABC): 

180 def __init__(self, name: str, implemented_properties: Iterable[str]): 

181 self.name = name 

182 self.implemented_properties = frozenset(implemented_properties) 

183 

184 @abstractmethod 

185 def write_input(self, profile, directory, atoms, parameters, properties): 

186 ... 

187 

188 @abstractmethod 

189 def execute(self, directory, profile): 

190 ... 

191 

192 @abstractmethod 

193 def read_results(self, directory: PathLike) -> Mapping[str, Any]: 

194 ... 

195 

196 @abstractmethod 

197 def load_profile(self, cfg): 

198 ... 

199 

200 def socketio_calculator( 

201 self, 

202 profile, 

203 parameters, 

204 directory, 

205 # We may need quite a few socket kwargs here 

206 # if we want to expose all the timeout etc. from 

207 # SocketIOCalculator. 

208 unixsocket=None, 

209 port=None, 

210 ): 

211 import os 

212 from subprocess import Popen 

213 

214 from ase.calculators.socketio import SocketIOCalculator 

215 

216 if port and unixsocket: 

217 raise TypeError( 

218 'For the socketio_calculator only a UNIX ' 

219 '(unixsocket) or INET (port) socket can be used' 

220 ' not both.' 

221 ) 

222 

223 if not port and not unixsocket: 

224 raise TypeError( 

225 'For the socketio_calculator either a ' 

226 'UNIX (unixsocket) or INET (port) socket ' 

227 'must be used' 

228 ) 

229 

230 if not ( 

231 hasattr(self, 'socketio_argv') 

232 and hasattr(self, 'socketio_parameters') 

233 ): 

234 raise TypeError( 

235 f'Template {self} does not implement mandatory ' 

236 'socketio_argv() and socketio_parameters()' 

237 ) 

238 

239 # XXX need socketio ABC or something 

240 argv = profile.get_command( 

241 inputfile=None, 

242 calc_command=self.socketio_argv(profile, unixsocket, port) 

243 ) 

244 parameters = { 

245 **self.socketio_parameters(unixsocket, port), 

246 **parameters, 

247 } 

248 

249 # Not so elegant that socket args are passed to this function 

250 # via socketiocalculator when we could make a closure right here. 

251 def launch(atoms, properties, port, unixsocket): 

252 directory.mkdir(exist_ok=True, parents=True) 

253 

254 self.write_input( 

255 atoms=atoms, 

256 profile=profile, 

257 parameters=parameters, 

258 properties=properties, 

259 directory=directory, 

260 ) 

261 

262 with open(directory / self.outputname, 'w') as out_fd: 

263 return Popen(argv, stdout=out_fd, cwd=directory, env=os.environ) 

264 

265 return SocketIOCalculator( 

266 launch_client=launch, unixsocket=unixsocket, port=port 

267 ) 

268 

269 

270class GenericFileIOCalculator(BaseCalculator, GetOutputsMixin): 

271 cfg = _cfg 

272 

273 def __init__( 

274 self, 

275 *, 

276 template, 

277 profile, 

278 directory, 

279 parameters=None, 

280 ): 

281 self.template = template 

282 if profile is None: 

283 if template.name not in self.cfg.parser: 

284 raise BadConfiguration( 

285 f"No configuration of '{template.name}'. " 

286 f"See '{link_calculator_docs}'" 

287 ) 

288 try: 

289 profile = template.load_profile(self.cfg) 

290 except Exception as err: 

291 configvars = self.cfg.as_dict() 

292 raise BadConfiguration( 

293 f'Failed to load section [{template.name}] ' 

294 f'from configuration: {configvars}' 

295 ) from err 

296 

297 self.profile = profile 

298 

299 # Maybe we should allow directory to be a factory, so 

300 # calculators e.g. produce new directories on demand. 

301 self.directory = Path(directory) 

302 super().__init__(parameters) 

303 

304 def set(self, *args, **kwargs): 

305 raise RuntimeError( 

306 'No setting parameters for now, please. ' 

307 'Just create new calculators.' 

308 ) 

309 

310 def __repr__(self): 

311 return f'{type(self).__name__}({self.template.name})' 

312 

313 @property 

314 def implemented_properties(self): 

315 return self.template.implemented_properties 

316 

317 @property 

318 def name(self): 

319 return self.template.name 

320 

321 def write_inputfiles(self, atoms, properties): 

322 # SocketIOCalculators like to write inputfiles 

323 # without calculating. 

324 self.directory.mkdir(exist_ok=True, parents=True) 

325 self.template.write_input( 

326 profile=self.profile, 

327 atoms=atoms, 

328 parameters=self.parameters, 

329 properties=properties, 

330 directory=self.directory, 

331 ) 

332 

333 def calculate(self, atoms, properties, system_changes): 

334 self.write_inputfiles(atoms, properties) 

335 self.template.execute(self.directory, self.profile) 

336 self.results = self.template.read_results(self.directory) 

337 # XXX Return something useful? 

338 

339 def _outputmixin_get_results(self): 

340 return self.results 

341 

342 def socketio(self, **socketkwargs): 

343 return self.template.socketio_calculator( 

344 directory=self.directory, 

345 parameters=self.parameters, 

346 profile=self.profile, 

347 **socketkwargs, 

348 )