Update Response Evaluators to use the new eval schema.

PiperOrigin-RevId: 758929683
This commit is contained in:
Ankur Sharma 2025-05-14 19:25:41 -07:00 committed by Copybara-Service
parent ee674ce0ef
commit ada24d7171
2 changed files with 149 additions and 54 deletions

View File

@ -32,6 +32,7 @@ from ..artifacts.base_artifact_service import BaseArtifactService
from ..evaluation.eval_case import EvalCase from ..evaluation.eval_case import EvalCase
from ..evaluation.eval_case import Invocation from ..evaluation.eval_case import Invocation
from ..evaluation.evaluator import EvalStatus from ..evaluation.evaluator import EvalStatus
from ..evaluation.evaluator import Evaluator
from ..sessions.base_session_service import BaseSessionService from ..sessions.base_session_service import BaseSessionService
from ..sessions.session import Session from ..sessions.session import Session
from .utils import common from .utils import common
@ -271,13 +272,13 @@ async def run_evals(
overall_eval_metric_results = [] overall_eval_metric_results = []
for eval_metric in eval_metrics: for eval_metric in eval_metrics:
if eval_metric.metric_name == TOOL_TRAJECTORY_SCORE_KEY: metric_evaluator = _get_evaluator(eval_metric)
evaluation_result = TrajectoryEvaluator(
eval_metric.threshold evaluation_result = metric_evaluator.evaluate_invocations(
).evaluate_invocations(
actual_invocations=inference_result, actual_invocations=inference_result,
expected_invocations=eval_case.conversation, expected_invocations=eval_case.conversation,
) )
overall_eval_metric_results.append( overall_eval_metric_results.append(
EvalMetricResult( EvalMetricResult(
metric_name=eval_metric.metric_name, metric_name=eval_metric.metric_name,
@ -289,9 +290,7 @@ async def run_evals(
for index, per_invocation_result in enumerate( for index, per_invocation_result in enumerate(
evaluation_result.per_invocation_results evaluation_result.per_invocation_results
): ):
eval_metric_result_per_invocation[ eval_metric_result_per_invocation[index].eval_metric_results.append(
index
].eval_metric_results.append(
EvalMetricResult( EvalMetricResult(
metric_name=eval_metric.metric_name, metric_name=eval_metric.metric_name,
threshold=eval_metric.threshold, threshold=eval_metric.threshold,
@ -300,27 +299,6 @@ async def run_evals(
) )
) )
# elif eval_metric.metric_name == RESPONSE_MATCH_SCORE_KEY:
# score = ResponseEvaluator.evaluate(
# [inference_result],
# [RESPONSE_MATCH_SCORE_KEY],
# print_detailed_results=print_detailed_results,
# )
# eval_metric_result = _get_eval_metric_result(
# eval_metric, score["rouge_1/mean"].item()
# )
# elif eval_metric.metric_name == RESPONSE_EVALUATION_SCORE_KEY:
# score = ResponseEvaluator.evaluate(
# [inference_result],
# [RESPONSE_EVALUATION_SCORE_KEY],
# print_detailed_results=print_detailed_results,
# )
# eval_metric_result = _get_eval_metric_result(
# eval_metric, score["coherence/mean"].item()
# )
else:
logger.warning("`%s` is not supported.", eval_metric.metric_name)
final_eval_status = EvalStatus.NOT_EVALUATED final_eval_status = EvalStatus.NOT_EVALUATED
# Go over the all the eval statuses and mark the final eval status as # Go over the all the eval statuses and mark the final eval status as
# passed if all of them pass, otherwise mark the final eval status to # passed if all of them pass, otherwise mark the final eval status to
@ -356,13 +334,26 @@ async def run_evals(
print(f"Result: {result}\n") print(f"Result: {result}\n")
except Exception as e: except Exception:
print(f"Error: {e}") # Catching the general exception, so that we don't block other eval
logger.info("Error: %s", str(traceback.format_exc())) # cases.
logger.exception(f"Eval failed for `{eval_set_id}:{eval_name}`")
def _get_eval_metric_result(eval_metric, score): def _get_evaluator(eval_metric: EvalMetric) -> Evaluator:
eval_status = ( try:
EvalStatus.PASSED if score >= eval_metric.threshold else EvalStatus.FAILED from ..evaluation.response_evaluator import ResponseEvaluator
from ..evaluation.trajectory_evaluator import TrajectoryEvaluator
except ModuleNotFoundError as e:
raise ModuleNotFoundError(MISSING_EVAL_DEPENDENCIES_MESSAGE) from e
if eval_metric.metric_name == TOOL_TRAJECTORY_SCORE_KEY:
return TrajectoryEvaluator(threshold=eval_metric.threshold)
elif (
eval_metric.metric_name == RESPONSE_MATCH_SCORE_KEY
or eval_metric == RESPONSE_EVALUATION_SCORE_KEY
):
return ResponseEvaluator(
threshold=eval_metric.threshold, metric_name=eval_metric.metric_name
) )
return EvalMetricResult(score=score, eval_status=eval_status)
raise ValueError(f"Unsupported eval metric: {eval_metric}")

View File

@ -12,18 +12,122 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from typing import Any from typing import Any, Optional
from deprecated import deprecated
from google.genai import types as genai_types
import pandas as pd import pandas as pd
from tabulate import tabulate from tabulate import tabulate
from typing_extensions import override
from vertexai.preview.evaluation import EvalTask from vertexai.preview.evaluation import EvalTask
from vertexai.preview.evaluation import MetricPromptTemplateExamples from vertexai.preview.evaluation import MetricPromptTemplateExamples
from .eval_case import IntermediateData
from .eval_case import Invocation
from .evaluator import EvalStatus
from .evaluator import EvaluationResult
from .evaluator import Evaluator
from .evaluator import PerInvocationResult
class ResponseEvaluator:
class ResponseEvaluator(Evaluator):
"""Runs response evaluation for agents.""" """Runs response evaluation for agents."""
def __init__(self, threshold: float, metric_name: str):
if "response_evaluation_score" == metric_name:
self._metric_name = MetricPromptTemplateExamples.Pointwise.COHERENCE
elif "response_match_score" == metric_name:
self._metric_name = "rouge_1"
else:
raise ValueError(f"`{metric_name}` is not supported.")
self._threshold = threshold
@override
def evaluate_invocations(
self,
actual_invocations: list[Invocation],
expected_invocations: list[Invocation],
) -> EvaluationResult:
total_score = 0.0
num_invocations = 0
per_invocation_results = []
for actual, expected in zip(actual_invocations, expected_invocations):
prompt = self._get_text(expected.user_content)
reference = self._get_text(expected.final_response)
response = self._get_text(actual.final_response)
actual_tool_use = self._get_tool_use_trajectory(actual.intermediate_data)
reference_trajectory = self._get_tool_use_trajectory(
expected.intermediate_data
)
eval_case = {
"prompt": prompt,
"reference": reference,
"response": response,
"actual_tool_user": actual_tool_use,
"reference_trajectory": reference_trajectory,
}
eval_case_result = ResponseEvaluator._perform_eval(
pd.DataFrame([eval_case]), [self._metric_name]
)
score = self._get_score(eval_case_result)
per_invocation_results.append(
PerInvocationResult(
actual_invocation=actual,
expected_invocation=expected,
score=score,
eval_status=self._get_eval_status(score),
)
)
total_score += score
num_invocations += 1
if per_invocation_results:
overall_score = total_score / num_invocations
return EvaluationResult(
overall_score=overall_score,
overall_eval_status=self._get_eval_status(overall_score),
per_invocation_results=per_invocation_results,
)
return EvaluationResult()
def _get_text(self, content: Optional[genai_types.Content]) -> str:
if content and content.parts:
return "\n".join([p.text for p in content.parts if p.text])
return ""
def _get_tool_use_trajectory(
self, intermediate_data: Optional[IntermediateData]
) -> list[dict[str, Any]]:
tool_use_trajectory = []
if not intermediate_data:
return tool_use_trajectory
for function_call in intermediate_data.tool_uses:
tool_use_trajectory.append({
"tool_name": function_call.name,
"tool_input": function_call.args or {},
})
return tool_use_trajectory
def _get_score(self, eval_result) -> float:
return eval_result.summary_metrics[f"{self._metric_name}/mean"].item()
def _get_eval_status(self, score: float):
return EvalStatus.PASSED if score >= self._threshold else EvalStatus.FAILED
@staticmethod @staticmethod
@deprecated(
reason=(
"This method has been deprecated and will be removed soon. Please use"
" evaluate_invocations instead."
)
)
def evaluate( def evaluate(
raw_eval_dataset: list[list[dict[str, Any]]], raw_eval_dataset: list[list[dict[str, Any]]],
evaluation_criteria: list[str], evaluation_criteria: list[str],