diff --git a/contributing/samples/telemetry/agent.py b/contributing/samples/telemetry/agent.py new file mode 100755 index 0000000..b7b8ce1 --- /dev/null +++ b/contributing/samples/telemetry/agent.py @@ -0,0 +1,109 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import random + +from google.adk import Agent +from google.adk.planners import BuiltInPlanner +from google.adk.planners import PlanReActPlanner +from google.adk.tools.tool_context import ToolContext +from google.genai import types + + +def roll_die(sides: int, tool_context: ToolContext) -> int: + """Roll a die and return the rolled result. + + Args: + sides: The integer number of sides the die has. + + Returns: + An integer of the result of rolling the die. + """ + result = random.randint(1, sides) + if not 'rolls' in tool_context.state: + tool_context.state['rolls'] = [] + + tool_context.state['rolls'] = tool_context.state['rolls'] + [result] + return result + + +async def check_prime(nums: list[int]) -> str: + """Check if a given list of numbers are prime. + + Args: + nums: The list of numbers to check. + + Returns: + A str indicating which number is prime. + """ + primes = set() + for number in nums: + number = int(number) + if number <= 1: + continue + is_prime = True + for i in range(2, int(number**0.5) + 1): + if number % i == 0: + is_prime = False + break + if is_prime: + primes.add(number) + return ( + 'No prime numbers found.' + if not primes + else f"{', '.join(str(num) for num in primes)} are prime numbers." + ) + +root_agent = Agent( + model='gemini-2.0-flash', + name='data_processing_agent', + description=( + 'hello world agent that can roll a dice of 8 sides and check prime' + ' numbers.' + ), + instruction=""" + You roll dice and answer questions about the outcome of the dice rolls. + You can roll dice of different sizes. + You can use multiple tools in parallel by calling functions in parallel(in one request and in one round). + It is ok to discuss previous dice roles, and comment on the dice rolls. + When you are asked to roll a die, you must call the roll_die tool with the number of sides. Be sure to pass in an integer. Do not pass in a string. + You should never roll a die on your own. + When checking prime numbers, call the check_prime tool with a list of integers. Be sure to pass in a list of integers. You should never pass in a string. + You should not check prime numbers before calling the tool. + When you are asked to roll a die and check prime numbers, you should always make the following two function calls: + 1. You should first call the roll_die tool to get a roll. Wait for the function response before calling the check_prime tool. + 2. After you get the function response from roll_die tool, you should call the check_prime tool with the roll_die result. + 2.1 If user asks you to check primes based on previous rolls, make sure you include the previous rolls in the list. + 3. When you respond, you must include the roll_die result from step 1. + You should always perform the previous 3 steps when asking for a roll and checking prime numbers. + You should not rely on the previous history on prime results. + """, + tools=[ + roll_die, + check_prime, + ], + # planner=BuiltInPlanner( + # thinking_config=types.ThinkingConfig( + # include_thoughts=True, + # ), + # ), + generate_content_config=types.GenerateContentConfig( + safety_settings=[ + types.SafetySetting( # avoid false alarm about rolling dice. + category=types.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT, + threshold=types.HarmBlockThreshold.OFF, + ), + ] + ), +) diff --git a/contributing/samples/telemetry/main.py b/contributing/samples/telemetry/main.py new file mode 100755 index 0000000..060096b --- /dev/null +++ b/contributing/samples/telemetry/main.py @@ -0,0 +1,111 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import asyncio +import os +import time + +import agent +from dotenv import load_dotenv +from google.adk.agents.run_config import RunConfig +from google.adk.runners import InMemoryRunner +from google.adk.sessions import Session +from google.genai import types +from opentelemetry import trace +from opentelemetry.exporter.cloud_trace import CloudTraceSpanExporter +from opentelemetry.sdk.trace import export +from opentelemetry.sdk.trace import TracerProvider + + +load_dotenv(override=True) + +async def main(): + app_name = 'my_app' + user_id_1 = 'user1' + runner = InMemoryRunner( + agent=agent.root_agent, + app_name=app_name, + ) + session_11 = await runner.session_service.create_session( + app_name=app_name, user_id=user_id_1 + ) + + async def run_prompt(session: Session, new_message: str): + content = types.Content( + role='user', parts=[types.Part.from_text(text=new_message)] + ) + print('** User says:', content.model_dump(exclude_none=True)) + async for event in runner.run_async( + user_id=user_id_1, + session_id=session.id, + new_message=content, + ): + if event.content.parts and event.content.parts[0].text: + print(f'** {event.author}: {event.content.parts[0].text}') + + async def run_prompt_bytes(session: Session, new_message: str): + content = types.Content( + role='user', + parts=[ + types.Part.from_bytes( + data=str.encode(new_message), mime_type='text/plain' + ) + ], + ) + print('** User says:', content.model_dump(exclude_none=True)) + async for event in runner.run_async( + user_id=user_id_1, + session_id=session.id, + new_message=content, + run_config=RunConfig(save_input_blobs_as_artifacts=True), + ): + if event.content.parts and event.content.parts[0].text: + print(f'** {event.author}: {event.content.parts[0].text}') + + start_time = time.time() + print('Start time:', start_time) + print('------------------------------------') + await run_prompt(session_11, 'Hi') + await run_prompt(session_11, 'Roll a die with 100 sides') + await run_prompt(session_11, 'Roll a die again with 100 sides.') + await run_prompt(session_11, 'What numbers did I got?') + await run_prompt_bytes(session_11, 'Hi bytes') + print( + await runner.artifact_service.list_artifact_keys( + app_name=app_name, user_id=user_id_1, session_id=session_11.id + ) + ) + end_time = time.time() + print('------------------------------------') + print('End time:', end_time) + print('Total time:', end_time - start_time) + + +if __name__ == '__main__': + + provider = TracerProvider() + project_id = os.environ.get('GOOGLE_CLOUD_PROJECT') + if not project_id: + raise ValueError('GOOGLE_CLOUD_PROJECT environment variable is not set.') + print('Tracing to project', project_id) + processor = export.BatchSpanProcessor( + CloudTraceSpanExporter(project_id=project_id) + ) + provider.add_span_processor(processor) + trace.set_tracer_provider(provider) + + asyncio.run(main()) + + provider.force_flush() + print('Done tracing to project', project_id) diff --git a/src/google/adk/cli/fast_api.py b/src/google/adk/cli/fast_api.py index bd370e8..8dca604 100644 --- a/src/google/adk/cli/fast_api.py +++ b/src/google/adk/cli/fast_api.py @@ -98,7 +98,7 @@ class ApiServerSpanExporter(export.SpanExporter): if ( span.name == "call_llm" or span.name == "send_data" - or span.name.startswith("tool_response") + or span.name.startswith("execute_tool") ): attributes = dict(span.attributes) attributes["trace_id"] = span.get_span_context().trace_id diff --git a/src/google/adk/flows/llm_flows/functions.py b/src/google/adk/flows/llm_flows/functions.py index 53ddb35..2541ac6 100644 --- a/src/google/adk/flows/llm_flows/functions.py +++ b/src/google/adk/flows/llm_flows/functions.py @@ -32,8 +32,8 @@ from ...agents.invocation_context import InvocationContext from ...auth.auth_tool import AuthToolArguments from ...events.event import Event from ...events.event_actions import EventActions +from ...telemetry import trace_merged_tool_calls from ...telemetry import trace_tool_call -from ...telemetry import trace_tool_response from ...telemetry import tracer from ...tools.base_tool import BaseTool from ...tools.tool_context import ToolContext @@ -148,62 +148,69 @@ async def handle_function_calls_async( function_call, tools_dict, ) - # do not use "args" as the variable name, because it is a reserved keyword - # in python debugger. - function_args = function_call.args or {} - function_response: Optional[dict] = None - for callback in agent.canonical_before_tool_callbacks: - function_response = callback( - tool=tool, args=function_args, tool_context=tool_context + with tracer.start_as_current_span(f'execute_tool {tool.name}'): + # do not use "args" as the variable name, because it is a reserved keyword + # in python debugger. + function_args = function_call.args or {} + function_response: Optional[dict] = None + + for callback in agent.canonical_before_tool_callbacks: + function_response = callback( + tool=tool, args=function_args, tool_context=tool_context + ) + if inspect.isawaitable(function_response): + function_response = await function_response + if function_response: + break + + if not function_response: + function_response = await __call_tool_async( + tool, args=function_args, tool_context=tool_context + ) + + for callback in agent.canonical_after_tool_callbacks: + altered_function_response = callback( + tool=tool, + args=function_args, + tool_context=tool_context, + tool_response=function_response, + ) + if inspect.isawaitable(altered_function_response): + altered_function_response = await altered_function_response + if altered_function_response is not None: + function_response = altered_function_response + break + + if tool.is_long_running: + # Allow long running function to return None to not provide function response. + if not function_response: + continue + + # Builds the function response event. + function_response_event = __build_response_event( + tool, function_response, tool_context, invocation_context ) - if inspect.isawaitable(function_response): - function_response = await function_response - if function_response: - break - - if not function_response: - function_response = await __call_tool_async( - tool, args=function_args, tool_context=tool_context - ) - - for callback in agent.canonical_after_tool_callbacks: - altered_function_response = callback( + trace_tool_call( tool=tool, args=function_args, - tool_context=tool_context, - tool_response=function_response, + function_response_event=function_response_event, ) - if inspect.isawaitable(altered_function_response): - altered_function_response = await altered_function_response - if altered_function_response is not None: - function_response = altered_function_response - break - - if tool.is_long_running: - # Allow long running function to return None to not provide function response. - if not function_response: - continue - - # Builds the function response event. - function_response_event = __build_response_event( - tool, function_response, tool_context, invocation_context - ) - function_response_events.append(function_response_event) + function_response_events.append(function_response_event) if not function_response_events: return None merged_event = merge_parallel_function_response_events( function_response_events ) + if len(function_response_events) > 1: # this is needed for debug traces of parallel calls # individual response with tool.name is traced in __build_response_event # (we drop tool.name from span name here as this is merged event) - with tracer.start_as_current_span('tool_response'): - trace_tool_response( - invocation_context=invocation_context, - event_id=merged_event.id, + with tracer.start_as_current_span('execute_tool (merged)'): + trace_merged_tool_calls( + response_event_id=merged_event.id, function_response_event=merged_event, ) return merged_event @@ -225,65 +232,81 @@ async def handle_function_calls_live( tool, tool_context = _get_tool_and_context( invocation_context, function_call_event, function_call, tools_dict ) - # do not use "args" as the variable name, because it is a reserved keyword - # in python debugger. - function_args = function_call.args or {} - function_response = None - # # Calls the tool if before_tool_callback does not exist or returns None. - # if agent.before_tool_callback: - # function_response = agent.before_tool_callback( - # tool, function_args, tool_context - # ) - if agent.before_tool_callback: - function_response = agent.before_tool_callback( - tool=tool, args=function_args, tool_context=tool_context - ) - if inspect.isawaitable(function_response): - function_response = await function_response + with tracer.start_as_current_span(f'execute_tool {tool.name}'): + # do not use "args" as the variable name, because it is a reserved keyword + # in python debugger. + function_args = function_call.args or {} + function_response = None + # # Calls the tool if before_tool_callback does not exist or returns None. + # if agent.before_tool_callback: + # function_response = agent.before_tool_callback( + # tool, function_args, tool_context + # ) + if agent.before_tool_callback: + function_response = agent.before_tool_callback( + tool=tool, args=function_args, tool_context=tool_context + ) + if inspect.isawaitable(function_response): + function_response = await function_response - if not function_response: - function_response = await _process_function_live_helper( - tool, tool_context, function_call, function_args, invocation_context - ) + if not function_response: + function_response = await _process_function_live_helper( + tool, tool_context, function_call, function_args, invocation_context + ) - # Calls after_tool_callback if it exists. - # if agent.after_tool_callback: - # new_response = agent.after_tool_callback( - # tool, - # function_args, - # tool_context, - # function_response, - # ) - # if new_response: - # function_response = new_response - if agent.after_tool_callback: - altered_function_response = agent.after_tool_callback( + # Calls after_tool_callback if it exists. + # if agent.after_tool_callback: + # new_response = agent.after_tool_callback( + # tool, + # function_args, + # tool_context, + # function_response, + # ) + # if new_response: + # function_response = new_response + if agent.after_tool_callback: + altered_function_response = agent.after_tool_callback( + tool=tool, + args=function_args, + tool_context=tool_context, + tool_response=function_response, + ) + if inspect.isawaitable(altered_function_response): + altered_function_response = await altered_function_response + if altered_function_response is not None: + function_response = altered_function_response + + if tool.is_long_running: + # Allow async function to return None to not provide function response. + if not function_response: + continue + + # Builds the function response event. + function_response_event = __build_response_event( + tool, function_response, tool_context, invocation_context + ) + trace_tool_call( tool=tool, args=function_args, - tool_context=tool_context, - tool_response=function_response, + response_event_id=function_response_event.id, + function_response=function_response, ) - if inspect.isawaitable(altered_function_response): - altered_function_response = await altered_function_response - if altered_function_response is not None: - function_response = altered_function_response - - if tool.is_long_running: - # Allow async function to return None to not provide function response. - if not function_response: - continue - - # Builds the function response event. - function_response_event = __build_response_event( - tool, function_response, tool_context, invocation_context - ) - function_response_events.append(function_response_event) + function_response_events.append(function_response_event) if not function_response_events: return None merged_event = merge_parallel_function_response_events( function_response_events ) + if len(function_response_events) > 1: + # this is needed for debug traces of parallel calls + # individual response with tool.name is traced in __build_response_event + # (we drop tool.name from span name here as this is merged event) + with tracer.start_as_current_span('execute_tool (merged)'): + trace_merged_tool_calls( + response_event_id=merged_event.id, + function_response_event=merged_event, + ) return merged_event @@ -410,14 +433,12 @@ async def __call_tool_live( invocation_context: InvocationContext, ) -> AsyncGenerator[Event, None]: """Calls the tool asynchronously (awaiting the coroutine).""" - with tracer.start_as_current_span(f'tool_call [{tool.name}]'): - trace_tool_call(args=args) - async for item in tool._call_live( - args=args, - tool_context=tool_context, - invocation_context=invocation_context, - ): - yield item + async for item in tool._call_live( + args=args, + tool_context=tool_context, + invocation_context=invocation_context, + ): + yield item async def __call_tool_async( @@ -426,9 +447,7 @@ async def __call_tool_async( tool_context: ToolContext, ) -> Any: """Calls the tool.""" - with tracer.start_as_current_span(f'tool_call [{tool.name}]'): - trace_tool_call(args=args) - return await tool.run_async(args=args, tool_context=tool_context) + return await tool.run_async(args=args, tool_context=tool_context) def __build_response_event( @@ -437,35 +456,29 @@ def __build_response_event( tool_context: ToolContext, invocation_context: InvocationContext, ) -> Event: - with tracer.start_as_current_span(f'tool_response [{tool.name}]'): - # Specs requires the result to be a dict. - if not isinstance(function_result, dict): - function_result = {'result': function_result} + # Specs requires the result to be a dict. + if not isinstance(function_result, dict): + function_result = {'result': function_result} - part_function_response = types.Part.from_function_response( - name=tool.name, response=function_result - ) - part_function_response.function_response.id = tool_context.function_call_id + part_function_response = types.Part.from_function_response( + name=tool.name, response=function_result + ) + part_function_response.function_response.id = tool_context.function_call_id - content = types.Content( - role='user', - parts=[part_function_response], - ) + content = types.Content( + role='user', + parts=[part_function_response], + ) - function_response_event = Event( - invocation_id=invocation_context.invocation_id, - author=invocation_context.agent.name, - content=content, - actions=tool_context.actions, - branch=invocation_context.branch, - ) + function_response_event = Event( + invocation_id=invocation_context.invocation_id, + author=invocation_context.agent.name, + content=content, + actions=tool_context.actions, + branch=invocation_context.branch, + ) - trace_tool_response( - invocation_context=invocation_context, - event_id=function_response_event.id, - function_response_event=function_response_event, - ) - return function_response_event + return function_response_event def merge_parallel_function_response_events( diff --git a/src/google/adk/telemetry.py b/src/google/adk/telemetry.py index 2744ea5..1cc59ac 100644 --- a/src/google/adk/telemetry.py +++ b/src/google/adk/telemetry.py @@ -21,6 +21,8 @@ # Agent Development Kit should be focused on the higher-level # constructs of the framework that are not observable by the SDK. +from __future__ import annotations + import json from typing import Any @@ -31,51 +33,91 @@ from .agents.invocation_context import InvocationContext from .events.event import Event from .models.llm_request import LlmRequest from .models.llm_response import LlmResponse +from .tools.base_tool import BaseTool tracer = trace.get_tracer('gcp.vertex.agent') def trace_tool_call( + tool: BaseTool, args: dict[str, Any], + function_response_event: Event, ): """Traces tool call. Args: + tool: The tool that was called. args: The arguments to the tool call. + function_response_event: The event with the function response details. """ span = trace.get_current_span() span.set_attribute('gen_ai.system', 'gcp.vertex.agent') + span.set_attribute('gen_ai.operation.name', 'execute_tool') + span.set_attribute('gen_ai.tool.name', tool.name) + span.set_attribute('gen_ai.tool.description', tool.description) + tool_call_id = '' + tool_response = '' + if function_response_event.content.parts: + function_response = function_response_event.content.parts[ + 0 + ].function_response + if function_response is not None: + tool_call_id = function_response.id + tool_response = function_response.response + + span.set_attribute('gen_ai.tool.call.id', tool_call_id) + + if not isinstance(tool_response, dict): + tool_response = {'result': tool_response} span.set_attribute('gcp.vertex.agent.tool_call_args', json.dumps(args)) - - -def trace_tool_response( - invocation_context: InvocationContext, - event_id: str, - function_response_event: Event, -): - """Traces tool response event. - - This function records details about the tool response event as attributes on - the current OpenTelemetry span. - - Args: - invocation_context: The invocation context for the current agent run. - event_id: The ID of the event. - function_response_event: The function response event which can be either - merged function response for parallel function calls or individual - function response for sequential function calls. - """ - span = trace.get_current_span() - span.set_attribute('gen_ai.system', 'gcp.vertex.agent') - span.set_attribute( - 'gcp.vertex.agent.invocation_id', invocation_context.invocation_id - ) - span.set_attribute('gcp.vertex.agent.event_id', event_id) + span.set_attribute('gcp.vertex.agent.event_id', function_response_event.id) span.set_attribute( 'gcp.vertex.agent.tool_response', - function_response_event.model_dump_json(exclude_none=True), + json.dumps(tool_response), + ) + # Setting empty llm request and response (as UI expect these) while not + # applicable for tool_response. + span.set_attribute('gcp.vertex.agent.llm_request', '{}') + span.set_attribute( + 'gcp.vertex.agent.llm_response', + '{}', ) + +def trace_merged_tool_calls( + response_event_id: str, + function_response_event: Event, +): + """Traces merged tool call events. + + Calling this function is not needed for telemetry purposes. This is provided + for preventing /debug/trace requests (typically sent by web UI). + + Args: + response_event_id: The ID of the response event. + function_response_event: The merged response event. + """ + + span = trace.get_current_span() + span.set_attribute('gen_ai.system', 'gcp.vertex.agent') + span.set_attribute('gen_ai.operation.name', 'execute_tool') + span.set_attribute('gen_ai.tool.name', '(merged tools)') + span.set_attribute('gen_ai.tool.description', '(merged tools)') + span.set_attribute('gen_ai.tool.call.id', response_event_id) + + span.set_attribute('gcp.vertex.agent.tool_call_args', 'N/A') + span.set_attribute('gcp.vertex.agent.event_id', response_event_id) + try: + function_response_event_json = function_response_event.model_dumps_json( + exclude_none=True + ) + except Exception: # pylint: disable=broad-exception-caught + function_response_event_json = '' + + span.set_attribute( + 'gcp.vertex.agent.tool_response', + function_response_event_json, + ) # Setting empty llm request and response (as UI expect these) while not # applicable for tool_response. span.set_attribute('gcp.vertex.agent.llm_request', '{}') @@ -123,9 +165,15 @@ def trace_call_llm( ), ) # Consider removing once GenAI SDK provides a way to record this info. + + try: + llm_response_json = llm_response.model_dump_json(exclude_none=True) + except Exception: # pylint: disable=broad-exception-caught + llm_response_json = '' + span.set_attribute( 'gcp.vertex.agent.llm_response', - llm_response.model_dump_json(exclude_none=True), + llm_response_json, ) diff --git a/tests/unittests/flows/llm_flows/test_tool_telemetry.py b/tests/unittests/flows/llm_flows/test_tool_telemetry.py new file mode 100644 index 0000000..b599566 --- /dev/null +++ b/tests/unittests/flows/llm_flows/test_tool_telemetry.py @@ -0,0 +1,99 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any +from typing import Dict +from typing import Optional +from unittest import mock + +from google.adk import telemetry +from google.adk.agents import Agent +from google.adk.events.event import Event +from google.adk.flows.llm_flows.functions import handle_function_calls_async +from google.adk.tools.function_tool import FunctionTool +from google.genai import types + +from ... import testing_utils + + +async def invoke_tool() -> Optional[Event]: + def simple_fn(**kwargs) -> Dict[str, Any]: + return {'result': 'test'} + + tool = FunctionTool(simple_fn) + model = testing_utils.MockModel.create(responses=[]) + agent = Agent( + name='agent', + model=model, + tools=[tool], + ) + invocation_context = await testing_utils.create_invocation_context( + agent=agent, user_content='' + ) + function_call = types.FunctionCall(name=tool.name, args={'a': 1, 'b': 2}) + content = types.Content(parts=[types.Part(function_call=function_call)]) + event = Event( + invocation_id=invocation_context.invocation_id, + author=agent.name, + content=content, + ) + tools_dict = {tool.name: tool} + return await handle_function_calls_async( + invocation_context, + event, + tools_dict, + ) + + +async def test_simple_function_with_mocked_tracer(monkeypatch): + mock_start_as_current_span_func = mock.Mock() + returned_context_manager_mock = mock.MagicMock() + returned_context_manager_mock.__enter__.return_value = mock.Mock( + name='span_mock' + ) + mock_start_as_current_span_func.return_value = returned_context_manager_mock + + monkeypatch.setattr( + telemetry.tracer, 'start_as_current_span', mock_start_as_current_span_func + ) + + mock_adk_trace_tool_call = mock.Mock() + monkeypatch.setattr( + 'google.adk.flows.llm_flows.functions.trace_tool_call', + mock_adk_trace_tool_call, + ) + + event = await invoke_tool() + assert event is not None + + event = await invoke_tool() + assert event is not None + + expected_span_name = 'execute_tool simple_fn' + + assert mock_start_as_current_span_func.call_count == 2 + mock_start_as_current_span_func.assert_any_call(expected_span_name) + + assert returned_context_manager_mock.__enter__.call_count == 2 + assert returned_context_manager_mock.__exit__.call_count == 2 + + assert mock_adk_trace_tool_call.call_count == 2 + for call_args_item in mock_adk_trace_tool_call.call_args_list: + kwargs = call_args_item.kwargs + assert kwargs['tool'].name == 'simple_fn' + assert kwargs['args'] == {'a': 1, 'b': 2} + assert 'function_response_event' in kwargs + assert kwargs['function_response_event'].content.parts[ + 0 + ].function_response.response == {'result': 'test'} diff --git a/tests/unittests/test_telemetry.py b/tests/unittests/test_telemetry.py index 64da250..1b8ee1b 100644 --- a/tests/unittests/test_telemetry.py +++ b/tests/unittests/test_telemetry.py @@ -1,5 +1,22 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json from typing import Any +from typing import Dict from typing import Optional +from unittest import mock from google.adk.agents.invocation_context import InvocationContext from google.adk.agents.llm_agent import LlmAgent @@ -7,19 +24,55 @@ from google.adk.models.llm_request import LlmRequest from google.adk.models.llm_response import LlmResponse from google.adk.sessions import InMemorySessionService from google.adk.telemetry import trace_call_llm +from google.adk.telemetry import trace_merged_tool_calls +from google.adk.telemetry import trace_tool_call +from google.adk.tools.base_tool import BaseTool from google.genai import types import pytest +class Event: + + def __init__(self, event_id: str, event_content: Any): + self.id = event_id + self.content = event_content + + def model_dumps_json(self, exclude_none: bool = False) -> str: + # This is just a stub for the spec. The mock will provide behavior. + return '' + + +@pytest.fixture +def mock_span_fixture(): + return mock.MagicMock() + + +@pytest.fixture +def mock_tool_fixture(): + tool = mock.Mock(spec=BaseTool) + tool.name = 'sample_tool' + tool.description = 'A sample tool for testing.' + return tool + + +@pytest.fixture +def mock_event_fixture(): + event_mock = mock.create_autospec(Event, instance=True) + event_mock.model_dumps_json.return_value = ( + '{"default_event_key": "default_event_value"}' + ) + return event_mock + + async def _create_invocation_context( agent: LlmAgent, state: Optional[dict[str, Any]] = None ) -> InvocationContext: session_service = InMemorySessionService() session = await session_service.create_session( - app_name="test_app", user_id="test_user", state=state + app_name='test_app', user_id='test_user', state=state ) invocation_context = InvocationContext( - invocation_id="test_id", + invocation_id='test_id', agent=agent, session=session, session_service=session_service, @@ -28,38 +81,216 @@ async def _create_invocation_context( @pytest.mark.asyncio -async def test_trace_call_llm_function_response_includes_part_from_bytes(): - agent = LlmAgent(name="test_agent") +async def test_trace_call_llm_function_response_includes_part_from_bytes( + monkeypatch, mock_span_fixture +): + monkeypatch.setattr( + 'opentelemetry.trace.get_current_span', lambda: mock_span_fixture + ) + + agent = LlmAgent(name='test_agent') invocation_context = await _create_invocation_context(agent) llm_request = LlmRequest( contents=[ types.Content( - role="user", + role='user', parts=[ types.Part.from_function_response( - name="test_function_1", + name='test_function_1', response={ - "result": b"test_data", + 'result': b'test_data', }, ), ], ), types.Content( - role="user", + role='user', parts=[ types.Part.from_function_response( - name="test_function_2", + name='test_function_2', response={ - "result": types.Part.from_bytes( - data=b"test_data", - mime_type="application/octet-stream", + 'result': types.Part.from_bytes( + data=b'test_data', + mime_type='application/octet-stream', ), }, ), ], ), ], - config=types.GenerateContentConfig(system_instruction=""), + config=types.GenerateContentConfig(system_instruction=''), ) llm_response = LlmResponse(turn_complete=True) - trace_call_llm(invocation_context, "test_event_id", llm_request, llm_response) + trace_call_llm(invocation_context, 'test_event_id', llm_request, llm_response) + + expected_calls = [ + mock.call('gen_ai.system', 'gcp.vertex.agent'), + ] + assert mock_span_fixture.set_attribute.call_count == 7 + mock_span_fixture.set_attribute.assert_has_calls(expected_calls) + llm_request_json_str = None + for call_obj in mock_span_fixture.set_attribute.call_args_list: + if call_obj.args[0] == 'gcp.vertex.agent.llm_request': + llm_request_json_str = call_obj.args[1] + break + + assert ( + llm_request_json_str is not None + ), "Attribute 'gcp.vertex.agent.llm_request' was not set on the span." + + assert llm_request_json_str.count('') == 2 + + +def test_trace_tool_call_with_scalar_response( + monkeypatch, mock_span_fixture, mock_tool_fixture, mock_event_fixture +): + monkeypatch.setattr( + 'opentelemetry.trace.get_current_span', lambda: mock_span_fixture + ) + + test_args: Dict[str, Any] = {'param_a': 'value_a', 'param_b': 100} + test_tool_call_id: str = 'tool_call_id_001' + test_event_id: str = 'event_id_001' + scalar_function_response: Any = 'Scalar result' + + expected_processed_response = {'result': scalar_function_response} + + mock_event_fixture.id = test_event_id + mock_event_fixture.content = types.Content( + role='user', + parts=[ + types.Part( + function_response=types.FunctionResponse( + id=test_tool_call_id, + name='test_function_1', + response={'result': scalar_function_response}, + ) + ), + ], + ) + + # Act + trace_tool_call( + tool=mock_tool_fixture, + args=test_args, + function_response_event=mock_event_fixture, + ) + + # Assert + assert mock_span_fixture.set_attribute.call_count == 10 + expected_calls = [ + mock.call('gen_ai.system', 'gcp.vertex.agent'), + mock.call('gen_ai.operation.name', 'execute_tool'), + mock.call('gen_ai.tool.name', mock_tool_fixture.name), + mock.call('gen_ai.tool.description', mock_tool_fixture.description), + mock.call('gen_ai.tool.call.id', test_tool_call_id), + mock.call('gcp.vertex.agent.tool_call_args', json.dumps(test_args)), + mock.call('gcp.vertex.agent.event_id', test_event_id), + mock.call( + 'gcp.vertex.agent.tool_response', + json.dumps(expected_processed_response), + ), + mock.call('gcp.vertex.agent.llm_request', '{}'), + mock.call('gcp.vertex.agent.llm_response', '{}'), + ] + + mock_span_fixture.set_attribute.assert_has_calls( + expected_calls, any_order=True + ) + + +def test_trace_tool_call_with_dict_response( + monkeypatch, mock_span_fixture, mock_tool_fixture, mock_event_fixture +): + # Arrange + monkeypatch.setattr( + 'opentelemetry.trace.get_current_span', lambda: mock_span_fixture + ) + + test_args: Dict[str, Any] = {'query': 'details', 'id_list': [1, 2, 3]} + test_tool_call_id: str = 'tool_call_id_002' + test_event_id: str = 'event_id_dict_002' + dict_function_response: Dict[str, Any] = { + 'data': 'structured_data', + 'count': 5, + } + + mock_event_fixture.id = test_event_id + mock_event_fixture.content = types.Content( + role='user', + parts=[ + types.Part( + function_response=types.FunctionResponse( + id=test_tool_call_id, + name='test_function_1', + response=dict_function_response, + ) + ), + ], + ) + + # Act + trace_tool_call( + tool=mock_tool_fixture, + args=test_args, + function_response_event=mock_event_fixture, + ) + + # Assert + expected_calls = [ + mock.call('gen_ai.system', 'gcp.vertex.agent'), + mock.call('gen_ai.operation.name', 'execute_tool'), + mock.call('gen_ai.tool.name', mock_tool_fixture.name), + mock.call('gen_ai.tool.description', mock_tool_fixture.description), + mock.call('gen_ai.tool.call.id', test_tool_call_id), + mock.call('gcp.vertex.agent.tool_call_args', json.dumps(test_args)), + mock.call('gcp.vertex.agent.event_id', test_event_id), + mock.call( + 'gcp.vertex.agent.tool_response', json.dumps(dict_function_response) + ), + mock.call('gcp.vertex.agent.llm_request', '{}'), + mock.call('gcp.vertex.agent.llm_response', '{}'), + ] + + assert mock_span_fixture.set_attribute.call_count == 10 + mock_span_fixture.set_attribute.assert_has_calls( + expected_calls, any_order=True + ) + + +def test_trace_merged_tool_calls_sets_correct_attributes( + monkeypatch, mock_span_fixture, mock_event_fixture +): + monkeypatch.setattr( + 'opentelemetry.trace.get_current_span', lambda: mock_span_fixture + ) + + test_response_event_id = 'merged_evt_id_001' + custom_event_json_output = ( + '{"custom_event_payload": true, "details": "merged_details"}' + ) + mock_event_fixture.model_dumps_json.return_value = custom_event_json_output + + trace_merged_tool_calls( + response_event_id=test_response_event_id, + function_response_event=mock_event_fixture, + ) + + expected_calls = [ + mock.call('gen_ai.system', 'gcp.vertex.agent'), + mock.call('gen_ai.operation.name', 'execute_tool'), + mock.call('gen_ai.tool.name', '(merged tools)'), + mock.call('gen_ai.tool.description', '(merged tools)'), + mock.call('gen_ai.tool.call.id', test_response_event_id), + mock.call('gcp.vertex.agent.tool_call_args', 'N/A'), + mock.call('gcp.vertex.agent.event_id', test_response_event_id), + mock.call('gcp.vertex.agent.tool_response', custom_event_json_output), + mock.call('gcp.vertex.agent.llm_request', '{}'), + mock.call('gcp.vertex.agent.llm_response', '{}'), + ] + + assert mock_span_fixture.set_attribute.call_count == 10 + mock_span_fixture.set_attribute.assert_has_calls( + expected_calls, any_order=True + ) + mock_event_fixture.model_dumps_json.assert_called_once_with(exclude_none=True)