make eval functions async

PiperOrigin-RevId: 756106627
This commit is contained in:
Xiang (Sean) Zhou 2025-05-07 19:52:47 -07:00 committed by Copybara-Service
parent cc1ef3f2ad
commit e7d9cf359a
11 changed files with 50 additions and 36 deletions

View File

@ -20,7 +20,7 @@ import os
import sys import sys
import traceback import traceback
from typing import Any from typing import Any
from typing import Generator from typing import AsyncGenerator
from typing import Optional from typing import Optional
import uuid import uuid
@ -146,7 +146,7 @@ def parse_and_get_evals_to_run(
return eval_set_to_evals return eval_set_to_evals
def run_evals( async def run_evals(
eval_set_to_evals: dict[str, list[str]], eval_set_to_evals: dict[str, list[str]],
root_agent: Agent, root_agent: Agent,
reset_func: Optional[Any], reset_func: Optional[Any],
@ -154,7 +154,7 @@ def run_evals(
session_service=None, session_service=None,
artifact_service=None, artifact_service=None,
print_detailed_results=False, print_detailed_results=False,
) -> Generator[EvalResult, None, None]: ) -> AsyncGenerator[EvalResult, None]:
try: try:
from ..evaluation.agent_evaluator import EvaluationGenerator from ..evaluation.agent_evaluator import EvaluationGenerator
from ..evaluation.response_evaluator import ResponseEvaluator from ..evaluation.response_evaluator import ResponseEvaluator
@ -181,14 +181,16 @@ def run_evals(
print(f"Running Eval: {eval_set_file}:{eval_name}") print(f"Running Eval: {eval_set_file}:{eval_name}")
session_id = f"{EVAL_SESSION_ID_PREFIX}{str(uuid.uuid4())}" session_id = f"{EVAL_SESSION_ID_PREFIX}{str(uuid.uuid4())}"
scrape_result = EvaluationGenerator._process_query_with_root_agent( scrape_result = (
data=eval_data, await EvaluationGenerator._process_query_with_root_agent(
root_agent=root_agent, data=eval_data,
reset_func=reset_func, root_agent=root_agent,
initial_session=initial_session, reset_func=reset_func,
session_id=session_id, initial_session=initial_session,
session_service=session_service, session_id=session_id,
artifact_service=artifact_service, session_service=session_service,
artifact_service=artifact_service,
)
) )
eval_metric_results = [] eval_metric_results = []

View File

@ -258,12 +258,14 @@ def cli_eval(
try: try:
eval_results = list( eval_results = list(
run_evals( asyncio.run(
eval_set_to_evals, run_evals(
root_agent, eval_set_to_evals,
reset_func, root_agent,
eval_metrics, reset_func,
print_detailed_results=print_detailed_results, eval_metrics,
print_detailed_results=print_detailed_results,
)
) )
) )
except ModuleNotFoundError: except ModuleNotFoundError:

View File

@ -467,7 +467,7 @@ def get_fast_api_app(
) )
root_agent = await _get_root_agent_async(app_name) root_agent = await _get_root_agent_async(app_name)
eval_results = list( eval_results = list(
run_evals( await run_evals(
eval_set_to_evals, eval_set_to_evals,
root_agent, root_agent,
getattr(root_agent, "reset_data", None), getattr(root_agent, "reset_data", None),

View File

@ -76,7 +76,7 @@ class AgentEvaluator:
return DEFAULT_CRITERIA return DEFAULT_CRITERIA
@staticmethod @staticmethod
def evaluate( async def evaluate(
agent_module, agent_module,
eval_dataset_file_path_or_dir, eval_dataset_file_path_or_dir,
num_runs=NUM_RUNS, num_runs=NUM_RUNS,
@ -120,7 +120,7 @@ class AgentEvaluator:
AgentEvaluator._validate_input([dataset], criteria) AgentEvaluator._validate_input([dataset], criteria)
evaluation_response = AgentEvaluator._generate_responses( evaluation_response = await AgentEvaluator._generate_responses(
agent_module, agent_module,
[dataset], [dataset],
num_runs, num_runs,
@ -246,7 +246,7 @@ class AgentEvaluator:
return inferred_criteria return inferred_criteria
@staticmethod @staticmethod
def _generate_responses( async def _generate_responses(
agent_module, eval_dataset, num_runs, agent_name=None, initial_session={} agent_module, eval_dataset, num_runs, agent_name=None, initial_session={}
): ):
"""Generates evaluation responses by running the agent module multiple times.""" """Generates evaluation responses by running the agent module multiple times."""

View File

@ -32,7 +32,7 @@ class EvaluationGenerator:
"""Generates evaluation responses for agents.""" """Generates evaluation responses for agents."""
@staticmethod @staticmethod
def generate_responses( async def generate_responses(
eval_dataset, eval_dataset,
agent_module_path, agent_module_path,
repeat_num=3, repeat_num=3,
@ -107,7 +107,7 @@ class EvaluationGenerator:
) )
@staticmethod @staticmethod
def _process_query_with_root_agent( async def _process_query_with_root_agent(
data, data,
root_agent, root_agent,
reset_func, reset_func,
@ -128,7 +128,7 @@ class EvaluationGenerator:
all_mock_tools.add(expected[EvalConstants.TOOL_NAME]) all_mock_tools.add(expected[EvalConstants.TOOL_NAME])
eval_data_copy = data.copy() eval_data_copy = data.copy()
EvaluationGenerator.apply_before_tool_callback( await EvaluationGenerator.apply_before_tool_callback(
root_agent, root_agent,
lambda *args: EvaluationGenerator.before_tool_callback( lambda *args: EvaluationGenerator.before_tool_callback(
*args, eval_dataset=eval_data_copy *args, eval_dataset=eval_data_copy
@ -247,7 +247,7 @@ class EvaluationGenerator:
return None return None
@staticmethod @staticmethod
def apply_before_tool_callback( async def apply_before_tool_callback(
agent: BaseAgent, agent: BaseAgent,
callback: BeforeToolCallback, callback: BeforeToolCallback,
all_mock_tools: set[str], all_mock_tools: set[str],
@ -265,6 +265,6 @@ class EvaluationGenerator:
# Apply recursively to subagents if they exist # Apply recursively to subagents if they exist
for sub_agent in agent.sub_agents: for sub_agent in agent.sub_agents:
EvaluationGenerator.apply_before_tool_callback( await EvaluationGenerator.apply_before_tool_callback(
sub_agent, callback, all_mock_tools sub_agent, callback, all_mock_tools
) )

View File

@ -51,12 +51,13 @@ def agent_eval_artifacts_in_fixture():
return agent_eval_artifacts return agent_eval_artifacts
@pytest.mark.asyncio
@pytest.mark.parametrize( @pytest.mark.parametrize(
'agent_name, evalfile, initial_session_file', 'agent_name, evalfile, initial_session_file',
agent_eval_artifacts_in_fixture(), agent_eval_artifacts_in_fixture(),
ids=[agent_name for agent_name, _, _ in agent_eval_artifacts_in_fixture()], ids=[agent_name for agent_name, _, _ in agent_eval_artifacts_in_fixture()],
) )
def test_evaluate_agents_long_running_4_runs_per_eval_item( async def test_evaluate_agents_long_running_4_runs_per_eval_item(
agent_name, evalfile, initial_session_file agent_name, evalfile, initial_session_file
): ):
"""Test agents evaluation in fixture folder. """Test agents evaluation in fixture folder.
@ -66,7 +67,7 @@ def test_evaluate_agents_long_running_4_runs_per_eval_item(
A single eval item is a session that can have multiple queries in it. A single eval item is a session that can have multiple queries in it.
""" """
AgentEvaluator.evaluate( await AgentEvaluator.evaluate(
agent_module=agent_name, agent_module=agent_name,
eval_dataset_file_path_or_dir=evalfile, eval_dataset_file_path_or_dir=evalfile,
initial_session_file=initial_session_file, initial_session_file=initial_session_file,

View File

@ -15,7 +15,8 @@
from google.adk.evaluation import AgentEvaluator from google.adk.evaluation import AgentEvaluator
def test_eval_agent(): @pytest.mark.asyncio
async def test_eval_agent():
AgentEvaluator.evaluate( AgentEvaluator.evaluate(
agent_module="tests.integration.fixture.trip_planner_agent", agent_module="tests.integration.fixture.trip_planner_agent",
eval_dataset_file_path_or_dir=( eval_dataset_file_path_or_dir=(

View File

@ -15,7 +15,8 @@
from google.adk.evaluation import AgentEvaluator from google.adk.evaluation import AgentEvaluator
def test_simple_multi_turn_conversation(): @pytest.mark.asyncio
async def test_simple_multi_turn_conversation():
"""Test a simple multi-turn conversation.""" """Test a simple multi-turn conversation."""
AgentEvaluator.evaluate( AgentEvaluator.evaluate(
agent_module="tests.integration.fixture.home_automation_agent", agent_module="tests.integration.fixture.home_automation_agent",
@ -24,7 +25,8 @@ def test_simple_multi_turn_conversation():
) )
def test_dependent_tool_calls(): @pytest.mark.asyncio
async def test_dependent_tool_calls():
"""Test subsequent tool calls that are dependent on previous tool calls.""" """Test subsequent tool calls that are dependent on previous tool calls."""
AgentEvaluator.evaluate( AgentEvaluator.evaluate(
agent_module="tests.integration.fixture.home_automation_agent", agent_module="tests.integration.fixture.home_automation_agent",
@ -33,8 +35,10 @@ def test_dependent_tool_calls():
) )
def test_memorizing_past_events(): @pytest.mark.asyncio
async def test_memorizing_past_events():
"""Test memorizing past events.""" """Test memorizing past events."""
AgentEvaluator.evaluate( AgentEvaluator.evaluate(
agent_module="tests.integration.fixture.home_automation_agent", agent_module="tests.integration.fixture.home_automation_agent",
eval_dataset_file_path_or_dir="tests/integration/fixture/home_automation_agent/test_files/memorizing_past_events/eval_data.test.json", eval_dataset_file_path_or_dir="tests/integration/fixture/home_automation_agent/test_files/memorizing_past_events/eval_data.test.json",

View File

@ -15,7 +15,8 @@
from google.adk.evaluation import AgentEvaluator from google.adk.evaluation import AgentEvaluator
def test_eval_agent(): @pytest.mark.asyncio
async def test_eval_agent():
AgentEvaluator.evaluate( AgentEvaluator.evaluate(
agent_module="tests.integration.fixture.home_automation_agent", agent_module="tests.integration.fixture.home_automation_agent",
eval_dataset_file_path_or_dir="tests/integration/fixture/home_automation_agent/simple_test.test.json", eval_dataset_file_path_or_dir="tests/integration/fixture/home_automation_agent/simple_test.test.json",

View File

@ -15,7 +15,8 @@
from google.adk.evaluation import AgentEvaluator from google.adk.evaluation import AgentEvaluator
def test_eval_agent(): @pytest.mark.asyncio
async def test_eval_agent():
"""Test hotel sub agent in a multi-agent system.""" """Test hotel sub agent in a multi-agent system."""
AgentEvaluator.evaluate( AgentEvaluator.evaluate(
agent_module="tests.integration.fixture.trip_planner_agent", agent_module="tests.integration.fixture.trip_planner_agent",

View File

@ -15,7 +15,8 @@
from google.adk.evaluation import AgentEvaluator from google.adk.evaluation import AgentEvaluator
def test_with_single_test_file(): @pytest.mark.asyncio
async def test_with_single_test_file():
"""Test the agent's basic ability via session file.""" """Test the agent's basic ability via session file."""
AgentEvaluator.evaluate( AgentEvaluator.evaluate(
agent_module="tests.integration.fixture.home_automation_agent", agent_module="tests.integration.fixture.home_automation_agent",
@ -23,7 +24,8 @@ def test_with_single_test_file():
) )
def test_with_folder_of_test_files_long_running(): @pytest.mark.asyncio
async def test_with_folder_of_test_files_long_running():
"""Test the agent's basic ability via a folder of session files.""" """Test the agent's basic ability via a folder of session files."""
AgentEvaluator.evaluate( AgentEvaluator.evaluate(
agent_module="tests.integration.fixture.home_automation_agent", agent_module="tests.integration.fixture.home_automation_agent",