Browse Source

tool functions and pytest

master
Florin Tobler 6 months ago
parent
commit
fd7e3d5235
  1. 3
      .gitignore
  2. 8
      __main__.py
  3. 226
      llama.py
  4. 1
      tests/__init__.py
  5. 7
      tests/helper.py
  6. 30
      tests/test_tool_function_decorator.py
  7. 89
      tests/test_tool_parse_exec.py
  8. 35
      tool_functions.py
  9. 102
      tool_helper.py
  10. 11
      utils.py

3
.gitignore

@ -1,2 +1,3 @@
/model/* /model/*
*.prof *.prof
__pycache__

8
__main__.py

@ -0,0 +1,8 @@
print("running __main__.-py")
from llama import main
if __name__ == "__main__":
main()

226
llama.py

@ -2,8 +2,9 @@ from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
import time import time
import torch import torch
import random import random
import datetime from tool_helper import tool_list, parse_and_execute_tool_call
import json from tool_functions import register_dummy
import utils
t_start = time.time() t_start = time.time()
@ -41,10 +42,9 @@ print("load took %.3fs" % (time.time() - t_start))
max_context_length = model.config.max_position_embeddings max_context_length = model.config.max_position_embeddings
# if tokenizer.chat_template is None: tokenizer.chat_template = utils.load_json_file("chat_template.json")
print("apply external chat template...")
with open("chat_template.json", "r") as f:
tokenizer.chat_template = json.load(f)
print("max_context_length is %d tokens." % (max_context_length)) print("max_context_length is %d tokens." % (max_context_length))
@ -97,22 +97,9 @@ messages = [
roleflip = {"role": "system", "content": "Keep the conversation going, ask for more information on the subject. Keep messages short at max 1-2 sentences. Do not thank and say goodbye."} roleflip = {"role": "system", "content": "Keep the conversation going, ask for more information on the subject. Keep messages short at max 1-2 sentences. Do not thank and say goodbye."}
def current_time():
"""Get the current local date and time as a string."""
return datetime.datetime.now().strftime("%Y-%m-%d %H:%M")
def random_float():
"""Get a random float from 0..1"""
return str(random.random())
def random_int(a: int, b: int):
"""Return random integer in range [a, b], including both end points.
Args:
a: minimum possible value
b: maximum possible value"""
return str(random.randint(a, b))
tool_functions = [current_time, random_float, random_int] register_dummy()
# tool_functions = [current_time, random_float, random_int]
@ -139,31 +126,35 @@ def generate_incremental(inputs):
generated_tokens = input_ids # Initially, this is just the input tokens generated_tokens = input_ids # Initially, this is just the input tokens
n = 0 n = 0
try:
# Loop to generate one token at a time # Loop to generate one token at a time
while True: while True:
# Call the model with the current tokens # Call the model with the current tokens
outputs = model(input_ids=generated_tokens, use_cache=True) outputs = model(input_ids=generated_tokens, use_cache=True)
# Get the next token (the last token from the generated sequence) # Get the next token (the last token from the generated sequence)
next_token = outputs.logits.argmax(dim=-1)[:, -1] next_token = outputs.logits.argmax(dim=-1)[:, -1]
# Append the new token to the sequence # Append the new token to the sequence
generated_tokens = torch.cat([generated_tokens, next_token.unsqueeze(0)], dim=1) generated_tokens = torch.cat([generated_tokens, next_token.unsqueeze(0)], dim=1)
# Decode and print the newly generated token (skip special tokens) # Decode and print the newly generated token (skip special tokens)
out_text = tokenizer.decode(next_token, skip_special_tokens=True) out_text = tokenizer.decode(next_token, skip_special_tokens=True)
print(out_text, end="", flush=True) # Print without newline print(out_text, end="", flush=True) # Print without newline
# Check if the generated token is the end-of-sequence token # Check if the generated token is the end-of-sequence token
if next_token.item() == tokenizer.eos_token_id: if next_token.item() == tokenizer.eos_token_id:
print("") print("")
break break
n += 1
if n >= 15:
n = 0
torch.cuda.empty_cache()
n += 1 except KeyboardInterrupt:
if n >= 30: pass
n = 0
torch.cuda.empty_cache()
# Once done, return the full generated sequence # Once done, return the full generated sequence
@ -184,7 +175,7 @@ def append_generate_chat(input_text: str, role="user"):
# input_text = "Hello, who are you?" # input_text = "Hello, who are you?"
# inputs = tokenizer(input_text, return_tensors="pt").to("cpu") # .to("cuda") .to("cpu") # inputs = tokenizer(input_text, return_tensors="pt").to("cpu") # .to("cuda") .to("cpu")
inputs = tokenizer.apply_chat_template(messages, return_tensors="pt", tokenize=True, return_dict=True, add_generation_prompt=True, tools=tool_functions) #continue_final_message=True, inputs = tokenizer.apply_chat_template(messages, return_tensors="pt", tokenize=True, return_dict=True, add_generation_prompt=True, tools=tool_list) #continue_final_message=True,
inputs = {key: value.to(model.device) for key, value in inputs.items()} inputs = {key: value.to(model.device) for key, value in inputs.items()}
# inputs = {key: value.to("cpu") for key, value in inputs.items()} # inputs = {key: value.to("cpu") for key, value in inputs.items()}
# inputs["input_ids"] = inputs["input_ids"][:, 1:] # inputs["input_ids"] = inputs["input_ids"][:, 1:]
@ -194,82 +185,105 @@ def append_generate_chat(input_text: str, role="user"):
# append result to message history # append result to message history
messages.append({"role": "assistant", "content": out_text}) messages.append({"role": "assistant", "content": out_text})
print("") print("")
print("generation took %.3fs (%d tokens)" % (time.time() - t_start, len(outputs[0]))) print("generation took %.3fs (%d tokens)" % (time.time() - t_start, len(outputs[0])))
# handle tool call and check if a tool call has happened.
while True: tool_result = parse_and_execute_tool_call(out_text, tool_list)
# print an input prompt to receive text or commands if tool_result != None:
input_text = input(">>> ") # tool call happened
print("") # tool_result = "<tool_response>%s</tool_response>" % tool_result
# depending on the chat template the tool response tags must or must not be passed. :(
append_generate_chat(tool_result, role="tool")
if input_text.startswith("!"):
# append_generate_chat("<tool_response>%s</tool_response>" % input_text[1:], role="tool")
append_generate_chat("%s" % input_text[1:], role="tool") # depending on the chat template the tool response tags must or must not be passed. :(
elif input_text.startswith("/clear"): def main():
print("clearing chat history") global messages
messages = [messages[0]] while True:
# print an input prompt to receive text or commands
input_text = input(">>> ")
print("") print("")
elif input_text.startswith("/history"):
history = tokenizer.apply_chat_template(messages, return_tensors="pt", tokenize=False, add_generation_prompt=False, tools=tool_functions)
print(history)
elif input_text.startswith("/undo"): if input_text.startswith("!"):
if len(messages) > 2: # append_generate_chat("<tool_response>%s</tool_response>" % input_text[1:], role="tool")
print("undo latest prompt") append_generate_chat("%s" % input_text[1:], role="tool") # depending on the chat template the tool response tags must or must not be passed. :(
messages = messages[:-2]
else:
print("cannot undo because there are not enough messages on history.")
print("")
elif input_text.startswith("/regen"): elif input_text.startswith("/clear"):
if len(messages) >= 2: print("clearing chat history")
print("regenerating message (not working)") start_msg = messages[0]
messages = messages[:-1] messages = [start_msg]
seed = random.randint(0, 2**32 - 1) # Generate a random seed print("")
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed) elif input_text.startswith("/history"):
history = tokenizer.apply_chat_template(messages, return_tensors="pt", tokenize=False, add_generation_prompt=False, tools=tool_list)
print(history)
elif input_text.startswith("/undo"):
if len(messages) > 2:
print("undo latest prompt")
messages = messages[:-2]
else:
print("cannot undo because there are not enough messages on history.")
print("")
elif input_text.startswith("/regen"):
if len(messages) >= 2:
print("regenerating message (not working)")
messages = messages[:-1]
seed = random.randint(0, 2**32 - 1) # Generate a random seed
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
append_generate_chat(None)
else:
print("cannot regenerate because there are not enough messages on history.")
print("")
elif input_text.startswith("/more"):
append_generate_chat(None) append_generate_chat(None)
else:
print("cannot regenerate because there are not enough messages on history.")
print("")
elif input_text.startswith("/more"): elif input_text.startswith("/file"):
append_generate_chat(None) filename = input_text[len("/file "):]
print("read '%s' for prompt:" % filename)
elif input_text.startswith("/auto"): with open(filename, "r") as f:
messages_backup = messages content = f.read()
messages = [roleflip] print(content)
for m in messages_backup: append_generate_chat(content)
role = m["role"]
content = m["content"] elif input_text.startswith("/auto"):
if role == "user": messages_backup = messages
role = "assistant" messages = [roleflip]
elif role == "assistant": for m in messages_backup:
role = "user" role = m["role"]
if role != "system": content = m["content"]
messages.append({"role": role, "content": content}) if role == "user":
append_generate_chat(None) # will automatically advance the conversation as 'user' role = "assistant"
last_message = messages[-1] elif role == "assistant":
last_message["role"] = "user" role = "user"
messages = messages_backup + [last_message] if role != "system":
append_generate_chat(None) # 'regular' chatbot answer messages.append({"role": role, "content": content})
append_generate_chat(None) # will automatically advance the conversation as 'user'
elif input_text.startswith("/help"): last_message = messages[-1]
print("!<prompt> answer as 'tool' in <tool_response> tags") last_message["role"] = "user"
print("/clear clear chat history") messages = messages_backup + [last_message]
print("/undo undo latest prompt") append_generate_chat(None) # 'regular' chatbot answer
print("/regen regenerate the last message")
print("/more generate more additional information") elif input_text.startswith("/help"):
print("/auto automatically advance conversation") print("!<prompt> answer as 'tool' in <tool_response> tags")
print("/help print this message") print("/clear clear chat history")
print("") print("/undo undo latest prompt")
print("/regen regenerate the last message")
print("/more generate more additional information")
print("/file read prompt input from file")
print("/auto automatically advance conversation")
print("/help print this message")
print("")
elif input_text.startswith("/"):
print("unknown command.")
elif input_text.startswith("/"): else:
print("unknown command.") append_generate_chat(input_text)
else:
append_generate_chat(input_text)

1
tests/__init__.py

@ -0,0 +1 @@
# empty

7
tests/helper.py

@ -0,0 +1,7 @@
def tool_dummy(a: int, b: str):
return "result_%d_%s" % (a, b)
def tool_dummy2(text: str):
return text.upper()

30
tests/test_tool_function_decorator.py

@ -0,0 +1,30 @@
import pytest
import tool_helper
import tests.helper as helper
def test_tool_function_decorator_if_clean_tool_list():
""" tests for the tool list to be empty. NOT strictly nessesary,
but I want to be warned if this is not the case anymore. Could be not the intention """
start_len = len(tool_helper.tool_list)
assert start_len == 0
def test_tool_function_decorator():
# get length before adding tools
start_len = len(tool_helper.tool_list)
# add tools like it would be a decorator
tool_helper.tool(helper.tool_dummy)
tool_helper.tool(helper.tool_dummy2)
# get length after adding tools
end_len = len(tool_helper.tool_list)
# remove the added ones again
tool_helper.tool_list = tool_helper.tool_list[:-2]
assert end_len == start_len + 2
assert len(tool_helper.tool_list) == start_len

89
tests/test_tool_parse_exec.py

@ -0,0 +1,89 @@
import pytest
import tool_helper
from unittest import mock
import tests.helper as helper
def test_tool_dummy():
with mock.patch("tests.helper.tool_dummy") as mock_dummy:
helper.tool_dummy()
mock_dummy.assert_called_once() # this will check if the mocked function on the context was called.
def test_tool_parse_no_exec():
with mock.patch("tests.helper.tool_dummy") as mock_dummy:
tool_helper.parse_and_execute_tool_call("something else", [helper.tool_dummy, helper.tool_dummy2])
assert mock_dummy.call_count == 0
def test_match_and_extract_no_match():
result = tool_helper._match_and_extract("something else", r"<tool_call>(.*)<\/tool_call>")
assert result == None
def test_match_and_extract_matching():
result = tool_helper._match_and_extract("asdfsdfas <tool_call>{json content}</tool_call> adfafsd", r"<tool_call>(.*)<\/tool_call>")
assert result == "{json content}"
def test_match_and_extract_matching2():
result = tool_helper._match_and_extract("<tool_call>{json content}</tool_call>", r"<tool_call>(.*)<\/tool_call>")
assert result == "{json content}"
def test_match_and_extract_matching3_with_newline():
result = tool_helper._match_and_extract("<tool_call>\n{json content}\n</tool_call>", r"<tool_call>(.*)<\/tool_call>")
assert result == "\n{json content}\n"
def test_string_malformed_faulty():
with mock.patch("utils.print_error") as print_error_mock:
result = tool_helper._execute_tool_call_str("{json_content}", [])
assert result == None
print_error_mock.assert_called_once() # this will check if the mocked function on the context was called.
def test_tool_call_json_1():
with mock.patch("utils.print_error") as print_error_mock:
result = tool_helper._execute_tool_call_json({"name": "tool_dummy", "arguments": {"a": 1, "b": "zwei"}}, [helper.tool_dummy, helper.tool_dummy2])
assert result == "result_1_zwei"
assert print_error_mock.call_count == 0
def test_tool_call_json_2():
with mock.patch("utils.print_error") as print_error_mock:
result = tool_helper._execute_tool_call_json({"name": "tool_dummy2", "arguments": {"text": "some_text"}}, [helper.tool_dummy, helper.tool_dummy2])
assert result == "SOME_TEXT"
assert print_error_mock.call_count == 0
def test_tool_call_json_non_existing_call_check():
with mock.patch("utils.print_error") as print_error_mock:
result = tool_helper._execute_tool_call_json({"name": "tool_dummy_which_is_not_existing", "arguments": {"text": "some_text"}}, [helper.tool_dummy, helper.tool_dummy2])
assert result == None
assert print_error_mock.call_count == 1 # this will check if the mocked function on the context was called.
def test_tool_call_json_wrong_arguments_check():
with mock.patch("utils.print_error") as print_error_mock:
result = tool_helper._execute_tool_call_json({"name": "tool_dummy", "arguments": {"a": "must_be_an_int_but_is_string", "b": "zwei"}}, [helper.tool_dummy, helper.tool_dummy2])
assert result == None
assert print_error_mock.call_count == 1 # this will check if the mocked function on the context was called.
def test_regex_multiline():
import re
pattern = r"<start>(.*)</end>"
# The text to search (spanning multiple lines)
text = """<start>
{json}
</end>"""
# Use re.search with re.DOTALL to match across newlines
match = re.search(pattern, text, re.DOTALL)
assert match.group(1).find("{json}") != -1

35
tool_functions.py

@ -0,0 +1,35 @@
import random
import datetime
from tool_helper import tool
@tool
def current_time():
"""Get the current local date and time as a string."""
return datetime.datetime.now().strftime("%Y-%m-%d %H:%M")
# @tool
# def random_float():
# """Generate a random float from 0..1."""
# return str(random.random())
@tool
def random_float(a: float=0.0, b: float=1.0):
"""Generate a random float in range [a, b], including both end points. Optional pass no parameter and range 0..1 will be used.
Args:
a: minimum possible value
b: maximum possible value"""
return str(random.randint(a, b))
@tool
def random_int(a: int, b: int):
"""Generate a random integer in range [a, b], including both end points.
Args:
a: minimum possible value
b: maximum possible value"""
return str(random.randint(a, b))
def register_dummy():
pass # dummy function to run and be sure the decorators have run

102
tool_helper.py

@ -0,0 +1,102 @@
from typing import Callable, List, Optional
import json
import re
import utils
tool_list = []
def tool(fn):
"""tool function decorator"""
print("register tool '%s'" % fn.__name__)
tool_list.append(fn)
# def parse_and_execute_tool_call(message: str, tools: list[function]) -> str | None:
# """execute tool call if needed accordint <tool_call> tag and return the content of the tool call or None if no call happened."""
# pass
def parse_and_execute_tool_call(message: str, tools: List[Callable]) -> Optional[str]:
"""
Execute a tool call if the <tool_call> tag is present and return the tool's response.
If no <tool_call> tag is found, return None.
Args:
message (str): The message containing the tool call.
tools (list[function]): A list of tool functions available for execution.
Returns:
Optional[str]: The content of the tool response or None if no tool call occurred.
"""
# in case LLM responds with <tool_call></tool_call> the correct way
extracted = _match_and_extract(message, r"<tool_call>(.*)<\/tool_call>")
if extracted:
return _execute_tool_call_str(extracted, tools)
# in case LLM responds with <tool_call></tool_response> by accident
extracted = _match_and_extract(message, r"<tool_call>(.*)<\/tool_.*>")
if extracted:
return _execute_tool_call_str(extracted, tools)
# in case LLM responds with <tool_call></???> by accident
extracted = _match_and_extract(message, r"<tool_call>(.*)<\/.*>")
if extracted:
return _execute_tool_call_str(extracted, tools)
# in case LLM responds with <tool_call></???> by accident
extracted = _match_and_extract(message, r"<tool_response>(.*)<\/.*>")
if extracted:
return _execute_tool_call_str(extracted, tools)
return None
def _match_and_extract(message: str, pattern: str) -> Optional[str]:
""" helper function to match regex and extract group 1 """
match = re.search(pattern, message, re.DOTALL)
if match:
group1 = match.group(1)
return group1
return None
def _execute_tool_call_str(tool_call: str, tools: List[Callable]) -> Optional[str]:
""" execute tool call per string content. The content must be a valid json """
try:
js = json.loads(tool_call)
return _execute_tool_call_json(js, tools)
except json.JSONDecodeError:
utils.print_error("Json was malformed. Will be ignored.")
return None
def _execute_tool_call_json(data: any, tools: List[Callable]) -> Optional[str]:
""" extract name and arguments from parsed data and call the tool, which is matched from the tools list """
# Extract tool name and arguments
tool_name = data.get("name")
arguments = data.get("arguments", {})
# Find the tool by name in the list of tools
for tool in tools:
if tool.__name__ == tool_name:
# Execute the tool
return _execute_tool_function(arguments, tool)
utils.print_error("Specified tool '%s' not found." % tool_name)
return None
def _execute_tool_function(arguments: any, tool: Callable) -> Optional[str]:
""" Execute the tool and return the result. """
try:
result = tool(**arguments)
print("<tool_response>", result, "</tool_response>")
return result
except TypeError as e:
utils.print_error("Type error while executing function call: '%s'" % str(e))
return None

11
utils.py

@ -0,0 +1,11 @@
import json
import sys
def load_json_file(filepath: str) -> any:
with open(filepath, "r") as f:
return json.load(f)
def print_error(*args, **kwargs):
print(*args, file=sys.stderr, **kwargs)
Loading…
Cancel
Save