You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 

110 lines
3.1 KiB

import pytest
from tests import helper
inference = None
Tensor = None
def prepare():
global inference
global Tensor
if inference == None:
from torch import Tensor as _Tensor
from chatbug.inference import Inference
from chatbug.model_selection import get_model
inference = Inference(get_model())
Tensor = _Tensor
def test_tool_header_generation():
prepare()
tools = [helper.tool_dummy, helper.tool_dummy2]
header = inference.generate_tool_use_header(tools)
assert len(header) > 100
def test_tokenize_dummy():
prepare()
system_message = "Hold a casual conversation with the user. Keep responses short at max 3 sentences."
user_message = "say 'Hello World!'"
assistant_message = "Hello World!"
messages = [
{"role": "system", "content": system_message},
{"role": "user", "content": user_message},
{"role": "assistant", "content": assistant_message}
]
history = inference.tokenize(messages, tokenize=False)
assert type(history) is str
assert history.find(system_message) != -1
assert history.find(user_message) != -1
assert history.find(assistant_message) != -1
def test_tokenize_tensor():
prepare()
system_message = "Hold a casual conversation with the user. Keep responses short at max 3 sentences."
user_message = "say 'Hello World!'"
assistant_message = "Hello World!"
messages = [
{"role": "system", "content": system_message},
{"role": "user", "content": user_message},
{"role": "assistant", "content": assistant_message}
]
history = inference.tokenize(messages, tokenize=True)
assert type(history) is Tensor
assert len(history[0]) >= len(str(messages).split(" "))
def test_inference():
prepare()
system_message = "Pretend you are a Python console."
user_message = "print('Hello World!')"
messages = [
{"role": "system", "content": system_message},
{"role": "user", "content": user_message},
]
input_ids = inference.tokenize(messages, tokenize=True)
assert type(input_ids) is Tensor
assert len(input_ids[0]) >= len(str(messages).split(" "))
generated_tokens, full_output = inference.generate_incremental(input_ids)
assert type(generated_tokens) is Tensor
assert len(generated_tokens[0]) > 2
assert type(full_output) is str
# assert full_output.find("Hello World!") >= 0
def test_inference_2():
prepare()
system_message = "Pretend you are a Python console."
user_message = "print('Hello World!')"
messages = [
{"role": "system", "content": system_message},
{"role": "user", "content": user_message},
]
input_ids = inference.tokenize(messages, tokenize=True)
assert type(input_ids) is Tensor
assert len(input_ids[0]) >= len(str(messages).split(" "))
generated_tokens, full_output = inference.generate_batch(input_ids)
assert type(generated_tokens) is Tensor
assert len(generated_tokens[0]) > 2
assert type(full_output) is str
# assert full_output.find("Hello World!") >= 0