Coverage for ase / constraints / constraint.py: 92.78%
97 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-03-30 08:22 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-03-30 08:22 +0000
1from __future__ import annotations
3import numpy as np
5from ase import Atoms
8def slice2enlist(s, n):
9 """Convert a slice object into a list of (new, old) tuples."""
10 if isinstance(s, slice):
11 return enumerate(range(*s.indices(n)))
12 return enumerate(s)
15def ints2string(x, threshold=None):
16 """Convert ndarray of ints to string."""
17 if threshold is None or len(x) <= threshold:
18 return str(x.tolist())
19 return str(x[:threshold].tolist())[:-1] + ', ...]'
22def constrained_indices(atoms, only_include=None):
23 """Returns a list of indices for the atoms that are constrained
24 by a constraint that is applied. By setting only_include to a
25 specific type of constraint you can make it only look for that
26 given constraint.
27 """
28 indices = []
29 for constraint in atoms.constraints:
30 if only_include is not None:
31 if not isinstance(constraint, only_include):
32 continue
33 indices.extend(np.array(constraint.get_indices()))
34 return np.array(np.unique(indices))
37def _normalize(direction):
38 if np.shape(direction) != (3,):
39 raise ValueError('len(direction) is {len(direction)}. Has to be 3')
41 direction = np.asarray(direction) / np.linalg.norm(direction)
42 return direction
45def _projection(vectors, direction):
46 dotprods = vectors @ direction
47 projection = direction[None, :] * dotprods[:, None]
48 return projection
51class FixConstraint:
52 """Base class for classes that fix one or more atoms in some way."""
54 def index_shuffle(self, atoms: Atoms, ind):
55 """Change the indices.
57 When the ordering of the atoms in the Atoms object changes,
58 this method can be called to shuffle the indices of the
59 constraints.
61 ind -- List or tuple of indices.
63 """
64 raise NotImplementedError
66 def repeat(self, m: int, n: int):
67 """basic method to multiply by m, needs to know the length
68 of the underlying atoms object for the assignment of
69 multiplied constraints to work.
70 """
71 msg = (
72 "Repeat is not compatible with your atoms' constraints."
73 ' Use atoms.set_constraint() before calling repeat to '
74 'remove your constraints.'
75 )
76 raise NotImplementedError(msg)
78 def get_removed_dof(self, atoms: Atoms):
79 """Get number of removed degrees of freedom due to constraint."""
80 raise NotImplementedError
82 def adjust_positions(self, atoms: Atoms, new):
83 """Adjust positions."""
85 def adjust_momenta(self, atoms: Atoms, momenta):
86 """Adjust momenta."""
87 # The default is in identical manner to forces.
88 # TODO: The default is however not always reasonable.
89 self.adjust_forces(atoms, momenta)
91 def adjust_forces(self, atoms: Atoms, forces):
92 """Adjust forces."""
94 def copy(self):
95 """Copy constraint."""
96 # Import here to prevent circular imports
97 from ase.constraints import dict2constraint
99 return dict2constraint(self.todict().copy())
101 def todict(self):
102 """Convert constraint to dictionary."""
105class IndexedConstraint(FixConstraint):
106 def __init__(self, indices=None, mask=None):
107 """Constrain chosen atoms.
109 Parameters
110 ----------
111 indices : sequence of int
112 Indices for those atoms that should be constrained.
113 mask : sequence of bool
114 One boolean per atom indicating if the atom should be
115 constrained or not.
116 """
118 if mask is not None:
119 if indices is not None:
120 raise ValueError('Use only one of "indices" and "mask".')
121 indices = mask
122 indices = np.atleast_1d(indices)
123 if np.ndim(indices) > 1:
124 raise ValueError(
125 'indices has wrong amount of dimensions. '
126 f'Got {np.ndim(indices)}, expected ndim <= 1'
127 )
129 if indices.dtype == bool:
130 indices = np.arange(len(indices))[indices]
131 elif len(indices) == 0:
132 indices = np.empty(0, dtype=int)
133 elif not np.issubdtype(indices.dtype, np.integer):
134 raise ValueError(
135 'Indices must be integers or boolean mask, '
136 f'not dtype={indices.dtype}'
137 )
139 if len(set(indices)) < len(indices):
140 raise ValueError(
141 'The indices array contains duplicates. '
142 'Perhaps you want to specify a mask instead, but '
143 'forgot the mask= keyword.'
144 )
146 self.index = indices
148 def index_shuffle(self, atoms, ind):
149 # See docstring of superclass
150 index = []
152 # Resolve negative indices:
153 actual_indices = set(np.arange(len(atoms))[self.index])
155 for new, old in slice2enlist(ind, len(atoms)):
156 if old in actual_indices:
157 index.append(new)
158 if len(index) == 0:
159 raise IndexError('All indices in FixAtoms not part of slice')
160 self.index = np.asarray(index, int)
161 # XXX make immutable
163 def get_indices(self):
164 return self.index.copy()
166 def repeat(self, m, n):
167 i0 = 0
168 natoms = 0
169 if isinstance(m, int):
170 m = (m, m, m)
171 index_new = []
172 for _ in range(m[2]):
173 for _ in range(m[1]):
174 for _ in range(m[0]):
175 i1 = i0 + n
176 index_new += [i + natoms for i in self.index]
177 i0 = i1
178 natoms += n
179 self.index = np.asarray(index_new, int)
180 # XXX make immutable
181 return self
183 def delete_atoms(self, indices, natoms):
184 """Removes atoms from the index array, if present.
186 Required for removing atoms with existing constraint.
187 """
189 i = np.zeros(natoms, int) - 1
190 new = np.delete(np.arange(natoms), indices)
191 i[new] = np.arange(len(new))
192 index = i[self.index]
193 self.index = index[index >= 0]
194 # XXX make immutable
195 if len(self.index) == 0:
196 return None
197 return self