python package restructuring
This commit is contained in:
0
chatbug/__init__.py
Normal file
0
chatbug/__init__.py
Normal file
7
chatbug/__main__.py
Normal file
7
chatbug/__main__.py
Normal file
@@ -0,0 +1,7 @@
|
||||
print("running __main__.-py")
|
||||
|
||||
from chatbug.llama import main
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
37
chatbug/download_model.py
Normal file
37
chatbug/download_model.py
Normal file
@@ -0,0 +1,37 @@
|
||||
|
||||
|
||||
from chatbug.inference import Inference
|
||||
from chatbug.modelconfig import Modelconfig
|
||||
|
||||
|
||||
def main():
|
||||
# Model size: 3.21B params
|
||||
Inference(Modelconfig("NousResearch/Hermes-3-Llama-3.2-3B", load_in_8bit=True))
|
||||
|
||||
# Model size: 1.24B params
|
||||
Inference(Modelconfig("unsloth/Llama-3.2-1B", load_in_8bit=True))
|
||||
|
||||
# Model size: 3.21B params
|
||||
Inference(Modelconfig("unsloth/Llama-3.2-3B-Instruct", load_in_8bit=True))
|
||||
|
||||
# Model size: 4.65B params
|
||||
Inference(Modelconfig("unsloth/llama-3-8b-bnb-4bit", load_in_4bit=True))
|
||||
|
||||
# Model size: 3.21B params
|
||||
Inference(Modelconfig("unsloth/Llama-3.2-3B-Instruct-GGUF", load_in_4bit=True))
|
||||
|
||||
# Model size: 5.21B params
|
||||
Inference(Modelconfig("unsloth/gemma-2-9b-it-bnb-4bit", load_in_4bit=True))
|
||||
|
||||
# Model size: 4.46B params
|
||||
Inference(Modelconfig("unsloth/Qwen2.5-7B-Instruct-bnb-4bit", load_in_4bit=True))
|
||||
|
||||
# Model size: 3.09B params
|
||||
Inference(Modelconfig("unsloth/Qwen2.5-3B-Instruct", load_in_4bit=True))
|
||||
|
||||
# Model size: 3.87B params
|
||||
Inference(Modelconfig("unsloth/mistral-7b-instruct-v0.3-bnb-4bit", load_in_4bit=True))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
44
chatbug/file_append.py
Normal file
44
chatbug/file_append.py
Normal file
@@ -0,0 +1,44 @@
|
||||
import os
|
||||
|
||||
|
||||
def check_append_file(prompt: str) -> str:
|
||||
if "@" in prompt:
|
||||
parts = prompt.split(" ")
|
||||
content = []
|
||||
for part in parts:
|
||||
if part.startswith("@"):
|
||||
filename = part[1:]
|
||||
try:
|
||||
if os.path.exists(filename):
|
||||
with open(filename, "r") as f:
|
||||
content.append("%s:'''\n%s'''" % (filename, f.read()))
|
||||
except FileNotFoundError:
|
||||
print(f"File '{filename}' not found.")
|
||||
content.append(prompt)
|
||||
return "\n".join(content)
|
||||
return prompt
|
||||
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
exit() # not accidentally trigger it
|
||||
|
||||
# Create some sample files
|
||||
with open("fmain.py", "w") as f:
|
||||
f.write("# This is main.py\n")
|
||||
with open("finference.py", "w") as f:
|
||||
f.write("# This is inference.py\n")
|
||||
|
||||
# Test cases
|
||||
test_prompts = [
|
||||
"@fmain.py",
|
||||
"@fmain.py @finference.py",
|
||||
"@fnonexistent.py",
|
||||
"@fmain.py @fnonexistent.py"
|
||||
]
|
||||
|
||||
for prompt in test_prompts:
|
||||
print(f"Testing prompt: {prompt}")
|
||||
result = check_append_file(prompt)
|
||||
print(f"Result: {result}")
|
||||
print("-" * 20)
|
170
chatbug/generation_loop.py
Normal file
170
chatbug/generation_loop.py
Normal file
@@ -0,0 +1,170 @@
|
||||
import time
|
||||
import json
|
||||
import random
|
||||
from chatbug.tool_helper import tool_list, parse_and_execute_tool_call
|
||||
from chatbug.inference import Inference, torch_reseed
|
||||
from chatbug.file_append import check_append_file
|
||||
|
||||
|
||||
|
||||
def msg(role: str, content: str) -> dict:
|
||||
return {"role": role, "content": content}
|
||||
|
||||
|
||||
class Terminal:
|
||||
|
||||
def __init__(self, inference: Inference, systemmessage: dict):
|
||||
self.inference = inference
|
||||
self.messages:list[dict] = [systemmessage]
|
||||
|
||||
# these are meant to be overwritten by better ones
|
||||
self.roleflip = msg("system", "keep going.")
|
||||
self.summarize = msg("system", "summarize conversation")
|
||||
self.summarize_user = msg("system", "please summarize conversation")
|
||||
self.title_prompt = msg("system", "create a title for this conversation")
|
||||
|
||||
def append_generate_chat(self, input_text: str, role="user"):
|
||||
t_start = time.time()
|
||||
|
||||
# generate AI response
|
||||
if input_text != None:
|
||||
self.messages.append({"role": role, "content": input_text})
|
||||
|
||||
inputs = self.inference.tokenize(self.messages, tokenize=True)
|
||||
number_of_input_tokens = inputs.shape[1]
|
||||
|
||||
outputs, out_text = self.inference.generate(inputs)
|
||||
|
||||
# append result to message history
|
||||
self.messages.append({"role": "assistant", "content": out_text})
|
||||
|
||||
print("")
|
||||
time_taken = time.time() - t_start
|
||||
number_of_tokens = len(outputs[0])
|
||||
tokens_per_second = (number_of_tokens - number_of_input_tokens) / time_taken
|
||||
print("generation took %.3fs (%d tokens, %.3f t/s)" % (time_taken, number_of_tokens, tokens_per_second))
|
||||
|
||||
# handle tool call and check if a tool call has happened.
|
||||
tool_result = parse_and_execute_tool_call(out_text, tool_list)
|
||||
if tool_result != None:
|
||||
# tool call happened
|
||||
tool_result = "<tool_response>%s</tool_response>" % tool_result
|
||||
# depending on the chat template the tool response tags must or must not be passed. :(
|
||||
self.append_generate_chat(tool_result, role="tool")
|
||||
|
||||
def join(self):
|
||||
|
||||
while True:
|
||||
# print an input prompt to receive text or commands
|
||||
input_text = input(">>> ")
|
||||
print("")
|
||||
|
||||
input_text = check_append_file(input_text)
|
||||
|
||||
|
||||
if input_text.startswith("!"):
|
||||
self.append_generate_chat("<tool_response>%s</tool_response>" % input_text[1:], role="tool")
|
||||
# append_generate_chat("%s" % input_text[1:], role="tool") # depending on the chat template the tool response tags must or must not be passed. :(
|
||||
|
||||
elif input_text.startswith("/clear"):
|
||||
print("clearing chat history")
|
||||
start_msg = self.messages[0]
|
||||
self.message = [start_msg]
|
||||
print("")
|
||||
|
||||
elif input_text.startswith("/history"):
|
||||
history = self.inference.tokenize(self.messages, tokenize=False)
|
||||
# history = tokenizer.apply_chat_template(self.message, return_tensors="pt", tokenize=False, add_generation_prompt=False)
|
||||
print(history)
|
||||
|
||||
elif input_text.startswith("/undo"):
|
||||
if len(self.messages) > 2:
|
||||
print("undo latest prompt")
|
||||
self.message = self.messages[:-2]
|
||||
else:
|
||||
print("cannot undo because there are not enough self.message on history.")
|
||||
print("")
|
||||
|
||||
elif input_text.startswith("/regen"):
|
||||
if len(self.messages) >= 2:
|
||||
print("regenerating message (not working)")
|
||||
self.messages = self.messages[:-1]
|
||||
seed = random.randint(0, 2**32 - 1) # Generate a random seed
|
||||
torch_reseed(seed)
|
||||
self.append_generate_chat(None)
|
||||
else:
|
||||
print("cannot regenerate because there are not enough self.message on history.")
|
||||
print("")
|
||||
|
||||
elif input_text.startswith("/more"):
|
||||
self.append_generate_chat(None)
|
||||
|
||||
elif input_text.startswith("/file"):
|
||||
filename = input_text[len("/file "):]
|
||||
print("read '%s' for prompt:" % filename)
|
||||
with open(filename, "r") as f:
|
||||
content = f.read()
|
||||
print(content)
|
||||
self.append_generate_chat(content)
|
||||
|
||||
elif input_text.startswith("/auto"):
|
||||
message_backup = self.messages
|
||||
self.messages = [self.roleflip]
|
||||
for m in self.message_backup:
|
||||
role = m["role"]
|
||||
content = m["content"]
|
||||
if role == "user":
|
||||
role = "assistant"
|
||||
elif role == "assistant":
|
||||
role = "user"
|
||||
if role != "system":
|
||||
self.message.append({"role": role, "content": content})
|
||||
self.append_generate_chat(None) # will automatically advance the conversation as 'user'
|
||||
last_message = self.messages[-1]
|
||||
last_message["role"] = "user"
|
||||
self.messages = message_backup + [last_message]
|
||||
self.append_generate_chat(None) # 'regular' chatbot answer
|
||||
|
||||
elif input_text.startswith("/summarize"):
|
||||
messages_temp = list(filter(lambda x: x["role"] != "system", self.messages))
|
||||
messages_temp = [self.summarize] + messages_temp + [self.summarize_user] # copy dict in last instance
|
||||
# messages_temp[-1]["role"] = "user"
|
||||
input_ids = self.inference.tokenize(messages_temp, tokenize=True, assistant_prefix="The conversation was about ")
|
||||
generated_tokens, full_output = self.inference.generate(input_ids)
|
||||
|
||||
elif input_text.startswith("/title"):
|
||||
messages_temp = list(filter(lambda x: x["role"] != "system", self.messages))
|
||||
messages_temp = [self.title_prompt] + messages_temp #+ [dict(title)] # copy dict in last instance
|
||||
messages_temp[-1]["role"] = "user"
|
||||
input_ids = self.inference.tokenize(messages_temp, tokenize=True, assistant_prefix="Title: ")
|
||||
generated_tokens, full_output = self.inference.generate(input_ids)
|
||||
|
||||
elif input_text.startswith("/save"):
|
||||
with open("messages.json", "w") as f:
|
||||
json.dump(self.messages, f, indent=4)
|
||||
|
||||
elif input_text.startswith("/load"):
|
||||
with open("messages.json", "r") as f:
|
||||
new_messages = json.load(f)
|
||||
self.messages = [self.messages[0]] + new_messages[1:]
|
||||
|
||||
elif input_text.startswith("/help"):
|
||||
print("!<prompt> answer as 'tool' in <tool_response> tags")
|
||||
print("/clear clear chat history")
|
||||
print("/undo undo latest prompt")
|
||||
print("/regen regenerate the last message")
|
||||
print("/more generate more additional information")
|
||||
print("/file read prompt input from file")
|
||||
print("/auto automatically advance conversation")
|
||||
print("/summarize generate a summary of the chat")
|
||||
print("/title generate a title of the chat")
|
||||
print("/save write chat history to file")
|
||||
print("/load load previously saved history")
|
||||
print("/help print this message")
|
||||
print("")
|
||||
|
||||
elif input_text.startswith("/"):
|
||||
print("unknown command.")
|
||||
|
||||
else:
|
||||
self.append_generate_chat(input_text)
|
24
chatbug/gpt2.py
Normal file
24
chatbug/gpt2.py
Normal file
@@ -0,0 +1,24 @@
|
||||
import cProfile
|
||||
import pstats
|
||||
from transformers import pipeline
|
||||
import time
|
||||
|
||||
import torch
|
||||
torch.set_num_threads(24) # Adjust this to the number of threads/cores you have
|
||||
|
||||
# Initialize the pipeline
|
||||
generator = pipeline('text-generation', model='gpt2', device_map="cpu") # gpt2
|
||||
|
||||
def run_inference():
|
||||
t_start = time.time()
|
||||
# Generate text
|
||||
generated_text = generator("below is a simple python function to extract email addresses from a string:", max_length=500, num_return_sequences=1)
|
||||
|
||||
# Print the generated text
|
||||
print(generated_text[0]['generated_text'])
|
||||
print("took %.3fs" % (time.time() - t_start))
|
||||
|
||||
cProfile.run('run_inference()', 'profile_output.prof')
|
||||
|
||||
p = pstats.Stats('profile_output.prof')
|
||||
p.sort_stats('cumulative').print_stats(30) # Show the top 10 time-consuming functions
|
260
chatbug/inference.py
Normal file
260
chatbug/inference.py
Normal file
@@ -0,0 +1,260 @@
|
||||
if __name__ == "__main__":
|
||||
# this message is at the start, because initializing torch/transformers takes lots of time. fail fast.
|
||||
raise Exception("cannot execute this file directly")
|
||||
|
||||
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
|
||||
from transformers.cache_utils import (
|
||||
DynamicCache,
|
||||
SinkCache,
|
||||
StaticCache,
|
||||
SlidingWindowCache,
|
||||
QuantoQuantizedCache,
|
||||
QuantizedCacheConfig,
|
||||
)
|
||||
import torch
|
||||
import time
|
||||
import re
|
||||
import os
|
||||
import chatbug.utils as utils
|
||||
from chatbug.modelconfig import Modelconfig
|
||||
|
||||
torch.set_num_threads(os.cpu_count()) # Adjust this to the number of threads/cores you have
|
||||
|
||||
|
||||
class Inference:
|
||||
def __init__(self, modelconfig: Modelconfig):
|
||||
print("loading LLM '%s'..." % modelconfig.model_name)
|
||||
t_start = time.time()
|
||||
|
||||
# model_name = "NousResearch/Llama-2-7b-hf" # will cache on C:\Users\ftobler\.cache\huggingface\hub
|
||||
# model_name = "NousResearch/Hermes-3-Llama-3.2-3B" # will cache on C:\Users\ftobler\.cache\huggingface\hub
|
||||
# model_name = "unsloth/phi-4-unsloth-bnb-4bit" #too big
|
||||
# model_name = "gpt2"
|
||||
# model_name = "NousResearch/Hermes-2-Pro-Llama-3-8B"
|
||||
# model_name = "Orenguteng/Llama-3.1-8B-Lexi-Uncensored-V2"
|
||||
# "meta-llama/Llama-2-7b-hf" # Replace with your chosen model
|
||||
|
||||
|
||||
# quantization_config_4bit = BitsAndBytesConfig( # tool calls don't really work in 4 bit mode
|
||||
# load_in_4bit=True,
|
||||
# bnb_4bit_quant_type="nf4", # Recommended for better performance
|
||||
# bnb_4bit_use_double_quant=True, # Optional: Further quantization for more memory saving
|
||||
# bnb_4bit_compute_dtype=torch.bfloat16 # Use bfloat16 for computation
|
||||
# )
|
||||
|
||||
# quantization_config_8bit = BitsAndBytesConfig(load_in_8bit=True)
|
||||
|
||||
# Load the model with quantization (optional)
|
||||
if modelconfig.bits_and_bytes_config != None:
|
||||
self.model = AutoModelForCausalLM.from_pretrained(
|
||||
modelconfig.model_name,
|
||||
# device_map="auto", # Automatically places parts of the model on GPU/CPU
|
||||
# device_map="cuda", # Automatically places parts of the model on GPU/CPU
|
||||
device_map="cuda", # Automatically places parts of the model on GPU/CPU
|
||||
# load_in_8bit=True, # Enables 8-bit quantization if bitsandbytes is installed
|
||||
quantization_config=modelconfig.bits_and_bytes_config
|
||||
)
|
||||
else:
|
||||
self.model = AutoModelForCausalLM.from_pretrained(
|
||||
modelconfig.model_name,
|
||||
device_map="cuda",
|
||||
)
|
||||
|
||||
# print("apply optimization")
|
||||
# self.model.generation_config.cache_implementation = "static"
|
||||
# self.model.forward = torch.compile(self.model.forward, mode="reduce-overhead", fullgraph=True)
|
||||
|
||||
|
||||
# Load tokenizer
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(modelconfig.model_name)
|
||||
|
||||
print("load took %.3fs" % (time.time() - t_start))
|
||||
|
||||
self.max_context_length = self.model.config.max_position_embeddings
|
||||
|
||||
|
||||
self.tokenizer.chat_template = utils.load_json_file("chat_template.json")
|
||||
|
||||
print("max_context_length is %d tokens." % (self.max_context_length))
|
||||
|
||||
|
||||
def generate(self, input_ids: torch.Tensor, print_stdout=True) -> tuple[torch.Tensor, str]:
|
||||
with torch.inference_mode():
|
||||
with torch.no_grad():
|
||||
return self.generate_incremental_2(input_ids, print_stdout)
|
||||
|
||||
|
||||
def generate_batch(self, input_ids: torch.Tensor, print_stdout:bool=True) -> tuple[torch.Tensor, str]:
|
||||
outputs = self.model.generate(
|
||||
input_ids, # **inputs, inputs["input_ids"]
|
||||
max_new_tokens=500, # max_length=max_context_length,
|
||||
pad_token_id=self.tokenizer.pad_token_id,
|
||||
eos_token_id=self.tokenizer.eos_token_id,
|
||||
do_sample=True,
|
||||
num_return_sequences=1,
|
||||
num_beams = 1
|
||||
)
|
||||
# skip all input tokens and only output the additional generated part of the conversation
|
||||
input_token_count = len(input_ids[0])
|
||||
out_text = self.tokenizer.decode(outputs[0][input_token_count:], skip_special_tokens=True)
|
||||
if print_stdout:
|
||||
print(out_text)
|
||||
return outputs, out_text
|
||||
|
||||
|
||||
def generate_incremental_2(self, input_ids: torch.Tensor, print_stdout:bool=True) -> tuple[torch.Tensor, str]:
|
||||
generated_tokens = input_ids
|
||||
|
||||
past_key_values = DynamicCache()
|
||||
|
||||
# n = 0
|
||||
try:
|
||||
while True:
|
||||
outputs = self.model.generate(
|
||||
generated_tokens, # **inputs, inputs["input_ids"]
|
||||
max_new_tokens=5, # like streaming
|
||||
pad_token_id=self.tokenizer.pad_token_id,
|
||||
eos_token_id=self.tokenizer.eos_token_id,
|
||||
do_sample=True,
|
||||
num_return_sequences=1,
|
||||
num_beams = 1,
|
||||
# use_cache=True,
|
||||
past_key_values=past_key_values
|
||||
)
|
||||
|
||||
# Get the next token (the last token from the generated sequence)
|
||||
# next_token = outputs.argmax(dim=-1)[:, -1]
|
||||
new_tokens = outputs[0, len(generated_tokens[0]):]
|
||||
# next_token = outputs[0,-1]
|
||||
|
||||
# Append the new token to the sequence
|
||||
generated_tokens = outputs
|
||||
# generated_tokens = torch.cat([generated_tokens, next_token.unsqueeze(0).unsqueeze(0)], dim=1)
|
||||
|
||||
# Decode and print the newly generated token (skip special tokens)
|
||||
# out_text = self.tokenizer.decode(next_token, skip_special_tokens=True)
|
||||
out_text = self.tokenizer.decode(new_tokens, skip_special_tokens=True)
|
||||
if print_stdout:
|
||||
print(out_text, end="", flush=True) # Print without newline
|
||||
|
||||
# Check if the generated token is the end-of-sequence token
|
||||
# if next_token.item() == self.tokenizer.eos_token_id:
|
||||
if new_tokens[-1].item() == self.tokenizer.eos_token_id:
|
||||
if print_stdout:
|
||||
print("")
|
||||
break
|
||||
|
||||
# n += 1
|
||||
# if n >= 15:
|
||||
# n = 0
|
||||
# torch.cuda.empty_cache()
|
||||
except KeyboardInterrupt:
|
||||
pass
|
||||
|
||||
# Once done, return the full generated sequence
|
||||
input_token_count = len(input_ids[0])
|
||||
full_output = self.tokenizer.decode(generated_tokens[0][input_token_count:], skip_special_tokens=True)
|
||||
|
||||
# torch.cuda.empty_cache()
|
||||
|
||||
return generated_tokens, full_output
|
||||
|
||||
|
||||
def generate_incremental(self, input_ids: torch.Tensor, print_stdout:bool=True) -> tuple[torch.Tensor, str]:
|
||||
with torch.inference_mode():
|
||||
return self._generate_incremental(input_ids, print_stdout)
|
||||
|
||||
|
||||
def _generate_incremental(self, input_ids: torch.Tensor, print_stdout:bool=True) -> tuple[torch.Tensor, str]:
|
||||
# Start with the initial input tokens
|
||||
generated_tokens = input_ids # Initially, this is just the input tokens
|
||||
|
||||
# past_key_values = DynamicCache()
|
||||
# max_cache_length = past_key_values.get_max_length()
|
||||
|
||||
n = 0
|
||||
try:
|
||||
|
||||
# Loop to generate one token at a time
|
||||
while True:
|
||||
# Call the model with the current tokens
|
||||
outputs = self.model(
|
||||
input_ids=generated_tokens,
|
||||
use_cache=True,
|
||||
num_beams = 1
|
||||
# past_key_values=past_key_values
|
||||
)
|
||||
|
||||
# Get the next token (the last token from the generated sequence)
|
||||
next_token = outputs.logits.argmax(dim=-1)[:, -1]
|
||||
|
||||
# Append the new token to the sequence
|
||||
generated_tokens = torch.cat([generated_tokens, next_token.unsqueeze(0)], dim=1)
|
||||
|
||||
# Decode and print the newly generated token (skip special tokens)
|
||||
out_text = self.tokenizer.decode(next_token, skip_special_tokens=True)
|
||||
if print_stdout:
|
||||
print(out_text, end="", flush=True) # Print without newline
|
||||
|
||||
# Check if the generated token is the end-of-sequence token
|
||||
if next_token.item() == self.tokenizer.eos_token_id:
|
||||
if print_stdout:
|
||||
print("")
|
||||
break
|
||||
|
||||
n += 1
|
||||
if n >= 15:
|
||||
n = 0
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
except KeyboardInterrupt:
|
||||
pass
|
||||
|
||||
# Once done, return the full generated sequence
|
||||
input_token_count = len(input_ids[0])
|
||||
full_output = self.tokenizer.decode(generated_tokens[0][input_token_count:], skip_special_tokens=True)
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
return generated_tokens, full_output
|
||||
|
||||
|
||||
def tokenize(self, messages: list[dict], tokenize: bool, assistant_prefix: str = None) -> str | torch.Tensor:
|
||||
if tokenize:
|
||||
inputs = self.tokenizer.apply_chat_template(messages, return_tensors="pt", tokenize=True, return_dict=True, add_generation_prompt=True) #continue_final_message=True,
|
||||
inputs = {key: value.to(self.model.device) for key, value in inputs.items()}
|
||||
input_ids = inputs["input_ids"]
|
||||
|
||||
# Append the assistant prefix if provided
|
||||
if assistant_prefix:
|
||||
prefix_ids = self.tokenizer(assistant_prefix, return_tensors="pt")["input_ids"]
|
||||
input_ids = torch.cat([input_ids, prefix_ids.to(self.model.device)], dim=-1)
|
||||
|
||||
return input_ids
|
||||
else:
|
||||
# only plain text generation
|
||||
message = self.tokenizer.apply_chat_template(messages, return_tensors="pt", tokenize=False, add_generation_prompt=False)
|
||||
|
||||
# Append the assistant prefix to raw text if provided
|
||||
if assistant_prefix:
|
||||
message += f"<|im_start|>assistant\n{assistant_prefix}"
|
||||
|
||||
return message
|
||||
|
||||
|
||||
def generate_tool_use_header(self, tools: list[callable]) -> str:
|
||||
temp_messages = [{}] # for some reason an empty array is not allowed but a {} inside works like an empty array.
|
||||
s = self.tokenizer.apply_chat_template(temp_messages, return_tensors="pt", tokenize=False, add_generation_prompt=False, tools=tools)
|
||||
pattern = r"<\|im_start\|>system\n(.*)<\|im_end\|>"
|
||||
match = re.search(pattern, s, re.DOTALL)
|
||||
if not match:
|
||||
raise Exception("Failed to regex match the template tool system text.")
|
||||
extraction = match.group(1)
|
||||
return extraction
|
||||
|
||||
|
||||
def torch_reseed(seed: int):
|
||||
torch.manual_seed(seed)
|
||||
torch.cuda.manual_seed_all(seed)
|
||||
|
76
chatbug/inference_profile_experiement.py
Normal file
76
chatbug/inference_profile_experiement.py
Normal file
@@ -0,0 +1,76 @@
|
||||
import time
|
||||
import nvidia_smi
|
||||
import torch
|
||||
import gc
|
||||
from chatbug.inference import Inference
|
||||
from chatbug.modelconfig import Modelconfig
|
||||
|
||||
|
||||
def empty_cuda():
|
||||
while True:
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
time.sleep(0.5)
|
||||
vram = nvidia_smi.get_gpu_stats()["memory_used"]
|
||||
print("vram: %d MB" % vram)
|
||||
if vram < 200:
|
||||
return
|
||||
|
||||
|
||||
def profile_ex(model_conf: Modelconfig):
|
||||
print("")
|
||||
empty_cuda()
|
||||
messages = [
|
||||
{"role": "system", "content": "Hold a casual conversation with the user. Keep responses short at max 3 sentences. Answer using markdown to the user."},
|
||||
{"role": "user", "content": "How do astronomers determine the original wavelength of light emitted by a celestial body at rest, which is necessary for measuring its speed using the Doppler effect?"},
|
||||
]
|
||||
|
||||
gpu_stats_before = nvidia_smi.get_gpu_stats()
|
||||
inference = Inference(model_conf)
|
||||
|
||||
gpu_stats_loaded = nvidia_smi.get_gpu_stats()
|
||||
t_start = time.time()
|
||||
input_ids = inference.tokenize(messages, tokenize=True)
|
||||
generated_tokens, full_output = inference.generate_batch(input_ids, print_stdout=False)
|
||||
t_end = time.time()
|
||||
gpu_stats_after = nvidia_smi.get_gpu_stats()
|
||||
|
||||
took = t_end - t_start
|
||||
tokens = len(generated_tokens[0])
|
||||
tokens_per = tokens / took
|
||||
vram_bulk = gpu_stats_loaded["memory_used"] - gpu_stats_before["memory_used"]
|
||||
vram_top = gpu_stats_after["memory_used"] - gpu_stats_loaded["memory_used"]
|
||||
print("model: %s" % model_conf.model_name)
|
||||
print("tokens: %d tk" % tokens)
|
||||
print("time: %.3f s" % took)
|
||||
print("speed: %.3f tk/s" % tokens_per)
|
||||
print("vram_bulk: %d MB" % vram_bulk)
|
||||
print("vram_top: %d MB" % vram_top)
|
||||
print("context: %d tk" % inference.max_context_length)
|
||||
print("")
|
||||
|
||||
|
||||
def profile(model_conf):
|
||||
try:
|
||||
profile_ex(model_conf)
|
||||
except Exception as e:
|
||||
print("exception: " + str(e))
|
||||
pass
|
||||
|
||||
|
||||
def main():
|
||||
profile(Modelconfig("NousResearch/Hermes-3-Llama-3.2-3B", load_in_8bit=True))
|
||||
profile(Modelconfig("unsloth/Llama-3.2-1B"))
|
||||
profile(Modelconfig("unsloth/Llama-3.2-3B-Instruct", load_in_8bit=True))
|
||||
profile(Modelconfig("unsloth/llama-3-8b-bnb-4bit"))
|
||||
# profile(Modelconfig("unsloth/Llama-3.2-3B-Instruct-GGUF", load_in_8bit=True))
|
||||
profile(Modelconfig("unsloth/gemma-2-9b-it-bnb-4bit"))
|
||||
profile(Modelconfig("unsloth/Qwen2.5-7B-Instruct-bnb-4bit"))
|
||||
profile(Modelconfig("unsloth/Qwen2.5-3B-Instruct", load_in_4bit=True))
|
||||
profile(Modelconfig("unsloth/Qwen2.5-3B-Instruct", load_in_8bit=True))
|
||||
profile(Modelconfig("unsloth/mistral-7b-instruct-v0.3-bnb-4bit"))
|
||||
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
46
chatbug/llama.py
Normal file
46
chatbug/llama.py
Normal file
@@ -0,0 +1,46 @@
|
||||
|
||||
|
||||
import datetime
|
||||
from chatbug.tool_helper import tool_list
|
||||
from chatbug.tool_functions import register_dummy
|
||||
from chatbug.inference import Inference
|
||||
from chatbug.generation_loop import Terminal, msg
|
||||
from chatbug import model_selection
|
||||
|
||||
|
||||
register_dummy()
|
||||
|
||||
|
||||
def initialize_config(inference: Inference) -> Terminal:
|
||||
|
||||
# systemmessage at the very begin of the chat. Will be concatenated with the automatic tool usage descriptions
|
||||
system_prompt = "Hold a casual conversation with the user. Keep responses short at max 5 sentences and on point. Answer using markdown to the user. When providing code examples, avoid comments which provide no additional information. Do not summarize."
|
||||
current_date_and_time = datetime.datetime.now().strftime("Current date is %Y-%m-%d and its %H:%M %p right now.")
|
||||
append_toolcalls = False
|
||||
if append_toolcalls:
|
||||
systemmessage = msg("system", system_prompt + "\n" + current_date_and_time + "\n" + inference.generate_tool_use_header(tool_list))
|
||||
else:
|
||||
systemmessage = msg("system", system_prompt + "\n" + current_date_and_time)
|
||||
|
||||
terminal = Terminal(inference, systemmessage)
|
||||
|
||||
# system message for role flip so the model automatically answers for the user
|
||||
terminal.roleflip = msg("system", "Keep the conversation going, ask for more information on the subject. Keep messages short at max 1-2 sentences. Do not thank and say goodbye.")
|
||||
|
||||
# system messages and user message to bring the model to summarize the entire conversation
|
||||
terminal.summarize = msg("system", "Summarize the conversation as a single, cohesive paragraph. Avoid using any bullet points, numbers, or list formatting. Write in plain text with natural sentences that flow together seamlessly.")
|
||||
terminal.summarize_user = msg("system", "Can you summarize the conversation?")
|
||||
|
||||
# system message to create a conversation title
|
||||
terminal.title_prompt = msg("system", "Please create a very short and descriptive title or label for this conversation. Maximum 2-5 words. Use only plain text, avoid numbering, special characters, or unnecessary formatting-focus on clarity and brevity.")
|
||||
return terminal
|
||||
|
||||
|
||||
def main():
|
||||
inference = Inference(model_selection.get_model())
|
||||
terminal = initialize_config(inference)
|
||||
terminal.join()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
3
chatbug/matheval/__init__.py
Normal file
3
chatbug/matheval/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from chatbug.matheval import ast
|
||||
from chatbug.matheval import interpreter
|
||||
from chatbug.matheval import lexer
|
129
chatbug/matheval/ast.py
Normal file
129
chatbug/matheval/ast.py
Normal file
@@ -0,0 +1,129 @@
|
||||
from chatbug.matheval import lexer
|
||||
from chatbug.matheval.lexer import Token
|
||||
|
||||
|
||||
class Statement:
|
||||
pass
|
||||
|
||||
class Expression(Statement):
|
||||
def __init__(self, value: str):
|
||||
self.value = value
|
||||
|
||||
class Equation:
|
||||
def __init__(self, lhs: Expression, rhs: Expression):
|
||||
self.lhs = lhs
|
||||
self.rhs = rhs
|
||||
|
||||
class Solve(Statement):
|
||||
def __init__(self, equations: list[Equation], variables: list[Expression]):
|
||||
self.equations = equations
|
||||
self.variables = variables
|
||||
|
||||
|
||||
|
||||
|
||||
class Parser:
|
||||
def __init__(self):
|
||||
self.tokens: list[Token] # tokens from lexer
|
||||
self._last_eaten = None
|
||||
|
||||
def not_eof(self) -> bool:
|
||||
return self.tokens[0].type is not lexer.END_OF_INPUT
|
||||
|
||||
def at(self) -> Token:
|
||||
return self.tokens[0]
|
||||
|
||||
def at_last(self) -> Token:
|
||||
return self._last_eaten
|
||||
|
||||
def eat(self) -> Token:
|
||||
self._last_eaten = self.tokens.pop(0)
|
||||
return self._last_eaten
|
||||
|
||||
def backtrack(self):
|
||||
if not self._last_eaten:
|
||||
raise Exception("Cannot backtrack.")
|
||||
self.tokens.insert(0, self._last_eaten)
|
||||
self._last_eaten = None
|
||||
|
||||
def eat_expect(self, token_type: int | str) -> Token:
|
||||
prev = self.eat()
|
||||
if prev.type is not token_type:
|
||||
raise Exception("expected to consume '%s' but '%s' encountered." % (str(token_type), str(prev.type)))
|
||||
return prev
|
||||
|
||||
def at_expect(self, token_type: int | str) -> Token:
|
||||
prev = self.at()
|
||||
if prev.type is not token_type:
|
||||
raise Exception("expected to be at '%s' but '%s' encountered." % (str(token_type), str(prev.type)))
|
||||
return prev
|
||||
|
||||
def parse(self, tokens: list[Token]) -> Statement:
|
||||
self.tokens = tokens
|
||||
statement = self.parse_statement()
|
||||
self.at_expect(lexer.END_OF_INPUT)
|
||||
return statement
|
||||
|
||||
def parse_statement(self) -> Statement:
|
||||
type = self.at().type
|
||||
if type is lexer.SOLVE:
|
||||
return self.parse_solve()
|
||||
return self.parse_expression(merge_commas=True)
|
||||
|
||||
def parse_solve(self) -> Solve:
|
||||
"""
|
||||
solve x = 1 for x
|
||||
solve x = y and y = 2 for x and y
|
||||
"""
|
||||
self.eat_expect(lexer.SOLVE)
|
||||
equations = [] # list of equations
|
||||
variables = [] # list of variables to solve for
|
||||
|
||||
while self.not_eof() and self.at().type is not lexer.FOR:
|
||||
equations.append(self.parse_equation())
|
||||
selfattype = self.at().type
|
||||
if selfattype is lexer.AND or selfattype is lexer.COMMA:
|
||||
self.eat()
|
||||
|
||||
self.eat_expect(lexer.FOR)
|
||||
|
||||
while self.not_eof():
|
||||
variables.append(self.parse_expression(merge_commas=False))
|
||||
selfattype = self.at().type
|
||||
if selfattype is lexer.AND or selfattype is lexer.COMMA:
|
||||
self.eat()
|
||||
|
||||
return Solve(equations, variables)
|
||||
|
||||
def parse_equation(self) -> Equation:
|
||||
lhs = self.parse_expression(merge_commas=False)
|
||||
self.eat_expect(lexer.EQUALS)
|
||||
rhs = self.parse_expression(merge_commas=False)
|
||||
return Equation(lhs, rhs)
|
||||
|
||||
def parse_expression(self, merge_commas) -> Expression:
|
||||
"""
|
||||
math expression
|
||||
e.g:
|
||||
sin(45) / 4 * pi
|
||||
"""
|
||||
|
||||
if merge_commas == True:
|
||||
values = []
|
||||
while self.not_eof():
|
||||
token = self.eat()
|
||||
if token.type is lexer.COMMA:
|
||||
values.append(lexer.COMMA)
|
||||
elif token.type is lexer.EQUALS:
|
||||
values.append(lexer.EQUALS)
|
||||
else:
|
||||
values.append(token.value)
|
||||
# token = self.eat_expect(lexer.EXPRESSION)
|
||||
# values.append(token.value)
|
||||
# if self.at() is lexer.COMMA:
|
||||
# token = self.eat()
|
||||
# values.append(lexer.COMMA)
|
||||
return Expression("".join(values))
|
||||
else:
|
||||
token = self.eat_expect(lexer.EXPRESSION)
|
||||
return Expression(token.value)
|
122
chatbug/matheval/interpreter.py
Normal file
122
chatbug/matheval/interpreter.py
Normal file
@@ -0,0 +1,122 @@
|
||||
|
||||
|
||||
|
||||
from sympy.parsing.sympy_parser import parse_expr
|
||||
from sympy.core.numbers import Integer, One, Zero
|
||||
from sympy import symbols, Eq, solveset, linsolve, nonlinsolve
|
||||
from sympy.core.symbol import Symbol
|
||||
from chatbug.matheval import ast
|
||||
|
||||
|
||||
def interpret(statement: ast.Statement) -> str:
|
||||
if isinstance(statement, ast.Solve):
|
||||
return interpret_solve(statement)
|
||||
elif isinstance(statement, ast.Expression):
|
||||
return interpret_expression(statement)
|
||||
return "interpretation error"
|
||||
|
||||
|
||||
def interpret_solve(statement: ast.Solve) -> str:
|
||||
eqs = statement.equations
|
||||
var = statement.variables
|
||||
|
||||
# convert equations to list of sympy Eq objects
|
||||
equations = [Eq(_math_expression_sanitation_and_parse(e.lhs.value), _math_expression_sanitation_and_parse(e.rhs.value)) for e in eqs]
|
||||
|
||||
variables = [symbols(v.value) for v in var]
|
||||
|
||||
if len(equations) == 1 and len(variables) == 1:
|
||||
return solve_simple_equation(equations[0], variables[0])
|
||||
else:
|
||||
return solve_multi_equation(equations, variables)
|
||||
|
||||
|
||||
|
||||
def solve_simple_equation(equation, variable):
|
||||
result = solveset(equation, variable)
|
||||
return "solved %s = %s for %s = %s" % (equation.lhs, equation.rhs, variable, result)
|
||||
|
||||
def solve_multi_equation(equations, variables):
|
||||
if is_linear(equations, variables):
|
||||
solution = linsolve(equations, variables)
|
||||
else:
|
||||
solution = nonlinsolve(equations, variables)
|
||||
|
||||
solutionpairs = []
|
||||
for variable, value in zip(variables, list(solution)[0]):
|
||||
value_str = str(value)
|
||||
if not isinstance(value, Integer):
|
||||
try:
|
||||
float_value = value.evalf()
|
||||
if len(value_str) > 20:
|
||||
value_str = "~%.3f" % float_value
|
||||
else:
|
||||
value_str += "=~%.3f" % float_value
|
||||
except:
|
||||
pass
|
||||
|
||||
solutionpairs.append(f"{variable}={value_str}")
|
||||
|
||||
# solutionpairs = [f"{variable}={value.doit()}" for variable, value in zip(variables, list(solution)[0])]
|
||||
|
||||
# return "solved equation system for " + ", ".join(solutionpairs[:-1]) + " and " + solutionpairs[-1]
|
||||
|
||||
if len(equations) > 1:
|
||||
leadin = "Solved equation system "
|
||||
else:
|
||||
leadin = "Solved equation "
|
||||
return leadin + _natural_join([_pretty_equation(e) for e in equations]) + " for " + _natural_join(solutionpairs) + "."
|
||||
|
||||
def _natural_join(data: list[any], joiner=", ", last=" and "):
|
||||
if len(data) > 1:
|
||||
return joiner.join(data[:-1]) + last + data[-1]
|
||||
return last.join(data)
|
||||
|
||||
def _pretty_equation(simpy_Eq) -> str:
|
||||
return f"{simpy_Eq.lhs} = {simpy_Eq.rhs}"
|
||||
|
||||
|
||||
|
||||
|
||||
def is_linear(equations, variables):
|
||||
return False
|
||||
"""Checks if a system of equations is linear."""
|
||||
for eq in equations:
|
||||
for var in variables:
|
||||
deriv = eq.diff(var) # Partial derivative
|
||||
if not (deriv.is_number or (isinstance(deriv, Symbol) and deriv.free_symbols.isdisjoint({var}))): # If the derivative is not a number or a symbol independent of the variable, the system is non-linear
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
|
||||
|
||||
def interpret_expression(statement: ast.Expression) -> str:
|
||||
return _math_evaluate_expression(statement.value)
|
||||
|
||||
|
||||
def _math_evaluate_expression(expression: str):
|
||||
"""evaluate a simple mathematical expression using sympy expression evaluation."""
|
||||
therm, simple, result = _math_evaluate_internal(expression)
|
||||
if isinstance(simple, Integer):
|
||||
return _build_equation_pair([therm, simple])
|
||||
if therm == simple or simple == result:
|
||||
return _build_equation_pair([therm, result])
|
||||
return _build_equation_pair([therm, simple, result])
|
||||
|
||||
|
||||
def _math_evaluate_internal(expression: str):
|
||||
therm = _math_expression_sanitation_and_parse(expression)
|
||||
simple = therm.doit()
|
||||
numerical = therm.evalf()
|
||||
return therm, simple, numerical
|
||||
|
||||
|
||||
def _math_expression_sanitation_and_parse(expression: str):
|
||||
expression = expression.replace("^", "**")
|
||||
return parse_expr(expression, evaluate=False)
|
||||
|
||||
|
||||
def _build_equation_pair(expressions: list[any]) -> str:
|
||||
expressions = [str(e) for e in expressions]
|
||||
return " = ".join(expressions)
|
61
chatbug/matheval/lexer.py
Normal file
61
chatbug/matheval/lexer.py
Normal file
@@ -0,0 +1,61 @@
|
||||
|
||||
|
||||
|
||||
EXPRESSION = 0
|
||||
END_OF_INPUT = 1
|
||||
|
||||
SOLVE = "solve"
|
||||
FOR = "for"
|
||||
AND = "and"
|
||||
EQUALS = "="
|
||||
COMMA = ","
|
||||
|
||||
keyword_tokens = [SOLVE, FOR, AND, EQUALS, COMMA]
|
||||
|
||||
|
||||
|
||||
class Token:
|
||||
def __init__(self, type: int|str, value: str = None):
|
||||
self.type = type
|
||||
self.value = value
|
||||
|
||||
def __repr__(self):
|
||||
if self.value == None:
|
||||
return f"{self.type}"
|
||||
return f"{self.type}|'{self.value}'"
|
||||
|
||||
|
||||
def tokenize(expression: str) -> list[Token]:
|
||||
"""
|
||||
this splits a math instruction into tokens.
|
||||
example:
|
||||
"solve x + 1 = 5 and y = 2*x for x, y"
|
||||
result:
|
||||
["solve", "x + 1", "=", "5", "and", "y", "=", "2*x", "for", "x", "and", "y", "end_of_input"]
|
||||
"""
|
||||
|
||||
tokens = [] # output list of tokens
|
||||
|
||||
symbols = expression.replace(",", " , ").replace("=", " = ").split(" ")
|
||||
|
||||
current_token = [] # everything that is not directly in math_keyword_tokens gets binned here
|
||||
for s in symbols:
|
||||
found = False
|
||||
|
||||
for keyword in keyword_tokens:
|
||||
if s.lower() == keyword:
|
||||
if len(current_token) != 0:
|
||||
tokens.append(Token(EXPRESSION, " ".join(current_token)))
|
||||
current_token = []
|
||||
tokens.append(Token(keyword))
|
||||
found = True
|
||||
break
|
||||
|
||||
if found == False:
|
||||
current_token.append(s)
|
||||
if len(current_token) != 0:
|
||||
tokens.append(Token(EXPRESSION, " ".join(current_token)))
|
||||
current_token = []
|
||||
|
||||
tokens.append(Token(END_OF_INPUT))
|
||||
return tokens
|
95
chatbug/model_selection.py
Normal file
95
chatbug/model_selection.py
Normal file
@@ -0,0 +1,95 @@
|
||||
|
||||
from chatbug.modelconfig import Modelconfig
|
||||
|
||||
|
||||
|
||||
def get_model() -> Modelconfig:
|
||||
|
||||
# model: NousResearch/Hermes-3-Llama-3.2-3B
|
||||
# tokens: 315 tk
|
||||
# time: 94.360 s
|
||||
# speed: 3.338 tk/s
|
||||
# vram_bulk: 3622 MB
|
||||
# vram_top: 80 MB
|
||||
# context: 131072 tk
|
||||
# model = Modelconfig("NousResearch/Hermes-3-Llama-3.2-3B", load_in_8bit=True)
|
||||
|
||||
# model: unsloth/Llama-3.2-1B
|
||||
# tokens: 589 tk
|
||||
# time: 39.348 s
|
||||
# speed: 14.969 tk/s
|
||||
# vram_bulk: 4708 MB
|
||||
# vram_top: 102 MB
|
||||
# context: 131072 tk
|
||||
# model = Modelconfig("unsloth/Llama-3.2-1B") # note, fast, but talks to itself. basically does not work.
|
||||
|
||||
# model: unsloth/Llama-3.2-3B-Instruct
|
||||
# tokens: 285 tk
|
||||
# time: 75.363 s
|
||||
# speed: 3.782 tk/s
|
||||
# vram_bulk: 3512 MB
|
||||
# vram_top: 48 MB
|
||||
# context: 131072 tk
|
||||
# model = Modelconfig("unsloth/Llama-3.2-3B-Instruct", load_in_8bit=True)
|
||||
|
||||
# model: unsloth/llama-3-8b-bnb-4bit
|
||||
# tokens: 435 tk
|
||||
# time: 84.314 s
|
||||
# speed: 5.159 tk/s
|
||||
# vram_bulk: 5440 MB
|
||||
# vram_top: 216 MB
|
||||
# context: 8192 tk
|
||||
# model = Modelconfig("unsloth/llama-3-8b-bnb-4bit")
|
||||
|
||||
# Model size: 3.21B params
|
||||
# vram used: xxxxx MB
|
||||
# speed xxxxx t/s
|
||||
# working: DOES NOT LOAD
|
||||
# model = Modelconfig("unsloth/Llama-3.2-3B-Instruct-GGUF", load_in_8bit=True)
|
||||
|
||||
# model: unsloth/gemma-2-9b-it-bnb-4bit
|
||||
# tokens: 154 tk
|
||||
# time: 32.727 s
|
||||
# speed: 4.706 tk/s
|
||||
# vram_bulk: 6156 MB
|
||||
# vram_top: 232 MB
|
||||
# context: 8192 tk
|
||||
# model = Modelconfig("unsloth/gemma-2-9b-it-bnb-4bit")
|
||||
|
||||
# model: unsloth/Qwen2.5-7B-Instruct-bnb-4bit
|
||||
# tokens: 120 tk
|
||||
# time: 12.248 s
|
||||
# speed: 9.798 tk/s
|
||||
# vram_bulk: 5382 MB
|
||||
# vram_top: 170 MB
|
||||
# context: 32768 tk
|
||||
model = Modelconfig("unsloth/Qwen2.5-7B-Instruct-bnb-4bit") # note, this works really good
|
||||
|
||||
# model: unsloth/Qwen2.5-3B-Instruct
|
||||
# tokens: 112 tk
|
||||
# time: 12.703 s
|
||||
# speed: 8.816 tk/s
|
||||
# vram_bulk: 2108 MB
|
||||
# vram_top: 98 MB
|
||||
# context: 32768 tk
|
||||
# model = Modelconfig("unsloth/Qwen2.5-3B-Instruct", load_in_4bit=True)
|
||||
|
||||
# model: unsloth/Qwen2.5-3B-Instruct
|
||||
# tokens: 118 tk
|
||||
# time: 33.748 s
|
||||
# speed: 3.497 tk/s
|
||||
# vram_bulk: 3310 MB
|
||||
# vram_top: 60 MB
|
||||
# context: 32768 tk
|
||||
# model = Modelconfig("unsloth/Qwen2.5-3B-Instruct", load_in_8bit=True)
|
||||
|
||||
# Model size: 3.87B params
|
||||
# vram used: xxxxx MB
|
||||
# speed xxxxx t/s
|
||||
# error: requires the protobuf library but it was not found in your environment
|
||||
# model = Modelconfig("unsloth/mistral-7b-instruct-v0.3-bnb-4bit")
|
||||
|
||||
return model
|
||||
|
||||
|
||||
|
20
chatbug/modelconfig.py
Normal file
20
chatbug/modelconfig.py
Normal file
@@ -0,0 +1,20 @@
|
||||
|
||||
from transformers import BitsAndBytesConfig
|
||||
import torch
|
||||
|
||||
class Modelconfig:
|
||||
def __init__(self, model_name, bits_and_bytes_config=None, load_in_8bit=False, load_in_4bit=False):
|
||||
self.model_name = model_name
|
||||
if load_in_4bit:
|
||||
assert bits_and_bytes_config == None
|
||||
self.bits_and_bytes_config = BitsAndBytesConfig( # tool calls don't really work in 4 bit mode
|
||||
load_in_4bit=True,
|
||||
bnb_4bit_quant_type="nf4", # Recommended for better performance
|
||||
bnb_4bit_use_double_quant=True, # Optional: Further quantization for more memory saving
|
||||
bnb_4bit_compute_dtype=torch.bfloat16 # Use bfloat16 for computation
|
||||
)
|
||||
elif load_in_8bit:
|
||||
assert bits_and_bytes_config == None
|
||||
self.bits_and_bytes_config = BitsAndBytesConfig(load_in_8bit=True)
|
||||
else:
|
||||
self.bits_and_bytes_config = bits_and_bytes_config
|
40
chatbug/nvidia_smi.py
Normal file
40
chatbug/nvidia_smi.py
Normal file
@@ -0,0 +1,40 @@
|
||||
import subprocess
|
||||
import psutil
|
||||
import time
|
||||
|
||||
def get_gpu_stats():
|
||||
command = [
|
||||
"nvidia-smi",
|
||||
"--query-gpu=index,name,memory.used,memory.free,power.draw,temperature.gpu",
|
||||
"--format=csv,noheader,nounits"
|
||||
]
|
||||
result = subprocess.run(command, stdout=subprocess.PIPE, text=True)
|
||||
stats = result.stdout.strip().split(", ")
|
||||
return {
|
||||
"index": int(stats[0]),
|
||||
"name": str(stats[1]),
|
||||
"memory_used": int(stats[2]), # in MB
|
||||
"memory_free": int(stats[3]), # in MB
|
||||
"power_draw": float(stats[4]), # in watts
|
||||
"temperature": float(stats[5]) # in Celsius
|
||||
}
|
||||
|
||||
def get_cpu_memory_stats():
|
||||
cpu_usage = psutil.cpu_percent(interval=1)
|
||||
memory_info = psutil.virtual_memory()
|
||||
return {
|
||||
"cpu_usage": cpu_usage,
|
||||
"memory_used": memory_info.used // (1024 ** 2), # in MB
|
||||
"memory_total": memory_info.total // (1024 ** 2) # in MB
|
||||
}
|
||||
|
||||
def main():
|
||||
import json
|
||||
while True:
|
||||
gpu_stats = get_gpu_stats()
|
||||
cpu_stats = get_cpu_memory_stats()
|
||||
print(json.dumps([gpu_stats, cpu_stats]))
|
||||
time.sleep(1) # Update every second
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
74
chatbug/tool_functions.py
Normal file
74
chatbug/tool_functions.py
Normal file
@@ -0,0 +1,74 @@
|
||||
import random
|
||||
import datetime
|
||||
from chatbug.tool_helper import tool
|
||||
import chatbug.matheval as matheval
|
||||
# from chatbug.matheval import interpreter, lexer
|
||||
# from chatbug.matheval.ast import Parser
|
||||
import chatbug.utils as utils
|
||||
|
||||
|
||||
# @tool
|
||||
# def current_time():
|
||||
# """Get the current local date and time as a string."""
|
||||
# # return datetime.datetime.now().strftime("%Y-%m-%d %H:%M")
|
||||
# return f"The current local date and time is {datetime.datetime.now().strftime('%Y-%m-%d %H:%M %p')}."
|
||||
|
||||
|
||||
# @tool
|
||||
# def random_float():
|
||||
# """Generate a random float in range 0 to 1."""
|
||||
# # return str(random.random())
|
||||
# return f"The freshly generated a random number from 0..1 is: {random.random():.5f}."
|
||||
|
||||
|
||||
# @tool
|
||||
# def random_int(a: int, b: int):
|
||||
# """Generate a random integer in the range [a, b], including both end points.
|
||||
# Args:
|
||||
# a: minimum possible value (must be <= b)
|
||||
# b: maximum possible value (must be >= a)"""
|
||||
# # return str(random.randint(a, b))
|
||||
# return f"A fresh generated random integer between {a} and {b} is {random.randint(a, b)}."
|
||||
|
||||
|
||||
|
||||
|
||||
@tool
|
||||
def math_evaluate(expression: str):
|
||||
"""Evaluate and simplify a mathematical expression. Returns the evaluated result or a simplified version of the expression as a string.
|
||||
Args:
|
||||
expression: A valid arithmetic expression (e.g., '2 + 3 * 4'). The expression must not contain '='."""
|
||||
try:
|
||||
tokens = matheval.lexer.tokenize(expression)
|
||||
parser = matheval.ast.Parser()
|
||||
ast = parser.parse(tokens)
|
||||
return matheval.interpreter.interpret(ast)
|
||||
except Exception as e:
|
||||
utils.print_error("Tool call evaluation failed. - " + str(e))
|
||||
return "Tool call evaluation failed."
|
||||
|
||||
|
||||
@tool
|
||||
def math_solve(equations: list[str], variables: list[str]):
|
||||
"""Solve a system of linear or non-linear equation system. Returns the solutions as a string, or an error message if the input is invalid or unsolvable.
|
||||
Args:
|
||||
equations: A list of mathematical equations in the format 'x + y = 2'.
|
||||
variables: A list of variables to solve for. The number of variables must not exceed the number of equations."""
|
||||
try:
|
||||
expression = "solve " + " and ".join(equations) + " for " + " and ".join(variables)
|
||||
print(expression)
|
||||
|
||||
tokens = lexer.tokenize(expression)
|
||||
parser = ast.Parser()
|
||||
ast = parser.parse(tokens)
|
||||
return interpreter.interpret(ast)
|
||||
except Exception as e:
|
||||
utils.print_error("Tool call evaluation failed. - " + str(e))
|
||||
return "Tool call evaluation failed."
|
||||
|
||||
|
||||
|
||||
|
||||
def register_dummy():
|
||||
"""dummy function to run and be sure the decorators have been initialized"""
|
||||
pass
|
103
chatbug/tool_helper.py
Normal file
103
chatbug/tool_helper.py
Normal file
@@ -0,0 +1,103 @@
|
||||
|
||||
from typing import Callable, List, Optional
|
||||
import json
|
||||
import re
|
||||
import chatbug.utils as utils
|
||||
|
||||
tool_list = []
|
||||
|
||||
|
||||
def tool(fn):
|
||||
"""tool function decorator"""
|
||||
print("register tool '%s'" % fn.__name__)
|
||||
tool_list.append(fn)
|
||||
return fn
|
||||
|
||||
# def parse_and_execute_tool_call(message: str, tools: list[function]) -> str | None:
|
||||
# """execute tool call if needed accordint <tool_call> tag and return the content of the tool call or None if no call happened."""
|
||||
# pass
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
def parse_and_execute_tool_call(message: str, tools: List[Callable]) -> Optional[str]:
|
||||
"""
|
||||
Execute a tool call if the <tool_call> tag is present and return the tool's response.
|
||||
If no <tool_call> tag is found, return None.
|
||||
|
||||
Args:
|
||||
message (str): The message containing the tool call.
|
||||
tools (list[function]): A list of tool functions available for execution.
|
||||
|
||||
Returns:
|
||||
Optional[str]: The content of the tool response or None if no tool call occurred.
|
||||
"""
|
||||
|
||||
# in case LLM responds with <tool_call></tool_call> the correct way
|
||||
extracted = _match_and_extract(message, r"<tool_call>(.*)<\/tool_call>")
|
||||
if extracted:
|
||||
return _execute_tool_call_str(extracted, tools)
|
||||
|
||||
# in case LLM responds with <tool_call></tool_response> by accident
|
||||
extracted = _match_and_extract(message, r"<tool_call>(.*)<\/tool_.*>")
|
||||
if extracted:
|
||||
return _execute_tool_call_str(extracted, tools)
|
||||
|
||||
# in case LLM responds with <tool_call></???> by accident
|
||||
extracted = _match_and_extract(message, r"<tool_call>(.*)<\/.*>")
|
||||
if extracted:
|
||||
return _execute_tool_call_str(extracted, tools)
|
||||
|
||||
# in case LLM responds with <tool_call></???> by accident
|
||||
extracted = _match_and_extract(message, r"<tool_response>(.*)<\/.*>")
|
||||
if extracted:
|
||||
return _execute_tool_call_str(extracted, tools)
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def _match_and_extract(message: str, pattern: str) -> Optional[str]:
|
||||
""" helper function to match regex and extract group 1 """
|
||||
match = re.search(pattern, message, re.DOTALL)
|
||||
if match:
|
||||
group1 = match.group(1)
|
||||
return group1
|
||||
return None
|
||||
|
||||
|
||||
def _execute_tool_call_str(tool_call: str, tools: List[Callable]) -> Optional[str]:
|
||||
""" execute tool call per string content. The content must be a valid json """
|
||||
try:
|
||||
js = json.loads(tool_call)
|
||||
return _execute_tool_call_json(js, tools)
|
||||
except json.JSONDecodeError:
|
||||
utils.print_error("Json was malformed. Will be ignored.")
|
||||
return None
|
||||
|
||||
def _execute_tool_call_json(data: any, tools: List[Callable]) -> Optional[str]:
|
||||
""" extract name and arguments from parsed data and call the tool, which is matched from the tools list """
|
||||
# Extract tool name and arguments
|
||||
tool_name = data.get("name")
|
||||
arguments = data.get("arguments", {})
|
||||
|
||||
# Find the tool by name in the list of tools
|
||||
for tool in tools:
|
||||
if tool.__name__ == tool_name:
|
||||
# Execute the tool
|
||||
return _execute_tool_function(arguments, tool)
|
||||
|
||||
utils.print_error("Specified tool '%s' not found." % tool_name)
|
||||
return None
|
||||
|
||||
def _execute_tool_function(arguments: any, tool: Callable) -> Optional[str]:
|
||||
""" Execute the tool and return the result. """
|
||||
try:
|
||||
result = tool(**arguments)
|
||||
print("<tool_response>", result, "</tool_response>")
|
||||
return result
|
||||
except TypeError as e:
|
||||
utils.print_error("Type error while executing function call: '%s'" % str(e))
|
||||
|
||||
return None
|
20
chatbug/utils.py
Normal file
20
chatbug/utils.py
Normal file
@@ -0,0 +1,20 @@
|
||||
import json
|
||||
import sys
|
||||
import datetime
|
||||
|
||||
|
||||
def load_json_file(filepath: str) -> any:
|
||||
with open(filepath, "r") as f:
|
||||
return json.load(f)
|
||||
|
||||
def save_string_as_file(content: str, filepath: str = None) -> None:
|
||||
if filepath == None:
|
||||
filepath = "temp_" + datetime.datetime.now().strftime("%Y%m%d_%H%M%S") + ".txt"
|
||||
|
||||
with open(filepath, "w") as f:
|
||||
f.write(content)
|
||||
|
||||
|
||||
|
||||
def print_error(*args, **kwargs):
|
||||
print(*args, file=sys.stderr, **kwargs)
|
Reference in New Issue
Block a user