4 changed files with 136 additions and 169 deletions
@ -1,7 +1,19 @@ |
|||
|
|||
|
|||
def tool_dummy(a: int, b: str): |
|||
""" |
|||
tool_dummy |
|||
Args: |
|||
a: how much |
|||
b: how text? |
|||
""" |
|||
return "result_%d_%s" % (a, b) |
|||
|
|||
|
|||
def tool_dummy2(text: str): |
|||
""" |
|||
tool_dummy2 |
|||
Args: |
|||
text: only text? |
|||
""" |
|||
return text.upper() |
@ -0,0 +1,122 @@ |
|||
import pytest |
|||
import tests.helper as helper |
|||
|
|||
inference = None |
|||
InferenceClass = None |
|||
Tensor = None |
|||
|
|||
|
|||
def prepare(): |
|||
if InferenceClass == None: |
|||
test_import_inference_module_librarys() |
|||
if inference == None: |
|||
test_instantiate_inference_instance() |
|||
|
|||
|
|||
def test_import_inference_module_librarys(): |
|||
import inference |
|||
import torch |
|||
global InferenceClass |
|||
global Tensor |
|||
InferenceClass = inference.Inference |
|||
Tensor = torch.Tensor |
|||
|
|||
|
|||
def test_instantiate_inference_instance(): |
|||
if InferenceClass == None: |
|||
test_import_inference_module_librarys() |
|||
global inference |
|||
inference = InferenceClass() |
|||
|
|||
|
|||
def test_tool_header_generation(): |
|||
prepare() |
|||
tools = [helper.tool_dummy, helper.tool_dummy2] |
|||
header = inference.generate_tool_use_header(tools) |
|||
assert len(header) > 100 |
|||
|
|||
|
|||
def test_tokenize_dummy(): |
|||
prepare() |
|||
|
|||
system_message = "Hold a casual conversation with the user. Keep responses short at max 3 sentences." |
|||
user_message = "say 'Hello World!'" |
|||
assistant_message = "Hello World!" |
|||
|
|||
messages = [ |
|||
{"role": "system", "content": system_message}, |
|||
{"role": "user", "content": user_message}, |
|||
{"role": "assistant", "content": assistant_message} |
|||
] |
|||
history = inference.tokenize(messages, tokenize=False) |
|||
|
|||
assert type(history) is str |
|||
assert history.find(system_message) != -1 |
|||
assert history.find(user_message) != -1 |
|||
assert history.find(assistant_message) != -1 |
|||
|
|||
|
|||
def test_tokenize_tensor(): |
|||
prepare() |
|||
|
|||
system_message = "Hold a casual conversation with the user. Keep responses short at max 3 sentences." |
|||
user_message = "say 'Hello World!'" |
|||
assistant_message = "Hello World!" |
|||
|
|||
messages = [ |
|||
{"role": "system", "content": system_message}, |
|||
{"role": "user", "content": user_message}, |
|||
{"role": "assistant", "content": assistant_message} |
|||
] |
|||
history = inference.tokenize(messages, tokenize=True) |
|||
|
|||
assert type(history) is Tensor |
|||
assert len(history[0]) >= len(str(messages).split(" ")) |
|||
|
|||
|
|||
def test_inference(): |
|||
prepare() |
|||
|
|||
system_message = "Pretend you are a Python console." |
|||
user_message = "print('Hello World!')" |
|||
|
|||
messages = [ |
|||
{"role": "system", "content": system_message}, |
|||
{"role": "user", "content": user_message}, |
|||
] |
|||
input_ids = inference.tokenize(messages, tokenize=True) |
|||
|
|||
assert type(input_ids) is Tensor |
|||
assert len(input_ids[0]) >= len(str(messages).split(" ")) |
|||
|
|||
generated_tokens, full_output = inference.generate_incremental(input_ids) |
|||
|
|||
assert type(generated_tokens) is Tensor |
|||
assert len(generated_tokens[0]) > 2 |
|||
|
|||
assert type(full_output) is str |
|||
# assert full_output.find("Hello World!") >= 0 |
|||
|
|||
|
|||
def test_inference_2(): |
|||
prepare() |
|||
|
|||
system_message = "Pretend you are a Python console." |
|||
user_message = "print('Hello World!')" |
|||
|
|||
messages = [ |
|||
{"role": "system", "content": system_message}, |
|||
{"role": "user", "content": user_message}, |
|||
] |
|||
input_ids = inference.tokenize(messages, tokenize=True) |
|||
|
|||
assert type(input_ids) is Tensor |
|||
assert len(input_ids[0]) >= len(str(messages).split(" ")) |
|||
|
|||
generated_tokens, full_output = inference.generate_batch(input_ids) |
|||
|
|||
assert type(generated_tokens) is Tensor |
|||
assert len(generated_tokens[0]) > 2 |
|||
|
|||
assert type(full_output) is str |
|||
# assert full_output.find("Hello World!") >= 0 |
Loading…
Reference in new issue