Agent Development Kit(ADK)

An easy-to-use and powerful framework to build AI agents.
This commit is contained in:
hangfei
2025-04-08 17:22:09 +00:00
parent f92478bd5c
commit 9827820143
299 changed files with 44398 additions and 2 deletions
+15
View File
@@ -0,0 +1,15 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .cli_tools_click import main
+18
View File
@@ -0,0 +1,18 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .cli_tools_click import main
if __name__ == '__main__':
main()
+122
View File
@@ -0,0 +1,122 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Union
import graphviz
from ..agents import BaseAgent
from ..agents.llm_agent import LlmAgent
from ..tools.agent_tool import AgentTool
from ..tools.base_tool import BaseTool
from ..tools.function_tool import FunctionTool
from ..tools.retrieval.base_retrieval_tool import BaseRetrievalTool
def build_graph(graph, agent: BaseAgent, highlight_pairs):
dark_green = '#0F5223'
light_green = '#69CB87'
light_gray = '#cccccc'
def get_node_name(tool_or_agent: Union[BaseAgent, BaseTool]):
if isinstance(tool_or_agent, BaseAgent):
return tool_or_agent.name
elif isinstance(tool_or_agent, BaseTool):
return tool_or_agent.name
else:
raise ValueError(f'Unsupported tool type: {tool_or_agent}')
def get_node_caption(tool_or_agent: Union[BaseAgent, BaseTool]):
if isinstance(tool_or_agent, BaseAgent):
return '🤖 ' + tool_or_agent.name
elif isinstance(tool_or_agent, BaseRetrievalTool):
return '🔎 ' + tool_or_agent.name
elif isinstance(tool_or_agent, FunctionTool):
return '🔧 ' + tool_or_agent.name
elif isinstance(tool_or_agent, AgentTool):
return '🤖 ' + tool_or_agent.name
elif isinstance(tool_or_agent, BaseTool):
return '🔧 ' + tool_or_agent.name
else:
raise ValueError(f'Unsupported tool type: {type(tool)}')
def get_node_shape(tool_or_agent: Union[BaseAgent, BaseTool]):
if isinstance(tool_or_agent, BaseAgent):
return 'ellipse'
elif isinstance(tool_or_agent, BaseRetrievalTool):
return 'cylinder'
elif isinstance(tool_or_agent, FunctionTool):
return 'box'
elif isinstance(tool_or_agent, BaseTool):
return 'box'
else:
raise ValueError(f'Unsupported tool type: {type(tool_or_agent)}')
def draw_node(tool_or_agent: Union[BaseAgent, BaseTool]):
name = get_node_name(tool_or_agent)
shape = get_node_shape(tool_or_agent)
caption = get_node_caption(tool_or_agent)
if highlight_pairs:
for highlight_tuple in highlight_pairs:
if name in highlight_tuple:
graph.node(
name,
caption,
style='filled,rounded',
fillcolor=dark_green,
color=dark_green,
shape=shape,
fontcolor=light_gray,
)
return
# if not in highlight, draw non-highliht node
graph.node(
name,
caption,
shape=shape,
style='rounded',
color=light_gray,
fontcolor=light_gray,
)
def draw_edge(from_name, to_name):
if highlight_pairs:
for highlight_from, highlight_to in highlight_pairs:
if from_name == highlight_from and to_name == highlight_to:
graph.edge(from_name, to_name, color=light_green)
return
elif from_name == highlight_to and to_name == highlight_from:
graph.edge(from_name, to_name, color=light_green, dir='back')
return
# if no need to highlight, color gray
graph.edge(from_name, to_name, arrowhead='none', color=light_gray)
draw_node(agent)
for sub_agent in agent.sub_agents:
build_graph(graph, sub_agent, highlight_pairs)
draw_edge(agent.name, sub_agent.name)
if isinstance(agent, LlmAgent):
for tool in agent.canonical_tools:
draw_node(tool)
draw_edge(agent.name, get_node_name(tool))
def get_agent_graph(root_agent, highlights_pairs, image=False):
print('build graph')
graph = graphviz.Digraph(graph_attr={'rankdir': 'LR', 'bgcolor': '#333537'})
build_graph(graph, root_agent, highlights_pairs)
if image:
return graph.pipe(format='png')
else:
return graph
+181
View File
@@ -0,0 +1,181 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from datetime import datetime
import importlib
import os
import sys
from typing import Optional
import click
from google.genai import types
from pydantic import BaseModel
from ..agents.llm_agent import LlmAgent
from ..artifacts import BaseArtifactService
from ..artifacts import InMemoryArtifactService
from ..runners import Runner
from ..sessions.base_session_service import BaseSessionService
from ..sessions.in_memory_session_service import InMemorySessionService
from ..sessions.session import Session
from .utils import envs
class InputFile(BaseModel):
state: dict[str, object]
queries: list[str]
async def run_input_file(
app_name: str,
root_agent: LlmAgent,
artifact_service: BaseArtifactService,
session: Session,
session_service: BaseSessionService,
input_path: str,
) -> None:
runner = Runner(
app_name=app_name,
agent=root_agent,
artifact_service=artifact_service,
session_service=session_service,
)
with open(input_path, 'r', encoding='utf-8') as f:
input_file = InputFile.model_validate_json(f.read())
input_file.state['_time'] = datetime.now()
session.state = input_file.state
for query in input_file.queries:
click.echo(f'user: {query}')
content = types.Content(role='user', parts=[types.Part(text=query)])
async for event in runner.run_async(
user_id=session.user_id, session_id=session.id, new_message=content
):
if event.content and event.content.parts:
if text := ''.join(part.text or '' for part in event.content.parts):
click.echo(f'[{event.author}]: {text}')
async def run_interactively(
app_name: str,
root_agent: LlmAgent,
artifact_service: BaseArtifactService,
session: Session,
session_service: BaseSessionService,
) -> None:
runner = Runner(
app_name=app_name,
agent=root_agent,
artifact_service=artifact_service,
session_service=session_service,
)
while True:
query = input('user: ')
if query == 'exit':
break
async for event in runner.run_async(
user_id=session.user_id,
session_id=session.id,
new_message=types.Content(role='user', parts=[types.Part(text=query)]),
):
if event.content and event.content.parts:
if text := ''.join(part.text or '' for part in event.content.parts):
click.echo(f'[{event.author}]: {text}')
async def run_cli(
*,
agent_parent_dir: str,
agent_folder_name: str,
json_file_path: Optional[str] = None,
save_session: bool,
) -> None:
"""Runs an interactive CLI for a certain agent.
Args:
agent_parent_dir: str, the absolute path of the parent folder of the agent
folder.
agent_folder_name: str, the name of the agent folder.
json_file_path: Optional[str], the absolute path to the json file, either
*.input.json or *.session.json.
save_session: bool, whether to save the session on exit.
"""
if agent_parent_dir not in sys.path:
sys.path.append(agent_parent_dir)
artifact_service = InMemoryArtifactService()
session_service = InMemorySessionService()
session = session_service.create_session(
app_name=agent_folder_name, user_id='test_user'
)
agent_module_path = os.path.join(agent_parent_dir, agent_folder_name)
agent_module = importlib.import_module(agent_folder_name)
root_agent = agent_module.agent.root_agent
envs.load_dotenv_for_agent(agent_folder_name, agent_parent_dir)
if json_file_path:
if json_file_path.endswith('.input.json'):
await run_input_file(
app_name=agent_folder_name,
root_agent=root_agent,
artifact_service=artifact_service,
session=session,
session_service=session_service,
input_path=json_file_path,
)
elif json_file_path.endswith('.session.json'):
with open(json_file_path, 'r') as f:
session = Session.model_validate_json(f.read())
for content in session.get_contents():
if content.role == 'user':
print('user: ', content.parts[0].text)
else:
print(content.parts[0].text)
await run_interactively(
agent_folder_name,
root_agent,
artifact_service,
session,
session_service,
)
else:
print(f'Unsupported file type: {json_file_path}')
exit(1)
else:
print(f'Running agent {root_agent.name}, type exit to exit.')
await run_interactively(
agent_folder_name,
root_agent,
artifact_service,
session,
session_service,
)
if save_session:
if json_file_path:
session_path = json_file_path.replace('.input.json', '.session.json')
else:
session_id = input('Session ID to save: ')
session_path = f'{agent_module_path}/{session_id}.session.json'
with open(session_path, 'w') as f:
f.write(session.model_dump_json(indent=2, exclude_none=True))
# TODO: Save from opentelemetry.
# logs_path = session_path.replace('.session.json', '.logs.json')
# with open(logs_path, 'w') as f:
# f.write(
# session.model_dump_json(
# indent=2, exclude_none=True, include='event_logs'
# )
# )
print('Session saved to', session_path)
+181
View File
@@ -0,0 +1,181 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import shutil
import subprocess
from typing import Optional
import click
_DOCKERFILE_TEMPLATE = """
FROM python:3.11-slim
WORKDIR /app
# Create a non-root user
RUN adduser --disabled-password --gecos "" myuser
# Change ownership of /app to myuser
RUN chown -R myuser:myuser /app
# Switch to the non-root user
USER myuser
# Set up environment variables - Start
ENV PATH="/home/myuser/.local/bin:$PATH"
ENV GOOGLE_GENAI_USE_VERTEXAI=1
# TODO: use passed-in value
ENV GOOGLE_CLOUD_PROJECT={gcp_project_id}
ENV GOOGLE_CLOUD_LOCATION={gcp_region}
ENV ADK_TRACE_TO_CLOUD={with_cloud_trace}
# Set up environment variables - End
# Install ADK - Start
RUN pip install google-adk
# Install ADK - End
# Copy agent - Start
COPY "agents/{app_name}/" "/app/agents/{app_name}/"
{install_agent_deps}
# Copy agent - End
EXPOSE {port}
CMD adk {command} --port={port} "/app/agents"
"""
def _resolve_project(project_in_option: Optional[str]) -> str:
if project_in_option:
return project_in_option
result = subprocess.run(
['gcloud', 'config', 'get-value', 'project'],
check=True,
capture_output=True,
text=True,
)
project = result.stdout.strip()
click.echo(f'Use default project: {project}')
return project
def to_cloud_run(
*,
agent_folder: str,
project: Optional[str],
region: Optional[str],
service_name: str,
app_name: str,
temp_folder: str,
port: int,
with_cloud_trace: bool,
with_ui: bool,
):
"""Deploys an agent to Google Cloud Run.
`agent_folder` should contain the following files:
- __init__.py
- agent.py
- requirements.txt (optional, for additional dependencies)
- ... (other required source files)
The folder structure of temp_folder will be
* dist/[google_adk wheel file]
* agents/[app_name]/
* agent source code from `agent_folder`
Args:
agent_folder: The folder (absolute path) containing the agent source code.
project: Google Cloud project id.
region: Google Cloud region.
service_name: The service name in Cloud Run.
app_name: The name of the app, by default, it's basename of `agent_folder`.
temp_folder: The temp folder for the generated Cloud Run source files.
port: The port of the ADK api server.
with_cloud_trace: Whether to enable Cloud Trace.
with_ui: Whether to deploy with UI.
"""
app_name = app_name or os.path.basename(agent_folder)
click.echo(f'Start generating Cloud Run source files in {temp_folder}')
# remove temp_folder if exists
if os.path.exists(temp_folder):
click.echo('Removing existing files')
shutil.rmtree(temp_folder)
try:
# copy agent source code
click.echo('Copying agent source code...')
agent_src_path = os.path.join(temp_folder, 'agents', app_name)
shutil.copytree(agent_folder, agent_src_path)
requirements_txt_path = os.path.join(agent_src_path, 'requirements.txt')
install_agent_deps = (
f'RUN pip install -r "/app/agents/{app_name}/requirements.txt"'
if os.path.exists(requirements_txt_path)
else ''
)
click.echo('Copying agent source code complete.')
# create Dockerfile
click.echo('Creating Dockerfile...')
dockerfile_content = _DOCKERFILE_TEMPLATE.format(
gcp_project_id=project,
gcp_region=region,
app_name=app_name,
port=port,
command='web' if with_ui else 'api_server',
install_agent_deps=install_agent_deps,
with_cloud_trace='1' if with_cloud_trace else '0',
)
dockerfile_path = os.path.join(temp_folder, 'Dockerfile')
os.makedirs(temp_folder, exist_ok=True)
with open(dockerfile_path, 'w', encoding='utf-8') as f:
f.write(
dockerfile_content,
)
click.echo(f'Creating Dockerfile complete: {dockerfile_path}')
# Deploy to Cloud Run
click.echo('Deploying to Cloud Run...')
region_options = ['--region', region] if region else []
project = _resolve_project(project)
subprocess.run(
[
'gcloud',
'run',
'deploy',
service_name,
'--source',
temp_folder,
'--project',
project,
*region_options,
'--port',
str(port),
'--labels',
'created-by=adk',
],
check=True,
)
finally:
click.echo(f'Cleaning up the temp folder: {temp_folder}')
shutil.rmtree(temp_folder)
+282
View File
@@ -0,0 +1,282 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from enum import Enum
import importlib.util
import json
import logging
import os
import sys
import traceback
from typing import Any
from typing import Generator
from typing import Optional
import uuid
from pydantic import BaseModel
from ..agents import Agent
logger = logging.getLogger(__name__)
class EvalStatus(Enum):
PASSED = 1
FAILED = 2
NOT_EVALUATED = 3
class EvalMetric(BaseModel):
metric_name: str
threshold: float
class EvalMetricResult(BaseModel):
score: Optional[float]
eval_status: EvalStatus
class EvalResult(BaseModel):
eval_set_file: str
eval_id: str
final_eval_status: EvalStatus
eval_metric_results: list[tuple[EvalMetric, EvalMetricResult]]
session_id: str
MISSING_EVAL_DEPENDENCIES_MESSAGE = (
"Eval module is not installed, please install via `pip install"
" google-adk[eval]`."
)
TOOL_TRAJECTORY_SCORE_KEY = "tool_trajectory_avg_score"
RESPONSE_MATCH_SCORE_KEY = "response_match_score"
# This evaluation is not very stable.
# This is always optional unless explicitly specified.
RESPONSE_EVALUATION_SCORE_KEY = "response_evaluation_score"
EVAL_SESSION_ID_PREFIX = "___eval___session___"
DEFAULT_CRITERIA = {
TOOL_TRAJECTORY_SCORE_KEY: 1.0, # 1-point scale; 1.0 is perfect.
RESPONSE_MATCH_SCORE_KEY: 0.8,
}
def _import_from_path(module_name, file_path):
spec = importlib.util.spec_from_file_location(module_name, file_path)
module = importlib.util.module_from_spec(spec)
sys.modules[module_name] = module
spec.loader.exec_module(module)
return module
def _get_agent_module(agent_module_file_path: str):
file_path = os.path.join(agent_module_file_path, "__init__.py")
module_name = "agent"
return _import_from_path(module_name, file_path)
def get_evaluation_criteria_or_default(
eval_config_file_path: str,
) -> dict[str, float]:
"""Returns evaluation criteria from the config file, if present.
Otherwise a default one is returned.
"""
if eval_config_file_path:
with open(eval_config_file_path, "r", encoding="utf-8") as f:
config_data = json.load(f)
if "criteria" in config_data and isinstance(config_data["criteria"], dict):
evaluation_criteria = config_data["criteria"]
else:
raise ValueError(
f"Invalid format for test_config.json at {eval_config_file_path}."
" Expected a 'criteria' dictionary."
)
else:
logger.info("No config file supplied. Using default criteria.")
evaluation_criteria = DEFAULT_CRITERIA
return evaluation_criteria
def get_root_agent(agent_module_file_path: str) -> Agent:
"""Returns root agent given the agetn module."""
agent_module = _get_agent_module(agent_module_file_path)
root_agent = agent_module.agent.root_agent
return root_agent
def try_get_reset_func(agent_module_file_path: str) -> Any:
"""Returns reset function for the agent, if present, given the agetn module."""
agent_module = _get_agent_module(agent_module_file_path)
reset_func = getattr(agent_module.agent, "reset_data", None)
return reset_func
def parse_and_get_evals_to_run(
eval_set_file_path: tuple[str],
) -> dict[str, list[str]]:
"""Returns a dictionary of eval sets to evals that should be run."""
eval_set_to_evals = {}
for input_eval_set in eval_set_file_path:
evals = []
if ":" not in input_eval_set:
eval_set_file = input_eval_set
else:
eval_set_file = input_eval_set.split(":")[0]
evals = input_eval_set.split(":")[1].split(",")
if eval_set_file not in eval_set_to_evals:
eval_set_to_evals[eval_set_file] = []
eval_set_to_evals[eval_set_file].extend(evals)
return eval_set_to_evals
def run_evals(
eval_set_to_evals: dict[str, list[str]],
root_agent: Agent,
reset_func: Optional[Any],
eval_metrics: list[EvalMetric],
session_service=None,
artifact_service=None,
print_detailed_results=False,
) -> Generator[EvalResult, None, None]:
try:
from ..evaluation.agent_evaluator import EvaluationGenerator
from ..evaluation.response_evaluator import ResponseEvaluator
from ..evaluation.trajectory_evaluator import TrajectoryEvaluator
except ModuleNotFoundError as e:
raise ModuleNotFoundError(MISSING_EVAL_DEPENDENCIES_MESSAGE) from e
"""Returns a summary of eval runs."""
for eval_set_file, evals_to_run in eval_set_to_evals.items():
with open(eval_set_file, "r", encoding="utf-8") as file:
eval_items = json.load(file) # Load JSON into a list
assert eval_items, f"No eval data found in eval set file: {eval_set_file}"
for eval_item in eval_items:
eval_name = eval_item["name"]
eval_data = eval_item["data"]
initial_session = eval_item.get("initial_session", {})
if evals_to_run and eval_name not in evals_to_run:
continue
try:
print(f"Running Eval: {eval_set_file}:{eval_name}")
session_id = f"{EVAL_SESSION_ID_PREFIX}{str(uuid.uuid4())}"
scrape_result = EvaluationGenerator._process_query_with_root_agent(
data=eval_data,
root_agent=root_agent,
reset_func=reset_func,
initial_session=initial_session,
session_id=session_id,
session_service=session_service,
artifact_service=artifact_service,
)
eval_metric_results = []
for eval_metric in eval_metrics:
eval_metric_result = None
if eval_metric.metric_name == TOOL_TRAJECTORY_SCORE_KEY:
score = TrajectoryEvaluator.evaluate(
[scrape_result], print_detailed_results=print_detailed_results
)
eval_metric_result = _get_eval_metric_result(eval_metric, score)
elif eval_metric.metric_name == RESPONSE_MATCH_SCORE_KEY:
score = ResponseEvaluator.evaluate(
[scrape_result],
[RESPONSE_MATCH_SCORE_KEY],
print_detailed_results=print_detailed_results,
)
eval_metric_result = _get_eval_metric_result(
eval_metric, score["rouge_1/mean"].item()
)
elif eval_metric.metric_name == RESPONSE_EVALUATION_SCORE_KEY:
score = ResponseEvaluator.evaluate(
[scrape_result],
[RESPONSE_EVALUATION_SCORE_KEY],
print_detailed_results=print_detailed_results,
)
eval_metric_result = _get_eval_metric_result(
eval_metric, score["coherence/mean"].item()
)
else:
logger.warning("`%s` is not supported.", eval_metric.metric_name)
eval_metric_results.append((
eval_metric,
EvalMetricResult(eval_status=EvalStatus.NOT_EVALUATED),
))
eval_metric_results.append((
eval_metric,
eval_metric_result,
))
_print_eval_metric_result(eval_metric, eval_metric_result)
final_eval_status = EvalStatus.NOT_EVALUATED
# Go over the all the eval statuses and mark the final eval status as
# passed if all of them pass, otherwise mark the final eval status to
# failed.
for eval_metric_result in eval_metric_results:
eval_status = eval_metric_result[1].eval_status
if eval_status == EvalStatus.PASSED:
final_eval_status = EvalStatus.PASSED
elif eval_status == EvalStatus.NOT_EVALUATED:
continue
elif eval_status == EvalStatus.FAILED:
final_eval_status = EvalStatus.FAILED
break
else:
raise ValueError("Unknown eval status.")
yield EvalResult(
eval_set_file=eval_set_file,
eval_id=eval_name,
final_eval_status=final_eval_status,
eval_metric_results=eval_metric_results,
session_id=session_id,
)
if final_eval_status == EvalStatus.PASSED:
result = "✅ Passsed"
else:
result = "❌ Failed"
print(f"Result: {result}\n")
except Exception as e:
print(f"Error: {e}")
logger.info("Error: %s", str(traceback.format_exc()))
def _get_eval_metric_result(eval_metric, score):
eval_status = (
EvalStatus.PASSED if score >= eval_metric.threshold else EvalStatus.FAILED
)
return EvalMetricResult(score=score, eval_status=eval_status)
def _print_eval_metric_result(eval_metric, eval_metric_result):
print(
f"Metric: {eval_metric.metric_name}\tStatus:"
f" {eval_metric_result.eval_status}\tScore:"
f" {eval_metric_result.score}\tThreshold: {eval_metric.threshold}"
)
+479
View File
@@ -0,0 +1,479 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import asyncio
from datetime import datetime
import logging
import os
import tempfile
from typing import Optional
import click
import uvicorn
from . import cli_deploy
from .cli import run_cli
from .cli_eval import MISSING_EVAL_DEPENDENCIES_MESSAGE
from .fast_api import get_fast_api_app
from .utils import envs
from .utils import logs
logger = logging.getLogger(__name__)
@click.group(context_settings={"max_content_width": 240})
def main():
"""Agent Development Kit CLI tools."""
pass
@main.group()
def deploy():
"""Deploy Agent."""
pass
@main.command("run")
@click.option(
"--save_session",
type=bool,
is_flag=True,
show_default=True,
default=False,
help="Optional. Whether to save the session to a json file on exit.",
)
@click.argument(
"agent",
type=click.Path(
exists=True, dir_okay=True, file_okay=False, resolve_path=True
),
)
def cli_run(agent: str, save_session: bool):
"""Run an interactive CLI for a certain agent.
AGENT: The path to the agent source code folder.
Example:
adk run path/to/my_agent
"""
logs.log_to_tmp_folder()
agent_parent_folder = os.path.dirname(agent)
agent_folder_name = os.path.basename(agent)
asyncio.run(
run_cli(
agent_parent_dir=agent_parent_folder,
agent_folder_name=agent_folder_name,
save_session=save_session,
)
)
@main.command("eval")
@click.argument(
"agent_module_file_path",
type=click.Path(
exists=True, dir_okay=True, file_okay=False, resolve_path=True
),
)
@click.argument("eval_set_file_path", nargs=-1)
@click.option("--config_file_path", help="Optional. The path to config file.")
@click.option(
"--print_detailed_results",
is_flag=True,
show_default=True,
default=False,
help="Optional. Whether to print detailed results on console or not.",
)
def eval_command(
agent_module_file_path: str,
eval_set_file_path: tuple[str],
config_file_path: str,
print_detailed_results: bool,
):
"""Evaluates an agent given the eval sets.
AGENT_MODULE_FILE_PATH: The path to the __init__.py file that contains a
module by the name "agent". "agent" module contains a root_agent.
EVAL_SET_FILE_PATH: You can specify one or more eval set file paths.
For each file, all evals will be run by default.
If you want to run only specific evals from a eval set, first create a comma
separated list of eval names and then add that as a suffix to the eval set
file name, demarcated by a `:`.
For example,
sample_eval_set_file.json:eval_1,eval_2,eval_3
This will only run eval_1, eval_2 and eval_3 from sample_eval_set_file.json.
CONFIG_FILE_PATH: The path to config file.
PRINT_DETAILED_RESULTS: Prints detailed results on the console.
"""
envs.load_dotenv_for_agent(agent_module_file_path, ".")
try:
from .cli_eval import EvalMetric
from .cli_eval import EvalResult
from .cli_eval import EvalStatus
from .cli_eval import get_evaluation_criteria_or_default
from .cli_eval import get_root_agent
from .cli_eval import parse_and_get_evals_to_run
from .cli_eval import run_evals
from .cli_eval import try_get_reset_func
except ModuleNotFoundError:
raise click.ClickException(MISSING_EVAL_DEPENDENCIES_MESSAGE)
evaluation_criteria = get_evaluation_criteria_or_default(config_file_path)
eval_metrics = []
for metric_name, threshold in evaluation_criteria.items():
eval_metrics.append(
EvalMetric(metric_name=metric_name, threshold=threshold)
)
print(f"Using evaluation creiteria: {evaluation_criteria}")
root_agent = get_root_agent(agent_module_file_path)
reset_func = try_get_reset_func(agent_module_file_path)
eval_set_to_evals = parse_and_get_evals_to_run(eval_set_file_path)
try:
eval_results = list(
run_evals(
eval_set_to_evals,
root_agent,
reset_func,
eval_metrics,
print_detailed_results=print_detailed_results,
)
)
except ModuleNotFoundError:
raise click.ClickException(MISSING_EVAL_DEPENDENCIES_MESSAGE)
print("*********************************************************************")
eval_run_summary = {}
for eval_result in eval_results:
eval_result: EvalResult
if eval_result.eval_set_file not in eval_run_summary:
eval_run_summary[eval_result.eval_set_file] = [0, 0]
if eval_result.final_eval_status == EvalStatus.PASSED:
eval_run_summary[eval_result.eval_set_file][0] += 1
else:
eval_run_summary[eval_result.eval_set_file][1] += 1
print("Eval Run Summary")
for eval_set_file, pass_fail_count in eval_run_summary.items():
print(
f"{eval_set_file}:\n Tests passed: {pass_fail_count[0]}\n Tests"
f" failed: {pass_fail_count[1]}"
)
@main.command("web")
@click.option(
"--session_db_url",
help=(
"Optional. The database URL to store the session.\n\n - Use"
" 'agentengine://<agent_engine_resource_id>' to connect to Vertex"
" managed session service.\n\n - Use 'sqlite://<path_to_sqlite_file>'"
" to connect to a SQLite DB.\n\n - See"
" https://docs.sqlalchemy.org/en/20/core/engines.html#backend-specific-urls"
" for more details on supported DB URLs."
),
)
@click.option(
"--port",
type=int,
help="Optional. The port of the server",
default=8000,
)
@click.option(
"--allow_origins",
help="Optional. Any additional origins to allow for CORS.",
multiple=True,
)
@click.option(
"--log_level",
type=click.Choice(
["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"], case_sensitive=False
),
default="INFO",
help="Optional. Set the logging level",
)
@click.option(
"--log_to_tmp",
is_flag=True,
show_default=True,
default=False,
help=(
"Optional. Whether to log to system temp folder instead of console."
" This is useful for local debugging."
),
)
@click.argument(
"agents_dir",
type=click.Path(
exists=True, dir_okay=True, file_okay=False, resolve_path=True
),
default=os.getcwd(),
)
def web(
agents_dir: str,
log_to_tmp: bool,
session_db_url: str = "",
log_level: str = "INFO",
allow_origins: Optional[list[str]] = None,
port: int = 8000,
):
"""Start a FastAPI server with web UI for a certain agent.
AGENTS_DIR: The directory of agents, where each sub-directory is a single
agent, containing at least `__init__.py` and `agent.py` files.
Example:
adk web --session_db_url=[db_url] --port=[port] path/to/agents_dir
"""
if log_to_tmp:
logs.log_to_tmp_folder()
else:
logs.log_to_stderr()
logging.getLogger().setLevel(log_level)
config = uvicorn.Config(
get_fast_api_app(
agent_dir=agents_dir,
session_db_url=session_db_url,
allow_origins=allow_origins,
web=True,
),
host="0.0.0.0",
port=port,
reload=True,
)
server = uvicorn.Server(config)
server.run()
@main.command("api_server")
@click.option(
"--session_db_url",
help=(
"Optional. The database URL to store the session.\n\n - Use"
" 'agentengine://<agent_engine_resource_id>' to connect to Vertex"
" managed session service.\n\n - Use 'sqlite://<path_to_sqlite_file>'"
" to connect to a SQLite DB.\n\n - See"
" https://docs.sqlalchemy.org/en/20/core/engines.html#backend-specific-urls"
" for more details on supported DB URLs."
),
)
@click.option(
"--port",
type=int,
help="Optional. The port of the server",
default=8000,
)
@click.option(
"--allow_origins",
help="Optional. Any additional origins to allow for CORS.",
multiple=True,
)
@click.option(
"--log_level",
type=click.Choice(
["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"], case_sensitive=False
),
default="INFO",
help="Optional. Set the logging level",
)
@click.option(
"--log_to_tmp",
is_flag=True,
show_default=True,
default=False,
help=(
"Optional. Whether to log to system temp folder instead of console."
" This is useful for local debugging."
),
)
# The directory of agents, where each sub-directory is a single agent.
# By default, it is the current working directory
@click.argument(
"agents_dir",
type=click.Path(
exists=True, dir_okay=True, file_okay=False, resolve_path=True
),
default=os.getcwd(),
)
def cli_api_server(
agents_dir: str,
log_to_tmp: bool,
session_db_url: str = "",
log_level: str = "INFO",
allow_origins: Optional[list[str]] = None,
port: int = 8000,
):
"""Start an api server for a certain agent.
AGENTS_DIR: The directory of agents, where each sub-directory is a single
agent, containing at least `__init__.py` and `agent.py` files.
Example:
adk api_server --session_db_url=[db_url] --port=[port] path/to/agents_dir
"""
if log_to_tmp:
logs.log_to_tmp_folder()
else:
logs.log_to_stderr()
logging.getLogger().setLevel(log_level)
config = uvicorn.Config(
get_fast_api_app(
agent_dir=agents_dir,
session_db_url=session_db_url,
allow_origins=allow_origins,
web=False,
),
host="0.0.0.0",
port=port,
reload=True,
)
server = uvicorn.Server(config)
server.run()
@deploy.command("cloud_run")
@click.option(
"--project",
type=str,
help=(
"Required. Google Cloud project to deploy the agent. When absent,"
" default project from gcloud config is used."
),
)
@click.option(
"--region",
type=str,
help=(
"Required. Google Cloud region to deploy the agent. When absent,"
" gcloud run deploy will prompt later."
),
)
@click.option(
"--service_name",
type=str,
default="adk-default-service-name",
help=(
"Optional. The service name to use in Cloud Run (default:"
" 'adk-default-service-name')."
),
)
@click.option(
"--app_name",
type=str,
default="",
help=(
"Optional. App name of the ADK API server (default: the folder name"
" of the AGENT source code)."
),
)
@click.option(
"--port",
type=int,
default=8000,
help="Optional. The port of the ADK API server (default: 8000).",
)
@click.option(
"--with_cloud_trace",
type=bool,
is_flag=True,
show_default=True,
default=False,
help="Optional. Whether to enable Cloud Trace for cloud run.",
)
@click.option(
"--with_ui",
type=bool,
is_flag=True,
show_default=True,
default=False,
help=(
"Optional. Deploy ADK Web UI if set. (default: deploy ADK API server"
" only)"
),
)
@click.option(
"--temp_folder",
type=str,
default=os.path.join(
tempfile.gettempdir(),
"cloud_run_deploy_src",
datetime.now().strftime("%Y%m%d_%H%M%S"),
),
help=(
"Optional. Temp folder for the generated Cloud Run source files"
" (default: a timestamped folder in the system temp directory)."
),
)
@click.argument(
"agent",
type=click.Path(
exists=True, dir_okay=True, file_okay=False, resolve_path=True
),
)
def deploy_to_cloud_run(
agent: str,
project: Optional[str],
region: Optional[str],
service_name: str,
app_name: str,
temp_folder: str,
port: int,
with_cloud_trace: bool,
with_ui: bool,
):
"""Deploys agent to Cloud Run.
AGENT: The path to the agent source code folder.
Example:
adk deploy cloud_run --project=[project] --region=[region] path/to/my_agent
"""
try:
cli_deploy.to_cloud_run(
agent_folder=agent,
project=project,
region=region,
service_name=service_name,
app_name=app_name,
temp_folder=temp_folder,
port=port,
with_cloud_trace=with_cloud_trace,
with_ui=with_ui,
)
except Exception as e:
click.secho(f"Deploy failed: {e}", fg="red", err=True)
+765
View File
@@ -0,0 +1,765 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import asyncio
import importlib
import json
import logging
import os
from pathlib import Path
import re
import sys
import traceback
import typing
from typing import Any
from typing import List
from typing import Literal
from typing import Optional
import click
from fastapi import FastAPI
from fastapi import HTTPException
from fastapi import Query
from fastapi import Response
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import FileResponse
from fastapi.responses import RedirectResponse
from fastapi.responses import StreamingResponse
from fastapi.staticfiles import StaticFiles
from fastapi.websockets import WebSocket
from fastapi.websockets import WebSocketDisconnect
from google.genai import types
import graphviz
from opentelemetry import trace
from opentelemetry.exporter.cloud_trace import CloudTraceSpanExporter
from opentelemetry.sdk.trace import export
from opentelemetry.sdk.trace import ReadableSpan
from opentelemetry.sdk.trace import TracerProvider
from pydantic import BaseModel
from pydantic import ValidationError
from ..agents import RunConfig
from ..agents.live_request_queue import LiveRequest
from ..agents.live_request_queue import LiveRequestQueue
from ..agents.llm_agent import Agent
from ..agents.run_config import StreamingMode
from ..artifacts import InMemoryArtifactService
from ..events.event import Event
from ..runners import Runner
from ..sessions.database_session_service import DatabaseSessionService
from ..sessions.in_memory_session_service import InMemorySessionService
from ..sessions.session import Session
from ..sessions.vertex_ai_session_service import VertexAiSessionService
from .cli_eval import EVAL_SESSION_ID_PREFIX
from .cli_eval import EvalMetric
from .cli_eval import EvalMetricResult
from .cli_eval import EvalStatus
from .utils import create_empty_state
from .utils import envs
from .utils import evals
logger = logging.getLogger(__name__)
_EVAL_SET_FILE_EXTENSION = ".evalset.json"
class ApiServerSpanExporter(export.SpanExporter):
def __init__(self, trace_dict):
self.trace_dict = trace_dict
def export(
self, spans: typing.Sequence[ReadableSpan]
) -> export.SpanExportResult:
for span in spans:
if span.name == "call_llm" or span.name == "send_data":
attributes = dict(span.attributes)
attributes["trace_id"] = span.get_span_context().trace_id
attributes["span_id"] = span.get_span_context().span_id
if attributes.get("gcp.vertex.agent.event_id", None):
self.trace_dict[attributes["gcp.vertex.agent.event_id"]] = attributes
return export.SpanExportResult.SUCCESS
def force_flush(self, timeout_millis: int = 30000) -> bool:
return True
class AgentRunRequest(BaseModel):
app_name: str
user_id: str
session_id: str
new_message: types.Content
streaming: bool = False
class AddSessionToEvalSetRequest(BaseModel):
eval_id: str
session_id: str
user_id: str
class RunEvalRequest(BaseModel):
eval_ids: list[str] # if empty, then all evals in the eval set are run.
eval_metrics: list[EvalMetric]
class RunEvalResult(BaseModel):
eval_set_id: str
eval_id: str
final_eval_status: EvalStatus
eval_metric_results: list[tuple[EvalMetric, EvalMetricResult]]
session_id: str
def get_fast_api_app(
*,
agent_dir: str,
session_db_url: str = "",
allow_origins: Optional[list[str]] = None,
web: bool,
) -> FastAPI:
# InMemory tracing dict.
trace_dict: dict[str, Any] = {}
# Set up tracing in the FastAPI server.
provider = TracerProvider()
provider.add_span_processor(
export.SimpleSpanProcessor(ApiServerSpanExporter(trace_dict))
)
if os.environ.get("ADK_TRACE_TO_CLOUD", "0") == "1":
processor = export.BatchSpanProcessor(
CloudTraceSpanExporter(
project_id=os.environ.get("GOOGLE_CLOUD_PROJECT", "")
)
)
provider.add_span_processor(processor)
trace.set_tracer_provider(provider)
# Run the FastAPI server.
app = FastAPI()
if allow_origins:
app.add_middleware(
CORSMiddleware,
allow_origins=allow_origins,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
if agent_dir not in sys.path:
sys.path.append(agent_dir)
runner_dict = {}
root_agent_dict = {}
# Build the Artifact service
artifact_service = InMemoryArtifactService()
# Build the Session service
agent_engine_id = ""
if session_db_url:
if session_db_url.startswith("agentengine://"):
# Create vertex session service
agent_engine_id = session_db_url.split("://")[1]
if not agent_engine_id:
raise click.ClickException("Agent engine id can not be empty.")
envs.load_dotenv_for_agent("", agent_dir)
session_service = VertexAiSessionService(
os.environ["GOOGLE_CLOUD_PROJECT"],
os.environ["GOOGLE_CLOUD_LOCATION"],
)
else:
session_service = DatabaseSessionService(db_url=session_db_url)
else:
session_service = InMemorySessionService()
@app.get("/list-apps")
def list_apps() -> list[str]:
base_path = Path.cwd() / agent_dir
if not base_path.exists():
raise HTTPException(status_code=404, detail="Path not found")
if not base_path.is_dir():
raise HTTPException(status_code=400, detail="Not a directory")
agent_names = [
x
for x in os.listdir(base_path)
if os.path.isdir(os.path.join(base_path, x))
and not x.startswith(".")
and x != "__pycache__"
]
agent_names.sort()
return agent_names
@app.get("/debug/trace/{event_id}")
def get_trace_dict(event_id: str) -> Any:
event_dict = trace_dict.get(event_id, None)
if event_dict is None:
raise HTTPException(status_code=404, detail="Trace not found")
return event_dict
@app.get(
"/apps/{app_name}/users/{user_id}/sessions/{session_id}",
response_model_exclude_none=True,
)
def get_session(app_name: str, user_id: str, session_id: str) -> Session:
# Connect to managed session if agent_engine_id is set.
app_name = agent_engine_id if agent_engine_id else app_name
session = session_service.get_session(
app_name=app_name, user_id=user_id, session_id=session_id
)
if not session:
raise HTTPException(status_code=404, detail="Session not found")
return session
@app.get(
"/apps/{app_name}/users/{user_id}/sessions",
response_model_exclude_none=True,
)
def list_sessions(app_name: str, user_id: str) -> list[Session]:
# Connect to managed session if agent_engine_id is set.
app_name = agent_engine_id if agent_engine_id else app_name
return [
session
for session in session_service.list_sessions(
app_name=app_name, user_id=user_id
).sessions
# Remove sessions that were generated as a part of Eval.
if not session.id.startswith(EVAL_SESSION_ID_PREFIX)
]
@app.post(
"/apps/{app_name}/users/{user_id}/sessions/{session_id}",
response_model_exclude_none=True,
)
def create_session_with_id(
app_name: str,
user_id: str,
session_id: str,
state: Optional[dict[str, Any]] = None,
) -> Session:
# Connect to managed session if agent_engine_id is set.
app_name = agent_engine_id if agent_engine_id else app_name
if (
session_service.get_session(
app_name=app_name, user_id=user_id, session_id=session_id
)
is not None
):
logger.warning("Session already exists: %s", session_id)
raise HTTPException(
status_code=400, detail=f"Session already exists: {session_id}"
)
logger.info("New session created: %s", session_id)
return session_service.create_session(
app_name=app_name, user_id=user_id, state=state, session_id=session_id
)
@app.post(
"/apps/{app_name}/users/{user_id}/sessions",
response_model_exclude_none=True,
)
def create_session(
app_name: str,
user_id: str,
state: Optional[dict[str, Any]] = None,
) -> Session:
# Connect to managed session if agent_engine_id is set.
app_name = agent_engine_id if agent_engine_id else app_name
logger.info("New session created")
return session_service.create_session(
app_name=app_name, user_id=user_id, state=state
)
def _get_eval_set_file_path(app_name, agent_dir, eval_set_id) -> str:
return os.path.join(
agent_dir,
app_name,
eval_set_id + _EVAL_SET_FILE_EXTENSION,
)
@app.post(
"/apps/{app_name}/eval_sets/{eval_set_id}",
response_model_exclude_none=True,
)
def create_eval_set(
app_name: str,
eval_set_id: str,
):
"""Creates an eval set, given the id."""
pattern = r"^[a-zA-Z0-9_]+$"
if not bool(re.fullmatch(pattern, eval_set_id)):
raise HTTPException(
status_code=400,
detail=(
f"Invalid eval set id. Eval set id should have the `{pattern}`"
" format"
),
)
# Define the file path
new_eval_set_path = _get_eval_set_file_path(
app_name, agent_dir, eval_set_id
)
logger.info("Creating eval set file `%s`", new_eval_set_path)
if not os.path.exists(new_eval_set_path):
# Write the JSON string to the file
logger.info("Eval set file doesn't exist, we will create a new one.")
with open(new_eval_set_path, "w") as f:
empty_content = json.dumps([], indent=2)
f.write(empty_content)
@app.get(
"/apps/{app_name}/eval_sets",
response_model_exclude_none=True,
)
def list_eval_sets(app_name: str) -> list[str]:
"""Lists all eval sets for the given app."""
eval_set_file_path = os.path.join(agent_dir, app_name)
eval_sets = []
for file in os.listdir(eval_set_file_path):
if file.endswith(_EVAL_SET_FILE_EXTENSION):
eval_sets.append(
os.path.basename(file).removesuffix(_EVAL_SET_FILE_EXTENSION)
)
return sorted(eval_sets)
@app.post(
"/apps/{app_name}/eval_sets/{eval_set_id}/add_session",
response_model_exclude_none=True,
)
def add_session_to_eval_set(
app_name: str, eval_set_id: str, req: AddSessionToEvalSetRequest
):
pattern = r"^[a-zA-Z0-9_]+$"
if not bool(re.fullmatch(pattern, req.eval_id)):
raise HTTPException(
status_code=400,
detail=f"Invalid eval id. Eval id should have the `{pattern}` format",
)
# Get the session
session = session_service.get_session(
app_name=app_name, user_id=req.user_id, session_id=req.session_id
)
assert session, "Session not found."
# Load the eval set file data
eval_set_file_path = _get_eval_set_file_path(
app_name, agent_dir, eval_set_id
)
with open(eval_set_file_path, "r") as file:
eval_set_data = json.load(file) # Load JSON into a list
if [x for x in eval_set_data if x["name"] == req.eval_id]:
raise HTTPException(
status_code=400,
detail=(
f"Eval id `{req.eval_id}` already exists in `{eval_set_id}`"
" eval set."
),
)
# Convert the session data to evaluation format
test_data = evals.convert_session_to_eval_format(session)
# Populate the session with initial session state.
initial_session_state = create_empty_state(_get_root_agent(app_name))
eval_set_data.append({
"name": req.eval_id,
"data": test_data,
"initial_session": {
"state": initial_session_state,
"app_name": app_name,
"user_id": req.user_id,
},
})
# Serialize the test data to JSON and write to the eval set file.
with open(eval_set_file_path, "w") as f:
f.write(json.dumps(eval_set_data, indent=2))
@app.get(
"/apps/{app_name}/eval_sets/{eval_set_id}/evals",
response_model_exclude_none=True,
)
def list_evals_in_eval_set(
app_name: str,
eval_set_id: str,
) -> list[str]:
"""Lists all evals in an eval set."""
# Load the eval set file data
eval_set_file_path = _get_eval_set_file_path(
app_name, agent_dir, eval_set_id
)
with open(eval_set_file_path, "r") as file:
eval_set_data = json.load(file) # Load JSON into a list
return sorted([x["name"] for x in eval_set_data])
@app.post(
"/apps/{app_name}/eval_sets/{eval_set_id}/run_eval",
response_model_exclude_none=True,
)
def run_eval(
app_name: str, eval_set_id: str, req: RunEvalRequest
) -> list[RunEvalResult]:
from .cli_eval import run_evals
"""Runs an eval given the details in the eval request."""
# Create a mapping from eval set file to all the evals that needed to be
# run.
eval_set_file_path = _get_eval_set_file_path(
app_name, agent_dir, eval_set_id
)
eval_set_to_evals = {eval_set_file_path: req.eval_ids}
if not req.eval_ids:
logger.info(
"Eval ids to run list is empty. We will all evals in the eval set."
)
root_agent = _get_root_agent(app_name)
eval_results = list(
run_evals(
eval_set_to_evals,
root_agent,
getattr(root_agent, "reset_data", None),
req.eval_metrics,
session_service=session_service,
artifact_service=artifact_service,
)
)
run_eval_results = []
for eval_result in eval_results:
run_eval_results.append(
RunEvalResult(
app_name=app_name,
eval_set_id=eval_set_id,
eval_id=eval_result.eval_id,
final_eval_status=eval_result.final_eval_status,
eval_metric_results=eval_result.eval_metric_results,
session_id=eval_result.session_id,
)
)
return run_eval_results
@app.delete("/apps/{app_name}/users/{user_id}/sessions/{session_id}")
def delete_session(app_name: str, user_id: str, session_id: str):
# Connect to managed session if agent_engine_id is set.
app_name = agent_engine_id if agent_engine_id else app_name
session_service.delete_session(
app_name=app_name, user_id=user_id, session_id=session_id
)
@app.get(
"/apps/{app_name}/users/{user_id}/sessions/{session_id}/artifacts/{artifact_name}",
response_model_exclude_none=True,
)
def load_artifact(
app_name: str,
user_id: str,
session_id: str,
artifact_name: str,
version: Optional[int] = Query(None),
) -> Optional[types.Part]:
artifact = artifact_service.load_artifact(
app_name=app_name,
user_id=user_id,
session_id=session_id,
filename=artifact_name,
version=version,
)
if not artifact:
raise HTTPException(status_code=404, detail="Artifact not found")
return artifact
@app.get(
"/apps/{app_name}/users/{user_id}/sessions/{session_id}/artifacts/{artifact_name}/versions/{version_id}",
response_model_exclude_none=True,
)
def load_artifact_version(
app_name: str,
user_id: str,
session_id: str,
artifact_name: str,
version_id: int,
) -> Optional[types.Part]:
artifact = artifact_service.load_artifact(
app_name=app_name,
user_id=user_id,
session_id=session_id,
filename=artifact_name,
version=version_id,
)
if not artifact:
raise HTTPException(status_code=404, detail="Artifact not found")
return artifact
@app.get(
"/apps/{app_name}/users/{user_id}/sessions/{session_id}/artifacts",
response_model_exclude_none=True,
)
def list_artifact_names(
app_name: str, user_id: str, session_id: str
) -> list[str]:
return artifact_service.list_artifact_keys(
app_name=app_name, user_id=user_id, session_id=session_id
)
@app.get(
"/apps/{app_name}/users/{user_id}/sessions/{session_id}/artifacts/{artifact_name}/versions",
response_model_exclude_none=True,
)
def list_artifact_versions(
app_name: str, user_id: str, session_id: str, artifact_name: str
) -> list[int]:
return artifact_service.list_versions(
app_name=app_name,
user_id=user_id,
session_id=session_id,
filename=artifact_name,
)
@app.delete(
"/apps/{app_name}/users/{user_id}/sessions/{session_id}/artifacts/{artifact_name}",
)
def delete_artifact(
app_name: str, user_id: str, session_id: str, artifact_name: str
):
artifact_service.delete_artifact(
app_name=app_name,
user_id=user_id,
session_id=session_id,
filename=artifact_name,
)
@app.post("/run", response_model_exclude_none=True)
async def agent_run(req: AgentRunRequest) -> list[Event]:
# Connect to managed session if agent_engine_id is set.
app_id = agent_engine_id if agent_engine_id else req.app_name
session = session_service.get_session(
app_name=app_id, user_id=req.user_id, session_id=req.session_id
)
if not session:
raise HTTPException(status_code=404, detail="Session not found")
runner = _get_runner(req.app_name)
events = [
event
async for event in runner.run_async(
user_id=req.user_id,
session_id=req.session_id,
new_message=req.new_message,
)
]
logger.info("Generated %s events in agent run: %s", len(events), events)
return events
@app.post("/run_sse")
async def agent_run_sse(req: AgentRunRequest) -> StreamingResponse:
# Connect to managed session if agent_engine_id is set.
app_id = agent_engine_id if agent_engine_id else req.app_name
# SSE endpoint
session = session_service.get_session(
app_name=app_id, user_id=req.user_id, session_id=req.session_id
)
if not session:
raise HTTPException(status_code=404, detail="Session not found")
# Convert the events to properly formatted SSE
async def event_generator():
try:
stream_mode = StreamingMode.SSE if req.streaming else StreamingMode.NONE
runner = _get_runner(req.app_name)
async for event in runner.run_async(
user_id=req.user_id,
session_id=req.session_id,
new_message=req.new_message,
run_config=RunConfig(streaming_mode=stream_mode),
):
# Format as SSE data
sse_event = event.model_dump_json(exclude_none=True, by_alias=True)
logger.info("Generated event in agent run streaming: %s", sse_event)
yield f"data: {sse_event}\n\n"
except Exception as e:
logger.exception("Error in event_generator: %s", e)
# You might want to yield an error event here
yield f'data: {{"error": "{str(e)}"}}\n\n'
# Returns a streaming response with the proper media type for SSE
return StreamingResponse(
event_generator(),
media_type="text/event-stream",
)
@app.get(
"/apps/{app_name}/users/{user_id}/sessions/{session_id}/events/{event_id}/graph",
response_model_exclude_none=True,
)
def get_event_graph(
app_name: str, user_id: str, session_id: str, event_id: str
):
# Connect to managed session if agent_engine_id is set.
app_id = agent_engine_id if agent_engine_id else app_name
session = session_service.get_session(
app_name=app_id, user_id=user_id, session_id=session_id
)
session_events = session.events if session else []
event = next((x for x in session_events if x.id == event_id), None)
if not event:
return {}
from . import agent_graph
function_calls = event.get_function_calls()
function_responses = event.get_function_responses()
root_agent = _get_root_agent(app_name)
dot_graph = None
if function_calls:
function_call_highlights = []
for function_call in function_calls:
from_name = event.author
to_name = function_call.name
function_call_highlights.append((from_name, to_name))
dot_graph = agent_graph.get_agent_graph(
root_agent, function_call_highlights
)
elif function_responses:
function_responses_highlights = []
for function_response in function_responses:
from_name = function_response.name
to_name = event.author
function_responses_highlights.append((from_name, to_name))
dot_graph = agent_graph.get_agent_graph(
root_agent, function_responses_highlights
)
else:
from_name = event.author
to_name = ""
dot_graph = agent_graph.get_agent_graph(
root_agent, [(from_name, to_name)]
)
if dot_graph and isinstance(dot_graph, graphviz.Digraph):
return {"dot_src": dot_graph.source}
else:
return {}
@app.websocket("/run_live")
async def agent_live_run(
websocket: WebSocket,
app_name: str,
user_id: str,
session_id: str,
modalities: List[Literal["TEXT", "AUDIO"]] = Query(
default=["TEXT", "AUDIO"]
), # Only allows "TEXT" or "AUDIO"
) -> None:
await websocket.accept()
# Connect to managed session if agent_engine_id is set.
app_id = agent_engine_id if agent_engine_id else app_name
session = session_service.get_session(
app_name=app_id, user_id=user_id, session_id=session_id
)
if not session:
# Accept first so that the client is aware of connection establishment,
# then close with a specific code.
await websocket.close(code=1002, reason="Session not found")
return
live_request_queue = LiveRequestQueue()
async def forward_events():
runner = _get_runner(app_name)
async for event in runner.run_live(
session=session, live_request_queue=live_request_queue
):
await websocket.send_text(
event.model_dump_json(exclude_none=True, by_alias=True)
)
async def process_messages():
try:
while True:
data = await websocket.receive_text()
# Validate and send the received message to the live queue.
live_request_queue.send(LiveRequest.model_validate_json(data))
except ValidationError as ve:
logger.error("Validation error in process_messages: %s", ve)
# Run both tasks concurrently and cancel all if one fails.
tasks = [
asyncio.create_task(forward_events()),
asyncio.create_task(process_messages()),
]
done, pending = await asyncio.wait(
tasks, return_when=asyncio.FIRST_EXCEPTION
)
try:
# This will re-raise any exception from the completed tasks.
for task in done:
task.result()
except WebSocketDisconnect:
logger.info("Client disconnected during process_messages.")
except Exception as e:
logger.exception("Error during live websocket communication: %s", e)
traceback.print_exc()
finally:
for task in pending:
task.cancel()
def _get_root_agent(app_name: str) -> Agent:
"""Returns the root agent for the given app."""
if app_name in root_agent_dict:
return root_agent_dict[app_name]
envs.load_dotenv_for_agent(os.path.basename(app_name), agent_dir)
agent_module = importlib.import_module(app_name)
root_agent: Agent = agent_module.agent.root_agent
root_agent_dict[app_name] = root_agent
return root_agent
def _get_runner(app_name: str) -> Runner:
"""Returns the runner for the given app."""
if app_name in runner_dict:
return runner_dict[app_name]
root_agent = _get_root_agent(app_name)
runner = Runner(
app_name=agent_engine_id if agent_engine_id else app_name,
agent=root_agent,
artifact_service=artifact_service,
session_service=session_service,
)
runner_dict[app_name] = runner
return runner
if web:
BASE_DIR = Path(__file__).parent.resolve()
ANGULAR_DIST_PATH = BASE_DIR / "browser"
@app.get("/")
async def redirect_to_dev_ui():
return RedirectResponse("/dev-ui")
@app.get("/dev-ui")
async def dev_ui():
return FileResponse(BASE_DIR / "browser/index.html")
app.mount(
"/", StaticFiles(directory=ANGULAR_DIST_PATH, html=True), name="static"
)
return app
+49
View File
@@ -0,0 +1,49 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import re
from typing import Any
from typing import Optional
from ...agents.base_agent import BaseAgent
from ...agents.llm_agent import LlmAgent
__all__ = [
'create_empty_state',
]
def _create_empty_state(agent: BaseAgent, all_state: dict[str, Any]):
for sub_agent in agent.sub_agents:
_create_empty_state(sub_agent, all_state)
if (
isinstance(agent, LlmAgent)
and agent.instruction
and isinstance(agent.instruction, str)
):
for key in re.findall(r'{([\w]+)}', agent.instruction):
all_state[key] = ''
def create_empty_state(
agent: BaseAgent, initialized_states: Optional[dict[str, Any]] = None
) -> dict[str, Any]:
"""Creates empty str for non-initialized states."""
non_initialized_states = {}
_create_empty_state(agent, non_initialized_states)
for key in initialized_states or {}:
if key in non_initialized_states:
del non_initialized_states[key]
return non_initialized_states
+57
View File
@@ -0,0 +1,57 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import os
from dotenv import load_dotenv
logger = logging.getLogger(__file__)
def _walk_to_root_until_found(folder, filename) -> str:
checkpath = os.path.join(folder, filename)
if os.path.exists(checkpath) and os.path.isfile(checkpath):
return checkpath
parent_folder = os.path.dirname(folder)
if parent_folder == folder: # reached the root
return ''
return _walk_to_root_until_found(parent_folder, filename)
def load_dotenv_for_agent(
agent_name: str, agent_parent_folder: str, filename: str = '.env'
):
"""Lods the .env file for the agent module."""
# Gets the folder of agent_module as starting_folder
starting_folder = os.path.abspath(
os.path.join(agent_parent_folder, agent_name)
)
dotenv_file_path = _walk_to_root_until_found(starting_folder, filename)
if dotenv_file_path:
load_dotenv(dotenv_file_path, override=True, verbose=True)
logger.info(
'Loaded %s file for %s at %s',
filename,
agent_name,
dotenv_file_path,
)
logger.info(
'Reloaded %s file for %s at %s', filename, agent_name, dotenv_file_path
)
else:
logger.info('No %s file found for %s', filename, agent_name)
+93
View File
@@ -0,0 +1,93 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any
from ...sessions.session import Session
def convert_session_to_eval_format(session: Session) -> list[dict[str, Any]]:
"""Converts a session data into eval format.
Args:
session: The session that should be converted.
Returns:
list: A single evaluation dataset in the required format.
"""
eval_case = []
events = session.events if session and session.events else []
for event in events:
if event.author == 'user':
if not event.content or not event.content.parts:
continue
# Extract user query
content = event.content
parts = content.parts
query = parts[0].text or ''
# Find the corresponding tool usage or response for the query
expected_tool_use = []
intermediate_agent_responses = []
# Check subsequent events to extract tool uses or responses for this turn.
for subsequent_event in events[events.index(event) + 1 :]:
event_author = subsequent_event.author or 'agent'
if event_author == 'user':
# We found an event where the author was the user. This means that a
# new turn has started. So close this turn here.
break
if not subsequent_event.content or not subsequent_event.content.parts:
continue
for subsequent_part in subsequent_event.content.parts:
# Some events have both function call and reference
if subsequent_part.function_call:
tool_name = subsequent_part.function_call.name or ''
tool_input = subsequent_part.function_call.args or {}
expected_tool_use.append({
'tool_name': tool_name,
'tool_input': tool_input,
})
elif subsequent_part.text:
# Also keep track of all the natural langauge responses that
# agent (or sub agents) generated.
intermediate_agent_responses.append(
{'author': event_author, 'text': subsequent_part.text}
)
# If we are here then either we are done reading all the events or we
# encountered an event that had content authored by the end-user.
# This, basically means an end of turn.
# We assume that the last natural langauge intermediate response is the
# final response from the agent/model. We treat that as a reference.
eval_case.append({
'query': query,
'expected_tool_use': expected_tool_use,
'expected_intermediate_agent_responses': intermediate_agent_responses[
:-1
],
'reference': (
intermediate_agent_responses[-1]['text']
if intermediate_agent_responses
else ''
),
})
return eval_case
+72
View File
@@ -0,0 +1,72 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import os
import tempfile
import time
LOGGING_FORMAT = (
'%(asctime)s - %(levelname)s - %(filename)s:%(lineno)d - %(message)s'
)
def log_to_stderr(level=logging.INFO):
logging.basicConfig(
level=level,
format=LOGGING_FORMAT,
)
def log_to_tmp_folder(
level=logging.INFO,
*,
sub_folder: str = 'agents_log',
log_file_prefix: str = 'agent',
log_file_timestamp: str = time.strftime('%Y%m%d_%H%M%S'),
):
"""Logs to system temp folder, instead of logging to stderr.
Args
sub_folder: str = 'agents_log',
log_file_prefix: str = 'agent',
log_file_timestamp: str = time.strftime('%Y%m%d_%H%M%S'),
Returns
the log file path.
"""
log_dir = os.path.join(tempfile.gettempdir(), sub_folder)
log_filename = f'{log_file_prefix}.{log_file_timestamp}.log'
log_filepath = os.path.join(log_dir, log_filename)
os.makedirs(log_dir, exist_ok=True)
file_handler = logging.FileHandler(log_filepath, mode='w')
file_handler.setLevel(level)
file_handler.setFormatter(logging.Formatter(LOGGING_FORMAT))
root_logger = logging.getLogger()
root_logger.setLevel(level)
root_logger.handlers = [] # Clear handles to disable logging to stderr
root_logger.addHandler(file_handler)
print(f'Log setup complete: {log_filepath}')
latest_log_link = os.path.join(log_dir, f'{log_file_prefix}.latest.log')
if os.path.islink(latest_log_link):
os.unlink(latest_log_link)
os.symlink(log_filepath, latest_log_link)
print(f'To access latest log: tail -F {latest_log_link}')
return log_filepath