fix: fix bigquery credentials and bigquery tool to make it compatible with python 3.9 and make the credential serializable in session

PiperOrigin-RevId: 763332829
This commit is contained in:
Xiang (Sean) Zhou 2025-05-26 01:57:40 -07:00 committed by Copybara-Service
parent 55cb36edfe
commit 694eca08e5
5 changed files with 233 additions and 104 deletions

View File

@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from __future__ import annotations
from typing import List from typing import List
from typing import Optional from typing import Optional
@ -33,15 +35,31 @@ from ..tool_context import ToolContext
BIGQUERY_TOKEN_CACHE_KEY = "bigquery_token_cache" BIGQUERY_TOKEN_CACHE_KEY = "bigquery_token_cache"
class BigQueryCredentials(BaseModel): class BigQueryCredentialsConfig(BaseModel):
"""Configuration for Google API tools. (Experimental)""" """Configuration for Google API tools. (Experimental)"""
# Configure the model to allow arbitrary types like Credentials # Configure the model to allow arbitrary types like Credentials
model_config = {"arbitrary_types_allowed": True} model_config = {"arbitrary_types_allowed": True}
credentials: Optional[Credentials] = None credentials: Optional[Credentials] = None
"""the existing oauth credentials to use. If set will override client ID, """the existing oauth credentials to use. If set,this credential will be used
client secret, and scopes.""" for every end user, end users don't need to be involved in the oauthflow. This
field is mutually exclusive with client_id, client_secret and scopes.
Don't set this field unless you are sure this credential has the permission to
access every end user's data.
Example usage: when the agent is deployed in Google Cloud environment and
the service account (used as application default credentials) has access to
all the required BigQuery resource. Setting this credential to allow user to
access the BigQuery resource without end users going through oauth flow.
To get application default credential: `google.auth.default(...)`. See more
details in https://cloud.google.com/docs/authentication/application-default-credentials.
When the deployed environment cannot provide a pre-existing credential,
consider setting below client_id, client_secret and scope for end users to go
through oauth flow, so that agent can access the user data.
"""
client_id: Optional[str] = None client_id: Optional[str] = None
"""the oauth client ID to use.""" """the oauth client ID to use."""
client_secret: Optional[str] = None client_secret: Optional[str] = None
@ -51,12 +69,20 @@ class BigQueryCredentials(BaseModel):
""" """
@model_validator(mode="after") @model_validator(mode="after")
def __post_init__(self) -> "BigQueryCredentials": def __post_init__(self) -> BigQueryCredentialsConfig:
"""Validate that either credentials or client ID/secret are provided.""" """Validate that either credentials or client ID/secret are provided."""
if not self.credentials and (not self.client_id or not self.client_secret): if not self.credentials and (not self.client_id or not self.client_secret):
raise ValueError( raise ValueError(
"Must provide either credentials or client_id abd client_secret pair." "Must provide either credentials or client_id abd client_secret pair."
) )
if self.credentials and (
self.client_id or self.client_secret or self.scopes
):
raise ValueError(
"Cannot provide both existing credentials and"
" client_id/client_secret/scopes."
)
if self.credentials: if self.credentials:
self.client_id = self.credentials.client_id self.client_id = self.credentials.client_id
self.client_secret = self.credentials.client_secret self.client_secret = self.credentials.client_secret
@ -71,14 +97,14 @@ class BigQueryCredentialsManager:
the same authenticated session without duplicating OAuth logic. the same authenticated session without duplicating OAuth logic.
""" """
def __init__(self, credentials: BigQueryCredentials): def __init__(self, credentials_config: BigQueryCredentialsConfig):
"""Initialize the credential manager. """Initialize the credential manager.
Args: Args:
credential_config: Configuration containing OAuth details or existing credentials_config: Credentials containing client id and client secrete
credentials or default credentials
""" """
self.credentials = credentials self.credentials_config = credentials_config
async def get_valid_credentials( async def get_valid_credentials(
self, tool_context: ToolContext self, tool_context: ToolContext
@ -87,18 +113,23 @@ class BigQueryCredentialsManager:
Args: Args:
tool_context: The tool context for OAuth flow and state management tool_context: The tool context for OAuth flow and state management
required_scopes: Set of OAuth scopes required by the calling tool
Returns: Returns:
Valid Credentials object, or None if OAuth flow is needed Valid Credentials object, or None if OAuth flow is needed
""" """
# First, try to get cached credentials from the instance # First, try to get credentials from the tool context
creds = self.credentials.credentials creds_json = tool_context.state.get(BIGQUERY_TOKEN_CACHE_KEY, None)
creds = (
Credentials.from_authorized_user_info(
creds_json, self.credentials_config.scopes
)
if creds_json
else None
)
# If credentails are empty # If credentails are empty use the default credential
if not creds: if not creds:
creds = tool_context.get(BIGQUERY_TOKEN_CACHE_KEY, None) creds = self.credentials_config.credentials
self.credentials.credentials = creds
# Check if we have valid credentials # Check if we have valid credentials
if creds and creds.valid: if creds and creds.valid:
@ -110,7 +141,7 @@ class BigQueryCredentialsManager:
creds.refresh(Request()) creds.refresh(Request())
if creds.valid: if creds.valid:
# Cache the refreshed credentials # Cache the refreshed credentials
self.credentials.credentials = creds tool_context.state[BIGQUERY_TOKEN_CACHE_KEY] = creds.to_json()
return creds return creds
except RefreshError: except RefreshError:
# Refresh failed, need to re-authenticate # Refresh failed, need to re-authenticate
@ -140,7 +171,7 @@ class BigQueryCredentialsManager:
tokenUrl="https://oauth2.googleapis.com/token", tokenUrl="https://oauth2.googleapis.com/token",
scopes={ scopes={
scope: f"Access to {scope}" scope: f"Access to {scope}"
for scope in self.credentials.scopes for scope in self.credentials_config.scopes
}, },
) )
) )
@ -149,8 +180,8 @@ class BigQueryCredentialsManager:
auth_credential = AuthCredential( auth_credential = AuthCredential(
auth_type=AuthCredentialTypes.OAUTH2, auth_type=AuthCredentialTypes.OAUTH2,
oauth2=OAuth2Auth( oauth2=OAuth2Auth(
client_id=self.credentials.client_id, client_id=self.credentials_config.client_id,
client_secret=self.credentials.client_secret, client_secret=self.credentials_config.client_secret,
), ),
) )
@ -165,14 +196,14 @@ class BigQueryCredentialsManager:
token=auth_response.oauth2.access_token, token=auth_response.oauth2.access_token,
refresh_token=auth_response.oauth2.refresh_token, refresh_token=auth_response.oauth2.refresh_token,
token_uri=auth_scheme.flows.authorizationCode.tokenUrl, token_uri=auth_scheme.flows.authorizationCode.tokenUrl,
client_id=self.credentials.client_id, client_id=self.credentials_config.client_id,
client_secret=self.credentials.client_secret, client_secret=self.credentials_config.client_secret,
scopes=list(self.credentials.scopes), scopes=list(self.credentials_config.scopes),
) )
# Cache the new credentials # Cache the new credentials
self.credentials.credentials = creds tool_context.state[BIGQUERY_TOKEN_CACHE_KEY] = creds.to_json()
tool_context.state[BIGQUERY_TOKEN_CACHE_KEY] = creds
return creds return creds
else: else:
# Request OAuth flow # Request OAuth flow

View File

@ -17,13 +17,13 @@ import inspect
from typing import Any from typing import Any
from typing import Callable from typing import Callable
from typing import Optional from typing import Optional
from typing import override
from google.oauth2.credentials import Credentials from google.oauth2.credentials import Credentials
from typing_extensions import override
from ..function_tool import FunctionTool from ..function_tool import FunctionTool
from ..tool_context import ToolContext from ..tool_context import ToolContext
from .bigquery_credentials import BigQueryCredentials from .bigquery_credentials import BigQueryCredentialsConfig
from .bigquery_credentials import BigQueryCredentialsManager from .bigquery_credentials import BigQueryCredentialsManager
@ -41,7 +41,7 @@ class BigQueryTool(FunctionTool):
def __init__( def __init__(
self, self,
func: Callable[..., Any], func: Callable[..., Any],
credentials: Optional[BigQueryCredentials] = None, credentials: Optional[BigQueryCredentialsConfig] = None,
): ):
"""Initialize the Google API tool. """Initialize the Google API tool.

View File

@ -14,7 +14,7 @@
from unittest.mock import Mock from unittest.mock import Mock
from google.adk.tools.bigquery.bigquery_credentials import BigQueryCredentials from google.adk.tools.bigquery.bigquery_credentials import BigQueryCredentialsConfig
# Mock the Google OAuth and API dependencies # Mock the Google OAuth and API dependencies
from google.oauth2.credentials import Credentials from google.oauth2.credentials import Credentials
import pytest import pytest
@ -39,7 +39,7 @@ class TestBigQueryCredentials:
mock_creds.client_secret = "test_client_secret" mock_creds.client_secret = "test_client_secret"
mock_creds.scopes = ["https://www.googleapis.com/auth/calendar"] mock_creds.scopes = ["https://www.googleapis.com/auth/calendar"]
config = BigQueryCredentials(credentials=mock_creds) config = BigQueryCredentialsConfig(credentials=mock_creds)
# Verify that the credentials are properly stored and attributes are extracted # Verify that the credentials are properly stored and attributes are extracted
assert config.credentials == mock_creds assert config.credentials == mock_creds
@ -53,7 +53,7 @@ class TestBigQueryCredentials:
This tests the scenario where users want to create new OAuth credentials This tests the scenario where users want to create new OAuth credentials
from scratch using their application's client ID and secret. from scratch using their application's client ID and secret.
""" """
config = BigQueryCredentials( config = BigQueryCredentialsConfig(
client_id="test_client_id", client_id="test_client_id",
client_secret="test_client_secret", client_secret="test_client_secret",
scopes=["https://www.googleapis.com/auth/bigquery"], scopes=["https://www.googleapis.com/auth/bigquery"],
@ -77,7 +77,7 @@ class TestBigQueryCredentials:
" pair" " pair"
), ),
): ):
BigQueryCredentials(client_id="test_client_id") BigQueryCredentialsConfig(client_id="test_client_id")
def test_missing_client_id_raises_error(self): def test_missing_client_id_raises_error(self):
"""Test that missing client ID raises appropriate validation error.""" """Test that missing client ID raises appropriate validation error."""
@ -88,7 +88,7 @@ class TestBigQueryCredentials:
" pair" " pair"
), ),
): ):
BigQueryCredentials(client_secret="test_client_secret") BigQueryCredentialsConfig(client_secret="test_client_secret")
def test_empty_configuration_raises_error(self): def test_empty_configuration_raises_error(self):
"""Test that completely empty configuration is rejected. """Test that completely empty configuration is rejected.
@ -103,4 +103,4 @@ class TestBigQueryCredentials:
" pair" " pair"
), ),
): ):
BigQueryCredentials() BigQueryCredentialsConfig()

View File

@ -19,7 +19,7 @@ from unittest.mock import patch
from google.adk.auth import AuthConfig from google.adk.auth import AuthConfig
from google.adk.tools import ToolContext from google.adk.tools import ToolContext
from google.adk.tools.bigquery.bigquery_credentials import BIGQUERY_TOKEN_CACHE_KEY from google.adk.tools.bigquery.bigquery_credentials import BIGQUERY_TOKEN_CACHE_KEY
from google.adk.tools.bigquery.bigquery_credentials import BigQueryCredentials from google.adk.tools.bigquery.bigquery_credentials import BigQueryCredentialsConfig
from google.adk.tools.bigquery.bigquery_credentials import BigQueryCredentialsManager from google.adk.tools.bigquery.bigquery_credentials import BigQueryCredentialsManager
from google.auth.exceptions import RefreshError from google.auth.exceptions import RefreshError
# Mock the Google OAuth and API dependencies # Mock the Google OAuth and API dependencies
@ -46,15 +46,13 @@ class TestBigQueryCredentialsManager:
context = Mock(spec=ToolContext) context = Mock(spec=ToolContext)
context.get_auth_response = Mock(return_value=None) context.get_auth_response = Mock(return_value=None)
context.request_credential = Mock() context.request_credential = Mock()
# Mock the get method and state dictionary for caching tests
context.get = Mock(return_value=None)
context.state = {} context.state = {}
return context return context
@pytest.fixture @pytest.fixture
def credentials_config(self): def credentials_config(self):
"""Create a basic credentials configuration for testing.""" """Create a basic credentials configuration for testing."""
return BigQueryCredentials( return BigQueryCredentialsConfig(
client_id="test_client_id", client_id="test_client_id",
client_secret="test_client_secret", client_secret="test_client_secret",
scopes=["https://www.googleapis.com/auth/calendar"], scopes=["https://www.googleapis.com/auth/calendar"],
@ -77,7 +75,7 @@ class TestBigQueryCredentialsManager:
# Create mock credentials that are already valid # Create mock credentials that are already valid
mock_creds = Mock(spec=Credentials) mock_creds = Mock(spec=Credentials)
mock_creds.valid = True mock_creds.valid = True
manager.credentials.credentials = mock_creds manager.credentials_config.credentials = mock_creds
result = await manager.get_valid_credentials(mock_tool_context) result = await manager.get_valid_credentials(mock_tool_context)
@ -85,8 +83,6 @@ class TestBigQueryCredentialsManager:
# Verify no OAuth flow was triggered # Verify no OAuth flow was triggered
mock_tool_context.get_auth_response.assert_not_called() mock_tool_context.get_auth_response.assert_not_called()
mock_tool_context.request_credential.assert_not_called() mock_tool_context.request_credential.assert_not_called()
# Verify cache retrieval wasn't needed since we had valid creds
mock_tool_context.get.assert_not_called()
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_get_credentials_from_cache_when_none_in_manager( async def test_get_credentials_from_cache_when_none_in_manager(
@ -99,25 +95,37 @@ class TestBigQueryCredentialsManager:
doesn't have them loaded. doesn't have them loaded.
""" """
# Manager starts with no credentials # Manager starts with no credentials
manager.credentials.credentials = None manager.credentials_config.credentials = None
# Create mock cached credentials that are valid # Create mock cached credentials JSON that would be stored in cache
mock_cached_creds = Mock(spec=Credentials) mock_cached_creds_json = {
mock_cached_creds.valid = True "token": "cached_token",
"refresh_token": "cached_refresh_token",
"client_id": "test_client_id",
"client_secret": "test_client_secret",
}
# Set up the tool context to return cached credentials # Set up the tool context state to contain cached credentials
mock_tool_context.get.return_value = mock_cached_creds mock_tool_context.state[BIGQUERY_TOKEN_CACHE_KEY] = mock_cached_creds_json
# Mock the Credentials.from_authorized_user_info method
with patch(
"google.oauth2.credentials.Credentials.from_authorized_user_info"
) as mock_from_json:
mock_creds = Mock(spec=Credentials)
mock_creds.valid = True
mock_from_json.return_value = mock_creds
result = await manager.get_valid_credentials(mock_tool_context) result = await manager.get_valid_credentials(mock_tool_context)
# Verify credentials were retrieved from cache # Verify credentials were created from cached JSON
mock_tool_context.get.assert_called_once_with( mock_from_json.assert_called_once_with(
BIGQUERY_TOKEN_CACHE_KEY, None mock_cached_creds_json, manager.credentials_config.scopes
) )
# Verify cached credentials were loaded into manager # Verify loaded credentials were not cached into manager
assert manager.credentials.credentials == mock_cached_creds assert manager.credentials_config.credentials is None
# Verify valid cached credentials were returned # Verify valid cached credentials were returned
assert result == mock_cached_creds assert result == mock_creds
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_no_credentials_in_manager_or_cache( async def test_no_credentials_in_manager_or_cache(
@ -129,17 +137,13 @@ class TestBigQueryCredentialsManager:
requiring a new OAuth flow to be initiated. requiring a new OAuth flow to be initiated.
""" """
# Manager starts with no credentials # Manager starts with no credentials
manager.credentials.credentials = None manager.credentials_config.credentials = None
# Cache also returns None # Cache is also empty (state dict doesn't contain the key)
mock_tool_context.get.return_value = None
result = await manager.get_valid_credentials(mock_tool_context) result = await manager.get_valid_credentials(mock_tool_context)
# Should trigger OAuth flow and return None (flow in progress) # Should trigger OAuth flow and return None (flow in progress)
assert result is None assert result is None
mock_tool_context.get.assert_called_once_with(
BIGQUERY_TOKEN_CACHE_KEY, None
)
mock_tool_context.request_credential.assert_called_once() mock_tool_context.request_credential.assert_called_once()
@pytest.mark.asyncio @pytest.mark.asyncio
@ -152,32 +156,61 @@ class TestBigQueryCredentialsManager:
This tests the interaction between caching and refresh functionality, This tests the interaction between caching and refresh functionality,
ensuring that expired cached credentials can be refreshed properly. ensuring that expired cached credentials can be refreshed properly.
""" """
# Manager starts with no credentials # Manager starts with no default credentials
manager.credentials.credentials = None manager.credentials_config.credentials = None
# Create mock cached credentials JSON
mock_cached_creds_json = {
"token": "expired_token",
"refresh_token": "valid_refresh_token",
"client_id": "test_client_id",
"client_secret": "test_client_secret",
}
mock_refreshed_creds_json = {
"token": "new_token",
"refresh_token": "valid_refresh_token",
"client_id": "test_client_id",
"client_secret": "test_client_secret",
}
# Set up the tool context state to contain cached credentials
mock_tool_context.state[BIGQUERY_TOKEN_CACHE_KEY] = mock_cached_creds_json
# Create expired cached credentials with refresh token # Create expired cached credentials with refresh token
mock_cached_creds = Mock(spec=Credentials) mock_cached_creds = Mock(spec=Credentials)
mock_cached_creds.valid = False mock_cached_creds.valid = False
mock_cached_creds.expired = True mock_cached_creds.expired = True
mock_cached_creds.refresh_token = "refresh_token" mock_cached_creds.refresh_token = "valid_refresh_token"
mock_cached_creds.to_json.return_value = mock_refreshed_creds_json
# Mock successful refresh # Mock successful refresh
def mock_refresh(request): def mock_refresh(request):
mock_cached_creds.valid = True mock_cached_creds.valid = True
mock_cached_creds.refresh = Mock(side_effect=mock_refresh) mock_cached_creds.refresh = Mock(side_effect=mock_refresh)
mock_tool_context.get.return_value = mock_cached_creds
# Mock the Credentials.from_authorized_user_info method
with patch(
"google.oauth2.credentials.Credentials.from_authorized_user_info"
) as mock_from_json:
mock_from_json.return_value = mock_cached_creds
result = await manager.get_valid_credentials(mock_tool_context) result = await manager.get_valid_credentials(mock_tool_context)
# Verify credentials were retrieved from cache # Verify credentials were created from cached JSON
mock_tool_context.get.assert_called_once_with( mock_from_json.assert_called_once_with(
BIGQUERY_TOKEN_CACHE_KEY, None mock_cached_creds_json, manager.credentials_config.scopes
) )
# Verify refresh was attempted and succeeded # Verify refresh was attempted and succeeded
mock_cached_creds.refresh.assert_called_once() mock_cached_creds.refresh.assert_called_once()
# Verify refreshed credentials were loaded into manager # Verify refreshed credentials were not cached into manager
assert manager.credentials.credentials == mock_cached_creds assert manager.credentials_config.credentials is None
# Verify refreshed credentials were cached
assert (
"new_token"
== mock_tool_context.state[BIGQUERY_TOKEN_CACHE_KEY]["token"]
)
assert result == mock_cached_creds assert result == mock_cached_creds
@pytest.mark.asyncio @pytest.mark.asyncio
@ -201,14 +234,14 @@ class TestBigQueryCredentialsManager:
mock_creds.valid = True mock_creds.valid = True
mock_creds.refresh = Mock(side_effect=mock_refresh) mock_creds.refresh = Mock(side_effect=mock_refresh)
manager.credentials.credentials = mock_creds manager.credentials_config.credentials = mock_creds
result = await manager.get_valid_credentials(mock_tool_context) result = await manager.get_valid_credentials(mock_tool_context)
assert result == mock_creds assert result == mock_creds
mock_creds.refresh.assert_called_once() mock_creds.refresh.assert_called_once()
# Verify credentials were cached after successful refresh # Verify credentials were cached after successful refresh
assert manager.credentials.credentials == mock_creds assert manager.credentials_config.credentials == mock_creds
@pytest.mark.asyncio @pytest.mark.asyncio
@patch("google.auth.transport.requests.Request") @patch("google.auth.transport.requests.Request")
@ -226,7 +259,7 @@ class TestBigQueryCredentialsManager:
mock_creds.expired = True mock_creds.expired = True
mock_creds.refresh_token = "expired_refresh_token" mock_creds.refresh_token = "expired_refresh_token"
mock_creds.refresh = Mock(side_effect=RefreshError("Refresh failed")) mock_creds.refresh = Mock(side_effect=RefreshError("Refresh failed"))
manager.credentials.credentials = mock_creds manager.credentials_config.credentials = mock_creds
result = await manager.get_valid_credentials(mock_tool_context) result = await manager.get_valid_credentials(mock_tool_context)
@ -250,16 +283,39 @@ class TestBigQueryCredentialsManager:
mock_auth_response.oauth2.refresh_token = "new_refresh_token" mock_auth_response.oauth2.refresh_token = "new_refresh_token"
mock_tool_context.get_auth_response.return_value = mock_auth_response mock_tool_context.get_auth_response.return_value = mock_auth_response
# Create a mock credentials instance that will represent our created credentials
mock_creds = Mock(spec=Credentials)
# Make the JSON match what a real Credentials object would produce
mock_creds_json = (
'{"token": "new_access_token", "refresh_token": "new_refresh_token",'
' "token_uri": "https://oauth2.googleapis.com/token", "client_id":'
' "test_client_id", "client_secret": "test_client_secret", "scopes":'
' ["https://www.googleapis.com/auth/calendar"], "universe_domain":'
' "googleapis.com", "account": ""}'
)
mock_creds.to_json.return_value = mock_creds_json
# Use the full module path as it appears in the project structure
with patch(
"google.adk.tools.bigquery.bigquery_credentials.Credentials",
return_value=mock_creds,
) as mock_credentials_class:
result = await manager.get_valid_credentials(mock_tool_context) result = await manager.get_valid_credentials(mock_tool_context)
# Verify new credentials were created and cached # Verify new credentials were created
assert isinstance(result, Credentials) assert result == mock_creds
assert result.token == "new_access_token" # Verify credentials are created with correct parameters
assert result.refresh_token == "new_refresh_token" mock_credentials_class.assert_called_once()
# Verify credentials are cached in manager call_kwargs = mock_credentials_class.call_args[1]
assert manager.credentials.credentials == result assert call_kwargs["token"] == "new_access_token"
assert call_kwargs["refresh_token"] == "new_refresh_token"
# Verify credentials are not cached in manager
assert manager.credentials_config.credentials is None
# Verify credentials are also cached in tool context state # Verify credentials are also cached in tool context state
assert mock_tool_context.state[BIGQUERY_TOKEN_CACHE_KEY] == result assert (
mock_tool_context.state[BIGQUERY_TOKEN_CACHE_KEY] == mock_creds_json
)
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_oauth_flow_in_progress(self, manager, mock_tool_context): async def test_oauth_flow_in_progress(self, manager, mock_tool_context):
@ -269,6 +325,7 @@ class TestBigQueryCredentialsManager:
and the user hasn't completed authorization yet. and the user hasn't completed authorization yet.
""" """
# No existing credentials, no auth response (flow not completed) # No existing credentials, no auth response (flow not completed)
manager.credentials_config.credentials = None
mock_tool_context.get_auth_response.return_value = None mock_tool_context.get_auth_response.return_value = None
result = await manager.get_valid_credentials(mock_tool_context) result = await manager.get_valid_credentials(mock_tool_context)
@ -300,28 +357,69 @@ class TestBigQueryCredentialsManager:
mock_auth_response.oauth2.refresh_token = "cached_refresh_token" mock_auth_response.oauth2.refresh_token = "cached_refresh_token"
mock_tool_context.get_auth_response.return_value = mock_auth_response mock_tool_context.get_auth_response.return_value = mock_auth_response
# Create the mock credentials instance that will be returned by the constructor
mock_creds = Mock(spec=Credentials)
# Make sure our mock JSON matches the structure that real Credentials objects produce
mock_creds_json = (
'{"token": "cached_access_token", "refresh_token":'
' "cached_refresh_token", "token_uri":'
' "https://oauth2.googleapis.com/token", "client_id": "test_client_id",'
' "client_secret": "test_client_secret", "scopes":'
' ["https://www.googleapis.com/auth/calendar"], "universe_domain":'
' "googleapis.com", "account": ""}'
)
mock_creds.to_json.return_value = mock_creds_json
mock_creds.valid = True
# Use the correct module path - without the 'src.' prefix
with patch(
"google.adk.tools.bigquery.bigquery_credentials.Credentials",
return_value=mock_creds,
) as mock_credentials_class:
# Complete OAuth flow with first manager # Complete OAuth flow with first manager
result1 = await manager1.get_valid_credentials(mock_tool_context) result1 = await manager1.get_valid_credentials(mock_tool_context)
# Verify credentials were cached in tool context # Verify credentials were cached in tool context
assert BIGQUERY_TOKEN_CACHE_KEY in mock_tool_context.state assert BIGQUERY_TOKEN_CACHE_KEY in mock_tool_context.state
cached_creds = mock_tool_context.state[BIGQUERY_TOKEN_CACHE_KEY] cached_creds_json = mock_tool_context.state[BIGQUERY_TOKEN_CACHE_KEY]
assert cached_creds_json == mock_creds_json
# Create second manager instance (simulating new request/session) # Create second manager instance (simulating new request/session)
manager2 = BigQueryCredentialsManager(credentials_config) manager2 = BigQueryCredentialsManager(credentials_config)
credentials_config.credentials = None
# Reset auth response to None (no new OAuth flow available) # Reset auth response to None (no new OAuth flow available)
mock_tool_context.get_auth_response.return_value = None mock_tool_context.get_auth_response.return_value = None
# Set up get method to return cached credentials
mock_tool_context.get.return_value = cached_creds # Mock the from_authorized_user_info method for the second manager
with patch(
"google.adk.tools.bigquery.bigquery_credentials.Credentials.from_authorized_user_info"
) as mock_from_json:
mock_cached_creds = Mock(spec=Credentials)
mock_cached_creds.valid = True
mock_from_json.return_value = mock_cached_creds
# Get credentials with second manager # Get credentials with second manager
result2 = await manager2.get_valid_credentials(mock_tool_context) result2 = await manager2.get_valid_credentials(mock_tool_context)
# Verify second manager retrieved cached credentials successfully # Verify second manager retrieved cached credentials successfully
assert result2 == cached_creds assert result2 == mock_cached_creds
assert manager2.credentials.credentials == cached_creds assert manager2.credentials_config.credentials is None
# Verify no new OAuth flow was requested
assert ( assert (
mock_tool_context.request_credential.call_count == 0 cached_creds_json == mock_tool_context.state[BIGQUERY_TOKEN_CACHE_KEY]
) # Only from first manager )
# The from_authorized_user_info should be called with the complete JSON structure
mock_from_json.assert_called_once()
# Extract the actual argument that was passed to verify it's the right JSON structure
actual_json_arg = mock_from_json.call_args[0][0]
# We need to parse and compare the structure rather than exact string match
# since the order of keys in JSON might differ
import json
expected_data = json.loads(mock_creds_json)
actual_data = (
actual_json_arg
if isinstance(actual_json_arg, dict)
else json.loads(actual_json_arg)
)
assert actual_data == expected_data

View File

@ -17,7 +17,7 @@ from unittest.mock import Mock
from unittest.mock import patch from unittest.mock import patch
from google.adk.tools import ToolContext from google.adk.tools import ToolContext
from google.adk.tools.bigquery.bigquery_credentials import BigQueryCredentials from google.adk.tools.bigquery.bigquery_credentials import BigQueryCredentialsConfig
from google.adk.tools.bigquery.bigquery_credentials import BigQueryCredentialsManager from google.adk.tools.bigquery.bigquery_credentials import BigQueryCredentialsManager
from google.adk.tools.bigquery.bigquery_tool import BigQueryTool from google.adk.tools.bigquery.bigquery_tool import BigQueryTool
# Mock the Google OAuth and API dependencies # Mock the Google OAuth and API dependencies
@ -78,7 +78,7 @@ class TestBigQueryTool:
@pytest.fixture @pytest.fixture
def credentials_config(self): def credentials_config(self):
"""Create credentials configuration for testing.""" """Create credentials configuration for testing."""
return BigQueryCredentials( return BigQueryCredentialsConfig(
client_id="test_client_id", client_id="test_client_id",
client_secret="test_client_secret", client_secret="test_client_secret",
scopes=["https://www.googleapis.com/auth/bigquery"], scopes=["https://www.googleapis.com/auth/bigquery"],