From d6c6bb4b2489a8b7a4713e4747c30d6df0c07961 Mon Sep 17 00:00:00 2001 From: Google Team Member Date: Wed, 28 May 2025 00:58:53 -0700 Subject: [PATCH] feat: add BigQuery first-party tools. These tools support getting BigQuery dataset/table metadata and query results. PiperOrigin-RevId: 764139132 --- contributing/samples/bigquery/README.md | 83 ++++++ contributing/samples/bigquery/__init__.py | 15 ++ contributing/samples/bigquery/agent.py | 58 ++++ src/google/adk/tools/bigquery/__init__.py | 12 +- .../tools/bigquery/bigquery_credentials.py | 3 +- .../adk/tools/bigquery/bigquery_toolset.py | 86 ++++++ src/google/adk/tools/bigquery/client.py | 33 +++ .../adk/tools/bigquery/metadata_tool.py | 249 ++++++++++++++++++ src/google/adk/tools/bigquery/query_tool.py | 76 ++++++ .../test_bigquery_credentials_manager.py | 21 +- .../bigquery_tool/test_bigquery_toolset.py | 123 +++++++++ 11 files changed, 748 insertions(+), 11 deletions(-) create mode 100644 contributing/samples/bigquery/README.md create mode 100644 contributing/samples/bigquery/__init__.py create mode 100644 contributing/samples/bigquery/agent.py create mode 100644 src/google/adk/tools/bigquery/bigquery_toolset.py create mode 100644 src/google/adk/tools/bigquery/client.py create mode 100644 src/google/adk/tools/bigquery/metadata_tool.py create mode 100644 src/google/adk/tools/bigquery/query_tool.py create mode 100644 tests/unittests/tools/bigquery_tool/test_bigquery_toolset.py diff --git a/contributing/samples/bigquery/README.md b/contributing/samples/bigquery/README.md new file mode 100644 index 0000000..cd4583c --- /dev/null +++ b/contributing/samples/bigquery/README.md @@ -0,0 +1,83 @@ +# BigQuery Tools Sample + +## Introduction + +This sample agent demonstrates the BigQuery first-party tools in ADK, +distributed via the `google.adk.tools.bigquery` module. These tools include: + +1. `list_dataset_ids` + + Fetches BigQuery dataset ids present in a GCP project. + +1. `get_dataset_info` + + Fetches metadata about a BigQuery dataset. + +1. `list_table_ids` + + Fetches table ids present in a BigQuery dataset. + +1. `get_table_info` + + Fetches metadata about a BigQuery table. + +1. `execute_sql` + + Runs a SQL query in BigQuery. + +## How to use + +Set up environment variables in your `.env` file for using +[Google AI Studio](https://google.github.io/adk-docs/get-started/quickstart/#gemini---google-ai-studio) +or +[Google Cloud Vertex AI](https://google.github.io/adk-docs/get-started/quickstart/#gemini---google-cloud-vertex-ai) +for the LLM service for your agent. For example, for using Google AI Studio you +would set: + +* GOOGLE_GENAI_USE_VERTEXAI=FALSE +* GOOGLE_API_KEY={your api key} + +### With Application Default Credentials + +This mode is useful for quick development when the agent builder is the only +user interacting with the agent. The tools are initialized with the default +credentials present on the machine running the agent. + +1. Create application default credentials on the machine where the agent would +be running by following https://cloud.google.com/docs/authentication/provide-credentials-adc. + +1. Set `RUN_WITH_ADC=True` in `agent.py` and run the agent + +### With Interactive OAuth + +1. Follow +https://developers.google.com/identity/protocols/oauth2#1.-obtain-oauth-2.0-credentials-from-the-dynamic_data.setvar.console_name. +to get your client id and client secret. Be sure to choose "web" as your client +type. + +1. Follow https://developers.google.com/workspace/guides/configure-oauth-consent to add scope "https://www.googleapis.com/auth/bigquery". + +1. Follow https://developers.google.com/identity/protocols/oauth2/web-server#creatingcred to add http://localhost/dev-ui/ to "Authorized redirect URIs". + + Note: localhost here is just a hostname that you use to access the dev ui, + replace it with the actual hostname you use to access the dev ui. + +1. For 1st run, allow popup for localhost in Chrome. + +1. Configure your `.env` file to add two more variables before running the agent: + + * OAUTH_CLIENT_ID={your client id} + * OAUTH_CLIENT_SECRET={your client secret} + + Note: don't create a separate .env, instead put it to the same .env file that + stores your Vertex AI or Dev ML credentials + +1. Set `RUN_WITH_ADC=False` in `agent.py` and run the agent + +## Sample prompts + +* which weather datasets exist in bigquery public data? +* tell me more about noaa_lightning +* which tables exist in the ml_datasets dataset? +* show more details about the penguins table +* compute penguins population per island. diff --git a/contributing/samples/bigquery/__init__.py b/contributing/samples/bigquery/__init__.py new file mode 100644 index 0000000..c48963c --- /dev/null +++ b/contributing/samples/bigquery/__init__.py @@ -0,0 +1,15 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from . import agent diff --git a/contributing/samples/bigquery/agent.py b/contributing/samples/bigquery/agent.py new file mode 100644 index 0000000..81a0a18 --- /dev/null +++ b/contributing/samples/bigquery/agent.py @@ -0,0 +1,58 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os + +from google.adk.agents import llm_agent +from google.adk.tools.bigquery import BigQueryCredentialsConfig +from google.adk.tools.bigquery import BigQueryToolset +import google.auth + + +RUN_WITH_ADC = False + + +if RUN_WITH_ADC: + # Initialize the tools to use the application default credentials. + application_default_credentials, _ = google.auth.default() + credentials_config = BigQueryCredentialsConfig( + credentials=application_default_credentials + ) +else: + # Initiaze the tools to do interactive OAuth + # The environment variables OAUTH_CLIENT_ID and OAUTH_CLIENT_SECRET + # must be set + credentials_config = BigQueryCredentialsConfig( + client_id=os.getenv("OAUTH_CLIENT_ID"), + client_secret=os.getenv("OAUTH_CLIENT_SECRET"), + scopes=["https://www.googleapis.com/auth/bigquery"], + ) + +bigquery_toolset = BigQueryToolset(credentials_config=credentials_config) + +# The variable name `root_agent` determines what your root agent is for the +# debug CLI +root_agent = llm_agent.Agent( + model="gemini-2.0-flash", + name="hello_agent", + description=( + "Agent to answer questions about BigQuery data and models and execute" + " SQL queries." + ), + instruction="""\ + You are a data science agent with access to several BigQuery tools. + Make use of those tools to answer the user's questions. + """, + tools=[bigquery_toolset], +) diff --git a/src/google/adk/tools/bigquery/__init__.py b/src/google/adk/tools/bigquery/__init__.py index 72054bb..af3c776 100644 --- a/src/google/adk/tools/bigquery/__init__.py +++ b/src/google/adk/tools/bigquery/__init__.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""BigQuery Tools. (Experimental) +"""BigQuery Tools (Experimental). BigQuery Tools under this module are hand crafted and customized while the tools under google.adk.tools.google_api_tool are auto generated based on API @@ -26,3 +26,13 @@ definition. The rationales to have customized tool are: 4. We want to provide extra access guardrails in those tools. For example, execute_sql can't arbitrarily mutate existing data. """ + +from .bigquery_credentials import BigQueryCredentialsConfig +from .bigquery_tool import BigQueryTool +from .bigquery_toolset import BigQueryToolset + +__all__ = [ + "BigQueryTool", + "BigQueryToolset", + "BigQueryCredentialsConfig", +] diff --git a/src/google/adk/tools/bigquery/bigquery_credentials.py b/src/google/adk/tools/bigquery/bigquery_credentials.py index f334b40..8b3854a 100644 --- a/src/google/adk/tools/bigquery/bigquery_credentials.py +++ b/src/google/adk/tools/bigquery/bigquery_credentials.py @@ -14,6 +14,7 @@ from __future__ import annotations +import json from typing import List from typing import Optional @@ -121,7 +122,7 @@ class BigQueryCredentialsManager: creds_json = tool_context.state.get(BIGQUERY_TOKEN_CACHE_KEY, None) creds = ( Credentials.from_authorized_user_info( - creds_json, self.credentials_config.scopes + json.loads(creds_json), self.credentials_config.scopes ) if creds_json else None diff --git a/src/google/adk/tools/bigquery/bigquery_toolset.py b/src/google/adk/tools/bigquery/bigquery_toolset.py new file mode 100644 index 0000000..241c010 --- /dev/null +++ b/src/google/adk/tools/bigquery/bigquery_toolset.py @@ -0,0 +1,86 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from typing import List +from typing import Optional +from typing import Union + +from google.adk.agents.readonly_context import ReadonlyContext +from typing_extensions import override + +from . import metadata_tool +from . import query_tool +from ...tools.base_tool import BaseTool +from ...tools.base_toolset import BaseToolset +from ...tools.base_toolset import ToolPredicate +from .bigquery_credentials import BigQueryCredentialsConfig +from .bigquery_tool import BigQueryTool + + +class BigQueryToolset(BaseToolset): + """BigQuery Toolset contains tools for interacting with BigQuery data and metadata.""" + + def __init__( + self, + *, + tool_filter: Optional[Union[ToolPredicate, List[str]]] = None, + credentials_config: Optional[BigQueryCredentialsConfig] = None, + ): + self._credentials_config = credentials_config + self.tool_filter = tool_filter + + def _is_tool_selected( + self, tool: BaseTool, readonly_context: ReadonlyContext + ) -> bool: + if self.tool_filter is None: + return True + + if isinstance(self.tool_filter, ToolPredicate): + return self.tool_filter(tool, readonly_context) + + if isinstance(self.tool_filter, list): + return tool.name in self.tool_filter + + return False + + @override + async def get_tools( + self, readonly_context: Optional[ReadonlyContext] = None + ) -> List[BaseTool]: + """Get tools from the toolset.""" + all_tools = [ + BigQueryTool( + func=func, + credentials=self._credentials_config, + ) + for func in [ + metadata_tool.get_dataset_info, + metadata_tool.get_table_info, + metadata_tool.list_dataset_ids, + metadata_tool.list_table_ids, + query_tool.execute_sql, + ] + ] + + return [ + tool + for tool in all_tools + if self._is_tool_selected(tool, readonly_context) + ] + + @override + async def close(self): + pass diff --git a/src/google/adk/tools/bigquery/client.py b/src/google/adk/tools/bigquery/client.py new file mode 100644 index 0000000..d72761b --- /dev/null +++ b/src/google/adk/tools/bigquery/client.py @@ -0,0 +1,33 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import google.api_core.client_info +from google.cloud import bigquery +from google.oauth2.credentials import Credentials + +USER_AGENT = "adk-bigquery-tool" + + +def get_bigquery_client(*, credentials: Credentials) -> bigquery.Client: + """Get a BigQuery client.""" + + client_info = google.api_core.client_info.ClientInfo(user_agent=USER_AGENT) + + bigquery_client = bigquery.Client( + credentials=credentials, client_info=client_info + ) + + return bigquery_client diff --git a/src/google/adk/tools/bigquery/metadata_tool.py b/src/google/adk/tools/bigquery/metadata_tool.py new file mode 100644 index 0000000..c4b866d --- /dev/null +++ b/src/google/adk/tools/bigquery/metadata_tool.py @@ -0,0 +1,249 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from google.cloud import bigquery +from google.oauth2.credentials import Credentials + +from ...tools.bigquery import client + + +def list_dataset_ids(project_id: str, credentials: Credentials) -> list[str]: + """List BigQuery dataset ids in a Google Cloud project. + + Args: + project_id (str): The Google Cloud project id. + credentials (Credentials): The credentials to use for the request. + + Returns: + list[str]: List of the BigQuery dataset ids present in the project. + + Examples: + >>> list_dataset_ids("bigquery-public-data") + ['america_health_rankings', + 'american_community_survey', + 'aml_ai_input_dataset', + 'austin_311', + 'austin_bikeshare', + 'austin_crime', + 'austin_incidents', + 'austin_waste', + 'baseball', + 'bbc_news'] + """ + try: + bq_client = client.get_bigquery_client(credentials=credentials) + + datasets = [] + for dataset in bq_client.list_datasets(project_id): + datasets.append(dataset.dataset_id) + return datasets + except Exception as ex: + return { + "status": "ERROR", + "error_details": str(ex), + } + + +def get_dataset_info( + project_id: str, dataset_id: str, credentials: Credentials +) -> dict: + """Get metadata information about a BigQuery dataset. + + Args: + project_id (str): The Google Cloud project id containing the dataset. + dataset_id (str): The BigQuery dataset id. + credentials (Credentials): The credentials to use for the request. + + Returns: + dict: Dictionary representing the properties of the dataset. + + Examples: + >>> get_dataset_info("bigquery-public-data", "penguins") + { + "kind": "bigquery#dataset", + "etag": "PNC5907iQbzeVcAru/2L3A==", + "id": "bigquery-public-data:ml_datasets", + "selfLink": + "https://bigquery.googleapis.com/bigquery/v2/projects/bigquery-public-data/datasets/ml_datasets", + "datasetReference": { + "datasetId": "ml_datasets", + "projectId": "bigquery-public-data" + }, + "access": [ + { + "role": "OWNER", + "groupByEmail": "cloud-datasets-eng@google.com" + }, + { + "role": "READER", + "iamMember": "allUsers" + }, + { + "role": "READER", + "groupByEmail": "bqml-eng@google.com" + } + ], + "creationTime": "1553208775542", + "lastModifiedTime": "1686338918114", + "location": "US", + "type": "DEFAULT", + "maxTimeTravelHours": "168" + } + """ + try: + bq_client = client.get_bigquery_client(credentials=credentials) + dataset = bq_client.get_dataset( + bigquery.DatasetReference(project_id, dataset_id) + ) + return dataset.to_api_repr() + except Exception as ex: + return { + "status": "ERROR", + "error_details": str(ex), + } + + +def list_table_ids( + project_id: str, dataset_id: str, credentials: Credentials +) -> list[str]: + """List table ids in a BigQuery dataset. + + Args: + project_id (str): The Google Cloud project id containing the dataset. + dataset_id (str): The BigQuery dataset id. + credentials (Credentials): The credentials to use for the request. + + Returns: + list[str]: List of the tables ids present in the dataset. + + Examples: + >>> list_table_ids("bigquery-public-data", "ml_datasets") + ['census_adult_income', + 'credit_card_default', + 'holidays_and_events_for_forecasting', + 'iris', + 'penguins', + 'ulb_fraud_detection'] + """ + try: + bq_client = client.get_bigquery_client(credentials=credentials) + + tables = [] + for table in bq_client.list_tables( + bigquery.DatasetReference(project_id, dataset_id) + ): + tables.append(table.table_id) + return tables + except Exception as ex: + return { + "status": "ERROR", + "error_details": str(ex), + } + + +def get_table_info( + project_id: str, dataset_id: str, table_id: str, credentials: Credentials +) -> dict: + """Get metadata information about a BigQuery table. + + Args: + project_id (str): The Google Cloud project id containing the dataset. + dataset_id (str): The BigQuery dataset id containing the table. + table_id (str): The BigQuery table id. + credentials (Credentials): The credentials to use for the request. + + Returns: + dict: Dictionary representing the properties of the table. + + Examples: + >>> get_table_info("bigquery-public-data", "ml_datasets", "penguins") + { + "kind": "bigquery#table", + "etag": "X0ZkRohSGoYvWemRYEgOHA==", + "id": "bigquery-public-data:ml_datasets.penguins", + "selfLink": + "https://bigquery.googleapis.com/bigquery/v2/projects/bigquery-public-data/datasets/ml_datasets/tables/penguins", + "tableReference": { + "projectId": "bigquery-public-data", + "datasetId": "ml_datasets", + "tableId": "penguins" + }, + "schema": { + "fields": [ + { + "name": "species", + "type": "STRING", + "mode": "REQUIRED" + }, + { + "name": "island", + "type": "STRING", + "mode": "NULLABLE" + }, + { + "name": "culmen_length_mm", + "type": "FLOAT", + "mode": "NULLABLE" + }, + { + "name": "culmen_depth_mm", + "type": "FLOAT", + "mode": "NULLABLE" + }, + { + "name": "flipper_length_mm", + "type": "FLOAT", + "mode": "NULLABLE" + }, + { + "name": "body_mass_g", + "type": "FLOAT", + "mode": "NULLABLE" + }, + { + "name": "sex", + "type": "STRING", + "mode": "NULLABLE" + } + ] + }, + "numBytes": "28947", + "numLongTermBytes": "28947", + "numRows": "344", + "creationTime": "1619804743188", + "lastModifiedTime": "1634584675234", + "type": "TABLE", + "location": "US", + "numTimeTravelPhysicalBytes": "0", + "numTotalLogicalBytes": "28947", + "numActiveLogicalBytes": "0", + "numLongTermLogicalBytes": "28947", + "numTotalPhysicalBytes": "5350", + "numActivePhysicalBytes": "0", + "numLongTermPhysicalBytes": "5350", + "numCurrentPhysicalBytes": "5350" + } + """ + try: + bq_client = client.get_bigquery_client(credentials=credentials) + return bq_client.get_table( + bigquery.TableReference( + bigquery.DatasetReference(project_id, dataset_id), table_id + ) + ).to_api_repr() + except Exception as ex: + return { + "status": "ERROR", + "error_details": str(ex), + } diff --git a/src/google/adk/tools/bigquery/query_tool.py b/src/google/adk/tools/bigquery/query_tool.py new file mode 100644 index 0000000..8144401 --- /dev/null +++ b/src/google/adk/tools/bigquery/query_tool.py @@ -0,0 +1,76 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from google.oauth2.credentials import Credentials + +from ...tools.bigquery import client + +MAX_DOWNLOADED_QUERY_RESULT_ROWS = 50 + + +def execute_sql(project_id: str, query: str, credentials: Credentials) -> dict: + """Run a BigQuery SQL query in the project and return the result. + + Args: + project_id (str): The GCP project id in which the query should be + executed. + query (str): The BigQuery SQL query to be executed. + credentials (Credentials): The credentials to use for the request. + + Returns: + dict: Dictionary representing the result of the query. + If the result contains the key "result_is_likely_truncated" with + value True, it means that there may be additional rows matching the + query not returned in the result. + + Examples: + >>> execute_sql("bigframes-dev", + ... "SELECT island, COUNT(*) AS population " + ... "FROM bigquery-public-data.ml_datasets.penguins GROUP BY island") + { + "rows": [ + { + "island": "Dream", + "population": 124 + }, + { + "island": "Biscoe", + "population": 168 + }, + { + "island": "Torgersen", + "population": 52 + } + ] + } + """ + + try: + bq_client = client.get_bigquery_client(credentials=credentials) + row_iterator = bq_client.query_and_wait( + query, project=project_id, max_results=MAX_DOWNLOADED_QUERY_RESULT_ROWS + ) + rows = [{key: val for key, val in row.items()} for row in row_iterator] + result = {"rows": rows} + if ( + MAX_DOWNLOADED_QUERY_RESULT_ROWS is not None + and len(rows) == MAX_DOWNLOADED_QUERY_RESULT_ROWS + ): + result["result_is_likely_truncated"] = True + return result + except Exception as ex: + return { + "status": "ERROR", + "error_details": str(ex), + } diff --git a/tests/unittests/tools/bigquery/test_bigquery_credentials_manager.py b/tests/unittests/tools/bigquery/test_bigquery_credentials_manager.py index e6fd38a..95d8b00 100644 --- a/tests/unittests/tools/bigquery/test_bigquery_credentials_manager.py +++ b/tests/unittests/tools/bigquery/test_bigquery_credentials_manager.py @@ -13,6 +13,7 @@ # limitations under the License. +import json from unittest.mock import Mock from unittest.mock import patch @@ -98,12 +99,12 @@ class TestBigQueryCredentialsManager: manager.credentials_config.credentials = None # Create mock cached credentials JSON that would be stored in cache - mock_cached_creds_json = { + mock_cached_creds_json = json.dumps({ "token": "cached_token", "refresh_token": "cached_refresh_token", "client_id": "test_client_id", "client_secret": "test_client_secret", - } + }) # Set up the tool context state to contain cached credentials mock_tool_context.state[BIGQUERY_TOKEN_CACHE_KEY] = mock_cached_creds_json @@ -120,7 +121,7 @@ class TestBigQueryCredentialsManager: # Verify credentials were created from cached JSON mock_from_json.assert_called_once_with( - mock_cached_creds_json, manager.credentials_config.scopes + json.loads(mock_cached_creds_json), manager.credentials_config.scopes ) # Verify loaded credentials were not cached into manager assert manager.credentials_config.credentials is None @@ -160,19 +161,19 @@ class TestBigQueryCredentialsManager: manager.credentials_config.credentials = None # Create mock cached credentials JSON - mock_cached_creds_json = { + mock_cached_creds_json = json.dumps({ "token": "expired_token", "refresh_token": "valid_refresh_token", "client_id": "test_client_id", "client_secret": "test_client_secret", - } + }) - mock_refreshed_creds_json = { + mock_refreshed_creds_json = json.dumps({ "token": "new_token", "refresh_token": "valid_refresh_token", "client_id": "test_client_id", "client_secret": "test_client_secret", - } + }) # Set up the tool context state to contain cached credentials mock_tool_context.state[BIGQUERY_TOKEN_CACHE_KEY] = mock_cached_creds_json @@ -200,7 +201,7 @@ class TestBigQueryCredentialsManager: # Verify credentials were created from cached JSON mock_from_json.assert_called_once_with( - mock_cached_creds_json, manager.credentials_config.scopes + json.loads(mock_cached_creds_json), manager.credentials_config.scopes ) # Verify refresh was attempted and succeeded mock_cached_creds.refresh.assert_called_once() @@ -209,7 +210,9 @@ class TestBigQueryCredentialsManager: # Verify refreshed credentials were cached assert ( "new_token" - == mock_tool_context.state[BIGQUERY_TOKEN_CACHE_KEY]["token"] + == json.loads(mock_tool_context.state[BIGQUERY_TOKEN_CACHE_KEY])[ + "token" + ] ) assert result == mock_cached_creds diff --git a/tests/unittests/tools/bigquery_tool/test_bigquery_toolset.py b/tests/unittests/tools/bigquery_tool/test_bigquery_toolset.py new file mode 100644 index 0000000..ea9990b --- /dev/null +++ b/tests/unittests/tools/bigquery_tool/test_bigquery_toolset.py @@ -0,0 +1,123 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from google.adk.tools.bigquery import BigQueryCredentialsConfig +from google.adk.tools.bigquery import BigQueryTool +from google.adk.tools.bigquery import BigQueryToolset +import pytest + + +@pytest.mark.asyncio +async def test_bigquery_toolset_tools_default(): + """Test default BigQuery toolset. + + This test verifies the behavior of the BigQuery toolset when no filter is + specified. + """ + credentials_config = BigQueryCredentialsConfig( + client_id="abc", client_secret="def" + ) + toolset = BigQueryToolset(credentials_config=credentials_config) + tools = await toolset.get_tools() + assert tools is not None + + assert len(tools) == 5 + assert all([isinstance(tool, BigQueryTool) for tool in tools]) + + expected_tool_names = set([ + "list_dataset_ids", + "get_dataset_info", + "list_table_ids", + "get_table_info", + "execute_sql", + ]) + actual_tool_names = set([tool.name for tool in tools]) + assert actual_tool_names == expected_tool_names + + +@pytest.mark.parametrize( + "selected_tools", + [ + pytest.param([], id="None"), + pytest.param( + ["list_dataset_ids", "get_dataset_info"], id="dataset-metadata" + ), + pytest.param(["list_table_ids", "get_table_info"], id="table-metadata"), + pytest.param(["execute_sql"], id="query"), + ], +) +@pytest.mark.asyncio +async def test_bigquery_toolset_tools_selective(selected_tools): + """Test BigQuery toolset with filter. + + This test verifies the behavior of the BigQuery toolset when filter is + specified. A use case for this would be when the agent builder wants to + use only a subset of the tools provided by the toolset. + """ + credentials_config = BigQueryCredentialsConfig( + client_id="abc", client_secret="def" + ) + toolset = BigQueryToolset( + credentials_config=credentials_config, tool_filter=selected_tools + ) + tools = await toolset.get_tools() + assert tools is not None + + assert len(tools) == len(selected_tools) + assert all([isinstance(tool, BigQueryTool) for tool in tools]) + + expected_tool_names = set(selected_tools) + actual_tool_names = set([tool.name for tool in tools]) + assert actual_tool_names == expected_tool_names + + +@pytest.mark.parametrize( + ("selected_tools", "returned_tools"), + [ + pytest.param(["unknown"], [], id="all-unknown"), + pytest.param( + ["unknown", "execute_sql"], + ["execute_sql"], + id="mixed-known-unknown", + ), + ], +) +@pytest.mark.asyncio +async def test_bigquery_toolset_unknown_tool_raises( + selected_tools, returned_tools +): + """Test BigQuery toolset with filter. + + This test verifies the behavior of the BigQuery toolset when filter is + specified with an unknown tool. + """ + credentials_config = BigQueryCredentialsConfig( + client_id="abc", client_secret="def" + ) + + toolset = BigQueryToolset( + credentials_config=credentials_config, tool_filter=selected_tools + ) + + tools = await toolset.get_tools() + assert tools is not None + + assert len(tools) == len(returned_tools) + assert all([isinstance(tool, BigQueryTool) for tool in tools]) + + expected_tool_names = set(returned_tools) + actual_tool_names = set([tool.name for tool in tools]) + assert actual_tool_names == expected_tool_names