from typing import Any, List, Optional, Type, Union, get_args, get_origin from pydantic import BaseModel, Field from app.tool import BaseTool class CreateChatCompletion(BaseTool): name: str = "create_chat_completion" description: str = ( "Creates a structured completion with specified output formatting." ) # Type mapping for JSON schema type_mapping: dict = { str: "string", int: "integer", float: "number", bool: "boolean", dict: "object", list: "array", } response_type: Optional[Type] = None required: List[str] = Field(default_factory=lambda: ["response"]) def __init__(self, response_type: Optional[Type] = str): """Initialize with a specific response type.""" super().__init__() self.response_type = response_type self.parameters = self._build_parameters() def _build_parameters(self) -> dict: """Build parameters schema based on response type.""" if self.response_type == str: return { "type": "object", "properties": { "response": { "type": "string", "description": "The response text that should be delivered to the user.", }, }, "required": self.required, } if isinstance(self.response_type, type) and issubclass( self.response_type, BaseModel ): schema = self.response_type.model_json_schema() return { "type": "object", "properties": schema["properties"], "required": schema.get("required", self.required), } return self._create_type_schema(self.response_type) def _create_type_schema(self, type_hint: Type) -> dict: """Create a JSON schema for the given type.""" origin = get_origin(type_hint) args = get_args(type_hint) # Handle primitive types if origin is None: return { "type": "object", "properties": { "response": { "type": self.type_mapping.get(type_hint, "string"), "description": f"Response of type {type_hint.__name__}", } }, "required": self.required, } # Handle List type if origin is list: item_type = args[0] if args else Any return { "type": "object", "properties": { "response": { "type": "array", "items": self._get_type_info(item_type), } }, "required": self.required, } # Handle Dict type if origin is dict: value_type = args[1] if len(args) > 1 else Any return { "type": "object", "properties": { "response": { "type": "object", "additionalProperties": self._get_type_info(value_type), } }, "required": self.required, } # Handle Union type if origin is Union: return self._create_union_schema(args) return self._build_parameters() def _get_type_info(self, type_hint: Type) -> dict: """Get type information for a single type.""" if isinstance(type_hint, type) and issubclass(type_hint, BaseModel): return type_hint.model_json_schema() return { "type": self.type_mapping.get(type_hint, "string"), "description": f"Value of type {getattr(type_hint, '__name__', 'any')}", } def _create_union_schema(self, types: tuple) -> dict: """Create schema for Union types.""" return { "type": "object", "properties": { "response": {"anyOf": [self._get_type_info(t) for t in types]} }, "required": self.required, } async def execute(self, required: list | None = None, **kwargs) -> Any: """Execute the chat completion with type conversion. Args: required: List of required field names or None **kwargs: Response data Returns: Converted response based on response_type """ required = required or self.required # Handle case when required is a list if isinstance(required, list) and len(required) > 0: if len(required) == 1: required_field = required[0] result = kwargs.get(required_field, "") else: # Return multiple fields as a dictionary return {field: kwargs.get(field, "") for field in required} else: required_field = "response" result = kwargs.get(required_field, "") # Type conversion logic if self.response_type == str: return result if isinstance(self.response_type, type) and issubclass( self.response_type, BaseModel ): return self.response_type(**kwargs) if get_origin(self.response_type) in (list, dict): return result # Assuming result is already in correct format try: return self.response_type(result) except (ValueError, TypeError): return result