重构:使用PlanStepStatus枚举替代硬编码status

This commit is contained in:
rxdaozhang 2025-03-11 01:23:16 +08:00
parent b3c4f5ce22
commit 34ede1c082
2 changed files with 47 additions and 29 deletions

View File

@ -60,3 +60,30 @@ class BaseFlow(BaseModel, ABC):
@abstractmethod
async def execute(self, input_text: str) -> str:
"""Execute the flow with given input"""
class PlanStepStatus(str, Enum):
"""Enum class defining possible statuses of a plan step"""
NOT_STARTED = "not_started"
IN_PROGRESS = "in_progress"
COMPLETED = "completed"
BLOCKED = "blocked"
@classmethod
def get_all_statuses(cls) -> list[str]:
"""Return a list of all possible step status values"""
return [status.value for status in cls]
@classmethod
def get_active_statuses(cls) -> list[str]:
"""Return a list of values representing active statuses (not started or in progress)"""
return [cls.NOT_STARTED.value, cls.IN_PROGRESS.value]
@classmethod
def get_status_marks(cls) -> Dict[str, str]:
"""Return a mapping of statuses to their marker symbols"""
return {
cls.COMPLETED.value: "[✓]",
cls.IN_PROGRESS.value: "[→]",
cls.BLOCKED.value: "[!]",
cls.NOT_STARTED.value: "[ ]"
}

View File

@ -5,7 +5,7 @@ from typing import Dict, List, Optional, Union
from pydantic import Field
from app.agent.base import BaseAgent
from app.flow.base import BaseFlow
from app.flow.base import BaseFlow, PlanStepStatus
from app.llm import LLM
from app.logger import logger
from app.schema import AgentState, Message
@ -183,11 +183,11 @@ class PlanningFlow(BaseFlow):
# Find first non-completed step
for i, step in enumerate(steps):
if i >= len(step_statuses):
status = "not_started"
status = PlanStepStatus.NOT_STARTED.value
else:
status = step_statuses[i]
if status in ["not_started", "in_progress"]:
if status in PlanStepStatus.get_active_statuses():
# Extract step type/category if available
step_info = {"text": step}
@ -204,17 +204,17 @@ class PlanningFlow(BaseFlow):
command="mark_step",
plan_id=self.active_plan_id,
step_index=i,
step_status="in_progress",
step_status=PlanStepStatus.IN_PROGRESS.value,
)
except Exception as e:
logger.warning(f"Error marking step as in_progress: {e}")
# Update step status directly if needed
if i < len(step_statuses):
step_statuses[i] = "in_progress"
step_statuses[i] = PlanStepStatus.IN_PROGRESS.value
else:
while len(step_statuses) < i:
step_statuses.append("not_started")
step_statuses.append("in_progress")
step_statuses.append(PlanStepStatus.NOT_STARTED.value)
step_statuses.append(PlanStepStatus.IN_PROGRESS.value)
plan_data["step_statuses"] = step_statuses
@ -266,7 +266,7 @@ class PlanningFlow(BaseFlow):
command="mark_step",
plan_id=self.active_plan_id,
step_index=self.current_step_index,
step_status="completed",
step_status=PlanStepStatus.COMPLETED.value,
)
logger.info(
f"Marked step {self.current_step_index} as completed in plan {self.active_plan_id}"
@ -280,10 +280,10 @@ class PlanningFlow(BaseFlow):
# Ensure the step_statuses list is long enough
while len(step_statuses) <= self.current_step_index:
step_statuses.append("not_started")
step_statuses.append(PlanStepStatus.NOT_STARTED.value)
# Update the status
step_statuses[self.current_step_index] = "completed"
step_statuses[self.current_step_index] = PlanStepStatus.COMPLETED.value
plan_data["step_statuses"] = step_statuses
async def _get_plan_text(self) -> str:
@ -311,23 +311,18 @@ class PlanningFlow(BaseFlow):
# Ensure step_statuses and step_notes match the number of steps
while len(step_statuses) < len(steps):
step_statuses.append("not_started")
step_statuses.append(PlanStepStatus.NOT_STARTED.value)
while len(step_notes) < len(steps):
step_notes.append("")
# Count steps by status
status_counts = {
"completed": 0,
"in_progress": 0,
"blocked": 0,
"not_started": 0,
}
status_counts = {status: 0 for status in PlanStepStatus.get_all_statuses()}
for status in step_statuses:
if status in status_counts:
status_counts[status] += 1
completed = status_counts["completed"]
completed = status_counts[PlanStepStatus.COMPLETED.value]
total = len(steps)
progress = (completed / total) * 100 if total > 0 else 0
@ -337,22 +332,18 @@ class PlanningFlow(BaseFlow):
plan_text += (
f"Progress: {completed}/{total} steps completed ({progress:.1f}%)\n"
)
plan_text += f"Status: {status_counts['completed']} completed, {status_counts['in_progress']} in progress, "
plan_text += f"{status_counts['blocked']} blocked, {status_counts['not_started']} not started\n\n"
plan_text += f"Status: {status_counts[PlanStepStatus.COMPLETED.value]} completed, {status_counts[PlanStepStatus.IN_PROGRESS.value]} in progress, "
plan_text += f"{status_counts[PlanStepStatus.BLOCKED.value]} blocked, {status_counts[PlanStepStatus.NOT_STARTED.value]} not started\n\n"
plan_text += "Steps:\n"
status_marks = PlanStepStatus.get_status_marks()
for i, (step, status, notes) in enumerate(
zip(steps, step_statuses, step_notes)
):
if status == "completed":
status_mark = "[✓]"
elif status == "in_progress":
status_mark = "[→]"
elif status == "blocked":
status_mark = "[!]"
else: # not_started
status_mark = "[ ]"
# Use status marks to indicate step status
status_mark = status_marks.get(status, status_marks[PlanStepStatus.NOT_STARTED.value])
plan_text += f"{i}. {status_mark} {step}\n"
if notes:
plan_text += f" Notes: {notes}\n"