Coverage for /builds/ase/ase/ase/ga/data.py: 78.22%

225 statements  

« prev     ^ index     » next       coverage.py v7.5.3, created at 2025-08-02 00:12 +0000

1# fmt: off 

2 

3""" 

4 Objects which handle all communication with the SQLite database. 

5""" 

6import os 

7 

8import ase.db 

9from ase import Atoms 

10from ase.ga import get_raw_score, set_neighbor_list, set_parametrization 

11 

12 

13def split_description(desc): 

14 """ Utility method for string splitting. """ 

15 d = desc.split(':') 

16 assert len(d) == 2, desc 

17 return d[0], d[1] 

18 

19 

20def test_raw_score(atoms): 

21 """Test that raw_score can be extracted.""" 

22 err_msg = "raw_score not put in atoms.info['key_value_pairs']" 

23 assert 'raw_score' in atoms.info['key_value_pairs'], err_msg 

24 

25 

26class DataConnection: 

27 """Class that handles all database communication. 

28 

29 All data communication is collected in this class in order to 

30 make a decoupling of the data representation and the GA method. 

31 

32 A new candidate must be added with one of the functions 

33 add_unrelaxed_candidate or add_relaxed_candidate this will correctly 

34 initialize a configuration id used to keep track of candidates in the 

35 database. 

36 After one of the add_*_candidate functions have been used, if the candidate 

37 is further modified or relaxed the functions add_unrelaxed_step or 

38 add_relaxed_step must be used. This way the configuration id carries 

39 through correctly. 

40 

41 Parameters: 

42 

43 db_file_name: Path to the ase.db data file. 

44 """ 

45 

46 def __init__(self, db_file_name): 

47 self.db_file_name = db_file_name 

48 if not os.path.isfile(self.db_file_name): 

49 raise OSError(f'DB file {self.db_file_name} not found') 

50 self.c = ase.db.connect(self.db_file_name) 

51 self.already_returned = set() 

52 

53 def get_number_of_unrelaxed_candidates(self): 

54 """ Returns the number of candidates not yet queued or relaxed. """ 

55 return len(self.__get_ids_of_all_unrelaxed_candidates__()) 

56 

57 def get_an_unrelaxed_candidate(self): 

58 """ Returns a candidate ready for relaxation. """ 

59 to_get = self.__get_ids_of_all_unrelaxed_candidates__() 

60 if len(to_get) == 0: 

61 raise ValueError('No unrelaxed candidate to return') 

62 

63 a = self.__get_latest_traj_for_confid__(to_get[0]) 

64 a.info['confid'] = to_get[0] 

65 if 'data' not in a.info: 

66 a.info['data'] = {} 

67 return a 

68 

69 def get_all_unrelaxed_candidates(self): 

70 """Return all unrelaxed candidates, 

71 useful if they can all be evaluated quickly.""" 

72 to_get = self.__get_ids_of_all_unrelaxed_candidates__() 

73 if len(to_get) == 0: 

74 return [] 

75 res = [] 

76 for confid in to_get: 

77 a = self.__get_latest_traj_for_confid__(confid) 

78 a.info['confid'] = confid 

79 if 'data' not in a.info: 

80 a.info['data'] = {} 

81 res.append(a) 

82 return res 

83 

84 def __get_ids_of_all_unrelaxed_candidates__(self): 

85 """ Helper method used by the two above methods. """ 

86 

87 all_unrelaxed_ids = {t.gaid for t in self.c.select(relaxed=0)} 

88 all_relaxed_ids = {t.gaid for t in self.c.select(relaxed=1)} 

89 all_queued_ids = {t.gaid for t in self.c.select(queued=1)} 

90 

91 actually_unrelaxed = [gaid for gaid in all_unrelaxed_ids 

92 if (gaid not in all_relaxed_ids and 

93 gaid not in all_queued_ids)] 

94 

95 return actually_unrelaxed 

96 

97 def __get_latest_traj_for_confid__(self, confid): 

98 """ Method for obtaining the latest traj 

99 file for a given configuration. 

100 There can be several traj files for 

101 one configuration if it has undergone 

102 several changes (mutations, pairings, etc.).""" 

103 allcands = list(self.c.select(gaid=confid)) 

104 allcands.sort(key=lambda x: x.mtime) 

105 # return self.get_atoms(all[-1].gaid) 

106 return self.get_atoms(allcands[-1].id) 

107 

108 def mark_as_queued(self, a): 

109 """ Marks a configuration as queued for relaxation. """ 

110 gaid = a.info['confid'] 

111 self.c.write(None, gaid=gaid, queued=1, 

112 key_value_pairs=a.info['key_value_pairs']) 

113 

114# if not np.array_equal(a.numbers, self.atom_numbers): 

115# raise ValueError('Wrong stoichiometry') 

116# self.c.write(a, gaid=gaid, queued=1) 

117 

118 def add_relaxed_step(self, a, find_neighbors=None, 

119 perform_parametrization=None): 

120 """After a candidate is relaxed it must be marked 

121 as such. Use this function if the candidate has already been in the 

122 database in an unrelaxed version, i.e. add_unrelaxed_candidate has 

123 been used. 

124 

125 Neighbor list and parametrization parameters to screen 

126 candidates before relaxation can be added. Default is not to use. 

127 """ 

128 # test that raw_score can be extracted 

129 err_msg = "raw_score not put in atoms.info['key_value_pairs']" 

130 assert 'raw_score' in a.info['key_value_pairs'], err_msg 

131 

132 # confid has already been set in add_unrelaxed_candidate 

133 gaid = a.info['confid'] 

134 

135 if 'generation' not in a.info['key_value_pairs']: 

136 g = self.get_generation_number() 

137 a.info['key_value_pairs']['generation'] = g 

138 

139 if find_neighbors is not None: 

140 set_neighbor_list(a, find_neighbors(a)) 

141 if perform_parametrization is not None: 

142 set_parametrization(a, perform_parametrization(a)) 

143 

144 relax_id = self.c.write(a, relaxed=1, gaid=gaid, 

145 key_value_pairs=a.info['key_value_pairs'], 

146 data=a.info['data']) 

147 a.info['relax_id'] = relax_id 

148 

149 def add_relaxed_candidate(self, a, find_neighbors=None, 

150 perform_parametrization=None): 

151 """After a candidate is relaxed it must be marked 

152 as such. Use this function if the candidate has *not* been in the 

153 database in an unrelaxed version, i.e. add_unrelaxed_candidate has 

154 *not* been used. 

155 

156 Neighbor list and parametrization parameters to screen 

157 candidates before relaxation can be added. Default is not to use. 

158 """ 

159 test_raw_score(a) 

160 

161 if 'generation' not in a.info['key_value_pairs']: 

162 g = self.get_generation_number() 

163 a.info['key_value_pairs']['generation'] = g 

164 

165 if find_neighbors is not None: 

166 set_neighbor_list(a, find_neighbors(a)) 

167 if perform_parametrization is not None: 

168 set_parametrization(a, perform_parametrization(a)) 

169 

170 relax_id = self.c.write(a, relaxed=1, 

171 key_value_pairs=a.info['key_value_pairs'], 

172 data=a.info['data']) 

173 self.c.update(relax_id, gaid=relax_id) 

174 a.info['confid'] = relax_id 

175 a.info['relax_id'] = relax_id 

176 

177 def add_more_relaxed_steps(self, a_list): 

178 # This function will be removed soon as the function name indicates 

179 # that unrelaxed candidates are added beforehand 

180 print('Please use add_more_relaxed_candidates instead') 

181 self.add_more_relaxed_candidates(a_list) 

182 

183 def add_more_relaxed_candidates(self, a_list): 

184 """Add more relaxed candidates quickly""" 

185 for a in a_list: 

186 try: 

187 a.info['key_value_pairs']['raw_score'] 

188 except KeyError: 

189 print("raw_score not put in atoms.info['key_value_pairs']") 

190 

191 g = self.get_generation_number() 

192 

193 # Insert gaid by getting the next available id and assuming that the 

194 # entire a_list will be written without interuption 

195 next_id = self.get_next_id() 

196 with self.c as con: 

197 for j, a in enumerate(a_list): 

198 if 'generation' not in a.info['key_value_pairs']: 

199 a.info['key_value_pairs']['generation'] = g 

200 

201 gaid = next_id + j 

202 relax_id = con.write(a, relaxed=1, gaid=gaid, 

203 key_value_pairs=a.info['key_value_pairs'], 

204 data=a.info['data']) 

205 assert gaid == relax_id 

206 a.info['confid'] = relax_id 

207 a.info['relax_id'] = relax_id 

208 

209 def get_next_id(self): 

210 """Get the id of the next candidate to be added to the database. 

211 This is a hacky way of obtaining the id and it only works on a 

212 sqlite database. 

213 """ 

214 con = self.c._connect() 

215 last_id = self.c.get_last_id(con.cursor()) 

216 con.close() 

217 return last_id + 1 

218 

219 def get_largest_in_db(self, var): 

220 return next(self.c.select(sort=f'-{var}')).get(var) 

221 

222 def add_unrelaxed_candidate(self, candidate, description): 

223 """ Adds a new candidate which needs to be relaxed. """ 

224 t, desc = split_description(description) 

225 kwargs = {'relaxed': 0, 

226 'extinct': 0, 

227 t: 1, 

228 'description': desc} 

229 

230 if 'generation' not in candidate.info['key_value_pairs']: 

231 kwargs.update({'generation': self.get_generation_number()}) 

232 

233 gaid = self.c.write(candidate, 

234 key_value_pairs=candidate.info['key_value_pairs'], 

235 data=candidate.info['data'], 

236 **kwargs) 

237 self.c.update(gaid, gaid=gaid) 

238 candidate.info['confid'] = gaid 

239 

240 def add_unrelaxed_step(self, candidate, description): 

241 """ Add a change to a candidate without it having been relaxed. 

242 This method is typically used when a 

243 candidate has been mutated. """ 

244 

245 # confid has already been set by add_unrelaxed_candidate 

246 gaid = candidate.info['confid'] 

247 

248 t, desc = split_description(description) 

249 kwargs = {'relaxed': 0, 

250 'extinct': 0, 

251 t: 1, 

252 'description': desc, 'gaid': gaid} 

253 

254 self.c.write(candidate, 

255 key_value_pairs=candidate.info['key_value_pairs'], 

256 data=candidate.info['data'], 

257 **kwargs) 

258 

259 def get_number_of_atoms_to_optimize(self): 

260 """ Get the number of atoms being optimized. """ 

261 v = self.c.get(simulation_cell=True) 

262 return len(v.data.stoichiometry) 

263 

264 def get_atom_numbers_to_optimize(self): 

265 """ Get the list of atom numbers being optimized. """ 

266 v = self.c.get(simulation_cell=True) 

267 return v.data.stoichiometry 

268 

269 def get_slab(self): 

270 """ Get the super cell, including stationary atoms, in which 

271 the structure is being optimized. """ 

272 return self.c.get_atoms(simulation_cell=True) 

273 

274 def get_participation_in_pairing(self): 

275 """ Get information about how many direct 

276 offsprings each candidate has, and which specific 

277 pairings have been made. This information is used 

278 for the extended fitness calculation described in 

279 L.B. Vilhelmsen et al., JACS, 2012, 134 (30), pp 12807-12816 

280 """ 

281 entries = self.c.select(pairing=1) 

282 

283 frequency = {} 

284 pairs = [] 

285 for e in entries: 

286 c1, c2 = e.data['parents'] 

287 pairs.append(tuple(sorted([c1, c2]))) 

288 if c1 not in frequency.keys(): 

289 frequency[c1] = 0 

290 frequency[c1] += 1 

291 if c2 not in frequency.keys(): 

292 frequency[c2] = 0 

293 frequency[c2] += 1 

294 return (frequency, pairs) 

295 

296 def get_all_relaxed_candidates(self, only_new=False, use_extinct=False): 

297 """ Returns all candidates that have been relaxed. 

298 

299 Parameters: 

300 

301 only_new: boolean (optional) 

302 Used to specify only to get candidates relaxed since last 

303 time this function was invoked. Default: False. 

304 

305 use_extinct: boolean (optional) 

306 Set to True if the extinct key (and mass extinction) is going 

307 to be used. Default: False.""" 

308 

309 if use_extinct: 

310 entries = self.c.select('relaxed=1,extinct=0', 

311 sort='-raw_score') 

312 else: 

313 entries = self.c.select('relaxed=1', sort='-raw_score') 

314 

315 trajs = [] 

316 for v in entries: 

317 if only_new and v.gaid in self.already_returned: 

318 continue 

319 t = self.get_atoms(id=v.id) 

320 t.info['confid'] = v.gaid 

321 t.info['relax_id'] = v.id 

322 trajs.append(t) 

323 self.already_returned.add(v.gaid) 

324 return trajs 

325 

326 def get_all_relaxed_candidates_after_generation(self, gen): 

327 """ Returns all candidates that have been relaxed up to 

328 and including the specified generation 

329 """ 

330 q = 'relaxed=1,extinct=0,generation<={0}' 

331 entries = self.c.select(q.format(gen)) 

332 

333 trajs = [] 

334 for v in entries: 

335 t = self.get_atoms(id=v.id) 

336 t.info['confid'] = v.gaid 

337 t.info['relax_id'] = v.id 

338 trajs.append(t) 

339 trajs.sort(key=get_raw_score, 

340 reverse=True) 

341 return trajs 

342 

343 def get_all_candidates_in_queue(self): 

344 """ Returns all structures that are queued, but have not yet 

345 been relaxed. """ 

346 all_queued_ids = [t.gaid for t in self.c.select(queued=1)] 

347 all_relaxed_ids = [t.gaid for t in self.c.select(relaxed=1)] 

348 

349 in_queue = [qid for qid in all_queued_ids 

350 if qid not in all_relaxed_ids] 

351 return in_queue 

352 

353 def remove_from_queue(self, confid): 

354 """ Removes the candidate confid from the queue. """ 

355 

356 queued_ids = self.c.select(queued=1, gaid=confid) 

357 ids = [q.id for q in queued_ids] 

358 self.c.delete(ids) 

359 

360 def get_generation_number(self, size=None): 

361 """ Returns the current generation number, by looking 

362 at the number of relaxed individuals and comparing 

363 this number to the supplied size or population size. 

364 

365 If all individuals in generation 3 has been relaxed 

366 it will return 4 if not all in generation 4 has been 

367 relaxed. 

368 """ 

369 if size is None: 

370 size = self.get_param('population_size') 

371 if size is None: 

372 # size = len(list(self.c.select(relaxed=0,generation=0))) 

373 return 0 

374 lg = size 

375 g = 0 

376 all_candidates = list(self.c.select(relaxed=1)) 

377 while lg > 0: 

378 lg = len([c for c in all_candidates if c.generation == g]) 

379 if lg >= size: 

380 g += 1 

381 else: 

382 return g 

383 

384 def get_atoms(self, id, add_info=True): 

385 """Return the atoms object with the specified id""" 

386 a = self.c.get_atoms(id, add_additional_information=add_info) 

387 return a 

388 

389 def get_param(self, parameter): 

390 """ Get a parameter saved when creating the database. """ 

391 if self.c.get(1).get('data'): 

392 return self.c.get(1).data.get(parameter, None) 

393 return None 

394 

395 def remove_old_queued(self): 

396 pass 

397 # gen = self.get_generation_number() 

398 # self.c.select() 

399 

400 def is_duplicate(self, **kwargs): 

401 """Check if the key-value pair is already present in the database""" 

402 return len(list(self.c.select(**kwargs))) > 0 

403 

404 def kill_candidate(self, confid): 

405 """Sets extinct=1 in the key_value_pairs of the candidate 

406 with gaid=confid. This could be used in the 

407 mass extinction operator.""" 

408 for dct in self.c.select(gaid=confid): 

409 self.c.update(dct.id, extinct=1) 

410 

411 

412class PrepareDB: 

413 """ Class used to initialize a database. 

414 

415 This class is used once to setup the database and create 

416 working directories. 

417 

418 Parameters: 

419 

420 db_file_name: Database file to use 

421 

422 """ 

423 

424 def __init__(self, db_file_name, simulation_cell=None, **kwargs): 

425 if os.path.exists(db_file_name): 

426 raise OSError('DB file {} already exists' 

427 .format(os.path.abspath(db_file_name))) 

428 self.db_file_name = db_file_name 

429 if simulation_cell is None: 

430 simulation_cell = Atoms() 

431 

432 self.c = ase.db.connect(self.db_file_name) 

433 

434 # Just put everything in data, 

435 # because we don't want to search the db for it. 

436 data = dict(kwargs) 

437 

438 self.c.write(simulation_cell, data=data, 

439 simulation_cell=True) 

440 

441 def add_unrelaxed_candidate(self, candidate, **kwargs): 

442 """ Add an unrelaxed starting candidate. """ 

443 gaid = self.c.write(candidate, origin='StartingCandidateUnrelaxed', 

444 relaxed=0, generation=0, extinct=0, **kwargs) 

445 self.c.update(gaid, gaid=gaid) 

446 candidate.info['confid'] = gaid 

447 

448 def add_relaxed_candidate(self, candidate, **kwargs): 

449 """ Add a relaxed starting candidate. """ 

450 test_raw_score(candidate) 

451 

452 if 'data' in candidate.info: 

453 data = candidate.info['data'] 

454 else: 

455 data = {} 

456 

457 gaid = self.c.write(candidate, origin='StartingCandidateRelaxed', 

458 relaxed=1, generation=0, extinct=0, 

459 key_value_pairs=candidate.info['key_value_pairs'], 

460 data=data, **kwargs) 

461 self.c.update(gaid, gaid=gaid) 

462 candidate.info['confid'] = gaid