diff --git a/.gitignore b/.gitignore index cb1f61f..6d27a49 100644 --- a/.gitignore +++ b/.gitignore @@ -1,2 +1,3 @@ /model/* -*.prof \ No newline at end of file +*.prof +__pycache__ \ No newline at end of file diff --git a/__main__.py b/__main__.py new file mode 100644 index 0000000..3e97233 --- /dev/null +++ b/__main__.py @@ -0,0 +1,8 @@ + +print("running __main__.-py") + +from llama import main + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/llama.py b/llama.py index 1dc1d3d..a0d4cf7 100644 --- a/llama.py +++ b/llama.py @@ -2,8 +2,9 @@ from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig import time import torch import random -import datetime -import json +from tool_helper import tool_list, parse_and_execute_tool_call +from tool_functions import register_dummy +import utils t_start = time.time() @@ -41,10 +42,9 @@ print("load took %.3fs" % (time.time() - t_start)) max_context_length = model.config.max_position_embeddings -# if tokenizer.chat_template is None: -print("apply external chat template...") -with open("chat_template.json", "r") as f: - tokenizer.chat_template = json.load(f) +tokenizer.chat_template = utils.load_json_file("chat_template.json") + + print("max_context_length is %d tokens." % (max_context_length)) @@ -97,22 +97,9 @@ messages = [ roleflip = {"role": "system", "content": "Keep the conversation going, ask for more information on the subject. Keep messages short at max 1-2 sentences. Do not thank and say goodbye."} -def current_time(): - """Get the current local date and time as a string.""" - return datetime.datetime.now().strftime("%Y-%m-%d %H:%M") - -def random_float(): - """Get a random float from 0..1""" - return str(random.random()) - -def random_int(a: int, b: int): - """Return random integer in range [a, b], including both end points. -Args: - a: minimum possible value - b: maximum possible value""" - return str(random.randint(a, b)) -tool_functions = [current_time, random_float, random_int] +register_dummy() +# tool_functions = [current_time, random_float, random_int] @@ -139,31 +126,35 @@ def generate_incremental(inputs): generated_tokens = input_ids # Initially, this is just the input tokens n = 0 + try: - # Loop to generate one token at a time - while True: - # Call the model with the current tokens - outputs = model(input_ids=generated_tokens, use_cache=True) + # Loop to generate one token at a time + while True: + # Call the model with the current tokens + outputs = model(input_ids=generated_tokens, use_cache=True) - # Get the next token (the last token from the generated sequence) - next_token = outputs.logits.argmax(dim=-1)[:, -1] + # Get the next token (the last token from the generated sequence) + next_token = outputs.logits.argmax(dim=-1)[:, -1] - # Append the new token to the sequence - generated_tokens = torch.cat([generated_tokens, next_token.unsqueeze(0)], dim=1) + # Append the new token to the sequence + generated_tokens = torch.cat([generated_tokens, next_token.unsqueeze(0)], dim=1) - # Decode and print the newly generated token (skip special tokens) - out_text = tokenizer.decode(next_token, skip_special_tokens=True) - print(out_text, end="", flush=True) # Print without newline + # Decode and print the newly generated token (skip special tokens) + out_text = tokenizer.decode(next_token, skip_special_tokens=True) + print(out_text, end="", flush=True) # Print without newline - # Check if the generated token is the end-of-sequence token - if next_token.item() == tokenizer.eos_token_id: - print("") - break + # Check if the generated token is the end-of-sequence token + if next_token.item() == tokenizer.eos_token_id: + print("") + break + + n += 1 + if n >= 15: + n = 0 + torch.cuda.empty_cache() - n += 1 - if n >= 30: - n = 0 - torch.cuda.empty_cache() + except KeyboardInterrupt: + pass # Once done, return the full generated sequence @@ -184,7 +175,7 @@ def append_generate_chat(input_text: str, role="user"): # input_text = "Hello, who are you?" # inputs = tokenizer(input_text, return_tensors="pt").to("cpu") # .to("cuda") .to("cpu") - inputs = tokenizer.apply_chat_template(messages, return_tensors="pt", tokenize=True, return_dict=True, add_generation_prompt=True, tools=tool_functions) #continue_final_message=True, + inputs = tokenizer.apply_chat_template(messages, return_tensors="pt", tokenize=True, return_dict=True, add_generation_prompt=True, tools=tool_list) #continue_final_message=True, inputs = {key: value.to(model.device) for key, value in inputs.items()} # inputs = {key: value.to("cpu") for key, value in inputs.items()} # inputs["input_ids"] = inputs["input_ids"][:, 1:] @@ -194,82 +185,105 @@ def append_generate_chat(input_text: str, role="user"): # append result to message history messages.append({"role": "assistant", "content": out_text}) + print("") print("generation took %.3fs (%d tokens)" % (time.time() - t_start, len(outputs[0]))) - -while True: - # print an input prompt to receive text or commands - input_text = input(">>> ") - print("") + # handle tool call and check if a tool call has happened. + tool_result = parse_and_execute_tool_call(out_text, tool_list) + if tool_result != None: + # tool call happened + # tool_result = "%s" % tool_result + # depending on the chat template the tool response tags must or must not be passed. :( + append_generate_chat(tool_result, role="tool") - if input_text.startswith("!"): - # append_generate_chat("%s" % input_text[1:], role="tool") - append_generate_chat("%s" % input_text[1:], role="tool") # depending on the chat template the tool response tags must or must not be passed. :( - elif input_text.startswith("/clear"): - print("clearing chat history") - messages = [messages[0]] +def main(): + global messages + while True: + # print an input prompt to receive text or commands + input_text = input(">>> ") print("") - elif input_text.startswith("/history"): - history = tokenizer.apply_chat_template(messages, return_tensors="pt", tokenize=False, add_generation_prompt=False, tools=tool_functions) - print(history) - elif input_text.startswith("/undo"): - if len(messages) > 2: - print("undo latest prompt") - messages = messages[:-2] - else: - print("cannot undo because there are not enough messages on history.") - print("") + if input_text.startswith("!"): + # append_generate_chat("%s" % input_text[1:], role="tool") + append_generate_chat("%s" % input_text[1:], role="tool") # depending on the chat template the tool response tags must or must not be passed. :( - elif input_text.startswith("/regen"): - if len(messages) >= 2: - print("regenerating message (not working)") - messages = messages[:-1] - seed = random.randint(0, 2**32 - 1) # Generate a random seed - torch.manual_seed(seed) - torch.cuda.manual_seed_all(seed) + elif input_text.startswith("/clear"): + print("clearing chat history") + start_msg = messages[0] + messages = [start_msg] + print("") + + elif input_text.startswith("/history"): + history = tokenizer.apply_chat_template(messages, return_tensors="pt", tokenize=False, add_generation_prompt=False, tools=tool_list) + print(history) + + elif input_text.startswith("/undo"): + if len(messages) > 2: + print("undo latest prompt") + messages = messages[:-2] + else: + print("cannot undo because there are not enough messages on history.") + print("") + + elif input_text.startswith("/regen"): + if len(messages) >= 2: + print("regenerating message (not working)") + messages = messages[:-1] + seed = random.randint(0, 2**32 - 1) # Generate a random seed + torch.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + append_generate_chat(None) + else: + print("cannot regenerate because there are not enough messages on history.") + print("") + + elif input_text.startswith("/more"): append_generate_chat(None) - else: - print("cannot regenerate because there are not enough messages on history.") - print("") - elif input_text.startswith("/more"): - append_generate_chat(None) - - elif input_text.startswith("/auto"): - messages_backup = messages - messages = [roleflip] - for m in messages_backup: - role = m["role"] - content = m["content"] - if role == "user": - role = "assistant" - elif role == "assistant": - role = "user" - if role != "system": - messages.append({"role": role, "content": content}) - append_generate_chat(None) # will automatically advance the conversation as 'user' - last_message = messages[-1] - last_message["role"] = "user" - messages = messages_backup + [last_message] - append_generate_chat(None) # 'regular' chatbot answer - - elif input_text.startswith("/help"): - print("! answer as 'tool' in tags") - print("/clear clear chat history") - print("/undo undo latest prompt") - print("/regen regenerate the last message") - print("/more generate more additional information") - print("/auto automatically advance conversation") - print("/help print this message") - print("") + elif input_text.startswith("/file"): + filename = input_text[len("/file "):] + print("read '%s' for prompt:" % filename) + with open(filename, "r") as f: + content = f.read() + print(content) + append_generate_chat(content) + + elif input_text.startswith("/auto"): + messages_backup = messages + messages = [roleflip] + for m in messages_backup: + role = m["role"] + content = m["content"] + if role == "user": + role = "assistant" + elif role == "assistant": + role = "user" + if role != "system": + messages.append({"role": role, "content": content}) + append_generate_chat(None) # will automatically advance the conversation as 'user' + last_message = messages[-1] + last_message["role"] = "user" + messages = messages_backup + [last_message] + append_generate_chat(None) # 'regular' chatbot answer + + elif input_text.startswith("/help"): + print("! answer as 'tool' in tags") + print("/clear clear chat history") + print("/undo undo latest prompt") + print("/regen regenerate the last message") + print("/more generate more additional information") + print("/file read prompt input from file") + print("/auto automatically advance conversation") + print("/help print this message") + print("") + + elif input_text.startswith("/"): + print("unknown command.") - elif input_text.startswith("/"): - print("unknown command.") + else: + append_generate_chat(input_text) - else: - append_generate_chat(input_text) diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..bc63beb --- /dev/null +++ b/tests/__init__.py @@ -0,0 +1 @@ +# empty \ No newline at end of file diff --git a/tests/helper.py b/tests/helper.py new file mode 100644 index 0000000..6fa7e12 --- /dev/null +++ b/tests/helper.py @@ -0,0 +1,7 @@ + + +def tool_dummy(a: int, b: str): + return "result_%d_%s" % (a, b) + +def tool_dummy2(text: str): + return text.upper() \ No newline at end of file diff --git a/tests/test_tool_function_decorator.py b/tests/test_tool_function_decorator.py new file mode 100644 index 0000000..3405493 --- /dev/null +++ b/tests/test_tool_function_decorator.py @@ -0,0 +1,30 @@ +import pytest +import tool_helper +import tests.helper as helper + + +def test_tool_function_decorator_if_clean_tool_list(): + """ tests for the tool list to be empty. NOT strictly nessesary, + but I want to be warned if this is not the case anymore. Could be not the intention """ + start_len = len(tool_helper.tool_list) + assert start_len == 0 + +def test_tool_function_decorator(): + # get length before adding tools + start_len = len(tool_helper.tool_list) + + # add tools like it would be a decorator + tool_helper.tool(helper.tool_dummy) + tool_helper.tool(helper.tool_dummy2) + + # get length after adding tools + end_len = len(tool_helper.tool_list) + + # remove the added ones again + tool_helper.tool_list = tool_helper.tool_list[:-2] + + assert end_len == start_len + 2 + assert len(tool_helper.tool_list) == start_len + + + diff --git a/tests/test_tool_parse_exec.py b/tests/test_tool_parse_exec.py new file mode 100644 index 0000000..d0caa7b --- /dev/null +++ b/tests/test_tool_parse_exec.py @@ -0,0 +1,89 @@ +import pytest +import tool_helper +from unittest import mock +import tests.helper as helper + + + + + +def test_tool_dummy(): + with mock.patch("tests.helper.tool_dummy") as mock_dummy: + helper.tool_dummy() + mock_dummy.assert_called_once() # this will check if the mocked function on the context was called. + + +def test_tool_parse_no_exec(): + with mock.patch("tests.helper.tool_dummy") as mock_dummy: + tool_helper.parse_and_execute_tool_call("something else", [helper.tool_dummy, helper.tool_dummy2]) + assert mock_dummy.call_count == 0 + + +def test_match_and_extract_no_match(): + result = tool_helper._match_and_extract("something else", r"(.*)<\/tool_call>") + assert result == None + + +def test_match_and_extract_matching(): + result = tool_helper._match_and_extract("asdfsdfas {json content} adfafsd", r"(.*)<\/tool_call>") + assert result == "{json content}" + + +def test_match_and_extract_matching2(): + result = tool_helper._match_and_extract("{json content}", r"(.*)<\/tool_call>") + assert result == "{json content}" + + +def test_match_and_extract_matching3_with_newline(): + result = tool_helper._match_and_extract("\n{json content}\n", r"(.*)<\/tool_call>") + assert result == "\n{json content}\n" + + +def test_string_malformed_faulty(): + with mock.patch("utils.print_error") as print_error_mock: + result = tool_helper._execute_tool_call_str("{json_content}", []) + assert result == None + print_error_mock.assert_called_once() # this will check if the mocked function on the context was called. + + +def test_tool_call_json_1(): + with mock.patch("utils.print_error") as print_error_mock: + result = tool_helper._execute_tool_call_json({"name": "tool_dummy", "arguments": {"a": 1, "b": "zwei"}}, [helper.tool_dummy, helper.tool_dummy2]) + assert result == "result_1_zwei" + assert print_error_mock.call_count == 0 + + +def test_tool_call_json_2(): + with mock.patch("utils.print_error") as print_error_mock: + result = tool_helper._execute_tool_call_json({"name": "tool_dummy2", "arguments": {"text": "some_text"}}, [helper.tool_dummy, helper.tool_dummy2]) + assert result == "SOME_TEXT" + assert print_error_mock.call_count == 0 + + +def test_tool_call_json_non_existing_call_check(): + with mock.patch("utils.print_error") as print_error_mock: + result = tool_helper._execute_tool_call_json({"name": "tool_dummy_which_is_not_existing", "arguments": {"text": "some_text"}}, [helper.tool_dummy, helper.tool_dummy2]) + assert result == None + assert print_error_mock.call_count == 1 # this will check if the mocked function on the context was called. + +def test_tool_call_json_wrong_arguments_check(): + with mock.patch("utils.print_error") as print_error_mock: + result = tool_helper._execute_tool_call_json({"name": "tool_dummy", "arguments": {"a": "must_be_an_int_but_is_string", "b": "zwei"}}, [helper.tool_dummy, helper.tool_dummy2]) + assert result == None + assert print_error_mock.call_count == 1 # this will check if the mocked function on the context was called. + + + +def test_regex_multiline(): + import re + pattern = r"(.*)" + + # The text to search (spanning multiple lines) + text = """ + {json} + """ + + # Use re.search with re.DOTALL to match across newlines + match = re.search(pattern, text, re.DOTALL) + + assert match.group(1).find("{json}") != -1 \ No newline at end of file diff --git a/tool_functions.py b/tool_functions.py new file mode 100644 index 0000000..2a74e69 --- /dev/null +++ b/tool_functions.py @@ -0,0 +1,35 @@ +import random +import datetime +from tool_helper import tool + +@tool +def current_time(): + """Get the current local date and time as a string.""" + return datetime.datetime.now().strftime("%Y-%m-%d %H:%M") + +# @tool +# def random_float(): +# """Generate a random float from 0..1.""" +# return str(random.random()) + +@tool +def random_float(a: float=0.0, b: float=1.0): + """Generate a random float in range [a, b], including both end points. Optional pass no parameter and range 0..1 will be used. +Args: + a: minimum possible value + b: maximum possible value""" + return str(random.randint(a, b)) + +@tool +def random_int(a: int, b: int): + """Generate a random integer in range [a, b], including both end points. +Args: + a: minimum possible value + b: maximum possible value""" + return str(random.randint(a, b)) + + + + +def register_dummy(): + pass # dummy function to run and be sure the decorators have run \ No newline at end of file diff --git a/tool_helper.py b/tool_helper.py new file mode 100644 index 0000000..653915d --- /dev/null +++ b/tool_helper.py @@ -0,0 +1,102 @@ + +from typing import Callable, List, Optional +import json +import re +import utils + +tool_list = [] + + +def tool(fn): + """tool function decorator""" + print("register tool '%s'" % fn.__name__) + tool_list.append(fn) + +# def parse_and_execute_tool_call(message: str, tools: list[function]) -> str | None: +# """execute tool call if needed accordint tag and return the content of the tool call or None if no call happened.""" +# pass + + + + + + +def parse_and_execute_tool_call(message: str, tools: List[Callable]) -> Optional[str]: + """ + Execute a tool call if the tag is present and return the tool's response. + If no tag is found, return None. + + Args: + message (str): The message containing the tool call. + tools (list[function]): A list of tool functions available for execution. + + Returns: + Optional[str]: The content of the tool response or None if no tool call occurred. + """ + + # in case LLM responds with the correct way + extracted = _match_and_extract(message, r"(.*)<\/tool_call>") + if extracted: + return _execute_tool_call_str(extracted, tools) + + # in case LLM responds with by accident + extracted = _match_and_extract(message, r"(.*)<\/tool_.*>") + if extracted: + return _execute_tool_call_str(extracted, tools) + + # in case LLM responds with by accident + extracted = _match_and_extract(message, r"(.*)<\/.*>") + if extracted: + return _execute_tool_call_str(extracted, tools) + + # in case LLM responds with by accident + extracted = _match_and_extract(message, r"(.*)<\/.*>") + if extracted: + return _execute_tool_call_str(extracted, tools) + + return None + + +def _match_and_extract(message: str, pattern: str) -> Optional[str]: + """ helper function to match regex and extract group 1 """ + match = re.search(pattern, message, re.DOTALL) + if match: + group1 = match.group(1) + return group1 + return None + + +def _execute_tool_call_str(tool_call: str, tools: List[Callable]) -> Optional[str]: + """ execute tool call per string content. The content must be a valid json """ + try: + js = json.loads(tool_call) + return _execute_tool_call_json(js, tools) + except json.JSONDecodeError: + utils.print_error("Json was malformed. Will be ignored.") + return None + +def _execute_tool_call_json(data: any, tools: List[Callable]) -> Optional[str]: + """ extract name and arguments from parsed data and call the tool, which is matched from the tools list """ + # Extract tool name and arguments + tool_name = data.get("name") + arguments = data.get("arguments", {}) + + # Find the tool by name in the list of tools + for tool in tools: + if tool.__name__ == tool_name: + # Execute the tool + return _execute_tool_function(arguments, tool) + + utils.print_error("Specified tool '%s' not found." % tool_name) + return None + +def _execute_tool_function(arguments: any, tool: Callable) -> Optional[str]: + """ Execute the tool and return the result. """ + try: + result = tool(**arguments) + print("", result, "") + return result + except TypeError as e: + utils.print_error("Type error while executing function call: '%s'" % str(e)) + + return None diff --git a/utils.py b/utils.py new file mode 100644 index 0000000..4936a6a --- /dev/null +++ b/utils.py @@ -0,0 +1,11 @@ +import json +import sys + +def load_json_file(filepath: str) -> any: + with open(filepath, "r") as f: + return json.load(f) + + + +def print_error(*args, **kwargs): + print(*args, file=sys.stderr, **kwargs) \ No newline at end of file