OpenManus/app/tool/create_chat_completion.py
2025-03-06 22:57:07 +08:00

170 lines
5.5 KiB
Python

from typing import Any, List, Optional, Type, Union, get_args, get_origin
from pydantic import BaseModel, Field
from app.tool import BaseTool
class CreateChatCompletion(BaseTool):
name: str = "create_chat_completion"
description: str = (
"Creates a structured completion with specified output formatting."
)
# Type mapping for JSON schema
type_mapping: dict = {
str: "string",
int: "integer",
float: "number",
bool: "boolean",
dict: "object",
list: "array",
}
response_type: Optional[Type] = None
required: List[str] = Field(default_factory=lambda: ["response"])
def __init__(self, response_type: Optional[Type] = str):
"""Initialize with a specific response type."""
super().__init__()
self.response_type = response_type
self.parameters = self._build_parameters()
def _build_parameters(self) -> dict:
"""Build parameters schema based on response type."""
if self.response_type == str:
return {
"type": "object",
"properties": {
"response": {
"type": "string",
"description": "The response text that should be delivered to the user.",
},
},
"required": self.required,
}
if isinstance(self.response_type, type) and issubclass(
self.response_type, BaseModel
):
schema = self.response_type.model_json_schema()
return {
"type": "object",
"properties": schema["properties"],
"required": schema.get("required", self.required),
}
return self._create_type_schema(self.response_type)
def _create_type_schema(self, type_hint: Type) -> dict:
"""Create a JSON schema for the given type."""
origin = get_origin(type_hint)
args = get_args(type_hint)
# Handle primitive types
if origin is None:
return {
"type": "object",
"properties": {
"response": {
"type": self.type_mapping.get(type_hint, "string"),
"description": f"Response of type {type_hint.__name__}",
}
},
"required": self.required,
}
# Handle List type
if origin is list:
item_type = args[0] if args else Any
return {
"type": "object",
"properties": {
"response": {
"type": "array",
"items": self._get_type_info(item_type),
}
},
"required": self.required,
}
# Handle Dict type
if origin is dict:
value_type = args[1] if len(args) > 1 else Any
return {
"type": "object",
"properties": {
"response": {
"type": "object",
"additionalProperties": self._get_type_info(value_type),
}
},
"required": self.required,
}
# Handle Union type
if origin is Union:
return self._create_union_schema(args)
return self._build_parameters()
def _get_type_info(self, type_hint: Type) -> dict:
"""Get type information for a single type."""
if isinstance(type_hint, type) and issubclass(type_hint, BaseModel):
return type_hint.model_json_schema()
return {
"type": self.type_mapping.get(type_hint, "string"),
"description": f"Value of type {getattr(type_hint, '__name__', 'any')}",
}
def _create_union_schema(self, types: tuple) -> dict:
"""Create schema for Union types."""
return {
"type": "object",
"properties": {
"response": {"anyOf": [self._get_type_info(t) for t in types]}
},
"required": self.required,
}
async def execute(self, required: list | None = None, **kwargs) -> Any:
"""Execute the chat completion with type conversion.
Args:
required: List of required field names or None
**kwargs: Response data
Returns:
Converted response based on response_type
"""
required = required or self.required
# Handle case when required is a list
if isinstance(required, list) and len(required) > 0:
if len(required) == 1:
required_field = required[0]
result = kwargs.get(required_field, "")
else:
# Return multiple fields as a dictionary
return {field: kwargs.get(field, "") for field in required}
else:
required_field = "response"
result = kwargs.get(required_field, "")
# Type conversion logic
if self.response_type == str:
return result
if isinstance(self.response_type, type) and issubclass(
self.response_type, BaseModel
):
return self.response_type(**kwargs)
if get_origin(self.response_type) in (list, dict):
return result # Assuming result is already in correct format
try:
return self.response_type(result)
except (ValueError, TypeError):
return result