186 lines
7.0 KiB
Python
186 lines
7.0 KiB
Python
from typing import Any, Dict, List, Optional, Tuple
|
|
|
|
from pydantic import Field
|
|
|
|
from app.agent.toolcall import ToolCallAgent
|
|
from app.logger import logger
|
|
from app.prompt.mcp import MULTIMEDIA_RESPONSE_PROMPT, NEXT_STEP_PROMPT, SYSTEM_PROMPT
|
|
from app.schema import AgentState, Message
|
|
from app.tool.base import ToolResult
|
|
from app.tool.mcp import MCPClients
|
|
|
|
|
|
class MCPAgent(ToolCallAgent):
|
|
"""Agent for interacting with MCP (Model Context Protocol) servers.
|
|
|
|
This agent connects to an MCP server using either SSE or stdio transport
|
|
and makes the server's tools available through the agent's tool interface.
|
|
"""
|
|
|
|
name: str = "mcp_agent"
|
|
description: str = "An agent that connects to an MCP server and uses its tools."
|
|
|
|
system_prompt: str = SYSTEM_PROMPT
|
|
next_step_prompt: str = NEXT_STEP_PROMPT
|
|
|
|
# Initialize MCP tool collection
|
|
mcp_clients: MCPClients = Field(default_factory=MCPClients)
|
|
available_tools: MCPClients = None # Will be set in initialize()
|
|
|
|
max_steps: int = 20
|
|
connection_type: str = "stdio" # "stdio" or "sse"
|
|
|
|
# Track tool schemas to detect changes
|
|
tool_schemas: Dict[str, Dict[str, Any]] = Field(default_factory=dict)
|
|
_refresh_tools_interval: int = 5 # Refresh tools every N steps
|
|
|
|
# Special tool names that should trigger termination
|
|
special_tool_names: List[str] = Field(default_factory=lambda: ["terminate"])
|
|
|
|
async def initialize(
|
|
self,
|
|
connection_type: Optional[str] = None,
|
|
server_url: Optional[str] = None,
|
|
command: Optional[str] = None,
|
|
args: Optional[List[str]] = None,
|
|
) -> None:
|
|
"""Initialize the MCP connection.
|
|
|
|
Args:
|
|
connection_type: Type of connection to use ("stdio" or "sse")
|
|
server_url: URL of the MCP server (for SSE connection)
|
|
command: Command to run (for stdio connection)
|
|
args: Arguments for the command (for stdio connection)
|
|
"""
|
|
if connection_type:
|
|
self.connection_type = connection_type
|
|
|
|
# Connect to the MCP server based on connection type
|
|
if self.connection_type == "sse":
|
|
if not server_url:
|
|
raise ValueError("Server URL is required for SSE connection")
|
|
await self.mcp_clients.connect_sse(server_url=server_url)
|
|
elif self.connection_type == "stdio":
|
|
if not command:
|
|
raise ValueError("Command is required for stdio connection")
|
|
await self.mcp_clients.connect_stdio(command=command, args=args or [])
|
|
else:
|
|
raise ValueError(f"Unsupported connection type: {self.connection_type}")
|
|
|
|
# Set available_tools to our MCP instance
|
|
self.available_tools = self.mcp_clients
|
|
|
|
# Store initial tool schemas
|
|
await self._refresh_tools()
|
|
|
|
# Add system message about available tools
|
|
tool_names = list(self.mcp_clients.tool_map.keys())
|
|
tools_info = ", ".join(tool_names)
|
|
|
|
# Add system prompt and available tools information
|
|
self.memory.add_message(
|
|
Message.system_message(
|
|
f"{self.system_prompt}\n\nAvailable MCP tools: {tools_info}"
|
|
)
|
|
)
|
|
|
|
async def _refresh_tools(self) -> Tuple[List[str], List[str]]:
|
|
"""Refresh the list of available tools from the MCP server.
|
|
|
|
Returns:
|
|
A tuple of (added_tools, removed_tools)
|
|
"""
|
|
if not self.mcp_clients.session:
|
|
return [], []
|
|
|
|
# Get current tool schemas directly from the server
|
|
response = await self.mcp_clients.session.list_tools()
|
|
current_tools = {tool.name: tool.inputSchema for tool in response.tools}
|
|
|
|
# Determine added, removed, and changed tools
|
|
current_names = set(current_tools.keys())
|
|
previous_names = set(self.tool_schemas.keys())
|
|
|
|
added_tools = list(current_names - previous_names)
|
|
removed_tools = list(previous_names - current_names)
|
|
|
|
# Check for schema changes in existing tools
|
|
changed_tools = []
|
|
for name in current_names.intersection(previous_names):
|
|
if current_tools[name] != self.tool_schemas.get(name):
|
|
changed_tools.append(name)
|
|
|
|
# Update stored schemas
|
|
self.tool_schemas = current_tools
|
|
|
|
# Log and notify about changes
|
|
if added_tools:
|
|
logger.info(f"Added MCP tools: {added_tools}")
|
|
self.memory.add_message(
|
|
Message.system_message(f"New tools available: {', '.join(added_tools)}")
|
|
)
|
|
if removed_tools:
|
|
logger.info(f"Removed MCP tools: {removed_tools}")
|
|
self.memory.add_message(
|
|
Message.system_message(
|
|
f"Tools no longer available: {', '.join(removed_tools)}"
|
|
)
|
|
)
|
|
if changed_tools:
|
|
logger.info(f"Changed MCP tools: {changed_tools}")
|
|
|
|
return added_tools, removed_tools
|
|
|
|
async def think(self) -> bool:
|
|
"""Process current state and decide next action."""
|
|
# Check MCP session and tools availability
|
|
if not self.mcp_clients.session or not self.mcp_clients.tool_map:
|
|
logger.info("MCP service is no longer available, ending interaction")
|
|
self.state = AgentState.FINISHED
|
|
return False
|
|
|
|
# Refresh tools periodically
|
|
if self.current_step % self._refresh_tools_interval == 0:
|
|
await self._refresh_tools()
|
|
# All tools removed indicates shutdown
|
|
if not self.mcp_clients.tool_map:
|
|
logger.info("MCP service has shut down, ending interaction")
|
|
self.state = AgentState.FINISHED
|
|
return False
|
|
|
|
# Use the parent class's think method
|
|
return await super().think()
|
|
|
|
async def _handle_special_tool(self, name: str, result: Any, **kwargs) -> None:
|
|
"""Handle special tool execution and state changes"""
|
|
# First process with parent handler
|
|
await super()._handle_special_tool(name, result, **kwargs)
|
|
|
|
# Handle multimedia responses
|
|
if isinstance(result, ToolResult) and result.base64_image:
|
|
self.memory.add_message(
|
|
Message.system_message(
|
|
MULTIMEDIA_RESPONSE_PROMPT.format(tool_name=name)
|
|
)
|
|
)
|
|
|
|
def _should_finish_execution(self, name: str, **kwargs) -> bool:
|
|
"""Determine if tool execution should finish the agent"""
|
|
# Terminate if the tool name is 'terminate'
|
|
return name.lower() == "terminate"
|
|
|
|
async def cleanup(self) -> None:
|
|
"""Clean up MCP connection when done."""
|
|
if self.mcp_clients.session:
|
|
await self.mcp_clients.disconnect()
|
|
logger.info("MCP connection closed")
|
|
|
|
async def run(self, request: Optional[str] = None) -> str:
|
|
"""Run the agent with cleanup when done."""
|
|
try:
|
|
result = await super().run(request)
|
|
return result
|
|
finally:
|
|
# Ensure cleanup happens even if there's an error
|
|
await self.cleanup()
|