chore: Fixes test_fast_api.py (part I for circular deps).

It still fails due to signal used not in main thread. It will be fixed later.

PiperOrigin-RevId: 760050504
This commit is contained in:
Wei Sun (Jack) 2025-05-17 12:21:22 -07:00 committed by Copybara-Service
parent f592de4cc0
commit 9324801b75
2 changed files with 64 additions and 39 deletions

View File

@ -12,23 +12,29 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from __future__ import annotations
from collections import OrderedDict from collections import OrderedDict
import json import json
import os import os
import tempfile import tempfile
from typing import Optional from typing import Optional
from typing import TYPE_CHECKING
from google.genai import types from google.genai import types
from typing_extensions import override from typing_extensions import override
from vertexai.preview import rag from vertexai.preview import rag
from ..events.event import Event
from ..sessions.session import Session
from . import _utils from . import _utils
from .base_memory_service import BaseMemoryService from .base_memory_service import BaseMemoryService
from .base_memory_service import SearchMemoryResponse from .base_memory_service import SearchMemoryResponse
from .memory_entry import MemoryEntry from .memory_entry import MemoryEntry
if TYPE_CHECKING:
from ..events.event import Event
from ..sessions.session import Session
class VertexAiRagMemoryService(BaseMemoryService): class VertexAiRagMemoryService(BaseMemoryService):
"""A memory service that uses Vertex AI RAG for storage and retrieval.""" """A memory service that uses Vertex AI RAG for storage and retrieval."""
@ -103,6 +109,8 @@ class VertexAiRagMemoryService(BaseMemoryService):
self, *, app_name: str, user_id: str, query: str self, *, app_name: str, user_id: str, query: str
) -> SearchMemoryResponse: ) -> SearchMemoryResponse:
"""Searches for sessions that match the query using rag.retrieval_query.""" """Searches for sessions that match the query using rag.retrieval_query."""
from ..events.event import Event
response = rag.retrieval_query( response = rag.retrieval_query(
text=query, text=query,
rag_resources=self._vertex_rag_store.rag_resources, rag_resources=self._vertex_rag_store.rag_resources,

View File

@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from __future__ import annotations
import asyncio import asyncio
import json import json
import sys import sys
@ -19,14 +21,14 @@ import threading
import time import time
import types as ptypes import types as ptypes
from typing import AsyncGenerator from typing import AsyncGenerator
from typing import TYPE_CHECKING
from google.adk.agents import BaseAgent from google.adk.agents.base_agent import BaseAgent
from google.adk.agents import LiveRequest from google.adk.agents.live_request_queue import LiveRequest
from google.adk.agents.run_config import RunConfig from google.adk.agents.run_config import RunConfig
from google.adk.cli.fast_api import AgentRunRequest from google.adk.cli.fast_api import AgentRunRequest
from google.adk.cli.fast_api import get_fast_api_app from google.adk.cli.fast_api import get_fast_api_app
from google.adk.cli.utils import envs from google.adk.cli.utils import envs
from google.adk.events import Event
from google.adk.runners import Runner from google.adk.runners import Runner
from google.genai import types from google.genai import types
import httpx import httpx
@ -34,6 +36,9 @@ import pytest
from uvicorn.main import run as uvicorn_run from uvicorn.main import run as uvicorn_run
import websockets import websockets
if TYPE_CHECKING:
from google.adk.events import Event
# Here we “fake” the agent module that get_fast_api_app expects. # Here we “fake” the agent module that get_fast_api_app expects.
# The server code does: `agent_module = importlib.import_module(agent_name)` # The server code does: `agent_module = importlib.import_module(agent_name)`
@ -49,7 +54,11 @@ dummy_module.agent = ptypes.SimpleNamespace(
sys.modules["test_app"] = dummy_module sys.modules["test_app"] = dummy_module
envs.load_dotenv_for_agent("test_app", ".") envs.load_dotenv_for_agent("test_app", ".")
event1 = Event(
def _event_1():
from google.adk.events import Event
return Event(
author="dummy agent", author="dummy agent",
invocation_id="invocation_id", invocation_id="invocation_id",
content=types.Content( content=types.Content(
@ -57,7 +66,11 @@ event1 = Event(
), ),
) )
event2 = Event(
def _event_2():
from google.adk.events import Event
return Event(
author="dummy agent", author="dummy agent",
invocation_id="invocation_id", invocation_id="invocation_id",
content=types.Content( content=types.Content(
@ -73,7 +86,11 @@ event2 = Event(
), ),
) )
event3 = Event(
def _event_3():
from google.adk.events import Event
return Event(
author="dummy agent", invocation_id="invocation_id", interrupted=True author="dummy agent", invocation_id="invocation_id", interrupted=True
) )
@ -84,13 +101,13 @@ async def dummy_run_live(
self, session, live_request_queue self, session, live_request_queue
) -> AsyncGenerator[Event, None]: ) -> AsyncGenerator[Event, None]:
# Immediately yield a dummy event with a text reply. # Immediately yield a dummy event with a text reply.
yield event1 yield _event_1()
await asyncio.sleep(0) await asyncio.sleep(0)
yield event2 yield _event_2()
await asyncio.sleep(0) await asyncio.sleep(0)
yield event3 yield _event_3()
raise Exception() raise Exception()
@ -103,13 +120,13 @@ async def dummy_run_async(
run_config: RunConfig = RunConfig(), run_config: RunConfig = RunConfig(),
) -> AsyncGenerator[Event, None]: ) -> AsyncGenerator[Event, None]:
# Immediately yield a dummy event with a text reply. # Immediately yield a dummy event with a text reply.
yield event1 yield _event_1()
await asyncio.sleep(0) await asyncio.sleep(0)
yield event2 yield _event_2()
await asyncio.sleep(0) await asyncio.sleep(0)
yield event3 yield _event_3()
return return
@ -199,15 +216,15 @@ async def test_sse_endpoint():
if event_data: if event_data:
event_count += 1 event_count += 1
if event_count == 1: if event_count == 1:
assert event_data == event1.model_dump_json( assert event_data == _event_1().model_dump_json(
exclude_none=True, by_alias=True exclude_none=True, by_alias=True
) )
elif event_count == 2: elif event_count == 2:
assert event_data == event2.model_dump_json( assert event_data == _event_2().model_dump_json(
exclude_none=True, by_alias=True exclude_none=True, by_alias=True
) )
elif event_count == 3: elif event_count == 3:
assert event_data == event3.model_dump_json( assert event_data == _event_3().model_dump_json(
exclude_none=True, by_alias=True exclude_none=True, by_alias=True
) )
else: else: