add mcp
This commit is contained in:
parent
b7dcbfecb3
commit
14fa48e8d7
@ -1,5 +1,6 @@
|
|||||||
from app.agent.base import BaseAgent
|
from app.agent.base import BaseAgent
|
||||||
from app.agent.browser import BrowserAgent
|
from app.agent.browser import BrowserAgent
|
||||||
|
from app.agent.mcp import MCPAgent
|
||||||
from app.agent.planning import PlanningAgent
|
from app.agent.planning import PlanningAgent
|
||||||
from app.agent.react import ReActAgent
|
from app.agent.react import ReActAgent
|
||||||
from app.agent.swe import SWEAgent
|
from app.agent.swe import SWEAgent
|
||||||
@ -13,4 +14,5 @@ __all__ = [
|
|||||||
"ReActAgent",
|
"ReActAgent",
|
||||||
"SWEAgent",
|
"SWEAgent",
|
||||||
"ToolCallAgent",
|
"ToolCallAgent",
|
||||||
|
"MCPAgent",
|
||||||
]
|
]
|
||||||
|
185
app/agent/mcp.py
Normal file
185
app/agent/mcp.py
Normal file
@ -0,0 +1,185 @@
|
|||||||
|
from typing import Any, Dict, List, Optional, Tuple
|
||||||
|
|
||||||
|
from pydantic import Field
|
||||||
|
|
||||||
|
from app.agent.toolcall import ToolCallAgent
|
||||||
|
from app.logger import logger
|
||||||
|
from app.prompt.mcp import MULTIMEDIA_RESPONSE_PROMPT, NEXT_STEP_PROMPT, SYSTEM_PROMPT
|
||||||
|
from app.schema import AgentState, Message
|
||||||
|
from app.tool.base import ToolResult
|
||||||
|
from app.tool.mcp import MCP
|
||||||
|
|
||||||
|
|
||||||
|
class MCPAgent(ToolCallAgent):
|
||||||
|
"""Agent for interacting with MCP (Model Context Protocol) servers.
|
||||||
|
|
||||||
|
This agent connects to an MCP server using either SSE or stdio transport
|
||||||
|
and makes the server's tools available through the agent's tool interface.
|
||||||
|
"""
|
||||||
|
|
||||||
|
name: str = "mcp_agent"
|
||||||
|
description: str = "An agent that connects to an MCP server and uses its tools."
|
||||||
|
|
||||||
|
system_prompt: str = SYSTEM_PROMPT
|
||||||
|
next_step_prompt: str = NEXT_STEP_PROMPT
|
||||||
|
|
||||||
|
# Initialize MCP tool collection
|
||||||
|
mcp_tools: MCP = Field(default_factory=MCP)
|
||||||
|
available_tools: MCP = None # Will be set in initialize()
|
||||||
|
|
||||||
|
max_steps: int = 20
|
||||||
|
connection_type: str = "stdio" # "stdio" or "sse"
|
||||||
|
|
||||||
|
# Track tool schemas to detect changes
|
||||||
|
tool_schemas: Dict[str, Dict[str, Any]] = Field(default_factory=dict)
|
||||||
|
_refresh_tools_interval: int = 5 # Refresh tools every N steps
|
||||||
|
|
||||||
|
# Special tool names that should trigger termination
|
||||||
|
special_tool_names: List[str] = Field(default_factory=lambda: ["terminate"])
|
||||||
|
|
||||||
|
async def initialize(
|
||||||
|
self,
|
||||||
|
connection_type: Optional[str] = None,
|
||||||
|
server_url: Optional[str] = None,
|
||||||
|
command: Optional[str] = None,
|
||||||
|
args: Optional[List[str]] = None,
|
||||||
|
) -> None:
|
||||||
|
"""Initialize the MCP connection.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
connection_type: Type of connection to use ("stdio" or "sse")
|
||||||
|
server_url: URL of the MCP server (for SSE connection)
|
||||||
|
command: Command to run (for stdio connection)
|
||||||
|
args: Arguments for the command (for stdio connection)
|
||||||
|
"""
|
||||||
|
if connection_type:
|
||||||
|
self.connection_type = connection_type
|
||||||
|
|
||||||
|
# Connect to the MCP server based on connection type
|
||||||
|
if self.connection_type == "sse":
|
||||||
|
if not server_url:
|
||||||
|
raise ValueError("Server URL is required for SSE connection")
|
||||||
|
await self.mcp_tools.connect_sse(server_url=server_url)
|
||||||
|
elif self.connection_type == "stdio":
|
||||||
|
if not command:
|
||||||
|
raise ValueError("Command is required for stdio connection")
|
||||||
|
await self.mcp_tools.connect_stdio(command=command, args=args or [])
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unsupported connection type: {self.connection_type}")
|
||||||
|
|
||||||
|
# Set available_tools to our MCP instance
|
||||||
|
self.available_tools = self.mcp_tools
|
||||||
|
|
||||||
|
# Store initial tool schemas
|
||||||
|
await self._refresh_tools()
|
||||||
|
|
||||||
|
# Add system message about available tools
|
||||||
|
tool_names = list(self.mcp_tools.tool_map.keys())
|
||||||
|
tools_info = ", ".join(tool_names)
|
||||||
|
|
||||||
|
# Add system prompt and available tools information
|
||||||
|
self.memory.add_message(
|
||||||
|
Message.system_message(
|
||||||
|
f"{self.system_prompt}\n\nAvailable MCP tools: {tools_info}"
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _refresh_tools(self) -> Tuple[List[str], List[str]]:
|
||||||
|
"""Refresh the list of available tools from the MCP server.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A tuple of (added_tools, removed_tools)
|
||||||
|
"""
|
||||||
|
if not self.mcp_tools.session:
|
||||||
|
return [], []
|
||||||
|
|
||||||
|
# Get current tool schemas directly from the server
|
||||||
|
response = await self.mcp_tools.session.list_tools()
|
||||||
|
current_tools = {tool.name: tool.inputSchema for tool in response.tools}
|
||||||
|
|
||||||
|
# Determine added, removed, and changed tools
|
||||||
|
current_names = set(current_tools.keys())
|
||||||
|
previous_names = set(self.tool_schemas.keys())
|
||||||
|
|
||||||
|
added_tools = list(current_names - previous_names)
|
||||||
|
removed_tools = list(previous_names - current_names)
|
||||||
|
|
||||||
|
# Check for schema changes in existing tools
|
||||||
|
changed_tools = []
|
||||||
|
for name in current_names.intersection(previous_names):
|
||||||
|
if current_tools[name] != self.tool_schemas.get(name):
|
||||||
|
changed_tools.append(name)
|
||||||
|
|
||||||
|
# Update stored schemas
|
||||||
|
self.tool_schemas = current_tools
|
||||||
|
|
||||||
|
# Log and notify about changes
|
||||||
|
if added_tools:
|
||||||
|
logger.info(f"Added MCP tools: {added_tools}")
|
||||||
|
self.memory.add_message(
|
||||||
|
Message.system_message(f"New tools available: {', '.join(added_tools)}")
|
||||||
|
)
|
||||||
|
if removed_tools:
|
||||||
|
logger.info(f"Removed MCP tools: {removed_tools}")
|
||||||
|
self.memory.add_message(
|
||||||
|
Message.system_message(
|
||||||
|
f"Tools no longer available: {', '.join(removed_tools)}"
|
||||||
|
)
|
||||||
|
)
|
||||||
|
if changed_tools:
|
||||||
|
logger.info(f"Changed MCP tools: {changed_tools}")
|
||||||
|
|
||||||
|
return added_tools, removed_tools
|
||||||
|
|
||||||
|
async def think(self) -> bool:
|
||||||
|
"""Process current state and decide next action."""
|
||||||
|
# Check MCP session and tools availability
|
||||||
|
if not self.mcp_tools.session or not self.mcp_tools.tool_map:
|
||||||
|
logger.info("MCP service is no longer available, ending interaction")
|
||||||
|
self.state = AgentState.FINISHED
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Refresh tools periodically
|
||||||
|
if self.current_step % self._refresh_tools_interval == 0:
|
||||||
|
added, removed = await self._refresh_tools()
|
||||||
|
# All tools removed indicates shutdown
|
||||||
|
if not self.mcp_tools.tool_map:
|
||||||
|
logger.info("MCP service has shut down, ending interaction")
|
||||||
|
self.state = AgentState.FINISHED
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Use the parent class's think method
|
||||||
|
return await super().think()
|
||||||
|
|
||||||
|
async def _handle_special_tool(self, name: str, result: Any, **kwargs) -> None:
|
||||||
|
"""Handle special tool execution and state changes"""
|
||||||
|
# First process with parent handler
|
||||||
|
await super()._handle_special_tool(name, result, **kwargs)
|
||||||
|
|
||||||
|
# Handle multimedia responses
|
||||||
|
if isinstance(result, ToolResult) and result.base64_image:
|
||||||
|
self.memory.add_message(
|
||||||
|
Message.system_message(
|
||||||
|
MULTIMEDIA_RESPONSE_PROMPT.format(tool_name=name)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
def _should_finish_execution(self, name: str, **kwargs) -> bool:
|
||||||
|
"""Determine if tool execution should finish the agent"""
|
||||||
|
# Terminate if the tool name is 'terminate'
|
||||||
|
return name.lower() == "terminate"
|
||||||
|
|
||||||
|
async def cleanup(self) -> None:
|
||||||
|
"""Clean up MCP connection when done."""
|
||||||
|
if self.mcp_tools.session:
|
||||||
|
await self.mcp_tools.disconnect()
|
||||||
|
logger.info("MCP connection closed")
|
||||||
|
|
||||||
|
async def run(self, request: Optional[str] = None) -> str:
|
||||||
|
"""Run the agent with cleanup when done."""
|
||||||
|
try:
|
||||||
|
result = await super().run(request)
|
||||||
|
return result
|
||||||
|
finally:
|
||||||
|
# Ensure cleanup happens even if there's an error
|
||||||
|
await self.cleanup()
|
@ -227,5 +227,10 @@ class Config:
|
|||||||
"""Get the workspace root directory"""
|
"""Get the workspace root directory"""
|
||||||
return WORKSPACE_ROOT
|
return WORKSPACE_ROOT
|
||||||
|
|
||||||
|
@property
|
||||||
|
def root_path(self) -> Path:
|
||||||
|
"""Get the root path of the application"""
|
||||||
|
return PROJECT_ROOT
|
||||||
|
|
||||||
|
|
||||||
config = Config()
|
config = Config()
|
||||||
|
0
app/mcp/__init__.py
Normal file
0
app/mcp/__init__.py
Normal file
196
app/mcp/server.py
Normal file
196
app/mcp/server.py
Normal file
@ -0,0 +1,196 @@
|
|||||||
|
import argparse
|
||||||
|
import asyncio
|
||||||
|
import atexit
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
from inspect import Parameter, Signature
|
||||||
|
from typing import Any, Dict, Optional
|
||||||
|
|
||||||
|
from mcp.server.fastmcp import FastMCP
|
||||||
|
|
||||||
|
|
||||||
|
# Add directories to Python path (needed for proper importing)
|
||||||
|
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||||
|
parent_dir = os.path.dirname(current_dir)
|
||||||
|
root_dir = os.path.dirname(parent_dir)
|
||||||
|
sys.path.insert(0, parent_dir)
|
||||||
|
sys.path.insert(0, current_dir)
|
||||||
|
sys.path.insert(0, root_dir)
|
||||||
|
|
||||||
|
# Configure logging (using the same format as original)
|
||||||
|
logging.basicConfig(
|
||||||
|
level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
|
||||||
|
)
|
||||||
|
logger = logging.getLogger("mcp-server")
|
||||||
|
|
||||||
|
from app.tool.base import BaseTool
|
||||||
|
from app.tool.bash import Bash
|
||||||
|
from app.tool.browser_use_tool import BrowserUseTool
|
||||||
|
from app.tool.str_replace_editor import StrReplaceEditor
|
||||||
|
from app.tool.terminate import Terminate
|
||||||
|
|
||||||
|
|
||||||
|
class MCPServer:
|
||||||
|
"""MCP Server implementation with tool registration and management."""
|
||||||
|
|
||||||
|
def __init__(self, name: str = "openmanus"):
|
||||||
|
self.server = FastMCP(name)
|
||||||
|
self.tools: Dict[str, BaseTool] = {}
|
||||||
|
|
||||||
|
# Initialize standard tools
|
||||||
|
self.tools["bash"] = Bash()
|
||||||
|
self.tools["browser"] = BrowserUseTool()
|
||||||
|
self.tools["editor"] = StrReplaceEditor()
|
||||||
|
self.tools["terminate"] = Terminate()
|
||||||
|
|
||||||
|
from app.logger import logger as app_logger
|
||||||
|
|
||||||
|
global logger
|
||||||
|
logger = app_logger
|
||||||
|
|
||||||
|
def register_tool(self, tool: BaseTool, method_name: Optional[str] = None) -> None:
|
||||||
|
"""Register a tool with parameter validation and documentation."""
|
||||||
|
tool_name = method_name or tool.name
|
||||||
|
tool_param = tool.to_param()
|
||||||
|
tool_function = tool_param["function"]
|
||||||
|
|
||||||
|
# Define the async function to be registered
|
||||||
|
async def tool_method(**kwargs):
|
||||||
|
logger.info(f"Executing {tool_name}: {kwargs}")
|
||||||
|
result = await tool.execute(**kwargs)
|
||||||
|
|
||||||
|
logger.info(f"Result of {tool_name}: {result}")
|
||||||
|
|
||||||
|
# Handle different types of results (match original logic)
|
||||||
|
if hasattr(result, "model_dump"):
|
||||||
|
return json.dumps(result.model_dump())
|
||||||
|
elif isinstance(result, dict):
|
||||||
|
return json.dumps(result)
|
||||||
|
return result
|
||||||
|
|
||||||
|
# Set method metadata
|
||||||
|
tool_method.__name__ = tool_name
|
||||||
|
tool_method.__doc__ = self._build_docstring(tool_function)
|
||||||
|
tool_method.__signature__ = self._build_signature(tool_function)
|
||||||
|
|
||||||
|
# Store parameter schema (important for tools that access it programmatically)
|
||||||
|
param_props = tool_function.get("parameters", {}).get("properties", {})
|
||||||
|
required_params = tool_function.get("parameters", {}).get("required", [])
|
||||||
|
tool_method._parameter_schema = {
|
||||||
|
param_name: {
|
||||||
|
"description": param_details.get("description", ""),
|
||||||
|
"type": param_details.get("type", "any"),
|
||||||
|
"required": param_name in required_params,
|
||||||
|
}
|
||||||
|
for param_name, param_details in param_props.items()
|
||||||
|
}
|
||||||
|
|
||||||
|
# Register with server
|
||||||
|
self.server.tool()(tool_method)
|
||||||
|
logger.info(f"Registered tool: {tool_name}")
|
||||||
|
|
||||||
|
def _build_docstring(self, tool_function: dict) -> str:
|
||||||
|
"""Build a formatted docstring from tool function metadata."""
|
||||||
|
description = tool_function.get("description", "")
|
||||||
|
param_props = tool_function.get("parameters", {}).get("properties", {})
|
||||||
|
required_params = tool_function.get("parameters", {}).get("required", [])
|
||||||
|
|
||||||
|
# Build docstring (match original format)
|
||||||
|
docstring = description
|
||||||
|
if param_props:
|
||||||
|
docstring += "\n\nParameters:\n"
|
||||||
|
for param_name, param_details in param_props.items():
|
||||||
|
required_str = (
|
||||||
|
"(required)" if param_name in required_params else "(optional)"
|
||||||
|
)
|
||||||
|
param_type = param_details.get("type", "any")
|
||||||
|
param_desc = param_details.get("description", "")
|
||||||
|
docstring += (
|
||||||
|
f" {param_name} ({param_type}) {required_str}: {param_desc}\n"
|
||||||
|
)
|
||||||
|
|
||||||
|
return docstring
|
||||||
|
|
||||||
|
def _build_signature(self, tool_function: dict) -> Signature:
|
||||||
|
"""Build a function signature from tool function metadata."""
|
||||||
|
param_props = tool_function.get("parameters", {}).get("properties", {})
|
||||||
|
required_params = tool_function.get("parameters", {}).get("required", [])
|
||||||
|
|
||||||
|
parameters = []
|
||||||
|
|
||||||
|
# Follow original type mapping
|
||||||
|
for param_name, param_details in param_props.items():
|
||||||
|
param_type = param_details.get("type", "")
|
||||||
|
default = Parameter.empty if param_name in required_params else None
|
||||||
|
|
||||||
|
# Map JSON Schema types to Python types (same as original)
|
||||||
|
annotation = Any
|
||||||
|
if param_type == "string":
|
||||||
|
annotation = str
|
||||||
|
elif param_type == "integer":
|
||||||
|
annotation = int
|
||||||
|
elif param_type == "number":
|
||||||
|
annotation = float
|
||||||
|
elif param_type == "boolean":
|
||||||
|
annotation = bool
|
||||||
|
elif param_type == "object":
|
||||||
|
annotation = dict
|
||||||
|
elif param_type == "array":
|
||||||
|
annotation = list
|
||||||
|
|
||||||
|
# Create parameter with same structure as original
|
||||||
|
param = Parameter(
|
||||||
|
name=param_name,
|
||||||
|
kind=Parameter.KEYWORD_ONLY,
|
||||||
|
default=default,
|
||||||
|
annotation=annotation,
|
||||||
|
)
|
||||||
|
parameters.append(param)
|
||||||
|
|
||||||
|
return Signature(parameters=parameters)
|
||||||
|
|
||||||
|
async def cleanup(self) -> None:
|
||||||
|
"""Clean up server resources."""
|
||||||
|
logger.info("Cleaning up resources")
|
||||||
|
# Follow original cleanup logic - only clean browser tool
|
||||||
|
if "browser" in self.tools and hasattr(self.tools["browser"], "cleanup"):
|
||||||
|
await self.tools["browser"].cleanup()
|
||||||
|
|
||||||
|
def register_all_tools(self) -> None:
|
||||||
|
"""Register all tools with the server."""
|
||||||
|
for tool in self.tools.values():
|
||||||
|
self.register_tool(tool)
|
||||||
|
|
||||||
|
def run(self, transport: str = "stdio") -> None:
|
||||||
|
"""Run the MCP server."""
|
||||||
|
# Register all tools
|
||||||
|
self.register_all_tools()
|
||||||
|
|
||||||
|
# Register cleanup function (match original behavior)
|
||||||
|
atexit.register(lambda: asyncio.run(self.cleanup()))
|
||||||
|
|
||||||
|
# Start server (with same logging as original)
|
||||||
|
logger.info(f"Starting OpenManus server ({transport} mode)")
|
||||||
|
self.server.run(transport=transport)
|
||||||
|
|
||||||
|
|
||||||
|
def parse_args() -> argparse.Namespace:
|
||||||
|
"""Parse command line arguments."""
|
||||||
|
parser = argparse.ArgumentParser(description="OpenManus MCP Server")
|
||||||
|
parser.add_argument(
|
||||||
|
"--transport",
|
||||||
|
choices=["stdio"],
|
||||||
|
default="stdio",
|
||||||
|
help="Communication method: stdio or http (default: stdio)",
|
||||||
|
)
|
||||||
|
return parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
args = parse_args()
|
||||||
|
|
||||||
|
# Create and run server (maintaining original flow)
|
||||||
|
server = MCPServer()
|
||||||
|
server.run(transport=args.transport)
|
114
app/tool/mcp.py
Normal file
114
app/tool/mcp.py
Normal file
@ -0,0 +1,114 @@
|
|||||||
|
from contextlib import AsyncExitStack
|
||||||
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
|
from mcp import ClientSession, StdioServerParameters
|
||||||
|
from mcp.client.sse import sse_client
|
||||||
|
from mcp.client.stdio import stdio_client
|
||||||
|
from mcp.types import TextContent
|
||||||
|
|
||||||
|
from app.logger import logger
|
||||||
|
from app.tool.base import BaseTool, ToolResult
|
||||||
|
from app.tool.tool_collection import ToolCollection
|
||||||
|
|
||||||
|
|
||||||
|
class MCPServerTool(BaseTool):
|
||||||
|
"""Represents a tool available on the MCP server."""
|
||||||
|
|
||||||
|
schema: Dict[str, Any] = None
|
||||||
|
session: Optional[ClientSession] = None
|
||||||
|
|
||||||
|
async def execute(self, **kwargs) -> ToolResult:
|
||||||
|
"""Execute the tool on the MCP server."""
|
||||||
|
if not self.session:
|
||||||
|
return ToolResult(error="Not connected to MCP server")
|
||||||
|
|
||||||
|
try:
|
||||||
|
result = await self.session.call_tool(self.name, kwargs)
|
||||||
|
content_str = ", ".join(
|
||||||
|
item.text for item in result.content if isinstance(item, TextContent)
|
||||||
|
)
|
||||||
|
return ToolResult(output=content_str or "No output returned.")
|
||||||
|
except Exception as e:
|
||||||
|
return ToolResult(error=f"Error executing tool: {str(e)}")
|
||||||
|
|
||||||
|
|
||||||
|
class MCP(ToolCollection):
|
||||||
|
"""AN MCP tool collection that connects to an MCP server and executes commands."""
|
||||||
|
|
||||||
|
session: Optional[ClientSession] = None
|
||||||
|
exit_stack: AsyncExitStack = None
|
||||||
|
description: str = "MCP tools for server interaction"
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__() # Initialize with empty tools list
|
||||||
|
self.name = "mcp" # Keep name for backward compatibility
|
||||||
|
self.exit_stack = AsyncExitStack()
|
||||||
|
|
||||||
|
async def connect_sse(self, server_url: str) -> None:
|
||||||
|
"""Connect to an MCP server using SSE transport."""
|
||||||
|
if not server_url:
|
||||||
|
raise ValueError("Server URL is required.")
|
||||||
|
if self.session:
|
||||||
|
await self.disconnect()
|
||||||
|
|
||||||
|
streams_context = sse_client(url=server_url)
|
||||||
|
streams = await self.exit_stack.enter_async_context(streams_context)
|
||||||
|
self.session = await self.exit_stack.enter_async_context(
|
||||||
|
ClientSession(*streams)
|
||||||
|
)
|
||||||
|
|
||||||
|
await self._initialize_and_list_tools()
|
||||||
|
|
||||||
|
async def connect_stdio(self, command: str, args: List[str]) -> None:
|
||||||
|
"""Connect to an MCP server using stdio transport."""
|
||||||
|
if not command:
|
||||||
|
raise ValueError("Server command is required.")
|
||||||
|
if self.session:
|
||||||
|
await self.disconnect()
|
||||||
|
|
||||||
|
server_params = StdioServerParameters(command=command, args=args)
|
||||||
|
stdio_transport = await self.exit_stack.enter_async_context(
|
||||||
|
stdio_client(server_params)
|
||||||
|
)
|
||||||
|
read, write = stdio_transport
|
||||||
|
self.session = await self.exit_stack.enter_async_context(
|
||||||
|
ClientSession(read, write)
|
||||||
|
)
|
||||||
|
|
||||||
|
await self._initialize_and_list_tools()
|
||||||
|
|
||||||
|
async def _initialize_and_list_tools(self) -> None:
|
||||||
|
"""Initialize session and populate tool map."""
|
||||||
|
if not self.session:
|
||||||
|
raise RuntimeError("Session not initialized.")
|
||||||
|
|
||||||
|
await self.session.initialize()
|
||||||
|
response = await self.session.list_tools()
|
||||||
|
|
||||||
|
# Clear existing tools
|
||||||
|
self.tools = tuple()
|
||||||
|
self.tool_map = {}
|
||||||
|
|
||||||
|
# Create proper tool objects for each server tool
|
||||||
|
for tool in response.tools:
|
||||||
|
server_tool = MCPServerTool(
|
||||||
|
name=tool.name,
|
||||||
|
description=tool.description,
|
||||||
|
schema=tool.inputSchema,
|
||||||
|
session=self.session,
|
||||||
|
)
|
||||||
|
self.tool_map[tool.name] = server_tool
|
||||||
|
|
||||||
|
self.tools = tuple(self.tool_map.values())
|
||||||
|
logger.info(
|
||||||
|
f"Connected to server with tools: {[tool.name for tool in response.tools]}"
|
||||||
|
)
|
||||||
|
|
||||||
|
async def disconnect(self) -> None:
|
||||||
|
"""Disconnect from the MCP server and clean up resources."""
|
||||||
|
if self.session and self.exit_stack:
|
||||||
|
await self.exit_stack.aclose()
|
||||||
|
self.session = None
|
||||||
|
self.tools = tuple()
|
||||||
|
self.tool_map = {}
|
||||||
|
logger.info("Disconnected from MCP server")
|
@ -8,6 +8,9 @@ from app.tool.base import BaseTool, ToolFailure, ToolResult
|
|||||||
class ToolCollection:
|
class ToolCollection:
|
||||||
"""A collection of defined tools."""
|
"""A collection of defined tools."""
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
arbitrary_types_allowed = True
|
||||||
|
|
||||||
def __init__(self, *tools: BaseTool):
|
def __init__(self, *tools: BaseTool):
|
||||||
self.tools = tools
|
self.tools = tools
|
||||||
self.tool_map = {tool.name: tool for tool in tools}
|
self.tool_map = {tool.name: tool for tool in tools}
|
||||||
|
107
run_mcp.py
Normal file
107
run_mcp.py
Normal file
@ -0,0 +1,107 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
import argparse
|
||||||
|
import asyncio
|
||||||
|
import sys
|
||||||
|
|
||||||
|
from app.agent.mcp import MCPAgent
|
||||||
|
from app.config import config
|
||||||
|
from app.logger import logger
|
||||||
|
|
||||||
|
|
||||||
|
class MCPRunner:
|
||||||
|
"""Runner class for MCP Agent with proper path handling and configuration."""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.root_path = config.root_path
|
||||||
|
self.server_script = self.root_path / "app" / "mcp" / "server.py"
|
||||||
|
self.agent = MCPAgent()
|
||||||
|
|
||||||
|
async def initialize(self, connection_type: str, server_url: str = None) -> None:
|
||||||
|
"""Initialize the MCP agent with the appropriate connection."""
|
||||||
|
logger.info(f"Initializing MCPAgent with {connection_type} connection...")
|
||||||
|
|
||||||
|
if connection_type == "stdio":
|
||||||
|
await self.agent.initialize(
|
||||||
|
connection_type="stdio",
|
||||||
|
command=sys.executable,
|
||||||
|
args=[str(self.server_script)],
|
||||||
|
)
|
||||||
|
else: # sse
|
||||||
|
await self.agent.initialize(connection_type="sse", server_url=server_url)
|
||||||
|
|
||||||
|
logger.info(f"Connected to MCP server via {connection_type}")
|
||||||
|
|
||||||
|
async def run_interactive(self) -> None:
|
||||||
|
"""Run the agent in interactive mode."""
|
||||||
|
print("\nMCP Agent Interactive Mode (type 'exit' to quit)\n")
|
||||||
|
while True:
|
||||||
|
user_input = input("\nEnter your request: ")
|
||||||
|
if user_input.lower() in ["exit", "quit", "q"]:
|
||||||
|
break
|
||||||
|
response = await self.agent.run(user_input)
|
||||||
|
print(f"\nAgent: {response}")
|
||||||
|
|
||||||
|
async def run_single_prompt(self, prompt: str) -> None:
|
||||||
|
"""Run the agent with a single prompt."""
|
||||||
|
await self.agent.run(prompt)
|
||||||
|
|
||||||
|
async def run_default(self) -> None:
|
||||||
|
"""Run the agent in default mode."""
|
||||||
|
await self.agent.run(
|
||||||
|
"Hello, what tools are available to me? Terminate after you have listed the tools."
|
||||||
|
)
|
||||||
|
|
||||||
|
async def cleanup(self) -> None:
|
||||||
|
"""Clean up agent resources."""
|
||||||
|
await self.agent.cleanup()
|
||||||
|
logger.info("Session ended")
|
||||||
|
|
||||||
|
|
||||||
|
def parse_args() -> argparse.Namespace:
|
||||||
|
"""Parse command line arguments."""
|
||||||
|
parser = argparse.ArgumentParser(description="Run the MCP Agent")
|
||||||
|
parser.add_argument(
|
||||||
|
"--connection",
|
||||||
|
"-c",
|
||||||
|
choices=["stdio", "sse"],
|
||||||
|
default="stdio",
|
||||||
|
help="Connection type: stdio or sse",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--server-url",
|
||||||
|
default="http://127.0.0.1:8000/sse",
|
||||||
|
help="URL for SSE connection",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--interactive", "-i", action="store_true", help="Run in interactive mode"
|
||||||
|
)
|
||||||
|
parser.add_argument("--prompt", "-p", help="Single prompt to execute and exit")
|
||||||
|
return parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
|
async def run_mcp() -> None:
|
||||||
|
"""Main entry point for the MCP runner."""
|
||||||
|
args = parse_args()
|
||||||
|
runner = MCPRunner()
|
||||||
|
|
||||||
|
try:
|
||||||
|
await runner.initialize(args.connection, args.server_url)
|
||||||
|
|
||||||
|
if args.prompt:
|
||||||
|
await runner.run_single_prompt(args.prompt)
|
||||||
|
elif args.interactive:
|
||||||
|
await runner.run_interactive()
|
||||||
|
else:
|
||||||
|
await runner.run_default()
|
||||||
|
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
logger.info("Program interrupted by user")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error running MCPAgent: {str(e)}", exc_info=True)
|
||||||
|
sys.exit(1)
|
||||||
|
finally:
|
||||||
|
await runner.cleanup()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
asyncio.run(run_mcp())
|
Loading…
x
Reference in New Issue
Block a user