Merge pull request #510 from the0807/feature/o3-mini
Support OpenAI Reasoning Models (o1, o3-mini)
This commit is contained in:
commit
e844dfca34
56
app/llm.py
56
app/llm.py
@ -14,6 +14,7 @@ from app.config import LLMSettings, config
|
|||||||
from app.logger import logger # Assuming a logger is set up in your app
|
from app.logger import logger # Assuming a logger is set up in your app
|
||||||
from app.schema import Message, TOOL_CHOICE_TYPE, ROLE_VALUES, TOOL_CHOICE_VALUES, ToolChoice
|
from app.schema import Message, TOOL_CHOICE_TYPE, ROLE_VALUES, TOOL_CHOICE_VALUES, ToolChoice
|
||||||
|
|
||||||
|
REASONING_MODELS = ["o1", "o3-mini"]
|
||||||
|
|
||||||
class LLM:
|
class LLM:
|
||||||
_instances: Dict[str, "LLM"] = {}
|
_instances: Dict[str, "LLM"] = {}
|
||||||
@ -133,27 +134,30 @@ class LLM:
|
|||||||
else:
|
else:
|
||||||
messages = self.format_messages(messages)
|
messages = self.format_messages(messages)
|
||||||
|
|
||||||
|
params = {
|
||||||
|
"model": self.model,
|
||||||
|
"messages": messages,
|
||||||
|
}
|
||||||
|
|
||||||
|
if self.model in REASONING_MODELS:
|
||||||
|
params["max_completion_tokens"] = self.max_tokens
|
||||||
|
else:
|
||||||
|
params["max_tokens"] = self.max_tokens
|
||||||
|
params["temperature"] = temperature or self.temperature
|
||||||
|
|
||||||
if not stream:
|
if not stream:
|
||||||
# Non-streaming request
|
# Non-streaming request
|
||||||
response = await self.client.chat.completions.create(
|
params["stream"] = False
|
||||||
model=self.model,
|
|
||||||
messages=messages,
|
response = await self.client.chat.completions.create(**params)
|
||||||
max_tokens=self.max_tokens,
|
|
||||||
temperature=temperature or self.temperature,
|
|
||||||
stream=False,
|
|
||||||
)
|
|
||||||
if not response.choices or not response.choices[0].message.content:
|
if not response.choices or not response.choices[0].message.content:
|
||||||
raise ValueError("Empty or invalid response from LLM")
|
raise ValueError("Empty or invalid response from LLM")
|
||||||
return response.choices[0].message.content
|
return response.choices[0].message.content
|
||||||
|
|
||||||
# Streaming request
|
# Streaming request
|
||||||
response = await self.client.chat.completions.create(
|
params["stream"] = True
|
||||||
model=self.model,
|
response = await self.client.chat.completions.create(**params)
|
||||||
messages=messages,
|
|
||||||
max_tokens=self.max_tokens,
|
|
||||||
temperature=temperature or self.temperature,
|
|
||||||
stream=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
collected_messages = []
|
collected_messages = []
|
||||||
async for chunk in response:
|
async for chunk in response:
|
||||||
@ -230,16 +234,22 @@ class LLM:
|
|||||||
raise ValueError("Each tool must be a dict with 'type' field")
|
raise ValueError("Each tool must be a dict with 'type' field")
|
||||||
|
|
||||||
# Set up the completion request
|
# Set up the completion request
|
||||||
response = await self.client.chat.completions.create(
|
params = {
|
||||||
model=self.model,
|
"model": self.model,
|
||||||
messages=messages,
|
"messages": messages,
|
||||||
temperature=temperature or self.temperature,
|
"tools": tools,
|
||||||
max_tokens=self.max_tokens,
|
"tool_choice": tool_choice,
|
||||||
tools=tools,
|
"timeout": timeout,
|
||||||
tool_choice=tool_choice,
|
|
||||||
timeout=timeout,
|
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
}
|
||||||
|
|
||||||
|
if self.model in REASONING_MODELS:
|
||||||
|
params["max_completion_tokens"] = self.max_tokens
|
||||||
|
else:
|
||||||
|
params["max_tokens"] = self.max_tokens
|
||||||
|
params["temperature"] = temperature or self.temperature
|
||||||
|
|
||||||
|
response = await self.client.chat.completions.create(**params)
|
||||||
|
|
||||||
# Check if response is valid
|
# Check if response is valid
|
||||||
if not response.choices or not response.choices[0].message:
|
if not response.choices or not response.choices[0].message:
|
||||||
|
Loading…
x
Reference in New Issue
Block a user