You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 

108 lines
3.6 KiB

import math_ast as ast
from sympy.parsing.sympy_parser import parse_expr
from sympy.core.numbers import Integer, One, Zero
from sympy import symbols, Eq, solveset, linsolve, nonlinsolve
from sympy.core.symbol import Symbol
def interpret(statement: ast.Statement) -> str:
if isinstance(statement, ast.Solve):
return interpret_solve(statement)
elif isinstance(statement, ast.Expression):
return interpret_expression(statement)
return "interpretation error"
def interpret_solve(statement: ast.Solve) -> str:
eqs = statement.equations
var = statement.variables
# convert equations to list of sympy Eq objects
equations = [Eq(_math_expression_sanitation_and_parse(e.lhs.value), _math_expression_sanitation_and_parse(e.rhs.value)) for e in eqs]
variables = [symbols(v.value) for v in var]
if len(equations) == 1 and len(variables) == 1:
return solve_simple_equation(equations[0], variables[0])
else:
return solve_multi_equation(equations, variables)
def solve_simple_equation(equation, variable):
result = solveset(equation, variable)
return "solved %s = %s for %s = %s" % (equation.lhs, equation.rhs, variable, result)
def solve_multi_equation(equations, variables):
if is_linear(equations, variables):
solution = linsolve(equations, variables)
else:
solution = nonlinsolve(equations, variables)
solutionpairs = []
for variable, value in zip(variables, list(solution)[0]):
value_str = str(value)
if not isinstance(value, Integer):
try:
float_value = value.evalf()
if len(value_str) > 20:
value_str = "~%.3f" % float_value
else:
value_str += "=~%.3f" % float_value
except:
pass
solutionpairs.append(f"{variable}={value_str}")
# solutionpairs = [f"{variable}={value.doit()}" for variable, value in zip(variables, list(solution)[0])]
return "solved equation system for " + ", ".join(solutionpairs[:-1]) + " and " + solutionpairs[-1]
def is_linear(equations, variables):
return False
"""Checks if a system of equations is linear."""
for eq in equations:
for var in variables:
deriv = eq.diff(var) # Partial derivative
if not (deriv.is_number or (isinstance(deriv, Symbol) and deriv.free_symbols.isdisjoint({var}))): # If the derivative is not a number or a symbol independent of the variable, the system is non-linear
return False
return True
def interpret_expression(statement: ast.Expression) -> str:
return _math_evaluate_expression(statement.value)
def _math_evaluate_expression(expression: str):
"""evaluate a simple mathematical expression using sympy expression evaluation."""
therm, simple, result = _math_evaluate_internal(expression)
if isinstance(simple, Integer):
return _build_equation_pair([therm, simple])
if therm == simple or simple == result:
return _build_equation_pair([therm, result])
return _build_equation_pair([therm, simple, result])
def _math_evaluate_internal(expression: str):
therm = _math_expression_sanitation_and_parse(expression)
simple = therm.doit()
numerical = therm.evalf()
return therm, simple, numerical
def _math_expression_sanitation_and_parse(expression: str):
expression = expression.replace("^", "**")
return parse_expr(expression, evaluate=False)
def _build_equation_pair(expressions: list[any]) -> str:
expressions = [str(e) for e in expressions]
return " = ".join(expressions)