Coverage for /builds/ase/ase/ase/utils/parsemath.py: 85.11%
94 statements
« prev ^ index » next coverage.py v7.5.3, created at 2025-08-02 00:12 +0000
« prev ^ index » next coverage.py v7.5.3, created at 2025-08-02 00:12 +0000
1# fmt: off
3"""A Module to safely parse/evaluate Mathematical Expressions"""
4import ast
5import math
6import operator as op
8from numpy import int64
10# Sets the limit of how high the number can get to prevent DNS attacks
11max_value = 1e17
14# Redefine mathematical operations to prevent DNS attacks
15def add(a, b):
16 """Redefine add function to prevent too large numbers"""
17 if any(abs(n) > max_value for n in [a, b]):
18 raise ValueError((a, b))
19 return op.add(a, b)
22def sub(a, b):
23 """Redefine sub function to prevent too large numbers"""
24 if any(abs(n) > max_value for n in [a, b]):
25 raise ValueError((a, b))
26 return op.sub(a, b)
29def mul(a, b):
30 """Redefine mul function to prevent too large numbers"""
31 if a == 0.0 or b == 0.0:
32 pass
33 elif math.log10(abs(a)) + math.log10(abs(b)) > math.log10(max_value):
34 raise ValueError((a, b))
35 return op.mul(a, b)
38def div(a, b):
39 """Redefine div function to prevent too large numbers"""
40 if b == 0.0:
41 raise ValueError((a, b))
42 elif a == 0.0:
43 pass
44 elif math.log10(abs(a)) - math.log10(abs(b)) > math.log10(max_value):
45 raise ValueError((a, b))
46 return op.truediv(a, b)
49def power(a, b):
50 """Redefine pow function to prevent too large numbers"""
51 if a == 0.0:
52 return 0.0
53 elif b / math.log(max_value, abs(a)) >= 1:
54 raise ValueError((a, b))
55 return op.pow(a, b)
58def exp(a):
59 """Redefine exp function to prevent too large numbers"""
60 if a > math.log(max_value):
61 raise ValueError(a)
62 return math.exp(a)
65# The list of allowed operators with defined functions they should operate on
66operators = {
67 ast.Add: add,
68 ast.Sub: sub,
69 ast.Mult: mul,
70 ast.Div: div,
71 ast.Pow: power,
72 ast.USub: op.neg,
73 ast.Mod: op.mod,
74 ast.FloorDiv: op.ifloordiv
75}
77# Take all functions from math module as allowed functions
78allowed_math_fxn = {
79 "sin": math.sin,
80 "cos": math.cos,
81 "tan": math.tan,
82 "asin": math.asin,
83 "acos": math.acos,
84 "atan": math.atan,
85 "atan2": math.atan2,
86 "hypot": math.hypot,
87 "sinh": math.sinh,
88 "cosh": math.cosh,
89 "tanh": math.tanh,
90 "asinh": math.asinh,
91 "acosh": math.acosh,
92 "atanh": math.atanh,
93 "radians": math.radians,
94 "degrees": math.degrees,
95 "sqrt": math.sqrt,
96 "log": math.log,
97 "log10": math.log10,
98 "log2": math.log2,
99 "fmod": math.fmod,
100 "abs": math.fabs,
101 "ceil": math.ceil,
102 "floor": math.floor,
103 "round": round,
104 "exp": exp,
105}
108def get_function(node):
109 """Get the function from an ast.node"""
111 # The function call can be to a bare function or a module.function
112 if isinstance(node.func, ast.Name):
113 return node.func.id
114 elif isinstance(node.func, ast.Attribute):
115 return node.func.attr
116 else:
117 raise TypeError("node.func is of the wrong type")
120def limit(max_=None):
121 """Return decorator that limits allowed returned values."""
122 import functools
124 def decorator(func):
125 @functools.wraps(func)
126 def wrapper(*args, **kwargs):
127 ret = func(*args, **kwargs)
128 try:
129 mag = abs(ret)
130 except TypeError:
131 pass # not applicable
132 else:
133 if mag > max_:
134 raise ValueError(ret)
135 if isinstance(ret, int):
136 ret = int64(ret)
137 return ret
139 return wrapper
141 return decorator
144@limit(max_=max_value)
145def _eval(node):
146 """Evaluate a mathematical expression string parsed by ast"""
147 # Allow evaluate certain types of operators
148 if isinstance(node, ast.Constant) and isinstance(node.value, (float, int)):
149 return node.value
150 elif isinstance(node, ast.BinOp): # <left> <operator> <right>
151 return operators[type(node.op)](_eval(node.left), _eval(node.right))
152 elif isinstance(node, ast.UnaryOp): # <operator> <operand> e.g., -1
153 return operators[type(node.op)](_eval(node.operand))
154 elif isinstance(node, ast.Call): # using math.function
155 func = get_function(node)
156 # Evaluate all arguments
157 evaled_args = [_eval(arg) for arg in node.args]
158 return allowed_math_fxn[func](*evaled_args)
159 elif isinstance(node, ast.Name):
160 if node.id.lower() == "pi":
161 return math.pi
162 elif node.id.lower() == "e":
163 return math.e
164 elif node.id.lower() == "tau":
165 return math.pi * 2.0
166 else:
167 raise TypeError(
168 "Found a str in the expression, either param_dct/the "
169 "expression has a mistake in the parameter names or "
170 "attempting to parse non-mathematical code")
171 else:
172 raise TypeError(node)
175def eval_expression(expression, param_dct={}):
176 """Parse a mathematical expression,
178 Replaces variables with the values in param_dict and solves the expression
180 """
181 if not isinstance(expression, str):
182 raise TypeError("The expression must be a string")
183 if len(expression) > 1e4:
184 raise ValueError("The expression is too long.")
186 expression_rep = expression.strip()
188 if "()" in expression_rep:
189 raise ValueError("Invalid operation in expression")
191 for key, val in param_dct.items():
192 expression_rep = expression_rep.replace(key, str(val))
194 return _eval(ast.parse(expression_rep, mode="eval").body)