diff --git a/src/google/adk/tools/bigquery/__init__.py b/src/google/adk/tools/bigquery/__init__.py new file mode 100644 index 0000000..72054bb --- /dev/null +++ b/src/google/adk/tools/bigquery/__init__.py @@ -0,0 +1,28 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""BigQuery Tools. (Experimental) + +BigQuery Tools under this module are hand crafted and customized while the tools +under google.adk.tools.google_api_tool are auto generated based on API +definition. The rationales to have customized tool are: + +1. BigQuery APIs have functions overlaps and LLM can't tell what tool to use +2. BigQuery APIs have a lot of parameters with some rarely used, which are not + LLM-friendly +3. We want to provide more high-level tools like forecasting, RAG, segmentation, + etc. +4. We want to provide extra access guardrails in those tools. For example, + execute_sql can't arbitrarily mutate existing data. +""" diff --git a/src/google/adk/tools/bigquery/bigquery_credentials.py b/src/google/adk/tools/bigquery/bigquery_credentials.py new file mode 100644 index 0000000..3738685 --- /dev/null +++ b/src/google/adk/tools/bigquery/bigquery_credentials.py @@ -0,0 +1,185 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import List +from typing import Optional + +from fastapi.openapi.models import OAuth2 +from fastapi.openapi.models import OAuthFlowAuthorizationCode +from fastapi.openapi.models import OAuthFlows +from google.auth.exceptions import RefreshError +from google.auth.transport.requests import Request +from google.oauth2.credentials import Credentials +from pydantic import BaseModel +from pydantic import model_validator + +from ...auth import AuthConfig +from ...auth import AuthCredential +from ...auth import AuthCredentialTypes +from ...auth import OAuth2Auth +from ..tool_context import ToolContext + +BIGQUERY_TOKEN_CACHE_KEY = "bigquery_token_cache" + + +class BigQueryCredentials(BaseModel): + """Configuration for Google API tools. (Experimental)""" + + # Configure the model to allow arbitrary types like Credentials + model_config = {"arbitrary_types_allowed": True} + + credentials: Optional[Credentials] = None + """the existing oauth credentials to use. If set will override client ID, + client secret, and scopes.""" + client_id: Optional[str] = None + """the oauth client ID to use.""" + client_secret: Optional[str] = None + """the oauth client secret to use.""" + scopes: Optional[List[str]] = None + """the scopes to use. + """ + + @model_validator(mode="after") + def __post_init__(self) -> "BigQueryCredentials": + """Validate that either credentials or client ID/secret are provided.""" + if not self.credentials and (not self.client_id or not self.client_secret): + raise ValueError( + "Must provide either credentials or client_id abd client_secret pair." + ) + if self.credentials: + self.client_id = self.credentials.client_id + self.client_secret = self.credentials.client_secret + self.scopes = self.credentials.scopes + return self + + +class BigQueryCredentialsManager: + """Manages Google API credentials with automatic refresh and OAuth flow handling. + + This class centralizes credential management so multiple tools can share + the same authenticated session without duplicating OAuth logic. + """ + + def __init__(self, credentials: BigQueryCredentials): + """Initialize the credential manager. + + Args: + credential_config: Configuration containing OAuth details or existing + credentials + """ + self.credentials = credentials + + async def get_valid_credentials( + self, tool_context: ToolContext + ) -> Optional[Credentials]: + """Get valid credentials, handling refresh and OAuth flow as needed. + + Args: + tool_context: The tool context for OAuth flow and state management + required_scopes: Set of OAuth scopes required by the calling tool + + Returns: + Valid Credentials object, or None if OAuth flow is needed + """ + # First, try to get cached credentials from the instance + creds = self.credentials.credentials + + # If credentails are empty + if not creds: + creds = tool_context.get(BIGQUERY_TOKEN_CACHE_KEY, None) + self.credentials.credentials = creds + + # Check if we have valid credentials + if creds and creds.valid: + return creds + + # Try to refresh expired credentials + if creds and creds.expired and creds.refresh_token: + try: + creds.refresh(Request()) + if creds.valid: + # Cache the refreshed credentials + self.credentials.credentials = creds + return creds + except RefreshError: + # Refresh failed, need to re-authenticate + pass + + # Need to perform OAuth flow + return await self._perform_oauth_flow(tool_context) + + async def _perform_oauth_flow( + self, tool_context: ToolContext + ) -> Optional[Credentials]: + """Perform OAuth flow to get new credentials. + + Args: + tool_context: The tool context for OAuth flow + required_scopes: Set of required OAuth scopes + + Returns: + New Credentials object, or None if flow is in progress + """ + + # Create OAuth configuration + auth_scheme = OAuth2( + flows=OAuthFlows( + authorizationCode=OAuthFlowAuthorizationCode( + authorizationUrl="https://accounts.google.com/o/oauth2/auth", + tokenUrl="https://oauth2.googleapis.com/token", + scopes={ + scope: f"Access to {scope}" + for scope in self.credentials.scopes + }, + ) + ) + ) + + auth_credential = AuthCredential( + auth_type=AuthCredentialTypes.OAUTH2, + oauth2=OAuth2Auth( + client_id=self.credentials.client_id, + client_secret=self.credentials.client_secret, + ), + ) + + # Check if OAuth response is available + auth_response = tool_context.get_auth_response( + AuthConfig(auth_scheme=auth_scheme, raw_auth_credential=auth_credential) + ) + + if auth_response: + # OAuth flow completed, create credentials + creds = Credentials( + token=auth_response.oauth2.access_token, + refresh_token=auth_response.oauth2.refresh_token, + token_uri=auth_scheme.flows.authorizationCode.tokenUrl, + client_id=self.credentials.client_id, + client_secret=self.credentials.client_secret, + scopes=list(self.credentials.scopes), + ) + + # Cache the new credentials + self.credentials.credentials = creds + tool_context.state[BIGQUERY_TOKEN_CACHE_KEY] = creds + return creds + else: + # Request OAuth flow + tool_context.request_credential( + AuthConfig( + auth_scheme=auth_scheme, + raw_auth_credential=auth_credential, + ) + ) + return None diff --git a/src/google/adk/tools/bigquery/bigquery_tool.py b/src/google/adk/tools/bigquery/bigquery_tool.py new file mode 100644 index 0000000..ba0dd48 --- /dev/null +++ b/src/google/adk/tools/bigquery/bigquery_tool.py @@ -0,0 +1,116 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import inspect +from typing import Any +from typing import Callable +from typing import Optional +from typing import override + +from google.oauth2.credentials import Credentials + +from ..function_tool import FunctionTool +from ..tool_context import ToolContext +from .bigquery_credentials import BigQueryCredentials +from .bigquery_credentials import BigQueryCredentialsManager + + +class BigQueryTool(FunctionTool): + """GoogleApiTool class for tools that call Google APIs. + + This class is for developers to handcraft customized Google API tools rather + than auto generate Google API tools based on API specs. + + This class handles all the OAuth complexity, credential management, + and common Google API patterns so subclasses can focus on their + specific functionality. + """ + + def __init__( + self, + func: Callable[..., Any], + credentials: Optional[BigQueryCredentials] = None, + ): + """Initialize the Google API tool. + + Args: + func: callable that impelments the tool's logic, can accept one + 'credential" parameter + credentials: credentials used to call Google API. If None, then we don't + hanlde the auth logic + """ + super().__init__(func=func) + self._ignore_params.append("credentials") + self.credentials_manager = ( + BigQueryCredentialsManager(credentials) if credentials else None + ) + + @override + async def run_async( + self, *, args: dict[str, Any], tool_context: ToolContext + ) -> Any: + """Main entry point for tool execution with credential handling. + + This method handles all the OAuth complexity and then delegates + to the subclass's run_async_with_credential method. + """ + try: + # Get valid credentials + credentials = ( + await self.credentials_manager.get_valid_credentials(tool_context) + if self.credentials_manager + else None + ) + + if credentials is None and self.credentials_manager: + # OAuth flow in progress + return ( + "User authorization is required to access Google services for" + f" {self.name}. Please complete the authorization flow." + ) + + # Execute the tool's specific logic with valid credentials + + return await self._run_async_with_credential( + credentials, args, tool_context + ) + + except Exception as ex: + return { + "status": "ERROR", + "error_details": str(ex), + } + + async def _run_async_with_credential( + self, + credentials: Credentials, + args: dict[str, Any], + tool_context: ToolContext, + ) -> Any: + """Execute the tool's specific logic with valid credentials. + + Args: + credentials: Valid Google OAuth credentials + args: Arguments passed to the tool + tool_context: Tool execution context + + Returns: + The result of the tool execution + """ + args_to_call = args.copy() + signature = inspect.signature(self.func) + if "credentials" in signature.parameters: + args_to_call["credentials"] = credentials + return await super().run_async(args=args_to_call, tool_context=tool_context) diff --git a/src/google/adk/tools/function_tool.py b/src/google/adk/tools/function_tool.py index 069108c..30c1a11 100644 --- a/src/google/adk/tools/function_tool.py +++ b/src/google/adk/tools/function_tool.py @@ -57,6 +57,7 @@ class FunctionTool(BaseTool): super().__init__(name=name, description=doc) self.func = func + self._ignore_params = ['tool_context', 'input_stream'] @override def _get_declaration(self) -> Optional[types.FunctionDeclaration]: @@ -65,7 +66,7 @@ class FunctionTool(BaseTool): func=self.func, # The model doesn't understand the function context. # input_stream is for streaming tool - ignore_params=['tool_context', 'input_stream'], + ignore_params=self._ignore_params, variant=self._api_variant, ) ) diff --git a/tests/unittests/tools/bigquery/__init__ b/tests/unittests/tools/bigquery/__init__ new file mode 100644 index 0000000..0a2669d --- /dev/null +++ b/tests/unittests/tools/bigquery/__init__ @@ -0,0 +1,13 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tests/unittests/tools/bigquery/test_bigquery_credentials.py b/tests/unittests/tools/bigquery/test_bigquery_credentials.py new file mode 100644 index 0000000..7937ccc --- /dev/null +++ b/tests/unittests/tools/bigquery/test_bigquery_credentials.py @@ -0,0 +1,106 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from unittest.mock import Mock + +from google.adk.tools.bigquery.bigquery_credentials import BigQueryCredentials +# Mock the Google OAuth and API dependencies +from google.oauth2.credentials import Credentials +import pytest + + +class TestBigQueryCredentials: + """Test suite for BigQueryCredentials configuration validation. + + This class tests the credential configuration logic that ensures + either existing credentials or client ID/secret pairs are provided. + """ + + def test_valid_credentials_object(self): + """Test that providing valid Credentials object works correctly. + + When a user already has valid OAuth credentials, they should be able + to pass them directly without needing to provide client ID/secret. + """ + # Create a mock credentials object with the expected attributes + mock_creds = Mock(spec=Credentials) + mock_creds.client_id = "test_client_id" + mock_creds.client_secret = "test_client_secret" + mock_creds.scopes = ["https://www.googleapis.com/auth/calendar"] + + config = BigQueryCredentials(credentials=mock_creds) + + # Verify that the credentials are properly stored and attributes are extracted + assert config.credentials == mock_creds + assert config.client_id == "test_client_id" + assert config.client_secret == "test_client_secret" + assert config.scopes == ["https://www.googleapis.com/auth/calendar"] + + def test_valid_client_id_secret_pair(self): + """Test that providing client ID and secret without credentials works. + + This tests the scenario where users want to create new OAuth credentials + from scratch using their application's client ID and secret. + """ + config = BigQueryCredentials( + client_id="test_client_id", + client_secret="test_client_secret", + scopes=["https://www.googleapis.com/auth/bigquery"], + ) + + assert config.credentials is None + assert config.client_id == "test_client_id" + assert config.client_secret == "test_client_secret" + assert config.scopes == ["https://www.googleapis.com/auth/bigquery"] + + def test_missing_client_secret_raises_error(self): + """Test that missing client secret raises appropriate validation error. + + This ensures that incomplete OAuth configuration is caught early + rather than failing during runtime. + """ + with pytest.raises( + ValueError, + match=( + "Must provide either credentials or client_id abd client_secret" + " pair" + ), + ): + BigQueryCredentials(client_id="test_client_id") + + def test_missing_client_id_raises_error(self): + """Test that missing client ID raises appropriate validation error.""" + with pytest.raises( + ValueError, + match=( + "Must provide either credentials or client_id abd client_secret" + " pair" + ), + ): + BigQueryCredentials(client_secret="test_client_secret") + + def test_empty_configuration_raises_error(self): + """Test that completely empty configuration is rejected. + + Users must provide either existing credentials or the components + needed to create new ones. + """ + with pytest.raises( + ValueError, + match=( + "Must provide either credentials or client_id abd client_secret" + " pair" + ), + ): + BigQueryCredentials() diff --git a/tests/unittests/tools/bigquery/test_bigquery_credentials_manager.py b/tests/unittests/tools/bigquery/test_bigquery_credentials_manager.py new file mode 100644 index 0000000..d9d594d --- /dev/null +++ b/tests/unittests/tools/bigquery/test_bigquery_credentials_manager.py @@ -0,0 +1,327 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from unittest.mock import Mock +from unittest.mock import patch + +from google.adk.auth import AuthConfig +from google.adk.tools import ToolContext +from google.adk.tools.bigquery.bigquery_credentials import BIGQUERY_TOKEN_CACHE_KEY +from google.adk.tools.bigquery.bigquery_credentials import BigQueryCredentials +from google.adk.tools.bigquery.bigquery_credentials import BigQueryCredentialsManager +from google.auth.exceptions import RefreshError +# Mock the Google OAuth and API dependencies +from google.oauth2.credentials import Credentials +import pytest + + +class TestBigQueryCredentialsManager: + """Test suite for BigQueryCredentialsManager OAuth flow handling. + + This class tests the complex credential management logic including + credential validation, refresh, OAuth flow orchestration, and the + new token caching functionality through tool_context.state. + """ + + @pytest.fixture + def mock_tool_context(self): + """Create a mock ToolContext for testing. + + The ToolContext is the interface between tools and the broader + agent framework, handling OAuth flows and state management. + Now includes state dictionary for testing caching behavior. + """ + context = Mock(spec=ToolContext) + context.get_auth_response = Mock(return_value=None) + context.request_credential = Mock() + # Mock the get method and state dictionary for caching tests + context.get = Mock(return_value=None) + context.state = {} + return context + + @pytest.fixture + def credentials_config(self): + """Create a basic credentials configuration for testing.""" + return BigQueryCredentials( + client_id="test_client_id", + client_secret="test_client_secret", + scopes=["https://www.googleapis.com/auth/calendar"], + ) + + @pytest.fixture + def manager(self, credentials_config): + """Create a credentials manager instance for testing.""" + return BigQueryCredentialsManager(credentials_config) + + @pytest.mark.asyncio + async def test_get_valid_credentials_with_valid_existing_creds( + self, manager, mock_tool_context + ): + """Test that valid existing credentials are returned immediately. + + When credentials are already valid, no refresh or OAuth flow + should be needed. This is the optimal happy path scenario. + """ + # Create mock credentials that are already valid + mock_creds = Mock(spec=Credentials) + mock_creds.valid = True + manager.credentials.credentials = mock_creds + + result = await manager.get_valid_credentials(mock_tool_context) + + assert result == mock_creds + # Verify no OAuth flow was triggered + mock_tool_context.get_auth_response.assert_not_called() + mock_tool_context.request_credential.assert_not_called() + # Verify cache retrieval wasn't needed since we had valid creds + mock_tool_context.get.assert_not_called() + + @pytest.mark.asyncio + async def test_get_credentials_from_cache_when_none_in_manager( + self, manager, mock_tool_context + ): + """Test retrieving credentials from tool_context cache when manager has none. + + This tests the new caching functionality where credentials can be + retrieved from the tool context state when the manager instance + doesn't have them loaded. + """ + # Manager starts with no credentials + manager.credentials.credentials = None + + # Create mock cached credentials that are valid + mock_cached_creds = Mock(spec=Credentials) + mock_cached_creds.valid = True + + # Set up the tool context to return cached credentials + mock_tool_context.get.return_value = mock_cached_creds + + result = await manager.get_valid_credentials(mock_tool_context) + + # Verify credentials were retrieved from cache + mock_tool_context.get.assert_called_once_with( + BIGQUERY_TOKEN_CACHE_KEY, None + ) + # Verify cached credentials were loaded into manager + assert manager.credentials.credentials == mock_cached_creds + # Verify valid cached credentials were returned + assert result == mock_cached_creds + + @pytest.mark.asyncio + async def test_no_credentials_in_manager_or_cache( + self, manager, mock_tool_context + ): + """Test OAuth flow when no credentials exist in manager or cache. + + This tests the scenario where both the manager and cache are empty, + requiring a new OAuth flow to be initiated. + """ + # Manager starts with no credentials + manager.credentials.credentials = None + # Cache also returns None + mock_tool_context.get.return_value = None + + result = await manager.get_valid_credentials(mock_tool_context) + + # Should trigger OAuth flow and return None (flow in progress) + assert result is None + mock_tool_context.get.assert_called_once_with( + BIGQUERY_TOKEN_CACHE_KEY, None + ) + mock_tool_context.request_credential.assert_called_once() + + @pytest.mark.asyncio + @patch("google.auth.transport.requests.Request") + async def test_refresh_cached_credentials_success( + self, mock_request_class, manager, mock_tool_context + ): + """Test successful refresh of expired credentials retrieved from cache. + + This tests the interaction between caching and refresh functionality, + ensuring that expired cached credentials can be refreshed properly. + """ + # Manager starts with no credentials + manager.credentials.credentials = None + + # Create expired cached credentials with refresh token + mock_cached_creds = Mock(spec=Credentials) + mock_cached_creds.valid = False + mock_cached_creds.expired = True + mock_cached_creds.refresh_token = "refresh_token" + + # Mock successful refresh + def mock_refresh(request): + mock_cached_creds.valid = True + + mock_cached_creds.refresh = Mock(side_effect=mock_refresh) + mock_tool_context.get.return_value = mock_cached_creds + + result = await manager.get_valid_credentials(mock_tool_context) + + # Verify credentials were retrieved from cache + mock_tool_context.get.assert_called_once_with( + BIGQUERY_TOKEN_CACHE_KEY, None + ) + # Verify refresh was attempted and succeeded + mock_cached_creds.refresh.assert_called_once() + # Verify refreshed credentials were loaded into manager + assert manager.credentials.credentials == mock_cached_creds + assert result == mock_cached_creds + + @pytest.mark.asyncio + @patch("google.auth.transport.requests.Request") + async def test_get_valid_credentials_with_refresh_success( + self, mock_request_class, manager, mock_tool_context + ): + """Test successful credential refresh when tokens are expired. + + This tests the automatic token refresh capability that prevents + users from having to re-authenticate for every expired token. + """ + # Create expired credentials with refresh token + mock_creds = Mock(spec=Credentials) + mock_creds.valid = False + mock_creds.expired = True + mock_creds.refresh_token = "refresh_token" + + # Mock successful refresh + def mock_refresh(request): + mock_creds.valid = True + + mock_creds.refresh = Mock(side_effect=mock_refresh) + manager.credentials.credentials = mock_creds + + result = await manager.get_valid_credentials(mock_tool_context) + + assert result == mock_creds + mock_creds.refresh.assert_called_once() + # Verify credentials were cached after successful refresh + assert manager.credentials.credentials == mock_creds + + @pytest.mark.asyncio + @patch("google.auth.transport.requests.Request") + async def test_get_valid_credentials_with_refresh_failure( + self, mock_request_class, manager, mock_tool_context + ): + """Test OAuth flow trigger when credential refresh fails. + + When refresh tokens expire or become invalid, the system should + gracefully fall back to requesting a new OAuth flow. + """ + # Create expired credentials that fail to refresh + mock_creds = Mock(spec=Credentials) + mock_creds.valid = False + mock_creds.expired = True + mock_creds.refresh_token = "expired_refresh_token" + mock_creds.refresh = Mock(side_effect=RefreshError("Refresh failed")) + manager.credentials.credentials = mock_creds + + result = await manager.get_valid_credentials(mock_tool_context) + + # Should trigger OAuth flow and return None (flow in progress) + assert result is None + mock_tool_context.request_credential.assert_called_once() + + @pytest.mark.asyncio + async def test_oauth_flow_completion_with_caching( + self, manager, mock_tool_context + ): + """Test successful OAuth flow completion with proper credential caching. + + This tests the happy path where a user completes the OAuth flow + and the system successfully creates and caches new credentials + in both the manager and the tool context state. + """ + # Mock OAuth response indicating completed flow + mock_auth_response = Mock() + mock_auth_response.oauth2.access_token = "new_access_token" + mock_auth_response.oauth2.refresh_token = "new_refresh_token" + mock_tool_context.get_auth_response.return_value = mock_auth_response + + result = await manager.get_valid_credentials(mock_tool_context) + + # Verify new credentials were created and cached + assert isinstance(result, Credentials) + assert result.token == "new_access_token" + assert result.refresh_token == "new_refresh_token" + # Verify credentials are cached in manager + assert manager.credentials.credentials == result + # Verify credentials are also cached in tool context state + assert mock_tool_context.state[BIGQUERY_TOKEN_CACHE_KEY] == result + + @pytest.mark.asyncio + async def test_oauth_flow_in_progress(self, manager, mock_tool_context): + """Test OAuth flow initiation when no auth response is available. + + This tests the case where the OAuth flow needs to be started, + and the user hasn't completed authorization yet. + """ + # No existing credentials, no auth response (flow not completed) + mock_tool_context.get_auth_response.return_value = None + + result = await manager.get_valid_credentials(mock_tool_context) + + # Should return None and request credential flow + assert result is None + mock_tool_context.request_credential.assert_called_once() + + # Verify the auth configuration includes correct scopes and endpoints + call_args = mock_tool_context.request_credential.call_args[0][0] + assert isinstance(call_args, AuthConfig) + + @pytest.mark.asyncio + async def test_cache_persistence_across_manager_instances( + self, credentials_config, mock_tool_context + ): + """Test that cached credentials persist across different manager instances. + + This tests the key benefit of the tool context caching - that + credentials can be shared between different instances of the + credential manager, avoiding redundant OAuth flows. + """ + # Create first manager instance and simulate OAuth completion + manager1 = BigQueryCredentialsManager(credentials_config) + + # Mock OAuth response for first manager + mock_auth_response = Mock() + mock_auth_response.oauth2.access_token = "cached_access_token" + mock_auth_response.oauth2.refresh_token = "cached_refresh_token" + mock_tool_context.get_auth_response.return_value = mock_auth_response + + # Complete OAuth flow with first manager + result1 = await manager1.get_valid_credentials(mock_tool_context) + + # Verify credentials were cached in tool context + assert BIGQUERY_TOKEN_CACHE_KEY in mock_tool_context.state + cached_creds = mock_tool_context.state[BIGQUERY_TOKEN_CACHE_KEY] + + # Create second manager instance (simulating new request/session) + manager2 = BigQueryCredentialsManager(credentials_config) + + # Reset auth response to None (no new OAuth flow available) + mock_tool_context.get_auth_response.return_value = None + # Set up get method to return cached credentials + mock_tool_context.get.return_value = cached_creds + + # Get credentials with second manager + result2 = await manager2.get_valid_credentials(mock_tool_context) + + # Verify second manager retrieved cached credentials successfully + assert result2 == cached_creds + assert manager2.credentials.credentials == cached_creds + # Verify no new OAuth flow was requested + assert ( + mock_tool_context.request_credential.call_count == 0 + ) # Only from first manager diff --git a/tests/unittests/tools/bigquery/test_bigquery_tool.py b/tests/unittests/tools/bigquery/test_bigquery_tool.py new file mode 100644 index 0000000..2accd82 --- /dev/null +++ b/tests/unittests/tools/bigquery/test_bigquery_tool.py @@ -0,0 +1,259 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from unittest.mock import Mock +from unittest.mock import patch + +from google.adk.tools import ToolContext +from google.adk.tools.bigquery.bigquery_credentials import BigQueryCredentials +from google.adk.tools.bigquery.bigquery_credentials import BigQueryCredentialsManager +from google.adk.tools.bigquery.bigquery_tool import BigQueryTool +# Mock the Google OAuth and API dependencies +from google.oauth2.credentials import Credentials +import pytest + + +class TestBigQueryTool: + """Test suite for BigQueryTool OAuth integration and execution. + + This class tests the high-level tool execution logic that combines + credential management with actual function execution. + """ + + @pytest.fixture + def mock_tool_context(self): + """Create a mock ToolContext for testing tool execution.""" + context = Mock(spec=ToolContext) + context.get_auth_response = Mock(return_value=None) + context.request_credential = Mock() + return context + + @pytest.fixture + def sample_function(self): + """Create a sample function that accepts credentials for testing. + + This simulates a real Google API tool function that needs + authenticated credentials to perform its work. + """ + + def sample_func(param1: str, credentials: Credentials = None) -> dict: + """Sample function that uses Google API credentials.""" + if credentials: + return {"result": f"Success with {param1}", "authenticated": True} + else: + return {"result": f"Success with {param1}", "authenticated": False} + + return sample_func + + @pytest.fixture + def async_sample_function(self): + """Create an async sample function for testing async execution paths.""" + + async def async_sample_func( + param1: str, credentials: Credentials = None + ) -> dict: + """Async sample function that uses Google API credentials.""" + if credentials: + return {"result": f"Async success with {param1}", "authenticated": True} + else: + return { + "result": f"Async success with {param1}", + "authenticated": False, + } + + return async_sample_func + + @pytest.fixture + def credentials_config(self): + """Create credentials configuration for testing.""" + return BigQueryCredentials( + client_id="test_client_id", + client_secret="test_client_secret", + scopes=["https://www.googleapis.com/auth/bigquery"], + ) + + def test_tool_initialization_with_credentials( + self, sample_function, credentials_config + ): + """Test that BigQueryTool initializes correctly with credentials. + + The tool should properly inherit from FunctionTool while adding + Google API specific credential management capabilities. + """ + tool = BigQueryTool(func=sample_function, credentials=credentials_config) + + assert tool.func == sample_function + assert tool.credentials_manager is not None + assert isinstance(tool.credentials_manager, BigQueryCredentialsManager) + # Verify that 'credentials' parameter is ignored in function signature analysis + assert "credentials" in tool._ignore_params + + def test_tool_initialization_without_credentials(self, sample_function): + """Test tool initialization when no credential management is needed. + + Some tools might handle authentication externally or use service + accounts, so credential management should be optional. + """ + tool = BigQueryTool(func=sample_function, credentials=None) + + assert tool.func == sample_function + assert tool.credentials_manager is None + + @pytest.mark.asyncio + async def test_run_async_with_valid_credentials( + self, sample_function, credentials_config, mock_tool_context + ): + """Test successful tool execution with valid credentials. + + This tests the main happy path where credentials are available + and the underlying function executes successfully. + """ + tool = BigQueryTool(func=sample_function, credentials=credentials_config) + + # Mock the credentials manager to return valid credentials + mock_creds = Mock(spec=Credentials) + with patch.object( + tool.credentials_manager, + "get_valid_credentials", + return_value=mock_creds, + ) as mock_get_creds: + + result = await tool.run_async( + args={"param1": "test_value"}, tool_context=mock_tool_context + ) + + mock_get_creds.assert_called_once_with(mock_tool_context) + assert result["result"] == "Success with test_value" + assert result["authenticated"] is True + + @pytest.mark.asyncio + async def test_run_async_oauth_flow_in_progress( + self, sample_function, credentials_config, mock_tool_context + ): + """Test tool behavior when OAuth flow is in progress. + + When credentials aren't available and OAuth flow is needed, + the tool should return a user-friendly message rather than failing. + """ + tool = BigQueryTool(func=sample_function, credentials=credentials_config) + + # Mock credentials manager to return None (OAuth flow in progress) + with patch.object( + tool.credentials_manager, "get_valid_credentials", return_value=None + ) as mock_get_creds: + + result = await tool.run_async( + args={"param1": "test_value"}, tool_context=mock_tool_context + ) + + mock_get_creds.assert_called_once_with(mock_tool_context) + assert "authorization is required" in result.lower() + assert tool.name in result + + @pytest.mark.asyncio + async def test_run_async_without_credentials_manager( + self, sample_function, mock_tool_context + ): + """Test tool execution when no credential management is configured. + + Tools without credential managers should execute normally, + passing None for credentials if the function accepts them. + """ + tool = BigQueryTool(func=sample_function, credentials=None) + + result = await tool.run_async( + args={"param1": "test_value"}, tool_context=mock_tool_context + ) + + assert result["result"] == "Success with test_value" + assert result["authenticated"] is False + + @pytest.mark.asyncio + async def test_run_async_with_async_function( + self, async_sample_function, credentials_config, mock_tool_context + ): + """Test that async functions are properly handled. + + The tool should correctly detect and execute async functions, + which is important for tools that make async API calls. + """ + tool = BigQueryTool( + func=async_sample_function, credentials=credentials_config + ) + + mock_creds = Mock(spec=Credentials) + with patch.object( + tool.credentials_manager, + "get_valid_credentials", + return_value=mock_creds, + ): + + result = await tool.run_async( + args={"param1": "test_value"}, tool_context=mock_tool_context + ) + + assert result["result"] == "Async success with test_value" + assert result["authenticated"] is True + + @pytest.mark.asyncio + async def test_run_async_exception_handling( + self, credentials_config, mock_tool_context + ): + """Test that exceptions in tool execution are properly handled. + + Tools should gracefully handle errors and return structured + error responses rather than letting exceptions propagate. + """ + + def failing_function(param1: str, credentials: Credentials = None) -> dict: + raise ValueError("Something went wrong") + + tool = BigQueryTool(func=failing_function, credentials=credentials_config) + + mock_creds = Mock(spec=Credentials) + with patch.object( + tool.credentials_manager, + "get_valid_credentials", + return_value=mock_creds, + ): + + result = await tool.run_async( + args={"param1": "test_value"}, tool_context=mock_tool_context + ) + + assert result["status"] == "ERROR" + assert "Something went wrong" in result["error_details"] + + def test_function_signature_analysis(self, credentials_config): + """Test that function signature analysis correctly handles credentials parameter. + + The tool should properly identify and handle the credentials parameter + while preserving other parameter analysis for LLM function calling. + """ + + def complex_function( + required_param: str, + optional_param: str = "default", + credentials: Credentials = None, + ) -> dict: + return {"success": True} + + tool = BigQueryTool(func=complex_function, credentials=credentials_config) + + # The 'credentials' parameter should be ignored in mandatory args analysis + mandatory_args = tool._get_mandatory_args() + assert "required_param" in mandatory_args + assert "credentials" not in mandatory_args + assert "optional_param" not in mandatory_args