You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
110 lines
3.1 KiB
110 lines
3.1 KiB
import pytest
|
|
from tests import helper
|
|
|
|
|
|
inference = None
|
|
Tensor = None
|
|
|
|
|
|
def prepare():
|
|
global inference
|
|
global Tensor
|
|
if inference == None:
|
|
from torch import Tensor as _Tensor
|
|
from chatbug.inference import Inference
|
|
from chatbug.model_selection import get_model
|
|
inference = Inference(get_model())
|
|
Tensor = _Tensor
|
|
|
|
|
|
def test_tool_header_generation():
|
|
prepare()
|
|
tools = [helper.tool_dummy, helper.tool_dummy2]
|
|
header = inference.generate_tool_use_header(tools)
|
|
assert len(header) > 100
|
|
|
|
|
|
def test_tokenize_dummy():
|
|
prepare()
|
|
|
|
system_message = "Hold a casual conversation with the user. Keep responses short at max 3 sentences."
|
|
user_message = "say 'Hello World!'"
|
|
assistant_message = "Hello World!"
|
|
|
|
messages = [
|
|
{"role": "system", "content": system_message},
|
|
{"role": "user", "content": user_message},
|
|
{"role": "assistant", "content": assistant_message}
|
|
]
|
|
history = inference.tokenize(messages, tokenize=False)
|
|
|
|
assert type(history) is str
|
|
assert history.find(system_message) != -1
|
|
assert history.find(user_message) != -1
|
|
assert history.find(assistant_message) != -1
|
|
|
|
|
|
def test_tokenize_tensor():
|
|
prepare()
|
|
|
|
system_message = "Hold a casual conversation with the user. Keep responses short at max 3 sentences."
|
|
user_message = "say 'Hello World!'"
|
|
assistant_message = "Hello World!"
|
|
|
|
messages = [
|
|
{"role": "system", "content": system_message},
|
|
{"role": "user", "content": user_message},
|
|
{"role": "assistant", "content": assistant_message}
|
|
]
|
|
history = inference.tokenize(messages, tokenize=True)
|
|
|
|
assert type(history) is Tensor
|
|
assert len(history[0]) >= len(str(messages).split(" "))
|
|
|
|
|
|
def test_inference():
|
|
prepare()
|
|
|
|
system_message = "Pretend you are a Python console."
|
|
user_message = "print('Hello World!')"
|
|
|
|
messages = [
|
|
{"role": "system", "content": system_message},
|
|
{"role": "user", "content": user_message},
|
|
]
|
|
input_ids = inference.tokenize(messages, tokenize=True)
|
|
|
|
assert type(input_ids) is Tensor
|
|
assert len(input_ids[0]) >= len(str(messages).split(" "))
|
|
|
|
generated_tokens, full_output = inference.generate_incremental(input_ids)
|
|
|
|
assert type(generated_tokens) is Tensor
|
|
assert len(generated_tokens[0]) > 2
|
|
|
|
assert type(full_output) is str
|
|
# assert full_output.find("Hello World!") >= 0
|
|
|
|
|
|
def test_inference_2():
|
|
prepare()
|
|
|
|
system_message = "Pretend you are a Python console."
|
|
user_message = "print('Hello World!')"
|
|
|
|
messages = [
|
|
{"role": "system", "content": system_message},
|
|
{"role": "user", "content": user_message},
|
|
]
|
|
input_ids = inference.tokenize(messages, tokenize=True)
|
|
|
|
assert type(input_ids) is Tensor
|
|
assert len(input_ids[0]) >= len(str(messages).split(" "))
|
|
|
|
generated_tokens, full_output = inference.generate_batch(input_ids)
|
|
|
|
assert type(generated_tokens) is Tensor
|
|
assert len(generated_tokens[0]) > 2
|
|
|
|
assert type(full_output) is str
|
|
# assert full_output.find("Hello World!") >= 0
|
|
|