Coverage for ase / gui / images.py: 72.95%

292 statements  

« prev     ^ index     » next       coverage.py v7.13.3, created at 2026-02-04 10:20 +0000

1# fmt: off 

2 

3import warnings 

4from math import sqrt 

5 

6import numpy as np 

7 

8from ase import Atoms 

9from ase.calculators.singlepoint import SinglePointCalculator 

10from ase.constraints import FixAtoms 

11from ase.data import atomic_numbers, covalent_radii 

12from ase.geometry import find_mic 

13from ase.gui.defaults import read_defaults 

14from ase.gui.history import History 

15from ase.gui.i18n import _ 

16from ase.io import read, string2index, write 

17 

18 

19class Images: 

20 def __init__(self, images=None): 

21 self.covalent_radii = covalent_radii.copy() 

22 self.config = read_defaults() 

23 self.configure_radii(self.config['covalent_radii']) 

24 self.atom_scale = self.config['radii_scale'] 

25 if images is None: 

26 images = [Atoms()] 

27 self.initialize(images) 

28 self.history = History(self) 

29 

30 def __len__(self): 

31 return len(self._images) 

32 

33 def __getitem__(self, index): 

34 return self._images[index] 

35 

36 def __iter__(self): 

37 return iter(self._images) 

38 

39 # XXXXXXX hack 

40 # compatibility hacks while allowing variable number of atoms 

41 def get_dynamic(self, atoms: Atoms) -> np.ndarray: 

42 dynamic = np.ones(len(atoms), bool) 

43 for constraint in atoms.constraints: 

44 if isinstance(constraint, FixAtoms): 

45 dynamic[constraint.index] = False 

46 return dynamic 

47 

48 def set_dynamic(self, mask, value): 

49 # Does not make much sense if different images have different 

50 # atom counts. Attempts to apply mask to all images, 

51 # to the extent possible. 

52 for atoms in self: 

53 dynamic = self.get_dynamic(atoms) 

54 dynamic[mask[:len(atoms)]] = value 

55 atoms.constraints = [c for c in atoms.constraints 

56 if not isinstance(c, FixAtoms)] 

57 atoms.constraints.append(FixAtoms(mask=~dynamic)) 

58 

59 def scale_radii(self, scaling_factor): 

60 self.covalent_radii *= scaling_factor 

61 

62 def configure_radii(self, radii): 

63 """Configure the GUI atom radii with a {atom: radius, ...} 

64 dictionary or a list/tuple eg. [(atom, radius), ...]""" 

65 if not radii: 

66 return 

67 if not isinstance(radii, dict): 

68 _radii = {entry[0]: entry[1] for entry in radii} 

69 else: 

70 _radii = radii 

71 for key, value in _radii.items(): 

72 if isinstance(key, str): 

73 key = atomic_numbers[key] 

74 self.covalent_radii[key] = value 

75 

76 def get_energy(self, atoms: Atoms) -> np.float64: 

77 try: 

78 return atoms.get_potential_energy() 

79 except RuntimeError: 

80 return np.nan # type: ignore[return-value] 

81 

82 def get_forces(self, atoms: Atoms): 

83 try: 

84 return atoms.get_forces(apply_constraint=False) 

85 except RuntimeError: 

86 return None 

87 

88 def initialize(self, images, filenames=None): 

89 nimages = len(images) 

90 if filenames is None: 

91 filenames = [None] * nimages 

92 self.filenames = filenames 

93 

94 warning = False 

95 

96 self._images = [] 

97 

98 # Whether length or chemical composition changes: 

99 self.have_varying_species = False 

100 for i, atoms in enumerate(images): 

101 # copy atoms or not? Not copying allows back-editing, 

102 # but copying actually forgets things like the attached 

103 # calculator (might have forces/energies 

104 self._images.append(atoms) 

105 self.have_varying_species |= not np.array_equal(self[0].numbers, 

106 atoms.numbers) 

107 if (atoms.pbc != self[0].pbc).any(): 

108 warning = True 

109 

110 if warning: 

111 import warnings 

112 warnings.warn('Not all images have the same boundary conditions!') 

113 

114 self.maxnatoms = max(len(atoms) for atoms in self) 

115 self.selected = np.zeros(self.maxnatoms, bool) 

116 self.selected_ordered = [] 

117 self.visible = np.ones(self.maxnatoms, bool) 

118 self.repeat = np.ones(3, int) 

119 

120 def get_radii(self, atoms: Atoms) -> np.ndarray: 

121 radii = np.array([self.covalent_radii[z] for z in atoms.numbers]) 

122 radii *= self.atom_scale 

123 return radii 

124 

125 def read(self, filenames, default_index=':', filetype=None): 

126 if isinstance(default_index, str): 

127 default_index = string2index(default_index) 

128 

129 images = [] 

130 names = [] 

131 for filename in filenames: 

132 from ase.io.formats import parse_filename 

133 

134 if '@' in filename and 'postgres' not in filename or \ 

135 'postgres' in filename and filename.count('@') == 2: 

136 actual_filename, index = parse_filename(filename, None) 

137 else: 

138 actual_filename, index = parse_filename(filename, 

139 default_index) 

140 

141 # Read from stdin: 

142 if filename == '-': 

143 import sys 

144 from io import BytesIO 

145 buf = BytesIO(sys.stdin.buffer.read()) 

146 buf.seek(0) 

147 filename = buf 

148 filetype = 'traj' 

149 

150 imgs = read(filename, index, filetype) 

151 if hasattr(imgs, 'iterimages'): 

152 imgs = list(imgs.iterimages()) 

153 

154 images.extend(imgs) 

155 

156 # Name each file as filename@index: 

157 if isinstance(index, slice): 

158 start = index.start or 0 

159 step = index.step or 1 

160 else: 

161 start = index 

162 step = 1 

163 for i, img in enumerate(imgs): 

164 if isinstance(start, int): 

165 names.append('{}@{}'.format( 

166 actual_filename, start + i * step)) 

167 else: 

168 names.append(f'{actual_filename}@{start}') 

169 

170 self.initialize(images, names) 

171 

172 def repeat_results(self, atoms: Atoms, repeat=None, oldprod=None): 

173 """Return a dictionary which updates the magmoms, energy and forces 

174 to the repeated amount of atoms. 

175 """ 

176 def getresult(name, get_quantity): 

177 # ase/io/trajectory.py line 170 does this by using 

178 # the get_property(prop, atoms, allow_calculation=False) 

179 # so that is an alternative option. 

180 try: 

181 if (not atoms.calc or 

182 atoms.calc.calculation_required(atoms, [name])): 

183 quantity = None 

184 else: 

185 quantity = get_quantity() 

186 except Exception as err: 

187 quantity = None 

188 errmsg = ('An error occurred while retrieving {} ' 

189 'from the calculator: {}'.format(name, err)) 

190 warnings.warn(errmsg) 

191 return quantity 

192 

193 if repeat is None: 

194 repeat = self.repeat.prod() 

195 if oldprod is None: 

196 oldprod = self.repeat.prod() 

197 

198 results = {} 

199 

200 original_length = len(atoms) // oldprod 

201 newprod = repeat.prod() 

202 

203 # Read the old properties 

204 magmoms = getresult('magmoms', atoms.get_magnetic_moments) 

205 magmom = getresult('magmom', atoms.get_magnetic_moment) 

206 energy = getresult('energy', atoms.get_potential_energy) 

207 forces = getresult('forces', atoms.get_forces) 

208 

209 # Update old properties to the repeated image 

210 if magmoms is not None: 

211 magmoms = np.tile(magmoms[:original_length], newprod) 

212 results['magmoms'] = magmoms 

213 

214 if magmom is not None: 

215 magmom = magmom * newprod / oldprod 

216 results['magmom'] = magmom 

217 

218 if forces is not None: 

219 forces = np.tile(forces[:original_length].T, newprod).T 

220 results['forces'] = forces 

221 

222 if energy is not None: 

223 energy = energy * newprod / oldprod 

224 results['energy'] = energy 

225 

226 return results 

227 

228 def repeat_unit_cell(self): 

229 for atoms in self: 

230 # Get quantities taking into account current repeat():' 

231 results = self.repeat_results(atoms, self.repeat.prod(), 

232 oldprod=self.repeat.prod()) 

233 

234 atoms.cell *= self.repeat.reshape((3, 1)) 

235 atoms.calc = SinglePointCalculator(atoms, **results) 

236 self.repeat = np.ones(3, int) 

237 

238 def repeat_images(self, repeat): 

239 from ase.constraints import FixAtoms 

240 repeat = np.array(repeat) 

241 oldprod = self.repeat.prod() 

242 images = [] 

243 constraints_removed = False 

244 

245 for i, atoms in enumerate(self): 

246 refcell = atoms.get_cell() 

247 fa = [] 

248 for c in atoms._constraints: 

249 if isinstance(c, FixAtoms): 

250 fa.append(c) 

251 else: 

252 constraints_removed = True 

253 atoms.set_constraint(fa) 

254 

255 # Update results dictionary to repeated atoms 

256 results = self.repeat_results(atoms, repeat, oldprod) 

257 

258 del atoms[len(atoms) // oldprod:] # Original atoms 

259 

260 atoms *= repeat 

261 atoms.cell = refcell 

262 

263 atoms.calc = SinglePointCalculator(atoms, **results) 

264 

265 images.append(atoms) 

266 

267 if constraints_removed: 

268 from ase.gui.ui import showwarning, tk 

269 

270 # We must be able to show warning before the main GUI 

271 # has been created. So we create a new window, 

272 # then show the warning, then destroy the window. 

273 tmpwindow = tk.Tk() 

274 tmpwindow.withdraw() # Host window will never be shown 

275 showwarning(_('Constraints discarded'), 

276 _('Constraints other than FixAtoms ' 

277 'have been discarded.')) 

278 tmpwindow.destroy() 

279 

280 self.initialize(images, filenames=self.filenames) 

281 self.repeat = repeat 

282 

283 def center(self): 

284 """Center each image in the existing unit cell, keeping the 

285 cell constant.""" 

286 for atoms in self: 

287 atoms.center() 

288 

289 def graph(self, expr: str) -> np.ndarray: 

290 """Routine to create the data in graphs, defined by the 

291 string expr.""" 

292 import ase.units as units 

293 code = compile(expr + ',', '<input>', 'eval') 

294 

295 nimages = len(self) 

296 

297 def d(n1, n2): 

298 return sqrt(((R[n1] - R[n2])**2).sum()) 

299 

300 def a(n1, n2, n3): 

301 v1 = R[n1] - R[n2] 

302 v2 = R[n3] - R[n2] 

303 arg = np.vdot(v1, v2) / (sqrt((v1**2).sum() * (v2**2).sum())) 

304 if arg > 1.0: 

305 arg = 1.0 

306 if arg < -1.0: 

307 arg = -1.0 

308 return 180.0 * np.arccos(arg) / np.pi 

309 

310 def dih(n1, n2, n3, n4): 

311 # vector 0->1, 1->2, 2->3 and their normalized cross products: 

312 a = R[n2] - R[n1] 

313 b = R[n3] - R[n2] 

314 c = R[n4] - R[n3] 

315 bxa = np.cross(b, a) 

316 bxa /= np.sqrt(np.vdot(bxa, bxa)) 

317 cxb = np.cross(c, b) 

318 cxb /= np.sqrt(np.vdot(cxb, cxb)) 

319 angle = np.vdot(bxa, cxb) 

320 # check for numerical trouble due to finite precision: 

321 if angle < -1: 

322 angle = -1 

323 if angle > 1: 

324 angle = 1 

325 angle = np.arccos(angle) 

326 if np.vdot(bxa, c) > 0: 

327 angle = 2 * np.pi - angle 

328 return angle * 180.0 / np.pi 

329 

330 # get number of mobile atoms for temperature calculation 

331 E = np.array([self.get_energy(atoms) for atoms in self]) 

332 

333 s = 0.0 

334 

335 # Namespace for eval: 

336 ns = {'E': E, 

337 'd': d, 'a': a, 'dih': dih} 

338 

339 data = [] 

340 for i in range(nimages): 

341 ns['i'] = i 

342 ns['s'] = s 

343 ns['R'] = R = self[i].get_positions() 

344 ns['V'] = self[i].get_velocities() 

345 F = self.get_forces(self[i]) 

346 if F is not None: 

347 ns['F'] = F 

348 ns['A'] = self[i].get_cell() 

349 ns['M'] = self[i].get_masses() 

350 # XXX askhl verify: 

351 dynamic = self.get_dynamic(self[i]) 

352 if F is not None: 

353 ns['f'] = f = ((F * dynamic[:, None])**2).sum(1)**.5 

354 ns['fmax'] = max(f) 

355 ns['fave'] = f.mean() 

356 ns['epot'] = epot = E[i] 

357 ns['ekin'] = ekin = self[i].get_kinetic_energy() 

358 ns['e'] = epot + ekin 

359 ndynamic = dynamic.sum() 

360 if ndynamic > 0: 

361 ns['T'] = 2.0 * ekin / (3.0 * ndynamic * units.kB) 

362 data = eval(code, ns) 

363 if i == 0: 

364 nvariables = len(data) 

365 xy = np.empty((nvariables, nimages)) 

366 xy[:, i] = data 

367 if i + 1 < nimages and not self.have_varying_species: 

368 dR = find_mic(self[i + 1].positions - R, self[i].get_cell(), 

369 self[i].get_pbc())[0] 

370 s += sqrt((dR**2).sum()) 

371 return xy 

372 

373 def write(self, filename, rotations='', bbox=None, 

374 **kwargs): 

375 # XXX We should show the unit cell whenever there is one 

376 indices = range(len(self)) 

377 p = filename.rfind('@') 

378 if p != -1: 

379 try: 

380 slice = string2index(filename[p + 1:]) 

381 except ValueError: 

382 pass 

383 else: 

384 indices = indices[slice] 

385 filename = filename[:p] 

386 if isinstance(indices, int): 

387 indices = [indices] 

388 

389 images = [self.get_atoms(i) for i in indices] 

390 if len(filename) > 4 and filename[-4:] in ['.eps', '.png', '.pov']: 

391 write(filename, images, 

392 rotation=rotations, 

393 bbox=bbox, **kwargs) 

394 else: 

395 write(filename, images, **kwargs) 

396 

397 def get_atoms(self, frame, remove_hidden=False): 

398 atoms = self[frame] 

399 try: 

400 E = atoms.get_potential_energy() 

401 except RuntimeError: 

402 E = None 

403 try: 

404 F = atoms.get_forces() 

405 except RuntimeError: 

406 F = None 

407 

408 # Remove hidden atoms if applicable 

409 if remove_hidden: 

410 atoms = atoms[self.visible] 

411 if F is not None: 

412 F = F[self.visible] 

413 atoms.calc = SinglePointCalculator(atoms, energy=E, forces=F) 

414 return atoms 

415 

416 def delete(self, i): 

417 self._images.pop(i) 

418 self.filenames.pop(i) 

419 self.initialize(self._images, self.filenames)