Compare commits
24 Commits
first_hack
...
main
Author | SHA1 | Date | |
---|---|---|---|
|
c432ec9286 | ||
|
65470c2ae9 | ||
|
30342247c0 | ||
|
a61ef9b737 | ||
|
82e3140357 | ||
|
d0492a500e | ||
|
35209978e1 | ||
|
3dd990e554 | ||
|
e218c0655f | ||
|
5d18b5dc69 | ||
|
567bffb441 | ||
|
acb435f9f5 | ||
|
d63e88f089 | ||
|
c3de3ad6f7 | ||
|
08a20f6880 | ||
|
2a13cb49f3 | ||
|
4e10b42b30 | ||
|
59a92257be | ||
|
855caad4d9 | ||
|
95e3487402 | ||
|
fe44fe726d | ||
|
c7858c2eb4 | ||
|
9fa12e594c | ||
|
711c2805e4 |
334
app/bedrock.py
Normal file
334
app/bedrock.py
Normal file
@ -0,0 +1,334 @@
|
|||||||
|
import json
|
||||||
|
import sys
|
||||||
|
import time
|
||||||
|
import uuid
|
||||||
|
from datetime import datetime
|
||||||
|
from typing import Dict, List, Literal, Optional
|
||||||
|
|
||||||
|
import boto3
|
||||||
|
|
||||||
|
|
||||||
|
# Global variables to track the current tool use ID across function calls
|
||||||
|
# Tmp solution
|
||||||
|
CURRENT_TOOLUSE_ID = None
|
||||||
|
|
||||||
|
|
||||||
|
# Class to handle OpenAI-style response formatting
|
||||||
|
class OpenAIResponse:
|
||||||
|
def __init__(self, data):
|
||||||
|
# Recursively convert nested dicts and lists to OpenAIResponse objects
|
||||||
|
for key, value in data.items():
|
||||||
|
if isinstance(value, dict):
|
||||||
|
value = OpenAIResponse(value)
|
||||||
|
elif isinstance(value, list):
|
||||||
|
value = [
|
||||||
|
OpenAIResponse(item) if isinstance(item, dict) else item
|
||||||
|
for item in value
|
||||||
|
]
|
||||||
|
setattr(self, key, value)
|
||||||
|
|
||||||
|
def model_dump(self, *args, **kwargs):
|
||||||
|
# Convert object to dict and add timestamp
|
||||||
|
data = self.__dict__
|
||||||
|
data["created_at"] = datetime.now().isoformat()
|
||||||
|
return data
|
||||||
|
|
||||||
|
|
||||||
|
# Main client class for interacting with Amazon Bedrock
|
||||||
|
class BedrockClient:
|
||||||
|
def __init__(self):
|
||||||
|
# Initialize Bedrock client, you need to configure AWS env first
|
||||||
|
try:
|
||||||
|
self.client = boto3.client("bedrock-runtime")
|
||||||
|
self.chat = Chat(self.client)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error initializing Bedrock client: {e}")
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
|
||||||
|
# Chat interface class
|
||||||
|
class Chat:
|
||||||
|
def __init__(self, client):
|
||||||
|
self.completions = ChatCompletions(client)
|
||||||
|
|
||||||
|
|
||||||
|
# Core class handling chat completions functionality
|
||||||
|
class ChatCompletions:
|
||||||
|
def __init__(self, client):
|
||||||
|
self.client = client
|
||||||
|
|
||||||
|
def _convert_openai_tools_to_bedrock_format(self, tools):
|
||||||
|
# Convert OpenAI function calling format to Bedrock tool format
|
||||||
|
bedrock_tools = []
|
||||||
|
for tool in tools:
|
||||||
|
if tool.get("type") == "function":
|
||||||
|
function = tool.get("function", {})
|
||||||
|
bedrock_tool = {
|
||||||
|
"toolSpec": {
|
||||||
|
"name": function.get("name", ""),
|
||||||
|
"description": function.get("description", ""),
|
||||||
|
"inputSchema": {
|
||||||
|
"json": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": function.get("parameters", {}).get(
|
||||||
|
"properties", {}
|
||||||
|
),
|
||||||
|
"required": function.get("parameters", {}).get(
|
||||||
|
"required", []
|
||||||
|
),
|
||||||
|
}
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
bedrock_tools.append(bedrock_tool)
|
||||||
|
return bedrock_tools
|
||||||
|
|
||||||
|
def _convert_openai_messages_to_bedrock_format(self, messages):
|
||||||
|
# Convert OpenAI message format to Bedrock message format
|
||||||
|
bedrock_messages = []
|
||||||
|
system_prompt = []
|
||||||
|
for message in messages:
|
||||||
|
if message.get("role") == "system":
|
||||||
|
system_prompt = [{"text": message.get("content")}]
|
||||||
|
elif message.get("role") == "user":
|
||||||
|
bedrock_message = {
|
||||||
|
"role": message.get("role", "user"),
|
||||||
|
"content": [{"text": message.get("content")}],
|
||||||
|
}
|
||||||
|
bedrock_messages.append(bedrock_message)
|
||||||
|
elif message.get("role") == "assistant":
|
||||||
|
bedrock_message = {
|
||||||
|
"role": "assistant",
|
||||||
|
"content": [{"text": message.get("content")}],
|
||||||
|
}
|
||||||
|
openai_tool_calls = message.get("tool_calls", [])
|
||||||
|
if openai_tool_calls:
|
||||||
|
bedrock_tool_use = {
|
||||||
|
"toolUseId": openai_tool_calls[0]["id"],
|
||||||
|
"name": openai_tool_calls[0]["function"]["name"],
|
||||||
|
"input": json.loads(
|
||||||
|
openai_tool_calls[0]["function"]["arguments"]
|
||||||
|
),
|
||||||
|
}
|
||||||
|
bedrock_message["content"].append({"toolUse": bedrock_tool_use})
|
||||||
|
global CURRENT_TOOLUSE_ID
|
||||||
|
CURRENT_TOOLUSE_ID = openai_tool_calls[0]["id"]
|
||||||
|
bedrock_messages.append(bedrock_message)
|
||||||
|
elif message.get("role") == "tool":
|
||||||
|
bedrock_message = {
|
||||||
|
"role": "user",
|
||||||
|
"content": [
|
||||||
|
{
|
||||||
|
"toolResult": {
|
||||||
|
"toolUseId": CURRENT_TOOLUSE_ID,
|
||||||
|
"content": [{"text": message.get("content")}],
|
||||||
|
}
|
||||||
|
}
|
||||||
|
],
|
||||||
|
}
|
||||||
|
bedrock_messages.append(bedrock_message)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Invalid role: {message.get('role')}")
|
||||||
|
return system_prompt, bedrock_messages
|
||||||
|
|
||||||
|
def _convert_bedrock_response_to_openai_format(self, bedrock_response):
|
||||||
|
# Convert Bedrock response format to OpenAI format
|
||||||
|
content = ""
|
||||||
|
if bedrock_response.get("output", {}).get("message", {}).get("content"):
|
||||||
|
content_array = bedrock_response["output"]["message"]["content"]
|
||||||
|
content = "".join(item.get("text", "") for item in content_array)
|
||||||
|
if content == "":
|
||||||
|
content = "."
|
||||||
|
|
||||||
|
# Handle tool calls in response
|
||||||
|
openai_tool_calls = []
|
||||||
|
if bedrock_response.get("output", {}).get("message", {}).get("content"):
|
||||||
|
for content_item in bedrock_response["output"]["message"]["content"]:
|
||||||
|
if content_item.get("toolUse"):
|
||||||
|
bedrock_tool_use = content_item["toolUse"]
|
||||||
|
global CURRENT_TOOLUSE_ID
|
||||||
|
CURRENT_TOOLUSE_ID = bedrock_tool_use["toolUseId"]
|
||||||
|
openai_tool_call = {
|
||||||
|
"id": CURRENT_TOOLUSE_ID,
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": bedrock_tool_use["name"],
|
||||||
|
"arguments": json.dumps(bedrock_tool_use["input"]),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
openai_tool_calls.append(openai_tool_call)
|
||||||
|
|
||||||
|
# Construct final OpenAI format response
|
||||||
|
openai_format = {
|
||||||
|
"id": f"chatcmpl-{uuid.uuid4()}",
|
||||||
|
"created": int(time.time()),
|
||||||
|
"object": "chat.completion",
|
||||||
|
"system_fingerprint": None,
|
||||||
|
"choices": [
|
||||||
|
{
|
||||||
|
"finish_reason": bedrock_response.get("stopReason", "end_turn"),
|
||||||
|
"index": 0,
|
||||||
|
"message": {
|
||||||
|
"content": content,
|
||||||
|
"role": bedrock_response.get("output", {})
|
||||||
|
.get("message", {})
|
||||||
|
.get("role", "assistant"),
|
||||||
|
"tool_calls": openai_tool_calls
|
||||||
|
if openai_tool_calls != []
|
||||||
|
else None,
|
||||||
|
"function_call": None,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"usage": {
|
||||||
|
"completion_tokens": bedrock_response.get("usage", {}).get(
|
||||||
|
"outputTokens", 0
|
||||||
|
),
|
||||||
|
"prompt_tokens": bedrock_response.get("usage", {}).get(
|
||||||
|
"inputTokens", 0
|
||||||
|
),
|
||||||
|
"total_tokens": bedrock_response.get("usage", {}).get("totalTokens", 0),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
return OpenAIResponse(openai_format)
|
||||||
|
|
||||||
|
async def _invoke_bedrock(
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
messages: List[Dict[str, str]],
|
||||||
|
max_tokens: int,
|
||||||
|
temperature: float,
|
||||||
|
tools: Optional[List[dict]] = None,
|
||||||
|
tool_choice: Literal["none", "auto", "required"] = "auto",
|
||||||
|
**kwargs,
|
||||||
|
) -> OpenAIResponse:
|
||||||
|
# Non-streaming invocation of Bedrock model
|
||||||
|
(
|
||||||
|
system_prompt,
|
||||||
|
bedrock_messages,
|
||||||
|
) = self._convert_openai_messages_to_bedrock_format(messages)
|
||||||
|
response = self.client.converse(
|
||||||
|
modelId=model,
|
||||||
|
system=system_prompt,
|
||||||
|
messages=bedrock_messages,
|
||||||
|
inferenceConfig={"temperature": temperature, "maxTokens": max_tokens},
|
||||||
|
toolConfig={"tools": tools} if tools else None,
|
||||||
|
)
|
||||||
|
openai_response = self._convert_bedrock_response_to_openai_format(response)
|
||||||
|
return openai_response
|
||||||
|
|
||||||
|
async def _invoke_bedrock_stream(
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
messages: List[Dict[str, str]],
|
||||||
|
max_tokens: int,
|
||||||
|
temperature: float,
|
||||||
|
tools: Optional[List[dict]] = None,
|
||||||
|
tool_choice: Literal["none", "auto", "required"] = "auto",
|
||||||
|
**kwargs,
|
||||||
|
) -> OpenAIResponse:
|
||||||
|
# Streaming invocation of Bedrock model
|
||||||
|
(
|
||||||
|
system_prompt,
|
||||||
|
bedrock_messages,
|
||||||
|
) = self._convert_openai_messages_to_bedrock_format(messages)
|
||||||
|
response = self.client.converse_stream(
|
||||||
|
modelId=model,
|
||||||
|
system=system_prompt,
|
||||||
|
messages=bedrock_messages,
|
||||||
|
inferenceConfig={"temperature": temperature, "maxTokens": max_tokens},
|
||||||
|
toolConfig={"tools": tools} if tools else None,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Initialize response structure
|
||||||
|
bedrock_response = {
|
||||||
|
"output": {"message": {"role": "", "content": []}},
|
||||||
|
"stopReason": "",
|
||||||
|
"usage": {},
|
||||||
|
"metrics": {},
|
||||||
|
}
|
||||||
|
bedrock_response_text = ""
|
||||||
|
bedrock_response_tool_input = ""
|
||||||
|
|
||||||
|
# Process streaming response
|
||||||
|
stream = response.get("stream")
|
||||||
|
if stream:
|
||||||
|
for event in stream:
|
||||||
|
if event.get("messageStart", {}).get("role"):
|
||||||
|
bedrock_response["output"]["message"]["role"] = event[
|
||||||
|
"messageStart"
|
||||||
|
]["role"]
|
||||||
|
if event.get("contentBlockDelta", {}).get("delta", {}).get("text"):
|
||||||
|
bedrock_response_text += event["contentBlockDelta"]["delta"]["text"]
|
||||||
|
print(
|
||||||
|
event["contentBlockDelta"]["delta"]["text"], end="", flush=True
|
||||||
|
)
|
||||||
|
if event.get("contentBlockStop", {}).get("contentBlockIndex") == 0:
|
||||||
|
bedrock_response["output"]["message"]["content"].append(
|
||||||
|
{"text": bedrock_response_text}
|
||||||
|
)
|
||||||
|
if event.get("contentBlockStart", {}).get("start", {}).get("toolUse"):
|
||||||
|
bedrock_tool_use = event["contentBlockStart"]["start"]["toolUse"]
|
||||||
|
tool_use = {
|
||||||
|
"toolUseId": bedrock_tool_use["toolUseId"],
|
||||||
|
"name": bedrock_tool_use["name"],
|
||||||
|
}
|
||||||
|
bedrock_response["output"]["message"]["content"].append(
|
||||||
|
{"toolUse": tool_use}
|
||||||
|
)
|
||||||
|
global CURRENT_TOOLUSE_ID
|
||||||
|
CURRENT_TOOLUSE_ID = bedrock_tool_use["toolUseId"]
|
||||||
|
if event.get("contentBlockDelta", {}).get("delta", {}).get("toolUse"):
|
||||||
|
bedrock_response_tool_input += event["contentBlockDelta"]["delta"][
|
||||||
|
"toolUse"
|
||||||
|
]["input"]
|
||||||
|
print(
|
||||||
|
event["contentBlockDelta"]["delta"]["toolUse"]["input"],
|
||||||
|
end="",
|
||||||
|
flush=True,
|
||||||
|
)
|
||||||
|
if event.get("contentBlockStop", {}).get("contentBlockIndex") == 1:
|
||||||
|
bedrock_response["output"]["message"]["content"][1]["toolUse"][
|
||||||
|
"input"
|
||||||
|
] = json.loads(bedrock_response_tool_input)
|
||||||
|
print()
|
||||||
|
openai_response = self._convert_bedrock_response_to_openai_format(
|
||||||
|
bedrock_response
|
||||||
|
)
|
||||||
|
return openai_response
|
||||||
|
|
||||||
|
def create(
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
messages: List[Dict[str, str]],
|
||||||
|
max_tokens: int,
|
||||||
|
temperature: float,
|
||||||
|
stream: Optional[bool] = True,
|
||||||
|
tools: Optional[List[dict]] = None,
|
||||||
|
tool_choice: Literal["none", "auto", "required"] = "auto",
|
||||||
|
**kwargs,
|
||||||
|
) -> OpenAIResponse:
|
||||||
|
# Main entry point for chat completion
|
||||||
|
bedrock_tools = []
|
||||||
|
if tools is not None:
|
||||||
|
bedrock_tools = self._convert_openai_tools_to_bedrock_format(tools)
|
||||||
|
if stream:
|
||||||
|
return self._invoke_bedrock_stream(
|
||||||
|
model,
|
||||||
|
messages,
|
||||||
|
max_tokens,
|
||||||
|
temperature,
|
||||||
|
bedrock_tools,
|
||||||
|
tool_choice,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
return self._invoke_bedrock(
|
||||||
|
model,
|
||||||
|
messages,
|
||||||
|
max_tokens,
|
||||||
|
temperature,
|
||||||
|
bedrock_tools,
|
||||||
|
tool_choice,
|
||||||
|
**kwargs,
|
||||||
|
)
|
@ -37,6 +37,18 @@ class ProxySettings(BaseModel):
|
|||||||
|
|
||||||
class SearchSettings(BaseModel):
|
class SearchSettings(BaseModel):
|
||||||
engine: str = Field(default="Google", description="Search engine the llm to use")
|
engine: str = Field(default="Google", description="Search engine the llm to use")
|
||||||
|
fallback_engines: List[str] = Field(
|
||||||
|
default_factory=lambda: ["DuckDuckGo", "Baidu"],
|
||||||
|
description="Fallback search engines to try if the primary engine fails",
|
||||||
|
)
|
||||||
|
retry_delay: int = Field(
|
||||||
|
default=60,
|
||||||
|
description="Seconds to wait before retrying all engines again after they all fail",
|
||||||
|
)
|
||||||
|
max_retries: int = Field(
|
||||||
|
default=3,
|
||||||
|
description="Maximum number of times to retry all engines when all fail",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class BrowserSettings(BaseModel):
|
class BrowserSettings(BaseModel):
|
||||||
|
@ -1,5 +1,4 @@
|
|||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from enum import Enum
|
|
||||||
from typing import Dict, List, Optional, Union
|
from typing import Dict, List, Optional, Union
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
@ -7,10 +6,6 @@ from pydantic import BaseModel
|
|||||||
from app.agent.base import BaseAgent
|
from app.agent.base import BaseAgent
|
||||||
|
|
||||||
|
|
||||||
class FlowType(str, Enum):
|
|
||||||
PLANNING = "planning"
|
|
||||||
|
|
||||||
|
|
||||||
class BaseFlow(BaseModel, ABC):
|
class BaseFlow(BaseModel, ABC):
|
||||||
"""Base class for execution flows supporting multiple agents"""
|
"""Base class for execution flows supporting multiple agents"""
|
||||||
|
|
||||||
@ -60,32 +55,3 @@ class BaseFlow(BaseModel, ABC):
|
|||||||
@abstractmethod
|
@abstractmethod
|
||||||
async def execute(self, input_text: str) -> str:
|
async def execute(self, input_text: str) -> str:
|
||||||
"""Execute the flow with given input"""
|
"""Execute the flow with given input"""
|
||||||
|
|
||||||
|
|
||||||
class PlanStepStatus(str, Enum):
|
|
||||||
"""Enum class defining possible statuses of a plan step"""
|
|
||||||
|
|
||||||
NOT_STARTED = "not_started"
|
|
||||||
IN_PROGRESS = "in_progress"
|
|
||||||
COMPLETED = "completed"
|
|
||||||
BLOCKED = "blocked"
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def get_all_statuses(cls) -> list[str]:
|
|
||||||
"""Return a list of all possible step status values"""
|
|
||||||
return [status.value for status in cls]
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def get_active_statuses(cls) -> list[str]:
|
|
||||||
"""Return a list of values representing active statuses (not started or in progress)"""
|
|
||||||
return [cls.NOT_STARTED.value, cls.IN_PROGRESS.value]
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def get_status_marks(cls) -> Dict[str, str]:
|
|
||||||
"""Return a mapping of statuses to their marker symbols"""
|
|
||||||
return {
|
|
||||||
cls.COMPLETED.value: "[✓]",
|
|
||||||
cls.IN_PROGRESS.value: "[→]",
|
|
||||||
cls.BLOCKED.value: "[!]",
|
|
||||||
cls.NOT_STARTED.value: "[ ]",
|
|
||||||
}
|
|
||||||
|
@ -1,10 +1,15 @@
|
|||||||
|
from enum import Enum
|
||||||
from typing import Dict, List, Union
|
from typing import Dict, List, Union
|
||||||
|
|
||||||
from app.agent.base import BaseAgent
|
from app.agent.base import BaseAgent
|
||||||
from app.flow.base import BaseFlow, FlowType
|
from app.flow.base import BaseFlow
|
||||||
from app.flow.planning import PlanningFlow
|
from app.flow.planning import PlanningFlow
|
||||||
|
|
||||||
|
|
||||||
|
class FlowType(str, Enum):
|
||||||
|
PLANNING = "planning"
|
||||||
|
|
||||||
|
|
||||||
class FlowFactory:
|
class FlowFactory:
|
||||||
"""Factory for creating different types of flows with support for multiple agents"""
|
"""Factory for creating different types of flows with support for multiple agents"""
|
||||||
|
|
||||||
|
@ -1,17 +1,47 @@
|
|||||||
import json
|
import json
|
||||||
import time
|
import time
|
||||||
|
from enum import Enum
|
||||||
from typing import Dict, List, Optional, Union
|
from typing import Dict, List, Optional, Union
|
||||||
|
|
||||||
from pydantic import Field
|
from pydantic import Field
|
||||||
|
|
||||||
from app.agent.base import BaseAgent
|
from app.agent.base import BaseAgent
|
||||||
from app.flow.base import BaseFlow, PlanStepStatus
|
from app.flow.base import BaseFlow
|
||||||
from app.llm import LLM
|
from app.llm import LLM
|
||||||
from app.logger import logger
|
from app.logger import logger
|
||||||
from app.schema import AgentState, Message, ToolChoice
|
from app.schema import AgentState, Message, ToolChoice
|
||||||
from app.tool import PlanningTool
|
from app.tool import PlanningTool
|
||||||
|
|
||||||
|
|
||||||
|
class PlanStepStatus(str, Enum):
|
||||||
|
"""Enum class defining possible statuses of a plan step"""
|
||||||
|
|
||||||
|
NOT_STARTED = "not_started"
|
||||||
|
IN_PROGRESS = "in_progress"
|
||||||
|
COMPLETED = "completed"
|
||||||
|
BLOCKED = "blocked"
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_all_statuses(cls) -> list[str]:
|
||||||
|
"""Return a list of all possible step status values"""
|
||||||
|
return [status.value for status in cls]
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_active_statuses(cls) -> list[str]:
|
||||||
|
"""Return a list of values representing active statuses (not started or in progress)"""
|
||||||
|
return [cls.NOT_STARTED.value, cls.IN_PROGRESS.value]
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_status_marks(cls) -> Dict[str, str]:
|
||||||
|
"""Return a mapping of statuses to their marker symbols"""
|
||||||
|
return {
|
||||||
|
cls.COMPLETED.value: "[✓]",
|
||||||
|
cls.IN_PROGRESS.value: "[→]",
|
||||||
|
cls.BLOCKED.value: "[!]",
|
||||||
|
cls.NOT_STARTED.value: "[ ]",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
class PlanningFlow(BaseFlow):
|
class PlanningFlow(BaseFlow):
|
||||||
"""A flow that manages planning and execution of tasks using agents."""
|
"""A flow that manages planning and execution of tasks using agents."""
|
||||||
|
|
||||||
|
@ -18,6 +18,7 @@ from tenacity import (
|
|||||||
wait_random_exponential,
|
wait_random_exponential,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
from app.bedrock import BedrockClient
|
||||||
from app.config import LLMSettings, config
|
from app.config import LLMSettings, config
|
||||||
from app.exceptions import TokenLimitExceeded
|
from app.exceptions import TokenLimitExceeded
|
||||||
from app.logger import logger # Assuming a logger is set up in your app
|
from app.logger import logger # Assuming a logger is set up in your app
|
||||||
@ -225,6 +226,8 @@ class LLM:
|
|||||||
api_key=self.api_key,
|
api_key=self.api_key,
|
||||||
api_version=self.api_version,
|
api_version=self.api_version,
|
||||||
)
|
)
|
||||||
|
elif self.api_type == "aws":
|
||||||
|
self.client = BedrockClient()
|
||||||
else:
|
else:
|
||||||
self.client = AsyncOpenAI(api_key=self.api_key, base_url=self.base_url)
|
self.client = AsyncOpenAI(api_key=self.api_key, base_url=self.base_url)
|
||||||
|
|
||||||
|
@ -1,30 +1,19 @@
|
|||||||
|
import logging
|
||||||
|
import sys
|
||||||
|
|
||||||
|
|
||||||
|
logging.basicConfig(level=logging.INFO, handlers=[logging.StreamHandler(sys.stderr)])
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import asyncio
|
import asyncio
|
||||||
import atexit
|
import atexit
|
||||||
import json
|
import json
|
||||||
import logging
|
|
||||||
import os
|
|
||||||
import sys
|
|
||||||
from inspect import Parameter, Signature
|
from inspect import Parameter, Signature
|
||||||
from typing import Any, Dict, Optional
|
from typing import Any, Dict, Optional
|
||||||
|
|
||||||
from mcp.server.fastmcp import FastMCP
|
from mcp.server.fastmcp import FastMCP
|
||||||
|
|
||||||
|
from app.logger import logger
|
||||||
# Add directories to Python path (needed for proper importing)
|
|
||||||
current_dir = os.path.dirname(os.path.abspath(__file__))
|
|
||||||
parent_dir = os.path.dirname(current_dir)
|
|
||||||
root_dir = os.path.dirname(parent_dir)
|
|
||||||
sys.path.insert(0, parent_dir)
|
|
||||||
sys.path.insert(0, current_dir)
|
|
||||||
sys.path.insert(0, root_dir)
|
|
||||||
|
|
||||||
# Configure logging (using the same format as original)
|
|
||||||
logging.basicConfig(
|
|
||||||
level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
|
|
||||||
)
|
|
||||||
logger = logging.getLogger("mcp-server")
|
|
||||||
|
|
||||||
from app.tool.base import BaseTool
|
from app.tool.base import BaseTool
|
||||||
from app.tool.bash import Bash
|
from app.tool.bash import Bash
|
||||||
from app.tool.browser_use_tool import BrowserUseTool
|
from app.tool.browser_use_tool import BrowserUseTool
|
||||||
@ -45,11 +34,6 @@ class MCPServer:
|
|||||||
self.tools["editor"] = StrReplaceEditor()
|
self.tools["editor"] = StrReplaceEditor()
|
||||||
self.tools["terminate"] = Terminate()
|
self.tools["terminate"] = Terminate()
|
||||||
|
|
||||||
from app.logger import logger as app_logger
|
|
||||||
|
|
||||||
global logger
|
|
||||||
logger = app_logger
|
|
||||||
|
|
||||||
def register_tool(self, tool: BaseTool, method_name: Optional[str] = None) -> None:
|
def register_tool(self, tool: BaseTool, method_name: Optional[str] = None) -> None:
|
||||||
"""Register a tool with parameter validation and documentation."""
|
"""Register a tool with parameter validation and documentation."""
|
||||||
tool_name = method_name or tool.name
|
tool_name = method_name or tool.name
|
||||||
|
@ -4,6 +4,7 @@ from typing import List
|
|||||||
from tenacity import retry, stop_after_attempt, wait_exponential
|
from tenacity import retry, stop_after_attempt, wait_exponential
|
||||||
|
|
||||||
from app.config import config
|
from app.config import config
|
||||||
|
from app.logger import logger
|
||||||
from app.tool.base import BaseTool
|
from app.tool.base import BaseTool
|
||||||
from app.tool.search import (
|
from app.tool.search import (
|
||||||
BaiduSearchEngine,
|
BaiduSearchEngine,
|
||||||
@ -44,6 +45,8 @@ class WebSearch(BaseTool):
|
|||||||
async def execute(self, query: str, num_results: int = 10) -> List[str]:
|
async def execute(self, query: str, num_results: int = 10) -> List[str]:
|
||||||
"""
|
"""
|
||||||
Execute a Web search and return a list of URLs.
|
Execute a Web search and return a list of URLs.
|
||||||
|
Tries engines in order based on configuration, falling back if an engine fails with errors.
|
||||||
|
If all engines fail, it will wait and retry up to the configured number of times.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
query (str): The search query to submit to the search engine.
|
query (str): The search query to submit to the search engine.
|
||||||
@ -52,37 +55,109 @@ class WebSearch(BaseTool):
|
|||||||
Returns:
|
Returns:
|
||||||
List[str]: A list of URLs matching the search query.
|
List[str]: A list of URLs matching the search query.
|
||||||
"""
|
"""
|
||||||
|
# Get retry settings from config
|
||||||
|
retry_delay = 60 # Default to 60 seconds
|
||||||
|
max_retries = 3 # Default to 3 retries
|
||||||
|
|
||||||
|
if config.search_config:
|
||||||
|
retry_delay = getattr(config.search_config, "retry_delay", 60)
|
||||||
|
max_retries = getattr(config.search_config, "max_retries", 3)
|
||||||
|
|
||||||
|
# Try searching with retries when all engines fail
|
||||||
|
for retry_count in range(
|
||||||
|
max_retries + 1
|
||||||
|
): # +1 because first try is not a retry
|
||||||
|
links = await self._try_all_engines(query, num_results)
|
||||||
|
if links:
|
||||||
|
return links
|
||||||
|
|
||||||
|
if retry_count < max_retries:
|
||||||
|
# All engines failed, wait and retry
|
||||||
|
logger.warning(
|
||||||
|
f"All search engines failed. Waiting {retry_delay} seconds before retry {retry_count + 1}/{max_retries}..."
|
||||||
|
)
|
||||||
|
await asyncio.sleep(retry_delay)
|
||||||
|
else:
|
||||||
|
logger.error(
|
||||||
|
f"All search engines failed after {max_retries} retries. Giving up."
|
||||||
|
)
|
||||||
|
|
||||||
|
return []
|
||||||
|
|
||||||
|
async def _try_all_engines(self, query: str, num_results: int) -> List[str]:
|
||||||
|
"""
|
||||||
|
Try all search engines in the configured order.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
query (str): The search query to submit to the search engine.
|
||||||
|
num_results (int): The number of search results to return.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List[str]: A list of URLs matching the search query, or empty list if all engines fail.
|
||||||
|
"""
|
||||||
engine_order = self._get_engine_order()
|
engine_order = self._get_engine_order()
|
||||||
|
failed_engines = []
|
||||||
|
|
||||||
for engine_name in engine_order:
|
for engine_name in engine_order:
|
||||||
engine = self._search_engine[engine_name]
|
engine = self._search_engine[engine_name]
|
||||||
try:
|
try:
|
||||||
|
logger.info(f"🔎 Attempting search with {engine_name.capitalize()}...")
|
||||||
links = await self._perform_search_with_engine(
|
links = await self._perform_search_with_engine(
|
||||||
engine, query, num_results
|
engine, query, num_results
|
||||||
)
|
)
|
||||||
if links:
|
if links:
|
||||||
|
if failed_engines:
|
||||||
|
logger.info(
|
||||||
|
f"Search successful with {engine_name.capitalize()} after trying: {', '.join(failed_engines)}"
|
||||||
|
)
|
||||||
return links
|
return links
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Search engine '{engine_name}' failed with error: {e}")
|
failed_engines.append(engine_name.capitalize())
|
||||||
|
is_rate_limit = "429" in str(e) or "Too Many Requests" in str(e)
|
||||||
|
|
||||||
|
if is_rate_limit:
|
||||||
|
logger.warning(
|
||||||
|
f"⚠️ {engine_name.capitalize()} search engine rate limit exceeded, trying next engine..."
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logger.warning(
|
||||||
|
f"⚠️ {engine_name.capitalize()} search failed with error: {e}"
|
||||||
|
)
|
||||||
|
|
||||||
|
if failed_engines:
|
||||||
|
logger.error(f"All search engines failed: {', '.join(failed_engines)}")
|
||||||
return []
|
return []
|
||||||
|
|
||||||
def _get_engine_order(self) -> List[str]:
|
def _get_engine_order(self) -> List[str]:
|
||||||
"""
|
"""
|
||||||
Determines the order in which to try search engines.
|
Determines the order in which to try search engines.
|
||||||
Preferred engine is first (based on configuration), followed by the remaining engines.
|
Preferred engine is first (based on configuration), followed by fallback engines,
|
||||||
|
and then the remaining engines.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
List[str]: Ordered list of search engine names.
|
List[str]: Ordered list of search engine names.
|
||||||
"""
|
"""
|
||||||
preferred = "google"
|
preferred = "google"
|
||||||
if config.search_config and config.search_config.engine:
|
fallbacks = []
|
||||||
preferred = config.search_config.engine.lower()
|
|
||||||
|
if config.search_config:
|
||||||
|
if config.search_config.engine:
|
||||||
|
preferred = config.search_config.engine.lower()
|
||||||
|
if config.search_config.fallback_engines:
|
||||||
|
fallbacks = [
|
||||||
|
engine.lower() for engine in config.search_config.fallback_engines
|
||||||
|
]
|
||||||
|
|
||||||
engine_order = []
|
engine_order = []
|
||||||
|
# Add preferred engine first
|
||||||
if preferred in self._search_engine:
|
if preferred in self._search_engine:
|
||||||
engine_order.append(preferred)
|
engine_order.append(preferred)
|
||||||
for key in self._search_engine:
|
|
||||||
if key not in engine_order:
|
# Add configured fallback engines in order
|
||||||
engine_order.append(key)
|
for fallback in fallbacks:
|
||||||
|
if fallback in self._search_engine and fallback not in engine_order:
|
||||||
|
engine_order.append(fallback)
|
||||||
|
|
||||||
return engine_order
|
return engine_order
|
||||||
|
|
||||||
@retry(
|
@retry(
|
||||||
|
@ -6,6 +6,14 @@ api_key = "YOUR_API_KEY" # Your API key
|
|||||||
max_tokens = 8192 # Maximum number of tokens in the response
|
max_tokens = 8192 # Maximum number of tokens in the response
|
||||||
temperature = 0.0 # Controls randomness
|
temperature = 0.0 # Controls randomness
|
||||||
|
|
||||||
|
# [llm] # Amazon Bedrock
|
||||||
|
# api_type = "aws" # Required
|
||||||
|
# model = "us.anthropic.claude-3-7-sonnet-20250219-v1:0" # Bedrock supported modelID
|
||||||
|
# base_url = "bedrock-runtime.us-west-2.amazonaws.com" # Not used now
|
||||||
|
# max_tokens = 8192
|
||||||
|
# temperature = 1.0
|
||||||
|
# api_key = "bear" # Required but not used for Bedrock
|
||||||
|
|
||||||
# [llm] #AZURE OPENAI:
|
# [llm] #AZURE OPENAI:
|
||||||
# api_type= 'azure'
|
# api_type= 'azure'
|
||||||
# model = "YOUR_MODEL_NAME" #"gpt-4o-mini"
|
# model = "YOUR_MODEL_NAME" #"gpt-4o-mini"
|
||||||
@ -65,6 +73,13 @@ temperature = 0.0 # Controls randomness for vision mod
|
|||||||
# [search]
|
# [search]
|
||||||
# Search engine for agent to use. Default is "Google", can be set to "Baidu" or "DuckDuckGo".
|
# Search engine for agent to use. Default is "Google", can be set to "Baidu" or "DuckDuckGo".
|
||||||
#engine = "Google"
|
#engine = "Google"
|
||||||
|
# Fallback engine order. Default is ["DuckDuckGo", "Baidu"] - will try in this order after primary engine fails.
|
||||||
|
#fallback_engines = ["DuckDuckGo", "Baidu"]
|
||||||
|
# Seconds to wait before retrying all engines again when they all fail due to rate limits. Default is 60.
|
||||||
|
#retry_delay = 60
|
||||||
|
# Maximum number of times to retry all engines when all fail. Default is 3.
|
||||||
|
#max_retries = 3
|
||||||
|
|
||||||
|
|
||||||
## Sandbox configuration
|
## Sandbox configuration
|
||||||
#[sandbox]
|
#[sandbox]
|
||||||
|
@ -31,3 +31,5 @@ pytest-asyncio~=0.25.3
|
|||||||
mcp~=1.4.1
|
mcp~=1.4.1
|
||||||
httpx>=0.27.0
|
httpx>=0.27.0
|
||||||
tomli>=2.0.0
|
tomli>=2.0.0
|
||||||
|
|
||||||
|
boto3~=1.37.16
|
||||||
|
21
run_mcp.py
21
run_mcp.py
@ -13,10 +13,14 @@ class MCPRunner:
|
|||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.root_path = config.root_path
|
self.root_path = config.root_path
|
||||||
self.server_script = self.root_path / "app" / "mcp" / "server.py"
|
self.server_reference = "app.mcp.server"
|
||||||
self.agent = MCPAgent()
|
self.agent = MCPAgent()
|
||||||
|
|
||||||
async def initialize(self, connection_type: str, server_url: str = None) -> None:
|
async def initialize(
|
||||||
|
self,
|
||||||
|
connection_type: str,
|
||||||
|
server_url: str | None = None,
|
||||||
|
) -> None:
|
||||||
"""Initialize the MCP agent with the appropriate connection."""
|
"""Initialize the MCP agent with the appropriate connection."""
|
||||||
logger.info(f"Initializing MCPAgent with {connection_type} connection...")
|
logger.info(f"Initializing MCPAgent with {connection_type} connection...")
|
||||||
|
|
||||||
@ -24,7 +28,7 @@ class MCPRunner:
|
|||||||
await self.agent.initialize(
|
await self.agent.initialize(
|
||||||
connection_type="stdio",
|
connection_type="stdio",
|
||||||
command=sys.executable,
|
command=sys.executable,
|
||||||
args=[str(self.server_script)],
|
args=["-m", self.server_reference],
|
||||||
)
|
)
|
||||||
else: # sse
|
else: # sse
|
||||||
await self.agent.initialize(connection_type="sse", server_url=server_url)
|
await self.agent.initialize(connection_type="sse", server_url=server_url)
|
||||||
@ -47,9 +51,14 @@ class MCPRunner:
|
|||||||
|
|
||||||
async def run_default(self) -> None:
|
async def run_default(self) -> None:
|
||||||
"""Run the agent in default mode."""
|
"""Run the agent in default mode."""
|
||||||
await self.agent.run(
|
prompt = input("Enter your prompt: ")
|
||||||
"Hello, what tools are available to me? Terminate after you have listed the tools."
|
if not prompt.strip():
|
||||||
)
|
logger.warning("Empty prompt provided.")
|
||||||
|
return
|
||||||
|
|
||||||
|
logger.warning("Processing your request...")
|
||||||
|
await self.agent.run(prompt)
|
||||||
|
logger.info("Request processing completed.")
|
||||||
|
|
||||||
async def cleanup(self) -> None:
|
async def cleanup(self) -> None:
|
||||||
"""Clean up agent resources."""
|
"""Clean up agent resources."""
|
||||||
|
11
run_mcp_server.py
Normal file
11
run_mcp_server.py
Normal file
@ -0,0 +1,11 @@
|
|||||||
|
# coding: utf-8
|
||||||
|
# A shortcut to launch OpenManus MCP server, where its introduction also solves other import issues.
|
||||||
|
from app.mcp.server import MCPServer, parse_args
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
args = parse_args()
|
||||||
|
|
||||||
|
# Create and run server (maintaining original flow)
|
||||||
|
server = MCPServer()
|
||||||
|
server.run(transport=args.transport)
|
Loading…
x
Reference in New Issue
Block a user