Coverage for ase / visualize / mlab.py: 13.51%

74 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-03-30 08:22 +0000

1# fmt: off 

2 

3import optparse 

4 

5import numpy as np 

6 

7from ase.calculators.calculator import get_calculator_class 

8from ase.data import covalent_radii 

9from ase.data.colors import cpk_colors 

10from ase.io.cube import read_cube_data 

11 

12 

13def plot(atoms, data, contours): 

14 """Plot atoms, unit-cell and iso-surfaces using Mayavi. 

15 

16 Parameters 

17 ---------- 

18 

19 atoms: Atoms object 

20 Positions, atomiz numbers and unit-cell. 

21 data: 3-d ndarray of float 

22 Data for iso-surfaces. 

23 countours: list of float 

24 Contour values. 

25 """ 

26 

27 # Delay slow imports: 

28 import os 

29 

30 from mayavi import mlab 

31 

32 # mayavi GUI bug fix for remote access via ssh (X11 forwarding) 

33 if "SSH_CONNECTION" in os.environ: 

34 f = mlab.gcf() 

35 f.scene._lift() 

36 

37 mlab.figure(1, bgcolor=(1, 1, 1)) # make a white figure 

38 

39 # Plot the atoms as spheres: 

40 for pos, Z in zip(atoms.positions, atoms.numbers): 

41 mlab.points3d(*pos, 

42 scale_factor=covalent_radii[Z], 

43 resolution=20, 

44 color=tuple(cpk_colors[Z])) 

45 

46 # Draw the unit cell: 

47 A = atoms.cell 

48 for i1, a in enumerate(A): 

49 i2 = (i1 + 1) % 3 

50 i3 = (i1 + 2) % 3 

51 for b in [np.zeros(3), A[i2]]: 

52 for c in [np.zeros(3), A[i3]]: 

53 p1 = b + c 

54 p2 = p1 + a 

55 mlab.plot3d([p1[0], p2[0]], 

56 [p1[1], p2[1]], 

57 [p1[2], p2[2]], 

58 tube_radius=0.1) 

59 

60 cp = mlab.contour3d(data, contours=contours, transparent=True, 

61 opacity=0.5, colormap='hot') 

62 # Do some tvtk magic in order to allow for non-orthogonal unit cells: 

63 polydata = cp.actor.actors[0].mapper.input 

64 pts = np.array(polydata.points) - 1 

65 # Transform the points to the unit cell: 

66 polydata.points = np.dot(pts, A / np.array(data.shape)[:, np.newaxis]) 

67 

68 # Apparently we need this to redraw the figure, maybe it can be done in 

69 # another way? 

70 mlab.view(azimuth=155, elevation=70, distance='auto') 

71 # Show the 3d plot: 

72 mlab.show() 

73 

74 

75def view_mlab(atoms, *args, **kwargs): 

76 return plot(atoms, *args, **kwargs) 

77 

78 

79description = """\ 

80Plot iso-surfaces from a cube-file or a wave function or an electron 

81density from a calculator-restart file.""" 

82 

83 

84def main(args=None): 

85 parser = optparse.OptionParser(usage='%prog [options] filename', 

86 description=description) 

87 add = parser.add_option 

88 add('-n', '--band-index', type=int, metavar='INDEX', 

89 help='Band index counting from zero.') 

90 add('-s', '--spin-index', type=int, metavar='SPIN', 

91 help='Spin index: zero or one.') 

92 add('-e', '--electrostatic-potential', action='store_true', 

93 help='Plot the electrostatic potential.') 

94 add('-c', '--contours', default='4', 

95 help='Use "-c 3" for 3 contours or "-c -0.5,0.5" for specific ' + 

96 'values. Default is four contours.') 

97 add('-r', '--repeat', help='Example: "-r 2,2,2".') 

98 add('-C', '--calculator-name', metavar='NAME', help='Name of calculator.') 

99 

100 opts, args = parser.parse_args(args) 

101 if len(args) != 1: 

102 parser.error('Incorrect number of arguments') 

103 

104 arg = args[0] 

105 if arg.endswith('.cube'): 

106 data, atoms = read_cube_data(arg) 

107 else: 

108 calc = get_calculator_class(opts.calculator_name)(arg, txt=None) 

109 atoms = calc.get_atoms() 

110 if opts.band_index is None: 

111 if opts.electrostatic_potential: 

112 data = calc.get_electrostatic_potential() 

113 else: 

114 data = calc.get_pseudo_density(opts.spin_index) 

115 else: 

116 data = calc.get_pseudo_wave_function(opts.band_index, 

117 opts.spin_index or 0) 

118 if data.dtype == complex: 

119 data = abs(data) 

120 

121 mn = data.min() 

122 mx = data.max() 

123 print('Min: %16.6f' % mn) 

124 print('Max: %16.6f' % mx) 

125 

126 if opts.contours.isdigit(): 

127 n = int(opts.contours) 

128 d = (mx - mn) / n 

129 contours = np.linspace(mn + d / 2, mx - d / 2, n).tolist() 

130 else: 

131 contours = [float(x) for x in opts.contours.rstrip(',').split(',')] 

132 

133 if len(contours) == 1: 

134 print('1 contour:', contours[0]) 

135 else: 

136 print('%d contours: %.6f, ..., %.6f' % 

137 (len(contours), contours[0], contours[-1])) 

138 

139 if opts.repeat: 

140 repeat = [int(r) for r in opts.repeat.split(',')] 

141 data = np.tile(data, repeat) 

142 atoms *= repeat 

143 

144 plot(atoms, data, contours) 

145 

146 

147if __name__ == '__main__': 

148 main()