Merge pull request #623 from zyren123/main

print the token used information
This commit is contained in:
mannaandpoem 2025-03-18 23:23:26 +08:00 committed by GitHub
commit f518fc59b7
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -196,6 +196,7 @@ class LLM:
# Add token counting related attributes # Add token counting related attributes
self.total_input_tokens = 0 self.total_input_tokens = 0
self.total_completion_tokens = 0
self.max_input_tokens = ( self.max_input_tokens = (
llm_config.max_input_tokens llm_config.max_input_tokens
if hasattr(llm_config, "max_input_tokens") if hasattr(llm_config, "max_input_tokens")
@ -229,12 +230,15 @@ class LLM:
def count_message_tokens(self, messages: List[dict]) -> int: def count_message_tokens(self, messages: List[dict]) -> int:
return self.token_counter.count_message_tokens(messages) return self.token_counter.count_message_tokens(messages)
def update_token_count(self, input_tokens: int) -> None: def update_token_count(self, input_tokens: int, completion_tokens: int = 0) -> None:
"""Update token counts""" """Update token counts"""
# Only track tokens if max_input_tokens is set # Only track tokens if max_input_tokens is set
self.total_input_tokens += input_tokens self.total_input_tokens += input_tokens
self.total_completion_tokens += completion_tokens
logger.info( logger.info(
f"Token usage: Input={input_tokens}, Cumulative Input={self.total_input_tokens}" f"Token usage: Input={input_tokens}, Completion={completion_tokens}, "
f"Cumulative Input={self.total_input_tokens}, Cumulative Completion={self.total_completion_tokens}, "
f"Total={input_tokens + completion_tokens}, Cumulative Total={self.total_input_tokens + self.total_completion_tokens}"
) )
def check_token_limit(self, input_tokens: int) -> bool: def check_token_limit(self, input_tokens: int) -> bool:
@ -407,7 +411,9 @@ class LLM:
raise ValueError("Empty or invalid response from LLM") raise ValueError("Empty or invalid response from LLM")
# Update token counts # Update token counts
self.update_token_count(response.usage.prompt_tokens) self.update_token_count(
response.usage.prompt_tokens, response.usage.completion_tokens
)
return response.choices[0].message.content return response.choices[0].message.content
@ -418,9 +424,11 @@ class LLM:
response = await self.client.chat.completions.create(**params) response = await self.client.chat.completions.create(**params)
collected_messages = [] collected_messages = []
completion_text = ""
async for chunk in response: async for chunk in response:
chunk_message = chunk.choices[0].delta.content or "" chunk_message = chunk.choices[0].delta.content or ""
collected_messages.append(chunk_message) collected_messages.append(chunk_message)
completion_text += chunk_message
print(chunk_message, end="", flush=True) print(chunk_message, end="", flush=True)
print() # Newline after streaming print() # Newline after streaming
@ -428,6 +436,13 @@ class LLM:
if not full_response: if not full_response:
raise ValueError("Empty response from streaming LLM") raise ValueError("Empty response from streaming LLM")
# estimate completion tokens for streaming response
completion_tokens = self.count_tokens(completion_text)
logger.info(
f"Estimated completion tokens for streaming response: {completion_tokens}"
)
self.total_completion_tokens += completion_tokens
return full_response return full_response
except TokenLimitExceeded: except TokenLimitExceeded:
@ -694,7 +709,9 @@ class LLM:
raise ValueError("Invalid or empty response from LLM") raise ValueError("Invalid or empty response from LLM")
# Update token counts # Update token counts
self.update_token_count(response.usage.prompt_tokens) self.update_token_count(
response.usage.prompt_tokens, response.usage.completion_tokens
)
return response.choices[0].message return response.choices[0].message