4 changed files with 136 additions and 169 deletions
@ -1,7 +1,19 @@ |
|||||
|
|
||||
|
|
||||
def tool_dummy(a: int, b: str): |
def tool_dummy(a: int, b: str): |
||||
|
""" |
||||
|
tool_dummy |
||||
|
Args: |
||||
|
a: how much |
||||
|
b: how text? |
||||
|
""" |
||||
return "result_%d_%s" % (a, b) |
return "result_%d_%s" % (a, b) |
||||
|
|
||||
|
|
||||
def tool_dummy2(text: str): |
def tool_dummy2(text: str): |
||||
|
""" |
||||
|
tool_dummy2 |
||||
|
Args: |
||||
|
text: only text? |
||||
|
""" |
||||
return text.upper() |
return text.upper() |
@ -0,0 +1,122 @@ |
|||||
|
import pytest |
||||
|
import tests.helper as helper |
||||
|
|
||||
|
inference = None |
||||
|
InferenceClass = None |
||||
|
Tensor = None |
||||
|
|
||||
|
|
||||
|
def prepare(): |
||||
|
if InferenceClass == None: |
||||
|
test_import_inference_module_librarys() |
||||
|
if inference == None: |
||||
|
test_instantiate_inference_instance() |
||||
|
|
||||
|
|
||||
|
def test_import_inference_module_librarys(): |
||||
|
import inference |
||||
|
import torch |
||||
|
global InferenceClass |
||||
|
global Tensor |
||||
|
InferenceClass = inference.Inference |
||||
|
Tensor = torch.Tensor |
||||
|
|
||||
|
|
||||
|
def test_instantiate_inference_instance(): |
||||
|
if InferenceClass == None: |
||||
|
test_import_inference_module_librarys() |
||||
|
global inference |
||||
|
inference = InferenceClass() |
||||
|
|
||||
|
|
||||
|
def test_tool_header_generation(): |
||||
|
prepare() |
||||
|
tools = [helper.tool_dummy, helper.tool_dummy2] |
||||
|
header = inference.generate_tool_use_header(tools) |
||||
|
assert len(header) > 100 |
||||
|
|
||||
|
|
||||
|
def test_tokenize_dummy(): |
||||
|
prepare() |
||||
|
|
||||
|
system_message = "Hold a casual conversation with the user. Keep responses short at max 3 sentences." |
||||
|
user_message = "say 'Hello World!'" |
||||
|
assistant_message = "Hello World!" |
||||
|
|
||||
|
messages = [ |
||||
|
{"role": "system", "content": system_message}, |
||||
|
{"role": "user", "content": user_message}, |
||||
|
{"role": "assistant", "content": assistant_message} |
||||
|
] |
||||
|
history = inference.tokenize(messages, tokenize=False) |
||||
|
|
||||
|
assert type(history) is str |
||||
|
assert history.find(system_message) != -1 |
||||
|
assert history.find(user_message) != -1 |
||||
|
assert history.find(assistant_message) != -1 |
||||
|
|
||||
|
|
||||
|
def test_tokenize_tensor(): |
||||
|
prepare() |
||||
|
|
||||
|
system_message = "Hold a casual conversation with the user. Keep responses short at max 3 sentences." |
||||
|
user_message = "say 'Hello World!'" |
||||
|
assistant_message = "Hello World!" |
||||
|
|
||||
|
messages = [ |
||||
|
{"role": "system", "content": system_message}, |
||||
|
{"role": "user", "content": user_message}, |
||||
|
{"role": "assistant", "content": assistant_message} |
||||
|
] |
||||
|
history = inference.tokenize(messages, tokenize=True) |
||||
|
|
||||
|
assert type(history) is Tensor |
||||
|
assert len(history[0]) >= len(str(messages).split(" ")) |
||||
|
|
||||
|
|
||||
|
def test_inference(): |
||||
|
prepare() |
||||
|
|
||||
|
system_message = "Pretend you are a Python console." |
||||
|
user_message = "print('Hello World!')" |
||||
|
|
||||
|
messages = [ |
||||
|
{"role": "system", "content": system_message}, |
||||
|
{"role": "user", "content": user_message}, |
||||
|
] |
||||
|
input_ids = inference.tokenize(messages, tokenize=True) |
||||
|
|
||||
|
assert type(input_ids) is Tensor |
||||
|
assert len(input_ids[0]) >= len(str(messages).split(" ")) |
||||
|
|
||||
|
generated_tokens, full_output = inference.generate_incremental(input_ids) |
||||
|
|
||||
|
assert type(generated_tokens) is Tensor |
||||
|
assert len(generated_tokens[0]) > 2 |
||||
|
|
||||
|
assert type(full_output) is str |
||||
|
# assert full_output.find("Hello World!") >= 0 |
||||
|
|
||||
|
|
||||
|
def test_inference_2(): |
||||
|
prepare() |
||||
|
|
||||
|
system_message = "Pretend you are a Python console." |
||||
|
user_message = "print('Hello World!')" |
||||
|
|
||||
|
messages = [ |
||||
|
{"role": "system", "content": system_message}, |
||||
|
{"role": "user", "content": user_message}, |
||||
|
] |
||||
|
input_ids = inference.tokenize(messages, tokenize=True) |
||||
|
|
||||
|
assert type(input_ids) is Tensor |
||||
|
assert len(input_ids[0]) >= len(str(messages).split(" ")) |
||||
|
|
||||
|
generated_tokens, full_output = inference.generate_batch(input_ids) |
||||
|
|
||||
|
assert type(generated_tokens) is Tensor |
||||
|
assert len(generated_tokens[0]) > 2 |
||||
|
|
||||
|
assert type(full_output) is str |
||||
|
# assert full_output.find("Hello World!") >= 0 |
Loading…
Reference in new issue