2 changed files with 258 additions and 99 deletions
@ -0,0 +1,148 @@ |
|||
if __name__ == "__main__": |
|||
# this message is at the start, because initializing torch/transformers takes lots of time. fail fast. |
|||
raise Exception("cannot execute this file directly") |
|||
|
|||
|
|||
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig |
|||
import torch |
|||
import time |
|||
import utils |
|||
import re |
|||
|
|||
|
|||
|
|||
class Inference: |
|||
def __init__(self): |
|||
print("loading LLM...") |
|||
t_start = time.time() |
|||
|
|||
# model_name = "NousResearch/Llama-2-7b-hf" # will cache on C:\Users\ftobler\.cache\huggingface\hub |
|||
model_name = "NousResearch/Hermes-3-Llama-3.2-3B" # will cache on C:\Users\ftobler\.cache\huggingface\hub |
|||
# model_name = "NousResearch/Hermes-2-Pro-Llama-3-8B" |
|||
# model_name = "Orenguteng/Llama-3.1-8B-Lexi-Uncensored-V2" |
|||
# "meta-llama/Llama-2-7b-hf" # Replace with your chosen model |
|||
|
|||
|
|||
quantization_config_4bit = BitsAndBytesConfig( # tool calls don't really work in 4 bit mode |
|||
load_in_4bit=True, |
|||
bnb_4bit_quant_type="nf4", # Recommended for better performance |
|||
bnb_4bit_use_double_quant=True, # Optional: Further quantization for more memory saving |
|||
bnb_4bit_compute_dtype=torch.bfloat16 # Use bfloat16 for computation |
|||
) |
|||
|
|||
quantization_config_8bit = BitsAndBytesConfig(load_in_8bit=True) |
|||
|
|||
# Load the model with quantization (optional) |
|||
self.model = AutoModelForCausalLM.from_pretrained( |
|||
model_name, |
|||
# device_map="auto", # Automatically places parts of the model on GPU/CPU |
|||
# device_map="cuda", # Automatically places parts of the model on GPU/CPU |
|||
device_map="cuda", # Automatically places parts of the model on GPU/CPU |
|||
# load_in_8bit=True, # Enables 8-bit quantization if bitsandbytes is installed |
|||
quantization_config=quantization_config_8bit |
|||
) |
|||
|
|||
# Load tokenizer |
|||
self.tokenizer = AutoTokenizer.from_pretrained(model_name) |
|||
|
|||
print("load took %.3fs" % (time.time() - t_start)) |
|||
|
|||
max_context_length = self.model.config.max_position_embeddings |
|||
|
|||
|
|||
self.tokenizer.chat_template = utils.load_json_file("chat_template.json") |
|||
|
|||
print("max_context_length is %d tokens." % (max_context_length)) |
|||
|
|||
|
|||
def generate_batch(self, input_ids: torch.Tensor) -> tuple[torch.Tensor, str]: |
|||
outputs = self.model.generate( |
|||
input_ids, # **inputs, inputs["input_ids"] |
|||
max_new_tokens=500, # max_length=max_context_length, |
|||
pad_token_id=self.tokenizer.pad_token_id, |
|||
eos_token_id=self.tokenizer.eos_token_id, |
|||
do_sample=True, |
|||
num_return_sequences=1 |
|||
) |
|||
# skip all input tokens and only output the additional generated part of the conversation |
|||
input_token_count = len(input_ids[0]) |
|||
out_text = self.tokenizer.decode(outputs[0][input_token_count:], skip_special_tokens=True) |
|||
print(out_text) |
|||
return outputs, out_text |
|||
|
|||
|
|||
|
|||
def generate_incremental(self, input_ids: torch.Tensor) -> tuple[torch.Tensor, str]: |
|||
with torch.inference_mode(): |
|||
return self._generate_incremental(input_ids) |
|||
|
|||
|
|||
def _generate_incremental(self, input_ids: torch.Tensor) -> tuple[torch.Tensor, str]: |
|||
# Start with the initial input tokens |
|||
generated_tokens = input_ids # Initially, this is just the input tokens |
|||
|
|||
n = 0 |
|||
try: |
|||
|
|||
# Loop to generate one token at a time |
|||
while True: |
|||
# Call the model with the current tokens |
|||
outputs = self.model(input_ids=generated_tokens, use_cache=True) |
|||
|
|||
# Get the next token (the last token from the generated sequence) |
|||
next_token = outputs.logits.argmax(dim=-1)[:, -1] |
|||
|
|||
# Append the new token to the sequence |
|||
generated_tokens = torch.cat([generated_tokens, next_token.unsqueeze(0)], dim=1) |
|||
|
|||
# Decode and print the newly generated token (skip special tokens) |
|||
out_text = self.tokenizer.decode(next_token, skip_special_tokens=True) |
|||
print(out_text, end="", flush=True) # Print without newline |
|||
|
|||
# Check if the generated token is the end-of-sequence token |
|||
if next_token.item() == self.tokenizer.eos_token_id: |
|||
print("") |
|||
break |
|||
|
|||
n += 1 |
|||
if n >= 15: |
|||
n = 0 |
|||
torch.cuda.empty_cache() |
|||
|
|||
except KeyboardInterrupt: |
|||
pass |
|||
|
|||
# Once done, return the full generated sequence |
|||
input_token_count = len(input_ids[0]) |
|||
full_output = self.tokenizer.decode(generated_tokens[0][input_token_count:], skip_special_tokens=True) |
|||
|
|||
torch.cuda.empty_cache() |
|||
|
|||
return generated_tokens, full_output |
|||
|
|||
|
|||
def tokenize(self, messages: list[dict], tokenize: bool) -> str | torch.Tensor: |
|||
if tokenize: |
|||
inputs = self.tokenizer.apply_chat_template(messages, return_tensors="pt", tokenize=True, return_dict=True, add_generation_prompt=True) #continue_final_message=True, |
|||
inputs = {key: value.to(self.model.device) for key, value in inputs.items()} |
|||
return inputs["input_ids"] |
|||
else: |
|||
message = self.tokenizer.apply_chat_template(messages, return_tensors="pt", tokenize=False, add_generation_prompt=False) |
|||
return message |
|||
|
|||
|
|||
def generate_tool_use_header(self, tools: list[callable]) -> str: |
|||
temp_messages = [{}] # for some reason an empty array is not allowed but a {} inside works like an empty array. |
|||
s = self.tokenizer.apply_chat_template(temp_messages, return_tensors="pt", tokenize=False, add_generation_prompt=False, tools=tools) |
|||
pattern = r"<\|im_start\|>system\n(.*)<\|im_end\|>" |
|||
match = re.search(pattern, s, re.DOTALL) |
|||
if not match: |
|||
raise Exception("Failed to regex match the template tool system text.") |
|||
extraction = match.group(1) |
|||
return extraction |
|||
|
|||
|
|||
def torch_reseed(seed: int): |
|||
torch.manual_seed(seed) |
|||
torch.cuda.manual_seed_all(seed) |
|||
|
Loading…
Reference in new issue