No public description

PiperOrigin-RevId: 748777998
This commit is contained in:
Google ADK Member
2025-04-17 19:50:22 +00:00
committed by hangfei
parent 290058eb05
commit 61d4be2d76
99 changed files with 2120 additions and 256 deletions

View File

@@ -225,3 +225,76 @@ def test_create_new_session_will_merge_states(service_type):
assert session_2.state.get('user:key1') == 'value1'
assert not session_2.state.get('key1')
assert not session_2.state.get('temp:key')
@pytest.mark.parametrize(
'service_type', [SessionServiceType.IN_MEMORY, SessionServiceType.DATABASE]
)
def test_append_event_bytes(service_type):
session_service = get_session_service(service_type)
app_name = 'my_app'
user_id = 'user'
session = session_service.create_session(app_name=app_name, user_id=user_id)
event = Event(
invocation_id='invocation',
author='user',
content=types.Content(
role='user',
parts=[
types.Part.from_bytes(
data=b'test_image_data', mime_type='image/png'
),
],
),
)
session_service.append_event(session=session, event=event)
assert session.events[0].content.parts[0] == types.Part.from_bytes(
data=b'test_image_data', mime_type='image/png'
)
events = session_service.get_session(
app_name=app_name, user_id=user_id, session_id=session.id
).events
assert len(events) == 1
assert events[0].content.parts[0] == types.Part.from_bytes(
data=b'test_image_data', mime_type='image/png'
)
@pytest.mark.parametrize(
'service_type', [SessionServiceType.IN_MEMORY, SessionServiceType.DATABASE]
)
def test_append_event_complete(service_type):
session_service = get_session_service(service_type)
app_name = 'my_app'
user_id = 'user'
session = session_service.create_session(app_name=app_name, user_id=user_id)
event = Event(
invocation_id='invocation',
author='user',
content=types.Content(role='user', parts=[types.Part(text='test_text')]),
turn_complete=True,
partial=False,
actions=EventActions(
artifact_delta={
'file': 0,
},
transfer_to_agent='agent',
escalate=True,
),
long_running_tool_ids={'tool1'},
error_code='error_code',
error_message='error_message',
interrupted=True,
)
session_service.append_event(session=session, event=event)
assert (
session_service.get_session(
app_name=app_name, user_id=user_id, session_id=session.id
)
== session
)

View File

@@ -57,7 +57,7 @@ MOCK_EVENT_JSON = [
{
'name': (
'projects/test-project/locations/test-location/'
'reasoningEngines/test_engine/sessions/1/events/123'
'reasoningEngines/123/sessions/1/events/123'
),
'invocationId': '123',
'author': 'user',
@@ -111,7 +111,7 @@ MOCK_SESSION = Session(
SESSION_REGEX = r'^reasoningEngines/([^/]+)/sessions/([^/]+)$'
SESSIONS_REGEX = r'^reasoningEngines/([^/]+)/sessions$'
SESSIONS_REGEX = r'^reasoningEngines/([^/]+)/sessions\?filter=user_id=([^/]+)$'
EVENTS_REGEX = r'^reasoningEngines/([^/]+)/sessions/([^/]+)/events$'
LRO_REGEX = r'^operations/([^/]+)$'
@@ -136,39 +136,52 @@ class MockApiClient:
else:
raise ValueError(f'Session not found: {session_id}')
elif re.match(SESSIONS_REGEX, path):
match = re.match(SESSIONS_REGEX, path)
return {
'sessions': self.session_dict.values(),
'sessions': [
session
for session in self.session_dict.values()
if session['userId'] == match.group(2)
],
}
elif re.match(EVENTS_REGEX, path):
match = re.match(EVENTS_REGEX, path)
if match:
return {'sessionEvents': self.event_dict[match.group(2)]}
return {
'sessionEvents': (
self.event_dict[match.group(2)]
if match.group(2) in self.event_dict
else []
)
}
elif re.match(LRO_REGEX, path):
return {
'name': (
'projects/test-project/locations/test-location/'
'reasoningEngines/123/sessions/123'
'reasoningEngines/123/sessions/4'
),
'done': True,
}
else:
raise ValueError(f'Unsupported path: {path}')
elif http_method == 'POST':
id = str(uuid.uuid4())
self.session_dict[id] = {
new_session_id = '4'
self.session_dict[new_session_id] = {
'name': (
'projects/test-project/locations/test-location/'
'reasoningEngines/123/sessions/'
+ id
+ new_session_id
),
'userId': request_dict['user_id'],
'sessionState': request_dict.get('sessionState', {}),
'sessionState': request_dict.get('session_state', {}),
'updateTime': '2024-12-12T12:12:12.123456Z',
}
return {
'name': (
'projects/test_project/locations/test_location/'
'reasoningEngines/test_engine/sessions/123'
'reasoningEngines/123/sessions/'
+ new_session_id
+ '/operations/111'
),
'done': False,
}
@@ -223,24 +236,28 @@ def test_get_and_delete_session():
)
assert str(excinfo.value) == 'Session not found: 1'
def test_list_sessions():
session_service = mock_vertex_ai_session_service()
sessions = session_service.list_sessions(app_name='123', user_id='user')
assert len(sessions.sessions) == 2
assert sessions.sessions[0].id == '1'
assert sessions.sessions[1].id == '2'
def test_create_session():
session_service = mock_vertex_ai_session_service()
session = session_service.create_session(
app_name='123', user_id='user', state={'key': 'value'}
)
assert session.state == {'key': 'value'}
assert session.app_name == '123'
assert session.user_id == 'user'
assert session.last_update_time is not None
def test_list_sessions():
session_service = mock_vertex_ai_session_service()
sessions = session_service.list_sessions(app_name='123', user_id='user')
assert len(sessions.sessions) == 2
assert sessions.sessions[0].id == '1'
assert sessions.sessions[1].id == '2'
session_id = session.id
assert session == session_service.get_session(
app_name='123', user_id='user', session_id=session_id
)
def test_create_session():
session_service = mock_vertex_ai_session_service()
state = {'key': 'value'}
session = session_service.create_session(
app_name='123', user_id='user', state=state
)
assert session.state == state
assert session.app_name == '123'
assert session.user_id == 'user'
assert session.last_update_time is not None
session_id = session.id
assert session == session_service.get_session(
app_name='123', user_id='user', session_id=session_id
)