fix: Continue fetching events if there are multiple pages.

Fixes https://github.com/google/adk-python/issues/920

PiperOrigin-RevId: 764985371
This commit is contained in:
Shangjie Chen 2025-05-29 19:42:17 -07:00 committed by Copybara-Service
parent 18fbe3cbfc
commit 65063023a5
2 changed files with 87 additions and 5 deletions

View File

@ -159,15 +159,29 @@ class VertexAiSessionService(BaseSessionService):
if list_events_api_response.get('httpHeaders', None): if list_events_api_response.get('httpHeaders', None):
return session return session
session.events = [ session.events += [
_from_api_event(event) _from_api_event(event)
for event in list_events_api_response['sessionEvents'] for event in list_events_api_response['sessionEvents']
] ]
while list_events_api_response.get('nextPageToken', None):
page_token = list_events_api_response.get('nextPageToken', None)
list_events_api_response = await api_client.async_request(
http_method='GET',
path=f'reasoningEngines/{reasoning_engine_id}/sessions/{session_id}/events?pageToken={page_token}',
request_dict={},
)
session.events += [
_from_api_event(event)
for event in list_events_api_response['sessionEvents']
]
session.events = [ session.events = [
event for event in session.events if event.timestamp <= update_timestamp event for event in session.events if event.timestamp <= update_timestamp
] ]
session.events.sort(key=lambda event: event.timestamp) session.events.sort(key=lambda event: event.timestamp)
# Filter events based on config
if config: if config:
if config.num_recent_events: if config.num_recent_events:
session.events = session.events[-config.num_recent_events :] session.events = session.events[-config.num_recent_events :]

View File

@ -15,6 +15,9 @@
import re import re
import this import this
from typing import Any from typing import Any
from typing import List
from typing import Optional
from typing import Tuple
from unittest import mock from unittest import mock
from dateutil.parser import isoparse from dateutil.parser import isoparse
@ -82,6 +85,28 @@ MOCK_EVENT_JSON = [
}, },
}, },
] ]
MOCK_EVENT_JSON_2 = [
{
'name': (
'projects/test-project/locations/test-location/'
'reasoningEngines/123/sessions/2/events/123'
),
'invocationId': '222',
'author': 'user',
'timestamp': '2024-12-12T12:12:12.123456Z',
},
]
MOCK_EVENT_JSON_3 = [
{
'name': (
'projects/test-project/locations/test-location/'
'reasoningEngines/123/sessions/2/events/456'
),
'invocationId': '333',
'author': 'user',
'timestamp': '2024-12-12T12:12:12.123456Z',
},
]
MOCK_SESSION = Session( MOCK_SESSION = Session(
app_name='123', app_name='123',
@ -109,12 +134,35 @@ MOCK_SESSION = Session(
], ],
) )
MOCK_SESSION_2 = Session(
app_name='123',
user_id='user',
id='2',
last_update_time=isoparse(MOCK_SESSION_JSON_2['updateTime']).timestamp(),
events=[
Event(
id='123',
invocation_id='222',
author='user',
timestamp=isoparse(MOCK_EVENT_JSON_2[0]['timestamp']).timestamp(),
),
Event(
id='456',
invocation_id='333',
author='user',
timestamp=isoparse(MOCK_EVENT_JSON_3[0]['timestamp']).timestamp(),
),
],
)
SESSION_REGEX = r'^reasoningEngines/([^/]+)/sessions/([^/]+)$' SESSION_REGEX = r'^reasoningEngines/([^/]+)/sessions/([^/]+)$'
SESSIONS_REGEX = ( # %22 represents double-quotes in a URL-encoded string SESSIONS_REGEX = ( # %22 represents double-quotes in a URL-encoded string
r'^reasoningEngines/([^/]+)/sessions\?filter=user_id=%22([^%]+)%22.*$' r'^reasoningEngines/([^/]+)/sessions\?filter=user_id=%22([^%]+)%22.*$'
) )
EVENTS_REGEX = r'^reasoningEngines/([^/]+)/sessions/([^/]+)/events$' EVENTS_REGEX = (
r'^reasoningEngines/([^/]+)/sessions/([^/]+)/events(?:\?pageToken=([^/]+))?'
)
LRO_REGEX = r'^operations/([^/]+)$' LRO_REGEX = r'^operations/([^/]+)$'
@ -124,7 +172,7 @@ class MockApiClient:
def __init__(self) -> None: def __init__(self) -> None:
"""Initializes MockClient.""" """Initializes MockClient."""
this.session_dict: dict[str, Any] = {} this.session_dict: dict[str, Any] = {}
this.event_dict: dict[str, list[Any]] = {} this.event_dict: dict[str, Tuple[List[Any], Optional[str]]] = {}
async def async_request( async def async_request(
self, http_method: str, path: str, request_dict: dict[str, Any] self, http_method: str, path: str, request_dict: dict[str, Any]
@ -152,7 +200,13 @@ class MockApiClient:
match = re.match(EVENTS_REGEX, path) match = re.match(EVENTS_REGEX, path)
if match: if match:
session_id = match.group(2) session_id = match.group(2)
return {'sessionEvents': self.event_dict.get(session_id, [])} if match.group(3):
return {'sessionEvents': MOCK_EVENT_JSON_3}
events_tuple = self.event_dict.get(session_id, ([], None))
response = {'sessionEvents': events_tuple[0]}
if events_tuple[1]:
response['nextPageToken'] = events_tuple[1]
return response
elif re.match(LRO_REGEX, path): elif re.match(LRO_REGEX, path):
# Mock long-running operation as completed # Mock long-running operation as completed
return { return {
@ -207,7 +261,8 @@ def mock_get_api_client():
'3': MOCK_SESSION_JSON_3, '3': MOCK_SESSION_JSON_3,
} }
api_client.event_dict = { api_client.event_dict = {
'1': MOCK_EVENT_JSON, '1': (MOCK_EVENT_JSON, None),
'2': (MOCK_EVENT_JSON_2, 'my_token'),
} }
with mock.patch( with mock.patch(
'google.adk.sessions.vertex_ai_session_service._get_api_client', 'google.adk.sessions.vertex_ai_session_service._get_api_client',
@ -249,6 +304,19 @@ async def test_get_and_delete_session():
assert str(excinfo.value) == 'Session not found: 1' assert str(excinfo.value) == 'Session not found: 1'
@pytest.mark.asyncio
@pytest.mark.usefixtures('mock_get_api_client')
async def test_get_session_with_page_token():
session_service = mock_vertex_ai_session_service()
assert (
await session_service.get_session(
app_name='123', user_id='user', session_id='2'
)
== MOCK_SESSION_2
)
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.usefixtures('mock_get_api_client') @pytest.mark.usefixtures('mock_get_api_client')
async def test_list_sessions(): async def test_list_sessions():