mirror of
https://github.com/EvolutionAPI/adk-python.git
synced 2026-02-05 06:16:24 -06:00
Update Eval Run and TrajectoryEvaluator to use the new schema.
PiperOrigin-RevId: 758927160
This commit is contained in:
committed by
Copybara-Service
parent
2cb74dd20e
commit
ee674ce0ef
@@ -12,18 +12,98 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Any
|
||||
from typing import Any, cast
|
||||
|
||||
from deprecated import deprecated
|
||||
from google.genai import types as genai_types
|
||||
import pandas as pd
|
||||
from tabulate import tabulate
|
||||
from typing_extensions import override
|
||||
|
||||
from .eval_case import Invocation
|
||||
from .evaluation_constants import EvalConstants
|
||||
from .evaluator import EvalStatus
|
||||
from .evaluator import EvaluationResult
|
||||
from .evaluator import Evaluator
|
||||
from .evaluator import PerInvocationResult
|
||||
|
||||
|
||||
class TrajectoryEvaluator:
|
||||
class TrajectoryEvaluator(Evaluator):
|
||||
"""Evaluates tool use trajectories for accuracy."""
|
||||
|
||||
def __init__(self, threshold: float):
|
||||
self._threshold = threshold
|
||||
|
||||
@override
|
||||
def evaluate_invocations(
|
||||
self,
|
||||
actual_invocations: list[Invocation],
|
||||
expected_invocations: list[Invocation],
|
||||
) -> EvaluationResult:
|
||||
"""Returns EvaluationResult after performing evaluations using actual and expected invocations."""
|
||||
total_tool_use_accuracy = 0.0
|
||||
num_invocations = 0
|
||||
per_invocation_results = []
|
||||
|
||||
for actual, expected in zip(actual_invocations, expected_invocations):
|
||||
actual_tool_uses = (
|
||||
actual.intermediate_data.tool_uses if actual.intermediate_data else []
|
||||
)
|
||||
expected_tool_uses = (
|
||||
expected.intermediate_data.tool_uses
|
||||
if expected.intermediate_data
|
||||
else []
|
||||
)
|
||||
tool_use_accuracy = (
|
||||
1.0
|
||||
if self._are_tool_calls_equal(actual_tool_uses, expected_tool_uses)
|
||||
else 0.0
|
||||
)
|
||||
per_invocation_results.append(
|
||||
PerInvocationResult(
|
||||
actual_invocation=actual,
|
||||
expected_invocation=expected,
|
||||
score=tool_use_accuracy,
|
||||
eval_status=self._get_eval_status(tool_use_accuracy),
|
||||
)
|
||||
)
|
||||
total_tool_use_accuracy += tool_use_accuracy
|
||||
num_invocations += 1
|
||||
|
||||
if per_invocation_results:
|
||||
overall_score = total_tool_use_accuracy / num_invocations
|
||||
return EvaluationResult(
|
||||
overall_score=overall_score,
|
||||
overall_eval_status=self._get_eval_status(overall_score),
|
||||
per_invocation_results=per_invocation_results,
|
||||
)
|
||||
|
||||
return EvaluationResult()
|
||||
|
||||
def _are_tool_calls_equal(
|
||||
self,
|
||||
actual_tool_calls: list[genai_types.FunctionCall],
|
||||
expected_tool_calls: list[genai_types.FunctionCall],
|
||||
) -> bool:
|
||||
if len(actual_tool_calls) != len(expected_tool_calls):
|
||||
return False
|
||||
|
||||
for actual, expected in zip(actual_tool_calls, expected_tool_calls):
|
||||
if actual.name != expected.name or actual.args != expected.args:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def _get_eval_status(self, score: float):
|
||||
return EvalStatus.PASSED if score >= self._threshold else EvalStatus.FAILED
|
||||
|
||||
@staticmethod
|
||||
@deprecated(
|
||||
reason=(
|
||||
"This method has been deprecated and will be removed soon. Please use"
|
||||
" evaluate_invocations instead."
|
||||
)
|
||||
)
|
||||
def evaluate(
|
||||
eval_dataset: list[list[dict[str, Any]]],
|
||||
*,
|
||||
@@ -137,6 +217,7 @@ class TrajectoryEvaluator:
|
||||
return new_row, failure
|
||||
|
||||
@staticmethod
|
||||
@deprecated()
|
||||
def are_tools_equal(list_a_original, list_b_original):
|
||||
# Remove other entries that we don't want to evaluate
|
||||
list_a = [
|
||||
|
||||
Reference in New Issue
Block a user