diff --git a/app/llm.py b/app/llm.py index eccedf1..37de566 100644 --- a/app/llm.py +++ b/app/llm.py @@ -196,6 +196,7 @@ class LLM: # Add token counting related attributes self.total_input_tokens = 0 + self.total_completion_tokens = 0 self.max_input_tokens = ( llm_config.max_input_tokens if hasattr(llm_config, "max_input_tokens") @@ -229,12 +230,15 @@ class LLM: def count_message_tokens(self, messages: List[dict]) -> int: return self.token_counter.count_message_tokens(messages) - def update_token_count(self, input_tokens: int) -> None: + def update_token_count(self, input_tokens: int, completion_tokens: int = 0) -> None: """Update token counts""" # Only track tokens if max_input_tokens is set self.total_input_tokens += input_tokens + self.total_completion_tokens += completion_tokens logger.info( - f"Token usage: Input={input_tokens}, Cumulative Input={self.total_input_tokens}" + f"Token usage: Input={input_tokens}, Completion={completion_tokens}, " + f"Cumulative Input={self.total_input_tokens}, Cumulative Completion={self.total_completion_tokens}, " + f"Total={input_tokens + completion_tokens}, Cumulative Total={self.total_input_tokens + self.total_completion_tokens}" ) def check_token_limit(self, input_tokens: int) -> bool: @@ -407,7 +411,9 @@ class LLM: raise ValueError("Empty or invalid response from LLM") # Update token counts - self.update_token_count(response.usage.prompt_tokens) + self.update_token_count( + response.usage.prompt_tokens, response.usage.completion_tokens + ) return response.choices[0].message.content @@ -418,9 +424,11 @@ class LLM: response = await self.client.chat.completions.create(**params) collected_messages = [] + completion_text = "" async for chunk in response: chunk_message = chunk.choices[0].delta.content or "" collected_messages.append(chunk_message) + completion_text += chunk_message print(chunk_message, end="", flush=True) print() # Newline after streaming @@ -428,6 +436,13 @@ class LLM: if not full_response: raise ValueError("Empty response from streaming LLM") + # estimate completion tokens for streaming response + completion_tokens = self.count_tokens(completion_text) + logger.info( + f"Estimated completion tokens for streaming response: {completion_tokens}" + ) + self.total_completion_tokens += completion_tokens + return full_response except TokenLimitExceeded: @@ -694,7 +709,9 @@ class LLM: raise ValueError("Invalid or empty response from LLM") # Update token counts - self.update_token_count(response.usage.prompt_tokens) + self.update_token_count( + response.usage.prompt_tokens, response.usage.completion_tokens + ) return response.choices[0].message