Merge branch 'mannaandpoem:main' into patch/mcp-server
This commit is contained in:
commit
acb435f9f5
275
app/bedrock.py
Normal file
275
app/bedrock.py
Normal file
@ -0,0 +1,275 @@
|
||||
from typing import Dict, List, Literal, Optional, Union
|
||||
import boto3
|
||||
import json
|
||||
import time
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
import sys
|
||||
|
||||
# Global variables to track the current tool use ID across function calls
|
||||
# Tmp solution
|
||||
CURRENT_TOOLUSE_ID = None
|
||||
|
||||
# Class to handle OpenAI-style response formatting
|
||||
class OpenAIResponse:
|
||||
def __init__(self, data):
|
||||
# Recursively convert nested dicts and lists to OpenAIResponse objects
|
||||
for key, value in data.items():
|
||||
if isinstance(value, dict):
|
||||
value = OpenAIResponse(value)
|
||||
elif isinstance(value, list):
|
||||
value = [OpenAIResponse(item) if isinstance(item, dict) else item for item in value]
|
||||
setattr(self, key, value)
|
||||
|
||||
def model_dump(self, *args, **kwargs):
|
||||
# Convert object to dict and add timestamp
|
||||
data = self.__dict__
|
||||
data['created_at'] = datetime.now().isoformat()
|
||||
return data
|
||||
|
||||
# Main client class for interacting with Amazon Bedrock
|
||||
class BedrockClient:
|
||||
def __init__(self):
|
||||
# Initialize Bedrock client, you need to configure AWS env first
|
||||
try:
|
||||
self.client = boto3.client('bedrock-runtime')
|
||||
self.chat = Chat(self.client)
|
||||
except Exception as e:
|
||||
print(f"Error initializing Bedrock client: {e}")
|
||||
sys.exit(1)
|
||||
|
||||
# Chat interface class
|
||||
class Chat:
|
||||
def __init__(self, client):
|
||||
self.completions = ChatCompletions(client)
|
||||
|
||||
# Core class handling chat completions functionality
|
||||
class ChatCompletions:
|
||||
def __init__(self, client):
|
||||
self.client = client
|
||||
|
||||
def _convert_openai_tools_to_bedrock_format(self, tools):
|
||||
# Convert OpenAI function calling format to Bedrock tool format
|
||||
bedrock_tools = []
|
||||
for tool in tools:
|
||||
if tool.get('type') == 'function':
|
||||
function = tool.get('function', {})
|
||||
bedrock_tool = {
|
||||
"toolSpec": {
|
||||
"name": function.get('name', ''),
|
||||
"description": function.get('description', ''),
|
||||
"inputSchema": {
|
||||
"json": {
|
||||
"type": "object",
|
||||
"properties": function.get('parameters', {}).get('properties', {}),
|
||||
"required": function.get('parameters', {}).get('required', [])
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
bedrock_tools.append(bedrock_tool)
|
||||
return bedrock_tools
|
||||
|
||||
def _convert_openai_messages_to_bedrock_format(self, messages):
|
||||
# Convert OpenAI message format to Bedrock message format
|
||||
bedrock_messages = []
|
||||
system_prompt = []
|
||||
for message in messages:
|
||||
if message.get('role') == 'system':
|
||||
system_prompt = [{"text": message.get('content')}]
|
||||
elif message.get('role') == 'user':
|
||||
bedrock_message = {
|
||||
"role": message.get('role', 'user'),
|
||||
"content": [{"text": message.get('content')}]
|
||||
}
|
||||
bedrock_messages.append(bedrock_message)
|
||||
elif message.get('role') == 'assistant':
|
||||
bedrock_message = {
|
||||
"role": "assistant",
|
||||
"content": [{"text": message.get('content')}]
|
||||
}
|
||||
openai_tool_calls = message.get('tool_calls', [])
|
||||
if openai_tool_calls:
|
||||
bedrock_tool_use = {
|
||||
"toolUseId": openai_tool_calls[0]['id'],
|
||||
"name": openai_tool_calls[0]['function']['name'],
|
||||
"input": json.loads(openai_tool_calls[0]['function']['arguments'])
|
||||
}
|
||||
bedrock_message['content'].append({"toolUse": bedrock_tool_use})
|
||||
global CURRENT_TOOLUSE_ID
|
||||
CURRENT_TOOLUSE_ID = openai_tool_calls[0]['id']
|
||||
bedrock_messages.append(bedrock_message)
|
||||
elif message.get('role') == 'tool':
|
||||
bedrock_message = {
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"toolResult": {
|
||||
"toolUseId": CURRENT_TOOLUSE_ID,
|
||||
"content": [{"text":message.get('content')}]
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
bedrock_messages.append(bedrock_message)
|
||||
else:
|
||||
raise ValueError(f"Invalid role: {message.get('role')}")
|
||||
return system_prompt, bedrock_messages
|
||||
|
||||
def _convert_bedrock_response_to_openai_format(self, bedrock_response):
|
||||
# Convert Bedrock response format to OpenAI format
|
||||
content = ""
|
||||
if bedrock_response.get('output', {}).get('message', {}).get('content'):
|
||||
content_array = bedrock_response['output']['message']['content']
|
||||
content = "".join(item.get('text', '') for item in content_array)
|
||||
if content == "": content = "."
|
||||
|
||||
# Handle tool calls in response
|
||||
openai_tool_calls = []
|
||||
if bedrock_response.get('output', {}).get('message', {}).get('content'):
|
||||
for content_item in bedrock_response['output']['message']['content']:
|
||||
if content_item.get('toolUse'):
|
||||
bedrock_tool_use = content_item['toolUse']
|
||||
global CURRENT_TOOLUSE_ID
|
||||
CURRENT_TOOLUSE_ID = bedrock_tool_use['toolUseId']
|
||||
openai_tool_call = {
|
||||
'id': CURRENT_TOOLUSE_ID,
|
||||
'type': 'function',
|
||||
'function': {
|
||||
'name': bedrock_tool_use['name'],
|
||||
'arguments': json.dumps(bedrock_tool_use['input'])
|
||||
}
|
||||
}
|
||||
openai_tool_calls.append(openai_tool_call)
|
||||
|
||||
# Construct final OpenAI format response
|
||||
openai_format = {
|
||||
"id": f"chatcmpl-{uuid.uuid4()}",
|
||||
"created": int(time.time()),
|
||||
"object": "chat.completion",
|
||||
"system_fingerprint": None,
|
||||
"choices": [
|
||||
{
|
||||
"finish_reason": bedrock_response.get('stopReason', 'end_turn'),
|
||||
"index": 0,
|
||||
"message": {
|
||||
"content": content,
|
||||
"role": bedrock_response.get('output', {}).get('message', {}).get('role', 'assistant'),
|
||||
"tool_calls": openai_tool_calls if openai_tool_calls != [] else None,
|
||||
"function_call": None
|
||||
}
|
||||
}
|
||||
],
|
||||
"usage": {
|
||||
"completion_tokens": bedrock_response.get('usage', {}).get('outputTokens', 0),
|
||||
"prompt_tokens": bedrock_response.get('usage', {}).get('inputTokens', 0),
|
||||
"total_tokens": bedrock_response.get('usage', {}).get('totalTokens', 0)
|
||||
}
|
||||
}
|
||||
return OpenAIResponse(openai_format)
|
||||
|
||||
async def _invoke_bedrock(
|
||||
self,
|
||||
model: str,
|
||||
messages: List[Dict[str, str]],
|
||||
max_tokens: int,
|
||||
temperature: float,
|
||||
tools: Optional[List[dict]] = None,
|
||||
tool_choice: Literal["none", "auto", "required"] = "auto",
|
||||
**kwargs
|
||||
) -> OpenAIResponse:
|
||||
# Non-streaming invocation of Bedrock model
|
||||
system_prompt, bedrock_messages = self._convert_openai_messages_to_bedrock_format(messages)
|
||||
response = self.client.converse(
|
||||
modelId = model,
|
||||
system = system_prompt,
|
||||
messages = bedrock_messages,
|
||||
inferenceConfig = {"temperature": temperature, "maxTokens": max_tokens},
|
||||
toolConfig = {"tools": tools} if tools else None,
|
||||
)
|
||||
openai_response = self._convert_bedrock_response_to_openai_format(response)
|
||||
return openai_response
|
||||
|
||||
async def _invoke_bedrock_stream(
|
||||
self,
|
||||
model: str,
|
||||
messages: List[Dict[str, str]],
|
||||
max_tokens: int,
|
||||
temperature: float,
|
||||
tools: Optional[List[dict]] = None,
|
||||
tool_choice: Literal["none", "auto", "required"] = "auto",
|
||||
**kwargs
|
||||
) -> OpenAIResponse:
|
||||
# Streaming invocation of Bedrock model
|
||||
system_prompt, bedrock_messages = self._convert_openai_messages_to_bedrock_format(messages)
|
||||
response = self.client.converse_stream(
|
||||
modelId = model,
|
||||
system = system_prompt,
|
||||
messages = bedrock_messages,
|
||||
inferenceConfig = {"temperature": temperature, "maxTokens": max_tokens},
|
||||
toolConfig = {"tools": tools} if tools else None,
|
||||
)
|
||||
|
||||
# Initialize response structure
|
||||
bedrock_response = {
|
||||
'output': {
|
||||
'message': {
|
||||
'role': '',
|
||||
'content': []
|
||||
}
|
||||
},
|
||||
'stopReason': '',
|
||||
'usage': {},
|
||||
'metrics': {}
|
||||
}
|
||||
bedrock_response_text = ""
|
||||
bedrock_response_tool_input = ""
|
||||
|
||||
# Process streaming response
|
||||
stream = response.get('stream')
|
||||
if stream:
|
||||
for event in stream:
|
||||
if event.get('messageStart', {}).get('role'):
|
||||
bedrock_response['output']['message']['role'] = event['messageStart']['role']
|
||||
if event.get('contentBlockDelta', {}).get('delta', {}).get('text'):
|
||||
bedrock_response_text += event['contentBlockDelta']['delta']['text']
|
||||
print(event['contentBlockDelta']['delta']['text'], end='', flush=True)
|
||||
if event.get('contentBlockStop', {}).get('contentBlockIndex') == 0:
|
||||
bedrock_response['output']['message']['content'].append({"text": bedrock_response_text})
|
||||
if event.get('contentBlockStart', {}).get('start', {}).get('toolUse'):
|
||||
bedrock_tool_use = event['contentBlockStart']['start']['toolUse']
|
||||
tool_use = {
|
||||
"toolUseId": bedrock_tool_use['toolUseId'],
|
||||
"name": bedrock_tool_use['name'],
|
||||
}
|
||||
bedrock_response['output']['message']['content'].append({"toolUse": tool_use})
|
||||
global CURRENT_TOOLUSE_ID
|
||||
CURRENT_TOOLUSE_ID = bedrock_tool_use['toolUseId']
|
||||
if event.get('contentBlockDelta', {}).get('delta', {}).get('toolUse'):
|
||||
bedrock_response_tool_input += event['contentBlockDelta']['delta']['toolUse']['input']
|
||||
print(event['contentBlockDelta']['delta']['toolUse']['input'], end='', flush=True)
|
||||
if event.get('contentBlockStop', {}).get('contentBlockIndex') == 1:
|
||||
bedrock_response['output']['message']['content'][1]['toolUse']['input'] = json.loads(bedrock_response_tool_input)
|
||||
print()
|
||||
openai_response = self._convert_bedrock_response_to_openai_format(bedrock_response)
|
||||
return openai_response
|
||||
|
||||
def create(
|
||||
self,
|
||||
model: str,
|
||||
messages: List[Dict[str, str]],
|
||||
max_tokens: int,
|
||||
temperature: float,
|
||||
stream: Optional[bool] = True,
|
||||
tools: Optional[List[dict]] = None,
|
||||
tool_choice: Literal["none", "auto", "required"] = "auto",
|
||||
**kwargs
|
||||
) -> OpenAIResponse:
|
||||
# Main entry point for chat completion
|
||||
bedrock_tools = []
|
||||
if tools is not None:
|
||||
bedrock_tools = self._convert_openai_tools_to_bedrock_format(tools)
|
||||
if stream:
|
||||
return self._invoke_bedrock_stream(model, messages, max_tokens, temperature, bedrock_tools, tool_choice, **kwargs)
|
||||
else:
|
||||
return self._invoke_bedrock(model, messages, max_tokens, temperature, bedrock_tools, tool_choice, **kwargs)
|
@ -37,6 +37,18 @@ class ProxySettings(BaseModel):
|
||||
|
||||
class SearchSettings(BaseModel):
|
||||
engine: str = Field(default="Google", description="Search engine the llm to use")
|
||||
fallback_engines: List[str] = Field(
|
||||
default_factory=lambda: ["DuckDuckGo", "Baidu"],
|
||||
description="Fallback search engines to try if the primary engine fails",
|
||||
)
|
||||
retry_delay: int = Field(
|
||||
default=60,
|
||||
description="Seconds to wait before retrying all engines again after they all fail",
|
||||
)
|
||||
max_retries: int = Field(
|
||||
default=3,
|
||||
description="Maximum number of times to retry all engines when all fail",
|
||||
)
|
||||
|
||||
|
||||
class BrowserSettings(BaseModel):
|
||||
|
@ -28,6 +28,7 @@ from app.schema import (
|
||||
Message,
|
||||
ToolChoice,
|
||||
)
|
||||
from app.bedrock import BedrockClient
|
||||
|
||||
|
||||
REASONING_MODELS = ["o1", "o3-mini"]
|
||||
@ -225,6 +226,8 @@ class LLM:
|
||||
api_key=self.api_key,
|
||||
api_version=self.api_version,
|
||||
)
|
||||
elif self.api_type == "aws":
|
||||
self.client = BedrockClient()
|
||||
else:
|
||||
self.client = AsyncOpenAI(api_key=self.api_key, base_url=self.base_url)
|
||||
|
||||
|
@ -4,6 +4,7 @@ from typing import List
|
||||
from tenacity import retry, stop_after_attempt, wait_exponential
|
||||
|
||||
from app.config import config
|
||||
from app.logger import logger
|
||||
from app.tool.base import BaseTool
|
||||
from app.tool.search import (
|
||||
BaiduSearchEngine,
|
||||
@ -44,6 +45,8 @@ class WebSearch(BaseTool):
|
||||
async def execute(self, query: str, num_results: int = 10) -> List[str]:
|
||||
"""
|
||||
Execute a Web search and return a list of URLs.
|
||||
Tries engines in order based on configuration, falling back if an engine fails with errors.
|
||||
If all engines fail, it will wait and retry up to the configured number of times.
|
||||
|
||||
Args:
|
||||
query (str): The search query to submit to the search engine.
|
||||
@ -52,37 +55,109 @@ class WebSearch(BaseTool):
|
||||
Returns:
|
||||
List[str]: A list of URLs matching the search query.
|
||||
"""
|
||||
# Get retry settings from config
|
||||
retry_delay = 60 # Default to 60 seconds
|
||||
max_retries = 3 # Default to 3 retries
|
||||
|
||||
if config.search_config:
|
||||
retry_delay = getattr(config.search_config, "retry_delay", 60)
|
||||
max_retries = getattr(config.search_config, "max_retries", 3)
|
||||
|
||||
# Try searching with retries when all engines fail
|
||||
for retry_count in range(
|
||||
max_retries + 1
|
||||
): # +1 because first try is not a retry
|
||||
links = await self._try_all_engines(query, num_results)
|
||||
if links:
|
||||
return links
|
||||
|
||||
if retry_count < max_retries:
|
||||
# All engines failed, wait and retry
|
||||
logger.warning(
|
||||
f"All search engines failed. Waiting {retry_delay} seconds before retry {retry_count + 1}/{max_retries}..."
|
||||
)
|
||||
await asyncio.sleep(retry_delay)
|
||||
else:
|
||||
logger.error(
|
||||
f"All search engines failed after {max_retries} retries. Giving up."
|
||||
)
|
||||
|
||||
return []
|
||||
|
||||
async def _try_all_engines(self, query: str, num_results: int) -> List[str]:
|
||||
"""
|
||||
Try all search engines in the configured order.
|
||||
|
||||
Args:
|
||||
query (str): The search query to submit to the search engine.
|
||||
num_results (int): The number of search results to return.
|
||||
|
||||
Returns:
|
||||
List[str]: A list of URLs matching the search query, or empty list if all engines fail.
|
||||
"""
|
||||
engine_order = self._get_engine_order()
|
||||
failed_engines = []
|
||||
|
||||
for engine_name in engine_order:
|
||||
engine = self._search_engine[engine_name]
|
||||
try:
|
||||
logger.info(f"🔎 Attempting search with {engine_name.capitalize()}...")
|
||||
links = await self._perform_search_with_engine(
|
||||
engine, query, num_results
|
||||
)
|
||||
if links:
|
||||
if failed_engines:
|
||||
logger.info(
|
||||
f"Search successful with {engine_name.capitalize()} after trying: {', '.join(failed_engines)}"
|
||||
)
|
||||
return links
|
||||
except Exception as e:
|
||||
print(f"Search engine '{engine_name}' failed with error: {e}")
|
||||
failed_engines.append(engine_name.capitalize())
|
||||
is_rate_limit = "429" in str(e) or "Too Many Requests" in str(e)
|
||||
|
||||
if is_rate_limit:
|
||||
logger.warning(
|
||||
f"⚠️ {engine_name.capitalize()} search engine rate limit exceeded, trying next engine..."
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
f"⚠️ {engine_name.capitalize()} search failed with error: {e}"
|
||||
)
|
||||
|
||||
if failed_engines:
|
||||
logger.error(f"All search engines failed: {', '.join(failed_engines)}")
|
||||
return []
|
||||
|
||||
def _get_engine_order(self) -> List[str]:
|
||||
"""
|
||||
Determines the order in which to try search engines.
|
||||
Preferred engine is first (based on configuration), followed by the remaining engines.
|
||||
Preferred engine is first (based on configuration), followed by fallback engines,
|
||||
and then the remaining engines.
|
||||
|
||||
Returns:
|
||||
List[str]: Ordered list of search engine names.
|
||||
"""
|
||||
preferred = "google"
|
||||
if config.search_config and config.search_config.engine:
|
||||
fallbacks = []
|
||||
|
||||
if config.search_config:
|
||||
if config.search_config.engine:
|
||||
preferred = config.search_config.engine.lower()
|
||||
if config.search_config.fallback_engines:
|
||||
fallbacks = [
|
||||
engine.lower() for engine in config.search_config.fallback_engines
|
||||
]
|
||||
|
||||
engine_order = []
|
||||
# Add preferred engine first
|
||||
if preferred in self._search_engine:
|
||||
engine_order.append(preferred)
|
||||
for key in self._search_engine:
|
||||
if key not in engine_order:
|
||||
engine_order.append(key)
|
||||
|
||||
# Add configured fallback engines in order
|
||||
for fallback in fallbacks:
|
||||
if fallback in self._search_engine and fallback not in engine_order:
|
||||
engine_order.append(fallback)
|
||||
|
||||
return engine_order
|
||||
|
||||
@retry(
|
||||
|
@ -6,6 +6,14 @@ api_key = "YOUR_API_KEY" # Your API key
|
||||
max_tokens = 8192 # Maximum number of tokens in the response
|
||||
temperature = 0.0 # Controls randomness
|
||||
|
||||
# [llm] # Amazon Bedrock
|
||||
# api_type = "aws" # Required
|
||||
# model = "us.anthropic.claude-3-7-sonnet-20250219-v1:0" # Bedrock supported modelID
|
||||
# base_url = "bedrock-runtime.us-west-2.amazonaws.com" # Not used now
|
||||
# max_tokens = 8192
|
||||
# temperature = 1.0
|
||||
# api_key = "bear" # Required but not used for Bedrock
|
||||
|
||||
# [llm] #AZURE OPENAI:
|
||||
# api_type= 'azure'
|
||||
# model = "YOUR_MODEL_NAME" #"gpt-4o-mini"
|
||||
@ -65,6 +73,13 @@ temperature = 0.0 # Controls randomness for vision mod
|
||||
# [search]
|
||||
# Search engine for agent to use. Default is "Google", can be set to "Baidu" or "DuckDuckGo".
|
||||
#engine = "Google"
|
||||
# Fallback engine order. Default is ["DuckDuckGo", "Baidu"] - will try in this order after primary engine fails.
|
||||
#fallback_engines = ["DuckDuckGo", "Baidu"]
|
||||
# Seconds to wait before retrying all engines again when they all fail due to rate limits. Default is 60.
|
||||
#retry_delay = 60
|
||||
# Maximum number of times to retry all engines when all fail. Default is 3.
|
||||
#max_retries = 3
|
||||
|
||||
|
||||
## Sandbox configuration
|
||||
#[sandbox]
|
||||
@ -75,3 +90,4 @@ temperature = 0.0 # Controls randomness for vision mod
|
||||
#cpu_limit = 2.0
|
||||
#timeout = 300
|
||||
#network_enabled = true
|
||||
>>>>>>>>> Temporary merge branch 2
|
||||
|
@ -31,3 +31,5 @@ pytest-asyncio~=0.25.3
|
||||
mcp~=1.4.1
|
||||
httpx>=0.27.0
|
||||
tomli>=2.0.0
|
||||
|
||||
boto3~=1.37.16
|
||||
|
Loading…
x
Reference in New Issue
Block a user