Coverage for /builds/ase/ase/ase/io/vtkxml.py: 5.33%
75 statements
« prev ^ index » next coverage.py v7.5.3, created at 2025-08-02 00:12 +0000
« prev ^ index » next coverage.py v7.5.3, created at 2025-08-02 00:12 +0000
1# fmt: off
3import numpy as np
5fast = False
8def write_vti(filename, atoms, data=None):
9 from vtk import vtkDoubleArray, vtkStructuredPoints, vtkXMLImageDataWriter
11 # if isinstance(fileobj, str):
12 # fileobj = paropen(fileobj, 'w')
14 if isinstance(atoms, list):
15 if len(atoms) > 1:
16 raise ValueError('Can only write one configuration to a VTI file!')
17 atoms = atoms[0]
19 if data is None:
20 raise ValueError('VTK XML Image Data (VTI) format requires data!')
22 data = np.asarray(data)
24 if data.dtype == complex:
25 data = np.abs(data)
27 cell = atoms.get_cell()
29 if not np.all(cell == np.diag(np.diag(cell))):
30 raise ValueError('Unit cell must be orthogonal')
32 bbox = np.array(list(zip(np.zeros(3), cell.diagonal()))).ravel()
34 # Create a VTK grid of structured points
35 spts = vtkStructuredPoints()
36 spts.SetWholeBoundingBox(bbox)
37 spts.SetDimensions(data.shape)
38 spts.SetSpacing(cell.diagonal() / data.shape)
39 # spts.SetSpacing(paw.gd.h_c * Bohr)
41 # print('paw.gd.h_c * Bohr=',paw.gd.h_c * Bohr)
42 # print('atoms.cell.diagonal() / data.shape=', cell.diagonal()/data.shape)
43 # assert np.all(paw.gd.h_c * Bohr==cell.diagonal()/data.shape)
45 # s = paw.wfs.kpt_u[0].psit_nG[0].copy()
46 # data = paw.get_pseudo_wave_function(band=0, kpt=0, spin=0, pad=False)
47 # spts.point_data.scalars = data.swapaxes(0,2).flatten()
48 # spts.point_data.scalars.name = 'scalars'
50 # Allocate a VTK array of type double and copy data
51 da = vtkDoubleArray()
52 da.SetName('scalars')
53 da.SetNumberOfComponents(1)
54 da.SetNumberOfTuples(np.prod(data.shape))
56 for i, d in enumerate(data.swapaxes(0, 2).flatten()):
57 da.SetTuple1(i, d)
59 # Assign the VTK array as point data of the grid
60 spd = spts.GetPointData() # type(spd) is vtkPointData
61 spd.SetScalars(da)
63 """
64 from vtk.util.vtkImageImportFromArray import vtkImageImportFromArray
65 iia = vtkImageImportFromArray()
66 #iia.SetArray(Numeric_asarray(data.swapaxes(0,2).flatten()))
67 iia.SetArray(Numeric_asarray(data))
68 ida = iia.GetOutput()
69 ipd = ida.GetPointData()
70 ipd.SetName('scalars')
71 spd.SetScalars(ipd.GetScalars())
72 """
74 # Save the ImageData dataset to a VTK XML file.
75 w = vtkXMLImageDataWriter()
77 if fast:
78 w.SetDataModeToAppend()
79 w.EncodeAppendedDataOff()
80 else:
81 w.SetDataModeToAscii()
83 w.SetFileName(filename)
84 w.SetInput(spts)
85 w.Write()
88def write_vtu(filename, atoms, data=None):
89 from vtk import (
90 VTK_MAJOR_VERSION,
91 vtkPoints,
92 vtkUnstructuredGrid,
93 vtkXMLUnstructuredGridWriter,
94 )
95 from vtk.util.numpy_support import numpy_to_vtk
97 if isinstance(atoms, list):
98 if len(atoms) > 1:
99 raise ValueError('Can only write one configuration to a VTI file!')
100 atoms = atoms[0]
102 # Create a VTK grid of structured points
103 ugd = vtkUnstructuredGrid()
105 # add atoms as vtk Points
106 p = vtkPoints()
107 p.SetNumberOfPoints(len(atoms))
108 p.SetDataTypeToDouble()
109 for i, pos in enumerate(atoms.get_positions()):
110 p.InsertPoint(i, *pos)
111 ugd.SetPoints(p)
113 # add atomic numbers
114 numbers = numpy_to_vtk(atoms.get_atomic_numbers(), deep=1)
115 ugd.GetPointData().AddArray(numbers)
116 numbers.SetName("atomic numbers")
118 # add tags
119 tags = numpy_to_vtk(atoms.get_tags(), deep=1)
120 ugd.GetPointData().AddArray(tags)
121 tags.SetName("tags")
123 # add covalent radii
124 from ase.data import covalent_radii
125 radii = numpy_to_vtk(covalent_radii[atoms.numbers], deep=1)
126 ugd.GetPointData().AddArray(radii)
127 radii.SetName("radii")
129 # Save the UnstructuredGrid dataset to a VTK XML file.
130 w = vtkXMLUnstructuredGridWriter()
132 if fast:
133 w.SetDataModeToAppend()
134 w.EncodeAppendedDataOff()
135 else:
136 w.GetCompressor().SetCompressionLevel(0)
137 w.SetDataModeToAscii()
139 if isinstance(filename, str):
140 w.SetFileName(filename)
141 else:
142 w.SetFileName(filename.name)
143 if VTK_MAJOR_VERSION <= 5:
144 w.SetInput(ugd)
145 else:
146 w.SetInputData(ugd)
147 w.Write()