mirror of
https://github.com/EvolutionAPI/adk-python.git
synced 2025-12-18 11:22:22 -06:00
No public description
PiperOrigin-RevId: 748777998
This commit is contained in:
committed by
hangfei
parent
290058eb05
commit
61d4be2d76
@@ -225,3 +225,76 @@ def test_create_new_session_will_merge_states(service_type):
|
||||
assert session_2.state.get('user:key1') == 'value1'
|
||||
assert not session_2.state.get('key1')
|
||||
assert not session_2.state.get('temp:key')
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
'service_type', [SessionServiceType.IN_MEMORY, SessionServiceType.DATABASE]
|
||||
)
|
||||
def test_append_event_bytes(service_type):
|
||||
session_service = get_session_service(service_type)
|
||||
app_name = 'my_app'
|
||||
user_id = 'user'
|
||||
|
||||
session = session_service.create_session(app_name=app_name, user_id=user_id)
|
||||
event = Event(
|
||||
invocation_id='invocation',
|
||||
author='user',
|
||||
content=types.Content(
|
||||
role='user',
|
||||
parts=[
|
||||
types.Part.from_bytes(
|
||||
data=b'test_image_data', mime_type='image/png'
|
||||
),
|
||||
],
|
||||
),
|
||||
)
|
||||
session_service.append_event(session=session, event=event)
|
||||
|
||||
assert session.events[0].content.parts[0] == types.Part.from_bytes(
|
||||
data=b'test_image_data', mime_type='image/png'
|
||||
)
|
||||
|
||||
events = session_service.get_session(
|
||||
app_name=app_name, user_id=user_id, session_id=session.id
|
||||
).events
|
||||
assert len(events) == 1
|
||||
assert events[0].content.parts[0] == types.Part.from_bytes(
|
||||
data=b'test_image_data', mime_type='image/png'
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
'service_type', [SessionServiceType.IN_MEMORY, SessionServiceType.DATABASE]
|
||||
)
|
||||
def test_append_event_complete(service_type):
|
||||
session_service = get_session_service(service_type)
|
||||
app_name = 'my_app'
|
||||
user_id = 'user'
|
||||
|
||||
session = session_service.create_session(app_name=app_name, user_id=user_id)
|
||||
event = Event(
|
||||
invocation_id='invocation',
|
||||
author='user',
|
||||
content=types.Content(role='user', parts=[types.Part(text='test_text')]),
|
||||
turn_complete=True,
|
||||
partial=False,
|
||||
actions=EventActions(
|
||||
artifact_delta={
|
||||
'file': 0,
|
||||
},
|
||||
transfer_to_agent='agent',
|
||||
escalate=True,
|
||||
),
|
||||
long_running_tool_ids={'tool1'},
|
||||
error_code='error_code',
|
||||
error_message='error_message',
|
||||
interrupted=True,
|
||||
)
|
||||
session_service.append_event(session=session, event=event)
|
||||
|
||||
assert (
|
||||
session_service.get_session(
|
||||
app_name=app_name, user_id=user_id, session_id=session.id
|
||||
)
|
||||
== session
|
||||
)
|
||||
|
||||
@@ -57,7 +57,7 @@ MOCK_EVENT_JSON = [
|
||||
{
|
||||
'name': (
|
||||
'projects/test-project/locations/test-location/'
|
||||
'reasoningEngines/test_engine/sessions/1/events/123'
|
||||
'reasoningEngines/123/sessions/1/events/123'
|
||||
),
|
||||
'invocationId': '123',
|
||||
'author': 'user',
|
||||
@@ -111,7 +111,7 @@ MOCK_SESSION = Session(
|
||||
|
||||
|
||||
SESSION_REGEX = r'^reasoningEngines/([^/]+)/sessions/([^/]+)$'
|
||||
SESSIONS_REGEX = r'^reasoningEngines/([^/]+)/sessions$'
|
||||
SESSIONS_REGEX = r'^reasoningEngines/([^/]+)/sessions\?filter=user_id=([^/]+)$'
|
||||
EVENTS_REGEX = r'^reasoningEngines/([^/]+)/sessions/([^/]+)/events$'
|
||||
LRO_REGEX = r'^operations/([^/]+)$'
|
||||
|
||||
@@ -136,39 +136,52 @@ class MockApiClient:
|
||||
else:
|
||||
raise ValueError(f'Session not found: {session_id}')
|
||||
elif re.match(SESSIONS_REGEX, path):
|
||||
match = re.match(SESSIONS_REGEX, path)
|
||||
return {
|
||||
'sessions': self.session_dict.values(),
|
||||
'sessions': [
|
||||
session
|
||||
for session in self.session_dict.values()
|
||||
if session['userId'] == match.group(2)
|
||||
],
|
||||
}
|
||||
elif re.match(EVENTS_REGEX, path):
|
||||
match = re.match(EVENTS_REGEX, path)
|
||||
if match:
|
||||
return {'sessionEvents': self.event_dict[match.group(2)]}
|
||||
return {
|
||||
'sessionEvents': (
|
||||
self.event_dict[match.group(2)]
|
||||
if match.group(2) in self.event_dict
|
||||
else []
|
||||
)
|
||||
}
|
||||
elif re.match(LRO_REGEX, path):
|
||||
return {
|
||||
'name': (
|
||||
'projects/test-project/locations/test-location/'
|
||||
'reasoningEngines/123/sessions/123'
|
||||
'reasoningEngines/123/sessions/4'
|
||||
),
|
||||
'done': True,
|
||||
}
|
||||
else:
|
||||
raise ValueError(f'Unsupported path: {path}')
|
||||
elif http_method == 'POST':
|
||||
id = str(uuid.uuid4())
|
||||
self.session_dict[id] = {
|
||||
new_session_id = '4'
|
||||
self.session_dict[new_session_id] = {
|
||||
'name': (
|
||||
'projects/test-project/locations/test-location/'
|
||||
'reasoningEngines/123/sessions/'
|
||||
+ id
|
||||
+ new_session_id
|
||||
),
|
||||
'userId': request_dict['user_id'],
|
||||
'sessionState': request_dict.get('sessionState', {}),
|
||||
'sessionState': request_dict.get('session_state', {}),
|
||||
'updateTime': '2024-12-12T12:12:12.123456Z',
|
||||
}
|
||||
return {
|
||||
'name': (
|
||||
'projects/test_project/locations/test_location/'
|
||||
'reasoningEngines/test_engine/sessions/123'
|
||||
'reasoningEngines/123/sessions/'
|
||||
+ new_session_id
|
||||
+ '/operations/111'
|
||||
),
|
||||
'done': False,
|
||||
}
|
||||
@@ -223,24 +236,28 @@ def test_get_and_delete_session():
|
||||
)
|
||||
assert str(excinfo.value) == 'Session not found: 1'
|
||||
|
||||
def test_list_sessions():
|
||||
session_service = mock_vertex_ai_session_service()
|
||||
sessions = session_service.list_sessions(app_name='123', user_id='user')
|
||||
assert len(sessions.sessions) == 2
|
||||
assert sessions.sessions[0].id == '1'
|
||||
assert sessions.sessions[1].id == '2'
|
||||
|
||||
def test_create_session():
|
||||
session_service = mock_vertex_ai_session_service()
|
||||
session = session_service.create_session(
|
||||
app_name='123', user_id='user', state={'key': 'value'}
|
||||
)
|
||||
assert session.state == {'key': 'value'}
|
||||
assert session.app_name == '123'
|
||||
assert session.user_id == 'user'
|
||||
assert session.last_update_time is not None
|
||||
def test_list_sessions():
|
||||
session_service = mock_vertex_ai_session_service()
|
||||
sessions = session_service.list_sessions(app_name='123', user_id='user')
|
||||
assert len(sessions.sessions) == 2
|
||||
assert sessions.sessions[0].id == '1'
|
||||
assert sessions.sessions[1].id == '2'
|
||||
|
||||
session_id = session.id
|
||||
assert session == session_service.get_session(
|
||||
app_name='123', user_id='user', session_id=session_id
|
||||
)
|
||||
|
||||
def test_create_session():
|
||||
session_service = mock_vertex_ai_session_service()
|
||||
|
||||
state = {'key': 'value'}
|
||||
session = session_service.create_session(
|
||||
app_name='123', user_id='user', state=state
|
||||
)
|
||||
assert session.state == state
|
||||
assert session.app_name == '123'
|
||||
assert session.user_id == 'user'
|
||||
assert session.last_update_time is not None
|
||||
|
||||
session_id = session.id
|
||||
assert session == session_service.get_session(
|
||||
app_name='123', user_id='user', session_id=session_id
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user