From 53b14325cebc8a2757d237c9371e825638b0e457 Mon Sep 17 00:00:00 2001 From: Shangjie Chen Date: Tue, 20 May 2025 15:19:47 -0700 Subject: [PATCH] fix: Use sync request method in VertexAiSessionService. The api_client has it own event loop management. PiperOrigin-RevId: 761250268 --- .../adk/sessions/vertex_ai_session_service.py | 14 +++++++------- .../sessions/test_vertex_ai_session_service.py | 16 ---------------- 2 files changed, 7 insertions(+), 23 deletions(-) diff --git a/src/google/adk/sessions/vertex_ai_session_service.py b/src/google/adk/sessions/vertex_ai_session_service.py index a6d9053..b3cbd93 100644 --- a/src/google/adk/sessions/vertex_ai_session_service.py +++ b/src/google/adk/sessions/vertex_ai_session_service.py @@ -68,7 +68,7 @@ class VertexAiSessionService(BaseSessionService): if state: session_json_dict['session_state'] = state - api_response = await self.api_client.async_request( + api_response = self.api_client.request( http_method='POST', path=f'reasoningEngines/{reasoning_engine_id}/sessions', request_dict=session_json_dict, @@ -80,7 +80,7 @@ class VertexAiSessionService(BaseSessionService): max_retry_attempt = 5 while max_retry_attempt >= 0: - lro_response = await self.api_client.async_request( + lro_response = self.api_client.request( http_method='GET', path=f'operations/{operation_id}', request_dict={}, @@ -93,7 +93,7 @@ class VertexAiSessionService(BaseSessionService): max_retry_attempt -= 1 # Get session resource - get_session_api_response = await self.api_client.async_request( + get_session_api_response = self.api_client.request( http_method='GET', path=f'reasoningEngines/{reasoning_engine_id}/sessions/{session_id}', request_dict={}, @@ -123,7 +123,7 @@ class VertexAiSessionService(BaseSessionService): reasoning_engine_id = _parse_reasoning_engine_id(app_name) # Get session resource - get_session_api_response = await self.api_client.async_request( + get_session_api_response = self.api_client.request( http_method='GET', path=f'reasoningEngines/{reasoning_engine_id}/sessions/{session_id}', request_dict={}, @@ -141,7 +141,7 @@ class VertexAiSessionService(BaseSessionService): last_update_time=update_timestamp, ) - list_events_api_response = await self.api_client.async_request( + list_events_api_response = self.api_client.request( http_method='GET', path=f'reasoningEngines/{reasoning_engine_id}/sessions/{session_id}/events', request_dict={}, @@ -206,7 +206,7 @@ class VertexAiSessionService(BaseSessionService): self, *, app_name: str, user_id: str, session_id: str ) -> None: reasoning_engine_id = _parse_reasoning_engine_id(app_name) - await self.api_client.async_request( + self.api_client.request( http_method='DELETE', path=f'reasoningEngines/{reasoning_engine_id}/sessions/{session_id}', request_dict={}, @@ -218,7 +218,7 @@ class VertexAiSessionService(BaseSessionService): await super().append_event(session=session, event=event) reasoning_engine_id = _parse_reasoning_engine_id(session.app_name) - await self.api_client.async_request( + self.api_client.request( http_method='POST', path=f'reasoningEngines/{reasoning_engine_id}/sessions/{session.id}:appendEvent', request_dict=_convert_event_to_json(event), diff --git a/tests/unittests/sessions/test_vertex_ai_session_service.py b/tests/unittests/sessions/test_vertex_ai_session_service.py index 0722662..ba9f945 100644 --- a/tests/unittests/sessions/test_vertex_ai_session_service.py +++ b/tests/unittests/sessions/test_vertex_ai_session_service.py @@ -125,22 +125,6 @@ class MockApiClient: this.event_dict: dict[str, list[Any]] = {} def request(self, http_method: str, path: str, request_dict: dict[str, Any]): - """Mocks the API Client request method.""" - if http_method == 'GET': - if re.match(SESSIONS_REGEX, path): - match = re.match(SESSIONS_REGEX, path) - return { - 'sessions': [ - session - for session in self.session_dict.values() - if session['userId'] == match.group(2) - ], - } - raise ValueError(f'Unsupported sync path: {path}') - - async def async_request( - self, http_method: str, path: str, request_dict: dict[str, Any] - ): """Mocks the API Client request method.""" if http_method == 'GET': if re.match(SESSION_REGEX, path):