chore: fix fast api ut

PiperOrigin-RevId: 764935253
This commit is contained in:
Xiang (Sean) Zhou 2025-05-29 16:40:01 -07:00 committed by Copybara-Service
parent 41ba2d1c8a
commit 2b41824e46
2 changed files with 94 additions and 38 deletions

View File

@ -23,21 +23,23 @@ from types import SimpleNamespace
from typing import Any from typing import Any
from typing import Dict from typing import Dict
from typing import List from typing import List
from typing import Optional
from typing import Tuple from typing import Tuple
import click import click
from click.testing import CliRunner from click.testing import CliRunner
from google.adk.cli import cli_tools_click from google.adk.cli import cli_tools_click
from google.adk.evaluation import local_eval_set_results_manager from google.adk.evaluation import local_eval_set_results_manager
from google.adk.sessions import Session
from pydantic import BaseModel
import pytest import pytest
# Helpers # Helpers
class _Recorder: class _Recorder(BaseModel):
"""Callable that records every invocation.""" """Callable that records every invocation."""
def __init__(self) -> None: calls: List[Tuple[Tuple[Any, ...], Dict[str, Any]]] = []
self.calls: List[Tuple[Tuple[Any, ...], Dict[str, Any]]] = []
def __call__(self, *args: Any, **kwargs: Any) -> None: # noqa: D401 def __call__(self, *args: Any, **kwargs: Any) -> None: # noqa: D401
self.calls.append((args, kwargs)) self.calls.append((args, kwargs))
@ -254,30 +256,23 @@ def test_cli_eval_success_path(
def __init__(self, metric_name: str, threshold: float) -> None: def __init__(self, metric_name: str, threshold: float) -> None:
... ...
class _EvalCaseResult: class _EvalCaseResult(BaseModel):
eval_set_id: str
eval_id: str
final_eval_status: Any
user_id: str
session_id: str
session_details: Optional[Session] = None
eval_metric_results: list = {}
overall_eval_metric_results: list = {}
eval_metric_result_per_invocation: list = {}
def __init__( class EvalCase(BaseModel):
self, eval_id: str
eval_set_id: str,
final_eval_status: str,
user_id: str,
session_id: str,
) -> None:
self.eval_set_id = eval_set_id
self.final_eval_status = final_eval_status
self.user_id = user_id
self.session_id = session_id
class EvalCase: class EvalSet(BaseModel):
eval_set_id: str
def __init__(self, eval_id: str): eval_cases: list[EvalCase]
self.eval_id = eval_id
class EvalSet:
def __init__(self, eval_set_id: str, eval_cases: list[EvalCase]):
self.eval_set_id = eval_set_id
self.eval_cases = eval_cases
def mock_save_eval_set_result(cls, *args, **kwargs): def mock_save_eval_set_result(cls, *args, **kwargs):
return None return None
@ -302,13 +297,38 @@ def test_cli_eval_success_path(
stub.try_get_reset_func = lambda _p: None stub.try_get_reset_func = lambda _p: None
stub.parse_and_get_evals_to_run = lambda _paths: {"set1.json": ["e1", "e2"]} stub.parse_and_get_evals_to_run = lambda _paths: {"set1.json": ["e1", "e2"]}
eval_sets_manager_stub.load_eval_set_from_file = lambda x, y: EvalSet( eval_sets_manager_stub.load_eval_set_from_file = lambda x, y: EvalSet(
"test_eval_set_id", [EvalCase("e1"), EvalCase("e2")] eval_set_id="test_eval_set_id",
eval_cases=[EvalCase(eval_id="e1"), EvalCase(eval_id="e2")],
) )
# Create an async generator function for run_evals # Create an async generator function for run_evals
async def mock_run_evals(*_a, **_k): async def mock_run_evals(*_a, **_k):
yield _EvalCaseResult("set1.json", "PASSED", "user", "session1") yield _EvalCaseResult(
yield _EvalCaseResult("set1.json", "FAILED", "user", "session2") eval_set_id="set1.json",
eval_id="e1",
final_eval_status=_EvalStatus.PASSED,
user_id="user",
session_id="session1",
overall_eval_metric_results=[{
"metricName": "some_metric",
"threshold": 0.0,
"score": 1.0,
"evalStatus": _EvalStatus.PASSED,
}],
)
yield _EvalCaseResult(
eval_set_id="set1.json",
eval_id="e2",
final_eval_status=_EvalStatus.FAILED,
user_id="user",
session_id="session2",
overall_eval_metric_results=[{
"metricName": "some_metric",
"threshold": 0.0,
"score": 0.0,
"evalStatus": _EvalStatus.FAILED,
}],
)
stub.run_evals = mock_run_evals stub.run_evals = mock_run_evals
@ -324,9 +344,11 @@ def test_cli_eval_success_path(
monkeypatch.setattr(cli_tools_click.asyncio, "run", mock_asyncio_run) monkeypatch.setattr(cli_tools_click.asyncio, "run", mock_asyncio_run)
# inject stub # inject stub
sys.modules["google.adk.cli.cli_eval"] = stub monkeypatch.setitem(sys.modules, "google.adk.cli.cli_eval", stub)
sys.modules["google.adk.evaluation.local_eval_sets_manager"] = ( monkeypatch.setitem(
eval_sets_manager_stub sys.modules,
"google.adk.evaluation.local_eval_sets_manager",
eval_sets_manager_stub,
) )
# create dummy agent directory # create dummy agent directory

View File

@ -15,6 +15,8 @@
import asyncio import asyncio
import logging import logging
import time import time
from typing import Any
from typing import Optional
from unittest.mock import MagicMock from unittest.mock import MagicMock
from unittest.mock import patch from unittest.mock import patch
@ -30,6 +32,7 @@ from google.adk.events import Event
from google.adk.runners import Runner from google.adk.runners import Runner
from google.adk.sessions.base_session_service import ListSessionsResponse from google.adk.sessions.base_session_service import ListSessionsResponse
from google.genai import types from google.genai import types
from pydantic import BaseModel
import pytest import pytest
# Configure logging to help diagnose server startup issues # Configure logging to help diagnose server startup issues
@ -113,6 +116,40 @@ async def dummy_run_async(
yield _event_3() yield _event_3()
# Define a local mock for EvalCaseResult specific to fast_api tests
class _MockEvalCaseResult(BaseModel):
eval_set_id: str
eval_id: str
final_eval_status: Any
user_id: str
session_id: str
eval_set_file: str
eval_metric_results: list = {}
overall_eval_metric_results: list = ({},)
eval_metric_result_per_invocation: list = {}
# Mock for the run_evals function, tailored for test_run_eval
async def mock_run_evals_for_fast_api(*args, **kwargs):
# This is what the test_run_eval expects for its assertions
yield _MockEvalCaseResult(
eval_set_id="test_eval_set_id", # Matches expected in verify_eval_case_result
eval_id="test_eval_case_id", # Matches expected
final_eval_status=1, # Matches expected (assuming 1 is PASSED)
user_id="test_user", # Placeholder, adapt if needed
session_id="test_session_for_eval_case", # Placeholder
overall_eval_metric_results=[{ # Matches expected
"metricName": "tool_trajectory_avg_score",
"threshold": 0.5,
"score": 1.0,
"evalStatus": 1,
}],
# Provide other fields if RunEvalResult or subsequent processing needs them
eval_metric_results=[],
eval_metric_result_per_invocation=[],
)
################################################# #################################################
# Test Fixtures # Test Fixtures
################################################# #################################################
@ -414,6 +451,10 @@ def test_app(
"google.adk.cli.fast_api.LocalEvalSetResultsManager", "google.adk.cli.fast_api.LocalEvalSetResultsManager",
return_value=mock_eval_set_results_manager, return_value=mock_eval_set_results_manager,
), ),
patch(
"google.adk.cli.cli_eval.run_evals", # Patch where it's imported in fast_api.py
new=mock_run_evals_for_fast_api,
),
): ):
# Get the FastAPI app, but don't actually run it # Get the FastAPI app, but don't actually run it
app = get_fast_api_app( app = get_fast_api_app(
@ -613,13 +654,6 @@ def test_list_artifact_names(test_app, create_test_session):
logger.info(f"Listed {len(data)} artifacts") logger.info(f"Listed {len(data)} artifacts")
def test_get_eval_set_not_found(test_app):
"""Test getting an eval set that doesn't exist."""
url = "/apps/test_app_name/eval_sets/test_eval_set_id_not_found"
response = test_app.get(url)
assert response.status_code == 404
def test_create_eval_set(test_app, test_session_info): def test_create_eval_set(test_app, test_session_info):
"""Test creating an eval set.""" """Test creating an eval set."""
url = f"/apps/{test_session_info['app_name']}/eval_sets/test_eval_set_id" url = f"/apps/{test_session_info['app_name']}/eval_sets/test_eval_set_id"