Make tool_call one span for telemetry

Also renamed tool_call as execute_tool and added attributes as recommended in https://opentelemetry.io/docs/specs/semconv/gen-ai/gen-ai-spans/#execute-tool-span.

PiperOrigin-RevId: 764594179
This commit is contained in:
Selcuk Gun
2025-05-29 00:01:18 -07:00
committed by Copybara-Service
parent 96b36b70dc
commit 999a7fe69d
7 changed files with 779 additions and 168 deletions

View File

@@ -98,7 +98,7 @@ class ApiServerSpanExporter(export.SpanExporter):
if (
span.name == "call_llm"
or span.name == "send_data"
or span.name.startswith("tool_response")
or span.name.startswith("execute_tool")
):
attributes = dict(span.attributes)
attributes["trace_id"] = span.get_span_context().trace_id

View File

@@ -32,8 +32,8 @@ from ...agents.invocation_context import InvocationContext
from ...auth.auth_tool import AuthToolArguments
from ...events.event import Event
from ...events.event_actions import EventActions
from ...telemetry import trace_merged_tool_calls
from ...telemetry import trace_tool_call
from ...telemetry import trace_tool_response
from ...telemetry import tracer
from ...tools.base_tool import BaseTool
from ...tools.tool_context import ToolContext
@@ -148,62 +148,69 @@ async def handle_function_calls_async(
function_call,
tools_dict,
)
# do not use "args" as the variable name, because it is a reserved keyword
# in python debugger.
function_args = function_call.args or {}
function_response: Optional[dict] = None
for callback in agent.canonical_before_tool_callbacks:
function_response = callback(
tool=tool, args=function_args, tool_context=tool_context
with tracer.start_as_current_span(f'execute_tool {tool.name}'):
# do not use "args" as the variable name, because it is a reserved keyword
# in python debugger.
function_args = function_call.args or {}
function_response: Optional[dict] = None
for callback in agent.canonical_before_tool_callbacks:
function_response = callback(
tool=tool, args=function_args, tool_context=tool_context
)
if inspect.isawaitable(function_response):
function_response = await function_response
if function_response:
break
if not function_response:
function_response = await __call_tool_async(
tool, args=function_args, tool_context=tool_context
)
for callback in agent.canonical_after_tool_callbacks:
altered_function_response = callback(
tool=tool,
args=function_args,
tool_context=tool_context,
tool_response=function_response,
)
if inspect.isawaitable(altered_function_response):
altered_function_response = await altered_function_response
if altered_function_response is not None:
function_response = altered_function_response
break
if tool.is_long_running:
# Allow long running function to return None to not provide function response.
if not function_response:
continue
# Builds the function response event.
function_response_event = __build_response_event(
tool, function_response, tool_context, invocation_context
)
if inspect.isawaitable(function_response):
function_response = await function_response
if function_response:
break
if not function_response:
function_response = await __call_tool_async(
tool, args=function_args, tool_context=tool_context
)
for callback in agent.canonical_after_tool_callbacks:
altered_function_response = callback(
trace_tool_call(
tool=tool,
args=function_args,
tool_context=tool_context,
tool_response=function_response,
function_response_event=function_response_event,
)
if inspect.isawaitable(altered_function_response):
altered_function_response = await altered_function_response
if altered_function_response is not None:
function_response = altered_function_response
break
if tool.is_long_running:
# Allow long running function to return None to not provide function response.
if not function_response:
continue
# Builds the function response event.
function_response_event = __build_response_event(
tool, function_response, tool_context, invocation_context
)
function_response_events.append(function_response_event)
function_response_events.append(function_response_event)
if not function_response_events:
return None
merged_event = merge_parallel_function_response_events(
function_response_events
)
if len(function_response_events) > 1:
# this is needed for debug traces of parallel calls
# individual response with tool.name is traced in __build_response_event
# (we drop tool.name from span name here as this is merged event)
with tracer.start_as_current_span('tool_response'):
trace_tool_response(
invocation_context=invocation_context,
event_id=merged_event.id,
with tracer.start_as_current_span('execute_tool (merged)'):
trace_merged_tool_calls(
response_event_id=merged_event.id,
function_response_event=merged_event,
)
return merged_event
@@ -225,65 +232,81 @@ async def handle_function_calls_live(
tool, tool_context = _get_tool_and_context(
invocation_context, function_call_event, function_call, tools_dict
)
# do not use "args" as the variable name, because it is a reserved keyword
# in python debugger.
function_args = function_call.args or {}
function_response = None
# # Calls the tool if before_tool_callback does not exist or returns None.
# if agent.before_tool_callback:
# function_response = agent.before_tool_callback(
# tool, function_args, tool_context
# )
if agent.before_tool_callback:
function_response = agent.before_tool_callback(
tool=tool, args=function_args, tool_context=tool_context
)
if inspect.isawaitable(function_response):
function_response = await function_response
with tracer.start_as_current_span(f'execute_tool {tool.name}'):
# do not use "args" as the variable name, because it is a reserved keyword
# in python debugger.
function_args = function_call.args or {}
function_response = None
# # Calls the tool if before_tool_callback does not exist or returns None.
# if agent.before_tool_callback:
# function_response = agent.before_tool_callback(
# tool, function_args, tool_context
# )
if agent.before_tool_callback:
function_response = agent.before_tool_callback(
tool=tool, args=function_args, tool_context=tool_context
)
if inspect.isawaitable(function_response):
function_response = await function_response
if not function_response:
function_response = await _process_function_live_helper(
tool, tool_context, function_call, function_args, invocation_context
)
if not function_response:
function_response = await _process_function_live_helper(
tool, tool_context, function_call, function_args, invocation_context
)
# Calls after_tool_callback if it exists.
# if agent.after_tool_callback:
# new_response = agent.after_tool_callback(
# tool,
# function_args,
# tool_context,
# function_response,
# )
# if new_response:
# function_response = new_response
if agent.after_tool_callback:
altered_function_response = agent.after_tool_callback(
# Calls after_tool_callback if it exists.
# if agent.after_tool_callback:
# new_response = agent.after_tool_callback(
# tool,
# function_args,
# tool_context,
# function_response,
# )
# if new_response:
# function_response = new_response
if agent.after_tool_callback:
altered_function_response = agent.after_tool_callback(
tool=tool,
args=function_args,
tool_context=tool_context,
tool_response=function_response,
)
if inspect.isawaitable(altered_function_response):
altered_function_response = await altered_function_response
if altered_function_response is not None:
function_response = altered_function_response
if tool.is_long_running:
# Allow async function to return None to not provide function response.
if not function_response:
continue
# Builds the function response event.
function_response_event = __build_response_event(
tool, function_response, tool_context, invocation_context
)
trace_tool_call(
tool=tool,
args=function_args,
tool_context=tool_context,
tool_response=function_response,
response_event_id=function_response_event.id,
function_response=function_response,
)
if inspect.isawaitable(altered_function_response):
altered_function_response = await altered_function_response
if altered_function_response is not None:
function_response = altered_function_response
if tool.is_long_running:
# Allow async function to return None to not provide function response.
if not function_response:
continue
# Builds the function response event.
function_response_event = __build_response_event(
tool, function_response, tool_context, invocation_context
)
function_response_events.append(function_response_event)
function_response_events.append(function_response_event)
if not function_response_events:
return None
merged_event = merge_parallel_function_response_events(
function_response_events
)
if len(function_response_events) > 1:
# this is needed for debug traces of parallel calls
# individual response with tool.name is traced in __build_response_event
# (we drop tool.name from span name here as this is merged event)
with tracer.start_as_current_span('execute_tool (merged)'):
trace_merged_tool_calls(
response_event_id=merged_event.id,
function_response_event=merged_event,
)
return merged_event
@@ -410,14 +433,12 @@ async def __call_tool_live(
invocation_context: InvocationContext,
) -> AsyncGenerator[Event, None]:
"""Calls the tool asynchronously (awaiting the coroutine)."""
with tracer.start_as_current_span(f'tool_call [{tool.name}]'):
trace_tool_call(args=args)
async for item in tool._call_live(
args=args,
tool_context=tool_context,
invocation_context=invocation_context,
):
yield item
async for item in tool._call_live(
args=args,
tool_context=tool_context,
invocation_context=invocation_context,
):
yield item
async def __call_tool_async(
@@ -426,9 +447,7 @@ async def __call_tool_async(
tool_context: ToolContext,
) -> Any:
"""Calls the tool."""
with tracer.start_as_current_span(f'tool_call [{tool.name}]'):
trace_tool_call(args=args)
return await tool.run_async(args=args, tool_context=tool_context)
return await tool.run_async(args=args, tool_context=tool_context)
def __build_response_event(
@@ -437,35 +456,29 @@ def __build_response_event(
tool_context: ToolContext,
invocation_context: InvocationContext,
) -> Event:
with tracer.start_as_current_span(f'tool_response [{tool.name}]'):
# Specs requires the result to be a dict.
if not isinstance(function_result, dict):
function_result = {'result': function_result}
# Specs requires the result to be a dict.
if not isinstance(function_result, dict):
function_result = {'result': function_result}
part_function_response = types.Part.from_function_response(
name=tool.name, response=function_result
)
part_function_response.function_response.id = tool_context.function_call_id
part_function_response = types.Part.from_function_response(
name=tool.name, response=function_result
)
part_function_response.function_response.id = tool_context.function_call_id
content = types.Content(
role='user',
parts=[part_function_response],
)
content = types.Content(
role='user',
parts=[part_function_response],
)
function_response_event = Event(
invocation_id=invocation_context.invocation_id,
author=invocation_context.agent.name,
content=content,
actions=tool_context.actions,
branch=invocation_context.branch,
)
function_response_event = Event(
invocation_id=invocation_context.invocation_id,
author=invocation_context.agent.name,
content=content,
actions=tool_context.actions,
branch=invocation_context.branch,
)
trace_tool_response(
invocation_context=invocation_context,
event_id=function_response_event.id,
function_response_event=function_response_event,
)
return function_response_event
return function_response_event
def merge_parallel_function_response_events(

View File

@@ -21,6 +21,8 @@
# Agent Development Kit should be focused on the higher-level
# constructs of the framework that are not observable by the SDK.
from __future__ import annotations
import json
from typing import Any
@@ -31,51 +33,91 @@ from .agents.invocation_context import InvocationContext
from .events.event import Event
from .models.llm_request import LlmRequest
from .models.llm_response import LlmResponse
from .tools.base_tool import BaseTool
tracer = trace.get_tracer('gcp.vertex.agent')
def trace_tool_call(
tool: BaseTool,
args: dict[str, Any],
function_response_event: Event,
):
"""Traces tool call.
Args:
tool: The tool that was called.
args: The arguments to the tool call.
function_response_event: The event with the function response details.
"""
span = trace.get_current_span()
span.set_attribute('gen_ai.system', 'gcp.vertex.agent')
span.set_attribute('gen_ai.operation.name', 'execute_tool')
span.set_attribute('gen_ai.tool.name', tool.name)
span.set_attribute('gen_ai.tool.description', tool.description)
tool_call_id = '<not specified>'
tool_response = '<not specified>'
if function_response_event.content.parts:
function_response = function_response_event.content.parts[
0
].function_response
if function_response is not None:
tool_call_id = function_response.id
tool_response = function_response.response
span.set_attribute('gen_ai.tool.call.id', tool_call_id)
if not isinstance(tool_response, dict):
tool_response = {'result': tool_response}
span.set_attribute('gcp.vertex.agent.tool_call_args', json.dumps(args))
def trace_tool_response(
invocation_context: InvocationContext,
event_id: str,
function_response_event: Event,
):
"""Traces tool response event.
This function records details about the tool response event as attributes on
the current OpenTelemetry span.
Args:
invocation_context: The invocation context for the current agent run.
event_id: The ID of the event.
function_response_event: The function response event which can be either
merged function response for parallel function calls or individual
function response for sequential function calls.
"""
span = trace.get_current_span()
span.set_attribute('gen_ai.system', 'gcp.vertex.agent')
span.set_attribute(
'gcp.vertex.agent.invocation_id', invocation_context.invocation_id
)
span.set_attribute('gcp.vertex.agent.event_id', event_id)
span.set_attribute('gcp.vertex.agent.event_id', function_response_event.id)
span.set_attribute(
'gcp.vertex.agent.tool_response',
function_response_event.model_dump_json(exclude_none=True),
json.dumps(tool_response),
)
# Setting empty llm request and response (as UI expect these) while not
# applicable for tool_response.
span.set_attribute('gcp.vertex.agent.llm_request', '{}')
span.set_attribute(
'gcp.vertex.agent.llm_response',
'{}',
)
def trace_merged_tool_calls(
response_event_id: str,
function_response_event: Event,
):
"""Traces merged tool call events.
Calling this function is not needed for telemetry purposes. This is provided
for preventing /debug/trace requests (typically sent by web UI).
Args:
response_event_id: The ID of the response event.
function_response_event: The merged response event.
"""
span = trace.get_current_span()
span.set_attribute('gen_ai.system', 'gcp.vertex.agent')
span.set_attribute('gen_ai.operation.name', 'execute_tool')
span.set_attribute('gen_ai.tool.name', '(merged tools)')
span.set_attribute('gen_ai.tool.description', '(merged tools)')
span.set_attribute('gen_ai.tool.call.id', response_event_id)
span.set_attribute('gcp.vertex.agent.tool_call_args', 'N/A')
span.set_attribute('gcp.vertex.agent.event_id', response_event_id)
try:
function_response_event_json = function_response_event.model_dumps_json(
exclude_none=True
)
except Exception: # pylint: disable=broad-exception-caught
function_response_event_json = '<not serializable>'
span.set_attribute(
'gcp.vertex.agent.tool_response',
function_response_event_json,
)
# Setting empty llm request and response (as UI expect these) while not
# applicable for tool_response.
span.set_attribute('gcp.vertex.agent.llm_request', '{}')
@@ -123,9 +165,15 @@ def trace_call_llm(
),
)
# Consider removing once GenAI SDK provides a way to record this info.
try:
llm_response_json = llm_response.model_dump_json(exclude_none=True)
except Exception: # pylint: disable=broad-exception-caught
llm_response_json = '<not serializable>'
span.set_attribute(
'gcp.vertex.agent.llm_response',
llm_response.model_dump_json(exclude_none=True),
llm_response_json,
)