Use the max_input_tokens configuration to constrain the agent’s token usage.

This commit is contained in:
zhengshuli 2025-03-13 15:46:51 +08:00
parent 7db0b2fbf0
commit 9b0b69a5e1
6 changed files with 173 additions and 11 deletions

View File

@ -4,6 +4,7 @@ from typing import Any, List, Optional, Union
from pydantic import Field
from app.agent.react import ReActAgent
from app.exceptions import TokenLimitExceeded
from app.logger import logger
from app.prompt.toolcall import NEXT_STEP_PROMPT, SYSTEM_PROMPT
from app.schema import TOOL_CHOICE_TYPE, AgentState, Message, ToolCall, ToolChoice
@ -32,6 +33,7 @@ class ToolCallAgent(ReActAgent):
max_steps: int = 30
max_observe: Optional[Union[int, bool]] = None
max_input_tokens: Optional[int] = None
async def think(self) -> bool:
"""Process current state and decide next actions using tools"""
@ -39,6 +41,7 @@ class ToolCallAgent(ReActAgent):
user_msg = Message.user_message(self.next_step_prompt)
self.messages += [user_msg]
try:
# Get response with tool options
response = await self.llm.ask_tool(
messages=self.messages,
@ -48,6 +51,22 @@ class ToolCallAgent(ReActAgent):
tools=self.available_tools.to_params(),
tool_choice=self.tool_choices,
)
except ValueError as e:
raise
except Exception as e:
# Check if this is a RetryError containing TokenLimitExceeded
if hasattr(e, "__cause__") and isinstance(e.__cause__, TokenLimitExceeded):
token_limit_error = e.__cause__
logger.error(f"🚨 Token limit error (from RetryError): {token_limit_error}")
self.memory.add_message(
Message.assistant_message(
f"Maximum token limit reached, cannot continue execution: {str(token_limit_error)}"
)
)
self.state = AgentState.FINISHED
return False
raise
self.tool_calls = response.tool_calls
# Log response info

View File

@ -20,6 +20,7 @@ class LLMSettings(BaseModel):
base_url: str = Field(..., description="API base URL")
api_key: str = Field(..., description="API key")
max_tokens: int = Field(4096, description="Maximum number of tokens per request")
max_input_tokens: Optional[int] = Field(None, description="Maximum input tokens to use across all requests (None for unlimited)")
temperature: float = Field(1.0, description="Sampling temperature")
api_type: str = Field(..., description="AzureOpenai or Openai")
api_version: str = Field(..., description="Azure Openai version if AzureOpenai")
@ -118,6 +119,7 @@ class Config:
"base_url": base_llm.get("base_url"),
"api_key": base_llm.get("api_key"),
"max_tokens": base_llm.get("max_tokens", 4096),
"max_input_tokens": base_llm.get("max_input_tokens"),
"temperature": base_llm.get("temperature", 1.0),
"api_type": base_llm.get("api_type", ""),
"api_version": base_llm.get("api_version", ""),

View File

@ -3,3 +3,11 @@ class ToolError(Exception):
def __init__(self, message):
self.message = message
class OpenManusError(Exception):
"""Base exception for all OpenManus errors"""
pass
class TokenLimitExceeded(OpenManusError):
"""Exception raised when the token limit is exceeded"""
pass

View File

@ -8,9 +8,11 @@ from openai import (
OpenAIError,
RateLimitError,
)
from tenacity import retry, stop_after_attempt, wait_random_exponential
import tiktoken
from tenacity import retry, stop_after_attempt, wait_random_exponential, retry_if_exception_type
from app.config import LLMSettings, config
from app.exceptions import TokenLimitExceeded
from app.logger import logger # Assuming a logger is set up in your app
from app.schema import (
ROLE_VALUES,
@ -49,6 +51,18 @@ class LLM:
self.api_key = llm_config.api_key
self.api_version = llm_config.api_version
self.base_url = llm_config.base_url
# Add token counting related attributes
self.total_input_tokens = 0
self.max_input_tokens = llm_config.max_input_tokens if hasattr(llm_config, "max_input_tokens") else None
# Initialize tokenizer
try:
self.tokenizer = tiktoken.encoding_for_model(self.model)
except KeyError:
# If the model is not in tiktoken's presets, use cl100k_base as default
self.tokenizer = tiktoken.get_encoding("cl100k_base")
if self.api_type == "azure":
self.client = AsyncAzureOpenAI(
base_url=self.base_url,
@ -58,6 +72,70 @@ class LLM:
else:
self.client = AsyncOpenAI(api_key=self.api_key, base_url=self.base_url)
def count_tokens(self, text: str) -> int:
"""Calculate the number of tokens in a text"""
if not text:
return 0
return len(self.tokenizer.encode(text))
def count_message_tokens(self, messages: List[dict]) -> int:
"""Calculate the number of tokens in a message list"""
token_count = 0
for message in messages:
# Base token count for each message (according to OpenAI's calculation method)
token_count += 4 # Base token count for each message
# Calculate tokens for the role
if "role" in message:
token_count += self.count_tokens(message["role"])
# Calculate tokens for the content
if "content" in message and message["content"]:
token_count += self.count_tokens(message["content"])
# Calculate tokens for tool calls
if "tool_calls" in message and message["tool_calls"]:
for tool_call in message["tool_calls"]:
if "function" in tool_call:
# Function name
if "name" in tool_call["function"]:
token_count += self.count_tokens(tool_call["function"]["name"])
# Function arguments
if "arguments" in tool_call["function"]:
token_count += self.count_tokens(tool_call["function"]["arguments"])
# Calculate tokens for tool responses
if "name" in message and message["name"]:
token_count += self.count_tokens(message["name"])
if "tool_call_id" in message and message["tool_call_id"]:
token_count += self.count_tokens(message["tool_call_id"])
# Add extra tokens for message format
token_count += 2 # Extra tokens for message format
return token_count
def update_token_count(self, input_tokens: int) -> None:
"""Update token counts"""
# Only track tokens if max_input_tokens is set
self.total_input_tokens += input_tokens
logger.info(f"Token usage: Input={input_tokens}, Cumulative Input={self.total_input_tokens}")
def check_token_limit(self, input_tokens: int) -> bool:
"""Check if token limits are exceeded"""
if self.max_input_tokens is not None:
return (self.total_input_tokens + input_tokens) <= self.max_input_tokens
# If max_input_tokens is not set, always return True
return True
def get_limit_error_message(self, input_tokens: int) -> str:
"""Generate error message for token limit exceeded"""
if self.max_input_tokens is not None and (self.total_input_tokens + input_tokens) > self.max_input_tokens:
return f"Request may exceed input token limit (Current: {self.total_input_tokens}, Needed: {input_tokens}, Max: {self.max_input_tokens})"
return "Token limit exceeded"
@staticmethod
def format_messages(messages: List[Union[dict, Message]]) -> List[dict]:
"""
@ -109,6 +187,7 @@ class LLM:
@retry(
wait=wait_random_exponential(min=1, max=60),
stop=stop_after_attempt(6),
retry=retry_if_exception_type((OpenAIError, Exception, ValueError)), # Don't retry TokenLimitExceeded
)
async def ask(
self,
@ -130,6 +209,7 @@ class LLM:
str: The generated response
Raises:
TokenLimitExceeded: If token limits are exceeded
ValueError: If messages are invalid or response is empty
OpenAIError: If API call fails after retries
Exception: For unexpected errors
@ -142,6 +222,15 @@ class LLM:
else:
messages = self.format_messages(messages)
# Calculate input token count
input_tokens = self.count_message_tokens(messages)
# Check if token limits are exceeded
if not self.check_token_limit(input_tokens):
error_message = self.get_limit_error_message(input_tokens)
# Raise a special exception that won't be retried
raise TokenLimitExceeded(error_message)
params = {
"model": self.model,
"messages": messages,
@ -161,9 +250,15 @@ class LLM:
if not response.choices or not response.choices[0].message.content:
raise ValueError("Empty or invalid response from LLM")
# Update token counts
self.update_token_count(response.usage.prompt_tokens)
return response.choices[0].message.content
# Streaming request
# Streaming request, For streaming, update estimated token count before making the request
self.update_token_count(input_tokens)
params["stream"] = True
response = await self.client.chat.completions.create(**params)
@ -177,13 +272,23 @@ class LLM:
full_response = "".join(collected_messages).strip()
if not full_response:
raise ValueError("Empty response from streaming LLM")
return full_response
except TokenLimitExceeded:
# Re-raise token limit errors without logging
raise
except ValueError as ve:
logger.error(f"Validation error: {ve}")
raise
except OpenAIError as oe:
logger.error(f"OpenAI API error: {oe}")
if isinstance(oe, AuthenticationError):
logger.error("Authentication failed. Check API key.")
elif isinstance(oe, RateLimitError):
logger.error("Rate limit exceeded. Consider increasing retry attempts.")
elif isinstance(oe, APIError):
logger.error(f"API error: {oe}")
raise
except Exception as e:
logger.error(f"Unexpected error in ask: {e}")
@ -192,6 +297,7 @@ class LLM:
@retry(
wait=wait_random_exponential(min=1, max=60),
stop=stop_after_attempt(6),
retry=retry_if_exception_type((OpenAIError, Exception, ValueError)), # Don't retry TokenLimitExceeded
)
async def ask_tool(
self,
@ -219,6 +325,7 @@ class LLM:
ChatCompletionMessage: The model's response
Raises:
TokenLimitExceeded: If token limits are exceeded
ValueError: If tools, tool_choice, or messages are invalid
OpenAIError: If API call fails after retries
Exception: For unexpected errors
@ -235,6 +342,23 @@ class LLM:
else:
messages = self.format_messages(messages)
# Calculate input token count
input_tokens = self.count_message_tokens(messages)
# If there are tools, calculate token count for tool descriptions
tools_tokens = 0
if tools:
for tool in tools:
tools_tokens += self.count_tokens(str(tool))
input_tokens += tools_tokens
# Check if token limits are exceeded
if not self.check_token_limit(input_tokens):
error_message = self.get_limit_error_message(input_tokens)
# Raise a special exception that won't be retried
raise TokenLimitExceeded(error_message)
# Validate tools if provided
if tools:
for tool in tools:
@ -264,12 +388,19 @@ class LLM:
print(response)
raise ValueError("Invalid or empty response from LLM")
# Update token counts
self.update_token_count(response.usage.prompt_tokens)
return response.choices[0].message
except TokenLimitExceeded:
# Re-raise token limit errors without logging
raise
except ValueError as ve:
logger.error(f"Validation error in ask_tool: {ve}")
raise
except OpenAIError as oe:
logger.error(f"OpenAI API error: {oe}")
if isinstance(oe, AuthenticationError):
logger.error("Authentication failed. Check API key.")
elif isinstance(oe, RateLimitError):

View File

@ -5,6 +5,7 @@ base_url = "https://api.openai.com/v1" # API endpoint URL
api_key = "sk-..." # Your API key
max_tokens = 8192 # Maximum number of tokens in the response
temperature = 0.0 # Controls randomness
#max_input_tokens = 100000 # Maximum input tokens to use across all requests (set to null or delete this line for unlimited)
# [llm] #AZURE OPENAI:
# api_type= 'azure'

View File

@ -6,6 +6,7 @@ loguru~=0.7.3
numpy
datasets~=3.2.0
fastapi~=0.115.11
tiktoken~=0.9.0
html2text~=2024.2.26
gymnasium~=1.0.0