Standardize literals for role and tool choice type definitions
This commit is contained in:
parent
af8023de43
commit
eac3a6e24e
@ -1,12 +1,12 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from contextlib import asynccontextmanager
|
||||
from typing import List, Literal, Optional
|
||||
from typing import List, Optional
|
||||
|
||||
from pydantic import BaseModel, Field, model_validator
|
||||
|
||||
from app.llm import LLM
|
||||
from app.logger import logger
|
||||
from app.schema import AgentState, Memory, Message
|
||||
from app.schema import AgentState, Memory, Message, ROLE_TYPE
|
||||
|
||||
|
||||
class BaseAgent(BaseModel, ABC):
|
||||
@ -82,7 +82,7 @@ class BaseAgent(BaseModel, ABC):
|
||||
|
||||
def update_memory(
|
||||
self,
|
||||
role: Literal["user", "system", "assistant", "tool"],
|
||||
role: ROLE_TYPE, # type: ignore
|
||||
content: str,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
|
@ -1,12 +1,12 @@
|
||||
import time
|
||||
from typing import Dict, List, Literal, Optional
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
from pydantic import Field, model_validator
|
||||
|
||||
from app.agent.toolcall import ToolCallAgent
|
||||
from app.logger import logger
|
||||
from app.prompt.planning import NEXT_STEP_PROMPT, PLANNING_SYSTEM_PROMPT
|
||||
from app.schema import Message, ToolCall
|
||||
from app.schema import Message, TOOL_CHOICE_TYPE, ToolCall, ToolChoice
|
||||
from app.tool import PlanningTool, Terminate, ToolCollection
|
||||
|
||||
|
||||
@ -27,7 +27,7 @@ class PlanningAgent(ToolCallAgent):
|
||||
available_tools: ToolCollection = Field(
|
||||
default_factory=lambda: ToolCollection(PlanningTool(), Terminate())
|
||||
)
|
||||
tool_choices: Literal["none", "auto", "required"] = "auto"
|
||||
tool_choices: TOOL_CHOICE_TYPE = ToolChoice.AUTO # type: ignore
|
||||
special_tool_names: List[str] = Field(default_factory=lambda: [Terminate().name])
|
||||
|
||||
tool_calls: List[ToolCall] = Field(default_factory=list)
|
||||
@ -212,7 +212,7 @@ class PlanningAgent(ToolCallAgent):
|
||||
messages=messages,
|
||||
system_msgs=[Message.system_message(self.system_prompt)],
|
||||
tools=self.available_tools.to_params(),
|
||||
tool_choice="required",
|
||||
tool_choice=ToolChoice.REQUIRED,
|
||||
)
|
||||
assistant_msg = Message.from_tool_calls(
|
||||
content=response.content, tool_calls=response.tool_calls
|
||||
|
@ -1,12 +1,12 @@
|
||||
import json
|
||||
from typing import Any, List, Literal
|
||||
from typing import Any, List
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from app.agent.react import ReActAgent
|
||||
from app.logger import logger
|
||||
from app.prompt.toolcall import NEXT_STEP_PROMPT, SYSTEM_PROMPT
|
||||
from app.schema import AgentState, Message, ToolCall
|
||||
from app.schema import AgentState, Message, ToolCall, TOOL_CHOICE_TYPE, ToolChoice
|
||||
from app.tool import CreateChatCompletion, Terminate, ToolCollection
|
||||
|
||||
|
||||
@ -25,7 +25,7 @@ class ToolCallAgent(ReActAgent):
|
||||
available_tools: ToolCollection = ToolCollection(
|
||||
CreateChatCompletion(), Terminate()
|
||||
)
|
||||
tool_choices: Literal["none", "auto", "required"] = "auto"
|
||||
tool_choices: TOOL_CHOICE_TYPE = ToolChoice.AUTO # type: ignore
|
||||
special_tool_names: List[str] = Field(default_factory=lambda: [Terminate().name])
|
||||
|
||||
tool_calls: List[ToolCall] = Field(default_factory=list)
|
||||
@ -61,7 +61,7 @@ class ToolCallAgent(ReActAgent):
|
||||
|
||||
try:
|
||||
# Handle different tool_choices modes
|
||||
if self.tool_choices == "none":
|
||||
if self.tool_choices == ToolChoice.NONE:
|
||||
if response.tool_calls:
|
||||
logger.warning(
|
||||
f"🤔 Hmm, {self.name} tried to use tools when they weren't available!"
|
||||
@ -81,11 +81,11 @@ class ToolCallAgent(ReActAgent):
|
||||
)
|
||||
self.memory.add_message(assistant_msg)
|
||||
|
||||
if self.tool_choices == "required" and not self.tool_calls:
|
||||
if self.tool_choices == ToolChoice.REQUIRED and not self.tool_calls:
|
||||
return True # Will be handled in act()
|
||||
|
||||
# For 'auto' mode, continue with content if no commands but content exists
|
||||
if self.tool_choices == "auto" and not self.tool_calls:
|
||||
if self.tool_choices == ToolChoice.AUTO and not self.tool_calls:
|
||||
return bool(response.content)
|
||||
|
||||
return bool(self.tool_calls)
|
||||
@ -101,7 +101,7 @@ class ToolCallAgent(ReActAgent):
|
||||
async def act(self) -> str:
|
||||
"""Execute tool calls and handle their results"""
|
||||
if not self.tool_calls:
|
||||
if self.tool_choices == "required":
|
||||
if self.tool_choices == ToolChoice.REQUIRED:
|
||||
raise ValueError(TOOL_CALL_REQUIRED)
|
||||
|
||||
# Return last message content if no tool calls
|
||||
|
@ -8,7 +8,7 @@ from app.agent.base import BaseAgent
|
||||
from app.flow.base import BaseFlow, PlanStepStatus
|
||||
from app.llm import LLM
|
||||
from app.logger import logger
|
||||
from app.schema import AgentState, Message
|
||||
from app.schema import AgentState, Message, ToolChoice
|
||||
from app.tool import PlanningTool
|
||||
|
||||
|
||||
@ -124,7 +124,7 @@ class PlanningFlow(BaseFlow):
|
||||
messages=[user_message],
|
||||
system_msgs=[system_message],
|
||||
tools=[self.planning_tool.to_param()],
|
||||
tool_choice="required",
|
||||
tool_choice=ToolChoice.REQUIRED,
|
||||
)
|
||||
|
||||
# Process tool calls if present
|
||||
|
10
app/llm.py
10
app/llm.py
@ -1,4 +1,4 @@
|
||||
from typing import Dict, List, Literal, Optional, Union
|
||||
from typing import Dict, List, Optional, Union
|
||||
|
||||
from openai import (
|
||||
APIError,
|
||||
@ -12,7 +12,7 @@ from tenacity import retry, stop_after_attempt, wait_random_exponential
|
||||
|
||||
from app.config import LLMSettings, config
|
||||
from app.logger import logger # Assuming a logger is set up in your app
|
||||
from app.schema import Message
|
||||
from app.schema import Message, TOOL_CHOICE_TYPE, ROLE_VALUES, TOOL_CHOICE_VALUES, ToolChoice
|
||||
|
||||
|
||||
class LLM:
|
||||
@ -88,7 +88,7 @@ class LLM:
|
||||
|
||||
# Validate all messages have required fields
|
||||
for msg in formatted_messages:
|
||||
if msg["role"] not in ["system", "user", "assistant", "tool"]:
|
||||
if msg["role"] not in ROLE_VALUES:
|
||||
raise ValueError(f"Invalid role: {msg['role']}")
|
||||
if "content" not in msg and "tool_calls" not in msg:
|
||||
raise ValueError(
|
||||
@ -187,7 +187,7 @@ class LLM:
|
||||
system_msgs: Optional[List[Union[dict, Message]]] = None,
|
||||
timeout: int = 60,
|
||||
tools: Optional[List[dict]] = None,
|
||||
tool_choice: Literal["none", "auto", "required"] = "auto",
|
||||
tool_choice: TOOL_CHOICE_TYPE = ToolChoice.AUTO, # type: ignore
|
||||
temperature: Optional[float] = None,
|
||||
**kwargs,
|
||||
):
|
||||
@ -213,7 +213,7 @@ class LLM:
|
||||
"""
|
||||
try:
|
||||
# Validate tool_choice
|
||||
if tool_choice not in ["none", "auto", "required"]:
|
||||
if tool_choice not in TOOL_CHOICE_VALUES:
|
||||
raise ValueError(f"Invalid tool_choice: {tool_choice}")
|
||||
|
||||
# Format messages
|
||||
|
@ -3,6 +3,24 @@ from typing import Any, List, Literal, Optional, Union
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
class Role(str, Enum):
|
||||
"""Message role options"""
|
||||
SYSTEM = "system"
|
||||
USER = "user"
|
||||
ASSISTANT = "assistant"
|
||||
TOOL = "tool"
|
||||
|
||||
ROLE_VALUES = tuple(role.value for role in Role)
|
||||
ROLE_TYPE = Literal[ROLE_VALUES] # type: ignore
|
||||
|
||||
class ToolChoice(str, Enum):
|
||||
"""Tool choice options"""
|
||||
NONE = "none"
|
||||
AUTO = "auto"
|
||||
REQUIRED = "required"
|
||||
|
||||
TOOL_CHOICE_VALUES = tuple(choice.value for choice in ToolChoice)
|
||||
TOOL_CHOICE_TYPE = Literal[TOOL_CHOICE_VALUES] # type: ignore
|
||||
|
||||
class AgentState(str, Enum):
|
||||
"""Agent execution states"""
|
||||
@ -29,7 +47,7 @@ class ToolCall(BaseModel):
|
||||
class Message(BaseModel):
|
||||
"""Represents a chat message in the conversation"""
|
||||
|
||||
role: Literal["system", "user", "assistant", "tool"] = Field(...)
|
||||
role: ROLE_TYPE = Field(...) # type: ignore
|
||||
content: Optional[str] = Field(default=None)
|
||||
tool_calls: Optional[List[ToolCall]] = Field(default=None)
|
||||
name: Optional[str] = Field(default=None)
|
||||
@ -71,22 +89,22 @@ class Message(BaseModel):
|
||||
@classmethod
|
||||
def user_message(cls, content: str) -> "Message":
|
||||
"""Create a user message"""
|
||||
return cls(role="user", content=content)
|
||||
return cls(role=Role.USER, content=content)
|
||||
|
||||
@classmethod
|
||||
def system_message(cls, content: str) -> "Message":
|
||||
"""Create a system message"""
|
||||
return cls(role="system", content=content)
|
||||
return cls(role=Role.SYSTEM, content=content)
|
||||
|
||||
@classmethod
|
||||
def assistant_message(cls, content: Optional[str] = None) -> "Message":
|
||||
"""Create an assistant message"""
|
||||
return cls(role="assistant", content=content)
|
||||
return cls(role=Role.ASSISTANT, content=content)
|
||||
|
||||
@classmethod
|
||||
def tool_message(cls, content: str, name, tool_call_id: str) -> "Message":
|
||||
"""Create a tool message"""
|
||||
return cls(role="tool", content=content, name=name, tool_call_id=tool_call_id)
|
||||
return cls(role=Role.TOOL, content=content, name=name, tool_call_id=tool_call_id)
|
||||
|
||||
@classmethod
|
||||
def from_tool_calls(
|
||||
@ -103,7 +121,7 @@ class Message(BaseModel):
|
||||
for call in tool_calls
|
||||
]
|
||||
return cls(
|
||||
role="assistant", content=content, tool_calls=formatted_calls, **kwargs
|
||||
role=Role.ASSISTANT, content=content, tool_calls=formatted_calls, **kwargs
|
||||
)
|
||||
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user