Moves unittests to root folder and adds github action to run unit tests. (#72)

* Move unit tests to root package.

* Adds deps to "test" extra, and mark two broken tests in tests/unittests/auth/test_auth_handler.py

* Adds github workflow

* minor fix in lite_llm.py for python 3.9.

* format pyproject.toml
This commit is contained in:
Jack Sun
2025-04-11 08:25:59 -07:00
committed by GitHub
parent 59117b9b96
commit 05142a07cc
66 changed files with 50 additions and 2 deletions

View File

@@ -0,0 +1,14 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

View File

@@ -0,0 +1,227 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import enum
import pytest
from google.adk.events import Event
from google.adk.events import EventActions
from google.adk.sessions import DatabaseSessionService
from google.adk.sessions import InMemorySessionService
from google.genai import types
class SessionServiceType(enum.Enum):
IN_MEMORY = 'IN_MEMORY'
DATABASE = 'DATABASE'
def get_session_service(
service_type: SessionServiceType = SessionServiceType.IN_MEMORY,
):
"""Creates a session service for testing."""
if service_type == SessionServiceType.DATABASE:
return DatabaseSessionService('sqlite:///:memory:')
return InMemorySessionService()
@pytest.mark.parametrize(
'service_type', [SessionServiceType.IN_MEMORY, SessionServiceType.DATABASE]
)
def test_get_empty_session(service_type):
session_service = get_session_service(service_type)
assert not session_service.get_session(
app_name='my_app', user_id='test_user', session_id='123'
)
@pytest.mark.parametrize(
'service_type', [SessionServiceType.IN_MEMORY, SessionServiceType.DATABASE]
)
def test_create_get_session(service_type):
session_service = get_session_service(service_type)
app_name = 'my_app'
user_id = 'test_user'
state = {'key': 'value'}
session = session_service.create_session(
app_name=app_name, user_id=user_id, state=state
)
assert session.app_name == app_name
assert session.user_id == user_id
assert session.id
assert session.state == state
assert (
session_service.get_session(
app_name=app_name, user_id=user_id, session_id=session.id
)
== session
)
session_id = session.id
session_service.delete_session(
app_name=app_name, user_id=user_id, session_id=session_id
)
assert (
not session_service.get_session(
app_name=app_name, user_id=user_id, session_id=session.id
)
== session
)
@pytest.mark.parametrize(
'service_type', [SessionServiceType.IN_MEMORY, SessionServiceType.DATABASE]
)
def test_create_and_list_sessions(service_type):
session_service = get_session_service(service_type)
app_name = 'my_app'
user_id = 'test_user'
session_ids = ['session' + str(i) for i in range(5)]
for session_id in session_ids:
session_service.create_session(
app_name=app_name, user_id=user_id, session_id=session_id
)
sessions = session_service.list_sessions(
app_name=app_name, user_id=user_id
).sessions
for i in range(len(sessions)):
assert sessions[i].id == session_ids[i]
@pytest.mark.parametrize(
'service_type', [SessionServiceType.IN_MEMORY, SessionServiceType.DATABASE]
)
def test_session_state(service_type):
session_service = get_session_service(service_type)
app_name = 'my_app'
user_id_1 = 'user1'
user_id_2 = 'user2'
session_id_11 = 'session11'
session_id_12 = 'session12'
session_id_2 = 'session2'
state_11 = {'key11': 'value11'}
state_12 = {'key12': 'value12'}
session_11 = session_service.create_session(
app_name=app_name,
user_id=user_id_1,
state=state_11,
session_id=session_id_11,
)
session_service.create_session(
app_name=app_name,
user_id=user_id_1,
state=state_12,
session_id=session_id_12,
)
session_service.create_session(
app_name=app_name, user_id=user_id_2, session_id=session_id_2
)
assert session_11.state.get('key11') == 'value11'
event = Event(
invocation_id='invocation',
author='user',
content=types.Content(role='user', parts=[types.Part(text='text')]),
actions=EventActions(
state_delta={
'app:key': 'value',
'user:key1': 'value1',
'temp:key': 'temp',
'key11': 'value11_new',
}
),
)
session_service.append_event(session=session_11, event=event)
# User and app state is stored, temp state is filtered.
assert session_11.state.get('app:key') == 'value'
assert session_11.state.get('key11') == 'value11_new'
assert session_11.state.get('user:key1') == 'value1'
assert not session_11.state.get('temp:key')
session_12 = session_service.get_session(
app_name=app_name, user_id=user_id_1, session_id=session_id_12
)
# After getting a new instance, the session_12 got the user and app state,
# even append_event is not applied to it, temp state has no effect
assert session_12.state.get('key12') == 'value12'
assert not session_12.state.get('temp:key')
# The user1's state is not visible to user2, app state is visible
session_2 = session_service.get_session(
app_name=app_name, user_id=user_id_2, session_id=session_id_2
)
assert session_2.state.get('app:key') == 'value'
assert not session_2.state.get('user:key1')
assert not session_2.state.get('user:key1')
# The change to session_11 is persisted
session_11 = session_service.get_session(
app_name=app_name, user_id=user_id_1, session_id=session_id_11
)
assert session_11.state.get('key11') == 'value11_new'
assert session_11.state.get('user:key1') == 'value1'
assert not session_11.state.get('temp:key')
@pytest.mark.parametrize(
"service_type", [SessionServiceType.IN_MEMORY, SessionServiceType.DATABASE]
)
def test_create_new_session_will_merge_states(service_type):
session_service = get_session_service(service_type)
app_name = 'my_app'
user_id = 'user'
session_id_1 = 'session1'
session_id_2 = 'session2'
state_1 = {'key1': 'value1'}
session_1 = session_service.create_session(
app_name=app_name, user_id=user_id, state=state_1, session_id=session_id_1
)
event = Event(
invocation_id='invocation',
author='user',
content=types.Content(role='user', parts=[types.Part(text='text')]),
actions=EventActions(
state_delta={
'app:key': 'value',
'user:key1': 'value1',
'temp:key': 'temp',
}
),
)
session_service.append_event(session=session_1, event=event)
# User and app state is stored, temp state is filtered.
assert session_1.state.get('app:key') == 'value'
assert session_1.state.get('key1') == 'value1'
assert session_1.state.get('user:key1') == 'value1'
assert not session_1.state.get('temp:key')
session_2 = session_service.create_session(
app_name=app_name, user_id=user_id, state={}, session_id=session_id_2
)
# Session 2 has the persisted states
assert session_2.state.get('app:key') == 'value'
assert session_2.state.get('user:key1') == 'value1'
assert not session_2.state.get('key1')
assert not session_2.state.get('temp:key')

View File

@@ -0,0 +1,246 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import re
import this
from typing import Any
import uuid
from dateutil.parser import isoparse
from google.adk.events import Event
from google.adk.events import EventActions
from google.adk.sessions import Session
from google.adk.sessions import VertexAiSessionService
from google.genai import types
import pytest
MOCK_SESSION_JSON_1 = {
'name': (
'projects/test-project/locations/test-location/'
'reasoningEngines/123/sessions/1'
),
'createTime': '2024-12-12T12:12:12.123456Z',
'updateTime': '2024-12-12T12:12:12.123456Z',
'sessionState': {
'key': {'value': 'test_value'},
},
'userId': 'user',
}
MOCK_SESSION_JSON_2 = {
'name': (
'projects/test-project/locations/test-location/'
'reasoningEngines/123/sessions/2'
),
'updateTime': '2024-12-13T12:12:12.123456Z',
'userId': 'user',
}
MOCK_SESSION_JSON_3 = {
'name': (
'projects/test-project/locations/test-location/'
'reasoningEngines/123/sessions/3'
),
'updateTime': '2024-12-14T12:12:12.123456Z',
'userId': 'user2',
}
MOCK_EVENT_JSON = [
{
'name': (
'projects/test-project/locations/test-location/'
'reasoningEngines/test_engine/sessions/1/events/123'
),
'invocationId': '123',
'author': 'user',
'timestamp': '2024-12-12T12:12:12.123456Z',
'content': {
'parts': [
{'text': 'test_content'},
],
},
'actions': {
'stateDelta': {
'key': {'value': 'test_value'},
},
'transferAgent': 'agent',
},
'eventMetadata': {
'partial': False,
'turnComplete': True,
'interrupted': False,
'branch': '',
'longRunningToolIds': ['tool1'],
},
},
]
MOCK_SESSION = Session(
app_name='123',
user_id='user',
id='1',
state=MOCK_SESSION_JSON_1['sessionState'],
last_update_time=isoparse(MOCK_SESSION_JSON_1['updateTime']).timestamp(),
events=[
Event(
id='123',
invocation_id='123',
author='user',
timestamp=isoparse(MOCK_EVENT_JSON[0]['timestamp']).timestamp(),
content=types.Content(parts=[types.Part(text='test_content')]),
actions=EventActions(
transfer_to_agent='agent',
state_delta={'key': {'value': 'test_value'}},
),
partial=False,
turn_complete=True,
interrupted=False,
branch='',
long_running_tool_ids={'tool1'},
),
],
)
SESSION_REGEX = r'^reasoningEngines/([^/]+)/sessions/([^/]+)$'
SESSIONS_REGEX = r'^reasoningEngines/([^/]+)/sessions$'
EVENTS_REGEX = r'^reasoningEngines/([^/]+)/sessions/([^/]+)/events$'
LRO_REGEX = r'^operations/([^/]+)$'
class MockApiClient:
"""Mocks the API Client."""
def __init__(self) -> None:
"""Initializes MockClient."""
this.session_dict: dict[str, Any] = {}
this.event_dict: dict[str, list[Any]] = {}
def request(self, http_method: str, path: str, request_dict: dict[str, Any]):
"""Mocks the API Client request method."""
if http_method == 'GET':
if re.match(SESSION_REGEX, path):
match = re.match(SESSION_REGEX, path)
if match:
session_id = match.group(2)
if session_id in self.session_dict:
return self.session_dict[session_id]
else:
raise ValueError(f'Session not found: {session_id}')
elif re.match(SESSIONS_REGEX, path):
return {
'sessions': self.session_dict.values(),
}
elif re.match(EVENTS_REGEX, path):
match = re.match(EVENTS_REGEX, path)
if match:
return {'sessionEvents': self.event_dict[match.group(2)]}
elif re.match(LRO_REGEX, path):
return {
'name': (
'projects/test-project/locations/test-location/'
'reasoningEngines/123/sessions/123'
),
'done': True,
}
else:
raise ValueError(f'Unsupported path: {path}')
elif http_method == 'POST':
id = str(uuid.uuid4())
self.session_dict[id] = {
'name': (
'projects/test-project/locations/test-location/'
'reasoningEngines/123/sessions/'
+ id
),
'userId': request_dict['user_id'],
'sessionState': request_dict.get('sessionState', {}),
'updateTime': '2024-12-12T12:12:12.123456Z',
}
return {
'name': (
'projects/test_project/locations/test_location/'
'reasoningEngines/test_engine/sessions/123'
),
'done': False,
}
elif http_method == 'DELETE':
match = re.match(SESSION_REGEX, path)
if match:
self.session_dict.pop(match.group(2))
else:
raise ValueError(f'Unsupported http method: {http_method}')
def mock_vertex_ai_session_service():
"""Creates a mock Vertex AI Session service for testing."""
service = VertexAiSessionService(
project='test-project', location='test-location'
)
service.api_client = MockApiClient()
service.api_client.session_dict = {
'1': MOCK_SESSION_JSON_1,
'2': MOCK_SESSION_JSON_2,
'3': MOCK_SESSION_JSON_3,
}
service.api_client.event_dict = {
'1': MOCK_EVENT_JSON,
}
return service
def test_get_empty_session():
session_service = mock_vertex_ai_session_service()
with pytest.raises(ValueError) as excinfo:
assert session_service.get_session(
app_name='123', user_id='user', session_id='0'
)
assert str(excinfo.value) == 'Session not found: 0'
def test_get_and_delete_session():
session_service = mock_vertex_ai_session_service()
assert (
session_service.get_session(
app_name='123', user_id='user', session_id='1'
)
== MOCK_SESSION
)
session_service.delete_session(app_name='123', user_id='user', session_id='1')
with pytest.raises(ValueError) as excinfo:
assert session_service.get_session(
app_name='123', user_id='user', session_id='1'
)
assert str(excinfo.value) == 'Session not found: 1'
def test_list_sessions():
session_service = mock_vertex_ai_session_service()
sessions = session_service.list_sessions(app_name='123', user_id='user')
assert len(sessions.sessions) == 2
assert sessions.sessions[0].id == '1'
assert sessions.sessions[1].id == '2'
def test_create_session():
session_service = mock_vertex_ai_session_service()
session = session_service.create_session(
app_name='123', user_id='user', state={'key': 'value'}
)
assert session.state == {'key': 'value'}
assert session.app_name == '123'
assert session.user_id == 'user'
assert session.last_update_time is not None
session_id = session.id
assert session == session_service.get_session(
app_name='123', user_id='user', session_id=session_id
)