Coverage for ase / optimize / minimahopping.py: 57.31%
499 statements
« prev ^ index » next coverage.py v7.14.0, created at 2026-05-21 15:52 +0000
« prev ^ index » next coverage.py v7.14.0, created at 2026-05-21 15:52 +0000
1# fmt: off
3import os
5import numpy as np
7from ase import io, units
8from ase.md import MDLogger, VelocityVerlet, thermalize_momenta
9from ase.optimize import QuasiNewton
10from ase.parallel import paropen, world
13class MinimaHopping:
14 """Implements the minima hopping method of global optimization outlined
15 by S. Goedecker, J. Chem. Phys. 120: 9911 (2004). Initialize with an
16 ASE atoms object. Optional parameters are fed through keywords.
17 To run multiple searches in parallel, specify the minima_traj keyword,
18 and have each run point to the same path.
19 """
21 _default_settings = {
22 'T0': 1000., # K, initial MD 'temperature'
23 'beta1': 1.1, # temperature adjustment parameter
24 'beta2': 1.1, # temperature adjustment parameter
25 'beta3': 1. / 1.1, # temperature adjustment parameter
26 'Ediff0': 0.5, # eV, initial energy acceptance threshold
27 'alpha1': 0.98, # energy threshold adjustment parameter
28 'alpha2': 1. / 0.98, # energy threshold adjustment parameter
29 'mdmin': 2, # criteria to stop MD simulation (no. of minima)
30 'logfile': 'hop.log', # text log
31 'minima_threshold': 0.5, # A, threshold for identical configs
32 'timestep': 1.0, # fs, timestep for MD simulations
33 'optimizer': QuasiNewton, # local optimizer to use
34 'minima_traj': 'minima.traj', # storage file for minima list
35 'fmax': 0.05, # eV/A, max force for optimizations
36 'rng': None}
38 def __init__(self, atoms, comm=world, **kwargs):
39 """Initialize with an ASE atoms object and keyword arguments."""
40 self._atoms = atoms
41 for key in kwargs:
42 if key not in self._default_settings:
43 raise RuntimeError(f'Unknown keyword: {key}')
44 for k, v in self._default_settings.items():
45 setattr(self, f'_{k}', kwargs.pop(k, v))
47 # when a MD sim. has passed a local minimum:
48 self._passedminimum = PassedMinimum()
50 # Misc storage.
51 self._previous_optimum = None
52 self._previous_energy = None
53 self._temperature = self._T0
54 self._Ediff = self._Ediff0
55 self.comm = comm
57 def __call__(self, totalsteps=None, maxtemp=None):
58 """Run the minima hopping algorithm. Can specify stopping criteria
59 with total steps allowed or maximum searching temperature allowed.
60 If neither is specified, runs indefinitely (or until stopped by
61 batching software)."""
62 self._startup()
63 while True:
64 if (totalsteps and self._counter >= totalsteps):
65 self._log('msg', 'Run terminated. Step #%i reached of '
66 '%i allowed. Increase totalsteps if resuming.'
67 % (self._counter, totalsteps))
68 return
69 if (maxtemp and self._temperature >= maxtemp):
70 self._log('msg', 'Run terminated. Temperature is %.2f K;'
71 ' max temperature allowed %.2f K.'
72 % (self._temperature, maxtemp))
73 return
75 self._previous_optimum = self._atoms.copy()
76 self._previous_energy = self._atoms.get_potential_energy()
77 self._molecular_dynamics()
78 self._optimize()
79 self._counter += 1
80 self._check_results()
82 def _startup(self):
83 """Initiates a run, and determines if running from previous data or
84 a fresh run."""
86 status = np.array(-1.)
87 exists = self._read_minima()
88 if self.comm.rank == 0:
89 if not exists:
90 # Fresh run with new minima file.
91 status = np.array(0.)
92 elif not os.path.exists(self._logfile):
93 # Fresh run with existing or shared minima file.
94 status = np.array(1.)
95 else:
96 # Must be resuming from within a working directory.
97 status = np.array(2.)
98 self.comm.barrier()
99 self.comm.broadcast(status, 0)
101 if status == 2.:
102 self._resume()
103 else:
104 self._counter = 0
105 self._log('init')
106 self._log('msg', 'Performing initial optimization.')
107 if status == 1.:
108 self._log('msg', 'Using existing minima file with %i prior '
109 'minima: %s' % (len(self._minima),
110 self._minima_traj))
111 self._optimize()
112 self._check_results()
113 self._counter += 1
115 def _resume(self):
116 """Attempt to resume a run, based on information in the log
117 file. Note it will almost always be interrupted in the middle of
118 either a qn or md run or when exceeding totalsteps, so it only has
119 been tested in those cases currently."""
120 with paropen(self._logfile, 'r', comm=self.comm) as fd:
121 lines = fd.read().splitlines()
123 self._log('msg', 'Attempting to resume stopped run.')
124 self._log('msg', 'Using existing minima file with %i prior '
125 'minima: %s' % (len(self._minima), self._minima_traj))
126 mdcount, qncount = 0, 0
127 for line in lines:
128 if (line[:4] == 'par:') and ('Ediff' not in line):
129 self._temperature = float(line.split()[1])
130 self._Ediff = float(line.split()[2])
131 elif line[:18] == 'msg: Optimization:':
132 qncount = int(line[19:].split('qn')[1])
133 elif line[:24] == 'msg: Molecular dynamics:':
134 mdcount = int(line[25:].split('md')[1])
135 self._counter = max((mdcount, qncount))
136 if qncount == mdcount:
137 # Either stopped during local optimization or terminated due to
138 # max steps.
139 self._log('msg', 'Attempting to resume at qn%05i' % qncount)
140 if qncount > 0:
141 atoms = io.read('qn%05i.traj' % (qncount - 1), index=-1)
142 self._previous_optimum = atoms.copy()
143 self._previous_energy = atoms.get_potential_energy()
144 if os.path.getsize('qn%05i.traj' % qncount) > 0:
145 atoms = io.read('qn%05i.traj' % qncount, index=-1)
146 else:
147 atoms = io.read('md%05i.traj' % qncount, index=-3)
148 self._atoms.positions = atoms.get_positions()
149 fmax = np.sqrt((atoms.get_forces() ** 2).sum(axis=1).max())
150 if fmax < self._fmax:
151 # Stopped after a qn finished.
152 self._log('msg', 'qn%05i fmax already less than fmax=%.3f'
153 % (qncount, self._fmax))
154 self._counter += 1
155 return
156 self._optimize()
157 self._counter += 1
158 if qncount > 0:
159 self._check_results()
160 else:
161 self._record_minimum()
162 self._log('msg', 'Found a new minimum.')
163 self._log('msg', 'Accepted new minimum.')
164 self._log('par')
165 elif qncount < mdcount:
166 # Probably stopped during molecular dynamics.
167 self._log('msg', 'Attempting to resume at md%05i.' % mdcount)
168 atoms = io.read('qn%05i.traj' % qncount, index=-1)
169 self._previous_optimum = atoms.copy()
170 self._previous_energy = atoms.get_potential_energy()
171 self._molecular_dynamics(resume=mdcount)
172 self._optimize()
173 self._counter += 1
174 self._check_results()
176 def _check_results(self):
177 """Adjusts parameters and positions based on outputs."""
179 # No prior minima found?
180 self._read_minima()
181 if len(self._minima) == 0:
182 self._log('msg', 'Found a new minimum.')
183 self._log('msg', 'Accepted new minimum.')
184 self._record_minimum()
185 self._log('par')
186 return
187 # Returned to starting position?
188 if self._previous_optimum:
189 compare = ComparePositions(translate=False)
190 dmax = compare(self._atoms, self._previous_optimum)
191 self._log('msg', 'Max distance to last minimum: %.3f A' % dmax)
192 if dmax < self._minima_threshold:
193 self._log('msg', 'Re-found last minimum.')
194 self._temperature *= self._beta1
195 self._log('par')
196 return
197 # In a previously found position?
198 unique, dmax_closest = self._unique_minimum_position()
199 self._log('msg', 'Max distance to closest minimum: %.3f A' %
200 dmax_closest)
201 if not unique:
202 self._temperature *= self._beta2
203 self._log('msg', 'Found previously found minimum.')
204 self._log('par')
205 if self._previous_optimum:
206 self._log('msg', 'Restoring last minimum.')
207 self._atoms.positions = self._previous_optimum.positions
208 return
209 # Must have found a unique minimum.
210 self._temperature *= self._beta3
211 self._log('msg', 'Found a new minimum.')
212 self._log('par')
213 if (self._previous_energy is None or
214 (self._atoms.get_potential_energy() <
215 self._previous_energy + self._Ediff)):
216 self._log('msg', 'Accepted new minimum.')
217 self._Ediff *= self._alpha1
218 self._log('par')
219 self._record_minimum()
220 else:
221 self._log('msg', 'Rejected new minimum due to energy. '
222 'Restoring last minimum.')
223 self._atoms.positions = self._previous_optimum.positions
224 self._Ediff *= self._alpha2
225 self._log('par')
227 def _log(self, cat='msg', message=None):
228 """Records the message as a line in the log file."""
229 if cat == 'init':
230 if self.comm.rank == 0:
231 if os.path.exists(self._logfile):
232 raise RuntimeError(f'File exists: {self._logfile}')
233 with paropen(self._logfile, 'w', comm=self.comm) as fd:
234 fd.write('par: %12s %12s %12s\n' % ('T (K)', 'Ediff (eV)',
235 'mdmin'))
236 fd.write('ene: %12s %12s %12s\n' % ('E_current', 'E_previous',
237 'Difference'))
238 return
239 with paropen(self._logfile, 'a', comm=self.comm) as fd:
240 if cat == 'msg':
241 line = f'msg: {message}'
242 elif cat == 'par':
243 line = ('par: %12.4f %12.4f %12i' %
244 (self._temperature, self._Ediff, self._mdmin))
245 elif cat == 'ene':
246 current = self._atoms.get_potential_energy()
247 if self._previous_optimum:
248 previous = self._previous_energy
249 line = ('ene: %12.5f %12.5f %12.5f' %
250 (current, previous, current - previous))
251 else:
252 line = ('ene: %12.5f' % current)
253 fd.write(line + '\n')
255 def _optimize(self):
256 """Perform an optimization."""
257 self._atoms.set_momenta(np.zeros(self._atoms.get_momenta().shape))
258 with self._optimizer(self._atoms,
259 trajectory='qn%05i.traj' % self._counter,
260 logfile='qn%05i.log' % self._counter) as opt:
261 self._log('msg', 'Optimization: qn%05i' % self._counter)
262 opt.run(fmax=self._fmax)
263 self._log('ene')
265 def _record_minimum(self):
266 """Adds the current atoms configuration to the minima list."""
267 with io.Trajectory(self._minima_traj, 'a') as traj:
268 traj.write(self._atoms)
269 self._read_minima()
270 self._log('msg', 'Recorded minima #%i.' % (len(self._minima) - 1))
272 def _read_minima(self):
273 """Reads in the list of minima from the minima file."""
274 exists = os.path.exists(self._minima_traj)
275 if exists:
276 empty = os.path.getsize(self._minima_traj) == 0
277 if not empty:
278 with io.Trajectory(self._minima_traj, 'r') as traj:
279 self._minima = [atoms for atoms in traj]
280 else:
281 self._minima = []
282 return True
283 else:
284 self._minima = []
285 return False
287 def _molecular_dynamics(self, resume=None):
288 """Performs a molecular dynamics simulation, until mdmin is
289 exceeded. If resuming, the file number (md%05i) is expected."""
290 self._log('msg', 'Molecular dynamics: md%05i' % self._counter)
291 mincount = 0
292 energies, oldpositions = [], []
293 thermalized = False
294 if resume:
295 self._log('msg', 'Resuming MD from md%05i.traj' % resume)
296 if os.path.getsize('md%05i.traj' % resume) == 0:
297 self._log('msg', 'md%05i.traj is empty. Resuming from '
298 'qn%05i.traj.' % (resume, resume - 1))
299 atoms = io.read('qn%05i.traj' % (resume - 1), index=-1)
300 else:
301 with io.Trajectory('md%05i.traj' % resume, 'r') as images:
302 for atoms in images:
303 energies.append(atoms.get_potential_energy())
304 oldpositions.append(atoms.positions.copy())
305 passedmin = self._passedminimum(energies)
306 if passedmin:
307 mincount += 1
308 self._atoms.set_momenta(atoms.get_momenta())
309 thermalized = True
310 self._atoms.positions = atoms.get_positions()
311 self._log('msg', 'Starting MD with %i existing energies.' %
312 len(energies))
313 if not thermalized:
314 thermalize_momenta(
315 self._atoms,
316 self._temperature,
317 exact_temperature=True,
318 rng=self._rng
319 )
320 traj = io.Trajectory('md%05i.traj' % self._counter, 'a',
321 self._atoms)
322 dyn = VelocityVerlet(self._atoms, timestep=self._timestep * units.fs)
323 log = MDLogger(dyn, self._atoms, 'md%05i.log' % self._counter,
324 header=True, stress=False, peratom=False)
326 with traj, dyn, log:
327 dyn.attach(log, interval=1)
328 dyn.attach(traj, interval=1)
329 while mincount < self._mdmin:
330 dyn.run(1)
331 energies.append(self._atoms.get_potential_energy())
332 passedmin = self._passedminimum(energies)
333 if passedmin:
334 mincount += 1
335 oldpositions.append(self._atoms.positions.copy())
336 # Reset atoms to minimum point.
337 self._atoms.positions = oldpositions[passedmin[0]]
339 def _unique_minimum_position(self):
340 """Identifies if the current position of the atoms, which should be
341 a local minima, has been found before."""
342 unique = True
343 dmax_closest = 99999.
344 compare = ComparePositions(translate=True)
345 self._read_minima()
346 for minimum in self._minima:
347 dmax = compare(minimum, self._atoms)
348 if dmax < self._minima_threshold:
349 unique = False
350 if dmax < dmax_closest:
351 dmax_closest = dmax
352 return unique, dmax_closest
355class ComparePositions:
356 """Class that compares the atomic positions between two ASE atoms
357 objects. Returns the maximum distance that any atom has moved, assuming
358 all atoms of the same element are indistinguishable. If translate is
359 set to True, allows for arbitrary translations within the unit cell,
360 as well as translations across any periodic boundary conditions. When
361 called, returns the maximum displacement of any one atom."""
363 def __init__(self, translate=True):
364 self._translate = translate
366 def __call__(self, atoms1, atoms2):
367 atoms1 = atoms1.copy()
368 atoms2 = atoms2.copy()
369 if not self._translate:
370 dmax = self. _indistinguishable_compare(atoms1, atoms2)
371 else:
372 dmax = self._translated_compare(atoms1, atoms2)
373 return dmax
375 def _translated_compare(self, atoms1, atoms2):
376 """Moves the atoms around and tries to pair up atoms, assuming any
377 atoms with the same symbol are indistinguishable, and honors
378 periodic boundary conditions (for example, so that an atom at
379 (0.1, 0., 0.) correctly is found to be close to an atom at
380 (7.9, 0., 0.) if the atoms are in an orthorhombic cell with
381 x-dimension of 8. Returns dmax, the maximum distance between any
382 two atoms in the optimal configuration."""
383 atoms1.set_constraint()
384 atoms2.set_constraint()
385 for index in range(3):
386 assert atoms1.pbc[index] == atoms2.pbc[index]
387 least = self._get_least_common(atoms1)
388 indices1 = [atom.index for atom in atoms1 if atom.symbol == least[0]]
389 indices2 = [atom.index for atom in atoms2 if atom.symbol == least[0]]
390 # Make comparison sets from atoms2, which contain repeated atoms in
391 # all pbc's and bring the atom listed in indices2 to (0,0,0)
392 comparisons = []
393 repeat = []
394 for bc in atoms2.pbc:
395 if bc:
396 repeat.append(3)
397 else:
398 repeat.append(1)
399 repeated = atoms2.repeat(repeat)
400 moved_cell = atoms2.cell * atoms2.pbc
401 for moved in moved_cell:
402 repeated.translate(-moved)
403 repeated.set_cell(atoms2.cell)
404 for index in indices2:
405 comparison = repeated.copy()
406 comparison.translate(-atoms2[index].position)
407 comparisons.append(comparison)
408 # Bring the atom listed in indices1 to (0,0,0) [not whole list]
409 standard = atoms1.copy()
410 standard.translate(-atoms1[indices1[0]].position)
411 # Compare the standard to the comparison sets.
412 dmaxes = []
413 for comparison in comparisons:
414 dmax = self._indistinguishable_compare(standard, comparison)
415 dmaxes.append(dmax)
416 return min(dmaxes)
418 def _get_least_common(self, atoms):
419 """Returns the least common element in atoms. If more than one,
420 returns the first encountered."""
421 symbols = [atom.symbol for atom in atoms]
422 least = ['', np.inf]
423 for element in set(symbols):
424 count = symbols.count(element)
425 if count < least[1]:
426 least = [element, count]
427 return least
429 def _indistinguishable_compare(self, atoms1, atoms2):
430 """Finds each atom in atoms1's nearest neighbor with the same
431 chemical symbol in atoms2. Return dmax, the farthest distance an
432 individual atom differs by."""
433 atoms2 = atoms2.copy() # allow deletion
434 atoms2.set_constraint()
435 dmax = 0.
436 for atom1 in atoms1:
437 closest = [np.nan, np.inf]
438 for index, atom2 in enumerate(atoms2):
439 if atom2.symbol == atom1.symbol:
440 d = np.linalg.norm(atom1.position - atom2.position)
441 if d < closest[1]:
442 closest = [index, d]
443 if closest[1] > dmax:
444 dmax = closest[1]
445 del atoms2[closest[0]]
446 return dmax
449class PassedMinimum:
450 """Simple routine to find if a minimum in the potential energy surface
451 has been passed. In its default settings, a minimum is found if the
452 sequence ends with two downward points followed by two upward points.
453 Initialize with n_down and n_up, integer values of the number of up and
454 down points. If it has successfully determined it passed a minimum, it
455 returns the value (energy) of that minimum and the number of positions
456 back it occurred, otherwise returns None."""
458 def __init__(self, n_down=2, n_up=2):
459 self._ndown = n_down
460 self._nup = n_up
462 def __call__(self, energies):
463 if len(energies) < (self._nup + self._ndown + 1):
464 return None
465 status = True
466 index = -1
467 for _ in range(self._nup):
468 if energies[index] < energies[index - 1]:
469 status = False
470 index -= 1
471 for _ in range(self._ndown):
472 if energies[index] > energies[index - 1]:
473 status = False
474 index -= 1
475 if status:
476 return (-self._nup - 1), energies[-self._nup - 1]
479class MHPlot:
480 """Makes a plot summarizing the output of the MH algorithm from the
481 specified rundirectory. If no rundirectory is supplied, uses the
482 current directory."""
484 def __init__(self, rundirectory=None, logname='hop.log'):
485 if not rundirectory:
486 rundirectory = os.getcwd()
487 self._rundirectory = rundirectory
488 self._logname = logname
489 self._read_log()
490 self._fig, self._ax = self._makecanvas()
491 self._plot_data()
493 def get_figure(self):
494 """Returns the matplotlib figure object."""
495 return self._fig
497 def save_figure(self, filename):
498 """Saves the file to the specified path, with any allowed
499 matplotlib extension (e.g., .pdf, .png, etc.)."""
500 self._fig.savefig(filename)
502 def _read_log(self):
503 """Reads relevant parts of the log file."""
504 data = [] # format: [energy, status, temperature, ediff]
506 with open(os.path.join(self._rundirectory, self._logname)) as fd:
507 lines = fd.read().splitlines()
509 step_almost_over = False
510 step_over = False
511 for line in lines:
512 if line.startswith('msg: Molecular dynamics:'):
513 status = 'performing MD'
514 elif line.startswith('msg: Optimization:'):
515 status = 'performing QN'
516 elif line.startswith('ene:'):
517 status = 'local optimum reached'
518 energy = floatornan(line.split()[1])
519 elif line.startswith('msg: Accepted new minimum.'):
520 status = 'accepted'
521 step_almost_over = True
522 elif line.startswith('msg: Found previously found minimum.'):
523 status = 'previously found minimum'
524 step_almost_over = True
525 elif line.startswith('msg: Re-found last minimum.'):
526 status = 'previous minimum'
527 step_almost_over = True
528 elif line.startswith('msg: Rejected new minimum'):
529 status = 'rejected'
530 step_almost_over = True
531 elif line.startswith('par: '):
532 temperature = floatornan(line.split()[1])
533 ediff = floatornan(line.split()[2])
534 if step_almost_over:
535 step_over = True
536 step_almost_over = False
537 if step_over:
538 data.append([energy, status, temperature, ediff])
539 step_over = False
540 if data[-1][1] != status:
541 data.append([np.nan, status, temperature, ediff])
542 self._data = data
544 def _makecanvas(self):
545 from matplotlib import pyplot
546 from matplotlib.ticker import ScalarFormatter
547 fig = pyplot.figure(figsize=(6., 8.))
548 lm, rm, bm, tm = 0.22, 0.02, 0.05, 0.04
549 vg1 = 0.01 # between adjacent energy plots
550 vg2 = 0.03 # between different types of plots
551 ratio = 2. # size of an energy plot to a parameter plot
552 figwidth = 1. - lm - rm
553 totalfigheight = 1. - bm - tm - vg1 - 2. * vg2
554 parfigheight = totalfigheight / (2. * ratio + 2)
555 epotheight = ratio * parfigheight
556 ax1 = fig.add_axes((lm, bm, figwidth, epotheight))
557 ax2 = fig.add_axes((lm, bm + epotheight + vg1,
558 figwidth, epotheight))
559 for ax in [ax1, ax2]:
560 ax.yaxis.set_major_formatter(ScalarFormatter(useOffset=False))
561 ediffax = fig.add_axes((lm, bm + 2. * epotheight + vg1 + vg2,
562 figwidth, parfigheight))
563 tempax = fig.add_axes((lm, (bm + 2 * epotheight + vg1 + 2 * vg2 +
564 parfigheight), figwidth, parfigheight))
565 for ax in [ax2, tempax, ediffax]:
566 ax.set_xticklabels([])
567 ax1.set_xlabel('step')
568 tempax.set_ylabel('$T$, K')
569 ediffax.set_ylabel(r'$E_\mathrm{diff}$, eV')
570 for ax in [ax1, ax2]:
571 ax.set_ylabel(r'$E_\mathrm{pot}$, eV')
572 ax = CombinedAxis(ax1, ax2, tempax, ediffax)
573 self._set_zoomed_range(ax)
574 ax1.spines['top'].set_visible(False)
575 ax2.spines['bottom'].set_visible(False)
576 return fig, ax
578 def _set_zoomed_range(self, ax):
579 """Try to intelligently set the range for the zoomed-in part of the
580 graph."""
581 energies = [line[0] for line in self._data
582 if not np.isnan(line[0])]
583 dr = max(energies) - min(energies)
584 if dr == 0.:
585 dr = 1.
586 ax.set_ax1_range((min(energies) - 0.2 * dr,
587 max(energies) + 0.2 * dr))
589 def _plot_data(self):
590 for step, line in enumerate(self._data):
591 self._plot_energy(step, line)
592 self._plot_qn(step, line)
593 self._plot_md(step, line)
594 self._plot_parameters()
595 self._ax.set_xlim(self._ax.ax1.get_xlim())
597 def _plot_energy(self, step, line):
598 """Plots energy and annotation for acceptance."""
599 energy, status = line[0], line[1]
600 if np.isnan(energy):
601 return
602 self._ax.plot([step, step + 0.5], [energy] * 2, '-',
603 color='k', linewidth=2.)
604 if status == 'accepted':
605 self._ax.text(step + 0.51, energy, r'$\checkmark$')
606 elif status == 'rejected':
607 self._ax.text(step + 0.51, energy, r'$\Uparrow$', color='red')
608 elif status == 'previously found minimum':
609 self._ax.text(step + 0.51, energy, r'$\hookleftarrow$',
610 color='red', va='center')
611 elif status == 'previous minimum':
612 self._ax.text(step + 0.51, energy, r'$\leftarrow$',
613 color='red', va='center')
615 def _plot_md(self, step, line):
616 """Adds a curved plot of molecular dynamics trajectory."""
617 if step == 0:
618 return
619 energies = [self._data[step - 1][0]]
620 file = os.path.join(self._rundirectory, 'md%05i.traj' % step)
621 with io.Trajectory(file, 'r') as traj:
622 for atoms in traj:
623 energies.append(atoms.get_potential_energy())
624 xi = step - 1 + .5
625 if len(energies) > 2:
626 xf = xi + (step + 0.25 - xi) * len(energies) / (len(energies) - 2.)
627 else:
628 xf = step
629 if xf > (step + .75):
630 xf = step
631 self._ax.plot(np.linspace(xi, xf, num=len(energies)), energies,
632 '-k')
634 def _plot_qn(self, index, line):
635 """Plots a dashed vertical line for the optimization."""
636 if line[1] == 'performing MD':
637 return
638 file = os.path.join(self._rundirectory, 'qn%05i.traj' % index)
639 if os.path.getsize(file) == 0:
640 return
641 with io.Trajectory(file, 'r') as traj:
642 energies = [traj[0].get_potential_energy(),
643 traj[-1].get_potential_energy()]
644 if index > 0:
645 file = os.path.join(self._rundirectory, 'md%05i.traj' % index)
646 atoms = io.read(file, index=-3)
647 energies[0] = atoms.get_potential_energy()
648 self._ax.plot([index + 0.25] * 2, energies, ':k')
650 def _plot_parameters(self):
651 """Adds a plot of temperature and Ediff to the plot."""
652 steps, Ts, ediffs = [], [], []
653 for step, line in enumerate(self._data):
654 steps.extend([step + 0.5, step + 1.5])
655 Ts.extend([line[2]] * 2)
656 ediffs.extend([line[3]] * 2)
657 self._ax.tempax.plot(steps, Ts)
658 self._ax.ediffax.plot(steps, ediffs)
660 for ax in [self._ax.tempax, self._ax.ediffax]:
661 ylim = ax.get_ylim()
662 yrange = ylim[1] - ylim[0]
663 ax.set_ylim((ylim[0] - 0.1 * yrange, ylim[1] + 0.1 * yrange))
666def floatornan(value):
667 """Converts the argument into a float if possible, np.nan if not."""
668 try:
669 output = float(value)
670 except ValueError:
671 output = np.nan
672 return output
675class CombinedAxis:
676 """Helper class for MHPlot to plot on split y axis and adjust limits
677 simultaneously."""
679 def __init__(self, ax1, ax2, tempax, ediffax):
680 self.ax1 = ax1
681 self.ax2 = ax2
682 self.tempax = tempax
683 self.ediffax = ediffax
684 self._ymax = -np.inf
686 def set_ax1_range(self, ylim):
687 self._ax1_ylim = ylim
688 self.ax1.set_ylim(ylim)
690 def plot(self, *args, **kwargs):
691 self.ax1.plot(*args, **kwargs)
692 self.ax2.plot(*args, **kwargs)
693 # Re-adjust yrange
694 for yvalue in args[1]:
695 if yvalue > self._ymax:
696 self._ymax = yvalue
697 self.ax1.set_ylim(self._ax1_ylim)
698 self.ax2.set_ylim((self._ax1_ylim[1], self._ymax))
700 def set_xlim(self, *args):
701 self.ax1.set_xlim(*args)
702 self.ax2.set_xlim(*args)
703 self.tempax.set_xlim(*args)
704 self.ediffax.set_xlim(*args)
706 def text(self, *args, **kwargs):
707 y = args[1]
708 if y < self._ax1_ylim[1]:
709 ax = self.ax1
710 else:
711 ax = self.ax2
712 ax.text(*args, **kwargs)