mirror of
https://github.com/EvolutionAPI/adk-python.git
synced 2025-07-14 01:41:25 -06:00
Update Response Evaluators to use the new eval schema.
PiperOrigin-RevId: 758929683
This commit is contained in:
parent
ee674ce0ef
commit
ada24d7171
@ -32,6 +32,7 @@ from ..artifacts.base_artifact_service import BaseArtifactService
|
|||||||
from ..evaluation.eval_case import EvalCase
|
from ..evaluation.eval_case import EvalCase
|
||||||
from ..evaluation.eval_case import Invocation
|
from ..evaluation.eval_case import Invocation
|
||||||
from ..evaluation.evaluator import EvalStatus
|
from ..evaluation.evaluator import EvalStatus
|
||||||
|
from ..evaluation.evaluator import Evaluator
|
||||||
from ..sessions.base_session_service import BaseSessionService
|
from ..sessions.base_session_service import BaseSessionService
|
||||||
from ..sessions.session import Session
|
from ..sessions.session import Session
|
||||||
from .utils import common
|
from .utils import common
|
||||||
@ -271,55 +272,32 @@ async def run_evals(
|
|||||||
overall_eval_metric_results = []
|
overall_eval_metric_results = []
|
||||||
|
|
||||||
for eval_metric in eval_metrics:
|
for eval_metric in eval_metrics:
|
||||||
if eval_metric.metric_name == TOOL_TRAJECTORY_SCORE_KEY:
|
metric_evaluator = _get_evaluator(eval_metric)
|
||||||
evaluation_result = TrajectoryEvaluator(
|
|
||||||
eval_metric.threshold
|
evaluation_result = metric_evaluator.evaluate_invocations(
|
||||||
).evaluate_invocations(
|
actual_invocations=inference_result,
|
||||||
actual_invocations=inference_result,
|
expected_invocations=eval_case.conversation,
|
||||||
expected_invocations=eval_case.conversation,
|
)
|
||||||
)
|
|
||||||
overall_eval_metric_results.append(
|
overall_eval_metric_results.append(
|
||||||
|
EvalMetricResult(
|
||||||
|
metric_name=eval_metric.metric_name,
|
||||||
|
threshold=eval_metric.threshold,
|
||||||
|
score=evaluation_result.overall_score,
|
||||||
|
eval_status=evaluation_result.overall_eval_status,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
for index, per_invocation_result in enumerate(
|
||||||
|
evaluation_result.per_invocation_results
|
||||||
|
):
|
||||||
|
eval_metric_result_per_invocation[index].eval_metric_results.append(
|
||||||
EvalMetricResult(
|
EvalMetricResult(
|
||||||
metric_name=eval_metric.metric_name,
|
metric_name=eval_metric.metric_name,
|
||||||
threshold=eval_metric.threshold,
|
threshold=eval_metric.threshold,
|
||||||
score=evaluation_result.overall_score,
|
score=per_invocation_result.score,
|
||||||
eval_status=evaluation_result.overall_eval_status,
|
eval_status=per_invocation_result.eval_status,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
for index, per_invocation_result in enumerate(
|
|
||||||
evaluation_result.per_invocation_results
|
|
||||||
):
|
|
||||||
eval_metric_result_per_invocation[
|
|
||||||
index
|
|
||||||
].eval_metric_results.append(
|
|
||||||
EvalMetricResult(
|
|
||||||
metric_name=eval_metric.metric_name,
|
|
||||||
threshold=eval_metric.threshold,
|
|
||||||
score=per_invocation_result.score,
|
|
||||||
eval_status=per_invocation_result.eval_status,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
# elif eval_metric.metric_name == RESPONSE_MATCH_SCORE_KEY:
|
|
||||||
# score = ResponseEvaluator.evaluate(
|
|
||||||
# [inference_result],
|
|
||||||
# [RESPONSE_MATCH_SCORE_KEY],
|
|
||||||
# print_detailed_results=print_detailed_results,
|
|
||||||
# )
|
|
||||||
# eval_metric_result = _get_eval_metric_result(
|
|
||||||
# eval_metric, score["rouge_1/mean"].item()
|
|
||||||
# )
|
|
||||||
# elif eval_metric.metric_name == RESPONSE_EVALUATION_SCORE_KEY:
|
|
||||||
# score = ResponseEvaluator.evaluate(
|
|
||||||
# [inference_result],
|
|
||||||
# [RESPONSE_EVALUATION_SCORE_KEY],
|
|
||||||
# print_detailed_results=print_detailed_results,
|
|
||||||
# )
|
|
||||||
# eval_metric_result = _get_eval_metric_result(
|
|
||||||
# eval_metric, score["coherence/mean"].item()
|
|
||||||
# )
|
|
||||||
else:
|
|
||||||
logger.warning("`%s` is not supported.", eval_metric.metric_name)
|
|
||||||
|
|
||||||
final_eval_status = EvalStatus.NOT_EVALUATED
|
final_eval_status = EvalStatus.NOT_EVALUATED
|
||||||
# Go over the all the eval statuses and mark the final eval status as
|
# Go over the all the eval statuses and mark the final eval status as
|
||||||
@ -356,13 +334,26 @@ async def run_evals(
|
|||||||
|
|
||||||
print(f"Result: {result}\n")
|
print(f"Result: {result}\n")
|
||||||
|
|
||||||
except Exception as e:
|
except Exception:
|
||||||
print(f"Error: {e}")
|
# Catching the general exception, so that we don't block other eval
|
||||||
logger.info("Error: %s", str(traceback.format_exc()))
|
# cases.
|
||||||
|
logger.exception(f"Eval failed for `{eval_set_id}:{eval_name}`")
|
||||||
|
|
||||||
|
|
||||||
def _get_eval_metric_result(eval_metric, score):
|
def _get_evaluator(eval_metric: EvalMetric) -> Evaluator:
|
||||||
eval_status = (
|
try:
|
||||||
EvalStatus.PASSED if score >= eval_metric.threshold else EvalStatus.FAILED
|
from ..evaluation.response_evaluator import ResponseEvaluator
|
||||||
)
|
from ..evaluation.trajectory_evaluator import TrajectoryEvaluator
|
||||||
return EvalMetricResult(score=score, eval_status=eval_status)
|
except ModuleNotFoundError as e:
|
||||||
|
raise ModuleNotFoundError(MISSING_EVAL_DEPENDENCIES_MESSAGE) from e
|
||||||
|
if eval_metric.metric_name == TOOL_TRAJECTORY_SCORE_KEY:
|
||||||
|
return TrajectoryEvaluator(threshold=eval_metric.threshold)
|
||||||
|
elif (
|
||||||
|
eval_metric.metric_name == RESPONSE_MATCH_SCORE_KEY
|
||||||
|
or eval_metric == RESPONSE_EVALUATION_SCORE_KEY
|
||||||
|
):
|
||||||
|
return ResponseEvaluator(
|
||||||
|
threshold=eval_metric.threshold, metric_name=eval_metric.metric_name
|
||||||
|
)
|
||||||
|
|
||||||
|
raise ValueError(f"Unsupported eval metric: {eval_metric}")
|
||||||
|
@ -12,18 +12,122 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
from typing import Any
|
from typing import Any, Optional
|
||||||
|
|
||||||
|
from deprecated import deprecated
|
||||||
|
from google.genai import types as genai_types
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
from tabulate import tabulate
|
from tabulate import tabulate
|
||||||
|
from typing_extensions import override
|
||||||
from vertexai.preview.evaluation import EvalTask
|
from vertexai.preview.evaluation import EvalTask
|
||||||
from vertexai.preview.evaluation import MetricPromptTemplateExamples
|
from vertexai.preview.evaluation import MetricPromptTemplateExamples
|
||||||
|
|
||||||
|
from .eval_case import IntermediateData
|
||||||
|
from .eval_case import Invocation
|
||||||
|
from .evaluator import EvalStatus
|
||||||
|
from .evaluator import EvaluationResult
|
||||||
|
from .evaluator import Evaluator
|
||||||
|
from .evaluator import PerInvocationResult
|
||||||
|
|
||||||
class ResponseEvaluator:
|
|
||||||
|
class ResponseEvaluator(Evaluator):
|
||||||
"""Runs response evaluation for agents."""
|
"""Runs response evaluation for agents."""
|
||||||
|
|
||||||
|
def __init__(self, threshold: float, metric_name: str):
|
||||||
|
if "response_evaluation_score" == metric_name:
|
||||||
|
self._metric_name = MetricPromptTemplateExamples.Pointwise.COHERENCE
|
||||||
|
elif "response_match_score" == metric_name:
|
||||||
|
self._metric_name = "rouge_1"
|
||||||
|
else:
|
||||||
|
raise ValueError(f"`{metric_name}` is not supported.")
|
||||||
|
|
||||||
|
self._threshold = threshold
|
||||||
|
|
||||||
|
@override
|
||||||
|
def evaluate_invocations(
|
||||||
|
self,
|
||||||
|
actual_invocations: list[Invocation],
|
||||||
|
expected_invocations: list[Invocation],
|
||||||
|
) -> EvaluationResult:
|
||||||
|
total_score = 0.0
|
||||||
|
num_invocations = 0
|
||||||
|
per_invocation_results = []
|
||||||
|
for actual, expected in zip(actual_invocations, expected_invocations):
|
||||||
|
prompt = self._get_text(expected.user_content)
|
||||||
|
reference = self._get_text(expected.final_response)
|
||||||
|
response = self._get_text(actual.final_response)
|
||||||
|
actual_tool_use = self._get_tool_use_trajectory(actual.intermediate_data)
|
||||||
|
reference_trajectory = self._get_tool_use_trajectory(
|
||||||
|
expected.intermediate_data
|
||||||
|
)
|
||||||
|
|
||||||
|
eval_case = {
|
||||||
|
"prompt": prompt,
|
||||||
|
"reference": reference,
|
||||||
|
"response": response,
|
||||||
|
"actual_tool_user": actual_tool_use,
|
||||||
|
"reference_trajectory": reference_trajectory,
|
||||||
|
}
|
||||||
|
|
||||||
|
eval_case_result = ResponseEvaluator._perform_eval(
|
||||||
|
pd.DataFrame([eval_case]), [self._metric_name]
|
||||||
|
)
|
||||||
|
score = self._get_score(eval_case_result)
|
||||||
|
per_invocation_results.append(
|
||||||
|
PerInvocationResult(
|
||||||
|
actual_invocation=actual,
|
||||||
|
expected_invocation=expected,
|
||||||
|
score=score,
|
||||||
|
eval_status=self._get_eval_status(score),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
total_score += score
|
||||||
|
num_invocations += 1
|
||||||
|
|
||||||
|
if per_invocation_results:
|
||||||
|
overall_score = total_score / num_invocations
|
||||||
|
return EvaluationResult(
|
||||||
|
overall_score=overall_score,
|
||||||
|
overall_eval_status=self._get_eval_status(overall_score),
|
||||||
|
per_invocation_results=per_invocation_results,
|
||||||
|
)
|
||||||
|
|
||||||
|
return EvaluationResult()
|
||||||
|
|
||||||
|
def _get_text(self, content: Optional[genai_types.Content]) -> str:
|
||||||
|
if content and content.parts:
|
||||||
|
return "\n".join([p.text for p in content.parts if p.text])
|
||||||
|
|
||||||
|
return ""
|
||||||
|
|
||||||
|
def _get_tool_use_trajectory(
|
||||||
|
self, intermediate_data: Optional[IntermediateData]
|
||||||
|
) -> list[dict[str, Any]]:
|
||||||
|
tool_use_trajectory = []
|
||||||
|
if not intermediate_data:
|
||||||
|
return tool_use_trajectory
|
||||||
|
|
||||||
|
for function_call in intermediate_data.tool_uses:
|
||||||
|
tool_use_trajectory.append({
|
||||||
|
"tool_name": function_call.name,
|
||||||
|
"tool_input": function_call.args or {},
|
||||||
|
})
|
||||||
|
|
||||||
|
return tool_use_trajectory
|
||||||
|
|
||||||
|
def _get_score(self, eval_result) -> float:
|
||||||
|
return eval_result.summary_metrics[f"{self._metric_name}/mean"].item()
|
||||||
|
|
||||||
|
def _get_eval_status(self, score: float):
|
||||||
|
return EvalStatus.PASSED if score >= self._threshold else EvalStatus.FAILED
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
@deprecated(
|
||||||
|
reason=(
|
||||||
|
"This method has been deprecated and will be removed soon. Please use"
|
||||||
|
" evaluate_invocations instead."
|
||||||
|
)
|
||||||
|
)
|
||||||
def evaluate(
|
def evaluate(
|
||||||
raw_eval_dataset: list[list[dict[str, Any]]],
|
raw_eval_dataset: list[list[dict[str, Any]]],
|
||||||
evaluation_criteria: list[str],
|
evaluation_criteria: list[str],
|
||||||
|
Loading…
Reference in New Issue
Block a user