mirror of
https://github.com/EvolutionAPI/adk-python.git
synced 2025-07-13 15:14:50 -06:00
chore: Fixes test_fast_api.py (part I for circular deps).
It still fails due to signal used not in main thread. It will be fixed later. PiperOrigin-RevId: 760050504
This commit is contained in:
parent
f592de4cc0
commit
9324801b75
@ -12,23 +12,29 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections import OrderedDict
|
||||
import json
|
||||
import os
|
||||
import tempfile
|
||||
from typing import Optional
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from google.genai import types
|
||||
from typing_extensions import override
|
||||
from vertexai.preview import rag
|
||||
|
||||
from ..events.event import Event
|
||||
from ..sessions.session import Session
|
||||
from . import _utils
|
||||
from .base_memory_service import BaseMemoryService
|
||||
from .base_memory_service import SearchMemoryResponse
|
||||
from .memory_entry import MemoryEntry
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..events.event import Event
|
||||
from ..sessions.session import Session
|
||||
|
||||
|
||||
class VertexAiRagMemoryService(BaseMemoryService):
|
||||
"""A memory service that uses Vertex AI RAG for storage and retrieval."""
|
||||
@ -103,6 +109,8 @@ class VertexAiRagMemoryService(BaseMemoryService):
|
||||
self, *, app_name: str, user_id: str, query: str
|
||||
) -> SearchMemoryResponse:
|
||||
"""Searches for sessions that match the query using rag.retrieval_query."""
|
||||
from ..events.event import Event
|
||||
|
||||
response = rag.retrieval_query(
|
||||
text=query,
|
||||
rag_resources=self._vertex_rag_store.rag_resources,
|
||||
|
@ -12,6 +12,8 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import sys
|
||||
@ -19,14 +21,14 @@ import threading
|
||||
import time
|
||||
import types as ptypes
|
||||
from typing import AsyncGenerator
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from google.adk.agents import BaseAgent
|
||||
from google.adk.agents import LiveRequest
|
||||
from google.adk.agents.base_agent import BaseAgent
|
||||
from google.adk.agents.live_request_queue import LiveRequest
|
||||
from google.adk.agents.run_config import RunConfig
|
||||
from google.adk.cli.fast_api import AgentRunRequest
|
||||
from google.adk.cli.fast_api import get_fast_api_app
|
||||
from google.adk.cli.utils import envs
|
||||
from google.adk.events import Event
|
||||
from google.adk.runners import Runner
|
||||
from google.genai import types
|
||||
import httpx
|
||||
@ -34,6 +36,9 @@ import pytest
|
||||
from uvicorn.main import run as uvicorn_run
|
||||
import websockets
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from google.adk.events import Event
|
||||
|
||||
|
||||
# Here we “fake” the agent module that get_fast_api_app expects.
|
||||
# The server code does: `agent_module = importlib.import_module(agent_name)`
|
||||
@ -49,33 +54,45 @@ dummy_module.agent = ptypes.SimpleNamespace(
|
||||
sys.modules["test_app"] = dummy_module
|
||||
envs.load_dotenv_for_agent("test_app", ".")
|
||||
|
||||
event1 = Event(
|
||||
author="dummy agent",
|
||||
invocation_id="invocation_id",
|
||||
content=types.Content(
|
||||
role="model", parts=[types.Part(text="LLM reply", inline_data=None)]
|
||||
),
|
||||
)
|
||||
|
||||
event2 = Event(
|
||||
author="dummy agent",
|
||||
invocation_id="invocation_id",
|
||||
content=types.Content(
|
||||
role="model",
|
||||
parts=[
|
||||
types.Part(
|
||||
text=None,
|
||||
inline_data=types.Blob(
|
||||
mime_type="audio/pcm;rate=24000", data=b"\x00\xFF"
|
||||
),
|
||||
)
|
||||
],
|
||||
),
|
||||
)
|
||||
def _event_1():
|
||||
from google.adk.events import Event
|
||||
|
||||
event3 = Event(
|
||||
author="dummy agent", invocation_id="invocation_id", interrupted=True
|
||||
)
|
||||
return Event(
|
||||
author="dummy agent",
|
||||
invocation_id="invocation_id",
|
||||
content=types.Content(
|
||||
role="model", parts=[types.Part(text="LLM reply", inline_data=None)]
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def _event_2():
|
||||
from google.adk.events import Event
|
||||
|
||||
return Event(
|
||||
author="dummy agent",
|
||||
invocation_id="invocation_id",
|
||||
content=types.Content(
|
||||
role="model",
|
||||
parts=[
|
||||
types.Part(
|
||||
text=None,
|
||||
inline_data=types.Blob(
|
||||
mime_type="audio/pcm;rate=24000", data=b"\x00\xFF"
|
||||
),
|
||||
)
|
||||
],
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def _event_3():
|
||||
from google.adk.events import Event
|
||||
|
||||
return Event(
|
||||
author="dummy agent", invocation_id="invocation_id", interrupted=True
|
||||
)
|
||||
|
||||
|
||||
# For simplicity, we patch Runner.run_live to yield dummy events.
|
||||
@ -84,13 +101,13 @@ async def dummy_run_live(
|
||||
self, session, live_request_queue
|
||||
) -> AsyncGenerator[Event, None]:
|
||||
# Immediately yield a dummy event with a text reply.
|
||||
yield event1
|
||||
yield _event_1()
|
||||
await asyncio.sleep(0)
|
||||
|
||||
yield event2
|
||||
yield _event_2()
|
||||
await asyncio.sleep(0)
|
||||
|
||||
yield event3
|
||||
yield _event_3()
|
||||
|
||||
raise Exception()
|
||||
|
||||
@ -103,13 +120,13 @@ async def dummy_run_async(
|
||||
run_config: RunConfig = RunConfig(),
|
||||
) -> AsyncGenerator[Event, None]:
|
||||
# Immediately yield a dummy event with a text reply.
|
||||
yield event1
|
||||
yield _event_1()
|
||||
await asyncio.sleep(0)
|
||||
|
||||
yield event2
|
||||
yield _event_2()
|
||||
await asyncio.sleep(0)
|
||||
|
||||
yield event3
|
||||
yield _event_3()
|
||||
|
||||
return
|
||||
|
||||
@ -199,15 +216,15 @@ async def test_sse_endpoint():
|
||||
if event_data:
|
||||
event_count += 1
|
||||
if event_count == 1:
|
||||
assert event_data == event1.model_dump_json(
|
||||
assert event_data == _event_1().model_dump_json(
|
||||
exclude_none=True, by_alias=True
|
||||
)
|
||||
elif event_count == 2:
|
||||
assert event_data == event2.model_dump_json(
|
||||
assert event_data == _event_2().model_dump_json(
|
||||
exclude_none=True, by_alias=True
|
||||
)
|
||||
elif event_count == 3:
|
||||
assert event_data == event3.model_dump_json(
|
||||
assert event_data == _event_3().model_dump_json(
|
||||
exclude_none=True, by_alias=True
|
||||
)
|
||||
else:
|
||||
|
Loading…
Reference in New Issue
Block a user