Support OpenAI Reasoning Models (o1, o3-mini)

This commit is contained in:
the0807 2025-03-12 14:33:32 +09:00
parent af8023de43
commit 983e8f0d4b

View File

@ -14,6 +14,7 @@ from app.config import LLMSettings, config
from app.logger import logger # Assuming a logger is set up in your app from app.logger import logger # Assuming a logger is set up in your app
from app.schema import Message from app.schema import Message
REASONING_MODELS = ["o1", "o3-mini"]
class LLM: class LLM:
_instances: Dict[str, "LLM"] = {} _instances: Dict[str, "LLM"] = {}
@ -133,27 +134,30 @@ class LLM:
else: else:
messages = self.format_messages(messages) messages = self.format_messages(messages)
params = {
"model": self.model,
"messages": messages,
}
if self.model in REASONING_MODELS:
params["max_completion_tokens"] = self.max_tokens
else:
params["max_tokens"] = self.max_tokens
params["temperature"] = temperature or self.temperature
if not stream: if not stream:
# Non-streaming request # Non-streaming request
response = await self.client.chat.completions.create( params["stream"] = False
model=self.model,
messages=messages, response = await self.client.chat.completions.create(**params)
max_tokens=self.max_tokens,
temperature=temperature or self.temperature,
stream=False,
)
if not response.choices or not response.choices[0].message.content: if not response.choices or not response.choices[0].message.content:
raise ValueError("Empty or invalid response from LLM") raise ValueError("Empty or invalid response from LLM")
return response.choices[0].message.content return response.choices[0].message.content
# Streaming request # Streaming request
response = await self.client.chat.completions.create( params["stream"] = True
model=self.model, response = await self.client.chat.completions.create(**params)
messages=messages,
max_tokens=self.max_tokens,
temperature=temperature or self.temperature,
stream=True,
)
collected_messages = [] collected_messages = []
async for chunk in response: async for chunk in response:
@ -230,16 +234,22 @@ class LLM:
raise ValueError("Each tool must be a dict with 'type' field") raise ValueError("Each tool must be a dict with 'type' field")
# Set up the completion request # Set up the completion request
response = await self.client.chat.completions.create( params = {
model=self.model, "model": self.model,
messages=messages, "messages": messages,
temperature=temperature or self.temperature, "tools": tools,
max_tokens=self.max_tokens, "tool_choice": tool_choice,
tools=tools, "timeout": timeout,
tool_choice=tool_choice,
timeout=timeout,
**kwargs, **kwargs,
) }
if self.model in REASONING_MODELS:
params["max_completion_tokens"] = self.max_tokens
else:
params["max_tokens"] = self.max_tokens
params["temperature"] = temperature or self.temperature
response = await self.client.chat.completions.create(**params)
# Check if response is valid # Check if response is valid
if not response.choices or not response.choices[0].message: if not response.choices or not response.choices[0].message: