diff --git a/src/google/adk/cli/cli_eval.py b/src/google/adk/cli/cli_eval.py index 0574dbd..0d62191 100644 --- a/src/google/adk/cli/cli_eval.py +++ b/src/google/adk/cli/cli_eval.py @@ -32,6 +32,7 @@ from ..artifacts.base_artifact_service import BaseArtifactService from ..evaluation.eval_case import EvalCase from ..evaluation.eval_case import Invocation from ..evaluation.evaluator import EvalStatus +from ..evaluation.evaluator import Evaluator from ..sessions.base_session_service import BaseSessionService from ..sessions.session import Session from .utils import common @@ -271,55 +272,32 @@ async def run_evals( overall_eval_metric_results = [] for eval_metric in eval_metrics: - if eval_metric.metric_name == TOOL_TRAJECTORY_SCORE_KEY: - evaluation_result = TrajectoryEvaluator( - eval_metric.threshold - ).evaluate_invocations( - actual_invocations=inference_result, - expected_invocations=eval_case.conversation, - ) - overall_eval_metric_results.append( + metric_evaluator = _get_evaluator(eval_metric) + + evaluation_result = metric_evaluator.evaluate_invocations( + actual_invocations=inference_result, + expected_invocations=eval_case.conversation, + ) + + overall_eval_metric_results.append( + EvalMetricResult( + metric_name=eval_metric.metric_name, + threshold=eval_metric.threshold, + score=evaluation_result.overall_score, + eval_status=evaluation_result.overall_eval_status, + ) + ) + for index, per_invocation_result in enumerate( + evaluation_result.per_invocation_results + ): + eval_metric_result_per_invocation[index].eval_metric_results.append( EvalMetricResult( metric_name=eval_metric.metric_name, threshold=eval_metric.threshold, - score=evaluation_result.overall_score, - eval_status=evaluation_result.overall_eval_status, + score=per_invocation_result.score, + eval_status=per_invocation_result.eval_status, ) ) - for index, per_invocation_result in enumerate( - evaluation_result.per_invocation_results - ): - eval_metric_result_per_invocation[ - index - ].eval_metric_results.append( - EvalMetricResult( - metric_name=eval_metric.metric_name, - threshold=eval_metric.threshold, - score=per_invocation_result.score, - eval_status=per_invocation_result.eval_status, - ) - ) - - # elif eval_metric.metric_name == RESPONSE_MATCH_SCORE_KEY: - # score = ResponseEvaluator.evaluate( - # [inference_result], - # [RESPONSE_MATCH_SCORE_KEY], - # print_detailed_results=print_detailed_results, - # ) - # eval_metric_result = _get_eval_metric_result( - # eval_metric, score["rouge_1/mean"].item() - # ) - # elif eval_metric.metric_name == RESPONSE_EVALUATION_SCORE_KEY: - # score = ResponseEvaluator.evaluate( - # [inference_result], - # [RESPONSE_EVALUATION_SCORE_KEY], - # print_detailed_results=print_detailed_results, - # ) - # eval_metric_result = _get_eval_metric_result( - # eval_metric, score["coherence/mean"].item() - # ) - else: - logger.warning("`%s` is not supported.", eval_metric.metric_name) final_eval_status = EvalStatus.NOT_EVALUATED # Go over the all the eval statuses and mark the final eval status as @@ -356,13 +334,26 @@ async def run_evals( print(f"Result: {result}\n") - except Exception as e: - print(f"Error: {e}") - logger.info("Error: %s", str(traceback.format_exc())) + except Exception: + # Catching the general exception, so that we don't block other eval + # cases. + logger.exception(f"Eval failed for `{eval_set_id}:{eval_name}`") -def _get_eval_metric_result(eval_metric, score): - eval_status = ( - EvalStatus.PASSED if score >= eval_metric.threshold else EvalStatus.FAILED - ) - return EvalMetricResult(score=score, eval_status=eval_status) +def _get_evaluator(eval_metric: EvalMetric) -> Evaluator: + try: + from ..evaluation.response_evaluator import ResponseEvaluator + from ..evaluation.trajectory_evaluator import TrajectoryEvaluator + except ModuleNotFoundError as e: + raise ModuleNotFoundError(MISSING_EVAL_DEPENDENCIES_MESSAGE) from e + if eval_metric.metric_name == TOOL_TRAJECTORY_SCORE_KEY: + return TrajectoryEvaluator(threshold=eval_metric.threshold) + elif ( + eval_metric.metric_name == RESPONSE_MATCH_SCORE_KEY + or eval_metric == RESPONSE_EVALUATION_SCORE_KEY + ): + return ResponseEvaluator( + threshold=eval_metric.threshold, metric_name=eval_metric.metric_name + ) + + raise ValueError(f"Unsupported eval metric: {eval_metric}") diff --git a/src/google/adk/evaluation/response_evaluator.py b/src/google/adk/evaluation/response_evaluator.py index ba25b3f..c444785 100644 --- a/src/google/adk/evaluation/response_evaluator.py +++ b/src/google/adk/evaluation/response_evaluator.py @@ -12,18 +12,122 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any +from typing import Any, Optional +from deprecated import deprecated +from google.genai import types as genai_types import pandas as pd from tabulate import tabulate +from typing_extensions import override from vertexai.preview.evaluation import EvalTask from vertexai.preview.evaluation import MetricPromptTemplateExamples +from .eval_case import IntermediateData +from .eval_case import Invocation +from .evaluator import EvalStatus +from .evaluator import EvaluationResult +from .evaluator import Evaluator +from .evaluator import PerInvocationResult -class ResponseEvaluator: + +class ResponseEvaluator(Evaluator): """Runs response evaluation for agents.""" + def __init__(self, threshold: float, metric_name: str): + if "response_evaluation_score" == metric_name: + self._metric_name = MetricPromptTemplateExamples.Pointwise.COHERENCE + elif "response_match_score" == metric_name: + self._metric_name = "rouge_1" + else: + raise ValueError(f"`{metric_name}` is not supported.") + + self._threshold = threshold + + @override + def evaluate_invocations( + self, + actual_invocations: list[Invocation], + expected_invocations: list[Invocation], + ) -> EvaluationResult: + total_score = 0.0 + num_invocations = 0 + per_invocation_results = [] + for actual, expected in zip(actual_invocations, expected_invocations): + prompt = self._get_text(expected.user_content) + reference = self._get_text(expected.final_response) + response = self._get_text(actual.final_response) + actual_tool_use = self._get_tool_use_trajectory(actual.intermediate_data) + reference_trajectory = self._get_tool_use_trajectory( + expected.intermediate_data + ) + + eval_case = { + "prompt": prompt, + "reference": reference, + "response": response, + "actual_tool_user": actual_tool_use, + "reference_trajectory": reference_trajectory, + } + + eval_case_result = ResponseEvaluator._perform_eval( + pd.DataFrame([eval_case]), [self._metric_name] + ) + score = self._get_score(eval_case_result) + per_invocation_results.append( + PerInvocationResult( + actual_invocation=actual, + expected_invocation=expected, + score=score, + eval_status=self._get_eval_status(score), + ) + ) + total_score += score + num_invocations += 1 + + if per_invocation_results: + overall_score = total_score / num_invocations + return EvaluationResult( + overall_score=overall_score, + overall_eval_status=self._get_eval_status(overall_score), + per_invocation_results=per_invocation_results, + ) + + return EvaluationResult() + + def _get_text(self, content: Optional[genai_types.Content]) -> str: + if content and content.parts: + return "\n".join([p.text for p in content.parts if p.text]) + + return "" + + def _get_tool_use_trajectory( + self, intermediate_data: Optional[IntermediateData] + ) -> list[dict[str, Any]]: + tool_use_trajectory = [] + if not intermediate_data: + return tool_use_trajectory + + for function_call in intermediate_data.tool_uses: + tool_use_trajectory.append({ + "tool_name": function_call.name, + "tool_input": function_call.args or {}, + }) + + return tool_use_trajectory + + def _get_score(self, eval_result) -> float: + return eval_result.summary_metrics[f"{self._metric_name}/mean"].item() + + def _get_eval_status(self, score: float): + return EvalStatus.PASSED if score >= self._threshold else EvalStatus.FAILED + @staticmethod + @deprecated( + reason=( + "This method has been deprecated and will be removed soon. Please use" + " evaluate_invocations instead." + ) + ) def evaluate( raw_eval_dataset: list[list[dict[str, Any]]], evaluation_criteria: list[str],