Update Eval Run and TrajectoryEvaluator to use the new schema.

PiperOrigin-RevId: 758927160
This commit is contained in:
Ankur Sharma
2025-05-14 19:15:52 -07:00
committed by Copybara-Service
parent 2cb74dd20e
commit ee674ce0ef
9 changed files with 418 additions and 244 deletions
@@ -12,18 +12,98 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any
from typing import Any, cast
from deprecated import deprecated
from google.genai import types as genai_types
import pandas as pd
from tabulate import tabulate
from typing_extensions import override
from .eval_case import Invocation
from .evaluation_constants import EvalConstants
from .evaluator import EvalStatus
from .evaluator import EvaluationResult
from .evaluator import Evaluator
from .evaluator import PerInvocationResult
class TrajectoryEvaluator:
class TrajectoryEvaluator(Evaluator):
"""Evaluates tool use trajectories for accuracy."""
def __init__(self, threshold: float):
self._threshold = threshold
@override
def evaluate_invocations(
self,
actual_invocations: list[Invocation],
expected_invocations: list[Invocation],
) -> EvaluationResult:
"""Returns EvaluationResult after performing evaluations using actual and expected invocations."""
total_tool_use_accuracy = 0.0
num_invocations = 0
per_invocation_results = []
for actual, expected in zip(actual_invocations, expected_invocations):
actual_tool_uses = (
actual.intermediate_data.tool_uses if actual.intermediate_data else []
)
expected_tool_uses = (
expected.intermediate_data.tool_uses
if expected.intermediate_data
else []
)
tool_use_accuracy = (
1.0
if self._are_tool_calls_equal(actual_tool_uses, expected_tool_uses)
else 0.0
)
per_invocation_results.append(
PerInvocationResult(
actual_invocation=actual,
expected_invocation=expected,
score=tool_use_accuracy,
eval_status=self._get_eval_status(tool_use_accuracy),
)
)
total_tool_use_accuracy += tool_use_accuracy
num_invocations += 1
if per_invocation_results:
overall_score = total_tool_use_accuracy / num_invocations
return EvaluationResult(
overall_score=overall_score,
overall_eval_status=self._get_eval_status(overall_score),
per_invocation_results=per_invocation_results,
)
return EvaluationResult()
def _are_tool_calls_equal(
self,
actual_tool_calls: list[genai_types.FunctionCall],
expected_tool_calls: list[genai_types.FunctionCall],
) -> bool:
if len(actual_tool_calls) != len(expected_tool_calls):
return False
for actual, expected in zip(actual_tool_calls, expected_tool_calls):
if actual.name != expected.name or actual.args != expected.args:
return False
return True
def _get_eval_status(self, score: float):
return EvalStatus.PASSED if score >= self._threshold else EvalStatus.FAILED
@staticmethod
@deprecated(
reason=(
"This method has been deprecated and will be removed soon. Please use"
" evaluate_invocations instead."
)
)
def evaluate(
eval_dataset: list[list[dict[str, Any]]],
*,
@@ -137,6 +217,7 @@ class TrajectoryEvaluator:
return new_row, failure
@staticmethod
@deprecated()
def are_tools_equal(list_a_original, list_b_original):
# Remove other entries that we don't want to evaluate
list_a = [