170 lines
5.5 KiB
Python
170 lines
5.5 KiB
Python
from typing import Any, List, Optional, Type, Union, get_args, get_origin
|
|
|
|
from pydantic import BaseModel, Field
|
|
|
|
from app.tool import BaseTool
|
|
|
|
|
|
class CreateChatCompletion(BaseTool):
|
|
name: str = "create_chat_completion"
|
|
description: str = (
|
|
"Creates a structured completion with specified output formatting."
|
|
)
|
|
|
|
# Type mapping for JSON schema
|
|
type_mapping: dict = {
|
|
str: "string",
|
|
int: "integer",
|
|
float: "number",
|
|
bool: "boolean",
|
|
dict: "object",
|
|
list: "array",
|
|
}
|
|
response_type: Optional[Type] = None
|
|
required: List[str] = Field(default_factory=lambda: ["response"])
|
|
|
|
def __init__(self, response_type: Optional[Type] = str):
|
|
"""Initialize with a specific response type."""
|
|
super().__init__()
|
|
self.response_type = response_type
|
|
self.parameters = self._build_parameters()
|
|
|
|
def _build_parameters(self) -> dict:
|
|
"""Build parameters schema based on response type."""
|
|
if self.response_type == str:
|
|
return {
|
|
"type": "object",
|
|
"properties": {
|
|
"response": {
|
|
"type": "string",
|
|
"description": "The response text that should be delivered to the user.",
|
|
},
|
|
},
|
|
"required": self.required,
|
|
}
|
|
|
|
if isinstance(self.response_type, type) and issubclass(
|
|
self.response_type, BaseModel
|
|
):
|
|
schema = self.response_type.model_json_schema()
|
|
return {
|
|
"type": "object",
|
|
"properties": schema["properties"],
|
|
"required": schema.get("required", self.required),
|
|
}
|
|
|
|
return self._create_type_schema(self.response_type)
|
|
|
|
def _create_type_schema(self, type_hint: Type) -> dict:
|
|
"""Create a JSON schema for the given type."""
|
|
origin = get_origin(type_hint)
|
|
args = get_args(type_hint)
|
|
|
|
# Handle primitive types
|
|
if origin is None:
|
|
return {
|
|
"type": "object",
|
|
"properties": {
|
|
"response": {
|
|
"type": self.type_mapping.get(type_hint, "string"),
|
|
"description": f"Response of type {type_hint.__name__}",
|
|
}
|
|
},
|
|
"required": self.required,
|
|
}
|
|
|
|
# Handle List type
|
|
if origin is list:
|
|
item_type = args[0] if args else Any
|
|
return {
|
|
"type": "object",
|
|
"properties": {
|
|
"response": {
|
|
"type": "array",
|
|
"items": self._get_type_info(item_type),
|
|
}
|
|
},
|
|
"required": self.required,
|
|
}
|
|
|
|
# Handle Dict type
|
|
if origin is dict:
|
|
value_type = args[1] if len(args) > 1 else Any
|
|
return {
|
|
"type": "object",
|
|
"properties": {
|
|
"response": {
|
|
"type": "object",
|
|
"additionalProperties": self._get_type_info(value_type),
|
|
}
|
|
},
|
|
"required": self.required,
|
|
}
|
|
|
|
# Handle Union type
|
|
if origin is Union:
|
|
return self._create_union_schema(args)
|
|
|
|
return self._build_parameters()
|
|
|
|
def _get_type_info(self, type_hint: Type) -> dict:
|
|
"""Get type information for a single type."""
|
|
if isinstance(type_hint, type) and issubclass(type_hint, BaseModel):
|
|
return type_hint.model_json_schema()
|
|
|
|
return {
|
|
"type": self.type_mapping.get(type_hint, "string"),
|
|
"description": f"Value of type {getattr(type_hint, '__name__', 'any')}",
|
|
}
|
|
|
|
def _create_union_schema(self, types: tuple) -> dict:
|
|
"""Create schema for Union types."""
|
|
return {
|
|
"type": "object",
|
|
"properties": {
|
|
"response": {"anyOf": [self._get_type_info(t) for t in types]}
|
|
},
|
|
"required": self.required,
|
|
}
|
|
|
|
async def execute(self, required: list | None = None, **kwargs) -> Any:
|
|
"""Execute the chat completion with type conversion.
|
|
|
|
Args:
|
|
required: List of required field names or None
|
|
**kwargs: Response data
|
|
|
|
Returns:
|
|
Converted response based on response_type
|
|
"""
|
|
required = required or self.required
|
|
|
|
# Handle case when required is a list
|
|
if isinstance(required, list) and len(required) > 0:
|
|
if len(required) == 1:
|
|
required_field = required[0]
|
|
result = kwargs.get(required_field, "")
|
|
else:
|
|
# Return multiple fields as a dictionary
|
|
return {field: kwargs.get(field, "") for field in required}
|
|
else:
|
|
required_field = "response"
|
|
result = kwargs.get(required_field, "")
|
|
|
|
# Type conversion logic
|
|
if self.response_type == str:
|
|
return result
|
|
|
|
if isinstance(self.response_type, type) and issubclass(
|
|
self.response_type, BaseModel
|
|
):
|
|
return self.response_type(**kwargs)
|
|
|
|
if get_origin(self.response_type) in (list, dict):
|
|
return result # Assuming result is already in correct format
|
|
|
|
try:
|
|
return self.response_type(result)
|
|
except (ValueError, TypeError):
|
|
return result
|