265 lines
9.7 KiB
Python
265 lines
9.7 KiB
Python
from typing import Dict, List, Optional, Union
|
|
|
|
from openai import (
|
|
APIError,
|
|
AsyncAzureOpenAI,
|
|
AsyncOpenAI,
|
|
AuthenticationError,
|
|
OpenAIError,
|
|
RateLimitError,
|
|
)
|
|
from tenacity import retry, stop_after_attempt, wait_random_exponential
|
|
|
|
from app.config import LLMSettings, config
|
|
from app.logger import logger # Assuming a logger is set up in your app
|
|
from app.schema import Message, TOOL_CHOICE_TYPE, ROLE_VALUES, TOOL_CHOICE_VALUES, ToolChoice
|
|
|
|
|
|
class LLM:
|
|
_instances: Dict[str, "LLM"] = {}
|
|
|
|
def __new__(
|
|
cls, config_name: str = "default", llm_config: Optional[LLMSettings] = None
|
|
):
|
|
if config_name not in cls._instances:
|
|
instance = super().__new__(cls)
|
|
instance.__init__(config_name, llm_config)
|
|
cls._instances[config_name] = instance
|
|
return cls._instances[config_name]
|
|
|
|
def __init__(
|
|
self, config_name: str = "default", llm_config: Optional[LLMSettings] = None
|
|
):
|
|
if not hasattr(self, "client"): # Only initialize if not already initialized
|
|
llm_config = llm_config or config.llm
|
|
llm_config = llm_config.get(config_name, llm_config["default"])
|
|
self.model = llm_config.model
|
|
self.max_tokens = llm_config.max_tokens
|
|
self.temperature = llm_config.temperature
|
|
self.api_type = llm_config.api_type
|
|
self.api_key = llm_config.api_key
|
|
self.api_version = llm_config.api_version
|
|
self.base_url = llm_config.base_url
|
|
if self.api_type == "azure":
|
|
self.client = AsyncAzureOpenAI(
|
|
base_url=self.base_url,
|
|
api_key=self.api_key,
|
|
api_version=self.api_version,
|
|
)
|
|
else:
|
|
self.client = AsyncOpenAI(api_key=self.api_key, base_url=self.base_url)
|
|
|
|
@staticmethod
|
|
def format_messages(messages: List[Union[dict, Message]]) -> List[dict]:
|
|
"""
|
|
Format messages for LLM by converting them to OpenAI message format.
|
|
|
|
Args:
|
|
messages: List of messages that can be either dict or Message objects
|
|
|
|
Returns:
|
|
List[dict]: List of formatted messages in OpenAI format
|
|
|
|
Raises:
|
|
ValueError: If messages are invalid or missing required fields
|
|
TypeError: If unsupported message types are provided
|
|
|
|
Examples:
|
|
>>> msgs = [
|
|
... Message.system_message("You are a helpful assistant"),
|
|
... {"role": "user", "content": "Hello"},
|
|
... Message.user_message("How are you?")
|
|
... ]
|
|
>>> formatted = LLM.format_messages(msgs)
|
|
"""
|
|
formatted_messages = []
|
|
|
|
for message in messages:
|
|
if isinstance(message, dict):
|
|
# If message is already a dict, ensure it has required fields
|
|
if "role" not in message:
|
|
raise ValueError("Message dict must contain 'role' field")
|
|
formatted_messages.append(message)
|
|
elif isinstance(message, Message):
|
|
# If message is a Message object, convert it to dict
|
|
formatted_messages.append(message.to_dict())
|
|
else:
|
|
raise TypeError(f"Unsupported message type: {type(message)}")
|
|
|
|
# Validate all messages have required fields
|
|
for msg in formatted_messages:
|
|
if msg["role"] not in ROLE_VALUES:
|
|
raise ValueError(f"Invalid role: {msg['role']}")
|
|
if "content" not in msg and "tool_calls" not in msg:
|
|
raise ValueError(
|
|
"Message must contain either 'content' or 'tool_calls'"
|
|
)
|
|
|
|
return formatted_messages
|
|
|
|
@retry(
|
|
wait=wait_random_exponential(min=1, max=60),
|
|
stop=stop_after_attempt(6),
|
|
)
|
|
async def ask(
|
|
self,
|
|
messages: List[Union[dict, Message]],
|
|
system_msgs: Optional[List[Union[dict, Message]]] = None,
|
|
stream: bool = True,
|
|
temperature: Optional[float] = None,
|
|
) -> str:
|
|
"""
|
|
Send a prompt to the LLM and get the response.
|
|
|
|
Args:
|
|
messages: List of conversation messages
|
|
system_msgs: Optional system messages to prepend
|
|
stream (bool): Whether to stream the response
|
|
temperature (float): Sampling temperature for the response
|
|
|
|
Returns:
|
|
str: The generated response
|
|
|
|
Raises:
|
|
ValueError: If messages are invalid or response is empty
|
|
OpenAIError: If API call fails after retries
|
|
Exception: For unexpected errors
|
|
"""
|
|
try:
|
|
# Format system and user messages
|
|
if system_msgs:
|
|
system_msgs = self.format_messages(system_msgs)
|
|
messages = system_msgs + self.format_messages(messages)
|
|
else:
|
|
messages = self.format_messages(messages)
|
|
|
|
if not stream:
|
|
# Non-streaming request
|
|
response = await self.client.chat.completions.create(
|
|
model=self.model,
|
|
messages=messages,
|
|
max_tokens=self.max_tokens,
|
|
temperature=temperature or self.temperature,
|
|
stream=False,
|
|
)
|
|
if not response.choices or not response.choices[0].message.content:
|
|
raise ValueError("Empty or invalid response from LLM")
|
|
return response.choices[0].message.content
|
|
|
|
# Streaming request
|
|
response = await self.client.chat.completions.create(
|
|
model=self.model,
|
|
messages=messages,
|
|
max_tokens=self.max_tokens,
|
|
temperature=temperature or self.temperature,
|
|
stream=True,
|
|
)
|
|
|
|
collected_messages = []
|
|
async for chunk in response:
|
|
chunk_message = chunk.choices[0].delta.content or ""
|
|
collected_messages.append(chunk_message)
|
|
print(chunk_message, end="", flush=True)
|
|
|
|
print() # Newline after streaming
|
|
full_response = "".join(collected_messages).strip()
|
|
if not full_response:
|
|
raise ValueError("Empty response from streaming LLM")
|
|
return full_response
|
|
|
|
except ValueError as ve:
|
|
logger.error(f"Validation error: {ve}")
|
|
raise
|
|
except OpenAIError as oe:
|
|
logger.error(f"OpenAI API error: {oe}")
|
|
raise
|
|
except Exception as e:
|
|
logger.error(f"Unexpected error in ask: {e}")
|
|
raise
|
|
|
|
@retry(
|
|
wait=wait_random_exponential(min=1, max=60),
|
|
stop=stop_after_attempt(6),
|
|
)
|
|
async def ask_tool(
|
|
self,
|
|
messages: List[Union[dict, Message]],
|
|
system_msgs: Optional[List[Union[dict, Message]]] = None,
|
|
timeout: int = 300,
|
|
tools: Optional[List[dict]] = None,
|
|
tool_choice: TOOL_CHOICE_TYPE = ToolChoice.AUTO, # type: ignore
|
|
temperature: Optional[float] = None,
|
|
**kwargs,
|
|
):
|
|
"""
|
|
Ask LLM using functions/tools and return the response.
|
|
|
|
Args:
|
|
messages: List of conversation messages
|
|
system_msgs: Optional system messages to prepend
|
|
timeout: Request timeout in seconds
|
|
tools: List of tools to use
|
|
tool_choice: Tool choice strategy
|
|
temperature: Sampling temperature for the response
|
|
**kwargs: Additional completion arguments
|
|
|
|
Returns:
|
|
ChatCompletionMessage: The model's response
|
|
|
|
Raises:
|
|
ValueError: If tools, tool_choice, or messages are invalid
|
|
OpenAIError: If API call fails after retries
|
|
Exception: For unexpected errors
|
|
"""
|
|
try:
|
|
# Validate tool_choice
|
|
if tool_choice not in TOOL_CHOICE_VALUES:
|
|
raise ValueError(f"Invalid tool_choice: {tool_choice}")
|
|
|
|
# Format messages
|
|
if system_msgs:
|
|
system_msgs = self.format_messages(system_msgs)
|
|
messages = system_msgs + self.format_messages(messages)
|
|
else:
|
|
messages = self.format_messages(messages)
|
|
|
|
# Validate tools if provided
|
|
if tools:
|
|
for tool in tools:
|
|
if not isinstance(tool, dict) or "type" not in tool:
|
|
raise ValueError("Each tool must be a dict with 'type' field")
|
|
|
|
# Set up the completion request
|
|
response = await self.client.chat.completions.create(
|
|
model=self.model,
|
|
messages=messages,
|
|
temperature=temperature or self.temperature,
|
|
max_tokens=self.max_tokens,
|
|
tools=tools,
|
|
tool_choice=tool_choice,
|
|
timeout=timeout,
|
|
**kwargs,
|
|
)
|
|
|
|
# Check if response is valid
|
|
if not response.choices or not response.choices[0].message:
|
|
print(response)
|
|
raise ValueError("Invalid or empty response from LLM")
|
|
|
|
return response.choices[0].message
|
|
|
|
except ValueError as ve:
|
|
logger.error(f"Validation error in ask_tool: {ve}")
|
|
raise
|
|
except OpenAIError as oe:
|
|
if isinstance(oe, AuthenticationError):
|
|
logger.error("Authentication failed. Check API key.")
|
|
elif isinstance(oe, RateLimitError):
|
|
logger.error("Rate limit exceeded. Consider increasing retry attempts.")
|
|
elif isinstance(oe, APIError):
|
|
logger.error(f"API error: {oe}")
|
|
raise
|
|
except Exception as e:
|
|
logger.error(f"Unexpected error in ask_tool: {e}")
|
|
raise
|