Coverage for /builds/ase/ase/ase/dft/wannier.py: 58.74%
555 statements
« prev ^ index » next coverage.py v7.5.3, created at 2025-08-02 00:12 +0000
« prev ^ index » next coverage.py v7.5.3, created at 2025-08-02 00:12 +0000
1"""Partly occupied Wannier functions
3Find the set of partly occupied Wannier functions using the method from
4Thygesen, Hansen and Jacobsen PRB v72 i12 p125119 2005.
5"""
7import functools
8import warnings
9from math import pi, sqrt
10from time import time
12import numpy as np
13from scipy.linalg import qr
15from ase.dft.bandgap import bandgap
16from ase.dft.kpoints import get_monkhorst_pack_size_and_offset
17from ase.io.jsonio import read_json, write_json
18from ase.parallel import paropen
19from ase.transport.tools import dagger, normalize
21dag = dagger
24def silent(*args, **kwargs):
25 """Dummy logging function."""
28def gram_schmidt(U):
29 """Orthonormalize columns of U according to the Gram-Schmidt procedure."""
30 for i, col in enumerate(U.T):
31 for col2 in U.T[:i]:
32 col -= col2 * (col2.conj() @ col)
33 col /= np.linalg.norm(col)
36def lowdin(U, S=None):
37 """Orthonormalize columns of U according to the symmetric Lowdin procedure.
38 The implementation uses SVD, like symm. Lowdin it returns the nearest
39 orthonormal matrix, but is more robust.
40 """
42 L, _s, R = np.linalg.svd(U, full_matrices=False)
43 U[:] = L @ R
46def neighbor_k_search(k_c, G_c, kpt_kc, tol=1e-4):
47 # search for k1 (in kpt_kc) and k0 (in alldir), such that
48 # k1 - k - G + k0 = 0
49 alldir_dc = np.array(
50 [
51 [0, 0, 0],
52 [1, 0, 0],
53 [0, 1, 0],
54 [0, 0, 1],
55 [1, 1, 0],
56 [1, 0, 1],
57 [0, 1, 1],
58 ],
59 dtype=int,
60 )
61 for k0_c in alldir_dc:
62 for k1, k1_c in enumerate(kpt_kc):
63 if np.linalg.norm(k1_c - k_c - G_c + k0_c) < tol:
64 return k1, k0_c
66 raise ValueError(
67 f'Wannier: Did not find matching kpoint for kpt={k_c}. '
68 'Probably non-uniform k-point grid'
69 )
72def calculate_weights(cell_cc, normalize=True):
73 """Weights are used for non-cubic cells, see PRB **61**, 10040
74 If normalized they lose the physical dimension."""
75 alldirs_dc = np.array(
76 [[1, 0, 0], [0, 1, 0], [0, 0, 1], [1, 1, 0], [1, 0, 1], [0, 1, 1]],
77 dtype=int,
78 )
79 g = cell_cc @ cell_cc.T
80 # NOTE: Only first 3 of following 6 weights are presently used:
81 w = np.zeros(6)
82 w[0] = g[0, 0] - g[0, 1] - g[0, 2]
83 w[1] = g[1, 1] - g[0, 1] - g[1, 2]
84 w[2] = g[2, 2] - g[0, 2] - g[1, 2]
85 w[3] = g[0, 1]
86 w[4] = g[0, 2]
87 w[5] = g[1, 2]
88 # Make sure that first 3 Gdir vectors are included -
89 # these are used to calculate Wanniercenters.
90 Gdir_dc = alldirs_dc[:3]
91 weight_d = w[:3]
92 for d in range(3, 6):
93 if abs(w[d]) > 1e-5:
94 Gdir_dc = np.concatenate((Gdir_dc, alldirs_dc[d : d + 1]))
95 weight_d = np.concatenate((weight_d, w[d : d + 1]))
96 if normalize:
97 weight_d /= max(abs(weight_d))
98 return weight_d, Gdir_dc
101def steepest_descent(func, step=0.005, tolerance=1e-6, log=silent, **kwargs):
102 fvalueold = 0.0
103 fvalue = fvalueold + 10
104 count = 0
105 while abs((fvalue - fvalueold) / fvalue) > tolerance:
106 fvalueold = fvalue
107 dF = func.get_gradients()
108 func.step(dF * step, **kwargs)
109 fvalue = func.get_functional_value()
110 count += 1
111 log(f'SteepestDescent: iter={count}, value={fvalue}')
114def md_min(
115 func, step=0.25, tolerance=1e-6, max_iter=10000, log=silent, **kwargs
116):
117 log('Localize with step =', step, 'and tolerance =', tolerance)
118 finit = func.get_functional_value()
120 t = -time()
121 fvalueold = 0.0
122 fvalue = fvalueold + 10
123 count = 0
124 V = np.zeros(func.get_gradients().shape, dtype=complex)
126 while abs((fvalue - fvalueold) / fvalue) > tolerance:
127 fvalueold = fvalue
128 dF = func.get_gradients()
130 V *= (dF * V.conj()).real > 0
131 V += step * dF
132 func.step(V, **kwargs)
133 fvalue = func.get_functional_value()
135 if fvalue < fvalueold:
136 step *= 0.5
137 count += 1
138 log(f'MDmin: iter={count}, step={step}, value={fvalue}')
139 if count > max_iter:
140 t += time()
141 warnings.warn(
142 'Max iterations reached: '
143 'iters=%s, step=%s, seconds=%0.2f, value=%0.4f'
144 % (count, step, t, fvalue.real)
145 )
146 break
148 t += time()
149 log(
150 '%d iterations in %0.2f seconds (%0.2f ms/iter), endstep = %s'
151 % (count, t, t * 1000.0 / count, step)
152 )
153 log(f'Initial value={finit}, Final value={fvalue}')
156def rotation_from_projection(proj_nw, fixed, ortho=True):
157 """Determine rotation and coefficient matrices from projections
159 proj_nw = <psi_n|p_w>
160 psi_n: eigenstates
161 p_w: localized function
163 Nb (n) = Number of bands
164 Nw (w) = Number of wannier functions
165 M (f) = Number of fixed states
166 L (l) = Number of extra degrees of freedom
167 U (u) = Number of non-fixed states
168 """
170 Nb, Nw = proj_nw.shape
171 M = fixed
172 L = Nw - M
173 U = Nb - M
175 U_ww = np.empty((Nw, Nw), dtype=proj_nw.dtype)
177 # Set the section of the rotation matrix about the 'fixed' states
178 U_ww[:M] = proj_nw[:M]
180 if L > 0:
181 # If there are extra degrees of freedom we have to select L of them
182 C_ul = np.empty((U, L), dtype=proj_nw.dtype)
184 # Get the projections on the 'non fixed' states
185 proj_uw = proj_nw[M:]
187 # Obtain eigenvalues and eigevectors matrix
188 eig_w, C_ww = np.linalg.eigh(dag(proj_uw) @ proj_uw)
190 # Sort columns of eigenvectors matrix according to the eigenvalues
191 # magnitude, select only the L largest ones. Then use them to obtain
192 # the parameter C matrix.
193 C_ul[:] = proj_uw @ C_ww[:, np.argsort(-eig_w.real)[:L]]
195 # Compute the section of the rotation matrix about 'non fixed' states
196 U_ww[M:] = dag(C_ul) @ proj_uw
197 normalize(C_ul)
198 else:
199 # If there are no extra degrees of freedom we do not need any parameter
200 # matrix C
201 C_ul = np.empty((U, 0), dtype=proj_nw.dtype)
203 if ortho:
204 # Orthogonalize with Lowdin to take the closest orthogonal set
205 lowdin(U_ww)
206 else:
207 normalize(U_ww)
209 return U_ww, C_ul
212def search_for_gamma_point(kpts):
213 """Returns index of Gamma point in a list of k-points."""
214 gamma_idx = np.argmin([np.linalg.norm(kpt) for kpt in kpts])
215 if np.linalg.norm(kpts[gamma_idx]) >= 1e-14:
216 gamma_idx = None
217 return gamma_idx
220def scdm(pseudo_nkG, kpts, fixed_k, Nw):
221 """Compute localized orbitals with SCDM method
223 This method was published by Anil Damle and Lin Lin in Multiscale
224 Modeling & Simulation 16, 1392–1410 (2018).
225 For now only the isolated bands algorithm is implemented, because it is
226 intended as a drop-in replacement for other initial guess methods for
227 the ASE Wannier class.
229 pseudo_nkG = pseudo wave-functions on a real grid
230 Ng (G) = number of real grid points
231 kpts = List of k-points in the BZ
232 Nk (k) = Number of k-points
233 Nb (n) = Number of bands
234 Nw (w) = Number of wannier functions
235 fixed_k = Number of fixed states for each k-point
236 L (l) = Number of extra degrees of freedom
237 U (u) = Number of non-fixed states
238 """
240 gamma_idx = search_for_gamma_point(kpts)
241 Nk = len(kpts)
242 U_kww = []
243 C_kul = []
245 # compute factorization only at Gamma point
246 _, _, P = qr(
247 pseudo_nkG[:, gamma_idx, :],
248 mode='full',
249 pivoting=True,
250 check_finite=True,
251 )
253 for k in range(Nk):
254 A_nw = pseudo_nkG[:, k, P[:Nw]]
255 U_ww, C_ul = rotation_from_projection(
256 proj_nw=A_nw, fixed=fixed_k[k], ortho=True
257 )
258 U_kww.append(U_ww)
259 C_kul.append(C_ul)
261 U_kww = np.asarray(U_kww)
263 return C_kul, U_kww
266def arbitrary_s_orbitals(atoms, Ns, rng=np.random):
267 """
268 Generate a list of Ns randomly placed s-orbitals close to at least
269 one atom (< 1.5Å).
270 The format of the list is the one required by GPAW in initial_wannier().
271 """
272 # Create dummy copy of the Atoms object and dummy H atom
273 tmp_atoms = atoms.copy()
274 tmp_atoms.append('H')
275 s_pos = tmp_atoms.get_scaled_positions()
277 orbs = []
278 for _ in range(Ns):
279 fine = False
280 while not fine:
281 # Random position
282 x, y, z = rng.rand(3)
283 s_pos[-1] = [x, y, z]
284 tmp_atoms.set_scaled_positions(s_pos)
286 # Use dummy H atom to measure distance from any other atom
287 dists = tmp_atoms.get_distances(a=-1, indices=range(len(atoms)))
289 # Check if it is close to at least one atom
290 if (dists < 1.5).any():
291 fine = True
293 orbs.append([[x, y, z], 0, 1])
294 return orbs
297def init_orbitals(atoms, ntot, rng=np.random):
298 """
299 Place d-orbitals for every atom that has some in the valence states
300 and then random s-orbitals close to at least one atom (< 1.5Å).
301 The orbitals list format is compatible with GPAW.get_initial_wannier().
302 'atoms': ASE Atoms object
303 'ntot': total number of needed orbitals
304 'rng': generator random numbers
305 """
307 # List all the elements that should have occupied d-orbitals
308 # in the valence states (according to GPAW setups)
309 d_metals = set(
310 list(range(21, 31))
311 + list(range(39, 52))
312 + list(range(57, 84))
313 + list(range(89, 113))
314 )
315 orbs = []
317 # Start with zero orbitals
318 No = 0
320 # Add d orbitals to each d-metal
321 for i, z in enumerate(atoms.get_atomic_numbers()):
322 if z in d_metals:
323 No_new = No + 5
324 if No_new <= ntot:
325 orbs.append([i, 2, 1])
326 No = No_new
328 if No < ntot:
329 # Add random s-like orbitals if there are not enough yet
330 Ns = ntot - No
331 orbs += arbitrary_s_orbitals(atoms, Ns, rng)
333 assert sum(orb[1] * 2 + 1 for orb in orbs) == ntot
334 return orbs
337def square_modulus_of_Z_diagonal(Z_dww):
338 """
339 Square modulus of the Z matrix diagonal, the diagonal is taken
340 for the indexes running on the WFs.
341 """
342 return np.abs(Z_dww.diagonal(0, 1, 2)) ** 2
345def get_kklst(kpt_kc, Gdir_dc):
346 # Set the list of neighboring k-points k1, and the "wrapping" k0,
347 # such that k1 - k - G + k0 = 0
348 #
349 # Example: kpoints = (-0.375,-0.125,0.125,0.375), dir=0
350 # G = [0.25,0,0]
351 # k=0.375, k1= -0.375 : -0.375-0.375-0.25 => k0=[1,0,0]
352 #
353 # For a gamma point calculation k1 = k = 0, k0 = [1,0,0] for dir=0
354 Nk = len(kpt_kc)
355 Ndir = len(Gdir_dc)
357 if Nk == 1:
358 kklst_dk = np.zeros((Ndir, 1), int)
359 k0_dkc = Gdir_dc.reshape(-1, 1, 3)
360 else:
361 kklst_dk = np.empty((Ndir, Nk), int)
362 k0_dkc = np.empty((Ndir, Nk, 3), int)
364 # Distance between kpoints
365 kdist_c = np.empty(3)
366 for c in range(3):
367 # make a sorted list of the kpoint values in this direction
368 slist = np.argsort(kpt_kc[:, c], kind='mergesort')
369 skpoints_kc = np.take(kpt_kc, slist, axis=0)
370 kdist_c[c] = max(
371 skpoints_kc[n + 1, c] - skpoints_kc[n, c] for n in range(Nk - 1)
372 )
374 for d, Gdir_c in enumerate(Gdir_dc):
375 for k, k_c in enumerate(kpt_kc):
376 # setup dist vector to next kpoint
377 G_c = np.where(Gdir_c > 0, kdist_c, 0)
378 if max(G_c) < 1e-4:
379 kklst_dk[d, k] = k
380 k0_dkc[d, k] = Gdir_c
381 else:
382 kklst_dk[d, k], k0_dkc[d, k] = neighbor_k_search(
383 k_c, G_c, kpt_kc
384 )
385 return kklst_dk, k0_dkc
388def get_invkklst(kklst_dk):
389 Ndir, Nk = kklst_dk.shape
390 invkklst_dk = np.empty(kklst_dk.shape, int)
391 for d in range(Ndir):
392 for k1 in range(Nk):
393 invkklst_dk[d, k1] = kklst_dk[d].tolist().index(k1)
394 return invkklst_dk
397def choose_states(calcdata, fixedenergy, fixedstates, Nk, nwannier, log, spin):
398 if fixedenergy is None and fixedstates is not None:
399 if isinstance(fixedstates, int):
400 fixedstates = [fixedstates] * Nk
401 fixedstates_k = np.array(fixedstates, int)
402 elif fixedenergy is not None and fixedstates is None:
403 # Setting number of fixed states and EDF from given energy cutoff.
404 # All states below this energy cutoff are fixed.
405 # The reference energy is Ef for metals and CBM for insulators.
406 if calcdata.gap < 0.01 or fixedenergy < 0.01:
407 cutoff = fixedenergy + calcdata.fermi_level
408 else:
409 cutoff = fixedenergy + calcdata.lumo
411 # Find the states below the energy cutoff at each k-point
412 tmp_fixedstates_k = []
413 for k in range(Nk):
414 eps_n = calcdata.eps_skn[spin, k]
415 kindex = eps_n.searchsorted(cutoff)
416 tmp_fixedstates_k.append(kindex)
417 fixedstates_k = np.array(tmp_fixedstates_k, int)
418 elif fixedenergy is not None and fixedstates is not None:
419 raise RuntimeError('You can not set both fixedenergy and fixedstates')
421 if nwannier == 'auto':
422 if fixedenergy is None and fixedstates is None:
423 # Assume the fixedexergy parameter equal to 0 and
424 # find the states below the Fermi level at each k-point.
425 log(
426 "nwannier=auto but no 'fixedenergy' or 'fixedstates'",
427 'parameter was provided, using Fermi level as',
428 'energy cutoff.',
429 )
430 tmp_fixedstates_k = []
431 for k in range(Nk):
432 eps_n = calcdata.eps_skn[spin, k]
433 kindex = eps_n.searchsorted(calcdata.fermi_level)
434 tmp_fixedstates_k.append(kindex)
435 fixedstates_k = np.array(tmp_fixedstates_k, int)
436 nwannier = np.max(fixedstates_k)
438 # Without user choice just set nwannier fixed states without EDF
439 if fixedstates is None and fixedenergy is None:
440 fixedstates_k = np.array([nwannier] * Nk, int)
442 return fixedstates_k, nwannier
445def get_eigenvalues(calc):
446 nspins = calc.get_number_of_spins()
447 nkpts = len(calc.get_ibz_k_points())
448 nbands = calc.get_number_of_bands()
449 eps_skn = np.empty((nspins, nkpts, nbands))
451 for ispin in range(nspins):
452 for ikpt in range(nkpts):
453 eps_skn[ispin, ikpt] = calc.get_eigenvalues(kpt=ikpt, spin=ispin)
454 return eps_skn
457class CalcData:
458 def __init__(self, kpt_kc, atoms, fermi_level, lumo, eps_skn, gap):
459 self.kpt_kc = kpt_kc
460 self.atoms = atoms
461 self.fermi_level = fermi_level
462 self.lumo = lumo
463 self.eps_skn = eps_skn
464 self.gap = gap
466 @property
467 def nbands(self):
468 return self.eps_skn.shape[2]
471def get_calcdata(calc):
472 kpt_kc = calc.get_bz_k_points()
473 # Make sure there is no symmetry reduction
474 if len(calc.get_ibz_k_points()) != len(kpt_kc):
475 raise RuntimeError(
476 'K-point symmetry is not currently supported. '
477 "Please re-run your calculator with symmetry='off'."
478 )
480 lumo = calc.get_homo_lumo()[1]
481 gap = bandgap(calc=calc)[0]
482 return CalcData(
483 kpt_kc=kpt_kc,
484 atoms=calc.get_atoms(),
485 fermi_level=calc.get_fermi_level(),
486 lumo=lumo,
487 eps_skn=get_eigenvalues(calc),
488 gap=gap,
489 )
492class Wannier:
493 """Partly occupied Wannier functions
495 Find the set of partly occupied Wannier functions according to
496 Thygesen, Hansen and Jacobsen PRB v72 i12 p125119 2005.
497 """
499 def __init__(
500 self,
501 nwannier,
502 calc,
503 file=None,
504 nbands=None,
505 fixedenergy=None,
506 fixedstates=None,
507 spin=0,
508 initialwannier='orbitals',
509 functional='std',
510 rng=np.random,
511 log=silent,
512 ):
513 """
514 Required arguments:
516 ``nwannier``: The number of Wannier functions you wish to construct.
517 This must be at least half the number of electrons in the system
518 and at most equal to the number of bands in the calculation.
519 It can also be set to 'auto' in order to automatically choose the
520 minimum number of needed Wannier function based on the
521 ``fixedenergy`` / ``fixedstates`` parameter.
523 ``calc``: A converged DFT calculator class.
524 If ``file`` arg. is not provided, the calculator *must* provide the
525 method ``get_wannier_localization_matrix``, and contain the
526 wavefunctions (save files with only the density is not enough).
527 If the localization matrix is read from file, this is not needed,
528 unless ``get_function`` or ``write_cube`` is called.
530 Optional arguments:
532 ``nbands``: Bands to include in localization.
533 The number of bands considered by Wannier can be smaller than the
534 number of bands in the calculator. This is useful if the highest
535 bands of the DFT calculation are not well converged.
537 ``spin``: The spin channel to be considered.
538 The Wannier code treats each spin channel independently.
540 ``fixedenergy`` / ``fixedstates``: Fixed part of Hilbert space.
541 Determine the fixed part of Hilbert space by either a maximal
542 energy *or* a number of bands (possibly a list for multiple
543 k-points).
544 Default is None meaning that the number of fixed states is equated
545 to ``nwannier``.
546 The maximal energy is relative to the CBM if there is a finite
547 bandgap or to the Fermi level if there is none.
549 ``file``: Read localization and rotation matrices from this file.
551 ``initialwannier``: Initial guess for Wannier rotation matrix.
552 Can be 'bloch' to start from the Bloch states, 'random' to be
553 randomized, 'orbitals' to start from atom-centered d-orbitals and
554 randomly placed gaussian centers (see init_orbitals()),
555 'scdm' to start from localized state selected with SCDM
556 or a list passed to calc.get_initial_wannier.
558 ``functional``: The functional used to measure the localization.
559 Can be 'std' for the standard quadratic functional from the PRB
560 paper, 'var' to add a variance minimizing term.
562 ``rng``: Random number generator for ``initialwannier``.
564 ``log``: Function which logs, such as print().
565 """
566 # Bloch phase sign convention.
567 # May require special cases depending on which code is used.
568 sign = -1
570 self.log = log
571 self.calc = calc
573 self.spin = spin
574 self.functional = functional
575 self.initialwannier = initialwannier
576 self.log('Using functional:', functional)
578 self.calcdata = get_calcdata(calc)
580 self.kptgrid = get_monkhorst_pack_size_and_offset(self.kpt_kc)[0]
581 self.calcdata.kpt_kc *= sign
583 self.largeunitcell_cc = (self.unitcell_cc.T * self.kptgrid).T
584 self.weight_d, self.Gdir_dc = calculate_weights(self.largeunitcell_cc)
585 assert len(self.weight_d) == len(self.Gdir_dc)
587 if nbands is None:
588 # XXX Can work with other number of bands than calculator.
589 # Is this case properly tested, lest we confuse them?
590 nbands = self.calcdata.nbands
591 self.nbands = nbands
593 self.fixedstates_k, self.nwannier = choose_states(
594 self.calcdata,
595 fixedenergy,
596 fixedstates,
597 self.Nk,
598 nwannier,
599 log,
600 spin,
601 )
603 # Compute the number of extra degrees of freedom (EDF)
604 self.edf_k = self.nwannier - self.fixedstates_k
606 self.log(f'Wannier: Fixed states : {self.fixedstates_k}')
607 self.log(f'Wannier: Extra degrees of freedom: {self.edf_k}')
609 self.kklst_dk, k0_dkc = get_kklst(self.kpt_kc, self.Gdir_dc)
611 # Set the inverse list of neighboring k-points
612 self.invkklst_dk = get_invkklst(self.kklst_dk)
614 Nw = self.nwannier
615 Nb = self.nbands
616 self.Z_dkww = np.empty((self.Ndir, self.Nk, Nw, Nw), complex)
617 self.V_knw = np.zeros((self.Nk, Nb, Nw), complex)
619 if file is None:
620 self.Z_dknn = self.new_Z(calc, k0_dkc)
621 self.initialize(file=file, initialwannier=initialwannier, rng=rng)
623 @property
624 def atoms(self):
625 return self.calcdata.atoms
627 @property
628 def kpt_kc(self):
629 return self.calcdata.kpt_kc
631 @property
632 def Ndir(self):
633 return len(self.weight_d) # Number of directions
635 @property
636 def Nk(self):
637 return len(self.kpt_kc)
639 def new_Z(self, calc, k0_dkc):
640 Nb = self.nbands
641 Z_dknn = np.empty((self.Ndir, self.Nk, Nb, Nb), complex)
642 for d, dirG in enumerate(self.Gdir_dc):
643 for k in range(self.Nk):
644 k1 = self.kklst_dk[d, k]
645 k0_c = k0_dkc[d, k]
646 Z_dknn[d, k] = calc.get_wannier_localization_matrix(
647 nbands=Nb,
648 dirG=dirG,
649 kpoint=k,
650 nextkpoint=k1,
651 G_I=k0_c,
652 spin=self.spin,
653 )
654 return Z_dknn
656 @property
657 def unitcell_cc(self):
658 return self.atoms.cell
660 @property
661 def U_kww(self):
662 return self.wannier_state.U_kww
664 @property
665 def C_kul(self):
666 return self.wannier_state.C_kul
668 def initialize(self, file=None, initialwannier='random', rng=np.random):
669 """Re-initialize current rotation matrix.
671 Keywords are identical to those of the constructor.
672 """
673 from ase.dft.wannierstate import WannierSpec, WannierState
675 spec = WannierSpec(
676 self.Nk, self.nwannier, self.nbands, self.fixedstates_k
677 )
679 if file is not None:
680 with paropen(file, 'r') as fd:
681 Z_dknn, U_kww, C_kul = read_json(fd, always_array=False)
682 self.Z_dknn = Z_dknn
683 wannier_state = WannierState(C_kul, U_kww)
684 elif initialwannier == 'bloch':
685 # Set U and C to pick the lowest Bloch states
686 wannier_state = spec.bloch(self.edf_k)
687 elif initialwannier == 'random':
688 wannier_state = spec.random(rng, self.edf_k)
689 elif initialwannier == 'orbitals':
690 orbitals = init_orbitals(self.atoms, self.nwannier, rng)
691 wannier_state = spec.initial_orbitals(
692 self.calc, orbitals, self.kptgrid, self.edf_k, self.spin
693 )
694 elif initialwannier == 'scdm':
695 wannier_state = spec.scdm(self.calc, self.kpt_kc, self.spin)
696 else:
697 wannier_state = spec.initial_wannier(
698 self.calc, initialwannier, self.kptgrid, self.edf_k, self.spin
699 )
701 self.wannier_state = wannier_state
702 self.update()
704 def save(self, file):
705 """Save information on localization and rotation matrices to file."""
706 with paropen(file, 'w') as fd:
707 write_json(fd, (self.Z_dknn, self.U_kww, self.C_kul))
709 def update(self):
710 # Update large rotation matrix V (from rotation U and coeff C)
711 for k, M in enumerate(self.fixedstates_k):
712 self.V_knw[k, :M] = self.U_kww[k, :M]
713 if M < self.nwannier:
714 self.V_knw[k, M:] = self.C_kul[k] @ self.U_kww[k, M:]
715 # else: self.V_knw[k, M:] = 0.0
717 # Calculate the Zk matrix from the large rotation matrix:
718 # Zk = V^d[k] Zbloch V[k1]
719 for d in range(self.Ndir):
720 for k in range(self.Nk):
721 k1 = self.kklst_dk[d, k]
722 self.Z_dkww[d, k] = dag(self.V_knw[k]) @ (
723 self.Z_dknn[d, k] @ self.V_knw[k1]
724 )
726 # Update the new Z matrix
727 self.Z_dww = self.Z_dkww.sum(axis=1) / self.Nk
729 def get_optimal_nwannier(self, nwrange=5, random_reps=5, tolerance=1e-6):
730 """
731 The optimal value for 'nwannier', maybe.
733 The optimal value is the one that gives the lowest average value for
734 the spread of the most delocalized Wannier function in the set.
736 ``nwrange``: number of different values to try for 'nwannier', the
737 values will span a symmetric range around ``nwannier`` if possible.
739 ``random_reps``: number of repetitions with random seed, the value is
740 then an average over these repetitions.
742 ``tolerance``: tolerance for the gradient descent algorithm, can be
743 useful to increase the speed, with a cost in accuracy.
744 """
746 # Define the range of values to try based on the maximum number of fixed
747 # states (that is the minimum number of WFs we need) and the number of
748 # available bands we have.
749 max_number_fixedstates = np.max(self.fixedstates_k)
751 min_range_value = max(
752 self.nwannier - int(np.floor(nwrange / 2)), max_number_fixedstates
753 )
754 max_range_value = min(min_range_value + nwrange, self.nbands + 1)
755 Nws = np.arange(min_range_value, max_range_value)
757 # If there is no randomness, there is no need to repeat
758 random_initials = ['random', 'orbitals']
759 if self.initialwannier not in random_initials:
760 random_reps = 1
762 t = -time()
763 avg_max_spreads = np.zeros(len(Nws))
764 for j, Nw in enumerate(Nws):
765 self.log('Trying with Nw =', Nw)
767 # Define once with the fastest 'initialwannier',
768 # then initialize with random seeds in the for loop
769 wan = Wannier(
770 nwannier=int(Nw),
771 calc=self.calc,
772 nbands=self.nbands,
773 spin=self.spin,
774 functional=self.functional,
775 initialwannier='bloch',
776 log=self.log,
777 rng=self.rng,
778 )
779 wan.fixedstates_k = self.fixedstates_k
780 wan.edf_k = wan.nwannier - wan.fixedstates_k
782 max_spreads = np.zeros(random_reps)
783 for i in range(random_reps):
784 wan.initialize(initialwannier=self.initialwannier)
785 wan.localize(tolerance=tolerance)
786 max_spreads[i] = np.max(wan.get_spreads())
788 avg_max_spreads[j] = max_spreads.mean()
790 self.log('Average spreads: ', avg_max_spreads)
791 t += time()
792 self.log(f'Execution time: {t:.1f}s')
794 return Nws[np.argmin(avg_max_spreads)]
796 def get_centers(self, scaled=False):
797 """Calculate the Wannier centers
799 ::
801 pos = L / 2pi * phase(diag(Z))
802 """
803 coord_wc = np.angle(self.Z_dww[:3].diagonal(0, 1, 2)).T / (2 * pi) % 1
804 if not scaled:
805 coord_wc = coord_wc @ self.largeunitcell_cc
806 return coord_wc
808 def get_radii(self):
809 r"""Calculate the spread of the Wannier functions.
811 ::
813 -- / L \ 2 2
814 radius**2 = - > | --- | ln |Z|
815 --d \ 2pi /
817 Note that this function can fail with some Bravais lattices,
818 see `get_spreads()` for a more robust alternative.
819 """
820 r2 = -(self.largeunitcell_cc.diagonal() ** 2 / (2 * pi) ** 2) @ np.log(
821 abs(self.Z_dww[:3].diagonal(0, 1, 2)) ** 2
822 )
823 return np.sqrt(r2)
825 def get_spreads(self):
826 r"""Calculate the spread of the Wannier functions in Ų.
827 The definition is based on eq. 13 in PRB61-15 by Berghold and Mundy.
829 ::
831 / 1 \ 2 -- 2
832 spread = - |----| > W_d ln |Z_dw|
833 \2 pi/ --d
836 """
837 # compute weights without normalization, to keep physical dimension
838 weight_d, _ = calculate_weights(self.largeunitcell_cc, normalize=False)
839 Z2_dw = square_modulus_of_Z_diagonal(self.Z_dww)
840 spread_w = -(np.log(Z2_dw).T @ weight_d).real / (2 * np.pi) ** 2
841 return spread_w
843 def get_spectral_weight(self, w):
844 return abs(self.V_knw[:, :, w]) ** 2 / self.Nk
846 def get_pdos(self, w, energies, width):
847 """Projected density of states (PDOS).
849 Returns the (PDOS) for Wannier function ``w``. The calculation
850 is performed over the energy grid specified in energies. The
851 PDOS is produced as a sum of Gaussians centered at the points
852 of the energy grid and with the specified width.
853 """
854 spec_kn = self.get_spectral_weight(w)
855 dos = np.zeros(len(energies))
856 for k, spec_n in enumerate(spec_kn):
857 eig_n = self.calcdata.eps_skn[self.spin, k]
858 for weight, eig in zip(spec_n, eig_n):
859 # Add gaussian centered at the eigenvalue
860 x = ((energies - eig) / width) ** 2
861 dos += weight * np.exp(-x.clip(0.0, 40.0)) / (sqrt(pi) * width)
862 return dos
864 def translate(self, w, R):
865 """Translate the w'th Wannier function
867 The distance vector R = [n1, n2, n3], is in units of the basis
868 vectors of the small cell.
869 """
870 for kpt_c, U_ww in zip(self.kpt_kc, self.U_kww):
871 U_ww[:, w] *= np.exp(2.0j * pi * (np.array(R) @ kpt_c))
872 self.update()
874 def translate_to_cell(self, w, cell):
875 """Translate the w'th Wannier function to specified cell"""
876 scaled_c = np.angle(self.Z_dww[:3, w, w]) * self.kptgrid / (2 * pi)
877 trans = np.array(cell) - np.floor(scaled_c)
878 self.translate(w, trans)
880 def translate_all_to_cell(self, cell=(0, 0, 0)):
881 r"""Translate all Wannier functions to specified cell.
883 Move all Wannier orbitals to a specific unit cell. There
884 exists an arbitrariness in the positions of the Wannier
885 orbitals relative to the unit cell. This method can move all
886 orbitals to the unit cell specified by ``cell``. For a
887 `\Gamma`-point calculation, this has no effect. For a
888 **k**-point calculation the periodicity of the orbitals are
889 given by the large unit cell defined by repeating the original
890 unitcell by the number of **k**-points in each direction. In
891 this case it is useful to move the orbitals away from the
892 boundaries of the large cell before plotting them. For a bulk
893 calculation with, say 10x10x10 **k** points, one could move
894 the orbitals to the cell [2,2,2]. In this way the pbc
895 boundary conditions will not be noticed.
896 """
897 scaled_wc = (
898 np.angle(self.Z_dww[:3].diagonal(0, 1, 2)).T
899 * self.kptgrid
900 / (2 * pi)
901 )
902 trans_wc = np.array(cell)[None] - np.floor(scaled_wc)
903 for kpt_c, U_ww in zip(self.kpt_kc, self.U_kww):
904 U_ww *= np.exp(2.0j * pi * (trans_wc @ kpt_c))
905 self.update()
907 def distances(self, R):
908 """Relative distances between centers.
910 Returns a matrix with the distances between different Wannier centers.
911 R = [n1, n2, n3] is in units of the basis vectors of the small cell
912 and allows one to measure the distance with centers moved to a
913 different small cell.
914 The dimension of the matrix is [Nw, Nw].
915 """
916 Nw = self.nwannier
917 cen = self.get_centers()
918 r1 = cen.repeat(Nw, axis=0).reshape(Nw, Nw, 3)
919 r2 = cen.copy()
920 for i in range(3):
921 r2 += self.unitcell_cc[i] * R[i]
923 r2 = np.swapaxes(r2.repeat(Nw, axis=0).reshape(Nw, Nw, 3), 0, 1)
924 return np.sqrt(np.sum((r1 - r2) ** 2, axis=-1))
926 @functools.lru_cache(maxsize=10000)
927 def _get_hopping(self, n1, n2, n3):
928 """Returns the matrix H(R)_nm=<0,n|H|R,m>.
930 ::
932 1 _ -ik.R
933 H(R) = <0,n|H|R,m> = --- >_ e H(k)
934 Nk k
936 where R = (n1, n2, n3) is the cell-distance (in units of the basis
937 vectors of the small cell) and n,m are indices of the Wannier functions.
938 This function caches up to 'maxsize' results.
939 """
940 R = np.array([n1, n2, n3], float)
941 H_ww = np.zeros([self.nwannier, self.nwannier], complex)
942 for k, kpt_c in enumerate(self.kpt_kc):
943 phase = np.exp(-2.0j * pi * (np.array(R) @ kpt_c))
944 H_ww += self.get_hamiltonian(k) * phase
945 return H_ww / self.Nk
947 def get_hopping(self, R):
948 """Returns the matrix H(R)_nm=<0,n|H|R,m>.
950 ::
952 1 _ -ik.R
953 H(R) = <0,n|H|R,m> = --- >_ e H(k)
954 Nk k
956 where R is the cell-distance (in units of the basis vectors of
957 the small cell) and n,m are indices of the Wannier functions.
958 """
959 return self._get_hopping(R[0], R[1], R[2])
961 def get_hamiltonian(self, k):
962 """Get Hamiltonian at existing k-vector of index k
964 ::
966 dag
967 H(k) = V diag(eps ) V
968 k k k
969 """
970 eps_n = self.calcdata.eps_skn[self.spin, k, : self.nbands]
971 V_nw = self.V_knw[k]
972 return (dag(V_nw) * eps_n) @ V_nw
974 def get_hamiltonian_kpoint(self, kpt_c):
975 """Get Hamiltonian at some new arbitrary k-vector
977 ::
979 _ ik.R
980 H(k) = >_ e H(R)
981 R
983 Warning: This method moves all Wannier functions to cell (0, 0, 0)
984 """
985 self.log('Translating all Wannier functions to cell (0, 0, 0)')
986 self.translate_all_to_cell()
987 max = (self.kptgrid - 1) // 2
988 N1, N2, N3 = max
989 Hk = np.zeros([self.nwannier, self.nwannier], complex)
990 for n1 in range(-N1, N1 + 1):
991 for n2 in range(-N2, N2 + 1):
992 for n3 in range(-N3, N3 + 1):
993 R = np.array([n1, n2, n3], float)
994 hop_ww = self.get_hopping(R)
995 phase = np.exp(+2.0j * pi * (R @ kpt_c))
996 Hk += hop_ww * phase
997 return Hk
999 def get_function(self, index, repeat=None):
1000 r"""Get Wannier function on grid.
1002 Returns an array with the funcion values of the indicated Wannier
1003 function on a grid with the size of the *repeated* unit cell.
1005 For a calculation using **k**-points the relevant unit cell for
1006 eg. visualization of the Wannier orbitals is not the original unit
1007 cell, but rather a larger unit cell defined by repeating the
1008 original unit cell by the number of **k**-points in each direction.
1009 Note that for a `\Gamma`-point calculation the large unit cell
1010 coinsides with the original unit cell.
1011 The large unitcell also defines the periodicity of the Wannier
1012 orbitals.
1014 ``index`` can be either a single WF or a coordinate vector in terms
1015 of the WFs.
1016 """
1018 # Default size of plotting cell is the one corresponding to k-points.
1019 if repeat is None:
1020 repeat = self.kptgrid
1021 N1, N2, N3 = repeat
1023 dim = self.calc.get_number_of_grid_points()
1024 largedim = dim * [N1, N2, N3]
1026 wanniergrid = np.zeros(largedim, dtype=complex)
1027 for k, kpt_c in enumerate(self.kpt_kc):
1028 # The coordinate vector of wannier functions
1029 if isinstance(index, int):
1030 vec_n = self.V_knw[k, :, index]
1031 else:
1032 vec_n = self.V_knw[k] @ index
1034 wan_G = np.zeros(dim, complex)
1035 for n, coeff in enumerate(vec_n):
1036 wan_G += coeff * self.calc.get_pseudo_wave_function(
1037 n, k, self.spin, pad=True
1038 )
1040 # Distribute the small wavefunction over large cell:
1041 for n1 in range(N1):
1042 for n2 in range(N2):
1043 for n3 in range(N3): # sign?
1044 e = np.exp(-2.0j * pi * np.array([n1, n2, n3]) @ kpt_c)
1045 wanniergrid[
1046 n1 * dim[0] : (n1 + 1) * dim[0],
1047 n2 * dim[1] : (n2 + 1) * dim[1],
1048 n3 * dim[2] : (n3 + 1) * dim[2],
1049 ] += e * wan_G
1051 # Normalization
1052 wanniergrid /= np.sqrt(self.Nk)
1053 return wanniergrid
1055 def write_cube(self, index, fname, repeat=None, angle=False):
1056 """
1057 Dump specified Wannier function to a cube file.
1059 Arguments:
1061 ``index``: Integer, index of the Wannier function to save.
1063 ``repeat``: Array of integer, repeat supercell and Wannier function.
1065 ``fname``: Name of the cube file.
1067 ``angle``: If False, save the absolute value. If True, save
1068 the complex phase of the Wannier function.
1069 """
1070 from ase.io import write
1072 # Default size of plotting cell is the one corresponding to k-points.
1073 if repeat is None:
1074 repeat = self.kptgrid
1076 # Remove constraints, some are not compatible with repeat()
1077 atoms = self.atoms.copy()
1078 atoms.set_constraint()
1079 atoms = atoms * repeat
1080 func = self.get_function(index, repeat)
1082 # Compute absolute value or complex angle
1083 if angle:
1084 data = np.angle(func)
1085 else:
1086 if self.Nk == 1:
1087 func *= np.exp(-1.0j * np.angle(func.max()))
1088 func = abs(func)
1089 data = func
1091 write(fname, atoms, data=data, format='cube')
1093 def localize(
1094 self, step=0.25, tolerance=1e-08, updaterot=True, updatecoeff=True
1095 ):
1096 """Optimize rotation to give maximal localization"""
1097 md_min(
1098 self,
1099 step=step,
1100 tolerance=tolerance,
1101 log=self.log,
1102 updaterot=updaterot,
1103 updatecoeff=updatecoeff,
1104 )
1106 def get_functional_value(self):
1107 """Calculate the value of the spread functional.
1109 ::
1111 Tr[|ZI|^2]=sum(I)sum(n) w_i|Z_(i)_nn|^2,
1113 where w_i are weights.
1115 If the functional is set to 'var' it subtracts a variance term
1117 ::
1119 Nw * var(sum(n) w_i|Z_(i)_nn|^2),
1121 where Nw is the number of WFs ``nwannier``.
1122 """
1123 a_w = self._spread_contributions()
1124 if self.functional == 'std':
1125 fun = np.sum(a_w)
1126 elif self.functional == 'var':
1127 fun = np.sum(a_w) - self.nwannier * np.var(a_w)
1128 self.log(
1129 f'std: {np.sum(a_w):.4f}',
1130 f'\tvar: {self.nwannier * np.var(a_w):.4f}',
1131 )
1132 return fun
1134 def get_gradients(self):
1135 # Determine gradient of the spread functional.
1136 #
1137 # The gradient for a rotation A_kij is::
1138 #
1139 # dU = dRho/dA_{k,i,j} = sum(I) sum(k')
1140 # + Z_jj Z_kk',ij^* - Z_ii Z_k'k,ij^*
1141 # - Z_ii^* Z_kk',ji + Z_jj^* Z_k'k,ji
1142 #
1143 # The gradient for a change of coefficients is::
1144 #
1145 # dRho/da^*_{k,i,j} = sum(I) [[(Z_0)_{k} V_{k'} diag(Z^*) +
1146 # (Z_0_{k''})^d V_{k''} diag(Z)] *
1147 # U_k^d]_{N+i,N+j}
1148 #
1149 # where diag(Z) is a square,diagonal matrix with Z_nn in the diagonal,
1150 # k' = k + dk and k = k'' + dk.
1151 #
1152 # The extra degrees of freedom chould be kept orthonormal to the fixed
1153 # space, thus we introduce lagrange multipliers, and minimize instead::
1154 #
1155 # Rho_L = Rho - sum_{k,n,m} lambda_{k,nm} <c_{kn}|c_{km}>
1156 #
1157 # for this reason the coefficient gradients should be multiplied
1158 # by (1 - c c^dag).
1160 Nb = self.nbands
1161 Nw = self.nwannier
1163 if self.functional == 'var':
1164 O_w = self._spread_contributions()
1165 O_sum = np.sum(O_w)
1167 dU = []
1168 dC = []
1169 for k in range(self.Nk):
1170 M = self.fixedstates_k[k]
1171 L = self.edf_k[k]
1172 U_ww = self.U_kww[k]
1173 C_ul = self.C_kul[k]
1174 Utemp_ww = np.zeros((Nw, Nw), complex)
1175 Ctemp_nw = np.zeros((Nb, Nw), complex)
1177 for d, weight in enumerate(self.weight_d):
1178 if abs(weight) < 1.0e-6:
1179 continue
1181 Z_knn = self.Z_dknn[d]
1182 diagZ_w = self.Z_dww[d].diagonal()
1183 Zii_ww = np.repeat(diagZ_w, Nw).reshape(Nw, Nw)
1184 if self.functional == 'var':
1185 diagOZ_w = O_w * diagZ_w
1186 OZii_ww = np.repeat(diagOZ_w, Nw).reshape(Nw, Nw)
1188 k1 = self.kklst_dk[d, k]
1189 k2 = self.invkklst_dk[d, k]
1190 V_knw = self.V_knw
1191 Z_kww = self.Z_dkww[d]
1193 if L > 0:
1194 Ctemp_nw += weight * (
1195 (
1196 (Z_knn[k] @ V_knw[k1]) * diagZ_w.conj()
1197 + (dag(Z_knn[k2]) @ V_knw[k2]) * diagZ_w
1198 )
1199 @ dag(U_ww)
1200 )
1202 if self.functional == 'var':
1203 # Gradient of the variance term, split in two terms
1204 def variance_term_computer(factor):
1205 result = (
1206 self.nwannier
1207 * 2
1208 * weight
1209 * (
1210 (
1211 (Z_knn[k] @ V_knw[k1]) * factor.conj()
1212 + (dag(Z_knn[k2]) @ V_knw[k2]) * factor
1213 )
1214 @ dag(U_ww)
1215 )
1216 / Nw**2
1217 )
1218 return result
1220 first_term = (
1221 O_sum * variance_term_computer(diagZ_w) / Nw**2
1222 )
1224 second_term = -variance_term_computer(diagOZ_w) / Nw
1226 Ctemp_nw += first_term + second_term
1228 temp = Zii_ww.T * Z_kww[k].conj() - Zii_ww * Z_kww[k2].conj()
1229 Utemp_ww += weight * (temp - dag(temp))
1231 if self.functional == 'var':
1232 Utemp_ww += (
1233 self.nwannier
1234 * 2
1235 * O_sum
1236 * weight
1237 * (temp - dag(temp))
1238 / Nw**2
1239 )
1241 temp = (
1242 OZii_ww.T * Z_kww[k].conj() - OZii_ww * Z_kww[k2].conj()
1243 )
1244 Utemp_ww -= (
1245 self.nwannier * 2 * weight * (temp - dag(temp)) / Nw
1246 )
1248 dU.append(Utemp_ww.ravel())
1250 if L > 0:
1251 # Ctemp now has same dimension as V, the gradient is in the
1252 # lower-right (Nb-M) x L block
1253 Ctemp_ul = Ctemp_nw[M:, M:]
1254 G_ul = Ctemp_ul - ((C_ul @ dag(C_ul)) @ Ctemp_ul)
1255 dC.append(G_ul.ravel())
1257 return np.concatenate(dU + dC)
1259 def _spread_contributions(self):
1260 """
1261 Compute the contribution of each WF to the spread functional.
1262 """
1263 return (square_modulus_of_Z_diagonal(self.Z_dww).T @ self.weight_d).real
1265 def step(self, dX, updaterot=True, updatecoeff=True):
1266 # dX is (A, dC) where U->Uexp(-A) and C->C+dC
1267 Nw = self.nwannier
1268 Nk = self.Nk
1269 M_k = self.fixedstates_k
1270 L_k = self.edf_k
1271 if updaterot:
1272 A_kww = dX[: Nk * Nw**2].reshape(Nk, Nw, Nw)
1273 for U, A in zip(self.U_kww, A_kww):
1274 H = -1.0j * A.conj()
1275 epsilon, Z = np.linalg.eigh(H)
1276 # Z contains the eigenvectors as COLUMNS.
1277 # Since H = iA, dU = exp(-A) = exp(iH) = ZDZ^d
1278 dU = Z * np.exp(1.0j * epsilon) @ dag(Z)
1279 if U.dtype == float:
1280 U[:] = (U @ dU).real
1281 else:
1282 U[:] = U @ dU
1284 if updatecoeff:
1285 start = 0
1286 for C, unocc, L in zip(self.C_kul, self.nbands - M_k, L_k):
1287 if L == 0 or unocc == 0:
1288 continue
1289 Ncoeff = L * unocc
1290 deltaC = dX[Nk * Nw**2 + start : Nk * Nw**2 + start + Ncoeff]
1291 C += deltaC.reshape(unocc, L)
1292 gram_schmidt(C)
1293 start += Ncoeff
1295 self.update()