719 lines
26 KiB
Python
719 lines
26 KiB
Python
import math
|
|
from typing import Dict, List, Optional, Union
|
|
|
|
import tiktoken
|
|
from openai import (
|
|
APIError,
|
|
AsyncAzureOpenAI,
|
|
AsyncOpenAI,
|
|
AuthenticationError,
|
|
OpenAIError,
|
|
RateLimitError,
|
|
)
|
|
from tenacity import (
|
|
retry,
|
|
retry_if_exception_type,
|
|
stop_after_attempt,
|
|
wait_random_exponential,
|
|
)
|
|
|
|
from app.config import LLMSettings, config
|
|
from app.exceptions import TokenLimitExceeded
|
|
from app.logger import logger # Assuming a logger is set up in your app
|
|
from app.schema import (
|
|
ROLE_VALUES,
|
|
TOOL_CHOICE_TYPE,
|
|
TOOL_CHOICE_VALUES,
|
|
Message,
|
|
ToolChoice,
|
|
)
|
|
|
|
|
|
REASONING_MODELS = ["o1", "o3-mini"]
|
|
|
|
|
|
class TokenCounter:
|
|
# Token constants
|
|
BASE_MESSAGE_TOKENS = 4
|
|
FORMAT_TOKENS = 2
|
|
LOW_DETAIL_IMAGE_TOKENS = 85
|
|
HIGH_DETAIL_TILE_TOKENS = 170
|
|
|
|
# Image processing constants
|
|
MAX_SIZE = 2048
|
|
HIGH_DETAIL_TARGET_SHORT_SIDE = 768
|
|
TILE_SIZE = 512
|
|
|
|
def __init__(self, tokenizer):
|
|
self.tokenizer = tokenizer
|
|
|
|
def count_text(self, text: str) -> int:
|
|
"""Calculate tokens for a text string"""
|
|
return 0 if not text else len(self.tokenizer.encode(text))
|
|
|
|
def count_image(self, image_item: dict) -> int:
|
|
"""
|
|
Calculate tokens for an image based on detail level and dimensions
|
|
|
|
For "low" detail: fixed 85 tokens
|
|
For "high" detail:
|
|
1. Scale to fit in 2048x2048 square
|
|
2. Scale shortest side to 768px
|
|
3. Count 512px tiles (170 tokens each)
|
|
4. Add 85 tokens
|
|
"""
|
|
detail = image_item.get("detail", "medium")
|
|
|
|
# For low detail, always return fixed token count
|
|
if detail == "low":
|
|
return self.LOW_DETAIL_IMAGE_TOKENS
|
|
|
|
# For medium detail (default in OpenAI), use high detail calculation
|
|
# OpenAI doesn't specify a separate calculation for medium
|
|
|
|
# For high detail, calculate based on dimensions if available
|
|
if detail == "high" or detail == "medium":
|
|
# If dimensions are provided in the image_item
|
|
if "dimensions" in image_item:
|
|
width, height = image_item["dimensions"]
|
|
return self._calculate_high_detail_tokens(width, height)
|
|
|
|
# Default values when dimensions aren't available or detail level is unknown
|
|
if detail == "high":
|
|
# Default to a 1024x1024 image calculation for high detail
|
|
return self._calculate_high_detail_tokens(1024, 1024) # 765 tokens
|
|
elif detail == "medium":
|
|
# Default to a medium-sized image for medium detail
|
|
return 1024 # This matches the original default
|
|
else:
|
|
# For unknown detail levels, use medium as default
|
|
return 1024
|
|
|
|
def _calculate_high_detail_tokens(self, width: int, height: int) -> int:
|
|
"""Calculate tokens for high detail images based on dimensions"""
|
|
# Step 1: Scale to fit in MAX_SIZE x MAX_SIZE square
|
|
if width > self.MAX_SIZE or height > self.MAX_SIZE:
|
|
scale = self.MAX_SIZE / max(width, height)
|
|
width = int(width * scale)
|
|
height = int(height * scale)
|
|
|
|
# Step 2: Scale so shortest side is HIGH_DETAIL_TARGET_SHORT_SIDE
|
|
scale = self.HIGH_DETAIL_TARGET_SHORT_SIDE / min(width, height)
|
|
scaled_width = int(width * scale)
|
|
scaled_height = int(height * scale)
|
|
|
|
# Step 3: Count number of 512px tiles
|
|
tiles_x = math.ceil(scaled_width / self.TILE_SIZE)
|
|
tiles_y = math.ceil(scaled_height / self.TILE_SIZE)
|
|
total_tiles = tiles_x * tiles_y
|
|
|
|
# Step 4: Calculate final token count
|
|
return (
|
|
total_tiles * self.HIGH_DETAIL_TILE_TOKENS
|
|
) + self.LOW_DETAIL_IMAGE_TOKENS
|
|
|
|
def count_content(self, content: Union[str, List[Union[str, dict]]]) -> int:
|
|
"""Calculate tokens for message content"""
|
|
if not content:
|
|
return 0
|
|
|
|
if isinstance(content, str):
|
|
return self.count_text(content)
|
|
|
|
token_count = 0
|
|
for item in content:
|
|
if isinstance(item, str):
|
|
token_count += self.count_text(item)
|
|
elif isinstance(item, dict):
|
|
if "text" in item:
|
|
token_count += self.count_text(item["text"])
|
|
elif "image_url" in item:
|
|
token_count += self.count_image(item)
|
|
return token_count
|
|
|
|
def count_tool_calls(self, tool_calls: List[dict]) -> int:
|
|
"""Calculate tokens for tool calls"""
|
|
token_count = 0
|
|
for tool_call in tool_calls:
|
|
if "function" in tool_call:
|
|
function = tool_call["function"]
|
|
token_count += self.count_text(function.get("name", ""))
|
|
token_count += self.count_text(function.get("arguments", ""))
|
|
return token_count
|
|
|
|
def count_message_tokens(self, messages: List[dict]) -> int:
|
|
"""Calculate the total number of tokens in a message list"""
|
|
total_tokens = self.FORMAT_TOKENS # Base format tokens
|
|
|
|
for message in messages:
|
|
tokens = self.BASE_MESSAGE_TOKENS # Base tokens per message
|
|
|
|
# Add role tokens
|
|
tokens += self.count_text(message.get("role", ""))
|
|
|
|
# Add content tokens
|
|
if "content" in message:
|
|
tokens += self.count_content(message["content"])
|
|
|
|
# Add tool calls tokens
|
|
if "tool_calls" in message:
|
|
tokens += self.count_tool_calls(message["tool_calls"])
|
|
|
|
# Add name and tool_call_id tokens
|
|
tokens += self.count_text(message.get("name", ""))
|
|
tokens += self.count_text(message.get("tool_call_id", ""))
|
|
|
|
total_tokens += tokens
|
|
|
|
return total_tokens
|
|
|
|
|
|
class LLM:
|
|
_instances: Dict[str, "LLM"] = {}
|
|
|
|
def __new__(
|
|
cls, config_name: str = "default", llm_config: Optional[LLMSettings] = None
|
|
):
|
|
if config_name not in cls._instances:
|
|
instance = super().__new__(cls)
|
|
instance.__init__(config_name, llm_config)
|
|
cls._instances[config_name] = instance
|
|
return cls._instances[config_name]
|
|
|
|
def __init__(
|
|
self, config_name: str = "default", llm_config: Optional[LLMSettings] = None
|
|
):
|
|
if not hasattr(self, "client"): # Only initialize if not already initialized
|
|
llm_config = llm_config or config.llm
|
|
llm_config = llm_config.get(config_name, llm_config["default"])
|
|
self.model = llm_config.model
|
|
self.max_tokens = llm_config.max_tokens
|
|
self.temperature = llm_config.temperature
|
|
self.api_type = llm_config.api_type
|
|
self.api_key = llm_config.api_key
|
|
self.api_version = llm_config.api_version
|
|
self.base_url = llm_config.base_url
|
|
|
|
# Add token counting related attributes
|
|
self.total_input_tokens = 0
|
|
self.max_input_tokens = (
|
|
llm_config.max_input_tokens
|
|
if hasattr(llm_config, "max_input_tokens")
|
|
else None
|
|
)
|
|
|
|
# Initialize tokenizer
|
|
try:
|
|
self.tokenizer = tiktoken.encoding_for_model(self.model)
|
|
except KeyError:
|
|
# If the model is not in tiktoken's presets, use cl100k_base as default
|
|
self.tokenizer = tiktoken.get_encoding("cl100k_base")
|
|
|
|
if self.api_type == "azure":
|
|
self.client = AsyncAzureOpenAI(
|
|
base_url=self.base_url,
|
|
api_key=self.api_key,
|
|
api_version=self.api_version,
|
|
)
|
|
else:
|
|
self.client = AsyncOpenAI(api_key=self.api_key, base_url=self.base_url)
|
|
|
|
self.token_counter = TokenCounter(self.tokenizer)
|
|
|
|
def count_tokens(self, text: str) -> int:
|
|
"""Calculate the number of tokens in a text"""
|
|
if not text:
|
|
return 0
|
|
return len(self.tokenizer.encode(text))
|
|
|
|
def count_message_tokens(self, messages: List[dict]) -> int:
|
|
return self.token_counter.count_message_tokens(messages)
|
|
|
|
def update_token_count(self, input_tokens: int) -> None:
|
|
"""Update token counts"""
|
|
# Only track tokens if max_input_tokens is set
|
|
self.total_input_tokens += input_tokens
|
|
logger.info(
|
|
f"Token usage: Input={input_tokens}, Cumulative Input={self.total_input_tokens}"
|
|
)
|
|
|
|
def check_token_limit(self, input_tokens: int) -> bool:
|
|
"""Check if token limits are exceeded"""
|
|
if self.max_input_tokens is not None:
|
|
return (self.total_input_tokens + input_tokens) <= self.max_input_tokens
|
|
# If max_input_tokens is not set, always return True
|
|
return True
|
|
|
|
def get_limit_error_message(self, input_tokens: int) -> str:
|
|
"""Generate error message for token limit exceeded"""
|
|
if (
|
|
self.max_input_tokens is not None
|
|
and (self.total_input_tokens + input_tokens) > self.max_input_tokens
|
|
):
|
|
return f"Request may exceed input token limit (Current: {self.total_input_tokens}, Needed: {input_tokens}, Max: {self.max_input_tokens})"
|
|
|
|
return "Token limit exceeded"
|
|
|
|
@staticmethod
|
|
def format_messages(messages: List[Union[dict, Message]]) -> List[dict]:
|
|
"""
|
|
Format messages for LLM by converting them to OpenAI message format.
|
|
|
|
Args:
|
|
messages: List of messages that can be either dict or Message objects
|
|
|
|
Returns:
|
|
List[dict]: List of formatted messages in OpenAI format
|
|
|
|
Raises:
|
|
ValueError: If messages are invalid or missing required fields
|
|
TypeError: If unsupported message types are provided
|
|
|
|
Examples:
|
|
>>> msgs = [
|
|
... Message.system_message("You are a helpful assistant"),
|
|
... {"role": "user", "content": "Hello"},
|
|
... Message.user_message("How are you?")
|
|
... ]
|
|
>>> formatted = LLM.format_messages(msgs)
|
|
"""
|
|
formatted_messages = []
|
|
|
|
for message in messages:
|
|
# Convert Message objects to dictionaries
|
|
if isinstance(message, Message):
|
|
message = message.to_dict()
|
|
|
|
if not isinstance(message, dict):
|
|
raise TypeError(f"Unsupported message type: {type(message)}")
|
|
|
|
# Validate required fields
|
|
if "role" not in message:
|
|
raise ValueError("Message dict must contain 'role' field")
|
|
|
|
# Process base64 images if present
|
|
if message.get("base64_image"):
|
|
# Initialize or convert content to appropriate format
|
|
if not message.get("content"):
|
|
message["content"] = []
|
|
elif isinstance(message["content"], str):
|
|
message["content"] = [{"type": "text", "text": message["content"]}]
|
|
elif isinstance(message["content"], list):
|
|
# Convert string items to proper text objects
|
|
message["content"] = [
|
|
(
|
|
{"type": "text", "text": item}
|
|
if isinstance(item, str)
|
|
else item
|
|
)
|
|
for item in message["content"]
|
|
]
|
|
|
|
# Add the image to content
|
|
message["content"].append(
|
|
{
|
|
"type": "image_url",
|
|
"image_url": {
|
|
"url": f"data:image/jpeg;base64,{message['base64_image']}"
|
|
},
|
|
}
|
|
)
|
|
|
|
# Remove the base64_image field
|
|
del message["base64_image"]
|
|
|
|
# Only include messages with content or tool_calls
|
|
if "content" in message or "tool_calls" in message:
|
|
formatted_messages.append(message)
|
|
|
|
# Validate all roles
|
|
invalid_roles = [
|
|
msg for msg in formatted_messages if msg["role"] not in ROLE_VALUES
|
|
]
|
|
if invalid_roles:
|
|
raise ValueError(f"Invalid role: {invalid_roles[0]['role']}")
|
|
|
|
return formatted_messages
|
|
|
|
@retry(
|
|
wait=wait_random_exponential(min=1, max=60),
|
|
stop=stop_after_attempt(6),
|
|
retry=retry_if_exception_type(
|
|
(OpenAIError, Exception, ValueError)
|
|
), # Don't retry TokenLimitExceeded
|
|
)
|
|
async def ask(
|
|
self,
|
|
messages: List[Union[dict, Message]],
|
|
system_msgs: Optional[List[Union[dict, Message]]] = None,
|
|
stream: bool = True,
|
|
temperature: Optional[float] = None,
|
|
) -> str:
|
|
"""
|
|
Send a prompt to the LLM and get the response.
|
|
|
|
Args:
|
|
messages: List of conversation messages
|
|
system_msgs: Optional system messages to prepend
|
|
stream (bool): Whether to stream the response
|
|
temperature (float): Sampling temperature for the response
|
|
|
|
Returns:
|
|
str: The generated response
|
|
|
|
Raises:
|
|
TokenLimitExceeded: If token limits are exceeded
|
|
ValueError: If messages are invalid or response is empty
|
|
OpenAIError: If API call fails after retries
|
|
Exception: For unexpected errors
|
|
"""
|
|
try:
|
|
# Format system and user messages
|
|
if system_msgs:
|
|
system_msgs = self.format_messages(system_msgs)
|
|
messages = system_msgs + self.format_messages(messages)
|
|
else:
|
|
messages = self.format_messages(messages)
|
|
|
|
# Calculate input token count
|
|
input_tokens = self.count_message_tokens(messages)
|
|
|
|
# Check if token limits are exceeded
|
|
if not self.check_token_limit(input_tokens):
|
|
error_message = self.get_limit_error_message(input_tokens)
|
|
# Raise a special exception that won't be retried
|
|
raise TokenLimitExceeded(error_message)
|
|
|
|
params = {
|
|
"model": self.model,
|
|
"messages": messages,
|
|
}
|
|
|
|
if self.model in REASONING_MODELS:
|
|
params["max_completion_tokens"] = self.max_tokens
|
|
else:
|
|
params["max_tokens"] = self.max_tokens
|
|
params["temperature"] = (
|
|
temperature if temperature is not None else self.temperature
|
|
)
|
|
|
|
if not stream:
|
|
# Non-streaming request
|
|
params["stream"] = False
|
|
|
|
response = await self.client.chat.completions.create(**params)
|
|
|
|
if not response.choices or not response.choices[0].message.content:
|
|
raise ValueError("Empty or invalid response from LLM")
|
|
|
|
# Update token counts
|
|
self.update_token_count(response.usage.prompt_tokens)
|
|
|
|
return response.choices[0].message.content
|
|
|
|
# Streaming request, For streaming, update estimated token count before making the request
|
|
self.update_token_count(input_tokens)
|
|
|
|
params["stream"] = True
|
|
response = await self.client.chat.completions.create(**params)
|
|
|
|
collected_messages = []
|
|
async for chunk in response:
|
|
chunk_message = chunk.choices[0].delta.content or ""
|
|
collected_messages.append(chunk_message)
|
|
print(chunk_message, end="", flush=True)
|
|
|
|
print() # Newline after streaming
|
|
full_response = "".join(collected_messages).strip()
|
|
if not full_response:
|
|
raise ValueError("Empty response from streaming LLM")
|
|
|
|
return full_response
|
|
|
|
except TokenLimitExceeded:
|
|
# Re-raise token limit errors without logging
|
|
raise
|
|
except ValueError as ve:
|
|
logger.error(f"Validation error: {ve}")
|
|
raise
|
|
except OpenAIError as oe:
|
|
logger.error(f"OpenAI API error: {oe}")
|
|
if isinstance(oe, AuthenticationError):
|
|
logger.error("Authentication failed. Check API key.")
|
|
elif isinstance(oe, RateLimitError):
|
|
logger.error("Rate limit exceeded. Consider increasing retry attempts.")
|
|
elif isinstance(oe, APIError):
|
|
logger.error(f"API error: {oe}")
|
|
raise
|
|
except Exception as e:
|
|
logger.error(f"Unexpected error in ask: {e}")
|
|
raise
|
|
|
|
@retry(
|
|
wait=wait_random_exponential(min=1, max=60),
|
|
stop=stop_after_attempt(6),
|
|
retry=retry_if_exception_type(
|
|
(OpenAIError, Exception, ValueError)
|
|
), # Don't retry TokenLimitExceeded
|
|
)
|
|
async def ask_with_images(
|
|
self,
|
|
messages: List[Union[dict, Message]],
|
|
images: List[Union[str, dict]],
|
|
system_msgs: Optional[List[Union[dict, Message]]] = None,
|
|
stream: bool = False,
|
|
temperature: Optional[float] = None,
|
|
) -> str:
|
|
"""
|
|
Send a prompt with images to the LLM and get the response.
|
|
|
|
Args:
|
|
messages: List of conversation messages
|
|
images: List of image URLs or image data dictionaries
|
|
system_msgs: Optional system messages to prepend
|
|
stream (bool): Whether to stream the response
|
|
temperature (float): Sampling temperature for the response
|
|
|
|
Returns:
|
|
str: The generated response
|
|
|
|
Raises:
|
|
TokenLimitExceeded: If token limits are exceeded
|
|
ValueError: If messages are invalid or response is empty
|
|
OpenAIError: If API call fails after retries
|
|
Exception: For unexpected errors
|
|
"""
|
|
try:
|
|
# Format messages
|
|
formatted_messages = self.format_messages(messages)
|
|
|
|
# Ensure the last message is from the user to attach images
|
|
if not formatted_messages or formatted_messages[-1]["role"] != "user":
|
|
raise ValueError(
|
|
"The last message must be from the user to attach images"
|
|
)
|
|
|
|
# Process the last user message to include images
|
|
last_message = formatted_messages[-1]
|
|
|
|
# Convert content to multimodal format if needed
|
|
content = last_message["content"]
|
|
multimodal_content = (
|
|
[{"type": "text", "text": content}]
|
|
if isinstance(content, str)
|
|
else content
|
|
if isinstance(content, list)
|
|
else []
|
|
)
|
|
|
|
# Add images to content
|
|
for image in images:
|
|
if isinstance(image, str):
|
|
multimodal_content.append(
|
|
{"type": "image_url", "image_url": {"url": image}}
|
|
)
|
|
elif isinstance(image, dict) and "url" in image:
|
|
multimodal_content.append({"type": "image_url", "image_url": image})
|
|
elif isinstance(image, dict) and "image_url" in image:
|
|
multimodal_content.append(image)
|
|
else:
|
|
raise ValueError(f"Unsupported image format: {image}")
|
|
|
|
# Update the message with multimodal content
|
|
last_message["content"] = multimodal_content
|
|
|
|
# Add system messages if provided
|
|
if system_msgs:
|
|
all_messages = self.format_messages(system_msgs) + formatted_messages
|
|
else:
|
|
all_messages = formatted_messages
|
|
|
|
# Calculate tokens and check limits
|
|
input_tokens = self.count_message_tokens(all_messages)
|
|
if not self.check_token_limit(input_tokens):
|
|
raise TokenLimitExceeded(self.get_limit_error_message(input_tokens))
|
|
|
|
# Set up API parameters
|
|
params = {
|
|
"model": self.model,
|
|
"messages": all_messages,
|
|
"stream": stream,
|
|
}
|
|
|
|
# Add model-specific parameters
|
|
if self.model in REASONING_MODELS:
|
|
params["max_completion_tokens"] = self.max_tokens
|
|
else:
|
|
params["max_tokens"] = self.max_tokens
|
|
params["temperature"] = (
|
|
temperature if temperature is not None else self.temperature
|
|
)
|
|
|
|
# Handle non-streaming request
|
|
if not stream:
|
|
response = await self.client.chat.completions.create(**params)
|
|
|
|
if not response.choices or not response.choices[0].message.content:
|
|
raise ValueError("Empty or invalid response from LLM")
|
|
|
|
self.update_token_count(response.usage.prompt_tokens)
|
|
return response.choices[0].message.content
|
|
|
|
# Handle streaming request
|
|
self.update_token_count(input_tokens)
|
|
response = await self.client.chat.completions.create(**params)
|
|
|
|
collected_messages = []
|
|
async for chunk in response:
|
|
chunk_message = chunk.choices[0].delta.content or ""
|
|
collected_messages.append(chunk_message)
|
|
print(chunk_message, end="", flush=True)
|
|
|
|
print() # Newline after streaming
|
|
full_response = "".join(collected_messages).strip()
|
|
|
|
if not full_response:
|
|
raise ValueError("Empty response from streaming LLM")
|
|
|
|
return full_response
|
|
|
|
except TokenLimitExceeded:
|
|
raise
|
|
except ValueError as ve:
|
|
logger.error(f"Validation error in ask_with_images: {ve}")
|
|
raise
|
|
except OpenAIError as oe:
|
|
logger.error(f"OpenAI API error: {oe}")
|
|
if isinstance(oe, AuthenticationError):
|
|
logger.error("Authentication failed. Check API key.")
|
|
elif isinstance(oe, RateLimitError):
|
|
logger.error("Rate limit exceeded. Consider increasing retry attempts.")
|
|
elif isinstance(oe, APIError):
|
|
logger.error(f"API error: {oe}")
|
|
raise
|
|
except Exception as e:
|
|
logger.error(f"Unexpected error in ask_with_images: {e}")
|
|
raise
|
|
|
|
@retry(
|
|
wait=wait_random_exponential(min=1, max=60),
|
|
stop=stop_after_attempt(6),
|
|
retry=retry_if_exception_type(
|
|
(OpenAIError, Exception, ValueError)
|
|
), # Don't retry TokenLimitExceeded
|
|
)
|
|
async def ask_tool(
|
|
self,
|
|
messages: List[Union[dict, Message]],
|
|
system_msgs: Optional[List[Union[dict, Message]]] = None,
|
|
timeout: int = 300,
|
|
tools: Optional[List[dict]] = None,
|
|
tool_choice: TOOL_CHOICE_TYPE = ToolChoice.AUTO, # type: ignore
|
|
temperature: Optional[float] = None,
|
|
**kwargs,
|
|
):
|
|
"""
|
|
Ask LLM using functions/tools and return the response.
|
|
|
|
Args:
|
|
messages: List of conversation messages
|
|
system_msgs: Optional system messages to prepend
|
|
timeout: Request timeout in seconds
|
|
tools: List of tools to use
|
|
tool_choice: Tool choice strategy
|
|
temperature: Sampling temperature for the response
|
|
**kwargs: Additional completion arguments
|
|
|
|
Returns:
|
|
ChatCompletionMessage: The model's response
|
|
|
|
Raises:
|
|
TokenLimitExceeded: If token limits are exceeded
|
|
ValueError: If tools, tool_choice, or messages are invalid
|
|
OpenAIError: If API call fails after retries
|
|
Exception: For unexpected errors
|
|
"""
|
|
try:
|
|
# Validate tool_choice
|
|
if tool_choice not in TOOL_CHOICE_VALUES:
|
|
raise ValueError(f"Invalid tool_choice: {tool_choice}")
|
|
|
|
# Format messages
|
|
if system_msgs:
|
|
system_msgs = self.format_messages(system_msgs)
|
|
messages = system_msgs + self.format_messages(messages)
|
|
else:
|
|
messages = self.format_messages(messages)
|
|
|
|
# Calculate input token count
|
|
input_tokens = self.count_message_tokens(messages)
|
|
|
|
# If there are tools, calculate token count for tool descriptions
|
|
tools_tokens = 0
|
|
if tools:
|
|
for tool in tools:
|
|
tools_tokens += self.count_tokens(str(tool))
|
|
|
|
input_tokens += tools_tokens
|
|
|
|
# Check if token limits are exceeded
|
|
if not self.check_token_limit(input_tokens):
|
|
error_message = self.get_limit_error_message(input_tokens)
|
|
# Raise a special exception that won't be retried
|
|
raise TokenLimitExceeded(error_message)
|
|
|
|
# Validate tools if provided
|
|
if tools:
|
|
for tool in tools:
|
|
if not isinstance(tool, dict) or "type" not in tool:
|
|
raise ValueError("Each tool must be a dict with 'type' field")
|
|
|
|
# Set up the completion request
|
|
params = {
|
|
"model": self.model,
|
|
"messages": messages,
|
|
"tools": tools,
|
|
"tool_choice": tool_choice,
|
|
"timeout": timeout,
|
|
**kwargs,
|
|
}
|
|
|
|
if self.model in REASONING_MODELS:
|
|
params["max_completion_tokens"] = self.max_tokens
|
|
else:
|
|
params["max_tokens"] = self.max_tokens
|
|
params["temperature"] = (
|
|
temperature if temperature is not None else self.temperature
|
|
)
|
|
|
|
response = await self.client.chat.completions.create(**params)
|
|
|
|
# Check if response is valid
|
|
if not response.choices or not response.choices[0].message:
|
|
print(response)
|
|
raise ValueError("Invalid or empty response from LLM")
|
|
|
|
# Update token counts
|
|
self.update_token_count(response.usage.prompt_tokens)
|
|
|
|
return response.choices[0].message
|
|
|
|
except TokenLimitExceeded:
|
|
# Re-raise token limit errors without logging
|
|
raise
|
|
except ValueError as ve:
|
|
logger.error(f"Validation error in ask_tool: {ve}")
|
|
raise
|
|
except OpenAIError as oe:
|
|
logger.error(f"OpenAI API error: {oe}")
|
|
if isinstance(oe, AuthenticationError):
|
|
logger.error("Authentication failed. Check API key.")
|
|
elif isinstance(oe, RateLimitError):
|
|
logger.error("Rate limit exceeded. Consider increasing retry attempts.")
|
|
elif isinstance(oe, APIError):
|
|
logger.error(f"API error: {oe}")
|
|
raise
|
|
except Exception as e:
|
|
logger.error(f"Unexpected error in ask_tool: {e}")
|
|
raise
|