Coverage for /builds/ase/ase/ase/ga/slab_operators.py: 68.48%

330 statements  

« prev     ^ index     » next       coverage.py v7.5.3, created at 2025-08-02 00:12 +0000

1# fmt: off 

2 

3"""Operators that work on slabs. 

4Allowed compositions are respected. 

5Identical indexing of the slabs are assumed for the cut-splice operator.""" 

6from collections import Counter 

7from itertools import permutations 

8from operator import itemgetter 

9 

10import numpy as np 

11 

12from ase.ga.element_mutations import get_periodic_table_distance 

13from ase.ga.offspring_creator import OffspringCreator 

14from ase.utils import atoms_to_spglib_cell 

15 

16try: 

17 import spglib 

18except ImportError: 

19 spglib = None 

20 

21 

22def permute2(atoms, rng=np.random): 

23 i1 = rng.choice(range(len(atoms))) 

24 sym1 = atoms[i1].symbol 

25 i2 = rng.choice([a.index for a in atoms if a.symbol != sym1]) 

26 atoms[i1].symbol = atoms[i2].symbol 

27 atoms[i2].symbol = sym1 

28 

29 

30def replace_element(atoms, element_out, element_in): 

31 syms = np.array(atoms.get_chemical_symbols()) 

32 syms[syms == element_out] = element_in 

33 atoms.set_chemical_symbols(syms) 

34 

35 

36def get_add_remove_lists(**kwargs): 

37 to_add, to_rem = [], [] 

38 for s, amount in kwargs.items(): 

39 if amount > 0: 

40 to_add.extend([s] * amount) 

41 elif amount < 0: 

42 to_rem.extend([s] * abs(amount)) 

43 return to_add, to_rem 

44 

45 

46def get_minority_element(atoms): 

47 counter = Counter(atoms.get_chemical_symbols()) 

48 return sorted(counter.items(), key=itemgetter(1), reverse=False)[0][0] 

49 

50 

51def minority_element_segregate(atoms, layer_tag=1, rng=np.random): 

52 """Move the minority alloy element to the layer specified by the layer_tag, 

53 Atoms object should contain atoms with the corresponding tag.""" 

54 sym = get_minority_element(atoms) 

55 layer_indices = {a.index for a in atoms if a.tag == layer_tag} 

56 minority_indices = {a.index for a in atoms if a.symbol == sym} 

57 change_indices = minority_indices - layer_indices 

58 in_layer_not_sym = list(layer_indices - minority_indices) 

59 rng.shuffle(in_layer_not_sym) 

60 if len(change_indices) > 0: 

61 for i, ai in zip(change_indices, in_layer_not_sym): 

62 atoms[i].symbol = atoms[ai].symbol 

63 atoms[ai].symbol = sym 

64 

65 

66def same_layer_comp(atoms, rng=np.random): 

67 unique_syms, comp = np.unique(sorted(atoms.get_chemical_symbols()), 

68 return_counts=True) 

69 layer = get_layer_comps(atoms) 

70 sym_dict = {s: int(np.array(c) / len(layer)) 

71 for s, c in zip(unique_syms, comp)} 

72 for la in layer: 

73 correct_by = sym_dict.copy() 

74 lcomp = dict( 

75 zip(*np.unique([atoms[i].symbol for i in la], return_counts=True))) 

76 for s, num in lcomp.items(): 

77 correct_by[s] -= num 

78 to_add, to_rem = get_add_remove_lists(**correct_by) 

79 for add, rem in zip(to_add, to_rem): 

80 ai = rng.choice([i for i in la if atoms[i].symbol == rem]) 

81 atoms[ai].symbol = add 

82 

83 

84def get_layer_comps(atoms, eps=1e-2): 

85 lc = [] 

86 old_z = np.inf 

87 for z, ind in sorted([(a.z, a.index) for a in atoms]): 

88 if abs(old_z - z) < eps: 

89 lc[-1].append(ind) 

90 else: 

91 lc.append([ind]) 

92 old_z = z 

93 

94 return lc 

95 

96 

97def get_ordered_composition(syms, pools=None): 

98 if pools is None: 

99 pool_index = {sym: 0 for sym in set(syms)} 

100 else: 

101 pool_index = {} 

102 for i, pool in enumerate(pools): 

103 if isinstance(pool, str): 

104 pool_index[pool] = i 

105 else: 

106 for sym in set(syms): 

107 if sym in pool: 

108 pool_index[sym] = i 

109 syms = [(sym, pool_index[sym], c) 

110 for sym, c in zip(*np.unique(syms, return_counts=True))] 

111 unique_syms, pn, comp = zip( 

112 *sorted(syms, key=lambda k: (k[1] - k[2], k[0]))) 

113 return (unique_syms, pn, comp) 

114 

115 

116def dummy_func(*args): 

117 return 

118 

119 

120class SlabOperator(OffspringCreator): 

121 def __init__(self, verbose=False, num_muts=1, 

122 allowed_compositions=None, 

123 distribution_correction_function=None, 

124 element_pools=None, 

125 rng=np.random): 

126 OffspringCreator.__init__(self, verbose, num_muts=num_muts, rng=rng) 

127 

128 self.allowed_compositions = allowed_compositions 

129 self.element_pools = element_pools 

130 if distribution_correction_function is None: 

131 self.dcf = dummy_func 

132 else: 

133 self.dcf = distribution_correction_function 

134 # Number of different elements i.e. [2, 1] if len(element_pools) == 2 

135 # then 2 different elements in pool 1 is allowed but only 1 from pool 2 

136 

137 def get_symbols_to_use(self, syms): 

138 """Get the symbols to use for the offspring candidate. The returned 

139 list of symbols will respect self.allowed_compositions""" 

140 if self.allowed_compositions is None: 

141 return syms 

142 

143 unique_syms, counts = np.unique(syms, return_counts=True) 

144 comp, unique_syms = zip(*sorted(zip(counts, unique_syms), 

145 reverse=True)) 

146 

147 for cc in self.allowed_compositions: 

148 comp += (0,) * (len(cc) - len(comp)) 

149 if comp == tuple(sorted(cc)): 

150 return syms 

151 

152 comp_diff = self.get_closest_composition_diff(comp) 

153 to_add, to_rem = get_add_remove_lists( 

154 **dict(zip(unique_syms, comp_diff))) 

155 for add, rem in zip(to_add, to_rem): 

156 tbc = [i for i in range(len(syms)) if syms[i] == rem] 

157 ai = self.rng.choice(tbc) 

158 syms[ai] = add 

159 return syms 

160 

161 def get_add_remove_elements(self, syms): 

162 if self.element_pools is None or self.allowed_compositions is None: 

163 return [], [] 

164 unique_syms, pool_number, comp = get_ordered_composition( 

165 syms, self.element_pools) 

166 stay_comp, stay_syms = [], [] 

167 add_rem = {} 

168 per_pool = len(self.allowed_compositions[0]) / len(self.element_pools) 

169 pool_count = np.zeros(len(self.element_pools), dtype=int) 

170 for pn, num, sym in zip(pool_number, comp, unique_syms): 

171 pool_count[pn] += 1 

172 if pool_count[pn] <= per_pool: 

173 stay_comp.append(num) 

174 stay_syms.append(sym) 

175 else: 

176 add_rem[sym] = -num 

177 # collect elements from individual pools 

178 

179 diff = self.get_closest_composition_diff(stay_comp) 

180 add_rem.update({s: c for s, c in zip(stay_syms, diff)}) 

181 return get_add_remove_lists(**add_rem) 

182 

183 def get_closest_composition_diff(self, c): 

184 comp = np.array(c) 

185 mindiff = 1e10 

186 allowed_list = list(self.allowed_compositions) 

187 self.rng.shuffle(allowed_list) 

188 for ac in allowed_list: 

189 diff = self.get_composition_diff(comp, ac) 

190 numdiff = sum(abs(i) for i in diff) 

191 if numdiff < mindiff: 

192 mindiff = numdiff 

193 ccdiff = diff 

194 return ccdiff 

195 

196 def get_composition_diff(self, c1, c2): 

197 difflen = len(c1) - len(c2) 

198 if difflen > 0: 

199 c2 += (0,) * difflen 

200 return np.array(c2) - c1 

201 

202 def get_possible_mutations(self, a): 

203 unique_syms, comp = np.unique(sorted(a.get_chemical_symbols()), 

204 return_counts=True) 

205 min_num = min( 

206 i for i in np.ravel(list(self.allowed_compositions)) if i > 0 

207 ) 

208 muts = set() 

209 for i, n in enumerate(comp): 

210 if n != 0: 

211 muts.add((unique_syms[i], n)) 

212 if n % min_num >= 0: 

213 for j in range(1, n // min_num): 

214 muts.add((unique_syms[i], min_num * j)) 

215 return list(muts) 

216 

217 def get_all_element_mutations(self, a): 

218 """Get all possible mutations for the supplied atoms object given 

219 the element pools.""" 

220 muts = [] 

221 symset = set(a.get_chemical_symbols()) 

222 for sym in symset: 

223 for pool in self.element_pools: 

224 if sym in pool: 

225 muts.extend([(sym, s) for s in pool if s not in symset]) 

226 return muts 

227 

228 def finalize_individual(self, indi): 

229 atoms_string = ''.join(indi.get_chemical_symbols()) 

230 indi.info['key_value_pairs']['atoms_string'] = atoms_string 

231 return OffspringCreator.finalize_individual(self, indi) 

232 

233 

234class CutSpliceSlabCrossover(SlabOperator): 

235 def __init__(self, allowed_compositions=None, element_pools=None, 

236 verbose=False, 

237 num_muts=1, tries=1000, min_ratio=0.25, 

238 distribution_correction_function=None, rng=np.random): 

239 SlabOperator.__init__(self, verbose, num_muts, 

240 allowed_compositions, 

241 distribution_correction_function, 

242 element_pools=element_pools, 

243 rng=rng) 

244 

245 self.tries = tries 

246 self.min_ratio = min_ratio 

247 self.descriptor = 'CutSpliceSlabCrossover' 

248 

249 def get_new_individual(self, parents): 

250 f, m = parents 

251 

252 indi = self.initialize_individual(f, self.operate(f, m)) 

253 indi.info['data']['parents'] = [i.info['confid'] for i in parents] 

254 

255 parent_message = ': Parents {} {}'.format(f.info['confid'], 

256 m.info['confid']) 

257 return (self.finalize_individual(indi), 

258 self.descriptor + parent_message) 

259 

260 def operate(self, f, m): 

261 child = f.copy() 

262 fp = f.positions 

263 ma = np.max(fp.transpose(), axis=1) 

264 mi = np.min(fp.transpose(), axis=1) 

265 

266 for _ in range(self.tries): 

267 # Find center point of cut 

268 rv = [self.rng.random() for _ in range(3)] # random vector 

269 midpoint = (ma - mi) * rv + mi 

270 

271 # Determine cut plane 

272 theta = self.rng.random() * 2 * np.pi # 0,2pi 

273 phi = self.rng.random() * np.pi # 0,pi 

274 e = np.array((np.sin(phi) * np.cos(theta), 

275 np.sin(theta) * np.sin(phi), 

276 np.cos(phi))) 

277 

278 # Cut structures 

279 d2fp = np.dot(fp - midpoint, e) 

280 fpart = d2fp > 0 

281 ratio = float(np.count_nonzero(fpart)) / len(f) 

282 if ratio < self.min_ratio or ratio > 1 - self.min_ratio: 

283 continue 

284 syms = np.where(fpart, f.get_chemical_symbols(), 

285 m.get_chemical_symbols()) 

286 dists2plane = abs(d2fp) 

287 

288 # Correct the composition 

289 # What if only one element pool is represented in the offspring 

290 to_add, to_rem = self.get_add_remove_elements(syms) 

291 

292 # Change elements closest to the cut plane 

293 for add, rem in zip(to_add, to_rem): 

294 tbc = [(dists2plane[i], i) 

295 for i in range(len(syms)) if syms[i] == rem] 

296 ai = sorted(tbc)[0][1] 

297 syms[ai] = add 

298 

299 child.set_chemical_symbols(syms) 

300 break 

301 

302 self.dcf(child) 

303 

304 return child 

305 

306 

307# Mutations: Random, MoveUp/Down/Left/Right, six or all elements 

308 

309class RandomCompositionMutation(SlabOperator): 

310 """Change the current composition to another of the allowed compositions. 

311 The allowed compositions should be input in the same order as the element 

312 pools, for example: 

313 element_pools = [['Au', 'Cu'], ['In', 'Bi']] 

314 allowed_compositions = [(6, 2), (5, 3)] 

315 means that there can be 5 or 6 Au and Cu, and 2 or 3 In and Bi. 

316 """ 

317 

318 def __init__(self, verbose=False, num_muts=1, element_pools=None, 

319 allowed_compositions=None, 

320 distribution_correction_function=None, rng=np.random): 

321 SlabOperator.__init__(self, verbose, num_muts, 

322 allowed_compositions, 

323 distribution_correction_function, 

324 element_pools=element_pools, 

325 rng=rng) 

326 

327 self.descriptor = 'RandomCompositionMutation' 

328 

329 def get_new_individual(self, parents): 

330 f = parents[0] 

331 parent_message = ': Parent {}'.format(f.info['confid']) 

332 

333 if self.allowed_compositions is None: 

334 if len(set(f.get_chemical_symbols())) == 1: 

335 if self.element_pools is None: 

336 # We cannot find another composition without knowledge of 

337 # other allowed elements or compositions 

338 return None, self.descriptor + parent_message 

339 

340 # Do the operation 

341 indi = self.initialize_individual(f, self.operate(f)) 

342 indi.info['data']['parents'] = [i.info['confid'] for i in parents] 

343 

344 return (self.finalize_individual(indi), 

345 self.descriptor + parent_message) 

346 

347 def operate(self, atoms): 

348 allowed_comps = self.allowed_compositions 

349 if allowed_comps is None: 

350 n_elems = len(set(atoms.get_chemical_symbols())) 

351 n_atoms = len(atoms) 

352 allowed_comps = [c for c in permutations(range(1, n_atoms), 

353 n_elems) 

354 if sum(c) == n_atoms] 

355 

356 # Sorting the composition to have the same order as in element_pools 

357 syms = atoms.get_chemical_symbols() 

358 unique_syms, _, comp = get_ordered_composition(syms, 

359 self.element_pools) 

360 

361 # Choose the composition to change to 

362 for i, allowed in enumerate(allowed_comps): 

363 if comp == tuple(allowed): 

364 allowed_comps = np.delete(allowed_comps, i, axis=0) 

365 break 

366 chosen = self.rng.randint(len(allowed_comps)) 

367 comp_diff = self.get_composition_diff(comp, allowed_comps[chosen]) 

368 

369 # Get difference from current composition 

370 to_add, to_rem = get_add_remove_lists( 

371 **dict(zip(unique_syms, comp_diff))) 

372 

373 # Correct current composition 

374 syms = atoms.get_chemical_symbols() 

375 for add, rem in zip(to_add, to_rem): 

376 tbc = [i for i in range(len(syms)) if syms[i] == rem] 

377 ai = self.rng.choice(tbc) 

378 syms[ai] = add 

379 

380 atoms.set_chemical_symbols(syms) 

381 self.dcf(atoms) 

382 return atoms 

383 

384 

385class RandomElementMutation(SlabOperator): 

386 def __init__(self, element_pools, verbose=False, num_muts=1, 

387 allowed_compositions=None, 

388 distribution_correction_function=None, rng=np.random): 

389 SlabOperator.__init__(self, verbose, num_muts, 

390 allowed_compositions, 

391 distribution_correction_function, 

392 element_pools=element_pools, 

393 rng=rng) 

394 

395 self.descriptor = 'RandomElementMutation' 

396 

397 def get_new_individual(self, parents): 

398 f = parents[0] 

399 

400 # Do the operation 

401 indi = self.initialize_individual(f, self.operate(f)) 

402 indi.info['data']['parents'] = [i.info['confid'] for i in parents] 

403 

404 parent_message = ': Parent {}'.format(f.info['confid']) 

405 return (self.finalize_individual(indi), 

406 self.descriptor + parent_message) 

407 

408 def operate(self, atoms): 

409 poss_muts = self.get_all_element_mutations(atoms) 

410 chosen = self.rng.randint(len(poss_muts)) 

411 replace_element(atoms, *poss_muts[chosen]) 

412 self.dcf(atoms) 

413 return atoms 

414 

415 

416class NeighborhoodElementMutation(SlabOperator): 

417 def __init__(self, element_pools, verbose=False, num_muts=1, 

418 allowed_compositions=None, 

419 distribution_correction_function=None, rng=np.random): 

420 SlabOperator.__init__(self, verbose, num_muts, 

421 allowed_compositions, 

422 distribution_correction_function, 

423 element_pools=element_pools, 

424 rng=rng) 

425 

426 self.descriptor = 'NeighborhoodElementMutation' 

427 

428 def get_new_individual(self, parents): 

429 f = parents[0] 

430 

431 indi = self.initialize_individual(f, f) 

432 indi.info['data']['parents'] = [i.info['confid'] for i in parents] 

433 

434 indi = self.operate(indi) 

435 

436 parent_message = ': Parent {}'.format(f.info['confid']) 

437 return (self.finalize_individual(indi), 

438 self.descriptor + parent_message) 

439 

440 def operate(self, atoms): 

441 least_diff = 1e22 

442 for mut in self.get_all_element_mutations(atoms): 

443 dist = get_periodic_table_distance(*mut) 

444 if dist < least_diff: 

445 poss_muts = [mut] 

446 least_diff = dist 

447 elif dist == least_diff: 

448 poss_muts.append(mut) 

449 

450 chosen = self.rng.randint(len(poss_muts)) 

451 replace_element(atoms, *poss_muts[chosen]) 

452 self.dcf(atoms) 

453 return atoms 

454 

455 

456class SymmetrySlabPermutation(SlabOperator): 

457 """Permutes the atoms in the slab until it has a higher symmetry number.""" 

458 

459 def __init__(self, verbose=False, num_muts=1, sym_goal=100, max_tries=50, 

460 allowed_compositions=None, 

461 distribution_correction_function=None, rng=np.random): 

462 SlabOperator.__init__(self, verbose, num_muts, 

463 allowed_compositions, 

464 distribution_correction_function, 

465 rng=rng) 

466 if spglib is None: 

467 print("SymmetrySlabPermutation needs spglib to function") 

468 

469 assert sym_goal >= 1 

470 self.sym_goal = sym_goal 

471 self.max_tries = max_tries 

472 self.descriptor = 'SymmetrySlabPermutation' 

473 

474 def get_new_individual(self, parents): 

475 f = parents[0] 

476 # Permutation only makes sense if two different elements are present 

477 if len(set(f.get_chemical_symbols())) == 1: 

478 f = parents[1] 

479 if len(set(f.get_chemical_symbols())) == 1: 

480 return None, '{1} not possible in {0}'.format(f.info['confid'], 

481 self.descriptor) 

482 

483 indi = self.initialize_individual(f, self.operate(f)) 

484 indi.info['data']['parents'] = [i.info['confid'] for i in parents] 

485 

486 parent_message = ': Parent {}'.format(f.info['confid']) 

487 return (self.finalize_individual(indi), 

488 self.descriptor + parent_message) 

489 

490 def operate(self, atoms): 

491 from ase.spacegroup.symmetrize import spglib_get_symmetry_dataset 

492 # Do the operation 

493 sym_num = 1 

494 sg = self.sym_goal 

495 while sym_num < sg: 

496 for _ in range(self.max_tries): 

497 for _ in range(2): 

498 permute2(atoms, rng=self.rng) 

499 self.dcf(atoms) 

500 sym_num = spglib_get_symmetry_dataset( 

501 atoms_to_spglib_cell(atoms))['number'] 

502 if sym_num >= sg: 

503 break 

504 sg -= 1 

505 return atoms 

506 

507 

508class RandomSlabPermutation(SlabOperator): 

509 def __init__(self, verbose=False, num_muts=1, 

510 allowed_compositions=None, 

511 distribution_correction_function=None, rng=np.random): 

512 SlabOperator.__init__(self, verbose, num_muts, 

513 allowed_compositions, 

514 distribution_correction_function, 

515 rng=rng) 

516 

517 self.descriptor = 'RandomSlabPermutation' 

518 

519 def get_new_individual(self, parents): 

520 f = parents[0] 

521 # Permutation only makes sense if two different elements are present 

522 if len(set(f.get_chemical_symbols())) == 1: 

523 f = parents[1] 

524 if len(set(f.get_chemical_symbols())) == 1: 

525 return None, '{1} not possible in {0}'.format(f.info['confid'], 

526 self.descriptor) 

527 

528 indi = self.initialize_individual(f, f) 

529 indi.info['data']['parents'] = [i.info['confid'] for i in parents] 

530 

531 indi = self.operate(indi) 

532 

533 parent_message = ': Parent {}'.format(f.info['confid']) 

534 return (self.finalize_individual(indi), 

535 self.descriptor + parent_message) 

536 

537 def operate(self, atoms): 

538 # Do the operation 

539 for _ in range(self.num_muts): 

540 permute2(atoms, rng=self.rng) 

541 self.dcf(atoms) 

542 return atoms