chore: fix fast api ut

PiperOrigin-RevId: 764935253
This commit is contained in:
Xiang (Sean) Zhou
2025-05-29 16:40:01 -07:00
committed by Copybara-Service
parent 41ba2d1c8a
commit 2b41824e46
2 changed files with 94 additions and 38 deletions

View File

@@ -23,21 +23,23 @@ from types import SimpleNamespace
from typing import Any
from typing import Dict
from typing import List
from typing import Optional
from typing import Tuple
import click
from click.testing import CliRunner
from google.adk.cli import cli_tools_click
from google.adk.evaluation import local_eval_set_results_manager
from google.adk.sessions import Session
from pydantic import BaseModel
import pytest
# Helpers
class _Recorder:
class _Recorder(BaseModel):
"""Callable that records every invocation."""
def __init__(self) -> None:
self.calls: List[Tuple[Tuple[Any, ...], Dict[str, Any]]] = []
calls: List[Tuple[Tuple[Any, ...], Dict[str, Any]]] = []
def __call__(self, *args: Any, **kwargs: Any) -> None: # noqa: D401
self.calls.append((args, kwargs))
@@ -254,30 +256,23 @@ def test_cli_eval_success_path(
def __init__(self, metric_name: str, threshold: float) -> None:
...
class _EvalCaseResult:
class _EvalCaseResult(BaseModel):
eval_set_id: str
eval_id: str
final_eval_status: Any
user_id: str
session_id: str
session_details: Optional[Session] = None
eval_metric_results: list = {}
overall_eval_metric_results: list = {}
eval_metric_result_per_invocation: list = {}
def __init__(
self,
eval_set_id: str,
final_eval_status: str,
user_id: str,
session_id: str,
) -> None:
self.eval_set_id = eval_set_id
self.final_eval_status = final_eval_status
self.user_id = user_id
self.session_id = session_id
class EvalCase(BaseModel):
eval_id: str
class EvalCase:
def __init__(self, eval_id: str):
self.eval_id = eval_id
class EvalSet:
def __init__(self, eval_set_id: str, eval_cases: list[EvalCase]):
self.eval_set_id = eval_set_id
self.eval_cases = eval_cases
class EvalSet(BaseModel):
eval_set_id: str
eval_cases: list[EvalCase]
def mock_save_eval_set_result(cls, *args, **kwargs):
return None
@@ -302,13 +297,38 @@ def test_cli_eval_success_path(
stub.try_get_reset_func = lambda _p: None
stub.parse_and_get_evals_to_run = lambda _paths: {"set1.json": ["e1", "e2"]}
eval_sets_manager_stub.load_eval_set_from_file = lambda x, y: EvalSet(
"test_eval_set_id", [EvalCase("e1"), EvalCase("e2")]
eval_set_id="test_eval_set_id",
eval_cases=[EvalCase(eval_id="e1"), EvalCase(eval_id="e2")],
)
# Create an async generator function for run_evals
async def mock_run_evals(*_a, **_k):
yield _EvalCaseResult("set1.json", "PASSED", "user", "session1")
yield _EvalCaseResult("set1.json", "FAILED", "user", "session2")
yield _EvalCaseResult(
eval_set_id="set1.json",
eval_id="e1",
final_eval_status=_EvalStatus.PASSED,
user_id="user",
session_id="session1",
overall_eval_metric_results=[{
"metricName": "some_metric",
"threshold": 0.0,
"score": 1.0,
"evalStatus": _EvalStatus.PASSED,
}],
)
yield _EvalCaseResult(
eval_set_id="set1.json",
eval_id="e2",
final_eval_status=_EvalStatus.FAILED,
user_id="user",
session_id="session2",
overall_eval_metric_results=[{
"metricName": "some_metric",
"threshold": 0.0,
"score": 0.0,
"evalStatus": _EvalStatus.FAILED,
}],
)
stub.run_evals = mock_run_evals
@@ -324,9 +344,11 @@ def test_cli_eval_success_path(
monkeypatch.setattr(cli_tools_click.asyncio, "run", mock_asyncio_run)
# inject stub
sys.modules["google.adk.cli.cli_eval"] = stub
sys.modules["google.adk.evaluation.local_eval_sets_manager"] = (
eval_sets_manager_stub
monkeypatch.setitem(sys.modules, "google.adk.cli.cli_eval", stub)
monkeypatch.setitem(
sys.modules,
"google.adk.evaluation.local_eval_sets_manager",
eval_sets_manager_stub,
)
# create dummy agent directory