mirror of
https://github.com/EvolutionAPI/adk-python.git
synced 2025-07-14 01:41:25 -06:00
chore: fix ut for fast api server
PiperOrigin-RevId: 761350248
This commit is contained in:
parent
98727b4698
commit
cbdb5fc507
@ -61,7 +61,7 @@ from ..agents.live_request_queue import LiveRequestQueue
|
|||||||
from ..agents.llm_agent import Agent
|
from ..agents.llm_agent import Agent
|
||||||
from ..agents.llm_agent import LlmAgent
|
from ..agents.llm_agent import LlmAgent
|
||||||
from ..agents.run_config import StreamingMode
|
from ..agents.run_config import StreamingMode
|
||||||
from ..artifacts import InMemoryArtifactService
|
from ..artifacts.in_memory_artifact_service import InMemoryArtifactService
|
||||||
from ..evaluation.eval_case import EvalCase
|
from ..evaluation.eval_case import EvalCase
|
||||||
from ..evaluation.eval_case import SessionInput
|
from ..evaluation.eval_case import SessionInput
|
||||||
from ..evaluation.local_eval_sets_manager import LocalEvalSetsManager
|
from ..evaluation.local_eval_sets_manager import LocalEvalSetsManager
|
||||||
|
@ -12,52 +12,62 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import json
|
import logging
|
||||||
|
import os
|
||||||
import sys
|
import sys
|
||||||
import threading
|
|
||||||
import time
|
import time
|
||||||
import types as ptypes
|
import types as ptypes
|
||||||
from typing import AsyncGenerator
|
from unittest.mock import MagicMock, patch
|
||||||
from typing import TYPE_CHECKING
|
|
||||||
|
|
||||||
|
from fastapi.testclient import TestClient
|
||||||
from google.adk.agents.base_agent import BaseAgent
|
from google.adk.agents.base_agent import BaseAgent
|
||||||
from google.adk.agents.live_request_queue import LiveRequest
|
|
||||||
from google.adk.agents.run_config import RunConfig
|
from google.adk.agents.run_config import RunConfig
|
||||||
from google.adk.cli.fast_api import AgentRunRequest
|
|
||||||
from google.adk.cli.fast_api import get_fast_api_app
|
from google.adk.cli.fast_api import get_fast_api_app
|
||||||
from google.adk.cli.utils import envs
|
from google.adk.cli.utils import envs
|
||||||
from google.adk.runners import Runner
|
|
||||||
from google.genai import types
|
|
||||||
import httpx
|
|
||||||
import pytest
|
|
||||||
from uvicorn.main import run as uvicorn_run
|
|
||||||
import websockets
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
|
||||||
from google.adk.events import Event
|
from google.adk.events import Event
|
||||||
|
from google.adk.runners import Runner
|
||||||
|
from google.adk.sessions.base_session_service import ListSessionsResponse
|
||||||
|
from google.genai import types
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
|
||||||
# Here we “fake” the agent module that get_fast_api_app expects.
|
# Configure logging to help diagnose server startup issues
|
||||||
# The server code does: `agent_module = importlib.import_module(agent_name)`
|
logging.basicConfig(
|
||||||
# and then accesses: agent_module.agent.root_agent.
|
level=logging.INFO,
|
||||||
|
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
|
||||||
|
)
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
# Here we create a dummy agent module that get_fast_api_app expects
|
||||||
class DummyAgent(BaseAgent):
|
class DummyAgent(BaseAgent):
|
||||||
pass
|
|
||||||
|
def __init__(self, name):
|
||||||
|
super().__init__(name=name)
|
||||||
|
self.sub_agents = []
|
||||||
|
|
||||||
|
|
||||||
|
# Set up dummy module and add to sys.modules
|
||||||
dummy_module = ptypes.ModuleType("test_agent")
|
dummy_module = ptypes.ModuleType("test_agent")
|
||||||
dummy_module.agent = ptypes.SimpleNamespace(
|
dummy_module.agent = ptypes.SimpleNamespace(
|
||||||
root_agent=DummyAgent(name="dummy_agent")
|
root_agent=DummyAgent(name="dummy_agent")
|
||||||
)
|
)
|
||||||
sys.modules["test_app"] = dummy_module
|
sys.modules["test_app"] = dummy_module
|
||||||
|
|
||||||
|
# Try to load environment variables, with a fallback for testing
|
||||||
|
try:
|
||||||
envs.load_dotenv_for_agent("test_app", ".")
|
envs.load_dotenv_for_agent("test_app", ".")
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Could not load environment variables: {e}")
|
||||||
|
# Create a basic .env file if needed
|
||||||
|
if not os.path.exists(".env"):
|
||||||
|
with open(".env", "w") as f:
|
||||||
|
f.write("# Test environment variables\n")
|
||||||
|
|
||||||
|
|
||||||
|
# Create sample events that our mocked runner will return
|
||||||
def _event_1():
|
def _event_1():
|
||||||
from google.adk.events import Event
|
|
||||||
|
|
||||||
return Event(
|
return Event(
|
||||||
author="dummy agent",
|
author="dummy agent",
|
||||||
invocation_id="invocation_id",
|
invocation_id="invocation_id",
|
||||||
@ -68,8 +78,6 @@ def _event_1():
|
|||||||
|
|
||||||
|
|
||||||
def _event_2():
|
def _event_2():
|
||||||
from google.adk.events import Event
|
|
||||||
|
|
||||||
return Event(
|
return Event(
|
||||||
author="dummy agent",
|
author="dummy agent",
|
||||||
invocation_id="invocation_id",
|
invocation_id="invocation_id",
|
||||||
@ -88,19 +96,13 @@ def _event_2():
|
|||||||
|
|
||||||
|
|
||||||
def _event_3():
|
def _event_3():
|
||||||
from google.adk.events import Event
|
|
||||||
|
|
||||||
return Event(
|
return Event(
|
||||||
author="dummy agent", invocation_id="invocation_id", interrupted=True
|
author="dummy agent", invocation_id="invocation_id", interrupted=True
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
# For simplicity, we patch Runner.run_live to yield dummy events.
|
# Define mocked async generator functions for the Runner
|
||||||
# We use SimpleNamespace to mimic attribute-access (i.e. event.content.parts).
|
async def dummy_run_live(self, session, live_request_queue):
|
||||||
async def dummy_run_live(
|
|
||||||
self, session, live_request_queue
|
|
||||||
) -> AsyncGenerator[Event, None]:
|
|
||||||
# Immediately yield a dummy event with a text reply.
|
|
||||||
yield _event_1()
|
yield _event_1()
|
||||||
await asyncio.sleep(0)
|
await asyncio.sleep(0)
|
||||||
|
|
||||||
@ -109,8 +111,6 @@ async def dummy_run_live(
|
|||||||
|
|
||||||
yield _event_3()
|
yield _event_3()
|
||||||
|
|
||||||
raise Exception()
|
|
||||||
|
|
||||||
|
|
||||||
async def dummy_run_async(
|
async def dummy_run_async(
|
||||||
self,
|
self,
|
||||||
@ -118,8 +118,7 @@ async def dummy_run_async(
|
|||||||
session_id,
|
session_id,
|
||||||
new_message,
|
new_message,
|
||||||
run_config: RunConfig = RunConfig(),
|
run_config: RunConfig = RunConfig(),
|
||||||
) -> AsyncGenerator[Event, None]:
|
):
|
||||||
# Immediately yield a dummy event with a text reply.
|
|
||||||
yield _event_1()
|
yield _event_1()
|
||||||
await asyncio.sleep(0)
|
await asyncio.sleep(0)
|
||||||
|
|
||||||
@ -128,159 +127,365 @@ async def dummy_run_async(
|
|||||||
|
|
||||||
yield _event_3()
|
yield _event_3()
|
||||||
|
|
||||||
return
|
|
||||||
|
|
||||||
|
#################################################
|
||||||
###############################################################################
|
# Test Fixtures
|
||||||
# Pytest fixtures to patch methods and start the server
|
#################################################
|
||||||
###############################################################################
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(autouse=True)
|
@pytest.fixture(autouse=True)
|
||||||
def patch_runner(monkeypatch):
|
def patch_runner(monkeypatch):
|
||||||
# Patch the Runner methods to use our dummy implementations.
|
"""Patch the Runner methods to use our dummy implementations."""
|
||||||
monkeypatch.setattr(Runner, "run_live", dummy_run_live)
|
monkeypatch.setattr(Runner, "run_live", dummy_run_live)
|
||||||
monkeypatch.setattr(Runner, "run_async", dummy_run_async)
|
monkeypatch.setattr(Runner, "run_async", dummy_run_async)
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="module", autouse=True)
|
@pytest.fixture
|
||||||
def start_server():
|
def test_session_info():
|
||||||
"""Start the FastAPI server in a background thread."""
|
"""Return test user and session IDs for testing."""
|
||||||
|
return {
|
||||||
|
"app_name": "test_app",
|
||||||
|
"user_id": "test_user",
|
||||||
|
"session_id": "test_session",
|
||||||
|
}
|
||||||
|
|
||||||
def run_server():
|
|
||||||
uvicorn_run(
|
@pytest.fixture
|
||||||
get_fast_api_app(agent_dir=".", web=True),
|
def mock_session_service():
|
||||||
host="0.0.0.0",
|
"""Create a mock session service that uses an in-memory dictionary."""
|
||||||
log_config=None,
|
|
||||||
|
# In-memory database to store sessions during testing
|
||||||
|
session_data = {
|
||||||
|
"test_app": {
|
||||||
|
"test_user": {
|
||||||
|
"test_session": {
|
||||||
|
"id": "test_session",
|
||||||
|
"app_name": "test_app",
|
||||||
|
"user_id": "test_user",
|
||||||
|
"events": [],
|
||||||
|
"state": {},
|
||||||
|
"created_at": time.time(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
# Mock session service class that operates on the in-memory database
|
||||||
|
class MockSessionService:
|
||||||
|
|
||||||
|
async def get_session(self, app_name, user_id, session_id):
|
||||||
|
"""Retrieve a session by ID."""
|
||||||
|
if (
|
||||||
|
app_name in session_data
|
||||||
|
and user_id in session_data[app_name]
|
||||||
|
and session_id in session_data[app_name][user_id]
|
||||||
|
):
|
||||||
|
return session_data[app_name][user_id][session_id]
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def create_session(
|
||||||
|
self, app_name, user_id, state=None, session_id=None
|
||||||
|
):
|
||||||
|
"""Create a new session."""
|
||||||
|
if session_id is None:
|
||||||
|
session_id = f"session_{int(time.time())}"
|
||||||
|
|
||||||
|
# Initialize app_name and user_id if they don't exist
|
||||||
|
if app_name not in session_data:
|
||||||
|
session_data[app_name] = {}
|
||||||
|
if user_id not in session_data[app_name]:
|
||||||
|
session_data[app_name][user_id] = {}
|
||||||
|
|
||||||
|
# Create the session
|
||||||
|
session = {
|
||||||
|
"id": session_id,
|
||||||
|
"app_name": app_name,
|
||||||
|
"user_id": user_id,
|
||||||
|
"events": [],
|
||||||
|
"state": state or {},
|
||||||
|
}
|
||||||
|
|
||||||
|
session_data[app_name][user_id][session_id] = session
|
||||||
|
return session
|
||||||
|
|
||||||
|
async def list_sessions(self, app_name, user_id):
|
||||||
|
"""List all sessions for a user."""
|
||||||
|
if app_name not in session_data or user_id not in session_data[app_name]:
|
||||||
|
return {"sessions": []}
|
||||||
|
|
||||||
|
return ListSessionsResponse(
|
||||||
|
sessions=list(session_data[app_name][user_id].values())
|
||||||
)
|
)
|
||||||
|
|
||||||
server_thread = threading.Thread(target=run_server, daemon=True)
|
async def delete_session(self, app_name, user_id, session_id):
|
||||||
server_thread.start()
|
"""Delete a session."""
|
||||||
# Wait a moment to ensure the server is up.
|
if (
|
||||||
time.sleep(2)
|
app_name in session_data
|
||||||
yield
|
and user_id in session_data[app_name]
|
||||||
# The daemon thread will be terminated when tests complete.
|
and session_id in session_data[app_name][user_id]
|
||||||
|
):
|
||||||
|
del session_data[app_name][user_id][session_id]
|
||||||
|
|
||||||
|
# Return an instance of our mock service
|
||||||
|
return MockSessionService()
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.fixture
|
||||||
async def test_sse_endpoint():
|
def mock_artifact_service():
|
||||||
base_http_url = "http://127.0.0.1:8000"
|
"""Create a mock artifact service."""
|
||||||
user_id = "test_user"
|
|
||||||
session_id = "test_session"
|
|
||||||
|
|
||||||
# Ensure that the session exists (create if necessary).
|
# Storage for artifacts
|
||||||
url_create = (
|
artifacts = {}
|
||||||
f"{base_http_url}/apps/test_app/users/{user_id}/sessions/{session_id}"
|
|
||||||
|
class MockArtifactService:
|
||||||
|
|
||||||
|
async def load_artifact(
|
||||||
|
self, app_name, user_id, session_id, filename, version=None
|
||||||
|
):
|
||||||
|
"""Load an artifact by filename."""
|
||||||
|
key = f"{app_name}:{user_id}:{session_id}:{filename}"
|
||||||
|
if key not in artifacts:
|
||||||
|
return None
|
||||||
|
|
||||||
|
if version is not None:
|
||||||
|
# Get a specific version
|
||||||
|
for v in artifacts[key]:
|
||||||
|
if v["version"] == version:
|
||||||
|
return v["artifact"]
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Get the latest version
|
||||||
|
return sorted(artifacts[key], key=lambda x: x["version"])[-1]["artifact"]
|
||||||
|
|
||||||
|
async def list_artifact_keys(self, app_name, user_id, session_id):
|
||||||
|
"""List artifact names for a session."""
|
||||||
|
prefix = f"{app_name}:{user_id}:{session_id}:"
|
||||||
|
return [
|
||||||
|
k.split(":")[-1] for k in artifacts.keys() if k.startswith(prefix)
|
||||||
|
]
|
||||||
|
|
||||||
|
async def list_versions(self, app_name, user_id, session_id, filename):
|
||||||
|
"""List versions of an artifact."""
|
||||||
|
key = f"{app_name}:{user_id}:{session_id}:{filename}"
|
||||||
|
if key not in artifacts:
|
||||||
|
return []
|
||||||
|
return [a["version"] for a in artifacts[key]]
|
||||||
|
|
||||||
|
async def delete_artifact(self, app_name, user_id, session_id, filename):
|
||||||
|
"""Delete an artifact."""
|
||||||
|
key = f"{app_name}:{user_id}:{session_id}:{filename}"
|
||||||
|
if key in artifacts:
|
||||||
|
del artifacts[key]
|
||||||
|
|
||||||
|
return MockArtifactService()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_memory_service():
|
||||||
|
"""Create a mock memory service."""
|
||||||
|
return MagicMock()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def test_app(mock_session_service, mock_artifact_service, mock_memory_service):
|
||||||
|
"""Create a TestClient for the FastAPI app without starting a server."""
|
||||||
|
|
||||||
|
# Patch multiple services and signal handlers
|
||||||
|
with (
|
||||||
|
patch("signal.signal", return_value=None),
|
||||||
|
patch(
|
||||||
|
"google.adk.cli.fast_api.InMemorySessionService", # Changed this line
|
||||||
|
return_value=mock_session_service,
|
||||||
|
),
|
||||||
|
patch(
|
||||||
|
"google.adk.cli.fast_api.InMemoryArtifactService", # Make consistent
|
||||||
|
return_value=mock_artifact_service,
|
||||||
|
),
|
||||||
|
patch(
|
||||||
|
"google.adk.cli.fast_api.InMemoryMemoryService", # Make consistent
|
||||||
|
return_value=mock_memory_service,
|
||||||
|
),
|
||||||
|
):
|
||||||
|
# Get the FastAPI app, but don't actually run it
|
||||||
|
app = get_fast_api_app(
|
||||||
|
agent_dir=".", web=True, session_db_url="", allow_origins=["*"]
|
||||||
)
|
)
|
||||||
httpx.post(url_create, json={"state": {}})
|
|
||||||
|
|
||||||
async with httpx.AsyncClient() as client:
|
# Create a TestClient that doesn't start a real server
|
||||||
# Make a POST request to the SSE endpoint.
|
client = TestClient(app)
|
||||||
async with client.stream(
|
|
||||||
"POST",
|
return client
|
||||||
f"{base_http_url}/run_sse",
|
|
||||||
json=json.loads(
|
|
||||||
AgentRunRequest(
|
@pytest.fixture
|
||||||
app_name="test_app",
|
async def create_test_session(
|
||||||
user_id=user_id,
|
test_app, test_session_info, mock_session_service
|
||||||
session_id=session_id,
|
):
|
||||||
new_message=types.Content(
|
"""Create a test session using the mocked session service."""
|
||||||
parts=[types.Part(text="Hello via SSE", inline_data=None)]
|
|
||||||
),
|
# Create the session directly through the mock service
|
||||||
streaming=False,
|
session = await mock_session_service.create_session(
|
||||||
).model_dump_json(exclude_none=True)
|
app_name=test_session_info["app_name"],
|
||||||
),
|
user_id=test_session_info["user_id"],
|
||||||
) as response:
|
session_id=test_session_info["session_id"],
|
||||||
# Ensure the status code and header are as expected.
|
state={},
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(f"Created test session: {session['id']}")
|
||||||
|
return test_session_info
|
||||||
|
|
||||||
|
|
||||||
|
#################################################
|
||||||
|
# Test Cases
|
||||||
|
#################################################
|
||||||
|
|
||||||
|
|
||||||
|
def test_list_apps(test_app):
|
||||||
|
"""Test listing available applications."""
|
||||||
|
# Use the TestClient to make a request
|
||||||
|
response = test_app.get("/list-apps")
|
||||||
|
|
||||||
|
# Verify the response
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
|
data = response.json()
|
||||||
|
assert isinstance(data, list)
|
||||||
|
logger.info(f"Listed apps: {data}")
|
||||||
|
|
||||||
|
|
||||||
|
def test_create_session_with_id(test_app, test_session_info):
|
||||||
|
"""Test creating a session with a specific ID."""
|
||||||
|
new_session_id = "new_session_id"
|
||||||
|
url = f"/apps/{test_session_info['app_name']}/users/{test_session_info['user_id']}/sessions/{new_session_id}"
|
||||||
|
response = test_app.post(url, json={"state": {}})
|
||||||
|
|
||||||
|
# Verify the response
|
||||||
|
assert response.status_code == 200
|
||||||
|
data = response.json()
|
||||||
|
assert data["id"] == new_session_id
|
||||||
|
assert data["appName"] == test_session_info["app_name"]
|
||||||
|
assert data["userId"] == test_session_info["user_id"]
|
||||||
|
logger.info(f"Created session with ID: {data['id']}")
|
||||||
|
|
||||||
|
|
||||||
|
def test_create_session_without_id(test_app, test_session_info):
|
||||||
|
"""Test creating a session with a generated ID."""
|
||||||
|
url = f"/apps/{test_session_info['app_name']}/users/{test_session_info['user_id']}/sessions"
|
||||||
|
response = test_app.post(url, json={"state": {}})
|
||||||
|
|
||||||
|
# Verify the response
|
||||||
|
assert response.status_code == 200
|
||||||
|
data = response.json()
|
||||||
|
assert "id" in data
|
||||||
|
assert data["appName"] == test_session_info["app_name"]
|
||||||
|
assert data["userId"] == test_session_info["user_id"]
|
||||||
|
logger.info(f"Created session with generated ID: {data['id']}")
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_session(test_app, create_test_session):
|
||||||
|
"""Test retrieving a session by ID."""
|
||||||
|
info = create_test_session
|
||||||
|
url = f"/apps/{info['app_name']}/users/{info['user_id']}/sessions/{info['session_id']}"
|
||||||
|
response = test_app.get(url)
|
||||||
|
|
||||||
|
# Verify the response
|
||||||
|
assert response.status_code == 200
|
||||||
|
data = response.json()
|
||||||
|
assert data["id"] == info["session_id"]
|
||||||
|
assert data["appName"] == info["app_name"]
|
||||||
|
assert data["userId"] == info["user_id"]
|
||||||
|
logger.info(f"Retrieved session: {data['id']}")
|
||||||
|
|
||||||
|
|
||||||
|
def test_list_sessions(test_app, create_test_session):
|
||||||
|
"""Test listing all sessions for a user."""
|
||||||
|
info = create_test_session
|
||||||
|
url = f"/apps/{info['app_name']}/users/{info['user_id']}/sessions"
|
||||||
|
response = test_app.get(url)
|
||||||
|
|
||||||
|
# Verify the response
|
||||||
|
assert response.status_code == 200
|
||||||
|
data = response.json()
|
||||||
|
assert isinstance(data, list)
|
||||||
|
# At least our test session should be present
|
||||||
|
assert any(session["id"] == info["session_id"] for session in data)
|
||||||
|
logger.info(f"Listed {len(data)} sessions")
|
||||||
|
|
||||||
|
|
||||||
|
def test_delete_session(test_app, create_test_session):
|
||||||
|
"""Test deleting a session."""
|
||||||
|
info = create_test_session
|
||||||
|
url = f"/apps/{info['app_name']}/users/{info['user_id']}/sessions/{info['session_id']}"
|
||||||
|
response = test_app.delete(url)
|
||||||
|
|
||||||
|
# Verify the response
|
||||||
|
assert response.status_code == 200
|
||||||
|
|
||||||
|
# Verify the session is deleted
|
||||||
|
response = test_app.get(url)
|
||||||
|
assert response.status_code == 404
|
||||||
|
logger.info("Session deleted successfully")
|
||||||
|
|
||||||
|
|
||||||
|
def test_agent_run(test_app, create_test_session):
|
||||||
|
"""Test running an agent with a message."""
|
||||||
|
info = create_test_session
|
||||||
|
url = "/run"
|
||||||
|
payload = {
|
||||||
|
"app_name": info["app_name"],
|
||||||
|
"user_id": info["user_id"],
|
||||||
|
"session_id": info["session_id"],
|
||||||
|
"new_message": {"role": "user", "parts": [{"text": "Hello agent"}]},
|
||||||
|
"streaming": False,
|
||||||
|
}
|
||||||
|
|
||||||
|
response = test_app.post(url, json=payload)
|
||||||
|
|
||||||
|
# Verify the response
|
||||||
|
assert response.status_code == 200
|
||||||
|
data = response.json()
|
||||||
|
assert isinstance(data, list)
|
||||||
|
assert len(data) == 3 # We expect 3 events from our dummy_run_async
|
||||||
|
|
||||||
|
# Verify we got the expected events
|
||||||
|
assert data[0]["author"] == "dummy agent"
|
||||||
|
assert data[0]["content"]["parts"][0]["text"] == "LLM reply"
|
||||||
|
|
||||||
|
# Second event should have binary data
|
||||||
assert (
|
assert (
|
||||||
response.headers.get("content-type")
|
data[1]["content"]["parts"][0]["inlineData"]["mimeType"]
|
||||||
== "text/event-stream; charset=utf-8"
|
== "audio/pcm;rate=24000"
|
||||||
)
|
)
|
||||||
|
|
||||||
# Iterate over events from the stream.
|
# Third event should have interrupted flag
|
||||||
event_count = 0
|
assert data[2]["interrupted"] == True
|
||||||
event_buffer = ""
|
|
||||||
|
|
||||||
async for line in response.aiter_lines():
|
logger.info("Agent run test completed successfully")
|
||||||
event_buffer += line + "\n"
|
|
||||||
|
|
||||||
# An SSE event is terminated by an empty line (double newline)
|
|
||||||
if line == "" and event_buffer.strip():
|
|
||||||
# Process the complete event
|
|
||||||
event_data = None
|
|
||||||
for event_line in event_buffer.split("\n"):
|
|
||||||
if event_line.startswith("data: "):
|
|
||||||
event_data = event_line[6:] # Remove "data: " prefix
|
|
||||||
|
|
||||||
if event_data:
|
|
||||||
event_count += 1
|
|
||||||
if event_count == 1:
|
|
||||||
assert event_data == _event_1().model_dump_json(
|
|
||||||
exclude_none=True, by_alias=True
|
|
||||||
)
|
|
||||||
elif event_count == 2:
|
|
||||||
assert event_data == _event_2().model_dump_json(
|
|
||||||
exclude_none=True, by_alias=True
|
|
||||||
)
|
|
||||||
elif event_count == 3:
|
|
||||||
assert event_data == _event_3().model_dump_json(
|
|
||||||
exclude_none=True, by_alias=True
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
pass
|
|
||||||
|
|
||||||
# Reset buffer for next event
|
|
||||||
event_buffer = ""
|
|
||||||
|
|
||||||
assert event_count == 3 # Expecting 3 events from dummy_run_async
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
def test_list_artifact_names(test_app, create_test_session):
|
||||||
async def test_websocket_endpoint():
|
"""Test listing artifact names for a session."""
|
||||||
base_http_url = "http://127.0.0.1:8000"
|
info = create_test_session
|
||||||
base_ws_url = "ws://127.0.0.1:8000"
|
url = f"/apps/{info['app_name']}/users/{info['user_id']}/sessions/{info['session_id']}/artifacts"
|
||||||
user_id = "test_user"
|
response = test_app.get(url)
|
||||||
session_id = "test_session"
|
|
||||||
|
|
||||||
# Ensure that the session exists (create if necessary).
|
# Verify the response
|
||||||
url_create = (
|
assert response.status_code == 200
|
||||||
f"{base_http_url}/apps/test_app/users/{user_id}/sessions/{session_id}"
|
data = response.json()
|
||||||
)
|
assert isinstance(data, list)
|
||||||
httpx.post(url_create, json={"state": {}})
|
logger.info(f"Listed {len(data)} artifacts")
|
||||||
|
|
||||||
ws_url = f"{base_ws_url}/run_live?app_name=test_app&user_id={user_id}&session_id={session_id}"
|
|
||||||
async with websockets.connect(ws_url) as ws:
|
|
||||||
# --- Test sending text data ---
|
|
||||||
text_payload = LiveRequest(
|
|
||||||
content=types.Content(
|
|
||||||
parts=[types.Part(text="Hello via WebSocket", inline_data=None)]
|
|
||||||
)
|
|
||||||
)
|
|
||||||
await ws.send(text_payload.model_dump_json())
|
|
||||||
# Wait for a reply from our dummy_run_live.
|
|
||||||
reply = await ws.recv()
|
|
||||||
event = Event.model_validate_json(reply)
|
|
||||||
assert event.content.parts[0].text == "LLM reply"
|
|
||||||
|
|
||||||
# --- Test sending binary data (allowed mime type "audio/pcm") ---
|
def test_debug_trace(test_app):
|
||||||
sample_audio = b"\x00\xFF"
|
"""Test the debug trace endpoint."""
|
||||||
binary_payload = LiveRequest(
|
# This test will likely return 404 since we haven't set up trace data,
|
||||||
blob=types.Blob(
|
# but it tests that the endpoint exists and handles missing traces correctly.
|
||||||
mime_type="audio/pcm",
|
url = "/debug/trace/nonexistent-event"
|
||||||
data=sample_audio,
|
response = test_app.get(url)
|
||||||
)
|
|
||||||
)
|
|
||||||
await ws.send(binary_payload.model_dump_json())
|
|
||||||
# Wait for a reply.
|
|
||||||
reply = await ws.recv()
|
|
||||||
event = Event.model_validate_json(reply)
|
|
||||||
assert (
|
|
||||||
event.content.parts[0].inline_data.mime_type == "audio/pcm;rate=24000"
|
|
||||||
)
|
|
||||||
assert event.content.parts[0].inline_data.data == b"\x00\xFF"
|
|
||||||
|
|
||||||
reply = await ws.recv()
|
# Verify we get a 404 for a nonexistent trace
|
||||||
event = Event.model_validate_json(reply)
|
assert response.status_code == 404
|
||||||
assert event.interrupted is True
|
logger.info("Debug trace test completed successfully")
|
||||||
assert event.content is None
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
pytest.main(["-xvs", __file__])
|
||||||
|
Loading…
Reference in New Issue
Block a user