from enum import Enum from typing import Any, List, Literal, Optional, Union from pydantic import BaseModel, Field class Role(str, Enum): """Message role options""" SYSTEM = "system" USER = "user" ASSISTANT = "assistant" TOOL = "tool" ROLE_VALUES = tuple(role.value for role in Role) ROLE_TYPE = Literal[ROLE_VALUES] # type: ignore class ToolChoice(str, Enum): """Tool choice options""" NONE = "none" AUTO = "auto" REQUIRED = "required" TOOL_CHOICE_VALUES = tuple(choice.value for choice in ToolChoice) TOOL_CHOICE_TYPE = Literal[TOOL_CHOICE_VALUES] # type: ignore class AgentState(str, Enum): """Agent execution states""" IDLE = "IDLE" RUNNING = "RUNNING" FINISHED = "FINISHED" ERROR = "ERROR" class Function(BaseModel): name: str arguments: str class ToolCall(BaseModel): """Represents a tool/function call in a message""" id: str type: str = "function" function: Function class Message(BaseModel): """Represents a chat message in the conversation""" role: ROLE_TYPE = Field(...) # type: ignore content: Optional[str] = Field(default=None) tool_calls: Optional[List[ToolCall]] = Field(default=None) name: Optional[str] = Field(default=None) tool_call_id: Optional[str] = Field(default=None) base64_image: Optional[str] = Field(default=None) def __add__(self, other) -> List["Message"]: """支持 Message + list 或 Message + Message 的操作""" if isinstance(other, list): return [self] + other elif isinstance(other, Message): return [self, other] else: raise TypeError( f"unsupported operand type(s) for +: '{type(self).__name__}' and '{type(other).__name__}'" ) def __radd__(self, other) -> List["Message"]: """支持 list + Message 的操作""" if isinstance(other, list): return other + [self] else: raise TypeError( f"unsupported operand type(s) for +: '{type(other).__name__}' and '{type(self).__name__}'" ) def to_dict(self) -> dict: """Convert message to dictionary format""" message = {"role": self.role} if self.content is not None: message["content"] = self.content if self.tool_calls is not None: message["tool_calls"] = [tool_call.dict() for tool_call in self.tool_calls] if self.name is not None: message["name"] = self.name if self.tool_call_id is not None: message["tool_call_id"] = self.tool_call_id if self.base64_image is not None: message["base64_image"] = self.base64_image return message @classmethod def user_message( cls, content: str, base64_image: Optional[str] = None ) -> "Message": """Create a user message""" return cls(role=Role.USER, content=content, base64_image=base64_image) @classmethod def system_message(cls, content: str) -> "Message": """Create a system message""" return cls(role=Role.SYSTEM, content=content) @classmethod def assistant_message( cls, content: Optional[str] = None, base64_image: Optional[str] = None ) -> "Message": """Create an assistant message""" return cls(role=Role.ASSISTANT, content=content, base64_image=base64_image) @classmethod def tool_message( cls, content: str, name, tool_call_id: str, base64_image: Optional[str] = None ) -> "Message": """Create a tool message""" return cls( role=Role.TOOL, content=content, name=name, tool_call_id=tool_call_id, base64_image=base64_image, ) @classmethod def from_tool_calls( cls, tool_calls: List[Any], content: Union[str, List[str]] = "", base64_image: Optional[str] = None, **kwargs, ): """Create ToolCallsMessage from raw tool calls. Args: tool_calls: Raw tool calls from LLM content: Optional message content base64_image: Optional base64 encoded image """ formatted_calls = [ {"id": call.id, "function": call.function.model_dump(), "type": "function"} for call in tool_calls ] return cls( role=Role.ASSISTANT, content=content, tool_calls=formatted_calls, base64_image=base64_image, **kwargs, ) class Memory(BaseModel): messages: List[Message] = Field(default_factory=list) max_messages: int = Field(default=100) def add_message(self, message: Message) -> None: """Add a message to memory""" self.messages.append(message) # Optional: Implement message limit if len(self.messages) > self.max_messages: self.messages = self.messages[-self.max_messages :] def add_messages(self, messages: List[Message]) -> None: """Add multiple messages to memory""" self.messages.extend(messages) def clear(self) -> None: """Clear all messages""" self.messages.clear() def get_recent_messages(self, n: int) -> List[Message]: """Get n most recent messages""" return self.messages[-n:] def to_dict_list(self) -> List[dict]: """Convert messages to list of dicts""" return [msg.to_dict() for msg in self.messages]