You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
260 lines
11 KiB
260 lines
11 KiB
if __name__ == "__main__":
|
|
# this message is at the start, because initializing torch/transformers takes lots of time. fail fast.
|
|
raise Exception("cannot execute this file directly")
|
|
|
|
|
|
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
|
|
from transformers.cache_utils import (
|
|
DynamicCache,
|
|
SinkCache,
|
|
StaticCache,
|
|
SlidingWindowCache,
|
|
QuantoQuantizedCache,
|
|
QuantizedCacheConfig,
|
|
)
|
|
import torch
|
|
import time
|
|
import utils
|
|
import re
|
|
import os
|
|
from modelconfig import Modelconfig
|
|
|
|
torch.set_num_threads(os.cpu_count()) # Adjust this to the number of threads/cores you have
|
|
|
|
|
|
class Inference:
|
|
def __init__(self, modelconfig: Modelconfig):
|
|
print("loading LLM '%s'..." % modelconfig.model_name)
|
|
t_start = time.time()
|
|
|
|
# model_name = "NousResearch/Llama-2-7b-hf" # will cache on C:\Users\ftobler\.cache\huggingface\hub
|
|
# model_name = "NousResearch/Hermes-3-Llama-3.2-3B" # will cache on C:\Users\ftobler\.cache\huggingface\hub
|
|
# model_name = "unsloth/phi-4-unsloth-bnb-4bit" #too big
|
|
# model_name = "gpt2"
|
|
# model_name = "NousResearch/Hermes-2-Pro-Llama-3-8B"
|
|
# model_name = "Orenguteng/Llama-3.1-8B-Lexi-Uncensored-V2"
|
|
# "meta-llama/Llama-2-7b-hf" # Replace with your chosen model
|
|
|
|
|
|
# quantization_config_4bit = BitsAndBytesConfig( # tool calls don't really work in 4 bit mode
|
|
# load_in_4bit=True,
|
|
# bnb_4bit_quant_type="nf4", # Recommended for better performance
|
|
# bnb_4bit_use_double_quant=True, # Optional: Further quantization for more memory saving
|
|
# bnb_4bit_compute_dtype=torch.bfloat16 # Use bfloat16 for computation
|
|
# )
|
|
|
|
# quantization_config_8bit = BitsAndBytesConfig(load_in_8bit=True)
|
|
|
|
# Load the model with quantization (optional)
|
|
if modelconfig.bits_and_bytes_config != None:
|
|
self.model = AutoModelForCausalLM.from_pretrained(
|
|
modelconfig.model_name,
|
|
# device_map="auto", # Automatically places parts of the model on GPU/CPU
|
|
# device_map="cuda", # Automatically places parts of the model on GPU/CPU
|
|
device_map="cuda", # Automatically places parts of the model on GPU/CPU
|
|
# load_in_8bit=True, # Enables 8-bit quantization if bitsandbytes is installed
|
|
quantization_config=modelconfig.bits_and_bytes_config
|
|
)
|
|
else:
|
|
self.model = AutoModelForCausalLM.from_pretrained(
|
|
modelconfig.model_name,
|
|
device_map="cuda",
|
|
)
|
|
|
|
# print("apply optimization")
|
|
# self.model.generation_config.cache_implementation = "static"
|
|
# self.model.forward = torch.compile(self.model.forward, mode="reduce-overhead", fullgraph=True)
|
|
|
|
|
|
# Load tokenizer
|
|
self.tokenizer = AutoTokenizer.from_pretrained(modelconfig.model_name)
|
|
|
|
print("load took %.3fs" % (time.time() - t_start))
|
|
|
|
self.max_context_length = self.model.config.max_position_embeddings
|
|
|
|
|
|
self.tokenizer.chat_template = utils.load_json_file("chat_template.json")
|
|
|
|
print("max_context_length is %d tokens." % (self.max_context_length))
|
|
|
|
|
|
def generate(self, input_ids: torch.Tensor, print_stdout=True) -> tuple[torch.Tensor, str]:
|
|
with torch.inference_mode():
|
|
with torch.no_grad():
|
|
return self.generate_incremental_2(input_ids, print_stdout)
|
|
|
|
|
|
def generate_batch(self, input_ids: torch.Tensor, print_stdout:bool=True) -> tuple[torch.Tensor, str]:
|
|
outputs = self.model.generate(
|
|
input_ids, # **inputs, inputs["input_ids"]
|
|
max_new_tokens=500, # max_length=max_context_length,
|
|
pad_token_id=self.tokenizer.pad_token_id,
|
|
eos_token_id=self.tokenizer.eos_token_id,
|
|
do_sample=True,
|
|
num_return_sequences=1,
|
|
num_beams = 1
|
|
)
|
|
# skip all input tokens and only output the additional generated part of the conversation
|
|
input_token_count = len(input_ids[0])
|
|
out_text = self.tokenizer.decode(outputs[0][input_token_count:], skip_special_tokens=True)
|
|
if print_stdout:
|
|
print(out_text)
|
|
return outputs, out_text
|
|
|
|
|
|
def generate_incremental_2(self, input_ids: torch.Tensor, print_stdout:bool=True) -> tuple[torch.Tensor, str]:
|
|
generated_tokens = input_ids
|
|
|
|
past_key_values = DynamicCache()
|
|
|
|
# n = 0
|
|
try:
|
|
while True:
|
|
outputs = self.model.generate(
|
|
generated_tokens, # **inputs, inputs["input_ids"]
|
|
max_new_tokens=5, # like streaming
|
|
pad_token_id=self.tokenizer.pad_token_id,
|
|
eos_token_id=self.tokenizer.eos_token_id,
|
|
do_sample=True,
|
|
num_return_sequences=1,
|
|
num_beams = 1,
|
|
# use_cache=True,
|
|
past_key_values=past_key_values
|
|
)
|
|
|
|
# Get the next token (the last token from the generated sequence)
|
|
# next_token = outputs.argmax(dim=-1)[:, -1]
|
|
new_tokens = outputs[0, len(generated_tokens[0]):]
|
|
# next_token = outputs[0,-1]
|
|
|
|
# Append the new token to the sequence
|
|
generated_tokens = outputs
|
|
# generated_tokens = torch.cat([generated_tokens, next_token.unsqueeze(0).unsqueeze(0)], dim=1)
|
|
|
|
# Decode and print the newly generated token (skip special tokens)
|
|
# out_text = self.tokenizer.decode(next_token, skip_special_tokens=True)
|
|
out_text = self.tokenizer.decode(new_tokens, skip_special_tokens=True)
|
|
if print_stdout:
|
|
print(out_text, end="", flush=True) # Print without newline
|
|
|
|
# Check if the generated token is the end-of-sequence token
|
|
# if next_token.item() == self.tokenizer.eos_token_id:
|
|
if new_tokens[-1].item() == self.tokenizer.eos_token_id:
|
|
if print_stdout:
|
|
print("")
|
|
break
|
|
|
|
# n += 1
|
|
# if n >= 15:
|
|
# n = 0
|
|
# torch.cuda.empty_cache()
|
|
except KeyboardInterrupt:
|
|
pass
|
|
|
|
# Once done, return the full generated sequence
|
|
input_token_count = len(input_ids[0])
|
|
full_output = self.tokenizer.decode(generated_tokens[0][input_token_count:], skip_special_tokens=True)
|
|
|
|
# torch.cuda.empty_cache()
|
|
|
|
return generated_tokens, full_output
|
|
|
|
|
|
def generate_incremental(self, input_ids: torch.Tensor, print_stdout:bool=True) -> tuple[torch.Tensor, str]:
|
|
with torch.inference_mode():
|
|
return self._generate_incremental(input_ids, print_stdout)
|
|
|
|
|
|
def _generate_incremental(self, input_ids: torch.Tensor, print_stdout:bool=True) -> tuple[torch.Tensor, str]:
|
|
# Start with the initial input tokens
|
|
generated_tokens = input_ids # Initially, this is just the input tokens
|
|
|
|
# past_key_values = DynamicCache()
|
|
# max_cache_length = past_key_values.get_max_length()
|
|
|
|
n = 0
|
|
try:
|
|
|
|
# Loop to generate one token at a time
|
|
while True:
|
|
# Call the model with the current tokens
|
|
outputs = self.model(
|
|
input_ids=generated_tokens,
|
|
use_cache=True,
|
|
num_beams = 1
|
|
# past_key_values=past_key_values
|
|
)
|
|
|
|
# Get the next token (the last token from the generated sequence)
|
|
next_token = outputs.logits.argmax(dim=-1)[:, -1]
|
|
|
|
# Append the new token to the sequence
|
|
generated_tokens = torch.cat([generated_tokens, next_token.unsqueeze(0)], dim=1)
|
|
|
|
# Decode and print the newly generated token (skip special tokens)
|
|
out_text = self.tokenizer.decode(next_token, skip_special_tokens=True)
|
|
if print_stdout:
|
|
print(out_text, end="", flush=True) # Print without newline
|
|
|
|
# Check if the generated token is the end-of-sequence token
|
|
if next_token.item() == self.tokenizer.eos_token_id:
|
|
if print_stdout:
|
|
print("")
|
|
break
|
|
|
|
n += 1
|
|
if n >= 15:
|
|
n = 0
|
|
torch.cuda.empty_cache()
|
|
|
|
except KeyboardInterrupt:
|
|
pass
|
|
|
|
# Once done, return the full generated sequence
|
|
input_token_count = len(input_ids[0])
|
|
full_output = self.tokenizer.decode(generated_tokens[0][input_token_count:], skip_special_tokens=True)
|
|
|
|
torch.cuda.empty_cache()
|
|
|
|
return generated_tokens, full_output
|
|
|
|
|
|
def tokenize(self, messages: list[dict], tokenize: bool, assistant_prefix: str = None) -> str | torch.Tensor:
|
|
if tokenize:
|
|
inputs = self.tokenizer.apply_chat_template(messages, return_tensors="pt", tokenize=True, return_dict=True, add_generation_prompt=True) #continue_final_message=True,
|
|
inputs = {key: value.to(self.model.device) for key, value in inputs.items()}
|
|
input_ids = inputs["input_ids"]
|
|
|
|
# Append the assistant prefix if provided
|
|
if assistant_prefix:
|
|
prefix_ids = self.tokenizer(assistant_prefix, return_tensors="pt")["input_ids"]
|
|
input_ids = torch.cat([input_ids, prefix_ids.to(self.model.device)], dim=-1)
|
|
|
|
return input_ids
|
|
else:
|
|
# only plain text generation
|
|
message = self.tokenizer.apply_chat_template(messages, return_tensors="pt", tokenize=False, add_generation_prompt=False)
|
|
|
|
# Append the assistant prefix to raw text if provided
|
|
if assistant_prefix:
|
|
message += f"<|im_start|>assistant\n{assistant_prefix}"
|
|
|
|
return message
|
|
|
|
|
|
def generate_tool_use_header(self, tools: list[callable]) -> str:
|
|
temp_messages = [{}] # for some reason an empty array is not allowed but a {} inside works like an empty array.
|
|
s = self.tokenizer.apply_chat_template(temp_messages, return_tensors="pt", tokenize=False, add_generation_prompt=False, tools=tools)
|
|
pattern = r"<\|im_start\|>system\n(.*)<\|im_end\|>"
|
|
match = re.search(pattern, s, re.DOTALL)
|
|
if not match:
|
|
raise Exception("Failed to regex match the template tool system text.")
|
|
extraction = match.group(1)
|
|
return extraction
|
|
|
|
|
|
def torch_reseed(seed: int):
|
|
torch.manual_seed(seed)
|
|
torch.cuda.manual_seed_all(seed)
|
|
|
|
|