Coverage for /builds/ase/ase/ase/dft/bz.py: 94.09%
186 statements
« prev ^ index » next coverage.py v7.5.3, created at 2025-08-02 00:12 +0000
« prev ^ index » next coverage.py v7.5.3, created at 2025-08-02 00:12 +0000
1# fmt: off
3from itertools import product
4from math import cos, pi, sin
5from typing import Any, Dict, Optional, Tuple, Union
7import numpy as np
8from matplotlib.patches import FancyArrowPatch
9from mpl_toolkits.mplot3d import Axes3D, proj3d
10from scipy.spatial.transform import Rotation
12from ase.cell import Cell
15def bz_vertices(icell, dim=3):
16 """Return the vertices and the normal vector of the BZ.
18 See https://xkcd.com/1421 ..."""
19 from scipy.spatial import Voronoi
21 icell = icell.copy()
22 if dim < 3:
23 icell[2, 2] = 1e-3
24 if dim < 2:
25 icell[1, 1] = 1e-3
27 indices = (np.indices((3, 3, 3)) - 1).reshape((3, 27))
28 G = np.dot(icell.T, indices).T
29 vor = Voronoi(G)
30 bz1 = []
31 for vertices, points in zip(vor.ridge_vertices, vor.ridge_points):
32 if -1 not in vertices and 13 in points:
33 normal = G[points].sum(0)
34 normal /= (normal**2).sum()**0.5
35 bz1.append((vor.vertices[vertices], normal))
36 return bz1
39class FlatPlot:
40 """Helper class for 1D/2D Brillouin zone plots."""
42 axis_dim = 2 # Dimension of the plotting surface (2 even if it's 1D BZ).
43 point_options = {'zorder': 5}
45 def new_axes(self, fig):
46 return fig.gca()
48 def adjust_view(self, ax, minp, maxp, symmetric: bool = True):
49 """Ajusting view property of the drawn BZ. (1D/2D)
51 Parameters
52 ----------
53 ax: Axes
54 matplotlib Axes object.
55 minp: float
56 minimum value for the plotting region, which detemines the
57 bottom left corner of the figure. if symmetric is set as True,
58 this value is ignored.
59 maxp: float
60 maximum value for the plotting region, which detemines the
61 top right corner of the figure.
62 symmetric: bool
63 if True, set the (0,0) position (Gamma-bar position) at the center
64 of the figure.
66 """
67 ax.autoscale_view(tight=True)
68 s = maxp * 1.05
69 if symmetric:
70 ax.set_xlim(-s, s)
71 ax.set_ylim(-s, s)
72 else:
73 ax.set_xlim(minp * 1.05, s)
74 ax.set_ylim(minp * 1.05, s)
75 ax.set_aspect('equal')
77 def draw_arrow(self, ax, vector, **kwargs):
78 ax.arrow(0, 0, vector[0], vector[1],
79 lw=1,
80 length_includes_head=True,
81 head_width=0.03,
82 head_length=0.05,
83 **kwargs)
85 def label_options(self, point):
86 ha_s = ['right', 'left', 'right']
87 va_s = ['bottom', 'bottom', 'top']
89 x, y = point
90 ha = ha_s[int(np.sign(x))]
91 va = va_s[int(np.sign(y))]
92 return {'ha': ha, 'va': va, 'zorder': 4}
94 def view(self):
95 pass
98class SpacePlot:
99 """Helper class for ordinary (3D) Brillouin zone plots.
101 Attributes
102 ----------
103 azim : float
104 Azimuthal angle in radian for viewing 3D BZ.
105 default value is pi/5
106 elev : float
107 Elevation angle in radian for viewing 3D BZ.
108 default value is pi/6
110 """
111 axis_dim = 3
112 point_options: Dict[str, Any] = {}
114 def __init__(self, *, azim: Optional[float] = None,
115 elev: Optional[float] = None):
116 class Arrow3D(FancyArrowPatch):
117 def __init__(self, ax, xs, ys, zs, *args, **kwargs):
118 FancyArrowPatch.__init__(self, (0, 0), (0, 0), *args, **kwargs)
119 self._verts3d = xs, ys, zs
120 self.ax = ax
122 def draw(self, renderer):
123 xs3d, ys3d, zs3d = self._verts3d
124 xs, ys, _zs = proj3d.proj_transform(xs3d, ys3d,
125 zs3d, self.ax.axes.M)
126 self.set_positions((xs[0], ys[0]), (xs[1], ys[1]))
127 FancyArrowPatch.draw(self, renderer)
129 # FIXME: Compatibility fix for matplotlib 3.5.0: Handling of 3D
130 # artists have changed and all 3D artists now need
131 # "do_3d_projection". Since this class is a hack that manually
132 # projects onto the 3D axes we don't need to do anything in this
133 # method. Ideally we shouldn't resort to a hack like this.
134 def do_3d_projection(self, *_, **__):
135 return 0
137 self.arrow3d = Arrow3D
138 self.azim: float = pi / 5 if azim is None else azim
139 self.elev: float = pi / 6 if elev is None else elev
140 self.view = [
141 sin(self.azim) * cos(self.elev),
142 cos(self.azim) * cos(self.elev),
143 sin(self.elev),
144 ]
146 def new_axes(self, fig):
147 return fig.add_subplot(projection='3d')
149 def draw_arrow(self, ax: Axes3D, vector, **kwargs):
150 ax.add_artist(self.arrow3d(
151 ax,
152 [0, vector[0]],
153 [0, vector[1]],
154 [0, vector[2]],
155 mutation_scale=20,
156 arrowstyle='-|>',
157 **kwargs))
159 def adjust_view(self, ax, minp, maxp, symmetric=True):
160 """Ajusting view property of the drawn BZ. (3D)
162 Parameters
163 ----------
164 ax: Axes
165 matplotlib Axes object.
166 minp: float
167 minimum value for the plotting region, which detemines the
168 bottom left corner of the figure. if symmetric is set as True,
169 this value is ignored.
170 maxp: float
171 maximum value for the plotting region, which detemines the
172 top right corner of the figure.
173 symmetric: bool
174 Currently, this is not used, just for keeping consistency with 2D
175 version.
177 """
178 import matplotlib.pyplot as plt
180 # ax.set_aspect('equal') <-- won't work anymore in 3.1.0
181 ax.view_init(azim=np.rad2deg(self.azim), elev=np.rad2deg(self.elev))
182 # We want aspect 'equal', but apparently there was a bug in
183 # matplotlib causing wrong behaviour. Matplotlib raises
184 # NotImplementedError as of v3.1.0. This is a bit unfortunate
185 # because the workarounds known to StackOverflow and elsewhere
186 # all involve using set_aspect('equal') and then doing
187 # something more.
188 #
189 # We try to get square axes here by setting a square figure,
190 # but this is probably rather inexact.
191 fig = ax.get_figure()
192 xx = plt.figaspect(1.0)
193 fig.set_figheight(xx[1])
194 fig.set_figwidth(xx[0])
196 ax.set_proj_type('ortho')
198 minp0 = 0.9 * minp # Here we cheat a bit to trim spacings
199 maxp0 = 0.9 * maxp
200 ax.set_xlim3d(minp0, maxp0)
201 ax.set_ylim3d(minp0, maxp0)
202 ax.set_zlim3d(minp0, maxp0)
204 ax.set_box_aspect([1, 1, 1])
206 def label_options(self, point):
207 return dict(ha='center', va='bottom')
210def normalize_name(name):
211 if name == 'G':
212 return '\\Gamma'
214 if len(name) > 1:
215 import re
217 m = re.match(r'^(\D+?)(\d*)$', name)
218 if m is None:
219 raise ValueError(f'Bad label: {name}')
220 name, num = m.group(1, 2)
221 if num:
222 name = f'{name}_{{{num}}}'
223 return name
226def bz_plot(cell: Cell, vectors: bool = False, paths=None, points=None,
227 azim: Optional[float] = None, elev: Optional[float] = None,
228 scale=1, interactive: bool = False,
229 transforms: Optional[list] = None,
230 repeat: Union[Tuple[int, int], Tuple[int, int, int]] = (1, 1, 1),
231 pointstyle: Optional[dict] = None,
232 ax=None, show=False, **kwargs):
233 """Plot the Brillouin zone of the Cell
235 Parameters
236 ----------
237 cell: Cell
238 Cell object for BZ drawing.
239 vectors : bool
240 if True, show the vector.
241 paths : list[tuple[str, np.ndarray]] | None
242 Special point name and its coordinate position
243 points : np.ndarray
244 Coordinate points along the paths.
245 azim : float | None
246 Azimuthal angle in radian for viewing 3D BZ.
247 elev : float | None
248 Elevation angle in radian for viewing 3D BZ.
249 scale : float
250 Not used. To be removed?
251 interactive : bool
252 Not effectively works. To be removed?
253 transforms: List
254 List of linear transformation (scipy.spatial.transform.Rotation)
255 repeat: Tuple[int, int] | Tuple[int, int, int]
256 Set the repeating draw of BZ. default is (1, 1, 1), no repeat.
257 pointstyle : Dict
258 Style of the special point
259 ax : Axes | Axes3D
260 matplolib Axes (Axes3D in 3D) object
261 show : bool
262 If true, show the figure.
263 **kwargs
264 Additional keyword arguments to pass to ax.plot
266 Returns
267 -------
268 ax
269 A matplotlib axis object.
270 """
271 import matplotlib.pyplot as plt
273 if pointstyle is None:
274 pointstyle = {}
276 if transforms is None:
277 transforms = [Rotation.from_rotvec((0, 0, 0))]
279 cell = cell.copy()
281 dimensions = cell.rank
282 if dimensions == 3:
283 plotter: Union[SpacePlot, FlatPlot] = SpacePlot(azim=azim, elev=elev)
284 else:
285 plotter = FlatPlot()
286 assert dimensions > 0, 'No BZ for 0D!'
288 if ax is None:
289 ax = plotter.new_axes(plt.gcf())
291 assert not np.array(cell)[dimensions:, :].any()
292 assert not np.array(cell)[:, dimensions:].any()
294 icell = cell.reciprocal()
295 kpoints = points
296 bz1 = bz_vertices(icell, dim=dimensions)
297 if len(repeat) == 2:
298 repeat = (repeat[0], repeat[1], 1)
300 maxp = 0.0
301 minp = 0.0
302 for bz_i in bz_index(repeat):
303 for points, normal in bz1:
304 shift = np.dot(np.array(icell).T, np.array(bz_i))
305 for transform in transforms:
306 shift = transform.apply(shift)
307 ls = '-'
308 xyz = np.concatenate([points, points[:1]])
309 for transform in transforms:
310 xyz = transform.apply(xyz)
311 x, y, z = xyz.T
312 x, y, z = x + shift[0], y + shift[1], z + shift[2]
313 if dimensions == 3:
314 if normal @ plotter.view < 0 and not interactive:
315 ls = ':'
316 if plotter.axis_dim == 2:
317 ax.plot(x, y, c='k', ls=ls, **kwargs)
318 else:
319 ax.plot(x, y, z, c='k', ls=ls, **kwargs)
320 maxp = max(maxp, x.max(), y.max(), z.max())
321 minp = min(minp, x.min(), y.min(), z.min())
323 if vectors:
324 for transform in transforms:
325 icell = transform.apply(icell)
326 assert isinstance(icell, np.ndarray)
327 for i in range(dimensions):
328 plotter.draw_arrow(ax, icell[i], color='k')
330 # XXX Can this be removed?
331 if dimensions == 3:
332 maxp = max(maxp, 0.6 * icell.max())
333 else:
334 maxp = max(maxp, icell.max())
336 if paths is not None:
337 for names, points in paths:
338 for transform in transforms:
339 points = transform.apply(points)
340 coords = np.array(points).T[:plotter.axis_dim, :]
341 ax.plot(*coords, c='r', ls='-')
343 for name, point in zip(names, points):
344 name = normalize_name(name)
345 point = point[:plotter.axis_dim]
346 ax.text(*point, rf'$\mathrm{{{name}}}$',
347 color='g', **plotter.label_options(point))
349 if kpoints is not None:
350 kw = {'c': 'b', **plotter.point_options, **pointstyle}
351 for transform in transforms:
352 kpoints = transform.apply(kpoints)
353 ax.scatter(*kpoints[:, :plotter.axis_dim].T, **kw)
355 ax.set_axis_off()
357 if repeat == (1, 1, 1):
358 plotter.adjust_view(ax, minp, maxp)
359 else:
360 plotter.adjust_view(ax, minp, maxp, symmetric=False)
361 if show:
362 plt.show()
364 return ax
367def bz_index(repeat):
368 """BZ index from the repeat
370 A helper function to iterating drawing BZ.
372 Parameters
373 ----------
374 repeat: Tuple[int, int] | Tuple[int, int, int]
375 repeating for drawing BZ
377 Returns
378 -------
379 Iterator[Tuple[int, int, int]]
381 >>> list(_bz_index((1, 2, -2)))
382 [(0, 0, 0), (0, 0, -1), (0, 1, 0), (0, 1, -1)]
384 """
385 if len(repeat) == 2:
386 repeat = (repeat[0], repeat[1], 1)
387 assert len(repeat) == 3
388 assert repeat[0] != 0
389 assert repeat[1] != 0
390 assert repeat[2] != 0
391 repeat_along_a = (
392 range(0, repeat[0]) if repeat[0] > 0 else range(0, repeat[0], -1)
393 )
394 repeat_along_b = (
395 range(0, repeat[1]) if repeat[1] > 0 else range(0, repeat[1], -1)
396 )
397 repeat_along_c = (
398 range(0, repeat[2]) if repeat[2] > 0 else range(0, repeat[2], -1)
399 )
400 return product(repeat_along_a, repeat_along_b, repeat_along_c)