diff --git a/.gitignore b/.gitignore
index cb1f61f..6d27a49 100644
--- a/.gitignore
+++ b/.gitignore
@@ -1,2 +1,3 @@
/model/*
-*.prof
\ No newline at end of file
+*.prof
+__pycache__
\ No newline at end of file
diff --git a/__main__.py b/__main__.py
new file mode 100644
index 0000000..3e97233
--- /dev/null
+++ b/__main__.py
@@ -0,0 +1,8 @@
+
+print("running __main__.-py")
+
+from llama import main
+
+
+if __name__ == "__main__":
+ main()
\ No newline at end of file
diff --git a/llama.py b/llama.py
index 1dc1d3d..a0d4cf7 100644
--- a/llama.py
+++ b/llama.py
@@ -2,8 +2,9 @@ from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
import time
import torch
import random
-import datetime
-import json
+from tool_helper import tool_list, parse_and_execute_tool_call
+from tool_functions import register_dummy
+import utils
t_start = time.time()
@@ -41,10 +42,9 @@ print("load took %.3fs" % (time.time() - t_start))
max_context_length = model.config.max_position_embeddings
-# if tokenizer.chat_template is None:
-print("apply external chat template...")
-with open("chat_template.json", "r") as f:
- tokenizer.chat_template = json.load(f)
+tokenizer.chat_template = utils.load_json_file("chat_template.json")
+
+
print("max_context_length is %d tokens." % (max_context_length))
@@ -97,22 +97,9 @@ messages = [
roleflip = {"role": "system", "content": "Keep the conversation going, ask for more information on the subject. Keep messages short at max 1-2 sentences. Do not thank and say goodbye."}
-def current_time():
- """Get the current local date and time as a string."""
- return datetime.datetime.now().strftime("%Y-%m-%d %H:%M")
-
-def random_float():
- """Get a random float from 0..1"""
- return str(random.random())
-
-def random_int(a: int, b: int):
- """Return random integer in range [a, b], including both end points.
-Args:
- a: minimum possible value
- b: maximum possible value"""
- return str(random.randint(a, b))
-tool_functions = [current_time, random_float, random_int]
+register_dummy()
+# tool_functions = [current_time, random_float, random_int]
@@ -139,31 +126,35 @@ def generate_incremental(inputs):
generated_tokens = input_ids # Initially, this is just the input tokens
n = 0
+ try:
- # Loop to generate one token at a time
- while True:
- # Call the model with the current tokens
- outputs = model(input_ids=generated_tokens, use_cache=True)
+ # Loop to generate one token at a time
+ while True:
+ # Call the model with the current tokens
+ outputs = model(input_ids=generated_tokens, use_cache=True)
- # Get the next token (the last token from the generated sequence)
- next_token = outputs.logits.argmax(dim=-1)[:, -1]
+ # Get the next token (the last token from the generated sequence)
+ next_token = outputs.logits.argmax(dim=-1)[:, -1]
- # Append the new token to the sequence
- generated_tokens = torch.cat([generated_tokens, next_token.unsqueeze(0)], dim=1)
+ # Append the new token to the sequence
+ generated_tokens = torch.cat([generated_tokens, next_token.unsqueeze(0)], dim=1)
- # Decode and print the newly generated token (skip special tokens)
- out_text = tokenizer.decode(next_token, skip_special_tokens=True)
- print(out_text, end="", flush=True) # Print without newline
+ # Decode and print the newly generated token (skip special tokens)
+ out_text = tokenizer.decode(next_token, skip_special_tokens=True)
+ print(out_text, end="", flush=True) # Print without newline
- # Check if the generated token is the end-of-sequence token
- if next_token.item() == tokenizer.eos_token_id:
- print("")
- break
+ # Check if the generated token is the end-of-sequence token
+ if next_token.item() == tokenizer.eos_token_id:
+ print("")
+ break
+
+ n += 1
+ if n >= 15:
+ n = 0
+ torch.cuda.empty_cache()
- n += 1
- if n >= 30:
- n = 0
- torch.cuda.empty_cache()
+ except KeyboardInterrupt:
+ pass
# Once done, return the full generated sequence
@@ -184,7 +175,7 @@ def append_generate_chat(input_text: str, role="user"):
# input_text = "Hello, who are you?"
# inputs = tokenizer(input_text, return_tensors="pt").to("cpu") # .to("cuda") .to("cpu")
- inputs = tokenizer.apply_chat_template(messages, return_tensors="pt", tokenize=True, return_dict=True, add_generation_prompt=True, tools=tool_functions) #continue_final_message=True,
+ inputs = tokenizer.apply_chat_template(messages, return_tensors="pt", tokenize=True, return_dict=True, add_generation_prompt=True, tools=tool_list) #continue_final_message=True,
inputs = {key: value.to(model.device) for key, value in inputs.items()}
# inputs = {key: value.to("cpu") for key, value in inputs.items()}
# inputs["input_ids"] = inputs["input_ids"][:, 1:]
@@ -194,82 +185,105 @@ def append_generate_chat(input_text: str, role="user"):
# append result to message history
messages.append({"role": "assistant", "content": out_text})
+
print("")
print("generation took %.3fs (%d tokens)" % (time.time() - t_start, len(outputs[0])))
-
-while True:
- # print an input prompt to receive text or commands
- input_text = input(">>> ")
- print("")
+ # handle tool call and check if a tool call has happened.
+ tool_result = parse_and_execute_tool_call(out_text, tool_list)
+ if tool_result != None:
+ # tool call happened
+ # tool_result = "%s" % tool_result
+ # depending on the chat template the tool response tags must or must not be passed. :(
+ append_generate_chat(tool_result, role="tool")
- if input_text.startswith("!"):
- # append_generate_chat("%s" % input_text[1:], role="tool")
- append_generate_chat("%s" % input_text[1:], role="tool") # depending on the chat template the tool response tags must or must not be passed. :(
- elif input_text.startswith("/clear"):
- print("clearing chat history")
- messages = [messages[0]]
+def main():
+ global messages
+ while True:
+ # print an input prompt to receive text or commands
+ input_text = input(">>> ")
print("")
- elif input_text.startswith("/history"):
- history = tokenizer.apply_chat_template(messages, return_tensors="pt", tokenize=False, add_generation_prompt=False, tools=tool_functions)
- print(history)
- elif input_text.startswith("/undo"):
- if len(messages) > 2:
- print("undo latest prompt")
- messages = messages[:-2]
- else:
- print("cannot undo because there are not enough messages on history.")
- print("")
+ if input_text.startswith("!"):
+ # append_generate_chat("%s" % input_text[1:], role="tool")
+ append_generate_chat("%s" % input_text[1:], role="tool") # depending on the chat template the tool response tags must or must not be passed. :(
- elif input_text.startswith("/regen"):
- if len(messages) >= 2:
- print("regenerating message (not working)")
- messages = messages[:-1]
- seed = random.randint(0, 2**32 - 1) # Generate a random seed
- torch.manual_seed(seed)
- torch.cuda.manual_seed_all(seed)
+ elif input_text.startswith("/clear"):
+ print("clearing chat history")
+ start_msg = messages[0]
+ messages = [start_msg]
+ print("")
+
+ elif input_text.startswith("/history"):
+ history = tokenizer.apply_chat_template(messages, return_tensors="pt", tokenize=False, add_generation_prompt=False, tools=tool_list)
+ print(history)
+
+ elif input_text.startswith("/undo"):
+ if len(messages) > 2:
+ print("undo latest prompt")
+ messages = messages[:-2]
+ else:
+ print("cannot undo because there are not enough messages on history.")
+ print("")
+
+ elif input_text.startswith("/regen"):
+ if len(messages) >= 2:
+ print("regenerating message (not working)")
+ messages = messages[:-1]
+ seed = random.randint(0, 2**32 - 1) # Generate a random seed
+ torch.manual_seed(seed)
+ torch.cuda.manual_seed_all(seed)
+ append_generate_chat(None)
+ else:
+ print("cannot regenerate because there are not enough messages on history.")
+ print("")
+
+ elif input_text.startswith("/more"):
append_generate_chat(None)
- else:
- print("cannot regenerate because there are not enough messages on history.")
- print("")
- elif input_text.startswith("/more"):
- append_generate_chat(None)
-
- elif input_text.startswith("/auto"):
- messages_backup = messages
- messages = [roleflip]
- for m in messages_backup:
- role = m["role"]
- content = m["content"]
- if role == "user":
- role = "assistant"
- elif role == "assistant":
- role = "user"
- if role != "system":
- messages.append({"role": role, "content": content})
- append_generate_chat(None) # will automatically advance the conversation as 'user'
- last_message = messages[-1]
- last_message["role"] = "user"
- messages = messages_backup + [last_message]
- append_generate_chat(None) # 'regular' chatbot answer
-
- elif input_text.startswith("/help"):
- print("! answer as 'tool' in tags")
- print("/clear clear chat history")
- print("/undo undo latest prompt")
- print("/regen regenerate the last message")
- print("/more generate more additional information")
- print("/auto automatically advance conversation")
- print("/help print this message")
- print("")
+ elif input_text.startswith("/file"):
+ filename = input_text[len("/file "):]
+ print("read '%s' for prompt:" % filename)
+ with open(filename, "r") as f:
+ content = f.read()
+ print(content)
+ append_generate_chat(content)
+
+ elif input_text.startswith("/auto"):
+ messages_backup = messages
+ messages = [roleflip]
+ for m in messages_backup:
+ role = m["role"]
+ content = m["content"]
+ if role == "user":
+ role = "assistant"
+ elif role == "assistant":
+ role = "user"
+ if role != "system":
+ messages.append({"role": role, "content": content})
+ append_generate_chat(None) # will automatically advance the conversation as 'user'
+ last_message = messages[-1]
+ last_message["role"] = "user"
+ messages = messages_backup + [last_message]
+ append_generate_chat(None) # 'regular' chatbot answer
+
+ elif input_text.startswith("/help"):
+ print("! answer as 'tool' in tags")
+ print("/clear clear chat history")
+ print("/undo undo latest prompt")
+ print("/regen regenerate the last message")
+ print("/more generate more additional information")
+ print("/file read prompt input from file")
+ print("/auto automatically advance conversation")
+ print("/help print this message")
+ print("")
+
+ elif input_text.startswith("/"):
+ print("unknown command.")
- elif input_text.startswith("/"):
- print("unknown command.")
+ else:
+ append_generate_chat(input_text)
- else:
- append_generate_chat(input_text)
diff --git a/tests/__init__.py b/tests/__init__.py
new file mode 100644
index 0000000..bc63beb
--- /dev/null
+++ b/tests/__init__.py
@@ -0,0 +1 @@
+# empty
\ No newline at end of file
diff --git a/tests/helper.py b/tests/helper.py
new file mode 100644
index 0000000..6fa7e12
--- /dev/null
+++ b/tests/helper.py
@@ -0,0 +1,7 @@
+
+
+def tool_dummy(a: int, b: str):
+ return "result_%d_%s" % (a, b)
+
+def tool_dummy2(text: str):
+ return text.upper()
\ No newline at end of file
diff --git a/tests/test_tool_function_decorator.py b/tests/test_tool_function_decorator.py
new file mode 100644
index 0000000..3405493
--- /dev/null
+++ b/tests/test_tool_function_decorator.py
@@ -0,0 +1,30 @@
+import pytest
+import tool_helper
+import tests.helper as helper
+
+
+def test_tool_function_decorator_if_clean_tool_list():
+ """ tests for the tool list to be empty. NOT strictly nessesary,
+ but I want to be warned if this is not the case anymore. Could be not the intention """
+ start_len = len(tool_helper.tool_list)
+ assert start_len == 0
+
+def test_tool_function_decorator():
+ # get length before adding tools
+ start_len = len(tool_helper.tool_list)
+
+ # add tools like it would be a decorator
+ tool_helper.tool(helper.tool_dummy)
+ tool_helper.tool(helper.tool_dummy2)
+
+ # get length after adding tools
+ end_len = len(tool_helper.tool_list)
+
+ # remove the added ones again
+ tool_helper.tool_list = tool_helper.tool_list[:-2]
+
+ assert end_len == start_len + 2
+ assert len(tool_helper.tool_list) == start_len
+
+
+
diff --git a/tests/test_tool_parse_exec.py b/tests/test_tool_parse_exec.py
new file mode 100644
index 0000000..d0caa7b
--- /dev/null
+++ b/tests/test_tool_parse_exec.py
@@ -0,0 +1,89 @@
+import pytest
+import tool_helper
+from unittest import mock
+import tests.helper as helper
+
+
+
+
+
+def test_tool_dummy():
+ with mock.patch("tests.helper.tool_dummy") as mock_dummy:
+ helper.tool_dummy()
+ mock_dummy.assert_called_once() # this will check if the mocked function on the context was called.
+
+
+def test_tool_parse_no_exec():
+ with mock.patch("tests.helper.tool_dummy") as mock_dummy:
+ tool_helper.parse_and_execute_tool_call("something else", [helper.tool_dummy, helper.tool_dummy2])
+ assert mock_dummy.call_count == 0
+
+
+def test_match_and_extract_no_match():
+ result = tool_helper._match_and_extract("something else", r"(.*)<\/tool_call>")
+ assert result == None
+
+
+def test_match_and_extract_matching():
+ result = tool_helper._match_and_extract("asdfsdfas {json content} adfafsd", r"(.*)<\/tool_call>")
+ assert result == "{json content}"
+
+
+def test_match_and_extract_matching2():
+ result = tool_helper._match_and_extract("{json content}", r"(.*)<\/tool_call>")
+ assert result == "{json content}"
+
+
+def test_match_and_extract_matching3_with_newline():
+ result = tool_helper._match_and_extract("\n{json content}\n", r"(.*)<\/tool_call>")
+ assert result == "\n{json content}\n"
+
+
+def test_string_malformed_faulty():
+ with mock.patch("utils.print_error") as print_error_mock:
+ result = tool_helper._execute_tool_call_str("{json_content}", [])
+ assert result == None
+ print_error_mock.assert_called_once() # this will check if the mocked function on the context was called.
+
+
+def test_tool_call_json_1():
+ with mock.patch("utils.print_error") as print_error_mock:
+ result = tool_helper._execute_tool_call_json({"name": "tool_dummy", "arguments": {"a": 1, "b": "zwei"}}, [helper.tool_dummy, helper.tool_dummy2])
+ assert result == "result_1_zwei"
+ assert print_error_mock.call_count == 0
+
+
+def test_tool_call_json_2():
+ with mock.patch("utils.print_error") as print_error_mock:
+ result = tool_helper._execute_tool_call_json({"name": "tool_dummy2", "arguments": {"text": "some_text"}}, [helper.tool_dummy, helper.tool_dummy2])
+ assert result == "SOME_TEXT"
+ assert print_error_mock.call_count == 0
+
+
+def test_tool_call_json_non_existing_call_check():
+ with mock.patch("utils.print_error") as print_error_mock:
+ result = tool_helper._execute_tool_call_json({"name": "tool_dummy_which_is_not_existing", "arguments": {"text": "some_text"}}, [helper.tool_dummy, helper.tool_dummy2])
+ assert result == None
+ assert print_error_mock.call_count == 1 # this will check if the mocked function on the context was called.
+
+def test_tool_call_json_wrong_arguments_check():
+ with mock.patch("utils.print_error") as print_error_mock:
+ result = tool_helper._execute_tool_call_json({"name": "tool_dummy", "arguments": {"a": "must_be_an_int_but_is_string", "b": "zwei"}}, [helper.tool_dummy, helper.tool_dummy2])
+ assert result == None
+ assert print_error_mock.call_count == 1 # this will check if the mocked function on the context was called.
+
+
+
+def test_regex_multiline():
+ import re
+ pattern = r"(.*)"
+
+ # The text to search (spanning multiple lines)
+ text = """
+ {json}
+ """
+
+ # Use re.search with re.DOTALL to match across newlines
+ match = re.search(pattern, text, re.DOTALL)
+
+ assert match.group(1).find("{json}") != -1
\ No newline at end of file
diff --git a/tool_functions.py b/tool_functions.py
new file mode 100644
index 0000000..2a74e69
--- /dev/null
+++ b/tool_functions.py
@@ -0,0 +1,35 @@
+import random
+import datetime
+from tool_helper import tool
+
+@tool
+def current_time():
+ """Get the current local date and time as a string."""
+ return datetime.datetime.now().strftime("%Y-%m-%d %H:%M")
+
+# @tool
+# def random_float():
+# """Generate a random float from 0..1."""
+# return str(random.random())
+
+@tool
+def random_float(a: float=0.0, b: float=1.0):
+ """Generate a random float in range [a, b], including both end points. Optional pass no parameter and range 0..1 will be used.
+Args:
+ a: minimum possible value
+ b: maximum possible value"""
+ return str(random.randint(a, b))
+
+@tool
+def random_int(a: int, b: int):
+ """Generate a random integer in range [a, b], including both end points.
+Args:
+ a: minimum possible value
+ b: maximum possible value"""
+ return str(random.randint(a, b))
+
+
+
+
+def register_dummy():
+ pass # dummy function to run and be sure the decorators have run
\ No newline at end of file
diff --git a/tool_helper.py b/tool_helper.py
new file mode 100644
index 0000000..653915d
--- /dev/null
+++ b/tool_helper.py
@@ -0,0 +1,102 @@
+
+from typing import Callable, List, Optional
+import json
+import re
+import utils
+
+tool_list = []
+
+
+def tool(fn):
+ """tool function decorator"""
+ print("register tool '%s'" % fn.__name__)
+ tool_list.append(fn)
+
+# def parse_and_execute_tool_call(message: str, tools: list[function]) -> str | None:
+# """execute tool call if needed accordint tag and return the content of the tool call or None if no call happened."""
+# pass
+
+
+
+
+
+
+def parse_and_execute_tool_call(message: str, tools: List[Callable]) -> Optional[str]:
+ """
+ Execute a tool call if the tag is present and return the tool's response.
+ If no tag is found, return None.
+
+ Args:
+ message (str): The message containing the tool call.
+ tools (list[function]): A list of tool functions available for execution.
+
+ Returns:
+ Optional[str]: The content of the tool response or None if no tool call occurred.
+ """
+
+ # in case LLM responds with the correct way
+ extracted = _match_and_extract(message, r"(.*)<\/tool_call>")
+ if extracted:
+ return _execute_tool_call_str(extracted, tools)
+
+ # in case LLM responds with by accident
+ extracted = _match_and_extract(message, r"(.*)<\/tool_.*>")
+ if extracted:
+ return _execute_tool_call_str(extracted, tools)
+
+ # in case LLM responds with ???> by accident
+ extracted = _match_and_extract(message, r"(.*)<\/.*>")
+ if extracted:
+ return _execute_tool_call_str(extracted, tools)
+
+ # in case LLM responds with ???> by accident
+ extracted = _match_and_extract(message, r"(.*)<\/.*>")
+ if extracted:
+ return _execute_tool_call_str(extracted, tools)
+
+ return None
+
+
+def _match_and_extract(message: str, pattern: str) -> Optional[str]:
+ """ helper function to match regex and extract group 1 """
+ match = re.search(pattern, message, re.DOTALL)
+ if match:
+ group1 = match.group(1)
+ return group1
+ return None
+
+
+def _execute_tool_call_str(tool_call: str, tools: List[Callable]) -> Optional[str]:
+ """ execute tool call per string content. The content must be a valid json """
+ try:
+ js = json.loads(tool_call)
+ return _execute_tool_call_json(js, tools)
+ except json.JSONDecodeError:
+ utils.print_error("Json was malformed. Will be ignored.")
+ return None
+
+def _execute_tool_call_json(data: any, tools: List[Callable]) -> Optional[str]:
+ """ extract name and arguments from parsed data and call the tool, which is matched from the tools list """
+ # Extract tool name and arguments
+ tool_name = data.get("name")
+ arguments = data.get("arguments", {})
+
+ # Find the tool by name in the list of tools
+ for tool in tools:
+ if tool.__name__ == tool_name:
+ # Execute the tool
+ return _execute_tool_function(arguments, tool)
+
+ utils.print_error("Specified tool '%s' not found." % tool_name)
+ return None
+
+def _execute_tool_function(arguments: any, tool: Callable) -> Optional[str]:
+ """ Execute the tool and return the result. """
+ try:
+ result = tool(**arguments)
+ print("", result, "")
+ return result
+ except TypeError as e:
+ utils.print_error("Type error while executing function call: '%s'" % str(e))
+
+ return None
diff --git a/utils.py b/utils.py
new file mode 100644
index 0000000..4936a6a
--- /dev/null
+++ b/utils.py
@@ -0,0 +1,11 @@
+import json
+import sys
+
+def load_json_file(filepath: str) -> any:
+ with open(filepath, "r") as f:
+ return json.load(f)
+
+
+
+def print_error(*args, **kwargs):
+ print(*args, file=sys.stderr, **kwargs)
\ No newline at end of file