add: Added a sandbox for executing commands within docker containers
This commit is contained in:
parent
5226628c3e
commit
15024e320a
1
.gitignore
vendored
1
.gitignore
vendored
@ -178,3 +178,4 @@ data/
|
|||||||
|
|
||||||
# Workspace
|
# Workspace
|
||||||
workspace/
|
workspace/
|
||||||
|
.DS_Store
|
||||||
|
@ -23,8 +23,23 @@ class LLMSettings(BaseModel):
|
|||||||
temperature: float = Field(1.0, description="Sampling temperature")
|
temperature: float = Field(1.0, description="Sampling temperature")
|
||||||
|
|
||||||
|
|
||||||
|
class SandboxConfig(BaseModel):
|
||||||
|
"""Configuration for the execution sandbox"""
|
||||||
|
|
||||||
|
use_sandbox: bool = Field(False, description="Whether to use the sandbox")
|
||||||
|
image: str = Field("python:3.10-slim", description="Base image")
|
||||||
|
work_dir: str = Field("/workspace", description="Container working directory")
|
||||||
|
memory_limit: str = Field("512m", description="Memory limit")
|
||||||
|
cpu_limit: float = Field(1.0, description="CPU limit")
|
||||||
|
timeout: int = Field(300, description="Default command timeout (seconds)")
|
||||||
|
network_enabled: bool = Field(
|
||||||
|
False, description="Whether network access is allowed"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class AppConfig(BaseModel):
|
class AppConfig(BaseModel):
|
||||||
llm: Dict[str, LLMSettings]
|
llm: Dict[str, LLMSettings]
|
||||||
|
sandbox: SandboxConfig
|
||||||
|
|
||||||
|
|
||||||
class Config:
|
class Config:
|
||||||
@ -85,7 +100,8 @@ class Config:
|
|||||||
name: {**default_settings, **override_config}
|
name: {**default_settings, **override_config}
|
||||||
for name, override_config in llm_overrides.items()
|
for name, override_config in llm_overrides.items()
|
||||||
},
|
},
|
||||||
}
|
},
|
||||||
|
"sandbox": raw_config.get("sandbox", {}),
|
||||||
}
|
}
|
||||||
|
|
||||||
self._config = AppConfig(**config_dict)
|
self._config = AppConfig(**config_dict)
|
||||||
@ -94,5 +110,9 @@ class Config:
|
|||||||
def llm(self) -> Dict[str, LLMSettings]:
|
def llm(self) -> Dict[str, LLMSettings]:
|
||||||
return self._config.llm
|
return self._config.llm
|
||||||
|
|
||||||
|
@property
|
||||||
|
def sandbox(self) -> SandboxConfig:
|
||||||
|
return self._config.sandbox
|
||||||
|
|
||||||
|
|
||||||
config = Config()
|
config = Config()
|
||||||
|
30
app/sandbox/__init__.py
Normal file
30
app/sandbox/__init__.py
Normal file
@ -0,0 +1,30 @@
|
|||||||
|
"""
|
||||||
|
Docker Sandbox Module
|
||||||
|
|
||||||
|
Provides secure containerized execution environment with resource limits
|
||||||
|
and isolation for running untrusted code.
|
||||||
|
"""
|
||||||
|
from app.sandbox.client import (
|
||||||
|
BaseSandboxClient,
|
||||||
|
LocalSandboxClient,
|
||||||
|
create_sandbox_client,
|
||||||
|
)
|
||||||
|
from app.sandbox.core.exceptions import (
|
||||||
|
SandboxError,
|
||||||
|
SandboxResourceError,
|
||||||
|
SandboxTimeoutError,
|
||||||
|
)
|
||||||
|
from app.sandbox.core.manager import SandboxManager
|
||||||
|
from app.sandbox.core.sandbox import DockerSandbox
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"DockerSandbox",
|
||||||
|
"SandboxManager",
|
||||||
|
"BaseSandboxClient",
|
||||||
|
"LocalSandboxClient",
|
||||||
|
"create_sandbox_client",
|
||||||
|
"SandboxError",
|
||||||
|
"SandboxTimeoutError",
|
||||||
|
"SandboxResourceError",
|
||||||
|
]
|
201
app/sandbox/client.py
Normal file
201
app/sandbox/client.py
Normal file
@ -0,0 +1,201 @@
|
|||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from typing import Dict, Optional, Protocol
|
||||||
|
|
||||||
|
from app.config import SandboxConfig
|
||||||
|
from app.sandbox.core.sandbox import DockerSandbox
|
||||||
|
|
||||||
|
|
||||||
|
class SandboxFileOperations(Protocol):
|
||||||
|
"""Protocol for sandbox file operations."""
|
||||||
|
|
||||||
|
async def copy_from(self, container_path: str, local_path: str) -> None:
|
||||||
|
"""Copies file from container to local.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
container_path: File path in container.
|
||||||
|
local_path: Local destination path.
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
|
async def copy_to(self, local_path: str, container_path: str) -> None:
|
||||||
|
"""Copies file from local to container.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
local_path: Local source file path.
|
||||||
|
container_path: Destination path in container.
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
|
async def read_file(self, path: str) -> str:
|
||||||
|
"""Reads file content from container.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
path: File path in container.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: File content.
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
|
async def write_file(self, path: str, content: str) -> None:
|
||||||
|
"""Writes content to file in container.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
path: File path in container.
|
||||||
|
content: Content to write.
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
|
class BaseSandboxClient(ABC):
|
||||||
|
"""Base sandbox client interface."""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def create(
|
||||||
|
self,
|
||||||
|
config: Optional[SandboxConfig] = None,
|
||||||
|
volume_bindings: Optional[Dict[str, str]] = None,
|
||||||
|
) -> None:
|
||||||
|
"""Creates sandbox."""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def run_command(self, command: str, timeout: Optional[int] = None) -> str:
|
||||||
|
"""Executes command."""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def copy_from(self, container_path: str, local_path: str) -> None:
|
||||||
|
"""Copies file from container."""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def copy_to(self, local_path: str, container_path: str) -> None:
|
||||||
|
"""Copies file to container."""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def read_file(self, path: str) -> str:
|
||||||
|
"""Reads file."""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def write_file(self, path: str, content: str) -> None:
|
||||||
|
"""Writes file."""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def cleanup(self) -> None:
|
||||||
|
"""Cleans up resources."""
|
||||||
|
|
||||||
|
|
||||||
|
class LocalSandboxClient(BaseSandboxClient):
|
||||||
|
"""Local sandbox client implementation."""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
"""Initializes local sandbox client."""
|
||||||
|
self.sandbox: Optional[DockerSandbox] = None
|
||||||
|
|
||||||
|
async def create(
|
||||||
|
self,
|
||||||
|
config: Optional[SandboxConfig] = None,
|
||||||
|
volume_bindings: Optional[Dict[str, str]] = None,
|
||||||
|
) -> None:
|
||||||
|
"""Creates a sandbox.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config: Sandbox configuration.
|
||||||
|
volume_bindings: Volume mappings.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
RuntimeError: If sandbox creation fails.
|
||||||
|
"""
|
||||||
|
self.sandbox = DockerSandbox(config, volume_bindings)
|
||||||
|
await self.sandbox.create()
|
||||||
|
|
||||||
|
async def run_command(self, command: str, timeout: Optional[int] = None) -> str:
|
||||||
|
"""Runs command in sandbox.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
command: Command to execute.
|
||||||
|
timeout: Execution timeout in seconds.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Command output.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
RuntimeError: If sandbox not initialized.
|
||||||
|
"""
|
||||||
|
if not self.sandbox:
|
||||||
|
raise RuntimeError("Sandbox not initialized")
|
||||||
|
return await self.sandbox.run_command(command, timeout)
|
||||||
|
|
||||||
|
async def copy_from(self, container_path: str, local_path: str) -> None:
|
||||||
|
"""Copies file from container to local.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
container_path: File path in container.
|
||||||
|
local_path: Local destination path.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
RuntimeError: If sandbox not initialized.
|
||||||
|
"""
|
||||||
|
if not self.sandbox:
|
||||||
|
raise RuntimeError("Sandbox not initialized")
|
||||||
|
await self.sandbox.copy_from(container_path, local_path)
|
||||||
|
|
||||||
|
async def copy_to(self, local_path: str, container_path: str) -> None:
|
||||||
|
"""Copies file from local to container.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
local_path: Local source file path.
|
||||||
|
container_path: Destination path in container.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
RuntimeError: If sandbox not initialized.
|
||||||
|
"""
|
||||||
|
if not self.sandbox:
|
||||||
|
raise RuntimeError("Sandbox not initialized")
|
||||||
|
await self.sandbox.copy_to(local_path, container_path)
|
||||||
|
|
||||||
|
async def read_file(self, path: str) -> str:
|
||||||
|
"""Reads file from container.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
path: File path in container.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
File content.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
RuntimeError: If sandbox not initialized.
|
||||||
|
"""
|
||||||
|
if not self.sandbox:
|
||||||
|
raise RuntimeError("Sandbox not initialized")
|
||||||
|
return await self.sandbox.read_file(path)
|
||||||
|
|
||||||
|
async def write_file(self, path: str, content: str) -> None:
|
||||||
|
"""Writes file to container.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
path: File path in container.
|
||||||
|
content: File content.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
RuntimeError: If sandbox not initialized.
|
||||||
|
"""
|
||||||
|
if not self.sandbox:
|
||||||
|
raise RuntimeError("Sandbox not initialized")
|
||||||
|
await self.sandbox.write_file(path, content)
|
||||||
|
|
||||||
|
async def cleanup(self) -> None:
|
||||||
|
"""Cleans up resources."""
|
||||||
|
if self.sandbox:
|
||||||
|
await self.sandbox.cleanup()
|
||||||
|
self.sandbox = None
|
||||||
|
|
||||||
|
|
||||||
|
async def create_sandbox_client() -> LocalSandboxClient:
|
||||||
|
"""Creates a sandbox client.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
LocalSandboxClient: Sandbox client instance.
|
||||||
|
"""
|
||||||
|
return LocalSandboxClient()
|
||||||
|
|
||||||
|
|
||||||
|
SANDBOX_CLIENT = create_sandbox_client()
|
17
app/sandbox/core/exceptions.py
Normal file
17
app/sandbox/core/exceptions.py
Normal file
@ -0,0 +1,17 @@
|
|||||||
|
"""Exception classes for the sandbox system.
|
||||||
|
|
||||||
|
This module defines custom exceptions used throughout the sandbox system to
|
||||||
|
handle various error conditions in a structured way.
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
class SandboxError(Exception):
|
||||||
|
"""Base exception for sandbox-related errors."""
|
||||||
|
|
||||||
|
|
||||||
|
class SandboxTimeoutError(SandboxError):
|
||||||
|
"""Exception raised when a sandbox operation times out."""
|
||||||
|
|
||||||
|
|
||||||
|
class SandboxResourceError(SandboxError):
|
||||||
|
"""Exception raised for resource-related errors."""
|
313
app/sandbox/core/manager.py
Normal file
313
app/sandbox/core/manager.py
Normal file
@ -0,0 +1,313 @@
|
|||||||
|
import asyncio
|
||||||
|
import uuid
|
||||||
|
from contextlib import asynccontextmanager
|
||||||
|
from typing import Dict, Optional, Set
|
||||||
|
|
||||||
|
import docker
|
||||||
|
from docker.errors import APIError, ImageNotFound
|
||||||
|
|
||||||
|
from app.config import SandboxConfig
|
||||||
|
from app.logger import logger
|
||||||
|
from app.sandbox.core.sandbox import DockerSandbox
|
||||||
|
|
||||||
|
|
||||||
|
class SandboxManager:
|
||||||
|
"""Docker sandbox manager.
|
||||||
|
|
||||||
|
Manages multiple DockerSandbox instances lifecycle including creation,
|
||||||
|
monitoring, and cleanup. Provides concurrent access control and automatic
|
||||||
|
cleanup mechanisms for sandbox resources.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
max_sandboxes: Maximum allowed number of sandboxes.
|
||||||
|
idle_timeout: Sandbox idle timeout in seconds.
|
||||||
|
cleanup_interval: Cleanup check interval in seconds.
|
||||||
|
_sandboxes: Active sandbox instance mapping.
|
||||||
|
_last_used: Last used time record for sandboxes.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
max_sandboxes: int = 100,
|
||||||
|
idle_timeout: int = 3600,
|
||||||
|
cleanup_interval: int = 300,
|
||||||
|
):
|
||||||
|
"""Initializes sandbox manager.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
max_sandboxes: Maximum sandbox count limit.
|
||||||
|
idle_timeout: Idle timeout in seconds.
|
||||||
|
cleanup_interval: Cleanup check interval in seconds.
|
||||||
|
"""
|
||||||
|
self.max_sandboxes = max_sandboxes
|
||||||
|
self.idle_timeout = idle_timeout
|
||||||
|
self.cleanup_interval = cleanup_interval
|
||||||
|
|
||||||
|
# Docker client
|
||||||
|
self._client = docker.from_env()
|
||||||
|
|
||||||
|
# Resource mappings
|
||||||
|
self._sandboxes: Dict[str, DockerSandbox] = {}
|
||||||
|
self._last_used: Dict[str, float] = {}
|
||||||
|
|
||||||
|
# Concurrency control
|
||||||
|
self._locks: Dict[str, asyncio.Lock] = {}
|
||||||
|
self._global_lock = asyncio.Lock()
|
||||||
|
self._active_operations: Set[str] = set()
|
||||||
|
|
||||||
|
# Cleanup task
|
||||||
|
self._cleanup_task: Optional[asyncio.Task] = None
|
||||||
|
self._is_shutting_down = False
|
||||||
|
|
||||||
|
# Start automatic cleanup
|
||||||
|
self.start_cleanup_task()
|
||||||
|
|
||||||
|
async def ensure_image(self, image: str) -> bool:
|
||||||
|
"""Ensures Docker image is available.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image: Image name.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: Whether image is available.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
self._client.images.get(image)
|
||||||
|
return True
|
||||||
|
except ImageNotFound:
|
||||||
|
try:
|
||||||
|
logger.info(f"Pulling image {image}...")
|
||||||
|
await asyncio.get_event_loop().run_in_executor(
|
||||||
|
None, self._client.images.pull, image
|
||||||
|
)
|
||||||
|
return True
|
||||||
|
except (APIError, Exception) as e:
|
||||||
|
logger.error(f"Failed to pull image {image}: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
@asynccontextmanager
|
||||||
|
async def sandbox_operation(self, sandbox_id: str):
|
||||||
|
"""Context manager for sandbox operations.
|
||||||
|
|
||||||
|
Provides concurrency control and usage time updates.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
sandbox_id: Sandbox ID.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
KeyError: If sandbox not found.
|
||||||
|
"""
|
||||||
|
if sandbox_id not in self._locks:
|
||||||
|
self._locks[sandbox_id] = asyncio.Lock()
|
||||||
|
|
||||||
|
async with self._locks[sandbox_id]:
|
||||||
|
if sandbox_id not in self._sandboxes:
|
||||||
|
raise KeyError(f"Sandbox {sandbox_id} not found")
|
||||||
|
|
||||||
|
self._active_operations.add(sandbox_id)
|
||||||
|
try:
|
||||||
|
self._last_used[sandbox_id] = asyncio.get_event_loop().time()
|
||||||
|
yield self._sandboxes[sandbox_id]
|
||||||
|
finally:
|
||||||
|
self._active_operations.remove(sandbox_id)
|
||||||
|
|
||||||
|
async def create_sandbox(
|
||||||
|
self,
|
||||||
|
config: Optional[SandboxConfig] = None,
|
||||||
|
volume_bindings: Optional[Dict[str, str]] = None,
|
||||||
|
) -> str:
|
||||||
|
"""Creates a new sandbox instance.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config: Sandbox configuration.
|
||||||
|
volume_bindings: Volume mapping configuration.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: Sandbox ID.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
RuntimeError: If max sandbox count reached or creation fails.
|
||||||
|
"""
|
||||||
|
async with self._global_lock:
|
||||||
|
if len(self._sandboxes) >= self.max_sandboxes:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"Maximum number of sandboxes ({self.max_sandboxes}) reached"
|
||||||
|
)
|
||||||
|
|
||||||
|
config = config or SandboxConfig()
|
||||||
|
if not await self.ensure_image(config.image):
|
||||||
|
raise RuntimeError(f"Failed to ensure Docker image: {config.image}")
|
||||||
|
|
||||||
|
sandbox_id = str(uuid.uuid4())
|
||||||
|
try:
|
||||||
|
sandbox = DockerSandbox(config, volume_bindings)
|
||||||
|
await sandbox.create()
|
||||||
|
|
||||||
|
self._sandboxes[sandbox_id] = sandbox
|
||||||
|
self._last_used[sandbox_id] = asyncio.get_event_loop().time()
|
||||||
|
self._locks[sandbox_id] = asyncio.Lock()
|
||||||
|
|
||||||
|
logger.info(f"Created sandbox {sandbox_id}")
|
||||||
|
return sandbox_id
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to create sandbox: {e}")
|
||||||
|
if sandbox_id in self._sandboxes:
|
||||||
|
await self.delete_sandbox(sandbox_id)
|
||||||
|
raise RuntimeError(f"Failed to create sandbox: {e}")
|
||||||
|
|
||||||
|
async def get_sandbox(self, sandbox_id: str) -> DockerSandbox:
|
||||||
|
"""Gets a sandbox instance.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
sandbox_id: Sandbox ID.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
DockerSandbox: Sandbox instance.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
KeyError: If sandbox does not exist.
|
||||||
|
"""
|
||||||
|
async with self.sandbox_operation(sandbox_id) as sandbox:
|
||||||
|
return sandbox
|
||||||
|
|
||||||
|
def start_cleanup_task(self) -> None:
|
||||||
|
"""Starts automatic cleanup task."""
|
||||||
|
|
||||||
|
async def cleanup_loop():
|
||||||
|
while not self._is_shutting_down:
|
||||||
|
try:
|
||||||
|
await self._cleanup_idle_sandboxes()
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error in cleanup loop: {e}")
|
||||||
|
await asyncio.sleep(self.cleanup_interval)
|
||||||
|
|
||||||
|
self._cleanup_task = asyncio.create_task(cleanup_loop())
|
||||||
|
|
||||||
|
async def _cleanup_idle_sandboxes(self) -> None:
|
||||||
|
"""Cleans up idle sandboxes."""
|
||||||
|
current_time = asyncio.get_event_loop().time()
|
||||||
|
to_cleanup = []
|
||||||
|
|
||||||
|
async with self._global_lock:
|
||||||
|
for sandbox_id, last_used in self._last_used.items():
|
||||||
|
if (
|
||||||
|
sandbox_id not in self._active_operations
|
||||||
|
and current_time - last_used > self.idle_timeout
|
||||||
|
):
|
||||||
|
to_cleanup.append(sandbox_id)
|
||||||
|
|
||||||
|
for sandbox_id in to_cleanup:
|
||||||
|
try:
|
||||||
|
await self.delete_sandbox(sandbox_id)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error cleaning up sandbox {sandbox_id}: {e}")
|
||||||
|
|
||||||
|
async def cleanup(self) -> None:
|
||||||
|
"""Cleans up all resources."""
|
||||||
|
logger.info("Starting manager cleanup...")
|
||||||
|
self._is_shutting_down = True
|
||||||
|
|
||||||
|
# Cancel cleanup task
|
||||||
|
if self._cleanup_task:
|
||||||
|
self._cleanup_task.cancel()
|
||||||
|
try:
|
||||||
|
await asyncio.wait_for(self._cleanup_task, timeout=1.0)
|
||||||
|
except (asyncio.CancelledError, asyncio.TimeoutError):
|
||||||
|
pass
|
||||||
|
|
||||||
|
# Get all sandbox IDs to clean up
|
||||||
|
async with self._global_lock:
|
||||||
|
sandbox_ids = list(self._sandboxes.keys())
|
||||||
|
|
||||||
|
# Concurrently clean up all sandboxes
|
||||||
|
cleanup_tasks = []
|
||||||
|
for sandbox_id in sandbox_ids:
|
||||||
|
task = asyncio.create_task(self._safe_delete_sandbox(sandbox_id))
|
||||||
|
cleanup_tasks.append(task)
|
||||||
|
|
||||||
|
if cleanup_tasks:
|
||||||
|
# Wait for all cleanup tasks to complete, with timeout to avoid infinite waiting
|
||||||
|
try:
|
||||||
|
await asyncio.wait(cleanup_tasks, timeout=30.0)
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
logger.error("Sandbox cleanup timed out")
|
||||||
|
|
||||||
|
# Clean up remaining references
|
||||||
|
self._sandboxes.clear()
|
||||||
|
self._last_used.clear()
|
||||||
|
self._locks.clear()
|
||||||
|
self._active_operations.clear()
|
||||||
|
|
||||||
|
logger.info("Manager cleanup completed")
|
||||||
|
|
||||||
|
async def _safe_delete_sandbox(self, sandbox_id: str) -> None:
|
||||||
|
"""Safely deletes a single sandbox.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
sandbox_id: Sandbox ID to delete.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
if sandbox_id in self._active_operations:
|
||||||
|
logger.warning(
|
||||||
|
f"Sandbox {sandbox_id} has active operations, waiting for completion"
|
||||||
|
)
|
||||||
|
for _ in range(10): # Wait at most 10 times
|
||||||
|
await asyncio.sleep(0.5)
|
||||||
|
if sandbox_id not in self._active_operations:
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
logger.warning(
|
||||||
|
f"Timeout waiting for sandbox {sandbox_id} operations to complete"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Get reference to sandbox object
|
||||||
|
sandbox = self._sandboxes.get(sandbox_id)
|
||||||
|
if sandbox:
|
||||||
|
await sandbox.cleanup()
|
||||||
|
|
||||||
|
# Remove sandbox record from manager
|
||||||
|
async with self._global_lock:
|
||||||
|
self._sandboxes.pop(sandbox_id, None)
|
||||||
|
self._last_used.pop(sandbox_id, None)
|
||||||
|
self._locks.pop(sandbox_id, None)
|
||||||
|
logger.info(f"Deleted sandbox {sandbox_id}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error during cleanup of sandbox {sandbox_id}: {e}")
|
||||||
|
|
||||||
|
async def delete_sandbox(self, sandbox_id: str) -> None:
|
||||||
|
"""Deletes specified sandbox.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
sandbox_id: Sandbox ID.
|
||||||
|
"""
|
||||||
|
if sandbox_id not in self._sandboxes:
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
await self._safe_delete_sandbox(sandbox_id)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to delete sandbox {sandbox_id}: {e}")
|
||||||
|
|
||||||
|
async def __aenter__(self) -> "SandboxManager":
|
||||||
|
"""Async context manager entry."""
|
||||||
|
return self
|
||||||
|
|
||||||
|
async def __aexit__(self, exc_type, exc_val, exc_tb) -> None:
|
||||||
|
"""Async context manager exit."""
|
||||||
|
await self.cleanup()
|
||||||
|
|
||||||
|
def get_stats(self) -> Dict:
|
||||||
|
"""Gets manager statistics.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict: Statistics information.
|
||||||
|
"""
|
||||||
|
return {
|
||||||
|
"total_sandboxes": len(self._sandboxes),
|
||||||
|
"active_operations": len(self._active_operations),
|
||||||
|
"max_sandboxes": self.max_sandboxes,
|
||||||
|
"idle_timeout": self.idle_timeout,
|
||||||
|
"cleanup_interval": self.cleanup_interval,
|
||||||
|
"is_shutting_down": self._is_shutting_down,
|
||||||
|
}
|
462
app/sandbox/core/sandbox.py
Normal file
462
app/sandbox/core/sandbox.py
Normal file
@ -0,0 +1,462 @@
|
|||||||
|
import asyncio
|
||||||
|
import io
|
||||||
|
import os
|
||||||
|
import tarfile
|
||||||
|
import tempfile
|
||||||
|
import uuid
|
||||||
|
from typing import Dict, Optional
|
||||||
|
|
||||||
|
import docker
|
||||||
|
from docker.errors import NotFound
|
||||||
|
from docker.models.containers import Container
|
||||||
|
|
||||||
|
from app.config import SandboxConfig
|
||||||
|
from app.sandbox.core.exceptions import SandboxTimeoutError
|
||||||
|
from app.sandbox.core.terminal import AsyncDockerizedTerminal
|
||||||
|
|
||||||
|
|
||||||
|
class DockerSandbox:
|
||||||
|
"""Docker sandbox environment.
|
||||||
|
|
||||||
|
Provides a containerized execution environment with resource limits,
|
||||||
|
file operations, and command execution capabilities.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
config: Sandbox configuration.
|
||||||
|
volume_bindings: Volume mapping configuration.
|
||||||
|
client: Docker client.
|
||||||
|
container: Docker container instance.
|
||||||
|
terminal: Container terminal interface.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
config: Optional[SandboxConfig] = None,
|
||||||
|
volume_bindings: Optional[Dict[str, str]] = None,
|
||||||
|
):
|
||||||
|
"""Initializes a sandbox instance.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config: Sandbox configuration. Default configuration used if None.
|
||||||
|
volume_bindings: Volume mappings in {host_path: container_path} format.
|
||||||
|
"""
|
||||||
|
self.config = config or SandboxConfig()
|
||||||
|
self.volume_bindings = volume_bindings or {}
|
||||||
|
self.client = docker.from_env()
|
||||||
|
self.container: Optional[Container] = None
|
||||||
|
self.terminal: Optional[AsyncDockerizedTerminal] = None
|
||||||
|
|
||||||
|
async def create(self) -> "DockerSandbox":
|
||||||
|
"""Creates and starts the sandbox container.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Current sandbox instance.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
docker.errors.APIError: If Docker API call fails.
|
||||||
|
RuntimeError: If container creation or startup fails.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# Prepare container config
|
||||||
|
host_config = self.client.api.create_host_config(
|
||||||
|
mem_limit=self.config.memory_limit,
|
||||||
|
cpu_period=100000,
|
||||||
|
cpu_quota=int(100000 * self.config.cpu_limit),
|
||||||
|
network_mode="none" if not self.config.network_enabled else "bridge",
|
||||||
|
binds=self._prepare_volume_bindings(),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Generate unique container name with sandbox_ prefix
|
||||||
|
container_name = f"sandbox_{uuid.uuid4().hex[:8]}"
|
||||||
|
|
||||||
|
# Create container
|
||||||
|
container = await asyncio.to_thread(
|
||||||
|
self.client.api.create_container,
|
||||||
|
image=self.config.image,
|
||||||
|
command="tail -f /dev/null",
|
||||||
|
hostname="sandbox",
|
||||||
|
working_dir=self.config.work_dir,
|
||||||
|
host_config=host_config,
|
||||||
|
name=container_name,
|
||||||
|
tty=True,
|
||||||
|
detach=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.container = self.client.containers.get(container["Id"])
|
||||||
|
|
||||||
|
# Start container
|
||||||
|
await asyncio.to_thread(self.container.start)
|
||||||
|
|
||||||
|
# Initialize terminal
|
||||||
|
self.terminal = AsyncDockerizedTerminal(
|
||||||
|
container["Id"],
|
||||||
|
self.config.work_dir,
|
||||||
|
env_vars={"PYTHONUNBUFFERED": "1"}
|
||||||
|
# Ensure Python output is not buffered
|
||||||
|
)
|
||||||
|
await self.terminal.init()
|
||||||
|
|
||||||
|
return self
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
await self.cleanup() # Ensure resources are cleaned up
|
||||||
|
raise RuntimeError(f"Failed to create sandbox: {e}") from e
|
||||||
|
|
||||||
|
def _prepare_volume_bindings(self) -> Dict[str, Dict[str, str]]:
|
||||||
|
"""Prepares volume binding configuration.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Volume binding configuration dictionary.
|
||||||
|
"""
|
||||||
|
bindings = {}
|
||||||
|
|
||||||
|
# Create and add working directory mapping
|
||||||
|
work_dir = self._ensure_host_dir(self.config.work_dir)
|
||||||
|
bindings[work_dir] = {"bind": self.config.work_dir, "mode": "rw"}
|
||||||
|
|
||||||
|
# Add custom volume bindings
|
||||||
|
for host_path, container_path in self.volume_bindings.items():
|
||||||
|
bindings[host_path] = {"bind": container_path, "mode": "rw"}
|
||||||
|
|
||||||
|
return bindings
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _ensure_host_dir(path: str) -> str:
|
||||||
|
"""Ensures directory exists on the host.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
path: Directory path.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Actual path on the host.
|
||||||
|
"""
|
||||||
|
host_path = os.path.join(
|
||||||
|
tempfile.gettempdir(),
|
||||||
|
f"sandbox_{os.path.basename(path)}_{os.urandom(4).hex()}",
|
||||||
|
)
|
||||||
|
os.makedirs(host_path, exist_ok=True)
|
||||||
|
return host_path
|
||||||
|
|
||||||
|
async def run_command(self, cmd: str, timeout: Optional[int] = None) -> str:
|
||||||
|
"""Runs a command in the sandbox.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
cmd: Command to execute.
|
||||||
|
timeout: Timeout in seconds.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Command output as string.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
RuntimeError: If sandbox not initialized or command execution fails.
|
||||||
|
TimeoutError: If command execution times out.
|
||||||
|
"""
|
||||||
|
if not self.terminal:
|
||||||
|
raise RuntimeError("Sandbox not initialized")
|
||||||
|
|
||||||
|
try:
|
||||||
|
return await self.terminal.run_command(
|
||||||
|
cmd, timeout=timeout or self.config.timeout
|
||||||
|
)
|
||||||
|
except TimeoutError:
|
||||||
|
raise SandboxTimeoutError(
|
||||||
|
f"Command execution timed out after {timeout or self.config.timeout} seconds"
|
||||||
|
)
|
||||||
|
|
||||||
|
async def read_file(self, path: str) -> str:
|
||||||
|
"""Reads a file from the container.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
path: File path.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
File contents as string.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
FileNotFoundError: If file does not exist.
|
||||||
|
RuntimeError: If read operation fails.
|
||||||
|
"""
|
||||||
|
if not self.container:
|
||||||
|
raise RuntimeError("Sandbox not initialized")
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Get file archive
|
||||||
|
resolved_path = self._safe_resolve_path(path)
|
||||||
|
tar_stream, _ = await asyncio.to_thread(
|
||||||
|
self.container.get_archive, resolved_path
|
||||||
|
)
|
||||||
|
|
||||||
|
# Read file content from tar stream
|
||||||
|
content = await self._read_from_tar(tar_stream)
|
||||||
|
return content.decode("utf-8")
|
||||||
|
|
||||||
|
except NotFound:
|
||||||
|
raise FileNotFoundError(f"File not found: {path}")
|
||||||
|
except Exception as e:
|
||||||
|
raise RuntimeError(f"Failed to read file: {e}")
|
||||||
|
|
||||||
|
async def write_file(self, path: str, content: str) -> None:
|
||||||
|
"""Writes content to a file in the container.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
path: Target path.
|
||||||
|
content: File content.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
RuntimeError: If write operation fails.
|
||||||
|
"""
|
||||||
|
if not self.container:
|
||||||
|
raise RuntimeError("Sandbox not initialized")
|
||||||
|
|
||||||
|
try:
|
||||||
|
resolved_path = self._safe_resolve_path(path)
|
||||||
|
parent_dir = os.path.dirname(resolved_path)
|
||||||
|
|
||||||
|
# Create parent directory
|
||||||
|
if parent_dir:
|
||||||
|
await self.run_command(f"mkdir -p {parent_dir}")
|
||||||
|
|
||||||
|
# Prepare file data
|
||||||
|
tar_stream = await self._create_tar_stream(
|
||||||
|
os.path.basename(path), content.encode("utf-8")
|
||||||
|
)
|
||||||
|
|
||||||
|
# Write file
|
||||||
|
await asyncio.to_thread(
|
||||||
|
self.container.put_archive, parent_dir or "/", tar_stream
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
raise RuntimeError(f"Failed to write file: {e}")
|
||||||
|
|
||||||
|
def _safe_resolve_path(self, path: str) -> str:
|
||||||
|
"""Safely resolves container path, preventing path traversal.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
path: Original path.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Resolved absolute path.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If path contains potentially unsafe patterns.
|
||||||
|
"""
|
||||||
|
# Check for path traversal attempts
|
||||||
|
if ".." in path.split("/"):
|
||||||
|
raise ValueError("Path contains potentially unsafe patterns")
|
||||||
|
|
||||||
|
resolved = (
|
||||||
|
os.path.join(self.config.work_dir, path)
|
||||||
|
if not os.path.isabs(path)
|
||||||
|
else path
|
||||||
|
)
|
||||||
|
return resolved
|
||||||
|
|
||||||
|
async def copy_from(self, src_path: str, dst_path: str) -> None:
|
||||||
|
"""Copies a file from the container.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
src_path: Source file path (container).
|
||||||
|
dst_path: Destination path (host).
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
FileNotFoundError: If source file does not exist.
|
||||||
|
RuntimeError: If copy operation fails.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# Ensure destination file's parent directory exists
|
||||||
|
parent_dir = os.path.dirname(dst_path)
|
||||||
|
if parent_dir:
|
||||||
|
os.makedirs(parent_dir, exist_ok=True)
|
||||||
|
|
||||||
|
# Get file stream
|
||||||
|
resolved_src = self._safe_resolve_path(src_path)
|
||||||
|
stream, stat = await asyncio.to_thread(
|
||||||
|
self.container.get_archive, resolved_src
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create temporary directory to extract file
|
||||||
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||||
|
# Write stream to temporary file
|
||||||
|
tar_path = os.path.join(tmp_dir, "temp.tar")
|
||||||
|
with open(tar_path, "wb") as f:
|
||||||
|
for chunk in stream:
|
||||||
|
f.write(chunk)
|
||||||
|
|
||||||
|
# Extract file
|
||||||
|
with tarfile.open(tar_path) as tar:
|
||||||
|
members = tar.getmembers()
|
||||||
|
if not members:
|
||||||
|
raise FileNotFoundError(f"Source file is empty: {src_path}")
|
||||||
|
|
||||||
|
# If destination is a directory, we should preserve relative path structure
|
||||||
|
if os.path.isdir(dst_path):
|
||||||
|
tar.extractall(dst_path)
|
||||||
|
else:
|
||||||
|
# If destination is a file, we only extract the source file's content
|
||||||
|
if len(members) > 1:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"Source path is a directory but destination is a file: {src_path}"
|
||||||
|
)
|
||||||
|
|
||||||
|
with open(dst_path, "wb") as dst:
|
||||||
|
src_file = tar.extractfile(members[0])
|
||||||
|
if src_file is None:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"Failed to extract file: {src_path}"
|
||||||
|
)
|
||||||
|
dst.write(src_file.read())
|
||||||
|
|
||||||
|
except docker.errors.NotFound:
|
||||||
|
raise FileNotFoundError(f"Source file not found: {src_path}")
|
||||||
|
except Exception as e:
|
||||||
|
raise RuntimeError(f"Failed to copy file: {e}")
|
||||||
|
|
||||||
|
async def copy_to(self, src_path: str, dst_path: str) -> None:
|
||||||
|
"""Copies a file to the container.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
src_path: Source file path (host).
|
||||||
|
dst_path: Destination path (container).
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
FileNotFoundError: If source file does not exist.
|
||||||
|
RuntimeError: If copy operation fails.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
if not os.path.exists(src_path):
|
||||||
|
raise FileNotFoundError(f"Source file not found: {src_path}")
|
||||||
|
|
||||||
|
# Create destination directory in container
|
||||||
|
resolved_dst = self._safe_resolve_path(dst_path)
|
||||||
|
container_dir = os.path.dirname(resolved_dst)
|
||||||
|
if container_dir:
|
||||||
|
await self.run_command(f"mkdir -p {container_dir}")
|
||||||
|
|
||||||
|
# Create tar file to upload
|
||||||
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||||
|
tar_path = os.path.join(tmp_dir, "temp.tar")
|
||||||
|
with tarfile.open(tar_path, "w") as tar:
|
||||||
|
# Handle directory source path
|
||||||
|
if os.path.isdir(src_path):
|
||||||
|
os.path.basename(src_path.rstrip("/"))
|
||||||
|
for root, _, files in os.walk(src_path):
|
||||||
|
for file in files:
|
||||||
|
file_path = os.path.join(root, file)
|
||||||
|
arcname = os.path.join(
|
||||||
|
os.path.basename(dst_path),
|
||||||
|
os.path.relpath(file_path, src_path),
|
||||||
|
)
|
||||||
|
tar.add(file_path, arcname=arcname)
|
||||||
|
else:
|
||||||
|
# Add single file to tar
|
||||||
|
tar.add(src_path, arcname=os.path.basename(dst_path))
|
||||||
|
|
||||||
|
# Read tar file content
|
||||||
|
with open(tar_path, "rb") as f:
|
||||||
|
data = f.read()
|
||||||
|
|
||||||
|
# Upload to container
|
||||||
|
await asyncio.to_thread(
|
||||||
|
self.container.put_archive,
|
||||||
|
os.path.dirname(resolved_dst) or "/",
|
||||||
|
data,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify file was created successfully
|
||||||
|
try:
|
||||||
|
await self.run_command(f"test -e {resolved_dst}")
|
||||||
|
except Exception:
|
||||||
|
raise RuntimeError(f"Failed to verify file creation: {dst_path}")
|
||||||
|
|
||||||
|
except FileNotFoundError:
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
raise RuntimeError(f"Failed to copy file: {e}")
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def _create_tar_stream(name: str, content: bytes) -> io.BytesIO:
|
||||||
|
"""Creates a tar file stream.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name: Filename.
|
||||||
|
content: File content.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tar file stream.
|
||||||
|
"""
|
||||||
|
tar_stream = io.BytesIO()
|
||||||
|
with tarfile.open(fileobj=tar_stream, mode="w") as tar:
|
||||||
|
tarinfo = tarfile.TarInfo(name=name)
|
||||||
|
tarinfo.size = len(content)
|
||||||
|
tar.addfile(tarinfo, io.BytesIO(content))
|
||||||
|
tar_stream.seek(0)
|
||||||
|
return tar_stream
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def _read_from_tar(tar_stream) -> bytes:
|
||||||
|
"""Reads file content from a tar stream.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tar_stream: Tar file stream.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
File content.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
RuntimeError: If read operation fails.
|
||||||
|
"""
|
||||||
|
with tempfile.NamedTemporaryFile() as tmp:
|
||||||
|
for chunk in tar_stream:
|
||||||
|
tmp.write(chunk)
|
||||||
|
tmp.seek(0)
|
||||||
|
|
||||||
|
with tarfile.open(fileobj=tmp) as tar:
|
||||||
|
member = tar.next()
|
||||||
|
if not member:
|
||||||
|
raise RuntimeError("Empty tar archive")
|
||||||
|
|
||||||
|
file_content = tar.extractfile(member)
|
||||||
|
if not file_content:
|
||||||
|
raise RuntimeError("Failed to extract file content")
|
||||||
|
|
||||||
|
return file_content.read()
|
||||||
|
|
||||||
|
async def cleanup(self) -> None:
|
||||||
|
"""Cleans up sandbox resources."""
|
||||||
|
errors = []
|
||||||
|
try:
|
||||||
|
if self.terminal:
|
||||||
|
try:
|
||||||
|
await self.terminal.close()
|
||||||
|
except Exception as e:
|
||||||
|
errors.append(f"Terminal cleanup error: {e}")
|
||||||
|
finally:
|
||||||
|
self.terminal = None
|
||||||
|
|
||||||
|
if self.container:
|
||||||
|
try:
|
||||||
|
await asyncio.to_thread(self.container.stop, timeout=5)
|
||||||
|
except Exception as e:
|
||||||
|
errors.append(f"Container stop error: {e}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
await asyncio.to_thread(self.container.remove, force=True)
|
||||||
|
except Exception as e:
|
||||||
|
errors.append(f"Container remove error: {e}")
|
||||||
|
finally:
|
||||||
|
self.container = None
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
errors.append(f"General cleanup error: {e}")
|
||||||
|
|
||||||
|
if errors:
|
||||||
|
print(f"Warning: Errors during cleanup: {', '.join(errors)}")
|
||||||
|
|
||||||
|
async def __aenter__(self) -> "DockerSandbox":
|
||||||
|
"""Async context manager entry."""
|
||||||
|
return await self.create()
|
||||||
|
|
||||||
|
async def __aexit__(self, exc_type, exc_val, exc_tb) -> None:
|
||||||
|
"""Async context manager exit."""
|
||||||
|
await self.cleanup()
|
346
app/sandbox/core/terminal.py
Normal file
346
app/sandbox/core/terminal.py
Normal file
@ -0,0 +1,346 @@
|
|||||||
|
"""
|
||||||
|
Asynchronous Docker Terminal
|
||||||
|
|
||||||
|
This module provides asynchronous terminal functionality for Docker containers,
|
||||||
|
allowing interactive command execution with timeout control.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import re
|
||||||
|
import socket
|
||||||
|
from typing import Dict, Optional, Tuple, Union
|
||||||
|
|
||||||
|
import docker
|
||||||
|
from docker import APIClient
|
||||||
|
from docker.errors import APIError
|
||||||
|
from docker.models.containers import Container
|
||||||
|
|
||||||
|
|
||||||
|
class DockerSession:
|
||||||
|
def __init__(self, container_id: str) -> None:
|
||||||
|
"""Initializes a Docker session.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
container_id: ID of the Docker container.
|
||||||
|
"""
|
||||||
|
self.api = APIClient()
|
||||||
|
self.container_id = container_id
|
||||||
|
self.exec_id = None
|
||||||
|
self.socket = None
|
||||||
|
|
||||||
|
async def create(self, working_dir: str, env_vars: Dict[str, str]) -> None:
|
||||||
|
"""Creates an interactive session with the container.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
working_dir: Working directory inside the container.
|
||||||
|
env_vars: Environment variables to set.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
RuntimeError: If socket connection fails.
|
||||||
|
"""
|
||||||
|
startup_command = [
|
||||||
|
"bash",
|
||||||
|
"-c",
|
||||||
|
f"cd {working_dir} && "
|
||||||
|
"PROMPT_COMMAND='' "
|
||||||
|
"PS1='$ ' "
|
||||||
|
"exec bash --norc --noprofile",
|
||||||
|
]
|
||||||
|
|
||||||
|
exec_data = self.api.exec_create(
|
||||||
|
self.container_id,
|
||||||
|
startup_command,
|
||||||
|
stdin=True,
|
||||||
|
tty=True,
|
||||||
|
stdout=True,
|
||||||
|
stderr=True,
|
||||||
|
privileged=True,
|
||||||
|
user="root",
|
||||||
|
environment={**env_vars, "TERM": "dumb", "PS1": "$ ", "PROMPT_COMMAND": ""},
|
||||||
|
)
|
||||||
|
self.exec_id = exec_data["Id"]
|
||||||
|
|
||||||
|
socket_data = self.api.exec_start(
|
||||||
|
self.exec_id, socket=True, tty=True, stream=True, demux=True
|
||||||
|
)
|
||||||
|
|
||||||
|
if hasattr(socket_data, "_sock"):
|
||||||
|
self.socket = socket_data._sock
|
||||||
|
self.socket.setblocking(False)
|
||||||
|
else:
|
||||||
|
raise RuntimeError("Failed to get socket connection")
|
||||||
|
|
||||||
|
await self._read_until_prompt()
|
||||||
|
|
||||||
|
async def close(self) -> None:
|
||||||
|
"""Cleans up session resources.
|
||||||
|
|
||||||
|
1. Sends exit command
|
||||||
|
2. Closes socket connection
|
||||||
|
3. Checks and cleans up exec instance
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
if self.socket:
|
||||||
|
# Send exit command to close bash session
|
||||||
|
try:
|
||||||
|
self.socket.sendall(b"exit\n")
|
||||||
|
# Allow time for command execution
|
||||||
|
await asyncio.sleep(0.1)
|
||||||
|
except:
|
||||||
|
pass # Ignore sending errors, continue cleanup
|
||||||
|
|
||||||
|
# Close socket connection
|
||||||
|
try:
|
||||||
|
self.socket.shutdown(socket.SHUT_RDWR)
|
||||||
|
except:
|
||||||
|
pass # Some platforms may not support shutdown
|
||||||
|
|
||||||
|
self.socket.close()
|
||||||
|
self.socket = None
|
||||||
|
|
||||||
|
if self.exec_id:
|
||||||
|
try:
|
||||||
|
# Check exec instance status
|
||||||
|
exec_inspect = self.api.exec_inspect(self.exec_id)
|
||||||
|
if exec_inspect.get("Running", False):
|
||||||
|
# If still running, wait for it to complete
|
||||||
|
await asyncio.sleep(0.5)
|
||||||
|
except:
|
||||||
|
pass # Ignore inspection errors, continue cleanup
|
||||||
|
|
||||||
|
self.exec_id = None
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
# Log error but don't raise, ensure cleanup continues
|
||||||
|
print(f"Warning: Error during session cleanup: {e}")
|
||||||
|
|
||||||
|
async def _read_until_prompt(self) -> str:
|
||||||
|
"""Reads output until prompt is found.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
String containing output up to the prompt.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
socket.error: If socket communication fails.
|
||||||
|
"""
|
||||||
|
buffer = b""
|
||||||
|
while b"$ " not in buffer:
|
||||||
|
try:
|
||||||
|
chunk = self.socket.recv(4096)
|
||||||
|
if chunk:
|
||||||
|
buffer += chunk
|
||||||
|
except socket.error as e:
|
||||||
|
if e.errno == socket.EWOULDBLOCK:
|
||||||
|
await asyncio.sleep(0.1)
|
||||||
|
continue
|
||||||
|
raise
|
||||||
|
return buffer.decode("utf-8")
|
||||||
|
|
||||||
|
async def execute(self, command: str, timeout: Optional[int] = None) -> str:
|
||||||
|
"""Executes a command and returns cleaned output.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
command: Shell command to execute.
|
||||||
|
timeout: Maximum execution time in seconds.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Command output as string with prompt markers removed.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
RuntimeError: If session not initialized or execution fails.
|
||||||
|
TimeoutError: If command execution exceeds timeout.
|
||||||
|
"""
|
||||||
|
if not self.socket:
|
||||||
|
raise RuntimeError("Session not initialized")
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Sanitize command to prevent shell injection
|
||||||
|
sanitized_command = self._sanitize_command(command)
|
||||||
|
full_command = f"{sanitized_command}\necho $?\n"
|
||||||
|
self.socket.sendall(full_command.encode())
|
||||||
|
|
||||||
|
async def read_output() -> str:
|
||||||
|
buffer = b""
|
||||||
|
result_lines = []
|
||||||
|
command_sent = False
|
||||||
|
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
chunk = self.socket.recv(4096)
|
||||||
|
if not chunk:
|
||||||
|
break
|
||||||
|
|
||||||
|
buffer += chunk
|
||||||
|
lines = buffer.split(b"\n")
|
||||||
|
|
||||||
|
buffer = lines[-1]
|
||||||
|
lines = lines[:-1]
|
||||||
|
|
||||||
|
for line in lines:
|
||||||
|
line = line.rstrip(b"\r")
|
||||||
|
|
||||||
|
if not command_sent:
|
||||||
|
command_sent = True
|
||||||
|
continue
|
||||||
|
|
||||||
|
if line.strip() == b"echo $?" or line.strip().isdigit():
|
||||||
|
continue
|
||||||
|
|
||||||
|
if line.strip():
|
||||||
|
result_lines.append(line)
|
||||||
|
|
||||||
|
if buffer.endswith(b"$ "):
|
||||||
|
break
|
||||||
|
|
||||||
|
except socket.error as e:
|
||||||
|
if e.errno == socket.EWOULDBLOCK:
|
||||||
|
await asyncio.sleep(0.1)
|
||||||
|
continue
|
||||||
|
raise
|
||||||
|
|
||||||
|
output = b"\n".join(result_lines).decode("utf-8")
|
||||||
|
output = re.sub(r"\n\$ echo \$\$?.*$", "", output)
|
||||||
|
|
||||||
|
return output
|
||||||
|
|
||||||
|
if timeout:
|
||||||
|
result = await asyncio.wait_for(read_output(), timeout)
|
||||||
|
else:
|
||||||
|
result = await read_output()
|
||||||
|
|
||||||
|
return result.strip()
|
||||||
|
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
raise TimeoutError(f"Command execution timed out after {timeout} seconds")
|
||||||
|
except Exception as e:
|
||||||
|
raise RuntimeError(f"Failed to execute command: {e}")
|
||||||
|
|
||||||
|
def _sanitize_command(self, command: str) -> str:
|
||||||
|
"""Sanitizes the command string to prevent shell injection.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
command: Raw command string.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Sanitized command string.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If command contains potentially dangerous patterns.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Additional checks for specific risky commands
|
||||||
|
risky_commands = [
|
||||||
|
"rm -rf /",
|
||||||
|
"rm -rf /*",
|
||||||
|
"mkfs",
|
||||||
|
"dd if=/dev/zero",
|
||||||
|
":(){:|:&};:",
|
||||||
|
"chmod -R 777 /",
|
||||||
|
"chown -R",
|
||||||
|
]
|
||||||
|
|
||||||
|
for risky in risky_commands:
|
||||||
|
if risky in command.lower():
|
||||||
|
raise ValueError(
|
||||||
|
f"Command contains potentially dangerous operation: {risky}"
|
||||||
|
)
|
||||||
|
|
||||||
|
return command
|
||||||
|
|
||||||
|
|
||||||
|
class AsyncDockerizedTerminal:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
container: Union[str, Container],
|
||||||
|
working_dir: str = "/workspace",
|
||||||
|
env_vars: Optional[Dict[str, str]] = None,
|
||||||
|
default_timeout: int = 60,
|
||||||
|
) -> None:
|
||||||
|
"""Initializes an asynchronous terminal for Docker containers.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
container: Docker container ID or Container object.
|
||||||
|
working_dir: Working directory inside the container.
|
||||||
|
env_vars: Environment variables to set.
|
||||||
|
default_timeout: Default command execution timeout in seconds.
|
||||||
|
"""
|
||||||
|
self.client = docker.from_env()
|
||||||
|
self.container = (
|
||||||
|
container
|
||||||
|
if isinstance(container, Container)
|
||||||
|
else self.client.containers.get(container)
|
||||||
|
)
|
||||||
|
self.working_dir = working_dir
|
||||||
|
self.env_vars = env_vars or {}
|
||||||
|
self.default_timeout = default_timeout
|
||||||
|
self.session = None
|
||||||
|
|
||||||
|
async def init(self) -> None:
|
||||||
|
"""Initializes the terminal environment.
|
||||||
|
|
||||||
|
Ensures working directory exists and creates an interactive session.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
RuntimeError: If initialization fails.
|
||||||
|
"""
|
||||||
|
await self._ensure_workdir()
|
||||||
|
|
||||||
|
self.session = DockerSession(self.container.id)
|
||||||
|
await self.session.create(self.working_dir, self.env_vars)
|
||||||
|
|
||||||
|
async def _ensure_workdir(self) -> None:
|
||||||
|
"""Ensures working directory exists in container.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
RuntimeError: If directory creation fails.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
await self._exec_simple(f"mkdir -p {self.working_dir}")
|
||||||
|
except APIError as e:
|
||||||
|
raise RuntimeError(f"Failed to create working directory: {e}")
|
||||||
|
|
||||||
|
async def _exec_simple(self, cmd: str) -> Tuple[int, str]:
|
||||||
|
"""Executes a simple command using Docker's exec_run.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
cmd: Command to execute.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (exit_code, output).
|
||||||
|
"""
|
||||||
|
result = await asyncio.to_thread(
|
||||||
|
self.container.exec_run, cmd, environment=self.env_vars
|
||||||
|
)
|
||||||
|
return result.exit_code, result.output.decode("utf-8")
|
||||||
|
|
||||||
|
async def run_command(self, cmd: str, timeout: Optional[int] = None) -> str:
|
||||||
|
"""Runs a command in the container with timeout.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
cmd: Shell command to execute.
|
||||||
|
timeout: Maximum execution time in seconds.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Command output as string.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
RuntimeError: If terminal not initialized.
|
||||||
|
"""
|
||||||
|
if not self.session:
|
||||||
|
raise RuntimeError("Terminal not initialized")
|
||||||
|
|
||||||
|
return await self.session.execute(cmd, timeout=timeout or self.default_timeout)
|
||||||
|
|
||||||
|
async def close(self) -> None:
|
||||||
|
"""Closes the terminal session."""
|
||||||
|
if self.session:
|
||||||
|
await self.session.close()
|
||||||
|
|
||||||
|
async def __aenter__(self) -> "AsyncDockerizedTerminal":
|
||||||
|
"""Async context manager entry."""
|
||||||
|
await self.init()
|
||||||
|
return self
|
||||||
|
|
||||||
|
async def __aexit__(self, exc_type, exc_val, exc_tb) -> None:
|
||||||
|
"""Async context manager exit."""
|
||||||
|
await self.close()
|
156
app/tool/file_operators.py
Normal file
156
app/tool/file_operators.py
Normal file
@ -0,0 +1,156 @@
|
|||||||
|
"""File operation interfaces and implementations for local and sandbox environments."""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Optional, Protocol, Tuple, Union, runtime_checkable
|
||||||
|
|
||||||
|
from app.config import SandboxConfig
|
||||||
|
from app.exceptions import ToolError
|
||||||
|
from app.sandbox.client import SANDBOX_CLIENT
|
||||||
|
|
||||||
|
|
||||||
|
PathLike = Union[str, Path]
|
||||||
|
|
||||||
|
|
||||||
|
@runtime_checkable
|
||||||
|
class FileOperator(Protocol):
|
||||||
|
"""Interface for file operations in different environments."""
|
||||||
|
|
||||||
|
async def read_file(self, path: PathLike) -> str:
|
||||||
|
"""Read content from a file."""
|
||||||
|
...
|
||||||
|
|
||||||
|
async def write_file(self, path: PathLike, content: str) -> None:
|
||||||
|
"""Write content to a file."""
|
||||||
|
...
|
||||||
|
|
||||||
|
async def is_directory(self, path: PathLike) -> bool:
|
||||||
|
"""Check if path points to a directory."""
|
||||||
|
...
|
||||||
|
|
||||||
|
async def exists(self, path: PathLike) -> bool:
|
||||||
|
"""Check if path exists."""
|
||||||
|
...
|
||||||
|
|
||||||
|
async def run_command(
|
||||||
|
self, cmd: str, timeout: Optional[float] = 120.0
|
||||||
|
) -> Tuple[int, str, str]:
|
||||||
|
"""Run a shell command and return (return_code, stdout, stderr)."""
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
|
class LocalFileOperator(FileOperator):
|
||||||
|
"""File operations implementation for local filesystem."""
|
||||||
|
|
||||||
|
async def read_file(self, path: PathLike) -> str:
|
||||||
|
"""Read content from a local file."""
|
||||||
|
try:
|
||||||
|
return Path(path).read_text()
|
||||||
|
except Exception as e:
|
||||||
|
raise ToolError(f"Failed to read {path}: {str(e)}") from None
|
||||||
|
|
||||||
|
async def write_file(self, path: PathLike, content: str) -> None:
|
||||||
|
"""Write content to a local file."""
|
||||||
|
try:
|
||||||
|
Path(path).write_text(content)
|
||||||
|
except Exception as e:
|
||||||
|
raise ToolError(f"Failed to write to {path}: {str(e)}") from None
|
||||||
|
|
||||||
|
async def is_directory(self, path: PathLike) -> bool:
|
||||||
|
"""Check if path points to a directory."""
|
||||||
|
return Path(path).is_dir()
|
||||||
|
|
||||||
|
async def exists(self, path: PathLike) -> bool:
|
||||||
|
"""Check if path exists."""
|
||||||
|
return Path(path).exists()
|
||||||
|
|
||||||
|
async def run_command(
|
||||||
|
self, cmd: str, timeout: Optional[float] = 120.0
|
||||||
|
) -> Tuple[int, str, str]:
|
||||||
|
"""Run a shell command locally."""
|
||||||
|
process = await asyncio.create_subprocess_shell(
|
||||||
|
cmd, stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
stdout, stderr = await asyncio.wait_for(
|
||||||
|
process.communicate(), timeout=timeout
|
||||||
|
)
|
||||||
|
return (
|
||||||
|
process.returncode or 0,
|
||||||
|
stdout.decode(),
|
||||||
|
stderr.decode(),
|
||||||
|
)
|
||||||
|
except asyncio.TimeoutError as exc:
|
||||||
|
try:
|
||||||
|
process.kill()
|
||||||
|
except ProcessLookupError:
|
||||||
|
pass
|
||||||
|
raise TimeoutError(
|
||||||
|
f"Command '{cmd}' timed out after {timeout} seconds"
|
||||||
|
) from exc
|
||||||
|
|
||||||
|
|
||||||
|
class SandboxFileOperator(FileOperator):
|
||||||
|
"""File operations implementation for sandbox environment."""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.sandbox_client = SANDBOX_CLIENT
|
||||||
|
|
||||||
|
async def _ensure_sandbox_initialized(self):
|
||||||
|
"""Ensure sandbox is initialized."""
|
||||||
|
if not self.sandbox_client.sandbox:
|
||||||
|
await self.sandbox_client.create(config=SandboxConfig())
|
||||||
|
|
||||||
|
async def read_file(self, path: PathLike) -> str:
|
||||||
|
"""Read content from a file in sandbox."""
|
||||||
|
await self._ensure_sandbox_initialized()
|
||||||
|
try:
|
||||||
|
return await self.sandbox_client.read_file(str(path))
|
||||||
|
except Exception as e:
|
||||||
|
raise ToolError(f"Failed to read {path} in sandbox: {str(e)}") from None
|
||||||
|
|
||||||
|
async def write_file(self, path: PathLike, content: str) -> None:
|
||||||
|
"""Write content to a file in sandbox."""
|
||||||
|
await self._ensure_sandbox_initialized()
|
||||||
|
try:
|
||||||
|
await self.sandbox_client.write_file(str(path), content)
|
||||||
|
except Exception as e:
|
||||||
|
raise ToolError(f"Failed to write to {path} in sandbox: {str(e)}") from None
|
||||||
|
|
||||||
|
async def is_directory(self, path: PathLike) -> bool:
|
||||||
|
"""Check if path points to a directory in sandbox."""
|
||||||
|
await self._ensure_sandbox_initialized()
|
||||||
|
result = await self.sandbox_client.run_command(
|
||||||
|
f"test -d {path} && echo 'true' || echo 'false'"
|
||||||
|
)
|
||||||
|
return result.strip() == "true"
|
||||||
|
|
||||||
|
async def exists(self, path: PathLike) -> bool:
|
||||||
|
"""Check if path exists in sandbox."""
|
||||||
|
await self._ensure_sandbox_initialized()
|
||||||
|
result = await self.sandbox_client.run_command(
|
||||||
|
f"test -e {path} && echo 'true' || echo 'false'"
|
||||||
|
)
|
||||||
|
return result.strip() == "true"
|
||||||
|
|
||||||
|
async def run_command(
|
||||||
|
self, cmd: str, timeout: Optional[float] = 120.0
|
||||||
|
) -> Tuple[int, str, str]:
|
||||||
|
"""Run a command in sandbox environment."""
|
||||||
|
await self._ensure_sandbox_initialized()
|
||||||
|
try:
|
||||||
|
stdout = await self.sandbox_client.run_command(
|
||||||
|
cmd, timeout=int(timeout) if timeout else None
|
||||||
|
)
|
||||||
|
return (
|
||||||
|
0, # Always return 0 since we don't have explicit return code from sandbox
|
||||||
|
stdout,
|
||||||
|
"", # No stderr capture in the current sandbox implementation
|
||||||
|
)
|
||||||
|
except TimeoutError as exc:
|
||||||
|
raise TimeoutError(
|
||||||
|
f"Command '{cmd}' timed out after {timeout} seconds in sandbox"
|
||||||
|
) from exc
|
||||||
|
except Exception as exc:
|
||||||
|
return 1, "", f"Error executing command in sandbox: {str(exc)}"
|
@ -1,43 +0,0 @@
|
|||||||
"""Utility to run shell commands asynchronously with a timeout."""
|
|
||||||
|
|
||||||
import asyncio
|
|
||||||
|
|
||||||
|
|
||||||
TRUNCATED_MESSAGE: str = "<response clipped><NOTE>To save on context only part of this file has been shown to you. You should retry this tool after you have searched inside the file with `grep -n` in order to find the line numbers of what you are looking for.</NOTE>"
|
|
||||||
MAX_RESPONSE_LEN: int = 16000
|
|
||||||
|
|
||||||
|
|
||||||
def maybe_truncate(content: str, truncate_after: int | None = MAX_RESPONSE_LEN):
|
|
||||||
"""Truncate content and append a notice if content exceeds the specified length."""
|
|
||||||
return (
|
|
||||||
content
|
|
||||||
if not truncate_after or len(content) <= truncate_after
|
|
||||||
else content[:truncate_after] + TRUNCATED_MESSAGE
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
async def run(
|
|
||||||
cmd: str,
|
|
||||||
timeout: float | None = 120.0, # seconds
|
|
||||||
truncate_after: int | None = MAX_RESPONSE_LEN,
|
|
||||||
):
|
|
||||||
"""Run a shell command asynchronously with a timeout."""
|
|
||||||
process = await asyncio.create_subprocess_shell(
|
|
||||||
cmd, stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE
|
|
||||||
)
|
|
||||||
|
|
||||||
try:
|
|
||||||
stdout, stderr = await asyncio.wait_for(process.communicate(), timeout=timeout)
|
|
||||||
return (
|
|
||||||
process.returncode or 0,
|
|
||||||
maybe_truncate(stdout.decode(), truncate_after=truncate_after),
|
|
||||||
maybe_truncate(stderr.decode(), truncate_after=truncate_after),
|
|
||||||
)
|
|
||||||
except asyncio.TimeoutError as exc:
|
|
||||||
try:
|
|
||||||
process.kill()
|
|
||||||
except ProcessLookupError:
|
|
||||||
pass
|
|
||||||
raise TimeoutError(
|
|
||||||
f"Command '{cmd}' timed out after {timeout} seconds"
|
|
||||||
) from exc
|
|
@ -1,11 +1,18 @@
|
|||||||
from collections import defaultdict
|
"""File and directory manipulation tool with sandbox support."""
|
||||||
from pathlib import Path
|
|
||||||
from typing import Literal, get_args
|
|
||||||
|
|
||||||
|
from collections import defaultdict
|
||||||
|
from typing import Any, DefaultDict, List, Literal, Optional, get_args
|
||||||
|
|
||||||
|
from app.config import config
|
||||||
from app.exceptions import ToolError
|
from app.exceptions import ToolError
|
||||||
from app.tool import BaseTool
|
from app.tool import BaseTool
|
||||||
from app.tool.base import CLIResult, ToolResult
|
from app.tool.base import CLIResult, ToolResult
|
||||||
from app.tool.run import run
|
from app.tool.file_operators import (
|
||||||
|
FileOperator,
|
||||||
|
LocalFileOperator,
|
||||||
|
PathLike,
|
||||||
|
SandboxFileOperator,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
Command = Literal[
|
Command = Literal[
|
||||||
@ -15,12 +22,17 @@ Command = Literal[
|
|||||||
"insert",
|
"insert",
|
||||||
"undo_edit",
|
"undo_edit",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
# Constants
|
||||||
SNIPPET_LINES: int = 4
|
SNIPPET_LINES: int = 4
|
||||||
|
|
||||||
MAX_RESPONSE_LEN: int = 16000
|
MAX_RESPONSE_LEN: int = 16000
|
||||||
|
TRUNCATED_MESSAGE: str = (
|
||||||
|
"<response clipped><NOTE>To save on context only part of this file has been shown to you. "
|
||||||
|
"You should retry this tool after you have searched inside the file with `grep -n` "
|
||||||
|
"in order to find the line numbers of what you are looking for.</NOTE>"
|
||||||
|
)
|
||||||
|
|
||||||
TRUNCATED_MESSAGE: str = "<response clipped><NOTE>To save on context only part of this file has been shown to you. You should retry this tool after you have searched inside the file with `grep -n` in order to find the line numbers of what you are looking for.</NOTE>"
|
# Tool description
|
||||||
|
|
||||||
_STR_REPLACE_EDITOR_DESCRIPTION = """Custom editing tool for viewing, creating and editing files
|
_STR_REPLACE_EDITOR_DESCRIPTION = """Custom editing tool for viewing, creating and editing files
|
||||||
* State is persistent across command calls and discussions with the user
|
* State is persistent across command calls and discussions with the user
|
||||||
* If `path` is a file, `view` displays the result of applying `cat -n`. If `path` is a directory, `view` lists non-hidden files and directories up to 2 levels deep
|
* If `path` is a file, `view` displays the result of applying `cat -n`. If `path` is a directory, `view` lists non-hidden files and directories up to 2 levels deep
|
||||||
@ -35,17 +47,17 @@ Notes for using the `str_replace` command:
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
def maybe_truncate(content: str, truncate_after: int | None = MAX_RESPONSE_LEN):
|
def maybe_truncate(
|
||||||
|
content: str, truncate_after: Optional[int] = MAX_RESPONSE_LEN
|
||||||
|
) -> str:
|
||||||
"""Truncate content and append a notice if content exceeds the specified length."""
|
"""Truncate content and append a notice if content exceeds the specified length."""
|
||||||
return (
|
if not truncate_after or len(content) <= truncate_after:
|
||||||
content
|
return content
|
||||||
if not truncate_after or len(content) <= truncate_after
|
return content[:truncate_after] + TRUNCATED_MESSAGE
|
||||||
else content[:truncate_after] + TRUNCATED_MESSAGE
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class StrReplaceEditor(BaseTool):
|
class StrReplaceEditor(BaseTool):
|
||||||
"""A tool for executing bash commands"""
|
"""A tool for viewing, creating, and editing files with sandbox support."""
|
||||||
|
|
||||||
name: str = "str_replace_editor"
|
name: str = "str_replace_editor"
|
||||||
description: str = _STR_REPLACE_EDITOR_DESCRIPTION
|
description: str = _STR_REPLACE_EDITOR_DESCRIPTION
|
||||||
@ -85,8 +97,19 @@ class StrReplaceEditor(BaseTool):
|
|||||||
},
|
},
|
||||||
"required": ["command", "path"],
|
"required": ["command", "path"],
|
||||||
}
|
}
|
||||||
|
_file_history: DefaultDict[PathLike, List[str]] = defaultdict(list)
|
||||||
|
_local_operator: LocalFileOperator = LocalFileOperator()
|
||||||
|
# todo: Sandbox resources need to be destroyed at the appropriate time.
|
||||||
|
_sandbox_operator: SandboxFileOperator = SandboxFileOperator()
|
||||||
|
|
||||||
_file_history: list = defaultdict(list)
|
# def _get_operator(self, use_sandbox: bool) -> FileOperator:
|
||||||
|
def _get_operator(self) -> FileOperator:
|
||||||
|
"""Get the appropriate file operator based on execution mode."""
|
||||||
|
return (
|
||||||
|
self._sandbox_operator
|
||||||
|
if config.sandbox.use_sandbox
|
||||||
|
else self._local_operator
|
||||||
|
)
|
||||||
|
|
||||||
async def execute(
|
async def execute(
|
||||||
self,
|
self,
|
||||||
@ -98,24 +121,30 @@ class StrReplaceEditor(BaseTool):
|
|||||||
old_str: str | None = None,
|
old_str: str | None = None,
|
||||||
new_str: str | None = None,
|
new_str: str | None = None,
|
||||||
insert_line: int | None = None,
|
insert_line: int | None = None,
|
||||||
**kwargs,
|
**kwargs: Any,
|
||||||
) -> str:
|
) -> str:
|
||||||
_path = Path(path)
|
"""Execute a file operation command."""
|
||||||
self.validate_path(command, _path)
|
# Get the appropriate file operator
|
||||||
|
operator = self._get_operator()
|
||||||
|
|
||||||
|
# Validate path and command combination
|
||||||
|
await self.validate_path(command, path, operator)
|
||||||
|
|
||||||
|
# Execute the appropriate command
|
||||||
if command == "view":
|
if command == "view":
|
||||||
result = await self.view(_path, view_range)
|
result = await self.view(path, view_range, operator)
|
||||||
elif command == "create":
|
elif command == "create":
|
||||||
if file_text is None:
|
if file_text is None:
|
||||||
raise ToolError("Parameter `file_text` is required for command: create")
|
raise ToolError("Parameter `file_text` is required for command: create")
|
||||||
self.write_file(_path, file_text)
|
await operator.write_file(path, file_text)
|
||||||
self._file_history[_path].append(file_text)
|
self._file_history[path].append(file_text)
|
||||||
result = ToolResult(output=f"File created successfully at: {_path}")
|
result = ToolResult(output=f"File created successfully at: {path}")
|
||||||
elif command == "str_replace":
|
elif command == "str_replace":
|
||||||
if old_str is None:
|
if old_str is None:
|
||||||
raise ToolError(
|
raise ToolError(
|
||||||
"Parameter `old_str` is required for command: str_replace"
|
"Parameter `old_str` is required for command: str_replace"
|
||||||
)
|
)
|
||||||
result = self.str_replace(_path, old_str, new_str)
|
result = await self.str_replace(path, old_str, new_str, operator)
|
||||||
elif command == "insert":
|
elif command == "insert":
|
||||||
if insert_line is None:
|
if insert_line is None:
|
||||||
raise ToolError(
|
raise ToolError(
|
||||||
@ -123,92 +152,149 @@ class StrReplaceEditor(BaseTool):
|
|||||||
)
|
)
|
||||||
if new_str is None:
|
if new_str is None:
|
||||||
raise ToolError("Parameter `new_str` is required for command: insert")
|
raise ToolError("Parameter `new_str` is required for command: insert")
|
||||||
result = self.insert(_path, insert_line, new_str)
|
result = await self.insert(path, insert_line, new_str, operator)
|
||||||
elif command == "undo_edit":
|
elif command == "undo_edit":
|
||||||
result = self.undo_edit(_path)
|
result = await self.undo_edit(path, operator)
|
||||||
else:
|
else:
|
||||||
|
# This should be caught by type checking, but we include it for safety
|
||||||
raise ToolError(
|
raise ToolError(
|
||||||
f'Unrecognized command {command}. The allowed commands for the {self.name} tool are: {", ".join(get_args(Command))}'
|
f'Unrecognized command {command}. The allowed commands for the {self.name} tool are: {", ".join(get_args(Command))}'
|
||||||
)
|
)
|
||||||
|
|
||||||
return str(result)
|
return str(result)
|
||||||
|
|
||||||
def validate_path(self, command: str, path: Path):
|
async def validate_path(
|
||||||
"""
|
self, command: str, path: str, operator: FileOperator
|
||||||
Check that the path/command combination is valid.
|
) -> None:
|
||||||
"""
|
"""Validate path and command combination based on execution environment."""
|
||||||
# Check if its an absolute path
|
# Check if path is absolute
|
||||||
if not path.is_absolute():
|
if not path.startswith("/"):
|
||||||
suggested_path = Path("") / path
|
suggested_path = f"/{path}"
|
||||||
raise ToolError(
|
raise ToolError(
|
||||||
f"The path {path} is not an absolute path, it should start with `/`. Maybe you meant {suggested_path}?"
|
f"The path {path} is not an absolute path, it should start with `/`. "
|
||||||
|
f"Maybe you meant {suggested_path}?"
|
||||||
)
|
)
|
||||||
# Check if path exists
|
|
||||||
if not path.exists() and command != "create":
|
# Only check if path exists for non-create commands
|
||||||
|
if command != "create":
|
||||||
|
if not await operator.exists(path):
|
||||||
raise ToolError(
|
raise ToolError(
|
||||||
f"The path {path} does not exist. Please provide a valid path."
|
f"The path {path} does not exist. Please provide a valid path."
|
||||||
)
|
)
|
||||||
if path.exists() and command == "create":
|
|
||||||
raise ToolError(
|
# Check if path is a directory
|
||||||
f"File already exists at: {path}. Cannot overwrite files using command `create`."
|
is_dir = await operator.is_directory(path)
|
||||||
)
|
if is_dir and command != "view":
|
||||||
# Check if the path points to a directory
|
|
||||||
if path.is_dir():
|
|
||||||
if command != "view":
|
|
||||||
raise ToolError(
|
raise ToolError(
|
||||||
f"The path {path} is a directory and only the `view` command can be used on directories"
|
f"The path {path} is a directory and only the `view` command can be used on directories"
|
||||||
)
|
)
|
||||||
|
|
||||||
async def view(self, path: Path, view_range: list[int] | None = None):
|
# Check if file exists for create command
|
||||||
"""Implement the view command"""
|
elif command == "create":
|
||||||
if path.is_dir():
|
exists = await operator.exists(path)
|
||||||
|
if exists:
|
||||||
|
raise ToolError(
|
||||||
|
f"File already exists at: {path}. Cannot overwrite files using command `create`."
|
||||||
|
)
|
||||||
|
|
||||||
|
async def view(
|
||||||
|
self,
|
||||||
|
path: PathLike,
|
||||||
|
view_range: Optional[List[int]] = None,
|
||||||
|
operator: FileOperator = None,
|
||||||
|
) -> CLIResult:
|
||||||
|
"""Display file or directory content."""
|
||||||
|
# Determine if path is a directory
|
||||||
|
is_dir = await operator.is_directory(path)
|
||||||
|
|
||||||
|
if is_dir:
|
||||||
|
# Directory handling
|
||||||
if view_range:
|
if view_range:
|
||||||
raise ToolError(
|
raise ToolError(
|
||||||
"The `view_range` parameter is not allowed when `path` points to a directory."
|
"The `view_range` parameter is not allowed when `path` points to a directory."
|
||||||
)
|
)
|
||||||
|
|
||||||
_, stdout, stderr = await run(
|
return await self._view_directory(path, operator)
|
||||||
rf"find {path} -maxdepth 2 -not -path '*/\.*'"
|
else:
|
||||||
)
|
# File handling
|
||||||
|
return await self._view_file(path, operator, view_range)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def _view_directory(path: PathLike, operator: FileOperator) -> CLIResult:
|
||||||
|
"""Display directory contents."""
|
||||||
|
find_cmd = f"find {path} -maxdepth 2 -not -path '*/\\.*'"
|
||||||
|
|
||||||
|
# Execute command using the operator
|
||||||
|
returncode, stdout, stderr = await operator.run_command(find_cmd)
|
||||||
|
|
||||||
if not stderr:
|
if not stderr:
|
||||||
stdout = f"Here's the files and directories up to 2 levels deep in {path}, excluding hidden items:\n{stdout}\n"
|
stdout = (
|
||||||
|
f"Here's the files and directories up to 2 levels deep in {path}, "
|
||||||
|
f"excluding hidden items:\n{stdout}\n"
|
||||||
|
)
|
||||||
|
|
||||||
return CLIResult(output=stdout, error=stderr)
|
return CLIResult(output=stdout, error=stderr)
|
||||||
|
|
||||||
file_content = self.read_file(path)
|
async def _view_file(
|
||||||
|
self,
|
||||||
|
path: PathLike,
|
||||||
|
operator: FileOperator,
|
||||||
|
view_range: Optional[List[int]] = None,
|
||||||
|
) -> CLIResult:
|
||||||
|
"""Display file content, optionally within a specified line range."""
|
||||||
|
# Read file content
|
||||||
|
file_content = await operator.read_file(path)
|
||||||
init_line = 1
|
init_line = 1
|
||||||
|
|
||||||
|
# Apply view range if specified
|
||||||
if view_range:
|
if view_range:
|
||||||
if len(view_range) != 2 or not all(isinstance(i, int) for i in view_range):
|
if len(view_range) != 2 or not all(isinstance(i, int) for i in view_range):
|
||||||
raise ToolError(
|
raise ToolError(
|
||||||
"Invalid `view_range`. It should be a list of two integers."
|
"Invalid `view_range`. It should be a list of two integers."
|
||||||
)
|
)
|
||||||
|
|
||||||
file_lines = file_content.split("\n")
|
file_lines = file_content.split("\n")
|
||||||
n_lines_file = len(file_lines)
|
n_lines_file = len(file_lines)
|
||||||
init_line, final_line = view_range
|
init_line, final_line = view_range
|
||||||
|
|
||||||
|
# Validate view range
|
||||||
if init_line < 1 or init_line > n_lines_file:
|
if init_line < 1 or init_line > n_lines_file:
|
||||||
raise ToolError(
|
raise ToolError(
|
||||||
f"Invalid `view_range`: {view_range}. Its first element `{init_line}` should be within the range of lines of the file: {[1, n_lines_file]}"
|
f"Invalid `view_range`: {view_range}. Its first element `{init_line}` should be "
|
||||||
|
f"within the range of lines of the file: {[1, n_lines_file]}"
|
||||||
)
|
)
|
||||||
if final_line > n_lines_file:
|
if final_line > n_lines_file:
|
||||||
raise ToolError(
|
raise ToolError(
|
||||||
f"Invalid `view_range`: {view_range}. Its second element `{final_line}` should be smaller than the number of lines in the file: `{n_lines_file}`"
|
f"Invalid `view_range`: {view_range}. Its second element `{final_line}` should be "
|
||||||
|
f"smaller than the number of lines in the file: `{n_lines_file}`"
|
||||||
)
|
)
|
||||||
if final_line != -1 and final_line < init_line:
|
if final_line != -1 and final_line < init_line:
|
||||||
raise ToolError(
|
raise ToolError(
|
||||||
f"Invalid `view_range`: {view_range}. Its second element `{final_line}` should be larger or equal than its first `{init_line}`"
|
f"Invalid `view_range`: {view_range}. Its second element `{final_line}` should be "
|
||||||
|
f"larger or equal than its first `{init_line}`"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Apply range
|
||||||
if final_line == -1:
|
if final_line == -1:
|
||||||
file_content = "\n".join(file_lines[init_line - 1 :])
|
file_content = "\n".join(file_lines[init_line - 1 :])
|
||||||
else:
|
else:
|
||||||
file_content = "\n".join(file_lines[init_line - 1 : final_line])
|
file_content = "\n".join(file_lines[init_line - 1 : final_line])
|
||||||
|
|
||||||
|
# Format and return result
|
||||||
return CLIResult(
|
return CLIResult(
|
||||||
output=self._make_output(file_content, str(path), init_line=init_line)
|
output=self._make_output(file_content, str(path), init_line=init_line)
|
||||||
)
|
)
|
||||||
|
|
||||||
def str_replace(self, path: Path, old_str: str, new_str: str | None):
|
async def str_replace(
|
||||||
"""Implement the str_replace command, which replaces old_str with new_str in the file content"""
|
self,
|
||||||
# Read the file content
|
path: PathLike,
|
||||||
file_content = self.read_file(path).expandtabs()
|
old_str: str,
|
||||||
|
new_str: Optional[str] = None,
|
||||||
|
operator: FileOperator = None,
|
||||||
|
) -> CLIResult:
|
||||||
|
"""Replace a unique string in a file with a new string."""
|
||||||
|
# Read file content and expand tabs
|
||||||
|
file_content = (await operator.read_file(path)).expandtabs()
|
||||||
old_str = old_str.expandtabs()
|
old_str = old_str.expandtabs()
|
||||||
new_str = new_str.expandtabs() if new_str is not None else ""
|
new_str = new_str.expandtabs() if new_str is not None else ""
|
||||||
|
|
||||||
@ -219,6 +305,7 @@ class StrReplaceEditor(BaseTool):
|
|||||||
f"No replacement was performed, old_str `{old_str}` did not appear verbatim in {path}."
|
f"No replacement was performed, old_str `{old_str}` did not appear verbatim in {path}."
|
||||||
)
|
)
|
||||||
elif occurrences > 1:
|
elif occurrences > 1:
|
||||||
|
# Find line numbers of occurrences
|
||||||
file_content_lines = file_content.split("\n")
|
file_content_lines = file_content.split("\n")
|
||||||
lines = [
|
lines = [
|
||||||
idx + 1
|
idx + 1
|
||||||
@ -226,16 +313,17 @@ class StrReplaceEditor(BaseTool):
|
|||||||
if old_str in line
|
if old_str in line
|
||||||
]
|
]
|
||||||
raise ToolError(
|
raise ToolError(
|
||||||
f"No replacement was performed. Multiple occurrences of old_str `{old_str}` in lines {lines}. Please ensure it is unique"
|
f"No replacement was performed. Multiple occurrences of old_str `{old_str}` "
|
||||||
|
f"in lines {lines}. Please ensure it is unique"
|
||||||
)
|
)
|
||||||
|
|
||||||
# Replace old_str with new_str
|
# Replace old_str with new_str
|
||||||
new_file_content = file_content.replace(old_str, new_str)
|
new_file_content = file_content.replace(old_str, new_str)
|
||||||
|
|
||||||
# Write the new content to the file
|
# Write the new content to the file
|
||||||
self.write_file(path, new_file_content)
|
await operator.write_file(path, new_file_content)
|
||||||
|
|
||||||
# Save the content to history
|
# Save the original content to history
|
||||||
self._file_history[path].append(file_content)
|
self._file_history[path].append(file_content)
|
||||||
|
|
||||||
# Create a snippet of the edited section
|
# Create a snippet of the edited section
|
||||||
@ -253,36 +341,50 @@ class StrReplaceEditor(BaseTool):
|
|||||||
|
|
||||||
return CLIResult(output=success_msg)
|
return CLIResult(output=success_msg)
|
||||||
|
|
||||||
def insert(self, path: Path, insert_line: int, new_str: str):
|
async def insert(
|
||||||
"""Implement the insert command, which inserts new_str at the specified line in the file content."""
|
self,
|
||||||
file_text = self.read_file(path).expandtabs()
|
path: PathLike,
|
||||||
|
insert_line: int,
|
||||||
|
new_str: str,
|
||||||
|
operator: FileOperator = None,
|
||||||
|
) -> CLIResult:
|
||||||
|
"""Insert text at a specific line in a file."""
|
||||||
|
# Read and prepare content
|
||||||
|
file_text = (await operator.read_file(path)).expandtabs()
|
||||||
new_str = new_str.expandtabs()
|
new_str = new_str.expandtabs()
|
||||||
file_text_lines = file_text.split("\n")
|
file_text_lines = file_text.split("\n")
|
||||||
n_lines_file = len(file_text_lines)
|
n_lines_file = len(file_text_lines)
|
||||||
|
|
||||||
|
# Validate insert_line
|
||||||
if insert_line < 0 or insert_line > n_lines_file:
|
if insert_line < 0 or insert_line > n_lines_file:
|
||||||
raise ToolError(
|
raise ToolError(
|
||||||
f"Invalid `insert_line` parameter: {insert_line}. It should be within the range of lines of the file: {[0, n_lines_file]}"
|
f"Invalid `insert_line` parameter: {insert_line}. It should be within "
|
||||||
|
f"the range of lines of the file: {[0, n_lines_file]}"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Perform insertion
|
||||||
new_str_lines = new_str.split("\n")
|
new_str_lines = new_str.split("\n")
|
||||||
new_file_text_lines = (
|
new_file_text_lines = (
|
||||||
file_text_lines[:insert_line]
|
file_text_lines[:insert_line]
|
||||||
+ new_str_lines
|
+ new_str_lines
|
||||||
+ file_text_lines[insert_line:]
|
+ file_text_lines[insert_line:]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Create a snippet for preview
|
||||||
snippet_lines = (
|
snippet_lines = (
|
||||||
file_text_lines[max(0, insert_line - SNIPPET_LINES) : insert_line]
|
file_text_lines[max(0, insert_line - SNIPPET_LINES) : insert_line]
|
||||||
+ new_str_lines
|
+ new_str_lines
|
||||||
+ file_text_lines[insert_line : insert_line + SNIPPET_LINES]
|
+ file_text_lines[insert_line : insert_line + SNIPPET_LINES]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Join lines and write to file
|
||||||
new_file_text = "\n".join(new_file_text_lines)
|
new_file_text = "\n".join(new_file_text_lines)
|
||||||
snippet = "\n".join(snippet_lines)
|
snippet = "\n".join(snippet_lines)
|
||||||
|
|
||||||
self.write_file(path, new_file_text)
|
await operator.write_file(path, new_file_text)
|
||||||
self._file_history[path].append(file_text)
|
self._file_history[path].append(file_text)
|
||||||
|
|
||||||
|
# Prepare success message
|
||||||
success_msg = f"The file {path} has been edited. "
|
success_msg = f"The file {path} has been edited. "
|
||||||
success_msg += self._make_output(
|
success_msg += self._make_output(
|
||||||
snippet,
|
snippet,
|
||||||
@ -290,51 +392,43 @@ class StrReplaceEditor(BaseTool):
|
|||||||
max(1, insert_line - SNIPPET_LINES + 1),
|
max(1, insert_line - SNIPPET_LINES + 1),
|
||||||
)
|
)
|
||||||
success_msg += "Review the changes and make sure they are as expected (correct indentation, no duplicate lines, etc). Edit the file again if necessary."
|
success_msg += "Review the changes and make sure they are as expected (correct indentation, no duplicate lines, etc). Edit the file again if necessary."
|
||||||
|
|
||||||
return CLIResult(output=success_msg)
|
return CLIResult(output=success_msg)
|
||||||
|
|
||||||
def undo_edit(self, path: Path):
|
async def undo_edit(
|
||||||
"""Implement the undo_edit command."""
|
self, path: PathLike, operator: FileOperator = None
|
||||||
|
) -> CLIResult:
|
||||||
|
"""Revert the last edit made to a file."""
|
||||||
if not self._file_history[path]:
|
if not self._file_history[path]:
|
||||||
raise ToolError(f"No edit history found for {path}.")
|
raise ToolError(f"No edit history found for {path}.")
|
||||||
|
|
||||||
old_text = self._file_history[path].pop()
|
old_text = self._file_history[path].pop()
|
||||||
self.write_file(path, old_text)
|
await operator.write_file(path, old_text)
|
||||||
|
|
||||||
return CLIResult(
|
return CLIResult(
|
||||||
output=f"Last edit to {path} undone successfully. {self._make_output(old_text, str(path))}"
|
output=f"Last edit to {path} undone successfully. {self._make_output(old_text, str(path))}"
|
||||||
)
|
)
|
||||||
|
|
||||||
def read_file(self, path: Path):
|
|
||||||
"""Read the content of a file from a given path; raise a ToolError if an error occurs."""
|
|
||||||
try:
|
|
||||||
return path.read_text()
|
|
||||||
except Exception as e:
|
|
||||||
raise ToolError(f"Ran into {e} while trying to read {path}") from None
|
|
||||||
|
|
||||||
def write_file(self, path: Path, file: str):
|
|
||||||
"""Write the content of a file to a given path; raise a ToolError if an error occurs."""
|
|
||||||
try:
|
|
||||||
path.write_text(file)
|
|
||||||
except Exception as e:
|
|
||||||
raise ToolError(f"Ran into {e} while trying to write to {path}") from None
|
|
||||||
|
|
||||||
def _make_output(
|
def _make_output(
|
||||||
self,
|
self,
|
||||||
file_content: str,
|
file_content: str,
|
||||||
file_descriptor: str,
|
file_descriptor: str,
|
||||||
init_line: int = 1,
|
init_line: int = 1,
|
||||||
expand_tabs: bool = True,
|
expand_tabs: bool = True,
|
||||||
):
|
) -> str:
|
||||||
"""Generate output for the CLI based on the content of a file."""
|
"""Format file content for display with line numbers."""
|
||||||
file_content = maybe_truncate(file_content)
|
file_content = maybe_truncate(file_content)
|
||||||
if expand_tabs:
|
if expand_tabs:
|
||||||
file_content = file_content.expandtabs()
|
file_content = file_content.expandtabs()
|
||||||
|
|
||||||
|
# Add line numbers to each line
|
||||||
file_content = "\n".join(
|
file_content = "\n".join(
|
||||||
[
|
[
|
||||||
f"{i + init_line:6}\t{line}"
|
f"{i + init_line:6}\t{line}"
|
||||||
for i, line in enumerate(file_content.split("\n"))
|
for i, line in enumerate(file_content.split("\n"))
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
return (
|
return (
|
||||||
f"Here's the result of running `cat -n` on {file_descriptor}:\n"
|
f"Here's the result of running `cat -n` on {file_descriptor}:\n"
|
||||||
+ file_content
|
+ file_content
|
||||||
|
@ -11,3 +11,13 @@ temperature = 0.0
|
|||||||
model = "claude-3-5-sonnet"
|
model = "claude-3-5-sonnet"
|
||||||
base_url = "https://api.openai.com/v1"
|
base_url = "https://api.openai.com/v1"
|
||||||
api_key = "sk-..."
|
api_key = "sk-..."
|
||||||
|
|
||||||
|
# Sandbox configuration
|
||||||
|
[sandbox]
|
||||||
|
use_sandbox = false
|
||||||
|
image = "python:3.10-slim"
|
||||||
|
work_dir = "/workspace"
|
||||||
|
memory_limit = "1g" # 512m
|
||||||
|
cpu_limit = 2.0
|
||||||
|
timeout = 300
|
||||||
|
network_enabled = false
|
||||||
|
@ -19,3 +19,7 @@ aiofiles~=24.1.0
|
|||||||
pydantic_core~=2.27.2
|
pydantic_core~=2.27.2
|
||||||
colorama~=0.4.6
|
colorama~=0.4.6
|
||||||
playwright~=1.49.1
|
playwright~=1.49.1
|
||||||
|
|
||||||
|
docker~=7.1.0
|
||||||
|
pytest~=8.3.5
|
||||||
|
pytest-asyncio~=0.25.3
|
||||||
|
110
tests/sandbox/test_client.py
Normal file
110
tests/sandbox/test_client.py
Normal file
@ -0,0 +1,110 @@
|
|||||||
|
import tempfile
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import AsyncGenerator
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import pytest_asyncio
|
||||||
|
|
||||||
|
from app.config import SandboxConfig
|
||||||
|
from app.sandbox.client import LocalSandboxClient, create_sandbox_client
|
||||||
|
|
||||||
|
|
||||||
|
@pytest_asyncio.fixture(scope="function")
|
||||||
|
async def local_client() -> AsyncGenerator[LocalSandboxClient, None]:
|
||||||
|
"""Creates a local sandbox client for testing."""
|
||||||
|
client = await create_sandbox_client()
|
||||||
|
try:
|
||||||
|
yield client
|
||||||
|
finally:
|
||||||
|
await client.cleanup()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="function")
|
||||||
|
def temp_dir() -> Path:
|
||||||
|
"""Creates a temporary directory for testing."""
|
||||||
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||||
|
yield Path(tmp_dir)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_sandbox_creation(local_client: LocalSandboxClient):
|
||||||
|
"""Tests sandbox creation with specific configuration."""
|
||||||
|
config = SandboxConfig(
|
||||||
|
image="python:3.10-slim",
|
||||||
|
work_dir="/workspace",
|
||||||
|
memory_limit="512m",
|
||||||
|
cpu_limit=0.5,
|
||||||
|
)
|
||||||
|
|
||||||
|
await local_client.create(config)
|
||||||
|
result = await local_client.run_command("python3 --version")
|
||||||
|
assert "Python 3.10" in result
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_local_command_execution(local_client: LocalSandboxClient):
|
||||||
|
"""Tests command execution in local sandbox."""
|
||||||
|
await local_client.create()
|
||||||
|
|
||||||
|
result = await local_client.run_command("echo 'test'")
|
||||||
|
assert result.strip() == "test"
|
||||||
|
|
||||||
|
with pytest.raises(Exception):
|
||||||
|
await local_client.run_command("sleep 10", timeout=1)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_local_file_operations(local_client: LocalSandboxClient, temp_dir: Path):
|
||||||
|
"""Tests file operations in local sandbox."""
|
||||||
|
await local_client.create()
|
||||||
|
|
||||||
|
# Test write and read operations
|
||||||
|
test_content = "Hello, World!"
|
||||||
|
await local_client.write_file("/workspace/test.txt", test_content)
|
||||||
|
content = await local_client.read_file("/workspace/test.txt")
|
||||||
|
assert content.strip() == test_content
|
||||||
|
|
||||||
|
# Test copying file to container
|
||||||
|
src_file = temp_dir / "src.txt"
|
||||||
|
src_file.write_text("Copy to container")
|
||||||
|
await local_client.copy_to(str(src_file), "/workspace/copied.txt")
|
||||||
|
content = await local_client.read_file("/workspace/copied.txt")
|
||||||
|
assert content.strip() == "Copy to container"
|
||||||
|
|
||||||
|
# Test copying file from container
|
||||||
|
dst_file = temp_dir / "dst.txt"
|
||||||
|
await local_client.copy_from("/workspace/test.txt", str(dst_file))
|
||||||
|
assert dst_file.read_text().strip() == test_content
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_local_volume_binding(local_client: LocalSandboxClient, temp_dir: Path):
|
||||||
|
"""Tests volume binding in local sandbox."""
|
||||||
|
bind_path = str(temp_dir)
|
||||||
|
volume_bindings = {bind_path: "/data"}
|
||||||
|
|
||||||
|
await local_client.create(volume_bindings=volume_bindings)
|
||||||
|
|
||||||
|
test_file = temp_dir / "test.txt"
|
||||||
|
test_file.write_text("Volume test")
|
||||||
|
|
||||||
|
content = await local_client.read_file("/data/test.txt")
|
||||||
|
assert "Volume test" in content
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_local_error_handling(local_client: LocalSandboxClient):
|
||||||
|
"""Tests error handling in local sandbox."""
|
||||||
|
await local_client.create()
|
||||||
|
|
||||||
|
with pytest.raises(Exception) as exc:
|
||||||
|
await local_client.read_file("/nonexistent.txt")
|
||||||
|
assert "not found" in str(exc.value).lower()
|
||||||
|
|
||||||
|
with pytest.raises(Exception) as exc:
|
||||||
|
await local_client.copy_from("/nonexistent.txt", "local.txt")
|
||||||
|
assert "not found" in str(exc.value).lower()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
pytest.main(["-v", __file__])
|
104
tests/sandbox/test_docker_terminal.py
Normal file
104
tests/sandbox/test_docker_terminal.py
Normal file
@ -0,0 +1,104 @@
|
|||||||
|
"""Tests for the AsyncDockerizedTerminal implementation."""
|
||||||
|
|
||||||
|
import docker
|
||||||
|
import pytest
|
||||||
|
import pytest_asyncio
|
||||||
|
|
||||||
|
from app.sandbox.core.terminal import AsyncDockerizedTerminal
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module")
|
||||||
|
def docker_client():
|
||||||
|
"""Fixture providing a Docker client."""
|
||||||
|
return docker.from_env()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest_asyncio.fixture(scope="module")
|
||||||
|
async def docker_container(docker_client):
|
||||||
|
"""Fixture providing a test Docker container."""
|
||||||
|
container = docker_client.containers.run(
|
||||||
|
"python:3.10-slim",
|
||||||
|
"tail -f /dev/null",
|
||||||
|
name="test_container",
|
||||||
|
detach=True,
|
||||||
|
remove=True,
|
||||||
|
)
|
||||||
|
yield container
|
||||||
|
container.stop()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest_asyncio.fixture
|
||||||
|
async def terminal(docker_container):
|
||||||
|
"""Fixture providing an initialized AsyncDockerizedTerminal instance."""
|
||||||
|
terminal = AsyncDockerizedTerminal(
|
||||||
|
docker_container,
|
||||||
|
working_dir="/workspace",
|
||||||
|
env_vars={"TEST_VAR": "test_value"},
|
||||||
|
default_timeout=30,
|
||||||
|
)
|
||||||
|
await terminal.init()
|
||||||
|
yield terminal
|
||||||
|
await terminal.close()
|
||||||
|
|
||||||
|
|
||||||
|
class TestAsyncDockerizedTerminal:
|
||||||
|
"""Test cases for AsyncDockerizedTerminal."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_basic_command_execution(self, terminal):
|
||||||
|
"""Test basic command execution functionality."""
|
||||||
|
result = await terminal.run_command("echo 'Hello World'")
|
||||||
|
assert "Hello World" in result
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_environment_variables(self, terminal):
|
||||||
|
"""Test environment variable setting and access."""
|
||||||
|
result = await terminal.run_command("echo $TEST_VAR")
|
||||||
|
assert "test_value" in result
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_working_directory(self, terminal):
|
||||||
|
"""Test working directory setup."""
|
||||||
|
result = await terminal.run_command("pwd")
|
||||||
|
assert "/workspace" == result
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_command_timeout(self, docker_container):
|
||||||
|
"""Test command timeout functionality."""
|
||||||
|
terminal = AsyncDockerizedTerminal(docker_container, default_timeout=1)
|
||||||
|
await terminal.init()
|
||||||
|
try:
|
||||||
|
with pytest.raises(TimeoutError):
|
||||||
|
await terminal.run_command("sleep 5")
|
||||||
|
finally:
|
||||||
|
await terminal.close()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_multiple_commands(self, terminal):
|
||||||
|
"""Test execution of multiple commands in sequence."""
|
||||||
|
cmd1 = await terminal.run_command("echo 'First'")
|
||||||
|
cmd2 = await terminal.run_command("echo 'Second'")
|
||||||
|
assert "First" in cmd1
|
||||||
|
assert "Second" in cmd2
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_session_cleanup(self, docker_container):
|
||||||
|
"""Test proper cleanup of resources."""
|
||||||
|
terminal = AsyncDockerizedTerminal(docker_container)
|
||||||
|
await terminal.init()
|
||||||
|
assert terminal.session is not None
|
||||||
|
await terminal.close()
|
||||||
|
# Verify session is properly cleaned up
|
||||||
|
# Note: session object still exists, but internal connection is closed
|
||||||
|
assert terminal.session is not None
|
||||||
|
|
||||||
|
|
||||||
|
# Configure pytest-asyncio
|
||||||
|
def pytest_configure(config):
|
||||||
|
"""Configure pytest-asyncio."""
|
||||||
|
config.addinivalue_line("asyncio_mode", "strict")
|
||||||
|
config.addinivalue_line("asyncio_default_fixture_loop_scope", "function")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
pytest.main(["-v", __file__])
|
152
tests/sandbox/test_sandbox.py
Normal file
152
tests/sandbox/test_sandbox.py
Normal file
@ -0,0 +1,152 @@
|
|||||||
|
import pytest
|
||||||
|
import pytest_asyncio
|
||||||
|
|
||||||
|
from app.sandbox.core.sandbox import DockerSandbox, SandboxConfig
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module")
|
||||||
|
def sandbox_config():
|
||||||
|
"""Creates sandbox configuration for testing."""
|
||||||
|
return SandboxConfig(
|
||||||
|
image="python:3.10-slim",
|
||||||
|
work_dir="/workspace",
|
||||||
|
memory_limit="1g",
|
||||||
|
cpu_limit=0.5,
|
||||||
|
network_enabled=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest_asyncio.fixture(scope="module")
|
||||||
|
async def sandbox(sandbox_config):
|
||||||
|
"""Creates and manages a test sandbox instance."""
|
||||||
|
sandbox = DockerSandbox(sandbox_config)
|
||||||
|
await sandbox.create()
|
||||||
|
try:
|
||||||
|
yield sandbox
|
||||||
|
finally:
|
||||||
|
await sandbox.cleanup()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_sandbox_working_directory(sandbox):
|
||||||
|
"""Tests sandbox working directory configuration."""
|
||||||
|
result = await sandbox.terminal.run_command("pwd")
|
||||||
|
assert result.strip() == "/workspace"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_sandbox_file_operations(sandbox):
|
||||||
|
"""Tests sandbox file read/write operations."""
|
||||||
|
# Test file writing
|
||||||
|
test_content = "Hello from sandbox!"
|
||||||
|
await sandbox.write_file("/workspace/test.txt", test_content)
|
||||||
|
|
||||||
|
# Test file reading
|
||||||
|
content = await sandbox.read_file("/workspace/test.txt")
|
||||||
|
assert content.strip() == test_content
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_sandbox_python_execution(sandbox):
|
||||||
|
"""Tests Python code execution in sandbox."""
|
||||||
|
# Write test file
|
||||||
|
await sandbox.write_file("/workspace/test.txt", "Hello from file!")
|
||||||
|
|
||||||
|
# Write Python script
|
||||||
|
python_code = """
|
||||||
|
print("Hello from Python!")
|
||||||
|
with open('/workspace/test.txt') as f:
|
||||||
|
print(f.read())
|
||||||
|
"""
|
||||||
|
await sandbox.write_file("/workspace/test.py", python_code)
|
||||||
|
|
||||||
|
# Execute script and verify output
|
||||||
|
result = await sandbox.terminal.run_command("python3 /workspace/test.py")
|
||||||
|
assert "Hello from Python!" in result
|
||||||
|
assert "Hello from file!" in result
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_sandbox_file_persistence(sandbox):
|
||||||
|
"""Tests file persistence in sandbox."""
|
||||||
|
# Create multiple files
|
||||||
|
files = {
|
||||||
|
"file1.txt": "Content 1",
|
||||||
|
"file2.txt": "Content 2",
|
||||||
|
"nested/file3.txt": "Content 3",
|
||||||
|
}
|
||||||
|
|
||||||
|
# Write files
|
||||||
|
for path, content in files.items():
|
||||||
|
await sandbox.write_file(f"/workspace/{path}", content)
|
||||||
|
|
||||||
|
# Verify file contents
|
||||||
|
for path, expected_content in files.items():
|
||||||
|
content = await sandbox.read_file(f"/workspace/{path}")
|
||||||
|
assert content.strip() == expected_content
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_sandbox_python_environment(sandbox):
|
||||||
|
"""Tests Python environment configuration."""
|
||||||
|
# Test Python version
|
||||||
|
result = await sandbox.terminal.run_command("python3 --version")
|
||||||
|
assert "Python 3.10" in result
|
||||||
|
|
||||||
|
# Test basic module imports
|
||||||
|
python_code = """
|
||||||
|
import sys
|
||||||
|
import os
|
||||||
|
import json
|
||||||
|
print("Python is working!")
|
||||||
|
"""
|
||||||
|
await sandbox.write_file("/workspace/env_test.py", python_code)
|
||||||
|
result = await sandbox.terminal.run_command("python3 /workspace/env_test.py")
|
||||||
|
assert "Python is working!" in result
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_sandbox_network_access(sandbox):
|
||||||
|
"""Tests sandbox network access."""
|
||||||
|
if not sandbox.config.network_enabled:
|
||||||
|
pytest.skip("Network access is disabled")
|
||||||
|
|
||||||
|
# Test network connectivity
|
||||||
|
await sandbox.terminal.run_command("apt update && apt install curl -y")
|
||||||
|
result = await sandbox.terminal.run_command("curl -I https://www.example.com")
|
||||||
|
assert "HTTP/2 200" in result
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_sandbox_cleanup(sandbox_config):
|
||||||
|
"""Tests sandbox cleanup process."""
|
||||||
|
sandbox = DockerSandbox(sandbox_config)
|
||||||
|
await sandbox.create()
|
||||||
|
|
||||||
|
# Create test files
|
||||||
|
await sandbox.write_file("/workspace/test.txt", "test")
|
||||||
|
container_id = sandbox.terminal.container.id
|
||||||
|
# Perform cleanup
|
||||||
|
await sandbox.cleanup()
|
||||||
|
|
||||||
|
# Verify container has been removed
|
||||||
|
import docker
|
||||||
|
|
||||||
|
client = docker.from_env()
|
||||||
|
containers = client.containers.list(all=True)
|
||||||
|
assert not any(c.id == container_id for c in containers)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_sandbox_error_handling():
|
||||||
|
"""Tests error handling with invalid configuration."""
|
||||||
|
# Test invalid configuration
|
||||||
|
invalid_config = SandboxConfig(image="nonexistent:latest", work_dir="/invalid")
|
||||||
|
|
||||||
|
sandbox = DockerSandbox(invalid_config)
|
||||||
|
with pytest.raises(Exception):
|
||||||
|
await sandbox.create()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
pytest.main(["-v", __file__])
|
138
tests/sandbox/test_sandbox_manager.py
Normal file
138
tests/sandbox/test_sandbox_manager.py
Normal file
@ -0,0 +1,138 @@
|
|||||||
|
import asyncio
|
||||||
|
import os
|
||||||
|
import tempfile
|
||||||
|
from typing import AsyncGenerator
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import pytest_asyncio
|
||||||
|
|
||||||
|
from app.sandbox.core.manager import SandboxManager
|
||||||
|
|
||||||
|
|
||||||
|
@pytest_asyncio.fixture(scope="function")
|
||||||
|
async def manager() -> AsyncGenerator[SandboxManager, None]:
|
||||||
|
"""Creates a sandbox manager instance.
|
||||||
|
|
||||||
|
Uses function scope to ensure each test case has its own manager instance.
|
||||||
|
"""
|
||||||
|
manager = SandboxManager(max_sandboxes=2, idle_timeout=60, cleanup_interval=30)
|
||||||
|
try:
|
||||||
|
yield manager
|
||||||
|
finally:
|
||||||
|
# Ensure all resources are cleaned up
|
||||||
|
await manager.cleanup()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def temp_file():
|
||||||
|
"""Creates a temporary test file."""
|
||||||
|
with tempfile.NamedTemporaryFile(mode="w+", delete=False) as f:
|
||||||
|
f.write("test content")
|
||||||
|
path = f.name
|
||||||
|
try:
|
||||||
|
yield path
|
||||||
|
finally:
|
||||||
|
if os.path.exists(path):
|
||||||
|
os.unlink(path)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_create_sandbox(manager):
|
||||||
|
"""Tests sandbox creation."""
|
||||||
|
# Create default sandbox
|
||||||
|
sandbox_id = await manager.create_sandbox()
|
||||||
|
assert sandbox_id in manager._sandboxes
|
||||||
|
assert sandbox_id in manager._last_used
|
||||||
|
|
||||||
|
# Verify sandbox functionality
|
||||||
|
sandbox = await manager.get_sandbox(sandbox_id)
|
||||||
|
result = await sandbox.run_command("echo 'test'")
|
||||||
|
assert result.strip() == "test"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_max_sandboxes_limit(manager):
|
||||||
|
"""Tests maximum sandbox limit enforcement."""
|
||||||
|
created_sandboxes = []
|
||||||
|
try:
|
||||||
|
# Create maximum number of sandboxes
|
||||||
|
for _ in range(manager.max_sandboxes):
|
||||||
|
sandbox_id = await manager.create_sandbox()
|
||||||
|
created_sandboxes.append(sandbox_id)
|
||||||
|
|
||||||
|
# Verify created sandbox count
|
||||||
|
assert len(manager._sandboxes) == manager.max_sandboxes
|
||||||
|
|
||||||
|
# Attempting to create additional sandbox should fail
|
||||||
|
with pytest.raises(RuntimeError) as exc_info:
|
||||||
|
await manager.create_sandbox()
|
||||||
|
|
||||||
|
# Verify error message
|
||||||
|
expected_message = (
|
||||||
|
f"Maximum number of sandboxes ({manager.max_sandboxes}) reached"
|
||||||
|
)
|
||||||
|
assert str(exc_info.value) == expected_message
|
||||||
|
|
||||||
|
finally:
|
||||||
|
# Clean up all created sandboxes
|
||||||
|
for sandbox_id in created_sandboxes:
|
||||||
|
try:
|
||||||
|
await manager.delete_sandbox(sandbox_id)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Failed to cleanup sandbox {sandbox_id}: {e}")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_nonexistent_sandbox(manager):
|
||||||
|
"""Tests retrieving a non-existent sandbox."""
|
||||||
|
with pytest.raises(KeyError, match="Sandbox .* not found"):
|
||||||
|
await manager.get_sandbox("nonexistent-id")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_sandbox_cleanup(manager):
|
||||||
|
"""Tests sandbox cleanup functionality."""
|
||||||
|
sandbox_id = await manager.create_sandbox()
|
||||||
|
assert sandbox_id in manager._sandboxes
|
||||||
|
|
||||||
|
await manager.delete_sandbox(sandbox_id)
|
||||||
|
assert sandbox_id not in manager._sandboxes
|
||||||
|
assert sandbox_id not in manager._last_used
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_idle_sandbox_cleanup(manager):
|
||||||
|
"""Tests automatic cleanup of idle sandboxes."""
|
||||||
|
# Set short idle timeout
|
||||||
|
manager.idle_timeout = 0.1
|
||||||
|
|
||||||
|
sandbox_id = await manager.create_sandbox()
|
||||||
|
assert sandbox_id in manager._sandboxes
|
||||||
|
|
||||||
|
# Wait longer than idle timeout
|
||||||
|
await asyncio.sleep(0.2)
|
||||||
|
|
||||||
|
# Trigger cleanup
|
||||||
|
await manager._cleanup_idle_sandboxes()
|
||||||
|
assert sandbox_id not in manager._sandboxes
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_manager_cleanup(manager):
|
||||||
|
"""Tests manager cleanup functionality."""
|
||||||
|
# Create multiple sandboxes
|
||||||
|
sandbox_ids = []
|
||||||
|
for _ in range(2):
|
||||||
|
sandbox_id = await manager.create_sandbox()
|
||||||
|
sandbox_ids.append(sandbox_id)
|
||||||
|
|
||||||
|
# Clean up all resources
|
||||||
|
await manager.cleanup()
|
||||||
|
|
||||||
|
# Verify all sandboxes have been cleaned up
|
||||||
|
assert not manager._sandboxes
|
||||||
|
assert not manager._last_used
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
pytest.main(["-v", __file__])
|
Loading…
x
Reference in New Issue
Block a user