from sympy.parsing.sympy_parser import parse_expr from sympy.core.numbers import Integer, One, Zero from sympy import symbols, Eq, solveset, linsolve, nonlinsolve from sympy.core.symbol import Symbol from chatbug.matheval import ast def interpret(statement: ast.Statement) -> str: if isinstance(statement, ast.Solve): return interpret_solve(statement) elif isinstance(statement, ast.Expression): return interpret_expression(statement) return "interpretation error" def interpret_solve(statement: ast.Solve) -> str: eqs = statement.equations var = statement.variables # convert equations to list of sympy Eq objects equations = [Eq(_math_expression_sanitation_and_parse(e.lhs.value), _math_expression_sanitation_and_parse(e.rhs.value)) for e in eqs] variables = [symbols(v.value) for v in var] if len(equations) == 1 and len(variables) == 1: return solve_simple_equation(equations[0], variables[0]) else: return solve_multi_equation(equations, variables) def solve_simple_equation(equation, variable): result = solveset(equation, variable) return "solved %s = %s for %s = %s" % (equation.lhs, equation.rhs, variable, result) def solve_multi_equation(equations, variables): if is_linear(equations, variables): solution = linsolve(equations, variables) else: solution = nonlinsolve(equations, variables) solutionpairs = [] for variable, value in zip(variables, list(solution)[0]): value_str = str(value) if not isinstance(value, Integer): try: float_value = value.evalf() if len(value_str) > 20: value_str = "~%.3f" % float_value else: value_str += "=~%.3f" % float_value except: pass solutionpairs.append(f"{variable}={value_str}") # solutionpairs = [f"{variable}={value.doit()}" for variable, value in zip(variables, list(solution)[0])] # return "solved equation system for " + ", ".join(solutionpairs[:-1]) + " and " + solutionpairs[-1] if len(equations) > 1: leadin = "Solved equation system " else: leadin = "Solved equation " return leadin + _natural_join([_pretty_equation(e) for e in equations]) + " for " + _natural_join(solutionpairs) + "." def _natural_join(data: list[any], joiner=", ", last=" and "): if len(data) > 1: return joiner.join(data[:-1]) + last + data[-1] return last.join(data) def _pretty_equation(simpy_Eq) -> str: return f"{simpy_Eq.lhs} = {simpy_Eq.rhs}" def is_linear(equations, variables): return False """Checks if a system of equations is linear.""" for eq in equations: for var in variables: deriv = eq.diff(var) # Partial derivative if not (deriv.is_number or (isinstance(deriv, Symbol) and deriv.free_symbols.isdisjoint({var}))): # If the derivative is not a number or a symbol independent of the variable, the system is non-linear return False return True def interpret_expression(statement: ast.Expression) -> str: return _math_evaluate_expression(statement.value) def _math_evaluate_expression(expression: str): """evaluate a simple mathematical expression using sympy expression evaluation.""" therm, simple, result = _math_evaluate_internal(expression) if isinstance(simple, Integer): return _build_equation_pair([therm, simple]) if therm == simple or simple == result: return _build_equation_pair([therm, result]) return _build_equation_pair([therm, simple, result]) def _math_evaluate_internal(expression: str): therm = _math_expression_sanitation_and_parse(expression) simple = therm.doit() numerical = therm.evalf() return therm, simple, numerical def _math_expression_sanitation_and_parse(expression: str): expression = expression.replace("^", "**") return parse_expr(expression, evaluate=False) def _build_equation_pair(expressions: list[any]) -> str: expressions = [str(e) for e in expressions] return " = ".join(expressions)