Compare commits
6 Commits
main
...
llm-module
Author | SHA1 | Date | |
---|---|---|---|
|
13f98a0fde | ||
|
487b44fda8 | ||
|
737abe4f90 | ||
|
cbef113736 | ||
|
dc8ed530f0 | ||
|
62cfcb182e |
@ -143,6 +143,8 @@ Join our networking group on Feishu and share your experience with other develop
|
|||||||
Thanks to [anthropic-computer-use](https://github.com/anthropics/anthropic-quickstarts/tree/main/computer-use-demo)
|
Thanks to [anthropic-computer-use](https://github.com/anthropics/anthropic-quickstarts/tree/main/computer-use-demo)
|
||||||
and [browser-use](https://github.com/browser-use/browser-use) for providing basic support for this project!
|
and [browser-use](https://github.com/browser-use/browser-use) for providing basic support for this project!
|
||||||
|
|
||||||
|
Additionally, we are grateful to [AAAJ](https://github.com/metauto-ai/agent-as-a-judge), [MetaGPT](https://github.com/geekan/MetaGPT) and [OpenHands](https://github.com/All-Hands-AI/OpenHands).
|
||||||
|
|
||||||
OpenManus is built by contributors from MetaGPT. Huge thanks to this agent community!
|
OpenManus is built by contributors from MetaGPT. Huge thanks to this agent community!
|
||||||
|
|
||||||
## Cite
|
## Cite
|
||||||
|
@ -145,4 +145,6 @@ python run_flow.py
|
|||||||
特别感谢 [anthropic-computer-use](https://github.com/anthropics/anthropic-quickstarts/tree/main/computer-use-demo)
|
特别感谢 [anthropic-computer-use](https://github.com/anthropics/anthropic-quickstarts/tree/main/computer-use-demo)
|
||||||
和 [browser-use](https://github.com/browser-use/browser-use) 为本项目提供的基础支持!
|
和 [browser-use](https://github.com/browser-use/browser-use) 为本项目提供的基础支持!
|
||||||
|
|
||||||
|
此外,我们感谢 [AAAJ](https://github.com/metauto-ai/agent-as-a-judge),[MetaGPT](https://github.com/geekan/MetaGPT) 和 [OpenHands](https://github.com/All-Hands-AI/OpenHands).
|
||||||
|
|
||||||
OpenManus 由 MetaGPT 社区的贡献者共同构建,感谢这个充满活力的智能体开发者社区!
|
OpenManus 由 MetaGPT 社区的贡献者共同构建,感谢这个充满活力的智能体开发者社区!
|
||||||
|
@ -4,7 +4,7 @@ from typing import List, Literal, Optional
|
|||||||
|
|
||||||
from pydantic import BaseModel, Field, model_validator
|
from pydantic import BaseModel, Field, model_validator
|
||||||
|
|
||||||
from app.llm import LLM
|
from app.llm.inference import LLM
|
||||||
from app.logger import logger
|
from app.logger import logger
|
||||||
from app.schema import AgentState, Memory, Message
|
from app.schema import AgentState, Memory, Message
|
||||||
|
|
||||||
|
@ -4,7 +4,7 @@ from typing import Optional
|
|||||||
from pydantic import Field
|
from pydantic import Field
|
||||||
|
|
||||||
from app.agent.base import BaseAgent
|
from app.agent.base import BaseAgent
|
||||||
from app.llm import LLM
|
from app.llm.inference import LLM
|
||||||
from app.schema import AgentState, Memory
|
from app.schema import AgentState, Memory
|
||||||
|
|
||||||
|
|
||||||
|
@ -6,7 +6,7 @@ from pydantic import Field
|
|||||||
|
|
||||||
from app.agent.base import BaseAgent
|
from app.agent.base import BaseAgent
|
||||||
from app.flow.base import BaseFlow, PlanStepStatus
|
from app.flow.base import BaseFlow, PlanStepStatus
|
||||||
from app.llm import LLM
|
from app.llm.inference import LLM
|
||||||
from app.logger import logger
|
from app.logger import logger
|
||||||
from app.schema import AgentState, Message
|
from app.schema import AgentState, Message
|
||||||
from app.tool import PlanningTool
|
from app.tool import PlanningTool
|
||||||
|
264
app/llm.py
264
app/llm.py
@ -1,264 +0,0 @@
|
|||||||
from typing import Dict, List, Literal, Optional, Union
|
|
||||||
|
|
||||||
from openai import (
|
|
||||||
APIError,
|
|
||||||
AsyncAzureOpenAI,
|
|
||||||
AsyncOpenAI,
|
|
||||||
AuthenticationError,
|
|
||||||
OpenAIError,
|
|
||||||
RateLimitError,
|
|
||||||
)
|
|
||||||
from tenacity import retry, stop_after_attempt, wait_random_exponential
|
|
||||||
|
|
||||||
from app.config import LLMSettings, config
|
|
||||||
from app.logger import logger # Assuming a logger is set up in your app
|
|
||||||
from app.schema import Message
|
|
||||||
|
|
||||||
|
|
||||||
class LLM:
|
|
||||||
_instances: Dict[str, "LLM"] = {}
|
|
||||||
|
|
||||||
def __new__(
|
|
||||||
cls, config_name: str = "default", llm_config: Optional[LLMSettings] = None
|
|
||||||
):
|
|
||||||
if config_name not in cls._instances:
|
|
||||||
instance = super().__new__(cls)
|
|
||||||
instance.__init__(config_name, llm_config)
|
|
||||||
cls._instances[config_name] = instance
|
|
||||||
return cls._instances[config_name]
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self, config_name: str = "default", llm_config: Optional[LLMSettings] = None
|
|
||||||
):
|
|
||||||
if not hasattr(self, "client"): # Only initialize if not already initialized
|
|
||||||
llm_config = llm_config or config.llm
|
|
||||||
llm_config = llm_config.get(config_name, llm_config["default"])
|
|
||||||
self.model = llm_config.model
|
|
||||||
self.max_tokens = llm_config.max_tokens
|
|
||||||
self.temperature = llm_config.temperature
|
|
||||||
self.api_type = llm_config.api_type
|
|
||||||
self.api_key = llm_config.api_key
|
|
||||||
self.api_version = llm_config.api_version
|
|
||||||
self.base_url = llm_config.base_url
|
|
||||||
if self.api_type == "azure":
|
|
||||||
self.client = AsyncAzureOpenAI(
|
|
||||||
base_url=self.base_url,
|
|
||||||
api_key=self.api_key,
|
|
||||||
api_version=self.api_version,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
self.client = AsyncOpenAI(api_key=self.api_key, base_url=self.base_url)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def format_messages(messages: List[Union[dict, Message]]) -> List[dict]:
|
|
||||||
"""
|
|
||||||
Format messages for LLM by converting them to OpenAI message format.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
messages: List of messages that can be either dict or Message objects
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List[dict]: List of formatted messages in OpenAI format
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
ValueError: If messages are invalid or missing required fields
|
|
||||||
TypeError: If unsupported message types are provided
|
|
||||||
|
|
||||||
Examples:
|
|
||||||
>>> msgs = [
|
|
||||||
... Message.system_message("You are a helpful assistant"),
|
|
||||||
... {"role": "user", "content": "Hello"},
|
|
||||||
... Message.user_message("How are you?")
|
|
||||||
... ]
|
|
||||||
>>> formatted = LLM.format_messages(msgs)
|
|
||||||
"""
|
|
||||||
formatted_messages = []
|
|
||||||
|
|
||||||
for message in messages:
|
|
||||||
if isinstance(message, dict):
|
|
||||||
# If message is already a dict, ensure it has required fields
|
|
||||||
if "role" not in message:
|
|
||||||
raise ValueError("Message dict must contain 'role' field")
|
|
||||||
formatted_messages.append(message)
|
|
||||||
elif isinstance(message, Message):
|
|
||||||
# If message is a Message object, convert it to dict
|
|
||||||
formatted_messages.append(message.to_dict())
|
|
||||||
else:
|
|
||||||
raise TypeError(f"Unsupported message type: {type(message)}")
|
|
||||||
|
|
||||||
# Validate all messages have required fields
|
|
||||||
for msg in formatted_messages:
|
|
||||||
if msg["role"] not in ["system", "user", "assistant", "tool"]:
|
|
||||||
raise ValueError(f"Invalid role: {msg['role']}")
|
|
||||||
if "content" not in msg and "tool_calls" not in msg:
|
|
||||||
raise ValueError(
|
|
||||||
"Message must contain either 'content' or 'tool_calls'"
|
|
||||||
)
|
|
||||||
|
|
||||||
return formatted_messages
|
|
||||||
|
|
||||||
@retry(
|
|
||||||
wait=wait_random_exponential(min=1, max=60),
|
|
||||||
stop=stop_after_attempt(6),
|
|
||||||
)
|
|
||||||
async def ask(
|
|
||||||
self,
|
|
||||||
messages: List[Union[dict, Message]],
|
|
||||||
system_msgs: Optional[List[Union[dict, Message]]] = None,
|
|
||||||
stream: bool = True,
|
|
||||||
temperature: Optional[float] = None,
|
|
||||||
) -> str:
|
|
||||||
"""
|
|
||||||
Send a prompt to the LLM and get the response.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
messages: List of conversation messages
|
|
||||||
system_msgs: Optional system messages to prepend
|
|
||||||
stream (bool): Whether to stream the response
|
|
||||||
temperature (float): Sampling temperature for the response
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
str: The generated response
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
ValueError: If messages are invalid or response is empty
|
|
||||||
OpenAIError: If API call fails after retries
|
|
||||||
Exception: For unexpected errors
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
# Format system and user messages
|
|
||||||
if system_msgs:
|
|
||||||
system_msgs = self.format_messages(system_msgs)
|
|
||||||
messages = system_msgs + self.format_messages(messages)
|
|
||||||
else:
|
|
||||||
messages = self.format_messages(messages)
|
|
||||||
|
|
||||||
if not stream:
|
|
||||||
# Non-streaming request
|
|
||||||
response = await self.client.chat.completions.create(
|
|
||||||
model=self.model,
|
|
||||||
messages=messages,
|
|
||||||
max_tokens=self.max_tokens,
|
|
||||||
temperature=temperature or self.temperature,
|
|
||||||
stream=False,
|
|
||||||
)
|
|
||||||
if not response.choices or not response.choices[0].message.content:
|
|
||||||
raise ValueError("Empty or invalid response from LLM")
|
|
||||||
return response.choices[0].message.content
|
|
||||||
|
|
||||||
# Streaming request
|
|
||||||
response = await self.client.chat.completions.create(
|
|
||||||
model=self.model,
|
|
||||||
messages=messages,
|
|
||||||
max_tokens=self.max_tokens,
|
|
||||||
temperature=temperature or self.temperature,
|
|
||||||
stream=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
collected_messages = []
|
|
||||||
async for chunk in response:
|
|
||||||
chunk_message = chunk.choices[0].delta.content or ""
|
|
||||||
collected_messages.append(chunk_message)
|
|
||||||
print(chunk_message, end="", flush=True)
|
|
||||||
|
|
||||||
print() # Newline after streaming
|
|
||||||
full_response = "".join(collected_messages).strip()
|
|
||||||
if not full_response:
|
|
||||||
raise ValueError("Empty response from streaming LLM")
|
|
||||||
return full_response
|
|
||||||
|
|
||||||
except ValueError as ve:
|
|
||||||
logger.error(f"Validation error: {ve}")
|
|
||||||
raise
|
|
||||||
except OpenAIError as oe:
|
|
||||||
logger.error(f"OpenAI API error: {oe}")
|
|
||||||
raise
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Unexpected error in ask: {e}")
|
|
||||||
raise
|
|
||||||
|
|
||||||
@retry(
|
|
||||||
wait=wait_random_exponential(min=1, max=60),
|
|
||||||
stop=stop_after_attempt(6),
|
|
||||||
)
|
|
||||||
async def ask_tool(
|
|
||||||
self,
|
|
||||||
messages: List[Union[dict, Message]],
|
|
||||||
system_msgs: Optional[List[Union[dict, Message]]] = None,
|
|
||||||
timeout: int = 60,
|
|
||||||
tools: Optional[List[dict]] = None,
|
|
||||||
tool_choice: Literal["none", "auto", "required"] = "auto",
|
|
||||||
temperature: Optional[float] = None,
|
|
||||||
**kwargs,
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Ask LLM using functions/tools and return the response.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
messages: List of conversation messages
|
|
||||||
system_msgs: Optional system messages to prepend
|
|
||||||
timeout: Request timeout in seconds
|
|
||||||
tools: List of tools to use
|
|
||||||
tool_choice: Tool choice strategy
|
|
||||||
temperature: Sampling temperature for the response
|
|
||||||
**kwargs: Additional completion arguments
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
ChatCompletionMessage: The model's response
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
ValueError: If tools, tool_choice, or messages are invalid
|
|
||||||
OpenAIError: If API call fails after retries
|
|
||||||
Exception: For unexpected errors
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
# Validate tool_choice
|
|
||||||
if tool_choice not in ["none", "auto", "required"]:
|
|
||||||
raise ValueError(f"Invalid tool_choice: {tool_choice}")
|
|
||||||
|
|
||||||
# Format messages
|
|
||||||
if system_msgs:
|
|
||||||
system_msgs = self.format_messages(system_msgs)
|
|
||||||
messages = system_msgs + self.format_messages(messages)
|
|
||||||
else:
|
|
||||||
messages = self.format_messages(messages)
|
|
||||||
|
|
||||||
# Validate tools if provided
|
|
||||||
if tools:
|
|
||||||
for tool in tools:
|
|
||||||
if not isinstance(tool, dict) or "type" not in tool:
|
|
||||||
raise ValueError("Each tool must be a dict with 'type' field")
|
|
||||||
|
|
||||||
# Set up the completion request
|
|
||||||
response = await self.client.chat.completions.create(
|
|
||||||
model=self.model,
|
|
||||||
messages=messages,
|
|
||||||
temperature=temperature or self.temperature,
|
|
||||||
max_tokens=self.max_tokens,
|
|
||||||
tools=tools,
|
|
||||||
tool_choice=tool_choice,
|
|
||||||
timeout=timeout,
|
|
||||||
**kwargs,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Check if response is valid
|
|
||||||
if not response.choices or not response.choices[0].message:
|
|
||||||
print(response)
|
|
||||||
raise ValueError("Invalid or empty response from LLM")
|
|
||||||
|
|
||||||
return response.choices[0].message
|
|
||||||
|
|
||||||
except ValueError as ve:
|
|
||||||
logger.error(f"Validation error in ask_tool: {ve}")
|
|
||||||
raise
|
|
||||||
except OpenAIError as oe:
|
|
||||||
if isinstance(oe, AuthenticationError):
|
|
||||||
logger.error("Authentication failed. Check API key.")
|
|
||||||
elif isinstance(oe, RateLimitError):
|
|
||||||
logger.error("Rate limit exceeded. Consider increasing retry attempts.")
|
|
||||||
elif isinstance(oe, APIError):
|
|
||||||
logger.error(f"API error: {oe}")
|
|
||||||
raise
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Unexpected error in ask_tool: {e}")
|
|
||||||
raise
|
|
0
app/llm/__init__.py
Normal file
0
app/llm/__init__.py
Normal file
50
app/llm/cost.py
Normal file
50
app/llm/cost.py
Normal file
@ -0,0 +1,50 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
|
||||||
|
|
||||||
|
class Cost:
|
||||||
|
"""
|
||||||
|
Cost class can record various costs during running and evaluation.
|
||||||
|
Currently we define the following costs:
|
||||||
|
accumulated_cost: the total cost (USD $) of the current LLM.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
self._accumulated_cost: float = 0.0
|
||||||
|
self._costs: list[float] = []
|
||||||
|
|
||||||
|
@property
|
||||||
|
def accumulated_cost(self) -> float:
|
||||||
|
return self._accumulated_cost
|
||||||
|
|
||||||
|
@accumulated_cost.setter
|
||||||
|
def accumulated_cost(self, value: float) -> None:
|
||||||
|
if value < 0:
|
||||||
|
raise ValueError("Total cost cannot be negative.")
|
||||||
|
self._accumulated_cost = value
|
||||||
|
|
||||||
|
@property
|
||||||
|
def costs(self) -> list:
|
||||||
|
return self._costs
|
||||||
|
|
||||||
|
def add_cost(self, value: float) -> None:
|
||||||
|
if value < 0:
|
||||||
|
raise ValueError("Added cost cannot be negative.")
|
||||||
|
self._accumulated_cost += value
|
||||||
|
self._costs.append(value)
|
||||||
|
|
||||||
|
def get(self):
|
||||||
|
"""
|
||||||
|
Return the costs in a dictionary.
|
||||||
|
"""
|
||||||
|
return {"accumulated_cost": self._accumulated_cost, "costs": self._costs}
|
||||||
|
|
||||||
|
def log(self):
|
||||||
|
"""
|
||||||
|
Log the costs.
|
||||||
|
"""
|
||||||
|
cost = self.get()
|
||||||
|
logs = ""
|
||||||
|
for key, value in cost.items():
|
||||||
|
logs += f"{key}: {value}\n"
|
||||||
|
return logs
|
525
app/llm/inference.py
Normal file
525
app/llm/inference.py
Normal file
@ -0,0 +1,525 @@
|
|||||||
|
import base64
|
||||||
|
import os
|
||||||
|
from typing import Any, Dict, List, Literal, Optional, Tuple, Union
|
||||||
|
|
||||||
|
import litellm
|
||||||
|
from litellm import completion, completion_cost
|
||||||
|
from litellm.exceptions import (
|
||||||
|
APIConnectionError,
|
||||||
|
RateLimitError,
|
||||||
|
ServiceUnavailableError,
|
||||||
|
)
|
||||||
|
from tenacity import (
|
||||||
|
retry,
|
||||||
|
retry_if_exception_type,
|
||||||
|
stop_after_attempt,
|
||||||
|
wait_random_exponential,
|
||||||
|
)
|
||||||
|
|
||||||
|
from app.config import LLMSettings, config
|
||||||
|
from app.llm.cost import Cost
|
||||||
|
from app.logger import logger
|
||||||
|
from app.schema import Message
|
||||||
|
|
||||||
|
|
||||||
|
class LLM:
|
||||||
|
_instances: Dict[str, "LLM"] = {}
|
||||||
|
|
||||||
|
def __new__(
|
||||||
|
cls, config_name: str = "default", llm_config: Optional[LLMSettings] = None
|
||||||
|
):
|
||||||
|
if config_name not in cls._instances:
|
||||||
|
instance = super().__new__(cls)
|
||||||
|
instance.__init__(config_name, llm_config)
|
||||||
|
cls._instances[config_name] = instance
|
||||||
|
return cls._instances[config_name]
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self, config_name: str = "default", llm_config: Optional[LLMSettings] = None
|
||||||
|
):
|
||||||
|
if not hasattr(
|
||||||
|
self, "initialized"
|
||||||
|
): # Only initialize if not already initialized
|
||||||
|
llm_config = llm_config or config.llm
|
||||||
|
llm_config = llm_config.get(config_name, llm_config["default"])
|
||||||
|
|
||||||
|
self.model = getattr(llm_config, "model", "gpt-3.5-turbo")
|
||||||
|
self.max_tokens = getattr(llm_config, "max_tokens", 4096)
|
||||||
|
self.temperature = getattr(llm_config, "temperature", 0.7)
|
||||||
|
self.top_p = getattr(llm_config, "top_p", 0.9)
|
||||||
|
self.api_type = getattr(llm_config, "api_type", "openai")
|
||||||
|
self.api_key = getattr(
|
||||||
|
llm_config, "api_key", os.environ.get("OPENAI_API_KEY", "")
|
||||||
|
)
|
||||||
|
self.api_version = getattr(llm_config, "api_version", "")
|
||||||
|
self.base_url = getattr(llm_config, "base_url", "https://api.openai.com/v1")
|
||||||
|
self.timeout = getattr(llm_config, "timeout", 60)
|
||||||
|
self.num_retries = getattr(llm_config, "num_retries", 3)
|
||||||
|
self.retry_min_wait = getattr(llm_config, "retry_min_wait", 1)
|
||||||
|
self.retry_max_wait = getattr(llm_config, "retry_max_wait", 10)
|
||||||
|
self.custom_llm_provider = getattr(llm_config, "custom_llm_provider", None)
|
||||||
|
|
||||||
|
# Get model info if available
|
||||||
|
self.model_info = None
|
||||||
|
try:
|
||||||
|
self.model_info = litellm.get_model_info(self.model)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Could not get model info for {self.model}: {e}")
|
||||||
|
|
||||||
|
# Configure litellm
|
||||||
|
if self.api_type == "azure":
|
||||||
|
litellm.api_base = self.base_url
|
||||||
|
litellm.api_key = self.api_key
|
||||||
|
litellm.api_version = self.api_version
|
||||||
|
else:
|
||||||
|
litellm.api_key = self.api_key
|
||||||
|
if self.base_url:
|
||||||
|
litellm.api_base = self.base_url
|
||||||
|
|
||||||
|
# Initialize cost tracker
|
||||||
|
self.cost_tracker = Cost()
|
||||||
|
self.initialized = True
|
||||||
|
|
||||||
|
# Initialize completion function
|
||||||
|
self._initialize_completion_function()
|
||||||
|
|
||||||
|
def _initialize_completion_function(self):
|
||||||
|
"""Initialize the completion function with retry logic"""
|
||||||
|
|
||||||
|
def attempt_on_error(retry_state):
|
||||||
|
logger.error(
|
||||||
|
f"{retry_state.outcome.exception()}. Attempt #{retry_state.attempt_number}"
|
||||||
|
)
|
||||||
|
return True
|
||||||
|
|
||||||
|
@retry(
|
||||||
|
reraise=True,
|
||||||
|
stop=stop_after_attempt(self.num_retries),
|
||||||
|
wait=wait_random_exponential(
|
||||||
|
min=self.retry_min_wait, max=self.retry_max_wait
|
||||||
|
),
|
||||||
|
retry=retry_if_exception_type(
|
||||||
|
(RateLimitError, APIConnectionError, ServiceUnavailableError)
|
||||||
|
),
|
||||||
|
after=attempt_on_error,
|
||||||
|
)
|
||||||
|
def wrapper(*args, **kwargs):
|
||||||
|
model_name = self.model
|
||||||
|
if self.api_type == "azure":
|
||||||
|
model_name = f"azure/{self.model}"
|
||||||
|
|
||||||
|
# Set default parameters if not provided
|
||||||
|
if "max_tokens" not in kwargs:
|
||||||
|
kwargs["max_tokens"] = self.max_tokens
|
||||||
|
if "temperature" not in kwargs:
|
||||||
|
kwargs["temperature"] = self.temperature
|
||||||
|
if "top_p" not in kwargs:
|
||||||
|
kwargs["top_p"] = self.top_p
|
||||||
|
if "timeout" not in kwargs:
|
||||||
|
kwargs["timeout"] = self.timeout
|
||||||
|
|
||||||
|
kwargs["model"] = model_name
|
||||||
|
|
||||||
|
# Add API credentials if not in kwargs
|
||||||
|
if "api_key" not in kwargs:
|
||||||
|
kwargs["api_key"] = self.api_key
|
||||||
|
if "base_url" not in kwargs and self.base_url:
|
||||||
|
kwargs["base_url"] = self.base_url
|
||||||
|
if "api_version" not in kwargs and self.api_version:
|
||||||
|
kwargs["api_version"] = self.api_version
|
||||||
|
if "custom_llm_provider" not in kwargs and self.custom_llm_provider:
|
||||||
|
kwargs["custom_llm_provider"] = self.custom_llm_provider
|
||||||
|
|
||||||
|
resp = completion(**kwargs)
|
||||||
|
return resp
|
||||||
|
|
||||||
|
self._completion = wrapper
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def format_messages(messages: List[Union[dict, Message]]) -> List[dict]:
|
||||||
|
"""
|
||||||
|
Format messages for LLM by converting them to OpenAI message format.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
messages: List of messages that can be either dict or Message objects
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List[dict]: List of formatted messages in OpenAI format
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If messages are invalid or missing required fields
|
||||||
|
TypeError: If unsupported message types are provided
|
||||||
|
"""
|
||||||
|
formatted_messages = []
|
||||||
|
|
||||||
|
for message in messages:
|
||||||
|
if isinstance(message, dict):
|
||||||
|
# If message is already a dict, ensure it has required fields
|
||||||
|
if "role" not in message:
|
||||||
|
raise ValueError("Message dict must contain 'role' field")
|
||||||
|
formatted_messages.append(message)
|
||||||
|
elif isinstance(message, Message):
|
||||||
|
# If message is a Message object, convert it to dict
|
||||||
|
formatted_messages.append(message.to_dict())
|
||||||
|
else:
|
||||||
|
raise TypeError(f"Unsupported message type: {type(message)}")
|
||||||
|
|
||||||
|
# Validate all messages have required fields
|
||||||
|
for msg in formatted_messages:
|
||||||
|
if msg["role"] not in ["system", "user", "assistant", "tool"]:
|
||||||
|
raise ValueError(f"Invalid role: {msg['role']}")
|
||||||
|
if "content" not in msg and "tool_calls" not in msg:
|
||||||
|
raise ValueError(
|
||||||
|
"Message must contain either 'content' or 'tool_calls'"
|
||||||
|
)
|
||||||
|
|
||||||
|
return formatted_messages
|
||||||
|
|
||||||
|
def _calculate_and_track_cost(self, response) -> float:
|
||||||
|
"""
|
||||||
|
Calculate and track the cost of an LLM API call.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
response: The response from litellm
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
float: The calculated cost
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# Use litellm's completion_cost function
|
||||||
|
cost = completion_cost(completion_response=response)
|
||||||
|
|
||||||
|
# Add the cost to our tracker
|
||||||
|
if cost > 0:
|
||||||
|
self.cost_tracker.add_cost(cost)
|
||||||
|
logger.info(
|
||||||
|
f"Added cost: ${cost:.6f}, Total: ${self.cost_tracker.accumulated_cost:.6f}"
|
||||||
|
)
|
||||||
|
|
||||||
|
return cost
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Cost calculation failed: {e}")
|
||||||
|
return 0.0
|
||||||
|
|
||||||
|
def is_local(self) -> bool:
|
||||||
|
"""
|
||||||
|
Check if the model is running locally.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: True if the model is running locally, False otherwise
|
||||||
|
"""
|
||||||
|
if self.base_url:
|
||||||
|
return any(
|
||||||
|
substring in self.base_url
|
||||||
|
for substring in ["localhost", "127.0.0.1", "0.0.0.0"]
|
||||||
|
)
|
||||||
|
if self.model and (
|
||||||
|
self.model.startswith("ollama") or "local" in self.model.lower()
|
||||||
|
):
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
def do_completion(self, *args, **kwargs) -> Tuple[Any, float, float]:
|
||||||
|
"""
|
||||||
|
Perform a completion request and track cost.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple[Any, float, float]: (response, current_cost, accumulated_cost)
|
||||||
|
"""
|
||||||
|
response = self._completion(*args, **kwargs)
|
||||||
|
|
||||||
|
# Calculate and track cost
|
||||||
|
current_cost = self._calculate_and_track_cost(response)
|
||||||
|
|
||||||
|
return response, current_cost, self.cost_tracker.accumulated_cost
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def encode_image(image_path: str) -> str:
|
||||||
|
"""
|
||||||
|
Encode an image to base64.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image_path: Path to the image file
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: Base64-encoded image
|
||||||
|
"""
|
||||||
|
with open(image_path, "rb") as image_file:
|
||||||
|
return base64.b64encode(image_file.read()).decode("utf-8")
|
||||||
|
|
||||||
|
def prepare_messages(
|
||||||
|
self, text: str, image_path: Optional[str] = None
|
||||||
|
) -> List[dict]:
|
||||||
|
"""
|
||||||
|
Prepare messages for completion, including multimodal content if needed.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text: Text content
|
||||||
|
image_path: Optional path to an image file
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List[dict]: Formatted messages
|
||||||
|
"""
|
||||||
|
messages = [{"role": "user", "content": text}]
|
||||||
|
if image_path:
|
||||||
|
base64_image = self.encode_image(image_path)
|
||||||
|
messages[0]["content"] = [
|
||||||
|
{"type": "text", "text": text},
|
||||||
|
{
|
||||||
|
"type": "image_url",
|
||||||
|
"image_url": {"url": f"data:image/jpeg;base64,{base64_image}"},
|
||||||
|
},
|
||||||
|
]
|
||||||
|
return messages
|
||||||
|
|
||||||
|
def do_multimodal_completion(
|
||||||
|
self, text: str, image_path: str
|
||||||
|
) -> Tuple[Any, float, float]:
|
||||||
|
"""
|
||||||
|
Perform a multimodal completion with text and image.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text: Text prompt
|
||||||
|
image_path: Path to the image file
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple[Any, float, float]: (response, current_cost, accumulated_cost)
|
||||||
|
"""
|
||||||
|
messages = self.prepare_messages(text, image_path=image_path)
|
||||||
|
return self.do_completion(messages=messages)
|
||||||
|
|
||||||
|
@retry(
|
||||||
|
wait=wait_random_exponential(min=1, max=60),
|
||||||
|
stop=stop_after_attempt(6),
|
||||||
|
)
|
||||||
|
async def ask(
|
||||||
|
self,
|
||||||
|
messages: List[Union[dict, Message]],
|
||||||
|
system_msgs: Optional[List[Union[dict, Message]]] = None,
|
||||||
|
stream: bool = True,
|
||||||
|
temperature: Optional[float] = None,
|
||||||
|
) -> str:
|
||||||
|
"""
|
||||||
|
Send a prompt to the LLM and get the response.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
messages: List of conversation messages
|
||||||
|
system_msgs: Optional system messages to prepend
|
||||||
|
stream (bool): Whether to stream the response
|
||||||
|
temperature (float): Sampling temperature for the response
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: The generated response
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If messages are invalid or response is empty
|
||||||
|
Exception: For unexpected errors
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# Format system and user messages
|
||||||
|
if system_msgs:
|
||||||
|
system_msgs = self.format_messages(system_msgs)
|
||||||
|
messages = system_msgs + self.format_messages(messages)
|
||||||
|
else:
|
||||||
|
messages = self.format_messages(messages)
|
||||||
|
|
||||||
|
model_name = self.model
|
||||||
|
if self.api_type == "azure":
|
||||||
|
# For Azure, litellm expects model name in format: azure/<deployment_name>
|
||||||
|
model_name = f"azure/{self.model}"
|
||||||
|
|
||||||
|
if not stream:
|
||||||
|
# Non-streaming request
|
||||||
|
response = await litellm.acompletion(
|
||||||
|
model=model_name,
|
||||||
|
messages=messages,
|
||||||
|
max_tokens=self.max_tokens,
|
||||||
|
temperature=temperature or self.temperature,
|
||||||
|
stream=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Calculate and track cost
|
||||||
|
self._calculate_and_track_cost(response)
|
||||||
|
|
||||||
|
if not response.choices or not response.choices[0].message.content:
|
||||||
|
raise ValueError("Empty or invalid response from LLM")
|
||||||
|
return response.choices[0].message.content
|
||||||
|
|
||||||
|
# Streaming request
|
||||||
|
collected_messages = []
|
||||||
|
async for chunk in await litellm.acompletion(
|
||||||
|
model=model_name,
|
||||||
|
messages=messages,
|
||||||
|
max_tokens=self.max_tokens,
|
||||||
|
temperature=temperature or self.temperature,
|
||||||
|
stream=True,
|
||||||
|
):
|
||||||
|
chunk_message = chunk.choices[0].delta.content or ""
|
||||||
|
collected_messages.append(chunk_message)
|
||||||
|
print(chunk_message, end="", flush=True)
|
||||||
|
|
||||||
|
# For streaming responses, cost is calculated on the last chunk
|
||||||
|
if hasattr(chunk, "usage") and chunk.usage:
|
||||||
|
self._calculate_and_track_cost(chunk)
|
||||||
|
|
||||||
|
print() # Newline after streaming
|
||||||
|
full_response = "".join(collected_messages).strip()
|
||||||
|
if not full_response:
|
||||||
|
raise ValueError("Empty response from streaming LLM")
|
||||||
|
return full_response
|
||||||
|
|
||||||
|
except ValueError as ve:
|
||||||
|
logger.error(f"Validation error: {ve}")
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Unexpected error in ask: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
@retry(
|
||||||
|
wait=wait_random_exponential(min=1, max=60),
|
||||||
|
stop=stop_after_attempt(6),
|
||||||
|
)
|
||||||
|
async def ask_tool(
|
||||||
|
self,
|
||||||
|
messages: List[Union[dict, Message]],
|
||||||
|
system_msgs: Optional[List[Union[dict, Message]]] = None,
|
||||||
|
timeout: int = 60,
|
||||||
|
tools: Optional[List[dict]] = None,
|
||||||
|
tool_choice: Literal["none", "auto", "required"] = "auto",
|
||||||
|
temperature: Optional[float] = None,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Ask LLM using functions/tools and return the response.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
messages: List of conversation messages
|
||||||
|
system_msgs: Optional system messages to prepend
|
||||||
|
timeout: Request timeout in seconds
|
||||||
|
tools: List of tools to use
|
||||||
|
tool_choice: Tool choice strategy
|
||||||
|
temperature: Sampling temperature for the response
|
||||||
|
**kwargs: Additional completion arguments
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The model's response
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If tools, tool_choice, or messages are invalid
|
||||||
|
Exception: For unexpected errors
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# Validate tool_choice
|
||||||
|
if tool_choice not in ["none", "auto", "required"]:
|
||||||
|
raise ValueError(f"Invalid tool_choice: {tool_choice}")
|
||||||
|
|
||||||
|
# Format messages
|
||||||
|
if system_msgs:
|
||||||
|
system_msgs = self.format_messages(system_msgs)
|
||||||
|
messages = system_msgs + self.format_messages(messages)
|
||||||
|
else:
|
||||||
|
messages = self.format_messages(messages)
|
||||||
|
|
||||||
|
# Validate tools if provided
|
||||||
|
if tools:
|
||||||
|
for tool in tools:
|
||||||
|
if not isinstance(tool, dict) or "type" not in tool:
|
||||||
|
raise ValueError("Each tool must be a dict with 'type' field")
|
||||||
|
|
||||||
|
model_name = self.model
|
||||||
|
if self.api_type == "azure":
|
||||||
|
# For Azure, litellm expects model name in format: azure/<deployment_name>
|
||||||
|
model_name = f"azure/{self.model}"
|
||||||
|
|
||||||
|
# Set up the completion request
|
||||||
|
response = await litellm.acompletion(
|
||||||
|
model=model_name,
|
||||||
|
messages=messages,
|
||||||
|
temperature=temperature or self.temperature,
|
||||||
|
max_tokens=self.max_tokens,
|
||||||
|
tools=tools,
|
||||||
|
tool_choice=tool_choice,
|
||||||
|
timeout=timeout,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Calculate and track cost
|
||||||
|
self._calculate_and_track_cost(response)
|
||||||
|
|
||||||
|
# Check if response is valid
|
||||||
|
if not response.choices or not response.choices[0].message:
|
||||||
|
print(response)
|
||||||
|
raise ValueError("Invalid or empty response from LLM")
|
||||||
|
|
||||||
|
return response.choices[0].message
|
||||||
|
|
||||||
|
except ValueError as ve:
|
||||||
|
logger.error(f"Validation error: {ve}")
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Unexpected error in ask_tool: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
def get_cost(self):
|
||||||
|
"""
|
||||||
|
Get the current cost information.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict: Dictionary containing accumulated cost and individual costs
|
||||||
|
"""
|
||||||
|
return self.cost_tracker.get()
|
||||||
|
|
||||||
|
def log_cost(self):
|
||||||
|
"""
|
||||||
|
Log the current cost information.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: Formatted string of cost information
|
||||||
|
"""
|
||||||
|
return self.cost_tracker.log()
|
||||||
|
|
||||||
|
def get_token_count(self, messages):
|
||||||
|
"""
|
||||||
|
Get the token count for a list of messages.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
messages: List of messages
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
int: Token count
|
||||||
|
"""
|
||||||
|
return litellm.token_counter(model=self.model, messages=messages)
|
||||||
|
|
||||||
|
def __str__(self):
|
||||||
|
return f"LLM(model={self.model}, base_url={self.base_url})"
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return str(self)
|
||||||
|
|
||||||
|
|
||||||
|
# Example usage
|
||||||
|
if __name__ == "__main__":
|
||||||
|
# Load environment variables if needed
|
||||||
|
from dotenv import load_dotenv
|
||||||
|
|
||||||
|
load_dotenv()
|
||||||
|
|
||||||
|
# Create LLM instance
|
||||||
|
llm = LLM()
|
||||||
|
|
||||||
|
# Test text completion
|
||||||
|
messages = llm.prepare_messages("Hello, how are you?")
|
||||||
|
response, cost, total_cost = llm.do_completion(messages=messages)
|
||||||
|
print(f"Response: {response['choices'][0]['message']['content']}")
|
||||||
|
print(f"Cost: ${cost:.6f}, Total cost: ${total_cost:.6f}")
|
||||||
|
|
||||||
|
# Test multimodal if image path is available
|
||||||
|
image_path = os.getenv("TEST_IMAGE_PATH")
|
||||||
|
if image_path and os.path.exists(image_path):
|
||||||
|
multimodal_response, mm_cost, mm_total_cost = llm.do_multimodal_completion(
|
||||||
|
"What's in this image?", image_path
|
||||||
|
)
|
||||||
|
print(
|
||||||
|
f"Multimodal response: {multimodal_response['choices'][0]['message']['content']}"
|
||||||
|
)
|
||||||
|
print(f"Cost: ${mm_cost:.6f}, Total cost: ${mm_total_cost:.6f}")
|
@ -15,4 +15,4 @@ Terminate: End the current interaction when the task is complete or when you nee
|
|||||||
Based on user needs, proactively select the most appropriate tool or combination of tools. For complex tasks, you can break down the problem and use different tools step by step to solve it. After using each tool, clearly explain the execution results and suggest the next steps.
|
Based on user needs, proactively select the most appropriate tool or combination of tools. For complex tasks, you can break down the problem and use different tools step by step to solve it. After using each tool, clearly explain the execution results and suggest the next steps.
|
||||||
|
|
||||||
Always maintain a helpful, informative tone throughout the interaction. If you encounter any limitations or need more details, clearly communicate this to the user before terminating.
|
Always maintain a helpful, informative tone throughout the interaction. If you encounter any limitations or need more details, clearly communicate this to the user before terminating.
|
||||||
"""
|
"""
|
||||||
|
@ -11,6 +11,7 @@ from pydantic_core.core_schema import ValidationInfo
|
|||||||
|
|
||||||
from app.tool.base import BaseTool, ToolResult
|
from app.tool.base import BaseTool, ToolResult
|
||||||
|
|
||||||
|
|
||||||
MAX_LENGTH = 2000
|
MAX_LENGTH = 2000
|
||||||
|
|
||||||
_BROWSER_DESCRIPTION = """
|
_BROWSER_DESCRIPTION = """
|
||||||
@ -181,7 +182,9 @@ class BrowserUseTool(BaseTool):
|
|||||||
|
|
||||||
elif action == "get_html":
|
elif action == "get_html":
|
||||||
html = await context.get_page_html()
|
html = await context.get_page_html()
|
||||||
truncated = html[:MAX_LENGTH] + "..." if len(html) > MAX_LENGTH else html
|
truncated = (
|
||||||
|
html[:MAX_LENGTH] + "..." if len(html) > MAX_LENGTH else html
|
||||||
|
)
|
||||||
return ToolResult(output=truncated)
|
return ToolResult(output=truncated)
|
||||||
|
|
||||||
elif action == "get_text":
|
elif action == "get_text":
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
# Global LLM configuration
|
# Global LLM configuration
|
||||||
[llm]
|
[llm]
|
||||||
model = "claude-3-5-sonnet"
|
model = "gpt-4o" #"claude-3-5-sonnet"
|
||||||
base_url = "https://api.openai.com/v1"
|
base_url = "https://api.openai.com/v1" # "https://api.anthropic.com"
|
||||||
api_key = "sk-..."
|
api_key = "sk-..."
|
||||||
max_tokens = 4096
|
max_tokens = 4096
|
||||||
temperature = 0.0
|
temperature = 0.0
|
||||||
@ -17,6 +17,6 @@ temperature = 0.0
|
|||||||
|
|
||||||
# Optional configuration for specific LLM models
|
# Optional configuration for specific LLM models
|
||||||
[llm.vision]
|
[llm.vision]
|
||||||
model = "claude-3-5-sonnet"
|
model = "gpt-4o" # "claude-3-5-sonnet"
|
||||||
base_url = "https://api.openai.com/v1"
|
base_url = "https://api.openai.com/v1" # "https://api.anthropic.com"
|
||||||
api_key = "sk-..."
|
api_key = "sk-..."
|
||||||
|
@ -20,3 +20,4 @@ aiofiles~=24.1.0
|
|||||||
pydantic_core~=2.27.2
|
pydantic_core~=2.27.2
|
||||||
colorama~=0.4.6
|
colorama~=0.4.6
|
||||||
playwright~=1.49.1
|
playwright~=1.49.1
|
||||||
|
litellm~=1.63.6
|
||||||
|
Loading…
x
Reference in New Issue
Block a user