Merge pull request #573 from nezhazheng/main

Use the max_input_tokens configuration to constrain the agent’s token usage.
This commit is contained in:
mannaandpoem 2025-03-15 14:11:19 +08:00 committed by GitHub
commit d35cd5ccf0
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 173 additions and 11 deletions

View File

@ -4,6 +4,7 @@ from typing import Any, List, Optional, Union
from pydantic import Field from pydantic import Field
from app.agent.react import ReActAgent from app.agent.react import ReActAgent
from app.exceptions import TokenLimitExceeded
from app.logger import logger from app.logger import logger
from app.prompt.toolcall import NEXT_STEP_PROMPT, SYSTEM_PROMPT from app.prompt.toolcall import NEXT_STEP_PROMPT, SYSTEM_PROMPT
from app.schema import TOOL_CHOICE_TYPE, AgentState, Message, ToolCall, ToolChoice from app.schema import TOOL_CHOICE_TYPE, AgentState, Message, ToolCall, ToolChoice
@ -32,6 +33,7 @@ class ToolCallAgent(ReActAgent):
max_steps: int = 30 max_steps: int = 30
max_observe: Optional[Union[int, bool]] = None max_observe: Optional[Union[int, bool]] = None
max_input_tokens: Optional[int] = None
async def think(self) -> bool: async def think(self) -> bool:
"""Process current state and decide next actions using tools""" """Process current state and decide next actions using tools"""
@ -39,15 +41,32 @@ class ToolCallAgent(ReActAgent):
user_msg = Message.user_message(self.next_step_prompt) user_msg = Message.user_message(self.next_step_prompt)
self.messages += [user_msg] self.messages += [user_msg]
# Get response with tool options try:
response = await self.llm.ask_tool( # Get response with tool options
messages=self.messages, response = await self.llm.ask_tool(
system_msgs=[Message.system_message(self.system_prompt)] messages=self.messages,
if self.system_prompt system_msgs=[Message.system_message(self.system_prompt)]
else None, if self.system_prompt
tools=self.available_tools.to_params(), else None,
tool_choice=self.tool_choices, tools=self.available_tools.to_params(),
) tool_choice=self.tool_choices,
)
except ValueError as e:
raise
except Exception as e:
# Check if this is a RetryError containing TokenLimitExceeded
if hasattr(e, "__cause__") and isinstance(e.__cause__, TokenLimitExceeded):
token_limit_error = e.__cause__
logger.error(f"🚨 Token limit error (from RetryError): {token_limit_error}")
self.memory.add_message(
Message.assistant_message(
f"Maximum token limit reached, cannot continue execution: {str(token_limit_error)}"
)
)
self.state = AgentState.FINISHED
return False
raise
self.tool_calls = response.tool_calls self.tool_calls = response.tool_calls
# Log response info # Log response info

View File

@ -20,6 +20,7 @@ class LLMSettings(BaseModel):
base_url: str = Field(..., description="API base URL") base_url: str = Field(..., description="API base URL")
api_key: str = Field(..., description="API key") api_key: str = Field(..., description="API key")
max_tokens: int = Field(4096, description="Maximum number of tokens per request") max_tokens: int = Field(4096, description="Maximum number of tokens per request")
max_input_tokens: Optional[int] = Field(None, description="Maximum input tokens to use across all requests (None for unlimited)")
temperature: float = Field(1.0, description="Sampling temperature") temperature: float = Field(1.0, description="Sampling temperature")
api_type: str = Field(..., description="AzureOpenai or Openai") api_type: str = Field(..., description="AzureOpenai or Openai")
api_version: str = Field(..., description="Azure Openai version if AzureOpenai") api_version: str = Field(..., description="Azure Openai version if AzureOpenai")
@ -118,6 +119,7 @@ class Config:
"base_url": base_llm.get("base_url"), "base_url": base_llm.get("base_url"),
"api_key": base_llm.get("api_key"), "api_key": base_llm.get("api_key"),
"max_tokens": base_llm.get("max_tokens", 4096), "max_tokens": base_llm.get("max_tokens", 4096),
"max_input_tokens": base_llm.get("max_input_tokens"),
"temperature": base_llm.get("temperature", 1.0), "temperature": base_llm.get("temperature", 1.0),
"api_type": base_llm.get("api_type", ""), "api_type": base_llm.get("api_type", ""),
"api_version": base_llm.get("api_version", ""), "api_version": base_llm.get("api_version", ""),

View File

@ -3,3 +3,11 @@ class ToolError(Exception):
def __init__(self, message): def __init__(self, message):
self.message = message self.message = message
class OpenManusError(Exception):
"""Base exception for all OpenManus errors"""
pass
class TokenLimitExceeded(OpenManusError):
"""Exception raised when the token limit is exceeded"""
pass

View File

@ -8,9 +8,11 @@ from openai import (
OpenAIError, OpenAIError,
RateLimitError, RateLimitError,
) )
from tenacity import retry, stop_after_attempt, wait_random_exponential import tiktoken
from tenacity import retry, stop_after_attempt, wait_random_exponential, retry_if_exception_type
from app.config import LLMSettings, config from app.config import LLMSettings, config
from app.exceptions import TokenLimitExceeded
from app.logger import logger # Assuming a logger is set up in your app from app.logger import logger # Assuming a logger is set up in your app
from app.schema import ( from app.schema import (
ROLE_VALUES, ROLE_VALUES,
@ -49,6 +51,18 @@ class LLM:
self.api_key = llm_config.api_key self.api_key = llm_config.api_key
self.api_version = llm_config.api_version self.api_version = llm_config.api_version
self.base_url = llm_config.base_url self.base_url = llm_config.base_url
# Add token counting related attributes
self.total_input_tokens = 0
self.max_input_tokens = llm_config.max_input_tokens if hasattr(llm_config, "max_input_tokens") else None
# Initialize tokenizer
try:
self.tokenizer = tiktoken.encoding_for_model(self.model)
except KeyError:
# If the model is not in tiktoken's presets, use cl100k_base as default
self.tokenizer = tiktoken.get_encoding("cl100k_base")
if self.api_type == "azure": if self.api_type == "azure":
self.client = AsyncAzureOpenAI( self.client = AsyncAzureOpenAI(
base_url=self.base_url, base_url=self.base_url,
@ -58,6 +72,70 @@ class LLM:
else: else:
self.client = AsyncOpenAI(api_key=self.api_key, base_url=self.base_url) self.client = AsyncOpenAI(api_key=self.api_key, base_url=self.base_url)
def count_tokens(self, text: str) -> int:
"""Calculate the number of tokens in a text"""
if not text:
return 0
return len(self.tokenizer.encode(text))
def count_message_tokens(self, messages: List[dict]) -> int:
"""Calculate the number of tokens in a message list"""
token_count = 0
for message in messages:
# Base token count for each message (according to OpenAI's calculation method)
token_count += 4 # Base token count for each message
# Calculate tokens for the role
if "role" in message:
token_count += self.count_tokens(message["role"])
# Calculate tokens for the content
if "content" in message and message["content"]:
token_count += self.count_tokens(message["content"])
# Calculate tokens for tool calls
if "tool_calls" in message and message["tool_calls"]:
for tool_call in message["tool_calls"]:
if "function" in tool_call:
# Function name
if "name" in tool_call["function"]:
token_count += self.count_tokens(tool_call["function"]["name"])
# Function arguments
if "arguments" in tool_call["function"]:
token_count += self.count_tokens(tool_call["function"]["arguments"])
# Calculate tokens for tool responses
if "name" in message and message["name"]:
token_count += self.count_tokens(message["name"])
if "tool_call_id" in message and message["tool_call_id"]:
token_count += self.count_tokens(message["tool_call_id"])
# Add extra tokens for message format
token_count += 2 # Extra tokens for message format
return token_count
def update_token_count(self, input_tokens: int) -> None:
"""Update token counts"""
# Only track tokens if max_input_tokens is set
self.total_input_tokens += input_tokens
logger.info(f"Token usage: Input={input_tokens}, Cumulative Input={self.total_input_tokens}")
def check_token_limit(self, input_tokens: int) -> bool:
"""Check if token limits are exceeded"""
if self.max_input_tokens is not None:
return (self.total_input_tokens + input_tokens) <= self.max_input_tokens
# If max_input_tokens is not set, always return True
return True
def get_limit_error_message(self, input_tokens: int) -> str:
"""Generate error message for token limit exceeded"""
if self.max_input_tokens is not None and (self.total_input_tokens + input_tokens) > self.max_input_tokens:
return f"Request may exceed input token limit (Current: {self.total_input_tokens}, Needed: {input_tokens}, Max: {self.max_input_tokens})"
return "Token limit exceeded"
@staticmethod @staticmethod
def format_messages(messages: List[Union[dict, Message]]) -> List[dict]: def format_messages(messages: List[Union[dict, Message]]) -> List[dict]:
""" """
@ -109,6 +187,7 @@ class LLM:
@retry( @retry(
wait=wait_random_exponential(min=1, max=60), wait=wait_random_exponential(min=1, max=60),
stop=stop_after_attempt(6), stop=stop_after_attempt(6),
retry=retry_if_exception_type((OpenAIError, Exception, ValueError)), # Don't retry TokenLimitExceeded
) )
async def ask( async def ask(
self, self,
@ -130,6 +209,7 @@ class LLM:
str: The generated response str: The generated response
Raises: Raises:
TokenLimitExceeded: If token limits are exceeded
ValueError: If messages are invalid or response is empty ValueError: If messages are invalid or response is empty
OpenAIError: If API call fails after retries OpenAIError: If API call fails after retries
Exception: For unexpected errors Exception: For unexpected errors
@ -141,6 +221,15 @@ class LLM:
messages = system_msgs + self.format_messages(messages) messages = system_msgs + self.format_messages(messages)
else: else:
messages = self.format_messages(messages) messages = self.format_messages(messages)
# Calculate input token count
input_tokens = self.count_message_tokens(messages)
# Check if token limits are exceeded
if not self.check_token_limit(input_tokens):
error_message = self.get_limit_error_message(input_tokens)
# Raise a special exception that won't be retried
raise TokenLimitExceeded(error_message)
params = { params = {
"model": self.model, "model": self.model,
@ -161,9 +250,15 @@ class LLM:
if not response.choices or not response.choices[0].message.content: if not response.choices or not response.choices[0].message.content:
raise ValueError("Empty or invalid response from LLM") raise ValueError("Empty or invalid response from LLM")
# Update token counts
self.update_token_count(response.usage.prompt_tokens)
return response.choices[0].message.content return response.choices[0].message.content
# Streaming request # Streaming request, For streaming, update estimated token count before making the request
self.update_token_count(input_tokens)
params["stream"] = True params["stream"] = True
response = await self.client.chat.completions.create(**params) response = await self.client.chat.completions.create(**params)
@ -177,13 +272,23 @@ class LLM:
full_response = "".join(collected_messages).strip() full_response = "".join(collected_messages).strip()
if not full_response: if not full_response:
raise ValueError("Empty response from streaming LLM") raise ValueError("Empty response from streaming LLM")
return full_response return full_response
except TokenLimitExceeded:
# Re-raise token limit errors without logging
raise
except ValueError as ve: except ValueError as ve:
logger.error(f"Validation error: {ve}") logger.error(f"Validation error: {ve}")
raise raise
except OpenAIError as oe: except OpenAIError as oe:
logger.error(f"OpenAI API error: {oe}") logger.error(f"OpenAI API error: {oe}")
if isinstance(oe, AuthenticationError):
logger.error("Authentication failed. Check API key.")
elif isinstance(oe, RateLimitError):
logger.error("Rate limit exceeded. Consider increasing retry attempts.")
elif isinstance(oe, APIError):
logger.error(f"API error: {oe}")
raise raise
except Exception as e: except Exception as e:
logger.error(f"Unexpected error in ask: {e}") logger.error(f"Unexpected error in ask: {e}")
@ -192,6 +297,7 @@ class LLM:
@retry( @retry(
wait=wait_random_exponential(min=1, max=60), wait=wait_random_exponential(min=1, max=60),
stop=stop_after_attempt(6), stop=stop_after_attempt(6),
retry=retry_if_exception_type((OpenAIError, Exception, ValueError)), # Don't retry TokenLimitExceeded
) )
async def ask_tool( async def ask_tool(
self, self,
@ -219,6 +325,7 @@ class LLM:
ChatCompletionMessage: The model's response ChatCompletionMessage: The model's response
Raises: Raises:
TokenLimitExceeded: If token limits are exceeded
ValueError: If tools, tool_choice, or messages are invalid ValueError: If tools, tool_choice, or messages are invalid
OpenAIError: If API call fails after retries OpenAIError: If API call fails after retries
Exception: For unexpected errors Exception: For unexpected errors
@ -235,6 +342,23 @@ class LLM:
else: else:
messages = self.format_messages(messages) messages = self.format_messages(messages)
# Calculate input token count
input_tokens = self.count_message_tokens(messages)
# If there are tools, calculate token count for tool descriptions
tools_tokens = 0
if tools:
for tool in tools:
tools_tokens += self.count_tokens(str(tool))
input_tokens += tools_tokens
# Check if token limits are exceeded
if not self.check_token_limit(input_tokens):
error_message = self.get_limit_error_message(input_tokens)
# Raise a special exception that won't be retried
raise TokenLimitExceeded(error_message)
# Validate tools if provided # Validate tools if provided
if tools: if tools:
for tool in tools: for tool in tools:
@ -263,13 +387,20 @@ class LLM:
if not response.choices or not response.choices[0].message: if not response.choices or not response.choices[0].message:
print(response) print(response)
raise ValueError("Invalid or empty response from LLM") raise ValueError("Invalid or empty response from LLM")
# Update token counts
self.update_token_count(response.usage.prompt_tokens)
return response.choices[0].message return response.choices[0].message
except TokenLimitExceeded:
# Re-raise token limit errors without logging
raise
except ValueError as ve: except ValueError as ve:
logger.error(f"Validation error in ask_tool: {ve}") logger.error(f"Validation error in ask_tool: {ve}")
raise raise
except OpenAIError as oe: except OpenAIError as oe:
logger.error(f"OpenAI API error: {oe}")
if isinstance(oe, AuthenticationError): if isinstance(oe, AuthenticationError):
logger.error("Authentication failed. Check API key.") logger.error("Authentication failed. Check API key.")
elif isinstance(oe, RateLimitError): elif isinstance(oe, RateLimitError):

View File

@ -5,6 +5,7 @@ base_url = "https://api.openai.com/v1" # API endpoint URL
api_key = "sk-..." # Your API key api_key = "sk-..." # Your API key
max_tokens = 8192 # Maximum number of tokens in the response max_tokens = 8192 # Maximum number of tokens in the response
temperature = 0.0 # Controls randomness temperature = 0.0 # Controls randomness
#max_input_tokens = 100000 # Maximum input tokens to use across all requests (set to null or delete this line for unlimited)
# [llm] #AZURE OPENAI: # [llm] #AZURE OPENAI:
# api_type= 'azure' # api_type= 'azure'

View File

@ -6,6 +6,7 @@ loguru~=0.7.3
numpy numpy
datasets~=3.2.0 datasets~=3.2.0
fastapi~=0.115.11 fastapi~=0.115.11
tiktoken~=0.9.0
html2text~=2024.2.26 html2text~=2024.2.26
gymnasium~=1.0.0 gymnasium~=1.0.0