Agent Development Kit(ADK)

An easy-to-use and powerful framework to build AI agents.
This commit is contained in:
hangfei
2025-04-08 17:22:09 +00:00
parent f92478bd5c
commit 9827820143
299 changed files with 44398 additions and 2 deletions

View File

@@ -0,0 +1,31 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
logger = logging.getLogger(__name__)
__all__ = []
try:
from .agent_evaluator import AgentEvaluator
__all__.append('AgentEvaluator')
except ImportError:
logger.debug(
'The Vertex[eval] sdk is not installed. If you want to use the Vertex'
' Evaluation with agents, please install it(pip install'
' "google-cloud-aiplatform[evaluation]). If not, you can ignore this'
' warning.'
)

View File

@@ -0,0 +1,329 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import os
from os import path
from typing import Dict
from typing import List
from typing import Union
from .evaluation_generator import EvaluationGenerator
from .response_evaluator import ResponseEvaluator
from .trajectory_evaluator import TrajectoryEvaluator
# Constants for default runs and evaluation criteria
NUM_RUNS = 2
TOOL_TRAJECTORY_SCORE_KEY = "tool_trajectory_avg_score"
# This evaluation is not very stable.
# This is always optional unless explicitly specified.
RESPONSE_EVALUATION_SCORE_KEY = "response_evaluation_score"
RESPONSE_MATCH_SCORE_KEY = "response_match_score"
ALLOWED_CRITERIA = [
TOOL_TRAJECTORY_SCORE_KEY,
RESPONSE_EVALUATION_SCORE_KEY,
RESPONSE_MATCH_SCORE_KEY,
]
QUERY_COLUMN = "query"
REFERENCE_COLUMN = "reference"
EXPECTED_TOOL_USE_COLUMN = "expected_tool_use"
DEFAULT_CRITERIA = {
TOOL_TRAJECTORY_SCORE_KEY: 1.0, # 1-point scale; 1.0 is perfect.
RESPONSE_MATCH_SCORE_KEY: 0.8, # Rouge-1 text match; 0.8 is default.
}
def load_json(file_path: str) -> Union[Dict, List]:
with open(file_path, "r") as f:
return json.load(f)
class AgentEvaluator:
"""An evaluator for Agents, mainly intented for helping with test cases."""
@staticmethod
def find_config_for_test_file(test_file: str):
"""Find the test_config.json file in the same folder as the test file."""
test_folder = os.path.dirname(test_file)
config_path = os.path.join(test_folder, "test_config.json")
if os.path.exists(config_path):
config_data = load_json(config_path)
if "criteria" in config_data and isinstance(
config_data["criteria"], dict
):
return config_data["criteria"]
else:
raise ValueError(
f"Invalid format for test_config.json at {config_path}. Expected a"
" 'criteria' dictionary."
)
return DEFAULT_CRITERIA
@staticmethod
def evaluate(
agent_module,
eval_dataset_file_path_or_dir,
num_runs=NUM_RUNS,
agent_name=None,
initial_session_file=None,
):
"""Evaluates an Agent given eval data.
Args:
agent_module: The path to python module that contains the definition of
the agent. There is convention in place here, where the code is going to
look for 'root_agent' in the loaded module.
eval_dataset: The eval data set. This can be either a string representing
full path to the file containing eval dataset, or a directory that is
recusively explored for all files that have a `.test.json` suffix.
num_runs: Number of times all entries in the eval dataset should be
assessed.
agent_name: The name of the agent.
initial_session_file: File that contains initial session state that is
needed by all the evals in the eval dataset.
"""
test_files = []
if isinstance(eval_dataset_file_path_or_dir, str) and os.path.isdir(
eval_dataset_file_path_or_dir
):
for root, _, files in os.walk(eval_dataset_file_path_or_dir):
for file in files:
if file.endswith(".test.json"):
test_files.append(path.join(root, file))
else:
test_files = [eval_dataset_file_path_or_dir]
initial_session_state = {}
if initial_session_file:
with open(initial_session_file, "r") as f:
initial_session_state = json.loads(f.read())["state"]
for test_file in test_files:
dataset = AgentEvaluator._load_dataset(test_file)[0]
criteria = AgentEvaluator.find_config_for_test_file(test_file)
AgentEvaluator._validate_input([dataset], criteria)
evaluation_response = AgentEvaluator._generate_responses(
agent_module,
[dataset],
num_runs,
agent_name=agent_name,
initial_session={"state": initial_session_state},
)
if AgentEvaluator._response_evaluation_required(criteria, [dataset]):
AgentEvaluator._evaluate_response_scores(
agent_module, evaluation_response, criteria
)
if AgentEvaluator._trajectory_evaluation_required(criteria, [dataset]):
AgentEvaluator._evaluate_tool_trajectory(
agent_module, evaluation_response, criteria
)
@staticmethod
def _load_dataset(
input_data: Union[str, List[str], List[Dict], List[List[Dict]]],
) -> List[List[Dict]]:
def load_json_file(file_path: str) -> List[Dict]:
data = load_json(file_path)
if not isinstance(data, list) or not all(
isinstance(d, dict) for d in data
):
raise ValueError(f"{file_path} must contain a list of dictionaries.")
return data
if isinstance(input_data, str):
if os.path.isdir(input_data):
test_files = []
for root, _, files in os.walk(input_data):
for file in files:
if file.endswith(".test.json"):
test_files.append(os.path.join(root, file))
return [load_json_file(f) for f in test_files]
elif os.path.isfile(input_data):
return [load_json_file(input_data)]
else:
raise ValueError(f"Input path {input_data} is invalid.")
elif isinstance(input_data, list):
if all(isinstance(i, str) and os.path.isfile(i) for i in input_data):
return [load_json_file(i) for i in input_data]
raise TypeError("Input list must contain valid file paths.")
raise TypeError("Invalid input type for dataset loading.")
@staticmethod
def _validate_input(eval_dataset, criteria):
"""Validates that the evaluation criteria align with the provided dataset.
For efficiency, we only use first row to validate input.
"""
if not eval_dataset:
raise ValueError("The evaluation dataset is None or empty.")
for key in criteria:
if key not in ALLOWED_CRITERIA:
raise ValueError(
f"Invalid criteria key: {key}. Expected one of {ALLOWED_CRITERIA}."
)
if not eval_dataset:
raise ValueError("The evaluation dataset is empty.")
sample = eval_dataset[0]
first_query = sample[0]
if not isinstance(sample, list) and not isinstance(first_query, dict):
raise ValueError(
"Each evaluation dataset sample must be list of dictionary. But it's"
f" {eval_dataset}"
)
if TOOL_TRAJECTORY_SCORE_KEY in criteria:
if (
QUERY_COLUMN not in first_query
or EXPECTED_TOOL_USE_COLUMN not in first_query
):
raise ValueError(
f"Samples for {TOOL_TRAJECTORY_SCORE_KEY} must include"
f" '{QUERY_COLUMN}' and '{EXPECTED_TOOL_USE_COLUMN}' keys. The"
f" sample is {sample}."
)
if RESPONSE_EVALUATION_SCORE_KEY in criteria:
if QUERY_COLUMN not in first_query:
raise ValueError(
f"Samples for {RESPONSE_EVALUATION_SCORE_KEY} must include"
f" '{QUERY_COLUMN}' key. The sample is {sample}."
)
if RESPONSE_MATCH_SCORE_KEY in criteria:
if QUERY_COLUMN not in first_query or REFERENCE_COLUMN not in first_query:
raise ValueError(
f"Samples for {RESPONSE_MATCH_SCORE_KEY} must include"
f" '{QUERY_COLUMN}' and '{REFERENCE_COLUMN}' keys. The sample is"
f" {sample}."
)
@staticmethod
def _get_infer_criteria(eval_dataset):
"""Infers evaluation criteria based on the provided dataset.
Args:
eval_dataset (list): A list of evaluation samples.
Returns:
dict: Inferred evaluation criteria based on dataset fields.
"""
inferred_criteria = {}
sample = eval_dataset[0][0]
if QUERY_COLUMN in sample and EXPECTED_TOOL_USE_COLUMN in sample:
inferred_criteria[TOOL_TRAJECTORY_SCORE_KEY] = DEFAULT_CRITERIA[
TOOL_TRAJECTORY_SCORE_KEY
]
if QUERY_COLUMN in sample and REFERENCE_COLUMN in sample:
inferred_criteria[RESPONSE_MATCH_SCORE_KEY] = DEFAULT_CRITERIA[
RESPONSE_MATCH_SCORE_KEY
]
return inferred_criteria
@staticmethod
def _generate_responses(
agent_module, eval_dataset, num_runs, agent_name=None, initial_session={}
):
"""Generates evaluation responses by running the agent module multiple times."""
return EvaluationGenerator.generate_responses(
eval_dataset,
agent_module,
repeat_num=num_runs,
agent_name=agent_name,
initial_session=initial_session,
)
@staticmethod
def _generate_responses_from_session(eval_dataset, session_path):
"""Generates evaluation responses by running the agent module multiple times."""
return EvaluationGenerator.generate_responses_from_session(
session_path, eval_dataset
)
@staticmethod
def _response_evaluation_required(criteria, eval_dataset):
"""Checks if response evaluation are needed."""
return REFERENCE_COLUMN in eval_dataset[0][0] and any(
key in criteria
for key in [RESPONSE_EVALUATION_SCORE_KEY, RESPONSE_MATCH_SCORE_KEY]
)
@staticmethod
def _trajectory_evaluation_required(evaluation_criteria, eval_dataset):
"""Checks if response evaluation are needed."""
return (
EXPECTED_TOOL_USE_COLUMN in eval_dataset[0][0]
and TOOL_TRAJECTORY_SCORE_KEY in evaluation_criteria
)
@staticmethod
def _evaluate_response_scores(agent_module, evaluation_response, criteria):
"""Evaluates response scores and raises an assertion error if they don't meet the criteria."""
metrics = ResponseEvaluator.evaluate(
evaluation_response, criteria, print_detailed_results=True
)
AgentEvaluator._assert_score(
metrics,
"coherence/mean",
criteria.get(RESPONSE_EVALUATION_SCORE_KEY),
"Average response evaluation score",
agent_module,
)
AgentEvaluator._assert_score(
metrics,
"rouge_1/mean",
criteria.get(RESPONSE_MATCH_SCORE_KEY),
"Average response match score",
agent_module,
)
@staticmethod
def _evaluate_tool_trajectory(agent_module, evaluation_response, criteria):
"""Evaluates tool trajectory scores and raises an assertion error if they don't meet the criteria."""
score = TrajectoryEvaluator.evaluate(
evaluation_response, print_detailed_results=True
)
AgentEvaluator._assert_score(
{TOOL_TRAJECTORY_SCORE_KEY: score},
TOOL_TRAJECTORY_SCORE_KEY,
criteria[TOOL_TRAJECTORY_SCORE_KEY],
"Average tool trajectory evaluation score",
agent_module,
)
@staticmethod
def _assert_score(metrics, metric_key, threshold, description, agent_module):
"""Asserts that a metric meets the specified threshold."""
if metric_key in metrics:
actual_score = metrics[metric_key]
assert actual_score >= threshold, (
f"{description} for {agent_module} is lower than expected. "
f"Expected >= {threshold}, but got {actual_score}."
)

View File

@@ -0,0 +1,24 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
class EvalConstants:
"""Holds constants for evaluation file constants."""
QUERY = "query"
EXPECTED_TOOL_USE = "expected_tool_use"
RESPONSE = "response"
REFERENCE = "reference"
TOOL_NAME = "tool_name"
TOOL_INPUT = "tool_input"
MOCK_TOOL_OUTPUT = "mock_tool_output"

View File

@@ -0,0 +1,270 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import importlib
import uuid
from google.genai import types
from ..agents.base_agent import BaseAgent
from ..agents.llm_agent import Agent
from ..agents.llm_agent import BeforeToolCallback
from ..agents.llm_agent import LlmAgent
from ..artifacts.in_memory_artifact_service import InMemoryArtifactService
from ..runners import Runner
from ..sessions.in_memory_session_service import InMemorySessionService
from ..sessions.session import Session
from .evaluation_constants import EvalConstants
class EvaluationGenerator:
"""Generates evaluation responses for agents."""
@staticmethod
def generate_responses(
eval_dataset,
agent_module_path,
repeat_num=3,
agent_name=None,
initial_session={},
):
"""Returns evaluation responses for the given dataset and agent.
Args:
eval_dataset: The dataset that needs to be scraped for resposnes.
agent_module_path: Path to the module that contains the root agent.
repeat_num: Number of time the eval dataset should be repeated. This is
usually done to remove uncertainity that a single run may bring.
agent_name: The name of the agent that should be evaluated. This is
usually the sub-agent.
initial_session: Initial session for the eval data.
"""
results = []
for _ in range(repeat_num):
for data in eval_dataset:
results.append(
EvaluationGenerator._process_query(
data, agent_module_path, agent_name, initial_session
)
)
return results
@staticmethod
def generate_responses_from_session(session_path, eval_dataset):
"""Returns evaluation responses by combining session data with eval data.
Args:
session_path: Path to a json file that contains session data.
eval_dataset: The eval data set that should be combined with the session
data.
"""
results = []
with open(session_path, "r") as f:
session_data = Session.model_validate_json(f.read())
print("loaded session", session_path)
for data in eval_dataset:
# load session data from session_path
results.append(
EvaluationGenerator._process_query_with_session(
session_data,
data,
)
)
return results
@staticmethod
def _process_query(data, module_name, agent_name=None, initial_session={}):
"""Process a query using the agent and evaluation dataset."""
module_path = f"{module_name}"
agent_module = importlib.import_module(module_path)
root_agent = agent_module.agent.root_agent
reset_func = getattr(agent_module.agent, "reset_data", None)
agent_to_evaluate = root_agent
if agent_name:
agent_to_evaluate = root_agent.find_agent(agent_name)
assert agent_to_evaluate, f"Sub-Agent `{agent_name}` not found."
return EvaluationGenerator._process_query_with_root_agent(
data, agent_to_evaluate, reset_func, initial_session
)
@staticmethod
def _process_query_with_root_agent(
data,
root_agent,
reset_func,
initial_session={},
session_id=None,
session_service=None,
artifact_service=None,
):
"""Process a query using the agent and evaluation dataset."""
# we don't know which tools belong to which agent
# so we just apply to any agents that has certain tool outputs
all_mock_tools = set()
for eval_entry in data:
expected_tool_use = eval_entry.get(EvalConstants.EXPECTED_TOOL_USE, [])
for expected in expected_tool_use:
if EvalConstants.MOCK_TOOL_OUTPUT in expected:
all_mock_tools.add(expected[EvalConstants.TOOL_NAME])
eval_data_copy = data.copy()
EvaluationGenerator.apply_before_tool_callback(
root_agent,
lambda *args: EvaluationGenerator.before_tool_callback(
*args, eval_dataset=eval_data_copy
),
all_mock_tools,
)
if not session_service:
session_service = InMemorySessionService()
app_name = initial_session.get("app_name", "EvaluationGenerator")
user_id = initial_session.get("user_id", "test_user_id")
session_id = session_id if session_id else str(uuid.uuid4())
_ = session_service.create_session(
app_name=app_name,
user_id=user_id,
state=initial_session.get("state", {}),
session_id=session_id,
)
if not artifact_service:
artifact_service = InMemoryArtifactService()
runner = Runner(
app_name=app_name,
agent=root_agent,
artifact_service=artifact_service,
session_service=session_service,
)
# Reset agent state for each query
if callable(reset_func):
reset_func()
responses = data.copy()
for index, eval_entry in enumerate(responses):
response = None
query = eval_entry["query"]
content = types.Content(role="user", parts=[types.Part(text=query)])
turn_actual_tool_uses = []
for event in runner.run(
user_id=user_id, session_id=session_id, new_message=content
):
if event.is_final_response() and event.content and event.content.parts:
response = event.content.parts[0].text
elif event.get_function_calls():
for call in event.get_function_calls():
turn_actual_tool_uses.append({
EvalConstants.TOOL_NAME: call.name,
EvalConstants.TOOL_INPUT: call.args,
})
responses[index]["actual_tool_use"] = turn_actual_tool_uses
responses[index]["response"] = response
return responses
@staticmethod
def _process_query_with_session(session_data, data):
"""Process the queries using the existing session data without invoking the runner."""
responses = data.copy()
# Iterate through the provided queries and align them with the session events
for index, eval_entry in enumerate(responses):
query = eval_entry["query"]
actual_tool_uses = []
response = None
# Search for the corresponding session events
for event in session_data.events:
# Match the query to a user event
if (
event.author == "user"
and event.content
and event.content.parts
and event.content.parts[0].text == query
):
# Look for subsequent tool usage or model responses
for subsequent_event in session_data.events:
if subsequent_event.invocation_id == event.invocation_id:
# Extract tool usage
if subsequent_event.content.parts[0].function_call:
call = subsequent_event.content.parts[0].function_call
actual_tool_uses.append(
{"tool_name": call.name, "tool_input": call.args}
)
# Extract final response
elif subsequent_event.author != "user":
response = subsequent_event.content.parts[0].text
# Update the results for the current query
responses[index]["actual_tool_use"] = actual_tool_uses
responses[index]["response"] = response
return responses
@staticmethod
def before_tool_callback(tool, args, tool_context, eval_dataset):
"""Intercept specific tool calls and return predefined outputs
from eval_dataset.
"""
for index, eval_entry in enumerate(eval_dataset):
expected_tool_use = eval_entry.get("expected_tool_use", [])
for expected in expected_tool_use:
if (
EvalConstants.MOCK_TOOL_OUTPUT in expected
and tool.name == expected[EvalConstants.TOOL_NAME]
and args == expected.get(EvalConstants.TOOL_INPUT, {})
):
# pop the matched entry so we don't rematch again
eval_dataset.pop(index)
return {"result": expected[EvalConstants.MOCK_TOOL_OUTPUT]}
return None
@staticmethod
def apply_before_tool_callback(
agent: BaseAgent,
callback: BeforeToolCallback,
all_mock_tools: set[str],
):
"""Recursively apply the before_tool_callback to the root agent and all its subagents."""
# check if the agent has tools that defined by evalset
# We use function name to check if tools match
if not isinstance(agent, Agent) and not isinstance(agent, LlmAgent):
return
for tool in agent.canonical_tools:
tool_name = tool.name
if tool_name in all_mock_tools:
agent.before_tool_callback = callback
# Apply recursively to subagents if they exist
for sub_agent in agent.sub_agents:
EvaluationGenerator.apply_before_tool_callback(
sub_agent, callback, all_mock_tools
)

View File

@@ -0,0 +1,135 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any
import pandas as pd
from tabulate import tabulate
from vertexai.preview.evaluation import EvalTask
from vertexai.preview.evaluation import MetricPromptTemplateExamples
class ResponseEvaluator:
"""Runs response evaluation for agents."""
@staticmethod
def evaluate(
raw_eval_dataset: list[list[dict[str, Any]]],
evaluation_criteria: list[str],
*,
print_detailed_results: bool = False
):
r"""Returns the value of requested evaluation metrics.
Args:
raw_eval_dataset: The dataset that will be evaluated.
evaluation_criteria: The evaluation criteria to be used. This method
support two criterias, `response_evaluation_score` and
`response_match_score`.
print_detailed_results: Prints detailed results on the console. This is
usually helpful during debugging.
A note on evaluation_criteria:
`response_match_score`: This metric compares the agents final natural
language reponse with the expected final response, stored in the
"reference" field in test/eval files. We use Rouge metric to compare the
two responses.
Value Range: [0, 1]. A score closer to 0 means poor similarity between
response and reference. A score closer to 1 means strong similarity
between response and reference.
`response_evaluation_score`: Uses LLM to evalaute coherence of the
response, including tool use. This is pointwise metric.
Value range: [0, 5], where 0 means that the agent's response is not
coherent, while 5 means it is . High values are good.
A note on raw_eval_dataset:
The dataset should be a list session, where each sesssion is represented
as a list of interaction that need evaluation. Each evaluation is
represented as a dictionary that is expected to have values for the
following keys:
1) query
2) response
3) acutal_tool_use
4) expected_tool_use
5) reference
Here is a sample eval_dataset value with one entry:
[
[
{
"query": "roll a die for me",
"response": "I rolled a 16 sided die and got 13.\n",
"expected_tool_use": [
{
"tool_name": "roll_die",
"tool_input": {
"sides": 16
}
}
],
"acutal_tool_use": [
{
"tool_name": "roll_die",
"tool_input": {
"sides": 16
}
}
],
"reference": "I rolled a 16 sided die and got 13.\n"
}
]
]
"""
if not raw_eval_dataset:
raise ValueError("The evaluation dataset is empty.")
metrics = ResponseEvaluator._get_metrics(
raw_eval_dataset, evaluation_criteria
)
flattened_queries = [
item for sublist in raw_eval_dataset for item in sublist
]
eval_dataset = pd.DataFrame(flattened_queries).rename(
columns={"query": "prompt", "expected_tool_use": "reference_trajectory"}
)
eval_task = EvalTask(dataset=eval_dataset, metrics=metrics)
eval_result = eval_task.evaluate()
if print_detailed_results:
ResponseEvaluator._print_results(eval_result)
return eval_result.summary_metrics
@staticmethod
def _get_metrics(raw_eval_dataset, criteria):
metrics = []
if (
"response_evaluation_score" in criteria
and "query" in raw_eval_dataset[0][0]
and "expected_tool_use" in raw_eval_dataset[0][0]
):
metrics.append(MetricPromptTemplateExamples.Pointwise.COHERENCE)
if (
"response_match_score" in criteria
and "reference" in raw_eval_dataset[0][0]
):
metrics.append("rouge_1")
return metrics
@staticmethod
def _print_results(eval_result):
print("Evaluation Summary Metrics:", eval_result.summary_metrics)
print(tabulate(eval_result.metrics_table, headers="keys", tablefmt="grid"))

View File

@@ -0,0 +1,184 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any
import pandas as pd
from tabulate import tabulate
from .evaluation_constants import EvalConstants
class TrajectoryEvaluator:
"""Evaluates tool use trajectories for accuracy."""
@staticmethod
def evaluate(
eval_dataset: list[list[dict[str, Any]]],
*,
print_detailed_results: bool = False,
):
r"""Returns the mean tool use accuracy of the eval dataset.
Tool use accuracy is calculated by comparing the expected and actuall tool
use trajectories. An exact match scores a 1, 0 otherwise. The final number
is an
average of these individual scores.
Value range: [0, 1], where 0 is means none of the too use entries aligned,
and 1 would mean all of them aligned. Higher value is good.
Args:
eval_dataset: The dataset that will be evaluated.
print_detailed_results: Prints detailed results on the console. This is
usually helpful during debugging.
A note on eval_dataset:
The dataset should be a list session, where each sesssion is represented
as a list of interaction that need evaluation. Each evaluation is
represented as a dictionary that is expected to have values for the
following keys:
1) query
2) response
3) acutal_tool_use
4) expected_tool_use
Here is a sample eval_dataset value with one entry:
[
[
{
"query": "Roll a 16 sided dice for me",
"response": "I rolled a 16 sided die and got 13.\n",
"expected_tool_use": [
{
"tool_name": "roll_die",
"tool_input": {
"sides": 16
}
}
],
"acutal_tool_use": [
{
"tool_name": "roll_die",
"tool_input": {
"sides": 16
}
}
]
}
]
]
"""
if not eval_dataset:
raise ValueError("The evaluation dataset is empty.")
results_df = pd.DataFrame(
columns=[
"query",
"response",
"actual_tool_use",
"expected_tool_use",
"tool_use_accuracy",
]
)
failures = []
for conversation in eval_dataset:
for index, row in enumerate(conversation):
new_row, failure = TrajectoryEvaluator._evaluate_row(row)
results_df = pd.concat(
[results_df, pd.DataFrame([new_row])], ignore_index=True
)
if failure:
failure["turn"] = index + 1
failures.append(failure)
TrajectoryEvaluator._report_failures(failures)
if print_detailed_results:
TrajectoryEvaluator._print_results(results_df)
return results_df["tool_use_accuracy"].mean()
@staticmethod
def _evaluate_row(row):
# We don't evaluate the mock tool outputs.
expected = TrajectoryEvaluator._remove_tool_outputs(
row["expected_tool_use"]
)
actual = row["actual_tool_use"]
tool_use_accuracy = (
1.0 if TrajectoryEvaluator.are_tools_equal(actual, expected) else 0.0
)
new_row = {
"query": row["query"],
"response": row["response"],
"actual_tool_use": actual,
"expected_tool_use": expected,
"tool_use_accuracy": tool_use_accuracy,
}
failure = (
None
if tool_use_accuracy == 1.0
else {"query": row["query"], "actual": actual, "expected": expected}
)
return new_row, failure
@staticmethod
def are_tools_equal(list_a_original, list_b_original):
# Remove other entries that we don't want to evaluate
list_a = [
{"tool_name": tool["tool_name"], "tool_input": tool["tool_input"]}
for tool in list_a_original
]
list_b = [
{"tool_name": tool["tool_name"], "tool_input": tool["tool_input"]}
for tool in list_b_original
]
return list_a == list_b
@staticmethod
def _remove_tool_outputs(tool_use_list):
"""Removes 'mock_tool_output' from each dictionary in the list."""
result = []
for tool_use in tool_use_list:
new_tool_use = (
tool_use.copy()
) # Create a copy to avoid modifying the original
new_tool_use.pop(
EvalConstants.MOCK_TOOL_OUTPUT, None
) # Remove 'tool_output' if it exists
result.append(new_tool_use)
return result
@staticmethod
def _report_failures(failures):
if failures:
print("Failures:")
for failure in failures:
print(f"""{{
"turn": {failure["turn"]},
"query": '{failure["query"]}',
"actual": {failure["actual"]},
"expected_tool_use": {failure["expected"]},
}}
""")
@staticmethod
def _print_results(results_df):
print(tabulate(results_df, headers="keys", tablefmt="grid"))