Coverage for ase / calculators / genericfileio.py: 89.66%

116 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-03-30 08:22 +0000

1# fmt: off 

2 

3import shlex 

4from abc import ABC, abstractmethod 

5from collections.abc import Iterable, Mapping 

6from contextlib import ExitStack 

7from os import PathLike 

8from pathlib import Path 

9from typing import Any 

10 

11from ase.calculators.abc import GetOutputsMixin 

12from ase.calculators.calculator import ( 

13 BadConfiguration, 

14 BaseCalculator, 

15 _validate_command, 

16) 

17from ase.config import cfg as _cfg 

18 

19link_calculator_docs = ( 

20 "https://ase-lib.org/ase/calculators/" 

21 "calculators.html#calculator-configuration" 

22) 

23 

24 

25class BaseProfile(ABC): 

26 configvars: set[str] = set() 

27 

28 def __init__(self, command): 

29 self.command = _validate_command(command) 

30 

31 @property 

32 def _split_command(self): 

33 return shlex.split(self.command) 

34 

35 def get_command(self, inputfile, calc_command=None) -> list[str]: 

36 """ 

37 Get the command to run. This should be a list of strings. 

38 

39 Parameters 

40 ---------- 

41 inputfile : str 

42 calc_command: list[str]: calculator command (used for sockets) 

43 

44 Returns 

45 ------- 

46 list of str 

47 The command to run. 

48 """ 

49 if calc_command is None: 

50 calc_command = self.get_calculator_command(inputfile) 

51 return [*self._split_command, *calc_command] 

52 

53 @abstractmethod 

54 def get_calculator_command(self, inputfile): 

55 """ 

56 The calculator specific command as a list of strings. 

57 

58 Parameters 

59 ---------- 

60 inputfile : str 

61 

62 Returns 

63 ------- 

64 list of str 

65 The command to run. 

66 """ 

67 

68 def run( 

69 self, directory: Path, inputfile: str | None, 

70 outputfile: str, errorfile: str | None = None, 

71 append: bool = False 

72 ) -> None: 

73 """ 

74 Run the command in the given directory. 

75 

76 Parameters 

77 ---------- 

78 directory : pathlib.Path 

79 The directory to run the command in. 

80 inputfile : Optional[str] 

81 The name of the input file. 

82 outputfile : str 

83 The name of the output file. 

84 errorfile: Optional[str] 

85 the stderror file 

86 append: bool 

87 if True then use append mode 

88 """ 

89 

90 import os 

91 from subprocess import check_call 

92 

93 argv_command = self.get_command(inputfile) 

94 mode = 'wb' if not append else 'ab' 

95 

96 with ExitStack() as stack: 

97 output_path = directory / outputfile 

98 fd_out = stack.enter_context(open(output_path, mode)) 

99 if errorfile is not None: 

100 error_path = directory / errorfile 

101 fd_err = stack.enter_context(open(error_path, mode)) 

102 else: 

103 fd_err = None 

104 check_call( 

105 argv_command, 

106 cwd=directory, 

107 stdout=fd_out, 

108 stderr=fd_err, 

109 env=os.environ, 

110 ) 

111 

112 @abstractmethod 

113 def version(self): 

114 """Get the version of the code. 

115 

116 Returns 

117 ------- 

118 str 

119 The version of the code. 

120 """ 

121 

122 @classmethod 

123 def from_config(cls, cfg, section_name): 

124 """Create a profile from a configuration file. 

125 

126 Parameters 

127 ---------- 

128 cfg : ase.config.Config 

129 The configuration object. 

130 section_name : str 

131 The name of the section in the configuration file. E.g. the name 

132 of the template that this profile is for. 

133 

134 Returns 

135 ------- 

136 BaseProfile 

137 The profile object. 

138 """ 

139 section = cfg.parser[section_name] 

140 command = section['command'] 

141 

142 kwargs = { 

143 varname: section[varname] 

144 for varname in cls.configvars if varname in section 

145 } 

146 

147 try: 

148 return cls(command=command, **kwargs) 

149 except TypeError as err: 

150 raise BadConfiguration(*err.args) 

151 

152 

153def read_stdout(args, createfile=None): 

154 """Run command in tempdir and return standard output. 

155 

156 Helper function for getting version numbers of DFT codes. 

157 Most DFT codes don't implement a --version flag, so in order to 

158 determine the code version, we just run the code until it prints 

159 a version number.""" 

160 import tempfile 

161 from subprocess import PIPE, Popen 

162 

163 with tempfile.TemporaryDirectory() as directory: 

164 if createfile is not None: 

165 path = Path(directory) / createfile 

166 path.touch() 

167 proc = Popen( 

168 args, 

169 stdout=PIPE, 

170 stderr=PIPE, 

171 stdin=PIPE, 

172 cwd=directory, 

173 encoding='utf-8', # Make this a parameter if any non-utf8/ascii 

174 ) 

175 stdout, _ = proc.communicate() 

176 # Exit code will be != 0 because there isn't an input file 

177 return stdout 

178 

179 

180class CalculatorTemplate(ABC): 

181 def __init__(self, name: str, implemented_properties: Iterable[str]): 

182 self.name = name 

183 self.implemented_properties = frozenset(implemented_properties) 

184 

185 @abstractmethod 

186 def write_input(self, profile, directory, atoms, parameters, properties): 

187 ... 

188 

189 @abstractmethod 

190 def execute(self, directory, profile): 

191 ... 

192 

193 @abstractmethod 

194 def read_results(self, directory: PathLike) -> Mapping[str, Any]: 

195 ... 

196 

197 @abstractmethod 

198 def load_profile(self, cfg): 

199 ... 

200 

201 def socketio_calculator( 

202 self, 

203 profile, 

204 parameters, 

205 directory, 

206 # We may need quite a few socket kwargs here 

207 # if we want to expose all the timeout etc. from 

208 # SocketIOCalculator. 

209 unixsocket=None, 

210 port=None, 

211 ): 

212 import os 

213 from subprocess import Popen 

214 

215 from ase.calculators.socketio import SocketIOCalculator 

216 

217 if port and unixsocket: 

218 raise TypeError( 

219 'For the socketio_calculator only a UNIX ' 

220 '(unixsocket) or INET (port) socket can be used' 

221 ' not both.' 

222 ) 

223 

224 if not port and not unixsocket: 

225 raise TypeError( 

226 'For the socketio_calculator either a ' 

227 'UNIX (unixsocket) or INET (port) socket ' 

228 'must be used' 

229 ) 

230 

231 if not ( 

232 hasattr(self, 'socketio_argv') 

233 and hasattr(self, 'socketio_parameters') 

234 ): 

235 raise TypeError( 

236 f'Template {self} does not implement mandatory ' 

237 'socketio_argv() and socketio_parameters()' 

238 ) 

239 

240 # XXX need socketio ABC or something 

241 argv = profile.get_command( 

242 inputfile=None, 

243 calc_command=self.socketio_argv(profile, unixsocket, port) 

244 ) 

245 parameters = { 

246 **self.socketio_parameters(unixsocket, port), 

247 **parameters, 

248 } 

249 

250 # Not so elegant that socket args are passed to this function 

251 # via socketiocalculator when we could make a closure right here. 

252 def launch(atoms, properties, port, unixsocket): 

253 directory.mkdir(exist_ok=True, parents=True) 

254 

255 self.write_input( 

256 atoms=atoms, 

257 profile=profile, 

258 parameters=parameters, 

259 properties=properties, 

260 directory=directory, 

261 ) 

262 

263 with open(directory / self.outputname, 'w') as out_fd: 

264 return Popen(argv, stdout=out_fd, cwd=directory, env=os.environ) 

265 

266 return SocketIOCalculator( 

267 launch_client=launch, unixsocket=unixsocket, port=port 

268 ) 

269 

270 

271class GenericFileIOCalculator(BaseCalculator, GetOutputsMixin): 

272 cfg = _cfg 

273 

274 def __init__( 

275 self, 

276 *, 

277 template, 

278 profile, 

279 directory, 

280 parameters=None, 

281 ): 

282 self.template = template 

283 if profile is None: 

284 if template.name not in self.cfg.parser: 

285 raise BadConfiguration( 

286 f"No configuration of '{template.name}'. " 

287 f"See '{link_calculator_docs}'" 

288 ) 

289 try: 

290 profile = template.load_profile(self.cfg) 

291 except Exception as err: 

292 configvars = self.cfg.as_dict() 

293 raise BadConfiguration( 

294 f'Failed to load section [{template.name}] ' 

295 f'from configuration: {configvars}' 

296 ) from err 

297 

298 self.profile = profile 

299 

300 # Maybe we should allow directory to be a factory, so 

301 # calculators e.g. produce new directories on demand. 

302 self.directory = Path(directory) 

303 super().__init__(parameters) 

304 

305 def set(self, *args, **kwargs): 

306 raise RuntimeError( 

307 'No setting parameters for now, please. ' 

308 'Just create new calculators.' 

309 ) 

310 

311 def __repr__(self): 

312 return f'{type(self).__name__}({self.template.name})' 

313 

314 @property 

315 def implemented_properties(self): 

316 return self.template.implemented_properties 

317 

318 @property 

319 def name(self): 

320 return self.template.name 

321 

322 def write_inputfiles(self, atoms, properties): 

323 # SocketIOCalculators like to write inputfiles 

324 # without calculating. 

325 self.directory.mkdir(exist_ok=True, parents=True) 

326 self.template.write_input( 

327 profile=self.profile, 

328 atoms=atoms, 

329 parameters=self.parameters, 

330 properties=properties, 

331 directory=self.directory, 

332 ) 

333 

334 def calculate(self, atoms, properties, system_changes): 

335 self.write_inputfiles(atoms, properties) 

336 self.template.execute(self.directory, self.profile) 

337 self.results = self.template.read_results(self.directory) 

338 # XXX Return something useful? 

339 

340 def _outputmixin_get_results(self): 

341 return self.results 

342 

343 def socketio(self, **socketkwargs): 

344 return self.template.socketio_calculator( 

345 directory=self.directory, 

346 parameters=self.parameters, 

347 profile=self.profile, 

348 **socketkwargs, 

349 )