mirror of
https://github.com/EvolutionAPI/adk-python.git
synced 2025-12-19 03:42:22 -06:00
Add token usage to gemini (streaming), litellm and anthropic
Also included a token_usage sample that showcases the token usage of subagents with different models under a parent agent. PiperOrigin-RevId: 759347015
This commit is contained in:
committed by
Copybara-Service
parent
4d5760917d
commit
509db3f9fb
@@ -25,6 +25,7 @@ from google.adk.models.lite_llm import FunctionChunk
|
||||
from google.adk.models.lite_llm import LiteLlm
|
||||
from google.adk.models.lite_llm import LiteLLMClient
|
||||
from google.adk.models.lite_llm import TextChunk
|
||||
from google.adk.models.lite_llm import UsageMetadataChunk
|
||||
from google.adk.models.llm_request import LlmRequest
|
||||
from google.genai import types
|
||||
from litellm import ChatCompletionAssistantMessage
|
||||
@@ -314,13 +315,10 @@ litellm_append_user_content_test_cases = [
|
||||
litellm_append_user_content_test_cases
|
||||
)
|
||||
def test_maybe_append_user_content(lite_llm_instance, llm_request, expected_output):
|
||||
|
||||
lite_llm_instance._maybe_append_user_content(
|
||||
llm_request
|
||||
)
|
||||
|
||||
assert len(llm_request.contents) == expected_output
|
||||
|
||||
lite_llm_instance._maybe_append_user_content(llm_request)
|
||||
|
||||
assert len(llm_request.contents) == expected_output
|
||||
|
||||
|
||||
function_declaration_test_cases = [
|
||||
@@ -567,6 +565,80 @@ async def test_generate_content_async_with_tool_response(
|
||||
assert kwargs["messages"][2]["content"] == '{"result": "test_result"}'
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_generate_content_async(mock_acompletion, lite_llm_instance):
|
||||
|
||||
async for response in lite_llm_instance.generate_content_async(
|
||||
LLM_REQUEST_WITH_FUNCTION_DECLARATION
|
||||
):
|
||||
assert response.content.role == "model"
|
||||
assert response.content.parts[0].text == "Test response"
|
||||
assert response.content.parts[1].function_call.name == "test_function"
|
||||
assert response.content.parts[1].function_call.args == {
|
||||
"test_arg": "test_value"
|
||||
}
|
||||
assert response.content.parts[1].function_call.id == "test_tool_call_id"
|
||||
|
||||
mock_acompletion.assert_called_once()
|
||||
|
||||
_, kwargs = mock_acompletion.call_args
|
||||
assert kwargs["model"] == "test_model"
|
||||
assert kwargs["messages"][0]["role"] == "user"
|
||||
assert kwargs["messages"][0]["content"] == "Test prompt"
|
||||
assert kwargs["tools"][0]["function"]["name"] == "test_function"
|
||||
assert (
|
||||
kwargs["tools"][0]["function"]["description"]
|
||||
== "Test function description"
|
||||
)
|
||||
assert (
|
||||
kwargs["tools"][0]["function"]["parameters"]["properties"]["test_arg"][
|
||||
"type"
|
||||
]
|
||||
== "string"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_generate_content_async_with_usage_metadata(
|
||||
lite_llm_instance, mock_acompletion
|
||||
):
|
||||
mock_response_with_usage_metadata = ModelResponse(
|
||||
choices=[
|
||||
Choices(
|
||||
message=ChatCompletionAssistantMessage(
|
||||
role="assistant",
|
||||
content="Test response",
|
||||
)
|
||||
)
|
||||
],
|
||||
usage={
|
||||
"prompt_tokens": 10,
|
||||
"completion_tokens": 5,
|
||||
"total_tokens": 15,
|
||||
},
|
||||
)
|
||||
mock_acompletion.return_value = mock_response_with_usage_metadata
|
||||
|
||||
llm_request = LlmRequest(
|
||||
contents=[
|
||||
types.Content(
|
||||
role="user", parts=[types.Part.from_text(text="Test prompt")]
|
||||
),
|
||||
],
|
||||
config=types.GenerateContentConfig(
|
||||
system_instruction="test instruction",
|
||||
),
|
||||
)
|
||||
async for response in lite_llm_instance.generate_content_async(llm_request):
|
||||
assert response.content.role == "model"
|
||||
assert response.content.parts[0].text == "Test response"
|
||||
assert response.usage_metadata.prompt_token_count == 10
|
||||
assert response.usage_metadata.candidates_token_count == 5
|
||||
assert response.usage_metadata.total_token_count == 15
|
||||
|
||||
mock_acompletion.assert_called_once()
|
||||
|
||||
|
||||
def test_content_to_message_param_user_message():
|
||||
content = types.Content(
|
||||
role="user", parts=[types.Part.from_text(text="Test prompt")]
|
||||
@@ -704,7 +776,7 @@ def test_to_litellm_role():
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"response, expected_chunk, expected_finished",
|
||||
"response, expected_chunks, expected_finished",
|
||||
[
|
||||
(
|
||||
ModelResponse(
|
||||
@@ -716,7 +788,35 @@ def test_to_litellm_role():
|
||||
}
|
||||
]
|
||||
),
|
||||
TextChunk(text="this is a test"),
|
||||
[
|
||||
TextChunk(text="this is a test"),
|
||||
UsageMetadataChunk(
|
||||
prompt_tokens=0, completion_tokens=0, total_tokens=0
|
||||
),
|
||||
],
|
||||
"stop",
|
||||
),
|
||||
(
|
||||
ModelResponse(
|
||||
choices=[
|
||||
{
|
||||
"message": {
|
||||
"content": "this is a test",
|
||||
}
|
||||
}
|
||||
],
|
||||
usage={
|
||||
"prompt_tokens": 3,
|
||||
"completion_tokens": 5,
|
||||
"total_tokens": 8,
|
||||
},
|
||||
),
|
||||
[
|
||||
TextChunk(text="this is a test"),
|
||||
UsageMetadataChunk(
|
||||
prompt_tokens=3, completion_tokens=5, total_tokens=8
|
||||
),
|
||||
],
|
||||
"stop",
|
||||
),
|
||||
(
|
||||
@@ -741,28 +841,53 @@ def test_to_litellm_role():
|
||||
)
|
||||
]
|
||||
),
|
||||
FunctionChunk(id="1", name="test_function", args='{"key": "va'),
|
||||
[
|
||||
FunctionChunk(id="1", name="test_function", args='{"key": "va'),
|
||||
UsageMetadataChunk(
|
||||
prompt_tokens=0, completion_tokens=0, total_tokens=0
|
||||
),
|
||||
],
|
||||
None,
|
||||
),
|
||||
(
|
||||
ModelResponse(choices=[{"finish_reason": "tool_calls"}]),
|
||||
None,
|
||||
[
|
||||
None,
|
||||
UsageMetadataChunk(
|
||||
prompt_tokens=0, completion_tokens=0, total_tokens=0
|
||||
),
|
||||
],
|
||||
"tool_calls",
|
||||
),
|
||||
(ModelResponse(choices=[{}]), None, "stop"),
|
||||
(
|
||||
ModelResponse(choices=[{}]),
|
||||
[
|
||||
None,
|
||||
UsageMetadataChunk(
|
||||
prompt_tokens=0, completion_tokens=0, total_tokens=0
|
||||
),
|
||||
],
|
||||
"stop",
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_model_response_to_chunk(response, expected_chunk, expected_finished):
|
||||
def test_model_response_to_chunk(response, expected_chunks, expected_finished):
|
||||
result = list(_model_response_to_chunk(response))
|
||||
assert len(result) == 1
|
||||
assert len(result) == 2
|
||||
chunk, finished = result[0]
|
||||
if expected_chunk:
|
||||
assert isinstance(chunk, type(expected_chunk))
|
||||
assert chunk == expected_chunk
|
||||
if expected_chunks:
|
||||
assert isinstance(chunk, type(expected_chunks[0]))
|
||||
assert chunk == expected_chunks[0]
|
||||
else:
|
||||
assert chunk is None
|
||||
assert finished == expected_finished
|
||||
|
||||
usage_chunk, _ = result[1]
|
||||
assert usage_chunk is not None
|
||||
assert usage_chunk.prompt_tokens == expected_chunks[1].prompt_tokens
|
||||
assert usage_chunk.completion_tokens == expected_chunks[1].completion_tokens
|
||||
assert usage_chunk.total_tokens == expected_chunks[1].total_tokens
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_acompletion_additional_args(mock_acompletion, mock_client):
|
||||
@@ -893,3 +1018,71 @@ async def test_generate_content_async_stream(
|
||||
]
|
||||
== "string"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_generate_content_async_stream_with_usage_metadata(
|
||||
mock_completion, lite_llm_instance
|
||||
):
|
||||
|
||||
streaming_model_response_with_usage_metadata = [
|
||||
*STREAMING_MODEL_RESPONSE,
|
||||
ModelResponse(
|
||||
usage={
|
||||
"prompt_tokens": 10,
|
||||
"completion_tokens": 5,
|
||||
"total_tokens": 15,
|
||||
},
|
||||
choices=[
|
||||
StreamingChoices(
|
||||
finish_reason=None,
|
||||
)
|
||||
],
|
||||
),
|
||||
]
|
||||
|
||||
mock_completion.return_value = iter(
|
||||
streaming_model_response_with_usage_metadata
|
||||
)
|
||||
|
||||
responses = [
|
||||
response
|
||||
async for response in lite_llm_instance.generate_content_async(
|
||||
LLM_REQUEST_WITH_FUNCTION_DECLARATION, stream=True
|
||||
)
|
||||
]
|
||||
assert len(responses) == 4
|
||||
assert responses[0].content.role == "model"
|
||||
assert responses[0].content.parts[0].text == "zero, "
|
||||
assert responses[1].content.role == "model"
|
||||
assert responses[1].content.parts[0].text == "one, "
|
||||
assert responses[2].content.role == "model"
|
||||
assert responses[2].content.parts[0].text == "two:"
|
||||
assert responses[3].content.role == "model"
|
||||
assert responses[3].content.parts[0].function_call.name == "test_function"
|
||||
assert responses[3].content.parts[0].function_call.args == {
|
||||
"test_arg": "test_value"
|
||||
}
|
||||
assert responses[3].content.parts[0].function_call.id == "test_tool_call_id"
|
||||
|
||||
assert responses[3].usage_metadata.prompt_token_count == 10
|
||||
assert responses[3].usage_metadata.candidates_token_count == 5
|
||||
assert responses[3].usage_metadata.total_token_count == 15
|
||||
|
||||
mock_completion.assert_called_once()
|
||||
|
||||
_, kwargs = mock_completion.call_args
|
||||
assert kwargs["model"] == "test_model"
|
||||
assert kwargs["messages"][0]["role"] == "user"
|
||||
assert kwargs["messages"][0]["content"] == "Test prompt"
|
||||
assert kwargs["tools"][0]["function"]["name"] == "test_function"
|
||||
assert (
|
||||
kwargs["tools"][0]["function"]["description"]
|
||||
== "Test function description"
|
||||
)
|
||||
assert (
|
||||
kwargs["tools"][0]["function"]["parameters"]["properties"]["test_arg"][
|
||||
"type"
|
||||
]
|
||||
== "string"
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user