This commit improves the message handling in the LLM class to gracefully handle messages without 'content' or 'tool_calls' fields. Previously, the system would raise a ValueError when encountering such messages, causing crashes when working with models like Google's Gemini that sometimes return messages with different structures. Key changes: - Reordered message processing to check for Message objects first - Changed validation approach to silently skip malformed messages instead of crashing - Removed the strict ValueError when content/tool_calls are missing This change maintains compatibility with correctly formatted messages while improving robustness when working with various LLM providers.
280 lines
9.9 KiB
Python
280 lines
9.9 KiB
Python
from typing import Dict, List, Optional, Union
|
|
|
|
from openai import (
|
|
APIError,
|
|
AsyncAzureOpenAI,
|
|
AsyncOpenAI,
|
|
AuthenticationError,
|
|
OpenAIError,
|
|
RateLimitError,
|
|
)
|
|
from tenacity import retry, stop_after_attempt, wait_random_exponential
|
|
|
|
from app.config import LLMSettings, config
|
|
from app.logger import logger # Assuming a logger is set up in your app
|
|
from app.schema import (
|
|
ROLE_VALUES,
|
|
TOOL_CHOICE_TYPE,
|
|
TOOL_CHOICE_VALUES,
|
|
Message,
|
|
ToolChoice,
|
|
)
|
|
|
|
|
|
REASONING_MODELS = ["o1", "o3-mini"]
|
|
|
|
|
|
class LLM:
|
|
_instances: Dict[str, "LLM"] = {}
|
|
|
|
def __new__(
|
|
cls, config_name: str = "default", llm_config: Optional[LLMSettings] = None
|
|
):
|
|
if config_name not in cls._instances:
|
|
instance = super().__new__(cls)
|
|
instance.__init__(config_name, llm_config)
|
|
cls._instances[config_name] = instance
|
|
return cls._instances[config_name]
|
|
|
|
def __init__(
|
|
self, config_name: str = "default", llm_config: Optional[LLMSettings] = None
|
|
):
|
|
if not hasattr(self, "client"): # Only initialize if not already initialized
|
|
llm_config = llm_config or config.llm
|
|
llm_config = llm_config.get(config_name, llm_config["default"])
|
|
self.model = llm_config.model
|
|
self.max_tokens = llm_config.max_tokens
|
|
self.temperature = llm_config.temperature
|
|
self.api_type = llm_config.api_type
|
|
self.api_key = llm_config.api_key
|
|
self.api_version = llm_config.api_version
|
|
self.base_url = llm_config.base_url
|
|
if self.api_type == "azure":
|
|
self.client = AsyncAzureOpenAI(
|
|
base_url=self.base_url,
|
|
api_key=self.api_key,
|
|
api_version=self.api_version,
|
|
)
|
|
else:
|
|
self.client = AsyncOpenAI(api_key=self.api_key, base_url=self.base_url)
|
|
|
|
@staticmethod
|
|
def format_messages(messages: List[Union[dict, Message]]) -> List[dict]:
|
|
"""
|
|
Format messages for LLM by converting them to OpenAI message format.
|
|
|
|
Args:
|
|
messages: List of messages that can be either dict or Message objects
|
|
|
|
Returns:
|
|
List[dict]: List of formatted messages in OpenAI format
|
|
|
|
Raises:
|
|
ValueError: If messages are invalid or missing required fields
|
|
TypeError: If unsupported message types are provided
|
|
|
|
Examples:
|
|
>>> msgs = [
|
|
... Message.system_message("You are a helpful assistant"),
|
|
... {"role": "user", "content": "Hello"},
|
|
... Message.user_message("How are you?")
|
|
... ]
|
|
>>> formatted = LLM.format_messages(msgs)
|
|
"""
|
|
formatted_messages = []
|
|
|
|
for message in messages:
|
|
if isinstance(message, Message):
|
|
message = message.to_dict()
|
|
if isinstance(message, dict):
|
|
# If message is a dict, ensure it has required fields
|
|
if "role" not in message:
|
|
raise ValueError("Message dict must contain 'role' field")
|
|
if "content" in message or "tool_calls" in message:
|
|
formatted_messages.append(message)
|
|
# else: do not include the message
|
|
else:
|
|
raise TypeError(f"Unsupported message type: {type(message)}")
|
|
|
|
# Validate all messages have required fields
|
|
for msg in formatted_messages:
|
|
if msg["role"] not in ROLE_VALUES:
|
|
raise ValueError(f"Invalid role: {msg['role']}")
|
|
|
|
return formatted_messages
|
|
|
|
@retry(
|
|
wait=wait_random_exponential(min=1, max=60),
|
|
stop=stop_after_attempt(6),
|
|
)
|
|
async def ask(
|
|
self,
|
|
messages: List[Union[dict, Message]],
|
|
system_msgs: Optional[List[Union[dict, Message]]] = None,
|
|
stream: bool = True,
|
|
temperature: Optional[float] = None,
|
|
) -> str:
|
|
"""
|
|
Send a prompt to the LLM and get the response.
|
|
|
|
Args:
|
|
messages: List of conversation messages
|
|
system_msgs: Optional system messages to prepend
|
|
stream (bool): Whether to stream the response
|
|
temperature (float): Sampling temperature for the response
|
|
|
|
Returns:
|
|
str: The generated response
|
|
|
|
Raises:
|
|
ValueError: If messages are invalid or response is empty
|
|
OpenAIError: If API call fails after retries
|
|
Exception: For unexpected errors
|
|
"""
|
|
try:
|
|
# Format system and user messages
|
|
if system_msgs:
|
|
system_msgs = self.format_messages(system_msgs)
|
|
messages = system_msgs + self.format_messages(messages)
|
|
else:
|
|
messages = self.format_messages(messages)
|
|
|
|
params = {
|
|
"model": self.model,
|
|
"messages": messages,
|
|
}
|
|
|
|
if self.model in REASONING_MODELS:
|
|
params["max_completion_tokens"] = self.max_tokens
|
|
else:
|
|
params["max_tokens"] = self.max_tokens
|
|
params["temperature"] = temperature or self.temperature
|
|
|
|
if not stream:
|
|
# Non-streaming request
|
|
params["stream"] = False
|
|
|
|
response = await self.client.chat.completions.create(**params)
|
|
|
|
if not response.choices or not response.choices[0].message.content:
|
|
raise ValueError("Empty or invalid response from LLM")
|
|
return response.choices[0].message.content
|
|
|
|
# Streaming request
|
|
params["stream"] = True
|
|
response = await self.client.chat.completions.create(**params)
|
|
|
|
collected_messages = []
|
|
async for chunk in response:
|
|
chunk_message = chunk.choices[0].delta.content or ""
|
|
collected_messages.append(chunk_message)
|
|
print(chunk_message, end="", flush=True)
|
|
|
|
print() # Newline after streaming
|
|
full_response = "".join(collected_messages).strip()
|
|
if not full_response:
|
|
raise ValueError("Empty response from streaming LLM")
|
|
return full_response
|
|
|
|
except ValueError as ve:
|
|
logger.error(f"Validation error: {ve}")
|
|
raise
|
|
except OpenAIError as oe:
|
|
logger.error(f"OpenAI API error: {oe}")
|
|
raise
|
|
except Exception as e:
|
|
logger.error(f"Unexpected error in ask: {e}")
|
|
raise
|
|
|
|
@retry(
|
|
wait=wait_random_exponential(min=1, max=60),
|
|
stop=stop_after_attempt(6),
|
|
)
|
|
async def ask_tool(
|
|
self,
|
|
messages: List[Union[dict, Message]],
|
|
system_msgs: Optional[List[Union[dict, Message]]] = None,
|
|
timeout: int = 300,
|
|
tools: Optional[List[dict]] = None,
|
|
tool_choice: TOOL_CHOICE_TYPE = ToolChoice.AUTO, # type: ignore
|
|
temperature: Optional[float] = None,
|
|
**kwargs,
|
|
):
|
|
"""
|
|
Ask LLM using functions/tools and return the response.
|
|
|
|
Args:
|
|
messages: List of conversation messages
|
|
system_msgs: Optional system messages to prepend
|
|
timeout: Request timeout in seconds
|
|
tools: List of tools to use
|
|
tool_choice: Tool choice strategy
|
|
temperature: Sampling temperature for the response
|
|
**kwargs: Additional completion arguments
|
|
|
|
Returns:
|
|
ChatCompletionMessage: The model's response
|
|
|
|
Raises:
|
|
ValueError: If tools, tool_choice, or messages are invalid
|
|
OpenAIError: If API call fails after retries
|
|
Exception: For unexpected errors
|
|
"""
|
|
try:
|
|
# Validate tool_choice
|
|
if tool_choice not in TOOL_CHOICE_VALUES:
|
|
raise ValueError(f"Invalid tool_choice: {tool_choice}")
|
|
|
|
# Format messages
|
|
if system_msgs:
|
|
system_msgs = self.format_messages(system_msgs)
|
|
messages = system_msgs + self.format_messages(messages)
|
|
else:
|
|
messages = self.format_messages(messages)
|
|
|
|
# Validate tools if provided
|
|
if tools:
|
|
for tool in tools:
|
|
if not isinstance(tool, dict) or "type" not in tool:
|
|
raise ValueError("Each tool must be a dict with 'type' field")
|
|
|
|
# Set up the completion request
|
|
params = {
|
|
"model": self.model,
|
|
"messages": messages,
|
|
"tools": tools,
|
|
"tool_choice": tool_choice,
|
|
"timeout": timeout,
|
|
**kwargs,
|
|
}
|
|
|
|
if self.model in REASONING_MODELS:
|
|
params["max_completion_tokens"] = self.max_tokens
|
|
else:
|
|
params["max_tokens"] = self.max_tokens
|
|
params["temperature"] = temperature or self.temperature
|
|
|
|
response = await self.client.chat.completions.create(**params)
|
|
|
|
# Check if response is valid
|
|
if not response.choices or not response.choices[0].message:
|
|
print(response)
|
|
raise ValueError("Invalid or empty response from LLM")
|
|
|
|
return response.choices[0].message
|
|
|
|
except ValueError as ve:
|
|
logger.error(f"Validation error in ask_tool: {ve}")
|
|
raise
|
|
except OpenAIError as oe:
|
|
if isinstance(oe, AuthenticationError):
|
|
logger.error("Authentication failed. Check API key.")
|
|
elif isinstance(oe, RateLimitError):
|
|
logger.error("Rate limit exceeded. Consider increasing retry attempts.")
|
|
elif isinstance(oe, APIError):
|
|
logger.error(f"API error: {oe}")
|
|
raise
|
|
except Exception as e:
|
|
logger.error(f"Unexpected error in ask_tool: {e}")
|
|
raise
|