Write eval results locally from adk eval cli.

PiperOrigin-RevId: 762499588
This commit is contained in:
Google Team Member 2025-05-23 11:15:16 -07:00 committed by Copybara-Service
parent 33921d524f
commit 79681e3513
2 changed files with 46 additions and 10 deletions

View File

@ -13,11 +13,14 @@
# limitations under the License. # limitations under the License.
import asyncio import asyncio
import collections
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
from datetime import datetime from datetime import datetime
import logging import logging
import os import os
import tempfile import tempfile
from typing import AsyncGenerator
from typing import Coroutine
from typing import Optional from typing import Optional
import click import click
@ -27,6 +30,8 @@ import uvicorn
from . import cli_create from . import cli_create
from . import cli_deploy from . import cli_deploy
from .. import version from .. import version
from ..evaluation.local_eval_set_results_manager import LocalEvalSetResultsManager
from ..sessions.in_memory_session_service import InMemorySessionService
from .cli import run_cli from .cli import run_cli
from .cli_eval import MISSING_EVAL_DEPENDENCIES_MESSAGE from .cli_eval import MISSING_EVAL_DEPENDENCIES_MESSAGE
from .fast_api import get_fast_api_app from .fast_api import get_fast_api_app
@ -306,7 +311,7 @@ def cli_eval(
EvalMetric(metric_name=metric_name, threshold=threshold) EvalMetric(metric_name=metric_name, threshold=threshold)
) )
print(f"Using evaluation creiteria: {evaluation_criteria}") print(f"Using evaluation criteria: {evaluation_criteria}")
root_agent = get_root_agent(agent_module_file_path) root_agent = get_root_agent(agent_module_file_path)
reset_func = try_get_reset_func(agent_module_file_path) reset_func = try_get_reset_func(agent_module_file_path)
@ -325,21 +330,47 @@ def cli_eval(
e for e in eval_set.eval_cases if e.eval_id in eval_case_ids e for e in eval_set.eval_cases if e.eval_id in eval_case_ids
] ]
eval_set_id_to_eval_cases[eval_set_file_path] = eval_cases eval_set_id_to_eval_cases[eval_set.eval_set_id] = eval_cases
async def _collect_eval_results() -> list[EvalCaseResult]: async def _collect_eval_results() -> list[EvalCaseResult]:
return [ session_service = InMemorySessionService()
result eval_case_results = []
async for result in run_evals( async for eval_case_result in run_evals(
eval_set_id_to_eval_cases, root_agent, reset_func, eval_metrics eval_set_id_to_eval_cases,
) root_agent,
] reset_func,
eval_metrics,
session_service=session_service,
):
eval_case_result.session_details = await session_service.get_session(
app_name=os.path.basename(agent_module_file_path),
user_id=eval_case_result.user_id,
session_id=eval_case_result.session_id,
)
eval_case_results.append(eval_case_result)
return eval_case_results
try: try:
eval_results = asyncio.run(_collect_eval_results()) eval_results = asyncio.run(_collect_eval_results())
except ModuleNotFoundError: except ModuleNotFoundError:
raise click.ClickException(MISSING_EVAL_DEPENDENCIES_MESSAGE) raise click.ClickException(MISSING_EVAL_DEPENDENCIES_MESSAGE)
# Write eval set results.
local_eval_set_results_manager = LocalEvalSetResultsManager(
agent_dir=os.path.dirname(agent_module_file_path)
)
eval_set_id_to_eval_results = collections.defaultdict(list)
for eval_case_result in eval_results:
eval_set_id = eval_case_result.eval_set_id
eval_set_id_to_eval_results[eval_set_id].append(eval_case_result)
for eval_set_id, eval_case_results in eval_set_id_to_eval_results.items():
local_eval_set_results_manager.save_eval_set_result(
app_name=os.path.basename(agent_module_file_path),
eval_set_id=eval_set_id,
eval_case_results=eval_case_results,
)
print("*********************************************************************") print("*********************************************************************")
eval_run_summary = {} eval_run_summary = {}

View File

@ -29,6 +29,10 @@ _ADK_EVAL_HISTORY_DIR = ".adk/eval_history"
_EVAL_SET_RESULT_FILE_EXTENSION = ".evalset_result.json" _EVAL_SET_RESULT_FILE_EXTENSION = ".evalset_result.json"
def _sanitize_eval_set_result_name(eval_set_result_name: str) -> str:
return eval_set_result_name.replace("/", "_")
class LocalEvalSetResultsManager(EvalSetResultsManager): class LocalEvalSetResultsManager(EvalSetResultsManager):
"""An EvalSetResult manager that stores eval set results locally on disk.""" """An EvalSetResult manager that stores eval set results locally on disk."""
@ -44,9 +48,10 @@ class LocalEvalSetResultsManager(EvalSetResultsManager):
) -> None: ) -> None:
"""Creates and saves a new EvalSetResult given eval_case_results.""" """Creates and saves a new EvalSetResult given eval_case_results."""
timestamp = time.time() timestamp = time.time()
eval_set_result_name = app_name + "_" + eval_set_id + "_" + str(timestamp) eval_set_result_id = app_name + "_" + eval_set_id + "_" + str(timestamp)
eval_set_result_name = _sanitize_eval_set_result_name(eval_set_result_id)
eval_set_result = EvalSetResult( eval_set_result = EvalSetResult(
eval_set_result_id=eval_set_result_name, eval_set_result_id=eval_set_result_id,
eval_set_result_name=eval_set_result_name, eval_set_result_name=eval_set_result_name,
eval_set_id=eval_set_id, eval_set_id=eval_set_id,
eval_case_results=eval_case_results, eval_case_results=eval_case_results,