update format_messages
This commit is contained in:
parent
d5a662cbcc
commit
7b38dd7fbc
81
app/llm.py
81
app/llm.py
@ -30,6 +30,14 @@ from app.schema import (
|
||||
|
||||
|
||||
REASONING_MODELS = ["o1", "o3-mini"]
|
||||
MULTIMODAL_MODELS = [
|
||||
"gpt-4-vision-preview",
|
||||
"gpt-4o",
|
||||
"gpt-4o-mini",
|
||||
"claude-3-opus-20240229",
|
||||
"claude-3-sonnet-20240229",
|
||||
"claude-3-haiku-20240307",
|
||||
]
|
||||
|
||||
|
||||
class TokenCounter:
|
||||
@ -259,12 +267,15 @@ class LLM:
|
||||
return "Token limit exceeded"
|
||||
|
||||
@staticmethod
|
||||
def format_messages(messages: List[Union[dict, Message]]) -> List[dict]:
|
||||
def format_messages(
|
||||
messages: List[Union[dict, Message]], supports_images: bool = False
|
||||
) -> List[dict]:
|
||||
"""
|
||||
Format messages for LLM by converting them to OpenAI message format.
|
||||
|
||||
Args:
|
||||
messages: List of messages that can be either dict or Message objects
|
||||
supports_images: Flag indicating if the target model supports image inputs
|
||||
|
||||
Returns:
|
||||
List[dict]: List of formatted messages in OpenAI format
|
||||
@ -288,20 +299,20 @@ class LLM:
|
||||
if isinstance(message, Message):
|
||||
message = message.to_dict()
|
||||
|
||||
if not isinstance(message, dict):
|
||||
raise TypeError(f"Unsupported message type: {type(message)}")
|
||||
|
||||
# Validate required fields
|
||||
if isinstance(message, dict):
|
||||
# If message is a dict, ensure it has required fields
|
||||
if "role" not in message:
|
||||
raise ValueError("Message dict must contain 'role' field")
|
||||
|
||||
# Process base64 images if present
|
||||
if message.get("base64_image"):
|
||||
# Process base64 images if present and model supports images
|
||||
if supports_images and message.get("base64_image"):
|
||||
# Initialize or convert content to appropriate format
|
||||
if not message.get("content"):
|
||||
message["content"] = []
|
||||
elif isinstance(message["content"], str):
|
||||
message["content"] = [{"type": "text", "text": message["content"]}]
|
||||
message["content"] = [
|
||||
{"type": "text", "text": message["content"]}
|
||||
]
|
||||
elif isinstance(message["content"], list):
|
||||
# Convert string items to proper text objects
|
||||
message["content"] = [
|
||||
@ -325,17 +336,21 @@ class LLM:
|
||||
|
||||
# Remove the base64_image field
|
||||
del message["base64_image"]
|
||||
# If model doesn't support images but message has base64_image, handle gracefully
|
||||
elif not supports_images and message.get("base64_image"):
|
||||
# Just remove the base64_image field and keep the text content
|
||||
del message["base64_image"]
|
||||
|
||||
# Only include messages with content or tool_calls
|
||||
if "content" in message or "tool_calls" in message:
|
||||
formatted_messages.append(message)
|
||||
# else: do not include the message
|
||||
else:
|
||||
raise TypeError(f"Unsupported message type: {type(message)}")
|
||||
|
||||
# Validate all roles
|
||||
invalid_roles = [
|
||||
msg for msg in formatted_messages if msg["role"] not in ROLE_VALUES
|
||||
]
|
||||
if invalid_roles:
|
||||
raise ValueError(f"Invalid role: {invalid_roles[0]['role']}")
|
||||
# Validate all messages have required fields
|
||||
for msg in formatted_messages:
|
||||
if msg["role"] not in ROLE_VALUES:
|
||||
raise ValueError(f"Invalid role: {msg['role']}")
|
||||
|
||||
return formatted_messages
|
||||
|
||||
@ -372,12 +387,15 @@ class LLM:
|
||||
Exception: For unexpected errors
|
||||
"""
|
||||
try:
|
||||
# Format system and user messages
|
||||
# Check if the model supports images
|
||||
supports_images = self.model in MULTIMODAL_MODELS
|
||||
|
||||
# Format system and user messages with image support check
|
||||
if system_msgs:
|
||||
system_msgs = self.format_messages(system_msgs)
|
||||
messages = system_msgs + self.format_messages(messages)
|
||||
system_msgs = self.format_messages(system_msgs, supports_images)
|
||||
messages = system_msgs + self.format_messages(messages, supports_images)
|
||||
else:
|
||||
messages = self.format_messages(messages)
|
||||
messages = self.format_messages(messages, supports_images)
|
||||
|
||||
# Calculate input token count
|
||||
input_tokens = self.count_message_tokens(messages)
|
||||
@ -499,8 +517,15 @@ class LLM:
|
||||
Exception: For unexpected errors
|
||||
"""
|
||||
try:
|
||||
# Format messages
|
||||
formatted_messages = self.format_messages(messages)
|
||||
# For ask_with_images, we always set supports_images to True because
|
||||
# this method should only be called with models that support images
|
||||
if self.model not in MULTIMODAL_MODELS:
|
||||
raise ValueError(
|
||||
f"Model {self.model} does not support images. Use a model from {MULTIMODAL_MODELS}"
|
||||
)
|
||||
|
||||
# Format messages with image support
|
||||
formatted_messages = self.format_messages(messages, supports_images=True)
|
||||
|
||||
# Ensure the last message is from the user to attach images
|
||||
if not formatted_messages or formatted_messages[-1]["role"] != "user":
|
||||
@ -539,7 +564,10 @@ class LLM:
|
||||
|
||||
# Add system messages if provided
|
||||
if system_msgs:
|
||||
all_messages = self.format_messages(system_msgs) + formatted_messages
|
||||
all_messages = (
|
||||
self.format_messages(system_msgs, supports_images=True)
|
||||
+ formatted_messages
|
||||
)
|
||||
else:
|
||||
all_messages = formatted_messages
|
||||
|
||||
@ -653,12 +681,15 @@ class LLM:
|
||||
if tool_choice not in TOOL_CHOICE_VALUES:
|
||||
raise ValueError(f"Invalid tool_choice: {tool_choice}")
|
||||
|
||||
# Check if the model supports images
|
||||
supports_images = self.model in MULTIMODAL_MODELS
|
||||
|
||||
# Format messages
|
||||
if system_msgs:
|
||||
system_msgs = self.format_messages(system_msgs)
|
||||
messages = system_msgs + self.format_messages(messages)
|
||||
system_msgs = self.format_messages(system_msgs, supports_images)
|
||||
messages = system_msgs + self.format_messages(messages, supports_images)
|
||||
else:
|
||||
messages = self.format_messages(messages)
|
||||
messages = self.format_messages(messages, supports_images)
|
||||
|
||||
# Calculate input token count
|
||||
input_tokens = self.count_message_tokens(messages)
|
||||
|
Loading…
x
Reference in New Issue
Block a user