fix: Continue fetching events if there are multiple pages.

Fixes https://github.com/google/adk-python/issues/920

PiperOrigin-RevId: 764985371
This commit is contained in:
Shangjie Chen 2025-05-29 19:42:17 -07:00 committed by Copybara-Service
parent 18fbe3cbfc
commit 65063023a5
2 changed files with 87 additions and 5 deletions

View File

@ -159,15 +159,29 @@ class VertexAiSessionService(BaseSessionService):
if list_events_api_response.get('httpHeaders', None):
return session
session.events = [
session.events += [
_from_api_event(event)
for event in list_events_api_response['sessionEvents']
]
while list_events_api_response.get('nextPageToken', None):
page_token = list_events_api_response.get('nextPageToken', None)
list_events_api_response = await api_client.async_request(
http_method='GET',
path=f'reasoningEngines/{reasoning_engine_id}/sessions/{session_id}/events?pageToken={page_token}',
request_dict={},
)
session.events += [
_from_api_event(event)
for event in list_events_api_response['sessionEvents']
]
session.events = [
event for event in session.events if event.timestamp <= update_timestamp
]
session.events.sort(key=lambda event: event.timestamp)
# Filter events based on config
if config:
if config.num_recent_events:
session.events = session.events[-config.num_recent_events :]

View File

@ -15,6 +15,9 @@
import re
import this
from typing import Any
from typing import List
from typing import Optional
from typing import Tuple
from unittest import mock
from dateutil.parser import isoparse
@ -82,6 +85,28 @@ MOCK_EVENT_JSON = [
},
},
]
MOCK_EVENT_JSON_2 = [
{
'name': (
'projects/test-project/locations/test-location/'
'reasoningEngines/123/sessions/2/events/123'
),
'invocationId': '222',
'author': 'user',
'timestamp': '2024-12-12T12:12:12.123456Z',
},
]
MOCK_EVENT_JSON_3 = [
{
'name': (
'projects/test-project/locations/test-location/'
'reasoningEngines/123/sessions/2/events/456'
),
'invocationId': '333',
'author': 'user',
'timestamp': '2024-12-12T12:12:12.123456Z',
},
]
MOCK_SESSION = Session(
app_name='123',
@ -109,12 +134,35 @@ MOCK_SESSION = Session(
],
)
MOCK_SESSION_2 = Session(
app_name='123',
user_id='user',
id='2',
last_update_time=isoparse(MOCK_SESSION_JSON_2['updateTime']).timestamp(),
events=[
Event(
id='123',
invocation_id='222',
author='user',
timestamp=isoparse(MOCK_EVENT_JSON_2[0]['timestamp']).timestamp(),
),
Event(
id='456',
invocation_id='333',
author='user',
timestamp=isoparse(MOCK_EVENT_JSON_3[0]['timestamp']).timestamp(),
),
],
)
SESSION_REGEX = r'^reasoningEngines/([^/]+)/sessions/([^/]+)$'
SESSIONS_REGEX = ( # %22 represents double-quotes in a URL-encoded string
r'^reasoningEngines/([^/]+)/sessions\?filter=user_id=%22([^%]+)%22.*$'
)
EVENTS_REGEX = r'^reasoningEngines/([^/]+)/sessions/([^/]+)/events$'
EVENTS_REGEX = (
r'^reasoningEngines/([^/]+)/sessions/([^/]+)/events(?:\?pageToken=([^/]+))?'
)
LRO_REGEX = r'^operations/([^/]+)$'
@ -124,7 +172,7 @@ class MockApiClient:
def __init__(self) -> None:
"""Initializes MockClient."""
this.session_dict: dict[str, Any] = {}
this.event_dict: dict[str, list[Any]] = {}
this.event_dict: dict[str, Tuple[List[Any], Optional[str]]] = {}
async def async_request(
self, http_method: str, path: str, request_dict: dict[str, Any]
@ -152,7 +200,13 @@ class MockApiClient:
match = re.match(EVENTS_REGEX, path)
if match:
session_id = match.group(2)
return {'sessionEvents': self.event_dict.get(session_id, [])}
if match.group(3):
return {'sessionEvents': MOCK_EVENT_JSON_3}
events_tuple = self.event_dict.get(session_id, ([], None))
response = {'sessionEvents': events_tuple[0]}
if events_tuple[1]:
response['nextPageToken'] = events_tuple[1]
return response
elif re.match(LRO_REGEX, path):
# Mock long-running operation as completed
return {
@ -207,7 +261,8 @@ def mock_get_api_client():
'3': MOCK_SESSION_JSON_3,
}
api_client.event_dict = {
'1': MOCK_EVENT_JSON,
'1': (MOCK_EVENT_JSON, None),
'2': (MOCK_EVENT_JSON_2, 'my_token'),
}
with mock.patch(
'google.adk.sessions.vertex_ai_session_service._get_api_client',
@ -249,6 +304,19 @@ async def test_get_and_delete_session():
assert str(excinfo.value) == 'Session not found: 1'
@pytest.mark.asyncio
@pytest.mark.usefixtures('mock_get_api_client')
async def test_get_session_with_page_token():
session_service = mock_vertex_ai_session_service()
assert (
await session_service.get_session(
app_name='123', user_id='user', session_id='2'
)
== MOCK_SESSION_2
)
@pytest.mark.asyncio
@pytest.mark.usefixtures('mock_get_api_client')
async def test_list_sessions():