Coverage for /builds/ase/ase/ase/ga/standard_comparators.py: 87.27%

110 statements  

« prev     ^ index     » next       coverage.py v7.5.3, created at 2025-08-02 00:12 +0000

1# fmt: off 

2 

3import numpy as np 

4 

5from ase.ga import get_raw_score 

6 

7 

8def get_sorted_dist_list(atoms, mic=False): 

9 """ Utility method used to calculate the sorted distance list 

10 describing the cluster in atoms. """ 

11 numbers = atoms.numbers 

12 unique_types = set(numbers) 

13 pair_cor = {} 

14 for n in unique_types: 

15 i_un = [i for i in range(len(atoms)) if atoms[i].number == n] 

16 d = [] 

17 for i, n1 in enumerate(i_un): 

18 for n2 in i_un[i + 1:]: 

19 d.append(atoms.get_distance(n1, n2, mic)) 

20 d.sort() 

21 pair_cor[n] = np.array(d) 

22 return pair_cor 

23 

24 

25class InteratomicDistanceComparator: 

26 

27 """ An implementation of the comparison criteria described in 

28 L.B. Vilhelmsen and B. Hammer, PRL, 108, 126101 (2012) 

29 

30 Parameters: 

31 

32 n_top: The number of atoms being optimized by the GA. 

33 Default 0 - meaning all atoms. 

34 

35 pair_cor_cum_diff: The limit in eq. 2 of the letter. 

36 pair_cor_max: The limit in eq. 3 of the letter 

37 dE: The limit of eq. 1 of the letter 

38 mic: Determines if distances are calculated 

39 using the minimum image convention 

40 """ 

41 

42 def __init__(self, n_top=None, pair_cor_cum_diff=0.015, 

43 pair_cor_max=0.7, dE=0.02, mic=False): 

44 self.pair_cor_cum_diff = pair_cor_cum_diff 

45 self.pair_cor_max = pair_cor_max 

46 self.dE = dE 

47 self.n_top = n_top or 0 

48 self.mic = mic 

49 

50 def looks_like(self, a1, a2): 

51 """ Return if structure a1 or a2 are similar or not. """ 

52 if len(a1) != len(a2): 

53 raise Exception('The two configurations are not the same size') 

54 

55 # first we check the energy criteria 

56 dE = abs(a1.get_potential_energy() - a2.get_potential_energy()) 

57 if dE >= self.dE: 

58 return False 

59 

60 # then we check the structure 

61 a1top = a1[-self.n_top:] 

62 a2top = a2[-self.n_top:] 

63 cum_diff, max_diff = self.__compare_structure__(a1top, a2top) 

64 

65 return (cum_diff < self.pair_cor_cum_diff 

66 and max_diff < self.pair_cor_max) 

67 

68 def __compare_structure__(self, a1, a2): 

69 """ Private method for calculating the structural difference. """ 

70 p1 = get_sorted_dist_list(a1, mic=self.mic) 

71 p2 = get_sorted_dist_list(a2, mic=self.mic) 

72 numbers = a1.numbers 

73 total_cum_diff = 0. 

74 max_diff = 0 

75 for n in p1.keys(): 

76 cum_diff = 0. 

77 c1 = p1[n] 

78 c2 = p2[n] 

79 assert len(c1) == len(c2) 

80 if len(c1) == 0: 

81 continue 

82 t_size = np.sum(c1) 

83 d = np.abs(c1 - c2) 

84 cum_diff = np.sum(d) 

85 max_diff = np.max(d) 

86 ntype = float(sum(i == n for i in numbers)) 

87 total_cum_diff += cum_diff / t_size * ntype / float(len(numbers)) 

88 return (total_cum_diff, max_diff) 

89 

90 

91class SequentialComparator: 

92 """Use more than one comparison class and test them all in sequence. 

93 

94 Supply a list of integers if for example two comparison tests both 

95 need to be positive if two atoms objects are truly equal. 

96 Ex: 

97 methods = [a, b, c, d], logics = [0, 1, 1, 2] 

98 if a or d is positive -> return True 

99 if b and c are positive -> return True 

100 if b and not c are positive (or vice versa) -> return False 

101 """ 

102 

103 def __init__(self, methods, logics=None): 

104 if not isinstance(methods, list): 

105 methods = [methods] 

106 if logics is None: 

107 logics = [i for i in range(len(methods))] 

108 if not isinstance(logics, list): 

109 logics = [logics] 

110 assert len(logics) == len(methods) 

111 

112 self.methods = [] 

113 self.logics = [] 

114 for m, l in zip(methods, logics): 

115 if hasattr(m, 'looks_like'): 

116 self.methods.append(m) 

117 self.logics.append(l) 

118 

119 def looks_like(self, a1, a2): 

120 mdct = {logic: [] for logic in self.logics} 

121 for m, logic in zip(self.methods, self.logics): 

122 mdct[logic].append(m) 

123 

124 for methods in mdct.values(): 

125 for m in methods: 

126 if not m.looks_like(a1, a2): 

127 break 

128 else: 

129 return True 

130 return False 

131 

132 

133class StringComparator: 

134 """Compares the calculated hash strings. These strings should be stored 

135 in atoms.info['key_value_pairs'][key1] and 

136 atoms.info['key_value_pairs'][key2] ... 

137 where the keys should be supplied as parameters i.e. 

138 StringComparator(key1, key2, ...) 

139 """ 

140 

141 def __init__(self, *keys): 

142 self.keys = keys 

143 

144 def looks_like(self, a1, a2): 

145 for k in self.keys: 

146 if a1.info['key_value_pairs'][k] == a2.info['key_value_pairs'][k]: 

147 return True 

148 return False 

149 

150 

151class EnergyComparator: 

152 """Compares the energy of the supplied atoms objects using 

153 get_potential_energy(). 

154 

155 Parameters: 

156 

157 dE: the difference in energy below which two energies are 

158 deemed equal. 

159 """ 

160 

161 def __init__(self, dE=0.02): 

162 self.dE = dE 

163 

164 def looks_like(self, a1, a2): 

165 dE = abs(a1.get_potential_energy() - a2.get_potential_energy()) 

166 if dE >= self.dE: 

167 return False 

168 else: 

169 return True 

170 

171 

172class RawScoreComparator: 

173 """Compares the raw_score of the supplied individuals 

174 objects using a1.info['key_value_pairs']['raw_score']. 

175 

176 Parameters: 

177 

178 dist: the difference in raw_score below which two 

179 scores are deemed equal. 

180 """ 

181 

182 def __init__(self, dist=0.02): 

183 self.dist = dist 

184 

185 def looks_like(self, a1, a2): 

186 d = abs(get_raw_score(a1) - get_raw_score(a2)) 

187 if d >= self.dist: 

188 return False 

189 else: 

190 return True 

191 

192 

193class NoComparator: 

194 """Returns False always. If you don't want any comparator.""" 

195 

196 def looks_like(self, *args): 

197 return False 

198 

199 

200class AtomsComparator: 

201 """Compares the Atoms objects directly.""" 

202 

203 def looks_like(self, a1, a2): 

204 return a1 == a2 

205 

206 

207class CompositionComparator: 

208 """Compares the composition of the Atoms objects.""" 

209 

210 def looks_like(self, a1, a2): 

211 return a1.get_chemical_formula() == a2.get_chemical_formula()