diff --git a/src/google/adk/agents/base_agent.py b/src/google/adk/agents/base_agent.py index 067e1aa..ccf7e2b 100644 --- a/src/google/adk/agents/base_agent.py +++ b/src/google/adk/agents/base_agent.py @@ -14,7 +14,8 @@ from __future__ import annotations -from typing import Any +import inspect +from typing import Any, Awaitable, Union from typing import AsyncGenerator from typing import Callable from typing import final @@ -37,10 +38,15 @@ if TYPE_CHECKING: tracer = trace.get_tracer('gcp.vertex.agent') -BeforeAgentCallback = Callable[[CallbackContext], Optional[types.Content]] +BeforeAgentCallback = Callable[ + [CallbackContext], + Union[Awaitable[Optional[types.Content]], Optional[types.Content]], +] - -AfterAgentCallback = Callable[[CallbackContext], Optional[types.Content]] +AfterAgentCallback = Callable[ + [CallbackContext], + Union[Awaitable[Optional[types.Content]], Optional[types.Content]], +] class BaseAgent(BaseModel): @@ -119,7 +125,7 @@ class BaseAgent(BaseModel): with tracer.start_as_current_span(f'agent_run [{self.name}]'): ctx = self._create_invocation_context(parent_context) - if event := self.__handle_before_agent_callback(ctx): + if event := await self.__handle_before_agent_callback(ctx): yield event if ctx.end_invocation: return @@ -130,7 +136,7 @@ class BaseAgent(BaseModel): if ctx.end_invocation: return - if event := self.__handle_after_agent_callback(ctx): + if event := await self.__handle_after_agent_callback(ctx): yield event @final @@ -230,7 +236,7 @@ class BaseAgent(BaseModel): invocation_context.branch = f'{parent_context.branch}.{self.name}' return invocation_context - def __handle_before_agent_callback( + async def __handle_before_agent_callback( self, ctx: InvocationContext ) -> Optional[Event]: """Runs the before_agent_callback if it exists. @@ -248,6 +254,9 @@ class BaseAgent(BaseModel): callback_context=callback_context ) + if inspect.isawaitable(before_agent_callback_content): + before_agent_callback_content = await before_agent_callback_content + if before_agent_callback_content: ret_event = Event( invocation_id=ctx.invocation_id, @@ -269,7 +278,7 @@ class BaseAgent(BaseModel): return ret_event - def __handle_after_agent_callback( + async def __handle_after_agent_callback( self, invocation_context: InvocationContext ) -> Optional[Event]: """Runs the after_agent_callback if it exists. @@ -287,6 +296,9 @@ class BaseAgent(BaseModel): callback_context=callback_context ) + if inspect.isawaitable(after_agent_callback_content): + after_agent_callback_content = await after_agent_callback_content + if after_agent_callback_content or callback_context.state.has_delta(): ret_event = Event( invocation_id=invocation_context.invocation_id, diff --git a/src/google/adk/agents/llm_agent.py b/src/google/adk/agents/llm_agent.py index 67e2d31..7bde529 100644 --- a/src/google/adk/agents/llm_agent.py +++ b/src/google/adk/agents/llm_agent.py @@ -49,11 +49,12 @@ logger = logging.getLogger(__name__) BeforeModelCallback: TypeAlias = Callable[ - [CallbackContext, LlmRequest], Optional[LlmResponse] + [CallbackContext, LlmRequest], + Union[Awaitable[Optional[LlmResponse]], Optional[LlmResponse]], ] AfterModelCallback: TypeAlias = Callable[ [CallbackContext, LlmResponse], - Optional[LlmResponse], + Union[Awaitable[Optional[LlmResponse]], Optional[LlmResponse]], ] BeforeToolCallback: TypeAlias = Callable[ [BaseTool, dict[str, Any], ToolContext], diff --git a/src/google/adk/flows/llm_flows/base_llm_flow.py b/src/google/adk/flows/llm_flows/base_llm_flow.py index 31904e3..d1105e3 100644 --- a/src/google/adk/flows/llm_flows/base_llm_flow.py +++ b/src/google/adk/flows/llm_flows/base_llm_flow.py @@ -16,6 +16,7 @@ from __future__ import annotations from abc import ABC import asyncio +import inspect import logging from typing import AsyncGenerator from typing import cast @@ -199,7 +200,7 @@ class BaseLlmFlow(ABC): return "user" else: return invocation_context.agent.name - + assert invocation_context.live_request_queue try: while True: @@ -447,7 +448,7 @@ class BaseLlmFlow(ABC): model_response_event: Event, ) -> AsyncGenerator[LlmResponse, None]: # Runs before_model_callback if it exists. - if response := self._handle_before_model_callback( + if response := await self._handle_before_model_callback( invocation_context, llm_request, model_response_event ): yield response @@ -460,7 +461,7 @@ class BaseLlmFlow(ABC): invocation_context.live_request_queue = LiveRequestQueue() async for llm_response in self.run_live(invocation_context): # Runs after_model_callback if it exists. - if altered_llm_response := self._handle_after_model_callback( + if altered_llm_response := await self._handle_after_model_callback( invocation_context, llm_response, model_response_event ): llm_response = altered_llm_response @@ -489,14 +490,14 @@ class BaseLlmFlow(ABC): llm_response, ) # Runs after_model_callback if it exists. - if altered_llm_response := self._handle_after_model_callback( + if altered_llm_response := await self._handle_after_model_callback( invocation_context, llm_response, model_response_event ): llm_response = altered_llm_response yield llm_response - def _handle_before_model_callback( + async def _handle_before_model_callback( self, invocation_context: InvocationContext, llm_request: LlmRequest, @@ -514,11 +515,16 @@ class BaseLlmFlow(ABC): callback_context = CallbackContext( invocation_context, event_actions=model_response_event.actions ) - return agent.before_model_callback( + before_model_callback_content = agent.before_model_callback( callback_context=callback_context, llm_request=llm_request ) - def _handle_after_model_callback( + if inspect.isawaitable(before_model_callback_content): + before_model_callback_content = await before_model_callback_content + + return before_model_callback_content + + async def _handle_after_model_callback( self, invocation_context: InvocationContext, llm_response: LlmResponse, @@ -536,10 +542,15 @@ class BaseLlmFlow(ABC): callback_context = CallbackContext( invocation_context, event_actions=model_response_event.actions ) - return agent.after_model_callback( + after_model_callback_content = agent.after_model_callback( callback_context=callback_context, llm_response=llm_response ) + if inspect.isawaitable(after_model_callback_content): + after_model_callback_content = await after_model_callback_content + + return after_model_callback_content + def _finalize_model_response_event( self, llm_request: LlmRequest, diff --git a/tests/__init__.py b/tests/__init__.py index 36a1e8d..0a2669d 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -11,4 +11,3 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - diff --git a/tests/integration/fixture/__init__.py b/tests/integration/fixture/__init__.py index 36a1e8d..0a2669d 100644 --- a/tests/integration/fixture/__init__.py +++ b/tests/integration/fixture/__init__.py @@ -11,4 +11,3 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - diff --git a/tests/integration/fixture/callback_agent/__init__.py b/tests/integration/fixture/callback_agent/__init__.py index 44f7dab..c48963c 100644 --- a/tests/integration/fixture/callback_agent/__init__.py +++ b/tests/integration/fixture/callback_agent/__init__.py @@ -12,4 +12,4 @@ # See the License for the specific language governing permissions and # limitations under the License. -from . import agent \ No newline at end of file +from . import agent diff --git a/tests/integration/models/__init__.py b/tests/integration/models/__init__.py index 36a1e8d..0a2669d 100644 --- a/tests/integration/models/__init__.py +++ b/tests/integration/models/__init__.py @@ -11,4 +11,3 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - diff --git a/tests/integration/test_evalute_agent_in_fixture.py b/tests/integration/test_evalute_agent_in_fixture.py index 234e71b..3899a3f 100644 --- a/tests/integration/test_evalute_agent_in_fixture.py +++ b/tests/integration/test_evalute_agent_in_fixture.py @@ -19,6 +19,7 @@ import os from google.adk.evaluation import AgentEvaluator import pytest + def agent_eval_artifacts_in_fixture(): """Get all agents from fixture folder.""" agent_eval_artifacts = [] diff --git a/tests/integration/tools/__init__.py b/tests/integration/tools/__init__.py index 36a1e8d..0a2669d 100644 --- a/tests/integration/tools/__init__.py +++ b/tests/integration/tools/__init__.py @@ -11,4 +11,3 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - diff --git a/tests/unittests/__init__.py b/tests/unittests/__init__.py index 36a1e8d..0a2669d 100644 --- a/tests/unittests/__init__.py +++ b/tests/unittests/__init__.py @@ -11,4 +11,3 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - diff --git a/tests/unittests/agents/__init__.py b/tests/unittests/agents/__init__.py index 36a1e8d..0a2669d 100644 --- a/tests/unittests/agents/__init__.py +++ b/tests/unittests/agents/__init__.py @@ -11,4 +11,3 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - diff --git a/tests/unittests/agents/test_base_agent.py b/tests/unittests/agents/test_base_agent.py index d4e7387..9733586 100644 --- a/tests/unittests/agents/test_base_agent.py +++ b/tests/unittests/agents/test_base_agent.py @@ -33,16 +33,34 @@ def _before_agent_callback_noop(callback_context: CallbackContext) -> None: pass +async def _async_before_agent_callback_noop( + callback_context: CallbackContext, +) -> None: + pass + + def _before_agent_callback_bypass_agent( callback_context: CallbackContext, ) -> types.Content: return types.Content(parts=[types.Part(text='agent run is bypassed.')]) +async def _async_before_agent_callback_bypass_agent( + callback_context: CallbackContext, +) -> types.Content: + return types.Content(parts=[types.Part(text='agent run is bypassed.')]) + + def _after_agent_callback_noop(callback_context: CallbackContext) -> None: pass +async def _async_after_agent_callback_noop( + callback_context: CallbackContext, +) -> None: + pass + + def _after_agent_callback_append_agent_reply( callback_context: CallbackContext, ) -> types.Content: @@ -51,6 +69,14 @@ def _after_agent_callback_append_agent_reply( ) +async def _async_after_agent_callback_append_agent_reply( + callback_context: CallbackContext, +) -> types.Content: + return types.Content( + parts=[types.Part(text='Agent reply from after agent callback.')] + ) + + class _IncompleteAgent(BaseAgent): pass @@ -158,6 +184,34 @@ async def test_run_async_before_agent_callback_noop( spy_run_async_impl.assert_called_once() +@pytest.mark.asyncio +async def test_run_async_with_async_before_agent_callback_noop( + request: pytest.FixtureRequest, + mocker: pytest_mock.MockerFixture, +) -> Union[types.Content, None]: + # Arrange + agent = _TestingAgent( + name=f'{request.function.__name__}_test_agent', + before_agent_callback=_async_before_agent_callback_noop, + ) + parent_ctx = _create_parent_invocation_context( + request.function.__name__, agent + ) + spy_run_async_impl = mocker.spy(agent, BaseAgent._run_async_impl.__name__) + spy_before_agent_callback = mocker.spy(agent, 'before_agent_callback') + + # Act + _ = [e async for e in agent.run_async(parent_ctx)] + + # Assert + spy_before_agent_callback.assert_called_once() + _, kwargs = spy_before_agent_callback.call_args + assert 'callback_context' in kwargs + assert isinstance(kwargs['callback_context'], CallbackContext) + + spy_run_async_impl.assert_called_once() + + @pytest.mark.asyncio async def test_run_async_before_agent_callback_bypass_agent( request: pytest.FixtureRequest, @@ -185,6 +239,33 @@ async def test_run_async_before_agent_callback_bypass_agent( assert events[0].content.parts[0].text == 'agent run is bypassed.' +@pytest.mark.asyncio +async def test_run_async_with_async_before_agent_callback_bypass_agent( + request: pytest.FixtureRequest, + mocker: pytest_mock.MockerFixture, +): + # Arrange + agent = _TestingAgent( + name=f'{request.function.__name__}_test_agent', + before_agent_callback=_async_before_agent_callback_bypass_agent, + ) + parent_ctx = _create_parent_invocation_context( + request.function.__name__, agent + ) + spy_run_async_impl = mocker.spy(agent, BaseAgent._run_async_impl.__name__) + spy_before_agent_callback = mocker.spy(agent, 'before_agent_callback') + + # Act + events = [e async for e in agent.run_async(parent_ctx)] + + # Assert + spy_before_agent_callback.assert_called_once() + spy_run_async_impl.assert_not_called() + + assert len(events) == 1 + assert events[0].content.parts[0].text == 'agent run is bypassed.' + + @pytest.mark.asyncio async def test_run_async_after_agent_callback_noop( request: pytest.FixtureRequest, @@ -211,6 +292,32 @@ async def test_run_async_after_agent_callback_noop( assert len(events) == 1 +@pytest.mark.asyncio +async def test_run_async_with_async_after_agent_callback_noop( + request: pytest.FixtureRequest, + mocker: pytest_mock.MockerFixture, +): + # Arrange + agent = _TestingAgent( + name=f'{request.function.__name__}_test_agent', + after_agent_callback=_async_after_agent_callback_noop, + ) + parent_ctx = _create_parent_invocation_context( + request.function.__name__, agent + ) + spy_after_agent_callback = mocker.spy(agent, 'after_agent_callback') + + # Act + events = [e async for e in agent.run_async(parent_ctx)] + + # Assert + spy_after_agent_callback.assert_called_once() + _, kwargs = spy_after_agent_callback.call_args + assert 'callback_context' in kwargs + assert isinstance(kwargs['callback_context'], CallbackContext) + assert len(events) == 1 + + @pytest.mark.asyncio async def test_run_async_after_agent_callback_append_reply( request: pytest.FixtureRequest, @@ -236,6 +343,31 @@ async def test_run_async_after_agent_callback_append_reply( ) +@pytest.mark.asyncio +async def test_run_async_with_async_after_agent_callback_append_reply( + request: pytest.FixtureRequest, +): + # Arrange + agent = _TestingAgent( + name=f'{request.function.__name__}_test_agent', + after_agent_callback=_async_after_agent_callback_append_agent_reply, + ) + parent_ctx = _create_parent_invocation_context( + request.function.__name__, agent + ) + + # Act + events = [e async for e in agent.run_async(parent_ctx)] + + # Assert + assert len(events) == 2 + assert events[1].author == agent.name + assert ( + events[1].content.parts[0].text + == 'Agent reply from after agent callback.' + ) + + @pytest.mark.asyncio async def test_run_async_incomplete_agent(request: pytest.FixtureRequest): agent = _IncompleteAgent(name=f'{request.function.__name__}_test_agent') diff --git a/tests/unittests/agents/test_llm_agent_callbacks.py b/tests/unittests/agents/test_llm_agent_callbacks.py index 377e1cf..99a606e 100644 --- a/tests/unittests/agents/test_llm_agent_callbacks.py +++ b/tests/unittests/agents/test_llm_agent_callbacks.py @@ -56,10 +56,44 @@ class MockAfterModelCallback(BaseModel): ) +class MockAsyncBeforeModelCallback(BaseModel): + mock_response: str + + async def __call__( + self, + callback_context: CallbackContext, + llm_request: LlmRequest, + ) -> LlmResponse: + return LlmResponse( + content=utils.ModelContent( + [types.Part.from_text(text=self.mock_response)] + ) + ) + + +class MockAsyncAfterModelCallback(BaseModel): + mock_response: str + + async def __call__( + self, + callback_context: CallbackContext, + llm_response: LlmResponse, + ) -> LlmResponse: + return LlmResponse( + content=utils.ModelContent( + [types.Part.from_text(text=self.mock_response)] + ) + ) + + def noop_callback(**kwargs) -> Optional[LlmResponse]: pass +async def async_noop_callback(**kwargs) -> Optional[LlmResponse]: + pass + + @pytest.mark.asyncio async def test_before_model_callback(): responses = ['model_response'] @@ -98,26 +132,6 @@ async def test_before_model_callback_noop(): ] -@pytest.mark.asyncio -async def test_before_model_callback_end(): - responses = ['model_response'] - mock_model = utils.MockModel.create(responses=responses) - agent = Agent( - name='root_agent', - model=mock_model, - before_model_callback=MockBeforeModelCallback( - mock_response='before_model_callback', - ), - ) - - runner = utils.TestInMemoryRunner(agent) - assert utils.simplify_events( - await runner.run_async_with_new_session('test') - ) == [ - ('root_agent', 'before_model_callback'), - ] - - @pytest.mark.asyncio async def test_after_model_callback(): responses = ['model_response'] @@ -136,3 +150,61 @@ async def test_after_model_callback(): ) == [ ('root_agent', 'after_model_callback'), ] + + +@pytest.mark.asyncio +async def test_async_before_model_callback(): + responses = ['model_response'] + mock_model = utils.MockModel.create(responses=responses) + agent = Agent( + name='root_agent', + model=mock_model, + before_model_callback=MockAsyncBeforeModelCallback( + mock_response='async_before_model_callback' + ), + ) + + runner = utils.TestInMemoryRunner(agent) + assert utils.simplify_events( + await runner.run_async_with_new_session('test') + ) == [ + ('root_agent', 'async_before_model_callback'), + ] + + +@pytest.mark.asyncio +async def test_async_before_model_callback_noop(): + responses = ['model_response'] + mock_model = utils.MockModel.create(responses=responses) + agent = Agent( + name='root_agent', + model=mock_model, + before_model_callback=async_noop_callback, + ) + + runner = utils.TestInMemoryRunner(agent) + assert utils.simplify_events( + await runner.run_async_with_new_session('test') + ) == [ + ('root_agent', 'model_response'), + ] + + +@pytest.mark.asyncio +async def test_async_after_model_callback(): + responses = ['model_response'] + mock_model = utils.MockModel.create(responses=responses) + agent = Agent( + name='root_agent', + model=mock_model, + after_model_callback=MockAsyncAfterModelCallback( + mock_response='async_after_model_callback' + ), + ) + + runner = utils.TestInMemoryRunner(agent) + assert utils.simplify_events( + await runner.run_async_with_new_session('test') + ) == [ + ('root_agent', 'async_after_model_callback'), + ] diff --git a/tests/unittests/artifacts/__init__.py b/tests/unittests/artifacts/__init__.py index 36a1e8d..0a2669d 100644 --- a/tests/unittests/artifacts/__init__.py +++ b/tests/unittests/artifacts/__init__.py @@ -11,4 +11,3 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - diff --git a/tests/unittests/fast_api/__init__.py b/tests/unittests/fast_api/__init__.py index 36a1e8d..0a2669d 100644 --- a/tests/unittests/fast_api/__init__.py +++ b/tests/unittests/fast_api/__init__.py @@ -11,4 +11,3 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - diff --git a/tests/unittests/flows/__init__.py b/tests/unittests/flows/__init__.py index 36a1e8d..0a2669d 100644 --- a/tests/unittests/flows/__init__.py +++ b/tests/unittests/flows/__init__.py @@ -11,4 +11,3 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - diff --git a/tests/unittests/flows/llm_flows/__init__.py b/tests/unittests/flows/llm_flows/__init__.py index 36a1e8d..0a2669d 100644 --- a/tests/unittests/flows/llm_flows/__init__.py +++ b/tests/unittests/flows/llm_flows/__init__.py @@ -11,4 +11,3 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - diff --git a/tests/unittests/models/__init__.py b/tests/unittests/models/__init__.py index 36a1e8d..0a2669d 100644 --- a/tests/unittests/models/__init__.py +++ b/tests/unittests/models/__init__.py @@ -11,4 +11,3 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - diff --git a/tests/unittests/sessions/__init__.py b/tests/unittests/sessions/__init__.py index 36a1e8d..0a2669d 100644 --- a/tests/unittests/sessions/__init__.py +++ b/tests/unittests/sessions/__init__.py @@ -11,4 +11,3 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - diff --git a/tests/unittests/streaming/__init__.py b/tests/unittests/streaming/__init__.py index 36a1e8d..0a2669d 100644 --- a/tests/unittests/streaming/__init__.py +++ b/tests/unittests/streaming/__init__.py @@ -11,4 +11,3 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - diff --git a/tests/unittests/tools/__init__.py b/tests/unittests/tools/__init__.py index 36a1e8d..0a2669d 100644 --- a/tests/unittests/tools/__init__.py +++ b/tests/unittests/tools/__init__.py @@ -11,4 +11,3 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - diff --git a/tests/unittests/tools/application_integration_tool/clients/test_connections_client.py b/tests/unittests/tools/application_integration_tool/clients/test_connections_client.py index 975073f..228c0aa 100644 --- a/tests/unittests/tools/application_integration_tool/clients/test_connections_client.py +++ b/tests/unittests/tools/application_integration_tool/clients/test_connections_client.py @@ -74,9 +74,12 @@ class TestConnectionsClient: mock_response.raise_for_status.return_value = None mock_response.json.return_value = {"data": "test"} - with mock.patch.object( - client, "_get_access_token", return_value=mock_credentials.token - ), mock.patch("requests.get", return_value=mock_response): + with ( + mock.patch.object( + client, "_get_access_token", return_value=mock_credentials.token + ), + mock.patch("requests.get", return_value=mock_response), + ): response = client._execute_api_call("https://test.url") assert response.json() == {"data": "test"} requests.get.assert_called_once_with( @@ -121,9 +124,12 @@ class TestConnectionsClient: f"HTTP error {status_code}: {response_text}" ) - with mock.patch.object( - client, "_get_access_token", return_value=mock_credentials.token - ), mock.patch("requests.get", return_value=mock_response): + with ( + mock.patch.object( + client, "_get_access_token", return_value=mock_credentials.token + ), + mock.patch("requests.get", return_value=mock_response), + ): with pytest.raises( ValueError, match="Invalid request. Please check the provided" ): @@ -140,9 +146,12 @@ class TestConnectionsClient: "Internal Server Error" ) - with mock.patch.object( - client, "_get_access_token", return_value=mock_credentials.token - ), mock.patch("requests.get", return_value=mock_response): + with ( + mock.patch.object( + client, "_get_access_token", return_value=mock_credentials.token + ), + mock.patch("requests.get", return_value=mock_response), + ): with pytest.raises(ValueError, match="Request error: "): client._execute_api_call("https://test.url") @@ -151,10 +160,13 @@ class TestConnectionsClient: ): credentials = {"email": "test@example.com"} client = ConnectionsClient(project, location, connection_name, credentials) - with mock.patch.object( - client, "_get_access_token", return_value=mock_credentials.token - ), mock.patch( - "requests.get", side_effect=Exception("Something went wrong") + with ( + mock.patch.object( + client, "_get_access_token", return_value=mock_credentials.token + ), + mock.patch( + "requests.get", side_effect=Exception("Something went wrong") + ), ): with pytest.raises( Exception, match="An unexpected error occurred: Something went wrong" @@ -539,10 +551,13 @@ class TestConnectionsClient: mock_creds.token = "sa_token" mock_creds.expired = False - with mock.patch( - "google.oauth2.service_account.Credentials.from_service_account_info", - return_value=mock_creds, - ), mock.patch.object(mock_creds, "refresh", return_value=None): + with ( + mock.patch( + "google.oauth2.service_account.Credentials.from_service_account_info", + return_value=mock_creds, + ), + mock.patch.object(mock_creds, "refresh", return_value=None), + ): token = client._get_access_token() assert token == "sa_token" google.oauth2.service_account.Credentials.from_service_account_info.assert_called_once_with( @@ -555,10 +570,13 @@ class TestConnectionsClient: self, project, location, connection_name, mock_credentials ): client = ConnectionsClient(project, location, connection_name, None) - with mock.patch( - "google.adk.tools.application_integration_tool.clients.connections_client.default_service_credential", - return_value=(mock_credentials, "test_project_id"), - ), mock.patch.object(mock_credentials, "refresh", return_value=None): + with ( + mock.patch( + "google.adk.tools.application_integration_tool.clients.connections_client.default_service_credential", + return_value=(mock_credentials, "test_project_id"), + ), + mock.patch.object(mock_credentials, "refresh", return_value=None), + ): token = client._get_access_token() assert token == "test_token" diff --git a/tests/unittests/tools/application_integration_tool/clients/test_integration_client.py b/tests/unittests/tools/application_integration_tool/clients/test_integration_client.py index 469fa62..e58377e 100644 --- a/tests/unittests/tools/application_integration_tool/clients/test_integration_client.py +++ b/tests/unittests/tools/application_integration_tool/clients/test_integration_client.py @@ -114,11 +114,14 @@ class TestIntegrationClient: mock_response.status_code = 200 mock_response.json.return_value = {"openApiSpec": json.dumps(expected_spec)} - with mock.patch.object( - IntegrationClient, - "_get_access_token", - return_value=mock_credentials.token, - ), mock.patch("requests.post", return_value=mock_response): + with ( + mock.patch.object( + IntegrationClient, + "_get_access_token", + return_value=mock_credentials.token, + ), + mock.patch("requests.post", return_value=mock_response), + ): client = IntegrationClient( project=project, location=location, @@ -202,11 +205,14 @@ class TestIntegrationClient: f"HTTP error {status_code}: {response_text}" ) - with mock.patch.object( - IntegrationClient, - "_get_access_token", - return_value=mock_credentials.token, - ), mock.patch("requests.post", return_value=mock_response): + with ( + mock.patch.object( + IntegrationClient, + "_get_access_token", + return_value=mock_credentials.token, + ), + mock.patch("requests.post", return_value=mock_response), + ): client = IntegrationClient( project=project, location=location, @@ -243,11 +249,14 @@ class TestIntegrationClient: "Internal Server Error" ) - with mock.patch.object( - IntegrationClient, - "_get_access_token", - return_value=mock_credentials.token, - ), mock.patch("requests.post", return_value=mock_response): + with ( + mock.patch.object( + IntegrationClient, + "_get_access_token", + return_value=mock_credentials.token, + ), + mock.patch("requests.post", return_value=mock_response), + ): client = IntegrationClient( project=project, location=location, @@ -270,12 +279,15 @@ class TestIntegrationClient: mock_credentials, mock_connections_client, ): - with mock.patch.object( - IntegrationClient, - "_get_access_token", - return_value=mock_credentials.token, - ), mock.patch( - "requests.post", side_effect=Exception("Something went wrong") + with ( + mock.patch.object( + IntegrationClient, + "_get_access_token", + return_value=mock_credentials.token, + ), + mock.patch( + "requests.post", side_effect=Exception("Something went wrong") + ), ): client = IntegrationClient( project=project, @@ -486,10 +498,13 @@ class TestIntegrationClient: mock_creds.token = "sa_token" mock_creds.expired = False - with mock.patch( - "google.oauth2.service_account.Credentials.from_service_account_info", - return_value=mock_creds, - ), mock.patch.object(mock_creds, "refresh", return_value=None): + with ( + mock.patch( + "google.oauth2.service_account.Credentials.from_service_account_info", + return_value=mock_creds, + ), + mock.patch.object(mock_creds, "refresh", return_value=None), + ): client = IntegrationClient( project=project, location=location, @@ -518,10 +533,13 @@ class TestIntegrationClient: mock_credentials, ): mock_credentials.expired = False - with mock.patch( - "google.adk.tools.application_integration_tool.clients.integration_client.default_service_credential", - return_value=(mock_credentials, "test_project_id"), - ), mock.patch.object(mock_credentials, "refresh", return_value=None): + with ( + mock.patch( + "google.adk.tools.application_integration_tool.clients.integration_client.default_service_credential", + return_value=(mock_credentials, "test_project_id"), + ), + mock.patch.object(mock_credentials, "refresh", return_value=None), + ): client = IntegrationClient( project=project, location=location, @@ -538,12 +556,15 @@ class TestIntegrationClient: def test_get_access_token_no_valid_credentials( self, project, location, integration_name, trigger_name, connection_name ): - with mock.patch( - "google.adk.tools.application_integration_tool.clients.integration_client.default_service_credential", - return_value=(None, None), - ), mock.patch( - "google.oauth2.service_account.Credentials.from_service_account_info", - return_value=None, + with ( + mock.patch( + "google.adk.tools.application_integration_tool.clients.integration_client.default_service_credential", + return_value=(None, None), + ), + mock.patch( + "google.oauth2.service_account.Credentials.from_service_account_info", + return_value=None, + ), ): client = IntegrationClient( project=project, @@ -587,9 +608,12 @@ class TestIntegrationClient: service_account_json=None, ) client.credential_cache = mock_credentials # Simulate a cached credential - with mock.patch("google.auth.default") as mock_default, mock.patch( - "google.oauth2.service_account.Credentials.from_service_account_info" - ) as mock_sa: + with ( + mock.patch("google.auth.default") as mock_default, + mock.patch( + "google.oauth2.service_account.Credentials.from_service_account_info" + ) as mock_sa, + ): token = client._get_access_token() assert token == "cached_token" mock_default.assert_not_called() diff --git a/tests/unittests/tools/application_integration_tool/test_application_integration_toolset.py b/tests/unittests/tools/application_integration_tool/test_application_integration_toolset.py index b960dd6..28dbb9d 100644 --- a/tests/unittests/tools/application_integration_tool/test_application_integration_toolset.py +++ b/tests/unittests/tools/application_integration_tool/test_application_integration_toolset.py @@ -236,8 +236,7 @@ def test_initialization_with_connection_and_actions( ) mock_connections_client.return_value.get_connection_details.assert_called_once() mock_integration_client.return_value.get_openapi_spec_for_connection.assert_called_once_with( - tool_name, - tool_instructions + tool_name, tool_instructions ) mock_openapi_action_spec_parser.return_value.parse.assert_called_once() assert len(toolset.get_tools()) == 1 @@ -390,6 +389,5 @@ def test_initialization_with_connection_details( tool_instructions=tool_instructions, ) mock_integration_client.return_value.get_openapi_spec_for_connection.assert_called_once_with( - tool_name, - tool_instructions + tool_name, tool_instructions ) diff --git a/tests/unittests/tools/retrieval/__init__.py b/tests/unittests/tools/retrieval/__init__.py index 36a1e8d..0a2669d 100644 --- a/tests/unittests/tools/retrieval/__init__.py +++ b/tests/unittests/tools/retrieval/__init__.py @@ -11,4 +11,3 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -