No public description

PiperOrigin-RevId: 748777998
This commit is contained in:
Google ADK Member 2025-04-17 19:50:22 +00:00 committed by hangfei
parent 290058eb05
commit 61d4be2d76
99 changed files with 2120 additions and 256 deletions

View File

@ -5,9 +5,9 @@
[![r/agentdevelopmentkit](https://img.shields.io/badge/Reddit-r%2Fagentdevelopmentkit-FF4500?style=flat&logo=reddit&logoColor=white)](https://www.reddit.com/r/agentdevelopmentkit/)
<html>
<h1 align="center">
<img src="assets/agent-development-kit.png" width="256"/>
</h1>
<h2 align="center">
<img src="https://raw.githubusercontent.com/google/adk-python/main/assets/agent-development-kit.png" width="256"/>
</h2>
<h3 align="center">
An open-source, code-first Python toolkit for building, evaluating, and deploying sophisticated AI agents with flexibility and control.
</h3>
@ -50,6 +50,7 @@ You can install the ADK using `pip`:
```bash
pip install google-adk
```
## 📚 Documentation
Explore the full documentation for detailed guides on building, evaluating, and
@ -60,6 +61,7 @@ deploying agents:
## 🏁 Feature Highlight
### Define a single agent:
```python
from google.adk.agents import Agent
from google.adk.tools import google_search
@ -74,7 +76,9 @@ root_agent = Agent(
```
### Define a multi-agent system:
Define a multi-agent system with coordinator agent, greeter agent, and task execution agent. Then ADK engine and the model will guide the agents works together to accomplish the task.
```python
from google.adk.agents import LlmAgent, BaseAgent
@ -92,14 +96,13 @@ coordinator = LlmAgent(
task_executor
]
)
```
### Development UI
A built-in development UI to help you test, evaluate, debug, and showcase your agent(s).
<img src="assets/adk-web-dev-ui-function-call.png"/>
<img src="https://raw.githubusercontent.com/google/adk-python/main/assets/adk-web-dev-ui-function-call.png"/>
### Evaluate Agents

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

View File

@ -82,6 +82,8 @@ async def run_interactively(
)
while True:
query = input('user: ')
if not query or not query.strip():
continue
if query == 'exit':
break
async for event in runner.run_async(

View File

@ -0,0 +1,279 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import subprocess
from typing import Optional
from typing import Tuple
import click
_INIT_PY_TEMPLATE = """\
from . import agent
"""
_AGENT_PY_TEMPLATE = """\
from google.adk.agents import Agent
root_agent = Agent(
model='{model_name}',
name='root_agent',
description='A helpful assistant for user questions.',
instruction='Answer user questions to the best of your knowledge',
)
"""
_GOOGLE_API_MSG = """
Don't have API Key? Create one in AI Studio: https://aistudio.google.com/apikey
"""
_GOOGLE_CLOUD_SETUP_MSG = """
You need an existing Google Cloud account and project, check out this link for details:
https://google.github.io/adk-docs/get-started/quickstart/#gemini---google-cloud-vertex-ai
"""
_OTHER_MODEL_MSG = """
Please see below guide to configure other models:
https://google.github.io/adk-docs/agents/models
"""
_SUCCESS_MSG = """
Agent created in {agent_folder}:
- .env
- __init__.py
- agent.py
"""
def _get_gcp_project_from_gcloud() -> str:
"""Uses gcloud to get default project."""
try:
result = subprocess.run(
["gcloud", "config", "get-value", "project"],
capture_output=True,
text=True,
check=True,
)
return result.stdout.strip()
except (subprocess.CalledProcessError, FileNotFoundError):
return ""
def _get_gcp_region_from_gcloud() -> str:
"""Uses gcloud to get default region."""
try:
result = subprocess.run(
["gcloud", "config", "get-value", "compute/region"],
capture_output=True,
text=True,
check=True,
)
return result.stdout.strip()
except (subprocess.CalledProcessError, FileNotFoundError):
return ""
def _prompt_str(
prompt_prefix: str,
*,
prior_msg: Optional[str] = None,
default_value: Optional[str] = None,
) -> str:
if prior_msg:
click.secho(prior_msg, fg="green")
while True:
value: str = click.prompt(
prompt_prefix, default=default_value or None, type=str
)
if value and value.strip():
return value.strip()
def _prompt_for_google_cloud(
google_cloud_project: Optional[str],
) -> str:
"""Prompts user for Google Cloud project ID."""
google_cloud_project = (
google_cloud_project
or os.environ.get("GOOGLE_CLOUD_PROJECT", None)
or _get_gcp_project_from_gcloud()
)
google_cloud_project = _prompt_str(
"Enter Google Cloud project ID", default_value=google_cloud_project
)
return google_cloud_project
def _prompt_for_google_cloud_region(
google_cloud_region: Optional[str],
) -> str:
"""Prompts user for Google Cloud region."""
google_cloud_region = (
google_cloud_region
or os.environ.get("GOOGLE_CLOUD_LOCATION", None)
or _get_gcp_region_from_gcloud()
)
google_cloud_region = _prompt_str(
"Enter Google Cloud region",
default_value=google_cloud_region or "us-central1",
)
return google_cloud_region
def _prompt_for_google_api_key(
google_api_key: Optional[str],
) -> str:
"""Prompts user for Google API key."""
google_api_key = google_api_key or os.environ.get("GOOGLE_API_KEY", None)
google_api_key = _prompt_str(
"Enter Google API key",
prior_msg=_GOOGLE_API_MSG,
default_value=google_api_key,
)
return google_api_key
def _generate_files(
agent_folder: str,
*,
google_api_key: Optional[str] = None,
google_cloud_project: Optional[str] = None,
google_cloud_region: Optional[str] = None,
model: Optional[str] = None,
):
"""Generates a folder name for the agent."""
os.makedirs(agent_folder, exist_ok=True)
dotenv_file_path = os.path.join(agent_folder, ".env")
init_file_path = os.path.join(agent_folder, "__init__.py")
agent_file_path = os.path.join(agent_folder, "agent.py")
with open(dotenv_file_path, "w", encoding="utf-8") as f:
lines = []
if google_api_key:
lines.append("GOOGLE_GENAI_USE_VERTEXAI=0")
elif google_cloud_project and google_cloud_region:
lines.append("GOOGLE_GENAI_USE_VERTEXAI=1")
if google_api_key:
lines.append(f"GOOGLE_API_KEY={google_api_key}")
if google_cloud_project:
lines.append(f"GOOGLE_CLOUD_PROJECT={google_cloud_project}")
if google_cloud_region:
lines.append(f"GOOGLE_CLOUD_LOCATION={google_cloud_region}")
f.write("\n".join(lines))
with open(init_file_path, "w", encoding="utf-8") as f:
f.write(_INIT_PY_TEMPLATE)
with open(agent_file_path, "w", encoding="utf-8") as f:
f.write(_AGENT_PY_TEMPLATE.format(model_name=model))
click.secho(
_SUCCESS_MSG.format(agent_folder=agent_folder),
fg="green",
)
def _prompt_for_model() -> str:
model_choice = click.prompt(
"""\
Choose a model for the root agent:
1. gemini-2.0-flash-001
2. Other models (fill later)
Choose model""",
type=click.Choice(["1", "2"]),
)
if model_choice == "1":
return "gemini-2.0-flash-001"
else:
click.secho(_OTHER_MODEL_MSG, fg="green")
return "<FILL_IN_MODEL>"
def _prompt_to_choose_backend(
google_api_key: Optional[str],
google_cloud_project: Optional[str],
google_cloud_region: Optional[str],
) -> Tuple[Optional[str], Optional[str], Optional[str]]:
"""Prompts user to choose backend.
Returns:
A tuple of (google_api_key, google_cloud_project, google_cloud_region).
"""
backend_choice = click.prompt(
"1. Google AI\n2. Vertex AI\nChoose a backend",
type=click.Choice(["1", "2"]),
)
if backend_choice == "1":
google_api_key = _prompt_for_google_api_key(google_api_key)
elif backend_choice == "2":
click.secho(_GOOGLE_CLOUD_SETUP_MSG, fg="green")
google_cloud_project = _prompt_for_google_cloud(google_cloud_project)
google_cloud_region = _prompt_for_google_cloud_region(google_cloud_region)
return google_api_key, google_cloud_project, google_cloud_region
def run_cmd(
agent_name: str,
*,
model: Optional[str],
google_api_key: Optional[str],
google_cloud_project: Optional[str],
google_cloud_region: Optional[str],
):
"""Runs `adk create` command to create agent template.
Args:
agent_name: str, The name of the agent.
google_api_key: Optional[str], The Google API key for using Google AI as
backend.
google_cloud_project: Optional[str], The Google Cloud project for using
VertexAI as backend.
google_cloud_region: Optional[str], The Google Cloud region for using
VertexAI as backend.
"""
agent_folder = os.path.join(os.getcwd(), agent_name)
# check folder doesn't exist or it's empty. Otherwise, throw
if os.path.exists(agent_folder) and os.listdir(agent_folder):
# Prompt user whether to override existing files using click
if not click.confirm(
f"Non-empty folder already exist: '{agent_folder}'\n"
"Override existing content?",
default=False,
):
raise click.Abort()
if not model:
model = _prompt_for_model()
if not google_api_key and not (google_cloud_project and google_cloud_region):
if model.startswith("gemini"):
google_api_key, google_cloud_project, google_cloud_region = (
_prompt_to_choose_backend(
google_api_key, google_cloud_project, google_cloud_region
)
)
_generate_files(
agent_folder,
google_api_key=google_api_key,
google_cloud_project=google_cloud_project,
google_cloud_region=google_cloud_region,
model=model,
)

View File

@ -82,8 +82,9 @@ def to_cloud_run(
app_name: str,
temp_folder: str,
port: int,
with_cloud_trace: bool,
trace_to_cloud: bool,
with_ui: bool,
verbosity: str,
):
"""Deploys an agent to Google Cloud Run.
@ -108,8 +109,9 @@ def to_cloud_run(
app_name: The name of the app, by default, it's basename of `agent_folder`.
temp_folder: The temp folder for the generated Cloud Run source files.
port: The port of the ADK api server.
with_cloud_trace: Whether to enable Cloud Trace.
trace_to_cloud: Whether to enable Cloud Trace.
with_ui: Whether to deploy with UI.
verbosity: The verbosity level of the CLI.
"""
app_name = app_name or os.path.basename(agent_folder)
@ -142,7 +144,7 @@ def to_cloud_run(
port=port,
command='web' if with_ui else 'api_server',
install_agent_deps=install_agent_deps,
trace_to_cloud_option='--trace_to_cloud' if with_cloud_trace else '',
trace_to_cloud_option='--trace_to_cloud' if trace_to_cloud else '',
)
dockerfile_path = os.path.join(temp_folder, 'Dockerfile')
os.makedirs(temp_folder, exist_ok=True)
@ -169,6 +171,8 @@ def to_cloud_run(
*region_options,
'--port',
str(port),
'--verbosity',
verbosity,
'--labels',
'created-by=adk',
],

View File

@ -24,6 +24,7 @@ import click
from fastapi import FastAPI
import uvicorn
from . import cli_create
from . import cli_deploy
from .cli import run_cli
from .cli_eval import MISSING_EVAL_DEPENDENCIES_MESSAGE
@ -42,10 +43,59 @@ def main():
@main.group()
def deploy():
"""Deploy Agent."""
"""Deploys agent to hosted environments."""
pass
@main.command("create")
@click.option(
"--model",
type=str,
help="Optional. The model used for the root agent.",
)
@click.option(
"--api_key",
type=str,
help=(
"Optional. The API Key needed to access the model, e.g. Google AI API"
" Key."
),
)
@click.option(
"--project",
type=str,
help="Optional. The Google Cloud Project for using VertexAI as backend.",
)
@click.option(
"--region",
type=str,
help="Optional. The Google Cloud Region for using VertexAI as backend.",
)
@click.argument("app_name", type=str, required=True)
def cli_create_cmd(
app_name: str,
model: Optional[str],
api_key: Optional[str],
project: Optional[str],
region: Optional[str],
):
"""Creates a new app in the current folder with prepopulated agent template.
APP_NAME: required, the folder of the agent source code.
Example:
adk create path/to/my_app
"""
cli_create.run_cmd(
app_name,
model=model,
google_api_key=api_key,
google_cloud_project=project,
google_cloud_region=region,
)
@main.command("run")
@click.option(
"--save_session",
@ -62,7 +112,7 @@ def deploy():
),
)
def cli_run(agent: str, save_session: bool):
"""Run an interactive CLI for a certain agent.
"""Runs an interactive CLI for a certain agent.
AGENT: The path to the agent source code folder.
@ -150,7 +200,7 @@ def cli_eval(
EvalMetric(metric_name=metric_name, threshold=threshold)
)
print(f"Using evaluation criteria: {evaluation_criteria}")
print(f"Using evaluation creiteria: {evaluation_criteria}")
root_agent = get_root_agent(agent_module_file_path)
reset_func = try_get_reset_func(agent_module_file_path)
@ -244,7 +294,7 @@ def cli_eval(
type=click.Path(
exists=True, dir_okay=True, file_okay=False, resolve_path=True
),
default=os.getcwd(),
default=os.getcwd,
)
def cli_web(
agents_dir: str,
@ -255,7 +305,7 @@ def cli_web(
port: int = 8000,
trace_to_cloud: bool = False,
):
"""Start a FastAPI server with Web UI for agents.
"""Starts a FastAPI server with Web UI for agents.
AGENTS_DIR: The directory of agents, where each sub-directory is a single
agent, containing at least `__init__.py` and `agent.py` files.
@ -274,7 +324,7 @@ def cli_web(
@asynccontextmanager
async def _lifespan(app: FastAPI):
click.secho(
f"""\
f"""
+-----------------------------------------------------------------------------+
| ADK Web Server started |
| |
@ -285,7 +335,7 @@ def cli_web(
)
yield # Startup is done, now app is running
click.secho(
"""\
"""
+-----------------------------------------------------------------------------+
| ADK Web Server shutting down... |
+-----------------------------------------------------------------------------+
@ -378,7 +428,7 @@ def cli_api_server(
port: int = 8000,
trace_to_cloud: bool = False,
):
"""Start a FastAPI server for agents.
"""Starts a FastAPI server for agents.
AGENTS_DIR: The directory of agents, where each sub-directory is a single
agent, containing at least `__init__.py` and `agent.py` files.
@ -452,7 +502,7 @@ def cli_api_server(
help="Optional. The port of the ADK API server (default: 8000).",
)
@click.option(
"--with_cloud_trace",
"--trace_to_cloud",
type=bool,
is_flag=True,
show_default=True,
@ -483,6 +533,14 @@ def cli_api_server(
" (default: a timestamped folder in the system temp directory)."
),
)
@click.option(
"--verbosity",
type=click.Choice(
["debug", "info", "warning", "error", "critical"], case_sensitive=False
),
default="WARNING",
help="Optional. Override the default verbosity level.",
)
@click.argument(
"agent",
type=click.Path(
@ -497,8 +555,9 @@ def cli_deploy_cloud_run(
app_name: str,
temp_folder: str,
port: int,
with_cloud_trace: bool,
trace_to_cloud: bool,
with_ui: bool,
verbosity: str,
):
"""Deploys an agent to Cloud Run.
@ -517,8 +576,9 @@ def cli_deploy_cloud_run(
app_name=app_name,
temp_folder=temp_folder,
port=port,
with_cloud_trace=with_cloud_trace,
trace_to_cloud=trace_to_cloud,
with_ui=with_ui,
verbosity=verbosity,
)
except Exception as e:
click.secho(f"Deploy failed: {e}", fg="red", err=True)

View File

@ -13,7 +13,9 @@
# limitations under the License.
import asyncio
from contextlib import asynccontextmanager
import importlib
import inspect
import json
import logging
import os
@ -28,6 +30,7 @@ from typing import Literal
from typing import Optional
import click
from click import Tuple
from fastapi import FastAPI
from fastapi import HTTPException
from fastapi import Query
@ -56,6 +59,7 @@ from ..agents.llm_agent import Agent
from ..agents.run_config import StreamingMode
from ..artifacts import InMemoryArtifactService
from ..events.event import Event
from ..memory.in_memory_memory_service import InMemoryMemoryService
from ..runners import Runner
from ..sessions.database_session_service import DatabaseSessionService
from ..sessions.in_memory_session_service import InMemorySessionService
@ -143,11 +147,8 @@ def get_fast_api_app(
provider.add_span_processor(
export.SimpleSpanProcessor(ApiServerSpanExporter(trace_dict))
)
envs.load_dotenv()
enable_cloud_tracing = trace_to_cloud or os.environ.get(
"ADK_TRACE_TO_CLOUD", "0"
).lower() in ["1", "true"]
if enable_cloud_tracing:
if trace_to_cloud:
envs.load_dotenv_for_agent("", agent_dir)
if project_id := os.environ.get("GOOGLE_CLOUD_PROJECT", None):
processor = export.BatchSpanProcessor(
CloudTraceSpanExporter(project_id=project_id)
@ -161,8 +162,22 @@ def get_fast_api_app(
trace.set_tracer_provider(provider)
exit_stacks = []
@asynccontextmanager
async def internal_lifespan(app: FastAPI):
if lifespan:
async with lifespan(app) as lifespan_context:
yield
if exit_stacks:
for stack in exit_stacks:
await stack.aclose()
else:
yield
# Run the FastAPI server.
app = FastAPI(lifespan=lifespan)
app = FastAPI(lifespan=internal_lifespan)
if allow_origins:
app.add_middleware(
@ -181,6 +196,7 @@ def get_fast_api_app(
# Build the Artifact service
artifact_service = InMemoryArtifactService()
memory_service = InMemoryMemoryService()
# Build the Session service
agent_engine_id = ""
@ -358,7 +374,7 @@ def get_fast_api_app(
"/apps/{app_name}/eval_sets/{eval_set_id}/add_session",
response_model_exclude_none=True,
)
def add_session_to_eval_set(
async def add_session_to_eval_set(
app_name: str, eval_set_id: str, req: AddSessionToEvalSetRequest
):
pattern = r"^[a-zA-Z0-9_]+$"
@ -393,7 +409,9 @@ def get_fast_api_app(
test_data = evals.convert_session_to_eval_format(session)
# Populate the session with initial session state.
initial_session_state = create_empty_state(_get_root_agent(app_name))
initial_session_state = create_empty_state(
await _get_root_agent_async(app_name)
)
eval_set_data.append({
"name": req.eval_id,
@ -430,7 +448,7 @@ def get_fast_api_app(
"/apps/{app_name}/eval_sets/{eval_set_id}/run_eval",
response_model_exclude_none=True,
)
def run_eval(
async def run_eval(
app_name: str, eval_set_id: str, req: RunEvalRequest
) -> list[RunEvalResult]:
from .cli_eval import run_evals
@ -447,7 +465,7 @@ def get_fast_api_app(
logger.info(
"Eval ids to run list is empty. We will all evals in the eval set."
)
root_agent = _get_root_agent(app_name)
root_agent = await _get_root_agent_async(app_name)
eval_results = list(
run_evals(
eval_set_to_evals,
@ -577,7 +595,7 @@ def get_fast_api_app(
)
if not session:
raise HTTPException(status_code=404, detail="Session not found")
runner = _get_runner(req.app_name)
runner = await _get_runner_async(req.app_name)
events = [
event
async for event in runner.run_async(
@ -604,7 +622,7 @@ def get_fast_api_app(
async def event_generator():
try:
stream_mode = StreamingMode.SSE if req.streaming else StreamingMode.NONE
runner = _get_runner(req.app_name)
runner = await _get_runner_async(req.app_name)
async for event in runner.run_async(
user_id=req.user_id,
session_id=req.session_id,
@ -630,7 +648,7 @@ def get_fast_api_app(
"/apps/{app_name}/users/{user_id}/sessions/{session_id}/events/{event_id}/graph",
response_model_exclude_none=True,
)
def get_event_graph(
async def get_event_graph(
app_name: str, user_id: str, session_id: str, event_id: str
):
# Connect to managed session if agent_engine_id is set.
@ -647,7 +665,7 @@ def get_fast_api_app(
function_calls = event.get_function_calls()
function_responses = event.get_function_responses()
root_agent = _get_root_agent(app_name)
root_agent = await _get_root_agent_async(app_name)
dot_graph = None
if function_calls:
function_call_highlights = []
@ -704,7 +722,7 @@ def get_fast_api_app(
live_request_queue = LiveRequestQueue()
async def forward_events():
runner = _get_runner(app_name)
runner = await _get_runner_async(app_name)
async for event in runner.run_live(
session=session, live_request_queue=live_request_queue
):
@ -742,26 +760,40 @@ def get_fast_api_app(
for task in pending:
task.cancel()
def _get_root_agent(app_name: str) -> Agent:
async def _get_root_agent_async(app_name: str) -> Agent:
"""Returns the root agent for the given app."""
if app_name in root_agent_dict:
return root_agent_dict[app_name]
envs.load_dotenv_for_agent(os.path.basename(app_name), agent_dir)
agent_module = importlib.import_module(app_name)
root_agent: Agent = agent_module.agent.root_agent
if getattr(agent_module.agent, "root_agent"):
root_agent = agent_module.agent.root_agent
else:
raise ValueError(f'Unable to find "root_agent" from {app_name}.')
# Handle an awaitable root agent and await for the actual agent.
if inspect.isawaitable(root_agent):
try:
agent, exit_stack = await root_agent
exit_stacks.append(exit_stack)
root_agent = agent
except Exception as e:
raise RuntimeError(f"error getting root agent, {e}") from e
root_agent_dict[app_name] = root_agent
return root_agent
def _get_runner(app_name: str) -> Runner:
async def _get_runner_async(app_name: str) -> Runner:
"""Returns the runner for the given app."""
envs.load_dotenv_for_agent(os.path.basename(app_name), agent_dir)
if app_name in runner_dict:
return runner_dict[app_name]
root_agent = _get_root_agent(app_name)
root_agent = await _get_root_agent_async(app_name)
runner = Runner(
app_name=agent_engine_id if agent_engine_id else app_name,
agent=root_agent,
artifact_service=artifact_service,
session_service=session_service,
memory_service=memory_service,
)
runner_dict[app_name] = runner
return runner

View File

@ -50,8 +50,5 @@ def load_dotenv_for_agent(
agent_name,
dotenv_file_path,
)
logger.info(
'Reloaded %s file for %s at %s', filename, agent_name, dotenv_file_path
)
else:
logger.info('No %s file found for %s', filename, agent_name)

View File

@ -106,9 +106,11 @@ class ResponseEvaluator:
eval_dataset = pd.DataFrame(flattened_queries).rename(
columns={"query": "prompt", "expected_tool_use": "reference_trajectory"}
)
eval_task = EvalTask(dataset=eval_dataset, metrics=metrics)
eval_result = eval_task.evaluate()
eval_result = ResponseEvaluator._perform_eval(
dataset=eval_dataset, metrics=metrics
)
if print_detailed_results:
ResponseEvaluator._print_results(eval_result)
return eval_result.summary_metrics
@ -129,6 +131,16 @@ class ResponseEvaluator:
metrics.append("rouge_1")
return metrics
@staticmethod
def _perform_eval(dataset, metrics):
"""This method hides away the call to external service.
Primarily helps with unit testing.
"""
eval_task = EvalTask(dataset=dataset, metrics=metrics)
return eval_task.evaluate()
@staticmethod
def _print_results(eval_result):
print("Evaluation Summary Metrics:", eval_result.summary_metrics)

View File

@ -87,15 +87,21 @@ class _NlPlanningResponse(BaseLlmResponseProcessor):
return
# Postprocess the LLM response.
callback_context = CallbackContext(invocation_context)
processed_parts = planner.process_planning_response(
CallbackContext(invocation_context), llm_response.content.parts
callback_context, llm_response.content.parts
)
if processed_parts:
llm_response.content.parts = processed_parts
# Maintain async generator behavior
if False: # Ensures it behaves as a generator
yield # This is a no-op but maintains generator structure
if callback_context.state.has_delta():
state_update_event = Event(
invocation_id=invocation_context.invocation_id,
author=invocation_context.agent.name,
branch=invocation_context.branch,
actions=callback_context._event_actions,
)
yield state_update_event
response_processor = _NlPlanningResponse()

View File

@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import base64
import copy
from datetime import datetime
import json
@ -20,17 +21,17 @@ from typing import Any
from typing import Optional
import uuid
from sqlalchemy import Boolean
from sqlalchemy import delete
from sqlalchemy import Dialect
from sqlalchemy import ForeignKeyConstraint
from sqlalchemy import func
from sqlalchemy import select
from sqlalchemy import Text
from sqlalchemy.dialects import postgresql
from sqlalchemy.engine import create_engine
from sqlalchemy.engine import Engine
from sqlalchemy.ext.mutable import MutableDict
from sqlalchemy.exc import ArgumentError
from sqlalchemy.ext.mutable import MutableDict
from sqlalchemy.inspection import inspect
from sqlalchemy.orm import DeclarativeBase
from sqlalchemy.orm import Mapped
@ -54,6 +55,7 @@ from .base_session_service import ListSessionsResponse
from .session import Session
from .state import State
logger = logging.getLogger(__name__)
@ -103,7 +105,7 @@ class StorageSession(Base):
String, primary_key=True, default=lambda: str(uuid.uuid4())
)
state: Mapped[dict] = mapped_column(
state: Mapped[MutableDict[str, Any]] = mapped_column(
MutableDict.as_mutable(DynamicJSON), default={}
)
@ -134,8 +136,20 @@ class StorageEvent(Base):
author: Mapped[str] = mapped_column(String)
branch: Mapped[str] = mapped_column(String, nullable=True)
timestamp: Mapped[DateTime] = mapped_column(DateTime(), default=func.now())
content: Mapped[dict] = mapped_column(DynamicJSON)
actions: Mapped[dict] = mapped_column(PickleType)
content: Mapped[dict[str, Any]] = mapped_column(DynamicJSON)
actions: Mapped[MutableDict[str, Any]] = mapped_column(PickleType)
long_running_tool_ids_json: Mapped[Optional[str]] = mapped_column(
Text, nullable=True
)
grounding_metadata: Mapped[dict[str, Any]] = mapped_column(
DynamicJSON, nullable=True
)
partial: Mapped[bool] = mapped_column(Boolean, nullable=True)
turn_complete: Mapped[bool] = mapped_column(Boolean, nullable=True)
error_code: Mapped[str] = mapped_column(String, nullable=True)
error_message: Mapped[str] = mapped_column(String, nullable=True)
interrupted: Mapped[bool] = mapped_column(Boolean, nullable=True)
storage_session: Mapped[StorageSession] = relationship(
"StorageSession",
@ -150,13 +164,28 @@ class StorageEvent(Base):
),
)
@property
def long_running_tool_ids(self) -> set[str]:
return (
set(json.loads(self.long_running_tool_ids_json))
if self.long_running_tool_ids_json
else set()
)
@long_running_tool_ids.setter
def long_running_tool_ids(self, value: set[str]):
if value is None:
self.long_running_tool_ids_json = None
else:
self.long_running_tool_ids_json = json.dumps(list(value))
class StorageAppState(Base):
"""Represents an app state stored in the database."""
__tablename__ = "app_states"
app_name: Mapped[str] = mapped_column(String, primary_key=True)
state: Mapped[dict] = mapped_column(
state: Mapped[MutableDict[str, Any]] = mapped_column(
MutableDict.as_mutable(DynamicJSON), default={}
)
update_time: Mapped[DateTime] = mapped_column(
@ -170,7 +199,7 @@ class StorageUserState(Base):
app_name: Mapped[str] = mapped_column(String, primary_key=True)
user_id: Mapped[str] = mapped_column(String, primary_key=True)
state: Mapped[dict] = mapped_column(
state: Mapped[MutableDict[str, Any]] = mapped_column(
MutableDict.as_mutable(DynamicJSON), default={}
)
update_time: Mapped[DateTime] = mapped_column(
@ -295,7 +324,6 @@ class DatabaseSessionService(BaseSessionService):
last_update_time=storage_session.update_time.timestamp(),
)
return session
return None
@override
def get_session(
@ -309,7 +337,6 @@ class DatabaseSessionService(BaseSessionService):
# 1. Get the storage session entry from session table
# 2. Get all the events based on session id and filtering config
# 3. Convert and return the session
session: Session = None
with self.DatabaseSessionFactory() as sessionFactory:
storage_session = sessionFactory.get(
StorageSession, (app_name, user_id, session_id)
@ -356,13 +383,19 @@ class DatabaseSessionService(BaseSessionService):
author=e.author,
branch=e.branch,
invocation_id=e.invocation_id,
content=e.content,
content=_decode_content(e.content),
actions=e.actions,
timestamp=e.timestamp.timestamp(),
long_running_tool_ids=e.long_running_tool_ids,
grounding_metadata=e.grounding_metadata,
partial=e.partial,
turn_complete=e.turn_complete,
error_code=e.error_code,
error_message=e.error_message,
interrupted=e.interrupted,
)
for e in storage_events
]
return session
@override
@ -387,7 +420,6 @@ class DatabaseSessionService(BaseSessionService):
)
sessions.append(session)
return ListSessionsResponse(sessions=sessions)
raise ValueError("Failed to retrieve sessions.")
@override
def delete_session(
@ -406,7 +438,7 @@ class DatabaseSessionService(BaseSessionService):
def append_event(self, session: Session, event: Event) -> Event:
logger.info(f"Append event: {event} to session {session.id}")
if event.partial and not event.content:
if event.partial:
return event
# 1. Check if timestamp is stale
@ -455,19 +487,34 @@ class DatabaseSessionService(BaseSessionService):
storage_user_state.state = user_state
storage_session.state = session_state
encoded_content = event.content.model_dump(exclude_none=True)
storage_event = StorageEvent(
id=event.id,
invocation_id=event.invocation_id,
author=event.author,
branch=event.branch,
content=encoded_content,
actions=event.actions,
session_id=session.id,
app_name=session.app_name,
user_id=session.user_id,
timestamp=datetime.fromtimestamp(event.timestamp),
long_running_tool_ids=event.long_running_tool_ids,
grounding_metadata=event.grounding_metadata,
partial=event.partial,
turn_complete=event.turn_complete,
error_code=event.error_code,
error_message=event.error_message,
interrupted=event.interrupted,
)
if event.content:
encoded_content = event.content.model_dump(exclude_none=True)
# Workaround for multimodal Content throwing JSON not serializable
# error with SQLAlchemy.
for p in encoded_content["parts"]:
if "inline_data" in p:
p["inline_data"]["data"] = (
base64.b64encode(p["inline_data"]["data"]).decode("utf-8"),
)
storage_event.content = encoded_content
sessionFactory.add(storage_event)
@ -489,8 +536,7 @@ class DatabaseSessionService(BaseSessionService):
user_id: str,
session_id: str,
) -> ListEventsResponse:
pass
raise NotImplementedError()
def convert_event(event: StorageEvent) -> Event:
"""Converts a storage event to an event."""
@ -505,7 +551,7 @@ def convert_event(event: StorageEvent) -> Event:
)
def _extract_state_delta(state: dict):
def _extract_state_delta(state: dict[str, Any]):
app_state_delta = {}
user_state_delta = {}
session_state_delta = {}
@ -528,3 +574,10 @@ def _merge_state(app_state, user_state, session_state):
for key in user_state.keys():
merged_state[State.USER_PREFIX + key] = user_state[key]
return merged_state
def _decode_content(content: dict[str, Any]) -> dict[str, Any]:
for p in content["parts"]:
if "inline_data" in p:
p["inline_data"]["data"] = base64.b64decode(p["inline_data"]["data"][0])
return content

View File

@ -196,11 +196,12 @@ class IntegrationClient:
action_details = connections_client.get_action_schema(action)
input_schema = action_details["inputSchema"]
output_schema = action_details["outputSchema"]
action_display_name = action_details["displayName"]
# Remove spaces from the display name to generate valid spec
action_display_name = action_details["displayName"].replace(" ", "")
operation = "EXECUTE_ACTION"
if action == "ExecuteCustomQuery":
connector_spec["components"]["schemas"][
f"{action}_Request"
f"{action_display_name}_Request"
] = connections_client.execute_custom_query_request()
operation = "EXECUTE_QUERY"
else:

View File

@ -291,7 +291,7 @@ def _parse_schema_from_parameter(
return schema
raise ValueError(
f'Failed to parse the parameter {param} of function {func_name} for'
' automatic function calling.Automatic function calling works best with'
' automatic function calling. Automatic function calling works best with'
' simpler function signature schema,consider manually parse your'
f' function declaration for function {func_name}.'
)

View File

@ -11,4 +11,77 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .google_api_tool_sets import calendar_tool_set
__all__ = [
'bigquery_tool_set',
'calendar_tool_set',
'gmail_tool_set',
'youtube_tool_set',
'slides_tool_set',
'sheets_tool_set',
'docs_tool_set',
]
# Nothing is imported here automatically
# Each tool set will only be imported when accessed
_bigquery_tool_set = None
_calendar_tool_set = None
_gmail_tool_set = None
_youtube_tool_set = None
_slides_tool_set = None
_sheets_tool_set = None
_docs_tool_set = None
def __getattr__(name):
global _bigquery_tool_set, _calendar_tool_set, _gmail_tool_set, _youtube_tool_set, _slides_tool_set, _sheets_tool_set, _docs_tool_set
match name:
case 'bigquery_tool_set':
if _bigquery_tool_set is None:
from .google_api_tool_sets import bigquery_tool_set as bigquery
_bigquery_tool_set = bigquery
return _bigquery_tool_set
case 'calendar_tool_set':
if _calendar_tool_set is None:
from .google_api_tool_sets import calendar_tool_set as calendar
_calendar_tool_set = calendar
return _calendar_tool_set
case 'gmail_tool_set':
if _gmail_tool_set is None:
from .google_api_tool_sets import gmail_tool_set as gmail
_gmail_tool_set = gmail
return _gmail_tool_set
case 'youtube_tool_set':
if _youtube_tool_set is None:
from .google_api_tool_sets import youtube_tool_set as youtube
_youtube_tool_set = youtube
return _youtube_tool_set
case 'slides_tool_set':
if _slides_tool_set is None:
from .google_api_tool_sets import slides_tool_set as slides
_slides_tool_set = slides
return _slides_tool_set
case 'sheets_tool_set':
if _sheets_tool_set is None:
from .google_api_tool_sets import sheets_tool_set as sheets
_sheets_tool_set = sheets
return _sheets_tool_set
case 'docs_tool_set':
if _docs_tool_set is None:
from .google_api_tool_sets import docs_tool_set as docs
_docs_tool_set = docs
return _docs_tool_set

View File

@ -19,37 +19,94 @@ from .google_api_tool_set import GoogleApiToolSet
logger = logging.getLogger(__name__)
calendar_tool_set = GoogleApiToolSet.load_tool_set(
api_name="calendar",
api_version="v3",
)
_bigquery_tool_set = None
_calendar_tool_set = None
_gmail_tool_set = None
_youtube_tool_set = None
_slides_tool_set = None
_sheets_tool_set = None
_docs_tool_set = None
bigquery_tool_set = GoogleApiToolSet.load_tool_set(
def __getattr__(name):
"""This method dynamically loads and returns GoogleApiToolSet instances for
various Google APIs. It uses a lazy loading approach, initializing each
tool set only when it is first requested. This avoids unnecessary loading
of tool sets that are not used in a given session.
Args:
name (str): The name of the tool set to retrieve (e.g.,
"bigquery_tool_set").
Returns:
GoogleApiToolSet: The requested tool set instance.
Raises:
AttributeError: If the requested tool set name is not recognized.
"""
global _bigquery_tool_set, _calendar_tool_set, _gmail_tool_set, _youtube_tool_set, _slides_tool_set, _sheets_tool_set, _docs_tool_set
match name:
case "bigquery_tool_set":
if _bigquery_tool_set is None:
_bigquery_tool_set = GoogleApiToolSet.load_tool_set(
api_name="bigquery",
api_version="v2",
)
)
gmail_tool_set = GoogleApiToolSet.load_tool_set(
return _bigquery_tool_set
case "calendar_tool_set":
if _calendar_tool_set is None:
_calendar_tool_set = GoogleApiToolSet.load_tool_set(
api_name="calendar",
api_version="v3",
)
return _calendar_tool_set
case "gmail_tool_set":
if _gmail_tool_set is None:
_gmail_tool_set = GoogleApiToolSet.load_tool_set(
api_name="gmail",
api_version="v1",
)
)
youtube_tool_set = GoogleApiToolSet.load_tool_set(
return _gmail_tool_set
case "youtube_tool_set":
if _youtube_tool_set is None:
_youtube_tool_set = GoogleApiToolSet.load_tool_set(
api_name="youtube",
api_version="v3",
)
)
slides_tool_set = GoogleApiToolSet.load_tool_set(
return _youtube_tool_set
case "slides_tool_set":
if _slides_tool_set is None:
_slides_tool_set = GoogleApiToolSet.load_tool_set(
api_name="slides",
api_version="v1",
)
)
sheets_tool_set = GoogleApiToolSet.load_tool_set(
return _slides_tool_set
case "sheets_tool_set":
if _sheets_tool_set is None:
_sheets_tool_set = GoogleApiToolSet.load_tool_set(
api_name="sheets",
api_version="v4",
)
)
docs_tool_set = GoogleApiToolSet.load_tool_set(
return _sheets_tool_set
case "docs_tool_set":
if _docs_tool_set is None:
_docs_tool_set = GoogleApiToolSet.load_tool_set(
api_name="docs",
api_version="v1",
)
)
return _docs_tool_set

View File

@ -311,7 +311,9 @@ class GoogleApiToOpenApiConverter:
# Determine the actual endpoint path
# Google often has the format something like 'users.messages.list'
rest_path = method_data.get("path", "/")
# flatPath is preferred as it provides the actual path, while path
# might contain variables like {+projectId}
rest_path = method_data.get("flatPath", method_data.get("path", "/"))
if not rest_path.startswith("/"):
rest_path = "/" + rest_path

View File

@ -16,18 +16,26 @@ from __future__ import annotations
from typing import TYPE_CHECKING
from google.genai import types
from typing_extensions import override
from .function_tool import FunctionTool
from .tool_context import ToolContext
if TYPE_CHECKING:
from ..models import LlmRequest
from ..memory.base_memory_service import MemoryResult
from ..models import LlmRequest
def load_memory(query: str, tool_context: ToolContext) -> 'list[MemoryResult]':
"""Loads the memory for the current user."""
"""Loads the memory for the current user.
Args:
query: The query to load the memory for.
Returns:
A list of memory results.
"""
response = tool_context.search_memory(query)
return response.memories
@ -38,6 +46,21 @@ class LoadMemoryTool(FunctionTool):
def __init__(self):
super().__init__(load_memory)
@override
def _get_declaration(self) -> types.FunctionDeclaration | None:
return types.FunctionDeclaration(
name=self.name,
description=self.description,
parameters=types.Schema(
type=types.Type.OBJECT,
properties={
'query': types.Schema(
type=types.Type.STRING,
)
},
),
)
@override
async def process_llm_request(
self,

View File

@ -0,0 +1,176 @@
from contextlib import AsyncExitStack
import functools
import sys
from typing import Any, TextIO
import anyio
from pydantic import BaseModel
try:
from mcp import ClientSession, StdioServerParameters
from mcp.client.sse import sse_client
from mcp.client.stdio import stdio_client
except ImportError as e:
import sys
if sys.version_info < (3, 10):
raise ImportError(
'MCP Tool requires Python 3.10 or above. Please upgrade your Python'
' version.'
) from e
else:
raise e
class SseServerParams(BaseModel):
"""Parameters for the MCP SSE connection.
See MCP SSE Client documentation for more details.
https://github.com/modelcontextprotocol/python-sdk/blob/main/src/mcp/client/sse.py
"""
url: str
headers: dict[str, Any] | None = None
timeout: float = 5
sse_read_timeout: float = 60 * 5
def retry_on_closed_resource(async_reinit_func_name: str):
"""Decorator to automatically reinitialize session and retry action.
When MCP session was closed, the decorator will automatically recreate the
session and retry the action with the same parameters.
Note:
1. async_reinit_func_name is the name of the class member function that
reinitializes the MCP session.
2. Both the decorated function and the async_reinit_func_name must be async
functions.
Usage:
class MCPTool:
...
async def create_session(self):
self.session = ...
@retry_on_closed_resource('create_session')
async def use_session(self):
await self.session.call_tool()
Args:
async_reinit_func_name: The name of the async function to recreate session.
Returns:
The decorated function.
"""
def decorator(func):
@functools.wraps(
func
) # Preserves original function metadata (name, docstring)
async def wrapper(self, *args, **kwargs):
try:
return await func(self, *args, **kwargs)
except anyio.ClosedResourceError:
try:
if hasattr(self, async_reinit_func_name) and callable(
getattr(self, async_reinit_func_name)
):
async_init_fn = getattr(self, async_reinit_func_name)
await async_init_fn()
else:
raise ValueError(
f'Function {async_reinit_func_name} does not exist in decorated'
' class. Please check the function name in'
' retry_on_closed_resource decorator.'
)
except Exception as reinit_err:
raise RuntimeError(
f'Error reinitializing: {reinit_err}'
) from reinit_err
return await func(self, *args, **kwargs)
return wrapper
return decorator
class MCPSessionManager:
"""Manages MCP client sessions.
This class provides methods for creating and initializing MCP client sessions,
handling different connection parameters (Stdio and SSE).
"""
def __init__(
self,
connection_params: StdioServerParameters | SseServerParams,
exit_stack: AsyncExitStack,
errlog: TextIO = sys.stderr,
) -> ClientSession:
"""Initializes the MCP session manager.
Example usage:
```
mcp_session_manager = MCPSessionManager(
connection_params=connection_params,
exit_stack=exit_stack,
)
session = await mcp_session_manager.create_session()
```
Args:
connection_params: Parameters for the MCP connection (Stdio or SSE).
exit_stack: AsyncExitStack to manage the session lifecycle.
errlog: (Optional) TextIO stream for error logging. Use only for
initializing a local stdio MCP session.
"""
self.connection_params = connection_params
self.exit_stack = exit_stack
self.errlog = errlog
async def create_session(self) -> ClientSession:
return await MCPSessionManager.initialize_session(
connection_params=self.connection_params,
exit_stack=self.exit_stack,
errlog=self.errlog,
)
@classmethod
async def initialize_session(
cls,
*,
connection_params: StdioServerParameters | SseServerParams,
exit_stack: AsyncExitStack,
errlog: TextIO = sys.stderr,
) -> ClientSession:
"""Initializes an MCP client session.
Args:
connection_params: Parameters for the MCP connection (Stdio or SSE).
exit_stack: AsyncExitStack to manage the session lifecycle.
errlog: (Optional) TextIO stream for error logging. Use only for
initializing a local stdio MCP session.
Returns:
ClientSession: The initialized MCP client session.
"""
if isinstance(connection_params, StdioServerParameters):
client = stdio_client(server=connection_params, errlog=errlog)
elif isinstance(connection_params, SseServerParams):
client = sse_client(
url=connection_params.url,
headers=connection_params.headers,
timeout=connection_params.timeout,
sse_read_timeout=connection_params.sse_read_timeout,
)
else:
raise ValueError(
'Unable to initialize connection. Connection should be'
' StdioServerParameters or SseServerParams, but got'
f' {connection_params}'
)
transports = await exit_stack.enter_async_context(client)
session = await exit_stack.enter_async_context(ClientSession(*transports))
await session.initialize()
return session

View File

@ -17,6 +17,8 @@ from typing import Optional
from google.genai.types import FunctionDeclaration
from typing_extensions import override
from .mcp_session_manager import MCPSessionManager, retry_on_closed_resource
# Attempt to import MCP Tool from the MCP library, and hints user to upgrade
# their Python version to 3.10 if it fails.
try:
@ -33,6 +35,7 @@ except ImportError as e:
else:
raise e
from ..base_tool import BaseTool
from ...auth.auth_credential import AuthCredential
from ...auth.auth_schemes import AuthScheme
@ -51,6 +54,7 @@ class MCPTool(BaseTool):
self,
mcp_tool: McpBaseTool,
mcp_session: ClientSession,
mcp_session_manager: MCPSessionManager,
auth_scheme: Optional[AuthScheme] = None,
auth_credential: Optional[AuthCredential] | None = None,
):
@ -79,10 +83,14 @@ class MCPTool(BaseTool):
self.description = mcp_tool.description if mcp_tool.description else ""
self.mcp_tool = mcp_tool
self.mcp_session = mcp_session
self.mcp_session_manager = mcp_session_manager
# TODO(cheliu): Support passing auth to MCP Server.
self.auth_scheme = auth_scheme
self.auth_credential = auth_credential
async def _reinitialize_session(self):
self.mcp_session = await self.mcp_session_manager.create_session()
@override
def _get_declaration(self) -> FunctionDeclaration:
"""Gets the function declaration for the tool.
@ -98,6 +106,7 @@ class MCPTool(BaseTool):
return function_decl
@override
@retry_on_closed_resource("_reinitialize_session")
async def run_async(self, *, args, tool_context: ToolContext):
"""Runs the tool asynchronously.
@ -109,5 +118,9 @@ class MCPTool(BaseTool):
Any: The response from the tool.
"""
# TODO(cheliu): Support passing tool context to MCP Server.
try:
response = await self.mcp_session.call_tool(self.name, arguments=args)
return response
except Exception as e:
print(e)
raise e

View File

@ -13,15 +13,16 @@
# limitations under the License.
from contextlib import AsyncExitStack
import sys
from types import TracebackType
from typing import Any, List, Optional, Tuple, Type
from typing import List, Optional, TextIO, Tuple, Type
from .mcp_session_manager import MCPSessionManager, SseServerParams, retry_on_closed_resource
# Attempt to import MCP Tool from the MCP library, and hints user to upgrade
# their Python version to 3.10 if it fails.
try:
from mcp import ClientSession, StdioServerParameters
from mcp.client.sse import sse_client
from mcp.client.stdio import stdio_client
from mcp.types import ListToolsResult
except ImportError as e:
import sys
@ -34,18 +35,9 @@ except ImportError as e:
else:
raise e
from pydantic import BaseModel
from .mcp_tool import MCPTool
class SseServerParams(BaseModel):
url: str
headers: dict[str, Any] | None = None
timeout: float = 5
sse_read_timeout: float = 60 * 5
class MCPToolset:
"""Connects to a MCP Server, and retrieves MCP Tools into ADK Tools.
@ -110,7 +102,11 @@ class MCPToolset:
"""
def __init__(
self, *, connection_params: StdioServerParameters | SseServerParams
self,
*,
connection_params: StdioServerParameters | SseServerParams,
errlog: TextIO = sys.stderr,
exit_stack=AsyncExitStack(),
):
"""Initializes the MCPToolset.
@ -175,7 +171,14 @@ class MCPToolset:
if not connection_params:
raise ValueError('Missing connection params in MCPToolset.')
self.connection_params = connection_params
self.exit_stack = AsyncExitStack()
self.errlog = errlog
self.exit_stack = exit_stack
self.session_manager = MCPSessionManager(
connection_params=self.connection_params,
exit_stack=self.exit_stack,
errlog=self.errlog,
)
@classmethod
async def from_server(
@ -183,6 +186,7 @@ class MCPToolset:
*,
connection_params: StdioServerParameters | SseServerParams,
async_exit_stack: Optional[AsyncExitStack] = None,
errlog: TextIO = sys.stderr,
) -> Tuple[List[MCPTool], AsyncExitStack]:
"""Retrieve all tools from the MCP connection.
@ -209,41 +213,27 @@ class MCPToolset:
the MCP server. Use `await async_exit_stack.aclose()` to close the
connection when server shuts down.
"""
toolset = cls(connection_params=connection_params)
async_exit_stack = async_exit_stack or AsyncExitStack()
toolset = cls(
connection_params=connection_params,
exit_stack=async_exit_stack,
errlog=errlog,
)
await async_exit_stack.enter_async_context(toolset)
tools = await toolset.load_tools()
return (tools, async_exit_stack)
async def _initialize(self) -> ClientSession:
"""Connects to the MCP Server and initializes the ClientSession."""
if isinstance(self.connection_params, StdioServerParameters):
client = stdio_client(self.connection_params)
elif isinstance(self.connection_params, SseServerParams):
client = sse_client(
url=self.connection_params.url,
headers=self.connection_params.headers,
timeout=self.connection_params.timeout,
sse_read_timeout=self.connection_params.sse_read_timeout,
)
else:
raise ValueError(
'Unable to initialize connection. Connection should be'
' StdioServerParameters or SseServerParams, but got'
f' {self.connection_params}'
)
transports = await self.exit_stack.enter_async_context(client)
self.session = await self.exit_stack.enter_async_context(
ClientSession(*transports)
)
await self.session.initialize()
self.session = await self.session_manager.create_session()
return self.session
async def _exit(self):
"""Closes the connection to MCP Server."""
await self.exit_stack.aclose()
@retry_on_closed_resource('_initialize')
async def load_tools(self) -> List[MCPTool]:
"""Loads all tools from the MCP Server.
@ -252,7 +242,11 @@ class MCPToolset:
"""
tools_response: ListToolsResult = await self.session.list_tools()
return [
MCPTool(mcp_tool=tool, mcp_session=self.session)
MCPTool(
mcp_tool=tool,
mcp_session=self.session,
mcp_session_manager=self.session_manager,
)
for tool in tools_response.tools
]

View File

@ -28,7 +28,7 @@ from typing_extensions import override
from ....auth.auth_credential import AuthCredential
from ....auth.auth_schemes import AuthScheme
from ....tools import BaseTool
from ....tools.base_tool import BaseTool
from ...tool_context import ToolContext
from ..auth.auth_helpers import credential_to_param
from ..auth.auth_helpers import dict_to_auth_scheme

View File

@ -13,4 +13,4 @@
# limitations under the License.
# version: date+base_cl
__version__ = "0.1.0"
__version__ = "0.1.1"

View File

@ -0,0 +1,13 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

View File

@ -0,0 +1,13 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

View File

@ -0,0 +1,434 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for utilities in eval."""
from google.adk.cli.utils.evals import convert_session_to_eval_format
from google.adk.events.event import Event
from google.adk.sessions.session import Session
from google.genai import types
def build_event(author: str, parts_content: list[dict]) -> Event:
"""Builds an Event object with specified parts."""
parts = []
for p_data in parts_content:
part_args = {}
if "text" in p_data:
part_args["text"] = p_data["text"]
if "func_name" in p_data:
part_args["function_call"] = types.FunctionCall(
name=p_data.get("func_name"), args=p_data.get("func_args")
)
# Add other part types here if needed for future tests
parts.append(types.Part(**part_args))
return Event(author=author, content=types.Content(parts=parts))
def test_convert_empty_session():
"""Test conversion function with empty events list in Session."""
# Pydantic models require mandatory fields for instantiation
session_empty_events = Session(
id="s1", app_name="app", user_id="u1", events=[]
)
assert not convert_session_to_eval_format(session_empty_events)
def test_convert_none_session():
"""Test conversion function with None Session."""
assert not convert_session_to_eval_format(None)
def test_convert_session_skips_initial_non_user_events():
"""Test conversion function with only user events."""
events = [
build_event("model", [{"text": "Hello"}]),
build_event("user", [{"text": "How are you?"}]),
]
session = Session(id="s1", app_name="app", user_id="u1", events=events)
expected = [
{
"query": "How are you?",
"expected_tool_use": [],
"expected_intermediate_agent_responses": [],
"reference": "",
},
]
assert convert_session_to_eval_format(session) == expected
def test_convert_single_turn_text_only():
"""Test a single user query followed by a single agent text response."""
events = [
build_event("user", [{"text": "What is the time?"}]),
build_event("root_agent", [{"text": "It is 3 PM."}]),
]
session = Session(id="s1", app_name="app", user_id="u1", events=events)
expected = [{
"query": "What is the time?",
"expected_tool_use": [],
"expected_intermediate_agent_responses": [],
"reference": "It is 3 PM.",
}]
assert convert_session_to_eval_format(session) == expected
def test_convert_single_turn_tool_only():
"""Test a single user query followed by a single agent tool call."""
events = [
build_event("user", [{"text": "Get weather for Seattle"}]),
build_event(
"root_agent",
[{"func_name": "get_weather", "func_args": {"city": "Seattle"}}],
),
]
session = Session(id="s1", app_name="app", user_id="u1", events=events)
expected = [{
"query": "Get weather for Seattle",
"expected_tool_use": [
{"tool_name": "get_weather", "tool_input": {"city": "Seattle"}}
],
"expected_intermediate_agent_responses": [],
"reference": "",
}]
assert convert_session_to_eval_format(session) == expected
def test_convert_single_turn_multiple_tools_and_texts():
"""Test a turn with multiple agent responses (tools and text)."""
events = [
build_event("user", [{"text": "Do task A then task B"}]),
build_event(
"root_agent", [{"text": "Okay, starting task A."}]
), # Intermediate Text 1
build_event(
"root_agent", [{"func_name": "task_A", "func_args": {"param": 1}}]
), # Tool 1
build_event(
"root_agent", [{"text": "Task A done. Now starting task B."}]
), # Intermediate Text 2
build_event(
"another_agent", [{"func_name": "task_B", "func_args": {}}]
), # Tool 2
build_event(
"root_agent", [{"text": "All tasks completed."}]
), # Final Text (Reference)
]
session = Session(id="s1", app_name="app", user_id="u1", events=events)
expected = [{
"query": "Do task A then task B",
"expected_tool_use": [
{"tool_name": "task_A", "tool_input": {"param": 1}},
{"tool_name": "task_B", "tool_input": {}},
],
"expected_intermediate_agent_responses": [
{"author": "root_agent", "text": "Okay, starting task A."},
{
"author": "root_agent",
"text": "Task A done. Now starting task B.",
},
],
"reference": "All tasks completed.",
}]
assert convert_session_to_eval_format(session) == expected
def test_convert_multi_turn_session():
"""Test a session with multiple user/agent turns."""
events = [
build_event("user", [{"text": "Query 1"}]),
build_event("agent", [{"text": "Response 1"}]),
build_event("user", [{"text": "Query 2"}]),
build_event("agent", [{"func_name": "tool_X", "func_args": {}}]),
build_event("agent", [{"text": "Response 2"}]),
]
session = Session(id="s1", app_name="app", user_id="u1", events=events)
expected = [
{ # Turn 1
"query": "Query 1",
"expected_tool_use": [],
"expected_intermediate_agent_responses": [],
"reference": "Response 1",
},
{ # Turn 2
"query": "Query 2",
"expected_tool_use": [{"tool_name": "tool_X", "tool_input": {}}],
"expected_intermediate_agent_responses": [],
"reference": "Response 2",
},
]
assert convert_session_to_eval_format(session) == expected
def test_convert_agent_event_multiple_parts():
"""Test an agent event with both text and tool call parts."""
events = [
build_event("user", [{"text": "Do something complex"}]),
# Build event with multiple dicts in parts_content list
build_event(
"agent",
[
{"text": "Okay, doing it."},
{"func_name": "complex_tool", "func_args": {"value": True}},
],
),
build_event("agent", [{"text": "Finished."}]),
]
session = Session(id="s1", app_name="app", user_id="u1", events=events)
expected = [{
"query": "Do something complex",
"expected_tool_use": [
{"tool_name": "complex_tool", "tool_input": {"value": True}}
],
"expected_intermediate_agent_responses": [{
"author": "agent",
"text": "Okay, doing it.",
}], # Text from first part of agent event
"reference": "Finished.", # Text from second agent event
}]
assert convert_session_to_eval_format(session) == expected
def test_convert_handles_missing_content_or_parts():
"""Test that events missing content or parts are skipped gracefully."""
events = [
build_event("user", [{"text": "Query 1"}]),
Event(author="agent", content=None), # Agent event missing content
build_event("agent", [{"text": "Response 1"}]),
Event(author="user", content=None), # User event missing content
build_event("user", [{"text": "Query 2"}]),
Event(
author="agent", content=types.Content(parts=[])
), # Agent event with empty parts list
build_event("agent", [{"text": "Response 2"}]),
# User event with content but no parts (or None parts)
Event(author="user", content=types.Content(parts=None)),
build_event("user", [{"text": "Query 3"}]),
build_event("agent", [{"text": "Response 3"}]),
]
session = Session(id="s1", app_name="app", user_id="u1", events=events)
expected = [
{ # Turn 1 (from Query 1)
"query": "Query 1",
"expected_tool_use": [],
"expected_intermediate_agent_responses": [],
"reference": "Response 1",
},
{ # Turn 2 (from Query 2 - user event with None content was skipped)
"query": "Query 2",
"expected_tool_use": [],
"expected_intermediate_agent_responses": [],
"reference": "Response 2",
},
{ # Turn 3 (from Query 3 - user event with None parts was skipped)
"query": "Query 3",
"expected_tool_use": [],
"expected_intermediate_agent_responses": [],
"reference": "Response 3",
},
]
assert convert_session_to_eval_format(session) == expected
def test_convert_handles_missing_tool_name_or_args():
"""Test tool calls with missing name or args."""
events = [
build_event("user", [{"text": "Call tools"}]),
# Event where FunctionCall has name=None
Event(
author="agent",
content=types.Content(
parts=[
types.Part(
function_call=types.FunctionCall(name=None, args={"a": 1})
)
]
),
),
# Event where FunctionCall has args=None
Event(
author="agent",
content=types.Content(
parts=[
types.Part(
function_call=types.FunctionCall(name="tool_B", args=None)
)
]
),
),
# Event where FunctionCall part exists but FunctionCall object is None
# (should skip)
Event(
author="agent",
content=types.Content(
parts=[types.Part(function_call=None, text="some text")]
),
),
build_event("agent", [{"text": "Done"}]),
]
session = Session(id="s1", app_name="app", user_id="u1", events=events)
expected = [{
"query": "Call tools",
"expected_tool_use": [
{"tool_name": "", "tool_input": {"a": 1}}, # Defaults name to ""
{"tool_name": "tool_B", "tool_input": {}}, # Defaults args to {}
],
"expected_intermediate_agent_responses": [{
"author": "agent",
"text": "some text",
}], # Text part from the event where function_call was None
"reference": "Done",
}]
assert convert_session_to_eval_format(session) == expected
def test_convert_handles_missing_user_query_text():
"""Test user event where the first part has no text."""
events = [
# Event where user part has text=None
Event(
author="user", content=types.Content(parts=[types.Part(text=None)])
),
build_event("agent", [{"text": "Response 1"}]),
# Event where user part has text=""
build_event("user", [{"text": ""}]),
build_event("agent", [{"text": "Response 2"}]),
]
session = Session(id="s1", app_name="app", user_id="u1", events=events)
expected = [
{
"query": "", # Defaults to "" if text is None
"expected_tool_use": [],
"expected_intermediate_agent_responses": [],
"reference": "Response 1",
},
{
"query": "", # Defaults to "" if text is ""
"expected_tool_use": [],
"expected_intermediate_agent_responses": [],
"reference": "Response 2",
},
]
assert convert_session_to_eval_format(session) == expected
def test_convert_handles_empty_agent_text():
"""Test agent responses with empty string text."""
events = [
build_event("user", [{"text": "Query"}]),
build_event("agent", [{"text": "Okay"}]),
build_event("agent", [{"text": ""}]), # Empty text
build_event("agent", [{"text": "Done"}]),
]
session = Session(id="s1", app_name="app", user_id="u1", events=events)
expected = [{
"query": "Query",
"expected_tool_use": [],
"expected_intermediate_agent_responses": [
{"author": "agent", "text": "Okay"},
],
"reference": "Done",
}]
assert convert_session_to_eval_format(session) == expected
def test_convert_complex_sample_session():
"""Test using the complex sample session provided earlier."""
events = [
build_event("user", [{"text": "What can you do?"}]),
build_event(
"root_agent",
[{"text": "I can roll dice and check if numbers are prime. \n"}],
),
build_event(
"user",
[{
"text": (
"Roll a 8 sided dice and then check if 90 is a prime number"
" or not."
)
}],
),
build_event(
"root_agent",
[{
"func_name": "transfer_to_agent",
"func_args": {"agent_name": "roll_agent"},
}],
),
# Skipping FunctionResponse events as they don't have text/functionCall
# parts used by converter
build_event(
"roll_agent", [{"func_name": "roll_die", "func_args": {"sides": 8}}]
),
# Skipping FunctionResponse
build_event(
"roll_agent",
[
{"text": "I rolled a 2. Now, I'll check if 90 is prime. \n\n"},
{
"func_name": "transfer_to_agent",
"func_args": {"agent_name": "prime_agent"},
},
],
),
# Skipping FunctionResponse
build_event(
"prime_agent",
[{"func_name": "check_prime", "func_args": {"nums": [90]}}],
),
# Skipping FunctionResponse
build_event("prime_agent", [{"text": "90 is not a prime number. \n"}]),
]
session = Session(
id="some_id",
app_name="hello_world_ma",
user_id="user",
events=events,
)
expected = [
{
"query": "What can you do?",
"expected_tool_use": [],
"expected_intermediate_agent_responses": [],
"reference": "I can roll dice and check if numbers are prime. \n",
},
{
"query": (
"Roll a 8 sided dice and then check if 90 is a prime number or"
" not."
),
"expected_tool_use": [
{
"tool_name": "transfer_to_agent",
"tool_input": {"agent_name": "roll_agent"},
},
{"tool_name": "roll_die", "tool_input": {"sides": 8}},
{
"tool_name": "transfer_to_agent",
"tool_input": {"agent_name": "prime_agent"},
}, # From combined event
{"tool_name": "check_prime", "tool_input": {"nums": [90]}},
],
"expected_intermediate_agent_responses": [{
"author": "roll_agent",
"text": "I rolled a 2. Now, I'll check if 90 is prime. \n\n",
}], # Text from combined event
"reference": "90 is not a prime number. \n",
},
]
actual = convert_session_to_eval_format(session)
assert actual == expected

View File

@ -0,0 +1,13 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

View File

@ -0,0 +1,259 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for the Response Evaluator."""
from unittest.mock import MagicMock
from unittest.mock import patch
from google.adk.evaluation.response_evaluator import ResponseEvaluator
import pandas as pd
import pytest
from vertexai.preview.evaluation import MetricPromptTemplateExamples
# Mock object for the result normally returned by _perform_eval
MOCK_EVAL_RESULT = MagicMock()
MOCK_EVAL_RESULT.summary_metrics = {"mock_metric": 0.75, "another_mock": 3.5}
# Add a metrics_table for testing _print_results interaction
MOCK_EVAL_RESULT.metrics_table = pd.DataFrame({
"prompt": ["mock_query1"],
"response": ["mock_resp1"],
"mock_metric": [0.75],
})
SAMPLE_TURN_1_ALL_KEYS = {
"query": "query1",
"response": "response1",
"actual_tool_use": [{"tool_name": "tool_a", "tool_input": {}}],
"expected_tool_use": [{"tool_name": "tool_a", "tool_input": {}}],
"reference": "reference1",
}
SAMPLE_TURN_2_MISSING_REF = {
"query": "query2",
"response": "response2",
"actual_tool_use": [],
"expected_tool_use": [],
# "reference": "reference2" # Missing
}
SAMPLE_TURN_3_MISSING_EXP_TOOLS = {
"query": "query3",
"response": "response3",
"actual_tool_use": [{"tool_name": "tool_b", "tool_input": {}}],
# "expected_tool_use": [], # Missing
"reference": "reference3",
}
SAMPLE_TURN_4_MINIMAL = {
"query": "query4",
"response": "response4",
# Minimal keys, others missing
}
@patch(
"google.adk.evaluation.response_evaluator.ResponseEvaluator._perform_eval"
)
class TestResponseEvaluator:
"""A class to help organize "patch" that are applicabple to all tests."""
def test_evaluate_none_dataset_raises_value_error(self, mock_perform_eval):
"""Test evaluate function raises ValueError for an empty list."""
with pytest.raises(ValueError, match="The evaluation dataset is empty."):
ResponseEvaluator.evaluate(None, ["response_evaluation_score"])
mock_perform_eval.assert_not_called() # Ensure _perform_eval was not called
def test_evaluate_empty_dataset_raises_value_error(self, mock_perform_eval):
"""Test evaluate function raises ValueError for an empty list."""
with pytest.raises(ValueError, match="The evaluation dataset is empty."):
ResponseEvaluator.evaluate([], ["response_evaluation_score"])
mock_perform_eval.assert_not_called() # Ensure _perform_eval was not called
def test_evaluate_determines_metrics_correctly_for_perform_eval(
self, mock_perform_eval
):
"""Test that the correct metrics list is passed to _perform_eval based on criteria/keys."""
mock_perform_eval.return_value = MOCK_EVAL_RESULT
# Test case 1: Only Coherence
raw_data_1 = [[SAMPLE_TURN_1_ALL_KEYS]]
criteria_1 = ["response_evaluation_score"]
ResponseEvaluator.evaluate(raw_data_1, criteria_1)
_, kwargs = mock_perform_eval.call_args
assert kwargs["metrics"] == [
MetricPromptTemplateExamples.Pointwise.COHERENCE
]
mock_perform_eval.reset_mock() # Reset mock for next call
# Test case 2: Only Rouge
raw_data_2 = [[SAMPLE_TURN_1_ALL_KEYS]]
criteria_2 = ["response_match_score"]
ResponseEvaluator.evaluate(raw_data_2, criteria_2)
_, kwargs = mock_perform_eval.call_args
assert kwargs["metrics"] == ["rouge_1"]
mock_perform_eval.reset_mock()
# Test case 3: No metrics if keys missing in first turn
raw_data_3 = [[SAMPLE_TURN_4_MINIMAL, SAMPLE_TURN_1_ALL_KEYS]]
criteria_3 = ["response_evaluation_score", "response_match_score"]
ResponseEvaluator.evaluate(raw_data_3, criteria_3)
_, kwargs = mock_perform_eval.call_args
assert kwargs["metrics"] == []
mock_perform_eval.reset_mock()
# Test case 4: No metrics if criteria empty
raw_data_4 = [[SAMPLE_TURN_1_ALL_KEYS]]
criteria_4 = []
ResponseEvaluator.evaluate(raw_data_4, criteria_4)
_, kwargs = mock_perform_eval.call_args
assert kwargs["metrics"] == []
mock_perform_eval.reset_mock()
def test_evaluate_calls_perform_eval_correctly_all_metrics(
self, mock_perform_eval
):
"""Test evaluate function calls _perform_eval with expected args when all criteria/keys are present."""
# Arrange
mock_perform_eval.return_value = (
MOCK_EVAL_RESULT # Configure the mock return value
)
raw_data = [[SAMPLE_TURN_1_ALL_KEYS]]
criteria = ["response_evaluation_score", "response_match_score"]
# Act
summary = ResponseEvaluator.evaluate(raw_data, criteria)
# Assert
# 1. Check metrics determined by _get_metrics (passed to _perform_eval)
expected_metrics_list = [
MetricPromptTemplateExamples.Pointwise.COHERENCE,
"rouge_1",
]
# 2. Check DataFrame prepared (passed to _perform_eval)
expected_df_data = [{
"prompt": "query1",
"response": "response1",
"actual_tool_use": [{"tool_name": "tool_a", "tool_input": {}}],
"reference_trajectory": [{"tool_name": "tool_a", "tool_input": {}}],
"reference": "reference1",
}]
expected_df = pd.DataFrame(expected_df_data)
# Assert _perform_eval was called once
mock_perform_eval.assert_called_once()
# Get the arguments passed to the mocked _perform_eval
_, kwargs = mock_perform_eval.call_args
# Check the 'dataset' keyword argument
pd.testing.assert_frame_equal(kwargs["dataset"], expected_df)
# Check the 'metrics' keyword argument
assert kwargs["metrics"] == expected_metrics_list
# 3. Check the correct summary metrics are returned
# (from mock_perform_eval's return value)
assert summary == MOCK_EVAL_RESULT.summary_metrics
def test_evaluate_prepares_dataframe_correctly_for_perform_eval(
self, mock_perform_eval
):
"""Test that the DataFrame is correctly flattened and renamed before passing to _perform_eval."""
mock_perform_eval.return_value = MOCK_EVAL_RESULT
raw_data = [
[SAMPLE_TURN_1_ALL_KEYS], # Conversation 1
[
SAMPLE_TURN_2_MISSING_REF,
SAMPLE_TURN_3_MISSING_EXP_TOOLS,
], # Conversation 2
]
criteria = [
"response_match_score"
] # Doesn't affect the DataFrame structure
ResponseEvaluator.evaluate(raw_data, criteria)
# Expected flattened and renamed data
expected_df_data = [
# Turn 1 (from SAMPLE_TURN_1_ALL_KEYS)
{
"prompt": "query1",
"response": "response1",
"actual_tool_use": [{"tool_name": "tool_a", "tool_input": {}}],
"reference_trajectory": [{"tool_name": "tool_a", "tool_input": {}}],
"reference": "reference1",
},
# Turn 2 (from SAMPLE_TURN_2_MISSING_REF)
{
"prompt": "query2",
"response": "response2",
"actual_tool_use": [],
"reference_trajectory": [],
# "reference": None # Missing key results in NaN in DataFrame
# usually
},
# Turn 3 (from SAMPLE_TURN_3_MISSING_EXP_TOOLS)
{
"prompt": "query3",
"response": "response3",
"actual_tool_use": [{"tool_name": "tool_b", "tool_input": {}}],
# "reference_trajectory": None, # Missing key results in NaN
"reference": "reference3",
},
]
# Need to be careful with missing keys -> NaN when creating DataFrame
# Pandas handles this automatically when creating from list of dicts
expected_df = pd.DataFrame(expected_df_data)
mock_perform_eval.assert_called_once()
_, kwargs = mock_perform_eval.call_args
# Compare the DataFrame passed to the mock
pd.testing.assert_frame_equal(kwargs["dataset"], expected_df)
@patch(
"google.adk.evaluation.response_evaluator.ResponseEvaluator._print_results"
) # Mock the private print method
def test_evaluate_print_detailed_results(
self, mock_print_results, mock_perform_eval
):
"""Test _print_results function is called when print_detailed_results=True."""
mock_perform_eval.return_value = (
MOCK_EVAL_RESULT # Ensure _perform_eval returns our mock result
)
raw_data = [[SAMPLE_TURN_1_ALL_KEYS]]
criteria = ["response_match_score"]
ResponseEvaluator.evaluate(raw_data, criteria, print_detailed_results=True)
# Assert _perform_eval was called
mock_perform_eval.assert_called_once()
# Assert _print_results was called once with the result object
# from _perform_eval
mock_print_results.assert_called_once_with(MOCK_EVAL_RESULT)
@patch(
"google.adk.evaluation.response_evaluator.ResponseEvaluator._print_results"
)
def test_evaluate_no_print_detailed_results(
self, mock_print_results, mock_perform_eval
):
"""Test _print_results function is NOT called when print_detailed_results=False (default)."""
mock_perform_eval.return_value = MOCK_EVAL_RESULT
raw_data = [[SAMPLE_TURN_1_ALL_KEYS]]
criteria = ["response_match_score"]
ResponseEvaluator.evaluate(raw_data, criteria, print_detailed_results=False)
# Assert _perform_eval was called
mock_perform_eval.assert_called_once()
# Assert _print_results was NOT called
mock_print_results.assert_not_called()

View File

@ -0,0 +1,271 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Testings for the Trajectory Evaluator."""
import math
from google.adk.evaluation.trajectory_evaluator import TrajectoryEvaluator
import pytest
# Define reusable tool call structures
TOOL_ROLL_DICE_16 = {"tool_name": "roll_die", "tool_input": {"sides": 16}}
TOOL_ROLL_DICE_6 = {"tool_name": "roll_die", "tool_input": {"sides": 6}}
TOOL_GET_WEATHER = {
"tool_name": "get_weather",
"tool_input": {"location": "Paris"},
}
TOOL_GET_WEATHER_SF = {
"tool_name": "get_weather",
"tool_input": {"location": "SF"},
}
# Sample data for turns
TURN_MATCH = {
"query": "Q1",
"response": "R1",
"actual_tool_use": [TOOL_ROLL_DICE_16],
"expected_tool_use": [TOOL_ROLL_DICE_16],
}
TURN_MISMATCH_INPUT = {
"query": "Q2",
"response": "R2",
"actual_tool_use": [TOOL_ROLL_DICE_6],
"expected_tool_use": [TOOL_ROLL_DICE_16],
}
TURN_MISMATCH_NAME = {
"query": "Q3",
"response": "R3",
"actual_tool_use": [TOOL_GET_WEATHER],
"expected_tool_use": [TOOL_ROLL_DICE_16],
}
TURN_MATCH_MULTIPLE = {
"query": "Q4",
"response": "R4",
"actual_tool_use": [TOOL_GET_WEATHER, TOOL_ROLL_DICE_6],
"expected_tool_use": [TOOL_GET_WEATHER, TOOL_ROLL_DICE_6],
}
TURN_MISMATCH_ORDER = {
"query": "Q5",
"response": "R5",
"actual_tool_use": [TOOL_ROLL_DICE_6, TOOL_GET_WEATHER],
"expected_tool_use": [TOOL_GET_WEATHER, TOOL_ROLL_DICE_6],
}
TURN_MISMATCH_LENGTH_ACTUAL_LONGER = {
"query": "Q6",
"response": "R6",
"actual_tool_use": [TOOL_GET_WEATHER, TOOL_ROLL_DICE_6],
"expected_tool_use": [TOOL_GET_WEATHER],
}
TURN_MISMATCH_LENGTH_EXPECTED_LONGER = {
"query": "Q7",
"response": "R7",
"actual_tool_use": [TOOL_GET_WEATHER],
"expected_tool_use": [TOOL_GET_WEATHER, TOOL_ROLL_DICE_6],
}
TURN_MATCH_WITH_MOCK_OUTPUT = {
"query": "Q8",
"response": "R8",
"actual_tool_use": [TOOL_GET_WEATHER_SF],
"expected_tool_use": [
{**TOOL_GET_WEATHER_SF, "mock_tool_output": "Sunny"}
], # Add mock output to expected
}
TURN_MATCH_EMPTY_TOOLS = {
"query": "Q9",
"response": "R9",
"actual_tool_use": [],
"expected_tool_use": [],
}
TURN_MISMATCH_EMPTY_VS_NONEMPTY = {
"query": "Q10",
"response": "R10",
"actual_tool_use": [],
"expected_tool_use": [TOOL_GET_WEATHER],
}
def test_evaluate_none_dataset_raises_value_error():
"""Tests evaluate function raises ValueError for an empty list."""
with pytest.raises(ValueError, match="The evaluation dataset is empty."):
TrajectoryEvaluator.evaluate(None)
def test_evaluate_empty_dataset_raises_value_error():
"""Tests evaluate function raises ValueError for an empty list."""
with pytest.raises(ValueError, match="The evaluation dataset is empty."):
TrajectoryEvaluator.evaluate([])
def test_evaluate_single_turn_match():
"""Tests evaluate function with one conversation, one turn, perfect match."""
eval_dataset = [[TURN_MATCH]]
assert TrajectoryEvaluator.evaluate(eval_dataset) == 1.0
def test_evaluate_single_turn_mismatch():
"""Tests evaluate function with one conversation, one turn, mismatch."""
eval_dataset = [[TURN_MISMATCH_INPUT]]
assert TrajectoryEvaluator.evaluate(eval_dataset) == 0.0
def test_evaluate_multiple_turns_all_match():
"""Tests evaluate function with one conversation, multiple turns, all match."""
eval_dataset = [[TURN_MATCH, TURN_MATCH_MULTIPLE, TURN_MATCH_EMPTY_TOOLS]]
assert TrajectoryEvaluator.evaluate(eval_dataset) == 1.0
def test_evaluate_multiple_turns_mixed():
"""Tests evaluate function with one conversation, mixed match/mismatch turns."""
eval_dataset = [
[TURN_MATCH, TURN_MISMATCH_NAME, TURN_MATCH_MULTIPLE, TURN_MISMATCH_ORDER]
]
# Expected: (1.0 + 0.0 + 1.0 + 0.0) / 4 = 0.5
assert TrajectoryEvaluator.evaluate(eval_dataset) == 0.5
def test_evaluate_multiple_conversations_mixed():
"""Tests evaluate function with multiple conversations, mixed turns."""
eval_dataset = [
[TURN_MATCH, TURN_MISMATCH_INPUT], # Conv 1: 1.0, 0.0 -> Avg 0.5
[TURN_MATCH_MULTIPLE], # Conv 2: 1.0 -> Avg 1.0
[
TURN_MISMATCH_ORDER,
TURN_MISMATCH_LENGTH_ACTUAL_LONGER,
TURN_MATCH,
], # Conv 3: 0.0, 0.0, 1.0 -> Avg 1/3
]
# Expected: (1.0 + 0.0 + 1.0 + 0.0 + 0.0 + 1.0) / 6 = 3.0 / 6 = 0.5
assert TrajectoryEvaluator.evaluate(eval_dataset) == 0.5
def test_evaluate_ignores_mock_tool_output_in_expected():
"""Tests evaluate function correctly compares even if expected has mock_tool_output."""
eval_dataset = [[TURN_MATCH_WITH_MOCK_OUTPUT]]
assert TrajectoryEvaluator.evaluate(eval_dataset) == 1.0
def test_evaluate_match_empty_tool_lists():
"""Tests evaluate function correctly matches empty tool lists."""
eval_dataset = [[TURN_MATCH_EMPTY_TOOLS]]
assert TrajectoryEvaluator.evaluate(eval_dataset) == 1.0
def test_evaluate_mismatch_empty_vs_nonempty():
"""Tests evaluate function correctly mismatches empty vs non-empty tool lists."""
eval_dataset = [[TURN_MISMATCH_EMPTY_VS_NONEMPTY]]
assert TrajectoryEvaluator.evaluate(eval_dataset) == 0.0
eval_dataset_rev = [[{
**TURN_MISMATCH_EMPTY_VS_NONEMPTY, # Swap actual/expected
"actual_tool_use": [TOOL_GET_WEATHER],
"expected_tool_use": [],
}]]
assert TrajectoryEvaluator.evaluate(eval_dataset_rev) == 0.0
def test_evaluate_dataset_with_empty_conversation():
"""Tests evaluate function handles dataset containing an empty conversation list."""
eval_dataset = [[TURN_MATCH], []] # One valid conversation, one empty
# Should only evaluate the first conversation -> 1.0 / 1 turn = 1.0
assert TrajectoryEvaluator.evaluate(eval_dataset) == 1.0
def test_evaluate_dataset_only_empty_conversation():
"""Tests evaluate function handles dataset with only an empty conversation."""
eval_dataset = [[]]
# No rows evaluated, mean of empty series is NaN
# Depending on desired behavior, this could be 0.0 or NaN. The code returns
# NaN.
assert math.isnan(TrajectoryEvaluator.evaluate(eval_dataset))
def test_evaluate_print_detailed_results(capsys):
"""Tests evaluate function runs with print_detailed_results=True and prints something."""
eval_dataset = [[TURN_MATCH, TURN_MISMATCH_INPUT]]
TrajectoryEvaluator.evaluate(eval_dataset, print_detailed_results=True)
captured = capsys.readouterr()
assert "query" in captured.out # Check if the results table header is printed
assert "R1" in captured.out # Check if some data is printed
assert "Failures:" in captured.out # Check if failures header is printed
assert "Q2" in captured.out # Check if the failing query is printed
def test_evaluate_no_failures_print(capsys):
"""Tests evaluate function does not print Failures section when all turns match."""
eval_dataset = [[TURN_MATCH]]
TrajectoryEvaluator.evaluate(eval_dataset, print_detailed_results=True)
captured = capsys.readouterr()
assert "query" in captured.out # Results table should still print
assert "Failures:" not in captured.out # Failures section should NOT print
def test_are_tools_equal_identical():
"""Tests are_tools_equal function with identical lists."""
list_a = [TOOL_GET_WEATHER, TOOL_ROLL_DICE_6]
list_b = [TOOL_GET_WEATHER, TOOL_ROLL_DICE_6]
assert TrajectoryEvaluator.are_tools_equal(list_a, list_b)
def test_are_tools_equal_empty():
"""Tests are_tools_equal function with empty lists."""
assert TrajectoryEvaluator.are_tools_equal([], [])
def test_are_tools_equal_different_order():
"""Tests are_tools_equal function with same tools, different order."""
list_a = [TOOL_ROLL_DICE_6, TOOL_GET_WEATHER]
list_b = [TOOL_GET_WEATHER, TOOL_ROLL_DICE_6]
assert not TrajectoryEvaluator.are_tools_equal(list_a, list_b)
def test_are_tools_equal_different_length():
"""Tests are_tools_equal function with lists of different lengths."""
list_a = [TOOL_GET_WEATHER, TOOL_ROLL_DICE_6]
list_b = [TOOL_GET_WEATHER]
assert not TrajectoryEvaluator.are_tools_equal(list_a, list_b)
def test_are_tools_equal_different_input_values():
"""Tests are_tools_equal function with different input values."""
list_a = [TOOL_ROLL_DICE_16]
list_b = [TOOL_ROLL_DICE_6]
assert not TrajectoryEvaluator.are_tools_equal(list_a, list_b)
def test_are_tools_equal_different_tool_names():
"""Tests are_tools_equal function with different tool names."""
list_a = [TOOL_ROLL_DICE_16]
list_b = [TOOL_GET_WEATHER]
assert not TrajectoryEvaluator.are_tools_equal(list_a, list_b)
def test_are_tools_equal_ignores_extra_keys():
"""Tests are_tools_equal function ignores keys other than tool_name/tool_input."""
list_a = [{
"tool_name": "get_weather",
"tool_input": {"location": "Paris"},
"extra_key": "abc",
}]
list_b = [{
"tool_name": "get_weather",
"tool_input": {"location": "Paris"},
"other_key": 123,
}]
assert TrajectoryEvaluator.are_tools_equal(list_a, list_b)
def test_are_tools_equal_one_empty_one_not():
"""Tests are_tools_equal function with one empty list and one non-empty list."""
list_a = []
list_b = [TOOL_GET_WEATHER]
assert not TrajectoryEvaluator.are_tools_equal(list_a, list_b)

View File

@ -225,3 +225,76 @@ def test_create_new_session_will_merge_states(service_type):
assert session_2.state.get('user:key1') == 'value1'
assert not session_2.state.get('key1')
assert not session_2.state.get('temp:key')
@pytest.mark.parametrize(
'service_type', [SessionServiceType.IN_MEMORY, SessionServiceType.DATABASE]
)
def test_append_event_bytes(service_type):
session_service = get_session_service(service_type)
app_name = 'my_app'
user_id = 'user'
session = session_service.create_session(app_name=app_name, user_id=user_id)
event = Event(
invocation_id='invocation',
author='user',
content=types.Content(
role='user',
parts=[
types.Part.from_bytes(
data=b'test_image_data', mime_type='image/png'
),
],
),
)
session_service.append_event(session=session, event=event)
assert session.events[0].content.parts[0] == types.Part.from_bytes(
data=b'test_image_data', mime_type='image/png'
)
events = session_service.get_session(
app_name=app_name, user_id=user_id, session_id=session.id
).events
assert len(events) == 1
assert events[0].content.parts[0] == types.Part.from_bytes(
data=b'test_image_data', mime_type='image/png'
)
@pytest.mark.parametrize(
'service_type', [SessionServiceType.IN_MEMORY, SessionServiceType.DATABASE]
)
def test_append_event_complete(service_type):
session_service = get_session_service(service_type)
app_name = 'my_app'
user_id = 'user'
session = session_service.create_session(app_name=app_name, user_id=user_id)
event = Event(
invocation_id='invocation',
author='user',
content=types.Content(role='user', parts=[types.Part(text='test_text')]),
turn_complete=True,
partial=False,
actions=EventActions(
artifact_delta={
'file': 0,
},
transfer_to_agent='agent',
escalate=True,
),
long_running_tool_ids={'tool1'},
error_code='error_code',
error_message='error_message',
interrupted=True,
)
session_service.append_event(session=session, event=event)
assert (
session_service.get_session(
app_name=app_name, user_id=user_id, session_id=session.id
)
== session
)

View File

@ -57,7 +57,7 @@ MOCK_EVENT_JSON = [
{
'name': (
'projects/test-project/locations/test-location/'
'reasoningEngines/test_engine/sessions/1/events/123'
'reasoningEngines/123/sessions/1/events/123'
),
'invocationId': '123',
'author': 'user',
@ -111,7 +111,7 @@ MOCK_SESSION = Session(
SESSION_REGEX = r'^reasoningEngines/([^/]+)/sessions/([^/]+)$'
SESSIONS_REGEX = r'^reasoningEngines/([^/]+)/sessions$'
SESSIONS_REGEX = r'^reasoningEngines/([^/]+)/sessions\?filter=user_id=([^/]+)$'
EVENTS_REGEX = r'^reasoningEngines/([^/]+)/sessions/([^/]+)/events$'
LRO_REGEX = r'^operations/([^/]+)$'
@ -136,39 +136,52 @@ class MockApiClient:
else:
raise ValueError(f'Session not found: {session_id}')
elif re.match(SESSIONS_REGEX, path):
match = re.match(SESSIONS_REGEX, path)
return {
'sessions': self.session_dict.values(),
'sessions': [
session
for session in self.session_dict.values()
if session['userId'] == match.group(2)
],
}
elif re.match(EVENTS_REGEX, path):
match = re.match(EVENTS_REGEX, path)
if match:
return {'sessionEvents': self.event_dict[match.group(2)]}
return {
'sessionEvents': (
self.event_dict[match.group(2)]
if match.group(2) in self.event_dict
else []
)
}
elif re.match(LRO_REGEX, path):
return {
'name': (
'projects/test-project/locations/test-location/'
'reasoningEngines/123/sessions/123'
'reasoningEngines/123/sessions/4'
),
'done': True,
}
else:
raise ValueError(f'Unsupported path: {path}')
elif http_method == 'POST':
id = str(uuid.uuid4())
self.session_dict[id] = {
new_session_id = '4'
self.session_dict[new_session_id] = {
'name': (
'projects/test-project/locations/test-location/'
'reasoningEngines/123/sessions/'
+ id
+ new_session_id
),
'userId': request_dict['user_id'],
'sessionState': request_dict.get('sessionState', {}),
'sessionState': request_dict.get('session_state', {}),
'updateTime': '2024-12-12T12:12:12.123456Z',
}
return {
'name': (
'projects/test_project/locations/test_location/'
'reasoningEngines/test_engine/sessions/123'
'reasoningEngines/123/sessions/'
+ new_session_id
+ '/operations/111'
),
'done': False,
}
@ -223,19 +236,23 @@ def test_get_and_delete_session():
)
assert str(excinfo.value) == 'Session not found: 1'
def test_list_sessions():
def test_list_sessions():
session_service = mock_vertex_ai_session_service()
sessions = session_service.list_sessions(app_name='123', user_id='user')
assert len(sessions.sessions) == 2
assert sessions.sessions[0].id == '1'
assert sessions.sessions[1].id == '2'
def test_create_session():
def test_create_session():
session_service = mock_vertex_ai_session_service()
state = {'key': 'value'}
session = session_service.create_session(
app_name='123', user_id='user', state={'key': 'value'}
app_name='123', user_id='user', state=state
)
assert session.state == {'key': 'value'}
assert session.state == state
assert session.app_name == '123'
assert session.user_id == 'user'
assert session.last_update_time is not None

View File

@ -119,7 +119,7 @@ def calendar_api_spec():
"methods": {
"get": {
"id": "calendar.calendars.get",
"path": "calendars/{calendarId}",
"flatPath": "calendars/{calendarId}",
"httpMethod": "GET",
"description": "Returns metadata for a calendar.",
"parameters": {
@ -151,7 +151,7 @@ def calendar_api_spec():
"methods": {
"list": {
"id": "calendar.events.list",
"path": "calendars/{calendarId}/events",
"flatPath": "calendars/{calendarId}/events",
"httpMethod": "GET",
"description": (
"Returns events on the specified calendar."