mirror of
https://github.com/EvolutionAPI/adk-python.git
synced 2025-07-13 23:17:35 -06:00
feat: add BigQuery first-party tools.
These tools support getting BigQuery dataset/table metadata and query results. PiperOrigin-RevId: 764139132
This commit is contained in:
parent
46282eeb0d
commit
d6c6bb4b24
83
contributing/samples/bigquery/README.md
Normal file
83
contributing/samples/bigquery/README.md
Normal file
@ -0,0 +1,83 @@
|
||||
# BigQuery Tools Sample
|
||||
|
||||
## Introduction
|
||||
|
||||
This sample agent demonstrates the BigQuery first-party tools in ADK,
|
||||
distributed via the `google.adk.tools.bigquery` module. These tools include:
|
||||
|
||||
1. `list_dataset_ids`
|
||||
|
||||
Fetches BigQuery dataset ids present in a GCP project.
|
||||
|
||||
1. `get_dataset_info`
|
||||
|
||||
Fetches metadata about a BigQuery dataset.
|
||||
|
||||
1. `list_table_ids`
|
||||
|
||||
Fetches table ids present in a BigQuery dataset.
|
||||
|
||||
1. `get_table_info`
|
||||
|
||||
Fetches metadata about a BigQuery table.
|
||||
|
||||
1. `execute_sql`
|
||||
|
||||
Runs a SQL query in BigQuery.
|
||||
|
||||
## How to use
|
||||
|
||||
Set up environment variables in your `.env` file for using
|
||||
[Google AI Studio](https://google.github.io/adk-docs/get-started/quickstart/#gemini---google-ai-studio)
|
||||
or
|
||||
[Google Cloud Vertex AI](https://google.github.io/adk-docs/get-started/quickstart/#gemini---google-cloud-vertex-ai)
|
||||
for the LLM service for your agent. For example, for using Google AI Studio you
|
||||
would set:
|
||||
|
||||
* GOOGLE_GENAI_USE_VERTEXAI=FALSE
|
||||
* GOOGLE_API_KEY={your api key}
|
||||
|
||||
### With Application Default Credentials
|
||||
|
||||
This mode is useful for quick development when the agent builder is the only
|
||||
user interacting with the agent. The tools are initialized with the default
|
||||
credentials present on the machine running the agent.
|
||||
|
||||
1. Create application default credentials on the machine where the agent would
|
||||
be running by following https://cloud.google.com/docs/authentication/provide-credentials-adc.
|
||||
|
||||
1. Set `RUN_WITH_ADC=True` in `agent.py` and run the agent
|
||||
|
||||
### With Interactive OAuth
|
||||
|
||||
1. Follow
|
||||
https://developers.google.com/identity/protocols/oauth2#1.-obtain-oauth-2.0-credentials-from-the-dynamic_data.setvar.console_name.
|
||||
to get your client id and client secret. Be sure to choose "web" as your client
|
||||
type.
|
||||
|
||||
1. Follow https://developers.google.com/workspace/guides/configure-oauth-consent to add scope "https://www.googleapis.com/auth/bigquery".
|
||||
|
||||
1. Follow https://developers.google.com/identity/protocols/oauth2/web-server#creatingcred to add http://localhost/dev-ui/ to "Authorized redirect URIs".
|
||||
|
||||
Note: localhost here is just a hostname that you use to access the dev ui,
|
||||
replace it with the actual hostname you use to access the dev ui.
|
||||
|
||||
1. For 1st run, allow popup for localhost in Chrome.
|
||||
|
||||
1. Configure your `.env` file to add two more variables before running the agent:
|
||||
|
||||
* OAUTH_CLIENT_ID={your client id}
|
||||
* OAUTH_CLIENT_SECRET={your client secret}
|
||||
|
||||
Note: don't create a separate .env, instead put it to the same .env file that
|
||||
stores your Vertex AI or Dev ML credentials
|
||||
|
||||
1. Set `RUN_WITH_ADC=False` in `agent.py` and run the agent
|
||||
|
||||
## Sample prompts
|
||||
|
||||
* which weather datasets exist in bigquery public data?
|
||||
* tell me more about noaa_lightning
|
||||
* which tables exist in the ml_datasets dataset?
|
||||
* show more details about the penguins table
|
||||
* compute penguins population per island.
|
15
contributing/samples/bigquery/__init__.py
Normal file
15
contributing/samples/bigquery/__init__.py
Normal file
@ -0,0 +1,15 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from . import agent
|
58
contributing/samples/bigquery/agent.py
Normal file
58
contributing/samples/bigquery/agent.py
Normal file
@ -0,0 +1,58 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
|
||||
from google.adk.agents import llm_agent
|
||||
from google.adk.tools.bigquery import BigQueryCredentialsConfig
|
||||
from google.adk.tools.bigquery import BigQueryToolset
|
||||
import google.auth
|
||||
|
||||
|
||||
RUN_WITH_ADC = False
|
||||
|
||||
|
||||
if RUN_WITH_ADC:
|
||||
# Initialize the tools to use the application default credentials.
|
||||
application_default_credentials, _ = google.auth.default()
|
||||
credentials_config = BigQueryCredentialsConfig(
|
||||
credentials=application_default_credentials
|
||||
)
|
||||
else:
|
||||
# Initiaze the tools to do interactive OAuth
|
||||
# The environment variables OAUTH_CLIENT_ID and OAUTH_CLIENT_SECRET
|
||||
# must be set
|
||||
credentials_config = BigQueryCredentialsConfig(
|
||||
client_id=os.getenv("OAUTH_CLIENT_ID"),
|
||||
client_secret=os.getenv("OAUTH_CLIENT_SECRET"),
|
||||
scopes=["https://www.googleapis.com/auth/bigquery"],
|
||||
)
|
||||
|
||||
bigquery_toolset = BigQueryToolset(credentials_config=credentials_config)
|
||||
|
||||
# The variable name `root_agent` determines what your root agent is for the
|
||||
# debug CLI
|
||||
root_agent = llm_agent.Agent(
|
||||
model="gemini-2.0-flash",
|
||||
name="hello_agent",
|
||||
description=(
|
||||
"Agent to answer questions about BigQuery data and models and execute"
|
||||
" SQL queries."
|
||||
),
|
||||
instruction="""\
|
||||
You are a data science agent with access to several BigQuery tools.
|
||||
Make use of those tools to answer the user's questions.
|
||||
""",
|
||||
tools=[bigquery_toolset],
|
||||
)
|
@ -12,7 +12,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""BigQuery Tools. (Experimental)
|
||||
"""BigQuery Tools (Experimental).
|
||||
|
||||
BigQuery Tools under this module are hand crafted and customized while the tools
|
||||
under google.adk.tools.google_api_tool are auto generated based on API
|
||||
@ -26,3 +26,13 @@ definition. The rationales to have customized tool are:
|
||||
4. We want to provide extra access guardrails in those tools. For example,
|
||||
execute_sql can't arbitrarily mutate existing data.
|
||||
"""
|
||||
|
||||
from .bigquery_credentials import BigQueryCredentialsConfig
|
||||
from .bigquery_tool import BigQueryTool
|
||||
from .bigquery_toolset import BigQueryToolset
|
||||
|
||||
__all__ = [
|
||||
"BigQueryTool",
|
||||
"BigQueryToolset",
|
||||
"BigQueryCredentialsConfig",
|
||||
]
|
||||
|
@ -14,6 +14,7 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from typing import List
|
||||
from typing import Optional
|
||||
|
||||
@ -121,7 +122,7 @@ class BigQueryCredentialsManager:
|
||||
creds_json = tool_context.state.get(BIGQUERY_TOKEN_CACHE_KEY, None)
|
||||
creds = (
|
||||
Credentials.from_authorized_user_info(
|
||||
creds_json, self.credentials_config.scopes
|
||||
json.loads(creds_json), self.credentials_config.scopes
|
||||
)
|
||||
if creds_json
|
||||
else None
|
||||
|
86
src/google/adk/tools/bigquery/bigquery_toolset.py
Normal file
86
src/google/adk/tools/bigquery/bigquery_toolset.py
Normal file
@ -0,0 +1,86 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import List
|
||||
from typing import Optional
|
||||
from typing import Union
|
||||
|
||||
from google.adk.agents.readonly_context import ReadonlyContext
|
||||
from typing_extensions import override
|
||||
|
||||
from . import metadata_tool
|
||||
from . import query_tool
|
||||
from ...tools.base_tool import BaseTool
|
||||
from ...tools.base_toolset import BaseToolset
|
||||
from ...tools.base_toolset import ToolPredicate
|
||||
from .bigquery_credentials import BigQueryCredentialsConfig
|
||||
from .bigquery_tool import BigQueryTool
|
||||
|
||||
|
||||
class BigQueryToolset(BaseToolset):
|
||||
"""BigQuery Toolset contains tools for interacting with BigQuery data and metadata."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
tool_filter: Optional[Union[ToolPredicate, List[str]]] = None,
|
||||
credentials_config: Optional[BigQueryCredentialsConfig] = None,
|
||||
):
|
||||
self._credentials_config = credentials_config
|
||||
self.tool_filter = tool_filter
|
||||
|
||||
def _is_tool_selected(
|
||||
self, tool: BaseTool, readonly_context: ReadonlyContext
|
||||
) -> bool:
|
||||
if self.tool_filter is None:
|
||||
return True
|
||||
|
||||
if isinstance(self.tool_filter, ToolPredicate):
|
||||
return self.tool_filter(tool, readonly_context)
|
||||
|
||||
if isinstance(self.tool_filter, list):
|
||||
return tool.name in self.tool_filter
|
||||
|
||||
return False
|
||||
|
||||
@override
|
||||
async def get_tools(
|
||||
self, readonly_context: Optional[ReadonlyContext] = None
|
||||
) -> List[BaseTool]:
|
||||
"""Get tools from the toolset."""
|
||||
all_tools = [
|
||||
BigQueryTool(
|
||||
func=func,
|
||||
credentials=self._credentials_config,
|
||||
)
|
||||
for func in [
|
||||
metadata_tool.get_dataset_info,
|
||||
metadata_tool.get_table_info,
|
||||
metadata_tool.list_dataset_ids,
|
||||
metadata_tool.list_table_ids,
|
||||
query_tool.execute_sql,
|
||||
]
|
||||
]
|
||||
|
||||
return [
|
||||
tool
|
||||
for tool in all_tools
|
||||
if self._is_tool_selected(tool, readonly_context)
|
||||
]
|
||||
|
||||
@override
|
||||
async def close(self):
|
||||
pass
|
33
src/google/adk/tools/bigquery/client.py
Normal file
33
src/google/adk/tools/bigquery/client.py
Normal file
@ -0,0 +1,33 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import google.api_core.client_info
|
||||
from google.cloud import bigquery
|
||||
from google.oauth2.credentials import Credentials
|
||||
|
||||
USER_AGENT = "adk-bigquery-tool"
|
||||
|
||||
|
||||
def get_bigquery_client(*, credentials: Credentials) -> bigquery.Client:
|
||||
"""Get a BigQuery client."""
|
||||
|
||||
client_info = google.api_core.client_info.ClientInfo(user_agent=USER_AGENT)
|
||||
|
||||
bigquery_client = bigquery.Client(
|
||||
credentials=credentials, client_info=client_info
|
||||
)
|
||||
|
||||
return bigquery_client
|
249
src/google/adk/tools/bigquery/metadata_tool.py
Normal file
249
src/google/adk/tools/bigquery/metadata_tool.py
Normal file
@ -0,0 +1,249 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from google.cloud import bigquery
|
||||
from google.oauth2.credentials import Credentials
|
||||
|
||||
from ...tools.bigquery import client
|
||||
|
||||
|
||||
def list_dataset_ids(project_id: str, credentials: Credentials) -> list[str]:
|
||||
"""List BigQuery dataset ids in a Google Cloud project.
|
||||
|
||||
Args:
|
||||
project_id (str): The Google Cloud project id.
|
||||
credentials (Credentials): The credentials to use for the request.
|
||||
|
||||
Returns:
|
||||
list[str]: List of the BigQuery dataset ids present in the project.
|
||||
|
||||
Examples:
|
||||
>>> list_dataset_ids("bigquery-public-data")
|
||||
['america_health_rankings',
|
||||
'american_community_survey',
|
||||
'aml_ai_input_dataset',
|
||||
'austin_311',
|
||||
'austin_bikeshare',
|
||||
'austin_crime',
|
||||
'austin_incidents',
|
||||
'austin_waste',
|
||||
'baseball',
|
||||
'bbc_news']
|
||||
"""
|
||||
try:
|
||||
bq_client = client.get_bigquery_client(credentials=credentials)
|
||||
|
||||
datasets = []
|
||||
for dataset in bq_client.list_datasets(project_id):
|
||||
datasets.append(dataset.dataset_id)
|
||||
return datasets
|
||||
except Exception as ex:
|
||||
return {
|
||||
"status": "ERROR",
|
||||
"error_details": str(ex),
|
||||
}
|
||||
|
||||
|
||||
def get_dataset_info(
|
||||
project_id: str, dataset_id: str, credentials: Credentials
|
||||
) -> dict:
|
||||
"""Get metadata information about a BigQuery dataset.
|
||||
|
||||
Args:
|
||||
project_id (str): The Google Cloud project id containing the dataset.
|
||||
dataset_id (str): The BigQuery dataset id.
|
||||
credentials (Credentials): The credentials to use for the request.
|
||||
|
||||
Returns:
|
||||
dict: Dictionary representing the properties of the dataset.
|
||||
|
||||
Examples:
|
||||
>>> get_dataset_info("bigquery-public-data", "penguins")
|
||||
{
|
||||
"kind": "bigquery#dataset",
|
||||
"etag": "PNC5907iQbzeVcAru/2L3A==",
|
||||
"id": "bigquery-public-data:ml_datasets",
|
||||
"selfLink":
|
||||
"https://bigquery.googleapis.com/bigquery/v2/projects/bigquery-public-data/datasets/ml_datasets",
|
||||
"datasetReference": {
|
||||
"datasetId": "ml_datasets",
|
||||
"projectId": "bigquery-public-data"
|
||||
},
|
||||
"access": [
|
||||
{
|
||||
"role": "OWNER",
|
||||
"groupByEmail": "cloud-datasets-eng@google.com"
|
||||
},
|
||||
{
|
||||
"role": "READER",
|
||||
"iamMember": "allUsers"
|
||||
},
|
||||
{
|
||||
"role": "READER",
|
||||
"groupByEmail": "bqml-eng@google.com"
|
||||
}
|
||||
],
|
||||
"creationTime": "1553208775542",
|
||||
"lastModifiedTime": "1686338918114",
|
||||
"location": "US",
|
||||
"type": "DEFAULT",
|
||||
"maxTimeTravelHours": "168"
|
||||
}
|
||||
"""
|
||||
try:
|
||||
bq_client = client.get_bigquery_client(credentials=credentials)
|
||||
dataset = bq_client.get_dataset(
|
||||
bigquery.DatasetReference(project_id, dataset_id)
|
||||
)
|
||||
return dataset.to_api_repr()
|
||||
except Exception as ex:
|
||||
return {
|
||||
"status": "ERROR",
|
||||
"error_details": str(ex),
|
||||
}
|
||||
|
||||
|
||||
def list_table_ids(
|
||||
project_id: str, dataset_id: str, credentials: Credentials
|
||||
) -> list[str]:
|
||||
"""List table ids in a BigQuery dataset.
|
||||
|
||||
Args:
|
||||
project_id (str): The Google Cloud project id containing the dataset.
|
||||
dataset_id (str): The BigQuery dataset id.
|
||||
credentials (Credentials): The credentials to use for the request.
|
||||
|
||||
Returns:
|
||||
list[str]: List of the tables ids present in the dataset.
|
||||
|
||||
Examples:
|
||||
>>> list_table_ids("bigquery-public-data", "ml_datasets")
|
||||
['census_adult_income',
|
||||
'credit_card_default',
|
||||
'holidays_and_events_for_forecasting',
|
||||
'iris',
|
||||
'penguins',
|
||||
'ulb_fraud_detection']
|
||||
"""
|
||||
try:
|
||||
bq_client = client.get_bigquery_client(credentials=credentials)
|
||||
|
||||
tables = []
|
||||
for table in bq_client.list_tables(
|
||||
bigquery.DatasetReference(project_id, dataset_id)
|
||||
):
|
||||
tables.append(table.table_id)
|
||||
return tables
|
||||
except Exception as ex:
|
||||
return {
|
||||
"status": "ERROR",
|
||||
"error_details": str(ex),
|
||||
}
|
||||
|
||||
|
||||
def get_table_info(
|
||||
project_id: str, dataset_id: str, table_id: str, credentials: Credentials
|
||||
) -> dict:
|
||||
"""Get metadata information about a BigQuery table.
|
||||
|
||||
Args:
|
||||
project_id (str): The Google Cloud project id containing the dataset.
|
||||
dataset_id (str): The BigQuery dataset id containing the table.
|
||||
table_id (str): The BigQuery table id.
|
||||
credentials (Credentials): The credentials to use for the request.
|
||||
|
||||
Returns:
|
||||
dict: Dictionary representing the properties of the table.
|
||||
|
||||
Examples:
|
||||
>>> get_table_info("bigquery-public-data", "ml_datasets", "penguins")
|
||||
{
|
||||
"kind": "bigquery#table",
|
||||
"etag": "X0ZkRohSGoYvWemRYEgOHA==",
|
||||
"id": "bigquery-public-data:ml_datasets.penguins",
|
||||
"selfLink":
|
||||
"https://bigquery.googleapis.com/bigquery/v2/projects/bigquery-public-data/datasets/ml_datasets/tables/penguins",
|
||||
"tableReference": {
|
||||
"projectId": "bigquery-public-data",
|
||||
"datasetId": "ml_datasets",
|
||||
"tableId": "penguins"
|
||||
},
|
||||
"schema": {
|
||||
"fields": [
|
||||
{
|
||||
"name": "species",
|
||||
"type": "STRING",
|
||||
"mode": "REQUIRED"
|
||||
},
|
||||
{
|
||||
"name": "island",
|
||||
"type": "STRING",
|
||||
"mode": "NULLABLE"
|
||||
},
|
||||
{
|
||||
"name": "culmen_length_mm",
|
||||
"type": "FLOAT",
|
||||
"mode": "NULLABLE"
|
||||
},
|
||||
{
|
||||
"name": "culmen_depth_mm",
|
||||
"type": "FLOAT",
|
||||
"mode": "NULLABLE"
|
||||
},
|
||||
{
|
||||
"name": "flipper_length_mm",
|
||||
"type": "FLOAT",
|
||||
"mode": "NULLABLE"
|
||||
},
|
||||
{
|
||||
"name": "body_mass_g",
|
||||
"type": "FLOAT",
|
||||
"mode": "NULLABLE"
|
||||
},
|
||||
{
|
||||
"name": "sex",
|
||||
"type": "STRING",
|
||||
"mode": "NULLABLE"
|
||||
}
|
||||
]
|
||||
},
|
||||
"numBytes": "28947",
|
||||
"numLongTermBytes": "28947",
|
||||
"numRows": "344",
|
||||
"creationTime": "1619804743188",
|
||||
"lastModifiedTime": "1634584675234",
|
||||
"type": "TABLE",
|
||||
"location": "US",
|
||||
"numTimeTravelPhysicalBytes": "0",
|
||||
"numTotalLogicalBytes": "28947",
|
||||
"numActiveLogicalBytes": "0",
|
||||
"numLongTermLogicalBytes": "28947",
|
||||
"numTotalPhysicalBytes": "5350",
|
||||
"numActivePhysicalBytes": "0",
|
||||
"numLongTermPhysicalBytes": "5350",
|
||||
"numCurrentPhysicalBytes": "5350"
|
||||
}
|
||||
"""
|
||||
try:
|
||||
bq_client = client.get_bigquery_client(credentials=credentials)
|
||||
return bq_client.get_table(
|
||||
bigquery.TableReference(
|
||||
bigquery.DatasetReference(project_id, dataset_id), table_id
|
||||
)
|
||||
).to_api_repr()
|
||||
except Exception as ex:
|
||||
return {
|
||||
"status": "ERROR",
|
||||
"error_details": str(ex),
|
||||
}
|
76
src/google/adk/tools/bigquery/query_tool.py
Normal file
76
src/google/adk/tools/bigquery/query_tool.py
Normal file
@ -0,0 +1,76 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from google.oauth2.credentials import Credentials
|
||||
|
||||
from ...tools.bigquery import client
|
||||
|
||||
MAX_DOWNLOADED_QUERY_RESULT_ROWS = 50
|
||||
|
||||
|
||||
def execute_sql(project_id: str, query: str, credentials: Credentials) -> dict:
|
||||
"""Run a BigQuery SQL query in the project and return the result.
|
||||
|
||||
Args:
|
||||
project_id (str): The GCP project id in which the query should be
|
||||
executed.
|
||||
query (str): The BigQuery SQL query to be executed.
|
||||
credentials (Credentials): The credentials to use for the request.
|
||||
|
||||
Returns:
|
||||
dict: Dictionary representing the result of the query.
|
||||
If the result contains the key "result_is_likely_truncated" with
|
||||
value True, it means that there may be additional rows matching the
|
||||
query not returned in the result.
|
||||
|
||||
Examples:
|
||||
>>> execute_sql("bigframes-dev",
|
||||
... "SELECT island, COUNT(*) AS population "
|
||||
... "FROM bigquery-public-data.ml_datasets.penguins GROUP BY island")
|
||||
{
|
||||
"rows": [
|
||||
{
|
||||
"island": "Dream",
|
||||
"population": 124
|
||||
},
|
||||
{
|
||||
"island": "Biscoe",
|
||||
"population": 168
|
||||
},
|
||||
{
|
||||
"island": "Torgersen",
|
||||
"population": 52
|
||||
}
|
||||
]
|
||||
}
|
||||
"""
|
||||
|
||||
try:
|
||||
bq_client = client.get_bigquery_client(credentials=credentials)
|
||||
row_iterator = bq_client.query_and_wait(
|
||||
query, project=project_id, max_results=MAX_DOWNLOADED_QUERY_RESULT_ROWS
|
||||
)
|
||||
rows = [{key: val for key, val in row.items()} for row in row_iterator]
|
||||
result = {"rows": rows}
|
||||
if (
|
||||
MAX_DOWNLOADED_QUERY_RESULT_ROWS is not None
|
||||
and len(rows) == MAX_DOWNLOADED_QUERY_RESULT_ROWS
|
||||
):
|
||||
result["result_is_likely_truncated"] = True
|
||||
return result
|
||||
except Exception as ex:
|
||||
return {
|
||||
"status": "ERROR",
|
||||
"error_details": str(ex),
|
||||
}
|
@ -13,6 +13,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import json
|
||||
from unittest.mock import Mock
|
||||
from unittest.mock import patch
|
||||
|
||||
@ -98,12 +99,12 @@ class TestBigQueryCredentialsManager:
|
||||
manager.credentials_config.credentials = None
|
||||
|
||||
# Create mock cached credentials JSON that would be stored in cache
|
||||
mock_cached_creds_json = {
|
||||
mock_cached_creds_json = json.dumps({
|
||||
"token": "cached_token",
|
||||
"refresh_token": "cached_refresh_token",
|
||||
"client_id": "test_client_id",
|
||||
"client_secret": "test_client_secret",
|
||||
}
|
||||
})
|
||||
|
||||
# Set up the tool context state to contain cached credentials
|
||||
mock_tool_context.state[BIGQUERY_TOKEN_CACHE_KEY] = mock_cached_creds_json
|
||||
@ -120,7 +121,7 @@ class TestBigQueryCredentialsManager:
|
||||
|
||||
# Verify credentials were created from cached JSON
|
||||
mock_from_json.assert_called_once_with(
|
||||
mock_cached_creds_json, manager.credentials_config.scopes
|
||||
json.loads(mock_cached_creds_json), manager.credentials_config.scopes
|
||||
)
|
||||
# Verify loaded credentials were not cached into manager
|
||||
assert manager.credentials_config.credentials is None
|
||||
@ -160,19 +161,19 @@ class TestBigQueryCredentialsManager:
|
||||
manager.credentials_config.credentials = None
|
||||
|
||||
# Create mock cached credentials JSON
|
||||
mock_cached_creds_json = {
|
||||
mock_cached_creds_json = json.dumps({
|
||||
"token": "expired_token",
|
||||
"refresh_token": "valid_refresh_token",
|
||||
"client_id": "test_client_id",
|
||||
"client_secret": "test_client_secret",
|
||||
}
|
||||
})
|
||||
|
||||
mock_refreshed_creds_json = {
|
||||
mock_refreshed_creds_json = json.dumps({
|
||||
"token": "new_token",
|
||||
"refresh_token": "valid_refresh_token",
|
||||
"client_id": "test_client_id",
|
||||
"client_secret": "test_client_secret",
|
||||
}
|
||||
})
|
||||
|
||||
# Set up the tool context state to contain cached credentials
|
||||
mock_tool_context.state[BIGQUERY_TOKEN_CACHE_KEY] = mock_cached_creds_json
|
||||
@ -200,7 +201,7 @@ class TestBigQueryCredentialsManager:
|
||||
|
||||
# Verify credentials were created from cached JSON
|
||||
mock_from_json.assert_called_once_with(
|
||||
mock_cached_creds_json, manager.credentials_config.scopes
|
||||
json.loads(mock_cached_creds_json), manager.credentials_config.scopes
|
||||
)
|
||||
# Verify refresh was attempted and succeeded
|
||||
mock_cached_creds.refresh.assert_called_once()
|
||||
@ -209,7 +210,9 @@ class TestBigQueryCredentialsManager:
|
||||
# Verify refreshed credentials were cached
|
||||
assert (
|
||||
"new_token"
|
||||
== mock_tool_context.state[BIGQUERY_TOKEN_CACHE_KEY]["token"]
|
||||
== json.loads(mock_tool_context.state[BIGQUERY_TOKEN_CACHE_KEY])[
|
||||
"token"
|
||||
]
|
||||
)
|
||||
assert result == mock_cached_creds
|
||||
|
||||
|
123
tests/unittests/tools/bigquery_tool/test_bigquery_toolset.py
Normal file
123
tests/unittests/tools/bigquery_tool/test_bigquery_toolset.py
Normal file
@ -0,0 +1,123 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from google.adk.tools.bigquery import BigQueryCredentialsConfig
|
||||
from google.adk.tools.bigquery import BigQueryTool
|
||||
from google.adk.tools.bigquery import BigQueryToolset
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_bigquery_toolset_tools_default():
|
||||
"""Test default BigQuery toolset.
|
||||
|
||||
This test verifies the behavior of the BigQuery toolset when no filter is
|
||||
specified.
|
||||
"""
|
||||
credentials_config = BigQueryCredentialsConfig(
|
||||
client_id="abc", client_secret="def"
|
||||
)
|
||||
toolset = BigQueryToolset(credentials_config=credentials_config)
|
||||
tools = await toolset.get_tools()
|
||||
assert tools is not None
|
||||
|
||||
assert len(tools) == 5
|
||||
assert all([isinstance(tool, BigQueryTool) for tool in tools])
|
||||
|
||||
expected_tool_names = set([
|
||||
"list_dataset_ids",
|
||||
"get_dataset_info",
|
||||
"list_table_ids",
|
||||
"get_table_info",
|
||||
"execute_sql",
|
||||
])
|
||||
actual_tool_names = set([tool.name for tool in tools])
|
||||
assert actual_tool_names == expected_tool_names
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"selected_tools",
|
||||
[
|
||||
pytest.param([], id="None"),
|
||||
pytest.param(
|
||||
["list_dataset_ids", "get_dataset_info"], id="dataset-metadata"
|
||||
),
|
||||
pytest.param(["list_table_ids", "get_table_info"], id="table-metadata"),
|
||||
pytest.param(["execute_sql"], id="query"),
|
||||
],
|
||||
)
|
||||
@pytest.mark.asyncio
|
||||
async def test_bigquery_toolset_tools_selective(selected_tools):
|
||||
"""Test BigQuery toolset with filter.
|
||||
|
||||
This test verifies the behavior of the BigQuery toolset when filter is
|
||||
specified. A use case for this would be when the agent builder wants to
|
||||
use only a subset of the tools provided by the toolset.
|
||||
"""
|
||||
credentials_config = BigQueryCredentialsConfig(
|
||||
client_id="abc", client_secret="def"
|
||||
)
|
||||
toolset = BigQueryToolset(
|
||||
credentials_config=credentials_config, tool_filter=selected_tools
|
||||
)
|
||||
tools = await toolset.get_tools()
|
||||
assert tools is not None
|
||||
|
||||
assert len(tools) == len(selected_tools)
|
||||
assert all([isinstance(tool, BigQueryTool) for tool in tools])
|
||||
|
||||
expected_tool_names = set(selected_tools)
|
||||
actual_tool_names = set([tool.name for tool in tools])
|
||||
assert actual_tool_names == expected_tool_names
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("selected_tools", "returned_tools"),
|
||||
[
|
||||
pytest.param(["unknown"], [], id="all-unknown"),
|
||||
pytest.param(
|
||||
["unknown", "execute_sql"],
|
||||
["execute_sql"],
|
||||
id="mixed-known-unknown",
|
||||
),
|
||||
],
|
||||
)
|
||||
@pytest.mark.asyncio
|
||||
async def test_bigquery_toolset_unknown_tool_raises(
|
||||
selected_tools, returned_tools
|
||||
):
|
||||
"""Test BigQuery toolset with filter.
|
||||
|
||||
This test verifies the behavior of the BigQuery toolset when filter is
|
||||
specified with an unknown tool.
|
||||
"""
|
||||
credentials_config = BigQueryCredentialsConfig(
|
||||
client_id="abc", client_secret="def"
|
||||
)
|
||||
|
||||
toolset = BigQueryToolset(
|
||||
credentials_config=credentials_config, tool_filter=selected_tools
|
||||
)
|
||||
|
||||
tools = await toolset.get_tools()
|
||||
assert tools is not None
|
||||
|
||||
assert len(tools) == len(returned_tools)
|
||||
assert all([isinstance(tool, BigQueryTool) for tool in tools])
|
||||
|
||||
expected_tool_names = set(returned_tools)
|
||||
actual_tool_names = set([tool.name for tool in tools])
|
||||
assert actual_tool_names == expected_tool_names
|
Loading…
Reference in New Issue
Block a user