add mathematical solver engine based on sympy
This commit is contained in:
130
math_ast.py
Normal file
130
math_ast.py
Normal file
@@ -0,0 +1,130 @@
|
|||||||
|
|
||||||
|
import math_lexer as lexer
|
||||||
|
from math_lexer import Token
|
||||||
|
|
||||||
|
|
||||||
|
class Statement:
|
||||||
|
pass
|
||||||
|
|
||||||
|
class Expression(Statement):
|
||||||
|
def __init__(self, value: str):
|
||||||
|
self.value = value
|
||||||
|
|
||||||
|
class Equation:
|
||||||
|
def __init__(self, lhs: Expression, rhs: Expression):
|
||||||
|
self.lhs = lhs
|
||||||
|
self.rhs = rhs
|
||||||
|
|
||||||
|
class Solve(Statement):
|
||||||
|
def __init__(self, equations: list[Equation], variables: list[Expression]):
|
||||||
|
self.equations = equations
|
||||||
|
self.variables = variables
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class Parser:
|
||||||
|
def __init__(self):
|
||||||
|
self.tokens: list[Token] # tokens from lexer
|
||||||
|
self._last_eaten = None
|
||||||
|
|
||||||
|
def not_eof(self) -> bool:
|
||||||
|
return self.tokens[0].type is not lexer.END_OF_INPUT
|
||||||
|
|
||||||
|
def at(self) -> Token:
|
||||||
|
return self.tokens[0]
|
||||||
|
|
||||||
|
def at_last(self) -> Token:
|
||||||
|
return self._last_eaten
|
||||||
|
|
||||||
|
def eat(self) -> Token:
|
||||||
|
self._last_eaten = self.tokens.pop(0)
|
||||||
|
return self._last_eaten
|
||||||
|
|
||||||
|
def backtrack(self):
|
||||||
|
if not self._last_eaten:
|
||||||
|
raise Exception("Cannot backtrack.")
|
||||||
|
self.tokens.insert(0, self._last_eaten)
|
||||||
|
self._last_eaten = None
|
||||||
|
|
||||||
|
def eat_expect(self, token_type: int | str) -> Token:
|
||||||
|
prev = self.eat()
|
||||||
|
if prev.type is not token_type:
|
||||||
|
raise Exception("expected to consume '%s' but '%s' encountered." % (str(token_type), str(prev.type)))
|
||||||
|
return prev
|
||||||
|
|
||||||
|
def at_expect(self, token_type: int | str) -> Token:
|
||||||
|
prev = self.at()
|
||||||
|
if prev.type is not token_type:
|
||||||
|
raise Exception("expected to be at '%s' but '%s' encountered." % (str(token_type), str(prev.type)))
|
||||||
|
return prev
|
||||||
|
|
||||||
|
def parse(self, tokens: list[Token]) -> Statement:
|
||||||
|
self.tokens = tokens
|
||||||
|
statement = self.parse_statement()
|
||||||
|
self.at_expect(lexer.END_OF_INPUT)
|
||||||
|
return statement
|
||||||
|
|
||||||
|
def parse_statement(self) -> Statement:
|
||||||
|
type = self.at().type
|
||||||
|
if type is lexer.SOLVE:
|
||||||
|
return self.parse_solve()
|
||||||
|
return self.parse_expression(merge_commas=True)
|
||||||
|
|
||||||
|
def parse_solve(self) -> Solve:
|
||||||
|
"""
|
||||||
|
solve x = 1 for x
|
||||||
|
solve x = y and y = 2 for x and y
|
||||||
|
"""
|
||||||
|
self.eat_expect(lexer.SOLVE)
|
||||||
|
equations = [] # list of equations
|
||||||
|
variables = [] # list of variables to solve for
|
||||||
|
|
||||||
|
while self.not_eof() and self.at().type is not lexer.FOR:
|
||||||
|
equations.append(self.parse_equation())
|
||||||
|
selfattype = self.at().type
|
||||||
|
if selfattype is lexer.AND or selfattype is lexer.COMMA:
|
||||||
|
self.eat()
|
||||||
|
|
||||||
|
self.eat_expect(lexer.FOR)
|
||||||
|
|
||||||
|
while self.not_eof():
|
||||||
|
variables.append(self.parse_expression(merge_commas=False))
|
||||||
|
selfattype = self.at().type
|
||||||
|
if selfattype is lexer.AND or selfattype is lexer.COMMA:
|
||||||
|
self.eat()
|
||||||
|
|
||||||
|
return Solve(equations, variables)
|
||||||
|
|
||||||
|
def parse_equation(self) -> Equation:
|
||||||
|
lhs = self.parse_expression(merge_commas=False)
|
||||||
|
self.eat_expect(lexer.EQUALS)
|
||||||
|
rhs = self.parse_expression(merge_commas=False)
|
||||||
|
return Equation(lhs, rhs)
|
||||||
|
|
||||||
|
def parse_expression(self, merge_commas) -> Expression:
|
||||||
|
"""
|
||||||
|
math expression
|
||||||
|
e.g:
|
||||||
|
sin(45) / 4 * pi
|
||||||
|
"""
|
||||||
|
|
||||||
|
if merge_commas == True:
|
||||||
|
values = []
|
||||||
|
while self.not_eof():
|
||||||
|
token = self.eat()
|
||||||
|
if token.type is lexer.COMMA:
|
||||||
|
values.append(lexer.COMMA)
|
||||||
|
elif token.type is lexer.EQUALS:
|
||||||
|
values.append(lexer.EQUALS)
|
||||||
|
else:
|
||||||
|
values.append(token.value)
|
||||||
|
# token = self.eat_expect(lexer.EXPRESSION)
|
||||||
|
# values.append(token.value)
|
||||||
|
# if self.at() is lexer.COMMA:
|
||||||
|
# token = self.eat()
|
||||||
|
# values.append(lexer.COMMA)
|
||||||
|
return Expression("".join(values))
|
||||||
|
else:
|
||||||
|
token = self.eat_expect(lexer.EXPRESSION)
|
||||||
|
return Expression(token.value)
|
108
math_interpreter.py
Normal file
108
math_interpreter.py
Normal file
@@ -0,0 +1,108 @@
|
|||||||
|
import math_ast as ast
|
||||||
|
|
||||||
|
|
||||||
|
from sympy.parsing.sympy_parser import parse_expr
|
||||||
|
from sympy.core.numbers import Integer, One, Zero
|
||||||
|
from sympy import symbols, Eq, solveset, linsolve, nonlinsolve
|
||||||
|
from sympy.core.symbol import Symbol
|
||||||
|
|
||||||
|
|
||||||
|
def interpret(statement: ast.Statement) -> str:
|
||||||
|
if isinstance(statement, ast.Solve):
|
||||||
|
return interpret_solve(statement)
|
||||||
|
elif isinstance(statement, ast.Expression):
|
||||||
|
return interpret_expression(statement)
|
||||||
|
return "interpretation error"
|
||||||
|
|
||||||
|
|
||||||
|
def interpret_solve(statement: ast.Solve) -> str:
|
||||||
|
eqs = statement.equations
|
||||||
|
var = statement.variables
|
||||||
|
|
||||||
|
# convert equations to list of sympy Eq objects
|
||||||
|
equations = [Eq(_math_expression_sanitation_and_parse(e.lhs.value), _math_expression_sanitation_and_parse(e.rhs.value)) for e in eqs]
|
||||||
|
|
||||||
|
variables = [symbols(v.value) for v in var]
|
||||||
|
|
||||||
|
if len(equations) == 1 and len(variables) == 1:
|
||||||
|
return solve_simple_equation(equations[0], variables[0])
|
||||||
|
else:
|
||||||
|
return solve_multi_equation(equations, variables)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def solve_simple_equation(equation, variable):
|
||||||
|
result = solveset(equation, variable)
|
||||||
|
return "solved %s = %s for %s = %s" % (equation.lhs, equation.rhs, variable, result)
|
||||||
|
|
||||||
|
def solve_multi_equation(equations, variables):
|
||||||
|
if is_linear(equations, variables):
|
||||||
|
solution = linsolve(equations, variables)
|
||||||
|
else:
|
||||||
|
solution = nonlinsolve(equations, variables)
|
||||||
|
|
||||||
|
solutionpairs = []
|
||||||
|
for variable, value in zip(variables, list(solution)[0]):
|
||||||
|
value_str = str(value)
|
||||||
|
if not isinstance(value, Integer):
|
||||||
|
try:
|
||||||
|
float_value = value.evalf()
|
||||||
|
if len(value_str) > 20:
|
||||||
|
value_str = "~%.3f" % float_value
|
||||||
|
else:
|
||||||
|
value_str += "=~%.3f" % float_value
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
|
||||||
|
solutionpairs.append(f"{variable}={value_str}")
|
||||||
|
|
||||||
|
# solutionpairs = [f"{variable}={value.doit()}" for variable, value in zip(variables, list(solution)[0])]
|
||||||
|
|
||||||
|
return "solved equation system for " + ", ".join(solutionpairs[:-1]) + " and " + solutionpairs[-1]
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def is_linear(equations, variables):
|
||||||
|
return False
|
||||||
|
"""Checks if a system of equations is linear."""
|
||||||
|
for eq in equations:
|
||||||
|
for var in variables:
|
||||||
|
deriv = eq.diff(var) # Partial derivative
|
||||||
|
if not (deriv.is_number or (isinstance(deriv, Symbol) and deriv.free_symbols.isdisjoint({var}))): # If the derivative is not a number or a symbol independent of the variable, the system is non-linear
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def interpret_expression(statement: ast.Expression) -> str:
|
||||||
|
return _math_evaluate_expression(statement.value)
|
||||||
|
|
||||||
|
|
||||||
|
def _math_evaluate_expression(expression: str):
|
||||||
|
"""evaluate a simple mathematical expression using sympy expression evaluation."""
|
||||||
|
therm, simple, result = _math_evaluate_internal(expression)
|
||||||
|
if isinstance(simple, Integer):
|
||||||
|
return _build_equation_pair([therm, simple])
|
||||||
|
if therm == simple or simple == result:
|
||||||
|
return _build_equation_pair([therm, result])
|
||||||
|
return _build_equation_pair([therm, simple, result])
|
||||||
|
|
||||||
|
|
||||||
|
def _math_evaluate_internal(expression: str):
|
||||||
|
therm = _math_expression_sanitation_and_parse(expression)
|
||||||
|
simple = therm.doit()
|
||||||
|
numerical = therm.evalf()
|
||||||
|
return therm, simple, numerical
|
||||||
|
|
||||||
|
|
||||||
|
def _math_expression_sanitation_and_parse(expression: str):
|
||||||
|
expression = expression.replace("^", "**")
|
||||||
|
return parse_expr(expression, evaluate=False)
|
||||||
|
|
||||||
|
|
||||||
|
def _build_equation_pair(expressions: list[any]) -> str:
|
||||||
|
expressions = [str(e) for e in expressions]
|
||||||
|
return " = ".join(expressions)
|
61
math_lexer.py
Normal file
61
math_lexer.py
Normal file
@@ -0,0 +1,61 @@
|
|||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
EXPRESSION = 0
|
||||||
|
END_OF_INPUT = 1
|
||||||
|
|
||||||
|
SOLVE = "solve"
|
||||||
|
FOR = "for"
|
||||||
|
AND = "and"
|
||||||
|
EQUALS = "="
|
||||||
|
COMMA = ","
|
||||||
|
|
||||||
|
keyword_tokens = [SOLVE, FOR, AND, EQUALS, COMMA]
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class Token:
|
||||||
|
def __init__(self, type: int|str, value: str = None):
|
||||||
|
self.type = type
|
||||||
|
self.value = value
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
if self.value == None:
|
||||||
|
return f"{self.type}"
|
||||||
|
return f"{self.type}|'{self.value}'"
|
||||||
|
|
||||||
|
|
||||||
|
def tokenize(expression: str) -> list[Token]:
|
||||||
|
"""
|
||||||
|
this splits a math instruction into tokens.
|
||||||
|
example:
|
||||||
|
"solve x + 1 = 5 and y = 2*x for x, y"
|
||||||
|
result:
|
||||||
|
["solve", "x + 1", "=", "5", "and", "y", "=", "2*x", "for", "x", "and", "y", "end_of_input"]
|
||||||
|
"""
|
||||||
|
|
||||||
|
tokens = [] # output list of tokens
|
||||||
|
|
||||||
|
symbols = expression.replace(",", " , ").replace("=", " = ").split(" ")
|
||||||
|
|
||||||
|
current_token = [] # everything that is not directly in math_keyword_tokens gets binned here
|
||||||
|
for s in symbols:
|
||||||
|
found = False
|
||||||
|
|
||||||
|
for keyword in keyword_tokens:
|
||||||
|
if s.lower() == keyword:
|
||||||
|
if len(current_token) != 0:
|
||||||
|
tokens.append(Token(EXPRESSION, " ".join(current_token)))
|
||||||
|
current_token = []
|
||||||
|
tokens.append(Token(keyword))
|
||||||
|
found = True
|
||||||
|
break
|
||||||
|
|
||||||
|
if found == False:
|
||||||
|
current_token.append(s)
|
||||||
|
if len(current_token) != 0:
|
||||||
|
tokens.append(Token(EXPRESSION, " ".join(current_token)))
|
||||||
|
current_token = []
|
||||||
|
|
||||||
|
tokens.append(Token(END_OF_INPUT))
|
||||||
|
return tokens
|
@@ -3,19 +3,16 @@ import tool_helper
|
|||||||
import tests.helper as helper
|
import tests.helper as helper
|
||||||
|
|
||||||
|
|
||||||
def test_tool_function_decorator_if_clean_tool_list():
|
|
||||||
""" tests for the tool list to be empty. NOT strictly nessesary,
|
|
||||||
but I want to be warned if this is not the case anymore. Could be not the intention """
|
|
||||||
start_len = len(tool_helper.tool_list)
|
|
||||||
assert start_len == 0
|
|
||||||
|
|
||||||
def test_tool_function_decorator():
|
def test_tool_function_decorator():
|
||||||
# get length before adding tools
|
# get length before adding tools
|
||||||
start_len = len(tool_helper.tool_list)
|
start_len = len(tool_helper.tool_list)
|
||||||
|
|
||||||
# add tools like it would be a decorator
|
# add tools like it would be a decorator
|
||||||
tool_helper.tool(helper.tool_dummy)
|
res = tool_helper.tool(helper.tool_dummy)
|
||||||
tool_helper.tool(helper.tool_dummy2)
|
assert res == helper.tool_dummy # decorator should return the function itself, so it is usable just in case.
|
||||||
|
res = tool_helper.tool(helper.tool_dummy2)
|
||||||
|
assert res == helper.tool_dummy2 # decorator should return the function itself, so it is usable just in case.
|
||||||
|
|
||||||
# get length after adding tools
|
# get length after adding tools
|
||||||
end_len = len(tool_helper.tool_list)
|
end_len = len(tool_helper.tool_list)
|
||||||
@@ -28,3 +25,4 @@ def test_tool_function_decorator():
|
|||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
57
tests/test_tool_functions.py
Normal file
57
tests/test_tool_functions.py
Normal file
@@ -0,0 +1,57 @@
|
|||||||
|
import pytest
|
||||||
|
import tool_functions
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def test_math_evaluate_1():
|
||||||
|
result = tool_functions.math_evaluate("1+2*pi")
|
||||||
|
assert result == "1 + 2*pi = 7.28318530717959"
|
||||||
|
|
||||||
|
def test_math_evaluate_2a():
|
||||||
|
result = tool_functions.math_evaluate("2**4")
|
||||||
|
assert result == "2**4 = 16"
|
||||||
|
|
||||||
|
def test_math_evaluate_2b():
|
||||||
|
""" test that ^ notation is also working, original sympy cannot do this """
|
||||||
|
result = tool_functions.math_evaluate("2^4")
|
||||||
|
assert result == "2**4 = 16"
|
||||||
|
|
||||||
|
def test_math_evaluate_3():
|
||||||
|
result = tool_functions.math_evaluate("Integral(exp(-x**2), (x, -oo, oo))")
|
||||||
|
assert result == "Integral(exp(-x**2), (x, -oo, oo)) = sqrt(pi) = 1.77245385090552"
|
||||||
|
|
||||||
|
def test_math_evaluate_4():
|
||||||
|
result = tool_functions.math_evaluate("(2**x)**2")
|
||||||
|
assert result == "(2**x)**2 = 2**(2*x) = 2.0**(2*x)"
|
||||||
|
|
||||||
|
def test_math_evaluate_5():
|
||||||
|
result = tool_functions.math_evaluate("sin(pi/2) + cos(0)")
|
||||||
|
assert result == "sin(pi/2) + cos(0) = 2"
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def test_math_solver_1():
|
||||||
|
result = tool_functions.math_evaluate("solve x = 1 for x")
|
||||||
|
assert result == "solved x = 1 for x = {1}"
|
||||||
|
|
||||||
|
def test_math_solver_2():
|
||||||
|
result = tool_functions.math_evaluate("solve (x + 1)*(x - 1) = 1 for x")
|
||||||
|
assert result == "solved (x + 1)*(x - 1*1) = 1 for x = {-sqrt(2), sqrt(2)}"
|
||||||
|
|
||||||
|
def test_math_solver_3a():
|
||||||
|
result = tool_functions.math_evaluate("solve 2*x + 3*y = 7 and x - y = 1 for x, y")
|
||||||
|
assert result == "solved equation system for x=2 and y=1"
|
||||||
|
|
||||||
|
def test_math_solver_3b():
|
||||||
|
result = tool_functions.math_evaluate("solve 2*x + 3*y = 7, x - y = 1 for x and y")
|
||||||
|
assert result == "solved equation system for x=2 and y=1"
|
||||||
|
|
||||||
|
def test_math_solver_4():
|
||||||
|
result = tool_functions.math_evaluate("solve 2*x**3 + 3*y = 7 and x - y = 1 for x, y")
|
||||||
|
assert result == "solved equation system for x=~1.421 and y=~0.421"
|
||||||
|
|
@@ -1,24 +1,29 @@
|
|||||||
import random
|
import random
|
||||||
import datetime
|
import datetime
|
||||||
from tool_helper import tool
|
from tool_helper import tool
|
||||||
|
import math_lexer
|
||||||
|
import math_ast
|
||||||
|
import math_interpreter
|
||||||
|
|
||||||
|
|
||||||
@tool
|
@tool
|
||||||
def current_time():
|
def current_time():
|
||||||
"""Get the current local date and time as a string."""
|
"""Get the current local date and time as a string."""
|
||||||
return datetime.datetime.now().strftime("%Y-%m-%d %H:%M")
|
return datetime.datetime.now().strftime("%Y-%m-%d %H:%M")
|
||||||
|
|
||||||
# @tool
|
|
||||||
# def random_float():
|
|
||||||
# """Generate a random float from 0..1."""
|
|
||||||
# return str(random.random())
|
|
||||||
|
|
||||||
@tool
|
@tool
|
||||||
def random_float(a: float=0.0, b: float=1.0):
|
def random_float():
|
||||||
"""Generate a random float in range [a, b], including both end points. Optional pass no parameter and range 0..1 will be used.
|
"""Generate a random float from 0..1."""
|
||||||
Args:
|
return str(random.random())
|
||||||
a: minimum possible value
|
|
||||||
b: maximum possible value"""
|
# @tool
|
||||||
return str(random.randint(a, b))
|
# def random_float(a: float=0.0, b: float=1.0):
|
||||||
|
# """Generate a random float in range [a, b], including both end points. Optional pass no parameter and range 0..1 will be used.
|
||||||
|
# Args:
|
||||||
|
# a: minimum possible value
|
||||||
|
# b: maximum possible value"""
|
||||||
|
# return str(random.randint(a, b))
|
||||||
|
|
||||||
|
|
||||||
@tool
|
@tool
|
||||||
def random_int(a: int, b: int):
|
def random_int(a: int, b: int):
|
||||||
@@ -31,5 +36,37 @@ Args:
|
|||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
@tool
|
||||||
|
def math_evaluate(expression: str):
|
||||||
|
"""evaluate and reduce a mathematical expression.
|
||||||
|
Args:
|
||||||
|
expression: Reduce mathematic expression (without '=') algebraically..
|
||||||
|
"""
|
||||||
|
|
||||||
|
tokens = math_lexer.tokenize(expression)
|
||||||
|
parser = math_ast.Parser()
|
||||||
|
ast = parser.parse(tokens)
|
||||||
|
return math_interpreter.interpret(ast)
|
||||||
|
|
||||||
|
|
||||||
|
@tool
|
||||||
|
def math_solve(equations: list[str], variables: list[str]):
|
||||||
|
"""evaluate a mathematical equation system and solve equation systems. Can be used to solve (x + 1)*(x - 1) = 1 for x as an example.
|
||||||
|
Args:
|
||||||
|
equations: list of mathematical equations containing a '='.
|
||||||
|
variables: list of variables to solve for. Must be lower or equal the number of given equations.
|
||||||
|
"""
|
||||||
|
|
||||||
|
expression = "solve " + " and ".join(equations) + " for " + " and ".join(variables)
|
||||||
|
print(expression)
|
||||||
|
|
||||||
|
tokens = math_lexer.tokenize(expression)
|
||||||
|
parser = math_ast.Parser()
|
||||||
|
ast = parser.parse(tokens)
|
||||||
|
return math_interpreter.interpret(ast)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def register_dummy():
|
def register_dummy():
|
||||||
pass # dummy function to run and be sure the decorators have run
|
pass # dummy function to run and be sure the decorators have run
|
@@ -11,6 +11,7 @@ def tool(fn):
|
|||||||
"""tool function decorator"""
|
"""tool function decorator"""
|
||||||
print("register tool '%s'" % fn.__name__)
|
print("register tool '%s'" % fn.__name__)
|
||||||
tool_list.append(fn)
|
tool_list.append(fn)
|
||||||
|
return fn
|
||||||
|
|
||||||
# def parse_and_execute_tool_call(message: str, tools: list[function]) -> str | None:
|
# def parse_and_execute_tool_call(message: str, tools: list[function]) -> str | None:
|
||||||
# """execute tool call if needed accordint <tool_call> tag and return the content of the tool call or None if no call happened."""
|
# """execute tool call if needed accordint <tool_call> tag and return the content of the tool call or None if no call happened."""
|
||||||
|
Reference in New Issue
Block a user