update ToolCallAgent and Manus
This commit is contained in:
parent
c3203e7fa3
commit
2509bc30c4
@ -1,8 +1,10 @@
|
|||||||
from typing import Any
|
import json
|
||||||
|
from typing import Any, Optional
|
||||||
|
|
||||||
from pydantic import Field
|
from pydantic import Field
|
||||||
|
|
||||||
from app.agent.toolcall import ToolCallAgent
|
from app.agent.toolcall import ToolCallAgent
|
||||||
|
from app.logger import logger
|
||||||
from app.prompt.manus import NEXT_STEP_PROMPT, SYSTEM_PROMPT
|
from app.prompt.manus import NEXT_STEP_PROMPT, SYSTEM_PROMPT
|
||||||
from app.tool import Terminate, ToolCollection
|
from app.tool import Terminate, ToolCollection
|
||||||
from app.tool.browser_use_tool import BrowserUseTool
|
from app.tool.browser_use_tool import BrowserUseTool
|
||||||
@ -43,3 +45,45 @@ class Manus(ToolCallAgent):
|
|||||||
else:
|
else:
|
||||||
await self.available_tools.get_tool(BrowserUseTool().name).cleanup()
|
await self.available_tools.get_tool(BrowserUseTool().name).cleanup()
|
||||||
await super()._handle_special_tool(name, result, **kwargs)
|
await super()._handle_special_tool(name, result, **kwargs)
|
||||||
|
|
||||||
|
async def get_browser_state(self) -> Optional[dict]:
|
||||||
|
"""Get the current browser state for context in next steps."""
|
||||||
|
browser_tool = self.available_tools.get_tool(BrowserUseTool().name)
|
||||||
|
if not browser_tool:
|
||||||
|
return None
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Get browser state directly from the tool with no context parameter
|
||||||
|
result = await browser_tool.get_current_state()
|
||||||
|
|
||||||
|
if result.error:
|
||||||
|
logger.debug(f"Browser state error: {result.error}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Store screenshot if available
|
||||||
|
if hasattr(result, "base64_image") and result.base64_image:
|
||||||
|
self._current_base64_image = result.base64_image
|
||||||
|
|
||||||
|
# Parse the state info
|
||||||
|
return json.loads(result.output)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.debug(f"Failed to get browser state: {str(e)}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def think(self) -> bool:
|
||||||
|
# Add your custom pre-processing here
|
||||||
|
browser_state = await self.get_browser_state()
|
||||||
|
|
||||||
|
# Modify the next_step_prompt temporarily
|
||||||
|
original_prompt = self.next_step_prompt
|
||||||
|
if browser_state and not browser_state.get("error"):
|
||||||
|
self.next_step_prompt += f"\nCurrent browser state:\nURL: {browser_state.get('url', 'N/A')}\nTitle: {browser_state.get('title', 'N/A')}\n"
|
||||||
|
|
||||||
|
# Call parent implementation
|
||||||
|
result = await super().think()
|
||||||
|
|
||||||
|
# Restore original prompt
|
||||||
|
self.next_step_prompt = original_prompt
|
||||||
|
|
||||||
|
return result
|
||||||
|
@ -30,6 +30,7 @@ class ToolCallAgent(ReActAgent):
|
|||||||
special_tool_names: List[str] = Field(default_factory=lambda: [Terminate().name])
|
special_tool_names: List[str] = Field(default_factory=lambda: [Terminate().name])
|
||||||
|
|
||||||
tool_calls: List[ToolCall] = Field(default_factory=list)
|
tool_calls: List[ToolCall] = Field(default_factory=list)
|
||||||
|
_current_base64_image: Optional[str] = None
|
||||||
|
|
||||||
max_steps: int = 30
|
max_steps: int = 30
|
||||||
max_observe: Optional[Union[int, bool]] = None
|
max_observe: Optional[Union[int, bool]] = None
|
||||||
@ -44,9 +45,11 @@ class ToolCallAgent(ReActAgent):
|
|||||||
# Get response with tool options
|
# Get response with tool options
|
||||||
response = await self.llm.ask_tool(
|
response = await self.llm.ask_tool(
|
||||||
messages=self.messages,
|
messages=self.messages,
|
||||||
system_msgs=[Message.system_message(self.system_prompt)]
|
system_msgs=(
|
||||||
if self.system_prompt
|
[Message.system_message(self.system_prompt)]
|
||||||
else None,
|
if self.system_prompt
|
||||||
|
else None
|
||||||
|
),
|
||||||
tools=self.available_tools.to_params(),
|
tools=self.available_tools.to_params(),
|
||||||
tool_choice=self.tool_choices,
|
tool_choice=self.tool_choices,
|
||||||
)
|
)
|
||||||
@ -79,6 +82,9 @@ class ToolCallAgent(ReActAgent):
|
|||||||
logger.info(
|
logger.info(
|
||||||
f"🧰 Tools being prepared: {[call.function.name for call in response.tool_calls]}"
|
f"🧰 Tools being prepared: {[call.function.name for call in response.tool_calls]}"
|
||||||
)
|
)
|
||||||
|
logger.info(
|
||||||
|
f"🔧 Tool arguments: {response.tool_calls[0].function.arguments}"
|
||||||
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Handle different tool_choices modes
|
# Handle different tool_choices modes
|
||||||
@ -130,6 +136,9 @@ class ToolCallAgent(ReActAgent):
|
|||||||
|
|
||||||
results = []
|
results = []
|
||||||
for command in self.tool_calls:
|
for command in self.tool_calls:
|
||||||
|
# Reset base64_image for each tool call
|
||||||
|
self._current_base64_image = None
|
||||||
|
|
||||||
result = await self.execute_tool(command)
|
result = await self.execute_tool(command)
|
||||||
|
|
||||||
if self.max_observe:
|
if self.max_observe:
|
||||||
@ -141,7 +150,10 @@ class ToolCallAgent(ReActAgent):
|
|||||||
|
|
||||||
# Add tool response to memory
|
# Add tool response to memory
|
||||||
tool_msg = Message.tool_message(
|
tool_msg = Message.tool_message(
|
||||||
content=result, tool_call_id=command.id, name=command.function.name
|
content=result,
|
||||||
|
tool_call_id=command.id,
|
||||||
|
name=command.function.name,
|
||||||
|
base64_image=self._current_base64_image,
|
||||||
)
|
)
|
||||||
self.memory.add_message(tool_msg)
|
self.memory.add_message(tool_msg)
|
||||||
results.append(result)
|
results.append(result)
|
||||||
@ -165,16 +177,29 @@ class ToolCallAgent(ReActAgent):
|
|||||||
logger.info(f"🔧 Activating tool: '{name}'...")
|
logger.info(f"🔧 Activating tool: '{name}'...")
|
||||||
result = await self.available_tools.execute(name=name, tool_input=args)
|
result = await self.available_tools.execute(name=name, tool_input=args)
|
||||||
|
|
||||||
# Format result for display
|
# Handle special tools
|
||||||
|
await self._handle_special_tool(name=name, result=result)
|
||||||
|
|
||||||
|
# Check if result is a ToolResult with base64_image
|
||||||
|
if hasattr(result, "base64_image") and result.base64_image:
|
||||||
|
# Store the base64_image for later use in tool_message
|
||||||
|
self._current_base64_image = result.base64_image
|
||||||
|
|
||||||
|
# Format result for display
|
||||||
|
observation = (
|
||||||
|
f"Observed output of cmd `{name}` executed:\n{str(result)}"
|
||||||
|
if result
|
||||||
|
else f"Cmd `{name}` completed with no output"
|
||||||
|
)
|
||||||
|
return observation
|
||||||
|
|
||||||
|
# Format result for display (standard case)
|
||||||
observation = (
|
observation = (
|
||||||
f"Observed output of cmd `{name}` executed:\n{str(result)}"
|
f"Observed output of cmd `{name}` executed:\n{str(result)}"
|
||||||
if result
|
if result
|
||||||
else f"Cmd `{name}` completed with no output"
|
else f"Cmd `{name}` completed with no output"
|
||||||
)
|
)
|
||||||
|
|
||||||
# Handle special tools like `finish`
|
|
||||||
await self._handle_special_tool(name=name, result=result)
|
|
||||||
|
|
||||||
return observation
|
return observation
|
||||||
except json.JSONDecodeError:
|
except json.JSONDecodeError:
|
||||||
error_msg = f"Error parsing arguments for {name}: Invalid JSON format"
|
error_msg = f"Error parsing arguments for {name}: Invalid JSON format"
|
||||||
|
Loading…
x
Reference in New Issue
Block a user