diff --git a/src/google/adk/sessions/vertex_ai_session_service.py b/src/google/adk/sessions/vertex_ai_session_service.py index f4d9f3b..1352728 100644 --- a/src/google/adk/sessions/vertex_ai_session_service.py +++ b/src/google/adk/sessions/vertex_ai_session_service.py @@ -11,9 +11,9 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import asyncio import logging import re -import time from typing import Any from typing import Optional @@ -69,7 +69,8 @@ class VertexAiSessionService(BaseSessionService): if state: session_json_dict['session_state'] = state - api_response = self.api_client.request( + api_client = _get_api_client(self.project, self.location) + api_response = await api_client.async_request( http_method='POST', path=f'reasoningEngines/{reasoning_engine_id}/sessions', request_dict=session_json_dict, @@ -81,7 +82,7 @@ class VertexAiSessionService(BaseSessionService): max_retry_attempt = 5 while max_retry_attempt >= 0: - lro_response = self.api_client.request( + lro_response = await api_client.async_request( http_method='GET', path=f'operations/{operation_id}', request_dict={}, @@ -90,11 +91,11 @@ class VertexAiSessionService(BaseSessionService): if lro_response.get('done', None): break - time.sleep(1) + await asyncio.sleep(1) max_retry_attempt -= 1 # Get session resource - get_session_api_response = self.api_client.request( + get_session_api_response = await api_client.async_request( http_method='GET', path=f'reasoningEngines/{reasoning_engine_id}/sessions/{session_id}', request_dict={}, @@ -124,7 +125,8 @@ class VertexAiSessionService(BaseSessionService): reasoning_engine_id = _parse_reasoning_engine_id(app_name) # Get session resource - get_session_api_response = self.api_client.request( + api_client = _get_api_client(self.project, self.location) + get_session_api_response = await api_client.async_request( http_method='GET', path=f'reasoningEngines/{reasoning_engine_id}/sessions/{session_id}', request_dict={}, @@ -142,7 +144,7 @@ class VertexAiSessionService(BaseSessionService): last_update_time=update_timestamp, ) - list_events_api_response = self.api_client.request( + list_events_api_response = await api_client.async_request( http_method='GET', path=f'reasoningEngines/{reasoning_engine_id}/sessions/{session_id}/events', request_dict={}, @@ -181,7 +183,8 @@ class VertexAiSessionService(BaseSessionService): ) -> ListSessionsResponse: reasoning_engine_id = _parse_reasoning_engine_id(app_name) - api_response = self.api_client.request( + api_client = _get_api_client(self.project, self.location) + api_response = await api_client.async_request( http_method='GET', path=f'reasoningEngines/{reasoning_engine_id}/sessions?filter=user_id={user_id}', request_dict={}, @@ -207,7 +210,8 @@ class VertexAiSessionService(BaseSessionService): self, *, app_name: str, user_id: str, session_id: str ) -> None: reasoning_engine_id = _parse_reasoning_engine_id(app_name) - self.api_client.request( + api_client = _get_api_client(self.project, self.location) + await api_client.async_request( http_method='DELETE', path=f'reasoningEngines/{reasoning_engine_id}/sessions/{session_id}', request_dict={}, @@ -219,15 +223,25 @@ class VertexAiSessionService(BaseSessionService): await super().append_event(session=session, event=event) reasoning_engine_id = _parse_reasoning_engine_id(session.app_name) - self.api_client.request( + api_client = _get_api_client(self.project, self.location) + await api_client.async_request( http_method='POST', path=f'reasoningEngines/{reasoning_engine_id}/sessions/{session.id}:appendEvent', request_dict=_convert_event_to_json(event), ) - return event +def _get_api_client(project: str, location: str): + """Instantiates an API client for the given project and location. + + It needs to be instantiated inside each request so that the event loop + management. + """ + client = genai.Client(vertexai=True, project=project, location=location) + return client._api_client + + def _convert_event_to_json(event: Event): metadata_json = { 'partial': event.partial, diff --git a/tests/unittests/sessions/test_vertex_ai_session_service.py b/tests/unittests/sessions/test_vertex_ai_session_service.py index 71145ce..83351bc 100644 --- a/tests/unittests/sessions/test_vertex_ai_session_service.py +++ b/tests/unittests/sessions/test_vertex_ai_session_service.py @@ -15,6 +15,7 @@ import re import this from typing import Any +from unittest import mock from dateutil.parser import isoparse from google.adk.events import Event @@ -123,7 +124,9 @@ class MockApiClient: this.session_dict: dict[str, Any] = {} this.event_dict: dict[str, list[Any]] = {} - def request(self, http_method: str, path: str, request_dict: dict[str, Any]): + async def async_request( + self, http_method: str, path: str, request_dict: dict[str, Any] + ): """Mocks the API Client request method.""" if http_method == 'GET': if re.match(SESSION_REGEX, path): @@ -194,22 +197,31 @@ class MockApiClient: def mock_vertex_ai_session_service(): """Creates a mock Vertex AI Session service for testing.""" - service = VertexAiSessionService( + return VertexAiSessionService( project='test-project', location='test-location' ) - service.api_client = MockApiClient() - service.api_client.session_dict = { + + +@pytest.fixture +def mock_get_api_client(): + api_client = MockApiClient() + api_client.session_dict = { '1': MOCK_SESSION_JSON_1, '2': MOCK_SESSION_JSON_2, '3': MOCK_SESSION_JSON_3, } - service.api_client.event_dict = { + api_client.event_dict = { '1': MOCK_EVENT_JSON, } - return service + with mock.patch( + "google.adk.sessions.vertex_ai_session_service._get_api_client", + return_value=api_client, + ): + yield @pytest.mark.asyncio +@pytest.mark.usefixtures('mock_get_api_client') async def test_get_empty_session(): session_service = mock_vertex_ai_session_service() with pytest.raises(ValueError) as excinfo: @@ -220,6 +232,7 @@ async def test_get_empty_session(): @pytest.mark.asyncio +@pytest.mark.usefixtures('mock_get_api_client') async def test_get_and_delete_session(): session_service = mock_vertex_ai_session_service() @@ -241,6 +254,7 @@ async def test_get_and_delete_session(): @pytest.mark.asyncio +@pytest.mark.usefixtures('mock_get_api_client') async def test_list_sessions(): session_service = mock_vertex_ai_session_service() sessions = await session_service.list_sessions(app_name='123', user_id='user') @@ -250,6 +264,7 @@ async def test_list_sessions(): @pytest.mark.asyncio +@pytest.mark.usefixtures('mock_get_api_client') async def test_create_session(): session_service = mock_vertex_ai_session_service() @@ -269,6 +284,7 @@ async def test_create_session(): @pytest.mark.asyncio +@pytest.mark.usefixtures('mock_get_api_client') async def test_create_session_with_custom_session_id(): session_service = mock_vertex_ai_session_service()