diff --git a/src/google/adk/models/google_llm.py b/src/google/adk/models/google_llm.py index 352a1b9..20b747c 100644 --- a/src/google/adk/models/google_llm.py +++ b/src/google/adk/models/google_llm.py @@ -98,6 +98,7 @@ class Gemini(BaseLlm): ) response = None text = '' + usage_metadata = None # for sse, similar as bidi (see receive method in gemini_llm_connecton.py), # we need to mark those text content as partial and after all partial # contents are sent, we send an accumulated event which contains all the @@ -106,6 +107,7 @@ class Gemini(BaseLlm): async for response in responses: logger.info(_build_response_log(response)) llm_response = LlmResponse.create(response) + usage_metadata = llm_response.usage_metadata if ( llm_response.content and llm_response.content.parts @@ -123,7 +125,7 @@ class Gemini(BaseLlm): content=types.ModelContent( parts=[types.Part.from_text(text=text)], ), - usage_metadata=llm_response.usage_metadata, + usage_metadata=usage_metadata, ) text = '' yield llm_response @@ -137,6 +139,7 @@ class Gemini(BaseLlm): content=types.ModelContent( parts=[types.Part.from_text(text=text)], ), + usage_metadata=usage_metadata, ) else: