update llm, schema, BaseTool and BaseAgent

This commit is contained in:
liangxinbing 2025-03-18 02:30:30 +08:00
parent 5cf34f82df
commit 91d14a3a47
4 changed files with 85 additions and 27 deletions

View File

@ -84,6 +84,7 @@ class BaseAgent(BaseModel, ABC):
self,
role: ROLE_TYPE, # type: ignore
content: str,
base64_image: Optional[str] = None,
**kwargs,
) -> None:
"""Add a message to the agent's memory.
@ -91,6 +92,7 @@ class BaseAgent(BaseModel, ABC):
Args:
role: The role of the message sender (user, system, assistant, tool).
content: The message content.
base64_image: Optional base64 encoded image.
**kwargs: Additional arguments (e.g., tool_call_id for tool messages).
Raises:
@ -106,9 +108,9 @@ class BaseAgent(BaseModel, ABC):
if role not in message_map:
raise ValueError(f"Unsupported message role: {role}")
msg_factory = message_map[role]
msg = msg_factory(content, **kwargs) if role == "tool" else msg_factory(content)
self.memory.add_message(msg)
# Create message with appropriate parameters based on role
kwargs = {"base64_image": base64_image, **(kwargs if role == "tool" else {})}
self.memory.add_message(message_map[role](content, **kwargs))
async def run(self, request: Optional[str] = None) -> str:
"""Execute the agent's main loop asynchronously.

View File

@ -280,22 +280,58 @@ class LLM:
formatted_messages = []
for message in messages:
# Convert Message objects to dictionaries
if isinstance(message, Message):
message = message.to_dict()
if isinstance(message, dict):
# If message is a dict, ensure it has required fields
if "role" not in message:
raise ValueError("Message dict must contain 'role' field")
if "content" in message or "tool_calls" in message:
formatted_messages.append(message)
# else: do not include the message
else:
if not isinstance(message, dict):
raise TypeError(f"Unsupported message type: {type(message)}")
# Validate all messages have required fields
for msg in formatted_messages:
if msg["role"] not in ROLE_VALUES:
raise ValueError(f"Invalid role: {msg['role']}")
# Validate required fields
if "role" not in message:
raise ValueError("Message dict must contain 'role' field")
# Process base64 images if present
if message.get("base64_image"):
# Initialize or convert content to appropriate format
if not message.get("content"):
message["content"] = []
elif isinstance(message["content"], str):
message["content"] = [{"type": "text", "text": message["content"]}]
elif isinstance(message["content"], list):
# Convert string items to proper text objects
message["content"] = [
(
{"type": "text", "text": item}
if isinstance(item, str)
else item
)
for item in message["content"]
]
# Add the image to content
message["content"].append(
{
"type": "image_url",
"image_url": {
"url": f"data:image/jpeg;base64,{message['base64_image']}"
},
}
)
# Remove the base64_image field
del message["base64_image"]
# Only include messages with content or tool_calls
if "content" in message or "tool_calls" in message:
formatted_messages.append(message)
# Validate all roles
invalid_roles = [
msg for msg in formatted_messages if msg["role"] not in ROLE_VALUES
]
if invalid_roles:
raise ValueError(f"Invalid role: {invalid_roles[0]['role']}")
return formatted_messages

View File

@ -59,6 +59,7 @@ class Message(BaseModel):
tool_calls: Optional[List[ToolCall]] = Field(default=None)
name: Optional[str] = Field(default=None)
tool_call_id: Optional[str] = Field(default=None)
base64_image: Optional[str] = Field(default=None)
def __add__(self, other) -> List["Message"]:
"""支持 Message + list 或 Message + Message 的操作"""
@ -91,12 +92,16 @@ class Message(BaseModel):
message["name"] = self.name
if self.tool_call_id is not None:
message["tool_call_id"] = self.tool_call_id
if self.base64_image is not None:
message["base64_image"] = self.base64_image
return message
@classmethod
def user_message(cls, content: str) -> "Message":
def user_message(
cls, content: str, base64_image: Optional[str] = None
) -> "Message":
"""Create a user message"""
return cls(role=Role.USER, content=content)
return cls(role=Role.USER, content=content, base64_image=base64_image)
@classmethod
def system_message(cls, content: str) -> "Message":
@ -104,33 +109,50 @@ class Message(BaseModel):
return cls(role=Role.SYSTEM, content=content)
@classmethod
def assistant_message(cls, content: Optional[str] = None) -> "Message":
def assistant_message(
cls, content: Optional[str] = None, base64_image: Optional[str] = None
) -> "Message":
"""Create an assistant message"""
return cls(role=Role.ASSISTANT, content=content)
return cls(role=Role.ASSISTANT, content=content, base64_image=base64_image)
@classmethod
def tool_message(cls, content: str, name, tool_call_id: str) -> "Message":
def tool_message(
cls, content: str, name, tool_call_id: str, base64_image: Optional[str] = None
) -> "Message":
"""Create a tool message"""
return cls(
role=Role.TOOL, content=content, name=name, tool_call_id=tool_call_id
role=Role.TOOL,
content=content,
name=name,
tool_call_id=tool_call_id,
base64_image=base64_image,
)
@classmethod
def from_tool_calls(
cls, tool_calls: List[Any], content: Union[str, List[str]] = "", **kwargs
cls,
tool_calls: List[Any],
content: Union[str, List[str]] = "",
base64_image: Optional[str] = None,
**kwargs,
):
"""Create ToolCallsMessage from raw tool calls.
Args:
tool_calls: Raw tool calls from LLM
content: Optional message content
base64_image: Optional base64 encoded image
"""
formatted_calls = [
{"id": call.id, "function": call.function.model_dump(), "type": "function"}
for call in tool_calls
]
return cls(
role=Role.ASSISTANT, content=content, tool_calls=formatted_calls, **kwargs
role=Role.ASSISTANT,
content=content,
tool_calls=formatted_calls,
base64_image=base64_image,
**kwargs,
)

View File

@ -37,6 +37,7 @@ class ToolResult(BaseModel):
output: Any = Field(default=None)
error: Optional[str] = Field(default=None)
base64_image: Optional[str] = Field(default=None)
system: Optional[str] = Field(default=None)
class Config:
@ -58,6 +59,7 @@ class ToolResult(BaseModel):
return ToolResult(
output=combine_fields(self.output, other.output),
error=combine_fields(self.error, other.error),
base64_image=combine_fields(self.base64_image, other.base64_image, False),
system=combine_fields(self.system, other.system),
)
@ -76,7 +78,3 @@ class CLIResult(ToolResult):
class ToolFailure(ToolResult):
"""A ToolResult that represents a failure."""
class AgentAwareTool:
agent: Optional = None