454 lines
17 KiB
Python
454 lines
17 KiB
Python
from typing import Dict, List, Optional, Union
|
||
|
||
import tiktoken
|
||
from openai import (
|
||
APIError,
|
||
AsyncAzureOpenAI,
|
||
AsyncOpenAI,
|
||
AuthenticationError,
|
||
OpenAIError,
|
||
RateLimitError,
|
||
)
|
||
from tenacity import (
|
||
retry,
|
||
retry_if_exception_type,
|
||
stop_after_attempt,
|
||
wait_random_exponential,
|
||
)
|
||
|
||
from app.config import LLMSettings, config
|
||
from app.exceptions import TokenLimitExceeded
|
||
from app.logger import logger # Assuming a logger is set up in your app
|
||
from app.schema import (
|
||
ROLE_VALUES,
|
||
TOOL_CHOICE_TYPE,
|
||
TOOL_CHOICE_VALUES,
|
||
Message,
|
||
ToolChoice,
|
||
)
|
||
|
||
|
||
REASONING_MODELS = ["o1", "o3-mini"]
|
||
|
||
|
||
class LLM:
|
||
_instances: Dict[str, "LLM"] = {}
|
||
|
||
def __new__(
|
||
cls, config_name: str = "default", llm_config: Optional[LLMSettings] = None
|
||
):
|
||
if config_name not in cls._instances:
|
||
instance = super().__new__(cls)
|
||
instance.__init__(config_name, llm_config)
|
||
cls._instances[config_name] = instance
|
||
return cls._instances[config_name]
|
||
|
||
def __init__(
|
||
self, config_name: str = "default", llm_config: Optional[LLMSettings] = None
|
||
):
|
||
if not hasattr(self, "client"): # Only initialize if not already initialized
|
||
llm_config = llm_config or config.llm
|
||
llm_config = llm_config.get(config_name, llm_config["default"])
|
||
self.model = llm_config.model
|
||
self.max_tokens = llm_config.max_tokens
|
||
self.temperature = llm_config.temperature
|
||
self.api_type = llm_config.api_type
|
||
self.api_key = llm_config.api_key
|
||
self.api_version = llm_config.api_version
|
||
self.base_url = llm_config.base_url
|
||
|
||
# Add token counting related attributes
|
||
self.total_input_tokens = 0
|
||
self.total_completion_tokens = 0
|
||
self.max_input_tokens = (
|
||
llm_config.max_input_tokens
|
||
if hasattr(llm_config, "max_input_tokens")
|
||
else None
|
||
)
|
||
|
||
# Initialize tokenizer
|
||
try:
|
||
self.tokenizer = tiktoken.encoding_for_model(self.model)
|
||
except KeyError:
|
||
# If the model is not in tiktoken's presets, use cl100k_base as default
|
||
self.tokenizer = tiktoken.get_encoding("cl100k_base")
|
||
|
||
if self.api_type == "azure":
|
||
self.client = AsyncAzureOpenAI(
|
||
base_url=self.base_url,
|
||
api_key=self.api_key,
|
||
api_version=self.api_version,
|
||
)
|
||
else:
|
||
self.client = AsyncOpenAI(api_key=self.api_key, base_url=self.base_url)
|
||
|
||
def count_tokens(self, text: str) -> int:
|
||
"""Calculate the number of tokens in a text"""
|
||
if not text:
|
||
return 0
|
||
return len(self.tokenizer.encode(text))
|
||
|
||
def count_message_tokens(self, messages: List[dict]) -> int:
|
||
"""Calculate the number of tokens in a message list"""
|
||
token_count = 0
|
||
for message in messages:
|
||
# Base token count for each message (according to OpenAI's calculation method)
|
||
token_count += 4 # Base token count for each message
|
||
|
||
# Calculate tokens for the role
|
||
if "role" in message:
|
||
token_count += self.count_tokens(message["role"])
|
||
|
||
# Calculate tokens for the content
|
||
if "content" in message and message["content"]:
|
||
token_count += self.count_tokens(message["content"])
|
||
|
||
# Calculate tokens for tool calls
|
||
if "tool_calls" in message and message["tool_calls"]:
|
||
for tool_call in message["tool_calls"]:
|
||
if "function" in tool_call:
|
||
# Function name
|
||
if "name" in tool_call["function"]:
|
||
token_count += self.count_tokens(
|
||
tool_call["function"]["name"]
|
||
)
|
||
# Function arguments
|
||
if "arguments" in tool_call["function"]:
|
||
token_count += self.count_tokens(
|
||
tool_call["function"]["arguments"]
|
||
)
|
||
|
||
# Calculate tokens for tool responses
|
||
if "name" in message and message["name"]:
|
||
token_count += self.count_tokens(message["name"])
|
||
|
||
if "tool_call_id" in message and message["tool_call_id"]:
|
||
token_count += self.count_tokens(message["tool_call_id"])
|
||
|
||
# Add extra tokens for message format
|
||
token_count += 2 # Extra tokens for message format
|
||
|
||
return token_count
|
||
|
||
def update_token_count(self, input_tokens: int, completion_tokens: int = 0) -> None:
|
||
"""Update token counts"""
|
||
# Only track tokens if max_input_tokens is set
|
||
self.total_input_tokens += input_tokens
|
||
self.total_completion_tokens += completion_tokens
|
||
logger.info(
|
||
f"Token usage: Input={input_tokens}, Completion={completion_tokens}, "
|
||
f"Cumulative Input={self.total_input_tokens}, Cumulative Completion={self.total_completion_tokens}, "
|
||
f"Total={input_tokens + completion_tokens}, Cumulative Total={self.total_input_tokens + self.total_completion_tokens}"
|
||
)
|
||
|
||
def check_token_limit(self, input_tokens: int) -> bool:
|
||
"""Check if token limits are exceeded"""
|
||
if self.max_input_tokens is not None:
|
||
return (self.total_input_tokens + input_tokens) <= self.max_input_tokens
|
||
# If max_input_tokens is not set, always return True
|
||
return True
|
||
|
||
def get_limit_error_message(self, input_tokens: int) -> str:
|
||
"""Generate error message for token limit exceeded"""
|
||
if (
|
||
self.max_input_tokens is not None
|
||
and (self.total_input_tokens + input_tokens) > self.max_input_tokens
|
||
):
|
||
return f"Request may exceed input token limit (Current: {self.total_input_tokens}, Needed: {input_tokens}, Max: {self.max_input_tokens})"
|
||
|
||
return "Token limit exceeded"
|
||
|
||
@staticmethod
|
||
def format_messages(messages: List[Union[dict, Message]]) -> List[dict]:
|
||
"""
|
||
Format messages for LLM by converting them to OpenAI message format.
|
||
|
||
Args:
|
||
messages: List of messages that can be either dict or Message objects
|
||
|
||
Returns:
|
||
List[dict]: List of formatted messages in OpenAI format
|
||
|
||
Raises:
|
||
ValueError: If messages are invalid or missing required fields
|
||
TypeError: If unsupported message types are provided
|
||
|
||
Examples:
|
||
>>> msgs = [
|
||
... Message.system_message("You are a helpful assistant"),
|
||
... {"role": "user", "content": "Hello"},
|
||
... Message.user_message("How are you?")
|
||
... ]
|
||
>>> formatted = LLM.format_messages(msgs)
|
||
"""
|
||
formatted_messages = []
|
||
|
||
for message in messages:
|
||
if isinstance(message, Message):
|
||
message = message.to_dict()
|
||
if isinstance(message, dict):
|
||
# If message is a dict, ensure it has required fields
|
||
if "role" not in message:
|
||
raise ValueError("Message dict must contain 'role' field")
|
||
if "content" in message or "tool_calls" in message:
|
||
formatted_messages.append(message)
|
||
# else: do not include the message
|
||
else:
|
||
raise TypeError(f"Unsupported message type: {type(message)}")
|
||
|
||
# Validate all messages have required fields
|
||
for msg in formatted_messages:
|
||
if msg["role"] not in ROLE_VALUES:
|
||
raise ValueError(f"Invalid role: {msg['role']}")
|
||
|
||
return formatted_messages
|
||
|
||
@retry(
|
||
wait=wait_random_exponential(min=1, max=60),
|
||
stop=stop_after_attempt(6),
|
||
retry=retry_if_exception_type(
|
||
(OpenAIError, Exception, ValueError)
|
||
), # Don't retry TokenLimitExceeded
|
||
)
|
||
async def ask(
|
||
self,
|
||
messages: List[Union[dict, Message]],
|
||
system_msgs: Optional[List[Union[dict, Message]]] = None,
|
||
stream: bool = True,
|
||
temperature: Optional[float] = None,
|
||
) -> str:
|
||
"""
|
||
Send a prompt to the LLM and get the response.
|
||
|
||
Args:
|
||
messages: List of conversation messages
|
||
system_msgs: Optional system messages to prepend
|
||
stream (bool): Whether to stream the response
|
||
temperature (float): Sampling temperature for the response
|
||
|
||
Returns:
|
||
str: The generated response
|
||
|
||
Raises:
|
||
TokenLimitExceeded: If token limits are exceeded
|
||
ValueError: If messages are invalid or response is empty
|
||
OpenAIError: If API call fails after retries
|
||
Exception: For unexpected errors
|
||
"""
|
||
try:
|
||
# Format system and user messages
|
||
if system_msgs:
|
||
system_msgs = self.format_messages(system_msgs)
|
||
messages = system_msgs + self.format_messages(messages)
|
||
else:
|
||
messages = self.format_messages(messages)
|
||
|
||
# Calculate input token count
|
||
input_tokens = self.count_message_tokens(messages)
|
||
|
||
# Check if token limits are exceeded
|
||
if not self.check_token_limit(input_tokens):
|
||
error_message = self.get_limit_error_message(input_tokens)
|
||
# Raise a special exception that won't be retried
|
||
raise TokenLimitExceeded(error_message)
|
||
|
||
params = {
|
||
"model": self.model,
|
||
"messages": messages,
|
||
}
|
||
|
||
if self.model in REASONING_MODELS:
|
||
params["max_completion_tokens"] = self.max_tokens
|
||
else:
|
||
params["max_tokens"] = self.max_tokens
|
||
params["temperature"] = (
|
||
temperature if temperature is not None else self.temperature
|
||
)
|
||
|
||
if not stream:
|
||
# Non-streaming request
|
||
params["stream"] = False
|
||
|
||
response = await self.client.chat.completions.create(**params)
|
||
|
||
if not response.choices or not response.choices[0].message.content:
|
||
raise ValueError("Empty or invalid response from LLM")
|
||
|
||
# Update token counts
|
||
self.update_token_count(
|
||
response.usage.prompt_tokens, response.usage.completion_tokens
|
||
)
|
||
|
||
return response.choices[0].message.content
|
||
|
||
# Streaming request, For streaming, update estimated token count before making the request
|
||
self.update_token_count(input_tokens)
|
||
|
||
params["stream"] = True
|
||
response = await self.client.chat.completions.create(**params)
|
||
|
||
collected_messages = []
|
||
completion_text = ""
|
||
async for chunk in response:
|
||
chunk_message = chunk.choices[0].delta.content or ""
|
||
collected_messages.append(chunk_message)
|
||
completion_text += chunk_message
|
||
print(chunk_message, end="", flush=True)
|
||
|
||
print() # Newline after streaming
|
||
full_response = "".join(collected_messages).strip()
|
||
if not full_response:
|
||
raise ValueError("Empty response from streaming LLM")
|
||
|
||
# 对于流式响应,估算completion tokens
|
||
completion_tokens = self.count_tokens(completion_text)
|
||
logger.info(
|
||
f"Estimated completion tokens for streaming response: {completion_tokens}"
|
||
)
|
||
self.total_completion_tokens += completion_tokens
|
||
|
||
return full_response
|
||
|
||
except TokenLimitExceeded:
|
||
# Re-raise token limit errors without logging
|
||
raise
|
||
except ValueError as ve:
|
||
logger.error(f"Validation error: {ve}")
|
||
raise
|
||
except OpenAIError as oe:
|
||
logger.error(f"OpenAI API error: {oe}")
|
||
if isinstance(oe, AuthenticationError):
|
||
logger.error("Authentication failed. Check API key.")
|
||
elif isinstance(oe, RateLimitError):
|
||
logger.error("Rate limit exceeded. Consider increasing retry attempts.")
|
||
elif isinstance(oe, APIError):
|
||
logger.error(f"API error: {oe}")
|
||
raise
|
||
except Exception as e:
|
||
logger.error(f"Unexpected error in ask: {e}")
|
||
raise
|
||
|
||
@retry(
|
||
wait=wait_random_exponential(min=1, max=60),
|
||
stop=stop_after_attempt(6),
|
||
retry=retry_if_exception_type(
|
||
(OpenAIError, Exception, ValueError)
|
||
), # Don't retry TokenLimitExceeded
|
||
)
|
||
async def ask_tool(
|
||
self,
|
||
messages: List[Union[dict, Message]],
|
||
system_msgs: Optional[List[Union[dict, Message]]] = None,
|
||
timeout: int = 300,
|
||
tools: Optional[List[dict]] = None,
|
||
tool_choice: TOOL_CHOICE_TYPE = ToolChoice.AUTO, # type: ignore
|
||
temperature: Optional[float] = None,
|
||
**kwargs,
|
||
):
|
||
"""
|
||
Ask LLM using functions/tools and return the response.
|
||
|
||
Args:
|
||
messages: List of conversation messages
|
||
system_msgs: Optional system messages to prepend
|
||
timeout: Request timeout in seconds
|
||
tools: List of tools to use
|
||
tool_choice: Tool choice strategy
|
||
temperature: Sampling temperature for the response
|
||
**kwargs: Additional completion arguments
|
||
|
||
Returns:
|
||
ChatCompletionMessage: The model's response
|
||
|
||
Raises:
|
||
TokenLimitExceeded: If token limits are exceeded
|
||
ValueError: If tools, tool_choice, or messages are invalid
|
||
OpenAIError: If API call fails after retries
|
||
Exception: For unexpected errors
|
||
"""
|
||
try:
|
||
# Validate tool_choice
|
||
if tool_choice not in TOOL_CHOICE_VALUES:
|
||
raise ValueError(f"Invalid tool_choice: {tool_choice}")
|
||
|
||
# Format messages
|
||
if system_msgs:
|
||
system_msgs = self.format_messages(system_msgs)
|
||
messages = system_msgs + self.format_messages(messages)
|
||
else:
|
||
messages = self.format_messages(messages)
|
||
|
||
# Calculate input token count
|
||
input_tokens = self.count_message_tokens(messages)
|
||
|
||
# If there are tools, calculate token count for tool descriptions
|
||
tools_tokens = 0
|
||
if tools:
|
||
for tool in tools:
|
||
tools_tokens += self.count_tokens(str(tool))
|
||
|
||
input_tokens += tools_tokens
|
||
|
||
# Check if token limits are exceeded
|
||
if not self.check_token_limit(input_tokens):
|
||
error_message = self.get_limit_error_message(input_tokens)
|
||
# Raise a special exception that won't be retried
|
||
raise TokenLimitExceeded(error_message)
|
||
|
||
# Validate tools if provided
|
||
if tools:
|
||
for tool in tools:
|
||
if not isinstance(tool, dict) or "type" not in tool:
|
||
raise ValueError("Each tool must be a dict with 'type' field")
|
||
|
||
# Set up the completion request
|
||
params = {
|
||
"model": self.model,
|
||
"messages": messages,
|
||
"tools": tools,
|
||
"tool_choice": tool_choice,
|
||
"timeout": timeout,
|
||
**kwargs,
|
||
}
|
||
|
||
if self.model in REASONING_MODELS:
|
||
params["max_completion_tokens"] = self.max_tokens
|
||
else:
|
||
params["max_tokens"] = self.max_tokens
|
||
params["temperature"] = (
|
||
temperature if temperature is not None else self.temperature
|
||
)
|
||
|
||
response = await self.client.chat.completions.create(**params)
|
||
|
||
# Check if response is valid
|
||
if not response.choices or not response.choices[0].message:
|
||
print(response)
|
||
raise ValueError("Invalid or empty response from LLM")
|
||
|
||
# Update token counts
|
||
self.update_token_count(
|
||
response.usage.prompt_tokens, response.usage.completion_tokens
|
||
)
|
||
|
||
return response.choices[0].message
|
||
|
||
except TokenLimitExceeded:
|
||
# Re-raise token limit errors without logging
|
||
raise
|
||
except ValueError as ve:
|
||
logger.error(f"Validation error in ask_tool: {ve}")
|
||
raise
|
||
except OpenAIError as oe:
|
||
logger.error(f"OpenAI API error: {oe}")
|
||
if isinstance(oe, AuthenticationError):
|
||
logger.error("Authentication failed. Check API key.")
|
||
elif isinstance(oe, RateLimitError):
|
||
logger.error("Rate limit exceeded. Consider increasing retry attempts.")
|
||
elif isinstance(oe, APIError):
|
||
logger.error(f"API error: {oe}")
|
||
raise
|
||
except Exception as e:
|
||
logger.error(f"Unexpected error in ask_tool: {e}")
|
||
raise
|