From 7b38dd7fbc3a6197e88e0fa98a16efe30125bf36 Mon Sep 17 00:00:00 2001 From: liangxinbing <1580466765@qq.com> Date: Wed, 19 Mar 2025 13:18:45 +0800 Subject: [PATCH] update format_messages --- app/llm.py | 145 ++++++++++++++++++++++++++++++++--------------------- 1 file changed, 88 insertions(+), 57 deletions(-) diff --git a/app/llm.py b/app/llm.py index 37de566..a67a86f 100644 --- a/app/llm.py +++ b/app/llm.py @@ -30,6 +30,14 @@ from app.schema import ( REASONING_MODELS = ["o1", "o3-mini"] +MULTIMODAL_MODELS = [ + "gpt-4-vision-preview", + "gpt-4o", + "gpt-4o-mini", + "claude-3-opus-20240229", + "claude-3-sonnet-20240229", + "claude-3-haiku-20240307", +] class TokenCounter: @@ -259,12 +267,15 @@ class LLM: return "Token limit exceeded" @staticmethod - def format_messages(messages: List[Union[dict, Message]]) -> List[dict]: + def format_messages( + messages: List[Union[dict, Message]], supports_images: bool = False + ) -> List[dict]: """ Format messages for LLM by converting them to OpenAI message format. Args: messages: List of messages that can be either dict or Message objects + supports_images: Flag indicating if the target model supports image inputs Returns: List[dict]: List of formatted messages in OpenAI format @@ -288,54 +299,58 @@ class LLM: if isinstance(message, Message): message = message.to_dict() - if not isinstance(message, dict): + if isinstance(message, dict): + # If message is a dict, ensure it has required fields + if "role" not in message: + raise ValueError("Message dict must contain 'role' field") + + # Process base64 images if present and model supports images + if supports_images and message.get("base64_image"): + # Initialize or convert content to appropriate format + if not message.get("content"): + message["content"] = [] + elif isinstance(message["content"], str): + message["content"] = [ + {"type": "text", "text": message["content"]} + ] + elif isinstance(message["content"], list): + # Convert string items to proper text objects + message["content"] = [ + ( + {"type": "text", "text": item} + if isinstance(item, str) + else item + ) + for item in message["content"] + ] + + # Add the image to content + message["content"].append( + { + "type": "image_url", + "image_url": { + "url": f"data:image/jpeg;base64,{message['base64_image']}" + }, + } + ) + + # Remove the base64_image field + del message["base64_image"] + # If model doesn't support images but message has base64_image, handle gracefully + elif not supports_images and message.get("base64_image"): + # Just remove the base64_image field and keep the text content + del message["base64_image"] + + if "content" in message or "tool_calls" in message: + formatted_messages.append(message) + # else: do not include the message + else: raise TypeError(f"Unsupported message type: {type(message)}") - # Validate required fields - if "role" not in message: - raise ValueError("Message dict must contain 'role' field") - - # Process base64 images if present - if message.get("base64_image"): - # Initialize or convert content to appropriate format - if not message.get("content"): - message["content"] = [] - elif isinstance(message["content"], str): - message["content"] = [{"type": "text", "text": message["content"]}] - elif isinstance(message["content"], list): - # Convert string items to proper text objects - message["content"] = [ - ( - {"type": "text", "text": item} - if isinstance(item, str) - else item - ) - for item in message["content"] - ] - - # Add the image to content - message["content"].append( - { - "type": "image_url", - "image_url": { - "url": f"data:image/jpeg;base64,{message['base64_image']}" - }, - } - ) - - # Remove the base64_image field - del message["base64_image"] - - # Only include messages with content or tool_calls - if "content" in message or "tool_calls" in message: - formatted_messages.append(message) - - # Validate all roles - invalid_roles = [ - msg for msg in formatted_messages if msg["role"] not in ROLE_VALUES - ] - if invalid_roles: - raise ValueError(f"Invalid role: {invalid_roles[0]['role']}") + # Validate all messages have required fields + for msg in formatted_messages: + if msg["role"] not in ROLE_VALUES: + raise ValueError(f"Invalid role: {msg['role']}") return formatted_messages @@ -372,12 +387,15 @@ class LLM: Exception: For unexpected errors """ try: - # Format system and user messages + # Check if the model supports images + supports_images = self.model in MULTIMODAL_MODELS + + # Format system and user messages with image support check if system_msgs: - system_msgs = self.format_messages(system_msgs) - messages = system_msgs + self.format_messages(messages) + system_msgs = self.format_messages(system_msgs, supports_images) + messages = system_msgs + self.format_messages(messages, supports_images) else: - messages = self.format_messages(messages) + messages = self.format_messages(messages, supports_images) # Calculate input token count input_tokens = self.count_message_tokens(messages) @@ -499,8 +517,15 @@ class LLM: Exception: For unexpected errors """ try: - # Format messages - formatted_messages = self.format_messages(messages) + # For ask_with_images, we always set supports_images to True because + # this method should only be called with models that support images + if self.model not in MULTIMODAL_MODELS: + raise ValueError( + f"Model {self.model} does not support images. Use a model from {MULTIMODAL_MODELS}" + ) + + # Format messages with image support + formatted_messages = self.format_messages(messages, supports_images=True) # Ensure the last message is from the user to attach images if not formatted_messages or formatted_messages[-1]["role"] != "user": @@ -539,7 +564,10 @@ class LLM: # Add system messages if provided if system_msgs: - all_messages = self.format_messages(system_msgs) + formatted_messages + all_messages = ( + self.format_messages(system_msgs, supports_images=True) + + formatted_messages + ) else: all_messages = formatted_messages @@ -653,12 +681,15 @@ class LLM: if tool_choice not in TOOL_CHOICE_VALUES: raise ValueError(f"Invalid tool_choice: {tool_choice}") + # Check if the model supports images + supports_images = self.model in MULTIMODAL_MODELS + # Format messages if system_msgs: - system_msgs = self.format_messages(system_msgs) - messages = system_msgs + self.format_messages(messages) + system_msgs = self.format_messages(system_msgs, supports_images) + messages = system_msgs + self.format_messages(messages, supports_images) else: - messages = self.format_messages(messages) + messages = self.format_messages(messages, supports_images) # Calculate input token count input_tokens = self.count_message_tokens(messages)