Add token usage to gemini (streaming), litellm and anthropic

Also included a token_usage sample that showcases the token usage of subagents with different models under a parent agent.

PiperOrigin-RevId: 759347015
This commit is contained in:
Selcuk Gun
2025-05-15 16:22:04 -07:00
committed by Copybara-Service
parent 4d5760917d
commit 509db3f9fb
8 changed files with 515 additions and 45 deletions

View File

@@ -218,7 +218,7 @@ class BaseLlmFlow(ABC):
When the model returns transcription, the author is "user". Otherwise, the
author is the agent name(not 'model').
Args:
llm_response: The LLM response from the LLM call.
"""

View File

@@ -140,15 +140,15 @@ def message_to_generate_content_response(
role="model",
parts=[content_block_to_part(cb) for cb in message.content],
),
usage_metadata=types.GenerateContentResponseUsageMetadata(
prompt_token_count=message.usage.input_tokens,
candidates_token_count=message.usage.output_tokens,
total_token_count=(
message.usage.input_tokens + message.usage.output_tokens
),
),
# TODO: Deal with these later.
# finish_reason=to_google_genai_finish_reason(message.stop_reason),
# usage_metadata=types.GenerateContentResponseUsageMetadata(
# prompt_token_count=message.usage.input_tokens,
# candidates_token_count=message.usage.output_tokens,
# total_token_count=(
# message.usage.input_tokens + message.usage.output_tokens
# ),
# ),
)
@@ -196,6 +196,12 @@ def function_declaration_to_tool_param(
class Claude(BaseLlm):
""" "Integration with Claude models served from Vertex AI.
Attributes:
model: The name of the Claude model.
"""
model: str = "claude-3-5-sonnet-v2@20241022"
@staticmethod

View File

@@ -121,6 +121,7 @@ class Gemini(BaseLlm):
content=types.ModelContent(
parts=[types.Part.from_text(text=text)],
),
usage_metadata=llm_response.usage_metadata,
)
text = ''
yield llm_response
@@ -174,7 +175,7 @@ class Gemini(BaseLlm):
@cached_property
def _live_api_client(self) -> Client:
if self._api_backend == 'vertex':
#use beta version for vertex api
# use beta version for vertex api
api_version = 'v1beta1'
# use default api version for vertex
return Client(

View File

@@ -67,6 +67,12 @@ class TextChunk(BaseModel):
text: str
class UsageMetadataChunk(BaseModel):
prompt_tokens: int
completion_tokens: int
total_tokens: int
class LiteLLMClient:
"""Provides acompletion method (for better testability)."""
@@ -344,15 +350,20 @@ def _function_declaration_to_tool_param(
def _model_response_to_chunk(
response: ModelResponse,
) -> Generator[
Tuple[Optional[Union[TextChunk, FunctionChunk]], Optional[str]], None, None
Tuple[
Optional[Union[TextChunk, FunctionChunk, UsageMetadataChunk]],
Optional[str],
],
None,
None,
]:
"""Converts a litellm message to text or function chunk.
"""Converts a litellm message to text, function or usage metadata chunk.
Args:
response: The response from the model.
Yields:
A tuple of text or function chunk and finish reason.
A tuple of text or function or usage metadata chunk and finish reason.
"""
message = None
@@ -384,11 +395,21 @@ def _model_response_to_chunk(
if not message:
yield None, None
# Ideally usage would be expected with the last ModelResponseStream with a
# finish_reason set. But this is not the case we are observing from litellm.
# So we are sending it as a separate chunk to be set on the llm_response.
if response.get("usage", None):
yield UsageMetadataChunk(
prompt_tokens=response["usage"].get("prompt_tokens", 0),
completion_tokens=response["usage"].get("completion_tokens", 0),
total_tokens=response["usage"].get("total_tokens", 0),
), None
def _model_response_to_generate_content_response(
response: ModelResponse,
) -> LlmResponse:
"""Converts a litellm response to LlmResponse.
"""Converts a litellm response to LlmResponse. Also adds usage metadata.
Args:
response: The model response.
@@ -403,7 +424,15 @@ def _model_response_to_generate_content_response(
if not message:
raise ValueError("No message in response")
return _message_to_generate_content_response(message)
llm_response = _message_to_generate_content_response(message)
if response.get("usage", None):
llm_response.usage_metadata = types.GenerateContentResponseUsageMetadata(
prompt_token_count=response["usage"].get("prompt_tokens", 0),
candidates_token_count=response["usage"].get("completion_tokens", 0),
total_token_count=response["usage"].get("total_tokens", 0),
)
return llm_response
def _message_to_generate_content_response(
@@ -628,6 +657,10 @@ class LiteLlm(BaseLlm):
function_args = ""
function_id = None
completion_args["stream"] = True
aggregated_llm_response = None
aggregated_llm_response_with_tool_call = None
usage_metadata = None
for part in self.llm_client.completion(**completion_args):
for chunk, finish_reason in _model_response_to_chunk(part):
if isinstance(chunk, FunctionChunk):
@@ -645,32 +678,55 @@ class LiteLlm(BaseLlm):
),
is_partial=True,
)
elif isinstance(chunk, UsageMetadataChunk):
usage_metadata = types.GenerateContentResponseUsageMetadata(
prompt_token_count=chunk.prompt_tokens,
candidates_token_count=chunk.completion_tokens,
total_token_count=chunk.total_tokens,
)
if finish_reason == "tool_calls" and function_id:
yield _message_to_generate_content_response(
ChatCompletionAssistantMessage(
role="assistant",
content="",
tool_calls=[
ChatCompletionMessageToolCall(
type="function",
id=function_id,
function=Function(
name=function_name,
arguments=function_args,
),
)
],
aggregated_llm_response_with_tool_call = (
_message_to_generate_content_response(
ChatCompletionAssistantMessage(
role="assistant",
content="",
tool_calls=[
ChatCompletionMessageToolCall(
type="function",
id=function_id,
function=Function(
name=function_name,
arguments=function_args,
),
)
],
)
)
)
function_name = ""
function_args = ""
function_id = None
elif finish_reason == "stop" and text:
yield _message_to_generate_content_response(
aggregated_llm_response = _message_to_generate_content_response(
ChatCompletionAssistantMessage(role="assistant", content=text)
)
text = ""
# waiting until streaming ends to yield the llm_response as litellm tends
# to send chunk that contains usage_metadata after the chunk with
# finish_reason set to tool_calls or stop.
if aggregated_llm_response:
if usage_metadata:
aggregated_llm_response.usage_metadata = usage_metadata
usage_metadata = None
yield aggregated_llm_response
if aggregated_llm_response_with_tool_call:
if usage_metadata:
aggregated_llm_response_with_tool_call.usage_metadata = usage_metadata
yield aggregated_llm_response_with_tool_call
else:
response = await self.llm_client.acompletion(**completion_args)
yield _model_response_to_generate_content_response(response)