No public description

PiperOrigin-RevId: 748777998
This commit is contained in:
Google ADK Member
2025-04-17 19:50:22 +00:00
committed by hangfei
parent 290058eb05
commit 61d4be2d76
99 changed files with 2120 additions and 256 deletions

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

View File

@@ -82,6 +82,8 @@ async def run_interactively(
)
while True:
query = input('user: ')
if not query or not query.strip():
continue
if query == 'exit':
break
async for event in runner.run_async(

View File

@@ -0,0 +1,279 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import subprocess
from typing import Optional
from typing import Tuple
import click
_INIT_PY_TEMPLATE = """\
from . import agent
"""
_AGENT_PY_TEMPLATE = """\
from google.adk.agents import Agent
root_agent = Agent(
model='{model_name}',
name='root_agent',
description='A helpful assistant for user questions.',
instruction='Answer user questions to the best of your knowledge',
)
"""
_GOOGLE_API_MSG = """
Don't have API Key? Create one in AI Studio: https://aistudio.google.com/apikey
"""
_GOOGLE_CLOUD_SETUP_MSG = """
You need an existing Google Cloud account and project, check out this link for details:
https://google.github.io/adk-docs/get-started/quickstart/#gemini---google-cloud-vertex-ai
"""
_OTHER_MODEL_MSG = """
Please see below guide to configure other models:
https://google.github.io/adk-docs/agents/models
"""
_SUCCESS_MSG = """
Agent created in {agent_folder}:
- .env
- __init__.py
- agent.py
"""
def _get_gcp_project_from_gcloud() -> str:
"""Uses gcloud to get default project."""
try:
result = subprocess.run(
["gcloud", "config", "get-value", "project"],
capture_output=True,
text=True,
check=True,
)
return result.stdout.strip()
except (subprocess.CalledProcessError, FileNotFoundError):
return ""
def _get_gcp_region_from_gcloud() -> str:
"""Uses gcloud to get default region."""
try:
result = subprocess.run(
["gcloud", "config", "get-value", "compute/region"],
capture_output=True,
text=True,
check=True,
)
return result.stdout.strip()
except (subprocess.CalledProcessError, FileNotFoundError):
return ""
def _prompt_str(
prompt_prefix: str,
*,
prior_msg: Optional[str] = None,
default_value: Optional[str] = None,
) -> str:
if prior_msg:
click.secho(prior_msg, fg="green")
while True:
value: str = click.prompt(
prompt_prefix, default=default_value or None, type=str
)
if value and value.strip():
return value.strip()
def _prompt_for_google_cloud(
google_cloud_project: Optional[str],
) -> str:
"""Prompts user for Google Cloud project ID."""
google_cloud_project = (
google_cloud_project
or os.environ.get("GOOGLE_CLOUD_PROJECT", None)
or _get_gcp_project_from_gcloud()
)
google_cloud_project = _prompt_str(
"Enter Google Cloud project ID", default_value=google_cloud_project
)
return google_cloud_project
def _prompt_for_google_cloud_region(
google_cloud_region: Optional[str],
) -> str:
"""Prompts user for Google Cloud region."""
google_cloud_region = (
google_cloud_region
or os.environ.get("GOOGLE_CLOUD_LOCATION", None)
or _get_gcp_region_from_gcloud()
)
google_cloud_region = _prompt_str(
"Enter Google Cloud region",
default_value=google_cloud_region or "us-central1",
)
return google_cloud_region
def _prompt_for_google_api_key(
google_api_key: Optional[str],
) -> str:
"""Prompts user for Google API key."""
google_api_key = google_api_key or os.environ.get("GOOGLE_API_KEY", None)
google_api_key = _prompt_str(
"Enter Google API key",
prior_msg=_GOOGLE_API_MSG,
default_value=google_api_key,
)
return google_api_key
def _generate_files(
agent_folder: str,
*,
google_api_key: Optional[str] = None,
google_cloud_project: Optional[str] = None,
google_cloud_region: Optional[str] = None,
model: Optional[str] = None,
):
"""Generates a folder name for the agent."""
os.makedirs(agent_folder, exist_ok=True)
dotenv_file_path = os.path.join(agent_folder, ".env")
init_file_path = os.path.join(agent_folder, "__init__.py")
agent_file_path = os.path.join(agent_folder, "agent.py")
with open(dotenv_file_path, "w", encoding="utf-8") as f:
lines = []
if google_api_key:
lines.append("GOOGLE_GENAI_USE_VERTEXAI=0")
elif google_cloud_project and google_cloud_region:
lines.append("GOOGLE_GENAI_USE_VERTEXAI=1")
if google_api_key:
lines.append(f"GOOGLE_API_KEY={google_api_key}")
if google_cloud_project:
lines.append(f"GOOGLE_CLOUD_PROJECT={google_cloud_project}")
if google_cloud_region:
lines.append(f"GOOGLE_CLOUD_LOCATION={google_cloud_region}")
f.write("\n".join(lines))
with open(init_file_path, "w", encoding="utf-8") as f:
f.write(_INIT_PY_TEMPLATE)
with open(agent_file_path, "w", encoding="utf-8") as f:
f.write(_AGENT_PY_TEMPLATE.format(model_name=model))
click.secho(
_SUCCESS_MSG.format(agent_folder=agent_folder),
fg="green",
)
def _prompt_for_model() -> str:
model_choice = click.prompt(
"""\
Choose a model for the root agent:
1. gemini-2.0-flash-001
2. Other models (fill later)
Choose model""",
type=click.Choice(["1", "2"]),
)
if model_choice == "1":
return "gemini-2.0-flash-001"
else:
click.secho(_OTHER_MODEL_MSG, fg="green")
return "<FILL_IN_MODEL>"
def _prompt_to_choose_backend(
google_api_key: Optional[str],
google_cloud_project: Optional[str],
google_cloud_region: Optional[str],
) -> Tuple[Optional[str], Optional[str], Optional[str]]:
"""Prompts user to choose backend.
Returns:
A tuple of (google_api_key, google_cloud_project, google_cloud_region).
"""
backend_choice = click.prompt(
"1. Google AI\n2. Vertex AI\nChoose a backend",
type=click.Choice(["1", "2"]),
)
if backend_choice == "1":
google_api_key = _prompt_for_google_api_key(google_api_key)
elif backend_choice == "2":
click.secho(_GOOGLE_CLOUD_SETUP_MSG, fg="green")
google_cloud_project = _prompt_for_google_cloud(google_cloud_project)
google_cloud_region = _prompt_for_google_cloud_region(google_cloud_region)
return google_api_key, google_cloud_project, google_cloud_region
def run_cmd(
agent_name: str,
*,
model: Optional[str],
google_api_key: Optional[str],
google_cloud_project: Optional[str],
google_cloud_region: Optional[str],
):
"""Runs `adk create` command to create agent template.
Args:
agent_name: str, The name of the agent.
google_api_key: Optional[str], The Google API key for using Google AI as
backend.
google_cloud_project: Optional[str], The Google Cloud project for using
VertexAI as backend.
google_cloud_region: Optional[str], The Google Cloud region for using
VertexAI as backend.
"""
agent_folder = os.path.join(os.getcwd(), agent_name)
# check folder doesn't exist or it's empty. Otherwise, throw
if os.path.exists(agent_folder) and os.listdir(agent_folder):
# Prompt user whether to override existing files using click
if not click.confirm(
f"Non-empty folder already exist: '{agent_folder}'\n"
"Override existing content?",
default=False,
):
raise click.Abort()
if not model:
model = _prompt_for_model()
if not google_api_key and not (google_cloud_project and google_cloud_region):
if model.startswith("gemini"):
google_api_key, google_cloud_project, google_cloud_region = (
_prompt_to_choose_backend(
google_api_key, google_cloud_project, google_cloud_region
)
)
_generate_files(
agent_folder,
google_api_key=google_api_key,
google_cloud_project=google_cloud_project,
google_cloud_region=google_cloud_region,
model=model,
)

View File

@@ -82,8 +82,9 @@ def to_cloud_run(
app_name: str,
temp_folder: str,
port: int,
with_cloud_trace: bool,
trace_to_cloud: bool,
with_ui: bool,
verbosity: str,
):
"""Deploys an agent to Google Cloud Run.
@@ -108,8 +109,9 @@ def to_cloud_run(
app_name: The name of the app, by default, it's basename of `agent_folder`.
temp_folder: The temp folder for the generated Cloud Run source files.
port: The port of the ADK api server.
with_cloud_trace: Whether to enable Cloud Trace.
trace_to_cloud: Whether to enable Cloud Trace.
with_ui: Whether to deploy with UI.
verbosity: The verbosity level of the CLI.
"""
app_name = app_name or os.path.basename(agent_folder)
@@ -142,7 +144,7 @@ def to_cloud_run(
port=port,
command='web' if with_ui else 'api_server',
install_agent_deps=install_agent_deps,
trace_to_cloud_option='--trace_to_cloud' if with_cloud_trace else '',
trace_to_cloud_option='--trace_to_cloud' if trace_to_cloud else '',
)
dockerfile_path = os.path.join(temp_folder, 'Dockerfile')
os.makedirs(temp_folder, exist_ok=True)
@@ -169,6 +171,8 @@ def to_cloud_run(
*region_options,
'--port',
str(port),
'--verbosity',
verbosity,
'--labels',
'created-by=adk',
],

View File

@@ -24,6 +24,7 @@ import click
from fastapi import FastAPI
import uvicorn
from . import cli_create
from . import cli_deploy
from .cli import run_cli
from .cli_eval import MISSING_EVAL_DEPENDENCIES_MESSAGE
@@ -42,10 +43,59 @@ def main():
@main.group()
def deploy():
"""Deploy Agent."""
"""Deploys agent to hosted environments."""
pass
@main.command("create")
@click.option(
"--model",
type=str,
help="Optional. The model used for the root agent.",
)
@click.option(
"--api_key",
type=str,
help=(
"Optional. The API Key needed to access the model, e.g. Google AI API"
" Key."
),
)
@click.option(
"--project",
type=str,
help="Optional. The Google Cloud Project for using VertexAI as backend.",
)
@click.option(
"--region",
type=str,
help="Optional. The Google Cloud Region for using VertexAI as backend.",
)
@click.argument("app_name", type=str, required=True)
def cli_create_cmd(
app_name: str,
model: Optional[str],
api_key: Optional[str],
project: Optional[str],
region: Optional[str],
):
"""Creates a new app in the current folder with prepopulated agent template.
APP_NAME: required, the folder of the agent source code.
Example:
adk create path/to/my_app
"""
cli_create.run_cmd(
app_name,
model=model,
google_api_key=api_key,
google_cloud_project=project,
google_cloud_region=region,
)
@main.command("run")
@click.option(
"--save_session",
@@ -62,7 +112,7 @@ def deploy():
),
)
def cli_run(agent: str, save_session: bool):
"""Run an interactive CLI for a certain agent.
"""Runs an interactive CLI for a certain agent.
AGENT: The path to the agent source code folder.
@@ -150,7 +200,7 @@ def cli_eval(
EvalMetric(metric_name=metric_name, threshold=threshold)
)
print(f"Using evaluation criteria: {evaluation_criteria}")
print(f"Using evaluation creiteria: {evaluation_criteria}")
root_agent = get_root_agent(agent_module_file_path)
reset_func = try_get_reset_func(agent_module_file_path)
@@ -244,7 +294,7 @@ def cli_eval(
type=click.Path(
exists=True, dir_okay=True, file_okay=False, resolve_path=True
),
default=os.getcwd(),
default=os.getcwd,
)
def cli_web(
agents_dir: str,
@@ -255,7 +305,7 @@ def cli_web(
port: int = 8000,
trace_to_cloud: bool = False,
):
"""Start a FastAPI server with Web UI for agents.
"""Starts a FastAPI server with Web UI for agents.
AGENTS_DIR: The directory of agents, where each sub-directory is a single
agent, containing at least `__init__.py` and `agent.py` files.
@@ -274,7 +324,7 @@ def cli_web(
@asynccontextmanager
async def _lifespan(app: FastAPI):
click.secho(
f"""\
f"""
+-----------------------------------------------------------------------------+
| ADK Web Server started |
| |
@@ -285,7 +335,7 @@ def cli_web(
)
yield # Startup is done, now app is running
click.secho(
"""\
"""
+-----------------------------------------------------------------------------+
| ADK Web Server shutting down... |
+-----------------------------------------------------------------------------+
@@ -378,7 +428,7 @@ def cli_api_server(
port: int = 8000,
trace_to_cloud: bool = False,
):
"""Start a FastAPI server for agents.
"""Starts a FastAPI server for agents.
AGENTS_DIR: The directory of agents, where each sub-directory is a single
agent, containing at least `__init__.py` and `agent.py` files.
@@ -452,7 +502,7 @@ def cli_api_server(
help="Optional. The port of the ADK API server (default: 8000).",
)
@click.option(
"--with_cloud_trace",
"--trace_to_cloud",
type=bool,
is_flag=True,
show_default=True,
@@ -483,6 +533,14 @@ def cli_api_server(
" (default: a timestamped folder in the system temp directory)."
),
)
@click.option(
"--verbosity",
type=click.Choice(
["debug", "info", "warning", "error", "critical"], case_sensitive=False
),
default="WARNING",
help="Optional. Override the default verbosity level.",
)
@click.argument(
"agent",
type=click.Path(
@@ -497,8 +555,9 @@ def cli_deploy_cloud_run(
app_name: str,
temp_folder: str,
port: int,
with_cloud_trace: bool,
trace_to_cloud: bool,
with_ui: bool,
verbosity: str,
):
"""Deploys an agent to Cloud Run.
@@ -517,8 +576,9 @@ def cli_deploy_cloud_run(
app_name=app_name,
temp_folder=temp_folder,
port=port,
with_cloud_trace=with_cloud_trace,
trace_to_cloud=trace_to_cloud,
with_ui=with_ui,
verbosity=verbosity,
)
except Exception as e:
click.secho(f"Deploy failed: {e}", fg="red", err=True)

View File

@@ -13,7 +13,9 @@
# limitations under the License.
import asyncio
from contextlib import asynccontextmanager
import importlib
import inspect
import json
import logging
import os
@@ -28,6 +30,7 @@ from typing import Literal
from typing import Optional
import click
from click import Tuple
from fastapi import FastAPI
from fastapi import HTTPException
from fastapi import Query
@@ -56,6 +59,7 @@ from ..agents.llm_agent import Agent
from ..agents.run_config import StreamingMode
from ..artifacts import InMemoryArtifactService
from ..events.event import Event
from ..memory.in_memory_memory_service import InMemoryMemoryService
from ..runners import Runner
from ..sessions.database_session_service import DatabaseSessionService
from ..sessions.in_memory_session_service import InMemorySessionService
@@ -143,11 +147,8 @@ def get_fast_api_app(
provider.add_span_processor(
export.SimpleSpanProcessor(ApiServerSpanExporter(trace_dict))
)
envs.load_dotenv()
enable_cloud_tracing = trace_to_cloud or os.environ.get(
"ADK_TRACE_TO_CLOUD", "0"
).lower() in ["1", "true"]
if enable_cloud_tracing:
if trace_to_cloud:
envs.load_dotenv_for_agent("", agent_dir)
if project_id := os.environ.get("GOOGLE_CLOUD_PROJECT", None):
processor = export.BatchSpanProcessor(
CloudTraceSpanExporter(project_id=project_id)
@@ -161,8 +162,22 @@ def get_fast_api_app(
trace.set_tracer_provider(provider)
exit_stacks = []
@asynccontextmanager
async def internal_lifespan(app: FastAPI):
if lifespan:
async with lifespan(app) as lifespan_context:
yield
if exit_stacks:
for stack in exit_stacks:
await stack.aclose()
else:
yield
# Run the FastAPI server.
app = FastAPI(lifespan=lifespan)
app = FastAPI(lifespan=internal_lifespan)
if allow_origins:
app.add_middleware(
@@ -181,6 +196,7 @@ def get_fast_api_app(
# Build the Artifact service
artifact_service = InMemoryArtifactService()
memory_service = InMemoryMemoryService()
# Build the Session service
agent_engine_id = ""
@@ -358,7 +374,7 @@ def get_fast_api_app(
"/apps/{app_name}/eval_sets/{eval_set_id}/add_session",
response_model_exclude_none=True,
)
def add_session_to_eval_set(
async def add_session_to_eval_set(
app_name: str, eval_set_id: str, req: AddSessionToEvalSetRequest
):
pattern = r"^[a-zA-Z0-9_]+$"
@@ -393,7 +409,9 @@ def get_fast_api_app(
test_data = evals.convert_session_to_eval_format(session)
# Populate the session with initial session state.
initial_session_state = create_empty_state(_get_root_agent(app_name))
initial_session_state = create_empty_state(
await _get_root_agent_async(app_name)
)
eval_set_data.append({
"name": req.eval_id,
@@ -430,7 +448,7 @@ def get_fast_api_app(
"/apps/{app_name}/eval_sets/{eval_set_id}/run_eval",
response_model_exclude_none=True,
)
def run_eval(
async def run_eval(
app_name: str, eval_set_id: str, req: RunEvalRequest
) -> list[RunEvalResult]:
from .cli_eval import run_evals
@@ -447,7 +465,7 @@ def get_fast_api_app(
logger.info(
"Eval ids to run list is empty. We will all evals in the eval set."
)
root_agent = _get_root_agent(app_name)
root_agent = await _get_root_agent_async(app_name)
eval_results = list(
run_evals(
eval_set_to_evals,
@@ -577,7 +595,7 @@ def get_fast_api_app(
)
if not session:
raise HTTPException(status_code=404, detail="Session not found")
runner = _get_runner(req.app_name)
runner = await _get_runner_async(req.app_name)
events = [
event
async for event in runner.run_async(
@@ -604,7 +622,7 @@ def get_fast_api_app(
async def event_generator():
try:
stream_mode = StreamingMode.SSE if req.streaming else StreamingMode.NONE
runner = _get_runner(req.app_name)
runner = await _get_runner_async(req.app_name)
async for event in runner.run_async(
user_id=req.user_id,
session_id=req.session_id,
@@ -630,7 +648,7 @@ def get_fast_api_app(
"/apps/{app_name}/users/{user_id}/sessions/{session_id}/events/{event_id}/graph",
response_model_exclude_none=True,
)
def get_event_graph(
async def get_event_graph(
app_name: str, user_id: str, session_id: str, event_id: str
):
# Connect to managed session if agent_engine_id is set.
@@ -647,7 +665,7 @@ def get_fast_api_app(
function_calls = event.get_function_calls()
function_responses = event.get_function_responses()
root_agent = _get_root_agent(app_name)
root_agent = await _get_root_agent_async(app_name)
dot_graph = None
if function_calls:
function_call_highlights = []
@@ -704,7 +722,7 @@ def get_fast_api_app(
live_request_queue = LiveRequestQueue()
async def forward_events():
runner = _get_runner(app_name)
runner = await _get_runner_async(app_name)
async for event in runner.run_live(
session=session, live_request_queue=live_request_queue
):
@@ -742,26 +760,40 @@ def get_fast_api_app(
for task in pending:
task.cancel()
def _get_root_agent(app_name: str) -> Agent:
async def _get_root_agent_async(app_name: str) -> Agent:
"""Returns the root agent for the given app."""
if app_name in root_agent_dict:
return root_agent_dict[app_name]
envs.load_dotenv_for_agent(os.path.basename(app_name), agent_dir)
agent_module = importlib.import_module(app_name)
root_agent: Agent = agent_module.agent.root_agent
if getattr(agent_module.agent, "root_agent"):
root_agent = agent_module.agent.root_agent
else:
raise ValueError(f'Unable to find "root_agent" from {app_name}.')
# Handle an awaitable root agent and await for the actual agent.
if inspect.isawaitable(root_agent):
try:
agent, exit_stack = await root_agent
exit_stacks.append(exit_stack)
root_agent = agent
except Exception as e:
raise RuntimeError(f"error getting root agent, {e}") from e
root_agent_dict[app_name] = root_agent
return root_agent
def _get_runner(app_name: str) -> Runner:
async def _get_runner_async(app_name: str) -> Runner:
"""Returns the runner for the given app."""
envs.load_dotenv_for_agent(os.path.basename(app_name), agent_dir)
if app_name in runner_dict:
return runner_dict[app_name]
root_agent = _get_root_agent(app_name)
root_agent = await _get_root_agent_async(app_name)
runner = Runner(
app_name=agent_engine_id if agent_engine_id else app_name,
agent=root_agent,
artifact_service=artifact_service,
session_service=session_service,
memory_service=memory_service,
)
runner_dict[app_name] = runner
return runner

View File

@@ -50,8 +50,5 @@ def load_dotenv_for_agent(
agent_name,
dotenv_file_path,
)
logger.info(
'Reloaded %s file for %s at %s', filename, agent_name, dotenv_file_path
)
else:
logger.info('No %s file found for %s', filename, agent_name)

View File

@@ -106,9 +106,11 @@ class ResponseEvaluator:
eval_dataset = pd.DataFrame(flattened_queries).rename(
columns={"query": "prompt", "expected_tool_use": "reference_trajectory"}
)
eval_task = EvalTask(dataset=eval_dataset, metrics=metrics)
eval_result = eval_task.evaluate()
eval_result = ResponseEvaluator._perform_eval(
dataset=eval_dataset, metrics=metrics
)
if print_detailed_results:
ResponseEvaluator._print_results(eval_result)
return eval_result.summary_metrics
@@ -129,6 +131,16 @@ class ResponseEvaluator:
metrics.append("rouge_1")
return metrics
@staticmethod
def _perform_eval(dataset, metrics):
"""This method hides away the call to external service.
Primarily helps with unit testing.
"""
eval_task = EvalTask(dataset=dataset, metrics=metrics)
return eval_task.evaluate()
@staticmethod
def _print_results(eval_result):
print("Evaluation Summary Metrics:", eval_result.summary_metrics)

View File

@@ -87,15 +87,21 @@ class _NlPlanningResponse(BaseLlmResponseProcessor):
return
# Postprocess the LLM response.
callback_context = CallbackContext(invocation_context)
processed_parts = planner.process_planning_response(
CallbackContext(invocation_context), llm_response.content.parts
callback_context, llm_response.content.parts
)
if processed_parts:
llm_response.content.parts = processed_parts
# Maintain async generator behavior
if False: # Ensures it behaves as a generator
yield # This is a no-op but maintains generator structure
if callback_context.state.has_delta():
state_update_event = Event(
invocation_id=invocation_context.invocation_id,
author=invocation_context.agent.name,
branch=invocation_context.branch,
actions=callback_context._event_actions,
)
yield state_update_event
response_processor = _NlPlanningResponse()

View File

@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import base64
import copy
from datetime import datetime
import json
@@ -20,17 +21,17 @@ from typing import Any
from typing import Optional
import uuid
from sqlalchemy import Boolean
from sqlalchemy import delete
from sqlalchemy import Dialect
from sqlalchemy import ForeignKeyConstraint
from sqlalchemy import func
from sqlalchemy import select
from sqlalchemy import Text
from sqlalchemy.dialects import postgresql
from sqlalchemy.engine import create_engine
from sqlalchemy.engine import Engine
from sqlalchemy.ext.mutable import MutableDict
from sqlalchemy.exc import ArgumentError
from sqlalchemy.ext.mutable import MutableDict
from sqlalchemy.inspection import inspect
from sqlalchemy.orm import DeclarativeBase
from sqlalchemy.orm import Mapped
@@ -54,6 +55,7 @@ from .base_session_service import ListSessionsResponse
from .session import Session
from .state import State
logger = logging.getLogger(__name__)
@@ -103,7 +105,7 @@ class StorageSession(Base):
String, primary_key=True, default=lambda: str(uuid.uuid4())
)
state: Mapped[dict] = mapped_column(
state: Mapped[MutableDict[str, Any]] = mapped_column(
MutableDict.as_mutable(DynamicJSON), default={}
)
@@ -134,8 +136,20 @@ class StorageEvent(Base):
author: Mapped[str] = mapped_column(String)
branch: Mapped[str] = mapped_column(String, nullable=True)
timestamp: Mapped[DateTime] = mapped_column(DateTime(), default=func.now())
content: Mapped[dict] = mapped_column(DynamicJSON)
actions: Mapped[dict] = mapped_column(PickleType)
content: Mapped[dict[str, Any]] = mapped_column(DynamicJSON)
actions: Mapped[MutableDict[str, Any]] = mapped_column(PickleType)
long_running_tool_ids_json: Mapped[Optional[str]] = mapped_column(
Text, nullable=True
)
grounding_metadata: Mapped[dict[str, Any]] = mapped_column(
DynamicJSON, nullable=True
)
partial: Mapped[bool] = mapped_column(Boolean, nullable=True)
turn_complete: Mapped[bool] = mapped_column(Boolean, nullable=True)
error_code: Mapped[str] = mapped_column(String, nullable=True)
error_message: Mapped[str] = mapped_column(String, nullable=True)
interrupted: Mapped[bool] = mapped_column(Boolean, nullable=True)
storage_session: Mapped[StorageSession] = relationship(
"StorageSession",
@@ -150,13 +164,28 @@ class StorageEvent(Base):
),
)
@property
def long_running_tool_ids(self) -> set[str]:
return (
set(json.loads(self.long_running_tool_ids_json))
if self.long_running_tool_ids_json
else set()
)
@long_running_tool_ids.setter
def long_running_tool_ids(self, value: set[str]):
if value is None:
self.long_running_tool_ids_json = None
else:
self.long_running_tool_ids_json = json.dumps(list(value))
class StorageAppState(Base):
"""Represents an app state stored in the database."""
__tablename__ = "app_states"
app_name: Mapped[str] = mapped_column(String, primary_key=True)
state: Mapped[dict] = mapped_column(
state: Mapped[MutableDict[str, Any]] = mapped_column(
MutableDict.as_mutable(DynamicJSON), default={}
)
update_time: Mapped[DateTime] = mapped_column(
@@ -170,7 +199,7 @@ class StorageUserState(Base):
app_name: Mapped[str] = mapped_column(String, primary_key=True)
user_id: Mapped[str] = mapped_column(String, primary_key=True)
state: Mapped[dict] = mapped_column(
state: Mapped[MutableDict[str, Any]] = mapped_column(
MutableDict.as_mutable(DynamicJSON), default={}
)
update_time: Mapped[DateTime] = mapped_column(
@@ -295,7 +324,6 @@ class DatabaseSessionService(BaseSessionService):
last_update_time=storage_session.update_time.timestamp(),
)
return session
return None
@override
def get_session(
@@ -309,7 +337,6 @@ class DatabaseSessionService(BaseSessionService):
# 1. Get the storage session entry from session table
# 2. Get all the events based on session id and filtering config
# 3. Convert and return the session
session: Session = None
with self.DatabaseSessionFactory() as sessionFactory:
storage_session = sessionFactory.get(
StorageSession, (app_name, user_id, session_id)
@@ -356,13 +383,19 @@ class DatabaseSessionService(BaseSessionService):
author=e.author,
branch=e.branch,
invocation_id=e.invocation_id,
content=e.content,
content=_decode_content(e.content),
actions=e.actions,
timestamp=e.timestamp.timestamp(),
long_running_tool_ids=e.long_running_tool_ids,
grounding_metadata=e.grounding_metadata,
partial=e.partial,
turn_complete=e.turn_complete,
error_code=e.error_code,
error_message=e.error_message,
interrupted=e.interrupted,
)
for e in storage_events
]
return session
@override
@@ -387,7 +420,6 @@ class DatabaseSessionService(BaseSessionService):
)
sessions.append(session)
return ListSessionsResponse(sessions=sessions)
raise ValueError("Failed to retrieve sessions.")
@override
def delete_session(
@@ -406,7 +438,7 @@ class DatabaseSessionService(BaseSessionService):
def append_event(self, session: Session, event: Event) -> Event:
logger.info(f"Append event: {event} to session {session.id}")
if event.partial and not event.content:
if event.partial:
return event
# 1. Check if timestamp is stale
@@ -455,19 +487,34 @@ class DatabaseSessionService(BaseSessionService):
storage_user_state.state = user_state
storage_session.state = session_state
encoded_content = event.content.model_dump(exclude_none=True)
storage_event = StorageEvent(
id=event.id,
invocation_id=event.invocation_id,
author=event.author,
branch=event.branch,
content=encoded_content,
actions=event.actions,
session_id=session.id,
app_name=session.app_name,
user_id=session.user_id,
timestamp=datetime.fromtimestamp(event.timestamp),
long_running_tool_ids=event.long_running_tool_ids,
grounding_metadata=event.grounding_metadata,
partial=event.partial,
turn_complete=event.turn_complete,
error_code=event.error_code,
error_message=event.error_message,
interrupted=event.interrupted,
)
if event.content:
encoded_content = event.content.model_dump(exclude_none=True)
# Workaround for multimodal Content throwing JSON not serializable
# error with SQLAlchemy.
for p in encoded_content["parts"]:
if "inline_data" in p:
p["inline_data"]["data"] = (
base64.b64encode(p["inline_data"]["data"]).decode("utf-8"),
)
storage_event.content = encoded_content
sessionFactory.add(storage_event)
@@ -489,8 +536,7 @@ class DatabaseSessionService(BaseSessionService):
user_id: str,
session_id: str,
) -> ListEventsResponse:
pass
raise NotImplementedError()
def convert_event(event: StorageEvent) -> Event:
"""Converts a storage event to an event."""
@@ -505,7 +551,7 @@ def convert_event(event: StorageEvent) -> Event:
)
def _extract_state_delta(state: dict):
def _extract_state_delta(state: dict[str, Any]):
app_state_delta = {}
user_state_delta = {}
session_state_delta = {}
@@ -528,3 +574,10 @@ def _merge_state(app_state, user_state, session_state):
for key in user_state.keys():
merged_state[State.USER_PREFIX + key] = user_state[key]
return merged_state
def _decode_content(content: dict[str, Any]) -> dict[str, Any]:
for p in content["parts"]:
if "inline_data" in p:
p["inline_data"]["data"] = base64.b64decode(p["inline_data"]["data"][0])
return content

View File

@@ -1,14 +0,0 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

View File

@@ -1,10 +0,0 @@
# Copy as .env file and fill your values below to run integration tests.
# Choose Backend: GOOGLE_AI_ONLY | VERTEX_ONLY | BOTH (default)
TEST_BACKEND=BOTH
# ML Dev backend config
GOOGLE_API_KEY=YOUR_VALUE_HERE
# Vertex backend config
GOOGLE_CLOUD_PROJECT=YOUR_VALUE_HERE
GOOGLE_CLOUD_LOCATION=YOUR_VALUE_HERE

View File

@@ -1,18 +0,0 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import pytest
# This allows pytest to show the values of the asserts.
pytest.register_assert_rewrite('tests.integration.utils')

View File

@@ -1,119 +0,0 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import os
from typing import Literal
import warnings
from dotenv import load_dotenv
from google.adk import Agent
from pytest import fixture
from pytest import FixtureRequest
from pytest import hookimpl
from pytest import Metafunc
from .utils import TestRunner
logger = logging.getLogger(__name__)
def load_env_for_tests():
dotenv_path = os.path.join(os.path.dirname(__file__), '.env')
if not os.path.exists(dotenv_path):
warnings.warn(
f'Missing .env file at {dotenv_path}. See dotenv.sample for an example.'
)
else:
load_dotenv(dotenv_path, override=True, verbose=True)
if 'GOOGLE_API_KEY' not in os.environ:
warnings.warn(
'Missing GOOGLE_API_KEY in the environment variables. GOOGLE_AI backend'
' integration tests will fail.'
)
for env_var in [
'GOOGLE_CLOUD_PROJECT',
'GOOGLE_CLOUD_LOCATION',
]:
if env_var not in os.environ:
warnings.warn(
f'Missing {env_var} in the environment variables. Vertex backend'
' integration tests will fail.'
)
load_env_for_tests()
BackendType = Literal['GOOGLE_AI', 'VERTEX']
@fixture
def agent_runner(request: FixtureRequest) -> TestRunner:
assert isinstance(request.param, dict)
if 'agent' in request.param:
assert isinstance(request.param['agent'], Agent)
return TestRunner(request.param['agent'])
elif 'agent_name' in request.param:
assert isinstance(request.param['agent_name'], str)
return TestRunner.from_agent_name(request.param['agent_name'])
raise NotImplementedError('Must provide agent or agent_name.')
@fixture(autouse=True)
def llm_backend(request: FixtureRequest):
# Set backend environment value.
original_val = os.environ.get('GOOGLE_GENAI_USE_VERTEXAI')
backend_type = request.param
if backend_type == 'GOOGLE_AI':
os.environ['GOOGLE_GENAI_USE_VERTEXAI'] = '0'
else:
os.environ['GOOGLE_GENAI_USE_VERTEXAI'] = '1'
yield # Run the test
# Restore the environment
if original_val is None:
os.environ.pop('GOOGLE_GENAI_USE_VERTEXAI', None)
else:
os.environ['GOOGLE_GENAI_USE_VERTEXAI'] = original_val
@hookimpl(tryfirst=True)
def pytest_generate_tests(metafunc: Metafunc):
if llm_backend.__name__ in metafunc.fixturenames:
if not _is_explicitly_marked(llm_backend.__name__, metafunc):
test_backend = os.environ.get('TEST_BACKEND', 'BOTH')
if test_backend == 'GOOGLE_AI_ONLY':
metafunc.parametrize(llm_backend.__name__, ['GOOGLE_AI'], indirect=True)
elif test_backend == 'VERTEX_ONLY':
metafunc.parametrize(llm_backend.__name__, ['VERTEX'], indirect=True)
elif test_backend == 'BOTH':
metafunc.parametrize(
llm_backend.__name__, ['GOOGLE_AI', 'VERTEX'], indirect=True
)
else:
raise ValueError(
f'Invalid TEST_BACKEND value: {test_backend}, should be one of'
' [GOOGLE_AI_ONLY, VERTEX_ONLY, BOTH]'
)
def _is_explicitly_marked(mark_name: str, metafunc: Metafunc) -> bool:
if hasattr(metafunc.function, 'pytestmark'):
for mark in metafunc.function.pytestmark:
if mark.name == 'parametrize' and mark.args[0] == mark_name:
return True
return False

View File

@@ -1,14 +0,0 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

View File

@@ -1,15 +0,0 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from . import agent

View File

@@ -1,88 +0,0 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from google.adk import Agent
from google.genai import types
new_message = types.Content(
role="user",
parts=[types.Part.from_text(text="Count a number")],
)
google_agent_1 = Agent(
model="gemini-1.5-flash",
name="agent_1",
description="The first agent in the team.",
instruction="Just say 1",
generate_content_config=types.GenerateContentConfig(
temperature=0.1,
),
)
google_agent_2 = Agent(
model="gemini-1.5-flash",
name="agent_2",
description="The second agent in the team.",
instruction="Just say 2",
generate_content_config=types.GenerateContentConfig(
temperature=0.2,
safety_settings=[{
"category": "HARM_CATEGORY_HATE_SPEECH",
"threshold": "BLOCK_ONLY_HIGH",
}],
),
)
google_agent_3 = Agent(
model="gemini-1.5-flash",
name="agent_3",
description="The third agent in the team.",
instruction="Just say 3",
generate_content_config=types.GenerateContentConfig(
temperature=0.5,
safety_settings=[{
"category": "HARM_CATEGORY_DANGEROUS_CONTENT",
"threshold": "BLOCK_NONE",
}],
),
)
google_agent_with_instruction_in_config = Agent(
model="gemini-1.5-flash",
name="agent",
generate_content_config=types.GenerateContentConfig(
temperature=0.5, system_instruction="Count 1"
),
)
def function():
pass
google_agent_with_tools_in_config = Agent(
model="gemini-1.5-flash",
name="agent",
generate_content_config=types.GenerateContentConfig(
temperature=0.5, tools=[function]
),
)
google_agent_with_response_schema_in_config = Agent(
model="gemini-1.5-flash",
name="agent",
generate_content_config=types.GenerateContentConfig(
temperature=0.5, response_schema={"key": "value"}
),
)

View File

@@ -1,15 +0,0 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from . import agent

View File

@@ -1,105 +0,0 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional
from google.adk import Agent
from google.adk.agents.callback_context import CallbackContext
from google.adk.agents.invocation_context import InvocationContext
from google.adk.models import LlmRequest
from google.adk.models import LlmResponse
from google.genai import types
def before_agent_call_end_invocation(
callback_context: CallbackContext,
) -> types.Content:
return types.Content(
role='model',
parts=[types.Part(text='End invocation event before agent call.')],
)
def before_agent_call(
invocation_context: InvocationContext,
) -> types.Content:
return types.Content(
role='model',
parts=[types.Part.from_text(text='Plain text event before agent call.')],
)
def before_model_call_end_invocation(
callback_context: CallbackContext, llm_request: LlmRequest
) -> LlmResponse:
return LlmResponse(
content=types.Content(
role='model',
parts=[
types.Part.from_text(
text='End invocation event before model call.'
)
],
)
)
def before_model_call(
invocation_context: InvocationContext, request: LlmRequest
) -> LlmResponse:
request.config.system_instruction = 'Just return 999 as response.'
return LlmResponse(
content=types.Content(
role='model',
parts=[
types.Part.from_text(
text='Update request event before model call.'
)
],
)
)
def after_model_call(
callback_context: CallbackContext,
llm_response: LlmResponse,
) -> Optional[LlmResponse]:
content = llm_response.content
if not content or not content.parts or not content.parts[0].text:
return
content.parts[0].text += 'Update response event after model call.'
return llm_response
before_agent_callback_agent = Agent(
model='gemini-1.5-flash',
name='before_agent_callback_agent',
instruction='echo 1',
before_agent_callback=before_agent_call_end_invocation,
)
before_model_callback_agent = Agent(
model='gemini-1.5-flash',
name='before_model_callback_agent',
instruction='echo 2',
before_model_callback=before_model_call_end_invocation,
)
after_model_callback_agent = Agent(
model='gemini-1.5-flash',
name='after_model_callback_agent',
instruction='Say hello',
after_model_callback=after_model_call,
)

View File

@@ -1,15 +0,0 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from . import agent

View File

@@ -1,43 +0,0 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List
from typing import Union
from google.adk import Agent
from google.adk.tools import ToolContext
from pydantic import BaseModel
def update_fc(
data_one: str,
data_two: Union[int, float, str],
data_three: list[str],
data_four: List[Union[int, float, str]],
tool_context: ToolContext,
):
"""Simply ask to update these variables in the context"""
tool_context.actions.update_state("data_one", data_one)
tool_context.actions.update_state("data_two", data_two)
tool_context.actions.update_state("data_three", data_three)
tool_context.actions.update_state("data_four", data_four)
root_agent = Agent(
model="gemini-1.5-flash",
name="root_agent",
instruction="Call tools",
flow="auto",
tools=[update_fc],
)

View File

@@ -1,582 +0,0 @@
{
"id": "ead43200-b575-4241-9248-233b4be4f29a",
"context": {
"_time": "2024-12-01 09:02:43.531503",
"data_one": "RRRR",
"data_two": "3.141529",
"data_three": [
"apple",
"banana"
],
"data_four": [
"1",
"hello",
"3.14"
]
},
"events": [
{
"invocation_id": "6BGrtKJu",
"author": "user",
"content": {
"parts": [
{
"text": "hi"
}
],
"role": "user"
},
"options": {},
"id": "ltzQTqR4",
"timestamp": 1733043686.8428597
},
{
"invocation_id": "6BGrtKJu",
"author": "root_agent",
"content": {
"parts": [
{
"text": "Hello! 👋 How can I help you today? \n"
}
],
"role": "model"
},
"options": {
"partial": false
},
"id": "ClSROx8b",
"timestamp": 1733043688.1030986
},
{
"invocation_id": "M3dUcVa8",
"author": "user",
"content": {
"parts": [
{
"text": "update data_one to be RRRR, data_two to be 3.141529, data_three to be apple and banana, data_four to be 1, hello, and 3.14"
}
],
"role": "user"
},
"options": {},
"id": "yxigGwIZ",
"timestamp": 1733043745.9900541
},
{
"invocation_id": "M3dUcVa8",
"author": "root_agent",
"content": {
"parts": [
{
"function_call": {
"args": {
"data_four": [
"1",
"hello",
"3.14"
],
"data_two": "3.141529",
"data_three": [
"apple",
"banana"
],
"data_one": "RRRR"
},
"name": "update_fc"
}
}
],
"role": "model"
},
"options": {
"partial": false
},
"id": "8V6de8th",
"timestamp": 1733043747.4545543
},
{
"invocation_id": "M3dUcVa8",
"author": "root_agent",
"content": {
"parts": [
{
"function_response": {
"name": "update_fc",
"response": {}
}
}
],
"role": "user"
},
"options": {
"update_context": {
"data_one": "RRRR",
"data_two": "3.141529",
"data_three": [
"apple",
"banana"
],
"data_four": [
"1",
"hello",
"3.14"
]
},
"function_call_event_id": "8V6de8th"
},
"id": "dkTj5v8B",
"timestamp": 1733043747.457031
},
{
"invocation_id": "M3dUcVa8",
"author": "root_agent",
"content": {
"parts": [
{
"text": "OK. I've updated the data. Anything else? \n"
}
],
"role": "model"
},
"options": {
"partial": false
},
"id": "OZ77XR41",
"timestamp": 1733043748.7901294
}
],
"past_events": [],
"pending_events": {},
"artifacts": {},
"event_logs": [
{
"invocation_id": "6BGrtKJu",
"event_id": "ClSROx8b",
"model_request": {
"model": "gemini-1.5-flash",
"contents": [
{
"parts": [
{
"text": "hi"
}
],
"role": "user"
}
],
"config": {
"system_instruction": "You are an agent. Your name is root_agent.\nCall tools",
"tools": [
{
"function_declarations": [
{
"description": "Hello",
"name": "update_fc",
"parameters": {
"type": "OBJECT",
"properties": {
"data_one": {
"type": "STRING"
},
"data_two": {
"type": "STRING"
},
"data_three": {
"type": "ARRAY",
"items": {
"type": "STRING"
}
},
"data_four": {
"type": "ARRAY",
"items": {
"any_of": [
{
"type": "INTEGER"
},
{
"type": "NUMBER"
},
{
"type": "STRING"
}
],
"type": "STRING"
}
}
}
}
}
]
}
]
}
},
"model_response": {
"candidates": [
{
"content": {
"parts": [
{
"text": "Hello! 👋 How can I help you today? \n"
}
],
"role": "model"
},
"avg_logprobs": -0.15831730915949896,
"finish_reason": "STOP",
"safety_ratings": [
{
"category": "HARM_CATEGORY_HATE_SPEECH",
"probability": "NEGLIGIBLE",
"probability_score": 0.071777344,
"severity": "HARM_SEVERITY_NEGLIGIBLE",
"severity_score": 0.07080078
},
{
"category": "HARM_CATEGORY_DANGEROUS_CONTENT",
"probability": "NEGLIGIBLE",
"probability_score": 0.16308594,
"severity": "HARM_SEVERITY_NEGLIGIBLE",
"severity_score": 0.14160156
},
{
"category": "HARM_CATEGORY_HARASSMENT",
"probability": "NEGLIGIBLE",
"probability_score": 0.09423828,
"severity": "HARM_SEVERITY_NEGLIGIBLE",
"severity_score": 0.037841797
},
{
"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT",
"probability": "NEGLIGIBLE",
"probability_score": 0.059326172,
"severity": "HARM_SEVERITY_NEGLIGIBLE",
"severity_score": 0.02368164
}
]
}
],
"model_version": "gemini-1.5-flash-001",
"usage_metadata": {
"candidates_token_count": 13,
"prompt_token_count": 32,
"total_token_count": 45
}
}
},
{
"invocation_id": "M3dUcVa8",
"event_id": "8V6de8th",
"model_request": {
"model": "gemini-1.5-flash",
"contents": [
{
"parts": [
{
"text": "hi"
}
],
"role": "user"
},
{
"parts": [
{
"text": "Hello! 👋 How can I help you today? \n"
}
],
"role": "model"
},
{
"parts": [
{
"text": "update data_one to be RRRR, data_two to be 3.141529, data_three to be apple and banana, data_four to be 1, hello, and 3.14"
}
],
"role": "user"
}
],
"config": {
"system_instruction": "You are an agent. Your name is root_agent.\nCall tools",
"tools": [
{
"function_declarations": [
{
"description": "Hello",
"name": "update_fc",
"parameters": {
"type": "OBJECT",
"properties": {
"data_one": {
"type": "STRING"
},
"data_two": {
"type": "STRING"
},
"data_three": {
"type": "ARRAY",
"items": {
"type": "STRING"
}
},
"data_four": {
"type": "ARRAY",
"items": {
"any_of": [
{
"type": "INTEGER"
},
{
"type": "NUMBER"
},
{
"type": "STRING"
}
],
"type": "STRING"
}
}
}
}
}
]
}
]
}
},
"model_response": {
"candidates": [
{
"content": {
"parts": [
{
"function_call": {
"args": {
"data_four": [
"1",
"hello",
"3.14"
],
"data_two": "3.141529",
"data_three": [
"apple",
"banana"
],
"data_one": "RRRR"
},
"name": "update_fc"
}
}
],
"role": "model"
},
"avg_logprobs": -2.100960955431219e-6,
"finish_reason": "STOP",
"safety_ratings": [
{
"category": "HARM_CATEGORY_HATE_SPEECH",
"probability": "NEGLIGIBLE",
"probability_score": 0.12158203,
"severity": "HARM_SEVERITY_NEGLIGIBLE",
"severity_score": 0.13671875
},
{
"category": "HARM_CATEGORY_DANGEROUS_CONTENT",
"probability": "NEGLIGIBLE",
"probability_score": 0.421875,
"severity": "HARM_SEVERITY_LOW",
"severity_score": 0.24511719
},
{
"category": "HARM_CATEGORY_HARASSMENT",
"probability": "NEGLIGIBLE",
"probability_score": 0.15722656,
"severity": "HARM_SEVERITY_NEGLIGIBLE",
"severity_score": 0.072753906
},
{
"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT",
"probability": "NEGLIGIBLE",
"probability_score": 0.083984375,
"severity": "HARM_SEVERITY_NEGLIGIBLE",
"severity_score": 0.03564453
}
]
}
],
"model_version": "gemini-1.5-flash-001",
"usage_metadata": {
"candidates_token_count": 32,
"prompt_token_count": 94,
"total_token_count": 126
}
}
},
{
"invocation_id": "M3dUcVa8",
"event_id": "OZ77XR41",
"model_request": {
"model": "gemini-1.5-flash",
"contents": [
{
"parts": [
{
"text": "hi"
}
],
"role": "user"
},
{
"parts": [
{
"text": "Hello! 👋 How can I help you today? \n"
}
],
"role": "model"
},
{
"parts": [
{
"text": "update data_one to be RRRR, data_two to be 3.141529, data_three to be apple and banana, data_four to be 1, hello, and 3.14"
}
],
"role": "user"
},
{
"parts": [
{
"function_call": {
"args": {
"data_four": [
"1",
"hello",
"3.14"
],
"data_two": "3.141529",
"data_three": [
"apple",
"banana"
],
"data_one": "RRRR"
},
"name": "update_fc"
}
}
],
"role": "model"
},
{
"parts": [
{
"function_response": {
"name": "update_fc",
"response": {}
}
}
],
"role": "user"
}
],
"config": {
"system_instruction": "You are an agent. Your name is root_agent.\nCall tools",
"tools": [
{
"function_declarations": [
{
"description": "Hello",
"name": "update_fc",
"parameters": {
"type": "OBJECT",
"properties": {
"data_one": {
"type": "STRING"
},
"data_two": {
"type": "STRING"
},
"data_three": {
"type": "ARRAY",
"items": {
"type": "STRING"
}
},
"data_four": {
"type": "ARRAY",
"items": {
"any_of": [
{
"type": "INTEGER"
},
{
"type": "NUMBER"
},
{
"type": "STRING"
}
],
"type": "STRING"
}
}
}
}
}
]
}
]
}
},
"model_response": {
"candidates": [
{
"content": {
"parts": [
{
"text": "OK. I've updated the data. Anything else? \n"
}
],
"role": "model"
},
"avg_logprobs": -0.22089435373033797,
"finish_reason": "STOP",
"safety_ratings": [
{
"category": "HARM_CATEGORY_HATE_SPEECH",
"probability": "NEGLIGIBLE",
"probability_score": 0.04663086,
"severity": "HARM_SEVERITY_NEGLIGIBLE",
"severity_score": 0.09423828
},
{
"category": "HARM_CATEGORY_DANGEROUS_CONTENT",
"probability": "NEGLIGIBLE",
"probability_score": 0.18554688,
"severity": "HARM_SEVERITY_NEGLIGIBLE",
"severity_score": 0.111328125
},
{
"category": "HARM_CATEGORY_HARASSMENT",
"probability": "NEGLIGIBLE",
"probability_score": 0.071777344,
"severity": "HARM_SEVERITY_NEGLIGIBLE",
"severity_score": 0.03112793
},
{
"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT",
"probability": "NEGLIGIBLE",
"probability_score": 0.043945313,
"severity": "HARM_SEVERITY_NEGLIGIBLE",
"severity_score": 0.057373047
}
]
}
],
"model_version": "gemini-1.5-flash-001",
"usage_metadata": {
"candidates_token_count": 14,
"prompt_token_count": 129,
"total_token_count": 143
}
}
}
]
}

View File

@@ -1,15 +0,0 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from . import agent

View File

@@ -1,115 +0,0 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List
from typing import Union
from google.adk import Agent
from google.adk.agents.invocation_context import InvocationContext
from google.adk.planners import PlanReActPlanner
from google.adk.tools import ToolContext
def update_fc(
data_one: str,
data_two: Union[int, float, str],
data_three: list[str],
data_four: List[Union[int, float, str]],
tool_context: ToolContext,
) -> str:
"""Simply ask to update these variables in the context"""
tool_context.actions.update_state('data_one', data_one)
tool_context.actions.update_state('data_two', data_two)
tool_context.actions.update_state('data_three', data_three)
tool_context.actions.update_state('data_four', data_four)
return 'The function `update_fc` executed successfully'
def echo_info(customer_id: str) -> str:
"""Echo the context variable"""
return customer_id
def build_global_instruction(invocation_context: InvocationContext) -> str:
return (
'This is the gloabl agent instruction for invocation:'
f' {invocation_context.invocation_id}.'
)
def build_sub_agent_instruction(invocation_context: InvocationContext) -> str:
return 'This is the plain text sub agent instruction.'
context_variable_echo_agent = Agent(
model='gemini-1.5-flash',
name='context_variable_echo_agent',
instruction=(
'Use the echo_info tool to echo {customerId}, {customerInt},'
' {customerFloat}, and {customerJson}. Ask for it if you need to.'
),
flow='auto',
tools=[echo_info],
)
context_variable_with_complicated_format_agent = Agent(
model='gemini-1.5-flash',
name='context_variable_echo_agent',
instruction=(
'Use the echo_info tool to echo { customerId }, {{customer_int }, { '
" non-identifier-float}}, {artifact.fileName}, {'key1': 'value1'} and"
" {{'key2': 'value2'}}. Ask for it if you need to."
),
flow='auto',
tools=[echo_info],
)
context_variable_with_nl_planner_agent = Agent(
model='gemini-1.5-flash',
name='context_variable_with_nl_planner_agent',
instruction=(
'Use the echo_info tool to echo {customerId}. Ask for it if you'
' need to.'
),
flow='auto',
planner=PlanReActPlanner(),
tools=[echo_info],
)
context_variable_with_function_instruction_agent = Agent(
model='gemini-1.5-flash',
name='context_variable_with_function_instruction_agent',
instruction=build_sub_agent_instruction,
flow='auto',
)
context_variable_update_agent = Agent(
model='gemini-1.5-flash',
name='context_variable_update_agent',
instruction='Call tools',
flow='auto',
tools=[update_fc],
)
root_agent = Agent(
model='gemini-1.5-flash',
name='root_agent',
description='The root agent.',
flow='auto',
global_instruction=build_global_instruction,
sub_agents=[
context_variable_with_nl_planner_agent,
context_variable_update_agent,
],
)

View File

@@ -1,15 +0,0 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from . import agent

View File

@@ -1,172 +0,0 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import sys
from google.adk import Agent
from google.adk.agents import RemoteAgent
from google.adk.examples import Example
from google.adk.sessions import Session
from google.genai import types
def reset_data():
pass
def fetch_user_flight_information(customer_email: str) -> str:
"""Fetch user flight information."""
return """
[{"ticket_no": "7240005432906569", "book_ref": "C46E9F", "flight_id": 19250, "flight_no": "LX0112", "departure_airport": "CDG", "arrival_airport": "BSL", "scheduled_departure": "2024-12-30 12:09:03.561731-04:00", "scheduled_arrival": "2024-12-30 13:39:03.561731-04:00", "seat_no": "18E", "fare_conditions": "Economy"}]
"""
def list_customer_flights(customer_email: str) -> str:
return "{'flights': [{'book_ref': 'C46E9F'}]}"
def update_ticket_to_new_flight(ticket_no: str, new_flight_id: str) -> str:
return 'OK, your ticket has been updated.'
def lookup_company_policy(topic: str) -> str:
"""Lookup policies for flight cancelation and rebooking."""
return """
1. How can I change my booking?
* The ticket number must start with 724 (SWISS ticket no./plate).
* The ticket was not paid for by barter or voucher (there are exceptions to voucher payments; if the ticket was paid for in full by voucher, then it may be possible to rebook online under certain circumstances. If it is not possible to rebook online because of the payment method, then you will be informed accordingly during the rebooking process).
* There must be an active flight booking for your ticket. It is not possible to rebook open tickets or tickets without the corresponding flight segments online at the moment.
* It is currently only possible to rebook outbound (one-way) tickets or return tickets with single flight routes (point-to-point).
"""
def search_flights(
departure_airport: str = None,
arrival_airport: str = None,
start_time: str = None,
end_time: str = None,
) -> list[dict]:
return """
[{"flight_id": 19238, "flight_no": "LX0112", "scheduled_departure": "2024-05-08 12:09:03.561731-04:00", "scheduled_arrival": "2024-05-08 13:39:03.561731-04:00", "departure_airport": "CDG", "arrival_airport": "BSL", "status": "Scheduled", "aircraft_code": "SU9", "actual_departure": null, "actual_arrival": null}, {"flight_id": 19242, "flight_no": "LX0112", "scheduled_departure": "2024-05-09 12:09:03.561731-04:00", "scheduled_arrival": "2024-05-09 13:39:03.561731-04:00", "departure_airport": "CDG", "arrival_airport": "BSL", "status": "Scheduled", "aircraft_code": "SU9", "actual_departure": null, "actual_arrival": null}]"""
def search_hotels(
location: str = None,
price_tier: str = None,
checkin_date: str = None,
checkout_date: str = None,
) -> list[dict]:
return """
[{"id": 1, "name": "Hilton Basel", "location": "Basel", "price_tier": "Luxury"}, {"id": 3, "name": "Hyatt Regency Basel", "location": "Basel", "price_tier": "Upper Upscale"}, {"id": 8, "name": "Holiday Inn Basel", "location": "Basel", "price_tier": "Upper Midscale"}]
"""
def book_hotel(hotel_name: str) -> str:
return 'OK, your hotel has been booked.'
def before_model_call(agent: Agent, session: Session, user_message):
if 'expedia' in user_message.lower():
response = types.Content(
role='model',
parts=[types.Part(text="Sorry, I can't answer this question.")],
)
return response
return None
def after_model_call(
agent: Agent, session: Session, content: types.Content
) -> bool:
model_message = content.parts[0].text
if 'expedia' in model_message.lower():
response = types.Content(
role='model',
parts=[types.Part(text="Sorry, I can't answer this question.")],
)
return response
return None
flight_agent = Agent(
model='gemini-1.5-pro',
name='flight_agent',
description='Handles flight information, policy and updates',
instruction="""
You are a specialized assistant for handling flight updates.
The primary assistant delegates work to you whenever the user needs help updating their bookings.
Confirm the updated flight details with the customer and inform them of any additional fees.
When searching, be persistent. Expand your query bounds if the first search returns no results.
Remember that a booking isn't completed until after the relevant tool has successfully been used.
Do not waste the user's time. Do not make up invalid tools or functions.
""",
tools=[
list_customer_flights,
lookup_company_policy,
fetch_user_flight_information,
search_flights,
update_ticket_to_new_flight,
],
)
hotel_agent = Agent(
model='gemini-1.5-pro',
name='hotel_agent',
description='Handles hotel information and booking',
instruction="""
You are a specialized assistant for handling hotel bookings.
The primary assistant delegates work to you whenever the user needs help booking a hotel.
Search for available hotels based on the user's preferences and confirm the booking details with the customer.
When searching, be persistent. Expand your query bounds if the first search returns no results.
""",
tools=[search_hotels, book_hotel],
)
idea_agent = RemoteAgent(
model='gemini-1.5-pro',
name='idea_agent',
description='Provide travel ideas base on the destination.',
url='http://localhost:8000/agent/run',
)
root_agent = Agent(
model='gemini-1.5-pro',
name='root_agent',
instruction="""
You are a helpful customer support assistant for Swiss Airlines.
""",
sub_agents=[flight_agent, hotel_agent, idea_agent],
flow='auto',
examples=[
Example(
input=types.Content(
role='user',
parts=[types.Part(text='How were you built?')],
),
output=[
types.Content(
role='model',
parts=[
types.Part(
text='I was built with the best agent framework.'
)
],
)
],
),
],
)

View File

@@ -1,15 +0,0 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from . import agent

View File

@@ -1,338 +0,0 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from google.adk import Agent
# A lightweight in-memory mock database
ORDER_DB = {
"1": "FINISHED",
"2": "CANCELED",
"3": "PENDING",
"4": "PENDING",
} # Order id to status mapping. Available states: 'FINISHED', 'PENDING', and 'CANCELED'
USER_TO_ORDER_DB = {
"user_a": ["1", "4"],
"user_b": ["2"],
"user_c": ["3"],
} # User id to Order id mapping
TICKET_DB = [{
"ticket_id": "1",
"user_id": "user_a",
"issue_type": "LOGIN_ISSUE",
"status": "OPEN",
}] # Available states: 'OPEN', 'CLOSED', 'ESCALATED'
USER_INFO_DB = {
"user_a": {"name": "Alice", "email": "alice@example.com"},
"user_b": {"name": "Bob", "email": "bob@example.com"},
}
def reset_data():
global ORDER_DB
global USER_TO_ORDER_DB
global TICKET_DB
global USER_INFO_DB
ORDER_DB = {
"1": "FINISHED",
"2": "CANCELED",
"3": "PENDING",
"4": "PENDING",
}
USER_TO_ORDER_DB = {
"user_a": ["1", "4"],
"user_b": ["2"],
"user_c": ["3"],
}
TICKET_DB = [{
"ticket_id": "1",
"user_id": "user_a",
"issue_type": "LOGIN_ISSUE",
"status": "OPEN",
}]
USER_INFO_DB = {
"user_a": {"name": "Alice", "email": "alice@example.com"},
"user_b": {"name": "Bob", "email": "bob@example.com"},
}
def get_order_status(order_id: str) -> str:
"""Get the status of an order.
Args:
order_id (str): The unique identifier of the order.
Returns:
str: The status of the order (e.g., 'FINISHED', 'CANCELED', 'PENDING'),
or 'Order not found' if the order_id does not exist.
"""
return ORDER_DB.get(order_id, "Order not found")
def get_order_ids_for_user(user_id: str) -> list:
"""Get the list of order IDs assigned to a specific transaction associated with a user.
Args:
user_id (str): The unique identifier of the user.
Returns:
List[str]: A list of order IDs associated with the user, or an empty list
if no orders are found.
"""
return USER_TO_ORDER_DB.get(user_id, [])
def cancel_order(order_id: str) -> str:
"""Cancel an order if it is in a 'PENDING' state.
You should call "get_order_status" to check the status first, before calling
this tool.
Args:
order_id (str): The unique identifier of the order to be canceled.
Returns:
str: A message indicating whether the order was successfully canceled or
not.
"""
if order_id in ORDER_DB and ORDER_DB[order_id] == "PENDING":
ORDER_DB[order_id] = "CANCELED"
return f"Order {order_id} has been canceled."
return f"Order {order_id} cannot be canceled."
def refund_order(order_id: str) -> str:
"""Process a refund for an order if it is in a 'CANCELED' state.
You should call "get_order_status" to check if status first, before calling
this tool.
Args:
order_id (str): The unique identifier of the order to be refunded.
Returns:
str: A message indicating whether the order was successfully refunded or
not.
"""
if order_id in ORDER_DB and ORDER_DB[order_id] == "CANCELED":
return f"Order {order_id} has been refunded."
return f"Order {order_id} cannot be refunded."
def create_ticket(user_id: str, issue_type: str) -> str:
"""Create a new support ticket for a user.
Args:
user_id (str): The unique identifier of the user creating the ticket.
issue_type (str): An issue type the user is facing. Available types:
'LOGIN_ISSUE', 'ORDER_ISSUE', 'OTHER'.
Returns:
str: A message indicating that the ticket was created successfully,
including the ticket ID.
"""
ticket_id = str(len(TICKET_DB) + 1)
TICKET_DB.append({
"ticket_id": ticket_id,
"user_id": user_id,
"issue_type": issue_type,
"status": "OPEN",
})
return f"Ticket {ticket_id} created successfully."
def get_ticket_info(ticket_id: str) -> str:
"""Retrieve the information of a support ticket.
current status of a support ticket.
Args:
ticket_id (str): The unique identifier of the ticket.
Returns:
A dictionary contains the following fields, or 'Ticket not found' if the
ticket_id does not exist:
- "ticket_id": str, the current ticket id
- "user_id": str, the associated user id
- "issue": str, the issue type
- "status": The current status of the ticket (e.g., 'OPEN', 'CLOSED',
'ESCALATED')
Example: {"ticket_id": "1", "user_id": "user_a", "issue": "Login issue",
"status": "OPEN"}
"""
for ticket in TICKET_DB:
if ticket["ticket_id"] == ticket_id:
return ticket
return "Ticket not found"
def get_tickets_for_user(user_id: str) -> list:
"""Get all the ticket IDs associated with a user.
Args:
user_id (str): The unique identifier of the user.
Returns:
List[str]: A list of ticket IDs associated with the user.
If no tickets are found, returns an empty list.
"""
return [
ticket["ticket_id"]
for ticket in TICKET_DB
if ticket["user_id"] == user_id
]
def update_ticket_status(ticket_id: str, status: str) -> str:
"""Update the status of a support ticket.
Args:
ticket_id (str): The unique identifier of the ticket.
status (str): The new status to assign to the ticket (e.g., 'OPEN',
'CLOSED', 'ESCALATED').
Returns:
str: A message indicating whether the ticket status was successfully
updated.
"""
for ticket in TICKET_DB:
if ticket["ticket_id"] == ticket_id:
ticket["status"] = status
return f"Ticket {ticket_id} status updated to {status}."
return "Ticket not found"
def get_user_info(user_id: str) -> dict:
"""Retrieve information (name, email) about a user.
Args:
user_id (str): The unique identifier of the user.
Returns:
dict or str: A dictionary containing user information of the following
fields, or 'User not found' if the user_id does not exist:
- name: The name of the user
- email: The email address of the user
For example, {"name": "Chelsea", "email": "123@example.com"}
"""
return USER_INFO_DB.get(user_id, "User not found")
def send_email(user_id: str, email: str) -> list:
"""Send email to user for notification.
Args:
user_id (str): The unique identifier of the user.
email (str): The email address of the user.
Returns:
str: A message indicating whether the email was successfully sent.
"""
if user_id in USER_INFO_DB:
return f"Email sent to {email} for user id {user_id}"
return "Cannot find this user"
# def update_user_info(user_id: str, new_info: dict[str, str]) -> str:
def update_user_info(user_id: str, email: str, name: str) -> str:
"""Update a user's information.
Args:
user_id (str): The unique identifier of the user.
new_info (dict): A dictionary containing the fields to be updated (e.g.,
{'email': 'new_email@example.com'}). Available field keys: 'email' and
'name'.
Returns:
str: A message indicating whether the user's information was successfully
updated or not.
"""
if user_id in USER_INFO_DB:
# USER_INFO_DB[user_id].update(new_info)
if email and name:
USER_INFO_DB[user_id].update({"email": email, "name": name})
elif email:
USER_INFO_DB[user_id].update({"email": email})
elif name:
USER_INFO_DB[user_id].update({"name": name})
else:
raise ValueError("this should not happen.")
return f"User {user_id} information updated."
return "User not found"
def get_user_id_from_cookie() -> str:
"""Get user ID(username) from the cookie.
Only use this function when you do not know user ID(username).
Args: None
Returns:
str: The user ID.
"""
return "user_a"
root_agent = Agent(
model="gemini-2.0-flash-001",
name="Ecommerce_Customer_Service",
instruction="""
You are an intelligent customer service assistant for an e-commerce platform. Your goal is to accurately understand user queries and use the appropriate tools to fulfill requests. Follow these guidelines:
1. **Understand the Query**:
- Identify actions and conditions (e.g., create a ticket only for pending orders).
- Extract necessary details (e.g., user ID, order ID) from the query or infer them from the context.
2. **Plan Multi-Step Workflows**:
- Break down complex queries into sequential steps. For example
- typical workflow:
- Retrieve IDs or references first (e.g., orders for a user).
- Evaluate conditions (e.g., check order status).
- Perform actions (e.g., create a ticket) only when conditions are met.
- another typical workflows - order cancellation and refund:
- Retrieve all orders for the user (`get_order_ids_for_user`).
- Cancel pending orders (`cancel_order`).
- Refund canceled orders (`refund_order`).
- Notify the user (`send_email`).
- another typical workflows - send user report:
- Get user id.
- Get user info(like emails)
- Send email to user.
3. **Avoid Skipping Steps**:
- Ensure each intermediate step is completed before moving to the next.
- Do not create tickets or take other actions without verifying the conditions specified in the query.
4. **Provide Clear Responses**:
- Confirm the actions performed, including details like ticket ID or pending orders.
- Ensure the response aligns with the steps taken and query intent.
""",
tools=[
get_order_status,
cancel_order,
get_order_ids_for_user,
refund_order,
create_ticket,
update_ticket_status,
get_tickets_for_user,
get_ticket_info,
get_user_info,
send_email,
update_user_info,
get_user_id_from_cookie,
],
)

View File

@@ -1,69 +0,0 @@
[
{
"query": "Send an email to user user_a whose email address is alice@example.com",
"expected_tool_use": [
{
"tool_name": "send_email",
"tool_input": {
"email": "alice@example.com",
"user_id": "user_a"
}
}
],
"reference": "Email sent to alice@example.com for user id user_a."
},
{
"query": "Can you tell me the status of my order with ID 1?",
"expected_tool_use": [
{
"tool_name": "get_order_status",
"tool_input": {
"order_id": "1"
}
}
],
"reference": "Your order with ID 1 is FINISHED."
},
{
"query": "Cancel all pending order for the user with user id user_a",
"expected_tool_use": [
{
"tool_name": "get_order_ids_for_user",
"tool_input": {
"user_id": "user_a"
}
},
{
"tool_name": "get_order_status",
"tool_input": {
"order_id": "1"
}
},
{
"tool_name": "get_order_status",
"tool_input": {
"order_id": "4"
}
},
{
"tool_name": "cancel_order",
"tool_input": {
"order_id": "4"
}
}
],
"reference": "I have checked your orders and order 4 was in pending status, so I have cancelled it. Order 1 was already finished and couldn't be cancelled.\n"
},
{
"query": "What orders have I placed under the username user_b?",
"expected_tool_use": [
{
"tool_name": "get_order_ids_for_user",
"tool_input": {
"user_id": "user_b"
}
}
],
"reference": "User user_b has placed one order with order ID 2.\n"
}
]

View File

@@ -1,6 +0,0 @@
{
"criteria": {
"tool_trajectory_avg_score": 0.7,
"response_match_score": 0.5
}
}

View File

@@ -1,15 +0,0 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from . import agent

View File

@@ -1,182 +0,0 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from google.adk import Agent
from google.genai import types
research_plan_agent = Agent(
model="gemini-1.5-flash",
name="research_plan_agent",
description="I can help generate research plan.",
instruction="""\
Your task is to create a research plan according to the user's query.
# Here are the instructions for creating the research plan:
+ Focus on finding specific things, e.g. products, data, etc.
+ Have the personality of a work colleague that is very helpful and explains things very nicely.
+ Don't mention your name unless you are asked.
+ Think about the most common things that you would need to research.
+ Think about possible answers when creating the plan.
+ Your task is to create the sections that should be researched. You will output high level headers, preceded by ##
+ Underneath each header, write a short sentence on what we want to find there.
+ The headers will follow the logical analysis pattern, as well as logical exploration pattern.
+ The headers should be a statement, not be in the form of questions.
+ The header will not include roman numerals or anything of the sort, e.g. ":", etc
+ Do not include things that you cannot possibly know about from using Google Search: e.g. sales forecasting, competitors, profitability analysis, etc.
+ Do not have an executive summary
+ In each section describe specifically what will be researched.
+ Never use "we will", but rather "I will".
+ Don't ask for clarifications from the user.
+ Do not ask the user for clarifications or if they have any other questions.
+ All headers should be bolded.
+ If you have steps in the plan that depend on other information, make sure they are 2 diferent sections in the plan.
+ At the end mention that you will start researching.
# Instruction on replying format
+ Start with your name as "[research_plan_agent]: ".
+ Output the content you want to say.
Output summary:
""",
flow="single",
sub_agents=[],
generate_content_config=types.GenerateContentConfig(
temperature=0.1,
),
)
question_generation_agent = Agent(
model="gemini-1.5-flash",
name="question_generation_agent",
description="I can help generate questions related to user's question.",
instruction="""\
Generate questions related to the research plan generated by research_plan_agent.
# Instruction on replying format
Your reply should be a numbered lsit.
For each question, reply in the following format: "[question_generation_agent]: [generated questions]"
Here is an example of the generated question list:
1. [question_generation_agent]: which state is San Jose in?
2. [question_generation_agent]: how google website is designed?
""",
flow="single",
sub_agents=[],
generate_content_config=types.GenerateContentConfig(
temperature=0.1,
),
)
information_retrieval_agent = Agent(
model="gemini-1.5-flash",
name="information_retrieval_agent",
description=(
"I can help retrieve information related to question_generation_agent's"
" question."
),
instruction="""\
Inspect all the questions after "[question_generation_agent]: " and asnwer them.
# Instruction on replying format
Always start with "[information_retrieval_agent]: "
For the answer of one question:
- Start with a title with one line summary of the reply.
- The title line should be bolded and starts with No.x of the corresponding question.
- Have a paragraph of detailed explain.
# Instruction on exiting the loop
- If you see there are less than 20 questions by "question_generation_agent", do not say "[exit]".
- If you see there are already great or equal to 20 questions asked by "question_generation_agent", say "[exit]" at last to exit the loop.
""",
flow="single",
sub_agents=[],
generate_content_config=types.GenerateContentConfig(
temperature=0.1,
),
)
question_sources_generation_agent = Agent(
model="gemini-1.5-flash",
name="question_sources_generation_agent",
description=(
"I can help generate questions and retrieve related information."
),
instruction="Generate questions and retrieve information.",
flow="loop",
sub_agents=[
question_generation_agent,
information_retrieval_agent,
],
generate_content_config=types.GenerateContentConfig(
temperature=0.1,
),
)
summary_agent = Agent(
model="gemini-1.5-flash",
name="summary_agent",
description="I can help summarize information of previous content.",
instruction="""\
Summarize information in all historical messages that were replied by "question_generation_agent" and "information_retrieval_agent".
# Instruction on replying format
- The output should be like an essay that has a title, an abstract, multiple paragraphs for each topic and a conclusion.
- Each paragraph should maps to one or more question in historical content.
""",
flow="single",
generate_content_config=types.GenerateContentConfig(
temperature=0.8,
),
)
research_assistant = Agent(
model="gemini-1.5-flash",
name="research_assistant",
description="I can help with research question.",
instruction="Help customers with their need.",
flow="sequential",
sub_agents=[
research_plan_agent,
question_sources_generation_agent,
summary_agent,
],
generate_content_config=types.GenerateContentConfig(
temperature=0.1,
),
)
spark_agent = Agent(
model="gemini-1.5-flash",
name="spark_assistant",
description="I can help with non-research question.",
instruction="Help customers with their need.",
flow="auto",
sub_agents=[research_assistant],
generate_content_config=types.GenerateContentConfig(
temperature=0.1,
),
)
root_agent = spark_agent

File diff suppressed because one or more lines are too long

View File

@@ -1,15 +0,0 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from . import agent

View File

@@ -1,95 +0,0 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Hello world agent from agent 1.0 - https://colab.sandbox.google.com/drive/1Zq-nqmgK0nCERCv8jKIaoeTTgbNn6oSo?resourcekey=0-GYaz9pFT4wY8CI8Cvjy5GA#scrollTo=u3X3XwDOaCv9
import random
from google.adk import Agent
from google.genai import types
def roll_die(sides: int) -> int:
"""Roll a die and return the rolled result.
Args:
sides: The integer number of sides the die has.
Returns:
An integer of the result of rolling the die.
"""
return random.randint(1, sides)
def check_prime(nums: list[int]) -> list[str]:
"""Check if a given list of numbers are prime.
Args:
nums: The list of numbers to check.
Returns:
A str indicating which number is prime.
"""
primes = set()
for number in nums:
number = int(number)
if number <= 1:
continue
is_prime = True
for i in range(2, int(number**0.5) + 1):
if number % i == 0:
is_prime = False
break
if is_prime:
primes.add(number)
return (
'No prime numbers found.'
if not primes
else f"{', '.join(str(num) for num in primes)} are prime numbers."
)
root_agent = Agent(
model='gemini-2.0-flash-001',
name='data_processing_agent',
instruction="""
You roll dice and answer questions about the outcome of the dice rolls.
You can roll dice of different sizes.
You can use multiple tools in parallel by calling functions in parallel(in one request and in one round).
The only things you do are roll dice for the user and discuss the outcomes.
It is ok to discuss previous dice roles, and comment on the dice rolls.
When you are asked to roll a die, you must call the roll_die tool with the number of sides. Be sure to pass in an integer. Do not pass in a string.
You should never roll a die on your own.
When checking prime numbers, call the check_prime tool with a list of integers. Be sure to pass in a list of integers. You should never pass in a string.
You should not check prime numbers before calling the tool.
When you are asked to roll a die and check prime numbers, you should always make the following two function calls:
1. You should first call the roll_die tool to get a roll. Wait for the function response before calling the check_prime tool.
2. After you get the function response from roll_die tool, you should call the check_prime tool with the roll_die result.
2.1 If user asks you to check primes based on previous rolls, make sure you include the previous rolls in the list.
3. When you respond, you must include the roll_die result from step 1.
You should always perform the previous 3 steps when asking for a roll and checking prime numbers.
You should not rely on the previous history on prime results.
""",
tools=[
roll_die,
check_prime,
],
generate_content_config=types.GenerateContentConfig(
safety_settings=[
types.SafetySetting( # avoid false alarm about rolling dice.
category=types.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT,
threshold=types.HarmBlockThreshold.OFF,
),
]
),
)

View File

@@ -1,24 +0,0 @@
[
{
"query": "Hi who are you?",
"expected_tool_use": [],
"reference": "I am a data processing agent. I can roll dice and check if the results are prime numbers. What would you like me to do? \n"
},
{
"query": "What can you do?",
"expected_tool_use": [],
"reference": "I can roll dice for you of different sizes, and I can check if the results are prime numbers. I can also remember previous rolls if you'd like to check those for primes as well. What would you like me to do? \n"
},
{
"query": "Can you roll a die with 6 sides",
"expected_tool_use": [
{
"tool_name": "roll_die",
"tool_input": {
"sides": 6
}
}
],
"reference": null
}
]

View File

@@ -1,6 +0,0 @@
{
"criteria": {
"tool_trajectory_avg_score": 1.0,
"response_match_score": 0.5
}
}

View File

@@ -1,15 +0,0 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from . import agent

View File

@@ -1,304 +0,0 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import sys
from google.adk import Agent
DEVICE_DB = {
"device_1": {"status": "ON", "location": "Living Room"},
"device_2": {"status": "OFF", "location": "Bedroom"},
"device_3": {"status": "OFF", "location": "Kitchen"},
}
TEMPERATURE_DB = {
"Living Room": 22,
"Bedroom": 20,
"Kitchen": 24,
}
SCHEDULE_DB = {
"device_1": {"time": "18:00", "status": "ON"},
"device_2": {"time": "22:00", "status": "OFF"},
}
USER_PREFERENCES_DB = {
"user_x": {"preferred_temp": 21, "location": "Bedroom"},
"user_x": {"preferred_temp": 21, "location": "Living Room"},
"user_y": {"preferred_temp": 23, "location": "Living Room"},
}
def reset_data():
global DEVICE_DB
global TEMPERATURE_DB
global SCHEDULE_DB
global USER_PREFERENCES_DB
DEVICE_DB = {
"device_1": {"status": "ON", "location": "Living Room"},
"device_2": {"status": "OFF", "location": "Bedroom"},
"device_3": {"status": "OFF", "location": "Kitchen"},
}
TEMPERATURE_DB = {
"Living Room": 22,
"Bedroom": 20,
"Kitchen": 24,
}
SCHEDULE_DB = {
"device_1": {"time": "18:00", "status": "ON"},
"device_2": {"time": "22:00", "status": "OFF"},
}
USER_PREFERENCES_DB = {
"user_x": {"preferred_temp": 21, "location": "Bedroom"},
"user_x": {"preferred_temp": 21, "location": "Living Room"},
"user_y": {"preferred_temp": 23, "location": "Living Room"},
}
def get_device_info(device_id: str) -> dict:
"""Get the current status and location of a AC device.
Args:
device_id (str): The unique identifier of the device.
Returns:
dict: A dictionary containing the following fields, or 'Device not found'
if the device_id does not exist:
- status: The current status of the device (e.g., 'ON', 'OFF')
- location: The location where the device is installed (e.g., 'Living
Room', 'Bedroom', ''Kitchen')
"""
return DEVICE_DB.get(device_id, "Device not found")
# def set_device_info(device_id: str, updates: dict) -> str:
# """Update the information of a AC device, specifically its status and/or location.
# Args:
# device_id (str): Required. The unique identifier of the device.
# updates (dict): Required. A dictionary containing the fields to be
# updated. Supported keys: - "status" (str): The new status to set for the
# device. Accepted values: 'ON', 'OFF'. **Only these values are allowed.**
# - "location" (str): The new location to set for the device. Accepted
# values: 'Living Room', 'Bedroom', 'Kitchen'. **Only these values are
# allowed.**
# Returns:
# str: A message indicating whether the device information was successfully
# updated.
# """
# if device_id in DEVICE_DB:
# if "status" in updates:
# DEVICE_DB[device_id]["status"] = updates["status"]
# if "location" in updates:
# DEVICE_DB[device_id]["location"] = updates["location"]
# return f"Device {device_id} information updated: {updates}."
# return "Device not found"
def set_device_info(
device_id: str, status: str = "", location: str = ""
) -> str:
"""Update the information of a AC device, specifically its status and/or location.
Args:
device_id (str): Required. The unique identifier of the device.
status (str): The new status to set for the
device. Accepted values: 'ON', 'OFF'. **Only these values are allowed.**
location (str): The new location to set for the device. Accepted
values: 'Living Room', 'Bedroom', 'Kitchen'. **Only these values are
allowed.**
Returns:
str: A message indicating whether the device information was successfully
updated.
"""
if device_id in DEVICE_DB:
if status:
DEVICE_DB[device_id]["status"] = status
return f"Device {device_id} information updated: status -> {status}."
if location:
DEVICE_DB[device_id]["location"] = location
return f"Device {device_id} information updated: location -> {location}."
return "Device not found"
def get_temperature(location: str) -> int:
"""Get the current temperature in celsius of a location (e.g., 'Living Room', 'Bedroom', 'Kitchen').
Args:
location (str): The location for which to retrieve the temperature (e.g.,
'Living Room', 'Bedroom', 'Kitchen').
Returns:
int: The current temperature in celsius in the specified location, or
'Location not found' if the location does not exist.
"""
return TEMPERATURE_DB.get(location, "Location not found")
def set_temperature(location: str, temperature: int) -> str:
"""Set the desired temperature in celsius for a location.
Acceptable range of temperature: 18-30 celsius. If it's out of the range, do
not call this tool.
Args:
location (str): The location where the temperature should be set.
temperature (int): The desired temperature as integer to set in celsius.
Acceptable range: 18-30 celsius.
Returns:
str: A message indicating whether the temperature was successfully set.
"""
if location in TEMPERATURE_DB:
TEMPERATURE_DB[location] = temperature
return f"Temperature in {location} set to {temperature}°C."
return "Location not found"
def get_user_preferences(user_id: str) -> dict:
"""Get the temperature preferences and preferred location of a user_id.
user_id must be provided.
Args:
user_id (str): The unique identifier of the user.
Returns:
dict: A dictionary containing the following fields, or 'User not found' if
the user_id does not exist:
- preferred_temp: The user's preferred temperature.
- location: The location where the user prefers to be.
"""
return USER_PREFERENCES_DB.get(user_id, "User not found")
def set_device_schedule(device_id: str, time: str, status: str) -> str:
"""Schedule a device to change its status at a specific time.
Args:
device_id (str): The unique identifier of the device.
time (str): The time at which the device should change its status (format:
'HH:MM').
status (str): The status to set for the device at the specified time
(e.g., 'ON', 'OFF').
Returns:
str: A message indicating whether the schedule was successfully set.
"""
if device_id in DEVICE_DB:
SCHEDULE_DB[device_id] = {"time": time, "status": status}
return f"Device {device_id} scheduled to turn {status} at {time}."
return "Device not found"
def get_device_schedule(device_id: str) -> dict:
"""Retrieve the schedule of a device.
Args:
device_id (str): The unique identifier of the device.
Returns:
dict: A dictionary containing the following fields, or 'Schedule not
found' if the device_id does not exist:
- time: The scheduled time for the device to change its status (format:
'HH:MM').
- status: The status that will be set at the scheduled time (e.g., 'ON',
'OFF').
"""
return SCHEDULE_DB.get(device_id, "Schedule not found")
def celsius_to_fahrenheit(celsius: int) -> float:
"""Convert Celsius to Fahrenheit.
You must call this to do the conversion of temperature, so you can get the
precise number in required format.
Args:
celsius (int): Temperature in Celsius.
Returns:
float: Temperature in Fahrenheit.
"""
return (celsius * 9 / 5) + 32
def fahrenheit_to_celsius(fahrenheit: float) -> int:
"""Convert Fahrenheit to Celsius.
You must call this to do the conversion of temperature, so you can get the
precise number in required format.
Args:
fahrenheit (float): Temperature in Fahrenheit.
Returns:
int: Temperature in Celsius.
"""
return int((fahrenheit - 32) * 5 / 9)
def list_devices(status: str = "", location: str = "") -> list:
"""Retrieve a list of AC devices, filtered by status and/or location when provided.
For cost efficiency, always apply as many filters (status and location) as
available in the input arguments.
Args:
status (str, optional): The status to filter devices by (e.g., 'ON',
'OFF'). Defaults to None.
location (str, optional): The location to filter devices by (e.g., 'Living
Room', 'Bedroom', ''Kitchen'). Defaults to None.
Returns:
list: A list of dictionaries, each containing the device ID, status, and
location, or an empty list if no devices match the criteria.
"""
devices = []
for device_id, info in DEVICE_DB.items():
if ((not status) or info["status"] == status) and (
(not location) or info["location"] == location
):
devices.append({
"device_id": device_id,
"status": info["status"],
"location": info["location"],
})
return devices if devices else "No devices found matching the criteria."
root_agent = Agent(
model="gemini-2.0-flash-001",
name="Home_automation_agent",
instruction="""
You are Home Automation Agent. You are responsible for controlling the devices in the home.
""",
tools=[
get_device_info,
set_device_info,
get_temperature,
set_temperature,
get_user_preferences,
set_device_schedule,
get_device_schedule,
celsius_to_fahrenheit,
fahrenheit_to_celsius,
list_devices,
],
)

View File

@@ -1,5 +0,0 @@
[{
"query": "Turn off device_2 in the Bedroom.",
"expected_tool_use": [{"tool_name": "set_device_info", "tool_input": {"location": "Bedroom", "device_id": "device_2", "status": "OFF"}}],
"reference": "I have set the device_2 status to off."
}]

View File

@@ -1,5 +0,0 @@
[{
"query": "Turn off device_3 in the Bedroom.",
"expected_tool_use": [{"tool_name": "set_device_info", "tool_input": {"location": "Bedroom", "device_id": "device_3", "status": "OFF"}}],
"reference": "I have set the device_3 status to off."
}]

View File

@@ -1,5 +0,0 @@
{
"criteria": {
"tool_trajectory_avg_score": 1.0
}
}

View File

@@ -1,18 +0,0 @@
[
{
"query": "Turn off device_2 in the Bedroom.",
"expected_tool_use": [{
"tool_name": "set_device_info",
"tool_input": {"location": "Bedroom", "status": "OFF", "device_id": "device_2"}
}],
"reference": "I have set the device 2 status to off."
},
{
"query": "What's the status of device_2 in the Bedroom?",
"expected_tool_use": [{
"tool_name": "get_device_info",
"tool_input": {"device_id": "device_2"}
}],
"reference": "Status of device_2 is off."
}
]

View File

@@ -1,17 +0,0 @@
[
{
"query": "Turn off device_2 in the Bedroom.",
"expected_tool_use": [
{
"tool_name": "set_device_info",
"tool_input": {"location": "Bedroom", "device_id": "device_2", "status": "OFF"}
}
],
"reference": "OK. I've turned off device_2 in the Bedroom. Anything else?\n"
},
{
"query": "What's the command I just issued?",
"expected_tool_use": [],
"reference": "You asked me to turn off device_2 in the Bedroom.\n"
}
]

View File

@@ -1,6 +0,0 @@
{
"criteria": {
"tool_trajectory_avg_score": 1.0,
"response_match_score": 0.5
}
}

View File

@@ -1,18 +0,0 @@
[
{
"query": "Turn off device_2 in the Bedroom.",
"expected_tool_use": [{
"tool_name": "set_device_info",
"tool_input": {"location": "Bedroom", "device_id": "device_2", "status": "OFF"}
}],
"reference": "I have set the device 2 status to off."
},
{
"query": "Turn on device_2 in the Bedroom.",
"expected_tool_use": [{
"tool_name": "set_device_info",
"tool_input": {"location": "Bedroom", "status": "ON", "device_id": "device_2"}
}],
"reference": "I have set the device 2 status to on."
}
]

View File

@@ -1,17 +0,0 @@
[
{
"query": "Turn off device_2 in the Bedroom.",
"expected_tool_use": [
{
"tool_name": "set_device_info",
"tool_input": {"location": "Bedroom", "device_id": "device_2", "status": "OFF"}
}
],
"reference": "OK. I've turned off device_2 in the Bedroom. Anything else?\n"
},
{
"query": "What's the command I just issued?",
"expected_tool_use": [],
"reference": "You asked me to turn off device_2 in the Bedroom.\n"
}
]

View File

@@ -1,5 +0,0 @@
[{
"query": "Turn off device_3 in the Bedroom.",
"expected_tool_use": [{"tool_name": "set_device_info", "tool_input": {"location": "Bedroom", "device_id": "device_3", "status": "OFF"}}],
"reference": "I have set the device_3 status to off."
}]

View File

@@ -1,5 +0,0 @@
{
"criteria": {
"tool_trajectory_avg_score": 1.0
}
}

View File

@@ -1,15 +0,0 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from . import agent

View File

@@ -1,218 +0,0 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from typing import Any
from crewai_tools import DirectoryReadTool
from google.adk import Agent
from google.adk.tools.agent_tool import AgentTool
from google.adk.tools.crewai_tool import CrewaiTool
from google.adk.tools.langchain_tool import LangchainTool
from google.adk.tools.retrieval.files_retrieval import FilesRetrieval
from google.adk.tools.retrieval.vertex_ai_rag_retrieval import VertexAiRagRetrieval
from langchain_community.tools import ShellTool
from pydantic import BaseModel
class TestCase(BaseModel):
case: str
class Test(BaseModel):
test_title: list[str]
def simple_function(param: str) -> str:
if isinstance(param, str):
return "Called simple function successfully"
return "Called simple function with wrong param type"
def no_param_function() -> str:
return "Called no param function successfully"
def no_output_function(param: str):
return
def multiple_param_types_function(
param1: str, param2: int, param3: float, param4: bool
) -> str:
if (
isinstance(param1, str)
and isinstance(param2, int)
and isinstance(param3, float)
and isinstance(param4, bool)
):
return "Called multiple param types function successfully"
return "Called multiple param types function with wrong param types"
def throw_error_function(param: str) -> str:
raise ValueError("Error thrown by throw_error_function")
def list_str_param_function(param: list[str]) -> str:
if isinstance(param, list) and all(isinstance(item, str) for item in param):
return "Called list str param function successfully"
return "Called list str param function with wrong param type"
def return_list_str_function(param: str) -> list[str]:
return ["Called return list str function successfully"]
def complex_function_list_dict(
param1: dict[str, Any], param2: list[dict[str, Any]]
) -> list[Test]:
if (
isinstance(param1, dict)
and isinstance(param2, list)
and all(isinstance(item, dict) for item in param2)
):
return [
Test(test_title=["function test 1", "function test 2"]),
Test(test_title=["retrieval test"]),
]
raise ValueError("Wrong param")
def repetive_call_1(param: str):
return f"Call repetive_call_2 tool with param {param + '_repetive'}"
def repetive_call_2(param: str):
return param
test_case_retrieval = FilesRetrieval(
name="test_case_retrieval",
description="General guidence for agent test cases",
input_dir=os.path.join(os.path.dirname(__file__), "files"),
)
valid_rag_retrieval = VertexAiRagRetrieval(
name="valid_rag_retrieval",
rag_corpora=[
"projects/1096655024998/locations/us-central1/ragCorpora/4985766262475849728"
],
description="General guidence for agent test cases",
)
invalid_rag_retrieval = VertexAiRagRetrieval(
name="invalid_rag_retrieval",
rag_corpora=[
"projects/1096655024998/locations/us-central1/InValidRagCorporas/4985766262475849728"
],
description="Invalid rag retrieval resource name",
)
non_exist_rag_retrieval = VertexAiRagRetrieval(
name="non_exist_rag_retrieval",
rag_corpora=[
"projects/1096655024998/locations/us-central1/RagCorpora/1234567"
],
description="Non exist rag retrieval resource name",
)
shell_tool = LangchainTool(ShellTool())
docs_tool = CrewaiTool(
name="direcotry_read_tool",
description="use this to find files for you.",
tool=DirectoryReadTool(directory="."),
)
no_schema_agent = Agent(
model="gemini-1.5-flash",
name="no_schema_agent",
instruction="""Just say 'Hi'
""",
)
schema_agent = Agent(
model="gemini-1.5-flash",
name="schema_agent",
instruction="""
You will be given a test case.
Return a list of the received test case appended with '_success' and '_failure' as test_titles
""",
input_schema=TestCase,
output_schema=Test,
)
no_input_schema_agent = Agent(
model="gemini-1.5-flash",
name="no_input_schema_agent",
instruction="""
Just return ['Tools_success, Tools_failure']
""",
output_schema=Test,
)
no_output_schema_agent = Agent(
model="gemini-1.5-flash",
name="no_output_schema_agent",
instruction="""
Just say 'Hi'
""",
input_schema=TestCase,
)
single_function_agent = Agent(
model="gemini-1.5-flash",
name="single_function_agent",
description="An agent that calls a single function",
instruction="When calling tools, just return what the tool returns.",
tools=[simple_function],
)
root_agent = Agent(
model="gemini-1.5-flash",
name="tool_agent",
description="An agent that can call other tools",
instruction="When calling tools, just return what the tool returns.",
tools=[
simple_function,
no_param_function,
no_output_function,
multiple_param_types_function,
throw_error_function,
list_str_param_function,
return_list_str_function,
# complex_function_list_dict,
repetive_call_1,
repetive_call_2,
test_case_retrieval,
valid_rag_retrieval,
invalid_rag_retrieval,
non_exist_rag_retrieval,
shell_tool,
docs_tool,
AgentTool(
agent=no_schema_agent,
),
AgentTool(
agent=schema_agent,
),
AgentTool(
agent=no_input_schema_agent,
),
AgentTool(
agent=no_output_schema_agent,
),
],
)

View File

@@ -1,15 +0,0 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from . import agent

View File

@@ -1,110 +0,0 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# https://github.com/crewAIInc/crewAI-examples/tree/main/trip_planner
from google.adk import Agent
# Agent that selects the best city for the trip.
identify_agent = Agent(
name='identify_agent',
description='Select the best city based on weather, season, and prices.',
instruction="""
Analyze and select the best city for the trip based
on specific criteria such as weather patterns, seasonal
events, and travel costs. This task involves comparing
multiple cities, considering factors like current weather
conditions, upcoming cultural or seasonal events, and
overall travel expenses.
Your final answer must be a detailed
report on the chosen city, and everything you found out
about it, including the actual flight costs, weather
forecast and attractions.
Traveling from: {origin}
City Options: {cities}
Trip Date: {range}
Traveler Interests: {interests}
""",
)
# Agent that gathers information about the city.
gather_agent = Agent(
name='gather_agent',
description='Provide the BEST insights about the selected city',
instruction="""
As a local expert on this city you must compile an
in-depth guide for someone traveling there and wanting
to have THE BEST trip ever!
Gather information about key attractions, local customs,
special events, and daily activity recommendations.
Find the best spots to go to, the kind of place only a
local would know.
This guide should provide a thorough overview of what
the city has to offer, including hidden gems, cultural
hotspots, must-visit landmarks, weather forecasts, and
high level costs.
The final answer must be a comprehensive city guide,
rich in cultural insights and practical tips,
tailored to enhance the travel experience.
Trip Date: {range}
Traveling from: {origin}
Traveler Interests: {interests}
""",
)
# Agent that plans the trip.
plan_agent = Agent(
name='plan_agent',
description="""Create the most amazing travel itineraries with budget and
packing suggestions for the city""",
instruction="""
Expand this guide into a full 7-day travel
itinerary with detailed per-day plans, including
weather forecasts, places to eat, packing suggestions,
and a budget breakdown.
You MUST suggest actual places to visit, actual hotels
to stay and actual restaurants to go to.
This itinerary should cover all aspects of the trip,
from arrival to departure, integrating the city guide
information with practical travel logistics.
Your final answer MUST be a complete expanded travel plan,
formatted as markdown, encompassing a daily schedule,
anticipated weather conditions, recommended clothing and
items to pack, and a detailed budget, ensuring THE BEST
TRIP EVER. Be specific and give it a reason why you picked
each place, what makes them special!
Trip Date: {range}
Traveling from: {origin}
Traveler Interests: {interests}
""",
)
root_agent = Agent(
model='gemini-2.0-flash-001',
name='trip_planner',
description='Plan the best trip ever',
instruction="""
Your goal is to plan the best trip according to information listed above.
You describe why did you choose the city, list top 3
attactions and provide a detailed itinerary for each day.""",
sub_agents=[identify_agent, gather_agent, plan_agent],
)

View File

@@ -1,13 +0,0 @@
{
"id": "test_id",
"app_name": "trip_planner_agent",
"user_id": "test_user",
"state": {
"origin": "San Francisco",
"interests": "Food, Shopping, Museums",
"range": "1000 miles",
"cities": ""
},
"events": [],
"last_update_time": 1741218714.258285
}

View File

@@ -1,5 +0,0 @@
{
"criteria": {
"response_match_score": 0.5
}
}

View File

@@ -1,13 +0,0 @@
{
"id": "test_id",
"app_name": "trip_planner_agent",
"user_id": "test_user",
"state": {
"origin": "San Francisco",
"interests": "Food, Shopping, Museums",
"range": "1000 miles",
"cities": ""
},
"events": [],
"last_update_time": 1741218714.258285
}

View File

@@ -1,5 +0,0 @@
{
"criteria": {
"response_match_score": 0.5
}
}

View File

@@ -1,7 +0,0 @@
[
{
"query": "Based on my interests, where should I go, Yosemite national park or Los Angeles?",
"expected_tool_use": [],
"reference": "Given your interests in food, shopping, and museums, Los Angeles would be a better choice than Yosemite National Park. Yosemite is primarily focused on outdoor activities and natural landscapes, while Los Angeles offers a diverse range of culinary experiences, shopping districts, and world-class museums. I will now gather information to create an in-depth guide for your trip to Los Angeles.\n"
}
]

View File

@@ -1,19 +0,0 @@
[
{
"query": "Hi, who are you? What can you do?",
"expected_tool_use": [],
"reference": "I am trip_planner, and my goal is to plan the best trip ever. I can describe why a city was chosen, list its top attractions, and provide a detailed itinerary for each day of the trip.\n"
},
{
"query": "I want to travel from San Francisco to an European country in fall next year. I am considering London and Paris. What is your advice?",
"expected_tool_use": [
{
"tool_name": "transfer_to_agent",
"tool_input": {
"agent_name": "indentify_agent"
}
}
],
"reference": "Okay, I can help you analyze London and Paris to determine which city is better for your trip next fall. I will consider weather patterns, seasonal events, travel costs (including flights from San Francisco), and your interests (food, shopping, and museums). After gathering this information, I'll provide a detailed report on my chosen city.\n"
}
]

View File

@@ -1,14 +0,0 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

View File

@@ -1,65 +0,0 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from google.adk.models import LlmRequest
from google.adk.models import LlmResponse
from google.adk.models.google_llm import Gemini
from google.genai import types
from google.genai.types import Content
from google.genai.types import Part
import pytest
@pytest.fixture
def gemini_llm():
return Gemini(model="gemini-1.5-flash")
@pytest.fixture
def llm_request():
return LlmRequest(
model="gemini-1.5-flash",
contents=[Content(role="user", parts=[Part.from_text(text="Hello")])],
config=types.GenerateContentConfig(
temperature=0.1,
response_modalities=[types.Modality.TEXT],
system_instruction="You are a helpful assistant",
),
)
@pytest.mark.asyncio
async def test_generate_content_async(gemini_llm, llm_request):
async for response in gemini_llm.generate_content_async(llm_request):
assert isinstance(response, LlmResponse)
assert response.content.parts[0].text
@pytest.mark.asyncio
async def test_generate_content_async_stream(gemini_llm, llm_request):
responses = [
resp
async for resp in gemini_llm.generate_content_async(
llm_request, stream=True
)
]
text = ""
for i in range(len(responses) - 1):
assert responses[i].partial is True
assert responses[i].content.parts[0].text
text += responses[i].content.parts[0].text
# Last message should be accumulated text
assert responses[-1].content.parts[0].text == text
assert not responses[-1].partial

View File

@@ -1,70 +0,0 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from pytest import mark
from ..unittests.utils import simplify_events
from .fixture import callback_agent
from .utils import assert_agent_says
from .utils import TestRunner
@mark.parametrize(
"agent_runner",
[{"agent": callback_agent.agent.before_agent_callback_agent}],
indirect=True,
)
def test_before_agent_call(agent_runner: TestRunner):
agent_runner.run("Hi.")
# Assert the response content
assert_agent_says(
"End invocation event before agent call.",
agent_name="before_agent_callback_agent",
agent_runner=agent_runner,
)
@mark.parametrize(
"agent_runner",
[{"agent": callback_agent.agent.before_model_callback_agent}],
indirect=True,
)
def test_before_model_call(agent_runner: TestRunner):
agent_runner.run("Hi.")
# Assert the response content
assert_agent_says(
"End invocation event before model call.",
agent_name="before_model_callback_agent",
agent_runner=agent_runner,
)
# TODO: re-enable vertex by removing below line after fixing.
@mark.parametrize("llm_backend", ["GOOGLE_AI"], indirect=True)
@mark.parametrize(
"agent_runner",
[{"agent": callback_agent.agent.after_model_callback_agent}],
indirect=True,
)
def test_after_model_call(agent_runner: TestRunner):
events = agent_runner.run("Hi.")
# Assert the response content
simplified_events = simplify_events(events)
assert simplified_events[0][0] == "after_model_callback_agent"
assert simplified_events[0][1].endswith(
"Update response event after model call."
)

View File

@@ -1,67 +0,0 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import pytest
# Skip until fixed.
pytest.skip(allow_module_level=True)
from .fixture import context_variable_agent
from .utils import TestRunner
@pytest.mark.parametrize(
"agent_runner",
[{"agent": context_variable_agent.agent.state_variable_echo_agent}],
indirect=True,
)
def test_context_variable_missing(agent_runner: TestRunner):
with pytest.raises(KeyError) as e_info:
agent_runner.run("Hi echo my customer id.")
assert "customerId" in str(e_info.value)
@pytest.mark.parametrize(
"agent_runner",
[{"agent": context_variable_agent.agent.state_variable_update_agent}],
indirect=True,
)
def test_context_variable_update(agent_runner: TestRunner):
_call_function_and_assert(
agent_runner,
"update_fc",
["RRRR", "3.141529", ["apple", "banana"], [1, 3.14, "hello"]],
"successfully",
)
def _call_function_and_assert(
agent_runner: TestRunner, function_name: str, params, expected
):
param_section = (
" with params"
f" {params if isinstance(params, str) else json.dumps(params)}"
if params is not None
else ""
)
agent_runner.run(
f"Call {function_name}{param_section} and show me the result"
)
model_response_event = agent_runner.get_events()[-1]
assert model_response_event.author == "context_variable_update_agent"
assert model_response_event.content.role == "model"
assert expected in model_response_event.content.parts[0].text.strip()

View File

@@ -1,76 +0,0 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Evaluate all agents in fixture folder if evaluation test files exist."""
import os
from google.adk.evaluation import AgentEvaluator
import pytest
def agent_eval_artifacts_in_fixture():
"""Get all agents from fixture folder."""
agent_eval_artifacts = []
fixture_dir = os.path.join(os.path.dirname(__file__), 'fixture')
for agent_name in os.listdir(fixture_dir):
agent_dir = os.path.join(fixture_dir, agent_name)
if not os.path.isdir(agent_dir):
continue
for filename in os.listdir(agent_dir):
# Evaluation test files end with test.json
if not filename.endswith('test.json'):
continue
initial_session_file = (
f'tests/integration/fixture/{agent_name}/initial.session.json'
)
agent_eval_artifacts.append((
f'tests.integration.fixture.{agent_name}',
f'tests/integration/fixture/{agent_name}/{filename}',
initial_session_file
if os.path.exists(initial_session_file)
else None,
))
# This method gets invoked twice, sorting helps ensure that both the
# invocations have the same view.
agent_eval_artifacts = sorted(
agent_eval_artifacts, key=lambda item: f'{item[0]}|{item[1]}'
)
return agent_eval_artifacts
@pytest.mark.parametrize(
'agent_name, evalfile, initial_session_file',
agent_eval_artifacts_in_fixture(),
ids=[agent_name for agent_name, _, _ in agent_eval_artifacts_in_fixture()],
)
def test_evaluate_agents_long_running_4_runs_per_eval_item(
agent_name, evalfile, initial_session_file
):
"""Test agents evaluation in fixture folder.
After querying the fixture folder, we have 5 eval items. For each eval item
we use 4 runs.
A single eval item is a session that can have multiple queries in it.
"""
AgentEvaluator.evaluate(
agent_module=agent_name,
eval_dataset_file_path_or_dir=evalfile,
initial_session_file=initial_session_file,
# Using a slightly higher value helps us manange the variances that may
# happen in each eval.
# This, of course, comes at a cost of incrased test run times.
num_runs=4,
)

View File

@@ -1,28 +0,0 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from google.adk.evaluation import AgentEvaluator
def test_eval_agent():
AgentEvaluator.evaluate(
agent_module="tests.integration.fixture.trip_planner_agent",
eval_dataset_file_path_or_dir=(
"tests/integration/fixture/trip_planner_agent/trip_inquiry.test.json"
),
initial_session_file=(
"tests/integration/fixture/trip_planner_agent/initial.session.json"
),
num_runs=4,
)

View File

@@ -1,42 +0,0 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from google.adk.evaluation import AgentEvaluator
def test_simple_multi_turn_conversation():
"""Test a simple multi-turn conversation."""
AgentEvaluator.evaluate(
agent_module="tests.integration.fixture.home_automation_agent",
eval_dataset_file_path_or_dir="tests/integration/fixture/home_automation_agent/test_files/simple_multi_turn_conversation.test.json",
num_runs=4,
)
def test_dependent_tool_calls():
"""Test subsequent tool calls that are dependent on previous tool calls."""
AgentEvaluator.evaluate(
agent_module="tests.integration.fixture.home_automation_agent",
eval_dataset_file_path_or_dir="tests/integration/fixture/home_automation_agent/test_files/dependent_tool_calls.test.json",
num_runs=4,
)
def test_memorizing_past_events():
"""Test memorizing past events."""
AgentEvaluator.evaluate(
agent_module="tests.integration.fixture.home_automation_agent",
eval_dataset_file_path_or_dir="tests/integration/fixture/home_automation_agent/test_files/memorizing_past_events/eval_data.test.json",
num_runs=4,
)

View File

@@ -1,23 +0,0 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from google.adk.evaluation import AgentEvaluator
def test_eval_agent():
AgentEvaluator.evaluate(
agent_module="tests.integration.fixture.home_automation_agent",
eval_dataset_file_path_or_dir="tests/integration/fixture/home_automation_agent/simple_test.test.json",
num_runs=4,
)

View File

@@ -1,26 +0,0 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from google.adk.evaluation import AgentEvaluator
def test_eval_agent():
"""Test hotel sub agent in a multi-agent system."""
AgentEvaluator.evaluate(
agent_module="tests.integration.fixture.trip_planner_agent",
eval_dataset_file_path_or_dir="tests/integration/fixture/trip_planner_agent/test_files/trip_inquiry_sub_agent.test.json",
initial_session_file="tests/integration/fixture/trip_planner_agent/test_files/initial.session.json",
agent_name="identify_agent",
num_runs=4,
)

View File

@@ -1,177 +0,0 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import pytest
# Skip until fixed.
pytest.skip(allow_module_level=True)
from google.adk.agents import InvocationContext
from google.adk.sessions import Session
from google.genai import types
from .fixture import context_variable_agent
from .utils import TestRunner
nl_planner_si = """
You are an intelligent tool use agent built upon the Gemini large language model. When answering the question, try to leverage the available tools to gather the information instead of your memorized knowledge.
Follow this process when answering the question: (1) first come up with a plan in natural language text format; (2) Then use tools to execute the plan and provide reasoning between tool code snippets to make a summary of current state and next step. Tool code snippets and reasoning should be interleaved with each other. (3) In the end, return one final answer.
Follow this format when answering the question: (1) The planning part should be under /*PLANNING*/. (2) The tool code snippets should be under /*ACTION*/, and the reasoning parts should be under /*REASONING*/. (3) The final answer part should be under /*FINAL_ANSWER*/.
Below are the requirements for the planning:
The plan is made to answer the user query if following the plan. The plan is coherent and covers all aspects of information from user query, and only involves the tools that are accessible by the agent. The plan contains the decomposed steps as a numbered list where each step should use one or multiple available tools. By reading the plan, you can intuitively know which tools to trigger or what actions to take.
If the initial plan cannot be successfully executed, you should learn from previous execution results and revise your plan. The revised plan should be be under /*REPLANNING*/. Then use tools to follow the new plan.
Below are the requirements for the reasoning:
The reasoning makes a summary of the current trajectory based on the user query and tool outputs. Based on the tool outputs and plan, the reasoning also comes up with instructions to the next steps, making the trajectory closer to the final answer.
Below are the requirements for the final answer:
The final answer should be precise and follow query formatting requirements. Some queries may not be answerable with the available tools and information. In those cases, inform the user why you cannot process their query and ask for more information.
Below are the requirements for the tool code:
**Custom Tools:** The available tools are described in the context and can be directly used.
- Code must be valid self-contained Python snippets with no imports and no references to tools or Python libraries that are not in the context.
- You cannot use any parameters or fields that are not explicitly defined in the APIs in the context.
- Use "print" to output execution results for the next step or final answer that you need for responding to the user. Never generate ```tool_outputs yourself.
- The code snippets should be readable, efficient, and directly relevant to the user query and reasoning steps.
- When using the tools, you should use the library name together with the function name, e.g., vertex_search.search().
- If Python libraries are not provided in the context, NEVER write your own code other than the function calls using the provided tools.
VERY IMPORTANT instruction that you MUST follow in addition to the above instructions:
You should ask for clarification if you need more information to answer the question.
You should prefer using the information available in the context instead of repeated tool use.
You should ONLY generate code snippets prefixed with "```tool_code" if you need to use the tools to answer the question.
If you are asked to write code by user specifically,
- you should ALWAYS use "```python" to format the code.
- you should NEVER put "tool_code" to format the code.
- Good example:
```python
print('hello')
```
- Bad example:
```tool_code
print('hello')
```
"""
@pytest.mark.parametrize(
"agent_runner",
[{"agent": context_variable_agent.agent.state_variable_echo_agent}],
indirect=True,
)
def test_context_variable(agent_runner: TestRunner):
session = Session(
context={
"customerId": "1234567890",
"customerInt": 30,
"customerFloat": 12.34,
"customerJson": {"name": "John Doe", "age": 30, "count": 11.1},
}
)
si = UnitFlow()._build_system_instruction(
InvocationContext(
invocation_id="1234567890", agent=agent_runner.agent, session=session
)
)
assert (
"Use the echo_info tool to echo 1234567890, 30, 12.34, and {'name': 'John"
" Doe', 'age': 30, 'count': 11.1}. Ask for it if you need to."
in si
)
@pytest.mark.parametrize(
"agent_runner",
[{
"agent": (
context_variable_agent.agent.state_variable_with_complicated_format_agent
)
}],
indirect=True,
)
def test_context_variable_with_complicated_format(agent_runner: TestRunner):
session = Session(
context={"customerId": "1234567890", "customer_int": 30},
artifacts={"fileName": [types.Part(text="test artifact")]},
)
si = _context_formatter.populate_context_and_artifact_variable_values(
agent_runner.agent.instruction,
session.get_state(),
session.get_artifact_dict(),
)
assert (
si
== "Use the echo_info tool to echo 1234567890, 30, { "
" non-identifier-float}}, test artifact, {'key1': 'value1'} and"
" {{'key2': 'value2'}}. Ask for it if you need to."
)
@pytest.mark.parametrize(
"agent_runner",
[{
"agent": (
context_variable_agent.agent.state_variable_with_nl_planner_agent
)
}],
indirect=True,
)
def test_nl_planner(agent_runner: TestRunner):
session = Session(context={"customerId": "1234567890"})
si = UnitFlow()._build_system_instruction(
InvocationContext(
invocation_id="1234567890",
agent=agent_runner.agent,
session=session,
)
)
for line in nl_planner_si.splitlines():
assert line in si
@pytest.mark.parametrize(
"agent_runner",
[{
"agent": (
context_variable_agent.agent.state_variable_with_function_instruction_agent
)
}],
indirect=True,
)
def test_function_instruction(agent_runner: TestRunner):
session = Session(context={"customerId": "1234567890"})
si = UnitFlow()._build_system_instruction(
InvocationContext(
invocation_id="1234567890", agent=agent_runner.agent, session=session
)
)
assert "This is the plain text sub agent instruction." in si

View File

@@ -1,287 +0,0 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import pytest
# Skip until fixed.
pytest.skip(allow_module_level=True)
from .fixture import tool_agent
from .utils import TestRunner
@pytest.mark.parametrize(
"agent_runner",
[{"agent": tool_agent.agent.single_function_agent}],
indirect=True,
)
def test_single_function_calls_success(agent_runner: TestRunner):
_call_function_and_assert(
agent_runner,
"simple_function",
"test",
"success",
)
@pytest.mark.parametrize(
"agent_runner",
[{"agent": tool_agent.agent.root_agent}],
indirect=True,
)
def test_multiple_function_calls_success(agent_runner: TestRunner):
_call_function_and_assert(
agent_runner,
"simple_function",
"test",
"success",
)
_call_function_and_assert(
agent_runner,
"no_param_function",
None,
"Called no param function successfully",
)
_call_function_and_assert(
agent_runner,
"no_output_function",
"test",
"",
)
_call_function_and_assert(
agent_runner,
"multiple_param_types_function",
["test", 1, 2.34, True],
"success",
)
_call_function_and_assert(
agent_runner,
"return_list_str_function",
"test",
"success",
)
_call_function_and_assert(
agent_runner,
"list_str_param_function",
["test", "test2", "test3", "test4"],
"success",
)
@pytest.mark.skip(reason="Currently failing with 400 on MLDev.")
@pytest.mark.parametrize(
"agent_runner",
[{"agent": tool_agent.agent.root_agent}],
indirect=True,
)
def test_complex_function_calls_success(agent_runner: TestRunner):
param1 = {"name": "Test", "count": 3}
param2 = [
{"name": "Function", "count": 2},
{"name": "Retrieval", "count": 1},
]
_call_function_and_assert(
agent_runner,
"complex_function_list_dict",
[param1, param2],
"test",
)
@pytest.mark.parametrize(
"agent_runner",
[{"agent": tool_agent.agent.root_agent}],
indirect=True,
)
def test_repetive_call_success(agent_runner: TestRunner):
_call_function_and_assert(
agent_runner,
"repetive_call_1",
"test",
"test_repetive",
)
@pytest.mark.parametrize(
"agent_runner",
[{"agent": tool_agent.agent.root_agent}],
indirect=True,
)
def test_function_calls_fail(agent_runner: TestRunner):
_call_function_and_assert(
agent_runner,
"throw_error_function",
"test",
None,
ValueError,
)
@pytest.mark.parametrize(
"agent_runner",
[{"agent": tool_agent.agent.root_agent}],
indirect=True,
)
def test_agent_tools_success(agent_runner: TestRunner):
_call_function_and_assert(
agent_runner,
"no_schema_agent",
"Hi",
"Hi",
)
_call_function_and_assert(
agent_runner,
"schema_agent",
"Agent_tools",
"Agent_tools_success",
)
_call_function_and_assert(
agent_runner, "no_input_schema_agent", "Tools", "Tools_success"
)
_call_function_and_assert(agent_runner, "no_output_schema_agent", "Hi", "Hi")
@pytest.mark.parametrize(
"agent_runner",
[{"agent": tool_agent.agent.root_agent}],
indirect=True,
)
def test_files_retrieval_success(agent_runner: TestRunner):
_call_function_and_assert(
agent_runner,
"test_case_retrieval",
"What is the testing strategy of agent 2.0?",
"test",
)
# For non relevant query, the agent should still be running fine, just return
# response might be different for different calls, so we don't compare the
# response here.
_call_function_and_assert(
agent_runner,
"test_case_retrieval",
"What is the whether in bay area?",
"",
)
@pytest.mark.parametrize(
"agent_runner",
[{"agent": tool_agent.agent.root_agent}],
indirect=True,
)
def test_rag_retrieval_success(agent_runner: TestRunner):
_call_function_and_assert(
agent_runner,
"valid_rag_retrieval",
"What is the testing strategy of agent 2.0?",
"test",
)
_call_function_and_assert(
agent_runner,
"valid_rag_retrieval",
"What is the whether in bay area?",
"No",
)
@pytest.mark.parametrize(
"agent_runner",
[{"agent": tool_agent.agent.root_agent}],
indirect=True,
)
def test_rag_retrieval_fail(agent_runner: TestRunner):
_call_function_and_assert(
agent_runner,
"invalid_rag_retrieval",
"What is the testing strategy of agent 2.0?",
None,
ValueError,
)
_call_function_and_assert(
agent_runner,
"non_exist_rag_retrieval",
"What is the whether in bay area?",
None,
ValueError,
)
@pytest.mark.parametrize(
"agent_runner",
[{"agent": tool_agent.agent.root_agent}],
indirect=True,
)
def test_langchain_tool_success(agent_runner: TestRunner):
_call_function_and_assert(
agent_runner,
"terminal",
"Run the following shell command 'echo test!'",
"test",
)
@pytest.mark.parametrize(
"agent_runner",
[{"agent": tool_agent.agent.root_agent}],
indirect=True,
)
def test_crewai_tool_success(agent_runner: TestRunner):
_call_function_and_assert(
agent_runner,
"direcotry_read_tool",
"Find all the file paths",
"file",
)
def _call_function_and_assert(
agent_runner: TestRunner,
function_name: str,
params,
expected=None,
exception: Exception = None,
):
param_section = (
" with params"
f" {params if isinstance(params, str) else json.dumps(params)}"
if params is not None
else ""
)
query = f"Call {function_name}{param_section} and show me the result"
if exception:
_assert_raises(agent_runner, query, exception)
return
_assert_function_output(agent_runner, query, expected)
def _assert_raises(agent_runner: TestRunner, query: str, exception: Exception):
with pytest.raises(exception):
agent_runner.run(query)
def _assert_function_output(agent_runner: TestRunner, query: str, expected):
agent_runner.run(query)
# Retrieve the latest model response event
model_response_event = agent_runner.get_events()[-1]
# Assert the response content
assert model_response_event.content.role == "model"
assert (
expected.lower()
in model_response_event.content.parts[0].text.strip().lower()
)

View File

@@ -1,34 +0,0 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from google.adk.evaluation import AgentEvaluator
def test_with_single_test_file():
"""Test the agent's basic ability via session file."""
AgentEvaluator.evaluate(
agent_module="tests.integration.fixture.home_automation_agent",
eval_dataset_file_path_or_dir="tests/integration/fixture/home_automation_agent/simple_test.test.json",
)
def test_with_folder_of_test_files_long_running():
"""Test the agent's basic ability via a folder of session files."""
AgentEvaluator.evaluate(
agent_module="tests.integration.fixture.home_automation_agent",
eval_dataset_file_path_or_dir=(
"tests/integration/fixture/home_automation_agent/test_files"
),
num_runs=4,
)

View File

@@ -1,14 +0,0 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

View File

@@ -1,16 +0,0 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .asserts import *
from .test_runner import TestRunner

View File

@@ -1,75 +0,0 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TypedDict
from .test_runner import TestRunner
class Message(TypedDict):
agent_name: str
expected_text: str
def assert_current_agent_is(agent_name: str, *, agent_runner: TestRunner):
assert agent_runner.get_current_agent_name() == agent_name
def assert_agent_says(
expected_text: str, *, agent_name: str, agent_runner: TestRunner
):
for event in reversed(agent_runner.get_events()):
if event.author == agent_name and event.content.parts[0].text:
assert event.content.parts[0].text.strip() == expected_text
return
def assert_agent_says_in_order(
expected_conversation: list[Message], agent_runner: TestRunner
):
expected_conversation_idx = len(expected_conversation) - 1
for event in reversed(agent_runner.get_events()):
if event.content.parts and event.content.parts[0].text:
assert (
event.author
== expected_conversation[expected_conversation_idx]['agent_name']
)
assert (
event.content.parts[0].text.strip()
== expected_conversation[expected_conversation_idx]['expected_text']
)
expected_conversation_idx -= 1
if expected_conversation_idx < 0:
return
def assert_agent_transfer_path(
expected_path: list[str], *, agent_runner: TestRunner
):
events = agent_runner.get_events()
idx_in_expected_path = len(expected_path) - 1
# iterate events in reverse order
for event in reversed(events):
function_calls = event.get_function_calls()
if (
len(function_calls) == 1
and function_calls[0].name == 'transfer_to_agent'
):
assert (
function_calls[0].args['agent_name']
== expected_path[idx_in_expected_path]
)
idx_in_expected_path -= 1
if idx_in_expected_path < 0:
return

View File

@@ -1,97 +0,0 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import importlib
from typing import Optional
from google.adk import Agent
from google.adk import Runner
from google.adk.artifacts import BaseArtifactService
from google.adk.artifacts import InMemoryArtifactService
from google.adk.events import Event
from google.adk.sessions import BaseSessionService
from google.adk.sessions import InMemorySessionService
from google.adk.sessions import Session
from google.genai import types
class TestRunner:
"""Agents runner for testing."""
app_name = "test_app"
user_id = "test_user"
def __init__(
self,
agent: Agent,
artifact_service: BaseArtifactService = InMemoryArtifactService(),
session_service: BaseSessionService = InMemorySessionService(),
) -> None:
self.agent = agent
self.agent_client = Runner(
app_name=self.app_name,
agent=agent,
artifact_service=artifact_service,
session_service=session_service,
)
self.session_service = session_service
self.current_session_id = session_service.create_session(
app_name=self.app_name, user_id=self.user_id
).id
def new_session(self, session_id: Optional[str] = None) -> None:
self.current_session_id = self.session_service.create_session(
app_name=self.app_name, user_id=self.user_id, session_id=session_id
).id
def run(self, prompt: str) -> list[Event]:
current_session = self.session_service.get_session(
app_name=self.app_name,
user_id=self.user_id,
session_id=self.current_session_id,
)
assert current_session is not None
return list(
self.agent_client.run(
user_id=current_session.user_id,
session_id=current_session.id,
new_message=types.Content(
role="user",
parts=[types.Part.from_text(text=prompt)],
),
)
)
def get_current_session(self) -> Optional[Session]:
return self.session_service.get_session(
app_name=self.app_name,
user_id=self.user_id,
session_id=self.current_session_id,
)
def get_events(self) -> list[Event]:
return self.get_current_session().events
@classmethod
def from_agent_name(cls, agent_name: str):
agent_module_path = f"tests.integration.fixture.{agent_name}"
agent_module = importlib.import_module(agent_module_path)
agent: Agent = agent_module.agent.root_agent
return cls(agent)
def get_current_agent_name(self) -> str:
return self.agent_client._find_agent_to_run(
self.get_current_session(), self.agent
).name

View File

@@ -196,11 +196,12 @@ class IntegrationClient:
action_details = connections_client.get_action_schema(action)
input_schema = action_details["inputSchema"]
output_schema = action_details["outputSchema"]
action_display_name = action_details["displayName"]
# Remove spaces from the display name to generate valid spec
action_display_name = action_details["displayName"].replace(" ", "")
operation = "EXECUTE_ACTION"
if action == "ExecuteCustomQuery":
connector_spec["components"]["schemas"][
f"{action}_Request"
f"{action_display_name}_Request"
] = connections_client.execute_custom_query_request()
operation = "EXECUTE_QUERY"
else:

View File

@@ -291,7 +291,7 @@ def _parse_schema_from_parameter(
return schema
raise ValueError(
f'Failed to parse the parameter {param} of function {func_name} for'
' automatic function calling.Automatic function calling works best with'
' automatic function calling. Automatic function calling works best with'
' simpler function signature schema,consider manually parse your'
f' function declaration for function {func_name}.'
)

View File

@@ -11,4 +11,77 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .google_api_tool_sets import calendar_tool_set
__all__ = [
'bigquery_tool_set',
'calendar_tool_set',
'gmail_tool_set',
'youtube_tool_set',
'slides_tool_set',
'sheets_tool_set',
'docs_tool_set',
]
# Nothing is imported here automatically
# Each tool set will only be imported when accessed
_bigquery_tool_set = None
_calendar_tool_set = None
_gmail_tool_set = None
_youtube_tool_set = None
_slides_tool_set = None
_sheets_tool_set = None
_docs_tool_set = None
def __getattr__(name):
global _bigquery_tool_set, _calendar_tool_set, _gmail_tool_set, _youtube_tool_set, _slides_tool_set, _sheets_tool_set, _docs_tool_set
match name:
case 'bigquery_tool_set':
if _bigquery_tool_set is None:
from .google_api_tool_sets import bigquery_tool_set as bigquery
_bigquery_tool_set = bigquery
return _bigquery_tool_set
case 'calendar_tool_set':
if _calendar_tool_set is None:
from .google_api_tool_sets import calendar_tool_set as calendar
_calendar_tool_set = calendar
return _calendar_tool_set
case 'gmail_tool_set':
if _gmail_tool_set is None:
from .google_api_tool_sets import gmail_tool_set as gmail
_gmail_tool_set = gmail
return _gmail_tool_set
case 'youtube_tool_set':
if _youtube_tool_set is None:
from .google_api_tool_sets import youtube_tool_set as youtube
_youtube_tool_set = youtube
return _youtube_tool_set
case 'slides_tool_set':
if _slides_tool_set is None:
from .google_api_tool_sets import slides_tool_set as slides
_slides_tool_set = slides
return _slides_tool_set
case 'sheets_tool_set':
if _sheets_tool_set is None:
from .google_api_tool_sets import sheets_tool_set as sheets
_sheets_tool_set = sheets
return _sheets_tool_set
case 'docs_tool_set':
if _docs_tool_set is None:
from .google_api_tool_sets import docs_tool_set as docs
_docs_tool_set = docs
return _docs_tool_set

View File

@@ -19,37 +19,94 @@ from .google_api_tool_set import GoogleApiToolSet
logger = logging.getLogger(__name__)
calendar_tool_set = GoogleApiToolSet.load_tool_set(
api_name="calendar",
api_version="v3",
)
_bigquery_tool_set = None
_calendar_tool_set = None
_gmail_tool_set = None
_youtube_tool_set = None
_slides_tool_set = None
_sheets_tool_set = None
_docs_tool_set = None
bigquery_tool_set = GoogleApiToolSet.load_tool_set(
api_name="bigquery",
api_version="v2",
)
gmail_tool_set = GoogleApiToolSet.load_tool_set(
api_name="gmail",
api_version="v1",
)
def __getattr__(name):
"""This method dynamically loads and returns GoogleApiToolSet instances for
youtube_tool_set = GoogleApiToolSet.load_tool_set(
api_name="youtube",
api_version="v3",
)
various Google APIs. It uses a lazy loading approach, initializing each
tool set only when it is first requested. This avoids unnecessary loading
of tool sets that are not used in a given session.
slides_tool_set = GoogleApiToolSet.load_tool_set(
api_name="slides",
api_version="v1",
)
Args:
name (str): The name of the tool set to retrieve (e.g.,
"bigquery_tool_set").
sheets_tool_set = GoogleApiToolSet.load_tool_set(
api_name="sheets",
api_version="v4",
)
Returns:
GoogleApiToolSet: The requested tool set instance.
docs_tool_set = GoogleApiToolSet.load_tool_set(
api_name="docs",
api_version="v1",
)
Raises:
AttributeError: If the requested tool set name is not recognized.
"""
global _bigquery_tool_set, _calendar_tool_set, _gmail_tool_set, _youtube_tool_set, _slides_tool_set, _sheets_tool_set, _docs_tool_set
match name:
case "bigquery_tool_set":
if _bigquery_tool_set is None:
_bigquery_tool_set = GoogleApiToolSet.load_tool_set(
api_name="bigquery",
api_version="v2",
)
return _bigquery_tool_set
case "calendar_tool_set":
if _calendar_tool_set is None:
_calendar_tool_set = GoogleApiToolSet.load_tool_set(
api_name="calendar",
api_version="v3",
)
return _calendar_tool_set
case "gmail_tool_set":
if _gmail_tool_set is None:
_gmail_tool_set = GoogleApiToolSet.load_tool_set(
api_name="gmail",
api_version="v1",
)
return _gmail_tool_set
case "youtube_tool_set":
if _youtube_tool_set is None:
_youtube_tool_set = GoogleApiToolSet.load_tool_set(
api_name="youtube",
api_version="v3",
)
return _youtube_tool_set
case "slides_tool_set":
if _slides_tool_set is None:
_slides_tool_set = GoogleApiToolSet.load_tool_set(
api_name="slides",
api_version="v1",
)
return _slides_tool_set
case "sheets_tool_set":
if _sheets_tool_set is None:
_sheets_tool_set = GoogleApiToolSet.load_tool_set(
api_name="sheets",
api_version="v4",
)
return _sheets_tool_set
case "docs_tool_set":
if _docs_tool_set is None:
_docs_tool_set = GoogleApiToolSet.load_tool_set(
api_name="docs",
api_version="v1",
)
return _docs_tool_set

View File

@@ -311,7 +311,9 @@ class GoogleApiToOpenApiConverter:
# Determine the actual endpoint path
# Google often has the format something like 'users.messages.list'
rest_path = method_data.get("path", "/")
# flatPath is preferred as it provides the actual path, while path
# might contain variables like {+projectId}
rest_path = method_data.get("flatPath", method_data.get("path", "/"))
if not rest_path.startswith("/"):
rest_path = "/" + rest_path

View File

@@ -16,18 +16,26 @@ from __future__ import annotations
from typing import TYPE_CHECKING
from google.genai import types
from typing_extensions import override
from .function_tool import FunctionTool
from .tool_context import ToolContext
if TYPE_CHECKING:
from ..models import LlmRequest
from ..memory.base_memory_service import MemoryResult
from ..models import LlmRequest
def load_memory(query: str, tool_context: ToolContext) -> 'list[MemoryResult]':
"""Loads the memory for the current user."""
"""Loads the memory for the current user.
Args:
query: The query to load the memory for.
Returns:
A list of memory results.
"""
response = tool_context.search_memory(query)
return response.memories
@@ -38,6 +46,21 @@ class LoadMemoryTool(FunctionTool):
def __init__(self):
super().__init__(load_memory)
@override
def _get_declaration(self) -> types.FunctionDeclaration | None:
return types.FunctionDeclaration(
name=self.name,
description=self.description,
parameters=types.Schema(
type=types.Type.OBJECT,
properties={
'query': types.Schema(
type=types.Type.STRING,
)
},
),
)
@override
async def process_llm_request(
self,

View File

@@ -0,0 +1,176 @@
from contextlib import AsyncExitStack
import functools
import sys
from typing import Any, TextIO
import anyio
from pydantic import BaseModel
try:
from mcp import ClientSession, StdioServerParameters
from mcp.client.sse import sse_client
from mcp.client.stdio import stdio_client
except ImportError as e:
import sys
if sys.version_info < (3, 10):
raise ImportError(
'MCP Tool requires Python 3.10 or above. Please upgrade your Python'
' version.'
) from e
else:
raise e
class SseServerParams(BaseModel):
"""Parameters for the MCP SSE connection.
See MCP SSE Client documentation for more details.
https://github.com/modelcontextprotocol/python-sdk/blob/main/src/mcp/client/sse.py
"""
url: str
headers: dict[str, Any] | None = None
timeout: float = 5
sse_read_timeout: float = 60 * 5
def retry_on_closed_resource(async_reinit_func_name: str):
"""Decorator to automatically reinitialize session and retry action.
When MCP session was closed, the decorator will automatically recreate the
session and retry the action with the same parameters.
Note:
1. async_reinit_func_name is the name of the class member function that
reinitializes the MCP session.
2. Both the decorated function and the async_reinit_func_name must be async
functions.
Usage:
class MCPTool:
...
async def create_session(self):
self.session = ...
@retry_on_closed_resource('create_session')
async def use_session(self):
await self.session.call_tool()
Args:
async_reinit_func_name: The name of the async function to recreate session.
Returns:
The decorated function.
"""
def decorator(func):
@functools.wraps(
func
) # Preserves original function metadata (name, docstring)
async def wrapper(self, *args, **kwargs):
try:
return await func(self, *args, **kwargs)
except anyio.ClosedResourceError:
try:
if hasattr(self, async_reinit_func_name) and callable(
getattr(self, async_reinit_func_name)
):
async_init_fn = getattr(self, async_reinit_func_name)
await async_init_fn()
else:
raise ValueError(
f'Function {async_reinit_func_name} does not exist in decorated'
' class. Please check the function name in'
' retry_on_closed_resource decorator.'
)
except Exception as reinit_err:
raise RuntimeError(
f'Error reinitializing: {reinit_err}'
) from reinit_err
return await func(self, *args, **kwargs)
return wrapper
return decorator
class MCPSessionManager:
"""Manages MCP client sessions.
This class provides methods for creating and initializing MCP client sessions,
handling different connection parameters (Stdio and SSE).
"""
def __init__(
self,
connection_params: StdioServerParameters | SseServerParams,
exit_stack: AsyncExitStack,
errlog: TextIO = sys.stderr,
) -> ClientSession:
"""Initializes the MCP session manager.
Example usage:
```
mcp_session_manager = MCPSessionManager(
connection_params=connection_params,
exit_stack=exit_stack,
)
session = await mcp_session_manager.create_session()
```
Args:
connection_params: Parameters for the MCP connection (Stdio or SSE).
exit_stack: AsyncExitStack to manage the session lifecycle.
errlog: (Optional) TextIO stream for error logging. Use only for
initializing a local stdio MCP session.
"""
self.connection_params = connection_params
self.exit_stack = exit_stack
self.errlog = errlog
async def create_session(self) -> ClientSession:
return await MCPSessionManager.initialize_session(
connection_params=self.connection_params,
exit_stack=self.exit_stack,
errlog=self.errlog,
)
@classmethod
async def initialize_session(
cls,
*,
connection_params: StdioServerParameters | SseServerParams,
exit_stack: AsyncExitStack,
errlog: TextIO = sys.stderr,
) -> ClientSession:
"""Initializes an MCP client session.
Args:
connection_params: Parameters for the MCP connection (Stdio or SSE).
exit_stack: AsyncExitStack to manage the session lifecycle.
errlog: (Optional) TextIO stream for error logging. Use only for
initializing a local stdio MCP session.
Returns:
ClientSession: The initialized MCP client session.
"""
if isinstance(connection_params, StdioServerParameters):
client = stdio_client(server=connection_params, errlog=errlog)
elif isinstance(connection_params, SseServerParams):
client = sse_client(
url=connection_params.url,
headers=connection_params.headers,
timeout=connection_params.timeout,
sse_read_timeout=connection_params.sse_read_timeout,
)
else:
raise ValueError(
'Unable to initialize connection. Connection should be'
' StdioServerParameters or SseServerParams, but got'
f' {connection_params}'
)
transports = await exit_stack.enter_async_context(client)
session = await exit_stack.enter_async_context(ClientSession(*transports))
await session.initialize()
return session

View File

@@ -17,6 +17,8 @@ from typing import Optional
from google.genai.types import FunctionDeclaration
from typing_extensions import override
from .mcp_session_manager import MCPSessionManager, retry_on_closed_resource
# Attempt to import MCP Tool from the MCP library, and hints user to upgrade
# their Python version to 3.10 if it fails.
try:
@@ -33,6 +35,7 @@ except ImportError as e:
else:
raise e
from ..base_tool import BaseTool
from ...auth.auth_credential import AuthCredential
from ...auth.auth_schemes import AuthScheme
@@ -51,6 +54,7 @@ class MCPTool(BaseTool):
self,
mcp_tool: McpBaseTool,
mcp_session: ClientSession,
mcp_session_manager: MCPSessionManager,
auth_scheme: Optional[AuthScheme] = None,
auth_credential: Optional[AuthCredential] | None = None,
):
@@ -79,10 +83,14 @@ class MCPTool(BaseTool):
self.description = mcp_tool.description if mcp_tool.description else ""
self.mcp_tool = mcp_tool
self.mcp_session = mcp_session
self.mcp_session_manager = mcp_session_manager
# TODO(cheliu): Support passing auth to MCP Server.
self.auth_scheme = auth_scheme
self.auth_credential = auth_credential
async def _reinitialize_session(self):
self.mcp_session = await self.mcp_session_manager.create_session()
@override
def _get_declaration(self) -> FunctionDeclaration:
"""Gets the function declaration for the tool.
@@ -98,6 +106,7 @@ class MCPTool(BaseTool):
return function_decl
@override
@retry_on_closed_resource("_reinitialize_session")
async def run_async(self, *, args, tool_context: ToolContext):
"""Runs the tool asynchronously.
@@ -109,5 +118,9 @@ class MCPTool(BaseTool):
Any: The response from the tool.
"""
# TODO(cheliu): Support passing tool context to MCP Server.
response = await self.mcp_session.call_tool(self.name, arguments=args)
return response
try:
response = await self.mcp_session.call_tool(self.name, arguments=args)
return response
except Exception as e:
print(e)
raise e

View File

@@ -13,15 +13,16 @@
# limitations under the License.
from contextlib import AsyncExitStack
import sys
from types import TracebackType
from typing import Any, List, Optional, Tuple, Type
from typing import List, Optional, TextIO, Tuple, Type
from .mcp_session_manager import MCPSessionManager, SseServerParams, retry_on_closed_resource
# Attempt to import MCP Tool from the MCP library, and hints user to upgrade
# their Python version to 3.10 if it fails.
try:
from mcp import ClientSession, StdioServerParameters
from mcp.client.sse import sse_client
from mcp.client.stdio import stdio_client
from mcp.types import ListToolsResult
except ImportError as e:
import sys
@@ -34,18 +35,9 @@ except ImportError as e:
else:
raise e
from pydantic import BaseModel
from .mcp_tool import MCPTool
class SseServerParams(BaseModel):
url: str
headers: dict[str, Any] | None = None
timeout: float = 5
sse_read_timeout: float = 60 * 5
class MCPToolset:
"""Connects to a MCP Server, and retrieves MCP Tools into ADK Tools.
@@ -110,7 +102,11 @@ class MCPToolset:
"""
def __init__(
self, *, connection_params: StdioServerParameters | SseServerParams
self,
*,
connection_params: StdioServerParameters | SseServerParams,
errlog: TextIO = sys.stderr,
exit_stack=AsyncExitStack(),
):
"""Initializes the MCPToolset.
@@ -175,7 +171,14 @@ class MCPToolset:
if not connection_params:
raise ValueError('Missing connection params in MCPToolset.')
self.connection_params = connection_params
self.exit_stack = AsyncExitStack()
self.errlog = errlog
self.exit_stack = exit_stack
self.session_manager = MCPSessionManager(
connection_params=self.connection_params,
exit_stack=self.exit_stack,
errlog=self.errlog,
)
@classmethod
async def from_server(
@@ -183,6 +186,7 @@ class MCPToolset:
*,
connection_params: StdioServerParameters | SseServerParams,
async_exit_stack: Optional[AsyncExitStack] = None,
errlog: TextIO = sys.stderr,
) -> Tuple[List[MCPTool], AsyncExitStack]:
"""Retrieve all tools from the MCP connection.
@@ -209,41 +213,27 @@ class MCPToolset:
the MCP server. Use `await async_exit_stack.aclose()` to close the
connection when server shuts down.
"""
toolset = cls(connection_params=connection_params)
async_exit_stack = async_exit_stack or AsyncExitStack()
toolset = cls(
connection_params=connection_params,
exit_stack=async_exit_stack,
errlog=errlog,
)
await async_exit_stack.enter_async_context(toolset)
tools = await toolset.load_tools()
return (tools, async_exit_stack)
async def _initialize(self) -> ClientSession:
"""Connects to the MCP Server and initializes the ClientSession."""
if isinstance(self.connection_params, StdioServerParameters):
client = stdio_client(self.connection_params)
elif isinstance(self.connection_params, SseServerParams):
client = sse_client(
url=self.connection_params.url,
headers=self.connection_params.headers,
timeout=self.connection_params.timeout,
sse_read_timeout=self.connection_params.sse_read_timeout,
)
else:
raise ValueError(
'Unable to initialize connection. Connection should be'
' StdioServerParameters or SseServerParams, but got'
f' {self.connection_params}'
)
transports = await self.exit_stack.enter_async_context(client)
self.session = await self.exit_stack.enter_async_context(
ClientSession(*transports)
)
await self.session.initialize()
self.session = await self.session_manager.create_session()
return self.session
async def _exit(self):
"""Closes the connection to MCP Server."""
await self.exit_stack.aclose()
@retry_on_closed_resource('_initialize')
async def load_tools(self) -> List[MCPTool]:
"""Loads all tools from the MCP Server.
@@ -252,7 +242,11 @@ class MCPToolset:
"""
tools_response: ListToolsResult = await self.session.list_tools()
return [
MCPTool(mcp_tool=tool, mcp_session=self.session)
MCPTool(
mcp_tool=tool,
mcp_session=self.session,
mcp_session_manager=self.session_manager,
)
for tool in tools_response.tools
]

View File

@@ -28,7 +28,7 @@ from typing_extensions import override
from ....auth.auth_credential import AuthCredential
from ....auth.auth_schemes import AuthScheme
from ....tools import BaseTool
from ....tools.base_tool import BaseTool
from ...tool_context import ToolContext
from ..auth.auth_helpers import credential_to_param
from ..auth.auth_helpers import dict_to_auth_scheme

View File

@@ -13,4 +13,4 @@
# limitations under the License.
# version: date+base_cl
__version__ = "0.1.0"
__version__ = "0.1.1"