Coverage for /builds/ase/ase/ase/utils/parsemath.py: 85.11%

94 statements  

« prev     ^ index     » next       coverage.py v7.5.3, created at 2025-08-02 00:12 +0000

1# fmt: off 

2 

3"""A Module to safely parse/evaluate Mathematical Expressions""" 

4import ast 

5import math 

6import operator as op 

7 

8from numpy import int64 

9 

10# Sets the limit of how high the number can get to prevent DNS attacks 

11max_value = 1e17 

12 

13 

14# Redefine mathematical operations to prevent DNS attacks 

15def add(a, b): 

16 """Redefine add function to prevent too large numbers""" 

17 if any(abs(n) > max_value for n in [a, b]): 

18 raise ValueError((a, b)) 

19 return op.add(a, b) 

20 

21 

22def sub(a, b): 

23 """Redefine sub function to prevent too large numbers""" 

24 if any(abs(n) > max_value for n in [a, b]): 

25 raise ValueError((a, b)) 

26 return op.sub(a, b) 

27 

28 

29def mul(a, b): 

30 """Redefine mul function to prevent too large numbers""" 

31 if a == 0.0 or b == 0.0: 

32 pass 

33 elif math.log10(abs(a)) + math.log10(abs(b)) > math.log10(max_value): 

34 raise ValueError((a, b)) 

35 return op.mul(a, b) 

36 

37 

38def div(a, b): 

39 """Redefine div function to prevent too large numbers""" 

40 if b == 0.0: 

41 raise ValueError((a, b)) 

42 elif a == 0.0: 

43 pass 

44 elif math.log10(abs(a)) - math.log10(abs(b)) > math.log10(max_value): 

45 raise ValueError((a, b)) 

46 return op.truediv(a, b) 

47 

48 

49def power(a, b): 

50 """Redefine pow function to prevent too large numbers""" 

51 if a == 0.0: 

52 return 0.0 

53 elif b / math.log(max_value, abs(a)) >= 1: 

54 raise ValueError((a, b)) 

55 return op.pow(a, b) 

56 

57 

58def exp(a): 

59 """Redefine exp function to prevent too large numbers""" 

60 if a > math.log(max_value): 

61 raise ValueError(a) 

62 return math.exp(a) 

63 

64 

65# The list of allowed operators with defined functions they should operate on 

66operators = { 

67 ast.Add: add, 

68 ast.Sub: sub, 

69 ast.Mult: mul, 

70 ast.Div: div, 

71 ast.Pow: power, 

72 ast.USub: op.neg, 

73 ast.Mod: op.mod, 

74 ast.FloorDiv: op.ifloordiv 

75} 

76 

77# Take all functions from math module as allowed functions 

78allowed_math_fxn = { 

79 "sin": math.sin, 

80 "cos": math.cos, 

81 "tan": math.tan, 

82 "asin": math.asin, 

83 "acos": math.acos, 

84 "atan": math.atan, 

85 "atan2": math.atan2, 

86 "hypot": math.hypot, 

87 "sinh": math.sinh, 

88 "cosh": math.cosh, 

89 "tanh": math.tanh, 

90 "asinh": math.asinh, 

91 "acosh": math.acosh, 

92 "atanh": math.atanh, 

93 "radians": math.radians, 

94 "degrees": math.degrees, 

95 "sqrt": math.sqrt, 

96 "log": math.log, 

97 "log10": math.log10, 

98 "log2": math.log2, 

99 "fmod": math.fmod, 

100 "abs": math.fabs, 

101 "ceil": math.ceil, 

102 "floor": math.floor, 

103 "round": round, 

104 "exp": exp, 

105} 

106 

107 

108def get_function(node): 

109 """Get the function from an ast.node""" 

110 

111 # The function call can be to a bare function or a module.function 

112 if isinstance(node.func, ast.Name): 

113 return node.func.id 

114 elif isinstance(node.func, ast.Attribute): 

115 return node.func.attr 

116 else: 

117 raise TypeError("node.func is of the wrong type") 

118 

119 

120def limit(max_=None): 

121 """Return decorator that limits allowed returned values.""" 

122 import functools 

123 

124 def decorator(func): 

125 @functools.wraps(func) 

126 def wrapper(*args, **kwargs): 

127 ret = func(*args, **kwargs) 

128 try: 

129 mag = abs(ret) 

130 except TypeError: 

131 pass # not applicable 

132 else: 

133 if mag > max_: 

134 raise ValueError(ret) 

135 if isinstance(ret, int): 

136 ret = int64(ret) 

137 return ret 

138 

139 return wrapper 

140 

141 return decorator 

142 

143 

144@limit(max_=max_value) 

145def _eval(node): 

146 """Evaluate a mathematical expression string parsed by ast""" 

147 # Allow evaluate certain types of operators 

148 if isinstance(node, ast.Constant) and isinstance(node.value, (float, int)): 

149 return node.value 

150 elif isinstance(node, ast.BinOp): # <left> <operator> <right> 

151 return operators[type(node.op)](_eval(node.left), _eval(node.right)) 

152 elif isinstance(node, ast.UnaryOp): # <operator> <operand> e.g., -1 

153 return operators[type(node.op)](_eval(node.operand)) 

154 elif isinstance(node, ast.Call): # using math.function 

155 func = get_function(node) 

156 # Evaluate all arguments 

157 evaled_args = [_eval(arg) for arg in node.args] 

158 return allowed_math_fxn[func](*evaled_args) 

159 elif isinstance(node, ast.Name): 

160 if node.id.lower() == "pi": 

161 return math.pi 

162 elif node.id.lower() == "e": 

163 return math.e 

164 elif node.id.lower() == "tau": 

165 return math.pi * 2.0 

166 else: 

167 raise TypeError( 

168 "Found a str in the expression, either param_dct/the " 

169 "expression has a mistake in the parameter names or " 

170 "attempting to parse non-mathematical code") 

171 else: 

172 raise TypeError(node) 

173 

174 

175def eval_expression(expression, param_dct={}): 

176 """Parse a mathematical expression, 

177 

178 Replaces variables with the values in param_dict and solves the expression 

179 

180 """ 

181 if not isinstance(expression, str): 

182 raise TypeError("The expression must be a string") 

183 if len(expression) > 1e4: 

184 raise ValueError("The expression is too long.") 

185 

186 expression_rep = expression.strip() 

187 

188 if "()" in expression_rep: 

189 raise ValueError("Invalid operation in expression") 

190 

191 for key, val in param_dct.items(): 

192 expression_rep = expression_rep.replace(key, str(val)) 

193 

194 return _eval(ast.parse(expression_rep, mode="eval").body)