172 lines
6.8 KiB
Python
172 lines
6.8 KiB
Python
import json
|
||
from typing import Any, List, Literal
|
||
|
||
from pydantic import Field
|
||
|
||
from app.agent.react import ReActAgent
|
||
from app.logger import logger
|
||
from app.prompt.toolcall import NEXT_STEP_PROMPT, SYSTEM_PROMPT
|
||
from app.schema import AgentState, Message, ToolCall
|
||
from app.tool import CreateChatCompletion, Terminate, ToolCollection
|
||
|
||
TOOL_CALL_REQUIRED = "Tool calls required but none provided"
|
||
|
||
|
||
class ToolCallAgent(ReActAgent):
|
||
"""Base agent class for handling tool/function calls with enhanced abstraction"""
|
||
|
||
name: str = "toolcall"
|
||
description: str = "an agent that can execute tool calls."
|
||
|
||
system_prompt: str = SYSTEM_PROMPT
|
||
next_step_prompt: str = NEXT_STEP_PROMPT
|
||
|
||
available_tools: ToolCollection = ToolCollection(
|
||
CreateChatCompletion(), Terminate()
|
||
)
|
||
tool_choices: Literal["none", "auto", "required"] = "auto"
|
||
special_tool_names: List[str] = Field(default_factory=lambda: [Terminate().name])
|
||
|
||
tool_calls: List[ToolCall] = Field(default_factory=list)
|
||
|
||
max_steps: int = 30
|
||
|
||
async def think(self) -> bool:
|
||
"""Process current state and decide next actions using tools"""
|
||
if self.next_step_prompt:
|
||
user_msg = Message.user_message(self.next_step_prompt)
|
||
self.messages += [user_msg]
|
||
|
||
# Get response with tool options
|
||
response = await self.llm.ask_tool(
|
||
messages=self.messages,
|
||
system_msgs=[Message.system_message(self.system_prompt)] if self.system_prompt else None,
|
||
tools=self.available_tools.to_params(),
|
||
tool_choice=self.tool_choices,
|
||
)
|
||
self.tool_calls = response.tool_calls
|
||
|
||
# Log response info in a more engaging way
|
||
logger.info(f"✨ AI的思考过程:{response.content}")
|
||
logger.info(f"🛠️ AI选择了 {len(response.tool_calls) if response.tool_calls else 0} 个工具来解决问题")
|
||
if response.tool_calls:
|
||
logger.info(f"🧰 准备使用的工具箱:{[call.function.name for call in response.tool_calls]}")
|
||
|
||
try:
|
||
# Handle different tool_choices modes
|
||
if self.tool_choices == "none":
|
||
if response.tool_calls:
|
||
logger.warning("🤔 咦?AI想用工具,但现在不允许使用任何工具呢!")
|
||
if response.content:
|
||
self.memory.add_message(Message.assistant_message(response.content))
|
||
return True
|
||
return False
|
||
|
||
# Create and add assistant message
|
||
assistant_msg = (
|
||
Message.from_tool_calls(
|
||
content=response.content, tool_calls=self.tool_calls
|
||
)
|
||
if self.tool_calls
|
||
else Message.assistant_message(response.content)
|
||
)
|
||
self.memory.add_message(assistant_msg)
|
||
|
||
if self.tool_choices == "required" and not self.tool_calls:
|
||
return True # Will be handled in act()
|
||
|
||
# For 'auto' mode, continue with content if no commands but content exists
|
||
if self.tool_choices == "auto" and not self.tool_calls:
|
||
return bool(response.content)
|
||
|
||
return bool(self.tool_calls)
|
||
except Exception as e:
|
||
logger.error(f"🚨 糟糕!AI思考时遇到了一点小问题:{e}")
|
||
self.memory.add_message(Message.assistant_message(
|
||
f"Error encountered while processing: {str(e)}"
|
||
))
|
||
return False
|
||
|
||
async def act(self) -> str:
|
||
"""Execute tool calls and handle their results"""
|
||
if not self.tool_calls:
|
||
if self.tool_choices == "required":
|
||
raise ValueError(TOOL_CALL_REQUIRED)
|
||
|
||
# Return last message content if no tool calls
|
||
return (
|
||
self.messages[-1].content or "No content or commands to execute"
|
||
)
|
||
|
||
results = []
|
||
for command in self.tool_calls:
|
||
result = await self.execute_tool(command)
|
||
logger.info(
|
||
f"🎯 工具「{command.function.name}」完成任务啦!结果:{result[:100]}{'...' if len(result) > 100 else ''}"
|
||
)
|
||
|
||
# Add tool response to memory
|
||
tool_msg = Message.tool_message(
|
||
content=result, tool_call_id=command.id, name=command.function.name
|
||
)
|
||
self.memory.add_message(tool_msg)
|
||
results.append(result)
|
||
|
||
return "\n\n".join(results)
|
||
|
||
async def execute_tool(self, command: ToolCall) -> str:
|
||
"""Execute a single tool call with robust error handling"""
|
||
if not command or not command.function or not command.function.name:
|
||
return "Error: Invalid command format"
|
||
|
||
name = command.function.name
|
||
if name not in self.available_tools.tool_map:
|
||
return f"Error: Unknown tool '{name}'"
|
||
|
||
try:
|
||
# Parse arguments
|
||
args = json.loads(command.function.arguments or "{}")
|
||
|
||
# Execute the tool
|
||
logger.info(f"🔧 正在启动工具:「{name}」...")
|
||
result = await self.available_tools.execute(name=name, tool_input=args)
|
||
|
||
# Format result for display
|
||
observation = (
|
||
f"Observed output of cmd `{name}` executed:\n{str(result)}"
|
||
if result
|
||
else f"Cmd `{name}` completed with no output"
|
||
)
|
||
|
||
# Handle special tools like `finish`
|
||
await self._handle_special_tool(name=name, result=result)
|
||
|
||
return observation
|
||
except json.JSONDecodeError:
|
||
error_msg = f"Error parsing arguments for {name}: Invalid JSON format"
|
||
logger.error(f"📝 哎呀!「{name}」工具的参数格式不对 - JSON解析失败")
|
||
return f"Error: {error_msg}"
|
||
except Exception as e:
|
||
error_msg = f"Error executing tool {name}: {str(e)}"
|
||
logger.error(f"⚠️ 工具「{name}」执行时遇到了一点小麻烦:{str(e)}")
|
||
return f"Error: {error_msg}"
|
||
|
||
async def _handle_special_tool(self, name: str, result: Any, **kwargs):
|
||
"""Handle special tool execution and state changes"""
|
||
if not self._is_special_tool(name):
|
||
return
|
||
|
||
if self._should_finish_execution(name=name, result=result, **kwargs):
|
||
# Set agent state to finished
|
||
logger.info(f"🏁 特殊工具「{name}」已完成任务,AI助手的工作告一段落!")
|
||
self.state = AgentState.FINISHED
|
||
|
||
@staticmethod
|
||
def _should_finish_execution(**kwargs) -> bool:
|
||
"""Determine if tool execution should finish the agent"""
|
||
return True
|
||
|
||
def _is_special_tool(self, name: str) -> bool:
|
||
"""Check if tool name is in special tools list"""
|
||
return name.lower() in [n.lower() for n in self.special_tool_names]
|