You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 

76 lines
2.7 KiB

from inference import Inference
from modelconfig import Modelconfig
import time
import nvidia_smi
import torch
import gc
def empty_cuda():
while True:
gc.collect()
torch.cuda.empty_cache()
time.sleep(0.5)
vram = nvidia_smi.get_gpu_stats()["memory_used"]
print("vram: %d MB" % vram)
if vram < 200:
return
def profile_ex(model_conf: Modelconfig):
print("")
empty_cuda()
messages = [
{"role": "system", "content": "Hold a casual conversation with the user. Keep responses short at max 3 sentences. Answer using markdown to the user."},
{"role": "user", "content": "How do astronomers determine the original wavelength of light emitted by a celestial body at rest, which is necessary for measuring its speed using the Doppler effect?"},
]
gpu_stats_before = nvidia_smi.get_gpu_stats()
inference = Inference(model_conf)
gpu_stats_loaded = nvidia_smi.get_gpu_stats()
t_start = time.time()
input_ids = inference.tokenize(messages, tokenize=True)
generated_tokens, full_output = inference.generate_batch(input_ids, print_stdout=False)
t_end = time.time()
gpu_stats_after = nvidia_smi.get_gpu_stats()
took = t_end - t_start
tokens = len(generated_tokens[0])
tokens_per = tokens / took
vram_bulk = gpu_stats_loaded["memory_used"] - gpu_stats_before["memory_used"]
vram_top = gpu_stats_after["memory_used"] - gpu_stats_loaded["memory_used"]
print("model: %s" % model_conf.model_name)
print("tokens: %d tk" % tokens)
print("time: %.3f s" % took)
print("speed: %.3f tk/s" % tokens_per)
print("vram_bulk: %d MB" % vram_bulk)
print("vram_top: %d MB" % vram_top)
print("context: %d tk" % inference.max_context_length)
print("")
def profile(model_conf):
try:
profile_ex(model_conf)
except Exception as e:
print("exception: " + str(e))
pass
def main():
profile(Modelconfig("NousResearch/Hermes-3-Llama-3.2-3B", load_in_8bit=True))
profile(Modelconfig("unsloth/Llama-3.2-1B"))
profile(Modelconfig("unsloth/Llama-3.2-3B-Instruct", load_in_8bit=True))
profile(Modelconfig("unsloth/llama-3-8b-bnb-4bit"))
# profile(Modelconfig("unsloth/Llama-3.2-3B-Instruct-GGUF", load_in_8bit=True))
profile(Modelconfig("unsloth/gemma-2-9b-it-bnb-4bit"))
profile(Modelconfig("unsloth/Qwen2.5-7B-Instruct-bnb-4bit"))
profile(Modelconfig("unsloth/Qwen2.5-3B-Instruct", load_in_4bit=True))
profile(Modelconfig("unsloth/Qwen2.5-3B-Instruct", load_in_8bit=True))
profile(Modelconfig("unsloth/mistral-7b-instruct-v0.3-bnb-4bit"))
if __name__ == "__main__":
main()