Coverage for ase / constraints / constraint.py: 92.78%

97 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-03-30 08:22 +0000

1from __future__ import annotations 

2 

3import numpy as np 

4 

5from ase import Atoms 

6 

7 

8def slice2enlist(s, n): 

9 """Convert a slice object into a list of (new, old) tuples.""" 

10 if isinstance(s, slice): 

11 return enumerate(range(*s.indices(n))) 

12 return enumerate(s) 

13 

14 

15def ints2string(x, threshold=None): 

16 """Convert ndarray of ints to string.""" 

17 if threshold is None or len(x) <= threshold: 

18 return str(x.tolist()) 

19 return str(x[:threshold].tolist())[:-1] + ', ...]' 

20 

21 

22def constrained_indices(atoms, only_include=None): 

23 """Returns a list of indices for the atoms that are constrained 

24 by a constraint that is applied. By setting only_include to a 

25 specific type of constraint you can make it only look for that 

26 given constraint. 

27 """ 

28 indices = [] 

29 for constraint in atoms.constraints: 

30 if only_include is not None: 

31 if not isinstance(constraint, only_include): 

32 continue 

33 indices.extend(np.array(constraint.get_indices())) 

34 return np.array(np.unique(indices)) 

35 

36 

37def _normalize(direction): 

38 if np.shape(direction) != (3,): 

39 raise ValueError('len(direction) is {len(direction)}. Has to be 3') 

40 

41 direction = np.asarray(direction) / np.linalg.norm(direction) 

42 return direction 

43 

44 

45def _projection(vectors, direction): 

46 dotprods = vectors @ direction 

47 projection = direction[None, :] * dotprods[:, None] 

48 return projection 

49 

50 

51class FixConstraint: 

52 """Base class for classes that fix one or more atoms in some way.""" 

53 

54 def index_shuffle(self, atoms: Atoms, ind): 

55 """Change the indices. 

56 

57 When the ordering of the atoms in the Atoms object changes, 

58 this method can be called to shuffle the indices of the 

59 constraints. 

60 

61 ind -- List or tuple of indices. 

62 

63 """ 

64 raise NotImplementedError 

65 

66 def repeat(self, m: int, n: int): 

67 """basic method to multiply by m, needs to know the length 

68 of the underlying atoms object for the assignment of 

69 multiplied constraints to work. 

70 """ 

71 msg = ( 

72 "Repeat is not compatible with your atoms' constraints." 

73 ' Use atoms.set_constraint() before calling repeat to ' 

74 'remove your constraints.' 

75 ) 

76 raise NotImplementedError(msg) 

77 

78 def get_removed_dof(self, atoms: Atoms): 

79 """Get number of removed degrees of freedom due to constraint.""" 

80 raise NotImplementedError 

81 

82 def adjust_positions(self, atoms: Atoms, new): 

83 """Adjust positions.""" 

84 

85 def adjust_momenta(self, atoms: Atoms, momenta): 

86 """Adjust momenta.""" 

87 # The default is in identical manner to forces. 

88 # TODO: The default is however not always reasonable. 

89 self.adjust_forces(atoms, momenta) 

90 

91 def adjust_forces(self, atoms: Atoms, forces): 

92 """Adjust forces.""" 

93 

94 def copy(self): 

95 """Copy constraint.""" 

96 # Import here to prevent circular imports 

97 from ase.constraints import dict2constraint 

98 

99 return dict2constraint(self.todict().copy()) 

100 

101 def todict(self): 

102 """Convert constraint to dictionary.""" 

103 

104 

105class IndexedConstraint(FixConstraint): 

106 def __init__(self, indices=None, mask=None): 

107 """Constrain chosen atoms. 

108 

109 Parameters 

110 ---------- 

111 indices : sequence of int 

112 Indices for those atoms that should be constrained. 

113 mask : sequence of bool 

114 One boolean per atom indicating if the atom should be 

115 constrained or not. 

116 """ 

117 

118 if mask is not None: 

119 if indices is not None: 

120 raise ValueError('Use only one of "indices" and "mask".') 

121 indices = mask 

122 indices = np.atleast_1d(indices) 

123 if np.ndim(indices) > 1: 

124 raise ValueError( 

125 'indices has wrong amount of dimensions. ' 

126 f'Got {np.ndim(indices)}, expected ndim <= 1' 

127 ) 

128 

129 if indices.dtype == bool: 

130 indices = np.arange(len(indices))[indices] 

131 elif len(indices) == 0: 

132 indices = np.empty(0, dtype=int) 

133 elif not np.issubdtype(indices.dtype, np.integer): 

134 raise ValueError( 

135 'Indices must be integers or boolean mask, ' 

136 f'not dtype={indices.dtype}' 

137 ) 

138 

139 if len(set(indices)) < len(indices): 

140 raise ValueError( 

141 'The indices array contains duplicates. ' 

142 'Perhaps you want to specify a mask instead, but ' 

143 'forgot the mask= keyword.' 

144 ) 

145 

146 self.index = indices 

147 

148 def index_shuffle(self, atoms, ind): 

149 # See docstring of superclass 

150 index = [] 

151 

152 # Resolve negative indices: 

153 actual_indices = set(np.arange(len(atoms))[self.index]) 

154 

155 for new, old in slice2enlist(ind, len(atoms)): 

156 if old in actual_indices: 

157 index.append(new) 

158 if len(index) == 0: 

159 raise IndexError('All indices in FixAtoms not part of slice') 

160 self.index = np.asarray(index, int) 

161 # XXX make immutable 

162 

163 def get_indices(self): 

164 return self.index.copy() 

165 

166 def repeat(self, m, n): 

167 i0 = 0 

168 natoms = 0 

169 if isinstance(m, int): 

170 m = (m, m, m) 

171 index_new = [] 

172 for _ in range(m[2]): 

173 for _ in range(m[1]): 

174 for _ in range(m[0]): 

175 i1 = i0 + n 

176 index_new += [i + natoms for i in self.index] 

177 i0 = i1 

178 natoms += n 

179 self.index = np.asarray(index_new, int) 

180 # XXX make immutable 

181 return self 

182 

183 def delete_atoms(self, indices, natoms): 

184 """Removes atoms from the index array, if present. 

185 

186 Required for removing atoms with existing constraint. 

187 """ 

188 

189 i = np.zeros(natoms, int) - 1 

190 new = np.delete(np.arange(natoms), indices) 

191 i[new] = np.arange(len(new)) 

192 index = i[self.index] 

193 self.index = index[index >= 0] 

194 # XXX make immutable 

195 if len(self.index) == 0: 

196 return None 

197 return self