diff --git a/src/google/adk/cli/cli_eval.py b/src/google/adk/cli/cli_eval.py index 7a21cf8..8fa6ea2 100644 --- a/src/google/adk/cli/cli_eval.py +++ b/src/google/adk/cli/cli_eval.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import datetime from enum import Enum import importlib.util import json @@ -25,8 +26,10 @@ from typing import Optional import uuid from pydantic import BaseModel +from pydantic import Field from ..agents import Agent +from ..sessions.session import Session logger = logging.getLogger(__name__) @@ -43,16 +46,25 @@ class EvalMetric(BaseModel): class EvalMetricResult(BaseModel): - score: Optional[float] + score: Optional[float] = None eval_status: EvalStatus -class EvalResult(BaseModel): +class EvalCaseResult(BaseModel): eval_set_file: str eval_id: str final_eval_status: EvalStatus eval_metric_results: list[tuple[EvalMetric, EvalMetricResult]] session_id: str + session_details: Optional[Session] = None + + +class EvalSetResult(BaseModel): + eval_set_result_id: str + eval_set_result_name: str + eval_set_id: str + eval_case_results: list[EvalCaseResult] = Field(default_factory=list) + creation_timestamp: float = 0.0 MISSING_EVAL_DEPENDENCIES_MESSAGE = ( @@ -154,7 +166,7 @@ async def run_evals( session_service=None, artifact_service=None, print_detailed_results=False, -) -> AsyncGenerator[EvalResult, None]: +) -> AsyncGenerator[EvalCaseResult, None]: try: from ..evaluation.agent_evaluator import EvaluationGenerator from ..evaluation.response_evaluator import ResponseEvaluator @@ -249,7 +261,7 @@ async def run_evals( else: raise ValueError("Unknown eval status.") - yield EvalResult( + yield EvalCaseResult( eval_set_file=eval_set_file, eval_id=eval_name, final_eval_status=final_eval_status, diff --git a/src/google/adk/cli/cli_tools_click.py b/src/google/adk/cli/cli_tools_click.py index 88e3026..29476f1 100644 --- a/src/google/adk/cli/cli_tools_click.py +++ b/src/google/adk/cli/cli_tools_click.py @@ -245,7 +245,7 @@ def cli_eval( try: from .cli_eval import EvalMetric - from .cli_eval import EvalResult + from .cli_eval import EvalCaseResult from .cli_eval import EvalStatus from .cli_eval import get_evaluation_criteria_or_default from .cli_eval import get_root_agent @@ -269,7 +269,7 @@ def cli_eval( eval_set_to_evals = parse_and_get_evals_to_run(eval_set_file_path) - async def _collect_eval_results() -> list[EvalResult]: + async def _collect_eval_results() -> list[EvalCaseResult]: return [ result async for result in run_evals( @@ -290,7 +290,7 @@ def cli_eval( eval_run_summary = {} for eval_result in eval_results: - eval_result: EvalResult + eval_result: EvalCaseResult if eval_result.eval_set_file not in eval_run_summary: eval_run_summary[eval_result.eval_set_file] = [0, 0] diff --git a/tests/unittests/cli/utils/test_cli_tools_click.py b/tests/unittests/cli/utils/test_cli_tools_click.py index f52f424..0fe2958 100644 --- a/tests/unittests/cli/utils/test_cli_tools_click.py +++ b/tests/unittests/cli/utils/test_cli_tools_click.py @@ -250,7 +250,7 @@ def test_cli_eval_success_path( def __init__(self, metric_name: str, threshold: float) -> None: ... - class _EvalResult: + class _EvalCaseResult: def __init__(self, eval_set_file: str, final_eval_status: str) -> None: self.eval_set_file = eval_set_file @@ -261,7 +261,7 @@ def test_cli_eval_success_path( # helper funcs stub.EvalMetric = _EvalMetric - stub.EvalResult = _EvalResult + stub.EvalCaseResult = _EvalCaseResult stub.EvalStatus = _EvalStatus stub.MISSING_EVAL_DEPENDENCIES_MESSAGE = "stub msg" @@ -272,8 +272,8 @@ def test_cli_eval_success_path( # Create an async generator function for run_evals async def mock_run_evals(*_a, **_k): - yield _EvalResult("set1.json", "PASSED") - yield _EvalResult("set1.json", "FAILED") + yield _EvalCaseResult("set1.json", "PASSED") + yield _EvalCaseResult("set1.json", "FAILED") stub.run_evals = mock_run_evals