import pytest from tests import helper inference = None Tensor = None def prepare(): global inference global Tensor if inference == None: from torch import Tensor as _Tensor from chatbug.inference import Inference from chatbug.model_selection import get_model inference = Inference(get_model()) Tensor = _Tensor def test_tool_header_generation(): prepare() tools = [helper.tool_dummy, helper.tool_dummy2] header = inference.generate_tool_use_header(tools) assert len(header) > 100 def test_tokenize_dummy(): prepare() system_message = "Hold a casual conversation with the user. Keep responses short at max 3 sentences." user_message = "say 'Hello World!'" assistant_message = "Hello World!" messages = [ {"role": "system", "content": system_message}, {"role": "user", "content": user_message}, {"role": "assistant", "content": assistant_message} ] history = inference.tokenize(messages, tokenize=False) assert type(history) is str assert history.find(system_message) != -1 assert history.find(user_message) != -1 assert history.find(assistant_message) != -1 def test_tokenize_tensor(): prepare() system_message = "Hold a casual conversation with the user. Keep responses short at max 3 sentences." user_message = "say 'Hello World!'" assistant_message = "Hello World!" messages = [ {"role": "system", "content": system_message}, {"role": "user", "content": user_message}, {"role": "assistant", "content": assistant_message} ] history = inference.tokenize(messages, tokenize=True) assert type(history) is Tensor assert len(history[0]) >= len(str(messages).split(" ")) def test_inference(): prepare() system_message = "Pretend you are a Python console." user_message = "print('Hello World!')" messages = [ {"role": "system", "content": system_message}, {"role": "user", "content": user_message}, ] input_ids = inference.tokenize(messages, tokenize=True) assert type(input_ids) is Tensor assert len(input_ids[0]) >= len(str(messages).split(" ")) generated_tokens, full_output = inference.generate_incremental(input_ids) assert type(generated_tokens) is Tensor assert len(generated_tokens[0]) > 2 assert type(full_output) is str # assert full_output.find("Hello World!") >= 0 def test_inference_2(): prepare() system_message = "Pretend you are a Python console." user_message = "print('Hello World!')" messages = [ {"role": "system", "content": system_message}, {"role": "user", "content": user_message}, ] input_ids = inference.tokenize(messages, tokenize=True) assert type(input_ids) is Tensor assert len(input_ids[0]) >= len(str(messages).split(" ")) generated_tokens, full_output = inference.generate_batch(input_ids) assert type(generated_tokens) is Tensor assert len(generated_tokens[0]) > 2 assert type(full_output) is str # assert full_output.find("Hello World!") >= 0