Coverage for /builds/ase/ase/ase/ga/element_crossovers.py: 77.46%
71 statements
« prev ^ index » next coverage.py v7.5.3, created at 2025-08-02 00:12 +0000
« prev ^ index » next coverage.py v7.5.3, created at 2025-08-02 00:12 +0000
1# fmt: off
3"""Crossover classes, that cross the elements in the supplied
4atoms objects.
6"""
7import numpy as np
9from ase.ga.offspring_creator import OffspringCreator
12class ElementCrossover(OffspringCreator):
13 """Base class for all operators where the elements of
14 the atoms objects cross.
16 """
18 def __init__(self, element_pool, max_diff_elements,
19 min_percentage_elements, verbose, rng=np.random):
20 OffspringCreator.__init__(self, verbose, rng=rng)
22 if not isinstance(element_pool[0], (list, np.ndarray)):
23 self.element_pools = [element_pool]
24 else:
25 self.element_pools = element_pool
27 if max_diff_elements is None:
28 self.max_diff_elements = [None for _ in self.element_pools]
29 elif isinstance(max_diff_elements, int):
30 self.max_diff_elements = [max_diff_elements]
31 else:
32 self.max_diff_elements = max_diff_elements
33 assert len(self.max_diff_elements) == len(self.element_pools)
35 if min_percentage_elements is None:
36 self.min_percentage_elements = [0 for _ in self.element_pools]
37 elif isinstance(min_percentage_elements, (int, float)):
38 self.min_percentage_elements = [min_percentage_elements]
39 else:
40 self.min_percentage_elements = min_percentage_elements
41 assert len(self.min_percentage_elements) == len(self.element_pools)
43 self.min_inputs = 2
45 def get_new_individual(self, parents):
46 raise NotImplementedError
49class OnePointElementCrossover(ElementCrossover):
50 """Crossover of the elements in the atoms objects. Point of cross
51 is chosen randomly.
53 Parameters:
55 element_pool: List of elements in the phase space. The elements can be
56 grouped if the individual consist of different types of elements.
57 The list should then be a list of lists e.g. [[list1], [list2]]
59 max_diff_elements: The maximum number of different elements in the
60 individual. Default is infinite. If the elements are grouped
61 max_diff_elements should be supplied as a list with each input
62 corresponding to the elements specified in the same input in
63 element_pool.
65 min_percentage_elements: The minimum percentage of any element in
66 the individual. Default is any number is allowed. If the elements
67 are grouped min_percentage_elements should be supplied as a list
68 with each input corresponding to the elements specified in the
69 same input in element_pool.
71 Example: element_pool=[[A,B,C,D],[x,y,z]], max_diff_elements=[3,2],
72 min_percentage_elements=[.25, .5]
73 An individual could be "D,B,B,C,x,x,x,x,z,z,z,z"
75 rng: Random number generator
76 By default numpy.random.
77 """
79 def __init__(self, element_pool, max_diff_elements=None,
80 min_percentage_elements=None, verbose=False, rng=np.random):
81 ElementCrossover.__init__(self, element_pool, max_diff_elements,
82 min_percentage_elements, verbose, rng=rng)
83 self.descriptor = 'OnePointElementCrossover'
85 def get_new_individual(self, parents):
86 f, m = parents
88 indi = self.initialize_individual(f)
89 indi.info['data']['parents'] = [i.info['confid'] for i in parents]
91 cut_choices = [i for i in range(1, len(f) - 1)]
92 self.rng.shuffle(cut_choices)
93 for cut in cut_choices:
94 fsyms = f.get_chemical_symbols()
95 msyms = m.get_chemical_symbols()
96 syms = fsyms[:cut] + msyms[cut:]
97 ok = True
98 for i, e in enumerate(self.element_pools):
99 elems = e[:]
100 elems_in, indices_in = zip(*[(a.symbol, a.index) for a in f
101 if a.symbol in elems])
102 max_diff_elem = self.max_diff_elements[i]
103 min_percent_elem = self.min_percentage_elements[i]
104 if min_percent_elem == 0:
105 min_percent_elem = 1. / len(elems_in)
106 if max_diff_elem is None:
107 max_diff_elem = len(elems_in)
109 syms_in = [syms[i] for i in indices_in]
110 for s in set(syms_in):
111 percentage = syms_in.count(s) / float(len(syms_in))
112 if percentage < min_percent_elem:
113 ok = False
114 break
115 num_diff = len(set(syms_in))
116 if num_diff > max_diff_elem:
117 ok = False
118 break
119 if not ok:
120 break
121 if ok:
122 break
124 # Sufficient or does some individuals appear
125 # below min_percentage_elements
127 for a in f[:cut] + m[cut:]:
128 indi.append(a)
130 parent_message = ':Parents {} {}'.format(f.info['confid'],
131 m.info['confid'])
132 return (self.finalize_individual(indi),
133 self.descriptor + parent_message)
136class TwoPointElementCrossover(ElementCrossover):
137 """Crosses two individuals by choosing two cross points
138 at random"""
140 def __init__(self, element_pool, max_diff_elements=None,
141 min_percentage_elements=None, verbose=False):
142 ElementCrossover.__init__(self, element_pool,
143 max_diff_elements,
144 min_percentage_elements, verbose)
145 self.descriptor = 'TwoPointElementCrossover'
147 def get_new_individual(self, parents):
148 raise NotImplementedError