From 509db3f9fb3c710dac5bcb72284506acb85687bb Mon Sep 17 00:00:00 2001 From: Selcuk Gun Date: Thu, 15 May 2025 16:22:04 -0700 Subject: [PATCH] Add token usage to gemini (streaming), litellm and anthropic Also included a token_usage sample that showcases the token usage of subagents with different models under a parent agent. PiperOrigin-RevId: 759347015 --- contributing/samples/token_usage/__init__.py | 15 ++ contributing/samples/token_usage/agent.py | 97 ++++++++ .../samples/token_usage/asyncio_run.py | 102 ++++++++ .../adk/flows/llm_flows/base_llm_flow.py | 2 +- src/google/adk/models/anthropic_llm.py | 20 +- src/google/adk/models/google_llm.py | 3 +- src/google/adk/models/lite_llm.py | 96 ++++++-- tests/unittests/models/test_litellm.py | 225 ++++++++++++++++-- 8 files changed, 515 insertions(+), 45 deletions(-) create mode 100755 contributing/samples/token_usage/__init__.py create mode 100755 contributing/samples/token_usage/agent.py create mode 100755 contributing/samples/token_usage/asyncio_run.py diff --git a/contributing/samples/token_usage/__init__.py b/contributing/samples/token_usage/__init__.py new file mode 100755 index 0000000..c48963c --- /dev/null +++ b/contributing/samples/token_usage/__init__.py @@ -0,0 +1,15 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from . import agent diff --git a/contributing/samples/token_usage/agent.py b/contributing/samples/token_usage/agent.py new file mode 100755 index 0000000..65990ce --- /dev/null +++ b/contributing/samples/token_usage/agent.py @@ -0,0 +1,97 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import random + +from google.adk import Agent +from google.adk.agents.llm_agent import LlmAgent +from google.adk.agents.sequential_agent import SequentialAgent +from google.adk.models.anthropic_llm import Claude +from google.adk.models.lite_llm import LiteLlm +from google.adk.planners import BuiltInPlanner +from google.adk.planners import PlanReActPlanner +from google.adk.tools.tool_context import ToolContext +from google.genai import types + + +def roll_die(sides: int, tool_context: ToolContext) -> int: + """Roll a die and return the rolled result. + + Args: + sides: The integer number of sides the die has. + + Returns: + An integer of the result of rolling the die. + """ + result = random.randint(1, sides) + if 'rolls' not in tool_context.state: + tool_context.state['rolls'] = [] + + tool_context.state['rolls'] = tool_context.state['rolls'] + [result] + return result + + +roll_agent_with_openai = LlmAgent( + model=LiteLlm(model='openai/gpt-4o'), + description='Handles rolling dice of different sizes.', + name='roll_agent_with_openai', + instruction=""" + You are responsible for rolling dice based on the user's request. + When asked to roll a die, you must call the roll_die tool with the number of sides as an integer. + """, + tools=[roll_die], +) + +roll_agent_with_claude = LlmAgent( + model=Claude(model='claude-3-7-sonnet@20250219'), + description='Handles rolling dice of different sizes.', + name='roll_agent_with_claude', + instruction=""" + You are responsible for rolling dice based on the user's request. + When asked to roll a die, you must call the roll_die tool with the number of sides as an integer. + """, + tools=[roll_die], +) + +roll_agent_with_litellm_claude = LlmAgent( + model=LiteLlm(model='vertex_ai/claude-3-7-sonnet'), + description='Handles rolling dice of different sizes.', + name='roll_agent_with_litellm_claude', + instruction=""" + You are responsible for rolling dice based on the user's request. + When asked to roll a die, you must call the roll_die tool with the number of sides as an integer. + """, + tools=[roll_die], +) + +roll_agent_with_gemini = LlmAgent( + model='gemini-2.0-flash', + description='Handles rolling dice of different sizes.', + name='roll_agent_with_gemini', + instruction=""" + You are responsible for rolling dice based on the user's request. + When asked to roll a die, you must call the roll_die tool with the number of sides as an integer. + """, + tools=[roll_die], +) + +root_agent = SequentialAgent( + name='code_pipeline_agent', + sub_agents=[ + roll_agent_with_openai, + roll_agent_with_claude, + roll_agent_with_litellm_claude, + roll_agent_with_gemini, + ], +) diff --git a/contributing/samples/token_usage/asyncio_run.py b/contributing/samples/token_usage/asyncio_run.py new file mode 100755 index 0000000..f169756 --- /dev/null +++ b/contributing/samples/token_usage/asyncio_run.py @@ -0,0 +1,102 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import asyncio +import time +import warnings + +import agent +from dotenv import load_dotenv +from google.adk import Runner +from google.adk.agents.run_config import RunConfig +from google.adk.artifacts import InMemoryArtifactService +from google.adk.cli.utils import logs +from google.adk.sessions import InMemorySessionService +from google.adk.sessions import Session +from google.genai import types + +load_dotenv(override=True) +warnings.filterwarnings('ignore', category=UserWarning) +logs.log_to_tmp_folder() + + +async def main(): + app_name = 'my_app' + user_id_1 = 'user1' + session_service = InMemorySessionService() + artifact_service = InMemoryArtifactService() + runner = Runner( + app_name=app_name, + agent=agent.root_agent, + artifact_service=artifact_service, + session_service=session_service, + ) + session_11 = session_service.create_session( + app_name=app_name, user_id=user_id_1 + ) + + total_prompt_tokens = 0 + total_candidate_tokens = 0 + total_tokens = 0 + + async def run_prompt(session: Session, new_message: str): + nonlocal total_prompt_tokens + nonlocal total_candidate_tokens + nonlocal total_tokens + content = types.Content( + role='user', parts=[types.Part.from_text(text=new_message)] + ) + print('** User says:', content.model_dump(exclude_none=True)) + async for event in runner.run_async( + user_id=user_id_1, + session_id=session.id, + new_message=content, + ): + if event.content.parts and event.content.parts[0].text: + print(f'** {event.author}: {event.content.parts[0].text}') + if event.usage_metadata: + total_prompt_tokens += event.usage_metadata.prompt_token_count or 0 + total_candidate_tokens += ( + event.usage_metadata.candidates_token_count or 0 + ) + total_tokens += event.usage_metadata.total_token_count or 0 + print( + 'Turn tokens:' + f' {event.usage_metadata.total_token_count} (prompt={event.usage_metadata.prompt_token_count},' + f' candidates={event.usage_metadata.candidates_token_count})' + ) + + print( + f'Session tokens: {total_tokens} (prompt={total_prompt_tokens},' + f' candidates={total_candidate_tokens})' + ) + + start_time = time.time() + print('Start time:', start_time) + print('------------------------------------') + await run_prompt(session_11, 'Hi') + await run_prompt(session_11, 'Roll a die with 100 sides') + print( + await artifact_service.list_artifact_keys( + app_name=app_name, user_id=user_id_1, session_id=session_11.id + ) + ) + end_time = time.time() + print('------------------------------------') + print('End time:', end_time) + print('Total time:', end_time - start_time) + + +if __name__ == '__main__': + asyncio.run(main()) diff --git a/src/google/adk/flows/llm_flows/base_llm_flow.py b/src/google/adk/flows/llm_flows/base_llm_flow.py index b6b45fc..bece41f 100644 --- a/src/google/adk/flows/llm_flows/base_llm_flow.py +++ b/src/google/adk/flows/llm_flows/base_llm_flow.py @@ -218,7 +218,7 @@ class BaseLlmFlow(ABC): When the model returns transcription, the author is "user". Otherwise, the author is the agent name(not 'model'). - + Args: llm_response: The LLM response from the LLM call. """ diff --git a/src/google/adk/models/anthropic_llm.py b/src/google/adk/models/anthropic_llm.py index 1544edd..0fc7502 100644 --- a/src/google/adk/models/anthropic_llm.py +++ b/src/google/adk/models/anthropic_llm.py @@ -140,15 +140,15 @@ def message_to_generate_content_response( role="model", parts=[content_block_to_part(cb) for cb in message.content], ), + usage_metadata=types.GenerateContentResponseUsageMetadata( + prompt_token_count=message.usage.input_tokens, + candidates_token_count=message.usage.output_tokens, + total_token_count=( + message.usage.input_tokens + message.usage.output_tokens + ), + ), # TODO: Deal with these later. # finish_reason=to_google_genai_finish_reason(message.stop_reason), - # usage_metadata=types.GenerateContentResponseUsageMetadata( - # prompt_token_count=message.usage.input_tokens, - # candidates_token_count=message.usage.output_tokens, - # total_token_count=( - # message.usage.input_tokens + message.usage.output_tokens - # ), - # ), ) @@ -196,6 +196,12 @@ def function_declaration_to_tool_param( class Claude(BaseLlm): + """ "Integration with Claude models served from Vertex AI. + + Attributes: + model: The name of the Claude model. + """ + model: str = "claude-3-5-sonnet-v2@20241022" @staticmethod diff --git a/src/google/adk/models/google_llm.py b/src/google/adk/models/google_llm.py index e4f21e5..5b80602 100644 --- a/src/google/adk/models/google_llm.py +++ b/src/google/adk/models/google_llm.py @@ -121,6 +121,7 @@ class Gemini(BaseLlm): content=types.ModelContent( parts=[types.Part.from_text(text=text)], ), + usage_metadata=llm_response.usage_metadata, ) text = '' yield llm_response @@ -174,7 +175,7 @@ class Gemini(BaseLlm): @cached_property def _live_api_client(self) -> Client: if self._api_backend == 'vertex': - #use beta version for vertex api + # use beta version for vertex api api_version = 'v1beta1' # use default api version for vertex return Client( diff --git a/src/google/adk/models/lite_llm.py b/src/google/adk/models/lite_llm.py index 63abd78..27c856d 100644 --- a/src/google/adk/models/lite_llm.py +++ b/src/google/adk/models/lite_llm.py @@ -67,6 +67,12 @@ class TextChunk(BaseModel): text: str +class UsageMetadataChunk(BaseModel): + prompt_tokens: int + completion_tokens: int + total_tokens: int + + class LiteLLMClient: """Provides acompletion method (for better testability).""" @@ -344,15 +350,20 @@ def _function_declaration_to_tool_param( def _model_response_to_chunk( response: ModelResponse, ) -> Generator[ - Tuple[Optional[Union[TextChunk, FunctionChunk]], Optional[str]], None, None + Tuple[ + Optional[Union[TextChunk, FunctionChunk, UsageMetadataChunk]], + Optional[str], + ], + None, + None, ]: - """Converts a litellm message to text or function chunk. + """Converts a litellm message to text, function or usage metadata chunk. Args: response: The response from the model. Yields: - A tuple of text or function chunk and finish reason. + A tuple of text or function or usage metadata chunk and finish reason. """ message = None @@ -384,11 +395,21 @@ def _model_response_to_chunk( if not message: yield None, None + # Ideally usage would be expected with the last ModelResponseStream with a + # finish_reason set. But this is not the case we are observing from litellm. + # So we are sending it as a separate chunk to be set on the llm_response. + if response.get("usage", None): + yield UsageMetadataChunk( + prompt_tokens=response["usage"].get("prompt_tokens", 0), + completion_tokens=response["usage"].get("completion_tokens", 0), + total_tokens=response["usage"].get("total_tokens", 0), + ), None + def _model_response_to_generate_content_response( response: ModelResponse, ) -> LlmResponse: - """Converts a litellm response to LlmResponse. + """Converts a litellm response to LlmResponse. Also adds usage metadata. Args: response: The model response. @@ -403,7 +424,15 @@ def _model_response_to_generate_content_response( if not message: raise ValueError("No message in response") - return _message_to_generate_content_response(message) + + llm_response = _message_to_generate_content_response(message) + if response.get("usage", None): + llm_response.usage_metadata = types.GenerateContentResponseUsageMetadata( + prompt_token_count=response["usage"].get("prompt_tokens", 0), + candidates_token_count=response["usage"].get("completion_tokens", 0), + total_token_count=response["usage"].get("total_tokens", 0), + ) + return llm_response def _message_to_generate_content_response( @@ -628,6 +657,10 @@ class LiteLlm(BaseLlm): function_args = "" function_id = None completion_args["stream"] = True + aggregated_llm_response = None + aggregated_llm_response_with_tool_call = None + usage_metadata = None + for part in self.llm_client.completion(**completion_args): for chunk, finish_reason in _model_response_to_chunk(part): if isinstance(chunk, FunctionChunk): @@ -645,32 +678,55 @@ class LiteLlm(BaseLlm): ), is_partial=True, ) + elif isinstance(chunk, UsageMetadataChunk): + usage_metadata = types.GenerateContentResponseUsageMetadata( + prompt_token_count=chunk.prompt_tokens, + candidates_token_count=chunk.completion_tokens, + total_token_count=chunk.total_tokens, + ) + if finish_reason == "tool_calls" and function_id: - yield _message_to_generate_content_response( - ChatCompletionAssistantMessage( - role="assistant", - content="", - tool_calls=[ - ChatCompletionMessageToolCall( - type="function", - id=function_id, - function=Function( - name=function_name, - arguments=function_args, - ), - ) - ], + aggregated_llm_response_with_tool_call = ( + _message_to_generate_content_response( + ChatCompletionAssistantMessage( + role="assistant", + content="", + tool_calls=[ + ChatCompletionMessageToolCall( + type="function", + id=function_id, + function=Function( + name=function_name, + arguments=function_args, + ), + ) + ], + ) ) ) function_name = "" function_args = "" function_id = None elif finish_reason == "stop" and text: - yield _message_to_generate_content_response( + aggregated_llm_response = _message_to_generate_content_response( ChatCompletionAssistantMessage(role="assistant", content=text) ) text = "" + # waiting until streaming ends to yield the llm_response as litellm tends + # to send chunk that contains usage_metadata after the chunk with + # finish_reason set to tool_calls or stop. + if aggregated_llm_response: + if usage_metadata: + aggregated_llm_response.usage_metadata = usage_metadata + usage_metadata = None + yield aggregated_llm_response + + if aggregated_llm_response_with_tool_call: + if usage_metadata: + aggregated_llm_response_with_tool_call.usage_metadata = usage_metadata + yield aggregated_llm_response_with_tool_call + else: response = await self.llm_client.acompletion(**completion_args) yield _model_response_to_generate_content_response(response) diff --git a/tests/unittests/models/test_litellm.py b/tests/unittests/models/test_litellm.py index 6c6af41..9e74d2b 100644 --- a/tests/unittests/models/test_litellm.py +++ b/tests/unittests/models/test_litellm.py @@ -25,6 +25,7 @@ from google.adk.models.lite_llm import FunctionChunk from google.adk.models.lite_llm import LiteLlm from google.adk.models.lite_llm import LiteLLMClient from google.adk.models.lite_llm import TextChunk +from google.adk.models.lite_llm import UsageMetadataChunk from google.adk.models.llm_request import LlmRequest from google.genai import types from litellm import ChatCompletionAssistantMessage @@ -314,13 +315,10 @@ litellm_append_user_content_test_cases = [ litellm_append_user_content_test_cases ) def test_maybe_append_user_content(lite_llm_instance, llm_request, expected_output): - - lite_llm_instance._maybe_append_user_content( - llm_request - ) - - assert len(llm_request.contents) == expected_output + lite_llm_instance._maybe_append_user_content(llm_request) + + assert len(llm_request.contents) == expected_output function_declaration_test_cases = [ @@ -567,6 +565,80 @@ async def test_generate_content_async_with_tool_response( assert kwargs["messages"][2]["content"] == '{"result": "test_result"}' +@pytest.mark.asyncio +async def test_generate_content_async(mock_acompletion, lite_llm_instance): + + async for response in lite_llm_instance.generate_content_async( + LLM_REQUEST_WITH_FUNCTION_DECLARATION + ): + assert response.content.role == "model" + assert response.content.parts[0].text == "Test response" + assert response.content.parts[1].function_call.name == "test_function" + assert response.content.parts[1].function_call.args == { + "test_arg": "test_value" + } + assert response.content.parts[1].function_call.id == "test_tool_call_id" + + mock_acompletion.assert_called_once() + + _, kwargs = mock_acompletion.call_args + assert kwargs["model"] == "test_model" + assert kwargs["messages"][0]["role"] == "user" + assert kwargs["messages"][0]["content"] == "Test prompt" + assert kwargs["tools"][0]["function"]["name"] == "test_function" + assert ( + kwargs["tools"][0]["function"]["description"] + == "Test function description" + ) + assert ( + kwargs["tools"][0]["function"]["parameters"]["properties"]["test_arg"][ + "type" + ] + == "string" + ) + + +@pytest.mark.asyncio +async def test_generate_content_async_with_usage_metadata( + lite_llm_instance, mock_acompletion +): + mock_response_with_usage_metadata = ModelResponse( + choices=[ + Choices( + message=ChatCompletionAssistantMessage( + role="assistant", + content="Test response", + ) + ) + ], + usage={ + "prompt_tokens": 10, + "completion_tokens": 5, + "total_tokens": 15, + }, + ) + mock_acompletion.return_value = mock_response_with_usage_metadata + + llm_request = LlmRequest( + contents=[ + types.Content( + role="user", parts=[types.Part.from_text(text="Test prompt")] + ), + ], + config=types.GenerateContentConfig( + system_instruction="test instruction", + ), + ) + async for response in lite_llm_instance.generate_content_async(llm_request): + assert response.content.role == "model" + assert response.content.parts[0].text == "Test response" + assert response.usage_metadata.prompt_token_count == 10 + assert response.usage_metadata.candidates_token_count == 5 + assert response.usage_metadata.total_token_count == 15 + + mock_acompletion.assert_called_once() + + def test_content_to_message_param_user_message(): content = types.Content( role="user", parts=[types.Part.from_text(text="Test prompt")] @@ -704,7 +776,7 @@ def test_to_litellm_role(): @pytest.mark.parametrize( - "response, expected_chunk, expected_finished", + "response, expected_chunks, expected_finished", [ ( ModelResponse( @@ -716,7 +788,35 @@ def test_to_litellm_role(): } ] ), - TextChunk(text="this is a test"), + [ + TextChunk(text="this is a test"), + UsageMetadataChunk( + prompt_tokens=0, completion_tokens=0, total_tokens=0 + ), + ], + "stop", + ), + ( + ModelResponse( + choices=[ + { + "message": { + "content": "this is a test", + } + } + ], + usage={ + "prompt_tokens": 3, + "completion_tokens": 5, + "total_tokens": 8, + }, + ), + [ + TextChunk(text="this is a test"), + UsageMetadataChunk( + prompt_tokens=3, completion_tokens=5, total_tokens=8 + ), + ], "stop", ), ( @@ -741,28 +841,53 @@ def test_to_litellm_role(): ) ] ), - FunctionChunk(id="1", name="test_function", args='{"key": "va'), + [ + FunctionChunk(id="1", name="test_function", args='{"key": "va'), + UsageMetadataChunk( + prompt_tokens=0, completion_tokens=0, total_tokens=0 + ), + ], None, ), ( ModelResponse(choices=[{"finish_reason": "tool_calls"}]), - None, + [ + None, + UsageMetadataChunk( + prompt_tokens=0, completion_tokens=0, total_tokens=0 + ), + ], "tool_calls", ), - (ModelResponse(choices=[{}]), None, "stop"), + ( + ModelResponse(choices=[{}]), + [ + None, + UsageMetadataChunk( + prompt_tokens=0, completion_tokens=0, total_tokens=0 + ), + ], + "stop", + ), ], ) -def test_model_response_to_chunk(response, expected_chunk, expected_finished): +def test_model_response_to_chunk(response, expected_chunks, expected_finished): result = list(_model_response_to_chunk(response)) - assert len(result) == 1 + assert len(result) == 2 chunk, finished = result[0] - if expected_chunk: - assert isinstance(chunk, type(expected_chunk)) - assert chunk == expected_chunk + if expected_chunks: + assert isinstance(chunk, type(expected_chunks[0])) + assert chunk == expected_chunks[0] else: assert chunk is None assert finished == expected_finished + usage_chunk, _ = result[1] + assert usage_chunk is not None + assert usage_chunk.prompt_tokens == expected_chunks[1].prompt_tokens + assert usage_chunk.completion_tokens == expected_chunks[1].completion_tokens + assert usage_chunk.total_tokens == expected_chunks[1].total_tokens + @pytest.mark.asyncio async def test_acompletion_additional_args(mock_acompletion, mock_client): @@ -893,3 +1018,71 @@ async def test_generate_content_async_stream( ] == "string" ) + + +@pytest.mark.asyncio +async def test_generate_content_async_stream_with_usage_metadata( + mock_completion, lite_llm_instance +): + + streaming_model_response_with_usage_metadata = [ + *STREAMING_MODEL_RESPONSE, + ModelResponse( + usage={ + "prompt_tokens": 10, + "completion_tokens": 5, + "total_tokens": 15, + }, + choices=[ + StreamingChoices( + finish_reason=None, + ) + ], + ), + ] + + mock_completion.return_value = iter( + streaming_model_response_with_usage_metadata + ) + + responses = [ + response + async for response in lite_llm_instance.generate_content_async( + LLM_REQUEST_WITH_FUNCTION_DECLARATION, stream=True + ) + ] + assert len(responses) == 4 + assert responses[0].content.role == "model" + assert responses[0].content.parts[0].text == "zero, " + assert responses[1].content.role == "model" + assert responses[1].content.parts[0].text == "one, " + assert responses[2].content.role == "model" + assert responses[2].content.parts[0].text == "two:" + assert responses[3].content.role == "model" + assert responses[3].content.parts[0].function_call.name == "test_function" + assert responses[3].content.parts[0].function_call.args == { + "test_arg": "test_value" + } + assert responses[3].content.parts[0].function_call.id == "test_tool_call_id" + + assert responses[3].usage_metadata.prompt_token_count == 10 + assert responses[3].usage_metadata.candidates_token_count == 5 + assert responses[3].usage_metadata.total_token_count == 15 + + mock_completion.assert_called_once() + + _, kwargs = mock_completion.call_args + assert kwargs["model"] == "test_model" + assert kwargs["messages"][0]["role"] == "user" + assert kwargs["messages"][0]["content"] == "Test prompt" + assert kwargs["tools"][0]["function"]["name"] == "test_function" + assert ( + kwargs["tools"][0]["function"]["description"] + == "Test function description" + ) + assert ( + kwargs["tools"][0]["function"]["parameters"]["properties"]["test_arg"][ + "type" + ] + == "string" + )