From ee674ce0efef136a160cb4ed7fc62417aa259380 Mon Sep 17 00:00:00 2001 From: Ankur Sharma Date: Wed, 14 May 2025 19:15:52 -0700 Subject: [PATCH] Update Eval Run and TrajectoryEvaluator to use the new schema. PiperOrigin-RevId: 758927160 --- src/google/adk/cli/cli_eval.py | 234 +++++++++++------- src/google/adk/cli/cli_tools_click.py | 43 +++- src/google/adk/cli/fast_api.py | 64 ++--- src/google/adk/evaluation/agent_evaluator.py | 7 - src/google/adk/evaluation/eval_case.py | 4 +- .../adk/evaluation/evaluation_generator.py | 142 ++++------- src/google/adk/evaluation/evaluator.py | 56 +++++ .../adk/evaluation/local_eval_sets_manager.py | 27 +- .../adk/evaluation/trajectory_evaluator.py | 85 ++++++- 9 files changed, 418 insertions(+), 244 deletions(-) create mode 100644 src/google/adk/evaluation/evaluator.py diff --git a/src/google/adk/cli/cli_eval.py b/src/google/adk/cli/cli_eval.py index f2f2586..0574dbd 100644 --- a/src/google/adk/cli/cli_eval.py +++ b/src/google/adk/cli/cli_eval.py @@ -12,8 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import datetime -from enum import Enum import importlib.util import json import logging @@ -22,6 +20,7 @@ import sys import traceback from typing import Any from typing import AsyncGenerator +from typing import cast from typing import Optional import uuid @@ -29,36 +28,84 @@ from pydantic import BaseModel from pydantic import Field from ..agents import Agent +from ..artifacts.base_artifact_service import BaseArtifactService +from ..evaluation.eval_case import EvalCase +from ..evaluation.eval_case import Invocation +from ..evaluation.evaluator import EvalStatus +from ..sessions.base_session_service import BaseSessionService from ..sessions.session import Session from .utils import common logger = logging.getLogger(__name__) -class EvalStatus(Enum): - PASSED = 1 - FAILED = 2 - NOT_EVALUATED = 3 - - class EvalMetric(BaseModel): + """A metric used to evaluate a particular aspect of an eval case.""" + metric_name: str + """The name of the metric.""" + threshold: float + """A threshold value. Each metric decides how to interpret this threshold.""" -class EvalMetricResult(BaseModel): +class EvalMetricResult(EvalMetric): + """The actual computed score/value of a particular EvalMetric.""" + score: Optional[float] = None eval_status: EvalStatus +class EvalMetricResultPerInvocation(BaseModel): + """Eval metric results per invocation.""" + + actual_invocation: Invocation + """The actual invocation, usually obtained by inferencing the agent.""" + + expected_invocation: Invocation + """The expected invocation, usually the reference or golden invocation.""" + + eval_metric_results: list[EvalMetricResult] = [] + """Eval resutls for each applicable metric.""" + + class EvalCaseResult(common.BaseModel): - eval_set_file: str - eval_id: str + """Case-level evaluation results.""" + + eval_set_file: str = Field( + deprecated=True, + description="This field is deprecated, use eval_set_id instead.", + ) + eval_set_id: str = "" + """The eval set id.""" + + eval_id: str = "" + """The eval case id.""" + final_eval_status: EvalStatus - eval_metric_results: list[tuple[EvalMetric, EvalMetricResult]] + """Final evalu status for this eval case.""" + + eval_metric_results: list[tuple[EvalMetric, EvalMetricResult]] = Field( + deprecated=True, + description=( + "This field is deprecated, use overall_eval_metric_results instead." + ), + ) + + overall_eval_metric_results: list[EvalMetricResult] + """Overall result for each metric for the entire eval case.""" + + eval_metric_result_per_invocation: list[EvalMetricResultPerInvocation] + """Result for each metric on a per invocation basis.""" + session_id: str + """Session id of the session generated as result of inferencing/scraping stage of the eval.""" + session_details: Optional[Session] = None + """Session generated as result of inferencing/scraping stage of the eval.""" + user_id: Optional[str] = None + """User id used during inferencing/scraping stage of the eval.""" class EvalSetResult(common.BaseModel): @@ -161,14 +208,25 @@ def parse_and_get_evals_to_run( async def run_evals( - eval_set_to_evals: dict[str, list[str]], + eval_cases_by_eval_set_id: dict[str, list[EvalCase]], root_agent: Agent, reset_func: Optional[Any], eval_metrics: list[EvalMetric], - session_service=None, - artifact_service=None, - print_detailed_results=False, + session_service: Optional[BaseSessionService] = None, + artifact_service: Optional[BaseArtifactService] = None, ) -> AsyncGenerator[EvalCaseResult, None]: + """Returns a stream of EvalCaseResult for each eval case that was evaluated. + + Args: + eval_cases_by_eval_set_id: Eval cases categorized by eval set id to which + they belong. + root_agent: Agent to use for inferencing. + reset_func: If present, this will be called before invoking the agent before + every inferencing step. + eval_metrics: A list of metrics that should be used during evaluation. + session_service: The session service to use during inferencing. + artifact_service: The artifact service to use during inferencing. + """ try: from ..evaluation.agent_evaluator import EvaluationGenerator from ..evaluation.response_evaluator import ResponseEvaluator @@ -176,29 +234,19 @@ async def run_evals( except ModuleNotFoundError as e: raise ModuleNotFoundError(MISSING_EVAL_DEPENDENCIES_MESSAGE) from e - """Returns a summary of eval runs.""" - for eval_set_file, evals_to_run in eval_set_to_evals.items(): - with open(eval_set_file, "r", encoding="utf-8") as file: - eval_items = json.load(file) # Load JSON into a list - - assert eval_items, f"No eval data found in eval set file: {eval_set_file}" - - for eval_item in eval_items: - eval_name = eval_item["name"] - eval_data = eval_item["data"] - initial_session = eval_item.get("initial_session", {}) - user_id = initial_session.get("user_id", "test_user_id") - - if evals_to_run and eval_name not in evals_to_run: - continue + for eval_set_id, eval_cases in eval_cases_by_eval_set_id.items(): + for eval_case in eval_cases: + eval_name = eval_case.eval_id + initial_session = eval_case.session_input + user_id = initial_session.user_id if initial_session else "test_user_id" try: - print(f"Running Eval: {eval_set_file}:{eval_name}") + print(f"Running Eval: {eval_set_id}:{eval_name}") session_id = f"{EVAL_SESSION_ID_PREFIX}{str(uuid.uuid4())}" - scrape_result = ( - await EvaluationGenerator._process_query_with_root_agent( - data=eval_data, + inference_result = ( + await EvaluationGenerator._generate_inferences_from_root_agent( + invocations=eval_case.conversation, root_agent=root_agent, reset_func=reset_func, initial_session=initial_session, @@ -208,67 +256,95 @@ async def run_evals( ) ) - eval_metric_results = [] + # Initialize the per-invocation metric results to an empty list. + # We will fill this as we evaluate each metric. + eval_metric_result_per_invocation = [] + for actual, expected in zip(inference_result, eval_case.conversation): + eval_metric_result_per_invocation.append( + EvalMetricResultPerInvocation( + actual_invocation=actual, + expected_invocation=expected, + eval_metric_results=[], + ) + ) + + overall_eval_metric_results = [] + for eval_metric in eval_metrics: - eval_metric_result = None if eval_metric.metric_name == TOOL_TRAJECTORY_SCORE_KEY: - score = TrajectoryEvaluator.evaluate( - [scrape_result], print_detailed_results=print_detailed_results + evaluation_result = TrajectoryEvaluator( + eval_metric.threshold + ).evaluate_invocations( + actual_invocations=inference_result, + expected_invocations=eval_case.conversation, ) - eval_metric_result = _get_eval_metric_result(eval_metric, score) - elif eval_metric.metric_name == RESPONSE_MATCH_SCORE_KEY: - score = ResponseEvaluator.evaluate( - [scrape_result], - [RESPONSE_MATCH_SCORE_KEY], - print_detailed_results=print_detailed_results, - ) - eval_metric_result = _get_eval_metric_result( - eval_metric, score["rouge_1/mean"].item() - ) - elif eval_metric.metric_name == RESPONSE_EVALUATION_SCORE_KEY: - score = ResponseEvaluator.evaluate( - [scrape_result], - [RESPONSE_EVALUATION_SCORE_KEY], - print_detailed_results=print_detailed_results, - ) - eval_metric_result = _get_eval_metric_result( - eval_metric, score["coherence/mean"].item() + overall_eval_metric_results.append( + EvalMetricResult( + metric_name=eval_metric.metric_name, + threshold=eval_metric.threshold, + score=evaluation_result.overall_score, + eval_status=evaluation_result.overall_eval_status, + ) ) + for index, per_invocation_result in enumerate( + evaluation_result.per_invocation_results + ): + eval_metric_result_per_invocation[ + index + ].eval_metric_results.append( + EvalMetricResult( + metric_name=eval_metric.metric_name, + threshold=eval_metric.threshold, + score=per_invocation_result.score, + eval_status=per_invocation_result.eval_status, + ) + ) + + # elif eval_metric.metric_name == RESPONSE_MATCH_SCORE_KEY: + # score = ResponseEvaluator.evaluate( + # [inference_result], + # [RESPONSE_MATCH_SCORE_KEY], + # print_detailed_results=print_detailed_results, + # ) + # eval_metric_result = _get_eval_metric_result( + # eval_metric, score["rouge_1/mean"].item() + # ) + # elif eval_metric.metric_name == RESPONSE_EVALUATION_SCORE_KEY: + # score = ResponseEvaluator.evaluate( + # [inference_result], + # [RESPONSE_EVALUATION_SCORE_KEY], + # print_detailed_results=print_detailed_results, + # ) + # eval_metric_result = _get_eval_metric_result( + # eval_metric, score["coherence/mean"].item() + # ) else: logger.warning("`%s` is not supported.", eval_metric.metric_name) - eval_metric_results.append(( - eval_metric, - EvalMetricResult(eval_status=EvalStatus.NOT_EVALUATED), - )) - - eval_metric_results.append(( - eval_metric, - eval_metric_result, - )) - _print_eval_metric_result(eval_metric, eval_metric_result) final_eval_status = EvalStatus.NOT_EVALUATED - # Go over the all the eval statuses and mark the final eval status as # passed if all of them pass, otherwise mark the final eval status to # failed. - for eval_metric_result in eval_metric_results: - eval_status = eval_metric_result[1].eval_status - if eval_status == EvalStatus.PASSED: + for overall_eval_metric_result in overall_eval_metric_results: + overall_eval_status = overall_eval_metric_result.eval_status + if overall_eval_status == EvalStatus.PASSED: final_eval_status = EvalStatus.PASSED - elif eval_status == EvalStatus.NOT_EVALUATED: + elif overall_eval_status == EvalStatus.NOT_EVALUATED: continue - elif eval_status == EvalStatus.FAILED: + elif overall_eval_status == EvalStatus.FAILED: final_eval_status = EvalStatus.FAILED break else: raise ValueError("Unknown eval status.") yield EvalCaseResult( - eval_set_file=eval_set_file, + eval_set_file=eval_set_id, + eval_set_id=eval_set_id, eval_id=eval_name, final_eval_status=final_eval_status, - eval_metric_results=eval_metric_results, + eval_metric_results=[], + overall_eval_metric_results=overall_eval_metric_results, + eval_metric_result_per_invocation=eval_metric_result_per_invocation, session_id=session_id, user_id=user_id, ) @@ -290,11 +366,3 @@ def _get_eval_metric_result(eval_metric, score): EvalStatus.PASSED if score >= eval_metric.threshold else EvalStatus.FAILED ) return EvalMetricResult(score=score, eval_status=eval_status) - - -def _print_eval_metric_result(eval_metric, eval_metric_result): - print( - f"Metric: {eval_metric.metric_name}\tStatus:" - f" {eval_metric_result.eval_status}\tScore:" - f" {eval_metric_result.score}\tThreshold: {eval_metric.threshold}" - ) diff --git a/src/google/adk/cli/cli_tools_click.py b/src/google/adk/cli/cli_tools_click.py index d06d058..753a1a4 100644 --- a/src/google/adk/cli/cli_tools_click.py +++ b/src/google/adk/cli/cli_tools_click.py @@ -296,6 +296,7 @@ def cli_eval( from .cli_eval import parse_and_get_evals_to_run from .cli_eval import run_evals from .cli_eval import try_get_reset_func + from ..evaluation.local_eval_sets_manager import load_eval_set_from_file except ModuleNotFoundError: raise click.ClickException(MISSING_EVAL_DEPENDENCIES_MESSAGE) @@ -311,17 +312,27 @@ def cli_eval( root_agent = get_root_agent(agent_module_file_path) reset_func = try_get_reset_func(agent_module_file_path) - eval_set_to_evals = parse_and_get_evals_to_run(eval_set_file_path) + eval_set_file_path_to_evals = parse_and_get_evals_to_run(eval_set_file_path) + eval_set_id_to_eval_cases = {} + + # Read the eval_set files and get the cases. + for eval_set_file_path, eval_case_ids in eval_set_file_path_to_evals.items(): + eval_set = load_eval_set_from_file(eval_set_file_path, eval_set_file_path) + eval_cases = eval_set.eval_cases + + if eval_case_ids: + # There are eval_ids that we should select. + eval_cases = [ + e for e in eval_set.eval_cases if e.eval_id in eval_case_ids + ] + + eval_set_id_to_eval_cases[eval_set_file_path] = eval_cases async def _collect_eval_results() -> list[EvalCaseResult]: return [ result async for result in run_evals( - eval_set_to_evals, - root_agent, - reset_func, - eval_metrics, - print_detailed_results=print_detailed_results, + eval_set_id_to_eval_cases, root_agent, reset_func, eval_metrics ) ] @@ -336,20 +347,28 @@ def cli_eval( for eval_result in eval_results: eval_result: EvalCaseResult - if eval_result.eval_set_file not in eval_run_summary: - eval_run_summary[eval_result.eval_set_file] = [0, 0] + if eval_result.eval_set_id not in eval_run_summary: + eval_run_summary[eval_result.eval_set_id] = [0, 0] if eval_result.final_eval_status == EvalStatus.PASSED: - eval_run_summary[eval_result.eval_set_file][0] += 1 + eval_run_summary[eval_result.eval_set_id][0] += 1 else: - eval_run_summary[eval_result.eval_set_file][1] += 1 + eval_run_summary[eval_result.eval_set_id][1] += 1 print("Eval Run Summary") - for eval_set_file, pass_fail_count in eval_run_summary.items(): + for eval_set_id, pass_fail_count in eval_run_summary.items(): print( - f"{eval_set_file}:\n Tests passed: {pass_fail_count[0]}\n Tests" + f"{eval_set_id}:\n Tests passed: {pass_fail_count[0]}\n Tests" f" failed: {pass_fail_count[1]}" ) + if print_detailed_results: + for eval_result in eval_results: + eval_result: EvalCaseResult + print( + "*********************************************************************" + ) + print(eval_result.model_dump_json(indent=2)) + @main.command("web") @click.option( diff --git a/src/google/adk/cli/fast_api.py b/src/google/adk/cli/fast_api.py index 2783946..ea49143 100644 --- a/src/google/adk/cli/fast_api.py +++ b/src/google/adk/cli/fast_api.py @@ -48,6 +48,7 @@ from opentelemetry.exporter.cloud_trace import CloudTraceSpanExporter from opentelemetry.sdk.trace import export from opentelemetry.sdk.trace import ReadableSpan from opentelemetry.sdk.trace import TracerProvider +from pydantic import Field from pydantic import ValidationError from starlette.types import Lifespan from typing_extensions import override @@ -75,6 +76,7 @@ from .cli_eval import EVAL_SESSION_ID_PREFIX from .cli_eval import EvalCaseResult from .cli_eval import EvalMetric from .cli_eval import EvalMetricResult +from .cli_eval import EvalMetricResultPerInvocation from .cli_eval import EvalSetResult from .cli_eval import EvalStatus from .utils import common @@ -175,7 +177,14 @@ class RunEvalResult(common.BaseModel): eval_set_id: str eval_id: str final_eval_status: EvalStatus - eval_metric_results: list[tuple[EvalMetric, EvalMetricResult]] + eval_metric_results: list[tuple[EvalMetric, EvalMetricResult]] = Field( + deprecated=True, + description=( + "This field is deprecated, use overall_eval_metric_results instead." + ), + ) + overall_eval_metric_results: list[EvalMetricResult] + eval_metric_result_per_invocation: list[EvalMetricResultPerInvocation] user_id: str session_id: str @@ -480,25 +489,26 @@ def get_fast_api_app( async def run_eval( app_name: str, eval_set_id: str, req: RunEvalRequest ) -> list[RunEvalResult]: + """Runs an eval given the details in the eval request.""" from .cli_eval import run_evals - """Runs an eval given the details in the eval request.""" # Create a mapping from eval set file to all the evals that needed to be # run. envs.load_dotenv_for_agent(os.path.basename(app_name), agent_dir) - eval_set_file_path = _get_eval_set_file_path( - app_name, agent_dir, eval_set_id - ) - eval_set_to_evals = {eval_set_file_path: req.eval_ids} - if not req.eval_ids: - logger.info( - "Eval ids to run list is empty. We will all evals in the eval set." - ) + eval_set = eval_sets_manager.get_eval_set(app_name, eval_set_id) + + if req.eval_ids: + eval_cases = [e for e in eval_set.eval_cases if e.eval_id in req.eval_ids] + eval_set_to_evals = {eval_set_id: eval_cases} + else: + logger.info("Eval ids to run list is empty. We will run all eval cases.") + eval_set_to_evals = {eval_set_id: eval_set.eval_cases} + root_agent = await _get_root_agent_async(app_name) run_eval_results = [] eval_case_results = [] - async for eval_result in run_evals( + async for eval_case_result in run_evals( eval_set_to_evals, root_agent, getattr(root_agent, "reset_data", None), @@ -509,31 +519,23 @@ def get_fast_api_app( run_eval_results.append( RunEvalResult( app_name=app_name, - eval_set_file=eval_result.eval_set_file, + eval_set_file=eval_case_result.eval_set_file, eval_set_id=eval_set_id, - eval_id=eval_result.eval_id, - final_eval_status=eval_result.final_eval_status, - eval_metric_results=eval_result.eval_metric_results, - user_id=eval_result.user_id, - session_id=eval_result.session_id, + eval_id=eval_case_result.eval_id, + final_eval_status=eval_case_result.final_eval_status, + eval_metric_results=eval_case_result.eval_metric_results, + overall_eval_metric_results=eval_case_result.overall_eval_metric_results, + eval_metric_result_per_invocation=eval_case_result.eval_metric_result_per_invocation, + user_id=eval_case_result.user_id, + session_id=eval_case_result.session_id, ) ) - session = session_service.get_session( + eval_case_result.session_details = session_service.get_session( app_name=app_name, - user_id=eval_result.user_id, - session_id=eval_result.session_id, - ) - eval_case_results.append( - EvalCaseResult( - eval_set_file=eval_result.eval_set_file, - eval_id=eval_result.eval_id, - final_eval_status=eval_result.final_eval_status, - eval_metric_results=eval_result.eval_metric_results, - session_id=eval_result.session_id, - session_details=session, - user_id=eval_result.user_id, - ) + user_id=eval_case_result.user_id, + session_id=eval_case_result.session_id, ) + eval_case_results.append(eval_case_result) timestamp = time.time() eval_set_result_name = app_name + "_" + eval_set_id + "_" + str(timestamp) diff --git a/src/google/adk/evaluation/agent_evaluator.py b/src/google/adk/evaluation/agent_evaluator.py index d97cd1f..b7303d6 100644 --- a/src/google/adk/evaluation/agent_evaluator.py +++ b/src/google/adk/evaluation/agent_evaluator.py @@ -258,13 +258,6 @@ class AgentEvaluator: initial_session=initial_session, ) - @staticmethod - def _generate_responses_from_session(eval_dataset, session_path): - """Generates evaluation responses by running the agent module multiple times.""" - return EvaluationGenerator.generate_responses_from_session( - session_path, eval_dataset - ) - @staticmethod def _response_evaluation_required(criteria, eval_dataset): """Checks if response evaluation are needed.""" diff --git a/src/google/adk/evaluation/eval_case.py b/src/google/adk/evaluation/eval_case.py index d815a61..58a738d 100644 --- a/src/google/adk/evaluation/eval_case.py +++ b/src/google/adk/evaluation/eval_case.py @@ -23,10 +23,10 @@ from pydantic import Field class IntermediateData(BaseModel): """Container for intermediate data that an agent would generate as it responds with a final answer.""" - tool_uses: list[genai_types.FunctionCall] + tool_uses: list[genai_types.FunctionCall] = [] """Tool use trajectory in chronological order.""" - intermediate_responses: list[Tuple[str, list[genai_types.Part]]] + intermediate_responses: list[Tuple[str, list[genai_types.Part]]] = [] """Intermediate responses generated by sub-agents to convey progress or status in a multi-agent system, distinct from the final response. diff --git a/src/google/adk/evaluation/evaluation_generator.py b/src/google/adk/evaluation/evaluation_generator.py index 09fcf26..c59868e 100644 --- a/src/google/adk/evaluation/evaluation_generator.py +++ b/src/google/adk/evaluation/evaluation_generator.py @@ -13,19 +13,19 @@ # limitations under the License. import importlib +from typing import Any, Optional import uuid -from google.genai import types - -from ..agents.base_agent import BaseAgent from ..agents.llm_agent import Agent -from ..agents.llm_agent import BeforeToolCallback -from ..agents.llm_agent import LlmAgent +from ..artifacts.base_artifact_service import BaseArtifactService from ..artifacts.in_memory_artifact_service import InMemoryArtifactService from ..runners import Runner +from ..sessions.base_session_service import BaseSessionService from ..sessions.in_memory_session_service import InMemorySessionService from ..sessions.session import Session -from .evaluation_constants import EvalConstants +from .eval_case import IntermediateData +from .eval_case import Invocation +from .eval_case import SessionInput class EvaluationGenerator: @@ -102,56 +102,40 @@ class EvaluationGenerator: agent_to_evaluate = root_agent.find_agent(agent_name) assert agent_to_evaluate, f"Sub-Agent `{agent_name}` not found." - return EvaluationGenerator._process_query_with_root_agent( + return EvaluationGenerator._generate_inferences_from_root_agent( data, agent_to_evaluate, reset_func, initial_session ) @staticmethod - async def _process_query_with_root_agent( - data, - root_agent, - reset_func, - initial_session={}, - session_id=None, - session_service=None, - artifact_service=None, - ): - """Process a query using the agent and evaluation dataset.""" - - # we don't know which tools belong to which agent - # so we just apply to any agents that has certain tool outputs - all_mock_tools = set() - for eval_entry in data: - expected_tool_use = eval_entry.get(EvalConstants.EXPECTED_TOOL_USE, []) - for expected in expected_tool_use: - if EvalConstants.MOCK_TOOL_OUTPUT in expected: - all_mock_tools.add(expected[EvalConstants.TOOL_NAME]) - - eval_data_copy = data.copy() - await EvaluationGenerator.apply_before_tool_callback( - root_agent, - lambda *args: EvaluationGenerator.before_tool_callback( - *args, eval_dataset=eval_data_copy - ), - all_mock_tools, - ) - + async def _generate_inferences_from_root_agent( + invocations: list[Invocation], + root_agent: Agent, + reset_func: Any, + initial_session: Optional[SessionInput] = None, + session_id: Optional[str] = None, + session_service: Optional[BaseSessionService] = None, + artifact_service: Optional[BaseArtifactService] = None, + ) -> list[Invocation]: + """Scrapes the root agent given the list of Invocations.""" if not session_service: session_service = InMemorySessionService() - app_name = initial_session.get("app_name", "EvaluationGenerator") - user_id = initial_session.get("user_id", "test_user_id") + app_name = ( + initial_session.app_name if initial_session else "EvaluationGenerator" + ) + user_id = initial_session.user_id if initial_session else "test_user_id" session_id = session_id if session_id else str(uuid.uuid4()) _ = session_service.create_session( app_name=app_name, user_id=user_id, - state=initial_session.get("state", {}), + state=initial_session.state if initial_session else {}, session_id=session_id, ) if not artifact_service: artifact_service = InMemoryArtifactService() + runner = Runner( app_name=app_name, agent=root_agent, @@ -163,30 +147,37 @@ class EvaluationGenerator: if callable(reset_func): reset_func() - responses = data.copy() + response_invocations = [] - for index, eval_entry in enumerate(responses): - response = None - query = eval_entry["query"] - content = types.Content(role="user", parts=[types.Part(text=query)]) - turn_actual_tool_uses = [] + for invocation in invocations: + final_response = None + user_content = invocation.user_content + tool_uses = [] + invocation_id = "" for event in runner.run( - user_id=user_id, session_id=session_id, new_message=content + user_id=user_id, session_id=session_id, new_message=user_content ): + invocation_id = ( + event.invocation_id if not invocation_id else invocation_id + ) + if event.is_final_response() and event.content and event.content.parts: - response = event.content.parts[0].text + final_response = event.content elif event.get_function_calls(): for call in event.get_function_calls(): - turn_actual_tool_uses.append({ - EvalConstants.TOOL_NAME: call.name, - EvalConstants.TOOL_INPUT: call.args, - }) + tool_uses.append(call) - responses[index]["actual_tool_use"] = turn_actual_tool_uses - responses[index]["response"] = response + response_invocations.append( + Invocation( + invocation_id=invocation_id, + user_content=user_content, + final_response=final_response, + intermediate_data=IntermediateData(tool_uses=tool_uses), + ) + ) - return responses + return response_invocations @staticmethod def _process_query_with_session(session_data, data): @@ -225,46 +216,3 @@ class EvaluationGenerator: responses[index]["actual_tool_use"] = actual_tool_uses responses[index]["response"] = response return responses - - @staticmethod - def before_tool_callback(tool, args, tool_context, eval_dataset): - """Intercept specific tool calls and return predefined outputs - - from eval_dataset. - """ - for index, eval_entry in enumerate(eval_dataset): - expected_tool_use = eval_entry.get("expected_tool_use", []) - for expected in expected_tool_use: - if ( - EvalConstants.MOCK_TOOL_OUTPUT in expected - and tool.name == expected[EvalConstants.TOOL_NAME] - and args == expected.get(EvalConstants.TOOL_INPUT, {}) - ): - # pop the matched entry so we don't rematch again - eval_dataset.pop(index) - return {"result": expected[EvalConstants.MOCK_TOOL_OUTPUT]} - - return None - - @staticmethod - async def apply_before_tool_callback( - agent: BaseAgent, - callback: BeforeToolCallback, - all_mock_tools: set[str], - ): - """Recursively apply the before_tool_callback to the root agent and all its subagents.""" - # Check if the agent has tools that are defined by evalset. - # We use function names to check if tools match - if not isinstance(agent, Agent) and not isinstance(agent, LlmAgent): - return - - for tool in await agent.canonical_tools(): - tool_name = tool.name - if tool_name in all_mock_tools: - agent.before_tool_callback = callback - - # Apply recursively to subagents if they exist - for sub_agent in agent.sub_agents: - await EvaluationGenerator.apply_before_tool_callback( - sub_agent, callback, all_mock_tools - ) diff --git a/src/google/adk/evaluation/evaluator.py b/src/google/adk/evaluation/evaluator.py new file mode 100644 index 0000000..5b7bc98 --- /dev/null +++ b/src/google/adk/evaluation/evaluator.py @@ -0,0 +1,56 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from abc import ABC +from enum import Enum +from typing import Optional +from pydantic import BaseModel +from .eval_case import Invocation + + +class EvalStatus(Enum): + PASSED = 1 + FAILED = 2 + NOT_EVALUATED = 3 + + +class PerInvocationResult(BaseModel): + """Metric evaluation score per invocation.""" + + actual_invocation: Invocation + expected_invocation: Invocation + score: Optional[float] = None + eval_status: EvalStatus = EvalStatus.NOT_EVALUATED + + +class EvaluationResult(BaseModel): + overall_score: Optional[float] = None + """Overall score, based on each invocation.""" + + overall_eval_status: EvalStatus = EvalStatus.NOT_EVALUATED + """Overall status, based on each invocation.""" + + per_invocation_results: list[PerInvocationResult] = [] + + +class Evaluator(ABC): + """A merics evaluator interface.""" + + def evaluate_invocations( + self, + actual_invocations: list[Invocation], + expected_invocations: list[Invocation], + ) -> EvaluationResult: + """Returns EvaluationResult after performing evaluations using actual and expected invocations.""" + raise NotImplementedError() diff --git a/src/google/adk/evaluation/local_eval_sets_manager.py b/src/google/adk/evaluation/local_eval_sets_manager.py index 4e5b776..9c1b509 100644 --- a/src/google/adk/evaluation/local_eval_sets_manager.py +++ b/src/google/adk/evaluation/local_eval_sets_manager.py @@ -154,6 +154,22 @@ def convert_eval_set_to_pydanctic_schema( ) +def load_eval_set_from_file( + eval_set_file_path: str, eval_set_id: str +) -> EvalSet: + """Returns an EvalSet that is read from the given file.""" + with open(eval_set_file_path, "r", encoding="utf-8") as f: + content = f.read() + try: + return EvalSet.model_validate_json(content) + except ValidationError: + # We assume that the eval data was specified in the old format and try + # to convert it to the new format. + return convert_eval_set_to_pydanctic_schema( + eval_set_id, json.loads(content) + ) + + class LocalEvalSetsManager(EvalSetsManager): """An EvalSets manager that stores eval sets locally on disk.""" @@ -165,16 +181,7 @@ class LocalEvalSetsManager(EvalSetsManager): """Returns an EvalSet identified by an app_name and eval_set_id.""" # Load the eval set file data eval_set_file_path = self._get_eval_set_file_path(app_name, eval_set_id) - with open(eval_set_file_path, "r", encoding="utf-8") as f: - content = f.read() - try: - return EvalSet.model_validate_json(content) - except ValidationError: - # We assume that the eval data was specified in the old format and try - # to convert it to the new format. - return convert_eval_set_to_pydanctic_schema( - eval_set_id, json.loads(content) - ) + return load_eval_set_from_file(eval_set_file_path, eval_set_id) @override def create_eval_set(self, app_name: str, eval_set_id: str): diff --git a/src/google/adk/evaluation/trajectory_evaluator.py b/src/google/adk/evaluation/trajectory_evaluator.py index 29b4069..1291045 100644 --- a/src/google/adk/evaluation/trajectory_evaluator.py +++ b/src/google/adk/evaluation/trajectory_evaluator.py @@ -12,18 +12,98 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any +from typing import Any, cast +from deprecated import deprecated +from google.genai import types as genai_types import pandas as pd from tabulate import tabulate +from typing_extensions import override +from .eval_case import Invocation from .evaluation_constants import EvalConstants +from .evaluator import EvalStatus +from .evaluator import EvaluationResult +from .evaluator import Evaluator +from .evaluator import PerInvocationResult -class TrajectoryEvaluator: +class TrajectoryEvaluator(Evaluator): """Evaluates tool use trajectories for accuracy.""" + def __init__(self, threshold: float): + self._threshold = threshold + + @override + def evaluate_invocations( + self, + actual_invocations: list[Invocation], + expected_invocations: list[Invocation], + ) -> EvaluationResult: + """Returns EvaluationResult after performing evaluations using actual and expected invocations.""" + total_tool_use_accuracy = 0.0 + num_invocations = 0 + per_invocation_results = [] + + for actual, expected in zip(actual_invocations, expected_invocations): + actual_tool_uses = ( + actual.intermediate_data.tool_uses if actual.intermediate_data else [] + ) + expected_tool_uses = ( + expected.intermediate_data.tool_uses + if expected.intermediate_data + else [] + ) + tool_use_accuracy = ( + 1.0 + if self._are_tool_calls_equal(actual_tool_uses, expected_tool_uses) + else 0.0 + ) + per_invocation_results.append( + PerInvocationResult( + actual_invocation=actual, + expected_invocation=expected, + score=tool_use_accuracy, + eval_status=self._get_eval_status(tool_use_accuracy), + ) + ) + total_tool_use_accuracy += tool_use_accuracy + num_invocations += 1 + + if per_invocation_results: + overall_score = total_tool_use_accuracy / num_invocations + return EvaluationResult( + overall_score=overall_score, + overall_eval_status=self._get_eval_status(overall_score), + per_invocation_results=per_invocation_results, + ) + + return EvaluationResult() + + def _are_tool_calls_equal( + self, + actual_tool_calls: list[genai_types.FunctionCall], + expected_tool_calls: list[genai_types.FunctionCall], + ) -> bool: + if len(actual_tool_calls) != len(expected_tool_calls): + return False + + for actual, expected in zip(actual_tool_calls, expected_tool_calls): + if actual.name != expected.name or actual.args != expected.args: + return False + + return True + + def _get_eval_status(self, score: float): + return EvalStatus.PASSED if score >= self._threshold else EvalStatus.FAILED + @staticmethod + @deprecated( + reason=( + "This method has been deprecated and will be removed soon. Please use" + " evaluate_invocations instead." + ) + ) def evaluate( eval_dataset: list[list[dict[str, Any]]], *, @@ -137,6 +217,7 @@ class TrajectoryEvaluator: return new_row, failure @staticmethod + @deprecated() def are_tools_equal(list_a_original, list_b_original): # Remove other entries that we don't want to evaluate list_a = [