Update Eval Run and TrajectoryEvaluator to use the new schema.

PiperOrigin-RevId: 758927160
This commit is contained in:
Ankur Sharma
2025-05-14 19:15:52 -07:00
committed by Copybara-Service
parent 2cb74dd20e
commit ee674ce0ef
9 changed files with 418 additions and 244 deletions
+151 -83
View File
@@ -12,8 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import datetime
from enum import Enum
import importlib.util
import json
import logging
@@ -22,6 +20,7 @@ import sys
import traceback
from typing import Any
from typing import AsyncGenerator
from typing import cast
from typing import Optional
import uuid
@@ -29,36 +28,84 @@ from pydantic import BaseModel
from pydantic import Field
from ..agents import Agent
from ..artifacts.base_artifact_service import BaseArtifactService
from ..evaluation.eval_case import EvalCase
from ..evaluation.eval_case import Invocation
from ..evaluation.evaluator import EvalStatus
from ..sessions.base_session_service import BaseSessionService
from ..sessions.session import Session
from .utils import common
logger = logging.getLogger(__name__)
class EvalStatus(Enum):
PASSED = 1
FAILED = 2
NOT_EVALUATED = 3
class EvalMetric(BaseModel):
"""A metric used to evaluate a particular aspect of an eval case."""
metric_name: str
"""The name of the metric."""
threshold: float
"""A threshold value. Each metric decides how to interpret this threshold."""
class EvalMetricResult(BaseModel):
class EvalMetricResult(EvalMetric):
"""The actual computed score/value of a particular EvalMetric."""
score: Optional[float] = None
eval_status: EvalStatus
class EvalMetricResultPerInvocation(BaseModel):
"""Eval metric results per invocation."""
actual_invocation: Invocation
"""The actual invocation, usually obtained by inferencing the agent."""
expected_invocation: Invocation
"""The expected invocation, usually the reference or golden invocation."""
eval_metric_results: list[EvalMetricResult] = []
"""Eval resutls for each applicable metric."""
class EvalCaseResult(common.BaseModel):
eval_set_file: str
eval_id: str
"""Case-level evaluation results."""
eval_set_file: str = Field(
deprecated=True,
description="This field is deprecated, use eval_set_id instead.",
)
eval_set_id: str = ""
"""The eval set id."""
eval_id: str = ""
"""The eval case id."""
final_eval_status: EvalStatus
eval_metric_results: list[tuple[EvalMetric, EvalMetricResult]]
"""Final evalu status for this eval case."""
eval_metric_results: list[tuple[EvalMetric, EvalMetricResult]] = Field(
deprecated=True,
description=(
"This field is deprecated, use overall_eval_metric_results instead."
),
)
overall_eval_metric_results: list[EvalMetricResult]
"""Overall result for each metric for the entire eval case."""
eval_metric_result_per_invocation: list[EvalMetricResultPerInvocation]
"""Result for each metric on a per invocation basis."""
session_id: str
"""Session id of the session generated as result of inferencing/scraping stage of the eval."""
session_details: Optional[Session] = None
"""Session generated as result of inferencing/scraping stage of the eval."""
user_id: Optional[str] = None
"""User id used during inferencing/scraping stage of the eval."""
class EvalSetResult(common.BaseModel):
@@ -161,14 +208,25 @@ def parse_and_get_evals_to_run(
async def run_evals(
eval_set_to_evals: dict[str, list[str]],
eval_cases_by_eval_set_id: dict[str, list[EvalCase]],
root_agent: Agent,
reset_func: Optional[Any],
eval_metrics: list[EvalMetric],
session_service=None,
artifact_service=None,
print_detailed_results=False,
session_service: Optional[BaseSessionService] = None,
artifact_service: Optional[BaseArtifactService] = None,
) -> AsyncGenerator[EvalCaseResult, None]:
"""Returns a stream of EvalCaseResult for each eval case that was evaluated.
Args:
eval_cases_by_eval_set_id: Eval cases categorized by eval set id to which
they belong.
root_agent: Agent to use for inferencing.
reset_func: If present, this will be called before invoking the agent before
every inferencing step.
eval_metrics: A list of metrics that should be used during evaluation.
session_service: The session service to use during inferencing.
artifact_service: The artifact service to use during inferencing.
"""
try:
from ..evaluation.agent_evaluator import EvaluationGenerator
from ..evaluation.response_evaluator import ResponseEvaluator
@@ -176,29 +234,19 @@ async def run_evals(
except ModuleNotFoundError as e:
raise ModuleNotFoundError(MISSING_EVAL_DEPENDENCIES_MESSAGE) from e
"""Returns a summary of eval runs."""
for eval_set_file, evals_to_run in eval_set_to_evals.items():
with open(eval_set_file, "r", encoding="utf-8") as file:
eval_items = json.load(file) # Load JSON into a list
assert eval_items, f"No eval data found in eval set file: {eval_set_file}"
for eval_item in eval_items:
eval_name = eval_item["name"]
eval_data = eval_item["data"]
initial_session = eval_item.get("initial_session", {})
user_id = initial_session.get("user_id", "test_user_id")
if evals_to_run and eval_name not in evals_to_run:
continue
for eval_set_id, eval_cases in eval_cases_by_eval_set_id.items():
for eval_case in eval_cases:
eval_name = eval_case.eval_id
initial_session = eval_case.session_input
user_id = initial_session.user_id if initial_session else "test_user_id"
try:
print(f"Running Eval: {eval_set_file}:{eval_name}")
print(f"Running Eval: {eval_set_id}:{eval_name}")
session_id = f"{EVAL_SESSION_ID_PREFIX}{str(uuid.uuid4())}"
scrape_result = (
await EvaluationGenerator._process_query_with_root_agent(
data=eval_data,
inference_result = (
await EvaluationGenerator._generate_inferences_from_root_agent(
invocations=eval_case.conversation,
root_agent=root_agent,
reset_func=reset_func,
initial_session=initial_session,
@@ -208,67 +256,95 @@ async def run_evals(
)
)
eval_metric_results = []
# Initialize the per-invocation metric results to an empty list.
# We will fill this as we evaluate each metric.
eval_metric_result_per_invocation = []
for actual, expected in zip(inference_result, eval_case.conversation):
eval_metric_result_per_invocation.append(
EvalMetricResultPerInvocation(
actual_invocation=actual,
expected_invocation=expected,
eval_metric_results=[],
)
)
overall_eval_metric_results = []
for eval_metric in eval_metrics:
eval_metric_result = None
if eval_metric.metric_name == TOOL_TRAJECTORY_SCORE_KEY:
score = TrajectoryEvaluator.evaluate(
[scrape_result], print_detailed_results=print_detailed_results
evaluation_result = TrajectoryEvaluator(
eval_metric.threshold
).evaluate_invocations(
actual_invocations=inference_result,
expected_invocations=eval_case.conversation,
)
eval_metric_result = _get_eval_metric_result(eval_metric, score)
elif eval_metric.metric_name == RESPONSE_MATCH_SCORE_KEY:
score = ResponseEvaluator.evaluate(
[scrape_result],
[RESPONSE_MATCH_SCORE_KEY],
print_detailed_results=print_detailed_results,
)
eval_metric_result = _get_eval_metric_result(
eval_metric, score["rouge_1/mean"].item()
)
elif eval_metric.metric_name == RESPONSE_EVALUATION_SCORE_KEY:
score = ResponseEvaluator.evaluate(
[scrape_result],
[RESPONSE_EVALUATION_SCORE_KEY],
print_detailed_results=print_detailed_results,
)
eval_metric_result = _get_eval_metric_result(
eval_metric, score["coherence/mean"].item()
overall_eval_metric_results.append(
EvalMetricResult(
metric_name=eval_metric.metric_name,
threshold=eval_metric.threshold,
score=evaluation_result.overall_score,
eval_status=evaluation_result.overall_eval_status,
)
)
for index, per_invocation_result in enumerate(
evaluation_result.per_invocation_results
):
eval_metric_result_per_invocation[
index
].eval_metric_results.append(
EvalMetricResult(
metric_name=eval_metric.metric_name,
threshold=eval_metric.threshold,
score=per_invocation_result.score,
eval_status=per_invocation_result.eval_status,
)
)
# elif eval_metric.metric_name == RESPONSE_MATCH_SCORE_KEY:
# score = ResponseEvaluator.evaluate(
# [inference_result],
# [RESPONSE_MATCH_SCORE_KEY],
# print_detailed_results=print_detailed_results,
# )
# eval_metric_result = _get_eval_metric_result(
# eval_metric, score["rouge_1/mean"].item()
# )
# elif eval_metric.metric_name == RESPONSE_EVALUATION_SCORE_KEY:
# score = ResponseEvaluator.evaluate(
# [inference_result],
# [RESPONSE_EVALUATION_SCORE_KEY],
# print_detailed_results=print_detailed_results,
# )
# eval_metric_result = _get_eval_metric_result(
# eval_metric, score["coherence/mean"].item()
# )
else:
logger.warning("`%s` is not supported.", eval_metric.metric_name)
eval_metric_results.append((
eval_metric,
EvalMetricResult(eval_status=EvalStatus.NOT_EVALUATED),
))
eval_metric_results.append((
eval_metric,
eval_metric_result,
))
_print_eval_metric_result(eval_metric, eval_metric_result)
final_eval_status = EvalStatus.NOT_EVALUATED
# Go over the all the eval statuses and mark the final eval status as
# passed if all of them pass, otherwise mark the final eval status to
# failed.
for eval_metric_result in eval_metric_results:
eval_status = eval_metric_result[1].eval_status
if eval_status == EvalStatus.PASSED:
for overall_eval_metric_result in overall_eval_metric_results:
overall_eval_status = overall_eval_metric_result.eval_status
if overall_eval_status == EvalStatus.PASSED:
final_eval_status = EvalStatus.PASSED
elif eval_status == EvalStatus.NOT_EVALUATED:
elif overall_eval_status == EvalStatus.NOT_EVALUATED:
continue
elif eval_status == EvalStatus.FAILED:
elif overall_eval_status == EvalStatus.FAILED:
final_eval_status = EvalStatus.FAILED
break
else:
raise ValueError("Unknown eval status.")
yield EvalCaseResult(
eval_set_file=eval_set_file,
eval_set_file=eval_set_id,
eval_set_id=eval_set_id,
eval_id=eval_name,
final_eval_status=final_eval_status,
eval_metric_results=eval_metric_results,
eval_metric_results=[],
overall_eval_metric_results=overall_eval_metric_results,
eval_metric_result_per_invocation=eval_metric_result_per_invocation,
session_id=session_id,
user_id=user_id,
)
@@ -290,11 +366,3 @@ def _get_eval_metric_result(eval_metric, score):
EvalStatus.PASSED if score >= eval_metric.threshold else EvalStatus.FAILED
)
return EvalMetricResult(score=score, eval_status=eval_status)
def _print_eval_metric_result(eval_metric, eval_metric_result):
print(
f"Metric: {eval_metric.metric_name}\tStatus:"
f" {eval_metric_result.eval_status}\tScore:"
f" {eval_metric_result.score}\tThreshold: {eval_metric.threshold}"
)
+31 -12
View File
@@ -296,6 +296,7 @@ def cli_eval(
from .cli_eval import parse_and_get_evals_to_run
from .cli_eval import run_evals
from .cli_eval import try_get_reset_func
from ..evaluation.local_eval_sets_manager import load_eval_set_from_file
except ModuleNotFoundError:
raise click.ClickException(MISSING_EVAL_DEPENDENCIES_MESSAGE)
@@ -311,17 +312,27 @@ def cli_eval(
root_agent = get_root_agent(agent_module_file_path)
reset_func = try_get_reset_func(agent_module_file_path)
eval_set_to_evals = parse_and_get_evals_to_run(eval_set_file_path)
eval_set_file_path_to_evals = parse_and_get_evals_to_run(eval_set_file_path)
eval_set_id_to_eval_cases = {}
# Read the eval_set files and get the cases.
for eval_set_file_path, eval_case_ids in eval_set_file_path_to_evals.items():
eval_set = load_eval_set_from_file(eval_set_file_path, eval_set_file_path)
eval_cases = eval_set.eval_cases
if eval_case_ids:
# There are eval_ids that we should select.
eval_cases = [
e for e in eval_set.eval_cases if e.eval_id in eval_case_ids
]
eval_set_id_to_eval_cases[eval_set_file_path] = eval_cases
async def _collect_eval_results() -> list[EvalCaseResult]:
return [
result
async for result in run_evals(
eval_set_to_evals,
root_agent,
reset_func,
eval_metrics,
print_detailed_results=print_detailed_results,
eval_set_id_to_eval_cases, root_agent, reset_func, eval_metrics
)
]
@@ -336,20 +347,28 @@ def cli_eval(
for eval_result in eval_results:
eval_result: EvalCaseResult
if eval_result.eval_set_file not in eval_run_summary:
eval_run_summary[eval_result.eval_set_file] = [0, 0]
if eval_result.eval_set_id not in eval_run_summary:
eval_run_summary[eval_result.eval_set_id] = [0, 0]
if eval_result.final_eval_status == EvalStatus.PASSED:
eval_run_summary[eval_result.eval_set_file][0] += 1
eval_run_summary[eval_result.eval_set_id][0] += 1
else:
eval_run_summary[eval_result.eval_set_file][1] += 1
eval_run_summary[eval_result.eval_set_id][1] += 1
print("Eval Run Summary")
for eval_set_file, pass_fail_count in eval_run_summary.items():
for eval_set_id, pass_fail_count in eval_run_summary.items():
print(
f"{eval_set_file}:\n Tests passed: {pass_fail_count[0]}\n Tests"
f"{eval_set_id}:\n Tests passed: {pass_fail_count[0]}\n Tests"
f" failed: {pass_fail_count[1]}"
)
if print_detailed_results:
for eval_result in eval_results:
eval_result: EvalCaseResult
print(
"*********************************************************************"
)
print(eval_result.model_dump_json(indent=2))
@main.command("web")
@click.option(
+33 -31
View File
@@ -48,6 +48,7 @@ from opentelemetry.exporter.cloud_trace import CloudTraceSpanExporter
from opentelemetry.sdk.trace import export
from opentelemetry.sdk.trace import ReadableSpan
from opentelemetry.sdk.trace import TracerProvider
from pydantic import Field
from pydantic import ValidationError
from starlette.types import Lifespan
from typing_extensions import override
@@ -75,6 +76,7 @@ from .cli_eval import EVAL_SESSION_ID_PREFIX
from .cli_eval import EvalCaseResult
from .cli_eval import EvalMetric
from .cli_eval import EvalMetricResult
from .cli_eval import EvalMetricResultPerInvocation
from .cli_eval import EvalSetResult
from .cli_eval import EvalStatus
from .utils import common
@@ -175,7 +177,14 @@ class RunEvalResult(common.BaseModel):
eval_set_id: str
eval_id: str
final_eval_status: EvalStatus
eval_metric_results: list[tuple[EvalMetric, EvalMetricResult]]
eval_metric_results: list[tuple[EvalMetric, EvalMetricResult]] = Field(
deprecated=True,
description=(
"This field is deprecated, use overall_eval_metric_results instead."
),
)
overall_eval_metric_results: list[EvalMetricResult]
eval_metric_result_per_invocation: list[EvalMetricResultPerInvocation]
user_id: str
session_id: str
@@ -480,25 +489,26 @@ def get_fast_api_app(
async def run_eval(
app_name: str, eval_set_id: str, req: RunEvalRequest
) -> list[RunEvalResult]:
"""Runs an eval given the details in the eval request."""
from .cli_eval import run_evals
"""Runs an eval given the details in the eval request."""
# Create a mapping from eval set file to all the evals that needed to be
# run.
envs.load_dotenv_for_agent(os.path.basename(app_name), agent_dir)
eval_set_file_path = _get_eval_set_file_path(
app_name, agent_dir, eval_set_id
)
eval_set_to_evals = {eval_set_file_path: req.eval_ids}
if not req.eval_ids:
logger.info(
"Eval ids to run list is empty. We will all evals in the eval set."
)
eval_set = eval_sets_manager.get_eval_set(app_name, eval_set_id)
if req.eval_ids:
eval_cases = [e for e in eval_set.eval_cases if e.eval_id in req.eval_ids]
eval_set_to_evals = {eval_set_id: eval_cases}
else:
logger.info("Eval ids to run list is empty. We will run all eval cases.")
eval_set_to_evals = {eval_set_id: eval_set.eval_cases}
root_agent = await _get_root_agent_async(app_name)
run_eval_results = []
eval_case_results = []
async for eval_result in run_evals(
async for eval_case_result in run_evals(
eval_set_to_evals,
root_agent,
getattr(root_agent, "reset_data", None),
@@ -509,31 +519,23 @@ def get_fast_api_app(
run_eval_results.append(
RunEvalResult(
app_name=app_name,
eval_set_file=eval_result.eval_set_file,
eval_set_file=eval_case_result.eval_set_file,
eval_set_id=eval_set_id,
eval_id=eval_result.eval_id,
final_eval_status=eval_result.final_eval_status,
eval_metric_results=eval_result.eval_metric_results,
user_id=eval_result.user_id,
session_id=eval_result.session_id,
eval_id=eval_case_result.eval_id,
final_eval_status=eval_case_result.final_eval_status,
eval_metric_results=eval_case_result.eval_metric_results,
overall_eval_metric_results=eval_case_result.overall_eval_metric_results,
eval_metric_result_per_invocation=eval_case_result.eval_metric_result_per_invocation,
user_id=eval_case_result.user_id,
session_id=eval_case_result.session_id,
)
)
session = session_service.get_session(
eval_case_result.session_details = session_service.get_session(
app_name=app_name,
user_id=eval_result.user_id,
session_id=eval_result.session_id,
)
eval_case_results.append(
EvalCaseResult(
eval_set_file=eval_result.eval_set_file,
eval_id=eval_result.eval_id,
final_eval_status=eval_result.final_eval_status,
eval_metric_results=eval_result.eval_metric_results,
session_id=eval_result.session_id,
session_details=session,
user_id=eval_result.user_id,
)
user_id=eval_case_result.user_id,
session_id=eval_case_result.session_id,
)
eval_case_results.append(eval_case_result)
timestamp = time.time()
eval_set_result_name = app_name + "_" + eval_set_id + "_" + str(timestamp)