Make tool_call one span for telemetry

Also renamed tool_call as execute_tool and added attributes as recommended in https://opentelemetry.io/docs/specs/semconv/gen-ai/gen-ai-spans/#execute-tool-span.

PiperOrigin-RevId: 764594179
This commit is contained in:
Selcuk Gun 2025-05-29 00:01:18 -07:00 committed by Copybara-Service
parent 96b36b70dc
commit 999a7fe69d
7 changed files with 779 additions and 168 deletions

View File

@ -0,0 +1,109 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import random
from google.adk import Agent
from google.adk.planners import BuiltInPlanner
from google.adk.planners import PlanReActPlanner
from google.adk.tools.tool_context import ToolContext
from google.genai import types
def roll_die(sides: int, tool_context: ToolContext) -> int:
"""Roll a die and return the rolled result.
Args:
sides: The integer number of sides the die has.
Returns:
An integer of the result of rolling the die.
"""
result = random.randint(1, sides)
if not 'rolls' in tool_context.state:
tool_context.state['rolls'] = []
tool_context.state['rolls'] = tool_context.state['rolls'] + [result]
return result
async def check_prime(nums: list[int]) -> str:
"""Check if a given list of numbers are prime.
Args:
nums: The list of numbers to check.
Returns:
A str indicating which number is prime.
"""
primes = set()
for number in nums:
number = int(number)
if number <= 1:
continue
is_prime = True
for i in range(2, int(number**0.5) + 1):
if number % i == 0:
is_prime = False
break
if is_prime:
primes.add(number)
return (
'No prime numbers found.'
if not primes
else f"{', '.join(str(num) for num in primes)} are prime numbers."
)
root_agent = Agent(
model='gemini-2.0-flash',
name='data_processing_agent',
description=(
'hello world agent that can roll a dice of 8 sides and check prime'
' numbers.'
),
instruction="""
You roll dice and answer questions about the outcome of the dice rolls.
You can roll dice of different sizes.
You can use multiple tools in parallel by calling functions in parallel(in one request and in one round).
It is ok to discuss previous dice roles, and comment on the dice rolls.
When you are asked to roll a die, you must call the roll_die tool with the number of sides. Be sure to pass in an integer. Do not pass in a string.
You should never roll a die on your own.
When checking prime numbers, call the check_prime tool with a list of integers. Be sure to pass in a list of integers. You should never pass in a string.
You should not check prime numbers before calling the tool.
When you are asked to roll a die and check prime numbers, you should always make the following two function calls:
1. You should first call the roll_die tool to get a roll. Wait for the function response before calling the check_prime tool.
2. After you get the function response from roll_die tool, you should call the check_prime tool with the roll_die result.
2.1 If user asks you to check primes based on previous rolls, make sure you include the previous rolls in the list.
3. When you respond, you must include the roll_die result from step 1.
You should always perform the previous 3 steps when asking for a roll and checking prime numbers.
You should not rely on the previous history on prime results.
""",
tools=[
roll_die,
check_prime,
],
# planner=BuiltInPlanner(
# thinking_config=types.ThinkingConfig(
# include_thoughts=True,
# ),
# ),
generate_content_config=types.GenerateContentConfig(
safety_settings=[
types.SafetySetting( # avoid false alarm about rolling dice.
category=types.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT,
threshold=types.HarmBlockThreshold.OFF,
),
]
),
)

View File

@ -0,0 +1,111 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import asyncio
import os
import time
import agent
from dotenv import load_dotenv
from google.adk.agents.run_config import RunConfig
from google.adk.runners import InMemoryRunner
from google.adk.sessions import Session
from google.genai import types
from opentelemetry import trace
from opentelemetry.exporter.cloud_trace import CloudTraceSpanExporter
from opentelemetry.sdk.trace import export
from opentelemetry.sdk.trace import TracerProvider
load_dotenv(override=True)
async def main():
app_name = 'my_app'
user_id_1 = 'user1'
runner = InMemoryRunner(
agent=agent.root_agent,
app_name=app_name,
)
session_11 = await runner.session_service.create_session(
app_name=app_name, user_id=user_id_1
)
async def run_prompt(session: Session, new_message: str):
content = types.Content(
role='user', parts=[types.Part.from_text(text=new_message)]
)
print('** User says:', content.model_dump(exclude_none=True))
async for event in runner.run_async(
user_id=user_id_1,
session_id=session.id,
new_message=content,
):
if event.content.parts and event.content.parts[0].text:
print(f'** {event.author}: {event.content.parts[0].text}')
async def run_prompt_bytes(session: Session, new_message: str):
content = types.Content(
role='user',
parts=[
types.Part.from_bytes(
data=str.encode(new_message), mime_type='text/plain'
)
],
)
print('** User says:', content.model_dump(exclude_none=True))
async for event in runner.run_async(
user_id=user_id_1,
session_id=session.id,
new_message=content,
run_config=RunConfig(save_input_blobs_as_artifacts=True),
):
if event.content.parts and event.content.parts[0].text:
print(f'** {event.author}: {event.content.parts[0].text}')
start_time = time.time()
print('Start time:', start_time)
print('------------------------------------')
await run_prompt(session_11, 'Hi')
await run_prompt(session_11, 'Roll a die with 100 sides')
await run_prompt(session_11, 'Roll a die again with 100 sides.')
await run_prompt(session_11, 'What numbers did I got?')
await run_prompt_bytes(session_11, 'Hi bytes')
print(
await runner.artifact_service.list_artifact_keys(
app_name=app_name, user_id=user_id_1, session_id=session_11.id
)
)
end_time = time.time()
print('------------------------------------')
print('End time:', end_time)
print('Total time:', end_time - start_time)
if __name__ == '__main__':
provider = TracerProvider()
project_id = os.environ.get('GOOGLE_CLOUD_PROJECT')
if not project_id:
raise ValueError('GOOGLE_CLOUD_PROJECT environment variable is not set.')
print('Tracing to project', project_id)
processor = export.BatchSpanProcessor(
CloudTraceSpanExporter(project_id=project_id)
)
provider.add_span_processor(processor)
trace.set_tracer_provider(provider)
asyncio.run(main())
provider.force_flush()
print('Done tracing to project', project_id)

View File

@ -98,7 +98,7 @@ class ApiServerSpanExporter(export.SpanExporter):
if ( if (
span.name == "call_llm" span.name == "call_llm"
or span.name == "send_data" or span.name == "send_data"
or span.name.startswith("tool_response") or span.name.startswith("execute_tool")
): ):
attributes = dict(span.attributes) attributes = dict(span.attributes)
attributes["trace_id"] = span.get_span_context().trace_id attributes["trace_id"] = span.get_span_context().trace_id

View File

@ -32,8 +32,8 @@ from ...agents.invocation_context import InvocationContext
from ...auth.auth_tool import AuthToolArguments from ...auth.auth_tool import AuthToolArguments
from ...events.event import Event from ...events.event import Event
from ...events.event_actions import EventActions from ...events.event_actions import EventActions
from ...telemetry import trace_merged_tool_calls
from ...telemetry import trace_tool_call from ...telemetry import trace_tool_call
from ...telemetry import trace_tool_response
from ...telemetry import tracer from ...telemetry import tracer
from ...tools.base_tool import BaseTool from ...tools.base_tool import BaseTool
from ...tools.tool_context import ToolContext from ...tools.tool_context import ToolContext
@ -148,62 +148,69 @@ async def handle_function_calls_async(
function_call, function_call,
tools_dict, tools_dict,
) )
# do not use "args" as the variable name, because it is a reserved keyword
# in python debugger.
function_args = function_call.args or {}
function_response: Optional[dict] = None
for callback in agent.canonical_before_tool_callbacks: with tracer.start_as_current_span(f'execute_tool {tool.name}'):
function_response = callback( # do not use "args" as the variable name, because it is a reserved keyword
tool=tool, args=function_args, tool_context=tool_context # in python debugger.
function_args = function_call.args or {}
function_response: Optional[dict] = None
for callback in agent.canonical_before_tool_callbacks:
function_response = callback(
tool=tool, args=function_args, tool_context=tool_context
)
if inspect.isawaitable(function_response):
function_response = await function_response
if function_response:
break
if not function_response:
function_response = await __call_tool_async(
tool, args=function_args, tool_context=tool_context
)
for callback in agent.canonical_after_tool_callbacks:
altered_function_response = callback(
tool=tool,
args=function_args,
tool_context=tool_context,
tool_response=function_response,
)
if inspect.isawaitable(altered_function_response):
altered_function_response = await altered_function_response
if altered_function_response is not None:
function_response = altered_function_response
break
if tool.is_long_running:
# Allow long running function to return None to not provide function response.
if not function_response:
continue
# Builds the function response event.
function_response_event = __build_response_event(
tool, function_response, tool_context, invocation_context
) )
if inspect.isawaitable(function_response): trace_tool_call(
function_response = await function_response
if function_response:
break
if not function_response:
function_response = await __call_tool_async(
tool, args=function_args, tool_context=tool_context
)
for callback in agent.canonical_after_tool_callbacks:
altered_function_response = callback(
tool=tool, tool=tool,
args=function_args, args=function_args,
tool_context=tool_context, function_response_event=function_response_event,
tool_response=function_response,
) )
if inspect.isawaitable(altered_function_response): function_response_events.append(function_response_event)
altered_function_response = await altered_function_response
if altered_function_response is not None:
function_response = altered_function_response
break
if tool.is_long_running:
# Allow long running function to return None to not provide function response.
if not function_response:
continue
# Builds the function response event.
function_response_event = __build_response_event(
tool, function_response, tool_context, invocation_context
)
function_response_events.append(function_response_event)
if not function_response_events: if not function_response_events:
return None return None
merged_event = merge_parallel_function_response_events( merged_event = merge_parallel_function_response_events(
function_response_events function_response_events
) )
if len(function_response_events) > 1: if len(function_response_events) > 1:
# this is needed for debug traces of parallel calls # this is needed for debug traces of parallel calls
# individual response with tool.name is traced in __build_response_event # individual response with tool.name is traced in __build_response_event
# (we drop tool.name from span name here as this is merged event) # (we drop tool.name from span name here as this is merged event)
with tracer.start_as_current_span('tool_response'): with tracer.start_as_current_span('execute_tool (merged)'):
trace_tool_response( trace_merged_tool_calls(
invocation_context=invocation_context, response_event_id=merged_event.id,
event_id=merged_event.id,
function_response_event=merged_event, function_response_event=merged_event,
) )
return merged_event return merged_event
@ -225,65 +232,81 @@ async def handle_function_calls_live(
tool, tool_context = _get_tool_and_context( tool, tool_context = _get_tool_and_context(
invocation_context, function_call_event, function_call, tools_dict invocation_context, function_call_event, function_call, tools_dict
) )
# do not use "args" as the variable name, because it is a reserved keyword with tracer.start_as_current_span(f'execute_tool {tool.name}'):
# in python debugger. # do not use "args" as the variable name, because it is a reserved keyword
function_args = function_call.args or {} # in python debugger.
function_response = None function_args = function_call.args or {}
# # Calls the tool if before_tool_callback does not exist or returns None. function_response = None
# if agent.before_tool_callback: # # Calls the tool if before_tool_callback does not exist or returns None.
# function_response = agent.before_tool_callback( # if agent.before_tool_callback:
# tool, function_args, tool_context # function_response = agent.before_tool_callback(
# ) # tool, function_args, tool_context
if agent.before_tool_callback: # )
function_response = agent.before_tool_callback( if agent.before_tool_callback:
tool=tool, args=function_args, tool_context=tool_context function_response = agent.before_tool_callback(
) tool=tool, args=function_args, tool_context=tool_context
if inspect.isawaitable(function_response): )
function_response = await function_response if inspect.isawaitable(function_response):
function_response = await function_response
if not function_response: if not function_response:
function_response = await _process_function_live_helper( function_response = await _process_function_live_helper(
tool, tool_context, function_call, function_args, invocation_context tool, tool_context, function_call, function_args, invocation_context
) )
# Calls after_tool_callback if it exists. # Calls after_tool_callback if it exists.
# if agent.after_tool_callback: # if agent.after_tool_callback:
# new_response = agent.after_tool_callback( # new_response = agent.after_tool_callback(
# tool, # tool,
# function_args, # function_args,
# tool_context, # tool_context,
# function_response, # function_response,
# ) # )
# if new_response: # if new_response:
# function_response = new_response # function_response = new_response
if agent.after_tool_callback: if agent.after_tool_callback:
altered_function_response = agent.after_tool_callback( altered_function_response = agent.after_tool_callback(
tool=tool,
args=function_args,
tool_context=tool_context,
tool_response=function_response,
)
if inspect.isawaitable(altered_function_response):
altered_function_response = await altered_function_response
if altered_function_response is not None:
function_response = altered_function_response
if tool.is_long_running:
# Allow async function to return None to not provide function response.
if not function_response:
continue
# Builds the function response event.
function_response_event = __build_response_event(
tool, function_response, tool_context, invocation_context
)
trace_tool_call(
tool=tool, tool=tool,
args=function_args, args=function_args,
tool_context=tool_context, response_event_id=function_response_event.id,
tool_response=function_response, function_response=function_response,
) )
if inspect.isawaitable(altered_function_response): function_response_events.append(function_response_event)
altered_function_response = await altered_function_response
if altered_function_response is not None:
function_response = altered_function_response
if tool.is_long_running:
# Allow async function to return None to not provide function response.
if not function_response:
continue
# Builds the function response event.
function_response_event = __build_response_event(
tool, function_response, tool_context, invocation_context
)
function_response_events.append(function_response_event)
if not function_response_events: if not function_response_events:
return None return None
merged_event = merge_parallel_function_response_events( merged_event = merge_parallel_function_response_events(
function_response_events function_response_events
) )
if len(function_response_events) > 1:
# this is needed for debug traces of parallel calls
# individual response with tool.name is traced in __build_response_event
# (we drop tool.name from span name here as this is merged event)
with tracer.start_as_current_span('execute_tool (merged)'):
trace_merged_tool_calls(
response_event_id=merged_event.id,
function_response_event=merged_event,
)
return merged_event return merged_event
@ -410,14 +433,12 @@ async def __call_tool_live(
invocation_context: InvocationContext, invocation_context: InvocationContext,
) -> AsyncGenerator[Event, None]: ) -> AsyncGenerator[Event, None]:
"""Calls the tool asynchronously (awaiting the coroutine).""" """Calls the tool asynchronously (awaiting the coroutine)."""
with tracer.start_as_current_span(f'tool_call [{tool.name}]'): async for item in tool._call_live(
trace_tool_call(args=args) args=args,
async for item in tool._call_live( tool_context=tool_context,
args=args, invocation_context=invocation_context,
tool_context=tool_context, ):
invocation_context=invocation_context, yield item
):
yield item
async def __call_tool_async( async def __call_tool_async(
@ -426,9 +447,7 @@ async def __call_tool_async(
tool_context: ToolContext, tool_context: ToolContext,
) -> Any: ) -> Any:
"""Calls the tool.""" """Calls the tool."""
with tracer.start_as_current_span(f'tool_call [{tool.name}]'): return await tool.run_async(args=args, tool_context=tool_context)
trace_tool_call(args=args)
return await tool.run_async(args=args, tool_context=tool_context)
def __build_response_event( def __build_response_event(
@ -437,35 +456,29 @@ def __build_response_event(
tool_context: ToolContext, tool_context: ToolContext,
invocation_context: InvocationContext, invocation_context: InvocationContext,
) -> Event: ) -> Event:
with tracer.start_as_current_span(f'tool_response [{tool.name}]'): # Specs requires the result to be a dict.
# Specs requires the result to be a dict. if not isinstance(function_result, dict):
if not isinstance(function_result, dict): function_result = {'result': function_result}
function_result = {'result': function_result}
part_function_response = types.Part.from_function_response( part_function_response = types.Part.from_function_response(
name=tool.name, response=function_result name=tool.name, response=function_result
) )
part_function_response.function_response.id = tool_context.function_call_id part_function_response.function_response.id = tool_context.function_call_id
content = types.Content( content = types.Content(
role='user', role='user',
parts=[part_function_response], parts=[part_function_response],
) )
function_response_event = Event( function_response_event = Event(
invocation_id=invocation_context.invocation_id, invocation_id=invocation_context.invocation_id,
author=invocation_context.agent.name, author=invocation_context.agent.name,
content=content, content=content,
actions=tool_context.actions, actions=tool_context.actions,
branch=invocation_context.branch, branch=invocation_context.branch,
) )
trace_tool_response( return function_response_event
invocation_context=invocation_context,
event_id=function_response_event.id,
function_response_event=function_response_event,
)
return function_response_event
def merge_parallel_function_response_events( def merge_parallel_function_response_events(

View File

@ -21,6 +21,8 @@
# Agent Development Kit should be focused on the higher-level # Agent Development Kit should be focused on the higher-level
# constructs of the framework that are not observable by the SDK. # constructs of the framework that are not observable by the SDK.
from __future__ import annotations
import json import json
from typing import Any from typing import Any
@ -31,51 +33,91 @@ from .agents.invocation_context import InvocationContext
from .events.event import Event from .events.event import Event
from .models.llm_request import LlmRequest from .models.llm_request import LlmRequest
from .models.llm_response import LlmResponse from .models.llm_response import LlmResponse
from .tools.base_tool import BaseTool
tracer = trace.get_tracer('gcp.vertex.agent') tracer = trace.get_tracer('gcp.vertex.agent')
def trace_tool_call( def trace_tool_call(
tool: BaseTool,
args: dict[str, Any], args: dict[str, Any],
function_response_event: Event,
): ):
"""Traces tool call. """Traces tool call.
Args: Args:
tool: The tool that was called.
args: The arguments to the tool call. args: The arguments to the tool call.
function_response_event: The event with the function response details.
""" """
span = trace.get_current_span() span = trace.get_current_span()
span.set_attribute('gen_ai.system', 'gcp.vertex.agent') span.set_attribute('gen_ai.system', 'gcp.vertex.agent')
span.set_attribute('gen_ai.operation.name', 'execute_tool')
span.set_attribute('gen_ai.tool.name', tool.name)
span.set_attribute('gen_ai.tool.description', tool.description)
tool_call_id = '<not specified>'
tool_response = '<not specified>'
if function_response_event.content.parts:
function_response = function_response_event.content.parts[
0
].function_response
if function_response is not None:
tool_call_id = function_response.id
tool_response = function_response.response
span.set_attribute('gen_ai.tool.call.id', tool_call_id)
if not isinstance(tool_response, dict):
tool_response = {'result': tool_response}
span.set_attribute('gcp.vertex.agent.tool_call_args', json.dumps(args)) span.set_attribute('gcp.vertex.agent.tool_call_args', json.dumps(args))
span.set_attribute('gcp.vertex.agent.event_id', function_response_event.id)
def trace_tool_response(
invocation_context: InvocationContext,
event_id: str,
function_response_event: Event,
):
"""Traces tool response event.
This function records details about the tool response event as attributes on
the current OpenTelemetry span.
Args:
invocation_context: The invocation context for the current agent run.
event_id: The ID of the event.
function_response_event: The function response event which can be either
merged function response for parallel function calls or individual
function response for sequential function calls.
"""
span = trace.get_current_span()
span.set_attribute('gen_ai.system', 'gcp.vertex.agent')
span.set_attribute(
'gcp.vertex.agent.invocation_id', invocation_context.invocation_id
)
span.set_attribute('gcp.vertex.agent.event_id', event_id)
span.set_attribute( span.set_attribute(
'gcp.vertex.agent.tool_response', 'gcp.vertex.agent.tool_response',
function_response_event.model_dump_json(exclude_none=True), json.dumps(tool_response),
)
# Setting empty llm request and response (as UI expect these) while not
# applicable for tool_response.
span.set_attribute('gcp.vertex.agent.llm_request', '{}')
span.set_attribute(
'gcp.vertex.agent.llm_response',
'{}',
) )
def trace_merged_tool_calls(
response_event_id: str,
function_response_event: Event,
):
"""Traces merged tool call events.
Calling this function is not needed for telemetry purposes. This is provided
for preventing /debug/trace requests (typically sent by web UI).
Args:
response_event_id: The ID of the response event.
function_response_event: The merged response event.
"""
span = trace.get_current_span()
span.set_attribute('gen_ai.system', 'gcp.vertex.agent')
span.set_attribute('gen_ai.operation.name', 'execute_tool')
span.set_attribute('gen_ai.tool.name', '(merged tools)')
span.set_attribute('gen_ai.tool.description', '(merged tools)')
span.set_attribute('gen_ai.tool.call.id', response_event_id)
span.set_attribute('gcp.vertex.agent.tool_call_args', 'N/A')
span.set_attribute('gcp.vertex.agent.event_id', response_event_id)
try:
function_response_event_json = function_response_event.model_dumps_json(
exclude_none=True
)
except Exception: # pylint: disable=broad-exception-caught
function_response_event_json = '<not serializable>'
span.set_attribute(
'gcp.vertex.agent.tool_response',
function_response_event_json,
)
# Setting empty llm request and response (as UI expect these) while not # Setting empty llm request and response (as UI expect these) while not
# applicable for tool_response. # applicable for tool_response.
span.set_attribute('gcp.vertex.agent.llm_request', '{}') span.set_attribute('gcp.vertex.agent.llm_request', '{}')
@ -123,9 +165,15 @@ def trace_call_llm(
), ),
) )
# Consider removing once GenAI SDK provides a way to record this info. # Consider removing once GenAI SDK provides a way to record this info.
try:
llm_response_json = llm_response.model_dump_json(exclude_none=True)
except Exception: # pylint: disable=broad-exception-caught
llm_response_json = '<not serializable>'
span.set_attribute( span.set_attribute(
'gcp.vertex.agent.llm_response', 'gcp.vertex.agent.llm_response',
llm_response.model_dump_json(exclude_none=True), llm_response_json,
) )

View File

@ -0,0 +1,99 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any
from typing import Dict
from typing import Optional
from unittest import mock
from google.adk import telemetry
from google.adk.agents import Agent
from google.adk.events.event import Event
from google.adk.flows.llm_flows.functions import handle_function_calls_async
from google.adk.tools.function_tool import FunctionTool
from google.genai import types
from ... import testing_utils
async def invoke_tool() -> Optional[Event]:
def simple_fn(**kwargs) -> Dict[str, Any]:
return {'result': 'test'}
tool = FunctionTool(simple_fn)
model = testing_utils.MockModel.create(responses=[])
agent = Agent(
name='agent',
model=model,
tools=[tool],
)
invocation_context = await testing_utils.create_invocation_context(
agent=agent, user_content=''
)
function_call = types.FunctionCall(name=tool.name, args={'a': 1, 'b': 2})
content = types.Content(parts=[types.Part(function_call=function_call)])
event = Event(
invocation_id=invocation_context.invocation_id,
author=agent.name,
content=content,
)
tools_dict = {tool.name: tool}
return await handle_function_calls_async(
invocation_context,
event,
tools_dict,
)
async def test_simple_function_with_mocked_tracer(monkeypatch):
mock_start_as_current_span_func = mock.Mock()
returned_context_manager_mock = mock.MagicMock()
returned_context_manager_mock.__enter__.return_value = mock.Mock(
name='span_mock'
)
mock_start_as_current_span_func.return_value = returned_context_manager_mock
monkeypatch.setattr(
telemetry.tracer, 'start_as_current_span', mock_start_as_current_span_func
)
mock_adk_trace_tool_call = mock.Mock()
monkeypatch.setattr(
'google.adk.flows.llm_flows.functions.trace_tool_call',
mock_adk_trace_tool_call,
)
event = await invoke_tool()
assert event is not None
event = await invoke_tool()
assert event is not None
expected_span_name = 'execute_tool simple_fn'
assert mock_start_as_current_span_func.call_count == 2
mock_start_as_current_span_func.assert_any_call(expected_span_name)
assert returned_context_manager_mock.__enter__.call_count == 2
assert returned_context_manager_mock.__exit__.call_count == 2
assert mock_adk_trace_tool_call.call_count == 2
for call_args_item in mock_adk_trace_tool_call.call_args_list:
kwargs = call_args_item.kwargs
assert kwargs['tool'].name == 'simple_fn'
assert kwargs['args'] == {'a': 1, 'b': 2}
assert 'function_response_event' in kwargs
assert kwargs['function_response_event'].content.parts[
0
].function_response.response == {'result': 'test'}

View File

@ -1,5 +1,22 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
from typing import Any from typing import Any
from typing import Dict
from typing import Optional from typing import Optional
from unittest import mock
from google.adk.agents.invocation_context import InvocationContext from google.adk.agents.invocation_context import InvocationContext
from google.adk.agents.llm_agent import LlmAgent from google.adk.agents.llm_agent import LlmAgent
@ -7,19 +24,55 @@ from google.adk.models.llm_request import LlmRequest
from google.adk.models.llm_response import LlmResponse from google.adk.models.llm_response import LlmResponse
from google.adk.sessions import InMemorySessionService from google.adk.sessions import InMemorySessionService
from google.adk.telemetry import trace_call_llm from google.adk.telemetry import trace_call_llm
from google.adk.telemetry import trace_merged_tool_calls
from google.adk.telemetry import trace_tool_call
from google.adk.tools.base_tool import BaseTool
from google.genai import types from google.genai import types
import pytest import pytest
class Event:
def __init__(self, event_id: str, event_content: Any):
self.id = event_id
self.content = event_content
def model_dumps_json(self, exclude_none: bool = False) -> str:
# This is just a stub for the spec. The mock will provide behavior.
return ''
@pytest.fixture
def mock_span_fixture():
return mock.MagicMock()
@pytest.fixture
def mock_tool_fixture():
tool = mock.Mock(spec=BaseTool)
tool.name = 'sample_tool'
tool.description = 'A sample tool for testing.'
return tool
@pytest.fixture
def mock_event_fixture():
event_mock = mock.create_autospec(Event, instance=True)
event_mock.model_dumps_json.return_value = (
'{"default_event_key": "default_event_value"}'
)
return event_mock
async def _create_invocation_context( async def _create_invocation_context(
agent: LlmAgent, state: Optional[dict[str, Any]] = None agent: LlmAgent, state: Optional[dict[str, Any]] = None
) -> InvocationContext: ) -> InvocationContext:
session_service = InMemorySessionService() session_service = InMemorySessionService()
session = await session_service.create_session( session = await session_service.create_session(
app_name="test_app", user_id="test_user", state=state app_name='test_app', user_id='test_user', state=state
) )
invocation_context = InvocationContext( invocation_context = InvocationContext(
invocation_id="test_id", invocation_id='test_id',
agent=agent, agent=agent,
session=session, session=session,
session_service=session_service, session_service=session_service,
@ -28,38 +81,216 @@ async def _create_invocation_context(
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_trace_call_llm_function_response_includes_part_from_bytes(): async def test_trace_call_llm_function_response_includes_part_from_bytes(
agent = LlmAgent(name="test_agent") monkeypatch, mock_span_fixture
):
monkeypatch.setattr(
'opentelemetry.trace.get_current_span', lambda: mock_span_fixture
)
agent = LlmAgent(name='test_agent')
invocation_context = await _create_invocation_context(agent) invocation_context = await _create_invocation_context(agent)
llm_request = LlmRequest( llm_request = LlmRequest(
contents=[ contents=[
types.Content( types.Content(
role="user", role='user',
parts=[ parts=[
types.Part.from_function_response( types.Part.from_function_response(
name="test_function_1", name='test_function_1',
response={ response={
"result": b"test_data", 'result': b'test_data',
}, },
), ),
], ],
), ),
types.Content( types.Content(
role="user", role='user',
parts=[ parts=[
types.Part.from_function_response( types.Part.from_function_response(
name="test_function_2", name='test_function_2',
response={ response={
"result": types.Part.from_bytes( 'result': types.Part.from_bytes(
data=b"test_data", data=b'test_data',
mime_type="application/octet-stream", mime_type='application/octet-stream',
), ),
}, },
), ),
], ],
), ),
], ],
config=types.GenerateContentConfig(system_instruction=""), config=types.GenerateContentConfig(system_instruction=''),
) )
llm_response = LlmResponse(turn_complete=True) llm_response = LlmResponse(turn_complete=True)
trace_call_llm(invocation_context, "test_event_id", llm_request, llm_response) trace_call_llm(invocation_context, 'test_event_id', llm_request, llm_response)
expected_calls = [
mock.call('gen_ai.system', 'gcp.vertex.agent'),
]
assert mock_span_fixture.set_attribute.call_count == 7
mock_span_fixture.set_attribute.assert_has_calls(expected_calls)
llm_request_json_str = None
for call_obj in mock_span_fixture.set_attribute.call_args_list:
if call_obj.args[0] == 'gcp.vertex.agent.llm_request':
llm_request_json_str = call_obj.args[1]
break
assert (
llm_request_json_str is not None
), "Attribute 'gcp.vertex.agent.llm_request' was not set on the span."
assert llm_request_json_str.count('<not serializable>') == 2
def test_trace_tool_call_with_scalar_response(
monkeypatch, mock_span_fixture, mock_tool_fixture, mock_event_fixture
):
monkeypatch.setattr(
'opentelemetry.trace.get_current_span', lambda: mock_span_fixture
)
test_args: Dict[str, Any] = {'param_a': 'value_a', 'param_b': 100}
test_tool_call_id: str = 'tool_call_id_001'
test_event_id: str = 'event_id_001'
scalar_function_response: Any = 'Scalar result'
expected_processed_response = {'result': scalar_function_response}
mock_event_fixture.id = test_event_id
mock_event_fixture.content = types.Content(
role='user',
parts=[
types.Part(
function_response=types.FunctionResponse(
id=test_tool_call_id,
name='test_function_1',
response={'result': scalar_function_response},
)
),
],
)
# Act
trace_tool_call(
tool=mock_tool_fixture,
args=test_args,
function_response_event=mock_event_fixture,
)
# Assert
assert mock_span_fixture.set_attribute.call_count == 10
expected_calls = [
mock.call('gen_ai.system', 'gcp.vertex.agent'),
mock.call('gen_ai.operation.name', 'execute_tool'),
mock.call('gen_ai.tool.name', mock_tool_fixture.name),
mock.call('gen_ai.tool.description', mock_tool_fixture.description),
mock.call('gen_ai.tool.call.id', test_tool_call_id),
mock.call('gcp.vertex.agent.tool_call_args', json.dumps(test_args)),
mock.call('gcp.vertex.agent.event_id', test_event_id),
mock.call(
'gcp.vertex.agent.tool_response',
json.dumps(expected_processed_response),
),
mock.call('gcp.vertex.agent.llm_request', '{}'),
mock.call('gcp.vertex.agent.llm_response', '{}'),
]
mock_span_fixture.set_attribute.assert_has_calls(
expected_calls, any_order=True
)
def test_trace_tool_call_with_dict_response(
monkeypatch, mock_span_fixture, mock_tool_fixture, mock_event_fixture
):
# Arrange
monkeypatch.setattr(
'opentelemetry.trace.get_current_span', lambda: mock_span_fixture
)
test_args: Dict[str, Any] = {'query': 'details', 'id_list': [1, 2, 3]}
test_tool_call_id: str = 'tool_call_id_002'
test_event_id: str = 'event_id_dict_002'
dict_function_response: Dict[str, Any] = {
'data': 'structured_data',
'count': 5,
}
mock_event_fixture.id = test_event_id
mock_event_fixture.content = types.Content(
role='user',
parts=[
types.Part(
function_response=types.FunctionResponse(
id=test_tool_call_id,
name='test_function_1',
response=dict_function_response,
)
),
],
)
# Act
trace_tool_call(
tool=mock_tool_fixture,
args=test_args,
function_response_event=mock_event_fixture,
)
# Assert
expected_calls = [
mock.call('gen_ai.system', 'gcp.vertex.agent'),
mock.call('gen_ai.operation.name', 'execute_tool'),
mock.call('gen_ai.tool.name', mock_tool_fixture.name),
mock.call('gen_ai.tool.description', mock_tool_fixture.description),
mock.call('gen_ai.tool.call.id', test_tool_call_id),
mock.call('gcp.vertex.agent.tool_call_args', json.dumps(test_args)),
mock.call('gcp.vertex.agent.event_id', test_event_id),
mock.call(
'gcp.vertex.agent.tool_response', json.dumps(dict_function_response)
),
mock.call('gcp.vertex.agent.llm_request', '{}'),
mock.call('gcp.vertex.agent.llm_response', '{}'),
]
assert mock_span_fixture.set_attribute.call_count == 10
mock_span_fixture.set_attribute.assert_has_calls(
expected_calls, any_order=True
)
def test_trace_merged_tool_calls_sets_correct_attributes(
monkeypatch, mock_span_fixture, mock_event_fixture
):
monkeypatch.setattr(
'opentelemetry.trace.get_current_span', lambda: mock_span_fixture
)
test_response_event_id = 'merged_evt_id_001'
custom_event_json_output = (
'{"custom_event_payload": true, "details": "merged_details"}'
)
mock_event_fixture.model_dumps_json.return_value = custom_event_json_output
trace_merged_tool_calls(
response_event_id=test_response_event_id,
function_response_event=mock_event_fixture,
)
expected_calls = [
mock.call('gen_ai.system', 'gcp.vertex.agent'),
mock.call('gen_ai.operation.name', 'execute_tool'),
mock.call('gen_ai.tool.name', '(merged tools)'),
mock.call('gen_ai.tool.description', '(merged tools)'),
mock.call('gen_ai.tool.call.id', test_response_event_id),
mock.call('gcp.vertex.agent.tool_call_args', 'N/A'),
mock.call('gcp.vertex.agent.event_id', test_response_event_id),
mock.call('gcp.vertex.agent.tool_response', custom_event_json_output),
mock.call('gcp.vertex.agent.llm_request', '{}'),
mock.call('gcp.vertex.agent.llm_response', '{}'),
]
assert mock_span_fixture.set_attribute.call_count == 10
mock_span_fixture.set_attribute.assert_has_calls(
expected_calls, any_order=True
)
mock_event_fixture.model_dumps_json.assert_called_once_with(exclude_none=True)