Merge pull request #516 from matengm1/refactor/standardize-tool-choice-literals

Standardize literals for role and tool choice type definitions
This commit is contained in:
Isaac 2025-03-12 20:59:04 +08:00 committed by GitHub
commit 067c59e39c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 45 additions and 26 deletions

View File

@ -1,12 +1,12 @@
from abc import ABC, abstractmethod
from contextlib import asynccontextmanager
from typing import List, Literal, Optional
from typing import List, Optional
from pydantic import BaseModel, Field, model_validator
from app.llm import LLM
from app.logger import logger
from app.schema import AgentState, Memory, Message
from app.schema import AgentState, Memory, Message, ROLE_TYPE
class BaseAgent(BaseModel, ABC):
@ -82,7 +82,7 @@ class BaseAgent(BaseModel, ABC):
def update_memory(
self,
role: Literal["user", "system", "assistant", "tool"],
role: ROLE_TYPE, # type: ignore
content: str,
**kwargs,
) -> None:

View File

@ -1,12 +1,12 @@
import time
from typing import Dict, List, Literal, Optional
from typing import Dict, List, Optional
from pydantic import Field, model_validator
from app.agent.toolcall import ToolCallAgent
from app.logger import logger
from app.prompt.planning import NEXT_STEP_PROMPT, PLANNING_SYSTEM_PROMPT
from app.schema import Message, ToolCall
from app.schema import Message, TOOL_CHOICE_TYPE, ToolCall, ToolChoice
from app.tool import PlanningTool, Terminate, ToolCollection
@ -27,7 +27,7 @@ class PlanningAgent(ToolCallAgent):
available_tools: ToolCollection = Field(
default_factory=lambda: ToolCollection(PlanningTool(), Terminate())
)
tool_choices: Literal["none", "auto", "required"] = "auto"
tool_choices: TOOL_CHOICE_TYPE = ToolChoice.AUTO # type: ignore
special_tool_names: List[str] = Field(default_factory=lambda: [Terminate().name])
tool_calls: List[ToolCall] = Field(default_factory=list)
@ -212,7 +212,7 @@ class PlanningAgent(ToolCallAgent):
messages=messages,
system_msgs=[Message.system_message(self.system_prompt)],
tools=self.available_tools.to_params(),
tool_choice="required",
tool_choice=ToolChoice.REQUIRED,
)
assistant_msg = Message.from_tool_calls(
content=response.content, tool_calls=response.tool_calls

View File

@ -1,4 +1,5 @@
import json
from typing import Any, List, Literal, Optional, Union
from pydantic import Field
@ -6,7 +7,7 @@ from pydantic import Field
from app.agent.react import ReActAgent
from app.logger import logger
from app.prompt.toolcall import NEXT_STEP_PROMPT, SYSTEM_PROMPT
from app.schema import AgentState, Message, ToolCall
from app.schema import AgentState, Message, ToolCall, TOOL_CHOICE_TYPE, ToolChoice
from app.tool import CreateChatCompletion, Terminate, ToolCollection
@ -25,7 +26,7 @@ class ToolCallAgent(ReActAgent):
available_tools: ToolCollection = ToolCollection(
CreateChatCompletion(), Terminate()
)
tool_choices: Literal["none", "auto", "required"] = "auto"
tool_choices: TOOL_CHOICE_TYPE = ToolChoice.AUTO # type: ignore
special_tool_names: List[str] = Field(default_factory=lambda: [Terminate().name])
tool_calls: List[ToolCall] = Field(default_factory=list)
@ -62,7 +63,7 @@ class ToolCallAgent(ReActAgent):
try:
# Handle different tool_choices modes
if self.tool_choices == "none":
if self.tool_choices == ToolChoice.NONE:
if response.tool_calls:
logger.warning(
f"🤔 Hmm, {self.name} tried to use tools when they weren't available!"
@ -82,11 +83,11 @@ class ToolCallAgent(ReActAgent):
)
self.memory.add_message(assistant_msg)
if self.tool_choices == "required" and not self.tool_calls:
if self.tool_choices == ToolChoice.REQUIRED and not self.tool_calls:
return True # Will be handled in act()
# For 'auto' mode, continue with content if no commands but content exists
if self.tool_choices == "auto" and not self.tool_calls:
if self.tool_choices == ToolChoice.AUTO and not self.tool_calls:
return bool(response.content)
return bool(self.tool_calls)
@ -102,7 +103,7 @@ class ToolCallAgent(ReActAgent):
async def act(self) -> str:
"""Execute tool calls and handle their results"""
if not self.tool_calls:
if self.tool_choices == "required":
if self.tool_choices == ToolChoice.REQUIRED:
raise ValueError(TOOL_CALL_REQUIRED)
# Return last message content if no tool calls

View File

@ -8,7 +8,7 @@ from app.agent.base import BaseAgent
from app.flow.base import BaseFlow, PlanStepStatus
from app.llm import LLM
from app.logger import logger
from app.schema import AgentState, Message
from app.schema import AgentState, Message, ToolChoice
from app.tool import PlanningTool
@ -124,7 +124,7 @@ class PlanningFlow(BaseFlow):
messages=[user_message],
system_msgs=[system_message],
tools=[self.planning_tool.to_param()],
tool_choice="required",
tool_choice=ToolChoice.REQUIRED,
)
# Process tool calls if present

View File

@ -1,4 +1,4 @@
from typing import Dict, List, Literal, Optional, Union
from typing import Dict, List, Optional, Union
from openai import (
APIError,
@ -12,7 +12,7 @@ from tenacity import retry, stop_after_attempt, wait_random_exponential
from app.config import LLMSettings, config
from app.logger import logger # Assuming a logger is set up in your app
from app.schema import Message
from app.schema import Message, TOOL_CHOICE_TYPE, ROLE_VALUES, TOOL_CHOICE_VALUES, ToolChoice
class LLM:
@ -88,7 +88,7 @@ class LLM:
# Validate all messages have required fields
for msg in formatted_messages:
if msg["role"] not in ["system", "user", "assistant", "tool"]:
if msg["role"] not in ROLE_VALUES:
raise ValueError(f"Invalid role: {msg['role']}")
if "content" not in msg and "tool_calls" not in msg:
raise ValueError(
@ -187,7 +187,7 @@ class LLM:
system_msgs: Optional[List[Union[dict, Message]]] = None,
timeout: int = 300,
tools: Optional[List[dict]] = None,
tool_choice: Literal["none", "auto", "required"] = "auto",
tool_choice: TOOL_CHOICE_TYPE = ToolChoice.AUTO, # type: ignore
temperature: Optional[float] = None,
**kwargs,
):
@ -213,7 +213,7 @@ class LLM:
"""
try:
# Validate tool_choice
if tool_choice not in ["none", "auto", "required"]:
if tool_choice not in TOOL_CHOICE_VALUES:
raise ValueError(f"Invalid tool_choice: {tool_choice}")
# Format messages

View File

@ -3,6 +3,24 @@ from typing import Any, List, Literal, Optional, Union
from pydantic import BaseModel, Field
class Role(str, Enum):
"""Message role options"""
SYSTEM = "system"
USER = "user"
ASSISTANT = "assistant"
TOOL = "tool"
ROLE_VALUES = tuple(role.value for role in Role)
ROLE_TYPE = Literal[ROLE_VALUES] # type: ignore
class ToolChoice(str, Enum):
"""Tool choice options"""
NONE = "none"
AUTO = "auto"
REQUIRED = "required"
TOOL_CHOICE_VALUES = tuple(choice.value for choice in ToolChoice)
TOOL_CHOICE_TYPE = Literal[TOOL_CHOICE_VALUES] # type: ignore
class AgentState(str, Enum):
"""Agent execution states"""
@ -29,7 +47,7 @@ class ToolCall(BaseModel):
class Message(BaseModel):
"""Represents a chat message in the conversation"""
role: Literal["system", "user", "assistant", "tool"] = Field(...)
role: ROLE_TYPE = Field(...) # type: ignore
content: Optional[str] = Field(default=None)
tool_calls: Optional[List[ToolCall]] = Field(default=None)
name: Optional[str] = Field(default=None)
@ -71,22 +89,22 @@ class Message(BaseModel):
@classmethod
def user_message(cls, content: str) -> "Message":
"""Create a user message"""
return cls(role="user", content=content)
return cls(role=Role.USER, content=content)
@classmethod
def system_message(cls, content: str) -> "Message":
"""Create a system message"""
return cls(role="system", content=content)
return cls(role=Role.SYSTEM, content=content)
@classmethod
def assistant_message(cls, content: Optional[str] = None) -> "Message":
"""Create an assistant message"""
return cls(role="assistant", content=content)
return cls(role=Role.ASSISTANT, content=content)
@classmethod
def tool_message(cls, content: str, name, tool_call_id: str) -> "Message":
"""Create a tool message"""
return cls(role="tool", content=content, name=name, tool_call_id=tool_call_id)
return cls(role=Role.TOOL, content=content, name=name, tool_call_id=tool_call_id)
@classmethod
def from_tool_calls(
@ -103,7 +121,7 @@ class Message(BaseModel):
for call in tool_calls
]
return cls(
role="assistant", content=content, tool_calls=formatted_calls, **kwargs
role=Role.ASSISTANT, content=content, tool_calls=formatted_calls, **kwargs
)