Make tool_call one span for telemetry

Also renamed tool_call as execute_tool and added attributes as recommended in https://opentelemetry.io/docs/specs/semconv/gen-ai/gen-ai-spans/#execute-tool-span.

PiperOrigin-RevId: 764594179
This commit is contained in:
Selcuk Gun
2025-05-29 00:01:18 -07:00
committed by Copybara-Service
parent 96b36b70dc
commit 999a7fe69d
7 changed files with 779 additions and 168 deletions

View File

@@ -0,0 +1,99 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any
from typing import Dict
from typing import Optional
from unittest import mock
from google.adk import telemetry
from google.adk.agents import Agent
from google.adk.events.event import Event
from google.adk.flows.llm_flows.functions import handle_function_calls_async
from google.adk.tools.function_tool import FunctionTool
from google.genai import types
from ... import testing_utils
async def invoke_tool() -> Optional[Event]:
def simple_fn(**kwargs) -> Dict[str, Any]:
return {'result': 'test'}
tool = FunctionTool(simple_fn)
model = testing_utils.MockModel.create(responses=[])
agent = Agent(
name='agent',
model=model,
tools=[tool],
)
invocation_context = await testing_utils.create_invocation_context(
agent=agent, user_content=''
)
function_call = types.FunctionCall(name=tool.name, args={'a': 1, 'b': 2})
content = types.Content(parts=[types.Part(function_call=function_call)])
event = Event(
invocation_id=invocation_context.invocation_id,
author=agent.name,
content=content,
)
tools_dict = {tool.name: tool}
return await handle_function_calls_async(
invocation_context,
event,
tools_dict,
)
async def test_simple_function_with_mocked_tracer(monkeypatch):
mock_start_as_current_span_func = mock.Mock()
returned_context_manager_mock = mock.MagicMock()
returned_context_manager_mock.__enter__.return_value = mock.Mock(
name='span_mock'
)
mock_start_as_current_span_func.return_value = returned_context_manager_mock
monkeypatch.setattr(
telemetry.tracer, 'start_as_current_span', mock_start_as_current_span_func
)
mock_adk_trace_tool_call = mock.Mock()
monkeypatch.setattr(
'google.adk.flows.llm_flows.functions.trace_tool_call',
mock_adk_trace_tool_call,
)
event = await invoke_tool()
assert event is not None
event = await invoke_tool()
assert event is not None
expected_span_name = 'execute_tool simple_fn'
assert mock_start_as_current_span_func.call_count == 2
mock_start_as_current_span_func.assert_any_call(expected_span_name)
assert returned_context_manager_mock.__enter__.call_count == 2
assert returned_context_manager_mock.__exit__.call_count == 2
assert mock_adk_trace_tool_call.call_count == 2
for call_args_item in mock_adk_trace_tool_call.call_args_list:
kwargs = call_args_item.kwargs
assert kwargs['tool'].name == 'simple_fn'
assert kwargs['args'] == {'a': 1, 'b': 2}
assert 'function_response_event' in kwargs
assert kwargs['function_response_event'].content.parts[
0
].function_response.response == {'result': 'test'}

View File

@@ -1,5 +1,22 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
from typing import Any
from typing import Dict
from typing import Optional
from unittest import mock
from google.adk.agents.invocation_context import InvocationContext
from google.adk.agents.llm_agent import LlmAgent
@@ -7,19 +24,55 @@ from google.adk.models.llm_request import LlmRequest
from google.adk.models.llm_response import LlmResponse
from google.adk.sessions import InMemorySessionService
from google.adk.telemetry import trace_call_llm
from google.adk.telemetry import trace_merged_tool_calls
from google.adk.telemetry import trace_tool_call
from google.adk.tools.base_tool import BaseTool
from google.genai import types
import pytest
class Event:
def __init__(self, event_id: str, event_content: Any):
self.id = event_id
self.content = event_content
def model_dumps_json(self, exclude_none: bool = False) -> str:
# This is just a stub for the spec. The mock will provide behavior.
return ''
@pytest.fixture
def mock_span_fixture():
return mock.MagicMock()
@pytest.fixture
def mock_tool_fixture():
tool = mock.Mock(spec=BaseTool)
tool.name = 'sample_tool'
tool.description = 'A sample tool for testing.'
return tool
@pytest.fixture
def mock_event_fixture():
event_mock = mock.create_autospec(Event, instance=True)
event_mock.model_dumps_json.return_value = (
'{"default_event_key": "default_event_value"}'
)
return event_mock
async def _create_invocation_context(
agent: LlmAgent, state: Optional[dict[str, Any]] = None
) -> InvocationContext:
session_service = InMemorySessionService()
session = await session_service.create_session(
app_name="test_app", user_id="test_user", state=state
app_name='test_app', user_id='test_user', state=state
)
invocation_context = InvocationContext(
invocation_id="test_id",
invocation_id='test_id',
agent=agent,
session=session,
session_service=session_service,
@@ -28,38 +81,216 @@ async def _create_invocation_context(
@pytest.mark.asyncio
async def test_trace_call_llm_function_response_includes_part_from_bytes():
agent = LlmAgent(name="test_agent")
async def test_trace_call_llm_function_response_includes_part_from_bytes(
monkeypatch, mock_span_fixture
):
monkeypatch.setattr(
'opentelemetry.trace.get_current_span', lambda: mock_span_fixture
)
agent = LlmAgent(name='test_agent')
invocation_context = await _create_invocation_context(agent)
llm_request = LlmRequest(
contents=[
types.Content(
role="user",
role='user',
parts=[
types.Part.from_function_response(
name="test_function_1",
name='test_function_1',
response={
"result": b"test_data",
'result': b'test_data',
},
),
],
),
types.Content(
role="user",
role='user',
parts=[
types.Part.from_function_response(
name="test_function_2",
name='test_function_2',
response={
"result": types.Part.from_bytes(
data=b"test_data",
mime_type="application/octet-stream",
'result': types.Part.from_bytes(
data=b'test_data',
mime_type='application/octet-stream',
),
},
),
],
),
],
config=types.GenerateContentConfig(system_instruction=""),
config=types.GenerateContentConfig(system_instruction=''),
)
llm_response = LlmResponse(turn_complete=True)
trace_call_llm(invocation_context, "test_event_id", llm_request, llm_response)
trace_call_llm(invocation_context, 'test_event_id', llm_request, llm_response)
expected_calls = [
mock.call('gen_ai.system', 'gcp.vertex.agent'),
]
assert mock_span_fixture.set_attribute.call_count == 7
mock_span_fixture.set_attribute.assert_has_calls(expected_calls)
llm_request_json_str = None
for call_obj in mock_span_fixture.set_attribute.call_args_list:
if call_obj.args[0] == 'gcp.vertex.agent.llm_request':
llm_request_json_str = call_obj.args[1]
break
assert (
llm_request_json_str is not None
), "Attribute 'gcp.vertex.agent.llm_request' was not set on the span."
assert llm_request_json_str.count('<not serializable>') == 2
def test_trace_tool_call_with_scalar_response(
monkeypatch, mock_span_fixture, mock_tool_fixture, mock_event_fixture
):
monkeypatch.setattr(
'opentelemetry.trace.get_current_span', lambda: mock_span_fixture
)
test_args: Dict[str, Any] = {'param_a': 'value_a', 'param_b': 100}
test_tool_call_id: str = 'tool_call_id_001'
test_event_id: str = 'event_id_001'
scalar_function_response: Any = 'Scalar result'
expected_processed_response = {'result': scalar_function_response}
mock_event_fixture.id = test_event_id
mock_event_fixture.content = types.Content(
role='user',
parts=[
types.Part(
function_response=types.FunctionResponse(
id=test_tool_call_id,
name='test_function_1',
response={'result': scalar_function_response},
)
),
],
)
# Act
trace_tool_call(
tool=mock_tool_fixture,
args=test_args,
function_response_event=mock_event_fixture,
)
# Assert
assert mock_span_fixture.set_attribute.call_count == 10
expected_calls = [
mock.call('gen_ai.system', 'gcp.vertex.agent'),
mock.call('gen_ai.operation.name', 'execute_tool'),
mock.call('gen_ai.tool.name', mock_tool_fixture.name),
mock.call('gen_ai.tool.description', mock_tool_fixture.description),
mock.call('gen_ai.tool.call.id', test_tool_call_id),
mock.call('gcp.vertex.agent.tool_call_args', json.dumps(test_args)),
mock.call('gcp.vertex.agent.event_id', test_event_id),
mock.call(
'gcp.vertex.agent.tool_response',
json.dumps(expected_processed_response),
),
mock.call('gcp.vertex.agent.llm_request', '{}'),
mock.call('gcp.vertex.agent.llm_response', '{}'),
]
mock_span_fixture.set_attribute.assert_has_calls(
expected_calls, any_order=True
)
def test_trace_tool_call_with_dict_response(
monkeypatch, mock_span_fixture, mock_tool_fixture, mock_event_fixture
):
# Arrange
monkeypatch.setattr(
'opentelemetry.trace.get_current_span', lambda: mock_span_fixture
)
test_args: Dict[str, Any] = {'query': 'details', 'id_list': [1, 2, 3]}
test_tool_call_id: str = 'tool_call_id_002'
test_event_id: str = 'event_id_dict_002'
dict_function_response: Dict[str, Any] = {
'data': 'structured_data',
'count': 5,
}
mock_event_fixture.id = test_event_id
mock_event_fixture.content = types.Content(
role='user',
parts=[
types.Part(
function_response=types.FunctionResponse(
id=test_tool_call_id,
name='test_function_1',
response=dict_function_response,
)
),
],
)
# Act
trace_tool_call(
tool=mock_tool_fixture,
args=test_args,
function_response_event=mock_event_fixture,
)
# Assert
expected_calls = [
mock.call('gen_ai.system', 'gcp.vertex.agent'),
mock.call('gen_ai.operation.name', 'execute_tool'),
mock.call('gen_ai.tool.name', mock_tool_fixture.name),
mock.call('gen_ai.tool.description', mock_tool_fixture.description),
mock.call('gen_ai.tool.call.id', test_tool_call_id),
mock.call('gcp.vertex.agent.tool_call_args', json.dumps(test_args)),
mock.call('gcp.vertex.agent.event_id', test_event_id),
mock.call(
'gcp.vertex.agent.tool_response', json.dumps(dict_function_response)
),
mock.call('gcp.vertex.agent.llm_request', '{}'),
mock.call('gcp.vertex.agent.llm_response', '{}'),
]
assert mock_span_fixture.set_attribute.call_count == 10
mock_span_fixture.set_attribute.assert_has_calls(
expected_calls, any_order=True
)
def test_trace_merged_tool_calls_sets_correct_attributes(
monkeypatch, mock_span_fixture, mock_event_fixture
):
monkeypatch.setattr(
'opentelemetry.trace.get_current_span', lambda: mock_span_fixture
)
test_response_event_id = 'merged_evt_id_001'
custom_event_json_output = (
'{"custom_event_payload": true, "details": "merged_details"}'
)
mock_event_fixture.model_dumps_json.return_value = custom_event_json_output
trace_merged_tool_calls(
response_event_id=test_response_event_id,
function_response_event=mock_event_fixture,
)
expected_calls = [
mock.call('gen_ai.system', 'gcp.vertex.agent'),
mock.call('gen_ai.operation.name', 'execute_tool'),
mock.call('gen_ai.tool.name', '(merged tools)'),
mock.call('gen_ai.tool.description', '(merged tools)'),
mock.call('gen_ai.tool.call.id', test_response_event_id),
mock.call('gcp.vertex.agent.tool_call_args', 'N/A'),
mock.call('gcp.vertex.agent.event_id', test_response_event_id),
mock.call('gcp.vertex.agent.tool_response', custom_event_json_output),
mock.call('gcp.vertex.agent.llm_request', '{}'),
mock.call('gcp.vertex.agent.llm_response', '{}'),
]
assert mock_span_fixture.set_attribute.call_count == 10
mock_span_fixture.set_attribute.assert_has_calls(
expected_calls, any_order=True
)
mock_event_fixture.model_dumps_json.assert_called_once_with(exclude_none=True)