7 changed files with 409 additions and 17 deletions
@ -0,0 +1,130 @@ |
|||
|
|||
import math_lexer as lexer |
|||
from math_lexer import Token |
|||
|
|||
|
|||
class Statement: |
|||
pass |
|||
|
|||
class Expression(Statement): |
|||
def __init__(self, value: str): |
|||
self.value = value |
|||
|
|||
class Equation: |
|||
def __init__(self, lhs: Expression, rhs: Expression): |
|||
self.lhs = lhs |
|||
self.rhs = rhs |
|||
|
|||
class Solve(Statement): |
|||
def __init__(self, equations: list[Equation], variables: list[Expression]): |
|||
self.equations = equations |
|||
self.variables = variables |
|||
|
|||
|
|||
|
|||
|
|||
class Parser: |
|||
def __init__(self): |
|||
self.tokens: list[Token] # tokens from lexer |
|||
self._last_eaten = None |
|||
|
|||
def not_eof(self) -> bool: |
|||
return self.tokens[0].type is not lexer.END_OF_INPUT |
|||
|
|||
def at(self) -> Token: |
|||
return self.tokens[0] |
|||
|
|||
def at_last(self) -> Token: |
|||
return self._last_eaten |
|||
|
|||
def eat(self) -> Token: |
|||
self._last_eaten = self.tokens.pop(0) |
|||
return self._last_eaten |
|||
|
|||
def backtrack(self): |
|||
if not self._last_eaten: |
|||
raise Exception("Cannot backtrack.") |
|||
self.tokens.insert(0, self._last_eaten) |
|||
self._last_eaten = None |
|||
|
|||
def eat_expect(self, token_type: int | str) -> Token: |
|||
prev = self.eat() |
|||
if prev.type is not token_type: |
|||
raise Exception("expected to consume '%s' but '%s' encountered." % (str(token_type), str(prev.type))) |
|||
return prev |
|||
|
|||
def at_expect(self, token_type: int | str) -> Token: |
|||
prev = self.at() |
|||
if prev.type is not token_type: |
|||
raise Exception("expected to be at '%s' but '%s' encountered." % (str(token_type), str(prev.type))) |
|||
return prev |
|||
|
|||
def parse(self, tokens: list[Token]) -> Statement: |
|||
self.tokens = tokens |
|||
statement = self.parse_statement() |
|||
self.at_expect(lexer.END_OF_INPUT) |
|||
return statement |
|||
|
|||
def parse_statement(self) -> Statement: |
|||
type = self.at().type |
|||
if type is lexer.SOLVE: |
|||
return self.parse_solve() |
|||
return self.parse_expression(merge_commas=True) |
|||
|
|||
def parse_solve(self) -> Solve: |
|||
""" |
|||
solve x = 1 for x |
|||
solve x = y and y = 2 for x and y |
|||
""" |
|||
self.eat_expect(lexer.SOLVE) |
|||
equations = [] # list of equations |
|||
variables = [] # list of variables to solve for |
|||
|
|||
while self.not_eof() and self.at().type is not lexer.FOR: |
|||
equations.append(self.parse_equation()) |
|||
selfattype = self.at().type |
|||
if selfattype is lexer.AND or selfattype is lexer.COMMA: |
|||
self.eat() |
|||
|
|||
self.eat_expect(lexer.FOR) |
|||
|
|||
while self.not_eof(): |
|||
variables.append(self.parse_expression(merge_commas=False)) |
|||
selfattype = self.at().type |
|||
if selfattype is lexer.AND or selfattype is lexer.COMMA: |
|||
self.eat() |
|||
|
|||
return Solve(equations, variables) |
|||
|
|||
def parse_equation(self) -> Equation: |
|||
lhs = self.parse_expression(merge_commas=False) |
|||
self.eat_expect(lexer.EQUALS) |
|||
rhs = self.parse_expression(merge_commas=False) |
|||
return Equation(lhs, rhs) |
|||
|
|||
def parse_expression(self, merge_commas) -> Expression: |
|||
""" |
|||
math expression |
|||
e.g: |
|||
sin(45) / 4 * pi |
|||
""" |
|||
|
|||
if merge_commas == True: |
|||
values = [] |
|||
while self.not_eof(): |
|||
token = self.eat() |
|||
if token.type is lexer.COMMA: |
|||
values.append(lexer.COMMA) |
|||
elif token.type is lexer.EQUALS: |
|||
values.append(lexer.EQUALS) |
|||
else: |
|||
values.append(token.value) |
|||
# token = self.eat_expect(lexer.EXPRESSION) |
|||
# values.append(token.value) |
|||
# if self.at() is lexer.COMMA: |
|||
# token = self.eat() |
|||
# values.append(lexer.COMMA) |
|||
return Expression("".join(values)) |
|||
else: |
|||
token = self.eat_expect(lexer.EXPRESSION) |
|||
return Expression(token.value) |
@ -0,0 +1,108 @@ |
|||
import math_ast as ast |
|||
|
|||
|
|||
from sympy.parsing.sympy_parser import parse_expr |
|||
from sympy.core.numbers import Integer, One, Zero |
|||
from sympy import symbols, Eq, solveset, linsolve, nonlinsolve |
|||
from sympy.core.symbol import Symbol |
|||
|
|||
|
|||
def interpret(statement: ast.Statement) -> str: |
|||
if isinstance(statement, ast.Solve): |
|||
return interpret_solve(statement) |
|||
elif isinstance(statement, ast.Expression): |
|||
return interpret_expression(statement) |
|||
return "interpretation error" |
|||
|
|||
|
|||
def interpret_solve(statement: ast.Solve) -> str: |
|||
eqs = statement.equations |
|||
var = statement.variables |
|||
|
|||
# convert equations to list of sympy Eq objects |
|||
equations = [Eq(_math_expression_sanitation_and_parse(e.lhs.value), _math_expression_sanitation_and_parse(e.rhs.value)) for e in eqs] |
|||
|
|||
variables = [symbols(v.value) for v in var] |
|||
|
|||
if len(equations) == 1 and len(variables) == 1: |
|||
return solve_simple_equation(equations[0], variables[0]) |
|||
else: |
|||
return solve_multi_equation(equations, variables) |
|||
|
|||
|
|||
|
|||
def solve_simple_equation(equation, variable): |
|||
result = solveset(equation, variable) |
|||
return "solved %s = %s for %s = %s" % (equation.lhs, equation.rhs, variable, result) |
|||
|
|||
def solve_multi_equation(equations, variables): |
|||
if is_linear(equations, variables): |
|||
solution = linsolve(equations, variables) |
|||
else: |
|||
solution = nonlinsolve(equations, variables) |
|||
|
|||
solutionpairs = [] |
|||
for variable, value in zip(variables, list(solution)[0]): |
|||
value_str = str(value) |
|||
if not isinstance(value, Integer): |
|||
try: |
|||
float_value = value.evalf() |
|||
if len(value_str) > 20: |
|||
value_str = "~%.3f" % float_value |
|||
else: |
|||
value_str += "=~%.3f" % float_value |
|||
except: |
|||
pass |
|||
|
|||
solutionpairs.append(f"{variable}={value_str}") |
|||
|
|||
# solutionpairs = [f"{variable}={value.doit()}" for variable, value in zip(variables, list(solution)[0])] |
|||
|
|||
return "solved equation system for " + ", ".join(solutionpairs[:-1]) + " and " + solutionpairs[-1] |
|||
|
|||
|
|||
|
|||
|
|||
|
|||
def is_linear(equations, variables): |
|||
return False |
|||
"""Checks if a system of equations is linear.""" |
|||
for eq in equations: |
|||
for var in variables: |
|||
deriv = eq.diff(var) # Partial derivative |
|||
if not (deriv.is_number or (isinstance(deriv, Symbol) and deriv.free_symbols.isdisjoint({var}))): # If the derivative is not a number or a symbol independent of the variable, the system is non-linear |
|||
return False |
|||
return True |
|||
|
|||
|
|||
|
|||
|
|||
def interpret_expression(statement: ast.Expression) -> str: |
|||
return _math_evaluate_expression(statement.value) |
|||
|
|||
|
|||
def _math_evaluate_expression(expression: str): |
|||
"""evaluate a simple mathematical expression using sympy expression evaluation.""" |
|||
therm, simple, result = _math_evaluate_internal(expression) |
|||
if isinstance(simple, Integer): |
|||
return _build_equation_pair([therm, simple]) |
|||
if therm == simple or simple == result: |
|||
return _build_equation_pair([therm, result]) |
|||
return _build_equation_pair([therm, simple, result]) |
|||
|
|||
|
|||
def _math_evaluate_internal(expression: str): |
|||
therm = _math_expression_sanitation_and_parse(expression) |
|||
simple = therm.doit() |
|||
numerical = therm.evalf() |
|||
return therm, simple, numerical |
|||
|
|||
|
|||
def _math_expression_sanitation_and_parse(expression: str): |
|||
expression = expression.replace("^", "**") |
|||
return parse_expr(expression, evaluate=False) |
|||
|
|||
|
|||
def _build_equation_pair(expressions: list[any]) -> str: |
|||
expressions = [str(e) for e in expressions] |
|||
return " = ".join(expressions) |
@ -0,0 +1,61 @@ |
|||
|
|||
|
|||
|
|||
EXPRESSION = 0 |
|||
END_OF_INPUT = 1 |
|||
|
|||
SOLVE = "solve" |
|||
FOR = "for" |
|||
AND = "and" |
|||
EQUALS = "=" |
|||
COMMA = "," |
|||
|
|||
keyword_tokens = [SOLVE, FOR, AND, EQUALS, COMMA] |
|||
|
|||
|
|||
|
|||
class Token: |
|||
def __init__(self, type: int|str, value: str = None): |
|||
self.type = type |
|||
self.value = value |
|||
|
|||
def __repr__(self): |
|||
if self.value == None: |
|||
return f"{self.type}" |
|||
return f"{self.type}|'{self.value}'" |
|||
|
|||
|
|||
def tokenize(expression: str) -> list[Token]: |
|||
""" |
|||
this splits a math instruction into tokens. |
|||
example: |
|||
"solve x + 1 = 5 and y = 2*x for x, y" |
|||
result: |
|||
["solve", "x + 1", "=", "5", "and", "y", "=", "2*x", "for", "x", "and", "y", "end_of_input"] |
|||
""" |
|||
|
|||
tokens = [] # output list of tokens |
|||
|
|||
symbols = expression.replace(",", " , ").replace("=", " = ").split(" ") |
|||
|
|||
current_token = [] # everything that is not directly in math_keyword_tokens gets binned here |
|||
for s in symbols: |
|||
found = False |
|||
|
|||
for keyword in keyword_tokens: |
|||
if s.lower() == keyword: |
|||
if len(current_token) != 0: |
|||
tokens.append(Token(EXPRESSION, " ".join(current_token))) |
|||
current_token = [] |
|||
tokens.append(Token(keyword)) |
|||
found = True |
|||
break |
|||
|
|||
if found == False: |
|||
current_token.append(s) |
|||
if len(current_token) != 0: |
|||
tokens.append(Token(EXPRESSION, " ".join(current_token))) |
|||
current_token = [] |
|||
|
|||
tokens.append(Token(END_OF_INPUT)) |
|||
return tokens |
@ -0,0 +1,57 @@ |
|||
import pytest |
|||
import tool_functions |
|||
|
|||
|
|||
|
|||
def test_math_evaluate_1(): |
|||
result = tool_functions.math_evaluate("1+2*pi") |
|||
assert result == "1 + 2*pi = 7.28318530717959" |
|||
|
|||
def test_math_evaluate_2a(): |
|||
result = tool_functions.math_evaluate("2**4") |
|||
assert result == "2**4 = 16" |
|||
|
|||
def test_math_evaluate_2b(): |
|||
""" test that ^ notation is also working, original sympy cannot do this """ |
|||
result = tool_functions.math_evaluate("2^4") |
|||
assert result == "2**4 = 16" |
|||
|
|||
def test_math_evaluate_3(): |
|||
result = tool_functions.math_evaluate("Integral(exp(-x**2), (x, -oo, oo))") |
|||
assert result == "Integral(exp(-x**2), (x, -oo, oo)) = sqrt(pi) = 1.77245385090552" |
|||
|
|||
def test_math_evaluate_4(): |
|||
result = tool_functions.math_evaluate("(2**x)**2") |
|||
assert result == "(2**x)**2 = 2**(2*x) = 2.0**(2*x)" |
|||
|
|||
def test_math_evaluate_5(): |
|||
result = tool_functions.math_evaluate("sin(pi/2) + cos(0)") |
|||
assert result == "sin(pi/2) + cos(0) = 2" |
|||
|
|||
|
|||
|
|||
|
|||
|
|||
|
|||
|
|||
|
|||
def test_math_solver_1(): |
|||
result = tool_functions.math_evaluate("solve x = 1 for x") |
|||
assert result == "solved x = 1 for x = {1}" |
|||
|
|||
def test_math_solver_2(): |
|||
result = tool_functions.math_evaluate("solve (x + 1)*(x - 1) = 1 for x") |
|||
assert result == "solved (x + 1)*(x - 1*1) = 1 for x = {-sqrt(2), sqrt(2)}" |
|||
|
|||
def test_math_solver_3a(): |
|||
result = tool_functions.math_evaluate("solve 2*x + 3*y = 7 and x - y = 1 for x, y") |
|||
assert result == "solved equation system for x=2 and y=1" |
|||
|
|||
def test_math_solver_3b(): |
|||
result = tool_functions.math_evaluate("solve 2*x + 3*y = 7, x - y = 1 for x and y") |
|||
assert result == "solved equation system for x=2 and y=1" |
|||
|
|||
def test_math_solver_4(): |
|||
result = tool_functions.math_evaluate("solve 2*x**3 + 3*y = 7 and x - y = 1 for x, y") |
|||
assert result == "solved equation system for x=~1.421 and y=~0.421" |
|||
|
Loading…
Reference in new issue