You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
76 lines
2.7 KiB
76 lines
2.7 KiB
from inference import Inference
|
|
from modelconfig import Modelconfig
|
|
import time
|
|
import nvidia_smi
|
|
import torch
|
|
import gc
|
|
|
|
|
|
def empty_cuda():
|
|
while True:
|
|
gc.collect()
|
|
torch.cuda.empty_cache()
|
|
time.sleep(0.5)
|
|
vram = nvidia_smi.get_gpu_stats()["memory_used"]
|
|
print("vram: %d MB" % vram)
|
|
if vram < 200:
|
|
return
|
|
|
|
|
|
def profile_ex(model_conf: Modelconfig):
|
|
print("")
|
|
empty_cuda()
|
|
messages = [
|
|
{"role": "system", "content": "Hold a casual conversation with the user. Keep responses short at max 3 sentences. Answer using markdown to the user."},
|
|
{"role": "user", "content": "How do astronomers determine the original wavelength of light emitted by a celestial body at rest, which is necessary for measuring its speed using the Doppler effect?"},
|
|
]
|
|
|
|
gpu_stats_before = nvidia_smi.get_gpu_stats()
|
|
inference = Inference(model_conf)
|
|
|
|
gpu_stats_loaded = nvidia_smi.get_gpu_stats()
|
|
t_start = time.time()
|
|
input_ids = inference.tokenize(messages, tokenize=True)
|
|
generated_tokens, full_output = inference.generate_batch(input_ids, print_stdout=False)
|
|
t_end = time.time()
|
|
gpu_stats_after = nvidia_smi.get_gpu_stats()
|
|
|
|
took = t_end - t_start
|
|
tokens = len(generated_tokens[0])
|
|
tokens_per = tokens / took
|
|
vram_bulk = gpu_stats_loaded["memory_used"] - gpu_stats_before["memory_used"]
|
|
vram_top = gpu_stats_after["memory_used"] - gpu_stats_loaded["memory_used"]
|
|
print("model: %s" % model_conf.model_name)
|
|
print("tokens: %d tk" % tokens)
|
|
print("time: %.3f s" % took)
|
|
print("speed: %.3f tk/s" % tokens_per)
|
|
print("vram_bulk: %d MB" % vram_bulk)
|
|
print("vram_top: %d MB" % vram_top)
|
|
print("context: %d tk" % inference.max_context_length)
|
|
print("")
|
|
|
|
|
|
def profile(model_conf):
|
|
try:
|
|
profile_ex(model_conf)
|
|
except Exception as e:
|
|
print("exception: " + str(e))
|
|
pass
|
|
|
|
|
|
def main():
|
|
profile(Modelconfig("NousResearch/Hermes-3-Llama-3.2-3B", load_in_8bit=True))
|
|
profile(Modelconfig("unsloth/Llama-3.2-1B"))
|
|
profile(Modelconfig("unsloth/Llama-3.2-3B-Instruct", load_in_8bit=True))
|
|
profile(Modelconfig("unsloth/llama-3-8b-bnb-4bit"))
|
|
# profile(Modelconfig("unsloth/Llama-3.2-3B-Instruct-GGUF", load_in_8bit=True))
|
|
profile(Modelconfig("unsloth/gemma-2-9b-it-bnb-4bit"))
|
|
profile(Modelconfig("unsloth/Qwen2.5-7B-Instruct-bnb-4bit"))
|
|
profile(Modelconfig("unsloth/Qwen2.5-3B-Instruct", load_in_4bit=True))
|
|
profile(Modelconfig("unsloth/Qwen2.5-3B-Instruct", load_in_8bit=True))
|
|
profile(Modelconfig("unsloth/mistral-7b-instruct-v0.3-bnb-4bit"))
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|