From fe9c7388910136b6b0a3a10d4e75489f1cbd0966 Mon Sep 17 00:00:00 2001 From: Florin Tobler Date: Thu, 2 Jan 2025 00:45:03 +0100 Subject: [PATCH] add mathematical solver engine based on sympy --- math_ast.py | 130 ++++++++++++++++++++++++++ math_interpreter.py | 108 +++++++++++++++++++++ math_lexer.py | 61 ++++++++++++ tests/test_tool_function_decorator.py | 12 +-- tests/test_tool_functions.py | 57 +++++++++++ tool_functions.py | 57 +++++++++-- tool_helper.py | 1 + 7 files changed, 409 insertions(+), 17 deletions(-) create mode 100644 math_ast.py create mode 100644 math_interpreter.py create mode 100644 math_lexer.py create mode 100644 tests/test_tool_functions.py diff --git a/math_ast.py b/math_ast.py new file mode 100644 index 0000000..90f9916 --- /dev/null +++ b/math_ast.py @@ -0,0 +1,130 @@ + +import math_lexer as lexer +from math_lexer import Token + + +class Statement: + pass + +class Expression(Statement): + def __init__(self, value: str): + self.value = value + +class Equation: + def __init__(self, lhs: Expression, rhs: Expression): + self.lhs = lhs + self.rhs = rhs + +class Solve(Statement): + def __init__(self, equations: list[Equation], variables: list[Expression]): + self.equations = equations + self.variables = variables + + + + +class Parser: + def __init__(self): + self.tokens: list[Token] # tokens from lexer + self._last_eaten = None + + def not_eof(self) -> bool: + return self.tokens[0].type is not lexer.END_OF_INPUT + + def at(self) -> Token: + return self.tokens[0] + + def at_last(self) -> Token: + return self._last_eaten + + def eat(self) -> Token: + self._last_eaten = self.tokens.pop(0) + return self._last_eaten + + def backtrack(self): + if not self._last_eaten: + raise Exception("Cannot backtrack.") + self.tokens.insert(0, self._last_eaten) + self._last_eaten = None + + def eat_expect(self, token_type: int | str) -> Token: + prev = self.eat() + if prev.type is not token_type: + raise Exception("expected to consume '%s' but '%s' encountered." % (str(token_type), str(prev.type))) + return prev + + def at_expect(self, token_type: int | str) -> Token: + prev = self.at() + if prev.type is not token_type: + raise Exception("expected to be at '%s' but '%s' encountered." % (str(token_type), str(prev.type))) + return prev + + def parse(self, tokens: list[Token]) -> Statement: + self.tokens = tokens + statement = self.parse_statement() + self.at_expect(lexer.END_OF_INPUT) + return statement + + def parse_statement(self) -> Statement: + type = self.at().type + if type is lexer.SOLVE: + return self.parse_solve() + return self.parse_expression(merge_commas=True) + + def parse_solve(self) -> Solve: + """ + solve x = 1 for x + solve x = y and y = 2 for x and y + """ + self.eat_expect(lexer.SOLVE) + equations = [] # list of equations + variables = [] # list of variables to solve for + + while self.not_eof() and self.at().type is not lexer.FOR: + equations.append(self.parse_equation()) + selfattype = self.at().type + if selfattype is lexer.AND or selfattype is lexer.COMMA: + self.eat() + + self.eat_expect(lexer.FOR) + + while self.not_eof(): + variables.append(self.parse_expression(merge_commas=False)) + selfattype = self.at().type + if selfattype is lexer.AND or selfattype is lexer.COMMA: + self.eat() + + return Solve(equations, variables) + + def parse_equation(self) -> Equation: + lhs = self.parse_expression(merge_commas=False) + self.eat_expect(lexer.EQUALS) + rhs = self.parse_expression(merge_commas=False) + return Equation(lhs, rhs) + + def parse_expression(self, merge_commas) -> Expression: + """ + math expression + e.g: + sin(45) / 4 * pi + """ + + if merge_commas == True: + values = [] + while self.not_eof(): + token = self.eat() + if token.type is lexer.COMMA: + values.append(lexer.COMMA) + elif token.type is lexer.EQUALS: + values.append(lexer.EQUALS) + else: + values.append(token.value) + # token = self.eat_expect(lexer.EXPRESSION) + # values.append(token.value) + # if self.at() is lexer.COMMA: + # token = self.eat() + # values.append(lexer.COMMA) + return Expression("".join(values)) + else: + token = self.eat_expect(lexer.EXPRESSION) + return Expression(token.value) \ No newline at end of file diff --git a/math_interpreter.py b/math_interpreter.py new file mode 100644 index 0000000..2189bc1 --- /dev/null +++ b/math_interpreter.py @@ -0,0 +1,108 @@ +import math_ast as ast + + +from sympy.parsing.sympy_parser import parse_expr +from sympy.core.numbers import Integer, One, Zero +from sympy import symbols, Eq, solveset, linsolve, nonlinsolve +from sympy.core.symbol import Symbol + + +def interpret(statement: ast.Statement) -> str: + if isinstance(statement, ast.Solve): + return interpret_solve(statement) + elif isinstance(statement, ast.Expression): + return interpret_expression(statement) + return "interpretation error" + + +def interpret_solve(statement: ast.Solve) -> str: + eqs = statement.equations + var = statement.variables + + # convert equations to list of sympy Eq objects + equations = [Eq(_math_expression_sanitation_and_parse(e.lhs.value), _math_expression_sanitation_and_parse(e.rhs.value)) for e in eqs] + + variables = [symbols(v.value) for v in var] + + if len(equations) == 1 and len(variables) == 1: + return solve_simple_equation(equations[0], variables[0]) + else: + return solve_multi_equation(equations, variables) + + + +def solve_simple_equation(equation, variable): + result = solveset(equation, variable) + return "solved %s = %s for %s = %s" % (equation.lhs, equation.rhs, variable, result) + +def solve_multi_equation(equations, variables): + if is_linear(equations, variables): + solution = linsolve(equations, variables) + else: + solution = nonlinsolve(equations, variables) + + solutionpairs = [] + for variable, value in zip(variables, list(solution)[0]): + value_str = str(value) + if not isinstance(value, Integer): + try: + float_value = value.evalf() + if len(value_str) > 20: + value_str = "~%.3f" % float_value + else: + value_str += "=~%.3f" % float_value + except: + pass + + solutionpairs.append(f"{variable}={value_str}") + + # solutionpairs = [f"{variable}={value.doit()}" for variable, value in zip(variables, list(solution)[0])] + + return "solved equation system for " + ", ".join(solutionpairs[:-1]) + " and " + solutionpairs[-1] + + + + + +def is_linear(equations, variables): + return False + """Checks if a system of equations is linear.""" + for eq in equations: + for var in variables: + deriv = eq.diff(var) # Partial derivative + if not (deriv.is_number or (isinstance(deriv, Symbol) and deriv.free_symbols.isdisjoint({var}))): # If the derivative is not a number or a symbol independent of the variable, the system is non-linear + return False + return True + + + + +def interpret_expression(statement: ast.Expression) -> str: + return _math_evaluate_expression(statement.value) + + +def _math_evaluate_expression(expression: str): + """evaluate a simple mathematical expression using sympy expression evaluation.""" + therm, simple, result = _math_evaluate_internal(expression) + if isinstance(simple, Integer): + return _build_equation_pair([therm, simple]) + if therm == simple or simple == result: + return _build_equation_pair([therm, result]) + return _build_equation_pair([therm, simple, result]) + + +def _math_evaluate_internal(expression: str): + therm = _math_expression_sanitation_and_parse(expression) + simple = therm.doit() + numerical = therm.evalf() + return therm, simple, numerical + + +def _math_expression_sanitation_and_parse(expression: str): + expression = expression.replace("^", "**") + return parse_expr(expression, evaluate=False) + + +def _build_equation_pair(expressions: list[any]) -> str: + expressions = [str(e) for e in expressions] + return " = ".join(expressions) \ No newline at end of file diff --git a/math_lexer.py b/math_lexer.py new file mode 100644 index 0000000..5649027 --- /dev/null +++ b/math_lexer.py @@ -0,0 +1,61 @@ + + + +EXPRESSION = 0 +END_OF_INPUT = 1 + +SOLVE = "solve" +FOR = "for" +AND = "and" +EQUALS = "=" +COMMA = "," + +keyword_tokens = [SOLVE, FOR, AND, EQUALS, COMMA] + + + +class Token: + def __init__(self, type: int|str, value: str = None): + self.type = type + self.value = value + + def __repr__(self): + if self.value == None: + return f"{self.type}" + return f"{self.type}|'{self.value}'" + + +def tokenize(expression: str) -> list[Token]: + """ + this splits a math instruction into tokens. + example: + "solve x + 1 = 5 and y = 2*x for x, y" + result: + ["solve", "x + 1", "=", "5", "and", "y", "=", "2*x", "for", "x", "and", "y", "end_of_input"] + """ + + tokens = [] # output list of tokens + + symbols = expression.replace(",", " , ").replace("=", " = ").split(" ") + + current_token = [] # everything that is not directly in math_keyword_tokens gets binned here + for s in symbols: + found = False + + for keyword in keyword_tokens: + if s.lower() == keyword: + if len(current_token) != 0: + tokens.append(Token(EXPRESSION, " ".join(current_token))) + current_token = [] + tokens.append(Token(keyword)) + found = True + break + + if found == False: + current_token.append(s) + if len(current_token) != 0: + tokens.append(Token(EXPRESSION, " ".join(current_token))) + current_token = [] + + tokens.append(Token(END_OF_INPUT)) + return tokens \ No newline at end of file diff --git a/tests/test_tool_function_decorator.py b/tests/test_tool_function_decorator.py index 3405493..732ba63 100644 --- a/tests/test_tool_function_decorator.py +++ b/tests/test_tool_function_decorator.py @@ -3,19 +3,16 @@ import tool_helper import tests.helper as helper -def test_tool_function_decorator_if_clean_tool_list(): - """ tests for the tool list to be empty. NOT strictly nessesary, - but I want to be warned if this is not the case anymore. Could be not the intention """ - start_len = len(tool_helper.tool_list) - assert start_len == 0 def test_tool_function_decorator(): # get length before adding tools start_len = len(tool_helper.tool_list) # add tools like it would be a decorator - tool_helper.tool(helper.tool_dummy) - tool_helper.tool(helper.tool_dummy2) + res = tool_helper.tool(helper.tool_dummy) + assert res == helper.tool_dummy # decorator should return the function itself, so it is usable just in case. + res = tool_helper.tool(helper.tool_dummy2) + assert res == helper.tool_dummy2 # decorator should return the function itself, so it is usable just in case. # get length after adding tools end_len = len(tool_helper.tool_list) @@ -28,3 +25,4 @@ def test_tool_function_decorator(): + diff --git a/tests/test_tool_functions.py b/tests/test_tool_functions.py new file mode 100644 index 0000000..04ce25e --- /dev/null +++ b/tests/test_tool_functions.py @@ -0,0 +1,57 @@ +import pytest +import tool_functions + + + +def test_math_evaluate_1(): + result = tool_functions.math_evaluate("1+2*pi") + assert result == "1 + 2*pi = 7.28318530717959" + +def test_math_evaluate_2a(): + result = tool_functions.math_evaluate("2**4") + assert result == "2**4 = 16" + +def test_math_evaluate_2b(): + """ test that ^ notation is also working, original sympy cannot do this """ + result = tool_functions.math_evaluate("2^4") + assert result == "2**4 = 16" + +def test_math_evaluate_3(): + result = tool_functions.math_evaluate("Integral(exp(-x**2), (x, -oo, oo))") + assert result == "Integral(exp(-x**2), (x, -oo, oo)) = sqrt(pi) = 1.77245385090552" + +def test_math_evaluate_4(): + result = tool_functions.math_evaluate("(2**x)**2") + assert result == "(2**x)**2 = 2**(2*x) = 2.0**(2*x)" + +def test_math_evaluate_5(): + result = tool_functions.math_evaluate("sin(pi/2) + cos(0)") + assert result == "sin(pi/2) + cos(0) = 2" + + + + + + + + +def test_math_solver_1(): + result = tool_functions.math_evaluate("solve x = 1 for x") + assert result == "solved x = 1 for x = {1}" + +def test_math_solver_2(): + result = tool_functions.math_evaluate("solve (x + 1)*(x - 1) = 1 for x") + assert result == "solved (x + 1)*(x - 1*1) = 1 for x = {-sqrt(2), sqrt(2)}" + +def test_math_solver_3a(): + result = tool_functions.math_evaluate("solve 2*x + 3*y = 7 and x - y = 1 for x, y") + assert result == "solved equation system for x=2 and y=1" + +def test_math_solver_3b(): + result = tool_functions.math_evaluate("solve 2*x + 3*y = 7, x - y = 1 for x and y") + assert result == "solved equation system for x=2 and y=1" + +def test_math_solver_4(): + result = tool_functions.math_evaluate("solve 2*x**3 + 3*y = 7 and x - y = 1 for x, y") + assert result == "solved equation system for x=~1.421 and y=~0.421" + diff --git a/tool_functions.py b/tool_functions.py index 2a74e69..1eb353b 100644 --- a/tool_functions.py +++ b/tool_functions.py @@ -1,24 +1,29 @@ import random import datetime from tool_helper import tool +import math_lexer +import math_ast +import math_interpreter + @tool def current_time(): """Get the current local date and time as a string.""" return datetime.datetime.now().strftime("%Y-%m-%d %H:%M") +@tool +def random_float(): + """Generate a random float from 0..1.""" + return str(random.random()) + # @tool -# def random_float(): -# """Generate a random float from 0..1.""" -# return str(random.random()) +# def random_float(a: float=0.0, b: float=1.0): +# """Generate a random float in range [a, b], including both end points. Optional pass no parameter and range 0..1 will be used. +# Args: +# a: minimum possible value +# b: maximum possible value""" +# return str(random.randint(a, b)) -@tool -def random_float(a: float=0.0, b: float=1.0): - """Generate a random float in range [a, b], including both end points. Optional pass no parameter and range 0..1 will be used. -Args: - a: minimum possible value - b: maximum possible value""" - return str(random.randint(a, b)) @tool def random_int(a: int, b: int): @@ -31,5 +36,37 @@ Args: +@tool +def math_evaluate(expression: str): + """evaluate and reduce a mathematical expression. +Args: + expression: Reduce mathematic expression (without '=') algebraically.. + """ + + tokens = math_lexer.tokenize(expression) + parser = math_ast.Parser() + ast = parser.parse(tokens) + return math_interpreter.interpret(ast) + + +@tool +def math_solve(equations: list[str], variables: list[str]): + """evaluate a mathematical equation system and solve equation systems. Can be used to solve (x + 1)*(x - 1) = 1 for x as an example. +Args: + equations: list of mathematical equations containing a '='. + variables: list of variables to solve for. Must be lower or equal the number of given equations. + """ + + expression = "solve " + " and ".join(equations) + " for " + " and ".join(variables) + print(expression) + + tokens = math_lexer.tokenize(expression) + parser = math_ast.Parser() + ast = parser.parse(tokens) + return math_interpreter.interpret(ast) + + + + def register_dummy(): pass # dummy function to run and be sure the decorators have run \ No newline at end of file diff --git a/tool_helper.py b/tool_helper.py index 653915d..4acf724 100644 --- a/tool_helper.py +++ b/tool_helper.py @@ -11,6 +11,7 @@ def tool(fn): """tool function decorator""" print("register tool '%s'" % fn.__name__) tool_list.append(fn) + return fn # def parse_and_execute_tool_call(message: str, tools: list[function]) -> str | None: # """execute tool call if needed accordint tag and return the content of the tool call or None if no call happened."""