Coverage for /builds/ase/ase/ase/io/jsonio.py: 89.29%

112 statements  

« prev     ^ index     » next       coverage.py v7.5.3, created at 2025-08-02 00:12 +0000

1# fmt: off 

2 

3import datetime 

4import json 

5 

6import numpy as np 

7 

8from ase.utils import reader, writer 

9 

10# Note: We are converting JSON classes to the recommended mechanisms 

11# by the json module. That means instead of classes, we will use the 

12# functions default() and object_hook(). 

13# 

14# The encoder classes are to be deprecated (but maybe not removed, if 

15# widely used). 

16 

17 

18def default(obj): 

19 if hasattr(obj, 'todict'): 

20 dct = obj.todict() 

21 

22 if not isinstance(dct, dict): 

23 raise RuntimeError('todict() of {} returned object of type {} ' 

24 'but should have returned dict' 

25 .format(obj, type(dct))) 

26 if hasattr(obj, 'ase_objtype'): 

27 # We modify the dictionary, so it is wise to take a copy. 

28 dct = dct.copy() 

29 dct['__ase_objtype__'] = obj.ase_objtype 

30 

31 return dct 

32 if isinstance(obj, np.ndarray): 

33 flatobj = obj.ravel() 

34 if np.iscomplexobj(obj): 

35 flatobj.dtype = obj.real.dtype 

36 # We use str(obj.dtype) here instead of obj.dtype.name, because 

37 # they are not always the same (e.g. for numpy arrays of strings). 

38 # Using obj.dtype.name can break the ability to recursively decode/ 

39 # encode such arrays. 

40 return {'__ndarray__': (obj.shape, 

41 str(obj.dtype), 

42 flatobj.tolist())} 

43 if isinstance(obj, np.integer): 

44 return int(obj) 

45 if isinstance(obj, np.floating): 

46 return float(obj) 

47 if isinstance(obj, np.bool_): 

48 return bool(obj) 

49 if isinstance(obj, datetime.datetime): 

50 return {'__datetime__': obj.isoformat()} 

51 if isinstance(obj, complex): 

52 return {'__complex__': (obj.real, obj.imag)} 

53 

54 raise TypeError(f'Cannot convert object of type {type(obj)} to ' 

55 'dictionary for JSON') 

56 

57 

58class MyEncoder(json.JSONEncoder): 

59 def default(self, obj): 

60 # (Note the name "default" comes from the outer namespace, so 

61 # not actually recursive) 

62 return default(obj) 

63 

64 

65encode = MyEncoder().encode 

66 

67 

68def object_hook(dct): 

69 if '__datetime__' in dct: 

70 return datetime.datetime.strptime(dct['__datetime__'], 

71 '%Y-%m-%dT%H:%M:%S.%f') 

72 

73 if '__complex__' in dct: 

74 return complex(*dct['__complex__']) 

75 

76 if '__ndarray__' in dct: 

77 return create_ndarray(*dct['__ndarray__']) 

78 

79 # No longer used (only here for backwards compatibility): 

80 if '__complex_ndarray__' in dct: 

81 r, i = (np.array(x) for x in dct['__complex_ndarray__']) 

82 return r + i * 1j 

83 

84 if '__ase_objtype__' in dct: 

85 objtype = dct.pop('__ase_objtype__') 

86 dct = numpyfy(dct) 

87 return create_ase_object(objtype, dct) 

88 

89 return dct 

90 

91 

92def create_ndarray(shape, dtype, data): 

93 """Create ndarray from shape, dtype and flattened data.""" 

94 array = np.empty(shape, dtype=dtype) 

95 flatbuf = array.ravel() 

96 if np.iscomplexobj(array): 

97 flatbuf.dtype = array.real.dtype 

98 flatbuf[:] = data 

99 return array 

100 

101 

102def create_ase_object(objtype, dct): 

103 # We just try each object type one after another and instantiate 

104 # them manually, depending on which kind it is. 

105 # We can formalize this later if it ever becomes necessary. 

106 if objtype == 'cell': 

107 from ase.cell import Cell 

108 dct.pop('pbc', None) # compatibility; we once had pbc 

109 obj = Cell(**dct) 

110 elif objtype == 'bandstructure': 

111 from ase.spectrum.band_structure import BandStructure 

112 obj = BandStructure(**dct) 

113 elif objtype == 'bandpath': 

114 from ase.dft.kpoints import BandPath 

115 obj = BandPath(path=dct.pop('labelseq'), **dct) 

116 elif objtype == 'atoms': 

117 from ase import Atoms 

118 obj = Atoms.fromdict(dct) 

119 elif objtype == 'vibrationsdata': 

120 from ase.vibrations import VibrationsData 

121 obj = VibrationsData.fromdict(dct) 

122 else: 

123 raise ValueError('Do not know how to decode object type {} ' 

124 'into an actual object'.format(objtype)) 

125 assert obj.ase_objtype == objtype 

126 return obj 

127 

128 

129mydecode = json.JSONDecoder(object_hook=object_hook).decode 

130 

131 

132def intkey(key): 

133 """Convert str to int if possible.""" 

134 try: 

135 return int(key) 

136 except ValueError: 

137 return key 

138 

139 

140def fix_int_keys_in_dicts(obj): 

141 """Convert "int" keys: "1" -> 1. 

142 

143 The json.dump() function will convert int keys in dicts to str keys. 

144 This function goes the other way. 

145 """ 

146 if isinstance(obj, dict): 

147 return {intkey(key): fix_int_keys_in_dicts(value) 

148 for key, value in obj.items()} 

149 return obj 

150 

151 

152def numpyfy(obj): 

153 if isinstance(obj, dict): 

154 if '__complex_ndarray__' in obj: 

155 r, i = (np.array(x) for x in obj['__complex_ndarray__']) 

156 return r + i * 1j 

157 if isinstance(obj, list) and len(obj) > 0: 

158 try: 

159 a = np.array(obj) 

160 except ValueError: 

161 pass 

162 else: 

163 if a.dtype in [bool, int, float]: 

164 return a 

165 obj = [numpyfy(value) for value in obj] 

166 return obj 

167 

168 

169def decode(txt, always_array=True): 

170 obj = mydecode(txt) 

171 obj = fix_int_keys_in_dicts(obj) 

172 if always_array: 

173 obj = numpyfy(obj) 

174 return obj 

175 

176 

177@reader 

178def read_json(fd, always_array=True): 

179 dct = decode(fd.read(), always_array=always_array) 

180 return dct 

181 

182 

183@writer 

184def write_json(fd, obj): 

185 fd.write(encode(obj))