Coverage for /builds/ase/ase/ase/dft/bandgap.py: 87.69%

130 statements  

« prev     ^ index     » next       coverage.py v7.5.3, created at 2025-08-02 00:12 +0000

1# fmt: off 

2 

3import warnings 

4from dataclasses import dataclass 

5 

6import numpy as np 

7 

8spin_error = ( 

9 'The spin keyword is no longer supported. Please call the function ' 

10 'with the energies corresponding to the desired spins.') 

11_deprecated = object() 

12 

13 

14def get_band_gap(calc, direct=False, spin=_deprecated): 

15 warnings.warn('Please use ase.dft.bandgap.bandgap() instead!') 

16 gap, (s1, k1, _n1), (s2, k2, _n2) = bandgap(calc, direct, spin=spin) 

17 ns = calc.get_number_of_spins() 

18 if ns == 2: 

19 return gap, (s1, k1), (s2, k2) 

20 return gap, k1, k2 

21 

22 

23@dataclass 

24class GapInfo: 

25 eigenvalues: np.ndarray 

26 

27 def __post_init__(self): 

28 self._gapinfo = _bandgap(self.eigenvalues, direct=False) 

29 self._direct_gapinfo = _bandgap(self.eigenvalues, direct=True) 

30 

31 @classmethod 

32 def fromcalc(cls, calc): 

33 kpts = calc.get_ibz_k_points() 

34 nk = len(kpts) 

35 ns = calc.get_number_of_spins() 

36 eigenvalues = np.array([[calc.get_eigenvalues(kpt=k, spin=s) 

37 for k in range(nk)] 

38 for s in range(ns)]) 

39 

40 efermi = calc.get_fermi_level() 

41 return cls(eigenvalues - efermi) 

42 

43 def gap(self): 

44 return self._gapinfo 

45 

46 def direct_gap(self): 

47 return self._direct_gapinfo 

48 

49 @property 

50 def is_metallic(self) -> bool: 

51 return self._gapinfo[0] == 0.0 

52 

53 @property 

54 def gap_is_direct(self) -> bool: 

55 """Whether the direct and indirect gaps are the same transition.""" 

56 return self._gapinfo[1:] == self._direct_gapinfo[1:] 

57 

58 def description(self, *, ibz_kpoints=None) -> str: 

59 """Return human-friendly description of direct/indirect gap. 

60 

61 If ibz_k_points are given, coordinates are printed as well.""" 

62 from typing import List 

63 

64 lines: List[str] = [] 

65 add = lines.append 

66 

67 def skn(skn): 

68 """Convert k-point indices (s, k, n) to string.""" 

69 description = 's={}, k={}, n={}'.format(*skn) 

70 if ibz_kpoints is not None: 

71 coordtxt = '[{:.2f}, {:.2f}, {:.2f}]'.format( 

72 *ibz_kpoints[skn[1]]) 

73 description = f'{description}, {coordtxt}' 

74 return f'({description})' 

75 

76 gap, skn1, skn2 = self.gap() 

77 direct_gap, skn_direct1, skn_direct2 = self.direct_gap() 

78 

79 if self.is_metallic: 

80 add('No gap') 

81 else: 

82 add(f'Gap: {gap:.3f} eV') 

83 add('Transition (v -> c):') 

84 add(f' {skn(skn1)} -> {skn(skn2)}') 

85 

86 if self.gap_is_direct: 

87 add('No difference between direct/indirect transitions') 

88 else: 

89 add('Direct/indirect transitions are different') 

90 add(f'Direct gap: {direct_gap:.3f} eV') 

91 if skn_direct1[0] == skn_direct2[0]: 

92 add(f'Transition at: {skn(skn_direct1)}') 

93 else: 

94 transition = skn((f'{skn_direct1[0]}->{skn_direct2[0]}', 

95 *skn_direct1[1:])) 

96 add(f'Transition at: {transition}') 

97 

98 return '\n'.join(lines) 

99 

100 

101def bandgap(calc=None, direct=False, spin=_deprecated, 

102 eigenvalues=None, efermi=None, output=None, kpts=None): 

103 """Calculates the band-gap. 

104 

105 Parameters: 

106 

107 calc: Calculator object 

108 Electronic structure calculator object. 

109 direct: bool 

110 Calculate direct band-gap. 

111 eigenvalues: ndarray of shape (nspin, nkpt, nband) or (nkpt, nband) 

112 Eigenvalues. 

113 efermi: float 

114 Fermi level (defaults to 0.0). 

115 

116 Returns a (gap, p1, p2) tuple where p1 and p2 are tuples of indices of the 

117 valence and conduction points (s, k, n). 

118 

119 Example: 

120 

121 >>> gap, p1, p2 = bandgap(silicon.calc) 

122 >>> print(gap, p1, p2) 

123 1.2 (0, 0, 3), (0, 5, 4) 

124 >>> gap, p1, p2 = bandgap(silicon.calc, direct=True) 

125 >>> print(gap, p1, p2) 

126 3.4 (0, 0, 3), (0, 0, 4) 

127 """ 

128 

129 if spin is not _deprecated: 

130 raise RuntimeError(spin_error) 

131 

132 if calc: 

133 kpts = calc.get_ibz_k_points() 

134 nk = len(kpts) 

135 ns = calc.get_number_of_spins() 

136 eigenvalues = np.array([[calc.get_eigenvalues(kpt=k, spin=s) 

137 for k in range(nk)] 

138 for s in range(ns)]) 

139 if efermi is None: 

140 efermi = calc.get_fermi_level() 

141 

142 efermi = efermi or 0.0 

143 

144 gapinfo = GapInfo(eigenvalues - efermi) 

145 

146 e_skn = gapinfo.eigenvalues 

147 if eigenvalues.ndim == 2: 

148 e_skn = e_skn[np.newaxis] # spinors 

149 

150 if not np.isfinite(e_skn).all(): 

151 raise ValueError('Bad eigenvalues!') 

152 

153 gap, (s1, k1, n1), (s2, k2, n2) = _bandgap(e_skn, direct) 

154 

155 if eigenvalues.ndim != 3: 

156 p1 = (k1, n1) 

157 p2 = (k2, n2) 

158 else: 

159 p1 = (s1, k1, n1) 

160 p2 = (s2, k2, n2) 

161 

162 return gap, p1, p2 

163 

164 

165def _bandgap(e_skn, direct): 

166 """Helper function.""" 

167 ns, nk, nb = e_skn.shape 

168 s1 = s2 = k1 = k2 = n1 = n2 = None 

169 

170 N_sk = (e_skn < 0.0).sum(2) # number of occupied bands 

171 

172 # Check for bands crossing the fermi-level 

173 if ns == 1: 

174 if np.ptp(N_sk[0]) > 0: 

175 return 0.0, (None, None, None), (None, None, None) 

176 else: 

177 if (np.ptp(N_sk, axis=1) > 0).any(): 

178 return 0.0, (None, None, None), (None, None, None) 

179 

180 if (N_sk == 0).any() or (N_sk == nb).any(): 

181 raise ValueError('Too few bands!') 

182 

183 e_skn = np.array([[e_skn[s, k, N_sk[s, k] - 1:N_sk[s, k] + 1] 

184 for k in range(nk)] 

185 for s in range(ns)]) 

186 ev_sk = e_skn[:, :, 0] # valence band 

187 ec_sk = e_skn[:, :, 1] # conduction band 

188 

189 if ns == 1: 

190 s1 = 0 

191 s2 = 0 

192 gap, k1, k2 = find_gap(ev_sk[0], ec_sk[0], direct) 

193 n1 = N_sk[0, 0] - 1 

194 n2 = n1 + 1 

195 return gap, (0, k1, n1), (0, k2, n2) 

196 

197 gap, k1, k2 = find_gap(ev_sk.ravel(), ec_sk.ravel(), direct) 

198 if direct: 

199 # Check also spin flips: 

200 for s in [0, 1]: 

201 g, k, _ = find_gap(ev_sk[s], ec_sk[1 - s], direct) 

202 if g < gap: 

203 gap = g 

204 k1 = k + nk * s 

205 k2 = k + nk * (1 - s) 

206 

207 if gap > 0.0: 

208 s1, k1 = divmod(k1, nk) 

209 s2, k2 = divmod(k2, nk) 

210 n1 = N_sk[s1, k1] - 1 

211 n2 = N_sk[s2, k2] 

212 return gap, (s1, k1, n1), (s2, k2, n2) 

213 return 0.0, (None, None, None), (None, None, None) 

214 

215 

216def find_gap(ev_k, ec_k, direct): 

217 """Helper function.""" 

218 if direct: 

219 gap_k = ec_k - ev_k 

220 k = gap_k.argmin() 

221 return gap_k[k], k, k 

222 kv = ev_k.argmax() 

223 kc = ec_k.argmin() 

224 return ec_k[kc] - ev_k[kv], kv, kc