make eval functions async

PiperOrigin-RevId: 756106627
This commit is contained in:
Xiang (Sean) Zhou
2025-05-07 19:52:47 -07:00
committed by Copybara-Service
parent cc1ef3f2ad
commit e7d9cf359a
11 changed files with 50 additions and 36 deletions
+13 -11
View File
@@ -20,7 +20,7 @@ import os
import sys
import traceback
from typing import Any
from typing import Generator
from typing import AsyncGenerator
from typing import Optional
import uuid
@@ -146,7 +146,7 @@ def parse_and_get_evals_to_run(
return eval_set_to_evals
def run_evals(
async def run_evals(
eval_set_to_evals: dict[str, list[str]],
root_agent: Agent,
reset_func: Optional[Any],
@@ -154,7 +154,7 @@ def run_evals(
session_service=None,
artifact_service=None,
print_detailed_results=False,
) -> Generator[EvalResult, None, None]:
) -> AsyncGenerator[EvalResult, None]:
try:
from ..evaluation.agent_evaluator import EvaluationGenerator
from ..evaluation.response_evaluator import ResponseEvaluator
@@ -181,14 +181,16 @@ def run_evals(
print(f"Running Eval: {eval_set_file}:{eval_name}")
session_id = f"{EVAL_SESSION_ID_PREFIX}{str(uuid.uuid4())}"
scrape_result = EvaluationGenerator._process_query_with_root_agent(
data=eval_data,
root_agent=root_agent,
reset_func=reset_func,
initial_session=initial_session,
session_id=session_id,
session_service=session_service,
artifact_service=artifact_service,
scrape_result = (
await EvaluationGenerator._process_query_with_root_agent(
data=eval_data,
root_agent=root_agent,
reset_func=reset_func,
initial_session=initial_session,
session_id=session_id,
session_service=session_service,
artifact_service=artifact_service,
)
)
eval_metric_results = []
+8 -6
View File
@@ -258,12 +258,14 @@ def cli_eval(
try:
eval_results = list(
run_evals(
eval_set_to_evals,
root_agent,
reset_func,
eval_metrics,
print_detailed_results=print_detailed_results,
asyncio.run(
run_evals(
eval_set_to_evals,
root_agent,
reset_func,
eval_metrics,
print_detailed_results=print_detailed_results,
)
)
)
except ModuleNotFoundError:
+1 -1
View File
@@ -467,7 +467,7 @@ def get_fast_api_app(
)
root_agent = await _get_root_agent_async(app_name)
eval_results = list(
run_evals(
await run_evals(
eval_set_to_evals,
root_agent,
getattr(root_agent, "reset_data", None),