Coverage for /builds/ase/ase/ase/gui/images.py: 70.97%
279 statements
« prev ^ index » next coverage.py v7.5.3, created at 2025-08-02 00:12 +0000
« prev ^ index » next coverage.py v7.5.3, created at 2025-08-02 00:12 +0000
1# fmt: off
3import warnings
4from math import sqrt
6import numpy as np
8from ase import Atoms
9from ase.calculators.singlepoint import SinglePointCalculator
10from ase.constraints import FixAtoms
11from ase.data import covalent_radii
12from ase.geometry import find_mic
13from ase.gui.defaults import read_defaults
14from ase.gui.i18n import _
15from ase.io import read, string2index, write
18class Images:
19 def __init__(self, images=None):
20 self.covalent_radii = covalent_radii.copy()
21 self.config = read_defaults()
22 self.atom_scale = self.config['radii_scale']
23 if images is None:
24 images = [Atoms()]
25 self.initialize(images)
27 def __len__(self):
28 return len(self._images)
30 def __getitem__(self, index):
31 return self._images[index]
33 def __iter__(self):
34 return iter(self._images)
36 # XXXXXXX hack
37 # compatibility hacks while allowing variable number of atoms
38 def get_dynamic(self, atoms: Atoms) -> np.ndarray:
39 dynamic = np.ones(len(atoms), bool)
40 for constraint in atoms.constraints:
41 if isinstance(constraint, FixAtoms):
42 dynamic[constraint.index] = False
43 return dynamic
45 def set_dynamic(self, mask, value):
46 # Does not make much sense if different images have different
47 # atom counts. Attempts to apply mask to all images,
48 # to the extent possible.
49 for atoms in self:
50 dynamic = self.get_dynamic(atoms)
51 dynamic[mask[:len(atoms)]] = value
52 atoms.constraints = [c for c in atoms.constraints
53 if not isinstance(c, FixAtoms)]
54 atoms.constraints.append(FixAtoms(mask=~dynamic))
56 def scale_radii(self, scaling_factor):
57 self.covalent_radii *= scaling_factor
59 def get_energy(self, atoms: Atoms) -> np.float64:
60 try:
61 return atoms.get_potential_energy()
62 except RuntimeError:
63 return np.nan # type: ignore[return-value]
65 def get_forces(self, atoms: Atoms):
66 try:
67 return atoms.get_forces(apply_constraint=False)
68 except RuntimeError:
69 return None
71 def initialize(self, images, filenames=None):
72 nimages = len(images)
73 if filenames is None:
74 filenames = [None] * nimages
75 self.filenames = filenames
77 warning = False
79 self._images = []
81 # Whether length or chemical composition changes:
82 self.have_varying_species = False
83 for i, atoms in enumerate(images):
84 # copy atoms or not? Not copying allows back-editing,
85 # but copying actually forgets things like the attached
86 # calculator (might have forces/energies
87 self._images.append(atoms)
88 self.have_varying_species |= not np.array_equal(self[0].numbers,
89 atoms.numbers)
90 if (atoms.pbc != self[0].pbc).any():
91 warning = True
93 if warning:
94 import warnings
95 warnings.warn('Not all images have the same boundary conditions!')
97 self.maxnatoms = max(len(atoms) for atoms in self)
98 self.selected = np.zeros(self.maxnatoms, bool)
99 self.selected_ordered = []
100 self.visible = np.ones(self.maxnatoms, bool)
101 self.repeat = np.ones(3, int)
103 def get_radii(self, atoms: Atoms) -> np.ndarray:
104 radii = np.array([self.covalent_radii[z] for z in atoms.numbers])
105 radii *= self.atom_scale
106 return radii
108 def read(self, filenames, default_index=':', filetype=None):
109 if isinstance(default_index, str):
110 default_index = string2index(default_index)
112 images = []
113 names = []
114 for filename in filenames:
115 from ase.io.formats import parse_filename
117 if '@' in filename and 'postgres' not in filename or \
118 'postgres' in filename and filename.count('@') == 2:
119 actual_filename, index = parse_filename(filename, None)
120 else:
121 actual_filename, index = parse_filename(filename,
122 default_index)
124 # Read from stdin:
125 if filename == '-':
126 import sys
127 from io import BytesIO
128 buf = BytesIO(sys.stdin.buffer.read())
129 buf.seek(0)
130 filename = buf
131 filetype = 'traj'
133 imgs = read(filename, index, filetype)
134 if hasattr(imgs, 'iterimages'):
135 imgs = list(imgs.iterimages())
137 images.extend(imgs)
139 # Name each file as filename@index:
140 if isinstance(index, slice):
141 start = index.start or 0
142 step = index.step or 1
143 else:
144 start = index
145 step = 1
146 for i, img in enumerate(imgs):
147 if isinstance(start, int):
148 names.append('{}@{}'.format(
149 actual_filename, start + i * step))
150 else:
151 names.append(f'{actual_filename}@{start}')
153 self.initialize(images, names)
155 def repeat_results(self, atoms: Atoms, repeat=None, oldprod=None):
156 """Return a dictionary which updates the magmoms, energy and forces
157 to the repeated amount of atoms.
158 """
159 def getresult(name, get_quantity):
160 # ase/io/trajectory.py line 170 does this by using
161 # the get_property(prop, atoms, allow_calculation=False)
162 # so that is an alternative option.
163 try:
164 if (not atoms.calc or
165 atoms.calc.calculation_required(atoms, [name])):
166 quantity = None
167 else:
168 quantity = get_quantity()
169 except Exception as err:
170 quantity = None
171 errmsg = ('An error occurred while retrieving {} '
172 'from the calculator: {}'.format(name, err))
173 warnings.warn(errmsg)
174 return quantity
176 if repeat is None:
177 repeat = self.repeat.prod()
178 if oldprod is None:
179 oldprod = self.repeat.prod()
181 results = {}
183 original_length = len(atoms) // oldprod
184 newprod = repeat.prod()
186 # Read the old properties
187 magmoms = getresult('magmoms', atoms.get_magnetic_moments)
188 magmom = getresult('magmom', atoms.get_magnetic_moment)
189 energy = getresult('energy', atoms.get_potential_energy)
190 forces = getresult('forces', atoms.get_forces)
192 # Update old properties to the repeated image
193 if magmoms is not None:
194 magmoms = np.tile(magmoms[:original_length], newprod)
195 results['magmoms'] = magmoms
197 if magmom is not None:
198 magmom = magmom * newprod / oldprod
199 results['magmom'] = magmom
201 if forces is not None:
202 forces = np.tile(forces[:original_length].T, newprod).T
203 results['forces'] = forces
205 if energy is not None:
206 energy = energy * newprod / oldprod
207 results['energy'] = energy
209 return results
211 def repeat_unit_cell(self):
212 for atoms in self:
213 # Get quantities taking into account current repeat():'
214 results = self.repeat_results(atoms, self.repeat.prod(),
215 oldprod=self.repeat.prod())
217 atoms.cell *= self.repeat.reshape((3, 1))
218 atoms.calc = SinglePointCalculator(atoms, **results)
219 self.repeat = np.ones(3, int)
221 def repeat_images(self, repeat):
222 from ase.constraints import FixAtoms
223 repeat = np.array(repeat)
224 oldprod = self.repeat.prod()
225 images = []
226 constraints_removed = False
228 for i, atoms in enumerate(self):
229 refcell = atoms.get_cell()
230 fa = []
231 for c in atoms._constraints:
232 if isinstance(c, FixAtoms):
233 fa.append(c)
234 else:
235 constraints_removed = True
236 atoms.set_constraint(fa)
238 # Update results dictionary to repeated atoms
239 results = self.repeat_results(atoms, repeat, oldprod)
241 del atoms[len(atoms) // oldprod:] # Original atoms
243 atoms *= repeat
244 atoms.cell = refcell
246 atoms.calc = SinglePointCalculator(atoms, **results)
248 images.append(atoms)
250 if constraints_removed:
251 from ase.gui.ui import showwarning, tk
253 # We must be able to show warning before the main GUI
254 # has been created. So we create a new window,
255 # then show the warning, then destroy the window.
256 tmpwindow = tk.Tk()
257 tmpwindow.withdraw() # Host window will never be shown
258 showwarning(_('Constraints discarded'),
259 _('Constraints other than FixAtoms '
260 'have been discarded.'))
261 tmpwindow.destroy()
263 self.initialize(images, filenames=self.filenames)
264 self.repeat = repeat
266 def center(self):
267 """Center each image in the existing unit cell, keeping the
268 cell constant."""
269 for atoms in self:
270 atoms.center()
272 def graph(self, expr: str) -> np.ndarray:
273 """Routine to create the data in graphs, defined by the
274 string expr."""
275 import ase.units as units
276 code = compile(expr + ',', '<input>', 'eval')
278 nimages = len(self)
280 def d(n1, n2):
281 return sqrt(((R[n1] - R[n2])**2).sum())
283 def a(n1, n2, n3):
284 v1 = R[n1] - R[n2]
285 v2 = R[n3] - R[n2]
286 arg = np.vdot(v1, v2) / (sqrt((v1**2).sum() * (v2**2).sum()))
287 if arg > 1.0:
288 arg = 1.0
289 if arg < -1.0:
290 arg = -1.0
291 return 180.0 * np.arccos(arg) / np.pi
293 def dih(n1, n2, n3, n4):
294 # vector 0->1, 1->2, 2->3 and their normalized cross products:
295 a = R[n2] - R[n1]
296 b = R[n3] - R[n2]
297 c = R[n4] - R[n3]
298 bxa = np.cross(b, a)
299 bxa /= np.sqrt(np.vdot(bxa, bxa))
300 cxb = np.cross(c, b)
301 cxb /= np.sqrt(np.vdot(cxb, cxb))
302 angle = np.vdot(bxa, cxb)
303 # check for numerical trouble due to finite precision:
304 if angle < -1:
305 angle = -1
306 if angle > 1:
307 angle = 1
308 angle = np.arccos(angle)
309 if np.vdot(bxa, c) > 0:
310 angle = 2 * np.pi - angle
311 return angle * 180.0 / np.pi
313 # get number of mobile atoms for temperature calculation
314 E = np.array([self.get_energy(atoms) for atoms in self])
316 s = 0.0
318 # Namespace for eval:
319 ns = {'E': E,
320 'd': d, 'a': a, 'dih': dih}
322 data = []
323 for i in range(nimages):
324 ns['i'] = i
325 ns['s'] = s
326 ns['R'] = R = self[i].get_positions()
327 ns['V'] = self[i].get_velocities()
328 F = self.get_forces(self[i])
329 if F is not None:
330 ns['F'] = F
331 ns['A'] = self[i].get_cell()
332 ns['M'] = self[i].get_masses()
333 # XXX askhl verify:
334 dynamic = self.get_dynamic(self[i])
335 if F is not None:
336 ns['f'] = f = ((F * dynamic[:, None])**2).sum(1)**.5
337 ns['fmax'] = max(f)
338 ns['fave'] = f.mean()
339 ns['epot'] = epot = E[i]
340 ns['ekin'] = ekin = self[i].get_kinetic_energy()
341 ns['e'] = epot + ekin
342 ndynamic = dynamic.sum()
343 if ndynamic > 0:
344 ns['T'] = 2.0 * ekin / (3.0 * ndynamic * units.kB)
345 data = eval(code, ns)
346 if i == 0:
347 nvariables = len(data)
348 xy = np.empty((nvariables, nimages))
349 xy[:, i] = data
350 if i + 1 < nimages and not self.have_varying_species:
351 dR = find_mic(self[i + 1].positions - R, self[i].get_cell(),
352 self[i].get_pbc())[0]
353 s += sqrt((dR**2).sum())
354 return xy
356 def write(self, filename, rotations='', bbox=None,
357 **kwargs):
358 # XXX We should show the unit cell whenever there is one
359 indices = range(len(self))
360 p = filename.rfind('@')
361 if p != -1:
362 try:
363 slice = string2index(filename[p + 1:])
364 except ValueError:
365 pass
366 else:
367 indices = indices[slice]
368 filename = filename[:p]
369 if isinstance(indices, int):
370 indices = [indices]
372 images = [self.get_atoms(i) for i in indices]
373 if len(filename) > 4 and filename[-4:] in ['.eps', '.png', '.pov']:
374 write(filename, images,
375 rotation=rotations,
376 bbox=bbox, **kwargs)
377 else:
378 write(filename, images, **kwargs)
380 def get_atoms(self, frame, remove_hidden=False):
381 atoms = self[frame]
382 try:
383 E = atoms.get_potential_energy()
384 except RuntimeError:
385 E = None
386 try:
387 F = atoms.get_forces()
388 except RuntimeError:
389 F = None
391 # Remove hidden atoms if applicable
392 if remove_hidden:
393 atoms = atoms[self.visible]
394 if F is not None:
395 F = F[self.visible]
396 atoms.calc = SinglePointCalculator(atoms, energy=E, forces=F)
397 return atoms
399 def delete(self, i):
400 self._images.pop(i)
401 self.filenames.pop(i)
402 self.initialize(self._images, self.filenames)