diff --git a/app/llm.py b/app/llm.py index 5b599e8..3314062 100644 --- a/app/llm.py +++ b/app/llm.py @@ -14,6 +14,7 @@ from app.config import LLMSettings, config from app.logger import logger # Assuming a logger is set up in your app from app.schema import Message, TOOL_CHOICE_TYPE, ROLE_VALUES, TOOL_CHOICE_VALUES, ToolChoice +REASONING_MODELS = ["o1", "o3-mini"] class LLM: _instances: Dict[str, "LLM"] = {} @@ -133,27 +134,30 @@ class LLM: else: messages = self.format_messages(messages) + params = { + "model": self.model, + "messages": messages, + } + + if self.model in REASONING_MODELS: + params["max_completion_tokens"] = self.max_tokens + else: + params["max_tokens"] = self.max_tokens + params["temperature"] = temperature or self.temperature + if not stream: # Non-streaming request - response = await self.client.chat.completions.create( - model=self.model, - messages=messages, - max_tokens=self.max_tokens, - temperature=temperature or self.temperature, - stream=False, - ) + params["stream"] = False + + response = await self.client.chat.completions.create(**params) + if not response.choices or not response.choices[0].message.content: raise ValueError("Empty or invalid response from LLM") return response.choices[0].message.content # Streaming request - response = await self.client.chat.completions.create( - model=self.model, - messages=messages, - max_tokens=self.max_tokens, - temperature=temperature or self.temperature, - stream=True, - ) + params["stream"] = True + response = await self.client.chat.completions.create(**params) collected_messages = [] async for chunk in response: @@ -230,16 +234,22 @@ class LLM: raise ValueError("Each tool must be a dict with 'type' field") # Set up the completion request - response = await self.client.chat.completions.create( - model=self.model, - messages=messages, - temperature=temperature or self.temperature, - max_tokens=self.max_tokens, - tools=tools, - tool_choice=tool_choice, - timeout=timeout, + params = { + "model": self.model, + "messages": messages, + "tools": tools, + "tool_choice": tool_choice, + "timeout": timeout, **kwargs, - ) + } + + if self.model in REASONING_MODELS: + params["max_completion_tokens"] = self.max_tokens + else: + params["max_tokens"] = self.max_tokens + params["temperature"] = temperature or self.temperature + + response = await self.client.chat.completions.create(**params) # Check if response is valid if not response.choices or not response.choices[0].message: