mirror of
https://github.com/EvolutionAPI/adk-python.git
synced 2025-07-14 01:41:25 -06:00
make eval functions async
PiperOrigin-RevId: 756106627
This commit is contained in:
parent
cc1ef3f2ad
commit
e7d9cf359a
@ -20,7 +20,7 @@ import os
|
||||
import sys
|
||||
import traceback
|
||||
from typing import Any
|
||||
from typing import Generator
|
||||
from typing import AsyncGenerator
|
||||
from typing import Optional
|
||||
import uuid
|
||||
|
||||
@ -146,7 +146,7 @@ def parse_and_get_evals_to_run(
|
||||
return eval_set_to_evals
|
||||
|
||||
|
||||
def run_evals(
|
||||
async def run_evals(
|
||||
eval_set_to_evals: dict[str, list[str]],
|
||||
root_agent: Agent,
|
||||
reset_func: Optional[Any],
|
||||
@ -154,7 +154,7 @@ def run_evals(
|
||||
session_service=None,
|
||||
artifact_service=None,
|
||||
print_detailed_results=False,
|
||||
) -> Generator[EvalResult, None, None]:
|
||||
) -> AsyncGenerator[EvalResult, None]:
|
||||
try:
|
||||
from ..evaluation.agent_evaluator import EvaluationGenerator
|
||||
from ..evaluation.response_evaluator import ResponseEvaluator
|
||||
@ -181,14 +181,16 @@ def run_evals(
|
||||
print(f"Running Eval: {eval_set_file}:{eval_name}")
|
||||
session_id = f"{EVAL_SESSION_ID_PREFIX}{str(uuid.uuid4())}"
|
||||
|
||||
scrape_result = EvaluationGenerator._process_query_with_root_agent(
|
||||
data=eval_data,
|
||||
root_agent=root_agent,
|
||||
reset_func=reset_func,
|
||||
initial_session=initial_session,
|
||||
session_id=session_id,
|
||||
session_service=session_service,
|
||||
artifact_service=artifact_service,
|
||||
scrape_result = (
|
||||
await EvaluationGenerator._process_query_with_root_agent(
|
||||
data=eval_data,
|
||||
root_agent=root_agent,
|
||||
reset_func=reset_func,
|
||||
initial_session=initial_session,
|
||||
session_id=session_id,
|
||||
session_service=session_service,
|
||||
artifact_service=artifact_service,
|
||||
)
|
||||
)
|
||||
|
||||
eval_metric_results = []
|
||||
|
@ -258,12 +258,14 @@ def cli_eval(
|
||||
|
||||
try:
|
||||
eval_results = list(
|
||||
run_evals(
|
||||
eval_set_to_evals,
|
||||
root_agent,
|
||||
reset_func,
|
||||
eval_metrics,
|
||||
print_detailed_results=print_detailed_results,
|
||||
asyncio.run(
|
||||
run_evals(
|
||||
eval_set_to_evals,
|
||||
root_agent,
|
||||
reset_func,
|
||||
eval_metrics,
|
||||
print_detailed_results=print_detailed_results,
|
||||
)
|
||||
)
|
||||
)
|
||||
except ModuleNotFoundError:
|
||||
|
@ -467,7 +467,7 @@ def get_fast_api_app(
|
||||
)
|
||||
root_agent = await _get_root_agent_async(app_name)
|
||||
eval_results = list(
|
||||
run_evals(
|
||||
await run_evals(
|
||||
eval_set_to_evals,
|
||||
root_agent,
|
||||
getattr(root_agent, "reset_data", None),
|
||||
|
@ -76,7 +76,7 @@ class AgentEvaluator:
|
||||
return DEFAULT_CRITERIA
|
||||
|
||||
@staticmethod
|
||||
def evaluate(
|
||||
async def evaluate(
|
||||
agent_module,
|
||||
eval_dataset_file_path_or_dir,
|
||||
num_runs=NUM_RUNS,
|
||||
@ -120,7 +120,7 @@ class AgentEvaluator:
|
||||
|
||||
AgentEvaluator._validate_input([dataset], criteria)
|
||||
|
||||
evaluation_response = AgentEvaluator._generate_responses(
|
||||
evaluation_response = await AgentEvaluator._generate_responses(
|
||||
agent_module,
|
||||
[dataset],
|
||||
num_runs,
|
||||
@ -246,7 +246,7 @@ class AgentEvaluator:
|
||||
return inferred_criteria
|
||||
|
||||
@staticmethod
|
||||
def _generate_responses(
|
||||
async def _generate_responses(
|
||||
agent_module, eval_dataset, num_runs, agent_name=None, initial_session={}
|
||||
):
|
||||
"""Generates evaluation responses by running the agent module multiple times."""
|
||||
|
@ -32,7 +32,7 @@ class EvaluationGenerator:
|
||||
"""Generates evaluation responses for agents."""
|
||||
|
||||
@staticmethod
|
||||
def generate_responses(
|
||||
async def generate_responses(
|
||||
eval_dataset,
|
||||
agent_module_path,
|
||||
repeat_num=3,
|
||||
@ -107,7 +107,7 @@ class EvaluationGenerator:
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _process_query_with_root_agent(
|
||||
async def _process_query_with_root_agent(
|
||||
data,
|
||||
root_agent,
|
||||
reset_func,
|
||||
@ -128,7 +128,7 @@ class EvaluationGenerator:
|
||||
all_mock_tools.add(expected[EvalConstants.TOOL_NAME])
|
||||
|
||||
eval_data_copy = data.copy()
|
||||
EvaluationGenerator.apply_before_tool_callback(
|
||||
await EvaluationGenerator.apply_before_tool_callback(
|
||||
root_agent,
|
||||
lambda *args: EvaluationGenerator.before_tool_callback(
|
||||
*args, eval_dataset=eval_data_copy
|
||||
@ -247,7 +247,7 @@ class EvaluationGenerator:
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def apply_before_tool_callback(
|
||||
async def apply_before_tool_callback(
|
||||
agent: BaseAgent,
|
||||
callback: BeforeToolCallback,
|
||||
all_mock_tools: set[str],
|
||||
@ -265,6 +265,6 @@ class EvaluationGenerator:
|
||||
|
||||
# Apply recursively to subagents if they exist
|
||||
for sub_agent in agent.sub_agents:
|
||||
EvaluationGenerator.apply_before_tool_callback(
|
||||
await EvaluationGenerator.apply_before_tool_callback(
|
||||
sub_agent, callback, all_mock_tools
|
||||
)
|
||||
|
@ -51,12 +51,13 @@ def agent_eval_artifacts_in_fixture():
|
||||
return agent_eval_artifacts
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
'agent_name, evalfile, initial_session_file',
|
||||
agent_eval_artifacts_in_fixture(),
|
||||
ids=[agent_name for agent_name, _, _ in agent_eval_artifacts_in_fixture()],
|
||||
)
|
||||
def test_evaluate_agents_long_running_4_runs_per_eval_item(
|
||||
async def test_evaluate_agents_long_running_4_runs_per_eval_item(
|
||||
agent_name, evalfile, initial_session_file
|
||||
):
|
||||
"""Test agents evaluation in fixture folder.
|
||||
@ -66,7 +67,7 @@ def test_evaluate_agents_long_running_4_runs_per_eval_item(
|
||||
|
||||
A single eval item is a session that can have multiple queries in it.
|
||||
"""
|
||||
AgentEvaluator.evaluate(
|
||||
await AgentEvaluator.evaluate(
|
||||
agent_module=agent_name,
|
||||
eval_dataset_file_path_or_dir=evalfile,
|
||||
initial_session_file=initial_session_file,
|
||||
|
@ -15,7 +15,8 @@
|
||||
from google.adk.evaluation import AgentEvaluator
|
||||
|
||||
|
||||
def test_eval_agent():
|
||||
@pytest.mark.asyncio
|
||||
async def test_eval_agent():
|
||||
AgentEvaluator.evaluate(
|
||||
agent_module="tests.integration.fixture.trip_planner_agent",
|
||||
eval_dataset_file_path_or_dir=(
|
||||
|
@ -15,7 +15,8 @@
|
||||
from google.adk.evaluation import AgentEvaluator
|
||||
|
||||
|
||||
def test_simple_multi_turn_conversation():
|
||||
@pytest.mark.asyncio
|
||||
async def test_simple_multi_turn_conversation():
|
||||
"""Test a simple multi-turn conversation."""
|
||||
AgentEvaluator.evaluate(
|
||||
agent_module="tests.integration.fixture.home_automation_agent",
|
||||
@ -24,7 +25,8 @@ def test_simple_multi_turn_conversation():
|
||||
)
|
||||
|
||||
|
||||
def test_dependent_tool_calls():
|
||||
@pytest.mark.asyncio
|
||||
async def test_dependent_tool_calls():
|
||||
"""Test subsequent tool calls that are dependent on previous tool calls."""
|
||||
AgentEvaluator.evaluate(
|
||||
agent_module="tests.integration.fixture.home_automation_agent",
|
||||
@ -33,8 +35,10 @@ def test_dependent_tool_calls():
|
||||
)
|
||||
|
||||
|
||||
def test_memorizing_past_events():
|
||||
@pytest.mark.asyncio
|
||||
async def test_memorizing_past_events():
|
||||
"""Test memorizing past events."""
|
||||
|
||||
AgentEvaluator.evaluate(
|
||||
agent_module="tests.integration.fixture.home_automation_agent",
|
||||
eval_dataset_file_path_or_dir="tests/integration/fixture/home_automation_agent/test_files/memorizing_past_events/eval_data.test.json",
|
||||
|
@ -15,7 +15,8 @@
|
||||
from google.adk.evaluation import AgentEvaluator
|
||||
|
||||
|
||||
def test_eval_agent():
|
||||
@pytest.mark.asyncio
|
||||
async def test_eval_agent():
|
||||
AgentEvaluator.evaluate(
|
||||
agent_module="tests.integration.fixture.home_automation_agent",
|
||||
eval_dataset_file_path_or_dir="tests/integration/fixture/home_automation_agent/simple_test.test.json",
|
||||
|
@ -15,7 +15,8 @@
|
||||
from google.adk.evaluation import AgentEvaluator
|
||||
|
||||
|
||||
def test_eval_agent():
|
||||
@pytest.mark.asyncio
|
||||
async def test_eval_agent():
|
||||
"""Test hotel sub agent in a multi-agent system."""
|
||||
AgentEvaluator.evaluate(
|
||||
agent_module="tests.integration.fixture.trip_planner_agent",
|
||||
|
@ -15,7 +15,8 @@
|
||||
from google.adk.evaluation import AgentEvaluator
|
||||
|
||||
|
||||
def test_with_single_test_file():
|
||||
@pytest.mark.asyncio
|
||||
async def test_with_single_test_file():
|
||||
"""Test the agent's basic ability via session file."""
|
||||
AgentEvaluator.evaluate(
|
||||
agent_module="tests.integration.fixture.home_automation_agent",
|
||||
@ -23,7 +24,8 @@ def test_with_single_test_file():
|
||||
)
|
||||
|
||||
|
||||
def test_with_folder_of_test_files_long_running():
|
||||
@pytest.mark.asyncio
|
||||
async def test_with_folder_of_test_files_long_running():
|
||||
"""Test the agent's basic ability via a folder of session files."""
|
||||
AgentEvaluator.evaluate(
|
||||
agent_module="tests.integration.fixture.home_automation_agent",
|
||||
|
Loading…
Reference in New Issue
Block a user