tool functions and pytest
This commit is contained in:
228
llama.py
228
llama.py
@@ -2,8 +2,9 @@ from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
|
||||
import time
|
||||
import torch
|
||||
import random
|
||||
import datetime
|
||||
import json
|
||||
from tool_helper import tool_list, parse_and_execute_tool_call
|
||||
from tool_functions import register_dummy
|
||||
import utils
|
||||
|
||||
t_start = time.time()
|
||||
|
||||
@@ -41,10 +42,9 @@ print("load took %.3fs" % (time.time() - t_start))
|
||||
max_context_length = model.config.max_position_embeddings
|
||||
|
||||
|
||||
# if tokenizer.chat_template is None:
|
||||
print("apply external chat template...")
|
||||
with open("chat_template.json", "r") as f:
|
||||
tokenizer.chat_template = json.load(f)
|
||||
tokenizer.chat_template = utils.load_json_file("chat_template.json")
|
||||
|
||||
|
||||
|
||||
|
||||
print("max_context_length is %d tokens." % (max_context_length))
|
||||
@@ -97,22 +97,9 @@ messages = [
|
||||
roleflip = {"role": "system", "content": "Keep the conversation going, ask for more information on the subject. Keep messages short at max 1-2 sentences. Do not thank and say goodbye."}
|
||||
|
||||
|
||||
def current_time():
|
||||
"""Get the current local date and time as a string."""
|
||||
return datetime.datetime.now().strftime("%Y-%m-%d %H:%M")
|
||||
|
||||
def random_float():
|
||||
"""Get a random float from 0..1"""
|
||||
return str(random.random())
|
||||
|
||||
def random_int(a: int, b: int):
|
||||
"""Return random integer in range [a, b], including both end points.
|
||||
Args:
|
||||
a: minimum possible value
|
||||
b: maximum possible value"""
|
||||
return str(random.randint(a, b))
|
||||
|
||||
tool_functions = [current_time, random_float, random_int]
|
||||
register_dummy()
|
||||
# tool_functions = [current_time, random_float, random_int]
|
||||
|
||||
|
||||
|
||||
@@ -139,31 +126,35 @@ def generate_incremental(inputs):
|
||||
generated_tokens = input_ids # Initially, this is just the input tokens
|
||||
|
||||
n = 0
|
||||
try:
|
||||
|
||||
# Loop to generate one token at a time
|
||||
while True:
|
||||
# Call the model with the current tokens
|
||||
outputs = model(input_ids=generated_tokens, use_cache=True)
|
||||
# Loop to generate one token at a time
|
||||
while True:
|
||||
# Call the model with the current tokens
|
||||
outputs = model(input_ids=generated_tokens, use_cache=True)
|
||||
|
||||
# Get the next token (the last token from the generated sequence)
|
||||
next_token = outputs.logits.argmax(dim=-1)[:, -1]
|
||||
# Get the next token (the last token from the generated sequence)
|
||||
next_token = outputs.logits.argmax(dim=-1)[:, -1]
|
||||
|
||||
# Append the new token to the sequence
|
||||
generated_tokens = torch.cat([generated_tokens, next_token.unsqueeze(0)], dim=1)
|
||||
# Append the new token to the sequence
|
||||
generated_tokens = torch.cat([generated_tokens, next_token.unsqueeze(0)], dim=1)
|
||||
|
||||
# Decode and print the newly generated token (skip special tokens)
|
||||
out_text = tokenizer.decode(next_token, skip_special_tokens=True)
|
||||
print(out_text, end="", flush=True) # Print without newline
|
||||
# Decode and print the newly generated token (skip special tokens)
|
||||
out_text = tokenizer.decode(next_token, skip_special_tokens=True)
|
||||
print(out_text, end="", flush=True) # Print without newline
|
||||
|
||||
# Check if the generated token is the end-of-sequence token
|
||||
if next_token.item() == tokenizer.eos_token_id:
|
||||
print("")
|
||||
break
|
||||
# Check if the generated token is the end-of-sequence token
|
||||
if next_token.item() == tokenizer.eos_token_id:
|
||||
print("")
|
||||
break
|
||||
|
||||
n += 1
|
||||
if n >= 30:
|
||||
n = 0
|
||||
torch.cuda.empty_cache()
|
||||
n += 1
|
||||
if n >= 15:
|
||||
n = 0
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
except KeyboardInterrupt:
|
||||
pass
|
||||
|
||||
|
||||
# Once done, return the full generated sequence
|
||||
@@ -184,7 +175,7 @@ def append_generate_chat(input_text: str, role="user"):
|
||||
|
||||
# input_text = "Hello, who are you?"
|
||||
# inputs = tokenizer(input_text, return_tensors="pt").to("cpu") # .to("cuda") .to("cpu")
|
||||
inputs = tokenizer.apply_chat_template(messages, return_tensors="pt", tokenize=True, return_dict=True, add_generation_prompt=True, tools=tool_functions) #continue_final_message=True,
|
||||
inputs = tokenizer.apply_chat_template(messages, return_tensors="pt", tokenize=True, return_dict=True, add_generation_prompt=True, tools=tool_list) #continue_final_message=True,
|
||||
inputs = {key: value.to(model.device) for key, value in inputs.items()}
|
||||
# inputs = {key: value.to("cpu") for key, value in inputs.items()}
|
||||
# inputs["input_ids"] = inputs["input_ids"][:, 1:]
|
||||
@@ -194,82 +185,105 @@ def append_generate_chat(input_text: str, role="user"):
|
||||
|
||||
# append result to message history
|
||||
messages.append({"role": "assistant", "content": out_text})
|
||||
|
||||
print("")
|
||||
print("generation took %.3fs (%d tokens)" % (time.time() - t_start, len(outputs[0])))
|
||||
|
||||
|
||||
while True:
|
||||
# print an input prompt to receive text or commands
|
||||
input_text = input(">>> ")
|
||||
print("")
|
||||
# handle tool call and check if a tool call has happened.
|
||||
tool_result = parse_and_execute_tool_call(out_text, tool_list)
|
||||
if tool_result != None:
|
||||
# tool call happened
|
||||
# tool_result = "<tool_response>%s</tool_response>" % tool_result
|
||||
# depending on the chat template the tool response tags must or must not be passed. :(
|
||||
append_generate_chat(tool_result, role="tool")
|
||||
|
||||
|
||||
if input_text.startswith("!"):
|
||||
# append_generate_chat("<tool_response>%s</tool_response>" % input_text[1:], role="tool")
|
||||
append_generate_chat("%s" % input_text[1:], role="tool") # depending on the chat template the tool response tags must or must not be passed. :(
|
||||
|
||||
elif input_text.startswith("/clear"):
|
||||
print("clearing chat history")
|
||||
messages = [messages[0]]
|
||||
def main():
|
||||
global messages
|
||||
while True:
|
||||
# print an input prompt to receive text or commands
|
||||
input_text = input(">>> ")
|
||||
print("")
|
||||
|
||||
elif input_text.startswith("/history"):
|
||||
history = tokenizer.apply_chat_template(messages, return_tensors="pt", tokenize=False, add_generation_prompt=False, tools=tool_functions)
|
||||
print(history)
|
||||
|
||||
elif input_text.startswith("/undo"):
|
||||
if len(messages) > 2:
|
||||
print("undo latest prompt")
|
||||
messages = messages[:-2]
|
||||
else:
|
||||
print("cannot undo because there are not enough messages on history.")
|
||||
print("")
|
||||
if input_text.startswith("!"):
|
||||
# append_generate_chat("<tool_response>%s</tool_response>" % input_text[1:], role="tool")
|
||||
append_generate_chat("%s" % input_text[1:], role="tool") # depending on the chat template the tool response tags must or must not be passed. :(
|
||||
|
||||
elif input_text.startswith("/regen"):
|
||||
if len(messages) >= 2:
|
||||
print("regenerating message (not working)")
|
||||
messages = messages[:-1]
|
||||
seed = random.randint(0, 2**32 - 1) # Generate a random seed
|
||||
torch.manual_seed(seed)
|
||||
torch.cuda.manual_seed_all(seed)
|
||||
elif input_text.startswith("/clear"):
|
||||
print("clearing chat history")
|
||||
start_msg = messages[0]
|
||||
messages = [start_msg]
|
||||
print("")
|
||||
|
||||
elif input_text.startswith("/history"):
|
||||
history = tokenizer.apply_chat_template(messages, return_tensors="pt", tokenize=False, add_generation_prompt=False, tools=tool_list)
|
||||
print(history)
|
||||
|
||||
elif input_text.startswith("/undo"):
|
||||
if len(messages) > 2:
|
||||
print("undo latest prompt")
|
||||
messages = messages[:-2]
|
||||
else:
|
||||
print("cannot undo because there are not enough messages on history.")
|
||||
print("")
|
||||
|
||||
elif input_text.startswith("/regen"):
|
||||
if len(messages) >= 2:
|
||||
print("regenerating message (not working)")
|
||||
messages = messages[:-1]
|
||||
seed = random.randint(0, 2**32 - 1) # Generate a random seed
|
||||
torch.manual_seed(seed)
|
||||
torch.cuda.manual_seed_all(seed)
|
||||
append_generate_chat(None)
|
||||
else:
|
||||
print("cannot regenerate because there are not enough messages on history.")
|
||||
print("")
|
||||
|
||||
elif input_text.startswith("/more"):
|
||||
append_generate_chat(None)
|
||||
|
||||
elif input_text.startswith("/file"):
|
||||
filename = input_text[len("/file "):]
|
||||
print("read '%s' for prompt:" % filename)
|
||||
with open(filename, "r") as f:
|
||||
content = f.read()
|
||||
print(content)
|
||||
append_generate_chat(content)
|
||||
|
||||
elif input_text.startswith("/auto"):
|
||||
messages_backup = messages
|
||||
messages = [roleflip]
|
||||
for m in messages_backup:
|
||||
role = m["role"]
|
||||
content = m["content"]
|
||||
if role == "user":
|
||||
role = "assistant"
|
||||
elif role == "assistant":
|
||||
role = "user"
|
||||
if role != "system":
|
||||
messages.append({"role": role, "content": content})
|
||||
append_generate_chat(None) # will automatically advance the conversation as 'user'
|
||||
last_message = messages[-1]
|
||||
last_message["role"] = "user"
|
||||
messages = messages_backup + [last_message]
|
||||
append_generate_chat(None) # 'regular' chatbot answer
|
||||
|
||||
elif input_text.startswith("/help"):
|
||||
print("!<prompt> answer as 'tool' in <tool_response> tags")
|
||||
print("/clear clear chat history")
|
||||
print("/undo undo latest prompt")
|
||||
print("/regen regenerate the last message")
|
||||
print("/more generate more additional information")
|
||||
print("/file read prompt input from file")
|
||||
print("/auto automatically advance conversation")
|
||||
print("/help print this message")
|
||||
print("")
|
||||
|
||||
elif input_text.startswith("/"):
|
||||
print("unknown command.")
|
||||
|
||||
else:
|
||||
print("cannot regenerate because there are not enough messages on history.")
|
||||
print("")
|
||||
append_generate_chat(input_text)
|
||||
|
||||
elif input_text.startswith("/more"):
|
||||
append_generate_chat(None)
|
||||
|
||||
elif input_text.startswith("/auto"):
|
||||
messages_backup = messages
|
||||
messages = [roleflip]
|
||||
for m in messages_backup:
|
||||
role = m["role"]
|
||||
content = m["content"]
|
||||
if role == "user":
|
||||
role = "assistant"
|
||||
elif role == "assistant":
|
||||
role = "user"
|
||||
if role != "system":
|
||||
messages.append({"role": role, "content": content})
|
||||
append_generate_chat(None) # will automatically advance the conversation as 'user'
|
||||
last_message = messages[-1]
|
||||
last_message["role"] = "user"
|
||||
messages = messages_backup + [last_message]
|
||||
append_generate_chat(None) # 'regular' chatbot answer
|
||||
|
||||
elif input_text.startswith("/help"):
|
||||
print("!<prompt> answer as 'tool' in <tool_response> tags")
|
||||
print("/clear clear chat history")
|
||||
print("/undo undo latest prompt")
|
||||
print("/regen regenerate the last message")
|
||||
print("/more generate more additional information")
|
||||
print("/auto automatically advance conversation")
|
||||
print("/help print this message")
|
||||
print("")
|
||||
|
||||
elif input_text.startswith("/"):
|
||||
print("unknown command.")
|
||||
|
||||
else:
|
||||
append_generate_chat(input_text)
|
||||
|
Reference in New Issue
Block a user