From d40df2edf23bed52da383ea78dad315bab34d57b Mon Sep 17 00:00:00 2001 From: "Xiang (Sean) Zhou" Date: Tue, 27 May 2025 23:06:46 -0700 Subject: [PATCH] chore: fix unit tests PiperOrigin-RevId: 764107186 --- tests/unittests/fast_api/test_fast_api.py | 52 ++++++++-------- .../sessions/test_session_service.py | 4 +- tests/unittests/test_telemetry.py | 59 ++++++++++--------- 3 files changed, 62 insertions(+), 53 deletions(-) diff --git a/tests/unittests/fast_api/test_fast_api.py b/tests/unittests/fast_api/test_fast_api.py index 9729098..ad12c68 100644 --- a/tests/unittests/fast_api/test_fast_api.py +++ b/tests/unittests/fast_api/test_fast_api.py @@ -14,10 +14,7 @@ import asyncio import logging -import os -import sys import time -import types as ptypes from unittest.mock import MagicMock from unittest.mock import patch @@ -25,7 +22,6 @@ from fastapi.testclient import TestClient from google.adk.agents.base_agent import BaseAgent from google.adk.agents.run_config import RunConfig from google.adk.cli.fast_api import get_fast_api_app -from google.adk.cli.utils import envs from google.adk.events import Event from google.adk.runners import Runner from google.adk.sessions.base_session_service import ListSessionsResponse @@ -48,22 +44,7 @@ class DummyAgent(BaseAgent): self.sub_agents = [] -# Set up dummy module and add to sys.modules -dummy_module = ptypes.ModuleType("test_agent") -dummy_module.agent = ptypes.SimpleNamespace( - root_agent=DummyAgent(name="dummy_agent") -) -sys.modules["test_app"] = dummy_module - -# Try to load environment variables, with a fallback for testing -try: - envs.load_dotenv_for_agent("test_app", ".") -except Exception as e: - logger.warning(f"Could not load environment variables: {e}") - # Create a basic .env file if needed - if not os.path.exists(".env"): - with open(".env", "w") as f: - f.write("# Test environment variables\n") +root_agent = DummyAgent(name="dummy_agent") # Create sample events that our mocked runner will return @@ -150,6 +131,20 @@ def test_session_info(): } +@pytest.fixture +def mock_agent_loader(): + + class MockAgentLoader: + + def __init__(self, agents_dir: str): + pass + + def load_agent(self, app_name): + return root_agent + + return MockAgentLoader(".") + + @pytest.fixture def mock_session_service(): """Create a mock session service that uses an in-memory dictionary.""" @@ -287,24 +282,33 @@ def mock_memory_service(): @pytest.fixture -def test_app(mock_session_service, mock_artifact_service, mock_memory_service): +def test_app( + mock_session_service, + mock_artifact_service, + mock_memory_service, + mock_agent_loader, +): """Create a TestClient for the FastAPI app without starting a server.""" # Patch multiple services and signal handlers with ( patch("signal.signal", return_value=None), patch( - "google.adk.cli.fast_api.InMemorySessionService", # Changed this line + "google.adk.cli.fast_api.InMemorySessionService", return_value=mock_session_service, ), patch( - "google.adk.cli.fast_api.InMemoryArtifactService", # Make consistent + "google.adk.cli.fast_api.InMemoryArtifactService", return_value=mock_artifact_service, ), patch( - "google.adk.cli.fast_api.InMemoryMemoryService", # Make consistent + "google.adk.cli.fast_api.InMemoryMemoryService", return_value=mock_memory_service, ), + patch( + "google.adk.cli.fast_api.AgentLoader", + return_value=mock_agent_loader, + ), ): # Get the FastAPI app, but don't actually run it app = get_fast_api_app( diff --git a/tests/unittests/sessions/test_session_service.py b/tests/unittests/sessions/test_session_service.py index e28f7ff..676fb7d 100644 --- a/tests/unittests/sessions/test_session_service.py +++ b/tests/unittests/sessions/test_session_service.py @@ -315,7 +315,9 @@ async def test_append_event_complete(service_type): @pytest.mark.asyncio -@pytest.mark.parametrize('service_type', [SessionServiceType.IN_MEMORY, SessionServiceType.DATABASE]) +@pytest.mark.parametrize( + 'service_type', [SessionServiceType.IN_MEMORY, SessionServiceType.DATABASE] +) async def test_get_session_with_config(service_type): session_service = get_session_service(service_type) app_name = 'my_app' diff --git a/tests/unittests/test_telemetry.py b/tests/unittests/test_telemetry.py index 5a77584..64da250 100644 --- a/tests/unittests/test_telemetry.py +++ b/tests/unittests/test_telemetry.py @@ -1,11 +1,11 @@ from typing import Any from typing import Optional -from google.adk.sessions import InMemorySessionService from google.adk.agents.invocation_context import InvocationContext from google.adk.agents.llm_agent import LlmAgent from google.adk.models.llm_request import LlmRequest from google.adk.models.llm_response import LlmResponse +from google.adk.sessions import InMemorySessionService from google.adk.telemetry import trace_call_llm from google.genai import types import pytest @@ -16,10 +16,10 @@ async def _create_invocation_context( ) -> InvocationContext: session_service = InMemorySessionService() session = await session_service.create_session( - app_name='test_app', user_id='test_user', state=state + app_name="test_app", user_id="test_user", state=state ) invocation_context = InvocationContext( - invocation_id='test_id', + invocation_id="test_id", agent=agent, session=session, session_service=session_service, @@ -29,34 +29,37 @@ async def _create_invocation_context( @pytest.mark.asyncio async def test_trace_call_llm_function_response_includes_part_from_bytes(): - agent = LlmAgent(name='test_agent') + agent = LlmAgent(name="test_agent") invocation_context = await _create_invocation_context(agent) llm_request = LlmRequest( - contents=[ - types.Content( - role="user", - parts=[ - types.Part.from_function_response( - name="test_function_1", - response={ - "result": b"test_data", - }, + contents=[ + types.Content( + role="user", + parts=[ + types.Part.from_function_response( + name="test_function_1", + response={ + "result": b"test_data", + }, + ), + ], ), - ], - ), - types.Content( - role="user", - parts=[ - types.Part.from_function_response( - name="test_function_2", - response={ - "result": types.Part.from_bytes(data=b"test_data", mime_type="application/octet-stream"), - }, + types.Content( + role="user", + parts=[ + types.Part.from_function_response( + name="test_function_2", + response={ + "result": types.Part.from_bytes( + data=b"test_data", + mime_type="application/octet-stream", + ), + }, + ), + ], ), - ], - ), - ], - config=types.GenerateContentConfig(system_instruction=""), + ], + config=types.GenerateContentConfig(system_instruction=""), ) llm_response = LlmResponse(turn_complete=True) - trace_call_llm(invocation_context, 'test_event_id', llm_request, llm_response) + trace_call_llm(invocation_context, "test_event_id", llm_request, llm_response)