diff --git a/app/agent/toolcall.py b/app/agent/toolcall.py index ecf0bb4..7d6afa0 100644 --- a/app/agent/toolcall.py +++ b/app/agent/toolcall.py @@ -4,6 +4,7 @@ from typing import Any, List, Optional, Union from pydantic import Field from app.agent.react import ReActAgent +from app.exceptions import TokenLimitExceeded from app.logger import logger from app.prompt.toolcall import NEXT_STEP_PROMPT, SYSTEM_PROMPT from app.schema import TOOL_CHOICE_TYPE, AgentState, Message, ToolCall, ToolChoice @@ -32,6 +33,7 @@ class ToolCallAgent(ReActAgent): max_steps: int = 30 max_observe: Optional[Union[int, bool]] = None + max_input_tokens: Optional[int] = None async def think(self) -> bool: """Process current state and decide next actions using tools""" @@ -39,15 +41,32 @@ class ToolCallAgent(ReActAgent): user_msg = Message.user_message(self.next_step_prompt) self.messages += [user_msg] - # Get response with tool options - response = await self.llm.ask_tool( - messages=self.messages, - system_msgs=[Message.system_message(self.system_prompt)] - if self.system_prompt - else None, - tools=self.available_tools.to_params(), - tool_choice=self.tool_choices, - ) + try: + # Get response with tool options + response = await self.llm.ask_tool( + messages=self.messages, + system_msgs=[Message.system_message(self.system_prompt)] + if self.system_prompt + else None, + tools=self.available_tools.to_params(), + tool_choice=self.tool_choices, + ) + except ValueError as e: + raise + except Exception as e: + # Check if this is a RetryError containing TokenLimitExceeded + if hasattr(e, "__cause__") and isinstance(e.__cause__, TokenLimitExceeded): + token_limit_error = e.__cause__ + logger.error(f"🚨 Token limit error (from RetryError): {token_limit_error}") + self.memory.add_message( + Message.assistant_message( + f"Maximum token limit reached, cannot continue execution: {str(token_limit_error)}" + ) + ) + self.state = AgentState.FINISHED + return False + raise + self.tool_calls = response.tool_calls # Log response info diff --git a/app/config.py b/app/config.py index 0a267d7..8f78151 100644 --- a/app/config.py +++ b/app/config.py @@ -20,6 +20,7 @@ class LLMSettings(BaseModel): base_url: str = Field(..., description="API base URL") api_key: str = Field(..., description="API key") max_tokens: int = Field(4096, description="Maximum number of tokens per request") + max_input_tokens: Optional[int] = Field(None, description="Maximum input tokens to use across all requests (None for unlimited)") temperature: float = Field(1.0, description="Sampling temperature") api_type: str = Field(..., description="AzureOpenai or Openai") api_version: str = Field(..., description="Azure Openai version if AzureOpenai") @@ -118,6 +119,7 @@ class Config: "base_url": base_llm.get("base_url"), "api_key": base_llm.get("api_key"), "max_tokens": base_llm.get("max_tokens", 4096), + "max_input_tokens": base_llm.get("max_input_tokens"), "temperature": base_llm.get("temperature", 1.0), "api_type": base_llm.get("api_type", ""), "api_version": base_llm.get("api_version", ""), diff --git a/app/exceptions.py b/app/exceptions.py index 57a0148..42195de 100644 --- a/app/exceptions.py +++ b/app/exceptions.py @@ -3,3 +3,11 @@ class ToolError(Exception): def __init__(self, message): self.message = message + +class OpenManusError(Exception): + """Base exception for all OpenManus errors""" + pass + +class TokenLimitExceeded(OpenManusError): + """Exception raised when the token limit is exceeded""" + pass diff --git a/app/llm.py b/app/llm.py index 8c085ae..817bbfe 100644 --- a/app/llm.py +++ b/app/llm.py @@ -8,9 +8,11 @@ from openai import ( OpenAIError, RateLimitError, ) -from tenacity import retry, stop_after_attempt, wait_random_exponential +import tiktoken +from tenacity import retry, stop_after_attempt, wait_random_exponential, retry_if_exception_type from app.config import LLMSettings, config +from app.exceptions import TokenLimitExceeded from app.logger import logger # Assuming a logger is set up in your app from app.schema import ( ROLE_VALUES, @@ -49,6 +51,18 @@ class LLM: self.api_key = llm_config.api_key self.api_version = llm_config.api_version self.base_url = llm_config.base_url + + # Add token counting related attributes + self.total_input_tokens = 0 + self.max_input_tokens = llm_config.max_input_tokens if hasattr(llm_config, "max_input_tokens") else None + + # Initialize tokenizer + try: + self.tokenizer = tiktoken.encoding_for_model(self.model) + except KeyError: + # If the model is not in tiktoken's presets, use cl100k_base as default + self.tokenizer = tiktoken.get_encoding("cl100k_base") + if self.api_type == "azure": self.client = AsyncAzureOpenAI( base_url=self.base_url, @@ -58,6 +72,70 @@ class LLM: else: self.client = AsyncOpenAI(api_key=self.api_key, base_url=self.base_url) + def count_tokens(self, text: str) -> int: + """Calculate the number of tokens in a text""" + if not text: + return 0 + return len(self.tokenizer.encode(text)) + + def count_message_tokens(self, messages: List[dict]) -> int: + """Calculate the number of tokens in a message list""" + token_count = 0 + for message in messages: + # Base token count for each message (according to OpenAI's calculation method) + token_count += 4 # Base token count for each message + + # Calculate tokens for the role + if "role" in message: + token_count += self.count_tokens(message["role"]) + + # Calculate tokens for the content + if "content" in message and message["content"]: + token_count += self.count_tokens(message["content"]) + + # Calculate tokens for tool calls + if "tool_calls" in message and message["tool_calls"]: + for tool_call in message["tool_calls"]: + if "function" in tool_call: + # Function name + if "name" in tool_call["function"]: + token_count += self.count_tokens(tool_call["function"]["name"]) + # Function arguments + if "arguments" in tool_call["function"]: + token_count += self.count_tokens(tool_call["function"]["arguments"]) + + # Calculate tokens for tool responses + if "name" in message and message["name"]: + token_count += self.count_tokens(message["name"]) + + if "tool_call_id" in message and message["tool_call_id"]: + token_count += self.count_tokens(message["tool_call_id"]) + + # Add extra tokens for message format + token_count += 2 # Extra tokens for message format + + return token_count + + def update_token_count(self, input_tokens: int) -> None: + """Update token counts""" + # Only track tokens if max_input_tokens is set + self.total_input_tokens += input_tokens + logger.info(f"Token usage: Input={input_tokens}, Cumulative Input={self.total_input_tokens}") + + def check_token_limit(self, input_tokens: int) -> bool: + """Check if token limits are exceeded""" + if self.max_input_tokens is not None: + return (self.total_input_tokens + input_tokens) <= self.max_input_tokens + # If max_input_tokens is not set, always return True + return True + + def get_limit_error_message(self, input_tokens: int) -> str: + """Generate error message for token limit exceeded""" + if self.max_input_tokens is not None and (self.total_input_tokens + input_tokens) > self.max_input_tokens: + return f"Request may exceed input token limit (Current: {self.total_input_tokens}, Needed: {input_tokens}, Max: {self.max_input_tokens})" + + return "Token limit exceeded" + @staticmethod def format_messages(messages: List[Union[dict, Message]]) -> List[dict]: """ @@ -109,6 +187,7 @@ class LLM: @retry( wait=wait_random_exponential(min=1, max=60), stop=stop_after_attempt(6), + retry=retry_if_exception_type((OpenAIError, Exception, ValueError)), # Don't retry TokenLimitExceeded ) async def ask( self, @@ -130,6 +209,7 @@ class LLM: str: The generated response Raises: + TokenLimitExceeded: If token limits are exceeded ValueError: If messages are invalid or response is empty OpenAIError: If API call fails after retries Exception: For unexpected errors @@ -141,6 +221,15 @@ class LLM: messages = system_msgs + self.format_messages(messages) else: messages = self.format_messages(messages) + + # Calculate input token count + input_tokens = self.count_message_tokens(messages) + + # Check if token limits are exceeded + if not self.check_token_limit(input_tokens): + error_message = self.get_limit_error_message(input_tokens) + # Raise a special exception that won't be retried + raise TokenLimitExceeded(error_message) params = { "model": self.model, @@ -161,9 +250,15 @@ class LLM: if not response.choices or not response.choices[0].message.content: raise ValueError("Empty or invalid response from LLM") + + # Update token counts + self.update_token_count(response.usage.prompt_tokens) + return response.choices[0].message.content - # Streaming request + # Streaming request, For streaming, update estimated token count before making the request + self.update_token_count(input_tokens) + params["stream"] = True response = await self.client.chat.completions.create(**params) @@ -177,13 +272,23 @@ class LLM: full_response = "".join(collected_messages).strip() if not full_response: raise ValueError("Empty response from streaming LLM") + return full_response + except TokenLimitExceeded: + # Re-raise token limit errors without logging + raise except ValueError as ve: logger.error(f"Validation error: {ve}") raise except OpenAIError as oe: logger.error(f"OpenAI API error: {oe}") + if isinstance(oe, AuthenticationError): + logger.error("Authentication failed. Check API key.") + elif isinstance(oe, RateLimitError): + logger.error("Rate limit exceeded. Consider increasing retry attempts.") + elif isinstance(oe, APIError): + logger.error(f"API error: {oe}") raise except Exception as e: logger.error(f"Unexpected error in ask: {e}") @@ -192,6 +297,7 @@ class LLM: @retry( wait=wait_random_exponential(min=1, max=60), stop=stop_after_attempt(6), + retry=retry_if_exception_type((OpenAIError, Exception, ValueError)), # Don't retry TokenLimitExceeded ) async def ask_tool( self, @@ -219,6 +325,7 @@ class LLM: ChatCompletionMessage: The model's response Raises: + TokenLimitExceeded: If token limits are exceeded ValueError: If tools, tool_choice, or messages are invalid OpenAIError: If API call fails after retries Exception: For unexpected errors @@ -235,6 +342,23 @@ class LLM: else: messages = self.format_messages(messages) + # Calculate input token count + input_tokens = self.count_message_tokens(messages) + + # If there are tools, calculate token count for tool descriptions + tools_tokens = 0 + if tools: + for tool in tools: + tools_tokens += self.count_tokens(str(tool)) + + input_tokens += tools_tokens + + # Check if token limits are exceeded + if not self.check_token_limit(input_tokens): + error_message = self.get_limit_error_message(input_tokens) + # Raise a special exception that won't be retried + raise TokenLimitExceeded(error_message) + # Validate tools if provided if tools: for tool in tools: @@ -263,13 +387,20 @@ class LLM: if not response.choices or not response.choices[0].message: print(response) raise ValueError("Invalid or empty response from LLM") + + # Update token counts + self.update_token_count(response.usage.prompt_tokens) return response.choices[0].message + except TokenLimitExceeded: + # Re-raise token limit errors without logging + raise except ValueError as ve: logger.error(f"Validation error in ask_tool: {ve}") raise except OpenAIError as oe: + logger.error(f"OpenAI API error: {oe}") if isinstance(oe, AuthenticationError): logger.error("Authentication failed. Check API key.") elif isinstance(oe, RateLimitError): diff --git a/config/config.example.toml b/config/config.example.toml index 762f42c..b720088 100644 --- a/config/config.example.toml +++ b/config/config.example.toml @@ -5,6 +5,7 @@ base_url = "https://api.openai.com/v1" # API endpoint URL api_key = "sk-..." # Your API key max_tokens = 8192 # Maximum number of tokens in the response temperature = 0.0 # Controls randomness +#max_input_tokens = 100000 # Maximum input tokens to use across all requests (set to null or delete this line for unlimited) # [llm] #AZURE OPENAI: # api_type= 'azure' diff --git a/requirements.txt b/requirements.txt index 60ad38e..534eca9 100644 --- a/requirements.txt +++ b/requirements.txt @@ -6,6 +6,7 @@ loguru~=0.7.3 numpy datasets~=3.2.0 fastapi~=0.115.11 +tiktoken~=0.9.0 html2text~=2024.2.26 gymnasium~=1.0.0