update format_messages
This commit is contained in:
parent
d5a662cbcc
commit
7b38dd7fbc
81
app/llm.py
81
app/llm.py
@ -30,6 +30,14 @@ from app.schema import (
|
|||||||
|
|
||||||
|
|
||||||
REASONING_MODELS = ["o1", "o3-mini"]
|
REASONING_MODELS = ["o1", "o3-mini"]
|
||||||
|
MULTIMODAL_MODELS = [
|
||||||
|
"gpt-4-vision-preview",
|
||||||
|
"gpt-4o",
|
||||||
|
"gpt-4o-mini",
|
||||||
|
"claude-3-opus-20240229",
|
||||||
|
"claude-3-sonnet-20240229",
|
||||||
|
"claude-3-haiku-20240307",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
class TokenCounter:
|
class TokenCounter:
|
||||||
@ -259,12 +267,15 @@ class LLM:
|
|||||||
return "Token limit exceeded"
|
return "Token limit exceeded"
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def format_messages(messages: List[Union[dict, Message]]) -> List[dict]:
|
def format_messages(
|
||||||
|
messages: List[Union[dict, Message]], supports_images: bool = False
|
||||||
|
) -> List[dict]:
|
||||||
"""
|
"""
|
||||||
Format messages for LLM by converting them to OpenAI message format.
|
Format messages for LLM by converting them to OpenAI message format.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
messages: List of messages that can be either dict or Message objects
|
messages: List of messages that can be either dict or Message objects
|
||||||
|
supports_images: Flag indicating if the target model supports image inputs
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
List[dict]: List of formatted messages in OpenAI format
|
List[dict]: List of formatted messages in OpenAI format
|
||||||
@ -288,20 +299,20 @@ class LLM:
|
|||||||
if isinstance(message, Message):
|
if isinstance(message, Message):
|
||||||
message = message.to_dict()
|
message = message.to_dict()
|
||||||
|
|
||||||
if not isinstance(message, dict):
|
if isinstance(message, dict):
|
||||||
raise TypeError(f"Unsupported message type: {type(message)}")
|
# If message is a dict, ensure it has required fields
|
||||||
|
|
||||||
# Validate required fields
|
|
||||||
if "role" not in message:
|
if "role" not in message:
|
||||||
raise ValueError("Message dict must contain 'role' field")
|
raise ValueError("Message dict must contain 'role' field")
|
||||||
|
|
||||||
# Process base64 images if present
|
# Process base64 images if present and model supports images
|
||||||
if message.get("base64_image"):
|
if supports_images and message.get("base64_image"):
|
||||||
# Initialize or convert content to appropriate format
|
# Initialize or convert content to appropriate format
|
||||||
if not message.get("content"):
|
if not message.get("content"):
|
||||||
message["content"] = []
|
message["content"] = []
|
||||||
elif isinstance(message["content"], str):
|
elif isinstance(message["content"], str):
|
||||||
message["content"] = [{"type": "text", "text": message["content"]}]
|
message["content"] = [
|
||||||
|
{"type": "text", "text": message["content"]}
|
||||||
|
]
|
||||||
elif isinstance(message["content"], list):
|
elif isinstance(message["content"], list):
|
||||||
# Convert string items to proper text objects
|
# Convert string items to proper text objects
|
||||||
message["content"] = [
|
message["content"] = [
|
||||||
@ -325,17 +336,21 @@ class LLM:
|
|||||||
|
|
||||||
# Remove the base64_image field
|
# Remove the base64_image field
|
||||||
del message["base64_image"]
|
del message["base64_image"]
|
||||||
|
# If model doesn't support images but message has base64_image, handle gracefully
|
||||||
|
elif not supports_images and message.get("base64_image"):
|
||||||
|
# Just remove the base64_image field and keep the text content
|
||||||
|
del message["base64_image"]
|
||||||
|
|
||||||
# Only include messages with content or tool_calls
|
|
||||||
if "content" in message or "tool_calls" in message:
|
if "content" in message or "tool_calls" in message:
|
||||||
formatted_messages.append(message)
|
formatted_messages.append(message)
|
||||||
|
# else: do not include the message
|
||||||
|
else:
|
||||||
|
raise TypeError(f"Unsupported message type: {type(message)}")
|
||||||
|
|
||||||
# Validate all roles
|
# Validate all messages have required fields
|
||||||
invalid_roles = [
|
for msg in formatted_messages:
|
||||||
msg for msg in formatted_messages if msg["role"] not in ROLE_VALUES
|
if msg["role"] not in ROLE_VALUES:
|
||||||
]
|
raise ValueError(f"Invalid role: {msg['role']}")
|
||||||
if invalid_roles:
|
|
||||||
raise ValueError(f"Invalid role: {invalid_roles[0]['role']}")
|
|
||||||
|
|
||||||
return formatted_messages
|
return formatted_messages
|
||||||
|
|
||||||
@ -372,12 +387,15 @@ class LLM:
|
|||||||
Exception: For unexpected errors
|
Exception: For unexpected errors
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
# Format system and user messages
|
# Check if the model supports images
|
||||||
|
supports_images = self.model in MULTIMODAL_MODELS
|
||||||
|
|
||||||
|
# Format system and user messages with image support check
|
||||||
if system_msgs:
|
if system_msgs:
|
||||||
system_msgs = self.format_messages(system_msgs)
|
system_msgs = self.format_messages(system_msgs, supports_images)
|
||||||
messages = system_msgs + self.format_messages(messages)
|
messages = system_msgs + self.format_messages(messages, supports_images)
|
||||||
else:
|
else:
|
||||||
messages = self.format_messages(messages)
|
messages = self.format_messages(messages, supports_images)
|
||||||
|
|
||||||
# Calculate input token count
|
# Calculate input token count
|
||||||
input_tokens = self.count_message_tokens(messages)
|
input_tokens = self.count_message_tokens(messages)
|
||||||
@ -499,8 +517,15 @@ class LLM:
|
|||||||
Exception: For unexpected errors
|
Exception: For unexpected errors
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
# Format messages
|
# For ask_with_images, we always set supports_images to True because
|
||||||
formatted_messages = self.format_messages(messages)
|
# this method should only be called with models that support images
|
||||||
|
if self.model not in MULTIMODAL_MODELS:
|
||||||
|
raise ValueError(
|
||||||
|
f"Model {self.model} does not support images. Use a model from {MULTIMODAL_MODELS}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Format messages with image support
|
||||||
|
formatted_messages = self.format_messages(messages, supports_images=True)
|
||||||
|
|
||||||
# Ensure the last message is from the user to attach images
|
# Ensure the last message is from the user to attach images
|
||||||
if not formatted_messages or formatted_messages[-1]["role"] != "user":
|
if not formatted_messages or formatted_messages[-1]["role"] != "user":
|
||||||
@ -539,7 +564,10 @@ class LLM:
|
|||||||
|
|
||||||
# Add system messages if provided
|
# Add system messages if provided
|
||||||
if system_msgs:
|
if system_msgs:
|
||||||
all_messages = self.format_messages(system_msgs) + formatted_messages
|
all_messages = (
|
||||||
|
self.format_messages(system_msgs, supports_images=True)
|
||||||
|
+ formatted_messages
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
all_messages = formatted_messages
|
all_messages = formatted_messages
|
||||||
|
|
||||||
@ -653,12 +681,15 @@ class LLM:
|
|||||||
if tool_choice not in TOOL_CHOICE_VALUES:
|
if tool_choice not in TOOL_CHOICE_VALUES:
|
||||||
raise ValueError(f"Invalid tool_choice: {tool_choice}")
|
raise ValueError(f"Invalid tool_choice: {tool_choice}")
|
||||||
|
|
||||||
|
# Check if the model supports images
|
||||||
|
supports_images = self.model in MULTIMODAL_MODELS
|
||||||
|
|
||||||
# Format messages
|
# Format messages
|
||||||
if system_msgs:
|
if system_msgs:
|
||||||
system_msgs = self.format_messages(system_msgs)
|
system_msgs = self.format_messages(system_msgs, supports_images)
|
||||||
messages = system_msgs + self.format_messages(messages)
|
messages = system_msgs + self.format_messages(messages, supports_images)
|
||||||
else:
|
else:
|
||||||
messages = self.format_messages(messages)
|
messages = self.format_messages(messages, supports_images)
|
||||||
|
|
||||||
# Calculate input token count
|
# Calculate input token count
|
||||||
input_tokens = self.count_message_tokens(messages)
|
input_tokens = self.count_message_tokens(messages)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user