mirror of
https://github.com/EvolutionAPI/adk-python.git
synced 2025-07-13 15:14:50 -06:00
fix: Fix filtering by user_id for vertex ai session service listing
When the user id contains special characters (i.e. an email), we have added in extra url parsing to address those characters. We have also added an if statement to use the correct url when there is no user_id supplied. Copybara import of the project: -- ef8499001afaea40bd037c4e9946b883e23a5854 by Danny Park <dpark@calicolabs.com>: -- 773cd2b50d15b9b056b47b6155df492b0ca8034c by Danny Park <dpark@calicolabs.com>: COPYBARA_INTEGRATE_REVIEW=https://github.com/google/adk-python/pull/996 from dpark27:fix/list_vertex_ai_sessions d351d7f6017c03165129adc7d0212f21d1340d88 PiperOrigin-RevId: 764522026
This commit is contained in:
parent
fc3e374c86
commit
9d4ca4ed44
@ -16,8 +16,10 @@ from __future__ import annotations
|
||||
import asyncio
|
||||
import logging
|
||||
import re
|
||||
import time
|
||||
from typing import Any
|
||||
from typing import Optional
|
||||
import urllib.parse
|
||||
|
||||
from dateutil import parser
|
||||
from google.genai import types
|
||||
@ -186,10 +188,15 @@ class VertexAiSessionService(BaseSessionService):
|
||||
) -> ListSessionsResponse:
|
||||
reasoning_engine_id = _parse_reasoning_engine_id(app_name)
|
||||
|
||||
path = f"reasoningEngines/{reasoning_engine_id}/sessions"
|
||||
if user_id:
|
||||
parsed_user_id = urllib.parse.quote(f'''"{user_id}"''', safe="")
|
||||
path = path + f"?filter=user_id={parsed_user_id}"
|
||||
|
||||
api_client = _get_api_client(self.project, self.location)
|
||||
api_response = await api_client.async_request(
|
||||
http_method='GET',
|
||||
path=f'reasoningEngines/{reasoning_engine_id}/sessions?filter=user_id={user_id}',
|
||||
path=path,
|
||||
request_dict={},
|
||||
)
|
||||
|
||||
|
@ -111,7 +111,7 @@ MOCK_SESSION = Session(
|
||||
|
||||
|
||||
SESSION_REGEX = r'^reasoningEngines/([^/]+)/sessions/([^/]+)$'
|
||||
SESSIONS_REGEX = r'^reasoningEngines/([^/]+)/sessions\?filter=user_id=([^/]+)$'
|
||||
SESSIONS_REGEX = r'^reasoningEngines/([^/]+)/sessions\?filter=user_id=%22([^%]+)%22.*$' # %22 represents double-quotes in a URL-encoded string
|
||||
EVENTS_REGEX = r'^reasoningEngines/([^/]+)/sessions/([^/]+)/events$'
|
||||
LRO_REGEX = r'^operations/([^/]+)$'
|
||||
|
||||
@ -127,7 +127,7 @@ class MockApiClient:
|
||||
async def async_request(
|
||||
self, http_method: str, path: str, request_dict: dict[str, Any]
|
||||
):
|
||||
"""Mocks the API Client request method."""
|
||||
"""Mocks the API Client request method"""
|
||||
if http_method == 'GET':
|
||||
if re.match(SESSION_REGEX, path):
|
||||
match = re.match(SESSION_REGEX, path)
|
||||
@ -149,20 +149,14 @@ class MockApiClient:
|
||||
elif re.match(EVENTS_REGEX, path):
|
||||
match = re.match(EVENTS_REGEX, path)
|
||||
if match:
|
||||
return {
|
||||
'sessionEvents': (
|
||||
self.event_dict[match.group(2)]
|
||||
if match.group(2) in self.event_dict
|
||||
else []
|
||||
)
|
||||
}
|
||||
session_id = match.group(2)
|
||||
return {'sessionEvents': self.event_dict.get(session_id, [])}
|
||||
elif re.match(LRO_REGEX, path):
|
||||
# Mock long-running operation as completed
|
||||
return {
|
||||
'name': (
|
||||
'projects/test-project/locations/test-location/'
|
||||
'reasoningEngines/123/sessions/4'
|
||||
),
|
||||
'name': path,
|
||||
'done': True,
|
||||
'response': self.session_dict['4'] # Return the created session
|
||||
}
|
||||
else:
|
||||
raise ValueError(f'Unsupported path: {path}')
|
||||
@ -225,10 +219,10 @@ def mock_get_api_client():
|
||||
async def test_get_empty_session():
|
||||
session_service = mock_vertex_ai_session_service()
|
||||
with pytest.raises(ValueError) as excinfo:
|
||||
assert await session_service.get_session(
|
||||
await session_service.get_session(
|
||||
app_name='123', user_id='user', session_id='0'
|
||||
)
|
||||
assert str(excinfo.value) == 'Session not found: 0'
|
||||
assert str(excinfo.value) == 'Session not found: 0'
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@ -247,10 +241,10 @@ async def test_get_and_delete_session():
|
||||
app_name='123', user_id='user', session_id='1'
|
||||
)
|
||||
with pytest.raises(ValueError) as excinfo:
|
||||
assert await session_service.get_session(
|
||||
await session_service.get_session(
|
||||
app_name='123', user_id='user', session_id='1'
|
||||
)
|
||||
assert str(excinfo.value) == 'Session not found: 1'
|
||||
assert str(excinfo.value) == 'Session not found: 1'
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@ -292,6 +286,6 @@ async def test_create_session_with_custom_session_id():
|
||||
await session_service.create_session(
|
||||
app_name='123', user_id='user', session_id='1'
|
||||
)
|
||||
assert str(excinfo.value) == (
|
||||
'User-provided Session id is not supported for VertexAISessionService.'
|
||||
)
|
||||
assert str(excinfo.value) == (
|
||||
'User-provided Session id is not supported for VertexAISessionService.'
|
||||
)
|
||||
|
Loading…
Reference in New Issue
Block a user