Files
evo-ai/.venv/lib/python3.10/site-packages/google/adk/cli/cli_eval.py
2025-04-25 15:30:54 -03:00

283 lines
9.0 KiB
Python

# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from enum import Enum
import importlib.util
import json
import logging
import os
import sys
import traceback
from typing import Any
from typing import Generator
from typing import Optional
import uuid
from pydantic import BaseModel
from ..agents import Agent
logger = logging.getLogger(__name__)
class EvalStatus(Enum):
PASSED = 1
FAILED = 2
NOT_EVALUATED = 3
class EvalMetric(BaseModel):
metric_name: str
threshold: float
class EvalMetricResult(BaseModel):
score: Optional[float]
eval_status: EvalStatus
class EvalResult(BaseModel):
eval_set_file: str
eval_id: str
final_eval_status: EvalStatus
eval_metric_results: list[tuple[EvalMetric, EvalMetricResult]]
session_id: str
MISSING_EVAL_DEPENDENCIES_MESSAGE = (
"Eval module is not installed, please install via `pip install"
" google-adk[eval]`."
)
TOOL_TRAJECTORY_SCORE_KEY = "tool_trajectory_avg_score"
RESPONSE_MATCH_SCORE_KEY = "response_match_score"
# This evaluation is not very stable.
# This is always optional unless explicitly specified.
RESPONSE_EVALUATION_SCORE_KEY = "response_evaluation_score"
EVAL_SESSION_ID_PREFIX = "___eval___session___"
DEFAULT_CRITERIA = {
TOOL_TRAJECTORY_SCORE_KEY: 1.0, # 1-point scale; 1.0 is perfect.
RESPONSE_MATCH_SCORE_KEY: 0.8,
}
def _import_from_path(module_name, file_path):
spec = importlib.util.spec_from_file_location(module_name, file_path)
module = importlib.util.module_from_spec(spec)
sys.modules[module_name] = module
spec.loader.exec_module(module)
return module
def _get_agent_module(agent_module_file_path: str):
file_path = os.path.join(agent_module_file_path, "__init__.py")
module_name = "agent"
return _import_from_path(module_name, file_path)
def get_evaluation_criteria_or_default(
eval_config_file_path: str,
) -> dict[str, float]:
"""Returns evaluation criteria from the config file, if present.
Otherwise a default one is returned.
"""
if eval_config_file_path:
with open(eval_config_file_path, "r", encoding="utf-8") as f:
config_data = json.load(f)
if "criteria" in config_data and isinstance(config_data["criteria"], dict):
evaluation_criteria = config_data["criteria"]
else:
raise ValueError(
f"Invalid format for test_config.json at {eval_config_file_path}."
" Expected a 'criteria' dictionary."
)
else:
logger.info("No config file supplied. Using default criteria.")
evaluation_criteria = DEFAULT_CRITERIA
return evaluation_criteria
def get_root_agent(agent_module_file_path: str) -> Agent:
"""Returns root agent given the agent module."""
agent_module = _get_agent_module(agent_module_file_path)
root_agent = agent_module.agent.root_agent
return root_agent
def try_get_reset_func(agent_module_file_path: str) -> Any:
"""Returns reset function for the agent, if present, given the agent module."""
agent_module = _get_agent_module(agent_module_file_path)
reset_func = getattr(agent_module.agent, "reset_data", None)
return reset_func
def parse_and_get_evals_to_run(
eval_set_file_path: tuple[str],
) -> dict[str, list[str]]:
"""Returns a dictionary of eval sets to evals that should be run."""
eval_set_to_evals = {}
for input_eval_set in eval_set_file_path:
evals = []
if ":" not in input_eval_set:
eval_set_file = input_eval_set
else:
eval_set_file = input_eval_set.split(":")[0]
evals = input_eval_set.split(":")[1].split(",")
if eval_set_file not in eval_set_to_evals:
eval_set_to_evals[eval_set_file] = []
eval_set_to_evals[eval_set_file].extend(evals)
return eval_set_to_evals
def run_evals(
eval_set_to_evals: dict[str, list[str]],
root_agent: Agent,
reset_func: Optional[Any],
eval_metrics: list[EvalMetric],
session_service=None,
artifact_service=None,
print_detailed_results=False,
) -> Generator[EvalResult, None, None]:
try:
from ..evaluation.agent_evaluator import EvaluationGenerator
from ..evaluation.response_evaluator import ResponseEvaluator
from ..evaluation.trajectory_evaluator import TrajectoryEvaluator
except ModuleNotFoundError as e:
raise ModuleNotFoundError(MISSING_EVAL_DEPENDENCIES_MESSAGE) from e
"""Returns a summary of eval runs."""
for eval_set_file, evals_to_run in eval_set_to_evals.items():
with open(eval_set_file, "r", encoding="utf-8") as file:
eval_items = json.load(file) # Load JSON into a list
assert eval_items, f"No eval data found in eval set file: {eval_set_file}"
for eval_item in eval_items:
eval_name = eval_item["name"]
eval_data = eval_item["data"]
initial_session = eval_item.get("initial_session", {})
if evals_to_run and eval_name not in evals_to_run:
continue
try:
print(f"Running Eval: {eval_set_file}:{eval_name}")
session_id = f"{EVAL_SESSION_ID_PREFIX}{str(uuid.uuid4())}"
scrape_result = EvaluationGenerator._process_query_with_root_agent(
data=eval_data,
root_agent=root_agent,
reset_func=reset_func,
initial_session=initial_session,
session_id=session_id,
session_service=session_service,
artifact_service=artifact_service,
)
eval_metric_results = []
for eval_metric in eval_metrics:
eval_metric_result = None
if eval_metric.metric_name == TOOL_TRAJECTORY_SCORE_KEY:
score = TrajectoryEvaluator.evaluate(
[scrape_result], print_detailed_results=print_detailed_results
)
eval_metric_result = _get_eval_metric_result(eval_metric, score)
elif eval_metric.metric_name == RESPONSE_MATCH_SCORE_KEY:
score = ResponseEvaluator.evaluate(
[scrape_result],
[RESPONSE_MATCH_SCORE_KEY],
print_detailed_results=print_detailed_results,
)
eval_metric_result = _get_eval_metric_result(
eval_metric, score["rouge_1/mean"].item()
)
elif eval_metric.metric_name == RESPONSE_EVALUATION_SCORE_KEY:
score = ResponseEvaluator.evaluate(
[scrape_result],
[RESPONSE_EVALUATION_SCORE_KEY],
print_detailed_results=print_detailed_results,
)
eval_metric_result = _get_eval_metric_result(
eval_metric, score["coherence/mean"].item()
)
else:
logger.warning("`%s` is not supported.", eval_metric.metric_name)
eval_metric_results.append((
eval_metric,
EvalMetricResult(eval_status=EvalStatus.NOT_EVALUATED),
))
eval_metric_results.append((
eval_metric,
eval_metric_result,
))
_print_eval_metric_result(eval_metric, eval_metric_result)
final_eval_status = EvalStatus.NOT_EVALUATED
# Go over the all the eval statuses and mark the final eval status as
# passed if all of them pass, otherwise mark the final eval status to
# failed.
for eval_metric_result in eval_metric_results:
eval_status = eval_metric_result[1].eval_status
if eval_status == EvalStatus.PASSED:
final_eval_status = EvalStatus.PASSED
elif eval_status == EvalStatus.NOT_EVALUATED:
continue
elif eval_status == EvalStatus.FAILED:
final_eval_status = EvalStatus.FAILED
break
else:
raise ValueError("Unknown eval status.")
yield EvalResult(
eval_set_file=eval_set_file,
eval_id=eval_name,
final_eval_status=final_eval_status,
eval_metric_results=eval_metric_results,
session_id=session_id,
)
if final_eval_status == EvalStatus.PASSED:
result = "✅ Passed"
else:
result = "❌ Failed"
print(f"Result: {result}\n")
except Exception as e:
print(f"Error: {e}")
logger.info("Error: %s", str(traceback.format_exc()))
def _get_eval_metric_result(eval_metric, score):
eval_status = (
EvalStatus.PASSED if score >= eval_metric.threshold else EvalStatus.FAILED
)
return EvalMetricResult(score=score, eval_status=eval_status)
def _print_eval_metric_result(eval_metric, eval_metric_result):
print(
f"Metric: {eval_metric.metric_name}\tStatus:"
f" {eval_metric_result.eval_status}\tScore:"
f" {eval_metric_result.score}\tThreshold: {eval_metric.threshold}"
)