mirror of
https://github.com/EvolutionAPI/adk-python.git
synced 2025-12-24 06:07:44 -06:00
Update Response Evaluators to use the new eval schema.
PiperOrigin-RevId: 758929683
This commit is contained in:
committed by
Copybara-Service
parent
ee674ce0ef
commit
ada24d7171
@@ -12,18 +12,122 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Any
|
||||
from typing import Any, Optional
|
||||
|
||||
from deprecated import deprecated
|
||||
from google.genai import types as genai_types
|
||||
import pandas as pd
|
||||
from tabulate import tabulate
|
||||
from typing_extensions import override
|
||||
from vertexai.preview.evaluation import EvalTask
|
||||
from vertexai.preview.evaluation import MetricPromptTemplateExamples
|
||||
|
||||
from .eval_case import IntermediateData
|
||||
from .eval_case import Invocation
|
||||
from .evaluator import EvalStatus
|
||||
from .evaluator import EvaluationResult
|
||||
from .evaluator import Evaluator
|
||||
from .evaluator import PerInvocationResult
|
||||
|
||||
class ResponseEvaluator:
|
||||
|
||||
class ResponseEvaluator(Evaluator):
|
||||
"""Runs response evaluation for agents."""
|
||||
|
||||
def __init__(self, threshold: float, metric_name: str):
|
||||
if "response_evaluation_score" == metric_name:
|
||||
self._metric_name = MetricPromptTemplateExamples.Pointwise.COHERENCE
|
||||
elif "response_match_score" == metric_name:
|
||||
self._metric_name = "rouge_1"
|
||||
else:
|
||||
raise ValueError(f"`{metric_name}` is not supported.")
|
||||
|
||||
self._threshold = threshold
|
||||
|
||||
@override
|
||||
def evaluate_invocations(
|
||||
self,
|
||||
actual_invocations: list[Invocation],
|
||||
expected_invocations: list[Invocation],
|
||||
) -> EvaluationResult:
|
||||
total_score = 0.0
|
||||
num_invocations = 0
|
||||
per_invocation_results = []
|
||||
for actual, expected in zip(actual_invocations, expected_invocations):
|
||||
prompt = self._get_text(expected.user_content)
|
||||
reference = self._get_text(expected.final_response)
|
||||
response = self._get_text(actual.final_response)
|
||||
actual_tool_use = self._get_tool_use_trajectory(actual.intermediate_data)
|
||||
reference_trajectory = self._get_tool_use_trajectory(
|
||||
expected.intermediate_data
|
||||
)
|
||||
|
||||
eval_case = {
|
||||
"prompt": prompt,
|
||||
"reference": reference,
|
||||
"response": response,
|
||||
"actual_tool_user": actual_tool_use,
|
||||
"reference_trajectory": reference_trajectory,
|
||||
}
|
||||
|
||||
eval_case_result = ResponseEvaluator._perform_eval(
|
||||
pd.DataFrame([eval_case]), [self._metric_name]
|
||||
)
|
||||
score = self._get_score(eval_case_result)
|
||||
per_invocation_results.append(
|
||||
PerInvocationResult(
|
||||
actual_invocation=actual,
|
||||
expected_invocation=expected,
|
||||
score=score,
|
||||
eval_status=self._get_eval_status(score),
|
||||
)
|
||||
)
|
||||
total_score += score
|
||||
num_invocations += 1
|
||||
|
||||
if per_invocation_results:
|
||||
overall_score = total_score / num_invocations
|
||||
return EvaluationResult(
|
||||
overall_score=overall_score,
|
||||
overall_eval_status=self._get_eval_status(overall_score),
|
||||
per_invocation_results=per_invocation_results,
|
||||
)
|
||||
|
||||
return EvaluationResult()
|
||||
|
||||
def _get_text(self, content: Optional[genai_types.Content]) -> str:
|
||||
if content and content.parts:
|
||||
return "\n".join([p.text for p in content.parts if p.text])
|
||||
|
||||
return ""
|
||||
|
||||
def _get_tool_use_trajectory(
|
||||
self, intermediate_data: Optional[IntermediateData]
|
||||
) -> list[dict[str, Any]]:
|
||||
tool_use_trajectory = []
|
||||
if not intermediate_data:
|
||||
return tool_use_trajectory
|
||||
|
||||
for function_call in intermediate_data.tool_uses:
|
||||
tool_use_trajectory.append({
|
||||
"tool_name": function_call.name,
|
||||
"tool_input": function_call.args or {},
|
||||
})
|
||||
|
||||
return tool_use_trajectory
|
||||
|
||||
def _get_score(self, eval_result) -> float:
|
||||
return eval_result.summary_metrics[f"{self._metric_name}/mean"].item()
|
||||
|
||||
def _get_eval_status(self, score: float):
|
||||
return EvalStatus.PASSED if score >= self._threshold else EvalStatus.FAILED
|
||||
|
||||
@staticmethod
|
||||
@deprecated(
|
||||
reason=(
|
||||
"This method has been deprecated and will be removed soon. Please use"
|
||||
" evaluate_invocations instead."
|
||||
)
|
||||
)
|
||||
def evaluate(
|
||||
raw_eval_dataset: list[list[dict[str, Any]]],
|
||||
evaluation_criteria: list[str],
|
||||
|
||||
Reference in New Issue
Block a user