mirror of
https://github.com/EvolutionAPI/adk-python.git
synced 2025-07-14 01:41:25 -06:00
fix: Continue fetching events if there are multiple pages.
Fixes https://github.com/google/adk-python/issues/920 PiperOrigin-RevId: 764985371
This commit is contained in:
parent
18fbe3cbfc
commit
65063023a5
@ -159,15 +159,29 @@ class VertexAiSessionService(BaseSessionService):
|
||||
if list_events_api_response.get('httpHeaders', None):
|
||||
return session
|
||||
|
||||
session.events = [
|
||||
session.events += [
|
||||
_from_api_event(event)
|
||||
for event in list_events_api_response['sessionEvents']
|
||||
]
|
||||
|
||||
while list_events_api_response.get('nextPageToken', None):
|
||||
page_token = list_events_api_response.get('nextPageToken', None)
|
||||
list_events_api_response = await api_client.async_request(
|
||||
http_method='GET',
|
||||
path=f'reasoningEngines/{reasoning_engine_id}/sessions/{session_id}/events?pageToken={page_token}',
|
||||
request_dict={},
|
||||
)
|
||||
session.events += [
|
||||
_from_api_event(event)
|
||||
for event in list_events_api_response['sessionEvents']
|
||||
]
|
||||
|
||||
session.events = [
|
||||
event for event in session.events if event.timestamp <= update_timestamp
|
||||
]
|
||||
session.events.sort(key=lambda event: event.timestamp)
|
||||
|
||||
# Filter events based on config
|
||||
if config:
|
||||
if config.num_recent_events:
|
||||
session.events = session.events[-config.num_recent_events :]
|
||||
|
@ -15,6 +15,9 @@
|
||||
import re
|
||||
import this
|
||||
from typing import Any
|
||||
from typing import List
|
||||
from typing import Optional
|
||||
from typing import Tuple
|
||||
from unittest import mock
|
||||
|
||||
from dateutil.parser import isoparse
|
||||
@ -82,6 +85,28 @@ MOCK_EVENT_JSON = [
|
||||
},
|
||||
},
|
||||
]
|
||||
MOCK_EVENT_JSON_2 = [
|
||||
{
|
||||
'name': (
|
||||
'projects/test-project/locations/test-location/'
|
||||
'reasoningEngines/123/sessions/2/events/123'
|
||||
),
|
||||
'invocationId': '222',
|
||||
'author': 'user',
|
||||
'timestamp': '2024-12-12T12:12:12.123456Z',
|
||||
},
|
||||
]
|
||||
MOCK_EVENT_JSON_3 = [
|
||||
{
|
||||
'name': (
|
||||
'projects/test-project/locations/test-location/'
|
||||
'reasoningEngines/123/sessions/2/events/456'
|
||||
),
|
||||
'invocationId': '333',
|
||||
'author': 'user',
|
||||
'timestamp': '2024-12-12T12:12:12.123456Z',
|
||||
},
|
||||
]
|
||||
|
||||
MOCK_SESSION = Session(
|
||||
app_name='123',
|
||||
@ -109,12 +134,35 @@ MOCK_SESSION = Session(
|
||||
],
|
||||
)
|
||||
|
||||
MOCK_SESSION_2 = Session(
|
||||
app_name='123',
|
||||
user_id='user',
|
||||
id='2',
|
||||
last_update_time=isoparse(MOCK_SESSION_JSON_2['updateTime']).timestamp(),
|
||||
events=[
|
||||
Event(
|
||||
id='123',
|
||||
invocation_id='222',
|
||||
author='user',
|
||||
timestamp=isoparse(MOCK_EVENT_JSON_2[0]['timestamp']).timestamp(),
|
||||
),
|
||||
Event(
|
||||
id='456',
|
||||
invocation_id='333',
|
||||
author='user',
|
||||
timestamp=isoparse(MOCK_EVENT_JSON_3[0]['timestamp']).timestamp(),
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
SESSION_REGEX = r'^reasoningEngines/([^/]+)/sessions/([^/]+)$'
|
||||
SESSIONS_REGEX = ( # %22 represents double-quotes in a URL-encoded string
|
||||
r'^reasoningEngines/([^/]+)/sessions\?filter=user_id=%22([^%]+)%22.*$'
|
||||
)
|
||||
EVENTS_REGEX = r'^reasoningEngines/([^/]+)/sessions/([^/]+)/events$'
|
||||
EVENTS_REGEX = (
|
||||
r'^reasoningEngines/([^/]+)/sessions/([^/]+)/events(?:\?pageToken=([^/]+))?'
|
||||
)
|
||||
LRO_REGEX = r'^operations/([^/]+)$'
|
||||
|
||||
|
||||
@ -124,7 +172,7 @@ class MockApiClient:
|
||||
def __init__(self) -> None:
|
||||
"""Initializes MockClient."""
|
||||
this.session_dict: dict[str, Any] = {}
|
||||
this.event_dict: dict[str, list[Any]] = {}
|
||||
this.event_dict: dict[str, Tuple[List[Any], Optional[str]]] = {}
|
||||
|
||||
async def async_request(
|
||||
self, http_method: str, path: str, request_dict: dict[str, Any]
|
||||
@ -152,7 +200,13 @@ class MockApiClient:
|
||||
match = re.match(EVENTS_REGEX, path)
|
||||
if match:
|
||||
session_id = match.group(2)
|
||||
return {'sessionEvents': self.event_dict.get(session_id, [])}
|
||||
if match.group(3):
|
||||
return {'sessionEvents': MOCK_EVENT_JSON_3}
|
||||
events_tuple = self.event_dict.get(session_id, ([], None))
|
||||
response = {'sessionEvents': events_tuple[0]}
|
||||
if events_tuple[1]:
|
||||
response['nextPageToken'] = events_tuple[1]
|
||||
return response
|
||||
elif re.match(LRO_REGEX, path):
|
||||
# Mock long-running operation as completed
|
||||
return {
|
||||
@ -207,7 +261,8 @@ def mock_get_api_client():
|
||||
'3': MOCK_SESSION_JSON_3,
|
||||
}
|
||||
api_client.event_dict = {
|
||||
'1': MOCK_EVENT_JSON,
|
||||
'1': (MOCK_EVENT_JSON, None),
|
||||
'2': (MOCK_EVENT_JSON_2, 'my_token'),
|
||||
}
|
||||
with mock.patch(
|
||||
'google.adk.sessions.vertex_ai_session_service._get_api_client',
|
||||
@ -249,6 +304,19 @@ async def test_get_and_delete_session():
|
||||
assert str(excinfo.value) == 'Session not found: 1'
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.usefixtures('mock_get_api_client')
|
||||
async def test_get_session_with_page_token():
|
||||
session_service = mock_vertex_ai_session_service()
|
||||
|
||||
assert (
|
||||
await session_service.get_session(
|
||||
app_name='123', user_id='user', session_id='2'
|
||||
)
|
||||
== MOCK_SESSION_2
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.usefixtures('mock_get_api_client')
|
||||
async def test_list_sessions():
|
||||
|
Loading…
Reference in New Issue
Block a user