adk-python/tests/unittests/sessions/test_session_service.py
Yongsul Kim c024ac5762 Align event filtering and ordering logic
Copybara import of the project:

--
d01a8fd5f079bc4fca9e4b71796dbe65312ce9ff by Leo Yongsul Kim <ystory84@gmail.com>:

fix(DatabaseSessionService): Align event filtering and ordering logic

This commit addresses inconsistencies in how DatabaseSessionService
handles config.after_timestamp and config.num_recent_events
parameters, aligning its behavior with InMemorySessionService and
VertexAiSessionService.

Key changes:
- Made after_timestamp filtering inclusive
- Corrected num_recent_events behavior to fetch the N most recent events
- Refined timezone handling for after_timestamp
- Updated the unit test test_get_session_with_config to includeSessionServiceType.DATABASE, allowing verification of these fixes.

Fixes #911

COPYBARA_INTEGRATE_REVIEW=https://github.com/google/adk-python/pull/915 from ystory:fix/database-session-timestamp-recency 5cc8cf5f5a5c0cb3e87f6ab178a5725d3f696c88
PiperOrigin-RevId: 763874840
2025-05-27 11:22:04 -07:00

377 lines
12 KiB
Python

# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import enum
from google.adk.events import Event
from google.adk.events import EventActions
from google.adk.sessions import DatabaseSessionService
from google.adk.sessions import InMemorySessionService
from google.adk.sessions.base_session_service import GetSessionConfig
from google.genai import types
import pytest
class SessionServiceType(enum.Enum):
IN_MEMORY = 'IN_MEMORY'
DATABASE = 'DATABASE'
def get_session_service(
service_type: SessionServiceType = SessionServiceType.IN_MEMORY,
):
"""Creates a session service for testing."""
if service_type == SessionServiceType.DATABASE:
return DatabaseSessionService('sqlite:///:memory:')
return InMemorySessionService()
@pytest.mark.asyncio
@pytest.mark.parametrize(
'service_type', [SessionServiceType.IN_MEMORY, SessionServiceType.DATABASE]
)
async def test_get_empty_session(service_type):
session_service = get_session_service(service_type)
assert not await session_service.get_session(
app_name='my_app', user_id='test_user', session_id='123'
)
@pytest.mark.asyncio
@pytest.mark.parametrize(
'service_type', [SessionServiceType.IN_MEMORY, SessionServiceType.DATABASE]
)
async def test_create_get_session(service_type):
session_service = get_session_service(service_type)
app_name = 'my_app'
user_id = 'test_user'
state = {'key': 'value'}
session = await session_service.create_session(
app_name=app_name, user_id=user_id, state=state
)
assert session.app_name == app_name
assert session.user_id == user_id
assert session.id
assert session.state == state
assert (
await session_service.get_session(
app_name=app_name, user_id=user_id, session_id=session.id
)
== session
)
session_id = session.id
await session_service.delete_session(
app_name=app_name, user_id=user_id, session_id=session_id
)
assert (
await session_service.get_session(
app_name=app_name, user_id=user_id, session_id=session.id
)
!= session
)
@pytest.mark.asyncio
@pytest.mark.parametrize(
'service_type', [SessionServiceType.IN_MEMORY, SessionServiceType.DATABASE]
)
async def test_create_and_list_sessions(service_type):
session_service = get_session_service(service_type)
app_name = 'my_app'
user_id = 'test_user'
session_ids = ['session' + str(i) for i in range(5)]
for session_id in session_ids:
await session_service.create_session(
app_name=app_name, user_id=user_id, session_id=session_id
)
list_sessions_response = await session_service.list_sessions(
app_name=app_name, user_id=user_id
)
sessions = list_sessions_response.sessions
for i in range(len(sessions)):
assert sessions[i].id == session_ids[i]
@pytest.mark.asyncio
@pytest.mark.parametrize(
'service_type', [SessionServiceType.IN_MEMORY, SessionServiceType.DATABASE]
)
async def test_session_state(service_type):
session_service = get_session_service(service_type)
app_name = 'my_app'
user_id_1 = 'user1'
user_id_2 = 'user2'
session_id_11 = 'session11'
session_id_12 = 'session12'
session_id_2 = 'session2'
state_11 = {'key11': 'value11'}
state_12 = {'key12': 'value12'}
session_11 = await session_service.create_session(
app_name=app_name,
user_id=user_id_1,
state=state_11,
session_id=session_id_11,
)
await session_service.create_session(
app_name=app_name,
user_id=user_id_1,
state=state_12,
session_id=session_id_12,
)
await session_service.create_session(
app_name=app_name, user_id=user_id_2, session_id=session_id_2
)
assert session_11.state.get('key11') == 'value11'
event = Event(
invocation_id='invocation',
author='user',
content=types.Content(role='user', parts=[types.Part(text='text')]),
actions=EventActions(
state_delta={
'app:key': 'value',
'user:key1': 'value1',
'temp:key': 'temp',
'key11': 'value11_new',
}
),
)
await session_service.append_event(session=session_11, event=event)
# User and app state is stored, temp state is filtered.
assert session_11.state.get('app:key') == 'value'
assert session_11.state.get('key11') == 'value11_new'
assert session_11.state.get('user:key1') == 'value1'
assert not session_11.state.get('temp:key')
session_12 = await session_service.get_session(
app_name=app_name, user_id=user_id_1, session_id=session_id_12
)
# After getting a new instance, the session_12 got the user and app state,
# even append_event is not applied to it, temp state has no effect
assert session_12.state.get('key12') == 'value12'
assert not session_12.state.get('temp:key')
# The user1's state is not visible to user2, app state is visible
session_2 = await session_service.get_session(
app_name=app_name, user_id=user_id_2, session_id=session_id_2
)
assert session_2.state.get('app:key') == 'value'
assert not session_2.state.get('user:key1')
assert not session_2.state.get('user:key1')
# The change to session_11 is persisted
session_11 = await session_service.get_session(
app_name=app_name, user_id=user_id_1, session_id=session_id_11
)
assert session_11.state.get('key11') == 'value11_new'
assert session_11.state.get('user:key1') == 'value1'
assert not session_11.state.get('temp:key')
@pytest.mark.asyncio
@pytest.mark.parametrize(
'service_type', [SessionServiceType.IN_MEMORY, SessionServiceType.DATABASE]
)
async def test_create_new_session_will_merge_states(service_type):
session_service = get_session_service(service_type)
app_name = 'my_app'
user_id = 'user'
session_id_1 = 'session1'
session_id_2 = 'session2'
state_1 = {'key1': 'value1'}
session_1 = await session_service.create_session(
app_name=app_name, user_id=user_id, state=state_1, session_id=session_id_1
)
event = Event(
invocation_id='invocation',
author='user',
content=types.Content(role='user', parts=[types.Part(text='text')]),
actions=EventActions(
state_delta={
'app:key': 'value',
'user:key1': 'value1',
'temp:key': 'temp',
}
),
)
await session_service.append_event(session=session_1, event=event)
# User and app state is stored, temp state is filtered.
assert session_1.state.get('app:key') == 'value'
assert session_1.state.get('key1') == 'value1'
assert session_1.state.get('user:key1') == 'value1'
assert not session_1.state.get('temp:key')
session_2 = await session_service.create_session(
app_name=app_name, user_id=user_id, state={}, session_id=session_id_2
)
# Session 2 has the persisted states
assert session_2.state.get('app:key') == 'value'
assert session_2.state.get('user:key1') == 'value1'
assert not session_2.state.get('key1')
assert not session_2.state.get('temp:key')
@pytest.mark.asyncio
@pytest.mark.parametrize(
'service_type', [SessionServiceType.IN_MEMORY, SessionServiceType.DATABASE]
)
async def test_append_event_bytes(service_type):
session_service = get_session_service(service_type)
app_name = 'my_app'
user_id = 'user'
session = await session_service.create_session(
app_name=app_name, user_id=user_id
)
event = Event(
invocation_id='invocation',
author='user',
content=types.Content(
role='user',
parts=[
types.Part.from_bytes(
data=b'test_image_data', mime_type='image/png'
),
],
),
)
await session_service.append_event(session=session, event=event)
assert session.events[0].content.parts[0] == types.Part.from_bytes(
data=b'test_image_data', mime_type='image/png'
)
session = await session_service.get_session(
app_name=app_name, user_id=user_id, session_id=session.id
)
events = session.events
assert len(events) == 1
assert events[0].content.parts[0] == types.Part.from_bytes(
data=b'test_image_data', mime_type='image/png'
)
@pytest.mark.asyncio
@pytest.mark.parametrize(
'service_type', [SessionServiceType.IN_MEMORY, SessionServiceType.DATABASE]
)
async def test_append_event_complete(service_type):
session_service = get_session_service(service_type)
app_name = 'my_app'
user_id = 'user'
session = await session_service.create_session(
app_name=app_name, user_id=user_id
)
event = Event(
invocation_id='invocation',
author='user',
content=types.Content(role='user', parts=[types.Part(text='test_text')]),
turn_complete=True,
partial=False,
actions=EventActions(
artifact_delta={
'file': 0,
},
transfer_to_agent='agent',
escalate=True,
),
long_running_tool_ids={'tool1'},
error_code='error_code',
error_message='error_message',
interrupted=True,
)
await session_service.append_event(session=session, event=event)
assert (
await session_service.get_session(
app_name=app_name, user_id=user_id, session_id=session.id
)
== session
)
@pytest.mark.asyncio
@pytest.mark.parametrize('service_type', [SessionServiceType.IN_MEMORY, SessionServiceType.DATABASE])
async def test_get_session_with_config(service_type):
session_service = get_session_service(service_type)
app_name = 'my_app'
user_id = 'user'
num_test_events = 5
session = await session_service.create_session(
app_name=app_name, user_id=user_id
)
for i in range(1, num_test_events + 1):
event = Event(author='user', timestamp=i)
await session_service.append_event(session, event)
# No config, expect all events to be returned.
session = await session_service.get_session(
app_name=app_name, user_id=user_id, session_id=session.id
)
events = session.events
assert len(events) == num_test_events
# Only expect the most recent 3 events.
num_recent_events = 3
config = GetSessionConfig(num_recent_events=num_recent_events)
session = await session_service.get_session(
app_name=app_name, user_id=user_id, session_id=session.id, config=config
)
events = session.events
assert len(events) == num_recent_events
assert events[0].timestamp == num_test_events - num_recent_events + 1
# Only expect events after timestamp 4.0 (inclusive), i.e., 2 events.
after_timestamp = 4.0
config = GetSessionConfig(after_timestamp=after_timestamp)
session = await session_service.get_session(
app_name=app_name, user_id=user_id, session_id=session.id, config=config
)
events = session.events
assert len(events) == num_test_events - after_timestamp + 1
assert events[0].timestamp == after_timestamp
# Expect no events if none are > after_timestamp.
way_after_timestamp = num_test_events * 10
config = GetSessionConfig(after_timestamp=way_after_timestamp)
session = await session_service.get_session(
app_name=app_name, user_id=user_id, session_id=session.id, config=config
)
assert not session.events
# Both filters applied, i.e., of 3 most recent events, only 2 are after
# timestamp 4.0, so expect 2 events.
config = GetSessionConfig(
after_timestamp=after_timestamp, num_recent_events=num_recent_events
)
session = await session_service.get_session(
app_name=app_name, user_id=user_id, session_id=session.id, config=config
)
events = session.events
assert len(events) == num_test_events - after_timestamp + 1