Add token usage to gemini (streaming), litellm and anthropic

Also included a token_usage sample that showcases the token usage of subagents with different models under a parent agent.

PiperOrigin-RevId: 759347015
This commit is contained in:
Selcuk Gun 2025-05-15 16:22:04 -07:00 committed by Copybara-Service
parent 4d5760917d
commit 509db3f9fb
8 changed files with 515 additions and 45 deletions

View File

@ -0,0 +1,15 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from . import agent

View File

@ -0,0 +1,97 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import random
from google.adk import Agent
from google.adk.agents.llm_agent import LlmAgent
from google.adk.agents.sequential_agent import SequentialAgent
from google.adk.models.anthropic_llm import Claude
from google.adk.models.lite_llm import LiteLlm
from google.adk.planners import BuiltInPlanner
from google.adk.planners import PlanReActPlanner
from google.adk.tools.tool_context import ToolContext
from google.genai import types
def roll_die(sides: int, tool_context: ToolContext) -> int:
"""Roll a die and return the rolled result.
Args:
sides: The integer number of sides the die has.
Returns:
An integer of the result of rolling the die.
"""
result = random.randint(1, sides)
if 'rolls' not in tool_context.state:
tool_context.state['rolls'] = []
tool_context.state['rolls'] = tool_context.state['rolls'] + [result]
return result
roll_agent_with_openai = LlmAgent(
model=LiteLlm(model='openai/gpt-4o'),
description='Handles rolling dice of different sizes.',
name='roll_agent_with_openai',
instruction="""
You are responsible for rolling dice based on the user's request.
When asked to roll a die, you must call the roll_die tool with the number of sides as an integer.
""",
tools=[roll_die],
)
roll_agent_with_claude = LlmAgent(
model=Claude(model='claude-3-7-sonnet@20250219'),
description='Handles rolling dice of different sizes.',
name='roll_agent_with_claude',
instruction="""
You are responsible for rolling dice based on the user's request.
When asked to roll a die, you must call the roll_die tool with the number of sides as an integer.
""",
tools=[roll_die],
)
roll_agent_with_litellm_claude = LlmAgent(
model=LiteLlm(model='vertex_ai/claude-3-7-sonnet'),
description='Handles rolling dice of different sizes.',
name='roll_agent_with_litellm_claude',
instruction="""
You are responsible for rolling dice based on the user's request.
When asked to roll a die, you must call the roll_die tool with the number of sides as an integer.
""",
tools=[roll_die],
)
roll_agent_with_gemini = LlmAgent(
model='gemini-2.0-flash',
description='Handles rolling dice of different sizes.',
name='roll_agent_with_gemini',
instruction="""
You are responsible for rolling dice based on the user's request.
When asked to roll a die, you must call the roll_die tool with the number of sides as an integer.
""",
tools=[roll_die],
)
root_agent = SequentialAgent(
name='code_pipeline_agent',
sub_agents=[
roll_agent_with_openai,
roll_agent_with_claude,
roll_agent_with_litellm_claude,
roll_agent_with_gemini,
],
)

View File

@ -0,0 +1,102 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import asyncio
import time
import warnings
import agent
from dotenv import load_dotenv
from google.adk import Runner
from google.adk.agents.run_config import RunConfig
from google.adk.artifacts import InMemoryArtifactService
from google.adk.cli.utils import logs
from google.adk.sessions import InMemorySessionService
from google.adk.sessions import Session
from google.genai import types
load_dotenv(override=True)
warnings.filterwarnings('ignore', category=UserWarning)
logs.log_to_tmp_folder()
async def main():
app_name = 'my_app'
user_id_1 = 'user1'
session_service = InMemorySessionService()
artifact_service = InMemoryArtifactService()
runner = Runner(
app_name=app_name,
agent=agent.root_agent,
artifact_service=artifact_service,
session_service=session_service,
)
session_11 = session_service.create_session(
app_name=app_name, user_id=user_id_1
)
total_prompt_tokens = 0
total_candidate_tokens = 0
total_tokens = 0
async def run_prompt(session: Session, new_message: str):
nonlocal total_prompt_tokens
nonlocal total_candidate_tokens
nonlocal total_tokens
content = types.Content(
role='user', parts=[types.Part.from_text(text=new_message)]
)
print('** User says:', content.model_dump(exclude_none=True))
async for event in runner.run_async(
user_id=user_id_1,
session_id=session.id,
new_message=content,
):
if event.content.parts and event.content.parts[0].text:
print(f'** {event.author}: {event.content.parts[0].text}')
if event.usage_metadata:
total_prompt_tokens += event.usage_metadata.prompt_token_count or 0
total_candidate_tokens += (
event.usage_metadata.candidates_token_count or 0
)
total_tokens += event.usage_metadata.total_token_count or 0
print(
'Turn tokens:'
f' {event.usage_metadata.total_token_count} (prompt={event.usage_metadata.prompt_token_count},'
f' candidates={event.usage_metadata.candidates_token_count})'
)
print(
f'Session tokens: {total_tokens} (prompt={total_prompt_tokens},'
f' candidates={total_candidate_tokens})'
)
start_time = time.time()
print('Start time:', start_time)
print('------------------------------------')
await run_prompt(session_11, 'Hi')
await run_prompt(session_11, 'Roll a die with 100 sides')
print(
await artifact_service.list_artifact_keys(
app_name=app_name, user_id=user_id_1, session_id=session_11.id
)
)
end_time = time.time()
print('------------------------------------')
print('End time:', end_time)
print('Total time:', end_time - start_time)
if __name__ == '__main__':
asyncio.run(main())

View File

@ -218,7 +218,7 @@ class BaseLlmFlow(ABC):
When the model returns transcription, the author is "user". Otherwise, the When the model returns transcription, the author is "user". Otherwise, the
author is the agent name(not 'model'). author is the agent name(not 'model').
Args: Args:
llm_response: The LLM response from the LLM call. llm_response: The LLM response from the LLM call.
""" """

View File

@ -140,15 +140,15 @@ def message_to_generate_content_response(
role="model", role="model",
parts=[content_block_to_part(cb) for cb in message.content], parts=[content_block_to_part(cb) for cb in message.content],
), ),
usage_metadata=types.GenerateContentResponseUsageMetadata(
prompt_token_count=message.usage.input_tokens,
candidates_token_count=message.usage.output_tokens,
total_token_count=(
message.usage.input_tokens + message.usage.output_tokens
),
),
# TODO: Deal with these later. # TODO: Deal with these later.
# finish_reason=to_google_genai_finish_reason(message.stop_reason), # finish_reason=to_google_genai_finish_reason(message.stop_reason),
# usage_metadata=types.GenerateContentResponseUsageMetadata(
# prompt_token_count=message.usage.input_tokens,
# candidates_token_count=message.usage.output_tokens,
# total_token_count=(
# message.usage.input_tokens + message.usage.output_tokens
# ),
# ),
) )
@ -196,6 +196,12 @@ def function_declaration_to_tool_param(
class Claude(BaseLlm): class Claude(BaseLlm):
""" "Integration with Claude models served from Vertex AI.
Attributes:
model: The name of the Claude model.
"""
model: str = "claude-3-5-sonnet-v2@20241022" model: str = "claude-3-5-sonnet-v2@20241022"
@staticmethod @staticmethod

View File

@ -121,6 +121,7 @@ class Gemini(BaseLlm):
content=types.ModelContent( content=types.ModelContent(
parts=[types.Part.from_text(text=text)], parts=[types.Part.from_text(text=text)],
), ),
usage_metadata=llm_response.usage_metadata,
) )
text = '' text = ''
yield llm_response yield llm_response
@ -174,7 +175,7 @@ class Gemini(BaseLlm):
@cached_property @cached_property
def _live_api_client(self) -> Client: def _live_api_client(self) -> Client:
if self._api_backend == 'vertex': if self._api_backend == 'vertex':
#use beta version for vertex api # use beta version for vertex api
api_version = 'v1beta1' api_version = 'v1beta1'
# use default api version for vertex # use default api version for vertex
return Client( return Client(

View File

@ -67,6 +67,12 @@ class TextChunk(BaseModel):
text: str text: str
class UsageMetadataChunk(BaseModel):
prompt_tokens: int
completion_tokens: int
total_tokens: int
class LiteLLMClient: class LiteLLMClient:
"""Provides acompletion method (for better testability).""" """Provides acompletion method (for better testability)."""
@ -344,15 +350,20 @@ def _function_declaration_to_tool_param(
def _model_response_to_chunk( def _model_response_to_chunk(
response: ModelResponse, response: ModelResponse,
) -> Generator[ ) -> Generator[
Tuple[Optional[Union[TextChunk, FunctionChunk]], Optional[str]], None, None Tuple[
Optional[Union[TextChunk, FunctionChunk, UsageMetadataChunk]],
Optional[str],
],
None,
None,
]: ]:
"""Converts a litellm message to text or function chunk. """Converts a litellm message to text, function or usage metadata chunk.
Args: Args:
response: The response from the model. response: The response from the model.
Yields: Yields:
A tuple of text or function chunk and finish reason. A tuple of text or function or usage metadata chunk and finish reason.
""" """
message = None message = None
@ -384,11 +395,21 @@ def _model_response_to_chunk(
if not message: if not message:
yield None, None yield None, None
# Ideally usage would be expected with the last ModelResponseStream with a
# finish_reason set. But this is not the case we are observing from litellm.
# So we are sending it as a separate chunk to be set on the llm_response.
if response.get("usage", None):
yield UsageMetadataChunk(
prompt_tokens=response["usage"].get("prompt_tokens", 0),
completion_tokens=response["usage"].get("completion_tokens", 0),
total_tokens=response["usage"].get("total_tokens", 0),
), None
def _model_response_to_generate_content_response( def _model_response_to_generate_content_response(
response: ModelResponse, response: ModelResponse,
) -> LlmResponse: ) -> LlmResponse:
"""Converts a litellm response to LlmResponse. """Converts a litellm response to LlmResponse. Also adds usage metadata.
Args: Args:
response: The model response. response: The model response.
@ -403,7 +424,15 @@ def _model_response_to_generate_content_response(
if not message: if not message:
raise ValueError("No message in response") raise ValueError("No message in response")
return _message_to_generate_content_response(message)
llm_response = _message_to_generate_content_response(message)
if response.get("usage", None):
llm_response.usage_metadata = types.GenerateContentResponseUsageMetadata(
prompt_token_count=response["usage"].get("prompt_tokens", 0),
candidates_token_count=response["usage"].get("completion_tokens", 0),
total_token_count=response["usage"].get("total_tokens", 0),
)
return llm_response
def _message_to_generate_content_response( def _message_to_generate_content_response(
@ -628,6 +657,10 @@ class LiteLlm(BaseLlm):
function_args = "" function_args = ""
function_id = None function_id = None
completion_args["stream"] = True completion_args["stream"] = True
aggregated_llm_response = None
aggregated_llm_response_with_tool_call = None
usage_metadata = None
for part in self.llm_client.completion(**completion_args): for part in self.llm_client.completion(**completion_args):
for chunk, finish_reason in _model_response_to_chunk(part): for chunk, finish_reason in _model_response_to_chunk(part):
if isinstance(chunk, FunctionChunk): if isinstance(chunk, FunctionChunk):
@ -645,32 +678,55 @@ class LiteLlm(BaseLlm):
), ),
is_partial=True, is_partial=True,
) )
elif isinstance(chunk, UsageMetadataChunk):
usage_metadata = types.GenerateContentResponseUsageMetadata(
prompt_token_count=chunk.prompt_tokens,
candidates_token_count=chunk.completion_tokens,
total_token_count=chunk.total_tokens,
)
if finish_reason == "tool_calls" and function_id: if finish_reason == "tool_calls" and function_id:
yield _message_to_generate_content_response( aggregated_llm_response_with_tool_call = (
ChatCompletionAssistantMessage( _message_to_generate_content_response(
role="assistant", ChatCompletionAssistantMessage(
content="", role="assistant",
tool_calls=[ content="",
ChatCompletionMessageToolCall( tool_calls=[
type="function", ChatCompletionMessageToolCall(
id=function_id, type="function",
function=Function( id=function_id,
name=function_name, function=Function(
arguments=function_args, name=function_name,
), arguments=function_args,
) ),
], )
],
)
) )
) )
function_name = "" function_name = ""
function_args = "" function_args = ""
function_id = None function_id = None
elif finish_reason == "stop" and text: elif finish_reason == "stop" and text:
yield _message_to_generate_content_response( aggregated_llm_response = _message_to_generate_content_response(
ChatCompletionAssistantMessage(role="assistant", content=text) ChatCompletionAssistantMessage(role="assistant", content=text)
) )
text = "" text = ""
# waiting until streaming ends to yield the llm_response as litellm tends
# to send chunk that contains usage_metadata after the chunk with
# finish_reason set to tool_calls or stop.
if aggregated_llm_response:
if usage_metadata:
aggregated_llm_response.usage_metadata = usage_metadata
usage_metadata = None
yield aggregated_llm_response
if aggregated_llm_response_with_tool_call:
if usage_metadata:
aggregated_llm_response_with_tool_call.usage_metadata = usage_metadata
yield aggregated_llm_response_with_tool_call
else: else:
response = await self.llm_client.acompletion(**completion_args) response = await self.llm_client.acompletion(**completion_args)
yield _model_response_to_generate_content_response(response) yield _model_response_to_generate_content_response(response)

View File

@ -25,6 +25,7 @@ from google.adk.models.lite_llm import FunctionChunk
from google.adk.models.lite_llm import LiteLlm from google.adk.models.lite_llm import LiteLlm
from google.adk.models.lite_llm import LiteLLMClient from google.adk.models.lite_llm import LiteLLMClient
from google.adk.models.lite_llm import TextChunk from google.adk.models.lite_llm import TextChunk
from google.adk.models.lite_llm import UsageMetadataChunk
from google.adk.models.llm_request import LlmRequest from google.adk.models.llm_request import LlmRequest
from google.genai import types from google.genai import types
from litellm import ChatCompletionAssistantMessage from litellm import ChatCompletionAssistantMessage
@ -314,13 +315,10 @@ litellm_append_user_content_test_cases = [
litellm_append_user_content_test_cases litellm_append_user_content_test_cases
) )
def test_maybe_append_user_content(lite_llm_instance, llm_request, expected_output): def test_maybe_append_user_content(lite_llm_instance, llm_request, expected_output):
lite_llm_instance._maybe_append_user_content(
llm_request
)
assert len(llm_request.contents) == expected_output
lite_llm_instance._maybe_append_user_content(llm_request)
assert len(llm_request.contents) == expected_output
function_declaration_test_cases = [ function_declaration_test_cases = [
@ -567,6 +565,80 @@ async def test_generate_content_async_with_tool_response(
assert kwargs["messages"][2]["content"] == '{"result": "test_result"}' assert kwargs["messages"][2]["content"] == '{"result": "test_result"}'
@pytest.mark.asyncio
async def test_generate_content_async(mock_acompletion, lite_llm_instance):
async for response in lite_llm_instance.generate_content_async(
LLM_REQUEST_WITH_FUNCTION_DECLARATION
):
assert response.content.role == "model"
assert response.content.parts[0].text == "Test response"
assert response.content.parts[1].function_call.name == "test_function"
assert response.content.parts[1].function_call.args == {
"test_arg": "test_value"
}
assert response.content.parts[1].function_call.id == "test_tool_call_id"
mock_acompletion.assert_called_once()
_, kwargs = mock_acompletion.call_args
assert kwargs["model"] == "test_model"
assert kwargs["messages"][0]["role"] == "user"
assert kwargs["messages"][0]["content"] == "Test prompt"
assert kwargs["tools"][0]["function"]["name"] == "test_function"
assert (
kwargs["tools"][0]["function"]["description"]
== "Test function description"
)
assert (
kwargs["tools"][0]["function"]["parameters"]["properties"]["test_arg"][
"type"
]
== "string"
)
@pytest.mark.asyncio
async def test_generate_content_async_with_usage_metadata(
lite_llm_instance, mock_acompletion
):
mock_response_with_usage_metadata = ModelResponse(
choices=[
Choices(
message=ChatCompletionAssistantMessage(
role="assistant",
content="Test response",
)
)
],
usage={
"prompt_tokens": 10,
"completion_tokens": 5,
"total_tokens": 15,
},
)
mock_acompletion.return_value = mock_response_with_usage_metadata
llm_request = LlmRequest(
contents=[
types.Content(
role="user", parts=[types.Part.from_text(text="Test prompt")]
),
],
config=types.GenerateContentConfig(
system_instruction="test instruction",
),
)
async for response in lite_llm_instance.generate_content_async(llm_request):
assert response.content.role == "model"
assert response.content.parts[0].text == "Test response"
assert response.usage_metadata.prompt_token_count == 10
assert response.usage_metadata.candidates_token_count == 5
assert response.usage_metadata.total_token_count == 15
mock_acompletion.assert_called_once()
def test_content_to_message_param_user_message(): def test_content_to_message_param_user_message():
content = types.Content( content = types.Content(
role="user", parts=[types.Part.from_text(text="Test prompt")] role="user", parts=[types.Part.from_text(text="Test prompt")]
@ -704,7 +776,7 @@ def test_to_litellm_role():
@pytest.mark.parametrize( @pytest.mark.parametrize(
"response, expected_chunk, expected_finished", "response, expected_chunks, expected_finished",
[ [
( (
ModelResponse( ModelResponse(
@ -716,7 +788,35 @@ def test_to_litellm_role():
} }
] ]
), ),
TextChunk(text="this is a test"), [
TextChunk(text="this is a test"),
UsageMetadataChunk(
prompt_tokens=0, completion_tokens=0, total_tokens=0
),
],
"stop",
),
(
ModelResponse(
choices=[
{
"message": {
"content": "this is a test",
}
}
],
usage={
"prompt_tokens": 3,
"completion_tokens": 5,
"total_tokens": 8,
},
),
[
TextChunk(text="this is a test"),
UsageMetadataChunk(
prompt_tokens=3, completion_tokens=5, total_tokens=8
),
],
"stop", "stop",
), ),
( (
@ -741,28 +841,53 @@ def test_to_litellm_role():
) )
] ]
), ),
FunctionChunk(id="1", name="test_function", args='{"key": "va'), [
FunctionChunk(id="1", name="test_function", args='{"key": "va'),
UsageMetadataChunk(
prompt_tokens=0, completion_tokens=0, total_tokens=0
),
],
None, None,
), ),
( (
ModelResponse(choices=[{"finish_reason": "tool_calls"}]), ModelResponse(choices=[{"finish_reason": "tool_calls"}]),
None, [
None,
UsageMetadataChunk(
prompt_tokens=0, completion_tokens=0, total_tokens=0
),
],
"tool_calls", "tool_calls",
), ),
(ModelResponse(choices=[{}]), None, "stop"), (
ModelResponse(choices=[{}]),
[
None,
UsageMetadataChunk(
prompt_tokens=0, completion_tokens=0, total_tokens=0
),
],
"stop",
),
], ],
) )
def test_model_response_to_chunk(response, expected_chunk, expected_finished): def test_model_response_to_chunk(response, expected_chunks, expected_finished):
result = list(_model_response_to_chunk(response)) result = list(_model_response_to_chunk(response))
assert len(result) == 1 assert len(result) == 2
chunk, finished = result[0] chunk, finished = result[0]
if expected_chunk: if expected_chunks:
assert isinstance(chunk, type(expected_chunk)) assert isinstance(chunk, type(expected_chunks[0]))
assert chunk == expected_chunk assert chunk == expected_chunks[0]
else: else:
assert chunk is None assert chunk is None
assert finished == expected_finished assert finished == expected_finished
usage_chunk, _ = result[1]
assert usage_chunk is not None
assert usage_chunk.prompt_tokens == expected_chunks[1].prompt_tokens
assert usage_chunk.completion_tokens == expected_chunks[1].completion_tokens
assert usage_chunk.total_tokens == expected_chunks[1].total_tokens
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_acompletion_additional_args(mock_acompletion, mock_client): async def test_acompletion_additional_args(mock_acompletion, mock_client):
@ -893,3 +1018,71 @@ async def test_generate_content_async_stream(
] ]
== "string" == "string"
) )
@pytest.mark.asyncio
async def test_generate_content_async_stream_with_usage_metadata(
mock_completion, lite_llm_instance
):
streaming_model_response_with_usage_metadata = [
*STREAMING_MODEL_RESPONSE,
ModelResponse(
usage={
"prompt_tokens": 10,
"completion_tokens": 5,
"total_tokens": 15,
},
choices=[
StreamingChoices(
finish_reason=None,
)
],
),
]
mock_completion.return_value = iter(
streaming_model_response_with_usage_metadata
)
responses = [
response
async for response in lite_llm_instance.generate_content_async(
LLM_REQUEST_WITH_FUNCTION_DECLARATION, stream=True
)
]
assert len(responses) == 4
assert responses[0].content.role == "model"
assert responses[0].content.parts[0].text == "zero, "
assert responses[1].content.role == "model"
assert responses[1].content.parts[0].text == "one, "
assert responses[2].content.role == "model"
assert responses[2].content.parts[0].text == "two:"
assert responses[3].content.role == "model"
assert responses[3].content.parts[0].function_call.name == "test_function"
assert responses[3].content.parts[0].function_call.args == {
"test_arg": "test_value"
}
assert responses[3].content.parts[0].function_call.id == "test_tool_call_id"
assert responses[3].usage_metadata.prompt_token_count == 10
assert responses[3].usage_metadata.candidates_token_count == 5
assert responses[3].usage_metadata.total_token_count == 15
mock_completion.assert_called_once()
_, kwargs = mock_completion.call_args
assert kwargs["model"] == "test_model"
assert kwargs["messages"][0]["role"] == "user"
assert kwargs["messages"][0]["content"] == "Test prompt"
assert kwargs["tools"][0]["function"]["name"] == "test_function"
assert (
kwargs["tools"][0]["function"]["description"]
== "Test function description"
)
assert (
kwargs["tools"][0]["function"]["parameters"]["properties"]["test_arg"][
"type"
]
== "string"
)