fix: fix bigquery credentials and bigquery tool to make it compatible with python 3.9 and make the credential serializable in session

PiperOrigin-RevId: 763332829
This commit is contained in:
Xiang (Sean) Zhou
2025-05-26 01:57:40 -07:00
committed by Copybara-Service
parent 55cb36edfe
commit 694eca08e5
5 changed files with 233 additions and 104 deletions

View File

@@ -14,7 +14,7 @@
from unittest.mock import Mock
from google.adk.tools.bigquery.bigquery_credentials import BigQueryCredentials
from google.adk.tools.bigquery.bigquery_credentials import BigQueryCredentialsConfig
# Mock the Google OAuth and API dependencies
from google.oauth2.credentials import Credentials
import pytest
@@ -39,7 +39,7 @@ class TestBigQueryCredentials:
mock_creds.client_secret = "test_client_secret"
mock_creds.scopes = ["https://www.googleapis.com/auth/calendar"]
config = BigQueryCredentials(credentials=mock_creds)
config = BigQueryCredentialsConfig(credentials=mock_creds)
# Verify that the credentials are properly stored and attributes are extracted
assert config.credentials == mock_creds
@@ -53,7 +53,7 @@ class TestBigQueryCredentials:
This tests the scenario where users want to create new OAuth credentials
from scratch using their application's client ID and secret.
"""
config = BigQueryCredentials(
config = BigQueryCredentialsConfig(
client_id="test_client_id",
client_secret="test_client_secret",
scopes=["https://www.googleapis.com/auth/bigquery"],
@@ -77,7 +77,7 @@ class TestBigQueryCredentials:
" pair"
),
):
BigQueryCredentials(client_id="test_client_id")
BigQueryCredentialsConfig(client_id="test_client_id")
def test_missing_client_id_raises_error(self):
"""Test that missing client ID raises appropriate validation error."""
@@ -88,7 +88,7 @@ class TestBigQueryCredentials:
" pair"
),
):
BigQueryCredentials(client_secret="test_client_secret")
BigQueryCredentialsConfig(client_secret="test_client_secret")
def test_empty_configuration_raises_error(self):
"""Test that completely empty configuration is rejected.
@@ -103,4 +103,4 @@ class TestBigQueryCredentials:
" pair"
),
):
BigQueryCredentials()
BigQueryCredentialsConfig()

View File

@@ -19,7 +19,7 @@ from unittest.mock import patch
from google.adk.auth import AuthConfig
from google.adk.tools import ToolContext
from google.adk.tools.bigquery.bigquery_credentials import BIGQUERY_TOKEN_CACHE_KEY
from google.adk.tools.bigquery.bigquery_credentials import BigQueryCredentials
from google.adk.tools.bigquery.bigquery_credentials import BigQueryCredentialsConfig
from google.adk.tools.bigquery.bigquery_credentials import BigQueryCredentialsManager
from google.auth.exceptions import RefreshError
# Mock the Google OAuth and API dependencies
@@ -46,15 +46,13 @@ class TestBigQueryCredentialsManager:
context = Mock(spec=ToolContext)
context.get_auth_response = Mock(return_value=None)
context.request_credential = Mock()
# Mock the get method and state dictionary for caching tests
context.get = Mock(return_value=None)
context.state = {}
return context
@pytest.fixture
def credentials_config(self):
"""Create a basic credentials configuration for testing."""
return BigQueryCredentials(
return BigQueryCredentialsConfig(
client_id="test_client_id",
client_secret="test_client_secret",
scopes=["https://www.googleapis.com/auth/calendar"],
@@ -77,7 +75,7 @@ class TestBigQueryCredentialsManager:
# Create mock credentials that are already valid
mock_creds = Mock(spec=Credentials)
mock_creds.valid = True
manager.credentials.credentials = mock_creds
manager.credentials_config.credentials = mock_creds
result = await manager.get_valid_credentials(mock_tool_context)
@@ -85,8 +83,6 @@ class TestBigQueryCredentialsManager:
# Verify no OAuth flow was triggered
mock_tool_context.get_auth_response.assert_not_called()
mock_tool_context.request_credential.assert_not_called()
# Verify cache retrieval wasn't needed since we had valid creds
mock_tool_context.get.assert_not_called()
@pytest.mark.asyncio
async def test_get_credentials_from_cache_when_none_in_manager(
@@ -99,25 +95,37 @@ class TestBigQueryCredentialsManager:
doesn't have them loaded.
"""
# Manager starts with no credentials
manager.credentials.credentials = None
manager.credentials_config.credentials = None
# Create mock cached credentials that are valid
mock_cached_creds = Mock(spec=Credentials)
mock_cached_creds.valid = True
# Create mock cached credentials JSON that would be stored in cache
mock_cached_creds_json = {
"token": "cached_token",
"refresh_token": "cached_refresh_token",
"client_id": "test_client_id",
"client_secret": "test_client_secret",
}
# Set up the tool context to return cached credentials
mock_tool_context.get.return_value = mock_cached_creds
# Set up the tool context state to contain cached credentials
mock_tool_context.state[BIGQUERY_TOKEN_CACHE_KEY] = mock_cached_creds_json
result = await manager.get_valid_credentials(mock_tool_context)
# Mock the Credentials.from_authorized_user_info method
with patch(
"google.oauth2.credentials.Credentials.from_authorized_user_info"
) as mock_from_json:
mock_creds = Mock(spec=Credentials)
mock_creds.valid = True
mock_from_json.return_value = mock_creds
# Verify credentials were retrieved from cache
mock_tool_context.get.assert_called_once_with(
BIGQUERY_TOKEN_CACHE_KEY, None
)
# Verify cached credentials were loaded into manager
assert manager.credentials.credentials == mock_cached_creds
# Verify valid cached credentials were returned
assert result == mock_cached_creds
result = await manager.get_valid_credentials(mock_tool_context)
# Verify credentials were created from cached JSON
mock_from_json.assert_called_once_with(
mock_cached_creds_json, manager.credentials_config.scopes
)
# Verify loaded credentials were not cached into manager
assert manager.credentials_config.credentials is None
# Verify valid cached credentials were returned
assert result == mock_creds
@pytest.mark.asyncio
async def test_no_credentials_in_manager_or_cache(
@@ -129,17 +137,13 @@ class TestBigQueryCredentialsManager:
requiring a new OAuth flow to be initiated.
"""
# Manager starts with no credentials
manager.credentials.credentials = None
# Cache also returns None
mock_tool_context.get.return_value = None
manager.credentials_config.credentials = None
# Cache is also empty (state dict doesn't contain the key)
result = await manager.get_valid_credentials(mock_tool_context)
# Should trigger OAuth flow and return None (flow in progress)
assert result is None
mock_tool_context.get.assert_called_once_with(
BIGQUERY_TOKEN_CACHE_KEY, None
)
mock_tool_context.request_credential.assert_called_once()
@pytest.mark.asyncio
@@ -152,33 +156,62 @@ class TestBigQueryCredentialsManager:
This tests the interaction between caching and refresh functionality,
ensuring that expired cached credentials can be refreshed properly.
"""
# Manager starts with no credentials
manager.credentials.credentials = None
# Manager starts with no default credentials
manager.credentials_config.credentials = None
# Create mock cached credentials JSON
mock_cached_creds_json = {
"token": "expired_token",
"refresh_token": "valid_refresh_token",
"client_id": "test_client_id",
"client_secret": "test_client_secret",
}
mock_refreshed_creds_json = {
"token": "new_token",
"refresh_token": "valid_refresh_token",
"client_id": "test_client_id",
"client_secret": "test_client_secret",
}
# Set up the tool context state to contain cached credentials
mock_tool_context.state[BIGQUERY_TOKEN_CACHE_KEY] = mock_cached_creds_json
# Create expired cached credentials with refresh token
mock_cached_creds = Mock(spec=Credentials)
mock_cached_creds.valid = False
mock_cached_creds.expired = True
mock_cached_creds.refresh_token = "refresh_token"
mock_cached_creds.refresh_token = "valid_refresh_token"
mock_cached_creds.to_json.return_value = mock_refreshed_creds_json
# Mock successful refresh
def mock_refresh(request):
mock_cached_creds.valid = True
mock_cached_creds.refresh = Mock(side_effect=mock_refresh)
mock_tool_context.get.return_value = mock_cached_creds
result = await manager.get_valid_credentials(mock_tool_context)
# Mock the Credentials.from_authorized_user_info method
with patch(
"google.oauth2.credentials.Credentials.from_authorized_user_info"
) as mock_from_json:
mock_from_json.return_value = mock_cached_creds
# Verify credentials were retrieved from cache
mock_tool_context.get.assert_called_once_with(
BIGQUERY_TOKEN_CACHE_KEY, None
)
# Verify refresh was attempted and succeeded
mock_cached_creds.refresh.assert_called_once()
# Verify refreshed credentials were loaded into manager
assert manager.credentials.credentials == mock_cached_creds
assert result == mock_cached_creds
result = await manager.get_valid_credentials(mock_tool_context)
# Verify credentials were created from cached JSON
mock_from_json.assert_called_once_with(
mock_cached_creds_json, manager.credentials_config.scopes
)
# Verify refresh was attempted and succeeded
mock_cached_creds.refresh.assert_called_once()
# Verify refreshed credentials were not cached into manager
assert manager.credentials_config.credentials is None
# Verify refreshed credentials were cached
assert (
"new_token"
== mock_tool_context.state[BIGQUERY_TOKEN_CACHE_KEY]["token"]
)
assert result == mock_cached_creds
@pytest.mark.asyncio
@patch("google.auth.transport.requests.Request")
@@ -201,14 +234,14 @@ class TestBigQueryCredentialsManager:
mock_creds.valid = True
mock_creds.refresh = Mock(side_effect=mock_refresh)
manager.credentials.credentials = mock_creds
manager.credentials_config.credentials = mock_creds
result = await manager.get_valid_credentials(mock_tool_context)
assert result == mock_creds
mock_creds.refresh.assert_called_once()
# Verify credentials were cached after successful refresh
assert manager.credentials.credentials == mock_creds
assert manager.credentials_config.credentials == mock_creds
@pytest.mark.asyncio
@patch("google.auth.transport.requests.Request")
@@ -226,7 +259,7 @@ class TestBigQueryCredentialsManager:
mock_creds.expired = True
mock_creds.refresh_token = "expired_refresh_token"
mock_creds.refresh = Mock(side_effect=RefreshError("Refresh failed"))
manager.credentials.credentials = mock_creds
manager.credentials_config.credentials = mock_creds
result = await manager.get_valid_credentials(mock_tool_context)
@@ -250,16 +283,39 @@ class TestBigQueryCredentialsManager:
mock_auth_response.oauth2.refresh_token = "new_refresh_token"
mock_tool_context.get_auth_response.return_value = mock_auth_response
result = await manager.get_valid_credentials(mock_tool_context)
# Create a mock credentials instance that will represent our created credentials
mock_creds = Mock(spec=Credentials)
# Make the JSON match what a real Credentials object would produce
mock_creds_json = (
'{"token": "new_access_token", "refresh_token": "new_refresh_token",'
' "token_uri": "https://oauth2.googleapis.com/token", "client_id":'
' "test_client_id", "client_secret": "test_client_secret", "scopes":'
' ["https://www.googleapis.com/auth/calendar"], "universe_domain":'
' "googleapis.com", "account": ""}'
)
mock_creds.to_json.return_value = mock_creds_json
# Verify new credentials were created and cached
assert isinstance(result, Credentials)
assert result.token == "new_access_token"
assert result.refresh_token == "new_refresh_token"
# Verify credentials are cached in manager
assert manager.credentials.credentials == result
# Verify credentials are also cached in tool context state
assert mock_tool_context.state[BIGQUERY_TOKEN_CACHE_KEY] == result
# Use the full module path as it appears in the project structure
with patch(
"google.adk.tools.bigquery.bigquery_credentials.Credentials",
return_value=mock_creds,
) as mock_credentials_class:
result = await manager.get_valid_credentials(mock_tool_context)
# Verify new credentials were created
assert result == mock_creds
# Verify credentials are created with correct parameters
mock_credentials_class.assert_called_once()
call_kwargs = mock_credentials_class.call_args[1]
assert call_kwargs["token"] == "new_access_token"
assert call_kwargs["refresh_token"] == "new_refresh_token"
# Verify credentials are not cached in manager
assert manager.credentials_config.credentials is None
# Verify credentials are also cached in tool context state
assert (
mock_tool_context.state[BIGQUERY_TOKEN_CACHE_KEY] == mock_creds_json
)
@pytest.mark.asyncio
async def test_oauth_flow_in_progress(self, manager, mock_tool_context):
@@ -269,6 +325,7 @@ class TestBigQueryCredentialsManager:
and the user hasn't completed authorization yet.
"""
# No existing credentials, no auth response (flow not completed)
manager.credentials_config.credentials = None
mock_tool_context.get_auth_response.return_value = None
result = await manager.get_valid_credentials(mock_tool_context)
@@ -300,28 +357,69 @@ class TestBigQueryCredentialsManager:
mock_auth_response.oauth2.refresh_token = "cached_refresh_token"
mock_tool_context.get_auth_response.return_value = mock_auth_response
# Complete OAuth flow with first manager
result1 = await manager1.get_valid_credentials(mock_tool_context)
# Create the mock credentials instance that will be returned by the constructor
mock_creds = Mock(spec=Credentials)
# Make sure our mock JSON matches the structure that real Credentials objects produce
mock_creds_json = (
'{"token": "cached_access_token", "refresh_token":'
' "cached_refresh_token", "token_uri":'
' "https://oauth2.googleapis.com/token", "client_id": "test_client_id",'
' "client_secret": "test_client_secret", "scopes":'
' ["https://www.googleapis.com/auth/calendar"], "universe_domain":'
' "googleapis.com", "account": ""}'
)
mock_creds.to_json.return_value = mock_creds_json
mock_creds.valid = True
# Verify credentials were cached in tool context
assert BIGQUERY_TOKEN_CACHE_KEY in mock_tool_context.state
cached_creds = mock_tool_context.state[BIGQUERY_TOKEN_CACHE_KEY]
# Use the correct module path - without the 'src.' prefix
with patch(
"google.adk.tools.bigquery.bigquery_credentials.Credentials",
return_value=mock_creds,
) as mock_credentials_class:
# Complete OAuth flow with first manager
result1 = await manager1.get_valid_credentials(mock_tool_context)
# Verify credentials were cached in tool context
assert BIGQUERY_TOKEN_CACHE_KEY in mock_tool_context.state
cached_creds_json = mock_tool_context.state[BIGQUERY_TOKEN_CACHE_KEY]
assert cached_creds_json == mock_creds_json
# Create second manager instance (simulating new request/session)
manager2 = BigQueryCredentialsManager(credentials_config)
credentials_config.credentials = None
# Reset auth response to None (no new OAuth flow available)
mock_tool_context.get_auth_response.return_value = None
# Set up get method to return cached credentials
mock_tool_context.get.return_value = cached_creds
# Get credentials with second manager
result2 = await manager2.get_valid_credentials(mock_tool_context)
# Mock the from_authorized_user_info method for the second manager
with patch(
"google.adk.tools.bigquery.bigquery_credentials.Credentials.from_authorized_user_info"
) as mock_from_json:
mock_cached_creds = Mock(spec=Credentials)
mock_cached_creds.valid = True
mock_from_json.return_value = mock_cached_creds
# Verify second manager retrieved cached credentials successfully
assert result2 == cached_creds
assert manager2.credentials.credentials == cached_creds
# Verify no new OAuth flow was requested
assert (
mock_tool_context.request_credential.call_count == 0
) # Only from first manager
# Get credentials with second manager
result2 = await manager2.get_valid_credentials(mock_tool_context)
# Verify second manager retrieved cached credentials successfully
assert result2 == mock_cached_creds
assert manager2.credentials_config.credentials is None
assert (
cached_creds_json == mock_tool_context.state[BIGQUERY_TOKEN_CACHE_KEY]
)
# The from_authorized_user_info should be called with the complete JSON structure
mock_from_json.assert_called_once()
# Extract the actual argument that was passed to verify it's the right JSON structure
actual_json_arg = mock_from_json.call_args[0][0]
# We need to parse and compare the structure rather than exact string match
# since the order of keys in JSON might differ
import json
expected_data = json.loads(mock_creds_json)
actual_data = (
actual_json_arg
if isinstance(actual_json_arg, dict)
else json.loads(actual_json_arg)
)
assert actual_data == expected_data

View File

@@ -17,7 +17,7 @@ from unittest.mock import Mock
from unittest.mock import patch
from google.adk.tools import ToolContext
from google.adk.tools.bigquery.bigquery_credentials import BigQueryCredentials
from google.adk.tools.bigquery.bigquery_credentials import BigQueryCredentialsConfig
from google.adk.tools.bigquery.bigquery_credentials import BigQueryCredentialsManager
from google.adk.tools.bigquery.bigquery_tool import BigQueryTool
# Mock the Google OAuth and API dependencies
@@ -78,7 +78,7 @@ class TestBigQueryTool:
@pytest.fixture
def credentials_config(self):
"""Create credentials configuration for testing."""
return BigQueryCredentials(
return BigQueryCredentialsConfig(
client_id="test_client_id",
client_secret="test_client_secret",
scopes=["https://www.googleapis.com/auth/bigquery"],