From 02b3850a9f24eb5278f73b0482f558ac1bb4d5d1 Mon Sep 17 00:00:00 2001 From: Florin Tobler Date: Thu, 2 Jan 2025 01:26:28 +0100 Subject: [PATCH] tool evaluations --- llama.py | 21 +++++++++++++++++++-- tool_functions.py | 5 +++-- 2 files changed, 22 insertions(+), 4 deletions(-) diff --git a/llama.py b/llama.py index a0d4cf7..5e72830 100644 --- a/llama.py +++ b/llama.py @@ -5,6 +5,7 @@ import random from tool_helper import tool_list, parse_and_execute_tool_call from tool_functions import register_dummy import utils +import re t_start = time.time() @@ -94,6 +95,8 @@ messages = [ # {"role": "user", "content": "Hello, who are you?"} ] +systemmessage = "Hold a casual conversation with the user. Keep responses short at max 3 sentences." + roleflip = {"role": "system", "content": "Keep the conversation going, ask for more information on the subject. Keep messages short at max 1-2 sentences. Do not thank and say goodbye."} @@ -175,7 +178,7 @@ def append_generate_chat(input_text: str, role="user"): # input_text = "Hello, who are you?" # inputs = tokenizer(input_text, return_tensors="pt").to("cpu") # .to("cuda") .to("cpu") - inputs = tokenizer.apply_chat_template(messages, return_tensors="pt", tokenize=True, return_dict=True, add_generation_prompt=True, tools=tool_list) #continue_final_message=True, + inputs = tokenizer.apply_chat_template(messages, return_tensors="pt", tokenize=True, return_dict=True, add_generation_prompt=True) #continue_final_message=True, inputs = {key: value.to(model.device) for key, value in inputs.items()} # inputs = {key: value.to("cpu") for key, value in inputs.items()} # inputs["input_ids"] = inputs["input_ids"][:, 1:] @@ -199,8 +202,22 @@ def append_generate_chat(input_text: str, role="user"): +def generate_tool_use_header(tools: list[callable]) -> str: + temp_messages = [{}] # for some reason an empty array is not allowed but a {} inside works like an empty array. + s = tokenizer.apply_chat_template(temp_messages, return_tensors="pt", tokenize=False, add_generation_prompt=False, tools=tools) + pattern = r"<\|im_start\|>system\n(.*)<\|im_end\|>" + match = re.search(pattern, s, re.DOTALL) + if not match: + raise Exception("Failed to regex match the template tool system text.") + extraction = match.group(1) + return extraction + + def main(): global messages + + messages = [{"role": "system", "content": systemmessage + "\n" + generate_tool_use_header(tool_list)}] + while True: # print an input prompt to receive text or commands input_text = input(">>> ") @@ -218,7 +235,7 @@ def main(): print("") elif input_text.startswith("/history"): - history = tokenizer.apply_chat_template(messages, return_tensors="pt", tokenize=False, add_generation_prompt=False, tools=tool_list) + history = tokenizer.apply_chat_template(messages, return_tensors="pt", tokenize=False, add_generation_prompt=False) print(history) elif input_text.startswith("/undo"): diff --git a/tool_functions.py b/tool_functions.py index 1eb353b..cac6c54 100644 --- a/tool_functions.py +++ b/tool_functions.py @@ -40,7 +40,7 @@ Args: def math_evaluate(expression: str): """evaluate and reduce a mathematical expression. Args: - expression: Reduce mathematic expression (without '=') algebraically.. + expression: Reduce mathematic expression (without '=') algebraically. """ tokens = math_lexer.tokenize(expression) @@ -69,4 +69,5 @@ Args: def register_dummy(): - pass # dummy function to run and be sure the decorators have run \ No newline at end of file + """dummy function to run and be sure the decorators have been initialized""" + pass \ No newline at end of file