tool functions and pytest
This commit is contained in:
102
tool_helper.py
Normal file
102
tool_helper.py
Normal file
@@ -0,0 +1,102 @@
|
||||
|
||||
from typing import Callable, List, Optional
|
||||
import json
|
||||
import re
|
||||
import utils
|
||||
|
||||
tool_list = []
|
||||
|
||||
|
||||
def tool(fn):
|
||||
"""tool function decorator"""
|
||||
print("register tool '%s'" % fn.__name__)
|
||||
tool_list.append(fn)
|
||||
|
||||
# def parse_and_execute_tool_call(message: str, tools: list[function]) -> str | None:
|
||||
# """execute tool call if needed accordint <tool_call> tag and return the content of the tool call or None if no call happened."""
|
||||
# pass
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
def parse_and_execute_tool_call(message: str, tools: List[Callable]) -> Optional[str]:
|
||||
"""
|
||||
Execute a tool call if the <tool_call> tag is present and return the tool's response.
|
||||
If no <tool_call> tag is found, return None.
|
||||
|
||||
Args:
|
||||
message (str): The message containing the tool call.
|
||||
tools (list[function]): A list of tool functions available for execution.
|
||||
|
||||
Returns:
|
||||
Optional[str]: The content of the tool response or None if no tool call occurred.
|
||||
"""
|
||||
|
||||
# in case LLM responds with <tool_call></tool_call> the correct way
|
||||
extracted = _match_and_extract(message, r"<tool_call>(.*)<\/tool_call>")
|
||||
if extracted:
|
||||
return _execute_tool_call_str(extracted, tools)
|
||||
|
||||
# in case LLM responds with <tool_call></tool_response> by accident
|
||||
extracted = _match_and_extract(message, r"<tool_call>(.*)<\/tool_.*>")
|
||||
if extracted:
|
||||
return _execute_tool_call_str(extracted, tools)
|
||||
|
||||
# in case LLM responds with <tool_call></???> by accident
|
||||
extracted = _match_and_extract(message, r"<tool_call>(.*)<\/.*>")
|
||||
if extracted:
|
||||
return _execute_tool_call_str(extracted, tools)
|
||||
|
||||
# in case LLM responds with <tool_call></???> by accident
|
||||
extracted = _match_and_extract(message, r"<tool_response>(.*)<\/.*>")
|
||||
if extracted:
|
||||
return _execute_tool_call_str(extracted, tools)
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def _match_and_extract(message: str, pattern: str) -> Optional[str]:
|
||||
""" helper function to match regex and extract group 1 """
|
||||
match = re.search(pattern, message, re.DOTALL)
|
||||
if match:
|
||||
group1 = match.group(1)
|
||||
return group1
|
||||
return None
|
||||
|
||||
|
||||
def _execute_tool_call_str(tool_call: str, tools: List[Callable]) -> Optional[str]:
|
||||
""" execute tool call per string content. The content must be a valid json """
|
||||
try:
|
||||
js = json.loads(tool_call)
|
||||
return _execute_tool_call_json(js, tools)
|
||||
except json.JSONDecodeError:
|
||||
utils.print_error("Json was malformed. Will be ignored.")
|
||||
return None
|
||||
|
||||
def _execute_tool_call_json(data: any, tools: List[Callable]) -> Optional[str]:
|
||||
""" extract name and arguments from parsed data and call the tool, which is matched from the tools list """
|
||||
# Extract tool name and arguments
|
||||
tool_name = data.get("name")
|
||||
arguments = data.get("arguments", {})
|
||||
|
||||
# Find the tool by name in the list of tools
|
||||
for tool in tools:
|
||||
if tool.__name__ == tool_name:
|
||||
# Execute the tool
|
||||
return _execute_tool_function(arguments, tool)
|
||||
|
||||
utils.print_error("Specified tool '%s' not found." % tool_name)
|
||||
return None
|
||||
|
||||
def _execute_tool_function(arguments: any, tool: Callable) -> Optional[str]:
|
||||
""" Execute the tool and return the result. """
|
||||
try:
|
||||
result = tool(**arguments)
|
||||
print("<tool_response>", result, "</tool_response>")
|
||||
return result
|
||||
except TypeError as e:
|
||||
utils.print_error("Type error while executing function call: '%s'" % str(e))
|
||||
|
||||
return None
|
Reference in New Issue
Block a user