feat: add BigQuery first-party tools.

These tools support getting BigQuery dataset/table metadata and query results.

PiperOrigin-RevId: 764139132
This commit is contained in:
Google Team Member 2025-05-28 00:58:53 -07:00 committed by Copybara-Service
parent 46282eeb0d
commit d6c6bb4b24
11 changed files with 748 additions and 11 deletions

View File

@ -0,0 +1,83 @@
# BigQuery Tools Sample
## Introduction
This sample agent demonstrates the BigQuery first-party tools in ADK,
distributed via the `google.adk.tools.bigquery` module. These tools include:
1. `list_dataset_ids`
Fetches BigQuery dataset ids present in a GCP project.
1. `get_dataset_info`
Fetches metadata about a BigQuery dataset.
1. `list_table_ids`
Fetches table ids present in a BigQuery dataset.
1. `get_table_info`
Fetches metadata about a BigQuery table.
1. `execute_sql`
Runs a SQL query in BigQuery.
## How to use
Set up environment variables in your `.env` file for using
[Google AI Studio](https://google.github.io/adk-docs/get-started/quickstart/#gemini---google-ai-studio)
or
[Google Cloud Vertex AI](https://google.github.io/adk-docs/get-started/quickstart/#gemini---google-cloud-vertex-ai)
for the LLM service for your agent. For example, for using Google AI Studio you
would set:
* GOOGLE_GENAI_USE_VERTEXAI=FALSE
* GOOGLE_API_KEY={your api key}
### With Application Default Credentials
This mode is useful for quick development when the agent builder is the only
user interacting with the agent. The tools are initialized with the default
credentials present on the machine running the agent.
1. Create application default credentials on the machine where the agent would
be running by following https://cloud.google.com/docs/authentication/provide-credentials-adc.
1. Set `RUN_WITH_ADC=True` in `agent.py` and run the agent
### With Interactive OAuth
1. Follow
https://developers.google.com/identity/protocols/oauth2#1.-obtain-oauth-2.0-credentials-from-the-dynamic_data.setvar.console_name.
to get your client id and client secret. Be sure to choose "web" as your client
type.
1. Follow https://developers.google.com/workspace/guides/configure-oauth-consent to add scope "https://www.googleapis.com/auth/bigquery".
1. Follow https://developers.google.com/identity/protocols/oauth2/web-server#creatingcred to add http://localhost/dev-ui/ to "Authorized redirect URIs".
Note: localhost here is just a hostname that you use to access the dev ui,
replace it with the actual hostname you use to access the dev ui.
1. For 1st run, allow popup for localhost in Chrome.
1. Configure your `.env` file to add two more variables before running the agent:
* OAUTH_CLIENT_ID={your client id}
* OAUTH_CLIENT_SECRET={your client secret}
Note: don't create a separate .env, instead put it to the same .env file that
stores your Vertex AI or Dev ML credentials
1. Set `RUN_WITH_ADC=False` in `agent.py` and run the agent
## Sample prompts
* which weather datasets exist in bigquery public data?
* tell me more about noaa_lightning
* which tables exist in the ml_datasets dataset?
* show more details about the penguins table
* compute penguins population per island.

View File

@ -0,0 +1,15 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from . import agent

View File

@ -0,0 +1,58 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from google.adk.agents import llm_agent
from google.adk.tools.bigquery import BigQueryCredentialsConfig
from google.adk.tools.bigquery import BigQueryToolset
import google.auth
RUN_WITH_ADC = False
if RUN_WITH_ADC:
# Initialize the tools to use the application default credentials.
application_default_credentials, _ = google.auth.default()
credentials_config = BigQueryCredentialsConfig(
credentials=application_default_credentials
)
else:
# Initiaze the tools to do interactive OAuth
# The environment variables OAUTH_CLIENT_ID and OAUTH_CLIENT_SECRET
# must be set
credentials_config = BigQueryCredentialsConfig(
client_id=os.getenv("OAUTH_CLIENT_ID"),
client_secret=os.getenv("OAUTH_CLIENT_SECRET"),
scopes=["https://www.googleapis.com/auth/bigquery"],
)
bigquery_toolset = BigQueryToolset(credentials_config=credentials_config)
# The variable name `root_agent` determines what your root agent is for the
# debug CLI
root_agent = llm_agent.Agent(
model="gemini-2.0-flash",
name="hello_agent",
description=(
"Agent to answer questions about BigQuery data and models and execute"
" SQL queries."
),
instruction="""\
You are a data science agent with access to several BigQuery tools.
Make use of those tools to answer the user's questions.
""",
tools=[bigquery_toolset],
)

View File

@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""BigQuery Tools. (Experimental) """BigQuery Tools (Experimental).
BigQuery Tools under this module are hand crafted and customized while the tools BigQuery Tools under this module are hand crafted and customized while the tools
under google.adk.tools.google_api_tool are auto generated based on API under google.adk.tools.google_api_tool are auto generated based on API
@ -26,3 +26,13 @@ definition. The rationales to have customized tool are:
4. We want to provide extra access guardrails in those tools. For example, 4. We want to provide extra access guardrails in those tools. For example,
execute_sql can't arbitrarily mutate existing data. execute_sql can't arbitrarily mutate existing data.
""" """
from .bigquery_credentials import BigQueryCredentialsConfig
from .bigquery_tool import BigQueryTool
from .bigquery_toolset import BigQueryToolset
__all__ = [
"BigQueryTool",
"BigQueryToolset",
"BigQueryCredentialsConfig",
]

View File

@ -14,6 +14,7 @@
from __future__ import annotations from __future__ import annotations
import json
from typing import List from typing import List
from typing import Optional from typing import Optional
@ -121,7 +122,7 @@ class BigQueryCredentialsManager:
creds_json = tool_context.state.get(BIGQUERY_TOKEN_CACHE_KEY, None) creds_json = tool_context.state.get(BIGQUERY_TOKEN_CACHE_KEY, None)
creds = ( creds = (
Credentials.from_authorized_user_info( Credentials.from_authorized_user_info(
creds_json, self.credentials_config.scopes json.loads(creds_json), self.credentials_config.scopes
) )
if creds_json if creds_json
else None else None

View File

@ -0,0 +1,86 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
from typing import List
from typing import Optional
from typing import Union
from google.adk.agents.readonly_context import ReadonlyContext
from typing_extensions import override
from . import metadata_tool
from . import query_tool
from ...tools.base_tool import BaseTool
from ...tools.base_toolset import BaseToolset
from ...tools.base_toolset import ToolPredicate
from .bigquery_credentials import BigQueryCredentialsConfig
from .bigquery_tool import BigQueryTool
class BigQueryToolset(BaseToolset):
"""BigQuery Toolset contains tools for interacting with BigQuery data and metadata."""
def __init__(
self,
*,
tool_filter: Optional[Union[ToolPredicate, List[str]]] = None,
credentials_config: Optional[BigQueryCredentialsConfig] = None,
):
self._credentials_config = credentials_config
self.tool_filter = tool_filter
def _is_tool_selected(
self, tool: BaseTool, readonly_context: ReadonlyContext
) -> bool:
if self.tool_filter is None:
return True
if isinstance(self.tool_filter, ToolPredicate):
return self.tool_filter(tool, readonly_context)
if isinstance(self.tool_filter, list):
return tool.name in self.tool_filter
return False
@override
async def get_tools(
self, readonly_context: Optional[ReadonlyContext] = None
) -> List[BaseTool]:
"""Get tools from the toolset."""
all_tools = [
BigQueryTool(
func=func,
credentials=self._credentials_config,
)
for func in [
metadata_tool.get_dataset_info,
metadata_tool.get_table_info,
metadata_tool.list_dataset_ids,
metadata_tool.list_table_ids,
query_tool.execute_sql,
]
]
return [
tool
for tool in all_tools
if self._is_tool_selected(tool, readonly_context)
]
@override
async def close(self):
pass

View File

@ -0,0 +1,33 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import google.api_core.client_info
from google.cloud import bigquery
from google.oauth2.credentials import Credentials
USER_AGENT = "adk-bigquery-tool"
def get_bigquery_client(*, credentials: Credentials) -> bigquery.Client:
"""Get a BigQuery client."""
client_info = google.api_core.client_info.ClientInfo(user_agent=USER_AGENT)
bigquery_client = bigquery.Client(
credentials=credentials, client_info=client_info
)
return bigquery_client

View File

@ -0,0 +1,249 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from google.cloud import bigquery
from google.oauth2.credentials import Credentials
from ...tools.bigquery import client
def list_dataset_ids(project_id: str, credentials: Credentials) -> list[str]:
"""List BigQuery dataset ids in a Google Cloud project.
Args:
project_id (str): The Google Cloud project id.
credentials (Credentials): The credentials to use for the request.
Returns:
list[str]: List of the BigQuery dataset ids present in the project.
Examples:
>>> list_dataset_ids("bigquery-public-data")
['america_health_rankings',
'american_community_survey',
'aml_ai_input_dataset',
'austin_311',
'austin_bikeshare',
'austin_crime',
'austin_incidents',
'austin_waste',
'baseball',
'bbc_news']
"""
try:
bq_client = client.get_bigquery_client(credentials=credentials)
datasets = []
for dataset in bq_client.list_datasets(project_id):
datasets.append(dataset.dataset_id)
return datasets
except Exception as ex:
return {
"status": "ERROR",
"error_details": str(ex),
}
def get_dataset_info(
project_id: str, dataset_id: str, credentials: Credentials
) -> dict:
"""Get metadata information about a BigQuery dataset.
Args:
project_id (str): The Google Cloud project id containing the dataset.
dataset_id (str): The BigQuery dataset id.
credentials (Credentials): The credentials to use for the request.
Returns:
dict: Dictionary representing the properties of the dataset.
Examples:
>>> get_dataset_info("bigquery-public-data", "penguins")
{
"kind": "bigquery#dataset",
"etag": "PNC5907iQbzeVcAru/2L3A==",
"id": "bigquery-public-data:ml_datasets",
"selfLink":
"https://bigquery.googleapis.com/bigquery/v2/projects/bigquery-public-data/datasets/ml_datasets",
"datasetReference": {
"datasetId": "ml_datasets",
"projectId": "bigquery-public-data"
},
"access": [
{
"role": "OWNER",
"groupByEmail": "cloud-datasets-eng@google.com"
},
{
"role": "READER",
"iamMember": "allUsers"
},
{
"role": "READER",
"groupByEmail": "bqml-eng@google.com"
}
],
"creationTime": "1553208775542",
"lastModifiedTime": "1686338918114",
"location": "US",
"type": "DEFAULT",
"maxTimeTravelHours": "168"
}
"""
try:
bq_client = client.get_bigquery_client(credentials=credentials)
dataset = bq_client.get_dataset(
bigquery.DatasetReference(project_id, dataset_id)
)
return dataset.to_api_repr()
except Exception as ex:
return {
"status": "ERROR",
"error_details": str(ex),
}
def list_table_ids(
project_id: str, dataset_id: str, credentials: Credentials
) -> list[str]:
"""List table ids in a BigQuery dataset.
Args:
project_id (str): The Google Cloud project id containing the dataset.
dataset_id (str): The BigQuery dataset id.
credentials (Credentials): The credentials to use for the request.
Returns:
list[str]: List of the tables ids present in the dataset.
Examples:
>>> list_table_ids("bigquery-public-data", "ml_datasets")
['census_adult_income',
'credit_card_default',
'holidays_and_events_for_forecasting',
'iris',
'penguins',
'ulb_fraud_detection']
"""
try:
bq_client = client.get_bigquery_client(credentials=credentials)
tables = []
for table in bq_client.list_tables(
bigquery.DatasetReference(project_id, dataset_id)
):
tables.append(table.table_id)
return tables
except Exception as ex:
return {
"status": "ERROR",
"error_details": str(ex),
}
def get_table_info(
project_id: str, dataset_id: str, table_id: str, credentials: Credentials
) -> dict:
"""Get metadata information about a BigQuery table.
Args:
project_id (str): The Google Cloud project id containing the dataset.
dataset_id (str): The BigQuery dataset id containing the table.
table_id (str): The BigQuery table id.
credentials (Credentials): The credentials to use for the request.
Returns:
dict: Dictionary representing the properties of the table.
Examples:
>>> get_table_info("bigquery-public-data", "ml_datasets", "penguins")
{
"kind": "bigquery#table",
"etag": "X0ZkRohSGoYvWemRYEgOHA==",
"id": "bigquery-public-data:ml_datasets.penguins",
"selfLink":
"https://bigquery.googleapis.com/bigquery/v2/projects/bigquery-public-data/datasets/ml_datasets/tables/penguins",
"tableReference": {
"projectId": "bigquery-public-data",
"datasetId": "ml_datasets",
"tableId": "penguins"
},
"schema": {
"fields": [
{
"name": "species",
"type": "STRING",
"mode": "REQUIRED"
},
{
"name": "island",
"type": "STRING",
"mode": "NULLABLE"
},
{
"name": "culmen_length_mm",
"type": "FLOAT",
"mode": "NULLABLE"
},
{
"name": "culmen_depth_mm",
"type": "FLOAT",
"mode": "NULLABLE"
},
{
"name": "flipper_length_mm",
"type": "FLOAT",
"mode": "NULLABLE"
},
{
"name": "body_mass_g",
"type": "FLOAT",
"mode": "NULLABLE"
},
{
"name": "sex",
"type": "STRING",
"mode": "NULLABLE"
}
]
},
"numBytes": "28947",
"numLongTermBytes": "28947",
"numRows": "344",
"creationTime": "1619804743188",
"lastModifiedTime": "1634584675234",
"type": "TABLE",
"location": "US",
"numTimeTravelPhysicalBytes": "0",
"numTotalLogicalBytes": "28947",
"numActiveLogicalBytes": "0",
"numLongTermLogicalBytes": "28947",
"numTotalPhysicalBytes": "5350",
"numActivePhysicalBytes": "0",
"numLongTermPhysicalBytes": "5350",
"numCurrentPhysicalBytes": "5350"
}
"""
try:
bq_client = client.get_bigquery_client(credentials=credentials)
return bq_client.get_table(
bigquery.TableReference(
bigquery.DatasetReference(project_id, dataset_id), table_id
)
).to_api_repr()
except Exception as ex:
return {
"status": "ERROR",
"error_details": str(ex),
}

View File

@ -0,0 +1,76 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from google.oauth2.credentials import Credentials
from ...tools.bigquery import client
MAX_DOWNLOADED_QUERY_RESULT_ROWS = 50
def execute_sql(project_id: str, query: str, credentials: Credentials) -> dict:
"""Run a BigQuery SQL query in the project and return the result.
Args:
project_id (str): The GCP project id in which the query should be
executed.
query (str): The BigQuery SQL query to be executed.
credentials (Credentials): The credentials to use for the request.
Returns:
dict: Dictionary representing the result of the query.
If the result contains the key "result_is_likely_truncated" with
value True, it means that there may be additional rows matching the
query not returned in the result.
Examples:
>>> execute_sql("bigframes-dev",
... "SELECT island, COUNT(*) AS population "
... "FROM bigquery-public-data.ml_datasets.penguins GROUP BY island")
{
"rows": [
{
"island": "Dream",
"population": 124
},
{
"island": "Biscoe",
"population": 168
},
{
"island": "Torgersen",
"population": 52
}
]
}
"""
try:
bq_client = client.get_bigquery_client(credentials=credentials)
row_iterator = bq_client.query_and_wait(
query, project=project_id, max_results=MAX_DOWNLOADED_QUERY_RESULT_ROWS
)
rows = [{key: val for key, val in row.items()} for row in row_iterator]
result = {"rows": rows}
if (
MAX_DOWNLOADED_QUERY_RESULT_ROWS is not None
and len(rows) == MAX_DOWNLOADED_QUERY_RESULT_ROWS
):
result["result_is_likely_truncated"] = True
return result
except Exception as ex:
return {
"status": "ERROR",
"error_details": str(ex),
}

View File

@ -13,6 +13,7 @@
# limitations under the License. # limitations under the License.
import json
from unittest.mock import Mock from unittest.mock import Mock
from unittest.mock import patch from unittest.mock import patch
@ -98,12 +99,12 @@ class TestBigQueryCredentialsManager:
manager.credentials_config.credentials = None manager.credentials_config.credentials = None
# Create mock cached credentials JSON that would be stored in cache # Create mock cached credentials JSON that would be stored in cache
mock_cached_creds_json = { mock_cached_creds_json = json.dumps({
"token": "cached_token", "token": "cached_token",
"refresh_token": "cached_refresh_token", "refresh_token": "cached_refresh_token",
"client_id": "test_client_id", "client_id": "test_client_id",
"client_secret": "test_client_secret", "client_secret": "test_client_secret",
} })
# Set up the tool context state to contain cached credentials # Set up the tool context state to contain cached credentials
mock_tool_context.state[BIGQUERY_TOKEN_CACHE_KEY] = mock_cached_creds_json mock_tool_context.state[BIGQUERY_TOKEN_CACHE_KEY] = mock_cached_creds_json
@ -120,7 +121,7 @@ class TestBigQueryCredentialsManager:
# Verify credentials were created from cached JSON # Verify credentials were created from cached JSON
mock_from_json.assert_called_once_with( mock_from_json.assert_called_once_with(
mock_cached_creds_json, manager.credentials_config.scopes json.loads(mock_cached_creds_json), manager.credentials_config.scopes
) )
# Verify loaded credentials were not cached into manager # Verify loaded credentials were not cached into manager
assert manager.credentials_config.credentials is None assert manager.credentials_config.credentials is None
@ -160,19 +161,19 @@ class TestBigQueryCredentialsManager:
manager.credentials_config.credentials = None manager.credentials_config.credentials = None
# Create mock cached credentials JSON # Create mock cached credentials JSON
mock_cached_creds_json = { mock_cached_creds_json = json.dumps({
"token": "expired_token", "token": "expired_token",
"refresh_token": "valid_refresh_token", "refresh_token": "valid_refresh_token",
"client_id": "test_client_id", "client_id": "test_client_id",
"client_secret": "test_client_secret", "client_secret": "test_client_secret",
} })
mock_refreshed_creds_json = { mock_refreshed_creds_json = json.dumps({
"token": "new_token", "token": "new_token",
"refresh_token": "valid_refresh_token", "refresh_token": "valid_refresh_token",
"client_id": "test_client_id", "client_id": "test_client_id",
"client_secret": "test_client_secret", "client_secret": "test_client_secret",
} })
# Set up the tool context state to contain cached credentials # Set up the tool context state to contain cached credentials
mock_tool_context.state[BIGQUERY_TOKEN_CACHE_KEY] = mock_cached_creds_json mock_tool_context.state[BIGQUERY_TOKEN_CACHE_KEY] = mock_cached_creds_json
@ -200,7 +201,7 @@ class TestBigQueryCredentialsManager:
# Verify credentials were created from cached JSON # Verify credentials were created from cached JSON
mock_from_json.assert_called_once_with( mock_from_json.assert_called_once_with(
mock_cached_creds_json, manager.credentials_config.scopes json.loads(mock_cached_creds_json), manager.credentials_config.scopes
) )
# Verify refresh was attempted and succeeded # Verify refresh was attempted and succeeded
mock_cached_creds.refresh.assert_called_once() mock_cached_creds.refresh.assert_called_once()
@ -209,7 +210,9 @@ class TestBigQueryCredentialsManager:
# Verify refreshed credentials were cached # Verify refreshed credentials were cached
assert ( assert (
"new_token" "new_token"
== mock_tool_context.state[BIGQUERY_TOKEN_CACHE_KEY]["token"] == json.loads(mock_tool_context.state[BIGQUERY_TOKEN_CACHE_KEY])[
"token"
]
) )
assert result == mock_cached_creds assert result == mock_cached_creds

View File

@ -0,0 +1,123 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
from google.adk.tools.bigquery import BigQueryCredentialsConfig
from google.adk.tools.bigquery import BigQueryTool
from google.adk.tools.bigquery import BigQueryToolset
import pytest
@pytest.mark.asyncio
async def test_bigquery_toolset_tools_default():
"""Test default BigQuery toolset.
This test verifies the behavior of the BigQuery toolset when no filter is
specified.
"""
credentials_config = BigQueryCredentialsConfig(
client_id="abc", client_secret="def"
)
toolset = BigQueryToolset(credentials_config=credentials_config)
tools = await toolset.get_tools()
assert tools is not None
assert len(tools) == 5
assert all([isinstance(tool, BigQueryTool) for tool in tools])
expected_tool_names = set([
"list_dataset_ids",
"get_dataset_info",
"list_table_ids",
"get_table_info",
"execute_sql",
])
actual_tool_names = set([tool.name for tool in tools])
assert actual_tool_names == expected_tool_names
@pytest.mark.parametrize(
"selected_tools",
[
pytest.param([], id="None"),
pytest.param(
["list_dataset_ids", "get_dataset_info"], id="dataset-metadata"
),
pytest.param(["list_table_ids", "get_table_info"], id="table-metadata"),
pytest.param(["execute_sql"], id="query"),
],
)
@pytest.mark.asyncio
async def test_bigquery_toolset_tools_selective(selected_tools):
"""Test BigQuery toolset with filter.
This test verifies the behavior of the BigQuery toolset when filter is
specified. A use case for this would be when the agent builder wants to
use only a subset of the tools provided by the toolset.
"""
credentials_config = BigQueryCredentialsConfig(
client_id="abc", client_secret="def"
)
toolset = BigQueryToolset(
credentials_config=credentials_config, tool_filter=selected_tools
)
tools = await toolset.get_tools()
assert tools is not None
assert len(tools) == len(selected_tools)
assert all([isinstance(tool, BigQueryTool) for tool in tools])
expected_tool_names = set(selected_tools)
actual_tool_names = set([tool.name for tool in tools])
assert actual_tool_names == expected_tool_names
@pytest.mark.parametrize(
("selected_tools", "returned_tools"),
[
pytest.param(["unknown"], [], id="all-unknown"),
pytest.param(
["unknown", "execute_sql"],
["execute_sql"],
id="mixed-known-unknown",
),
],
)
@pytest.mark.asyncio
async def test_bigquery_toolset_unknown_tool_raises(
selected_tools, returned_tools
):
"""Test BigQuery toolset with filter.
This test verifies the behavior of the BigQuery toolset when filter is
specified with an unknown tool.
"""
credentials_config = BigQueryCredentialsConfig(
client_id="abc", client_secret="def"
)
toolset = BigQueryToolset(
credentials_config=credentials_config, tool_filter=selected_tools
)
tools = await toolset.get_tools()
assert tools is not None
assert len(tools) == len(returned_tools)
assert all([isinstance(tool, BigQueryTool) for tool in tools])
expected_tool_names = set(returned_tools)
actual_tool_names = set([tool.name for tool in tools])
assert actual_tool_names == expected_tool_names