@ -2,8 +2,9 @@ from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
import time
import torch
import random
import datetime
import json
from tool_helper import tool_list , parse_and_execute_tool_call
from tool_functions import register_dummy
import utils
t_start = time . time ( )
@ -41,10 +42,9 @@ print("load took %.3fs" % (time.time() - t_start))
max_context_length = model . config . max_position_embeddings
# if tokenizer.chat_template is None:
print ( " apply external chat template... " )
with open ( " chat_template.json " , " r " ) as f :
tokenizer . chat_template = json . load ( f )
tokenizer . chat_template = utils . load_json_file ( " chat_template.json " )
print ( " max_context_length is %d tokens. " % ( max_context_length ) )
@ -97,22 +97,9 @@ messages = [
roleflip = { " role " : " system " , " content " : " Keep the conversation going, ask for more information on the subject. Keep messages short at max 1-2 sentences. Do not thank and say goodbye. " }
def current_time ( ) :
""" Get the current local date and time as a string. """
return datetime . datetime . now ( ) . strftime ( " % Y- % m- %d % H: % M " )
def random_float ( ) :
""" Get a random float from 0..1 """
return str ( random . random ( ) )
def random_int ( a : int , b : int ) :
""" Return random integer in range [a, b], including both end points.
Args :
a : minimum possible value
b : maximum possible value """
return str ( random . randint ( a , b ) )
tool_functions = [ current_time , random_float , random_int ]
register_dummy ( )
# tool_functions = [current_time, random_float, random_int]
@ -139,31 +126,35 @@ def generate_incremental(inputs):
generated_tokens = input_ids # Initially, this is just the input tokens
n = 0
try :
# Loop to generate one token at a time
while True :
# Call the model with the current tokens
outputs = model ( input_ids = generated_tokens , use_cache = True )
# Loop to generate one token at a time
while True :
# Call the model with the current tokens
outputs = model ( input_ids = generated_tokens , use_cache = True )
# Get the next token (the last token from the generated sequence)
next_token = outputs . logits . argmax ( dim = - 1 ) [ : , - 1 ]
# Get the next token (the last token from the generated sequence)
next_token = outputs . logits . argmax ( dim = - 1 ) [ : , - 1 ]
# Append the new token to the sequence
generated_tokens = torch . cat ( [ generated_tokens , next_token . unsqueeze ( 0 ) ] , dim = 1 )
# Append the new token to the sequence
generated_tokens = torch . cat ( [ generated_tokens , next_token . unsqueeze ( 0 ) ] , dim = 1 )
# Decode and print the newly generated token (skip special tokens)
out_text = tokenizer . decode ( next_token , skip_special_tokens = True )
print ( out_text , end = " " , flush = True ) # Print without newline
# Decode and print the newly generated token (skip special tokens)
out_text = tokenizer . decode ( next_token , skip_special_tokens = True )
print ( out_text , end = " " , flush = True ) # Print without newline
# Check if the generated token is the end-of-sequence token
if next_token . item ( ) == tokenizer . eos_token_id :
print ( " " )
break
# Check if the generated token is the end-of-sequence token
if next_token . item ( ) == tokenizer . eos_token_id :
print ( " " )
break
n + = 1
if n > = 15 :
n = 0
torch . cuda . empty_cache ( )
n + = 1
if n > = 30 :
n = 0
torch . cuda . empty_cache ( )
except KeyboardInterrupt :
pass
# Once done, return the full generated sequence
@ -184,7 +175,7 @@ def append_generate_chat(input_text: str, role="user"):
# input_text = "Hello, who are you?"
# inputs = tokenizer(input_text, return_tensors="pt").to("cpu") # .to("cuda") .to("cpu")
inputs = tokenizer . apply_chat_template ( messages , return_tensors = " pt " , tokenize = True , return_dict = True , add_generation_prompt = True , tools = tool_functions ) #continue_final_message=True,
inputs = tokenizer . apply_chat_template ( messages , return_tensors = " pt " , tokenize = True , return_dict = True , add_generation_prompt = True , tools = tool_list ) #continue_final_message=True,
inputs = { key : value . to ( model . device ) for key , value in inputs . items ( ) }
# inputs = {key: value.to("cpu") for key, value in inputs.items()}
# inputs["input_ids"] = inputs["input_ids"][:, 1:]
@ -194,82 +185,105 @@ def append_generate_chat(input_text: str, role="user"):
# append result to message history
messages . append ( { " role " : " assistant " , " content " : out_text } )
print ( " " )
print ( " generation took %.3f s ( %d tokens) " % ( time . time ( ) - t_start , len ( outputs [ 0 ] ) ) )
while True :
# print an input prompt to receive text or commands
input_text = input ( " >>> " )
print ( " " )
# handle tool call and check if a tool call has happened.
tool_result = parse_and_execute_tool_call ( out_text , tool_list )
if tool_result != None :
# tool call happened
# tool_result = "<tool_response>%s</tool_response>" % tool_result
# depending on the chat template the tool response tags must or must not be passed. :(
append_generate_chat ( tool_result , role = " tool " )
if input_text . startswith ( " ! " ) :
# append_generate_chat("<tool_response>%s</tool_response>" % input_text[1:], role="tool")
append_generate_chat ( " %s " % input_text [ 1 : ] , role = " tool " ) # depending on the chat template the tool response tags must or must not be passed. :(
elif input_text . startswith ( " /clear " ) :
print ( " clearing chat history " )
messages = [ messages [ 0 ] ]
def main ( ) :
global messages
while True :
# print an input prompt to receive text or commands
input_text = input ( " >>> " )
print ( " " )
elif input_text . startswith ( " /history " ) :
history = tokenizer . apply_chat_template ( messages , return_tensors = " pt " , tokenize = False , add_generation_prompt = False , tools = tool_functions )
print ( history )
elif input_text . startswith ( " /undo " ) :
if len ( messages ) > 2 :
print ( " undo latest prompt " )
messages = messages [ : - 2 ]
else :
print ( " cannot undo because there are not enough messages on history. " )
print ( " " )
if input_text . startswith ( " ! " ) :
# append_generate_chat("<tool_response>%s</tool_response>" % input_text[1:], role="tool")
append_generate_chat ( " %s " % input_text [ 1 : ] , role = " tool " ) # depending on the chat template the tool response tags must or must not be passed. :(
elif input_text . startswith ( " /regen " ) :
if len ( messages ) > = 2 :
print ( " regenerating message (not working) " )
messages = messages [ : - 1 ]
seed = random . randint ( 0 , 2 * * 32 - 1 ) # Generate a random seed
torch . manual_seed ( seed )
torch . cuda . manual_seed_all ( seed )
elif input_text . startswith ( " /clear " ) :
print ( " clearing chat history " )
start_msg = messages [ 0 ]
messages = [ start_msg ]
print ( " " )
elif input_text . startswith ( " /history " ) :
history = tokenizer . apply_chat_template ( messages , return_tensors = " pt " , tokenize = False , add_generation_prompt = False , tools = tool_list )
print ( history )
elif input_text . startswith ( " /undo " ) :
if len ( messages ) > 2 :
print ( " undo latest prompt " )
messages = messages [ : - 2 ]
else :
print ( " cannot undo because there are not enough messages on history. " )
print ( " " )
elif input_text . startswith ( " /regen " ) :
if len ( messages ) > = 2 :
print ( " regenerating message (not working) " )
messages = messages [ : - 1 ]
seed = random . randint ( 0 , 2 * * 32 - 1 ) # Generate a random seed
torch . manual_seed ( seed )
torch . cuda . manual_seed_all ( seed )
append_generate_chat ( None )
else :
print ( " cannot regenerate because there are not enough messages on history. " )
print ( " " )
elif input_text . startswith ( " /more " ) :
append_generate_chat ( None )
else :
print ( " cannot regenerate because there are not enough messages on history. " )
print ( " " )
elif input_text . startswith ( " /more " ) :
append_generate_chat ( None )
elif input_text . startswith ( " /auto " ) :
messages_backup = messages
messages = [ roleflip ]
for m in messages_backup :
role = m [ " role " ]
content = m [ " content " ]
if role == " user " :
role = " assistant "
elif role == " assistant " :
role = " user "
if role != " system " :
messages . append ( { " role " : role , " content " : content } )
append_generate_chat ( None ) # will automatically advance the conversation as 'user'
last_message = messages [ - 1 ]
last_message [ " role " ] = " user "
messages = messages_backup + [ last_message ]
append_generate_chat ( None ) # 'regular' chatbot answer
elif input_text . startswith ( " /help " ) :
print ( " !<prompt> answer as ' tool ' in <tool_response> tags " )
print ( " /clear clear chat history " )
print ( " /undo undo latest prompt " )
print ( " /regen regenerate the last message " )
print ( " /more generate more additional information " )
print ( " /auto automatically advance conversation " )
print ( " /help print this message " )
print ( " " )
elif input_text . startswith ( " /file " ) :
filename = input_text [ len ( " /file " ) : ]
print ( " read ' %s ' for prompt: " % filename )
with open ( filename , " r " ) as f :
content = f . read ( )
print ( content )
append_generate_chat ( content )
elif input_text . startswith ( " /auto " ) :
messages_backup = messages
messages = [ roleflip ]
for m in messages_backup :
role = m [ " role " ]
content = m [ " content " ]
if role == " user " :
role = " assistant "
elif role == " assistant " :
role = " user "
if role != " system " :
messages . append ( { " role " : role , " content " : content } )
append_generate_chat ( None ) # will automatically advance the conversation as 'user'
last_message = messages [ - 1 ]
last_message [ " role " ] = " user "
messages = messages_backup + [ last_message ]
append_generate_chat ( None ) # 'regular' chatbot answer
elif input_text . startswith ( " /help " ) :
print ( " !<prompt> answer as ' tool ' in <tool_response> tags " )
print ( " /clear clear chat history " )
print ( " /undo undo latest prompt " )
print ( " /regen regenerate the last message " )
print ( " /more generate more additional information " )
print ( " /file read prompt input from file " )
print ( " /auto automatically advance conversation " )
print ( " /help print this message " )
print ( " " )
elif input_text . startswith ( " / " ) :
print ( " unknown command. " )
elif input_text . startswith ( " / " ) :
print ( " unknown command. " )
else :
append_generate_chat ( input_text )
else :
append_generate_chat ( input_text )