mirror of
https://github.com/EvolutionAPI/adk-python.git
synced 2025-12-18 19:32:21 -06:00
feat: add customized bigquer tool wrapper class to facilitate developer to handcraft bigquery api tool
PiperOrigin-RevId: 762626700
This commit is contained in:
committed by
Copybara-Service
parent
0e284f45ff
commit
756a326033
28
src/google/adk/tools/bigquery/__init__.py
Normal file
28
src/google/adk/tools/bigquery/__init__.py
Normal file
@@ -0,0 +1,28 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""BigQuery Tools. (Experimental)
|
||||
|
||||
BigQuery Tools under this module are hand crafted and customized while the tools
|
||||
under google.adk.tools.google_api_tool are auto generated based on API
|
||||
definition. The rationales to have customized tool are:
|
||||
|
||||
1. BigQuery APIs have functions overlaps and LLM can't tell what tool to use
|
||||
2. BigQuery APIs have a lot of parameters with some rarely used, which are not
|
||||
LLM-friendly
|
||||
3. We want to provide more high-level tools like forecasting, RAG, segmentation,
|
||||
etc.
|
||||
4. We want to provide extra access guardrails in those tools. For example,
|
||||
execute_sql can't arbitrarily mutate existing data.
|
||||
"""
|
||||
185
src/google/adk/tools/bigquery/bigquery_credentials.py
Normal file
185
src/google/adk/tools/bigquery/bigquery_credentials.py
Normal file
@@ -0,0 +1,185 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import List
|
||||
from typing import Optional
|
||||
|
||||
from fastapi.openapi.models import OAuth2
|
||||
from fastapi.openapi.models import OAuthFlowAuthorizationCode
|
||||
from fastapi.openapi.models import OAuthFlows
|
||||
from google.auth.exceptions import RefreshError
|
||||
from google.auth.transport.requests import Request
|
||||
from google.oauth2.credentials import Credentials
|
||||
from pydantic import BaseModel
|
||||
from pydantic import model_validator
|
||||
|
||||
from ...auth import AuthConfig
|
||||
from ...auth import AuthCredential
|
||||
from ...auth import AuthCredentialTypes
|
||||
from ...auth import OAuth2Auth
|
||||
from ..tool_context import ToolContext
|
||||
|
||||
BIGQUERY_TOKEN_CACHE_KEY = "bigquery_token_cache"
|
||||
|
||||
|
||||
class BigQueryCredentials(BaseModel):
|
||||
"""Configuration for Google API tools. (Experimental)"""
|
||||
|
||||
# Configure the model to allow arbitrary types like Credentials
|
||||
model_config = {"arbitrary_types_allowed": True}
|
||||
|
||||
credentials: Optional[Credentials] = None
|
||||
"""the existing oauth credentials to use. If set will override client ID,
|
||||
client secret, and scopes."""
|
||||
client_id: Optional[str] = None
|
||||
"""the oauth client ID to use."""
|
||||
client_secret: Optional[str] = None
|
||||
"""the oauth client secret to use."""
|
||||
scopes: Optional[List[str]] = None
|
||||
"""the scopes to use.
|
||||
"""
|
||||
|
||||
@model_validator(mode="after")
|
||||
def __post_init__(self) -> "BigQueryCredentials":
|
||||
"""Validate that either credentials or client ID/secret are provided."""
|
||||
if not self.credentials and (not self.client_id or not self.client_secret):
|
||||
raise ValueError(
|
||||
"Must provide either credentials or client_id abd client_secret pair."
|
||||
)
|
||||
if self.credentials:
|
||||
self.client_id = self.credentials.client_id
|
||||
self.client_secret = self.credentials.client_secret
|
||||
self.scopes = self.credentials.scopes
|
||||
return self
|
||||
|
||||
|
||||
class BigQueryCredentialsManager:
|
||||
"""Manages Google API credentials with automatic refresh and OAuth flow handling.
|
||||
|
||||
This class centralizes credential management so multiple tools can share
|
||||
the same authenticated session without duplicating OAuth logic.
|
||||
"""
|
||||
|
||||
def __init__(self, credentials: BigQueryCredentials):
|
||||
"""Initialize the credential manager.
|
||||
|
||||
Args:
|
||||
credential_config: Configuration containing OAuth details or existing
|
||||
credentials
|
||||
"""
|
||||
self.credentials = credentials
|
||||
|
||||
async def get_valid_credentials(
|
||||
self, tool_context: ToolContext
|
||||
) -> Optional[Credentials]:
|
||||
"""Get valid credentials, handling refresh and OAuth flow as needed.
|
||||
|
||||
Args:
|
||||
tool_context: The tool context for OAuth flow and state management
|
||||
required_scopes: Set of OAuth scopes required by the calling tool
|
||||
|
||||
Returns:
|
||||
Valid Credentials object, or None if OAuth flow is needed
|
||||
"""
|
||||
# First, try to get cached credentials from the instance
|
||||
creds = self.credentials.credentials
|
||||
|
||||
# If credentails are empty
|
||||
if not creds:
|
||||
creds = tool_context.get(BIGQUERY_TOKEN_CACHE_KEY, None)
|
||||
self.credentials.credentials = creds
|
||||
|
||||
# Check if we have valid credentials
|
||||
if creds and creds.valid:
|
||||
return creds
|
||||
|
||||
# Try to refresh expired credentials
|
||||
if creds and creds.expired and creds.refresh_token:
|
||||
try:
|
||||
creds.refresh(Request())
|
||||
if creds.valid:
|
||||
# Cache the refreshed credentials
|
||||
self.credentials.credentials = creds
|
||||
return creds
|
||||
except RefreshError:
|
||||
# Refresh failed, need to re-authenticate
|
||||
pass
|
||||
|
||||
# Need to perform OAuth flow
|
||||
return await self._perform_oauth_flow(tool_context)
|
||||
|
||||
async def _perform_oauth_flow(
|
||||
self, tool_context: ToolContext
|
||||
) -> Optional[Credentials]:
|
||||
"""Perform OAuth flow to get new credentials.
|
||||
|
||||
Args:
|
||||
tool_context: The tool context for OAuth flow
|
||||
required_scopes: Set of required OAuth scopes
|
||||
|
||||
Returns:
|
||||
New Credentials object, or None if flow is in progress
|
||||
"""
|
||||
|
||||
# Create OAuth configuration
|
||||
auth_scheme = OAuth2(
|
||||
flows=OAuthFlows(
|
||||
authorizationCode=OAuthFlowAuthorizationCode(
|
||||
authorizationUrl="https://accounts.google.com/o/oauth2/auth",
|
||||
tokenUrl="https://oauth2.googleapis.com/token",
|
||||
scopes={
|
||||
scope: f"Access to {scope}"
|
||||
for scope in self.credentials.scopes
|
||||
},
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
auth_credential = AuthCredential(
|
||||
auth_type=AuthCredentialTypes.OAUTH2,
|
||||
oauth2=OAuth2Auth(
|
||||
client_id=self.credentials.client_id,
|
||||
client_secret=self.credentials.client_secret,
|
||||
),
|
||||
)
|
||||
|
||||
# Check if OAuth response is available
|
||||
auth_response = tool_context.get_auth_response(
|
||||
AuthConfig(auth_scheme=auth_scheme, raw_auth_credential=auth_credential)
|
||||
)
|
||||
|
||||
if auth_response:
|
||||
# OAuth flow completed, create credentials
|
||||
creds = Credentials(
|
||||
token=auth_response.oauth2.access_token,
|
||||
refresh_token=auth_response.oauth2.refresh_token,
|
||||
token_uri=auth_scheme.flows.authorizationCode.tokenUrl,
|
||||
client_id=self.credentials.client_id,
|
||||
client_secret=self.credentials.client_secret,
|
||||
scopes=list(self.credentials.scopes),
|
||||
)
|
||||
|
||||
# Cache the new credentials
|
||||
self.credentials.credentials = creds
|
||||
tool_context.state[BIGQUERY_TOKEN_CACHE_KEY] = creds
|
||||
return creds
|
||||
else:
|
||||
# Request OAuth flow
|
||||
tool_context.request_credential(
|
||||
AuthConfig(
|
||||
auth_scheme=auth_scheme,
|
||||
raw_auth_credential=auth_credential,
|
||||
)
|
||||
)
|
||||
return None
|
||||
116
src/google/adk/tools/bigquery/bigquery_tool.py
Normal file
116
src/google/adk/tools/bigquery/bigquery_tool.py
Normal file
@@ -0,0 +1,116 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import inspect
|
||||
from typing import Any
|
||||
from typing import Callable
|
||||
from typing import Optional
|
||||
from typing import override
|
||||
|
||||
from google.oauth2.credentials import Credentials
|
||||
|
||||
from ..function_tool import FunctionTool
|
||||
from ..tool_context import ToolContext
|
||||
from .bigquery_credentials import BigQueryCredentials
|
||||
from .bigquery_credentials import BigQueryCredentialsManager
|
||||
|
||||
|
||||
class BigQueryTool(FunctionTool):
|
||||
"""GoogleApiTool class for tools that call Google APIs.
|
||||
|
||||
This class is for developers to handcraft customized Google API tools rather
|
||||
than auto generate Google API tools based on API specs.
|
||||
|
||||
This class handles all the OAuth complexity, credential management,
|
||||
and common Google API patterns so subclasses can focus on their
|
||||
specific functionality.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
func: Callable[..., Any],
|
||||
credentials: Optional[BigQueryCredentials] = None,
|
||||
):
|
||||
"""Initialize the Google API tool.
|
||||
|
||||
Args:
|
||||
func: callable that impelments the tool's logic, can accept one
|
||||
'credential" parameter
|
||||
credentials: credentials used to call Google API. If None, then we don't
|
||||
hanlde the auth logic
|
||||
"""
|
||||
super().__init__(func=func)
|
||||
self._ignore_params.append("credentials")
|
||||
self.credentials_manager = (
|
||||
BigQueryCredentialsManager(credentials) if credentials else None
|
||||
)
|
||||
|
||||
@override
|
||||
async def run_async(
|
||||
self, *, args: dict[str, Any], tool_context: ToolContext
|
||||
) -> Any:
|
||||
"""Main entry point for tool execution with credential handling.
|
||||
|
||||
This method handles all the OAuth complexity and then delegates
|
||||
to the subclass's run_async_with_credential method.
|
||||
"""
|
||||
try:
|
||||
# Get valid credentials
|
||||
credentials = (
|
||||
await self.credentials_manager.get_valid_credentials(tool_context)
|
||||
if self.credentials_manager
|
||||
else None
|
||||
)
|
||||
|
||||
if credentials is None and self.credentials_manager:
|
||||
# OAuth flow in progress
|
||||
return (
|
||||
"User authorization is required to access Google services for"
|
||||
f" {self.name}. Please complete the authorization flow."
|
||||
)
|
||||
|
||||
# Execute the tool's specific logic with valid credentials
|
||||
|
||||
return await self._run_async_with_credential(
|
||||
credentials, args, tool_context
|
||||
)
|
||||
|
||||
except Exception as ex:
|
||||
return {
|
||||
"status": "ERROR",
|
||||
"error_details": str(ex),
|
||||
}
|
||||
|
||||
async def _run_async_with_credential(
|
||||
self,
|
||||
credentials: Credentials,
|
||||
args: dict[str, Any],
|
||||
tool_context: ToolContext,
|
||||
) -> Any:
|
||||
"""Execute the tool's specific logic with valid credentials.
|
||||
|
||||
Args:
|
||||
credentials: Valid Google OAuth credentials
|
||||
args: Arguments passed to the tool
|
||||
tool_context: Tool execution context
|
||||
|
||||
Returns:
|
||||
The result of the tool execution
|
||||
"""
|
||||
args_to_call = args.copy()
|
||||
signature = inspect.signature(self.func)
|
||||
if "credentials" in signature.parameters:
|
||||
args_to_call["credentials"] = credentials
|
||||
return await super().run_async(args=args_to_call, tool_context=tool_context)
|
||||
@@ -57,6 +57,7 @@ class FunctionTool(BaseTool):
|
||||
|
||||
super().__init__(name=name, description=doc)
|
||||
self.func = func
|
||||
self._ignore_params = ['tool_context', 'input_stream']
|
||||
|
||||
@override
|
||||
def _get_declaration(self) -> Optional[types.FunctionDeclaration]:
|
||||
@@ -65,7 +66,7 @@ class FunctionTool(BaseTool):
|
||||
func=self.func,
|
||||
# The model doesn't understand the function context.
|
||||
# input_stream is for streaming tool
|
||||
ignore_params=['tool_context', 'input_stream'],
|
||||
ignore_params=self._ignore_params,
|
||||
variant=self._api_variant,
|
||||
)
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user