format code and remove max_input_tokens for ToolCallAgent

This commit is contained in:
liangxinbing 2025-03-15 14:43:07 +08:00
parent 86399b97d6
commit 65a3898592
4 changed files with 66 additions and 40 deletions

View File

@ -33,7 +33,6 @@ class ToolCallAgent(ReActAgent):
max_steps: int = 30 max_steps: int = 30
max_observe: Optional[Union[int, bool]] = None max_observe: Optional[Union[int, bool]] = None
max_input_tokens: Optional[int] = None
async def think(self) -> bool: async def think(self) -> bool:
"""Process current state and decide next actions using tools""" """Process current state and decide next actions using tools"""
@ -51,13 +50,15 @@ class ToolCallAgent(ReActAgent):
tools=self.available_tools.to_params(), tools=self.available_tools.to_params(),
tool_choice=self.tool_choices, tool_choice=self.tool_choices,
) )
except ValueError as e: except ValueError:
raise raise
except Exception as e: except Exception as e:
# Check if this is a RetryError containing TokenLimitExceeded # Check if this is a RetryError containing TokenLimitExceeded
if hasattr(e, "__cause__") and isinstance(e.__cause__, TokenLimitExceeded): if hasattr(e, "__cause__") and isinstance(e.__cause__, TokenLimitExceeded):
token_limit_error = e.__cause__ token_limit_error = e.__cause__
logger.error(f"🚨 Token limit error (from RetryError): {token_limit_error}") logger.error(
f"🚨 Token limit error (from RetryError): {token_limit_error}"
)
self.memory.add_message( self.memory.add_message(
Message.assistant_message( Message.assistant_message(
f"Maximum token limit reached, cannot continue execution: {str(token_limit_error)}" f"Maximum token limit reached, cannot continue execution: {str(token_limit_error)}"

View File

@ -20,7 +20,10 @@ class LLMSettings(BaseModel):
base_url: str = Field(..., description="API base URL") base_url: str = Field(..., description="API base URL")
api_key: str = Field(..., description="API key") api_key: str = Field(..., description="API key")
max_tokens: int = Field(4096, description="Maximum number of tokens per request") max_tokens: int = Field(4096, description="Maximum number of tokens per request")
max_input_tokens: Optional[int] = Field(None, description="Maximum input tokens to use across all requests (None for unlimited)") max_input_tokens: Optional[int] = Field(
None,
description="Maximum input tokens to use across all requests (None for unlimited)",
)
temperature: float = Field(1.0, description="Sampling temperature") temperature: float = Field(1.0, description="Sampling temperature")
api_type: str = Field(..., description="AzureOpenai or Openai") api_type: str = Field(..., description="AzureOpenai or Openai")
api_version: str = Field(..., description="Azure Openai version if AzureOpenai") api_version: str = Field(..., description="Azure Openai version if AzureOpenai")

View File

@ -4,10 +4,10 @@ class ToolError(Exception):
def __init__(self, message): def __init__(self, message):
self.message = message self.message = message
class OpenManusError(Exception): class OpenManusError(Exception):
"""Base exception for all OpenManus errors""" """Base exception for all OpenManus errors"""
pass
class TokenLimitExceeded(OpenManusError): class TokenLimitExceeded(OpenManusError):
"""Exception raised when the token limit is exceeded""" """Exception raised when the token limit is exceeded"""
pass

View File

@ -1,5 +1,6 @@
from typing import Dict, List, Optional, Union from typing import Dict, List, Optional, Union
import tiktoken
from openai import ( from openai import (
APIError, APIError,
AsyncAzureOpenAI, AsyncAzureOpenAI,
@ -8,8 +9,12 @@ from openai import (
OpenAIError, OpenAIError,
RateLimitError, RateLimitError,
) )
import tiktoken from tenacity import (
from tenacity import retry, stop_after_attempt, wait_random_exponential, retry_if_exception_type retry,
retry_if_exception_type,
stop_after_attempt,
wait_random_exponential,
)
from app.config import LLMSettings, config from app.config import LLMSettings, config
from app.exceptions import TokenLimitExceeded from app.exceptions import TokenLimitExceeded
@ -51,18 +56,22 @@ class LLM:
self.api_key = llm_config.api_key self.api_key = llm_config.api_key
self.api_version = llm_config.api_version self.api_version = llm_config.api_version
self.base_url = llm_config.base_url self.base_url = llm_config.base_url
# Add token counting related attributes # Add token counting related attributes
self.total_input_tokens = 0 self.total_input_tokens = 0
self.max_input_tokens = llm_config.max_input_tokens if hasattr(llm_config, "max_input_tokens") else None self.max_input_tokens = (
llm_config.max_input_tokens
if hasattr(llm_config, "max_input_tokens")
else None
)
# Initialize tokenizer # Initialize tokenizer
try: try:
self.tokenizer = tiktoken.encoding_for_model(self.model) self.tokenizer = tiktoken.encoding_for_model(self.model)
except KeyError: except KeyError:
# If the model is not in tiktoken's presets, use cl100k_base as default # If the model is not in tiktoken's presets, use cl100k_base as default
self.tokenizer = tiktoken.get_encoding("cl100k_base") self.tokenizer = tiktoken.get_encoding("cl100k_base")
if self.api_type == "azure": if self.api_type == "azure":
self.client = AsyncAzureOpenAI( self.client = AsyncAzureOpenAI(
base_url=self.base_url, base_url=self.base_url,
@ -77,51 +86,57 @@ class LLM:
if not text: if not text:
return 0 return 0
return len(self.tokenizer.encode(text)) return len(self.tokenizer.encode(text))
def count_message_tokens(self, messages: List[dict]) -> int: def count_message_tokens(self, messages: List[dict]) -> int:
"""Calculate the number of tokens in a message list""" """Calculate the number of tokens in a message list"""
token_count = 0 token_count = 0
for message in messages: for message in messages:
# Base token count for each message (according to OpenAI's calculation method) # Base token count for each message (according to OpenAI's calculation method)
token_count += 4 # Base token count for each message token_count += 4 # Base token count for each message
# Calculate tokens for the role # Calculate tokens for the role
if "role" in message: if "role" in message:
token_count += self.count_tokens(message["role"]) token_count += self.count_tokens(message["role"])
# Calculate tokens for the content # Calculate tokens for the content
if "content" in message and message["content"]: if "content" in message and message["content"]:
token_count += self.count_tokens(message["content"]) token_count += self.count_tokens(message["content"])
# Calculate tokens for tool calls # Calculate tokens for tool calls
if "tool_calls" in message and message["tool_calls"]: if "tool_calls" in message and message["tool_calls"]:
for tool_call in message["tool_calls"]: for tool_call in message["tool_calls"]:
if "function" in tool_call: if "function" in tool_call:
# Function name # Function name
if "name" in tool_call["function"]: if "name" in tool_call["function"]:
token_count += self.count_tokens(tool_call["function"]["name"]) token_count += self.count_tokens(
tool_call["function"]["name"]
)
# Function arguments # Function arguments
if "arguments" in tool_call["function"]: if "arguments" in tool_call["function"]:
token_count += self.count_tokens(tool_call["function"]["arguments"]) token_count += self.count_tokens(
tool_call["function"]["arguments"]
)
# Calculate tokens for tool responses # Calculate tokens for tool responses
if "name" in message and message["name"]: if "name" in message and message["name"]:
token_count += self.count_tokens(message["name"]) token_count += self.count_tokens(message["name"])
if "tool_call_id" in message and message["tool_call_id"]: if "tool_call_id" in message and message["tool_call_id"]:
token_count += self.count_tokens(message["tool_call_id"]) token_count += self.count_tokens(message["tool_call_id"])
# Add extra tokens for message format # Add extra tokens for message format
token_count += 2 # Extra tokens for message format token_count += 2 # Extra tokens for message format
return token_count return token_count
def update_token_count(self, input_tokens: int) -> None: def update_token_count(self, input_tokens: int) -> None:
"""Update token counts""" """Update token counts"""
# Only track tokens if max_input_tokens is set # Only track tokens if max_input_tokens is set
self.total_input_tokens += input_tokens self.total_input_tokens += input_tokens
logger.info(f"Token usage: Input={input_tokens}, Cumulative Input={self.total_input_tokens}") logger.info(
f"Token usage: Input={input_tokens}, Cumulative Input={self.total_input_tokens}"
)
def check_token_limit(self, input_tokens: int) -> bool: def check_token_limit(self, input_tokens: int) -> bool:
"""Check if token limits are exceeded""" """Check if token limits are exceeded"""
if self.max_input_tokens is not None: if self.max_input_tokens is not None:
@ -131,9 +146,12 @@ class LLM:
def get_limit_error_message(self, input_tokens: int) -> str: def get_limit_error_message(self, input_tokens: int) -> str:
"""Generate error message for token limit exceeded""" """Generate error message for token limit exceeded"""
if self.max_input_tokens is not None and (self.total_input_tokens + input_tokens) > self.max_input_tokens: if (
self.max_input_tokens is not None
and (self.total_input_tokens + input_tokens) > self.max_input_tokens
):
return f"Request may exceed input token limit (Current: {self.total_input_tokens}, Needed: {input_tokens}, Max: {self.max_input_tokens})" return f"Request may exceed input token limit (Current: {self.total_input_tokens}, Needed: {input_tokens}, Max: {self.max_input_tokens})"
return "Token limit exceeded" return "Token limit exceeded"
@staticmethod @staticmethod
@ -187,7 +205,9 @@ class LLM:
@retry( @retry(
wait=wait_random_exponential(min=1, max=60), wait=wait_random_exponential(min=1, max=60),
stop=stop_after_attempt(6), stop=stop_after_attempt(6),
retry=retry_if_exception_type((OpenAIError, Exception, ValueError)), # Don't retry TokenLimitExceeded retry=retry_if_exception_type(
(OpenAIError, Exception, ValueError)
), # Don't retry TokenLimitExceeded
) )
async def ask( async def ask(
self, self,
@ -221,10 +241,10 @@ class LLM:
messages = system_msgs + self.format_messages(messages) messages = system_msgs + self.format_messages(messages)
else: else:
messages = self.format_messages(messages) messages = self.format_messages(messages)
# Calculate input token count # Calculate input token count
input_tokens = self.count_message_tokens(messages) input_tokens = self.count_message_tokens(messages)
# Check if token limits are exceeded # Check if token limits are exceeded
if not self.check_token_limit(input_tokens): if not self.check_token_limit(input_tokens):
error_message = self.get_limit_error_message(input_tokens) error_message = self.get_limit_error_message(input_tokens)
@ -252,15 +272,15 @@ class LLM:
if not response.choices or not response.choices[0].message.content: if not response.choices or not response.choices[0].message.content:
raise ValueError("Empty or invalid response from LLM") raise ValueError("Empty or invalid response from LLM")
# Update token counts # Update token counts
self.update_token_count(response.usage.prompt_tokens) self.update_token_count(response.usage.prompt_tokens)
return response.choices[0].message.content return response.choices[0].message.content
# Streaming request, For streaming, update estimated token count before making the request # Streaming request, For streaming, update estimated token count before making the request
self.update_token_count(input_tokens) self.update_token_count(input_tokens)
params["stream"] = True params["stream"] = True
response = await self.client.chat.completions.create(**params) response = await self.client.chat.completions.create(**params)
@ -274,7 +294,7 @@ class LLM:
full_response = "".join(collected_messages).strip() full_response = "".join(collected_messages).strip()
if not full_response: if not full_response:
raise ValueError("Empty response from streaming LLM") raise ValueError("Empty response from streaming LLM")
return full_response return full_response
except TokenLimitExceeded: except TokenLimitExceeded:
@ -299,7 +319,9 @@ class LLM:
@retry( @retry(
wait=wait_random_exponential(min=1, max=60), wait=wait_random_exponential(min=1, max=60),
stop=stop_after_attempt(6), stop=stop_after_attempt(6),
retry=retry_if_exception_type((OpenAIError, Exception, ValueError)), # Don't retry TokenLimitExceeded retry=retry_if_exception_type(
(OpenAIError, Exception, ValueError)
), # Don't retry TokenLimitExceeded
) )
async def ask_tool( async def ask_tool(
self, self,
@ -346,21 +368,21 @@ class LLM:
# Calculate input token count # Calculate input token count
input_tokens = self.count_message_tokens(messages) input_tokens = self.count_message_tokens(messages)
# If there are tools, calculate token count for tool descriptions # If there are tools, calculate token count for tool descriptions
tools_tokens = 0 tools_tokens = 0
if tools: if tools:
for tool in tools: for tool in tools:
tools_tokens += self.count_tokens(str(tool)) tools_tokens += self.count_tokens(str(tool))
input_tokens += tools_tokens input_tokens += tools_tokens
# Check if token limits are exceeded # Check if token limits are exceeded
if not self.check_token_limit(input_tokens): if not self.check_token_limit(input_tokens):
error_message = self.get_limit_error_message(input_tokens) error_message = self.get_limit_error_message(input_tokens)
# Raise a special exception that won't be retried # Raise a special exception that won't be retried
raise TokenLimitExceeded(error_message) raise TokenLimitExceeded(error_message)
# Validate tools if provided # Validate tools if provided
if tools: if tools:
for tool in tools: for tool in tools:
@ -391,7 +413,7 @@ class LLM:
if not response.choices or not response.choices[0].message: if not response.choices or not response.choices[0].message:
print(response) print(response)
raise ValueError("Invalid or empty response from LLM") raise ValueError("Invalid or empty response from LLM")
# Update token counts # Update token counts
self.update_token_count(response.usage.prompt_tokens) self.update_token_count(response.usage.prompt_tokens)