feat(auth)!: expose access_token and refresh_token at top level of auth credentails

BREAKING CHANGE: `token` attribute of OAuth2Auth credentials used to be a dict containing both access_token and refresh_token, given that may cause confusions, now we replace it with access_token and refresh_token at top level of the auth credentials

PiperOrigin-RevId: 750346172
This commit is contained in:
Xiang (Sean) Zhou 2025-04-22 15:22:51 -07:00 committed by Copybara-Service
parent 49d8c0fbb2
commit 956fb912e8
6 changed files with 21 additions and 20 deletions

View File

@ -66,7 +66,8 @@ class OAuth2Auth(BaseModelWithConfig):
redirect_uri: Optional[str] = None redirect_uri: Optional[str] = None
auth_response_uri: Optional[str] = None auth_response_uri: Optional[str] = None
auth_code: Optional[str] = None auth_code: Optional[str] = None
token: Optional[Dict[str, Any]] = None access_token: Optional[str] = None
refresh_token: Optional[str] = None
class ServiceAccountCredential(BaseModelWithConfig): class ServiceAccountCredential(BaseModelWithConfig):

View File

@ -82,7 +82,8 @@ class AuthHandler:
or not auth_credential.oauth2 or not auth_credential.oauth2
or not auth_credential.oauth2.client_id or not auth_credential.oauth2.client_id
or not auth_credential.oauth2.client_secret or not auth_credential.oauth2.client_secret
or auth_credential.oauth2.token or auth_credential.oauth2.access_token
or auth_credential.oauth2.refresh_token
): ):
return self.auth_config.exchanged_auth_credential return self.auth_config.exchanged_auth_credential
@ -93,7 +94,7 @@ class AuthHandler:
redirect_uri=auth_credential.oauth2.redirect_uri, redirect_uri=auth_credential.oauth2.redirect_uri,
state=auth_credential.oauth2.state, state=auth_credential.oauth2.state,
) )
token = client.fetch_token( tokens = client.fetch_token(
token_endpoint, token_endpoint,
authorization_response=auth_credential.oauth2.auth_response_uri, authorization_response=auth_credential.oauth2.auth_response_uri,
code=auth_credential.oauth2.auth_code, code=auth_credential.oauth2.auth_code,
@ -102,7 +103,10 @@ class AuthHandler:
updated_credential = AuthCredential( updated_credential = AuthCredential(
auth_type=AuthCredentialTypes.OAUTH2, auth_type=AuthCredentialTypes.OAUTH2,
oauth2=OAuth2Auth(token=dict(token)), oauth2=OAuth2Auth(
access_token=tokens.get("access_token"),
refresh_token=tokens.get("refresh_token"),
),
) )
return updated_credential return updated_credential

View File

@ -69,7 +69,7 @@ class OAuth2CredentialExchanger(BaseAuthCredentialExchanger):
HTTP bearer token cannot be generated, return the original credential. HTTP bearer token cannot be generated, return the original credential.
""" """
if "access_token" not in auth_credential.oauth2.token: if not auth_credential.oauth2.access_token:
return auth_credential return auth_credential
# Return the access token as a bearer token. # Return the access token as a bearer token.
@ -78,7 +78,7 @@ class OAuth2CredentialExchanger(BaseAuthCredentialExchanger):
http=HttpAuth( http=HttpAuth(
scheme="bearer", scheme="bearer",
credentials=HttpCredentials( credentials=HttpCredentials(
token=auth_credential.oauth2.token["access_token"] token=auth_credential.oauth2.access_token
), ),
), ),
) )
@ -111,7 +111,7 @@ class OAuth2CredentialExchanger(BaseAuthCredentialExchanger):
return auth_credential return auth_credential
# If access token is exchanged, exchange a HTTPBearer token. # If access token is exchanged, exchange a HTTPBearer token.
if auth_credential.oauth2.token: if auth_credential.oauth2.access_token:
return self.generate_auth_token(auth_credential) return self.generate_auth_token(auth_credential)
return None return None

View File

@ -126,12 +126,8 @@ def oauth2_credentials_with_token():
client_id="mock_client_id", client_id="mock_client_id",
client_secret="mock_client_secret", client_secret="mock_client_secret",
redirect_uri="https://example.com/callback", redirect_uri="https://example.com/callback",
token={ access_token="mock_access_token",
"access_token": "mock_access_token", refresh_token="mock_refresh_token",
"token_type": "bearer",
"expires_in": 3600,
"refresh_token": "mock_refresh_token",
},
), ),
) )
@ -458,7 +454,7 @@ class TestParseAndStoreAuthResponse:
"""Test with an OAuth auth scheme.""" """Test with an OAuth auth scheme."""
mock_exchange_token.return_value = AuthCredential( mock_exchange_token.return_value = AuthCredential(
auth_type=AuthCredentialTypes.OAUTH2, auth_type=AuthCredentialTypes.OAUTH2,
oauth2=OAuth2Auth(token={"access_token": "exchanged_token"}), oauth2=OAuth2Auth(access_token="exchanged_token"),
) )
handler = AuthHandler(auth_config_with_exchanged) handler = AuthHandler(auth_config_with_exchanged)
@ -573,6 +569,6 @@ class TestExchangeAuthToken:
handler = AuthHandler(auth_config_with_auth_code) handler = AuthHandler(auth_config_with_auth_code)
result = handler.exchange_auth_token() result = handler.exchange_auth_token()
assert result.oauth2.token["access_token"] == "mock_access_token" assert result.oauth2.access_token == "mock_access_token"
assert result.oauth2.token["refresh_token"] == "mock_refresh_token" assert result.oauth2.refresh_token == "mock_refresh_token"
assert result.auth_type == AuthCredentialTypes.OAUTH2 assert result.auth_type == AuthCredentialTypes.OAUTH2

View File

@ -246,7 +246,7 @@ def test_function_get_auth_response():
oauth2=OAuth2Auth( oauth2=OAuth2Auth(
client_id='oauth_client_id_1', client_id='oauth_client_id_1',
client_secret='oauth_client_secret1', client_secret='oauth_client_secret1',
token={'access_token': 'token1'}, access_token='token1',
), ),
), ),
) )
@ -277,7 +277,7 @@ def test_function_get_auth_response():
oauth2=OAuth2Auth( oauth2=OAuth2Auth(
client_id='oauth_client_id_2', client_id='oauth_client_id_2',
client_secret='oauth_client_secret2', client_secret='oauth_client_secret2',
token={'access_token': 'token2'}, access_token='token2',
), ),
), ),
) )

View File

@ -110,7 +110,7 @@ def test_generate_auth_token_success(
client_secret="test_secret", client_secret="test_secret",
redirect_uri="http://localhost:8080", redirect_uri="http://localhost:8080",
auth_response_uri="https://example.com/callback?code=test_code", auth_response_uri="https://example.com/callback?code=test_code",
token={"access_token": "test_access_token"}, access_token="test_access_token",
), ),
) )
updated_credential = oauth2_exchanger.generate_auth_token(auth_credential) updated_credential = oauth2_exchanger.generate_auth_token(auth_credential)
@ -131,7 +131,7 @@ def test_exchange_credential_generate_auth_token(
client_secret="test_secret", client_secret="test_secret",
redirect_uri="http://localhost:8080", redirect_uri="http://localhost:8080",
auth_response_uri="https://example.com/callback?code=test_code", auth_response_uri="https://example.com/callback?code=test_code",
token={"access_token": "test_access_token"}, access_token="test_access_token",
), ),
) )