Merge branch 'mannaandpoem:main' into main

This commit is contained in:
zyren123 2025-03-18 09:46:10 +08:00 committed by GitHub
commit f474290395
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
16 changed files with 995 additions and 201 deletions

View File

@ -0,0 +1,33 @@
name: Environment Corruption Check
on:
push:
branches: ["main"]
paths:
- requirements.txt
pull_request:
branches: ["main"]
paths:
- requirements.txt
concurrency:
group: ${{ github.workflow }}-${{ github.event_name }}-${{ github.ref }}
cancel-in-progress: true
jobs:
test-python-versions:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: ["3.11.11", "3.12.8", "3.13.2"]
fail-fast: false
steps:
- name: Checkout repository
uses: actions/checkout@v4
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python-version }}
- name: Upgrade pip
run: |
python -m pip install --upgrade pip
- name: Install dependencies
run: |
pip install -r requirements.txt

127
.github/workflows/pr-autodiff.yaml vendored Normal file
View File

@ -0,0 +1,127 @@
name: PR Diff Summarization
on:
# pull_request:
# branches: [main]
# types: [opened, ready_for_review, reopened]
issue_comment:
types: [created]
permissions:
contents: read
pull-requests: write
jobs:
pr-diff-summarization:
runs-on: ubuntu-latest
if: |
(github.event_name == 'pull_request') ||
(github.event_name == 'issue_comment' &&
contains(github.event.comment.body, '!pr-diff') &&
(github.event.comment.author_association == 'COLLABORATOR' || github.event.comment.author_association == 'MEMBER' || github.event.comment.author_association == 'OWNER') &&
github.event.issue.pull_request)
steps:
- name: Get PR head SHA
id: get-pr-sha
run: |
if [ "${{ github.event_name }}" == "pull_request" ]; then
echo "pr_sha=${{ github.event.pull_request.head.sha }}" >> $GITHUB_OUTPUT
echo "Retrieved PR head SHA: ${{ github.event.pull_request.head.sha }}"
else
PR_URL="${{ github.event.issue.pull_request.url }}"
SHA=$(curl -s -H "Authorization: Bearer ${{ secrets.GITHUB_TOKEN }}" $PR_URL | jq -r '.head.sha')
echo "pr_sha=$SHA" >> $GITHUB_OUTPUT
echo "Retrieved PR head SHA from API: $SHA"
fi
- name: Check out code
uses: actions/checkout@v4
with:
ref: ${{ steps.get-pr-sha.outputs.pr_sha }}
fetch-depth: 0
- name: Set up Python
uses: actions/setup-python@v5
with:
python-version: '3.11'
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install openai requests
- name: Create and run Python script
env:
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
OPENAI_BASE_URL: ${{ secrets.OPENAI_BASE_URL }}
GH_TOKEN: ${{ github.token }}
PR_NUMBER: ${{ github.event.pull_request.number || github.event.issue.number }}
run: |-
cat << 'EOF' > /tmp/_workflow_core.py
import os
import subprocess
import json
import requests
from openai import OpenAI
def get_diff():
result = subprocess.run(
['git', 'diff', 'origin/main...HEAD'],
capture_output=True, text=True, check=True)
return '\n'.join(
line for line in result.stdout.split('\n')
if any(line.startswith(c) for c in ('+', '-'))
and not line.startswith(('---', '+++'))
)[:round(200000 * 0.4)] # Truncate to prevent overflow
def generate_comment(diff_content):
client = OpenAI(
base_url=os.getenv("OPENAI_BASE_URL"),
api_key=os.getenv("OPENAI_API_KEY")
)
guidelines = '''
1. English version first, Chinese Simplified version after
2. Example format:
# Diff Report
## English
- Added `ABC` class
- Fixed `f()` behavior in `foo` module
### Comments Highlight
- `config.toml` needs to be configured properly to make sure new features work as expected.
### Spelling/Offensive Content Check
- No spelling mistakes or offensive content found in the code or comments.
3. Highlight non-English comments
4. Check for spelling/offensive content'''
response = client.chat.completions.create(
model="o3-mini",
messages=[{
"role": "system",
"content": "Generate bilingual code review feedback."
}, {
"role": "user",
"content": f"Review these changes per guidelines:\n{guidelines}\n\nDIFF:\n{diff_content}"
}]
)
return response.choices[0].message.content
def post_comment(comment):
repo = os.getenv("GITHUB_REPOSITORY")
pr_number = os.getenv("PR_NUMBER")
headers = {
"Authorization": f"Bearer {os.getenv('GH_TOKEN')}",
"Accept": "application/vnd.github.v3+json"
}
url = f"https://api.github.com/repos/{repo}/issues/{pr_number}/comments"
requests.post(url, json={"body": comment}, headers=headers)
if __name__ == "__main__":
diff_content = get_diff()
if not diff_content.strip():
print("No meaningful diff detected.")
exit(0)
comment = generate_comment(diff_content)
post_comment(comment)
print("Comment posted successfully.")
EOF
python /tmp/_workflow_core.py

View File

@ -24,4 +24,4 @@ jobs:
dashboard_show_total_reactions: true
top_issues: true
top_pull_requests: true
top_list_size: 32
top_list_size: 14

View File

@ -0,0 +1,10 @@
# Python version check: 3.11-3.13
import sys
if sys.version_info < (3, 11) or sys.version_info > (3, 13):
print(
"Warning: Unsupported Python version {ver}, please use 3.11-3.13".format(
ver=".".join(map(str, sys.version_info))
)
)

View File

@ -84,6 +84,7 @@ class BaseAgent(BaseModel, ABC):
self,
role: ROLE_TYPE, # type: ignore
content: str,
base64_image: Optional[str] = None,
**kwargs,
) -> None:
"""Add a message to the agent's memory.
@ -91,6 +92,7 @@ class BaseAgent(BaseModel, ABC):
Args:
role: The role of the message sender (user, system, assistant, tool).
content: The message content.
base64_image: Optional base64 encoded image.
**kwargs: Additional arguments (e.g., tool_call_id for tool messages).
Raises:
@ -106,9 +108,9 @@ class BaseAgent(BaseModel, ABC):
if role not in message_map:
raise ValueError(f"Unsupported message role: {role}")
msg_factory = message_map[role]
msg = msg_factory(content, **kwargs) if role == "tool" else msg_factory(content)
self.memory.add_message(msg)
# Create message with appropriate parameters based on role
kwargs = {"base64_image": base64_image, **(kwargs if role == "tool" else {})}
self.memory.add_message(message_map[role](content, **kwargs))
async def run(self, request: Optional[str] = None) -> str:
"""Execute the agent's main loop asynchronously.

View File

@ -1,14 +1,15 @@
from typing import Any
import json
from typing import Any, Optional
from pydantic import Field
from app.agent.toolcall import ToolCallAgent
from app.logger import logger
from app.prompt.manus import NEXT_STEP_PROMPT, SYSTEM_PROMPT
from app.tool import Terminate, ToolCollection
from app.tool.browser_use_tool import BrowserUseTool
from app.tool.file_saver import FileSaver
from app.tool.python_execute import PythonExecute
from app.tool.web_search import WebSearch
class Manus(ToolCallAgent):
@ -34,7 +35,7 @@ class Manus(ToolCallAgent):
# Add general-purpose tools to the tool collection
available_tools: ToolCollection = Field(
default_factory=lambda: ToolCollection(
PythonExecute(), WebSearch(), BrowserUseTool(), FileSaver(), Terminate()
PythonExecute(), BrowserUseTool(), FileSaver(), Terminate()
)
)
@ -44,3 +45,45 @@ class Manus(ToolCallAgent):
else:
await self.available_tools.get_tool(BrowserUseTool().name).cleanup()
await super()._handle_special_tool(name, result, **kwargs)
async def get_browser_state(self) -> Optional[dict]:
"""Get the current browser state for context in next steps."""
browser_tool = self.available_tools.get_tool(BrowserUseTool().name)
if not browser_tool:
return None
try:
# Get browser state directly from the tool with no context parameter
result = await browser_tool.get_current_state()
if result.error:
logger.debug(f"Browser state error: {result.error}")
return None
# Store screenshot if available
if hasattr(result, "base64_image") and result.base64_image:
self._current_base64_image = result.base64_image
# Parse the state info
return json.loads(result.output)
except Exception as e:
logger.debug(f"Failed to get browser state: {str(e)}")
return None
async def think(self) -> bool:
# Add your custom pre-processing here
browser_state = await self.get_browser_state()
# Modify the next_step_prompt temporarily
original_prompt = self.next_step_prompt
if browser_state and not browser_state.get("error"):
self.next_step_prompt += f"\nCurrent browser state:\nURL: {browser_state.get('url', 'N/A')}\nTitle: {browser_state.get('title', 'N/A')}\n"
# Call parent implementation
result = await super().think()
# Restore original prompt
self.next_step_prompt = original_prompt
return result

View File

@ -30,6 +30,7 @@ class ToolCallAgent(ReActAgent):
special_tool_names: List[str] = Field(default_factory=lambda: [Terminate().name])
tool_calls: List[ToolCall] = Field(default_factory=list)
_current_base64_image: Optional[str] = None
max_steps: int = 30
max_observe: Optional[Union[int, bool]] = None
@ -44,9 +45,11 @@ class ToolCallAgent(ReActAgent):
# Get response with tool options
response = await self.llm.ask_tool(
messages=self.messages,
system_msgs=[Message.system_message(self.system_prompt)]
if self.system_prompt
else None,
system_msgs=(
[Message.system_message(self.system_prompt)]
if self.system_prompt
else None
),
tools=self.available_tools.to_params(),
tool_choice=self.tool_choices,
)
@ -79,6 +82,9 @@ class ToolCallAgent(ReActAgent):
logger.info(
f"🧰 Tools being prepared: {[call.function.name for call in response.tool_calls]}"
)
logger.info(
f"🔧 Tool arguments: {response.tool_calls[0].function.arguments}"
)
try:
# Handle different tool_choices modes
@ -130,6 +136,9 @@ class ToolCallAgent(ReActAgent):
results = []
for command in self.tool_calls:
# Reset base64_image for each tool call
self._current_base64_image = None
result = await self.execute_tool(command)
if self.max_observe:
@ -141,7 +150,10 @@ class ToolCallAgent(ReActAgent):
# Add tool response to memory
tool_msg = Message.tool_message(
content=result, tool_call_id=command.id, name=command.function.name
content=result,
tool_call_id=command.id,
name=command.function.name,
base64_image=self._current_base64_image,
)
self.memory.add_message(tool_msg)
results.append(result)
@ -165,16 +177,29 @@ class ToolCallAgent(ReActAgent):
logger.info(f"🔧 Activating tool: '{name}'...")
result = await self.available_tools.execute(name=name, tool_input=args)
# Format result for display
# Handle special tools
await self._handle_special_tool(name=name, result=result)
# Check if result is a ToolResult with base64_image
if hasattr(result, "base64_image") and result.base64_image:
# Store the base64_image for later use in tool_message
self._current_base64_image = result.base64_image
# Format result for display
observation = (
f"Observed output of cmd `{name}` executed:\n{str(result)}"
if result
else f"Cmd `{name}` completed with no output"
)
return observation
# Format result for display (standard case)
observation = (
f"Observed output of cmd `{name}` executed:\n{str(result)}"
if result
else f"Cmd `{name}` completed with no output"
)
# Handle special tools like `finish`
await self._handle_special_tool(name=name, result=result)
return observation
except json.JSONDecodeError:
error_msg = f"Error parsing arguments for {name}: Invalid JSON format"

View File

@ -59,6 +59,9 @@ class BrowserSettings(BaseModel):
proxy: Optional[ProxySettings] = Field(
None, description="Proxy settings for the browser"
)
max_content_length: int = Field(
2000, description="Maximum length for content retrieval operations"
)
class AppConfig(BaseModel):

View File

@ -1,3 +1,4 @@
import math
from typing import Dict, List, Optional, Union
import tiktoken
@ -31,6 +32,142 @@ from app.schema import (
REASONING_MODELS = ["o1", "o3-mini"]
class TokenCounter:
# Token constants
BASE_MESSAGE_TOKENS = 4
FORMAT_TOKENS = 2
LOW_DETAIL_IMAGE_TOKENS = 85
HIGH_DETAIL_TILE_TOKENS = 170
# Image processing constants
MAX_SIZE = 2048
HIGH_DETAIL_TARGET_SHORT_SIDE = 768
TILE_SIZE = 512
def __init__(self, tokenizer):
self.tokenizer = tokenizer
def count_text(self, text: str) -> int:
"""Calculate tokens for a text string"""
return 0 if not text else len(self.tokenizer.encode(text))
def count_image(self, image_item: dict) -> int:
"""
Calculate tokens for an image based on detail level and dimensions
For "low" detail: fixed 85 tokens
For "high" detail:
1. Scale to fit in 2048x2048 square
2. Scale shortest side to 768px
3. Count 512px tiles (170 tokens each)
4. Add 85 tokens
"""
detail = image_item.get("detail", "medium")
# For low detail, always return fixed token count
if detail == "low":
return self.LOW_DETAIL_IMAGE_TOKENS
# For medium detail (default in OpenAI), use high detail calculation
# OpenAI doesn't specify a separate calculation for medium
# For high detail, calculate based on dimensions if available
if detail == "high" or detail == "medium":
# If dimensions are provided in the image_item
if "dimensions" in image_item:
width, height = image_item["dimensions"]
return self._calculate_high_detail_tokens(width, height)
# Default values when dimensions aren't available or detail level is unknown
if detail == "high":
# Default to a 1024x1024 image calculation for high detail
return self._calculate_high_detail_tokens(1024, 1024) # 765 tokens
elif detail == "medium":
# Default to a medium-sized image for medium detail
return 1024 # This matches the original default
else:
# For unknown detail levels, use medium as default
return 1024
def _calculate_high_detail_tokens(self, width: int, height: int) -> int:
"""Calculate tokens for high detail images based on dimensions"""
# Step 1: Scale to fit in MAX_SIZE x MAX_SIZE square
if width > self.MAX_SIZE or height > self.MAX_SIZE:
scale = self.MAX_SIZE / max(width, height)
width = int(width * scale)
height = int(height * scale)
# Step 2: Scale so shortest side is HIGH_DETAIL_TARGET_SHORT_SIDE
scale = self.HIGH_DETAIL_TARGET_SHORT_SIDE / min(width, height)
scaled_width = int(width * scale)
scaled_height = int(height * scale)
# Step 3: Count number of 512px tiles
tiles_x = math.ceil(scaled_width / self.TILE_SIZE)
tiles_y = math.ceil(scaled_height / self.TILE_SIZE)
total_tiles = tiles_x * tiles_y
# Step 4: Calculate final token count
return (
total_tiles * self.HIGH_DETAIL_TILE_TOKENS
) + self.LOW_DETAIL_IMAGE_TOKENS
def count_content(self, content: Union[str, List[Union[str, dict]]]) -> int:
"""Calculate tokens for message content"""
if not content:
return 0
if isinstance(content, str):
return self.count_text(content)
token_count = 0
for item in content:
if isinstance(item, str):
token_count += self.count_text(item)
elif isinstance(item, dict):
if "text" in item:
token_count += self.count_text(item["text"])
elif "image_url" in item:
token_count += self.count_image(item)
return token_count
def count_tool_calls(self, tool_calls: List[dict]) -> int:
"""Calculate tokens for tool calls"""
token_count = 0
for tool_call in tool_calls:
if "function" in tool_call:
function = tool_call["function"]
token_count += self.count_text(function.get("name", ""))
token_count += self.count_text(function.get("arguments", ""))
return token_count
def count_message_tokens(self, messages: List[dict]) -> int:
"""Calculate the total number of tokens in a message list"""
total_tokens = self.FORMAT_TOKENS # Base format tokens
for message in messages:
tokens = self.BASE_MESSAGE_TOKENS # Base tokens per message
# Add role tokens
tokens += self.count_text(message.get("role", ""))
# Add content tokens
if "content" in message:
tokens += self.count_content(message["content"])
# Add tool calls tokens
if "tool_calls" in message:
tokens += self.count_tool_calls(message["tool_calls"])
# Add name and tool_call_id tokens
tokens += self.count_text(message.get("name", ""))
tokens += self.count_text(message.get("tool_call_id", ""))
total_tokens += tokens
return total_tokens
class LLM:
_instances: Dict[str, "LLM"] = {}
@ -82,6 +219,8 @@ class LLM:
else:
self.client = AsyncOpenAI(api_key=self.api_key, base_url=self.base_url)
self.token_counter = TokenCounter(self.tokenizer)
def count_tokens(self, text: str) -> int:
"""Calculate the number of tokens in a text"""
if not text:
@ -89,46 +228,7 @@ class LLM:
return len(self.tokenizer.encode(text))
def count_message_tokens(self, messages: List[dict]) -> int:
"""Calculate the number of tokens in a message list"""
token_count = 0
for message in messages:
# Base token count for each message (according to OpenAI's calculation method)
token_count += 4 # Base token count for each message
# Calculate tokens for the role
if "role" in message:
token_count += self.count_tokens(message["role"])
# Calculate tokens for the content
if "content" in message and message["content"]:
token_count += self.count_tokens(message["content"])
# Calculate tokens for tool calls
if "tool_calls" in message and message["tool_calls"]:
for tool_call in message["tool_calls"]:
if "function" in tool_call:
# Function name
if "name" in tool_call["function"]:
token_count += self.count_tokens(
tool_call["function"]["name"]
)
# Function arguments
if "arguments" in tool_call["function"]:
token_count += self.count_tokens(
tool_call["function"]["arguments"]
)
# Calculate tokens for tool responses
if "name" in message and message["name"]:
token_count += self.count_tokens(message["name"])
if "tool_call_id" in message and message["tool_call_id"]:
token_count += self.count_tokens(message["tool_call_id"])
# Add extra tokens for message format
token_count += 2 # Extra tokens for message format
return token_count
return self.token_counter.count_message_tokens(messages)
def update_token_count(self, input_tokens: int, completion_tokens: int = 0) -> None:
"""Update token counts"""
@ -184,22 +284,58 @@ class LLM:
formatted_messages = []
for message in messages:
# Convert Message objects to dictionaries
if isinstance(message, Message):
message = message.to_dict()
if isinstance(message, dict):
# If message is a dict, ensure it has required fields
if "role" not in message:
raise ValueError("Message dict must contain 'role' field")
if "content" in message or "tool_calls" in message:
formatted_messages.append(message)
# else: do not include the message
else:
if not isinstance(message, dict):
raise TypeError(f"Unsupported message type: {type(message)}")
# Validate all messages have required fields
for msg in formatted_messages:
if msg["role"] not in ROLE_VALUES:
raise ValueError(f"Invalid role: {msg['role']}")
# Validate required fields
if "role" not in message:
raise ValueError("Message dict must contain 'role' field")
# Process base64 images if present
if message.get("base64_image"):
# Initialize or convert content to appropriate format
if not message.get("content"):
message["content"] = []
elif isinstance(message["content"], str):
message["content"] = [{"type": "text", "text": message["content"]}]
elif isinstance(message["content"], list):
# Convert string items to proper text objects
message["content"] = [
(
{"type": "text", "text": item}
if isinstance(item, str)
else item
)
for item in message["content"]
]
# Add the image to content
message["content"].append(
{
"type": "image_url",
"image_url": {
"url": f"data:image/jpeg;base64,{message['base64_image']}"
},
}
)
# Remove the base64_image field
del message["base64_image"]
# Only include messages with content or tool_calls
if "content" in message or "tool_calls" in message:
formatted_messages.append(message)
# Validate all roles
invalid_roles = [
msg for msg in formatted_messages if msg["role"] not in ROLE_VALUES
]
if invalid_roles:
raise ValueError(f"Invalid role: {invalid_roles[0]['role']}")
return formatted_messages
@ -328,6 +464,152 @@ class LLM:
logger.error(f"Unexpected error in ask: {e}")
raise
@retry(
wait=wait_random_exponential(min=1, max=60),
stop=stop_after_attempt(6),
retry=retry_if_exception_type(
(OpenAIError, Exception, ValueError)
), # Don't retry TokenLimitExceeded
)
async def ask_with_images(
self,
messages: List[Union[dict, Message]],
images: List[Union[str, dict]],
system_msgs: Optional[List[Union[dict, Message]]] = None,
stream: bool = False,
temperature: Optional[float] = None,
) -> str:
"""
Send a prompt with images to the LLM and get the response.
Args:
messages: List of conversation messages
images: List of image URLs or image data dictionaries
system_msgs: Optional system messages to prepend
stream (bool): Whether to stream the response
temperature (float): Sampling temperature for the response
Returns:
str: The generated response
Raises:
TokenLimitExceeded: If token limits are exceeded
ValueError: If messages are invalid or response is empty
OpenAIError: If API call fails after retries
Exception: For unexpected errors
"""
try:
# Format messages
formatted_messages = self.format_messages(messages)
# Ensure the last message is from the user to attach images
if not formatted_messages or formatted_messages[-1]["role"] != "user":
raise ValueError(
"The last message must be from the user to attach images"
)
# Process the last user message to include images
last_message = formatted_messages[-1]
# Convert content to multimodal format if needed
content = last_message["content"]
multimodal_content = (
[{"type": "text", "text": content}]
if isinstance(content, str)
else content
if isinstance(content, list)
else []
)
# Add images to content
for image in images:
if isinstance(image, str):
multimodal_content.append(
{"type": "image_url", "image_url": {"url": image}}
)
elif isinstance(image, dict) and "url" in image:
multimodal_content.append({"type": "image_url", "image_url": image})
elif isinstance(image, dict) and "image_url" in image:
multimodal_content.append(image)
else:
raise ValueError(f"Unsupported image format: {image}")
# Update the message with multimodal content
last_message["content"] = multimodal_content
# Add system messages if provided
if system_msgs:
all_messages = self.format_messages(system_msgs) + formatted_messages
else:
all_messages = formatted_messages
# Calculate tokens and check limits
input_tokens = self.count_message_tokens(all_messages)
if not self.check_token_limit(input_tokens):
raise TokenLimitExceeded(self.get_limit_error_message(input_tokens))
# Set up API parameters
params = {
"model": self.model,
"messages": all_messages,
"stream": stream,
}
# Add model-specific parameters
if self.model in REASONING_MODELS:
params["max_completion_tokens"] = self.max_tokens
else:
params["max_tokens"] = self.max_tokens
params["temperature"] = (
temperature if temperature is not None else self.temperature
)
# Handle non-streaming request
if not stream:
response = await self.client.chat.completions.create(**params)
if not response.choices or not response.choices[0].message.content:
raise ValueError("Empty or invalid response from LLM")
self.update_token_count(response.usage.prompt_tokens)
return response.choices[0].message.content
# Handle streaming request
self.update_token_count(input_tokens)
response = await self.client.chat.completions.create(**params)
collected_messages = []
async for chunk in response:
chunk_message = chunk.choices[0].delta.content or ""
collected_messages.append(chunk_message)
print(chunk_message, end="", flush=True)
print() # Newline after streaming
full_response = "".join(collected_messages).strip()
if not full_response:
raise ValueError("Empty response from streaming LLM")
return full_response
except TokenLimitExceeded:
raise
except ValueError as ve:
logger.error(f"Validation error in ask_with_images: {ve}")
raise
except OpenAIError as oe:
logger.error(f"OpenAI API error: {oe}")
if isinstance(oe, AuthenticationError):
logger.error("Authentication failed. Check API key.")
elif isinstance(oe, RateLimitError):
logger.error("Rate limit exceeded. Consider increasing retry attempts.")
elif isinstance(oe, APIError):
logger.error(f"API error: {oe}")
raise
except Exception as e:
logger.error(f"Unexpected error in ask_with_images: {e}")
raise
@retry(
wait=wait_random_exponential(min=1, max=60),
stop=stop_after_attempt(6),

View File

@ -6,9 +6,7 @@ PythonExecute: Execute Python code to interact with the computer system, data pr
FileSaver: Save files locally, such as txt, py, html, etc.
BrowserUseTool: Open, browse, and use web browsers.If you open a local HTML file, you must provide the absolute path to the file.
WebSearch: Perform web information retrieval
BrowserUseTool: Open, browse, and use web browsers. If you open a local HTML file, you must provide the absolute path to the file.
Terminate: End the current interaction when the task is complete or when you need additional information from the user. Use this tool to signal that you've finished addressing the user's request or need clarification before proceeding further.

View File

@ -59,6 +59,7 @@ class Message(BaseModel):
tool_calls: Optional[List[ToolCall]] = Field(default=None)
name: Optional[str] = Field(default=None)
tool_call_id: Optional[str] = Field(default=None)
base64_image: Optional[str] = Field(default=None)
def __add__(self, other) -> List["Message"]:
"""支持 Message + list 或 Message + Message 的操作"""
@ -91,12 +92,16 @@ class Message(BaseModel):
message["name"] = self.name
if self.tool_call_id is not None:
message["tool_call_id"] = self.tool_call_id
if self.base64_image is not None:
message["base64_image"] = self.base64_image
return message
@classmethod
def user_message(cls, content: str) -> "Message":
def user_message(
cls, content: str, base64_image: Optional[str] = None
) -> "Message":
"""Create a user message"""
return cls(role=Role.USER, content=content)
return cls(role=Role.USER, content=content, base64_image=base64_image)
@classmethod
def system_message(cls, content: str) -> "Message":
@ -104,33 +109,50 @@ class Message(BaseModel):
return cls(role=Role.SYSTEM, content=content)
@classmethod
def assistant_message(cls, content: Optional[str] = None) -> "Message":
def assistant_message(
cls, content: Optional[str] = None, base64_image: Optional[str] = None
) -> "Message":
"""Create an assistant message"""
return cls(role=Role.ASSISTANT, content=content)
return cls(role=Role.ASSISTANT, content=content, base64_image=base64_image)
@classmethod
def tool_message(cls, content: str, name, tool_call_id: str) -> "Message":
def tool_message(
cls, content: str, name, tool_call_id: str, base64_image: Optional[str] = None
) -> "Message":
"""Create a tool message"""
return cls(
role=Role.TOOL, content=content, name=name, tool_call_id=tool_call_id
role=Role.TOOL,
content=content,
name=name,
tool_call_id=tool_call_id,
base64_image=base64_image,
)
@classmethod
def from_tool_calls(
cls, tool_calls: List[Any], content: Union[str, List[str]] = "", **kwargs
cls,
tool_calls: List[Any],
content: Union[str, List[str]] = "",
base64_image: Optional[str] = None,
**kwargs,
):
"""Create ToolCallsMessage from raw tool calls.
Args:
tool_calls: Raw tool calls from LLM
content: Optional message content
base64_image: Optional base64 encoded image
"""
formatted_calls = [
{"id": call.id, "function": call.function.model_dump(), "type": "function"}
for call in tool_calls
]
return cls(
role=Role.ASSISTANT, content=content, tool_calls=formatted_calls, **kwargs
role=Role.ASSISTANT,
content=content,
tool_calls=formatted_calls,
base64_image=base64_image,
**kwargs,
)

View File

@ -37,6 +37,7 @@ class ToolResult(BaseModel):
output: Any = Field(default=None)
error: Optional[str] = Field(default=None)
base64_image: Optional[str] = Field(default=None)
system: Optional[str] = Field(default=None)
class Config:
@ -58,6 +59,7 @@ class ToolResult(BaseModel):
return ToolResult(
output=combine_fields(self.output, other.output),
error=combine_fields(self.error, other.error),
base64_image=combine_fields(self.base64_image, other.base64_image, False),
system=combine_fields(self.system, other.system),
)
@ -76,7 +78,3 @@ class CLIResult(ToolResult):
class ToolFailure(ToolResult):
"""A ToolResult that represents a failure."""
class AgentAwareTool:
agent: Optional = None

View File

@ -1,6 +1,6 @@
import asyncio
import json
from typing import Optional
from typing import Generic, Optional, TypeVar
from browser_use import Browser as BrowserUseBrowser
from browser_use import BrowserConfig
@ -10,31 +10,45 @@ from pydantic import Field, field_validator
from pydantic_core.core_schema import ValidationInfo
from app.config import config
from app.llm import LLM
from app.tool.base import BaseTool, ToolResult
from app.tool.web_search import WebSearch
MAX_LENGTH = 2000
_BROWSER_DESCRIPTION = """
Interact with a web browser to perform various actions such as navigation, element interaction,
content extraction, and tab management. Supported actions include:
- 'navigate': Go to a specific URL
- 'click': Click an element by index
- 'input_text': Input text into an element
- 'screenshot': Capture a screenshot
- 'get_html': Get page HTML content
- 'get_text': Get text content of the page
- 'read_links': Get all links on the page
- 'execute_js': Execute JavaScript code
- 'scroll': Scroll the page
- 'switch_tab': Switch to a specific tab
- 'new_tab': Open a new tab
- 'close_tab': Close the current tab
Interact with a web browser to perform various actions such as navigation, element interaction, content extraction, and tab management. This tool provides a comprehensive set of browser automation capabilities:
Navigation:
- 'go_to_url': Go to a specific URL in the current tab
- 'go_back': Go back
- 'refresh': Refresh the current page
- 'web_search': Search the query in the current tab, the query should be a search query like humans search in web, concrete and not vague or super long. More the single most important items.
Element Interaction:
- 'click_element': Click an element by index
- 'input_text': Input text into a form element
- 'scroll_down'/'scroll_up': Scroll the page (with optional pixel amount)
- 'scroll_to_text': If you dont find something which you want to interact with, scroll to it
- 'send_keys': Send strings of special keys like Escape,Backspace, Insert, PageDown, Delete, Enter, Shortcuts such as `Control+o`, `Control+Shift+T` are supported as well. This gets used in keyboard.press.
- 'get_dropdown_options': Get all options from a dropdown
- 'select_dropdown_option': Select dropdown option for interactive element index by the text of the option you want to select
Content Extraction:
- 'extract_content': Extract page content to retrieve specific information from the page, e.g. all company names, a specifc description, all information about, links with companies in structured format or simply links
Tab Management:
- 'switch_tab': Switch to a specific tab
- 'open_tab': Open a new tab with a URL
- 'close_tab': Close the current tab
Utility:
- 'wait': Wait for a specified number of seconds
"""
Context = TypeVar("Context")
class BrowserUseTool(BaseTool):
class BrowserUseTool(BaseTool, Generic[Context]):
name: str = "browser_use"
description: str = _BROWSER_DESCRIPTION
parameters: dict = {
@ -43,52 +57,79 @@ class BrowserUseTool(BaseTool):
"action": {
"type": "string",
"enum": [
"navigate",
"click",
"go_to_url",
"click_element",
"input_text",
"screenshot",
"get_html",
"get_text",
"execute_js",
"scroll",
"scroll_down",
"scroll_up",
"scroll_to_text",
"send_keys",
"get_dropdown_options",
"select_dropdown_option",
"go_back",
"web_search",
"wait",
"extract_content",
"switch_tab",
"new_tab",
"open_tab",
"close_tab",
"refresh",
],
"description": "The browser action to perform",
},
"url": {
"type": "string",
"description": "URL for 'navigate' or 'new_tab' actions",
"description": "URL for 'go_to_url' or 'open_tab' actions",
},
"index": {
"type": "integer",
"description": "Element index for 'click' or 'input_text' actions",
"description": "Element index for 'click_element', 'input_text', 'get_dropdown_options', or 'select_dropdown_option' actions",
},
"text": {"type": "string", "description": "Text for 'input_text' action"},
"script": {
"text": {
"type": "string",
"description": "JavaScript code for 'execute_js' action",
"description": "Text for 'input_text', 'scroll_to_text', or 'select_dropdown_option' actions",
},
"scroll_amount": {
"type": "integer",
"description": "Pixels to scroll (positive for down, negative for up) for 'scroll' action",
"description": "Pixels to scroll (positive for down, negative for up) for 'scroll_down' or 'scroll_up' actions",
},
"tab_id": {
"type": "integer",
"description": "Tab ID for 'switch_tab' action",
},
"query": {
"type": "string",
"description": "Search query for 'web_search' action",
},
"goal": {
"type": "string",
"description": "Extraction goal for 'extract_content' action",
},
"keys": {
"type": "string",
"description": "Keys to send for 'send_keys' action",
},
"seconds": {
"type": "integer",
"description": "Seconds to wait for 'wait' action",
},
},
"required": ["action"],
"dependencies": {
"navigate": ["url"],
"click": ["index"],
"go_to_url": ["url"],
"click_element": ["index"],
"input_text": ["index", "text"],
"execute_js": ["script"],
"switch_tab": ["tab_id"],
"new_tab": ["url"],
"scroll": ["scroll_amount"],
"open_tab": ["url"],
"scroll_down": ["scroll_amount"],
"scroll_up": ["scroll_amount"],
"scroll_to_text": ["text"],
"send_keys": ["keys"],
"get_dropdown_options": ["index"],
"select_dropdown_option": ["index", "text"],
"go_back": [],
"web_search": ["query"],
"wait": ["seconds"],
"extract_content": ["goal"],
},
}
@ -96,6 +137,12 @@ class BrowserUseTool(BaseTool):
browser: Optional[BrowserUseBrowser] = Field(default=None, exclude=True)
context: Optional[BrowserContext] = Field(default=None, exclude=True)
dom_service: Optional[DomService] = Field(default=None, exclude=True)
web_search_tool: WebSearch = Field(default_factory=WebSearch, exclude=True)
# Context for generic functionality
tool_context: Optional[Context] = Field(default=None, exclude=True)
llm: Optional[LLM] = Field(default_factory=LLM)
@field_validator("parameters", mode="before")
def validate_parameters(cls, v: dict, info: ValidationInfo) -> dict:
@ -158,9 +205,12 @@ class BrowserUseTool(BaseTool):
url: Optional[str] = None,
index: Optional[int] = None,
text: Optional[str] = None,
script: Optional[str] = None,
scroll_amount: Optional[int] = None,
tab_id: Optional[int] = None,
query: Optional[str] = None,
goal: Optional[str] = None,
keys: Optional[str] = None,
seconds: Optional[int] = None,
**kwargs,
) -> ToolResult:
"""
@ -170,10 +220,13 @@ class BrowserUseTool(BaseTool):
action: The browser action to perform
url: URL for navigation or new tab
index: Element index for click or input actions
text: Text for input action
script: JavaScript code for execution
text: Text for input action or search query
scroll_amount: Pixels to scroll for scroll action
tab_id: Tab ID for switch_tab action
query: Search query for Google search
goal: Extraction goal for content extraction
keys: Keys to send for keyboard actions
seconds: Seconds to wait
**kwargs: Additional arguments
Returns:
@ -183,15 +236,68 @@ class BrowserUseTool(BaseTool):
try:
context = await self._ensure_browser_initialized()
if action == "navigate":
# Get max content length from config
max_content_length = getattr(
config.browser_config, "max_content_length", 2000
)
# Navigation actions
if action == "go_to_url":
if not url:
return ToolResult(error="URL is required for 'navigate' action")
await context.navigate_to(url)
return ToolResult(
error="URL is required for 'go_to_url' action"
)
page = await context.get_current_page()
await page.goto(url)
await page.wait_for_load_state()
return ToolResult(output=f"Navigated to {url}")
elif action == "click":
elif action == "go_back":
await context.go_back()
return ToolResult(output="Navigated back")
elif action == "refresh":
await context.refresh_page()
return ToolResult(output="Refreshed current page")
elif action == "web_search":
if not query:
return ToolResult(
error="Query is required for 'web_search' action"
)
search_results = await self.web_search_tool.execute(query)
if search_results:
# Navigate to the first search result
first_result = search_results[0]
if isinstance(first_result, dict) and "url" in first_result:
url_to_navigate = first_result["url"]
elif isinstance(first_result, str):
url_to_navigate = first_result
else:
return ToolResult(
error=f"Invalid search result format: {first_result}"
)
page = await context.get_current_page()
await page.goto(url_to_navigate)
await page.wait_for_load_state()
return ToolResult(
output=f"Searched for '{query}' and navigated to first result: {url_to_navigate}\nAll results:"
+ "\n".join([str(r) for r in search_results])
)
else:
return ToolResult(
error=f"No search results found for '{query}'"
)
# Element interaction actions
elif action == "click_element":
if index is None:
return ToolResult(error="Index is required for 'click' action")
return ToolResult(
error="Index is required for 'click_element' action"
)
element = await context.get_dom_element_by_index(index)
if not element:
return ToolResult(error=f"Element with index {index} not found")
@ -214,72 +320,174 @@ class BrowserUseTool(BaseTool):
output=f"Input '{text}' into element at index {index}"
)
elif action == "screenshot":
screenshot = await context.take_screenshot(full_page=True)
return ToolResult(
output=f"Screenshot captured (base64 length: {len(screenshot)})",
system=screenshot,
elif action == "scroll_down" or action == "scroll_up":
direction = 1 if action == "scroll_down" else -1
amount = (
scroll_amount
if scroll_amount is not None
else context.config.browser_window_size["height"]
)
elif action == "get_html":
html = await context.get_page_html()
truncated = (
html[:MAX_LENGTH] + "..." if len(html) > MAX_LENGTH else html
)
return ToolResult(output=truncated)
elif action == "get_text":
text = await context.execute_javascript("document.body.innerText")
return ToolResult(output=text)
elif action == "read_links":
links = await context.execute_javascript(
"document.querySelectorAll('a[href]').forEach((elem) => {if (elem.innerText) {console.log(elem.innerText, elem.href)}})"
)
return ToolResult(output=links)
elif action == "execute_js":
if not script:
return ToolResult(
error="Script is required for 'execute_js' action"
)
result = await context.execute_javascript(script)
return ToolResult(output=str(result))
elif action == "scroll":
if scroll_amount is None:
return ToolResult(
error="Scroll amount is required for 'scroll' action"
)
await context.execute_javascript(
f"window.scrollBy(0, {scroll_amount});"
f"window.scrollBy(0, {direction * amount});"
)
direction = "down" if scroll_amount > 0 else "up"
return ToolResult(
output=f"Scrolled {direction} by {abs(scroll_amount)} pixels"
output=f"Scrolled {'down' if direction > 0 else 'up'} by {amount} pixels"
)
elif action == "scroll_to_text":
if not text:
return ToolResult(
error="Text is required for 'scroll_to_text' action"
)
page = await context.get_current_page()
try:
locator = page.get_by_text(text, exact=False)
await locator.scroll_into_view_if_needed()
return ToolResult(output=f"Scrolled to text: '{text}'")
except Exception as e:
return ToolResult(error=f"Failed to scroll to text: {str(e)}")
elif action == "send_keys":
if not keys:
return ToolResult(
error="Keys are required for 'send_keys' action"
)
page = await context.get_current_page()
await page.keyboard.press(keys)
return ToolResult(output=f"Sent keys: {keys}")
elif action == "get_dropdown_options":
if index is None:
return ToolResult(
error="Index is required for 'get_dropdown_options' action"
)
element = await context.get_dom_element_by_index(index)
if not element:
return ToolResult(error=f"Element with index {index} not found")
page = await context.get_current_page()
options = await page.evaluate(
"""
(xpath) => {
const select = document.evaluate(xpath, document, null,
XPathResult.FIRST_ORDERED_NODE_TYPE, null).singleNodeValue;
if (!select) return null;
return Array.from(select.options).map(opt => ({
text: opt.text,
value: opt.value,
index: opt.index
}));
}
""",
element.xpath,
)
return ToolResult(output=f"Dropdown options: {options}")
elif action == "select_dropdown_option":
if index is None or not text:
return ToolResult(
error="Index and text are required for 'select_dropdown_option' action"
)
element = await context.get_dom_element_by_index(index)
if not element:
return ToolResult(error=f"Element with index {index} not found")
page = await context.get_current_page()
await page.select_option(element.xpath, label=text)
return ToolResult(
output=f"Selected option '{text}' from dropdown at index {index}"
)
# Content extraction actions
elif action == "extract_content":
if not goal:
return ToolResult(
error="Goal is required for 'extract_content' action"
)
page = await context.get_current_page()
try:
# Get page content and convert to markdown for better processing
html_content = await page.content()
# Import markdownify here to avoid global import
try:
import markdownify
content = markdownify.markdownify(html_content)
except ImportError:
# Fallback if markdownify is not available
content = html_content
# Create prompt for LLM
prompt_text = """
Your task is to extract the content of the page. You will be given a page and a goal, and you should extract all relevant information around this goal from the page.
Examples of extraction goals:
- Extract all company names
- Extract specific descriptions
- Extract all information about a topic
- Extract links with companies in structured format
- Extract all links
If the goal is vague, summarize the page. Respond in JSON format.
Extraction goal: {goal}
Page content:
{page}
"""
# Format the prompt with the goal and content
max_content_length = min(50000, len(content))
formatted_prompt = prompt_text.format(
goal=goal, page=content[:max_content_length]
)
# Create a proper message list for the LLM
from app.schema import Message
messages = [Message.user_message(formatted_prompt)]
# Use LLM to extract content based on the goal
response = await self.llm.ask(messages)
msg = f"Extracted from page:\n{response}\n"
return ToolResult(output=msg)
except Exception as e:
# Provide a more helpful error message
error_msg = f"Failed to extract content: {str(e)}"
try:
# Try to return a portion of the page content as fallback
return ToolResult(
output=f"{error_msg}\nHere's a portion of the page content:\n{content[:2000]}..."
)
except:
# If all else fails, just return the error
return ToolResult(error=error_msg)
# Tab management actions
elif action == "switch_tab":
if tab_id is None:
return ToolResult(
error="Tab ID is required for 'switch_tab' action"
)
await context.switch_to_tab(tab_id)
page = await context.get_current_page()
await page.wait_for_load_state()
return ToolResult(output=f"Switched to tab {tab_id}")
elif action == "new_tab":
elif action == "open_tab":
if not url:
return ToolResult(error="URL is required for 'new_tab' action")
return ToolResult(error="URL is required for 'open_tab' action")
await context.create_new_tab(url)
return ToolResult(output=f"Opened new tab with URL {url}")
return ToolResult(output=f"Opened new tab with {url}")
elif action == "close_tab":
await context.close_current_tab()
return ToolResult(output="Closed current tab")
elif action == "refresh":
await context.refresh_page()
return ToolResult(output="Refreshed current page")
# Utility actions
elif action == "wait":
seconds_to_wait = seconds if seconds is not None else 3
await asyncio.sleep(seconds_to_wait)
return ToolResult(output=f"Waited for {seconds_to_wait} seconds")
else:
return ToolResult(error=f"Unknown action: {action}")
@ -287,21 +495,58 @@ class BrowserUseTool(BaseTool):
except Exception as e:
return ToolResult(error=f"Browser action '{action}' failed: {str(e)}")
async def get_current_state(self) -> ToolResult:
"""Get the current browser state as a ToolResult."""
async with self.lock:
try:
context = await self._ensure_browser_initialized()
state = await context.get_state()
state_info = {
"url": state.url,
"title": state.title,
"tabs": [tab.model_dump() for tab in state.tabs],
"interactive_elements": state.element_tree.clickable_elements_to_string(),
}
return ToolResult(output=json.dumps(state_info))
except Exception as e:
return ToolResult(error=f"Failed to get browser state: {str(e)}")
async def get_current_state(
self, context: Optional[BrowserContext] = None
) -> ToolResult:
"""
Get the current browser state as a ToolResult.
If context is not provided, uses self.context.
"""
try:
# Use provided context or fall back to self.context
ctx = context or self.context
if not ctx:
return ToolResult(error="Browser context not initialized")
state = await ctx.get_state()
# Create a viewport_info dictionary if it doesn't exist
viewport_height = 0
if hasattr(state, "viewport_info") and state.viewport_info:
viewport_height = state.viewport_info.height
elif hasattr(ctx, "config") and hasattr(ctx.config, "browser_window_size"):
viewport_height = ctx.config.browser_window_size.get("height", 0)
# Take a screenshot for the state
screenshot = await ctx.take_screenshot(full_page=True)
# Build the state info with all required fields
state_info = {
"url": state.url,
"title": state.title,
"tabs": [tab.model_dump() for tab in state.tabs],
"help": "[0], [1], [2], etc., represent clickable indices corresponding to the elements listed. Clicking on these indices will navigate to or interact with the respective content behind them.",
"interactive_elements": (
state.element_tree.clickable_elements_to_string()
if state.element_tree
else ""
),
"scroll_info": {
"pixels_above": getattr(state, "pixels_above", 0),
"pixels_below": getattr(state, "pixels_below", 0),
"total_height": getattr(state, "pixels_above", 0)
+ getattr(state, "pixels_below", 0)
+ viewport_height,
},
"viewport_height": viewport_height,
}
return ToolResult(
output=json.dumps(state_info, indent=4, ensure_ascii=False),
base64_image=screenshot,
)
except Exception as e:
return ToolResult(error=f"Failed to get browser state: {str(e)}")
async def cleanup(self):
"""Clean up browser resources."""
@ -323,3 +568,10 @@ class BrowserUseTool(BaseTool):
loop = asyncio.new_event_loop()
loop.run_until_complete(self.cleanup())
loop.close()
@classmethod
def create_with_context(cls, context: Context) -> "BrowserUseTool[Context]":
"""Factory method to create a BrowserUseTool with a specific context."""
tool = cls()
tool.tool_context = context
return tool

View File

@ -1,11 +1,10 @@
# Global LLM configuration
[llm]
model = "gpt-4o" # The LLM model to use
base_url = "https://api.openai.com/v1" # API endpoint URL
api_key = "sk-..." # Your API key
max_tokens = 8192 # Maximum number of tokens in the response
temperature = 0.0 # Controls randomness
# max_input_tokens = 100000 # Maximum input tokens to use across all requests (set to null or delete this line for unlimited)
model = "claude-3-7-sonnet-20250219" # The LLM model to use
base_url = "https://api.anthropic.com/v1/" # API endpoint URL
api_key = "YOUR_API_KEY" # Your API key
max_tokens = 8192 # Maximum number of tokens in the response
temperature = 0.0 # Controls randomness
# [llm] #AZURE OPENAI:
# api_type= 'azure'
@ -26,11 +25,11 @@ temperature = 0.0 # Controls randomness
# Optional configuration for specific LLM models
[llm.vision]
model = "gpt-4o" # The vision model to use
base_url = "https://api.openai.com/v1" # API endpoint URL for vision model
api_key = "sk-..." # Your API key for vision model
max_tokens = 8192 # Maximum number of tokens in the response
temperature = 0.0 # Controls randomness for vision model
model = "claude-3-7-sonnet-20250219" # The vision model to use
base_url = "https://api.anthropic.com/v1/" # API endpoint URL for vision model
api_key = "YOUR_API_KEY" # Your API key for vision model
max_tokens = 8192 # Maximum number of tokens in the response
temperature = 0.0 # Controls randomness for vision model
# [llm.vision] #OLLAMA VISION:
# api_type = 'ollama'

View File

@ -20,6 +20,6 @@ baidusearch~=1.0.3
duckduckgo_search~=7.5.1
aiofiles~=24.1.0
pydantic_core~=2.32.0
pydantic_core~=2.27.2
colorama~=0.4.6
playwright~=1.50.0

View File

@ -31,7 +31,7 @@ setup(
"browser-use~=0.1.40",
"googlesearch-python~=1.3.0",
"aiofiles~=24.1.0",
"pydantic_core>=2.27.2,<2.33.0",
"pydantic_core>=2.27.2,<2.28.0",
"colorama~=0.4.6",
],
classifiers=[