Merge branch 'mannaandpoem:main' into main

This commit is contained in:
zyren123 2025-03-18 09:46:10 +08:00 committed by GitHub
commit f474290395
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
16 changed files with 995 additions and 201 deletions

View File

@ -0,0 +1,33 @@
name: Environment Corruption Check
on:
push:
branches: ["main"]
paths:
- requirements.txt
pull_request:
branches: ["main"]
paths:
- requirements.txt
concurrency:
group: ${{ github.workflow }}-${{ github.event_name }}-${{ github.ref }}
cancel-in-progress: true
jobs:
test-python-versions:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: ["3.11.11", "3.12.8", "3.13.2"]
fail-fast: false
steps:
- name: Checkout repository
uses: actions/checkout@v4
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python-version }}
- name: Upgrade pip
run: |
python -m pip install --upgrade pip
- name: Install dependencies
run: |
pip install -r requirements.txt

127
.github/workflows/pr-autodiff.yaml vendored Normal file
View File

@ -0,0 +1,127 @@
name: PR Diff Summarization
on:
# pull_request:
# branches: [main]
# types: [opened, ready_for_review, reopened]
issue_comment:
types: [created]
permissions:
contents: read
pull-requests: write
jobs:
pr-diff-summarization:
runs-on: ubuntu-latest
if: |
(github.event_name == 'pull_request') ||
(github.event_name == 'issue_comment' &&
contains(github.event.comment.body, '!pr-diff') &&
(github.event.comment.author_association == 'COLLABORATOR' || github.event.comment.author_association == 'MEMBER' || github.event.comment.author_association == 'OWNER') &&
github.event.issue.pull_request)
steps:
- name: Get PR head SHA
id: get-pr-sha
run: |
if [ "${{ github.event_name }}" == "pull_request" ]; then
echo "pr_sha=${{ github.event.pull_request.head.sha }}" >> $GITHUB_OUTPUT
echo "Retrieved PR head SHA: ${{ github.event.pull_request.head.sha }}"
else
PR_URL="${{ github.event.issue.pull_request.url }}"
SHA=$(curl -s -H "Authorization: Bearer ${{ secrets.GITHUB_TOKEN }}" $PR_URL | jq -r '.head.sha')
echo "pr_sha=$SHA" >> $GITHUB_OUTPUT
echo "Retrieved PR head SHA from API: $SHA"
fi
- name: Check out code
uses: actions/checkout@v4
with:
ref: ${{ steps.get-pr-sha.outputs.pr_sha }}
fetch-depth: 0
- name: Set up Python
uses: actions/setup-python@v5
with:
python-version: '3.11'
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install openai requests
- name: Create and run Python script
env:
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
OPENAI_BASE_URL: ${{ secrets.OPENAI_BASE_URL }}
GH_TOKEN: ${{ github.token }}
PR_NUMBER: ${{ github.event.pull_request.number || github.event.issue.number }}
run: |-
cat << 'EOF' > /tmp/_workflow_core.py
import os
import subprocess
import json
import requests
from openai import OpenAI
def get_diff():
result = subprocess.run(
['git', 'diff', 'origin/main...HEAD'],
capture_output=True, text=True, check=True)
return '\n'.join(
line for line in result.stdout.split('\n')
if any(line.startswith(c) for c in ('+', '-'))
and not line.startswith(('---', '+++'))
)[:round(200000 * 0.4)] # Truncate to prevent overflow
def generate_comment(diff_content):
client = OpenAI(
base_url=os.getenv("OPENAI_BASE_URL"),
api_key=os.getenv("OPENAI_API_KEY")
)
guidelines = '''
1. English version first, Chinese Simplified version after
2. Example format:
# Diff Report
## English
- Added `ABC` class
- Fixed `f()` behavior in `foo` module
### Comments Highlight
- `config.toml` needs to be configured properly to make sure new features work as expected.
### Spelling/Offensive Content Check
- No spelling mistakes or offensive content found in the code or comments.
3. Highlight non-English comments
4. Check for spelling/offensive content'''
response = client.chat.completions.create(
model="o3-mini",
messages=[{
"role": "system",
"content": "Generate bilingual code review feedback."
}, {
"role": "user",
"content": f"Review these changes per guidelines:\n{guidelines}\n\nDIFF:\n{diff_content}"
}]
)
return response.choices[0].message.content
def post_comment(comment):
repo = os.getenv("GITHUB_REPOSITORY")
pr_number = os.getenv("PR_NUMBER")
headers = {
"Authorization": f"Bearer {os.getenv('GH_TOKEN')}",
"Accept": "application/vnd.github.v3+json"
}
url = f"https://api.github.com/repos/{repo}/issues/{pr_number}/comments"
requests.post(url, json={"body": comment}, headers=headers)
if __name__ == "__main__":
diff_content = get_diff()
if not diff_content.strip():
print("No meaningful diff detected.")
exit(0)
comment = generate_comment(diff_content)
post_comment(comment)
print("Comment posted successfully.")
EOF
python /tmp/_workflow_core.py

View File

@ -24,4 +24,4 @@ jobs:
dashboard_show_total_reactions: true dashboard_show_total_reactions: true
top_issues: true top_issues: true
top_pull_requests: true top_pull_requests: true
top_list_size: 32 top_list_size: 14

View File

@ -0,0 +1,10 @@
# Python version check: 3.11-3.13
import sys
if sys.version_info < (3, 11) or sys.version_info > (3, 13):
print(
"Warning: Unsupported Python version {ver}, please use 3.11-3.13".format(
ver=".".join(map(str, sys.version_info))
)
)

View File

@ -84,6 +84,7 @@ class BaseAgent(BaseModel, ABC):
self, self,
role: ROLE_TYPE, # type: ignore role: ROLE_TYPE, # type: ignore
content: str, content: str,
base64_image: Optional[str] = None,
**kwargs, **kwargs,
) -> None: ) -> None:
"""Add a message to the agent's memory. """Add a message to the agent's memory.
@ -91,6 +92,7 @@ class BaseAgent(BaseModel, ABC):
Args: Args:
role: The role of the message sender (user, system, assistant, tool). role: The role of the message sender (user, system, assistant, tool).
content: The message content. content: The message content.
base64_image: Optional base64 encoded image.
**kwargs: Additional arguments (e.g., tool_call_id for tool messages). **kwargs: Additional arguments (e.g., tool_call_id for tool messages).
Raises: Raises:
@ -106,9 +108,9 @@ class BaseAgent(BaseModel, ABC):
if role not in message_map: if role not in message_map:
raise ValueError(f"Unsupported message role: {role}") raise ValueError(f"Unsupported message role: {role}")
msg_factory = message_map[role] # Create message with appropriate parameters based on role
msg = msg_factory(content, **kwargs) if role == "tool" else msg_factory(content) kwargs = {"base64_image": base64_image, **(kwargs if role == "tool" else {})}
self.memory.add_message(msg) self.memory.add_message(message_map[role](content, **kwargs))
async def run(self, request: Optional[str] = None) -> str: async def run(self, request: Optional[str] = None) -> str:
"""Execute the agent's main loop asynchronously. """Execute the agent's main loop asynchronously.

View File

@ -1,14 +1,15 @@
from typing import Any import json
from typing import Any, Optional
from pydantic import Field from pydantic import Field
from app.agent.toolcall import ToolCallAgent from app.agent.toolcall import ToolCallAgent
from app.logger import logger
from app.prompt.manus import NEXT_STEP_PROMPT, SYSTEM_PROMPT from app.prompt.manus import NEXT_STEP_PROMPT, SYSTEM_PROMPT
from app.tool import Terminate, ToolCollection from app.tool import Terminate, ToolCollection
from app.tool.browser_use_tool import BrowserUseTool from app.tool.browser_use_tool import BrowserUseTool
from app.tool.file_saver import FileSaver from app.tool.file_saver import FileSaver
from app.tool.python_execute import PythonExecute from app.tool.python_execute import PythonExecute
from app.tool.web_search import WebSearch
class Manus(ToolCallAgent): class Manus(ToolCallAgent):
@ -34,7 +35,7 @@ class Manus(ToolCallAgent):
# Add general-purpose tools to the tool collection # Add general-purpose tools to the tool collection
available_tools: ToolCollection = Field( available_tools: ToolCollection = Field(
default_factory=lambda: ToolCollection( default_factory=lambda: ToolCollection(
PythonExecute(), WebSearch(), BrowserUseTool(), FileSaver(), Terminate() PythonExecute(), BrowserUseTool(), FileSaver(), Terminate()
) )
) )
@ -44,3 +45,45 @@ class Manus(ToolCallAgent):
else: else:
await self.available_tools.get_tool(BrowserUseTool().name).cleanup() await self.available_tools.get_tool(BrowserUseTool().name).cleanup()
await super()._handle_special_tool(name, result, **kwargs) await super()._handle_special_tool(name, result, **kwargs)
async def get_browser_state(self) -> Optional[dict]:
"""Get the current browser state for context in next steps."""
browser_tool = self.available_tools.get_tool(BrowserUseTool().name)
if not browser_tool:
return None
try:
# Get browser state directly from the tool with no context parameter
result = await browser_tool.get_current_state()
if result.error:
logger.debug(f"Browser state error: {result.error}")
return None
# Store screenshot if available
if hasattr(result, "base64_image") and result.base64_image:
self._current_base64_image = result.base64_image
# Parse the state info
return json.loads(result.output)
except Exception as e:
logger.debug(f"Failed to get browser state: {str(e)}")
return None
async def think(self) -> bool:
# Add your custom pre-processing here
browser_state = await self.get_browser_state()
# Modify the next_step_prompt temporarily
original_prompt = self.next_step_prompt
if browser_state and not browser_state.get("error"):
self.next_step_prompt += f"\nCurrent browser state:\nURL: {browser_state.get('url', 'N/A')}\nTitle: {browser_state.get('title', 'N/A')}\n"
# Call parent implementation
result = await super().think()
# Restore original prompt
self.next_step_prompt = original_prompt
return result

View File

@ -30,6 +30,7 @@ class ToolCallAgent(ReActAgent):
special_tool_names: List[str] = Field(default_factory=lambda: [Terminate().name]) special_tool_names: List[str] = Field(default_factory=lambda: [Terminate().name])
tool_calls: List[ToolCall] = Field(default_factory=list) tool_calls: List[ToolCall] = Field(default_factory=list)
_current_base64_image: Optional[str] = None
max_steps: int = 30 max_steps: int = 30
max_observe: Optional[Union[int, bool]] = None max_observe: Optional[Union[int, bool]] = None
@ -44,9 +45,11 @@ class ToolCallAgent(ReActAgent):
# Get response with tool options # Get response with tool options
response = await self.llm.ask_tool( response = await self.llm.ask_tool(
messages=self.messages, messages=self.messages,
system_msgs=[Message.system_message(self.system_prompt)] system_msgs=(
if self.system_prompt [Message.system_message(self.system_prompt)]
else None, if self.system_prompt
else None
),
tools=self.available_tools.to_params(), tools=self.available_tools.to_params(),
tool_choice=self.tool_choices, tool_choice=self.tool_choices,
) )
@ -79,6 +82,9 @@ class ToolCallAgent(ReActAgent):
logger.info( logger.info(
f"🧰 Tools being prepared: {[call.function.name for call in response.tool_calls]}" f"🧰 Tools being prepared: {[call.function.name for call in response.tool_calls]}"
) )
logger.info(
f"🔧 Tool arguments: {response.tool_calls[0].function.arguments}"
)
try: try:
# Handle different tool_choices modes # Handle different tool_choices modes
@ -130,6 +136,9 @@ class ToolCallAgent(ReActAgent):
results = [] results = []
for command in self.tool_calls: for command in self.tool_calls:
# Reset base64_image for each tool call
self._current_base64_image = None
result = await self.execute_tool(command) result = await self.execute_tool(command)
if self.max_observe: if self.max_observe:
@ -141,7 +150,10 @@ class ToolCallAgent(ReActAgent):
# Add tool response to memory # Add tool response to memory
tool_msg = Message.tool_message( tool_msg = Message.tool_message(
content=result, tool_call_id=command.id, name=command.function.name content=result,
tool_call_id=command.id,
name=command.function.name,
base64_image=self._current_base64_image,
) )
self.memory.add_message(tool_msg) self.memory.add_message(tool_msg)
results.append(result) results.append(result)
@ -165,16 +177,29 @@ class ToolCallAgent(ReActAgent):
logger.info(f"🔧 Activating tool: '{name}'...") logger.info(f"🔧 Activating tool: '{name}'...")
result = await self.available_tools.execute(name=name, tool_input=args) result = await self.available_tools.execute(name=name, tool_input=args)
# Format result for display # Handle special tools
await self._handle_special_tool(name=name, result=result)
# Check if result is a ToolResult with base64_image
if hasattr(result, "base64_image") and result.base64_image:
# Store the base64_image for later use in tool_message
self._current_base64_image = result.base64_image
# Format result for display
observation = (
f"Observed output of cmd `{name}` executed:\n{str(result)}"
if result
else f"Cmd `{name}` completed with no output"
)
return observation
# Format result for display (standard case)
observation = ( observation = (
f"Observed output of cmd `{name}` executed:\n{str(result)}" f"Observed output of cmd `{name}` executed:\n{str(result)}"
if result if result
else f"Cmd `{name}` completed with no output" else f"Cmd `{name}` completed with no output"
) )
# Handle special tools like `finish`
await self._handle_special_tool(name=name, result=result)
return observation return observation
except json.JSONDecodeError: except json.JSONDecodeError:
error_msg = f"Error parsing arguments for {name}: Invalid JSON format" error_msg = f"Error parsing arguments for {name}: Invalid JSON format"

View File

@ -59,6 +59,9 @@ class BrowserSettings(BaseModel):
proxy: Optional[ProxySettings] = Field( proxy: Optional[ProxySettings] = Field(
None, description="Proxy settings for the browser" None, description="Proxy settings for the browser"
) )
max_content_length: int = Field(
2000, description="Maximum length for content retrieval operations"
)
class AppConfig(BaseModel): class AppConfig(BaseModel):

View File

@ -1,3 +1,4 @@
import math
from typing import Dict, List, Optional, Union from typing import Dict, List, Optional, Union
import tiktoken import tiktoken
@ -31,6 +32,142 @@ from app.schema import (
REASONING_MODELS = ["o1", "o3-mini"] REASONING_MODELS = ["o1", "o3-mini"]
class TokenCounter:
# Token constants
BASE_MESSAGE_TOKENS = 4
FORMAT_TOKENS = 2
LOW_DETAIL_IMAGE_TOKENS = 85
HIGH_DETAIL_TILE_TOKENS = 170
# Image processing constants
MAX_SIZE = 2048
HIGH_DETAIL_TARGET_SHORT_SIDE = 768
TILE_SIZE = 512
def __init__(self, tokenizer):
self.tokenizer = tokenizer
def count_text(self, text: str) -> int:
"""Calculate tokens for a text string"""
return 0 if not text else len(self.tokenizer.encode(text))
def count_image(self, image_item: dict) -> int:
"""
Calculate tokens for an image based on detail level and dimensions
For "low" detail: fixed 85 tokens
For "high" detail:
1. Scale to fit in 2048x2048 square
2. Scale shortest side to 768px
3. Count 512px tiles (170 tokens each)
4. Add 85 tokens
"""
detail = image_item.get("detail", "medium")
# For low detail, always return fixed token count
if detail == "low":
return self.LOW_DETAIL_IMAGE_TOKENS
# For medium detail (default in OpenAI), use high detail calculation
# OpenAI doesn't specify a separate calculation for medium
# For high detail, calculate based on dimensions if available
if detail == "high" or detail == "medium":
# If dimensions are provided in the image_item
if "dimensions" in image_item:
width, height = image_item["dimensions"]
return self._calculate_high_detail_tokens(width, height)
# Default values when dimensions aren't available or detail level is unknown
if detail == "high":
# Default to a 1024x1024 image calculation for high detail
return self._calculate_high_detail_tokens(1024, 1024) # 765 tokens
elif detail == "medium":
# Default to a medium-sized image for medium detail
return 1024 # This matches the original default
else:
# For unknown detail levels, use medium as default
return 1024
def _calculate_high_detail_tokens(self, width: int, height: int) -> int:
"""Calculate tokens for high detail images based on dimensions"""
# Step 1: Scale to fit in MAX_SIZE x MAX_SIZE square
if width > self.MAX_SIZE or height > self.MAX_SIZE:
scale = self.MAX_SIZE / max(width, height)
width = int(width * scale)
height = int(height * scale)
# Step 2: Scale so shortest side is HIGH_DETAIL_TARGET_SHORT_SIDE
scale = self.HIGH_DETAIL_TARGET_SHORT_SIDE / min(width, height)
scaled_width = int(width * scale)
scaled_height = int(height * scale)
# Step 3: Count number of 512px tiles
tiles_x = math.ceil(scaled_width / self.TILE_SIZE)
tiles_y = math.ceil(scaled_height / self.TILE_SIZE)
total_tiles = tiles_x * tiles_y
# Step 4: Calculate final token count
return (
total_tiles * self.HIGH_DETAIL_TILE_TOKENS
) + self.LOW_DETAIL_IMAGE_TOKENS
def count_content(self, content: Union[str, List[Union[str, dict]]]) -> int:
"""Calculate tokens for message content"""
if not content:
return 0
if isinstance(content, str):
return self.count_text(content)
token_count = 0
for item in content:
if isinstance(item, str):
token_count += self.count_text(item)
elif isinstance(item, dict):
if "text" in item:
token_count += self.count_text(item["text"])
elif "image_url" in item:
token_count += self.count_image(item)
return token_count
def count_tool_calls(self, tool_calls: List[dict]) -> int:
"""Calculate tokens for tool calls"""
token_count = 0
for tool_call in tool_calls:
if "function" in tool_call:
function = tool_call["function"]
token_count += self.count_text(function.get("name", ""))
token_count += self.count_text(function.get("arguments", ""))
return token_count
def count_message_tokens(self, messages: List[dict]) -> int:
"""Calculate the total number of tokens in a message list"""
total_tokens = self.FORMAT_TOKENS # Base format tokens
for message in messages:
tokens = self.BASE_MESSAGE_TOKENS # Base tokens per message
# Add role tokens
tokens += self.count_text(message.get("role", ""))
# Add content tokens
if "content" in message:
tokens += self.count_content(message["content"])
# Add tool calls tokens
if "tool_calls" in message:
tokens += self.count_tool_calls(message["tool_calls"])
# Add name and tool_call_id tokens
tokens += self.count_text(message.get("name", ""))
tokens += self.count_text(message.get("tool_call_id", ""))
total_tokens += tokens
return total_tokens
class LLM: class LLM:
_instances: Dict[str, "LLM"] = {} _instances: Dict[str, "LLM"] = {}
@ -82,6 +219,8 @@ class LLM:
else: else:
self.client = AsyncOpenAI(api_key=self.api_key, base_url=self.base_url) self.client = AsyncOpenAI(api_key=self.api_key, base_url=self.base_url)
self.token_counter = TokenCounter(self.tokenizer)
def count_tokens(self, text: str) -> int: def count_tokens(self, text: str) -> int:
"""Calculate the number of tokens in a text""" """Calculate the number of tokens in a text"""
if not text: if not text:
@ -89,46 +228,7 @@ class LLM:
return len(self.tokenizer.encode(text)) return len(self.tokenizer.encode(text))
def count_message_tokens(self, messages: List[dict]) -> int: def count_message_tokens(self, messages: List[dict]) -> int:
"""Calculate the number of tokens in a message list""" return self.token_counter.count_message_tokens(messages)
token_count = 0
for message in messages:
# Base token count for each message (according to OpenAI's calculation method)
token_count += 4 # Base token count for each message
# Calculate tokens for the role
if "role" in message:
token_count += self.count_tokens(message["role"])
# Calculate tokens for the content
if "content" in message and message["content"]:
token_count += self.count_tokens(message["content"])
# Calculate tokens for tool calls
if "tool_calls" in message and message["tool_calls"]:
for tool_call in message["tool_calls"]:
if "function" in tool_call:
# Function name
if "name" in tool_call["function"]:
token_count += self.count_tokens(
tool_call["function"]["name"]
)
# Function arguments
if "arguments" in tool_call["function"]:
token_count += self.count_tokens(
tool_call["function"]["arguments"]
)
# Calculate tokens for tool responses
if "name" in message and message["name"]:
token_count += self.count_tokens(message["name"])
if "tool_call_id" in message and message["tool_call_id"]:
token_count += self.count_tokens(message["tool_call_id"])
# Add extra tokens for message format
token_count += 2 # Extra tokens for message format
return token_count
def update_token_count(self, input_tokens: int, completion_tokens: int = 0) -> None: def update_token_count(self, input_tokens: int, completion_tokens: int = 0) -> None:
"""Update token counts""" """Update token counts"""
@ -184,22 +284,58 @@ class LLM:
formatted_messages = [] formatted_messages = []
for message in messages: for message in messages:
# Convert Message objects to dictionaries
if isinstance(message, Message): if isinstance(message, Message):
message = message.to_dict() message = message.to_dict()
if isinstance(message, dict):
# If message is a dict, ensure it has required fields if not isinstance(message, dict):
if "role" not in message:
raise ValueError("Message dict must contain 'role' field")
if "content" in message or "tool_calls" in message:
formatted_messages.append(message)
# else: do not include the message
else:
raise TypeError(f"Unsupported message type: {type(message)}") raise TypeError(f"Unsupported message type: {type(message)}")
# Validate all messages have required fields # Validate required fields
for msg in formatted_messages: if "role" not in message:
if msg["role"] not in ROLE_VALUES: raise ValueError("Message dict must contain 'role' field")
raise ValueError(f"Invalid role: {msg['role']}")
# Process base64 images if present
if message.get("base64_image"):
# Initialize or convert content to appropriate format
if not message.get("content"):
message["content"] = []
elif isinstance(message["content"], str):
message["content"] = [{"type": "text", "text": message["content"]}]
elif isinstance(message["content"], list):
# Convert string items to proper text objects
message["content"] = [
(
{"type": "text", "text": item}
if isinstance(item, str)
else item
)
for item in message["content"]
]
# Add the image to content
message["content"].append(
{
"type": "image_url",
"image_url": {
"url": f"data:image/jpeg;base64,{message['base64_image']}"
},
}
)
# Remove the base64_image field
del message["base64_image"]
# Only include messages with content or tool_calls
if "content" in message or "tool_calls" in message:
formatted_messages.append(message)
# Validate all roles
invalid_roles = [
msg for msg in formatted_messages if msg["role"] not in ROLE_VALUES
]
if invalid_roles:
raise ValueError(f"Invalid role: {invalid_roles[0]['role']}")
return formatted_messages return formatted_messages
@ -328,6 +464,152 @@ class LLM:
logger.error(f"Unexpected error in ask: {e}") logger.error(f"Unexpected error in ask: {e}")
raise raise
@retry(
wait=wait_random_exponential(min=1, max=60),
stop=stop_after_attempt(6),
retry=retry_if_exception_type(
(OpenAIError, Exception, ValueError)
), # Don't retry TokenLimitExceeded
)
async def ask_with_images(
self,
messages: List[Union[dict, Message]],
images: List[Union[str, dict]],
system_msgs: Optional[List[Union[dict, Message]]] = None,
stream: bool = False,
temperature: Optional[float] = None,
) -> str:
"""
Send a prompt with images to the LLM and get the response.
Args:
messages: List of conversation messages
images: List of image URLs or image data dictionaries
system_msgs: Optional system messages to prepend
stream (bool): Whether to stream the response
temperature (float): Sampling temperature for the response
Returns:
str: The generated response
Raises:
TokenLimitExceeded: If token limits are exceeded
ValueError: If messages are invalid or response is empty
OpenAIError: If API call fails after retries
Exception: For unexpected errors
"""
try:
# Format messages
formatted_messages = self.format_messages(messages)
# Ensure the last message is from the user to attach images
if not formatted_messages or formatted_messages[-1]["role"] != "user":
raise ValueError(
"The last message must be from the user to attach images"
)
# Process the last user message to include images
last_message = formatted_messages[-1]
# Convert content to multimodal format if needed
content = last_message["content"]
multimodal_content = (
[{"type": "text", "text": content}]
if isinstance(content, str)
else content
if isinstance(content, list)
else []
)
# Add images to content
for image in images:
if isinstance(image, str):
multimodal_content.append(
{"type": "image_url", "image_url": {"url": image}}
)
elif isinstance(image, dict) and "url" in image:
multimodal_content.append({"type": "image_url", "image_url": image})
elif isinstance(image, dict) and "image_url" in image:
multimodal_content.append(image)
else:
raise ValueError(f"Unsupported image format: {image}")
# Update the message with multimodal content
last_message["content"] = multimodal_content
# Add system messages if provided
if system_msgs:
all_messages = self.format_messages(system_msgs) + formatted_messages
else:
all_messages = formatted_messages
# Calculate tokens and check limits
input_tokens = self.count_message_tokens(all_messages)
if not self.check_token_limit(input_tokens):
raise TokenLimitExceeded(self.get_limit_error_message(input_tokens))
# Set up API parameters
params = {
"model": self.model,
"messages": all_messages,
"stream": stream,
}
# Add model-specific parameters
if self.model in REASONING_MODELS:
params["max_completion_tokens"] = self.max_tokens
else:
params["max_tokens"] = self.max_tokens
params["temperature"] = (
temperature if temperature is not None else self.temperature
)
# Handle non-streaming request
if not stream:
response = await self.client.chat.completions.create(**params)
if not response.choices or not response.choices[0].message.content:
raise ValueError("Empty or invalid response from LLM")
self.update_token_count(response.usage.prompt_tokens)
return response.choices[0].message.content
# Handle streaming request
self.update_token_count(input_tokens)
response = await self.client.chat.completions.create(**params)
collected_messages = []
async for chunk in response:
chunk_message = chunk.choices[0].delta.content or ""
collected_messages.append(chunk_message)
print(chunk_message, end="", flush=True)
print() # Newline after streaming
full_response = "".join(collected_messages).strip()
if not full_response:
raise ValueError("Empty response from streaming LLM")
return full_response
except TokenLimitExceeded:
raise
except ValueError as ve:
logger.error(f"Validation error in ask_with_images: {ve}")
raise
except OpenAIError as oe:
logger.error(f"OpenAI API error: {oe}")
if isinstance(oe, AuthenticationError):
logger.error("Authentication failed. Check API key.")
elif isinstance(oe, RateLimitError):
logger.error("Rate limit exceeded. Consider increasing retry attempts.")
elif isinstance(oe, APIError):
logger.error(f"API error: {oe}")
raise
except Exception as e:
logger.error(f"Unexpected error in ask_with_images: {e}")
raise
@retry( @retry(
wait=wait_random_exponential(min=1, max=60), wait=wait_random_exponential(min=1, max=60),
stop=stop_after_attempt(6), stop=stop_after_attempt(6),

View File

@ -6,9 +6,7 @@ PythonExecute: Execute Python code to interact with the computer system, data pr
FileSaver: Save files locally, such as txt, py, html, etc. FileSaver: Save files locally, such as txt, py, html, etc.
BrowserUseTool: Open, browse, and use web browsers.If you open a local HTML file, you must provide the absolute path to the file. BrowserUseTool: Open, browse, and use web browsers. If you open a local HTML file, you must provide the absolute path to the file.
WebSearch: Perform web information retrieval
Terminate: End the current interaction when the task is complete or when you need additional information from the user. Use this tool to signal that you've finished addressing the user's request or need clarification before proceeding further. Terminate: End the current interaction when the task is complete or when you need additional information from the user. Use this tool to signal that you've finished addressing the user's request or need clarification before proceeding further.

View File

@ -59,6 +59,7 @@ class Message(BaseModel):
tool_calls: Optional[List[ToolCall]] = Field(default=None) tool_calls: Optional[List[ToolCall]] = Field(default=None)
name: Optional[str] = Field(default=None) name: Optional[str] = Field(default=None)
tool_call_id: Optional[str] = Field(default=None) tool_call_id: Optional[str] = Field(default=None)
base64_image: Optional[str] = Field(default=None)
def __add__(self, other) -> List["Message"]: def __add__(self, other) -> List["Message"]:
"""支持 Message + list 或 Message + Message 的操作""" """支持 Message + list 或 Message + Message 的操作"""
@ -91,12 +92,16 @@ class Message(BaseModel):
message["name"] = self.name message["name"] = self.name
if self.tool_call_id is not None: if self.tool_call_id is not None:
message["tool_call_id"] = self.tool_call_id message["tool_call_id"] = self.tool_call_id
if self.base64_image is not None:
message["base64_image"] = self.base64_image
return message return message
@classmethod @classmethod
def user_message(cls, content: str) -> "Message": def user_message(
cls, content: str, base64_image: Optional[str] = None
) -> "Message":
"""Create a user message""" """Create a user message"""
return cls(role=Role.USER, content=content) return cls(role=Role.USER, content=content, base64_image=base64_image)
@classmethod @classmethod
def system_message(cls, content: str) -> "Message": def system_message(cls, content: str) -> "Message":
@ -104,33 +109,50 @@ class Message(BaseModel):
return cls(role=Role.SYSTEM, content=content) return cls(role=Role.SYSTEM, content=content)
@classmethod @classmethod
def assistant_message(cls, content: Optional[str] = None) -> "Message": def assistant_message(
cls, content: Optional[str] = None, base64_image: Optional[str] = None
) -> "Message":
"""Create an assistant message""" """Create an assistant message"""
return cls(role=Role.ASSISTANT, content=content) return cls(role=Role.ASSISTANT, content=content, base64_image=base64_image)
@classmethod @classmethod
def tool_message(cls, content: str, name, tool_call_id: str) -> "Message": def tool_message(
cls, content: str, name, tool_call_id: str, base64_image: Optional[str] = None
) -> "Message":
"""Create a tool message""" """Create a tool message"""
return cls( return cls(
role=Role.TOOL, content=content, name=name, tool_call_id=tool_call_id role=Role.TOOL,
content=content,
name=name,
tool_call_id=tool_call_id,
base64_image=base64_image,
) )
@classmethod @classmethod
def from_tool_calls( def from_tool_calls(
cls, tool_calls: List[Any], content: Union[str, List[str]] = "", **kwargs cls,
tool_calls: List[Any],
content: Union[str, List[str]] = "",
base64_image: Optional[str] = None,
**kwargs,
): ):
"""Create ToolCallsMessage from raw tool calls. """Create ToolCallsMessage from raw tool calls.
Args: Args:
tool_calls: Raw tool calls from LLM tool_calls: Raw tool calls from LLM
content: Optional message content content: Optional message content
base64_image: Optional base64 encoded image
""" """
formatted_calls = [ formatted_calls = [
{"id": call.id, "function": call.function.model_dump(), "type": "function"} {"id": call.id, "function": call.function.model_dump(), "type": "function"}
for call in tool_calls for call in tool_calls
] ]
return cls( return cls(
role=Role.ASSISTANT, content=content, tool_calls=formatted_calls, **kwargs role=Role.ASSISTANT,
content=content,
tool_calls=formatted_calls,
base64_image=base64_image,
**kwargs,
) )

View File

@ -37,6 +37,7 @@ class ToolResult(BaseModel):
output: Any = Field(default=None) output: Any = Field(default=None)
error: Optional[str] = Field(default=None) error: Optional[str] = Field(default=None)
base64_image: Optional[str] = Field(default=None)
system: Optional[str] = Field(default=None) system: Optional[str] = Field(default=None)
class Config: class Config:
@ -58,6 +59,7 @@ class ToolResult(BaseModel):
return ToolResult( return ToolResult(
output=combine_fields(self.output, other.output), output=combine_fields(self.output, other.output),
error=combine_fields(self.error, other.error), error=combine_fields(self.error, other.error),
base64_image=combine_fields(self.base64_image, other.base64_image, False),
system=combine_fields(self.system, other.system), system=combine_fields(self.system, other.system),
) )
@ -76,7 +78,3 @@ class CLIResult(ToolResult):
class ToolFailure(ToolResult): class ToolFailure(ToolResult):
"""A ToolResult that represents a failure.""" """A ToolResult that represents a failure."""
class AgentAwareTool:
agent: Optional = None

View File

@ -1,6 +1,6 @@
import asyncio import asyncio
import json import json
from typing import Optional from typing import Generic, Optional, TypeVar
from browser_use import Browser as BrowserUseBrowser from browser_use import Browser as BrowserUseBrowser
from browser_use import BrowserConfig from browser_use import BrowserConfig
@ -10,31 +10,45 @@ from pydantic import Field, field_validator
from pydantic_core.core_schema import ValidationInfo from pydantic_core.core_schema import ValidationInfo
from app.config import config from app.config import config
from app.llm import LLM
from app.tool.base import BaseTool, ToolResult from app.tool.base import BaseTool, ToolResult
from app.tool.web_search import WebSearch
MAX_LENGTH = 2000
_BROWSER_DESCRIPTION = """ _BROWSER_DESCRIPTION = """
Interact with a web browser to perform various actions such as navigation, element interaction, Interact with a web browser to perform various actions such as navigation, element interaction, content extraction, and tab management. This tool provides a comprehensive set of browser automation capabilities:
content extraction, and tab management. Supported actions include:
- 'navigate': Go to a specific URL Navigation:
- 'click': Click an element by index - 'go_to_url': Go to a specific URL in the current tab
- 'input_text': Input text into an element - 'go_back': Go back
- 'screenshot': Capture a screenshot
- 'get_html': Get page HTML content
- 'get_text': Get text content of the page
- 'read_links': Get all links on the page
- 'execute_js': Execute JavaScript code
- 'scroll': Scroll the page
- 'switch_tab': Switch to a specific tab
- 'new_tab': Open a new tab
- 'close_tab': Close the current tab
- 'refresh': Refresh the current page - 'refresh': Refresh the current page
- 'web_search': Search the query in the current tab, the query should be a search query like humans search in web, concrete and not vague or super long. More the single most important items.
Element Interaction:
- 'click_element': Click an element by index
- 'input_text': Input text into a form element
- 'scroll_down'/'scroll_up': Scroll the page (with optional pixel amount)
- 'scroll_to_text': If you dont find something which you want to interact with, scroll to it
- 'send_keys': Send strings of special keys like Escape,Backspace, Insert, PageDown, Delete, Enter, Shortcuts such as `Control+o`, `Control+Shift+T` are supported as well. This gets used in keyboard.press.
- 'get_dropdown_options': Get all options from a dropdown
- 'select_dropdown_option': Select dropdown option for interactive element index by the text of the option you want to select
Content Extraction:
- 'extract_content': Extract page content to retrieve specific information from the page, e.g. all company names, a specifc description, all information about, links with companies in structured format or simply links
Tab Management:
- 'switch_tab': Switch to a specific tab
- 'open_tab': Open a new tab with a URL
- 'close_tab': Close the current tab
Utility:
- 'wait': Wait for a specified number of seconds
""" """
Context = TypeVar("Context")
class BrowserUseTool(BaseTool):
class BrowserUseTool(BaseTool, Generic[Context]):
name: str = "browser_use" name: str = "browser_use"
description: str = _BROWSER_DESCRIPTION description: str = _BROWSER_DESCRIPTION
parameters: dict = { parameters: dict = {
@ -43,52 +57,79 @@ class BrowserUseTool(BaseTool):
"action": { "action": {
"type": "string", "type": "string",
"enum": [ "enum": [
"navigate", "go_to_url",
"click", "click_element",
"input_text", "input_text",
"screenshot", "scroll_down",
"get_html", "scroll_up",
"get_text", "scroll_to_text",
"execute_js", "send_keys",
"scroll", "get_dropdown_options",
"select_dropdown_option",
"go_back",
"web_search",
"wait",
"extract_content",
"switch_tab", "switch_tab",
"new_tab", "open_tab",
"close_tab", "close_tab",
"refresh",
], ],
"description": "The browser action to perform", "description": "The browser action to perform",
}, },
"url": { "url": {
"type": "string", "type": "string",
"description": "URL for 'navigate' or 'new_tab' actions", "description": "URL for 'go_to_url' or 'open_tab' actions",
}, },
"index": { "index": {
"type": "integer", "type": "integer",
"description": "Element index for 'click' or 'input_text' actions", "description": "Element index for 'click_element', 'input_text', 'get_dropdown_options', or 'select_dropdown_option' actions",
}, },
"text": {"type": "string", "description": "Text for 'input_text' action"}, "text": {
"script": {
"type": "string", "type": "string",
"description": "JavaScript code for 'execute_js' action", "description": "Text for 'input_text', 'scroll_to_text', or 'select_dropdown_option' actions",
}, },
"scroll_amount": { "scroll_amount": {
"type": "integer", "type": "integer",
"description": "Pixels to scroll (positive for down, negative for up) for 'scroll' action", "description": "Pixels to scroll (positive for down, negative for up) for 'scroll_down' or 'scroll_up' actions",
}, },
"tab_id": { "tab_id": {
"type": "integer", "type": "integer",
"description": "Tab ID for 'switch_tab' action", "description": "Tab ID for 'switch_tab' action",
}, },
"query": {
"type": "string",
"description": "Search query for 'web_search' action",
},
"goal": {
"type": "string",
"description": "Extraction goal for 'extract_content' action",
},
"keys": {
"type": "string",
"description": "Keys to send for 'send_keys' action",
},
"seconds": {
"type": "integer",
"description": "Seconds to wait for 'wait' action",
},
}, },
"required": ["action"], "required": ["action"],
"dependencies": { "dependencies": {
"navigate": ["url"], "go_to_url": ["url"],
"click": ["index"], "click_element": ["index"],
"input_text": ["index", "text"], "input_text": ["index", "text"],
"execute_js": ["script"],
"switch_tab": ["tab_id"], "switch_tab": ["tab_id"],
"new_tab": ["url"], "open_tab": ["url"],
"scroll": ["scroll_amount"], "scroll_down": ["scroll_amount"],
"scroll_up": ["scroll_amount"],
"scroll_to_text": ["text"],
"send_keys": ["keys"],
"get_dropdown_options": ["index"],
"select_dropdown_option": ["index", "text"],
"go_back": [],
"web_search": ["query"],
"wait": ["seconds"],
"extract_content": ["goal"],
}, },
} }
@ -96,6 +137,12 @@ class BrowserUseTool(BaseTool):
browser: Optional[BrowserUseBrowser] = Field(default=None, exclude=True) browser: Optional[BrowserUseBrowser] = Field(default=None, exclude=True)
context: Optional[BrowserContext] = Field(default=None, exclude=True) context: Optional[BrowserContext] = Field(default=None, exclude=True)
dom_service: Optional[DomService] = Field(default=None, exclude=True) dom_service: Optional[DomService] = Field(default=None, exclude=True)
web_search_tool: WebSearch = Field(default_factory=WebSearch, exclude=True)
# Context for generic functionality
tool_context: Optional[Context] = Field(default=None, exclude=True)
llm: Optional[LLM] = Field(default_factory=LLM)
@field_validator("parameters", mode="before") @field_validator("parameters", mode="before")
def validate_parameters(cls, v: dict, info: ValidationInfo) -> dict: def validate_parameters(cls, v: dict, info: ValidationInfo) -> dict:
@ -158,9 +205,12 @@ class BrowserUseTool(BaseTool):
url: Optional[str] = None, url: Optional[str] = None,
index: Optional[int] = None, index: Optional[int] = None,
text: Optional[str] = None, text: Optional[str] = None,
script: Optional[str] = None,
scroll_amount: Optional[int] = None, scroll_amount: Optional[int] = None,
tab_id: Optional[int] = None, tab_id: Optional[int] = None,
query: Optional[str] = None,
goal: Optional[str] = None,
keys: Optional[str] = None,
seconds: Optional[int] = None,
**kwargs, **kwargs,
) -> ToolResult: ) -> ToolResult:
""" """
@ -170,10 +220,13 @@ class BrowserUseTool(BaseTool):
action: The browser action to perform action: The browser action to perform
url: URL for navigation or new tab url: URL for navigation or new tab
index: Element index for click or input actions index: Element index for click or input actions
text: Text for input action text: Text for input action or search query
script: JavaScript code for execution
scroll_amount: Pixels to scroll for scroll action scroll_amount: Pixels to scroll for scroll action
tab_id: Tab ID for switch_tab action tab_id: Tab ID for switch_tab action
query: Search query for Google search
goal: Extraction goal for content extraction
keys: Keys to send for keyboard actions
seconds: Seconds to wait
**kwargs: Additional arguments **kwargs: Additional arguments
Returns: Returns:
@ -183,15 +236,68 @@ class BrowserUseTool(BaseTool):
try: try:
context = await self._ensure_browser_initialized() context = await self._ensure_browser_initialized()
if action == "navigate": # Get max content length from config
max_content_length = getattr(
config.browser_config, "max_content_length", 2000
)
# Navigation actions
if action == "go_to_url":
if not url: if not url:
return ToolResult(error="URL is required for 'navigate' action") return ToolResult(
await context.navigate_to(url) error="URL is required for 'go_to_url' action"
)
page = await context.get_current_page()
await page.goto(url)
await page.wait_for_load_state()
return ToolResult(output=f"Navigated to {url}") return ToolResult(output=f"Navigated to {url}")
elif action == "click": elif action == "go_back":
await context.go_back()
return ToolResult(output="Navigated back")
elif action == "refresh":
await context.refresh_page()
return ToolResult(output="Refreshed current page")
elif action == "web_search":
if not query:
return ToolResult(
error="Query is required for 'web_search' action"
)
search_results = await self.web_search_tool.execute(query)
if search_results:
# Navigate to the first search result
first_result = search_results[0]
if isinstance(first_result, dict) and "url" in first_result:
url_to_navigate = first_result["url"]
elif isinstance(first_result, str):
url_to_navigate = first_result
else:
return ToolResult(
error=f"Invalid search result format: {first_result}"
)
page = await context.get_current_page()
await page.goto(url_to_navigate)
await page.wait_for_load_state()
return ToolResult(
output=f"Searched for '{query}' and navigated to first result: {url_to_navigate}\nAll results:"
+ "\n".join([str(r) for r in search_results])
)
else:
return ToolResult(
error=f"No search results found for '{query}'"
)
# Element interaction actions
elif action == "click_element":
if index is None: if index is None:
return ToolResult(error="Index is required for 'click' action") return ToolResult(
error="Index is required for 'click_element' action"
)
element = await context.get_dom_element_by_index(index) element = await context.get_dom_element_by_index(index)
if not element: if not element:
return ToolResult(error=f"Element with index {index} not found") return ToolResult(error=f"Element with index {index} not found")
@ -214,72 +320,174 @@ class BrowserUseTool(BaseTool):
output=f"Input '{text}' into element at index {index}" output=f"Input '{text}' into element at index {index}"
) )
elif action == "screenshot": elif action == "scroll_down" or action == "scroll_up":
screenshot = await context.take_screenshot(full_page=True) direction = 1 if action == "scroll_down" else -1
return ToolResult( amount = (
output=f"Screenshot captured (base64 length: {len(screenshot)})", scroll_amount
system=screenshot, if scroll_amount is not None
else context.config.browser_window_size["height"]
) )
elif action == "get_html":
html = await context.get_page_html()
truncated = (
html[:MAX_LENGTH] + "..." if len(html) > MAX_LENGTH else html
)
return ToolResult(output=truncated)
elif action == "get_text":
text = await context.execute_javascript("document.body.innerText")
return ToolResult(output=text)
elif action == "read_links":
links = await context.execute_javascript(
"document.querySelectorAll('a[href]').forEach((elem) => {if (elem.innerText) {console.log(elem.innerText, elem.href)}})"
)
return ToolResult(output=links)
elif action == "execute_js":
if not script:
return ToolResult(
error="Script is required for 'execute_js' action"
)
result = await context.execute_javascript(script)
return ToolResult(output=str(result))
elif action == "scroll":
if scroll_amount is None:
return ToolResult(
error="Scroll amount is required for 'scroll' action"
)
await context.execute_javascript( await context.execute_javascript(
f"window.scrollBy(0, {scroll_amount});" f"window.scrollBy(0, {direction * amount});"
) )
direction = "down" if scroll_amount > 0 else "up"
return ToolResult( return ToolResult(
output=f"Scrolled {direction} by {abs(scroll_amount)} pixels" output=f"Scrolled {'down' if direction > 0 else 'up'} by {amount} pixels"
) )
elif action == "scroll_to_text":
if not text:
return ToolResult(
error="Text is required for 'scroll_to_text' action"
)
page = await context.get_current_page()
try:
locator = page.get_by_text(text, exact=False)
await locator.scroll_into_view_if_needed()
return ToolResult(output=f"Scrolled to text: '{text}'")
except Exception as e:
return ToolResult(error=f"Failed to scroll to text: {str(e)}")
elif action == "send_keys":
if not keys:
return ToolResult(
error="Keys are required for 'send_keys' action"
)
page = await context.get_current_page()
await page.keyboard.press(keys)
return ToolResult(output=f"Sent keys: {keys}")
elif action == "get_dropdown_options":
if index is None:
return ToolResult(
error="Index is required for 'get_dropdown_options' action"
)
element = await context.get_dom_element_by_index(index)
if not element:
return ToolResult(error=f"Element with index {index} not found")
page = await context.get_current_page()
options = await page.evaluate(
"""
(xpath) => {
const select = document.evaluate(xpath, document, null,
XPathResult.FIRST_ORDERED_NODE_TYPE, null).singleNodeValue;
if (!select) return null;
return Array.from(select.options).map(opt => ({
text: opt.text,
value: opt.value,
index: opt.index
}));
}
""",
element.xpath,
)
return ToolResult(output=f"Dropdown options: {options}")
elif action == "select_dropdown_option":
if index is None or not text:
return ToolResult(
error="Index and text are required for 'select_dropdown_option' action"
)
element = await context.get_dom_element_by_index(index)
if not element:
return ToolResult(error=f"Element with index {index} not found")
page = await context.get_current_page()
await page.select_option(element.xpath, label=text)
return ToolResult(
output=f"Selected option '{text}' from dropdown at index {index}"
)
# Content extraction actions
elif action == "extract_content":
if not goal:
return ToolResult(
error="Goal is required for 'extract_content' action"
)
page = await context.get_current_page()
try:
# Get page content and convert to markdown for better processing
html_content = await page.content()
# Import markdownify here to avoid global import
try:
import markdownify
content = markdownify.markdownify(html_content)
except ImportError:
# Fallback if markdownify is not available
content = html_content
# Create prompt for LLM
prompt_text = """
Your task is to extract the content of the page. You will be given a page and a goal, and you should extract all relevant information around this goal from the page.
Examples of extraction goals:
- Extract all company names
- Extract specific descriptions
- Extract all information about a topic
- Extract links with companies in structured format
- Extract all links
If the goal is vague, summarize the page. Respond in JSON format.
Extraction goal: {goal}
Page content:
{page}
"""
# Format the prompt with the goal and content
max_content_length = min(50000, len(content))
formatted_prompt = prompt_text.format(
goal=goal, page=content[:max_content_length]
)
# Create a proper message list for the LLM
from app.schema import Message
messages = [Message.user_message(formatted_prompt)]
# Use LLM to extract content based on the goal
response = await self.llm.ask(messages)
msg = f"Extracted from page:\n{response}\n"
return ToolResult(output=msg)
except Exception as e:
# Provide a more helpful error message
error_msg = f"Failed to extract content: {str(e)}"
try:
# Try to return a portion of the page content as fallback
return ToolResult(
output=f"{error_msg}\nHere's a portion of the page content:\n{content[:2000]}..."
)
except:
# If all else fails, just return the error
return ToolResult(error=error_msg)
# Tab management actions
elif action == "switch_tab": elif action == "switch_tab":
if tab_id is None: if tab_id is None:
return ToolResult( return ToolResult(
error="Tab ID is required for 'switch_tab' action" error="Tab ID is required for 'switch_tab' action"
) )
await context.switch_to_tab(tab_id) await context.switch_to_tab(tab_id)
page = await context.get_current_page()
await page.wait_for_load_state()
return ToolResult(output=f"Switched to tab {tab_id}") return ToolResult(output=f"Switched to tab {tab_id}")
elif action == "new_tab": elif action == "open_tab":
if not url: if not url:
return ToolResult(error="URL is required for 'new_tab' action") return ToolResult(error="URL is required for 'open_tab' action")
await context.create_new_tab(url) await context.create_new_tab(url)
return ToolResult(output=f"Opened new tab with URL {url}") return ToolResult(output=f"Opened new tab with {url}")
elif action == "close_tab": elif action == "close_tab":
await context.close_current_tab() await context.close_current_tab()
return ToolResult(output="Closed current tab") return ToolResult(output="Closed current tab")
elif action == "refresh": # Utility actions
await context.refresh_page() elif action == "wait":
return ToolResult(output="Refreshed current page") seconds_to_wait = seconds if seconds is not None else 3
await asyncio.sleep(seconds_to_wait)
return ToolResult(output=f"Waited for {seconds_to_wait} seconds")
else: else:
return ToolResult(error=f"Unknown action: {action}") return ToolResult(error=f"Unknown action: {action}")
@ -287,21 +495,58 @@ class BrowserUseTool(BaseTool):
except Exception as e: except Exception as e:
return ToolResult(error=f"Browser action '{action}' failed: {str(e)}") return ToolResult(error=f"Browser action '{action}' failed: {str(e)}")
async def get_current_state(self) -> ToolResult: async def get_current_state(
"""Get the current browser state as a ToolResult.""" self, context: Optional[BrowserContext] = None
async with self.lock: ) -> ToolResult:
try: """
context = await self._ensure_browser_initialized() Get the current browser state as a ToolResult.
state = await context.get_state() If context is not provided, uses self.context.
state_info = { """
"url": state.url, try:
"title": state.title, # Use provided context or fall back to self.context
"tabs": [tab.model_dump() for tab in state.tabs], ctx = context or self.context
"interactive_elements": state.element_tree.clickable_elements_to_string(), if not ctx:
} return ToolResult(error="Browser context not initialized")
return ToolResult(output=json.dumps(state_info))
except Exception as e: state = await ctx.get_state()
return ToolResult(error=f"Failed to get browser state: {str(e)}")
# Create a viewport_info dictionary if it doesn't exist
viewport_height = 0
if hasattr(state, "viewport_info") and state.viewport_info:
viewport_height = state.viewport_info.height
elif hasattr(ctx, "config") and hasattr(ctx.config, "browser_window_size"):
viewport_height = ctx.config.browser_window_size.get("height", 0)
# Take a screenshot for the state
screenshot = await ctx.take_screenshot(full_page=True)
# Build the state info with all required fields
state_info = {
"url": state.url,
"title": state.title,
"tabs": [tab.model_dump() for tab in state.tabs],
"help": "[0], [1], [2], etc., represent clickable indices corresponding to the elements listed. Clicking on these indices will navigate to or interact with the respective content behind them.",
"interactive_elements": (
state.element_tree.clickable_elements_to_string()
if state.element_tree
else ""
),
"scroll_info": {
"pixels_above": getattr(state, "pixels_above", 0),
"pixels_below": getattr(state, "pixels_below", 0),
"total_height": getattr(state, "pixels_above", 0)
+ getattr(state, "pixels_below", 0)
+ viewport_height,
},
"viewport_height": viewport_height,
}
return ToolResult(
output=json.dumps(state_info, indent=4, ensure_ascii=False),
base64_image=screenshot,
)
except Exception as e:
return ToolResult(error=f"Failed to get browser state: {str(e)}")
async def cleanup(self): async def cleanup(self):
"""Clean up browser resources.""" """Clean up browser resources."""
@ -323,3 +568,10 @@ class BrowserUseTool(BaseTool):
loop = asyncio.new_event_loop() loop = asyncio.new_event_loop()
loop.run_until_complete(self.cleanup()) loop.run_until_complete(self.cleanup())
loop.close() loop.close()
@classmethod
def create_with_context(cls, context: Context) -> "BrowserUseTool[Context]":
"""Factory method to create a BrowserUseTool with a specific context."""
tool = cls()
tool.tool_context = context
return tool

View File

@ -1,11 +1,10 @@
# Global LLM configuration # Global LLM configuration
[llm] [llm]
model = "gpt-4o" # The LLM model to use model = "claude-3-7-sonnet-20250219" # The LLM model to use
base_url = "https://api.openai.com/v1" # API endpoint URL base_url = "https://api.anthropic.com/v1/" # API endpoint URL
api_key = "sk-..." # Your API key api_key = "YOUR_API_KEY" # Your API key
max_tokens = 8192 # Maximum number of tokens in the response max_tokens = 8192 # Maximum number of tokens in the response
temperature = 0.0 # Controls randomness temperature = 0.0 # Controls randomness
# max_input_tokens = 100000 # Maximum input tokens to use across all requests (set to null or delete this line for unlimited)
# [llm] #AZURE OPENAI: # [llm] #AZURE OPENAI:
# api_type= 'azure' # api_type= 'azure'
@ -26,11 +25,11 @@ temperature = 0.0 # Controls randomness
# Optional configuration for specific LLM models # Optional configuration for specific LLM models
[llm.vision] [llm.vision]
model = "gpt-4o" # The vision model to use model = "claude-3-7-sonnet-20250219" # The vision model to use
base_url = "https://api.openai.com/v1" # API endpoint URL for vision model base_url = "https://api.anthropic.com/v1/" # API endpoint URL for vision model
api_key = "sk-..." # Your API key for vision model api_key = "YOUR_API_KEY" # Your API key for vision model
max_tokens = 8192 # Maximum number of tokens in the response max_tokens = 8192 # Maximum number of tokens in the response
temperature = 0.0 # Controls randomness for vision model temperature = 0.0 # Controls randomness for vision model
# [llm.vision] #OLLAMA VISION: # [llm.vision] #OLLAMA VISION:
# api_type = 'ollama' # api_type = 'ollama'

View File

@ -20,6 +20,6 @@ baidusearch~=1.0.3
duckduckgo_search~=7.5.1 duckduckgo_search~=7.5.1
aiofiles~=24.1.0 aiofiles~=24.1.0
pydantic_core~=2.32.0 pydantic_core~=2.27.2
colorama~=0.4.6 colorama~=0.4.6
playwright~=1.50.0 playwright~=1.50.0

View File

@ -31,7 +31,7 @@ setup(
"browser-use~=0.1.40", "browser-use~=0.1.40",
"googlesearch-python~=1.3.0", "googlesearch-python~=1.3.0",
"aiofiles~=24.1.0", "aiofiles~=24.1.0",
"pydantic_core>=2.27.2,<2.33.0", "pydantic_core>=2.27.2,<2.28.0",
"colorama~=0.4.6", "colorama~=0.4.6",
], ],
classifiers=[ classifiers=[