From 65063023a5a7cb6cd5db43db14a411213dc8acf5 Mon Sep 17 00:00:00 2001 From: Shangjie Chen Date: Thu, 29 May 2025 19:42:17 -0700 Subject: [PATCH] fix: Continue fetching events if there are multiple pages. Fixes https://github.com/google/adk-python/issues/920 PiperOrigin-RevId: 764985371 --- .../adk/sessions/vertex_ai_session_service.py | 16 +++- .../test_vertex_ai_session_service.py | 76 ++++++++++++++++++- 2 files changed, 87 insertions(+), 5 deletions(-) diff --git a/src/google/adk/sessions/vertex_ai_session_service.py b/src/google/adk/sessions/vertex_ai_session_service.py index 7174967..2cff001 100644 --- a/src/google/adk/sessions/vertex_ai_session_service.py +++ b/src/google/adk/sessions/vertex_ai_session_service.py @@ -159,15 +159,29 @@ class VertexAiSessionService(BaseSessionService): if list_events_api_response.get('httpHeaders', None): return session - session.events = [ + session.events += [ _from_api_event(event) for event in list_events_api_response['sessionEvents'] ] + + while list_events_api_response.get('nextPageToken', None): + page_token = list_events_api_response.get('nextPageToken', None) + list_events_api_response = await api_client.async_request( + http_method='GET', + path=f'reasoningEngines/{reasoning_engine_id}/sessions/{session_id}/events?pageToken={page_token}', + request_dict={}, + ) + session.events += [ + _from_api_event(event) + for event in list_events_api_response['sessionEvents'] + ] + session.events = [ event for event in session.events if event.timestamp <= update_timestamp ] session.events.sort(key=lambda event: event.timestamp) + # Filter events based on config if config: if config.num_recent_events: session.events = session.events[-config.num_recent_events :] diff --git a/tests/unittests/sessions/test_vertex_ai_session_service.py b/tests/unittests/sessions/test_vertex_ai_session_service.py index dc34079..92f6a29 100644 --- a/tests/unittests/sessions/test_vertex_ai_session_service.py +++ b/tests/unittests/sessions/test_vertex_ai_session_service.py @@ -15,6 +15,9 @@ import re import this from typing import Any +from typing import List +from typing import Optional +from typing import Tuple from unittest import mock from dateutil.parser import isoparse @@ -82,6 +85,28 @@ MOCK_EVENT_JSON = [ }, }, ] +MOCK_EVENT_JSON_2 = [ + { + 'name': ( + 'projects/test-project/locations/test-location/' + 'reasoningEngines/123/sessions/2/events/123' + ), + 'invocationId': '222', + 'author': 'user', + 'timestamp': '2024-12-12T12:12:12.123456Z', + }, +] +MOCK_EVENT_JSON_3 = [ + { + 'name': ( + 'projects/test-project/locations/test-location/' + 'reasoningEngines/123/sessions/2/events/456' + ), + 'invocationId': '333', + 'author': 'user', + 'timestamp': '2024-12-12T12:12:12.123456Z', + }, +] MOCK_SESSION = Session( app_name='123', @@ -109,12 +134,35 @@ MOCK_SESSION = Session( ], ) +MOCK_SESSION_2 = Session( + app_name='123', + user_id='user', + id='2', + last_update_time=isoparse(MOCK_SESSION_JSON_2['updateTime']).timestamp(), + events=[ + Event( + id='123', + invocation_id='222', + author='user', + timestamp=isoparse(MOCK_EVENT_JSON_2[0]['timestamp']).timestamp(), + ), + Event( + id='456', + invocation_id='333', + author='user', + timestamp=isoparse(MOCK_EVENT_JSON_3[0]['timestamp']).timestamp(), + ), + ], +) + SESSION_REGEX = r'^reasoningEngines/([^/]+)/sessions/([^/]+)$' SESSIONS_REGEX = ( # %22 represents double-quotes in a URL-encoded string r'^reasoningEngines/([^/]+)/sessions\?filter=user_id=%22([^%]+)%22.*$' ) -EVENTS_REGEX = r'^reasoningEngines/([^/]+)/sessions/([^/]+)/events$' +EVENTS_REGEX = ( + r'^reasoningEngines/([^/]+)/sessions/([^/]+)/events(?:\?pageToken=([^/]+))?' +) LRO_REGEX = r'^operations/([^/]+)$' @@ -124,7 +172,7 @@ class MockApiClient: def __init__(self) -> None: """Initializes MockClient.""" this.session_dict: dict[str, Any] = {} - this.event_dict: dict[str, list[Any]] = {} + this.event_dict: dict[str, Tuple[List[Any], Optional[str]]] = {} async def async_request( self, http_method: str, path: str, request_dict: dict[str, Any] @@ -152,7 +200,13 @@ class MockApiClient: match = re.match(EVENTS_REGEX, path) if match: session_id = match.group(2) - return {'sessionEvents': self.event_dict.get(session_id, [])} + if match.group(3): + return {'sessionEvents': MOCK_EVENT_JSON_3} + events_tuple = self.event_dict.get(session_id, ([], None)) + response = {'sessionEvents': events_tuple[0]} + if events_tuple[1]: + response['nextPageToken'] = events_tuple[1] + return response elif re.match(LRO_REGEX, path): # Mock long-running operation as completed return { @@ -207,7 +261,8 @@ def mock_get_api_client(): '3': MOCK_SESSION_JSON_3, } api_client.event_dict = { - '1': MOCK_EVENT_JSON, + '1': (MOCK_EVENT_JSON, None), + '2': (MOCK_EVENT_JSON_2, 'my_token'), } with mock.patch( 'google.adk.sessions.vertex_ai_session_service._get_api_client', @@ -249,6 +304,19 @@ async def test_get_and_delete_session(): assert str(excinfo.value) == 'Session not found: 1' +@pytest.mark.asyncio +@pytest.mark.usefixtures('mock_get_api_client') +async def test_get_session_with_page_token(): + session_service = mock_vertex_ai_session_service() + + assert ( + await session_service.get_session( + app_name='123', user_id='user', session_id='2' + ) + == MOCK_SESSION_2 + ) + + @pytest.mark.asyncio @pytest.mark.usefixtures('mock_get_api_client') async def test_list_sessions():