chore: fix ut for fast api server

PiperOrigin-RevId: 761350248
This commit is contained in:
Xiang (Sean) Zhou 2025-05-20 21:17:03 -07:00 committed by Copybara-Service
parent 98727b4698
commit cbdb5fc507
2 changed files with 374 additions and 169 deletions

View File

@ -61,7 +61,7 @@ from ..agents.live_request_queue import LiveRequestQueue
from ..agents.llm_agent import Agent from ..agents.llm_agent import Agent
from ..agents.llm_agent import LlmAgent from ..agents.llm_agent import LlmAgent
from ..agents.run_config import StreamingMode from ..agents.run_config import StreamingMode
from ..artifacts import InMemoryArtifactService from ..artifacts.in_memory_artifact_service import InMemoryArtifactService
from ..evaluation.eval_case import EvalCase from ..evaluation.eval_case import EvalCase
from ..evaluation.eval_case import SessionInput from ..evaluation.eval_case import SessionInput
from ..evaluation.local_eval_sets_manager import LocalEvalSetsManager from ..evaluation.local_eval_sets_manager import LocalEvalSetsManager

View File

@ -12,52 +12,62 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from __future__ import annotations
import asyncio import asyncio
import json import logging
import os
import sys import sys
import threading
import time import time
import types as ptypes import types as ptypes
from typing import AsyncGenerator from unittest.mock import MagicMock, patch
from typing import TYPE_CHECKING
from fastapi.testclient import TestClient
from google.adk.agents.base_agent import BaseAgent from google.adk.agents.base_agent import BaseAgent
from google.adk.agents.live_request_queue import LiveRequest
from google.adk.agents.run_config import RunConfig from google.adk.agents.run_config import RunConfig
from google.adk.cli.fast_api import AgentRunRequest
from google.adk.cli.fast_api import get_fast_api_app from google.adk.cli.fast_api import get_fast_api_app
from google.adk.cli.utils import envs from google.adk.cli.utils import envs
from google.adk.events import Event
from google.adk.runners import Runner from google.adk.runners import Runner
from google.adk.sessions.base_session_service import ListSessionsResponse
from google.genai import types from google.genai import types
import httpx
import pytest import pytest
from uvicorn.main import run as uvicorn_run
import websockets
if TYPE_CHECKING:
from google.adk.events import Event
# Here we “fake” the agent module that get_fast_api_app expects. # Configure logging to help diagnose server startup issues
# The server code does: `agent_module = importlib.import_module(agent_name)` logging.basicConfig(
# and then accesses: agent_module.agent.root_agent. level=logging.INFO,
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
)
logger = logging.getLogger(__name__)
# Here we create a dummy agent module that get_fast_api_app expects
class DummyAgent(BaseAgent): class DummyAgent(BaseAgent):
pass
def __init__(self, name):
super().__init__(name=name)
self.sub_agents = []
# Set up dummy module and add to sys.modules
dummy_module = ptypes.ModuleType("test_agent") dummy_module = ptypes.ModuleType("test_agent")
dummy_module.agent = ptypes.SimpleNamespace( dummy_module.agent = ptypes.SimpleNamespace(
root_agent=DummyAgent(name="dummy_agent") root_agent=DummyAgent(name="dummy_agent")
) )
sys.modules["test_app"] = dummy_module sys.modules["test_app"] = dummy_module
envs.load_dotenv_for_agent("test_app", ".")
# Try to load environment variables, with a fallback for testing
try:
envs.load_dotenv_for_agent("test_app", ".")
except Exception as e:
logger.warning(f"Could not load environment variables: {e}")
# Create a basic .env file if needed
if not os.path.exists(".env"):
with open(".env", "w") as f:
f.write("# Test environment variables\n")
# Create sample events that our mocked runner will return
def _event_1(): def _event_1():
from google.adk.events import Event
return Event( return Event(
author="dummy agent", author="dummy agent",
invocation_id="invocation_id", invocation_id="invocation_id",
@ -68,8 +78,6 @@ def _event_1():
def _event_2(): def _event_2():
from google.adk.events import Event
return Event( return Event(
author="dummy agent", author="dummy agent",
invocation_id="invocation_id", invocation_id="invocation_id",
@ -88,19 +96,13 @@ def _event_2():
def _event_3(): def _event_3():
from google.adk.events import Event
return Event( return Event(
author="dummy agent", invocation_id="invocation_id", interrupted=True author="dummy agent", invocation_id="invocation_id", interrupted=True
) )
# For simplicity, we patch Runner.run_live to yield dummy events. # Define mocked async generator functions for the Runner
# We use SimpleNamespace to mimic attribute-access (i.e. event.content.parts). async def dummy_run_live(self, session, live_request_queue):
async def dummy_run_live(
self, session, live_request_queue
) -> AsyncGenerator[Event, None]:
# Immediately yield a dummy event with a text reply.
yield _event_1() yield _event_1()
await asyncio.sleep(0) await asyncio.sleep(0)
@ -109,8 +111,6 @@ async def dummy_run_live(
yield _event_3() yield _event_3()
raise Exception()
async def dummy_run_async( async def dummy_run_async(
self, self,
@ -118,8 +118,7 @@ async def dummy_run_async(
session_id, session_id,
new_message, new_message,
run_config: RunConfig = RunConfig(), run_config: RunConfig = RunConfig(),
) -> AsyncGenerator[Event, None]: ):
# Immediately yield a dummy event with a text reply.
yield _event_1() yield _event_1()
await asyncio.sleep(0) await asyncio.sleep(0)
@ -128,159 +127,365 @@ async def dummy_run_async(
yield _event_3() yield _event_3()
return
#################################################
############################################################################### # Test Fixtures
# Pytest fixtures to patch methods and start the server #################################################
###############################################################################
@pytest.fixture(autouse=True) @pytest.fixture(autouse=True)
def patch_runner(monkeypatch): def patch_runner(monkeypatch):
# Patch the Runner methods to use our dummy implementations. """Patch the Runner methods to use our dummy implementations."""
monkeypatch.setattr(Runner, "run_live", dummy_run_live) monkeypatch.setattr(Runner, "run_live", dummy_run_live)
monkeypatch.setattr(Runner, "run_async", dummy_run_async) monkeypatch.setattr(Runner, "run_async", dummy_run_async)
@pytest.fixture(scope="module", autouse=True) @pytest.fixture
def start_server(): def test_session_info():
"""Start the FastAPI server in a background thread.""" """Return test user and session IDs for testing."""
return {
"app_name": "test_app",
"user_id": "test_user",
"session_id": "test_session",
}
def run_server():
uvicorn_run( @pytest.fixture
get_fast_api_app(agent_dir=".", web=True), def mock_session_service():
host="0.0.0.0", """Create a mock session service that uses an in-memory dictionary."""
log_config=None,
# In-memory database to store sessions during testing
session_data = {
"test_app": {
"test_user": {
"test_session": {
"id": "test_session",
"app_name": "test_app",
"user_id": "test_user",
"events": [],
"state": {},
"created_at": time.time(),
}
}
}
}
# Mock session service class that operates on the in-memory database
class MockSessionService:
async def get_session(self, app_name, user_id, session_id):
"""Retrieve a session by ID."""
if (
app_name in session_data
and user_id in session_data[app_name]
and session_id in session_data[app_name][user_id]
):
return session_data[app_name][user_id][session_id]
return None
async def create_session(
self, app_name, user_id, state=None, session_id=None
):
"""Create a new session."""
if session_id is None:
session_id = f"session_{int(time.time())}"
# Initialize app_name and user_id if they don't exist
if app_name not in session_data:
session_data[app_name] = {}
if user_id not in session_data[app_name]:
session_data[app_name][user_id] = {}
# Create the session
session = {
"id": session_id,
"app_name": app_name,
"user_id": user_id,
"events": [],
"state": state or {},
}
session_data[app_name][user_id][session_id] = session
return session
async def list_sessions(self, app_name, user_id):
"""List all sessions for a user."""
if app_name not in session_data or user_id not in session_data[app_name]:
return {"sessions": []}
return ListSessionsResponse(
sessions=list(session_data[app_name][user_id].values())
) )
server_thread = threading.Thread(target=run_server, daemon=True) async def delete_session(self, app_name, user_id, session_id):
server_thread.start() """Delete a session."""
# Wait a moment to ensure the server is up. if (
time.sleep(2) app_name in session_data
yield and user_id in session_data[app_name]
# The daemon thread will be terminated when tests complete. and session_id in session_data[app_name][user_id]
):
del session_data[app_name][user_id][session_id]
# Return an instance of our mock service
return MockSessionService()
@pytest.mark.asyncio @pytest.fixture
async def test_sse_endpoint(): def mock_artifact_service():
base_http_url = "http://127.0.0.1:8000" """Create a mock artifact service."""
user_id = "test_user"
session_id = "test_session"
# Ensure that the session exists (create if necessary). # Storage for artifacts
url_create = ( artifacts = {}
f"{base_http_url}/apps/test_app/users/{user_id}/sessions/{session_id}"
class MockArtifactService:
async def load_artifact(
self, app_name, user_id, session_id, filename, version=None
):
"""Load an artifact by filename."""
key = f"{app_name}:{user_id}:{session_id}:{filename}"
if key not in artifacts:
return None
if version is not None:
# Get a specific version
for v in artifacts[key]:
if v["version"] == version:
return v["artifact"]
return None
# Get the latest version
return sorted(artifacts[key], key=lambda x: x["version"])[-1]["artifact"]
async def list_artifact_keys(self, app_name, user_id, session_id):
"""List artifact names for a session."""
prefix = f"{app_name}:{user_id}:{session_id}:"
return [
k.split(":")[-1] for k in artifacts.keys() if k.startswith(prefix)
]
async def list_versions(self, app_name, user_id, session_id, filename):
"""List versions of an artifact."""
key = f"{app_name}:{user_id}:{session_id}:{filename}"
if key not in artifacts:
return []
return [a["version"] for a in artifacts[key]]
async def delete_artifact(self, app_name, user_id, session_id, filename):
"""Delete an artifact."""
key = f"{app_name}:{user_id}:{session_id}:{filename}"
if key in artifacts:
del artifacts[key]
return MockArtifactService()
@pytest.fixture
def mock_memory_service():
"""Create a mock memory service."""
return MagicMock()
@pytest.fixture
def test_app(mock_session_service, mock_artifact_service, mock_memory_service):
"""Create a TestClient for the FastAPI app without starting a server."""
# Patch multiple services and signal handlers
with (
patch("signal.signal", return_value=None),
patch(
"google.adk.cli.fast_api.InMemorySessionService", # Changed this line
return_value=mock_session_service,
),
patch(
"google.adk.cli.fast_api.InMemoryArtifactService", # Make consistent
return_value=mock_artifact_service,
),
patch(
"google.adk.cli.fast_api.InMemoryMemoryService", # Make consistent
return_value=mock_memory_service,
),
):
# Get the FastAPI app, but don't actually run it
app = get_fast_api_app(
agent_dir=".", web=True, session_db_url="", allow_origins=["*"]
) )
httpx.post(url_create, json={"state": {}})
async with httpx.AsyncClient() as client: # Create a TestClient that doesn't start a real server
# Make a POST request to the SSE endpoint. client = TestClient(app)
async with client.stream(
"POST", return client
f"{base_http_url}/run_sse",
json=json.loads(
AgentRunRequest( @pytest.fixture
app_name="test_app", async def create_test_session(
user_id=user_id, test_app, test_session_info, mock_session_service
session_id=session_id, ):
new_message=types.Content( """Create a test session using the mocked session service."""
parts=[types.Part(text="Hello via SSE", inline_data=None)]
), # Create the session directly through the mock service
streaming=False, session = await mock_session_service.create_session(
).model_dump_json(exclude_none=True) app_name=test_session_info["app_name"],
), user_id=test_session_info["user_id"],
) as response: session_id=test_session_info["session_id"],
# Ensure the status code and header are as expected. state={},
)
logger.info(f"Created test session: {session['id']}")
return test_session_info
#################################################
# Test Cases
#################################################
def test_list_apps(test_app):
"""Test listing available applications."""
# Use the TestClient to make a request
response = test_app.get("/list-apps")
# Verify the response
assert response.status_code == 200 assert response.status_code == 200
data = response.json()
assert isinstance(data, list)
logger.info(f"Listed apps: {data}")
def test_create_session_with_id(test_app, test_session_info):
"""Test creating a session with a specific ID."""
new_session_id = "new_session_id"
url = f"/apps/{test_session_info['app_name']}/users/{test_session_info['user_id']}/sessions/{new_session_id}"
response = test_app.post(url, json={"state": {}})
# Verify the response
assert response.status_code == 200
data = response.json()
assert data["id"] == new_session_id
assert data["appName"] == test_session_info["app_name"]
assert data["userId"] == test_session_info["user_id"]
logger.info(f"Created session with ID: {data['id']}")
def test_create_session_without_id(test_app, test_session_info):
"""Test creating a session with a generated ID."""
url = f"/apps/{test_session_info['app_name']}/users/{test_session_info['user_id']}/sessions"
response = test_app.post(url, json={"state": {}})
# Verify the response
assert response.status_code == 200
data = response.json()
assert "id" in data
assert data["appName"] == test_session_info["app_name"]
assert data["userId"] == test_session_info["user_id"]
logger.info(f"Created session with generated ID: {data['id']}")
def test_get_session(test_app, create_test_session):
"""Test retrieving a session by ID."""
info = create_test_session
url = f"/apps/{info['app_name']}/users/{info['user_id']}/sessions/{info['session_id']}"
response = test_app.get(url)
# Verify the response
assert response.status_code == 200
data = response.json()
assert data["id"] == info["session_id"]
assert data["appName"] == info["app_name"]
assert data["userId"] == info["user_id"]
logger.info(f"Retrieved session: {data['id']}")
def test_list_sessions(test_app, create_test_session):
"""Test listing all sessions for a user."""
info = create_test_session
url = f"/apps/{info['app_name']}/users/{info['user_id']}/sessions"
response = test_app.get(url)
# Verify the response
assert response.status_code == 200
data = response.json()
assert isinstance(data, list)
# At least our test session should be present
assert any(session["id"] == info["session_id"] for session in data)
logger.info(f"Listed {len(data)} sessions")
def test_delete_session(test_app, create_test_session):
"""Test deleting a session."""
info = create_test_session
url = f"/apps/{info['app_name']}/users/{info['user_id']}/sessions/{info['session_id']}"
response = test_app.delete(url)
# Verify the response
assert response.status_code == 200
# Verify the session is deleted
response = test_app.get(url)
assert response.status_code == 404
logger.info("Session deleted successfully")
def test_agent_run(test_app, create_test_session):
"""Test running an agent with a message."""
info = create_test_session
url = "/run"
payload = {
"app_name": info["app_name"],
"user_id": info["user_id"],
"session_id": info["session_id"],
"new_message": {"role": "user", "parts": [{"text": "Hello agent"}]},
"streaming": False,
}
response = test_app.post(url, json=payload)
# Verify the response
assert response.status_code == 200
data = response.json()
assert isinstance(data, list)
assert len(data) == 3 # We expect 3 events from our dummy_run_async
# Verify we got the expected events
assert data[0]["author"] == "dummy agent"
assert data[0]["content"]["parts"][0]["text"] == "LLM reply"
# Second event should have binary data
assert ( assert (
response.headers.get("content-type") data[1]["content"]["parts"][0]["inlineData"]["mimeType"]
== "text/event-stream; charset=utf-8" == "audio/pcm;rate=24000"
) )
# Iterate over events from the stream. # Third event should have interrupted flag
event_count = 0 assert data[2]["interrupted"] == True
event_buffer = ""
async for line in response.aiter_lines(): logger.info("Agent run test completed successfully")
event_buffer += line + "\n"
# An SSE event is terminated by an empty line (double newline)
if line == "" and event_buffer.strip():
# Process the complete event
event_data = None
for event_line in event_buffer.split("\n"):
if event_line.startswith("data: "):
event_data = event_line[6:] # Remove "data: " prefix
if event_data:
event_count += 1
if event_count == 1:
assert event_data == _event_1().model_dump_json(
exclude_none=True, by_alias=True
)
elif event_count == 2:
assert event_data == _event_2().model_dump_json(
exclude_none=True, by_alias=True
)
elif event_count == 3:
assert event_data == _event_3().model_dump_json(
exclude_none=True, by_alias=True
)
else:
pass
# Reset buffer for next event
event_buffer = ""
assert event_count == 3 # Expecting 3 events from dummy_run_async
@pytest.mark.asyncio def test_list_artifact_names(test_app, create_test_session):
async def test_websocket_endpoint(): """Test listing artifact names for a session."""
base_http_url = "http://127.0.0.1:8000" info = create_test_session
base_ws_url = "ws://127.0.0.1:8000" url = f"/apps/{info['app_name']}/users/{info['user_id']}/sessions/{info['session_id']}/artifacts"
user_id = "test_user" response = test_app.get(url)
session_id = "test_session"
# Ensure that the session exists (create if necessary). # Verify the response
url_create = ( assert response.status_code == 200
f"{base_http_url}/apps/test_app/users/{user_id}/sessions/{session_id}" data = response.json()
) assert isinstance(data, list)
httpx.post(url_create, json={"state": {}}) logger.info(f"Listed {len(data)} artifacts")
ws_url = f"{base_ws_url}/run_live?app_name=test_app&user_id={user_id}&session_id={session_id}"
async with websockets.connect(ws_url) as ws:
# --- Test sending text data ---
text_payload = LiveRequest(
content=types.Content(
parts=[types.Part(text="Hello via WebSocket", inline_data=None)]
)
)
await ws.send(text_payload.model_dump_json())
# Wait for a reply from our dummy_run_live.
reply = await ws.recv()
event = Event.model_validate_json(reply)
assert event.content.parts[0].text == "LLM reply"
# --- Test sending binary data (allowed mime type "audio/pcm") --- def test_debug_trace(test_app):
sample_audio = b"\x00\xFF" """Test the debug trace endpoint."""
binary_payload = LiveRequest( # This test will likely return 404 since we haven't set up trace data,
blob=types.Blob( # but it tests that the endpoint exists and handles missing traces correctly.
mime_type="audio/pcm", url = "/debug/trace/nonexistent-event"
data=sample_audio, response = test_app.get(url)
)
)
await ws.send(binary_payload.model_dump_json())
# Wait for a reply.
reply = await ws.recv()
event = Event.model_validate_json(reply)
assert (
event.content.parts[0].inline_data.mime_type == "audio/pcm;rate=24000"
)
assert event.content.parts[0].inline_data.data == b"\x00\xFF"
reply = await ws.recv() # Verify we get a 404 for a nonexistent trace
event = Event.model_validate_json(reply) assert response.status_code == 404
assert event.interrupted is True logger.info("Debug trace test completed successfully")
assert event.content is None
if __name__ == "__main__":
pytest.main(["-xvs", __file__])