Coverage for /builds/ase/ase/ase/io/jsonio.py: 89.29%
112 statements
« prev ^ index » next coverage.py v7.5.3, created at 2025-08-02 00:12 +0000
« prev ^ index » next coverage.py v7.5.3, created at 2025-08-02 00:12 +0000
1# fmt: off
3import datetime
4import json
6import numpy as np
8from ase.utils import reader, writer
10# Note: We are converting JSON classes to the recommended mechanisms
11# by the json module. That means instead of classes, we will use the
12# functions default() and object_hook().
13#
14# The encoder classes are to be deprecated (but maybe not removed, if
15# widely used).
18def default(obj):
19 if hasattr(obj, 'todict'):
20 dct = obj.todict()
22 if not isinstance(dct, dict):
23 raise RuntimeError('todict() of {} returned object of type {} '
24 'but should have returned dict'
25 .format(obj, type(dct)))
26 if hasattr(obj, 'ase_objtype'):
27 # We modify the dictionary, so it is wise to take a copy.
28 dct = dct.copy()
29 dct['__ase_objtype__'] = obj.ase_objtype
31 return dct
32 if isinstance(obj, np.ndarray):
33 flatobj = obj.ravel()
34 if np.iscomplexobj(obj):
35 flatobj.dtype = obj.real.dtype
36 # We use str(obj.dtype) here instead of obj.dtype.name, because
37 # they are not always the same (e.g. for numpy arrays of strings).
38 # Using obj.dtype.name can break the ability to recursively decode/
39 # encode such arrays.
40 return {'__ndarray__': (obj.shape,
41 str(obj.dtype),
42 flatobj.tolist())}
43 if isinstance(obj, np.integer):
44 return int(obj)
45 if isinstance(obj, np.floating):
46 return float(obj)
47 if isinstance(obj, np.bool_):
48 return bool(obj)
49 if isinstance(obj, datetime.datetime):
50 return {'__datetime__': obj.isoformat()}
51 if isinstance(obj, complex):
52 return {'__complex__': (obj.real, obj.imag)}
54 raise TypeError(f'Cannot convert object of type {type(obj)} to '
55 'dictionary for JSON')
58class MyEncoder(json.JSONEncoder):
59 def default(self, obj):
60 # (Note the name "default" comes from the outer namespace, so
61 # not actually recursive)
62 return default(obj)
65encode = MyEncoder().encode
68def object_hook(dct):
69 if '__datetime__' in dct:
70 return datetime.datetime.strptime(dct['__datetime__'],
71 '%Y-%m-%dT%H:%M:%S.%f')
73 if '__complex__' in dct:
74 return complex(*dct['__complex__'])
76 if '__ndarray__' in dct:
77 return create_ndarray(*dct['__ndarray__'])
79 # No longer used (only here for backwards compatibility):
80 if '__complex_ndarray__' in dct:
81 r, i = (np.array(x) for x in dct['__complex_ndarray__'])
82 return r + i * 1j
84 if '__ase_objtype__' in dct:
85 objtype = dct.pop('__ase_objtype__')
86 dct = numpyfy(dct)
87 return create_ase_object(objtype, dct)
89 return dct
92def create_ndarray(shape, dtype, data):
93 """Create ndarray from shape, dtype and flattened data."""
94 array = np.empty(shape, dtype=dtype)
95 flatbuf = array.ravel()
96 if np.iscomplexobj(array):
97 flatbuf.dtype = array.real.dtype
98 flatbuf[:] = data
99 return array
102def create_ase_object(objtype, dct):
103 # We just try each object type one after another and instantiate
104 # them manually, depending on which kind it is.
105 # We can formalize this later if it ever becomes necessary.
106 if objtype == 'cell':
107 from ase.cell import Cell
108 dct.pop('pbc', None) # compatibility; we once had pbc
109 obj = Cell(**dct)
110 elif objtype == 'bandstructure':
111 from ase.spectrum.band_structure import BandStructure
112 obj = BandStructure(**dct)
113 elif objtype == 'bandpath':
114 from ase.dft.kpoints import BandPath
115 obj = BandPath(path=dct.pop('labelseq'), **dct)
116 elif objtype == 'atoms':
117 from ase import Atoms
118 obj = Atoms.fromdict(dct)
119 elif objtype == 'vibrationsdata':
120 from ase.vibrations import VibrationsData
121 obj = VibrationsData.fromdict(dct)
122 else:
123 raise ValueError('Do not know how to decode object type {} '
124 'into an actual object'.format(objtype))
125 assert obj.ase_objtype == objtype
126 return obj
129mydecode = json.JSONDecoder(object_hook=object_hook).decode
132def intkey(key):
133 """Convert str to int if possible."""
134 try:
135 return int(key)
136 except ValueError:
137 return key
140def fix_int_keys_in_dicts(obj):
141 """Convert "int" keys: "1" -> 1.
143 The json.dump() function will convert int keys in dicts to str keys.
144 This function goes the other way.
145 """
146 if isinstance(obj, dict):
147 return {intkey(key): fix_int_keys_in_dicts(value)
148 for key, value in obj.items()}
149 return obj
152def numpyfy(obj):
153 if isinstance(obj, dict):
154 if '__complex_ndarray__' in obj:
155 r, i = (np.array(x) for x in obj['__complex_ndarray__'])
156 return r + i * 1j
157 if isinstance(obj, list) and len(obj) > 0:
158 try:
159 a = np.array(obj)
160 except ValueError:
161 pass
162 else:
163 if a.dtype in [bool, int, float]:
164 return a
165 obj = [numpyfy(value) for value in obj]
166 return obj
169def decode(txt, always_array=True):
170 obj = mydecode(txt)
171 obj = fix_int_keys_in_dicts(obj)
172 if always_array:
173 obj = numpyfy(obj)
174 return obj
177@reader
178def read_json(fd, always_array=True):
179 dct = decode(fd.read(), always_array=always_array)
180 return dct
183@writer
184def write_json(fd, obj):
185 fd.write(encode(obj))