init project
This commit is contained in:
commit
d028e64a98
34
.gitattributes
vendored
Normal file
34
.gitattributes
vendored
Normal file
@ -0,0 +1,34 @@
|
|||||||
|
# HTML code is incorrectly calculated into statistics, so ignore them
|
||||||
|
*.html linguist-detectable=false
|
||||||
|
|
||||||
|
# Auto detect text files and perform LF normalization
|
||||||
|
* text=auto eol=lf
|
||||||
|
|
||||||
|
# Ensure shell scripts use LF (Linux style) line endings on Windows
|
||||||
|
*.sh text eol=lf
|
||||||
|
|
||||||
|
# Treat specific binary files as binary and prevent line ending conversion
|
||||||
|
*.png binary
|
||||||
|
*.jpg binary
|
||||||
|
*.gif binary
|
||||||
|
*.ico binary
|
||||||
|
*.jpeg binary
|
||||||
|
*.mp3 binary
|
||||||
|
*.zip binary
|
||||||
|
*.bin binary
|
||||||
|
|
||||||
|
|
||||||
|
# Preserve original line endings for specific document files
|
||||||
|
*.doc text eol=crlf
|
||||||
|
*.docx text eol=crlf
|
||||||
|
*.pdf binary
|
||||||
|
|
||||||
|
# Ensure source code and script files use LF line endings
|
||||||
|
*.py text eol=lf
|
||||||
|
*.js text eol=lf
|
||||||
|
*.html text eol=lf
|
||||||
|
*.css text eol=lf
|
||||||
|
|
||||||
|
# Specify custom diff driver for specific file types
|
||||||
|
*.md diff=markdown
|
||||||
|
*.json diff=json
|
180
.gitignore
vendored
Normal file
180
.gitignore
vendored
Normal file
@ -0,0 +1,180 @@
|
|||||||
|
# Byte-compiled / optimized / DLL files
|
||||||
|
__pycache__/
|
||||||
|
*.py[cod]
|
||||||
|
*$py.class
|
||||||
|
|
||||||
|
# C extensions
|
||||||
|
*.so
|
||||||
|
|
||||||
|
# Distribution / packaging
|
||||||
|
.Python
|
||||||
|
build/
|
||||||
|
develop-eggs/
|
||||||
|
dist/
|
||||||
|
downloads/
|
||||||
|
eggs/
|
||||||
|
.eggs/
|
||||||
|
lib/
|
||||||
|
lib64/
|
||||||
|
parts/
|
||||||
|
sdist/
|
||||||
|
var/
|
||||||
|
wheels/
|
||||||
|
share/python-wheels/
|
||||||
|
*.egg-info/
|
||||||
|
.installed.cfg
|
||||||
|
*.egg
|
||||||
|
MANIFEST
|
||||||
|
|
||||||
|
# PyInstaller
|
||||||
|
# Usually these files are written by a python script from a template
|
||||||
|
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
||||||
|
*.manifest
|
||||||
|
*.spec
|
||||||
|
|
||||||
|
# Installer logs
|
||||||
|
pip-log.txt
|
||||||
|
pip-delete-this-directory.txt
|
||||||
|
|
||||||
|
# Unit test / coverage reports
|
||||||
|
htmlcov/
|
||||||
|
.tox/
|
||||||
|
.nox/
|
||||||
|
.coverage
|
||||||
|
.coverage.*
|
||||||
|
.cache
|
||||||
|
nosetests.xml
|
||||||
|
coverage.xml
|
||||||
|
*.cover
|
||||||
|
*.py,cover
|
||||||
|
.hypothesis/
|
||||||
|
.pytest_cache/
|
||||||
|
cover/
|
||||||
|
|
||||||
|
# Translations
|
||||||
|
*.mo
|
||||||
|
*.pot
|
||||||
|
|
||||||
|
# Django stuff:
|
||||||
|
*.log
|
||||||
|
local_settings.py
|
||||||
|
db.sqlite3
|
||||||
|
db.sqlite3-journal
|
||||||
|
|
||||||
|
# Flask stuff:
|
||||||
|
instance/
|
||||||
|
.webassets-cache
|
||||||
|
|
||||||
|
# Scrapy stuff:
|
||||||
|
.scrapy
|
||||||
|
|
||||||
|
# Sphinx documentation
|
||||||
|
docs/_build/
|
||||||
|
|
||||||
|
# PyBuilder
|
||||||
|
.pybuilder/
|
||||||
|
target/
|
||||||
|
|
||||||
|
# Jupyter Notebook
|
||||||
|
.ipynb_checkpoints
|
||||||
|
|
||||||
|
# IPython
|
||||||
|
profile_default/
|
||||||
|
ipython_config.py
|
||||||
|
|
||||||
|
# pyenv
|
||||||
|
# For a library or package, you might want to ignore these files since the code is
|
||||||
|
# intended to run in multiple environments; otherwise, check them in:
|
||||||
|
# .python-version
|
||||||
|
|
||||||
|
# pipenv
|
||||||
|
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
||||||
|
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
||||||
|
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
||||||
|
# install all needed dependencies.
|
||||||
|
#Pipfile.lock
|
||||||
|
|
||||||
|
# UV
|
||||||
|
# Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control.
|
||||||
|
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
||||||
|
# commonly ignored for libraries.
|
||||||
|
#uv.lock
|
||||||
|
|
||||||
|
# poetry
|
||||||
|
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
|
||||||
|
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
||||||
|
# commonly ignored for libraries.
|
||||||
|
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
|
||||||
|
#poetry.lock
|
||||||
|
|
||||||
|
# pdm
|
||||||
|
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
|
||||||
|
#pdm.lock
|
||||||
|
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
|
||||||
|
# in version control.
|
||||||
|
# https://pdm.fming.dev/latest/usage/project/#working-with-version-control
|
||||||
|
.pdm.toml
|
||||||
|
.pdm-python
|
||||||
|
.pdm-build/
|
||||||
|
|
||||||
|
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
|
||||||
|
__pypackages__/
|
||||||
|
|
||||||
|
# Celery stuff
|
||||||
|
celerybeat-schedule
|
||||||
|
celerybeat.pid
|
||||||
|
|
||||||
|
# SageMath parsed files
|
||||||
|
*.sage.py
|
||||||
|
|
||||||
|
# Environments
|
||||||
|
.env
|
||||||
|
.venv
|
||||||
|
env/
|
||||||
|
venv/
|
||||||
|
ENV/
|
||||||
|
env.bak/
|
||||||
|
venv.bak/
|
||||||
|
|
||||||
|
# Spyder project settings
|
||||||
|
.spyderproject
|
||||||
|
.spyproject
|
||||||
|
|
||||||
|
# Rope project settings
|
||||||
|
.ropeproject
|
||||||
|
|
||||||
|
# mkdocs documentation
|
||||||
|
/site
|
||||||
|
|
||||||
|
# mypy
|
||||||
|
.mypy_cache/
|
||||||
|
.dmypy.json
|
||||||
|
dmypy.json
|
||||||
|
|
||||||
|
# Pyre type checker
|
||||||
|
.pyre/
|
||||||
|
|
||||||
|
# pytype static type analyzer
|
||||||
|
.pytype/
|
||||||
|
|
||||||
|
# Cython debug symbols
|
||||||
|
cython_debug/
|
||||||
|
|
||||||
|
# PyCharm
|
||||||
|
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
|
||||||
|
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
||||||
|
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
||||||
|
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
||||||
|
.idea/
|
||||||
|
|
||||||
|
# PyPI configuration file
|
||||||
|
.pypirc
|
||||||
|
|
||||||
|
# Logs
|
||||||
|
logs/
|
||||||
|
|
||||||
|
# Data
|
||||||
|
data/
|
||||||
|
|
||||||
|
# Workspace
|
||||||
|
workspace/
|
39
.pre-commit-config.yaml
Normal file
39
.pre-commit-config.yaml
Normal file
@ -0,0 +1,39 @@
|
|||||||
|
repos:
|
||||||
|
- repo: https://github.com/psf/black
|
||||||
|
rev: 23.1.0
|
||||||
|
hooks:
|
||||||
|
- id: black
|
||||||
|
|
||||||
|
- repo: https://github.com/pre-commit/pre-commit-hooks
|
||||||
|
rev: v4.4.0
|
||||||
|
hooks:
|
||||||
|
- id: trailing-whitespace
|
||||||
|
- id: end-of-file-fixer
|
||||||
|
- id: check-yaml
|
||||||
|
- id: check-added-large-files
|
||||||
|
|
||||||
|
- repo: https://github.com/PyCQA/autoflake
|
||||||
|
rev: v2.0.1
|
||||||
|
hooks:
|
||||||
|
- id: autoflake
|
||||||
|
args: [
|
||||||
|
--remove-all-unused-imports,
|
||||||
|
--ignore-init-module-imports, # 忽略 __init__.py 中的导入
|
||||||
|
--expand-star-imports,
|
||||||
|
--remove-duplicate-keys,
|
||||||
|
--remove-unused-variables,
|
||||||
|
--recursive,
|
||||||
|
--in-place,
|
||||||
|
--exclude=__init__.py, # 排除 __init__.py 文件
|
||||||
|
]
|
||||||
|
files: \.py$ # 只处理 Python 文件
|
||||||
|
|
||||||
|
- repo: https://github.com/pycqa/isort
|
||||||
|
rev: 5.12.0
|
||||||
|
hooks:
|
||||||
|
- id: isort
|
||||||
|
args: [
|
||||||
|
"--profile", "black", # 使用 black 兼容的配置
|
||||||
|
"--filter-files",
|
||||||
|
"--lines-after-imports=2",
|
||||||
|
]
|
21
LICENSE
Normal file
21
LICENSE
Normal file
@ -0,0 +1,21 @@
|
|||||||
|
MIT License
|
||||||
|
|
||||||
|
Copyright (c) 2025 manna_and_poem
|
||||||
|
|
||||||
|
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||||
|
of this software and associated documentation files (the "Software"), to deal
|
||||||
|
in the Software without restriction, including without limitation the rights
|
||||||
|
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||||
|
copies of the Software, and to permit persons to whom the Software is
|
||||||
|
furnished to do so, subject to the following conditions:
|
||||||
|
|
||||||
|
The above copyright notice and this permission notice shall be included in all
|
||||||
|
copies or substantial portions of the Software.
|
||||||
|
|
||||||
|
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||||
|
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||||
|
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||||
|
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||||
|
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||||
|
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||||
|
SOFTWARE.
|
0
app/__init__.py
Normal file
0
app/__init__.py
Normal file
14
app/agent/__init__.py
Normal file
14
app/agent/__init__.py
Normal file
@ -0,0 +1,14 @@
|
|||||||
|
from app.agent.base import BaseAgent
|
||||||
|
from app.agent.planning import PlanningAgent
|
||||||
|
from app.agent.react import ReActAgent
|
||||||
|
from app.agent.swe import SWEAgent
|
||||||
|
from app.agent.toolcall import ToolCallAgent
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"BaseAgent",
|
||||||
|
"PlanningAgent",
|
||||||
|
"ReActAgent",
|
||||||
|
"SWEAgent",
|
||||||
|
"ToolCallAgent",
|
||||||
|
]
|
191
app/agent/base.py
Normal file
191
app/agent/base.py
Normal file
@ -0,0 +1,191 @@
|
|||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from contextlib import asynccontextmanager
|
||||||
|
from typing import List, Literal, Optional
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Field, model_validator
|
||||||
|
|
||||||
|
from app.llm import LLM
|
||||||
|
from app.logger import logger
|
||||||
|
from app.schema import AgentState, Memory, Message
|
||||||
|
|
||||||
|
|
||||||
|
class BaseAgent(BaseModel, ABC):
|
||||||
|
"""Abstract base class for managing agent state and execution.
|
||||||
|
|
||||||
|
Provides foundational functionality for state transitions, memory management,
|
||||||
|
and a step-based execution loop. Subclasses must implement the `step` method.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Core attributes
|
||||||
|
name: str = Field(..., description="Unique name of the agent")
|
||||||
|
description: Optional[str] = Field(None, description="Optional agent description")
|
||||||
|
|
||||||
|
# Prompts
|
||||||
|
system_prompt: Optional[str] = Field(
|
||||||
|
None, description="System-level instruction prompt"
|
||||||
|
)
|
||||||
|
next_step_prompt: Optional[str] = Field(
|
||||||
|
None, description="Prompt for determining next action"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Dependencies
|
||||||
|
llm: LLM = Field(default_factory=LLM, description="Language model instance")
|
||||||
|
memory: Memory = Field(default_factory=Memory, description="Agent's memory store")
|
||||||
|
state: AgentState = Field(
|
||||||
|
default=AgentState.IDLE, description="Current agent state"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Execution control
|
||||||
|
max_steps: int = Field(default=10, description="Maximum steps before termination")
|
||||||
|
current_step: int = Field(default=0, description="Current step in execution")
|
||||||
|
|
||||||
|
duplicate_threshold: int = 2
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
arbitrary_types_allowed = True
|
||||||
|
extra = "allow" # Allow extra fields for flexibility in subclasses
|
||||||
|
|
||||||
|
@model_validator(mode="after")
|
||||||
|
def initialize_agent(self) -> "BaseAgent":
|
||||||
|
"""Initialize agent with default settings if not provided."""
|
||||||
|
if self.llm is None or not isinstance(self.llm, LLM):
|
||||||
|
self.llm = LLM(config_name=self.name.lower())
|
||||||
|
if not isinstance(self.memory, Memory):
|
||||||
|
self.memory = Memory()
|
||||||
|
return self
|
||||||
|
|
||||||
|
@asynccontextmanager
|
||||||
|
async def state_context(self, new_state: AgentState):
|
||||||
|
"""Context manager for safe agent state transitions.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
new_state: The state to transition to during the context.
|
||||||
|
|
||||||
|
Yields:
|
||||||
|
None: Allows execution within the new state.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If the new_state is invalid.
|
||||||
|
"""
|
||||||
|
if not isinstance(new_state, AgentState):
|
||||||
|
raise ValueError(f"Invalid state: {new_state}")
|
||||||
|
|
||||||
|
previous_state = self.state
|
||||||
|
self.state = new_state
|
||||||
|
try:
|
||||||
|
yield
|
||||||
|
except Exception as e:
|
||||||
|
self.state = AgentState.ERROR # Transition to ERROR on failure
|
||||||
|
raise e
|
||||||
|
finally:
|
||||||
|
self.state = previous_state # Revert to previous state
|
||||||
|
|
||||||
|
def update_memory(
|
||||||
|
self,
|
||||||
|
role: Literal["user", "system", "assistant", "tool"],
|
||||||
|
content: str,
|
||||||
|
**kwargs,
|
||||||
|
) -> None:
|
||||||
|
"""Add a message to the agent's memory.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
role: The role of the message sender (user, system, assistant, tool).
|
||||||
|
content: The message content.
|
||||||
|
**kwargs: Additional arguments (e.g., tool_call_id for tool messages).
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If the role is unsupported.
|
||||||
|
"""
|
||||||
|
message_map = {
|
||||||
|
"user": Message.user_message,
|
||||||
|
"system": Message.system_message,
|
||||||
|
"assistant": Message.assistant_message,
|
||||||
|
"tool": lambda content, **kw: Message.tool_message(content, **kw),
|
||||||
|
}
|
||||||
|
|
||||||
|
if role not in message_map:
|
||||||
|
raise ValueError(f"Unsupported message role: {role}")
|
||||||
|
|
||||||
|
msg_factory = message_map[role]
|
||||||
|
msg = msg_factory(content, **kwargs) if role == "tool" else msg_factory(content)
|
||||||
|
self.memory.add_message(msg)
|
||||||
|
|
||||||
|
async def run(self, request: Optional[str] = None) -> str:
|
||||||
|
"""Execute the agent's main loop asynchronously.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
request: Optional initial user request to process.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A string summarizing the execution results.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
RuntimeError: If the agent is not in IDLE state at start.
|
||||||
|
"""
|
||||||
|
if self.state != AgentState.IDLE:
|
||||||
|
raise RuntimeError(f"Cannot run agent from state: {self.state}")
|
||||||
|
|
||||||
|
if request:
|
||||||
|
self.update_memory("user", request)
|
||||||
|
|
||||||
|
results: List[str] = []
|
||||||
|
async with self.state_context(AgentState.RUNNING):
|
||||||
|
while (
|
||||||
|
self.current_step < self.max_steps and self.state != AgentState.FINISHED
|
||||||
|
):
|
||||||
|
self.current_step += 1
|
||||||
|
logger.info(f"Executing step {self.current_step}/{self.max_steps}")
|
||||||
|
step_result = await self.step()
|
||||||
|
|
||||||
|
# Check for stuck state
|
||||||
|
if self.is_stuck():
|
||||||
|
self.handle_stuck_state()
|
||||||
|
|
||||||
|
results.append(f"Step {self.current_step}: {step_result}")
|
||||||
|
|
||||||
|
if self.current_step >= self.max_steps:
|
||||||
|
results.append(f"Terminated: Reached max steps ({self.max_steps})")
|
||||||
|
|
||||||
|
return "\n".join(results) if results else "No steps executed"
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def step(self) -> str:
|
||||||
|
"""Execute a single step in the agent's workflow.
|
||||||
|
|
||||||
|
Must be implemented by subclasses to define specific behavior.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def handle_stuck_state(self):
|
||||||
|
"""Handle stuck state by adding a prompt to change strategy"""
|
||||||
|
stuck_prompt = "\
|
||||||
|
Observed duplicate responses. Consider new strategies and avoid repeating ineffective paths already attempted."
|
||||||
|
self.next_step_prompt = f"{stuck_prompt}\n{self.next_step_prompt}"
|
||||||
|
logger.warning(f"Agent detected stuck state. Added prompt: {stuck_prompt}")
|
||||||
|
|
||||||
|
def is_stuck(self) -> bool:
|
||||||
|
"""Check if the agent is stuck in a loop by detecting duplicate content"""
|
||||||
|
if len(self.memory.messages) < 2:
|
||||||
|
return False
|
||||||
|
|
||||||
|
last_message = self.memory.messages[-1]
|
||||||
|
if not last_message.content:
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Count identical content occurrences
|
||||||
|
duplicate_count = sum(
|
||||||
|
1
|
||||||
|
for msg in reversed(self.memory.messages[:-1])
|
||||||
|
if msg.role == "assistant" and msg.content == last_message.content
|
||||||
|
)
|
||||||
|
|
||||||
|
return duplicate_count >= self.duplicate_threshold
|
||||||
|
|
||||||
|
@property
|
||||||
|
def messages(self) -> List[Message]:
|
||||||
|
"""Retrieve a list of messages from the agent's memory."""
|
||||||
|
return self.memory.messages
|
||||||
|
|
||||||
|
@messages.setter
|
||||||
|
def messages(self, value: List[Message]):
|
||||||
|
"""Set the list of messages in the agent's memory."""
|
||||||
|
self.memory.messages = value
|
285
app/agent/planning.py
Normal file
285
app/agent/planning.py
Normal file
@ -0,0 +1,285 @@
|
|||||||
|
import time
|
||||||
|
from typing import Dict, List, Literal, Optional
|
||||||
|
|
||||||
|
from pydantic import Field, model_validator
|
||||||
|
|
||||||
|
from app.agent.toolcall import ToolCallAgent
|
||||||
|
from app.logger import logger
|
||||||
|
from app.schema import Message, ToolCall
|
||||||
|
from app.tool import PlanningTool, Terminate, ToolCollection
|
||||||
|
|
||||||
|
|
||||||
|
PLANNING_SYSTEM_PROMPT = """
|
||||||
|
You are an expert Planning Agent tasked with solving complex problems by creating and managing structured plans.
|
||||||
|
Your job is:
|
||||||
|
1. Analyze requests to understand the task scope
|
||||||
|
2. Create clear, actionable plans with the `planning` tool
|
||||||
|
3. Execute steps using available tools as needed
|
||||||
|
4. Track progress and adapt plans dynamically
|
||||||
|
5. Use `finish` to conclude when the task is complete
|
||||||
|
|
||||||
|
Available tools will vary by task but may include:
|
||||||
|
- `planning`: Create, update, and track plans (commands: create, update, mark_step, etc.)
|
||||||
|
- `finish`: End the task when complete
|
||||||
|
|
||||||
|
Break tasks into logical, sequential steps. Think about dependencies and verification methods.
|
||||||
|
"""
|
||||||
|
|
||||||
|
NEXT_STEP_PROMPT = """
|
||||||
|
Based on the current state, what's your next step?
|
||||||
|
Consider:
|
||||||
|
1. Do you need to create or refine a plan?
|
||||||
|
2. Are you ready to execute a specific step?
|
||||||
|
3. Have you completed the task?
|
||||||
|
|
||||||
|
Provide reasoning, then select the appropriate tool or action.
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
class PlanningAgent(ToolCallAgent):
|
||||||
|
"""
|
||||||
|
An agent that creates and manages plans to solve tasks.
|
||||||
|
|
||||||
|
This agent uses a planning tool to create and manage structured plans,
|
||||||
|
and tracks progress through individual steps until task completion.
|
||||||
|
"""
|
||||||
|
|
||||||
|
name: str = "planning"
|
||||||
|
description: str = "An agent that creates and manages plans to solve tasks"
|
||||||
|
|
||||||
|
system_prompt: str = PLANNING_SYSTEM_PROMPT
|
||||||
|
next_step_prompt: str = NEXT_STEP_PROMPT
|
||||||
|
|
||||||
|
available_tools: ToolCollection = Field(
|
||||||
|
default_factory=lambda: ToolCollection(PlanningTool(), Terminate())
|
||||||
|
)
|
||||||
|
tool_choices: Literal["none", "auto", "required"] = "auto"
|
||||||
|
special_tool_names: List[str] = Field(default_factory=lambda: [Terminate().name])
|
||||||
|
|
||||||
|
tool_calls: List[ToolCall] = Field(default_factory=list)
|
||||||
|
active_plan_id: Optional[str] = Field(default=None)
|
||||||
|
|
||||||
|
# Add a dictionary to track the step status for each tool call
|
||||||
|
step_execution_tracker: Dict[str, Dict] = Field(default_factory=dict)
|
||||||
|
current_step_index: Optional[int] = None
|
||||||
|
|
||||||
|
max_steps: int = 20
|
||||||
|
|
||||||
|
@model_validator(mode="after")
|
||||||
|
def initialize_plan_and_verify_tools(self) -> "PlanningAgent":
|
||||||
|
"""Initialize the agent with a default plan ID and validate required tools."""
|
||||||
|
self.active_plan_id = f"plan_{int(time.time())}"
|
||||||
|
|
||||||
|
if "planning" not in self.available_tools.tool_map:
|
||||||
|
self.available_tools.add_tool(PlanningTool())
|
||||||
|
|
||||||
|
return self
|
||||||
|
|
||||||
|
async def think(self) -> bool:
|
||||||
|
"""Decide the next action based on plan status."""
|
||||||
|
prompt = (
|
||||||
|
f"CURRENT PLAN STATUS:\n{await self.get_plan()}\n\n{self.next_step_prompt}"
|
||||||
|
if self.active_plan_id
|
||||||
|
else self.next_step_prompt
|
||||||
|
)
|
||||||
|
self.messages.append(Message.user_message(prompt))
|
||||||
|
|
||||||
|
# Get the current step index before thinking
|
||||||
|
self.current_step_index = await self._get_current_step_index()
|
||||||
|
|
||||||
|
result = await super().think()
|
||||||
|
|
||||||
|
# After thinking, if we decided to execute a tool and it's not a planning tool or special tool,
|
||||||
|
# associate it with the current step for tracking
|
||||||
|
if result and self.tool_calls:
|
||||||
|
latest_tool_call = self.tool_calls[0] # Get the most recent tool call
|
||||||
|
if (
|
||||||
|
latest_tool_call.function.name != "planning"
|
||||||
|
and latest_tool_call.function.name not in self.special_tool_names
|
||||||
|
and self.current_step_index is not None
|
||||||
|
):
|
||||||
|
self.step_execution_tracker[latest_tool_call.id] = {
|
||||||
|
"step_index": self.current_step_index,
|
||||||
|
"tool_name": latest_tool_call.function.name,
|
||||||
|
"status": "pending", # Will be updated after execution
|
||||||
|
}
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
async def act(self) -> str:
|
||||||
|
"""Execute a step and track its completion status."""
|
||||||
|
result = await super().act()
|
||||||
|
|
||||||
|
# After executing the tool, update the plan status
|
||||||
|
if self.tool_calls:
|
||||||
|
latest_tool_call = self.tool_calls[0]
|
||||||
|
|
||||||
|
# Update the execution status to completed
|
||||||
|
if latest_tool_call.id in self.step_execution_tracker:
|
||||||
|
self.step_execution_tracker[latest_tool_call.id]["status"] = "completed"
|
||||||
|
self.step_execution_tracker[latest_tool_call.id]["result"] = result
|
||||||
|
|
||||||
|
# Update the plan status if this was a non-planning, non-special tool
|
||||||
|
if (
|
||||||
|
latest_tool_call.function.name != "planning"
|
||||||
|
and latest_tool_call.function.name not in self.special_tool_names
|
||||||
|
):
|
||||||
|
await self.update_plan_status(latest_tool_call.id)
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
async def get_plan(self) -> str:
|
||||||
|
"""Retrieve the current plan status."""
|
||||||
|
if not self.active_plan_id:
|
||||||
|
return "No active plan. Please create a plan first."
|
||||||
|
|
||||||
|
result = await self.available_tools.execute(
|
||||||
|
name="planning",
|
||||||
|
tool_input={"command": "get", "plan_id": self.active_plan_id},
|
||||||
|
)
|
||||||
|
return result.output if hasattr(result, "output") else str(result)
|
||||||
|
|
||||||
|
async def run(self, request: Optional[str] = None) -> str:
|
||||||
|
"""Run the agent with an optional initial request."""
|
||||||
|
if request:
|
||||||
|
await self.create_initial_plan(request)
|
||||||
|
return await super().run()
|
||||||
|
|
||||||
|
async def update_plan_status(self, tool_call_id: str) -> None:
|
||||||
|
"""
|
||||||
|
Update the current plan progress based on completed tool execution.
|
||||||
|
Only marks a step as completed if the associated tool has been successfully executed.
|
||||||
|
"""
|
||||||
|
if not self.active_plan_id:
|
||||||
|
return
|
||||||
|
|
||||||
|
if tool_call_id not in self.step_execution_tracker:
|
||||||
|
logger.warning(f"No step tracking found for tool call {tool_call_id}")
|
||||||
|
return
|
||||||
|
|
||||||
|
tracker = self.step_execution_tracker[tool_call_id]
|
||||||
|
if tracker["status"] != "completed":
|
||||||
|
logger.warning(f"Tool call {tool_call_id} has not completed successfully")
|
||||||
|
return
|
||||||
|
|
||||||
|
step_index = tracker["step_index"]
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Mark the step as completed
|
||||||
|
await self.available_tools.execute(
|
||||||
|
name="planning",
|
||||||
|
tool_input={
|
||||||
|
"command": "mark_step",
|
||||||
|
"plan_id": self.active_plan_id,
|
||||||
|
"step_index": step_index,
|
||||||
|
"step_status": "completed",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
logger.info(
|
||||||
|
f"Marked step {step_index} as completed in plan {self.active_plan_id}"
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Failed to update plan status: {e}")
|
||||||
|
|
||||||
|
async def _get_current_step_index(self) -> Optional[int]:
|
||||||
|
"""
|
||||||
|
Parse the current plan to identify the first non-completed step's index.
|
||||||
|
Returns None if no active step is found.
|
||||||
|
"""
|
||||||
|
if not self.active_plan_id:
|
||||||
|
return None
|
||||||
|
|
||||||
|
plan = await self.get_plan()
|
||||||
|
|
||||||
|
try:
|
||||||
|
plan_lines = plan.splitlines()
|
||||||
|
steps_index = -1
|
||||||
|
|
||||||
|
# Find the index of the "Steps:" line
|
||||||
|
for i, line in enumerate(plan_lines):
|
||||||
|
if line.strip() == "Steps:":
|
||||||
|
steps_index = i
|
||||||
|
break
|
||||||
|
|
||||||
|
if steps_index == -1:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Find the first non-completed step
|
||||||
|
for i, line in enumerate(plan_lines[steps_index + 1 :], start=0):
|
||||||
|
if "[ ]" in line or "[→]" in line: # not_started or in_progress
|
||||||
|
# Mark current step as in_progress
|
||||||
|
await self.available_tools.execute(
|
||||||
|
name="planning",
|
||||||
|
tool_input={
|
||||||
|
"command": "mark_step",
|
||||||
|
"plan_id": self.active_plan_id,
|
||||||
|
"step_index": i,
|
||||||
|
"step_status": "in_progress",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
return i
|
||||||
|
|
||||||
|
return None # No active step found
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Error finding current step index: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def create_initial_plan(self, request: str) -> None:
|
||||||
|
"""Create an initial plan based on the request."""
|
||||||
|
logger.info(f"Creating initial plan with ID: {self.active_plan_id}")
|
||||||
|
|
||||||
|
messages = [
|
||||||
|
Message.user_message(
|
||||||
|
f"Analyze the request and create a plan with ID {self.active_plan_id}: {request}"
|
||||||
|
)
|
||||||
|
]
|
||||||
|
self.memory.add_messages(messages)
|
||||||
|
response = await self.llm.ask_tool(
|
||||||
|
messages=messages,
|
||||||
|
system_msgs=[Message.system_message(self.system_prompt)],
|
||||||
|
tools=self.available_tools.to_params(),
|
||||||
|
tool_choice="required",
|
||||||
|
)
|
||||||
|
assistant_msg = Message.from_tool_calls(
|
||||||
|
content=response.content, tool_calls=response.tool_calls
|
||||||
|
)
|
||||||
|
|
||||||
|
self.memory.add_message(assistant_msg)
|
||||||
|
|
||||||
|
plan_created = False
|
||||||
|
for tool_call in response.tool_calls:
|
||||||
|
if tool_call.function.name == "planning":
|
||||||
|
result = await self.execute_tool(tool_call)
|
||||||
|
logger.info(
|
||||||
|
f"Executed tool {tool_call.function.name} with result: {result}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Add tool response to memory
|
||||||
|
tool_msg = Message.tool_message(
|
||||||
|
content=result,
|
||||||
|
tool_call_id=tool_call.id,
|
||||||
|
name=tool_call.function.name,
|
||||||
|
)
|
||||||
|
self.memory.add_message(tool_msg)
|
||||||
|
plan_created = True
|
||||||
|
break
|
||||||
|
|
||||||
|
if not plan_created:
|
||||||
|
logger.warning("No plan created from initial request")
|
||||||
|
tool_msg = Message.assistant_message(
|
||||||
|
"Error: Parameter `plan_id` is required for command: create"
|
||||||
|
)
|
||||||
|
self.memory.add_message(tool_msg)
|
||||||
|
|
||||||
|
|
||||||
|
async def main():
|
||||||
|
# Configure and run the agent
|
||||||
|
agent = PlanningAgent(available_tools=ToolCollection(PlanningTool(), Terminate()))
|
||||||
|
result = await agent.run("Help me plan a trip to the moon")
|
||||||
|
print(result)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
asyncio.run(main())
|
38
app/agent/react.py
Normal file
38
app/agent/react.py
Normal file
@ -0,0 +1,38 @@
|
|||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from pydantic import Field
|
||||||
|
|
||||||
|
from app.agent.base import BaseAgent
|
||||||
|
from app.llm import LLM
|
||||||
|
from app.schema import AgentState, Memory
|
||||||
|
|
||||||
|
|
||||||
|
class ReActAgent(BaseAgent, ABC):
|
||||||
|
name: str
|
||||||
|
description: Optional[str] = None
|
||||||
|
|
||||||
|
system_prompt: Optional[str] = None
|
||||||
|
next_step_prompt: Optional[str] = None
|
||||||
|
|
||||||
|
llm: Optional[LLM] = Field(default_factory=LLM)
|
||||||
|
memory: Memory = Field(default_factory=Memory)
|
||||||
|
state: AgentState = AgentState.IDLE
|
||||||
|
|
||||||
|
max_steps: int = 10
|
||||||
|
current_step: int = 0
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def think(self) -> bool:
|
||||||
|
"""Process current state and decide next action"""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def act(self) -> str:
|
||||||
|
"""Execute decided actions"""
|
||||||
|
|
||||||
|
async def step(self) -> str:
|
||||||
|
"""Execute a single step: think and act."""
|
||||||
|
should_act = await self.think()
|
||||||
|
if not should_act:
|
||||||
|
return "Thinking complete - no action needed"
|
||||||
|
return await self.act()
|
37
app/agent/swe.py
Normal file
37
app/agent/swe.py
Normal file
@ -0,0 +1,37 @@
|
|||||||
|
from typing import List
|
||||||
|
|
||||||
|
from pydantic import Field
|
||||||
|
|
||||||
|
from app.agent.toolcall import ToolCallAgent
|
||||||
|
from app.prompt.swe import NEXT_STEP_TEMPLATE, SYSTEM_PROMPT
|
||||||
|
from app.tool import Bash, StrReplaceEditor, Terminate, ToolCollection
|
||||||
|
|
||||||
|
|
||||||
|
class SWEAgent(ToolCallAgent):
|
||||||
|
"""An agent that implements the SWEAgent paradigm for executing code and natural conversations."""
|
||||||
|
|
||||||
|
name: str = "swe"
|
||||||
|
description: str = "an autonomous AI programmer that interacts directly with the computer to solve tasks."
|
||||||
|
|
||||||
|
system_prompt: str = SYSTEM_PROMPT
|
||||||
|
next_step_prompt: str = NEXT_STEP_TEMPLATE
|
||||||
|
|
||||||
|
available_tools: ToolCollection = ToolCollection(
|
||||||
|
Bash(), StrReplaceEditor(), Terminate()
|
||||||
|
)
|
||||||
|
special_tool_names: List[str] = Field(default_factory=lambda: [Terminate().name])
|
||||||
|
|
||||||
|
max_steps: int = 30
|
||||||
|
|
||||||
|
bash: Bash = Field(default_factory=Bash)
|
||||||
|
working_dir: str = "."
|
||||||
|
|
||||||
|
async def think(self) -> bool:
|
||||||
|
"""Process current state and decide next action"""
|
||||||
|
# Update working directory
|
||||||
|
self.working_dir = await self.bash.execute("pwd")
|
||||||
|
self.next_step_prompt = self.next_step_prompt.format(
|
||||||
|
current_dir=self.working_dir
|
||||||
|
)
|
||||||
|
|
||||||
|
return await super().think()
|
171
app/agent/toolcall.py
Normal file
171
app/agent/toolcall.py
Normal file
@ -0,0 +1,171 @@
|
|||||||
|
import json
|
||||||
|
from typing import Any, List, Literal
|
||||||
|
|
||||||
|
from pydantic import Field
|
||||||
|
|
||||||
|
from app.agent.react import ReActAgent
|
||||||
|
from app.logger import logger
|
||||||
|
from app.prompt.toolcall import NEXT_STEP_PROMPT, SYSTEM_PROMPT
|
||||||
|
from app.schema import AgentState, Message, ToolCall
|
||||||
|
from app.tool import CreateChatCompletion, Terminate, ToolCollection
|
||||||
|
|
||||||
|
|
||||||
|
TOOL_CALL_REQUIRED = "Tool calls required but none provided"
|
||||||
|
|
||||||
|
|
||||||
|
class ToolCallAgent(ReActAgent):
|
||||||
|
"""Base agent class for handling tool/function calls with enhanced abstraction"""
|
||||||
|
|
||||||
|
name: str = "toolcall"
|
||||||
|
description: str = "an agent that can execute tool calls."
|
||||||
|
|
||||||
|
system_prompt: str = SYSTEM_PROMPT
|
||||||
|
next_step_prompt: str = NEXT_STEP_PROMPT
|
||||||
|
|
||||||
|
available_tools: ToolCollection = ToolCollection(
|
||||||
|
CreateChatCompletion(), Terminate()
|
||||||
|
)
|
||||||
|
tool_choices: Literal["none", "auto", "required"] = "auto"
|
||||||
|
special_tool_names: List[str] = Field(default_factory=lambda: [Terminate().name])
|
||||||
|
|
||||||
|
tool_calls: List[ToolCall] = Field(default_factory=list)
|
||||||
|
|
||||||
|
max_steps: int = 30
|
||||||
|
|
||||||
|
async def think(self) -> bool:
|
||||||
|
"""Process current state and decide next actions using tools"""
|
||||||
|
if self.next_step_prompt:
|
||||||
|
user_msg = Message.user_message(self.next_step_prompt)
|
||||||
|
self.messages += [user_msg]
|
||||||
|
|
||||||
|
# Get response with tool options
|
||||||
|
response = await self.llm.ask_tool(
|
||||||
|
messages=self.messages,
|
||||||
|
system_msgs=[Message.system_message(self.system_prompt)]
|
||||||
|
if self.system_prompt
|
||||||
|
else None,
|
||||||
|
tools=self.available_tools.to_params(),
|
||||||
|
tool_choice=self.tool_choices,
|
||||||
|
)
|
||||||
|
self.tool_calls = response.tool_calls
|
||||||
|
|
||||||
|
# Log response info
|
||||||
|
logger.info(f"Tool content: {response.content}")
|
||||||
|
logger.info(
|
||||||
|
f"Tool calls count: {len(response.tool_calls) if response.tool_calls else 0}"
|
||||||
|
)
|
||||||
|
logger.info(f"Tool calls: {response.tool_calls}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Handle different tool_choices modes
|
||||||
|
if self.tool_choices == "none":
|
||||||
|
if response.tool_calls:
|
||||||
|
logger.warning("Tool calls provided when tool_choice is 'none'")
|
||||||
|
if response.content:
|
||||||
|
self.memory.add_message(Message.assistant_message(response.content))
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Create and add assistant message
|
||||||
|
assistant_msg = (
|
||||||
|
Message.from_tool_calls(
|
||||||
|
content=response.content, tool_calls=self.tool_calls
|
||||||
|
)
|
||||||
|
if self.tool_calls
|
||||||
|
else Message.assistant_message(response.content)
|
||||||
|
)
|
||||||
|
self.memory.add_message(assistant_msg)
|
||||||
|
|
||||||
|
if self.tool_choices == "required" and not self.tool_calls:
|
||||||
|
return True # Will be handled in act()
|
||||||
|
|
||||||
|
# For 'auto' mode, continue with content if no commands but content exists
|
||||||
|
if self.tool_choices == "auto" and not self.tool_calls:
|
||||||
|
return bool(response.content)
|
||||||
|
|
||||||
|
return bool(self.tool_calls)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error in thinking phase: {e}")
|
||||||
|
self.memory.add_message(
|
||||||
|
Message.assistant_message(
|
||||||
|
f"Error encountered while processing: {str(e)}"
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return False
|
||||||
|
|
||||||
|
async def act(self) -> str:
|
||||||
|
"""Execute tool calls and handle their results"""
|
||||||
|
if not self.tool_calls:
|
||||||
|
if self.tool_choices == "required":
|
||||||
|
raise ValueError(TOOL_CALL_REQUIRED)
|
||||||
|
|
||||||
|
# Return last message content if no tool calls
|
||||||
|
return self.messages[-1].content or "No content or commands to execute"
|
||||||
|
|
||||||
|
results = []
|
||||||
|
for command in self.tool_calls:
|
||||||
|
result = await self.execute_tool(command)
|
||||||
|
logger.info(f"Executed tool {command.function.name} with result: {result}")
|
||||||
|
|
||||||
|
# Add tool response to memory
|
||||||
|
tool_msg = Message.tool_message(
|
||||||
|
content=result, tool_call_id=command.id, name=command.function.name
|
||||||
|
)
|
||||||
|
self.memory.add_message(tool_msg)
|
||||||
|
results.append(result)
|
||||||
|
|
||||||
|
return "\n\n".join(results)
|
||||||
|
|
||||||
|
async def execute_tool(self, command: ToolCall) -> str:
|
||||||
|
"""Execute a single tool call with robust error handling"""
|
||||||
|
if not command or not command.function or not command.function.name:
|
||||||
|
return "Error: Invalid command format"
|
||||||
|
|
||||||
|
name = command.function.name
|
||||||
|
if name not in self.available_tools.tool_map:
|
||||||
|
return f"Error: Unknown tool '{name}'"
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Parse arguments
|
||||||
|
args = json.loads(command.function.arguments or "{}")
|
||||||
|
|
||||||
|
# Execute the tool
|
||||||
|
result = await self.available_tools.execute(name=name, tool_input=args)
|
||||||
|
|
||||||
|
# Format result for display
|
||||||
|
observation = (
|
||||||
|
f"Observed output of cmd `{name}` executed:\n{str(result)}"
|
||||||
|
if result
|
||||||
|
else f"Cmd `{name}` completed with no output"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Handle special tools like `finish`
|
||||||
|
await self._handle_special_tool(name=name, result=result)
|
||||||
|
|
||||||
|
return observation
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
error_msg = f"Error parsing arguments for {name}: Invalid JSON format"
|
||||||
|
logger.error(error_msg)
|
||||||
|
return f"Error: {error_msg}"
|
||||||
|
except Exception as e:
|
||||||
|
error_msg = f"Error executing tool {name}: {str(e)}"
|
||||||
|
logger.error(error_msg)
|
||||||
|
return f"Error: {error_msg}"
|
||||||
|
|
||||||
|
async def _handle_special_tool(self, name: str, result: Any, **kwargs):
|
||||||
|
"""Handle special tool execution and state changes"""
|
||||||
|
if not self._is_special_tool(name):
|
||||||
|
return
|
||||||
|
|
||||||
|
if self._should_finish_execution(name=name, result=result, **kwargs):
|
||||||
|
# Set agent state to finished
|
||||||
|
self.state = AgentState.FINISHED
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _should_finish_execution(**kwargs) -> bool:
|
||||||
|
"""Determine if tool execution should finish the agent"""
|
||||||
|
return True
|
||||||
|
|
||||||
|
def _is_special_tool(self, name: str) -> bool:
|
||||||
|
"""Check if tool name is in special tools list"""
|
||||||
|
return name.lower() in [n.lower() for n in self.special_tool_names]
|
112
app/config.py
Normal file
112
app/config.py
Normal file
@ -0,0 +1,112 @@
|
|||||||
|
import threading
|
||||||
|
import tomllib
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Dict, Optional
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
|
||||||
|
def get_project_root() -> Path:
|
||||||
|
"""Get the project root directory"""
|
||||||
|
return Path(__file__).resolve().parent.parent
|
||||||
|
|
||||||
|
|
||||||
|
PROJECT_ROOT = get_project_root()
|
||||||
|
WORKSPACE_ROOT = PROJECT_ROOT / "workspace"
|
||||||
|
|
||||||
|
|
||||||
|
class LLMSettings(BaseModel):
|
||||||
|
model: str = Field(..., description="Model name")
|
||||||
|
base_url: str = Field(..., description="API base URL")
|
||||||
|
api_key: str = Field(..., description="API key")
|
||||||
|
max_tokens: int = Field(4096, description="Maximum number of tokens per request")
|
||||||
|
temperature: float = Field(1.0, description="Sampling temperature")
|
||||||
|
|
||||||
|
|
||||||
|
class ScreenshotSettings(BaseModel):
|
||||||
|
api_key: Optional[str] = Field(None, description="Screenshot API key")
|
||||||
|
base_url: Optional[str] = Field(None, description="Screenshot service URL")
|
||||||
|
|
||||||
|
|
||||||
|
class AppConfig(BaseModel):
|
||||||
|
llm: Dict[str, LLMSettings]
|
||||||
|
screenshot: Optional[ScreenshotSettings] = None
|
||||||
|
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
_instance = None
|
||||||
|
_lock = threading.Lock()
|
||||||
|
_initialized = False
|
||||||
|
|
||||||
|
def __new__(cls):
|
||||||
|
if cls._instance is None:
|
||||||
|
with cls._lock:
|
||||||
|
if cls._instance is None:
|
||||||
|
cls._instance = super().__new__(cls)
|
||||||
|
return cls._instance
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
if not self._initialized:
|
||||||
|
with self._lock:
|
||||||
|
if not self._initialized:
|
||||||
|
self._config = None
|
||||||
|
self._load_initial_config()
|
||||||
|
self._initialized = True
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _get_config_path() -> Path:
|
||||||
|
root = PROJECT_ROOT
|
||||||
|
config_path = root / "config" / "config.toml"
|
||||||
|
if config_path.exists():
|
||||||
|
return config_path
|
||||||
|
example_path = root / "config" / "config.example.toml"
|
||||||
|
if example_path.exists():
|
||||||
|
return example_path
|
||||||
|
raise FileNotFoundError("No configuration file found in config directory")
|
||||||
|
|
||||||
|
def _load_config(self) -> dict:
|
||||||
|
config_path = self._get_config_path()
|
||||||
|
with config_path.open("rb") as f:
|
||||||
|
return tomllib.load(f)
|
||||||
|
|
||||||
|
def _load_initial_config(self):
|
||||||
|
raw_config = self._load_config()
|
||||||
|
base_llm = raw_config.get("llm", {})
|
||||||
|
llm_overrides = {
|
||||||
|
k: v for k, v in raw_config.get("llm", {}).items() if isinstance(v, dict)
|
||||||
|
}
|
||||||
|
|
||||||
|
default_settings = {
|
||||||
|
"model": base_llm.get("model"),
|
||||||
|
"base_url": base_llm.get("base_url"),
|
||||||
|
"api_key": base_llm.get("api_key"),
|
||||||
|
"max_tokens": base_llm.get("max_tokens", 4096),
|
||||||
|
"temperature": base_llm.get("temperature", 1.0),
|
||||||
|
}
|
||||||
|
|
||||||
|
config_dict = {
|
||||||
|
"llm": {
|
||||||
|
"default": default_settings,
|
||||||
|
**{
|
||||||
|
name: {**default_settings, **override_config}
|
||||||
|
for name, override_config in llm_overrides.items()
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
# Add screenshot config if present
|
||||||
|
if screenshot_config := raw_config.get("screenshot"):
|
||||||
|
config_dict["screenshot"] = screenshot_config
|
||||||
|
|
||||||
|
self._config = AppConfig(**config_dict)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def screenshot(self) -> Optional[ScreenshotSettings]:
|
||||||
|
return self._config.screenshot
|
||||||
|
|
||||||
|
@property
|
||||||
|
def llm(self) -> Dict[str, LLMSettings]:
|
||||||
|
return self._config.llm
|
||||||
|
|
||||||
|
|
||||||
|
config = Config()
|
12
app/exceptions.py
Normal file
12
app/exceptions.py
Normal file
@ -0,0 +1,12 @@
|
|||||||
|
class ToolError(Exception):
|
||||||
|
"""Raised when a tool encounters an error."""
|
||||||
|
|
||||||
|
def __init__(self, message):
|
||||||
|
self.message = message
|
||||||
|
|
||||||
|
|
||||||
|
class BrowserException(Exception):
|
||||||
|
"""Base exception for browser-related errors."""
|
||||||
|
|
||||||
|
def __init__(self, message):
|
||||||
|
super().__init__(message)
|
0
app/flow/__init__.py
Normal file
0
app/flow/__init__.py
Normal file
60
app/flow/base.py
Normal file
60
app/flow/base.py
Normal file
@ -0,0 +1,60 @@
|
|||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from enum import Enum
|
||||||
|
from typing import Dict, List, Optional, Union
|
||||||
|
|
||||||
|
from app.agent.base import BaseAgent
|
||||||
|
from app.agent.toolcall import ToolCallAgent
|
||||||
|
|
||||||
|
|
||||||
|
class FlowType(str, Enum):
|
||||||
|
PLANNING = "planning"
|
||||||
|
|
||||||
|
|
||||||
|
class BaseFlow(ABC):
|
||||||
|
"""Base class for execution flows supporting multiple agents"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self, agents: Union[BaseAgent, List[BaseAgent], Dict[str, BaseAgent]], **kwargs
|
||||||
|
):
|
||||||
|
# Handle different ways of providing agents
|
||||||
|
if isinstance(agents, BaseAgent):
|
||||||
|
self.agents = {"default": agents}
|
||||||
|
elif isinstance(agents, list):
|
||||||
|
self.agents = {f"agent_{i}": agent for i, agent in enumerate(agents)}
|
||||||
|
else:
|
||||||
|
self.agents = agents
|
||||||
|
|
||||||
|
self.tools = kwargs.get("tools")
|
||||||
|
self.primary_agent_key = kwargs.get("primary_agent", None)
|
||||||
|
|
||||||
|
# If primary agent not specified, use first agent
|
||||||
|
if not self.primary_agent_key and self.agents:
|
||||||
|
self.primary_agent_key = next(iter(self.agents))
|
||||||
|
|
||||||
|
self._setup_agents()
|
||||||
|
|
||||||
|
def _setup_agents(self):
|
||||||
|
"""Configure all agents with tools and initial setup"""
|
||||||
|
if self.tools:
|
||||||
|
for agent_key, agent in self.agents.items():
|
||||||
|
if isinstance(agent, ToolCallAgent):
|
||||||
|
agent.available_tools = self.tools
|
||||||
|
|
||||||
|
@property
|
||||||
|
def primary_agent(self) -> Optional[BaseAgent]:
|
||||||
|
"""Get the primary agent for the flow"""
|
||||||
|
return self.agents.get(self.primary_agent_key)
|
||||||
|
|
||||||
|
def get_agent(self, key: str) -> Optional[BaseAgent]:
|
||||||
|
"""Get a specific agent by key"""
|
||||||
|
return self.agents.get(key)
|
||||||
|
|
||||||
|
def add_agent(self, key: str, agent: BaseAgent) -> None:
|
||||||
|
"""Add a new agent to the flow"""
|
||||||
|
self.agents[key] = agent
|
||||||
|
if isinstance(agent, ToolCallAgent) and self.tools:
|
||||||
|
agent.available_tools = self.tools
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def execute(self, input_text: str) -> str:
|
||||||
|
"""Execute the flow with given input"""
|
25
app/flow/flow_factory.py
Normal file
25
app/flow/flow_factory.py
Normal file
@ -0,0 +1,25 @@
|
|||||||
|
from typing import Dict, List, Union
|
||||||
|
|
||||||
|
from app.agent.base import BaseAgent
|
||||||
|
from app.flow.base import BaseFlow, FlowType
|
||||||
|
from app.flow.planning import PlanningFlow
|
||||||
|
|
||||||
|
|
||||||
|
class FlowFactory:
|
||||||
|
"""Factory for creating different types of flows with support for multiple agents"""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def create_flow(
|
||||||
|
flow_type: FlowType,
|
||||||
|
agents: Union[BaseAgent, List[BaseAgent], Dict[str, BaseAgent]],
|
||||||
|
**kwargs,
|
||||||
|
) -> BaseFlow:
|
||||||
|
flows = {
|
||||||
|
FlowType.PLANNING: PlanningFlow,
|
||||||
|
}
|
||||||
|
|
||||||
|
flow_class = flows.get(flow_type)
|
||||||
|
if not flow_class:
|
||||||
|
raise ValueError(f"Unknown flow type: {flow_type}")
|
||||||
|
|
||||||
|
return flow_class(agents, **kwargs)
|
402
app/flow/planning.py
Normal file
402
app/flow/planning.py
Normal file
@ -0,0 +1,402 @@
|
|||||||
|
import time
|
||||||
|
from typing import Dict, List, Optional, Union
|
||||||
|
|
||||||
|
from app.agent.base import BaseAgent
|
||||||
|
from app.flow.base import BaseFlow
|
||||||
|
from app.logger import logger
|
||||||
|
from app.schema import AgentState
|
||||||
|
from app.tool import PlanningTool, ToolCollection
|
||||||
|
|
||||||
|
|
||||||
|
class PlanningFlow(BaseFlow):
|
||||||
|
"""A flow that manages planning and execution of tasks using agents."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self, agents: Union[BaseAgent, List[BaseAgent], Dict[str, BaseAgent]], **kwargs
|
||||||
|
):
|
||||||
|
# Initialize planning tool first
|
||||||
|
self.planning_tool = self._initialize_planning_tool(kwargs.get("tools"))
|
||||||
|
|
||||||
|
# If tools were provided, ensure planning tool is included
|
||||||
|
tools = kwargs.get("tools")
|
||||||
|
if tools:
|
||||||
|
planning_tool_exists = any(
|
||||||
|
isinstance(tool, PlanningTool) for tool in tools.tools
|
||||||
|
)
|
||||||
|
if not planning_tool_exists:
|
||||||
|
tools.add_tool(self.planning_tool)
|
||||||
|
else:
|
||||||
|
# Create a new tool collection with at least the planning tool
|
||||||
|
tools = ToolCollection(self.planning_tool)
|
||||||
|
kwargs["tools"] = tools
|
||||||
|
|
||||||
|
super().__init__(agents, **kwargs)
|
||||||
|
|
||||||
|
# Define agent roles
|
||||||
|
self.planner_key = kwargs.get("planner", self.primary_agent_key)
|
||||||
|
self.executor_keys = kwargs.get("executors", list(self.agents.keys()))
|
||||||
|
|
||||||
|
# Planning state tracking
|
||||||
|
self.active_plan_id = kwargs.get("plan_id", f"plan_{int(time.time())}")
|
||||||
|
self.current_step_index = None
|
||||||
|
|
||||||
|
# Ensure the planning tool has been initialized properly
|
||||||
|
if not hasattr(self.planning_tool, "_plans"):
|
||||||
|
self.planning_tool._plans = {}
|
||||||
|
|
||||||
|
def _initialize_planning_tool(
|
||||||
|
self, tools: Optional[ToolCollection]
|
||||||
|
) -> PlanningTool:
|
||||||
|
"""Initialize planning tool, reusing existing one if available"""
|
||||||
|
if tools:
|
||||||
|
for tool in tools.tools:
|
||||||
|
if isinstance(tool, PlanningTool):
|
||||||
|
return tool
|
||||||
|
return PlanningTool()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def planner(self) -> Optional[BaseAgent]:
|
||||||
|
"""Get the planning agent"""
|
||||||
|
return (
|
||||||
|
self.agents.get(self.planner_key)
|
||||||
|
if self.planner_key
|
||||||
|
else self.primary_agent
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_executor(self, step_type: Optional[str] = None) -> BaseAgent:
|
||||||
|
"""
|
||||||
|
Get an appropriate executor agent for the current step.
|
||||||
|
Can be extended to select agents based on step type/requirements.
|
||||||
|
"""
|
||||||
|
# If step type is provided and matches an agent key, use that agent
|
||||||
|
if step_type and step_type in self.agents:
|
||||||
|
return self.agents[step_type]
|
||||||
|
|
||||||
|
# Otherwise use the first available executor or fall back to primary agent
|
||||||
|
for key in self.executor_keys:
|
||||||
|
if key in self.agents:
|
||||||
|
return self.agents[key]
|
||||||
|
|
||||||
|
# Fallback to primary agent
|
||||||
|
return self.primary_agent
|
||||||
|
|
||||||
|
async def execute(self, input_text: str) -> str:
|
||||||
|
"""Execute the planning flow with agents."""
|
||||||
|
try:
|
||||||
|
if not self.primary_agent:
|
||||||
|
raise ValueError("No primary agent available")
|
||||||
|
|
||||||
|
# Create initial plan if input provided
|
||||||
|
if input_text:
|
||||||
|
await self._create_initial_plan(input_text)
|
||||||
|
|
||||||
|
# Verify plan was created successfully
|
||||||
|
if self.active_plan_id not in self.planning_tool._plans:
|
||||||
|
logger.error(
|
||||||
|
f"Plan creation failed. Plan ID {self.active_plan_id} not found in planning tool."
|
||||||
|
)
|
||||||
|
return f"Failed to create plan for: {input_text}"
|
||||||
|
|
||||||
|
result = ""
|
||||||
|
while True:
|
||||||
|
# Get current step to execute
|
||||||
|
self.current_step_index, step_info = await self._get_current_step_info()
|
||||||
|
|
||||||
|
# Exit if no more steps or plan completed
|
||||||
|
if self.current_step_index is None:
|
||||||
|
result += await self._finalize_plan()
|
||||||
|
break
|
||||||
|
|
||||||
|
# Execute current step with appropriate agent
|
||||||
|
step_type = step_info.get("type") if step_info else None
|
||||||
|
executor = self.get_executor(step_type)
|
||||||
|
step_result = await self._execute_step(executor, step_info)
|
||||||
|
result += step_result + "\n"
|
||||||
|
|
||||||
|
# Check if agent wants to terminate
|
||||||
|
if hasattr(executor, "state") and executor.state == AgentState.FINISHED:
|
||||||
|
break
|
||||||
|
|
||||||
|
return result
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error in PlanningFlow: {str(e)}")
|
||||||
|
return f"Execution failed: {str(e)}"
|
||||||
|
|
||||||
|
async def _create_initial_plan(self, request: str) -> None:
|
||||||
|
"""Create an initial plan based on the request using an appropriate agent."""
|
||||||
|
logger.info(f"Creating initial plan with ID: {self.active_plan_id}")
|
||||||
|
|
||||||
|
agent = self.planner if self.planner else self.primary_agent
|
||||||
|
|
||||||
|
# First, directly create an empty plan to ensure the plan ID exists
|
||||||
|
self.planning_tool._plans[self.active_plan_id] = {
|
||||||
|
"title": f"Plan for: {request[:50]}{'...' if len(request) > 50 else ''}",
|
||||||
|
"description": f"Auto-generated plan for request: {request}",
|
||||||
|
"steps": [],
|
||||||
|
"step_status": {},
|
||||||
|
"created_at": time.time(),
|
||||||
|
"updated_at": time.time(),
|
||||||
|
}
|
||||||
|
|
||||||
|
# Use agent.run to create the plan
|
||||||
|
plan_prompt = f"""
|
||||||
|
I need you to create a detailed plan to accomplish this task:
|
||||||
|
|
||||||
|
{request}
|
||||||
|
|
||||||
|
Please create a plan with ID {self.active_plan_id} using the planning tool.
|
||||||
|
The plan should include all necessary steps to complete the task.
|
||||||
|
"""
|
||||||
|
|
||||||
|
try:
|
||||||
|
plan_result = await agent.run(plan_prompt)
|
||||||
|
logger.info(f"Plan creation result: {plan_result[:200]}...")
|
||||||
|
|
||||||
|
# Verify the plan was created
|
||||||
|
if (
|
||||||
|
self.active_plan_id not in self.planning_tool._plans
|
||||||
|
or not self.planning_tool._plans[self.active_plan_id].get("steps")
|
||||||
|
):
|
||||||
|
logger.warning(
|
||||||
|
"Plan may not have been created properly. Creating default plan."
|
||||||
|
)
|
||||||
|
await self._create_default_plan(request)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error creating plan: {e}")
|
||||||
|
await self._create_default_plan(request)
|
||||||
|
|
||||||
|
async def _create_default_plan(self, request: str) -> None:
|
||||||
|
"""Create a default plan if the agent fails to create one."""
|
||||||
|
try:
|
||||||
|
# Try using the planning tool directly
|
||||||
|
await self.planning_tool.execute(
|
||||||
|
command="create",
|
||||||
|
plan_id=self.active_plan_id,
|
||||||
|
title=f"Plan for: {request[:50]}{'...' if len(request) > 50 else ''}",
|
||||||
|
description=f"Auto-generated plan for request: {request}",
|
||||||
|
steps=["Analyze request", "Execute task", "Verify results"],
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to create default plan with planning tool: {e}")
|
||||||
|
# Create plan directly in the planning tool's storage
|
||||||
|
self.planning_tool._plans[self.active_plan_id] = {
|
||||||
|
"title": f"Emergency Plan for: {request[:50]}{'...' if len(request) > 50 else ''}",
|
||||||
|
"description": f"Emergency auto-generated plan for request: {request}",
|
||||||
|
"steps": ["Analyze request", "Execute task", "Verify results"],
|
||||||
|
"step_status": {
|
||||||
|
"0": "not_started",
|
||||||
|
"1": "not_started",
|
||||||
|
"2": "not_started",
|
||||||
|
},
|
||||||
|
"created_at": time.time(),
|
||||||
|
"updated_at": time.time(),
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.info(f"Created default plan with ID: {self.active_plan_id}")
|
||||||
|
|
||||||
|
async def _get_current_step_info(self) -> tuple[Optional[int], Optional[dict]]:
|
||||||
|
"""
|
||||||
|
Parse the current plan to identify the first non-completed step's index and info.
|
||||||
|
Returns (None, None) if no active step is found.
|
||||||
|
"""
|
||||||
|
if (
|
||||||
|
not self.active_plan_id
|
||||||
|
or self.active_plan_id not in self.planning_tool._plans
|
||||||
|
):
|
||||||
|
logger.error(f"Plan with ID {self.active_plan_id} not found")
|
||||||
|
return None, None
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Direct access to step status from planning tool storage
|
||||||
|
plan_data = self.planning_tool._plans[self.active_plan_id]
|
||||||
|
steps = plan_data.get("steps", [])
|
||||||
|
step_status = plan_data.get("step_status", {})
|
||||||
|
|
||||||
|
# Find first non-completed step
|
||||||
|
for i, step in enumerate(steps):
|
||||||
|
status = step_status.get(str(i), "not_started")
|
||||||
|
if status in ["not_started", "in_progress"]:
|
||||||
|
# Extract step type/category if available
|
||||||
|
step_info = {"text": step}
|
||||||
|
|
||||||
|
# Try to extract step type from the text (e.g., [SEARCH] or [CODE])
|
||||||
|
import re
|
||||||
|
|
||||||
|
type_match = re.search(r"\[([A-Z_]+)\]", step)
|
||||||
|
if type_match:
|
||||||
|
step_info["type"] = type_match.group(1).lower()
|
||||||
|
|
||||||
|
# Mark current step as in_progress
|
||||||
|
try:
|
||||||
|
await self.planning_tool.execute(
|
||||||
|
command="mark_step",
|
||||||
|
plan_id=self.active_plan_id,
|
||||||
|
step_index=i,
|
||||||
|
step_status="in_progress",
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Error marking step as in_progress: {e}")
|
||||||
|
# Update step status directly
|
||||||
|
step_status[str(i)] = "in_progress"
|
||||||
|
plan_data["step_status"] = step_status
|
||||||
|
plan_data["updated_at"] = time.time()
|
||||||
|
|
||||||
|
return i, step_info
|
||||||
|
|
||||||
|
return None, None # No active step found
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Error finding current step index: {e}")
|
||||||
|
return None, None
|
||||||
|
|
||||||
|
async def _execute_step(self, executor: BaseAgent, step_info: dict) -> str:
|
||||||
|
"""Execute the current step with the specified agent using agent.run()."""
|
||||||
|
# Prepare context for the agent with current plan status
|
||||||
|
plan_status = await self._get_plan_text()
|
||||||
|
step_text = step_info.get("text", f"Step {self.current_step_index}")
|
||||||
|
|
||||||
|
# Create a prompt for the agent to execute the current step
|
||||||
|
step_prompt = f"""
|
||||||
|
CURRENT PLAN STATUS:
|
||||||
|
{plan_status}
|
||||||
|
|
||||||
|
YOUR CURRENT TASK:
|
||||||
|
You are now working on step {self.current_step_index}: "{step_text}"
|
||||||
|
|
||||||
|
Please execute this step using the appropriate tools. When you're done, provide a summary of what you accomplished.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Use agent.run() to execute the step
|
||||||
|
try:
|
||||||
|
step_result = await executor.run(step_prompt)
|
||||||
|
|
||||||
|
# Mark the step as completed after successful execution
|
||||||
|
await self._mark_step_completed()
|
||||||
|
|
||||||
|
return step_result
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error executing step {self.current_step_index}: {e}")
|
||||||
|
return f"Error executing step {self.current_step_index}: {str(e)}"
|
||||||
|
|
||||||
|
async def _mark_step_completed(self) -> None:
|
||||||
|
"""Mark the current step as completed."""
|
||||||
|
if self.current_step_index is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Mark the step as completed
|
||||||
|
await self.planning_tool.execute(
|
||||||
|
command="mark_step",
|
||||||
|
plan_id=self.active_plan_id,
|
||||||
|
step_index=self.current_step_index,
|
||||||
|
step_status="completed",
|
||||||
|
)
|
||||||
|
logger.info(
|
||||||
|
f"Marked step {self.current_step_index} as completed in plan {self.active_plan_id}"
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Failed to update plan status: {e}")
|
||||||
|
# Update step status directly in planning tool storage
|
||||||
|
if self.active_plan_id in self.planning_tool._plans:
|
||||||
|
plan_data = self.planning_tool._plans[self.active_plan_id]
|
||||||
|
step_status = plan_data.get("step_status", {})
|
||||||
|
step_status[str(self.current_step_index)] = "completed"
|
||||||
|
plan_data["step_status"] = step_status
|
||||||
|
plan_data["updated_at"] = time.time()
|
||||||
|
|
||||||
|
async def _get_plan_text(self) -> str:
|
||||||
|
"""Get the current plan as formatted text."""
|
||||||
|
try:
|
||||||
|
result = await self.planning_tool.execute(
|
||||||
|
command="get", plan_id=self.active_plan_id
|
||||||
|
)
|
||||||
|
return result.output if hasattr(result, "output") else str(result)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error getting plan: {e}")
|
||||||
|
return self._generate_plan_text_from_storage()
|
||||||
|
|
||||||
|
def _generate_plan_text_from_storage(self) -> str:
|
||||||
|
"""Generate plan text directly from storage if the planning tool fails."""
|
||||||
|
try:
|
||||||
|
if self.active_plan_id not in self.planning_tool._plans:
|
||||||
|
return f"Error: Plan with ID {self.active_plan_id} not found"
|
||||||
|
|
||||||
|
plan_data = self.planning_tool._plans[self.active_plan_id]
|
||||||
|
title = plan_data.get("title", "Untitled Plan")
|
||||||
|
description = plan_data.get("description", "")
|
||||||
|
steps = plan_data.get("steps", [])
|
||||||
|
step_status = plan_data.get("step_status", {})
|
||||||
|
|
||||||
|
# Count steps by status
|
||||||
|
status_counts = {
|
||||||
|
"completed": 0,
|
||||||
|
"in_progress": 0,
|
||||||
|
"blocked": 0,
|
||||||
|
"not_started": 0,
|
||||||
|
}
|
||||||
|
for status in step_status.values():
|
||||||
|
if status in status_counts:
|
||||||
|
status_counts[status] += 1
|
||||||
|
|
||||||
|
completed = status_counts["completed"]
|
||||||
|
total = len(steps)
|
||||||
|
progress = (completed / total) * 100 if total > 0 else 0
|
||||||
|
|
||||||
|
plan_text = f"Plan: {title} (ID: {self.active_plan_id})\n"
|
||||||
|
plan_text += "=" * len(plan_text) + "\n\n"
|
||||||
|
plan_text += f"{description}\n\n" if description else ""
|
||||||
|
plan_text += (
|
||||||
|
f"Progress: {completed}/{total} steps completed ({progress:.1f}%)\n"
|
||||||
|
)
|
||||||
|
plan_text += f"Status: {status_counts['completed']} completed, {status_counts['in_progress']} in progress, "
|
||||||
|
plan_text += f"{status_counts['blocked']} blocked, {status_counts['not_started']} not started\n\n"
|
||||||
|
plan_text += "Steps:\n"
|
||||||
|
|
||||||
|
for i, step in enumerate(steps):
|
||||||
|
status = step_status.get(str(i), "not_started")
|
||||||
|
if status == "completed":
|
||||||
|
status_mark = "[✓]"
|
||||||
|
elif status == "in_progress":
|
||||||
|
status_mark = "[→]"
|
||||||
|
elif status == "blocked":
|
||||||
|
status_mark = "[!]"
|
||||||
|
else: # not_started
|
||||||
|
status_mark = "[ ]"
|
||||||
|
|
||||||
|
plan_text += f"{i}. {status_mark} {step}\n"
|
||||||
|
|
||||||
|
return plan_text
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error generating plan text from storage: {e}")
|
||||||
|
return f"Error: Unable to retrieve plan with ID {self.active_plan_id}"
|
||||||
|
|
||||||
|
async def _get_plan(self) -> dict:
|
||||||
|
"""Get the current plan as a dictionary."""
|
||||||
|
if (
|
||||||
|
not self.active_plan_id
|
||||||
|
or self.active_plan_id not in self.planning_tool._plans
|
||||||
|
):
|
||||||
|
return {}
|
||||||
|
return self.planning_tool._plans[self.active_plan_id]
|
||||||
|
|
||||||
|
async def _finalize_plan(self) -> str:
|
||||||
|
"""Finalize the plan and provide a summary using an appropriate agent."""
|
||||||
|
agent = self.planner if self.planner else self.primary_agent
|
||||||
|
plan_text = await self._get_plan_text()
|
||||||
|
|
||||||
|
# Create a summary prompt
|
||||||
|
summary_prompt = f"""
|
||||||
|
The plan has been completed. Here is the final plan status:
|
||||||
|
|
||||||
|
{plan_text}
|
||||||
|
|
||||||
|
Please provide a summary of what was accomplished and any final thoughts.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Use agent.run() to generate the summary
|
||||||
|
try:
|
||||||
|
summary = await agent.run(summary_prompt)
|
||||||
|
return f"Plan completed:\n\n{summary}"
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error finalizing plan: {e}")
|
||||||
|
return "Plan completed. Error generating summary."
|
254
app/llm.py
Normal file
254
app/llm.py
Normal file
@ -0,0 +1,254 @@
|
|||||||
|
from typing import Dict, List, Literal, Optional, Union
|
||||||
|
|
||||||
|
from openai import (
|
||||||
|
APIError,
|
||||||
|
AsyncOpenAI,
|
||||||
|
AuthenticationError,
|
||||||
|
OpenAIError,
|
||||||
|
RateLimitError,
|
||||||
|
)
|
||||||
|
from tenacity import retry, stop_after_attempt, wait_random_exponential
|
||||||
|
|
||||||
|
from app.config import LLMSettings, config
|
||||||
|
from app.logger import logger # Assuming a logger is set up in your app
|
||||||
|
from app.schema import Message
|
||||||
|
|
||||||
|
|
||||||
|
class LLM:
|
||||||
|
_instances: Dict[str, "LLM"] = {}
|
||||||
|
|
||||||
|
def __new__(
|
||||||
|
cls, config_name: str = "default", llm_config: Optional[LLMSettings] = None
|
||||||
|
):
|
||||||
|
if config_name not in cls._instances:
|
||||||
|
instance = super().__new__(cls)
|
||||||
|
instance.__init__(config_name, llm_config)
|
||||||
|
cls._instances[config_name] = instance
|
||||||
|
return cls._instances[config_name]
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self, config_name: str = "default", llm_config: Optional[LLMSettings] = None
|
||||||
|
):
|
||||||
|
if not hasattr(self, "client"): # Only initialize if not already initialized
|
||||||
|
llm_config = llm_config or config.llm
|
||||||
|
llm_config = llm_config.get(config_name, llm_config["default"])
|
||||||
|
self.model = llm_config.model
|
||||||
|
self.max_tokens = llm_config.max_tokens
|
||||||
|
self.temperature = llm_config.temperature
|
||||||
|
self.client = AsyncOpenAI(
|
||||||
|
api_key=llm_config.api_key, base_url=llm_config.base_url
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def format_messages(messages: List[Union[dict, Message]]) -> List[dict]:
|
||||||
|
"""
|
||||||
|
Format messages for LLM by converting them to OpenAI message format.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
messages: List of messages that can be either dict or Message objects
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List[dict]: List of formatted messages in OpenAI format
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If messages are invalid or missing required fields
|
||||||
|
TypeError: If unsupported message types are provided
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
>>> msgs = [
|
||||||
|
... Message.system_message("You are a helpful assistant"),
|
||||||
|
... {"role": "user", "content": "Hello"},
|
||||||
|
... Message.user_message("How are you?")
|
||||||
|
... ]
|
||||||
|
>>> formatted = LLM.format_messages(msgs)
|
||||||
|
"""
|
||||||
|
formatted_messages = []
|
||||||
|
|
||||||
|
for message in messages:
|
||||||
|
if isinstance(message, dict):
|
||||||
|
# If message is already a dict, ensure it has required fields
|
||||||
|
if "role" not in message:
|
||||||
|
raise ValueError("Message dict must contain 'role' field")
|
||||||
|
formatted_messages.append(message)
|
||||||
|
elif isinstance(message, Message):
|
||||||
|
# If message is a Message object, convert it to dict
|
||||||
|
formatted_messages.append(message.to_dict())
|
||||||
|
else:
|
||||||
|
raise TypeError(f"Unsupported message type: {type(message)}")
|
||||||
|
|
||||||
|
# Validate all messages have required fields
|
||||||
|
for msg in formatted_messages:
|
||||||
|
if msg["role"] not in ["system", "user", "assistant", "tool"]:
|
||||||
|
raise ValueError(f"Invalid role: {msg['role']}")
|
||||||
|
if "content" not in msg and "tool_calls" not in msg:
|
||||||
|
raise ValueError(
|
||||||
|
"Message must contain either 'content' or 'tool_calls'"
|
||||||
|
)
|
||||||
|
|
||||||
|
return formatted_messages
|
||||||
|
|
||||||
|
@retry(
|
||||||
|
wait=wait_random_exponential(min=1, max=60),
|
||||||
|
stop=stop_after_attempt(6),
|
||||||
|
)
|
||||||
|
async def ask(
|
||||||
|
self,
|
||||||
|
messages: List[Union[dict, Message]],
|
||||||
|
system_msgs: Optional[List[Union[dict, Message]]] = None,
|
||||||
|
stream: bool = True,
|
||||||
|
temperature: Optional[float] = None,
|
||||||
|
) -> str:
|
||||||
|
"""
|
||||||
|
Send a prompt to the LLM and get the response.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
messages: List of conversation messages
|
||||||
|
system_msgs: Optional system messages to prepend
|
||||||
|
stream (bool): Whether to stream the response
|
||||||
|
temperature (float): Sampling temperature for the response
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: The generated response
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If messages are invalid or response is empty
|
||||||
|
OpenAIError: If API call fails after retries
|
||||||
|
Exception: For unexpected errors
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# Format system and user messages
|
||||||
|
if system_msgs:
|
||||||
|
system_msgs = self.format_messages(system_msgs)
|
||||||
|
messages = system_msgs + self.format_messages(messages)
|
||||||
|
else:
|
||||||
|
messages = self.format_messages(messages)
|
||||||
|
|
||||||
|
if not stream:
|
||||||
|
# Non-streaming request
|
||||||
|
response = await self.client.chat.completions.create(
|
||||||
|
model=self.model,
|
||||||
|
messages=messages,
|
||||||
|
max_tokens=self.max_tokens,
|
||||||
|
temperature=temperature or self.temperature,
|
||||||
|
stream=False,
|
||||||
|
)
|
||||||
|
if not response.choices or not response.choices[0].message.content:
|
||||||
|
raise ValueError("Empty or invalid response from LLM")
|
||||||
|
return response.choices[0].message.content
|
||||||
|
|
||||||
|
# Streaming request
|
||||||
|
response = await self.client.chat.completions.create(
|
||||||
|
model=self.model,
|
||||||
|
messages=messages,
|
||||||
|
max_tokens=self.max_tokens,
|
||||||
|
temperature=temperature or self.temperature,
|
||||||
|
stream=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
collected_messages = []
|
||||||
|
async for chunk in response:
|
||||||
|
chunk_message = chunk.choices[0].delta.content or ""
|
||||||
|
collected_messages.append(chunk_message)
|
||||||
|
print(chunk_message, end="", flush=True)
|
||||||
|
|
||||||
|
print() # Newline after streaming
|
||||||
|
full_response = "".join(collected_messages).strip()
|
||||||
|
if not full_response:
|
||||||
|
raise ValueError("Empty response from streaming LLM")
|
||||||
|
return full_response
|
||||||
|
|
||||||
|
except ValueError as ve:
|
||||||
|
logger.error(f"Validation error: {ve}")
|
||||||
|
raise
|
||||||
|
except OpenAIError as oe:
|
||||||
|
logger.error(f"OpenAI API error: {oe}")
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Unexpected error in ask: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
@retry(
|
||||||
|
wait=wait_random_exponential(min=1, max=60),
|
||||||
|
stop=stop_after_attempt(6),
|
||||||
|
)
|
||||||
|
async def ask_tool(
|
||||||
|
self,
|
||||||
|
messages: List[Union[dict, Message]],
|
||||||
|
system_msgs: Optional[List[Union[dict, Message]]] = None,
|
||||||
|
timeout: int = 60,
|
||||||
|
tools: Optional[List[dict]] = None,
|
||||||
|
tool_choice: Literal["none", "auto", "required"] = "auto",
|
||||||
|
temperature: Optional[float] = None,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Ask LLM using functions/tools and return the response.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
messages: List of conversation messages
|
||||||
|
system_msgs: Optional system messages to prepend
|
||||||
|
timeout: Request timeout in seconds
|
||||||
|
tools: List of tools to use
|
||||||
|
tool_choice: Tool choice strategy
|
||||||
|
temperature: Sampling temperature for the response
|
||||||
|
**kwargs: Additional completion arguments
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ChatCompletionMessage: The model's response
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If tools, tool_choice, or messages are invalid
|
||||||
|
OpenAIError: If API call fails after retries
|
||||||
|
Exception: For unexpected errors
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# Validate tool_choice
|
||||||
|
if tool_choice not in ["none", "auto", "required"]:
|
||||||
|
raise ValueError(f"Invalid tool_choice: {tool_choice}")
|
||||||
|
|
||||||
|
# Format messages
|
||||||
|
if system_msgs:
|
||||||
|
system_msgs = self.format_messages(system_msgs)
|
||||||
|
messages = system_msgs + self.format_messages(messages)
|
||||||
|
else:
|
||||||
|
messages = self.format_messages(messages)
|
||||||
|
|
||||||
|
# Validate tools if provided
|
||||||
|
if tools:
|
||||||
|
for tool in tools:
|
||||||
|
if not isinstance(tool, dict) or "type" not in tool:
|
||||||
|
raise ValueError("Each tool must be a dict with 'type' field")
|
||||||
|
|
||||||
|
# Set up the completion request
|
||||||
|
response = await self.client.chat.completions.create(
|
||||||
|
model=self.model,
|
||||||
|
messages=messages,
|
||||||
|
temperature=temperature or self.temperature,
|
||||||
|
max_tokens=self.max_tokens,
|
||||||
|
tools=tools,
|
||||||
|
tool_choice=tool_choice,
|
||||||
|
timeout=timeout,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check if response is valid
|
||||||
|
if not response.choices or not response.choices[0].message:
|
||||||
|
print(response)
|
||||||
|
raise ValueError("Invalid or empty response from LLM")
|
||||||
|
|
||||||
|
return response.choices[0].message
|
||||||
|
|
||||||
|
except ValueError as ve:
|
||||||
|
logger.error(f"Validation error in ask_tool: {ve}")
|
||||||
|
raise
|
||||||
|
except OpenAIError as oe:
|
||||||
|
if isinstance(oe, AuthenticationError):
|
||||||
|
logger.error("Authentication failed. Check API key.")
|
||||||
|
elif isinstance(oe, RateLimitError):
|
||||||
|
logger.error("Rate limit exceeded. Consider increasing retry attempts.")
|
||||||
|
elif isinstance(oe, APIError):
|
||||||
|
logger.error(f"API error: {oe}")
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Unexpected error in ask_tool: {e}")
|
||||||
|
raise
|
42
app/logger.py
Normal file
42
app/logger.py
Normal file
@ -0,0 +1,42 @@
|
|||||||
|
import sys
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
from loguru import logger as _logger
|
||||||
|
|
||||||
|
from app.config import PROJECT_ROOT
|
||||||
|
|
||||||
|
|
||||||
|
_print_level = "INFO"
|
||||||
|
|
||||||
|
|
||||||
|
def define_log_level(print_level="INFO", logfile_level="DEBUG", name: str = None):
|
||||||
|
"""Adjust the log level to above level"""
|
||||||
|
global _print_level
|
||||||
|
_print_level = print_level
|
||||||
|
|
||||||
|
current_date = datetime.now()
|
||||||
|
formatted_date = current_date.strftime("%Y%m%d")
|
||||||
|
log_name = (
|
||||||
|
f"{name}_{formatted_date}" if name else formatted_date
|
||||||
|
) # name a log with prefix name
|
||||||
|
|
||||||
|
_logger.remove()
|
||||||
|
_logger.add(sys.stderr, level=print_level)
|
||||||
|
_logger.add(PROJECT_ROOT / f"logs/{log_name}.txt", level=logfile_level)
|
||||||
|
return _logger
|
||||||
|
|
||||||
|
|
||||||
|
logger = define_log_level()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
logger.info("Starting application")
|
||||||
|
logger.debug("Debug message")
|
||||||
|
logger.warning("Warning message")
|
||||||
|
logger.error("Error message")
|
||||||
|
logger.critical("Critical message")
|
||||||
|
|
||||||
|
try:
|
||||||
|
raise ValueError("Test error")
|
||||||
|
except Exception as e:
|
||||||
|
logger.exception(f"An error occurred: {e}")
|
21
app/loop.py
Normal file
21
app/loop.py
Normal file
@ -0,0 +1,21 @@
|
|||||||
|
from typing import List, Optional
|
||||||
|
|
||||||
|
from app.agent.base import BaseAgent
|
||||||
|
from app.flow.base import FlowType
|
||||||
|
from app.flow.flow_factory import FlowFactory
|
||||||
|
from app.tool import BaseTool, ToolCollection
|
||||||
|
|
||||||
|
|
||||||
|
async def loop(
|
||||||
|
agent: BaseAgent,
|
||||||
|
tools: Optional[List[BaseTool]] = None,
|
||||||
|
flow_type: FlowType = FlowType.PLANNING,
|
||||||
|
input_text: str = "",
|
||||||
|
**loop_kwargs,
|
||||||
|
) -> str:
|
||||||
|
"""Main entry point for running an agent with specified flow type"""
|
||||||
|
tool_collection = ToolCollection(*tools) if tools else None
|
||||||
|
flow = FlowFactory.create_flow(
|
||||||
|
flow_type, agent, tool_collection=tool_collection, **loop_kwargs
|
||||||
|
)
|
||||||
|
return await flow.execute(input_text)
|
0
app/prompt/__init__.py
Normal file
0
app/prompt/__init__.py
Normal file
72
app/prompt/swe.py
Normal file
72
app/prompt/swe.py
Normal file
@ -0,0 +1,72 @@
|
|||||||
|
SYSTEM_PROMPT = """SETTING: You are an autonomous programmer, and you're working directly in the command line with a special interface.
|
||||||
|
|
||||||
|
The special interface consists of a file editor that shows you {{WINDOW}} lines of a file at a time.
|
||||||
|
In addition to typical bash commands, you can also use specific commands to help you navigate and edit files.
|
||||||
|
To call a command, you need to invoke it with a function call/tool call.
|
||||||
|
|
||||||
|
Please note that THE EDIT COMMAND REQUIRES PROPER INDENTATION.
|
||||||
|
If you'd like to add the line ' print(x)' you must fully write that out, with all those spaces before the code! Indentation is important and code that is not indented correctly will fail and require fixing before it can be run.
|
||||||
|
|
||||||
|
RESPONSE FORMAT:
|
||||||
|
Your shell prompt is formatted as follows:
|
||||||
|
(Open file: <path>)
|
||||||
|
(Current directory: <cwd>)
|
||||||
|
bash-$
|
||||||
|
|
||||||
|
First, you should _always_ include a general thought about what you're going to do next.
|
||||||
|
Then, for every response, you must include exactly _ONE_ tool call/function call.
|
||||||
|
|
||||||
|
Remember, you should always include a _SINGLE_ tool call/function call and then wait for a response from the shell before continuing with more discussion and commands. Everything you include in the DISCUSSION section will be saved for future reference.
|
||||||
|
If you'd like to issue two commands at once, PLEASE DO NOT DO THAT! Please instead first submit just the first tool call, and then after receiving a response you'll be able to issue the second tool call.
|
||||||
|
Note that the environment does NOT support interactive session commands (e.g. python, vim), so please do not invoke them.
|
||||||
|
"""
|
||||||
|
|
||||||
|
NEXT_STEP_TEMPLATE = """{{observation}}
|
||||||
|
(Open file: {{open_file}})
|
||||||
|
(Current directory: {{working_dir}})
|
||||||
|
bash-$
|
||||||
|
"""
|
||||||
|
|
||||||
|
NEXT_STEP_NO_OUTPUT_TEMPLATE = """Your command ran successfully and did not produce any output.
|
||||||
|
(Open file: {{open_file}})
|
||||||
|
(Current directory: {{working_dir}})
|
||||||
|
bash-$
|
||||||
|
"""
|
||||||
|
|
||||||
|
INSTANCE_TEMPLATE = """We're currently solving the following issue within our repository. Here's the issue text:
|
||||||
|
ISSUE:
|
||||||
|
{{problem_statement}}
|
||||||
|
|
||||||
|
INSTRUCTIONS:
|
||||||
|
Now, you're going to solve this issue on your own. Your terminal session has started and you're in the repository's root directory. You can use any bash commands or the special interface to help you. Edit all the files you need to and run any checks or tests that you want.
|
||||||
|
Remember, YOU SHOULD ALWAYS INCLUDE EXACTLY ONE TOOL CALL/FUNCTION CALL PER RESPONSE.
|
||||||
|
When you're satisfied with all of the changes you've made, you can submit your changes to the code base by simply running the submit command.
|
||||||
|
Note however that you cannot use any interactive session commands (e.g. python, vim) in this environment, but you can write scripts and run them. E.g. you can write a python script and then run it with the python <script_name>.py`.
|
||||||
|
|
||||||
|
NOTE ABOUT THE EDIT COMMAND: Indentation really matters! When editing a file, make sure to insert appropriate indentation before each line!
|
||||||
|
|
||||||
|
IMPORTANT TIPS:
|
||||||
|
1. Always start by trying to replicate the bug that the issues discusses.
|
||||||
|
If the issue includes code for reproducing the bug, we recommend that you re-implement that in your environment, and run it to make sure you can reproduce the bug.
|
||||||
|
Then start trying to fix it.
|
||||||
|
When you think you've fixed the bug, re-run the bug reproduction script to make sure that the bug has indeed been fixed.
|
||||||
|
|
||||||
|
If the bug reproduction script does not print anything when it successfully runs, we recommend adding a print("Script completed successfully, no errors.") command at the end of the file,
|
||||||
|
so that you can be sure that the script indeed ran fine all the way through.
|
||||||
|
|
||||||
|
2. If you run a command and it doesn't work, try running a different command. A command that did not work once will not work the second time unless you modify it!
|
||||||
|
|
||||||
|
3. If you open a file and need to get to an area around a specific line that is not in the first 100 lines, say line 583, don't just use the scroll_down command multiple times. Instead, use the goto 583 command. It's much quicker.
|
||||||
|
|
||||||
|
4. If the bug reproduction script requires inputting/reading a specific file, such as buggy-input.png, and you'd like to understand how to input that file, conduct a search in the existing repo code, to see whether someone else has already done that. Do this by running the command: find_file "buggy-input.png" If that doesn't work, use the linux 'find' command.
|
||||||
|
|
||||||
|
5. Always make sure to look at the currently open file and the current working directory (which appears right after the currently open file). The currently open file might be in a different directory than the working directory! Note that some commands, such as 'create', open files, so they might change the current open file.
|
||||||
|
|
||||||
|
6. When editing files, it is easy to accidentally specify a wrong line number or to write code with incorrect indentation. Always check the code after you issue an edit to make sure that it reflects what you wanted to accomplish. If it didn't, issue another command to fix it.
|
||||||
|
|
||||||
|
7. Do not try to install any packages with `pip`, `conda`, or any other way. This will usually not work. If the environment is not set up correctly, try to fix the issue without executing python code or running any tests that require the package installed.
|
||||||
|
|
||||||
|
|
||||||
|
(Open file: {{open_file}})
|
||||||
|
(Current directory: {{working_dir}})
|
||||||
|
bash-$"""
|
5
app/prompt/toolcall.py
Normal file
5
app/prompt/toolcall.py
Normal file
@ -0,0 +1,5 @@
|
|||||||
|
SYSTEM_PROMPT = "You are an agent that can execute tool calls"
|
||||||
|
|
||||||
|
NEXT_STEP_PROMPT = (
|
||||||
|
"If you want to stop interaction, use `terminate` tool/function call."
|
||||||
|
)
|
0
app/runtime/__init__.py
Normal file
0
app/runtime/__init__.py
Normal file
259
app/runtime/browser_env.py
Normal file
259
app/runtime/browser_env.py
Normal file
@ -0,0 +1,259 @@
|
|||||||
|
import atexit
|
||||||
|
import base64
|
||||||
|
import io
|
||||||
|
import json
|
||||||
|
import multiprocessing
|
||||||
|
import platform
|
||||||
|
import time
|
||||||
|
import uuid
|
||||||
|
|
||||||
|
import browsergym.core # noqa F401 (we register the openended task as a gym environment)
|
||||||
|
import gymnasium as gym
|
||||||
|
import html2text
|
||||||
|
import numpy as np
|
||||||
|
import tenacity
|
||||||
|
from browsergym.utils.obs import flatten_dom_to_str
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
|
from app.exceptions import BrowserException
|
||||||
|
from app.logger import logger
|
||||||
|
from app.utils.shutdown_listener import should_continue, should_exit
|
||||||
|
|
||||||
|
|
||||||
|
BROWSER_EVAL_GET_GOAL_ACTION = "GET_EVAL_GOAL"
|
||||||
|
BROWSER_EVAL_GET_REWARDS_ACTION = "GET_EVAL_REWARDS"
|
||||||
|
|
||||||
|
|
||||||
|
class BrowserEnv:
|
||||||
|
def __init__(self, browsergym_eval_env: str | None = None, headless: bool = False):
|
||||||
|
"""
|
||||||
|
Initialize the browser environment.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
browsergym_eval_env: Optional evaluation environment name
|
||||||
|
headless: Whether to run the browser in headless mode (no UI)
|
||||||
|
"""
|
||||||
|
self.html_text_converter = self.get_html_text_converter()
|
||||||
|
self.eval_mode = False
|
||||||
|
self.eval_dir = ""
|
||||||
|
self.browsergym_eval_env = browsergym_eval_env
|
||||||
|
self.eval_mode = bool(browsergym_eval_env)
|
||||||
|
self.headless = headless
|
||||||
|
|
||||||
|
# Set multiprocessing start method
|
||||||
|
if platform.system() == "Windows":
|
||||||
|
multiprocessing.set_start_method("spawn", force=True)
|
||||||
|
else:
|
||||||
|
multiprocessing.set_start_method("fork", force=True)
|
||||||
|
|
||||||
|
self.browser_side, self.agent_side = multiprocessing.Pipe()
|
||||||
|
self.process = None # Initialize process as None
|
||||||
|
self.init_browser()
|
||||||
|
atexit.register(self.close)
|
||||||
|
|
||||||
|
def get_html_text_converter(self):
|
||||||
|
html_text_converter = html2text.HTML2Text()
|
||||||
|
# ignore links and images
|
||||||
|
html_text_converter.ignore_links = False
|
||||||
|
html_text_converter.ignore_images = True
|
||||||
|
# use alt text for images
|
||||||
|
html_text_converter.images_to_alt = True
|
||||||
|
# disable auto text wrapping
|
||||||
|
html_text_converter.body_width = 0
|
||||||
|
return html_text_converter
|
||||||
|
|
||||||
|
@tenacity.retry(
|
||||||
|
wait=tenacity.wait_fixed(1),
|
||||||
|
stop=tenacity.stop_after_attempt(5),
|
||||||
|
retry=tenacity.retry_if_exception_type(BrowserException),
|
||||||
|
)
|
||||||
|
def init_browser(self):
|
||||||
|
logger.debug(f"Starting browser env (headless: {self.headless})...")
|
||||||
|
try:
|
||||||
|
self.process = multiprocessing.Process(
|
||||||
|
target=self.browser_process, args=(self.headless,)
|
||||||
|
)
|
||||||
|
self.process.start()
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to start browser process: {e}")
|
||||||
|
if self.process is not None:
|
||||||
|
self.process.terminate()
|
||||||
|
raise BrowserException("Failed to start browser environment.")
|
||||||
|
|
||||||
|
if not self.check_alive():
|
||||||
|
self.close()
|
||||||
|
raise BrowserException("Failed to start browser environment.")
|
||||||
|
|
||||||
|
def browser_process(self, headless: bool):
|
||||||
|
if self.eval_mode:
|
||||||
|
assert self.browsergym_eval_env is not None
|
||||||
|
logger.debug("Initializing browser env for web browsing evaluation.")
|
||||||
|
if "webarena" in self.browsergym_eval_env:
|
||||||
|
import browsergym.webarena # noqa F401 register webarena tasks as gym environments
|
||||||
|
elif "miniwob" in self.browsergym_eval_env:
|
||||||
|
import browsergym.miniwob # noqa F401 register miniwob tasks as gym environments
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
f"Unsupported browsergym eval env: {self.browsergym_eval_env}"
|
||||||
|
)
|
||||||
|
env = gym.make(
|
||||||
|
self.browsergym_eval_env,
|
||||||
|
tags_to_mark="all",
|
||||||
|
headless=headless,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
env = gym.make(
|
||||||
|
"browsergym/openended",
|
||||||
|
task_kwargs={"start_url": "about:blank", "goal": "PLACEHOLDER_GOAL"},
|
||||||
|
wait_for_user_message=False,
|
||||||
|
headless=headless,
|
||||||
|
disable_env_checker=True,
|
||||||
|
tags_to_mark="all",
|
||||||
|
)
|
||||||
|
|
||||||
|
obs, info = env.reset()
|
||||||
|
|
||||||
|
# EVAL ONLY: save the goal into file for evaluation
|
||||||
|
self.eval_goal = None
|
||||||
|
self.eval_rewards: list[float] = [0]
|
||||||
|
if self.eval_mode:
|
||||||
|
logger.debug(f"Browsing goal: {obs['goal']}")
|
||||||
|
self.eval_goal = obs["goal"]
|
||||||
|
|
||||||
|
logger.debug(
|
||||||
|
f"Browser env started in {'headless' if headless else 'visible'} mode."
|
||||||
|
)
|
||||||
|
while should_continue():
|
||||||
|
try:
|
||||||
|
if self.browser_side.poll(timeout=0.01):
|
||||||
|
unique_request_id, action_data = self.browser_side.recv()
|
||||||
|
|
||||||
|
# shutdown the browser environment
|
||||||
|
if unique_request_id == "SHUTDOWN":
|
||||||
|
logger.debug("SHUTDOWN recv, shutting down browser env...")
|
||||||
|
env.close()
|
||||||
|
return
|
||||||
|
elif unique_request_id == "IS_ALIVE":
|
||||||
|
self.browser_side.send(("ALIVE", None))
|
||||||
|
continue
|
||||||
|
|
||||||
|
# EVAL ONLY: Get evaluation info
|
||||||
|
if action_data["action"] == BROWSER_EVAL_GET_GOAL_ACTION:
|
||||||
|
self.browser_side.send(
|
||||||
|
(unique_request_id, {"text_content": self.eval_goal})
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
elif action_data["action"] == BROWSER_EVAL_GET_REWARDS_ACTION:
|
||||||
|
self.browser_side.send(
|
||||||
|
(
|
||||||
|
unique_request_id,
|
||||||
|
{"text_content": json.dumps(self.eval_rewards)},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
|
||||||
|
action = action_data["action"]
|
||||||
|
obs, reward, terminated, truncated, info = env.step(action)
|
||||||
|
|
||||||
|
# EVAL ONLY: Save the rewards into file for evaluation
|
||||||
|
if self.eval_mode:
|
||||||
|
self.eval_rewards.append(reward)
|
||||||
|
|
||||||
|
# add text content of the page
|
||||||
|
html_str = flatten_dom_to_str(obs["dom_object"])
|
||||||
|
obs["text_content"] = self.html_text_converter.handle(html_str)
|
||||||
|
# make observation serializable
|
||||||
|
obs["screenshot"] = self.image_to_png_base64_url(obs["screenshot"])
|
||||||
|
obs["active_page_index"] = obs["active_page_index"].item()
|
||||||
|
obs["elapsed_time"] = obs["elapsed_time"].item()
|
||||||
|
self.browser_side.send((unique_request_id, obs))
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
logger.debug("Browser env process interrupted by user.")
|
||||||
|
try:
|
||||||
|
env.close()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
return
|
||||||
|
|
||||||
|
def step(self, action_str: str, timeout: float = 30) -> dict:
|
||||||
|
"""Execute an action in the browser environment and return the observation."""
|
||||||
|
unique_request_id = str(uuid.uuid4())
|
||||||
|
self.agent_side.send((unique_request_id, {"action": action_str}))
|
||||||
|
start_time = time.time()
|
||||||
|
while True:
|
||||||
|
if should_exit() or time.time() - start_time > timeout:
|
||||||
|
raise TimeoutError("Browser environment took too long to respond.")
|
||||||
|
if self.agent_side.poll(timeout=0.01):
|
||||||
|
response_id, obs = self.agent_side.recv()
|
||||||
|
if response_id == unique_request_id:
|
||||||
|
return obs
|
||||||
|
|
||||||
|
def check_alive(self, timeout: float = 60):
|
||||||
|
self.agent_side.send(("IS_ALIVE", None))
|
||||||
|
if self.agent_side.poll(timeout=timeout):
|
||||||
|
response_id, _ = self.agent_side.recv()
|
||||||
|
if response_id == "ALIVE":
|
||||||
|
return True
|
||||||
|
logger.debug(f"Browser env is not alive. Response ID: {response_id}")
|
||||||
|
|
||||||
|
def close(self):
|
||||||
|
if (
|
||||||
|
not hasattr(self, "process")
|
||||||
|
or self.process is None
|
||||||
|
or not self.process.is_alive()
|
||||||
|
):
|
||||||
|
return
|
||||||
|
try:
|
||||||
|
self.agent_side.send(("SHUTDOWN", None))
|
||||||
|
self.process.join(5) # Wait for the process to terminate
|
||||||
|
if self.process.is_alive():
|
||||||
|
logger.error(
|
||||||
|
"Browser process did not terminate, forcefully terminating..."
|
||||||
|
)
|
||||||
|
self.process.terminate()
|
||||||
|
self.process.join(5) # Wait for the process to terminate
|
||||||
|
if self.process.is_alive():
|
||||||
|
self.process.kill()
|
||||||
|
self.process.join(5) # Wait for the process to terminate
|
||||||
|
self.agent_side.close()
|
||||||
|
self.browser_side.close()
|
||||||
|
except Exception:
|
||||||
|
logger.error("Encountered an error when closing browser env", exc_info=True)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def image_to_png_base64_url(
|
||||||
|
image: np.ndarray | Image.Image, add_data_prefix: bool = False
|
||||||
|
):
|
||||||
|
"""Convert a numpy array to a base64 encoded png image url."""
|
||||||
|
if isinstance(image, np.ndarray):
|
||||||
|
image = Image.fromarray(image)
|
||||||
|
if image.mode in ("RGBA", "LA"):
|
||||||
|
image = image.convert("RGB")
|
||||||
|
buffered = io.BytesIO()
|
||||||
|
image.save(buffered, format="PNG")
|
||||||
|
|
||||||
|
image_base64 = base64.b64encode(buffered.getvalue()).decode()
|
||||||
|
return (
|
||||||
|
f"data:image/png;base64,{image_base64}"
|
||||||
|
if add_data_prefix
|
||||||
|
else f"{image_base64}"
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def image_to_jpg_base64_url(
|
||||||
|
image: np.ndarray | Image.Image, add_data_prefix: bool = False
|
||||||
|
):
|
||||||
|
"""Convert a numpy array to a base64 encoded jpeg image url."""
|
||||||
|
if isinstance(image, np.ndarray):
|
||||||
|
image = Image.fromarray(image)
|
||||||
|
if image.mode in ("RGBA", "LA"):
|
||||||
|
image = image.convert("RGB")
|
||||||
|
buffered = io.BytesIO()
|
||||||
|
image.save(buffered, format="JPEG")
|
||||||
|
|
||||||
|
image_base64 = base64.b64encode(buffered.getvalue()).decode()
|
||||||
|
return (
|
||||||
|
f"data:image/jpeg;base64,{image_base64}"
|
||||||
|
if add_data_prefix
|
||||||
|
else f"{image_base64}"
|
||||||
|
)
|
135
app/schema.py
Normal file
135
app/schema.py
Normal file
@ -0,0 +1,135 @@
|
|||||||
|
from enum import Enum
|
||||||
|
from typing import Any, List, Literal, Optional, Union
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
|
||||||
|
class AgentState(str, Enum):
|
||||||
|
"""Agent execution states"""
|
||||||
|
|
||||||
|
IDLE = "IDLE"
|
||||||
|
RUNNING = "RUNNING"
|
||||||
|
FINISHED = "FINISHED"
|
||||||
|
ERROR = "ERROR"
|
||||||
|
|
||||||
|
|
||||||
|
class Function(BaseModel):
|
||||||
|
name: str
|
||||||
|
arguments: str
|
||||||
|
|
||||||
|
|
||||||
|
class ToolCall(BaseModel):
|
||||||
|
"""Represents a tool/function call in a message"""
|
||||||
|
|
||||||
|
id: str
|
||||||
|
type: str = "function"
|
||||||
|
function: Function
|
||||||
|
|
||||||
|
|
||||||
|
class Message(BaseModel):
|
||||||
|
"""Represents a chat message in the conversation"""
|
||||||
|
|
||||||
|
role: Literal["system", "user", "assistant", "tool"] = Field(...)
|
||||||
|
content: Optional[str] = Field(default=None)
|
||||||
|
tool_calls: Optional[List[ToolCall]] = Field(default=None)
|
||||||
|
name: Optional[str] = Field(default=None)
|
||||||
|
tool_call_id: Optional[str] = Field(default=None)
|
||||||
|
|
||||||
|
def __add__(self, other) -> List["Message"]:
|
||||||
|
"""支持 Message + list 或 Message + Message 的操作"""
|
||||||
|
if isinstance(other, list):
|
||||||
|
return [self] + other
|
||||||
|
elif isinstance(other, Message):
|
||||||
|
return [self, other]
|
||||||
|
else:
|
||||||
|
raise TypeError(
|
||||||
|
f"unsupported operand type(s) for +: '{type(self).__name__}' and '{type(other).__name__}'"
|
||||||
|
)
|
||||||
|
|
||||||
|
def __radd__(self, other) -> List["Message"]:
|
||||||
|
"""支持 list + Message 的操作"""
|
||||||
|
if isinstance(other, list):
|
||||||
|
return other + [self]
|
||||||
|
else:
|
||||||
|
raise TypeError(
|
||||||
|
f"unsupported operand type(s) for +: '{type(other).__name__}' and '{type(self).__name__}'"
|
||||||
|
)
|
||||||
|
|
||||||
|
def to_dict(self) -> dict:
|
||||||
|
"""Convert message to dictionary format"""
|
||||||
|
message = {"role": self.role}
|
||||||
|
if self.content is not None:
|
||||||
|
message["content"] = self.content
|
||||||
|
if self.tool_calls is not None:
|
||||||
|
message["tool_calls"] = [tool_call.dict() for tool_call in self.tool_calls]
|
||||||
|
if self.name is not None:
|
||||||
|
message["name"] = self.name
|
||||||
|
if self.tool_call_id is not None:
|
||||||
|
message["tool_call_id"] = self.tool_call_id
|
||||||
|
return message
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def user_message(cls, content: str) -> "Message":
|
||||||
|
"""Create a user message"""
|
||||||
|
return cls(role="user", content=content)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def system_message(cls, content: str) -> "Message":
|
||||||
|
"""Create a system message"""
|
||||||
|
return cls(role="system", content=content)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def assistant_message(cls, content: Optional[str] = None) -> "Message":
|
||||||
|
"""Create an assistant message"""
|
||||||
|
return cls(role="assistant", content=content)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def tool_message(cls, content: str, name, tool_call_id: str) -> "Message":
|
||||||
|
"""Create a tool message"""
|
||||||
|
return cls(role="tool", content=content, name=name, tool_call_id=tool_call_id)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_tool_calls(
|
||||||
|
cls, tool_calls: List[Any], content: Union[str, List[str]] = "", **kwargs
|
||||||
|
):
|
||||||
|
"""Create ToolCallsMessage from raw tool calls.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tool_calls: Raw tool calls from LLM
|
||||||
|
content: Optional message content
|
||||||
|
"""
|
||||||
|
formatted_calls = [
|
||||||
|
{"id": call.id, "function": call.function.model_dump(), "type": "function"}
|
||||||
|
for call in tool_calls
|
||||||
|
]
|
||||||
|
return cls(
|
||||||
|
role="assistant", content=content, tool_calls=formatted_calls, **kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class Memory(BaseModel):
|
||||||
|
messages: List[Message] = Field(default_factory=list)
|
||||||
|
max_messages: int = Field(default=100)
|
||||||
|
|
||||||
|
def add_message(self, message: Message) -> None:
|
||||||
|
"""Add a message to memory"""
|
||||||
|
self.messages.append(message)
|
||||||
|
# Optional: Implement message limit
|
||||||
|
if len(self.messages) > self.max_messages:
|
||||||
|
self.messages = self.messages[-self.max_messages :]
|
||||||
|
|
||||||
|
def add_messages(self, messages: List[Message]) -> None:
|
||||||
|
"""Add multiple messages to memory"""
|
||||||
|
self.messages.extend(messages)
|
||||||
|
|
||||||
|
def clear(self) -> None:
|
||||||
|
"""Clear all messages"""
|
||||||
|
self.messages.clear()
|
||||||
|
|
||||||
|
def get_recent_messages(self, n: int) -> List[Message]:
|
||||||
|
"""Get n most recent messages"""
|
||||||
|
return self.messages[-n:]
|
||||||
|
|
||||||
|
def to_dict_list(self) -> List[dict]:
|
||||||
|
"""Convert messages to list of dicts"""
|
||||||
|
return [msg.to_dict() for msg in self.messages]
|
18
app/tool/__init__.py
Normal file
18
app/tool/__init__.py
Normal file
@ -0,0 +1,18 @@
|
|||||||
|
from app.tool.base import BaseTool
|
||||||
|
from app.tool.bash import Bash
|
||||||
|
from app.tool.create_chat_completion import CreateChatCompletion
|
||||||
|
from app.tool.planning import PlanningTool
|
||||||
|
from app.tool.str_replace_editor import StrReplaceEditor
|
||||||
|
from app.tool.terminate import Terminate
|
||||||
|
from app.tool.tool_collection import ToolCollection
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"BaseTool",
|
||||||
|
"Bash",
|
||||||
|
"Terminate",
|
||||||
|
"StrReplaceEditor",
|
||||||
|
"ToolCollection",
|
||||||
|
"CreateChatCompletion",
|
||||||
|
"PlanningTool",
|
||||||
|
]
|
82
app/tool/base.py
Normal file
82
app/tool/base.py
Normal file
@ -0,0 +1,82 @@
|
|||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from typing import Any, Dict, Optional
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
|
||||||
|
class BaseTool(ABC, BaseModel):
|
||||||
|
name: str
|
||||||
|
description: str
|
||||||
|
parameters: Optional[dict] = None
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
arbitrary_types_allowed = True
|
||||||
|
|
||||||
|
async def __call__(self, **kwargs) -> Any:
|
||||||
|
"""Execute the tool with given parameters."""
|
||||||
|
return await self.execute(**kwargs)
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def execute(self, **kwargs) -> Any:
|
||||||
|
"""Execute the tool with given parameters."""
|
||||||
|
|
||||||
|
def to_param(self) -> Dict:
|
||||||
|
"""Convert tool to function call format."""
|
||||||
|
return {
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": self.name,
|
||||||
|
"description": self.description,
|
||||||
|
"parameters": self.parameters,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class ToolResult(BaseModel):
|
||||||
|
"""Represents the result of a tool execution."""
|
||||||
|
|
||||||
|
output: Any = Field(default=None)
|
||||||
|
error: Optional[str] = Field(default=None)
|
||||||
|
system: Optional[str] = Field(default=None)
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
arbitrary_types_allowed = True
|
||||||
|
|
||||||
|
def __bool__(self):
|
||||||
|
return any(getattr(self, field) for field in self.__fields__)
|
||||||
|
|
||||||
|
def __add__(self, other: "ToolResult"):
|
||||||
|
def combine_fields(
|
||||||
|
field: Optional[str], other_field: Optional[str], concatenate: bool = True
|
||||||
|
):
|
||||||
|
if field and other_field:
|
||||||
|
if concatenate:
|
||||||
|
return field + other_field
|
||||||
|
raise ValueError("Cannot combine tool results")
|
||||||
|
return field or other_field
|
||||||
|
|
||||||
|
return ToolResult(
|
||||||
|
output=combine_fields(self.output, other.output),
|
||||||
|
error=combine_fields(self.error, other.error),
|
||||||
|
system=combine_fields(self.system, other.system),
|
||||||
|
)
|
||||||
|
|
||||||
|
def __str__(self):
|
||||||
|
return f"Error: {self.error}" if self.error else self.output
|
||||||
|
|
||||||
|
def replace(self, **kwargs):
|
||||||
|
"""Returns a new ToolResult with the given fields replaced."""
|
||||||
|
# return self.copy(update=kwargs)
|
||||||
|
return type(self)(**{**self.dict(), **kwargs})
|
||||||
|
|
||||||
|
|
||||||
|
class CLIResult(ToolResult):
|
||||||
|
"""A ToolResult that can be rendered as a CLI output."""
|
||||||
|
|
||||||
|
|
||||||
|
class ToolFailure(ToolResult):
|
||||||
|
"""A ToolResult that represents a failure."""
|
||||||
|
|
||||||
|
|
||||||
|
class AgentAwareTool:
|
||||||
|
agent: Optional = None
|
158
app/tool/bash.py
Normal file
158
app/tool/bash.py
Normal file
@ -0,0 +1,158 @@
|
|||||||
|
import asyncio
|
||||||
|
import os
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from app.exceptions import ToolError
|
||||||
|
from app.tool.base import BaseTool, CLIResult, ToolResult
|
||||||
|
|
||||||
|
|
||||||
|
_BASH_DESCRIPTION = """Execute a bash command in the terminal.
|
||||||
|
* Long running commands: For commands that may run indefinitely, it should be run in the background and the output should be redirected to a file, e.g. command = `python3 app.py > server.log 2>&1 &`.
|
||||||
|
* Interactive: If a bash command returns exit code `-1`, this means the process is not yet finished. The assistant must then send a second call to terminal with an empty `command` (which will retrieve any additional logs), or it can send additional text (set `command` to the text) to STDIN of the running process, or it can send command=`ctrl+c` to interrupt the process.
|
||||||
|
* Timeout: If a command execution result says "Command timed out. Sending SIGINT to the process", the assistant should retry running the command in the background.
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
class _BashSession:
|
||||||
|
"""A session of a bash shell."""
|
||||||
|
|
||||||
|
_started: bool
|
||||||
|
_process: asyncio.subprocess.Process
|
||||||
|
|
||||||
|
command: str = "/bin/bash"
|
||||||
|
_output_delay: float = 0.2 # seconds
|
||||||
|
_timeout: float = 120.0 # seconds
|
||||||
|
_sentinel: str = "<<exit>>"
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self._started = False
|
||||||
|
self._timed_out = False
|
||||||
|
|
||||||
|
async def start(self):
|
||||||
|
if self._started:
|
||||||
|
return
|
||||||
|
|
||||||
|
self._process = await asyncio.create_subprocess_shell(
|
||||||
|
self.command,
|
||||||
|
preexec_fn=os.setsid,
|
||||||
|
shell=True,
|
||||||
|
bufsize=0,
|
||||||
|
stdin=asyncio.subprocess.PIPE,
|
||||||
|
stdout=asyncio.subprocess.PIPE,
|
||||||
|
stderr=asyncio.subprocess.PIPE,
|
||||||
|
)
|
||||||
|
|
||||||
|
self._started = True
|
||||||
|
|
||||||
|
def stop(self):
|
||||||
|
"""Terminate the bash shell."""
|
||||||
|
if not self._started:
|
||||||
|
raise ToolError("Session has not started.")
|
||||||
|
if self._process.returncode is not None:
|
||||||
|
return
|
||||||
|
self._process.terminate()
|
||||||
|
|
||||||
|
async def run(self, command: str):
|
||||||
|
"""Execute a command in the bash shell."""
|
||||||
|
if not self._started:
|
||||||
|
raise ToolError("Session has not started.")
|
||||||
|
if self._process.returncode is not None:
|
||||||
|
return ToolResult(
|
||||||
|
system="tool must be restarted",
|
||||||
|
error=f"bash has exited with returncode {self._process.returncode}",
|
||||||
|
)
|
||||||
|
if self._timed_out:
|
||||||
|
raise ToolError(
|
||||||
|
f"timed out: bash has not returned in {self._timeout} seconds and must be restarted",
|
||||||
|
)
|
||||||
|
|
||||||
|
# we know these are not None because we created the process with PIPEs
|
||||||
|
assert self._process.stdin
|
||||||
|
assert self._process.stdout
|
||||||
|
assert self._process.stderr
|
||||||
|
|
||||||
|
# send command to the process
|
||||||
|
self._process.stdin.write(
|
||||||
|
command.encode() + f"; echo '{self._sentinel}'\n".encode()
|
||||||
|
)
|
||||||
|
await self._process.stdin.drain()
|
||||||
|
|
||||||
|
# read output from the process, until the sentinel is found
|
||||||
|
try:
|
||||||
|
async with asyncio.timeout(self._timeout):
|
||||||
|
while True:
|
||||||
|
await asyncio.sleep(self._output_delay)
|
||||||
|
# if we read directly from stdout/stderr, it will wait forever for
|
||||||
|
# EOF. use the StreamReader buffer directly instead.
|
||||||
|
output = (
|
||||||
|
self._process.stdout._buffer.decode()
|
||||||
|
) # pyright: ignore[reportAttributeAccessIssue]
|
||||||
|
if self._sentinel in output:
|
||||||
|
# strip the sentinel and break
|
||||||
|
output = output[: output.index(self._sentinel)]
|
||||||
|
break
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
self._timed_out = True
|
||||||
|
raise ToolError(
|
||||||
|
f"timed out: bash has not returned in {self._timeout} seconds and must be restarted",
|
||||||
|
) from None
|
||||||
|
|
||||||
|
if output.endswith("\n"):
|
||||||
|
output = output[:-1]
|
||||||
|
|
||||||
|
error = (
|
||||||
|
self._process.stderr._buffer.decode()
|
||||||
|
) # pyright: ignore[reportAttributeAccessIssue]
|
||||||
|
if error.endswith("\n"):
|
||||||
|
error = error[:-1]
|
||||||
|
|
||||||
|
# clear the buffers so that the next output can be read correctly
|
||||||
|
self._process.stdout._buffer.clear() # pyright: ignore[reportAttributeAccessIssue]
|
||||||
|
self._process.stderr._buffer.clear() # pyright: ignore[reportAttributeAccessIssue]
|
||||||
|
|
||||||
|
return CLIResult(output=output, error=error)
|
||||||
|
|
||||||
|
|
||||||
|
class Bash(BaseTool):
|
||||||
|
"""A tool for executing bash commands"""
|
||||||
|
|
||||||
|
name: str = "bash"
|
||||||
|
description: str = _BASH_DESCRIPTION
|
||||||
|
parameters: dict = {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"command": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "The bash command to execute. Can be empty to view additional logs when previous exit code is `-1`. Can be `ctrl+c` to interrupt the currently running process.",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"required": ["command"],
|
||||||
|
}
|
||||||
|
|
||||||
|
_session: Optional[_BashSession] = None
|
||||||
|
|
||||||
|
async def execute(
|
||||||
|
self, command: str | None = None, restart: bool = False, **kwargs
|
||||||
|
) -> CLIResult:
|
||||||
|
if restart:
|
||||||
|
if self._session:
|
||||||
|
self._session.stop()
|
||||||
|
self._session = _BashSession()
|
||||||
|
await self._session.start()
|
||||||
|
|
||||||
|
return ToolResult(system="tool has been restarted.")
|
||||||
|
|
||||||
|
if self._session is None:
|
||||||
|
self._session = _BashSession()
|
||||||
|
await self._session.start()
|
||||||
|
|
||||||
|
if command is not None:
|
||||||
|
return await self._session.run(command)
|
||||||
|
|
||||||
|
raise ToolError("no command provided.")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
bash = Bash()
|
||||||
|
rst = asyncio.run(bash.execute("ls -l"))
|
||||||
|
print(rst)
|
169
app/tool/create_chat_completion.py
Normal file
169
app/tool/create_chat_completion.py
Normal file
@ -0,0 +1,169 @@
|
|||||||
|
from typing import Any, List, Optional, Type, Union, get_args, get_origin
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
from app.tool import BaseTool
|
||||||
|
|
||||||
|
|
||||||
|
class CreateChatCompletion(BaseTool):
|
||||||
|
name: str = "create_chat_completion"
|
||||||
|
description: str = (
|
||||||
|
"Creates a structured completion with specified output formatting."
|
||||||
|
)
|
||||||
|
|
||||||
|
# Type mapping for JSON schema
|
||||||
|
type_mapping: dict = {
|
||||||
|
str: "string",
|
||||||
|
int: "integer",
|
||||||
|
float: "number",
|
||||||
|
bool: "boolean",
|
||||||
|
dict: "object",
|
||||||
|
list: "array",
|
||||||
|
}
|
||||||
|
response_type: Optional[Type] = None
|
||||||
|
required: List[str] = Field(default_factory=lambda: ["response"])
|
||||||
|
|
||||||
|
def __init__(self, response_type: Optional[Type] = str):
|
||||||
|
"""Initialize with a specific response type."""
|
||||||
|
super().__init__()
|
||||||
|
self.response_type = response_type
|
||||||
|
self.parameters = self._build_parameters()
|
||||||
|
|
||||||
|
def _build_parameters(self) -> dict:
|
||||||
|
"""Build parameters schema based on response type."""
|
||||||
|
if self.response_type == str:
|
||||||
|
return {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"response": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "The response text that should be delivered to the user.",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"required": self.required,
|
||||||
|
}
|
||||||
|
|
||||||
|
if isinstance(self.response_type, type) and issubclass(
|
||||||
|
self.response_type, BaseModel
|
||||||
|
):
|
||||||
|
schema = self.response_type.model_json_schema()
|
||||||
|
return {
|
||||||
|
"type": "object",
|
||||||
|
"properties": schema["properties"],
|
||||||
|
"required": schema.get("required", self.required),
|
||||||
|
}
|
||||||
|
|
||||||
|
return self._create_type_schema(self.response_type)
|
||||||
|
|
||||||
|
def _create_type_schema(self, type_hint: Type) -> dict:
|
||||||
|
"""Create a JSON schema for the given type."""
|
||||||
|
origin = get_origin(type_hint)
|
||||||
|
args = get_args(type_hint)
|
||||||
|
|
||||||
|
# Handle primitive types
|
||||||
|
if origin is None:
|
||||||
|
return {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"response": {
|
||||||
|
"type": self.type_mapping.get(type_hint, "string"),
|
||||||
|
"description": f"Response of type {type_hint.__name__}",
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"required": self.required,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Handle List type
|
||||||
|
if origin is list:
|
||||||
|
item_type = args[0] if args else Any
|
||||||
|
return {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"response": {
|
||||||
|
"type": "array",
|
||||||
|
"items": self._get_type_info(item_type),
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"required": self.required,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Handle Dict type
|
||||||
|
if origin is dict:
|
||||||
|
value_type = args[1] if len(args) > 1 else Any
|
||||||
|
return {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"response": {
|
||||||
|
"type": "object",
|
||||||
|
"additionalProperties": self._get_type_info(value_type),
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"required": self.required,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Handle Union type
|
||||||
|
if origin is Union:
|
||||||
|
return self._create_union_schema(args)
|
||||||
|
|
||||||
|
return self._build_parameters()
|
||||||
|
|
||||||
|
def _get_type_info(self, type_hint: Type) -> dict:
|
||||||
|
"""Get type information for a single type."""
|
||||||
|
if isinstance(type_hint, type) and issubclass(type_hint, BaseModel):
|
||||||
|
return type_hint.model_json_schema()
|
||||||
|
|
||||||
|
return {
|
||||||
|
"type": self.type_mapping.get(type_hint, "string"),
|
||||||
|
"description": f"Value of type {getattr(type_hint, '__name__', 'any')}",
|
||||||
|
}
|
||||||
|
|
||||||
|
def _create_union_schema(self, types: tuple) -> dict:
|
||||||
|
"""Create schema for Union types."""
|
||||||
|
return {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"response": {"anyOf": [self._get_type_info(t) for t in types]}
|
||||||
|
},
|
||||||
|
"required": self.required,
|
||||||
|
}
|
||||||
|
|
||||||
|
async def execute(self, required: list | None = None, **kwargs) -> Any:
|
||||||
|
"""Execute the chat completion with type conversion.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
required: List of required field names or None
|
||||||
|
**kwargs: Response data
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Converted response based on response_type
|
||||||
|
"""
|
||||||
|
required = required or self.required
|
||||||
|
|
||||||
|
# Handle case when required is a list
|
||||||
|
if isinstance(required, list) and len(required) > 0:
|
||||||
|
if len(required) == 1:
|
||||||
|
required_field = required[0]
|
||||||
|
result = kwargs.get(required_field, "")
|
||||||
|
else:
|
||||||
|
# Return multiple fields as a dictionary
|
||||||
|
return {field: kwargs.get(field, "") for field in required}
|
||||||
|
else:
|
||||||
|
required_field = "response"
|
||||||
|
result = kwargs.get(required_field, "")
|
||||||
|
|
||||||
|
# Type conversion logic
|
||||||
|
if self.response_type == str:
|
||||||
|
return result
|
||||||
|
|
||||||
|
if isinstance(self.response_type, type) and issubclass(
|
||||||
|
self.response_type, BaseModel
|
||||||
|
):
|
||||||
|
return self.response_type(**kwargs)
|
||||||
|
|
||||||
|
if get_origin(self.response_type) in (list, dict):
|
||||||
|
return result # Assuming result is already in correct format
|
||||||
|
|
||||||
|
try:
|
||||||
|
return self.response_type(result)
|
||||||
|
except (ValueError, TypeError):
|
||||||
|
return result
|
364
app/tool/planning.py
Normal file
364
app/tool/planning.py
Normal file
@ -0,0 +1,364 @@
|
|||||||
|
# tool/planning.py
|
||||||
|
from typing import Dict, List, Literal, Optional
|
||||||
|
|
||||||
|
from app.exceptions import ToolError
|
||||||
|
from app.tool.base import BaseTool, ToolResult
|
||||||
|
|
||||||
|
|
||||||
|
_PLANNING_TOOL_DESCRIPTION = """
|
||||||
|
A planning tool that allows the agent to create and manage plans for solving complex tasks.
|
||||||
|
The tool provides functionality for creating plans, updating plan steps, and tracking progress.
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
class PlanningTool(BaseTool):
|
||||||
|
"""
|
||||||
|
A planning tool that allows the agent to create and manage plans for solving complex tasks.
|
||||||
|
The tool provides functionality for creating plans, updating plan steps, and tracking progress.
|
||||||
|
"""
|
||||||
|
|
||||||
|
name: str = "planning"
|
||||||
|
description: str = _PLANNING_TOOL_DESCRIPTION
|
||||||
|
parameters: dict = {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"command": {
|
||||||
|
"description": "The command to execute. Available commands: create, update, list, get, set_active, mark_step, delete.",
|
||||||
|
"enum": [
|
||||||
|
"create",
|
||||||
|
"update",
|
||||||
|
"list",
|
||||||
|
"get",
|
||||||
|
"set_active",
|
||||||
|
"mark_step",
|
||||||
|
"delete",
|
||||||
|
],
|
||||||
|
"type": "string",
|
||||||
|
},
|
||||||
|
"plan_id": {
|
||||||
|
"description": "Unique identifier for the plan. Required for create, update, set_active, and delete commands. Optional for get and mark_step (uses active plan if not specified).",
|
||||||
|
"type": "string",
|
||||||
|
},
|
||||||
|
"title": {
|
||||||
|
"description": "Title for the plan. Required for create command, optional for update command.",
|
||||||
|
"type": "string",
|
||||||
|
},
|
||||||
|
"steps": {
|
||||||
|
"description": "List of plan steps. Required for create command, optional for update command.",
|
||||||
|
"type": "array",
|
||||||
|
"items": {"type": "string"},
|
||||||
|
},
|
||||||
|
"step_index": {
|
||||||
|
"description": "Index of the step to update (0-based). Required for mark_step command.",
|
||||||
|
"type": "integer",
|
||||||
|
},
|
||||||
|
"step_status": {
|
||||||
|
"description": "Status to set for a step. Used with mark_step command.",
|
||||||
|
"enum": ["not_started", "in_progress", "completed", "blocked"],
|
||||||
|
"type": "string",
|
||||||
|
},
|
||||||
|
"step_notes": {
|
||||||
|
"description": "Additional notes for a step. Optional for mark_step command.",
|
||||||
|
"type": "string",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"required": ["command"],
|
||||||
|
"additionalProperties": False,
|
||||||
|
}
|
||||||
|
|
||||||
|
_plans = {} # Dictionary to store plans by plan_id
|
||||||
|
_current_plan_id = None # Track the current active plan
|
||||||
|
|
||||||
|
# async def __call__(
|
||||||
|
async def execute(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
command: Literal[
|
||||||
|
"create", "update", "list", "get", "set_active", "mark_step", "delete"
|
||||||
|
],
|
||||||
|
plan_id: Optional[str] = None,
|
||||||
|
title: Optional[str] = None,
|
||||||
|
steps: Optional[List[str]] = None,
|
||||||
|
step_index: Optional[int] = None,
|
||||||
|
step_status: Optional[
|
||||||
|
Literal["not_started", "in_progress", "completed", "blocked"]
|
||||||
|
] = None,
|
||||||
|
step_notes: Optional[str] = None,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Execute the planning tool with the given command and parameters.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
- command: The operation to perform
|
||||||
|
- plan_id: Unique identifier for the plan
|
||||||
|
- title: Title for the plan (used with create command)
|
||||||
|
- steps: List of steps for the plan (used with create command)
|
||||||
|
- step_index: Index of the step to update (used with mark_step command)
|
||||||
|
- step_status: Status to set for a step (used with mark_step command)
|
||||||
|
- step_notes: Additional notes for a step (used with mark_step command)
|
||||||
|
"""
|
||||||
|
|
||||||
|
if command == "create":
|
||||||
|
return self._create_plan(plan_id, title, steps)
|
||||||
|
elif command == "update":
|
||||||
|
return self._update_plan(plan_id, title, steps)
|
||||||
|
elif command == "list":
|
||||||
|
return self._list_plans()
|
||||||
|
elif command == "get":
|
||||||
|
return self._get_plan(plan_id)
|
||||||
|
elif command == "set_active":
|
||||||
|
return self._set_active_plan(plan_id)
|
||||||
|
elif command == "mark_step":
|
||||||
|
return self._mark_step(plan_id, step_index, step_status, step_notes)
|
||||||
|
elif command == "delete":
|
||||||
|
return self._delete_plan(plan_id)
|
||||||
|
else:
|
||||||
|
raise ToolError(
|
||||||
|
f"Unrecognized command: {command}. Allowed commands are: create, update, list, get, set_active, mark_step, delete"
|
||||||
|
)
|
||||||
|
|
||||||
|
def _create_plan(
|
||||||
|
self, plan_id: Optional[str], title: Optional[str], steps: Optional[List[str]]
|
||||||
|
) -> ToolResult:
|
||||||
|
"""Create a new plan with the given ID, title, and steps."""
|
||||||
|
if not plan_id:
|
||||||
|
raise ToolError("Parameter `plan_id` is required for command: create")
|
||||||
|
|
||||||
|
if plan_id in self._plans:
|
||||||
|
raise ToolError(
|
||||||
|
f"A plan with ID '{plan_id}' already exists. Use 'update' to modify existing plans."
|
||||||
|
)
|
||||||
|
|
||||||
|
if not title:
|
||||||
|
raise ToolError("Parameter `title` is required for command: create")
|
||||||
|
|
||||||
|
if (
|
||||||
|
not steps
|
||||||
|
or not isinstance(steps, list)
|
||||||
|
or not all(isinstance(step, str) for step in steps)
|
||||||
|
):
|
||||||
|
raise ToolError(
|
||||||
|
"Parameter `steps` must be a non-empty list of strings for command: create"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create a new plan with initialized step statuses
|
||||||
|
plan = {
|
||||||
|
"plan_id": plan_id,
|
||||||
|
"title": title,
|
||||||
|
"steps": steps,
|
||||||
|
"step_statuses": ["not_started"] * len(steps),
|
||||||
|
"step_notes": [""] * len(steps),
|
||||||
|
}
|
||||||
|
|
||||||
|
self._plans[plan_id] = plan
|
||||||
|
self._current_plan_id = plan_id # Set as active plan
|
||||||
|
|
||||||
|
return ToolResult(
|
||||||
|
output=f"Plan created successfully with ID: {plan_id}\n\n{self._format_plan(plan)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
def _update_plan(
|
||||||
|
self, plan_id: Optional[str], title: Optional[str], steps: Optional[List[str]]
|
||||||
|
) -> ToolResult:
|
||||||
|
"""Update an existing plan with new title or steps."""
|
||||||
|
if not plan_id:
|
||||||
|
raise ToolError("Parameter `plan_id` is required for command: update")
|
||||||
|
|
||||||
|
if plan_id not in self._plans:
|
||||||
|
raise ToolError(f"No plan found with ID: {plan_id}")
|
||||||
|
|
||||||
|
plan = self._plans[plan_id]
|
||||||
|
|
||||||
|
if title:
|
||||||
|
plan["title"] = title
|
||||||
|
|
||||||
|
if steps:
|
||||||
|
if not isinstance(steps, list) or not all(
|
||||||
|
isinstance(step, str) for step in steps
|
||||||
|
):
|
||||||
|
raise ToolError(
|
||||||
|
"Parameter `steps` must be a list of strings for command: update"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Preserve existing step statuses for unchanged steps
|
||||||
|
old_steps = plan["steps"]
|
||||||
|
old_statuses = plan["step_statuses"]
|
||||||
|
old_notes = plan["step_notes"]
|
||||||
|
|
||||||
|
# Create new step statuses and notes
|
||||||
|
new_statuses = []
|
||||||
|
new_notes = []
|
||||||
|
|
||||||
|
for i, step in enumerate(steps):
|
||||||
|
# If the step exists at the same position in old steps, preserve status and notes
|
||||||
|
if i < len(old_steps) and step == old_steps[i]:
|
||||||
|
new_statuses.append(old_statuses[i])
|
||||||
|
new_notes.append(old_notes[i])
|
||||||
|
else:
|
||||||
|
new_statuses.append("not_started")
|
||||||
|
new_notes.append("")
|
||||||
|
|
||||||
|
plan["steps"] = steps
|
||||||
|
plan["step_statuses"] = new_statuses
|
||||||
|
plan["step_notes"] = new_notes
|
||||||
|
|
||||||
|
return ToolResult(
|
||||||
|
output=f"Plan updated successfully: {plan_id}\n\n{self._format_plan(plan)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
def _list_plans(self) -> ToolResult:
|
||||||
|
"""List all available plans."""
|
||||||
|
if not self._plans:
|
||||||
|
return ToolResult(
|
||||||
|
output="No plans available. Create a plan with the 'create' command."
|
||||||
|
)
|
||||||
|
|
||||||
|
output = "Available plans:\n"
|
||||||
|
for plan_id, plan in self._plans.items():
|
||||||
|
current_marker = " (active)" if plan_id == self._current_plan_id else ""
|
||||||
|
completed = sum(
|
||||||
|
1 for status in plan["step_statuses"] if status == "completed"
|
||||||
|
)
|
||||||
|
total = len(plan["steps"])
|
||||||
|
progress = f"{completed}/{total} steps completed"
|
||||||
|
output += f"• {plan_id}{current_marker}: {plan['title']} - {progress}\n"
|
||||||
|
|
||||||
|
return ToolResult(output=output)
|
||||||
|
|
||||||
|
def _get_plan(self, plan_id: Optional[str]) -> ToolResult:
|
||||||
|
"""Get details of a specific plan."""
|
||||||
|
if not plan_id:
|
||||||
|
# If no plan_id is provided, use the current active plan
|
||||||
|
if not self._current_plan_id:
|
||||||
|
raise ToolError(
|
||||||
|
"No active plan. Please specify a plan_id or set an active plan."
|
||||||
|
)
|
||||||
|
plan_id = self._current_plan_id
|
||||||
|
|
||||||
|
if plan_id not in self._plans:
|
||||||
|
raise ToolError(f"No plan found with ID: {plan_id}")
|
||||||
|
|
||||||
|
plan = self._plans[plan_id]
|
||||||
|
return ToolResult(output=self._format_plan(plan))
|
||||||
|
|
||||||
|
def _set_active_plan(self, plan_id: Optional[str]) -> ToolResult:
|
||||||
|
"""Set a plan as the active plan."""
|
||||||
|
if not plan_id:
|
||||||
|
raise ToolError("Parameter `plan_id` is required for command: set_active")
|
||||||
|
|
||||||
|
if plan_id not in self._plans:
|
||||||
|
raise ToolError(f"No plan found with ID: {plan_id}")
|
||||||
|
|
||||||
|
self._current_plan_id = plan_id
|
||||||
|
return ToolResult(
|
||||||
|
output=f"Plan '{plan_id}' is now the active plan.\n\n{self._format_plan(self._plans[plan_id])}"
|
||||||
|
)
|
||||||
|
|
||||||
|
def _mark_step(
|
||||||
|
self,
|
||||||
|
plan_id: Optional[str],
|
||||||
|
step_index: Optional[int],
|
||||||
|
step_status: Optional[str],
|
||||||
|
step_notes: Optional[str],
|
||||||
|
) -> ToolResult:
|
||||||
|
"""Mark a step with a specific status and optional notes."""
|
||||||
|
if not plan_id:
|
||||||
|
# If no plan_id is provided, use the current active plan
|
||||||
|
if not self._current_plan_id:
|
||||||
|
raise ToolError(
|
||||||
|
"No active plan. Please specify a plan_id or set an active plan."
|
||||||
|
)
|
||||||
|
plan_id = self._current_plan_id
|
||||||
|
|
||||||
|
if plan_id not in self._plans:
|
||||||
|
raise ToolError(f"No plan found with ID: {plan_id}")
|
||||||
|
|
||||||
|
if step_index is None:
|
||||||
|
raise ToolError("Parameter `step_index` is required for command: mark_step")
|
||||||
|
|
||||||
|
plan = self._plans[plan_id]
|
||||||
|
|
||||||
|
if step_index < 0 or step_index >= len(plan["steps"]):
|
||||||
|
raise ToolError(
|
||||||
|
f"Invalid step_index: {step_index}. Valid indices range from 0 to {len(plan['steps'])-1}."
|
||||||
|
)
|
||||||
|
|
||||||
|
if step_status and step_status not in [
|
||||||
|
"not_started",
|
||||||
|
"in_progress",
|
||||||
|
"completed",
|
||||||
|
"blocked",
|
||||||
|
]:
|
||||||
|
raise ToolError(
|
||||||
|
f"Invalid step_status: {step_status}. Valid statuses are: not_started, in_progress, completed, blocked"
|
||||||
|
)
|
||||||
|
|
||||||
|
if step_status:
|
||||||
|
plan["step_statuses"][step_index] = step_status
|
||||||
|
|
||||||
|
if step_notes:
|
||||||
|
plan["step_notes"][step_index] = step_notes
|
||||||
|
|
||||||
|
return ToolResult(
|
||||||
|
output=f"Step {step_index} updated in plan '{plan_id}'.\n\n{self._format_plan(plan)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
def _delete_plan(self, plan_id: Optional[str]) -> ToolResult:
|
||||||
|
"""Delete a plan."""
|
||||||
|
if not plan_id:
|
||||||
|
raise ToolError("Parameter `plan_id` is required for command: delete")
|
||||||
|
|
||||||
|
if plan_id not in self._plans:
|
||||||
|
raise ToolError(f"No plan found with ID: {plan_id}")
|
||||||
|
|
||||||
|
del self._plans[plan_id]
|
||||||
|
|
||||||
|
# If the deleted plan was the active plan, clear the active plan
|
||||||
|
if self._current_plan_id == plan_id:
|
||||||
|
self._current_plan_id = None
|
||||||
|
|
||||||
|
return ToolResult(output=f"Plan '{plan_id}' has been deleted.")
|
||||||
|
|
||||||
|
def _format_plan(self, plan: Dict) -> str:
|
||||||
|
"""Format a plan for display."""
|
||||||
|
output = f"Plan: {plan['title']} (ID: {plan['plan_id']})\n"
|
||||||
|
output += "=" * len(output) + "\n\n"
|
||||||
|
|
||||||
|
# Calculate progress statistics
|
||||||
|
total_steps = len(plan["steps"])
|
||||||
|
completed = sum(1 for status in plan["step_statuses"] if status == "completed")
|
||||||
|
in_progress = sum(
|
||||||
|
1 for status in plan["step_statuses"] if status == "in_progress"
|
||||||
|
)
|
||||||
|
blocked = sum(1 for status in plan["step_statuses"] if status == "blocked")
|
||||||
|
not_started = sum(
|
||||||
|
1 for status in plan["step_statuses"] if status == "not_started"
|
||||||
|
)
|
||||||
|
|
||||||
|
output += f"Progress: {completed}/{total_steps} steps completed "
|
||||||
|
if total_steps > 0:
|
||||||
|
percentage = (completed / total_steps) * 100
|
||||||
|
output += f"({percentage:.1f}%)\n"
|
||||||
|
else:
|
||||||
|
output += "(0%)\n"
|
||||||
|
|
||||||
|
output += f"Status: {completed} completed, {in_progress} in progress, {blocked} blocked, {not_started} not started\n\n"
|
||||||
|
output += "Steps:\n"
|
||||||
|
|
||||||
|
# Add each step with its status and notes
|
||||||
|
for i, (step, status, notes) in enumerate(
|
||||||
|
zip(plan["steps"], plan["step_statuses"], plan["step_notes"])
|
||||||
|
):
|
||||||
|
status_symbol = {
|
||||||
|
"not_started": "[ ]",
|
||||||
|
"in_progress": "[→]",
|
||||||
|
"completed": "[✓]",
|
||||||
|
"blocked": "[!]",
|
||||||
|
}.get(status, "[ ]")
|
||||||
|
|
||||||
|
output += f"{i}. {status_symbol} {step}\n"
|
||||||
|
if notes:
|
||||||
|
output += f" Notes: {notes}\n"
|
||||||
|
|
||||||
|
return output
|
43
app/tool/run.py
Normal file
43
app/tool/run.py
Normal file
@ -0,0 +1,43 @@
|
|||||||
|
"""Utility to run shell commands asynchronously with a timeout."""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
|
||||||
|
TRUNCATED_MESSAGE: str = "<response clipped><NOTE>To save on context only part of this file has been shown to you. You should retry this tool after you have searched inside the file with `grep -n` in order to find the line numbers of what you are looking for.</NOTE>"
|
||||||
|
MAX_RESPONSE_LEN: int = 16000
|
||||||
|
|
||||||
|
|
||||||
|
def maybe_truncate(content: str, truncate_after: int | None = MAX_RESPONSE_LEN):
|
||||||
|
"""Truncate content and append a notice if content exceeds the specified length."""
|
||||||
|
return (
|
||||||
|
content
|
||||||
|
if not truncate_after or len(content) <= truncate_after
|
||||||
|
else content[:truncate_after] + TRUNCATED_MESSAGE
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def run(
|
||||||
|
cmd: str,
|
||||||
|
timeout: float | None = 120.0, # seconds
|
||||||
|
truncate_after: int | None = MAX_RESPONSE_LEN,
|
||||||
|
):
|
||||||
|
"""Run a shell command asynchronously with a timeout."""
|
||||||
|
process = await asyncio.create_subprocess_shell(
|
||||||
|
cmd, stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
stdout, stderr = await asyncio.wait_for(process.communicate(), timeout=timeout)
|
||||||
|
return (
|
||||||
|
process.returncode or 0,
|
||||||
|
maybe_truncate(stdout.decode(), truncate_after=truncate_after),
|
||||||
|
maybe_truncate(stderr.decode(), truncate_after=truncate_after),
|
||||||
|
)
|
||||||
|
except asyncio.TimeoutError as exc:
|
||||||
|
try:
|
||||||
|
process.kill()
|
||||||
|
except ProcessLookupError:
|
||||||
|
pass
|
||||||
|
raise TimeoutError(
|
||||||
|
f"Command '{cmd}' timed out after {timeout} seconds"
|
||||||
|
) from exc
|
342
app/tool/str_replace_editor.py
Normal file
342
app/tool/str_replace_editor.py
Normal file
@ -0,0 +1,342 @@
|
|||||||
|
from collections import defaultdict
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Literal, get_args
|
||||||
|
|
||||||
|
from app.exceptions import ToolError
|
||||||
|
from app.tool import BaseTool
|
||||||
|
from app.tool.base import CLIResult, ToolResult
|
||||||
|
from app.tool.run import run
|
||||||
|
|
||||||
|
|
||||||
|
Command = Literal[
|
||||||
|
"view",
|
||||||
|
"create",
|
||||||
|
"str_replace",
|
||||||
|
"insert",
|
||||||
|
"undo_edit",
|
||||||
|
]
|
||||||
|
SNIPPET_LINES: int = 4
|
||||||
|
|
||||||
|
MAX_RESPONSE_LEN: int = 16000
|
||||||
|
|
||||||
|
TRUNCATED_MESSAGE: str = "<response clipped><NOTE>To save on context only part of this file has been shown to you. You should retry this tool after you have searched inside the file with `grep -n` in order to find the line numbers of what you are looking for.</NOTE>"
|
||||||
|
|
||||||
|
_STR_REPLACE_EDITOR_DESCRIPTION = """Custom editing tool for viewing, creating and editing files
|
||||||
|
* State is persistent across command calls and discussions with the user
|
||||||
|
* If `path` is a file, `view` displays the result of applying `cat -n`. If `path` is a directory, `view` lists non-hidden files and directories up to 2 levels deep
|
||||||
|
* The `create` command cannot be used if the specified `path` already exists as a file
|
||||||
|
* If a `command` generates a long output, it will be truncated and marked with `<response clipped>`
|
||||||
|
* The `undo_edit` command will revert the last edit made to the file at `path`
|
||||||
|
|
||||||
|
Notes for using the `str_replace` command:
|
||||||
|
* The `old_str` parameter should match EXACTLY one or more consecutive lines from the original file. Be mindful of whitespaces!
|
||||||
|
* If the `old_str` parameter is not unique in the file, the replacement will not be performed. Make sure to include enough context in `old_str` to make it unique
|
||||||
|
* The `new_str` parameter should contain the edited lines that should replace the `old_str`
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
def maybe_truncate(content: str, truncate_after: int | None = MAX_RESPONSE_LEN):
|
||||||
|
"""Truncate content and append a notice if content exceeds the specified length."""
|
||||||
|
return (
|
||||||
|
content
|
||||||
|
if not truncate_after or len(content) <= truncate_after
|
||||||
|
else content[:truncate_after] + TRUNCATED_MESSAGE
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class StrReplaceEditor(BaseTool):
|
||||||
|
"""A tool for executing bash commands"""
|
||||||
|
|
||||||
|
name: str = "str_replace_editor"
|
||||||
|
description: str = _STR_REPLACE_EDITOR_DESCRIPTION
|
||||||
|
parameters: dict = {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"command": {
|
||||||
|
"description": "The commands to run. Allowed options are: `view`, `create`, `str_replace`, `insert`, `undo_edit`.",
|
||||||
|
"enum": ["view", "create", "str_replace", "insert", "undo_edit"],
|
||||||
|
"type": "string",
|
||||||
|
},
|
||||||
|
"path": {
|
||||||
|
"description": "Absolute path to file or directory.",
|
||||||
|
"type": "string",
|
||||||
|
},
|
||||||
|
"file_text": {
|
||||||
|
"description": "Required parameter of `create` command, with the content of the file to be created.",
|
||||||
|
"type": "string",
|
||||||
|
},
|
||||||
|
"old_str": {
|
||||||
|
"description": "Required parameter of `str_replace` command containing the string in `path` to replace.",
|
||||||
|
"type": "string",
|
||||||
|
},
|
||||||
|
"new_str": {
|
||||||
|
"description": "Optional parameter of `str_replace` command containing the new string (if not given, no string will be added). Required parameter of `insert` command containing the string to insert.",
|
||||||
|
"type": "string",
|
||||||
|
},
|
||||||
|
"insert_line": {
|
||||||
|
"description": "Required parameter of `insert` command. The `new_str` will be inserted AFTER the line `insert_line` of `path`.",
|
||||||
|
"type": "integer",
|
||||||
|
},
|
||||||
|
"view_range": {
|
||||||
|
"description": "Optional parameter of `view` command when `path` points to a file. If none is given, the full file is shown. If provided, the file will be shown in the indicated line number range, e.g. [11, 12] will show lines 11 and 12. Indexing at 1 to start. Setting `[start_line, -1]` shows all lines from `start_line` to the end of the file.",
|
||||||
|
"items": {"type": "integer"},
|
||||||
|
"type": "array",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"required": ["command", "path"],
|
||||||
|
}
|
||||||
|
|
||||||
|
_file_history: list = defaultdict(list)
|
||||||
|
|
||||||
|
async def execute(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
command: Command,
|
||||||
|
path: str,
|
||||||
|
file_text: str | None = None,
|
||||||
|
view_range: list[int] | None = None,
|
||||||
|
old_str: str | None = None,
|
||||||
|
new_str: str | None = None,
|
||||||
|
insert_line: int | None = None,
|
||||||
|
**kwargs,
|
||||||
|
) -> str:
|
||||||
|
_path = Path(path)
|
||||||
|
self.validate_path(command, _path)
|
||||||
|
if command == "view":
|
||||||
|
result = await self.view(_path, view_range)
|
||||||
|
elif command == "create":
|
||||||
|
if file_text is None:
|
||||||
|
raise ToolError("Parameter `file_text` is required for command: create")
|
||||||
|
self.write_file(_path, file_text)
|
||||||
|
self._file_history[_path].append(file_text)
|
||||||
|
result = ToolResult(output=f"File created successfully at: {_path}")
|
||||||
|
elif command == "str_replace":
|
||||||
|
if old_str is None:
|
||||||
|
raise ToolError(
|
||||||
|
"Parameter `old_str` is required for command: str_replace"
|
||||||
|
)
|
||||||
|
result = self.str_replace(_path, old_str, new_str)
|
||||||
|
elif command == "insert":
|
||||||
|
if insert_line is None:
|
||||||
|
raise ToolError(
|
||||||
|
"Parameter `insert_line` is required for command: insert"
|
||||||
|
)
|
||||||
|
if new_str is None:
|
||||||
|
raise ToolError("Parameter `new_str` is required for command: insert")
|
||||||
|
result = self.insert(_path, insert_line, new_str)
|
||||||
|
elif command == "undo_edit":
|
||||||
|
result = self.undo_edit(_path)
|
||||||
|
else:
|
||||||
|
raise ToolError(
|
||||||
|
f'Unrecognized command {command}. The allowed commands for the {self.name} tool are: {", ".join(get_args(Command))}'
|
||||||
|
)
|
||||||
|
return str(result)
|
||||||
|
|
||||||
|
def validate_path(self, command: str, path: Path):
|
||||||
|
"""
|
||||||
|
Check that the path/command combination is valid.
|
||||||
|
"""
|
||||||
|
# Check if its an absolute path
|
||||||
|
if not path.is_absolute():
|
||||||
|
suggested_path = Path("") / path
|
||||||
|
raise ToolError(
|
||||||
|
f"The path {path} is not an absolute path, it should start with `/`. Maybe you meant {suggested_path}?"
|
||||||
|
)
|
||||||
|
# Check if path exists
|
||||||
|
if not path.exists() and command != "create":
|
||||||
|
raise ToolError(
|
||||||
|
f"The path {path} does not exist. Please provide a valid path."
|
||||||
|
)
|
||||||
|
if path.exists() and command == "create":
|
||||||
|
raise ToolError(
|
||||||
|
f"File already exists at: {path}. Cannot overwrite files using command `create`."
|
||||||
|
)
|
||||||
|
# Check if the path points to a directory
|
||||||
|
if path.is_dir():
|
||||||
|
if command != "view":
|
||||||
|
raise ToolError(
|
||||||
|
f"The path {path} is a directory and only the `view` command can be used on directories"
|
||||||
|
)
|
||||||
|
|
||||||
|
async def view(self, path: Path, view_range: list[int] | None = None):
|
||||||
|
"""Implement the view command"""
|
||||||
|
if path.is_dir():
|
||||||
|
if view_range:
|
||||||
|
raise ToolError(
|
||||||
|
"The `view_range` parameter is not allowed when `path` points to a directory."
|
||||||
|
)
|
||||||
|
|
||||||
|
_, stdout, stderr = await run(
|
||||||
|
rf"find {path} -maxdepth 2 -not -path '*/\.*'"
|
||||||
|
)
|
||||||
|
if not stderr:
|
||||||
|
stdout = f"Here's the files and directories up to 2 levels deep in {path}, excluding hidden items:\n{stdout}\n"
|
||||||
|
return CLIResult(output=stdout, error=stderr)
|
||||||
|
|
||||||
|
file_content = self.read_file(path)
|
||||||
|
init_line = 1
|
||||||
|
if view_range:
|
||||||
|
if len(view_range) != 2 or not all(isinstance(i, int) for i in view_range):
|
||||||
|
raise ToolError(
|
||||||
|
"Invalid `view_range`. It should be a list of two integers."
|
||||||
|
)
|
||||||
|
file_lines = file_content.split("\n")
|
||||||
|
n_lines_file = len(file_lines)
|
||||||
|
init_line, final_line = view_range
|
||||||
|
if init_line < 1 or init_line > n_lines_file:
|
||||||
|
raise ToolError(
|
||||||
|
f"Invalid `view_range`: {view_range}. Its first element `{init_line}` should be within the range of lines of the file: {[1, n_lines_file]}"
|
||||||
|
)
|
||||||
|
if final_line > n_lines_file:
|
||||||
|
raise ToolError(
|
||||||
|
f"Invalid `view_range`: {view_range}. Its second element `{final_line}` should be smaller than the number of lines in the file: `{n_lines_file}`"
|
||||||
|
)
|
||||||
|
if final_line != -1 and final_line < init_line:
|
||||||
|
raise ToolError(
|
||||||
|
f"Invalid `view_range`: {view_range}. Its second element `{final_line}` should be larger or equal than its first `{init_line}`"
|
||||||
|
)
|
||||||
|
|
||||||
|
if final_line == -1:
|
||||||
|
file_content = "\n".join(file_lines[init_line - 1 :])
|
||||||
|
else:
|
||||||
|
file_content = "\n".join(file_lines[init_line - 1 : final_line])
|
||||||
|
|
||||||
|
return CLIResult(
|
||||||
|
output=self._make_output(file_content, str(path), init_line=init_line)
|
||||||
|
)
|
||||||
|
|
||||||
|
def str_replace(self, path: Path, old_str: str, new_str: str | None):
|
||||||
|
"""Implement the str_replace command, which replaces old_str with new_str in the file content"""
|
||||||
|
# Read the file content
|
||||||
|
file_content = self.read_file(path).expandtabs()
|
||||||
|
old_str = old_str.expandtabs()
|
||||||
|
new_str = new_str.expandtabs() if new_str is not None else ""
|
||||||
|
|
||||||
|
# Check if old_str is unique in the file
|
||||||
|
occurrences = file_content.count(old_str)
|
||||||
|
if occurrences == 0:
|
||||||
|
raise ToolError(
|
||||||
|
f"No replacement was performed, old_str `{old_str}` did not appear verbatim in {path}."
|
||||||
|
)
|
||||||
|
elif occurrences > 1:
|
||||||
|
file_content_lines = file_content.split("\n")
|
||||||
|
lines = [
|
||||||
|
idx + 1
|
||||||
|
for idx, line in enumerate(file_content_lines)
|
||||||
|
if old_str in line
|
||||||
|
]
|
||||||
|
raise ToolError(
|
||||||
|
f"No replacement was performed. Multiple occurrences of old_str `{old_str}` in lines {lines}. Please ensure it is unique"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Replace old_str with new_str
|
||||||
|
new_file_content = file_content.replace(old_str, new_str)
|
||||||
|
|
||||||
|
# Write the new content to the file
|
||||||
|
self.write_file(path, new_file_content)
|
||||||
|
|
||||||
|
# Save the content to history
|
||||||
|
self._file_history[path].append(file_content)
|
||||||
|
|
||||||
|
# Create a snippet of the edited section
|
||||||
|
replacement_line = file_content.split(old_str)[0].count("\n")
|
||||||
|
start_line = max(0, replacement_line - SNIPPET_LINES)
|
||||||
|
end_line = replacement_line + SNIPPET_LINES + new_str.count("\n")
|
||||||
|
snippet = "\n".join(new_file_content.split("\n")[start_line : end_line + 1])
|
||||||
|
|
||||||
|
# Prepare the success message
|
||||||
|
success_msg = f"The file {path} has been edited. "
|
||||||
|
success_msg += self._make_output(
|
||||||
|
snippet, f"a snippet of {path}", start_line + 1
|
||||||
|
)
|
||||||
|
success_msg += "Review the changes and make sure they are as expected. Edit the file again if necessary."
|
||||||
|
|
||||||
|
return CLIResult(output=success_msg)
|
||||||
|
|
||||||
|
def insert(self, path: Path, insert_line: int, new_str: str):
|
||||||
|
"""Implement the insert command, which inserts new_str at the specified line in the file content."""
|
||||||
|
file_text = self.read_file(path).expandtabs()
|
||||||
|
new_str = new_str.expandtabs()
|
||||||
|
file_text_lines = file_text.split("\n")
|
||||||
|
n_lines_file = len(file_text_lines)
|
||||||
|
|
||||||
|
if insert_line < 0 or insert_line > n_lines_file:
|
||||||
|
raise ToolError(
|
||||||
|
f"Invalid `insert_line` parameter: {insert_line}. It should be within the range of lines of the file: {[0, n_lines_file]}"
|
||||||
|
)
|
||||||
|
|
||||||
|
new_str_lines = new_str.split("\n")
|
||||||
|
new_file_text_lines = (
|
||||||
|
file_text_lines[:insert_line]
|
||||||
|
+ new_str_lines
|
||||||
|
+ file_text_lines[insert_line:]
|
||||||
|
)
|
||||||
|
snippet_lines = (
|
||||||
|
file_text_lines[max(0, insert_line - SNIPPET_LINES) : insert_line]
|
||||||
|
+ new_str_lines
|
||||||
|
+ file_text_lines[insert_line : insert_line + SNIPPET_LINES]
|
||||||
|
)
|
||||||
|
|
||||||
|
new_file_text = "\n".join(new_file_text_lines)
|
||||||
|
snippet = "\n".join(snippet_lines)
|
||||||
|
|
||||||
|
self.write_file(path, new_file_text)
|
||||||
|
self._file_history[path].append(file_text)
|
||||||
|
|
||||||
|
success_msg = f"The file {path} has been edited. "
|
||||||
|
success_msg += self._make_output(
|
||||||
|
snippet,
|
||||||
|
"a snippet of the edited file",
|
||||||
|
max(1, insert_line - SNIPPET_LINES + 1),
|
||||||
|
)
|
||||||
|
success_msg += "Review the changes and make sure they are as expected (correct indentation, no duplicate lines, etc). Edit the file again if necessary."
|
||||||
|
return CLIResult(output=success_msg)
|
||||||
|
|
||||||
|
def undo_edit(self, path: Path):
|
||||||
|
"""Implement the undo_edit command."""
|
||||||
|
if not self._file_history[path]:
|
||||||
|
raise ToolError(f"No edit history found for {path}.")
|
||||||
|
|
||||||
|
old_text = self._file_history[path].pop()
|
||||||
|
self.write_file(path, old_text)
|
||||||
|
|
||||||
|
return CLIResult(
|
||||||
|
output=f"Last edit to {path} undone successfully. {self._make_output(old_text, str(path))}"
|
||||||
|
)
|
||||||
|
|
||||||
|
def read_file(self, path: Path):
|
||||||
|
"""Read the content of a file from a given path; raise a ToolError if an error occurs."""
|
||||||
|
try:
|
||||||
|
return path.read_text()
|
||||||
|
except Exception as e:
|
||||||
|
raise ToolError(f"Ran into {e} while trying to read {path}") from None
|
||||||
|
|
||||||
|
def write_file(self, path: Path, file: str):
|
||||||
|
"""Write the content of a file to a given path; raise a ToolError if an error occurs."""
|
||||||
|
try:
|
||||||
|
path.write_text(file)
|
||||||
|
except Exception as e:
|
||||||
|
raise ToolError(f"Ran into {e} while trying to write to {path}") from None
|
||||||
|
|
||||||
|
def _make_output(
|
||||||
|
self,
|
||||||
|
file_content: str,
|
||||||
|
file_descriptor: str,
|
||||||
|
init_line: int = 1,
|
||||||
|
expand_tabs: bool = True,
|
||||||
|
):
|
||||||
|
"""Generate output for the CLI based on the content of a file."""
|
||||||
|
file_content = maybe_truncate(file_content)
|
||||||
|
if expand_tabs:
|
||||||
|
file_content = file_content.expandtabs()
|
||||||
|
file_content = "\n".join(
|
||||||
|
[
|
||||||
|
f"{i + init_line:6}\t{line}"
|
||||||
|
for i, line in enumerate(file_content.split("\n"))
|
||||||
|
]
|
||||||
|
)
|
||||||
|
return (
|
||||||
|
f"Here's the result of running `cat -n` on {file_descriptor}:\n"
|
||||||
|
+ file_content
|
||||||
|
+ "\n"
|
||||||
|
)
|
24
app/tool/terminate.py
Normal file
24
app/tool/terminate.py
Normal file
@ -0,0 +1,24 @@
|
|||||||
|
from app.tool.base import BaseTool
|
||||||
|
|
||||||
|
|
||||||
|
_TERMINATE_DESCRIPTION = """Terminate the interaction when the request is met OR if the assistant cannot proceed further with the task."""
|
||||||
|
|
||||||
|
|
||||||
|
class Terminate(BaseTool):
|
||||||
|
name: str = "terminate"
|
||||||
|
description: str = _TERMINATE_DESCRIPTION
|
||||||
|
parameters: dict = {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"status": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "The finish status of the interaction.",
|
||||||
|
"enum": ["success", "failure"],
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"required": ["status"],
|
||||||
|
}
|
||||||
|
|
||||||
|
async def execute(self, status: str) -> str:
|
||||||
|
"""Finish the current execution"""
|
||||||
|
return f"The interaction has been completed with status: {status}"
|
55
app/tool/tool_collection.py
Normal file
55
app/tool/tool_collection.py
Normal file
@ -0,0 +1,55 @@
|
|||||||
|
"""Collection classes for managing multiple tools."""
|
||||||
|
from typing import Any, Dict, List
|
||||||
|
|
||||||
|
from app.exceptions import ToolError
|
||||||
|
from app.tool.base import BaseTool, ToolFailure, ToolResult
|
||||||
|
|
||||||
|
|
||||||
|
class ToolCollection:
|
||||||
|
"""A collection of defined tools."""
|
||||||
|
|
||||||
|
def __init__(self, *tools: BaseTool):
|
||||||
|
self.tools = tools
|
||||||
|
self.tool_map = {tool.name: tool for tool in tools}
|
||||||
|
|
||||||
|
def __iter__(self):
|
||||||
|
return iter(self.tools)
|
||||||
|
|
||||||
|
def to_params(self) -> List[Dict[str, Any]]:
|
||||||
|
return [tool.to_param() for tool in self.tools]
|
||||||
|
|
||||||
|
async def execute(
|
||||||
|
self, *, name: str, tool_input: Dict[str, Any] = None
|
||||||
|
) -> ToolResult:
|
||||||
|
tool = self.tool_map.get(name)
|
||||||
|
if not tool:
|
||||||
|
return ToolFailure(error=f"Tool {name} is invalid")
|
||||||
|
try:
|
||||||
|
result = await tool(**tool_input)
|
||||||
|
return result
|
||||||
|
except ToolError as e:
|
||||||
|
return ToolFailure(error=e.message)
|
||||||
|
|
||||||
|
async def execute_all(self) -> List[ToolResult]:
|
||||||
|
"""Execute all tools in the collection sequentially."""
|
||||||
|
results = []
|
||||||
|
for tool in self.tools:
|
||||||
|
try:
|
||||||
|
result = await tool()
|
||||||
|
results.append(result)
|
||||||
|
except ToolError as e:
|
||||||
|
results.append(ToolFailure(error=e.message))
|
||||||
|
return results
|
||||||
|
|
||||||
|
def get_tool(self, name: str) -> BaseTool:
|
||||||
|
return self.tool_map.get(name)
|
||||||
|
|
||||||
|
def add_tool(self, tool: BaseTool):
|
||||||
|
self.tools += (tool,)
|
||||||
|
self.tool_map[tool.name] = tool
|
||||||
|
return self
|
||||||
|
|
||||||
|
def add_tools(self, *tools: BaseTool):
|
||||||
|
for tool in tools:
|
||||||
|
self.add_tool(tool)
|
||||||
|
return self
|
0
app/utils/__init__.py
Normal file
0
app/utils/__init__.py
Normal file
112
app/utils/extract_html_content.py
Normal file
112
app/utils/extract_html_content.py
Normal file
@ -0,0 +1,112 @@
|
|||||||
|
import re
|
||||||
|
|
||||||
|
|
||||||
|
def extract_html_content(text: str, stack: str = "react-tailwind") -> str:
|
||||||
|
"""
|
||||||
|
Extract code content from LLM response based on technology stack.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text: The raw text response from LLM
|
||||||
|
stack: Technology stack ("react-tailwind", "html-tailwind", "svg")
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: Extracted code content
|
||||||
|
"""
|
||||||
|
# Remove markdown code blocks if present
|
||||||
|
text = re.sub(r"```[\w]*\n|```", "", text)
|
||||||
|
|
||||||
|
if stack == "svg":
|
||||||
|
# Extract SVG content
|
||||||
|
svg_match = re.search(r"(<svg.*?>.*?</svg>)", text, re.DOTALL)
|
||||||
|
if svg_match:
|
||||||
|
return svg_match.group(1)
|
||||||
|
elif stack == "react-tailwind":
|
||||||
|
# Extract React component content
|
||||||
|
react_match = re.search(r"(export default function.*?})\s*$", text, re.DOTALL)
|
||||||
|
if react_match:
|
||||||
|
return react_match.group(1)
|
||||||
|
# Alternative: look for const/function component definition
|
||||||
|
alt_match = re.search(
|
||||||
|
r"((?:const|function)\s+\w+\s*=?\s*(?:\([^)]*\))?\s*=>?\s*{.*?})\s*$",
|
||||||
|
text,
|
||||||
|
re.DOTALL,
|
||||||
|
)
|
||||||
|
if alt_match:
|
||||||
|
return alt_match.group(1)
|
||||||
|
|
||||||
|
# Default: try to extract content within <html> tags
|
||||||
|
html_match = re.search(r"(<html.*?>.*?</html>)", text, re.DOTALL)
|
||||||
|
if html_match:
|
||||||
|
return html_match.group(1)
|
||||||
|
|
||||||
|
# If no specific patterns match, try to extract any HTML-like content
|
||||||
|
body_match = re.search(r"(<body.*?>.*?</body>)", text, re.DOTALL)
|
||||||
|
if body_match:
|
||||||
|
return f"<html>\n{body_match.group(1)}\n</html>"
|
||||||
|
|
||||||
|
div_match = re.search(r"(<div.*?>.*?</div>)", text, re.DOTALL)
|
||||||
|
if div_match:
|
||||||
|
return f"<html>\n<body>\n{div_match.group(1)}\n</body>\n</html>"
|
||||||
|
|
||||||
|
# If no patterns match, clean up the text and return it
|
||||||
|
cleaned_text = text.strip()
|
||||||
|
print(
|
||||||
|
f"[Code Extraction] No specific pattern found for stack '{stack}'. Raw content:\n{cleaned_text}"
|
||||||
|
)
|
||||||
|
return cleaned_text
|
||||||
|
|
||||||
|
|
||||||
|
def clean_code_content(code: str) -> str:
|
||||||
|
"""
|
||||||
|
Clean and format the extracted code content.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
code: Raw code content
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: Cleaned and formatted code
|
||||||
|
"""
|
||||||
|
# Remove leading/trailing whitespace
|
||||||
|
code = code.strip()
|
||||||
|
|
||||||
|
# Remove extra blank lines
|
||||||
|
code = re.sub(r"\n\s*\n", "\n\n", code)
|
||||||
|
|
||||||
|
# Ensure proper indentation
|
||||||
|
lines = code.split("\n")
|
||||||
|
indent_level = 0
|
||||||
|
formatted_lines = []
|
||||||
|
|
||||||
|
for line in lines:
|
||||||
|
# Adjust indent level based on brackets/braces
|
||||||
|
stripped_line = line.strip()
|
||||||
|
if stripped_line.endswith("{"):
|
||||||
|
formatted_lines.append(" " * indent_level + stripped_line)
|
||||||
|
indent_level += 1
|
||||||
|
elif stripped_line.startswith("}"):
|
||||||
|
indent_level = max(0, indent_level - 1)
|
||||||
|
formatted_lines.append(" " * indent_level + stripped_line)
|
||||||
|
else:
|
||||||
|
formatted_lines.append(" " * indent_level + stripped_line)
|
||||||
|
|
||||||
|
return "\n".join(formatted_lines)
|
||||||
|
|
||||||
|
|
||||||
|
def extract_code_content(text: str, stack: str = "react-tailwind") -> str:
|
||||||
|
"""
|
||||||
|
Main function to extract and clean code content.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text: Raw text from LLM response
|
||||||
|
stack: Technology stack being used
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: Final cleaned and formatted code
|
||||||
|
"""
|
||||||
|
# Extract the relevant code content
|
||||||
|
extracted_content = extract_html_content(text, stack)
|
||||||
|
|
||||||
|
# Clean and format the code
|
||||||
|
cleaned_content = clean_code_content(extracted_content)
|
||||||
|
|
||||||
|
return cleaned_content
|
74
app/utils/shutdown_listener.py
Normal file
74
app/utils/shutdown_listener.py
Normal file
@ -0,0 +1,74 @@
|
|||||||
|
"""
|
||||||
|
This module monitors the app for shutdown signals
|
||||||
|
"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import signal
|
||||||
|
import threading
|
||||||
|
import time
|
||||||
|
from types import FrameType
|
||||||
|
|
||||||
|
from uvicorn.server import HANDLED_SIGNALS
|
||||||
|
|
||||||
|
from app.logger import logger
|
||||||
|
|
||||||
|
|
||||||
|
_should_exit = None
|
||||||
|
|
||||||
|
|
||||||
|
def _register_signal_handler(sig: signal.Signals):
|
||||||
|
original_handler = None
|
||||||
|
|
||||||
|
def handler(sig_: int, frame: FrameType | None):
|
||||||
|
logger.debug(f"shutdown_signal:{sig_}")
|
||||||
|
global _should_exit
|
||||||
|
_should_exit = True
|
||||||
|
if original_handler:
|
||||||
|
original_handler(sig_, frame) # type: ignore[unreachable]
|
||||||
|
|
||||||
|
original_handler = signal.signal(sig, handler)
|
||||||
|
|
||||||
|
|
||||||
|
def _register_signal_handlers():
|
||||||
|
global _should_exit
|
||||||
|
if _should_exit is not None:
|
||||||
|
return
|
||||||
|
_should_exit = False
|
||||||
|
|
||||||
|
logger.debug("_register_signal_handlers")
|
||||||
|
|
||||||
|
# Check if we're in the main thread of the main interpreter
|
||||||
|
if threading.current_thread() is threading.main_thread():
|
||||||
|
logger.debug("_register_signal_handlers:main_thread")
|
||||||
|
for sig in HANDLED_SIGNALS:
|
||||||
|
_register_signal_handler(sig)
|
||||||
|
else:
|
||||||
|
logger.debug("_register_signal_handlers:not_main_thread")
|
||||||
|
|
||||||
|
|
||||||
|
def should_exit() -> bool:
|
||||||
|
_register_signal_handlers()
|
||||||
|
return bool(_should_exit)
|
||||||
|
|
||||||
|
|
||||||
|
def should_continue() -> bool:
|
||||||
|
_register_signal_handlers()
|
||||||
|
return not _should_exit
|
||||||
|
|
||||||
|
|
||||||
|
def sleep_if_should_continue(timeout: float):
|
||||||
|
if timeout <= 1:
|
||||||
|
time.sleep(timeout)
|
||||||
|
return
|
||||||
|
start_time = time.time()
|
||||||
|
while (time.time() - start_time) < timeout and should_continue():
|
||||||
|
time.sleep(1)
|
||||||
|
|
||||||
|
|
||||||
|
async def async_sleep_if_should_continue(timeout: float):
|
||||||
|
if timeout <= 1:
|
||||||
|
await asyncio.sleep(timeout)
|
||||||
|
return
|
||||||
|
start_time = time.time()
|
||||||
|
while time.time() - start_time < timeout and should_continue():
|
||||||
|
await asyncio.sleep(1)
|
13
config/config.example.toml
Normal file
13
config/config.example.toml
Normal file
@ -0,0 +1,13 @@
|
|||||||
|
# Global LLM configuration
|
||||||
|
[llm]
|
||||||
|
model = "deepseek-chat"
|
||||||
|
base_url = "https://api.deepseek.com/v1"
|
||||||
|
api_key = "sk-..."
|
||||||
|
max_tokens = 4096
|
||||||
|
temperature = 0.0
|
||||||
|
|
||||||
|
# Optional configuration for specific LLM models
|
||||||
|
[llm.vision]
|
||||||
|
model = "..."
|
||||||
|
base_url = "..."
|
||||||
|
api_key = "sk-..."
|
23
main.py
Normal file
23
main.py
Normal file
@ -0,0 +1,23 @@
|
|||||||
|
import asyncio
|
||||||
|
|
||||||
|
from app.agent import ToolCallAgent
|
||||||
|
from app.logger import logger
|
||||||
|
|
||||||
|
|
||||||
|
async def main():
|
||||||
|
agent = ToolCallAgent()
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
prompt = input("Enter your prompt (or 'exit' to quit): ")
|
||||||
|
if prompt.lower() == "exit":
|
||||||
|
logger.info("Goodbye!")
|
||||||
|
break
|
||||||
|
logger.warning("Processing your request...")
|
||||||
|
await agent.run(prompt)
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
logger.warning("Goodbye!")
|
||||||
|
break
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
asyncio.run(main())
|
14
requirements.txt
Normal file
14
requirements.txt
Normal file
@ -0,0 +1,14 @@
|
|||||||
|
pydantic~=2.10.4
|
||||||
|
openai~=1.58.1
|
||||||
|
tenacity~=9.0.0
|
||||||
|
pyyaml~=6.0.2
|
||||||
|
loguru~=0.7.3
|
||||||
|
numpy
|
||||||
|
datasets~=3.2.0
|
||||||
|
|
||||||
|
html2text~=2024.2.26
|
||||||
|
gymnasium~=1.0.0
|
||||||
|
pillow~=10.4.0
|
||||||
|
browsergym~=0.13.3
|
||||||
|
uvicorn~=0.34.0
|
||||||
|
unidiff~=0.7.5
|
19
run_loop.py
Normal file
19
run_loop.py
Normal file
@ -0,0 +1,19 @@
|
|||||||
|
import asyncio
|
||||||
|
|
||||||
|
from app.agent import ToolCallAgent
|
||||||
|
from app.flow.base import FlowType
|
||||||
|
from app.flow.flow_factory import FlowFactory
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
agent = ToolCallAgent()
|
||||||
|
|
||||||
|
flow = FlowFactory.create_flow(
|
||||||
|
flow_type=FlowType.PLANNING,
|
||||||
|
agents=agent,
|
||||||
|
)
|
||||||
|
|
||||||
|
result = asyncio.run(
|
||||||
|
flow.execute("Create a web app that shows Japan travel destinations")
|
||||||
|
)
|
||||||
|
print(result)
|
Loading…
x
Reference in New Issue
Block a user