diff --git a/app/agent/toolcall.py b/app/agent/toolcall.py index 7d6afa0..29e5af4 100644 --- a/app/agent/toolcall.py +++ b/app/agent/toolcall.py @@ -33,7 +33,6 @@ class ToolCallAgent(ReActAgent): max_steps: int = 30 max_observe: Optional[Union[int, bool]] = None - max_input_tokens: Optional[int] = None async def think(self) -> bool: """Process current state and decide next actions using tools""" @@ -51,13 +50,15 @@ class ToolCallAgent(ReActAgent): tools=self.available_tools.to_params(), tool_choice=self.tool_choices, ) - except ValueError as e: + except ValueError: raise except Exception as e: # Check if this is a RetryError containing TokenLimitExceeded if hasattr(e, "__cause__") and isinstance(e.__cause__, TokenLimitExceeded): token_limit_error = e.__cause__ - logger.error(f"🚨 Token limit error (from RetryError): {token_limit_error}") + logger.error( + f"🚨 Token limit error (from RetryError): {token_limit_error}" + ) self.memory.add_message( Message.assistant_message( f"Maximum token limit reached, cannot continue execution: {str(token_limit_error)}" diff --git a/app/config.py b/app/config.py index 8f78151..51356a0 100644 --- a/app/config.py +++ b/app/config.py @@ -20,7 +20,10 @@ class LLMSettings(BaseModel): base_url: str = Field(..., description="API base URL") api_key: str = Field(..., description="API key") max_tokens: int = Field(4096, description="Maximum number of tokens per request") - max_input_tokens: Optional[int] = Field(None, description="Maximum input tokens to use across all requests (None for unlimited)") + max_input_tokens: Optional[int] = Field( + None, + description="Maximum input tokens to use across all requests (None for unlimited)", + ) temperature: float = Field(1.0, description="Sampling temperature") api_type: str = Field(..., description="AzureOpenai or Openai") api_version: str = Field(..., description="Azure Openai version if AzureOpenai") diff --git a/app/exceptions.py b/app/exceptions.py index 42195de..fc90087 100644 --- a/app/exceptions.py +++ b/app/exceptions.py @@ -4,10 +4,10 @@ class ToolError(Exception): def __init__(self, message): self.message = message + class OpenManusError(Exception): """Base exception for all OpenManus errors""" - pass + class TokenLimitExceeded(OpenManusError): """Exception raised when the token limit is exceeded""" - pass diff --git a/app/llm.py b/app/llm.py index 6777e5e..1cc640b 100644 --- a/app/llm.py +++ b/app/llm.py @@ -1,5 +1,6 @@ from typing import Dict, List, Optional, Union +import tiktoken from openai import ( APIError, AsyncAzureOpenAI, @@ -8,8 +9,12 @@ from openai import ( OpenAIError, RateLimitError, ) -import tiktoken -from tenacity import retry, stop_after_attempt, wait_random_exponential, retry_if_exception_type +from tenacity import ( + retry, + retry_if_exception_type, + stop_after_attempt, + wait_random_exponential, +) from app.config import LLMSettings, config from app.exceptions import TokenLimitExceeded @@ -51,18 +56,22 @@ class LLM: self.api_key = llm_config.api_key self.api_version = llm_config.api_version self.base_url = llm_config.base_url - + # Add token counting related attributes self.total_input_tokens = 0 - self.max_input_tokens = llm_config.max_input_tokens if hasattr(llm_config, "max_input_tokens") else None - + self.max_input_tokens = ( + llm_config.max_input_tokens + if hasattr(llm_config, "max_input_tokens") + else None + ) + # Initialize tokenizer try: self.tokenizer = tiktoken.encoding_for_model(self.model) except KeyError: # If the model is not in tiktoken's presets, use cl100k_base as default self.tokenizer = tiktoken.get_encoding("cl100k_base") - + if self.api_type == "azure": self.client = AsyncAzureOpenAI( base_url=self.base_url, @@ -77,51 +86,57 @@ class LLM: if not text: return 0 return len(self.tokenizer.encode(text)) - + def count_message_tokens(self, messages: List[dict]) -> int: """Calculate the number of tokens in a message list""" token_count = 0 for message in messages: # Base token count for each message (according to OpenAI's calculation method) token_count += 4 # Base token count for each message - + # Calculate tokens for the role if "role" in message: token_count += self.count_tokens(message["role"]) - + # Calculate tokens for the content if "content" in message and message["content"]: token_count += self.count_tokens(message["content"]) - + # Calculate tokens for tool calls if "tool_calls" in message and message["tool_calls"]: for tool_call in message["tool_calls"]: if "function" in tool_call: # Function name if "name" in tool_call["function"]: - token_count += self.count_tokens(tool_call["function"]["name"]) + token_count += self.count_tokens( + tool_call["function"]["name"] + ) # Function arguments if "arguments" in tool_call["function"]: - token_count += self.count_tokens(tool_call["function"]["arguments"]) - + token_count += self.count_tokens( + tool_call["function"]["arguments"] + ) + # Calculate tokens for tool responses if "name" in message and message["name"]: token_count += self.count_tokens(message["name"]) - + if "tool_call_id" in message and message["tool_call_id"]: token_count += self.count_tokens(message["tool_call_id"]) - + # Add extra tokens for message format token_count += 2 # Extra tokens for message format - + return token_count - + def update_token_count(self, input_tokens: int) -> None: """Update token counts""" # Only track tokens if max_input_tokens is set self.total_input_tokens += input_tokens - logger.info(f"Token usage: Input={input_tokens}, Cumulative Input={self.total_input_tokens}") - + logger.info( + f"Token usage: Input={input_tokens}, Cumulative Input={self.total_input_tokens}" + ) + def check_token_limit(self, input_tokens: int) -> bool: """Check if token limits are exceeded""" if self.max_input_tokens is not None: @@ -131,9 +146,12 @@ class LLM: def get_limit_error_message(self, input_tokens: int) -> str: """Generate error message for token limit exceeded""" - if self.max_input_tokens is not None and (self.total_input_tokens + input_tokens) > self.max_input_tokens: + if ( + self.max_input_tokens is not None + and (self.total_input_tokens + input_tokens) > self.max_input_tokens + ): return f"Request may exceed input token limit (Current: {self.total_input_tokens}, Needed: {input_tokens}, Max: {self.max_input_tokens})" - + return "Token limit exceeded" @staticmethod @@ -187,7 +205,9 @@ class LLM: @retry( wait=wait_random_exponential(min=1, max=60), stop=stop_after_attempt(6), - retry=retry_if_exception_type((OpenAIError, Exception, ValueError)), # Don't retry TokenLimitExceeded + retry=retry_if_exception_type( + (OpenAIError, Exception, ValueError) + ), # Don't retry TokenLimitExceeded ) async def ask( self, @@ -221,10 +241,10 @@ class LLM: messages = system_msgs + self.format_messages(messages) else: messages = self.format_messages(messages) - + # Calculate input token count input_tokens = self.count_message_tokens(messages) - + # Check if token limits are exceeded if not self.check_token_limit(input_tokens): error_message = self.get_limit_error_message(input_tokens) @@ -252,15 +272,15 @@ class LLM: if not response.choices or not response.choices[0].message.content: raise ValueError("Empty or invalid response from LLM") - + # Update token counts self.update_token_count(response.usage.prompt_tokens) - + return response.choices[0].message.content # Streaming request, For streaming, update estimated token count before making the request self.update_token_count(input_tokens) - + params["stream"] = True response = await self.client.chat.completions.create(**params) @@ -274,7 +294,7 @@ class LLM: full_response = "".join(collected_messages).strip() if not full_response: raise ValueError("Empty response from streaming LLM") - + return full_response except TokenLimitExceeded: @@ -299,7 +319,9 @@ class LLM: @retry( wait=wait_random_exponential(min=1, max=60), stop=stop_after_attempt(6), - retry=retry_if_exception_type((OpenAIError, Exception, ValueError)), # Don't retry TokenLimitExceeded + retry=retry_if_exception_type( + (OpenAIError, Exception, ValueError) + ), # Don't retry TokenLimitExceeded ) async def ask_tool( self, @@ -346,21 +368,21 @@ class LLM: # Calculate input token count input_tokens = self.count_message_tokens(messages) - + # If there are tools, calculate token count for tool descriptions tools_tokens = 0 if tools: for tool in tools: tools_tokens += self.count_tokens(str(tool)) - + input_tokens += tools_tokens - + # Check if token limits are exceeded if not self.check_token_limit(input_tokens): error_message = self.get_limit_error_message(input_tokens) # Raise a special exception that won't be retried raise TokenLimitExceeded(error_message) - + # Validate tools if provided if tools: for tool in tools: @@ -391,7 +413,7 @@ class LLM: if not response.choices or not response.choices[0].message: print(response) raise ValueError("Invalid or empty response from LLM") - + # Update token counts self.update_token_count(response.usage.prompt_tokens)