mirror of
https://github.com/EvolutionAPI/adk-python.git
synced 2025-12-18 11:22:22 -06:00
ADK changes
PiperOrigin-RevId: 759259620
This commit is contained in:
committed by
Copybara-Service
parent
1804ca39a6
commit
05917cabbd
@@ -37,26 +37,28 @@ def get_session_service(
|
||||
return InMemorySessionService()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
'service_type', [SessionServiceType.IN_MEMORY, SessionServiceType.DATABASE]
|
||||
)
|
||||
def test_get_empty_session(service_type):
|
||||
async def test_get_empty_session(service_type):
|
||||
session_service = get_session_service(service_type)
|
||||
assert not session_service.get_session(
|
||||
assert not await session_service.get_session(
|
||||
app_name='my_app', user_id='test_user', session_id='123'
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
'service_type', [SessionServiceType.IN_MEMORY, SessionServiceType.DATABASE]
|
||||
)
|
||||
def test_create_get_session(service_type):
|
||||
async def test_create_get_session(service_type):
|
||||
session_service = get_session_service(service_type)
|
||||
app_name = 'my_app'
|
||||
user_id = 'test_user'
|
||||
state = {'key': 'value'}
|
||||
|
||||
session = session_service.create_session(
|
||||
session = await session_service.create_session(
|
||||
app_name=app_name, user_id=user_id, state=state
|
||||
)
|
||||
assert session.app_name == app_name
|
||||
@@ -64,50 +66,53 @@ def test_create_get_session(service_type):
|
||||
assert session.id
|
||||
assert session.state == state
|
||||
assert (
|
||||
session_service.get_session(
|
||||
await session_service.get_session(
|
||||
app_name=app_name, user_id=user_id, session_id=session.id
|
||||
)
|
||||
== session
|
||||
)
|
||||
|
||||
session_id = session.id
|
||||
session_service.delete_session(
|
||||
await session_service.delete_session(
|
||||
app_name=app_name, user_id=user_id, session_id=session_id
|
||||
)
|
||||
|
||||
assert (
|
||||
not session_service.get_session(
|
||||
await session_service.get_session(
|
||||
app_name=app_name, user_id=user_id, session_id=session.id
|
||||
)
|
||||
== session
|
||||
!= session
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
'service_type', [SessionServiceType.IN_MEMORY, SessionServiceType.DATABASE]
|
||||
)
|
||||
def test_create_and_list_sessions(service_type):
|
||||
async def test_create_and_list_sessions(service_type):
|
||||
session_service = get_session_service(service_type)
|
||||
app_name = 'my_app'
|
||||
user_id = 'test_user'
|
||||
|
||||
session_ids = ['session' + str(i) for i in range(5)]
|
||||
for session_id in session_ids:
|
||||
session_service.create_session(
|
||||
await session_service.create_session(
|
||||
app_name=app_name, user_id=user_id, session_id=session_id
|
||||
)
|
||||
|
||||
sessions = session_service.list_sessions(
|
||||
list_sessions_response = await session_service.list_sessions(
|
||||
app_name=app_name, user_id=user_id
|
||||
).sessions
|
||||
)
|
||||
sessions = list_sessions_response.sessions
|
||||
for i in range(len(sessions)):
|
||||
assert sessions[i].id == session_ids[i]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
'service_type', [SessionServiceType.IN_MEMORY, SessionServiceType.DATABASE]
|
||||
)
|
||||
def test_session_state(service_type):
|
||||
async def test_session_state(service_type):
|
||||
session_service = get_session_service(service_type)
|
||||
app_name = 'my_app'
|
||||
user_id_1 = 'user1'
|
||||
@@ -118,19 +123,19 @@ def test_session_state(service_type):
|
||||
state_11 = {'key11': 'value11'}
|
||||
state_12 = {'key12': 'value12'}
|
||||
|
||||
session_11 = session_service.create_session(
|
||||
session_11 = await session_service.create_session(
|
||||
app_name=app_name,
|
||||
user_id=user_id_1,
|
||||
state=state_11,
|
||||
session_id=session_id_11,
|
||||
)
|
||||
session_service.create_session(
|
||||
await session_service.create_session(
|
||||
app_name=app_name,
|
||||
user_id=user_id_1,
|
||||
state=state_12,
|
||||
session_id=session_id_12,
|
||||
)
|
||||
session_service.create_session(
|
||||
await session_service.create_session(
|
||||
app_name=app_name, user_id=user_id_2, session_id=session_id_2
|
||||
)
|
||||
|
||||
@@ -149,7 +154,7 @@ def test_session_state(service_type):
|
||||
}
|
||||
),
|
||||
)
|
||||
session_service.append_event(session=session_11, event=event)
|
||||
await session_service.append_event(session=session_11, event=event)
|
||||
|
||||
# User and app state is stored, temp state is filtered.
|
||||
assert session_11.state.get('app:key') == 'value'
|
||||
@@ -157,7 +162,7 @@ def test_session_state(service_type):
|
||||
assert session_11.state.get('user:key1') == 'value1'
|
||||
assert not session_11.state.get('temp:key')
|
||||
|
||||
session_12 = session_service.get_session(
|
||||
session_12 = await session_service.get_session(
|
||||
app_name=app_name, user_id=user_id_1, session_id=session_id_12
|
||||
)
|
||||
# After getting a new instance, the session_12 got the user and app state,
|
||||
@@ -166,7 +171,7 @@ def test_session_state(service_type):
|
||||
assert not session_12.state.get('temp:key')
|
||||
|
||||
# The user1's state is not visible to user2, app state is visible
|
||||
session_2 = session_service.get_session(
|
||||
session_2 = await session_service.get_session(
|
||||
app_name=app_name, user_id=user_id_2, session_id=session_id_2
|
||||
)
|
||||
assert session_2.state.get('app:key') == 'value'
|
||||
@@ -175,7 +180,7 @@ def test_session_state(service_type):
|
||||
assert not session_2.state.get('user:key1')
|
||||
|
||||
# The change to session_11 is persisted
|
||||
session_11 = session_service.get_session(
|
||||
session_11 = await session_service.get_session(
|
||||
app_name=app_name, user_id=user_id_1, session_id=session_id_11
|
||||
)
|
||||
assert session_11.state.get('key11') == 'value11_new'
|
||||
@@ -183,10 +188,11 @@ def test_session_state(service_type):
|
||||
assert not session_11.state.get('temp:key')
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
'service_type', [SessionServiceType.IN_MEMORY, SessionServiceType.DATABASE]
|
||||
)
|
||||
def test_create_new_session_will_merge_states(service_type):
|
||||
async def test_create_new_session_will_merge_states(service_type):
|
||||
session_service = get_session_service(service_type)
|
||||
app_name = 'my_app'
|
||||
user_id = 'user'
|
||||
@@ -194,7 +200,7 @@ def test_create_new_session_will_merge_states(service_type):
|
||||
session_id_2 = 'session2'
|
||||
state_1 = {'key1': 'value1'}
|
||||
|
||||
session_1 = session_service.create_session(
|
||||
session_1 = await session_service.create_session(
|
||||
app_name=app_name, user_id=user_id, state=state_1, session_id=session_id_1
|
||||
)
|
||||
|
||||
@@ -210,7 +216,7 @@ def test_create_new_session_will_merge_states(service_type):
|
||||
}
|
||||
),
|
||||
)
|
||||
session_service.append_event(session=session_1, event=event)
|
||||
await session_service.append_event(session=session_1, event=event)
|
||||
|
||||
# User and app state is stored, temp state is filtered.
|
||||
assert session_1.state.get('app:key') == 'value'
|
||||
@@ -218,7 +224,7 @@ def test_create_new_session_will_merge_states(service_type):
|
||||
assert session_1.state.get('user:key1') == 'value1'
|
||||
assert not session_1.state.get('temp:key')
|
||||
|
||||
session_2 = session_service.create_session(
|
||||
session_2 = await session_service.create_session(
|
||||
app_name=app_name, user_id=user_id, state={}, session_id=session_id_2
|
||||
)
|
||||
# Session 2 has the persisted states
|
||||
@@ -228,15 +234,18 @@ def test_create_new_session_will_merge_states(service_type):
|
||||
assert not session_2.state.get('temp:key')
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
'service_type', [SessionServiceType.IN_MEMORY, SessionServiceType.DATABASE]
|
||||
)
|
||||
def test_append_event_bytes(service_type):
|
||||
async def test_append_event_bytes(service_type):
|
||||
session_service = get_session_service(service_type)
|
||||
app_name = 'my_app'
|
||||
user_id = 'user'
|
||||
|
||||
session = session_service.create_session(app_name=app_name, user_id=user_id)
|
||||
session = await session_service.create_session(
|
||||
app_name=app_name, user_id=user_id
|
||||
)
|
||||
event = Event(
|
||||
invocation_id='invocation',
|
||||
author='user',
|
||||
@@ -249,30 +258,34 @@ def test_append_event_bytes(service_type):
|
||||
],
|
||||
),
|
||||
)
|
||||
session_service.append_event(session=session, event=event)
|
||||
await session_service.append_event(session=session, event=event)
|
||||
|
||||
assert session.events[0].content.parts[0] == types.Part.from_bytes(
|
||||
data=b'test_image_data', mime_type='image/png'
|
||||
)
|
||||
|
||||
events = session_service.get_session(
|
||||
session = await session_service.get_session(
|
||||
app_name=app_name, user_id=user_id, session_id=session.id
|
||||
).events
|
||||
)
|
||||
events = session.events
|
||||
assert len(events) == 1
|
||||
assert events[0].content.parts[0] == types.Part.from_bytes(
|
||||
data=b'test_image_data', mime_type='image/png'
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
'service_type', [SessionServiceType.IN_MEMORY, SessionServiceType.DATABASE]
|
||||
)
|
||||
def test_append_event_complete(service_type):
|
||||
async def test_append_event_complete(service_type):
|
||||
session_service = get_session_service(service_type)
|
||||
app_name = 'my_app'
|
||||
user_id = 'user'
|
||||
|
||||
session = session_service.create_session(app_name=app_name, user_id=user_id)
|
||||
session = await session_service.create_session(
|
||||
app_name=app_name, user_id=user_id
|
||||
)
|
||||
event = Event(
|
||||
invocation_id='invocation',
|
||||
author='user',
|
||||
@@ -291,65 +304,73 @@ def test_append_event_complete(service_type):
|
||||
error_message='error_message',
|
||||
interrupted=True,
|
||||
)
|
||||
session_service.append_event(session=session, event=event)
|
||||
await session_service.append_event(session=session, event=event)
|
||||
|
||||
assert (
|
||||
session_service.get_session(
|
||||
await session_service.get_session(
|
||||
app_name=app_name, user_id=user_id, session_id=session.id
|
||||
)
|
||||
== session
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize('service_type', [SessionServiceType.IN_MEMORY])
|
||||
def test_get_session_with_config(service_type):
|
||||
async def test_get_session_with_config(service_type):
|
||||
session_service = get_session_service(service_type)
|
||||
app_name = 'my_app'
|
||||
user_id = 'user'
|
||||
|
||||
num_test_events = 5
|
||||
session = session_service.create_session(app_name=app_name, user_id=user_id)
|
||||
session = await session_service.create_session(
|
||||
app_name=app_name, user_id=user_id
|
||||
)
|
||||
for i in range(1, num_test_events + 1):
|
||||
event = Event(author='user', timestamp=i)
|
||||
session_service.append_event(session, event)
|
||||
await session_service.append_event(session, event)
|
||||
|
||||
# No config, expect all events to be returned.
|
||||
events = session_service.get_session(
|
||||
session = await session_service.get_session(
|
||||
app_name=app_name, user_id=user_id, session_id=session.id
|
||||
).events
|
||||
)
|
||||
events = session.events
|
||||
assert len(events) == num_test_events
|
||||
|
||||
# Only expect the most recent 3 events.
|
||||
num_recent_events = 3
|
||||
config = GetSessionConfig(num_recent_events=num_recent_events)
|
||||
events = session_service.get_session(
|
||||
session = await session_service.get_session(
|
||||
app_name=app_name, user_id=user_id, session_id=session.id, config=config
|
||||
).events
|
||||
)
|
||||
events = session.events
|
||||
assert len(events) == num_recent_events
|
||||
assert events[0].timestamp == num_test_events - num_recent_events + 1
|
||||
|
||||
# Only expect events after timestamp 4.0 (inclusive), i.e., 2 events.
|
||||
after_timestamp = 4.0
|
||||
config = GetSessionConfig(after_timestamp=after_timestamp)
|
||||
events = session_service.get_session(
|
||||
session = await session_service.get_session(
|
||||
app_name=app_name, user_id=user_id, session_id=session.id, config=config
|
||||
).events
|
||||
)
|
||||
events = session.events
|
||||
assert len(events) == num_test_events - after_timestamp + 1
|
||||
assert events[0].timestamp == after_timestamp
|
||||
|
||||
# Expect no events if none are > after_timestamp.
|
||||
way_after_timestamp = num_test_events * 10
|
||||
config = GetSessionConfig(after_timestamp=way_after_timestamp)
|
||||
events = session_service.get_session(
|
||||
session = await session_service.get_session(
|
||||
app_name=app_name, user_id=user_id, session_id=session.id, config=config
|
||||
).events
|
||||
assert len(events) == 0
|
||||
)
|
||||
assert not session.events
|
||||
|
||||
# Both filters applied, i.e., of 3 most recent events, only 2 are after
|
||||
# timestamp 4.0, so expect 2 events.
|
||||
config = GetSessionConfig(
|
||||
after_timestamp=after_timestamp, num_recent_events=num_recent_events
|
||||
)
|
||||
events = session_service.get_session(
|
||||
session = await session_service.get_session(
|
||||
app_name=app_name, user_id=user_id, session_id=session.id, config=config
|
||||
).events
|
||||
)
|
||||
events = session.events
|
||||
assert len(events) == num_test_events - after_timestamp + 1
|
||||
|
||||
@@ -15,7 +15,7 @@
|
||||
import re
|
||||
import this
|
||||
from typing import Any
|
||||
import uuid
|
||||
|
||||
from dateutil.parser import isoparse
|
||||
from google.adk.events import Event
|
||||
from google.adk.events import EventActions
|
||||
@@ -124,7 +124,9 @@ class MockApiClient:
|
||||
this.session_dict: dict[str, Any] = {}
|
||||
this.event_dict: dict[str, list[Any]] = {}
|
||||
|
||||
def request(self, http_method: str, path: str, request_dict: dict[str, Any]):
|
||||
async def async_request(
|
||||
self, http_method: str, path: str, request_dict: dict[str, Any]
|
||||
):
|
||||
"""Mocks the API Client request method."""
|
||||
if http_method == 'GET':
|
||||
if re.match(SESSION_REGEX, path):
|
||||
@@ -210,46 +212,52 @@ def mock_vertex_ai_session_service():
|
||||
return service
|
||||
|
||||
|
||||
def test_get_empty_session():
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_empty_session():
|
||||
session_service = mock_vertex_ai_session_service()
|
||||
with pytest.raises(ValueError) as excinfo:
|
||||
assert session_service.get_session(
|
||||
assert await session_service.get_session(
|
||||
app_name='123', user_id='user', session_id='0'
|
||||
)
|
||||
assert str(excinfo.value) == 'Session not found: 0'
|
||||
|
||||
|
||||
def test_get_and_delete_session():
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_and_delete_session():
|
||||
session_service = mock_vertex_ai_session_service()
|
||||
|
||||
assert (
|
||||
session_service.get_session(
|
||||
await session_service.get_session(
|
||||
app_name='123', user_id='user', session_id='1'
|
||||
)
|
||||
== MOCK_SESSION
|
||||
)
|
||||
|
||||
session_service.delete_session(app_name='123', user_id='user', session_id='1')
|
||||
await session_service.delete_session(
|
||||
app_name='123', user_id='user', session_id='1'
|
||||
)
|
||||
with pytest.raises(ValueError) as excinfo:
|
||||
assert session_service.get_session(
|
||||
assert await session_service.get_session(
|
||||
app_name='123', user_id='user', session_id='1'
|
||||
)
|
||||
assert str(excinfo.value) == 'Session not found: 1'
|
||||
|
||||
|
||||
def test_list_sessions():
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_sessions():
|
||||
session_service = mock_vertex_ai_session_service()
|
||||
sessions = session_service.list_sessions(app_name='123', user_id='user')
|
||||
sessions = await session_service.list_sessions(app_name='123', user_id='user')
|
||||
assert len(sessions.sessions) == 2
|
||||
assert sessions.sessions[0].id == '1'
|
||||
assert sessions.sessions[1].id == '2'
|
||||
|
||||
|
||||
def test_create_session():
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_session():
|
||||
session_service = mock_vertex_ai_session_service()
|
||||
|
||||
state = {'key': 'value'}
|
||||
session = session_service.create_session(
|
||||
session = await session_service.create_session(
|
||||
app_name='123', user_id='user', state=state
|
||||
)
|
||||
assert session.state == state
|
||||
@@ -258,16 +266,17 @@ def test_create_session():
|
||||
assert session.last_update_time is not None
|
||||
|
||||
session_id = session.id
|
||||
assert session == session_service.get_session(
|
||||
assert session == await session_service.get_session(
|
||||
app_name='123', user_id='user', session_id=session_id
|
||||
)
|
||||
|
||||
|
||||
def test_create_session_with_custom_session_id():
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_session_with_custom_session_id():
|
||||
session_service = mock_vertex_ai_session_service()
|
||||
|
||||
with pytest.raises(ValueError) as excinfo:
|
||||
session_service.create_session(
|
||||
await session_service.create_session(
|
||||
app_name='123', user_id='user', session_id='1'
|
||||
)
|
||||
assert str(excinfo.value) == (
|
||||
|
||||
Reference in New Issue
Block a user