chore: Fixes test_fast_api.py (part I for circular deps).

It still fails due to signal used not in main thread. It will be fixed later.

PiperOrigin-RevId: 760050504
This commit is contained in:
Wei Sun (Jack)
2025-05-17 12:21:22 -07:00
committed by Copybara-Service
parent f592de4cc0
commit 9324801b75
2 changed files with 64 additions and 39 deletions

View File

@@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import asyncio
import json
import sys
@@ -19,14 +21,14 @@ import threading
import time
import types as ptypes
from typing import AsyncGenerator
from typing import TYPE_CHECKING
from google.adk.agents import BaseAgent
from google.adk.agents import LiveRequest
from google.adk.agents.base_agent import BaseAgent
from google.adk.agents.live_request_queue import LiveRequest
from google.adk.agents.run_config import RunConfig
from google.adk.cli.fast_api import AgentRunRequest
from google.adk.cli.fast_api import get_fast_api_app
from google.adk.cli.utils import envs
from google.adk.events import Event
from google.adk.runners import Runner
from google.genai import types
import httpx
@@ -34,6 +36,9 @@ import pytest
from uvicorn.main import run as uvicorn_run
import websockets
if TYPE_CHECKING:
from google.adk.events import Event
# Here we “fake” the agent module that get_fast_api_app expects.
# The server code does: `agent_module = importlib.import_module(agent_name)`
@@ -49,33 +54,45 @@ dummy_module.agent = ptypes.SimpleNamespace(
sys.modules["test_app"] = dummy_module
envs.load_dotenv_for_agent("test_app", ".")
event1 = Event(
author="dummy agent",
invocation_id="invocation_id",
content=types.Content(
role="model", parts=[types.Part(text="LLM reply", inline_data=None)]
),
)
event2 = Event(
author="dummy agent",
invocation_id="invocation_id",
content=types.Content(
role="model",
parts=[
types.Part(
text=None,
inline_data=types.Blob(
mime_type="audio/pcm;rate=24000", data=b"\x00\xFF"
),
)
],
),
)
def _event_1():
from google.adk.events import Event
event3 = Event(
author="dummy agent", invocation_id="invocation_id", interrupted=True
)
return Event(
author="dummy agent",
invocation_id="invocation_id",
content=types.Content(
role="model", parts=[types.Part(text="LLM reply", inline_data=None)]
),
)
def _event_2():
from google.adk.events import Event
return Event(
author="dummy agent",
invocation_id="invocation_id",
content=types.Content(
role="model",
parts=[
types.Part(
text=None,
inline_data=types.Blob(
mime_type="audio/pcm;rate=24000", data=b"\x00\xFF"
),
)
],
),
)
def _event_3():
from google.adk.events import Event
return Event(
author="dummy agent", invocation_id="invocation_id", interrupted=True
)
# For simplicity, we patch Runner.run_live to yield dummy events.
@@ -84,13 +101,13 @@ async def dummy_run_live(
self, session, live_request_queue
) -> AsyncGenerator[Event, None]:
# Immediately yield a dummy event with a text reply.
yield event1
yield _event_1()
await asyncio.sleep(0)
yield event2
yield _event_2()
await asyncio.sleep(0)
yield event3
yield _event_3()
raise Exception()
@@ -103,13 +120,13 @@ async def dummy_run_async(
run_config: RunConfig = RunConfig(),
) -> AsyncGenerator[Event, None]:
# Immediately yield a dummy event with a text reply.
yield event1
yield _event_1()
await asyncio.sleep(0)
yield event2
yield _event_2()
await asyncio.sleep(0)
yield event3
yield _event_3()
return
@@ -199,15 +216,15 @@ async def test_sse_endpoint():
if event_data:
event_count += 1
if event_count == 1:
assert event_data == event1.model_dump_json(
assert event_data == _event_1().model_dump_json(
exclude_none=True, by_alias=True
)
elif event_count == 2:
assert event_data == event2.model_dump_json(
assert event_data == _event_2().model_dump_json(
exclude_none=True, by_alias=True
)
elif event_count == 3:
assert event_data == event3.model_dump_json(
assert event_data == _event_3().model_dump_json(
exclude_none=True, by_alias=True
)
else: