fix: Fix filtering by user_id for vertex ai session service listing

When the user id contains special characters (i.e. an email), we have added in extra url parsing to address those characters. We have also added an if statement to use the correct url when there is no user_id supplied.

Copybara import of the project:

--
ef8499001afaea40bd037c4e9946b883e23a5854 by Danny Park <dpark@calicolabs.com>:
--
773cd2b50d15b9b056b47b6155df492b0ca8034c by Danny Park <dpark@calicolabs.com>:

COPYBARA_INTEGRATE_REVIEW=https://github.com/google/adk-python/pull/996 from dpark27:fix/list_vertex_ai_sessions d351d7f6017c03165129adc7d0212f21d1340d88
PiperOrigin-RevId: 764522026
This commit is contained in:
Danny Park 2025-05-28 19:46:03 -07:00 committed by Copybara-Service
parent fc3e374c86
commit 9d4ca4ed44
2 changed files with 22 additions and 21 deletions

View File

@ -16,8 +16,10 @@ from __future__ import annotations
import asyncio
import logging
import re
import time
from typing import Any
from typing import Optional
import urllib.parse
from dateutil import parser
from google.genai import types
@ -186,10 +188,15 @@ class VertexAiSessionService(BaseSessionService):
) -> ListSessionsResponse:
reasoning_engine_id = _parse_reasoning_engine_id(app_name)
path = f"reasoningEngines/{reasoning_engine_id}/sessions"
if user_id:
parsed_user_id = urllib.parse.quote(f'''"{user_id}"''', safe="")
path = path + f"?filter=user_id={parsed_user_id}"
api_client = _get_api_client(self.project, self.location)
api_response = await api_client.async_request(
http_method='GET',
path=f'reasoningEngines/{reasoning_engine_id}/sessions?filter=user_id={user_id}',
path=path,
request_dict={},
)

View File

@ -111,7 +111,7 @@ MOCK_SESSION = Session(
SESSION_REGEX = r'^reasoningEngines/([^/]+)/sessions/([^/]+)$'
SESSIONS_REGEX = r'^reasoningEngines/([^/]+)/sessions\?filter=user_id=([^/]+)$'
SESSIONS_REGEX = r'^reasoningEngines/([^/]+)/sessions\?filter=user_id=%22([^%]+)%22.*$' # %22 represents double-quotes in a URL-encoded string
EVENTS_REGEX = r'^reasoningEngines/([^/]+)/sessions/([^/]+)/events$'
LRO_REGEX = r'^operations/([^/]+)$'
@ -127,7 +127,7 @@ class MockApiClient:
async def async_request(
self, http_method: str, path: str, request_dict: dict[str, Any]
):
"""Mocks the API Client request method."""
"""Mocks the API Client request method"""
if http_method == 'GET':
if re.match(SESSION_REGEX, path):
match = re.match(SESSION_REGEX, path)
@ -149,20 +149,14 @@ class MockApiClient:
elif re.match(EVENTS_REGEX, path):
match = re.match(EVENTS_REGEX, path)
if match:
return {
'sessionEvents': (
self.event_dict[match.group(2)]
if match.group(2) in self.event_dict
else []
)
}
session_id = match.group(2)
return {'sessionEvents': self.event_dict.get(session_id, [])}
elif re.match(LRO_REGEX, path):
# Mock long-running operation as completed
return {
'name': (
'projects/test-project/locations/test-location/'
'reasoningEngines/123/sessions/4'
),
'name': path,
'done': True,
'response': self.session_dict['4'] # Return the created session
}
else:
raise ValueError(f'Unsupported path: {path}')
@ -225,10 +219,10 @@ def mock_get_api_client():
async def test_get_empty_session():
session_service = mock_vertex_ai_session_service()
with pytest.raises(ValueError) as excinfo:
assert await session_service.get_session(
await session_service.get_session(
app_name='123', user_id='user', session_id='0'
)
assert str(excinfo.value) == 'Session not found: 0'
assert str(excinfo.value) == 'Session not found: 0'
@pytest.mark.asyncio
@ -247,10 +241,10 @@ async def test_get_and_delete_session():
app_name='123', user_id='user', session_id='1'
)
with pytest.raises(ValueError) as excinfo:
assert await session_service.get_session(
await session_service.get_session(
app_name='123', user_id='user', session_id='1'
)
assert str(excinfo.value) == 'Session not found: 1'
assert str(excinfo.value) == 'Session not found: 1'
@pytest.mark.asyncio
@ -292,6 +286,6 @@ async def test_create_session_with_custom_session_id():
await session_service.create_session(
app_name='123', user_id='user', session_id='1'
)
assert str(excinfo.value) == (
'User-provided Session id is not supported for VertexAISessionService.'
)
assert str(excinfo.value) == (
'User-provided Session id is not supported for VertexAISessionService.'
)