You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
102 lines
3.4 KiB
102 lines
3.4 KiB
|
|
from typing import Callable, List, Optional
|
|
import json
|
|
import re
|
|
import utils
|
|
|
|
tool_list = []
|
|
|
|
|
|
def tool(fn):
|
|
"""tool function decorator"""
|
|
print("register tool '%s'" % fn.__name__)
|
|
tool_list.append(fn)
|
|
|
|
# def parse_and_execute_tool_call(message: str, tools: list[function]) -> str | None:
|
|
# """execute tool call if needed accordint <tool_call> tag and return the content of the tool call or None if no call happened."""
|
|
# pass
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def parse_and_execute_tool_call(message: str, tools: List[Callable]) -> Optional[str]:
|
|
"""
|
|
Execute a tool call if the <tool_call> tag is present and return the tool's response.
|
|
If no <tool_call> tag is found, return None.
|
|
|
|
Args:
|
|
message (str): The message containing the tool call.
|
|
tools (list[function]): A list of tool functions available for execution.
|
|
|
|
Returns:
|
|
Optional[str]: The content of the tool response or None if no tool call occurred.
|
|
"""
|
|
|
|
# in case LLM responds with <tool_call></tool_call> the correct way
|
|
extracted = _match_and_extract(message, r"<tool_call>(.*)<\/tool_call>")
|
|
if extracted:
|
|
return _execute_tool_call_str(extracted, tools)
|
|
|
|
# in case LLM responds with <tool_call></tool_response> by accident
|
|
extracted = _match_and_extract(message, r"<tool_call>(.*)<\/tool_.*>")
|
|
if extracted:
|
|
return _execute_tool_call_str(extracted, tools)
|
|
|
|
# in case LLM responds with <tool_call></???> by accident
|
|
extracted = _match_and_extract(message, r"<tool_call>(.*)<\/.*>")
|
|
if extracted:
|
|
return _execute_tool_call_str(extracted, tools)
|
|
|
|
# in case LLM responds with <tool_call></???> by accident
|
|
extracted = _match_and_extract(message, r"<tool_response>(.*)<\/.*>")
|
|
if extracted:
|
|
return _execute_tool_call_str(extracted, tools)
|
|
|
|
return None
|
|
|
|
|
|
def _match_and_extract(message: str, pattern: str) -> Optional[str]:
|
|
""" helper function to match regex and extract group 1 """
|
|
match = re.search(pattern, message, re.DOTALL)
|
|
if match:
|
|
group1 = match.group(1)
|
|
return group1
|
|
return None
|
|
|
|
|
|
def _execute_tool_call_str(tool_call: str, tools: List[Callable]) -> Optional[str]:
|
|
""" execute tool call per string content. The content must be a valid json """
|
|
try:
|
|
js = json.loads(tool_call)
|
|
return _execute_tool_call_json(js, tools)
|
|
except json.JSONDecodeError:
|
|
utils.print_error("Json was malformed. Will be ignored.")
|
|
return None
|
|
|
|
def _execute_tool_call_json(data: any, tools: List[Callable]) -> Optional[str]:
|
|
""" extract name and arguments from parsed data and call the tool, which is matched from the tools list """
|
|
# Extract tool name and arguments
|
|
tool_name = data.get("name")
|
|
arguments = data.get("arguments", {})
|
|
|
|
# Find the tool by name in the list of tools
|
|
for tool in tools:
|
|
if tool.__name__ == tool_name:
|
|
# Execute the tool
|
|
return _execute_tool_function(arguments, tool)
|
|
|
|
utils.print_error("Specified tool '%s' not found." % tool_name)
|
|
return None
|
|
|
|
def _execute_tool_function(arguments: any, tool: Callable) -> Optional[str]:
|
|
""" Execute the tool and return the result. """
|
|
try:
|
|
result = tool(**arguments)
|
|
print("<tool_response>", result, "</tool_response>")
|
|
return result
|
|
except TypeError as e:
|
|
utils.print_error("Type error while executing function call: '%s'" % str(e))
|
|
|
|
return None
|
|
|