You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
122 lines
4.1 KiB
122 lines
4.1 KiB
|
|
|
|
|
|
from sympy.parsing.sympy_parser import parse_expr
|
|
from sympy.core.numbers import Integer, One, Zero
|
|
from sympy import symbols, Eq, solveset, linsolve, nonlinsolve
|
|
from sympy.core.symbol import Symbol
|
|
from chatbug.matheval import ast
|
|
|
|
|
|
def interpret(statement: ast.Statement) -> str:
|
|
if isinstance(statement, ast.Solve):
|
|
return interpret_solve(statement)
|
|
elif isinstance(statement, ast.Expression):
|
|
return interpret_expression(statement)
|
|
return "interpretation error"
|
|
|
|
|
|
def interpret_solve(statement: ast.Solve) -> str:
|
|
eqs = statement.equations
|
|
var = statement.variables
|
|
|
|
# convert equations to list of sympy Eq objects
|
|
equations = [Eq(_math_expression_sanitation_and_parse(e.lhs.value), _math_expression_sanitation_and_parse(e.rhs.value)) for e in eqs]
|
|
|
|
variables = [symbols(v.value) for v in var]
|
|
|
|
if len(equations) == 1 and len(variables) == 1:
|
|
return solve_simple_equation(equations[0], variables[0])
|
|
else:
|
|
return solve_multi_equation(equations, variables)
|
|
|
|
|
|
|
|
def solve_simple_equation(equation, variable):
|
|
result = solveset(equation, variable)
|
|
return "solved %s = %s for %s = %s" % (equation.lhs, equation.rhs, variable, result)
|
|
|
|
def solve_multi_equation(equations, variables):
|
|
if is_linear(equations, variables):
|
|
solution = linsolve(equations, variables)
|
|
else:
|
|
solution = nonlinsolve(equations, variables)
|
|
|
|
solutionpairs = []
|
|
for variable, value in zip(variables, list(solution)[0]):
|
|
value_str = str(value)
|
|
if not isinstance(value, Integer):
|
|
try:
|
|
float_value = value.evalf()
|
|
if len(value_str) > 20:
|
|
value_str = "~%.3f" % float_value
|
|
else:
|
|
value_str += "=~%.3f" % float_value
|
|
except:
|
|
pass
|
|
|
|
solutionpairs.append(f"{variable}={value_str}")
|
|
|
|
# solutionpairs = [f"{variable}={value.doit()}" for variable, value in zip(variables, list(solution)[0])]
|
|
|
|
# return "solved equation system for " + ", ".join(solutionpairs[:-1]) + " and " + solutionpairs[-1]
|
|
|
|
if len(equations) > 1:
|
|
leadin = "Solved equation system "
|
|
else:
|
|
leadin = "Solved equation "
|
|
return leadin + _natural_join([_pretty_equation(e) for e in equations]) + " for " + _natural_join(solutionpairs) + "."
|
|
|
|
def _natural_join(data: list[any], joiner=", ", last=" and "):
|
|
if len(data) > 1:
|
|
return joiner.join(data[:-1]) + last + data[-1]
|
|
return last.join(data)
|
|
|
|
def _pretty_equation(simpy_Eq) -> str:
|
|
return f"{simpy_Eq.lhs} = {simpy_Eq.rhs}"
|
|
|
|
|
|
|
|
|
|
def is_linear(equations, variables):
|
|
return False
|
|
"""Checks if a system of equations is linear."""
|
|
for eq in equations:
|
|
for var in variables:
|
|
deriv = eq.diff(var) # Partial derivative
|
|
if not (deriv.is_number or (isinstance(deriv, Symbol) and deriv.free_symbols.isdisjoint({var}))): # If the derivative is not a number or a symbol independent of the variable, the system is non-linear
|
|
return False
|
|
return True
|
|
|
|
|
|
|
|
|
|
def interpret_expression(statement: ast.Expression) -> str:
|
|
return _math_evaluate_expression(statement.value)
|
|
|
|
|
|
def _math_evaluate_expression(expression: str):
|
|
"""evaluate a simple mathematical expression using sympy expression evaluation."""
|
|
therm, simple, result = _math_evaluate_internal(expression)
|
|
if isinstance(simple, Integer):
|
|
return _build_equation_pair([therm, simple])
|
|
if therm == simple or simple == result:
|
|
return _build_equation_pair([therm, result])
|
|
return _build_equation_pair([therm, simple, result])
|
|
|
|
|
|
def _math_evaluate_internal(expression: str):
|
|
therm = _math_expression_sanitation_and_parse(expression)
|
|
simple = therm.doit()
|
|
numerical = therm.evalf()
|
|
return therm, simple, numerical
|
|
|
|
|
|
def _math_expression_sanitation_and_parse(expression: str):
|
|
expression = expression.replace("^", "**")
|
|
return parse_expr(expression, evaluate=False)
|
|
|
|
|
|
def _build_equation_pair(expressions: list[any]) -> str:
|
|
expressions = [str(e) for e in expressions]
|
|
return " = ".join(expressions)
|