cleanup, tests
This commit is contained in:
@@ -1,7 +1,19 @@
|
||||
|
||||
|
||||
def tool_dummy(a: int, b: str):
|
||||
"""
|
||||
tool_dummy
|
||||
Args:
|
||||
a: how much
|
||||
b: how text?
|
||||
"""
|
||||
return "result_%d_%s" % (a, b)
|
||||
|
||||
|
||||
def tool_dummy2(text: str):
|
||||
"""
|
||||
tool_dummy2
|
||||
Args:
|
||||
text: only text?
|
||||
"""
|
||||
return text.upper()
|
122
tests/test_inference.py
Normal file
122
tests/test_inference.py
Normal file
@@ -0,0 +1,122 @@
|
||||
import pytest
|
||||
import tests.helper as helper
|
||||
|
||||
inference = None
|
||||
InferenceClass = None
|
||||
Tensor = None
|
||||
|
||||
|
||||
def prepare():
|
||||
if InferenceClass == None:
|
||||
test_import_inference_module_librarys()
|
||||
if inference == None:
|
||||
test_instantiate_inference_instance()
|
||||
|
||||
|
||||
def test_import_inference_module_librarys():
|
||||
import inference
|
||||
import torch
|
||||
global InferenceClass
|
||||
global Tensor
|
||||
InferenceClass = inference.Inference
|
||||
Tensor = torch.Tensor
|
||||
|
||||
|
||||
def test_instantiate_inference_instance():
|
||||
if InferenceClass == None:
|
||||
test_import_inference_module_librarys()
|
||||
global inference
|
||||
inference = InferenceClass()
|
||||
|
||||
|
||||
def test_tool_header_generation():
|
||||
prepare()
|
||||
tools = [helper.tool_dummy, helper.tool_dummy2]
|
||||
header = inference.generate_tool_use_header(tools)
|
||||
assert len(header) > 100
|
||||
|
||||
|
||||
def test_tokenize_dummy():
|
||||
prepare()
|
||||
|
||||
system_message = "Hold a casual conversation with the user. Keep responses short at max 3 sentences."
|
||||
user_message = "say 'Hello World!'"
|
||||
assistant_message = "Hello World!"
|
||||
|
||||
messages = [
|
||||
{"role": "system", "content": system_message},
|
||||
{"role": "user", "content": user_message},
|
||||
{"role": "assistant", "content": assistant_message}
|
||||
]
|
||||
history = inference.tokenize(messages, tokenize=False)
|
||||
|
||||
assert type(history) is str
|
||||
assert history.find(system_message) != -1
|
||||
assert history.find(user_message) != -1
|
||||
assert history.find(assistant_message) != -1
|
||||
|
||||
|
||||
def test_tokenize_tensor():
|
||||
prepare()
|
||||
|
||||
system_message = "Hold a casual conversation with the user. Keep responses short at max 3 sentences."
|
||||
user_message = "say 'Hello World!'"
|
||||
assistant_message = "Hello World!"
|
||||
|
||||
messages = [
|
||||
{"role": "system", "content": system_message},
|
||||
{"role": "user", "content": user_message},
|
||||
{"role": "assistant", "content": assistant_message}
|
||||
]
|
||||
history = inference.tokenize(messages, tokenize=True)
|
||||
|
||||
assert type(history) is Tensor
|
||||
assert len(history[0]) >= len(str(messages).split(" "))
|
||||
|
||||
|
||||
def test_inference():
|
||||
prepare()
|
||||
|
||||
system_message = "Pretend you are a Python console."
|
||||
user_message = "print('Hello World!')"
|
||||
|
||||
messages = [
|
||||
{"role": "system", "content": system_message},
|
||||
{"role": "user", "content": user_message},
|
||||
]
|
||||
input_ids = inference.tokenize(messages, tokenize=True)
|
||||
|
||||
assert type(input_ids) is Tensor
|
||||
assert len(input_ids[0]) >= len(str(messages).split(" "))
|
||||
|
||||
generated_tokens, full_output = inference.generate_incremental(input_ids)
|
||||
|
||||
assert type(generated_tokens) is Tensor
|
||||
assert len(generated_tokens[0]) > 2
|
||||
|
||||
assert type(full_output) is str
|
||||
# assert full_output.find("Hello World!") >= 0
|
||||
|
||||
|
||||
def test_inference_2():
|
||||
prepare()
|
||||
|
||||
system_message = "Pretend you are a Python console."
|
||||
user_message = "print('Hello World!')"
|
||||
|
||||
messages = [
|
||||
{"role": "system", "content": system_message},
|
||||
{"role": "user", "content": user_message},
|
||||
]
|
||||
input_ids = inference.tokenize(messages, tokenize=True)
|
||||
|
||||
assert type(input_ids) is Tensor
|
||||
assert len(input_ids[0]) >= len(str(messages).split(" "))
|
||||
|
||||
generated_tokens, full_output = inference.generate_batch(input_ids)
|
||||
|
||||
assert type(generated_tokens) is Tensor
|
||||
assert len(generated_tokens[0]) > 2
|
||||
|
||||
assert type(full_output) is str
|
||||
# assert full_output.find("Hello World!") >= 0
|
@@ -5,6 +5,7 @@ import tests.helper as helper
|
||||
|
||||
|
||||
def test_tool_function_decorator():
|
||||
""" @tool """
|
||||
# get length before adding tools
|
||||
start_len = len(tool_helper.tool_list)
|
||||
|
||||
|
Reference in New Issue
Block a user