diff --git a/app/bedrock.py b/app/bedrock.py new file mode 100644 index 0000000..6572a56 --- /dev/null +++ b/app/bedrock.py @@ -0,0 +1,275 @@ +from typing import Dict, List, Literal, Optional, Union +import boto3 +import json +import time +import uuid +from datetime import datetime +import sys + +# Global variables to track the current tool use ID across function calls +# Tmp solution +CURRENT_TOOLUSE_ID = None + +# Class to handle OpenAI-style response formatting +class OpenAIResponse: + def __init__(self, data): + # Recursively convert nested dicts and lists to OpenAIResponse objects + for key, value in data.items(): + if isinstance(value, dict): + value = OpenAIResponse(value) + elif isinstance(value, list): + value = [OpenAIResponse(item) if isinstance(item, dict) else item for item in value] + setattr(self, key, value) + + def model_dump(self, *args, **kwargs): + # Convert object to dict and add timestamp + data = self.__dict__ + data['created_at'] = datetime.now().isoformat() + return data + +# Main client class for interacting with Amazon Bedrock +class BedrockClient: + def __init__(self): + # Initialize Bedrock client, you need to configure AWS env first + try: + self.client = boto3.client('bedrock-runtime') + self.chat = Chat(self.client) + except Exception as e: + print(f"Error initializing Bedrock client: {e}") + sys.exit(1) + +# Chat interface class +class Chat: + def __init__(self, client): + self.completions = ChatCompletions(client) + +# Core class handling chat completions functionality +class ChatCompletions: + def __init__(self, client): + self.client = client + + def _convert_openai_tools_to_bedrock_format(self, tools): + # Convert OpenAI function calling format to Bedrock tool format + bedrock_tools = [] + for tool in tools: + if tool.get('type') == 'function': + function = tool.get('function', {}) + bedrock_tool = { + "toolSpec": { + "name": function.get('name', ''), + "description": function.get('description', ''), + "inputSchema": { + "json": { + "type": "object", + "properties": function.get('parameters', {}).get('properties', {}), + "required": function.get('parameters', {}).get('required', []) + } + } + } + } + bedrock_tools.append(bedrock_tool) + return bedrock_tools + + def _convert_openai_messages_to_bedrock_format(self, messages): + # Convert OpenAI message format to Bedrock message format + bedrock_messages = [] + system_prompt = [] + for message in messages: + if message.get('role') == 'system': + system_prompt = [{"text": message.get('content')}] + elif message.get('role') == 'user': + bedrock_message = { + "role": message.get('role', 'user'), + "content": [{"text": message.get('content')}] + } + bedrock_messages.append(bedrock_message) + elif message.get('role') == 'assistant': + bedrock_message = { + "role": "assistant", + "content": [{"text": message.get('content')}] + } + openai_tool_calls = message.get('tool_calls', []) + if openai_tool_calls: + bedrock_tool_use = { + "toolUseId": openai_tool_calls[0]['id'], + "name": openai_tool_calls[0]['function']['name'], + "input": json.loads(openai_tool_calls[0]['function']['arguments']) + } + bedrock_message['content'].append({"toolUse": bedrock_tool_use}) + global CURRENT_TOOLUSE_ID + CURRENT_TOOLUSE_ID = openai_tool_calls[0]['id'] + bedrock_messages.append(bedrock_message) + elif message.get('role') == 'tool': + bedrock_message = { + "role": "user", + "content": [ + { + "toolResult": { + "toolUseId": CURRENT_TOOLUSE_ID, + "content": [{"text":message.get('content')}] + } + } + ] + } + bedrock_messages.append(bedrock_message) + else: + raise ValueError(f"Invalid role: {message.get('role')}") + return system_prompt, bedrock_messages + + def _convert_bedrock_response_to_openai_format(self, bedrock_response): + # Convert Bedrock response format to OpenAI format + content = "" + if bedrock_response.get('output', {}).get('message', {}).get('content'): + content_array = bedrock_response['output']['message']['content'] + content = "".join(item.get('text', '') for item in content_array) + if content == "": content = "." + + # Handle tool calls in response + openai_tool_calls = [] + if bedrock_response.get('output', {}).get('message', {}).get('content'): + for content_item in bedrock_response['output']['message']['content']: + if content_item.get('toolUse'): + bedrock_tool_use = content_item['toolUse'] + global CURRENT_TOOLUSE_ID + CURRENT_TOOLUSE_ID = bedrock_tool_use['toolUseId'] + openai_tool_call = { + 'id': CURRENT_TOOLUSE_ID, + 'type': 'function', + 'function': { + 'name': bedrock_tool_use['name'], + 'arguments': json.dumps(bedrock_tool_use['input']) + } + } + openai_tool_calls.append(openai_tool_call) + + # Construct final OpenAI format response + openai_format = { + "id": f"chatcmpl-{uuid.uuid4()}", + "created": int(time.time()), + "object": "chat.completion", + "system_fingerprint": None, + "choices": [ + { + "finish_reason": bedrock_response.get('stopReason', 'end_turn'), + "index": 0, + "message": { + "content": content, + "role": bedrock_response.get('output', {}).get('message', {}).get('role', 'assistant'), + "tool_calls": openai_tool_calls if openai_tool_calls != [] else None, + "function_call": None + } + } + ], + "usage": { + "completion_tokens": bedrock_response.get('usage', {}).get('outputTokens', 0), + "prompt_tokens": bedrock_response.get('usage', {}).get('inputTokens', 0), + "total_tokens": bedrock_response.get('usage', {}).get('totalTokens', 0) + } + } + return OpenAIResponse(openai_format) + + async def _invoke_bedrock( + self, + model: str, + messages: List[Dict[str, str]], + max_tokens: int, + temperature: float, + tools: Optional[List[dict]] = None, + tool_choice: Literal["none", "auto", "required"] = "auto", + **kwargs + ) -> OpenAIResponse: + # Non-streaming invocation of Bedrock model + system_prompt, bedrock_messages = self._convert_openai_messages_to_bedrock_format(messages) + response = self.client.converse( + modelId = model, + system = system_prompt, + messages = bedrock_messages, + inferenceConfig = {"temperature": temperature, "maxTokens": max_tokens}, + toolConfig = {"tools": tools} if tools else None, + ) + openai_response = self._convert_bedrock_response_to_openai_format(response) + return openai_response + + async def _invoke_bedrock_stream( + self, + model: str, + messages: List[Dict[str, str]], + max_tokens: int, + temperature: float, + tools: Optional[List[dict]] = None, + tool_choice: Literal["none", "auto", "required"] = "auto", + **kwargs + ) -> OpenAIResponse: + # Streaming invocation of Bedrock model + system_prompt, bedrock_messages = self._convert_openai_messages_to_bedrock_format(messages) + response = self.client.converse_stream( + modelId = model, + system = system_prompt, + messages = bedrock_messages, + inferenceConfig = {"temperature": temperature, "maxTokens": max_tokens}, + toolConfig = {"tools": tools} if tools else None, + ) + + # Initialize response structure + bedrock_response = { + 'output': { + 'message': { + 'role': '', + 'content': [] + } + }, + 'stopReason': '', + 'usage': {}, + 'metrics': {} + } + bedrock_response_text = "" + bedrock_response_tool_input = "" + + # Process streaming response + stream = response.get('stream') + if stream: + for event in stream: + if event.get('messageStart', {}).get('role'): + bedrock_response['output']['message']['role'] = event['messageStart']['role'] + if event.get('contentBlockDelta', {}).get('delta', {}).get('text'): + bedrock_response_text += event['contentBlockDelta']['delta']['text'] + print(event['contentBlockDelta']['delta']['text'], end='', flush=True) + if event.get('contentBlockStop', {}).get('contentBlockIndex') == 0: + bedrock_response['output']['message']['content'].append({"text": bedrock_response_text}) + if event.get('contentBlockStart', {}).get('start', {}).get('toolUse'): + bedrock_tool_use = event['contentBlockStart']['start']['toolUse'] + tool_use = { + "toolUseId": bedrock_tool_use['toolUseId'], + "name": bedrock_tool_use['name'], + } + bedrock_response['output']['message']['content'].append({"toolUse": tool_use}) + global CURRENT_TOOLUSE_ID + CURRENT_TOOLUSE_ID = bedrock_tool_use['toolUseId'] + if event.get('contentBlockDelta', {}).get('delta', {}).get('toolUse'): + bedrock_response_tool_input += event['contentBlockDelta']['delta']['toolUse']['input'] + print(event['contentBlockDelta']['delta']['toolUse']['input'], end='', flush=True) + if event.get('contentBlockStop', {}).get('contentBlockIndex') == 1: + bedrock_response['output']['message']['content'][1]['toolUse']['input'] = json.loads(bedrock_response_tool_input) + print() + openai_response = self._convert_bedrock_response_to_openai_format(bedrock_response) + return openai_response + + def create( + self, + model: str, + messages: List[Dict[str, str]], + max_tokens: int, + temperature: float, + stream: Optional[bool] = True, + tools: Optional[List[dict]] = None, + tool_choice: Literal["none", "auto", "required"] = "auto", + **kwargs + ) -> OpenAIResponse: + # Main entry point for chat completion + bedrock_tools = [] + if tools is not None: + bedrock_tools = self._convert_openai_tools_to_bedrock_format(tools) + if stream: + return self._invoke_bedrock_stream(model, messages, max_tokens, temperature, bedrock_tools, tool_choice, **kwargs) + else: + return self._invoke_bedrock(model, messages, max_tokens, temperature, bedrock_tools, tool_choice, **kwargs) diff --git a/app/config.py b/app/config.py index 9459707..7088c42 100644 --- a/app/config.py +++ b/app/config.py @@ -37,6 +37,18 @@ class ProxySettings(BaseModel): class SearchSettings(BaseModel): engine: str = Field(default="Google", description="Search engine the llm to use") + fallback_engines: List[str] = Field( + default_factory=lambda: ["DuckDuckGo", "Baidu"], + description="Fallback search engines to try if the primary engine fails", + ) + retry_delay: int = Field( + default=60, + description="Seconds to wait before retrying all engines again after they all fail", + ) + max_retries: int = Field( + default=3, + description="Maximum number of times to retry all engines when all fail", + ) class BrowserSettings(BaseModel): diff --git a/app/llm.py b/app/llm.py index 1a4e05b..d0bb231 100644 --- a/app/llm.py +++ b/app/llm.py @@ -28,6 +28,7 @@ from app.schema import ( Message, ToolChoice, ) +from app.bedrock import BedrockClient REASONING_MODELS = ["o1", "o3-mini"] @@ -225,6 +226,8 @@ class LLM: api_key=self.api_key, api_version=self.api_version, ) + elif self.api_type == "aws": + self.client = BedrockClient() else: self.client = AsyncOpenAI(api_key=self.api_key, base_url=self.base_url) diff --git a/app/tool/web_search.py b/app/tool/web_search.py index cb13934..950bb65 100644 --- a/app/tool/web_search.py +++ b/app/tool/web_search.py @@ -4,6 +4,7 @@ from typing import List from tenacity import retry, stop_after_attempt, wait_exponential from app.config import config +from app.logger import logger from app.tool.base import BaseTool from app.tool.search import ( BaiduSearchEngine, @@ -44,6 +45,8 @@ class WebSearch(BaseTool): async def execute(self, query: str, num_results: int = 10) -> List[str]: """ Execute a Web search and return a list of URLs. + Tries engines in order based on configuration, falling back if an engine fails with errors. + If all engines fail, it will wait and retry up to the configured number of times. Args: query (str): The search query to submit to the search engine. @@ -52,37 +55,109 @@ class WebSearch(BaseTool): Returns: List[str]: A list of URLs matching the search query. """ + # Get retry settings from config + retry_delay = 60 # Default to 60 seconds + max_retries = 3 # Default to 3 retries + + if config.search_config: + retry_delay = getattr(config.search_config, "retry_delay", 60) + max_retries = getattr(config.search_config, "max_retries", 3) + + # Try searching with retries when all engines fail + for retry_count in range( + max_retries + 1 + ): # +1 because first try is not a retry + links = await self._try_all_engines(query, num_results) + if links: + return links + + if retry_count < max_retries: + # All engines failed, wait and retry + logger.warning( + f"All search engines failed. Waiting {retry_delay} seconds before retry {retry_count + 1}/{max_retries}..." + ) + await asyncio.sleep(retry_delay) + else: + logger.error( + f"All search engines failed after {max_retries} retries. Giving up." + ) + + return [] + + async def _try_all_engines(self, query: str, num_results: int) -> List[str]: + """ + Try all search engines in the configured order. + + Args: + query (str): The search query to submit to the search engine. + num_results (int): The number of search results to return. + + Returns: + List[str]: A list of URLs matching the search query, or empty list if all engines fail. + """ engine_order = self._get_engine_order() + failed_engines = [] + for engine_name in engine_order: engine = self._search_engine[engine_name] try: + logger.info(f"🔎 Attempting search with {engine_name.capitalize()}...") links = await self._perform_search_with_engine( engine, query, num_results ) if links: + if failed_engines: + logger.info( + f"Search successful with {engine_name.capitalize()} after trying: {', '.join(failed_engines)}" + ) return links except Exception as e: - print(f"Search engine '{engine_name}' failed with error: {e}") + failed_engines.append(engine_name.capitalize()) + is_rate_limit = "429" in str(e) or "Too Many Requests" in str(e) + + if is_rate_limit: + logger.warning( + f"⚠️ {engine_name.capitalize()} search engine rate limit exceeded, trying next engine..." + ) + else: + logger.warning( + f"⚠️ {engine_name.capitalize()} search failed with error: {e}" + ) + + if failed_engines: + logger.error(f"All search engines failed: {', '.join(failed_engines)}") return [] def _get_engine_order(self) -> List[str]: """ Determines the order in which to try search engines. - Preferred engine is first (based on configuration), followed by the remaining engines. + Preferred engine is first (based on configuration), followed by fallback engines, + and then the remaining engines. Returns: List[str]: Ordered list of search engine names. """ preferred = "google" - if config.search_config and config.search_config.engine: - preferred = config.search_config.engine.lower() + fallbacks = [] + + if config.search_config: + if config.search_config.engine: + preferred = config.search_config.engine.lower() + if config.search_config.fallback_engines: + fallbacks = [ + engine.lower() for engine in config.search_config.fallback_engines + ] engine_order = [] + # Add preferred engine first if preferred in self._search_engine: engine_order.append(preferred) - for key in self._search_engine: - if key not in engine_order: - engine_order.append(key) + + # Add configured fallback engines in order + for fallback in fallbacks: + if fallback in self._search_engine and fallback not in engine_order: + engine_order.append(fallback) + return engine_order @retry( diff --git a/config/config.example.toml b/config/config.example.toml index d5750a2..db8045d 100644 --- a/config/config.example.toml +++ b/config/config.example.toml @@ -6,6 +6,14 @@ api_key = "YOUR_API_KEY" # Your API key max_tokens = 8192 # Maximum number of tokens in the response temperature = 0.0 # Controls randomness +# [llm] # Amazon Bedrock +# api_type = "aws" # Required +# model = "us.anthropic.claude-3-7-sonnet-20250219-v1:0" # Bedrock supported modelID +# base_url = "bedrock-runtime.us-west-2.amazonaws.com" # Not used now +# max_tokens = 8192 +# temperature = 1.0 +# api_key = "bear" # Required but not used for Bedrock + # [llm] #AZURE OPENAI: # api_type= 'azure' # model = "YOUR_MODEL_NAME" #"gpt-4o-mini" @@ -65,6 +73,13 @@ temperature = 0.0 # Controls randomness for vision mod # [search] # Search engine for agent to use. Default is "Google", can be set to "Baidu" or "DuckDuckGo". #engine = "Google" +# Fallback engine order. Default is ["DuckDuckGo", "Baidu"] - will try in this order after primary engine fails. +#fallback_engines = ["DuckDuckGo", "Baidu"] +# Seconds to wait before retrying all engines again when they all fail due to rate limits. Default is 60. +#retry_delay = 60 +# Maximum number of times to retry all engines when all fail. Default is 3. +#max_retries = 3 + ## Sandbox configuration #[sandbox] @@ -75,3 +90,4 @@ temperature = 0.0 # Controls randomness for vision mod #cpu_limit = 2.0 #timeout = 300 #network_enabled = true +>>>>>>>>> Temporary merge branch 2 diff --git a/requirements.txt b/requirements.txt index 7e7b82f..9bac8a2 100644 --- a/requirements.txt +++ b/requirements.txt @@ -31,3 +31,5 @@ pytest-asyncio~=0.25.3 mcp~=1.4.1 httpx>=0.27.0 tomli>=2.0.0 + +boto3~=1.37.16