from typing import Any, Dict, List, Optional, Tuple from pydantic import Field from app.agent.toolcall import ToolCallAgent from app.logger import logger from app.prompt.mcp import MULTIMEDIA_RESPONSE_PROMPT, NEXT_STEP_PROMPT, SYSTEM_PROMPT from app.schema import AgentState, Message from app.tool.base import ToolResult from app.tool.mcp import MCPClients class MCPAgent(ToolCallAgent): """Agent for interacting with MCP (Model Context Protocol) servers. This agent connects to an MCP server using either SSE or stdio transport and makes the server's tools available through the agent's tool interface. """ name: str = "mcp_agent" description: str = "An agent that connects to an MCP server and uses its tools." system_prompt: str = SYSTEM_PROMPT next_step_prompt: str = NEXT_STEP_PROMPT # Initialize MCP tool collection mcp_clients: MCPClients = Field(default_factory=MCPClients) available_tools: MCPClients = None # Will be set in initialize() max_steps: int = 20 connection_type: str = "stdio" # "stdio" or "sse" # Track tool schemas to detect changes tool_schemas: Dict[str, Dict[str, Any]] = Field(default_factory=dict) _refresh_tools_interval: int = 5 # Refresh tools every N steps # Special tool names that should trigger termination special_tool_names: List[str] = Field(default_factory=lambda: ["terminate"]) async def initialize( self, connection_type: Optional[str] = None, server_url: Optional[str] = None, command: Optional[str] = None, args: Optional[List[str]] = None, ) -> None: """Initialize the MCP connection. Args: connection_type: Type of connection to use ("stdio" or "sse") server_url: URL of the MCP server (for SSE connection) command: Command to run (for stdio connection) args: Arguments for the command (for stdio connection) """ if connection_type: self.connection_type = connection_type # Connect to the MCP server based on connection type if self.connection_type == "sse": if not server_url: raise ValueError("Server URL is required for SSE connection") await self.mcp_clients.connect_sse(server_url=server_url) elif self.connection_type == "stdio": if not command: raise ValueError("Command is required for stdio connection") await self.mcp_clients.connect_stdio(command=command, args=args or []) else: raise ValueError(f"Unsupported connection type: {self.connection_type}") # Set available_tools to our MCP instance self.available_tools = self.mcp_clients # Store initial tool schemas await self._refresh_tools() # Add system message about available tools tool_names = list(self.mcp_clients.tool_map.keys()) tools_info = ", ".join(tool_names) # Add system prompt and available tools information self.memory.add_message( Message.system_message( f"{self.system_prompt}\n\nAvailable MCP tools: {tools_info}" ) ) async def _refresh_tools(self) -> Tuple[List[str], List[str]]: """Refresh the list of available tools from the MCP server. Returns: A tuple of (added_tools, removed_tools) """ if not self.mcp_clients.session: return [], [] # Get current tool schemas directly from the server response = await self.mcp_clients.session.list_tools() current_tools = {tool.name: tool.inputSchema for tool in response.tools} # Determine added, removed, and changed tools current_names = set(current_tools.keys()) previous_names = set(self.tool_schemas.keys()) added_tools = list(current_names - previous_names) removed_tools = list(previous_names - current_names) # Check for schema changes in existing tools changed_tools = [] for name in current_names.intersection(previous_names): if current_tools[name] != self.tool_schemas.get(name): changed_tools.append(name) # Update stored schemas self.tool_schemas = current_tools # Log and notify about changes if added_tools: logger.info(f"Added MCP tools: {added_tools}") self.memory.add_message( Message.system_message(f"New tools available: {', '.join(added_tools)}") ) if removed_tools: logger.info(f"Removed MCP tools: {removed_tools}") self.memory.add_message( Message.system_message( f"Tools no longer available: {', '.join(removed_tools)}" ) ) if changed_tools: logger.info(f"Changed MCP tools: {changed_tools}") return added_tools, removed_tools async def think(self) -> bool: """Process current state and decide next action.""" # Check MCP session and tools availability if not self.mcp_clients.session or not self.mcp_clients.tool_map: logger.info("MCP service is no longer available, ending interaction") self.state = AgentState.FINISHED return False # Refresh tools periodically if self.current_step % self._refresh_tools_interval == 0: await self._refresh_tools() # All tools removed indicates shutdown if not self.mcp_clients.tool_map: logger.info("MCP service has shut down, ending interaction") self.state = AgentState.FINISHED return False # Use the parent class's think method return await super().think() async def _handle_special_tool(self, name: str, result: Any, **kwargs) -> None: """Handle special tool execution and state changes""" # First process with parent handler await super()._handle_special_tool(name, result, **kwargs) # Handle multimedia responses if isinstance(result, ToolResult) and result.base64_image: self.memory.add_message( Message.system_message( MULTIMEDIA_RESPONSE_PROMPT.format(tool_name=name) ) ) def _should_finish_execution(self, name: str, **kwargs) -> bool: """Determine if tool execution should finish the agent""" # Terminate if the tool name is 'terminate' return name.lower() == "terminate" async def cleanup(self) -> None: """Clean up MCP connection when done.""" if self.mcp_clients.session: await self.mcp_clients.disconnect() logger.info("MCP connection closed") async def run(self, request: Optional[str] = None) -> str: """Run the agent with cleanup when done.""" try: result = await super().run(request) return result finally: # Ensure cleanup happens even if there's an error await self.cleanup()