adk-python/tests/unittests/fast_api/test_fast_api.py
Wei Sun (Jack) 9324801b75 chore: Fixes test_fast_api.py (part I for circular deps).
It still fails due to signal used not in main thread. It will be fixed later.

PiperOrigin-RevId: 760050504
2025-05-17 12:22:04 -07:00

287 lines
8.2 KiB
Python

# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import asyncio
import json
import sys
import threading
import time
import types as ptypes
from typing import AsyncGenerator
from typing import TYPE_CHECKING
from google.adk.agents.base_agent import BaseAgent
from google.adk.agents.live_request_queue import LiveRequest
from google.adk.agents.run_config import RunConfig
from google.adk.cli.fast_api import AgentRunRequest
from google.adk.cli.fast_api import get_fast_api_app
from google.adk.cli.utils import envs
from google.adk.runners import Runner
from google.genai import types
import httpx
import pytest
from uvicorn.main import run as uvicorn_run
import websockets
if TYPE_CHECKING:
from google.adk.events import Event
# Here we “fake” the agent module that get_fast_api_app expects.
# The server code does: `agent_module = importlib.import_module(agent_name)`
# and then accesses: agent_module.agent.root_agent.
class DummyAgent(BaseAgent):
pass
dummy_module = ptypes.ModuleType("test_agent")
dummy_module.agent = ptypes.SimpleNamespace(
root_agent=DummyAgent(name="dummy_agent")
)
sys.modules["test_app"] = dummy_module
envs.load_dotenv_for_agent("test_app", ".")
def _event_1():
from google.adk.events import Event
return Event(
author="dummy agent",
invocation_id="invocation_id",
content=types.Content(
role="model", parts=[types.Part(text="LLM reply", inline_data=None)]
),
)
def _event_2():
from google.adk.events import Event
return Event(
author="dummy agent",
invocation_id="invocation_id",
content=types.Content(
role="model",
parts=[
types.Part(
text=None,
inline_data=types.Blob(
mime_type="audio/pcm;rate=24000", data=b"\x00\xFF"
),
)
],
),
)
def _event_3():
from google.adk.events import Event
return Event(
author="dummy agent", invocation_id="invocation_id", interrupted=True
)
# For simplicity, we patch Runner.run_live to yield dummy events.
# We use SimpleNamespace to mimic attribute-access (i.e. event.content.parts).
async def dummy_run_live(
self, session, live_request_queue
) -> AsyncGenerator[Event, None]:
# Immediately yield a dummy event with a text reply.
yield _event_1()
await asyncio.sleep(0)
yield _event_2()
await asyncio.sleep(0)
yield _event_3()
raise Exception()
async def dummy_run_async(
self,
user_id,
session_id,
new_message,
run_config: RunConfig = RunConfig(),
) -> AsyncGenerator[Event, None]:
# Immediately yield a dummy event with a text reply.
yield _event_1()
await asyncio.sleep(0)
yield _event_2()
await asyncio.sleep(0)
yield _event_3()
return
###############################################################################
# Pytest fixtures to patch methods and start the server
###############################################################################
@pytest.fixture(autouse=True)
def patch_runner(monkeypatch):
# Patch the Runner methods to use our dummy implementations.
monkeypatch.setattr(Runner, "run_live", dummy_run_live)
monkeypatch.setattr(Runner, "run_async", dummy_run_async)
@pytest.fixture(scope="module", autouse=True)
def start_server():
"""Start the FastAPI server in a background thread."""
def run_server():
uvicorn_run(
get_fast_api_app(agent_dir=".", web=True),
host="0.0.0.0",
log_config=None,
)
server_thread = threading.Thread(target=run_server, daemon=True)
server_thread.start()
# Wait a moment to ensure the server is up.
time.sleep(2)
yield
# The daemon thread will be terminated when tests complete.
@pytest.mark.asyncio
async def test_sse_endpoint():
base_http_url = "http://127.0.0.1:8000"
user_id = "test_user"
session_id = "test_session"
# Ensure that the session exists (create if necessary).
url_create = (
f"{base_http_url}/apps/test_app/users/{user_id}/sessions/{session_id}"
)
httpx.post(url_create, json={"state": {}})
async with httpx.AsyncClient() as client:
# Make a POST request to the SSE endpoint.
async with client.stream(
"POST",
f"{base_http_url}/run_sse",
json=json.loads(
AgentRunRequest(
app_name="test_app",
user_id=user_id,
session_id=session_id,
new_message=types.Content(
parts=[types.Part(text="Hello via SSE", inline_data=None)]
),
streaming=False,
).model_dump_json(exclude_none=True)
),
) as response:
# Ensure the status code and header are as expected.
assert response.status_code == 200
assert (
response.headers.get("content-type")
== "text/event-stream; charset=utf-8"
)
# Iterate over events from the stream.
event_count = 0
event_buffer = ""
async for line in response.aiter_lines():
event_buffer += line + "\n"
# An SSE event is terminated by an empty line (double newline)
if line == "" and event_buffer.strip():
# Process the complete event
event_data = None
for event_line in event_buffer.split("\n"):
if event_line.startswith("data: "):
event_data = event_line[6:] # Remove "data: " prefix
if event_data:
event_count += 1
if event_count == 1:
assert event_data == _event_1().model_dump_json(
exclude_none=True, by_alias=True
)
elif event_count == 2:
assert event_data == _event_2().model_dump_json(
exclude_none=True, by_alias=True
)
elif event_count == 3:
assert event_data == _event_3().model_dump_json(
exclude_none=True, by_alias=True
)
else:
pass
# Reset buffer for next event
event_buffer = ""
assert event_count == 3 # Expecting 3 events from dummy_run_async
@pytest.mark.asyncio
async def test_websocket_endpoint():
base_http_url = "http://127.0.0.1:8000"
base_ws_url = "ws://127.0.0.1:8000"
user_id = "test_user"
session_id = "test_session"
# Ensure that the session exists (create if necessary).
url_create = (
f"{base_http_url}/apps/test_app/users/{user_id}/sessions/{session_id}"
)
httpx.post(url_create, json={"state": {}})
ws_url = f"{base_ws_url}/run_live?app_name=test_app&user_id={user_id}&session_id={session_id}"
async with websockets.connect(ws_url) as ws:
# --- Test sending text data ---
text_payload = LiveRequest(
content=types.Content(
parts=[types.Part(text="Hello via WebSocket", inline_data=None)]
)
)
await ws.send(text_payload.model_dump_json())
# Wait for a reply from our dummy_run_live.
reply = await ws.recv()
event = Event.model_validate_json(reply)
assert event.content.parts[0].text == "LLM reply"
# --- Test sending binary data (allowed mime type "audio/pcm") ---
sample_audio = b"\x00\xFF"
binary_payload = LiveRequest(
blob=types.Blob(
mime_type="audio/pcm",
data=sample_audio,
)
)
await ws.send(binary_payload.model_dump_json())
# Wait for a reply.
reply = await ws.recv()
event = Event.model_validate_json(reply)
assert (
event.content.parts[0].inline_data.mime_type == "audio/pcm;rate=24000"
)
assert event.content.parts[0].inline_data.data == b"\x00\xFF"
reply = await ws.recv()
event = Event.model_validate_json(reply)
assert event.interrupted is True
assert event.content is None