using LiteLLM to support flexible LLM providing & adding cost calculations
This commit is contained in:
parent
111a2bc6b1
commit
62cfcb182e
@ -4,7 +4,7 @@ from typing import List, Literal, Optional
|
|||||||
|
|
||||||
from pydantic import BaseModel, Field, model_validator
|
from pydantic import BaseModel, Field, model_validator
|
||||||
|
|
||||||
from app.llm import LLM
|
from app.llm.inference import LLM
|
||||||
from app.logger import logger
|
from app.logger import logger
|
||||||
from app.schema import AgentState, Memory, Message
|
from app.schema import AgentState, Memory, Message
|
||||||
|
|
||||||
|
@ -4,7 +4,7 @@ from typing import Optional
|
|||||||
from pydantic import Field
|
from pydantic import Field
|
||||||
|
|
||||||
from app.agent.base import BaseAgent
|
from app.agent.base import BaseAgent
|
||||||
from app.llm import LLM
|
from app.llm.inference import LLM
|
||||||
from app.schema import AgentState, Memory
|
from app.schema import AgentState, Memory
|
||||||
|
|
||||||
|
|
||||||
|
@ -17,12 +17,12 @@ WORKSPACE_ROOT = PROJECT_ROOT / "workspace"
|
|||||||
|
|
||||||
class LLMSettings(BaseModel):
|
class LLMSettings(BaseModel):
|
||||||
model: str = Field(..., description="Model name")
|
model: str = Field(..., description="Model name")
|
||||||
base_url: str = Field(..., description="API base URL")
|
base_url: str = Field("https://api.openai.com/v1", description="API base URL")
|
||||||
api_key: str = Field(..., description="API key")
|
api_key: str = Field(..., description="API key")
|
||||||
max_tokens: int = Field(4096, description="Maximum number of tokens per request")
|
max_tokens: int = Field(4096, description="Maximum number of tokens per request")
|
||||||
temperature: float = Field(1.0, description="Sampling temperature")
|
temperature: float = Field(1.0, description="Sampling temperature")
|
||||||
api_type: str = Field(..., description="AzureOpenai or Openai")
|
api_type: str = Field("anthropic", description="AzureOpenai or Openai")
|
||||||
api_version: str = Field(..., description="Azure Openai version if AzureOpenai")
|
api_version: str = Field("", description="Azure Openai version if AzureOpenai")
|
||||||
|
|
||||||
|
|
||||||
class AppConfig(BaseModel):
|
class AppConfig(BaseModel):
|
||||||
@ -74,11 +74,11 @@ class Config:
|
|||||||
|
|
||||||
default_settings = {
|
default_settings = {
|
||||||
"model": base_llm.get("model"),
|
"model": base_llm.get("model"),
|
||||||
"base_url": base_llm.get("base_url"),
|
"base_url": base_llm.get("base_url", "https://api.anthropic.com"),
|
||||||
"api_key": base_llm.get("api_key"),
|
"api_key": base_llm.get("api_key"),
|
||||||
"max_tokens": base_llm.get("max_tokens", 4096),
|
"max_tokens": base_llm.get("max_tokens", 4096),
|
||||||
"temperature": base_llm.get("temperature", 1.0),
|
"temperature": base_llm.get("temperature", 1.0),
|
||||||
"api_type": base_llm.get("api_type", ""),
|
"api_type": base_llm.get("api_type", "anthropic"),
|
||||||
"api_version": base_llm.get("api_version", ""),
|
"api_version": base_llm.get("api_version", ""),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -6,7 +6,7 @@ from pydantic import Field
|
|||||||
|
|
||||||
from app.agent.base import BaseAgent
|
from app.agent.base import BaseAgent
|
||||||
from app.flow.base import BaseFlow, PlanStepStatus
|
from app.flow.base import BaseFlow, PlanStepStatus
|
||||||
from app.llm import LLM
|
from app.llm.inference import LLM
|
||||||
from app.logger import logger
|
from app.logger import logger
|
||||||
from app.schema import AgentState, Message
|
from app.schema import AgentState, Message
|
||||||
from app.tool import PlanningTool
|
from app.tool import PlanningTool
|
||||||
|
264
app/llm.py
264
app/llm.py
@ -1,264 +0,0 @@
|
|||||||
from typing import Dict, List, Literal, Optional, Union
|
|
||||||
|
|
||||||
from openai import (
|
|
||||||
APIError,
|
|
||||||
AsyncAzureOpenAI,
|
|
||||||
AsyncOpenAI,
|
|
||||||
AuthenticationError,
|
|
||||||
OpenAIError,
|
|
||||||
RateLimitError,
|
|
||||||
)
|
|
||||||
from tenacity import retry, stop_after_attempt, wait_random_exponential
|
|
||||||
|
|
||||||
from app.config import LLMSettings, config
|
|
||||||
from app.logger import logger # Assuming a logger is set up in your app
|
|
||||||
from app.schema import Message
|
|
||||||
|
|
||||||
|
|
||||||
class LLM:
|
|
||||||
_instances: Dict[str, "LLM"] = {}
|
|
||||||
|
|
||||||
def __new__(
|
|
||||||
cls, config_name: str = "default", llm_config: Optional[LLMSettings] = None
|
|
||||||
):
|
|
||||||
if config_name not in cls._instances:
|
|
||||||
instance = super().__new__(cls)
|
|
||||||
instance.__init__(config_name, llm_config)
|
|
||||||
cls._instances[config_name] = instance
|
|
||||||
return cls._instances[config_name]
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self, config_name: str = "default", llm_config: Optional[LLMSettings] = None
|
|
||||||
):
|
|
||||||
if not hasattr(self, "client"): # Only initialize if not already initialized
|
|
||||||
llm_config = llm_config or config.llm
|
|
||||||
llm_config = llm_config.get(config_name, llm_config["default"])
|
|
||||||
self.model = llm_config.model
|
|
||||||
self.max_tokens = llm_config.max_tokens
|
|
||||||
self.temperature = llm_config.temperature
|
|
||||||
self.api_type = llm_config.api_type
|
|
||||||
self.api_key = llm_config.api_key
|
|
||||||
self.api_version = llm_config.api_version
|
|
||||||
self.base_url = llm_config.base_url
|
|
||||||
if self.api_type == "azure":
|
|
||||||
self.client = AsyncAzureOpenAI(
|
|
||||||
base_url=self.base_url,
|
|
||||||
api_key=self.api_key,
|
|
||||||
api_version=self.api_version,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
self.client = AsyncOpenAI(api_key=self.api_key, base_url=self.base_url)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def format_messages(messages: List[Union[dict, Message]]) -> List[dict]:
|
|
||||||
"""
|
|
||||||
Format messages for LLM by converting them to OpenAI message format.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
messages: List of messages that can be either dict or Message objects
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List[dict]: List of formatted messages in OpenAI format
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
ValueError: If messages are invalid or missing required fields
|
|
||||||
TypeError: If unsupported message types are provided
|
|
||||||
|
|
||||||
Examples:
|
|
||||||
>>> msgs = [
|
|
||||||
... Message.system_message("You are a helpful assistant"),
|
|
||||||
... {"role": "user", "content": "Hello"},
|
|
||||||
... Message.user_message("How are you?")
|
|
||||||
... ]
|
|
||||||
>>> formatted = LLM.format_messages(msgs)
|
|
||||||
"""
|
|
||||||
formatted_messages = []
|
|
||||||
|
|
||||||
for message in messages:
|
|
||||||
if isinstance(message, dict):
|
|
||||||
# If message is already a dict, ensure it has required fields
|
|
||||||
if "role" not in message:
|
|
||||||
raise ValueError("Message dict must contain 'role' field")
|
|
||||||
formatted_messages.append(message)
|
|
||||||
elif isinstance(message, Message):
|
|
||||||
# If message is a Message object, convert it to dict
|
|
||||||
formatted_messages.append(message.to_dict())
|
|
||||||
else:
|
|
||||||
raise TypeError(f"Unsupported message type: {type(message)}")
|
|
||||||
|
|
||||||
# Validate all messages have required fields
|
|
||||||
for msg in formatted_messages:
|
|
||||||
if msg["role"] not in ["system", "user", "assistant", "tool"]:
|
|
||||||
raise ValueError(f"Invalid role: {msg['role']}")
|
|
||||||
if "content" not in msg and "tool_calls" not in msg:
|
|
||||||
raise ValueError(
|
|
||||||
"Message must contain either 'content' or 'tool_calls'"
|
|
||||||
)
|
|
||||||
|
|
||||||
return formatted_messages
|
|
||||||
|
|
||||||
@retry(
|
|
||||||
wait=wait_random_exponential(min=1, max=60),
|
|
||||||
stop=stop_after_attempt(6),
|
|
||||||
)
|
|
||||||
async def ask(
|
|
||||||
self,
|
|
||||||
messages: List[Union[dict, Message]],
|
|
||||||
system_msgs: Optional[List[Union[dict, Message]]] = None,
|
|
||||||
stream: bool = True,
|
|
||||||
temperature: Optional[float] = None,
|
|
||||||
) -> str:
|
|
||||||
"""
|
|
||||||
Send a prompt to the LLM and get the response.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
messages: List of conversation messages
|
|
||||||
system_msgs: Optional system messages to prepend
|
|
||||||
stream (bool): Whether to stream the response
|
|
||||||
temperature (float): Sampling temperature for the response
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
str: The generated response
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
ValueError: If messages are invalid or response is empty
|
|
||||||
OpenAIError: If API call fails after retries
|
|
||||||
Exception: For unexpected errors
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
# Format system and user messages
|
|
||||||
if system_msgs:
|
|
||||||
system_msgs = self.format_messages(system_msgs)
|
|
||||||
messages = system_msgs + self.format_messages(messages)
|
|
||||||
else:
|
|
||||||
messages = self.format_messages(messages)
|
|
||||||
|
|
||||||
if not stream:
|
|
||||||
# Non-streaming request
|
|
||||||
response = await self.client.chat.completions.create(
|
|
||||||
model=self.model,
|
|
||||||
messages=messages,
|
|
||||||
max_tokens=self.max_tokens,
|
|
||||||
temperature=temperature or self.temperature,
|
|
||||||
stream=False,
|
|
||||||
)
|
|
||||||
if not response.choices or not response.choices[0].message.content:
|
|
||||||
raise ValueError("Empty or invalid response from LLM")
|
|
||||||
return response.choices[0].message.content
|
|
||||||
|
|
||||||
# Streaming request
|
|
||||||
response = await self.client.chat.completions.create(
|
|
||||||
model=self.model,
|
|
||||||
messages=messages,
|
|
||||||
max_tokens=self.max_tokens,
|
|
||||||
temperature=temperature or self.temperature,
|
|
||||||
stream=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
collected_messages = []
|
|
||||||
async for chunk in response:
|
|
||||||
chunk_message = chunk.choices[0].delta.content or ""
|
|
||||||
collected_messages.append(chunk_message)
|
|
||||||
print(chunk_message, end="", flush=True)
|
|
||||||
|
|
||||||
print() # Newline after streaming
|
|
||||||
full_response = "".join(collected_messages).strip()
|
|
||||||
if not full_response:
|
|
||||||
raise ValueError("Empty response from streaming LLM")
|
|
||||||
return full_response
|
|
||||||
|
|
||||||
except ValueError as ve:
|
|
||||||
logger.error(f"Validation error: {ve}")
|
|
||||||
raise
|
|
||||||
except OpenAIError as oe:
|
|
||||||
logger.error(f"OpenAI API error: {oe}")
|
|
||||||
raise
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Unexpected error in ask: {e}")
|
|
||||||
raise
|
|
||||||
|
|
||||||
@retry(
|
|
||||||
wait=wait_random_exponential(min=1, max=60),
|
|
||||||
stop=stop_after_attempt(6),
|
|
||||||
)
|
|
||||||
async def ask_tool(
|
|
||||||
self,
|
|
||||||
messages: List[Union[dict, Message]],
|
|
||||||
system_msgs: Optional[List[Union[dict, Message]]] = None,
|
|
||||||
timeout: int = 60,
|
|
||||||
tools: Optional[List[dict]] = None,
|
|
||||||
tool_choice: Literal["none", "auto", "required"] = "auto",
|
|
||||||
temperature: Optional[float] = None,
|
|
||||||
**kwargs,
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Ask LLM using functions/tools and return the response.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
messages: List of conversation messages
|
|
||||||
system_msgs: Optional system messages to prepend
|
|
||||||
timeout: Request timeout in seconds
|
|
||||||
tools: List of tools to use
|
|
||||||
tool_choice: Tool choice strategy
|
|
||||||
temperature: Sampling temperature for the response
|
|
||||||
**kwargs: Additional completion arguments
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
ChatCompletionMessage: The model's response
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
ValueError: If tools, tool_choice, or messages are invalid
|
|
||||||
OpenAIError: If API call fails after retries
|
|
||||||
Exception: For unexpected errors
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
# Validate tool_choice
|
|
||||||
if tool_choice not in ["none", "auto", "required"]:
|
|
||||||
raise ValueError(f"Invalid tool_choice: {tool_choice}")
|
|
||||||
|
|
||||||
# Format messages
|
|
||||||
if system_msgs:
|
|
||||||
system_msgs = self.format_messages(system_msgs)
|
|
||||||
messages = system_msgs + self.format_messages(messages)
|
|
||||||
else:
|
|
||||||
messages = self.format_messages(messages)
|
|
||||||
|
|
||||||
# Validate tools if provided
|
|
||||||
if tools:
|
|
||||||
for tool in tools:
|
|
||||||
if not isinstance(tool, dict) or "type" not in tool:
|
|
||||||
raise ValueError("Each tool must be a dict with 'type' field")
|
|
||||||
|
|
||||||
# Set up the completion request
|
|
||||||
response = await self.client.chat.completions.create(
|
|
||||||
model=self.model,
|
|
||||||
messages=messages,
|
|
||||||
temperature=temperature or self.temperature,
|
|
||||||
max_tokens=self.max_tokens,
|
|
||||||
tools=tools,
|
|
||||||
tool_choice=tool_choice,
|
|
||||||
timeout=timeout,
|
|
||||||
**kwargs,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Check if response is valid
|
|
||||||
if not response.choices or not response.choices[0].message:
|
|
||||||
print(response)
|
|
||||||
raise ValueError("Invalid or empty response from LLM")
|
|
||||||
|
|
||||||
return response.choices[0].message
|
|
||||||
|
|
||||||
except ValueError as ve:
|
|
||||||
logger.error(f"Validation error in ask_tool: {ve}")
|
|
||||||
raise
|
|
||||||
except OpenAIError as oe:
|
|
||||||
if isinstance(oe, AuthenticationError):
|
|
||||||
logger.error("Authentication failed. Check API key.")
|
|
||||||
elif isinstance(oe, RateLimitError):
|
|
||||||
logger.error("Rate limit exceeded. Consider increasing retry attempts.")
|
|
||||||
elif isinstance(oe, APIError):
|
|
||||||
logger.error(f"API error: {oe}")
|
|
||||||
raise
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Unexpected error in ask_tool: {e}")
|
|
||||||
raise
|
|
0
app/llm/__init__.py
Normal file
0
app/llm/__init__.py
Normal file
50
app/llm/cost.py
Normal file
50
app/llm/cost.py
Normal file
@ -0,0 +1,50 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
|
||||||
|
|
||||||
|
class Cost:
|
||||||
|
"""
|
||||||
|
Cost class can record various costs during running and evaluation.
|
||||||
|
Currently we define the following costs:
|
||||||
|
accumulated_cost: the total cost (USD $) of the current LLM.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
self._accumulated_cost: float = 0.0
|
||||||
|
self._costs: list[float] = []
|
||||||
|
|
||||||
|
@property
|
||||||
|
def accumulated_cost(self) -> float:
|
||||||
|
return self._accumulated_cost
|
||||||
|
|
||||||
|
@accumulated_cost.setter
|
||||||
|
def accumulated_cost(self, value: float) -> None:
|
||||||
|
if value < 0:
|
||||||
|
raise ValueError("Total cost cannot be negative.")
|
||||||
|
self._accumulated_cost = value
|
||||||
|
|
||||||
|
@property
|
||||||
|
def costs(self) -> list:
|
||||||
|
return self._costs
|
||||||
|
|
||||||
|
def add_cost(self, value: float) -> None:
|
||||||
|
if value < 0:
|
||||||
|
raise ValueError("Added cost cannot be negative.")
|
||||||
|
self._accumulated_cost += value
|
||||||
|
self._costs.append(value)
|
||||||
|
|
||||||
|
def get(self):
|
||||||
|
"""
|
||||||
|
Return the costs in a dictionary.
|
||||||
|
"""
|
||||||
|
return {"accumulated_cost": self._accumulated_cost, "costs": self._costs}
|
||||||
|
|
||||||
|
def log(self):
|
||||||
|
"""
|
||||||
|
Log the costs.
|
||||||
|
"""
|
||||||
|
cost = self.get()
|
||||||
|
logs = ""
|
||||||
|
for key, value in cost.items():
|
||||||
|
logs += f"{key}: {value}\n"
|
||||||
|
return logs
|
505
app/llm/inference.py
Normal file
505
app/llm/inference.py
Normal file
@ -0,0 +1,505 @@
|
|||||||
|
from typing import Dict, List, Literal, Optional, Union, Tuple, Any
|
||||||
|
import os
|
||||||
|
import base64
|
||||||
|
|
||||||
|
import litellm
|
||||||
|
from litellm import completion, completion_cost
|
||||||
|
from litellm.exceptions import (
|
||||||
|
APIConnectionError,
|
||||||
|
RateLimitError,
|
||||||
|
ServiceUnavailableError,
|
||||||
|
)
|
||||||
|
from tenacity import retry, stop_after_attempt, wait_random_exponential, retry_if_exception_type
|
||||||
|
|
||||||
|
from app.config import LLMSettings, config
|
||||||
|
from app.logger import logger
|
||||||
|
from app.schema import Message
|
||||||
|
from app.llm.cost import Cost
|
||||||
|
|
||||||
|
|
||||||
|
class LLM:
|
||||||
|
_instances: Dict[str, "LLM"] = {}
|
||||||
|
|
||||||
|
def __new__(
|
||||||
|
cls, config_name: str = "default", llm_config: Optional[LLMSettings] = None
|
||||||
|
):
|
||||||
|
if config_name not in cls._instances:
|
||||||
|
instance = super().__new__(cls)
|
||||||
|
instance.__init__(config_name, llm_config)
|
||||||
|
cls._instances[config_name] = instance
|
||||||
|
return cls._instances[config_name]
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self, config_name: str = "default", llm_config: Optional[LLMSettings] = None
|
||||||
|
):
|
||||||
|
if not hasattr(self, "initialized"): # Only initialize if not already initialized
|
||||||
|
llm_config = llm_config or config.llm
|
||||||
|
llm_config = llm_config.get(config_name, llm_config["default"])
|
||||||
|
|
||||||
|
self.model = getattr(llm_config, "model", "gpt-3.5-turbo")
|
||||||
|
self.max_tokens = getattr(llm_config, "max_tokens", 4096)
|
||||||
|
self.temperature = getattr(llm_config, "temperature", 0.7)
|
||||||
|
self.top_p = getattr(llm_config, "top_p", 0.9)
|
||||||
|
self.api_type = getattr(llm_config, "api_type", "openai")
|
||||||
|
self.api_key = getattr(llm_config, "api_key", os.environ.get("OPENAI_API_KEY", ""))
|
||||||
|
self.api_version = getattr(llm_config, "api_version", "")
|
||||||
|
self.base_url = getattr(llm_config, "base_url", "https://api.openai.com/v1")
|
||||||
|
self.timeout = getattr(llm_config, "timeout", 60)
|
||||||
|
self.num_retries = getattr(llm_config, "num_retries", 3)
|
||||||
|
self.retry_min_wait = getattr(llm_config, "retry_min_wait", 1)
|
||||||
|
self.retry_max_wait = getattr(llm_config, "retry_max_wait", 10)
|
||||||
|
self.custom_llm_provider = getattr(llm_config, "custom_llm_provider", None)
|
||||||
|
|
||||||
|
# Get model info if available
|
||||||
|
self.model_info = None
|
||||||
|
try:
|
||||||
|
self.model_info = litellm.get_model_info(self.model)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Could not get model info for {self.model}: {e}")
|
||||||
|
|
||||||
|
# Configure litellm
|
||||||
|
if self.api_type == "azure":
|
||||||
|
litellm.api_base = self.base_url
|
||||||
|
litellm.api_key = self.api_key
|
||||||
|
litellm.api_version = self.api_version
|
||||||
|
else:
|
||||||
|
litellm.api_key = self.api_key
|
||||||
|
if self.base_url:
|
||||||
|
litellm.api_base = self.base_url
|
||||||
|
|
||||||
|
# Initialize cost tracker
|
||||||
|
self.cost_tracker = Cost()
|
||||||
|
self.initialized = True
|
||||||
|
|
||||||
|
# Initialize completion function
|
||||||
|
self._initialize_completion_function()
|
||||||
|
|
||||||
|
def _initialize_completion_function(self):
|
||||||
|
"""Initialize the completion function with retry logic"""
|
||||||
|
|
||||||
|
def attempt_on_error(retry_state):
|
||||||
|
logger.error(
|
||||||
|
f"{retry_state.outcome.exception()}. Attempt #{retry_state.attempt_number}"
|
||||||
|
)
|
||||||
|
return True
|
||||||
|
|
||||||
|
@retry(
|
||||||
|
reraise=True,
|
||||||
|
stop=stop_after_attempt(self.num_retries),
|
||||||
|
wait=wait_random_exponential(
|
||||||
|
min=self.retry_min_wait, max=self.retry_max_wait
|
||||||
|
),
|
||||||
|
retry=retry_if_exception_type(
|
||||||
|
(RateLimitError, APIConnectionError, ServiceUnavailableError)
|
||||||
|
),
|
||||||
|
after=attempt_on_error,
|
||||||
|
)
|
||||||
|
def wrapper(*args, **kwargs):
|
||||||
|
model_name = self.model
|
||||||
|
if self.api_type == "azure":
|
||||||
|
model_name = f"azure/{self.model}"
|
||||||
|
|
||||||
|
# Set default parameters if not provided
|
||||||
|
if "max_tokens" not in kwargs:
|
||||||
|
kwargs["max_tokens"] = self.max_tokens
|
||||||
|
if "temperature" not in kwargs:
|
||||||
|
kwargs["temperature"] = self.temperature
|
||||||
|
if "top_p" not in kwargs:
|
||||||
|
kwargs["top_p"] = self.top_p
|
||||||
|
if "timeout" not in kwargs:
|
||||||
|
kwargs["timeout"] = self.timeout
|
||||||
|
|
||||||
|
kwargs["model"] = model_name
|
||||||
|
|
||||||
|
# Add API credentials if not in kwargs
|
||||||
|
if "api_key" not in kwargs:
|
||||||
|
kwargs["api_key"] = self.api_key
|
||||||
|
if "base_url" not in kwargs and self.base_url:
|
||||||
|
kwargs["base_url"] = self.base_url
|
||||||
|
if "api_version" not in kwargs and self.api_version:
|
||||||
|
kwargs["api_version"] = self.api_version
|
||||||
|
if "custom_llm_provider" not in kwargs and self.custom_llm_provider:
|
||||||
|
kwargs["custom_llm_provider"] = self.custom_llm_provider
|
||||||
|
|
||||||
|
resp = completion(**kwargs)
|
||||||
|
return resp
|
||||||
|
|
||||||
|
self._completion = wrapper
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def format_messages(messages: List[Union[dict, Message]]) -> List[dict]:
|
||||||
|
"""
|
||||||
|
Format messages for LLM by converting them to OpenAI message format.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
messages: List of messages that can be either dict or Message objects
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List[dict]: List of formatted messages in OpenAI format
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If messages are invalid or missing required fields
|
||||||
|
TypeError: If unsupported message types are provided
|
||||||
|
"""
|
||||||
|
formatted_messages = []
|
||||||
|
|
||||||
|
for message in messages:
|
||||||
|
if isinstance(message, dict):
|
||||||
|
# If message is already a dict, ensure it has required fields
|
||||||
|
if "role" not in message:
|
||||||
|
raise ValueError("Message dict must contain 'role' field")
|
||||||
|
formatted_messages.append(message)
|
||||||
|
elif isinstance(message, Message):
|
||||||
|
# If message is a Message object, convert it to dict
|
||||||
|
formatted_messages.append(message.to_dict())
|
||||||
|
else:
|
||||||
|
raise TypeError(f"Unsupported message type: {type(message)}")
|
||||||
|
|
||||||
|
# Validate all messages have required fields
|
||||||
|
for msg in formatted_messages:
|
||||||
|
if msg["role"] not in ["system", "user", "assistant", "tool"]:
|
||||||
|
raise ValueError(f"Invalid role: {msg['role']}")
|
||||||
|
if "content" not in msg and "tool_calls" not in msg:
|
||||||
|
raise ValueError(
|
||||||
|
"Message must contain either 'content' or 'tool_calls'"
|
||||||
|
)
|
||||||
|
|
||||||
|
return formatted_messages
|
||||||
|
|
||||||
|
def _calculate_and_track_cost(self, response) -> float:
|
||||||
|
"""
|
||||||
|
Calculate and track the cost of an LLM API call.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
response: The response from litellm
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
float: The calculated cost
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# Use litellm's completion_cost function
|
||||||
|
cost = completion_cost(completion_response=response)
|
||||||
|
|
||||||
|
# Add the cost to our tracker
|
||||||
|
if cost > 0:
|
||||||
|
self.cost_tracker.add_cost(cost)
|
||||||
|
logger.info(f"Added cost: ${cost:.6f}, Total: ${self.cost_tracker.accumulated_cost:.6f}")
|
||||||
|
|
||||||
|
return cost
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Cost calculation failed: {e}")
|
||||||
|
return 0.0
|
||||||
|
|
||||||
|
def is_local(self) -> bool:
|
||||||
|
"""
|
||||||
|
Check if the model is running locally.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: True if the model is running locally, False otherwise
|
||||||
|
"""
|
||||||
|
if self.base_url:
|
||||||
|
return any(
|
||||||
|
substring in self.base_url
|
||||||
|
for substring in ["localhost", "127.0.0.1", "0.0.0.0"]
|
||||||
|
)
|
||||||
|
if self.model and (self.model.startswith("ollama") or "local" in self.model.lower()):
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
def do_completion(self, *args, **kwargs) -> Tuple[Any, float, float]:
|
||||||
|
"""
|
||||||
|
Perform a completion request and track cost.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple[Any, float, float]: (response, current_cost, accumulated_cost)
|
||||||
|
"""
|
||||||
|
response = self._completion(*args, **kwargs)
|
||||||
|
|
||||||
|
# Calculate and track cost
|
||||||
|
current_cost = self._calculate_and_track_cost(response)
|
||||||
|
|
||||||
|
return response, current_cost, self.cost_tracker.accumulated_cost
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def encode_image(image_path: str) -> str:
|
||||||
|
"""
|
||||||
|
Encode an image to base64.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image_path: Path to the image file
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: Base64-encoded image
|
||||||
|
"""
|
||||||
|
with open(image_path, "rb") as image_file:
|
||||||
|
return base64.b64encode(image_file.read()).decode("utf-8")
|
||||||
|
|
||||||
|
def prepare_messages(self, text: str, image_path: Optional[str] = None) -> List[dict]:
|
||||||
|
"""
|
||||||
|
Prepare messages for completion, including multimodal content if needed.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text: Text content
|
||||||
|
image_path: Optional path to an image file
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List[dict]: Formatted messages
|
||||||
|
"""
|
||||||
|
messages = [{"role": "user", "content": text}]
|
||||||
|
if image_path:
|
||||||
|
base64_image = self.encode_image(image_path)
|
||||||
|
messages[0]["content"] = [
|
||||||
|
{"type": "text", "text": text},
|
||||||
|
{
|
||||||
|
"type": "image_url",
|
||||||
|
"image_url": {"url": f"data:image/jpeg;base64,{base64_image}"},
|
||||||
|
},
|
||||||
|
]
|
||||||
|
return messages
|
||||||
|
|
||||||
|
def do_multimodal_completion(self, text: str, image_path: str) -> Tuple[Any, float, float]:
|
||||||
|
"""
|
||||||
|
Perform a multimodal completion with text and image.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text: Text prompt
|
||||||
|
image_path: Path to the image file
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple[Any, float, float]: (response, current_cost, accumulated_cost)
|
||||||
|
"""
|
||||||
|
messages = self.prepare_messages(text, image_path=image_path)
|
||||||
|
return self.do_completion(messages=messages)
|
||||||
|
|
||||||
|
@retry(
|
||||||
|
wait=wait_random_exponential(min=1, max=60),
|
||||||
|
stop=stop_after_attempt(6),
|
||||||
|
)
|
||||||
|
async def ask(
|
||||||
|
self,
|
||||||
|
messages: List[Union[dict, Message]],
|
||||||
|
system_msgs: Optional[List[Union[dict, Message]]] = None,
|
||||||
|
stream: bool = True,
|
||||||
|
temperature: Optional[float] = None,
|
||||||
|
) -> str:
|
||||||
|
"""
|
||||||
|
Send a prompt to the LLM and get the response.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
messages: List of conversation messages
|
||||||
|
system_msgs: Optional system messages to prepend
|
||||||
|
stream (bool): Whether to stream the response
|
||||||
|
temperature (float): Sampling temperature for the response
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: The generated response
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If messages are invalid or response is empty
|
||||||
|
Exception: For unexpected errors
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# Format system and user messages
|
||||||
|
if system_msgs:
|
||||||
|
system_msgs = self.format_messages(system_msgs)
|
||||||
|
messages = system_msgs + self.format_messages(messages)
|
||||||
|
else:
|
||||||
|
messages = self.format_messages(messages)
|
||||||
|
|
||||||
|
model_name = self.model
|
||||||
|
if self.api_type == "azure":
|
||||||
|
# For Azure, litellm expects model name in format: azure/<deployment_name>
|
||||||
|
model_name = f"azure/{self.model}"
|
||||||
|
|
||||||
|
if not stream:
|
||||||
|
# Non-streaming request
|
||||||
|
response = await litellm.acompletion(
|
||||||
|
model=model_name,
|
||||||
|
messages=messages,
|
||||||
|
max_tokens=self.max_tokens,
|
||||||
|
temperature=temperature or self.temperature,
|
||||||
|
stream=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Calculate and track cost
|
||||||
|
self._calculate_and_track_cost(response)
|
||||||
|
|
||||||
|
if not response.choices or not response.choices[0].message.content:
|
||||||
|
raise ValueError("Empty or invalid response from LLM")
|
||||||
|
return response.choices[0].message.content
|
||||||
|
|
||||||
|
# Streaming request
|
||||||
|
collected_messages = []
|
||||||
|
async for chunk in await litellm.acompletion(
|
||||||
|
model=model_name,
|
||||||
|
messages=messages,
|
||||||
|
max_tokens=self.max_tokens,
|
||||||
|
temperature=temperature or self.temperature,
|
||||||
|
stream=True,
|
||||||
|
):
|
||||||
|
chunk_message = chunk.choices[0].delta.content or ""
|
||||||
|
collected_messages.append(chunk_message)
|
||||||
|
print(chunk_message, end="", flush=True)
|
||||||
|
|
||||||
|
# For streaming responses, cost is calculated on the last chunk
|
||||||
|
if hasattr(chunk, 'usage') and chunk.usage:
|
||||||
|
self._calculate_and_track_cost(chunk)
|
||||||
|
|
||||||
|
print() # Newline after streaming
|
||||||
|
full_response = "".join(collected_messages).strip()
|
||||||
|
if not full_response:
|
||||||
|
raise ValueError("Empty response from streaming LLM")
|
||||||
|
return full_response
|
||||||
|
|
||||||
|
except ValueError as ve:
|
||||||
|
logger.error(f"Validation error: {ve}")
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Unexpected error in ask: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
@retry(
|
||||||
|
wait=wait_random_exponential(min=1, max=60),
|
||||||
|
stop=stop_after_attempt(6),
|
||||||
|
)
|
||||||
|
async def ask_tool(
|
||||||
|
self,
|
||||||
|
messages: List[Union[dict, Message]],
|
||||||
|
system_msgs: Optional[List[Union[dict, Message]]] = None,
|
||||||
|
timeout: int = 60,
|
||||||
|
tools: Optional[List[dict]] = None,
|
||||||
|
tool_choice: Literal["none", "auto", "required"] = "auto",
|
||||||
|
temperature: Optional[float] = None,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Ask LLM using functions/tools and return the response.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
messages: List of conversation messages
|
||||||
|
system_msgs: Optional system messages to prepend
|
||||||
|
timeout: Request timeout in seconds
|
||||||
|
tools: List of tools to use
|
||||||
|
tool_choice: Tool choice strategy
|
||||||
|
temperature: Sampling temperature for the response
|
||||||
|
**kwargs: Additional completion arguments
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The model's response
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If tools, tool_choice, or messages are invalid
|
||||||
|
Exception: For unexpected errors
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# Validate tool_choice
|
||||||
|
if tool_choice not in ["none", "auto", "required"]:
|
||||||
|
raise ValueError(f"Invalid tool_choice: {tool_choice}")
|
||||||
|
|
||||||
|
# Format messages
|
||||||
|
if system_msgs:
|
||||||
|
system_msgs = self.format_messages(system_msgs)
|
||||||
|
messages = system_msgs + self.format_messages(messages)
|
||||||
|
else:
|
||||||
|
messages = self.format_messages(messages)
|
||||||
|
|
||||||
|
# Validate tools if provided
|
||||||
|
if tools:
|
||||||
|
for tool in tools:
|
||||||
|
if not isinstance(tool, dict) or "type" not in tool:
|
||||||
|
raise ValueError("Each tool must be a dict with 'type' field")
|
||||||
|
|
||||||
|
model_name = self.model
|
||||||
|
if self.api_type == "azure":
|
||||||
|
# For Azure, litellm expects model name in format: azure/<deployment_name>
|
||||||
|
model_name = f"azure/{self.model}"
|
||||||
|
|
||||||
|
# Set up the completion request
|
||||||
|
response = await litellm.acompletion(
|
||||||
|
model=model_name,
|
||||||
|
messages=messages,
|
||||||
|
temperature=temperature or self.temperature,
|
||||||
|
max_tokens=self.max_tokens,
|
||||||
|
tools=tools,
|
||||||
|
tool_choice=tool_choice,
|
||||||
|
timeout=timeout,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Calculate and track cost
|
||||||
|
self._calculate_and_track_cost(response)
|
||||||
|
|
||||||
|
# Check if response is valid
|
||||||
|
if not response.choices or not response.choices[0].message:
|
||||||
|
print(response)
|
||||||
|
raise ValueError("Invalid or empty response from LLM")
|
||||||
|
|
||||||
|
return response.choices[0].message
|
||||||
|
|
||||||
|
except ValueError as ve:
|
||||||
|
logger.error(f"Validation error: {ve}")
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Unexpected error in ask_tool: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
def get_cost(self):
|
||||||
|
"""
|
||||||
|
Get the current cost information.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict: Dictionary containing accumulated cost and individual costs
|
||||||
|
"""
|
||||||
|
return self.cost_tracker.get()
|
||||||
|
|
||||||
|
def log_cost(self):
|
||||||
|
"""
|
||||||
|
Log the current cost information.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: Formatted string of cost information
|
||||||
|
"""
|
||||||
|
return self.cost_tracker.log()
|
||||||
|
|
||||||
|
def get_token_count(self, messages):
|
||||||
|
"""
|
||||||
|
Get the token count for a list of messages.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
messages: List of messages
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
int: Token count
|
||||||
|
"""
|
||||||
|
return litellm.token_counter(model=self.model, messages=messages)
|
||||||
|
|
||||||
|
def __str__(self):
|
||||||
|
return f"LLM(model={self.model}, base_url={self.base_url})"
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return str(self)
|
||||||
|
|
||||||
|
|
||||||
|
# Example usage
|
||||||
|
if __name__ == "__main__":
|
||||||
|
# Load environment variables if needed
|
||||||
|
from dotenv import load_dotenv
|
||||||
|
load_dotenv()
|
||||||
|
|
||||||
|
# Create LLM instance
|
||||||
|
llm = LLM()
|
||||||
|
|
||||||
|
# Test text completion
|
||||||
|
messages = llm.prepare_messages("Hello, how are you?")
|
||||||
|
response, cost, total_cost = llm.do_completion(messages=messages)
|
||||||
|
print(f"Response: {response['choices'][0]['message']['content']}")
|
||||||
|
print(f"Cost: ${cost:.6f}, Total cost: ${total_cost:.6f}")
|
||||||
|
|
||||||
|
# Test multimodal if image path is available
|
||||||
|
image_path = os.getenv("TEST_IMAGE_PATH")
|
||||||
|
if image_path and os.path.exists(image_path):
|
||||||
|
multimodal_response, mm_cost, mm_total_cost = llm.do_multimodal_completion(
|
||||||
|
"What's in this image?", image_path
|
||||||
|
)
|
||||||
|
print(f"Multimodal response: {multimodal_response['choices'][0]['message']['content']}")
|
||||||
|
print(f"Cost: ${mm_cost:.6f}, Total cost: ${mm_total_cost:.6f}")
|
@ -1,7 +1,7 @@
|
|||||||
# Global LLM configuration
|
# Global LLM configuration
|
||||||
[llm]
|
[llm]
|
||||||
model = "claude-3-5-sonnet"
|
model = "gpt-4o" #"claude-3-5-sonnet"
|
||||||
base_url = "https://api.openai.com/v1"
|
base_url = "https://api.openai.com/v1" # "https://api.anthropic.com"
|
||||||
api_key = "sk-..."
|
api_key = "sk-..."
|
||||||
max_tokens = 4096
|
max_tokens = 4096
|
||||||
temperature = 0.0
|
temperature = 0.0
|
||||||
@ -17,6 +17,6 @@ temperature = 0.0
|
|||||||
|
|
||||||
# Optional configuration for specific LLM models
|
# Optional configuration for specific LLM models
|
||||||
[llm.vision]
|
[llm.vision]
|
||||||
model = "claude-3-5-sonnet"
|
model = "gpt-4o" # "claude-3-5-sonnet"
|
||||||
base_url = "https://api.openai.com/v1"
|
base_url = "https://api.openai.com/v1" # "https://api.anthropic.com"
|
||||||
api_key = "sk-..."
|
api_key = "sk-..."
|
||||||
|
Loading…
x
Reference in New Issue
Block a user