Coverage for /builds/ase/ase/ase/dft/bz.py: 94.09%

186 statements  

« prev     ^ index     » next       coverage.py v7.5.3, created at 2025-08-02 00:12 +0000

1# fmt: off 

2 

3from itertools import product 

4from math import cos, pi, sin 

5from typing import Any, Dict, Optional, Tuple, Union 

6 

7import numpy as np 

8from matplotlib.patches import FancyArrowPatch 

9from mpl_toolkits.mplot3d import Axes3D, proj3d 

10from scipy.spatial.transform import Rotation 

11 

12from ase.cell import Cell 

13 

14 

15def bz_vertices(icell, dim=3): 

16 """Return the vertices and the normal vector of the BZ. 

17 

18 See https://xkcd.com/1421 ...""" 

19 from scipy.spatial import Voronoi 

20 

21 icell = icell.copy() 

22 if dim < 3: 

23 icell[2, 2] = 1e-3 

24 if dim < 2: 

25 icell[1, 1] = 1e-3 

26 

27 indices = (np.indices((3, 3, 3)) - 1).reshape((3, 27)) 

28 G = np.dot(icell.T, indices).T 

29 vor = Voronoi(G) 

30 bz1 = [] 

31 for vertices, points in zip(vor.ridge_vertices, vor.ridge_points): 

32 if -1 not in vertices and 13 in points: 

33 normal = G[points].sum(0) 

34 normal /= (normal**2).sum()**0.5 

35 bz1.append((vor.vertices[vertices], normal)) 

36 return bz1 

37 

38 

39class FlatPlot: 

40 """Helper class for 1D/2D Brillouin zone plots.""" 

41 

42 axis_dim = 2 # Dimension of the plotting surface (2 even if it's 1D BZ). 

43 point_options = {'zorder': 5} 

44 

45 def new_axes(self, fig): 

46 return fig.gca() 

47 

48 def adjust_view(self, ax, minp, maxp, symmetric: bool = True): 

49 """Ajusting view property of the drawn BZ. (1D/2D) 

50 

51 Parameters 

52 ---------- 

53 ax: Axes 

54 matplotlib Axes object. 

55 minp: float 

56 minimum value for the plotting region, which detemines the 

57 bottom left corner of the figure. if symmetric is set as True, 

58 this value is ignored. 

59 maxp: float 

60 maximum value for the plotting region, which detemines the 

61 top right corner of the figure. 

62 symmetric: bool 

63 if True, set the (0,0) position (Gamma-bar position) at the center 

64 of the figure. 

65 

66 """ 

67 ax.autoscale_view(tight=True) 

68 s = maxp * 1.05 

69 if symmetric: 

70 ax.set_xlim(-s, s) 

71 ax.set_ylim(-s, s) 

72 else: 

73 ax.set_xlim(minp * 1.05, s) 

74 ax.set_ylim(minp * 1.05, s) 

75 ax.set_aspect('equal') 

76 

77 def draw_arrow(self, ax, vector, **kwargs): 

78 ax.arrow(0, 0, vector[0], vector[1], 

79 lw=1, 

80 length_includes_head=True, 

81 head_width=0.03, 

82 head_length=0.05, 

83 **kwargs) 

84 

85 def label_options(self, point): 

86 ha_s = ['right', 'left', 'right'] 

87 va_s = ['bottom', 'bottom', 'top'] 

88 

89 x, y = point 

90 ha = ha_s[int(np.sign(x))] 

91 va = va_s[int(np.sign(y))] 

92 return {'ha': ha, 'va': va, 'zorder': 4} 

93 

94 def view(self): 

95 pass 

96 

97 

98class SpacePlot: 

99 """Helper class for ordinary (3D) Brillouin zone plots. 

100 

101 Attributes 

102 ---------- 

103 azim : float 

104 Azimuthal angle in radian for viewing 3D BZ. 

105 default value is pi/5 

106 elev : float 

107 Elevation angle in radian for viewing 3D BZ. 

108 default value is pi/6 

109 

110 """ 

111 axis_dim = 3 

112 point_options: Dict[str, Any] = {} 

113 

114 def __init__(self, *, azim: Optional[float] = None, 

115 elev: Optional[float] = None): 

116 class Arrow3D(FancyArrowPatch): 

117 def __init__(self, ax, xs, ys, zs, *args, **kwargs): 

118 FancyArrowPatch.__init__(self, (0, 0), (0, 0), *args, **kwargs) 

119 self._verts3d = xs, ys, zs 

120 self.ax = ax 

121 

122 def draw(self, renderer): 

123 xs3d, ys3d, zs3d = self._verts3d 

124 xs, ys, _zs = proj3d.proj_transform(xs3d, ys3d, 

125 zs3d, self.ax.axes.M) 

126 self.set_positions((xs[0], ys[0]), (xs[1], ys[1])) 

127 FancyArrowPatch.draw(self, renderer) 

128 

129 # FIXME: Compatibility fix for matplotlib 3.5.0: Handling of 3D 

130 # artists have changed and all 3D artists now need 

131 # "do_3d_projection". Since this class is a hack that manually 

132 # projects onto the 3D axes we don't need to do anything in this 

133 # method. Ideally we shouldn't resort to a hack like this. 

134 def do_3d_projection(self, *_, **__): 

135 return 0 

136 

137 self.arrow3d = Arrow3D 

138 self.azim: float = pi / 5 if azim is None else azim 

139 self.elev: float = pi / 6 if elev is None else elev 

140 self.view = [ 

141 sin(self.azim) * cos(self.elev), 

142 cos(self.azim) * cos(self.elev), 

143 sin(self.elev), 

144 ] 

145 

146 def new_axes(self, fig): 

147 return fig.add_subplot(projection='3d') 

148 

149 def draw_arrow(self, ax: Axes3D, vector, **kwargs): 

150 ax.add_artist(self.arrow3d( 

151 ax, 

152 [0, vector[0]], 

153 [0, vector[1]], 

154 [0, vector[2]], 

155 mutation_scale=20, 

156 arrowstyle='-|>', 

157 **kwargs)) 

158 

159 def adjust_view(self, ax, minp, maxp, symmetric=True): 

160 """Ajusting view property of the drawn BZ. (3D) 

161 

162 Parameters 

163 ---------- 

164 ax: Axes 

165 matplotlib Axes object. 

166 minp: float 

167 minimum value for the plotting region, which detemines the 

168 bottom left corner of the figure. if symmetric is set as True, 

169 this value is ignored. 

170 maxp: float 

171 maximum value for the plotting region, which detemines the 

172 top right corner of the figure. 

173 symmetric: bool 

174 Currently, this is not used, just for keeping consistency with 2D 

175 version. 

176 

177 """ 

178 import matplotlib.pyplot as plt 

179 

180 # ax.set_aspect('equal') <-- won't work anymore in 3.1.0 

181 ax.view_init(azim=np.rad2deg(self.azim), elev=np.rad2deg(self.elev)) 

182 # We want aspect 'equal', but apparently there was a bug in 

183 # matplotlib causing wrong behaviour. Matplotlib raises 

184 # NotImplementedError as of v3.1.0. This is a bit unfortunate 

185 # because the workarounds known to StackOverflow and elsewhere 

186 # all involve using set_aspect('equal') and then doing 

187 # something more. 

188 # 

189 # We try to get square axes here by setting a square figure, 

190 # but this is probably rather inexact. 

191 fig = ax.get_figure() 

192 xx = plt.figaspect(1.0) 

193 fig.set_figheight(xx[1]) 

194 fig.set_figwidth(xx[0]) 

195 

196 ax.set_proj_type('ortho') 

197 

198 minp0 = 0.9 * minp # Here we cheat a bit to trim spacings 

199 maxp0 = 0.9 * maxp 

200 ax.set_xlim3d(minp0, maxp0) 

201 ax.set_ylim3d(minp0, maxp0) 

202 ax.set_zlim3d(minp0, maxp0) 

203 

204 ax.set_box_aspect([1, 1, 1]) 

205 

206 def label_options(self, point): 

207 return dict(ha='center', va='bottom') 

208 

209 

210def normalize_name(name): 

211 if name == 'G': 

212 return '\\Gamma' 

213 

214 if len(name) > 1: 

215 import re 

216 

217 m = re.match(r'^(\D+?)(\d*)$', name) 

218 if m is None: 

219 raise ValueError(f'Bad label: {name}') 

220 name, num = m.group(1, 2) 

221 if num: 

222 name = f'{name}_{{{num}}}' 

223 return name 

224 

225 

226def bz_plot(cell: Cell, vectors: bool = False, paths=None, points=None, 

227 azim: Optional[float] = None, elev: Optional[float] = None, 

228 scale=1, interactive: bool = False, 

229 transforms: Optional[list] = None, 

230 repeat: Union[Tuple[int, int], Tuple[int, int, int]] = (1, 1, 1), 

231 pointstyle: Optional[dict] = None, 

232 ax=None, show=False, **kwargs): 

233 """Plot the Brillouin zone of the Cell 

234 

235 Parameters 

236 ---------- 

237 cell: Cell 

238 Cell object for BZ drawing. 

239 vectors : bool 

240 if True, show the vector. 

241 paths : list[tuple[str, np.ndarray]] | None 

242 Special point name and its coordinate position 

243 points : np.ndarray 

244 Coordinate points along the paths. 

245 azim : float | None 

246 Azimuthal angle in radian for viewing 3D BZ. 

247 elev : float | None 

248 Elevation angle in radian for viewing 3D BZ. 

249 scale : float 

250 Not used. To be removed? 

251 interactive : bool 

252 Not effectively works. To be removed? 

253 transforms: List 

254 List of linear transformation (scipy.spatial.transform.Rotation) 

255 repeat: Tuple[int, int] | Tuple[int, int, int] 

256 Set the repeating draw of BZ. default is (1, 1, 1), no repeat. 

257 pointstyle : Dict 

258 Style of the special point 

259 ax : Axes | Axes3D 

260 matplolib Axes (Axes3D in 3D) object 

261 show : bool 

262 If true, show the figure. 

263 **kwargs 

264 Additional keyword arguments to pass to ax.plot 

265 

266 Returns 

267 ------- 

268 ax 

269 A matplotlib axis object. 

270 """ 

271 import matplotlib.pyplot as plt 

272 

273 if pointstyle is None: 

274 pointstyle = {} 

275 

276 if transforms is None: 

277 transforms = [Rotation.from_rotvec((0, 0, 0))] 

278 

279 cell = cell.copy() 

280 

281 dimensions = cell.rank 

282 if dimensions == 3: 

283 plotter: Union[SpacePlot, FlatPlot] = SpacePlot(azim=azim, elev=elev) 

284 else: 

285 plotter = FlatPlot() 

286 assert dimensions > 0, 'No BZ for 0D!' 

287 

288 if ax is None: 

289 ax = plotter.new_axes(plt.gcf()) 

290 

291 assert not np.array(cell)[dimensions:, :].any() 

292 assert not np.array(cell)[:, dimensions:].any() 

293 

294 icell = cell.reciprocal() 

295 kpoints = points 

296 bz1 = bz_vertices(icell, dim=dimensions) 

297 if len(repeat) == 2: 

298 repeat = (repeat[0], repeat[1], 1) 

299 

300 maxp = 0.0 

301 minp = 0.0 

302 for bz_i in bz_index(repeat): 

303 for points, normal in bz1: 

304 shift = np.dot(np.array(icell).T, np.array(bz_i)) 

305 for transform in transforms: 

306 shift = transform.apply(shift) 

307 ls = '-' 

308 xyz = np.concatenate([points, points[:1]]) 

309 for transform in transforms: 

310 xyz = transform.apply(xyz) 

311 x, y, z = xyz.T 

312 x, y, z = x + shift[0], y + shift[1], z + shift[2] 

313 if dimensions == 3: 

314 if normal @ plotter.view < 0 and not interactive: 

315 ls = ':' 

316 if plotter.axis_dim == 2: 

317 ax.plot(x, y, c='k', ls=ls, **kwargs) 

318 else: 

319 ax.plot(x, y, z, c='k', ls=ls, **kwargs) 

320 maxp = max(maxp, x.max(), y.max(), z.max()) 

321 minp = min(minp, x.min(), y.min(), z.min()) 

322 

323 if vectors: 

324 for transform in transforms: 

325 icell = transform.apply(icell) 

326 assert isinstance(icell, np.ndarray) 

327 for i in range(dimensions): 

328 plotter.draw_arrow(ax, icell[i], color='k') 

329 

330 # XXX Can this be removed? 

331 if dimensions == 3: 

332 maxp = max(maxp, 0.6 * icell.max()) 

333 else: 

334 maxp = max(maxp, icell.max()) 

335 

336 if paths is not None: 

337 for names, points in paths: 

338 for transform in transforms: 

339 points = transform.apply(points) 

340 coords = np.array(points).T[:plotter.axis_dim, :] 

341 ax.plot(*coords, c='r', ls='-') 

342 

343 for name, point in zip(names, points): 

344 name = normalize_name(name) 

345 point = point[:plotter.axis_dim] 

346 ax.text(*point, rf'$\mathrm{{{name}}}$', 

347 color='g', **plotter.label_options(point)) 

348 

349 if kpoints is not None: 

350 kw = {'c': 'b', **plotter.point_options, **pointstyle} 

351 for transform in transforms: 

352 kpoints = transform.apply(kpoints) 

353 ax.scatter(*kpoints[:, :plotter.axis_dim].T, **kw) 

354 

355 ax.set_axis_off() 

356 

357 if repeat == (1, 1, 1): 

358 plotter.adjust_view(ax, minp, maxp) 

359 else: 

360 plotter.adjust_view(ax, minp, maxp, symmetric=False) 

361 if show: 

362 plt.show() 

363 

364 return ax 

365 

366 

367def bz_index(repeat): 

368 """BZ index from the repeat 

369 

370 A helper function to iterating drawing BZ. 

371 

372 Parameters 

373 ---------- 

374 repeat: Tuple[int, int] | Tuple[int, int, int] 

375 repeating for drawing BZ 

376 

377 Returns 

378 ------- 

379 Iterator[Tuple[int, int, int]] 

380 

381 >>> list(_bz_index((1, 2, -2))) 

382 [(0, 0, 0), (0, 0, -1), (0, 1, 0), (0, 1, -1)] 

383 

384 """ 

385 if len(repeat) == 2: 

386 repeat = (repeat[0], repeat[1], 1) 

387 assert len(repeat) == 3 

388 assert repeat[0] != 0 

389 assert repeat[1] != 0 

390 assert repeat[2] != 0 

391 repeat_along_a = ( 

392 range(0, repeat[0]) if repeat[0] > 0 else range(0, repeat[0], -1) 

393 ) 

394 repeat_along_b = ( 

395 range(0, repeat[1]) if repeat[1] > 0 else range(0, repeat[1], -1) 

396 ) 

397 repeat_along_c = ( 

398 range(0, repeat[2]) if repeat[2] > 0 else range(0, repeat[2], -1) 

399 ) 

400 return product(repeat_along_a, repeat_along_b, repeat_along_c)