mirror of
https://github.com/EvolutionAPI/adk-python.git
synced 2026-02-04 13:56:24 -06:00
Update Eval Run and TrajectoryEvaluator to use the new schema.
PiperOrigin-RevId: 758927160
This commit is contained in:
committed by
Copybara-Service
parent
2cb74dd20e
commit
ee674ce0ef
+151
-83
@@ -12,8 +12,6 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import datetime
|
||||
from enum import Enum
|
||||
import importlib.util
|
||||
import json
|
||||
import logging
|
||||
@@ -22,6 +20,7 @@ import sys
|
||||
import traceback
|
||||
from typing import Any
|
||||
from typing import AsyncGenerator
|
||||
from typing import cast
|
||||
from typing import Optional
|
||||
import uuid
|
||||
|
||||
@@ -29,36 +28,84 @@ from pydantic import BaseModel
|
||||
from pydantic import Field
|
||||
|
||||
from ..agents import Agent
|
||||
from ..artifacts.base_artifact_service import BaseArtifactService
|
||||
from ..evaluation.eval_case import EvalCase
|
||||
from ..evaluation.eval_case import Invocation
|
||||
from ..evaluation.evaluator import EvalStatus
|
||||
from ..sessions.base_session_service import BaseSessionService
|
||||
from ..sessions.session import Session
|
||||
from .utils import common
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class EvalStatus(Enum):
|
||||
PASSED = 1
|
||||
FAILED = 2
|
||||
NOT_EVALUATED = 3
|
||||
|
||||
|
||||
class EvalMetric(BaseModel):
|
||||
"""A metric used to evaluate a particular aspect of an eval case."""
|
||||
|
||||
metric_name: str
|
||||
"""The name of the metric."""
|
||||
|
||||
threshold: float
|
||||
"""A threshold value. Each metric decides how to interpret this threshold."""
|
||||
|
||||
|
||||
class EvalMetricResult(BaseModel):
|
||||
class EvalMetricResult(EvalMetric):
|
||||
"""The actual computed score/value of a particular EvalMetric."""
|
||||
|
||||
score: Optional[float] = None
|
||||
eval_status: EvalStatus
|
||||
|
||||
|
||||
class EvalMetricResultPerInvocation(BaseModel):
|
||||
"""Eval metric results per invocation."""
|
||||
|
||||
actual_invocation: Invocation
|
||||
"""The actual invocation, usually obtained by inferencing the agent."""
|
||||
|
||||
expected_invocation: Invocation
|
||||
"""The expected invocation, usually the reference or golden invocation."""
|
||||
|
||||
eval_metric_results: list[EvalMetricResult] = []
|
||||
"""Eval resutls for each applicable metric."""
|
||||
|
||||
|
||||
class EvalCaseResult(common.BaseModel):
|
||||
eval_set_file: str
|
||||
eval_id: str
|
||||
"""Case-level evaluation results."""
|
||||
|
||||
eval_set_file: str = Field(
|
||||
deprecated=True,
|
||||
description="This field is deprecated, use eval_set_id instead.",
|
||||
)
|
||||
eval_set_id: str = ""
|
||||
"""The eval set id."""
|
||||
|
||||
eval_id: str = ""
|
||||
"""The eval case id."""
|
||||
|
||||
final_eval_status: EvalStatus
|
||||
eval_metric_results: list[tuple[EvalMetric, EvalMetricResult]]
|
||||
"""Final evalu status for this eval case."""
|
||||
|
||||
eval_metric_results: list[tuple[EvalMetric, EvalMetricResult]] = Field(
|
||||
deprecated=True,
|
||||
description=(
|
||||
"This field is deprecated, use overall_eval_metric_results instead."
|
||||
),
|
||||
)
|
||||
|
||||
overall_eval_metric_results: list[EvalMetricResult]
|
||||
"""Overall result for each metric for the entire eval case."""
|
||||
|
||||
eval_metric_result_per_invocation: list[EvalMetricResultPerInvocation]
|
||||
"""Result for each metric on a per invocation basis."""
|
||||
|
||||
session_id: str
|
||||
"""Session id of the session generated as result of inferencing/scraping stage of the eval."""
|
||||
|
||||
session_details: Optional[Session] = None
|
||||
"""Session generated as result of inferencing/scraping stage of the eval."""
|
||||
|
||||
user_id: Optional[str] = None
|
||||
"""User id used during inferencing/scraping stage of the eval."""
|
||||
|
||||
|
||||
class EvalSetResult(common.BaseModel):
|
||||
@@ -161,14 +208,25 @@ def parse_and_get_evals_to_run(
|
||||
|
||||
|
||||
async def run_evals(
|
||||
eval_set_to_evals: dict[str, list[str]],
|
||||
eval_cases_by_eval_set_id: dict[str, list[EvalCase]],
|
||||
root_agent: Agent,
|
||||
reset_func: Optional[Any],
|
||||
eval_metrics: list[EvalMetric],
|
||||
session_service=None,
|
||||
artifact_service=None,
|
||||
print_detailed_results=False,
|
||||
session_service: Optional[BaseSessionService] = None,
|
||||
artifact_service: Optional[BaseArtifactService] = None,
|
||||
) -> AsyncGenerator[EvalCaseResult, None]:
|
||||
"""Returns a stream of EvalCaseResult for each eval case that was evaluated.
|
||||
|
||||
Args:
|
||||
eval_cases_by_eval_set_id: Eval cases categorized by eval set id to which
|
||||
they belong.
|
||||
root_agent: Agent to use for inferencing.
|
||||
reset_func: If present, this will be called before invoking the agent before
|
||||
every inferencing step.
|
||||
eval_metrics: A list of metrics that should be used during evaluation.
|
||||
session_service: The session service to use during inferencing.
|
||||
artifact_service: The artifact service to use during inferencing.
|
||||
"""
|
||||
try:
|
||||
from ..evaluation.agent_evaluator import EvaluationGenerator
|
||||
from ..evaluation.response_evaluator import ResponseEvaluator
|
||||
@@ -176,29 +234,19 @@ async def run_evals(
|
||||
except ModuleNotFoundError as e:
|
||||
raise ModuleNotFoundError(MISSING_EVAL_DEPENDENCIES_MESSAGE) from e
|
||||
|
||||
"""Returns a summary of eval runs."""
|
||||
for eval_set_file, evals_to_run in eval_set_to_evals.items():
|
||||
with open(eval_set_file, "r", encoding="utf-8") as file:
|
||||
eval_items = json.load(file) # Load JSON into a list
|
||||
|
||||
assert eval_items, f"No eval data found in eval set file: {eval_set_file}"
|
||||
|
||||
for eval_item in eval_items:
|
||||
eval_name = eval_item["name"]
|
||||
eval_data = eval_item["data"]
|
||||
initial_session = eval_item.get("initial_session", {})
|
||||
user_id = initial_session.get("user_id", "test_user_id")
|
||||
|
||||
if evals_to_run and eval_name not in evals_to_run:
|
||||
continue
|
||||
for eval_set_id, eval_cases in eval_cases_by_eval_set_id.items():
|
||||
for eval_case in eval_cases:
|
||||
eval_name = eval_case.eval_id
|
||||
initial_session = eval_case.session_input
|
||||
user_id = initial_session.user_id if initial_session else "test_user_id"
|
||||
|
||||
try:
|
||||
print(f"Running Eval: {eval_set_file}:{eval_name}")
|
||||
print(f"Running Eval: {eval_set_id}:{eval_name}")
|
||||
session_id = f"{EVAL_SESSION_ID_PREFIX}{str(uuid.uuid4())}"
|
||||
|
||||
scrape_result = (
|
||||
await EvaluationGenerator._process_query_with_root_agent(
|
||||
data=eval_data,
|
||||
inference_result = (
|
||||
await EvaluationGenerator._generate_inferences_from_root_agent(
|
||||
invocations=eval_case.conversation,
|
||||
root_agent=root_agent,
|
||||
reset_func=reset_func,
|
||||
initial_session=initial_session,
|
||||
@@ -208,67 +256,95 @@ async def run_evals(
|
||||
)
|
||||
)
|
||||
|
||||
eval_metric_results = []
|
||||
# Initialize the per-invocation metric results to an empty list.
|
||||
# We will fill this as we evaluate each metric.
|
||||
eval_metric_result_per_invocation = []
|
||||
for actual, expected in zip(inference_result, eval_case.conversation):
|
||||
eval_metric_result_per_invocation.append(
|
||||
EvalMetricResultPerInvocation(
|
||||
actual_invocation=actual,
|
||||
expected_invocation=expected,
|
||||
eval_metric_results=[],
|
||||
)
|
||||
)
|
||||
|
||||
overall_eval_metric_results = []
|
||||
|
||||
for eval_metric in eval_metrics:
|
||||
eval_metric_result = None
|
||||
if eval_metric.metric_name == TOOL_TRAJECTORY_SCORE_KEY:
|
||||
score = TrajectoryEvaluator.evaluate(
|
||||
[scrape_result], print_detailed_results=print_detailed_results
|
||||
evaluation_result = TrajectoryEvaluator(
|
||||
eval_metric.threshold
|
||||
).evaluate_invocations(
|
||||
actual_invocations=inference_result,
|
||||
expected_invocations=eval_case.conversation,
|
||||
)
|
||||
eval_metric_result = _get_eval_metric_result(eval_metric, score)
|
||||
elif eval_metric.metric_name == RESPONSE_MATCH_SCORE_KEY:
|
||||
score = ResponseEvaluator.evaluate(
|
||||
[scrape_result],
|
||||
[RESPONSE_MATCH_SCORE_KEY],
|
||||
print_detailed_results=print_detailed_results,
|
||||
)
|
||||
eval_metric_result = _get_eval_metric_result(
|
||||
eval_metric, score["rouge_1/mean"].item()
|
||||
)
|
||||
elif eval_metric.metric_name == RESPONSE_EVALUATION_SCORE_KEY:
|
||||
score = ResponseEvaluator.evaluate(
|
||||
[scrape_result],
|
||||
[RESPONSE_EVALUATION_SCORE_KEY],
|
||||
print_detailed_results=print_detailed_results,
|
||||
)
|
||||
eval_metric_result = _get_eval_metric_result(
|
||||
eval_metric, score["coherence/mean"].item()
|
||||
overall_eval_metric_results.append(
|
||||
EvalMetricResult(
|
||||
metric_name=eval_metric.metric_name,
|
||||
threshold=eval_metric.threshold,
|
||||
score=evaluation_result.overall_score,
|
||||
eval_status=evaluation_result.overall_eval_status,
|
||||
)
|
||||
)
|
||||
for index, per_invocation_result in enumerate(
|
||||
evaluation_result.per_invocation_results
|
||||
):
|
||||
eval_metric_result_per_invocation[
|
||||
index
|
||||
].eval_metric_results.append(
|
||||
EvalMetricResult(
|
||||
metric_name=eval_metric.metric_name,
|
||||
threshold=eval_metric.threshold,
|
||||
score=per_invocation_result.score,
|
||||
eval_status=per_invocation_result.eval_status,
|
||||
)
|
||||
)
|
||||
|
||||
# elif eval_metric.metric_name == RESPONSE_MATCH_SCORE_KEY:
|
||||
# score = ResponseEvaluator.evaluate(
|
||||
# [inference_result],
|
||||
# [RESPONSE_MATCH_SCORE_KEY],
|
||||
# print_detailed_results=print_detailed_results,
|
||||
# )
|
||||
# eval_metric_result = _get_eval_metric_result(
|
||||
# eval_metric, score["rouge_1/mean"].item()
|
||||
# )
|
||||
# elif eval_metric.metric_name == RESPONSE_EVALUATION_SCORE_KEY:
|
||||
# score = ResponseEvaluator.evaluate(
|
||||
# [inference_result],
|
||||
# [RESPONSE_EVALUATION_SCORE_KEY],
|
||||
# print_detailed_results=print_detailed_results,
|
||||
# )
|
||||
# eval_metric_result = _get_eval_metric_result(
|
||||
# eval_metric, score["coherence/mean"].item()
|
||||
# )
|
||||
else:
|
||||
logger.warning("`%s` is not supported.", eval_metric.metric_name)
|
||||
eval_metric_results.append((
|
||||
eval_metric,
|
||||
EvalMetricResult(eval_status=EvalStatus.NOT_EVALUATED),
|
||||
))
|
||||
|
||||
eval_metric_results.append((
|
||||
eval_metric,
|
||||
eval_metric_result,
|
||||
))
|
||||
_print_eval_metric_result(eval_metric, eval_metric_result)
|
||||
|
||||
final_eval_status = EvalStatus.NOT_EVALUATED
|
||||
|
||||
# Go over the all the eval statuses and mark the final eval status as
|
||||
# passed if all of them pass, otherwise mark the final eval status to
|
||||
# failed.
|
||||
for eval_metric_result in eval_metric_results:
|
||||
eval_status = eval_metric_result[1].eval_status
|
||||
if eval_status == EvalStatus.PASSED:
|
||||
for overall_eval_metric_result in overall_eval_metric_results:
|
||||
overall_eval_status = overall_eval_metric_result.eval_status
|
||||
if overall_eval_status == EvalStatus.PASSED:
|
||||
final_eval_status = EvalStatus.PASSED
|
||||
elif eval_status == EvalStatus.NOT_EVALUATED:
|
||||
elif overall_eval_status == EvalStatus.NOT_EVALUATED:
|
||||
continue
|
||||
elif eval_status == EvalStatus.FAILED:
|
||||
elif overall_eval_status == EvalStatus.FAILED:
|
||||
final_eval_status = EvalStatus.FAILED
|
||||
break
|
||||
else:
|
||||
raise ValueError("Unknown eval status.")
|
||||
|
||||
yield EvalCaseResult(
|
||||
eval_set_file=eval_set_file,
|
||||
eval_set_file=eval_set_id,
|
||||
eval_set_id=eval_set_id,
|
||||
eval_id=eval_name,
|
||||
final_eval_status=final_eval_status,
|
||||
eval_metric_results=eval_metric_results,
|
||||
eval_metric_results=[],
|
||||
overall_eval_metric_results=overall_eval_metric_results,
|
||||
eval_metric_result_per_invocation=eval_metric_result_per_invocation,
|
||||
session_id=session_id,
|
||||
user_id=user_id,
|
||||
)
|
||||
@@ -290,11 +366,3 @@ def _get_eval_metric_result(eval_metric, score):
|
||||
EvalStatus.PASSED if score >= eval_metric.threshold else EvalStatus.FAILED
|
||||
)
|
||||
return EvalMetricResult(score=score, eval_status=eval_status)
|
||||
|
||||
|
||||
def _print_eval_metric_result(eval_metric, eval_metric_result):
|
||||
print(
|
||||
f"Metric: {eval_metric.metric_name}\tStatus:"
|
||||
f" {eval_metric_result.eval_status}\tScore:"
|
||||
f" {eval_metric_result.score}\tThreshold: {eval_metric.threshold}"
|
||||
)
|
||||
|
||||
@@ -296,6 +296,7 @@ def cli_eval(
|
||||
from .cli_eval import parse_and_get_evals_to_run
|
||||
from .cli_eval import run_evals
|
||||
from .cli_eval import try_get_reset_func
|
||||
from ..evaluation.local_eval_sets_manager import load_eval_set_from_file
|
||||
except ModuleNotFoundError:
|
||||
raise click.ClickException(MISSING_EVAL_DEPENDENCIES_MESSAGE)
|
||||
|
||||
@@ -311,17 +312,27 @@ def cli_eval(
|
||||
root_agent = get_root_agent(agent_module_file_path)
|
||||
reset_func = try_get_reset_func(agent_module_file_path)
|
||||
|
||||
eval_set_to_evals = parse_and_get_evals_to_run(eval_set_file_path)
|
||||
eval_set_file_path_to_evals = parse_and_get_evals_to_run(eval_set_file_path)
|
||||
eval_set_id_to_eval_cases = {}
|
||||
|
||||
# Read the eval_set files and get the cases.
|
||||
for eval_set_file_path, eval_case_ids in eval_set_file_path_to_evals.items():
|
||||
eval_set = load_eval_set_from_file(eval_set_file_path, eval_set_file_path)
|
||||
eval_cases = eval_set.eval_cases
|
||||
|
||||
if eval_case_ids:
|
||||
# There are eval_ids that we should select.
|
||||
eval_cases = [
|
||||
e for e in eval_set.eval_cases if e.eval_id in eval_case_ids
|
||||
]
|
||||
|
||||
eval_set_id_to_eval_cases[eval_set_file_path] = eval_cases
|
||||
|
||||
async def _collect_eval_results() -> list[EvalCaseResult]:
|
||||
return [
|
||||
result
|
||||
async for result in run_evals(
|
||||
eval_set_to_evals,
|
||||
root_agent,
|
||||
reset_func,
|
||||
eval_metrics,
|
||||
print_detailed_results=print_detailed_results,
|
||||
eval_set_id_to_eval_cases, root_agent, reset_func, eval_metrics
|
||||
)
|
||||
]
|
||||
|
||||
@@ -336,20 +347,28 @@ def cli_eval(
|
||||
for eval_result in eval_results:
|
||||
eval_result: EvalCaseResult
|
||||
|
||||
if eval_result.eval_set_file not in eval_run_summary:
|
||||
eval_run_summary[eval_result.eval_set_file] = [0, 0]
|
||||
if eval_result.eval_set_id not in eval_run_summary:
|
||||
eval_run_summary[eval_result.eval_set_id] = [0, 0]
|
||||
|
||||
if eval_result.final_eval_status == EvalStatus.PASSED:
|
||||
eval_run_summary[eval_result.eval_set_file][0] += 1
|
||||
eval_run_summary[eval_result.eval_set_id][0] += 1
|
||||
else:
|
||||
eval_run_summary[eval_result.eval_set_file][1] += 1
|
||||
eval_run_summary[eval_result.eval_set_id][1] += 1
|
||||
print("Eval Run Summary")
|
||||
for eval_set_file, pass_fail_count in eval_run_summary.items():
|
||||
for eval_set_id, pass_fail_count in eval_run_summary.items():
|
||||
print(
|
||||
f"{eval_set_file}:\n Tests passed: {pass_fail_count[0]}\n Tests"
|
||||
f"{eval_set_id}:\n Tests passed: {pass_fail_count[0]}\n Tests"
|
||||
f" failed: {pass_fail_count[1]}"
|
||||
)
|
||||
|
||||
if print_detailed_results:
|
||||
for eval_result in eval_results:
|
||||
eval_result: EvalCaseResult
|
||||
print(
|
||||
"*********************************************************************"
|
||||
)
|
||||
print(eval_result.model_dump_json(indent=2))
|
||||
|
||||
|
||||
@main.command("web")
|
||||
@click.option(
|
||||
|
||||
@@ -48,6 +48,7 @@ from opentelemetry.exporter.cloud_trace import CloudTraceSpanExporter
|
||||
from opentelemetry.sdk.trace import export
|
||||
from opentelemetry.sdk.trace import ReadableSpan
|
||||
from opentelemetry.sdk.trace import TracerProvider
|
||||
from pydantic import Field
|
||||
from pydantic import ValidationError
|
||||
from starlette.types import Lifespan
|
||||
from typing_extensions import override
|
||||
@@ -75,6 +76,7 @@ from .cli_eval import EVAL_SESSION_ID_PREFIX
|
||||
from .cli_eval import EvalCaseResult
|
||||
from .cli_eval import EvalMetric
|
||||
from .cli_eval import EvalMetricResult
|
||||
from .cli_eval import EvalMetricResultPerInvocation
|
||||
from .cli_eval import EvalSetResult
|
||||
from .cli_eval import EvalStatus
|
||||
from .utils import common
|
||||
@@ -175,7 +177,14 @@ class RunEvalResult(common.BaseModel):
|
||||
eval_set_id: str
|
||||
eval_id: str
|
||||
final_eval_status: EvalStatus
|
||||
eval_metric_results: list[tuple[EvalMetric, EvalMetricResult]]
|
||||
eval_metric_results: list[tuple[EvalMetric, EvalMetricResult]] = Field(
|
||||
deprecated=True,
|
||||
description=(
|
||||
"This field is deprecated, use overall_eval_metric_results instead."
|
||||
),
|
||||
)
|
||||
overall_eval_metric_results: list[EvalMetricResult]
|
||||
eval_metric_result_per_invocation: list[EvalMetricResultPerInvocation]
|
||||
user_id: str
|
||||
session_id: str
|
||||
|
||||
@@ -480,25 +489,26 @@ def get_fast_api_app(
|
||||
async def run_eval(
|
||||
app_name: str, eval_set_id: str, req: RunEvalRequest
|
||||
) -> list[RunEvalResult]:
|
||||
"""Runs an eval given the details in the eval request."""
|
||||
from .cli_eval import run_evals
|
||||
|
||||
"""Runs an eval given the details in the eval request."""
|
||||
# Create a mapping from eval set file to all the evals that needed to be
|
||||
# run.
|
||||
envs.load_dotenv_for_agent(os.path.basename(app_name), agent_dir)
|
||||
eval_set_file_path = _get_eval_set_file_path(
|
||||
app_name, agent_dir, eval_set_id
|
||||
)
|
||||
eval_set_to_evals = {eval_set_file_path: req.eval_ids}
|
||||
|
||||
if not req.eval_ids:
|
||||
logger.info(
|
||||
"Eval ids to run list is empty. We will all evals in the eval set."
|
||||
)
|
||||
eval_set = eval_sets_manager.get_eval_set(app_name, eval_set_id)
|
||||
|
||||
if req.eval_ids:
|
||||
eval_cases = [e for e in eval_set.eval_cases if e.eval_id in req.eval_ids]
|
||||
eval_set_to_evals = {eval_set_id: eval_cases}
|
||||
else:
|
||||
logger.info("Eval ids to run list is empty. We will run all eval cases.")
|
||||
eval_set_to_evals = {eval_set_id: eval_set.eval_cases}
|
||||
|
||||
root_agent = await _get_root_agent_async(app_name)
|
||||
run_eval_results = []
|
||||
eval_case_results = []
|
||||
async for eval_result in run_evals(
|
||||
async for eval_case_result in run_evals(
|
||||
eval_set_to_evals,
|
||||
root_agent,
|
||||
getattr(root_agent, "reset_data", None),
|
||||
@@ -509,31 +519,23 @@ def get_fast_api_app(
|
||||
run_eval_results.append(
|
||||
RunEvalResult(
|
||||
app_name=app_name,
|
||||
eval_set_file=eval_result.eval_set_file,
|
||||
eval_set_file=eval_case_result.eval_set_file,
|
||||
eval_set_id=eval_set_id,
|
||||
eval_id=eval_result.eval_id,
|
||||
final_eval_status=eval_result.final_eval_status,
|
||||
eval_metric_results=eval_result.eval_metric_results,
|
||||
user_id=eval_result.user_id,
|
||||
session_id=eval_result.session_id,
|
||||
eval_id=eval_case_result.eval_id,
|
||||
final_eval_status=eval_case_result.final_eval_status,
|
||||
eval_metric_results=eval_case_result.eval_metric_results,
|
||||
overall_eval_metric_results=eval_case_result.overall_eval_metric_results,
|
||||
eval_metric_result_per_invocation=eval_case_result.eval_metric_result_per_invocation,
|
||||
user_id=eval_case_result.user_id,
|
||||
session_id=eval_case_result.session_id,
|
||||
)
|
||||
)
|
||||
session = session_service.get_session(
|
||||
eval_case_result.session_details = session_service.get_session(
|
||||
app_name=app_name,
|
||||
user_id=eval_result.user_id,
|
||||
session_id=eval_result.session_id,
|
||||
)
|
||||
eval_case_results.append(
|
||||
EvalCaseResult(
|
||||
eval_set_file=eval_result.eval_set_file,
|
||||
eval_id=eval_result.eval_id,
|
||||
final_eval_status=eval_result.final_eval_status,
|
||||
eval_metric_results=eval_result.eval_metric_results,
|
||||
session_id=eval_result.session_id,
|
||||
session_details=session,
|
||||
user_id=eval_result.user_id,
|
||||
)
|
||||
user_id=eval_case_result.user_id,
|
||||
session_id=eval_case_result.session_id,
|
||||
)
|
||||
eval_case_results.append(eval_case_result)
|
||||
|
||||
timestamp = time.time()
|
||||
eval_set_result_name = app_name + "_" + eval_set_id + "_" + str(timestamp)
|
||||
|
||||
Reference in New Issue
Block a user