chore: fix unit tests

PiperOrigin-RevId: 764107186
This commit is contained in:
Xiang (Sean) Zhou 2025-05-27 23:06:46 -07:00 committed by Copybara-Service
parent a66f12273c
commit d40df2edf2
3 changed files with 62 additions and 53 deletions

View File

@ -14,10 +14,7 @@
import asyncio import asyncio
import logging import logging
import os
import sys
import time import time
import types as ptypes
from unittest.mock import MagicMock from unittest.mock import MagicMock
from unittest.mock import patch from unittest.mock import patch
@ -25,7 +22,6 @@ from fastapi.testclient import TestClient
from google.adk.agents.base_agent import BaseAgent from google.adk.agents.base_agent import BaseAgent
from google.adk.agents.run_config import RunConfig from google.adk.agents.run_config import RunConfig
from google.adk.cli.fast_api import get_fast_api_app from google.adk.cli.fast_api import get_fast_api_app
from google.adk.cli.utils import envs
from google.adk.events import Event from google.adk.events import Event
from google.adk.runners import Runner from google.adk.runners import Runner
from google.adk.sessions.base_session_service import ListSessionsResponse from google.adk.sessions.base_session_service import ListSessionsResponse
@ -48,22 +44,7 @@ class DummyAgent(BaseAgent):
self.sub_agents = [] self.sub_agents = []
# Set up dummy module and add to sys.modules root_agent = DummyAgent(name="dummy_agent")
dummy_module = ptypes.ModuleType("test_agent")
dummy_module.agent = ptypes.SimpleNamespace(
root_agent=DummyAgent(name="dummy_agent")
)
sys.modules["test_app"] = dummy_module
# Try to load environment variables, with a fallback for testing
try:
envs.load_dotenv_for_agent("test_app", ".")
except Exception as e:
logger.warning(f"Could not load environment variables: {e}")
# Create a basic .env file if needed
if not os.path.exists(".env"):
with open(".env", "w") as f:
f.write("# Test environment variables\n")
# Create sample events that our mocked runner will return # Create sample events that our mocked runner will return
@ -150,6 +131,20 @@ def test_session_info():
} }
@pytest.fixture
def mock_agent_loader():
class MockAgentLoader:
def __init__(self, agents_dir: str):
pass
def load_agent(self, app_name):
return root_agent
return MockAgentLoader(".")
@pytest.fixture @pytest.fixture
def mock_session_service(): def mock_session_service():
"""Create a mock session service that uses an in-memory dictionary.""" """Create a mock session service that uses an in-memory dictionary."""
@ -287,24 +282,33 @@ def mock_memory_service():
@pytest.fixture @pytest.fixture
def test_app(mock_session_service, mock_artifact_service, mock_memory_service): def test_app(
mock_session_service,
mock_artifact_service,
mock_memory_service,
mock_agent_loader,
):
"""Create a TestClient for the FastAPI app without starting a server.""" """Create a TestClient for the FastAPI app without starting a server."""
# Patch multiple services and signal handlers # Patch multiple services and signal handlers
with ( with (
patch("signal.signal", return_value=None), patch("signal.signal", return_value=None),
patch( patch(
"google.adk.cli.fast_api.InMemorySessionService", # Changed this line "google.adk.cli.fast_api.InMemorySessionService",
return_value=mock_session_service, return_value=mock_session_service,
), ),
patch( patch(
"google.adk.cli.fast_api.InMemoryArtifactService", # Make consistent "google.adk.cli.fast_api.InMemoryArtifactService",
return_value=mock_artifact_service, return_value=mock_artifact_service,
), ),
patch( patch(
"google.adk.cli.fast_api.InMemoryMemoryService", # Make consistent "google.adk.cli.fast_api.InMemoryMemoryService",
return_value=mock_memory_service, return_value=mock_memory_service,
), ),
patch(
"google.adk.cli.fast_api.AgentLoader",
return_value=mock_agent_loader,
),
): ):
# Get the FastAPI app, but don't actually run it # Get the FastAPI app, but don't actually run it
app = get_fast_api_app( app = get_fast_api_app(

View File

@ -315,7 +315,9 @@ async def test_append_event_complete(service_type):
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.parametrize('service_type', [SessionServiceType.IN_MEMORY, SessionServiceType.DATABASE]) @pytest.mark.parametrize(
'service_type', [SessionServiceType.IN_MEMORY, SessionServiceType.DATABASE]
)
async def test_get_session_with_config(service_type): async def test_get_session_with_config(service_type):
session_service = get_session_service(service_type) session_service = get_session_service(service_type)
app_name = 'my_app' app_name = 'my_app'

View File

@ -1,11 +1,11 @@
from typing import Any from typing import Any
from typing import Optional from typing import Optional
from google.adk.sessions import InMemorySessionService
from google.adk.agents.invocation_context import InvocationContext from google.adk.agents.invocation_context import InvocationContext
from google.adk.agents.llm_agent import LlmAgent from google.adk.agents.llm_agent import LlmAgent
from google.adk.models.llm_request import LlmRequest from google.adk.models.llm_request import LlmRequest
from google.adk.models.llm_response import LlmResponse from google.adk.models.llm_response import LlmResponse
from google.adk.sessions import InMemorySessionService
from google.adk.telemetry import trace_call_llm from google.adk.telemetry import trace_call_llm
from google.genai import types from google.genai import types
import pytest import pytest
@ -16,10 +16,10 @@ async def _create_invocation_context(
) -> InvocationContext: ) -> InvocationContext:
session_service = InMemorySessionService() session_service = InMemorySessionService()
session = await session_service.create_session( session = await session_service.create_session(
app_name='test_app', user_id='test_user', state=state app_name="test_app", user_id="test_user", state=state
) )
invocation_context = InvocationContext( invocation_context = InvocationContext(
invocation_id='test_id', invocation_id="test_id",
agent=agent, agent=agent,
session=session, session=session,
session_service=session_service, session_service=session_service,
@ -29,34 +29,37 @@ async def _create_invocation_context(
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_trace_call_llm_function_response_includes_part_from_bytes(): async def test_trace_call_llm_function_response_includes_part_from_bytes():
agent = LlmAgent(name='test_agent') agent = LlmAgent(name="test_agent")
invocation_context = await _create_invocation_context(agent) invocation_context = await _create_invocation_context(agent)
llm_request = LlmRequest( llm_request = LlmRequest(
contents=[ contents=[
types.Content( types.Content(
role="user", role="user",
parts=[ parts=[
types.Part.from_function_response( types.Part.from_function_response(
name="test_function_1", name="test_function_1",
response={ response={
"result": b"test_data", "result": b"test_data",
}, },
),
],
), ),
], types.Content(
), role="user",
types.Content( parts=[
role="user", types.Part.from_function_response(
parts=[ name="test_function_2",
types.Part.from_function_response( response={
name="test_function_2", "result": types.Part.from_bytes(
response={ data=b"test_data",
"result": types.Part.from_bytes(data=b"test_data", mime_type="application/octet-stream"), mime_type="application/octet-stream",
}, ),
},
),
],
), ),
], ],
), config=types.GenerateContentConfig(system_instruction=""),
],
config=types.GenerateContentConfig(system_instruction=""),
) )
llm_response = LlmResponse(turn_complete=True) llm_response = LlmResponse(turn_complete=True)
trace_call_llm(invocation_context, 'test_event_id', llm_request, llm_response) trace_call_llm(invocation_context, "test_event_id", llm_request, llm_response)