import time import nvidia_smi import torch import gc from chatbug.inference import Inference from chatbug.modelconfig import Modelconfig def empty_cuda(): while True: gc.collect() torch.cuda.empty_cache() time.sleep(0.5) vram = nvidia_smi.get_gpu_stats()["memory_used"] print("vram: %d MB" % vram) if vram < 200: return def profile_ex(model_conf: Modelconfig): print("") empty_cuda() messages = [ {"role": "system", "content": "Hold a casual conversation with the user. Keep responses short at max 3 sentences. Answer using markdown to the user."}, {"role": "user", "content": "How do astronomers determine the original wavelength of light emitted by a celestial body at rest, which is necessary for measuring its speed using the Doppler effect?"}, ] gpu_stats_before = nvidia_smi.get_gpu_stats() inference = Inference(model_conf) gpu_stats_loaded = nvidia_smi.get_gpu_stats() t_start = time.time() input_ids = inference.tokenize(messages, tokenize=True) generated_tokens, full_output = inference.generate_batch(input_ids, print_stdout=False) t_end = time.time() gpu_stats_after = nvidia_smi.get_gpu_stats() took = t_end - t_start tokens = len(generated_tokens[0]) tokens_per = tokens / took vram_bulk = gpu_stats_loaded["memory_used"] - gpu_stats_before["memory_used"] vram_top = gpu_stats_after["memory_used"] - gpu_stats_loaded["memory_used"] print("model: %s" % model_conf.model_name) print("tokens: %d tk" % tokens) print("time: %.3f s" % took) print("speed: %.3f tk/s" % tokens_per) print("vram_bulk: %d MB" % vram_bulk) print("vram_top: %d MB" % vram_top) print("context: %d tk" % inference.max_context_length) print("") def profile(model_conf): try: profile_ex(model_conf) except Exception as e: print("exception: " + str(e)) pass def main(): profile(Modelconfig("NousResearch/Hermes-3-Llama-3.2-3B", load_in_8bit=True)) profile(Modelconfig("unsloth/Llama-3.2-1B")) profile(Modelconfig("unsloth/Llama-3.2-3B-Instruct", load_in_8bit=True)) profile(Modelconfig("unsloth/llama-3-8b-bnb-4bit")) # profile(Modelconfig("unsloth/Llama-3.2-3B-Instruct-GGUF", load_in_8bit=True)) profile(Modelconfig("unsloth/gemma-2-9b-it-bnb-4bit")) profile(Modelconfig("unsloth/Qwen2.5-7B-Instruct-bnb-4bit")) profile(Modelconfig("unsloth/Qwen2.5-3B-Instruct", load_in_4bit=True)) profile(Modelconfig("unsloth/Qwen2.5-3B-Instruct", load_in_8bit=True)) profile(Modelconfig("unsloth/mistral-7b-instruct-v0.3-bnb-4bit")) if __name__ == "__main__": main()