diff --git a/src/google/adk/flows/llm_flows/base_llm_flow.py b/src/google/adk/flows/llm_flows/base_llm_flow.py index 9b9b156..5fbfab2 100644 --- a/src/google/adk/flows/llm_flows/base_llm_flow.py +++ b/src/google/adk/flows/llm_flows/base_llm_flow.py @@ -23,6 +23,7 @@ from typing import cast from typing import Optional from typing import TYPE_CHECKING +from google.genai import types from websockets.exceptions import ConnectionClosedOK from . import functions @@ -50,6 +51,8 @@ if TYPE_CHECKING: logger = logging.getLogger('google_adk.' + __name__) +_ADK_AGENT_NAME_LABEL_KEY = 'adk_agent_name' + class BaseLlmFlow(ABC): """A basic flow that calls the LLM in a loop until a final response is generated. @@ -499,6 +502,16 @@ class BaseLlmFlow(ABC): yield response return + llm_request.config = llm_request.config or types.GenerateContentConfig() + llm_request.config.labels = llm_request.config.labels or {} + + # Add agent name as a label to the llm_request. This will help with slicing + # the billing reports on a per-agent basis. + if _ADK_AGENT_NAME_LABEL_KEY not in llm_request.config.labels: + llm_request.config.labels[_ADK_AGENT_NAME_LABEL_KEY] = ( + invocation_context.agent.name + ) + # Calls the LLM. llm = self.__get_llm(invocation_context) with tracer.start_as_current_span('call_llm'): diff --git a/tests/unittests/flows/llm_flows/test_other_configs.py b/tests/unittests/flows/llm_flows/test_other_configs.py index 005bd1b..1f3d816 100644 --- a/tests/unittests/flows/llm_flows/test_other_configs.py +++ b/tests/unittests/flows/llm_flows/test_other_configs.py @@ -44,3 +44,4 @@ def test_output_schema(): assert len(mockModel.requests) == 1 assert mockModel.requests[0].config.response_schema == CustomOutput assert mockModel.requests[0].config.response_mime_type == 'application/json' + assert mockModel.requests[0].config.labels == {'adk_agent_name': 'root_agent'}