Coverage for ase / optimize / minimahopping.py: 57.40%
500 statements
« prev ^ index » next coverage.py v7.13.3, created at 2026-02-04 10:20 +0000
« prev ^ index » next coverage.py v7.13.3, created at 2026-02-04 10:20 +0000
1# fmt: off
3import os
5import numpy as np
7from ase import io, units
8from ase.md import MDLogger, VelocityVerlet
9from ase.md.velocitydistribution import MaxwellBoltzmannDistribution
10from ase.optimize import QuasiNewton
11from ase.parallel import paropen, world
14class MinimaHopping:
15 """Implements the minima hopping method of global optimization outlined
16 by S. Goedecker, J. Chem. Phys. 120: 9911 (2004). Initialize with an
17 ASE atoms object. Optional parameters are fed through keywords.
18 To run multiple searches in parallel, specify the minima_traj keyword,
19 and have each run point to the same path.
20 """
22 _default_settings = {
23 'T0': 1000., # K, initial MD 'temperature'
24 'beta1': 1.1, # temperature adjustment parameter
25 'beta2': 1.1, # temperature adjustment parameter
26 'beta3': 1. / 1.1, # temperature adjustment parameter
27 'Ediff0': 0.5, # eV, initial energy acceptance threshold
28 'alpha1': 0.98, # energy threshold adjustment parameter
29 'alpha2': 1. / 0.98, # energy threshold adjustment parameter
30 'mdmin': 2, # criteria to stop MD simulation (no. of minima)
31 'logfile': 'hop.log', # text log
32 'minima_threshold': 0.5, # A, threshold for identical configs
33 'timestep': 1.0, # fs, timestep for MD simulations
34 'optimizer': QuasiNewton, # local optimizer to use
35 'minima_traj': 'minima.traj', # storage file for minima list
36 'fmax': 0.05, # eV/A, max force for optimizations
37 'rng': None}
39 def __init__(self, atoms, comm=world, **kwargs):
40 """Initialize with an ASE atoms object and keyword arguments."""
41 self._atoms = atoms
42 for key in kwargs:
43 if key not in self._default_settings:
44 raise RuntimeError(f'Unknown keyword: {key}')
45 for k, v in self._default_settings.items():
46 setattr(self, f'_{k}', kwargs.pop(k, v))
48 # when a MD sim. has passed a local minimum:
49 self._passedminimum = PassedMinimum()
51 # Misc storage.
52 self._previous_optimum = None
53 self._previous_energy = None
54 self._temperature = self._T0
55 self._Ediff = self._Ediff0
56 self.comm = comm
58 def __call__(self, totalsteps=None, maxtemp=None):
59 """Run the minima hopping algorithm. Can specify stopping criteria
60 with total steps allowed or maximum searching temperature allowed.
61 If neither is specified, runs indefinitely (or until stopped by
62 batching software)."""
63 self._startup()
64 while True:
65 if (totalsteps and self._counter >= totalsteps):
66 self._log('msg', 'Run terminated. Step #%i reached of '
67 '%i allowed. Increase totalsteps if resuming.'
68 % (self._counter, totalsteps))
69 return
70 if (maxtemp and self._temperature >= maxtemp):
71 self._log('msg', 'Run terminated. Temperature is %.2f K;'
72 ' max temperature allowed %.2f K.'
73 % (self._temperature, maxtemp))
74 return
76 self._previous_optimum = self._atoms.copy()
77 self._previous_energy = self._atoms.get_potential_energy()
78 self._molecular_dynamics()
79 self._optimize()
80 self._counter += 1
81 self._check_results()
83 def _startup(self):
84 """Initiates a run, and determines if running from previous data or
85 a fresh run."""
87 status = np.array(-1.)
88 exists = self._read_minima()
89 if self.comm.rank == 0:
90 if not exists:
91 # Fresh run with new minima file.
92 status = np.array(0.)
93 elif not os.path.exists(self._logfile):
94 # Fresh run with existing or shared minima file.
95 status = np.array(1.)
96 else:
97 # Must be resuming from within a working directory.
98 status = np.array(2.)
99 self.comm.barrier()
100 self.comm.broadcast(status, 0)
102 if status == 2.:
103 self._resume()
104 else:
105 self._counter = 0
106 self._log('init')
107 self._log('msg', 'Performing initial optimization.')
108 if status == 1.:
109 self._log('msg', 'Using existing minima file with %i prior '
110 'minima: %s' % (len(self._minima),
111 self._minima_traj))
112 self._optimize()
113 self._check_results()
114 self._counter += 1
116 def _resume(self):
117 """Attempt to resume a run, based on information in the log
118 file. Note it will almost always be interrupted in the middle of
119 either a qn or md run or when exceeding totalsteps, so it only has
120 been tested in those cases currently."""
121 with paropen(self._logfile, 'r', comm=self.comm) as fd:
122 lines = fd.read().splitlines()
124 self._log('msg', 'Attempting to resume stopped run.')
125 self._log('msg', 'Using existing minima file with %i prior '
126 'minima: %s' % (len(self._minima), self._minima_traj))
127 mdcount, qncount = 0, 0
128 for line in lines:
129 if (line[:4] == 'par:') and ('Ediff' not in line):
130 self._temperature = float(line.split()[1])
131 self._Ediff = float(line.split()[2])
132 elif line[:18] == 'msg: Optimization:':
133 qncount = int(line[19:].split('qn')[1])
134 elif line[:24] == 'msg: Molecular dynamics:':
135 mdcount = int(line[25:].split('md')[1])
136 self._counter = max((mdcount, qncount))
137 if qncount == mdcount:
138 # Either stopped during local optimization or terminated due to
139 # max steps.
140 self._log('msg', 'Attempting to resume at qn%05i' % qncount)
141 if qncount > 0:
142 atoms = io.read('qn%05i.traj' % (qncount - 1), index=-1)
143 self._previous_optimum = atoms.copy()
144 self._previous_energy = atoms.get_potential_energy()
145 if os.path.getsize('qn%05i.traj' % qncount) > 0:
146 atoms = io.read('qn%05i.traj' % qncount, index=-1)
147 else:
148 atoms = io.read('md%05i.traj' % qncount, index=-3)
149 self._atoms.positions = atoms.get_positions()
150 fmax = np.sqrt((atoms.get_forces() ** 2).sum(axis=1).max())
151 if fmax < self._fmax:
152 # Stopped after a qn finished.
153 self._log('msg', 'qn%05i fmax already less than fmax=%.3f'
154 % (qncount, self._fmax))
155 self._counter += 1
156 return
157 self._optimize()
158 self._counter += 1
159 if qncount > 0:
160 self._check_results()
161 else:
162 self._record_minimum()
163 self._log('msg', 'Found a new minimum.')
164 self._log('msg', 'Accepted new minimum.')
165 self._log('par')
166 elif qncount < mdcount:
167 # Probably stopped during molecular dynamics.
168 self._log('msg', 'Attempting to resume at md%05i.' % mdcount)
169 atoms = io.read('qn%05i.traj' % qncount, index=-1)
170 self._previous_optimum = atoms.copy()
171 self._previous_energy = atoms.get_potential_energy()
172 self._molecular_dynamics(resume=mdcount)
173 self._optimize()
174 self._counter += 1
175 self._check_results()
177 def _check_results(self):
178 """Adjusts parameters and positions based on outputs."""
180 # No prior minima found?
181 self._read_minima()
182 if len(self._minima) == 0:
183 self._log('msg', 'Found a new minimum.')
184 self._log('msg', 'Accepted new minimum.')
185 self._record_minimum()
186 self._log('par')
187 return
188 # Returned to starting position?
189 if self._previous_optimum:
190 compare = ComparePositions(translate=False)
191 dmax = compare(self._atoms, self._previous_optimum)
192 self._log('msg', 'Max distance to last minimum: %.3f A' % dmax)
193 if dmax < self._minima_threshold:
194 self._log('msg', 'Re-found last minimum.')
195 self._temperature *= self._beta1
196 self._log('par')
197 return
198 # In a previously found position?
199 unique, dmax_closest = self._unique_minimum_position()
200 self._log('msg', 'Max distance to closest minimum: %.3f A' %
201 dmax_closest)
202 if not unique:
203 self._temperature *= self._beta2
204 self._log('msg', 'Found previously found minimum.')
205 self._log('par')
206 if self._previous_optimum:
207 self._log('msg', 'Restoring last minimum.')
208 self._atoms.positions = self._previous_optimum.positions
209 return
210 # Must have found a unique minimum.
211 self._temperature *= self._beta3
212 self._log('msg', 'Found a new minimum.')
213 self._log('par')
214 if (self._previous_energy is None or
215 (self._atoms.get_potential_energy() <
216 self._previous_energy + self._Ediff)):
217 self._log('msg', 'Accepted new minimum.')
218 self._Ediff *= self._alpha1
219 self._log('par')
220 self._record_minimum()
221 else:
222 self._log('msg', 'Rejected new minimum due to energy. '
223 'Restoring last minimum.')
224 self._atoms.positions = self._previous_optimum.positions
225 self._Ediff *= self._alpha2
226 self._log('par')
228 def _log(self, cat='msg', message=None):
229 """Records the message as a line in the log file."""
230 if cat == 'init':
231 if self.comm.rank == 0:
232 if os.path.exists(self._logfile):
233 raise RuntimeError(f'File exists: {self._logfile}')
234 with paropen(self._logfile, 'w', comm=self.comm) as fd:
235 fd.write('par: %12s %12s %12s\n' % ('T (K)', 'Ediff (eV)',
236 'mdmin'))
237 fd.write('ene: %12s %12s %12s\n' % ('E_current', 'E_previous',
238 'Difference'))
239 return
240 with paropen(self._logfile, 'a', comm=self.comm) as fd:
241 if cat == 'msg':
242 line = f'msg: {message}'
243 elif cat == 'par':
244 line = ('par: %12.4f %12.4f %12i' %
245 (self._temperature, self._Ediff, self._mdmin))
246 elif cat == 'ene':
247 current = self._atoms.get_potential_energy()
248 if self._previous_optimum:
249 previous = self._previous_energy
250 line = ('ene: %12.5f %12.5f %12.5f' %
251 (current, previous, current - previous))
252 else:
253 line = ('ene: %12.5f' % current)
254 fd.write(line + '\n')
256 def _optimize(self):
257 """Perform an optimization."""
258 self._atoms.set_momenta(np.zeros(self._atoms.get_momenta().shape))
259 with self._optimizer(self._atoms,
260 trajectory='qn%05i.traj' % self._counter,
261 logfile='qn%05i.log' % self._counter) as opt:
262 self._log('msg', 'Optimization: qn%05i' % self._counter)
263 opt.run(fmax=self._fmax)
264 self._log('ene')
266 def _record_minimum(self):
267 """Adds the current atoms configuration to the minima list."""
268 with io.Trajectory(self._minima_traj, 'a') as traj:
269 traj.write(self._atoms)
270 self._read_minima()
271 self._log('msg', 'Recorded minima #%i.' % (len(self._minima) - 1))
273 def _read_minima(self):
274 """Reads in the list of minima from the minima file."""
275 exists = os.path.exists(self._minima_traj)
276 if exists:
277 empty = os.path.getsize(self._minima_traj) == 0
278 if not empty:
279 with io.Trajectory(self._minima_traj, 'r') as traj:
280 self._minima = [atoms for atoms in traj]
281 else:
282 self._minima = []
283 return True
284 else:
285 self._minima = []
286 return False
288 def _molecular_dynamics(self, resume=None):
289 """Performs a molecular dynamics simulation, until mdmin is
290 exceeded. If resuming, the file number (md%05i) is expected."""
291 self._log('msg', 'Molecular dynamics: md%05i' % self._counter)
292 mincount = 0
293 energies, oldpositions = [], []
294 thermalized = False
295 if resume:
296 self._log('msg', 'Resuming MD from md%05i.traj' % resume)
297 if os.path.getsize('md%05i.traj' % resume) == 0:
298 self._log('msg', 'md%05i.traj is empty. Resuming from '
299 'qn%05i.traj.' % (resume, resume - 1))
300 atoms = io.read('qn%05i.traj' % (resume - 1), index=-1)
301 else:
302 with io.Trajectory('md%05i.traj' % resume, 'r') as images:
303 for atoms in images:
304 energies.append(atoms.get_potential_energy())
305 oldpositions.append(atoms.positions.copy())
306 passedmin = self._passedminimum(energies)
307 if passedmin:
308 mincount += 1
309 self._atoms.set_momenta(atoms.get_momenta())
310 thermalized = True
311 self._atoms.positions = atoms.get_positions()
312 self._log('msg', 'Starting MD with %i existing energies.' %
313 len(energies))
314 if not thermalized:
315 MaxwellBoltzmannDistribution(self._atoms,
316 temperature_K=self._temperature,
317 force_temp=True,
318 rng=self._rng)
319 traj = io.Trajectory('md%05i.traj' % self._counter, 'a',
320 self._atoms)
321 dyn = VelocityVerlet(self._atoms, timestep=self._timestep * units.fs)
322 log = MDLogger(dyn, self._atoms, 'md%05i.log' % self._counter,
323 header=True, stress=False, peratom=False)
325 with traj, dyn, log:
326 dyn.attach(log, interval=1)
327 dyn.attach(traj, interval=1)
328 while mincount < self._mdmin:
329 dyn.run(1)
330 energies.append(self._atoms.get_potential_energy())
331 passedmin = self._passedminimum(energies)
332 if passedmin:
333 mincount += 1
334 oldpositions.append(self._atoms.positions.copy())
335 # Reset atoms to minimum point.
336 self._atoms.positions = oldpositions[passedmin[0]]
338 def _unique_minimum_position(self):
339 """Identifies if the current position of the atoms, which should be
340 a local minima, has been found before."""
341 unique = True
342 dmax_closest = 99999.
343 compare = ComparePositions(translate=True)
344 self._read_minima()
345 for minimum in self._minima:
346 dmax = compare(minimum, self._atoms)
347 if dmax < self._minima_threshold:
348 unique = False
349 if dmax < dmax_closest:
350 dmax_closest = dmax
351 return unique, dmax_closest
354class ComparePositions:
355 """Class that compares the atomic positions between two ASE atoms
356 objects. Returns the maximum distance that any atom has moved, assuming
357 all atoms of the same element are indistinguishable. If translate is
358 set to True, allows for arbitrary translations within the unit cell,
359 as well as translations across any periodic boundary conditions. When
360 called, returns the maximum displacement of any one atom."""
362 def __init__(self, translate=True):
363 self._translate = translate
365 def __call__(self, atoms1, atoms2):
366 atoms1 = atoms1.copy()
367 atoms2 = atoms2.copy()
368 if not self._translate:
369 dmax = self. _indistinguishable_compare(atoms1, atoms2)
370 else:
371 dmax = self._translated_compare(atoms1, atoms2)
372 return dmax
374 def _translated_compare(self, atoms1, atoms2):
375 """Moves the atoms around and tries to pair up atoms, assuming any
376 atoms with the same symbol are indistinguishable, and honors
377 periodic boundary conditions (for example, so that an atom at
378 (0.1, 0., 0.) correctly is found to be close to an atom at
379 (7.9, 0., 0.) if the atoms are in an orthorhombic cell with
380 x-dimension of 8. Returns dmax, the maximum distance between any
381 two atoms in the optimal configuration."""
382 atoms1.set_constraint()
383 atoms2.set_constraint()
384 for index in range(3):
385 assert atoms1.pbc[index] == atoms2.pbc[index]
386 least = self._get_least_common(atoms1)
387 indices1 = [atom.index for atom in atoms1 if atom.symbol == least[0]]
388 indices2 = [atom.index for atom in atoms2 if atom.symbol == least[0]]
389 # Make comparison sets from atoms2, which contain repeated atoms in
390 # all pbc's and bring the atom listed in indices2 to (0,0,0)
391 comparisons = []
392 repeat = []
393 for bc in atoms2.pbc:
394 if bc:
395 repeat.append(3)
396 else:
397 repeat.append(1)
398 repeated = atoms2.repeat(repeat)
399 moved_cell = atoms2.cell * atoms2.pbc
400 for moved in moved_cell:
401 repeated.translate(-moved)
402 repeated.set_cell(atoms2.cell)
403 for index in indices2:
404 comparison = repeated.copy()
405 comparison.translate(-atoms2[index].position)
406 comparisons.append(comparison)
407 # Bring the atom listed in indices1 to (0,0,0) [not whole list]
408 standard = atoms1.copy()
409 standard.translate(-atoms1[indices1[0]].position)
410 # Compare the standard to the comparison sets.
411 dmaxes = []
412 for comparison in comparisons:
413 dmax = self._indistinguishable_compare(standard, comparison)
414 dmaxes.append(dmax)
415 return min(dmaxes)
417 def _get_least_common(self, atoms):
418 """Returns the least common element in atoms. If more than one,
419 returns the first encountered."""
420 symbols = [atom.symbol for atom in atoms]
421 least = ['', np.inf]
422 for element in set(symbols):
423 count = symbols.count(element)
424 if count < least[1]:
425 least = [element, count]
426 return least
428 def _indistinguishable_compare(self, atoms1, atoms2):
429 """Finds each atom in atoms1's nearest neighbor with the same
430 chemical symbol in atoms2. Return dmax, the farthest distance an
431 individual atom differs by."""
432 atoms2 = atoms2.copy() # allow deletion
433 atoms2.set_constraint()
434 dmax = 0.
435 for atom1 in atoms1:
436 closest = [np.nan, np.inf]
437 for index, atom2 in enumerate(atoms2):
438 if atom2.symbol == atom1.symbol:
439 d = np.linalg.norm(atom1.position - atom2.position)
440 if d < closest[1]:
441 closest = [index, d]
442 if closest[1] > dmax:
443 dmax = closest[1]
444 del atoms2[closest[0]]
445 return dmax
448class PassedMinimum:
449 """Simple routine to find if a minimum in the potential energy surface
450 has been passed. In its default settings, a minimum is found if the
451 sequence ends with two downward points followed by two upward points.
452 Initialize with n_down and n_up, integer values of the number of up and
453 down points. If it has successfully determined it passed a minimum, it
454 returns the value (energy) of that minimum and the number of positions
455 back it occurred, otherwise returns None."""
457 def __init__(self, n_down=2, n_up=2):
458 self._ndown = n_down
459 self._nup = n_up
461 def __call__(self, energies):
462 if len(energies) < (self._nup + self._ndown + 1):
463 return None
464 status = True
465 index = -1
466 for _ in range(self._nup):
467 if energies[index] < energies[index - 1]:
468 status = False
469 index -= 1
470 for _ in range(self._ndown):
471 if energies[index] > energies[index - 1]:
472 status = False
473 index -= 1
474 if status:
475 return (-self._nup - 1), energies[-self._nup - 1]
478class MHPlot:
479 """Makes a plot summarizing the output of the MH algorithm from the
480 specified rundirectory. If no rundirectory is supplied, uses the
481 current directory."""
483 def __init__(self, rundirectory=None, logname='hop.log'):
484 if not rundirectory:
485 rundirectory = os.getcwd()
486 self._rundirectory = rundirectory
487 self._logname = logname
488 self._read_log()
489 self._fig, self._ax = self._makecanvas()
490 self._plot_data()
492 def get_figure(self):
493 """Returns the matplotlib figure object."""
494 return self._fig
496 def save_figure(self, filename):
497 """Saves the file to the specified path, with any allowed
498 matplotlib extension (e.g., .pdf, .png, etc.)."""
499 self._fig.savefig(filename)
501 def _read_log(self):
502 """Reads relevant parts of the log file."""
503 data = [] # format: [energy, status, temperature, ediff]
505 with open(os.path.join(self._rundirectory, self._logname)) as fd:
506 lines = fd.read().splitlines()
508 step_almost_over = False
509 step_over = False
510 for line in lines:
511 if line.startswith('msg: Molecular dynamics:'):
512 status = 'performing MD'
513 elif line.startswith('msg: Optimization:'):
514 status = 'performing QN'
515 elif line.startswith('ene:'):
516 status = 'local optimum reached'
517 energy = floatornan(line.split()[1])
518 elif line.startswith('msg: Accepted new minimum.'):
519 status = 'accepted'
520 step_almost_over = True
521 elif line.startswith('msg: Found previously found minimum.'):
522 status = 'previously found minimum'
523 step_almost_over = True
524 elif line.startswith('msg: Re-found last minimum.'):
525 status = 'previous minimum'
526 step_almost_over = True
527 elif line.startswith('msg: Rejected new minimum'):
528 status = 'rejected'
529 step_almost_over = True
530 elif line.startswith('par: '):
531 temperature = floatornan(line.split()[1])
532 ediff = floatornan(line.split()[2])
533 if step_almost_over:
534 step_over = True
535 step_almost_over = False
536 if step_over:
537 data.append([energy, status, temperature, ediff])
538 step_over = False
539 if data[-1][1] != status:
540 data.append([np.nan, status, temperature, ediff])
541 self._data = data
543 def _makecanvas(self):
544 from matplotlib import pyplot
545 from matplotlib.ticker import ScalarFormatter
546 fig = pyplot.figure(figsize=(6., 8.))
547 lm, rm, bm, tm = 0.22, 0.02, 0.05, 0.04
548 vg1 = 0.01 # between adjacent energy plots
549 vg2 = 0.03 # between different types of plots
550 ratio = 2. # size of an energy plot to a parameter plot
551 figwidth = 1. - lm - rm
552 totalfigheight = 1. - bm - tm - vg1 - 2. * vg2
553 parfigheight = totalfigheight / (2. * ratio + 2)
554 epotheight = ratio * parfigheight
555 ax1 = fig.add_axes((lm, bm, figwidth, epotheight))
556 ax2 = fig.add_axes((lm, bm + epotheight + vg1,
557 figwidth, epotheight))
558 for ax in [ax1, ax2]:
559 ax.yaxis.set_major_formatter(ScalarFormatter(useOffset=False))
560 ediffax = fig.add_axes((lm, bm + 2. * epotheight + vg1 + vg2,
561 figwidth, parfigheight))
562 tempax = fig.add_axes((lm, (bm + 2 * epotheight + vg1 + 2 * vg2 +
563 parfigheight), figwidth, parfigheight))
564 for ax in [ax2, tempax, ediffax]:
565 ax.set_xticklabels([])
566 ax1.set_xlabel('step')
567 tempax.set_ylabel('$T$, K')
568 ediffax.set_ylabel(r'$E_\mathrm{diff}$, eV')
569 for ax in [ax1, ax2]:
570 ax.set_ylabel(r'$E_\mathrm{pot}$, eV')
571 ax = CombinedAxis(ax1, ax2, tempax, ediffax)
572 self._set_zoomed_range(ax)
573 ax1.spines['top'].set_visible(False)
574 ax2.spines['bottom'].set_visible(False)
575 return fig, ax
577 def _set_zoomed_range(self, ax):
578 """Try to intelligently set the range for the zoomed-in part of the
579 graph."""
580 energies = [line[0] for line in self._data
581 if not np.isnan(line[0])]
582 dr = max(energies) - min(energies)
583 if dr == 0.:
584 dr = 1.
585 ax.set_ax1_range((min(energies) - 0.2 * dr,
586 max(energies) + 0.2 * dr))
588 def _plot_data(self):
589 for step, line in enumerate(self._data):
590 self._plot_energy(step, line)
591 self._plot_qn(step, line)
592 self._plot_md(step, line)
593 self._plot_parameters()
594 self._ax.set_xlim(self._ax.ax1.get_xlim())
596 def _plot_energy(self, step, line):
597 """Plots energy and annotation for acceptance."""
598 energy, status = line[0], line[1]
599 if np.isnan(energy):
600 return
601 self._ax.plot([step, step + 0.5], [energy] * 2, '-',
602 color='k', linewidth=2.)
603 if status == 'accepted':
604 self._ax.text(step + 0.51, energy, r'$\checkmark$')
605 elif status == 'rejected':
606 self._ax.text(step + 0.51, energy, r'$\Uparrow$', color='red')
607 elif status == 'previously found minimum':
608 self._ax.text(step + 0.51, energy, r'$\hookleftarrow$',
609 color='red', va='center')
610 elif status == 'previous minimum':
611 self._ax.text(step + 0.51, energy, r'$\leftarrow$',
612 color='red', va='center')
614 def _plot_md(self, step, line):
615 """Adds a curved plot of molecular dynamics trajectory."""
616 if step == 0:
617 return
618 energies = [self._data[step - 1][0]]
619 file = os.path.join(self._rundirectory, 'md%05i.traj' % step)
620 with io.Trajectory(file, 'r') as traj:
621 for atoms in traj:
622 energies.append(atoms.get_potential_energy())
623 xi = step - 1 + .5
624 if len(energies) > 2:
625 xf = xi + (step + 0.25 - xi) * len(energies) / (len(energies) - 2.)
626 else:
627 xf = step
628 if xf > (step + .75):
629 xf = step
630 self._ax.plot(np.linspace(xi, xf, num=len(energies)), energies,
631 '-k')
633 def _plot_qn(self, index, line):
634 """Plots a dashed vertical line for the optimization."""
635 if line[1] == 'performing MD':
636 return
637 file = os.path.join(self._rundirectory, 'qn%05i.traj' % index)
638 if os.path.getsize(file) == 0:
639 return
640 with io.Trajectory(file, 'r') as traj:
641 energies = [traj[0].get_potential_energy(),
642 traj[-1].get_potential_energy()]
643 if index > 0:
644 file = os.path.join(self._rundirectory, 'md%05i.traj' % index)
645 atoms = io.read(file, index=-3)
646 energies[0] = atoms.get_potential_energy()
647 self._ax.plot([index + 0.25] * 2, energies, ':k')
649 def _plot_parameters(self):
650 """Adds a plot of temperature and Ediff to the plot."""
651 steps, Ts, ediffs = [], [], []
652 for step, line in enumerate(self._data):
653 steps.extend([step + 0.5, step + 1.5])
654 Ts.extend([line[2]] * 2)
655 ediffs.extend([line[3]] * 2)
656 self._ax.tempax.plot(steps, Ts)
657 self._ax.ediffax.plot(steps, ediffs)
659 for ax in [self._ax.tempax, self._ax.ediffax]:
660 ylim = ax.get_ylim()
661 yrange = ylim[1] - ylim[0]
662 ax.set_ylim((ylim[0] - 0.1 * yrange, ylim[1] + 0.1 * yrange))
665def floatornan(value):
666 """Converts the argument into a float if possible, np.nan if not."""
667 try:
668 output = float(value)
669 except ValueError:
670 output = np.nan
671 return output
674class CombinedAxis:
675 """Helper class for MHPlot to plot on split y axis and adjust limits
676 simultaneously."""
678 def __init__(self, ax1, ax2, tempax, ediffax):
679 self.ax1 = ax1
680 self.ax2 = ax2
681 self.tempax = tempax
682 self.ediffax = ediffax
683 self._ymax = -np.inf
685 def set_ax1_range(self, ylim):
686 self._ax1_ylim = ylim
687 self.ax1.set_ylim(ylim)
689 def plot(self, *args, **kwargs):
690 self.ax1.plot(*args, **kwargs)
691 self.ax2.plot(*args, **kwargs)
692 # Re-adjust yrange
693 for yvalue in args[1]:
694 if yvalue > self._ymax:
695 self._ymax = yvalue
696 self.ax1.set_ylim(self._ax1_ylim)
697 self.ax2.set_ylim((self._ax1_ylim[1], self._ymax))
699 def set_xlim(self, *args):
700 self.ax1.set_xlim(*args)
701 self.ax2.set_xlim(*args)
702 self.tempax.set_xlim(*args)
703 self.ediffax.set_xlim(*args)
705 def text(self, *args, **kwargs):
706 y = args[1]
707 if y < self._ax1_ylim[1]:
708 ax = self.ax1
709 else:
710 ax = self.ax2
711 ax.text(*args, **kwargs)