Browse Source

refined tool use

master
Florin Tobler 5 months ago
parent
commit
8c00a6c326
  1. 2
      __main__.py
  2. 2
      inference.py
  3. 8
      llama.py
  4. 15
      math_interpreter.py
  5. 6
      tests/test_tool_functions.py
  6. 50
      tool_functions.py

2
__main__.py

@ -1,8 +1,6 @@
print("running __main__.-py")
from llama import main
if __name__ == "__main__":
main()

2
inference.py

@ -8,7 +8,9 @@ import torch
import time
import utils
import re
import os
torch.set_num_threads(os.cpu_count()) # Adjust this to the number of threads/cores you have
class Inference:

8
llama.py

@ -3,6 +3,7 @@ import random
from tool_helper import tool_list, parse_and_execute_tool_call
from tool_functions import register_dummy
from inference import Inference, torch_reseed
import datetime
@ -63,7 +64,8 @@ def main():
inference = Inference()
messages = [{"role": "system", "content": systemmessage + "\n" + inference.generate_tool_use_header(tool_list)}]
current_date_and_time = datetime.datetime.now().strftime("Current date is %Y-%m-%d and its %H:%M %p right now.")
messages = [{"role": "system", "content": systemmessage + "\n" + current_date_and_time + "\n" + inference.generate_tool_use_header(tool_list)}]
while True:
# print an input prompt to receive text or commands
@ -72,8 +74,8 @@ def main():
if input_text.startswith("!"):
# append_generate_chat("<tool_response>%s</tool_response>" % input_text[1:], role="tool")
append_generate_chat("%s" % input_text[1:], role="tool") # depending on the chat template the tool response tags must or must not be passed. :(
append_generate_chat("<tool_response>%s</tool_response>" % input_text[1:], role="tool")
# append_generate_chat("%s" % input_text[1:], role="tool") # depending on the chat template the tool response tags must or must not be passed. :(
elif input_text.startswith("/clear"):
print("clearing chat history")

15
math_interpreter.py

@ -58,8 +58,21 @@ def solve_multi_equation(equations, variables):
# solutionpairs = [f"{variable}={value.doit()}" for variable, value in zip(variables, list(solution)[0])]
return "solved equation system for " + ", ".join(solutionpairs[:-1]) + " and " + solutionpairs[-1]
# return "solved equation system for " + ", ".join(solutionpairs[:-1]) + " and " + solutionpairs[-1]
if len(equations) > 1:
leadin = "Solved equation system "
else:
leadin = "Solved equation "
return leadin + _natural_join([_pretty_equation(e) for e in equations]) + " for " + _natural_join(solutionpairs) + "."
def _natural_join(data: list[any], joiner=", ", last=" and "):
if len(data) > 1:
return joiner.join(data[:-1]) + last + data[-1]
return last.join(data)
def _pretty_equation(simpy_Eq) -> str:
return f"{simpy_Eq.lhs} = {simpy_Eq.rhs}"

6
tests/test_tool_functions.py

@ -45,13 +45,13 @@ def test_math_solver_2():
def test_math_solver_3a():
result = tool_functions.math_evaluate("solve 2*x + 3*y = 7 and x - y = 1 for x, y")
assert result == "solved equation system for x=2 and y=1"
assert result == "Solved equation system 2*x + 3*y = 7 and x - y = 1 for x=2 and y=1."
def test_math_solver_3b():
result = tool_functions.math_evaluate("solve 2*x + 3*y = 7, x - y = 1 for x and y")
assert result == "solved equation system for x=2 and y=1"
assert result == "Solved equation system 2*x + 3*y = 7 and x - y = 1 for x=2 and y=1."
def test_math_solver_4():
result = tool_functions.math_evaluate("solve 2*x**3 + 3*y = 7 and x - y = 1 for x, y")
assert result == "solved equation system for x=~1.421 and y=~0.421"
assert result == "Solved equation system 2*x**3 + 3*y = 7 and x - y = 1 for x=~1.421 and y=~0.421."

50
tool_functions.py

@ -7,42 +7,37 @@ import math_interpreter
import utils
@tool
def current_time():
"""Get the current local date and time as a string."""
return datetime.datetime.now().strftime("%Y-%m-%d %H:%M")
# @tool
# def current_time():
# """Get the current local date and time as a string."""
# # return datetime.datetime.now().strftime("%Y-%m-%d %H:%M")
# return f"The current local date and time is {datetime.datetime.now().strftime('%Y-%m-%d %H:%M %p')}."
@tool
def random_float():
"""Generate a random float from 0..1."""
return str(random.random())
# @tool
# def random_float(a: float=0.0, b: float=1.0):
# """Generate a random float in range [a, b], including both end points. Optional pass no parameter and range 0..1 will be used.
# Args:
# a: minimum possible value
# b: maximum possible value"""
# return str(random.randint(a, b))
# def random_float():
# """Generate a random float in range 0 to 1."""
# # return str(random.random())
# return f"The freshly generated a random number from 0..1 is: {random.random():.5f}."
@tool
def random_int(a: int, b: int):
"""Generate a random integer in range [a, b], including both end points.
Args:
a: minimum possible value
b: maximum possible value"""
return str(random.randint(a, b))
# @tool
# def random_int(a: int, b: int):
# """Generate a random integer in the range [a, b], including both end points.
# Args:
# a: minimum possible value (must be <= b)
# b: maximum possible value (must be >= a)"""
# # return str(random.randint(a, b))
# return f"A fresh generated random integer between {a} and {b} is {random.randint(a, b)}."
@tool
def math_evaluate(expression: str):
"""evaluate and reduce a mathematical expression.
"""Evaluate and simplify a mathematical expression. Returns the evaluated result or a simplified version of the expression as a string.
Args:
expression: Reduce mathematic expression (without '=') algebraically.
"""
expression: A valid arithmetic expression (e.g., '2 + 3 * 4'). The expression must not contain '='."""
try:
tokens = math_lexer.tokenize(expression)
parser = math_ast.Parser()
@ -55,11 +50,10 @@ Args:
@tool
def math_solve(equations: list[str], variables: list[str]):
"""evaluate a mathematical equation system and solve equation systems.
"""Solve a system of linear or non-linear equation system. Returns the solutions as a string, or an error message if the input is invalid or unsolvable.
Args:
equations: list of mathematical equations containing a '='.
variables: list of variables to solve for. Must be lower or equal the number of given equations.
"""
equations: A list of mathematical equations in the format 'x + y = 2'.
variables: A list of variables to solve for. The number of variables must not exceed the number of equations."""
try:
expression = "solve " + " and ".join(equations) + " for " + " and ".join(variables)
print(expression)

Loading…
Cancel
Save