You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 

102 lines
3.4 KiB

from typing import Callable, List, Optional
import json
import re
import utils
tool_list = []
def tool(fn):
"""tool function decorator"""
print("register tool '%s'" % fn.__name__)
tool_list.append(fn)
# def parse_and_execute_tool_call(message: str, tools: list[function]) -> str | None:
# """execute tool call if needed accordint <tool_call> tag and return the content of the tool call or None if no call happened."""
# pass
def parse_and_execute_tool_call(message: str, tools: List[Callable]) -> Optional[str]:
"""
Execute a tool call if the <tool_call> tag is present and return the tool's response.
If no <tool_call> tag is found, return None.
Args:
message (str): The message containing the tool call.
tools (list[function]): A list of tool functions available for execution.
Returns:
Optional[str]: The content of the tool response or None if no tool call occurred.
"""
# in case LLM responds with <tool_call></tool_call> the correct way
extracted = _match_and_extract(message, r"<tool_call>(.*)<\/tool_call>")
if extracted:
return _execute_tool_call_str(extracted, tools)
# in case LLM responds with <tool_call></tool_response> by accident
extracted = _match_and_extract(message, r"<tool_call>(.*)<\/tool_.*>")
if extracted:
return _execute_tool_call_str(extracted, tools)
# in case LLM responds with <tool_call></???> by accident
extracted = _match_and_extract(message, r"<tool_call>(.*)<\/.*>")
if extracted:
return _execute_tool_call_str(extracted, tools)
# in case LLM responds with <tool_call></???> by accident
extracted = _match_and_extract(message, r"<tool_response>(.*)<\/.*>")
if extracted:
return _execute_tool_call_str(extracted, tools)
return None
def _match_and_extract(message: str, pattern: str) -> Optional[str]:
""" helper function to match regex and extract group 1 """
match = re.search(pattern, message, re.DOTALL)
if match:
group1 = match.group(1)
return group1
return None
def _execute_tool_call_str(tool_call: str, tools: List[Callable]) -> Optional[str]:
""" execute tool call per string content. The content must be a valid json """
try:
js = json.loads(tool_call)
return _execute_tool_call_json(js, tools)
except json.JSONDecodeError:
utils.print_error("Json was malformed. Will be ignored.")
return None
def _execute_tool_call_json(data: any, tools: List[Callable]) -> Optional[str]:
""" extract name and arguments from parsed data and call the tool, which is matched from the tools list """
# Extract tool name and arguments
tool_name = data.get("name")
arguments = data.get("arguments", {})
# Find the tool by name in the list of tools
for tool in tools:
if tool.__name__ == tool_name:
# Execute the tool
return _execute_tool_function(arguments, tool)
utils.print_error("Specified tool '%s' not found." % tool_name)
return None
def _execute_tool_function(arguments: any, tool: Callable) -> Optional[str]:
""" Execute the tool and return the result. """
try:
result = tool(**arguments)
print("<tool_response>", result, "</tool_response>")
return result
except TypeError as e:
utils.print_error("Type error while executing function call: '%s'" % str(e))
return None