mirror of
https://github.com/EvolutionAPI/adk-python.git
synced 2025-07-14 01:41:25 -06:00
fix: fix bigquery credentials and bigquery tool to make it compatible with python 3.9 and make the credential serializable in session
PiperOrigin-RevId: 763332829
This commit is contained in:
parent
55cb36edfe
commit
694eca08e5
@ -12,6 +12,8 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
from typing import List
|
from typing import List
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
@ -33,15 +35,31 @@ from ..tool_context import ToolContext
|
|||||||
BIGQUERY_TOKEN_CACHE_KEY = "bigquery_token_cache"
|
BIGQUERY_TOKEN_CACHE_KEY = "bigquery_token_cache"
|
||||||
|
|
||||||
|
|
||||||
class BigQueryCredentials(BaseModel):
|
class BigQueryCredentialsConfig(BaseModel):
|
||||||
"""Configuration for Google API tools. (Experimental)"""
|
"""Configuration for Google API tools. (Experimental)"""
|
||||||
|
|
||||||
# Configure the model to allow arbitrary types like Credentials
|
# Configure the model to allow arbitrary types like Credentials
|
||||||
model_config = {"arbitrary_types_allowed": True}
|
model_config = {"arbitrary_types_allowed": True}
|
||||||
|
|
||||||
credentials: Optional[Credentials] = None
|
credentials: Optional[Credentials] = None
|
||||||
"""the existing oauth credentials to use. If set will override client ID,
|
"""the existing oauth credentials to use. If set,this credential will be used
|
||||||
client secret, and scopes."""
|
for every end user, end users don't need to be involved in the oauthflow. This
|
||||||
|
field is mutually exclusive with client_id, client_secret and scopes.
|
||||||
|
Don't set this field unless you are sure this credential has the permission to
|
||||||
|
access every end user's data.
|
||||||
|
|
||||||
|
Example usage: when the agent is deployed in Google Cloud environment and
|
||||||
|
the service account (used as application default credentials) has access to
|
||||||
|
all the required BigQuery resource. Setting this credential to allow user to
|
||||||
|
access the BigQuery resource without end users going through oauth flow.
|
||||||
|
|
||||||
|
To get application default credential: `google.auth.default(...)`. See more
|
||||||
|
details in https://cloud.google.com/docs/authentication/application-default-credentials.
|
||||||
|
|
||||||
|
When the deployed environment cannot provide a pre-existing credential,
|
||||||
|
consider setting below client_id, client_secret and scope for end users to go
|
||||||
|
through oauth flow, so that agent can access the user data.
|
||||||
|
"""
|
||||||
client_id: Optional[str] = None
|
client_id: Optional[str] = None
|
||||||
"""the oauth client ID to use."""
|
"""the oauth client ID to use."""
|
||||||
client_secret: Optional[str] = None
|
client_secret: Optional[str] = None
|
||||||
@ -51,12 +69,20 @@ class BigQueryCredentials(BaseModel):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
@model_validator(mode="after")
|
@model_validator(mode="after")
|
||||||
def __post_init__(self) -> "BigQueryCredentials":
|
def __post_init__(self) -> BigQueryCredentialsConfig:
|
||||||
"""Validate that either credentials or client ID/secret are provided."""
|
"""Validate that either credentials or client ID/secret are provided."""
|
||||||
if not self.credentials and (not self.client_id or not self.client_secret):
|
if not self.credentials and (not self.client_id or not self.client_secret):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Must provide either credentials or client_id abd client_secret pair."
|
"Must provide either credentials or client_id abd client_secret pair."
|
||||||
)
|
)
|
||||||
|
if self.credentials and (
|
||||||
|
self.client_id or self.client_secret or self.scopes
|
||||||
|
):
|
||||||
|
raise ValueError(
|
||||||
|
"Cannot provide both existing credentials and"
|
||||||
|
" client_id/client_secret/scopes."
|
||||||
|
)
|
||||||
|
|
||||||
if self.credentials:
|
if self.credentials:
|
||||||
self.client_id = self.credentials.client_id
|
self.client_id = self.credentials.client_id
|
||||||
self.client_secret = self.credentials.client_secret
|
self.client_secret = self.credentials.client_secret
|
||||||
@ -71,14 +97,14 @@ class BigQueryCredentialsManager:
|
|||||||
the same authenticated session without duplicating OAuth logic.
|
the same authenticated session without duplicating OAuth logic.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, credentials: BigQueryCredentials):
|
def __init__(self, credentials_config: BigQueryCredentialsConfig):
|
||||||
"""Initialize the credential manager.
|
"""Initialize the credential manager.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
credential_config: Configuration containing OAuth details or existing
|
credentials_config: Credentials containing client id and client secrete
|
||||||
credentials
|
or default credentials
|
||||||
"""
|
"""
|
||||||
self.credentials = credentials
|
self.credentials_config = credentials_config
|
||||||
|
|
||||||
async def get_valid_credentials(
|
async def get_valid_credentials(
|
||||||
self, tool_context: ToolContext
|
self, tool_context: ToolContext
|
||||||
@ -87,18 +113,23 @@ class BigQueryCredentialsManager:
|
|||||||
|
|
||||||
Args:
|
Args:
|
||||||
tool_context: The tool context for OAuth flow and state management
|
tool_context: The tool context for OAuth flow and state management
|
||||||
required_scopes: Set of OAuth scopes required by the calling tool
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Valid Credentials object, or None if OAuth flow is needed
|
Valid Credentials object, or None if OAuth flow is needed
|
||||||
"""
|
"""
|
||||||
# First, try to get cached credentials from the instance
|
# First, try to get credentials from the tool context
|
||||||
creds = self.credentials.credentials
|
creds_json = tool_context.state.get(BIGQUERY_TOKEN_CACHE_KEY, None)
|
||||||
|
creds = (
|
||||||
|
Credentials.from_authorized_user_info(
|
||||||
|
creds_json, self.credentials_config.scopes
|
||||||
|
)
|
||||||
|
if creds_json
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
|
||||||
# If credentails are empty
|
# If credentails are empty use the default credential
|
||||||
if not creds:
|
if not creds:
|
||||||
creds = tool_context.get(BIGQUERY_TOKEN_CACHE_KEY, None)
|
creds = self.credentials_config.credentials
|
||||||
self.credentials.credentials = creds
|
|
||||||
|
|
||||||
# Check if we have valid credentials
|
# Check if we have valid credentials
|
||||||
if creds and creds.valid:
|
if creds and creds.valid:
|
||||||
@ -110,7 +141,7 @@ class BigQueryCredentialsManager:
|
|||||||
creds.refresh(Request())
|
creds.refresh(Request())
|
||||||
if creds.valid:
|
if creds.valid:
|
||||||
# Cache the refreshed credentials
|
# Cache the refreshed credentials
|
||||||
self.credentials.credentials = creds
|
tool_context.state[BIGQUERY_TOKEN_CACHE_KEY] = creds.to_json()
|
||||||
return creds
|
return creds
|
||||||
except RefreshError:
|
except RefreshError:
|
||||||
# Refresh failed, need to re-authenticate
|
# Refresh failed, need to re-authenticate
|
||||||
@ -140,7 +171,7 @@ class BigQueryCredentialsManager:
|
|||||||
tokenUrl="https://oauth2.googleapis.com/token",
|
tokenUrl="https://oauth2.googleapis.com/token",
|
||||||
scopes={
|
scopes={
|
||||||
scope: f"Access to {scope}"
|
scope: f"Access to {scope}"
|
||||||
for scope in self.credentials.scopes
|
for scope in self.credentials_config.scopes
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
@ -149,8 +180,8 @@ class BigQueryCredentialsManager:
|
|||||||
auth_credential = AuthCredential(
|
auth_credential = AuthCredential(
|
||||||
auth_type=AuthCredentialTypes.OAUTH2,
|
auth_type=AuthCredentialTypes.OAUTH2,
|
||||||
oauth2=OAuth2Auth(
|
oauth2=OAuth2Auth(
|
||||||
client_id=self.credentials.client_id,
|
client_id=self.credentials_config.client_id,
|
||||||
client_secret=self.credentials.client_secret,
|
client_secret=self.credentials_config.client_secret,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -165,14 +196,14 @@ class BigQueryCredentialsManager:
|
|||||||
token=auth_response.oauth2.access_token,
|
token=auth_response.oauth2.access_token,
|
||||||
refresh_token=auth_response.oauth2.refresh_token,
|
refresh_token=auth_response.oauth2.refresh_token,
|
||||||
token_uri=auth_scheme.flows.authorizationCode.tokenUrl,
|
token_uri=auth_scheme.flows.authorizationCode.tokenUrl,
|
||||||
client_id=self.credentials.client_id,
|
client_id=self.credentials_config.client_id,
|
||||||
client_secret=self.credentials.client_secret,
|
client_secret=self.credentials_config.client_secret,
|
||||||
scopes=list(self.credentials.scopes),
|
scopes=list(self.credentials_config.scopes),
|
||||||
)
|
)
|
||||||
|
|
||||||
# Cache the new credentials
|
# Cache the new credentials
|
||||||
self.credentials.credentials = creds
|
tool_context.state[BIGQUERY_TOKEN_CACHE_KEY] = creds.to_json()
|
||||||
tool_context.state[BIGQUERY_TOKEN_CACHE_KEY] = creds
|
|
||||||
return creds
|
return creds
|
||||||
else:
|
else:
|
||||||
# Request OAuth flow
|
# Request OAuth flow
|
||||||
|
@ -17,13 +17,13 @@ import inspect
|
|||||||
from typing import Any
|
from typing import Any
|
||||||
from typing import Callable
|
from typing import Callable
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
from typing import override
|
|
||||||
|
|
||||||
from google.oauth2.credentials import Credentials
|
from google.oauth2.credentials import Credentials
|
||||||
|
from typing_extensions import override
|
||||||
|
|
||||||
from ..function_tool import FunctionTool
|
from ..function_tool import FunctionTool
|
||||||
from ..tool_context import ToolContext
|
from ..tool_context import ToolContext
|
||||||
from .bigquery_credentials import BigQueryCredentials
|
from .bigquery_credentials import BigQueryCredentialsConfig
|
||||||
from .bigquery_credentials import BigQueryCredentialsManager
|
from .bigquery_credentials import BigQueryCredentialsManager
|
||||||
|
|
||||||
|
|
||||||
@ -41,7 +41,7 @@ class BigQueryTool(FunctionTool):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
func: Callable[..., Any],
|
func: Callable[..., Any],
|
||||||
credentials: Optional[BigQueryCredentials] = None,
|
credentials: Optional[BigQueryCredentialsConfig] = None,
|
||||||
):
|
):
|
||||||
"""Initialize the Google API tool.
|
"""Initialize the Google API tool.
|
||||||
|
|
||||||
|
@ -14,7 +14,7 @@
|
|||||||
|
|
||||||
from unittest.mock import Mock
|
from unittest.mock import Mock
|
||||||
|
|
||||||
from google.adk.tools.bigquery.bigquery_credentials import BigQueryCredentials
|
from google.adk.tools.bigquery.bigquery_credentials import BigQueryCredentialsConfig
|
||||||
# Mock the Google OAuth and API dependencies
|
# Mock the Google OAuth and API dependencies
|
||||||
from google.oauth2.credentials import Credentials
|
from google.oauth2.credentials import Credentials
|
||||||
import pytest
|
import pytest
|
||||||
@ -39,7 +39,7 @@ class TestBigQueryCredentials:
|
|||||||
mock_creds.client_secret = "test_client_secret"
|
mock_creds.client_secret = "test_client_secret"
|
||||||
mock_creds.scopes = ["https://www.googleapis.com/auth/calendar"]
|
mock_creds.scopes = ["https://www.googleapis.com/auth/calendar"]
|
||||||
|
|
||||||
config = BigQueryCredentials(credentials=mock_creds)
|
config = BigQueryCredentialsConfig(credentials=mock_creds)
|
||||||
|
|
||||||
# Verify that the credentials are properly stored and attributes are extracted
|
# Verify that the credentials are properly stored and attributes are extracted
|
||||||
assert config.credentials == mock_creds
|
assert config.credentials == mock_creds
|
||||||
@ -53,7 +53,7 @@ class TestBigQueryCredentials:
|
|||||||
This tests the scenario where users want to create new OAuth credentials
|
This tests the scenario where users want to create new OAuth credentials
|
||||||
from scratch using their application's client ID and secret.
|
from scratch using their application's client ID and secret.
|
||||||
"""
|
"""
|
||||||
config = BigQueryCredentials(
|
config = BigQueryCredentialsConfig(
|
||||||
client_id="test_client_id",
|
client_id="test_client_id",
|
||||||
client_secret="test_client_secret",
|
client_secret="test_client_secret",
|
||||||
scopes=["https://www.googleapis.com/auth/bigquery"],
|
scopes=["https://www.googleapis.com/auth/bigquery"],
|
||||||
@ -77,7 +77,7 @@ class TestBigQueryCredentials:
|
|||||||
" pair"
|
" pair"
|
||||||
),
|
),
|
||||||
):
|
):
|
||||||
BigQueryCredentials(client_id="test_client_id")
|
BigQueryCredentialsConfig(client_id="test_client_id")
|
||||||
|
|
||||||
def test_missing_client_id_raises_error(self):
|
def test_missing_client_id_raises_error(self):
|
||||||
"""Test that missing client ID raises appropriate validation error."""
|
"""Test that missing client ID raises appropriate validation error."""
|
||||||
@ -88,7 +88,7 @@ class TestBigQueryCredentials:
|
|||||||
" pair"
|
" pair"
|
||||||
),
|
),
|
||||||
):
|
):
|
||||||
BigQueryCredentials(client_secret="test_client_secret")
|
BigQueryCredentialsConfig(client_secret="test_client_secret")
|
||||||
|
|
||||||
def test_empty_configuration_raises_error(self):
|
def test_empty_configuration_raises_error(self):
|
||||||
"""Test that completely empty configuration is rejected.
|
"""Test that completely empty configuration is rejected.
|
||||||
@ -103,4 +103,4 @@ class TestBigQueryCredentials:
|
|||||||
" pair"
|
" pair"
|
||||||
),
|
),
|
||||||
):
|
):
|
||||||
BigQueryCredentials()
|
BigQueryCredentialsConfig()
|
||||||
|
@ -19,7 +19,7 @@ from unittest.mock import patch
|
|||||||
from google.adk.auth import AuthConfig
|
from google.adk.auth import AuthConfig
|
||||||
from google.adk.tools import ToolContext
|
from google.adk.tools import ToolContext
|
||||||
from google.adk.tools.bigquery.bigquery_credentials import BIGQUERY_TOKEN_CACHE_KEY
|
from google.adk.tools.bigquery.bigquery_credentials import BIGQUERY_TOKEN_CACHE_KEY
|
||||||
from google.adk.tools.bigquery.bigquery_credentials import BigQueryCredentials
|
from google.adk.tools.bigquery.bigquery_credentials import BigQueryCredentialsConfig
|
||||||
from google.adk.tools.bigquery.bigquery_credentials import BigQueryCredentialsManager
|
from google.adk.tools.bigquery.bigquery_credentials import BigQueryCredentialsManager
|
||||||
from google.auth.exceptions import RefreshError
|
from google.auth.exceptions import RefreshError
|
||||||
# Mock the Google OAuth and API dependencies
|
# Mock the Google OAuth and API dependencies
|
||||||
@ -46,15 +46,13 @@ class TestBigQueryCredentialsManager:
|
|||||||
context = Mock(spec=ToolContext)
|
context = Mock(spec=ToolContext)
|
||||||
context.get_auth_response = Mock(return_value=None)
|
context.get_auth_response = Mock(return_value=None)
|
||||||
context.request_credential = Mock()
|
context.request_credential = Mock()
|
||||||
# Mock the get method and state dictionary for caching tests
|
|
||||||
context.get = Mock(return_value=None)
|
|
||||||
context.state = {}
|
context.state = {}
|
||||||
return context
|
return context
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def credentials_config(self):
|
def credentials_config(self):
|
||||||
"""Create a basic credentials configuration for testing."""
|
"""Create a basic credentials configuration for testing."""
|
||||||
return BigQueryCredentials(
|
return BigQueryCredentialsConfig(
|
||||||
client_id="test_client_id",
|
client_id="test_client_id",
|
||||||
client_secret="test_client_secret",
|
client_secret="test_client_secret",
|
||||||
scopes=["https://www.googleapis.com/auth/calendar"],
|
scopes=["https://www.googleapis.com/auth/calendar"],
|
||||||
@ -77,7 +75,7 @@ class TestBigQueryCredentialsManager:
|
|||||||
# Create mock credentials that are already valid
|
# Create mock credentials that are already valid
|
||||||
mock_creds = Mock(spec=Credentials)
|
mock_creds = Mock(spec=Credentials)
|
||||||
mock_creds.valid = True
|
mock_creds.valid = True
|
||||||
manager.credentials.credentials = mock_creds
|
manager.credentials_config.credentials = mock_creds
|
||||||
|
|
||||||
result = await manager.get_valid_credentials(mock_tool_context)
|
result = await manager.get_valid_credentials(mock_tool_context)
|
||||||
|
|
||||||
@ -85,8 +83,6 @@ class TestBigQueryCredentialsManager:
|
|||||||
# Verify no OAuth flow was triggered
|
# Verify no OAuth flow was triggered
|
||||||
mock_tool_context.get_auth_response.assert_not_called()
|
mock_tool_context.get_auth_response.assert_not_called()
|
||||||
mock_tool_context.request_credential.assert_not_called()
|
mock_tool_context.request_credential.assert_not_called()
|
||||||
# Verify cache retrieval wasn't needed since we had valid creds
|
|
||||||
mock_tool_context.get.assert_not_called()
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_get_credentials_from_cache_when_none_in_manager(
|
async def test_get_credentials_from_cache_when_none_in_manager(
|
||||||
@ -99,25 +95,37 @@ class TestBigQueryCredentialsManager:
|
|||||||
doesn't have them loaded.
|
doesn't have them loaded.
|
||||||
"""
|
"""
|
||||||
# Manager starts with no credentials
|
# Manager starts with no credentials
|
||||||
manager.credentials.credentials = None
|
manager.credentials_config.credentials = None
|
||||||
|
|
||||||
# Create mock cached credentials that are valid
|
# Create mock cached credentials JSON that would be stored in cache
|
||||||
mock_cached_creds = Mock(spec=Credentials)
|
mock_cached_creds_json = {
|
||||||
mock_cached_creds.valid = True
|
"token": "cached_token",
|
||||||
|
"refresh_token": "cached_refresh_token",
|
||||||
|
"client_id": "test_client_id",
|
||||||
|
"client_secret": "test_client_secret",
|
||||||
|
}
|
||||||
|
|
||||||
# Set up the tool context to return cached credentials
|
# Set up the tool context state to contain cached credentials
|
||||||
mock_tool_context.get.return_value = mock_cached_creds
|
mock_tool_context.state[BIGQUERY_TOKEN_CACHE_KEY] = mock_cached_creds_json
|
||||||
|
|
||||||
result = await manager.get_valid_credentials(mock_tool_context)
|
# Mock the Credentials.from_authorized_user_info method
|
||||||
|
with patch(
|
||||||
|
"google.oauth2.credentials.Credentials.from_authorized_user_info"
|
||||||
|
) as mock_from_json:
|
||||||
|
mock_creds = Mock(spec=Credentials)
|
||||||
|
mock_creds.valid = True
|
||||||
|
mock_from_json.return_value = mock_creds
|
||||||
|
|
||||||
# Verify credentials were retrieved from cache
|
result = await manager.get_valid_credentials(mock_tool_context)
|
||||||
mock_tool_context.get.assert_called_once_with(
|
|
||||||
BIGQUERY_TOKEN_CACHE_KEY, None
|
# Verify credentials were created from cached JSON
|
||||||
)
|
mock_from_json.assert_called_once_with(
|
||||||
# Verify cached credentials were loaded into manager
|
mock_cached_creds_json, manager.credentials_config.scopes
|
||||||
assert manager.credentials.credentials == mock_cached_creds
|
)
|
||||||
# Verify valid cached credentials were returned
|
# Verify loaded credentials were not cached into manager
|
||||||
assert result == mock_cached_creds
|
assert manager.credentials_config.credentials is None
|
||||||
|
# Verify valid cached credentials were returned
|
||||||
|
assert result == mock_creds
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_no_credentials_in_manager_or_cache(
|
async def test_no_credentials_in_manager_or_cache(
|
||||||
@ -129,17 +137,13 @@ class TestBigQueryCredentialsManager:
|
|||||||
requiring a new OAuth flow to be initiated.
|
requiring a new OAuth flow to be initiated.
|
||||||
"""
|
"""
|
||||||
# Manager starts with no credentials
|
# Manager starts with no credentials
|
||||||
manager.credentials.credentials = None
|
manager.credentials_config.credentials = None
|
||||||
# Cache also returns None
|
# Cache is also empty (state dict doesn't contain the key)
|
||||||
mock_tool_context.get.return_value = None
|
|
||||||
|
|
||||||
result = await manager.get_valid_credentials(mock_tool_context)
|
result = await manager.get_valid_credentials(mock_tool_context)
|
||||||
|
|
||||||
# Should trigger OAuth flow and return None (flow in progress)
|
# Should trigger OAuth flow and return None (flow in progress)
|
||||||
assert result is None
|
assert result is None
|
||||||
mock_tool_context.get.assert_called_once_with(
|
|
||||||
BIGQUERY_TOKEN_CACHE_KEY, None
|
|
||||||
)
|
|
||||||
mock_tool_context.request_credential.assert_called_once()
|
mock_tool_context.request_credential.assert_called_once()
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@ -152,33 +156,62 @@ class TestBigQueryCredentialsManager:
|
|||||||
This tests the interaction between caching and refresh functionality,
|
This tests the interaction between caching and refresh functionality,
|
||||||
ensuring that expired cached credentials can be refreshed properly.
|
ensuring that expired cached credentials can be refreshed properly.
|
||||||
"""
|
"""
|
||||||
# Manager starts with no credentials
|
# Manager starts with no default credentials
|
||||||
manager.credentials.credentials = None
|
manager.credentials_config.credentials = None
|
||||||
|
|
||||||
|
# Create mock cached credentials JSON
|
||||||
|
mock_cached_creds_json = {
|
||||||
|
"token": "expired_token",
|
||||||
|
"refresh_token": "valid_refresh_token",
|
||||||
|
"client_id": "test_client_id",
|
||||||
|
"client_secret": "test_client_secret",
|
||||||
|
}
|
||||||
|
|
||||||
|
mock_refreshed_creds_json = {
|
||||||
|
"token": "new_token",
|
||||||
|
"refresh_token": "valid_refresh_token",
|
||||||
|
"client_id": "test_client_id",
|
||||||
|
"client_secret": "test_client_secret",
|
||||||
|
}
|
||||||
|
|
||||||
|
# Set up the tool context state to contain cached credentials
|
||||||
|
mock_tool_context.state[BIGQUERY_TOKEN_CACHE_KEY] = mock_cached_creds_json
|
||||||
|
|
||||||
# Create expired cached credentials with refresh token
|
# Create expired cached credentials with refresh token
|
||||||
mock_cached_creds = Mock(spec=Credentials)
|
mock_cached_creds = Mock(spec=Credentials)
|
||||||
mock_cached_creds.valid = False
|
mock_cached_creds.valid = False
|
||||||
mock_cached_creds.expired = True
|
mock_cached_creds.expired = True
|
||||||
mock_cached_creds.refresh_token = "refresh_token"
|
mock_cached_creds.refresh_token = "valid_refresh_token"
|
||||||
|
mock_cached_creds.to_json.return_value = mock_refreshed_creds_json
|
||||||
|
|
||||||
# Mock successful refresh
|
# Mock successful refresh
|
||||||
def mock_refresh(request):
|
def mock_refresh(request):
|
||||||
mock_cached_creds.valid = True
|
mock_cached_creds.valid = True
|
||||||
|
|
||||||
mock_cached_creds.refresh = Mock(side_effect=mock_refresh)
|
mock_cached_creds.refresh = Mock(side_effect=mock_refresh)
|
||||||
mock_tool_context.get.return_value = mock_cached_creds
|
|
||||||
|
|
||||||
result = await manager.get_valid_credentials(mock_tool_context)
|
# Mock the Credentials.from_authorized_user_info method
|
||||||
|
with patch(
|
||||||
|
"google.oauth2.credentials.Credentials.from_authorized_user_info"
|
||||||
|
) as mock_from_json:
|
||||||
|
mock_from_json.return_value = mock_cached_creds
|
||||||
|
|
||||||
# Verify credentials were retrieved from cache
|
result = await manager.get_valid_credentials(mock_tool_context)
|
||||||
mock_tool_context.get.assert_called_once_with(
|
|
||||||
BIGQUERY_TOKEN_CACHE_KEY, None
|
# Verify credentials were created from cached JSON
|
||||||
)
|
mock_from_json.assert_called_once_with(
|
||||||
# Verify refresh was attempted and succeeded
|
mock_cached_creds_json, manager.credentials_config.scopes
|
||||||
mock_cached_creds.refresh.assert_called_once()
|
)
|
||||||
# Verify refreshed credentials were loaded into manager
|
# Verify refresh was attempted and succeeded
|
||||||
assert manager.credentials.credentials == mock_cached_creds
|
mock_cached_creds.refresh.assert_called_once()
|
||||||
assert result == mock_cached_creds
|
# Verify refreshed credentials were not cached into manager
|
||||||
|
assert manager.credentials_config.credentials is None
|
||||||
|
# Verify refreshed credentials were cached
|
||||||
|
assert (
|
||||||
|
"new_token"
|
||||||
|
== mock_tool_context.state[BIGQUERY_TOKEN_CACHE_KEY]["token"]
|
||||||
|
)
|
||||||
|
assert result == mock_cached_creds
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@patch("google.auth.transport.requests.Request")
|
@patch("google.auth.transport.requests.Request")
|
||||||
@ -201,14 +234,14 @@ class TestBigQueryCredentialsManager:
|
|||||||
mock_creds.valid = True
|
mock_creds.valid = True
|
||||||
|
|
||||||
mock_creds.refresh = Mock(side_effect=mock_refresh)
|
mock_creds.refresh = Mock(side_effect=mock_refresh)
|
||||||
manager.credentials.credentials = mock_creds
|
manager.credentials_config.credentials = mock_creds
|
||||||
|
|
||||||
result = await manager.get_valid_credentials(mock_tool_context)
|
result = await manager.get_valid_credentials(mock_tool_context)
|
||||||
|
|
||||||
assert result == mock_creds
|
assert result == mock_creds
|
||||||
mock_creds.refresh.assert_called_once()
|
mock_creds.refresh.assert_called_once()
|
||||||
# Verify credentials were cached after successful refresh
|
# Verify credentials were cached after successful refresh
|
||||||
assert manager.credentials.credentials == mock_creds
|
assert manager.credentials_config.credentials == mock_creds
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@patch("google.auth.transport.requests.Request")
|
@patch("google.auth.transport.requests.Request")
|
||||||
@ -226,7 +259,7 @@ class TestBigQueryCredentialsManager:
|
|||||||
mock_creds.expired = True
|
mock_creds.expired = True
|
||||||
mock_creds.refresh_token = "expired_refresh_token"
|
mock_creds.refresh_token = "expired_refresh_token"
|
||||||
mock_creds.refresh = Mock(side_effect=RefreshError("Refresh failed"))
|
mock_creds.refresh = Mock(side_effect=RefreshError("Refresh failed"))
|
||||||
manager.credentials.credentials = mock_creds
|
manager.credentials_config.credentials = mock_creds
|
||||||
|
|
||||||
result = await manager.get_valid_credentials(mock_tool_context)
|
result = await manager.get_valid_credentials(mock_tool_context)
|
||||||
|
|
||||||
@ -250,16 +283,39 @@ class TestBigQueryCredentialsManager:
|
|||||||
mock_auth_response.oauth2.refresh_token = "new_refresh_token"
|
mock_auth_response.oauth2.refresh_token = "new_refresh_token"
|
||||||
mock_tool_context.get_auth_response.return_value = mock_auth_response
|
mock_tool_context.get_auth_response.return_value = mock_auth_response
|
||||||
|
|
||||||
result = await manager.get_valid_credentials(mock_tool_context)
|
# Create a mock credentials instance that will represent our created credentials
|
||||||
|
mock_creds = Mock(spec=Credentials)
|
||||||
|
# Make the JSON match what a real Credentials object would produce
|
||||||
|
mock_creds_json = (
|
||||||
|
'{"token": "new_access_token", "refresh_token": "new_refresh_token",'
|
||||||
|
' "token_uri": "https://oauth2.googleapis.com/token", "client_id":'
|
||||||
|
' "test_client_id", "client_secret": "test_client_secret", "scopes":'
|
||||||
|
' ["https://www.googleapis.com/auth/calendar"], "universe_domain":'
|
||||||
|
' "googleapis.com", "account": ""}'
|
||||||
|
)
|
||||||
|
mock_creds.to_json.return_value = mock_creds_json
|
||||||
|
|
||||||
# Verify new credentials were created and cached
|
# Use the full module path as it appears in the project structure
|
||||||
assert isinstance(result, Credentials)
|
with patch(
|
||||||
assert result.token == "new_access_token"
|
"google.adk.tools.bigquery.bigquery_credentials.Credentials",
|
||||||
assert result.refresh_token == "new_refresh_token"
|
return_value=mock_creds,
|
||||||
# Verify credentials are cached in manager
|
) as mock_credentials_class:
|
||||||
assert manager.credentials.credentials == result
|
result = await manager.get_valid_credentials(mock_tool_context)
|
||||||
# Verify credentials are also cached in tool context state
|
|
||||||
assert mock_tool_context.state[BIGQUERY_TOKEN_CACHE_KEY] == result
|
# Verify new credentials were created
|
||||||
|
assert result == mock_creds
|
||||||
|
# Verify credentials are created with correct parameters
|
||||||
|
mock_credentials_class.assert_called_once()
|
||||||
|
call_kwargs = mock_credentials_class.call_args[1]
|
||||||
|
assert call_kwargs["token"] == "new_access_token"
|
||||||
|
assert call_kwargs["refresh_token"] == "new_refresh_token"
|
||||||
|
|
||||||
|
# Verify credentials are not cached in manager
|
||||||
|
assert manager.credentials_config.credentials is None
|
||||||
|
# Verify credentials are also cached in tool context state
|
||||||
|
assert (
|
||||||
|
mock_tool_context.state[BIGQUERY_TOKEN_CACHE_KEY] == mock_creds_json
|
||||||
|
)
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_oauth_flow_in_progress(self, manager, mock_tool_context):
|
async def test_oauth_flow_in_progress(self, manager, mock_tool_context):
|
||||||
@ -269,6 +325,7 @@ class TestBigQueryCredentialsManager:
|
|||||||
and the user hasn't completed authorization yet.
|
and the user hasn't completed authorization yet.
|
||||||
"""
|
"""
|
||||||
# No existing credentials, no auth response (flow not completed)
|
# No existing credentials, no auth response (flow not completed)
|
||||||
|
manager.credentials_config.credentials = None
|
||||||
mock_tool_context.get_auth_response.return_value = None
|
mock_tool_context.get_auth_response.return_value = None
|
||||||
|
|
||||||
result = await manager.get_valid_credentials(mock_tool_context)
|
result = await manager.get_valid_credentials(mock_tool_context)
|
||||||
@ -300,28 +357,69 @@ class TestBigQueryCredentialsManager:
|
|||||||
mock_auth_response.oauth2.refresh_token = "cached_refresh_token"
|
mock_auth_response.oauth2.refresh_token = "cached_refresh_token"
|
||||||
mock_tool_context.get_auth_response.return_value = mock_auth_response
|
mock_tool_context.get_auth_response.return_value = mock_auth_response
|
||||||
|
|
||||||
# Complete OAuth flow with first manager
|
# Create the mock credentials instance that will be returned by the constructor
|
||||||
result1 = await manager1.get_valid_credentials(mock_tool_context)
|
mock_creds = Mock(spec=Credentials)
|
||||||
|
# Make sure our mock JSON matches the structure that real Credentials objects produce
|
||||||
|
mock_creds_json = (
|
||||||
|
'{"token": "cached_access_token", "refresh_token":'
|
||||||
|
' "cached_refresh_token", "token_uri":'
|
||||||
|
' "https://oauth2.googleapis.com/token", "client_id": "test_client_id",'
|
||||||
|
' "client_secret": "test_client_secret", "scopes":'
|
||||||
|
' ["https://www.googleapis.com/auth/calendar"], "universe_domain":'
|
||||||
|
' "googleapis.com", "account": ""}'
|
||||||
|
)
|
||||||
|
mock_creds.to_json.return_value = mock_creds_json
|
||||||
|
mock_creds.valid = True
|
||||||
|
|
||||||
# Verify credentials were cached in tool context
|
# Use the correct module path - without the 'src.' prefix
|
||||||
assert BIGQUERY_TOKEN_CACHE_KEY in mock_tool_context.state
|
with patch(
|
||||||
cached_creds = mock_tool_context.state[BIGQUERY_TOKEN_CACHE_KEY]
|
"google.adk.tools.bigquery.bigquery_credentials.Credentials",
|
||||||
|
return_value=mock_creds,
|
||||||
|
) as mock_credentials_class:
|
||||||
|
# Complete OAuth flow with first manager
|
||||||
|
result1 = await manager1.get_valid_credentials(mock_tool_context)
|
||||||
|
|
||||||
|
# Verify credentials were cached in tool context
|
||||||
|
assert BIGQUERY_TOKEN_CACHE_KEY in mock_tool_context.state
|
||||||
|
cached_creds_json = mock_tool_context.state[BIGQUERY_TOKEN_CACHE_KEY]
|
||||||
|
assert cached_creds_json == mock_creds_json
|
||||||
|
|
||||||
# Create second manager instance (simulating new request/session)
|
# Create second manager instance (simulating new request/session)
|
||||||
manager2 = BigQueryCredentialsManager(credentials_config)
|
manager2 = BigQueryCredentialsManager(credentials_config)
|
||||||
|
credentials_config.credentials = None
|
||||||
|
|
||||||
# Reset auth response to None (no new OAuth flow available)
|
# Reset auth response to None (no new OAuth flow available)
|
||||||
mock_tool_context.get_auth_response.return_value = None
|
mock_tool_context.get_auth_response.return_value = None
|
||||||
# Set up get method to return cached credentials
|
|
||||||
mock_tool_context.get.return_value = cached_creds
|
|
||||||
|
|
||||||
# Get credentials with second manager
|
# Mock the from_authorized_user_info method for the second manager
|
||||||
result2 = await manager2.get_valid_credentials(mock_tool_context)
|
with patch(
|
||||||
|
"google.adk.tools.bigquery.bigquery_credentials.Credentials.from_authorized_user_info"
|
||||||
|
) as mock_from_json:
|
||||||
|
mock_cached_creds = Mock(spec=Credentials)
|
||||||
|
mock_cached_creds.valid = True
|
||||||
|
mock_from_json.return_value = mock_cached_creds
|
||||||
|
|
||||||
# Verify second manager retrieved cached credentials successfully
|
# Get credentials with second manager
|
||||||
assert result2 == cached_creds
|
result2 = await manager2.get_valid_credentials(mock_tool_context)
|
||||||
assert manager2.credentials.credentials == cached_creds
|
|
||||||
# Verify no new OAuth flow was requested
|
# Verify second manager retrieved cached credentials successfully
|
||||||
assert (
|
assert result2 == mock_cached_creds
|
||||||
mock_tool_context.request_credential.call_count == 0
|
assert manager2.credentials_config.credentials is None
|
||||||
) # Only from first manager
|
assert (
|
||||||
|
cached_creds_json == mock_tool_context.state[BIGQUERY_TOKEN_CACHE_KEY]
|
||||||
|
)
|
||||||
|
# The from_authorized_user_info should be called with the complete JSON structure
|
||||||
|
mock_from_json.assert_called_once()
|
||||||
|
# Extract the actual argument that was passed to verify it's the right JSON structure
|
||||||
|
actual_json_arg = mock_from_json.call_args[0][0]
|
||||||
|
# We need to parse and compare the structure rather than exact string match
|
||||||
|
# since the order of keys in JSON might differ
|
||||||
|
import json
|
||||||
|
|
||||||
|
expected_data = json.loads(mock_creds_json)
|
||||||
|
actual_data = (
|
||||||
|
actual_json_arg
|
||||||
|
if isinstance(actual_json_arg, dict)
|
||||||
|
else json.loads(actual_json_arg)
|
||||||
|
)
|
||||||
|
assert actual_data == expected_data
|
||||||
|
@ -17,7 +17,7 @@ from unittest.mock import Mock
|
|||||||
from unittest.mock import patch
|
from unittest.mock import patch
|
||||||
|
|
||||||
from google.adk.tools import ToolContext
|
from google.adk.tools import ToolContext
|
||||||
from google.adk.tools.bigquery.bigquery_credentials import BigQueryCredentials
|
from google.adk.tools.bigquery.bigquery_credentials import BigQueryCredentialsConfig
|
||||||
from google.adk.tools.bigquery.bigquery_credentials import BigQueryCredentialsManager
|
from google.adk.tools.bigquery.bigquery_credentials import BigQueryCredentialsManager
|
||||||
from google.adk.tools.bigquery.bigquery_tool import BigQueryTool
|
from google.adk.tools.bigquery.bigquery_tool import BigQueryTool
|
||||||
# Mock the Google OAuth and API dependencies
|
# Mock the Google OAuth and API dependencies
|
||||||
@ -78,7 +78,7 @@ class TestBigQueryTool:
|
|||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def credentials_config(self):
|
def credentials_config(self):
|
||||||
"""Create credentials configuration for testing."""
|
"""Create credentials configuration for testing."""
|
||||||
return BigQueryCredentials(
|
return BigQueryCredentialsConfig(
|
||||||
client_id="test_client_id",
|
client_id="test_client_id",
|
||||||
client_secret="test_client_secret",
|
client_secret="test_client_secret",
|
||||||
scopes=["https://www.googleapis.com/auth/bigquery"],
|
scopes=["https://www.googleapis.com/auth/bigquery"],
|
||||||
|
Loading…
Reference in New Issue
Block a user