mirror of
https://github.com/EvolutionAPI/adk-python.git
synced 2025-07-14 01:41:25 -06:00
Update Eval Run and TrajectoryEvaluator to use the new schema.
PiperOrigin-RevId: 758927160
This commit is contained in:
parent
2cb74dd20e
commit
ee674ce0ef
@ -12,8 +12,6 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import datetime
|
|
||||||
from enum import Enum
|
|
||||||
import importlib.util
|
import importlib.util
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
@ -22,6 +20,7 @@ import sys
|
|||||||
import traceback
|
import traceback
|
||||||
from typing import Any
|
from typing import Any
|
||||||
from typing import AsyncGenerator
|
from typing import AsyncGenerator
|
||||||
|
from typing import cast
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
import uuid
|
import uuid
|
||||||
|
|
||||||
@ -29,36 +28,84 @@ from pydantic import BaseModel
|
|||||||
from pydantic import Field
|
from pydantic import Field
|
||||||
|
|
||||||
from ..agents import Agent
|
from ..agents import Agent
|
||||||
|
from ..artifacts.base_artifact_service import BaseArtifactService
|
||||||
|
from ..evaluation.eval_case import EvalCase
|
||||||
|
from ..evaluation.eval_case import Invocation
|
||||||
|
from ..evaluation.evaluator import EvalStatus
|
||||||
|
from ..sessions.base_session_service import BaseSessionService
|
||||||
from ..sessions.session import Session
|
from ..sessions.session import Session
|
||||||
from .utils import common
|
from .utils import common
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class EvalStatus(Enum):
|
|
||||||
PASSED = 1
|
|
||||||
FAILED = 2
|
|
||||||
NOT_EVALUATED = 3
|
|
||||||
|
|
||||||
|
|
||||||
class EvalMetric(BaseModel):
|
class EvalMetric(BaseModel):
|
||||||
|
"""A metric used to evaluate a particular aspect of an eval case."""
|
||||||
|
|
||||||
metric_name: str
|
metric_name: str
|
||||||
|
"""The name of the metric."""
|
||||||
|
|
||||||
threshold: float
|
threshold: float
|
||||||
|
"""A threshold value. Each metric decides how to interpret this threshold."""
|
||||||
|
|
||||||
|
|
||||||
class EvalMetricResult(BaseModel):
|
class EvalMetricResult(EvalMetric):
|
||||||
|
"""The actual computed score/value of a particular EvalMetric."""
|
||||||
|
|
||||||
score: Optional[float] = None
|
score: Optional[float] = None
|
||||||
eval_status: EvalStatus
|
eval_status: EvalStatus
|
||||||
|
|
||||||
|
|
||||||
|
class EvalMetricResultPerInvocation(BaseModel):
|
||||||
|
"""Eval metric results per invocation."""
|
||||||
|
|
||||||
|
actual_invocation: Invocation
|
||||||
|
"""The actual invocation, usually obtained by inferencing the agent."""
|
||||||
|
|
||||||
|
expected_invocation: Invocation
|
||||||
|
"""The expected invocation, usually the reference or golden invocation."""
|
||||||
|
|
||||||
|
eval_metric_results: list[EvalMetricResult] = []
|
||||||
|
"""Eval resutls for each applicable metric."""
|
||||||
|
|
||||||
|
|
||||||
class EvalCaseResult(common.BaseModel):
|
class EvalCaseResult(common.BaseModel):
|
||||||
eval_set_file: str
|
"""Case-level evaluation results."""
|
||||||
eval_id: str
|
|
||||||
|
eval_set_file: str = Field(
|
||||||
|
deprecated=True,
|
||||||
|
description="This field is deprecated, use eval_set_id instead.",
|
||||||
|
)
|
||||||
|
eval_set_id: str = ""
|
||||||
|
"""The eval set id."""
|
||||||
|
|
||||||
|
eval_id: str = ""
|
||||||
|
"""The eval case id."""
|
||||||
|
|
||||||
final_eval_status: EvalStatus
|
final_eval_status: EvalStatus
|
||||||
eval_metric_results: list[tuple[EvalMetric, EvalMetricResult]]
|
"""Final evalu status for this eval case."""
|
||||||
|
|
||||||
|
eval_metric_results: list[tuple[EvalMetric, EvalMetricResult]] = Field(
|
||||||
|
deprecated=True,
|
||||||
|
description=(
|
||||||
|
"This field is deprecated, use overall_eval_metric_results instead."
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
overall_eval_metric_results: list[EvalMetricResult]
|
||||||
|
"""Overall result for each metric for the entire eval case."""
|
||||||
|
|
||||||
|
eval_metric_result_per_invocation: list[EvalMetricResultPerInvocation]
|
||||||
|
"""Result for each metric on a per invocation basis."""
|
||||||
|
|
||||||
session_id: str
|
session_id: str
|
||||||
|
"""Session id of the session generated as result of inferencing/scraping stage of the eval."""
|
||||||
|
|
||||||
session_details: Optional[Session] = None
|
session_details: Optional[Session] = None
|
||||||
|
"""Session generated as result of inferencing/scraping stage of the eval."""
|
||||||
|
|
||||||
user_id: Optional[str] = None
|
user_id: Optional[str] = None
|
||||||
|
"""User id used during inferencing/scraping stage of the eval."""
|
||||||
|
|
||||||
|
|
||||||
class EvalSetResult(common.BaseModel):
|
class EvalSetResult(common.BaseModel):
|
||||||
@ -161,14 +208,25 @@ def parse_and_get_evals_to_run(
|
|||||||
|
|
||||||
|
|
||||||
async def run_evals(
|
async def run_evals(
|
||||||
eval_set_to_evals: dict[str, list[str]],
|
eval_cases_by_eval_set_id: dict[str, list[EvalCase]],
|
||||||
root_agent: Agent,
|
root_agent: Agent,
|
||||||
reset_func: Optional[Any],
|
reset_func: Optional[Any],
|
||||||
eval_metrics: list[EvalMetric],
|
eval_metrics: list[EvalMetric],
|
||||||
session_service=None,
|
session_service: Optional[BaseSessionService] = None,
|
||||||
artifact_service=None,
|
artifact_service: Optional[BaseArtifactService] = None,
|
||||||
print_detailed_results=False,
|
|
||||||
) -> AsyncGenerator[EvalCaseResult, None]:
|
) -> AsyncGenerator[EvalCaseResult, None]:
|
||||||
|
"""Returns a stream of EvalCaseResult for each eval case that was evaluated.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
eval_cases_by_eval_set_id: Eval cases categorized by eval set id to which
|
||||||
|
they belong.
|
||||||
|
root_agent: Agent to use for inferencing.
|
||||||
|
reset_func: If present, this will be called before invoking the agent before
|
||||||
|
every inferencing step.
|
||||||
|
eval_metrics: A list of metrics that should be used during evaluation.
|
||||||
|
session_service: The session service to use during inferencing.
|
||||||
|
artifact_service: The artifact service to use during inferencing.
|
||||||
|
"""
|
||||||
try:
|
try:
|
||||||
from ..evaluation.agent_evaluator import EvaluationGenerator
|
from ..evaluation.agent_evaluator import EvaluationGenerator
|
||||||
from ..evaluation.response_evaluator import ResponseEvaluator
|
from ..evaluation.response_evaluator import ResponseEvaluator
|
||||||
@ -176,29 +234,19 @@ async def run_evals(
|
|||||||
except ModuleNotFoundError as e:
|
except ModuleNotFoundError as e:
|
||||||
raise ModuleNotFoundError(MISSING_EVAL_DEPENDENCIES_MESSAGE) from e
|
raise ModuleNotFoundError(MISSING_EVAL_DEPENDENCIES_MESSAGE) from e
|
||||||
|
|
||||||
"""Returns a summary of eval runs."""
|
for eval_set_id, eval_cases in eval_cases_by_eval_set_id.items():
|
||||||
for eval_set_file, evals_to_run in eval_set_to_evals.items():
|
for eval_case in eval_cases:
|
||||||
with open(eval_set_file, "r", encoding="utf-8") as file:
|
eval_name = eval_case.eval_id
|
||||||
eval_items = json.load(file) # Load JSON into a list
|
initial_session = eval_case.session_input
|
||||||
|
user_id = initial_session.user_id if initial_session else "test_user_id"
|
||||||
assert eval_items, f"No eval data found in eval set file: {eval_set_file}"
|
|
||||||
|
|
||||||
for eval_item in eval_items:
|
|
||||||
eval_name = eval_item["name"]
|
|
||||||
eval_data = eval_item["data"]
|
|
||||||
initial_session = eval_item.get("initial_session", {})
|
|
||||||
user_id = initial_session.get("user_id", "test_user_id")
|
|
||||||
|
|
||||||
if evals_to_run and eval_name not in evals_to_run:
|
|
||||||
continue
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
print(f"Running Eval: {eval_set_file}:{eval_name}")
|
print(f"Running Eval: {eval_set_id}:{eval_name}")
|
||||||
session_id = f"{EVAL_SESSION_ID_PREFIX}{str(uuid.uuid4())}"
|
session_id = f"{EVAL_SESSION_ID_PREFIX}{str(uuid.uuid4())}"
|
||||||
|
|
||||||
scrape_result = (
|
inference_result = (
|
||||||
await EvaluationGenerator._process_query_with_root_agent(
|
await EvaluationGenerator._generate_inferences_from_root_agent(
|
||||||
data=eval_data,
|
invocations=eval_case.conversation,
|
||||||
root_agent=root_agent,
|
root_agent=root_agent,
|
||||||
reset_func=reset_func,
|
reset_func=reset_func,
|
||||||
initial_session=initial_session,
|
initial_session=initial_session,
|
||||||
@ -208,67 +256,95 @@ async def run_evals(
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
eval_metric_results = []
|
# Initialize the per-invocation metric results to an empty list.
|
||||||
|
# We will fill this as we evaluate each metric.
|
||||||
|
eval_metric_result_per_invocation = []
|
||||||
|
for actual, expected in zip(inference_result, eval_case.conversation):
|
||||||
|
eval_metric_result_per_invocation.append(
|
||||||
|
EvalMetricResultPerInvocation(
|
||||||
|
actual_invocation=actual,
|
||||||
|
expected_invocation=expected,
|
||||||
|
eval_metric_results=[],
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
overall_eval_metric_results = []
|
||||||
|
|
||||||
for eval_metric in eval_metrics:
|
for eval_metric in eval_metrics:
|
||||||
eval_metric_result = None
|
|
||||||
if eval_metric.metric_name == TOOL_TRAJECTORY_SCORE_KEY:
|
if eval_metric.metric_name == TOOL_TRAJECTORY_SCORE_KEY:
|
||||||
score = TrajectoryEvaluator.evaluate(
|
evaluation_result = TrajectoryEvaluator(
|
||||||
[scrape_result], print_detailed_results=print_detailed_results
|
eval_metric.threshold
|
||||||
|
).evaluate_invocations(
|
||||||
|
actual_invocations=inference_result,
|
||||||
|
expected_invocations=eval_case.conversation,
|
||||||
)
|
)
|
||||||
eval_metric_result = _get_eval_metric_result(eval_metric, score)
|
overall_eval_metric_results.append(
|
||||||
elif eval_metric.metric_name == RESPONSE_MATCH_SCORE_KEY:
|
EvalMetricResult(
|
||||||
score = ResponseEvaluator.evaluate(
|
metric_name=eval_metric.metric_name,
|
||||||
[scrape_result],
|
threshold=eval_metric.threshold,
|
||||||
[RESPONSE_MATCH_SCORE_KEY],
|
score=evaluation_result.overall_score,
|
||||||
print_detailed_results=print_detailed_results,
|
eval_status=evaluation_result.overall_eval_status,
|
||||||
)
|
)
|
||||||
eval_metric_result = _get_eval_metric_result(
|
|
||||||
eval_metric, score["rouge_1/mean"].item()
|
|
||||||
)
|
|
||||||
elif eval_metric.metric_name == RESPONSE_EVALUATION_SCORE_KEY:
|
|
||||||
score = ResponseEvaluator.evaluate(
|
|
||||||
[scrape_result],
|
|
||||||
[RESPONSE_EVALUATION_SCORE_KEY],
|
|
||||||
print_detailed_results=print_detailed_results,
|
|
||||||
)
|
|
||||||
eval_metric_result = _get_eval_metric_result(
|
|
||||||
eval_metric, score["coherence/mean"].item()
|
|
||||||
)
|
)
|
||||||
|
for index, per_invocation_result in enumerate(
|
||||||
|
evaluation_result.per_invocation_results
|
||||||
|
):
|
||||||
|
eval_metric_result_per_invocation[
|
||||||
|
index
|
||||||
|
].eval_metric_results.append(
|
||||||
|
EvalMetricResult(
|
||||||
|
metric_name=eval_metric.metric_name,
|
||||||
|
threshold=eval_metric.threshold,
|
||||||
|
score=per_invocation_result.score,
|
||||||
|
eval_status=per_invocation_result.eval_status,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# elif eval_metric.metric_name == RESPONSE_MATCH_SCORE_KEY:
|
||||||
|
# score = ResponseEvaluator.evaluate(
|
||||||
|
# [inference_result],
|
||||||
|
# [RESPONSE_MATCH_SCORE_KEY],
|
||||||
|
# print_detailed_results=print_detailed_results,
|
||||||
|
# )
|
||||||
|
# eval_metric_result = _get_eval_metric_result(
|
||||||
|
# eval_metric, score["rouge_1/mean"].item()
|
||||||
|
# )
|
||||||
|
# elif eval_metric.metric_name == RESPONSE_EVALUATION_SCORE_KEY:
|
||||||
|
# score = ResponseEvaluator.evaluate(
|
||||||
|
# [inference_result],
|
||||||
|
# [RESPONSE_EVALUATION_SCORE_KEY],
|
||||||
|
# print_detailed_results=print_detailed_results,
|
||||||
|
# )
|
||||||
|
# eval_metric_result = _get_eval_metric_result(
|
||||||
|
# eval_metric, score["coherence/mean"].item()
|
||||||
|
# )
|
||||||
else:
|
else:
|
||||||
logger.warning("`%s` is not supported.", eval_metric.metric_name)
|
logger.warning("`%s` is not supported.", eval_metric.metric_name)
|
||||||
eval_metric_results.append((
|
|
||||||
eval_metric,
|
|
||||||
EvalMetricResult(eval_status=EvalStatus.NOT_EVALUATED),
|
|
||||||
))
|
|
||||||
|
|
||||||
eval_metric_results.append((
|
|
||||||
eval_metric,
|
|
||||||
eval_metric_result,
|
|
||||||
))
|
|
||||||
_print_eval_metric_result(eval_metric, eval_metric_result)
|
|
||||||
|
|
||||||
final_eval_status = EvalStatus.NOT_EVALUATED
|
final_eval_status = EvalStatus.NOT_EVALUATED
|
||||||
|
|
||||||
# Go over the all the eval statuses and mark the final eval status as
|
# Go over the all the eval statuses and mark the final eval status as
|
||||||
# passed if all of them pass, otherwise mark the final eval status to
|
# passed if all of them pass, otherwise mark the final eval status to
|
||||||
# failed.
|
# failed.
|
||||||
for eval_metric_result in eval_metric_results:
|
for overall_eval_metric_result in overall_eval_metric_results:
|
||||||
eval_status = eval_metric_result[1].eval_status
|
overall_eval_status = overall_eval_metric_result.eval_status
|
||||||
if eval_status == EvalStatus.PASSED:
|
if overall_eval_status == EvalStatus.PASSED:
|
||||||
final_eval_status = EvalStatus.PASSED
|
final_eval_status = EvalStatus.PASSED
|
||||||
elif eval_status == EvalStatus.NOT_EVALUATED:
|
elif overall_eval_status == EvalStatus.NOT_EVALUATED:
|
||||||
continue
|
continue
|
||||||
elif eval_status == EvalStatus.FAILED:
|
elif overall_eval_status == EvalStatus.FAILED:
|
||||||
final_eval_status = EvalStatus.FAILED
|
final_eval_status = EvalStatus.FAILED
|
||||||
break
|
break
|
||||||
else:
|
else:
|
||||||
raise ValueError("Unknown eval status.")
|
raise ValueError("Unknown eval status.")
|
||||||
|
|
||||||
yield EvalCaseResult(
|
yield EvalCaseResult(
|
||||||
eval_set_file=eval_set_file,
|
eval_set_file=eval_set_id,
|
||||||
|
eval_set_id=eval_set_id,
|
||||||
eval_id=eval_name,
|
eval_id=eval_name,
|
||||||
final_eval_status=final_eval_status,
|
final_eval_status=final_eval_status,
|
||||||
eval_metric_results=eval_metric_results,
|
eval_metric_results=[],
|
||||||
|
overall_eval_metric_results=overall_eval_metric_results,
|
||||||
|
eval_metric_result_per_invocation=eval_metric_result_per_invocation,
|
||||||
session_id=session_id,
|
session_id=session_id,
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
)
|
)
|
||||||
@ -290,11 +366,3 @@ def _get_eval_metric_result(eval_metric, score):
|
|||||||
EvalStatus.PASSED if score >= eval_metric.threshold else EvalStatus.FAILED
|
EvalStatus.PASSED if score >= eval_metric.threshold else EvalStatus.FAILED
|
||||||
)
|
)
|
||||||
return EvalMetricResult(score=score, eval_status=eval_status)
|
return EvalMetricResult(score=score, eval_status=eval_status)
|
||||||
|
|
||||||
|
|
||||||
def _print_eval_metric_result(eval_metric, eval_metric_result):
|
|
||||||
print(
|
|
||||||
f"Metric: {eval_metric.metric_name}\tStatus:"
|
|
||||||
f" {eval_metric_result.eval_status}\tScore:"
|
|
||||||
f" {eval_metric_result.score}\tThreshold: {eval_metric.threshold}"
|
|
||||||
)
|
|
||||||
|
@ -296,6 +296,7 @@ def cli_eval(
|
|||||||
from .cli_eval import parse_and_get_evals_to_run
|
from .cli_eval import parse_and_get_evals_to_run
|
||||||
from .cli_eval import run_evals
|
from .cli_eval import run_evals
|
||||||
from .cli_eval import try_get_reset_func
|
from .cli_eval import try_get_reset_func
|
||||||
|
from ..evaluation.local_eval_sets_manager import load_eval_set_from_file
|
||||||
except ModuleNotFoundError:
|
except ModuleNotFoundError:
|
||||||
raise click.ClickException(MISSING_EVAL_DEPENDENCIES_MESSAGE)
|
raise click.ClickException(MISSING_EVAL_DEPENDENCIES_MESSAGE)
|
||||||
|
|
||||||
@ -311,17 +312,27 @@ def cli_eval(
|
|||||||
root_agent = get_root_agent(agent_module_file_path)
|
root_agent = get_root_agent(agent_module_file_path)
|
||||||
reset_func = try_get_reset_func(agent_module_file_path)
|
reset_func = try_get_reset_func(agent_module_file_path)
|
||||||
|
|
||||||
eval_set_to_evals = parse_and_get_evals_to_run(eval_set_file_path)
|
eval_set_file_path_to_evals = parse_and_get_evals_to_run(eval_set_file_path)
|
||||||
|
eval_set_id_to_eval_cases = {}
|
||||||
|
|
||||||
|
# Read the eval_set files and get the cases.
|
||||||
|
for eval_set_file_path, eval_case_ids in eval_set_file_path_to_evals.items():
|
||||||
|
eval_set = load_eval_set_from_file(eval_set_file_path, eval_set_file_path)
|
||||||
|
eval_cases = eval_set.eval_cases
|
||||||
|
|
||||||
|
if eval_case_ids:
|
||||||
|
# There are eval_ids that we should select.
|
||||||
|
eval_cases = [
|
||||||
|
e for e in eval_set.eval_cases if e.eval_id in eval_case_ids
|
||||||
|
]
|
||||||
|
|
||||||
|
eval_set_id_to_eval_cases[eval_set_file_path] = eval_cases
|
||||||
|
|
||||||
async def _collect_eval_results() -> list[EvalCaseResult]:
|
async def _collect_eval_results() -> list[EvalCaseResult]:
|
||||||
return [
|
return [
|
||||||
result
|
result
|
||||||
async for result in run_evals(
|
async for result in run_evals(
|
||||||
eval_set_to_evals,
|
eval_set_id_to_eval_cases, root_agent, reset_func, eval_metrics
|
||||||
root_agent,
|
|
||||||
reset_func,
|
|
||||||
eval_metrics,
|
|
||||||
print_detailed_results=print_detailed_results,
|
|
||||||
)
|
)
|
||||||
]
|
]
|
||||||
|
|
||||||
@ -336,20 +347,28 @@ def cli_eval(
|
|||||||
for eval_result in eval_results:
|
for eval_result in eval_results:
|
||||||
eval_result: EvalCaseResult
|
eval_result: EvalCaseResult
|
||||||
|
|
||||||
if eval_result.eval_set_file not in eval_run_summary:
|
if eval_result.eval_set_id not in eval_run_summary:
|
||||||
eval_run_summary[eval_result.eval_set_file] = [0, 0]
|
eval_run_summary[eval_result.eval_set_id] = [0, 0]
|
||||||
|
|
||||||
if eval_result.final_eval_status == EvalStatus.PASSED:
|
if eval_result.final_eval_status == EvalStatus.PASSED:
|
||||||
eval_run_summary[eval_result.eval_set_file][0] += 1
|
eval_run_summary[eval_result.eval_set_id][0] += 1
|
||||||
else:
|
else:
|
||||||
eval_run_summary[eval_result.eval_set_file][1] += 1
|
eval_run_summary[eval_result.eval_set_id][1] += 1
|
||||||
print("Eval Run Summary")
|
print("Eval Run Summary")
|
||||||
for eval_set_file, pass_fail_count in eval_run_summary.items():
|
for eval_set_id, pass_fail_count in eval_run_summary.items():
|
||||||
print(
|
print(
|
||||||
f"{eval_set_file}:\n Tests passed: {pass_fail_count[0]}\n Tests"
|
f"{eval_set_id}:\n Tests passed: {pass_fail_count[0]}\n Tests"
|
||||||
f" failed: {pass_fail_count[1]}"
|
f" failed: {pass_fail_count[1]}"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if print_detailed_results:
|
||||||
|
for eval_result in eval_results:
|
||||||
|
eval_result: EvalCaseResult
|
||||||
|
print(
|
||||||
|
"*********************************************************************"
|
||||||
|
)
|
||||||
|
print(eval_result.model_dump_json(indent=2))
|
||||||
|
|
||||||
|
|
||||||
@main.command("web")
|
@main.command("web")
|
||||||
@click.option(
|
@click.option(
|
||||||
|
@ -48,6 +48,7 @@ from opentelemetry.exporter.cloud_trace import CloudTraceSpanExporter
|
|||||||
from opentelemetry.sdk.trace import export
|
from opentelemetry.sdk.trace import export
|
||||||
from opentelemetry.sdk.trace import ReadableSpan
|
from opentelemetry.sdk.trace import ReadableSpan
|
||||||
from opentelemetry.sdk.trace import TracerProvider
|
from opentelemetry.sdk.trace import TracerProvider
|
||||||
|
from pydantic import Field
|
||||||
from pydantic import ValidationError
|
from pydantic import ValidationError
|
||||||
from starlette.types import Lifespan
|
from starlette.types import Lifespan
|
||||||
from typing_extensions import override
|
from typing_extensions import override
|
||||||
@ -75,6 +76,7 @@ from .cli_eval import EVAL_SESSION_ID_PREFIX
|
|||||||
from .cli_eval import EvalCaseResult
|
from .cli_eval import EvalCaseResult
|
||||||
from .cli_eval import EvalMetric
|
from .cli_eval import EvalMetric
|
||||||
from .cli_eval import EvalMetricResult
|
from .cli_eval import EvalMetricResult
|
||||||
|
from .cli_eval import EvalMetricResultPerInvocation
|
||||||
from .cli_eval import EvalSetResult
|
from .cli_eval import EvalSetResult
|
||||||
from .cli_eval import EvalStatus
|
from .cli_eval import EvalStatus
|
||||||
from .utils import common
|
from .utils import common
|
||||||
@ -175,7 +177,14 @@ class RunEvalResult(common.BaseModel):
|
|||||||
eval_set_id: str
|
eval_set_id: str
|
||||||
eval_id: str
|
eval_id: str
|
||||||
final_eval_status: EvalStatus
|
final_eval_status: EvalStatus
|
||||||
eval_metric_results: list[tuple[EvalMetric, EvalMetricResult]]
|
eval_metric_results: list[tuple[EvalMetric, EvalMetricResult]] = Field(
|
||||||
|
deprecated=True,
|
||||||
|
description=(
|
||||||
|
"This field is deprecated, use overall_eval_metric_results instead."
|
||||||
|
),
|
||||||
|
)
|
||||||
|
overall_eval_metric_results: list[EvalMetricResult]
|
||||||
|
eval_metric_result_per_invocation: list[EvalMetricResultPerInvocation]
|
||||||
user_id: str
|
user_id: str
|
||||||
session_id: str
|
session_id: str
|
||||||
|
|
||||||
@ -480,25 +489,26 @@ def get_fast_api_app(
|
|||||||
async def run_eval(
|
async def run_eval(
|
||||||
app_name: str, eval_set_id: str, req: RunEvalRequest
|
app_name: str, eval_set_id: str, req: RunEvalRequest
|
||||||
) -> list[RunEvalResult]:
|
) -> list[RunEvalResult]:
|
||||||
|
"""Runs an eval given the details in the eval request."""
|
||||||
from .cli_eval import run_evals
|
from .cli_eval import run_evals
|
||||||
|
|
||||||
"""Runs an eval given the details in the eval request."""
|
|
||||||
# Create a mapping from eval set file to all the evals that needed to be
|
# Create a mapping from eval set file to all the evals that needed to be
|
||||||
# run.
|
# run.
|
||||||
envs.load_dotenv_for_agent(os.path.basename(app_name), agent_dir)
|
envs.load_dotenv_for_agent(os.path.basename(app_name), agent_dir)
|
||||||
eval_set_file_path = _get_eval_set_file_path(
|
|
||||||
app_name, agent_dir, eval_set_id
|
|
||||||
)
|
|
||||||
eval_set_to_evals = {eval_set_file_path: req.eval_ids}
|
|
||||||
|
|
||||||
if not req.eval_ids:
|
eval_set = eval_sets_manager.get_eval_set(app_name, eval_set_id)
|
||||||
logger.info(
|
|
||||||
"Eval ids to run list is empty. We will all evals in the eval set."
|
if req.eval_ids:
|
||||||
)
|
eval_cases = [e for e in eval_set.eval_cases if e.eval_id in req.eval_ids]
|
||||||
|
eval_set_to_evals = {eval_set_id: eval_cases}
|
||||||
|
else:
|
||||||
|
logger.info("Eval ids to run list is empty. We will run all eval cases.")
|
||||||
|
eval_set_to_evals = {eval_set_id: eval_set.eval_cases}
|
||||||
|
|
||||||
root_agent = await _get_root_agent_async(app_name)
|
root_agent = await _get_root_agent_async(app_name)
|
||||||
run_eval_results = []
|
run_eval_results = []
|
||||||
eval_case_results = []
|
eval_case_results = []
|
||||||
async for eval_result in run_evals(
|
async for eval_case_result in run_evals(
|
||||||
eval_set_to_evals,
|
eval_set_to_evals,
|
||||||
root_agent,
|
root_agent,
|
||||||
getattr(root_agent, "reset_data", None),
|
getattr(root_agent, "reset_data", None),
|
||||||
@ -509,31 +519,23 @@ def get_fast_api_app(
|
|||||||
run_eval_results.append(
|
run_eval_results.append(
|
||||||
RunEvalResult(
|
RunEvalResult(
|
||||||
app_name=app_name,
|
app_name=app_name,
|
||||||
eval_set_file=eval_result.eval_set_file,
|
eval_set_file=eval_case_result.eval_set_file,
|
||||||
eval_set_id=eval_set_id,
|
eval_set_id=eval_set_id,
|
||||||
eval_id=eval_result.eval_id,
|
eval_id=eval_case_result.eval_id,
|
||||||
final_eval_status=eval_result.final_eval_status,
|
final_eval_status=eval_case_result.final_eval_status,
|
||||||
eval_metric_results=eval_result.eval_metric_results,
|
eval_metric_results=eval_case_result.eval_metric_results,
|
||||||
user_id=eval_result.user_id,
|
overall_eval_metric_results=eval_case_result.overall_eval_metric_results,
|
||||||
session_id=eval_result.session_id,
|
eval_metric_result_per_invocation=eval_case_result.eval_metric_result_per_invocation,
|
||||||
|
user_id=eval_case_result.user_id,
|
||||||
|
session_id=eval_case_result.session_id,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
session = session_service.get_session(
|
eval_case_result.session_details = session_service.get_session(
|
||||||
app_name=app_name,
|
app_name=app_name,
|
||||||
user_id=eval_result.user_id,
|
user_id=eval_case_result.user_id,
|
||||||
session_id=eval_result.session_id,
|
session_id=eval_case_result.session_id,
|
||||||
)
|
|
||||||
eval_case_results.append(
|
|
||||||
EvalCaseResult(
|
|
||||||
eval_set_file=eval_result.eval_set_file,
|
|
||||||
eval_id=eval_result.eval_id,
|
|
||||||
final_eval_status=eval_result.final_eval_status,
|
|
||||||
eval_metric_results=eval_result.eval_metric_results,
|
|
||||||
session_id=eval_result.session_id,
|
|
||||||
session_details=session,
|
|
||||||
user_id=eval_result.user_id,
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
|
eval_case_results.append(eval_case_result)
|
||||||
|
|
||||||
timestamp = time.time()
|
timestamp = time.time()
|
||||||
eval_set_result_name = app_name + "_" + eval_set_id + "_" + str(timestamp)
|
eval_set_result_name = app_name + "_" + eval_set_id + "_" + str(timestamp)
|
||||||
|
@ -258,13 +258,6 @@ class AgentEvaluator:
|
|||||||
initial_session=initial_session,
|
initial_session=initial_session,
|
||||||
)
|
)
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _generate_responses_from_session(eval_dataset, session_path):
|
|
||||||
"""Generates evaluation responses by running the agent module multiple times."""
|
|
||||||
return EvaluationGenerator.generate_responses_from_session(
|
|
||||||
session_path, eval_dataset
|
|
||||||
)
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _response_evaluation_required(criteria, eval_dataset):
|
def _response_evaluation_required(criteria, eval_dataset):
|
||||||
"""Checks if response evaluation are needed."""
|
"""Checks if response evaluation are needed."""
|
||||||
|
@ -23,10 +23,10 @@ from pydantic import Field
|
|||||||
class IntermediateData(BaseModel):
|
class IntermediateData(BaseModel):
|
||||||
"""Container for intermediate data that an agent would generate as it responds with a final answer."""
|
"""Container for intermediate data that an agent would generate as it responds with a final answer."""
|
||||||
|
|
||||||
tool_uses: list[genai_types.FunctionCall]
|
tool_uses: list[genai_types.FunctionCall] = []
|
||||||
"""Tool use trajectory in chronological order."""
|
"""Tool use trajectory in chronological order."""
|
||||||
|
|
||||||
intermediate_responses: list[Tuple[str, list[genai_types.Part]]]
|
intermediate_responses: list[Tuple[str, list[genai_types.Part]]] = []
|
||||||
"""Intermediate responses generated by sub-agents to convey progress or status
|
"""Intermediate responses generated by sub-agents to convey progress or status
|
||||||
in a multi-agent system, distinct from the final response.
|
in a multi-agent system, distinct from the final response.
|
||||||
|
|
||||||
|
@ -13,19 +13,19 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import importlib
|
import importlib
|
||||||
|
from typing import Any, Optional
|
||||||
import uuid
|
import uuid
|
||||||
|
|
||||||
from google.genai import types
|
|
||||||
|
|
||||||
from ..agents.base_agent import BaseAgent
|
|
||||||
from ..agents.llm_agent import Agent
|
from ..agents.llm_agent import Agent
|
||||||
from ..agents.llm_agent import BeforeToolCallback
|
from ..artifacts.base_artifact_service import BaseArtifactService
|
||||||
from ..agents.llm_agent import LlmAgent
|
|
||||||
from ..artifacts.in_memory_artifact_service import InMemoryArtifactService
|
from ..artifacts.in_memory_artifact_service import InMemoryArtifactService
|
||||||
from ..runners import Runner
|
from ..runners import Runner
|
||||||
|
from ..sessions.base_session_service import BaseSessionService
|
||||||
from ..sessions.in_memory_session_service import InMemorySessionService
|
from ..sessions.in_memory_session_service import InMemorySessionService
|
||||||
from ..sessions.session import Session
|
from ..sessions.session import Session
|
||||||
from .evaluation_constants import EvalConstants
|
from .eval_case import IntermediateData
|
||||||
|
from .eval_case import Invocation
|
||||||
|
from .eval_case import SessionInput
|
||||||
|
|
||||||
|
|
||||||
class EvaluationGenerator:
|
class EvaluationGenerator:
|
||||||
@ -102,56 +102,40 @@ class EvaluationGenerator:
|
|||||||
agent_to_evaluate = root_agent.find_agent(agent_name)
|
agent_to_evaluate = root_agent.find_agent(agent_name)
|
||||||
assert agent_to_evaluate, f"Sub-Agent `{agent_name}` not found."
|
assert agent_to_evaluate, f"Sub-Agent `{agent_name}` not found."
|
||||||
|
|
||||||
return EvaluationGenerator._process_query_with_root_agent(
|
return EvaluationGenerator._generate_inferences_from_root_agent(
|
||||||
data, agent_to_evaluate, reset_func, initial_session
|
data, agent_to_evaluate, reset_func, initial_session
|
||||||
)
|
)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
async def _process_query_with_root_agent(
|
async def _generate_inferences_from_root_agent(
|
||||||
data,
|
invocations: list[Invocation],
|
||||||
root_agent,
|
root_agent: Agent,
|
||||||
reset_func,
|
reset_func: Any,
|
||||||
initial_session={},
|
initial_session: Optional[SessionInput] = None,
|
||||||
session_id=None,
|
session_id: Optional[str] = None,
|
||||||
session_service=None,
|
session_service: Optional[BaseSessionService] = None,
|
||||||
artifact_service=None,
|
artifact_service: Optional[BaseArtifactService] = None,
|
||||||
):
|
) -> list[Invocation]:
|
||||||
"""Process a query using the agent and evaluation dataset."""
|
"""Scrapes the root agent given the list of Invocations."""
|
||||||
|
|
||||||
# we don't know which tools belong to which agent
|
|
||||||
# so we just apply to any agents that has certain tool outputs
|
|
||||||
all_mock_tools = set()
|
|
||||||
for eval_entry in data:
|
|
||||||
expected_tool_use = eval_entry.get(EvalConstants.EXPECTED_TOOL_USE, [])
|
|
||||||
for expected in expected_tool_use:
|
|
||||||
if EvalConstants.MOCK_TOOL_OUTPUT in expected:
|
|
||||||
all_mock_tools.add(expected[EvalConstants.TOOL_NAME])
|
|
||||||
|
|
||||||
eval_data_copy = data.copy()
|
|
||||||
await EvaluationGenerator.apply_before_tool_callback(
|
|
||||||
root_agent,
|
|
||||||
lambda *args: EvaluationGenerator.before_tool_callback(
|
|
||||||
*args, eval_dataset=eval_data_copy
|
|
||||||
),
|
|
||||||
all_mock_tools,
|
|
||||||
)
|
|
||||||
|
|
||||||
if not session_service:
|
if not session_service:
|
||||||
session_service = InMemorySessionService()
|
session_service = InMemorySessionService()
|
||||||
|
|
||||||
app_name = initial_session.get("app_name", "EvaluationGenerator")
|
app_name = (
|
||||||
user_id = initial_session.get("user_id", "test_user_id")
|
initial_session.app_name if initial_session else "EvaluationGenerator"
|
||||||
|
)
|
||||||
|
user_id = initial_session.user_id if initial_session else "test_user_id"
|
||||||
session_id = session_id if session_id else str(uuid.uuid4())
|
session_id = session_id if session_id else str(uuid.uuid4())
|
||||||
|
|
||||||
_ = session_service.create_session(
|
_ = session_service.create_session(
|
||||||
app_name=app_name,
|
app_name=app_name,
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
state=initial_session.get("state", {}),
|
state=initial_session.state if initial_session else {},
|
||||||
session_id=session_id,
|
session_id=session_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
if not artifact_service:
|
if not artifact_service:
|
||||||
artifact_service = InMemoryArtifactService()
|
artifact_service = InMemoryArtifactService()
|
||||||
|
|
||||||
runner = Runner(
|
runner = Runner(
|
||||||
app_name=app_name,
|
app_name=app_name,
|
||||||
agent=root_agent,
|
agent=root_agent,
|
||||||
@ -163,30 +147,37 @@ class EvaluationGenerator:
|
|||||||
if callable(reset_func):
|
if callable(reset_func):
|
||||||
reset_func()
|
reset_func()
|
||||||
|
|
||||||
responses = data.copy()
|
response_invocations = []
|
||||||
|
|
||||||
for index, eval_entry in enumerate(responses):
|
for invocation in invocations:
|
||||||
response = None
|
final_response = None
|
||||||
query = eval_entry["query"]
|
user_content = invocation.user_content
|
||||||
content = types.Content(role="user", parts=[types.Part(text=query)])
|
tool_uses = []
|
||||||
turn_actual_tool_uses = []
|
invocation_id = ""
|
||||||
|
|
||||||
for event in runner.run(
|
for event in runner.run(
|
||||||
user_id=user_id, session_id=session_id, new_message=content
|
user_id=user_id, session_id=session_id, new_message=user_content
|
||||||
):
|
):
|
||||||
|
invocation_id = (
|
||||||
|
event.invocation_id if not invocation_id else invocation_id
|
||||||
|
)
|
||||||
|
|
||||||
if event.is_final_response() and event.content and event.content.parts:
|
if event.is_final_response() and event.content and event.content.parts:
|
||||||
response = event.content.parts[0].text
|
final_response = event.content
|
||||||
elif event.get_function_calls():
|
elif event.get_function_calls():
|
||||||
for call in event.get_function_calls():
|
for call in event.get_function_calls():
|
||||||
turn_actual_tool_uses.append({
|
tool_uses.append(call)
|
||||||
EvalConstants.TOOL_NAME: call.name,
|
|
||||||
EvalConstants.TOOL_INPUT: call.args,
|
|
||||||
})
|
|
||||||
|
|
||||||
responses[index]["actual_tool_use"] = turn_actual_tool_uses
|
response_invocations.append(
|
||||||
responses[index]["response"] = response
|
Invocation(
|
||||||
|
invocation_id=invocation_id,
|
||||||
|
user_content=user_content,
|
||||||
|
final_response=final_response,
|
||||||
|
intermediate_data=IntermediateData(tool_uses=tool_uses),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
return responses
|
return response_invocations
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _process_query_with_session(session_data, data):
|
def _process_query_with_session(session_data, data):
|
||||||
@ -225,46 +216,3 @@ class EvaluationGenerator:
|
|||||||
responses[index]["actual_tool_use"] = actual_tool_uses
|
responses[index]["actual_tool_use"] = actual_tool_uses
|
||||||
responses[index]["response"] = response
|
responses[index]["response"] = response
|
||||||
return responses
|
return responses
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def before_tool_callback(tool, args, tool_context, eval_dataset):
|
|
||||||
"""Intercept specific tool calls and return predefined outputs
|
|
||||||
|
|
||||||
from eval_dataset.
|
|
||||||
"""
|
|
||||||
for index, eval_entry in enumerate(eval_dataset):
|
|
||||||
expected_tool_use = eval_entry.get("expected_tool_use", [])
|
|
||||||
for expected in expected_tool_use:
|
|
||||||
if (
|
|
||||||
EvalConstants.MOCK_TOOL_OUTPUT in expected
|
|
||||||
and tool.name == expected[EvalConstants.TOOL_NAME]
|
|
||||||
and args == expected.get(EvalConstants.TOOL_INPUT, {})
|
|
||||||
):
|
|
||||||
# pop the matched entry so we don't rematch again
|
|
||||||
eval_dataset.pop(index)
|
|
||||||
return {"result": expected[EvalConstants.MOCK_TOOL_OUTPUT]}
|
|
||||||
|
|
||||||
return None
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
async def apply_before_tool_callback(
|
|
||||||
agent: BaseAgent,
|
|
||||||
callback: BeforeToolCallback,
|
|
||||||
all_mock_tools: set[str],
|
|
||||||
):
|
|
||||||
"""Recursively apply the before_tool_callback to the root agent and all its subagents."""
|
|
||||||
# Check if the agent has tools that are defined by evalset.
|
|
||||||
# We use function names to check if tools match
|
|
||||||
if not isinstance(agent, Agent) and not isinstance(agent, LlmAgent):
|
|
||||||
return
|
|
||||||
|
|
||||||
for tool in await agent.canonical_tools():
|
|
||||||
tool_name = tool.name
|
|
||||||
if tool_name in all_mock_tools:
|
|
||||||
agent.before_tool_callback = callback
|
|
||||||
|
|
||||||
# Apply recursively to subagents if they exist
|
|
||||||
for sub_agent in agent.sub_agents:
|
|
||||||
await EvaluationGenerator.apply_before_tool_callback(
|
|
||||||
sub_agent, callback, all_mock_tools
|
|
||||||
)
|
|
||||||
|
56
src/google/adk/evaluation/evaluator.py
Normal file
56
src/google/adk/evaluation/evaluator.py
Normal file
@ -0,0 +1,56 @@
|
|||||||
|
# Copyright 2025 Google LLC
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from abc import ABC
|
||||||
|
from enum import Enum
|
||||||
|
from typing import Optional
|
||||||
|
from pydantic import BaseModel
|
||||||
|
from .eval_case import Invocation
|
||||||
|
|
||||||
|
|
||||||
|
class EvalStatus(Enum):
|
||||||
|
PASSED = 1
|
||||||
|
FAILED = 2
|
||||||
|
NOT_EVALUATED = 3
|
||||||
|
|
||||||
|
|
||||||
|
class PerInvocationResult(BaseModel):
|
||||||
|
"""Metric evaluation score per invocation."""
|
||||||
|
|
||||||
|
actual_invocation: Invocation
|
||||||
|
expected_invocation: Invocation
|
||||||
|
score: Optional[float] = None
|
||||||
|
eval_status: EvalStatus = EvalStatus.NOT_EVALUATED
|
||||||
|
|
||||||
|
|
||||||
|
class EvaluationResult(BaseModel):
|
||||||
|
overall_score: Optional[float] = None
|
||||||
|
"""Overall score, based on each invocation."""
|
||||||
|
|
||||||
|
overall_eval_status: EvalStatus = EvalStatus.NOT_EVALUATED
|
||||||
|
"""Overall status, based on each invocation."""
|
||||||
|
|
||||||
|
per_invocation_results: list[PerInvocationResult] = []
|
||||||
|
|
||||||
|
|
||||||
|
class Evaluator(ABC):
|
||||||
|
"""A merics evaluator interface."""
|
||||||
|
|
||||||
|
def evaluate_invocations(
|
||||||
|
self,
|
||||||
|
actual_invocations: list[Invocation],
|
||||||
|
expected_invocations: list[Invocation],
|
||||||
|
) -> EvaluationResult:
|
||||||
|
"""Returns EvaluationResult after performing evaluations using actual and expected invocations."""
|
||||||
|
raise NotImplementedError()
|
@ -154,6 +154,22 @@ def convert_eval_set_to_pydanctic_schema(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def load_eval_set_from_file(
|
||||||
|
eval_set_file_path: str, eval_set_id: str
|
||||||
|
) -> EvalSet:
|
||||||
|
"""Returns an EvalSet that is read from the given file."""
|
||||||
|
with open(eval_set_file_path, "r", encoding="utf-8") as f:
|
||||||
|
content = f.read()
|
||||||
|
try:
|
||||||
|
return EvalSet.model_validate_json(content)
|
||||||
|
except ValidationError:
|
||||||
|
# We assume that the eval data was specified in the old format and try
|
||||||
|
# to convert it to the new format.
|
||||||
|
return convert_eval_set_to_pydanctic_schema(
|
||||||
|
eval_set_id, json.loads(content)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class LocalEvalSetsManager(EvalSetsManager):
|
class LocalEvalSetsManager(EvalSetsManager):
|
||||||
"""An EvalSets manager that stores eval sets locally on disk."""
|
"""An EvalSets manager that stores eval sets locally on disk."""
|
||||||
|
|
||||||
@ -165,16 +181,7 @@ class LocalEvalSetsManager(EvalSetsManager):
|
|||||||
"""Returns an EvalSet identified by an app_name and eval_set_id."""
|
"""Returns an EvalSet identified by an app_name and eval_set_id."""
|
||||||
# Load the eval set file data
|
# Load the eval set file data
|
||||||
eval_set_file_path = self._get_eval_set_file_path(app_name, eval_set_id)
|
eval_set_file_path = self._get_eval_set_file_path(app_name, eval_set_id)
|
||||||
with open(eval_set_file_path, "r", encoding="utf-8") as f:
|
return load_eval_set_from_file(eval_set_file_path, eval_set_id)
|
||||||
content = f.read()
|
|
||||||
try:
|
|
||||||
return EvalSet.model_validate_json(content)
|
|
||||||
except ValidationError:
|
|
||||||
# We assume that the eval data was specified in the old format and try
|
|
||||||
# to convert it to the new format.
|
|
||||||
return convert_eval_set_to_pydanctic_schema(
|
|
||||||
eval_set_id, json.loads(content)
|
|
||||||
)
|
|
||||||
|
|
||||||
@override
|
@override
|
||||||
def create_eval_set(self, app_name: str, eval_set_id: str):
|
def create_eval_set(self, app_name: str, eval_set_id: str):
|
||||||
|
@ -12,18 +12,98 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
from typing import Any
|
from typing import Any, cast
|
||||||
|
|
||||||
|
from deprecated import deprecated
|
||||||
|
from google.genai import types as genai_types
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
from tabulate import tabulate
|
from tabulate import tabulate
|
||||||
|
from typing_extensions import override
|
||||||
|
|
||||||
|
from .eval_case import Invocation
|
||||||
from .evaluation_constants import EvalConstants
|
from .evaluation_constants import EvalConstants
|
||||||
|
from .evaluator import EvalStatus
|
||||||
|
from .evaluator import EvaluationResult
|
||||||
|
from .evaluator import Evaluator
|
||||||
|
from .evaluator import PerInvocationResult
|
||||||
|
|
||||||
|
|
||||||
class TrajectoryEvaluator:
|
class TrajectoryEvaluator(Evaluator):
|
||||||
"""Evaluates tool use trajectories for accuracy."""
|
"""Evaluates tool use trajectories for accuracy."""
|
||||||
|
|
||||||
|
def __init__(self, threshold: float):
|
||||||
|
self._threshold = threshold
|
||||||
|
|
||||||
|
@override
|
||||||
|
def evaluate_invocations(
|
||||||
|
self,
|
||||||
|
actual_invocations: list[Invocation],
|
||||||
|
expected_invocations: list[Invocation],
|
||||||
|
) -> EvaluationResult:
|
||||||
|
"""Returns EvaluationResult after performing evaluations using actual and expected invocations."""
|
||||||
|
total_tool_use_accuracy = 0.0
|
||||||
|
num_invocations = 0
|
||||||
|
per_invocation_results = []
|
||||||
|
|
||||||
|
for actual, expected in zip(actual_invocations, expected_invocations):
|
||||||
|
actual_tool_uses = (
|
||||||
|
actual.intermediate_data.tool_uses if actual.intermediate_data else []
|
||||||
|
)
|
||||||
|
expected_tool_uses = (
|
||||||
|
expected.intermediate_data.tool_uses
|
||||||
|
if expected.intermediate_data
|
||||||
|
else []
|
||||||
|
)
|
||||||
|
tool_use_accuracy = (
|
||||||
|
1.0
|
||||||
|
if self._are_tool_calls_equal(actual_tool_uses, expected_tool_uses)
|
||||||
|
else 0.0
|
||||||
|
)
|
||||||
|
per_invocation_results.append(
|
||||||
|
PerInvocationResult(
|
||||||
|
actual_invocation=actual,
|
||||||
|
expected_invocation=expected,
|
||||||
|
score=tool_use_accuracy,
|
||||||
|
eval_status=self._get_eval_status(tool_use_accuracy),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
total_tool_use_accuracy += tool_use_accuracy
|
||||||
|
num_invocations += 1
|
||||||
|
|
||||||
|
if per_invocation_results:
|
||||||
|
overall_score = total_tool_use_accuracy / num_invocations
|
||||||
|
return EvaluationResult(
|
||||||
|
overall_score=overall_score,
|
||||||
|
overall_eval_status=self._get_eval_status(overall_score),
|
||||||
|
per_invocation_results=per_invocation_results,
|
||||||
|
)
|
||||||
|
|
||||||
|
return EvaluationResult()
|
||||||
|
|
||||||
|
def _are_tool_calls_equal(
|
||||||
|
self,
|
||||||
|
actual_tool_calls: list[genai_types.FunctionCall],
|
||||||
|
expected_tool_calls: list[genai_types.FunctionCall],
|
||||||
|
) -> bool:
|
||||||
|
if len(actual_tool_calls) != len(expected_tool_calls):
|
||||||
|
return False
|
||||||
|
|
||||||
|
for actual, expected in zip(actual_tool_calls, expected_tool_calls):
|
||||||
|
if actual.name != expected.name or actual.args != expected.args:
|
||||||
|
return False
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
def _get_eval_status(self, score: float):
|
||||||
|
return EvalStatus.PASSED if score >= self._threshold else EvalStatus.FAILED
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
@deprecated(
|
||||||
|
reason=(
|
||||||
|
"This method has been deprecated and will be removed soon. Please use"
|
||||||
|
" evaluate_invocations instead."
|
||||||
|
)
|
||||||
|
)
|
||||||
def evaluate(
|
def evaluate(
|
||||||
eval_dataset: list[list[dict[str, Any]]],
|
eval_dataset: list[list[dict[str, Any]]],
|
||||||
*,
|
*,
|
||||||
@ -137,6 +217,7 @@ class TrajectoryEvaluator:
|
|||||||
return new_row, failure
|
return new_row, failure
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
@deprecated()
|
||||||
def are_tools_equal(list_a_original, list_b_original):
|
def are_tools_equal(list_a_original, list_b_original):
|
||||||
# Remove other entries that we don't want to evaluate
|
# Remove other entries that we don't want to evaluate
|
||||||
list_a = [
|
list_a = [
|
||||||
|
Loading…
Reference in New Issue
Block a user