tool functions and pytest

This commit is contained in:
2025-01-01 18:20:50 +01:00
parent 823f13ab51
commit fd7e3d5235
10 changed files with 406 additions and 108 deletions

228
llama.py
View File

@@ -2,8 +2,9 @@ from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
import time
import torch
import random
import datetime
import json
from tool_helper import tool_list, parse_and_execute_tool_call
from tool_functions import register_dummy
import utils
t_start = time.time()
@@ -41,10 +42,9 @@ print("load took %.3fs" % (time.time() - t_start))
max_context_length = model.config.max_position_embeddings
# if tokenizer.chat_template is None:
print("apply external chat template...")
with open("chat_template.json", "r") as f:
tokenizer.chat_template = json.load(f)
tokenizer.chat_template = utils.load_json_file("chat_template.json")
print("max_context_length is %d tokens." % (max_context_length))
@@ -97,22 +97,9 @@ messages = [
roleflip = {"role": "system", "content": "Keep the conversation going, ask for more information on the subject. Keep messages short at max 1-2 sentences. Do not thank and say goodbye."}
def current_time():
"""Get the current local date and time as a string."""
return datetime.datetime.now().strftime("%Y-%m-%d %H:%M")
def random_float():
"""Get a random float from 0..1"""
return str(random.random())
def random_int(a: int, b: int):
"""Return random integer in range [a, b], including both end points.
Args:
a: minimum possible value
b: maximum possible value"""
return str(random.randint(a, b))
tool_functions = [current_time, random_float, random_int]
register_dummy()
# tool_functions = [current_time, random_float, random_int]
@@ -139,31 +126,35 @@ def generate_incremental(inputs):
generated_tokens = input_ids # Initially, this is just the input tokens
n = 0
try:
# Loop to generate one token at a time
while True:
# Call the model with the current tokens
outputs = model(input_ids=generated_tokens, use_cache=True)
# Loop to generate one token at a time
while True:
# Call the model with the current tokens
outputs = model(input_ids=generated_tokens, use_cache=True)
# Get the next token (the last token from the generated sequence)
next_token = outputs.logits.argmax(dim=-1)[:, -1]
# Get the next token (the last token from the generated sequence)
next_token = outputs.logits.argmax(dim=-1)[:, -1]
# Append the new token to the sequence
generated_tokens = torch.cat([generated_tokens, next_token.unsqueeze(0)], dim=1)
# Append the new token to the sequence
generated_tokens = torch.cat([generated_tokens, next_token.unsqueeze(0)], dim=1)
# Decode and print the newly generated token (skip special tokens)
out_text = tokenizer.decode(next_token, skip_special_tokens=True)
print(out_text, end="", flush=True) # Print without newline
# Decode and print the newly generated token (skip special tokens)
out_text = tokenizer.decode(next_token, skip_special_tokens=True)
print(out_text, end="", flush=True) # Print without newline
# Check if the generated token is the end-of-sequence token
if next_token.item() == tokenizer.eos_token_id:
print("")
break
# Check if the generated token is the end-of-sequence token
if next_token.item() == tokenizer.eos_token_id:
print("")
break
n += 1
if n >= 30:
n = 0
torch.cuda.empty_cache()
n += 1
if n >= 15:
n = 0
torch.cuda.empty_cache()
except KeyboardInterrupt:
pass
# Once done, return the full generated sequence
@@ -184,7 +175,7 @@ def append_generate_chat(input_text: str, role="user"):
# input_text = "Hello, who are you?"
# inputs = tokenizer(input_text, return_tensors="pt").to("cpu") # .to("cuda") .to("cpu")
inputs = tokenizer.apply_chat_template(messages, return_tensors="pt", tokenize=True, return_dict=True, add_generation_prompt=True, tools=tool_functions) #continue_final_message=True,
inputs = tokenizer.apply_chat_template(messages, return_tensors="pt", tokenize=True, return_dict=True, add_generation_prompt=True, tools=tool_list) #continue_final_message=True,
inputs = {key: value.to(model.device) for key, value in inputs.items()}
# inputs = {key: value.to("cpu") for key, value in inputs.items()}
# inputs["input_ids"] = inputs["input_ids"][:, 1:]
@@ -194,82 +185,105 @@ def append_generate_chat(input_text: str, role="user"):
# append result to message history
messages.append({"role": "assistant", "content": out_text})
print("")
print("generation took %.3fs (%d tokens)" % (time.time() - t_start, len(outputs[0])))
while True:
# print an input prompt to receive text or commands
input_text = input(">>> ")
print("")
# handle tool call and check if a tool call has happened.
tool_result = parse_and_execute_tool_call(out_text, tool_list)
if tool_result != None:
# tool call happened
# tool_result = "<tool_response>%s</tool_response>" % tool_result
# depending on the chat template the tool response tags must or must not be passed. :(
append_generate_chat(tool_result, role="tool")
if input_text.startswith("!"):
# append_generate_chat("<tool_response>%s</tool_response>" % input_text[1:], role="tool")
append_generate_chat("%s" % input_text[1:], role="tool") # depending on the chat template the tool response tags must or must not be passed. :(
elif input_text.startswith("/clear"):
print("clearing chat history")
messages = [messages[0]]
def main():
global messages
while True:
# print an input prompt to receive text or commands
input_text = input(">>> ")
print("")
elif input_text.startswith("/history"):
history = tokenizer.apply_chat_template(messages, return_tensors="pt", tokenize=False, add_generation_prompt=False, tools=tool_functions)
print(history)
elif input_text.startswith("/undo"):
if len(messages) > 2:
print("undo latest prompt")
messages = messages[:-2]
else:
print("cannot undo because there are not enough messages on history.")
print("")
if input_text.startswith("!"):
# append_generate_chat("<tool_response>%s</tool_response>" % input_text[1:], role="tool")
append_generate_chat("%s" % input_text[1:], role="tool") # depending on the chat template the tool response tags must or must not be passed. :(
elif input_text.startswith("/regen"):
if len(messages) >= 2:
print("regenerating message (not working)")
messages = messages[:-1]
seed = random.randint(0, 2**32 - 1) # Generate a random seed
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
elif input_text.startswith("/clear"):
print("clearing chat history")
start_msg = messages[0]
messages = [start_msg]
print("")
elif input_text.startswith("/history"):
history = tokenizer.apply_chat_template(messages, return_tensors="pt", tokenize=False, add_generation_prompt=False, tools=tool_list)
print(history)
elif input_text.startswith("/undo"):
if len(messages) > 2:
print("undo latest prompt")
messages = messages[:-2]
else:
print("cannot undo because there are not enough messages on history.")
print("")
elif input_text.startswith("/regen"):
if len(messages) >= 2:
print("regenerating message (not working)")
messages = messages[:-1]
seed = random.randint(0, 2**32 - 1) # Generate a random seed
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
append_generate_chat(None)
else:
print("cannot regenerate because there are not enough messages on history.")
print("")
elif input_text.startswith("/more"):
append_generate_chat(None)
elif input_text.startswith("/file"):
filename = input_text[len("/file "):]
print("read '%s' for prompt:" % filename)
with open(filename, "r") as f:
content = f.read()
print(content)
append_generate_chat(content)
elif input_text.startswith("/auto"):
messages_backup = messages
messages = [roleflip]
for m in messages_backup:
role = m["role"]
content = m["content"]
if role == "user":
role = "assistant"
elif role == "assistant":
role = "user"
if role != "system":
messages.append({"role": role, "content": content})
append_generate_chat(None) # will automatically advance the conversation as 'user'
last_message = messages[-1]
last_message["role"] = "user"
messages = messages_backup + [last_message]
append_generate_chat(None) # 'regular' chatbot answer
elif input_text.startswith("/help"):
print("!<prompt> answer as 'tool' in <tool_response> tags")
print("/clear clear chat history")
print("/undo undo latest prompt")
print("/regen regenerate the last message")
print("/more generate more additional information")
print("/file read prompt input from file")
print("/auto automatically advance conversation")
print("/help print this message")
print("")
elif input_text.startswith("/"):
print("unknown command.")
else:
print("cannot regenerate because there are not enough messages on history.")
print("")
append_generate_chat(input_text)
elif input_text.startswith("/more"):
append_generate_chat(None)
elif input_text.startswith("/auto"):
messages_backup = messages
messages = [roleflip]
for m in messages_backup:
role = m["role"]
content = m["content"]
if role == "user":
role = "assistant"
elif role == "assistant":
role = "user"
if role != "system":
messages.append({"role": role, "content": content})
append_generate_chat(None) # will automatically advance the conversation as 'user'
last_message = messages[-1]
last_message["role"] = "user"
messages = messages_backup + [last_message]
append_generate_chat(None) # 'regular' chatbot answer
elif input_text.startswith("/help"):
print("!<prompt> answer as 'tool' in <tool_response> tags")
print("/clear clear chat history")
print("/undo undo latest prompt")
print("/regen regenerate the last message")
print("/more generate more additional information")
print("/auto automatically advance conversation")
print("/help print this message")
print("")
elif input_text.startswith("/"):
print("unknown command.")
else:
append_generate_chat(input_text)