mirror of
https://github.com/EvolutionAPI/adk-python.git
synced 2025-07-13 15:14:50 -06:00
feat: add customized bigquer tool wrapper class to facilitate developer to handcraft bigquery api tool
PiperOrigin-RevId: 762626700
This commit is contained in:
parent
0e284f45ff
commit
756a326033
28
src/google/adk/tools/bigquery/__init__.py
Normal file
28
src/google/adk/tools/bigquery/__init__.py
Normal file
@ -0,0 +1,28 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""BigQuery Tools. (Experimental)
|
||||
|
||||
BigQuery Tools under this module are hand crafted and customized while the tools
|
||||
under google.adk.tools.google_api_tool are auto generated based on API
|
||||
definition. The rationales to have customized tool are:
|
||||
|
||||
1. BigQuery APIs have functions overlaps and LLM can't tell what tool to use
|
||||
2. BigQuery APIs have a lot of parameters with some rarely used, which are not
|
||||
LLM-friendly
|
||||
3. We want to provide more high-level tools like forecasting, RAG, segmentation,
|
||||
etc.
|
||||
4. We want to provide extra access guardrails in those tools. For example,
|
||||
execute_sql can't arbitrarily mutate existing data.
|
||||
"""
|
185
src/google/adk/tools/bigquery/bigquery_credentials.py
Normal file
185
src/google/adk/tools/bigquery/bigquery_credentials.py
Normal file
@ -0,0 +1,185 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import List
|
||||
from typing import Optional
|
||||
|
||||
from fastapi.openapi.models import OAuth2
|
||||
from fastapi.openapi.models import OAuthFlowAuthorizationCode
|
||||
from fastapi.openapi.models import OAuthFlows
|
||||
from google.auth.exceptions import RefreshError
|
||||
from google.auth.transport.requests import Request
|
||||
from google.oauth2.credentials import Credentials
|
||||
from pydantic import BaseModel
|
||||
from pydantic import model_validator
|
||||
|
||||
from ...auth import AuthConfig
|
||||
from ...auth import AuthCredential
|
||||
from ...auth import AuthCredentialTypes
|
||||
from ...auth import OAuth2Auth
|
||||
from ..tool_context import ToolContext
|
||||
|
||||
BIGQUERY_TOKEN_CACHE_KEY = "bigquery_token_cache"
|
||||
|
||||
|
||||
class BigQueryCredentials(BaseModel):
|
||||
"""Configuration for Google API tools. (Experimental)"""
|
||||
|
||||
# Configure the model to allow arbitrary types like Credentials
|
||||
model_config = {"arbitrary_types_allowed": True}
|
||||
|
||||
credentials: Optional[Credentials] = None
|
||||
"""the existing oauth credentials to use. If set will override client ID,
|
||||
client secret, and scopes."""
|
||||
client_id: Optional[str] = None
|
||||
"""the oauth client ID to use."""
|
||||
client_secret: Optional[str] = None
|
||||
"""the oauth client secret to use."""
|
||||
scopes: Optional[List[str]] = None
|
||||
"""the scopes to use.
|
||||
"""
|
||||
|
||||
@model_validator(mode="after")
|
||||
def __post_init__(self) -> "BigQueryCredentials":
|
||||
"""Validate that either credentials or client ID/secret are provided."""
|
||||
if not self.credentials and (not self.client_id or not self.client_secret):
|
||||
raise ValueError(
|
||||
"Must provide either credentials or client_id abd client_secret pair."
|
||||
)
|
||||
if self.credentials:
|
||||
self.client_id = self.credentials.client_id
|
||||
self.client_secret = self.credentials.client_secret
|
||||
self.scopes = self.credentials.scopes
|
||||
return self
|
||||
|
||||
|
||||
class BigQueryCredentialsManager:
|
||||
"""Manages Google API credentials with automatic refresh and OAuth flow handling.
|
||||
|
||||
This class centralizes credential management so multiple tools can share
|
||||
the same authenticated session without duplicating OAuth logic.
|
||||
"""
|
||||
|
||||
def __init__(self, credentials: BigQueryCredentials):
|
||||
"""Initialize the credential manager.
|
||||
|
||||
Args:
|
||||
credential_config: Configuration containing OAuth details or existing
|
||||
credentials
|
||||
"""
|
||||
self.credentials = credentials
|
||||
|
||||
async def get_valid_credentials(
|
||||
self, tool_context: ToolContext
|
||||
) -> Optional[Credentials]:
|
||||
"""Get valid credentials, handling refresh and OAuth flow as needed.
|
||||
|
||||
Args:
|
||||
tool_context: The tool context for OAuth flow and state management
|
||||
required_scopes: Set of OAuth scopes required by the calling tool
|
||||
|
||||
Returns:
|
||||
Valid Credentials object, or None if OAuth flow is needed
|
||||
"""
|
||||
# First, try to get cached credentials from the instance
|
||||
creds = self.credentials.credentials
|
||||
|
||||
# If credentails are empty
|
||||
if not creds:
|
||||
creds = tool_context.get(BIGQUERY_TOKEN_CACHE_KEY, None)
|
||||
self.credentials.credentials = creds
|
||||
|
||||
# Check if we have valid credentials
|
||||
if creds and creds.valid:
|
||||
return creds
|
||||
|
||||
# Try to refresh expired credentials
|
||||
if creds and creds.expired and creds.refresh_token:
|
||||
try:
|
||||
creds.refresh(Request())
|
||||
if creds.valid:
|
||||
# Cache the refreshed credentials
|
||||
self.credentials.credentials = creds
|
||||
return creds
|
||||
except RefreshError:
|
||||
# Refresh failed, need to re-authenticate
|
||||
pass
|
||||
|
||||
# Need to perform OAuth flow
|
||||
return await self._perform_oauth_flow(tool_context)
|
||||
|
||||
async def _perform_oauth_flow(
|
||||
self, tool_context: ToolContext
|
||||
) -> Optional[Credentials]:
|
||||
"""Perform OAuth flow to get new credentials.
|
||||
|
||||
Args:
|
||||
tool_context: The tool context for OAuth flow
|
||||
required_scopes: Set of required OAuth scopes
|
||||
|
||||
Returns:
|
||||
New Credentials object, or None if flow is in progress
|
||||
"""
|
||||
|
||||
# Create OAuth configuration
|
||||
auth_scheme = OAuth2(
|
||||
flows=OAuthFlows(
|
||||
authorizationCode=OAuthFlowAuthorizationCode(
|
||||
authorizationUrl="https://accounts.google.com/o/oauth2/auth",
|
||||
tokenUrl="https://oauth2.googleapis.com/token",
|
||||
scopes={
|
||||
scope: f"Access to {scope}"
|
||||
for scope in self.credentials.scopes
|
||||
},
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
auth_credential = AuthCredential(
|
||||
auth_type=AuthCredentialTypes.OAUTH2,
|
||||
oauth2=OAuth2Auth(
|
||||
client_id=self.credentials.client_id,
|
||||
client_secret=self.credentials.client_secret,
|
||||
),
|
||||
)
|
||||
|
||||
# Check if OAuth response is available
|
||||
auth_response = tool_context.get_auth_response(
|
||||
AuthConfig(auth_scheme=auth_scheme, raw_auth_credential=auth_credential)
|
||||
)
|
||||
|
||||
if auth_response:
|
||||
# OAuth flow completed, create credentials
|
||||
creds = Credentials(
|
||||
token=auth_response.oauth2.access_token,
|
||||
refresh_token=auth_response.oauth2.refresh_token,
|
||||
token_uri=auth_scheme.flows.authorizationCode.tokenUrl,
|
||||
client_id=self.credentials.client_id,
|
||||
client_secret=self.credentials.client_secret,
|
||||
scopes=list(self.credentials.scopes),
|
||||
)
|
||||
|
||||
# Cache the new credentials
|
||||
self.credentials.credentials = creds
|
||||
tool_context.state[BIGQUERY_TOKEN_CACHE_KEY] = creds
|
||||
return creds
|
||||
else:
|
||||
# Request OAuth flow
|
||||
tool_context.request_credential(
|
||||
AuthConfig(
|
||||
auth_scheme=auth_scheme,
|
||||
raw_auth_credential=auth_credential,
|
||||
)
|
||||
)
|
||||
return None
|
116
src/google/adk/tools/bigquery/bigquery_tool.py
Normal file
116
src/google/adk/tools/bigquery/bigquery_tool.py
Normal file
@ -0,0 +1,116 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import inspect
|
||||
from typing import Any
|
||||
from typing import Callable
|
||||
from typing import Optional
|
||||
from typing import override
|
||||
|
||||
from google.oauth2.credentials import Credentials
|
||||
|
||||
from ..function_tool import FunctionTool
|
||||
from ..tool_context import ToolContext
|
||||
from .bigquery_credentials import BigQueryCredentials
|
||||
from .bigquery_credentials import BigQueryCredentialsManager
|
||||
|
||||
|
||||
class BigQueryTool(FunctionTool):
|
||||
"""GoogleApiTool class for tools that call Google APIs.
|
||||
|
||||
This class is for developers to handcraft customized Google API tools rather
|
||||
than auto generate Google API tools based on API specs.
|
||||
|
||||
This class handles all the OAuth complexity, credential management,
|
||||
and common Google API patterns so subclasses can focus on their
|
||||
specific functionality.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
func: Callable[..., Any],
|
||||
credentials: Optional[BigQueryCredentials] = None,
|
||||
):
|
||||
"""Initialize the Google API tool.
|
||||
|
||||
Args:
|
||||
func: callable that impelments the tool's logic, can accept one
|
||||
'credential" parameter
|
||||
credentials: credentials used to call Google API. If None, then we don't
|
||||
hanlde the auth logic
|
||||
"""
|
||||
super().__init__(func=func)
|
||||
self._ignore_params.append("credentials")
|
||||
self.credentials_manager = (
|
||||
BigQueryCredentialsManager(credentials) if credentials else None
|
||||
)
|
||||
|
||||
@override
|
||||
async def run_async(
|
||||
self, *, args: dict[str, Any], tool_context: ToolContext
|
||||
) -> Any:
|
||||
"""Main entry point for tool execution with credential handling.
|
||||
|
||||
This method handles all the OAuth complexity and then delegates
|
||||
to the subclass's run_async_with_credential method.
|
||||
"""
|
||||
try:
|
||||
# Get valid credentials
|
||||
credentials = (
|
||||
await self.credentials_manager.get_valid_credentials(tool_context)
|
||||
if self.credentials_manager
|
||||
else None
|
||||
)
|
||||
|
||||
if credentials is None and self.credentials_manager:
|
||||
# OAuth flow in progress
|
||||
return (
|
||||
"User authorization is required to access Google services for"
|
||||
f" {self.name}. Please complete the authorization flow."
|
||||
)
|
||||
|
||||
# Execute the tool's specific logic with valid credentials
|
||||
|
||||
return await self._run_async_with_credential(
|
||||
credentials, args, tool_context
|
||||
)
|
||||
|
||||
except Exception as ex:
|
||||
return {
|
||||
"status": "ERROR",
|
||||
"error_details": str(ex),
|
||||
}
|
||||
|
||||
async def _run_async_with_credential(
|
||||
self,
|
||||
credentials: Credentials,
|
||||
args: dict[str, Any],
|
||||
tool_context: ToolContext,
|
||||
) -> Any:
|
||||
"""Execute the tool's specific logic with valid credentials.
|
||||
|
||||
Args:
|
||||
credentials: Valid Google OAuth credentials
|
||||
args: Arguments passed to the tool
|
||||
tool_context: Tool execution context
|
||||
|
||||
Returns:
|
||||
The result of the tool execution
|
||||
"""
|
||||
args_to_call = args.copy()
|
||||
signature = inspect.signature(self.func)
|
||||
if "credentials" in signature.parameters:
|
||||
args_to_call["credentials"] = credentials
|
||||
return await super().run_async(args=args_to_call, tool_context=tool_context)
|
@ -57,6 +57,7 @@ class FunctionTool(BaseTool):
|
||||
|
||||
super().__init__(name=name, description=doc)
|
||||
self.func = func
|
||||
self._ignore_params = ['tool_context', 'input_stream']
|
||||
|
||||
@override
|
||||
def _get_declaration(self) -> Optional[types.FunctionDeclaration]:
|
||||
@ -65,7 +66,7 @@ class FunctionTool(BaseTool):
|
||||
func=self.func,
|
||||
# The model doesn't understand the function context.
|
||||
# input_stream is for streaming tool
|
||||
ignore_params=['tool_context', 'input_stream'],
|
||||
ignore_params=self._ignore_params,
|
||||
variant=self._api_variant,
|
||||
)
|
||||
)
|
||||
|
13
tests/unittests/tools/bigquery/__init__
Normal file
13
tests/unittests/tools/bigquery/__init__
Normal file
@ -0,0 +1,13 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
106
tests/unittests/tools/bigquery/test_bigquery_credentials.py
Normal file
106
tests/unittests/tools/bigquery/test_bigquery_credentials.py
Normal file
@ -0,0 +1,106 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from unittest.mock import Mock
|
||||
|
||||
from google.adk.tools.bigquery.bigquery_credentials import BigQueryCredentials
|
||||
# Mock the Google OAuth and API dependencies
|
||||
from google.oauth2.credentials import Credentials
|
||||
import pytest
|
||||
|
||||
|
||||
class TestBigQueryCredentials:
|
||||
"""Test suite for BigQueryCredentials configuration validation.
|
||||
|
||||
This class tests the credential configuration logic that ensures
|
||||
either existing credentials or client ID/secret pairs are provided.
|
||||
"""
|
||||
|
||||
def test_valid_credentials_object(self):
|
||||
"""Test that providing valid Credentials object works correctly.
|
||||
|
||||
When a user already has valid OAuth credentials, they should be able
|
||||
to pass them directly without needing to provide client ID/secret.
|
||||
"""
|
||||
# Create a mock credentials object with the expected attributes
|
||||
mock_creds = Mock(spec=Credentials)
|
||||
mock_creds.client_id = "test_client_id"
|
||||
mock_creds.client_secret = "test_client_secret"
|
||||
mock_creds.scopes = ["https://www.googleapis.com/auth/calendar"]
|
||||
|
||||
config = BigQueryCredentials(credentials=mock_creds)
|
||||
|
||||
# Verify that the credentials are properly stored and attributes are extracted
|
||||
assert config.credentials == mock_creds
|
||||
assert config.client_id == "test_client_id"
|
||||
assert config.client_secret == "test_client_secret"
|
||||
assert config.scopes == ["https://www.googleapis.com/auth/calendar"]
|
||||
|
||||
def test_valid_client_id_secret_pair(self):
|
||||
"""Test that providing client ID and secret without credentials works.
|
||||
|
||||
This tests the scenario where users want to create new OAuth credentials
|
||||
from scratch using their application's client ID and secret.
|
||||
"""
|
||||
config = BigQueryCredentials(
|
||||
client_id="test_client_id",
|
||||
client_secret="test_client_secret",
|
||||
scopes=["https://www.googleapis.com/auth/bigquery"],
|
||||
)
|
||||
|
||||
assert config.credentials is None
|
||||
assert config.client_id == "test_client_id"
|
||||
assert config.client_secret == "test_client_secret"
|
||||
assert config.scopes == ["https://www.googleapis.com/auth/bigquery"]
|
||||
|
||||
def test_missing_client_secret_raises_error(self):
|
||||
"""Test that missing client secret raises appropriate validation error.
|
||||
|
||||
This ensures that incomplete OAuth configuration is caught early
|
||||
rather than failing during runtime.
|
||||
"""
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
match=(
|
||||
"Must provide either credentials or client_id abd client_secret"
|
||||
" pair"
|
||||
),
|
||||
):
|
||||
BigQueryCredentials(client_id="test_client_id")
|
||||
|
||||
def test_missing_client_id_raises_error(self):
|
||||
"""Test that missing client ID raises appropriate validation error."""
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
match=(
|
||||
"Must provide either credentials or client_id abd client_secret"
|
||||
" pair"
|
||||
),
|
||||
):
|
||||
BigQueryCredentials(client_secret="test_client_secret")
|
||||
|
||||
def test_empty_configuration_raises_error(self):
|
||||
"""Test that completely empty configuration is rejected.
|
||||
|
||||
Users must provide either existing credentials or the components
|
||||
needed to create new ones.
|
||||
"""
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
match=(
|
||||
"Must provide either credentials or client_id abd client_secret"
|
||||
" pair"
|
||||
),
|
||||
):
|
||||
BigQueryCredentials()
|
@ -0,0 +1,327 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
from unittest.mock import Mock
|
||||
from unittest.mock import patch
|
||||
|
||||
from google.adk.auth import AuthConfig
|
||||
from google.adk.tools import ToolContext
|
||||
from google.adk.tools.bigquery.bigquery_credentials import BIGQUERY_TOKEN_CACHE_KEY
|
||||
from google.adk.tools.bigquery.bigquery_credentials import BigQueryCredentials
|
||||
from google.adk.tools.bigquery.bigquery_credentials import BigQueryCredentialsManager
|
||||
from google.auth.exceptions import RefreshError
|
||||
# Mock the Google OAuth and API dependencies
|
||||
from google.oauth2.credentials import Credentials
|
||||
import pytest
|
||||
|
||||
|
||||
class TestBigQueryCredentialsManager:
|
||||
"""Test suite for BigQueryCredentialsManager OAuth flow handling.
|
||||
|
||||
This class tests the complex credential management logic including
|
||||
credential validation, refresh, OAuth flow orchestration, and the
|
||||
new token caching functionality through tool_context.state.
|
||||
"""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_tool_context(self):
|
||||
"""Create a mock ToolContext for testing.
|
||||
|
||||
The ToolContext is the interface between tools and the broader
|
||||
agent framework, handling OAuth flows and state management.
|
||||
Now includes state dictionary for testing caching behavior.
|
||||
"""
|
||||
context = Mock(spec=ToolContext)
|
||||
context.get_auth_response = Mock(return_value=None)
|
||||
context.request_credential = Mock()
|
||||
# Mock the get method and state dictionary for caching tests
|
||||
context.get = Mock(return_value=None)
|
||||
context.state = {}
|
||||
return context
|
||||
|
||||
@pytest.fixture
|
||||
def credentials_config(self):
|
||||
"""Create a basic credentials configuration for testing."""
|
||||
return BigQueryCredentials(
|
||||
client_id="test_client_id",
|
||||
client_secret="test_client_secret",
|
||||
scopes=["https://www.googleapis.com/auth/calendar"],
|
||||
)
|
||||
|
||||
@pytest.fixture
|
||||
def manager(self, credentials_config):
|
||||
"""Create a credentials manager instance for testing."""
|
||||
return BigQueryCredentialsManager(credentials_config)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_valid_credentials_with_valid_existing_creds(
|
||||
self, manager, mock_tool_context
|
||||
):
|
||||
"""Test that valid existing credentials are returned immediately.
|
||||
|
||||
When credentials are already valid, no refresh or OAuth flow
|
||||
should be needed. This is the optimal happy path scenario.
|
||||
"""
|
||||
# Create mock credentials that are already valid
|
||||
mock_creds = Mock(spec=Credentials)
|
||||
mock_creds.valid = True
|
||||
manager.credentials.credentials = mock_creds
|
||||
|
||||
result = await manager.get_valid_credentials(mock_tool_context)
|
||||
|
||||
assert result == mock_creds
|
||||
# Verify no OAuth flow was triggered
|
||||
mock_tool_context.get_auth_response.assert_not_called()
|
||||
mock_tool_context.request_credential.assert_not_called()
|
||||
# Verify cache retrieval wasn't needed since we had valid creds
|
||||
mock_tool_context.get.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_credentials_from_cache_when_none_in_manager(
|
||||
self, manager, mock_tool_context
|
||||
):
|
||||
"""Test retrieving credentials from tool_context cache when manager has none.
|
||||
|
||||
This tests the new caching functionality where credentials can be
|
||||
retrieved from the tool context state when the manager instance
|
||||
doesn't have them loaded.
|
||||
"""
|
||||
# Manager starts with no credentials
|
||||
manager.credentials.credentials = None
|
||||
|
||||
# Create mock cached credentials that are valid
|
||||
mock_cached_creds = Mock(spec=Credentials)
|
||||
mock_cached_creds.valid = True
|
||||
|
||||
# Set up the tool context to return cached credentials
|
||||
mock_tool_context.get.return_value = mock_cached_creds
|
||||
|
||||
result = await manager.get_valid_credentials(mock_tool_context)
|
||||
|
||||
# Verify credentials were retrieved from cache
|
||||
mock_tool_context.get.assert_called_once_with(
|
||||
BIGQUERY_TOKEN_CACHE_KEY, None
|
||||
)
|
||||
# Verify cached credentials were loaded into manager
|
||||
assert manager.credentials.credentials == mock_cached_creds
|
||||
# Verify valid cached credentials were returned
|
||||
assert result == mock_cached_creds
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_credentials_in_manager_or_cache(
|
||||
self, manager, mock_tool_context
|
||||
):
|
||||
"""Test OAuth flow when no credentials exist in manager or cache.
|
||||
|
||||
This tests the scenario where both the manager and cache are empty,
|
||||
requiring a new OAuth flow to be initiated.
|
||||
"""
|
||||
# Manager starts with no credentials
|
||||
manager.credentials.credentials = None
|
||||
# Cache also returns None
|
||||
mock_tool_context.get.return_value = None
|
||||
|
||||
result = await manager.get_valid_credentials(mock_tool_context)
|
||||
|
||||
# Should trigger OAuth flow and return None (flow in progress)
|
||||
assert result is None
|
||||
mock_tool_context.get.assert_called_once_with(
|
||||
BIGQUERY_TOKEN_CACHE_KEY, None
|
||||
)
|
||||
mock_tool_context.request_credential.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("google.auth.transport.requests.Request")
|
||||
async def test_refresh_cached_credentials_success(
|
||||
self, mock_request_class, manager, mock_tool_context
|
||||
):
|
||||
"""Test successful refresh of expired credentials retrieved from cache.
|
||||
|
||||
This tests the interaction between caching and refresh functionality,
|
||||
ensuring that expired cached credentials can be refreshed properly.
|
||||
"""
|
||||
# Manager starts with no credentials
|
||||
manager.credentials.credentials = None
|
||||
|
||||
# Create expired cached credentials with refresh token
|
||||
mock_cached_creds = Mock(spec=Credentials)
|
||||
mock_cached_creds.valid = False
|
||||
mock_cached_creds.expired = True
|
||||
mock_cached_creds.refresh_token = "refresh_token"
|
||||
|
||||
# Mock successful refresh
|
||||
def mock_refresh(request):
|
||||
mock_cached_creds.valid = True
|
||||
|
||||
mock_cached_creds.refresh = Mock(side_effect=mock_refresh)
|
||||
mock_tool_context.get.return_value = mock_cached_creds
|
||||
|
||||
result = await manager.get_valid_credentials(mock_tool_context)
|
||||
|
||||
# Verify credentials were retrieved from cache
|
||||
mock_tool_context.get.assert_called_once_with(
|
||||
BIGQUERY_TOKEN_CACHE_KEY, None
|
||||
)
|
||||
# Verify refresh was attempted and succeeded
|
||||
mock_cached_creds.refresh.assert_called_once()
|
||||
# Verify refreshed credentials were loaded into manager
|
||||
assert manager.credentials.credentials == mock_cached_creds
|
||||
assert result == mock_cached_creds
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("google.auth.transport.requests.Request")
|
||||
async def test_get_valid_credentials_with_refresh_success(
|
||||
self, mock_request_class, manager, mock_tool_context
|
||||
):
|
||||
"""Test successful credential refresh when tokens are expired.
|
||||
|
||||
This tests the automatic token refresh capability that prevents
|
||||
users from having to re-authenticate for every expired token.
|
||||
"""
|
||||
# Create expired credentials with refresh token
|
||||
mock_creds = Mock(spec=Credentials)
|
||||
mock_creds.valid = False
|
||||
mock_creds.expired = True
|
||||
mock_creds.refresh_token = "refresh_token"
|
||||
|
||||
# Mock successful refresh
|
||||
def mock_refresh(request):
|
||||
mock_creds.valid = True
|
||||
|
||||
mock_creds.refresh = Mock(side_effect=mock_refresh)
|
||||
manager.credentials.credentials = mock_creds
|
||||
|
||||
result = await manager.get_valid_credentials(mock_tool_context)
|
||||
|
||||
assert result == mock_creds
|
||||
mock_creds.refresh.assert_called_once()
|
||||
# Verify credentials were cached after successful refresh
|
||||
assert manager.credentials.credentials == mock_creds
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("google.auth.transport.requests.Request")
|
||||
async def test_get_valid_credentials_with_refresh_failure(
|
||||
self, mock_request_class, manager, mock_tool_context
|
||||
):
|
||||
"""Test OAuth flow trigger when credential refresh fails.
|
||||
|
||||
When refresh tokens expire or become invalid, the system should
|
||||
gracefully fall back to requesting a new OAuth flow.
|
||||
"""
|
||||
# Create expired credentials that fail to refresh
|
||||
mock_creds = Mock(spec=Credentials)
|
||||
mock_creds.valid = False
|
||||
mock_creds.expired = True
|
||||
mock_creds.refresh_token = "expired_refresh_token"
|
||||
mock_creds.refresh = Mock(side_effect=RefreshError("Refresh failed"))
|
||||
manager.credentials.credentials = mock_creds
|
||||
|
||||
result = await manager.get_valid_credentials(mock_tool_context)
|
||||
|
||||
# Should trigger OAuth flow and return None (flow in progress)
|
||||
assert result is None
|
||||
mock_tool_context.request_credential.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_oauth_flow_completion_with_caching(
|
||||
self, manager, mock_tool_context
|
||||
):
|
||||
"""Test successful OAuth flow completion with proper credential caching.
|
||||
|
||||
This tests the happy path where a user completes the OAuth flow
|
||||
and the system successfully creates and caches new credentials
|
||||
in both the manager and the tool context state.
|
||||
"""
|
||||
# Mock OAuth response indicating completed flow
|
||||
mock_auth_response = Mock()
|
||||
mock_auth_response.oauth2.access_token = "new_access_token"
|
||||
mock_auth_response.oauth2.refresh_token = "new_refresh_token"
|
||||
mock_tool_context.get_auth_response.return_value = mock_auth_response
|
||||
|
||||
result = await manager.get_valid_credentials(mock_tool_context)
|
||||
|
||||
# Verify new credentials were created and cached
|
||||
assert isinstance(result, Credentials)
|
||||
assert result.token == "new_access_token"
|
||||
assert result.refresh_token == "new_refresh_token"
|
||||
# Verify credentials are cached in manager
|
||||
assert manager.credentials.credentials == result
|
||||
# Verify credentials are also cached in tool context state
|
||||
assert mock_tool_context.state[BIGQUERY_TOKEN_CACHE_KEY] == result
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_oauth_flow_in_progress(self, manager, mock_tool_context):
|
||||
"""Test OAuth flow initiation when no auth response is available.
|
||||
|
||||
This tests the case where the OAuth flow needs to be started,
|
||||
and the user hasn't completed authorization yet.
|
||||
"""
|
||||
# No existing credentials, no auth response (flow not completed)
|
||||
mock_tool_context.get_auth_response.return_value = None
|
||||
|
||||
result = await manager.get_valid_credentials(mock_tool_context)
|
||||
|
||||
# Should return None and request credential flow
|
||||
assert result is None
|
||||
mock_tool_context.request_credential.assert_called_once()
|
||||
|
||||
# Verify the auth configuration includes correct scopes and endpoints
|
||||
call_args = mock_tool_context.request_credential.call_args[0][0]
|
||||
assert isinstance(call_args, AuthConfig)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cache_persistence_across_manager_instances(
|
||||
self, credentials_config, mock_tool_context
|
||||
):
|
||||
"""Test that cached credentials persist across different manager instances.
|
||||
|
||||
This tests the key benefit of the tool context caching - that
|
||||
credentials can be shared between different instances of the
|
||||
credential manager, avoiding redundant OAuth flows.
|
||||
"""
|
||||
# Create first manager instance and simulate OAuth completion
|
||||
manager1 = BigQueryCredentialsManager(credentials_config)
|
||||
|
||||
# Mock OAuth response for first manager
|
||||
mock_auth_response = Mock()
|
||||
mock_auth_response.oauth2.access_token = "cached_access_token"
|
||||
mock_auth_response.oauth2.refresh_token = "cached_refresh_token"
|
||||
mock_tool_context.get_auth_response.return_value = mock_auth_response
|
||||
|
||||
# Complete OAuth flow with first manager
|
||||
result1 = await manager1.get_valid_credentials(mock_tool_context)
|
||||
|
||||
# Verify credentials were cached in tool context
|
||||
assert BIGQUERY_TOKEN_CACHE_KEY in mock_tool_context.state
|
||||
cached_creds = mock_tool_context.state[BIGQUERY_TOKEN_CACHE_KEY]
|
||||
|
||||
# Create second manager instance (simulating new request/session)
|
||||
manager2 = BigQueryCredentialsManager(credentials_config)
|
||||
|
||||
# Reset auth response to None (no new OAuth flow available)
|
||||
mock_tool_context.get_auth_response.return_value = None
|
||||
# Set up get method to return cached credentials
|
||||
mock_tool_context.get.return_value = cached_creds
|
||||
|
||||
# Get credentials with second manager
|
||||
result2 = await manager2.get_valid_credentials(mock_tool_context)
|
||||
|
||||
# Verify second manager retrieved cached credentials successfully
|
||||
assert result2 == cached_creds
|
||||
assert manager2.credentials.credentials == cached_creds
|
||||
# Verify no new OAuth flow was requested
|
||||
assert (
|
||||
mock_tool_context.request_credential.call_count == 0
|
||||
) # Only from first manager
|
259
tests/unittests/tools/bigquery/test_bigquery_tool.py
Normal file
259
tests/unittests/tools/bigquery/test_bigquery_tool.py
Normal file
@ -0,0 +1,259 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
from unittest.mock import Mock
|
||||
from unittest.mock import patch
|
||||
|
||||
from google.adk.tools import ToolContext
|
||||
from google.adk.tools.bigquery.bigquery_credentials import BigQueryCredentials
|
||||
from google.adk.tools.bigquery.bigquery_credentials import BigQueryCredentialsManager
|
||||
from google.adk.tools.bigquery.bigquery_tool import BigQueryTool
|
||||
# Mock the Google OAuth and API dependencies
|
||||
from google.oauth2.credentials import Credentials
|
||||
import pytest
|
||||
|
||||
|
||||
class TestBigQueryTool:
|
||||
"""Test suite for BigQueryTool OAuth integration and execution.
|
||||
|
||||
This class tests the high-level tool execution logic that combines
|
||||
credential management with actual function execution.
|
||||
"""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_tool_context(self):
|
||||
"""Create a mock ToolContext for testing tool execution."""
|
||||
context = Mock(spec=ToolContext)
|
||||
context.get_auth_response = Mock(return_value=None)
|
||||
context.request_credential = Mock()
|
||||
return context
|
||||
|
||||
@pytest.fixture
|
||||
def sample_function(self):
|
||||
"""Create a sample function that accepts credentials for testing.
|
||||
|
||||
This simulates a real Google API tool function that needs
|
||||
authenticated credentials to perform its work.
|
||||
"""
|
||||
|
||||
def sample_func(param1: str, credentials: Credentials = None) -> dict:
|
||||
"""Sample function that uses Google API credentials."""
|
||||
if credentials:
|
||||
return {"result": f"Success with {param1}", "authenticated": True}
|
||||
else:
|
||||
return {"result": f"Success with {param1}", "authenticated": False}
|
||||
|
||||
return sample_func
|
||||
|
||||
@pytest.fixture
|
||||
def async_sample_function(self):
|
||||
"""Create an async sample function for testing async execution paths."""
|
||||
|
||||
async def async_sample_func(
|
||||
param1: str, credentials: Credentials = None
|
||||
) -> dict:
|
||||
"""Async sample function that uses Google API credentials."""
|
||||
if credentials:
|
||||
return {"result": f"Async success with {param1}", "authenticated": True}
|
||||
else:
|
||||
return {
|
||||
"result": f"Async success with {param1}",
|
||||
"authenticated": False,
|
||||
}
|
||||
|
||||
return async_sample_func
|
||||
|
||||
@pytest.fixture
|
||||
def credentials_config(self):
|
||||
"""Create credentials configuration for testing."""
|
||||
return BigQueryCredentials(
|
||||
client_id="test_client_id",
|
||||
client_secret="test_client_secret",
|
||||
scopes=["https://www.googleapis.com/auth/bigquery"],
|
||||
)
|
||||
|
||||
def test_tool_initialization_with_credentials(
|
||||
self, sample_function, credentials_config
|
||||
):
|
||||
"""Test that BigQueryTool initializes correctly with credentials.
|
||||
|
||||
The tool should properly inherit from FunctionTool while adding
|
||||
Google API specific credential management capabilities.
|
||||
"""
|
||||
tool = BigQueryTool(func=sample_function, credentials=credentials_config)
|
||||
|
||||
assert tool.func == sample_function
|
||||
assert tool.credentials_manager is not None
|
||||
assert isinstance(tool.credentials_manager, BigQueryCredentialsManager)
|
||||
# Verify that 'credentials' parameter is ignored in function signature analysis
|
||||
assert "credentials" in tool._ignore_params
|
||||
|
||||
def test_tool_initialization_without_credentials(self, sample_function):
|
||||
"""Test tool initialization when no credential management is needed.
|
||||
|
||||
Some tools might handle authentication externally or use service
|
||||
accounts, so credential management should be optional.
|
||||
"""
|
||||
tool = BigQueryTool(func=sample_function, credentials=None)
|
||||
|
||||
assert tool.func == sample_function
|
||||
assert tool.credentials_manager is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_async_with_valid_credentials(
|
||||
self, sample_function, credentials_config, mock_tool_context
|
||||
):
|
||||
"""Test successful tool execution with valid credentials.
|
||||
|
||||
This tests the main happy path where credentials are available
|
||||
and the underlying function executes successfully.
|
||||
"""
|
||||
tool = BigQueryTool(func=sample_function, credentials=credentials_config)
|
||||
|
||||
# Mock the credentials manager to return valid credentials
|
||||
mock_creds = Mock(spec=Credentials)
|
||||
with patch.object(
|
||||
tool.credentials_manager,
|
||||
"get_valid_credentials",
|
||||
return_value=mock_creds,
|
||||
) as mock_get_creds:
|
||||
|
||||
result = await tool.run_async(
|
||||
args={"param1": "test_value"}, tool_context=mock_tool_context
|
||||
)
|
||||
|
||||
mock_get_creds.assert_called_once_with(mock_tool_context)
|
||||
assert result["result"] == "Success with test_value"
|
||||
assert result["authenticated"] is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_async_oauth_flow_in_progress(
|
||||
self, sample_function, credentials_config, mock_tool_context
|
||||
):
|
||||
"""Test tool behavior when OAuth flow is in progress.
|
||||
|
||||
When credentials aren't available and OAuth flow is needed,
|
||||
the tool should return a user-friendly message rather than failing.
|
||||
"""
|
||||
tool = BigQueryTool(func=sample_function, credentials=credentials_config)
|
||||
|
||||
# Mock credentials manager to return None (OAuth flow in progress)
|
||||
with patch.object(
|
||||
tool.credentials_manager, "get_valid_credentials", return_value=None
|
||||
) as mock_get_creds:
|
||||
|
||||
result = await tool.run_async(
|
||||
args={"param1": "test_value"}, tool_context=mock_tool_context
|
||||
)
|
||||
|
||||
mock_get_creds.assert_called_once_with(mock_tool_context)
|
||||
assert "authorization is required" in result.lower()
|
||||
assert tool.name in result
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_async_without_credentials_manager(
|
||||
self, sample_function, mock_tool_context
|
||||
):
|
||||
"""Test tool execution when no credential management is configured.
|
||||
|
||||
Tools without credential managers should execute normally,
|
||||
passing None for credentials if the function accepts them.
|
||||
"""
|
||||
tool = BigQueryTool(func=sample_function, credentials=None)
|
||||
|
||||
result = await tool.run_async(
|
||||
args={"param1": "test_value"}, tool_context=mock_tool_context
|
||||
)
|
||||
|
||||
assert result["result"] == "Success with test_value"
|
||||
assert result["authenticated"] is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_async_with_async_function(
|
||||
self, async_sample_function, credentials_config, mock_tool_context
|
||||
):
|
||||
"""Test that async functions are properly handled.
|
||||
|
||||
The tool should correctly detect and execute async functions,
|
||||
which is important for tools that make async API calls.
|
||||
"""
|
||||
tool = BigQueryTool(
|
||||
func=async_sample_function, credentials=credentials_config
|
||||
)
|
||||
|
||||
mock_creds = Mock(spec=Credentials)
|
||||
with patch.object(
|
||||
tool.credentials_manager,
|
||||
"get_valid_credentials",
|
||||
return_value=mock_creds,
|
||||
):
|
||||
|
||||
result = await tool.run_async(
|
||||
args={"param1": "test_value"}, tool_context=mock_tool_context
|
||||
)
|
||||
|
||||
assert result["result"] == "Async success with test_value"
|
||||
assert result["authenticated"] is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_async_exception_handling(
|
||||
self, credentials_config, mock_tool_context
|
||||
):
|
||||
"""Test that exceptions in tool execution are properly handled.
|
||||
|
||||
Tools should gracefully handle errors and return structured
|
||||
error responses rather than letting exceptions propagate.
|
||||
"""
|
||||
|
||||
def failing_function(param1: str, credentials: Credentials = None) -> dict:
|
||||
raise ValueError("Something went wrong")
|
||||
|
||||
tool = BigQueryTool(func=failing_function, credentials=credentials_config)
|
||||
|
||||
mock_creds = Mock(spec=Credentials)
|
||||
with patch.object(
|
||||
tool.credentials_manager,
|
||||
"get_valid_credentials",
|
||||
return_value=mock_creds,
|
||||
):
|
||||
|
||||
result = await tool.run_async(
|
||||
args={"param1": "test_value"}, tool_context=mock_tool_context
|
||||
)
|
||||
|
||||
assert result["status"] == "ERROR"
|
||||
assert "Something went wrong" in result["error_details"]
|
||||
|
||||
def test_function_signature_analysis(self, credentials_config):
|
||||
"""Test that function signature analysis correctly handles credentials parameter.
|
||||
|
||||
The tool should properly identify and handle the credentials parameter
|
||||
while preserving other parameter analysis for LLM function calling.
|
||||
"""
|
||||
|
||||
def complex_function(
|
||||
required_param: str,
|
||||
optional_param: str = "default",
|
||||
credentials: Credentials = None,
|
||||
) -> dict:
|
||||
return {"success": True}
|
||||
|
||||
tool = BigQueryTool(func=complex_function, credentials=credentials_config)
|
||||
|
||||
# The 'credentials' parameter should be ignored in mandatory args analysis
|
||||
mandatory_args = tool._get_mandatory_args()
|
||||
assert "required_param" in mandatory_args
|
||||
assert "credentials" not in mandatory_args
|
||||
assert "optional_param" not in mandatory_args
|
Loading…
Reference in New Issue
Block a user