Update Response Evaluators to use the new eval schema.

PiperOrigin-RevId: 758929683
This commit is contained in:
Ankur Sharma
2025-05-14 19:25:41 -07:00
committed by Copybara-Service
parent ee674ce0ef
commit ada24d7171
2 changed files with 149 additions and 54 deletions
+43 -52
View File
@@ -32,6 +32,7 @@ from ..artifacts.base_artifact_service import BaseArtifactService
from ..evaluation.eval_case import EvalCase
from ..evaluation.eval_case import Invocation
from ..evaluation.evaluator import EvalStatus
from ..evaluation.evaluator import Evaluator
from ..sessions.base_session_service import BaseSessionService
from ..sessions.session import Session
from .utils import common
@@ -271,55 +272,32 @@ async def run_evals(
overall_eval_metric_results = []
for eval_metric in eval_metrics:
if eval_metric.metric_name == TOOL_TRAJECTORY_SCORE_KEY:
evaluation_result = TrajectoryEvaluator(
eval_metric.threshold
).evaluate_invocations(
actual_invocations=inference_result,
expected_invocations=eval_case.conversation,
)
overall_eval_metric_results.append(
metric_evaluator = _get_evaluator(eval_metric)
evaluation_result = metric_evaluator.evaluate_invocations(
actual_invocations=inference_result,
expected_invocations=eval_case.conversation,
)
overall_eval_metric_results.append(
EvalMetricResult(
metric_name=eval_metric.metric_name,
threshold=eval_metric.threshold,
score=evaluation_result.overall_score,
eval_status=evaluation_result.overall_eval_status,
)
)
for index, per_invocation_result in enumerate(
evaluation_result.per_invocation_results
):
eval_metric_result_per_invocation[index].eval_metric_results.append(
EvalMetricResult(
metric_name=eval_metric.metric_name,
threshold=eval_metric.threshold,
score=evaluation_result.overall_score,
eval_status=evaluation_result.overall_eval_status,
score=per_invocation_result.score,
eval_status=per_invocation_result.eval_status,
)
)
for index, per_invocation_result in enumerate(
evaluation_result.per_invocation_results
):
eval_metric_result_per_invocation[
index
].eval_metric_results.append(
EvalMetricResult(
metric_name=eval_metric.metric_name,
threshold=eval_metric.threshold,
score=per_invocation_result.score,
eval_status=per_invocation_result.eval_status,
)
)
# elif eval_metric.metric_name == RESPONSE_MATCH_SCORE_KEY:
# score = ResponseEvaluator.evaluate(
# [inference_result],
# [RESPONSE_MATCH_SCORE_KEY],
# print_detailed_results=print_detailed_results,
# )
# eval_metric_result = _get_eval_metric_result(
# eval_metric, score["rouge_1/mean"].item()
# )
# elif eval_metric.metric_name == RESPONSE_EVALUATION_SCORE_KEY:
# score = ResponseEvaluator.evaluate(
# [inference_result],
# [RESPONSE_EVALUATION_SCORE_KEY],
# print_detailed_results=print_detailed_results,
# )
# eval_metric_result = _get_eval_metric_result(
# eval_metric, score["coherence/mean"].item()
# )
else:
logger.warning("`%s` is not supported.", eval_metric.metric_name)
final_eval_status = EvalStatus.NOT_EVALUATED
# Go over the all the eval statuses and mark the final eval status as
@@ -356,13 +334,26 @@ async def run_evals(
print(f"Result: {result}\n")
except Exception as e:
print(f"Error: {e}")
logger.info("Error: %s", str(traceback.format_exc()))
except Exception:
# Catching the general exception, so that we don't block other eval
# cases.
logger.exception(f"Eval failed for `{eval_set_id}:{eval_name}`")
def _get_eval_metric_result(eval_metric, score):
eval_status = (
EvalStatus.PASSED if score >= eval_metric.threshold else EvalStatus.FAILED
)
return EvalMetricResult(score=score, eval_status=eval_status)
def _get_evaluator(eval_metric: EvalMetric) -> Evaluator:
try:
from ..evaluation.response_evaluator import ResponseEvaluator
from ..evaluation.trajectory_evaluator import TrajectoryEvaluator
except ModuleNotFoundError as e:
raise ModuleNotFoundError(MISSING_EVAL_DEPENDENCIES_MESSAGE) from e
if eval_metric.metric_name == TOOL_TRAJECTORY_SCORE_KEY:
return TrajectoryEvaluator(threshold=eval_metric.threshold)
elif (
eval_metric.metric_name == RESPONSE_MATCH_SCORE_KEY
or eval_metric == RESPONSE_EVALUATION_SCORE_KEY
):
return ResponseEvaluator(
threshold=eval_metric.threshold, metric_name=eval_metric.metric_name
)
raise ValueError(f"Unsupported eval metric: {eval_metric}")