feat: add BigQuery first-party tools.

These tools support getting BigQuery dataset/table metadata and query results.

PiperOrigin-RevId: 764139132
This commit is contained in:
Google Team Member
2025-05-28 00:58:53 -07:00
committed by Copybara-Service
parent 46282eeb0d
commit d6c6bb4b24
11 changed files with 748 additions and 11 deletions

View File

@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""BigQuery Tools. (Experimental)
"""BigQuery Tools (Experimental).
BigQuery Tools under this module are hand crafted and customized while the tools
under google.adk.tools.google_api_tool are auto generated based on API
@@ -26,3 +26,13 @@ definition. The rationales to have customized tool are:
4. We want to provide extra access guardrails in those tools. For example,
execute_sql can't arbitrarily mutate existing data.
"""
from .bigquery_credentials import BigQueryCredentialsConfig
from .bigquery_tool import BigQueryTool
from .bigquery_toolset import BigQueryToolset
__all__ = [
"BigQueryTool",
"BigQueryToolset",
"BigQueryCredentialsConfig",
]

View File

@@ -14,6 +14,7 @@
from __future__ import annotations
import json
from typing import List
from typing import Optional
@@ -121,7 +122,7 @@ class BigQueryCredentialsManager:
creds_json = tool_context.state.get(BIGQUERY_TOKEN_CACHE_KEY, None)
creds = (
Credentials.from_authorized_user_info(
creds_json, self.credentials_config.scopes
json.loads(creds_json), self.credentials_config.scopes
)
if creds_json
else None

View File

@@ -0,0 +1,86 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
from typing import List
from typing import Optional
from typing import Union
from google.adk.agents.readonly_context import ReadonlyContext
from typing_extensions import override
from . import metadata_tool
from . import query_tool
from ...tools.base_tool import BaseTool
from ...tools.base_toolset import BaseToolset
from ...tools.base_toolset import ToolPredicate
from .bigquery_credentials import BigQueryCredentialsConfig
from .bigquery_tool import BigQueryTool
class BigQueryToolset(BaseToolset):
"""BigQuery Toolset contains tools for interacting with BigQuery data and metadata."""
def __init__(
self,
*,
tool_filter: Optional[Union[ToolPredicate, List[str]]] = None,
credentials_config: Optional[BigQueryCredentialsConfig] = None,
):
self._credentials_config = credentials_config
self.tool_filter = tool_filter
def _is_tool_selected(
self, tool: BaseTool, readonly_context: ReadonlyContext
) -> bool:
if self.tool_filter is None:
return True
if isinstance(self.tool_filter, ToolPredicate):
return self.tool_filter(tool, readonly_context)
if isinstance(self.tool_filter, list):
return tool.name in self.tool_filter
return False
@override
async def get_tools(
self, readonly_context: Optional[ReadonlyContext] = None
) -> List[BaseTool]:
"""Get tools from the toolset."""
all_tools = [
BigQueryTool(
func=func,
credentials=self._credentials_config,
)
for func in [
metadata_tool.get_dataset_info,
metadata_tool.get_table_info,
metadata_tool.list_dataset_ids,
metadata_tool.list_table_ids,
query_tool.execute_sql,
]
]
return [
tool
for tool in all_tools
if self._is_tool_selected(tool, readonly_context)
]
@override
async def close(self):
pass

View File

@@ -0,0 +1,33 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import google.api_core.client_info
from google.cloud import bigquery
from google.oauth2.credentials import Credentials
USER_AGENT = "adk-bigquery-tool"
def get_bigquery_client(*, credentials: Credentials) -> bigquery.Client:
"""Get a BigQuery client."""
client_info = google.api_core.client_info.ClientInfo(user_agent=USER_AGENT)
bigquery_client = bigquery.Client(
credentials=credentials, client_info=client_info
)
return bigquery_client

View File

@@ -0,0 +1,249 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from google.cloud import bigquery
from google.oauth2.credentials import Credentials
from ...tools.bigquery import client
def list_dataset_ids(project_id: str, credentials: Credentials) -> list[str]:
"""List BigQuery dataset ids in a Google Cloud project.
Args:
project_id (str): The Google Cloud project id.
credentials (Credentials): The credentials to use for the request.
Returns:
list[str]: List of the BigQuery dataset ids present in the project.
Examples:
>>> list_dataset_ids("bigquery-public-data")
['america_health_rankings',
'american_community_survey',
'aml_ai_input_dataset',
'austin_311',
'austin_bikeshare',
'austin_crime',
'austin_incidents',
'austin_waste',
'baseball',
'bbc_news']
"""
try:
bq_client = client.get_bigquery_client(credentials=credentials)
datasets = []
for dataset in bq_client.list_datasets(project_id):
datasets.append(dataset.dataset_id)
return datasets
except Exception as ex:
return {
"status": "ERROR",
"error_details": str(ex),
}
def get_dataset_info(
project_id: str, dataset_id: str, credentials: Credentials
) -> dict:
"""Get metadata information about a BigQuery dataset.
Args:
project_id (str): The Google Cloud project id containing the dataset.
dataset_id (str): The BigQuery dataset id.
credentials (Credentials): The credentials to use for the request.
Returns:
dict: Dictionary representing the properties of the dataset.
Examples:
>>> get_dataset_info("bigquery-public-data", "penguins")
{
"kind": "bigquery#dataset",
"etag": "PNC5907iQbzeVcAru/2L3A==",
"id": "bigquery-public-data:ml_datasets",
"selfLink":
"https://bigquery.googleapis.com/bigquery/v2/projects/bigquery-public-data/datasets/ml_datasets",
"datasetReference": {
"datasetId": "ml_datasets",
"projectId": "bigquery-public-data"
},
"access": [
{
"role": "OWNER",
"groupByEmail": "cloud-datasets-eng@google.com"
},
{
"role": "READER",
"iamMember": "allUsers"
},
{
"role": "READER",
"groupByEmail": "bqml-eng@google.com"
}
],
"creationTime": "1553208775542",
"lastModifiedTime": "1686338918114",
"location": "US",
"type": "DEFAULT",
"maxTimeTravelHours": "168"
}
"""
try:
bq_client = client.get_bigquery_client(credentials=credentials)
dataset = bq_client.get_dataset(
bigquery.DatasetReference(project_id, dataset_id)
)
return dataset.to_api_repr()
except Exception as ex:
return {
"status": "ERROR",
"error_details": str(ex),
}
def list_table_ids(
project_id: str, dataset_id: str, credentials: Credentials
) -> list[str]:
"""List table ids in a BigQuery dataset.
Args:
project_id (str): The Google Cloud project id containing the dataset.
dataset_id (str): The BigQuery dataset id.
credentials (Credentials): The credentials to use for the request.
Returns:
list[str]: List of the tables ids present in the dataset.
Examples:
>>> list_table_ids("bigquery-public-data", "ml_datasets")
['census_adult_income',
'credit_card_default',
'holidays_and_events_for_forecasting',
'iris',
'penguins',
'ulb_fraud_detection']
"""
try:
bq_client = client.get_bigquery_client(credentials=credentials)
tables = []
for table in bq_client.list_tables(
bigquery.DatasetReference(project_id, dataset_id)
):
tables.append(table.table_id)
return tables
except Exception as ex:
return {
"status": "ERROR",
"error_details": str(ex),
}
def get_table_info(
project_id: str, dataset_id: str, table_id: str, credentials: Credentials
) -> dict:
"""Get metadata information about a BigQuery table.
Args:
project_id (str): The Google Cloud project id containing the dataset.
dataset_id (str): The BigQuery dataset id containing the table.
table_id (str): The BigQuery table id.
credentials (Credentials): The credentials to use for the request.
Returns:
dict: Dictionary representing the properties of the table.
Examples:
>>> get_table_info("bigquery-public-data", "ml_datasets", "penguins")
{
"kind": "bigquery#table",
"etag": "X0ZkRohSGoYvWemRYEgOHA==",
"id": "bigquery-public-data:ml_datasets.penguins",
"selfLink":
"https://bigquery.googleapis.com/bigquery/v2/projects/bigquery-public-data/datasets/ml_datasets/tables/penguins",
"tableReference": {
"projectId": "bigquery-public-data",
"datasetId": "ml_datasets",
"tableId": "penguins"
},
"schema": {
"fields": [
{
"name": "species",
"type": "STRING",
"mode": "REQUIRED"
},
{
"name": "island",
"type": "STRING",
"mode": "NULLABLE"
},
{
"name": "culmen_length_mm",
"type": "FLOAT",
"mode": "NULLABLE"
},
{
"name": "culmen_depth_mm",
"type": "FLOAT",
"mode": "NULLABLE"
},
{
"name": "flipper_length_mm",
"type": "FLOAT",
"mode": "NULLABLE"
},
{
"name": "body_mass_g",
"type": "FLOAT",
"mode": "NULLABLE"
},
{
"name": "sex",
"type": "STRING",
"mode": "NULLABLE"
}
]
},
"numBytes": "28947",
"numLongTermBytes": "28947",
"numRows": "344",
"creationTime": "1619804743188",
"lastModifiedTime": "1634584675234",
"type": "TABLE",
"location": "US",
"numTimeTravelPhysicalBytes": "0",
"numTotalLogicalBytes": "28947",
"numActiveLogicalBytes": "0",
"numLongTermLogicalBytes": "28947",
"numTotalPhysicalBytes": "5350",
"numActivePhysicalBytes": "0",
"numLongTermPhysicalBytes": "5350",
"numCurrentPhysicalBytes": "5350"
}
"""
try:
bq_client = client.get_bigquery_client(credentials=credentials)
return bq_client.get_table(
bigquery.TableReference(
bigquery.DatasetReference(project_id, dataset_id), table_id
)
).to_api_repr()
except Exception as ex:
return {
"status": "ERROR",
"error_details": str(ex),
}

View File

@@ -0,0 +1,76 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from google.oauth2.credentials import Credentials
from ...tools.bigquery import client
MAX_DOWNLOADED_QUERY_RESULT_ROWS = 50
def execute_sql(project_id: str, query: str, credentials: Credentials) -> dict:
"""Run a BigQuery SQL query in the project and return the result.
Args:
project_id (str): The GCP project id in which the query should be
executed.
query (str): The BigQuery SQL query to be executed.
credentials (Credentials): The credentials to use for the request.
Returns:
dict: Dictionary representing the result of the query.
If the result contains the key "result_is_likely_truncated" with
value True, it means that there may be additional rows matching the
query not returned in the result.
Examples:
>>> execute_sql("bigframes-dev",
... "SELECT island, COUNT(*) AS population "
... "FROM bigquery-public-data.ml_datasets.penguins GROUP BY island")
{
"rows": [
{
"island": "Dream",
"population": 124
},
{
"island": "Biscoe",
"population": 168
},
{
"island": "Torgersen",
"population": 52
}
]
}
"""
try:
bq_client = client.get_bigquery_client(credentials=credentials)
row_iterator = bq_client.query_and_wait(
query, project=project_id, max_results=MAX_DOWNLOADED_QUERY_RESULT_ROWS
)
rows = [{key: val for key, val in row.items()} for row in row_iterator]
result = {"rows": rows}
if (
MAX_DOWNLOADED_QUERY_RESULT_ROWS is not None
and len(rows) == MAX_DOWNLOADED_QUERY_RESULT_ROWS
):
result["result_is_likely_truncated"] = True
return result
except Exception as ex:
return {
"status": "ERROR",
"error_details": str(ex),
}