From 9324801b754d63e79a6b7035ee441b324784ce20 Mon Sep 17 00:00:00 2001 From: "Wei Sun (Jack)" Date: Sat, 17 May 2025 12:21:22 -0700 Subject: [PATCH] chore: Fixes test_fast_api.py (part I for circular deps). It still fails due to signal used not in main thread. It will be fixed later. PiperOrigin-RevId: 760050504 --- .../memory/vertex_ai_rag_memory_service.py | 12 ++- tests/unittests/fast_api/test_fast_api.py | 91 +++++++++++-------- 2 files changed, 64 insertions(+), 39 deletions(-) diff --git a/src/google/adk/memory/vertex_ai_rag_memory_service.py b/src/google/adk/memory/vertex_ai_rag_memory_service.py index 2322071..1b163a9 100644 --- a/src/google/adk/memory/vertex_ai_rag_memory_service.py +++ b/src/google/adk/memory/vertex_ai_rag_memory_service.py @@ -12,23 +12,29 @@ # See the License for the specific language governing permissions and # limitations under the License. + +from __future__ import annotations + from collections import OrderedDict import json import os import tempfile from typing import Optional +from typing import TYPE_CHECKING from google.genai import types from typing_extensions import override from vertexai.preview import rag -from ..events.event import Event -from ..sessions.session import Session from . import _utils from .base_memory_service import BaseMemoryService from .base_memory_service import SearchMemoryResponse from .memory_entry import MemoryEntry +if TYPE_CHECKING: + from ..events.event import Event + from ..sessions.session import Session + class VertexAiRagMemoryService(BaseMemoryService): """A memory service that uses Vertex AI RAG for storage and retrieval.""" @@ -103,6 +109,8 @@ class VertexAiRagMemoryService(BaseMemoryService): self, *, app_name: str, user_id: str, query: str ) -> SearchMemoryResponse: """Searches for sessions that match the query using rag.retrieval_query.""" + from ..events.event import Event + response = rag.retrieval_query( text=query, rag_resources=self._vertex_rag_store.rag_resources, diff --git a/tests/unittests/fast_api/test_fast_api.py b/tests/unittests/fast_api/test_fast_api.py index 1f7fd17..62c7e79 100644 --- a/tests/unittests/fast_api/test_fast_api.py +++ b/tests/unittests/fast_api/test_fast_api.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import asyncio import json import sys @@ -19,14 +21,14 @@ import threading import time import types as ptypes from typing import AsyncGenerator +from typing import TYPE_CHECKING -from google.adk.agents import BaseAgent -from google.adk.agents import LiveRequest +from google.adk.agents.base_agent import BaseAgent +from google.adk.agents.live_request_queue import LiveRequest from google.adk.agents.run_config import RunConfig from google.adk.cli.fast_api import AgentRunRequest from google.adk.cli.fast_api import get_fast_api_app from google.adk.cli.utils import envs -from google.adk.events import Event from google.adk.runners import Runner from google.genai import types import httpx @@ -34,6 +36,9 @@ import pytest from uvicorn.main import run as uvicorn_run import websockets +if TYPE_CHECKING: + from google.adk.events import Event + # Here we “fake” the agent module that get_fast_api_app expects. # The server code does: `agent_module = importlib.import_module(agent_name)` @@ -49,33 +54,45 @@ dummy_module.agent = ptypes.SimpleNamespace( sys.modules["test_app"] = dummy_module envs.load_dotenv_for_agent("test_app", ".") -event1 = Event( - author="dummy agent", - invocation_id="invocation_id", - content=types.Content( - role="model", parts=[types.Part(text="LLM reply", inline_data=None)] - ), -) -event2 = Event( - author="dummy agent", - invocation_id="invocation_id", - content=types.Content( - role="model", - parts=[ - types.Part( - text=None, - inline_data=types.Blob( - mime_type="audio/pcm;rate=24000", data=b"\x00\xFF" - ), - ) - ], - ), -) +def _event_1(): + from google.adk.events import Event -event3 = Event( - author="dummy agent", invocation_id="invocation_id", interrupted=True -) + return Event( + author="dummy agent", + invocation_id="invocation_id", + content=types.Content( + role="model", parts=[types.Part(text="LLM reply", inline_data=None)] + ), + ) + + +def _event_2(): + from google.adk.events import Event + + return Event( + author="dummy agent", + invocation_id="invocation_id", + content=types.Content( + role="model", + parts=[ + types.Part( + text=None, + inline_data=types.Blob( + mime_type="audio/pcm;rate=24000", data=b"\x00\xFF" + ), + ) + ], + ), + ) + + +def _event_3(): + from google.adk.events import Event + + return Event( + author="dummy agent", invocation_id="invocation_id", interrupted=True + ) # For simplicity, we patch Runner.run_live to yield dummy events. @@ -84,13 +101,13 @@ async def dummy_run_live( self, session, live_request_queue ) -> AsyncGenerator[Event, None]: # Immediately yield a dummy event with a text reply. - yield event1 + yield _event_1() await asyncio.sleep(0) - yield event2 + yield _event_2() await asyncio.sleep(0) - yield event3 + yield _event_3() raise Exception() @@ -103,13 +120,13 @@ async def dummy_run_async( run_config: RunConfig = RunConfig(), ) -> AsyncGenerator[Event, None]: # Immediately yield a dummy event with a text reply. - yield event1 + yield _event_1() await asyncio.sleep(0) - yield event2 + yield _event_2() await asyncio.sleep(0) - yield event3 + yield _event_3() return @@ -199,15 +216,15 @@ async def test_sse_endpoint(): if event_data: event_count += 1 if event_count == 1: - assert event_data == event1.model_dump_json( + assert event_data == _event_1().model_dump_json( exclude_none=True, by_alias=True ) elif event_count == 2: - assert event_data == event2.model_dump_json( + assert event_data == _event_2().model_dump_json( exclude_none=True, by_alias=True ) elif event_count == 3: - assert event_data == event3.model_dump_json( + assert event_data == _event_3().model_dump_json( exclude_none=True, by_alias=True ) else: