Coverage for /builds/ase/ase/ase/visualize/mlab.py: 14.47%

76 statements  

« prev     ^ index     » next       coverage.py v7.5.3, created at 2025-08-02 00:12 +0000

1# fmt: off 

2 

3import optparse 

4 

5import numpy as np 

6 

7from ase.calculators.calculator import get_calculator_class 

8from ase.data import covalent_radii 

9from ase.data.colors import cpk_colors 

10from ase.io.cube import read_cube_data 

11 

12 

13def plot(atoms, data, contours): 

14 """Plot atoms, unit-cell and iso-surfaces using Mayavi. 

15 

16 Parameters: 

17 

18 atoms: Atoms object 

19 Positions, atomiz numbers and unit-cell. 

20 data: 3-d ndarray of float 

21 Data for iso-surfaces. 

22 countours: list of float 

23 Contour values. 

24 """ 

25 

26 # Delay slow imports: 

27 import os 

28 

29 from mayavi import mlab 

30 

31 # mayavi GUI bug fix for remote access via ssh (X11 forwarding) 

32 if "SSH_CONNECTION" in os.environ: 

33 f = mlab.gcf() 

34 f.scene._lift() 

35 

36 mlab.figure(1, bgcolor=(1, 1, 1)) # make a white figure 

37 

38 # Plot the atoms as spheres: 

39 for pos, Z in zip(atoms.positions, atoms.numbers): 

40 mlab.points3d(*pos, 

41 scale_factor=covalent_radii[Z], 

42 resolution=20, 

43 color=tuple(cpk_colors[Z])) 

44 

45 # Draw the unit cell: 

46 A = atoms.cell 

47 for i1, a in enumerate(A): 

48 i2 = (i1 + 1) % 3 

49 i3 = (i1 + 2) % 3 

50 for b in [np.zeros(3), A[i2]]: 

51 for c in [np.zeros(3), A[i3]]: 

52 p1 = b + c 

53 p2 = p1 + a 

54 mlab.plot3d([p1[0], p2[0]], 

55 [p1[1], p2[1]], 

56 [p1[2], p2[2]], 

57 tube_radius=0.1) 

58 

59 cp = mlab.contour3d(data, contours=contours, transparent=True, 

60 opacity=0.5, colormap='hot') 

61 # Do some tvtk magic in order to allow for non-orthogonal unit cells: 

62 polydata = cp.actor.actors[0].mapper.input 

63 pts = np.array(polydata.points) - 1 

64 # Transform the points to the unit cell: 

65 polydata.points = np.dot(pts, A / np.array(data.shape)[:, np.newaxis]) 

66 

67 # Apparently we need this to redraw the figure, maybe it can be done in 

68 # another way? 

69 mlab.view(azimuth=155, elevation=70, distance='auto') 

70 # Show the 3d plot: 

71 mlab.show() 

72 

73 

74def view_mlab(atoms, *args, **kwargs): 

75 return plot(atoms, *args, **kwargs) 

76 

77 

78description = """\ 

79Plot iso-surfaces from a cube-file or a wave function or an electron 

80density from a calculator-restart file.""" 

81 

82 

83def main(args=None): 

84 parser = optparse.OptionParser(usage='%prog [options] filename', 

85 description=description) 

86 add = parser.add_option 

87 add('-n', '--band-index', type=int, metavar='INDEX', 

88 help='Band index counting from zero.') 

89 add('-s', '--spin-index', type=int, metavar='SPIN', 

90 help='Spin index: zero or one.') 

91 add('-e', '--electrostatic-potential', action='store_true', 

92 help='Plot the electrostatic potential.') 

93 add('-c', '--contours', default='4', 

94 help='Use "-c 3" for 3 contours or "-c -0.5,0.5" for specific ' + 

95 'values. Default is four contours.') 

96 add('-r', '--repeat', help='Example: "-r 2,2,2".') 

97 add('-C', '--calculator-name', metavar='NAME', help='Name of calculator.') 

98 

99 opts, args = parser.parse_args(args) 

100 if len(args) != 1: 

101 parser.error('Incorrect number of arguments') 

102 

103 arg = args[0] 

104 if arg.endswith('.cube'): 

105 data, atoms = read_cube_data(arg) 

106 else: 

107 calc = get_calculator_class(opts.calculator_name)(arg, txt=None) 

108 atoms = calc.get_atoms() 

109 if opts.band_index is None: 

110 if opts.electrostatic_potential: 

111 data = calc.get_electrostatic_potential() 

112 else: 

113 data = calc.get_pseudo_density(opts.spin_index) 

114 else: 

115 data = calc.get_pseudo_wave_function(opts.band_index, 

116 opts.spin_index or 0) 

117 if data.dtype == complex: 

118 data = abs(data) 

119 

120 mn = data.min() 

121 mx = data.max() 

122 print('Min: %16.6f' % mn) 

123 print('Max: %16.6f' % mx) 

124 

125 if opts.contours.isdigit(): 

126 n = int(opts.contours) 

127 d = (mx - mn) / n 

128 contours = np.linspace(mn + d / 2, mx - d / 2, n).tolist() 

129 else: 

130 contours = [float(x) for x in opts.contours.rstrip(',').split(',')] 

131 

132 if len(contours) == 1: 

133 print('1 contour:', contours[0]) 

134 else: 

135 print('%d contours: %.6f, ..., %.6f' % 

136 (len(contours), contours[0], contours[-1])) 

137 

138 if opts.repeat: 

139 repeat = [int(r) for r in opts.repeat.split(',')] 

140 data = np.tile(data, repeat) 

141 atoms *= repeat 

142 

143 plot(atoms, data, contours) 

144 

145 

146if __name__ == '__main__': 

147 main()