import time from typing import Dict, List, Optional from pydantic import Field, model_validator from app.agent.toolcall import ToolCallAgent from app.logger import logger from app.prompt.planning import NEXT_STEP_PROMPT, PLANNING_SYSTEM_PROMPT from app.schema import Message, TOOL_CHOICE_TYPE, ToolCall, ToolChoice from app.tool import PlanningTool, Terminate, ToolCollection class PlanningAgent(ToolCallAgent): """ An agent that creates and manages plans to solve tasks. This agent uses a planning tool to create and manage structured plans, and tracks progress through individual steps until task completion. """ name: str = "planning" description: str = "An agent that creates and manages plans to solve tasks" system_prompt: str = PLANNING_SYSTEM_PROMPT next_step_prompt: str = NEXT_STEP_PROMPT available_tools: ToolCollection = Field( default_factory=lambda: ToolCollection(PlanningTool(), Terminate()) ) tool_choices: TOOL_CHOICE_TYPE = ToolChoice.AUTO # type: ignore special_tool_names: List[str] = Field(default_factory=lambda: [Terminate().name]) tool_calls: List[ToolCall] = Field(default_factory=list) active_plan_id: Optional[str] = Field(default=None) # Add a dictionary to track the step status for each tool call step_execution_tracker: Dict[str, Dict] = Field(default_factory=dict) current_step_index: Optional[int] = None max_steps: int = 20 @model_validator(mode="after") def initialize_plan_and_verify_tools(self) -> "PlanningAgent": """Initialize the agent with a default plan ID and validate required tools.""" self.active_plan_id = f"plan_{int(time.time())}" if "planning" not in self.available_tools.tool_map: self.available_tools.add_tool(PlanningTool()) return self async def think(self) -> bool: """Decide the next action based on plan status.""" prompt = ( f"CURRENT PLAN STATUS:\n{await self.get_plan()}\n\n{self.next_step_prompt}" if self.active_plan_id else self.next_step_prompt ) self.messages.append(Message.user_message(prompt)) # Get the current step index before thinking self.current_step_index = await self._get_current_step_index() result = await super().think() # After thinking, if we decided to execute a tool and it's not a planning tool or special tool, # associate it with the current step for tracking if result and self.tool_calls: latest_tool_call = self.tool_calls[0] # Get the most recent tool call if ( latest_tool_call.function.name != "planning" and latest_tool_call.function.name not in self.special_tool_names and self.current_step_index is not None ): self.step_execution_tracker[latest_tool_call.id] = { "step_index": self.current_step_index, "tool_name": latest_tool_call.function.name, "status": "pending", # Will be updated after execution } return result async def act(self) -> str: """Execute a step and track its completion status.""" result = await super().act() # After executing the tool, update the plan status if self.tool_calls: latest_tool_call = self.tool_calls[0] # Update the execution status to completed if latest_tool_call.id in self.step_execution_tracker: self.step_execution_tracker[latest_tool_call.id]["status"] = "completed" self.step_execution_tracker[latest_tool_call.id]["result"] = result # Update the plan status if this was a non-planning, non-special tool if ( latest_tool_call.function.name != "planning" and latest_tool_call.function.name not in self.special_tool_names ): await self.update_plan_status(latest_tool_call.id) return result async def get_plan(self) -> str: """Retrieve the current plan status.""" if not self.active_plan_id: return "No active plan. Please create a plan first." result = await self.available_tools.execute( name="planning", tool_input={"command": "get", "plan_id": self.active_plan_id}, ) return result.output if hasattr(result, "output") else str(result) async def run(self, request: Optional[str] = None) -> str: """Run the agent with an optional initial request.""" if request: await self.create_initial_plan(request) return await super().run() async def update_plan_status(self, tool_call_id: str) -> None: """ Update the current plan progress based on completed tool execution. Only marks a step as completed if the associated tool has been successfully executed. """ if not self.active_plan_id: return if tool_call_id not in self.step_execution_tracker: logger.warning(f"No step tracking found for tool call {tool_call_id}") return tracker = self.step_execution_tracker[tool_call_id] if tracker["status"] != "completed": logger.warning(f"Tool call {tool_call_id} has not completed successfully") return step_index = tracker["step_index"] try: # Mark the step as completed await self.available_tools.execute( name="planning", tool_input={ "command": "mark_step", "plan_id": self.active_plan_id, "step_index": step_index, "step_status": "completed", }, ) logger.info( f"Marked step {step_index} as completed in plan {self.active_plan_id}" ) except Exception as e: logger.warning(f"Failed to update plan status: {e}") async def _get_current_step_index(self) -> Optional[int]: """ Parse the current plan to identify the first non-completed step's index. Returns None if no active step is found. """ if not self.active_plan_id: return None plan = await self.get_plan() try: plan_lines = plan.splitlines() steps_index = -1 # Find the index of the "Steps:" line for i, line in enumerate(plan_lines): if line.strip() == "Steps:": steps_index = i break if steps_index == -1: return None # Find the first non-completed step for i, line in enumerate(plan_lines[steps_index + 1 :], start=0): if "[ ]" in line or "[→]" in line: # not_started or in_progress # Mark current step as in_progress await self.available_tools.execute( name="planning", tool_input={ "command": "mark_step", "plan_id": self.active_plan_id, "step_index": i, "step_status": "in_progress", }, ) return i return None # No active step found except Exception as e: logger.warning(f"Error finding current step index: {e}") return None async def create_initial_plan(self, request: str) -> None: """Create an initial plan based on the request.""" logger.info(f"Creating initial plan with ID: {self.active_plan_id}") messages = [ Message.user_message( f"Analyze the request and create a plan with ID {self.active_plan_id}: {request}" ) ] self.memory.add_messages(messages) response = await self.llm.ask_tool( messages=messages, system_msgs=[Message.system_message(self.system_prompt)], tools=self.available_tools.to_params(), tool_choice=ToolChoice.REQUIRED, ) assistant_msg = Message.from_tool_calls( content=response.content, tool_calls=response.tool_calls ) self.memory.add_message(assistant_msg) plan_created = False for tool_call in response.tool_calls: if tool_call.function.name == "planning": result = await self.execute_tool(tool_call) logger.info( f"Executed tool {tool_call.function.name} with result: {result}" ) # Add tool response to memory tool_msg = Message.tool_message( content=result, tool_call_id=tool_call.id, name=tool_call.function.name, ) self.memory.add_message(tool_msg) plan_created = True break if not plan_created: logger.warning("No plan created from initial request") tool_msg = Message.assistant_message( "Error: Parameter `plan_id` is required for command: create" ) self.memory.add_message(tool_msg) async def main(): # Configure and run the agent agent = PlanningAgent(available_tools=ToolCollection(PlanningTool(), Terminate())) result = await agent.run("Help me plan a trip to the moon") print(result) if __name__ == "__main__": import asyncio asyncio.run(main())