using LiteLLM to support flexible LLM providing & adding cost calculations
This commit is contained in:
parent
111a2bc6b1
commit
62cfcb182e
@ -4,7 +4,7 @@ from typing import List, Literal, Optional
|
||||
|
||||
from pydantic import BaseModel, Field, model_validator
|
||||
|
||||
from app.llm import LLM
|
||||
from app.llm.inference import LLM
|
||||
from app.logger import logger
|
||||
from app.schema import AgentState, Memory, Message
|
||||
|
||||
|
@ -4,7 +4,7 @@ from typing import Optional
|
||||
from pydantic import Field
|
||||
|
||||
from app.agent.base import BaseAgent
|
||||
from app.llm import LLM
|
||||
from app.llm.inference import LLM
|
||||
from app.schema import AgentState, Memory
|
||||
|
||||
|
||||
|
@ -17,12 +17,12 @@ WORKSPACE_ROOT = PROJECT_ROOT / "workspace"
|
||||
|
||||
class LLMSettings(BaseModel):
|
||||
model: str = Field(..., description="Model name")
|
||||
base_url: str = Field(..., description="API base URL")
|
||||
base_url: str = Field("https://api.openai.com/v1", description="API base URL")
|
||||
api_key: str = Field(..., description="API key")
|
||||
max_tokens: int = Field(4096, description="Maximum number of tokens per request")
|
||||
temperature: float = Field(1.0, description="Sampling temperature")
|
||||
api_type: str = Field(..., description="AzureOpenai or Openai")
|
||||
api_version: str = Field(..., description="Azure Openai version if AzureOpenai")
|
||||
api_type: str = Field("anthropic", description="AzureOpenai or Openai")
|
||||
api_version: str = Field("", description="Azure Openai version if AzureOpenai")
|
||||
|
||||
|
||||
class AppConfig(BaseModel):
|
||||
@ -74,11 +74,11 @@ class Config:
|
||||
|
||||
default_settings = {
|
||||
"model": base_llm.get("model"),
|
||||
"base_url": base_llm.get("base_url"),
|
||||
"base_url": base_llm.get("base_url", "https://api.anthropic.com"),
|
||||
"api_key": base_llm.get("api_key"),
|
||||
"max_tokens": base_llm.get("max_tokens", 4096),
|
||||
"temperature": base_llm.get("temperature", 1.0),
|
||||
"api_type": base_llm.get("api_type", ""),
|
||||
"api_type": base_llm.get("api_type", "anthropic"),
|
||||
"api_version": base_llm.get("api_version", ""),
|
||||
}
|
||||
|
||||
|
@ -6,7 +6,7 @@ from pydantic import Field
|
||||
|
||||
from app.agent.base import BaseAgent
|
||||
from app.flow.base import BaseFlow, PlanStepStatus
|
||||
from app.llm import LLM
|
||||
from app.llm.inference import LLM
|
||||
from app.logger import logger
|
||||
from app.schema import AgentState, Message
|
||||
from app.tool import PlanningTool
|
||||
|
264
app/llm.py
264
app/llm.py
@ -1,264 +0,0 @@
|
||||
from typing import Dict, List, Literal, Optional, Union
|
||||
|
||||
from openai import (
|
||||
APIError,
|
||||
AsyncAzureOpenAI,
|
||||
AsyncOpenAI,
|
||||
AuthenticationError,
|
||||
OpenAIError,
|
||||
RateLimitError,
|
||||
)
|
||||
from tenacity import retry, stop_after_attempt, wait_random_exponential
|
||||
|
||||
from app.config import LLMSettings, config
|
||||
from app.logger import logger # Assuming a logger is set up in your app
|
||||
from app.schema import Message
|
||||
|
||||
|
||||
class LLM:
|
||||
_instances: Dict[str, "LLM"] = {}
|
||||
|
||||
def __new__(
|
||||
cls, config_name: str = "default", llm_config: Optional[LLMSettings] = None
|
||||
):
|
||||
if config_name not in cls._instances:
|
||||
instance = super().__new__(cls)
|
||||
instance.__init__(config_name, llm_config)
|
||||
cls._instances[config_name] = instance
|
||||
return cls._instances[config_name]
|
||||
|
||||
def __init__(
|
||||
self, config_name: str = "default", llm_config: Optional[LLMSettings] = None
|
||||
):
|
||||
if not hasattr(self, "client"): # Only initialize if not already initialized
|
||||
llm_config = llm_config or config.llm
|
||||
llm_config = llm_config.get(config_name, llm_config["default"])
|
||||
self.model = llm_config.model
|
||||
self.max_tokens = llm_config.max_tokens
|
||||
self.temperature = llm_config.temperature
|
||||
self.api_type = llm_config.api_type
|
||||
self.api_key = llm_config.api_key
|
||||
self.api_version = llm_config.api_version
|
||||
self.base_url = llm_config.base_url
|
||||
if self.api_type == "azure":
|
||||
self.client = AsyncAzureOpenAI(
|
||||
base_url=self.base_url,
|
||||
api_key=self.api_key,
|
||||
api_version=self.api_version,
|
||||
)
|
||||
else:
|
||||
self.client = AsyncOpenAI(api_key=self.api_key, base_url=self.base_url)
|
||||
|
||||
@staticmethod
|
||||
def format_messages(messages: List[Union[dict, Message]]) -> List[dict]:
|
||||
"""
|
||||
Format messages for LLM by converting them to OpenAI message format.
|
||||
|
||||
Args:
|
||||
messages: List of messages that can be either dict or Message objects
|
||||
|
||||
Returns:
|
||||
List[dict]: List of formatted messages in OpenAI format
|
||||
|
||||
Raises:
|
||||
ValueError: If messages are invalid or missing required fields
|
||||
TypeError: If unsupported message types are provided
|
||||
|
||||
Examples:
|
||||
>>> msgs = [
|
||||
... Message.system_message("You are a helpful assistant"),
|
||||
... {"role": "user", "content": "Hello"},
|
||||
... Message.user_message("How are you?")
|
||||
... ]
|
||||
>>> formatted = LLM.format_messages(msgs)
|
||||
"""
|
||||
formatted_messages = []
|
||||
|
||||
for message in messages:
|
||||
if isinstance(message, dict):
|
||||
# If message is already a dict, ensure it has required fields
|
||||
if "role" not in message:
|
||||
raise ValueError("Message dict must contain 'role' field")
|
||||
formatted_messages.append(message)
|
||||
elif isinstance(message, Message):
|
||||
# If message is a Message object, convert it to dict
|
||||
formatted_messages.append(message.to_dict())
|
||||
else:
|
||||
raise TypeError(f"Unsupported message type: {type(message)}")
|
||||
|
||||
# Validate all messages have required fields
|
||||
for msg in formatted_messages:
|
||||
if msg["role"] not in ["system", "user", "assistant", "tool"]:
|
||||
raise ValueError(f"Invalid role: {msg['role']}")
|
||||
if "content" not in msg and "tool_calls" not in msg:
|
||||
raise ValueError(
|
||||
"Message must contain either 'content' or 'tool_calls'"
|
||||
)
|
||||
|
||||
return formatted_messages
|
||||
|
||||
@retry(
|
||||
wait=wait_random_exponential(min=1, max=60),
|
||||
stop=stop_after_attempt(6),
|
||||
)
|
||||
async def ask(
|
||||
self,
|
||||
messages: List[Union[dict, Message]],
|
||||
system_msgs: Optional[List[Union[dict, Message]]] = None,
|
||||
stream: bool = True,
|
||||
temperature: Optional[float] = None,
|
||||
) -> str:
|
||||
"""
|
||||
Send a prompt to the LLM and get the response.
|
||||
|
||||
Args:
|
||||
messages: List of conversation messages
|
||||
system_msgs: Optional system messages to prepend
|
||||
stream (bool): Whether to stream the response
|
||||
temperature (float): Sampling temperature for the response
|
||||
|
||||
Returns:
|
||||
str: The generated response
|
||||
|
||||
Raises:
|
||||
ValueError: If messages are invalid or response is empty
|
||||
OpenAIError: If API call fails after retries
|
||||
Exception: For unexpected errors
|
||||
"""
|
||||
try:
|
||||
# Format system and user messages
|
||||
if system_msgs:
|
||||
system_msgs = self.format_messages(system_msgs)
|
||||
messages = system_msgs + self.format_messages(messages)
|
||||
else:
|
||||
messages = self.format_messages(messages)
|
||||
|
||||
if not stream:
|
||||
# Non-streaming request
|
||||
response = await self.client.chat.completions.create(
|
||||
model=self.model,
|
||||
messages=messages,
|
||||
max_tokens=self.max_tokens,
|
||||
temperature=temperature or self.temperature,
|
||||
stream=False,
|
||||
)
|
||||
if not response.choices or not response.choices[0].message.content:
|
||||
raise ValueError("Empty or invalid response from LLM")
|
||||
return response.choices[0].message.content
|
||||
|
||||
# Streaming request
|
||||
response = await self.client.chat.completions.create(
|
||||
model=self.model,
|
||||
messages=messages,
|
||||
max_tokens=self.max_tokens,
|
||||
temperature=temperature or self.temperature,
|
||||
stream=True,
|
||||
)
|
||||
|
||||
collected_messages = []
|
||||
async for chunk in response:
|
||||
chunk_message = chunk.choices[0].delta.content or ""
|
||||
collected_messages.append(chunk_message)
|
||||
print(chunk_message, end="", flush=True)
|
||||
|
||||
print() # Newline after streaming
|
||||
full_response = "".join(collected_messages).strip()
|
||||
if not full_response:
|
||||
raise ValueError("Empty response from streaming LLM")
|
||||
return full_response
|
||||
|
||||
except ValueError as ve:
|
||||
logger.error(f"Validation error: {ve}")
|
||||
raise
|
||||
except OpenAIError as oe:
|
||||
logger.error(f"OpenAI API error: {oe}")
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error in ask: {e}")
|
||||
raise
|
||||
|
||||
@retry(
|
||||
wait=wait_random_exponential(min=1, max=60),
|
||||
stop=stop_after_attempt(6),
|
||||
)
|
||||
async def ask_tool(
|
||||
self,
|
||||
messages: List[Union[dict, Message]],
|
||||
system_msgs: Optional[List[Union[dict, Message]]] = None,
|
||||
timeout: int = 60,
|
||||
tools: Optional[List[dict]] = None,
|
||||
tool_choice: Literal["none", "auto", "required"] = "auto",
|
||||
temperature: Optional[float] = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Ask LLM using functions/tools and return the response.
|
||||
|
||||
Args:
|
||||
messages: List of conversation messages
|
||||
system_msgs: Optional system messages to prepend
|
||||
timeout: Request timeout in seconds
|
||||
tools: List of tools to use
|
||||
tool_choice: Tool choice strategy
|
||||
temperature: Sampling temperature for the response
|
||||
**kwargs: Additional completion arguments
|
||||
|
||||
Returns:
|
||||
ChatCompletionMessage: The model's response
|
||||
|
||||
Raises:
|
||||
ValueError: If tools, tool_choice, or messages are invalid
|
||||
OpenAIError: If API call fails after retries
|
||||
Exception: For unexpected errors
|
||||
"""
|
||||
try:
|
||||
# Validate tool_choice
|
||||
if tool_choice not in ["none", "auto", "required"]:
|
||||
raise ValueError(f"Invalid tool_choice: {tool_choice}")
|
||||
|
||||
# Format messages
|
||||
if system_msgs:
|
||||
system_msgs = self.format_messages(system_msgs)
|
||||
messages = system_msgs + self.format_messages(messages)
|
||||
else:
|
||||
messages = self.format_messages(messages)
|
||||
|
||||
# Validate tools if provided
|
||||
if tools:
|
||||
for tool in tools:
|
||||
if not isinstance(tool, dict) or "type" not in tool:
|
||||
raise ValueError("Each tool must be a dict with 'type' field")
|
||||
|
||||
# Set up the completion request
|
||||
response = await self.client.chat.completions.create(
|
||||
model=self.model,
|
||||
messages=messages,
|
||||
temperature=temperature or self.temperature,
|
||||
max_tokens=self.max_tokens,
|
||||
tools=tools,
|
||||
tool_choice=tool_choice,
|
||||
timeout=timeout,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
# Check if response is valid
|
||||
if not response.choices or not response.choices[0].message:
|
||||
print(response)
|
||||
raise ValueError("Invalid or empty response from LLM")
|
||||
|
||||
return response.choices[0].message
|
||||
|
||||
except ValueError as ve:
|
||||
logger.error(f"Validation error in ask_tool: {ve}")
|
||||
raise
|
||||
except OpenAIError as oe:
|
||||
if isinstance(oe, AuthenticationError):
|
||||
logger.error("Authentication failed. Check API key.")
|
||||
elif isinstance(oe, RateLimitError):
|
||||
logger.error("Rate limit exceeded. Consider increasing retry attempts.")
|
||||
elif isinstance(oe, APIError):
|
||||
logger.error(f"API error: {oe}")
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error in ask_tool: {e}")
|
||||
raise
|
0
app/llm/__init__.py
Normal file
0
app/llm/__init__.py
Normal file
50
app/llm/cost.py
Normal file
50
app/llm/cost.py
Normal file
@ -0,0 +1,50 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
|
||||
class Cost:
|
||||
"""
|
||||
Cost class can record various costs during running and evaluation.
|
||||
Currently we define the following costs:
|
||||
accumulated_cost: the total cost (USD $) of the current LLM.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._accumulated_cost: float = 0.0
|
||||
self._costs: list[float] = []
|
||||
|
||||
@property
|
||||
def accumulated_cost(self) -> float:
|
||||
return self._accumulated_cost
|
||||
|
||||
@accumulated_cost.setter
|
||||
def accumulated_cost(self, value: float) -> None:
|
||||
if value < 0:
|
||||
raise ValueError("Total cost cannot be negative.")
|
||||
self._accumulated_cost = value
|
||||
|
||||
@property
|
||||
def costs(self) -> list:
|
||||
return self._costs
|
||||
|
||||
def add_cost(self, value: float) -> None:
|
||||
if value < 0:
|
||||
raise ValueError("Added cost cannot be negative.")
|
||||
self._accumulated_cost += value
|
||||
self._costs.append(value)
|
||||
|
||||
def get(self):
|
||||
"""
|
||||
Return the costs in a dictionary.
|
||||
"""
|
||||
return {"accumulated_cost": self._accumulated_cost, "costs": self._costs}
|
||||
|
||||
def log(self):
|
||||
"""
|
||||
Log the costs.
|
||||
"""
|
||||
cost = self.get()
|
||||
logs = ""
|
||||
for key, value in cost.items():
|
||||
logs += f"{key}: {value}\n"
|
||||
return logs
|
505
app/llm/inference.py
Normal file
505
app/llm/inference.py
Normal file
@ -0,0 +1,505 @@
|
||||
from typing import Dict, List, Literal, Optional, Union, Tuple, Any
|
||||
import os
|
||||
import base64
|
||||
|
||||
import litellm
|
||||
from litellm import completion, completion_cost
|
||||
from litellm.exceptions import (
|
||||
APIConnectionError,
|
||||
RateLimitError,
|
||||
ServiceUnavailableError,
|
||||
)
|
||||
from tenacity import retry, stop_after_attempt, wait_random_exponential, retry_if_exception_type
|
||||
|
||||
from app.config import LLMSettings, config
|
||||
from app.logger import logger
|
||||
from app.schema import Message
|
||||
from app.llm.cost import Cost
|
||||
|
||||
|
||||
class LLM:
|
||||
_instances: Dict[str, "LLM"] = {}
|
||||
|
||||
def __new__(
|
||||
cls, config_name: str = "default", llm_config: Optional[LLMSettings] = None
|
||||
):
|
||||
if config_name not in cls._instances:
|
||||
instance = super().__new__(cls)
|
||||
instance.__init__(config_name, llm_config)
|
||||
cls._instances[config_name] = instance
|
||||
return cls._instances[config_name]
|
||||
|
||||
def __init__(
|
||||
self, config_name: str = "default", llm_config: Optional[LLMSettings] = None
|
||||
):
|
||||
if not hasattr(self, "initialized"): # Only initialize if not already initialized
|
||||
llm_config = llm_config or config.llm
|
||||
llm_config = llm_config.get(config_name, llm_config["default"])
|
||||
|
||||
self.model = getattr(llm_config, "model", "gpt-3.5-turbo")
|
||||
self.max_tokens = getattr(llm_config, "max_tokens", 4096)
|
||||
self.temperature = getattr(llm_config, "temperature", 0.7)
|
||||
self.top_p = getattr(llm_config, "top_p", 0.9)
|
||||
self.api_type = getattr(llm_config, "api_type", "openai")
|
||||
self.api_key = getattr(llm_config, "api_key", os.environ.get("OPENAI_API_KEY", ""))
|
||||
self.api_version = getattr(llm_config, "api_version", "")
|
||||
self.base_url = getattr(llm_config, "base_url", "https://api.openai.com/v1")
|
||||
self.timeout = getattr(llm_config, "timeout", 60)
|
||||
self.num_retries = getattr(llm_config, "num_retries", 3)
|
||||
self.retry_min_wait = getattr(llm_config, "retry_min_wait", 1)
|
||||
self.retry_max_wait = getattr(llm_config, "retry_max_wait", 10)
|
||||
self.custom_llm_provider = getattr(llm_config, "custom_llm_provider", None)
|
||||
|
||||
# Get model info if available
|
||||
self.model_info = None
|
||||
try:
|
||||
self.model_info = litellm.get_model_info(self.model)
|
||||
except Exception as e:
|
||||
logger.warning(f"Could not get model info for {self.model}: {e}")
|
||||
|
||||
# Configure litellm
|
||||
if self.api_type == "azure":
|
||||
litellm.api_base = self.base_url
|
||||
litellm.api_key = self.api_key
|
||||
litellm.api_version = self.api_version
|
||||
else:
|
||||
litellm.api_key = self.api_key
|
||||
if self.base_url:
|
||||
litellm.api_base = self.base_url
|
||||
|
||||
# Initialize cost tracker
|
||||
self.cost_tracker = Cost()
|
||||
self.initialized = True
|
||||
|
||||
# Initialize completion function
|
||||
self._initialize_completion_function()
|
||||
|
||||
def _initialize_completion_function(self):
|
||||
"""Initialize the completion function with retry logic"""
|
||||
|
||||
def attempt_on_error(retry_state):
|
||||
logger.error(
|
||||
f"{retry_state.outcome.exception()}. Attempt #{retry_state.attempt_number}"
|
||||
)
|
||||
return True
|
||||
|
||||
@retry(
|
||||
reraise=True,
|
||||
stop=stop_after_attempt(self.num_retries),
|
||||
wait=wait_random_exponential(
|
||||
min=self.retry_min_wait, max=self.retry_max_wait
|
||||
),
|
||||
retry=retry_if_exception_type(
|
||||
(RateLimitError, APIConnectionError, ServiceUnavailableError)
|
||||
),
|
||||
after=attempt_on_error,
|
||||
)
|
||||
def wrapper(*args, **kwargs):
|
||||
model_name = self.model
|
||||
if self.api_type == "azure":
|
||||
model_name = f"azure/{self.model}"
|
||||
|
||||
# Set default parameters if not provided
|
||||
if "max_tokens" not in kwargs:
|
||||
kwargs["max_tokens"] = self.max_tokens
|
||||
if "temperature" not in kwargs:
|
||||
kwargs["temperature"] = self.temperature
|
||||
if "top_p" not in kwargs:
|
||||
kwargs["top_p"] = self.top_p
|
||||
if "timeout" not in kwargs:
|
||||
kwargs["timeout"] = self.timeout
|
||||
|
||||
kwargs["model"] = model_name
|
||||
|
||||
# Add API credentials if not in kwargs
|
||||
if "api_key" not in kwargs:
|
||||
kwargs["api_key"] = self.api_key
|
||||
if "base_url" not in kwargs and self.base_url:
|
||||
kwargs["base_url"] = self.base_url
|
||||
if "api_version" not in kwargs and self.api_version:
|
||||
kwargs["api_version"] = self.api_version
|
||||
if "custom_llm_provider" not in kwargs and self.custom_llm_provider:
|
||||
kwargs["custom_llm_provider"] = self.custom_llm_provider
|
||||
|
||||
resp = completion(**kwargs)
|
||||
return resp
|
||||
|
||||
self._completion = wrapper
|
||||
|
||||
@staticmethod
|
||||
def format_messages(messages: List[Union[dict, Message]]) -> List[dict]:
|
||||
"""
|
||||
Format messages for LLM by converting them to OpenAI message format.
|
||||
|
||||
Args:
|
||||
messages: List of messages that can be either dict or Message objects
|
||||
|
||||
Returns:
|
||||
List[dict]: List of formatted messages in OpenAI format
|
||||
|
||||
Raises:
|
||||
ValueError: If messages are invalid or missing required fields
|
||||
TypeError: If unsupported message types are provided
|
||||
"""
|
||||
formatted_messages = []
|
||||
|
||||
for message in messages:
|
||||
if isinstance(message, dict):
|
||||
# If message is already a dict, ensure it has required fields
|
||||
if "role" not in message:
|
||||
raise ValueError("Message dict must contain 'role' field")
|
||||
formatted_messages.append(message)
|
||||
elif isinstance(message, Message):
|
||||
# If message is a Message object, convert it to dict
|
||||
formatted_messages.append(message.to_dict())
|
||||
else:
|
||||
raise TypeError(f"Unsupported message type: {type(message)}")
|
||||
|
||||
# Validate all messages have required fields
|
||||
for msg in formatted_messages:
|
||||
if msg["role"] not in ["system", "user", "assistant", "tool"]:
|
||||
raise ValueError(f"Invalid role: {msg['role']}")
|
||||
if "content" not in msg and "tool_calls" not in msg:
|
||||
raise ValueError(
|
||||
"Message must contain either 'content' or 'tool_calls'"
|
||||
)
|
||||
|
||||
return formatted_messages
|
||||
|
||||
def _calculate_and_track_cost(self, response) -> float:
|
||||
"""
|
||||
Calculate and track the cost of an LLM API call.
|
||||
|
||||
Args:
|
||||
response: The response from litellm
|
||||
|
||||
Returns:
|
||||
float: The calculated cost
|
||||
"""
|
||||
try:
|
||||
# Use litellm's completion_cost function
|
||||
cost = completion_cost(completion_response=response)
|
||||
|
||||
# Add the cost to our tracker
|
||||
if cost > 0:
|
||||
self.cost_tracker.add_cost(cost)
|
||||
logger.info(f"Added cost: ${cost:.6f}, Total: ${self.cost_tracker.accumulated_cost:.6f}")
|
||||
|
||||
return cost
|
||||
except Exception as e:
|
||||
logger.warning(f"Cost calculation failed: {e}")
|
||||
return 0.0
|
||||
|
||||
def is_local(self) -> bool:
|
||||
"""
|
||||
Check if the model is running locally.
|
||||
|
||||
Returns:
|
||||
bool: True if the model is running locally, False otherwise
|
||||
"""
|
||||
if self.base_url:
|
||||
return any(
|
||||
substring in self.base_url
|
||||
for substring in ["localhost", "127.0.0.1", "0.0.0.0"]
|
||||
)
|
||||
if self.model and (self.model.startswith("ollama") or "local" in self.model.lower()):
|
||||
return True
|
||||
return False
|
||||
|
||||
def do_completion(self, *args, **kwargs) -> Tuple[Any, float, float]:
|
||||
"""
|
||||
Perform a completion request and track cost.
|
||||
|
||||
Returns:
|
||||
Tuple[Any, float, float]: (response, current_cost, accumulated_cost)
|
||||
"""
|
||||
response = self._completion(*args, **kwargs)
|
||||
|
||||
# Calculate and track cost
|
||||
current_cost = self._calculate_and_track_cost(response)
|
||||
|
||||
return response, current_cost, self.cost_tracker.accumulated_cost
|
||||
|
||||
@staticmethod
|
||||
def encode_image(image_path: str) -> str:
|
||||
"""
|
||||
Encode an image to base64.
|
||||
|
||||
Args:
|
||||
image_path: Path to the image file
|
||||
|
||||
Returns:
|
||||
str: Base64-encoded image
|
||||
"""
|
||||
with open(image_path, "rb") as image_file:
|
||||
return base64.b64encode(image_file.read()).decode("utf-8")
|
||||
|
||||
def prepare_messages(self, text: str, image_path: Optional[str] = None) -> List[dict]:
|
||||
"""
|
||||
Prepare messages for completion, including multimodal content if needed.
|
||||
|
||||
Args:
|
||||
text: Text content
|
||||
image_path: Optional path to an image file
|
||||
|
||||
Returns:
|
||||
List[dict]: Formatted messages
|
||||
"""
|
||||
messages = [{"role": "user", "content": text}]
|
||||
if image_path:
|
||||
base64_image = self.encode_image(image_path)
|
||||
messages[0]["content"] = [
|
||||
{"type": "text", "text": text},
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {"url": f"data:image/jpeg;base64,{base64_image}"},
|
||||
},
|
||||
]
|
||||
return messages
|
||||
|
||||
def do_multimodal_completion(self, text: str, image_path: str) -> Tuple[Any, float, float]:
|
||||
"""
|
||||
Perform a multimodal completion with text and image.
|
||||
|
||||
Args:
|
||||
text: Text prompt
|
||||
image_path: Path to the image file
|
||||
|
||||
Returns:
|
||||
Tuple[Any, float, float]: (response, current_cost, accumulated_cost)
|
||||
"""
|
||||
messages = self.prepare_messages(text, image_path=image_path)
|
||||
return self.do_completion(messages=messages)
|
||||
|
||||
@retry(
|
||||
wait=wait_random_exponential(min=1, max=60),
|
||||
stop=stop_after_attempt(6),
|
||||
)
|
||||
async def ask(
|
||||
self,
|
||||
messages: List[Union[dict, Message]],
|
||||
system_msgs: Optional[List[Union[dict, Message]]] = None,
|
||||
stream: bool = True,
|
||||
temperature: Optional[float] = None,
|
||||
) -> str:
|
||||
"""
|
||||
Send a prompt to the LLM and get the response.
|
||||
|
||||
Args:
|
||||
messages: List of conversation messages
|
||||
system_msgs: Optional system messages to prepend
|
||||
stream (bool): Whether to stream the response
|
||||
temperature (float): Sampling temperature for the response
|
||||
|
||||
Returns:
|
||||
str: The generated response
|
||||
|
||||
Raises:
|
||||
ValueError: If messages are invalid or response is empty
|
||||
Exception: For unexpected errors
|
||||
"""
|
||||
try:
|
||||
# Format system and user messages
|
||||
if system_msgs:
|
||||
system_msgs = self.format_messages(system_msgs)
|
||||
messages = system_msgs + self.format_messages(messages)
|
||||
else:
|
||||
messages = self.format_messages(messages)
|
||||
|
||||
model_name = self.model
|
||||
if self.api_type == "azure":
|
||||
# For Azure, litellm expects model name in format: azure/<deployment_name>
|
||||
model_name = f"azure/{self.model}"
|
||||
|
||||
if not stream:
|
||||
# Non-streaming request
|
||||
response = await litellm.acompletion(
|
||||
model=model_name,
|
||||
messages=messages,
|
||||
max_tokens=self.max_tokens,
|
||||
temperature=temperature or self.temperature,
|
||||
stream=False,
|
||||
)
|
||||
|
||||
# Calculate and track cost
|
||||
self._calculate_and_track_cost(response)
|
||||
|
||||
if not response.choices or not response.choices[0].message.content:
|
||||
raise ValueError("Empty or invalid response from LLM")
|
||||
return response.choices[0].message.content
|
||||
|
||||
# Streaming request
|
||||
collected_messages = []
|
||||
async for chunk in await litellm.acompletion(
|
||||
model=model_name,
|
||||
messages=messages,
|
||||
max_tokens=self.max_tokens,
|
||||
temperature=temperature or self.temperature,
|
||||
stream=True,
|
||||
):
|
||||
chunk_message = chunk.choices[0].delta.content or ""
|
||||
collected_messages.append(chunk_message)
|
||||
print(chunk_message, end="", flush=True)
|
||||
|
||||
# For streaming responses, cost is calculated on the last chunk
|
||||
if hasattr(chunk, 'usage') and chunk.usage:
|
||||
self._calculate_and_track_cost(chunk)
|
||||
|
||||
print() # Newline after streaming
|
||||
full_response = "".join(collected_messages).strip()
|
||||
if not full_response:
|
||||
raise ValueError("Empty response from streaming LLM")
|
||||
return full_response
|
||||
|
||||
except ValueError as ve:
|
||||
logger.error(f"Validation error: {ve}")
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error in ask: {e}")
|
||||
raise
|
||||
|
||||
@retry(
|
||||
wait=wait_random_exponential(min=1, max=60),
|
||||
stop=stop_after_attempt(6),
|
||||
)
|
||||
async def ask_tool(
|
||||
self,
|
||||
messages: List[Union[dict, Message]],
|
||||
system_msgs: Optional[List[Union[dict, Message]]] = None,
|
||||
timeout: int = 60,
|
||||
tools: Optional[List[dict]] = None,
|
||||
tool_choice: Literal["none", "auto", "required"] = "auto",
|
||||
temperature: Optional[float] = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Ask LLM using functions/tools and return the response.
|
||||
|
||||
Args:
|
||||
messages: List of conversation messages
|
||||
system_msgs: Optional system messages to prepend
|
||||
timeout: Request timeout in seconds
|
||||
tools: List of tools to use
|
||||
tool_choice: Tool choice strategy
|
||||
temperature: Sampling temperature for the response
|
||||
**kwargs: Additional completion arguments
|
||||
|
||||
Returns:
|
||||
The model's response
|
||||
|
||||
Raises:
|
||||
ValueError: If tools, tool_choice, or messages are invalid
|
||||
Exception: For unexpected errors
|
||||
"""
|
||||
try:
|
||||
# Validate tool_choice
|
||||
if tool_choice not in ["none", "auto", "required"]:
|
||||
raise ValueError(f"Invalid tool_choice: {tool_choice}")
|
||||
|
||||
# Format messages
|
||||
if system_msgs:
|
||||
system_msgs = self.format_messages(system_msgs)
|
||||
messages = system_msgs + self.format_messages(messages)
|
||||
else:
|
||||
messages = self.format_messages(messages)
|
||||
|
||||
# Validate tools if provided
|
||||
if tools:
|
||||
for tool in tools:
|
||||
if not isinstance(tool, dict) or "type" not in tool:
|
||||
raise ValueError("Each tool must be a dict with 'type' field")
|
||||
|
||||
model_name = self.model
|
||||
if self.api_type == "azure":
|
||||
# For Azure, litellm expects model name in format: azure/<deployment_name>
|
||||
model_name = f"azure/{self.model}"
|
||||
|
||||
# Set up the completion request
|
||||
response = await litellm.acompletion(
|
||||
model=model_name,
|
||||
messages=messages,
|
||||
temperature=temperature or self.temperature,
|
||||
max_tokens=self.max_tokens,
|
||||
tools=tools,
|
||||
tool_choice=tool_choice,
|
||||
timeout=timeout,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
# Calculate and track cost
|
||||
self._calculate_and_track_cost(response)
|
||||
|
||||
# Check if response is valid
|
||||
if not response.choices or not response.choices[0].message:
|
||||
print(response)
|
||||
raise ValueError("Invalid or empty response from LLM")
|
||||
|
||||
return response.choices[0].message
|
||||
|
||||
except ValueError as ve:
|
||||
logger.error(f"Validation error: {ve}")
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error in ask_tool: {e}")
|
||||
raise
|
||||
|
||||
def get_cost(self):
|
||||
"""
|
||||
Get the current cost information.
|
||||
|
||||
Returns:
|
||||
dict: Dictionary containing accumulated cost and individual costs
|
||||
"""
|
||||
return self.cost_tracker.get()
|
||||
|
||||
def log_cost(self):
|
||||
"""
|
||||
Log the current cost information.
|
||||
|
||||
Returns:
|
||||
str: Formatted string of cost information
|
||||
"""
|
||||
return self.cost_tracker.log()
|
||||
|
||||
def get_token_count(self, messages):
|
||||
"""
|
||||
Get the token count for a list of messages.
|
||||
|
||||
Args:
|
||||
messages: List of messages
|
||||
|
||||
Returns:
|
||||
int: Token count
|
||||
"""
|
||||
return litellm.token_counter(model=self.model, messages=messages)
|
||||
|
||||
def __str__(self):
|
||||
return f"LLM(model={self.model}, base_url={self.base_url})"
|
||||
|
||||
def __repr__(self):
|
||||
return str(self)
|
||||
|
||||
|
||||
# Example usage
|
||||
if __name__ == "__main__":
|
||||
# Load environment variables if needed
|
||||
from dotenv import load_dotenv
|
||||
load_dotenv()
|
||||
|
||||
# Create LLM instance
|
||||
llm = LLM()
|
||||
|
||||
# Test text completion
|
||||
messages = llm.prepare_messages("Hello, how are you?")
|
||||
response, cost, total_cost = llm.do_completion(messages=messages)
|
||||
print(f"Response: {response['choices'][0]['message']['content']}")
|
||||
print(f"Cost: ${cost:.6f}, Total cost: ${total_cost:.6f}")
|
||||
|
||||
# Test multimodal if image path is available
|
||||
image_path = os.getenv("TEST_IMAGE_PATH")
|
||||
if image_path and os.path.exists(image_path):
|
||||
multimodal_response, mm_cost, mm_total_cost = llm.do_multimodal_completion(
|
||||
"What's in this image?", image_path
|
||||
)
|
||||
print(f"Multimodal response: {multimodal_response['choices'][0]['message']['content']}")
|
||||
print(f"Cost: ${mm_cost:.6f}, Total cost: ${mm_total_cost:.6f}")
|
@ -1,7 +1,7 @@
|
||||
# Global LLM configuration
|
||||
[llm]
|
||||
model = "claude-3-5-sonnet"
|
||||
base_url = "https://api.openai.com/v1"
|
||||
model = "gpt-4o" #"claude-3-5-sonnet"
|
||||
base_url = "https://api.openai.com/v1" # "https://api.anthropic.com"
|
||||
api_key = "sk-..."
|
||||
max_tokens = 4096
|
||||
temperature = 0.0
|
||||
@ -17,6 +17,6 @@ temperature = 0.0
|
||||
|
||||
# Optional configuration for specific LLM models
|
||||
[llm.vision]
|
||||
model = "claude-3-5-sonnet"
|
||||
base_url = "https://api.openai.com/v1"
|
||||
model = "gpt-4o" # "claude-3-5-sonnet"
|
||||
base_url = "https://api.openai.com/v1" # "https://api.anthropic.com"
|
||||
api_key = "sk-..."
|
||||
|
Loading…
x
Reference in New Issue
Block a user