From 91d14a3a473f9e419caef879d3cf267d2501fbd5 Mon Sep 17 00:00:00 2001 From: liangxinbing <1580466765@qq.com> Date: Tue, 18 Mar 2025 02:30:30 +0800 Subject: [PATCH] update llm, schema, BaseTool and BaseAgent --- app/agent/base.py | 8 ++++--- app/llm.py | 60 +++++++++++++++++++++++++++++++++++++---------- app/schema.py | 38 +++++++++++++++++++++++------- app/tool/base.py | 6 ++--- 4 files changed, 85 insertions(+), 27 deletions(-) diff --git a/app/agent/base.py b/app/agent/base.py index fa3db30..9ece688 100644 --- a/app/agent/base.py +++ b/app/agent/base.py @@ -84,6 +84,7 @@ class BaseAgent(BaseModel, ABC): self, role: ROLE_TYPE, # type: ignore content: str, + base64_image: Optional[str] = None, **kwargs, ) -> None: """Add a message to the agent's memory. @@ -91,6 +92,7 @@ class BaseAgent(BaseModel, ABC): Args: role: The role of the message sender (user, system, assistant, tool). content: The message content. + base64_image: Optional base64 encoded image. **kwargs: Additional arguments (e.g., tool_call_id for tool messages). Raises: @@ -106,9 +108,9 @@ class BaseAgent(BaseModel, ABC): if role not in message_map: raise ValueError(f"Unsupported message role: {role}") - msg_factory = message_map[role] - msg = msg_factory(content, **kwargs) if role == "tool" else msg_factory(content) - self.memory.add_message(msg) + # Create message with appropriate parameters based on role + kwargs = {"base64_image": base64_image, **(kwargs if role == "tool" else {})} + self.memory.add_message(message_map[role](content, **kwargs)) async def run(self, request: Optional[str] = None) -> str: """Execute the agent's main loop asynchronously. diff --git a/app/llm.py b/app/llm.py index e354125..eccedf1 100644 --- a/app/llm.py +++ b/app/llm.py @@ -280,22 +280,58 @@ class LLM: formatted_messages = [] for message in messages: + # Convert Message objects to dictionaries if isinstance(message, Message): message = message.to_dict() - if isinstance(message, dict): - # If message is a dict, ensure it has required fields - if "role" not in message: - raise ValueError("Message dict must contain 'role' field") - if "content" in message or "tool_calls" in message: - formatted_messages.append(message) - # else: do not include the message - else: + + if not isinstance(message, dict): raise TypeError(f"Unsupported message type: {type(message)}") - # Validate all messages have required fields - for msg in formatted_messages: - if msg["role"] not in ROLE_VALUES: - raise ValueError(f"Invalid role: {msg['role']}") + # Validate required fields + if "role" not in message: + raise ValueError("Message dict must contain 'role' field") + + # Process base64 images if present + if message.get("base64_image"): + # Initialize or convert content to appropriate format + if not message.get("content"): + message["content"] = [] + elif isinstance(message["content"], str): + message["content"] = [{"type": "text", "text": message["content"]}] + elif isinstance(message["content"], list): + # Convert string items to proper text objects + message["content"] = [ + ( + {"type": "text", "text": item} + if isinstance(item, str) + else item + ) + for item in message["content"] + ] + + # Add the image to content + message["content"].append( + { + "type": "image_url", + "image_url": { + "url": f"data:image/jpeg;base64,{message['base64_image']}" + }, + } + ) + + # Remove the base64_image field + del message["base64_image"] + + # Only include messages with content or tool_calls + if "content" in message or "tool_calls" in message: + formatted_messages.append(message) + + # Validate all roles + invalid_roles = [ + msg for msg in formatted_messages if msg["role"] not in ROLE_VALUES + ] + if invalid_roles: + raise ValueError(f"Invalid role: {invalid_roles[0]['role']}") return formatted_messages diff --git a/app/schema.py b/app/schema.py index fb89c3c..de18c4f 100644 --- a/app/schema.py +++ b/app/schema.py @@ -59,6 +59,7 @@ class Message(BaseModel): tool_calls: Optional[List[ToolCall]] = Field(default=None) name: Optional[str] = Field(default=None) tool_call_id: Optional[str] = Field(default=None) + base64_image: Optional[str] = Field(default=None) def __add__(self, other) -> List["Message"]: """支持 Message + list 或 Message + Message 的操作""" @@ -91,12 +92,16 @@ class Message(BaseModel): message["name"] = self.name if self.tool_call_id is not None: message["tool_call_id"] = self.tool_call_id + if self.base64_image is not None: + message["base64_image"] = self.base64_image return message @classmethod - def user_message(cls, content: str) -> "Message": + def user_message( + cls, content: str, base64_image: Optional[str] = None + ) -> "Message": """Create a user message""" - return cls(role=Role.USER, content=content) + return cls(role=Role.USER, content=content, base64_image=base64_image) @classmethod def system_message(cls, content: str) -> "Message": @@ -104,33 +109,50 @@ class Message(BaseModel): return cls(role=Role.SYSTEM, content=content) @classmethod - def assistant_message(cls, content: Optional[str] = None) -> "Message": + def assistant_message( + cls, content: Optional[str] = None, base64_image: Optional[str] = None + ) -> "Message": """Create an assistant message""" - return cls(role=Role.ASSISTANT, content=content) + return cls(role=Role.ASSISTANT, content=content, base64_image=base64_image) @classmethod - def tool_message(cls, content: str, name, tool_call_id: str) -> "Message": + def tool_message( + cls, content: str, name, tool_call_id: str, base64_image: Optional[str] = None + ) -> "Message": """Create a tool message""" return cls( - role=Role.TOOL, content=content, name=name, tool_call_id=tool_call_id + role=Role.TOOL, + content=content, + name=name, + tool_call_id=tool_call_id, + base64_image=base64_image, ) @classmethod def from_tool_calls( - cls, tool_calls: List[Any], content: Union[str, List[str]] = "", **kwargs + cls, + tool_calls: List[Any], + content: Union[str, List[str]] = "", + base64_image: Optional[str] = None, + **kwargs, ): """Create ToolCallsMessage from raw tool calls. Args: tool_calls: Raw tool calls from LLM content: Optional message content + base64_image: Optional base64 encoded image """ formatted_calls = [ {"id": call.id, "function": call.function.model_dump(), "type": "function"} for call in tool_calls ] return cls( - role=Role.ASSISTANT, content=content, tool_calls=formatted_calls, **kwargs + role=Role.ASSISTANT, + content=content, + tool_calls=formatted_calls, + base64_image=base64_image, + **kwargs, ) diff --git a/app/tool/base.py b/app/tool/base.py index ae3c9f5..ba4084d 100644 --- a/app/tool/base.py +++ b/app/tool/base.py @@ -37,6 +37,7 @@ class ToolResult(BaseModel): output: Any = Field(default=None) error: Optional[str] = Field(default=None) + base64_image: Optional[str] = Field(default=None) system: Optional[str] = Field(default=None) class Config: @@ -58,6 +59,7 @@ class ToolResult(BaseModel): return ToolResult( output=combine_fields(self.output, other.output), error=combine_fields(self.error, other.error), + base64_image=combine_fields(self.base64_image, other.base64_image, False), system=combine_fields(self.system, other.system), ) @@ -76,7 +78,3 @@ class CLIResult(ToolResult): class ToolFailure(ToolResult): """A ToolResult that represents a failure.""" - - -class AgentAwareTool: - agent: Optional = None