Add token usage to gemini (streaming), litellm and anthropic

Also included a token_usage sample that showcases the token usage of subagents with different models under a parent agent.

PiperOrigin-RevId: 759347015
This commit is contained in:
Selcuk Gun
2025-05-15 16:22:04 -07:00
committed by Copybara-Service
parent 4d5760917d
commit 509db3f9fb
8 changed files with 515 additions and 45 deletions

View File

@@ -25,6 +25,7 @@ from google.adk.models.lite_llm import FunctionChunk
from google.adk.models.lite_llm import LiteLlm
from google.adk.models.lite_llm import LiteLLMClient
from google.adk.models.lite_llm import TextChunk
from google.adk.models.lite_llm import UsageMetadataChunk
from google.adk.models.llm_request import LlmRequest
from google.genai import types
from litellm import ChatCompletionAssistantMessage
@@ -314,13 +315,10 @@ litellm_append_user_content_test_cases = [
litellm_append_user_content_test_cases
)
def test_maybe_append_user_content(lite_llm_instance, llm_request, expected_output):
lite_llm_instance._maybe_append_user_content(
llm_request
)
assert len(llm_request.contents) == expected_output
lite_llm_instance._maybe_append_user_content(llm_request)
assert len(llm_request.contents) == expected_output
function_declaration_test_cases = [
@@ -567,6 +565,80 @@ async def test_generate_content_async_with_tool_response(
assert kwargs["messages"][2]["content"] == '{"result": "test_result"}'
@pytest.mark.asyncio
async def test_generate_content_async(mock_acompletion, lite_llm_instance):
async for response in lite_llm_instance.generate_content_async(
LLM_REQUEST_WITH_FUNCTION_DECLARATION
):
assert response.content.role == "model"
assert response.content.parts[0].text == "Test response"
assert response.content.parts[1].function_call.name == "test_function"
assert response.content.parts[1].function_call.args == {
"test_arg": "test_value"
}
assert response.content.parts[1].function_call.id == "test_tool_call_id"
mock_acompletion.assert_called_once()
_, kwargs = mock_acompletion.call_args
assert kwargs["model"] == "test_model"
assert kwargs["messages"][0]["role"] == "user"
assert kwargs["messages"][0]["content"] == "Test prompt"
assert kwargs["tools"][0]["function"]["name"] == "test_function"
assert (
kwargs["tools"][0]["function"]["description"]
== "Test function description"
)
assert (
kwargs["tools"][0]["function"]["parameters"]["properties"]["test_arg"][
"type"
]
== "string"
)
@pytest.mark.asyncio
async def test_generate_content_async_with_usage_metadata(
lite_llm_instance, mock_acompletion
):
mock_response_with_usage_metadata = ModelResponse(
choices=[
Choices(
message=ChatCompletionAssistantMessage(
role="assistant",
content="Test response",
)
)
],
usage={
"prompt_tokens": 10,
"completion_tokens": 5,
"total_tokens": 15,
},
)
mock_acompletion.return_value = mock_response_with_usage_metadata
llm_request = LlmRequest(
contents=[
types.Content(
role="user", parts=[types.Part.from_text(text="Test prompt")]
),
],
config=types.GenerateContentConfig(
system_instruction="test instruction",
),
)
async for response in lite_llm_instance.generate_content_async(llm_request):
assert response.content.role == "model"
assert response.content.parts[0].text == "Test response"
assert response.usage_metadata.prompt_token_count == 10
assert response.usage_metadata.candidates_token_count == 5
assert response.usage_metadata.total_token_count == 15
mock_acompletion.assert_called_once()
def test_content_to_message_param_user_message():
content = types.Content(
role="user", parts=[types.Part.from_text(text="Test prompt")]
@@ -704,7 +776,7 @@ def test_to_litellm_role():
@pytest.mark.parametrize(
"response, expected_chunk, expected_finished",
"response, expected_chunks, expected_finished",
[
(
ModelResponse(
@@ -716,7 +788,35 @@ def test_to_litellm_role():
}
]
),
TextChunk(text="this is a test"),
[
TextChunk(text="this is a test"),
UsageMetadataChunk(
prompt_tokens=0, completion_tokens=0, total_tokens=0
),
],
"stop",
),
(
ModelResponse(
choices=[
{
"message": {
"content": "this is a test",
}
}
],
usage={
"prompt_tokens": 3,
"completion_tokens": 5,
"total_tokens": 8,
},
),
[
TextChunk(text="this is a test"),
UsageMetadataChunk(
prompt_tokens=3, completion_tokens=5, total_tokens=8
),
],
"stop",
),
(
@@ -741,28 +841,53 @@ def test_to_litellm_role():
)
]
),
FunctionChunk(id="1", name="test_function", args='{"key": "va'),
[
FunctionChunk(id="1", name="test_function", args='{"key": "va'),
UsageMetadataChunk(
prompt_tokens=0, completion_tokens=0, total_tokens=0
),
],
None,
),
(
ModelResponse(choices=[{"finish_reason": "tool_calls"}]),
None,
[
None,
UsageMetadataChunk(
prompt_tokens=0, completion_tokens=0, total_tokens=0
),
],
"tool_calls",
),
(ModelResponse(choices=[{}]), None, "stop"),
(
ModelResponse(choices=[{}]),
[
None,
UsageMetadataChunk(
prompt_tokens=0, completion_tokens=0, total_tokens=0
),
],
"stop",
),
],
)
def test_model_response_to_chunk(response, expected_chunk, expected_finished):
def test_model_response_to_chunk(response, expected_chunks, expected_finished):
result = list(_model_response_to_chunk(response))
assert len(result) == 1
assert len(result) == 2
chunk, finished = result[0]
if expected_chunk:
assert isinstance(chunk, type(expected_chunk))
assert chunk == expected_chunk
if expected_chunks:
assert isinstance(chunk, type(expected_chunks[0]))
assert chunk == expected_chunks[0]
else:
assert chunk is None
assert finished == expected_finished
usage_chunk, _ = result[1]
assert usage_chunk is not None
assert usage_chunk.prompt_tokens == expected_chunks[1].prompt_tokens
assert usage_chunk.completion_tokens == expected_chunks[1].completion_tokens
assert usage_chunk.total_tokens == expected_chunks[1].total_tokens
@pytest.mark.asyncio
async def test_acompletion_additional_args(mock_acompletion, mock_client):
@@ -893,3 +1018,71 @@ async def test_generate_content_async_stream(
]
== "string"
)
@pytest.mark.asyncio
async def test_generate_content_async_stream_with_usage_metadata(
mock_completion, lite_llm_instance
):
streaming_model_response_with_usage_metadata = [
*STREAMING_MODEL_RESPONSE,
ModelResponse(
usage={
"prompt_tokens": 10,
"completion_tokens": 5,
"total_tokens": 15,
},
choices=[
StreamingChoices(
finish_reason=None,
)
],
),
]
mock_completion.return_value = iter(
streaming_model_response_with_usage_metadata
)
responses = [
response
async for response in lite_llm_instance.generate_content_async(
LLM_REQUEST_WITH_FUNCTION_DECLARATION, stream=True
)
]
assert len(responses) == 4
assert responses[0].content.role == "model"
assert responses[0].content.parts[0].text == "zero, "
assert responses[1].content.role == "model"
assert responses[1].content.parts[0].text == "one, "
assert responses[2].content.role == "model"
assert responses[2].content.parts[0].text == "two:"
assert responses[3].content.role == "model"
assert responses[3].content.parts[0].function_call.name == "test_function"
assert responses[3].content.parts[0].function_call.args == {
"test_arg": "test_value"
}
assert responses[3].content.parts[0].function_call.id == "test_tool_call_id"
assert responses[3].usage_metadata.prompt_token_count == 10
assert responses[3].usage_metadata.candidates_token_count == 5
assert responses[3].usage_metadata.total_token_count == 15
mock_completion.assert_called_once()
_, kwargs = mock_completion.call_args
assert kwargs["model"] == "test_model"
assert kwargs["messages"][0]["role"] == "user"
assert kwargs["messages"][0]["content"] == "Test prompt"
assert kwargs["tools"][0]["function"]["name"] == "test_function"
assert (
kwargs["tools"][0]["function"]["description"]
== "Test function description"
)
assert (
kwargs["tools"][0]["function"]["parameters"]["properties"]["test_arg"][
"type"
]
== "string"
)