Merge pull request #516 from matengm1/refactor/standardize-tool-choice-literals
Standardize literals for role and tool choice type definitions
This commit is contained in:
commit
067c59e39c
@ -1,12 +1,12 @@
|
|||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from contextlib import asynccontextmanager
|
from contextlib import asynccontextmanager
|
||||||
from typing import List, Literal, Optional
|
from typing import List, Optional
|
||||||
|
|
||||||
from pydantic import BaseModel, Field, model_validator
|
from pydantic import BaseModel, Field, model_validator
|
||||||
|
|
||||||
from app.llm import LLM
|
from app.llm import LLM
|
||||||
from app.logger import logger
|
from app.logger import logger
|
||||||
from app.schema import AgentState, Memory, Message
|
from app.schema import AgentState, Memory, Message, ROLE_TYPE
|
||||||
|
|
||||||
|
|
||||||
class BaseAgent(BaseModel, ABC):
|
class BaseAgent(BaseModel, ABC):
|
||||||
@ -82,7 +82,7 @@ class BaseAgent(BaseModel, ABC):
|
|||||||
|
|
||||||
def update_memory(
|
def update_memory(
|
||||||
self,
|
self,
|
||||||
role: Literal["user", "system", "assistant", "tool"],
|
role: ROLE_TYPE, # type: ignore
|
||||||
content: str,
|
content: str,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> None:
|
) -> None:
|
||||||
|
@ -1,12 +1,12 @@
|
|||||||
import time
|
import time
|
||||||
from typing import Dict, List, Literal, Optional
|
from typing import Dict, List, Optional
|
||||||
|
|
||||||
from pydantic import Field, model_validator
|
from pydantic import Field, model_validator
|
||||||
|
|
||||||
from app.agent.toolcall import ToolCallAgent
|
from app.agent.toolcall import ToolCallAgent
|
||||||
from app.logger import logger
|
from app.logger import logger
|
||||||
from app.prompt.planning import NEXT_STEP_PROMPT, PLANNING_SYSTEM_PROMPT
|
from app.prompt.planning import NEXT_STEP_PROMPT, PLANNING_SYSTEM_PROMPT
|
||||||
from app.schema import Message, ToolCall
|
from app.schema import Message, TOOL_CHOICE_TYPE, ToolCall, ToolChoice
|
||||||
from app.tool import PlanningTool, Terminate, ToolCollection
|
from app.tool import PlanningTool, Terminate, ToolCollection
|
||||||
|
|
||||||
|
|
||||||
@ -27,7 +27,7 @@ class PlanningAgent(ToolCallAgent):
|
|||||||
available_tools: ToolCollection = Field(
|
available_tools: ToolCollection = Field(
|
||||||
default_factory=lambda: ToolCollection(PlanningTool(), Terminate())
|
default_factory=lambda: ToolCollection(PlanningTool(), Terminate())
|
||||||
)
|
)
|
||||||
tool_choices: Literal["none", "auto", "required"] = "auto"
|
tool_choices: TOOL_CHOICE_TYPE = ToolChoice.AUTO # type: ignore
|
||||||
special_tool_names: List[str] = Field(default_factory=lambda: [Terminate().name])
|
special_tool_names: List[str] = Field(default_factory=lambda: [Terminate().name])
|
||||||
|
|
||||||
tool_calls: List[ToolCall] = Field(default_factory=list)
|
tool_calls: List[ToolCall] = Field(default_factory=list)
|
||||||
@ -212,7 +212,7 @@ class PlanningAgent(ToolCallAgent):
|
|||||||
messages=messages,
|
messages=messages,
|
||||||
system_msgs=[Message.system_message(self.system_prompt)],
|
system_msgs=[Message.system_message(self.system_prompt)],
|
||||||
tools=self.available_tools.to_params(),
|
tools=self.available_tools.to_params(),
|
||||||
tool_choice="required",
|
tool_choice=ToolChoice.REQUIRED,
|
||||||
)
|
)
|
||||||
assistant_msg = Message.from_tool_calls(
|
assistant_msg = Message.from_tool_calls(
|
||||||
content=response.content, tool_calls=response.tool_calls
|
content=response.content, tool_calls=response.tool_calls
|
||||||
|
@ -1,4 +1,5 @@
|
|||||||
import json
|
import json
|
||||||
|
|
||||||
from typing import Any, List, Literal, Optional, Union
|
from typing import Any, List, Literal, Optional, Union
|
||||||
|
|
||||||
from pydantic import Field
|
from pydantic import Field
|
||||||
@ -6,7 +7,7 @@ from pydantic import Field
|
|||||||
from app.agent.react import ReActAgent
|
from app.agent.react import ReActAgent
|
||||||
from app.logger import logger
|
from app.logger import logger
|
||||||
from app.prompt.toolcall import NEXT_STEP_PROMPT, SYSTEM_PROMPT
|
from app.prompt.toolcall import NEXT_STEP_PROMPT, SYSTEM_PROMPT
|
||||||
from app.schema import AgentState, Message, ToolCall
|
from app.schema import AgentState, Message, ToolCall, TOOL_CHOICE_TYPE, ToolChoice
|
||||||
from app.tool import CreateChatCompletion, Terminate, ToolCollection
|
from app.tool import CreateChatCompletion, Terminate, ToolCollection
|
||||||
|
|
||||||
|
|
||||||
@ -25,7 +26,7 @@ class ToolCallAgent(ReActAgent):
|
|||||||
available_tools: ToolCollection = ToolCollection(
|
available_tools: ToolCollection = ToolCollection(
|
||||||
CreateChatCompletion(), Terminate()
|
CreateChatCompletion(), Terminate()
|
||||||
)
|
)
|
||||||
tool_choices: Literal["none", "auto", "required"] = "auto"
|
tool_choices: TOOL_CHOICE_TYPE = ToolChoice.AUTO # type: ignore
|
||||||
special_tool_names: List[str] = Field(default_factory=lambda: [Terminate().name])
|
special_tool_names: List[str] = Field(default_factory=lambda: [Terminate().name])
|
||||||
|
|
||||||
tool_calls: List[ToolCall] = Field(default_factory=list)
|
tool_calls: List[ToolCall] = Field(default_factory=list)
|
||||||
@ -62,7 +63,7 @@ class ToolCallAgent(ReActAgent):
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
# Handle different tool_choices modes
|
# Handle different tool_choices modes
|
||||||
if self.tool_choices == "none":
|
if self.tool_choices == ToolChoice.NONE:
|
||||||
if response.tool_calls:
|
if response.tool_calls:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
f"🤔 Hmm, {self.name} tried to use tools when they weren't available!"
|
f"🤔 Hmm, {self.name} tried to use tools when they weren't available!"
|
||||||
@ -82,11 +83,11 @@ class ToolCallAgent(ReActAgent):
|
|||||||
)
|
)
|
||||||
self.memory.add_message(assistant_msg)
|
self.memory.add_message(assistant_msg)
|
||||||
|
|
||||||
if self.tool_choices == "required" and not self.tool_calls:
|
if self.tool_choices == ToolChoice.REQUIRED and not self.tool_calls:
|
||||||
return True # Will be handled in act()
|
return True # Will be handled in act()
|
||||||
|
|
||||||
# For 'auto' mode, continue with content if no commands but content exists
|
# For 'auto' mode, continue with content if no commands but content exists
|
||||||
if self.tool_choices == "auto" and not self.tool_calls:
|
if self.tool_choices == ToolChoice.AUTO and not self.tool_calls:
|
||||||
return bool(response.content)
|
return bool(response.content)
|
||||||
|
|
||||||
return bool(self.tool_calls)
|
return bool(self.tool_calls)
|
||||||
@ -102,7 +103,7 @@ class ToolCallAgent(ReActAgent):
|
|||||||
async def act(self) -> str:
|
async def act(self) -> str:
|
||||||
"""Execute tool calls and handle their results"""
|
"""Execute tool calls and handle their results"""
|
||||||
if not self.tool_calls:
|
if not self.tool_calls:
|
||||||
if self.tool_choices == "required":
|
if self.tool_choices == ToolChoice.REQUIRED:
|
||||||
raise ValueError(TOOL_CALL_REQUIRED)
|
raise ValueError(TOOL_CALL_REQUIRED)
|
||||||
|
|
||||||
# Return last message content if no tool calls
|
# Return last message content if no tool calls
|
||||||
|
@ -8,7 +8,7 @@ from app.agent.base import BaseAgent
|
|||||||
from app.flow.base import BaseFlow, PlanStepStatus
|
from app.flow.base import BaseFlow, PlanStepStatus
|
||||||
from app.llm import LLM
|
from app.llm import LLM
|
||||||
from app.logger import logger
|
from app.logger import logger
|
||||||
from app.schema import AgentState, Message
|
from app.schema import AgentState, Message, ToolChoice
|
||||||
from app.tool import PlanningTool
|
from app.tool import PlanningTool
|
||||||
|
|
||||||
|
|
||||||
@ -124,7 +124,7 @@ class PlanningFlow(BaseFlow):
|
|||||||
messages=[user_message],
|
messages=[user_message],
|
||||||
system_msgs=[system_message],
|
system_msgs=[system_message],
|
||||||
tools=[self.planning_tool.to_param()],
|
tools=[self.planning_tool.to_param()],
|
||||||
tool_choice="required",
|
tool_choice=ToolChoice.REQUIRED,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Process tool calls if present
|
# Process tool calls if present
|
||||||
|
10
app/llm.py
10
app/llm.py
@ -1,4 +1,4 @@
|
|||||||
from typing import Dict, List, Literal, Optional, Union
|
from typing import Dict, List, Optional, Union
|
||||||
|
|
||||||
from openai import (
|
from openai import (
|
||||||
APIError,
|
APIError,
|
||||||
@ -12,7 +12,7 @@ from tenacity import retry, stop_after_attempt, wait_random_exponential
|
|||||||
|
|
||||||
from app.config import LLMSettings, config
|
from app.config import LLMSettings, config
|
||||||
from app.logger import logger # Assuming a logger is set up in your app
|
from app.logger import logger # Assuming a logger is set up in your app
|
||||||
from app.schema import Message
|
from app.schema import Message, TOOL_CHOICE_TYPE, ROLE_VALUES, TOOL_CHOICE_VALUES, ToolChoice
|
||||||
|
|
||||||
|
|
||||||
class LLM:
|
class LLM:
|
||||||
@ -88,7 +88,7 @@ class LLM:
|
|||||||
|
|
||||||
# Validate all messages have required fields
|
# Validate all messages have required fields
|
||||||
for msg in formatted_messages:
|
for msg in formatted_messages:
|
||||||
if msg["role"] not in ["system", "user", "assistant", "tool"]:
|
if msg["role"] not in ROLE_VALUES:
|
||||||
raise ValueError(f"Invalid role: {msg['role']}")
|
raise ValueError(f"Invalid role: {msg['role']}")
|
||||||
if "content" not in msg and "tool_calls" not in msg:
|
if "content" not in msg and "tool_calls" not in msg:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
@ -187,7 +187,7 @@ class LLM:
|
|||||||
system_msgs: Optional[List[Union[dict, Message]]] = None,
|
system_msgs: Optional[List[Union[dict, Message]]] = None,
|
||||||
timeout: int = 300,
|
timeout: int = 300,
|
||||||
tools: Optional[List[dict]] = None,
|
tools: Optional[List[dict]] = None,
|
||||||
tool_choice: Literal["none", "auto", "required"] = "auto",
|
tool_choice: TOOL_CHOICE_TYPE = ToolChoice.AUTO, # type: ignore
|
||||||
temperature: Optional[float] = None,
|
temperature: Optional[float] = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
@ -213,7 +213,7 @@ class LLM:
|
|||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
# Validate tool_choice
|
# Validate tool_choice
|
||||||
if tool_choice not in ["none", "auto", "required"]:
|
if tool_choice not in TOOL_CHOICE_VALUES:
|
||||||
raise ValueError(f"Invalid tool_choice: {tool_choice}")
|
raise ValueError(f"Invalid tool_choice: {tool_choice}")
|
||||||
|
|
||||||
# Format messages
|
# Format messages
|
||||||
|
@ -3,6 +3,24 @@ from typing import Any, List, Literal, Optional, Union
|
|||||||
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
class Role(str, Enum):
|
||||||
|
"""Message role options"""
|
||||||
|
SYSTEM = "system"
|
||||||
|
USER = "user"
|
||||||
|
ASSISTANT = "assistant"
|
||||||
|
TOOL = "tool"
|
||||||
|
|
||||||
|
ROLE_VALUES = tuple(role.value for role in Role)
|
||||||
|
ROLE_TYPE = Literal[ROLE_VALUES] # type: ignore
|
||||||
|
|
||||||
|
class ToolChoice(str, Enum):
|
||||||
|
"""Tool choice options"""
|
||||||
|
NONE = "none"
|
||||||
|
AUTO = "auto"
|
||||||
|
REQUIRED = "required"
|
||||||
|
|
||||||
|
TOOL_CHOICE_VALUES = tuple(choice.value for choice in ToolChoice)
|
||||||
|
TOOL_CHOICE_TYPE = Literal[TOOL_CHOICE_VALUES] # type: ignore
|
||||||
|
|
||||||
class AgentState(str, Enum):
|
class AgentState(str, Enum):
|
||||||
"""Agent execution states"""
|
"""Agent execution states"""
|
||||||
@ -29,7 +47,7 @@ class ToolCall(BaseModel):
|
|||||||
class Message(BaseModel):
|
class Message(BaseModel):
|
||||||
"""Represents a chat message in the conversation"""
|
"""Represents a chat message in the conversation"""
|
||||||
|
|
||||||
role: Literal["system", "user", "assistant", "tool"] = Field(...)
|
role: ROLE_TYPE = Field(...) # type: ignore
|
||||||
content: Optional[str] = Field(default=None)
|
content: Optional[str] = Field(default=None)
|
||||||
tool_calls: Optional[List[ToolCall]] = Field(default=None)
|
tool_calls: Optional[List[ToolCall]] = Field(default=None)
|
||||||
name: Optional[str] = Field(default=None)
|
name: Optional[str] = Field(default=None)
|
||||||
@ -71,22 +89,22 @@ class Message(BaseModel):
|
|||||||
@classmethod
|
@classmethod
|
||||||
def user_message(cls, content: str) -> "Message":
|
def user_message(cls, content: str) -> "Message":
|
||||||
"""Create a user message"""
|
"""Create a user message"""
|
||||||
return cls(role="user", content=content)
|
return cls(role=Role.USER, content=content)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def system_message(cls, content: str) -> "Message":
|
def system_message(cls, content: str) -> "Message":
|
||||||
"""Create a system message"""
|
"""Create a system message"""
|
||||||
return cls(role="system", content=content)
|
return cls(role=Role.SYSTEM, content=content)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def assistant_message(cls, content: Optional[str] = None) -> "Message":
|
def assistant_message(cls, content: Optional[str] = None) -> "Message":
|
||||||
"""Create an assistant message"""
|
"""Create an assistant message"""
|
||||||
return cls(role="assistant", content=content)
|
return cls(role=Role.ASSISTANT, content=content)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def tool_message(cls, content: str, name, tool_call_id: str) -> "Message":
|
def tool_message(cls, content: str, name, tool_call_id: str) -> "Message":
|
||||||
"""Create a tool message"""
|
"""Create a tool message"""
|
||||||
return cls(role="tool", content=content, name=name, tool_call_id=tool_call_id)
|
return cls(role=Role.TOOL, content=content, name=name, tool_call_id=tool_call_id)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_tool_calls(
|
def from_tool_calls(
|
||||||
@ -103,7 +121,7 @@ class Message(BaseModel):
|
|||||||
for call in tool_calls
|
for call in tool_calls
|
||||||
]
|
]
|
||||||
return cls(
|
return cls(
|
||||||
role="assistant", content=content, tool_calls=formatted_calls, **kwargs
|
role=Role.ASSISTANT, content=content, tool_calls=formatted_calls, **kwargs
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user