from typing import Dict, List, Optional, Union from openai import ( APIError, AsyncAzureOpenAI, AsyncOpenAI, AuthenticationError, OpenAIError, RateLimitError, ) from tenacity import retry, stop_after_attempt, wait_random_exponential from app.config import LLMSettings, config from app.logger import logger # Assuming a logger is set up in your app from app.schema import ( ROLE_VALUES, TOOL_CHOICE_TYPE, TOOL_CHOICE_VALUES, Message, ToolChoice, ) REASONING_MODELS = ["o1", "o3-mini"] class LLM: _instances: Dict[str, "LLM"] = {} def __new__( cls, config_name: str = "default", llm_config: Optional[LLMSettings] = None ): if config_name not in cls._instances: instance = super().__new__(cls) instance.__init__(config_name, llm_config) cls._instances[config_name] = instance return cls._instances[config_name] def __init__( self, config_name: str = "default", llm_config: Optional[LLMSettings] = None ): if not hasattr(self, "client"): # Only initialize if not already initialized llm_config = llm_config or config.llm llm_config = llm_config.get(config_name, llm_config["default"]) self.model = llm_config.model self.max_tokens = llm_config.max_tokens self.temperature = llm_config.temperature self.api_type = llm_config.api_type self.api_key = llm_config.api_key self.api_version = llm_config.api_version self.base_url = llm_config.base_url if self.api_type == "azure": self.client = AsyncAzureOpenAI( base_url=self.base_url, api_key=self.api_key, api_version=self.api_version, ) else: self.client = AsyncOpenAI(api_key=self.api_key, base_url=self.base_url) @staticmethod def format_messages(messages: List[Union[dict, Message]]) -> List[dict]: """ Format messages for LLM by converting them to OpenAI message format. Args: messages: List of messages that can be either dict or Message objects Returns: List[dict]: List of formatted messages in OpenAI format Raises: ValueError: If messages are invalid or missing required fields TypeError: If unsupported message types are provided Examples: >>> msgs = [ ... Message.system_message("You are a helpful assistant"), ... {"role": "user", "content": "Hello"}, ... Message.user_message("How are you?") ... ] >>> formatted = LLM.format_messages(msgs) """ formatted_messages = [] for message in messages: if isinstance(message, Message): message = message.to_dict() if isinstance(message, dict): # If message is a dict, ensure it has required fields if "role" not in message: raise ValueError("Message dict must contain 'role' field") if "content" in message or "tool_calls" in message: formatted_messages.append(message) # else: do not include the message else: raise TypeError(f"Unsupported message type: {type(message)}") # Validate all messages have required fields for msg in formatted_messages: if msg["role"] not in ROLE_VALUES: raise ValueError(f"Invalid role: {msg['role']}") return formatted_messages @retry( wait=wait_random_exponential(min=1, max=60), stop=stop_after_attempt(6), ) async def ask( self, messages: List[Union[dict, Message]], system_msgs: Optional[List[Union[dict, Message]]] = None, stream: bool = True, temperature: Optional[float] = None, ) -> str: """ Send a prompt to the LLM and get the response. Args: messages: List of conversation messages system_msgs: Optional system messages to prepend stream (bool): Whether to stream the response temperature (float): Sampling temperature for the response Returns: str: The generated response Raises: ValueError: If messages are invalid or response is empty OpenAIError: If API call fails after retries Exception: For unexpected errors """ try: # Format system and user messages if system_msgs: system_msgs = self.format_messages(system_msgs) messages = system_msgs + self.format_messages(messages) else: messages = self.format_messages(messages) params = { "model": self.model, "messages": messages, } if self.model in REASONING_MODELS: params["max_completion_tokens"] = self.max_tokens else: params["max_tokens"] = self.max_tokens params["temperature"] = temperature or self.temperature if not stream: # Non-streaming request params["stream"] = False response = await self.client.chat.completions.create(**params) if not response.choices or not response.choices[0].message.content: raise ValueError("Empty or invalid response from LLM") return response.choices[0].message.content # Streaming request params["stream"] = True response = await self.client.chat.completions.create(**params) collected_messages = [] async for chunk in response: chunk_message = chunk.choices[0].delta.content or "" collected_messages.append(chunk_message) print(chunk_message, end="", flush=True) print() # Newline after streaming full_response = "".join(collected_messages).strip() if not full_response: raise ValueError("Empty response from streaming LLM") return full_response except ValueError as ve: logger.error(f"Validation error: {ve}") raise except OpenAIError as oe: logger.error(f"OpenAI API error: {oe}") raise except Exception as e: logger.error(f"Unexpected error in ask: {e}") raise @retry( wait=wait_random_exponential(min=1, max=60), stop=stop_after_attempt(6), ) async def ask_tool( self, messages: List[Union[dict, Message]], system_msgs: Optional[List[Union[dict, Message]]] = None, timeout: int = 300, tools: Optional[List[dict]] = None, tool_choice: TOOL_CHOICE_TYPE = ToolChoice.AUTO, # type: ignore temperature: Optional[float] = None, **kwargs, ): """ Ask LLM using functions/tools and return the response. Args: messages: List of conversation messages system_msgs: Optional system messages to prepend timeout: Request timeout in seconds tools: List of tools to use tool_choice: Tool choice strategy temperature: Sampling temperature for the response **kwargs: Additional completion arguments Returns: ChatCompletionMessage: The model's response Raises: ValueError: If tools, tool_choice, or messages are invalid OpenAIError: If API call fails after retries Exception: For unexpected errors """ try: # Validate tool_choice if tool_choice not in TOOL_CHOICE_VALUES: raise ValueError(f"Invalid tool_choice: {tool_choice}") # Format messages if system_msgs: system_msgs = self.format_messages(system_msgs) messages = system_msgs + self.format_messages(messages) else: messages = self.format_messages(messages) # Validate tools if provided if tools: for tool in tools: if not isinstance(tool, dict) or "type" not in tool: raise ValueError("Each tool must be a dict with 'type' field") # Set up the completion request params = { "model": self.model, "messages": messages, "tools": tools, "tool_choice": tool_choice, "timeout": timeout, **kwargs, } if self.model in REASONING_MODELS: params["max_completion_tokens"] = self.max_tokens else: params["max_tokens"] = self.max_tokens params["temperature"] = temperature or self.temperature response = await self.client.chat.completions.create(**params) # Check if response is valid if not response.choices or not response.choices[0].message: print(response) raise ValueError("Invalid or empty response from LLM") return response.choices[0].message except ValueError as ve: logger.error(f"Validation error in ask_tool: {ve}") raise except OpenAIError as oe: if isinstance(oe, AuthenticationError): logger.error("Authentication failed. Check API key.") elif isinstance(oe, RateLimitError): logger.error("Rate limit exceeded. Consider increasing retry attempts.") elif isinstance(oe, APIError): logger.error(f"API error: {oe}") raise except Exception as e: logger.error(f"Unexpected error in ask_tool: {e}") raise