mirror of
https://github.com/EvolutionAPI/adk-python.git
synced 2025-12-18 19:32:21 -06:00
make eval functions async
PiperOrigin-RevId: 756106627
This commit is contained in:
committed by
Copybara-Service
parent
cc1ef3f2ad
commit
e7d9cf359a
@@ -20,7 +20,7 @@ import os
|
||||
import sys
|
||||
import traceback
|
||||
from typing import Any
|
||||
from typing import Generator
|
||||
from typing import AsyncGenerator
|
||||
from typing import Optional
|
||||
import uuid
|
||||
|
||||
@@ -146,7 +146,7 @@ def parse_and_get_evals_to_run(
|
||||
return eval_set_to_evals
|
||||
|
||||
|
||||
def run_evals(
|
||||
async def run_evals(
|
||||
eval_set_to_evals: dict[str, list[str]],
|
||||
root_agent: Agent,
|
||||
reset_func: Optional[Any],
|
||||
@@ -154,7 +154,7 @@ def run_evals(
|
||||
session_service=None,
|
||||
artifact_service=None,
|
||||
print_detailed_results=False,
|
||||
) -> Generator[EvalResult, None, None]:
|
||||
) -> AsyncGenerator[EvalResult, None]:
|
||||
try:
|
||||
from ..evaluation.agent_evaluator import EvaluationGenerator
|
||||
from ..evaluation.response_evaluator import ResponseEvaluator
|
||||
@@ -181,14 +181,16 @@ def run_evals(
|
||||
print(f"Running Eval: {eval_set_file}:{eval_name}")
|
||||
session_id = f"{EVAL_SESSION_ID_PREFIX}{str(uuid.uuid4())}"
|
||||
|
||||
scrape_result = EvaluationGenerator._process_query_with_root_agent(
|
||||
data=eval_data,
|
||||
root_agent=root_agent,
|
||||
reset_func=reset_func,
|
||||
initial_session=initial_session,
|
||||
session_id=session_id,
|
||||
session_service=session_service,
|
||||
artifact_service=artifact_service,
|
||||
scrape_result = (
|
||||
await EvaluationGenerator._process_query_with_root_agent(
|
||||
data=eval_data,
|
||||
root_agent=root_agent,
|
||||
reset_func=reset_func,
|
||||
initial_session=initial_session,
|
||||
session_id=session_id,
|
||||
session_service=session_service,
|
||||
artifact_service=artifact_service,
|
||||
)
|
||||
)
|
||||
|
||||
eval_metric_results = []
|
||||
|
||||
@@ -258,12 +258,14 @@ def cli_eval(
|
||||
|
||||
try:
|
||||
eval_results = list(
|
||||
run_evals(
|
||||
eval_set_to_evals,
|
||||
root_agent,
|
||||
reset_func,
|
||||
eval_metrics,
|
||||
print_detailed_results=print_detailed_results,
|
||||
asyncio.run(
|
||||
run_evals(
|
||||
eval_set_to_evals,
|
||||
root_agent,
|
||||
reset_func,
|
||||
eval_metrics,
|
||||
print_detailed_results=print_detailed_results,
|
||||
)
|
||||
)
|
||||
)
|
||||
except ModuleNotFoundError:
|
||||
|
||||
@@ -467,7 +467,7 @@ def get_fast_api_app(
|
||||
)
|
||||
root_agent = await _get_root_agent_async(app_name)
|
||||
eval_results = list(
|
||||
run_evals(
|
||||
await run_evals(
|
||||
eval_set_to_evals,
|
||||
root_agent,
|
||||
getattr(root_agent, "reset_data", None),
|
||||
|
||||
@@ -76,7 +76,7 @@ class AgentEvaluator:
|
||||
return DEFAULT_CRITERIA
|
||||
|
||||
@staticmethod
|
||||
def evaluate(
|
||||
async def evaluate(
|
||||
agent_module,
|
||||
eval_dataset_file_path_or_dir,
|
||||
num_runs=NUM_RUNS,
|
||||
@@ -120,7 +120,7 @@ class AgentEvaluator:
|
||||
|
||||
AgentEvaluator._validate_input([dataset], criteria)
|
||||
|
||||
evaluation_response = AgentEvaluator._generate_responses(
|
||||
evaluation_response = await AgentEvaluator._generate_responses(
|
||||
agent_module,
|
||||
[dataset],
|
||||
num_runs,
|
||||
@@ -246,7 +246,7 @@ class AgentEvaluator:
|
||||
return inferred_criteria
|
||||
|
||||
@staticmethod
|
||||
def _generate_responses(
|
||||
async def _generate_responses(
|
||||
agent_module, eval_dataset, num_runs, agent_name=None, initial_session={}
|
||||
):
|
||||
"""Generates evaluation responses by running the agent module multiple times."""
|
||||
|
||||
@@ -32,7 +32,7 @@ class EvaluationGenerator:
|
||||
"""Generates evaluation responses for agents."""
|
||||
|
||||
@staticmethod
|
||||
def generate_responses(
|
||||
async def generate_responses(
|
||||
eval_dataset,
|
||||
agent_module_path,
|
||||
repeat_num=3,
|
||||
@@ -107,7 +107,7 @@ class EvaluationGenerator:
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _process_query_with_root_agent(
|
||||
async def _process_query_with_root_agent(
|
||||
data,
|
||||
root_agent,
|
||||
reset_func,
|
||||
@@ -128,7 +128,7 @@ class EvaluationGenerator:
|
||||
all_mock_tools.add(expected[EvalConstants.TOOL_NAME])
|
||||
|
||||
eval_data_copy = data.copy()
|
||||
EvaluationGenerator.apply_before_tool_callback(
|
||||
await EvaluationGenerator.apply_before_tool_callback(
|
||||
root_agent,
|
||||
lambda *args: EvaluationGenerator.before_tool_callback(
|
||||
*args, eval_dataset=eval_data_copy
|
||||
@@ -247,7 +247,7 @@ class EvaluationGenerator:
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def apply_before_tool_callback(
|
||||
async def apply_before_tool_callback(
|
||||
agent: BaseAgent,
|
||||
callback: BeforeToolCallback,
|
||||
all_mock_tools: set[str],
|
||||
@@ -265,6 +265,6 @@ class EvaluationGenerator:
|
||||
|
||||
# Apply recursively to subagents if they exist
|
||||
for sub_agent in agent.sub_agents:
|
||||
EvaluationGenerator.apply_before_tool_callback(
|
||||
await EvaluationGenerator.apply_before_tool_callback(
|
||||
sub_agent, callback, all_mock_tools
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user