update format_messages

This commit is contained in:
liangxinbing 2025-03-19 13:18:45 +08:00
parent d5a662cbcc
commit 7b38dd7fbc

View File

@ -30,6 +30,14 @@ from app.schema import (
REASONING_MODELS = ["o1", "o3-mini"] REASONING_MODELS = ["o1", "o3-mini"]
MULTIMODAL_MODELS = [
"gpt-4-vision-preview",
"gpt-4o",
"gpt-4o-mini",
"claude-3-opus-20240229",
"claude-3-sonnet-20240229",
"claude-3-haiku-20240307",
]
class TokenCounter: class TokenCounter:
@ -259,12 +267,15 @@ class LLM:
return "Token limit exceeded" return "Token limit exceeded"
@staticmethod @staticmethod
def format_messages(messages: List[Union[dict, Message]]) -> List[dict]: def format_messages(
messages: List[Union[dict, Message]], supports_images: bool = False
) -> List[dict]:
""" """
Format messages for LLM by converting them to OpenAI message format. Format messages for LLM by converting them to OpenAI message format.
Args: Args:
messages: List of messages that can be either dict or Message objects messages: List of messages that can be either dict or Message objects
supports_images: Flag indicating if the target model supports image inputs
Returns: Returns:
List[dict]: List of formatted messages in OpenAI format List[dict]: List of formatted messages in OpenAI format
@ -288,54 +299,58 @@ class LLM:
if isinstance(message, Message): if isinstance(message, Message):
message = message.to_dict() message = message.to_dict()
if not isinstance(message, dict): if isinstance(message, dict):
# If message is a dict, ensure it has required fields
if "role" not in message:
raise ValueError("Message dict must contain 'role' field")
# Process base64 images if present and model supports images
if supports_images and message.get("base64_image"):
# Initialize or convert content to appropriate format
if not message.get("content"):
message["content"] = []
elif isinstance(message["content"], str):
message["content"] = [
{"type": "text", "text": message["content"]}
]
elif isinstance(message["content"], list):
# Convert string items to proper text objects
message["content"] = [
(
{"type": "text", "text": item}
if isinstance(item, str)
else item
)
for item in message["content"]
]
# Add the image to content
message["content"].append(
{
"type": "image_url",
"image_url": {
"url": f"data:image/jpeg;base64,{message['base64_image']}"
},
}
)
# Remove the base64_image field
del message["base64_image"]
# If model doesn't support images but message has base64_image, handle gracefully
elif not supports_images and message.get("base64_image"):
# Just remove the base64_image field and keep the text content
del message["base64_image"]
if "content" in message or "tool_calls" in message:
formatted_messages.append(message)
# else: do not include the message
else:
raise TypeError(f"Unsupported message type: {type(message)}") raise TypeError(f"Unsupported message type: {type(message)}")
# Validate required fields # Validate all messages have required fields
if "role" not in message: for msg in formatted_messages:
raise ValueError("Message dict must contain 'role' field") if msg["role"] not in ROLE_VALUES:
raise ValueError(f"Invalid role: {msg['role']}")
# Process base64 images if present
if message.get("base64_image"):
# Initialize or convert content to appropriate format
if not message.get("content"):
message["content"] = []
elif isinstance(message["content"], str):
message["content"] = [{"type": "text", "text": message["content"]}]
elif isinstance(message["content"], list):
# Convert string items to proper text objects
message["content"] = [
(
{"type": "text", "text": item}
if isinstance(item, str)
else item
)
for item in message["content"]
]
# Add the image to content
message["content"].append(
{
"type": "image_url",
"image_url": {
"url": f"data:image/jpeg;base64,{message['base64_image']}"
},
}
)
# Remove the base64_image field
del message["base64_image"]
# Only include messages with content or tool_calls
if "content" in message or "tool_calls" in message:
formatted_messages.append(message)
# Validate all roles
invalid_roles = [
msg for msg in formatted_messages if msg["role"] not in ROLE_VALUES
]
if invalid_roles:
raise ValueError(f"Invalid role: {invalid_roles[0]['role']}")
return formatted_messages return formatted_messages
@ -372,12 +387,15 @@ class LLM:
Exception: For unexpected errors Exception: For unexpected errors
""" """
try: try:
# Format system and user messages # Check if the model supports images
supports_images = self.model in MULTIMODAL_MODELS
# Format system and user messages with image support check
if system_msgs: if system_msgs:
system_msgs = self.format_messages(system_msgs) system_msgs = self.format_messages(system_msgs, supports_images)
messages = system_msgs + self.format_messages(messages) messages = system_msgs + self.format_messages(messages, supports_images)
else: else:
messages = self.format_messages(messages) messages = self.format_messages(messages, supports_images)
# Calculate input token count # Calculate input token count
input_tokens = self.count_message_tokens(messages) input_tokens = self.count_message_tokens(messages)
@ -499,8 +517,15 @@ class LLM:
Exception: For unexpected errors Exception: For unexpected errors
""" """
try: try:
# Format messages # For ask_with_images, we always set supports_images to True because
formatted_messages = self.format_messages(messages) # this method should only be called with models that support images
if self.model not in MULTIMODAL_MODELS:
raise ValueError(
f"Model {self.model} does not support images. Use a model from {MULTIMODAL_MODELS}"
)
# Format messages with image support
formatted_messages = self.format_messages(messages, supports_images=True)
# Ensure the last message is from the user to attach images # Ensure the last message is from the user to attach images
if not formatted_messages or formatted_messages[-1]["role"] != "user": if not formatted_messages or formatted_messages[-1]["role"] != "user":
@ -539,7 +564,10 @@ class LLM:
# Add system messages if provided # Add system messages if provided
if system_msgs: if system_msgs:
all_messages = self.format_messages(system_msgs) + formatted_messages all_messages = (
self.format_messages(system_msgs, supports_images=True)
+ formatted_messages
)
else: else:
all_messages = formatted_messages all_messages = formatted_messages
@ -653,12 +681,15 @@ class LLM:
if tool_choice not in TOOL_CHOICE_VALUES: if tool_choice not in TOOL_CHOICE_VALUES:
raise ValueError(f"Invalid tool_choice: {tool_choice}") raise ValueError(f"Invalid tool_choice: {tool_choice}")
# Check if the model supports images
supports_images = self.model in MULTIMODAL_MODELS
# Format messages # Format messages
if system_msgs: if system_msgs:
system_msgs = self.format_messages(system_msgs) system_msgs = self.format_messages(system_msgs, supports_images)
messages = system_msgs + self.format_messages(messages) messages = system_msgs + self.format_messages(messages, supports_images)
else: else:
messages = self.format_messages(messages) messages = self.format_messages(messages, supports_images)
# Calculate input token count # Calculate input token count
input_tokens = self.count_message_tokens(messages) input_tokens = self.count_message_tokens(messages)