From cbdb5fc507dd941acb08a99440436d0f21463041 Mon Sep 17 00:00:00 2001 From: "Xiang (Sean) Zhou" Date: Tue, 20 May 2025 21:17:03 -0700 Subject: [PATCH] chore: fix ut for fast api server PiperOrigin-RevId: 761350248 --- src/google/adk/cli/fast_api.py | 2 +- tests/unittests/fast_api/test_fast_api.py | 541 +++++++++++++++------- 2 files changed, 374 insertions(+), 169 deletions(-) diff --git a/src/google/adk/cli/fast_api.py b/src/google/adk/cli/fast_api.py index 8729d6b..9ba608e 100644 --- a/src/google/adk/cli/fast_api.py +++ b/src/google/adk/cli/fast_api.py @@ -61,7 +61,7 @@ from ..agents.live_request_queue import LiveRequestQueue from ..agents.llm_agent import Agent from ..agents.llm_agent import LlmAgent from ..agents.run_config import StreamingMode -from ..artifacts import InMemoryArtifactService +from ..artifacts.in_memory_artifact_service import InMemoryArtifactService from ..evaluation.eval_case import EvalCase from ..evaluation.eval_case import SessionInput from ..evaluation.local_eval_sets_manager import LocalEvalSetsManager diff --git a/tests/unittests/fast_api/test_fast_api.py b/tests/unittests/fast_api/test_fast_api.py index 62c7e79..b285908 100644 --- a/tests/unittests/fast_api/test_fast_api.py +++ b/tests/unittests/fast_api/test_fast_api.py @@ -12,52 +12,62 @@ # See the License for the specific language governing permissions and # limitations under the License. -from __future__ import annotations - import asyncio -import json +import logging +import os import sys -import threading import time import types as ptypes -from typing import AsyncGenerator -from typing import TYPE_CHECKING +from unittest.mock import MagicMock, patch +from fastapi.testclient import TestClient from google.adk.agents.base_agent import BaseAgent -from google.adk.agents.live_request_queue import LiveRequest from google.adk.agents.run_config import RunConfig -from google.adk.cli.fast_api import AgentRunRequest from google.adk.cli.fast_api import get_fast_api_app from google.adk.cli.utils import envs +from google.adk.events import Event from google.adk.runners import Runner +from google.adk.sessions.base_session_service import ListSessionsResponse from google.genai import types -import httpx import pytest -from uvicorn.main import run as uvicorn_run -import websockets - -if TYPE_CHECKING: - from google.adk.events import Event -# Here we “fake” the agent module that get_fast_api_app expects. -# The server code does: `agent_module = importlib.import_module(agent_name)` -# and then accesses: agent_module.agent.root_agent. +# Configure logging to help diagnose server startup issues +logging.basicConfig( + level=logging.INFO, + format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", +) +logger = logging.getLogger(__name__) + + +# Here we create a dummy agent module that get_fast_api_app expects class DummyAgent(BaseAgent): - pass + + def __init__(self, name): + super().__init__(name=name) + self.sub_agents = [] +# Set up dummy module and add to sys.modules dummy_module = ptypes.ModuleType("test_agent") dummy_module.agent = ptypes.SimpleNamespace( root_agent=DummyAgent(name="dummy_agent") ) sys.modules["test_app"] = dummy_module -envs.load_dotenv_for_agent("test_app", ".") + +# Try to load environment variables, with a fallback for testing +try: + envs.load_dotenv_for_agent("test_app", ".") +except Exception as e: + logger.warning(f"Could not load environment variables: {e}") + # Create a basic .env file if needed + if not os.path.exists(".env"): + with open(".env", "w") as f: + f.write("# Test environment variables\n") +# Create sample events that our mocked runner will return def _event_1(): - from google.adk.events import Event - return Event( author="dummy agent", invocation_id="invocation_id", @@ -68,8 +78,6 @@ def _event_1(): def _event_2(): - from google.adk.events import Event - return Event( author="dummy agent", invocation_id="invocation_id", @@ -88,19 +96,13 @@ def _event_2(): def _event_3(): - from google.adk.events import Event - return Event( author="dummy agent", invocation_id="invocation_id", interrupted=True ) -# For simplicity, we patch Runner.run_live to yield dummy events. -# We use SimpleNamespace to mimic attribute-access (i.e. event.content.parts). -async def dummy_run_live( - self, session, live_request_queue -) -> AsyncGenerator[Event, None]: - # Immediately yield a dummy event with a text reply. +# Define mocked async generator functions for the Runner +async def dummy_run_live(self, session, live_request_queue): yield _event_1() await asyncio.sleep(0) @@ -109,8 +111,6 @@ async def dummy_run_live( yield _event_3() - raise Exception() - async def dummy_run_async( self, @@ -118,8 +118,7 @@ async def dummy_run_async( session_id, new_message, run_config: RunConfig = RunConfig(), -) -> AsyncGenerator[Event, None]: - # Immediately yield a dummy event with a text reply. +): yield _event_1() await asyncio.sleep(0) @@ -128,159 +127,365 @@ async def dummy_run_async( yield _event_3() - return - -############################################################################### -# Pytest fixtures to patch methods and start the server -############################################################################### +################################################# +# Test Fixtures +################################################# @pytest.fixture(autouse=True) def patch_runner(monkeypatch): - # Patch the Runner methods to use our dummy implementations. + """Patch the Runner methods to use our dummy implementations.""" monkeypatch.setattr(Runner, "run_live", dummy_run_live) monkeypatch.setattr(Runner, "run_async", dummy_run_async) -@pytest.fixture(scope="module", autouse=True) -def start_server(): - """Start the FastAPI server in a background thread.""" - - def run_server(): - uvicorn_run( - get_fast_api_app(agent_dir=".", web=True), - host="0.0.0.0", - log_config=None, - ) - - server_thread = threading.Thread(target=run_server, daemon=True) - server_thread.start() - # Wait a moment to ensure the server is up. - time.sleep(2) - yield - # The daemon thread will be terminated when tests complete. +@pytest.fixture +def test_session_info(): + """Return test user and session IDs for testing.""" + return { + "app_name": "test_app", + "user_id": "test_user", + "session_id": "test_session", + } -@pytest.mark.asyncio -async def test_sse_endpoint(): - base_http_url = "http://127.0.0.1:8000" - user_id = "test_user" - session_id = "test_session" +@pytest.fixture +def mock_session_service(): + """Create a mock session service that uses an in-memory dictionary.""" - # Ensure that the session exists (create if necessary). - url_create = ( - f"{base_http_url}/apps/test_app/users/{user_id}/sessions/{session_id}" - ) - httpx.post(url_create, json={"state": {}}) + # In-memory database to store sessions during testing + session_data = { + "test_app": { + "test_user": { + "test_session": { + "id": "test_session", + "app_name": "test_app", + "user_id": "test_user", + "events": [], + "state": {}, + "created_at": time.time(), + } + } + } + } - async with httpx.AsyncClient() as client: - # Make a POST request to the SSE endpoint. - async with client.stream( - "POST", - f"{base_http_url}/run_sse", - json=json.loads( - AgentRunRequest( - app_name="test_app", - user_id=user_id, - session_id=session_id, - new_message=types.Content( - parts=[types.Part(text="Hello via SSE", inline_data=None)] - ), - streaming=False, - ).model_dump_json(exclude_none=True) - ), - ) as response: - # Ensure the status code and header are as expected. - assert response.status_code == 200 - assert ( - response.headers.get("content-type") - == "text/event-stream; charset=utf-8" + # Mock session service class that operates on the in-memory database + class MockSessionService: + + async def get_session(self, app_name, user_id, session_id): + """Retrieve a session by ID.""" + if ( + app_name in session_data + and user_id in session_data[app_name] + and session_id in session_data[app_name][user_id] + ): + return session_data[app_name][user_id][session_id] + return None + + async def create_session( + self, app_name, user_id, state=None, session_id=None + ): + """Create a new session.""" + if session_id is None: + session_id = f"session_{int(time.time())}" + + # Initialize app_name and user_id if they don't exist + if app_name not in session_data: + session_data[app_name] = {} + if user_id not in session_data[app_name]: + session_data[app_name][user_id] = {} + + # Create the session + session = { + "id": session_id, + "app_name": app_name, + "user_id": user_id, + "events": [], + "state": state or {}, + } + + session_data[app_name][user_id][session_id] = session + return session + + async def list_sessions(self, app_name, user_id): + """List all sessions for a user.""" + if app_name not in session_data or user_id not in session_data[app_name]: + return {"sessions": []} + + return ListSessionsResponse( + sessions=list(session_data[app_name][user_id].values()) ) - # Iterate over events from the stream. - event_count = 0 - event_buffer = "" + async def delete_session(self, app_name, user_id, session_id): + """Delete a session.""" + if ( + app_name in session_data + and user_id in session_data[app_name] + and session_id in session_data[app_name][user_id] + ): + del session_data[app_name][user_id][session_id] - async for line in response.aiter_lines(): - event_buffer += line + "\n" - - # An SSE event is terminated by an empty line (double newline) - if line == "" and event_buffer.strip(): - # Process the complete event - event_data = None - for event_line in event_buffer.split("\n"): - if event_line.startswith("data: "): - event_data = event_line[6:] # Remove "data: " prefix - - if event_data: - event_count += 1 - if event_count == 1: - assert event_data == _event_1().model_dump_json( - exclude_none=True, by_alias=True - ) - elif event_count == 2: - assert event_data == _event_2().model_dump_json( - exclude_none=True, by_alias=True - ) - elif event_count == 3: - assert event_data == _event_3().model_dump_json( - exclude_none=True, by_alias=True - ) - else: - pass - - # Reset buffer for next event - event_buffer = "" - - assert event_count == 3 # Expecting 3 events from dummy_run_async + # Return an instance of our mock service + return MockSessionService() -@pytest.mark.asyncio -async def test_websocket_endpoint(): - base_http_url = "http://127.0.0.1:8000" - base_ws_url = "ws://127.0.0.1:8000" - user_id = "test_user" - session_id = "test_session" +@pytest.fixture +def mock_artifact_service(): + """Create a mock artifact service.""" - # Ensure that the session exists (create if necessary). - url_create = ( - f"{base_http_url}/apps/test_app/users/{user_id}/sessions/{session_id}" + # Storage for artifacts + artifacts = {} + + class MockArtifactService: + + async def load_artifact( + self, app_name, user_id, session_id, filename, version=None + ): + """Load an artifact by filename.""" + key = f"{app_name}:{user_id}:{session_id}:{filename}" + if key not in artifacts: + return None + + if version is not None: + # Get a specific version + for v in artifacts[key]: + if v["version"] == version: + return v["artifact"] + return None + + # Get the latest version + return sorted(artifacts[key], key=lambda x: x["version"])[-1]["artifact"] + + async def list_artifact_keys(self, app_name, user_id, session_id): + """List artifact names for a session.""" + prefix = f"{app_name}:{user_id}:{session_id}:" + return [ + k.split(":")[-1] for k in artifacts.keys() if k.startswith(prefix) + ] + + async def list_versions(self, app_name, user_id, session_id, filename): + """List versions of an artifact.""" + key = f"{app_name}:{user_id}:{session_id}:{filename}" + if key not in artifacts: + return [] + return [a["version"] for a in artifacts[key]] + + async def delete_artifact(self, app_name, user_id, session_id, filename): + """Delete an artifact.""" + key = f"{app_name}:{user_id}:{session_id}:{filename}" + if key in artifacts: + del artifacts[key] + + return MockArtifactService() + + +@pytest.fixture +def mock_memory_service(): + """Create a mock memory service.""" + return MagicMock() + + +@pytest.fixture +def test_app(mock_session_service, mock_artifact_service, mock_memory_service): + """Create a TestClient for the FastAPI app without starting a server.""" + + # Patch multiple services and signal handlers + with ( + patch("signal.signal", return_value=None), + patch( + "google.adk.cli.fast_api.InMemorySessionService", # Changed this line + return_value=mock_session_service, + ), + patch( + "google.adk.cli.fast_api.InMemoryArtifactService", # Make consistent + return_value=mock_artifact_service, + ), + patch( + "google.adk.cli.fast_api.InMemoryMemoryService", # Make consistent + return_value=mock_memory_service, + ), + ): + # Get the FastAPI app, but don't actually run it + app = get_fast_api_app( + agent_dir=".", web=True, session_db_url="", allow_origins=["*"] + ) + + # Create a TestClient that doesn't start a real server + client = TestClient(app) + + return client + + +@pytest.fixture +async def create_test_session( + test_app, test_session_info, mock_session_service +): + """Create a test session using the mocked session service.""" + + # Create the session directly through the mock service + session = await mock_session_service.create_session( + app_name=test_session_info["app_name"], + user_id=test_session_info["user_id"], + session_id=test_session_info["session_id"], + state={}, ) - httpx.post(url_create, json={"state": {}}) - ws_url = f"{base_ws_url}/run_live?app_name=test_app&user_id={user_id}&session_id={session_id}" - async with websockets.connect(ws_url) as ws: - # --- Test sending text data --- - text_payload = LiveRequest( - content=types.Content( - parts=[types.Part(text="Hello via WebSocket", inline_data=None)] - ) - ) - await ws.send(text_payload.model_dump_json()) - # Wait for a reply from our dummy_run_live. - reply = await ws.recv() - event = Event.model_validate_json(reply) - assert event.content.parts[0].text == "LLM reply" + logger.info(f"Created test session: {session['id']}") + return test_session_info - # --- Test sending binary data (allowed mime type "audio/pcm") --- - sample_audio = b"\x00\xFF" - binary_payload = LiveRequest( - blob=types.Blob( - mime_type="audio/pcm", - data=sample_audio, - ) - ) - await ws.send(binary_payload.model_dump_json()) - # Wait for a reply. - reply = await ws.recv() - event = Event.model_validate_json(reply) - assert ( - event.content.parts[0].inline_data.mime_type == "audio/pcm;rate=24000" - ) - assert event.content.parts[0].inline_data.data == b"\x00\xFF" - reply = await ws.recv() - event = Event.model_validate_json(reply) - assert event.interrupted is True - assert event.content is None +################################################# +# Test Cases +################################################# + + +def test_list_apps(test_app): + """Test listing available applications.""" + # Use the TestClient to make a request + response = test_app.get("/list-apps") + + # Verify the response + assert response.status_code == 200 + data = response.json() + assert isinstance(data, list) + logger.info(f"Listed apps: {data}") + + +def test_create_session_with_id(test_app, test_session_info): + """Test creating a session with a specific ID.""" + new_session_id = "new_session_id" + url = f"/apps/{test_session_info['app_name']}/users/{test_session_info['user_id']}/sessions/{new_session_id}" + response = test_app.post(url, json={"state": {}}) + + # Verify the response + assert response.status_code == 200 + data = response.json() + assert data["id"] == new_session_id + assert data["appName"] == test_session_info["app_name"] + assert data["userId"] == test_session_info["user_id"] + logger.info(f"Created session with ID: {data['id']}") + + +def test_create_session_without_id(test_app, test_session_info): + """Test creating a session with a generated ID.""" + url = f"/apps/{test_session_info['app_name']}/users/{test_session_info['user_id']}/sessions" + response = test_app.post(url, json={"state": {}}) + + # Verify the response + assert response.status_code == 200 + data = response.json() + assert "id" in data + assert data["appName"] == test_session_info["app_name"] + assert data["userId"] == test_session_info["user_id"] + logger.info(f"Created session with generated ID: {data['id']}") + + +def test_get_session(test_app, create_test_session): + """Test retrieving a session by ID.""" + info = create_test_session + url = f"/apps/{info['app_name']}/users/{info['user_id']}/sessions/{info['session_id']}" + response = test_app.get(url) + + # Verify the response + assert response.status_code == 200 + data = response.json() + assert data["id"] == info["session_id"] + assert data["appName"] == info["app_name"] + assert data["userId"] == info["user_id"] + logger.info(f"Retrieved session: {data['id']}") + + +def test_list_sessions(test_app, create_test_session): + """Test listing all sessions for a user.""" + info = create_test_session + url = f"/apps/{info['app_name']}/users/{info['user_id']}/sessions" + response = test_app.get(url) + + # Verify the response + assert response.status_code == 200 + data = response.json() + assert isinstance(data, list) + # At least our test session should be present + assert any(session["id"] == info["session_id"] for session in data) + logger.info(f"Listed {len(data)} sessions") + + +def test_delete_session(test_app, create_test_session): + """Test deleting a session.""" + info = create_test_session + url = f"/apps/{info['app_name']}/users/{info['user_id']}/sessions/{info['session_id']}" + response = test_app.delete(url) + + # Verify the response + assert response.status_code == 200 + + # Verify the session is deleted + response = test_app.get(url) + assert response.status_code == 404 + logger.info("Session deleted successfully") + + +def test_agent_run(test_app, create_test_session): + """Test running an agent with a message.""" + info = create_test_session + url = "/run" + payload = { + "app_name": info["app_name"], + "user_id": info["user_id"], + "session_id": info["session_id"], + "new_message": {"role": "user", "parts": [{"text": "Hello agent"}]}, + "streaming": False, + } + + response = test_app.post(url, json=payload) + + # Verify the response + assert response.status_code == 200 + data = response.json() + assert isinstance(data, list) + assert len(data) == 3 # We expect 3 events from our dummy_run_async + + # Verify we got the expected events + assert data[0]["author"] == "dummy agent" + assert data[0]["content"]["parts"][0]["text"] == "LLM reply" + + # Second event should have binary data + assert ( + data[1]["content"]["parts"][0]["inlineData"]["mimeType"] + == "audio/pcm;rate=24000" + ) + + # Third event should have interrupted flag + assert data[2]["interrupted"] == True + + logger.info("Agent run test completed successfully") + + +def test_list_artifact_names(test_app, create_test_session): + """Test listing artifact names for a session.""" + info = create_test_session + url = f"/apps/{info['app_name']}/users/{info['user_id']}/sessions/{info['session_id']}/artifacts" + response = test_app.get(url) + + # Verify the response + assert response.status_code == 200 + data = response.json() + assert isinstance(data, list) + logger.info(f"Listed {len(data)} artifacts") + + +def test_debug_trace(test_app): + """Test the debug trace endpoint.""" + # This test will likely return 404 since we haven't set up trace data, + # but it tests that the endpoint exists and handles missing traces correctly. + url = "/debug/trace/nonexistent-event" + response = test_app.get(url) + + # Verify we get a 404 for a nonexistent trace + assert response.status_code == 404 + logger.info("Debug trace test completed successfully") + + +if __name__ == "__main__": + pytest.main(["-xvs", __file__])