update mcp_tools to mcp_clients
This commit is contained in:
parent
74f438bde3
commit
e5808d1a90
@ -24,7 +24,7 @@ class MCPAgent(ToolCallAgent):
|
|||||||
next_step_prompt: str = NEXT_STEP_PROMPT
|
next_step_prompt: str = NEXT_STEP_PROMPT
|
||||||
|
|
||||||
# Initialize MCP tool collection
|
# Initialize MCP tool collection
|
||||||
mcp_tools: MCPClients = Field(default_factory=MCPClients)
|
mcp_clients: MCPClients = Field(default_factory=MCPClients)
|
||||||
available_tools: MCPClients = None # Will be set in initialize()
|
available_tools: MCPClients = None # Will be set in initialize()
|
||||||
|
|
||||||
max_steps: int = 20
|
max_steps: int = 20
|
||||||
@ -59,22 +59,22 @@ class MCPAgent(ToolCallAgent):
|
|||||||
if self.connection_type == "sse":
|
if self.connection_type == "sse":
|
||||||
if not server_url:
|
if not server_url:
|
||||||
raise ValueError("Server URL is required for SSE connection")
|
raise ValueError("Server URL is required for SSE connection")
|
||||||
await self.mcp_tools.connect_sse(server_url=server_url)
|
await self.mcp_clients.connect_sse(server_url=server_url)
|
||||||
elif self.connection_type == "stdio":
|
elif self.connection_type == "stdio":
|
||||||
if not command:
|
if not command:
|
||||||
raise ValueError("Command is required for stdio connection")
|
raise ValueError("Command is required for stdio connection")
|
||||||
await self.mcp_tools.connect_stdio(command=command, args=args or [])
|
await self.mcp_clients.connect_stdio(command=command, args=args or [])
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unsupported connection type: {self.connection_type}")
|
raise ValueError(f"Unsupported connection type: {self.connection_type}")
|
||||||
|
|
||||||
# Set available_tools to our MCP instance
|
# Set available_tools to our MCP instance
|
||||||
self.available_tools = self.mcp_tools
|
self.available_tools = self.mcp_clients
|
||||||
|
|
||||||
# Store initial tool schemas
|
# Store initial tool schemas
|
||||||
await self._refresh_tools()
|
await self._refresh_tools()
|
||||||
|
|
||||||
# Add system message about available tools
|
# Add system message about available tools
|
||||||
tool_names = list(self.mcp_tools.tool_map.keys())
|
tool_names = list(self.mcp_clients.tool_map.keys())
|
||||||
tools_info = ", ".join(tool_names)
|
tools_info = ", ".join(tool_names)
|
||||||
|
|
||||||
# Add system prompt and available tools information
|
# Add system prompt and available tools information
|
||||||
@ -90,11 +90,11 @@ class MCPAgent(ToolCallAgent):
|
|||||||
Returns:
|
Returns:
|
||||||
A tuple of (added_tools, removed_tools)
|
A tuple of (added_tools, removed_tools)
|
||||||
"""
|
"""
|
||||||
if not self.mcp_tools.session:
|
if not self.mcp_clients.session:
|
||||||
return [], []
|
return [], []
|
||||||
|
|
||||||
# Get current tool schemas directly from the server
|
# Get current tool schemas directly from the server
|
||||||
response = await self.mcp_tools.session.list_tools()
|
response = await self.mcp_clients.session.list_tools()
|
||||||
current_tools = {tool.name: tool.inputSchema for tool in response.tools}
|
current_tools = {tool.name: tool.inputSchema for tool in response.tools}
|
||||||
|
|
||||||
# Determine added, removed, and changed tools
|
# Determine added, removed, and changed tools
|
||||||
@ -134,7 +134,7 @@ class MCPAgent(ToolCallAgent):
|
|||||||
async def think(self) -> bool:
|
async def think(self) -> bool:
|
||||||
"""Process current state and decide next action."""
|
"""Process current state and decide next action."""
|
||||||
# Check MCP session and tools availability
|
# Check MCP session and tools availability
|
||||||
if not self.mcp_tools.session or not self.mcp_tools.tool_map:
|
if not self.mcp_clients.session or not self.mcp_clients.tool_map:
|
||||||
logger.info("MCP service is no longer available, ending interaction")
|
logger.info("MCP service is no longer available, ending interaction")
|
||||||
self.state = AgentState.FINISHED
|
self.state = AgentState.FINISHED
|
||||||
return False
|
return False
|
||||||
@ -143,7 +143,7 @@ class MCPAgent(ToolCallAgent):
|
|||||||
if self.current_step % self._refresh_tools_interval == 0:
|
if self.current_step % self._refresh_tools_interval == 0:
|
||||||
await self._refresh_tools()
|
await self._refresh_tools()
|
||||||
# All tools removed indicates shutdown
|
# All tools removed indicates shutdown
|
||||||
if not self.mcp_tools.tool_map:
|
if not self.mcp_clients.tool_map:
|
||||||
logger.info("MCP service has shut down, ending interaction")
|
logger.info("MCP service has shut down, ending interaction")
|
||||||
self.state = AgentState.FINISHED
|
self.state = AgentState.FINISHED
|
||||||
return False
|
return False
|
||||||
@ -171,8 +171,8 @@ class MCPAgent(ToolCallAgent):
|
|||||||
|
|
||||||
async def cleanup(self) -> None:
|
async def cleanup(self) -> None:
|
||||||
"""Clean up MCP connection when done."""
|
"""Clean up MCP connection when done."""
|
||||||
if self.mcp_tools.session:
|
if self.mcp_clients.session:
|
||||||
await self.mcp_tools.disconnect()
|
await self.mcp_clients.disconnect()
|
||||||
logger.info("MCP connection closed")
|
logger.info("MCP connection closed")
|
||||||
|
|
||||||
async def run(self, request: Optional[str] = None) -> str:
|
async def run(self, request: Optional[str] = None) -> str:
|
||||||
|
Loading…
x
Reference in New Issue
Block a user