diff --git a/.gitignore b/.gitignore index 653fd83..bd23c42 100644 --- a/.gitignore +++ b/.gitignore @@ -178,3 +178,4 @@ data/ # Workspace workspace/ +.DS_Store diff --git a/app/config.py b/app/config.py index 2275fb4..f1968d2 100644 --- a/app/config.py +++ b/app/config.py @@ -23,8 +23,23 @@ class LLMSettings(BaseModel): temperature: float = Field(1.0, description="Sampling temperature") +class SandboxConfig(BaseModel): + """Configuration for the execution sandbox""" + + use_sandbox: bool = Field(False, description="Whether to use the sandbox") + image: str = Field("python:3.10-slim", description="Base image") + work_dir: str = Field("/workspace", description="Container working directory") + memory_limit: str = Field("512m", description="Memory limit") + cpu_limit: float = Field(1.0, description="CPU limit") + timeout: int = Field(300, description="Default command timeout (seconds)") + network_enabled: bool = Field( + False, description="Whether network access is allowed" + ) + + class AppConfig(BaseModel): llm: Dict[str, LLMSettings] + sandbox: SandboxConfig class Config: @@ -85,7 +100,8 @@ class Config: name: {**default_settings, **override_config} for name, override_config in llm_overrides.items() }, - } + }, + "sandbox": raw_config.get("sandbox", {}), } self._config = AppConfig(**config_dict) @@ -94,5 +110,9 @@ class Config: def llm(self) -> Dict[str, LLMSettings]: return self._config.llm + @property + def sandbox(self) -> SandboxConfig: + return self._config.sandbox + config = Config() diff --git a/app/sandbox/__init__.py b/app/sandbox/__init__.py new file mode 100644 index 0000000..ccf0df6 --- /dev/null +++ b/app/sandbox/__init__.py @@ -0,0 +1,30 @@ +""" +Docker Sandbox Module + +Provides secure containerized execution environment with resource limits +and isolation for running untrusted code. +""" +from app.sandbox.client import ( + BaseSandboxClient, + LocalSandboxClient, + create_sandbox_client, +) +from app.sandbox.core.exceptions import ( + SandboxError, + SandboxResourceError, + SandboxTimeoutError, +) +from app.sandbox.core.manager import SandboxManager +from app.sandbox.core.sandbox import DockerSandbox + + +__all__ = [ + "DockerSandbox", + "SandboxManager", + "BaseSandboxClient", + "LocalSandboxClient", + "create_sandbox_client", + "SandboxError", + "SandboxTimeoutError", + "SandboxResourceError", +] diff --git a/app/sandbox/client.py b/app/sandbox/client.py new file mode 100644 index 0000000..e2c412f --- /dev/null +++ b/app/sandbox/client.py @@ -0,0 +1,201 @@ +from abc import ABC, abstractmethod +from typing import Dict, Optional, Protocol + +from app.config import SandboxConfig +from app.sandbox.core.sandbox import DockerSandbox + + +class SandboxFileOperations(Protocol): + """Protocol for sandbox file operations.""" + + async def copy_from(self, container_path: str, local_path: str) -> None: + """Copies file from container to local. + + Args: + container_path: File path in container. + local_path: Local destination path. + """ + ... + + async def copy_to(self, local_path: str, container_path: str) -> None: + """Copies file from local to container. + + Args: + local_path: Local source file path. + container_path: Destination path in container. + """ + ... + + async def read_file(self, path: str) -> str: + """Reads file content from container. + + Args: + path: File path in container. + + Returns: + str: File content. + """ + ... + + async def write_file(self, path: str, content: str) -> None: + """Writes content to file in container. + + Args: + path: File path in container. + content: Content to write. + """ + ... + + +class BaseSandboxClient(ABC): + """Base sandbox client interface.""" + + @abstractmethod + async def create( + self, + config: Optional[SandboxConfig] = None, + volume_bindings: Optional[Dict[str, str]] = None, + ) -> None: + """Creates sandbox.""" + + @abstractmethod + async def run_command(self, command: str, timeout: Optional[int] = None) -> str: + """Executes command.""" + + @abstractmethod + async def copy_from(self, container_path: str, local_path: str) -> None: + """Copies file from container.""" + + @abstractmethod + async def copy_to(self, local_path: str, container_path: str) -> None: + """Copies file to container.""" + + @abstractmethod + async def read_file(self, path: str) -> str: + """Reads file.""" + + @abstractmethod + async def write_file(self, path: str, content: str) -> None: + """Writes file.""" + + @abstractmethod + async def cleanup(self) -> None: + """Cleans up resources.""" + + +class LocalSandboxClient(BaseSandboxClient): + """Local sandbox client implementation.""" + + def __init__(self): + """Initializes local sandbox client.""" + self.sandbox: Optional[DockerSandbox] = None + + async def create( + self, + config: Optional[SandboxConfig] = None, + volume_bindings: Optional[Dict[str, str]] = None, + ) -> None: + """Creates a sandbox. + + Args: + config: Sandbox configuration. + volume_bindings: Volume mappings. + + Raises: + RuntimeError: If sandbox creation fails. + """ + self.sandbox = DockerSandbox(config, volume_bindings) + await self.sandbox.create() + + async def run_command(self, command: str, timeout: Optional[int] = None) -> str: + """Runs command in sandbox. + + Args: + command: Command to execute. + timeout: Execution timeout in seconds. + + Returns: + Command output. + + Raises: + RuntimeError: If sandbox not initialized. + """ + if not self.sandbox: + raise RuntimeError("Sandbox not initialized") + return await self.sandbox.run_command(command, timeout) + + async def copy_from(self, container_path: str, local_path: str) -> None: + """Copies file from container to local. + + Args: + container_path: File path in container. + local_path: Local destination path. + + Raises: + RuntimeError: If sandbox not initialized. + """ + if not self.sandbox: + raise RuntimeError("Sandbox not initialized") + await self.sandbox.copy_from(container_path, local_path) + + async def copy_to(self, local_path: str, container_path: str) -> None: + """Copies file from local to container. + + Args: + local_path: Local source file path. + container_path: Destination path in container. + + Raises: + RuntimeError: If sandbox not initialized. + """ + if not self.sandbox: + raise RuntimeError("Sandbox not initialized") + await self.sandbox.copy_to(local_path, container_path) + + async def read_file(self, path: str) -> str: + """Reads file from container. + + Args: + path: File path in container. + + Returns: + File content. + + Raises: + RuntimeError: If sandbox not initialized. + """ + if not self.sandbox: + raise RuntimeError("Sandbox not initialized") + return await self.sandbox.read_file(path) + + async def write_file(self, path: str, content: str) -> None: + """Writes file to container. + + Args: + path: File path in container. + content: File content. + + Raises: + RuntimeError: If sandbox not initialized. + """ + if not self.sandbox: + raise RuntimeError("Sandbox not initialized") + await self.sandbox.write_file(path, content) + + async def cleanup(self) -> None: + """Cleans up resources.""" + if self.sandbox: + await self.sandbox.cleanup() + self.sandbox = None + + +async def create_sandbox_client() -> LocalSandboxClient: + """Creates a sandbox client. + + Returns: + LocalSandboxClient: Sandbox client instance. + """ + return LocalSandboxClient() + + +SANDBOX_CLIENT = create_sandbox_client() diff --git a/app/sandbox/core/exceptions.py b/app/sandbox/core/exceptions.py new file mode 100644 index 0000000..5c1f0e8 --- /dev/null +++ b/app/sandbox/core/exceptions.py @@ -0,0 +1,17 @@ +"""Exception classes for the sandbox system. + +This module defines custom exceptions used throughout the sandbox system to +handle various error conditions in a structured way. +""" + + +class SandboxError(Exception): + """Base exception for sandbox-related errors.""" + + +class SandboxTimeoutError(SandboxError): + """Exception raised when a sandbox operation times out.""" + + +class SandboxResourceError(SandboxError): + """Exception raised for resource-related errors.""" diff --git a/app/sandbox/core/manager.py b/app/sandbox/core/manager.py new file mode 100644 index 0000000..6269d72 --- /dev/null +++ b/app/sandbox/core/manager.py @@ -0,0 +1,313 @@ +import asyncio +import uuid +from contextlib import asynccontextmanager +from typing import Dict, Optional, Set + +import docker +from docker.errors import APIError, ImageNotFound + +from app.config import SandboxConfig +from app.logger import logger +from app.sandbox.core.sandbox import DockerSandbox + + +class SandboxManager: + """Docker sandbox manager. + + Manages multiple DockerSandbox instances lifecycle including creation, + monitoring, and cleanup. Provides concurrent access control and automatic + cleanup mechanisms for sandbox resources. + + Attributes: + max_sandboxes: Maximum allowed number of sandboxes. + idle_timeout: Sandbox idle timeout in seconds. + cleanup_interval: Cleanup check interval in seconds. + _sandboxes: Active sandbox instance mapping. + _last_used: Last used time record for sandboxes. + """ + + def __init__( + self, + max_sandboxes: int = 100, + idle_timeout: int = 3600, + cleanup_interval: int = 300, + ): + """Initializes sandbox manager. + + Args: + max_sandboxes: Maximum sandbox count limit. + idle_timeout: Idle timeout in seconds. + cleanup_interval: Cleanup check interval in seconds. + """ + self.max_sandboxes = max_sandboxes + self.idle_timeout = idle_timeout + self.cleanup_interval = cleanup_interval + + # Docker client + self._client = docker.from_env() + + # Resource mappings + self._sandboxes: Dict[str, DockerSandbox] = {} + self._last_used: Dict[str, float] = {} + + # Concurrency control + self._locks: Dict[str, asyncio.Lock] = {} + self._global_lock = asyncio.Lock() + self._active_operations: Set[str] = set() + + # Cleanup task + self._cleanup_task: Optional[asyncio.Task] = None + self._is_shutting_down = False + + # Start automatic cleanup + self.start_cleanup_task() + + async def ensure_image(self, image: str) -> bool: + """Ensures Docker image is available. + + Args: + image: Image name. + + Returns: + bool: Whether image is available. + """ + try: + self._client.images.get(image) + return True + except ImageNotFound: + try: + logger.info(f"Pulling image {image}...") + await asyncio.get_event_loop().run_in_executor( + None, self._client.images.pull, image + ) + return True + except (APIError, Exception) as e: + logger.error(f"Failed to pull image {image}: {e}") + return False + + @asynccontextmanager + async def sandbox_operation(self, sandbox_id: str): + """Context manager for sandbox operations. + + Provides concurrency control and usage time updates. + + Args: + sandbox_id: Sandbox ID. + + Raises: + KeyError: If sandbox not found. + """ + if sandbox_id not in self._locks: + self._locks[sandbox_id] = asyncio.Lock() + + async with self._locks[sandbox_id]: + if sandbox_id not in self._sandboxes: + raise KeyError(f"Sandbox {sandbox_id} not found") + + self._active_operations.add(sandbox_id) + try: + self._last_used[sandbox_id] = asyncio.get_event_loop().time() + yield self._sandboxes[sandbox_id] + finally: + self._active_operations.remove(sandbox_id) + + async def create_sandbox( + self, + config: Optional[SandboxConfig] = None, + volume_bindings: Optional[Dict[str, str]] = None, + ) -> str: + """Creates a new sandbox instance. + + Args: + config: Sandbox configuration. + volume_bindings: Volume mapping configuration. + + Returns: + str: Sandbox ID. + + Raises: + RuntimeError: If max sandbox count reached or creation fails. + """ + async with self._global_lock: + if len(self._sandboxes) >= self.max_sandboxes: + raise RuntimeError( + f"Maximum number of sandboxes ({self.max_sandboxes}) reached" + ) + + config = config or SandboxConfig() + if not await self.ensure_image(config.image): + raise RuntimeError(f"Failed to ensure Docker image: {config.image}") + + sandbox_id = str(uuid.uuid4()) + try: + sandbox = DockerSandbox(config, volume_bindings) + await sandbox.create() + + self._sandboxes[sandbox_id] = sandbox + self._last_used[sandbox_id] = asyncio.get_event_loop().time() + self._locks[sandbox_id] = asyncio.Lock() + + logger.info(f"Created sandbox {sandbox_id}") + return sandbox_id + + except Exception as e: + logger.error(f"Failed to create sandbox: {e}") + if sandbox_id in self._sandboxes: + await self.delete_sandbox(sandbox_id) + raise RuntimeError(f"Failed to create sandbox: {e}") + + async def get_sandbox(self, sandbox_id: str) -> DockerSandbox: + """Gets a sandbox instance. + + Args: + sandbox_id: Sandbox ID. + + Returns: + DockerSandbox: Sandbox instance. + + Raises: + KeyError: If sandbox does not exist. + """ + async with self.sandbox_operation(sandbox_id) as sandbox: + return sandbox + + def start_cleanup_task(self) -> None: + """Starts automatic cleanup task.""" + + async def cleanup_loop(): + while not self._is_shutting_down: + try: + await self._cleanup_idle_sandboxes() + except Exception as e: + logger.error(f"Error in cleanup loop: {e}") + await asyncio.sleep(self.cleanup_interval) + + self._cleanup_task = asyncio.create_task(cleanup_loop()) + + async def _cleanup_idle_sandboxes(self) -> None: + """Cleans up idle sandboxes.""" + current_time = asyncio.get_event_loop().time() + to_cleanup = [] + + async with self._global_lock: + for sandbox_id, last_used in self._last_used.items(): + if ( + sandbox_id not in self._active_operations + and current_time - last_used > self.idle_timeout + ): + to_cleanup.append(sandbox_id) + + for sandbox_id in to_cleanup: + try: + await self.delete_sandbox(sandbox_id) + except Exception as e: + logger.error(f"Error cleaning up sandbox {sandbox_id}: {e}") + + async def cleanup(self) -> None: + """Cleans up all resources.""" + logger.info("Starting manager cleanup...") + self._is_shutting_down = True + + # Cancel cleanup task + if self._cleanup_task: + self._cleanup_task.cancel() + try: + await asyncio.wait_for(self._cleanup_task, timeout=1.0) + except (asyncio.CancelledError, asyncio.TimeoutError): + pass + + # Get all sandbox IDs to clean up + async with self._global_lock: + sandbox_ids = list(self._sandboxes.keys()) + + # Concurrently clean up all sandboxes + cleanup_tasks = [] + for sandbox_id in sandbox_ids: + task = asyncio.create_task(self._safe_delete_sandbox(sandbox_id)) + cleanup_tasks.append(task) + + if cleanup_tasks: + # Wait for all cleanup tasks to complete, with timeout to avoid infinite waiting + try: + await asyncio.wait(cleanup_tasks, timeout=30.0) + except asyncio.TimeoutError: + logger.error("Sandbox cleanup timed out") + + # Clean up remaining references + self._sandboxes.clear() + self._last_used.clear() + self._locks.clear() + self._active_operations.clear() + + logger.info("Manager cleanup completed") + + async def _safe_delete_sandbox(self, sandbox_id: str) -> None: + """Safely deletes a single sandbox. + + Args: + sandbox_id: Sandbox ID to delete. + """ + try: + if sandbox_id in self._active_operations: + logger.warning( + f"Sandbox {sandbox_id} has active operations, waiting for completion" + ) + for _ in range(10): # Wait at most 10 times + await asyncio.sleep(0.5) + if sandbox_id not in self._active_operations: + break + else: + logger.warning( + f"Timeout waiting for sandbox {sandbox_id} operations to complete" + ) + + # Get reference to sandbox object + sandbox = self._sandboxes.get(sandbox_id) + if sandbox: + await sandbox.cleanup() + + # Remove sandbox record from manager + async with self._global_lock: + self._sandboxes.pop(sandbox_id, None) + self._last_used.pop(sandbox_id, None) + self._locks.pop(sandbox_id, None) + logger.info(f"Deleted sandbox {sandbox_id}") + except Exception as e: + logger.error(f"Error during cleanup of sandbox {sandbox_id}: {e}") + + async def delete_sandbox(self, sandbox_id: str) -> None: + """Deletes specified sandbox. + + Args: + sandbox_id: Sandbox ID. + """ + if sandbox_id not in self._sandboxes: + return + + try: + await self._safe_delete_sandbox(sandbox_id) + except Exception as e: + logger.error(f"Failed to delete sandbox {sandbox_id}: {e}") + + async def __aenter__(self) -> "SandboxManager": + """Async context manager entry.""" + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb) -> None: + """Async context manager exit.""" + await self.cleanup() + + def get_stats(self) -> Dict: + """Gets manager statistics. + + Returns: + Dict: Statistics information. + """ + return { + "total_sandboxes": len(self._sandboxes), + "active_operations": len(self._active_operations), + "max_sandboxes": self.max_sandboxes, + "idle_timeout": self.idle_timeout, + "cleanup_interval": self.cleanup_interval, + "is_shutting_down": self._is_shutting_down, + } diff --git a/app/sandbox/core/sandbox.py b/app/sandbox/core/sandbox.py new file mode 100644 index 0000000..400b12a --- /dev/null +++ b/app/sandbox/core/sandbox.py @@ -0,0 +1,462 @@ +import asyncio +import io +import os +import tarfile +import tempfile +import uuid +from typing import Dict, Optional + +import docker +from docker.errors import NotFound +from docker.models.containers import Container + +from app.config import SandboxConfig +from app.sandbox.core.exceptions import SandboxTimeoutError +from app.sandbox.core.terminal import AsyncDockerizedTerminal + + +class DockerSandbox: + """Docker sandbox environment. + + Provides a containerized execution environment with resource limits, + file operations, and command execution capabilities. + + Attributes: + config: Sandbox configuration. + volume_bindings: Volume mapping configuration. + client: Docker client. + container: Docker container instance. + terminal: Container terminal interface. + """ + + def __init__( + self, + config: Optional[SandboxConfig] = None, + volume_bindings: Optional[Dict[str, str]] = None, + ): + """Initializes a sandbox instance. + + Args: + config: Sandbox configuration. Default configuration used if None. + volume_bindings: Volume mappings in {host_path: container_path} format. + """ + self.config = config or SandboxConfig() + self.volume_bindings = volume_bindings or {} + self.client = docker.from_env() + self.container: Optional[Container] = None + self.terminal: Optional[AsyncDockerizedTerminal] = None + + async def create(self) -> "DockerSandbox": + """Creates and starts the sandbox container. + + Returns: + Current sandbox instance. + + Raises: + docker.errors.APIError: If Docker API call fails. + RuntimeError: If container creation or startup fails. + """ + try: + # Prepare container config + host_config = self.client.api.create_host_config( + mem_limit=self.config.memory_limit, + cpu_period=100000, + cpu_quota=int(100000 * self.config.cpu_limit), + network_mode="none" if not self.config.network_enabled else "bridge", + binds=self._prepare_volume_bindings(), + ) + + # Generate unique container name with sandbox_ prefix + container_name = f"sandbox_{uuid.uuid4().hex[:8]}" + + # Create container + container = await asyncio.to_thread( + self.client.api.create_container, + image=self.config.image, + command="tail -f /dev/null", + hostname="sandbox", + working_dir=self.config.work_dir, + host_config=host_config, + name=container_name, + tty=True, + detach=True, + ) + + self.container = self.client.containers.get(container["Id"]) + + # Start container + await asyncio.to_thread(self.container.start) + + # Initialize terminal + self.terminal = AsyncDockerizedTerminal( + container["Id"], + self.config.work_dir, + env_vars={"PYTHONUNBUFFERED": "1"} + # Ensure Python output is not buffered + ) + await self.terminal.init() + + return self + + except Exception as e: + await self.cleanup() # Ensure resources are cleaned up + raise RuntimeError(f"Failed to create sandbox: {e}") from e + + def _prepare_volume_bindings(self) -> Dict[str, Dict[str, str]]: + """Prepares volume binding configuration. + + Returns: + Volume binding configuration dictionary. + """ + bindings = {} + + # Create and add working directory mapping + work_dir = self._ensure_host_dir(self.config.work_dir) + bindings[work_dir] = {"bind": self.config.work_dir, "mode": "rw"} + + # Add custom volume bindings + for host_path, container_path in self.volume_bindings.items(): + bindings[host_path] = {"bind": container_path, "mode": "rw"} + + return bindings + + @staticmethod + def _ensure_host_dir(path: str) -> str: + """Ensures directory exists on the host. + + Args: + path: Directory path. + + Returns: + Actual path on the host. + """ + host_path = os.path.join( + tempfile.gettempdir(), + f"sandbox_{os.path.basename(path)}_{os.urandom(4).hex()}", + ) + os.makedirs(host_path, exist_ok=True) + return host_path + + async def run_command(self, cmd: str, timeout: Optional[int] = None) -> str: + """Runs a command in the sandbox. + + Args: + cmd: Command to execute. + timeout: Timeout in seconds. + + Returns: + Command output as string. + + Raises: + RuntimeError: If sandbox not initialized or command execution fails. + TimeoutError: If command execution times out. + """ + if not self.terminal: + raise RuntimeError("Sandbox not initialized") + + try: + return await self.terminal.run_command( + cmd, timeout=timeout or self.config.timeout + ) + except TimeoutError: + raise SandboxTimeoutError( + f"Command execution timed out after {timeout or self.config.timeout} seconds" + ) + + async def read_file(self, path: str) -> str: + """Reads a file from the container. + + Args: + path: File path. + + Returns: + File contents as string. + + Raises: + FileNotFoundError: If file does not exist. + RuntimeError: If read operation fails. + """ + if not self.container: + raise RuntimeError("Sandbox not initialized") + + try: + # Get file archive + resolved_path = self._safe_resolve_path(path) + tar_stream, _ = await asyncio.to_thread( + self.container.get_archive, resolved_path + ) + + # Read file content from tar stream + content = await self._read_from_tar(tar_stream) + return content.decode("utf-8") + + except NotFound: + raise FileNotFoundError(f"File not found: {path}") + except Exception as e: + raise RuntimeError(f"Failed to read file: {e}") + + async def write_file(self, path: str, content: str) -> None: + """Writes content to a file in the container. + + Args: + path: Target path. + content: File content. + + Raises: + RuntimeError: If write operation fails. + """ + if not self.container: + raise RuntimeError("Sandbox not initialized") + + try: + resolved_path = self._safe_resolve_path(path) + parent_dir = os.path.dirname(resolved_path) + + # Create parent directory + if parent_dir: + await self.run_command(f"mkdir -p {parent_dir}") + + # Prepare file data + tar_stream = await self._create_tar_stream( + os.path.basename(path), content.encode("utf-8") + ) + + # Write file + await asyncio.to_thread( + self.container.put_archive, parent_dir or "/", tar_stream + ) + + except Exception as e: + raise RuntimeError(f"Failed to write file: {e}") + + def _safe_resolve_path(self, path: str) -> str: + """Safely resolves container path, preventing path traversal. + + Args: + path: Original path. + + Returns: + Resolved absolute path. + + Raises: + ValueError: If path contains potentially unsafe patterns. + """ + # Check for path traversal attempts + if ".." in path.split("/"): + raise ValueError("Path contains potentially unsafe patterns") + + resolved = ( + os.path.join(self.config.work_dir, path) + if not os.path.isabs(path) + else path + ) + return resolved + + async def copy_from(self, src_path: str, dst_path: str) -> None: + """Copies a file from the container. + + Args: + src_path: Source file path (container). + dst_path: Destination path (host). + + Raises: + FileNotFoundError: If source file does not exist. + RuntimeError: If copy operation fails. + """ + try: + # Ensure destination file's parent directory exists + parent_dir = os.path.dirname(dst_path) + if parent_dir: + os.makedirs(parent_dir, exist_ok=True) + + # Get file stream + resolved_src = self._safe_resolve_path(src_path) + stream, stat = await asyncio.to_thread( + self.container.get_archive, resolved_src + ) + + # Create temporary directory to extract file + with tempfile.TemporaryDirectory() as tmp_dir: + # Write stream to temporary file + tar_path = os.path.join(tmp_dir, "temp.tar") + with open(tar_path, "wb") as f: + for chunk in stream: + f.write(chunk) + + # Extract file + with tarfile.open(tar_path) as tar: + members = tar.getmembers() + if not members: + raise FileNotFoundError(f"Source file is empty: {src_path}") + + # If destination is a directory, we should preserve relative path structure + if os.path.isdir(dst_path): + tar.extractall(dst_path) + else: + # If destination is a file, we only extract the source file's content + if len(members) > 1: + raise RuntimeError( + f"Source path is a directory but destination is a file: {src_path}" + ) + + with open(dst_path, "wb") as dst: + src_file = tar.extractfile(members[0]) + if src_file is None: + raise RuntimeError( + f"Failed to extract file: {src_path}" + ) + dst.write(src_file.read()) + + except docker.errors.NotFound: + raise FileNotFoundError(f"Source file not found: {src_path}") + except Exception as e: + raise RuntimeError(f"Failed to copy file: {e}") + + async def copy_to(self, src_path: str, dst_path: str) -> None: + """Copies a file to the container. + + Args: + src_path: Source file path (host). + dst_path: Destination path (container). + + Raises: + FileNotFoundError: If source file does not exist. + RuntimeError: If copy operation fails. + """ + try: + if not os.path.exists(src_path): + raise FileNotFoundError(f"Source file not found: {src_path}") + + # Create destination directory in container + resolved_dst = self._safe_resolve_path(dst_path) + container_dir = os.path.dirname(resolved_dst) + if container_dir: + await self.run_command(f"mkdir -p {container_dir}") + + # Create tar file to upload + with tempfile.TemporaryDirectory() as tmp_dir: + tar_path = os.path.join(tmp_dir, "temp.tar") + with tarfile.open(tar_path, "w") as tar: + # Handle directory source path + if os.path.isdir(src_path): + os.path.basename(src_path.rstrip("/")) + for root, _, files in os.walk(src_path): + for file in files: + file_path = os.path.join(root, file) + arcname = os.path.join( + os.path.basename(dst_path), + os.path.relpath(file_path, src_path), + ) + tar.add(file_path, arcname=arcname) + else: + # Add single file to tar + tar.add(src_path, arcname=os.path.basename(dst_path)) + + # Read tar file content + with open(tar_path, "rb") as f: + data = f.read() + + # Upload to container + await asyncio.to_thread( + self.container.put_archive, + os.path.dirname(resolved_dst) or "/", + data, + ) + + # Verify file was created successfully + try: + await self.run_command(f"test -e {resolved_dst}") + except Exception: + raise RuntimeError(f"Failed to verify file creation: {dst_path}") + + except FileNotFoundError: + raise + except Exception as e: + raise RuntimeError(f"Failed to copy file: {e}") + + @staticmethod + async def _create_tar_stream(name: str, content: bytes) -> io.BytesIO: + """Creates a tar file stream. + + Args: + name: Filename. + content: File content. + + Returns: + Tar file stream. + """ + tar_stream = io.BytesIO() + with tarfile.open(fileobj=tar_stream, mode="w") as tar: + tarinfo = tarfile.TarInfo(name=name) + tarinfo.size = len(content) + tar.addfile(tarinfo, io.BytesIO(content)) + tar_stream.seek(0) + return tar_stream + + @staticmethod + async def _read_from_tar(tar_stream) -> bytes: + """Reads file content from a tar stream. + + Args: + tar_stream: Tar file stream. + + Returns: + File content. + + Raises: + RuntimeError: If read operation fails. + """ + with tempfile.NamedTemporaryFile() as tmp: + for chunk in tar_stream: + tmp.write(chunk) + tmp.seek(0) + + with tarfile.open(fileobj=tmp) as tar: + member = tar.next() + if not member: + raise RuntimeError("Empty tar archive") + + file_content = tar.extractfile(member) + if not file_content: + raise RuntimeError("Failed to extract file content") + + return file_content.read() + + async def cleanup(self) -> None: + """Cleans up sandbox resources.""" + errors = [] + try: + if self.terminal: + try: + await self.terminal.close() + except Exception as e: + errors.append(f"Terminal cleanup error: {e}") + finally: + self.terminal = None + + if self.container: + try: + await asyncio.to_thread(self.container.stop, timeout=5) + except Exception as e: + errors.append(f"Container stop error: {e}") + + try: + await asyncio.to_thread(self.container.remove, force=True) + except Exception as e: + errors.append(f"Container remove error: {e}") + finally: + self.container = None + + except Exception as e: + errors.append(f"General cleanup error: {e}") + + if errors: + print(f"Warning: Errors during cleanup: {', '.join(errors)}") + + async def __aenter__(self) -> "DockerSandbox": + """Async context manager entry.""" + return await self.create() + + async def __aexit__(self, exc_type, exc_val, exc_tb) -> None: + """Async context manager exit.""" + await self.cleanup() diff --git a/app/sandbox/core/terminal.py b/app/sandbox/core/terminal.py new file mode 100644 index 0000000..aee5184 --- /dev/null +++ b/app/sandbox/core/terminal.py @@ -0,0 +1,346 @@ +""" +Asynchronous Docker Terminal + +This module provides asynchronous terminal functionality for Docker containers, +allowing interactive command execution with timeout control. +""" + +import asyncio +import re +import socket +from typing import Dict, Optional, Tuple, Union + +import docker +from docker import APIClient +from docker.errors import APIError +from docker.models.containers import Container + + +class DockerSession: + def __init__(self, container_id: str) -> None: + """Initializes a Docker session. + + Args: + container_id: ID of the Docker container. + """ + self.api = APIClient() + self.container_id = container_id + self.exec_id = None + self.socket = None + + async def create(self, working_dir: str, env_vars: Dict[str, str]) -> None: + """Creates an interactive session with the container. + + Args: + working_dir: Working directory inside the container. + env_vars: Environment variables to set. + + Raises: + RuntimeError: If socket connection fails. + """ + startup_command = [ + "bash", + "-c", + f"cd {working_dir} && " + "PROMPT_COMMAND='' " + "PS1='$ ' " + "exec bash --norc --noprofile", + ] + + exec_data = self.api.exec_create( + self.container_id, + startup_command, + stdin=True, + tty=True, + stdout=True, + stderr=True, + privileged=True, + user="root", + environment={**env_vars, "TERM": "dumb", "PS1": "$ ", "PROMPT_COMMAND": ""}, + ) + self.exec_id = exec_data["Id"] + + socket_data = self.api.exec_start( + self.exec_id, socket=True, tty=True, stream=True, demux=True + ) + + if hasattr(socket_data, "_sock"): + self.socket = socket_data._sock + self.socket.setblocking(False) + else: + raise RuntimeError("Failed to get socket connection") + + await self._read_until_prompt() + + async def close(self) -> None: + """Cleans up session resources. + + 1. Sends exit command + 2. Closes socket connection + 3. Checks and cleans up exec instance + """ + try: + if self.socket: + # Send exit command to close bash session + try: + self.socket.sendall(b"exit\n") + # Allow time for command execution + await asyncio.sleep(0.1) + except: + pass # Ignore sending errors, continue cleanup + + # Close socket connection + try: + self.socket.shutdown(socket.SHUT_RDWR) + except: + pass # Some platforms may not support shutdown + + self.socket.close() + self.socket = None + + if self.exec_id: + try: + # Check exec instance status + exec_inspect = self.api.exec_inspect(self.exec_id) + if exec_inspect.get("Running", False): + # If still running, wait for it to complete + await asyncio.sleep(0.5) + except: + pass # Ignore inspection errors, continue cleanup + + self.exec_id = None + + except Exception as e: + # Log error but don't raise, ensure cleanup continues + print(f"Warning: Error during session cleanup: {e}") + + async def _read_until_prompt(self) -> str: + """Reads output until prompt is found. + + Returns: + String containing output up to the prompt. + + Raises: + socket.error: If socket communication fails. + """ + buffer = b"" + while b"$ " not in buffer: + try: + chunk = self.socket.recv(4096) + if chunk: + buffer += chunk + except socket.error as e: + if e.errno == socket.EWOULDBLOCK: + await asyncio.sleep(0.1) + continue + raise + return buffer.decode("utf-8") + + async def execute(self, command: str, timeout: Optional[int] = None) -> str: + """Executes a command and returns cleaned output. + + Args: + command: Shell command to execute. + timeout: Maximum execution time in seconds. + + Returns: + Command output as string with prompt markers removed. + + Raises: + RuntimeError: If session not initialized or execution fails. + TimeoutError: If command execution exceeds timeout. + """ + if not self.socket: + raise RuntimeError("Session not initialized") + + try: + # Sanitize command to prevent shell injection + sanitized_command = self._sanitize_command(command) + full_command = f"{sanitized_command}\necho $?\n" + self.socket.sendall(full_command.encode()) + + async def read_output() -> str: + buffer = b"" + result_lines = [] + command_sent = False + + while True: + try: + chunk = self.socket.recv(4096) + if not chunk: + break + + buffer += chunk + lines = buffer.split(b"\n") + + buffer = lines[-1] + lines = lines[:-1] + + for line in lines: + line = line.rstrip(b"\r") + + if not command_sent: + command_sent = True + continue + + if line.strip() == b"echo $?" or line.strip().isdigit(): + continue + + if line.strip(): + result_lines.append(line) + + if buffer.endswith(b"$ "): + break + + except socket.error as e: + if e.errno == socket.EWOULDBLOCK: + await asyncio.sleep(0.1) + continue + raise + + output = b"\n".join(result_lines).decode("utf-8") + output = re.sub(r"\n\$ echo \$\$?.*$", "", output) + + return output + + if timeout: + result = await asyncio.wait_for(read_output(), timeout) + else: + result = await read_output() + + return result.strip() + + except asyncio.TimeoutError: + raise TimeoutError(f"Command execution timed out after {timeout} seconds") + except Exception as e: + raise RuntimeError(f"Failed to execute command: {e}") + + def _sanitize_command(self, command: str) -> str: + """Sanitizes the command string to prevent shell injection. + + Args: + command: Raw command string. + + Returns: + Sanitized command string. + + Raises: + ValueError: If command contains potentially dangerous patterns. + """ + + # Additional checks for specific risky commands + risky_commands = [ + "rm -rf /", + "rm -rf /*", + "mkfs", + "dd if=/dev/zero", + ":(){:|:&};:", + "chmod -R 777 /", + "chown -R", + ] + + for risky in risky_commands: + if risky in command.lower(): + raise ValueError( + f"Command contains potentially dangerous operation: {risky}" + ) + + return command + + +class AsyncDockerizedTerminal: + def __init__( + self, + container: Union[str, Container], + working_dir: str = "/workspace", + env_vars: Optional[Dict[str, str]] = None, + default_timeout: int = 60, + ) -> None: + """Initializes an asynchronous terminal for Docker containers. + + Args: + container: Docker container ID or Container object. + working_dir: Working directory inside the container. + env_vars: Environment variables to set. + default_timeout: Default command execution timeout in seconds. + """ + self.client = docker.from_env() + self.container = ( + container + if isinstance(container, Container) + else self.client.containers.get(container) + ) + self.working_dir = working_dir + self.env_vars = env_vars or {} + self.default_timeout = default_timeout + self.session = None + + async def init(self) -> None: + """Initializes the terminal environment. + + Ensures working directory exists and creates an interactive session. + + Raises: + RuntimeError: If initialization fails. + """ + await self._ensure_workdir() + + self.session = DockerSession(self.container.id) + await self.session.create(self.working_dir, self.env_vars) + + async def _ensure_workdir(self) -> None: + """Ensures working directory exists in container. + + Raises: + RuntimeError: If directory creation fails. + """ + try: + await self._exec_simple(f"mkdir -p {self.working_dir}") + except APIError as e: + raise RuntimeError(f"Failed to create working directory: {e}") + + async def _exec_simple(self, cmd: str) -> Tuple[int, str]: + """Executes a simple command using Docker's exec_run. + + Args: + cmd: Command to execute. + + Returns: + Tuple of (exit_code, output). + """ + result = await asyncio.to_thread( + self.container.exec_run, cmd, environment=self.env_vars + ) + return result.exit_code, result.output.decode("utf-8") + + async def run_command(self, cmd: str, timeout: Optional[int] = None) -> str: + """Runs a command in the container with timeout. + + Args: + cmd: Shell command to execute. + timeout: Maximum execution time in seconds. + + Returns: + Command output as string. + + Raises: + RuntimeError: If terminal not initialized. + """ + if not self.session: + raise RuntimeError("Terminal not initialized") + + return await self.session.execute(cmd, timeout=timeout or self.default_timeout) + + async def close(self) -> None: + """Closes the terminal session.""" + if self.session: + await self.session.close() + + async def __aenter__(self) -> "AsyncDockerizedTerminal": + """Async context manager entry.""" + await self.init() + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb) -> None: + """Async context manager exit.""" + await self.close() diff --git a/app/tool/file_operators.py b/app/tool/file_operators.py new file mode 100644 index 0000000..5f38d30 --- /dev/null +++ b/app/tool/file_operators.py @@ -0,0 +1,156 @@ +"""File operation interfaces and implementations for local and sandbox environments.""" + +import asyncio +from pathlib import Path +from typing import Optional, Protocol, Tuple, Union, runtime_checkable + +from app.config import SandboxConfig +from app.exceptions import ToolError +from app.sandbox.client import SANDBOX_CLIENT + + +PathLike = Union[str, Path] + + +@runtime_checkable +class FileOperator(Protocol): + """Interface for file operations in different environments.""" + + async def read_file(self, path: PathLike) -> str: + """Read content from a file.""" + ... + + async def write_file(self, path: PathLike, content: str) -> None: + """Write content to a file.""" + ... + + async def is_directory(self, path: PathLike) -> bool: + """Check if path points to a directory.""" + ... + + async def exists(self, path: PathLike) -> bool: + """Check if path exists.""" + ... + + async def run_command( + self, cmd: str, timeout: Optional[float] = 120.0 + ) -> Tuple[int, str, str]: + """Run a shell command and return (return_code, stdout, stderr).""" + ... + + +class LocalFileOperator(FileOperator): + """File operations implementation for local filesystem.""" + + async def read_file(self, path: PathLike) -> str: + """Read content from a local file.""" + try: + return Path(path).read_text() + except Exception as e: + raise ToolError(f"Failed to read {path}: {str(e)}") from None + + async def write_file(self, path: PathLike, content: str) -> None: + """Write content to a local file.""" + try: + Path(path).write_text(content) + except Exception as e: + raise ToolError(f"Failed to write to {path}: {str(e)}") from None + + async def is_directory(self, path: PathLike) -> bool: + """Check if path points to a directory.""" + return Path(path).is_dir() + + async def exists(self, path: PathLike) -> bool: + """Check if path exists.""" + return Path(path).exists() + + async def run_command( + self, cmd: str, timeout: Optional[float] = 120.0 + ) -> Tuple[int, str, str]: + """Run a shell command locally.""" + process = await asyncio.create_subprocess_shell( + cmd, stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE + ) + + try: + stdout, stderr = await asyncio.wait_for( + process.communicate(), timeout=timeout + ) + return ( + process.returncode or 0, + stdout.decode(), + stderr.decode(), + ) + except asyncio.TimeoutError as exc: + try: + process.kill() + except ProcessLookupError: + pass + raise TimeoutError( + f"Command '{cmd}' timed out after {timeout} seconds" + ) from exc + + +class SandboxFileOperator(FileOperator): + """File operations implementation for sandbox environment.""" + + def __init__(self): + self.sandbox_client = SANDBOX_CLIENT + + async def _ensure_sandbox_initialized(self): + """Ensure sandbox is initialized.""" + if not self.sandbox_client.sandbox: + await self.sandbox_client.create(config=SandboxConfig()) + + async def read_file(self, path: PathLike) -> str: + """Read content from a file in sandbox.""" + await self._ensure_sandbox_initialized() + try: + return await self.sandbox_client.read_file(str(path)) + except Exception as e: + raise ToolError(f"Failed to read {path} in sandbox: {str(e)}") from None + + async def write_file(self, path: PathLike, content: str) -> None: + """Write content to a file in sandbox.""" + await self._ensure_sandbox_initialized() + try: + await self.sandbox_client.write_file(str(path), content) + except Exception as e: + raise ToolError(f"Failed to write to {path} in sandbox: {str(e)}") from None + + async def is_directory(self, path: PathLike) -> bool: + """Check if path points to a directory in sandbox.""" + await self._ensure_sandbox_initialized() + result = await self.sandbox_client.run_command( + f"test -d {path} && echo 'true' || echo 'false'" + ) + return result.strip() == "true" + + async def exists(self, path: PathLike) -> bool: + """Check if path exists in sandbox.""" + await self._ensure_sandbox_initialized() + result = await self.sandbox_client.run_command( + f"test -e {path} && echo 'true' || echo 'false'" + ) + return result.strip() == "true" + + async def run_command( + self, cmd: str, timeout: Optional[float] = 120.0 + ) -> Tuple[int, str, str]: + """Run a command in sandbox environment.""" + await self._ensure_sandbox_initialized() + try: + stdout = await self.sandbox_client.run_command( + cmd, timeout=int(timeout) if timeout else None + ) + return ( + 0, # Always return 0 since we don't have explicit return code from sandbox + stdout, + "", # No stderr capture in the current sandbox implementation + ) + except TimeoutError as exc: + raise TimeoutError( + f"Command '{cmd}' timed out after {timeout} seconds in sandbox" + ) from exc + except Exception as exc: + return 1, "", f"Error executing command in sandbox: {str(exc)}" diff --git a/app/tool/run.py b/app/tool/run.py deleted file mode 100644 index 8896c58..0000000 --- a/app/tool/run.py +++ /dev/null @@ -1,43 +0,0 @@ -"""Utility to run shell commands asynchronously with a timeout.""" - -import asyncio - - -TRUNCATED_MESSAGE: str = "To save on context only part of this file has been shown to you. You should retry this tool after you have searched inside the file with `grep -n` in order to find the line numbers of what you are looking for." -MAX_RESPONSE_LEN: int = 16000 - - -def maybe_truncate(content: str, truncate_after: int | None = MAX_RESPONSE_LEN): - """Truncate content and append a notice if content exceeds the specified length.""" - return ( - content - if not truncate_after or len(content) <= truncate_after - else content[:truncate_after] + TRUNCATED_MESSAGE - ) - - -async def run( - cmd: str, - timeout: float | None = 120.0, # seconds - truncate_after: int | None = MAX_RESPONSE_LEN, -): - """Run a shell command asynchronously with a timeout.""" - process = await asyncio.create_subprocess_shell( - cmd, stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE - ) - - try: - stdout, stderr = await asyncio.wait_for(process.communicate(), timeout=timeout) - return ( - process.returncode or 0, - maybe_truncate(stdout.decode(), truncate_after=truncate_after), - maybe_truncate(stderr.decode(), truncate_after=truncate_after), - ) - except asyncio.TimeoutError as exc: - try: - process.kill() - except ProcessLookupError: - pass - raise TimeoutError( - f"Command '{cmd}' timed out after {timeout} seconds" - ) from exc diff --git a/app/tool/str_replace_editor.py b/app/tool/str_replace_editor.py index 4094565..47e3d38 100644 --- a/app/tool/str_replace_editor.py +++ b/app/tool/str_replace_editor.py @@ -1,11 +1,18 @@ -from collections import defaultdict -from pathlib import Path -from typing import Literal, get_args +"""File and directory manipulation tool with sandbox support.""" +from collections import defaultdict +from typing import Any, DefaultDict, List, Literal, Optional, get_args + +from app.config import config from app.exceptions import ToolError from app.tool import BaseTool from app.tool.base import CLIResult, ToolResult -from app.tool.run import run +from app.tool.file_operators import ( + FileOperator, + LocalFileOperator, + PathLike, + SandboxFileOperator, +) Command = Literal[ @@ -15,12 +22,17 @@ Command = Literal[ "insert", "undo_edit", ] + +# Constants SNIPPET_LINES: int = 4 - MAX_RESPONSE_LEN: int = 16000 +TRUNCATED_MESSAGE: str = ( + "To save on context only part of this file has been shown to you. " + "You should retry this tool after you have searched inside the file with `grep -n` " + "in order to find the line numbers of what you are looking for." +) -TRUNCATED_MESSAGE: str = "To save on context only part of this file has been shown to you. You should retry this tool after you have searched inside the file with `grep -n` in order to find the line numbers of what you are looking for." - +# Tool description _STR_REPLACE_EDITOR_DESCRIPTION = """Custom editing tool for viewing, creating and editing files * State is persistent across command calls and discussions with the user * If `path` is a file, `view` displays the result of applying `cat -n`. If `path` is a directory, `view` lists non-hidden files and directories up to 2 levels deep @@ -35,17 +47,17 @@ Notes for using the `str_replace` command: """ -def maybe_truncate(content: str, truncate_after: int | None = MAX_RESPONSE_LEN): +def maybe_truncate( + content: str, truncate_after: Optional[int] = MAX_RESPONSE_LEN +) -> str: """Truncate content and append a notice if content exceeds the specified length.""" - return ( - content - if not truncate_after or len(content) <= truncate_after - else content[:truncate_after] + TRUNCATED_MESSAGE - ) + if not truncate_after or len(content) <= truncate_after: + return content + return content[:truncate_after] + TRUNCATED_MESSAGE class StrReplaceEditor(BaseTool): - """A tool for executing bash commands""" + """A tool for viewing, creating, and editing files with sandbox support.""" name: str = "str_replace_editor" description: str = _STR_REPLACE_EDITOR_DESCRIPTION @@ -85,8 +97,19 @@ class StrReplaceEditor(BaseTool): }, "required": ["command", "path"], } + _file_history: DefaultDict[PathLike, List[str]] = defaultdict(list) + _local_operator: LocalFileOperator = LocalFileOperator() + # todo: Sandbox resources need to be destroyed at the appropriate time. + _sandbox_operator: SandboxFileOperator = SandboxFileOperator() - _file_history: list = defaultdict(list) + # def _get_operator(self, use_sandbox: bool) -> FileOperator: + def _get_operator(self) -> FileOperator: + """Get the appropriate file operator based on execution mode.""" + return ( + self._sandbox_operator + if config.sandbox.use_sandbox + else self._local_operator + ) async def execute( self, @@ -98,24 +121,30 @@ class StrReplaceEditor(BaseTool): old_str: str | None = None, new_str: str | None = None, insert_line: int | None = None, - **kwargs, + **kwargs: Any, ) -> str: - _path = Path(path) - self.validate_path(command, _path) + """Execute a file operation command.""" + # Get the appropriate file operator + operator = self._get_operator() + + # Validate path and command combination + await self.validate_path(command, path, operator) + + # Execute the appropriate command if command == "view": - result = await self.view(_path, view_range) + result = await self.view(path, view_range, operator) elif command == "create": if file_text is None: raise ToolError("Parameter `file_text` is required for command: create") - self.write_file(_path, file_text) - self._file_history[_path].append(file_text) - result = ToolResult(output=f"File created successfully at: {_path}") + await operator.write_file(path, file_text) + self._file_history[path].append(file_text) + result = ToolResult(output=f"File created successfully at: {path}") elif command == "str_replace": if old_str is None: raise ToolError( "Parameter `old_str` is required for command: str_replace" ) - result = self.str_replace(_path, old_str, new_str) + result = await self.str_replace(path, old_str, new_str, operator) elif command == "insert": if insert_line is None: raise ToolError( @@ -123,92 +152,149 @@ class StrReplaceEditor(BaseTool): ) if new_str is None: raise ToolError("Parameter `new_str` is required for command: insert") - result = self.insert(_path, insert_line, new_str) + result = await self.insert(path, insert_line, new_str, operator) elif command == "undo_edit": - result = self.undo_edit(_path) + result = await self.undo_edit(path, operator) else: + # This should be caught by type checking, but we include it for safety raise ToolError( f'Unrecognized command {command}. The allowed commands for the {self.name} tool are: {", ".join(get_args(Command))}' ) + return str(result) - def validate_path(self, command: str, path: Path): - """ - Check that the path/command combination is valid. - """ - # Check if its an absolute path - if not path.is_absolute(): - suggested_path = Path("") / path + async def validate_path( + self, command: str, path: str, operator: FileOperator + ) -> None: + """Validate path and command combination based on execution environment.""" + # Check if path is absolute + if not path.startswith("/"): + suggested_path = f"/{path}" raise ToolError( - f"The path {path} is not an absolute path, it should start with `/`. Maybe you meant {suggested_path}?" + f"The path {path} is not an absolute path, it should start with `/`. " + f"Maybe you meant {suggested_path}?" ) - # Check if path exists - if not path.exists() and command != "create": - raise ToolError( - f"The path {path} does not exist. Please provide a valid path." - ) - if path.exists() and command == "create": - raise ToolError( - f"File already exists at: {path}. Cannot overwrite files using command `create`." - ) - # Check if the path points to a directory - if path.is_dir(): - if command != "view": + + # Only check if path exists for non-create commands + if command != "create": + if not await operator.exists(path): + raise ToolError( + f"The path {path} does not exist. Please provide a valid path." + ) + + # Check if path is a directory + is_dir = await operator.is_directory(path) + if is_dir and command != "view": raise ToolError( f"The path {path} is a directory and only the `view` command can be used on directories" ) - async def view(self, path: Path, view_range: list[int] | None = None): - """Implement the view command""" - if path.is_dir(): + # Check if file exists for create command + elif command == "create": + exists = await operator.exists(path) + if exists: + raise ToolError( + f"File already exists at: {path}. Cannot overwrite files using command `create`." + ) + + async def view( + self, + path: PathLike, + view_range: Optional[List[int]] = None, + operator: FileOperator = None, + ) -> CLIResult: + """Display file or directory content.""" + # Determine if path is a directory + is_dir = await operator.is_directory(path) + + if is_dir: + # Directory handling if view_range: raise ToolError( "The `view_range` parameter is not allowed when `path` points to a directory." ) - _, stdout, stderr = await run( - rf"find {path} -maxdepth 2 -not -path '*/\.*'" - ) - if not stderr: - stdout = f"Here's the files and directories up to 2 levels deep in {path}, excluding hidden items:\n{stdout}\n" - return CLIResult(output=stdout, error=stderr) + return await self._view_directory(path, operator) + else: + # File handling + return await self._view_file(path, operator, view_range) - file_content = self.read_file(path) + @staticmethod + async def _view_directory(path: PathLike, operator: FileOperator) -> CLIResult: + """Display directory contents.""" + find_cmd = f"find {path} -maxdepth 2 -not -path '*/\\.*'" + + # Execute command using the operator + returncode, stdout, stderr = await operator.run_command(find_cmd) + + if not stderr: + stdout = ( + f"Here's the files and directories up to 2 levels deep in {path}, " + f"excluding hidden items:\n{stdout}\n" + ) + + return CLIResult(output=stdout, error=stderr) + + async def _view_file( + self, + path: PathLike, + operator: FileOperator, + view_range: Optional[List[int]] = None, + ) -> CLIResult: + """Display file content, optionally within a specified line range.""" + # Read file content + file_content = await operator.read_file(path) init_line = 1 + + # Apply view range if specified if view_range: if len(view_range) != 2 or not all(isinstance(i, int) for i in view_range): raise ToolError( "Invalid `view_range`. It should be a list of two integers." ) + file_lines = file_content.split("\n") n_lines_file = len(file_lines) init_line, final_line = view_range + + # Validate view range if init_line < 1 or init_line > n_lines_file: raise ToolError( - f"Invalid `view_range`: {view_range}. Its first element `{init_line}` should be within the range of lines of the file: {[1, n_lines_file]}" + f"Invalid `view_range`: {view_range}. Its first element `{init_line}` should be " + f"within the range of lines of the file: {[1, n_lines_file]}" ) if final_line > n_lines_file: raise ToolError( - f"Invalid `view_range`: {view_range}. Its second element `{final_line}` should be smaller than the number of lines in the file: `{n_lines_file}`" + f"Invalid `view_range`: {view_range}. Its second element `{final_line}` should be " + f"smaller than the number of lines in the file: `{n_lines_file}`" ) if final_line != -1 and final_line < init_line: raise ToolError( - f"Invalid `view_range`: {view_range}. Its second element `{final_line}` should be larger or equal than its first `{init_line}`" + f"Invalid `view_range`: {view_range}. Its second element `{final_line}` should be " + f"larger or equal than its first `{init_line}`" ) + # Apply range if final_line == -1: file_content = "\n".join(file_lines[init_line - 1 :]) else: file_content = "\n".join(file_lines[init_line - 1 : final_line]) + # Format and return result return CLIResult( output=self._make_output(file_content, str(path), init_line=init_line) ) - def str_replace(self, path: Path, old_str: str, new_str: str | None): - """Implement the str_replace command, which replaces old_str with new_str in the file content""" - # Read the file content - file_content = self.read_file(path).expandtabs() + async def str_replace( + self, + path: PathLike, + old_str: str, + new_str: Optional[str] = None, + operator: FileOperator = None, + ) -> CLIResult: + """Replace a unique string in a file with a new string.""" + # Read file content and expand tabs + file_content = (await operator.read_file(path)).expandtabs() old_str = old_str.expandtabs() new_str = new_str.expandtabs() if new_str is not None else "" @@ -219,6 +305,7 @@ class StrReplaceEditor(BaseTool): f"No replacement was performed, old_str `{old_str}` did not appear verbatim in {path}." ) elif occurrences > 1: + # Find line numbers of occurrences file_content_lines = file_content.split("\n") lines = [ idx + 1 @@ -226,16 +313,17 @@ class StrReplaceEditor(BaseTool): if old_str in line ] raise ToolError( - f"No replacement was performed. Multiple occurrences of old_str `{old_str}` in lines {lines}. Please ensure it is unique" + f"No replacement was performed. Multiple occurrences of old_str `{old_str}` " + f"in lines {lines}. Please ensure it is unique" ) # Replace old_str with new_str new_file_content = file_content.replace(old_str, new_str) # Write the new content to the file - self.write_file(path, new_file_content) + await operator.write_file(path, new_file_content) - # Save the content to history + # Save the original content to history self._file_history[path].append(file_content) # Create a snippet of the edited section @@ -253,36 +341,50 @@ class StrReplaceEditor(BaseTool): return CLIResult(output=success_msg) - def insert(self, path: Path, insert_line: int, new_str: str): - """Implement the insert command, which inserts new_str at the specified line in the file content.""" - file_text = self.read_file(path).expandtabs() + async def insert( + self, + path: PathLike, + insert_line: int, + new_str: str, + operator: FileOperator = None, + ) -> CLIResult: + """Insert text at a specific line in a file.""" + # Read and prepare content + file_text = (await operator.read_file(path)).expandtabs() new_str = new_str.expandtabs() file_text_lines = file_text.split("\n") n_lines_file = len(file_text_lines) + # Validate insert_line if insert_line < 0 or insert_line > n_lines_file: raise ToolError( - f"Invalid `insert_line` parameter: {insert_line}. It should be within the range of lines of the file: {[0, n_lines_file]}" + f"Invalid `insert_line` parameter: {insert_line}. It should be within " + f"the range of lines of the file: {[0, n_lines_file]}" ) + # Perform insertion new_str_lines = new_str.split("\n") new_file_text_lines = ( file_text_lines[:insert_line] + new_str_lines + file_text_lines[insert_line:] ) + + # Create a snippet for preview snippet_lines = ( file_text_lines[max(0, insert_line - SNIPPET_LINES) : insert_line] + new_str_lines + file_text_lines[insert_line : insert_line + SNIPPET_LINES] ) + # Join lines and write to file new_file_text = "\n".join(new_file_text_lines) snippet = "\n".join(snippet_lines) - self.write_file(path, new_file_text) + await operator.write_file(path, new_file_text) self._file_history[path].append(file_text) + # Prepare success message success_msg = f"The file {path} has been edited. " success_msg += self._make_output( snippet, @@ -290,51 +392,43 @@ class StrReplaceEditor(BaseTool): max(1, insert_line - SNIPPET_LINES + 1), ) success_msg += "Review the changes and make sure they are as expected (correct indentation, no duplicate lines, etc). Edit the file again if necessary." + return CLIResult(output=success_msg) - def undo_edit(self, path: Path): - """Implement the undo_edit command.""" + async def undo_edit( + self, path: PathLike, operator: FileOperator = None + ) -> CLIResult: + """Revert the last edit made to a file.""" if not self._file_history[path]: raise ToolError(f"No edit history found for {path}.") old_text = self._file_history[path].pop() - self.write_file(path, old_text) + await operator.write_file(path, old_text) return CLIResult( output=f"Last edit to {path} undone successfully. {self._make_output(old_text, str(path))}" ) - def read_file(self, path: Path): - """Read the content of a file from a given path; raise a ToolError if an error occurs.""" - try: - return path.read_text() - except Exception as e: - raise ToolError(f"Ran into {e} while trying to read {path}") from None - - def write_file(self, path: Path, file: str): - """Write the content of a file to a given path; raise a ToolError if an error occurs.""" - try: - path.write_text(file) - except Exception as e: - raise ToolError(f"Ran into {e} while trying to write to {path}") from None - def _make_output( self, file_content: str, file_descriptor: str, init_line: int = 1, expand_tabs: bool = True, - ): - """Generate output for the CLI based on the content of a file.""" + ) -> str: + """Format file content for display with line numbers.""" file_content = maybe_truncate(file_content) if expand_tabs: file_content = file_content.expandtabs() + + # Add line numbers to each line file_content = "\n".join( [ f"{i + init_line:6}\t{line}" for i, line in enumerate(file_content.split("\n")) ] ) + return ( f"Here's the result of running `cat -n` on {file_descriptor}:\n" + file_content diff --git a/config/config.example.toml b/config/config.example.toml index de71832..68d6a62 100644 --- a/config/config.example.toml +++ b/config/config.example.toml @@ -11,3 +11,13 @@ temperature = 0.0 model = "claude-3-5-sonnet" base_url = "https://api.openai.com/v1" api_key = "sk-..." + +# Sandbox configuration +[sandbox] +use_sandbox = false +image = "python:3.10-slim" +work_dir = "/workspace" +memory_limit = "1g" # 512m +cpu_limit = 2.0 +timeout = 300 +network_enabled = false diff --git a/requirements.txt b/requirements.txt index a4128d2..eae7eff 100644 --- a/requirements.txt +++ b/requirements.txt @@ -19,3 +19,7 @@ aiofiles~=24.1.0 pydantic_core~=2.27.2 colorama~=0.4.6 playwright~=1.49.1 + +docker~=7.1.0 +pytest~=8.3.5 +pytest-asyncio~=0.25.3 diff --git a/tests/sandbox/test_client.py b/tests/sandbox/test_client.py new file mode 100644 index 0000000..5b05fed --- /dev/null +++ b/tests/sandbox/test_client.py @@ -0,0 +1,110 @@ +import tempfile +from pathlib import Path +from typing import AsyncGenerator + +import pytest +import pytest_asyncio + +from app.config import SandboxConfig +from app.sandbox.client import LocalSandboxClient, create_sandbox_client + + +@pytest_asyncio.fixture(scope="function") +async def local_client() -> AsyncGenerator[LocalSandboxClient, None]: + """Creates a local sandbox client for testing.""" + client = await create_sandbox_client() + try: + yield client + finally: + await client.cleanup() + + +@pytest.fixture(scope="function") +def temp_dir() -> Path: + """Creates a temporary directory for testing.""" + with tempfile.TemporaryDirectory() as tmp_dir: + yield Path(tmp_dir) + + +@pytest.mark.asyncio +async def test_sandbox_creation(local_client: LocalSandboxClient): + """Tests sandbox creation with specific configuration.""" + config = SandboxConfig( + image="python:3.10-slim", + work_dir="/workspace", + memory_limit="512m", + cpu_limit=0.5, + ) + + await local_client.create(config) + result = await local_client.run_command("python3 --version") + assert "Python 3.10" in result + + +@pytest.mark.asyncio +async def test_local_command_execution(local_client: LocalSandboxClient): + """Tests command execution in local sandbox.""" + await local_client.create() + + result = await local_client.run_command("echo 'test'") + assert result.strip() == "test" + + with pytest.raises(Exception): + await local_client.run_command("sleep 10", timeout=1) + + +@pytest.mark.asyncio +async def test_local_file_operations(local_client: LocalSandboxClient, temp_dir: Path): + """Tests file operations in local sandbox.""" + await local_client.create() + + # Test write and read operations + test_content = "Hello, World!" + await local_client.write_file("/workspace/test.txt", test_content) + content = await local_client.read_file("/workspace/test.txt") + assert content.strip() == test_content + + # Test copying file to container + src_file = temp_dir / "src.txt" + src_file.write_text("Copy to container") + await local_client.copy_to(str(src_file), "/workspace/copied.txt") + content = await local_client.read_file("/workspace/copied.txt") + assert content.strip() == "Copy to container" + + # Test copying file from container + dst_file = temp_dir / "dst.txt" + await local_client.copy_from("/workspace/test.txt", str(dst_file)) + assert dst_file.read_text().strip() == test_content + + +@pytest.mark.asyncio +async def test_local_volume_binding(local_client: LocalSandboxClient, temp_dir: Path): + """Tests volume binding in local sandbox.""" + bind_path = str(temp_dir) + volume_bindings = {bind_path: "/data"} + + await local_client.create(volume_bindings=volume_bindings) + + test_file = temp_dir / "test.txt" + test_file.write_text("Volume test") + + content = await local_client.read_file("/data/test.txt") + assert "Volume test" in content + + +@pytest.mark.asyncio +async def test_local_error_handling(local_client: LocalSandboxClient): + """Tests error handling in local sandbox.""" + await local_client.create() + + with pytest.raises(Exception) as exc: + await local_client.read_file("/nonexistent.txt") + assert "not found" in str(exc.value).lower() + + with pytest.raises(Exception) as exc: + await local_client.copy_from("/nonexistent.txt", "local.txt") + assert "not found" in str(exc.value).lower() + + +if __name__ == "__main__": + pytest.main(["-v", __file__]) diff --git a/tests/sandbox/test_docker_terminal.py b/tests/sandbox/test_docker_terminal.py new file mode 100644 index 0000000..7903d95 --- /dev/null +++ b/tests/sandbox/test_docker_terminal.py @@ -0,0 +1,104 @@ +"""Tests for the AsyncDockerizedTerminal implementation.""" + +import docker +import pytest +import pytest_asyncio + +from app.sandbox.core.terminal import AsyncDockerizedTerminal + + +@pytest.fixture(scope="module") +def docker_client(): + """Fixture providing a Docker client.""" + return docker.from_env() + + +@pytest_asyncio.fixture(scope="module") +async def docker_container(docker_client): + """Fixture providing a test Docker container.""" + container = docker_client.containers.run( + "python:3.10-slim", + "tail -f /dev/null", + name="test_container", + detach=True, + remove=True, + ) + yield container + container.stop() + + +@pytest_asyncio.fixture +async def terminal(docker_container): + """Fixture providing an initialized AsyncDockerizedTerminal instance.""" + terminal = AsyncDockerizedTerminal( + docker_container, + working_dir="/workspace", + env_vars={"TEST_VAR": "test_value"}, + default_timeout=30, + ) + await terminal.init() + yield terminal + await terminal.close() + + +class TestAsyncDockerizedTerminal: + """Test cases for AsyncDockerizedTerminal.""" + + @pytest.mark.asyncio + async def test_basic_command_execution(self, terminal): + """Test basic command execution functionality.""" + result = await terminal.run_command("echo 'Hello World'") + assert "Hello World" in result + + @pytest.mark.asyncio + async def test_environment_variables(self, terminal): + """Test environment variable setting and access.""" + result = await terminal.run_command("echo $TEST_VAR") + assert "test_value" in result + + @pytest.mark.asyncio + async def test_working_directory(self, terminal): + """Test working directory setup.""" + result = await terminal.run_command("pwd") + assert "/workspace" == result + + @pytest.mark.asyncio + async def test_command_timeout(self, docker_container): + """Test command timeout functionality.""" + terminal = AsyncDockerizedTerminal(docker_container, default_timeout=1) + await terminal.init() + try: + with pytest.raises(TimeoutError): + await terminal.run_command("sleep 5") + finally: + await terminal.close() + + @pytest.mark.asyncio + async def test_multiple_commands(self, terminal): + """Test execution of multiple commands in sequence.""" + cmd1 = await terminal.run_command("echo 'First'") + cmd2 = await terminal.run_command("echo 'Second'") + assert "First" in cmd1 + assert "Second" in cmd2 + + @pytest.mark.asyncio + async def test_session_cleanup(self, docker_container): + """Test proper cleanup of resources.""" + terminal = AsyncDockerizedTerminal(docker_container) + await terminal.init() + assert terminal.session is not None + await terminal.close() + # Verify session is properly cleaned up + # Note: session object still exists, but internal connection is closed + assert terminal.session is not None + + +# Configure pytest-asyncio +def pytest_configure(config): + """Configure pytest-asyncio.""" + config.addinivalue_line("asyncio_mode", "strict") + config.addinivalue_line("asyncio_default_fixture_loop_scope", "function") + + +if __name__ == "__main__": + pytest.main(["-v", __file__]) diff --git a/tests/sandbox/test_sandbox.py b/tests/sandbox/test_sandbox.py new file mode 100644 index 0000000..a690c98 --- /dev/null +++ b/tests/sandbox/test_sandbox.py @@ -0,0 +1,152 @@ +import pytest +import pytest_asyncio + +from app.sandbox.core.sandbox import DockerSandbox, SandboxConfig + + +@pytest.fixture(scope="module") +def sandbox_config(): + """Creates sandbox configuration for testing.""" + return SandboxConfig( + image="python:3.10-slim", + work_dir="/workspace", + memory_limit="1g", + cpu_limit=0.5, + network_enabled=True, + ) + + +@pytest_asyncio.fixture(scope="module") +async def sandbox(sandbox_config): + """Creates and manages a test sandbox instance.""" + sandbox = DockerSandbox(sandbox_config) + await sandbox.create() + try: + yield sandbox + finally: + await sandbox.cleanup() + + +@pytest.mark.asyncio +async def test_sandbox_working_directory(sandbox): + """Tests sandbox working directory configuration.""" + result = await sandbox.terminal.run_command("pwd") + assert result.strip() == "/workspace" + + +@pytest.mark.asyncio +async def test_sandbox_file_operations(sandbox): + """Tests sandbox file read/write operations.""" + # Test file writing + test_content = "Hello from sandbox!" + await sandbox.write_file("/workspace/test.txt", test_content) + + # Test file reading + content = await sandbox.read_file("/workspace/test.txt") + assert content.strip() == test_content + + +@pytest.mark.asyncio +async def test_sandbox_python_execution(sandbox): + """Tests Python code execution in sandbox.""" + # Write test file + await sandbox.write_file("/workspace/test.txt", "Hello from file!") + + # Write Python script + python_code = """ +print("Hello from Python!") +with open('/workspace/test.txt') as f: + print(f.read()) +""" + await sandbox.write_file("/workspace/test.py", python_code) + + # Execute script and verify output + result = await sandbox.terminal.run_command("python3 /workspace/test.py") + assert "Hello from Python!" in result + assert "Hello from file!" in result + + +@pytest.mark.asyncio +async def test_sandbox_file_persistence(sandbox): + """Tests file persistence in sandbox.""" + # Create multiple files + files = { + "file1.txt": "Content 1", + "file2.txt": "Content 2", + "nested/file3.txt": "Content 3", + } + + # Write files + for path, content in files.items(): + await sandbox.write_file(f"/workspace/{path}", content) + + # Verify file contents + for path, expected_content in files.items(): + content = await sandbox.read_file(f"/workspace/{path}") + assert content.strip() == expected_content + + +@pytest.mark.asyncio +async def test_sandbox_python_environment(sandbox): + """Tests Python environment configuration.""" + # Test Python version + result = await sandbox.terminal.run_command("python3 --version") + assert "Python 3.10" in result + + # Test basic module imports + python_code = """ +import sys +import os +import json +print("Python is working!") +""" + await sandbox.write_file("/workspace/env_test.py", python_code) + result = await sandbox.terminal.run_command("python3 /workspace/env_test.py") + assert "Python is working!" in result + + +@pytest.mark.asyncio +async def test_sandbox_network_access(sandbox): + """Tests sandbox network access.""" + if not sandbox.config.network_enabled: + pytest.skip("Network access is disabled") + + # Test network connectivity + await sandbox.terminal.run_command("apt update && apt install curl -y") + result = await sandbox.terminal.run_command("curl -I https://www.example.com") + assert "HTTP/2 200" in result + + +@pytest.mark.asyncio +async def test_sandbox_cleanup(sandbox_config): + """Tests sandbox cleanup process.""" + sandbox = DockerSandbox(sandbox_config) + await sandbox.create() + + # Create test files + await sandbox.write_file("/workspace/test.txt", "test") + container_id = sandbox.terminal.container.id + # Perform cleanup + await sandbox.cleanup() + + # Verify container has been removed + import docker + + client = docker.from_env() + containers = client.containers.list(all=True) + assert not any(c.id == container_id for c in containers) + + +@pytest.mark.asyncio +async def test_sandbox_error_handling(): + """Tests error handling with invalid configuration.""" + # Test invalid configuration + invalid_config = SandboxConfig(image="nonexistent:latest", work_dir="/invalid") + + sandbox = DockerSandbox(invalid_config) + with pytest.raises(Exception): + await sandbox.create() + + +if __name__ == "__main__": + pytest.main(["-v", __file__]) diff --git a/tests/sandbox/test_sandbox_manager.py b/tests/sandbox/test_sandbox_manager.py new file mode 100644 index 0000000..09f498d --- /dev/null +++ b/tests/sandbox/test_sandbox_manager.py @@ -0,0 +1,138 @@ +import asyncio +import os +import tempfile +from typing import AsyncGenerator + +import pytest +import pytest_asyncio + +from app.sandbox.core.manager import SandboxManager + + +@pytest_asyncio.fixture(scope="function") +async def manager() -> AsyncGenerator[SandboxManager, None]: + """Creates a sandbox manager instance. + + Uses function scope to ensure each test case has its own manager instance. + """ + manager = SandboxManager(max_sandboxes=2, idle_timeout=60, cleanup_interval=30) + try: + yield manager + finally: + # Ensure all resources are cleaned up + await manager.cleanup() + + +@pytest.fixture +def temp_file(): + """Creates a temporary test file.""" + with tempfile.NamedTemporaryFile(mode="w+", delete=False) as f: + f.write("test content") + path = f.name + try: + yield path + finally: + if os.path.exists(path): + os.unlink(path) + + +@pytest.mark.asyncio +async def test_create_sandbox(manager): + """Tests sandbox creation.""" + # Create default sandbox + sandbox_id = await manager.create_sandbox() + assert sandbox_id in manager._sandboxes + assert sandbox_id in manager._last_used + + # Verify sandbox functionality + sandbox = await manager.get_sandbox(sandbox_id) + result = await sandbox.run_command("echo 'test'") + assert result.strip() == "test" + + +@pytest.mark.asyncio +async def test_max_sandboxes_limit(manager): + """Tests maximum sandbox limit enforcement.""" + created_sandboxes = [] + try: + # Create maximum number of sandboxes + for _ in range(manager.max_sandboxes): + sandbox_id = await manager.create_sandbox() + created_sandboxes.append(sandbox_id) + + # Verify created sandbox count + assert len(manager._sandboxes) == manager.max_sandboxes + + # Attempting to create additional sandbox should fail + with pytest.raises(RuntimeError) as exc_info: + await manager.create_sandbox() + + # Verify error message + expected_message = ( + f"Maximum number of sandboxes ({manager.max_sandboxes}) reached" + ) + assert str(exc_info.value) == expected_message + + finally: + # Clean up all created sandboxes + for sandbox_id in created_sandboxes: + try: + await manager.delete_sandbox(sandbox_id) + except Exception as e: + print(f"Failed to cleanup sandbox {sandbox_id}: {e}") + + +@pytest.mark.asyncio +async def test_get_nonexistent_sandbox(manager): + """Tests retrieving a non-existent sandbox.""" + with pytest.raises(KeyError, match="Sandbox .* not found"): + await manager.get_sandbox("nonexistent-id") + + +@pytest.mark.asyncio +async def test_sandbox_cleanup(manager): + """Tests sandbox cleanup functionality.""" + sandbox_id = await manager.create_sandbox() + assert sandbox_id in manager._sandboxes + + await manager.delete_sandbox(sandbox_id) + assert sandbox_id not in manager._sandboxes + assert sandbox_id not in manager._last_used + + +@pytest.mark.asyncio +async def test_idle_sandbox_cleanup(manager): + """Tests automatic cleanup of idle sandboxes.""" + # Set short idle timeout + manager.idle_timeout = 0.1 + + sandbox_id = await manager.create_sandbox() + assert sandbox_id in manager._sandboxes + + # Wait longer than idle timeout + await asyncio.sleep(0.2) + + # Trigger cleanup + await manager._cleanup_idle_sandboxes() + assert sandbox_id not in manager._sandboxes + + +@pytest.mark.asyncio +async def test_manager_cleanup(manager): + """Tests manager cleanup functionality.""" + # Create multiple sandboxes + sandbox_ids = [] + for _ in range(2): + sandbox_id = await manager.create_sandbox() + sandbox_ids.append(sandbox_id) + + # Clean up all resources + await manager.cleanup() + + # Verify all sandboxes have been cleaned up + assert not manager._sandboxes + assert not manager._last_used + + +if __name__ == "__main__": + pytest.main(["-v", __file__])