fix: fix bigquery credentials and bigquery tool to make it compatible with python 3.9 and make the credential serializable in session

PiperOrigin-RevId: 763332829
This commit is contained in:
Xiang (Sean) Zhou
2025-05-26 01:57:40 -07:00
committed by Copybara-Service
parent 55cb36edfe
commit 694eca08e5
5 changed files with 233 additions and 104 deletions

View File

@@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
from typing import List
from typing import Optional
@@ -33,15 +35,31 @@ from ..tool_context import ToolContext
BIGQUERY_TOKEN_CACHE_KEY = "bigquery_token_cache"
class BigQueryCredentials(BaseModel):
class BigQueryCredentialsConfig(BaseModel):
"""Configuration for Google API tools. (Experimental)"""
# Configure the model to allow arbitrary types like Credentials
model_config = {"arbitrary_types_allowed": True}
credentials: Optional[Credentials] = None
"""the existing oauth credentials to use. If set will override client ID,
client secret, and scopes."""
"""the existing oauth credentials to use. If set,this credential will be used
for every end user, end users don't need to be involved in the oauthflow. This
field is mutually exclusive with client_id, client_secret and scopes.
Don't set this field unless you are sure this credential has the permission to
access every end user's data.
Example usage: when the agent is deployed in Google Cloud environment and
the service account (used as application default credentials) has access to
all the required BigQuery resource. Setting this credential to allow user to
access the BigQuery resource without end users going through oauth flow.
To get application default credential: `google.auth.default(...)`. See more
details in https://cloud.google.com/docs/authentication/application-default-credentials.
When the deployed environment cannot provide a pre-existing credential,
consider setting below client_id, client_secret and scope for end users to go
through oauth flow, so that agent can access the user data.
"""
client_id: Optional[str] = None
"""the oauth client ID to use."""
client_secret: Optional[str] = None
@@ -51,12 +69,20 @@ class BigQueryCredentials(BaseModel):
"""
@model_validator(mode="after")
def __post_init__(self) -> "BigQueryCredentials":
def __post_init__(self) -> BigQueryCredentialsConfig:
"""Validate that either credentials or client ID/secret are provided."""
if not self.credentials and (not self.client_id or not self.client_secret):
raise ValueError(
"Must provide either credentials or client_id abd client_secret pair."
)
if self.credentials and (
self.client_id or self.client_secret or self.scopes
):
raise ValueError(
"Cannot provide both existing credentials and"
" client_id/client_secret/scopes."
)
if self.credentials:
self.client_id = self.credentials.client_id
self.client_secret = self.credentials.client_secret
@@ -71,14 +97,14 @@ class BigQueryCredentialsManager:
the same authenticated session without duplicating OAuth logic.
"""
def __init__(self, credentials: BigQueryCredentials):
def __init__(self, credentials_config: BigQueryCredentialsConfig):
"""Initialize the credential manager.
Args:
credential_config: Configuration containing OAuth details or existing
credentials
credentials_config: Credentials containing client id and client secrete
or default credentials
"""
self.credentials = credentials
self.credentials_config = credentials_config
async def get_valid_credentials(
self, tool_context: ToolContext
@@ -87,18 +113,23 @@ class BigQueryCredentialsManager:
Args:
tool_context: The tool context for OAuth flow and state management
required_scopes: Set of OAuth scopes required by the calling tool
Returns:
Valid Credentials object, or None if OAuth flow is needed
"""
# First, try to get cached credentials from the instance
creds = self.credentials.credentials
# First, try to get credentials from the tool context
creds_json = tool_context.state.get(BIGQUERY_TOKEN_CACHE_KEY, None)
creds = (
Credentials.from_authorized_user_info(
creds_json, self.credentials_config.scopes
)
if creds_json
else None
)
# If credentails are empty
# If credentails are empty use the default credential
if not creds:
creds = tool_context.get(BIGQUERY_TOKEN_CACHE_KEY, None)
self.credentials.credentials = creds
creds = self.credentials_config.credentials
# Check if we have valid credentials
if creds and creds.valid:
@@ -110,7 +141,7 @@ class BigQueryCredentialsManager:
creds.refresh(Request())
if creds.valid:
# Cache the refreshed credentials
self.credentials.credentials = creds
tool_context.state[BIGQUERY_TOKEN_CACHE_KEY] = creds.to_json()
return creds
except RefreshError:
# Refresh failed, need to re-authenticate
@@ -140,7 +171,7 @@ class BigQueryCredentialsManager:
tokenUrl="https://oauth2.googleapis.com/token",
scopes={
scope: f"Access to {scope}"
for scope in self.credentials.scopes
for scope in self.credentials_config.scopes
},
)
)
@@ -149,8 +180,8 @@ class BigQueryCredentialsManager:
auth_credential = AuthCredential(
auth_type=AuthCredentialTypes.OAUTH2,
oauth2=OAuth2Auth(
client_id=self.credentials.client_id,
client_secret=self.credentials.client_secret,
client_id=self.credentials_config.client_id,
client_secret=self.credentials_config.client_secret,
),
)
@@ -165,14 +196,14 @@ class BigQueryCredentialsManager:
token=auth_response.oauth2.access_token,
refresh_token=auth_response.oauth2.refresh_token,
token_uri=auth_scheme.flows.authorizationCode.tokenUrl,
client_id=self.credentials.client_id,
client_secret=self.credentials.client_secret,
scopes=list(self.credentials.scopes),
client_id=self.credentials_config.client_id,
client_secret=self.credentials_config.client_secret,
scopes=list(self.credentials_config.scopes),
)
# Cache the new credentials
self.credentials.credentials = creds
tool_context.state[BIGQUERY_TOKEN_CACHE_KEY] = creds
tool_context.state[BIGQUERY_TOKEN_CACHE_KEY] = creds.to_json()
return creds
else:
# Request OAuth flow

View File

@@ -17,13 +17,13 @@ import inspect
from typing import Any
from typing import Callable
from typing import Optional
from typing import override
from google.oauth2.credentials import Credentials
from typing_extensions import override
from ..function_tool import FunctionTool
from ..tool_context import ToolContext
from .bigquery_credentials import BigQueryCredentials
from .bigquery_credentials import BigQueryCredentialsConfig
from .bigquery_credentials import BigQueryCredentialsManager
@@ -41,7 +41,7 @@ class BigQueryTool(FunctionTool):
def __init__(
self,
func: Callable[..., Any],
credentials: Optional[BigQueryCredentials] = None,
credentials: Optional[BigQueryCredentialsConfig] = None,
):
"""Initialize the Google API tool.