Coverage for ase / constraints / fix_bond_lengths.py: 96.34%

82 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-03-30 08:22 +0000

1import numpy as np 

2 

3from ase.constraints.constraint import FixConstraint 

4from ase.geometry import find_mic 

5 

6 

7class FixBondLengths(FixConstraint): 

8 maxiter = 500 

9 

10 def __init__( 

11 self, pairs, tolerance=1e-13, bondlengths=None, iterations=None 

12 ): 

13 """iterations: 

14 Ignored""" 

15 self.pairs = np.asarray(pairs) 

16 self.tolerance = tolerance 

17 self.bondlengths = bondlengths 

18 

19 def get_removed_dof(self, atoms): 

20 return len(self.pairs) 

21 

22 def adjust_positions(self, atoms, new): 

23 old = atoms.positions 

24 masses = atoms.get_masses() 

25 

26 if self.bondlengths is None: 

27 self.bondlengths = self.initialize_bond_lengths(atoms) 

28 

29 for i in range(self.maxiter): 

30 converged = True 

31 for j, ab in enumerate(self.pairs): 

32 a = ab[0] 

33 b = ab[1] 

34 cd = self.bondlengths[j] 

35 r0 = old[a] - old[b] 

36 d0, _ = find_mic(r0, atoms.cell, atoms.pbc) 

37 d1 = new[a] - new[b] - r0 + d0 

38 m = 1 / (1 / masses[a] + 1 / masses[b]) 

39 x = 0.5 * (cd**2 - np.dot(d1, d1)) / np.dot(d0, d1) 

40 if abs(x) > self.tolerance: 

41 new[a] += x * m / masses[a] * d0 

42 new[b] -= x * m / masses[b] * d0 

43 converged = False 

44 if converged: 

45 break 

46 else: 

47 raise RuntimeError('Did not converge') 

48 

49 def adjust_momenta(self, atoms, p): 

50 old = atoms.positions 

51 masses = atoms.get_masses() 

52 

53 if self.bondlengths is None: 

54 self.bondlengths = self.initialize_bond_lengths(atoms) 

55 

56 for i in range(self.maxiter): 

57 converged = True 

58 for j, ab in enumerate(self.pairs): 

59 a = ab[0] 

60 b = ab[1] 

61 cd = self.bondlengths[j] 

62 d = old[a] - old[b] 

63 d, _ = find_mic(d, atoms.cell, atoms.pbc) 

64 dv = p[a] / masses[a] - p[b] / masses[b] 

65 m = 1 / (1 / masses[a] + 1 / masses[b]) 

66 x = -np.dot(dv, d) / cd**2 

67 if abs(x) > self.tolerance: 

68 p[a] += x * m * d 

69 p[b] -= x * m * d 

70 converged = False 

71 if converged: 

72 break 

73 else: 

74 raise RuntimeError('Did not converge') 

75 

76 def adjust_forces(self, atoms, forces): 

77 self.constraint_forces = -forces 

78 self.adjust_momenta(atoms, forces) 

79 self.constraint_forces += forces 

80 

81 def initialize_bond_lengths(self, atoms): 

82 bondlengths = np.zeros(len(self.pairs)) 

83 

84 for i, ab in enumerate(self.pairs): 

85 bondlengths[i] = atoms.get_distance(ab[0], ab[1], mic=True) 

86 

87 return bondlengths 

88 

89 def get_indices(self): 

90 return np.unique(self.pairs.ravel()) 

91 

92 def todict(self): 

93 return { 

94 'name': 'FixBondLengths', 

95 'kwargs': { 

96 'pairs': self.pairs.tolist(), 

97 'tolerance': self.tolerance, 

98 }, 

99 } 

100 

101 def index_shuffle(self, atoms, ind): 

102 """Shuffle the indices of the two atoms in this constraint""" 

103 map = np.zeros(len(atoms), int) 

104 map[ind] = 1 

105 n = map.sum() 

106 map[:] = -1 

107 map[ind] = range(n) 

108 pairs = map[self.pairs] 

109 self.pairs = pairs[(pairs != -1).all(1)] 

110 if len(self.pairs) == 0: 

111 raise IndexError('Constraint not part of slice') 

112 

113 

114def FixBondLength(a1, a2): 

115 """Fix distance between atoms with indices a1 and a2.""" 

116 return FixBondLengths([(a1, a2)])