Coverage for /builds/ase/ase/ase/io/gpumd.py: 90.68%

118 statements  

« prev     ^ index     » next       coverage.py v7.5.3, created at 2025-08-02 00:12 +0000

1# fmt: off 

2 

3import numpy as np 

4 

5from ase import Atoms 

6from ase.data import atomic_masses, chemical_symbols 

7from ase.neighborlist import NeighborList 

8 

9 

10def find_nearest_index(array, value): 

11 array = np.asarray(array) 

12 idx = (np.abs(array - value)).argmin() 

13 return idx 

14 

15 

16def find_nearest_value(array, value): 

17 array = np.asarray(array) 

18 idx = (np.abs(array - value)).argmin() 

19 return array[idx] 

20 

21 

22def write_gpumd(fd, atoms, maximum_neighbors=None, cutoff=None, 

23 groupings=None, use_triclinic=False, species=None): 

24 """ 

25 Writes atoms into GPUMD input format. 

26 

27 Parameters 

28 ---------- 

29 fd : file 

30 File like object to which the atoms object should be written 

31 atoms : Atoms 

32 Input structure 

33 maximum_neighbors: int 

34 Maximum number of neighbors any atom can ever have (not relevant when 

35 using force constant potentials) 

36 cutoff: float 

37 Initial cutoff distance used for building the neighbor list (not 

38 relevant when using force constant potentials) 

39 groupings : list[list[list[int]]] 

40 Groups into which the individual atoms should be divided in the form of 

41 a list of list of lists. Specifically, the outer list corresponds to 

42 the grouping methods, of which there can be three at the most, which 

43 contains a list of groups in the form of lists of site indices. The 

44 sum of the lengths of the latter must be the same as the total number 

45 of atoms. 

46 use_triclinic: bool 

47 Use format for triclinic cells 

48 species : List[str] 

49 GPUMD uses integers to define atom types. This list allows customized 

50 such definitions (e.g, ['Pd', 'H'] means Pd is type 0 and H type 1). 

51 If None, this list is built by assigning each distinct species to 

52 an integer in the order of appearance in `atoms`. 

53 

54 Raises 

55 ------ 

56 ValueError 

57 Raised if parameters are incompatible 

58 """ 

59 

60 # Check velocties parameter 

61 if atoms.get_velocities() is None: 

62 has_velocity = 0 

63 else: 

64 has_velocity = 1 

65 velocities = atoms.get_velocities() 

66 

67 # Check groupings parameter 

68 if groupings is None: 

69 number_of_grouping_methods = 0 

70 else: 

71 number_of_grouping_methods = len(groupings) 

72 if number_of_grouping_methods > 3: 

73 raise ValueError('There can be no more than 3 grouping methods!') 

74 for g, grouping in enumerate(groupings): 

75 all_indices = [i for group in grouping for i in group] 

76 if len(all_indices) != len(atoms) or\ 

77 set(all_indices) != set(range(len(atoms))): 

78 raise ValueError('The indices listed in grouping method {} are' 

79 ' not compatible with the input' 

80 ' structure!'.format(g)) 

81 

82 # If not specified, estimate the maximum_neighbors 

83 if maximum_neighbors is None: 

84 if cutoff is None: 

85 cutoff = 0.1 

86 maximum_neighbors = 1 

87 else: 

88 nl = NeighborList([cutoff / 2] * len(atoms), skin=2, bothways=True) 

89 nl.update(atoms) 

90 maximum_neighbors = 0 

91 for atom in atoms: 

92 maximum_neighbors = max(maximum_neighbors, 

93 len(nl.get_neighbors(atom.index)[0])) 

94 maximum_neighbors *= 2 

95 maximum_neighbors = min(maximum_neighbors, 1024) 

96 

97 # Add header and cell parameters 

98 lines = [] 

99 if atoms.cell.orthorhombic and not use_triclinic: 

100 triclinic = 0 

101 else: 

102 triclinic = 1 

103 lines.append('{} {} {} {} {} {}'.format(len(atoms), maximum_neighbors, 

104 cutoff, triclinic, has_velocity, 

105 number_of_grouping_methods)) 

106 if triclinic: 

107 lines.append((' {}' * 12)[1:].format(*atoms.pbc.astype(int), 

108 *atoms.cell[:].flatten())) 

109 else: 

110 lines.append((' {}' * 6)[1:].format(*atoms.pbc.astype(int), 

111 *atoms.cell.lengths())) 

112 

113 # Create symbols-to-type map, i.e. integers starting at 0 

114 if not species: 

115 symbol_type_map = {} 

116 for symbol in atoms.get_chemical_symbols(): 

117 if symbol not in symbol_type_map: 

118 symbol_type_map[symbol] = len(symbol_type_map) 

119 else: 

120 if any(sym not in species 

121 for sym in set(atoms.get_chemical_symbols())): 

122 raise ValueError('The species list does not contain all chemical ' 

123 'species that are present in the atoms object.') 

124 else: 

125 symbol_type_map = {symbol: i for i, symbol in enumerate(species)} 

126 

127 # Add lines for all atoms 

128 for a, atm in enumerate(atoms): 

129 t = symbol_type_map[atm.symbol] 

130 line = (' {}' * 5)[1:].format(t, *atm.position, atm.mass) 

131 if has_velocity: 

132 line += (' {}' * 3).format(*velocities[a]) 

133 if groupings is not None: 

134 for grouping in groupings: 

135 for i, group in enumerate(grouping): 

136 if a in group: 

137 line += f' {i}' 

138 break 

139 lines.append(line) 

140 

141 # Write file 

142 fd.write('\n'.join(lines)) 

143 

144 

145def load_xyz_input_gpumd(fd, species=None, isotope_masses=None): 

146 """ 

147 Read the structure input file for GPUMD and return an ase Atoms object 

148 togehter with a dictionary with parameters and a types-to-symbols map 

149 

150 Parameters 

151 ---------- 

152 fd : file | str 

153 File object or name of file from which to read the Atoms object 

154 species : List[str] 

155 List with the chemical symbols that correspond to each type, will take 

156 precedence over isotope_masses 

157 isotope_masses: Dict[str, List[float]] 

158 Dictionary with chemical symbols and lists of the associated atomic 

159 masses, which is used to identify the chemical symbols that correspond 

160 to the types not found in species_types. The default is to find the 

161 closest match :data:`ase.data.atomic_masses`. 

162 

163 Returns 

164 ------- 

165 atoms : Atoms 

166 Atoms object 

167 input_parameters : Dict[str, int] 

168 Dictionary with parameters from the first row of the input file, namely 

169 'N', 'M', 'cutoff', 'triclinic', 'has_velocity' and 'num_of_groups' 

170 species : List[str] 

171 List with the chemical symbols that correspond to each type 

172 

173 Raises 

174 ------ 

175 ValueError 

176 Raised if the list of species is incompatible with the input file 

177 """ 

178 # Parse first line 

179 first_line = next(fd) 

180 input_parameters = {} 

181 keys = ['N', 'M', 'cutoff', 'triclinic', 'has_velocity', 

182 'num_of_groups'] 

183 types = [float if key == 'cutoff' else int for key in keys] 

184 for k, (key, typ) in enumerate(zip(keys, types)): 

185 input_parameters[key] = typ(first_line.split()[k]) 

186 

187 # Parse second line 

188 second_line = next(fd) 

189 second_arr = np.array(second_line.split()) 

190 pbc = second_arr[:3].astype(bool) 

191 if input_parameters['triclinic']: 

192 cell = second_arr[3:].astype(float).reshape((3, 3)) 

193 else: 

194 cell = np.diag(second_arr[3:].astype(float)) 

195 

196 # Parse all remaining rows 

197 n_rows = input_parameters['N'] 

198 n_columns = 5 + input_parameters['has_velocity'] * 3 +\ 

199 input_parameters['num_of_groups'] 

200 rest_lines = [next(fd) for _ in range(n_rows)] 

201 rest_arr = np.array([line.split() for line in rest_lines]) 

202 assert rest_arr.shape == (n_rows, n_columns) 

203 

204 # Extract atom types, positions and masses 

205 atom_types = rest_arr[:, 0].astype(int) 

206 positions = rest_arr[:, 1:4].astype(float) 

207 masses = rest_arr[:, 4].astype(float) 

208 

209 # Determine the atomic species 

210 if species is None: 

211 type_symbol_map = {} 

212 if isotope_masses is not None: 

213 mass_symbols = {mass: symbol for symbol, masses in 

214 isotope_masses.items() for mass in masses} 

215 symbols = [] 

216 for atom_type, mass in zip(atom_types, masses): 

217 if species is None: 

218 if atom_type not in type_symbol_map: 

219 if isotope_masses is not None: 

220 nearest_value = find_nearest_value( 

221 list(mass_symbols.keys()), mass) 

222 symbol = mass_symbols[nearest_value] 

223 else: 

224 symbol = chemical_symbols[ 

225 find_nearest_index(atomic_masses, mass)] 

226 type_symbol_map[atom_type] = symbol 

227 else: 

228 symbol = type_symbol_map[atom_type] 

229 else: 

230 if atom_type > len(species): 

231 raise Exception('There is no entry for atom type {} in the ' 

232 'species list!'.format(atom_type)) 

233 symbol = species[atom_type] 

234 symbols.append(symbol) 

235 

236 if species is None: 

237 species = [type_symbol_map[i] for i in sorted(type_symbol_map.keys())] 

238 

239 # Create the Atoms object 

240 atoms = Atoms(symbols=symbols, positions=positions, masses=masses, pbc=pbc, 

241 cell=cell) 

242 if input_parameters['has_velocity']: 

243 velocities = rest_arr[:, 5:8].astype(float) 

244 atoms.set_velocities(velocities) 

245 if input_parameters['num_of_groups']: 

246 start_col = 5 + 3 * input_parameters['has_velocity'] 

247 groups = rest_arr[:, start_col:].astype(int) 

248 atoms.info = {i: {'groups': groups[i, :]} for i in range(n_rows)} 

249 

250 return atoms, input_parameters, species 

251 

252 

253def read_gpumd(fd, species=None, isotope_masses=None): 

254 """ 

255 Read Atoms object from a GPUMD structure input file 

256 

257 Parameters 

258 ---------- 

259 fd : file | str 

260 File object or name of file from which to read the Atoms object 

261 species : List[str] 

262 List with the chemical symbols that correspond to each type, will take 

263 precedence over isotope_masses 

264 isotope_masses: Dict[str, List[float]] 

265 Dictionary with chemical symbols and lists of the associated atomic 

266 masses, which is used to identify the chemical symbols that correspond 

267 to the types not found in species_types. The default is to find the 

268 closest match :data:`ase.data.atomic_masses`. 

269 

270 Returns 

271 ------- 

272 atoms : Atoms 

273 Atoms object 

274 

275 Raises 

276 ------ 

277 ValueError 

278 Raised if the list of species is incompatible with the input file 

279 """ 

280 

281 return load_xyz_input_gpumd(fd, species, isotope_masses)[0]