Add token usage to gemini (streaming), litellm and anthropic

Also included a token_usage sample that showcases the token usage of subagents with different models under a parent agent.

PiperOrigin-RevId: 759347015
This commit is contained in:
Selcuk Gun 2025-05-15 16:22:04 -07:00 committed by Copybara-Service
parent 4d5760917d
commit 509db3f9fb
8 changed files with 515 additions and 45 deletions

View File

@ -0,0 +1,15 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from . import agent

View File

@ -0,0 +1,97 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import random
from google.adk import Agent
from google.adk.agents.llm_agent import LlmAgent
from google.adk.agents.sequential_agent import SequentialAgent
from google.adk.models.anthropic_llm import Claude
from google.adk.models.lite_llm import LiteLlm
from google.adk.planners import BuiltInPlanner
from google.adk.planners import PlanReActPlanner
from google.adk.tools.tool_context import ToolContext
from google.genai import types
def roll_die(sides: int, tool_context: ToolContext) -> int:
"""Roll a die and return the rolled result.
Args:
sides: The integer number of sides the die has.
Returns:
An integer of the result of rolling the die.
"""
result = random.randint(1, sides)
if 'rolls' not in tool_context.state:
tool_context.state['rolls'] = []
tool_context.state['rolls'] = tool_context.state['rolls'] + [result]
return result
roll_agent_with_openai = LlmAgent(
model=LiteLlm(model='openai/gpt-4o'),
description='Handles rolling dice of different sizes.',
name='roll_agent_with_openai',
instruction="""
You are responsible for rolling dice based on the user's request.
When asked to roll a die, you must call the roll_die tool with the number of sides as an integer.
""",
tools=[roll_die],
)
roll_agent_with_claude = LlmAgent(
model=Claude(model='claude-3-7-sonnet@20250219'),
description='Handles rolling dice of different sizes.',
name='roll_agent_with_claude',
instruction="""
You are responsible for rolling dice based on the user's request.
When asked to roll a die, you must call the roll_die tool with the number of sides as an integer.
""",
tools=[roll_die],
)
roll_agent_with_litellm_claude = LlmAgent(
model=LiteLlm(model='vertex_ai/claude-3-7-sonnet'),
description='Handles rolling dice of different sizes.',
name='roll_agent_with_litellm_claude',
instruction="""
You are responsible for rolling dice based on the user's request.
When asked to roll a die, you must call the roll_die tool with the number of sides as an integer.
""",
tools=[roll_die],
)
roll_agent_with_gemini = LlmAgent(
model='gemini-2.0-flash',
description='Handles rolling dice of different sizes.',
name='roll_agent_with_gemini',
instruction="""
You are responsible for rolling dice based on the user's request.
When asked to roll a die, you must call the roll_die tool with the number of sides as an integer.
""",
tools=[roll_die],
)
root_agent = SequentialAgent(
name='code_pipeline_agent',
sub_agents=[
roll_agent_with_openai,
roll_agent_with_claude,
roll_agent_with_litellm_claude,
roll_agent_with_gemini,
],
)

View File

@ -0,0 +1,102 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import asyncio
import time
import warnings
import agent
from dotenv import load_dotenv
from google.adk import Runner
from google.adk.agents.run_config import RunConfig
from google.adk.artifacts import InMemoryArtifactService
from google.adk.cli.utils import logs
from google.adk.sessions import InMemorySessionService
from google.adk.sessions import Session
from google.genai import types
load_dotenv(override=True)
warnings.filterwarnings('ignore', category=UserWarning)
logs.log_to_tmp_folder()
async def main():
app_name = 'my_app'
user_id_1 = 'user1'
session_service = InMemorySessionService()
artifact_service = InMemoryArtifactService()
runner = Runner(
app_name=app_name,
agent=agent.root_agent,
artifact_service=artifact_service,
session_service=session_service,
)
session_11 = session_service.create_session(
app_name=app_name, user_id=user_id_1
)
total_prompt_tokens = 0
total_candidate_tokens = 0
total_tokens = 0
async def run_prompt(session: Session, new_message: str):
nonlocal total_prompt_tokens
nonlocal total_candidate_tokens
nonlocal total_tokens
content = types.Content(
role='user', parts=[types.Part.from_text(text=new_message)]
)
print('** User says:', content.model_dump(exclude_none=True))
async for event in runner.run_async(
user_id=user_id_1,
session_id=session.id,
new_message=content,
):
if event.content.parts and event.content.parts[0].text:
print(f'** {event.author}: {event.content.parts[0].text}')
if event.usage_metadata:
total_prompt_tokens += event.usage_metadata.prompt_token_count or 0
total_candidate_tokens += (
event.usage_metadata.candidates_token_count or 0
)
total_tokens += event.usage_metadata.total_token_count or 0
print(
'Turn tokens:'
f' {event.usage_metadata.total_token_count} (prompt={event.usage_metadata.prompt_token_count},'
f' candidates={event.usage_metadata.candidates_token_count})'
)
print(
f'Session tokens: {total_tokens} (prompt={total_prompt_tokens},'
f' candidates={total_candidate_tokens})'
)
start_time = time.time()
print('Start time:', start_time)
print('------------------------------------')
await run_prompt(session_11, 'Hi')
await run_prompt(session_11, 'Roll a die with 100 sides')
print(
await artifact_service.list_artifact_keys(
app_name=app_name, user_id=user_id_1, session_id=session_11.id
)
)
end_time = time.time()
print('------------------------------------')
print('End time:', end_time)
print('Total time:', end_time - start_time)
if __name__ == '__main__':
asyncio.run(main())

View File

@ -140,15 +140,15 @@ def message_to_generate_content_response(
role="model",
parts=[content_block_to_part(cb) for cb in message.content],
),
usage_metadata=types.GenerateContentResponseUsageMetadata(
prompt_token_count=message.usage.input_tokens,
candidates_token_count=message.usage.output_tokens,
total_token_count=(
message.usage.input_tokens + message.usage.output_tokens
),
),
# TODO: Deal with these later.
# finish_reason=to_google_genai_finish_reason(message.stop_reason),
# usage_metadata=types.GenerateContentResponseUsageMetadata(
# prompt_token_count=message.usage.input_tokens,
# candidates_token_count=message.usage.output_tokens,
# total_token_count=(
# message.usage.input_tokens + message.usage.output_tokens
# ),
# ),
)
@ -196,6 +196,12 @@ def function_declaration_to_tool_param(
class Claude(BaseLlm):
""" "Integration with Claude models served from Vertex AI.
Attributes:
model: The name of the Claude model.
"""
model: str = "claude-3-5-sonnet-v2@20241022"
@staticmethod

View File

@ -121,6 +121,7 @@ class Gemini(BaseLlm):
content=types.ModelContent(
parts=[types.Part.from_text(text=text)],
),
usage_metadata=llm_response.usage_metadata,
)
text = ''
yield llm_response
@ -174,7 +175,7 @@ class Gemini(BaseLlm):
@cached_property
def _live_api_client(self) -> Client:
if self._api_backend == 'vertex':
#use beta version for vertex api
# use beta version for vertex api
api_version = 'v1beta1'
# use default api version for vertex
return Client(

View File

@ -67,6 +67,12 @@ class TextChunk(BaseModel):
text: str
class UsageMetadataChunk(BaseModel):
prompt_tokens: int
completion_tokens: int
total_tokens: int
class LiteLLMClient:
"""Provides acompletion method (for better testability)."""
@ -344,15 +350,20 @@ def _function_declaration_to_tool_param(
def _model_response_to_chunk(
response: ModelResponse,
) -> Generator[
Tuple[Optional[Union[TextChunk, FunctionChunk]], Optional[str]], None, None
Tuple[
Optional[Union[TextChunk, FunctionChunk, UsageMetadataChunk]],
Optional[str],
],
None,
None,
]:
"""Converts a litellm message to text or function chunk.
"""Converts a litellm message to text, function or usage metadata chunk.
Args:
response: The response from the model.
Yields:
A tuple of text or function chunk and finish reason.
A tuple of text or function or usage metadata chunk and finish reason.
"""
message = None
@ -384,11 +395,21 @@ def _model_response_to_chunk(
if not message:
yield None, None
# Ideally usage would be expected with the last ModelResponseStream with a
# finish_reason set. But this is not the case we are observing from litellm.
# So we are sending it as a separate chunk to be set on the llm_response.
if response.get("usage", None):
yield UsageMetadataChunk(
prompt_tokens=response["usage"].get("prompt_tokens", 0),
completion_tokens=response["usage"].get("completion_tokens", 0),
total_tokens=response["usage"].get("total_tokens", 0),
), None
def _model_response_to_generate_content_response(
response: ModelResponse,
) -> LlmResponse:
"""Converts a litellm response to LlmResponse.
"""Converts a litellm response to LlmResponse. Also adds usage metadata.
Args:
response: The model response.
@ -403,7 +424,15 @@ def _model_response_to_generate_content_response(
if not message:
raise ValueError("No message in response")
return _message_to_generate_content_response(message)
llm_response = _message_to_generate_content_response(message)
if response.get("usage", None):
llm_response.usage_metadata = types.GenerateContentResponseUsageMetadata(
prompt_token_count=response["usage"].get("prompt_tokens", 0),
candidates_token_count=response["usage"].get("completion_tokens", 0),
total_token_count=response["usage"].get("total_tokens", 0),
)
return llm_response
def _message_to_generate_content_response(
@ -628,6 +657,10 @@ class LiteLlm(BaseLlm):
function_args = ""
function_id = None
completion_args["stream"] = True
aggregated_llm_response = None
aggregated_llm_response_with_tool_call = None
usage_metadata = None
for part in self.llm_client.completion(**completion_args):
for chunk, finish_reason in _model_response_to_chunk(part):
if isinstance(chunk, FunctionChunk):
@ -645,8 +678,16 @@ class LiteLlm(BaseLlm):
),
is_partial=True,
)
elif isinstance(chunk, UsageMetadataChunk):
usage_metadata = types.GenerateContentResponseUsageMetadata(
prompt_token_count=chunk.prompt_tokens,
candidates_token_count=chunk.completion_tokens,
total_token_count=chunk.total_tokens,
)
if finish_reason == "tool_calls" and function_id:
yield _message_to_generate_content_response(
aggregated_llm_response_with_tool_call = (
_message_to_generate_content_response(
ChatCompletionAssistantMessage(
role="assistant",
content="",
@ -662,15 +703,30 @@ class LiteLlm(BaseLlm):
],
)
)
)
function_name = ""
function_args = ""
function_id = None
elif finish_reason == "stop" and text:
yield _message_to_generate_content_response(
aggregated_llm_response = _message_to_generate_content_response(
ChatCompletionAssistantMessage(role="assistant", content=text)
)
text = ""
# waiting until streaming ends to yield the llm_response as litellm tends
# to send chunk that contains usage_metadata after the chunk with
# finish_reason set to tool_calls or stop.
if aggregated_llm_response:
if usage_metadata:
aggregated_llm_response.usage_metadata = usage_metadata
usage_metadata = None
yield aggregated_llm_response
if aggregated_llm_response_with_tool_call:
if usage_metadata:
aggregated_llm_response_with_tool_call.usage_metadata = usage_metadata
yield aggregated_llm_response_with_tool_call
else:
response = await self.llm_client.acompletion(**completion_args)
yield _model_response_to_generate_content_response(response)

View File

@ -25,6 +25,7 @@ from google.adk.models.lite_llm import FunctionChunk
from google.adk.models.lite_llm import LiteLlm
from google.adk.models.lite_llm import LiteLLMClient
from google.adk.models.lite_llm import TextChunk
from google.adk.models.lite_llm import UsageMetadataChunk
from google.adk.models.llm_request import LlmRequest
from google.genai import types
from litellm import ChatCompletionAssistantMessage
@ -315,14 +316,11 @@ litellm_append_user_content_test_cases = [
)
def test_maybe_append_user_content(lite_llm_instance, llm_request, expected_output):
lite_llm_instance._maybe_append_user_content(
llm_request
)
lite_llm_instance._maybe_append_user_content(llm_request)
assert len(llm_request.contents) == expected_output
function_declaration_test_cases = [
(
"simple_function",
@ -567,6 +565,80 @@ async def test_generate_content_async_with_tool_response(
assert kwargs["messages"][2]["content"] == '{"result": "test_result"}'
@pytest.mark.asyncio
async def test_generate_content_async(mock_acompletion, lite_llm_instance):
async for response in lite_llm_instance.generate_content_async(
LLM_REQUEST_WITH_FUNCTION_DECLARATION
):
assert response.content.role == "model"
assert response.content.parts[0].text == "Test response"
assert response.content.parts[1].function_call.name == "test_function"
assert response.content.parts[1].function_call.args == {
"test_arg": "test_value"
}
assert response.content.parts[1].function_call.id == "test_tool_call_id"
mock_acompletion.assert_called_once()
_, kwargs = mock_acompletion.call_args
assert kwargs["model"] == "test_model"
assert kwargs["messages"][0]["role"] == "user"
assert kwargs["messages"][0]["content"] == "Test prompt"
assert kwargs["tools"][0]["function"]["name"] == "test_function"
assert (
kwargs["tools"][0]["function"]["description"]
== "Test function description"
)
assert (
kwargs["tools"][0]["function"]["parameters"]["properties"]["test_arg"][
"type"
]
== "string"
)
@pytest.mark.asyncio
async def test_generate_content_async_with_usage_metadata(
lite_llm_instance, mock_acompletion
):
mock_response_with_usage_metadata = ModelResponse(
choices=[
Choices(
message=ChatCompletionAssistantMessage(
role="assistant",
content="Test response",
)
)
],
usage={
"prompt_tokens": 10,
"completion_tokens": 5,
"total_tokens": 15,
},
)
mock_acompletion.return_value = mock_response_with_usage_metadata
llm_request = LlmRequest(
contents=[
types.Content(
role="user", parts=[types.Part.from_text(text="Test prompt")]
),
],
config=types.GenerateContentConfig(
system_instruction="test instruction",
),
)
async for response in lite_llm_instance.generate_content_async(llm_request):
assert response.content.role == "model"
assert response.content.parts[0].text == "Test response"
assert response.usage_metadata.prompt_token_count == 10
assert response.usage_metadata.candidates_token_count == 5
assert response.usage_metadata.total_token_count == 15
mock_acompletion.assert_called_once()
def test_content_to_message_param_user_message():
content = types.Content(
role="user", parts=[types.Part.from_text(text="Test prompt")]
@ -704,7 +776,7 @@ def test_to_litellm_role():
@pytest.mark.parametrize(
"response, expected_chunk, expected_finished",
"response, expected_chunks, expected_finished",
[
(
ModelResponse(
@ -716,7 +788,35 @@ def test_to_litellm_role():
}
]
),
[
TextChunk(text="this is a test"),
UsageMetadataChunk(
prompt_tokens=0, completion_tokens=0, total_tokens=0
),
],
"stop",
),
(
ModelResponse(
choices=[
{
"message": {
"content": "this is a test",
}
}
],
usage={
"prompt_tokens": 3,
"completion_tokens": 5,
"total_tokens": 8,
},
),
[
TextChunk(text="this is a test"),
UsageMetadataChunk(
prompt_tokens=3, completion_tokens=5, total_tokens=8
),
],
"stop",
),
(
@ -741,28 +841,53 @@ def test_to_litellm_role():
)
]
),
[
FunctionChunk(id="1", name="test_function", args='{"key": "va'),
UsageMetadataChunk(
prompt_tokens=0, completion_tokens=0, total_tokens=0
),
],
None,
),
(
ModelResponse(choices=[{"finish_reason": "tool_calls"}]),
[
None,
UsageMetadataChunk(
prompt_tokens=0, completion_tokens=0, total_tokens=0
),
],
"tool_calls",
),
(ModelResponse(choices=[{}]), None, "stop"),
(
ModelResponse(choices=[{}]),
[
None,
UsageMetadataChunk(
prompt_tokens=0, completion_tokens=0, total_tokens=0
),
],
"stop",
),
],
)
def test_model_response_to_chunk(response, expected_chunk, expected_finished):
def test_model_response_to_chunk(response, expected_chunks, expected_finished):
result = list(_model_response_to_chunk(response))
assert len(result) == 1
assert len(result) == 2
chunk, finished = result[0]
if expected_chunk:
assert isinstance(chunk, type(expected_chunk))
assert chunk == expected_chunk
if expected_chunks:
assert isinstance(chunk, type(expected_chunks[0]))
assert chunk == expected_chunks[0]
else:
assert chunk is None
assert finished == expected_finished
usage_chunk, _ = result[1]
assert usage_chunk is not None
assert usage_chunk.prompt_tokens == expected_chunks[1].prompt_tokens
assert usage_chunk.completion_tokens == expected_chunks[1].completion_tokens
assert usage_chunk.total_tokens == expected_chunks[1].total_tokens
@pytest.mark.asyncio
async def test_acompletion_additional_args(mock_acompletion, mock_client):
@ -893,3 +1018,71 @@ async def test_generate_content_async_stream(
]
== "string"
)
@pytest.mark.asyncio
async def test_generate_content_async_stream_with_usage_metadata(
mock_completion, lite_llm_instance
):
streaming_model_response_with_usage_metadata = [
*STREAMING_MODEL_RESPONSE,
ModelResponse(
usage={
"prompt_tokens": 10,
"completion_tokens": 5,
"total_tokens": 15,
},
choices=[
StreamingChoices(
finish_reason=None,
)
],
),
]
mock_completion.return_value = iter(
streaming_model_response_with_usage_metadata
)
responses = [
response
async for response in lite_llm_instance.generate_content_async(
LLM_REQUEST_WITH_FUNCTION_DECLARATION, stream=True
)
]
assert len(responses) == 4
assert responses[0].content.role == "model"
assert responses[0].content.parts[0].text == "zero, "
assert responses[1].content.role == "model"
assert responses[1].content.parts[0].text == "one, "
assert responses[2].content.role == "model"
assert responses[2].content.parts[0].text == "two:"
assert responses[3].content.role == "model"
assert responses[3].content.parts[0].function_call.name == "test_function"
assert responses[3].content.parts[0].function_call.args == {
"test_arg": "test_value"
}
assert responses[3].content.parts[0].function_call.id == "test_tool_call_id"
assert responses[3].usage_metadata.prompt_token_count == 10
assert responses[3].usage_metadata.candidates_token_count == 5
assert responses[3].usage_metadata.total_token_count == 15
mock_completion.assert_called_once()
_, kwargs = mock_completion.call_args
assert kwargs["model"] == "test_model"
assert kwargs["messages"][0]["role"] == "user"
assert kwargs["messages"][0]["content"] == "Test prompt"
assert kwargs["tools"][0]["function"]["name"] == "test_function"
assert (
kwargs["tools"][0]["function"]["description"]
== "Test function description"
)
assert (
kwargs["tools"][0]["function"]["parameters"]["properties"]["test_arg"][
"type"
]
== "string"
)