From 1393965720a8d364d8d6b979fc70bf1360696fed Mon Sep 17 00:00:00 2001 From: Allen Date: Tue, 6 May 2025 13:09:28 -0700 Subject: [PATCH] Fix: config.after_timestamp behavior in InMemorySessionService.get_session() and add a test Copybara import of the project: -- c1d0d649b5aae1322a02dbaa586822d69b8546f6 by allengour : fix: fix and test `config.after_timestamp` behavior in `InMemorySessionService.get_session()` COPYBARA_INTEGRATE_REVIEW=https://github.com/google/adk-python/pull/438 from allengour:fix/issue-437-after_timestamp-behavior 4b49a5e6509b5ad9dd9103a6dc357fd44c101f31 PiperOrigin-RevId: 755492201 --- .../adk/sessions/in_memory_session_service.py | 6 +- .../sessions/test_session_service.py | 57 ++++++++++++++++++- 2 files changed, 59 insertions(+), 4 deletions(-) diff --git a/src/google/adk/sessions/in_memory_session_service.py b/src/google/adk/sessions/in_memory_session_service.py index bcb659a..69767f2 100644 --- a/src/google/adk/sessions/in_memory_session_service.py +++ b/src/google/adk/sessions/in_memory_session_service.py @@ -95,14 +95,14 @@ class InMemorySessionService(BaseSessionService): copied_session.events = copied_session.events[ -config.num_recent_events : ] - elif config.after_timestamp: - i = len(session.events) - 1 + if config.after_timestamp: + i = len(copied_session.events) - 1 while i >= 0: if copied_session.events[i].timestamp < config.after_timestamp: break i -= 1 if i >= 0: - copied_session.events = copied_session.events[i:] + copied_session.events = copied_session.events[i + 1:] return self._merge_state(app_name, user_id, copied_session) diff --git a/tests/unittests/sessions/test_session_service.py b/tests/unittests/sessions/test_session_service.py index 5cf5e1d..158bf5e 100644 --- a/tests/unittests/sessions/test_session_service.py +++ b/tests/unittests/sessions/test_session_service.py @@ -19,6 +19,7 @@ from google.adk.events import Event from google.adk.events import EventActions from google.adk.sessions import DatabaseSessionService from google.adk.sessions import InMemorySessionService +from google.adk.sessions.base_session_service import GetSessionConfig from google.genai import types @@ -183,7 +184,7 @@ def test_session_state(service_type): @pytest.mark.parametrize( - "service_type", [SessionServiceType.IN_MEMORY, SessionServiceType.DATABASE] + 'service_type', [SessionServiceType.IN_MEMORY, SessionServiceType.DATABASE] ) def test_create_new_session_will_merge_states(service_type): session_service = get_session_service(service_type) @@ -298,3 +299,57 @@ def test_append_event_complete(service_type): ) == session ) + +@pytest.mark.parametrize('service_type', [SessionServiceType.IN_MEMORY]) +def test_get_session_with_config(service_type): + session_service = get_session_service(service_type) + app_name = 'my_app' + user_id = 'user' + + num_test_events = 5 + session = session_service.create_session(app_name=app_name, user_id=user_id) + for i in range(1, num_test_events + 1): + event = Event(author='user', timestamp=i) + session_service.append_event(session, event) + + # No config, expect all events to be returned. + events = session_service.get_session( + app_name=app_name, user_id=user_id, session_id=session.id + ).events + assert len(events) == num_test_events + + # Only expect the most recent 3 events. + num_recent_events = 3 + config = GetSessionConfig(num_recent_events=num_recent_events) + events = session_service.get_session( + app_name=app_name, user_id=user_id, session_id=session.id, config=config + ).events + assert len(events) == num_recent_events + assert events[0].timestamp == num_test_events - num_recent_events + 1 + + # Only expect events after timestamp 4.0 (inclusive), i.e., 2 events. + after_timestamp = 4.0 + config = GetSessionConfig(after_timestamp=after_timestamp) + events = session_service.get_session( + app_name=app_name, user_id=user_id, session_id=session.id, config=config + ).events + assert len(events) == num_test_events - after_timestamp + 1 + assert events[0].timestamp == after_timestamp + + # Expect no events if none are > after_timestamp. + way_after_timestamp = num_test_events * 10 + config = GetSessionConfig(after_timestamp=way_after_timestamp) + events = session_service.get_session( + app_name=app_name, user_id=user_id, session_id=session.id, config=config + ).events + assert len(events) == 0 + + # Both filters applied, i.e., of 3 most recent events, only 2 are after + # timestamp 4.0, so expect 2 events. + config = GetSessionConfig( + after_timestamp=after_timestamp, num_recent_events=num_recent_events + ) + events = session_service.get_session( + app_name=app_name, user_id=user_id, session_id=session.id, config=config + ).events + assert len(events) == num_test_events - after_timestamp + 1