Coverage for /builds/ase/ase/ase/ga/element_crossovers.py: 77.46%

71 statements  

« prev     ^ index     » next       coverage.py v7.5.3, created at 2025-08-02 00:12 +0000

1# fmt: off 

2 

3"""Crossover classes, that cross the elements in the supplied 

4atoms objects. 

5 

6""" 

7import numpy as np 

8 

9from ase.ga.offspring_creator import OffspringCreator 

10 

11 

12class ElementCrossover(OffspringCreator): 

13 """Base class for all operators where the elements of 

14 the atoms objects cross. 

15 

16 """ 

17 

18 def __init__(self, element_pool, max_diff_elements, 

19 min_percentage_elements, verbose, rng=np.random): 

20 OffspringCreator.__init__(self, verbose, rng=rng) 

21 

22 if not isinstance(element_pool[0], (list, np.ndarray)): 

23 self.element_pools = [element_pool] 

24 else: 

25 self.element_pools = element_pool 

26 

27 if max_diff_elements is None: 

28 self.max_diff_elements = [None for _ in self.element_pools] 

29 elif isinstance(max_diff_elements, int): 

30 self.max_diff_elements = [max_diff_elements] 

31 else: 

32 self.max_diff_elements = max_diff_elements 

33 assert len(self.max_diff_elements) == len(self.element_pools) 

34 

35 if min_percentage_elements is None: 

36 self.min_percentage_elements = [0 for _ in self.element_pools] 

37 elif isinstance(min_percentage_elements, (int, float)): 

38 self.min_percentage_elements = [min_percentage_elements] 

39 else: 

40 self.min_percentage_elements = min_percentage_elements 

41 assert len(self.min_percentage_elements) == len(self.element_pools) 

42 

43 self.min_inputs = 2 

44 

45 def get_new_individual(self, parents): 

46 raise NotImplementedError 

47 

48 

49class OnePointElementCrossover(ElementCrossover): 

50 """Crossover of the elements in the atoms objects. Point of cross 

51 is chosen randomly. 

52 

53 Parameters: 

54 

55 element_pool: List of elements in the phase space. The elements can be 

56 grouped if the individual consist of different types of elements. 

57 The list should then be a list of lists e.g. [[list1], [list2]] 

58 

59 max_diff_elements: The maximum number of different elements in the 

60 individual. Default is infinite. If the elements are grouped 

61 max_diff_elements should be supplied as a list with each input 

62 corresponding to the elements specified in the same input in 

63 element_pool. 

64 

65 min_percentage_elements: The minimum percentage of any element in 

66 the individual. Default is any number is allowed. If the elements 

67 are grouped min_percentage_elements should be supplied as a list 

68 with each input corresponding to the elements specified in the 

69 same input in element_pool. 

70 

71 Example: element_pool=[[A,B,C,D],[x,y,z]], max_diff_elements=[3,2], 

72 min_percentage_elements=[.25, .5] 

73 An individual could be "D,B,B,C,x,x,x,x,z,z,z,z" 

74 

75 rng: Random number generator 

76 By default numpy.random. 

77 """ 

78 

79 def __init__(self, element_pool, max_diff_elements=None, 

80 min_percentage_elements=None, verbose=False, rng=np.random): 

81 ElementCrossover.__init__(self, element_pool, max_diff_elements, 

82 min_percentage_elements, verbose, rng=rng) 

83 self.descriptor = 'OnePointElementCrossover' 

84 

85 def get_new_individual(self, parents): 

86 f, m = parents 

87 

88 indi = self.initialize_individual(f) 

89 indi.info['data']['parents'] = [i.info['confid'] for i in parents] 

90 

91 cut_choices = [i for i in range(1, len(f) - 1)] 

92 self.rng.shuffle(cut_choices) 

93 for cut in cut_choices: 

94 fsyms = f.get_chemical_symbols() 

95 msyms = m.get_chemical_symbols() 

96 syms = fsyms[:cut] + msyms[cut:] 

97 ok = True 

98 for i, e in enumerate(self.element_pools): 

99 elems = e[:] 

100 elems_in, indices_in = zip(*[(a.symbol, a.index) for a in f 

101 if a.symbol in elems]) 

102 max_diff_elem = self.max_diff_elements[i] 

103 min_percent_elem = self.min_percentage_elements[i] 

104 if min_percent_elem == 0: 

105 min_percent_elem = 1. / len(elems_in) 

106 if max_diff_elem is None: 

107 max_diff_elem = len(elems_in) 

108 

109 syms_in = [syms[i] for i in indices_in] 

110 for s in set(syms_in): 

111 percentage = syms_in.count(s) / float(len(syms_in)) 

112 if percentage < min_percent_elem: 

113 ok = False 

114 break 

115 num_diff = len(set(syms_in)) 

116 if num_diff > max_diff_elem: 

117 ok = False 

118 break 

119 if not ok: 

120 break 

121 if ok: 

122 break 

123 

124 # Sufficient or does some individuals appear 

125 # below min_percentage_elements 

126 

127 for a in f[:cut] + m[cut:]: 

128 indi.append(a) 

129 

130 parent_message = ':Parents {} {}'.format(f.info['confid'], 

131 m.info['confid']) 

132 return (self.finalize_individual(indi), 

133 self.descriptor + parent_message) 

134 

135 

136class TwoPointElementCrossover(ElementCrossover): 

137 """Crosses two individuals by choosing two cross points 

138 at random""" 

139 

140 def __init__(self, element_pool, max_diff_elements=None, 

141 min_percentage_elements=None, verbose=False): 

142 ElementCrossover.__init__(self, element_pool, 

143 max_diff_elements, 

144 min_percentage_elements, verbose) 

145 self.descriptor = 'TwoPointElementCrossover' 

146 

147 def get_new_individual(self, parents): 

148 raise NotImplementedError