make eval functions async

PiperOrigin-RevId: 756106627
This commit is contained in:
Xiang (Sean) Zhou 2025-05-07 19:52:47 -07:00 committed by Copybara-Service
parent cc1ef3f2ad
commit e7d9cf359a
11 changed files with 50 additions and 36 deletions

View File

@ -20,7 +20,7 @@ import os
import sys
import traceback
from typing import Any
from typing import Generator
from typing import AsyncGenerator
from typing import Optional
import uuid
@ -146,7 +146,7 @@ def parse_and_get_evals_to_run(
return eval_set_to_evals
def run_evals(
async def run_evals(
eval_set_to_evals: dict[str, list[str]],
root_agent: Agent,
reset_func: Optional[Any],
@ -154,7 +154,7 @@ def run_evals(
session_service=None,
artifact_service=None,
print_detailed_results=False,
) -> Generator[EvalResult, None, None]:
) -> AsyncGenerator[EvalResult, None]:
try:
from ..evaluation.agent_evaluator import EvaluationGenerator
from ..evaluation.response_evaluator import ResponseEvaluator
@ -181,7 +181,8 @@ def run_evals(
print(f"Running Eval: {eval_set_file}:{eval_name}")
session_id = f"{EVAL_SESSION_ID_PREFIX}{str(uuid.uuid4())}"
scrape_result = EvaluationGenerator._process_query_with_root_agent(
scrape_result = (
await EvaluationGenerator._process_query_with_root_agent(
data=eval_data,
root_agent=root_agent,
reset_func=reset_func,
@ -190,6 +191,7 @@ def run_evals(
session_service=session_service,
artifact_service=artifact_service,
)
)
eval_metric_results = []
for eval_metric in eval_metrics:

View File

@ -258,6 +258,7 @@ def cli_eval(
try:
eval_results = list(
asyncio.run(
run_evals(
eval_set_to_evals,
root_agent,
@ -266,6 +267,7 @@ def cli_eval(
print_detailed_results=print_detailed_results,
)
)
)
except ModuleNotFoundError:
raise click.ClickException(MISSING_EVAL_DEPENDENCIES_MESSAGE)

View File

@ -467,7 +467,7 @@ def get_fast_api_app(
)
root_agent = await _get_root_agent_async(app_name)
eval_results = list(
run_evals(
await run_evals(
eval_set_to_evals,
root_agent,
getattr(root_agent, "reset_data", None),

View File

@ -76,7 +76,7 @@ class AgentEvaluator:
return DEFAULT_CRITERIA
@staticmethod
def evaluate(
async def evaluate(
agent_module,
eval_dataset_file_path_or_dir,
num_runs=NUM_RUNS,
@ -120,7 +120,7 @@ class AgentEvaluator:
AgentEvaluator._validate_input([dataset], criteria)
evaluation_response = AgentEvaluator._generate_responses(
evaluation_response = await AgentEvaluator._generate_responses(
agent_module,
[dataset],
num_runs,
@ -246,7 +246,7 @@ class AgentEvaluator:
return inferred_criteria
@staticmethod
def _generate_responses(
async def _generate_responses(
agent_module, eval_dataset, num_runs, agent_name=None, initial_session={}
):
"""Generates evaluation responses by running the agent module multiple times."""

View File

@ -32,7 +32,7 @@ class EvaluationGenerator:
"""Generates evaluation responses for agents."""
@staticmethod
def generate_responses(
async def generate_responses(
eval_dataset,
agent_module_path,
repeat_num=3,
@ -107,7 +107,7 @@ class EvaluationGenerator:
)
@staticmethod
def _process_query_with_root_agent(
async def _process_query_with_root_agent(
data,
root_agent,
reset_func,
@ -128,7 +128,7 @@ class EvaluationGenerator:
all_mock_tools.add(expected[EvalConstants.TOOL_NAME])
eval_data_copy = data.copy()
EvaluationGenerator.apply_before_tool_callback(
await EvaluationGenerator.apply_before_tool_callback(
root_agent,
lambda *args: EvaluationGenerator.before_tool_callback(
*args, eval_dataset=eval_data_copy
@ -247,7 +247,7 @@ class EvaluationGenerator:
return None
@staticmethod
def apply_before_tool_callback(
async def apply_before_tool_callback(
agent: BaseAgent,
callback: BeforeToolCallback,
all_mock_tools: set[str],
@ -265,6 +265,6 @@ class EvaluationGenerator:
# Apply recursively to subagents if they exist
for sub_agent in agent.sub_agents:
EvaluationGenerator.apply_before_tool_callback(
await EvaluationGenerator.apply_before_tool_callback(
sub_agent, callback, all_mock_tools
)

View File

@ -51,12 +51,13 @@ def agent_eval_artifacts_in_fixture():
return agent_eval_artifacts
@pytest.mark.asyncio
@pytest.mark.parametrize(
'agent_name, evalfile, initial_session_file',
agent_eval_artifacts_in_fixture(),
ids=[agent_name for agent_name, _, _ in agent_eval_artifacts_in_fixture()],
)
def test_evaluate_agents_long_running_4_runs_per_eval_item(
async def test_evaluate_agents_long_running_4_runs_per_eval_item(
agent_name, evalfile, initial_session_file
):
"""Test agents evaluation in fixture folder.
@ -66,7 +67,7 @@ def test_evaluate_agents_long_running_4_runs_per_eval_item(
A single eval item is a session that can have multiple queries in it.
"""
AgentEvaluator.evaluate(
await AgentEvaluator.evaluate(
agent_module=agent_name,
eval_dataset_file_path_or_dir=evalfile,
initial_session_file=initial_session_file,

View File

@ -15,7 +15,8 @@
from google.adk.evaluation import AgentEvaluator
def test_eval_agent():
@pytest.mark.asyncio
async def test_eval_agent():
AgentEvaluator.evaluate(
agent_module="tests.integration.fixture.trip_planner_agent",
eval_dataset_file_path_or_dir=(

View File

@ -15,7 +15,8 @@
from google.adk.evaluation import AgentEvaluator
def test_simple_multi_turn_conversation():
@pytest.mark.asyncio
async def test_simple_multi_turn_conversation():
"""Test a simple multi-turn conversation."""
AgentEvaluator.evaluate(
agent_module="tests.integration.fixture.home_automation_agent",
@ -24,7 +25,8 @@ def test_simple_multi_turn_conversation():
)
def test_dependent_tool_calls():
@pytest.mark.asyncio
async def test_dependent_tool_calls():
"""Test subsequent tool calls that are dependent on previous tool calls."""
AgentEvaluator.evaluate(
agent_module="tests.integration.fixture.home_automation_agent",
@ -33,8 +35,10 @@ def test_dependent_tool_calls():
)
def test_memorizing_past_events():
@pytest.mark.asyncio
async def test_memorizing_past_events():
"""Test memorizing past events."""
AgentEvaluator.evaluate(
agent_module="tests.integration.fixture.home_automation_agent",
eval_dataset_file_path_or_dir="tests/integration/fixture/home_automation_agent/test_files/memorizing_past_events/eval_data.test.json",

View File

@ -15,7 +15,8 @@
from google.adk.evaluation import AgentEvaluator
def test_eval_agent():
@pytest.mark.asyncio
async def test_eval_agent():
AgentEvaluator.evaluate(
agent_module="tests.integration.fixture.home_automation_agent",
eval_dataset_file_path_or_dir="tests/integration/fixture/home_automation_agent/simple_test.test.json",

View File

@ -15,7 +15,8 @@
from google.adk.evaluation import AgentEvaluator
def test_eval_agent():
@pytest.mark.asyncio
async def test_eval_agent():
"""Test hotel sub agent in a multi-agent system."""
AgentEvaluator.evaluate(
agent_module="tests.integration.fixture.trip_planner_agent",

View File

@ -15,7 +15,8 @@
from google.adk.evaluation import AgentEvaluator
def test_with_single_test_file():
@pytest.mark.asyncio
async def test_with_single_test_file():
"""Test the agent's basic ability via session file."""
AgentEvaluator.evaluate(
agent_module="tests.integration.fixture.home_automation_agent",
@ -23,7 +24,8 @@ def test_with_single_test_file():
)
def test_with_folder_of_test_files_long_running():
@pytest.mark.asyncio
async def test_with_folder_of_test_files_long_running():
"""Test the agent's basic ability via a folder of session files."""
AgentEvaluator.evaluate(
agent_module="tests.integration.fixture.home_automation_agent",