From 9d4ca4ed44cf10bc87f577873faa49af469acc25 Mon Sep 17 00:00:00 2001 From: Danny Park Date: Wed, 28 May 2025 19:46:03 -0700 Subject: [PATCH] fix: Fix filtering by user_id for vertex ai session service listing When the user id contains special characters (i.e. an email), we have added in extra url parsing to address those characters. We have also added an if statement to use the correct url when there is no user_id supplied. Copybara import of the project: -- ef8499001afaea40bd037c4e9946b883e23a5854 by Danny Park : -- 773cd2b50d15b9b056b47b6155df492b0ca8034c by Danny Park : COPYBARA_INTEGRATE_REVIEW=https://github.com/google/adk-python/pull/996 from dpark27:fix/list_vertex_ai_sessions d351d7f6017c03165129adc7d0212f21d1340d88 PiperOrigin-RevId: 764522026 --- .../adk/sessions/vertex_ai_session_service.py | 9 ++++- .../test_vertex_ai_session_service.py | 34 ++++++++----------- 2 files changed, 22 insertions(+), 21 deletions(-) diff --git a/src/google/adk/sessions/vertex_ai_session_service.py b/src/google/adk/sessions/vertex_ai_session_service.py index a147bbe..475377f 100644 --- a/src/google/adk/sessions/vertex_ai_session_service.py +++ b/src/google/adk/sessions/vertex_ai_session_service.py @@ -16,8 +16,10 @@ from __future__ import annotations import asyncio import logging import re +import time from typing import Any from typing import Optional +import urllib.parse from dateutil import parser from google.genai import types @@ -186,10 +188,15 @@ class VertexAiSessionService(BaseSessionService): ) -> ListSessionsResponse: reasoning_engine_id = _parse_reasoning_engine_id(app_name) + path = f"reasoningEngines/{reasoning_engine_id}/sessions" + if user_id: + parsed_user_id = urllib.parse.quote(f'''"{user_id}"''', safe="") + path = path + f"?filter=user_id={parsed_user_id}" + api_client = _get_api_client(self.project, self.location) api_response = await api_client.async_request( http_method='GET', - path=f'reasoningEngines/{reasoning_engine_id}/sessions?filter=user_id={user_id}', + path=path, request_dict={}, ) diff --git a/tests/unittests/sessions/test_vertex_ai_session_service.py b/tests/unittests/sessions/test_vertex_ai_session_service.py index 8144cfe..1794d7a 100644 --- a/tests/unittests/sessions/test_vertex_ai_session_service.py +++ b/tests/unittests/sessions/test_vertex_ai_session_service.py @@ -111,7 +111,7 @@ MOCK_SESSION = Session( SESSION_REGEX = r'^reasoningEngines/([^/]+)/sessions/([^/]+)$' -SESSIONS_REGEX = r'^reasoningEngines/([^/]+)/sessions\?filter=user_id=([^/]+)$' +SESSIONS_REGEX = r'^reasoningEngines/([^/]+)/sessions\?filter=user_id=%22([^%]+)%22.*$' # %22 represents double-quotes in a URL-encoded string EVENTS_REGEX = r'^reasoningEngines/([^/]+)/sessions/([^/]+)/events$' LRO_REGEX = r'^operations/([^/]+)$' @@ -127,7 +127,7 @@ class MockApiClient: async def async_request( self, http_method: str, path: str, request_dict: dict[str, Any] ): - """Mocks the API Client request method.""" + """Mocks the API Client request method""" if http_method == 'GET': if re.match(SESSION_REGEX, path): match = re.match(SESSION_REGEX, path) @@ -149,20 +149,14 @@ class MockApiClient: elif re.match(EVENTS_REGEX, path): match = re.match(EVENTS_REGEX, path) if match: - return { - 'sessionEvents': ( - self.event_dict[match.group(2)] - if match.group(2) in self.event_dict - else [] - ) - } + session_id = match.group(2) + return {'sessionEvents': self.event_dict.get(session_id, [])} elif re.match(LRO_REGEX, path): + # Mock long-running operation as completed return { - 'name': ( - 'projects/test-project/locations/test-location/' - 'reasoningEngines/123/sessions/4' - ), + 'name': path, 'done': True, + 'response': self.session_dict['4'] # Return the created session } else: raise ValueError(f'Unsupported path: {path}') @@ -225,10 +219,10 @@ def mock_get_api_client(): async def test_get_empty_session(): session_service = mock_vertex_ai_session_service() with pytest.raises(ValueError) as excinfo: - assert await session_service.get_session( + await session_service.get_session( app_name='123', user_id='user', session_id='0' ) - assert str(excinfo.value) == 'Session not found: 0' + assert str(excinfo.value) == 'Session not found: 0' @pytest.mark.asyncio @@ -247,10 +241,10 @@ async def test_get_and_delete_session(): app_name='123', user_id='user', session_id='1' ) with pytest.raises(ValueError) as excinfo: - assert await session_service.get_session( + await session_service.get_session( app_name='123', user_id='user', session_id='1' ) - assert str(excinfo.value) == 'Session not found: 1' + assert str(excinfo.value) == 'Session not found: 1' @pytest.mark.asyncio @@ -292,6 +286,6 @@ async def test_create_session_with_custom_session_id(): await session_service.create_session( app_name='123', user_id='user', session_id='1' ) - assert str(excinfo.value) == ( - 'User-provided Session id is not supported for VertexAISessionService.' - ) + assert str(excinfo.value) == ( + 'User-provided Session id is not supported for VertexAISessionService.' + )