Coverage for /builds/ase/ase/ase/dft/wannier.py: 58.74%

555 statements  

« prev     ^ index     » next       coverage.py v7.5.3, created at 2025-08-02 00:12 +0000

1"""Partly occupied Wannier functions 

2 

3Find the set of partly occupied Wannier functions using the method from 

4Thygesen, Hansen and Jacobsen PRB v72 i12 p125119 2005. 

5""" 

6 

7import functools 

8import warnings 

9from math import pi, sqrt 

10from time import time 

11 

12import numpy as np 

13from scipy.linalg import qr 

14 

15from ase.dft.bandgap import bandgap 

16from ase.dft.kpoints import get_monkhorst_pack_size_and_offset 

17from ase.io.jsonio import read_json, write_json 

18from ase.parallel import paropen 

19from ase.transport.tools import dagger, normalize 

20 

21dag = dagger 

22 

23 

24def silent(*args, **kwargs): 

25 """Dummy logging function.""" 

26 

27 

28def gram_schmidt(U): 

29 """Orthonormalize columns of U according to the Gram-Schmidt procedure.""" 

30 for i, col in enumerate(U.T): 

31 for col2 in U.T[:i]: 

32 col -= col2 * (col2.conj() @ col) 

33 col /= np.linalg.norm(col) 

34 

35 

36def lowdin(U, S=None): 

37 """Orthonormalize columns of U according to the symmetric Lowdin procedure. 

38 The implementation uses SVD, like symm. Lowdin it returns the nearest 

39 orthonormal matrix, but is more robust. 

40 """ 

41 

42 L, _s, R = np.linalg.svd(U, full_matrices=False) 

43 U[:] = L @ R 

44 

45 

46def neighbor_k_search(k_c, G_c, kpt_kc, tol=1e-4): 

47 # search for k1 (in kpt_kc) and k0 (in alldir), such that 

48 # k1 - k - G + k0 = 0 

49 alldir_dc = np.array( 

50 [ 

51 [0, 0, 0], 

52 [1, 0, 0], 

53 [0, 1, 0], 

54 [0, 0, 1], 

55 [1, 1, 0], 

56 [1, 0, 1], 

57 [0, 1, 1], 

58 ], 

59 dtype=int, 

60 ) 

61 for k0_c in alldir_dc: 

62 for k1, k1_c in enumerate(kpt_kc): 

63 if np.linalg.norm(k1_c - k_c - G_c + k0_c) < tol: 

64 return k1, k0_c 

65 

66 raise ValueError( 

67 f'Wannier: Did not find matching kpoint for kpt={k_c}. ' 

68 'Probably non-uniform k-point grid' 

69 ) 

70 

71 

72def calculate_weights(cell_cc, normalize=True): 

73 """Weights are used for non-cubic cells, see PRB **61**, 10040 

74 If normalized they lose the physical dimension.""" 

75 alldirs_dc = np.array( 

76 [[1, 0, 0], [0, 1, 0], [0, 0, 1], [1, 1, 0], [1, 0, 1], [0, 1, 1]], 

77 dtype=int, 

78 ) 

79 g = cell_cc @ cell_cc.T 

80 # NOTE: Only first 3 of following 6 weights are presently used: 

81 w = np.zeros(6) 

82 w[0] = g[0, 0] - g[0, 1] - g[0, 2] 

83 w[1] = g[1, 1] - g[0, 1] - g[1, 2] 

84 w[2] = g[2, 2] - g[0, 2] - g[1, 2] 

85 w[3] = g[0, 1] 

86 w[4] = g[0, 2] 

87 w[5] = g[1, 2] 

88 # Make sure that first 3 Gdir vectors are included - 

89 # these are used to calculate Wanniercenters. 

90 Gdir_dc = alldirs_dc[:3] 

91 weight_d = w[:3] 

92 for d in range(3, 6): 

93 if abs(w[d]) > 1e-5: 

94 Gdir_dc = np.concatenate((Gdir_dc, alldirs_dc[d : d + 1])) 

95 weight_d = np.concatenate((weight_d, w[d : d + 1])) 

96 if normalize: 

97 weight_d /= max(abs(weight_d)) 

98 return weight_d, Gdir_dc 

99 

100 

101def steepest_descent(func, step=0.005, tolerance=1e-6, log=silent, **kwargs): 

102 fvalueold = 0.0 

103 fvalue = fvalueold + 10 

104 count = 0 

105 while abs((fvalue - fvalueold) / fvalue) > tolerance: 

106 fvalueold = fvalue 

107 dF = func.get_gradients() 

108 func.step(dF * step, **kwargs) 

109 fvalue = func.get_functional_value() 

110 count += 1 

111 log(f'SteepestDescent: iter={count}, value={fvalue}') 

112 

113 

114def md_min( 

115 func, step=0.25, tolerance=1e-6, max_iter=10000, log=silent, **kwargs 

116): 

117 log('Localize with step =', step, 'and tolerance =', tolerance) 

118 finit = func.get_functional_value() 

119 

120 t = -time() 

121 fvalueold = 0.0 

122 fvalue = fvalueold + 10 

123 count = 0 

124 V = np.zeros(func.get_gradients().shape, dtype=complex) 

125 

126 while abs((fvalue - fvalueold) / fvalue) > tolerance: 

127 fvalueold = fvalue 

128 dF = func.get_gradients() 

129 

130 V *= (dF * V.conj()).real > 0 

131 V += step * dF 

132 func.step(V, **kwargs) 

133 fvalue = func.get_functional_value() 

134 

135 if fvalue < fvalueold: 

136 step *= 0.5 

137 count += 1 

138 log(f'MDmin: iter={count}, step={step}, value={fvalue}') 

139 if count > max_iter: 

140 t += time() 

141 warnings.warn( 

142 'Max iterations reached: ' 

143 'iters=%s, step=%s, seconds=%0.2f, value=%0.4f' 

144 % (count, step, t, fvalue.real) 

145 ) 

146 break 

147 

148 t += time() 

149 log( 

150 '%d iterations in %0.2f seconds (%0.2f ms/iter), endstep = %s' 

151 % (count, t, t * 1000.0 / count, step) 

152 ) 

153 log(f'Initial value={finit}, Final value={fvalue}') 

154 

155 

156def rotation_from_projection(proj_nw, fixed, ortho=True): 

157 """Determine rotation and coefficient matrices from projections 

158 

159 proj_nw = <psi_n|p_w> 

160 psi_n: eigenstates 

161 p_w: localized function 

162 

163 Nb (n) = Number of bands 

164 Nw (w) = Number of wannier functions 

165 M (f) = Number of fixed states 

166 L (l) = Number of extra degrees of freedom 

167 U (u) = Number of non-fixed states 

168 """ 

169 

170 Nb, Nw = proj_nw.shape 

171 M = fixed 

172 L = Nw - M 

173 U = Nb - M 

174 

175 U_ww = np.empty((Nw, Nw), dtype=proj_nw.dtype) 

176 

177 # Set the section of the rotation matrix about the 'fixed' states 

178 U_ww[:M] = proj_nw[:M] 

179 

180 if L > 0: 

181 # If there are extra degrees of freedom we have to select L of them 

182 C_ul = np.empty((U, L), dtype=proj_nw.dtype) 

183 

184 # Get the projections on the 'non fixed' states 

185 proj_uw = proj_nw[M:] 

186 

187 # Obtain eigenvalues and eigevectors matrix 

188 eig_w, C_ww = np.linalg.eigh(dag(proj_uw) @ proj_uw) 

189 

190 # Sort columns of eigenvectors matrix according to the eigenvalues 

191 # magnitude, select only the L largest ones. Then use them to obtain 

192 # the parameter C matrix. 

193 C_ul[:] = proj_uw @ C_ww[:, np.argsort(-eig_w.real)[:L]] 

194 

195 # Compute the section of the rotation matrix about 'non fixed' states 

196 U_ww[M:] = dag(C_ul) @ proj_uw 

197 normalize(C_ul) 

198 else: 

199 # If there are no extra degrees of freedom we do not need any parameter 

200 # matrix C 

201 C_ul = np.empty((U, 0), dtype=proj_nw.dtype) 

202 

203 if ortho: 

204 # Orthogonalize with Lowdin to take the closest orthogonal set 

205 lowdin(U_ww) 

206 else: 

207 normalize(U_ww) 

208 

209 return U_ww, C_ul 

210 

211 

212def search_for_gamma_point(kpts): 

213 """Returns index of Gamma point in a list of k-points.""" 

214 gamma_idx = np.argmin([np.linalg.norm(kpt) for kpt in kpts]) 

215 if np.linalg.norm(kpts[gamma_idx]) >= 1e-14: 

216 gamma_idx = None 

217 return gamma_idx 

218 

219 

220def scdm(pseudo_nkG, kpts, fixed_k, Nw): 

221 """Compute localized orbitals with SCDM method 

222 

223 This method was published by Anil Damle and Lin Lin in Multiscale 

224 Modeling & Simulation 16, 1392–1410 (2018). 

225 For now only the isolated bands algorithm is implemented, because it is 

226 intended as a drop-in replacement for other initial guess methods for 

227 the ASE Wannier class. 

228 

229 pseudo_nkG = pseudo wave-functions on a real grid 

230 Ng (G) = number of real grid points 

231 kpts = List of k-points in the BZ 

232 Nk (k) = Number of k-points 

233 Nb (n) = Number of bands 

234 Nw (w) = Number of wannier functions 

235 fixed_k = Number of fixed states for each k-point 

236 L (l) = Number of extra degrees of freedom 

237 U (u) = Number of non-fixed states 

238 """ 

239 

240 gamma_idx = search_for_gamma_point(kpts) 

241 Nk = len(kpts) 

242 U_kww = [] 

243 C_kul = [] 

244 

245 # compute factorization only at Gamma point 

246 _, _, P = qr( 

247 pseudo_nkG[:, gamma_idx, :], 

248 mode='full', 

249 pivoting=True, 

250 check_finite=True, 

251 ) 

252 

253 for k in range(Nk): 

254 A_nw = pseudo_nkG[:, k, P[:Nw]] 

255 U_ww, C_ul = rotation_from_projection( 

256 proj_nw=A_nw, fixed=fixed_k[k], ortho=True 

257 ) 

258 U_kww.append(U_ww) 

259 C_kul.append(C_ul) 

260 

261 U_kww = np.asarray(U_kww) 

262 

263 return C_kul, U_kww 

264 

265 

266def arbitrary_s_orbitals(atoms, Ns, rng=np.random): 

267 """ 

268 Generate a list of Ns randomly placed s-orbitals close to at least 

269 one atom (< 1.5Å). 

270 The format of the list is the one required by GPAW in initial_wannier(). 

271 """ 

272 # Create dummy copy of the Atoms object and dummy H atom 

273 tmp_atoms = atoms.copy() 

274 tmp_atoms.append('H') 

275 s_pos = tmp_atoms.get_scaled_positions() 

276 

277 orbs = [] 

278 for _ in range(Ns): 

279 fine = False 

280 while not fine: 

281 # Random position 

282 x, y, z = rng.rand(3) 

283 s_pos[-1] = [x, y, z] 

284 tmp_atoms.set_scaled_positions(s_pos) 

285 

286 # Use dummy H atom to measure distance from any other atom 

287 dists = tmp_atoms.get_distances(a=-1, indices=range(len(atoms))) 

288 

289 # Check if it is close to at least one atom 

290 if (dists < 1.5).any(): 

291 fine = True 

292 

293 orbs.append([[x, y, z], 0, 1]) 

294 return orbs 

295 

296 

297def init_orbitals(atoms, ntot, rng=np.random): 

298 """ 

299 Place d-orbitals for every atom that has some in the valence states 

300 and then random s-orbitals close to at least one atom (< 1.5Å). 

301 The orbitals list format is compatible with GPAW.get_initial_wannier(). 

302 'atoms': ASE Atoms object 

303 'ntot': total number of needed orbitals 

304 'rng': generator random numbers 

305 """ 

306 

307 # List all the elements that should have occupied d-orbitals 

308 # in the valence states (according to GPAW setups) 

309 d_metals = set( 

310 list(range(21, 31)) 

311 + list(range(39, 52)) 

312 + list(range(57, 84)) 

313 + list(range(89, 113)) 

314 ) 

315 orbs = [] 

316 

317 # Start with zero orbitals 

318 No = 0 

319 

320 # Add d orbitals to each d-metal 

321 for i, z in enumerate(atoms.get_atomic_numbers()): 

322 if z in d_metals: 

323 No_new = No + 5 

324 if No_new <= ntot: 

325 orbs.append([i, 2, 1]) 

326 No = No_new 

327 

328 if No < ntot: 

329 # Add random s-like orbitals if there are not enough yet 

330 Ns = ntot - No 

331 orbs += arbitrary_s_orbitals(atoms, Ns, rng) 

332 

333 assert sum(orb[1] * 2 + 1 for orb in orbs) == ntot 

334 return orbs 

335 

336 

337def square_modulus_of_Z_diagonal(Z_dww): 

338 """ 

339 Square modulus of the Z matrix diagonal, the diagonal is taken 

340 for the indexes running on the WFs. 

341 """ 

342 return np.abs(Z_dww.diagonal(0, 1, 2)) ** 2 

343 

344 

345def get_kklst(kpt_kc, Gdir_dc): 

346 # Set the list of neighboring k-points k1, and the "wrapping" k0, 

347 # such that k1 - k - G + k0 = 0 

348 # 

349 # Example: kpoints = (-0.375,-0.125,0.125,0.375), dir=0 

350 # G = [0.25,0,0] 

351 # k=0.375, k1= -0.375 : -0.375-0.375-0.25 => k0=[1,0,0] 

352 # 

353 # For a gamma point calculation k1 = k = 0, k0 = [1,0,0] for dir=0 

354 Nk = len(kpt_kc) 

355 Ndir = len(Gdir_dc) 

356 

357 if Nk == 1: 

358 kklst_dk = np.zeros((Ndir, 1), int) 

359 k0_dkc = Gdir_dc.reshape(-1, 1, 3) 

360 else: 

361 kklst_dk = np.empty((Ndir, Nk), int) 

362 k0_dkc = np.empty((Ndir, Nk, 3), int) 

363 

364 # Distance between kpoints 

365 kdist_c = np.empty(3) 

366 for c in range(3): 

367 # make a sorted list of the kpoint values in this direction 

368 slist = np.argsort(kpt_kc[:, c], kind='mergesort') 

369 skpoints_kc = np.take(kpt_kc, slist, axis=0) 

370 kdist_c[c] = max( 

371 skpoints_kc[n + 1, c] - skpoints_kc[n, c] for n in range(Nk - 1) 

372 ) 

373 

374 for d, Gdir_c in enumerate(Gdir_dc): 

375 for k, k_c in enumerate(kpt_kc): 

376 # setup dist vector to next kpoint 

377 G_c = np.where(Gdir_c > 0, kdist_c, 0) 

378 if max(G_c) < 1e-4: 

379 kklst_dk[d, k] = k 

380 k0_dkc[d, k] = Gdir_c 

381 else: 

382 kklst_dk[d, k], k0_dkc[d, k] = neighbor_k_search( 

383 k_c, G_c, kpt_kc 

384 ) 

385 return kklst_dk, k0_dkc 

386 

387 

388def get_invkklst(kklst_dk): 

389 Ndir, Nk = kklst_dk.shape 

390 invkklst_dk = np.empty(kklst_dk.shape, int) 

391 for d in range(Ndir): 

392 for k1 in range(Nk): 

393 invkklst_dk[d, k1] = kklst_dk[d].tolist().index(k1) 

394 return invkklst_dk 

395 

396 

397def choose_states(calcdata, fixedenergy, fixedstates, Nk, nwannier, log, spin): 

398 if fixedenergy is None and fixedstates is not None: 

399 if isinstance(fixedstates, int): 

400 fixedstates = [fixedstates] * Nk 

401 fixedstates_k = np.array(fixedstates, int) 

402 elif fixedenergy is not None and fixedstates is None: 

403 # Setting number of fixed states and EDF from given energy cutoff. 

404 # All states below this energy cutoff are fixed. 

405 # The reference energy is Ef for metals and CBM for insulators. 

406 if calcdata.gap < 0.01 or fixedenergy < 0.01: 

407 cutoff = fixedenergy + calcdata.fermi_level 

408 else: 

409 cutoff = fixedenergy + calcdata.lumo 

410 

411 # Find the states below the energy cutoff at each k-point 

412 tmp_fixedstates_k = [] 

413 for k in range(Nk): 

414 eps_n = calcdata.eps_skn[spin, k] 

415 kindex = eps_n.searchsorted(cutoff) 

416 tmp_fixedstates_k.append(kindex) 

417 fixedstates_k = np.array(tmp_fixedstates_k, int) 

418 elif fixedenergy is not None and fixedstates is not None: 

419 raise RuntimeError('You can not set both fixedenergy and fixedstates') 

420 

421 if nwannier == 'auto': 

422 if fixedenergy is None and fixedstates is None: 

423 # Assume the fixedexergy parameter equal to 0 and 

424 # find the states below the Fermi level at each k-point. 

425 log( 

426 "nwannier=auto but no 'fixedenergy' or 'fixedstates'", 

427 'parameter was provided, using Fermi level as', 

428 'energy cutoff.', 

429 ) 

430 tmp_fixedstates_k = [] 

431 for k in range(Nk): 

432 eps_n = calcdata.eps_skn[spin, k] 

433 kindex = eps_n.searchsorted(calcdata.fermi_level) 

434 tmp_fixedstates_k.append(kindex) 

435 fixedstates_k = np.array(tmp_fixedstates_k, int) 

436 nwannier = np.max(fixedstates_k) 

437 

438 # Without user choice just set nwannier fixed states without EDF 

439 if fixedstates is None and fixedenergy is None: 

440 fixedstates_k = np.array([nwannier] * Nk, int) 

441 

442 return fixedstates_k, nwannier 

443 

444 

445def get_eigenvalues(calc): 

446 nspins = calc.get_number_of_spins() 

447 nkpts = len(calc.get_ibz_k_points()) 

448 nbands = calc.get_number_of_bands() 

449 eps_skn = np.empty((nspins, nkpts, nbands)) 

450 

451 for ispin in range(nspins): 

452 for ikpt in range(nkpts): 

453 eps_skn[ispin, ikpt] = calc.get_eigenvalues(kpt=ikpt, spin=ispin) 

454 return eps_skn 

455 

456 

457class CalcData: 

458 def __init__(self, kpt_kc, atoms, fermi_level, lumo, eps_skn, gap): 

459 self.kpt_kc = kpt_kc 

460 self.atoms = atoms 

461 self.fermi_level = fermi_level 

462 self.lumo = lumo 

463 self.eps_skn = eps_skn 

464 self.gap = gap 

465 

466 @property 

467 def nbands(self): 

468 return self.eps_skn.shape[2] 

469 

470 

471def get_calcdata(calc): 

472 kpt_kc = calc.get_bz_k_points() 

473 # Make sure there is no symmetry reduction 

474 if len(calc.get_ibz_k_points()) != len(kpt_kc): 

475 raise RuntimeError( 

476 'K-point symmetry is not currently supported. ' 

477 "Please re-run your calculator with symmetry='off'." 

478 ) 

479 

480 lumo = calc.get_homo_lumo()[1] 

481 gap = bandgap(calc=calc)[0] 

482 return CalcData( 

483 kpt_kc=kpt_kc, 

484 atoms=calc.get_atoms(), 

485 fermi_level=calc.get_fermi_level(), 

486 lumo=lumo, 

487 eps_skn=get_eigenvalues(calc), 

488 gap=gap, 

489 ) 

490 

491 

492class Wannier: 

493 """Partly occupied Wannier functions 

494 

495 Find the set of partly occupied Wannier functions according to 

496 Thygesen, Hansen and Jacobsen PRB v72 i12 p125119 2005. 

497 """ 

498 

499 def __init__( 

500 self, 

501 nwannier, 

502 calc, 

503 file=None, 

504 nbands=None, 

505 fixedenergy=None, 

506 fixedstates=None, 

507 spin=0, 

508 initialwannier='orbitals', 

509 functional='std', 

510 rng=np.random, 

511 log=silent, 

512 ): 

513 """ 

514 Required arguments: 

515 

516 ``nwannier``: The number of Wannier functions you wish to construct. 

517 This must be at least half the number of electrons in the system 

518 and at most equal to the number of bands in the calculation. 

519 It can also be set to 'auto' in order to automatically choose the 

520 minimum number of needed Wannier function based on the 

521 ``fixedenergy`` / ``fixedstates`` parameter. 

522 

523 ``calc``: A converged DFT calculator class. 

524 If ``file`` arg. is not provided, the calculator *must* provide the 

525 method ``get_wannier_localization_matrix``, and contain the 

526 wavefunctions (save files with only the density is not enough). 

527 If the localization matrix is read from file, this is not needed, 

528 unless ``get_function`` or ``write_cube`` is called. 

529 

530 Optional arguments: 

531 

532 ``nbands``: Bands to include in localization. 

533 The number of bands considered by Wannier can be smaller than the 

534 number of bands in the calculator. This is useful if the highest 

535 bands of the DFT calculation are not well converged. 

536 

537 ``spin``: The spin channel to be considered. 

538 The Wannier code treats each spin channel independently. 

539 

540 ``fixedenergy`` / ``fixedstates``: Fixed part of Hilbert space. 

541 Determine the fixed part of Hilbert space by either a maximal 

542 energy *or* a number of bands (possibly a list for multiple 

543 k-points). 

544 Default is None meaning that the number of fixed states is equated 

545 to ``nwannier``. 

546 The maximal energy is relative to the CBM if there is a finite 

547 bandgap or to the Fermi level if there is none. 

548 

549 ``file``: Read localization and rotation matrices from this file. 

550 

551 ``initialwannier``: Initial guess for Wannier rotation matrix. 

552 Can be 'bloch' to start from the Bloch states, 'random' to be 

553 randomized, 'orbitals' to start from atom-centered d-orbitals and 

554 randomly placed gaussian centers (see init_orbitals()), 

555 'scdm' to start from localized state selected with SCDM 

556 or a list passed to calc.get_initial_wannier. 

557 

558 ``functional``: The functional used to measure the localization. 

559 Can be 'std' for the standard quadratic functional from the PRB 

560 paper, 'var' to add a variance minimizing term. 

561 

562 ``rng``: Random number generator for ``initialwannier``. 

563 

564 ``log``: Function which logs, such as print(). 

565 """ 

566 # Bloch phase sign convention. 

567 # May require special cases depending on which code is used. 

568 sign = -1 

569 

570 self.log = log 

571 self.calc = calc 

572 

573 self.spin = spin 

574 self.functional = functional 

575 self.initialwannier = initialwannier 

576 self.log('Using functional:', functional) 

577 

578 self.calcdata = get_calcdata(calc) 

579 

580 self.kptgrid = get_monkhorst_pack_size_and_offset(self.kpt_kc)[0] 

581 self.calcdata.kpt_kc *= sign 

582 

583 self.largeunitcell_cc = (self.unitcell_cc.T * self.kptgrid).T 

584 self.weight_d, self.Gdir_dc = calculate_weights(self.largeunitcell_cc) 

585 assert len(self.weight_d) == len(self.Gdir_dc) 

586 

587 if nbands is None: 

588 # XXX Can work with other number of bands than calculator. 

589 # Is this case properly tested, lest we confuse them? 

590 nbands = self.calcdata.nbands 

591 self.nbands = nbands 

592 

593 self.fixedstates_k, self.nwannier = choose_states( 

594 self.calcdata, 

595 fixedenergy, 

596 fixedstates, 

597 self.Nk, 

598 nwannier, 

599 log, 

600 spin, 

601 ) 

602 

603 # Compute the number of extra degrees of freedom (EDF) 

604 self.edf_k = self.nwannier - self.fixedstates_k 

605 

606 self.log(f'Wannier: Fixed states : {self.fixedstates_k}') 

607 self.log(f'Wannier: Extra degrees of freedom: {self.edf_k}') 

608 

609 self.kklst_dk, k0_dkc = get_kklst(self.kpt_kc, self.Gdir_dc) 

610 

611 # Set the inverse list of neighboring k-points 

612 self.invkklst_dk = get_invkklst(self.kklst_dk) 

613 

614 Nw = self.nwannier 

615 Nb = self.nbands 

616 self.Z_dkww = np.empty((self.Ndir, self.Nk, Nw, Nw), complex) 

617 self.V_knw = np.zeros((self.Nk, Nb, Nw), complex) 

618 

619 if file is None: 

620 self.Z_dknn = self.new_Z(calc, k0_dkc) 

621 self.initialize(file=file, initialwannier=initialwannier, rng=rng) 

622 

623 @property 

624 def atoms(self): 

625 return self.calcdata.atoms 

626 

627 @property 

628 def kpt_kc(self): 

629 return self.calcdata.kpt_kc 

630 

631 @property 

632 def Ndir(self): 

633 return len(self.weight_d) # Number of directions 

634 

635 @property 

636 def Nk(self): 

637 return len(self.kpt_kc) 

638 

639 def new_Z(self, calc, k0_dkc): 

640 Nb = self.nbands 

641 Z_dknn = np.empty((self.Ndir, self.Nk, Nb, Nb), complex) 

642 for d, dirG in enumerate(self.Gdir_dc): 

643 for k in range(self.Nk): 

644 k1 = self.kklst_dk[d, k] 

645 k0_c = k0_dkc[d, k] 

646 Z_dknn[d, k] = calc.get_wannier_localization_matrix( 

647 nbands=Nb, 

648 dirG=dirG, 

649 kpoint=k, 

650 nextkpoint=k1, 

651 G_I=k0_c, 

652 spin=self.spin, 

653 ) 

654 return Z_dknn 

655 

656 @property 

657 def unitcell_cc(self): 

658 return self.atoms.cell 

659 

660 @property 

661 def U_kww(self): 

662 return self.wannier_state.U_kww 

663 

664 @property 

665 def C_kul(self): 

666 return self.wannier_state.C_kul 

667 

668 def initialize(self, file=None, initialwannier='random', rng=np.random): 

669 """Re-initialize current rotation matrix. 

670 

671 Keywords are identical to those of the constructor. 

672 """ 

673 from ase.dft.wannierstate import WannierSpec, WannierState 

674 

675 spec = WannierSpec( 

676 self.Nk, self.nwannier, self.nbands, self.fixedstates_k 

677 ) 

678 

679 if file is not None: 

680 with paropen(file, 'r') as fd: 

681 Z_dknn, U_kww, C_kul = read_json(fd, always_array=False) 

682 self.Z_dknn = Z_dknn 

683 wannier_state = WannierState(C_kul, U_kww) 

684 elif initialwannier == 'bloch': 

685 # Set U and C to pick the lowest Bloch states 

686 wannier_state = spec.bloch(self.edf_k) 

687 elif initialwannier == 'random': 

688 wannier_state = spec.random(rng, self.edf_k) 

689 elif initialwannier == 'orbitals': 

690 orbitals = init_orbitals(self.atoms, self.nwannier, rng) 

691 wannier_state = spec.initial_orbitals( 

692 self.calc, orbitals, self.kptgrid, self.edf_k, self.spin 

693 ) 

694 elif initialwannier == 'scdm': 

695 wannier_state = spec.scdm(self.calc, self.kpt_kc, self.spin) 

696 else: 

697 wannier_state = spec.initial_wannier( 

698 self.calc, initialwannier, self.kptgrid, self.edf_k, self.spin 

699 ) 

700 

701 self.wannier_state = wannier_state 

702 self.update() 

703 

704 def save(self, file): 

705 """Save information on localization and rotation matrices to file.""" 

706 with paropen(file, 'w') as fd: 

707 write_json(fd, (self.Z_dknn, self.U_kww, self.C_kul)) 

708 

709 def update(self): 

710 # Update large rotation matrix V (from rotation U and coeff C) 

711 for k, M in enumerate(self.fixedstates_k): 

712 self.V_knw[k, :M] = self.U_kww[k, :M] 

713 if M < self.nwannier: 

714 self.V_knw[k, M:] = self.C_kul[k] @ self.U_kww[k, M:] 

715 # else: self.V_knw[k, M:] = 0.0 

716 

717 # Calculate the Zk matrix from the large rotation matrix: 

718 # Zk = V^d[k] Zbloch V[k1] 

719 for d in range(self.Ndir): 

720 for k in range(self.Nk): 

721 k1 = self.kklst_dk[d, k] 

722 self.Z_dkww[d, k] = dag(self.V_knw[k]) @ ( 

723 self.Z_dknn[d, k] @ self.V_knw[k1] 

724 ) 

725 

726 # Update the new Z matrix 

727 self.Z_dww = self.Z_dkww.sum(axis=1) / self.Nk 

728 

729 def get_optimal_nwannier(self, nwrange=5, random_reps=5, tolerance=1e-6): 

730 """ 

731 The optimal value for 'nwannier', maybe. 

732 

733 The optimal value is the one that gives the lowest average value for 

734 the spread of the most delocalized Wannier function in the set. 

735 

736 ``nwrange``: number of different values to try for 'nwannier', the 

737 values will span a symmetric range around ``nwannier`` if possible. 

738 

739 ``random_reps``: number of repetitions with random seed, the value is 

740 then an average over these repetitions. 

741 

742 ``tolerance``: tolerance for the gradient descent algorithm, can be 

743 useful to increase the speed, with a cost in accuracy. 

744 """ 

745 

746 # Define the range of values to try based on the maximum number of fixed 

747 # states (that is the minimum number of WFs we need) and the number of 

748 # available bands we have. 

749 max_number_fixedstates = np.max(self.fixedstates_k) 

750 

751 min_range_value = max( 

752 self.nwannier - int(np.floor(nwrange / 2)), max_number_fixedstates 

753 ) 

754 max_range_value = min(min_range_value + nwrange, self.nbands + 1) 

755 Nws = np.arange(min_range_value, max_range_value) 

756 

757 # If there is no randomness, there is no need to repeat 

758 random_initials = ['random', 'orbitals'] 

759 if self.initialwannier not in random_initials: 

760 random_reps = 1 

761 

762 t = -time() 

763 avg_max_spreads = np.zeros(len(Nws)) 

764 for j, Nw in enumerate(Nws): 

765 self.log('Trying with Nw =', Nw) 

766 

767 # Define once with the fastest 'initialwannier', 

768 # then initialize with random seeds in the for loop 

769 wan = Wannier( 

770 nwannier=int(Nw), 

771 calc=self.calc, 

772 nbands=self.nbands, 

773 spin=self.spin, 

774 functional=self.functional, 

775 initialwannier='bloch', 

776 log=self.log, 

777 rng=self.rng, 

778 ) 

779 wan.fixedstates_k = self.fixedstates_k 

780 wan.edf_k = wan.nwannier - wan.fixedstates_k 

781 

782 max_spreads = np.zeros(random_reps) 

783 for i in range(random_reps): 

784 wan.initialize(initialwannier=self.initialwannier) 

785 wan.localize(tolerance=tolerance) 

786 max_spreads[i] = np.max(wan.get_spreads()) 

787 

788 avg_max_spreads[j] = max_spreads.mean() 

789 

790 self.log('Average spreads: ', avg_max_spreads) 

791 t += time() 

792 self.log(f'Execution time: {t:.1f}s') 

793 

794 return Nws[np.argmin(avg_max_spreads)] 

795 

796 def get_centers(self, scaled=False): 

797 """Calculate the Wannier centers 

798 

799 :: 

800 

801 pos = L / 2pi * phase(diag(Z)) 

802 """ 

803 coord_wc = np.angle(self.Z_dww[:3].diagonal(0, 1, 2)).T / (2 * pi) % 1 

804 if not scaled: 

805 coord_wc = coord_wc @ self.largeunitcell_cc 

806 return coord_wc 

807 

808 def get_radii(self): 

809 r"""Calculate the spread of the Wannier functions. 

810 

811 :: 

812 

813 -- / L \ 2 2 

814 radius**2 = - > | --- | ln |Z| 

815 --d \ 2pi / 

816 

817 Note that this function can fail with some Bravais lattices, 

818 see `get_spreads()` for a more robust alternative. 

819 """ 

820 r2 = -(self.largeunitcell_cc.diagonal() ** 2 / (2 * pi) ** 2) @ np.log( 

821 abs(self.Z_dww[:3].diagonal(0, 1, 2)) ** 2 

822 ) 

823 return np.sqrt(r2) 

824 

825 def get_spreads(self): 

826 r"""Calculate the spread of the Wannier functions in Ų. 

827 The definition is based on eq. 13 in PRB61-15 by Berghold and Mundy. 

828 

829 :: 

830 

831 / 1 \ 2 -- 2 

832 spread = - |----| > W_d ln |Z_dw| 

833 \2 pi/ --d 

834 

835 

836 """ 

837 # compute weights without normalization, to keep physical dimension 

838 weight_d, _ = calculate_weights(self.largeunitcell_cc, normalize=False) 

839 Z2_dw = square_modulus_of_Z_diagonal(self.Z_dww) 

840 spread_w = -(np.log(Z2_dw).T @ weight_d).real / (2 * np.pi) ** 2 

841 return spread_w 

842 

843 def get_spectral_weight(self, w): 

844 return abs(self.V_knw[:, :, w]) ** 2 / self.Nk 

845 

846 def get_pdos(self, w, energies, width): 

847 """Projected density of states (PDOS). 

848 

849 Returns the (PDOS) for Wannier function ``w``. The calculation 

850 is performed over the energy grid specified in energies. The 

851 PDOS is produced as a sum of Gaussians centered at the points 

852 of the energy grid and with the specified width. 

853 """ 

854 spec_kn = self.get_spectral_weight(w) 

855 dos = np.zeros(len(energies)) 

856 for k, spec_n in enumerate(spec_kn): 

857 eig_n = self.calcdata.eps_skn[self.spin, k] 

858 for weight, eig in zip(spec_n, eig_n): 

859 # Add gaussian centered at the eigenvalue 

860 x = ((energies - eig) / width) ** 2 

861 dos += weight * np.exp(-x.clip(0.0, 40.0)) / (sqrt(pi) * width) 

862 return dos 

863 

864 def translate(self, w, R): 

865 """Translate the w'th Wannier function 

866 

867 The distance vector R = [n1, n2, n3], is in units of the basis 

868 vectors of the small cell. 

869 """ 

870 for kpt_c, U_ww in zip(self.kpt_kc, self.U_kww): 

871 U_ww[:, w] *= np.exp(2.0j * pi * (np.array(R) @ kpt_c)) 

872 self.update() 

873 

874 def translate_to_cell(self, w, cell): 

875 """Translate the w'th Wannier function to specified cell""" 

876 scaled_c = np.angle(self.Z_dww[:3, w, w]) * self.kptgrid / (2 * pi) 

877 trans = np.array(cell) - np.floor(scaled_c) 

878 self.translate(w, trans) 

879 

880 def translate_all_to_cell(self, cell=(0, 0, 0)): 

881 r"""Translate all Wannier functions to specified cell. 

882 

883 Move all Wannier orbitals to a specific unit cell. There 

884 exists an arbitrariness in the positions of the Wannier 

885 orbitals relative to the unit cell. This method can move all 

886 orbitals to the unit cell specified by ``cell``. For a 

887 `\Gamma`-point calculation, this has no effect. For a 

888 **k**-point calculation the periodicity of the orbitals are 

889 given by the large unit cell defined by repeating the original 

890 unitcell by the number of **k**-points in each direction. In 

891 this case it is useful to move the orbitals away from the 

892 boundaries of the large cell before plotting them. For a bulk 

893 calculation with, say 10x10x10 **k** points, one could move 

894 the orbitals to the cell [2,2,2]. In this way the pbc 

895 boundary conditions will not be noticed. 

896 """ 

897 scaled_wc = ( 

898 np.angle(self.Z_dww[:3].diagonal(0, 1, 2)).T 

899 * self.kptgrid 

900 / (2 * pi) 

901 ) 

902 trans_wc = np.array(cell)[None] - np.floor(scaled_wc) 

903 for kpt_c, U_ww in zip(self.kpt_kc, self.U_kww): 

904 U_ww *= np.exp(2.0j * pi * (trans_wc @ kpt_c)) 

905 self.update() 

906 

907 def distances(self, R): 

908 """Relative distances between centers. 

909 

910 Returns a matrix with the distances between different Wannier centers. 

911 R = [n1, n2, n3] is in units of the basis vectors of the small cell 

912 and allows one to measure the distance with centers moved to a 

913 different small cell. 

914 The dimension of the matrix is [Nw, Nw]. 

915 """ 

916 Nw = self.nwannier 

917 cen = self.get_centers() 

918 r1 = cen.repeat(Nw, axis=0).reshape(Nw, Nw, 3) 

919 r2 = cen.copy() 

920 for i in range(3): 

921 r2 += self.unitcell_cc[i] * R[i] 

922 

923 r2 = np.swapaxes(r2.repeat(Nw, axis=0).reshape(Nw, Nw, 3), 0, 1) 

924 return np.sqrt(np.sum((r1 - r2) ** 2, axis=-1)) 

925 

926 @functools.lru_cache(maxsize=10000) 

927 def _get_hopping(self, n1, n2, n3): 

928 """Returns the matrix H(R)_nm=<0,n|H|R,m>. 

929 

930 :: 

931 

932 1 _ -ik.R 

933 H(R) = <0,n|H|R,m> = --- >_ e H(k) 

934 Nk k 

935 

936 where R = (n1, n2, n3) is the cell-distance (in units of the basis 

937 vectors of the small cell) and n,m are indices of the Wannier functions. 

938 This function caches up to 'maxsize' results. 

939 """ 

940 R = np.array([n1, n2, n3], float) 

941 H_ww = np.zeros([self.nwannier, self.nwannier], complex) 

942 for k, kpt_c in enumerate(self.kpt_kc): 

943 phase = np.exp(-2.0j * pi * (np.array(R) @ kpt_c)) 

944 H_ww += self.get_hamiltonian(k) * phase 

945 return H_ww / self.Nk 

946 

947 def get_hopping(self, R): 

948 """Returns the matrix H(R)_nm=<0,n|H|R,m>. 

949 

950 :: 

951 

952 1 _ -ik.R 

953 H(R) = <0,n|H|R,m> = --- >_ e H(k) 

954 Nk k 

955 

956 where R is the cell-distance (in units of the basis vectors of 

957 the small cell) and n,m are indices of the Wannier functions. 

958 """ 

959 return self._get_hopping(R[0], R[1], R[2]) 

960 

961 def get_hamiltonian(self, k): 

962 """Get Hamiltonian at existing k-vector of index k 

963 

964 :: 

965 

966 dag 

967 H(k) = V diag(eps ) V 

968 k k k 

969 """ 

970 eps_n = self.calcdata.eps_skn[self.spin, k, : self.nbands] 

971 V_nw = self.V_knw[k] 

972 return (dag(V_nw) * eps_n) @ V_nw 

973 

974 def get_hamiltonian_kpoint(self, kpt_c): 

975 """Get Hamiltonian at some new arbitrary k-vector 

976 

977 :: 

978 

979 _ ik.R 

980 H(k) = >_ e H(R) 

981 R 

982 

983 Warning: This method moves all Wannier functions to cell (0, 0, 0) 

984 """ 

985 self.log('Translating all Wannier functions to cell (0, 0, 0)') 

986 self.translate_all_to_cell() 

987 max = (self.kptgrid - 1) // 2 

988 N1, N2, N3 = max 

989 Hk = np.zeros([self.nwannier, self.nwannier], complex) 

990 for n1 in range(-N1, N1 + 1): 

991 for n2 in range(-N2, N2 + 1): 

992 for n3 in range(-N3, N3 + 1): 

993 R = np.array([n1, n2, n3], float) 

994 hop_ww = self.get_hopping(R) 

995 phase = np.exp(+2.0j * pi * (R @ kpt_c)) 

996 Hk += hop_ww * phase 

997 return Hk 

998 

999 def get_function(self, index, repeat=None): 

1000 r"""Get Wannier function on grid. 

1001 

1002 Returns an array with the funcion values of the indicated Wannier 

1003 function on a grid with the size of the *repeated* unit cell. 

1004 

1005 For a calculation using **k**-points the relevant unit cell for 

1006 eg. visualization of the Wannier orbitals is not the original unit 

1007 cell, but rather a larger unit cell defined by repeating the 

1008 original unit cell by the number of **k**-points in each direction. 

1009 Note that for a `\Gamma`-point calculation the large unit cell 

1010 coinsides with the original unit cell. 

1011 The large unitcell also defines the periodicity of the Wannier 

1012 orbitals. 

1013 

1014 ``index`` can be either a single WF or a coordinate vector in terms 

1015 of the WFs. 

1016 """ 

1017 

1018 # Default size of plotting cell is the one corresponding to k-points. 

1019 if repeat is None: 

1020 repeat = self.kptgrid 

1021 N1, N2, N3 = repeat 

1022 

1023 dim = self.calc.get_number_of_grid_points() 

1024 largedim = dim * [N1, N2, N3] 

1025 

1026 wanniergrid = np.zeros(largedim, dtype=complex) 

1027 for k, kpt_c in enumerate(self.kpt_kc): 

1028 # The coordinate vector of wannier functions 

1029 if isinstance(index, int): 

1030 vec_n = self.V_knw[k, :, index] 

1031 else: 

1032 vec_n = self.V_knw[k] @ index 

1033 

1034 wan_G = np.zeros(dim, complex) 

1035 for n, coeff in enumerate(vec_n): 

1036 wan_G += coeff * self.calc.get_pseudo_wave_function( 

1037 n, k, self.spin, pad=True 

1038 ) 

1039 

1040 # Distribute the small wavefunction over large cell: 

1041 for n1 in range(N1): 

1042 for n2 in range(N2): 

1043 for n3 in range(N3): # sign? 

1044 e = np.exp(-2.0j * pi * np.array([n1, n2, n3]) @ kpt_c) 

1045 wanniergrid[ 

1046 n1 * dim[0] : (n1 + 1) * dim[0], 

1047 n2 * dim[1] : (n2 + 1) * dim[1], 

1048 n3 * dim[2] : (n3 + 1) * dim[2], 

1049 ] += e * wan_G 

1050 

1051 # Normalization 

1052 wanniergrid /= np.sqrt(self.Nk) 

1053 return wanniergrid 

1054 

1055 def write_cube(self, index, fname, repeat=None, angle=False): 

1056 """ 

1057 Dump specified Wannier function to a cube file. 

1058 

1059 Arguments: 

1060 

1061 ``index``: Integer, index of the Wannier function to save. 

1062 

1063 ``repeat``: Array of integer, repeat supercell and Wannier function. 

1064 

1065 ``fname``: Name of the cube file. 

1066 

1067 ``angle``: If False, save the absolute value. If True, save 

1068 the complex phase of the Wannier function. 

1069 """ 

1070 from ase.io import write 

1071 

1072 # Default size of plotting cell is the one corresponding to k-points. 

1073 if repeat is None: 

1074 repeat = self.kptgrid 

1075 

1076 # Remove constraints, some are not compatible with repeat() 

1077 atoms = self.atoms.copy() 

1078 atoms.set_constraint() 

1079 atoms = atoms * repeat 

1080 func = self.get_function(index, repeat) 

1081 

1082 # Compute absolute value or complex angle 

1083 if angle: 

1084 data = np.angle(func) 

1085 else: 

1086 if self.Nk == 1: 

1087 func *= np.exp(-1.0j * np.angle(func.max())) 

1088 func = abs(func) 

1089 data = func 

1090 

1091 write(fname, atoms, data=data, format='cube') 

1092 

1093 def localize( 

1094 self, step=0.25, tolerance=1e-08, updaterot=True, updatecoeff=True 

1095 ): 

1096 """Optimize rotation to give maximal localization""" 

1097 md_min( 

1098 self, 

1099 step=step, 

1100 tolerance=tolerance, 

1101 log=self.log, 

1102 updaterot=updaterot, 

1103 updatecoeff=updatecoeff, 

1104 ) 

1105 

1106 def get_functional_value(self): 

1107 """Calculate the value of the spread functional. 

1108 

1109 :: 

1110 

1111 Tr[|ZI|^2]=sum(I)sum(n) w_i|Z_(i)_nn|^2, 

1112 

1113 where w_i are weights. 

1114 

1115 If the functional is set to 'var' it subtracts a variance term 

1116 

1117 :: 

1118 

1119 Nw * var(sum(n) w_i|Z_(i)_nn|^2), 

1120 

1121 where Nw is the number of WFs ``nwannier``. 

1122 """ 

1123 a_w = self._spread_contributions() 

1124 if self.functional == 'std': 

1125 fun = np.sum(a_w) 

1126 elif self.functional == 'var': 

1127 fun = np.sum(a_w) - self.nwannier * np.var(a_w) 

1128 self.log( 

1129 f'std: {np.sum(a_w):.4f}', 

1130 f'\tvar: {self.nwannier * np.var(a_w):.4f}', 

1131 ) 

1132 return fun 

1133 

1134 def get_gradients(self): 

1135 # Determine gradient of the spread functional. 

1136 # 

1137 # The gradient for a rotation A_kij is:: 

1138 # 

1139 # dU = dRho/dA_{k,i,j} = sum(I) sum(k') 

1140 # + Z_jj Z_kk',ij^* - Z_ii Z_k'k,ij^* 

1141 # - Z_ii^* Z_kk',ji + Z_jj^* Z_k'k,ji 

1142 # 

1143 # The gradient for a change of coefficients is:: 

1144 # 

1145 # dRho/da^*_{k,i,j} = sum(I) [[(Z_0)_{k} V_{k'} diag(Z^*) + 

1146 # (Z_0_{k''})^d V_{k''} diag(Z)] * 

1147 # U_k^d]_{N+i,N+j} 

1148 # 

1149 # where diag(Z) is a square,diagonal matrix with Z_nn in the diagonal, 

1150 # k' = k + dk and k = k'' + dk. 

1151 # 

1152 # The extra degrees of freedom chould be kept orthonormal to the fixed 

1153 # space, thus we introduce lagrange multipliers, and minimize instead:: 

1154 # 

1155 # Rho_L = Rho - sum_{k,n,m} lambda_{k,nm} <c_{kn}|c_{km}> 

1156 # 

1157 # for this reason the coefficient gradients should be multiplied 

1158 # by (1 - c c^dag). 

1159 

1160 Nb = self.nbands 

1161 Nw = self.nwannier 

1162 

1163 if self.functional == 'var': 

1164 O_w = self._spread_contributions() 

1165 O_sum = np.sum(O_w) 

1166 

1167 dU = [] 

1168 dC = [] 

1169 for k in range(self.Nk): 

1170 M = self.fixedstates_k[k] 

1171 L = self.edf_k[k] 

1172 U_ww = self.U_kww[k] 

1173 C_ul = self.C_kul[k] 

1174 Utemp_ww = np.zeros((Nw, Nw), complex) 

1175 Ctemp_nw = np.zeros((Nb, Nw), complex) 

1176 

1177 for d, weight in enumerate(self.weight_d): 

1178 if abs(weight) < 1.0e-6: 

1179 continue 

1180 

1181 Z_knn = self.Z_dknn[d] 

1182 diagZ_w = self.Z_dww[d].diagonal() 

1183 Zii_ww = np.repeat(diagZ_w, Nw).reshape(Nw, Nw) 

1184 if self.functional == 'var': 

1185 diagOZ_w = O_w * diagZ_w 

1186 OZii_ww = np.repeat(diagOZ_w, Nw).reshape(Nw, Nw) 

1187 

1188 k1 = self.kklst_dk[d, k] 

1189 k2 = self.invkklst_dk[d, k] 

1190 V_knw = self.V_knw 

1191 Z_kww = self.Z_dkww[d] 

1192 

1193 if L > 0: 

1194 Ctemp_nw += weight * ( 

1195 ( 

1196 (Z_knn[k] @ V_knw[k1]) * diagZ_w.conj() 

1197 + (dag(Z_knn[k2]) @ V_knw[k2]) * diagZ_w 

1198 ) 

1199 @ dag(U_ww) 

1200 ) 

1201 

1202 if self.functional == 'var': 

1203 # Gradient of the variance term, split in two terms 

1204 def variance_term_computer(factor): 

1205 result = ( 

1206 self.nwannier 

1207 * 2 

1208 * weight 

1209 * ( 

1210 ( 

1211 (Z_knn[k] @ V_knw[k1]) * factor.conj() 

1212 + (dag(Z_knn[k2]) @ V_knw[k2]) * factor 

1213 ) 

1214 @ dag(U_ww) 

1215 ) 

1216 / Nw**2 

1217 ) 

1218 return result 

1219 

1220 first_term = ( 

1221 O_sum * variance_term_computer(diagZ_w) / Nw**2 

1222 ) 

1223 

1224 second_term = -variance_term_computer(diagOZ_w) / Nw 

1225 

1226 Ctemp_nw += first_term + second_term 

1227 

1228 temp = Zii_ww.T * Z_kww[k].conj() - Zii_ww * Z_kww[k2].conj() 

1229 Utemp_ww += weight * (temp - dag(temp)) 

1230 

1231 if self.functional == 'var': 

1232 Utemp_ww += ( 

1233 self.nwannier 

1234 * 2 

1235 * O_sum 

1236 * weight 

1237 * (temp - dag(temp)) 

1238 / Nw**2 

1239 ) 

1240 

1241 temp = ( 

1242 OZii_ww.T * Z_kww[k].conj() - OZii_ww * Z_kww[k2].conj() 

1243 ) 

1244 Utemp_ww -= ( 

1245 self.nwannier * 2 * weight * (temp - dag(temp)) / Nw 

1246 ) 

1247 

1248 dU.append(Utemp_ww.ravel()) 

1249 

1250 if L > 0: 

1251 # Ctemp now has same dimension as V, the gradient is in the 

1252 # lower-right (Nb-M) x L block 

1253 Ctemp_ul = Ctemp_nw[M:, M:] 

1254 G_ul = Ctemp_ul - ((C_ul @ dag(C_ul)) @ Ctemp_ul) 

1255 dC.append(G_ul.ravel()) 

1256 

1257 return np.concatenate(dU + dC) 

1258 

1259 def _spread_contributions(self): 

1260 """ 

1261 Compute the contribution of each WF to the spread functional. 

1262 """ 

1263 return (square_modulus_of_Z_diagonal(self.Z_dww).T @ self.weight_d).real 

1264 

1265 def step(self, dX, updaterot=True, updatecoeff=True): 

1266 # dX is (A, dC) where U->Uexp(-A) and C->C+dC 

1267 Nw = self.nwannier 

1268 Nk = self.Nk 

1269 M_k = self.fixedstates_k 

1270 L_k = self.edf_k 

1271 if updaterot: 

1272 A_kww = dX[: Nk * Nw**2].reshape(Nk, Nw, Nw) 

1273 for U, A in zip(self.U_kww, A_kww): 

1274 H = -1.0j * A.conj() 

1275 epsilon, Z = np.linalg.eigh(H) 

1276 # Z contains the eigenvectors as COLUMNS. 

1277 # Since H = iA, dU = exp(-A) = exp(iH) = ZDZ^d 

1278 dU = Z * np.exp(1.0j * epsilon) @ dag(Z) 

1279 if U.dtype == float: 

1280 U[:] = (U @ dU).real 

1281 else: 

1282 U[:] = U @ dU 

1283 

1284 if updatecoeff: 

1285 start = 0 

1286 for C, unocc, L in zip(self.C_kul, self.nbands - M_k, L_k): 

1287 if L == 0 or unocc == 0: 

1288 continue 

1289 Ncoeff = L * unocc 

1290 deltaC = dX[Nk * Nw**2 + start : Nk * Nw**2 + start + Ncoeff] 

1291 C += deltaC.reshape(unocc, L) 

1292 gram_schmidt(C) 

1293 start += Ncoeff 

1294 

1295 self.update()