chore: Fixes test_fast_api.py (part I for circular deps).

It still fails due to signal used not in main thread. It will be fixed later.

PiperOrigin-RevId: 760050504
This commit is contained in:
Wei Sun (Jack) 2025-05-17 12:21:22 -07:00 committed by Copybara-Service
parent f592de4cc0
commit 9324801b75
2 changed files with 64 additions and 39 deletions

View File

@ -12,23 +12,29 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
from collections import OrderedDict
import json
import os
import tempfile
from typing import Optional
from typing import TYPE_CHECKING
from google.genai import types
from typing_extensions import override
from vertexai.preview import rag
from ..events.event import Event
from ..sessions.session import Session
from . import _utils
from .base_memory_service import BaseMemoryService
from .base_memory_service import SearchMemoryResponse
from .memory_entry import MemoryEntry
if TYPE_CHECKING:
from ..events.event import Event
from ..sessions.session import Session
class VertexAiRagMemoryService(BaseMemoryService):
"""A memory service that uses Vertex AI RAG for storage and retrieval."""
@ -103,6 +109,8 @@ class VertexAiRagMemoryService(BaseMemoryService):
self, *, app_name: str, user_id: str, query: str
) -> SearchMemoryResponse:
"""Searches for sessions that match the query using rag.retrieval_query."""
from ..events.event import Event
response = rag.retrieval_query(
text=query,
rag_resources=self._vertex_rag_store.rag_resources,

View File

@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import asyncio
import json
import sys
@ -19,14 +21,14 @@ import threading
import time
import types as ptypes
from typing import AsyncGenerator
from typing import TYPE_CHECKING
from google.adk.agents import BaseAgent
from google.adk.agents import LiveRequest
from google.adk.agents.base_agent import BaseAgent
from google.adk.agents.live_request_queue import LiveRequest
from google.adk.agents.run_config import RunConfig
from google.adk.cli.fast_api import AgentRunRequest
from google.adk.cli.fast_api import get_fast_api_app
from google.adk.cli.utils import envs
from google.adk.events import Event
from google.adk.runners import Runner
from google.genai import types
import httpx
@ -34,6 +36,9 @@ import pytest
from uvicorn.main import run as uvicorn_run
import websockets
if TYPE_CHECKING:
from google.adk.events import Event
# Here we “fake” the agent module that get_fast_api_app expects.
# The server code does: `agent_module = importlib.import_module(agent_name)`
@ -49,33 +54,45 @@ dummy_module.agent = ptypes.SimpleNamespace(
sys.modules["test_app"] = dummy_module
envs.load_dotenv_for_agent("test_app", ".")
event1 = Event(
author="dummy agent",
invocation_id="invocation_id",
content=types.Content(
role="model", parts=[types.Part(text="LLM reply", inline_data=None)]
),
)
event2 = Event(
author="dummy agent",
invocation_id="invocation_id",
content=types.Content(
role="model",
parts=[
types.Part(
text=None,
inline_data=types.Blob(
mime_type="audio/pcm;rate=24000", data=b"\x00\xFF"
),
)
],
),
)
def _event_1():
from google.adk.events import Event
event3 = Event(
author="dummy agent", invocation_id="invocation_id", interrupted=True
)
return Event(
author="dummy agent",
invocation_id="invocation_id",
content=types.Content(
role="model", parts=[types.Part(text="LLM reply", inline_data=None)]
),
)
def _event_2():
from google.adk.events import Event
return Event(
author="dummy agent",
invocation_id="invocation_id",
content=types.Content(
role="model",
parts=[
types.Part(
text=None,
inline_data=types.Blob(
mime_type="audio/pcm;rate=24000", data=b"\x00\xFF"
),
)
],
),
)
def _event_3():
from google.adk.events import Event
return Event(
author="dummy agent", invocation_id="invocation_id", interrupted=True
)
# For simplicity, we patch Runner.run_live to yield dummy events.
@ -84,13 +101,13 @@ async def dummy_run_live(
self, session, live_request_queue
) -> AsyncGenerator[Event, None]:
# Immediately yield a dummy event with a text reply.
yield event1
yield _event_1()
await asyncio.sleep(0)
yield event2
yield _event_2()
await asyncio.sleep(0)
yield event3
yield _event_3()
raise Exception()
@ -103,13 +120,13 @@ async def dummy_run_async(
run_config: RunConfig = RunConfig(),
) -> AsyncGenerator[Event, None]:
# Immediately yield a dummy event with a text reply.
yield event1
yield _event_1()
await asyncio.sleep(0)
yield event2
yield _event_2()
await asyncio.sleep(0)
yield event3
yield _event_3()
return
@ -199,15 +216,15 @@ async def test_sse_endpoint():
if event_data:
event_count += 1
if event_count == 1:
assert event_data == event1.model_dump_json(
assert event_data == _event_1().model_dump_json(
exclude_none=True, by_alias=True
)
elif event_count == 2:
assert event_data == event2.model_dump_json(
assert event_data == _event_2().model_dump_json(
exclude_none=True, by_alias=True
)
elif event_count == 3:
assert event_data == event3.model_dump_json(
assert event_data == _event_3().model_dump_json(
exclude_none=True, by_alias=True
)
else: