from contextlib import AsyncExitStack from typing import Any, Dict, List, Optional from mcp import ClientSession, StdioServerParameters from mcp.client.sse import sse_client from mcp.client.stdio import stdio_client from mcp.types import TextContent from app.logger import logger from app.tool.base import BaseTool, ToolResult from app.tool.tool_collection import ToolCollection class MCPServerTool(BaseTool): """Represents a tool available on the MCP server.""" schema: Dict[str, Any] = None session: Optional[ClientSession] = None async def execute(self, **kwargs) -> ToolResult: """Execute the tool on the MCP server.""" if not self.session: return ToolResult(error="Not connected to MCP server") try: result = await self.session.call_tool(self.name, kwargs) content_str = ", ".join( item.text for item in result.content if isinstance(item, TextContent) ) return ToolResult(output=content_str or "No output returned.") except Exception as e: return ToolResult(error=f"Error executing tool: {str(e)}") class MCP(ToolCollection): """AN MCP tool collection that connects to an MCP server and executes commands.""" session: Optional[ClientSession] = None exit_stack: AsyncExitStack = None description: str = "MCP tools for server interaction" def __init__(self): super().__init__() # Initialize with empty tools list self.name = "mcp" # Keep name for backward compatibility self.exit_stack = AsyncExitStack() async def connect_sse(self, server_url: str) -> None: """Connect to an MCP server using SSE transport.""" if not server_url: raise ValueError("Server URL is required.") if self.session: await self.disconnect() streams_context = sse_client(url=server_url) streams = await self.exit_stack.enter_async_context(streams_context) self.session = await self.exit_stack.enter_async_context( ClientSession(*streams) ) await self._initialize_and_list_tools() async def connect_stdio(self, command: str, args: List[str]) -> None: """Connect to an MCP server using stdio transport.""" if not command: raise ValueError("Server command is required.") if self.session: await self.disconnect() server_params = StdioServerParameters(command=command, args=args) stdio_transport = await self.exit_stack.enter_async_context( stdio_client(server_params) ) read, write = stdio_transport self.session = await self.exit_stack.enter_async_context( ClientSession(read, write) ) await self._initialize_and_list_tools() async def _initialize_and_list_tools(self) -> None: """Initialize session and populate tool map.""" if not self.session: raise RuntimeError("Session not initialized.") await self.session.initialize() response = await self.session.list_tools() # Clear existing tools self.tools = tuple() self.tool_map = {} # Create proper tool objects for each server tool for tool in response.tools: server_tool = MCPServerTool( name=tool.name, description=tool.description, schema=tool.inputSchema, session=self.session, ) self.tool_map[tool.name] = server_tool self.tools = tuple(self.tool_map.values()) logger.info( f"Connected to server with tools: {[tool.name for tool in response.tools]}" ) async def disconnect(self) -> None: """Disconnect from the MCP server and clean up resources.""" if self.session and self.exit_stack: await self.exit_stack.aclose() self.session = None self.tools = tuple() self.tool_map = {} logger.info("Disconnected from MCP server")