Coverage for /builds/ase/ase/ase/dft/wannierstate.py: 63.79%
58 statements
« prev ^ index » next coverage.py v7.5.3, created at 2025-08-02 00:12 +0000
« prev ^ index » next coverage.py v7.5.3, created at 2025-08-02 00:12 +0000
1# fmt: off
3import numpy as np
4from scipy.linalg import qr
7def random_orthogonal_matrix(dim, rng=np.random, real=False):
8 """Generate uniformly distributed random orthogonal matrices"""
9 if real:
10 from scipy.stats import special_ortho_group
11 ortho_m = special_ortho_group.rvs(dim=dim, random_state=rng)
12 else:
13 # The best method but not supported on old systems
14 # from scipy.stats import unitary_group
15 # ortho_m = unitary_group.rvs(dim=dim, random_state=rng)
17 # Alternative method from https://stackoverflow.com/questions/38426349
18 H = rng.random((dim, dim))
19 Q, R = qr(H)
20 ortho_m = Q @ np.diag(np.sign(np.diag(R)))
22 return ortho_m
25def _empty():
26 return np.empty(0, complex)
29class WannierSpec:
30 def __init__(self, Nk, Nw, Nb, fixedstates_k):
31 self.Nk = Nk
32 self.Nw = Nw
33 self.Nb = Nb
34 self.fixedstates_k = fixedstates_k
36 def _zeros(self):
37 return np.zeros((self.Nk, self.Nw, self.Nw), complex)
39 def bloch(self, edf_k):
40 U_kww = self._zeros()
41 C_kul = []
42 for U, M, L in zip(U_kww, self.fixedstates_k, edf_k):
43 U[:] = np.identity(self.Nw, complex)
44 if L > 0:
45 C_kul.append(np.identity(self.Nb - M, complex)[:, :L])
46 else:
47 C_kul.append(_empty())
48 return WannierState(C_kul, U_kww)
50 def random(self, rng, edf_k):
51 # Set U and C to random (orthogonal) matrices
52 U_kww = self._zeros()
53 C_kul = []
54 for U, M, L in zip(U_kww, self.fixedstates_k, edf_k):
55 U[:] = random_orthogonal_matrix(self.Nw, rng, real=False)
56 if L > 0:
57 C_kul.append(random_orthogonal_matrix(
58 self.Nb - M, rng=rng, real=False)[:, :L])
59 else:
60 C_kul.append(_empty())
61 return WannierState(C_kul, U_kww)
63 def initial_orbitals(self, calc, orbitals, kptgrid, edf_k, spin):
64 C_kul, U_kww = calc.initial_wannier(
65 orbitals, kptgrid, self.fixedstates_k, edf_k, spin, self.Nb)
66 return WannierState(C_kul, U_kww)
68 def initial_wannier(self, calc, method, kptgrid, edf_k, spin):
69 C_kul, U_kww = calc.initial_wannier(
70 method, kptgrid, self.fixedstates_k,
71 edf_k, spin, self.Nb)
72 return WannierState(C_kul, U_kww)
74 def scdm(self, calc, kpt_kc, spin):
75 from ase.dft.wannier import scdm
77 # get the size of the grid and check if there are Nw bands:
78 ps = calc.get_pseudo_wave_function(band=self.Nw,
79 kpt=0, spin=0)
80 Ng = ps.size
81 pseudo_nkG = np.zeros((self.Nb, self.Nk, Ng),
82 dtype=np.complex128)
83 for k in range(self.Nk):
84 for n in range(self.Nb):
85 pseudo_nkG[n, k] = \
86 calc.get_pseudo_wave_function(
87 band=n, kpt=k, spin=spin).ravel()
89 # Use initial guess to determine U and C
90 C_kul, U_kww = scdm(pseudo_nkG,
91 kpts=kpt_kc,
92 fixed_k=self.fixedstates_k,
93 Nw=self.Nw)
94 return WannierState(C_kul, U_kww)
97class WannierState:
98 def __init__(self, C_kul, U_kww):
99 # Number of u is not always the same, so C_kul is ragged
100 self.C_kul = [C_ul.astype(complex) for C_ul in C_kul]
101 self.U_kww = U_kww