Compare commits
1 Commits
main
...
dependabot
Author | SHA1 | Date | |
---|---|---|---|
|
3fa5fae7ec |
14
.github/ISSUE_TEMPLATE/request_new_features.md
vendored
Normal file
14
.github/ISSUE_TEMPLATE/request_new_features.md
vendored
Normal file
@ -0,0 +1,14 @@
|
||||
---
|
||||
name: "🤔 Request new features"
|
||||
about: Suggest ideas or features you’d like to see implemented in OpenManus.
|
||||
title: ''
|
||||
labels: kind/features
|
||||
assignees: ''
|
||||
---
|
||||
|
||||
**Feature description**
|
||||
<!-- Provide a clear and concise description of the proposed feature -->
|
||||
|
||||
**Your Feature**
|
||||
<!-- Explain your idea or implementation process. Optionally, include a Pull Request URL. -->
|
||||
<!-- Ensure accompanying docs/tests/examples are provided for review. -->
|
21
.github/ISSUE_TEMPLATE/request_new_features.yaml
vendored
21
.github/ISSUE_TEMPLATE/request_new_features.yaml
vendored
@ -1,21 +0,0 @@
|
||||
name: "🤔 Request new features"
|
||||
description: Suggest ideas or features you’d like to see implemented in OpenManus.
|
||||
labels: enhancement
|
||||
body:
|
||||
- type: textarea
|
||||
id: feature-description
|
||||
attributes:
|
||||
label: Feature description
|
||||
description: |
|
||||
Provide a clear and concise description of the proposed feature
|
||||
validations:
|
||||
required: true
|
||||
- type: textarea
|
||||
id: your-feature
|
||||
attributes:
|
||||
label: Your Feature
|
||||
description: |
|
||||
Explain your idea or implementation process, if any. Optionally, include a Pull Request URL.
|
||||
Ensure accompanying docs/tests/examples are provided for review.
|
||||
validations:
|
||||
required: false
|
25
.github/ISSUE_TEMPLATE/show_me_the_bug.md
vendored
Normal file
25
.github/ISSUE_TEMPLATE/show_me_the_bug.md
vendored
Normal file
@ -0,0 +1,25 @@
|
||||
---
|
||||
name: "🪲 Show me the Bug"
|
||||
about: Report a bug encountered while using OpenManus and seek assistance.
|
||||
title: ''
|
||||
labels: kind/bug
|
||||
assignees: ''
|
||||
---
|
||||
|
||||
**Bug description**
|
||||
<!-- Clearly describe the bug you encountered -->
|
||||
|
||||
**Bug solved method**
|
||||
<!-- If resolved, explain the solution. Optionally, include a Pull Request URL. -->
|
||||
<!-- If unresolved, provide additional details to aid investigation -->
|
||||
|
||||
**Environment information**
|
||||
<!-- System: e.g., Ubuntu 22.04, Python: e.g., 3.12, OpenManus version: e.g., 0.1.0 -->
|
||||
|
||||
- System version:
|
||||
- Python version:
|
||||
- OpenManus version or branch:
|
||||
- Installation method (e.g., `pip install -r requirements.txt` or `pip install -e .`):
|
||||
|
||||
**Screenshots or logs**
|
||||
<!-- Attach screenshots or logs to help diagnose the issue -->
|
44
.github/ISSUE_TEMPLATE/show_me_the_bug.yaml
vendored
44
.github/ISSUE_TEMPLATE/show_me_the_bug.yaml
vendored
@ -1,44 +0,0 @@
|
||||
name: "🪲 Show me the Bug"
|
||||
description: Report a bug encountered while using OpenManus and seek assistance.
|
||||
labels: bug
|
||||
body:
|
||||
- type: textarea
|
||||
id: bug-description
|
||||
attributes:
|
||||
label: Bug Description
|
||||
description: |
|
||||
Clearly describe the bug you encountered
|
||||
validations:
|
||||
required: true
|
||||
- type: textarea
|
||||
id: solve-method
|
||||
attributes:
|
||||
label: Bug solved method
|
||||
description: |
|
||||
If resolved, explain the solution. Optionally, include a Pull Request URL.
|
||||
If unresolved, provide additional details to aid investigation
|
||||
validations:
|
||||
required: true
|
||||
- type: textarea
|
||||
id: environment-information
|
||||
attributes:
|
||||
label: Environment information
|
||||
description: |
|
||||
System: e.g., Ubuntu 22.04
|
||||
Python: e.g., 3.12
|
||||
OpenManus version: e.g., 0.1.0
|
||||
value: |
|
||||
- System version:
|
||||
- Python version:
|
||||
- OpenManus version or branch:
|
||||
- Installation method (e.g., `pip install -r requirements.txt` or `pip install -e .`):
|
||||
validations:
|
||||
required: true
|
||||
- type: textarea
|
||||
id: extra-information
|
||||
attributes:
|
||||
label: Extra information
|
||||
description: |
|
||||
For example, attach screenshots or logs to help diagnose the issue
|
||||
validations:
|
||||
required: false
|
33
.github/workflows/pr-autodiff.yaml
vendored
33
.github/workflows/pr-autodiff.yaml
vendored
@ -15,20 +15,21 @@ jobs:
|
||||
(github.event_name == 'pull_request') ||
|
||||
(github.event_name == 'issue_comment' &&
|
||||
contains(github.event.comment.body, '!pr-diff') &&
|
||||
(github.event.comment.author_association == 'CONTRIBUTOR' || github.event.comment.author_association == 'COLLABORATOR' || github.event.comment.author_association == 'MEMBER' || github.event.comment.author_association == 'OWNER') &&
|
||||
(github.event.comment.author_association == 'COLLABORATOR' || github.event.comment.author_association == 'MEMBER' || github.event.comment.author_association == 'OWNER') &&
|
||||
github.event.issue.pull_request)
|
||||
steps:
|
||||
- name: Get PR head SHA
|
||||
id: get-pr-sha
|
||||
run: |
|
||||
PR_URL="${{ github.event.issue.pull_request.url || github.event.pull_request.url }}"
|
||||
# https://api.github.com/repos/OpenManus/pulls/1
|
||||
RESPONSE=$(curl -s -H "Authorization: Bearer ${{ secrets.GITHUB_TOKEN }}" $PR_URL)
|
||||
SHA=$(echo $RESPONSE | jq -r '.head.sha')
|
||||
TARGET_BRANCH=$(echo $RESPONSE | jq -r '.base.ref')
|
||||
echo "pr_sha=$SHA" >> $GITHUB_OUTPUT
|
||||
echo "target_branch=$TARGET_BRANCH" >> $GITHUB_OUTPUT
|
||||
echo "Retrieved PR head SHA from API: $SHA, target branch: $TARGET_BRANCH"
|
||||
if [ "${{ github.event_name }}" == "pull_request" ]; then
|
||||
echo "pr_sha=${{ github.event.pull_request.head.sha }}" >> $GITHUB_OUTPUT
|
||||
echo "Retrieved PR head SHA: ${{ github.event.pull_request.head.sha }}"
|
||||
else
|
||||
PR_URL="${{ github.event.issue.pull_request.url }}"
|
||||
SHA=$(curl -s -H "Authorization: Bearer ${{ secrets.GITHUB_TOKEN }}" $PR_URL | jq -r '.head.sha')
|
||||
echo "pr_sha=$SHA" >> $GITHUB_OUTPUT
|
||||
echo "Retrieved PR head SHA from API: $SHA"
|
||||
fi
|
||||
- name: Check out code
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
@ -48,7 +49,6 @@ jobs:
|
||||
OPENAI_BASE_URL: ${{ secrets.OPENAI_BASE_URL }}
|
||||
GH_TOKEN: ${{ github.token }}
|
||||
PR_NUMBER: ${{ github.event.pull_request.number || github.event.issue.number }}
|
||||
TARGET_BRANCH: ${{ steps.get-pr-sha.outputs.target_branch }}
|
||||
run: |-
|
||||
cat << 'EOF' > /tmp/_workflow_core.py
|
||||
import os
|
||||
@ -59,7 +59,7 @@ jobs:
|
||||
|
||||
def get_diff():
|
||||
result = subprocess.run(
|
||||
['git', 'diff', 'origin/' + os.getenv('TARGET_BRANCH') + '...HEAD'],
|
||||
['git', 'diff', 'origin/main...HEAD'],
|
||||
capture_output=True, text=True, check=True)
|
||||
return '\n'.join(
|
||||
line for line in result.stdout.split('\n')
|
||||
@ -86,17 +86,6 @@ jobs:
|
||||
|
||||
### Spelling/Offensive Content Check
|
||||
- No spelling mistakes or offensive content found in the code or comments.
|
||||
|
||||
## 中文(简体)
|
||||
- 新增了 `ABC` 类
|
||||
- `foo` 模块中的 `f()` 行为已修复
|
||||
|
||||
### 评论高亮
|
||||
- `config.toml` 需要正确配置才能确保新功能正常运行。
|
||||
|
||||
### 内容检查
|
||||
- 没有发现代码或注释中的拼写错误或不当措辞。
|
||||
|
||||
3. Highlight non-English comments
|
||||
4. Check for spelling/offensive content'''
|
||||
|
||||
|
4
.github/workflows/top-issues.yaml
vendored
4
.github/workflows/top-issues.yaml
vendored
@ -23,7 +23,5 @@ jobs:
|
||||
dashboard: true
|
||||
dashboard_show_total_reactions: true
|
||||
top_issues: true
|
||||
top_features: true
|
||||
top_bugs: true
|
||||
top_pull_requests: true
|
||||
top_list_size: 14
|
||||
top_list_size: 32
|
||||
|
3
.gitignore
vendored
3
.gitignore
vendored
@ -194,6 +194,3 @@ cython_debug/
|
||||
|
||||
# Built Visual Studio Code Extensions
|
||||
*.vsix
|
||||
|
||||
# OSX
|
||||
.DS_Store
|
||||
|
12
README.md
12
README.md
@ -81,11 +81,6 @@ source .venv/bin/activate # On Unix/macOS
|
||||
uv pip install -r requirements.txt
|
||||
```
|
||||
|
||||
### Browser Automation Tool (Optional)
|
||||
```bash
|
||||
playwright install
|
||||
```
|
||||
|
||||
## Configuration
|
||||
|
||||
OpenManus requires configuration for the LLM APIs it uses. Follow these steps to set up your configuration:
|
||||
@ -124,12 +119,7 @@ python main.py
|
||||
|
||||
Then input your idea via terminal!
|
||||
|
||||
For MCP tool version, you can run:
|
||||
```bash
|
||||
python run_mcp.py
|
||||
```
|
||||
|
||||
For unstable multi-agent version, you also can run:
|
||||
For unstable version, you also can run:
|
||||
|
||||
```bash
|
||||
python run_flow.py
|
||||
|
12
README_ja.md
12
README_ja.md
@ -81,11 +81,6 @@ source .venv/bin/activate # Unix/macOSの場合
|
||||
uv pip install -r requirements.txt
|
||||
```
|
||||
|
||||
### ブラウザ自動化ツール(オプション)
|
||||
```bash
|
||||
playwright install
|
||||
```
|
||||
|
||||
## 設定
|
||||
|
||||
OpenManusを使用するには、LLM APIの設定が必要です。以下の手順に従って設定してください:
|
||||
@ -124,12 +119,7 @@ python main.py
|
||||
|
||||
その後、ターミナルからプロンプトを入力してください!
|
||||
|
||||
MCP ツールバージョンを使用する場合は、以下を実行します:
|
||||
```bash
|
||||
python run_mcp.py
|
||||
```
|
||||
|
||||
開発中のマルチエージェントバージョンを試すには、以下を実行します:
|
||||
開発中バージョンを試すには、以下を実行します:
|
||||
|
||||
```bash
|
||||
python run_flow.py
|
||||
|
12
README_ko.md
12
README_ko.md
@ -81,11 +81,6 @@ source .venv/bin/activate # Unix/macOS의 경우
|
||||
uv pip install -r requirements.txt
|
||||
```
|
||||
|
||||
### 브라우저 자동화 도구 (선택사항)
|
||||
```bash
|
||||
playwright install
|
||||
```
|
||||
|
||||
## 설정 방법
|
||||
|
||||
OpenManus를 사용하려면 사용하는 LLM API에 대한 설정이 필요합니다. 아래 단계를 따라 설정을 완료하세요:
|
||||
@ -124,12 +119,7 @@ python main.py
|
||||
|
||||
이후 터미널에서 아이디어를 작성하세요!
|
||||
|
||||
MCP 도구 버전을 사용하려면 다음을 실행하세요:
|
||||
```bash
|
||||
python run_mcp.py
|
||||
```
|
||||
|
||||
불안정한 멀티 에이전트 버전을 실행하려면 다음을 실행할 수 있습니다:
|
||||
unstable 버전을 실행하려면 아래 명령어를 사용할 수도 있습니다:
|
||||
|
||||
```bash
|
||||
python run_flow.py
|
||||
|
12
README_zh.md
12
README_zh.md
@ -82,11 +82,6 @@ source .venv/bin/activate # Unix/macOS 系统
|
||||
uv pip install -r requirements.txt
|
||||
```
|
||||
|
||||
### 浏览器自动化工具(可选)
|
||||
```bash
|
||||
playwright install
|
||||
```
|
||||
|
||||
## 配置说明
|
||||
|
||||
OpenManus 需要配置使用的 LLM API,请按以下步骤设置:
|
||||
@ -125,12 +120,7 @@ python main.py
|
||||
|
||||
然后通过终端输入你的创意!
|
||||
|
||||
如需使用 MCP 工具版本,可运行:
|
||||
```bash
|
||||
python run_mcp.py
|
||||
```
|
||||
|
||||
如需体验不稳定的多智能体版本,可运行:
|
||||
如需体验不稳定的开发版本,可运行:
|
||||
|
||||
```bash
|
||||
python run_flow.py
|
||||
|
@ -1,10 +0,0 @@
|
||||
# Python version check: 3.11-3.13
|
||||
import sys
|
||||
|
||||
|
||||
if sys.version_info < (3, 11) or sys.version_info > (3, 13):
|
||||
print(
|
||||
"Warning: Unsupported Python version {ver}, please use 3.11-3.13".format(
|
||||
ver=".".join(map(str, sys.version_info))
|
||||
)
|
||||
)
|
@ -1,6 +1,4 @@
|
||||
from app.agent.base import BaseAgent
|
||||
from app.agent.browser import BrowserAgent
|
||||
from app.agent.mcp import MCPAgent
|
||||
from app.agent.planning import PlanningAgent
|
||||
from app.agent.react import ReActAgent
|
||||
from app.agent.swe import SWEAgent
|
||||
@ -9,10 +7,8 @@ from app.agent.toolcall import ToolCallAgent
|
||||
|
||||
__all__ = [
|
||||
"BaseAgent",
|
||||
"BrowserAgent",
|
||||
"PlanningAgent",
|
||||
"ReActAgent",
|
||||
"SWEAgent",
|
||||
"ToolCallAgent",
|
||||
"MCPAgent",
|
||||
]
|
||||
|
@ -6,7 +6,6 @@ from pydantic import BaseModel, Field, model_validator
|
||||
|
||||
from app.llm import LLM
|
||||
from app.logger import logger
|
||||
from app.sandbox.client import SANDBOX_CLIENT
|
||||
from app.schema import ROLE_TYPE, AgentState, Memory, Message
|
||||
|
||||
|
||||
@ -85,7 +84,6 @@ class BaseAgent(BaseModel, ABC):
|
||||
self,
|
||||
role: ROLE_TYPE, # type: ignore
|
||||
content: str,
|
||||
base64_image: Optional[str] = None,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
"""Add a message to the agent's memory.
|
||||
@ -93,7 +91,6 @@ class BaseAgent(BaseModel, ABC):
|
||||
Args:
|
||||
role: The role of the message sender (user, system, assistant, tool).
|
||||
content: The message content.
|
||||
base64_image: Optional base64 encoded image.
|
||||
**kwargs: Additional arguments (e.g., tool_call_id for tool messages).
|
||||
|
||||
Raises:
|
||||
@ -109,9 +106,9 @@ class BaseAgent(BaseModel, ABC):
|
||||
if role not in message_map:
|
||||
raise ValueError(f"Unsupported message role: {role}")
|
||||
|
||||
# Create message with appropriate parameters based on role
|
||||
kwargs = {"base64_image": base64_image, **(kwargs if role == "tool" else {})}
|
||||
self.memory.add_message(message_map[role](content, **kwargs))
|
||||
msg_factory = message_map[role]
|
||||
msg = msg_factory(content, **kwargs) if role == "tool" else msg_factory(content)
|
||||
self.memory.add_message(msg)
|
||||
|
||||
async def run(self, request: Optional[str] = None) -> str:
|
||||
"""Execute the agent's main loop asynchronously.
|
||||
@ -150,7 +147,7 @@ class BaseAgent(BaseModel, ABC):
|
||||
self.current_step = 0
|
||||
self.state = AgentState.IDLE
|
||||
results.append(f"Terminated: Reached max steps ({self.max_steps})")
|
||||
await SANDBOX_CLIENT.cleanup()
|
||||
|
||||
return "\n".join(results) if results else "No steps executed"
|
||||
|
||||
@abstractmethod
|
||||
|
@ -1,129 +0,0 @@
|
||||
import json
|
||||
from typing import Any, Optional
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from app.agent.toolcall import ToolCallAgent
|
||||
from app.logger import logger
|
||||
from app.prompt.browser import NEXT_STEP_PROMPT, SYSTEM_PROMPT
|
||||
from app.schema import Message, ToolChoice
|
||||
from app.tool import BrowserUseTool, Terminate, ToolCollection
|
||||
|
||||
|
||||
class BrowserAgent(ToolCallAgent):
|
||||
"""
|
||||
A browser agent that uses the browser_use library to control a browser.
|
||||
|
||||
This agent can navigate web pages, interact with elements, fill forms,
|
||||
extract content, and perform other browser-based actions to accomplish tasks.
|
||||
"""
|
||||
|
||||
name: str = "browser"
|
||||
description: str = "A browser agent that can control a browser to accomplish tasks"
|
||||
|
||||
system_prompt: str = SYSTEM_PROMPT
|
||||
next_step_prompt: str = NEXT_STEP_PROMPT
|
||||
|
||||
max_observe: int = 10000
|
||||
max_steps: int = 20
|
||||
|
||||
# Configure the available tools
|
||||
available_tools: ToolCollection = Field(
|
||||
default_factory=lambda: ToolCollection(BrowserUseTool(), Terminate())
|
||||
)
|
||||
|
||||
# Use Auto for tool choice to allow both tool usage and free-form responses
|
||||
tool_choices: ToolChoice = ToolChoice.AUTO
|
||||
special_tool_names: list[str] = Field(default_factory=lambda: [Terminate().name])
|
||||
|
||||
_current_base64_image: Optional[str] = None
|
||||
|
||||
async def _handle_special_tool(self, name: str, result: Any, **kwargs):
|
||||
if not self._is_special_tool(name):
|
||||
return
|
||||
else:
|
||||
await self.available_tools.get_tool(BrowserUseTool().name).cleanup()
|
||||
await super()._handle_special_tool(name, result, **kwargs)
|
||||
|
||||
async def get_browser_state(self) -> Optional[dict]:
|
||||
"""Get the current browser state for context in next steps."""
|
||||
browser_tool = self.available_tools.get_tool(BrowserUseTool().name)
|
||||
if not browser_tool:
|
||||
return None
|
||||
|
||||
try:
|
||||
# Get browser state directly from the tool
|
||||
result = await browser_tool.get_current_state()
|
||||
|
||||
if result.error:
|
||||
logger.debug(f"Browser state error: {result.error}")
|
||||
return None
|
||||
|
||||
# Store screenshot if available
|
||||
if hasattr(result, "base64_image") and result.base64_image:
|
||||
self._current_base64_image = result.base64_image
|
||||
|
||||
# Parse the state info
|
||||
return json.loads(result.output)
|
||||
|
||||
except Exception as e:
|
||||
logger.debug(f"Failed to get browser state: {str(e)}")
|
||||
return None
|
||||
|
||||
async def think(self) -> bool:
|
||||
"""Process current state and decide next actions using tools, with browser state info added"""
|
||||
# Add browser state to the context
|
||||
browser_state = await self.get_browser_state()
|
||||
|
||||
# Initialize placeholder values
|
||||
url_info = ""
|
||||
tabs_info = ""
|
||||
content_above_info = ""
|
||||
content_below_info = ""
|
||||
results_info = ""
|
||||
|
||||
if browser_state and not browser_state.get("error"):
|
||||
# URL and title info
|
||||
url_info = f"\n URL: {browser_state.get('url', 'N/A')}\n Title: {browser_state.get('title', 'N/A')}"
|
||||
|
||||
# Tab information
|
||||
if "tabs" in browser_state:
|
||||
tabs = browser_state.get("tabs", [])
|
||||
if tabs:
|
||||
tabs_info = f"\n {len(tabs)} tab(s) available"
|
||||
|
||||
# Content above/below viewport
|
||||
pixels_above = browser_state.get("pixels_above", 0)
|
||||
pixels_below = browser_state.get("pixels_below", 0)
|
||||
|
||||
if pixels_above > 0:
|
||||
content_above_info = f" ({pixels_above} pixels)"
|
||||
|
||||
if pixels_below > 0:
|
||||
content_below_info = f" ({pixels_below} pixels)"
|
||||
|
||||
# Add screenshot as base64 if available
|
||||
if self._current_base64_image:
|
||||
# Create a message with image attachment
|
||||
image_message = Message.user_message(
|
||||
content="Current browser screenshot:",
|
||||
base64_image=self._current_base64_image,
|
||||
)
|
||||
self.memory.add_message(image_message)
|
||||
|
||||
# Replace placeholders with actual browser state info
|
||||
self.next_step_prompt = NEXT_STEP_PROMPT.format(
|
||||
url_placeholder=url_info,
|
||||
tabs_placeholder=tabs_info,
|
||||
content_above_placeholder=content_above_info,
|
||||
content_below_placeholder=content_below_info,
|
||||
results_placeholder=results_info,
|
||||
)
|
||||
|
||||
# Call parent implementation
|
||||
result = await super().think()
|
||||
|
||||
# Reset the next_step_prompt to its original state
|
||||
self.next_step_prompt = NEXT_STEP_PROMPT
|
||||
|
||||
return result
|
@ -1,20 +1,21 @@
|
||||
from typing import Any
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from app.agent.browser import BrowserAgent
|
||||
from app.config import config
|
||||
from app.prompt.browser import NEXT_STEP_PROMPT as BROWSER_NEXT_STEP_PROMPT
|
||||
from app.agent.toolcall import ToolCallAgent
|
||||
from app.prompt.manus import NEXT_STEP_PROMPT, SYSTEM_PROMPT
|
||||
from app.tool import Terminate, ToolCollection
|
||||
from app.tool.browser_use_tool import BrowserUseTool
|
||||
from app.tool.file_saver import FileSaver
|
||||
from app.tool.python_execute import PythonExecute
|
||||
from app.tool.str_replace_editor import StrReplaceEditor
|
||||
from app.tool.web_search import WebSearch
|
||||
|
||||
|
||||
class Manus(BrowserAgent):
|
||||
class Manus(ToolCallAgent):
|
||||
"""
|
||||
A versatile general-purpose agent that uses planning to solve various tasks.
|
||||
|
||||
This agent extends BrowserAgent with a comprehensive set of tools and capabilities,
|
||||
This agent extends PlanningAgent with a comprehensive set of tools and capabilities,
|
||||
including Python execution, web browsing, file operations, and information retrieval
|
||||
to handle a wide range of user requests.
|
||||
"""
|
||||
@ -24,40 +25,22 @@ class Manus(BrowserAgent):
|
||||
"A versatile agent that can solve various tasks using multiple tools"
|
||||
)
|
||||
|
||||
system_prompt: str = SYSTEM_PROMPT.format(directory=config.workspace_root)
|
||||
system_prompt: str = SYSTEM_PROMPT
|
||||
next_step_prompt: str = NEXT_STEP_PROMPT
|
||||
|
||||
max_observe: int = 10000
|
||||
max_observe: int = 2000
|
||||
max_steps: int = 20
|
||||
|
||||
# Add general-purpose tools to the tool collection
|
||||
available_tools: ToolCollection = Field(
|
||||
default_factory=lambda: ToolCollection(
|
||||
PythonExecute(), BrowserUseTool(), StrReplaceEditor(), Terminate()
|
||||
PythonExecute(), WebSearch(), BrowserUseTool(), FileSaver(), Terminate()
|
||||
)
|
||||
)
|
||||
|
||||
async def think(self) -> bool:
|
||||
"""Process current state and decide next actions with appropriate context."""
|
||||
# Store original prompt
|
||||
original_prompt = self.next_step_prompt
|
||||
|
||||
# Only check recent messages (last 3) for browser activity
|
||||
recent_messages = self.memory.messages[-3:] if self.memory.messages else []
|
||||
browser_in_use = any(
|
||||
"browser_use" in msg.content.lower()
|
||||
for msg in recent_messages
|
||||
if hasattr(msg, "content") and isinstance(msg.content, str)
|
||||
)
|
||||
|
||||
if browser_in_use:
|
||||
# Override with browser-specific prompt temporarily to get browser context
|
||||
self.next_step_prompt = BROWSER_NEXT_STEP_PROMPT
|
||||
|
||||
# Call parent's think method
|
||||
result = await super().think()
|
||||
|
||||
# Restore original prompt
|
||||
self.next_step_prompt = original_prompt
|
||||
|
||||
return result
|
||||
async def _handle_special_tool(self, name: str, result: Any, **kwargs):
|
||||
if not self._is_special_tool(name):
|
||||
return
|
||||
else:
|
||||
await self.available_tools.get_tool(BrowserUseTool().name).cleanup()
|
||||
await super()._handle_special_tool(name, result, **kwargs)
|
||||
|
185
app/agent/mcp.py
185
app/agent/mcp.py
@ -1,185 +0,0 @@
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from app.agent.toolcall import ToolCallAgent
|
||||
from app.logger import logger
|
||||
from app.prompt.mcp import MULTIMEDIA_RESPONSE_PROMPT, NEXT_STEP_PROMPT, SYSTEM_PROMPT
|
||||
from app.schema import AgentState, Message
|
||||
from app.tool.base import ToolResult
|
||||
from app.tool.mcp import MCPClients
|
||||
|
||||
|
||||
class MCPAgent(ToolCallAgent):
|
||||
"""Agent for interacting with MCP (Model Context Protocol) servers.
|
||||
|
||||
This agent connects to an MCP server using either SSE or stdio transport
|
||||
and makes the server's tools available through the agent's tool interface.
|
||||
"""
|
||||
|
||||
name: str = "mcp_agent"
|
||||
description: str = "An agent that connects to an MCP server and uses its tools."
|
||||
|
||||
system_prompt: str = SYSTEM_PROMPT
|
||||
next_step_prompt: str = NEXT_STEP_PROMPT
|
||||
|
||||
# Initialize MCP tool collection
|
||||
mcp_clients: MCPClients = Field(default_factory=MCPClients)
|
||||
available_tools: MCPClients = None # Will be set in initialize()
|
||||
|
||||
max_steps: int = 20
|
||||
connection_type: str = "stdio" # "stdio" or "sse"
|
||||
|
||||
# Track tool schemas to detect changes
|
||||
tool_schemas: Dict[str, Dict[str, Any]] = Field(default_factory=dict)
|
||||
_refresh_tools_interval: int = 5 # Refresh tools every N steps
|
||||
|
||||
# Special tool names that should trigger termination
|
||||
special_tool_names: List[str] = Field(default_factory=lambda: ["terminate"])
|
||||
|
||||
async def initialize(
|
||||
self,
|
||||
connection_type: Optional[str] = None,
|
||||
server_url: Optional[str] = None,
|
||||
command: Optional[str] = None,
|
||||
args: Optional[List[str]] = None,
|
||||
) -> None:
|
||||
"""Initialize the MCP connection.
|
||||
|
||||
Args:
|
||||
connection_type: Type of connection to use ("stdio" or "sse")
|
||||
server_url: URL of the MCP server (for SSE connection)
|
||||
command: Command to run (for stdio connection)
|
||||
args: Arguments for the command (for stdio connection)
|
||||
"""
|
||||
if connection_type:
|
||||
self.connection_type = connection_type
|
||||
|
||||
# Connect to the MCP server based on connection type
|
||||
if self.connection_type == "sse":
|
||||
if not server_url:
|
||||
raise ValueError("Server URL is required for SSE connection")
|
||||
await self.mcp_clients.connect_sse(server_url=server_url)
|
||||
elif self.connection_type == "stdio":
|
||||
if not command:
|
||||
raise ValueError("Command is required for stdio connection")
|
||||
await self.mcp_clients.connect_stdio(command=command, args=args or [])
|
||||
else:
|
||||
raise ValueError(f"Unsupported connection type: {self.connection_type}")
|
||||
|
||||
# Set available_tools to our MCP instance
|
||||
self.available_tools = self.mcp_clients
|
||||
|
||||
# Store initial tool schemas
|
||||
await self._refresh_tools()
|
||||
|
||||
# Add system message about available tools
|
||||
tool_names = list(self.mcp_clients.tool_map.keys())
|
||||
tools_info = ", ".join(tool_names)
|
||||
|
||||
# Add system prompt and available tools information
|
||||
self.memory.add_message(
|
||||
Message.system_message(
|
||||
f"{self.system_prompt}\n\nAvailable MCP tools: {tools_info}"
|
||||
)
|
||||
)
|
||||
|
||||
async def _refresh_tools(self) -> Tuple[List[str], List[str]]:
|
||||
"""Refresh the list of available tools from the MCP server.
|
||||
|
||||
Returns:
|
||||
A tuple of (added_tools, removed_tools)
|
||||
"""
|
||||
if not self.mcp_clients.session:
|
||||
return [], []
|
||||
|
||||
# Get current tool schemas directly from the server
|
||||
response = await self.mcp_clients.session.list_tools()
|
||||
current_tools = {tool.name: tool.inputSchema for tool in response.tools}
|
||||
|
||||
# Determine added, removed, and changed tools
|
||||
current_names = set(current_tools.keys())
|
||||
previous_names = set(self.tool_schemas.keys())
|
||||
|
||||
added_tools = list(current_names - previous_names)
|
||||
removed_tools = list(previous_names - current_names)
|
||||
|
||||
# Check for schema changes in existing tools
|
||||
changed_tools = []
|
||||
for name in current_names.intersection(previous_names):
|
||||
if current_tools[name] != self.tool_schemas.get(name):
|
||||
changed_tools.append(name)
|
||||
|
||||
# Update stored schemas
|
||||
self.tool_schemas = current_tools
|
||||
|
||||
# Log and notify about changes
|
||||
if added_tools:
|
||||
logger.info(f"Added MCP tools: {added_tools}")
|
||||
self.memory.add_message(
|
||||
Message.system_message(f"New tools available: {', '.join(added_tools)}")
|
||||
)
|
||||
if removed_tools:
|
||||
logger.info(f"Removed MCP tools: {removed_tools}")
|
||||
self.memory.add_message(
|
||||
Message.system_message(
|
||||
f"Tools no longer available: {', '.join(removed_tools)}"
|
||||
)
|
||||
)
|
||||
if changed_tools:
|
||||
logger.info(f"Changed MCP tools: {changed_tools}")
|
||||
|
||||
return added_tools, removed_tools
|
||||
|
||||
async def think(self) -> bool:
|
||||
"""Process current state and decide next action."""
|
||||
# Check MCP session and tools availability
|
||||
if not self.mcp_clients.session or not self.mcp_clients.tool_map:
|
||||
logger.info("MCP service is no longer available, ending interaction")
|
||||
self.state = AgentState.FINISHED
|
||||
return False
|
||||
|
||||
# Refresh tools periodically
|
||||
if self.current_step % self._refresh_tools_interval == 0:
|
||||
await self._refresh_tools()
|
||||
# All tools removed indicates shutdown
|
||||
if not self.mcp_clients.tool_map:
|
||||
logger.info("MCP service has shut down, ending interaction")
|
||||
self.state = AgentState.FINISHED
|
||||
return False
|
||||
|
||||
# Use the parent class's think method
|
||||
return await super().think()
|
||||
|
||||
async def _handle_special_tool(self, name: str, result: Any, **kwargs) -> None:
|
||||
"""Handle special tool execution and state changes"""
|
||||
# First process with parent handler
|
||||
await super()._handle_special_tool(name, result, **kwargs)
|
||||
|
||||
# Handle multimedia responses
|
||||
if isinstance(result, ToolResult) and result.base64_image:
|
||||
self.memory.add_message(
|
||||
Message.system_message(
|
||||
MULTIMEDIA_RESPONSE_PROMPT.format(tool_name=name)
|
||||
)
|
||||
)
|
||||
|
||||
def _should_finish_execution(self, name: str, **kwargs) -> bool:
|
||||
"""Determine if tool execution should finish the agent"""
|
||||
# Terminate if the tool name is 'terminate'
|
||||
return name.lower() == "terminate"
|
||||
|
||||
async def cleanup(self) -> None:
|
||||
"""Clean up MCP connection when done."""
|
||||
if self.mcp_clients.session:
|
||||
await self.mcp_clients.disconnect()
|
||||
logger.info("MCP connection closed")
|
||||
|
||||
async def run(self, request: Optional[str] = None) -> str:
|
||||
"""Run the agent with cleanup when done."""
|
||||
try:
|
||||
result = await super().run(request)
|
||||
return result
|
||||
finally:
|
||||
# Ensure cleanup happens even if there's an error
|
||||
await self.cleanup()
|
@ -29,8 +29,7 @@ class SWEAgent(ToolCallAgent):
|
||||
async def think(self) -> bool:
|
||||
"""Process current state and decide next action"""
|
||||
# Update working directory
|
||||
result = await self.bash.execute("pwd")
|
||||
self.working_dir = result.output
|
||||
self.working_dir = await self.bash.execute("pwd")
|
||||
self.next_step_prompt = self.next_step_prompt.format(
|
||||
current_dir=self.working_dir
|
||||
)
|
||||
|
@ -30,7 +30,6 @@ class ToolCallAgent(ReActAgent):
|
||||
special_tool_names: List[str] = Field(default_factory=lambda: [Terminate().name])
|
||||
|
||||
tool_calls: List[ToolCall] = Field(default_factory=list)
|
||||
_current_base64_image: Optional[str] = None
|
||||
|
||||
max_steps: int = 30
|
||||
max_observe: Optional[Union[int, bool]] = None
|
||||
@ -45,11 +44,9 @@ class ToolCallAgent(ReActAgent):
|
||||
# Get response with tool options
|
||||
response = await self.llm.ask_tool(
|
||||
messages=self.messages,
|
||||
system_msgs=(
|
||||
[Message.system_message(self.system_prompt)]
|
||||
if self.system_prompt
|
||||
else None
|
||||
),
|
||||
system_msgs=[Message.system_message(self.system_prompt)]
|
||||
if self.system_prompt
|
||||
else None,
|
||||
tools=self.available_tools.to_params(),
|
||||
tool_choice=self.tool_choices,
|
||||
)
|
||||
@ -71,42 +68,37 @@ class ToolCallAgent(ReActAgent):
|
||||
return False
|
||||
raise
|
||||
|
||||
self.tool_calls = tool_calls = (
|
||||
response.tool_calls if response and response.tool_calls else []
|
||||
)
|
||||
content = response.content if response and response.content else ""
|
||||
self.tool_calls = response.tool_calls
|
||||
|
||||
# Log response info
|
||||
logger.info(f"✨ {self.name}'s thoughts: {content}")
|
||||
logger.info(f"✨ {self.name}'s thoughts: {response.content}")
|
||||
logger.info(
|
||||
f"🛠️ {self.name} selected {len(tool_calls) if tool_calls else 0} tools to use"
|
||||
f"🛠️ {self.name} selected {len(response.tool_calls) if response.tool_calls else 0} tools to use"
|
||||
)
|
||||
if tool_calls:
|
||||
if response.tool_calls:
|
||||
logger.info(
|
||||
f"🧰 Tools being prepared: {[call.function.name for call in tool_calls]}"
|
||||
f"🧰 Tools being prepared: {[call.function.name for call in response.tool_calls]}"
|
||||
)
|
||||
logger.info(f"🔧 Tool arguments: {tool_calls[0].function.arguments}")
|
||||
|
||||
try:
|
||||
if response is None:
|
||||
raise RuntimeError("No response received from the LLM")
|
||||
|
||||
# Handle different tool_choices modes
|
||||
if self.tool_choices == ToolChoice.NONE:
|
||||
if tool_calls:
|
||||
if response.tool_calls:
|
||||
logger.warning(
|
||||
f"🤔 Hmm, {self.name} tried to use tools when they weren't available!"
|
||||
)
|
||||
if content:
|
||||
self.memory.add_message(Message.assistant_message(content))
|
||||
if response.content:
|
||||
self.memory.add_message(Message.assistant_message(response.content))
|
||||
return True
|
||||
return False
|
||||
|
||||
# Create and add assistant message
|
||||
assistant_msg = (
|
||||
Message.from_tool_calls(content=content, tool_calls=self.tool_calls)
|
||||
Message.from_tool_calls(
|
||||
content=response.content, tool_calls=self.tool_calls
|
||||
)
|
||||
if self.tool_calls
|
||||
else Message.assistant_message(content)
|
||||
else Message.assistant_message(response.content)
|
||||
)
|
||||
self.memory.add_message(assistant_msg)
|
||||
|
||||
@ -115,7 +107,7 @@ class ToolCallAgent(ReActAgent):
|
||||
|
||||
# For 'auto' mode, continue with content if no commands but content exists
|
||||
if self.tool_choices == ToolChoice.AUTO and not self.tool_calls:
|
||||
return bool(content)
|
||||
return bool(response.content)
|
||||
|
||||
return bool(self.tool_calls)
|
||||
except Exception as e:
|
||||
@ -138,9 +130,6 @@ class ToolCallAgent(ReActAgent):
|
||||
|
||||
results = []
|
||||
for command in self.tool_calls:
|
||||
# Reset base64_image for each tool call
|
||||
self._current_base64_image = None
|
||||
|
||||
result = await self.execute_tool(command)
|
||||
|
||||
if self.max_observe:
|
||||
@ -152,10 +141,7 @@ class ToolCallAgent(ReActAgent):
|
||||
|
||||
# Add tool response to memory
|
||||
tool_msg = Message.tool_message(
|
||||
content=result,
|
||||
tool_call_id=command.id,
|
||||
name=command.function.name,
|
||||
base64_image=self._current_base64_image,
|
||||
content=result, tool_call_id=command.id, name=command.function.name
|
||||
)
|
||||
self.memory.add_message(tool_msg)
|
||||
results.append(result)
|
||||
@ -179,29 +165,16 @@ class ToolCallAgent(ReActAgent):
|
||||
logger.info(f"🔧 Activating tool: '{name}'...")
|
||||
result = await self.available_tools.execute(name=name, tool_input=args)
|
||||
|
||||
# Handle special tools
|
||||
await self._handle_special_tool(name=name, result=result)
|
||||
|
||||
# Check if result is a ToolResult with base64_image
|
||||
if hasattr(result, "base64_image") and result.base64_image:
|
||||
# Store the base64_image for later use in tool_message
|
||||
self._current_base64_image = result.base64_image
|
||||
|
||||
# Format result for display
|
||||
observation = (
|
||||
f"Observed output of cmd `{name}` executed:\n{str(result)}"
|
||||
if result
|
||||
else f"Cmd `{name}` completed with no output"
|
||||
)
|
||||
return observation
|
||||
|
||||
# Format result for display (standard case)
|
||||
# Format result for display
|
||||
observation = (
|
||||
f"Observed output of cmd `{name}` executed:\n{str(result)}"
|
||||
if result
|
||||
else f"Cmd `{name}` completed with no output"
|
||||
)
|
||||
|
||||
# Handle special tools like `finish`
|
||||
await self._handle_special_tool(name=name, result=result)
|
||||
|
||||
return observation
|
||||
except json.JSONDecodeError:
|
||||
error_msg = f"Error parsing arguments for {name}: Invalid JSON format"
|
||||
@ -211,7 +184,7 @@ class ToolCallAgent(ReActAgent):
|
||||
return f"Error: {error_msg}"
|
||||
except Exception as e:
|
||||
error_msg = f"⚠️ Tool '{name}' encountered a problem: {str(e)}"
|
||||
logger.exception(error_msg)
|
||||
logger.error(error_msg)
|
||||
return f"Error: {error_msg}"
|
||||
|
||||
async def _handle_special_tool(self, name: str, result: Any, **kwargs):
|
||||
|
334
app/bedrock.py
334
app/bedrock.py
@ -1,334 +0,0 @@
|
||||
import json
|
||||
import sys
|
||||
import time
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from typing import Dict, List, Literal, Optional
|
||||
|
||||
import boto3
|
||||
|
||||
|
||||
# Global variables to track the current tool use ID across function calls
|
||||
# Tmp solution
|
||||
CURRENT_TOOLUSE_ID = None
|
||||
|
||||
|
||||
# Class to handle OpenAI-style response formatting
|
||||
class OpenAIResponse:
|
||||
def __init__(self, data):
|
||||
# Recursively convert nested dicts and lists to OpenAIResponse objects
|
||||
for key, value in data.items():
|
||||
if isinstance(value, dict):
|
||||
value = OpenAIResponse(value)
|
||||
elif isinstance(value, list):
|
||||
value = [
|
||||
OpenAIResponse(item) if isinstance(item, dict) else item
|
||||
for item in value
|
||||
]
|
||||
setattr(self, key, value)
|
||||
|
||||
def model_dump(self, *args, **kwargs):
|
||||
# Convert object to dict and add timestamp
|
||||
data = self.__dict__
|
||||
data["created_at"] = datetime.now().isoformat()
|
||||
return data
|
||||
|
||||
|
||||
# Main client class for interacting with Amazon Bedrock
|
||||
class BedrockClient:
|
||||
def __init__(self):
|
||||
# Initialize Bedrock client, you need to configure AWS env first
|
||||
try:
|
||||
self.client = boto3.client("bedrock-runtime")
|
||||
self.chat = Chat(self.client)
|
||||
except Exception as e:
|
||||
print(f"Error initializing Bedrock client: {e}")
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
# Chat interface class
|
||||
class Chat:
|
||||
def __init__(self, client):
|
||||
self.completions = ChatCompletions(client)
|
||||
|
||||
|
||||
# Core class handling chat completions functionality
|
||||
class ChatCompletions:
|
||||
def __init__(self, client):
|
||||
self.client = client
|
||||
|
||||
def _convert_openai_tools_to_bedrock_format(self, tools):
|
||||
# Convert OpenAI function calling format to Bedrock tool format
|
||||
bedrock_tools = []
|
||||
for tool in tools:
|
||||
if tool.get("type") == "function":
|
||||
function = tool.get("function", {})
|
||||
bedrock_tool = {
|
||||
"toolSpec": {
|
||||
"name": function.get("name", ""),
|
||||
"description": function.get("description", ""),
|
||||
"inputSchema": {
|
||||
"json": {
|
||||
"type": "object",
|
||||
"properties": function.get("parameters", {}).get(
|
||||
"properties", {}
|
||||
),
|
||||
"required": function.get("parameters", {}).get(
|
||||
"required", []
|
||||
),
|
||||
}
|
||||
},
|
||||
}
|
||||
}
|
||||
bedrock_tools.append(bedrock_tool)
|
||||
return bedrock_tools
|
||||
|
||||
def _convert_openai_messages_to_bedrock_format(self, messages):
|
||||
# Convert OpenAI message format to Bedrock message format
|
||||
bedrock_messages = []
|
||||
system_prompt = []
|
||||
for message in messages:
|
||||
if message.get("role") == "system":
|
||||
system_prompt = [{"text": message.get("content")}]
|
||||
elif message.get("role") == "user":
|
||||
bedrock_message = {
|
||||
"role": message.get("role", "user"),
|
||||
"content": [{"text": message.get("content")}],
|
||||
}
|
||||
bedrock_messages.append(bedrock_message)
|
||||
elif message.get("role") == "assistant":
|
||||
bedrock_message = {
|
||||
"role": "assistant",
|
||||
"content": [{"text": message.get("content")}],
|
||||
}
|
||||
openai_tool_calls = message.get("tool_calls", [])
|
||||
if openai_tool_calls:
|
||||
bedrock_tool_use = {
|
||||
"toolUseId": openai_tool_calls[0]["id"],
|
||||
"name": openai_tool_calls[0]["function"]["name"],
|
||||
"input": json.loads(
|
||||
openai_tool_calls[0]["function"]["arguments"]
|
||||
),
|
||||
}
|
||||
bedrock_message["content"].append({"toolUse": bedrock_tool_use})
|
||||
global CURRENT_TOOLUSE_ID
|
||||
CURRENT_TOOLUSE_ID = openai_tool_calls[0]["id"]
|
||||
bedrock_messages.append(bedrock_message)
|
||||
elif message.get("role") == "tool":
|
||||
bedrock_message = {
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"toolResult": {
|
||||
"toolUseId": CURRENT_TOOLUSE_ID,
|
||||
"content": [{"text": message.get("content")}],
|
||||
}
|
||||
}
|
||||
],
|
||||
}
|
||||
bedrock_messages.append(bedrock_message)
|
||||
else:
|
||||
raise ValueError(f"Invalid role: {message.get('role')}")
|
||||
return system_prompt, bedrock_messages
|
||||
|
||||
def _convert_bedrock_response_to_openai_format(self, bedrock_response):
|
||||
# Convert Bedrock response format to OpenAI format
|
||||
content = ""
|
||||
if bedrock_response.get("output", {}).get("message", {}).get("content"):
|
||||
content_array = bedrock_response["output"]["message"]["content"]
|
||||
content = "".join(item.get("text", "") for item in content_array)
|
||||
if content == "":
|
||||
content = "."
|
||||
|
||||
# Handle tool calls in response
|
||||
openai_tool_calls = []
|
||||
if bedrock_response.get("output", {}).get("message", {}).get("content"):
|
||||
for content_item in bedrock_response["output"]["message"]["content"]:
|
||||
if content_item.get("toolUse"):
|
||||
bedrock_tool_use = content_item["toolUse"]
|
||||
global CURRENT_TOOLUSE_ID
|
||||
CURRENT_TOOLUSE_ID = bedrock_tool_use["toolUseId"]
|
||||
openai_tool_call = {
|
||||
"id": CURRENT_TOOLUSE_ID,
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": bedrock_tool_use["name"],
|
||||
"arguments": json.dumps(bedrock_tool_use["input"]),
|
||||
},
|
||||
}
|
||||
openai_tool_calls.append(openai_tool_call)
|
||||
|
||||
# Construct final OpenAI format response
|
||||
openai_format = {
|
||||
"id": f"chatcmpl-{uuid.uuid4()}",
|
||||
"created": int(time.time()),
|
||||
"object": "chat.completion",
|
||||
"system_fingerprint": None,
|
||||
"choices": [
|
||||
{
|
||||
"finish_reason": bedrock_response.get("stopReason", "end_turn"),
|
||||
"index": 0,
|
||||
"message": {
|
||||
"content": content,
|
||||
"role": bedrock_response.get("output", {})
|
||||
.get("message", {})
|
||||
.get("role", "assistant"),
|
||||
"tool_calls": openai_tool_calls
|
||||
if openai_tool_calls != []
|
||||
else None,
|
||||
"function_call": None,
|
||||
},
|
||||
}
|
||||
],
|
||||
"usage": {
|
||||
"completion_tokens": bedrock_response.get("usage", {}).get(
|
||||
"outputTokens", 0
|
||||
),
|
||||
"prompt_tokens": bedrock_response.get("usage", {}).get(
|
||||
"inputTokens", 0
|
||||
),
|
||||
"total_tokens": bedrock_response.get("usage", {}).get("totalTokens", 0),
|
||||
},
|
||||
}
|
||||
return OpenAIResponse(openai_format)
|
||||
|
||||
async def _invoke_bedrock(
|
||||
self,
|
||||
model: str,
|
||||
messages: List[Dict[str, str]],
|
||||
max_tokens: int,
|
||||
temperature: float,
|
||||
tools: Optional[List[dict]] = None,
|
||||
tool_choice: Literal["none", "auto", "required"] = "auto",
|
||||
**kwargs,
|
||||
) -> OpenAIResponse:
|
||||
# Non-streaming invocation of Bedrock model
|
||||
(
|
||||
system_prompt,
|
||||
bedrock_messages,
|
||||
) = self._convert_openai_messages_to_bedrock_format(messages)
|
||||
response = self.client.converse(
|
||||
modelId=model,
|
||||
system=system_prompt,
|
||||
messages=bedrock_messages,
|
||||
inferenceConfig={"temperature": temperature, "maxTokens": max_tokens},
|
||||
toolConfig={"tools": tools} if tools else None,
|
||||
)
|
||||
openai_response = self._convert_bedrock_response_to_openai_format(response)
|
||||
return openai_response
|
||||
|
||||
async def _invoke_bedrock_stream(
|
||||
self,
|
||||
model: str,
|
||||
messages: List[Dict[str, str]],
|
||||
max_tokens: int,
|
||||
temperature: float,
|
||||
tools: Optional[List[dict]] = None,
|
||||
tool_choice: Literal["none", "auto", "required"] = "auto",
|
||||
**kwargs,
|
||||
) -> OpenAIResponse:
|
||||
# Streaming invocation of Bedrock model
|
||||
(
|
||||
system_prompt,
|
||||
bedrock_messages,
|
||||
) = self._convert_openai_messages_to_bedrock_format(messages)
|
||||
response = self.client.converse_stream(
|
||||
modelId=model,
|
||||
system=system_prompt,
|
||||
messages=bedrock_messages,
|
||||
inferenceConfig={"temperature": temperature, "maxTokens": max_tokens},
|
||||
toolConfig={"tools": tools} if tools else None,
|
||||
)
|
||||
|
||||
# Initialize response structure
|
||||
bedrock_response = {
|
||||
"output": {"message": {"role": "", "content": []}},
|
||||
"stopReason": "",
|
||||
"usage": {},
|
||||
"metrics": {},
|
||||
}
|
||||
bedrock_response_text = ""
|
||||
bedrock_response_tool_input = ""
|
||||
|
||||
# Process streaming response
|
||||
stream = response.get("stream")
|
||||
if stream:
|
||||
for event in stream:
|
||||
if event.get("messageStart", {}).get("role"):
|
||||
bedrock_response["output"]["message"]["role"] = event[
|
||||
"messageStart"
|
||||
]["role"]
|
||||
if event.get("contentBlockDelta", {}).get("delta", {}).get("text"):
|
||||
bedrock_response_text += event["contentBlockDelta"]["delta"]["text"]
|
||||
print(
|
||||
event["contentBlockDelta"]["delta"]["text"], end="", flush=True
|
||||
)
|
||||
if event.get("contentBlockStop", {}).get("contentBlockIndex") == 0:
|
||||
bedrock_response["output"]["message"]["content"].append(
|
||||
{"text": bedrock_response_text}
|
||||
)
|
||||
if event.get("contentBlockStart", {}).get("start", {}).get("toolUse"):
|
||||
bedrock_tool_use = event["contentBlockStart"]["start"]["toolUse"]
|
||||
tool_use = {
|
||||
"toolUseId": bedrock_tool_use["toolUseId"],
|
||||
"name": bedrock_tool_use["name"],
|
||||
}
|
||||
bedrock_response["output"]["message"]["content"].append(
|
||||
{"toolUse": tool_use}
|
||||
)
|
||||
global CURRENT_TOOLUSE_ID
|
||||
CURRENT_TOOLUSE_ID = bedrock_tool_use["toolUseId"]
|
||||
if event.get("contentBlockDelta", {}).get("delta", {}).get("toolUse"):
|
||||
bedrock_response_tool_input += event["contentBlockDelta"]["delta"][
|
||||
"toolUse"
|
||||
]["input"]
|
||||
print(
|
||||
event["contentBlockDelta"]["delta"]["toolUse"]["input"],
|
||||
end="",
|
||||
flush=True,
|
||||
)
|
||||
if event.get("contentBlockStop", {}).get("contentBlockIndex") == 1:
|
||||
bedrock_response["output"]["message"]["content"][1]["toolUse"][
|
||||
"input"
|
||||
] = json.loads(bedrock_response_tool_input)
|
||||
print()
|
||||
openai_response = self._convert_bedrock_response_to_openai_format(
|
||||
bedrock_response
|
||||
)
|
||||
return openai_response
|
||||
|
||||
def create(
|
||||
self,
|
||||
model: str,
|
||||
messages: List[Dict[str, str]],
|
||||
max_tokens: int,
|
||||
temperature: float,
|
||||
stream: Optional[bool] = True,
|
||||
tools: Optional[List[dict]] = None,
|
||||
tool_choice: Literal["none", "auto", "required"] = "auto",
|
||||
**kwargs,
|
||||
) -> OpenAIResponse:
|
||||
# Main entry point for chat completion
|
||||
bedrock_tools = []
|
||||
if tools is not None:
|
||||
bedrock_tools = self._convert_openai_tools_to_bedrock_format(tools)
|
||||
if stream:
|
||||
return self._invoke_bedrock_stream(
|
||||
model,
|
||||
messages,
|
||||
max_tokens,
|
||||
temperature,
|
||||
bedrock_tools,
|
||||
tool_choice,
|
||||
**kwargs,
|
||||
)
|
||||
else:
|
||||
return self._invoke_bedrock(
|
||||
model,
|
||||
messages,
|
||||
max_tokens,
|
||||
temperature,
|
||||
bedrock_tools,
|
||||
tool_choice,
|
||||
**kwargs,
|
||||
)
|
@ -25,7 +25,7 @@ class LLMSettings(BaseModel):
|
||||
description="Maximum input tokens to use across all requests (None for unlimited)",
|
||||
)
|
||||
temperature: float = Field(1.0, description="Sampling temperature")
|
||||
api_type: str = Field(..., description="Azure, Openai, or Ollama")
|
||||
api_type: str = Field(..., description="AzureOpenai or Openai")
|
||||
api_version: str = Field(..., description="Azure Openai version if AzureOpenai")
|
||||
|
||||
|
||||
@ -37,18 +37,6 @@ class ProxySettings(BaseModel):
|
||||
|
||||
class SearchSettings(BaseModel):
|
||||
engine: str = Field(default="Google", description="Search engine the llm to use")
|
||||
fallback_engines: List[str] = Field(
|
||||
default_factory=lambda: ["DuckDuckGo", "Baidu"],
|
||||
description="Fallback search engines to try if the primary engine fails",
|
||||
)
|
||||
retry_delay: int = Field(
|
||||
default=60,
|
||||
description="Seconds to wait before retrying all engines again after they all fail",
|
||||
)
|
||||
max_retries: int = Field(
|
||||
default=3,
|
||||
description="Maximum number of times to retry all engines when all fail",
|
||||
)
|
||||
|
||||
|
||||
class BrowserSettings(BaseModel):
|
||||
@ -71,30 +59,10 @@ class BrowserSettings(BaseModel):
|
||||
proxy: Optional[ProxySettings] = Field(
|
||||
None, description="Proxy settings for the browser"
|
||||
)
|
||||
max_content_length: int = Field(
|
||||
2000, description="Maximum length for content retrieval operations"
|
||||
)
|
||||
|
||||
|
||||
class SandboxSettings(BaseModel):
|
||||
"""Configuration for the execution sandbox"""
|
||||
|
||||
use_sandbox: bool = Field(False, description="Whether to use the sandbox")
|
||||
image: str = Field("python:3.12-slim", description="Base image")
|
||||
work_dir: str = Field("/workspace", description="Container working directory")
|
||||
memory_limit: str = Field("512m", description="Memory limit")
|
||||
cpu_limit: float = Field(1.0, description="CPU limit")
|
||||
timeout: int = Field(300, description="Default command timeout (seconds)")
|
||||
network_enabled: bool = Field(
|
||||
False, description="Whether network access is allowed"
|
||||
)
|
||||
|
||||
|
||||
class AppConfig(BaseModel):
|
||||
llm: Dict[str, LLMSettings]
|
||||
sandbox: Optional[SandboxSettings] = Field(
|
||||
None, description="Sandbox configuration"
|
||||
)
|
||||
browser_config: Optional[BrowserSettings] = Field(
|
||||
None, description="Browser configuration"
|
||||
)
|
||||
@ -197,11 +165,6 @@ class Config:
|
||||
search_settings = None
|
||||
if search_config:
|
||||
search_settings = SearchSettings(**search_config)
|
||||
sandbox_config = raw_config.get("sandbox", {})
|
||||
if sandbox_config:
|
||||
sandbox_settings = SandboxSettings(**sandbox_config)
|
||||
else:
|
||||
sandbox_settings = SandboxSettings()
|
||||
|
||||
config_dict = {
|
||||
"llm": {
|
||||
@ -211,7 +174,6 @@ class Config:
|
||||
for name, override_config in llm_overrides.items()
|
||||
},
|
||||
},
|
||||
"sandbox": sandbox_settings,
|
||||
"browser_config": browser_settings,
|
||||
"search_config": search_settings,
|
||||
}
|
||||
@ -222,10 +184,6 @@ class Config:
|
||||
def llm(self) -> Dict[str, LLMSettings]:
|
||||
return self._config.llm
|
||||
|
||||
@property
|
||||
def sandbox(self) -> SandboxSettings:
|
||||
return self._config.sandbox
|
||||
|
||||
@property
|
||||
def browser_config(self) -> Optional[BrowserSettings]:
|
||||
return self._config.browser_config
|
||||
@ -234,15 +192,5 @@ class Config:
|
||||
def search_config(self) -> Optional[SearchSettings]:
|
||||
return self._config.search_config
|
||||
|
||||
@property
|
||||
def workspace_root(self) -> Path:
|
||||
"""Get the workspace root directory"""
|
||||
return WORKSPACE_ROOT
|
||||
|
||||
@property
|
||||
def root_path(self) -> Path:
|
||||
"""Get the root path of the application"""
|
||||
return PROJECT_ROOT
|
||||
|
||||
|
||||
config = Config()
|
||||
|
@ -1,4 +1,5 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from enum import Enum
|
||||
from typing import Dict, List, Optional, Union
|
||||
|
||||
from pydantic import BaseModel
|
||||
@ -6,6 +7,10 @@ from pydantic import BaseModel
|
||||
from app.agent.base import BaseAgent
|
||||
|
||||
|
||||
class FlowType(str, Enum):
|
||||
PLANNING = "planning"
|
||||
|
||||
|
||||
class BaseFlow(BaseModel, ABC):
|
||||
"""Base class for execution flows supporting multiple agents"""
|
||||
|
||||
@ -55,3 +60,32 @@ class BaseFlow(BaseModel, ABC):
|
||||
@abstractmethod
|
||||
async def execute(self, input_text: str) -> str:
|
||||
"""Execute the flow with given input"""
|
||||
|
||||
|
||||
class PlanStepStatus(str, Enum):
|
||||
"""Enum class defining possible statuses of a plan step"""
|
||||
|
||||
NOT_STARTED = "not_started"
|
||||
IN_PROGRESS = "in_progress"
|
||||
COMPLETED = "completed"
|
||||
BLOCKED = "blocked"
|
||||
|
||||
@classmethod
|
||||
def get_all_statuses(cls) -> list[str]:
|
||||
"""Return a list of all possible step status values"""
|
||||
return [status.value for status in cls]
|
||||
|
||||
@classmethod
|
||||
def get_active_statuses(cls) -> list[str]:
|
||||
"""Return a list of values representing active statuses (not started or in progress)"""
|
||||
return [cls.NOT_STARTED.value, cls.IN_PROGRESS.value]
|
||||
|
||||
@classmethod
|
||||
def get_status_marks(cls) -> Dict[str, str]:
|
||||
"""Return a mapping of statuses to their marker symbols"""
|
||||
return {
|
||||
cls.COMPLETED.value: "[✓]",
|
||||
cls.IN_PROGRESS.value: "[→]",
|
||||
cls.BLOCKED.value: "[!]",
|
||||
cls.NOT_STARTED.value: "[ ]",
|
||||
}
|
||||
|
@ -1,15 +1,10 @@
|
||||
from enum import Enum
|
||||
from typing import Dict, List, Union
|
||||
|
||||
from app.agent.base import BaseAgent
|
||||
from app.flow.base import BaseFlow
|
||||
from app.flow.base import BaseFlow, FlowType
|
||||
from app.flow.planning import PlanningFlow
|
||||
|
||||
|
||||
class FlowType(str, Enum):
|
||||
PLANNING = "planning"
|
||||
|
||||
|
||||
class FlowFactory:
|
||||
"""Factory for creating different types of flows with support for multiple agents"""
|
||||
|
||||
|
@ -1,47 +1,17 @@
|
||||
import json
|
||||
import time
|
||||
from enum import Enum
|
||||
from typing import Dict, List, Optional, Union
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from app.agent.base import BaseAgent
|
||||
from app.flow.base import BaseFlow
|
||||
from app.flow.base import BaseFlow, PlanStepStatus
|
||||
from app.llm import LLM
|
||||
from app.logger import logger
|
||||
from app.schema import AgentState, Message, ToolChoice
|
||||
from app.tool import PlanningTool
|
||||
|
||||
|
||||
class PlanStepStatus(str, Enum):
|
||||
"""Enum class defining possible statuses of a plan step"""
|
||||
|
||||
NOT_STARTED = "not_started"
|
||||
IN_PROGRESS = "in_progress"
|
||||
COMPLETED = "completed"
|
||||
BLOCKED = "blocked"
|
||||
|
||||
@classmethod
|
||||
def get_all_statuses(cls) -> list[str]:
|
||||
"""Return a list of all possible step status values"""
|
||||
return [status.value for status in cls]
|
||||
|
||||
@classmethod
|
||||
def get_active_statuses(cls) -> list[str]:
|
||||
"""Return a list of values representing active statuses (not started or in progress)"""
|
||||
return [cls.NOT_STARTED.value, cls.IN_PROGRESS.value]
|
||||
|
||||
@classmethod
|
||||
def get_status_marks(cls) -> Dict[str, str]:
|
||||
"""Return a mapping of statuses to their marker symbols"""
|
||||
return {
|
||||
cls.COMPLETED.value: "[✓]",
|
||||
cls.IN_PROGRESS.value: "[→]",
|
||||
cls.BLOCKED.value: "[!]",
|
||||
cls.NOT_STARTED.value: "[ ]",
|
||||
}
|
||||
|
||||
|
||||
class PlanningFlow(BaseFlow):
|
||||
"""A flow that manages planning and execution of tasks using agents."""
|
||||
|
||||
|
460
app/llm.py
460
app/llm.py
@ -1,4 +1,3 @@
|
||||
import math
|
||||
from typing import Dict, List, Optional, Union
|
||||
|
||||
import tiktoken
|
||||
@ -10,7 +9,6 @@ from openai import (
|
||||
OpenAIError,
|
||||
RateLimitError,
|
||||
)
|
||||
from openai.types.chat.chat_completion_message import ChatCompletionMessage
|
||||
from tenacity import (
|
||||
retry,
|
||||
retry_if_exception_type,
|
||||
@ -18,7 +16,6 @@ from tenacity import (
|
||||
wait_random_exponential,
|
||||
)
|
||||
|
||||
from app.bedrock import BedrockClient
|
||||
from app.config import LLMSettings, config
|
||||
from app.exceptions import TokenLimitExceeded
|
||||
from app.logger import logger # Assuming a logger is set up in your app
|
||||
@ -32,150 +29,6 @@ from app.schema import (
|
||||
|
||||
|
||||
REASONING_MODELS = ["o1", "o3-mini"]
|
||||
MULTIMODAL_MODELS = [
|
||||
"gpt-4-vision-preview",
|
||||
"gpt-4o",
|
||||
"gpt-4o-mini",
|
||||
"claude-3-opus-20240229",
|
||||
"claude-3-sonnet-20240229",
|
||||
"claude-3-haiku-20240307",
|
||||
]
|
||||
|
||||
|
||||
class TokenCounter:
|
||||
# Token constants
|
||||
BASE_MESSAGE_TOKENS = 4
|
||||
FORMAT_TOKENS = 2
|
||||
LOW_DETAIL_IMAGE_TOKENS = 85
|
||||
HIGH_DETAIL_TILE_TOKENS = 170
|
||||
|
||||
# Image processing constants
|
||||
MAX_SIZE = 2048
|
||||
HIGH_DETAIL_TARGET_SHORT_SIDE = 768
|
||||
TILE_SIZE = 512
|
||||
|
||||
def __init__(self, tokenizer):
|
||||
self.tokenizer = tokenizer
|
||||
|
||||
def count_text(self, text: str) -> int:
|
||||
"""Calculate tokens for a text string"""
|
||||
return 0 if not text else len(self.tokenizer.encode(text))
|
||||
|
||||
def count_image(self, image_item: dict) -> int:
|
||||
"""
|
||||
Calculate tokens for an image based on detail level and dimensions
|
||||
|
||||
For "low" detail: fixed 85 tokens
|
||||
For "high" detail:
|
||||
1. Scale to fit in 2048x2048 square
|
||||
2. Scale shortest side to 768px
|
||||
3. Count 512px tiles (170 tokens each)
|
||||
4. Add 85 tokens
|
||||
"""
|
||||
detail = image_item.get("detail", "medium")
|
||||
|
||||
# For low detail, always return fixed token count
|
||||
if detail == "low":
|
||||
return self.LOW_DETAIL_IMAGE_TOKENS
|
||||
|
||||
# For medium detail (default in OpenAI), use high detail calculation
|
||||
# OpenAI doesn't specify a separate calculation for medium
|
||||
|
||||
# For high detail, calculate based on dimensions if available
|
||||
if detail == "high" or detail == "medium":
|
||||
# If dimensions are provided in the image_item
|
||||
if "dimensions" in image_item:
|
||||
width, height = image_item["dimensions"]
|
||||
return self._calculate_high_detail_tokens(width, height)
|
||||
|
||||
# Default values when dimensions aren't available or detail level is unknown
|
||||
if detail == "high":
|
||||
# Default to a 1024x1024 image calculation for high detail
|
||||
return self._calculate_high_detail_tokens(1024, 1024) # 765 tokens
|
||||
elif detail == "medium":
|
||||
# Default to a medium-sized image for medium detail
|
||||
return 1024 # This matches the original default
|
||||
else:
|
||||
# For unknown detail levels, use medium as default
|
||||
return 1024
|
||||
|
||||
def _calculate_high_detail_tokens(self, width: int, height: int) -> int:
|
||||
"""Calculate tokens for high detail images based on dimensions"""
|
||||
# Step 1: Scale to fit in MAX_SIZE x MAX_SIZE square
|
||||
if width > self.MAX_SIZE or height > self.MAX_SIZE:
|
||||
scale = self.MAX_SIZE / max(width, height)
|
||||
width = int(width * scale)
|
||||
height = int(height * scale)
|
||||
|
||||
# Step 2: Scale so shortest side is HIGH_DETAIL_TARGET_SHORT_SIDE
|
||||
scale = self.HIGH_DETAIL_TARGET_SHORT_SIDE / min(width, height)
|
||||
scaled_width = int(width * scale)
|
||||
scaled_height = int(height * scale)
|
||||
|
||||
# Step 3: Count number of 512px tiles
|
||||
tiles_x = math.ceil(scaled_width / self.TILE_SIZE)
|
||||
tiles_y = math.ceil(scaled_height / self.TILE_SIZE)
|
||||
total_tiles = tiles_x * tiles_y
|
||||
|
||||
# Step 4: Calculate final token count
|
||||
return (
|
||||
total_tiles * self.HIGH_DETAIL_TILE_TOKENS
|
||||
) + self.LOW_DETAIL_IMAGE_TOKENS
|
||||
|
||||
def count_content(self, content: Union[str, List[Union[str, dict]]]) -> int:
|
||||
"""Calculate tokens for message content"""
|
||||
if not content:
|
||||
return 0
|
||||
|
||||
if isinstance(content, str):
|
||||
return self.count_text(content)
|
||||
|
||||
token_count = 0
|
||||
for item in content:
|
||||
if isinstance(item, str):
|
||||
token_count += self.count_text(item)
|
||||
elif isinstance(item, dict):
|
||||
if "text" in item:
|
||||
token_count += self.count_text(item["text"])
|
||||
elif "image_url" in item:
|
||||
token_count += self.count_image(item)
|
||||
return token_count
|
||||
|
||||
def count_tool_calls(self, tool_calls: List[dict]) -> int:
|
||||
"""Calculate tokens for tool calls"""
|
||||
token_count = 0
|
||||
for tool_call in tool_calls:
|
||||
if "function" in tool_call:
|
||||
function = tool_call["function"]
|
||||
token_count += self.count_text(function.get("name", ""))
|
||||
token_count += self.count_text(function.get("arguments", ""))
|
||||
return token_count
|
||||
|
||||
def count_message_tokens(self, messages: List[dict]) -> int:
|
||||
"""Calculate the total number of tokens in a message list"""
|
||||
total_tokens = self.FORMAT_TOKENS # Base format tokens
|
||||
|
||||
for message in messages:
|
||||
tokens = self.BASE_MESSAGE_TOKENS # Base tokens per message
|
||||
|
||||
# Add role tokens
|
||||
tokens += self.count_text(message.get("role", ""))
|
||||
|
||||
# Add content tokens
|
||||
if "content" in message:
|
||||
tokens += self.count_content(message["content"])
|
||||
|
||||
# Add tool calls tokens
|
||||
if "tool_calls" in message:
|
||||
tokens += self.count_tool_calls(message["tool_calls"])
|
||||
|
||||
# Add name and tool_call_id tokens
|
||||
tokens += self.count_text(message.get("name", ""))
|
||||
tokens += self.count_text(message.get("tool_call_id", ""))
|
||||
|
||||
total_tokens += tokens
|
||||
|
||||
return total_tokens
|
||||
|
||||
|
||||
class LLM:
|
||||
@ -206,7 +59,6 @@ class LLM:
|
||||
|
||||
# Add token counting related attributes
|
||||
self.total_input_tokens = 0
|
||||
self.total_completion_tokens = 0
|
||||
self.max_input_tokens = (
|
||||
llm_config.max_input_tokens
|
||||
if hasattr(llm_config, "max_input_tokens")
|
||||
@ -226,13 +78,9 @@ class LLM:
|
||||
api_key=self.api_key,
|
||||
api_version=self.api_version,
|
||||
)
|
||||
elif self.api_type == "aws":
|
||||
self.client = BedrockClient()
|
||||
else:
|
||||
self.client = AsyncOpenAI(api_key=self.api_key, base_url=self.base_url)
|
||||
|
||||
self.token_counter = TokenCounter(self.tokenizer)
|
||||
|
||||
def count_tokens(self, text: str) -> int:
|
||||
"""Calculate the number of tokens in a text"""
|
||||
if not text:
|
||||
@ -240,17 +88,53 @@ class LLM:
|
||||
return len(self.tokenizer.encode(text))
|
||||
|
||||
def count_message_tokens(self, messages: List[dict]) -> int:
|
||||
return self.token_counter.count_message_tokens(messages)
|
||||
"""Calculate the number of tokens in a message list"""
|
||||
token_count = 0
|
||||
for message in messages:
|
||||
# Base token count for each message (according to OpenAI's calculation method)
|
||||
token_count += 4 # Base token count for each message
|
||||
|
||||
def update_token_count(self, input_tokens: int, completion_tokens: int = 0) -> None:
|
||||
# Calculate tokens for the role
|
||||
if "role" in message:
|
||||
token_count += self.count_tokens(message["role"])
|
||||
|
||||
# Calculate tokens for the content
|
||||
if "content" in message and message["content"]:
|
||||
token_count += self.count_tokens(message["content"])
|
||||
|
||||
# Calculate tokens for tool calls
|
||||
if "tool_calls" in message and message["tool_calls"]:
|
||||
for tool_call in message["tool_calls"]:
|
||||
if "function" in tool_call:
|
||||
# Function name
|
||||
if "name" in tool_call["function"]:
|
||||
token_count += self.count_tokens(
|
||||
tool_call["function"]["name"]
|
||||
)
|
||||
# Function arguments
|
||||
if "arguments" in tool_call["function"]:
|
||||
token_count += self.count_tokens(
|
||||
tool_call["function"]["arguments"]
|
||||
)
|
||||
|
||||
# Calculate tokens for tool responses
|
||||
if "name" in message and message["name"]:
|
||||
token_count += self.count_tokens(message["name"])
|
||||
|
||||
if "tool_call_id" in message and message["tool_call_id"]:
|
||||
token_count += self.count_tokens(message["tool_call_id"])
|
||||
|
||||
# Add extra tokens for message format
|
||||
token_count += 2 # Extra tokens for message format
|
||||
|
||||
return token_count
|
||||
|
||||
def update_token_count(self, input_tokens: int) -> None:
|
||||
"""Update token counts"""
|
||||
# Only track tokens if max_input_tokens is set
|
||||
self.total_input_tokens += input_tokens
|
||||
self.total_completion_tokens += completion_tokens
|
||||
logger.info(
|
||||
f"Token usage: Input={input_tokens}, Completion={completion_tokens}, "
|
||||
f"Cumulative Input={self.total_input_tokens}, Cumulative Completion={self.total_completion_tokens}, "
|
||||
f"Total={input_tokens + completion_tokens}, Cumulative Total={self.total_input_tokens + self.total_completion_tokens}"
|
||||
f"Token usage: Input={input_tokens}, Cumulative Input={self.total_input_tokens}"
|
||||
)
|
||||
|
||||
def check_token_limit(self, input_tokens: int) -> bool:
|
||||
@ -271,15 +155,12 @@ class LLM:
|
||||
return "Token limit exceeded"
|
||||
|
||||
@staticmethod
|
||||
def format_messages(
|
||||
messages: List[Union[dict, Message]], supports_images: bool = False
|
||||
) -> List[dict]:
|
||||
def format_messages(messages: List[Union[dict, Message]]) -> List[dict]:
|
||||
"""
|
||||
Format messages for LLM by converting them to OpenAI message format.
|
||||
|
||||
Args:
|
||||
messages: List of messages that can be either dict or Message objects
|
||||
supports_images: Flag indicating if the target model supports image inputs
|
||||
|
||||
Returns:
|
||||
List[dict]: List of formatted messages in OpenAI format
|
||||
@ -299,52 +180,12 @@ class LLM:
|
||||
formatted_messages = []
|
||||
|
||||
for message in messages:
|
||||
# Convert Message objects to dictionaries
|
||||
if isinstance(message, Message):
|
||||
message = message.to_dict()
|
||||
|
||||
if isinstance(message, dict):
|
||||
# If message is a dict, ensure it has required fields
|
||||
if "role" not in message:
|
||||
raise ValueError("Message dict must contain 'role' field")
|
||||
|
||||
# Process base64 images if present and model supports images
|
||||
if supports_images and message.get("base64_image"):
|
||||
# Initialize or convert content to appropriate format
|
||||
if not message.get("content"):
|
||||
message["content"] = []
|
||||
elif isinstance(message["content"], str):
|
||||
message["content"] = [
|
||||
{"type": "text", "text": message["content"]}
|
||||
]
|
||||
elif isinstance(message["content"], list):
|
||||
# Convert string items to proper text objects
|
||||
message["content"] = [
|
||||
(
|
||||
{"type": "text", "text": item}
|
||||
if isinstance(item, str)
|
||||
else item
|
||||
)
|
||||
for item in message["content"]
|
||||
]
|
||||
|
||||
# Add the image to content
|
||||
message["content"].append(
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": f"data:image/jpeg;base64,{message['base64_image']}"
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
# Remove the base64_image field
|
||||
del message["base64_image"]
|
||||
# If model doesn't support images but message has base64_image, handle gracefully
|
||||
elif not supports_images and message.get("base64_image"):
|
||||
# Just remove the base64_image field and keep the text content
|
||||
del message["base64_image"]
|
||||
|
||||
if "content" in message or "tool_calls" in message:
|
||||
formatted_messages.append(message)
|
||||
# else: do not include the message
|
||||
@ -391,15 +232,12 @@ class LLM:
|
||||
Exception: For unexpected errors
|
||||
"""
|
||||
try:
|
||||
# Check if the model supports images
|
||||
supports_images = self.model in MULTIMODAL_MODELS
|
||||
|
||||
# Format system and user messages with image support check
|
||||
# Format system and user messages
|
||||
if system_msgs:
|
||||
system_msgs = self.format_messages(system_msgs, supports_images)
|
||||
messages = system_msgs + self.format_messages(messages, supports_images)
|
||||
system_msgs = self.format_messages(system_msgs)
|
||||
messages = system_msgs + self.format_messages(messages)
|
||||
else:
|
||||
messages = self.format_messages(messages, supports_images)
|
||||
messages = self.format_messages(messages)
|
||||
|
||||
# Calculate input token count
|
||||
input_tokens = self.count_message_tokens(messages)
|
||||
@ -425,188 +263,22 @@ class LLM:
|
||||
|
||||
if not stream:
|
||||
# Non-streaming request
|
||||
response = await self.client.chat.completions.create(
|
||||
**params, stream=False
|
||||
)
|
||||
params["stream"] = False
|
||||
|
||||
response = await self.client.chat.completions.create(**params)
|
||||
|
||||
if not response.choices or not response.choices[0].message.content:
|
||||
raise ValueError("Empty or invalid response from LLM")
|
||||
|
||||
# Update token counts
|
||||
self.update_token_count(
|
||||
response.usage.prompt_tokens, response.usage.completion_tokens
|
||||
)
|
||||
self.update_token_count(response.usage.prompt_tokens)
|
||||
|
||||
return response.choices[0].message.content
|
||||
|
||||
# Streaming request, For streaming, update estimated token count before making the request
|
||||
self.update_token_count(input_tokens)
|
||||
|
||||
response = await self.client.chat.completions.create(**params, stream=True)
|
||||
|
||||
collected_messages = []
|
||||
completion_text = ""
|
||||
async for chunk in response:
|
||||
chunk_message = chunk.choices[0].delta.content or ""
|
||||
collected_messages.append(chunk_message)
|
||||
completion_text += chunk_message
|
||||
print(chunk_message, end="", flush=True)
|
||||
|
||||
print() # Newline after streaming
|
||||
full_response = "".join(collected_messages).strip()
|
||||
if not full_response:
|
||||
raise ValueError("Empty response from streaming LLM")
|
||||
|
||||
# estimate completion tokens for streaming response
|
||||
completion_tokens = self.count_tokens(completion_text)
|
||||
logger.info(
|
||||
f"Estimated completion tokens for streaming response: {completion_tokens}"
|
||||
)
|
||||
self.total_completion_tokens += completion_tokens
|
||||
|
||||
return full_response
|
||||
|
||||
except TokenLimitExceeded:
|
||||
# Re-raise token limit errors without logging
|
||||
raise
|
||||
except ValueError:
|
||||
logger.exception(f"Validation error")
|
||||
raise
|
||||
except OpenAIError as oe:
|
||||
logger.exception(f"OpenAI API error")
|
||||
if isinstance(oe, AuthenticationError):
|
||||
logger.error("Authentication failed. Check API key.")
|
||||
elif isinstance(oe, RateLimitError):
|
||||
logger.error("Rate limit exceeded. Consider increasing retry attempts.")
|
||||
elif isinstance(oe, APIError):
|
||||
logger.error(f"API error: {oe}")
|
||||
raise
|
||||
except Exception:
|
||||
logger.exception(f"Unexpected error in ask")
|
||||
raise
|
||||
|
||||
@retry(
|
||||
wait=wait_random_exponential(min=1, max=60),
|
||||
stop=stop_after_attempt(6),
|
||||
retry=retry_if_exception_type(
|
||||
(OpenAIError, Exception, ValueError)
|
||||
), # Don't retry TokenLimitExceeded
|
||||
)
|
||||
async def ask_with_images(
|
||||
self,
|
||||
messages: List[Union[dict, Message]],
|
||||
images: List[Union[str, dict]],
|
||||
system_msgs: Optional[List[Union[dict, Message]]] = None,
|
||||
stream: bool = False,
|
||||
temperature: Optional[float] = None,
|
||||
) -> str:
|
||||
"""
|
||||
Send a prompt with images to the LLM and get the response.
|
||||
|
||||
Args:
|
||||
messages: List of conversation messages
|
||||
images: List of image URLs or image data dictionaries
|
||||
system_msgs: Optional system messages to prepend
|
||||
stream (bool): Whether to stream the response
|
||||
temperature (float): Sampling temperature for the response
|
||||
|
||||
Returns:
|
||||
str: The generated response
|
||||
|
||||
Raises:
|
||||
TokenLimitExceeded: If token limits are exceeded
|
||||
ValueError: If messages are invalid or response is empty
|
||||
OpenAIError: If API call fails after retries
|
||||
Exception: For unexpected errors
|
||||
"""
|
||||
try:
|
||||
# For ask_with_images, we always set supports_images to True because
|
||||
# this method should only be called with models that support images
|
||||
if self.model not in MULTIMODAL_MODELS:
|
||||
raise ValueError(
|
||||
f"Model {self.model} does not support images. Use a model from {MULTIMODAL_MODELS}"
|
||||
)
|
||||
|
||||
# Format messages with image support
|
||||
formatted_messages = self.format_messages(messages, supports_images=True)
|
||||
|
||||
# Ensure the last message is from the user to attach images
|
||||
if not formatted_messages or formatted_messages[-1]["role"] != "user":
|
||||
raise ValueError(
|
||||
"The last message must be from the user to attach images"
|
||||
)
|
||||
|
||||
# Process the last user message to include images
|
||||
last_message = formatted_messages[-1]
|
||||
|
||||
# Convert content to multimodal format if needed
|
||||
content = last_message["content"]
|
||||
multimodal_content = (
|
||||
[{"type": "text", "text": content}]
|
||||
if isinstance(content, str)
|
||||
else content
|
||||
if isinstance(content, list)
|
||||
else []
|
||||
)
|
||||
|
||||
# Add images to content
|
||||
for image in images:
|
||||
if isinstance(image, str):
|
||||
multimodal_content.append(
|
||||
{"type": "image_url", "image_url": {"url": image}}
|
||||
)
|
||||
elif isinstance(image, dict) and "url" in image:
|
||||
multimodal_content.append({"type": "image_url", "image_url": image})
|
||||
elif isinstance(image, dict) and "image_url" in image:
|
||||
multimodal_content.append(image)
|
||||
else:
|
||||
raise ValueError(f"Unsupported image format: {image}")
|
||||
|
||||
# Update the message with multimodal content
|
||||
last_message["content"] = multimodal_content
|
||||
|
||||
# Add system messages if provided
|
||||
if system_msgs:
|
||||
all_messages = (
|
||||
self.format_messages(system_msgs, supports_images=True)
|
||||
+ formatted_messages
|
||||
)
|
||||
else:
|
||||
all_messages = formatted_messages
|
||||
|
||||
# Calculate tokens and check limits
|
||||
input_tokens = self.count_message_tokens(all_messages)
|
||||
if not self.check_token_limit(input_tokens):
|
||||
raise TokenLimitExceeded(self.get_limit_error_message(input_tokens))
|
||||
|
||||
# Set up API parameters
|
||||
params = {
|
||||
"model": self.model,
|
||||
"messages": all_messages,
|
||||
"stream": stream,
|
||||
}
|
||||
|
||||
# Add model-specific parameters
|
||||
if self.model in REASONING_MODELS:
|
||||
params["max_completion_tokens"] = self.max_tokens
|
||||
else:
|
||||
params["max_tokens"] = self.max_tokens
|
||||
params["temperature"] = (
|
||||
temperature if temperature is not None else self.temperature
|
||||
)
|
||||
|
||||
# Handle non-streaming request
|
||||
if not stream:
|
||||
response = await self.client.chat.completions.create(**params)
|
||||
|
||||
if not response.choices or not response.choices[0].message.content:
|
||||
raise ValueError("Empty or invalid response from LLM")
|
||||
|
||||
self.update_token_count(response.usage.prompt_tokens)
|
||||
return response.choices[0].message.content
|
||||
|
||||
# Handle streaming request
|
||||
self.update_token_count(input_tokens)
|
||||
params["stream"] = True
|
||||
response = await self.client.chat.completions.create(**params)
|
||||
|
||||
collected_messages = []
|
||||
@ -617,16 +289,16 @@ class LLM:
|
||||
|
||||
print() # Newline after streaming
|
||||
full_response = "".join(collected_messages).strip()
|
||||
|
||||
if not full_response:
|
||||
raise ValueError("Empty response from streaming LLM")
|
||||
|
||||
return full_response
|
||||
|
||||
except TokenLimitExceeded:
|
||||
# Re-raise token limit errors without logging
|
||||
raise
|
||||
except ValueError as ve:
|
||||
logger.error(f"Validation error in ask_with_images: {ve}")
|
||||
logger.error(f"Validation error: {ve}")
|
||||
raise
|
||||
except OpenAIError as oe:
|
||||
logger.error(f"OpenAI API error: {oe}")
|
||||
@ -638,7 +310,7 @@ class LLM:
|
||||
logger.error(f"API error: {oe}")
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error in ask_with_images: {e}")
|
||||
logger.error(f"Unexpected error in ask: {e}")
|
||||
raise
|
||||
|
||||
@retry(
|
||||
@ -657,7 +329,7 @@ class LLM:
|
||||
tool_choice: TOOL_CHOICE_TYPE = ToolChoice.AUTO, # type: ignore
|
||||
temperature: Optional[float] = None,
|
||||
**kwargs,
|
||||
) -> ChatCompletionMessage | None:
|
||||
):
|
||||
"""
|
||||
Ask LLM using functions/tools and return the response.
|
||||
|
||||
@ -684,15 +356,12 @@ class LLM:
|
||||
if tool_choice not in TOOL_CHOICE_VALUES:
|
||||
raise ValueError(f"Invalid tool_choice: {tool_choice}")
|
||||
|
||||
# Check if the model supports images
|
||||
supports_images = self.model in MULTIMODAL_MODELS
|
||||
|
||||
# Format messages
|
||||
if system_msgs:
|
||||
system_msgs = self.format_messages(system_msgs, supports_images)
|
||||
messages = system_msgs + self.format_messages(messages, supports_images)
|
||||
system_msgs = self.format_messages(system_msgs)
|
||||
messages = system_msgs + self.format_messages(messages)
|
||||
else:
|
||||
messages = self.format_messages(messages, supports_images)
|
||||
messages = self.format_messages(messages)
|
||||
|
||||
# Calculate input token count
|
||||
input_tokens = self.count_message_tokens(messages)
|
||||
@ -735,20 +404,15 @@ class LLM:
|
||||
temperature if temperature is not None else self.temperature
|
||||
)
|
||||
|
||||
response: ChatCompletion = await self.client.chat.completions.create(
|
||||
**params, stream=False
|
||||
)
|
||||
response = await self.client.chat.completions.create(**params)
|
||||
|
||||
# Check if response is valid
|
||||
if not response.choices or not response.choices[0].message:
|
||||
print(response)
|
||||
# raise ValueError("Invalid or empty response from LLM")
|
||||
return None
|
||||
raise ValueError("Invalid or empty response from LLM")
|
||||
|
||||
# Update token counts
|
||||
self.update_token_count(
|
||||
response.usage.prompt_tokens, response.usage.completion_tokens
|
||||
)
|
||||
self.update_token_count(response.usage.prompt_tokens)
|
||||
|
||||
return response.choices[0].message
|
||||
|
||||
|
@ -1,180 +0,0 @@
|
||||
import logging
|
||||
import sys
|
||||
|
||||
|
||||
logging.basicConfig(level=logging.INFO, handlers=[logging.StreamHandler(sys.stderr)])
|
||||
|
||||
import argparse
|
||||
import asyncio
|
||||
import atexit
|
||||
import json
|
||||
from inspect import Parameter, Signature
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from mcp.server.fastmcp import FastMCP
|
||||
|
||||
from app.logger import logger
|
||||
from app.tool.base import BaseTool
|
||||
from app.tool.bash import Bash
|
||||
from app.tool.browser_use_tool import BrowserUseTool
|
||||
from app.tool.str_replace_editor import StrReplaceEditor
|
||||
from app.tool.terminate import Terminate
|
||||
|
||||
|
||||
class MCPServer:
|
||||
"""MCP Server implementation with tool registration and management."""
|
||||
|
||||
def __init__(self, name: str = "openmanus"):
|
||||
self.server = FastMCP(name)
|
||||
self.tools: Dict[str, BaseTool] = {}
|
||||
|
||||
# Initialize standard tools
|
||||
self.tools["bash"] = Bash()
|
||||
self.tools["browser"] = BrowserUseTool()
|
||||
self.tools["editor"] = StrReplaceEditor()
|
||||
self.tools["terminate"] = Terminate()
|
||||
|
||||
def register_tool(self, tool: BaseTool, method_name: Optional[str] = None) -> None:
|
||||
"""Register a tool with parameter validation and documentation."""
|
||||
tool_name = method_name or tool.name
|
||||
tool_param = tool.to_param()
|
||||
tool_function = tool_param["function"]
|
||||
|
||||
# Define the async function to be registered
|
||||
async def tool_method(**kwargs):
|
||||
logger.info(f"Executing {tool_name}: {kwargs}")
|
||||
result = await tool.execute(**kwargs)
|
||||
|
||||
logger.info(f"Result of {tool_name}: {result}")
|
||||
|
||||
# Handle different types of results (match original logic)
|
||||
if hasattr(result, "model_dump"):
|
||||
return json.dumps(result.model_dump())
|
||||
elif isinstance(result, dict):
|
||||
return json.dumps(result)
|
||||
return result
|
||||
|
||||
# Set method metadata
|
||||
tool_method.__name__ = tool_name
|
||||
tool_method.__doc__ = self._build_docstring(tool_function)
|
||||
tool_method.__signature__ = self._build_signature(tool_function)
|
||||
|
||||
# Store parameter schema (important for tools that access it programmatically)
|
||||
param_props = tool_function.get("parameters", {}).get("properties", {})
|
||||
required_params = tool_function.get("parameters", {}).get("required", [])
|
||||
tool_method._parameter_schema = {
|
||||
param_name: {
|
||||
"description": param_details.get("description", ""),
|
||||
"type": param_details.get("type", "any"),
|
||||
"required": param_name in required_params,
|
||||
}
|
||||
for param_name, param_details in param_props.items()
|
||||
}
|
||||
|
||||
# Register with server
|
||||
self.server.tool()(tool_method)
|
||||
logger.info(f"Registered tool: {tool_name}")
|
||||
|
||||
def _build_docstring(self, tool_function: dict) -> str:
|
||||
"""Build a formatted docstring from tool function metadata."""
|
||||
description = tool_function.get("description", "")
|
||||
param_props = tool_function.get("parameters", {}).get("properties", {})
|
||||
required_params = tool_function.get("parameters", {}).get("required", [])
|
||||
|
||||
# Build docstring (match original format)
|
||||
docstring = description
|
||||
if param_props:
|
||||
docstring += "\n\nParameters:\n"
|
||||
for param_name, param_details in param_props.items():
|
||||
required_str = (
|
||||
"(required)" if param_name in required_params else "(optional)"
|
||||
)
|
||||
param_type = param_details.get("type", "any")
|
||||
param_desc = param_details.get("description", "")
|
||||
docstring += (
|
||||
f" {param_name} ({param_type}) {required_str}: {param_desc}\n"
|
||||
)
|
||||
|
||||
return docstring
|
||||
|
||||
def _build_signature(self, tool_function: dict) -> Signature:
|
||||
"""Build a function signature from tool function metadata."""
|
||||
param_props = tool_function.get("parameters", {}).get("properties", {})
|
||||
required_params = tool_function.get("parameters", {}).get("required", [])
|
||||
|
||||
parameters = []
|
||||
|
||||
# Follow original type mapping
|
||||
for param_name, param_details in param_props.items():
|
||||
param_type = param_details.get("type", "")
|
||||
default = Parameter.empty if param_name in required_params else None
|
||||
|
||||
# Map JSON Schema types to Python types (same as original)
|
||||
annotation = Any
|
||||
if param_type == "string":
|
||||
annotation = str
|
||||
elif param_type == "integer":
|
||||
annotation = int
|
||||
elif param_type == "number":
|
||||
annotation = float
|
||||
elif param_type == "boolean":
|
||||
annotation = bool
|
||||
elif param_type == "object":
|
||||
annotation = dict
|
||||
elif param_type == "array":
|
||||
annotation = list
|
||||
|
||||
# Create parameter with same structure as original
|
||||
param = Parameter(
|
||||
name=param_name,
|
||||
kind=Parameter.KEYWORD_ONLY,
|
||||
default=default,
|
||||
annotation=annotation,
|
||||
)
|
||||
parameters.append(param)
|
||||
|
||||
return Signature(parameters=parameters)
|
||||
|
||||
async def cleanup(self) -> None:
|
||||
"""Clean up server resources."""
|
||||
logger.info("Cleaning up resources")
|
||||
# Follow original cleanup logic - only clean browser tool
|
||||
if "browser" in self.tools and hasattr(self.tools["browser"], "cleanup"):
|
||||
await self.tools["browser"].cleanup()
|
||||
|
||||
def register_all_tools(self) -> None:
|
||||
"""Register all tools with the server."""
|
||||
for tool in self.tools.values():
|
||||
self.register_tool(tool)
|
||||
|
||||
def run(self, transport: str = "stdio") -> None:
|
||||
"""Run the MCP server."""
|
||||
# Register all tools
|
||||
self.register_all_tools()
|
||||
|
||||
# Register cleanup function (match original behavior)
|
||||
atexit.register(lambda: asyncio.run(self.cleanup()))
|
||||
|
||||
# Start server (with same logging as original)
|
||||
logger.info(f"Starting OpenManus server ({transport} mode)")
|
||||
self.server.run(transport=transport)
|
||||
|
||||
|
||||
def parse_args() -> argparse.Namespace:
|
||||
"""Parse command line arguments."""
|
||||
parser = argparse.ArgumentParser(description="OpenManus MCP Server")
|
||||
parser.add_argument(
|
||||
"--transport",
|
||||
choices=["stdio"],
|
||||
default="stdio",
|
||||
help="Communication method: stdio or http (default: stdio)",
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = parse_args()
|
||||
|
||||
# Create and run server (maintaining original flow)
|
||||
server = MCPServer()
|
||||
server.run(transport=args.transport)
|
@ -1,92 +0,0 @@
|
||||
SYSTEM_PROMPT = """\
|
||||
You are an AI agent designed to automate browser tasks. Your goal is to accomplish the ultimate task following the rules.
|
||||
|
||||
# Input Format
|
||||
Task
|
||||
Previous steps
|
||||
Current URL
|
||||
Open Tabs
|
||||
Interactive Elements
|
||||
[index]<type>text</type>
|
||||
- index: Numeric identifier for interaction
|
||||
- type: HTML element type (button, input, etc.)
|
||||
- text: Element description
|
||||
Example:
|
||||
[33]<button>Submit Form</button>
|
||||
|
||||
- Only elements with numeric indexes in [] are interactive
|
||||
- elements without [] provide only context
|
||||
|
||||
# Response Rules
|
||||
1. RESPONSE FORMAT: You must ALWAYS respond with valid JSON in this exact format:
|
||||
{{"current_state": {{"evaluation_previous_goal": "Success|Failed|Unknown - Analyze the current elements and the image to check if the previous goals/actions are successful like intended by the task. Mention if something unexpected happened. Shortly state why/why not",
|
||||
"memory": "Description of what has been done and what you need to remember. Be very specific. Count here ALWAYS how many times you have done something and how many remain. E.g. 0 out of 10 websites analyzed. Continue with abc and xyz",
|
||||
"next_goal": "What needs to be done with the next immediate action"}},
|
||||
"action":[{{"one_action_name": {{// action-specific parameter}}}}, // ... more actions in sequence]}}
|
||||
|
||||
2. ACTIONS: You can specify multiple actions in the list to be executed in sequence. But always specify only one action name per item. Use maximum {{max_actions}} actions per sequence.
|
||||
Common action sequences:
|
||||
- Form filling: [{{"input_text": {{"index": 1, "text": "username"}}}}, {{"input_text": {{"index": 2, "text": "password"}}}}, {{"click_element": {{"index": 3}}}}]
|
||||
- Navigation and extraction: [{{"go_to_url": {{"url": "https://example.com"}}}}, {{"extract_content": {{"goal": "extract the names"}}}}]
|
||||
- Actions are executed in the given order
|
||||
- If the page changes after an action, the sequence is interrupted and you get the new state.
|
||||
- Only provide the action sequence until an action which changes the page state significantly.
|
||||
- Try to be efficient, e.g. fill forms at once, or chain actions where nothing changes on the page
|
||||
- only use multiple actions if it makes sense.
|
||||
|
||||
3. ELEMENT INTERACTION:
|
||||
- Only use indexes of the interactive elements
|
||||
- Elements marked with "[]Non-interactive text" are non-interactive
|
||||
|
||||
4. NAVIGATION & ERROR HANDLING:
|
||||
- If no suitable elements exist, use other functions to complete the task
|
||||
- If stuck, try alternative approaches - like going back to a previous page, new search, new tab etc.
|
||||
- Handle popups/cookies by accepting or closing them
|
||||
- Use scroll to find elements you are looking for
|
||||
- If you want to research something, open a new tab instead of using the current tab
|
||||
- If captcha pops up, try to solve it - else try a different approach
|
||||
- If the page is not fully loaded, use wait action
|
||||
|
||||
5. TASK COMPLETION:
|
||||
- Use the done action as the last action as soon as the ultimate task is complete
|
||||
- Dont use "done" before you are done with everything the user asked you, except you reach the last step of max_steps.
|
||||
- If you reach your last step, use the done action even if the task is not fully finished. Provide all the information you have gathered so far. If the ultimate task is completly finished set success to true. If not everything the user asked for is completed set success in done to false!
|
||||
- If you have to do something repeatedly for example the task says for "each", or "for all", or "x times", count always inside "memory" how many times you have done it and how many remain. Don't stop until you have completed like the task asked you. Only call done after the last step.
|
||||
- Don't hallucinate actions
|
||||
- Make sure you include everything you found out for the ultimate task in the done text parameter. Do not just say you are done, but include the requested information of the task.
|
||||
|
||||
6. VISUAL CONTEXT:
|
||||
- When an image is provided, use it to understand the page layout
|
||||
- Bounding boxes with labels on their top right corner correspond to element indexes
|
||||
|
||||
7. Form filling:
|
||||
- If you fill an input field and your action sequence is interrupted, most often something changed e.g. suggestions popped up under the field.
|
||||
|
||||
8. Long tasks:
|
||||
- Keep track of the status and subresults in the memory.
|
||||
|
||||
9. Extraction:
|
||||
- If your task is to find information - call extract_content on the specific pages to get and store the information.
|
||||
Your responses must be always JSON with the specified format.
|
||||
"""
|
||||
|
||||
NEXT_STEP_PROMPT = """
|
||||
What should I do next to achieve my goal?
|
||||
|
||||
When you see [Current state starts here], focus on the following:
|
||||
- Current URL and page title{url_placeholder}
|
||||
- Available tabs{tabs_placeholder}
|
||||
- Interactive elements and their indices
|
||||
- Content above{content_above_placeholder} or below{content_below_placeholder} the viewport (if indicated)
|
||||
- Any action results or errors{results_placeholder}
|
||||
|
||||
For browser interactions:
|
||||
- To navigate: browser_use with action="go_to_url", url="..."
|
||||
- To click: browser_use with action="click_element", index=N
|
||||
- To type: browser_use with action="input_text", index=N, text="..."
|
||||
- To extract: browser_use with action="extract_content", goal="..."
|
||||
- To scroll: browser_use with action="scroll_down" or "scroll_up"
|
||||
|
||||
Consider both what's visible and what might be beyond the current viewport.
|
||||
Be methodical - remember your progress and what you've learned so far.
|
||||
"""
|
@ -1,8 +1,18 @@
|
||||
SYSTEM_PROMPT = (
|
||||
"You are OpenManus, an all-capable AI assistant, aimed at solving any task presented by the user. You have various tools at your disposal that you can call upon to efficiently complete complex requests. Whether it's programming, information retrieval, file processing, or web browsing, you can handle it all."
|
||||
"The initial directory is: {directory}"
|
||||
)
|
||||
SYSTEM_PROMPT = "You are OpenManus, an all-capable AI assistant, aimed at solving any task presented by the user. You have various tools at your disposal that you can call upon to efficiently complete complex requests. Whether it's programming, information retrieval, file processing, or web browsing, you can handle it all."
|
||||
|
||||
NEXT_STEP_PROMPT = """You can interact with the computer using PythonExecute, save important content and information files through FileSaver, open browsers with BrowserUseTool, and retrieve information using GoogleSearch.
|
||||
|
||||
PythonExecute: Execute Python code to interact with the computer system, data processing, automation tasks, etc.
|
||||
|
||||
FileSaver: Save files locally, such as txt, py, html, etc.
|
||||
|
||||
BrowserUseTool: Open, browse, and use web browsers.If you open a local HTML file, you must provide the absolute path to the file.
|
||||
|
||||
WebSearch: Perform web information retrieval
|
||||
|
||||
Terminate: End the current interaction when the task is complete or when you need additional information from the user. Use this tool to signal that you've finished addressing the user's request or need clarification before proceeding further.
|
||||
|
||||
NEXT_STEP_PROMPT = """
|
||||
Based on user needs, proactively select the most appropriate tool or combination of tools. For complex tasks, you can break down the problem and use different tools step by step to solve it. After using each tool, clearly explain the execution results and suggest the next steps.
|
||||
|
||||
Always maintain a helpful, informative tone throughout the interaction. If you encounter any limitations or need more details, clearly communicate this to the user before terminating.
|
||||
"""
|
||||
|
@ -1,43 +0,0 @@
|
||||
"""Prompts for the MCP Agent."""
|
||||
|
||||
SYSTEM_PROMPT = """You are an AI assistant with access to a Model Context Protocol (MCP) server.
|
||||
You can use the tools provided by the MCP server to complete tasks.
|
||||
The MCP server will dynamically expose tools that you can use - always check the available tools first.
|
||||
|
||||
When using an MCP tool:
|
||||
1. Choose the appropriate tool based on your task requirements
|
||||
2. Provide properly formatted arguments as required by the tool
|
||||
3. Observe the results and use them to determine next steps
|
||||
4. Tools may change during operation - new tools might appear or existing ones might disappear
|
||||
|
||||
Follow these guidelines:
|
||||
- Call tools with valid parameters as documented in their schemas
|
||||
- Handle errors gracefully by understanding what went wrong and trying again with corrected parameters
|
||||
- For multimedia responses (like images), you'll receive a description of the content
|
||||
- Complete user requests step by step, using the most appropriate tools
|
||||
- If multiple tools need to be called in sequence, make one call at a time and wait for results
|
||||
|
||||
Remember to clearly explain your reasoning and actions to the user.
|
||||
"""
|
||||
|
||||
NEXT_STEP_PROMPT = """Based on the current state and available tools, what should be done next?
|
||||
Think step by step about the problem and identify which MCP tool would be most helpful for the current stage.
|
||||
If you've already made progress, consider what additional information you need or what actions would move you closer to completing the task.
|
||||
"""
|
||||
|
||||
# Additional specialized prompts
|
||||
TOOL_ERROR_PROMPT = """You encountered an error with the tool '{tool_name}'.
|
||||
Try to understand what went wrong and correct your approach.
|
||||
Common issues include:
|
||||
- Missing or incorrect parameters
|
||||
- Invalid parameter formats
|
||||
- Using a tool that's no longer available
|
||||
- Attempting an operation that's not supported
|
||||
|
||||
Please check the tool specifications and try again with corrected parameters.
|
||||
"""
|
||||
|
||||
MULTIMEDIA_RESPONSE_PROMPT = """You've received a multimedia response (image, audio, etc.) from the tool '{tool_name}'.
|
||||
This content has been processed and described for you.
|
||||
Use this information to continue the task or provide insights to the user.
|
||||
"""
|
@ -1,30 +0,0 @@
|
||||
"""
|
||||
Docker Sandbox Module
|
||||
|
||||
Provides secure containerized execution environment with resource limits
|
||||
and isolation for running untrusted code.
|
||||
"""
|
||||
from app.sandbox.client import (
|
||||
BaseSandboxClient,
|
||||
LocalSandboxClient,
|
||||
create_sandbox_client,
|
||||
)
|
||||
from app.sandbox.core.exceptions import (
|
||||
SandboxError,
|
||||
SandboxResourceError,
|
||||
SandboxTimeoutError,
|
||||
)
|
||||
from app.sandbox.core.manager import SandboxManager
|
||||
from app.sandbox.core.sandbox import DockerSandbox
|
||||
|
||||
|
||||
__all__ = [
|
||||
"DockerSandbox",
|
||||
"SandboxManager",
|
||||
"BaseSandboxClient",
|
||||
"LocalSandboxClient",
|
||||
"create_sandbox_client",
|
||||
"SandboxError",
|
||||
"SandboxTimeoutError",
|
||||
"SandboxResourceError",
|
||||
]
|
@ -1,201 +0,0 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Dict, Optional, Protocol
|
||||
|
||||
from app.config import SandboxSettings
|
||||
from app.sandbox.core.sandbox import DockerSandbox
|
||||
|
||||
|
||||
class SandboxFileOperations(Protocol):
|
||||
"""Protocol for sandbox file operations."""
|
||||
|
||||
async def copy_from(self, container_path: str, local_path: str) -> None:
|
||||
"""Copies file from container to local.
|
||||
|
||||
Args:
|
||||
container_path: File path in container.
|
||||
local_path: Local destination path.
|
||||
"""
|
||||
...
|
||||
|
||||
async def copy_to(self, local_path: str, container_path: str) -> None:
|
||||
"""Copies file from local to container.
|
||||
|
||||
Args:
|
||||
local_path: Local source file path.
|
||||
container_path: Destination path in container.
|
||||
"""
|
||||
...
|
||||
|
||||
async def read_file(self, path: str) -> str:
|
||||
"""Reads file content from container.
|
||||
|
||||
Args:
|
||||
path: File path in container.
|
||||
|
||||
Returns:
|
||||
str: File content.
|
||||
"""
|
||||
...
|
||||
|
||||
async def write_file(self, path: str, content: str) -> None:
|
||||
"""Writes content to file in container.
|
||||
|
||||
Args:
|
||||
path: File path in container.
|
||||
content: Content to write.
|
||||
"""
|
||||
...
|
||||
|
||||
|
||||
class BaseSandboxClient(ABC):
|
||||
"""Base sandbox client interface."""
|
||||
|
||||
@abstractmethod
|
||||
async def create(
|
||||
self,
|
||||
config: Optional[SandboxSettings] = None,
|
||||
volume_bindings: Optional[Dict[str, str]] = None,
|
||||
) -> None:
|
||||
"""Creates sandbox."""
|
||||
|
||||
@abstractmethod
|
||||
async def run_command(self, command: str, timeout: Optional[int] = None) -> str:
|
||||
"""Executes command."""
|
||||
|
||||
@abstractmethod
|
||||
async def copy_from(self, container_path: str, local_path: str) -> None:
|
||||
"""Copies file from container."""
|
||||
|
||||
@abstractmethod
|
||||
async def copy_to(self, local_path: str, container_path: str) -> None:
|
||||
"""Copies file to container."""
|
||||
|
||||
@abstractmethod
|
||||
async def read_file(self, path: str) -> str:
|
||||
"""Reads file."""
|
||||
|
||||
@abstractmethod
|
||||
async def write_file(self, path: str, content: str) -> None:
|
||||
"""Writes file."""
|
||||
|
||||
@abstractmethod
|
||||
async def cleanup(self) -> None:
|
||||
"""Cleans up resources."""
|
||||
|
||||
|
||||
class LocalSandboxClient(BaseSandboxClient):
|
||||
"""Local sandbox client implementation."""
|
||||
|
||||
def __init__(self):
|
||||
"""Initializes local sandbox client."""
|
||||
self.sandbox: Optional[DockerSandbox] = None
|
||||
|
||||
async def create(
|
||||
self,
|
||||
config: Optional[SandboxSettings] = None,
|
||||
volume_bindings: Optional[Dict[str, str]] = None,
|
||||
) -> None:
|
||||
"""Creates a sandbox.
|
||||
|
||||
Args:
|
||||
config: Sandbox configuration.
|
||||
volume_bindings: Volume mappings.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If sandbox creation fails.
|
||||
"""
|
||||
self.sandbox = DockerSandbox(config, volume_bindings)
|
||||
await self.sandbox.create()
|
||||
|
||||
async def run_command(self, command: str, timeout: Optional[int] = None) -> str:
|
||||
"""Runs command in sandbox.
|
||||
|
||||
Args:
|
||||
command: Command to execute.
|
||||
timeout: Execution timeout in seconds.
|
||||
|
||||
Returns:
|
||||
Command output.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If sandbox not initialized.
|
||||
"""
|
||||
if not self.sandbox:
|
||||
raise RuntimeError("Sandbox not initialized")
|
||||
return await self.sandbox.run_command(command, timeout)
|
||||
|
||||
async def copy_from(self, container_path: str, local_path: str) -> None:
|
||||
"""Copies file from container to local.
|
||||
|
||||
Args:
|
||||
container_path: File path in container.
|
||||
local_path: Local destination path.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If sandbox not initialized.
|
||||
"""
|
||||
if not self.sandbox:
|
||||
raise RuntimeError("Sandbox not initialized")
|
||||
await self.sandbox.copy_from(container_path, local_path)
|
||||
|
||||
async def copy_to(self, local_path: str, container_path: str) -> None:
|
||||
"""Copies file from local to container.
|
||||
|
||||
Args:
|
||||
local_path: Local source file path.
|
||||
container_path: Destination path in container.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If sandbox not initialized.
|
||||
"""
|
||||
if not self.sandbox:
|
||||
raise RuntimeError("Sandbox not initialized")
|
||||
await self.sandbox.copy_to(local_path, container_path)
|
||||
|
||||
async def read_file(self, path: str) -> str:
|
||||
"""Reads file from container.
|
||||
|
||||
Args:
|
||||
path: File path in container.
|
||||
|
||||
Returns:
|
||||
File content.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If sandbox not initialized.
|
||||
"""
|
||||
if not self.sandbox:
|
||||
raise RuntimeError("Sandbox not initialized")
|
||||
return await self.sandbox.read_file(path)
|
||||
|
||||
async def write_file(self, path: str, content: str) -> None:
|
||||
"""Writes file to container.
|
||||
|
||||
Args:
|
||||
path: File path in container.
|
||||
content: File content.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If sandbox not initialized.
|
||||
"""
|
||||
if not self.sandbox:
|
||||
raise RuntimeError("Sandbox not initialized")
|
||||
await self.sandbox.write_file(path, content)
|
||||
|
||||
async def cleanup(self) -> None:
|
||||
"""Cleans up resources."""
|
||||
if self.sandbox:
|
||||
await self.sandbox.cleanup()
|
||||
self.sandbox = None
|
||||
|
||||
|
||||
def create_sandbox_client() -> LocalSandboxClient:
|
||||
"""Creates a sandbox client.
|
||||
|
||||
Returns:
|
||||
LocalSandboxClient: Sandbox client instance.
|
||||
"""
|
||||
return LocalSandboxClient()
|
||||
|
||||
|
||||
SANDBOX_CLIENT = create_sandbox_client()
|
@ -1,17 +0,0 @@
|
||||
"""Exception classes for the sandbox system.
|
||||
|
||||
This module defines custom exceptions used throughout the sandbox system to
|
||||
handle various error conditions in a structured way.
|
||||
"""
|
||||
|
||||
|
||||
class SandboxError(Exception):
|
||||
"""Base exception for sandbox-related errors."""
|
||||
|
||||
|
||||
class SandboxTimeoutError(SandboxError):
|
||||
"""Exception raised when a sandbox operation times out."""
|
||||
|
||||
|
||||
class SandboxResourceError(SandboxError):
|
||||
"""Exception raised for resource-related errors."""
|
@ -1,313 +0,0 @@
|
||||
import asyncio
|
||||
import uuid
|
||||
from contextlib import asynccontextmanager
|
||||
from typing import Dict, Optional, Set
|
||||
|
||||
import docker
|
||||
from docker.errors import APIError, ImageNotFound
|
||||
|
||||
from app.config import SandboxSettings
|
||||
from app.logger import logger
|
||||
from app.sandbox.core.sandbox import DockerSandbox
|
||||
|
||||
|
||||
class SandboxManager:
|
||||
"""Docker sandbox manager.
|
||||
|
||||
Manages multiple DockerSandbox instances lifecycle including creation,
|
||||
monitoring, and cleanup. Provides concurrent access control and automatic
|
||||
cleanup mechanisms for sandbox resources.
|
||||
|
||||
Attributes:
|
||||
max_sandboxes: Maximum allowed number of sandboxes.
|
||||
idle_timeout: Sandbox idle timeout in seconds.
|
||||
cleanup_interval: Cleanup check interval in seconds.
|
||||
_sandboxes: Active sandbox instance mapping.
|
||||
_last_used: Last used time record for sandboxes.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
max_sandboxes: int = 100,
|
||||
idle_timeout: int = 3600,
|
||||
cleanup_interval: int = 300,
|
||||
):
|
||||
"""Initializes sandbox manager.
|
||||
|
||||
Args:
|
||||
max_sandboxes: Maximum sandbox count limit.
|
||||
idle_timeout: Idle timeout in seconds.
|
||||
cleanup_interval: Cleanup check interval in seconds.
|
||||
"""
|
||||
self.max_sandboxes = max_sandboxes
|
||||
self.idle_timeout = idle_timeout
|
||||
self.cleanup_interval = cleanup_interval
|
||||
|
||||
# Docker client
|
||||
self._client = docker.from_env()
|
||||
|
||||
# Resource mappings
|
||||
self._sandboxes: Dict[str, DockerSandbox] = {}
|
||||
self._last_used: Dict[str, float] = {}
|
||||
|
||||
# Concurrency control
|
||||
self._locks: Dict[str, asyncio.Lock] = {}
|
||||
self._global_lock = asyncio.Lock()
|
||||
self._active_operations: Set[str] = set()
|
||||
|
||||
# Cleanup task
|
||||
self._cleanup_task: Optional[asyncio.Task] = None
|
||||
self._is_shutting_down = False
|
||||
|
||||
# Start automatic cleanup
|
||||
self.start_cleanup_task()
|
||||
|
||||
async def ensure_image(self, image: str) -> bool:
|
||||
"""Ensures Docker image is available.
|
||||
|
||||
Args:
|
||||
image: Image name.
|
||||
|
||||
Returns:
|
||||
bool: Whether image is available.
|
||||
"""
|
||||
try:
|
||||
self._client.images.get(image)
|
||||
return True
|
||||
except ImageNotFound:
|
||||
try:
|
||||
logger.info(f"Pulling image {image}...")
|
||||
await asyncio.get_event_loop().run_in_executor(
|
||||
None, self._client.images.pull, image
|
||||
)
|
||||
return True
|
||||
except (APIError, Exception) as e:
|
||||
logger.error(f"Failed to pull image {image}: {e}")
|
||||
return False
|
||||
|
||||
@asynccontextmanager
|
||||
async def sandbox_operation(self, sandbox_id: str):
|
||||
"""Context manager for sandbox operations.
|
||||
|
||||
Provides concurrency control and usage time updates.
|
||||
|
||||
Args:
|
||||
sandbox_id: Sandbox ID.
|
||||
|
||||
Raises:
|
||||
KeyError: If sandbox not found.
|
||||
"""
|
||||
if sandbox_id not in self._locks:
|
||||
self._locks[sandbox_id] = asyncio.Lock()
|
||||
|
||||
async with self._locks[sandbox_id]:
|
||||
if sandbox_id not in self._sandboxes:
|
||||
raise KeyError(f"Sandbox {sandbox_id} not found")
|
||||
|
||||
self._active_operations.add(sandbox_id)
|
||||
try:
|
||||
self._last_used[sandbox_id] = asyncio.get_event_loop().time()
|
||||
yield self._sandboxes[sandbox_id]
|
||||
finally:
|
||||
self._active_operations.remove(sandbox_id)
|
||||
|
||||
async def create_sandbox(
|
||||
self,
|
||||
config: Optional[SandboxSettings] = None,
|
||||
volume_bindings: Optional[Dict[str, str]] = None,
|
||||
) -> str:
|
||||
"""Creates a new sandbox instance.
|
||||
|
||||
Args:
|
||||
config: Sandbox configuration.
|
||||
volume_bindings: Volume mapping configuration.
|
||||
|
||||
Returns:
|
||||
str: Sandbox ID.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If max sandbox count reached or creation fails.
|
||||
"""
|
||||
async with self._global_lock:
|
||||
if len(self._sandboxes) >= self.max_sandboxes:
|
||||
raise RuntimeError(
|
||||
f"Maximum number of sandboxes ({self.max_sandboxes}) reached"
|
||||
)
|
||||
|
||||
config = config or SandboxSettings()
|
||||
if not await self.ensure_image(config.image):
|
||||
raise RuntimeError(f"Failed to ensure Docker image: {config.image}")
|
||||
|
||||
sandbox_id = str(uuid.uuid4())
|
||||
try:
|
||||
sandbox = DockerSandbox(config, volume_bindings)
|
||||
await sandbox.create()
|
||||
|
||||
self._sandboxes[sandbox_id] = sandbox
|
||||
self._last_used[sandbox_id] = asyncio.get_event_loop().time()
|
||||
self._locks[sandbox_id] = asyncio.Lock()
|
||||
|
||||
logger.info(f"Created sandbox {sandbox_id}")
|
||||
return sandbox_id
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to create sandbox: {e}")
|
||||
if sandbox_id in self._sandboxes:
|
||||
await self.delete_sandbox(sandbox_id)
|
||||
raise RuntimeError(f"Failed to create sandbox: {e}")
|
||||
|
||||
async def get_sandbox(self, sandbox_id: str) -> DockerSandbox:
|
||||
"""Gets a sandbox instance.
|
||||
|
||||
Args:
|
||||
sandbox_id: Sandbox ID.
|
||||
|
||||
Returns:
|
||||
DockerSandbox: Sandbox instance.
|
||||
|
||||
Raises:
|
||||
KeyError: If sandbox does not exist.
|
||||
"""
|
||||
async with self.sandbox_operation(sandbox_id) as sandbox:
|
||||
return sandbox
|
||||
|
||||
def start_cleanup_task(self) -> None:
|
||||
"""Starts automatic cleanup task."""
|
||||
|
||||
async def cleanup_loop():
|
||||
while not self._is_shutting_down:
|
||||
try:
|
||||
await self._cleanup_idle_sandboxes()
|
||||
except Exception as e:
|
||||
logger.error(f"Error in cleanup loop: {e}")
|
||||
await asyncio.sleep(self.cleanup_interval)
|
||||
|
||||
self._cleanup_task = asyncio.create_task(cleanup_loop())
|
||||
|
||||
async def _cleanup_idle_sandboxes(self) -> None:
|
||||
"""Cleans up idle sandboxes."""
|
||||
current_time = asyncio.get_event_loop().time()
|
||||
to_cleanup = []
|
||||
|
||||
async with self._global_lock:
|
||||
for sandbox_id, last_used in self._last_used.items():
|
||||
if (
|
||||
sandbox_id not in self._active_operations
|
||||
and current_time - last_used > self.idle_timeout
|
||||
):
|
||||
to_cleanup.append(sandbox_id)
|
||||
|
||||
for sandbox_id in to_cleanup:
|
||||
try:
|
||||
await self.delete_sandbox(sandbox_id)
|
||||
except Exception as e:
|
||||
logger.error(f"Error cleaning up sandbox {sandbox_id}: {e}")
|
||||
|
||||
async def cleanup(self) -> None:
|
||||
"""Cleans up all resources."""
|
||||
logger.info("Starting manager cleanup...")
|
||||
self._is_shutting_down = True
|
||||
|
||||
# Cancel cleanup task
|
||||
if self._cleanup_task:
|
||||
self._cleanup_task.cancel()
|
||||
try:
|
||||
await asyncio.wait_for(self._cleanup_task, timeout=1.0)
|
||||
except (asyncio.CancelledError, asyncio.TimeoutError):
|
||||
pass
|
||||
|
||||
# Get all sandbox IDs to clean up
|
||||
async with self._global_lock:
|
||||
sandbox_ids = list(self._sandboxes.keys())
|
||||
|
||||
# Concurrently clean up all sandboxes
|
||||
cleanup_tasks = []
|
||||
for sandbox_id in sandbox_ids:
|
||||
task = asyncio.create_task(self._safe_delete_sandbox(sandbox_id))
|
||||
cleanup_tasks.append(task)
|
||||
|
||||
if cleanup_tasks:
|
||||
# Wait for all cleanup tasks to complete, with timeout to avoid infinite waiting
|
||||
try:
|
||||
await asyncio.wait(cleanup_tasks, timeout=30.0)
|
||||
except asyncio.TimeoutError:
|
||||
logger.error("Sandbox cleanup timed out")
|
||||
|
||||
# Clean up remaining references
|
||||
self._sandboxes.clear()
|
||||
self._last_used.clear()
|
||||
self._locks.clear()
|
||||
self._active_operations.clear()
|
||||
|
||||
logger.info("Manager cleanup completed")
|
||||
|
||||
async def _safe_delete_sandbox(self, sandbox_id: str) -> None:
|
||||
"""Safely deletes a single sandbox.
|
||||
|
||||
Args:
|
||||
sandbox_id: Sandbox ID to delete.
|
||||
"""
|
||||
try:
|
||||
if sandbox_id in self._active_operations:
|
||||
logger.warning(
|
||||
f"Sandbox {sandbox_id} has active operations, waiting for completion"
|
||||
)
|
||||
for _ in range(10): # Wait at most 10 times
|
||||
await asyncio.sleep(0.5)
|
||||
if sandbox_id not in self._active_operations:
|
||||
break
|
||||
else:
|
||||
logger.warning(
|
||||
f"Timeout waiting for sandbox {sandbox_id} operations to complete"
|
||||
)
|
||||
|
||||
# Get reference to sandbox object
|
||||
sandbox = self._sandboxes.get(sandbox_id)
|
||||
if sandbox:
|
||||
await sandbox.cleanup()
|
||||
|
||||
# Remove sandbox record from manager
|
||||
async with self._global_lock:
|
||||
self._sandboxes.pop(sandbox_id, None)
|
||||
self._last_used.pop(sandbox_id, None)
|
||||
self._locks.pop(sandbox_id, None)
|
||||
logger.info(f"Deleted sandbox {sandbox_id}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error during cleanup of sandbox {sandbox_id}: {e}")
|
||||
|
||||
async def delete_sandbox(self, sandbox_id: str) -> None:
|
||||
"""Deletes specified sandbox.
|
||||
|
||||
Args:
|
||||
sandbox_id: Sandbox ID.
|
||||
"""
|
||||
if sandbox_id not in self._sandboxes:
|
||||
return
|
||||
|
||||
try:
|
||||
await self._safe_delete_sandbox(sandbox_id)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to delete sandbox {sandbox_id}: {e}")
|
||||
|
||||
async def __aenter__(self) -> "SandboxManager":
|
||||
"""Async context manager entry."""
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc_val, exc_tb) -> None:
|
||||
"""Async context manager exit."""
|
||||
await self.cleanup()
|
||||
|
||||
def get_stats(self) -> Dict:
|
||||
"""Gets manager statistics.
|
||||
|
||||
Returns:
|
||||
Dict: Statistics information.
|
||||
"""
|
||||
return {
|
||||
"total_sandboxes": len(self._sandboxes),
|
||||
"active_operations": len(self._active_operations),
|
||||
"max_sandboxes": self.max_sandboxes,
|
||||
"idle_timeout": self.idle_timeout,
|
||||
"cleanup_interval": self.cleanup_interval,
|
||||
"is_shutting_down": self._is_shutting_down,
|
||||
}
|
@ -1,462 +0,0 @@
|
||||
import asyncio
|
||||
import io
|
||||
import os
|
||||
import tarfile
|
||||
import tempfile
|
||||
import uuid
|
||||
from typing import Dict, Optional
|
||||
|
||||
import docker
|
||||
from docker.errors import NotFound
|
||||
from docker.models.containers import Container
|
||||
|
||||
from app.config import SandboxSettings
|
||||
from app.sandbox.core.exceptions import SandboxTimeoutError
|
||||
from app.sandbox.core.terminal import AsyncDockerizedTerminal
|
||||
|
||||
|
||||
class DockerSandbox:
|
||||
"""Docker sandbox environment.
|
||||
|
||||
Provides a containerized execution environment with resource limits,
|
||||
file operations, and command execution capabilities.
|
||||
|
||||
Attributes:
|
||||
config: Sandbox configuration.
|
||||
volume_bindings: Volume mapping configuration.
|
||||
client: Docker client.
|
||||
container: Docker container instance.
|
||||
terminal: Container terminal interface.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: Optional[SandboxSettings] = None,
|
||||
volume_bindings: Optional[Dict[str, str]] = None,
|
||||
):
|
||||
"""Initializes a sandbox instance.
|
||||
|
||||
Args:
|
||||
config: Sandbox configuration. Default configuration used if None.
|
||||
volume_bindings: Volume mappings in {host_path: container_path} format.
|
||||
"""
|
||||
self.config = config or SandboxSettings()
|
||||
self.volume_bindings = volume_bindings or {}
|
||||
self.client = docker.from_env()
|
||||
self.container: Optional[Container] = None
|
||||
self.terminal: Optional[AsyncDockerizedTerminal] = None
|
||||
|
||||
async def create(self) -> "DockerSandbox":
|
||||
"""Creates and starts the sandbox container.
|
||||
|
||||
Returns:
|
||||
Current sandbox instance.
|
||||
|
||||
Raises:
|
||||
docker.errors.APIError: If Docker API call fails.
|
||||
RuntimeError: If container creation or startup fails.
|
||||
"""
|
||||
try:
|
||||
# Prepare container config
|
||||
host_config = self.client.api.create_host_config(
|
||||
mem_limit=self.config.memory_limit,
|
||||
cpu_period=100000,
|
||||
cpu_quota=int(100000 * self.config.cpu_limit),
|
||||
network_mode="none" if not self.config.network_enabled else "bridge",
|
||||
binds=self._prepare_volume_bindings(),
|
||||
)
|
||||
|
||||
# Generate unique container name with sandbox_ prefix
|
||||
container_name = f"sandbox_{uuid.uuid4().hex[:8]}"
|
||||
|
||||
# Create container
|
||||
container = await asyncio.to_thread(
|
||||
self.client.api.create_container,
|
||||
image=self.config.image,
|
||||
command="tail -f /dev/null",
|
||||
hostname="sandbox",
|
||||
working_dir=self.config.work_dir,
|
||||
host_config=host_config,
|
||||
name=container_name,
|
||||
tty=True,
|
||||
detach=True,
|
||||
)
|
||||
|
||||
self.container = self.client.containers.get(container["Id"])
|
||||
|
||||
# Start container
|
||||
await asyncio.to_thread(self.container.start)
|
||||
|
||||
# Initialize terminal
|
||||
self.terminal = AsyncDockerizedTerminal(
|
||||
container["Id"],
|
||||
self.config.work_dir,
|
||||
env_vars={"PYTHONUNBUFFERED": "1"}
|
||||
# Ensure Python output is not buffered
|
||||
)
|
||||
await self.terminal.init()
|
||||
|
||||
return self
|
||||
|
||||
except Exception as e:
|
||||
await self.cleanup() # Ensure resources are cleaned up
|
||||
raise RuntimeError(f"Failed to create sandbox: {e}") from e
|
||||
|
||||
def _prepare_volume_bindings(self) -> Dict[str, Dict[str, str]]:
|
||||
"""Prepares volume binding configuration.
|
||||
|
||||
Returns:
|
||||
Volume binding configuration dictionary.
|
||||
"""
|
||||
bindings = {}
|
||||
|
||||
# Create and add working directory mapping
|
||||
work_dir = self._ensure_host_dir(self.config.work_dir)
|
||||
bindings[work_dir] = {"bind": self.config.work_dir, "mode": "rw"}
|
||||
|
||||
# Add custom volume bindings
|
||||
for host_path, container_path in self.volume_bindings.items():
|
||||
bindings[host_path] = {"bind": container_path, "mode": "rw"}
|
||||
|
||||
return bindings
|
||||
|
||||
@staticmethod
|
||||
def _ensure_host_dir(path: str) -> str:
|
||||
"""Ensures directory exists on the host.
|
||||
|
||||
Args:
|
||||
path: Directory path.
|
||||
|
||||
Returns:
|
||||
Actual path on the host.
|
||||
"""
|
||||
host_path = os.path.join(
|
||||
tempfile.gettempdir(),
|
||||
f"sandbox_{os.path.basename(path)}_{os.urandom(4).hex()}",
|
||||
)
|
||||
os.makedirs(host_path, exist_ok=True)
|
||||
return host_path
|
||||
|
||||
async def run_command(self, cmd: str, timeout: Optional[int] = None) -> str:
|
||||
"""Runs a command in the sandbox.
|
||||
|
||||
Args:
|
||||
cmd: Command to execute.
|
||||
timeout: Timeout in seconds.
|
||||
|
||||
Returns:
|
||||
Command output as string.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If sandbox not initialized or command execution fails.
|
||||
TimeoutError: If command execution times out.
|
||||
"""
|
||||
if not self.terminal:
|
||||
raise RuntimeError("Sandbox not initialized")
|
||||
|
||||
try:
|
||||
return await self.terminal.run_command(
|
||||
cmd, timeout=timeout or self.config.timeout
|
||||
)
|
||||
except TimeoutError:
|
||||
raise SandboxTimeoutError(
|
||||
f"Command execution timed out after {timeout or self.config.timeout} seconds"
|
||||
)
|
||||
|
||||
async def read_file(self, path: str) -> str:
|
||||
"""Reads a file from the container.
|
||||
|
||||
Args:
|
||||
path: File path.
|
||||
|
||||
Returns:
|
||||
File contents as string.
|
||||
|
||||
Raises:
|
||||
FileNotFoundError: If file does not exist.
|
||||
RuntimeError: If read operation fails.
|
||||
"""
|
||||
if not self.container:
|
||||
raise RuntimeError("Sandbox not initialized")
|
||||
|
||||
try:
|
||||
# Get file archive
|
||||
resolved_path = self._safe_resolve_path(path)
|
||||
tar_stream, _ = await asyncio.to_thread(
|
||||
self.container.get_archive, resolved_path
|
||||
)
|
||||
|
||||
# Read file content from tar stream
|
||||
content = await self._read_from_tar(tar_stream)
|
||||
return content.decode("utf-8")
|
||||
|
||||
except NotFound:
|
||||
raise FileNotFoundError(f"File not found: {path}")
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Failed to read file: {e}")
|
||||
|
||||
async def write_file(self, path: str, content: str) -> None:
|
||||
"""Writes content to a file in the container.
|
||||
|
||||
Args:
|
||||
path: Target path.
|
||||
content: File content.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If write operation fails.
|
||||
"""
|
||||
if not self.container:
|
||||
raise RuntimeError("Sandbox not initialized")
|
||||
|
||||
try:
|
||||
resolved_path = self._safe_resolve_path(path)
|
||||
parent_dir = os.path.dirname(resolved_path)
|
||||
|
||||
# Create parent directory
|
||||
if parent_dir:
|
||||
await self.run_command(f"mkdir -p {parent_dir}")
|
||||
|
||||
# Prepare file data
|
||||
tar_stream = await self._create_tar_stream(
|
||||
os.path.basename(path), content.encode("utf-8")
|
||||
)
|
||||
|
||||
# Write file
|
||||
await asyncio.to_thread(
|
||||
self.container.put_archive, parent_dir or "/", tar_stream
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Failed to write file: {e}")
|
||||
|
||||
def _safe_resolve_path(self, path: str) -> str:
|
||||
"""Safely resolves container path, preventing path traversal.
|
||||
|
||||
Args:
|
||||
path: Original path.
|
||||
|
||||
Returns:
|
||||
Resolved absolute path.
|
||||
|
||||
Raises:
|
||||
ValueError: If path contains potentially unsafe patterns.
|
||||
"""
|
||||
# Check for path traversal attempts
|
||||
if ".." in path.split("/"):
|
||||
raise ValueError("Path contains potentially unsafe patterns")
|
||||
|
||||
resolved = (
|
||||
os.path.join(self.config.work_dir, path)
|
||||
if not os.path.isabs(path)
|
||||
else path
|
||||
)
|
||||
return resolved
|
||||
|
||||
async def copy_from(self, src_path: str, dst_path: str) -> None:
|
||||
"""Copies a file from the container.
|
||||
|
||||
Args:
|
||||
src_path: Source file path (container).
|
||||
dst_path: Destination path (host).
|
||||
|
||||
Raises:
|
||||
FileNotFoundError: If source file does not exist.
|
||||
RuntimeError: If copy operation fails.
|
||||
"""
|
||||
try:
|
||||
# Ensure destination file's parent directory exists
|
||||
parent_dir = os.path.dirname(dst_path)
|
||||
if parent_dir:
|
||||
os.makedirs(parent_dir, exist_ok=True)
|
||||
|
||||
# Get file stream
|
||||
resolved_src = self._safe_resolve_path(src_path)
|
||||
stream, stat = await asyncio.to_thread(
|
||||
self.container.get_archive, resolved_src
|
||||
)
|
||||
|
||||
# Create temporary directory to extract file
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
# Write stream to temporary file
|
||||
tar_path = os.path.join(tmp_dir, "temp.tar")
|
||||
with open(tar_path, "wb") as f:
|
||||
for chunk in stream:
|
||||
f.write(chunk)
|
||||
|
||||
# Extract file
|
||||
with tarfile.open(tar_path) as tar:
|
||||
members = tar.getmembers()
|
||||
if not members:
|
||||
raise FileNotFoundError(f"Source file is empty: {src_path}")
|
||||
|
||||
# If destination is a directory, we should preserve relative path structure
|
||||
if os.path.isdir(dst_path):
|
||||
tar.extractall(dst_path)
|
||||
else:
|
||||
# If destination is a file, we only extract the source file's content
|
||||
if len(members) > 1:
|
||||
raise RuntimeError(
|
||||
f"Source path is a directory but destination is a file: {src_path}"
|
||||
)
|
||||
|
||||
with open(dst_path, "wb") as dst:
|
||||
src_file = tar.extractfile(members[0])
|
||||
if src_file is None:
|
||||
raise RuntimeError(
|
||||
f"Failed to extract file: {src_path}"
|
||||
)
|
||||
dst.write(src_file.read())
|
||||
|
||||
except docker.errors.NotFound:
|
||||
raise FileNotFoundError(f"Source file not found: {src_path}")
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Failed to copy file: {e}")
|
||||
|
||||
async def copy_to(self, src_path: str, dst_path: str) -> None:
|
||||
"""Copies a file to the container.
|
||||
|
||||
Args:
|
||||
src_path: Source file path (host).
|
||||
dst_path: Destination path (container).
|
||||
|
||||
Raises:
|
||||
FileNotFoundError: If source file does not exist.
|
||||
RuntimeError: If copy operation fails.
|
||||
"""
|
||||
try:
|
||||
if not os.path.exists(src_path):
|
||||
raise FileNotFoundError(f"Source file not found: {src_path}")
|
||||
|
||||
# Create destination directory in container
|
||||
resolved_dst = self._safe_resolve_path(dst_path)
|
||||
container_dir = os.path.dirname(resolved_dst)
|
||||
if container_dir:
|
||||
await self.run_command(f"mkdir -p {container_dir}")
|
||||
|
||||
# Create tar file to upload
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
tar_path = os.path.join(tmp_dir, "temp.tar")
|
||||
with tarfile.open(tar_path, "w") as tar:
|
||||
# Handle directory source path
|
||||
if os.path.isdir(src_path):
|
||||
os.path.basename(src_path.rstrip("/"))
|
||||
for root, _, files in os.walk(src_path):
|
||||
for file in files:
|
||||
file_path = os.path.join(root, file)
|
||||
arcname = os.path.join(
|
||||
os.path.basename(dst_path),
|
||||
os.path.relpath(file_path, src_path),
|
||||
)
|
||||
tar.add(file_path, arcname=arcname)
|
||||
else:
|
||||
# Add single file to tar
|
||||
tar.add(src_path, arcname=os.path.basename(dst_path))
|
||||
|
||||
# Read tar file content
|
||||
with open(tar_path, "rb") as f:
|
||||
data = f.read()
|
||||
|
||||
# Upload to container
|
||||
await asyncio.to_thread(
|
||||
self.container.put_archive,
|
||||
os.path.dirname(resolved_dst) or "/",
|
||||
data,
|
||||
)
|
||||
|
||||
# Verify file was created successfully
|
||||
try:
|
||||
await self.run_command(f"test -e {resolved_dst}")
|
||||
except Exception:
|
||||
raise RuntimeError(f"Failed to verify file creation: {dst_path}")
|
||||
|
||||
except FileNotFoundError:
|
||||
raise
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Failed to copy file: {e}")
|
||||
|
||||
@staticmethod
|
||||
async def _create_tar_stream(name: str, content: bytes) -> io.BytesIO:
|
||||
"""Creates a tar file stream.
|
||||
|
||||
Args:
|
||||
name: Filename.
|
||||
content: File content.
|
||||
|
||||
Returns:
|
||||
Tar file stream.
|
||||
"""
|
||||
tar_stream = io.BytesIO()
|
||||
with tarfile.open(fileobj=tar_stream, mode="w") as tar:
|
||||
tarinfo = tarfile.TarInfo(name=name)
|
||||
tarinfo.size = len(content)
|
||||
tar.addfile(tarinfo, io.BytesIO(content))
|
||||
tar_stream.seek(0)
|
||||
return tar_stream
|
||||
|
||||
@staticmethod
|
||||
async def _read_from_tar(tar_stream) -> bytes:
|
||||
"""Reads file content from a tar stream.
|
||||
|
||||
Args:
|
||||
tar_stream: Tar file stream.
|
||||
|
||||
Returns:
|
||||
File content.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If read operation fails.
|
||||
"""
|
||||
with tempfile.NamedTemporaryFile() as tmp:
|
||||
for chunk in tar_stream:
|
||||
tmp.write(chunk)
|
||||
tmp.seek(0)
|
||||
|
||||
with tarfile.open(fileobj=tmp) as tar:
|
||||
member = tar.next()
|
||||
if not member:
|
||||
raise RuntimeError("Empty tar archive")
|
||||
|
||||
file_content = tar.extractfile(member)
|
||||
if not file_content:
|
||||
raise RuntimeError("Failed to extract file content")
|
||||
|
||||
return file_content.read()
|
||||
|
||||
async def cleanup(self) -> None:
|
||||
"""Cleans up sandbox resources."""
|
||||
errors = []
|
||||
try:
|
||||
if self.terminal:
|
||||
try:
|
||||
await self.terminal.close()
|
||||
except Exception as e:
|
||||
errors.append(f"Terminal cleanup error: {e}")
|
||||
finally:
|
||||
self.terminal = None
|
||||
|
||||
if self.container:
|
||||
try:
|
||||
await asyncio.to_thread(self.container.stop, timeout=5)
|
||||
except Exception as e:
|
||||
errors.append(f"Container stop error: {e}")
|
||||
|
||||
try:
|
||||
await asyncio.to_thread(self.container.remove, force=True)
|
||||
except Exception as e:
|
||||
errors.append(f"Container remove error: {e}")
|
||||
finally:
|
||||
self.container = None
|
||||
|
||||
except Exception as e:
|
||||
errors.append(f"General cleanup error: {e}")
|
||||
|
||||
if errors:
|
||||
print(f"Warning: Errors during cleanup: {', '.join(errors)}")
|
||||
|
||||
async def __aenter__(self) -> "DockerSandbox":
|
||||
"""Async context manager entry."""
|
||||
return await self.create()
|
||||
|
||||
async def __aexit__(self, exc_type, exc_val, exc_tb) -> None:
|
||||
"""Async context manager exit."""
|
||||
await self.cleanup()
|
@ -1,346 +0,0 @@
|
||||
"""
|
||||
Asynchronous Docker Terminal
|
||||
|
||||
This module provides asynchronous terminal functionality for Docker containers,
|
||||
allowing interactive command execution with timeout control.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import re
|
||||
import socket
|
||||
from typing import Dict, Optional, Tuple, Union
|
||||
|
||||
import docker
|
||||
from docker import APIClient
|
||||
from docker.errors import APIError
|
||||
from docker.models.containers import Container
|
||||
|
||||
|
||||
class DockerSession:
|
||||
def __init__(self, container_id: str) -> None:
|
||||
"""Initializes a Docker session.
|
||||
|
||||
Args:
|
||||
container_id: ID of the Docker container.
|
||||
"""
|
||||
self.api = APIClient()
|
||||
self.container_id = container_id
|
||||
self.exec_id = None
|
||||
self.socket = None
|
||||
|
||||
async def create(self, working_dir: str, env_vars: Dict[str, str]) -> None:
|
||||
"""Creates an interactive session with the container.
|
||||
|
||||
Args:
|
||||
working_dir: Working directory inside the container.
|
||||
env_vars: Environment variables to set.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If socket connection fails.
|
||||
"""
|
||||
startup_command = [
|
||||
"bash",
|
||||
"-c",
|
||||
f"cd {working_dir} && "
|
||||
"PROMPT_COMMAND='' "
|
||||
"PS1='$ ' "
|
||||
"exec bash --norc --noprofile",
|
||||
]
|
||||
|
||||
exec_data = self.api.exec_create(
|
||||
self.container_id,
|
||||
startup_command,
|
||||
stdin=True,
|
||||
tty=True,
|
||||
stdout=True,
|
||||
stderr=True,
|
||||
privileged=True,
|
||||
user="root",
|
||||
environment={**env_vars, "TERM": "dumb", "PS1": "$ ", "PROMPT_COMMAND": ""},
|
||||
)
|
||||
self.exec_id = exec_data["Id"]
|
||||
|
||||
socket_data = self.api.exec_start(
|
||||
self.exec_id, socket=True, tty=True, stream=True, demux=True
|
||||
)
|
||||
|
||||
if hasattr(socket_data, "_sock"):
|
||||
self.socket = socket_data._sock
|
||||
self.socket.setblocking(False)
|
||||
else:
|
||||
raise RuntimeError("Failed to get socket connection")
|
||||
|
||||
await self._read_until_prompt()
|
||||
|
||||
async def close(self) -> None:
|
||||
"""Cleans up session resources.
|
||||
|
||||
1. Sends exit command
|
||||
2. Closes socket connection
|
||||
3. Checks and cleans up exec instance
|
||||
"""
|
||||
try:
|
||||
if self.socket:
|
||||
# Send exit command to close bash session
|
||||
try:
|
||||
self.socket.sendall(b"exit\n")
|
||||
# Allow time for command execution
|
||||
await asyncio.sleep(0.1)
|
||||
except:
|
||||
pass # Ignore sending errors, continue cleanup
|
||||
|
||||
# Close socket connection
|
||||
try:
|
||||
self.socket.shutdown(socket.SHUT_RDWR)
|
||||
except:
|
||||
pass # Some platforms may not support shutdown
|
||||
|
||||
self.socket.close()
|
||||
self.socket = None
|
||||
|
||||
if self.exec_id:
|
||||
try:
|
||||
# Check exec instance status
|
||||
exec_inspect = self.api.exec_inspect(self.exec_id)
|
||||
if exec_inspect.get("Running", False):
|
||||
# If still running, wait for it to complete
|
||||
await asyncio.sleep(0.5)
|
||||
except:
|
||||
pass # Ignore inspection errors, continue cleanup
|
||||
|
||||
self.exec_id = None
|
||||
|
||||
except Exception as e:
|
||||
# Log error but don't raise, ensure cleanup continues
|
||||
print(f"Warning: Error during session cleanup: {e}")
|
||||
|
||||
async def _read_until_prompt(self) -> str:
|
||||
"""Reads output until prompt is found.
|
||||
|
||||
Returns:
|
||||
String containing output up to the prompt.
|
||||
|
||||
Raises:
|
||||
socket.error: If socket communication fails.
|
||||
"""
|
||||
buffer = b""
|
||||
while b"$ " not in buffer:
|
||||
try:
|
||||
chunk = self.socket.recv(4096)
|
||||
if chunk:
|
||||
buffer += chunk
|
||||
except socket.error as e:
|
||||
if e.errno == socket.EWOULDBLOCK:
|
||||
await asyncio.sleep(0.1)
|
||||
continue
|
||||
raise
|
||||
return buffer.decode("utf-8")
|
||||
|
||||
async def execute(self, command: str, timeout: Optional[int] = None) -> str:
|
||||
"""Executes a command and returns cleaned output.
|
||||
|
||||
Args:
|
||||
command: Shell command to execute.
|
||||
timeout: Maximum execution time in seconds.
|
||||
|
||||
Returns:
|
||||
Command output as string with prompt markers removed.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If session not initialized or execution fails.
|
||||
TimeoutError: If command execution exceeds timeout.
|
||||
"""
|
||||
if not self.socket:
|
||||
raise RuntimeError("Session not initialized")
|
||||
|
||||
try:
|
||||
# Sanitize command to prevent shell injection
|
||||
sanitized_command = self._sanitize_command(command)
|
||||
full_command = f"{sanitized_command}\necho $?\n"
|
||||
self.socket.sendall(full_command.encode())
|
||||
|
||||
async def read_output() -> str:
|
||||
buffer = b""
|
||||
result_lines = []
|
||||
command_sent = False
|
||||
|
||||
while True:
|
||||
try:
|
||||
chunk = self.socket.recv(4096)
|
||||
if not chunk:
|
||||
break
|
||||
|
||||
buffer += chunk
|
||||
lines = buffer.split(b"\n")
|
||||
|
||||
buffer = lines[-1]
|
||||
lines = lines[:-1]
|
||||
|
||||
for line in lines:
|
||||
line = line.rstrip(b"\r")
|
||||
|
||||
if not command_sent:
|
||||
command_sent = True
|
||||
continue
|
||||
|
||||
if line.strip() == b"echo $?" or line.strip().isdigit():
|
||||
continue
|
||||
|
||||
if line.strip():
|
||||
result_lines.append(line)
|
||||
|
||||
if buffer.endswith(b"$ "):
|
||||
break
|
||||
|
||||
except socket.error as e:
|
||||
if e.errno == socket.EWOULDBLOCK:
|
||||
await asyncio.sleep(0.1)
|
||||
continue
|
||||
raise
|
||||
|
||||
output = b"\n".join(result_lines).decode("utf-8")
|
||||
output = re.sub(r"\n\$ echo \$\$?.*$", "", output)
|
||||
|
||||
return output
|
||||
|
||||
if timeout:
|
||||
result = await asyncio.wait_for(read_output(), timeout)
|
||||
else:
|
||||
result = await read_output()
|
||||
|
||||
return result.strip()
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
raise TimeoutError(f"Command execution timed out after {timeout} seconds")
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Failed to execute command: {e}")
|
||||
|
||||
def _sanitize_command(self, command: str) -> str:
|
||||
"""Sanitizes the command string to prevent shell injection.
|
||||
|
||||
Args:
|
||||
command: Raw command string.
|
||||
|
||||
Returns:
|
||||
Sanitized command string.
|
||||
|
||||
Raises:
|
||||
ValueError: If command contains potentially dangerous patterns.
|
||||
"""
|
||||
|
||||
# Additional checks for specific risky commands
|
||||
risky_commands = [
|
||||
"rm -rf /",
|
||||
"rm -rf /*",
|
||||
"mkfs",
|
||||
"dd if=/dev/zero",
|
||||
":(){:|:&};:",
|
||||
"chmod -R 777 /",
|
||||
"chown -R",
|
||||
]
|
||||
|
||||
for risky in risky_commands:
|
||||
if risky in command.lower():
|
||||
raise ValueError(
|
||||
f"Command contains potentially dangerous operation: {risky}"
|
||||
)
|
||||
|
||||
return command
|
||||
|
||||
|
||||
class AsyncDockerizedTerminal:
|
||||
def __init__(
|
||||
self,
|
||||
container: Union[str, Container],
|
||||
working_dir: str = "/workspace",
|
||||
env_vars: Optional[Dict[str, str]] = None,
|
||||
default_timeout: int = 60,
|
||||
) -> None:
|
||||
"""Initializes an asynchronous terminal for Docker containers.
|
||||
|
||||
Args:
|
||||
container: Docker container ID or Container object.
|
||||
working_dir: Working directory inside the container.
|
||||
env_vars: Environment variables to set.
|
||||
default_timeout: Default command execution timeout in seconds.
|
||||
"""
|
||||
self.client = docker.from_env()
|
||||
self.container = (
|
||||
container
|
||||
if isinstance(container, Container)
|
||||
else self.client.containers.get(container)
|
||||
)
|
||||
self.working_dir = working_dir
|
||||
self.env_vars = env_vars or {}
|
||||
self.default_timeout = default_timeout
|
||||
self.session = None
|
||||
|
||||
async def init(self) -> None:
|
||||
"""Initializes the terminal environment.
|
||||
|
||||
Ensures working directory exists and creates an interactive session.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If initialization fails.
|
||||
"""
|
||||
await self._ensure_workdir()
|
||||
|
||||
self.session = DockerSession(self.container.id)
|
||||
await self.session.create(self.working_dir, self.env_vars)
|
||||
|
||||
async def _ensure_workdir(self) -> None:
|
||||
"""Ensures working directory exists in container.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If directory creation fails.
|
||||
"""
|
||||
try:
|
||||
await self._exec_simple(f"mkdir -p {self.working_dir}")
|
||||
except APIError as e:
|
||||
raise RuntimeError(f"Failed to create working directory: {e}")
|
||||
|
||||
async def _exec_simple(self, cmd: str) -> Tuple[int, str]:
|
||||
"""Executes a simple command using Docker's exec_run.
|
||||
|
||||
Args:
|
||||
cmd: Command to execute.
|
||||
|
||||
Returns:
|
||||
Tuple of (exit_code, output).
|
||||
"""
|
||||
result = await asyncio.to_thread(
|
||||
self.container.exec_run, cmd, environment=self.env_vars
|
||||
)
|
||||
return result.exit_code, result.output.decode("utf-8")
|
||||
|
||||
async def run_command(self, cmd: str, timeout: Optional[int] = None) -> str:
|
||||
"""Runs a command in the container with timeout.
|
||||
|
||||
Args:
|
||||
cmd: Shell command to execute.
|
||||
timeout: Maximum execution time in seconds.
|
||||
|
||||
Returns:
|
||||
Command output as string.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If terminal not initialized.
|
||||
"""
|
||||
if not self.session:
|
||||
raise RuntimeError("Terminal not initialized")
|
||||
|
||||
return await self.session.execute(cmd, timeout=timeout or self.default_timeout)
|
||||
|
||||
async def close(self) -> None:
|
||||
"""Closes the terminal session."""
|
||||
if self.session:
|
||||
await self.session.close()
|
||||
|
||||
async def __aenter__(self) -> "AsyncDockerizedTerminal":
|
||||
"""Async context manager entry."""
|
||||
await self.init()
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc_val, exc_tb) -> None:
|
||||
"""Async context manager exit."""
|
||||
await self.close()
|
@ -59,7 +59,6 @@ class Message(BaseModel):
|
||||
tool_calls: Optional[List[ToolCall]] = Field(default=None)
|
||||
name: Optional[str] = Field(default=None)
|
||||
tool_call_id: Optional[str] = Field(default=None)
|
||||
base64_image: Optional[str] = Field(default=None)
|
||||
|
||||
def __add__(self, other) -> List["Message"]:
|
||||
"""支持 Message + list 或 Message + Message 的操作"""
|
||||
@ -92,16 +91,12 @@ class Message(BaseModel):
|
||||
message["name"] = self.name
|
||||
if self.tool_call_id is not None:
|
||||
message["tool_call_id"] = self.tool_call_id
|
||||
if self.base64_image is not None:
|
||||
message["base64_image"] = self.base64_image
|
||||
return message
|
||||
|
||||
@classmethod
|
||||
def user_message(
|
||||
cls, content: str, base64_image: Optional[str] = None
|
||||
) -> "Message":
|
||||
def user_message(cls, content: str) -> "Message":
|
||||
"""Create a user message"""
|
||||
return cls(role=Role.USER, content=content, base64_image=base64_image)
|
||||
return cls(role=Role.USER, content=content)
|
||||
|
||||
@classmethod
|
||||
def system_message(cls, content: str) -> "Message":
|
||||
@ -109,50 +104,33 @@ class Message(BaseModel):
|
||||
return cls(role=Role.SYSTEM, content=content)
|
||||
|
||||
@classmethod
|
||||
def assistant_message(
|
||||
cls, content: Optional[str] = None, base64_image: Optional[str] = None
|
||||
) -> "Message":
|
||||
def assistant_message(cls, content: Optional[str] = None) -> "Message":
|
||||
"""Create an assistant message"""
|
||||
return cls(role=Role.ASSISTANT, content=content, base64_image=base64_image)
|
||||
return cls(role=Role.ASSISTANT, content=content)
|
||||
|
||||
@classmethod
|
||||
def tool_message(
|
||||
cls, content: str, name, tool_call_id: str, base64_image: Optional[str] = None
|
||||
) -> "Message":
|
||||
def tool_message(cls, content: str, name, tool_call_id: str) -> "Message":
|
||||
"""Create a tool message"""
|
||||
return cls(
|
||||
role=Role.TOOL,
|
||||
content=content,
|
||||
name=name,
|
||||
tool_call_id=tool_call_id,
|
||||
base64_image=base64_image,
|
||||
role=Role.TOOL, content=content, name=name, tool_call_id=tool_call_id
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_tool_calls(
|
||||
cls,
|
||||
tool_calls: List[Any],
|
||||
content: Union[str, List[str]] = "",
|
||||
base64_image: Optional[str] = None,
|
||||
**kwargs,
|
||||
cls, tool_calls: List[Any], content: Union[str, List[str]] = "", **kwargs
|
||||
):
|
||||
"""Create ToolCallsMessage from raw tool calls.
|
||||
|
||||
Args:
|
||||
tool_calls: Raw tool calls from LLM
|
||||
content: Optional message content
|
||||
base64_image: Optional base64 encoded image
|
||||
"""
|
||||
formatted_calls = [
|
||||
{"id": call.id, "function": call.function.model_dump(), "type": "function"}
|
||||
for call in tool_calls
|
||||
]
|
||||
return cls(
|
||||
role=Role.ASSISTANT,
|
||||
content=content,
|
||||
tool_calls=formatted_calls,
|
||||
base64_image=base64_image,
|
||||
**kwargs,
|
||||
role=Role.ASSISTANT, content=content, tool_calls=formatted_calls, **kwargs
|
||||
)
|
||||
|
||||
|
||||
|
@ -1,6 +1,5 @@
|
||||
from app.tool.base import BaseTool
|
||||
from app.tool.bash import Bash
|
||||
from app.tool.browser_use_tool import BrowserUseTool
|
||||
from app.tool.create_chat_completion import CreateChatCompletion
|
||||
from app.tool.planning import PlanningTool
|
||||
from app.tool.str_replace_editor import StrReplaceEditor
|
||||
@ -11,7 +10,6 @@ from app.tool.tool_collection import ToolCollection
|
||||
__all__ = [
|
||||
"BaseTool",
|
||||
"Bash",
|
||||
"BrowserUseTool",
|
||||
"Terminate",
|
||||
"StrReplaceEditor",
|
||||
"ToolCollection",
|
||||
|
@ -37,7 +37,6 @@ class ToolResult(BaseModel):
|
||||
|
||||
output: Any = Field(default=None)
|
||||
error: Optional[str] = Field(default=None)
|
||||
base64_image: Optional[str] = Field(default=None)
|
||||
system: Optional[str] = Field(default=None)
|
||||
|
||||
class Config:
|
||||
@ -59,7 +58,6 @@ class ToolResult(BaseModel):
|
||||
return ToolResult(
|
||||
output=combine_fields(self.output, other.output),
|
||||
error=combine_fields(self.error, other.error),
|
||||
base64_image=combine_fields(self.base64_image, other.base64_image, False),
|
||||
system=combine_fields(self.system, other.system),
|
||||
)
|
||||
|
||||
@ -78,3 +76,7 @@ class CLIResult(ToolResult):
|
||||
|
||||
class ToolFailure(ToolResult):
|
||||
"""A ToolResult that represents a failure."""
|
||||
|
||||
|
||||
class AgentAwareTool:
|
||||
agent: Optional = None
|
||||
|
@ -3,7 +3,7 @@ import os
|
||||
from typing import Optional
|
||||
|
||||
from app.exceptions import ToolError
|
||||
from app.tool.base import BaseTool, CLIResult
|
||||
from app.tool.base import BaseTool, CLIResult, ToolResult
|
||||
|
||||
|
||||
_BASH_DESCRIPTION = """Execute a bash command in the terminal.
|
||||
@ -57,7 +57,7 @@ class _BashSession:
|
||||
if not self._started:
|
||||
raise ToolError("Session has not started.")
|
||||
if self._process.returncode is not None:
|
||||
return CLIResult(
|
||||
return ToolResult(
|
||||
system="tool must be restarted",
|
||||
error=f"bash has exited with returncode {self._process.returncode}",
|
||||
)
|
||||
@ -140,7 +140,7 @@ class Bash(BaseTool):
|
||||
self._session = _BashSession()
|
||||
await self._session.start()
|
||||
|
||||
return CLIResult(system="tool has been restarted.")
|
||||
return ToolResult(system="tool has been restarted.")
|
||||
|
||||
if self._session is None:
|
||||
self._session = _BashSession()
|
||||
|
@ -1,7 +1,6 @@
|
||||
import asyncio
|
||||
import base64
|
||||
import json
|
||||
from typing import Generic, Optional, TypeVar
|
||||
from typing import Optional
|
||||
|
||||
from browser_use import Browser as BrowserUseBrowser
|
||||
from browser_use import BrowserConfig
|
||||
@ -11,45 +10,31 @@ from pydantic import Field, field_validator
|
||||
from pydantic_core.core_schema import ValidationInfo
|
||||
|
||||
from app.config import config
|
||||
from app.llm import LLM
|
||||
from app.tool.base import BaseTool, ToolResult
|
||||
from app.tool.web_search import WebSearch
|
||||
|
||||
|
||||
MAX_LENGTH = 2000
|
||||
|
||||
_BROWSER_DESCRIPTION = """
|
||||
Interact with a web browser to perform various actions such as navigation, element interaction, content extraction, and tab management. This tool provides a comprehensive set of browser automation capabilities:
|
||||
|
||||
Navigation:
|
||||
- 'go_to_url': Go to a specific URL in the current tab
|
||||
- 'go_back': Go back
|
||||
- 'refresh': Refresh the current page
|
||||
- 'web_search': Search the query in the current tab, the query should be a search query like humans search in web, concrete and not vague or super long. More the single most important items.
|
||||
|
||||
Element Interaction:
|
||||
- 'click_element': Click an element by index
|
||||
- 'input_text': Input text into a form element
|
||||
- 'scroll_down'/'scroll_up': Scroll the page (with optional pixel amount)
|
||||
- 'scroll_to_text': If you dont find something which you want to interact with, scroll to it
|
||||
- 'send_keys': Send strings of special keys like Escape,Backspace, Insert, PageDown, Delete, Enter, Shortcuts such as `Control+o`, `Control+Shift+T` are supported as well. This gets used in keyboard.press.
|
||||
- 'get_dropdown_options': Get all options from a dropdown
|
||||
- 'select_dropdown_option': Select dropdown option for interactive element index by the text of the option you want to select
|
||||
|
||||
Content Extraction:
|
||||
- 'extract_content': Extract page content to retrieve specific information from the page, e.g. all company names, a specifc description, all information about, links with companies in structured format or simply links
|
||||
|
||||
Tab Management:
|
||||
Interact with a web browser to perform various actions such as navigation, element interaction,
|
||||
content extraction, and tab management. Supported actions include:
|
||||
- 'navigate': Go to a specific URL
|
||||
- 'click': Click an element by index
|
||||
- 'input_text': Input text into an element
|
||||
- 'screenshot': Capture a screenshot
|
||||
- 'get_html': Get page HTML content
|
||||
- 'get_text': Get text content of the page
|
||||
- 'read_links': Get all links on the page
|
||||
- 'execute_js': Execute JavaScript code
|
||||
- 'scroll': Scroll the page
|
||||
- 'switch_tab': Switch to a specific tab
|
||||
- 'open_tab': Open a new tab with a URL
|
||||
- 'new_tab': Open a new tab
|
||||
- 'close_tab': Close the current tab
|
||||
|
||||
Utility:
|
||||
- 'wait': Wait for a specified number of seconds
|
||||
- 'refresh': Refresh the current page
|
||||
"""
|
||||
|
||||
Context = TypeVar("Context")
|
||||
|
||||
|
||||
class BrowserUseTool(BaseTool, Generic[Context]):
|
||||
class BrowserUseTool(BaseTool):
|
||||
name: str = "browser_use"
|
||||
description: str = _BROWSER_DESCRIPTION
|
||||
parameters: dict = {
|
||||
@ -58,79 +43,52 @@ class BrowserUseTool(BaseTool, Generic[Context]):
|
||||
"action": {
|
||||
"type": "string",
|
||||
"enum": [
|
||||
"go_to_url",
|
||||
"click_element",
|
||||
"navigate",
|
||||
"click",
|
||||
"input_text",
|
||||
"scroll_down",
|
||||
"scroll_up",
|
||||
"scroll_to_text",
|
||||
"send_keys",
|
||||
"get_dropdown_options",
|
||||
"select_dropdown_option",
|
||||
"go_back",
|
||||
"web_search",
|
||||
"wait",
|
||||
"extract_content",
|
||||
"screenshot",
|
||||
"get_html",
|
||||
"get_text",
|
||||
"execute_js",
|
||||
"scroll",
|
||||
"switch_tab",
|
||||
"open_tab",
|
||||
"new_tab",
|
||||
"close_tab",
|
||||
"refresh",
|
||||
],
|
||||
"description": "The browser action to perform",
|
||||
},
|
||||
"url": {
|
||||
"type": "string",
|
||||
"description": "URL for 'go_to_url' or 'open_tab' actions",
|
||||
"description": "URL for 'navigate' or 'new_tab' actions",
|
||||
},
|
||||
"index": {
|
||||
"type": "integer",
|
||||
"description": "Element index for 'click_element', 'input_text', 'get_dropdown_options', or 'select_dropdown_option' actions",
|
||||
"description": "Element index for 'click' or 'input_text' actions",
|
||||
},
|
||||
"text": {
|
||||
"text": {"type": "string", "description": "Text for 'input_text' action"},
|
||||
"script": {
|
||||
"type": "string",
|
||||
"description": "Text for 'input_text', 'scroll_to_text', or 'select_dropdown_option' actions",
|
||||
"description": "JavaScript code for 'execute_js' action",
|
||||
},
|
||||
"scroll_amount": {
|
||||
"type": "integer",
|
||||
"description": "Pixels to scroll (positive for down, negative for up) for 'scroll_down' or 'scroll_up' actions",
|
||||
"description": "Pixels to scroll (positive for down, negative for up) for 'scroll' action",
|
||||
},
|
||||
"tab_id": {
|
||||
"type": "integer",
|
||||
"description": "Tab ID for 'switch_tab' action",
|
||||
},
|
||||
"query": {
|
||||
"type": "string",
|
||||
"description": "Search query for 'web_search' action",
|
||||
},
|
||||
"goal": {
|
||||
"type": "string",
|
||||
"description": "Extraction goal for 'extract_content' action",
|
||||
},
|
||||
"keys": {
|
||||
"type": "string",
|
||||
"description": "Keys to send for 'send_keys' action",
|
||||
},
|
||||
"seconds": {
|
||||
"type": "integer",
|
||||
"description": "Seconds to wait for 'wait' action",
|
||||
},
|
||||
},
|
||||
"required": ["action"],
|
||||
"dependencies": {
|
||||
"go_to_url": ["url"],
|
||||
"click_element": ["index"],
|
||||
"navigate": ["url"],
|
||||
"click": ["index"],
|
||||
"input_text": ["index", "text"],
|
||||
"execute_js": ["script"],
|
||||
"switch_tab": ["tab_id"],
|
||||
"open_tab": ["url"],
|
||||
"scroll_down": ["scroll_amount"],
|
||||
"scroll_up": ["scroll_amount"],
|
||||
"scroll_to_text": ["text"],
|
||||
"send_keys": ["keys"],
|
||||
"get_dropdown_options": ["index"],
|
||||
"select_dropdown_option": ["index", "text"],
|
||||
"go_back": [],
|
||||
"web_search": ["query"],
|
||||
"wait": ["seconds"],
|
||||
"extract_content": ["goal"],
|
||||
"new_tab": ["url"],
|
||||
"scroll": ["scroll_amount"],
|
||||
},
|
||||
}
|
||||
|
||||
@ -138,12 +96,6 @@ class BrowserUseTool(BaseTool, Generic[Context]):
|
||||
browser: Optional[BrowserUseBrowser] = Field(default=None, exclude=True)
|
||||
context: Optional[BrowserContext] = Field(default=None, exclude=True)
|
||||
dom_service: Optional[DomService] = Field(default=None, exclude=True)
|
||||
web_search_tool: WebSearch = Field(default_factory=WebSearch, exclude=True)
|
||||
|
||||
# Context for generic functionality
|
||||
tool_context: Optional[Context] = Field(default=None, exclude=True)
|
||||
|
||||
llm: Optional[LLM] = Field(default_factory=LLM)
|
||||
|
||||
@field_validator("parameters", mode="before")
|
||||
def validate_parameters(cls, v: dict, info: ValidationInfo) -> dict:
|
||||
@ -206,12 +158,9 @@ class BrowserUseTool(BaseTool, Generic[Context]):
|
||||
url: Optional[str] = None,
|
||||
index: Optional[int] = None,
|
||||
text: Optional[str] = None,
|
||||
script: Optional[str] = None,
|
||||
scroll_amount: Optional[int] = None,
|
||||
tab_id: Optional[int] = None,
|
||||
query: Optional[str] = None,
|
||||
goal: Optional[str] = None,
|
||||
keys: Optional[str] = None,
|
||||
seconds: Optional[int] = None,
|
||||
**kwargs,
|
||||
) -> ToolResult:
|
||||
"""
|
||||
@ -221,13 +170,10 @@ class BrowserUseTool(BaseTool, Generic[Context]):
|
||||
action: The browser action to perform
|
||||
url: URL for navigation or new tab
|
||||
index: Element index for click or input actions
|
||||
text: Text for input action or search query
|
||||
text: Text for input action
|
||||
script: JavaScript code for execution
|
||||
scroll_amount: Pixels to scroll for scroll action
|
||||
tab_id: Tab ID for switch_tab action
|
||||
query: Search query for Google search
|
||||
goal: Extraction goal for content extraction
|
||||
keys: Keys to send for keyboard actions
|
||||
seconds: Seconds to wait
|
||||
**kwargs: Additional arguments
|
||||
|
||||
Returns:
|
||||
@ -237,68 +183,15 @@ class BrowserUseTool(BaseTool, Generic[Context]):
|
||||
try:
|
||||
context = await self._ensure_browser_initialized()
|
||||
|
||||
# Get max content length from config
|
||||
max_content_length = getattr(
|
||||
config.browser_config, "max_content_length", 2000
|
||||
)
|
||||
|
||||
# Navigation actions
|
||||
if action == "go_to_url":
|
||||
if action == "navigate":
|
||||
if not url:
|
||||
return ToolResult(
|
||||
error="URL is required for 'go_to_url' action"
|
||||
)
|
||||
page = await context.get_current_page()
|
||||
await page.goto(url)
|
||||
await page.wait_for_load_state()
|
||||
return ToolResult(error="URL is required for 'navigate' action")
|
||||
await context.navigate_to(url)
|
||||
return ToolResult(output=f"Navigated to {url}")
|
||||
|
||||
elif action == "go_back":
|
||||
await context.go_back()
|
||||
return ToolResult(output="Navigated back")
|
||||
|
||||
elif action == "refresh":
|
||||
await context.refresh_page()
|
||||
return ToolResult(output="Refreshed current page")
|
||||
|
||||
elif action == "web_search":
|
||||
if not query:
|
||||
return ToolResult(
|
||||
error="Query is required for 'web_search' action"
|
||||
)
|
||||
search_results = await self.web_search_tool.execute(query)
|
||||
|
||||
if search_results:
|
||||
# Navigate to the first search result
|
||||
first_result = search_results[0]
|
||||
if isinstance(first_result, dict) and "url" in first_result:
|
||||
url_to_navigate = first_result["url"]
|
||||
elif isinstance(first_result, str):
|
||||
url_to_navigate = first_result
|
||||
else:
|
||||
return ToolResult(
|
||||
error=f"Invalid search result format: {first_result}"
|
||||
)
|
||||
|
||||
page = await context.get_current_page()
|
||||
await page.goto(url_to_navigate)
|
||||
await page.wait_for_load_state()
|
||||
|
||||
return ToolResult(
|
||||
output=f"Searched for '{query}' and navigated to first result: {url_to_navigate}\nAll results:"
|
||||
+ "\n".join([str(r) for r in search_results])
|
||||
)
|
||||
else:
|
||||
return ToolResult(
|
||||
error=f"No search results found for '{query}'"
|
||||
)
|
||||
|
||||
# Element interaction actions
|
||||
elif action == "click_element":
|
||||
elif action == "click":
|
||||
if index is None:
|
||||
return ToolResult(
|
||||
error="Index is required for 'click_element' action"
|
||||
)
|
||||
return ToolResult(error="Index is required for 'click' action")
|
||||
element = await context.get_dom_element_by_index(index)
|
||||
if not element:
|
||||
return ToolResult(error=f"Element with index {index} not found")
|
||||
@ -321,208 +214,72 @@ class BrowserUseTool(BaseTool, Generic[Context]):
|
||||
output=f"Input '{text}' into element at index {index}"
|
||||
)
|
||||
|
||||
elif action == "scroll_down" or action == "scroll_up":
|
||||
direction = 1 if action == "scroll_down" else -1
|
||||
amount = (
|
||||
scroll_amount
|
||||
if scroll_amount is not None
|
||||
else context.config.browser_window_size["height"]
|
||||
elif action == "screenshot":
|
||||
screenshot = await context.take_screenshot(full_page=True)
|
||||
return ToolResult(
|
||||
output=f"Screenshot captured (base64 length: {len(screenshot)})",
|
||||
system=screenshot,
|
||||
)
|
||||
|
||||
elif action == "get_html":
|
||||
html = await context.get_page_html()
|
||||
truncated = (
|
||||
html[:MAX_LENGTH] + "..." if len(html) > MAX_LENGTH else html
|
||||
)
|
||||
return ToolResult(output=truncated)
|
||||
|
||||
elif action == "get_text":
|
||||
text = await context.execute_javascript("document.body.innerText")
|
||||
return ToolResult(output=text)
|
||||
|
||||
elif action == "read_links":
|
||||
links = await context.execute_javascript(
|
||||
"document.querySelectorAll('a[href]').forEach((elem) => {if (elem.innerText) {console.log(elem.innerText, elem.href)}})"
|
||||
)
|
||||
return ToolResult(output=links)
|
||||
|
||||
elif action == "execute_js":
|
||||
if not script:
|
||||
return ToolResult(
|
||||
error="Script is required for 'execute_js' action"
|
||||
)
|
||||
result = await context.execute_javascript(script)
|
||||
return ToolResult(output=str(result))
|
||||
|
||||
elif action == "scroll":
|
||||
if scroll_amount is None:
|
||||
return ToolResult(
|
||||
error="Scroll amount is required for 'scroll' action"
|
||||
)
|
||||
await context.execute_javascript(
|
||||
f"window.scrollBy(0, {direction * amount});"
|
||||
f"window.scrollBy(0, {scroll_amount});"
|
||||
)
|
||||
direction = "down" if scroll_amount > 0 else "up"
|
||||
return ToolResult(
|
||||
output=f"Scrolled {'down' if direction > 0 else 'up'} by {amount} pixels"
|
||||
output=f"Scrolled {direction} by {abs(scroll_amount)} pixels"
|
||||
)
|
||||
|
||||
elif action == "scroll_to_text":
|
||||
if not text:
|
||||
return ToolResult(
|
||||
error="Text is required for 'scroll_to_text' action"
|
||||
)
|
||||
page = await context.get_current_page()
|
||||
try:
|
||||
locator = page.get_by_text(text, exact=False)
|
||||
await locator.scroll_into_view_if_needed()
|
||||
return ToolResult(output=f"Scrolled to text: '{text}'")
|
||||
except Exception as e:
|
||||
return ToolResult(error=f"Failed to scroll to text: {str(e)}")
|
||||
|
||||
elif action == "send_keys":
|
||||
if not keys:
|
||||
return ToolResult(
|
||||
error="Keys are required for 'send_keys' action"
|
||||
)
|
||||
page = await context.get_current_page()
|
||||
await page.keyboard.press(keys)
|
||||
return ToolResult(output=f"Sent keys: {keys}")
|
||||
|
||||
elif action == "get_dropdown_options":
|
||||
if index is None:
|
||||
return ToolResult(
|
||||
error="Index is required for 'get_dropdown_options' action"
|
||||
)
|
||||
element = await context.get_dom_element_by_index(index)
|
||||
if not element:
|
||||
return ToolResult(error=f"Element with index {index} not found")
|
||||
page = await context.get_current_page()
|
||||
options = await page.evaluate(
|
||||
"""
|
||||
(xpath) => {
|
||||
const select = document.evaluate(xpath, document, null,
|
||||
XPathResult.FIRST_ORDERED_NODE_TYPE, null).singleNodeValue;
|
||||
if (!select) return null;
|
||||
return Array.from(select.options).map(opt => ({
|
||||
text: opt.text,
|
||||
value: opt.value,
|
||||
index: opt.index
|
||||
}));
|
||||
}
|
||||
""",
|
||||
element.xpath,
|
||||
)
|
||||
return ToolResult(output=f"Dropdown options: {options}")
|
||||
|
||||
elif action == "select_dropdown_option":
|
||||
if index is None or not text:
|
||||
return ToolResult(
|
||||
error="Index and text are required for 'select_dropdown_option' action"
|
||||
)
|
||||
element = await context.get_dom_element_by_index(index)
|
||||
if not element:
|
||||
return ToolResult(error=f"Element with index {index} not found")
|
||||
page = await context.get_current_page()
|
||||
await page.select_option(element.xpath, label=text)
|
||||
return ToolResult(
|
||||
output=f"Selected option '{text}' from dropdown at index {index}"
|
||||
)
|
||||
|
||||
# Content extraction actions
|
||||
elif action == "extract_content":
|
||||
if not goal:
|
||||
return ToolResult(
|
||||
error="Goal is required for 'extract_content' action"
|
||||
)
|
||||
page = await context.get_current_page()
|
||||
try:
|
||||
# Get page content and convert to markdown for better processing
|
||||
html_content = await page.content()
|
||||
|
||||
# Import markdownify here to avoid global import
|
||||
try:
|
||||
import markdownify
|
||||
|
||||
content = markdownify.markdownify(html_content)
|
||||
except ImportError:
|
||||
# Fallback if markdownify is not available
|
||||
content = html_content
|
||||
|
||||
# Create prompt for LLM
|
||||
prompt_text = """
|
||||
Your task is to extract the content of the page. You will be given a page and a goal, and you should extract all relevant information around this goal from the page. If the goal is vague, summarize the page. Respond in json format.
|
||||
Extraction goal: {goal}
|
||||
|
||||
Page content:
|
||||
{page}
|
||||
"""
|
||||
# Format the prompt with the goal and content
|
||||
max_content_length = min(50000, len(content))
|
||||
formatted_prompt = prompt_text.format(
|
||||
goal=goal, page=content[:max_content_length]
|
||||
)
|
||||
|
||||
# Create a proper message list for the LLM
|
||||
from app.schema import Message
|
||||
|
||||
messages = [Message.user_message(formatted_prompt)]
|
||||
|
||||
# Define extraction function for the tool
|
||||
extraction_function = {
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "extract_content",
|
||||
"description": "Extract specific information from a webpage based on a goal",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"extracted_content": {
|
||||
"type": "object",
|
||||
"description": "The content extracted from the page according to the goal",
|
||||
}
|
||||
},
|
||||
"required": ["extracted_content"],
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
# Use LLM to extract content with required function calling
|
||||
response = await self.llm.ask_tool(
|
||||
messages,
|
||||
tools=[extraction_function],
|
||||
tool_choice="required",
|
||||
)
|
||||
|
||||
# Extract content from function call response
|
||||
if (
|
||||
response
|
||||
and response.tool_calls
|
||||
and len(response.tool_calls) > 0
|
||||
):
|
||||
# Get the first tool call arguments
|
||||
tool_call = response.tool_calls[0]
|
||||
# Parse the JSON arguments
|
||||
try:
|
||||
args = json.loads(tool_call.function.arguments)
|
||||
extracted_content = args.get("extracted_content", {})
|
||||
# Format extracted content as JSON string
|
||||
content_json = json.dumps(
|
||||
extracted_content, indent=2, ensure_ascii=False
|
||||
)
|
||||
msg = f"Extracted from page:\n{content_json}\n"
|
||||
except Exception as e:
|
||||
msg = f"Error parsing extraction result: {str(e)}\nRaw response: {tool_call.function.arguments}"
|
||||
else:
|
||||
msg = "No content was extracted from the page."
|
||||
|
||||
return ToolResult(output=msg)
|
||||
except Exception as e:
|
||||
# Provide a more helpful error message
|
||||
error_msg = f"Failed to extract content: {str(e)}"
|
||||
try:
|
||||
# Try to return a portion of the page content as fallback
|
||||
return ToolResult(
|
||||
output=f"{error_msg}\nHere's a portion of the page content:\n{content[:2000]}..."
|
||||
)
|
||||
except:
|
||||
# If all else fails, just return the error
|
||||
return ToolResult(error=error_msg)
|
||||
|
||||
# Tab management actions
|
||||
elif action == "switch_tab":
|
||||
if tab_id is None:
|
||||
return ToolResult(
|
||||
error="Tab ID is required for 'switch_tab' action"
|
||||
)
|
||||
await context.switch_to_tab(tab_id)
|
||||
page = await context.get_current_page()
|
||||
await page.wait_for_load_state()
|
||||
return ToolResult(output=f"Switched to tab {tab_id}")
|
||||
|
||||
elif action == "open_tab":
|
||||
elif action == "new_tab":
|
||||
if not url:
|
||||
return ToolResult(error="URL is required for 'open_tab' action")
|
||||
return ToolResult(error="URL is required for 'new_tab' action")
|
||||
await context.create_new_tab(url)
|
||||
return ToolResult(output=f"Opened new tab with {url}")
|
||||
return ToolResult(output=f"Opened new tab with URL {url}")
|
||||
|
||||
elif action == "close_tab":
|
||||
await context.close_current_tab()
|
||||
return ToolResult(output="Closed current tab")
|
||||
|
||||
# Utility actions
|
||||
elif action == "wait":
|
||||
seconds_to_wait = seconds if seconds is not None else 3
|
||||
await asyncio.sleep(seconds_to_wait)
|
||||
return ToolResult(output=f"Waited for {seconds_to_wait} seconds")
|
||||
elif action == "refresh":
|
||||
await context.refresh_page()
|
||||
return ToolResult(output="Refreshed current page")
|
||||
|
||||
else:
|
||||
return ToolResult(error=f"Unknown action: {action}")
|
||||
@ -530,67 +287,21 @@ Page content:
|
||||
except Exception as e:
|
||||
return ToolResult(error=f"Browser action '{action}' failed: {str(e)}")
|
||||
|
||||
async def get_current_state(
|
||||
self, context: Optional[BrowserContext] = None
|
||||
) -> ToolResult:
|
||||
"""
|
||||
Get the current browser state as a ToolResult.
|
||||
If context is not provided, uses self.context.
|
||||
"""
|
||||
try:
|
||||
# Use provided context or fall back to self.context
|
||||
ctx = context or self.context
|
||||
if not ctx:
|
||||
return ToolResult(error="Browser context not initialized")
|
||||
|
||||
state = await ctx.get_state()
|
||||
|
||||
# Create a viewport_info dictionary if it doesn't exist
|
||||
viewport_height = 0
|
||||
if hasattr(state, "viewport_info") and state.viewport_info:
|
||||
viewport_height = state.viewport_info.height
|
||||
elif hasattr(ctx, "config") and hasattr(ctx.config, "browser_window_size"):
|
||||
viewport_height = ctx.config.browser_window_size.get("height", 0)
|
||||
|
||||
# Take a screenshot for the state
|
||||
page = await ctx.get_current_page()
|
||||
|
||||
await page.bring_to_front()
|
||||
await page.wait_for_load_state()
|
||||
|
||||
screenshot = await page.screenshot(
|
||||
full_page=True, animations="disabled", type="jpeg", quality=100
|
||||
)
|
||||
|
||||
screenshot = base64.b64encode(screenshot).decode("utf-8")
|
||||
|
||||
# Build the state info with all required fields
|
||||
state_info = {
|
||||
"url": state.url,
|
||||
"title": state.title,
|
||||
"tabs": [tab.model_dump() for tab in state.tabs],
|
||||
"help": "[0], [1], [2], etc., represent clickable indices corresponding to the elements listed. Clicking on these indices will navigate to or interact with the respective content behind them.",
|
||||
"interactive_elements": (
|
||||
state.element_tree.clickable_elements_to_string()
|
||||
if state.element_tree
|
||||
else ""
|
||||
),
|
||||
"scroll_info": {
|
||||
"pixels_above": getattr(state, "pixels_above", 0),
|
||||
"pixels_below": getattr(state, "pixels_below", 0),
|
||||
"total_height": getattr(state, "pixels_above", 0)
|
||||
+ getattr(state, "pixels_below", 0)
|
||||
+ viewport_height,
|
||||
},
|
||||
"viewport_height": viewport_height,
|
||||
}
|
||||
|
||||
return ToolResult(
|
||||
output=json.dumps(state_info, indent=4, ensure_ascii=False),
|
||||
base64_image=screenshot,
|
||||
)
|
||||
except Exception as e:
|
||||
return ToolResult(error=f"Failed to get browser state: {str(e)}")
|
||||
async def get_current_state(self) -> ToolResult:
|
||||
"""Get the current browser state as a ToolResult."""
|
||||
async with self.lock:
|
||||
try:
|
||||
context = await self._ensure_browser_initialized()
|
||||
state = await context.get_state()
|
||||
state_info = {
|
||||
"url": state.url,
|
||||
"title": state.title,
|
||||
"tabs": [tab.model_dump() for tab in state.tabs],
|
||||
"interactive_elements": state.element_tree.clickable_elements_to_string(),
|
||||
}
|
||||
return ToolResult(output=json.dumps(state_info))
|
||||
except Exception as e:
|
||||
return ToolResult(error=f"Failed to get browser state: {str(e)}")
|
||||
|
||||
async def cleanup(self):
|
||||
"""Clean up browser resources."""
|
||||
@ -612,10 +323,3 @@ Page content:
|
||||
loop = asyncio.new_event_loop()
|
||||
loop.run_until_complete(self.cleanup())
|
||||
loop.close()
|
||||
|
||||
@classmethod
|
||||
def create_with_context(cls, context: Context) -> "BrowserUseTool[Context]":
|
||||
"""Factory method to create a BrowserUseTool with a specific context."""
|
||||
tool = cls()
|
||||
tool.tool_context = context
|
||||
return tool
|
||||
|
@ -1,158 +0,0 @@
|
||||
"""File operation interfaces and implementations for local and sandbox environments."""
|
||||
|
||||
import asyncio
|
||||
from pathlib import Path
|
||||
from typing import Optional, Protocol, Tuple, Union, runtime_checkable
|
||||
|
||||
from app.config import SandboxSettings
|
||||
from app.exceptions import ToolError
|
||||
from app.sandbox.client import SANDBOX_CLIENT
|
||||
|
||||
|
||||
PathLike = Union[str, Path]
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class FileOperator(Protocol):
|
||||
"""Interface for file operations in different environments."""
|
||||
|
||||
async def read_file(self, path: PathLike) -> str:
|
||||
"""Read content from a file."""
|
||||
...
|
||||
|
||||
async def write_file(self, path: PathLike, content: str) -> None:
|
||||
"""Write content to a file."""
|
||||
...
|
||||
|
||||
async def is_directory(self, path: PathLike) -> bool:
|
||||
"""Check if path points to a directory."""
|
||||
...
|
||||
|
||||
async def exists(self, path: PathLike) -> bool:
|
||||
"""Check if path exists."""
|
||||
...
|
||||
|
||||
async def run_command(
|
||||
self, cmd: str, timeout: Optional[float] = 120.0
|
||||
) -> Tuple[int, str, str]:
|
||||
"""Run a shell command and return (return_code, stdout, stderr)."""
|
||||
...
|
||||
|
||||
|
||||
class LocalFileOperator(FileOperator):
|
||||
"""File operations implementation for local filesystem."""
|
||||
|
||||
encoding: str = "utf-8"
|
||||
|
||||
async def read_file(self, path: PathLike) -> str:
|
||||
"""Read content from a local file."""
|
||||
try:
|
||||
return Path(path).read_text(encoding=self.encoding)
|
||||
except Exception as e:
|
||||
raise ToolError(f"Failed to read {path}: {str(e)}") from None
|
||||
|
||||
async def write_file(self, path: PathLike, content: str) -> None:
|
||||
"""Write content to a local file."""
|
||||
try:
|
||||
Path(path).write_text(content, encoding=self.encoding)
|
||||
except Exception as e:
|
||||
raise ToolError(f"Failed to write to {path}: {str(e)}") from None
|
||||
|
||||
async def is_directory(self, path: PathLike) -> bool:
|
||||
"""Check if path points to a directory."""
|
||||
return Path(path).is_dir()
|
||||
|
||||
async def exists(self, path: PathLike) -> bool:
|
||||
"""Check if path exists."""
|
||||
return Path(path).exists()
|
||||
|
||||
async def run_command(
|
||||
self, cmd: str, timeout: Optional[float] = 120.0
|
||||
) -> Tuple[int, str, str]:
|
||||
"""Run a shell command locally."""
|
||||
process = await asyncio.create_subprocess_shell(
|
||||
cmd, stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE
|
||||
)
|
||||
|
||||
try:
|
||||
stdout, stderr = await asyncio.wait_for(
|
||||
process.communicate(), timeout=timeout
|
||||
)
|
||||
return (
|
||||
process.returncode or 0,
|
||||
stdout.decode(),
|
||||
stderr.decode(),
|
||||
)
|
||||
except asyncio.TimeoutError as exc:
|
||||
try:
|
||||
process.kill()
|
||||
except ProcessLookupError:
|
||||
pass
|
||||
raise TimeoutError(
|
||||
f"Command '{cmd}' timed out after {timeout} seconds"
|
||||
) from exc
|
||||
|
||||
|
||||
class SandboxFileOperator(FileOperator):
|
||||
"""File operations implementation for sandbox environment."""
|
||||
|
||||
def __init__(self):
|
||||
self.sandbox_client = SANDBOX_CLIENT
|
||||
|
||||
async def _ensure_sandbox_initialized(self):
|
||||
"""Ensure sandbox is initialized."""
|
||||
if not self.sandbox_client.sandbox:
|
||||
await self.sandbox_client.create(config=SandboxSettings())
|
||||
|
||||
async def read_file(self, path: PathLike) -> str:
|
||||
"""Read content from a file in sandbox."""
|
||||
await self._ensure_sandbox_initialized()
|
||||
try:
|
||||
return await self.sandbox_client.read_file(str(path))
|
||||
except Exception as e:
|
||||
raise ToolError(f"Failed to read {path} in sandbox: {str(e)}") from None
|
||||
|
||||
async def write_file(self, path: PathLike, content: str) -> None:
|
||||
"""Write content to a file in sandbox."""
|
||||
await self._ensure_sandbox_initialized()
|
||||
try:
|
||||
await self.sandbox_client.write_file(str(path), content)
|
||||
except Exception as e:
|
||||
raise ToolError(f"Failed to write to {path} in sandbox: {str(e)}") from None
|
||||
|
||||
async def is_directory(self, path: PathLike) -> bool:
|
||||
"""Check if path points to a directory in sandbox."""
|
||||
await self._ensure_sandbox_initialized()
|
||||
result = await self.sandbox_client.run_command(
|
||||
f"test -d {path} && echo 'true' || echo 'false'"
|
||||
)
|
||||
return result.strip() == "true"
|
||||
|
||||
async def exists(self, path: PathLike) -> bool:
|
||||
"""Check if path exists in sandbox."""
|
||||
await self._ensure_sandbox_initialized()
|
||||
result = await self.sandbox_client.run_command(
|
||||
f"test -e {path} && echo 'true' || echo 'false'"
|
||||
)
|
||||
return result.strip() == "true"
|
||||
|
||||
async def run_command(
|
||||
self, cmd: str, timeout: Optional[float] = 120.0
|
||||
) -> Tuple[int, str, str]:
|
||||
"""Run a command in sandbox environment."""
|
||||
await self._ensure_sandbox_initialized()
|
||||
try:
|
||||
stdout = await self.sandbox_client.run_command(
|
||||
cmd, timeout=int(timeout) if timeout else None
|
||||
)
|
||||
return (
|
||||
0, # Always return 0 since we don't have explicit return code from sandbox
|
||||
stdout,
|
||||
"", # No stderr capture in the current sandbox implementation
|
||||
)
|
||||
except TimeoutError as exc:
|
||||
raise TimeoutError(
|
||||
f"Command '{cmd}' timed out after {timeout} seconds in sandbox"
|
||||
) from exc
|
||||
except Exception as exc:
|
||||
return 1, "", f"Error executing command in sandbox: {str(exc)}"
|
115
app/tool/mcp.py
115
app/tool/mcp.py
@ -1,115 +0,0 @@
|
||||
from contextlib import AsyncExitStack
|
||||
from typing import List, Optional
|
||||
|
||||
from mcp import ClientSession, StdioServerParameters
|
||||
from mcp.client.sse import sse_client
|
||||
from mcp.client.stdio import stdio_client
|
||||
from mcp.types import TextContent
|
||||
|
||||
from app.logger import logger
|
||||
from app.tool.base import BaseTool, ToolResult
|
||||
from app.tool.tool_collection import ToolCollection
|
||||
|
||||
|
||||
class MCPClientTool(BaseTool):
|
||||
"""Represents a tool proxy that can be called on the MCP server from the client side."""
|
||||
|
||||
session: Optional[ClientSession] = None
|
||||
|
||||
async def execute(self, **kwargs) -> ToolResult:
|
||||
"""Execute the tool by making a remote call to the MCP server."""
|
||||
if not self.session:
|
||||
return ToolResult(error="Not connected to MCP server")
|
||||
|
||||
try:
|
||||
result = await self.session.call_tool(self.name, kwargs)
|
||||
content_str = ", ".join(
|
||||
item.text for item in result.content if isinstance(item, TextContent)
|
||||
)
|
||||
return ToolResult(output=content_str or "No output returned.")
|
||||
except Exception as e:
|
||||
return ToolResult(error=f"Error executing tool: {str(e)}")
|
||||
|
||||
|
||||
class MCPClients(ToolCollection):
|
||||
"""
|
||||
A collection of tools that connects to an MCP server and manages available tools through the Model Context Protocol.
|
||||
"""
|
||||
|
||||
session: Optional[ClientSession] = None
|
||||
exit_stack: AsyncExitStack = None
|
||||
description: str = "MCP client tools for server interaction"
|
||||
|
||||
def __init__(self):
|
||||
super().__init__() # Initialize with empty tools list
|
||||
self.name = "mcp" # Keep name for backward compatibility
|
||||
self.exit_stack = AsyncExitStack()
|
||||
|
||||
async def connect_sse(self, server_url: str) -> None:
|
||||
"""Connect to an MCP server using SSE transport."""
|
||||
if not server_url:
|
||||
raise ValueError("Server URL is required.")
|
||||
if self.session:
|
||||
await self.disconnect()
|
||||
|
||||
streams_context = sse_client(url=server_url)
|
||||
streams = await self.exit_stack.enter_async_context(streams_context)
|
||||
self.session = await self.exit_stack.enter_async_context(
|
||||
ClientSession(*streams)
|
||||
)
|
||||
|
||||
await self._initialize_and_list_tools()
|
||||
|
||||
async def connect_stdio(self, command: str, args: List[str]) -> None:
|
||||
"""Connect to an MCP server using stdio transport."""
|
||||
if not command:
|
||||
raise ValueError("Server command is required.")
|
||||
if self.session:
|
||||
await self.disconnect()
|
||||
|
||||
server_params = StdioServerParameters(command=command, args=args)
|
||||
stdio_transport = await self.exit_stack.enter_async_context(
|
||||
stdio_client(server_params)
|
||||
)
|
||||
read, write = stdio_transport
|
||||
self.session = await self.exit_stack.enter_async_context(
|
||||
ClientSession(read, write)
|
||||
)
|
||||
|
||||
await self._initialize_and_list_tools()
|
||||
|
||||
async def _initialize_and_list_tools(self) -> None:
|
||||
"""Initialize session and populate tool map."""
|
||||
if not self.session:
|
||||
raise RuntimeError("Session not initialized.")
|
||||
|
||||
await self.session.initialize()
|
||||
response = await self.session.list_tools()
|
||||
|
||||
# Clear existing tools
|
||||
self.tools = tuple()
|
||||
self.tool_map = {}
|
||||
|
||||
# Create proper tool objects for each server tool
|
||||
for tool in response.tools:
|
||||
server_tool = MCPClientTool(
|
||||
name=tool.name,
|
||||
description=tool.description,
|
||||
parameters=tool.inputSchema,
|
||||
session=self.session,
|
||||
)
|
||||
self.tool_map[tool.name] = server_tool
|
||||
|
||||
self.tools = tuple(self.tool_map.values())
|
||||
logger.info(
|
||||
f"Connected to server with tools: {[tool.name for tool in response.tools]}"
|
||||
)
|
||||
|
||||
async def disconnect(self) -> None:
|
||||
"""Disconnect from the MCP server and clean up resources."""
|
||||
if self.session and self.exit_stack:
|
||||
await self.exit_stack.aclose()
|
||||
self.session = None
|
||||
self.tools = tuple()
|
||||
self.tool_map = {}
|
||||
logger.info("Disconnected from MCP server")
|
43
app/tool/run.py
Normal file
43
app/tool/run.py
Normal file
@ -0,0 +1,43 @@
|
||||
"""Utility to run shell commands asynchronously with a timeout."""
|
||||
|
||||
import asyncio
|
||||
|
||||
|
||||
TRUNCATED_MESSAGE: str = "<response clipped><NOTE>To save on context only part of this file has been shown to you. You should retry this tool after you have searched inside the file with `grep -n` in order to find the line numbers of what you are looking for.</NOTE>"
|
||||
MAX_RESPONSE_LEN: int = 16000
|
||||
|
||||
|
||||
def maybe_truncate(content: str, truncate_after: int | None = MAX_RESPONSE_LEN):
|
||||
"""Truncate content and append a notice if content exceeds the specified length."""
|
||||
return (
|
||||
content
|
||||
if not truncate_after or len(content) <= truncate_after
|
||||
else content[:truncate_after] + TRUNCATED_MESSAGE
|
||||
)
|
||||
|
||||
|
||||
async def run(
|
||||
cmd: str,
|
||||
timeout: float | None = 120.0, # seconds
|
||||
truncate_after: int | None = MAX_RESPONSE_LEN,
|
||||
):
|
||||
"""Run a shell command asynchronously with a timeout."""
|
||||
process = await asyncio.create_subprocess_shell(
|
||||
cmd, stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE
|
||||
)
|
||||
|
||||
try:
|
||||
stdout, stderr = await asyncio.wait_for(process.communicate(), timeout=timeout)
|
||||
return (
|
||||
process.returncode or 0,
|
||||
maybe_truncate(stdout.decode(), truncate_after=truncate_after),
|
||||
maybe_truncate(stderr.decode(), truncate_after=truncate_after),
|
||||
)
|
||||
except asyncio.TimeoutError as exc:
|
||||
try:
|
||||
process.kill()
|
||||
except ProcessLookupError:
|
||||
pass
|
||||
raise TimeoutError(
|
||||
f"Command '{cmd}' timed out after {timeout} seconds"
|
||||
) from exc
|
@ -1,6 +1,5 @@
|
||||
from app.tool.search.baidu_search import BaiduSearchEngine
|
||||
from app.tool.search.base import WebSearchEngine
|
||||
from app.tool.search.bing_search import BingSearchEngine
|
||||
from app.tool.search.duckduckgo_search import DuckDuckGoSearchEngine
|
||||
from app.tool.search.google_search import GoogleSearchEngine
|
||||
|
||||
@ -10,5 +9,4 @@ __all__ = [
|
||||
"BaiduSearchEngine",
|
||||
"DuckDuckGoSearchEngine",
|
||||
"GoogleSearchEngine",
|
||||
"BingSearchEngine",
|
||||
]
|
||||
|
@ -1,146 +0,0 @@
|
||||
from typing import List
|
||||
|
||||
import requests
|
||||
from bs4 import BeautifulSoup
|
||||
|
||||
from app.logger import logger
|
||||
from app.tool.search.base import WebSearchEngine
|
||||
|
||||
|
||||
ABSTRACT_MAX_LENGTH = 300
|
||||
|
||||
USER_AGENTS = [
|
||||
"Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/68.0.3440.106 Safari/537.36",
|
||||
"Mozilla/5.0 (compatible; Googlebot/2.1; +http://www.google.com/bot.html)",
|
||||
"Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) Ubuntu Chromium/49.0.2623.108 Chrome/49.0.2623.108 Safari/537.36",
|
||||
"Mozilla/5.0 (Windows; U; Windows NT 5.1; pt-BR) AppleWebKit/533.3 (KHTML, like Gecko) QtWeb Internet Browser/3.7 http://www.QtWeb.net",
|
||||
"Mozilla/5.0 (Windows NT 6.1) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/41.0.2228.0 Safari/537.36",
|
||||
"Mozilla/5.0 (Windows; U; Windows NT 5.1; en-US) AppleWebKit/532.2 (KHTML, like Gecko) ChromePlus/4.0.222.3 Chrome/4.0.222.3 Safari/532.2",
|
||||
"Mozilla/5.0 (Windows; U; Windows NT 5.1; en-US; rv:1.8.1.4pre) Gecko/20070404 K-Ninja/2.1.3",
|
||||
"Mozilla/5.0 (Future Star Technologies Corp.; Star-Blade OS; x86_64; U; en-US) iNet Browser 4.7",
|
||||
"Mozilla/5.0 (Windows; U; Windows NT 6.1; rv:2.2) Gecko/20110201",
|
||||
"Mozilla/5.0 (Windows; U; Windows NT 5.1; en-US; rv:1.8.1.13) Gecko/20080414 Firefox/2.0.0.13 Pogo/2.0.0.13.6866",
|
||||
]
|
||||
|
||||
HEADERS = {
|
||||
"Accept": "text/html,application/xhtml+xml,application/xml;q=0.9,image/webp,image/apng,*/*;q=0.8",
|
||||
"Content-Type": "application/x-www-form-urlencoded",
|
||||
"User-Agent": USER_AGENTS[0],
|
||||
"Referer": "https://www.bing.com/",
|
||||
"Accept-Encoding": "gzip, deflate",
|
||||
"Accept-Language": "zh-CN,zh;q=0.9",
|
||||
}
|
||||
|
||||
BING_HOST_URL = "https://www.bing.com"
|
||||
BING_SEARCH_URL = "https://www.bing.com/search?q="
|
||||
|
||||
|
||||
class BingSearchEngine(WebSearchEngine):
|
||||
session: requests.Session = None
|
||||
|
||||
def __init__(self, **data):
|
||||
"""Initialize the BingSearch tool with a requests session."""
|
||||
super().__init__(**data)
|
||||
self.session = requests.Session()
|
||||
self.session.headers.update(HEADERS)
|
||||
|
||||
def _search_sync(self, query: str, num_results: int = 10) -> List[str]:
|
||||
"""
|
||||
Synchronous Bing search implementation to retrieve a list of URLs matching a query.
|
||||
|
||||
Args:
|
||||
query (str): The search query to submit to Bing. Must not be empty.
|
||||
num_results (int, optional): The maximum number of URLs to return. Defaults to 10.
|
||||
|
||||
Returns:
|
||||
List[str]: A list of URLs from the search results, capped at `num_results`.
|
||||
Returns an empty list if the query is empty or no results are found.
|
||||
|
||||
Notes:
|
||||
- Pagination is handled by incrementing the `first` parameter and following `next_url` links.
|
||||
- If fewer results than `num_results` are available, all found URLs are returned.
|
||||
"""
|
||||
if not query:
|
||||
return []
|
||||
|
||||
list_result = []
|
||||
first = 1
|
||||
next_url = BING_SEARCH_URL + query
|
||||
|
||||
while len(list_result) < num_results:
|
||||
data, next_url = self._parse_html(
|
||||
next_url, rank_start=len(list_result), first=first
|
||||
)
|
||||
if data:
|
||||
list_result.extend([item["url"] for item in data])
|
||||
if not next_url:
|
||||
break
|
||||
first += 10
|
||||
|
||||
return list_result[:num_results]
|
||||
|
||||
def _parse_html(self, url: str, rank_start: int = 0, first: int = 1) -> tuple:
|
||||
"""
|
||||
Parse Bing search result HTML synchronously to extract search results and the next page URL.
|
||||
|
||||
Args:
|
||||
url (str): The URL of the Bing search results page to parse.
|
||||
rank_start (int, optional): The starting rank for numbering the search results. Defaults to 0.
|
||||
first (int, optional): Unused parameter (possibly legacy). Defaults to 1.
|
||||
Returns:
|
||||
tuple: A tuple containing:
|
||||
- list: A list of dictionaries with keys 'title', 'abstract', 'url', and 'rank' for each result.
|
||||
- str or None: The URL of the next results page, or None if there is no next page.
|
||||
"""
|
||||
try:
|
||||
res = self.session.get(url=url)
|
||||
res.encoding = "utf-8"
|
||||
root = BeautifulSoup(res.text, "lxml")
|
||||
|
||||
list_data = []
|
||||
ol_results = root.find("ol", id="b_results")
|
||||
if not ol_results:
|
||||
return [], None
|
||||
|
||||
for li in ol_results.find_all("li", class_="b_algo"):
|
||||
title = ""
|
||||
url = ""
|
||||
abstract = ""
|
||||
try:
|
||||
h2 = li.find("h2")
|
||||
if h2:
|
||||
title = h2.text.strip()
|
||||
url = h2.a["href"].strip()
|
||||
|
||||
p = li.find("p")
|
||||
if p:
|
||||
abstract = p.text.strip()
|
||||
|
||||
if ABSTRACT_MAX_LENGTH and len(abstract) > ABSTRACT_MAX_LENGTH:
|
||||
abstract = abstract[:ABSTRACT_MAX_LENGTH]
|
||||
|
||||
rank_start += 1
|
||||
list_data.append(
|
||||
{
|
||||
"title": title,
|
||||
"abstract": abstract,
|
||||
"url": url,
|
||||
"rank": rank_start,
|
||||
}
|
||||
)
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
next_btn = root.find("a", title="Next page")
|
||||
if not next_btn:
|
||||
return list_data, None
|
||||
|
||||
next_url = BING_HOST_URL + next_btn["href"]
|
||||
return list_data, next_url
|
||||
except Exception as e:
|
||||
logger.warning(f"Error parsing HTML: {e}")
|
||||
return [], None
|
||||
|
||||
def perform_search(self, query, num_results=10, *args, **kwargs):
|
||||
"""Bing search engine."""
|
||||
return self._search_sync(query, num_results=num_results)
|
@ -1,19 +1,11 @@
|
||||
"""File and directory manipulation tool with sandbox support."""
|
||||
|
||||
from collections import defaultdict
|
||||
from pathlib import Path
|
||||
from typing import Any, DefaultDict, List, Literal, Optional, get_args
|
||||
from typing import Literal, get_args
|
||||
|
||||
from app.config import config
|
||||
from app.exceptions import ToolError
|
||||
from app.tool import BaseTool
|
||||
from app.tool.base import CLIResult, ToolResult
|
||||
from app.tool.file_operators import (
|
||||
FileOperator,
|
||||
LocalFileOperator,
|
||||
PathLike,
|
||||
SandboxFileOperator,
|
||||
)
|
||||
from app.tool.run import run
|
||||
|
||||
|
||||
Command = Literal[
|
||||
@ -23,17 +15,12 @@ Command = Literal[
|
||||
"insert",
|
||||
"undo_edit",
|
||||
]
|
||||
|
||||
# Constants
|
||||
SNIPPET_LINES: int = 4
|
||||
MAX_RESPONSE_LEN: int = 16000
|
||||
TRUNCATED_MESSAGE: str = (
|
||||
"<response clipped><NOTE>To save on context only part of this file has been shown to you. "
|
||||
"You should retry this tool after you have searched inside the file with `grep -n` "
|
||||
"in order to find the line numbers of what you are looking for.</NOTE>"
|
||||
)
|
||||
|
||||
# Tool description
|
||||
MAX_RESPONSE_LEN: int = 16000
|
||||
|
||||
TRUNCATED_MESSAGE: str = "<response clipped><NOTE>To save on context only part of this file has been shown to you. You should retry this tool after you have searched inside the file with `grep -n` in order to find the line numbers of what you are looking for.</NOTE>"
|
||||
|
||||
_STR_REPLACE_EDITOR_DESCRIPTION = """Custom editing tool for viewing, creating and editing files
|
||||
* State is persistent across command calls and discussions with the user
|
||||
* If `path` is a file, `view` displays the result of applying `cat -n`. If `path` is a directory, `view` lists non-hidden files and directories up to 2 levels deep
|
||||
@ -48,17 +35,17 @@ Notes for using the `str_replace` command:
|
||||
"""
|
||||
|
||||
|
||||
def maybe_truncate(
|
||||
content: str, truncate_after: Optional[int] = MAX_RESPONSE_LEN
|
||||
) -> str:
|
||||
def maybe_truncate(content: str, truncate_after: int | None = MAX_RESPONSE_LEN):
|
||||
"""Truncate content and append a notice if content exceeds the specified length."""
|
||||
if not truncate_after or len(content) <= truncate_after:
|
||||
return content
|
||||
return content[:truncate_after] + TRUNCATED_MESSAGE
|
||||
return (
|
||||
content
|
||||
if not truncate_after or len(content) <= truncate_after
|
||||
else content[:truncate_after] + TRUNCATED_MESSAGE
|
||||
)
|
||||
|
||||
|
||||
class StrReplaceEditor(BaseTool):
|
||||
"""A tool for viewing, creating, and editing files with sandbox support."""
|
||||
"""A tool for executing bash commands"""
|
||||
|
||||
name: str = "str_replace_editor"
|
||||
description: str = _STR_REPLACE_EDITOR_DESCRIPTION
|
||||
@ -98,18 +85,8 @@ class StrReplaceEditor(BaseTool):
|
||||
},
|
||||
"required": ["command", "path"],
|
||||
}
|
||||
_file_history: DefaultDict[PathLike, List[str]] = defaultdict(list)
|
||||
_local_operator: LocalFileOperator = LocalFileOperator()
|
||||
_sandbox_operator: SandboxFileOperator = SandboxFileOperator()
|
||||
|
||||
# def _get_operator(self, use_sandbox: bool) -> FileOperator:
|
||||
def _get_operator(self) -> FileOperator:
|
||||
"""Get the appropriate file operator based on execution mode."""
|
||||
return (
|
||||
self._sandbox_operator
|
||||
if config.sandbox.use_sandbox
|
||||
else self._local_operator
|
||||
)
|
||||
_file_history: list = defaultdict(list)
|
||||
|
||||
async def execute(
|
||||
self,
|
||||
@ -121,30 +98,24 @@ class StrReplaceEditor(BaseTool):
|
||||
old_str: str | None = None,
|
||||
new_str: str | None = None,
|
||||
insert_line: int | None = None,
|
||||
**kwargs: Any,
|
||||
**kwargs,
|
||||
) -> str:
|
||||
"""Execute a file operation command."""
|
||||
# Get the appropriate file operator
|
||||
operator = self._get_operator()
|
||||
|
||||
# Validate path and command combination
|
||||
await self.validate_path(command, Path(path), operator)
|
||||
|
||||
# Execute the appropriate command
|
||||
_path = Path(path)
|
||||
self.validate_path(command, _path)
|
||||
if command == "view":
|
||||
result = await self.view(path, view_range, operator)
|
||||
result = await self.view(_path, view_range)
|
||||
elif command == "create":
|
||||
if file_text is None:
|
||||
raise ToolError("Parameter `file_text` is required for command: create")
|
||||
await operator.write_file(path, file_text)
|
||||
self._file_history[path].append(file_text)
|
||||
result = ToolResult(output=f"File created successfully at: {path}")
|
||||
self.write_file(_path, file_text)
|
||||
self._file_history[_path].append(file_text)
|
||||
result = ToolResult(output=f"File created successfully at: {_path}")
|
||||
elif command == "str_replace":
|
||||
if old_str is None:
|
||||
raise ToolError(
|
||||
"Parameter `old_str` is required for command: str_replace"
|
||||
)
|
||||
result = await self.str_replace(path, old_str, new_str, operator)
|
||||
result = self.str_replace(_path, old_str, new_str)
|
||||
elif command == "insert":
|
||||
if insert_line is None:
|
||||
raise ToolError(
|
||||
@ -152,145 +123,92 @@ class StrReplaceEditor(BaseTool):
|
||||
)
|
||||
if new_str is None:
|
||||
raise ToolError("Parameter `new_str` is required for command: insert")
|
||||
result = await self.insert(path, insert_line, new_str, operator)
|
||||
result = self.insert(_path, insert_line, new_str)
|
||||
elif command == "undo_edit":
|
||||
result = await self.undo_edit(path, operator)
|
||||
result = self.undo_edit(_path)
|
||||
else:
|
||||
# This should be caught by type checking, but we include it for safety
|
||||
raise ToolError(
|
||||
f'Unrecognized command {command}. The allowed commands for the {self.name} tool are: {", ".join(get_args(Command))}'
|
||||
)
|
||||
|
||||
return str(result)
|
||||
|
||||
async def validate_path(
|
||||
self, command: str, path: Path, operator: FileOperator
|
||||
) -> None:
|
||||
"""Validate path and command combination based on execution environment."""
|
||||
# Check if path is absolute
|
||||
def validate_path(self, command: str, path: Path):
|
||||
"""
|
||||
Check that the path/command combination is valid.
|
||||
"""
|
||||
# Check if its an absolute path
|
||||
if not path.is_absolute():
|
||||
raise ToolError(f"The path {path} is not an absolute path")
|
||||
|
||||
# Only check if path exists for non-create commands
|
||||
if command != "create":
|
||||
if not await operator.exists(path):
|
||||
raise ToolError(
|
||||
f"The path {path} does not exist. Please provide a valid path."
|
||||
)
|
||||
|
||||
# Check if path is a directory
|
||||
is_dir = await operator.is_directory(path)
|
||||
if is_dir and command != "view":
|
||||
suggested_path = Path("") / path
|
||||
raise ToolError(
|
||||
f"The path {path} is not an absolute path, it should start with `/`. Maybe you meant {suggested_path}?"
|
||||
)
|
||||
# Check if path exists
|
||||
if not path.exists() and command != "create":
|
||||
raise ToolError(
|
||||
f"The path {path} does not exist. Please provide a valid path."
|
||||
)
|
||||
if path.exists() and command == "create":
|
||||
raise ToolError(
|
||||
f"File already exists at: {path}. Cannot overwrite files using command `create`."
|
||||
)
|
||||
# Check if the path points to a directory
|
||||
if path.is_dir():
|
||||
if command != "view":
|
||||
raise ToolError(
|
||||
f"The path {path} is a directory and only the `view` command can be used on directories"
|
||||
)
|
||||
|
||||
# Check if file exists for create command
|
||||
elif command == "create":
|
||||
exists = await operator.exists(path)
|
||||
if exists:
|
||||
raise ToolError(
|
||||
f"File already exists at: {path}. Cannot overwrite files using command `create`."
|
||||
)
|
||||
|
||||
async def view(
|
||||
self,
|
||||
path: PathLike,
|
||||
view_range: Optional[List[int]] = None,
|
||||
operator: FileOperator = None,
|
||||
) -> CLIResult:
|
||||
"""Display file or directory content."""
|
||||
# Determine if path is a directory
|
||||
is_dir = await operator.is_directory(path)
|
||||
|
||||
if is_dir:
|
||||
# Directory handling
|
||||
async def view(self, path: Path, view_range: list[int] | None = None):
|
||||
"""Implement the view command"""
|
||||
if path.is_dir():
|
||||
if view_range:
|
||||
raise ToolError(
|
||||
"The `view_range` parameter is not allowed when `path` points to a directory."
|
||||
)
|
||||
|
||||
return await self._view_directory(path, operator)
|
||||
else:
|
||||
# File handling
|
||||
return await self._view_file(path, operator, view_range)
|
||||
|
||||
@staticmethod
|
||||
async def _view_directory(path: PathLike, operator: FileOperator) -> CLIResult:
|
||||
"""Display directory contents."""
|
||||
find_cmd = f"find {path} -maxdepth 2 -not -path '*/\\.*'"
|
||||
|
||||
# Execute command using the operator
|
||||
returncode, stdout, stderr = await operator.run_command(find_cmd)
|
||||
|
||||
if not stderr:
|
||||
stdout = (
|
||||
f"Here's the files and directories up to 2 levels deep in {path}, "
|
||||
f"excluding hidden items:\n{stdout}\n"
|
||||
_, stdout, stderr = await run(
|
||||
rf"find {path} -maxdepth 2 -not -path '*/\.*'"
|
||||
)
|
||||
if not stderr:
|
||||
stdout = f"Here's the files and directories up to 2 levels deep in {path}, excluding hidden items:\n{stdout}\n"
|
||||
return CLIResult(output=stdout, error=stderr)
|
||||
|
||||
return CLIResult(output=stdout, error=stderr)
|
||||
|
||||
async def _view_file(
|
||||
self,
|
||||
path: PathLike,
|
||||
operator: FileOperator,
|
||||
view_range: Optional[List[int]] = None,
|
||||
) -> CLIResult:
|
||||
"""Display file content, optionally within a specified line range."""
|
||||
# Read file content
|
||||
file_content = await operator.read_file(path)
|
||||
file_content = self.read_file(path)
|
||||
init_line = 1
|
||||
|
||||
# Apply view range if specified
|
||||
if view_range:
|
||||
if len(view_range) != 2 or not all(isinstance(i, int) for i in view_range):
|
||||
raise ToolError(
|
||||
"Invalid `view_range`. It should be a list of two integers."
|
||||
)
|
||||
|
||||
file_lines = file_content.split("\n")
|
||||
n_lines_file = len(file_lines)
|
||||
init_line, final_line = view_range
|
||||
|
||||
# Validate view range
|
||||
if init_line < 1 or init_line > n_lines_file:
|
||||
raise ToolError(
|
||||
f"Invalid `view_range`: {view_range}. Its first element `{init_line}` should be "
|
||||
f"within the range of lines of the file: {[1, n_lines_file]}"
|
||||
f"Invalid `view_range`: {view_range}. Its first element `{init_line}` should be within the range of lines of the file: {[1, n_lines_file]}"
|
||||
)
|
||||
if final_line > n_lines_file:
|
||||
raise ToolError(
|
||||
f"Invalid `view_range`: {view_range}. Its second element `{final_line}` should be "
|
||||
f"smaller than the number of lines in the file: `{n_lines_file}`"
|
||||
f"Invalid `view_range`: {view_range}. Its second element `{final_line}` should be smaller than the number of lines in the file: `{n_lines_file}`"
|
||||
)
|
||||
if final_line != -1 and final_line < init_line:
|
||||
raise ToolError(
|
||||
f"Invalid `view_range`: {view_range}. Its second element `{final_line}` should be "
|
||||
f"larger or equal than its first `{init_line}`"
|
||||
f"Invalid `view_range`: {view_range}. Its second element `{final_line}` should be larger or equal than its first `{init_line}`"
|
||||
)
|
||||
|
||||
# Apply range
|
||||
if final_line == -1:
|
||||
file_content = "\n".join(file_lines[init_line - 1 :])
|
||||
else:
|
||||
file_content = "\n".join(file_lines[init_line - 1 : final_line])
|
||||
|
||||
# Format and return result
|
||||
return CLIResult(
|
||||
output=self._make_output(file_content, str(path), init_line=init_line)
|
||||
)
|
||||
|
||||
async def str_replace(
|
||||
self,
|
||||
path: PathLike,
|
||||
old_str: str,
|
||||
new_str: Optional[str] = None,
|
||||
operator: FileOperator = None,
|
||||
) -> CLIResult:
|
||||
"""Replace a unique string in a file with a new string."""
|
||||
# Read file content and expand tabs
|
||||
file_content = (await operator.read_file(path)).expandtabs()
|
||||
def str_replace(self, path: Path, old_str: str, new_str: str | None):
|
||||
"""Implement the str_replace command, which replaces old_str with new_str in the file content"""
|
||||
# Read the file content
|
||||
file_content = self.read_file(path).expandtabs()
|
||||
old_str = old_str.expandtabs()
|
||||
new_str = new_str.expandtabs() if new_str is not None else ""
|
||||
|
||||
@ -301,7 +219,6 @@ class StrReplaceEditor(BaseTool):
|
||||
f"No replacement was performed, old_str `{old_str}` did not appear verbatim in {path}."
|
||||
)
|
||||
elif occurrences > 1:
|
||||
# Find line numbers of occurrences
|
||||
file_content_lines = file_content.split("\n")
|
||||
lines = [
|
||||
idx + 1
|
||||
@ -309,17 +226,16 @@ class StrReplaceEditor(BaseTool):
|
||||
if old_str in line
|
||||
]
|
||||
raise ToolError(
|
||||
f"No replacement was performed. Multiple occurrences of old_str `{old_str}` "
|
||||
f"in lines {lines}. Please ensure it is unique"
|
||||
f"No replacement was performed. Multiple occurrences of old_str `{old_str}` in lines {lines}. Please ensure it is unique"
|
||||
)
|
||||
|
||||
# Replace old_str with new_str
|
||||
new_file_content = file_content.replace(old_str, new_str)
|
||||
|
||||
# Write the new content to the file
|
||||
await operator.write_file(path, new_file_content)
|
||||
self.write_file(path, new_file_content)
|
||||
|
||||
# Save the original content to history
|
||||
# Save the content to history
|
||||
self._file_history[path].append(file_content)
|
||||
|
||||
# Create a snippet of the edited section
|
||||
@ -337,50 +253,36 @@ class StrReplaceEditor(BaseTool):
|
||||
|
||||
return CLIResult(output=success_msg)
|
||||
|
||||
async def insert(
|
||||
self,
|
||||
path: PathLike,
|
||||
insert_line: int,
|
||||
new_str: str,
|
||||
operator: FileOperator = None,
|
||||
) -> CLIResult:
|
||||
"""Insert text at a specific line in a file."""
|
||||
# Read and prepare content
|
||||
file_text = (await operator.read_file(path)).expandtabs()
|
||||
def insert(self, path: Path, insert_line: int, new_str: str):
|
||||
"""Implement the insert command, which inserts new_str at the specified line in the file content."""
|
||||
file_text = self.read_file(path).expandtabs()
|
||||
new_str = new_str.expandtabs()
|
||||
file_text_lines = file_text.split("\n")
|
||||
n_lines_file = len(file_text_lines)
|
||||
|
||||
# Validate insert_line
|
||||
if insert_line < 0 or insert_line > n_lines_file:
|
||||
raise ToolError(
|
||||
f"Invalid `insert_line` parameter: {insert_line}. It should be within "
|
||||
f"the range of lines of the file: {[0, n_lines_file]}"
|
||||
f"Invalid `insert_line` parameter: {insert_line}. It should be within the range of lines of the file: {[0, n_lines_file]}"
|
||||
)
|
||||
|
||||
# Perform insertion
|
||||
new_str_lines = new_str.split("\n")
|
||||
new_file_text_lines = (
|
||||
file_text_lines[:insert_line]
|
||||
+ new_str_lines
|
||||
+ file_text_lines[insert_line:]
|
||||
)
|
||||
|
||||
# Create a snippet for preview
|
||||
snippet_lines = (
|
||||
file_text_lines[max(0, insert_line - SNIPPET_LINES) : insert_line]
|
||||
+ new_str_lines
|
||||
+ file_text_lines[insert_line : insert_line + SNIPPET_LINES]
|
||||
)
|
||||
|
||||
# Join lines and write to file
|
||||
new_file_text = "\n".join(new_file_text_lines)
|
||||
snippet = "\n".join(snippet_lines)
|
||||
|
||||
await operator.write_file(path, new_file_text)
|
||||
self.write_file(path, new_file_text)
|
||||
self._file_history[path].append(file_text)
|
||||
|
||||
# Prepare success message
|
||||
success_msg = f"The file {path} has been edited. "
|
||||
success_msg += self._make_output(
|
||||
snippet,
|
||||
@ -388,43 +290,51 @@ class StrReplaceEditor(BaseTool):
|
||||
max(1, insert_line - SNIPPET_LINES + 1),
|
||||
)
|
||||
success_msg += "Review the changes and make sure they are as expected (correct indentation, no duplicate lines, etc). Edit the file again if necessary."
|
||||
|
||||
return CLIResult(output=success_msg)
|
||||
|
||||
async def undo_edit(
|
||||
self, path: PathLike, operator: FileOperator = None
|
||||
) -> CLIResult:
|
||||
"""Revert the last edit made to a file."""
|
||||
def undo_edit(self, path: Path):
|
||||
"""Implement the undo_edit command."""
|
||||
if not self._file_history[path]:
|
||||
raise ToolError(f"No edit history found for {path}.")
|
||||
|
||||
old_text = self._file_history[path].pop()
|
||||
await operator.write_file(path, old_text)
|
||||
self.write_file(path, old_text)
|
||||
|
||||
return CLIResult(
|
||||
output=f"Last edit to {path} undone successfully. {self._make_output(old_text, str(path))}"
|
||||
)
|
||||
|
||||
def read_file(self, path: Path):
|
||||
"""Read the content of a file from a given path; raise a ToolError if an error occurs."""
|
||||
try:
|
||||
return path.read_text()
|
||||
except Exception as e:
|
||||
raise ToolError(f"Ran into {e} while trying to read {path}") from None
|
||||
|
||||
def write_file(self, path: Path, file: str):
|
||||
"""Write the content of a file to a given path; raise a ToolError if an error occurs."""
|
||||
try:
|
||||
path.write_text(file)
|
||||
except Exception as e:
|
||||
raise ToolError(f"Ran into {e} while trying to write to {path}") from None
|
||||
|
||||
def _make_output(
|
||||
self,
|
||||
file_content: str,
|
||||
file_descriptor: str,
|
||||
init_line: int = 1,
|
||||
expand_tabs: bool = True,
|
||||
) -> str:
|
||||
"""Format file content for display with line numbers."""
|
||||
):
|
||||
"""Generate output for the CLI based on the content of a file."""
|
||||
file_content = maybe_truncate(file_content)
|
||||
if expand_tabs:
|
||||
file_content = file_content.expandtabs()
|
||||
|
||||
# Add line numbers to each line
|
||||
file_content = "\n".join(
|
||||
[
|
||||
f"{i + init_line:6}\t{line}"
|
||||
for i, line in enumerate(file_content.split("\n"))
|
||||
]
|
||||
)
|
||||
|
||||
return (
|
||||
f"Here's the result of running `cat -n` on {file_descriptor}:\n"
|
||||
+ file_content
|
||||
|
@ -8,9 +8,6 @@ from app.tool.base import BaseTool, ToolFailure, ToolResult
|
||||
class ToolCollection:
|
||||
"""A collection of defined tools."""
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
def __init__(self, *tools: BaseTool):
|
||||
self.tools = tools
|
||||
self.tool_map = {tool.name: tool for tool in tools}
|
||||
|
@ -4,11 +4,9 @@ from typing import List
|
||||
from tenacity import retry, stop_after_attempt, wait_exponential
|
||||
|
||||
from app.config import config
|
||||
from app.logger import logger
|
||||
from app.tool.base import BaseTool
|
||||
from app.tool.search import (
|
||||
BaiduSearchEngine,
|
||||
BingSearchEngine,
|
||||
DuckDuckGoSearchEngine,
|
||||
GoogleSearchEngine,
|
||||
WebSearchEngine,
|
||||
@ -39,14 +37,11 @@ class WebSearch(BaseTool):
|
||||
"google": GoogleSearchEngine(),
|
||||
"baidu": BaiduSearchEngine(),
|
||||
"duckduckgo": DuckDuckGoSearchEngine(),
|
||||
"bing": BingSearchEngine(),
|
||||
}
|
||||
|
||||
async def execute(self, query: str, num_results: int = 10) -> List[str]:
|
||||
"""
|
||||
Execute a Web search and return a list of URLs.
|
||||
Tries engines in order based on configuration, falling back if an engine fails with errors.
|
||||
If all engines fail, it will wait and retry up to the configured number of times.
|
||||
|
||||
Args:
|
||||
query (str): The search query to submit to the search engine.
|
||||
@ -55,109 +50,37 @@ class WebSearch(BaseTool):
|
||||
Returns:
|
||||
List[str]: A list of URLs matching the search query.
|
||||
"""
|
||||
# Get retry settings from config
|
||||
retry_delay = 60 # Default to 60 seconds
|
||||
max_retries = 3 # Default to 3 retries
|
||||
|
||||
if config.search_config:
|
||||
retry_delay = getattr(config.search_config, "retry_delay", 60)
|
||||
max_retries = getattr(config.search_config, "max_retries", 3)
|
||||
|
||||
# Try searching with retries when all engines fail
|
||||
for retry_count in range(
|
||||
max_retries + 1
|
||||
): # +1 because first try is not a retry
|
||||
links = await self._try_all_engines(query, num_results)
|
||||
if links:
|
||||
return links
|
||||
|
||||
if retry_count < max_retries:
|
||||
# All engines failed, wait and retry
|
||||
logger.warning(
|
||||
f"All search engines failed. Waiting {retry_delay} seconds before retry {retry_count + 1}/{max_retries}..."
|
||||
)
|
||||
await asyncio.sleep(retry_delay)
|
||||
else:
|
||||
logger.error(
|
||||
f"All search engines failed after {max_retries} retries. Giving up."
|
||||
)
|
||||
|
||||
return []
|
||||
|
||||
async def _try_all_engines(self, query: str, num_results: int) -> List[str]:
|
||||
"""
|
||||
Try all search engines in the configured order.
|
||||
|
||||
Args:
|
||||
query (str): The search query to submit to the search engine.
|
||||
num_results (int): The number of search results to return.
|
||||
|
||||
Returns:
|
||||
List[str]: A list of URLs matching the search query, or empty list if all engines fail.
|
||||
"""
|
||||
engine_order = self._get_engine_order()
|
||||
failed_engines = []
|
||||
|
||||
for engine_name in engine_order:
|
||||
engine = self._search_engine[engine_name]
|
||||
try:
|
||||
logger.info(f"🔎 Attempting search with {engine_name.capitalize()}...")
|
||||
links = await self._perform_search_with_engine(
|
||||
engine, query, num_results
|
||||
)
|
||||
if links:
|
||||
if failed_engines:
|
||||
logger.info(
|
||||
f"Search successful with {engine_name.capitalize()} after trying: {', '.join(failed_engines)}"
|
||||
)
|
||||
return links
|
||||
except Exception as e:
|
||||
failed_engines.append(engine_name.capitalize())
|
||||
is_rate_limit = "429" in str(e) or "Too Many Requests" in str(e)
|
||||
|
||||
if is_rate_limit:
|
||||
logger.warning(
|
||||
f"⚠️ {engine_name.capitalize()} search engine rate limit exceeded, trying next engine..."
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
f"⚠️ {engine_name.capitalize()} search failed with error: {e}"
|
||||
)
|
||||
|
||||
if failed_engines:
|
||||
logger.error(f"All search engines failed: {', '.join(failed_engines)}")
|
||||
print(f"Search engine '{engine_name}' failed with error: {e}")
|
||||
return []
|
||||
|
||||
def _get_engine_order(self) -> List[str]:
|
||||
"""
|
||||
Determines the order in which to try search engines.
|
||||
Preferred engine is first (based on configuration), followed by fallback engines,
|
||||
and then the remaining engines.
|
||||
Preferred engine is first (based on configuration), followed by the remaining engines.
|
||||
|
||||
Returns:
|
||||
List[str]: Ordered list of search engine names.
|
||||
"""
|
||||
preferred = "google"
|
||||
fallbacks = []
|
||||
|
||||
if config.search_config:
|
||||
if config.search_config.engine:
|
||||
preferred = config.search_config.engine.lower()
|
||||
if config.search_config.fallback_engines:
|
||||
fallbacks = [
|
||||
engine.lower() for engine in config.search_config.fallback_engines
|
||||
]
|
||||
if config.search_config and config.search_config.engine:
|
||||
preferred = config.search_config.engine.lower()
|
||||
|
||||
engine_order = []
|
||||
# Add preferred engine first
|
||||
if preferred in self._search_engine:
|
||||
engine_order.append(preferred)
|
||||
|
||||
# Add configured fallback engines in order
|
||||
for fallback in fallbacks:
|
||||
if fallback in self._search_engine and fallback not in engine_order:
|
||||
engine_order.append(fallback)
|
||||
|
||||
for key in self._search_engine:
|
||||
if key not in engine_order:
|
||||
engine_order.append(key)
|
||||
return engine_order
|
||||
|
||||
@retry(
|
||||
|
@ -6,14 +6,6 @@ api_key = "YOUR_API_KEY" # Your API key
|
||||
max_tokens = 8192 # Maximum number of tokens in the response
|
||||
temperature = 0.0 # Controls randomness
|
||||
|
||||
# [llm] # Amazon Bedrock
|
||||
# api_type = "aws" # Required
|
||||
# model = "us.anthropic.claude-3-7-sonnet-20250219-v1:0" # Bedrock supported modelID
|
||||
# base_url = "bedrock-runtime.us-west-2.amazonaws.com" # Not used now
|
||||
# max_tokens = 8192
|
||||
# temperature = 1.0
|
||||
# api_key = "bear" # Required but not used for Bedrock
|
||||
|
||||
# [llm] #AZURE OPENAI:
|
||||
# api_type= 'azure'
|
||||
# model = "YOUR_MODEL_NAME" #"gpt-4o-mini"
|
||||
@ -73,20 +65,3 @@ temperature = 0.0 # Controls randomness for vision mod
|
||||
# [search]
|
||||
# Search engine for agent to use. Default is "Google", can be set to "Baidu" or "DuckDuckGo".
|
||||
#engine = "Google"
|
||||
# Fallback engine order. Default is ["DuckDuckGo", "Baidu"] - will try in this order after primary engine fails.
|
||||
#fallback_engines = ["DuckDuckGo", "Baidu"]
|
||||
# Seconds to wait before retrying all engines again when they all fail due to rate limits. Default is 60.
|
||||
#retry_delay = 60
|
||||
# Maximum number of times to retry all engines when all fail. Default is 3.
|
||||
#max_retries = 3
|
||||
|
||||
|
||||
## Sandbox configuration
|
||||
#[sandbox]
|
||||
#use_sandbox = false
|
||||
#image = "python:3.12-slim"
|
||||
#work_dir = "/workspace"
|
||||
#memory_limit = "1g" # 512m
|
||||
#cpu_limit = 2.0
|
||||
#timeout = 300
|
||||
#network_enabled = true
|
||||
|
@ -20,16 +20,6 @@ baidusearch~=1.0.3
|
||||
duckduckgo_search~=7.5.1
|
||||
|
||||
aiofiles~=24.1.0
|
||||
pydantic_core~=2.27.2
|
||||
pydantic_core~=2.32.0
|
||||
colorama~=0.4.6
|
||||
playwright~=1.50.0
|
||||
|
||||
docker~=7.1.0
|
||||
pytest~=8.3.5
|
||||
pytest-asyncio~=0.25.3
|
||||
|
||||
mcp~=1.4.1
|
||||
httpx>=0.27.0
|
||||
tomli>=2.0.0
|
||||
|
||||
boto3~=1.37.16
|
||||
|
116
run_mcp.py
116
run_mcp.py
@ -1,116 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
import argparse
|
||||
import asyncio
|
||||
import sys
|
||||
|
||||
from app.agent.mcp import MCPAgent
|
||||
from app.config import config
|
||||
from app.logger import logger
|
||||
|
||||
|
||||
class MCPRunner:
|
||||
"""Runner class for MCP Agent with proper path handling and configuration."""
|
||||
|
||||
def __init__(self):
|
||||
self.root_path = config.root_path
|
||||
self.server_reference = "app.mcp.server"
|
||||
self.agent = MCPAgent()
|
||||
|
||||
async def initialize(
|
||||
self,
|
||||
connection_type: str,
|
||||
server_url: str | None = None,
|
||||
) -> None:
|
||||
"""Initialize the MCP agent with the appropriate connection."""
|
||||
logger.info(f"Initializing MCPAgent with {connection_type} connection...")
|
||||
|
||||
if connection_type == "stdio":
|
||||
await self.agent.initialize(
|
||||
connection_type="stdio",
|
||||
command=sys.executable,
|
||||
args=["-m", self.server_reference],
|
||||
)
|
||||
else: # sse
|
||||
await self.agent.initialize(connection_type="sse", server_url=server_url)
|
||||
|
||||
logger.info(f"Connected to MCP server via {connection_type}")
|
||||
|
||||
async def run_interactive(self) -> None:
|
||||
"""Run the agent in interactive mode."""
|
||||
print("\nMCP Agent Interactive Mode (type 'exit' to quit)\n")
|
||||
while True:
|
||||
user_input = input("\nEnter your request: ")
|
||||
if user_input.lower() in ["exit", "quit", "q"]:
|
||||
break
|
||||
response = await self.agent.run(user_input)
|
||||
print(f"\nAgent: {response}")
|
||||
|
||||
async def run_single_prompt(self, prompt: str) -> None:
|
||||
"""Run the agent with a single prompt."""
|
||||
await self.agent.run(prompt)
|
||||
|
||||
async def run_default(self) -> None:
|
||||
"""Run the agent in default mode."""
|
||||
prompt = input("Enter your prompt: ")
|
||||
if not prompt.strip():
|
||||
logger.warning("Empty prompt provided.")
|
||||
return
|
||||
|
||||
logger.warning("Processing your request...")
|
||||
await self.agent.run(prompt)
|
||||
logger.info("Request processing completed.")
|
||||
|
||||
async def cleanup(self) -> None:
|
||||
"""Clean up agent resources."""
|
||||
await self.agent.cleanup()
|
||||
logger.info("Session ended")
|
||||
|
||||
|
||||
def parse_args() -> argparse.Namespace:
|
||||
"""Parse command line arguments."""
|
||||
parser = argparse.ArgumentParser(description="Run the MCP Agent")
|
||||
parser.add_argument(
|
||||
"--connection",
|
||||
"-c",
|
||||
choices=["stdio", "sse"],
|
||||
default="stdio",
|
||||
help="Connection type: stdio or sse",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--server-url",
|
||||
default="http://127.0.0.1:8000/sse",
|
||||
help="URL for SSE connection",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--interactive", "-i", action="store_true", help="Run in interactive mode"
|
||||
)
|
||||
parser.add_argument("--prompt", "-p", help="Single prompt to execute and exit")
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
async def run_mcp() -> None:
|
||||
"""Main entry point for the MCP runner."""
|
||||
args = parse_args()
|
||||
runner = MCPRunner()
|
||||
|
||||
try:
|
||||
await runner.initialize(args.connection, args.server_url)
|
||||
|
||||
if args.prompt:
|
||||
await runner.run_single_prompt(args.prompt)
|
||||
elif args.interactive:
|
||||
await runner.run_interactive()
|
||||
else:
|
||||
await runner.run_default()
|
||||
|
||||
except KeyboardInterrupt:
|
||||
logger.info("Program interrupted by user")
|
||||
except Exception as e:
|
||||
logger.error(f"Error running MCPAgent: {str(e)}", exc_info=True)
|
||||
sys.exit(1)
|
||||
finally:
|
||||
await runner.cleanup()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(run_mcp())
|
@ -1,11 +0,0 @@
|
||||
# coding: utf-8
|
||||
# A shortcut to launch OpenManus MCP server, where its introduction also solves other import issues.
|
||||
from app.mcp.server import MCPServer, parse_args
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = parse_args()
|
||||
|
||||
# Create and run server (maintaining original flow)
|
||||
server = MCPServer()
|
||||
server.run(transport=args.transport)
|
2
setup.py
2
setup.py
@ -31,7 +31,7 @@ setup(
|
||||
"browser-use~=0.1.40",
|
||||
"googlesearch-python~=1.3.0",
|
||||
"aiofiles~=24.1.0",
|
||||
"pydantic_core>=2.27.2,<2.28.0",
|
||||
"pydantic_core>=2.27.2,<2.33.0",
|
||||
"colorama~=0.4.6",
|
||||
],
|
||||
classifiers=[
|
||||
|
@ -1,110 +0,0 @@
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from typing import AsyncGenerator
|
||||
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
|
||||
from app.config import SandboxSettings
|
||||
from app.sandbox.client import LocalSandboxClient, create_sandbox_client
|
||||
|
||||
|
||||
@pytest_asyncio.fixture(scope="function")
|
||||
async def local_client() -> AsyncGenerator[LocalSandboxClient, None]:
|
||||
"""Creates a local sandbox client for testing."""
|
||||
client = create_sandbox_client()
|
||||
try:
|
||||
yield client
|
||||
finally:
|
||||
await client.cleanup()
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def temp_dir() -> Path:
|
||||
"""Creates a temporary directory for testing."""
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
yield Path(tmp_dir)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_sandbox_creation(local_client: LocalSandboxClient):
|
||||
"""Tests sandbox creation with specific configuration."""
|
||||
config = SandboxSettings(
|
||||
image="python:3.12-slim",
|
||||
work_dir="/workspace",
|
||||
memory_limit="512m",
|
||||
cpu_limit=0.5,
|
||||
)
|
||||
|
||||
await local_client.create(config)
|
||||
result = await local_client.run_command("python3 --version")
|
||||
assert "Python 3.10" in result
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_local_command_execution(local_client: LocalSandboxClient):
|
||||
"""Tests command execution in local sandbox."""
|
||||
await local_client.create()
|
||||
|
||||
result = await local_client.run_command("echo 'test'")
|
||||
assert result.strip() == "test"
|
||||
|
||||
with pytest.raises(Exception):
|
||||
await local_client.run_command("sleep 10", timeout=1)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_local_file_operations(local_client: LocalSandboxClient, temp_dir: Path):
|
||||
"""Tests file operations in local sandbox."""
|
||||
await local_client.create()
|
||||
|
||||
# Test write and read operations
|
||||
test_content = "Hello, World!"
|
||||
await local_client.write_file("/workspace/test.txt", test_content)
|
||||
content = await local_client.read_file("/workspace/test.txt")
|
||||
assert content.strip() == test_content
|
||||
|
||||
# Test copying file to container
|
||||
src_file = temp_dir / "src.txt"
|
||||
src_file.write_text("Copy to container")
|
||||
await local_client.copy_to(str(src_file), "/workspace/copied.txt")
|
||||
content = await local_client.read_file("/workspace/copied.txt")
|
||||
assert content.strip() == "Copy to container"
|
||||
|
||||
# Test copying file from container
|
||||
dst_file = temp_dir / "dst.txt"
|
||||
await local_client.copy_from("/workspace/test.txt", str(dst_file))
|
||||
assert dst_file.read_text().strip() == test_content
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_local_volume_binding(local_client: LocalSandboxClient, temp_dir: Path):
|
||||
"""Tests volume binding in local sandbox."""
|
||||
bind_path = str(temp_dir)
|
||||
volume_bindings = {bind_path: "/data"}
|
||||
|
||||
await local_client.create(volume_bindings=volume_bindings)
|
||||
|
||||
test_file = temp_dir / "test.txt"
|
||||
test_file.write_text("Volume test")
|
||||
|
||||
content = await local_client.read_file("/data/test.txt")
|
||||
assert "Volume test" in content
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_local_error_handling(local_client: LocalSandboxClient):
|
||||
"""Tests error handling in local sandbox."""
|
||||
await local_client.create()
|
||||
|
||||
with pytest.raises(Exception) as exc:
|
||||
await local_client.read_file("/nonexistent.txt")
|
||||
assert "not found" in str(exc.value).lower()
|
||||
|
||||
with pytest.raises(Exception) as exc:
|
||||
await local_client.copy_from("/nonexistent.txt", "local.txt")
|
||||
assert "not found" in str(exc.value).lower()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main(["-v", __file__])
|
@ -1,104 +0,0 @@
|
||||
"""Tests for the AsyncDockerizedTerminal implementation."""
|
||||
|
||||
import docker
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
|
||||
from app.sandbox.core.terminal import AsyncDockerizedTerminal
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def docker_client():
|
||||
"""Fixture providing a Docker client."""
|
||||
return docker.from_env()
|
||||
|
||||
|
||||
@pytest_asyncio.fixture(scope="module")
|
||||
async def docker_container(docker_client):
|
||||
"""Fixture providing a test Docker container."""
|
||||
container = docker_client.containers.run(
|
||||
"python:3.12-slim",
|
||||
"tail -f /dev/null",
|
||||
name="test_container",
|
||||
detach=True,
|
||||
remove=True,
|
||||
)
|
||||
yield container
|
||||
container.stop()
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def terminal(docker_container):
|
||||
"""Fixture providing an initialized AsyncDockerizedTerminal instance."""
|
||||
terminal = AsyncDockerizedTerminal(
|
||||
docker_container,
|
||||
working_dir="/workspace",
|
||||
env_vars={"TEST_VAR": "test_value"},
|
||||
default_timeout=30,
|
||||
)
|
||||
await terminal.init()
|
||||
yield terminal
|
||||
await terminal.close()
|
||||
|
||||
|
||||
class TestAsyncDockerizedTerminal:
|
||||
"""Test cases for AsyncDockerizedTerminal."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_basic_command_execution(self, terminal):
|
||||
"""Test basic command execution functionality."""
|
||||
result = await terminal.run_command("echo 'Hello World'")
|
||||
assert "Hello World" in result
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_environment_variables(self, terminal):
|
||||
"""Test environment variable setting and access."""
|
||||
result = await terminal.run_command("echo $TEST_VAR")
|
||||
assert "test_value" in result
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_working_directory(self, terminal):
|
||||
"""Test working directory setup."""
|
||||
result = await terminal.run_command("pwd")
|
||||
assert "/workspace" == result
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_command_timeout(self, docker_container):
|
||||
"""Test command timeout functionality."""
|
||||
terminal = AsyncDockerizedTerminal(docker_container, default_timeout=1)
|
||||
await terminal.init()
|
||||
try:
|
||||
with pytest.raises(TimeoutError):
|
||||
await terminal.run_command("sleep 5")
|
||||
finally:
|
||||
await terminal.close()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_multiple_commands(self, terminal):
|
||||
"""Test execution of multiple commands in sequence."""
|
||||
cmd1 = await terminal.run_command("echo 'First'")
|
||||
cmd2 = await terminal.run_command("echo 'Second'")
|
||||
assert "First" in cmd1
|
||||
assert "Second" in cmd2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_session_cleanup(self, docker_container):
|
||||
"""Test proper cleanup of resources."""
|
||||
terminal = AsyncDockerizedTerminal(docker_container)
|
||||
await terminal.init()
|
||||
assert terminal.session is not None
|
||||
await terminal.close()
|
||||
# Verify session is properly cleaned up
|
||||
# Note: session object still exists, but internal connection is closed
|
||||
assert terminal.session is not None
|
||||
|
||||
|
||||
# Configure pytest-asyncio
|
||||
def pytest_configure(config):
|
||||
"""Configure pytest-asyncio."""
|
||||
config.addinivalue_line("asyncio_mode", "strict")
|
||||
config.addinivalue_line("asyncio_default_fixture_loop_scope", "function")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main(["-v", __file__])
|
@ -1,152 +0,0 @@
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
|
||||
from app.sandbox.core.sandbox import DockerSandbox, SandboxSettings
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def sandbox_config():
|
||||
"""Creates sandbox configuration for testing."""
|
||||
return SandboxSettings(
|
||||
image="python:3.12-slim",
|
||||
work_dir="/workspace",
|
||||
memory_limit="1g",
|
||||
cpu_limit=0.5,
|
||||
network_enabled=True,
|
||||
)
|
||||
|
||||
|
||||
@pytest_asyncio.fixture(scope="module")
|
||||
async def sandbox(sandbox_config):
|
||||
"""Creates and manages a test sandbox instance."""
|
||||
sandbox = DockerSandbox(sandbox_config)
|
||||
await sandbox.create()
|
||||
try:
|
||||
yield sandbox
|
||||
finally:
|
||||
await sandbox.cleanup()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_sandbox_working_directory(sandbox):
|
||||
"""Tests sandbox working directory configuration."""
|
||||
result = await sandbox.terminal.run_command("pwd")
|
||||
assert result.strip() == "/workspace"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_sandbox_file_operations(sandbox):
|
||||
"""Tests sandbox file read/write operations."""
|
||||
# Test file writing
|
||||
test_content = "Hello from sandbox!"
|
||||
await sandbox.write_file("/workspace/test.txt", test_content)
|
||||
|
||||
# Test file reading
|
||||
content = await sandbox.read_file("/workspace/test.txt")
|
||||
assert content.strip() == test_content
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_sandbox_python_execution(sandbox):
|
||||
"""Tests Python code execution in sandbox."""
|
||||
# Write test file
|
||||
await sandbox.write_file("/workspace/test.txt", "Hello from file!")
|
||||
|
||||
# Write Python script
|
||||
python_code = """
|
||||
print("Hello from Python!")
|
||||
with open('/workspace/test.txt') as f:
|
||||
print(f.read())
|
||||
"""
|
||||
await sandbox.write_file("/workspace/test.py", python_code)
|
||||
|
||||
# Execute script and verify output
|
||||
result = await sandbox.terminal.run_command("python3 /workspace/test.py")
|
||||
assert "Hello from Python!" in result
|
||||
assert "Hello from file!" in result
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_sandbox_file_persistence(sandbox):
|
||||
"""Tests file persistence in sandbox."""
|
||||
# Create multiple files
|
||||
files = {
|
||||
"file1.txt": "Content 1",
|
||||
"file2.txt": "Content 2",
|
||||
"nested/file3.txt": "Content 3",
|
||||
}
|
||||
|
||||
# Write files
|
||||
for path, content in files.items():
|
||||
await sandbox.write_file(f"/workspace/{path}", content)
|
||||
|
||||
# Verify file contents
|
||||
for path, expected_content in files.items():
|
||||
content = await sandbox.read_file(f"/workspace/{path}")
|
||||
assert content.strip() == expected_content
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_sandbox_python_environment(sandbox):
|
||||
"""Tests Python environment configuration."""
|
||||
# Test Python version
|
||||
result = await sandbox.terminal.run_command("python3 --version")
|
||||
assert "Python 3.10" in result
|
||||
|
||||
# Test basic module imports
|
||||
python_code = """
|
||||
import sys
|
||||
import os
|
||||
import json
|
||||
print("Python is working!")
|
||||
"""
|
||||
await sandbox.write_file("/workspace/env_test.py", python_code)
|
||||
result = await sandbox.terminal.run_command("python3 /workspace/env_test.py")
|
||||
assert "Python is working!" in result
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_sandbox_network_access(sandbox):
|
||||
"""Tests sandbox network access."""
|
||||
if not sandbox.config.network_enabled:
|
||||
pytest.skip("Network access is disabled")
|
||||
|
||||
# Test network connectivity
|
||||
await sandbox.terminal.run_command("apt update && apt install curl -y")
|
||||
result = await sandbox.terminal.run_command("curl -I https://www.example.com")
|
||||
assert "HTTP/2 200" in result
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_sandbox_cleanup(sandbox_config):
|
||||
"""Tests sandbox cleanup process."""
|
||||
sandbox = DockerSandbox(sandbox_config)
|
||||
await sandbox.create()
|
||||
|
||||
# Create test files
|
||||
await sandbox.write_file("/workspace/test.txt", "test")
|
||||
container_id = sandbox.terminal.container.id
|
||||
# Perform cleanup
|
||||
await sandbox.cleanup()
|
||||
|
||||
# Verify container has been removed
|
||||
import docker
|
||||
|
||||
client = docker.from_env()
|
||||
containers = client.containers.list(all=True)
|
||||
assert not any(c.id == container_id for c in containers)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_sandbox_error_handling():
|
||||
"""Tests error handling with invalid configuration."""
|
||||
# Test invalid configuration
|
||||
invalid_config = SandboxSettings(image="nonexistent:latest", work_dir="/invalid")
|
||||
|
||||
sandbox = DockerSandbox(invalid_config)
|
||||
with pytest.raises(Exception):
|
||||
await sandbox.create()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main(["-v", __file__])
|
@ -1,138 +0,0 @@
|
||||
import asyncio
|
||||
import os
|
||||
import tempfile
|
||||
from typing import AsyncGenerator
|
||||
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
|
||||
from app.sandbox.core.manager import SandboxManager
|
||||
|
||||
|
||||
@pytest_asyncio.fixture(scope="function")
|
||||
async def manager() -> AsyncGenerator[SandboxManager, None]:
|
||||
"""Creates a sandbox manager instance.
|
||||
|
||||
Uses function scope to ensure each test case has its own manager instance.
|
||||
"""
|
||||
manager = SandboxManager(max_sandboxes=2, idle_timeout=60, cleanup_interval=30)
|
||||
try:
|
||||
yield manager
|
||||
finally:
|
||||
# Ensure all resources are cleaned up
|
||||
await manager.cleanup()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def temp_file():
|
||||
"""Creates a temporary test file."""
|
||||
with tempfile.NamedTemporaryFile(mode="w+", delete=False) as f:
|
||||
f.write("test content")
|
||||
path = f.name
|
||||
try:
|
||||
yield path
|
||||
finally:
|
||||
if os.path.exists(path):
|
||||
os.unlink(path)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_sandbox(manager):
|
||||
"""Tests sandbox creation."""
|
||||
# Create default sandbox
|
||||
sandbox_id = await manager.create_sandbox()
|
||||
assert sandbox_id in manager._sandboxes
|
||||
assert sandbox_id in manager._last_used
|
||||
|
||||
# Verify sandbox functionality
|
||||
sandbox = await manager.get_sandbox(sandbox_id)
|
||||
result = await sandbox.run_command("echo 'test'")
|
||||
assert result.strip() == "test"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_max_sandboxes_limit(manager):
|
||||
"""Tests maximum sandbox limit enforcement."""
|
||||
created_sandboxes = []
|
||||
try:
|
||||
# Create maximum number of sandboxes
|
||||
for _ in range(manager.max_sandboxes):
|
||||
sandbox_id = await manager.create_sandbox()
|
||||
created_sandboxes.append(sandbox_id)
|
||||
|
||||
# Verify created sandbox count
|
||||
assert len(manager._sandboxes) == manager.max_sandboxes
|
||||
|
||||
# Attempting to create additional sandbox should fail
|
||||
with pytest.raises(RuntimeError) as exc_info:
|
||||
await manager.create_sandbox()
|
||||
|
||||
# Verify error message
|
||||
expected_message = (
|
||||
f"Maximum number of sandboxes ({manager.max_sandboxes}) reached"
|
||||
)
|
||||
assert str(exc_info.value) == expected_message
|
||||
|
||||
finally:
|
||||
# Clean up all created sandboxes
|
||||
for sandbox_id in created_sandboxes:
|
||||
try:
|
||||
await manager.delete_sandbox(sandbox_id)
|
||||
except Exception as e:
|
||||
print(f"Failed to cleanup sandbox {sandbox_id}: {e}")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_nonexistent_sandbox(manager):
|
||||
"""Tests retrieving a non-existent sandbox."""
|
||||
with pytest.raises(KeyError, match="Sandbox .* not found"):
|
||||
await manager.get_sandbox("nonexistent-id")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_sandbox_cleanup(manager):
|
||||
"""Tests sandbox cleanup functionality."""
|
||||
sandbox_id = await manager.create_sandbox()
|
||||
assert sandbox_id in manager._sandboxes
|
||||
|
||||
await manager.delete_sandbox(sandbox_id)
|
||||
assert sandbox_id not in manager._sandboxes
|
||||
assert sandbox_id not in manager._last_used
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_idle_sandbox_cleanup(manager):
|
||||
"""Tests automatic cleanup of idle sandboxes."""
|
||||
# Set short idle timeout
|
||||
manager.idle_timeout = 0.1
|
||||
|
||||
sandbox_id = await manager.create_sandbox()
|
||||
assert sandbox_id in manager._sandboxes
|
||||
|
||||
# Wait longer than idle timeout
|
||||
await asyncio.sleep(0.2)
|
||||
|
||||
# Trigger cleanup
|
||||
await manager._cleanup_idle_sandboxes()
|
||||
assert sandbox_id not in manager._sandboxes
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_manager_cleanup(manager):
|
||||
"""Tests manager cleanup functionality."""
|
||||
# Create multiple sandboxes
|
||||
sandbox_ids = []
|
||||
for _ in range(2):
|
||||
sandbox_id = await manager.create_sandbox()
|
||||
sandbox_ids.append(sandbox_id)
|
||||
|
||||
# Clean up all resources
|
||||
await manager.cleanup()
|
||||
|
||||
# Verify all sandboxes have been cleaned up
|
||||
assert not manager._sandboxes
|
||||
assert not manager._last_used
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main(["-v", __file__])
|
@ -1 +0,0 @@
|
||||
This is a sample file. Files generated by OpenManus are stored in the current folder by default.
|
Loading…
x
Reference in New Issue
Block a user