From eac3a6e24efd63e03e71abc46a547f4691ca3da7 Mon Sep 17 00:00:00 2001 From: Matt Eng Date: Wed, 12 Mar 2025 00:15:31 -0700 Subject: [PATCH] Standardize literals for role and tool choice type definitions --- app/agent/base.py | 6 +++--- app/agent/planning.py | 8 ++++---- app/agent/toolcall.py | 14 +++++++------- app/flow/planning.py | 4 ++-- app/llm.py | 10 +++++----- app/schema.py | 30 ++++++++++++++++++++++++------ 6 files changed, 45 insertions(+), 27 deletions(-) diff --git a/app/agent/base.py b/app/agent/base.py index 43132cf..cdff133 100644 --- a/app/agent/base.py +++ b/app/agent/base.py @@ -1,12 +1,12 @@ from abc import ABC, abstractmethod from contextlib import asynccontextmanager -from typing import List, Literal, Optional +from typing import List, Optional from pydantic import BaseModel, Field, model_validator from app.llm import LLM from app.logger import logger -from app.schema import AgentState, Memory, Message +from app.schema import AgentState, Memory, Message, ROLE_TYPE class BaseAgent(BaseModel, ABC): @@ -82,7 +82,7 @@ class BaseAgent(BaseModel, ABC): def update_memory( self, - role: Literal["user", "system", "assistant", "tool"], + role: ROLE_TYPE, # type: ignore content: str, **kwargs, ) -> None: diff --git a/app/agent/planning.py b/app/agent/planning.py index 0965a88..cbd15a0 100644 --- a/app/agent/planning.py +++ b/app/agent/planning.py @@ -1,12 +1,12 @@ import time -from typing import Dict, List, Literal, Optional +from typing import Dict, List, Optional from pydantic import Field, model_validator from app.agent.toolcall import ToolCallAgent from app.logger import logger from app.prompt.planning import NEXT_STEP_PROMPT, PLANNING_SYSTEM_PROMPT -from app.schema import Message, ToolCall +from app.schema import Message, TOOL_CHOICE_TYPE, ToolCall, ToolChoice from app.tool import PlanningTool, Terminate, ToolCollection @@ -27,7 +27,7 @@ class PlanningAgent(ToolCallAgent): available_tools: ToolCollection = Field( default_factory=lambda: ToolCollection(PlanningTool(), Terminate()) ) - tool_choices: Literal["none", "auto", "required"] = "auto" + tool_choices: TOOL_CHOICE_TYPE = ToolChoice.AUTO # type: ignore special_tool_names: List[str] = Field(default_factory=lambda: [Terminate().name]) tool_calls: List[ToolCall] = Field(default_factory=list) @@ -212,7 +212,7 @@ class PlanningAgent(ToolCallAgent): messages=messages, system_msgs=[Message.system_message(self.system_prompt)], tools=self.available_tools.to_params(), - tool_choice="required", + tool_choice=ToolChoice.REQUIRED, ) assistant_msg = Message.from_tool_calls( content=response.content, tool_calls=response.tool_calls diff --git a/app/agent/toolcall.py b/app/agent/toolcall.py index b3b6439..a90dc85 100644 --- a/app/agent/toolcall.py +++ b/app/agent/toolcall.py @@ -1,12 +1,12 @@ import json -from typing import Any, List, Literal +from typing import Any, List from pydantic import Field from app.agent.react import ReActAgent from app.logger import logger from app.prompt.toolcall import NEXT_STEP_PROMPT, SYSTEM_PROMPT -from app.schema import AgentState, Message, ToolCall +from app.schema import AgentState, Message, ToolCall, TOOL_CHOICE_TYPE, ToolChoice from app.tool import CreateChatCompletion, Terminate, ToolCollection @@ -25,7 +25,7 @@ class ToolCallAgent(ReActAgent): available_tools: ToolCollection = ToolCollection( CreateChatCompletion(), Terminate() ) - tool_choices: Literal["none", "auto", "required"] = "auto" + tool_choices: TOOL_CHOICE_TYPE = ToolChoice.AUTO # type: ignore special_tool_names: List[str] = Field(default_factory=lambda: [Terminate().name]) tool_calls: List[ToolCall] = Field(default_factory=list) @@ -61,7 +61,7 @@ class ToolCallAgent(ReActAgent): try: # Handle different tool_choices modes - if self.tool_choices == "none": + if self.tool_choices == ToolChoice.NONE: if response.tool_calls: logger.warning( f"🤔 Hmm, {self.name} tried to use tools when they weren't available!" @@ -81,11 +81,11 @@ class ToolCallAgent(ReActAgent): ) self.memory.add_message(assistant_msg) - if self.tool_choices == "required" and not self.tool_calls: + if self.tool_choices == ToolChoice.REQUIRED and not self.tool_calls: return True # Will be handled in act() # For 'auto' mode, continue with content if no commands but content exists - if self.tool_choices == "auto" and not self.tool_calls: + if self.tool_choices == ToolChoice.AUTO and not self.tool_calls: return bool(response.content) return bool(self.tool_calls) @@ -101,7 +101,7 @@ class ToolCallAgent(ReActAgent): async def act(self) -> str: """Execute tool calls and handle their results""" if not self.tool_calls: - if self.tool_choices == "required": + if self.tool_choices == ToolChoice.REQUIRED: raise ValueError(TOOL_CALL_REQUIRED) # Return last message content if no tool calls diff --git a/app/flow/planning.py b/app/flow/planning.py index 808949f..a12bbe4 100644 --- a/app/flow/planning.py +++ b/app/flow/planning.py @@ -8,7 +8,7 @@ from app.agent.base import BaseAgent from app.flow.base import BaseFlow, PlanStepStatus from app.llm import LLM from app.logger import logger -from app.schema import AgentState, Message +from app.schema import AgentState, Message, ToolChoice from app.tool import PlanningTool @@ -124,7 +124,7 @@ class PlanningFlow(BaseFlow): messages=[user_message], system_msgs=[system_message], tools=[self.planning_tool.to_param()], - tool_choice="required", + tool_choice=ToolChoice.REQUIRED, ) # Process tool calls if present diff --git a/app/llm.py b/app/llm.py index 8f62782..39a0dc9 100644 --- a/app/llm.py +++ b/app/llm.py @@ -1,4 +1,4 @@ -from typing import Dict, List, Literal, Optional, Union +from typing import Dict, List, Optional, Union from openai import ( APIError, @@ -12,7 +12,7 @@ from tenacity import retry, stop_after_attempt, wait_random_exponential from app.config import LLMSettings, config from app.logger import logger # Assuming a logger is set up in your app -from app.schema import Message +from app.schema import Message, TOOL_CHOICE_TYPE, ROLE_VALUES, TOOL_CHOICE_VALUES, ToolChoice class LLM: @@ -88,7 +88,7 @@ class LLM: # Validate all messages have required fields for msg in formatted_messages: - if msg["role"] not in ["system", "user", "assistant", "tool"]: + if msg["role"] not in ROLE_VALUES: raise ValueError(f"Invalid role: {msg['role']}") if "content" not in msg and "tool_calls" not in msg: raise ValueError( @@ -187,7 +187,7 @@ class LLM: system_msgs: Optional[List[Union[dict, Message]]] = None, timeout: int = 60, tools: Optional[List[dict]] = None, - tool_choice: Literal["none", "auto", "required"] = "auto", + tool_choice: TOOL_CHOICE_TYPE = ToolChoice.AUTO, # type: ignore temperature: Optional[float] = None, **kwargs, ): @@ -213,7 +213,7 @@ class LLM: """ try: # Validate tool_choice - if tool_choice not in ["none", "auto", "required"]: + if tool_choice not in TOOL_CHOICE_VALUES: raise ValueError(f"Invalid tool_choice: {tool_choice}") # Format messages diff --git a/app/schema.py b/app/schema.py index 0b09d45..30ccf6c 100644 --- a/app/schema.py +++ b/app/schema.py @@ -3,6 +3,24 @@ from typing import Any, List, Literal, Optional, Union from pydantic import BaseModel, Field +class Role(str, Enum): + """Message role options""" + SYSTEM = "system" + USER = "user" + ASSISTANT = "assistant" + TOOL = "tool" + +ROLE_VALUES = tuple(role.value for role in Role) +ROLE_TYPE = Literal[ROLE_VALUES] # type: ignore + +class ToolChoice(str, Enum): + """Tool choice options""" + NONE = "none" + AUTO = "auto" + REQUIRED = "required" + +TOOL_CHOICE_VALUES = tuple(choice.value for choice in ToolChoice) +TOOL_CHOICE_TYPE = Literal[TOOL_CHOICE_VALUES] # type: ignore class AgentState(str, Enum): """Agent execution states""" @@ -29,7 +47,7 @@ class ToolCall(BaseModel): class Message(BaseModel): """Represents a chat message in the conversation""" - role: Literal["system", "user", "assistant", "tool"] = Field(...) + role: ROLE_TYPE = Field(...) # type: ignore content: Optional[str] = Field(default=None) tool_calls: Optional[List[ToolCall]] = Field(default=None) name: Optional[str] = Field(default=None) @@ -71,22 +89,22 @@ class Message(BaseModel): @classmethod def user_message(cls, content: str) -> "Message": """Create a user message""" - return cls(role="user", content=content) + return cls(role=Role.USER, content=content) @classmethod def system_message(cls, content: str) -> "Message": """Create a system message""" - return cls(role="system", content=content) + return cls(role=Role.SYSTEM, content=content) @classmethod def assistant_message(cls, content: Optional[str] = None) -> "Message": """Create an assistant message""" - return cls(role="assistant", content=content) + return cls(role=Role.ASSISTANT, content=content) @classmethod def tool_message(cls, content: str, name, tool_call_id: str) -> "Message": """Create a tool message""" - return cls(role="tool", content=content, name=name, tool_call_id=tool_call_id) + return cls(role=Role.TOOL, content=content, name=name, tool_call_id=tool_call_id) @classmethod def from_tool_calls( @@ -103,7 +121,7 @@ class Message(BaseModel): for call in tool_calls ] return cls( - role="assistant", content=content, tool_calls=formatted_calls, **kwargs + role=Role.ASSISTANT, content=content, tool_calls=formatted_calls, **kwargs )