mirror of
https://github.com/EvolutionAPI/adk-python.git
synced 2025-12-24 06:07:44 -06:00
Update Eval Run and TrajectoryEvaluator to use the new schema.
PiperOrigin-RevId: 758927160
This commit is contained in:
committed by
Copybara-Service
parent
2cb74dd20e
commit
ee674ce0ef
@@ -258,13 +258,6 @@ class AgentEvaluator:
|
||||
initial_session=initial_session,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _generate_responses_from_session(eval_dataset, session_path):
|
||||
"""Generates evaluation responses by running the agent module multiple times."""
|
||||
return EvaluationGenerator.generate_responses_from_session(
|
||||
session_path, eval_dataset
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _response_evaluation_required(criteria, eval_dataset):
|
||||
"""Checks if response evaluation are needed."""
|
||||
|
||||
@@ -23,10 +23,10 @@ from pydantic import Field
|
||||
class IntermediateData(BaseModel):
|
||||
"""Container for intermediate data that an agent would generate as it responds with a final answer."""
|
||||
|
||||
tool_uses: list[genai_types.FunctionCall]
|
||||
tool_uses: list[genai_types.FunctionCall] = []
|
||||
"""Tool use trajectory in chronological order."""
|
||||
|
||||
intermediate_responses: list[Tuple[str, list[genai_types.Part]]]
|
||||
intermediate_responses: list[Tuple[str, list[genai_types.Part]]] = []
|
||||
"""Intermediate responses generated by sub-agents to convey progress or status
|
||||
in a multi-agent system, distinct from the final response.
|
||||
|
||||
|
||||
@@ -13,19 +13,19 @@
|
||||
# limitations under the License.
|
||||
|
||||
import importlib
|
||||
from typing import Any, Optional
|
||||
import uuid
|
||||
|
||||
from google.genai import types
|
||||
|
||||
from ..agents.base_agent import BaseAgent
|
||||
from ..agents.llm_agent import Agent
|
||||
from ..agents.llm_agent import BeforeToolCallback
|
||||
from ..agents.llm_agent import LlmAgent
|
||||
from ..artifacts.base_artifact_service import BaseArtifactService
|
||||
from ..artifacts.in_memory_artifact_service import InMemoryArtifactService
|
||||
from ..runners import Runner
|
||||
from ..sessions.base_session_service import BaseSessionService
|
||||
from ..sessions.in_memory_session_service import InMemorySessionService
|
||||
from ..sessions.session import Session
|
||||
from .evaluation_constants import EvalConstants
|
||||
from .eval_case import IntermediateData
|
||||
from .eval_case import Invocation
|
||||
from .eval_case import SessionInput
|
||||
|
||||
|
||||
class EvaluationGenerator:
|
||||
@@ -102,56 +102,40 @@ class EvaluationGenerator:
|
||||
agent_to_evaluate = root_agent.find_agent(agent_name)
|
||||
assert agent_to_evaluate, f"Sub-Agent `{agent_name}` not found."
|
||||
|
||||
return EvaluationGenerator._process_query_with_root_agent(
|
||||
return EvaluationGenerator._generate_inferences_from_root_agent(
|
||||
data, agent_to_evaluate, reset_func, initial_session
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def _process_query_with_root_agent(
|
||||
data,
|
||||
root_agent,
|
||||
reset_func,
|
||||
initial_session={},
|
||||
session_id=None,
|
||||
session_service=None,
|
||||
artifact_service=None,
|
||||
):
|
||||
"""Process a query using the agent and evaluation dataset."""
|
||||
|
||||
# we don't know which tools belong to which agent
|
||||
# so we just apply to any agents that has certain tool outputs
|
||||
all_mock_tools = set()
|
||||
for eval_entry in data:
|
||||
expected_tool_use = eval_entry.get(EvalConstants.EXPECTED_TOOL_USE, [])
|
||||
for expected in expected_tool_use:
|
||||
if EvalConstants.MOCK_TOOL_OUTPUT in expected:
|
||||
all_mock_tools.add(expected[EvalConstants.TOOL_NAME])
|
||||
|
||||
eval_data_copy = data.copy()
|
||||
await EvaluationGenerator.apply_before_tool_callback(
|
||||
root_agent,
|
||||
lambda *args: EvaluationGenerator.before_tool_callback(
|
||||
*args, eval_dataset=eval_data_copy
|
||||
),
|
||||
all_mock_tools,
|
||||
)
|
||||
|
||||
async def _generate_inferences_from_root_agent(
|
||||
invocations: list[Invocation],
|
||||
root_agent: Agent,
|
||||
reset_func: Any,
|
||||
initial_session: Optional[SessionInput] = None,
|
||||
session_id: Optional[str] = None,
|
||||
session_service: Optional[BaseSessionService] = None,
|
||||
artifact_service: Optional[BaseArtifactService] = None,
|
||||
) -> list[Invocation]:
|
||||
"""Scrapes the root agent given the list of Invocations."""
|
||||
if not session_service:
|
||||
session_service = InMemorySessionService()
|
||||
|
||||
app_name = initial_session.get("app_name", "EvaluationGenerator")
|
||||
user_id = initial_session.get("user_id", "test_user_id")
|
||||
app_name = (
|
||||
initial_session.app_name if initial_session else "EvaluationGenerator"
|
||||
)
|
||||
user_id = initial_session.user_id if initial_session else "test_user_id"
|
||||
session_id = session_id if session_id else str(uuid.uuid4())
|
||||
|
||||
_ = session_service.create_session(
|
||||
app_name=app_name,
|
||||
user_id=user_id,
|
||||
state=initial_session.get("state", {}),
|
||||
state=initial_session.state if initial_session else {},
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
if not artifact_service:
|
||||
artifact_service = InMemoryArtifactService()
|
||||
|
||||
runner = Runner(
|
||||
app_name=app_name,
|
||||
agent=root_agent,
|
||||
@@ -163,30 +147,37 @@ class EvaluationGenerator:
|
||||
if callable(reset_func):
|
||||
reset_func()
|
||||
|
||||
responses = data.copy()
|
||||
response_invocations = []
|
||||
|
||||
for index, eval_entry in enumerate(responses):
|
||||
response = None
|
||||
query = eval_entry["query"]
|
||||
content = types.Content(role="user", parts=[types.Part(text=query)])
|
||||
turn_actual_tool_uses = []
|
||||
for invocation in invocations:
|
||||
final_response = None
|
||||
user_content = invocation.user_content
|
||||
tool_uses = []
|
||||
invocation_id = ""
|
||||
|
||||
for event in runner.run(
|
||||
user_id=user_id, session_id=session_id, new_message=content
|
||||
user_id=user_id, session_id=session_id, new_message=user_content
|
||||
):
|
||||
invocation_id = (
|
||||
event.invocation_id if not invocation_id else invocation_id
|
||||
)
|
||||
|
||||
if event.is_final_response() and event.content and event.content.parts:
|
||||
response = event.content.parts[0].text
|
||||
final_response = event.content
|
||||
elif event.get_function_calls():
|
||||
for call in event.get_function_calls():
|
||||
turn_actual_tool_uses.append({
|
||||
EvalConstants.TOOL_NAME: call.name,
|
||||
EvalConstants.TOOL_INPUT: call.args,
|
||||
})
|
||||
tool_uses.append(call)
|
||||
|
||||
responses[index]["actual_tool_use"] = turn_actual_tool_uses
|
||||
responses[index]["response"] = response
|
||||
response_invocations.append(
|
||||
Invocation(
|
||||
invocation_id=invocation_id,
|
||||
user_content=user_content,
|
||||
final_response=final_response,
|
||||
intermediate_data=IntermediateData(tool_uses=tool_uses),
|
||||
)
|
||||
)
|
||||
|
||||
return responses
|
||||
return response_invocations
|
||||
|
||||
@staticmethod
|
||||
def _process_query_with_session(session_data, data):
|
||||
@@ -225,46 +216,3 @@ class EvaluationGenerator:
|
||||
responses[index]["actual_tool_use"] = actual_tool_uses
|
||||
responses[index]["response"] = response
|
||||
return responses
|
||||
|
||||
@staticmethod
|
||||
def before_tool_callback(tool, args, tool_context, eval_dataset):
|
||||
"""Intercept specific tool calls and return predefined outputs
|
||||
|
||||
from eval_dataset.
|
||||
"""
|
||||
for index, eval_entry in enumerate(eval_dataset):
|
||||
expected_tool_use = eval_entry.get("expected_tool_use", [])
|
||||
for expected in expected_tool_use:
|
||||
if (
|
||||
EvalConstants.MOCK_TOOL_OUTPUT in expected
|
||||
and tool.name == expected[EvalConstants.TOOL_NAME]
|
||||
and args == expected.get(EvalConstants.TOOL_INPUT, {})
|
||||
):
|
||||
# pop the matched entry so we don't rematch again
|
||||
eval_dataset.pop(index)
|
||||
return {"result": expected[EvalConstants.MOCK_TOOL_OUTPUT]}
|
||||
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
async def apply_before_tool_callback(
|
||||
agent: BaseAgent,
|
||||
callback: BeforeToolCallback,
|
||||
all_mock_tools: set[str],
|
||||
):
|
||||
"""Recursively apply the before_tool_callback to the root agent and all its subagents."""
|
||||
# Check if the agent has tools that are defined by evalset.
|
||||
# We use function names to check if tools match
|
||||
if not isinstance(agent, Agent) and not isinstance(agent, LlmAgent):
|
||||
return
|
||||
|
||||
for tool in await agent.canonical_tools():
|
||||
tool_name = tool.name
|
||||
if tool_name in all_mock_tools:
|
||||
agent.before_tool_callback = callback
|
||||
|
||||
# Apply recursively to subagents if they exist
|
||||
for sub_agent in agent.sub_agents:
|
||||
await EvaluationGenerator.apply_before_tool_callback(
|
||||
sub_agent, callback, all_mock_tools
|
||||
)
|
||||
|
||||
56
src/google/adk/evaluation/evaluator.py
Normal file
56
src/google/adk/evaluation/evaluator.py
Normal file
@@ -0,0 +1,56 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from abc import ABC
|
||||
from enum import Enum
|
||||
from typing import Optional
|
||||
from pydantic import BaseModel
|
||||
from .eval_case import Invocation
|
||||
|
||||
|
||||
class EvalStatus(Enum):
|
||||
PASSED = 1
|
||||
FAILED = 2
|
||||
NOT_EVALUATED = 3
|
||||
|
||||
|
||||
class PerInvocationResult(BaseModel):
|
||||
"""Metric evaluation score per invocation."""
|
||||
|
||||
actual_invocation: Invocation
|
||||
expected_invocation: Invocation
|
||||
score: Optional[float] = None
|
||||
eval_status: EvalStatus = EvalStatus.NOT_EVALUATED
|
||||
|
||||
|
||||
class EvaluationResult(BaseModel):
|
||||
overall_score: Optional[float] = None
|
||||
"""Overall score, based on each invocation."""
|
||||
|
||||
overall_eval_status: EvalStatus = EvalStatus.NOT_EVALUATED
|
||||
"""Overall status, based on each invocation."""
|
||||
|
||||
per_invocation_results: list[PerInvocationResult] = []
|
||||
|
||||
|
||||
class Evaluator(ABC):
|
||||
"""A merics evaluator interface."""
|
||||
|
||||
def evaluate_invocations(
|
||||
self,
|
||||
actual_invocations: list[Invocation],
|
||||
expected_invocations: list[Invocation],
|
||||
) -> EvaluationResult:
|
||||
"""Returns EvaluationResult after performing evaluations using actual and expected invocations."""
|
||||
raise NotImplementedError()
|
||||
@@ -154,6 +154,22 @@ def convert_eval_set_to_pydanctic_schema(
|
||||
)
|
||||
|
||||
|
||||
def load_eval_set_from_file(
|
||||
eval_set_file_path: str, eval_set_id: str
|
||||
) -> EvalSet:
|
||||
"""Returns an EvalSet that is read from the given file."""
|
||||
with open(eval_set_file_path, "r", encoding="utf-8") as f:
|
||||
content = f.read()
|
||||
try:
|
||||
return EvalSet.model_validate_json(content)
|
||||
except ValidationError:
|
||||
# We assume that the eval data was specified in the old format and try
|
||||
# to convert it to the new format.
|
||||
return convert_eval_set_to_pydanctic_schema(
|
||||
eval_set_id, json.loads(content)
|
||||
)
|
||||
|
||||
|
||||
class LocalEvalSetsManager(EvalSetsManager):
|
||||
"""An EvalSets manager that stores eval sets locally on disk."""
|
||||
|
||||
@@ -165,16 +181,7 @@ class LocalEvalSetsManager(EvalSetsManager):
|
||||
"""Returns an EvalSet identified by an app_name and eval_set_id."""
|
||||
# Load the eval set file data
|
||||
eval_set_file_path = self._get_eval_set_file_path(app_name, eval_set_id)
|
||||
with open(eval_set_file_path, "r", encoding="utf-8") as f:
|
||||
content = f.read()
|
||||
try:
|
||||
return EvalSet.model_validate_json(content)
|
||||
except ValidationError:
|
||||
# We assume that the eval data was specified in the old format and try
|
||||
# to convert it to the new format.
|
||||
return convert_eval_set_to_pydanctic_schema(
|
||||
eval_set_id, json.loads(content)
|
||||
)
|
||||
return load_eval_set_from_file(eval_set_file_path, eval_set_id)
|
||||
|
||||
@override
|
||||
def create_eval_set(self, app_name: str, eval_set_id: str):
|
||||
|
||||
@@ -12,18 +12,98 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Any
|
||||
from typing import Any, cast
|
||||
|
||||
from deprecated import deprecated
|
||||
from google.genai import types as genai_types
|
||||
import pandas as pd
|
||||
from tabulate import tabulate
|
||||
from typing_extensions import override
|
||||
|
||||
from .eval_case import Invocation
|
||||
from .evaluation_constants import EvalConstants
|
||||
from .evaluator import EvalStatus
|
||||
from .evaluator import EvaluationResult
|
||||
from .evaluator import Evaluator
|
||||
from .evaluator import PerInvocationResult
|
||||
|
||||
|
||||
class TrajectoryEvaluator:
|
||||
class TrajectoryEvaluator(Evaluator):
|
||||
"""Evaluates tool use trajectories for accuracy."""
|
||||
|
||||
def __init__(self, threshold: float):
|
||||
self._threshold = threshold
|
||||
|
||||
@override
|
||||
def evaluate_invocations(
|
||||
self,
|
||||
actual_invocations: list[Invocation],
|
||||
expected_invocations: list[Invocation],
|
||||
) -> EvaluationResult:
|
||||
"""Returns EvaluationResult after performing evaluations using actual and expected invocations."""
|
||||
total_tool_use_accuracy = 0.0
|
||||
num_invocations = 0
|
||||
per_invocation_results = []
|
||||
|
||||
for actual, expected in zip(actual_invocations, expected_invocations):
|
||||
actual_tool_uses = (
|
||||
actual.intermediate_data.tool_uses if actual.intermediate_data else []
|
||||
)
|
||||
expected_tool_uses = (
|
||||
expected.intermediate_data.tool_uses
|
||||
if expected.intermediate_data
|
||||
else []
|
||||
)
|
||||
tool_use_accuracy = (
|
||||
1.0
|
||||
if self._are_tool_calls_equal(actual_tool_uses, expected_tool_uses)
|
||||
else 0.0
|
||||
)
|
||||
per_invocation_results.append(
|
||||
PerInvocationResult(
|
||||
actual_invocation=actual,
|
||||
expected_invocation=expected,
|
||||
score=tool_use_accuracy,
|
||||
eval_status=self._get_eval_status(tool_use_accuracy),
|
||||
)
|
||||
)
|
||||
total_tool_use_accuracy += tool_use_accuracy
|
||||
num_invocations += 1
|
||||
|
||||
if per_invocation_results:
|
||||
overall_score = total_tool_use_accuracy / num_invocations
|
||||
return EvaluationResult(
|
||||
overall_score=overall_score,
|
||||
overall_eval_status=self._get_eval_status(overall_score),
|
||||
per_invocation_results=per_invocation_results,
|
||||
)
|
||||
|
||||
return EvaluationResult()
|
||||
|
||||
def _are_tool_calls_equal(
|
||||
self,
|
||||
actual_tool_calls: list[genai_types.FunctionCall],
|
||||
expected_tool_calls: list[genai_types.FunctionCall],
|
||||
) -> bool:
|
||||
if len(actual_tool_calls) != len(expected_tool_calls):
|
||||
return False
|
||||
|
||||
for actual, expected in zip(actual_tool_calls, expected_tool_calls):
|
||||
if actual.name != expected.name or actual.args != expected.args:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def _get_eval_status(self, score: float):
|
||||
return EvalStatus.PASSED if score >= self._threshold else EvalStatus.FAILED
|
||||
|
||||
@staticmethod
|
||||
@deprecated(
|
||||
reason=(
|
||||
"This method has been deprecated and will be removed soon. Please use"
|
||||
" evaluate_invocations instead."
|
||||
)
|
||||
)
|
||||
def evaluate(
|
||||
eval_dataset: list[list[dict[str, Any]]],
|
||||
*,
|
||||
@@ -137,6 +217,7 @@ class TrajectoryEvaluator:
|
||||
return new_row, failure
|
||||
|
||||
@staticmethod
|
||||
@deprecated()
|
||||
def are_tools_equal(list_a_original, list_b_original):
|
||||
# Remove other entries that we don't want to evaluate
|
||||
list_a = [
|
||||
|
||||
Reference in New Issue
Block a user