feat: add BigQuery first-party tools.

These tools support getting BigQuery dataset/table metadata and query results.

PiperOrigin-RevId: 764139132
This commit is contained in:
Google Team Member
2025-05-28 00:58:53 -07:00
committed by Copybara-Service
parent 46282eeb0d
commit d6c6bb4b24
11 changed files with 748 additions and 11 deletions

View File

@@ -13,6 +13,7 @@
# limitations under the License.
import json
from unittest.mock import Mock
from unittest.mock import patch
@@ -98,12 +99,12 @@ class TestBigQueryCredentialsManager:
manager.credentials_config.credentials = None
# Create mock cached credentials JSON that would be stored in cache
mock_cached_creds_json = {
mock_cached_creds_json = json.dumps({
"token": "cached_token",
"refresh_token": "cached_refresh_token",
"client_id": "test_client_id",
"client_secret": "test_client_secret",
}
})
# Set up the tool context state to contain cached credentials
mock_tool_context.state[BIGQUERY_TOKEN_CACHE_KEY] = mock_cached_creds_json
@@ -120,7 +121,7 @@ class TestBigQueryCredentialsManager:
# Verify credentials were created from cached JSON
mock_from_json.assert_called_once_with(
mock_cached_creds_json, manager.credentials_config.scopes
json.loads(mock_cached_creds_json), manager.credentials_config.scopes
)
# Verify loaded credentials were not cached into manager
assert manager.credentials_config.credentials is None
@@ -160,19 +161,19 @@ class TestBigQueryCredentialsManager:
manager.credentials_config.credentials = None
# Create mock cached credentials JSON
mock_cached_creds_json = {
mock_cached_creds_json = json.dumps({
"token": "expired_token",
"refresh_token": "valid_refresh_token",
"client_id": "test_client_id",
"client_secret": "test_client_secret",
}
})
mock_refreshed_creds_json = {
mock_refreshed_creds_json = json.dumps({
"token": "new_token",
"refresh_token": "valid_refresh_token",
"client_id": "test_client_id",
"client_secret": "test_client_secret",
}
})
# Set up the tool context state to contain cached credentials
mock_tool_context.state[BIGQUERY_TOKEN_CACHE_KEY] = mock_cached_creds_json
@@ -200,7 +201,7 @@ class TestBigQueryCredentialsManager:
# Verify credentials were created from cached JSON
mock_from_json.assert_called_once_with(
mock_cached_creds_json, manager.credentials_config.scopes
json.loads(mock_cached_creds_json), manager.credentials_config.scopes
)
# Verify refresh was attempted and succeeded
mock_cached_creds.refresh.assert_called_once()
@@ -209,7 +210,9 @@ class TestBigQueryCredentialsManager:
# Verify refreshed credentials were cached
assert (
"new_token"
== mock_tool_context.state[BIGQUERY_TOKEN_CACHE_KEY]["token"]
== json.loads(mock_tool_context.state[BIGQUERY_TOKEN_CACHE_KEY])[
"token"
]
)
assert result == mock_cached_creds

View File

@@ -0,0 +1,123 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
from google.adk.tools.bigquery import BigQueryCredentialsConfig
from google.adk.tools.bigquery import BigQueryTool
from google.adk.tools.bigquery import BigQueryToolset
import pytest
@pytest.mark.asyncio
async def test_bigquery_toolset_tools_default():
"""Test default BigQuery toolset.
This test verifies the behavior of the BigQuery toolset when no filter is
specified.
"""
credentials_config = BigQueryCredentialsConfig(
client_id="abc", client_secret="def"
)
toolset = BigQueryToolset(credentials_config=credentials_config)
tools = await toolset.get_tools()
assert tools is not None
assert len(tools) == 5
assert all([isinstance(tool, BigQueryTool) for tool in tools])
expected_tool_names = set([
"list_dataset_ids",
"get_dataset_info",
"list_table_ids",
"get_table_info",
"execute_sql",
])
actual_tool_names = set([tool.name for tool in tools])
assert actual_tool_names == expected_tool_names
@pytest.mark.parametrize(
"selected_tools",
[
pytest.param([], id="None"),
pytest.param(
["list_dataset_ids", "get_dataset_info"], id="dataset-metadata"
),
pytest.param(["list_table_ids", "get_table_info"], id="table-metadata"),
pytest.param(["execute_sql"], id="query"),
],
)
@pytest.mark.asyncio
async def test_bigquery_toolset_tools_selective(selected_tools):
"""Test BigQuery toolset with filter.
This test verifies the behavior of the BigQuery toolset when filter is
specified. A use case for this would be when the agent builder wants to
use only a subset of the tools provided by the toolset.
"""
credentials_config = BigQueryCredentialsConfig(
client_id="abc", client_secret="def"
)
toolset = BigQueryToolset(
credentials_config=credentials_config, tool_filter=selected_tools
)
tools = await toolset.get_tools()
assert tools is not None
assert len(tools) == len(selected_tools)
assert all([isinstance(tool, BigQueryTool) for tool in tools])
expected_tool_names = set(selected_tools)
actual_tool_names = set([tool.name for tool in tools])
assert actual_tool_names == expected_tool_names
@pytest.mark.parametrize(
("selected_tools", "returned_tools"),
[
pytest.param(["unknown"], [], id="all-unknown"),
pytest.param(
["unknown", "execute_sql"],
["execute_sql"],
id="mixed-known-unknown",
),
],
)
@pytest.mark.asyncio
async def test_bigquery_toolset_unknown_tool_raises(
selected_tools, returned_tools
):
"""Test BigQuery toolset with filter.
This test verifies the behavior of the BigQuery toolset when filter is
specified with an unknown tool.
"""
credentials_config = BigQueryCredentialsConfig(
client_id="abc", client_secret="def"
)
toolset = BigQueryToolset(
credentials_config=credentials_config, tool_filter=selected_tools
)
tools = await toolset.get_tools()
assert tools is not None
assert len(tools) == len(returned_tools)
assert all([isinstance(tool, BigQueryTool) for tool in tools])
expected_tool_names = set(returned_tools)
actual_tool_names = set([tool.name for tool in tools])
assert actual_tool_names == expected_tool_names