You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
170 lines
7.6 KiB
170 lines
7.6 KiB
import time
|
|
import json
|
|
import random
|
|
from tool_helper import tool_list, parse_and_execute_tool_call
|
|
from inference import Inference, torch_reseed
|
|
from file_append import check_append_file
|
|
|
|
|
|
|
|
def msg(role: str, content: str) -> dict:
|
|
return {"role": role, "content": content}
|
|
|
|
|
|
class Terminal:
|
|
|
|
def __init__(self, inference: Inference, systemmessage: dict):
|
|
self.inference = inference
|
|
self.messages:list[dict] = [systemmessage]
|
|
|
|
# these are meant to be overwritten by better ones
|
|
self.roleflip = msg("system", "keep going.")
|
|
self.summarize = msg("system", "summarize conversation")
|
|
self.summarize_user = msg("system", "please summarize conversation")
|
|
self.title_prompt = msg("system", "create a title for this conversation")
|
|
|
|
def append_generate_chat(self, input_text: str, role="user"):
|
|
t_start = time.time()
|
|
|
|
# generate AI response
|
|
if input_text != None:
|
|
self.messages.append({"role": role, "content": input_text})
|
|
|
|
inputs = self.inference.tokenize(self.messages, tokenize=True)
|
|
number_of_input_tokens = inputs.shape[1]
|
|
|
|
outputs, out_text = self.inference.generate(inputs)
|
|
|
|
# append result to message history
|
|
self.messages.append({"role": "assistant", "content": out_text})
|
|
|
|
print("")
|
|
time_taken = time.time() - t_start
|
|
number_of_tokens = len(outputs[0])
|
|
tokens_per_second = (number_of_tokens - number_of_input_tokens) / time_taken
|
|
print("generation took %.3fs (%d tokens, %.3f t/s)" % (time_taken, number_of_tokens, tokens_per_second))
|
|
|
|
# handle tool call and check if a tool call has happened.
|
|
tool_result = parse_and_execute_tool_call(out_text, tool_list)
|
|
if tool_result != None:
|
|
# tool call happened
|
|
tool_result = "<tool_response>%s</tool_response>" % tool_result
|
|
# depending on the chat template the tool response tags must or must not be passed. :(
|
|
self.append_generate_chat(tool_result, role="tool")
|
|
|
|
def join(self):
|
|
|
|
while True:
|
|
# print an input prompt to receive text or commands
|
|
input_text = input(">>> ")
|
|
print("")
|
|
|
|
input_text = check_append_file(input_text)
|
|
|
|
|
|
if input_text.startswith("!"):
|
|
self.append_generate_chat("<tool_response>%s</tool_response>" % input_text[1:], role="tool")
|
|
# append_generate_chat("%s" % input_text[1:], role="tool") # depending on the chat template the tool response tags must or must not be passed. :(
|
|
|
|
elif input_text.startswith("/clear"):
|
|
print("clearing chat history")
|
|
start_msg = self.messages[0]
|
|
self.message = [start_msg]
|
|
print("")
|
|
|
|
elif input_text.startswith("/history"):
|
|
history = self.inference.tokenize(self.messages, tokenize=False)
|
|
# history = tokenizer.apply_chat_template(self.message, return_tensors="pt", tokenize=False, add_generation_prompt=False)
|
|
print(history)
|
|
|
|
elif input_text.startswith("/undo"):
|
|
if len(self.messages) > 2:
|
|
print("undo latest prompt")
|
|
self.message = self.messages[:-2]
|
|
else:
|
|
print("cannot undo because there are not enough self.message on history.")
|
|
print("")
|
|
|
|
elif input_text.startswith("/regen"):
|
|
if len(self.messages) >= 2:
|
|
print("regenerating message (not working)")
|
|
self.messages = self.messages[:-1]
|
|
seed = random.randint(0, 2**32 - 1) # Generate a random seed
|
|
torch_reseed(seed)
|
|
self.append_generate_chat(None)
|
|
else:
|
|
print("cannot regenerate because there are not enough self.message on history.")
|
|
print("")
|
|
|
|
elif input_text.startswith("/more"):
|
|
self.append_generate_chat(None)
|
|
|
|
elif input_text.startswith("/file"):
|
|
filename = input_text[len("/file "):]
|
|
print("read '%s' for prompt:" % filename)
|
|
with open(filename, "r") as f:
|
|
content = f.read()
|
|
print(content)
|
|
self.append_generate_chat(content)
|
|
|
|
elif input_text.startswith("/auto"):
|
|
message_backup = self.messages
|
|
self.messages = [self.roleflip]
|
|
for m in self.message_backup:
|
|
role = m["role"]
|
|
content = m["content"]
|
|
if role == "user":
|
|
role = "assistant"
|
|
elif role == "assistant":
|
|
role = "user"
|
|
if role != "system":
|
|
self.message.append({"role": role, "content": content})
|
|
self.append_generate_chat(None) # will automatically advance the conversation as 'user'
|
|
last_message = self.messages[-1]
|
|
last_message["role"] = "user"
|
|
self.messages = message_backup + [last_message]
|
|
self.append_generate_chat(None) # 'regular' chatbot answer
|
|
|
|
elif input_text.startswith("/summarize"):
|
|
messages_temp = list(filter(lambda x: x["role"] != "system", self.messages))
|
|
messages_temp = [self.summarize] + messages_temp + [self.summarize_user] # copy dict in last instance
|
|
# messages_temp[-1]["role"] = "user"
|
|
input_ids = self.inference.tokenize(messages_temp, tokenize=True, assistant_prefix="The conversation was about ")
|
|
generated_tokens, full_output = self.inference.generate(input_ids)
|
|
|
|
elif input_text.startswith("/title"):
|
|
messages_temp = list(filter(lambda x: x["role"] != "system", self.messages))
|
|
messages_temp = [self.title_prompt] + messages_temp #+ [dict(title)] # copy dict in last instance
|
|
messages_temp[-1]["role"] = "user"
|
|
input_ids = self.inference.tokenize(messages_temp, tokenize=True, assistant_prefix="Title: ")
|
|
generated_tokens, full_output = self.inference.generate(input_ids)
|
|
|
|
elif input_text.startswith("/save"):
|
|
with open("messages.json", "w") as f:
|
|
json.dump(self.messages, f, indent=4)
|
|
|
|
elif input_text.startswith("/load"):
|
|
with open("messages.json", "r") as f:
|
|
new_messages = json.load(f)
|
|
self.messages = [self.messages[0]] + new_messages[1:]
|
|
|
|
elif input_text.startswith("/help"):
|
|
print("!<prompt> answer as 'tool' in <tool_response> tags")
|
|
print("/clear clear chat history")
|
|
print("/undo undo latest prompt")
|
|
print("/regen regenerate the last message")
|
|
print("/more generate more additional information")
|
|
print("/file read prompt input from file")
|
|
print("/auto automatically advance conversation")
|
|
print("/summarize generate a summary of the chat")
|
|
print("/title generate a title of the chat")
|
|
print("/save write chat history to file")
|
|
print("/load load previously saved history")
|
|
print("/help print this message")
|
|
print("")
|
|
|
|
elif input_text.startswith("/"):
|
|
print("unknown command.")
|
|
|
|
else:
|
|
self.append_generate_chat(input_text)
|