update llm, schema, BaseTool and BaseAgent
This commit is contained in:
parent
5cf34f82df
commit
91d14a3a47
@ -84,6 +84,7 @@ class BaseAgent(BaseModel, ABC):
|
||||
self,
|
||||
role: ROLE_TYPE, # type: ignore
|
||||
content: str,
|
||||
base64_image: Optional[str] = None,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
"""Add a message to the agent's memory.
|
||||
@ -91,6 +92,7 @@ class BaseAgent(BaseModel, ABC):
|
||||
Args:
|
||||
role: The role of the message sender (user, system, assistant, tool).
|
||||
content: The message content.
|
||||
base64_image: Optional base64 encoded image.
|
||||
**kwargs: Additional arguments (e.g., tool_call_id for tool messages).
|
||||
|
||||
Raises:
|
||||
@ -106,9 +108,9 @@ class BaseAgent(BaseModel, ABC):
|
||||
if role not in message_map:
|
||||
raise ValueError(f"Unsupported message role: {role}")
|
||||
|
||||
msg_factory = message_map[role]
|
||||
msg = msg_factory(content, **kwargs) if role == "tool" else msg_factory(content)
|
||||
self.memory.add_message(msg)
|
||||
# Create message with appropriate parameters based on role
|
||||
kwargs = {"base64_image": base64_image, **(kwargs if role == "tool" else {})}
|
||||
self.memory.add_message(message_map[role](content, **kwargs))
|
||||
|
||||
async def run(self, request: Optional[str] = None) -> str:
|
||||
"""Execute the agent's main loop asynchronously.
|
||||
|
60
app/llm.py
60
app/llm.py
@ -280,22 +280,58 @@ class LLM:
|
||||
formatted_messages = []
|
||||
|
||||
for message in messages:
|
||||
# Convert Message objects to dictionaries
|
||||
if isinstance(message, Message):
|
||||
message = message.to_dict()
|
||||
if isinstance(message, dict):
|
||||
# If message is a dict, ensure it has required fields
|
||||
if "role" not in message:
|
||||
raise ValueError("Message dict must contain 'role' field")
|
||||
if "content" in message or "tool_calls" in message:
|
||||
formatted_messages.append(message)
|
||||
# else: do not include the message
|
||||
else:
|
||||
|
||||
if not isinstance(message, dict):
|
||||
raise TypeError(f"Unsupported message type: {type(message)}")
|
||||
|
||||
# Validate all messages have required fields
|
||||
for msg in formatted_messages:
|
||||
if msg["role"] not in ROLE_VALUES:
|
||||
raise ValueError(f"Invalid role: {msg['role']}")
|
||||
# Validate required fields
|
||||
if "role" not in message:
|
||||
raise ValueError("Message dict must contain 'role' field")
|
||||
|
||||
# Process base64 images if present
|
||||
if message.get("base64_image"):
|
||||
# Initialize or convert content to appropriate format
|
||||
if not message.get("content"):
|
||||
message["content"] = []
|
||||
elif isinstance(message["content"], str):
|
||||
message["content"] = [{"type": "text", "text": message["content"]}]
|
||||
elif isinstance(message["content"], list):
|
||||
# Convert string items to proper text objects
|
||||
message["content"] = [
|
||||
(
|
||||
{"type": "text", "text": item}
|
||||
if isinstance(item, str)
|
||||
else item
|
||||
)
|
||||
for item in message["content"]
|
||||
]
|
||||
|
||||
# Add the image to content
|
||||
message["content"].append(
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": f"data:image/jpeg;base64,{message['base64_image']}"
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
# Remove the base64_image field
|
||||
del message["base64_image"]
|
||||
|
||||
# Only include messages with content or tool_calls
|
||||
if "content" in message or "tool_calls" in message:
|
||||
formatted_messages.append(message)
|
||||
|
||||
# Validate all roles
|
||||
invalid_roles = [
|
||||
msg for msg in formatted_messages if msg["role"] not in ROLE_VALUES
|
||||
]
|
||||
if invalid_roles:
|
||||
raise ValueError(f"Invalid role: {invalid_roles[0]['role']}")
|
||||
|
||||
return formatted_messages
|
||||
|
||||
|
@ -59,6 +59,7 @@ class Message(BaseModel):
|
||||
tool_calls: Optional[List[ToolCall]] = Field(default=None)
|
||||
name: Optional[str] = Field(default=None)
|
||||
tool_call_id: Optional[str] = Field(default=None)
|
||||
base64_image: Optional[str] = Field(default=None)
|
||||
|
||||
def __add__(self, other) -> List["Message"]:
|
||||
"""支持 Message + list 或 Message + Message 的操作"""
|
||||
@ -91,12 +92,16 @@ class Message(BaseModel):
|
||||
message["name"] = self.name
|
||||
if self.tool_call_id is not None:
|
||||
message["tool_call_id"] = self.tool_call_id
|
||||
if self.base64_image is not None:
|
||||
message["base64_image"] = self.base64_image
|
||||
return message
|
||||
|
||||
@classmethod
|
||||
def user_message(cls, content: str) -> "Message":
|
||||
def user_message(
|
||||
cls, content: str, base64_image: Optional[str] = None
|
||||
) -> "Message":
|
||||
"""Create a user message"""
|
||||
return cls(role=Role.USER, content=content)
|
||||
return cls(role=Role.USER, content=content, base64_image=base64_image)
|
||||
|
||||
@classmethod
|
||||
def system_message(cls, content: str) -> "Message":
|
||||
@ -104,33 +109,50 @@ class Message(BaseModel):
|
||||
return cls(role=Role.SYSTEM, content=content)
|
||||
|
||||
@classmethod
|
||||
def assistant_message(cls, content: Optional[str] = None) -> "Message":
|
||||
def assistant_message(
|
||||
cls, content: Optional[str] = None, base64_image: Optional[str] = None
|
||||
) -> "Message":
|
||||
"""Create an assistant message"""
|
||||
return cls(role=Role.ASSISTANT, content=content)
|
||||
return cls(role=Role.ASSISTANT, content=content, base64_image=base64_image)
|
||||
|
||||
@classmethod
|
||||
def tool_message(cls, content: str, name, tool_call_id: str) -> "Message":
|
||||
def tool_message(
|
||||
cls, content: str, name, tool_call_id: str, base64_image: Optional[str] = None
|
||||
) -> "Message":
|
||||
"""Create a tool message"""
|
||||
return cls(
|
||||
role=Role.TOOL, content=content, name=name, tool_call_id=tool_call_id
|
||||
role=Role.TOOL,
|
||||
content=content,
|
||||
name=name,
|
||||
tool_call_id=tool_call_id,
|
||||
base64_image=base64_image,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_tool_calls(
|
||||
cls, tool_calls: List[Any], content: Union[str, List[str]] = "", **kwargs
|
||||
cls,
|
||||
tool_calls: List[Any],
|
||||
content: Union[str, List[str]] = "",
|
||||
base64_image: Optional[str] = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""Create ToolCallsMessage from raw tool calls.
|
||||
|
||||
Args:
|
||||
tool_calls: Raw tool calls from LLM
|
||||
content: Optional message content
|
||||
base64_image: Optional base64 encoded image
|
||||
"""
|
||||
formatted_calls = [
|
||||
{"id": call.id, "function": call.function.model_dump(), "type": "function"}
|
||||
for call in tool_calls
|
||||
]
|
||||
return cls(
|
||||
role=Role.ASSISTANT, content=content, tool_calls=formatted_calls, **kwargs
|
||||
role=Role.ASSISTANT,
|
||||
content=content,
|
||||
tool_calls=formatted_calls,
|
||||
base64_image=base64_image,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
|
@ -37,6 +37,7 @@ class ToolResult(BaseModel):
|
||||
|
||||
output: Any = Field(default=None)
|
||||
error: Optional[str] = Field(default=None)
|
||||
base64_image: Optional[str] = Field(default=None)
|
||||
system: Optional[str] = Field(default=None)
|
||||
|
||||
class Config:
|
||||
@ -58,6 +59,7 @@ class ToolResult(BaseModel):
|
||||
return ToolResult(
|
||||
output=combine_fields(self.output, other.output),
|
||||
error=combine_fields(self.error, other.error),
|
||||
base64_image=combine_fields(self.base64_image, other.base64_image, False),
|
||||
system=combine_fields(self.system, other.system),
|
||||
)
|
||||
|
||||
@ -76,7 +78,3 @@ class CLIResult(ToolResult):
|
||||
|
||||
class ToolFailure(ToolResult):
|
||||
"""A ToolResult that represents a failure."""
|
||||
|
||||
|
||||
class AgentAwareTool:
|
||||
agent: Optional = None
|
||||
|
Loading…
x
Reference in New Issue
Block a user