OpenManus/app/llm.py

440 lines
16 KiB
Python

from typing import Dict, List, Optional, Union
import tiktoken
from openai import (
APIError,
AsyncAzureOpenAI,
AsyncOpenAI,
AuthenticationError,
OpenAIError,
RateLimitError,
)
from tenacity import (
retry,
retry_if_exception_type,
stop_after_attempt,
wait_random_exponential,
)
from app.config import LLMSettings, config
from app.exceptions import TokenLimitExceeded
from app.logger import logger # Assuming a logger is set up in your app
from app.schema import (
ROLE_VALUES,
TOOL_CHOICE_TYPE,
TOOL_CHOICE_VALUES,
Message,
ToolChoice,
)
REASONING_MODELS = ["o1", "o3-mini"]
class LLM:
_instances: Dict[str, "LLM"] = {}
def __new__(
cls, config_name: str = "default", llm_config: Optional[LLMSettings] = None
):
if config_name not in cls._instances:
instance = super().__new__(cls)
instance.__init__(config_name, llm_config)
cls._instances[config_name] = instance
return cls._instances[config_name]
def __init__(
self, config_name: str = "default", llm_config: Optional[LLMSettings] = None
):
if not hasattr(self, "client"): # Only initialize if not already initialized
llm_config = llm_config or config.llm
llm_config = llm_config.get(config_name, llm_config["default"])
self.model = llm_config.model
self.max_tokens = llm_config.max_tokens
self.temperature = llm_config.temperature
self.api_type = llm_config.api_type
self.api_key = llm_config.api_key
self.api_version = llm_config.api_version
self.base_url = llm_config.base_url
# Add token counting related attributes
self.total_input_tokens = 0
self.max_input_tokens = (
llm_config.max_input_tokens
if hasattr(llm_config, "max_input_tokens")
else None
)
# Initialize tokenizer
try:
self.tokenizer = tiktoken.encoding_for_model(self.model)
except KeyError:
# If the model is not in tiktoken's presets, use cl100k_base as default
self.tokenizer = tiktoken.get_encoding("cl100k_base")
if self.api_type == "azure":
self.client = AsyncAzureOpenAI(
base_url=self.base_url,
api_key=self.api_key,
api_version=self.api_version,
)
else:
self.client = AsyncOpenAI(api_key=self.api_key, base_url=self.base_url)
def count_tokens(self, text: str) -> int:
"""Calculate the number of tokens in a text"""
if not text:
return 0
return len(self.tokenizer.encode(text))
def count_message_tokens(self, messages: List[dict]) -> int:
"""Calculate the number of tokens in a message list"""
token_count = 0
for message in messages:
# Base token count for each message (according to OpenAI's calculation method)
token_count += 4 # Base token count for each message
# Calculate tokens for the role
if "role" in message:
token_count += self.count_tokens(message["role"])
# Calculate tokens for the content
if "content" in message and message["content"]:
token_count += self.count_tokens(message["content"])
# Calculate tokens for tool calls
if "tool_calls" in message and message["tool_calls"]:
for tool_call in message["tool_calls"]:
if "function" in tool_call:
# Function name
if "name" in tool_call["function"]:
token_count += self.count_tokens(
tool_call["function"]["name"]
)
# Function arguments
if "arguments" in tool_call["function"]:
token_count += self.count_tokens(
tool_call["function"]["arguments"]
)
# Calculate tokens for tool responses
if "name" in message and message["name"]:
token_count += self.count_tokens(message["name"])
if "tool_call_id" in message and message["tool_call_id"]:
token_count += self.count_tokens(message["tool_call_id"])
# Add extra tokens for message format
token_count += 2 # Extra tokens for message format
return token_count
def update_token_count(self, input_tokens: int) -> None:
"""Update token counts"""
# Only track tokens if max_input_tokens is set
self.total_input_tokens += input_tokens
logger.info(
f"Token usage: Input={input_tokens}, Cumulative Input={self.total_input_tokens}"
)
def check_token_limit(self, input_tokens: int) -> bool:
"""Check if token limits are exceeded"""
if self.max_input_tokens is not None:
return (self.total_input_tokens + input_tokens) <= self.max_input_tokens
# If max_input_tokens is not set, always return True
return True
def get_limit_error_message(self, input_tokens: int) -> str:
"""Generate error message for token limit exceeded"""
if (
self.max_input_tokens is not None
and (self.total_input_tokens + input_tokens) > self.max_input_tokens
):
return f"Request may exceed input token limit (Current: {self.total_input_tokens}, Needed: {input_tokens}, Max: {self.max_input_tokens})"
return "Token limit exceeded"
@staticmethod
def format_messages(messages: List[Union[dict, Message]]) -> List[dict]:
"""
Format messages for LLM by converting them to OpenAI message format.
Args:
messages: List of messages that can be either dict or Message objects
Returns:
List[dict]: List of formatted messages in OpenAI format
Raises:
ValueError: If messages are invalid or missing required fields
TypeError: If unsupported message types are provided
Examples:
>>> msgs = [
... Message.system_message("You are a helpful assistant"),
... {"role": "user", "content": "Hello"},
... Message.user_message("How are you?")
... ]
>>> formatted = LLM.format_messages(msgs)
"""
formatted_messages = []
for message in messages:
if isinstance(message, dict):
# If message is already a dict, ensure it has required fields
if "role" not in message:
raise ValueError("Message dict must contain 'role' field")
formatted_messages.append(message)
elif isinstance(message, Message):
# If message is a Message object, convert it to dict
formatted_messages.append(message.to_dict())
else:
raise TypeError(f"Unsupported message type: {type(message)}")
# Validate all messages have required fields
for msg in formatted_messages:
if msg["role"] not in ROLE_VALUES:
raise ValueError(f"Invalid role: {msg['role']}")
if "content" not in msg and "tool_calls" not in msg:
raise ValueError(
"Message must contain either 'content' or 'tool_calls'"
)
return formatted_messages
@retry(
wait=wait_random_exponential(min=1, max=60),
stop=stop_after_attempt(6),
retry=retry_if_exception_type(
(OpenAIError, Exception, ValueError)
), # Don't retry TokenLimitExceeded
)
async def ask(
self,
messages: List[Union[dict, Message]],
system_msgs: Optional[List[Union[dict, Message]]] = None,
stream: bool = True,
temperature: Optional[float] = None,
) -> str:
"""
Send a prompt to the LLM and get the response.
Args:
messages: List of conversation messages
system_msgs: Optional system messages to prepend
stream (bool): Whether to stream the response
temperature (float): Sampling temperature for the response
Returns:
str: The generated response
Raises:
TokenLimitExceeded: If token limits are exceeded
ValueError: If messages are invalid or response is empty
OpenAIError: If API call fails after retries
Exception: For unexpected errors
"""
try:
# Format system and user messages
if system_msgs:
system_msgs = self.format_messages(system_msgs)
messages = system_msgs + self.format_messages(messages)
else:
messages = self.format_messages(messages)
# Calculate input token count
input_tokens = self.count_message_tokens(messages)
# Check if token limits are exceeded
if not self.check_token_limit(input_tokens):
error_message = self.get_limit_error_message(input_tokens)
# Raise a special exception that won't be retried
raise TokenLimitExceeded(error_message)
params = {
"model": self.model,
"messages": messages,
}
if self.model in REASONING_MODELS:
params["max_completion_tokens"] = self.max_tokens
else:
params["max_tokens"] = self.max_tokens
params["temperature"] = (
temperature if temperature is not None else self.temperature
)
if not stream:
# Non-streaming request
params["stream"] = False
response = await self.client.chat.completions.create(**params)
if not response.choices or not response.choices[0].message.content:
raise ValueError("Empty or invalid response from LLM")
# Update token counts
self.update_token_count(response.usage.prompt_tokens)
return response.choices[0].message.content
# Streaming request, For streaming, update estimated token count before making the request
self.update_token_count(input_tokens)
params["stream"] = True
response = await self.client.chat.completions.create(**params)
collected_messages = []
async for chunk in response:
chunk_message = chunk.choices[0].delta.content or ""
collected_messages.append(chunk_message)
print(chunk_message, end="", flush=True)
print() # Newline after streaming
full_response = "".join(collected_messages).strip()
if not full_response:
raise ValueError("Empty response from streaming LLM")
return full_response
except TokenLimitExceeded:
# Re-raise token limit errors without logging
raise
except ValueError as ve:
logger.error(f"Validation error: {ve}")
raise
except OpenAIError as oe:
logger.error(f"OpenAI API error: {oe}")
if isinstance(oe, AuthenticationError):
logger.error("Authentication failed. Check API key.")
elif isinstance(oe, RateLimitError):
logger.error("Rate limit exceeded. Consider increasing retry attempts.")
elif isinstance(oe, APIError):
logger.error(f"API error: {oe}")
raise
except Exception as e:
logger.error(f"Unexpected error in ask: {e}")
raise
@retry(
wait=wait_random_exponential(min=1, max=60),
stop=stop_after_attempt(6),
retry=retry_if_exception_type(
(OpenAIError, Exception, ValueError)
), # Don't retry TokenLimitExceeded
)
async def ask_tool(
self,
messages: List[Union[dict, Message]],
system_msgs: Optional[List[Union[dict, Message]]] = None,
timeout: int = 300,
tools: Optional[List[dict]] = None,
tool_choice: TOOL_CHOICE_TYPE = ToolChoice.AUTO, # type: ignore
temperature: Optional[float] = None,
**kwargs,
):
"""
Ask LLM using functions/tools and return the response.
Args:
messages: List of conversation messages
system_msgs: Optional system messages to prepend
timeout: Request timeout in seconds
tools: List of tools to use
tool_choice: Tool choice strategy
temperature: Sampling temperature for the response
**kwargs: Additional completion arguments
Returns:
ChatCompletionMessage: The model's response
Raises:
TokenLimitExceeded: If token limits are exceeded
ValueError: If tools, tool_choice, or messages are invalid
OpenAIError: If API call fails after retries
Exception: For unexpected errors
"""
try:
# Validate tool_choice
if tool_choice not in TOOL_CHOICE_VALUES:
raise ValueError(f"Invalid tool_choice: {tool_choice}")
# Format messages
if system_msgs:
system_msgs = self.format_messages(system_msgs)
messages = system_msgs + self.format_messages(messages)
else:
messages = self.format_messages(messages)
# Calculate input token count
input_tokens = self.count_message_tokens(messages)
# If there are tools, calculate token count for tool descriptions
tools_tokens = 0
if tools:
for tool in tools:
tools_tokens += self.count_tokens(str(tool))
input_tokens += tools_tokens
# Check if token limits are exceeded
if not self.check_token_limit(input_tokens):
error_message = self.get_limit_error_message(input_tokens)
# Raise a special exception that won't be retried
raise TokenLimitExceeded(error_message)
# Validate tools if provided
if tools:
for tool in tools:
if not isinstance(tool, dict) or "type" not in tool:
raise ValueError("Each tool must be a dict with 'type' field")
# Set up the completion request
params = {
"model": self.model,
"messages": messages,
"tools": tools,
"tool_choice": tool_choice,
"timeout": timeout,
**kwargs,
}
if self.model in REASONING_MODELS:
params["max_completion_tokens"] = self.max_tokens
else:
params["max_tokens"] = self.max_tokens
params["temperature"] = (
temperature if temperature is not None else self.temperature
)
response = await self.client.chat.completions.create(**params)
# Check if response is valid
if not response.choices or not response.choices[0].message:
print(response)
raise ValueError("Invalid or empty response from LLM")
# Update token counts
self.update_token_count(response.usage.prompt_tokens)
return response.choices[0].message
except TokenLimitExceeded:
# Re-raise token limit errors without logging
raise
except ValueError as ve:
logger.error(f"Validation error in ask_tool: {ve}")
raise
except OpenAIError as oe:
logger.error(f"OpenAI API error: {oe}")
if isinstance(oe, AuthenticationError):
logger.error("Authentication failed. Check API key.")
elif isinstance(oe, RateLimitError):
logger.error("Rate limit exceeded. Consider increasing retry attempts.")
elif isinstance(oe, APIError):
logger.error(f"API error: {oe}")
raise
except Exception as e:
logger.error(f"Unexpected error in ask_tool: {e}")
raise