Standardize literals for role and tool choice type definitions

This commit is contained in:
Matt Eng 2025-03-12 00:15:31 -07:00
parent af8023de43
commit eac3a6e24e
6 changed files with 45 additions and 27 deletions

View File

@ -1,12 +1,12 @@
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
from typing import List, Literal, Optional from typing import List, Optional
from pydantic import BaseModel, Field, model_validator from pydantic import BaseModel, Field, model_validator
from app.llm import LLM from app.llm import LLM
from app.logger import logger from app.logger import logger
from app.schema import AgentState, Memory, Message from app.schema import AgentState, Memory, Message, ROLE_TYPE
class BaseAgent(BaseModel, ABC): class BaseAgent(BaseModel, ABC):
@ -82,7 +82,7 @@ class BaseAgent(BaseModel, ABC):
def update_memory( def update_memory(
self, self,
role: Literal["user", "system", "assistant", "tool"], role: ROLE_TYPE, # type: ignore
content: str, content: str,
**kwargs, **kwargs,
) -> None: ) -> None:

View File

@ -1,12 +1,12 @@
import time import time
from typing import Dict, List, Literal, Optional from typing import Dict, List, Optional
from pydantic import Field, model_validator from pydantic import Field, model_validator
from app.agent.toolcall import ToolCallAgent from app.agent.toolcall import ToolCallAgent
from app.logger import logger from app.logger import logger
from app.prompt.planning import NEXT_STEP_PROMPT, PLANNING_SYSTEM_PROMPT from app.prompt.planning import NEXT_STEP_PROMPT, PLANNING_SYSTEM_PROMPT
from app.schema import Message, ToolCall from app.schema import Message, TOOL_CHOICE_TYPE, ToolCall, ToolChoice
from app.tool import PlanningTool, Terminate, ToolCollection from app.tool import PlanningTool, Terminate, ToolCollection
@ -27,7 +27,7 @@ class PlanningAgent(ToolCallAgent):
available_tools: ToolCollection = Field( available_tools: ToolCollection = Field(
default_factory=lambda: ToolCollection(PlanningTool(), Terminate()) default_factory=lambda: ToolCollection(PlanningTool(), Terminate())
) )
tool_choices: Literal["none", "auto", "required"] = "auto" tool_choices: TOOL_CHOICE_TYPE = ToolChoice.AUTO # type: ignore
special_tool_names: List[str] = Field(default_factory=lambda: [Terminate().name]) special_tool_names: List[str] = Field(default_factory=lambda: [Terminate().name])
tool_calls: List[ToolCall] = Field(default_factory=list) tool_calls: List[ToolCall] = Field(default_factory=list)
@ -212,7 +212,7 @@ class PlanningAgent(ToolCallAgent):
messages=messages, messages=messages,
system_msgs=[Message.system_message(self.system_prompt)], system_msgs=[Message.system_message(self.system_prompt)],
tools=self.available_tools.to_params(), tools=self.available_tools.to_params(),
tool_choice="required", tool_choice=ToolChoice.REQUIRED,
) )
assistant_msg = Message.from_tool_calls( assistant_msg = Message.from_tool_calls(
content=response.content, tool_calls=response.tool_calls content=response.content, tool_calls=response.tool_calls

View File

@ -1,12 +1,12 @@
import json import json
from typing import Any, List, Literal from typing import Any, List
from pydantic import Field from pydantic import Field
from app.agent.react import ReActAgent from app.agent.react import ReActAgent
from app.logger import logger from app.logger import logger
from app.prompt.toolcall import NEXT_STEP_PROMPT, SYSTEM_PROMPT from app.prompt.toolcall import NEXT_STEP_PROMPT, SYSTEM_PROMPT
from app.schema import AgentState, Message, ToolCall from app.schema import AgentState, Message, ToolCall, TOOL_CHOICE_TYPE, ToolChoice
from app.tool import CreateChatCompletion, Terminate, ToolCollection from app.tool import CreateChatCompletion, Terminate, ToolCollection
@ -25,7 +25,7 @@ class ToolCallAgent(ReActAgent):
available_tools: ToolCollection = ToolCollection( available_tools: ToolCollection = ToolCollection(
CreateChatCompletion(), Terminate() CreateChatCompletion(), Terminate()
) )
tool_choices: Literal["none", "auto", "required"] = "auto" tool_choices: TOOL_CHOICE_TYPE = ToolChoice.AUTO # type: ignore
special_tool_names: List[str] = Field(default_factory=lambda: [Terminate().name]) special_tool_names: List[str] = Field(default_factory=lambda: [Terminate().name])
tool_calls: List[ToolCall] = Field(default_factory=list) tool_calls: List[ToolCall] = Field(default_factory=list)
@ -61,7 +61,7 @@ class ToolCallAgent(ReActAgent):
try: try:
# Handle different tool_choices modes # Handle different tool_choices modes
if self.tool_choices == "none": if self.tool_choices == ToolChoice.NONE:
if response.tool_calls: if response.tool_calls:
logger.warning( logger.warning(
f"🤔 Hmm, {self.name} tried to use tools when they weren't available!" f"🤔 Hmm, {self.name} tried to use tools when they weren't available!"
@ -81,11 +81,11 @@ class ToolCallAgent(ReActAgent):
) )
self.memory.add_message(assistant_msg) self.memory.add_message(assistant_msg)
if self.tool_choices == "required" and not self.tool_calls: if self.tool_choices == ToolChoice.REQUIRED and not self.tool_calls:
return True # Will be handled in act() return True # Will be handled in act()
# For 'auto' mode, continue with content if no commands but content exists # For 'auto' mode, continue with content if no commands but content exists
if self.tool_choices == "auto" and not self.tool_calls: if self.tool_choices == ToolChoice.AUTO and not self.tool_calls:
return bool(response.content) return bool(response.content)
return bool(self.tool_calls) return bool(self.tool_calls)
@ -101,7 +101,7 @@ class ToolCallAgent(ReActAgent):
async def act(self) -> str: async def act(self) -> str:
"""Execute tool calls and handle their results""" """Execute tool calls and handle their results"""
if not self.tool_calls: if not self.tool_calls:
if self.tool_choices == "required": if self.tool_choices == ToolChoice.REQUIRED:
raise ValueError(TOOL_CALL_REQUIRED) raise ValueError(TOOL_CALL_REQUIRED)
# Return last message content if no tool calls # Return last message content if no tool calls

View File

@ -8,7 +8,7 @@ from app.agent.base import BaseAgent
from app.flow.base import BaseFlow, PlanStepStatus from app.flow.base import BaseFlow, PlanStepStatus
from app.llm import LLM from app.llm import LLM
from app.logger import logger from app.logger import logger
from app.schema import AgentState, Message from app.schema import AgentState, Message, ToolChoice
from app.tool import PlanningTool from app.tool import PlanningTool
@ -124,7 +124,7 @@ class PlanningFlow(BaseFlow):
messages=[user_message], messages=[user_message],
system_msgs=[system_message], system_msgs=[system_message],
tools=[self.planning_tool.to_param()], tools=[self.planning_tool.to_param()],
tool_choice="required", tool_choice=ToolChoice.REQUIRED,
) )
# Process tool calls if present # Process tool calls if present

View File

@ -1,4 +1,4 @@
from typing import Dict, List, Literal, Optional, Union from typing import Dict, List, Optional, Union
from openai import ( from openai import (
APIError, APIError,
@ -12,7 +12,7 @@ from tenacity import retry, stop_after_attempt, wait_random_exponential
from app.config import LLMSettings, config from app.config import LLMSettings, config
from app.logger import logger # Assuming a logger is set up in your app from app.logger import logger # Assuming a logger is set up in your app
from app.schema import Message from app.schema import Message, TOOL_CHOICE_TYPE, ROLE_VALUES, TOOL_CHOICE_VALUES, ToolChoice
class LLM: class LLM:
@ -88,7 +88,7 @@ class LLM:
# Validate all messages have required fields # Validate all messages have required fields
for msg in formatted_messages: for msg in formatted_messages:
if msg["role"] not in ["system", "user", "assistant", "tool"]: if msg["role"] not in ROLE_VALUES:
raise ValueError(f"Invalid role: {msg['role']}") raise ValueError(f"Invalid role: {msg['role']}")
if "content" not in msg and "tool_calls" not in msg: if "content" not in msg and "tool_calls" not in msg:
raise ValueError( raise ValueError(
@ -187,7 +187,7 @@ class LLM:
system_msgs: Optional[List[Union[dict, Message]]] = None, system_msgs: Optional[List[Union[dict, Message]]] = None,
timeout: int = 60, timeout: int = 60,
tools: Optional[List[dict]] = None, tools: Optional[List[dict]] = None,
tool_choice: Literal["none", "auto", "required"] = "auto", tool_choice: TOOL_CHOICE_TYPE = ToolChoice.AUTO, # type: ignore
temperature: Optional[float] = None, temperature: Optional[float] = None,
**kwargs, **kwargs,
): ):
@ -213,7 +213,7 @@ class LLM:
""" """
try: try:
# Validate tool_choice # Validate tool_choice
if tool_choice not in ["none", "auto", "required"]: if tool_choice not in TOOL_CHOICE_VALUES:
raise ValueError(f"Invalid tool_choice: {tool_choice}") raise ValueError(f"Invalid tool_choice: {tool_choice}")
# Format messages # Format messages

View File

@ -3,6 +3,24 @@ from typing import Any, List, Literal, Optional, Union
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
class Role(str, Enum):
"""Message role options"""
SYSTEM = "system"
USER = "user"
ASSISTANT = "assistant"
TOOL = "tool"
ROLE_VALUES = tuple(role.value for role in Role)
ROLE_TYPE = Literal[ROLE_VALUES] # type: ignore
class ToolChoice(str, Enum):
"""Tool choice options"""
NONE = "none"
AUTO = "auto"
REQUIRED = "required"
TOOL_CHOICE_VALUES = tuple(choice.value for choice in ToolChoice)
TOOL_CHOICE_TYPE = Literal[TOOL_CHOICE_VALUES] # type: ignore
class AgentState(str, Enum): class AgentState(str, Enum):
"""Agent execution states""" """Agent execution states"""
@ -29,7 +47,7 @@ class ToolCall(BaseModel):
class Message(BaseModel): class Message(BaseModel):
"""Represents a chat message in the conversation""" """Represents a chat message in the conversation"""
role: Literal["system", "user", "assistant", "tool"] = Field(...) role: ROLE_TYPE = Field(...) # type: ignore
content: Optional[str] = Field(default=None) content: Optional[str] = Field(default=None)
tool_calls: Optional[List[ToolCall]] = Field(default=None) tool_calls: Optional[List[ToolCall]] = Field(default=None)
name: Optional[str] = Field(default=None) name: Optional[str] = Field(default=None)
@ -71,22 +89,22 @@ class Message(BaseModel):
@classmethod @classmethod
def user_message(cls, content: str) -> "Message": def user_message(cls, content: str) -> "Message":
"""Create a user message""" """Create a user message"""
return cls(role="user", content=content) return cls(role=Role.USER, content=content)
@classmethod @classmethod
def system_message(cls, content: str) -> "Message": def system_message(cls, content: str) -> "Message":
"""Create a system message""" """Create a system message"""
return cls(role="system", content=content) return cls(role=Role.SYSTEM, content=content)
@classmethod @classmethod
def assistant_message(cls, content: Optional[str] = None) -> "Message": def assistant_message(cls, content: Optional[str] = None) -> "Message":
"""Create an assistant message""" """Create an assistant message"""
return cls(role="assistant", content=content) return cls(role=Role.ASSISTANT, content=content)
@classmethod @classmethod
def tool_message(cls, content: str, name, tool_call_id: str) -> "Message": def tool_message(cls, content: str, name, tool_call_id: str) -> "Message":
"""Create a tool message""" """Create a tool message"""
return cls(role="tool", content=content, name=name, tool_call_id=tool_call_id) return cls(role=Role.TOOL, content=content, name=name, tool_call_id=tool_call_id)
@classmethod @classmethod
def from_tool_calls( def from_tool_calls(
@ -103,7 +121,7 @@ class Message(BaseModel):
for call in tool_calls for call in tool_calls
] ]
return cls( return cls(
role="assistant", content=content, tool_calls=formatted_calls, **kwargs role=Role.ASSISTANT, content=content, tool_calls=formatted_calls, **kwargs
) )