You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 

149 lines
5.2 KiB

import time
import random
from tool_helper import tool_list, parse_and_execute_tool_call
from tool_functions import register_dummy
from inference import Inference, torch_reseed
messages = []
inference = None
systemmessage = "Hold a casual conversation with the user. Keep responses short at max 3 sentences."
roleflip = {"role": "system", "content": "Keep the conversation going, ask for more information on the subject. Keep messages short at max 1-2 sentences. Do not thank and say goodbye."}
register_dummy()
# tool_functions = [current_time, random_float, random_int]
def append_generate_chat(input_text: str, role="user"):
t_start = time.time()
# generate AI response
if input_text != None:
messages.append({"role": role, "content": input_text})
inputs = inference.tokenize(messages, tokenize=True)
outputs, out_text = inference.generate_incremental(inputs)
# append result to message history
messages.append({"role": "assistant", "content": out_text})
print("")
print("generation took %.3fs (%d tokens)" % (time.time() - t_start, len(outputs[0])))
# handle tool call and check if a tool call has happened.
tool_result = parse_and_execute_tool_call(out_text, tool_list)
if tool_result != None:
# tool call happened
# tool_result = "<tool_response>%s</tool_response>" % tool_result
# depending on the chat template the tool response tags must or must not be passed. :(
append_generate_chat(tool_result, role="tool")
def main():
global messages
global inference
inference = Inference()
messages = [{"role": "system", "content": systemmessage + "\n" + inference.generate_tool_use_header(tool_list)}]
while True:
# print an input prompt to receive text or commands
input_text = input(">>> ")
print("")
if input_text.startswith("!"):
# append_generate_chat("<tool_response>%s</tool_response>" % input_text[1:], role="tool")
append_generate_chat("%s" % input_text[1:], role="tool") # depending on the chat template the tool response tags must or must not be passed. :(
elif input_text.startswith("/clear"):
print("clearing chat history")
start_msg = messages[0]
messages = [start_msg]
print("")
elif input_text.startswith("/history"):
history = inference.tokenize(messages, tokenize=False)
# history = tokenizer.apply_chat_template(messages, return_tensors="pt", tokenize=False, add_generation_prompt=False)
print(history)
elif input_text.startswith("/undo"):
if len(messages) > 2:
print("undo latest prompt")
messages = messages[:-2]
else:
print("cannot undo because there are not enough messages on history.")
print("")
elif input_text.startswith("/regen"):
if len(messages) >= 2:
print("regenerating message (not working)")
messages = messages[:-1]
seed = random.randint(0, 2**32 - 1) # Generate a random seed
torch_reseed(seed)
append_generate_chat(None)
else:
print("cannot regenerate because there are not enough messages on history.")
print("")
elif input_text.startswith("/more"):
append_generate_chat(None)
elif input_text.startswith("/file"):
filename = input_text[len("/file "):]
print("read '%s' for prompt:" % filename)
with open(filename, "r") as f:
content = f.read()
print(content)
append_generate_chat(content)
elif input_text.startswith("/auto"):
messages_backup = messages
messages = [roleflip]
for m in messages_backup:
role = m["role"]
content = m["content"]
if role == "user":
role = "assistant"
elif role == "assistant":
role = "user"
if role != "system":
messages.append({"role": role, "content": content})
append_generate_chat(None) # will automatically advance the conversation as 'user'
last_message = messages[-1]
last_message["role"] = "user"
messages = messages_backup + [last_message]
append_generate_chat(None) # 'regular' chatbot answer
elif input_text.startswith("/help"):
print("!<prompt> answer as 'tool' in <tool_response> tags")
print("/clear clear chat history")
print("/undo undo latest prompt")
print("/regen regenerate the last message")
print("/more generate more additional information")
print("/file read prompt input from file")
print("/auto automatically advance conversation")
print("/help print this message")
print("")
elif input_text.startswith("/"):
print("unknown command.")
else:
append_generate_chat(input_text)
if __name__ == "__main__":
main()