Coverage for /builds/ase/ase/ase/dft/wannierstate.py: 63.79%

58 statements  

« prev     ^ index     » next       coverage.py v7.5.3, created at 2025-08-02 00:12 +0000

1# fmt: off 

2 

3import numpy as np 

4from scipy.linalg import qr 

5 

6 

7def random_orthogonal_matrix(dim, rng=np.random, real=False): 

8 """Generate uniformly distributed random orthogonal matrices""" 

9 if real: 

10 from scipy.stats import special_ortho_group 

11 ortho_m = special_ortho_group.rvs(dim=dim, random_state=rng) 

12 else: 

13 # The best method but not supported on old systems 

14 # from scipy.stats import unitary_group 

15 # ortho_m = unitary_group.rvs(dim=dim, random_state=rng) 

16 

17 # Alternative method from https://stackoverflow.com/questions/38426349 

18 H = rng.random((dim, dim)) 

19 Q, R = qr(H) 

20 ortho_m = Q @ np.diag(np.sign(np.diag(R))) 

21 

22 return ortho_m 

23 

24 

25def _empty(): 

26 return np.empty(0, complex) 

27 

28 

29class WannierSpec: 

30 def __init__(self, Nk, Nw, Nb, fixedstates_k): 

31 self.Nk = Nk 

32 self.Nw = Nw 

33 self.Nb = Nb 

34 self.fixedstates_k = fixedstates_k 

35 

36 def _zeros(self): 

37 return np.zeros((self.Nk, self.Nw, self.Nw), complex) 

38 

39 def bloch(self, edf_k): 

40 U_kww = self._zeros() 

41 C_kul = [] 

42 for U, M, L in zip(U_kww, self.fixedstates_k, edf_k): 

43 U[:] = np.identity(self.Nw, complex) 

44 if L > 0: 

45 C_kul.append(np.identity(self.Nb - M, complex)[:, :L]) 

46 else: 

47 C_kul.append(_empty()) 

48 return WannierState(C_kul, U_kww) 

49 

50 def random(self, rng, edf_k): 

51 # Set U and C to random (orthogonal) matrices 

52 U_kww = self._zeros() 

53 C_kul = [] 

54 for U, M, L in zip(U_kww, self.fixedstates_k, edf_k): 

55 U[:] = random_orthogonal_matrix(self.Nw, rng, real=False) 

56 if L > 0: 

57 C_kul.append(random_orthogonal_matrix( 

58 self.Nb - M, rng=rng, real=False)[:, :L]) 

59 else: 

60 C_kul.append(_empty()) 

61 return WannierState(C_kul, U_kww) 

62 

63 def initial_orbitals(self, calc, orbitals, kptgrid, edf_k, spin): 

64 C_kul, U_kww = calc.initial_wannier( 

65 orbitals, kptgrid, self.fixedstates_k, edf_k, spin, self.Nb) 

66 return WannierState(C_kul, U_kww) 

67 

68 def initial_wannier(self, calc, method, kptgrid, edf_k, spin): 

69 C_kul, U_kww = calc.initial_wannier( 

70 method, kptgrid, self.fixedstates_k, 

71 edf_k, spin, self.Nb) 

72 return WannierState(C_kul, U_kww) 

73 

74 def scdm(self, calc, kpt_kc, spin): 

75 from ase.dft.wannier import scdm 

76 

77 # get the size of the grid and check if there are Nw bands: 

78 ps = calc.get_pseudo_wave_function(band=self.Nw, 

79 kpt=0, spin=0) 

80 Ng = ps.size 

81 pseudo_nkG = np.zeros((self.Nb, self.Nk, Ng), 

82 dtype=np.complex128) 

83 for k in range(self.Nk): 

84 for n in range(self.Nb): 

85 pseudo_nkG[n, k] = \ 

86 calc.get_pseudo_wave_function( 

87 band=n, kpt=k, spin=spin).ravel() 

88 

89 # Use initial guess to determine U and C 

90 C_kul, U_kww = scdm(pseudo_nkG, 

91 kpts=kpt_kc, 

92 fixed_k=self.fixedstates_k, 

93 Nw=self.Nw) 

94 return WannierState(C_kul, U_kww) 

95 

96 

97class WannierState: 

98 def __init__(self, C_kul, U_kww): 

99 # Number of u is not always the same, so C_kul is ragged 

100 self.C_kul = [C_ul.astype(complex) for C_ul in C_kul] 

101 self.U_kww = U_kww