194 lines
6.7 KiB
Python
194 lines
6.7 KiB
Python
from abc import ABC, abstractmethod
|
||
from contextlib import asynccontextmanager
|
||
from typing import List, Literal, Optional
|
||
|
||
from pydantic import BaseModel, Field, model_validator
|
||
|
||
from app.llm import LLM
|
||
from app.logger import logger
|
||
from app.schema import AgentState, Memory, Message
|
||
|
||
|
||
class BaseAgent(BaseModel, ABC):
|
||
"""Abstract base class for managing agent state and execution.
|
||
|
||
Provides foundational functionality for state transitions, memory management,
|
||
and a step-based execution loop. Subclasses must implement the `step` method.
|
||
"""
|
||
|
||
# Core attributes
|
||
name: str = Field(..., description="Unique name of the agent")
|
||
description: Optional[str] = Field(None, description="Optional agent description")
|
||
|
||
# Prompts
|
||
system_prompt: Optional[str] = Field(
|
||
None, description="System-level instruction prompt"
|
||
)
|
||
next_step_prompt: Optional[str] = Field(
|
||
None, description="Prompt for determining next action"
|
||
)
|
||
|
||
# Dependencies
|
||
llm: LLM = Field(default_factory=LLM, description="Language model instance")
|
||
memory: Memory = Field(default_factory=Memory, description="Agent's memory store")
|
||
state: AgentState = Field(
|
||
default=AgentState.IDLE, description="Current agent state"
|
||
)
|
||
|
||
# Execution control
|
||
max_steps: int = Field(default=10, description="Maximum steps before termination")
|
||
current_step: int = Field(default=0, description="Current step in execution")
|
||
|
||
duplicate_threshold: int = 2
|
||
|
||
class Config:
|
||
arbitrary_types_allowed = True
|
||
extra = "allow" # Allow extra fields for flexibility in subclasses
|
||
|
||
@model_validator(mode="after")
|
||
def initialize_agent(self) -> "BaseAgent":
|
||
"""Initialize agent with default settings if not provided."""
|
||
if self.llm is None or not isinstance(self.llm, LLM):
|
||
self.llm = LLM(config_name=self.name.lower())
|
||
if not isinstance(self.memory, Memory):
|
||
self.memory = Memory()
|
||
return self
|
||
|
||
@asynccontextmanager
|
||
async def state_context(self, new_state: AgentState):
|
||
"""Context manager for safe agent state transitions.
|
||
|
||
Args:
|
||
new_state: The state to transition to during the context.
|
||
|
||
Yields:
|
||
None: Allows execution within the new state.
|
||
|
||
Raises:
|
||
ValueError: If the new_state is invalid.
|
||
"""
|
||
if not isinstance(new_state, AgentState):
|
||
raise ValueError(f"Invalid state: {new_state}")
|
||
|
||
previous_state = self.state
|
||
self.state = new_state
|
||
try:
|
||
yield
|
||
except Exception as e:
|
||
self.state = AgentState.ERROR # Transition to ERROR on failure
|
||
raise e
|
||
finally:
|
||
self.state = previous_state # Revert to previous state
|
||
|
||
def update_memory(
|
||
self,
|
||
role: Literal["user", "system", "assistant", "tool"],
|
||
content: str,
|
||
**kwargs,
|
||
) -> None:
|
||
"""Add a message to the agent's memory.
|
||
|
||
Args:
|
||
role: The role of the message sender (user, system, assistant, tool).
|
||
content: The message content.
|
||
**kwargs: Additional arguments (e.g., tool_call_id for tool messages).
|
||
|
||
Raises:
|
||
ValueError: If the role is unsupported.
|
||
"""
|
||
message_map = {
|
||
"user": Message.user_message,
|
||
"system": Message.system_message,
|
||
"assistant": Message.assistant_message,
|
||
"tool": lambda content, **kw: Message.tool_message(content, **kw),
|
||
}
|
||
|
||
if role not in message_map:
|
||
raise ValueError(f"Unsupported message role: {role}")
|
||
|
||
msg_factory = message_map[role]
|
||
msg = msg_factory(content, **kwargs) if role == "tool" else msg_factory(content)
|
||
self.memory.add_message(msg)
|
||
|
||
async def run(self, request: Optional[str] = None) -> str:
|
||
"""Execute the agent's main loop asynchronously.
|
||
|
||
Args:
|
||
request: Optional initial user request to process.
|
||
|
||
Returns:
|
||
A string summarizing the execution results.
|
||
|
||
Raises:
|
||
RuntimeError: If the agent is not in IDLE state at start.
|
||
"""
|
||
if self.state != AgentState.IDLE:
|
||
raise RuntimeError(f"Cannot run agent from state: {self.state}")
|
||
|
||
if request:
|
||
self.update_memory("user", request)
|
||
|
||
results: List[str] = []
|
||
async with self.state_context(AgentState.RUNNING):
|
||
while (
|
||
self.current_step < self.max_steps and self.state != AgentState.FINISHED
|
||
):
|
||
self.current_step += 1
|
||
logger.info(f"Executing step {self.current_step}/{self.max_steps}")
|
||
step_result = await self.step()
|
||
|
||
# Check for stuck state
|
||
if self.is_stuck():
|
||
self.handle_stuck_state()
|
||
|
||
results.append(f"Step {self.current_step}: {step_result}")
|
||
|
||
if self.current_step >= self.max_steps:
|
||
self.current_step = 0 # setting back to 0 when reached max steps, cefengxu
|
||
self.state = AgentState.IDLE # setting the status , cefengxu
|
||
results.append(f"Terminated: Reached max steps ({self.max_steps})")
|
||
|
||
return "\n".join(results) if results else "No steps executed"
|
||
|
||
@abstractmethod
|
||
async def step(self) -> str:
|
||
"""Execute a single step in the agent's workflow.
|
||
|
||
Must be implemented by subclasses to define specific behavior.
|
||
"""
|
||
|
||
def handle_stuck_state(self):
|
||
"""Handle stuck state by adding a prompt to change strategy"""
|
||
stuck_prompt = "\
|
||
Observed duplicate responses. Consider new strategies and avoid repeating ineffective paths already attempted."
|
||
self.next_step_prompt = f"{stuck_prompt}\n{self.next_step_prompt}"
|
||
logger.warning(f"Agent detected stuck state. Added prompt: {stuck_prompt}")
|
||
|
||
def is_stuck(self) -> bool:
|
||
"""Check if the agent is stuck in a loop by detecting duplicate content"""
|
||
if len(self.memory.messages) < 2:
|
||
return False
|
||
|
||
last_message = self.memory.messages[-1]
|
||
if not last_message.content:
|
||
return False
|
||
|
||
# Count identical content occurrences
|
||
duplicate_count = sum(
|
||
1
|
||
for msg in reversed(self.memory.messages[:-1])
|
||
if msg.role == "assistant" and msg.content == last_message.content
|
||
)
|
||
|
||
return duplicate_count >= self.duplicate_threshold
|
||
|
||
@property
|
||
def messages(self) -> List[Message]:
|
||
"""Retrieve a list of messages from the agent's memory."""
|
||
return self.memory.messages
|
||
|
||
@messages.setter
|
||
def messages(self, value: List[Message]):
|
||
"""Set the list of messages in the agent's memory."""
|
||
self.memory.messages = value
|