Update Eval Run and TrajectoryEvaluator to use the new schema.

PiperOrigin-RevId: 758927160
This commit is contained in:
Ankur Sharma 2025-05-14 19:15:52 -07:00 committed by Copybara-Service
parent 2cb74dd20e
commit ee674ce0ef
9 changed files with 418 additions and 244 deletions

View File

@ -12,8 +12,6 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import datetime
from enum import Enum
import importlib.util import importlib.util
import json import json
import logging import logging
@ -22,6 +20,7 @@ import sys
import traceback import traceback
from typing import Any from typing import Any
from typing import AsyncGenerator from typing import AsyncGenerator
from typing import cast
from typing import Optional from typing import Optional
import uuid import uuid
@ -29,36 +28,84 @@ from pydantic import BaseModel
from pydantic import Field from pydantic import Field
from ..agents import Agent from ..agents import Agent
from ..artifacts.base_artifact_service import BaseArtifactService
from ..evaluation.eval_case import EvalCase
from ..evaluation.eval_case import Invocation
from ..evaluation.evaluator import EvalStatus
from ..sessions.base_session_service import BaseSessionService
from ..sessions.session import Session from ..sessions.session import Session
from .utils import common from .utils import common
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class EvalStatus(Enum):
PASSED = 1
FAILED = 2
NOT_EVALUATED = 3
class EvalMetric(BaseModel): class EvalMetric(BaseModel):
"""A metric used to evaluate a particular aspect of an eval case."""
metric_name: str metric_name: str
"""The name of the metric."""
threshold: float threshold: float
"""A threshold value. Each metric decides how to interpret this threshold."""
class EvalMetricResult(BaseModel): class EvalMetricResult(EvalMetric):
"""The actual computed score/value of a particular EvalMetric."""
score: Optional[float] = None score: Optional[float] = None
eval_status: EvalStatus eval_status: EvalStatus
class EvalMetricResultPerInvocation(BaseModel):
"""Eval metric results per invocation."""
actual_invocation: Invocation
"""The actual invocation, usually obtained by inferencing the agent."""
expected_invocation: Invocation
"""The expected invocation, usually the reference or golden invocation."""
eval_metric_results: list[EvalMetricResult] = []
"""Eval resutls for each applicable metric."""
class EvalCaseResult(common.BaseModel): class EvalCaseResult(common.BaseModel):
eval_set_file: str """Case-level evaluation results."""
eval_id: str
eval_set_file: str = Field(
deprecated=True,
description="This field is deprecated, use eval_set_id instead.",
)
eval_set_id: str = ""
"""The eval set id."""
eval_id: str = ""
"""The eval case id."""
final_eval_status: EvalStatus final_eval_status: EvalStatus
eval_metric_results: list[tuple[EvalMetric, EvalMetricResult]] """Final evalu status for this eval case."""
eval_metric_results: list[tuple[EvalMetric, EvalMetricResult]] = Field(
deprecated=True,
description=(
"This field is deprecated, use overall_eval_metric_results instead."
),
)
overall_eval_metric_results: list[EvalMetricResult]
"""Overall result for each metric for the entire eval case."""
eval_metric_result_per_invocation: list[EvalMetricResultPerInvocation]
"""Result for each metric on a per invocation basis."""
session_id: str session_id: str
"""Session id of the session generated as result of inferencing/scraping stage of the eval."""
session_details: Optional[Session] = None session_details: Optional[Session] = None
"""Session generated as result of inferencing/scraping stage of the eval."""
user_id: Optional[str] = None user_id: Optional[str] = None
"""User id used during inferencing/scraping stage of the eval."""
class EvalSetResult(common.BaseModel): class EvalSetResult(common.BaseModel):
@ -161,14 +208,25 @@ def parse_and_get_evals_to_run(
async def run_evals( async def run_evals(
eval_set_to_evals: dict[str, list[str]], eval_cases_by_eval_set_id: dict[str, list[EvalCase]],
root_agent: Agent, root_agent: Agent,
reset_func: Optional[Any], reset_func: Optional[Any],
eval_metrics: list[EvalMetric], eval_metrics: list[EvalMetric],
session_service=None, session_service: Optional[BaseSessionService] = None,
artifact_service=None, artifact_service: Optional[BaseArtifactService] = None,
print_detailed_results=False,
) -> AsyncGenerator[EvalCaseResult, None]: ) -> AsyncGenerator[EvalCaseResult, None]:
"""Returns a stream of EvalCaseResult for each eval case that was evaluated.
Args:
eval_cases_by_eval_set_id: Eval cases categorized by eval set id to which
they belong.
root_agent: Agent to use for inferencing.
reset_func: If present, this will be called before invoking the agent before
every inferencing step.
eval_metrics: A list of metrics that should be used during evaluation.
session_service: The session service to use during inferencing.
artifact_service: The artifact service to use during inferencing.
"""
try: try:
from ..evaluation.agent_evaluator import EvaluationGenerator from ..evaluation.agent_evaluator import EvaluationGenerator
from ..evaluation.response_evaluator import ResponseEvaluator from ..evaluation.response_evaluator import ResponseEvaluator
@ -176,29 +234,19 @@ async def run_evals(
except ModuleNotFoundError as e: except ModuleNotFoundError as e:
raise ModuleNotFoundError(MISSING_EVAL_DEPENDENCIES_MESSAGE) from e raise ModuleNotFoundError(MISSING_EVAL_DEPENDENCIES_MESSAGE) from e
"""Returns a summary of eval runs.""" for eval_set_id, eval_cases in eval_cases_by_eval_set_id.items():
for eval_set_file, evals_to_run in eval_set_to_evals.items(): for eval_case in eval_cases:
with open(eval_set_file, "r", encoding="utf-8") as file: eval_name = eval_case.eval_id
eval_items = json.load(file) # Load JSON into a list initial_session = eval_case.session_input
user_id = initial_session.user_id if initial_session else "test_user_id"
assert eval_items, f"No eval data found in eval set file: {eval_set_file}"
for eval_item in eval_items:
eval_name = eval_item["name"]
eval_data = eval_item["data"]
initial_session = eval_item.get("initial_session", {})
user_id = initial_session.get("user_id", "test_user_id")
if evals_to_run and eval_name not in evals_to_run:
continue
try: try:
print(f"Running Eval: {eval_set_file}:{eval_name}") print(f"Running Eval: {eval_set_id}:{eval_name}")
session_id = f"{EVAL_SESSION_ID_PREFIX}{str(uuid.uuid4())}" session_id = f"{EVAL_SESSION_ID_PREFIX}{str(uuid.uuid4())}"
scrape_result = ( inference_result = (
await EvaluationGenerator._process_query_with_root_agent( await EvaluationGenerator._generate_inferences_from_root_agent(
data=eval_data, invocations=eval_case.conversation,
root_agent=root_agent, root_agent=root_agent,
reset_func=reset_func, reset_func=reset_func,
initial_session=initial_session, initial_session=initial_session,
@ -208,67 +256,95 @@ async def run_evals(
) )
) )
eval_metric_results = [] # Initialize the per-invocation metric results to an empty list.
# We will fill this as we evaluate each metric.
eval_metric_result_per_invocation = []
for actual, expected in zip(inference_result, eval_case.conversation):
eval_metric_result_per_invocation.append(
EvalMetricResultPerInvocation(
actual_invocation=actual,
expected_invocation=expected,
eval_metric_results=[],
)
)
overall_eval_metric_results = []
for eval_metric in eval_metrics: for eval_metric in eval_metrics:
eval_metric_result = None
if eval_metric.metric_name == TOOL_TRAJECTORY_SCORE_KEY: if eval_metric.metric_name == TOOL_TRAJECTORY_SCORE_KEY:
score = TrajectoryEvaluator.evaluate( evaluation_result = TrajectoryEvaluator(
[scrape_result], print_detailed_results=print_detailed_results eval_metric.threshold
).evaluate_invocations(
actual_invocations=inference_result,
expected_invocations=eval_case.conversation,
) )
eval_metric_result = _get_eval_metric_result(eval_metric, score) overall_eval_metric_results.append(
elif eval_metric.metric_name == RESPONSE_MATCH_SCORE_KEY: EvalMetricResult(
score = ResponseEvaluator.evaluate( metric_name=eval_metric.metric_name,
[scrape_result], threshold=eval_metric.threshold,
[RESPONSE_MATCH_SCORE_KEY], score=evaluation_result.overall_score,
print_detailed_results=print_detailed_results, eval_status=evaluation_result.overall_eval_status,
) )
eval_metric_result = _get_eval_metric_result(
eval_metric, score["rouge_1/mean"].item()
) )
elif eval_metric.metric_name == RESPONSE_EVALUATION_SCORE_KEY: for index, per_invocation_result in enumerate(
score = ResponseEvaluator.evaluate( evaluation_result.per_invocation_results
[scrape_result], ):
[RESPONSE_EVALUATION_SCORE_KEY], eval_metric_result_per_invocation[
print_detailed_results=print_detailed_results, index
].eval_metric_results.append(
EvalMetricResult(
metric_name=eval_metric.metric_name,
threshold=eval_metric.threshold,
score=per_invocation_result.score,
eval_status=per_invocation_result.eval_status,
) )
eval_metric_result = _get_eval_metric_result(
eval_metric, score["coherence/mean"].item()
) )
# elif eval_metric.metric_name == RESPONSE_MATCH_SCORE_KEY:
# score = ResponseEvaluator.evaluate(
# [inference_result],
# [RESPONSE_MATCH_SCORE_KEY],
# print_detailed_results=print_detailed_results,
# )
# eval_metric_result = _get_eval_metric_result(
# eval_metric, score["rouge_1/mean"].item()
# )
# elif eval_metric.metric_name == RESPONSE_EVALUATION_SCORE_KEY:
# score = ResponseEvaluator.evaluate(
# [inference_result],
# [RESPONSE_EVALUATION_SCORE_KEY],
# print_detailed_results=print_detailed_results,
# )
# eval_metric_result = _get_eval_metric_result(
# eval_metric, score["coherence/mean"].item()
# )
else: else:
logger.warning("`%s` is not supported.", eval_metric.metric_name) logger.warning("`%s` is not supported.", eval_metric.metric_name)
eval_metric_results.append((
eval_metric,
EvalMetricResult(eval_status=EvalStatus.NOT_EVALUATED),
))
eval_metric_results.append((
eval_metric,
eval_metric_result,
))
_print_eval_metric_result(eval_metric, eval_metric_result)
final_eval_status = EvalStatus.NOT_EVALUATED final_eval_status = EvalStatus.NOT_EVALUATED
# Go over the all the eval statuses and mark the final eval status as # Go over the all the eval statuses and mark the final eval status as
# passed if all of them pass, otherwise mark the final eval status to # passed if all of them pass, otherwise mark the final eval status to
# failed. # failed.
for eval_metric_result in eval_metric_results: for overall_eval_metric_result in overall_eval_metric_results:
eval_status = eval_metric_result[1].eval_status overall_eval_status = overall_eval_metric_result.eval_status
if eval_status == EvalStatus.PASSED: if overall_eval_status == EvalStatus.PASSED:
final_eval_status = EvalStatus.PASSED final_eval_status = EvalStatus.PASSED
elif eval_status == EvalStatus.NOT_EVALUATED: elif overall_eval_status == EvalStatus.NOT_EVALUATED:
continue continue
elif eval_status == EvalStatus.FAILED: elif overall_eval_status == EvalStatus.FAILED:
final_eval_status = EvalStatus.FAILED final_eval_status = EvalStatus.FAILED
break break
else: else:
raise ValueError("Unknown eval status.") raise ValueError("Unknown eval status.")
yield EvalCaseResult( yield EvalCaseResult(
eval_set_file=eval_set_file, eval_set_file=eval_set_id,
eval_set_id=eval_set_id,
eval_id=eval_name, eval_id=eval_name,
final_eval_status=final_eval_status, final_eval_status=final_eval_status,
eval_metric_results=eval_metric_results, eval_metric_results=[],
overall_eval_metric_results=overall_eval_metric_results,
eval_metric_result_per_invocation=eval_metric_result_per_invocation,
session_id=session_id, session_id=session_id,
user_id=user_id, user_id=user_id,
) )
@ -290,11 +366,3 @@ def _get_eval_metric_result(eval_metric, score):
EvalStatus.PASSED if score >= eval_metric.threshold else EvalStatus.FAILED EvalStatus.PASSED if score >= eval_metric.threshold else EvalStatus.FAILED
) )
return EvalMetricResult(score=score, eval_status=eval_status) return EvalMetricResult(score=score, eval_status=eval_status)
def _print_eval_metric_result(eval_metric, eval_metric_result):
print(
f"Metric: {eval_metric.metric_name}\tStatus:"
f" {eval_metric_result.eval_status}\tScore:"
f" {eval_metric_result.score}\tThreshold: {eval_metric.threshold}"
)

View File

@ -296,6 +296,7 @@ def cli_eval(
from .cli_eval import parse_and_get_evals_to_run from .cli_eval import parse_and_get_evals_to_run
from .cli_eval import run_evals from .cli_eval import run_evals
from .cli_eval import try_get_reset_func from .cli_eval import try_get_reset_func
from ..evaluation.local_eval_sets_manager import load_eval_set_from_file
except ModuleNotFoundError: except ModuleNotFoundError:
raise click.ClickException(MISSING_EVAL_DEPENDENCIES_MESSAGE) raise click.ClickException(MISSING_EVAL_DEPENDENCIES_MESSAGE)
@ -311,17 +312,27 @@ def cli_eval(
root_agent = get_root_agent(agent_module_file_path) root_agent = get_root_agent(agent_module_file_path)
reset_func = try_get_reset_func(agent_module_file_path) reset_func = try_get_reset_func(agent_module_file_path)
eval_set_to_evals = parse_and_get_evals_to_run(eval_set_file_path) eval_set_file_path_to_evals = parse_and_get_evals_to_run(eval_set_file_path)
eval_set_id_to_eval_cases = {}
# Read the eval_set files and get the cases.
for eval_set_file_path, eval_case_ids in eval_set_file_path_to_evals.items():
eval_set = load_eval_set_from_file(eval_set_file_path, eval_set_file_path)
eval_cases = eval_set.eval_cases
if eval_case_ids:
# There are eval_ids that we should select.
eval_cases = [
e for e in eval_set.eval_cases if e.eval_id in eval_case_ids
]
eval_set_id_to_eval_cases[eval_set_file_path] = eval_cases
async def _collect_eval_results() -> list[EvalCaseResult]: async def _collect_eval_results() -> list[EvalCaseResult]:
return [ return [
result result
async for result in run_evals( async for result in run_evals(
eval_set_to_evals, eval_set_id_to_eval_cases, root_agent, reset_func, eval_metrics
root_agent,
reset_func,
eval_metrics,
print_detailed_results=print_detailed_results,
) )
] ]
@ -336,20 +347,28 @@ def cli_eval(
for eval_result in eval_results: for eval_result in eval_results:
eval_result: EvalCaseResult eval_result: EvalCaseResult
if eval_result.eval_set_file not in eval_run_summary: if eval_result.eval_set_id not in eval_run_summary:
eval_run_summary[eval_result.eval_set_file] = [0, 0] eval_run_summary[eval_result.eval_set_id] = [0, 0]
if eval_result.final_eval_status == EvalStatus.PASSED: if eval_result.final_eval_status == EvalStatus.PASSED:
eval_run_summary[eval_result.eval_set_file][0] += 1 eval_run_summary[eval_result.eval_set_id][0] += 1
else: else:
eval_run_summary[eval_result.eval_set_file][1] += 1 eval_run_summary[eval_result.eval_set_id][1] += 1
print("Eval Run Summary") print("Eval Run Summary")
for eval_set_file, pass_fail_count in eval_run_summary.items(): for eval_set_id, pass_fail_count in eval_run_summary.items():
print( print(
f"{eval_set_file}:\n Tests passed: {pass_fail_count[0]}\n Tests" f"{eval_set_id}:\n Tests passed: {pass_fail_count[0]}\n Tests"
f" failed: {pass_fail_count[1]}" f" failed: {pass_fail_count[1]}"
) )
if print_detailed_results:
for eval_result in eval_results:
eval_result: EvalCaseResult
print(
"*********************************************************************"
)
print(eval_result.model_dump_json(indent=2))
@main.command("web") @main.command("web")
@click.option( @click.option(

View File

@ -48,6 +48,7 @@ from opentelemetry.exporter.cloud_trace import CloudTraceSpanExporter
from opentelemetry.sdk.trace import export from opentelemetry.sdk.trace import export
from opentelemetry.sdk.trace import ReadableSpan from opentelemetry.sdk.trace import ReadableSpan
from opentelemetry.sdk.trace import TracerProvider from opentelemetry.sdk.trace import TracerProvider
from pydantic import Field
from pydantic import ValidationError from pydantic import ValidationError
from starlette.types import Lifespan from starlette.types import Lifespan
from typing_extensions import override from typing_extensions import override
@ -75,6 +76,7 @@ from .cli_eval import EVAL_SESSION_ID_PREFIX
from .cli_eval import EvalCaseResult from .cli_eval import EvalCaseResult
from .cli_eval import EvalMetric from .cli_eval import EvalMetric
from .cli_eval import EvalMetricResult from .cli_eval import EvalMetricResult
from .cli_eval import EvalMetricResultPerInvocation
from .cli_eval import EvalSetResult from .cli_eval import EvalSetResult
from .cli_eval import EvalStatus from .cli_eval import EvalStatus
from .utils import common from .utils import common
@ -175,7 +177,14 @@ class RunEvalResult(common.BaseModel):
eval_set_id: str eval_set_id: str
eval_id: str eval_id: str
final_eval_status: EvalStatus final_eval_status: EvalStatus
eval_metric_results: list[tuple[EvalMetric, EvalMetricResult]] eval_metric_results: list[tuple[EvalMetric, EvalMetricResult]] = Field(
deprecated=True,
description=(
"This field is deprecated, use overall_eval_metric_results instead."
),
)
overall_eval_metric_results: list[EvalMetricResult]
eval_metric_result_per_invocation: list[EvalMetricResultPerInvocation]
user_id: str user_id: str
session_id: str session_id: str
@ -480,25 +489,26 @@ def get_fast_api_app(
async def run_eval( async def run_eval(
app_name: str, eval_set_id: str, req: RunEvalRequest app_name: str, eval_set_id: str, req: RunEvalRequest
) -> list[RunEvalResult]: ) -> list[RunEvalResult]:
"""Runs an eval given the details in the eval request."""
from .cli_eval import run_evals from .cli_eval import run_evals
"""Runs an eval given the details in the eval request."""
# Create a mapping from eval set file to all the evals that needed to be # Create a mapping from eval set file to all the evals that needed to be
# run. # run.
envs.load_dotenv_for_agent(os.path.basename(app_name), agent_dir) envs.load_dotenv_for_agent(os.path.basename(app_name), agent_dir)
eval_set_file_path = _get_eval_set_file_path(
app_name, agent_dir, eval_set_id
)
eval_set_to_evals = {eval_set_file_path: req.eval_ids}
if not req.eval_ids: eval_set = eval_sets_manager.get_eval_set(app_name, eval_set_id)
logger.info(
"Eval ids to run list is empty. We will all evals in the eval set." if req.eval_ids:
) eval_cases = [e for e in eval_set.eval_cases if e.eval_id in req.eval_ids]
eval_set_to_evals = {eval_set_id: eval_cases}
else:
logger.info("Eval ids to run list is empty. We will run all eval cases.")
eval_set_to_evals = {eval_set_id: eval_set.eval_cases}
root_agent = await _get_root_agent_async(app_name) root_agent = await _get_root_agent_async(app_name)
run_eval_results = [] run_eval_results = []
eval_case_results = [] eval_case_results = []
async for eval_result in run_evals( async for eval_case_result in run_evals(
eval_set_to_evals, eval_set_to_evals,
root_agent, root_agent,
getattr(root_agent, "reset_data", None), getattr(root_agent, "reset_data", None),
@ -509,31 +519,23 @@ def get_fast_api_app(
run_eval_results.append( run_eval_results.append(
RunEvalResult( RunEvalResult(
app_name=app_name, app_name=app_name,
eval_set_file=eval_result.eval_set_file, eval_set_file=eval_case_result.eval_set_file,
eval_set_id=eval_set_id, eval_set_id=eval_set_id,
eval_id=eval_result.eval_id, eval_id=eval_case_result.eval_id,
final_eval_status=eval_result.final_eval_status, final_eval_status=eval_case_result.final_eval_status,
eval_metric_results=eval_result.eval_metric_results, eval_metric_results=eval_case_result.eval_metric_results,
user_id=eval_result.user_id, overall_eval_metric_results=eval_case_result.overall_eval_metric_results,
session_id=eval_result.session_id, eval_metric_result_per_invocation=eval_case_result.eval_metric_result_per_invocation,
user_id=eval_case_result.user_id,
session_id=eval_case_result.session_id,
) )
) )
session = session_service.get_session( eval_case_result.session_details = session_service.get_session(
app_name=app_name, app_name=app_name,
user_id=eval_result.user_id, user_id=eval_case_result.user_id,
session_id=eval_result.session_id, session_id=eval_case_result.session_id,
)
eval_case_results.append(
EvalCaseResult(
eval_set_file=eval_result.eval_set_file,
eval_id=eval_result.eval_id,
final_eval_status=eval_result.final_eval_status,
eval_metric_results=eval_result.eval_metric_results,
session_id=eval_result.session_id,
session_details=session,
user_id=eval_result.user_id,
)
) )
eval_case_results.append(eval_case_result)
timestamp = time.time() timestamp = time.time()
eval_set_result_name = app_name + "_" + eval_set_id + "_" + str(timestamp) eval_set_result_name = app_name + "_" + eval_set_id + "_" + str(timestamp)

View File

@ -258,13 +258,6 @@ class AgentEvaluator:
initial_session=initial_session, initial_session=initial_session,
) )
@staticmethod
def _generate_responses_from_session(eval_dataset, session_path):
"""Generates evaluation responses by running the agent module multiple times."""
return EvaluationGenerator.generate_responses_from_session(
session_path, eval_dataset
)
@staticmethod @staticmethod
def _response_evaluation_required(criteria, eval_dataset): def _response_evaluation_required(criteria, eval_dataset):
"""Checks if response evaluation are needed.""" """Checks if response evaluation are needed."""

View File

@ -23,10 +23,10 @@ from pydantic import Field
class IntermediateData(BaseModel): class IntermediateData(BaseModel):
"""Container for intermediate data that an agent would generate as it responds with a final answer.""" """Container for intermediate data that an agent would generate as it responds with a final answer."""
tool_uses: list[genai_types.FunctionCall] tool_uses: list[genai_types.FunctionCall] = []
"""Tool use trajectory in chronological order.""" """Tool use trajectory in chronological order."""
intermediate_responses: list[Tuple[str, list[genai_types.Part]]] intermediate_responses: list[Tuple[str, list[genai_types.Part]]] = []
"""Intermediate responses generated by sub-agents to convey progress or status """Intermediate responses generated by sub-agents to convey progress or status
in a multi-agent system, distinct from the final response. in a multi-agent system, distinct from the final response.

View File

@ -13,19 +13,19 @@
# limitations under the License. # limitations under the License.
import importlib import importlib
from typing import Any, Optional
import uuid import uuid
from google.genai import types
from ..agents.base_agent import BaseAgent
from ..agents.llm_agent import Agent from ..agents.llm_agent import Agent
from ..agents.llm_agent import BeforeToolCallback from ..artifacts.base_artifact_service import BaseArtifactService
from ..agents.llm_agent import LlmAgent
from ..artifacts.in_memory_artifact_service import InMemoryArtifactService from ..artifacts.in_memory_artifact_service import InMemoryArtifactService
from ..runners import Runner from ..runners import Runner
from ..sessions.base_session_service import BaseSessionService
from ..sessions.in_memory_session_service import InMemorySessionService from ..sessions.in_memory_session_service import InMemorySessionService
from ..sessions.session import Session from ..sessions.session import Session
from .evaluation_constants import EvalConstants from .eval_case import IntermediateData
from .eval_case import Invocation
from .eval_case import SessionInput
class EvaluationGenerator: class EvaluationGenerator:
@ -102,56 +102,40 @@ class EvaluationGenerator:
agent_to_evaluate = root_agent.find_agent(agent_name) agent_to_evaluate = root_agent.find_agent(agent_name)
assert agent_to_evaluate, f"Sub-Agent `{agent_name}` not found." assert agent_to_evaluate, f"Sub-Agent `{agent_name}` not found."
return EvaluationGenerator._process_query_with_root_agent( return EvaluationGenerator._generate_inferences_from_root_agent(
data, agent_to_evaluate, reset_func, initial_session data, agent_to_evaluate, reset_func, initial_session
) )
@staticmethod @staticmethod
async def _process_query_with_root_agent( async def _generate_inferences_from_root_agent(
data, invocations: list[Invocation],
root_agent, root_agent: Agent,
reset_func, reset_func: Any,
initial_session={}, initial_session: Optional[SessionInput] = None,
session_id=None, session_id: Optional[str] = None,
session_service=None, session_service: Optional[BaseSessionService] = None,
artifact_service=None, artifact_service: Optional[BaseArtifactService] = None,
): ) -> list[Invocation]:
"""Process a query using the agent and evaluation dataset.""" """Scrapes the root agent given the list of Invocations."""
# we don't know which tools belong to which agent
# so we just apply to any agents that has certain tool outputs
all_mock_tools = set()
for eval_entry in data:
expected_tool_use = eval_entry.get(EvalConstants.EXPECTED_TOOL_USE, [])
for expected in expected_tool_use:
if EvalConstants.MOCK_TOOL_OUTPUT in expected:
all_mock_tools.add(expected[EvalConstants.TOOL_NAME])
eval_data_copy = data.copy()
await EvaluationGenerator.apply_before_tool_callback(
root_agent,
lambda *args: EvaluationGenerator.before_tool_callback(
*args, eval_dataset=eval_data_copy
),
all_mock_tools,
)
if not session_service: if not session_service:
session_service = InMemorySessionService() session_service = InMemorySessionService()
app_name = initial_session.get("app_name", "EvaluationGenerator") app_name = (
user_id = initial_session.get("user_id", "test_user_id") initial_session.app_name if initial_session else "EvaluationGenerator"
)
user_id = initial_session.user_id if initial_session else "test_user_id"
session_id = session_id if session_id else str(uuid.uuid4()) session_id = session_id if session_id else str(uuid.uuid4())
_ = session_service.create_session( _ = session_service.create_session(
app_name=app_name, app_name=app_name,
user_id=user_id, user_id=user_id,
state=initial_session.get("state", {}), state=initial_session.state if initial_session else {},
session_id=session_id, session_id=session_id,
) )
if not artifact_service: if not artifact_service:
artifact_service = InMemoryArtifactService() artifact_service = InMemoryArtifactService()
runner = Runner( runner = Runner(
app_name=app_name, app_name=app_name,
agent=root_agent, agent=root_agent,
@ -163,30 +147,37 @@ class EvaluationGenerator:
if callable(reset_func): if callable(reset_func):
reset_func() reset_func()
responses = data.copy() response_invocations = []
for index, eval_entry in enumerate(responses): for invocation in invocations:
response = None final_response = None
query = eval_entry["query"] user_content = invocation.user_content
content = types.Content(role="user", parts=[types.Part(text=query)]) tool_uses = []
turn_actual_tool_uses = [] invocation_id = ""
for event in runner.run( for event in runner.run(
user_id=user_id, session_id=session_id, new_message=content user_id=user_id, session_id=session_id, new_message=user_content
): ):
invocation_id = (
event.invocation_id if not invocation_id else invocation_id
)
if event.is_final_response() and event.content and event.content.parts: if event.is_final_response() and event.content and event.content.parts:
response = event.content.parts[0].text final_response = event.content
elif event.get_function_calls(): elif event.get_function_calls():
for call in event.get_function_calls(): for call in event.get_function_calls():
turn_actual_tool_uses.append({ tool_uses.append(call)
EvalConstants.TOOL_NAME: call.name,
EvalConstants.TOOL_INPUT: call.args,
})
responses[index]["actual_tool_use"] = turn_actual_tool_uses response_invocations.append(
responses[index]["response"] = response Invocation(
invocation_id=invocation_id,
user_content=user_content,
final_response=final_response,
intermediate_data=IntermediateData(tool_uses=tool_uses),
)
)
return responses return response_invocations
@staticmethod @staticmethod
def _process_query_with_session(session_data, data): def _process_query_with_session(session_data, data):
@ -225,46 +216,3 @@ class EvaluationGenerator:
responses[index]["actual_tool_use"] = actual_tool_uses responses[index]["actual_tool_use"] = actual_tool_uses
responses[index]["response"] = response responses[index]["response"] = response
return responses return responses
@staticmethod
def before_tool_callback(tool, args, tool_context, eval_dataset):
"""Intercept specific tool calls and return predefined outputs
from eval_dataset.
"""
for index, eval_entry in enumerate(eval_dataset):
expected_tool_use = eval_entry.get("expected_tool_use", [])
for expected in expected_tool_use:
if (
EvalConstants.MOCK_TOOL_OUTPUT in expected
and tool.name == expected[EvalConstants.TOOL_NAME]
and args == expected.get(EvalConstants.TOOL_INPUT, {})
):
# pop the matched entry so we don't rematch again
eval_dataset.pop(index)
return {"result": expected[EvalConstants.MOCK_TOOL_OUTPUT]}
return None
@staticmethod
async def apply_before_tool_callback(
agent: BaseAgent,
callback: BeforeToolCallback,
all_mock_tools: set[str],
):
"""Recursively apply the before_tool_callback to the root agent and all its subagents."""
# Check if the agent has tools that are defined by evalset.
# We use function names to check if tools match
if not isinstance(agent, Agent) and not isinstance(agent, LlmAgent):
return
for tool in await agent.canonical_tools():
tool_name = tool.name
if tool_name in all_mock_tools:
agent.before_tool_callback = callback
# Apply recursively to subagents if they exist
for sub_agent in agent.sub_agents:
await EvaluationGenerator.apply_before_tool_callback(
sub_agent, callback, all_mock_tools
)

View File

@ -0,0 +1,56 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from abc import ABC
from enum import Enum
from typing import Optional
from pydantic import BaseModel
from .eval_case import Invocation
class EvalStatus(Enum):
PASSED = 1
FAILED = 2
NOT_EVALUATED = 3
class PerInvocationResult(BaseModel):
"""Metric evaluation score per invocation."""
actual_invocation: Invocation
expected_invocation: Invocation
score: Optional[float] = None
eval_status: EvalStatus = EvalStatus.NOT_EVALUATED
class EvaluationResult(BaseModel):
overall_score: Optional[float] = None
"""Overall score, based on each invocation."""
overall_eval_status: EvalStatus = EvalStatus.NOT_EVALUATED
"""Overall status, based on each invocation."""
per_invocation_results: list[PerInvocationResult] = []
class Evaluator(ABC):
"""A merics evaluator interface."""
def evaluate_invocations(
self,
actual_invocations: list[Invocation],
expected_invocations: list[Invocation],
) -> EvaluationResult:
"""Returns EvaluationResult after performing evaluations using actual and expected invocations."""
raise NotImplementedError()

View File

@ -154,6 +154,22 @@ def convert_eval_set_to_pydanctic_schema(
) )
def load_eval_set_from_file(
eval_set_file_path: str, eval_set_id: str
) -> EvalSet:
"""Returns an EvalSet that is read from the given file."""
with open(eval_set_file_path, "r", encoding="utf-8") as f:
content = f.read()
try:
return EvalSet.model_validate_json(content)
except ValidationError:
# We assume that the eval data was specified in the old format and try
# to convert it to the new format.
return convert_eval_set_to_pydanctic_schema(
eval_set_id, json.loads(content)
)
class LocalEvalSetsManager(EvalSetsManager): class LocalEvalSetsManager(EvalSetsManager):
"""An EvalSets manager that stores eval sets locally on disk.""" """An EvalSets manager that stores eval sets locally on disk."""
@ -165,16 +181,7 @@ class LocalEvalSetsManager(EvalSetsManager):
"""Returns an EvalSet identified by an app_name and eval_set_id.""" """Returns an EvalSet identified by an app_name and eval_set_id."""
# Load the eval set file data # Load the eval set file data
eval_set_file_path = self._get_eval_set_file_path(app_name, eval_set_id) eval_set_file_path = self._get_eval_set_file_path(app_name, eval_set_id)
with open(eval_set_file_path, "r", encoding="utf-8") as f: return load_eval_set_from_file(eval_set_file_path, eval_set_id)
content = f.read()
try:
return EvalSet.model_validate_json(content)
except ValidationError:
# We assume that the eval data was specified in the old format and try
# to convert it to the new format.
return convert_eval_set_to_pydanctic_schema(
eval_set_id, json.loads(content)
)
@override @override
def create_eval_set(self, app_name: str, eval_set_id: str): def create_eval_set(self, app_name: str, eval_set_id: str):

View File

@ -12,18 +12,98 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from typing import Any from typing import Any, cast
from deprecated import deprecated
from google.genai import types as genai_types
import pandas as pd import pandas as pd
from tabulate import tabulate from tabulate import tabulate
from typing_extensions import override
from .eval_case import Invocation
from .evaluation_constants import EvalConstants from .evaluation_constants import EvalConstants
from .evaluator import EvalStatus
from .evaluator import EvaluationResult
from .evaluator import Evaluator
from .evaluator import PerInvocationResult
class TrajectoryEvaluator: class TrajectoryEvaluator(Evaluator):
"""Evaluates tool use trajectories for accuracy.""" """Evaluates tool use trajectories for accuracy."""
def __init__(self, threshold: float):
self._threshold = threshold
@override
def evaluate_invocations(
self,
actual_invocations: list[Invocation],
expected_invocations: list[Invocation],
) -> EvaluationResult:
"""Returns EvaluationResult after performing evaluations using actual and expected invocations."""
total_tool_use_accuracy = 0.0
num_invocations = 0
per_invocation_results = []
for actual, expected in zip(actual_invocations, expected_invocations):
actual_tool_uses = (
actual.intermediate_data.tool_uses if actual.intermediate_data else []
)
expected_tool_uses = (
expected.intermediate_data.tool_uses
if expected.intermediate_data
else []
)
tool_use_accuracy = (
1.0
if self._are_tool_calls_equal(actual_tool_uses, expected_tool_uses)
else 0.0
)
per_invocation_results.append(
PerInvocationResult(
actual_invocation=actual,
expected_invocation=expected,
score=tool_use_accuracy,
eval_status=self._get_eval_status(tool_use_accuracy),
)
)
total_tool_use_accuracy += tool_use_accuracy
num_invocations += 1
if per_invocation_results:
overall_score = total_tool_use_accuracy / num_invocations
return EvaluationResult(
overall_score=overall_score,
overall_eval_status=self._get_eval_status(overall_score),
per_invocation_results=per_invocation_results,
)
return EvaluationResult()
def _are_tool_calls_equal(
self,
actual_tool_calls: list[genai_types.FunctionCall],
expected_tool_calls: list[genai_types.FunctionCall],
) -> bool:
if len(actual_tool_calls) != len(expected_tool_calls):
return False
for actual, expected in zip(actual_tool_calls, expected_tool_calls):
if actual.name != expected.name or actual.args != expected.args:
return False
return True
def _get_eval_status(self, score: float):
return EvalStatus.PASSED if score >= self._threshold else EvalStatus.FAILED
@staticmethod @staticmethod
@deprecated(
reason=(
"This method has been deprecated and will be removed soon. Please use"
" evaluate_invocations instead."
)
)
def evaluate( def evaluate(
eval_dataset: list[list[dict[str, Any]]], eval_dataset: list[list[dict[str, Any]]],
*, *,
@ -137,6 +217,7 @@ class TrajectoryEvaluator:
return new_row, failure return new_row, failure
@staticmethod @staticmethod
@deprecated()
def are_tools_equal(list_a_original, list_b_original): def are_tools_equal(list_a_original, list_b_original):
# Remove other entries that we don't want to evaluate # Remove other entries that we don't want to evaluate
list_a = [ list_a = [