diff --git a/app/config.py b/app/config.py index 2275fb4..88a4928 100644 --- a/app/config.py +++ b/app/config.py @@ -21,7 +21,8 @@ class LLMSettings(BaseModel): api_key: str = Field(..., description="API key") max_tokens: int = Field(4096, description="Maximum number of tokens per request") temperature: float = Field(1.0, description="Sampling temperature") - + api_type: str = Field(..., description="AzureOpenai or Openai") + api_version: str = Field(..., description="Azure Openai version if AzureOpenai") class AppConfig(BaseModel): llm: Dict[str, LLMSettings] @@ -76,6 +77,9 @@ class Config: "api_key": base_llm.get("api_key"), "max_tokens": base_llm.get("max_tokens", 4096), "temperature": base_llm.get("temperature", 1.0), + "api_type": base_llm.get("api_type", ""), + "api_version": base_llm.get("api_version", ""), + } config_dict = { diff --git a/app/llm.py b/app/llm.py index 91a4b8f..af5d74c 100644 --- a/app/llm.py +++ b/app/llm.py @@ -6,6 +6,7 @@ from openai import ( AuthenticationError, OpenAIError, RateLimitError, + AsyncAzureOpenAI ) from tenacity import retry, stop_after_attempt, wait_random_exponential @@ -35,8 +36,19 @@ class LLM: self.model = llm_config.model self.max_tokens = llm_config.max_tokens self.temperature = llm_config.temperature - self.client = AsyncOpenAI( - api_key=llm_config.api_key, base_url=llm_config.base_url + self.api_type = llm_config.api_type + self.api_key = llm_config.api_key + self.api_version = llm_config.api_version + self.base_url = llm_config.base_url + if self.api_type == "azure": + self.client = AsyncAzureOpenAI( + base_url=self.base_url, + api_key=self.api_key, + api_version=self.api_version + ) + else: + self.client = AsyncOpenAI( + api_key=self.api_key, base_url=self.base_url ) @staticmethod diff --git a/config/config.example.toml b/config/config.example.toml index de71832..a41e548 100644 --- a/config/config.example.toml +++ b/config/config.example.toml @@ -6,6 +6,15 @@ api_key = "sk-..." max_tokens = 4096 temperature = 0.0 +[llm] #AZURE OPENAI: +api_type= 'azure' +model = "YOUR_MODEL_NAME" #"gpt-4o-mini" +base_url = "{YOUR_AZURE_ENDPOINT.rstrip('/')}/openai/deployments/{AZURE_DEPOLYMENT_ID}" +api_key = "AZURE API KEY" +max_tokens = 8096 +temperature = 0.0 +api_version="AZURE API VERSION" #"2024-08-01-preview" + # Optional configuration for specific LLM models [llm.vision] model = "claude-3-5-sonnet"