Coverage for /builds/ase/ase/ase/ga/data.py: 78.22%
225 statements
« prev ^ index » next coverage.py v7.5.3, created at 2025-08-02 00:12 +0000
« prev ^ index » next coverage.py v7.5.3, created at 2025-08-02 00:12 +0000
1# fmt: off
3"""
4 Objects which handle all communication with the SQLite database.
5"""
6import os
8import ase.db
9from ase import Atoms
10from ase.ga import get_raw_score, set_neighbor_list, set_parametrization
13def split_description(desc):
14 """ Utility method for string splitting. """
15 d = desc.split(':')
16 assert len(d) == 2, desc
17 return d[0], d[1]
20def test_raw_score(atoms):
21 """Test that raw_score can be extracted."""
22 err_msg = "raw_score not put in atoms.info['key_value_pairs']"
23 assert 'raw_score' in atoms.info['key_value_pairs'], err_msg
26class DataConnection:
27 """Class that handles all database communication.
29 All data communication is collected in this class in order to
30 make a decoupling of the data representation and the GA method.
32 A new candidate must be added with one of the functions
33 add_unrelaxed_candidate or add_relaxed_candidate this will correctly
34 initialize a configuration id used to keep track of candidates in the
35 database.
36 After one of the add_*_candidate functions have been used, if the candidate
37 is further modified or relaxed the functions add_unrelaxed_step or
38 add_relaxed_step must be used. This way the configuration id carries
39 through correctly.
41 Parameters:
43 db_file_name: Path to the ase.db data file.
44 """
46 def __init__(self, db_file_name):
47 self.db_file_name = db_file_name
48 if not os.path.isfile(self.db_file_name):
49 raise OSError(f'DB file {self.db_file_name} not found')
50 self.c = ase.db.connect(self.db_file_name)
51 self.already_returned = set()
53 def get_number_of_unrelaxed_candidates(self):
54 """ Returns the number of candidates not yet queued or relaxed. """
55 return len(self.__get_ids_of_all_unrelaxed_candidates__())
57 def get_an_unrelaxed_candidate(self):
58 """ Returns a candidate ready for relaxation. """
59 to_get = self.__get_ids_of_all_unrelaxed_candidates__()
60 if len(to_get) == 0:
61 raise ValueError('No unrelaxed candidate to return')
63 a = self.__get_latest_traj_for_confid__(to_get[0])
64 a.info['confid'] = to_get[0]
65 if 'data' not in a.info:
66 a.info['data'] = {}
67 return a
69 def get_all_unrelaxed_candidates(self):
70 """Return all unrelaxed candidates,
71 useful if they can all be evaluated quickly."""
72 to_get = self.__get_ids_of_all_unrelaxed_candidates__()
73 if len(to_get) == 0:
74 return []
75 res = []
76 for confid in to_get:
77 a = self.__get_latest_traj_for_confid__(confid)
78 a.info['confid'] = confid
79 if 'data' not in a.info:
80 a.info['data'] = {}
81 res.append(a)
82 return res
84 def __get_ids_of_all_unrelaxed_candidates__(self):
85 """ Helper method used by the two above methods. """
87 all_unrelaxed_ids = {t.gaid for t in self.c.select(relaxed=0)}
88 all_relaxed_ids = {t.gaid for t in self.c.select(relaxed=1)}
89 all_queued_ids = {t.gaid for t in self.c.select(queued=1)}
91 actually_unrelaxed = [gaid for gaid in all_unrelaxed_ids
92 if (gaid not in all_relaxed_ids and
93 gaid not in all_queued_ids)]
95 return actually_unrelaxed
97 def __get_latest_traj_for_confid__(self, confid):
98 """ Method for obtaining the latest traj
99 file for a given configuration.
100 There can be several traj files for
101 one configuration if it has undergone
102 several changes (mutations, pairings, etc.)."""
103 allcands = list(self.c.select(gaid=confid))
104 allcands.sort(key=lambda x: x.mtime)
105 # return self.get_atoms(all[-1].gaid)
106 return self.get_atoms(allcands[-1].id)
108 def mark_as_queued(self, a):
109 """ Marks a configuration as queued for relaxation. """
110 gaid = a.info['confid']
111 self.c.write(None, gaid=gaid, queued=1,
112 key_value_pairs=a.info['key_value_pairs'])
114# if not np.array_equal(a.numbers, self.atom_numbers):
115# raise ValueError('Wrong stoichiometry')
116# self.c.write(a, gaid=gaid, queued=1)
118 def add_relaxed_step(self, a, find_neighbors=None,
119 perform_parametrization=None):
120 """After a candidate is relaxed it must be marked
121 as such. Use this function if the candidate has already been in the
122 database in an unrelaxed version, i.e. add_unrelaxed_candidate has
123 been used.
125 Neighbor list and parametrization parameters to screen
126 candidates before relaxation can be added. Default is not to use.
127 """
128 # test that raw_score can be extracted
129 err_msg = "raw_score not put in atoms.info['key_value_pairs']"
130 assert 'raw_score' in a.info['key_value_pairs'], err_msg
132 # confid has already been set in add_unrelaxed_candidate
133 gaid = a.info['confid']
135 if 'generation' not in a.info['key_value_pairs']:
136 g = self.get_generation_number()
137 a.info['key_value_pairs']['generation'] = g
139 if find_neighbors is not None:
140 set_neighbor_list(a, find_neighbors(a))
141 if perform_parametrization is not None:
142 set_parametrization(a, perform_parametrization(a))
144 relax_id = self.c.write(a, relaxed=1, gaid=gaid,
145 key_value_pairs=a.info['key_value_pairs'],
146 data=a.info['data'])
147 a.info['relax_id'] = relax_id
149 def add_relaxed_candidate(self, a, find_neighbors=None,
150 perform_parametrization=None):
151 """After a candidate is relaxed it must be marked
152 as such. Use this function if the candidate has *not* been in the
153 database in an unrelaxed version, i.e. add_unrelaxed_candidate has
154 *not* been used.
156 Neighbor list and parametrization parameters to screen
157 candidates before relaxation can be added. Default is not to use.
158 """
159 test_raw_score(a)
161 if 'generation' not in a.info['key_value_pairs']:
162 g = self.get_generation_number()
163 a.info['key_value_pairs']['generation'] = g
165 if find_neighbors is not None:
166 set_neighbor_list(a, find_neighbors(a))
167 if perform_parametrization is not None:
168 set_parametrization(a, perform_parametrization(a))
170 relax_id = self.c.write(a, relaxed=1,
171 key_value_pairs=a.info['key_value_pairs'],
172 data=a.info['data'])
173 self.c.update(relax_id, gaid=relax_id)
174 a.info['confid'] = relax_id
175 a.info['relax_id'] = relax_id
177 def add_more_relaxed_steps(self, a_list):
178 # This function will be removed soon as the function name indicates
179 # that unrelaxed candidates are added beforehand
180 print('Please use add_more_relaxed_candidates instead')
181 self.add_more_relaxed_candidates(a_list)
183 def add_more_relaxed_candidates(self, a_list):
184 """Add more relaxed candidates quickly"""
185 for a in a_list:
186 try:
187 a.info['key_value_pairs']['raw_score']
188 except KeyError:
189 print("raw_score not put in atoms.info['key_value_pairs']")
191 g = self.get_generation_number()
193 # Insert gaid by getting the next available id and assuming that the
194 # entire a_list will be written without interuption
195 next_id = self.get_next_id()
196 with self.c as con:
197 for j, a in enumerate(a_list):
198 if 'generation' not in a.info['key_value_pairs']:
199 a.info['key_value_pairs']['generation'] = g
201 gaid = next_id + j
202 relax_id = con.write(a, relaxed=1, gaid=gaid,
203 key_value_pairs=a.info['key_value_pairs'],
204 data=a.info['data'])
205 assert gaid == relax_id
206 a.info['confid'] = relax_id
207 a.info['relax_id'] = relax_id
209 def get_next_id(self):
210 """Get the id of the next candidate to be added to the database.
211 This is a hacky way of obtaining the id and it only works on a
212 sqlite database.
213 """
214 con = self.c._connect()
215 last_id = self.c.get_last_id(con.cursor())
216 con.close()
217 return last_id + 1
219 def get_largest_in_db(self, var):
220 return next(self.c.select(sort=f'-{var}')).get(var)
222 def add_unrelaxed_candidate(self, candidate, description):
223 """ Adds a new candidate which needs to be relaxed. """
224 t, desc = split_description(description)
225 kwargs = {'relaxed': 0,
226 'extinct': 0,
227 t: 1,
228 'description': desc}
230 if 'generation' not in candidate.info['key_value_pairs']:
231 kwargs.update({'generation': self.get_generation_number()})
233 gaid = self.c.write(candidate,
234 key_value_pairs=candidate.info['key_value_pairs'],
235 data=candidate.info['data'],
236 **kwargs)
237 self.c.update(gaid, gaid=gaid)
238 candidate.info['confid'] = gaid
240 def add_unrelaxed_step(self, candidate, description):
241 """ Add a change to a candidate without it having been relaxed.
242 This method is typically used when a
243 candidate has been mutated. """
245 # confid has already been set by add_unrelaxed_candidate
246 gaid = candidate.info['confid']
248 t, desc = split_description(description)
249 kwargs = {'relaxed': 0,
250 'extinct': 0,
251 t: 1,
252 'description': desc, 'gaid': gaid}
254 self.c.write(candidate,
255 key_value_pairs=candidate.info['key_value_pairs'],
256 data=candidate.info['data'],
257 **kwargs)
259 def get_number_of_atoms_to_optimize(self):
260 """ Get the number of atoms being optimized. """
261 v = self.c.get(simulation_cell=True)
262 return len(v.data.stoichiometry)
264 def get_atom_numbers_to_optimize(self):
265 """ Get the list of atom numbers being optimized. """
266 v = self.c.get(simulation_cell=True)
267 return v.data.stoichiometry
269 def get_slab(self):
270 """ Get the super cell, including stationary atoms, in which
271 the structure is being optimized. """
272 return self.c.get_atoms(simulation_cell=True)
274 def get_participation_in_pairing(self):
275 """ Get information about how many direct
276 offsprings each candidate has, and which specific
277 pairings have been made. This information is used
278 for the extended fitness calculation described in
279 L.B. Vilhelmsen et al., JACS, 2012, 134 (30), pp 12807-12816
280 """
281 entries = self.c.select(pairing=1)
283 frequency = {}
284 pairs = []
285 for e in entries:
286 c1, c2 = e.data['parents']
287 pairs.append(tuple(sorted([c1, c2])))
288 if c1 not in frequency.keys():
289 frequency[c1] = 0
290 frequency[c1] += 1
291 if c2 not in frequency.keys():
292 frequency[c2] = 0
293 frequency[c2] += 1
294 return (frequency, pairs)
296 def get_all_relaxed_candidates(self, only_new=False, use_extinct=False):
297 """ Returns all candidates that have been relaxed.
299 Parameters:
301 only_new: boolean (optional)
302 Used to specify only to get candidates relaxed since last
303 time this function was invoked. Default: False.
305 use_extinct: boolean (optional)
306 Set to True if the extinct key (and mass extinction) is going
307 to be used. Default: False."""
309 if use_extinct:
310 entries = self.c.select('relaxed=1,extinct=0',
311 sort='-raw_score')
312 else:
313 entries = self.c.select('relaxed=1', sort='-raw_score')
315 trajs = []
316 for v in entries:
317 if only_new and v.gaid in self.already_returned:
318 continue
319 t = self.get_atoms(id=v.id)
320 t.info['confid'] = v.gaid
321 t.info['relax_id'] = v.id
322 trajs.append(t)
323 self.already_returned.add(v.gaid)
324 return trajs
326 def get_all_relaxed_candidates_after_generation(self, gen):
327 """ Returns all candidates that have been relaxed up to
328 and including the specified generation
329 """
330 q = 'relaxed=1,extinct=0,generation<={0}'
331 entries = self.c.select(q.format(gen))
333 trajs = []
334 for v in entries:
335 t = self.get_atoms(id=v.id)
336 t.info['confid'] = v.gaid
337 t.info['relax_id'] = v.id
338 trajs.append(t)
339 trajs.sort(key=get_raw_score,
340 reverse=True)
341 return trajs
343 def get_all_candidates_in_queue(self):
344 """ Returns all structures that are queued, but have not yet
345 been relaxed. """
346 all_queued_ids = [t.gaid for t in self.c.select(queued=1)]
347 all_relaxed_ids = [t.gaid for t in self.c.select(relaxed=1)]
349 in_queue = [qid for qid in all_queued_ids
350 if qid not in all_relaxed_ids]
351 return in_queue
353 def remove_from_queue(self, confid):
354 """ Removes the candidate confid from the queue. """
356 queued_ids = self.c.select(queued=1, gaid=confid)
357 ids = [q.id for q in queued_ids]
358 self.c.delete(ids)
360 def get_generation_number(self, size=None):
361 """ Returns the current generation number, by looking
362 at the number of relaxed individuals and comparing
363 this number to the supplied size or population size.
365 If all individuals in generation 3 has been relaxed
366 it will return 4 if not all in generation 4 has been
367 relaxed.
368 """
369 if size is None:
370 size = self.get_param('population_size')
371 if size is None:
372 # size = len(list(self.c.select(relaxed=0,generation=0)))
373 return 0
374 lg = size
375 g = 0
376 all_candidates = list(self.c.select(relaxed=1))
377 while lg > 0:
378 lg = len([c for c in all_candidates if c.generation == g])
379 if lg >= size:
380 g += 1
381 else:
382 return g
384 def get_atoms(self, id, add_info=True):
385 """Return the atoms object with the specified id"""
386 a = self.c.get_atoms(id, add_additional_information=add_info)
387 return a
389 def get_param(self, parameter):
390 """ Get a parameter saved when creating the database. """
391 if self.c.get(1).get('data'):
392 return self.c.get(1).data.get(parameter, None)
393 return None
395 def remove_old_queued(self):
396 pass
397 # gen = self.get_generation_number()
398 # self.c.select()
400 def is_duplicate(self, **kwargs):
401 """Check if the key-value pair is already present in the database"""
402 return len(list(self.c.select(**kwargs))) > 0
404 def kill_candidate(self, confid):
405 """Sets extinct=1 in the key_value_pairs of the candidate
406 with gaid=confid. This could be used in the
407 mass extinction operator."""
408 for dct in self.c.select(gaid=confid):
409 self.c.update(dct.id, extinct=1)
412class PrepareDB:
413 """ Class used to initialize a database.
415 This class is used once to setup the database and create
416 working directories.
418 Parameters:
420 db_file_name: Database file to use
422 """
424 def __init__(self, db_file_name, simulation_cell=None, **kwargs):
425 if os.path.exists(db_file_name):
426 raise OSError('DB file {} already exists'
427 .format(os.path.abspath(db_file_name)))
428 self.db_file_name = db_file_name
429 if simulation_cell is None:
430 simulation_cell = Atoms()
432 self.c = ase.db.connect(self.db_file_name)
434 # Just put everything in data,
435 # because we don't want to search the db for it.
436 data = dict(kwargs)
438 self.c.write(simulation_cell, data=data,
439 simulation_cell=True)
441 def add_unrelaxed_candidate(self, candidate, **kwargs):
442 """ Add an unrelaxed starting candidate. """
443 gaid = self.c.write(candidate, origin='StartingCandidateUnrelaxed',
444 relaxed=0, generation=0, extinct=0, **kwargs)
445 self.c.update(gaid, gaid=gaid)
446 candidate.info['confid'] = gaid
448 def add_relaxed_candidate(self, candidate, **kwargs):
449 """ Add a relaxed starting candidate. """
450 test_raw_score(candidate)
452 if 'data' in candidate.info:
453 data = candidate.info['data']
454 else:
455 data = {}
457 gaid = self.c.write(candidate, origin='StartingCandidateRelaxed',
458 relaxed=1, generation=0, extinct=0,
459 key_value_pairs=candidate.info['key_value_pairs'],
460 data=data, **kwargs)
461 self.c.update(gaid, gaid=gaid)
462 candidate.info['confid'] = gaid