Coverage for /builds/ase/ase/ase/io/vtkxml.py: 5.33%

75 statements  

« prev     ^ index     » next       coverage.py v7.5.3, created at 2025-08-02 00:12 +0000

1# fmt: off 

2 

3import numpy as np 

4 

5fast = False 

6 

7 

8def write_vti(filename, atoms, data=None): 

9 from vtk import vtkDoubleArray, vtkStructuredPoints, vtkXMLImageDataWriter 

10 

11 # if isinstance(fileobj, str): 

12 # fileobj = paropen(fileobj, 'w') 

13 

14 if isinstance(atoms, list): 

15 if len(atoms) > 1: 

16 raise ValueError('Can only write one configuration to a VTI file!') 

17 atoms = atoms[0] 

18 

19 if data is None: 

20 raise ValueError('VTK XML Image Data (VTI) format requires data!') 

21 

22 data = np.asarray(data) 

23 

24 if data.dtype == complex: 

25 data = np.abs(data) 

26 

27 cell = atoms.get_cell() 

28 

29 if not np.all(cell == np.diag(np.diag(cell))): 

30 raise ValueError('Unit cell must be orthogonal') 

31 

32 bbox = np.array(list(zip(np.zeros(3), cell.diagonal()))).ravel() 

33 

34 # Create a VTK grid of structured points 

35 spts = vtkStructuredPoints() 

36 spts.SetWholeBoundingBox(bbox) 

37 spts.SetDimensions(data.shape) 

38 spts.SetSpacing(cell.diagonal() / data.shape) 

39 # spts.SetSpacing(paw.gd.h_c * Bohr) 

40 

41 # print('paw.gd.h_c * Bohr=',paw.gd.h_c * Bohr) 

42 # print('atoms.cell.diagonal() / data.shape=', cell.diagonal()/data.shape) 

43 # assert np.all(paw.gd.h_c * Bohr==cell.diagonal()/data.shape) 

44 

45 # s = paw.wfs.kpt_u[0].psit_nG[0].copy() 

46 # data = paw.get_pseudo_wave_function(band=0, kpt=0, spin=0, pad=False) 

47 # spts.point_data.scalars = data.swapaxes(0,2).flatten() 

48 # spts.point_data.scalars.name = 'scalars' 

49 

50 # Allocate a VTK array of type double and copy data 

51 da = vtkDoubleArray() 

52 da.SetName('scalars') 

53 da.SetNumberOfComponents(1) 

54 da.SetNumberOfTuples(np.prod(data.shape)) 

55 

56 for i, d in enumerate(data.swapaxes(0, 2).flatten()): 

57 da.SetTuple1(i, d) 

58 

59 # Assign the VTK array as point data of the grid 

60 spd = spts.GetPointData() # type(spd) is vtkPointData 

61 spd.SetScalars(da) 

62 

63 """ 

64 from vtk.util.vtkImageImportFromArray import vtkImageImportFromArray 

65 iia = vtkImageImportFromArray() 

66 #iia.SetArray(Numeric_asarray(data.swapaxes(0,2).flatten())) 

67 iia.SetArray(Numeric_asarray(data)) 

68 ida = iia.GetOutput() 

69 ipd = ida.GetPointData() 

70 ipd.SetName('scalars') 

71 spd.SetScalars(ipd.GetScalars()) 

72 """ 

73 

74 # Save the ImageData dataset to a VTK XML file. 

75 w = vtkXMLImageDataWriter() 

76 

77 if fast: 

78 w.SetDataModeToAppend() 

79 w.EncodeAppendedDataOff() 

80 else: 

81 w.SetDataModeToAscii() 

82 

83 w.SetFileName(filename) 

84 w.SetInput(spts) 

85 w.Write() 

86 

87 

88def write_vtu(filename, atoms, data=None): 

89 from vtk import ( 

90 VTK_MAJOR_VERSION, 

91 vtkPoints, 

92 vtkUnstructuredGrid, 

93 vtkXMLUnstructuredGridWriter, 

94 ) 

95 from vtk.util.numpy_support import numpy_to_vtk 

96 

97 if isinstance(atoms, list): 

98 if len(atoms) > 1: 

99 raise ValueError('Can only write one configuration to a VTI file!') 

100 atoms = atoms[0] 

101 

102 # Create a VTK grid of structured points 

103 ugd = vtkUnstructuredGrid() 

104 

105 # add atoms as vtk Points 

106 p = vtkPoints() 

107 p.SetNumberOfPoints(len(atoms)) 

108 p.SetDataTypeToDouble() 

109 for i, pos in enumerate(atoms.get_positions()): 

110 p.InsertPoint(i, *pos) 

111 ugd.SetPoints(p) 

112 

113 # add atomic numbers 

114 numbers = numpy_to_vtk(atoms.get_atomic_numbers(), deep=1) 

115 ugd.GetPointData().AddArray(numbers) 

116 numbers.SetName("atomic numbers") 

117 

118 # add tags 

119 tags = numpy_to_vtk(atoms.get_tags(), deep=1) 

120 ugd.GetPointData().AddArray(tags) 

121 tags.SetName("tags") 

122 

123 # add covalent radii 

124 from ase.data import covalent_radii 

125 radii = numpy_to_vtk(covalent_radii[atoms.numbers], deep=1) 

126 ugd.GetPointData().AddArray(radii) 

127 radii.SetName("radii") 

128 

129 # Save the UnstructuredGrid dataset to a VTK XML file. 

130 w = vtkXMLUnstructuredGridWriter() 

131 

132 if fast: 

133 w.SetDataModeToAppend() 

134 w.EncodeAppendedDataOff() 

135 else: 

136 w.GetCompressor().SetCompressionLevel(0) 

137 w.SetDataModeToAscii() 

138 

139 if isinstance(filename, str): 

140 w.SetFileName(filename) 

141 else: 

142 w.SetFileName(filename.name) 

143 if VTK_MAJOR_VERSION <= 5: 

144 w.SetInput(ugd) 

145 else: 

146 w.SetInputData(ugd) 

147 w.Write()