This commit is contained in:
gantnocap 2025-03-18 00:41:11 +08:00
commit 0072174023
3 changed files with 308 additions and 56 deletions

View File

@ -24,4 +24,4 @@ jobs:
dashboard_show_total_reactions: true
top_issues: true
top_pull_requests: true
top_list_size: 32
top_list_size: 14

View File

@ -1,3 +1,4 @@
import math
from typing import Dict, List, Optional, Union
import tiktoken
@ -31,6 +32,142 @@ from app.schema import (
REASONING_MODELS = ["o1", "o3-mini"]
class TokenCounter:
# Token constants
BASE_MESSAGE_TOKENS = 4
FORMAT_TOKENS = 2
LOW_DETAIL_IMAGE_TOKENS = 85
HIGH_DETAIL_TILE_TOKENS = 170
# Image processing constants
MAX_SIZE = 2048
HIGH_DETAIL_TARGET_SHORT_SIDE = 768
TILE_SIZE = 512
def __init__(self, tokenizer):
self.tokenizer = tokenizer
def count_text(self, text: str) -> int:
"""Calculate tokens for a text string"""
return 0 if not text else len(self.tokenizer.encode(text))
def count_image(self, image_item: dict) -> int:
"""
Calculate tokens for an image based on detail level and dimensions
For "low" detail: fixed 85 tokens
For "high" detail:
1. Scale to fit in 2048x2048 square
2. Scale shortest side to 768px
3. Count 512px tiles (170 tokens each)
4. Add 85 tokens
"""
detail = image_item.get("detail", "medium")
# For low detail, always return fixed token count
if detail == "low":
return self.LOW_DETAIL_IMAGE_TOKENS
# For medium detail (default in OpenAI), use high detail calculation
# OpenAI doesn't specify a separate calculation for medium
# For high detail, calculate based on dimensions if available
if detail == "high" or detail == "medium":
# If dimensions are provided in the image_item
if "dimensions" in image_item:
width, height = image_item["dimensions"]
return self._calculate_high_detail_tokens(width, height)
# Default values when dimensions aren't available or detail level is unknown
if detail == "high":
# Default to a 1024x1024 image calculation for high detail
return self._calculate_high_detail_tokens(1024, 1024) # 765 tokens
elif detail == "medium":
# Default to a medium-sized image for medium detail
return 1024 # This matches the original default
else:
# For unknown detail levels, use medium as default
return 1024
def _calculate_high_detail_tokens(self, width: int, height: int) -> int:
"""Calculate tokens for high detail images based on dimensions"""
# Step 1: Scale to fit in MAX_SIZE x MAX_SIZE square
if width > self.MAX_SIZE or height > self.MAX_SIZE:
scale = self.MAX_SIZE / max(width, height)
width = int(width * scale)
height = int(height * scale)
# Step 2: Scale so shortest side is HIGH_DETAIL_TARGET_SHORT_SIDE
scale = self.HIGH_DETAIL_TARGET_SHORT_SIDE / min(width, height)
scaled_width = int(width * scale)
scaled_height = int(height * scale)
# Step 3: Count number of 512px tiles
tiles_x = math.ceil(scaled_width / self.TILE_SIZE)
tiles_y = math.ceil(scaled_height / self.TILE_SIZE)
total_tiles = tiles_x * tiles_y
# Step 4: Calculate final token count
return (
total_tiles * self.HIGH_DETAIL_TILE_TOKENS
) + self.LOW_DETAIL_IMAGE_TOKENS
def count_content(self, content: Union[str, List[Union[str, dict]]]) -> int:
"""Calculate tokens for message content"""
if not content:
return 0
if isinstance(content, str):
return self.count_text(content)
token_count = 0
for item in content:
if isinstance(item, str):
token_count += self.count_text(item)
elif isinstance(item, dict):
if "text" in item:
token_count += self.count_text(item["text"])
elif "image_url" in item:
token_count += self.count_image(item)
return token_count
def count_tool_calls(self, tool_calls: List[dict]) -> int:
"""Calculate tokens for tool calls"""
token_count = 0
for tool_call in tool_calls:
if "function" in tool_call:
function = tool_call["function"]
token_count += self.count_text(function.get("name", ""))
token_count += self.count_text(function.get("arguments", ""))
return token_count
def count_message_tokens(self, messages: List[dict]) -> int:
"""Calculate the total number of tokens in a message list"""
total_tokens = self.FORMAT_TOKENS # Base format tokens
for message in messages:
tokens = self.BASE_MESSAGE_TOKENS # Base tokens per message
# Add role tokens
tokens += self.count_text(message.get("role", ""))
# Add content tokens
if "content" in message:
tokens += self.count_content(message["content"])
# Add tool calls tokens
if "tool_calls" in message:
tokens += self.count_tool_calls(message["tool_calls"])
# Add name and tool_call_id tokens
tokens += self.count_text(message.get("name", ""))
tokens += self.count_text(message.get("tool_call_id", ""))
total_tokens += tokens
return total_tokens
class LLM:
_instances: Dict[str, "LLM"] = {}
@ -81,6 +218,8 @@ class LLM:
else:
self.client = AsyncOpenAI(api_key=self.api_key, base_url=self.base_url)
self.token_counter = TokenCounter(self.tokenizer)
def count_tokens(self, text: str) -> int:
"""Calculate the number of tokens in a text"""
if not text:
@ -88,46 +227,7 @@ class LLM:
return len(self.tokenizer.encode(text))
def count_message_tokens(self, messages: List[dict]) -> int:
"""Calculate the number of tokens in a message list"""
token_count = 0
for message in messages:
# Base token count for each message (according to OpenAI's calculation method)
token_count += 4 # Base token count for each message
# Calculate tokens for the role
if "role" in message:
token_count += self.count_tokens(message["role"])
# Calculate tokens for the content
if "content" in message and message["content"]:
token_count += self.count_tokens(message["content"])
# Calculate tokens for tool calls
if "tool_calls" in message and message["tool_calls"]:
for tool_call in message["tool_calls"]:
if "function" in tool_call:
# Function name
if "name" in tool_call["function"]:
token_count += self.count_tokens(
tool_call["function"]["name"]
)
# Function arguments
if "arguments" in tool_call["function"]:
token_count += self.count_tokens(
tool_call["function"]["arguments"]
)
# Calculate tokens for tool responses
if "name" in message and message["name"]:
token_count += self.count_tokens(message["name"])
if "tool_call_id" in message and message["tool_call_id"]:
token_count += self.count_tokens(message["tool_call_id"])
# Add extra tokens for message format
token_count += 2 # Extra tokens for message format
return token_count
return self.token_counter.count_message_tokens(messages)
def update_token_count(self, input_tokens: int) -> None:
"""Update token counts"""
@ -313,6 +413,152 @@ class LLM:
logger.error(f"Unexpected error in ask: {e}")
raise
@retry(
wait=wait_random_exponential(min=1, max=60),
stop=stop_after_attempt(6),
retry=retry_if_exception_type(
(OpenAIError, Exception, ValueError)
), # Don't retry TokenLimitExceeded
)
async def ask_with_images(
self,
messages: List[Union[dict, Message]],
images: List[Union[str, dict]],
system_msgs: Optional[List[Union[dict, Message]]] = None,
stream: bool = False,
temperature: Optional[float] = None,
) -> str:
"""
Send a prompt with images to the LLM and get the response.
Args:
messages: List of conversation messages
images: List of image URLs or image data dictionaries
system_msgs: Optional system messages to prepend
stream (bool): Whether to stream the response
temperature (float): Sampling temperature for the response
Returns:
str: The generated response
Raises:
TokenLimitExceeded: If token limits are exceeded
ValueError: If messages are invalid or response is empty
OpenAIError: If API call fails after retries
Exception: For unexpected errors
"""
try:
# Format messages
formatted_messages = self.format_messages(messages)
# Ensure the last message is from the user to attach images
if not formatted_messages or formatted_messages[-1]["role"] != "user":
raise ValueError(
"The last message must be from the user to attach images"
)
# Process the last user message to include images
last_message = formatted_messages[-1]
# Convert content to multimodal format if needed
content = last_message["content"]
multimodal_content = (
[{"type": "text", "text": content}]
if isinstance(content, str)
else content
if isinstance(content, list)
else []
)
# Add images to content
for image in images:
if isinstance(image, str):
multimodal_content.append(
{"type": "image_url", "image_url": {"url": image}}
)
elif isinstance(image, dict) and "url" in image:
multimodal_content.append({"type": "image_url", "image_url": image})
elif isinstance(image, dict) and "image_url" in image:
multimodal_content.append(image)
else:
raise ValueError(f"Unsupported image format: {image}")
# Update the message with multimodal content
last_message["content"] = multimodal_content
# Add system messages if provided
if system_msgs:
all_messages = self.format_messages(system_msgs) + formatted_messages
else:
all_messages = formatted_messages
# Calculate tokens and check limits
input_tokens = self.count_message_tokens(all_messages)
if not self.check_token_limit(input_tokens):
raise TokenLimitExceeded(self.get_limit_error_message(input_tokens))
# Set up API parameters
params = {
"model": self.model,
"messages": all_messages,
"stream": stream,
}
# Add model-specific parameters
if self.model in REASONING_MODELS:
params["max_completion_tokens"] = self.max_tokens
else:
params["max_tokens"] = self.max_tokens
params["temperature"] = (
temperature if temperature is not None else self.temperature
)
# Handle non-streaming request
if not stream:
response = await self.client.chat.completions.create(**params)
if not response.choices or not response.choices[0].message.content:
raise ValueError("Empty or invalid response from LLM")
self.update_token_count(response.usage.prompt_tokens)
return response.choices[0].message.content
# Handle streaming request
self.update_token_count(input_tokens)
response = await self.client.chat.completions.create(**params)
collected_messages = []
async for chunk in response:
chunk_message = chunk.choices[0].delta.content or ""
collected_messages.append(chunk_message)
print(chunk_message, end="", flush=True)
print() # Newline after streaming
full_response = "".join(collected_messages).strip()
if not full_response:
raise ValueError("Empty response from streaming LLM")
return full_response
except TokenLimitExceeded:
raise
except ValueError as ve:
logger.error(f"Validation error in ask_with_images: {ve}")
raise
except OpenAIError as oe:
logger.error(f"OpenAI API error: {oe}")
if isinstance(oe, AuthenticationError):
logger.error("Authentication failed. Check API key.")
elif isinstance(oe, RateLimitError):
logger.error("Rate limit exceeded. Consider increasing retry attempts.")
elif isinstance(oe, APIError):
logger.error(f"API error: {oe}")
raise
except Exception as e:
logger.error(f"Unexpected error in ask_with_images: {e}")
raise
@retry(
wait=wait_random_exponential(min=1, max=60),
stop=stop_after_attempt(6),

View File

@ -31,6 +31,7 @@ content extraction, and tab management. Supported actions include:
- 'new_tab': Open a new tab
- 'close_tab': Close the current tab
- 'refresh': Refresh the current page
- 'get_current_state': Get the current browser state including URL, title, tabs, and interactive elements
"""
@ -45,6 +46,7 @@ class BrowserUseTool(BaseTool):
"enum": [
"navigate",
"click",
"get_current_state",
"input_text",
"screenshot",
"get_html",
@ -64,7 +66,7 @@ class BrowserUseTool(BaseTool):
},
"index": {
"type": "integer",
"description": "Element index for 'click' or 'input_text' actions",
"description": "Element index (retrieved using get_current_state) for 'click' or 'input_text' actions",
},
"text": {"type": "string", "description": "Text for 'input_text' action"},
"script": {
@ -201,6 +203,9 @@ class BrowserUseTool(BaseTool):
output += f" - Downloaded file to {download_path}"
return ToolResult(output=output)
elif action == "get_current_state":
return await self.get_current_state(context)
elif action == "input_text":
if index is None or not text:
return ToolResult(
@ -287,21 +292,22 @@ class BrowserUseTool(BaseTool):
except Exception as e:
return ToolResult(error=f"Browser action '{action}' failed: {str(e)}")
async def get_current_state(self) -> ToolResult:
async def get_current_state(self, context: BrowserContext) -> ToolResult:
"""Get the current browser state as a ToolResult."""
async with self.lock:
try:
context = await self._ensure_browser_initialized()
state = await context.get_state()
state_info = {
"url": state.url,
"title": state.title,
"tabs": [tab.model_dump() for tab in state.tabs],
"interactive_elements": state.element_tree.clickable_elements_to_string(),
}
return ToolResult(output=json.dumps(state_info))
except Exception as e:
return ToolResult(error=f"Failed to get browser state: {str(e)}")
try:
state = await context.get_state()
state_info = {
"url": state.url,
"title": state.title,
"tabs": [tab.model_dump() for tab in state.tabs],
"help": "[0], [1], [2], etc., represent clickable indices corresponding to the elements listed. Clicking on these indices will navigate to or interact with the respective content behind them.",
"interactive_elements": state.element_tree.clickable_elements_to_string(),
}
return ToolResult(
output=json.dumps(state_info, indent=4, ensure_ascii=False)
)
except Exception as e:
return ToolResult(error=f"Failed to get browser state: {str(e)}")
async def cleanup(self):
"""Clean up browser resources."""