update llm, schema, BaseTool and BaseAgent
This commit is contained in:
parent
5cf34f82df
commit
91d14a3a47
@ -84,6 +84,7 @@ class BaseAgent(BaseModel, ABC):
|
|||||||
self,
|
self,
|
||||||
role: ROLE_TYPE, # type: ignore
|
role: ROLE_TYPE, # type: ignore
|
||||||
content: str,
|
content: str,
|
||||||
|
base64_image: Optional[str] = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Add a message to the agent's memory.
|
"""Add a message to the agent's memory.
|
||||||
@ -91,6 +92,7 @@ class BaseAgent(BaseModel, ABC):
|
|||||||
Args:
|
Args:
|
||||||
role: The role of the message sender (user, system, assistant, tool).
|
role: The role of the message sender (user, system, assistant, tool).
|
||||||
content: The message content.
|
content: The message content.
|
||||||
|
base64_image: Optional base64 encoded image.
|
||||||
**kwargs: Additional arguments (e.g., tool_call_id for tool messages).
|
**kwargs: Additional arguments (e.g., tool_call_id for tool messages).
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
@ -106,9 +108,9 @@ class BaseAgent(BaseModel, ABC):
|
|||||||
if role not in message_map:
|
if role not in message_map:
|
||||||
raise ValueError(f"Unsupported message role: {role}")
|
raise ValueError(f"Unsupported message role: {role}")
|
||||||
|
|
||||||
msg_factory = message_map[role]
|
# Create message with appropriate parameters based on role
|
||||||
msg = msg_factory(content, **kwargs) if role == "tool" else msg_factory(content)
|
kwargs = {"base64_image": base64_image, **(kwargs if role == "tool" else {})}
|
||||||
self.memory.add_message(msg)
|
self.memory.add_message(message_map[role](content, **kwargs))
|
||||||
|
|
||||||
async def run(self, request: Optional[str] = None) -> str:
|
async def run(self, request: Optional[str] = None) -> str:
|
||||||
"""Execute the agent's main loop asynchronously.
|
"""Execute the agent's main loop asynchronously.
|
||||||
|
60
app/llm.py
60
app/llm.py
@ -280,22 +280,58 @@ class LLM:
|
|||||||
formatted_messages = []
|
formatted_messages = []
|
||||||
|
|
||||||
for message in messages:
|
for message in messages:
|
||||||
|
# Convert Message objects to dictionaries
|
||||||
if isinstance(message, Message):
|
if isinstance(message, Message):
|
||||||
message = message.to_dict()
|
message = message.to_dict()
|
||||||
if isinstance(message, dict):
|
|
||||||
# If message is a dict, ensure it has required fields
|
if not isinstance(message, dict):
|
||||||
if "role" not in message:
|
|
||||||
raise ValueError("Message dict must contain 'role' field")
|
|
||||||
if "content" in message or "tool_calls" in message:
|
|
||||||
formatted_messages.append(message)
|
|
||||||
# else: do not include the message
|
|
||||||
else:
|
|
||||||
raise TypeError(f"Unsupported message type: {type(message)}")
|
raise TypeError(f"Unsupported message type: {type(message)}")
|
||||||
|
|
||||||
# Validate all messages have required fields
|
# Validate required fields
|
||||||
for msg in formatted_messages:
|
if "role" not in message:
|
||||||
if msg["role"] not in ROLE_VALUES:
|
raise ValueError("Message dict must contain 'role' field")
|
||||||
raise ValueError(f"Invalid role: {msg['role']}")
|
|
||||||
|
# Process base64 images if present
|
||||||
|
if message.get("base64_image"):
|
||||||
|
# Initialize or convert content to appropriate format
|
||||||
|
if not message.get("content"):
|
||||||
|
message["content"] = []
|
||||||
|
elif isinstance(message["content"], str):
|
||||||
|
message["content"] = [{"type": "text", "text": message["content"]}]
|
||||||
|
elif isinstance(message["content"], list):
|
||||||
|
# Convert string items to proper text objects
|
||||||
|
message["content"] = [
|
||||||
|
(
|
||||||
|
{"type": "text", "text": item}
|
||||||
|
if isinstance(item, str)
|
||||||
|
else item
|
||||||
|
)
|
||||||
|
for item in message["content"]
|
||||||
|
]
|
||||||
|
|
||||||
|
# Add the image to content
|
||||||
|
message["content"].append(
|
||||||
|
{
|
||||||
|
"type": "image_url",
|
||||||
|
"image_url": {
|
||||||
|
"url": f"data:image/jpeg;base64,{message['base64_image']}"
|
||||||
|
},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# Remove the base64_image field
|
||||||
|
del message["base64_image"]
|
||||||
|
|
||||||
|
# Only include messages with content or tool_calls
|
||||||
|
if "content" in message or "tool_calls" in message:
|
||||||
|
formatted_messages.append(message)
|
||||||
|
|
||||||
|
# Validate all roles
|
||||||
|
invalid_roles = [
|
||||||
|
msg for msg in formatted_messages if msg["role"] not in ROLE_VALUES
|
||||||
|
]
|
||||||
|
if invalid_roles:
|
||||||
|
raise ValueError(f"Invalid role: {invalid_roles[0]['role']}")
|
||||||
|
|
||||||
return formatted_messages
|
return formatted_messages
|
||||||
|
|
||||||
|
@ -59,6 +59,7 @@ class Message(BaseModel):
|
|||||||
tool_calls: Optional[List[ToolCall]] = Field(default=None)
|
tool_calls: Optional[List[ToolCall]] = Field(default=None)
|
||||||
name: Optional[str] = Field(default=None)
|
name: Optional[str] = Field(default=None)
|
||||||
tool_call_id: Optional[str] = Field(default=None)
|
tool_call_id: Optional[str] = Field(default=None)
|
||||||
|
base64_image: Optional[str] = Field(default=None)
|
||||||
|
|
||||||
def __add__(self, other) -> List["Message"]:
|
def __add__(self, other) -> List["Message"]:
|
||||||
"""支持 Message + list 或 Message + Message 的操作"""
|
"""支持 Message + list 或 Message + Message 的操作"""
|
||||||
@ -91,12 +92,16 @@ class Message(BaseModel):
|
|||||||
message["name"] = self.name
|
message["name"] = self.name
|
||||||
if self.tool_call_id is not None:
|
if self.tool_call_id is not None:
|
||||||
message["tool_call_id"] = self.tool_call_id
|
message["tool_call_id"] = self.tool_call_id
|
||||||
|
if self.base64_image is not None:
|
||||||
|
message["base64_image"] = self.base64_image
|
||||||
return message
|
return message
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def user_message(cls, content: str) -> "Message":
|
def user_message(
|
||||||
|
cls, content: str, base64_image: Optional[str] = None
|
||||||
|
) -> "Message":
|
||||||
"""Create a user message"""
|
"""Create a user message"""
|
||||||
return cls(role=Role.USER, content=content)
|
return cls(role=Role.USER, content=content, base64_image=base64_image)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def system_message(cls, content: str) -> "Message":
|
def system_message(cls, content: str) -> "Message":
|
||||||
@ -104,33 +109,50 @@ class Message(BaseModel):
|
|||||||
return cls(role=Role.SYSTEM, content=content)
|
return cls(role=Role.SYSTEM, content=content)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def assistant_message(cls, content: Optional[str] = None) -> "Message":
|
def assistant_message(
|
||||||
|
cls, content: Optional[str] = None, base64_image: Optional[str] = None
|
||||||
|
) -> "Message":
|
||||||
"""Create an assistant message"""
|
"""Create an assistant message"""
|
||||||
return cls(role=Role.ASSISTANT, content=content)
|
return cls(role=Role.ASSISTANT, content=content, base64_image=base64_image)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def tool_message(cls, content: str, name, tool_call_id: str) -> "Message":
|
def tool_message(
|
||||||
|
cls, content: str, name, tool_call_id: str, base64_image: Optional[str] = None
|
||||||
|
) -> "Message":
|
||||||
"""Create a tool message"""
|
"""Create a tool message"""
|
||||||
return cls(
|
return cls(
|
||||||
role=Role.TOOL, content=content, name=name, tool_call_id=tool_call_id
|
role=Role.TOOL,
|
||||||
|
content=content,
|
||||||
|
name=name,
|
||||||
|
tool_call_id=tool_call_id,
|
||||||
|
base64_image=base64_image,
|
||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_tool_calls(
|
def from_tool_calls(
|
||||||
cls, tool_calls: List[Any], content: Union[str, List[str]] = "", **kwargs
|
cls,
|
||||||
|
tool_calls: List[Any],
|
||||||
|
content: Union[str, List[str]] = "",
|
||||||
|
base64_image: Optional[str] = None,
|
||||||
|
**kwargs,
|
||||||
):
|
):
|
||||||
"""Create ToolCallsMessage from raw tool calls.
|
"""Create ToolCallsMessage from raw tool calls.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
tool_calls: Raw tool calls from LLM
|
tool_calls: Raw tool calls from LLM
|
||||||
content: Optional message content
|
content: Optional message content
|
||||||
|
base64_image: Optional base64 encoded image
|
||||||
"""
|
"""
|
||||||
formatted_calls = [
|
formatted_calls = [
|
||||||
{"id": call.id, "function": call.function.model_dump(), "type": "function"}
|
{"id": call.id, "function": call.function.model_dump(), "type": "function"}
|
||||||
for call in tool_calls
|
for call in tool_calls
|
||||||
]
|
]
|
||||||
return cls(
|
return cls(
|
||||||
role=Role.ASSISTANT, content=content, tool_calls=formatted_calls, **kwargs
|
role=Role.ASSISTANT,
|
||||||
|
content=content,
|
||||||
|
tool_calls=formatted_calls,
|
||||||
|
base64_image=base64_image,
|
||||||
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -37,6 +37,7 @@ class ToolResult(BaseModel):
|
|||||||
|
|
||||||
output: Any = Field(default=None)
|
output: Any = Field(default=None)
|
||||||
error: Optional[str] = Field(default=None)
|
error: Optional[str] = Field(default=None)
|
||||||
|
base64_image: Optional[str] = Field(default=None)
|
||||||
system: Optional[str] = Field(default=None)
|
system: Optional[str] = Field(default=None)
|
||||||
|
|
||||||
class Config:
|
class Config:
|
||||||
@ -58,6 +59,7 @@ class ToolResult(BaseModel):
|
|||||||
return ToolResult(
|
return ToolResult(
|
||||||
output=combine_fields(self.output, other.output),
|
output=combine_fields(self.output, other.output),
|
||||||
error=combine_fields(self.error, other.error),
|
error=combine_fields(self.error, other.error),
|
||||||
|
base64_image=combine_fields(self.base64_image, other.base64_image, False),
|
||||||
system=combine_fields(self.system, other.system),
|
system=combine_fields(self.system, other.system),
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -76,7 +78,3 @@ class CLIResult(ToolResult):
|
|||||||
|
|
||||||
class ToolFailure(ToolResult):
|
class ToolFailure(ToolResult):
|
||||||
"""A ToolResult that represents a failure."""
|
"""A ToolResult that represents a failure."""
|
||||||
|
|
||||||
|
|
||||||
class AgentAwareTool:
|
|
||||||
agent: Optional = None
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user