From 9b0b69a5e1e6bb6f9f9ceeaaf0cb53886af44dc3 Mon Sep 17 00:00:00 2001 From: zhengshuli Date: Thu, 13 Mar 2025 15:46:51 +0800 Subject: [PATCH 1/5] =?UTF-8?q?Use=20the=20max=5Finput=5Ftokens=20configur?= =?UTF-8?q?ation=20to=20constrain=20the=20agent=E2=80=99s=20token=20usage.?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- app/agent/toolcall.py | 37 +++++++--- app/config.py | 2 + app/exceptions.py | 8 +++ app/llm.py | 135 ++++++++++++++++++++++++++++++++++++- config/config.example.toml | 1 + requirements.txt | 1 + 6 files changed, 173 insertions(+), 11 deletions(-) diff --git a/app/agent/toolcall.py b/app/agent/toolcall.py index ecf0bb4..7d6afa0 100644 --- a/app/agent/toolcall.py +++ b/app/agent/toolcall.py @@ -4,6 +4,7 @@ from typing import Any, List, Optional, Union from pydantic import Field from app.agent.react import ReActAgent +from app.exceptions import TokenLimitExceeded from app.logger import logger from app.prompt.toolcall import NEXT_STEP_PROMPT, SYSTEM_PROMPT from app.schema import TOOL_CHOICE_TYPE, AgentState, Message, ToolCall, ToolChoice @@ -32,6 +33,7 @@ class ToolCallAgent(ReActAgent): max_steps: int = 30 max_observe: Optional[Union[int, bool]] = None + max_input_tokens: Optional[int] = None async def think(self) -> bool: """Process current state and decide next actions using tools""" @@ -39,15 +41,32 @@ class ToolCallAgent(ReActAgent): user_msg = Message.user_message(self.next_step_prompt) self.messages += [user_msg] - # Get response with tool options - response = await self.llm.ask_tool( - messages=self.messages, - system_msgs=[Message.system_message(self.system_prompt)] - if self.system_prompt - else None, - tools=self.available_tools.to_params(), - tool_choice=self.tool_choices, - ) + try: + # Get response with tool options + response = await self.llm.ask_tool( + messages=self.messages, + system_msgs=[Message.system_message(self.system_prompt)] + if self.system_prompt + else None, + tools=self.available_tools.to_params(), + tool_choice=self.tool_choices, + ) + except ValueError as e: + raise + except Exception as e: + # Check if this is a RetryError containing TokenLimitExceeded + if hasattr(e, "__cause__") and isinstance(e.__cause__, TokenLimitExceeded): + token_limit_error = e.__cause__ + logger.error(f"🚨 Token limit error (from RetryError): {token_limit_error}") + self.memory.add_message( + Message.assistant_message( + f"Maximum token limit reached, cannot continue execution: {str(token_limit_error)}" + ) + ) + self.state = AgentState.FINISHED + return False + raise + self.tool_calls = response.tool_calls # Log response info diff --git a/app/config.py b/app/config.py index 0a267d7..8f78151 100644 --- a/app/config.py +++ b/app/config.py @@ -20,6 +20,7 @@ class LLMSettings(BaseModel): base_url: str = Field(..., description="API base URL") api_key: str = Field(..., description="API key") max_tokens: int = Field(4096, description="Maximum number of tokens per request") + max_input_tokens: Optional[int] = Field(None, description="Maximum input tokens to use across all requests (None for unlimited)") temperature: float = Field(1.0, description="Sampling temperature") api_type: str = Field(..., description="AzureOpenai or Openai") api_version: str = Field(..., description="Azure Openai version if AzureOpenai") @@ -118,6 +119,7 @@ class Config: "base_url": base_llm.get("base_url"), "api_key": base_llm.get("api_key"), "max_tokens": base_llm.get("max_tokens", 4096), + "max_input_tokens": base_llm.get("max_input_tokens"), "temperature": base_llm.get("temperature", 1.0), "api_type": base_llm.get("api_type", ""), "api_version": base_llm.get("api_version", ""), diff --git a/app/exceptions.py b/app/exceptions.py index 57a0148..42195de 100644 --- a/app/exceptions.py +++ b/app/exceptions.py @@ -3,3 +3,11 @@ class ToolError(Exception): def __init__(self, message): self.message = message + +class OpenManusError(Exception): + """Base exception for all OpenManus errors""" + pass + +class TokenLimitExceeded(OpenManusError): + """Exception raised when the token limit is exceeded""" + pass diff --git a/app/llm.py b/app/llm.py index 8c085ae..817bbfe 100644 --- a/app/llm.py +++ b/app/llm.py @@ -8,9 +8,11 @@ from openai import ( OpenAIError, RateLimitError, ) -from tenacity import retry, stop_after_attempt, wait_random_exponential +import tiktoken +from tenacity import retry, stop_after_attempt, wait_random_exponential, retry_if_exception_type from app.config import LLMSettings, config +from app.exceptions import TokenLimitExceeded from app.logger import logger # Assuming a logger is set up in your app from app.schema import ( ROLE_VALUES, @@ -49,6 +51,18 @@ class LLM: self.api_key = llm_config.api_key self.api_version = llm_config.api_version self.base_url = llm_config.base_url + + # Add token counting related attributes + self.total_input_tokens = 0 + self.max_input_tokens = llm_config.max_input_tokens if hasattr(llm_config, "max_input_tokens") else None + + # Initialize tokenizer + try: + self.tokenizer = tiktoken.encoding_for_model(self.model) + except KeyError: + # If the model is not in tiktoken's presets, use cl100k_base as default + self.tokenizer = tiktoken.get_encoding("cl100k_base") + if self.api_type == "azure": self.client = AsyncAzureOpenAI( base_url=self.base_url, @@ -58,6 +72,70 @@ class LLM: else: self.client = AsyncOpenAI(api_key=self.api_key, base_url=self.base_url) + def count_tokens(self, text: str) -> int: + """Calculate the number of tokens in a text""" + if not text: + return 0 + return len(self.tokenizer.encode(text)) + + def count_message_tokens(self, messages: List[dict]) -> int: + """Calculate the number of tokens in a message list""" + token_count = 0 + for message in messages: + # Base token count for each message (according to OpenAI's calculation method) + token_count += 4 # Base token count for each message + + # Calculate tokens for the role + if "role" in message: + token_count += self.count_tokens(message["role"]) + + # Calculate tokens for the content + if "content" in message and message["content"]: + token_count += self.count_tokens(message["content"]) + + # Calculate tokens for tool calls + if "tool_calls" in message and message["tool_calls"]: + for tool_call in message["tool_calls"]: + if "function" in tool_call: + # Function name + if "name" in tool_call["function"]: + token_count += self.count_tokens(tool_call["function"]["name"]) + # Function arguments + if "arguments" in tool_call["function"]: + token_count += self.count_tokens(tool_call["function"]["arguments"]) + + # Calculate tokens for tool responses + if "name" in message and message["name"]: + token_count += self.count_tokens(message["name"]) + + if "tool_call_id" in message and message["tool_call_id"]: + token_count += self.count_tokens(message["tool_call_id"]) + + # Add extra tokens for message format + token_count += 2 # Extra tokens for message format + + return token_count + + def update_token_count(self, input_tokens: int) -> None: + """Update token counts""" + # Only track tokens if max_input_tokens is set + self.total_input_tokens += input_tokens + logger.info(f"Token usage: Input={input_tokens}, Cumulative Input={self.total_input_tokens}") + + def check_token_limit(self, input_tokens: int) -> bool: + """Check if token limits are exceeded""" + if self.max_input_tokens is not None: + return (self.total_input_tokens + input_tokens) <= self.max_input_tokens + # If max_input_tokens is not set, always return True + return True + + def get_limit_error_message(self, input_tokens: int) -> str: + """Generate error message for token limit exceeded""" + if self.max_input_tokens is not None and (self.total_input_tokens + input_tokens) > self.max_input_tokens: + return f"Request may exceed input token limit (Current: {self.total_input_tokens}, Needed: {input_tokens}, Max: {self.max_input_tokens})" + + return "Token limit exceeded" + @staticmethod def format_messages(messages: List[Union[dict, Message]]) -> List[dict]: """ @@ -109,6 +187,7 @@ class LLM: @retry( wait=wait_random_exponential(min=1, max=60), stop=stop_after_attempt(6), + retry=retry_if_exception_type((OpenAIError, Exception, ValueError)), # Don't retry TokenLimitExceeded ) async def ask( self, @@ -130,6 +209,7 @@ class LLM: str: The generated response Raises: + TokenLimitExceeded: If token limits are exceeded ValueError: If messages are invalid or response is empty OpenAIError: If API call fails after retries Exception: For unexpected errors @@ -141,6 +221,15 @@ class LLM: messages = system_msgs + self.format_messages(messages) else: messages = self.format_messages(messages) + + # Calculate input token count + input_tokens = self.count_message_tokens(messages) + + # Check if token limits are exceeded + if not self.check_token_limit(input_tokens): + error_message = self.get_limit_error_message(input_tokens) + # Raise a special exception that won't be retried + raise TokenLimitExceeded(error_message) params = { "model": self.model, @@ -161,9 +250,15 @@ class LLM: if not response.choices or not response.choices[0].message.content: raise ValueError("Empty or invalid response from LLM") + + # Update token counts + self.update_token_count(response.usage.prompt_tokens) + return response.choices[0].message.content - # Streaming request + # Streaming request, For streaming, update estimated token count before making the request + self.update_token_count(input_tokens) + params["stream"] = True response = await self.client.chat.completions.create(**params) @@ -177,13 +272,23 @@ class LLM: full_response = "".join(collected_messages).strip() if not full_response: raise ValueError("Empty response from streaming LLM") + return full_response + except TokenLimitExceeded: + # Re-raise token limit errors without logging + raise except ValueError as ve: logger.error(f"Validation error: {ve}") raise except OpenAIError as oe: logger.error(f"OpenAI API error: {oe}") + if isinstance(oe, AuthenticationError): + logger.error("Authentication failed. Check API key.") + elif isinstance(oe, RateLimitError): + logger.error("Rate limit exceeded. Consider increasing retry attempts.") + elif isinstance(oe, APIError): + logger.error(f"API error: {oe}") raise except Exception as e: logger.error(f"Unexpected error in ask: {e}") @@ -192,6 +297,7 @@ class LLM: @retry( wait=wait_random_exponential(min=1, max=60), stop=stop_after_attempt(6), + retry=retry_if_exception_type((OpenAIError, Exception, ValueError)), # Don't retry TokenLimitExceeded ) async def ask_tool( self, @@ -219,6 +325,7 @@ class LLM: ChatCompletionMessage: The model's response Raises: + TokenLimitExceeded: If token limits are exceeded ValueError: If tools, tool_choice, or messages are invalid OpenAIError: If API call fails after retries Exception: For unexpected errors @@ -235,6 +342,23 @@ class LLM: else: messages = self.format_messages(messages) + # Calculate input token count + input_tokens = self.count_message_tokens(messages) + + # If there are tools, calculate token count for tool descriptions + tools_tokens = 0 + if tools: + for tool in tools: + tools_tokens += self.count_tokens(str(tool)) + + input_tokens += tools_tokens + + # Check if token limits are exceeded + if not self.check_token_limit(input_tokens): + error_message = self.get_limit_error_message(input_tokens) + # Raise a special exception that won't be retried + raise TokenLimitExceeded(error_message) + # Validate tools if provided if tools: for tool in tools: @@ -263,13 +387,20 @@ class LLM: if not response.choices or not response.choices[0].message: print(response) raise ValueError("Invalid or empty response from LLM") + + # Update token counts + self.update_token_count(response.usage.prompt_tokens) return response.choices[0].message + except TokenLimitExceeded: + # Re-raise token limit errors without logging + raise except ValueError as ve: logger.error(f"Validation error in ask_tool: {ve}") raise except OpenAIError as oe: + logger.error(f"OpenAI API error: {oe}") if isinstance(oe, AuthenticationError): logger.error("Authentication failed. Check API key.") elif isinstance(oe, RateLimitError): diff --git a/config/config.example.toml b/config/config.example.toml index 762f42c..b720088 100644 --- a/config/config.example.toml +++ b/config/config.example.toml @@ -5,6 +5,7 @@ base_url = "https://api.openai.com/v1" # API endpoint URL api_key = "sk-..." # Your API key max_tokens = 8192 # Maximum number of tokens in the response temperature = 0.0 # Controls randomness +#max_input_tokens = 100000 # Maximum input tokens to use across all requests (set to null or delete this line for unlimited) # [llm] #AZURE OPENAI: # api_type= 'azure' diff --git a/requirements.txt b/requirements.txt index 60ad38e..534eca9 100644 --- a/requirements.txt +++ b/requirements.txt @@ -6,6 +6,7 @@ loguru~=0.7.3 numpy datasets~=3.2.0 fastapi~=0.115.11 +tiktoken~=0.9.0 html2text~=2024.2.26 gymnasium~=1.0.0 From b17c9d31a9e94f2a31eadf5ec47a551b3775658c Mon Sep 17 00:00:00 2001 From: Matt Eng Date: Fri, 14 Mar 2025 20:39:23 -0700 Subject: [PATCH 2/5] Fix temperature using default if 0 --- app/llm.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/app/llm.py b/app/llm.py index 8c085ae..98343e6 100644 --- a/app/llm.py +++ b/app/llm.py @@ -151,7 +151,7 @@ class LLM: params["max_completion_tokens"] = self.max_tokens else: params["max_tokens"] = self.max_tokens - params["temperature"] = temperature or self.temperature + params["temperature"] = temperature if temperature is not None else self.temperature if not stream: # Non-streaming request @@ -255,7 +255,7 @@ class LLM: params["max_completion_tokens"] = self.max_tokens else: params["max_tokens"] = self.max_tokens - params["temperature"] = temperature or self.temperature + params["temperature"] = temperature if temperature is not None else self.temperature response = await self.client.chat.completions.create(**params) From 49ccd72815487a67ac72cadea7919ab1e9b4ec1a Mon Sep 17 00:00:00 2001 From: Matt Eng Date: Fri, 14 Mar 2025 21:41:43 -0700 Subject: [PATCH 3/5] Reformat --- app/llm.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/app/llm.py b/app/llm.py index 98343e6..39871f6 100644 --- a/app/llm.py +++ b/app/llm.py @@ -151,7 +151,9 @@ class LLM: params["max_completion_tokens"] = self.max_tokens else: params["max_tokens"] = self.max_tokens - params["temperature"] = temperature if temperature is not None else self.temperature + params["temperature"] = ( + temperature if temperature is not None else self.temperature + ) if not stream: # Non-streaming request @@ -255,7 +257,9 @@ class LLM: params["max_completion_tokens"] = self.max_tokens else: params["max_tokens"] = self.max_tokens - params["temperature"] = temperature if temperature is not None else self.temperature + params["temperature"] = ( + temperature if temperature is not None else self.temperature + ) response = await self.client.chat.completions.create(**params) From 86399b97d66b61c22fdc6325b266bf7ae8d729d9 Mon Sep 17 00:00:00 2001 From: liangxinbing <1580466765@qq.com> Date: Sat, 15 Mar 2025 14:40:01 +0800 Subject: [PATCH 4/5] add dependabot.yml --- .github/dependabot.yml | 58 ++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 58 insertions(+) create mode 100644 .github/dependabot.yml diff --git a/.github/dependabot.yml b/.github/dependabot.yml new file mode 100644 index 0000000..4f94bce --- /dev/null +++ b/.github/dependabot.yml @@ -0,0 +1,58 @@ +version: 2 +updates: + - package-ecosystem: "pip" + directory: "/" + schedule: + interval: "weekly" + open-pull-requests-limit: 5 + groups: + # Group critical packages that might need careful review + core-dependencies: + patterns: + - "pydantic*" + - "openai" + - "fastapi" + - "tiktoken" + browsergym-related: + patterns: + - "browsergym*" + - "browser-use" + - "playwright" + search-tools: + patterns: + - "googlesearch-python" + - "baidusearch" + - "duckduckgo_search" + pre-commit: + patterns: + - "pre-commit" + security-all: + applies-to: "security-updates" + patterns: + - "*" + version-all: + applies-to: "version-updates" + patterns: + - "*" + exclude-patterns: + - "pydantic*" + - "openai" + - "fastapi" + - "tiktoken" + - "browsergym*" + - "browser-use" + - "playwright" + - "googlesearch-python" + - "baidusearch" + - "duckduckgo_search" + - "pre-commit" + + - package-ecosystem: "github-actions" + directory: "/" + schedule: + interval: "weekly" + open-pull-requests-limit: 5 + groups: + actions: + patterns: + - "*" From 65a3898592115ff50b73188e385d9e61ad320272 Mon Sep 17 00:00:00 2001 From: liangxinbing <1580466765@qq.com> Date: Sat, 15 Mar 2025 14:43:07 +0800 Subject: [PATCH 5/5] format code and remove max_input_tokens for ToolCallAgent --- app/agent/toolcall.py | 7 ++-- app/config.py | 5 ++- app/exceptions.py | 4 +- app/llm.py | 90 +++++++++++++++++++++++++++---------------- 4 files changed, 66 insertions(+), 40 deletions(-) diff --git a/app/agent/toolcall.py b/app/agent/toolcall.py index 7d6afa0..29e5af4 100644 --- a/app/agent/toolcall.py +++ b/app/agent/toolcall.py @@ -33,7 +33,6 @@ class ToolCallAgent(ReActAgent): max_steps: int = 30 max_observe: Optional[Union[int, bool]] = None - max_input_tokens: Optional[int] = None async def think(self) -> bool: """Process current state and decide next actions using tools""" @@ -51,13 +50,15 @@ class ToolCallAgent(ReActAgent): tools=self.available_tools.to_params(), tool_choice=self.tool_choices, ) - except ValueError as e: + except ValueError: raise except Exception as e: # Check if this is a RetryError containing TokenLimitExceeded if hasattr(e, "__cause__") and isinstance(e.__cause__, TokenLimitExceeded): token_limit_error = e.__cause__ - logger.error(f"🚨 Token limit error (from RetryError): {token_limit_error}") + logger.error( + f"🚨 Token limit error (from RetryError): {token_limit_error}" + ) self.memory.add_message( Message.assistant_message( f"Maximum token limit reached, cannot continue execution: {str(token_limit_error)}" diff --git a/app/config.py b/app/config.py index 8f78151..51356a0 100644 --- a/app/config.py +++ b/app/config.py @@ -20,7 +20,10 @@ class LLMSettings(BaseModel): base_url: str = Field(..., description="API base URL") api_key: str = Field(..., description="API key") max_tokens: int = Field(4096, description="Maximum number of tokens per request") - max_input_tokens: Optional[int] = Field(None, description="Maximum input tokens to use across all requests (None for unlimited)") + max_input_tokens: Optional[int] = Field( + None, + description="Maximum input tokens to use across all requests (None for unlimited)", + ) temperature: float = Field(1.0, description="Sampling temperature") api_type: str = Field(..., description="AzureOpenai or Openai") api_version: str = Field(..., description="Azure Openai version if AzureOpenai") diff --git a/app/exceptions.py b/app/exceptions.py index 42195de..fc90087 100644 --- a/app/exceptions.py +++ b/app/exceptions.py @@ -4,10 +4,10 @@ class ToolError(Exception): def __init__(self, message): self.message = message + class OpenManusError(Exception): """Base exception for all OpenManus errors""" - pass + class TokenLimitExceeded(OpenManusError): """Exception raised when the token limit is exceeded""" - pass diff --git a/app/llm.py b/app/llm.py index 6777e5e..1cc640b 100644 --- a/app/llm.py +++ b/app/llm.py @@ -1,5 +1,6 @@ from typing import Dict, List, Optional, Union +import tiktoken from openai import ( APIError, AsyncAzureOpenAI, @@ -8,8 +9,12 @@ from openai import ( OpenAIError, RateLimitError, ) -import tiktoken -from tenacity import retry, stop_after_attempt, wait_random_exponential, retry_if_exception_type +from tenacity import ( + retry, + retry_if_exception_type, + stop_after_attempt, + wait_random_exponential, +) from app.config import LLMSettings, config from app.exceptions import TokenLimitExceeded @@ -51,18 +56,22 @@ class LLM: self.api_key = llm_config.api_key self.api_version = llm_config.api_version self.base_url = llm_config.base_url - + # Add token counting related attributes self.total_input_tokens = 0 - self.max_input_tokens = llm_config.max_input_tokens if hasattr(llm_config, "max_input_tokens") else None - + self.max_input_tokens = ( + llm_config.max_input_tokens + if hasattr(llm_config, "max_input_tokens") + else None + ) + # Initialize tokenizer try: self.tokenizer = tiktoken.encoding_for_model(self.model) except KeyError: # If the model is not in tiktoken's presets, use cl100k_base as default self.tokenizer = tiktoken.get_encoding("cl100k_base") - + if self.api_type == "azure": self.client = AsyncAzureOpenAI( base_url=self.base_url, @@ -77,51 +86,57 @@ class LLM: if not text: return 0 return len(self.tokenizer.encode(text)) - + def count_message_tokens(self, messages: List[dict]) -> int: """Calculate the number of tokens in a message list""" token_count = 0 for message in messages: # Base token count for each message (according to OpenAI's calculation method) token_count += 4 # Base token count for each message - + # Calculate tokens for the role if "role" in message: token_count += self.count_tokens(message["role"]) - + # Calculate tokens for the content if "content" in message and message["content"]: token_count += self.count_tokens(message["content"]) - + # Calculate tokens for tool calls if "tool_calls" in message and message["tool_calls"]: for tool_call in message["tool_calls"]: if "function" in tool_call: # Function name if "name" in tool_call["function"]: - token_count += self.count_tokens(tool_call["function"]["name"]) + token_count += self.count_tokens( + tool_call["function"]["name"] + ) # Function arguments if "arguments" in tool_call["function"]: - token_count += self.count_tokens(tool_call["function"]["arguments"]) - + token_count += self.count_tokens( + tool_call["function"]["arguments"] + ) + # Calculate tokens for tool responses if "name" in message and message["name"]: token_count += self.count_tokens(message["name"]) - + if "tool_call_id" in message and message["tool_call_id"]: token_count += self.count_tokens(message["tool_call_id"]) - + # Add extra tokens for message format token_count += 2 # Extra tokens for message format - + return token_count - + def update_token_count(self, input_tokens: int) -> None: """Update token counts""" # Only track tokens if max_input_tokens is set self.total_input_tokens += input_tokens - logger.info(f"Token usage: Input={input_tokens}, Cumulative Input={self.total_input_tokens}") - + logger.info( + f"Token usage: Input={input_tokens}, Cumulative Input={self.total_input_tokens}" + ) + def check_token_limit(self, input_tokens: int) -> bool: """Check if token limits are exceeded""" if self.max_input_tokens is not None: @@ -131,9 +146,12 @@ class LLM: def get_limit_error_message(self, input_tokens: int) -> str: """Generate error message for token limit exceeded""" - if self.max_input_tokens is not None and (self.total_input_tokens + input_tokens) > self.max_input_tokens: + if ( + self.max_input_tokens is not None + and (self.total_input_tokens + input_tokens) > self.max_input_tokens + ): return f"Request may exceed input token limit (Current: {self.total_input_tokens}, Needed: {input_tokens}, Max: {self.max_input_tokens})" - + return "Token limit exceeded" @staticmethod @@ -187,7 +205,9 @@ class LLM: @retry( wait=wait_random_exponential(min=1, max=60), stop=stop_after_attempt(6), - retry=retry_if_exception_type((OpenAIError, Exception, ValueError)), # Don't retry TokenLimitExceeded + retry=retry_if_exception_type( + (OpenAIError, Exception, ValueError) + ), # Don't retry TokenLimitExceeded ) async def ask( self, @@ -221,10 +241,10 @@ class LLM: messages = system_msgs + self.format_messages(messages) else: messages = self.format_messages(messages) - + # Calculate input token count input_tokens = self.count_message_tokens(messages) - + # Check if token limits are exceeded if not self.check_token_limit(input_tokens): error_message = self.get_limit_error_message(input_tokens) @@ -252,15 +272,15 @@ class LLM: if not response.choices or not response.choices[0].message.content: raise ValueError("Empty or invalid response from LLM") - + # Update token counts self.update_token_count(response.usage.prompt_tokens) - + return response.choices[0].message.content # Streaming request, For streaming, update estimated token count before making the request self.update_token_count(input_tokens) - + params["stream"] = True response = await self.client.chat.completions.create(**params) @@ -274,7 +294,7 @@ class LLM: full_response = "".join(collected_messages).strip() if not full_response: raise ValueError("Empty response from streaming LLM") - + return full_response except TokenLimitExceeded: @@ -299,7 +319,9 @@ class LLM: @retry( wait=wait_random_exponential(min=1, max=60), stop=stop_after_attempt(6), - retry=retry_if_exception_type((OpenAIError, Exception, ValueError)), # Don't retry TokenLimitExceeded + retry=retry_if_exception_type( + (OpenAIError, Exception, ValueError) + ), # Don't retry TokenLimitExceeded ) async def ask_tool( self, @@ -346,21 +368,21 @@ class LLM: # Calculate input token count input_tokens = self.count_message_tokens(messages) - + # If there are tools, calculate token count for tool descriptions tools_tokens = 0 if tools: for tool in tools: tools_tokens += self.count_tokens(str(tool)) - + input_tokens += tools_tokens - + # Check if token limits are exceeded if not self.check_token_limit(input_tokens): error_message = self.get_limit_error_message(input_tokens) # Raise a special exception that won't be retried raise TokenLimitExceeded(error_message) - + # Validate tools if provided if tools: for tool in tools: @@ -391,7 +413,7 @@ class LLM: if not response.choices or not response.choices[0].message: print(response) raise ValueError("Invalid or empty response from LLM") - + # Update token counts self.update_token_count(response.usage.prompt_tokens)