feat! Update session service interface to be async.

Also keep the sync version in the InMemorySessionService as create_session_sync() as a temporary migration option.

PiperOrigin-RevId: 759252188
This commit is contained in:
Google Team Member
2025-05-15 12:23:33 -07:00
committed by Copybara-Service
parent 5b3204c356
commit 1804ca39a6
23 changed files with 268 additions and 264 deletions

View File

@@ -37,28 +37,26 @@ def get_session_service(
return InMemorySessionService()
@pytest.mark.asyncio
@pytest.mark.parametrize(
'service_type', [SessionServiceType.IN_MEMORY, SessionServiceType.DATABASE]
)
async def test_get_empty_session(service_type):
def test_get_empty_session(service_type):
session_service = get_session_service(service_type)
assert not await session_service.get_session(
assert not session_service.get_session(
app_name='my_app', user_id='test_user', session_id='123'
)
@pytest.mark.asyncio
@pytest.mark.parametrize(
'service_type', [SessionServiceType.IN_MEMORY, SessionServiceType.DATABASE]
)
async def test_create_get_session(service_type):
def test_create_get_session(service_type):
session_service = get_session_service(service_type)
app_name = 'my_app'
user_id = 'test_user'
state = {'key': 'value'}
session = await session_service.create_session(
session = session_service.create_session(
app_name=app_name, user_id=user_id, state=state
)
assert session.app_name == app_name
@@ -66,53 +64,50 @@ async def test_create_get_session(service_type):
assert session.id
assert session.state == state
assert (
await session_service.get_session(
session_service.get_session(
app_name=app_name, user_id=user_id, session_id=session.id
)
== session
)
session_id = session.id
await session_service.delete_session(
session_service.delete_session(
app_name=app_name, user_id=user_id, session_id=session_id
)
assert (
await session_service.get_session(
not session_service.get_session(
app_name=app_name, user_id=user_id, session_id=session.id
)
!= session
== session
)
@pytest.mark.asyncio
@pytest.mark.parametrize(
'service_type', [SessionServiceType.IN_MEMORY, SessionServiceType.DATABASE]
)
async def test_create_and_list_sessions(service_type):
def test_create_and_list_sessions(service_type):
session_service = get_session_service(service_type)
app_name = 'my_app'
user_id = 'test_user'
session_ids = ['session' + str(i) for i in range(5)]
for session_id in session_ids:
await session_service.create_session(
session_service.create_session(
app_name=app_name, user_id=user_id, session_id=session_id
)
list_sessions_response = await session_service.list_sessions(
sessions = session_service.list_sessions(
app_name=app_name, user_id=user_id
)
sessions = list_sessions_response.sessions
).sessions
for i in range(len(sessions)):
assert sessions[i].id == session_ids[i]
@pytest.mark.asyncio
@pytest.mark.parametrize(
'service_type', [SessionServiceType.IN_MEMORY, SessionServiceType.DATABASE]
)
async def test_session_state(service_type):
def test_session_state(service_type):
session_service = get_session_service(service_type)
app_name = 'my_app'
user_id_1 = 'user1'
@@ -123,19 +118,19 @@ async def test_session_state(service_type):
state_11 = {'key11': 'value11'}
state_12 = {'key12': 'value12'}
session_11 = await session_service.create_session(
session_11 = session_service.create_session(
app_name=app_name,
user_id=user_id_1,
state=state_11,
session_id=session_id_11,
)
await session_service.create_session(
session_service.create_session(
app_name=app_name,
user_id=user_id_1,
state=state_12,
session_id=session_id_12,
)
await session_service.create_session(
session_service.create_session(
app_name=app_name, user_id=user_id_2, session_id=session_id_2
)
@@ -154,7 +149,7 @@ async def test_session_state(service_type):
}
),
)
await session_service.append_event(session=session_11, event=event)
session_service.append_event(session=session_11, event=event)
# User and app state is stored, temp state is filtered.
assert session_11.state.get('app:key') == 'value'
@@ -162,7 +157,7 @@ async def test_session_state(service_type):
assert session_11.state.get('user:key1') == 'value1'
assert not session_11.state.get('temp:key')
session_12 = await session_service.get_session(
session_12 = session_service.get_session(
app_name=app_name, user_id=user_id_1, session_id=session_id_12
)
# After getting a new instance, the session_12 got the user and app state,
@@ -171,7 +166,7 @@ async def test_session_state(service_type):
assert not session_12.state.get('temp:key')
# The user1's state is not visible to user2, app state is visible
session_2 = await session_service.get_session(
session_2 = session_service.get_session(
app_name=app_name, user_id=user_id_2, session_id=session_id_2
)
assert session_2.state.get('app:key') == 'value'
@@ -180,7 +175,7 @@ async def test_session_state(service_type):
assert not session_2.state.get('user:key1')
# The change to session_11 is persisted
session_11 = await session_service.get_session(
session_11 = session_service.get_session(
app_name=app_name, user_id=user_id_1, session_id=session_id_11
)
assert session_11.state.get('key11') == 'value11_new'
@@ -188,11 +183,10 @@ async def test_session_state(service_type):
assert not session_11.state.get('temp:key')
@pytest.mark.asyncio
@pytest.mark.parametrize(
'service_type', [SessionServiceType.IN_MEMORY, SessionServiceType.DATABASE]
)
async def test_create_new_session_will_merge_states(service_type):
def test_create_new_session_will_merge_states(service_type):
session_service = get_session_service(service_type)
app_name = 'my_app'
user_id = 'user'
@@ -200,7 +194,7 @@ async def test_create_new_session_will_merge_states(service_type):
session_id_2 = 'session2'
state_1 = {'key1': 'value1'}
session_1 = await session_service.create_session(
session_1 = session_service.create_session(
app_name=app_name, user_id=user_id, state=state_1, session_id=session_id_1
)
@@ -216,7 +210,7 @@ async def test_create_new_session_will_merge_states(service_type):
}
),
)
await session_service.append_event(session=session_1, event=event)
session_service.append_event(session=session_1, event=event)
# User and app state is stored, temp state is filtered.
assert session_1.state.get('app:key') == 'value'
@@ -224,7 +218,7 @@ async def test_create_new_session_will_merge_states(service_type):
assert session_1.state.get('user:key1') == 'value1'
assert not session_1.state.get('temp:key')
session_2 = await session_service.create_session(
session_2 = session_service.create_session(
app_name=app_name, user_id=user_id, state={}, session_id=session_id_2
)
# Session 2 has the persisted states
@@ -234,18 +228,15 @@ async def test_create_new_session_will_merge_states(service_type):
assert not session_2.state.get('temp:key')
@pytest.mark.asyncio
@pytest.mark.parametrize(
'service_type', [SessionServiceType.IN_MEMORY, SessionServiceType.DATABASE]
)
async def test_append_event_bytes(service_type):
def test_append_event_bytes(service_type):
session_service = get_session_service(service_type)
app_name = 'my_app'
user_id = 'user'
session = await session_service.create_session(
app_name=app_name, user_id=user_id
)
session = session_service.create_session(app_name=app_name, user_id=user_id)
event = Event(
invocation_id='invocation',
author='user',
@@ -258,34 +249,30 @@ async def test_append_event_bytes(service_type):
],
),
)
await session_service.append_event(session=session, event=event)
session_service.append_event(session=session, event=event)
assert session.events[0].content.parts[0] == types.Part.from_bytes(
data=b'test_image_data', mime_type='image/png'
)
session = await session_service.get_session(
events = session_service.get_session(
app_name=app_name, user_id=user_id, session_id=session.id
)
events = session.events
).events
assert len(events) == 1
assert events[0].content.parts[0] == types.Part.from_bytes(
data=b'test_image_data', mime_type='image/png'
)
@pytest.mark.asyncio
@pytest.mark.parametrize(
'service_type', [SessionServiceType.IN_MEMORY, SessionServiceType.DATABASE]
)
async def test_append_event_complete(service_type):
def test_append_event_complete(service_type):
session_service = get_session_service(service_type)
app_name = 'my_app'
user_id = 'user'
session = await session_service.create_session(
app_name=app_name, user_id=user_id
)
session = session_service.create_session(app_name=app_name, user_id=user_id)
event = Event(
invocation_id='invocation',
author='user',
@@ -304,73 +291,65 @@ async def test_append_event_complete(service_type):
error_message='error_message',
interrupted=True,
)
await session_service.append_event(session=session, event=event)
session_service.append_event(session=session, event=event)
assert (
await session_service.get_session(
session_service.get_session(
app_name=app_name, user_id=user_id, session_id=session.id
)
== session
)
@pytest.mark.asyncio
@pytest.mark.parametrize('service_type', [SessionServiceType.IN_MEMORY])
async def test_get_session_with_config(service_type):
def test_get_session_with_config(service_type):
session_service = get_session_service(service_type)
app_name = 'my_app'
user_id = 'user'
num_test_events = 5
session = await session_service.create_session(
app_name=app_name, user_id=user_id
)
session = session_service.create_session(app_name=app_name, user_id=user_id)
for i in range(1, num_test_events + 1):
event = Event(author='user', timestamp=i)
await session_service.append_event(session, event)
session_service.append_event(session, event)
# No config, expect all events to be returned.
session = await session_service.get_session(
events = session_service.get_session(
app_name=app_name, user_id=user_id, session_id=session.id
)
events = session.events
).events
assert len(events) == num_test_events
# Only expect the most recent 3 events.
num_recent_events = 3
config = GetSessionConfig(num_recent_events=num_recent_events)
session = await session_service.get_session(
events = session_service.get_session(
app_name=app_name, user_id=user_id, session_id=session.id, config=config
)
events = session.events
).events
assert len(events) == num_recent_events
assert events[0].timestamp == num_test_events - num_recent_events + 1
# Only expect events after timestamp 4.0 (inclusive), i.e., 2 events.
after_timestamp = 4.0
config = GetSessionConfig(after_timestamp=after_timestamp)
session = await session_service.get_session(
events = session_service.get_session(
app_name=app_name, user_id=user_id, session_id=session.id, config=config
)
events = session.events
).events
assert len(events) == num_test_events - after_timestamp + 1
assert events[0].timestamp == after_timestamp
# Expect no events if none are > after_timestamp.
way_after_timestamp = num_test_events * 10
config = GetSessionConfig(after_timestamp=way_after_timestamp)
session = await session_service.get_session(
events = session_service.get_session(
app_name=app_name, user_id=user_id, session_id=session.id, config=config
)
assert not session.events
).events
assert len(events) == 0
# Both filters applied, i.e., of 3 most recent events, only 2 are after
# timestamp 4.0, so expect 2 events.
config = GetSessionConfig(
after_timestamp=after_timestamp, num_recent_events=num_recent_events
)
session = await session_service.get_session(
events = session_service.get_session(
app_name=app_name, user_id=user_id, session_id=session.id, config=config
)
events = session.events
).events
assert len(events) == num_test_events - after_timestamp + 1