81 lines
2.4 KiB
Python
81 lines
2.4 KiB
Python
from abc import ABC, abstractmethod
|
|
from typing import Any, Dict, Optional
|
|
|
|
from pydantic import BaseModel, Field
|
|
|
|
|
|
class BaseTool(ABC, BaseModel):
|
|
name: str
|
|
description: str
|
|
parameters: Optional[dict] = None
|
|
|
|
class Config:
|
|
arbitrary_types_allowed = True
|
|
|
|
async def __call__(self, **kwargs) -> Any:
|
|
"""Execute the tool with given parameters."""
|
|
return await self.execute(**kwargs)
|
|
|
|
@abstractmethod
|
|
async def execute(self, **kwargs) -> Any:
|
|
"""Execute the tool with given parameters."""
|
|
|
|
def to_param(self) -> Dict:
|
|
"""Convert tool to function call format."""
|
|
return {
|
|
"type": "function",
|
|
"function": {
|
|
"name": self.name,
|
|
"description": self.description,
|
|
"parameters": self.parameters,
|
|
},
|
|
}
|
|
|
|
|
|
class ToolResult(BaseModel):
|
|
"""Represents the result of a tool execution."""
|
|
|
|
output: Any = Field(default=None)
|
|
error: Optional[str] = Field(default=None)
|
|
base64_image: Optional[str] = Field(default=None)
|
|
system: Optional[str] = Field(default=None)
|
|
|
|
class Config:
|
|
arbitrary_types_allowed = True
|
|
|
|
def __bool__(self):
|
|
return any(getattr(self, field) for field in self.__fields__)
|
|
|
|
def __add__(self, other: "ToolResult"):
|
|
def combine_fields(
|
|
field: Optional[str], other_field: Optional[str], concatenate: bool = True
|
|
):
|
|
if field and other_field:
|
|
if concatenate:
|
|
return field + other_field
|
|
raise ValueError("Cannot combine tool results")
|
|
return field or other_field
|
|
|
|
return ToolResult(
|
|
output=combine_fields(self.output, other.output),
|
|
error=combine_fields(self.error, other.error),
|
|
base64_image=combine_fields(self.base64_image, other.base64_image, False),
|
|
system=combine_fields(self.system, other.system),
|
|
)
|
|
|
|
def __str__(self):
|
|
return f"Error: {self.error}" if self.error else self.output
|
|
|
|
def replace(self, **kwargs):
|
|
"""Returns a new ToolResult with the given fields replaced."""
|
|
# return self.copy(update=kwargs)
|
|
return type(self)(**{**self.dict(), **kwargs})
|
|
|
|
|
|
class CLIResult(ToolResult):
|
|
"""A ToolResult that can be rendered as a CLI output."""
|
|
|
|
|
|
class ToolFailure(ToolResult):
|
|
"""A ToolResult that represents a failure."""
|