try out some more models
This commit is contained in:
79
inference.py
79
inference.py
@@ -17,41 +17,49 @@ import time
|
||||
import utils
|
||||
import re
|
||||
import os
|
||||
from modelconfig import Modelconfig
|
||||
|
||||
torch.set_num_threads(os.cpu_count()) # Adjust this to the number of threads/cores you have
|
||||
|
||||
|
||||
class Inference:
|
||||
def __init__(self):
|
||||
print("loading LLM...")
|
||||
def __init__(self, modelconfig: Modelconfig):
|
||||
print("loading LLM '%s'..." % modelconfig.model_name)
|
||||
t_start = time.time()
|
||||
|
||||
# model_name = "NousResearch/Llama-2-7b-hf" # will cache on C:\Users\ftobler\.cache\huggingface\hub
|
||||
model_name = "NousResearch/Hermes-3-Llama-3.2-3B" # will cache on C:\Users\ftobler\.cache\huggingface\hub
|
||||
# model_name = "NousResearch/Hermes-3-Llama-3.2-3B" # will cache on C:\Users\ftobler\.cache\huggingface\hub
|
||||
# model_name = "unsloth/phi-4-unsloth-bnb-4bit" #too big
|
||||
# model_name = "gpt2"
|
||||
# model_name = "NousResearch/Hermes-2-Pro-Llama-3-8B"
|
||||
# model_name = "Orenguteng/Llama-3.1-8B-Lexi-Uncensored-V2"
|
||||
# "meta-llama/Llama-2-7b-hf" # Replace with your chosen model
|
||||
|
||||
|
||||
quantization_config_4bit = BitsAndBytesConfig( # tool calls don't really work in 4 bit mode
|
||||
load_in_4bit=True,
|
||||
bnb_4bit_quant_type="nf4", # Recommended for better performance
|
||||
bnb_4bit_use_double_quant=True, # Optional: Further quantization for more memory saving
|
||||
bnb_4bit_compute_dtype=torch.bfloat16 # Use bfloat16 for computation
|
||||
)
|
||||
# quantization_config_4bit = BitsAndBytesConfig( # tool calls don't really work in 4 bit mode
|
||||
# load_in_4bit=True,
|
||||
# bnb_4bit_quant_type="nf4", # Recommended for better performance
|
||||
# bnb_4bit_use_double_quant=True, # Optional: Further quantization for more memory saving
|
||||
# bnb_4bit_compute_dtype=torch.bfloat16 # Use bfloat16 for computation
|
||||
# )
|
||||
|
||||
quantization_config_8bit = BitsAndBytesConfig(load_in_8bit=True)
|
||||
# quantization_config_8bit = BitsAndBytesConfig(load_in_8bit=True)
|
||||
|
||||
# Load the model with quantization (optional)
|
||||
self.model = AutoModelForCausalLM.from_pretrained(
|
||||
model_name,
|
||||
# device_map="auto", # Automatically places parts of the model on GPU/CPU
|
||||
# device_map="cuda", # Automatically places parts of the model on GPU/CPU
|
||||
device_map="cuda", # Automatically places parts of the model on GPU/CPU
|
||||
# load_in_8bit=True, # Enables 8-bit quantization if bitsandbytes is installed
|
||||
quantization_config=quantization_config_8bit
|
||||
)
|
||||
if modelconfig.bits_and_bytes_config != None:
|
||||
self.model = AutoModelForCausalLM.from_pretrained(
|
||||
modelconfig.model_name,
|
||||
# device_map="auto", # Automatically places parts of the model on GPU/CPU
|
||||
# device_map="cuda", # Automatically places parts of the model on GPU/CPU
|
||||
device_map="cuda", # Automatically places parts of the model on GPU/CPU
|
||||
# load_in_8bit=True, # Enables 8-bit quantization if bitsandbytes is installed
|
||||
quantization_config=modelconfig.bits_and_bytes_config
|
||||
)
|
||||
else:
|
||||
self.model = AutoModelForCausalLM.from_pretrained(
|
||||
modelconfig.model_name,
|
||||
device_map="cuda",
|
||||
)
|
||||
|
||||
# print("apply optimization")
|
||||
# self.model.generation_config.cache_implementation = "static"
|
||||
@@ -59,25 +67,25 @@ class Inference:
|
||||
|
||||
|
||||
# Load tokenizer
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(modelconfig.model_name)
|
||||
|
||||
print("load took %.3fs" % (time.time() - t_start))
|
||||
|
||||
max_context_length = self.model.config.max_position_embeddings
|
||||
self.max_context_length = self.model.config.max_position_embeddings
|
||||
|
||||
|
||||
self.tokenizer.chat_template = utils.load_json_file("chat_template.json")
|
||||
|
||||
print("max_context_length is %d tokens." % (max_context_length))
|
||||
print("max_context_length is %d tokens." % (self.max_context_length))
|
||||
|
||||
|
||||
def generate(self, input_ids: torch.Tensor) -> tuple[torch.Tensor, str]:
|
||||
def generate(self, input_ids: torch.Tensor, print_stdout=True) -> tuple[torch.Tensor, str]:
|
||||
with torch.inference_mode():
|
||||
with torch.no_grad():
|
||||
return self.generate_incremental_2(input_ids)
|
||||
return self.generate_incremental_2(input_ids, print_stdout)
|
||||
|
||||
|
||||
def generate_batch(self, input_ids: torch.Tensor) -> tuple[torch.Tensor, str]:
|
||||
def generate_batch(self, input_ids: torch.Tensor, print_stdout:bool=True) -> tuple[torch.Tensor, str]:
|
||||
outputs = self.model.generate(
|
||||
input_ids, # **inputs, inputs["input_ids"]
|
||||
max_new_tokens=500, # max_length=max_context_length,
|
||||
@@ -90,11 +98,12 @@ class Inference:
|
||||
# skip all input tokens and only output the additional generated part of the conversation
|
||||
input_token_count = len(input_ids[0])
|
||||
out_text = self.tokenizer.decode(outputs[0][input_token_count:], skip_special_tokens=True)
|
||||
print(out_text)
|
||||
if print_stdout:
|
||||
print(out_text)
|
||||
return outputs, out_text
|
||||
|
||||
|
||||
def generate_incremental_2(self, input_ids: torch.Tensor) -> tuple[torch.Tensor, str]:
|
||||
def generate_incremental_2(self, input_ids: torch.Tensor, print_stdout:bool=True) -> tuple[torch.Tensor, str]:
|
||||
generated_tokens = input_ids
|
||||
|
||||
past_key_values = DynamicCache()
|
||||
@@ -126,12 +135,14 @@ class Inference:
|
||||
# Decode and print the newly generated token (skip special tokens)
|
||||
# out_text = self.tokenizer.decode(next_token, skip_special_tokens=True)
|
||||
out_text = self.tokenizer.decode(new_tokens, skip_special_tokens=True)
|
||||
print(out_text, end="", flush=True) # Print without newline
|
||||
if print_stdout:
|
||||
print(out_text, end="", flush=True) # Print without newline
|
||||
|
||||
# Check if the generated token is the end-of-sequence token
|
||||
# if next_token.item() == self.tokenizer.eos_token_id:
|
||||
if new_tokens[-1].item() == self.tokenizer.eos_token_id:
|
||||
print("")
|
||||
if print_stdout:
|
||||
print("")
|
||||
break
|
||||
|
||||
# n += 1
|
||||
@@ -150,12 +161,12 @@ class Inference:
|
||||
return generated_tokens, full_output
|
||||
|
||||
|
||||
def generate_incremental(self, input_ids: torch.Tensor) -> tuple[torch.Tensor, str]:
|
||||
def generate_incremental(self, input_ids: torch.Tensor, print_stdout:bool=True) -> tuple[torch.Tensor, str]:
|
||||
with torch.inference_mode():
|
||||
return self._generate_incremental(input_ids)
|
||||
return self._generate_incremental(input_ids, print_stdout)
|
||||
|
||||
|
||||
def _generate_incremental(self, input_ids: torch.Tensor) -> tuple[torch.Tensor, str]:
|
||||
def _generate_incremental(self, input_ids: torch.Tensor, print_stdout:bool=True) -> tuple[torch.Tensor, str]:
|
||||
# Start with the initial input tokens
|
||||
generated_tokens = input_ids # Initially, this is just the input tokens
|
||||
|
||||
@@ -183,11 +194,13 @@ class Inference:
|
||||
|
||||
# Decode and print the newly generated token (skip special tokens)
|
||||
out_text = self.tokenizer.decode(next_token, skip_special_tokens=True)
|
||||
print(out_text, end="", flush=True) # Print without newline
|
||||
if print_stdout:
|
||||
print(out_text, end="", flush=True) # Print without newline
|
||||
|
||||
# Check if the generated token is the end-of-sequence token
|
||||
if next_token.item() == self.tokenizer.eos_token_id:
|
||||
print("")
|
||||
if print_stdout:
|
||||
print("")
|
||||
break
|
||||
|
||||
n += 1
|
||||
|
Reference in New Issue
Block a user