cleanup, tests

This commit is contained in:
2025-01-02 02:49:57 +01:00
parent 4d034c7f2b
commit 71e5fa96f3
4 changed files with 136 additions and 169 deletions

View File

@@ -1,7 +1,19 @@
def tool_dummy(a: int, b: str):
"""
tool_dummy
Args:
a: how much
b: how text?
"""
return "result_%d_%s" % (a, b)
def tool_dummy2(text: str):
"""
tool_dummy2
Args:
text: only text?
"""
return text.upper()

122
tests/test_inference.py Normal file
View File

@@ -0,0 +1,122 @@
import pytest
import tests.helper as helper
inference = None
InferenceClass = None
Tensor = None
def prepare():
if InferenceClass == None:
test_import_inference_module_librarys()
if inference == None:
test_instantiate_inference_instance()
def test_import_inference_module_librarys():
import inference
import torch
global InferenceClass
global Tensor
InferenceClass = inference.Inference
Tensor = torch.Tensor
def test_instantiate_inference_instance():
if InferenceClass == None:
test_import_inference_module_librarys()
global inference
inference = InferenceClass()
def test_tool_header_generation():
prepare()
tools = [helper.tool_dummy, helper.tool_dummy2]
header = inference.generate_tool_use_header(tools)
assert len(header) > 100
def test_tokenize_dummy():
prepare()
system_message = "Hold a casual conversation with the user. Keep responses short at max 3 sentences."
user_message = "say 'Hello World!'"
assistant_message = "Hello World!"
messages = [
{"role": "system", "content": system_message},
{"role": "user", "content": user_message},
{"role": "assistant", "content": assistant_message}
]
history = inference.tokenize(messages, tokenize=False)
assert type(history) is str
assert history.find(system_message) != -1
assert history.find(user_message) != -1
assert history.find(assistant_message) != -1
def test_tokenize_tensor():
prepare()
system_message = "Hold a casual conversation with the user. Keep responses short at max 3 sentences."
user_message = "say 'Hello World!'"
assistant_message = "Hello World!"
messages = [
{"role": "system", "content": system_message},
{"role": "user", "content": user_message},
{"role": "assistant", "content": assistant_message}
]
history = inference.tokenize(messages, tokenize=True)
assert type(history) is Tensor
assert len(history[0]) >= len(str(messages).split(" "))
def test_inference():
prepare()
system_message = "Pretend you are a Python console."
user_message = "print('Hello World!')"
messages = [
{"role": "system", "content": system_message},
{"role": "user", "content": user_message},
]
input_ids = inference.tokenize(messages, tokenize=True)
assert type(input_ids) is Tensor
assert len(input_ids[0]) >= len(str(messages).split(" "))
generated_tokens, full_output = inference.generate_incremental(input_ids)
assert type(generated_tokens) is Tensor
assert len(generated_tokens[0]) > 2
assert type(full_output) is str
# assert full_output.find("Hello World!") >= 0
def test_inference_2():
prepare()
system_message = "Pretend you are a Python console."
user_message = "print('Hello World!')"
messages = [
{"role": "system", "content": system_message},
{"role": "user", "content": user_message},
]
input_ids = inference.tokenize(messages, tokenize=True)
assert type(input_ids) is Tensor
assert len(input_ids[0]) >= len(str(messages).split(" "))
generated_tokens, full_output = inference.generate_batch(input_ids)
assert type(generated_tokens) is Tensor
assert len(generated_tokens[0]) > 2
assert type(full_output) is str
# assert full_output.find("Hello World!") >= 0

View File

@@ -5,6 +5,7 @@ import tests.helper as helper
def test_tool_function_decorator():
""" @tool """
# get length before adding tools
start_len = len(tool_helper.tool_list)