diff --git a/src/google/adk/auth/auth_credential.py b/src/google/adk/auth/auth_credential.py index 5d49cee..90fbbee 100644 --- a/src/google/adk/auth/auth_credential.py +++ b/src/google/adk/auth/auth_credential.py @@ -66,7 +66,8 @@ class OAuth2Auth(BaseModelWithConfig): redirect_uri: Optional[str] = None auth_response_uri: Optional[str] = None auth_code: Optional[str] = None - token: Optional[Dict[str, Any]] = None + access_token: Optional[str] = None + refresh_token: Optional[str] = None class ServiceAccountCredential(BaseModelWithConfig): diff --git a/src/google/adk/auth/auth_handler.py b/src/google/adk/auth/auth_handler.py index a218715..a0cabc2 100644 --- a/src/google/adk/auth/auth_handler.py +++ b/src/google/adk/auth/auth_handler.py @@ -82,7 +82,8 @@ class AuthHandler: or not auth_credential.oauth2 or not auth_credential.oauth2.client_id or not auth_credential.oauth2.client_secret - or auth_credential.oauth2.token + or auth_credential.oauth2.access_token + or auth_credential.oauth2.refresh_token ): return self.auth_config.exchanged_auth_credential @@ -93,7 +94,7 @@ class AuthHandler: redirect_uri=auth_credential.oauth2.redirect_uri, state=auth_credential.oauth2.state, ) - token = client.fetch_token( + tokens = client.fetch_token( token_endpoint, authorization_response=auth_credential.oauth2.auth_response_uri, code=auth_credential.oauth2.auth_code, @@ -102,7 +103,10 @@ class AuthHandler: updated_credential = AuthCredential( auth_type=AuthCredentialTypes.OAUTH2, - oauth2=OAuth2Auth(token=dict(token)), + oauth2=OAuth2Auth( + access_token=tokens.get("access_token"), + refresh_token=tokens.get("refresh_token"), + ), ) return updated_credential diff --git a/src/google/adk/tools/openapi_tool/auth/credential_exchangers/oauth2_exchanger.py b/src/google/adk/tools/openapi_tool/auth/credential_exchangers/oauth2_exchanger.py index 267d4a9..dafa4c2 100644 --- a/src/google/adk/tools/openapi_tool/auth/credential_exchangers/oauth2_exchanger.py +++ b/src/google/adk/tools/openapi_tool/auth/credential_exchangers/oauth2_exchanger.py @@ -69,7 +69,7 @@ class OAuth2CredentialExchanger(BaseAuthCredentialExchanger): HTTP bearer token cannot be generated, return the original credential. """ - if "access_token" not in auth_credential.oauth2.token: + if not auth_credential.oauth2.access_token: return auth_credential # Return the access token as a bearer token. @@ -78,7 +78,7 @@ class OAuth2CredentialExchanger(BaseAuthCredentialExchanger): http=HttpAuth( scheme="bearer", credentials=HttpCredentials( - token=auth_credential.oauth2.token["access_token"] + token=auth_credential.oauth2.access_token ), ), ) @@ -111,7 +111,7 @@ class OAuth2CredentialExchanger(BaseAuthCredentialExchanger): return auth_credential # If access token is exchanged, exchange a HTTPBearer token. - if auth_credential.oauth2.token: + if auth_credential.oauth2.access_token: return self.generate_auth_token(auth_credential) return None diff --git a/tests/unittests/auth/test_auth_handler.py b/tests/unittests/auth/test_auth_handler.py index 39ce3ee..6a86e8d 100644 --- a/tests/unittests/auth/test_auth_handler.py +++ b/tests/unittests/auth/test_auth_handler.py @@ -126,12 +126,8 @@ def oauth2_credentials_with_token(): client_id="mock_client_id", client_secret="mock_client_secret", redirect_uri="https://example.com/callback", - token={ - "access_token": "mock_access_token", - "token_type": "bearer", - "expires_in": 3600, - "refresh_token": "mock_refresh_token", - }, + access_token="mock_access_token", + refresh_token="mock_refresh_token", ), ) @@ -458,7 +454,7 @@ class TestParseAndStoreAuthResponse: """Test with an OAuth auth scheme.""" mock_exchange_token.return_value = AuthCredential( auth_type=AuthCredentialTypes.OAUTH2, - oauth2=OAuth2Auth(token={"access_token": "exchanged_token"}), + oauth2=OAuth2Auth(access_token="exchanged_token"), ) handler = AuthHandler(auth_config_with_exchanged) @@ -573,6 +569,6 @@ class TestExchangeAuthToken: handler = AuthHandler(auth_config_with_auth_code) result = handler.exchange_auth_token() - assert result.oauth2.token["access_token"] == "mock_access_token" - assert result.oauth2.token["refresh_token"] == "mock_refresh_token" + assert result.oauth2.access_token == "mock_access_token" + assert result.oauth2.refresh_token == "mock_refresh_token" assert result.auth_type == AuthCredentialTypes.OAUTH2 diff --git a/tests/unittests/flows/llm_flows/test_functions_request_euc.py b/tests/unittests/flows/llm_flows/test_functions_request_euc.py index 5c6b784..6dcb6f9 100644 --- a/tests/unittests/flows/llm_flows/test_functions_request_euc.py +++ b/tests/unittests/flows/llm_flows/test_functions_request_euc.py @@ -246,7 +246,7 @@ def test_function_get_auth_response(): oauth2=OAuth2Auth( client_id='oauth_client_id_1', client_secret='oauth_client_secret1', - token={'access_token': 'token1'}, + access_token='token1', ), ), ) @@ -277,7 +277,7 @@ def test_function_get_auth_response(): oauth2=OAuth2Auth( client_id='oauth_client_id_2', client_secret='oauth_client_secret2', - token={'access_token': 'token2'}, + access_token='token2', ), ), ) diff --git a/tests/unittests/tools/openapi_tool/auth/credential_exchangers/test_oauth2_exchanger.py b/tests/unittests/tools/openapi_tool/auth/credential_exchangers/test_oauth2_exchanger.py index c028e0e..5b59fae 100644 --- a/tests/unittests/tools/openapi_tool/auth/credential_exchangers/test_oauth2_exchanger.py +++ b/tests/unittests/tools/openapi_tool/auth/credential_exchangers/test_oauth2_exchanger.py @@ -110,7 +110,7 @@ def test_generate_auth_token_success( client_secret="test_secret", redirect_uri="http://localhost:8080", auth_response_uri="https://example.com/callback?code=test_code", - token={"access_token": "test_access_token"}, + access_token="test_access_token", ), ) updated_credential = oauth2_exchanger.generate_auth_token(auth_credential) @@ -131,7 +131,7 @@ def test_exchange_credential_generate_auth_token( client_secret="test_secret", redirect_uri="http://localhost:8080", auth_response_uri="https://example.com/callback?code=test_code", - token={"access_token": "test_access_token"}, + access_token="test_access_token", ), )