Coverage for backend/django/Economics/formulas/engine/parsing.py: 86%

68 statements  

« prev     ^ index     » next       coverage.py v7.10.7, created at 2026-06-23 21:51 +0000

1from __future__ import annotations 

2 

3import ast 

4from decimal import Decimal, InvalidOperation 

5from typing import Iterable 

6 

7import sympy 

8 

9from .core import FormulaError, decimal_to_sympy 

10 

11MAX_EXPRESSION_NODES = 80 

12MAX_EXPRESSION_DEPTH = 24 

13 

14 

15def parse_cost_expression(expression: str, *, variable_symbols: Iterable[str]) -> sympy.Expr: 

16 """Parse the restricted cost-curve arithmetic grammar into a SymPy tree.""" 

17 allowed_symbols = tuple(str(symbol).strip() for symbol in variable_symbols) 

18 if not allowed_symbols: 18 ↛ 19line 18 didn't jump to line 19 because the condition on line 18 was never true

19 raise FormulaError("missing_formula_variables", "Cost curve expressions require declared formula variables.") 

20 if len(set(allowed_symbols)) != len(allowed_symbols): 20 ↛ 21line 20 didn't jump to line 21 because the condition on line 20 was never true

21 raise FormulaError( 

22 "duplicate_formula_variables", 

23 "Cost curve expression variables must be unique.", 

24 context={"variables": sorted(allowed_symbols)}, 

25 ) 

26 normalized_expression = expression.replace("^", "**") 

27 try: 

28 tree = ast.parse(normalized_expression, mode="eval") 

29 except SyntaxError as exc: 

30 raise FormulaError( 

31 "invalid_expression_syntax", 

32 "Cost curve expression could not be parsed.", 

33 context={"expression": expression}, 

34 ) from exc 

35 _validate_expression_size(tree, expression=expression) 

36 parsed_expression = _parse_node( 

37 tree, 

38 allowed_symbols=frozenset(allowed_symbols), 

39 source_expression=expression, 

40 ) 

41 used_symbols = {str(symbol) for symbol in parsed_expression.free_symbols} 

42 unused_symbols = sorted(set(allowed_symbols) - used_symbols) 

43 if unused_symbols: 

44 raise FormulaError( 

45 "unused_formula_variables", 

46 "Cost curve expression does not use all declared formula variables.", 

47 context={"unused_variables": unused_symbols, "expression": expression}, 

48 ) 

49 return parsed_expression 

50 

51 

52def _parse_node( 

53 node: ast.AST, 

54 *, 

55 allowed_symbols: frozenset[str], 

56 source_expression: str, 

57) -> sympy.Expr: 

58 if isinstance(node, ast.Expression): 

59 return _parse_node( 

60 node.body, 

61 allowed_symbols=allowed_symbols, 

62 source_expression=source_expression, 

63 ) 

64 

65 if isinstance(node, ast.Constant) and type(node.value) in (int, float): 

66 return _number_node(node.value) 

67 

68 if isinstance(node, ast.Name): 

69 if node.id in allowed_symbols: 

70 return sympy.Symbol(node.id) 

71 raise FormulaError( 

72 "unsupported_expression_name", 

73 "Cost curve expression used an unsupported variable name.", 

74 context={ 

75 "name": node.id, 

76 "allowed_names": sorted(allowed_symbols), 

77 "expression": source_expression, 

78 }, 

79 ) 

80 

81 if isinstance(node, ast.UnaryOp) and isinstance(node.op, ast.UAdd | ast.USub): 

82 value = _parse_node( 

83 node.operand, 

84 allowed_symbols=allowed_symbols, 

85 source_expression=source_expression, 

86 ) 

87 return value if isinstance(node.op, ast.UAdd) else sympy.Mul(sympy.Integer(-1), value, evaluate=False) 

88 

89 if isinstance(node, ast.BinOp) and isinstance(node.op, ast.Add | ast.Sub | ast.Mult | ast.Div | ast.Pow): 

90 left = _parse_node( 

91 node.left, 

92 allowed_symbols=allowed_symbols, 

93 source_expression=source_expression, 

94 ) 

95 right = _parse_node( 

96 node.right, 

97 allowed_symbols=allowed_symbols, 

98 source_expression=source_expression, 

99 ) 

100 if isinstance(node.op, ast.Add): 

101 return sympy.Add(left, right, evaluate=False) 

102 if isinstance(node.op, ast.Sub): 

103 return sympy.Add(left, sympy.Mul(sympy.Integer(-1), right, evaluate=False), evaluate=False) 

104 if isinstance(node.op, ast.Mult): 

105 return sympy.Mul(left, right, evaluate=False) 

106 if isinstance(node.op, ast.Div): 106 ↛ 107line 106 didn't jump to line 107 because the condition on line 106 was never true

107 return sympy.Mul(left, sympy.Pow(right, sympy.Integer(-1), evaluate=False), evaluate=False) 

108 return sympy.Pow(left, right, evaluate=False) 

109 

110 raise FormulaError( 

111 "unsupported_expression_node", 

112 "Cost curve expression uses unsupported syntax.", 

113 context={"node_type": type(node).__name__, "expression": source_expression}, 

114 ) 

115 

116 

117def _number_node(value: int | float) -> sympy.Rational: 

118 try: 

119 decimal = Decimal(str(value)) 

120 except (InvalidOperation, TypeError, ValueError) as exc: 

121 raise FormulaError( 

122 "invalid_expression_number", 

123 "Cost curve expression contains an invalid number.", 

124 context={"value": str(value)}, 

125 ) from exc 

126 return decimal_to_sympy(decimal) 

127 

128 

129def _validate_expression_size(tree: ast.AST, *, expression: str) -> None: 

130 node_count = 0 

131 

132 def visit(node: ast.AST, depth: int) -> None: 

133 nonlocal node_count 

134 node_count += 1 

135 if node_count > MAX_EXPRESSION_NODES: 135 ↛ 136line 135 didn't jump to line 136 because the condition on line 135 was never true

136 raise FormulaError( 

137 "expression_too_large", 

138 "Cost curve expression is too large.", 

139 context={"expression": expression, "max_nodes": MAX_EXPRESSION_NODES}, 

140 ) 

141 if depth > MAX_EXPRESSION_DEPTH: 141 ↛ 142line 141 didn't jump to line 142 because the condition on line 141 was never true

142 raise FormulaError( 

143 "expression_too_deep", 

144 "Cost curve expression is too deeply nested.", 

145 context={"expression": expression, "max_depth": MAX_EXPRESSION_DEPTH}, 

146 ) 

147 for child in ast.iter_child_nodes(node): 

148 visit(child, depth + 1) 

149 

150 visit(tree, 0)