Merge pull request #510 from the0807/feature/o3-mini

Support OpenAI Reasoning Models (o1, o3-mini)
This commit is contained in:
mannaandpoem 2025-03-14 11:47:34 +08:00 committed by GitHub
commit e844dfca34
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -14,6 +14,7 @@ from app.config import LLMSettings, config
from app.logger import logger # Assuming a logger is set up in your app from app.logger import logger # Assuming a logger is set up in your app
from app.schema import Message, TOOL_CHOICE_TYPE, ROLE_VALUES, TOOL_CHOICE_VALUES, ToolChoice from app.schema import Message, TOOL_CHOICE_TYPE, ROLE_VALUES, TOOL_CHOICE_VALUES, ToolChoice
REASONING_MODELS = ["o1", "o3-mini"]
class LLM: class LLM:
_instances: Dict[str, "LLM"] = {} _instances: Dict[str, "LLM"] = {}
@ -133,27 +134,30 @@ class LLM:
else: else:
messages = self.format_messages(messages) messages = self.format_messages(messages)
params = {
"model": self.model,
"messages": messages,
}
if self.model in REASONING_MODELS:
params["max_completion_tokens"] = self.max_tokens
else:
params["max_tokens"] = self.max_tokens
params["temperature"] = temperature or self.temperature
if not stream: if not stream:
# Non-streaming request # Non-streaming request
response = await self.client.chat.completions.create( params["stream"] = False
model=self.model,
messages=messages, response = await self.client.chat.completions.create(**params)
max_tokens=self.max_tokens,
temperature=temperature or self.temperature,
stream=False,
)
if not response.choices or not response.choices[0].message.content: if not response.choices or not response.choices[0].message.content:
raise ValueError("Empty or invalid response from LLM") raise ValueError("Empty or invalid response from LLM")
return response.choices[0].message.content return response.choices[0].message.content
# Streaming request # Streaming request
response = await self.client.chat.completions.create( params["stream"] = True
model=self.model, response = await self.client.chat.completions.create(**params)
messages=messages,
max_tokens=self.max_tokens,
temperature=temperature or self.temperature,
stream=True,
)
collected_messages = [] collected_messages = []
async for chunk in response: async for chunk in response:
@ -230,16 +234,22 @@ class LLM:
raise ValueError("Each tool must be a dict with 'type' field") raise ValueError("Each tool must be a dict with 'type' field")
# Set up the completion request # Set up the completion request
response = await self.client.chat.completions.create( params = {
model=self.model, "model": self.model,
messages=messages, "messages": messages,
temperature=temperature or self.temperature, "tools": tools,
max_tokens=self.max_tokens, "tool_choice": tool_choice,
tools=tools, "timeout": timeout,
tool_choice=tool_choice,
timeout=timeout,
**kwargs, **kwargs,
) }
if self.model in REASONING_MODELS:
params["max_completion_tokens"] = self.max_tokens
else:
params["max_tokens"] = self.max_tokens
params["temperature"] = temperature or self.temperature
response = await self.client.chat.completions.create(**params)
# Check if response is valid # Check if response is valid
if not response.choices or not response.choices[0].message: if not response.choices or not response.choices[0].message: