mirror of
https://github.com/EvolutionAPI/adk-python.git
synced 2026-02-04 13:56:24 -06:00
Agent Development Kit(ADK)
An easy-to-use and powerful framework to build AI agents.
This commit is contained in:
@@ -0,0 +1,15 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from .cli_tools_click import main
|
||||
@@ -0,0 +1,18 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from .cli_tools_click import main
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
@@ -0,0 +1,122 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Union
|
||||
|
||||
import graphviz
|
||||
|
||||
from ..agents import BaseAgent
|
||||
from ..agents.llm_agent import LlmAgent
|
||||
from ..tools.agent_tool import AgentTool
|
||||
from ..tools.base_tool import BaseTool
|
||||
from ..tools.function_tool import FunctionTool
|
||||
from ..tools.retrieval.base_retrieval_tool import BaseRetrievalTool
|
||||
|
||||
|
||||
def build_graph(graph, agent: BaseAgent, highlight_pairs):
|
||||
dark_green = '#0F5223'
|
||||
light_green = '#69CB87'
|
||||
light_gray = '#cccccc'
|
||||
|
||||
def get_node_name(tool_or_agent: Union[BaseAgent, BaseTool]):
|
||||
if isinstance(tool_or_agent, BaseAgent):
|
||||
return tool_or_agent.name
|
||||
elif isinstance(tool_or_agent, BaseTool):
|
||||
return tool_or_agent.name
|
||||
else:
|
||||
raise ValueError(f'Unsupported tool type: {tool_or_agent}')
|
||||
|
||||
def get_node_caption(tool_or_agent: Union[BaseAgent, BaseTool]):
|
||||
if isinstance(tool_or_agent, BaseAgent):
|
||||
return '🤖 ' + tool_or_agent.name
|
||||
elif isinstance(tool_or_agent, BaseRetrievalTool):
|
||||
return '🔎 ' + tool_or_agent.name
|
||||
elif isinstance(tool_or_agent, FunctionTool):
|
||||
return '🔧 ' + tool_or_agent.name
|
||||
elif isinstance(tool_or_agent, AgentTool):
|
||||
return '🤖 ' + tool_or_agent.name
|
||||
elif isinstance(tool_or_agent, BaseTool):
|
||||
return '🔧 ' + tool_or_agent.name
|
||||
else:
|
||||
raise ValueError(f'Unsupported tool type: {type(tool)}')
|
||||
|
||||
def get_node_shape(tool_or_agent: Union[BaseAgent, BaseTool]):
|
||||
if isinstance(tool_or_agent, BaseAgent):
|
||||
return 'ellipse'
|
||||
elif isinstance(tool_or_agent, BaseRetrievalTool):
|
||||
return 'cylinder'
|
||||
elif isinstance(tool_or_agent, FunctionTool):
|
||||
return 'box'
|
||||
elif isinstance(tool_or_agent, BaseTool):
|
||||
return 'box'
|
||||
else:
|
||||
raise ValueError(f'Unsupported tool type: {type(tool_or_agent)}')
|
||||
|
||||
def draw_node(tool_or_agent: Union[BaseAgent, BaseTool]):
|
||||
name = get_node_name(tool_or_agent)
|
||||
shape = get_node_shape(tool_or_agent)
|
||||
caption = get_node_caption(tool_or_agent)
|
||||
if highlight_pairs:
|
||||
for highlight_tuple in highlight_pairs:
|
||||
if name in highlight_tuple:
|
||||
graph.node(
|
||||
name,
|
||||
caption,
|
||||
style='filled,rounded',
|
||||
fillcolor=dark_green,
|
||||
color=dark_green,
|
||||
shape=shape,
|
||||
fontcolor=light_gray,
|
||||
)
|
||||
return
|
||||
# if not in highlight, draw non-highliht node
|
||||
graph.node(
|
||||
name,
|
||||
caption,
|
||||
shape=shape,
|
||||
style='rounded',
|
||||
color=light_gray,
|
||||
fontcolor=light_gray,
|
||||
)
|
||||
|
||||
def draw_edge(from_name, to_name):
|
||||
if highlight_pairs:
|
||||
for highlight_from, highlight_to in highlight_pairs:
|
||||
if from_name == highlight_from and to_name == highlight_to:
|
||||
graph.edge(from_name, to_name, color=light_green)
|
||||
return
|
||||
elif from_name == highlight_to and to_name == highlight_from:
|
||||
graph.edge(from_name, to_name, color=light_green, dir='back')
|
||||
return
|
||||
# if no need to highlight, color gray
|
||||
graph.edge(from_name, to_name, arrowhead='none', color=light_gray)
|
||||
|
||||
draw_node(agent)
|
||||
for sub_agent in agent.sub_agents:
|
||||
build_graph(graph, sub_agent, highlight_pairs)
|
||||
draw_edge(agent.name, sub_agent.name)
|
||||
if isinstance(agent, LlmAgent):
|
||||
for tool in agent.canonical_tools:
|
||||
draw_node(tool)
|
||||
draw_edge(agent.name, get_node_name(tool))
|
||||
|
||||
|
||||
def get_agent_graph(root_agent, highlights_pairs, image=False):
|
||||
print('build graph')
|
||||
graph = graphviz.Digraph(graph_attr={'rankdir': 'LR', 'bgcolor': '#333537'})
|
||||
build_graph(graph, root_agent, highlights_pairs)
|
||||
if image:
|
||||
return graph.pipe(format='png')
|
||||
else:
|
||||
return graph
|
||||
@@ -0,0 +1,181 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from datetime import datetime
|
||||
import importlib
|
||||
import os
|
||||
import sys
|
||||
from typing import Optional
|
||||
|
||||
import click
|
||||
from google.genai import types
|
||||
from pydantic import BaseModel
|
||||
|
||||
from ..agents.llm_agent import LlmAgent
|
||||
from ..artifacts import BaseArtifactService
|
||||
from ..artifacts import InMemoryArtifactService
|
||||
from ..runners import Runner
|
||||
from ..sessions.base_session_service import BaseSessionService
|
||||
from ..sessions.in_memory_session_service import InMemorySessionService
|
||||
from ..sessions.session import Session
|
||||
from .utils import envs
|
||||
|
||||
|
||||
class InputFile(BaseModel):
|
||||
state: dict[str, object]
|
||||
queries: list[str]
|
||||
|
||||
|
||||
async def run_input_file(
|
||||
app_name: str,
|
||||
root_agent: LlmAgent,
|
||||
artifact_service: BaseArtifactService,
|
||||
session: Session,
|
||||
session_service: BaseSessionService,
|
||||
input_path: str,
|
||||
) -> None:
|
||||
runner = Runner(
|
||||
app_name=app_name,
|
||||
agent=root_agent,
|
||||
artifact_service=artifact_service,
|
||||
session_service=session_service,
|
||||
)
|
||||
with open(input_path, 'r', encoding='utf-8') as f:
|
||||
input_file = InputFile.model_validate_json(f.read())
|
||||
input_file.state['_time'] = datetime.now()
|
||||
|
||||
session.state = input_file.state
|
||||
for query in input_file.queries:
|
||||
click.echo(f'user: {query}')
|
||||
content = types.Content(role='user', parts=[types.Part(text=query)])
|
||||
async for event in runner.run_async(
|
||||
user_id=session.user_id, session_id=session.id, new_message=content
|
||||
):
|
||||
if event.content and event.content.parts:
|
||||
if text := ''.join(part.text or '' for part in event.content.parts):
|
||||
click.echo(f'[{event.author}]: {text}')
|
||||
|
||||
|
||||
async def run_interactively(
|
||||
app_name: str,
|
||||
root_agent: LlmAgent,
|
||||
artifact_service: BaseArtifactService,
|
||||
session: Session,
|
||||
session_service: BaseSessionService,
|
||||
) -> None:
|
||||
runner = Runner(
|
||||
app_name=app_name,
|
||||
agent=root_agent,
|
||||
artifact_service=artifact_service,
|
||||
session_service=session_service,
|
||||
)
|
||||
while True:
|
||||
query = input('user: ')
|
||||
if query == 'exit':
|
||||
break
|
||||
async for event in runner.run_async(
|
||||
user_id=session.user_id,
|
||||
session_id=session.id,
|
||||
new_message=types.Content(role='user', parts=[types.Part(text=query)]),
|
||||
):
|
||||
if event.content and event.content.parts:
|
||||
if text := ''.join(part.text or '' for part in event.content.parts):
|
||||
click.echo(f'[{event.author}]: {text}')
|
||||
|
||||
|
||||
async def run_cli(
|
||||
*,
|
||||
agent_parent_dir: str,
|
||||
agent_folder_name: str,
|
||||
json_file_path: Optional[str] = None,
|
||||
save_session: bool,
|
||||
) -> None:
|
||||
"""Runs an interactive CLI for a certain agent.
|
||||
|
||||
Args:
|
||||
agent_parent_dir: str, the absolute path of the parent folder of the agent
|
||||
folder.
|
||||
agent_folder_name: str, the name of the agent folder.
|
||||
json_file_path: Optional[str], the absolute path to the json file, either
|
||||
*.input.json or *.session.json.
|
||||
save_session: bool, whether to save the session on exit.
|
||||
"""
|
||||
if agent_parent_dir not in sys.path:
|
||||
sys.path.append(agent_parent_dir)
|
||||
|
||||
artifact_service = InMemoryArtifactService()
|
||||
session_service = InMemorySessionService()
|
||||
session = session_service.create_session(
|
||||
app_name=agent_folder_name, user_id='test_user'
|
||||
)
|
||||
|
||||
agent_module_path = os.path.join(agent_parent_dir, agent_folder_name)
|
||||
agent_module = importlib.import_module(agent_folder_name)
|
||||
root_agent = agent_module.agent.root_agent
|
||||
envs.load_dotenv_for_agent(agent_folder_name, agent_parent_dir)
|
||||
if json_file_path:
|
||||
if json_file_path.endswith('.input.json'):
|
||||
await run_input_file(
|
||||
app_name=agent_folder_name,
|
||||
root_agent=root_agent,
|
||||
artifact_service=artifact_service,
|
||||
session=session,
|
||||
session_service=session_service,
|
||||
input_path=json_file_path,
|
||||
)
|
||||
elif json_file_path.endswith('.session.json'):
|
||||
with open(json_file_path, 'r') as f:
|
||||
session = Session.model_validate_json(f.read())
|
||||
for content in session.get_contents():
|
||||
if content.role == 'user':
|
||||
print('user: ', content.parts[0].text)
|
||||
else:
|
||||
print(content.parts[0].text)
|
||||
await run_interactively(
|
||||
agent_folder_name,
|
||||
root_agent,
|
||||
artifact_service,
|
||||
session,
|
||||
session_service,
|
||||
)
|
||||
else:
|
||||
print(f'Unsupported file type: {json_file_path}')
|
||||
exit(1)
|
||||
else:
|
||||
print(f'Running agent {root_agent.name}, type exit to exit.')
|
||||
await run_interactively(
|
||||
agent_folder_name,
|
||||
root_agent,
|
||||
artifact_service,
|
||||
session,
|
||||
session_service,
|
||||
)
|
||||
|
||||
if save_session:
|
||||
if json_file_path:
|
||||
session_path = json_file_path.replace('.input.json', '.session.json')
|
||||
else:
|
||||
session_id = input('Session ID to save: ')
|
||||
session_path = f'{agent_module_path}/{session_id}.session.json'
|
||||
with open(session_path, 'w') as f:
|
||||
f.write(session.model_dump_json(indent=2, exclude_none=True))
|
||||
# TODO: Save from opentelemetry.
|
||||
# logs_path = session_path.replace('.session.json', '.logs.json')
|
||||
# with open(logs_path, 'w') as f:
|
||||
# f.write(
|
||||
# session.model_dump_json(
|
||||
# indent=2, exclude_none=True, include='event_logs'
|
||||
# )
|
||||
# )
|
||||
print('Session saved to', session_path)
|
||||
@@ -0,0 +1,181 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
import shutil
|
||||
import subprocess
|
||||
from typing import Optional
|
||||
|
||||
import click
|
||||
|
||||
_DOCKERFILE_TEMPLATE = """
|
||||
FROM python:3.11-slim
|
||||
WORKDIR /app
|
||||
|
||||
# Create a non-root user
|
||||
RUN adduser --disabled-password --gecos "" myuser
|
||||
|
||||
# Change ownership of /app to myuser
|
||||
RUN chown -R myuser:myuser /app
|
||||
|
||||
# Switch to the non-root user
|
||||
USER myuser
|
||||
|
||||
# Set up environment variables - Start
|
||||
ENV PATH="/home/myuser/.local/bin:$PATH"
|
||||
|
||||
ENV GOOGLE_GENAI_USE_VERTEXAI=1
|
||||
# TODO: use passed-in value
|
||||
ENV GOOGLE_CLOUD_PROJECT={gcp_project_id}
|
||||
ENV GOOGLE_CLOUD_LOCATION={gcp_region}
|
||||
ENV ADK_TRACE_TO_CLOUD={with_cloud_trace}
|
||||
|
||||
# Set up environment variables - End
|
||||
|
||||
# Install ADK - Start
|
||||
RUN pip install google-adk
|
||||
# Install ADK - End
|
||||
|
||||
# Copy agent - Start
|
||||
|
||||
COPY "agents/{app_name}/" "/app/agents/{app_name}/"
|
||||
{install_agent_deps}
|
||||
|
||||
# Copy agent - End
|
||||
|
||||
EXPOSE {port}
|
||||
|
||||
CMD adk {command} --port={port} "/app/agents"
|
||||
"""
|
||||
|
||||
|
||||
def _resolve_project(project_in_option: Optional[str]) -> str:
|
||||
if project_in_option:
|
||||
return project_in_option
|
||||
|
||||
result = subprocess.run(
|
||||
['gcloud', 'config', 'get-value', 'project'],
|
||||
check=True,
|
||||
capture_output=True,
|
||||
text=True,
|
||||
)
|
||||
project = result.stdout.strip()
|
||||
click.echo(f'Use default project: {project}')
|
||||
return project
|
||||
|
||||
|
||||
def to_cloud_run(
|
||||
*,
|
||||
agent_folder: str,
|
||||
project: Optional[str],
|
||||
region: Optional[str],
|
||||
service_name: str,
|
||||
app_name: str,
|
||||
temp_folder: str,
|
||||
port: int,
|
||||
with_cloud_trace: bool,
|
||||
with_ui: bool,
|
||||
):
|
||||
"""Deploys an agent to Google Cloud Run.
|
||||
|
||||
`agent_folder` should contain the following files:
|
||||
|
||||
- __init__.py
|
||||
- agent.py
|
||||
- requirements.txt (optional, for additional dependencies)
|
||||
- ... (other required source files)
|
||||
|
||||
The folder structure of temp_folder will be
|
||||
|
||||
* dist/[google_adk wheel file]
|
||||
* agents/[app_name]/
|
||||
* agent source code from `agent_folder`
|
||||
|
||||
Args:
|
||||
agent_folder: The folder (absolute path) containing the agent source code.
|
||||
project: Google Cloud project id.
|
||||
region: Google Cloud region.
|
||||
service_name: The service name in Cloud Run.
|
||||
app_name: The name of the app, by default, it's basename of `agent_folder`.
|
||||
temp_folder: The temp folder for the generated Cloud Run source files.
|
||||
port: The port of the ADK api server.
|
||||
with_cloud_trace: Whether to enable Cloud Trace.
|
||||
with_ui: Whether to deploy with UI.
|
||||
"""
|
||||
app_name = app_name or os.path.basename(agent_folder)
|
||||
|
||||
click.echo(f'Start generating Cloud Run source files in {temp_folder}')
|
||||
|
||||
# remove temp_folder if exists
|
||||
if os.path.exists(temp_folder):
|
||||
click.echo('Removing existing files')
|
||||
shutil.rmtree(temp_folder)
|
||||
|
||||
try:
|
||||
# copy agent source code
|
||||
click.echo('Copying agent source code...')
|
||||
agent_src_path = os.path.join(temp_folder, 'agents', app_name)
|
||||
shutil.copytree(agent_folder, agent_src_path)
|
||||
requirements_txt_path = os.path.join(agent_src_path, 'requirements.txt')
|
||||
install_agent_deps = (
|
||||
f'RUN pip install -r "/app/agents/{app_name}/requirements.txt"'
|
||||
if os.path.exists(requirements_txt_path)
|
||||
else ''
|
||||
)
|
||||
click.echo('Copying agent source code complete.')
|
||||
|
||||
# create Dockerfile
|
||||
click.echo('Creating Dockerfile...')
|
||||
dockerfile_content = _DOCKERFILE_TEMPLATE.format(
|
||||
gcp_project_id=project,
|
||||
gcp_region=region,
|
||||
app_name=app_name,
|
||||
port=port,
|
||||
command='web' if with_ui else 'api_server',
|
||||
install_agent_deps=install_agent_deps,
|
||||
with_cloud_trace='1' if with_cloud_trace else '0',
|
||||
)
|
||||
dockerfile_path = os.path.join(temp_folder, 'Dockerfile')
|
||||
os.makedirs(temp_folder, exist_ok=True)
|
||||
with open(dockerfile_path, 'w', encoding='utf-8') as f:
|
||||
f.write(
|
||||
dockerfile_content,
|
||||
)
|
||||
click.echo(f'Creating Dockerfile complete: {dockerfile_path}')
|
||||
|
||||
# Deploy to Cloud Run
|
||||
click.echo('Deploying to Cloud Run...')
|
||||
region_options = ['--region', region] if region else []
|
||||
project = _resolve_project(project)
|
||||
subprocess.run(
|
||||
[
|
||||
'gcloud',
|
||||
'run',
|
||||
'deploy',
|
||||
service_name,
|
||||
'--source',
|
||||
temp_folder,
|
||||
'--project',
|
||||
project,
|
||||
*region_options,
|
||||
'--port',
|
||||
str(port),
|
||||
'--labels',
|
||||
'created-by=adk',
|
||||
],
|
||||
check=True,
|
||||
)
|
||||
finally:
|
||||
click.echo(f'Cleaning up the temp folder: {temp_folder}')
|
||||
shutil.rmtree(temp_folder)
|
||||
@@ -0,0 +1,282 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from enum import Enum
|
||||
import importlib.util
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
import traceback
|
||||
from typing import Any
|
||||
from typing import Generator
|
||||
from typing import Optional
|
||||
import uuid
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from ..agents import Agent
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class EvalStatus(Enum):
|
||||
PASSED = 1
|
||||
FAILED = 2
|
||||
NOT_EVALUATED = 3
|
||||
|
||||
|
||||
class EvalMetric(BaseModel):
|
||||
metric_name: str
|
||||
threshold: float
|
||||
|
||||
|
||||
class EvalMetricResult(BaseModel):
|
||||
score: Optional[float]
|
||||
eval_status: EvalStatus
|
||||
|
||||
|
||||
class EvalResult(BaseModel):
|
||||
eval_set_file: str
|
||||
eval_id: str
|
||||
final_eval_status: EvalStatus
|
||||
eval_metric_results: list[tuple[EvalMetric, EvalMetricResult]]
|
||||
session_id: str
|
||||
|
||||
|
||||
MISSING_EVAL_DEPENDENCIES_MESSAGE = (
|
||||
"Eval module is not installed, please install via `pip install"
|
||||
" google-adk[eval]`."
|
||||
)
|
||||
TOOL_TRAJECTORY_SCORE_KEY = "tool_trajectory_avg_score"
|
||||
RESPONSE_MATCH_SCORE_KEY = "response_match_score"
|
||||
# This evaluation is not very stable.
|
||||
# This is always optional unless explicitly specified.
|
||||
RESPONSE_EVALUATION_SCORE_KEY = "response_evaluation_score"
|
||||
|
||||
EVAL_SESSION_ID_PREFIX = "___eval___session___"
|
||||
DEFAULT_CRITERIA = {
|
||||
TOOL_TRAJECTORY_SCORE_KEY: 1.0, # 1-point scale; 1.0 is perfect.
|
||||
RESPONSE_MATCH_SCORE_KEY: 0.8,
|
||||
}
|
||||
|
||||
|
||||
def _import_from_path(module_name, file_path):
|
||||
spec = importlib.util.spec_from_file_location(module_name, file_path)
|
||||
module = importlib.util.module_from_spec(spec)
|
||||
sys.modules[module_name] = module
|
||||
spec.loader.exec_module(module)
|
||||
return module
|
||||
|
||||
|
||||
def _get_agent_module(agent_module_file_path: str):
|
||||
file_path = os.path.join(agent_module_file_path, "__init__.py")
|
||||
module_name = "agent"
|
||||
return _import_from_path(module_name, file_path)
|
||||
|
||||
|
||||
def get_evaluation_criteria_or_default(
|
||||
eval_config_file_path: str,
|
||||
) -> dict[str, float]:
|
||||
"""Returns evaluation criteria from the config file, if present.
|
||||
|
||||
Otherwise a default one is returned.
|
||||
"""
|
||||
if eval_config_file_path:
|
||||
with open(eval_config_file_path, "r", encoding="utf-8") as f:
|
||||
config_data = json.load(f)
|
||||
|
||||
if "criteria" in config_data and isinstance(config_data["criteria"], dict):
|
||||
evaluation_criteria = config_data["criteria"]
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Invalid format for test_config.json at {eval_config_file_path}."
|
||||
" Expected a 'criteria' dictionary."
|
||||
)
|
||||
else:
|
||||
logger.info("No config file supplied. Using default criteria.")
|
||||
evaluation_criteria = DEFAULT_CRITERIA
|
||||
|
||||
return evaluation_criteria
|
||||
|
||||
|
||||
def get_root_agent(agent_module_file_path: str) -> Agent:
|
||||
"""Returns root agent given the agetn module."""
|
||||
agent_module = _get_agent_module(agent_module_file_path)
|
||||
root_agent = agent_module.agent.root_agent
|
||||
return root_agent
|
||||
|
||||
|
||||
def try_get_reset_func(agent_module_file_path: str) -> Any:
|
||||
"""Returns reset function for the agent, if present, given the agetn module."""
|
||||
agent_module = _get_agent_module(agent_module_file_path)
|
||||
reset_func = getattr(agent_module.agent, "reset_data", None)
|
||||
return reset_func
|
||||
|
||||
|
||||
def parse_and_get_evals_to_run(
|
||||
eval_set_file_path: tuple[str],
|
||||
) -> dict[str, list[str]]:
|
||||
"""Returns a dictionary of eval sets to evals that should be run."""
|
||||
eval_set_to_evals = {}
|
||||
for input_eval_set in eval_set_file_path:
|
||||
evals = []
|
||||
if ":" not in input_eval_set:
|
||||
eval_set_file = input_eval_set
|
||||
else:
|
||||
eval_set_file = input_eval_set.split(":")[0]
|
||||
evals = input_eval_set.split(":")[1].split(",")
|
||||
|
||||
if eval_set_file not in eval_set_to_evals:
|
||||
eval_set_to_evals[eval_set_file] = []
|
||||
|
||||
eval_set_to_evals[eval_set_file].extend(evals)
|
||||
|
||||
return eval_set_to_evals
|
||||
|
||||
|
||||
def run_evals(
|
||||
eval_set_to_evals: dict[str, list[str]],
|
||||
root_agent: Agent,
|
||||
reset_func: Optional[Any],
|
||||
eval_metrics: list[EvalMetric],
|
||||
session_service=None,
|
||||
artifact_service=None,
|
||||
print_detailed_results=False,
|
||||
) -> Generator[EvalResult, None, None]:
|
||||
try:
|
||||
from ..evaluation.agent_evaluator import EvaluationGenerator
|
||||
from ..evaluation.response_evaluator import ResponseEvaluator
|
||||
from ..evaluation.trajectory_evaluator import TrajectoryEvaluator
|
||||
except ModuleNotFoundError as e:
|
||||
raise ModuleNotFoundError(MISSING_EVAL_DEPENDENCIES_MESSAGE) from e
|
||||
|
||||
"""Returns a summary of eval runs."""
|
||||
for eval_set_file, evals_to_run in eval_set_to_evals.items():
|
||||
with open(eval_set_file, "r", encoding="utf-8") as file:
|
||||
eval_items = json.load(file) # Load JSON into a list
|
||||
|
||||
assert eval_items, f"No eval data found in eval set file: {eval_set_file}"
|
||||
|
||||
for eval_item in eval_items:
|
||||
eval_name = eval_item["name"]
|
||||
eval_data = eval_item["data"]
|
||||
initial_session = eval_item.get("initial_session", {})
|
||||
|
||||
if evals_to_run and eval_name not in evals_to_run:
|
||||
continue
|
||||
|
||||
try:
|
||||
print(f"Running Eval: {eval_set_file}:{eval_name}")
|
||||
session_id = f"{EVAL_SESSION_ID_PREFIX}{str(uuid.uuid4())}"
|
||||
|
||||
scrape_result = EvaluationGenerator._process_query_with_root_agent(
|
||||
data=eval_data,
|
||||
root_agent=root_agent,
|
||||
reset_func=reset_func,
|
||||
initial_session=initial_session,
|
||||
session_id=session_id,
|
||||
session_service=session_service,
|
||||
artifact_service=artifact_service,
|
||||
)
|
||||
|
||||
eval_metric_results = []
|
||||
for eval_metric in eval_metrics:
|
||||
eval_metric_result = None
|
||||
if eval_metric.metric_name == TOOL_TRAJECTORY_SCORE_KEY:
|
||||
score = TrajectoryEvaluator.evaluate(
|
||||
[scrape_result], print_detailed_results=print_detailed_results
|
||||
)
|
||||
eval_metric_result = _get_eval_metric_result(eval_metric, score)
|
||||
elif eval_metric.metric_name == RESPONSE_MATCH_SCORE_KEY:
|
||||
score = ResponseEvaluator.evaluate(
|
||||
[scrape_result],
|
||||
[RESPONSE_MATCH_SCORE_KEY],
|
||||
print_detailed_results=print_detailed_results,
|
||||
)
|
||||
eval_metric_result = _get_eval_metric_result(
|
||||
eval_metric, score["rouge_1/mean"].item()
|
||||
)
|
||||
elif eval_metric.metric_name == RESPONSE_EVALUATION_SCORE_KEY:
|
||||
score = ResponseEvaluator.evaluate(
|
||||
[scrape_result],
|
||||
[RESPONSE_EVALUATION_SCORE_KEY],
|
||||
print_detailed_results=print_detailed_results,
|
||||
)
|
||||
eval_metric_result = _get_eval_metric_result(
|
||||
eval_metric, score["coherence/mean"].item()
|
||||
)
|
||||
else:
|
||||
logger.warning("`%s` is not supported.", eval_metric.metric_name)
|
||||
eval_metric_results.append((
|
||||
eval_metric,
|
||||
EvalMetricResult(eval_status=EvalStatus.NOT_EVALUATED),
|
||||
))
|
||||
|
||||
eval_metric_results.append((
|
||||
eval_metric,
|
||||
eval_metric_result,
|
||||
))
|
||||
_print_eval_metric_result(eval_metric, eval_metric_result)
|
||||
|
||||
final_eval_status = EvalStatus.NOT_EVALUATED
|
||||
|
||||
# Go over the all the eval statuses and mark the final eval status as
|
||||
# passed if all of them pass, otherwise mark the final eval status to
|
||||
# failed.
|
||||
for eval_metric_result in eval_metric_results:
|
||||
eval_status = eval_metric_result[1].eval_status
|
||||
if eval_status == EvalStatus.PASSED:
|
||||
final_eval_status = EvalStatus.PASSED
|
||||
elif eval_status == EvalStatus.NOT_EVALUATED:
|
||||
continue
|
||||
elif eval_status == EvalStatus.FAILED:
|
||||
final_eval_status = EvalStatus.FAILED
|
||||
break
|
||||
else:
|
||||
raise ValueError("Unknown eval status.")
|
||||
|
||||
yield EvalResult(
|
||||
eval_set_file=eval_set_file,
|
||||
eval_id=eval_name,
|
||||
final_eval_status=final_eval_status,
|
||||
eval_metric_results=eval_metric_results,
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
if final_eval_status == EvalStatus.PASSED:
|
||||
result = "✅ Passsed"
|
||||
else:
|
||||
result = "❌ Failed"
|
||||
|
||||
print(f"Result: {result}\n")
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error: {e}")
|
||||
logger.info("Error: %s", str(traceback.format_exc()))
|
||||
|
||||
|
||||
def _get_eval_metric_result(eval_metric, score):
|
||||
eval_status = (
|
||||
EvalStatus.PASSED if score >= eval_metric.threshold else EvalStatus.FAILED
|
||||
)
|
||||
return EvalMetricResult(score=score, eval_status=eval_status)
|
||||
|
||||
|
||||
def _print_eval_metric_result(eval_metric, eval_metric_result):
|
||||
print(
|
||||
f"Metric: {eval_metric.metric_name}\tStatus:"
|
||||
f" {eval_metric_result.eval_status}\tScore:"
|
||||
f" {eval_metric_result.score}\tThreshold: {eval_metric.threshold}"
|
||||
)
|
||||
@@ -0,0 +1,479 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import asyncio
|
||||
from datetime import datetime
|
||||
import logging
|
||||
import os
|
||||
import tempfile
|
||||
from typing import Optional
|
||||
|
||||
import click
|
||||
import uvicorn
|
||||
|
||||
from . import cli_deploy
|
||||
from .cli import run_cli
|
||||
from .cli_eval import MISSING_EVAL_DEPENDENCIES_MESSAGE
|
||||
from .fast_api import get_fast_api_app
|
||||
from .utils import envs
|
||||
from .utils import logs
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@click.group(context_settings={"max_content_width": 240})
|
||||
def main():
|
||||
"""Agent Development Kit CLI tools."""
|
||||
pass
|
||||
|
||||
|
||||
@main.group()
|
||||
def deploy():
|
||||
"""Deploy Agent."""
|
||||
pass
|
||||
|
||||
|
||||
@main.command("run")
|
||||
@click.option(
|
||||
"--save_session",
|
||||
type=bool,
|
||||
is_flag=True,
|
||||
show_default=True,
|
||||
default=False,
|
||||
help="Optional. Whether to save the session to a json file on exit.",
|
||||
)
|
||||
@click.argument(
|
||||
"agent",
|
||||
type=click.Path(
|
||||
exists=True, dir_okay=True, file_okay=False, resolve_path=True
|
||||
),
|
||||
)
|
||||
def cli_run(agent: str, save_session: bool):
|
||||
"""Run an interactive CLI for a certain agent.
|
||||
|
||||
AGENT: The path to the agent source code folder.
|
||||
|
||||
Example:
|
||||
|
||||
adk run path/to/my_agent
|
||||
"""
|
||||
logs.log_to_tmp_folder()
|
||||
|
||||
agent_parent_folder = os.path.dirname(agent)
|
||||
agent_folder_name = os.path.basename(agent)
|
||||
|
||||
asyncio.run(
|
||||
run_cli(
|
||||
agent_parent_dir=agent_parent_folder,
|
||||
agent_folder_name=agent_folder_name,
|
||||
save_session=save_session,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@main.command("eval")
|
||||
@click.argument(
|
||||
"agent_module_file_path",
|
||||
type=click.Path(
|
||||
exists=True, dir_okay=True, file_okay=False, resolve_path=True
|
||||
),
|
||||
)
|
||||
@click.argument("eval_set_file_path", nargs=-1)
|
||||
@click.option("--config_file_path", help="Optional. The path to config file.")
|
||||
@click.option(
|
||||
"--print_detailed_results",
|
||||
is_flag=True,
|
||||
show_default=True,
|
||||
default=False,
|
||||
help="Optional. Whether to print detailed results on console or not.",
|
||||
)
|
||||
def eval_command(
|
||||
agent_module_file_path: str,
|
||||
eval_set_file_path: tuple[str],
|
||||
config_file_path: str,
|
||||
print_detailed_results: bool,
|
||||
):
|
||||
"""Evaluates an agent given the eval sets.
|
||||
|
||||
AGENT_MODULE_FILE_PATH: The path to the __init__.py file that contains a
|
||||
module by the name "agent". "agent" module contains a root_agent.
|
||||
|
||||
EVAL_SET_FILE_PATH: You can specify one or more eval set file paths.
|
||||
|
||||
For each file, all evals will be run by default.
|
||||
|
||||
If you want to run only specific evals from a eval set, first create a comma
|
||||
separated list of eval names and then add that as a suffix to the eval set
|
||||
file name, demarcated by a `:`.
|
||||
|
||||
For example,
|
||||
|
||||
sample_eval_set_file.json:eval_1,eval_2,eval_3
|
||||
|
||||
This will only run eval_1, eval_2 and eval_3 from sample_eval_set_file.json.
|
||||
|
||||
CONFIG_FILE_PATH: The path to config file.
|
||||
|
||||
PRINT_DETAILED_RESULTS: Prints detailed results on the console.
|
||||
"""
|
||||
envs.load_dotenv_for_agent(agent_module_file_path, ".")
|
||||
|
||||
try:
|
||||
from .cli_eval import EvalMetric
|
||||
from .cli_eval import EvalResult
|
||||
from .cli_eval import EvalStatus
|
||||
from .cli_eval import get_evaluation_criteria_or_default
|
||||
from .cli_eval import get_root_agent
|
||||
from .cli_eval import parse_and_get_evals_to_run
|
||||
from .cli_eval import run_evals
|
||||
from .cli_eval import try_get_reset_func
|
||||
except ModuleNotFoundError:
|
||||
raise click.ClickException(MISSING_EVAL_DEPENDENCIES_MESSAGE)
|
||||
|
||||
evaluation_criteria = get_evaluation_criteria_or_default(config_file_path)
|
||||
eval_metrics = []
|
||||
for metric_name, threshold in evaluation_criteria.items():
|
||||
eval_metrics.append(
|
||||
EvalMetric(metric_name=metric_name, threshold=threshold)
|
||||
)
|
||||
|
||||
print(f"Using evaluation creiteria: {evaluation_criteria}")
|
||||
|
||||
root_agent = get_root_agent(agent_module_file_path)
|
||||
reset_func = try_get_reset_func(agent_module_file_path)
|
||||
|
||||
eval_set_to_evals = parse_and_get_evals_to_run(eval_set_file_path)
|
||||
|
||||
try:
|
||||
eval_results = list(
|
||||
run_evals(
|
||||
eval_set_to_evals,
|
||||
root_agent,
|
||||
reset_func,
|
||||
eval_metrics,
|
||||
print_detailed_results=print_detailed_results,
|
||||
)
|
||||
)
|
||||
except ModuleNotFoundError:
|
||||
raise click.ClickException(MISSING_EVAL_DEPENDENCIES_MESSAGE)
|
||||
|
||||
print("*********************************************************************")
|
||||
eval_run_summary = {}
|
||||
|
||||
for eval_result in eval_results:
|
||||
eval_result: EvalResult
|
||||
|
||||
if eval_result.eval_set_file not in eval_run_summary:
|
||||
eval_run_summary[eval_result.eval_set_file] = [0, 0]
|
||||
|
||||
if eval_result.final_eval_status == EvalStatus.PASSED:
|
||||
eval_run_summary[eval_result.eval_set_file][0] += 1
|
||||
else:
|
||||
eval_run_summary[eval_result.eval_set_file][1] += 1
|
||||
print("Eval Run Summary")
|
||||
for eval_set_file, pass_fail_count in eval_run_summary.items():
|
||||
print(
|
||||
f"{eval_set_file}:\n Tests passed: {pass_fail_count[0]}\n Tests"
|
||||
f" failed: {pass_fail_count[1]}"
|
||||
)
|
||||
|
||||
|
||||
@main.command("web")
|
||||
@click.option(
|
||||
"--session_db_url",
|
||||
help=(
|
||||
"Optional. The database URL to store the session.\n\n - Use"
|
||||
" 'agentengine://<agent_engine_resource_id>' to connect to Vertex"
|
||||
" managed session service.\n\n - Use 'sqlite://<path_to_sqlite_file>'"
|
||||
" to connect to a SQLite DB.\n\n - See"
|
||||
" https://docs.sqlalchemy.org/en/20/core/engines.html#backend-specific-urls"
|
||||
" for more details on supported DB URLs."
|
||||
),
|
||||
)
|
||||
@click.option(
|
||||
"--port",
|
||||
type=int,
|
||||
help="Optional. The port of the server",
|
||||
default=8000,
|
||||
)
|
||||
@click.option(
|
||||
"--allow_origins",
|
||||
help="Optional. Any additional origins to allow for CORS.",
|
||||
multiple=True,
|
||||
)
|
||||
@click.option(
|
||||
"--log_level",
|
||||
type=click.Choice(
|
||||
["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"], case_sensitive=False
|
||||
),
|
||||
default="INFO",
|
||||
help="Optional. Set the logging level",
|
||||
)
|
||||
@click.option(
|
||||
"--log_to_tmp",
|
||||
is_flag=True,
|
||||
show_default=True,
|
||||
default=False,
|
||||
help=(
|
||||
"Optional. Whether to log to system temp folder instead of console."
|
||||
" This is useful for local debugging."
|
||||
),
|
||||
)
|
||||
@click.argument(
|
||||
"agents_dir",
|
||||
type=click.Path(
|
||||
exists=True, dir_okay=True, file_okay=False, resolve_path=True
|
||||
),
|
||||
default=os.getcwd(),
|
||||
)
|
||||
def web(
|
||||
agents_dir: str,
|
||||
log_to_tmp: bool,
|
||||
session_db_url: str = "",
|
||||
log_level: str = "INFO",
|
||||
allow_origins: Optional[list[str]] = None,
|
||||
port: int = 8000,
|
||||
):
|
||||
"""Start a FastAPI server with web UI for a certain agent.
|
||||
|
||||
AGENTS_DIR: The directory of agents, where each sub-directory is a single
|
||||
agent, containing at least `__init__.py` and `agent.py` files.
|
||||
|
||||
Example:
|
||||
|
||||
adk web --session_db_url=[db_url] --port=[port] path/to/agents_dir
|
||||
"""
|
||||
if log_to_tmp:
|
||||
logs.log_to_tmp_folder()
|
||||
else:
|
||||
logs.log_to_stderr()
|
||||
|
||||
logging.getLogger().setLevel(log_level)
|
||||
|
||||
config = uvicorn.Config(
|
||||
get_fast_api_app(
|
||||
agent_dir=agents_dir,
|
||||
session_db_url=session_db_url,
|
||||
allow_origins=allow_origins,
|
||||
web=True,
|
||||
),
|
||||
host="0.0.0.0",
|
||||
port=port,
|
||||
reload=True,
|
||||
)
|
||||
server = uvicorn.Server(config)
|
||||
server.run()
|
||||
|
||||
|
||||
@main.command("api_server")
|
||||
@click.option(
|
||||
"--session_db_url",
|
||||
help=(
|
||||
"Optional. The database URL to store the session.\n\n - Use"
|
||||
" 'agentengine://<agent_engine_resource_id>' to connect to Vertex"
|
||||
" managed session service.\n\n - Use 'sqlite://<path_to_sqlite_file>'"
|
||||
" to connect to a SQLite DB.\n\n - See"
|
||||
" https://docs.sqlalchemy.org/en/20/core/engines.html#backend-specific-urls"
|
||||
" for more details on supported DB URLs."
|
||||
),
|
||||
)
|
||||
@click.option(
|
||||
"--port",
|
||||
type=int,
|
||||
help="Optional. The port of the server",
|
||||
default=8000,
|
||||
)
|
||||
@click.option(
|
||||
"--allow_origins",
|
||||
help="Optional. Any additional origins to allow for CORS.",
|
||||
multiple=True,
|
||||
)
|
||||
@click.option(
|
||||
"--log_level",
|
||||
type=click.Choice(
|
||||
["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"], case_sensitive=False
|
||||
),
|
||||
default="INFO",
|
||||
help="Optional. Set the logging level",
|
||||
)
|
||||
@click.option(
|
||||
"--log_to_tmp",
|
||||
is_flag=True,
|
||||
show_default=True,
|
||||
default=False,
|
||||
help=(
|
||||
"Optional. Whether to log to system temp folder instead of console."
|
||||
" This is useful for local debugging."
|
||||
),
|
||||
)
|
||||
# The directory of agents, where each sub-directory is a single agent.
|
||||
# By default, it is the current working directory
|
||||
@click.argument(
|
||||
"agents_dir",
|
||||
type=click.Path(
|
||||
exists=True, dir_okay=True, file_okay=False, resolve_path=True
|
||||
),
|
||||
default=os.getcwd(),
|
||||
)
|
||||
def cli_api_server(
|
||||
agents_dir: str,
|
||||
log_to_tmp: bool,
|
||||
session_db_url: str = "",
|
||||
log_level: str = "INFO",
|
||||
allow_origins: Optional[list[str]] = None,
|
||||
port: int = 8000,
|
||||
):
|
||||
"""Start an api server for a certain agent.
|
||||
|
||||
AGENTS_DIR: The directory of agents, where each sub-directory is a single
|
||||
agent, containing at least `__init__.py` and `agent.py` files.
|
||||
|
||||
Example:
|
||||
|
||||
adk api_server --session_db_url=[db_url] --port=[port] path/to/agents_dir
|
||||
"""
|
||||
if log_to_tmp:
|
||||
logs.log_to_tmp_folder()
|
||||
else:
|
||||
logs.log_to_stderr()
|
||||
|
||||
logging.getLogger().setLevel(log_level)
|
||||
|
||||
config = uvicorn.Config(
|
||||
get_fast_api_app(
|
||||
agent_dir=agents_dir,
|
||||
session_db_url=session_db_url,
|
||||
allow_origins=allow_origins,
|
||||
web=False,
|
||||
),
|
||||
host="0.0.0.0",
|
||||
port=port,
|
||||
reload=True,
|
||||
)
|
||||
server = uvicorn.Server(config)
|
||||
server.run()
|
||||
|
||||
|
||||
@deploy.command("cloud_run")
|
||||
@click.option(
|
||||
"--project",
|
||||
type=str,
|
||||
help=(
|
||||
"Required. Google Cloud project to deploy the agent. When absent,"
|
||||
" default project from gcloud config is used."
|
||||
),
|
||||
)
|
||||
@click.option(
|
||||
"--region",
|
||||
type=str,
|
||||
help=(
|
||||
"Required. Google Cloud region to deploy the agent. When absent,"
|
||||
" gcloud run deploy will prompt later."
|
||||
),
|
||||
)
|
||||
@click.option(
|
||||
"--service_name",
|
||||
type=str,
|
||||
default="adk-default-service-name",
|
||||
help=(
|
||||
"Optional. The service name to use in Cloud Run (default:"
|
||||
" 'adk-default-service-name')."
|
||||
),
|
||||
)
|
||||
@click.option(
|
||||
"--app_name",
|
||||
type=str,
|
||||
default="",
|
||||
help=(
|
||||
"Optional. App name of the ADK API server (default: the folder name"
|
||||
" of the AGENT source code)."
|
||||
),
|
||||
)
|
||||
@click.option(
|
||||
"--port",
|
||||
type=int,
|
||||
default=8000,
|
||||
help="Optional. The port of the ADK API server (default: 8000).",
|
||||
)
|
||||
@click.option(
|
||||
"--with_cloud_trace",
|
||||
type=bool,
|
||||
is_flag=True,
|
||||
show_default=True,
|
||||
default=False,
|
||||
help="Optional. Whether to enable Cloud Trace for cloud run.",
|
||||
)
|
||||
@click.option(
|
||||
"--with_ui",
|
||||
type=bool,
|
||||
is_flag=True,
|
||||
show_default=True,
|
||||
default=False,
|
||||
help=(
|
||||
"Optional. Deploy ADK Web UI if set. (default: deploy ADK API server"
|
||||
" only)"
|
||||
),
|
||||
)
|
||||
@click.option(
|
||||
"--temp_folder",
|
||||
type=str,
|
||||
default=os.path.join(
|
||||
tempfile.gettempdir(),
|
||||
"cloud_run_deploy_src",
|
||||
datetime.now().strftime("%Y%m%d_%H%M%S"),
|
||||
),
|
||||
help=(
|
||||
"Optional. Temp folder for the generated Cloud Run source files"
|
||||
" (default: a timestamped folder in the system temp directory)."
|
||||
),
|
||||
)
|
||||
@click.argument(
|
||||
"agent",
|
||||
type=click.Path(
|
||||
exists=True, dir_okay=True, file_okay=False, resolve_path=True
|
||||
),
|
||||
)
|
||||
def deploy_to_cloud_run(
|
||||
agent: str,
|
||||
project: Optional[str],
|
||||
region: Optional[str],
|
||||
service_name: str,
|
||||
app_name: str,
|
||||
temp_folder: str,
|
||||
port: int,
|
||||
with_cloud_trace: bool,
|
||||
with_ui: bool,
|
||||
):
|
||||
"""Deploys agent to Cloud Run.
|
||||
|
||||
AGENT: The path to the agent source code folder.
|
||||
|
||||
Example:
|
||||
|
||||
adk deploy cloud_run --project=[project] --region=[region] path/to/my_agent
|
||||
"""
|
||||
try:
|
||||
cli_deploy.to_cloud_run(
|
||||
agent_folder=agent,
|
||||
project=project,
|
||||
region=region,
|
||||
service_name=service_name,
|
||||
app_name=app_name,
|
||||
temp_folder=temp_folder,
|
||||
port=port,
|
||||
with_cloud_trace=with_cloud_trace,
|
||||
with_ui=with_ui,
|
||||
)
|
||||
except Exception as e:
|
||||
click.secho(f"Deploy failed: {e}", fg="red", err=True)
|
||||
@@ -0,0 +1,765 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import asyncio
|
||||
import importlib
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
from pathlib import Path
|
||||
import re
|
||||
import sys
|
||||
import traceback
|
||||
import typing
|
||||
from typing import Any
|
||||
from typing import List
|
||||
from typing import Literal
|
||||
from typing import Optional
|
||||
|
||||
import click
|
||||
from fastapi import FastAPI
|
||||
from fastapi import HTTPException
|
||||
from fastapi import Query
|
||||
from fastapi import Response
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.responses import FileResponse
|
||||
from fastapi.responses import RedirectResponse
|
||||
from fastapi.responses import StreamingResponse
|
||||
from fastapi.staticfiles import StaticFiles
|
||||
from fastapi.websockets import WebSocket
|
||||
from fastapi.websockets import WebSocketDisconnect
|
||||
from google.genai import types
|
||||
import graphviz
|
||||
from opentelemetry import trace
|
||||
from opentelemetry.exporter.cloud_trace import CloudTraceSpanExporter
|
||||
from opentelemetry.sdk.trace import export
|
||||
from opentelemetry.sdk.trace import ReadableSpan
|
||||
from opentelemetry.sdk.trace import TracerProvider
|
||||
from pydantic import BaseModel
|
||||
from pydantic import ValidationError
|
||||
|
||||
from ..agents import RunConfig
|
||||
from ..agents.live_request_queue import LiveRequest
|
||||
from ..agents.live_request_queue import LiveRequestQueue
|
||||
from ..agents.llm_agent import Agent
|
||||
from ..agents.run_config import StreamingMode
|
||||
from ..artifacts import InMemoryArtifactService
|
||||
from ..events.event import Event
|
||||
from ..runners import Runner
|
||||
from ..sessions.database_session_service import DatabaseSessionService
|
||||
from ..sessions.in_memory_session_service import InMemorySessionService
|
||||
from ..sessions.session import Session
|
||||
from ..sessions.vertex_ai_session_service import VertexAiSessionService
|
||||
from .cli_eval import EVAL_SESSION_ID_PREFIX
|
||||
from .cli_eval import EvalMetric
|
||||
from .cli_eval import EvalMetricResult
|
||||
from .cli_eval import EvalStatus
|
||||
from .utils import create_empty_state
|
||||
from .utils import envs
|
||||
from .utils import evals
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_EVAL_SET_FILE_EXTENSION = ".evalset.json"
|
||||
|
||||
|
||||
class ApiServerSpanExporter(export.SpanExporter):
|
||||
|
||||
def __init__(self, trace_dict):
|
||||
self.trace_dict = trace_dict
|
||||
|
||||
def export(
|
||||
self, spans: typing.Sequence[ReadableSpan]
|
||||
) -> export.SpanExportResult:
|
||||
for span in spans:
|
||||
if span.name == "call_llm" or span.name == "send_data":
|
||||
attributes = dict(span.attributes)
|
||||
attributes["trace_id"] = span.get_span_context().trace_id
|
||||
attributes["span_id"] = span.get_span_context().span_id
|
||||
if attributes.get("gcp.vertex.agent.event_id", None):
|
||||
self.trace_dict[attributes["gcp.vertex.agent.event_id"]] = attributes
|
||||
return export.SpanExportResult.SUCCESS
|
||||
|
||||
def force_flush(self, timeout_millis: int = 30000) -> bool:
|
||||
return True
|
||||
|
||||
|
||||
class AgentRunRequest(BaseModel):
|
||||
app_name: str
|
||||
user_id: str
|
||||
session_id: str
|
||||
new_message: types.Content
|
||||
streaming: bool = False
|
||||
|
||||
|
||||
class AddSessionToEvalSetRequest(BaseModel):
|
||||
eval_id: str
|
||||
session_id: str
|
||||
user_id: str
|
||||
|
||||
|
||||
class RunEvalRequest(BaseModel):
|
||||
eval_ids: list[str] # if empty, then all evals in the eval set are run.
|
||||
eval_metrics: list[EvalMetric]
|
||||
|
||||
|
||||
class RunEvalResult(BaseModel):
|
||||
eval_set_id: str
|
||||
eval_id: str
|
||||
final_eval_status: EvalStatus
|
||||
eval_metric_results: list[tuple[EvalMetric, EvalMetricResult]]
|
||||
session_id: str
|
||||
|
||||
|
||||
def get_fast_api_app(
|
||||
*,
|
||||
agent_dir: str,
|
||||
session_db_url: str = "",
|
||||
allow_origins: Optional[list[str]] = None,
|
||||
web: bool,
|
||||
) -> FastAPI:
|
||||
# InMemory tracing dict.
|
||||
trace_dict: dict[str, Any] = {}
|
||||
|
||||
# Set up tracing in the FastAPI server.
|
||||
provider = TracerProvider()
|
||||
provider.add_span_processor(
|
||||
export.SimpleSpanProcessor(ApiServerSpanExporter(trace_dict))
|
||||
)
|
||||
if os.environ.get("ADK_TRACE_TO_CLOUD", "0") == "1":
|
||||
processor = export.BatchSpanProcessor(
|
||||
CloudTraceSpanExporter(
|
||||
project_id=os.environ.get("GOOGLE_CLOUD_PROJECT", "")
|
||||
)
|
||||
)
|
||||
provider.add_span_processor(processor)
|
||||
|
||||
trace.set_tracer_provider(provider)
|
||||
|
||||
# Run the FastAPI server.
|
||||
app = FastAPI()
|
||||
|
||||
if allow_origins:
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=allow_origins,
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
if agent_dir not in sys.path:
|
||||
sys.path.append(agent_dir)
|
||||
|
||||
runner_dict = {}
|
||||
root_agent_dict = {}
|
||||
|
||||
# Build the Artifact service
|
||||
artifact_service = InMemoryArtifactService()
|
||||
|
||||
# Build the Session service
|
||||
agent_engine_id = ""
|
||||
if session_db_url:
|
||||
if session_db_url.startswith("agentengine://"):
|
||||
# Create vertex session service
|
||||
agent_engine_id = session_db_url.split("://")[1]
|
||||
if not agent_engine_id:
|
||||
raise click.ClickException("Agent engine id can not be empty.")
|
||||
envs.load_dotenv_for_agent("", agent_dir)
|
||||
session_service = VertexAiSessionService(
|
||||
os.environ["GOOGLE_CLOUD_PROJECT"],
|
||||
os.environ["GOOGLE_CLOUD_LOCATION"],
|
||||
)
|
||||
else:
|
||||
session_service = DatabaseSessionService(db_url=session_db_url)
|
||||
else:
|
||||
session_service = InMemorySessionService()
|
||||
|
||||
@app.get("/list-apps")
|
||||
def list_apps() -> list[str]:
|
||||
base_path = Path.cwd() / agent_dir
|
||||
if not base_path.exists():
|
||||
raise HTTPException(status_code=404, detail="Path not found")
|
||||
if not base_path.is_dir():
|
||||
raise HTTPException(status_code=400, detail="Not a directory")
|
||||
agent_names = [
|
||||
x
|
||||
for x in os.listdir(base_path)
|
||||
if os.path.isdir(os.path.join(base_path, x))
|
||||
and not x.startswith(".")
|
||||
and x != "__pycache__"
|
||||
]
|
||||
agent_names.sort()
|
||||
return agent_names
|
||||
|
||||
@app.get("/debug/trace/{event_id}")
|
||||
def get_trace_dict(event_id: str) -> Any:
|
||||
event_dict = trace_dict.get(event_id, None)
|
||||
if event_dict is None:
|
||||
raise HTTPException(status_code=404, detail="Trace not found")
|
||||
return event_dict
|
||||
|
||||
@app.get(
|
||||
"/apps/{app_name}/users/{user_id}/sessions/{session_id}",
|
||||
response_model_exclude_none=True,
|
||||
)
|
||||
def get_session(app_name: str, user_id: str, session_id: str) -> Session:
|
||||
# Connect to managed session if agent_engine_id is set.
|
||||
app_name = agent_engine_id if agent_engine_id else app_name
|
||||
session = session_service.get_session(
|
||||
app_name=app_name, user_id=user_id, session_id=session_id
|
||||
)
|
||||
if not session:
|
||||
raise HTTPException(status_code=404, detail="Session not found")
|
||||
return session
|
||||
|
||||
@app.get(
|
||||
"/apps/{app_name}/users/{user_id}/sessions",
|
||||
response_model_exclude_none=True,
|
||||
)
|
||||
def list_sessions(app_name: str, user_id: str) -> list[Session]:
|
||||
# Connect to managed session if agent_engine_id is set.
|
||||
app_name = agent_engine_id if agent_engine_id else app_name
|
||||
return [
|
||||
session
|
||||
for session in session_service.list_sessions(
|
||||
app_name=app_name, user_id=user_id
|
||||
).sessions
|
||||
# Remove sessions that were generated as a part of Eval.
|
||||
if not session.id.startswith(EVAL_SESSION_ID_PREFIX)
|
||||
]
|
||||
|
||||
@app.post(
|
||||
"/apps/{app_name}/users/{user_id}/sessions/{session_id}",
|
||||
response_model_exclude_none=True,
|
||||
)
|
||||
def create_session_with_id(
|
||||
app_name: str,
|
||||
user_id: str,
|
||||
session_id: str,
|
||||
state: Optional[dict[str, Any]] = None,
|
||||
) -> Session:
|
||||
# Connect to managed session if agent_engine_id is set.
|
||||
app_name = agent_engine_id if agent_engine_id else app_name
|
||||
if (
|
||||
session_service.get_session(
|
||||
app_name=app_name, user_id=user_id, session_id=session_id
|
||||
)
|
||||
is not None
|
||||
):
|
||||
logger.warning("Session already exists: %s", session_id)
|
||||
raise HTTPException(
|
||||
status_code=400, detail=f"Session already exists: {session_id}"
|
||||
)
|
||||
|
||||
logger.info("New session created: %s", session_id)
|
||||
return session_service.create_session(
|
||||
app_name=app_name, user_id=user_id, state=state, session_id=session_id
|
||||
)
|
||||
|
||||
@app.post(
|
||||
"/apps/{app_name}/users/{user_id}/sessions",
|
||||
response_model_exclude_none=True,
|
||||
)
|
||||
def create_session(
|
||||
app_name: str,
|
||||
user_id: str,
|
||||
state: Optional[dict[str, Any]] = None,
|
||||
) -> Session:
|
||||
# Connect to managed session if agent_engine_id is set.
|
||||
app_name = agent_engine_id if agent_engine_id else app_name
|
||||
|
||||
logger.info("New session created")
|
||||
return session_service.create_session(
|
||||
app_name=app_name, user_id=user_id, state=state
|
||||
)
|
||||
|
||||
def _get_eval_set_file_path(app_name, agent_dir, eval_set_id) -> str:
|
||||
return os.path.join(
|
||||
agent_dir,
|
||||
app_name,
|
||||
eval_set_id + _EVAL_SET_FILE_EXTENSION,
|
||||
)
|
||||
|
||||
@app.post(
|
||||
"/apps/{app_name}/eval_sets/{eval_set_id}",
|
||||
response_model_exclude_none=True,
|
||||
)
|
||||
def create_eval_set(
|
||||
app_name: str,
|
||||
eval_set_id: str,
|
||||
):
|
||||
"""Creates an eval set, given the id."""
|
||||
pattern = r"^[a-zA-Z0-9_]+$"
|
||||
if not bool(re.fullmatch(pattern, eval_set_id)):
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=(
|
||||
f"Invalid eval set id. Eval set id should have the `{pattern}`"
|
||||
" format"
|
||||
),
|
||||
)
|
||||
# Define the file path
|
||||
new_eval_set_path = _get_eval_set_file_path(
|
||||
app_name, agent_dir, eval_set_id
|
||||
)
|
||||
|
||||
logger.info("Creating eval set file `%s`", new_eval_set_path)
|
||||
|
||||
if not os.path.exists(new_eval_set_path):
|
||||
# Write the JSON string to the file
|
||||
logger.info("Eval set file doesn't exist, we will create a new one.")
|
||||
with open(new_eval_set_path, "w") as f:
|
||||
empty_content = json.dumps([], indent=2)
|
||||
f.write(empty_content)
|
||||
|
||||
@app.get(
|
||||
"/apps/{app_name}/eval_sets",
|
||||
response_model_exclude_none=True,
|
||||
)
|
||||
def list_eval_sets(app_name: str) -> list[str]:
|
||||
"""Lists all eval sets for the given app."""
|
||||
eval_set_file_path = os.path.join(agent_dir, app_name)
|
||||
eval_sets = []
|
||||
for file in os.listdir(eval_set_file_path):
|
||||
if file.endswith(_EVAL_SET_FILE_EXTENSION):
|
||||
eval_sets.append(
|
||||
os.path.basename(file).removesuffix(_EVAL_SET_FILE_EXTENSION)
|
||||
)
|
||||
|
||||
return sorted(eval_sets)
|
||||
|
||||
@app.post(
|
||||
"/apps/{app_name}/eval_sets/{eval_set_id}/add_session",
|
||||
response_model_exclude_none=True,
|
||||
)
|
||||
def add_session_to_eval_set(
|
||||
app_name: str, eval_set_id: str, req: AddSessionToEvalSetRequest
|
||||
):
|
||||
pattern = r"^[a-zA-Z0-9_]+$"
|
||||
if not bool(re.fullmatch(pattern, req.eval_id)):
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Invalid eval id. Eval id should have the `{pattern}` format",
|
||||
)
|
||||
|
||||
# Get the session
|
||||
session = session_service.get_session(
|
||||
app_name=app_name, user_id=req.user_id, session_id=req.session_id
|
||||
)
|
||||
assert session, "Session not found."
|
||||
# Load the eval set file data
|
||||
eval_set_file_path = _get_eval_set_file_path(
|
||||
app_name, agent_dir, eval_set_id
|
||||
)
|
||||
with open(eval_set_file_path, "r") as file:
|
||||
eval_set_data = json.load(file) # Load JSON into a list
|
||||
|
||||
if [x for x in eval_set_data if x["name"] == req.eval_id]:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=(
|
||||
f"Eval id `{req.eval_id}` already exists in `{eval_set_id}`"
|
||||
" eval set."
|
||||
),
|
||||
)
|
||||
|
||||
# Convert the session data to evaluation format
|
||||
test_data = evals.convert_session_to_eval_format(session)
|
||||
|
||||
# Populate the session with initial session state.
|
||||
initial_session_state = create_empty_state(_get_root_agent(app_name))
|
||||
|
||||
eval_set_data.append({
|
||||
"name": req.eval_id,
|
||||
"data": test_data,
|
||||
"initial_session": {
|
||||
"state": initial_session_state,
|
||||
"app_name": app_name,
|
||||
"user_id": req.user_id,
|
||||
},
|
||||
})
|
||||
# Serialize the test data to JSON and write to the eval set file.
|
||||
with open(eval_set_file_path, "w") as f:
|
||||
f.write(json.dumps(eval_set_data, indent=2))
|
||||
|
||||
@app.get(
|
||||
"/apps/{app_name}/eval_sets/{eval_set_id}/evals",
|
||||
response_model_exclude_none=True,
|
||||
)
|
||||
def list_evals_in_eval_set(
|
||||
app_name: str,
|
||||
eval_set_id: str,
|
||||
) -> list[str]:
|
||||
"""Lists all evals in an eval set."""
|
||||
# Load the eval set file data
|
||||
eval_set_file_path = _get_eval_set_file_path(
|
||||
app_name, agent_dir, eval_set_id
|
||||
)
|
||||
with open(eval_set_file_path, "r") as file:
|
||||
eval_set_data = json.load(file) # Load JSON into a list
|
||||
|
||||
return sorted([x["name"] for x in eval_set_data])
|
||||
|
||||
@app.post(
|
||||
"/apps/{app_name}/eval_sets/{eval_set_id}/run_eval",
|
||||
response_model_exclude_none=True,
|
||||
)
|
||||
def run_eval(
|
||||
app_name: str, eval_set_id: str, req: RunEvalRequest
|
||||
) -> list[RunEvalResult]:
|
||||
from .cli_eval import run_evals
|
||||
|
||||
"""Runs an eval given the details in the eval request."""
|
||||
# Create a mapping from eval set file to all the evals that needed to be
|
||||
# run.
|
||||
eval_set_file_path = _get_eval_set_file_path(
|
||||
app_name, agent_dir, eval_set_id
|
||||
)
|
||||
eval_set_to_evals = {eval_set_file_path: req.eval_ids}
|
||||
|
||||
if not req.eval_ids:
|
||||
logger.info(
|
||||
"Eval ids to run list is empty. We will all evals in the eval set."
|
||||
)
|
||||
root_agent = _get_root_agent(app_name)
|
||||
eval_results = list(
|
||||
run_evals(
|
||||
eval_set_to_evals,
|
||||
root_agent,
|
||||
getattr(root_agent, "reset_data", None),
|
||||
req.eval_metrics,
|
||||
session_service=session_service,
|
||||
artifact_service=artifact_service,
|
||||
)
|
||||
)
|
||||
|
||||
run_eval_results = []
|
||||
for eval_result in eval_results:
|
||||
run_eval_results.append(
|
||||
RunEvalResult(
|
||||
app_name=app_name,
|
||||
eval_set_id=eval_set_id,
|
||||
eval_id=eval_result.eval_id,
|
||||
final_eval_status=eval_result.final_eval_status,
|
||||
eval_metric_results=eval_result.eval_metric_results,
|
||||
session_id=eval_result.session_id,
|
||||
)
|
||||
)
|
||||
return run_eval_results
|
||||
|
||||
@app.delete("/apps/{app_name}/users/{user_id}/sessions/{session_id}")
|
||||
def delete_session(app_name: str, user_id: str, session_id: str):
|
||||
# Connect to managed session if agent_engine_id is set.
|
||||
app_name = agent_engine_id if agent_engine_id else app_name
|
||||
session_service.delete_session(
|
||||
app_name=app_name, user_id=user_id, session_id=session_id
|
||||
)
|
||||
|
||||
@app.get(
|
||||
"/apps/{app_name}/users/{user_id}/sessions/{session_id}/artifacts/{artifact_name}",
|
||||
response_model_exclude_none=True,
|
||||
)
|
||||
def load_artifact(
|
||||
app_name: str,
|
||||
user_id: str,
|
||||
session_id: str,
|
||||
artifact_name: str,
|
||||
version: Optional[int] = Query(None),
|
||||
) -> Optional[types.Part]:
|
||||
artifact = artifact_service.load_artifact(
|
||||
app_name=app_name,
|
||||
user_id=user_id,
|
||||
session_id=session_id,
|
||||
filename=artifact_name,
|
||||
version=version,
|
||||
)
|
||||
if not artifact:
|
||||
raise HTTPException(status_code=404, detail="Artifact not found")
|
||||
return artifact
|
||||
|
||||
@app.get(
|
||||
"/apps/{app_name}/users/{user_id}/sessions/{session_id}/artifacts/{artifact_name}/versions/{version_id}",
|
||||
response_model_exclude_none=True,
|
||||
)
|
||||
def load_artifact_version(
|
||||
app_name: str,
|
||||
user_id: str,
|
||||
session_id: str,
|
||||
artifact_name: str,
|
||||
version_id: int,
|
||||
) -> Optional[types.Part]:
|
||||
artifact = artifact_service.load_artifact(
|
||||
app_name=app_name,
|
||||
user_id=user_id,
|
||||
session_id=session_id,
|
||||
filename=artifact_name,
|
||||
version=version_id,
|
||||
)
|
||||
if not artifact:
|
||||
raise HTTPException(status_code=404, detail="Artifact not found")
|
||||
return artifact
|
||||
|
||||
@app.get(
|
||||
"/apps/{app_name}/users/{user_id}/sessions/{session_id}/artifacts",
|
||||
response_model_exclude_none=True,
|
||||
)
|
||||
def list_artifact_names(
|
||||
app_name: str, user_id: str, session_id: str
|
||||
) -> list[str]:
|
||||
return artifact_service.list_artifact_keys(
|
||||
app_name=app_name, user_id=user_id, session_id=session_id
|
||||
)
|
||||
|
||||
@app.get(
|
||||
"/apps/{app_name}/users/{user_id}/sessions/{session_id}/artifacts/{artifact_name}/versions",
|
||||
response_model_exclude_none=True,
|
||||
)
|
||||
def list_artifact_versions(
|
||||
app_name: str, user_id: str, session_id: str, artifact_name: str
|
||||
) -> list[int]:
|
||||
return artifact_service.list_versions(
|
||||
app_name=app_name,
|
||||
user_id=user_id,
|
||||
session_id=session_id,
|
||||
filename=artifact_name,
|
||||
)
|
||||
|
||||
@app.delete(
|
||||
"/apps/{app_name}/users/{user_id}/sessions/{session_id}/artifacts/{artifact_name}",
|
||||
)
|
||||
def delete_artifact(
|
||||
app_name: str, user_id: str, session_id: str, artifact_name: str
|
||||
):
|
||||
artifact_service.delete_artifact(
|
||||
app_name=app_name,
|
||||
user_id=user_id,
|
||||
session_id=session_id,
|
||||
filename=artifact_name,
|
||||
)
|
||||
|
||||
@app.post("/run", response_model_exclude_none=True)
|
||||
async def agent_run(req: AgentRunRequest) -> list[Event]:
|
||||
# Connect to managed session if agent_engine_id is set.
|
||||
app_id = agent_engine_id if agent_engine_id else req.app_name
|
||||
session = session_service.get_session(
|
||||
app_name=app_id, user_id=req.user_id, session_id=req.session_id
|
||||
)
|
||||
if not session:
|
||||
raise HTTPException(status_code=404, detail="Session not found")
|
||||
runner = _get_runner(req.app_name)
|
||||
events = [
|
||||
event
|
||||
async for event in runner.run_async(
|
||||
user_id=req.user_id,
|
||||
session_id=req.session_id,
|
||||
new_message=req.new_message,
|
||||
)
|
||||
]
|
||||
logger.info("Generated %s events in agent run: %s", len(events), events)
|
||||
return events
|
||||
|
||||
@app.post("/run_sse")
|
||||
async def agent_run_sse(req: AgentRunRequest) -> StreamingResponse:
|
||||
# Connect to managed session if agent_engine_id is set.
|
||||
app_id = agent_engine_id if agent_engine_id else req.app_name
|
||||
# SSE endpoint
|
||||
session = session_service.get_session(
|
||||
app_name=app_id, user_id=req.user_id, session_id=req.session_id
|
||||
)
|
||||
if not session:
|
||||
raise HTTPException(status_code=404, detail="Session not found")
|
||||
|
||||
# Convert the events to properly formatted SSE
|
||||
async def event_generator():
|
||||
try:
|
||||
stream_mode = StreamingMode.SSE if req.streaming else StreamingMode.NONE
|
||||
runner = _get_runner(req.app_name)
|
||||
async for event in runner.run_async(
|
||||
user_id=req.user_id,
|
||||
session_id=req.session_id,
|
||||
new_message=req.new_message,
|
||||
run_config=RunConfig(streaming_mode=stream_mode),
|
||||
):
|
||||
# Format as SSE data
|
||||
sse_event = event.model_dump_json(exclude_none=True, by_alias=True)
|
||||
logger.info("Generated event in agent run streaming: %s", sse_event)
|
||||
yield f"data: {sse_event}\n\n"
|
||||
except Exception as e:
|
||||
logger.exception("Error in event_generator: %s", e)
|
||||
# You might want to yield an error event here
|
||||
yield f'data: {{"error": "{str(e)}"}}\n\n'
|
||||
|
||||
# Returns a streaming response with the proper media type for SSE
|
||||
return StreamingResponse(
|
||||
event_generator(),
|
||||
media_type="text/event-stream",
|
||||
)
|
||||
|
||||
@app.get(
|
||||
"/apps/{app_name}/users/{user_id}/sessions/{session_id}/events/{event_id}/graph",
|
||||
response_model_exclude_none=True,
|
||||
)
|
||||
def get_event_graph(
|
||||
app_name: str, user_id: str, session_id: str, event_id: str
|
||||
):
|
||||
# Connect to managed session if agent_engine_id is set.
|
||||
app_id = agent_engine_id if agent_engine_id else app_name
|
||||
session = session_service.get_session(
|
||||
app_name=app_id, user_id=user_id, session_id=session_id
|
||||
)
|
||||
session_events = session.events if session else []
|
||||
event = next((x for x in session_events if x.id == event_id), None)
|
||||
if not event:
|
||||
return {}
|
||||
|
||||
from . import agent_graph
|
||||
|
||||
function_calls = event.get_function_calls()
|
||||
function_responses = event.get_function_responses()
|
||||
root_agent = _get_root_agent(app_name)
|
||||
dot_graph = None
|
||||
if function_calls:
|
||||
function_call_highlights = []
|
||||
for function_call in function_calls:
|
||||
from_name = event.author
|
||||
to_name = function_call.name
|
||||
function_call_highlights.append((from_name, to_name))
|
||||
dot_graph = agent_graph.get_agent_graph(
|
||||
root_agent, function_call_highlights
|
||||
)
|
||||
elif function_responses:
|
||||
function_responses_highlights = []
|
||||
for function_response in function_responses:
|
||||
from_name = function_response.name
|
||||
to_name = event.author
|
||||
function_responses_highlights.append((from_name, to_name))
|
||||
dot_graph = agent_graph.get_agent_graph(
|
||||
root_agent, function_responses_highlights
|
||||
)
|
||||
else:
|
||||
from_name = event.author
|
||||
to_name = ""
|
||||
dot_graph = agent_graph.get_agent_graph(
|
||||
root_agent, [(from_name, to_name)]
|
||||
)
|
||||
if dot_graph and isinstance(dot_graph, graphviz.Digraph):
|
||||
return {"dot_src": dot_graph.source}
|
||||
else:
|
||||
return {}
|
||||
|
||||
@app.websocket("/run_live")
|
||||
async def agent_live_run(
|
||||
websocket: WebSocket,
|
||||
app_name: str,
|
||||
user_id: str,
|
||||
session_id: str,
|
||||
modalities: List[Literal["TEXT", "AUDIO"]] = Query(
|
||||
default=["TEXT", "AUDIO"]
|
||||
), # Only allows "TEXT" or "AUDIO"
|
||||
) -> None:
|
||||
await websocket.accept()
|
||||
|
||||
# Connect to managed session if agent_engine_id is set.
|
||||
app_id = agent_engine_id if agent_engine_id else app_name
|
||||
session = session_service.get_session(
|
||||
app_name=app_id, user_id=user_id, session_id=session_id
|
||||
)
|
||||
if not session:
|
||||
# Accept first so that the client is aware of connection establishment,
|
||||
# then close with a specific code.
|
||||
await websocket.close(code=1002, reason="Session not found")
|
||||
return
|
||||
|
||||
live_request_queue = LiveRequestQueue()
|
||||
|
||||
async def forward_events():
|
||||
runner = _get_runner(app_name)
|
||||
async for event in runner.run_live(
|
||||
session=session, live_request_queue=live_request_queue
|
||||
):
|
||||
await websocket.send_text(
|
||||
event.model_dump_json(exclude_none=True, by_alias=True)
|
||||
)
|
||||
|
||||
async def process_messages():
|
||||
try:
|
||||
while True:
|
||||
data = await websocket.receive_text()
|
||||
# Validate and send the received message to the live queue.
|
||||
live_request_queue.send(LiveRequest.model_validate_json(data))
|
||||
except ValidationError as ve:
|
||||
logger.error("Validation error in process_messages: %s", ve)
|
||||
|
||||
# Run both tasks concurrently and cancel all if one fails.
|
||||
tasks = [
|
||||
asyncio.create_task(forward_events()),
|
||||
asyncio.create_task(process_messages()),
|
||||
]
|
||||
done, pending = await asyncio.wait(
|
||||
tasks, return_when=asyncio.FIRST_EXCEPTION
|
||||
)
|
||||
try:
|
||||
# This will re-raise any exception from the completed tasks.
|
||||
for task in done:
|
||||
task.result()
|
||||
except WebSocketDisconnect:
|
||||
logger.info("Client disconnected during process_messages.")
|
||||
except Exception as e:
|
||||
logger.exception("Error during live websocket communication: %s", e)
|
||||
traceback.print_exc()
|
||||
finally:
|
||||
for task in pending:
|
||||
task.cancel()
|
||||
|
||||
def _get_root_agent(app_name: str) -> Agent:
|
||||
"""Returns the root agent for the given app."""
|
||||
if app_name in root_agent_dict:
|
||||
return root_agent_dict[app_name]
|
||||
envs.load_dotenv_for_agent(os.path.basename(app_name), agent_dir)
|
||||
agent_module = importlib.import_module(app_name)
|
||||
root_agent: Agent = agent_module.agent.root_agent
|
||||
root_agent_dict[app_name] = root_agent
|
||||
return root_agent
|
||||
|
||||
def _get_runner(app_name: str) -> Runner:
|
||||
"""Returns the runner for the given app."""
|
||||
if app_name in runner_dict:
|
||||
return runner_dict[app_name]
|
||||
root_agent = _get_root_agent(app_name)
|
||||
runner = Runner(
|
||||
app_name=agent_engine_id if agent_engine_id else app_name,
|
||||
agent=root_agent,
|
||||
artifact_service=artifact_service,
|
||||
session_service=session_service,
|
||||
)
|
||||
runner_dict[app_name] = runner
|
||||
return runner
|
||||
|
||||
if web:
|
||||
BASE_DIR = Path(__file__).parent.resolve()
|
||||
ANGULAR_DIST_PATH = BASE_DIR / "browser"
|
||||
|
||||
@app.get("/")
|
||||
async def redirect_to_dev_ui():
|
||||
return RedirectResponse("/dev-ui")
|
||||
|
||||
@app.get("/dev-ui")
|
||||
async def dev_ui():
|
||||
return FileResponse(BASE_DIR / "browser/index.html")
|
||||
|
||||
app.mount(
|
||||
"/", StaticFiles(directory=ANGULAR_DIST_PATH, html=True), name="static"
|
||||
)
|
||||
return app
|
||||
@@ -0,0 +1,49 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import re
|
||||
from typing import Any
|
||||
from typing import Optional
|
||||
|
||||
from ...agents.base_agent import BaseAgent
|
||||
from ...agents.llm_agent import LlmAgent
|
||||
|
||||
__all__ = [
|
||||
'create_empty_state',
|
||||
]
|
||||
|
||||
|
||||
def _create_empty_state(agent: BaseAgent, all_state: dict[str, Any]):
|
||||
for sub_agent in agent.sub_agents:
|
||||
_create_empty_state(sub_agent, all_state)
|
||||
|
||||
if (
|
||||
isinstance(agent, LlmAgent)
|
||||
and agent.instruction
|
||||
and isinstance(agent.instruction, str)
|
||||
):
|
||||
for key in re.findall(r'{([\w]+)}', agent.instruction):
|
||||
all_state[key] = ''
|
||||
|
||||
|
||||
def create_empty_state(
|
||||
agent: BaseAgent, initialized_states: Optional[dict[str, Any]] = None
|
||||
) -> dict[str, Any]:
|
||||
"""Creates empty str for non-initialized states."""
|
||||
non_initialized_states = {}
|
||||
_create_empty_state(agent, non_initialized_states)
|
||||
for key in initialized_states or {}:
|
||||
if key in non_initialized_states:
|
||||
del non_initialized_states[key]
|
||||
return non_initialized_states
|
||||
@@ -0,0 +1,57 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
import os
|
||||
|
||||
from dotenv import load_dotenv
|
||||
|
||||
logger = logging.getLogger(__file__)
|
||||
|
||||
|
||||
def _walk_to_root_until_found(folder, filename) -> str:
|
||||
checkpath = os.path.join(folder, filename)
|
||||
if os.path.exists(checkpath) and os.path.isfile(checkpath):
|
||||
return checkpath
|
||||
|
||||
parent_folder = os.path.dirname(folder)
|
||||
if parent_folder == folder: # reached the root
|
||||
return ''
|
||||
|
||||
return _walk_to_root_until_found(parent_folder, filename)
|
||||
|
||||
|
||||
def load_dotenv_for_agent(
|
||||
agent_name: str, agent_parent_folder: str, filename: str = '.env'
|
||||
):
|
||||
"""Lods the .env file for the agent module."""
|
||||
|
||||
# Gets the folder of agent_module as starting_folder
|
||||
starting_folder = os.path.abspath(
|
||||
os.path.join(agent_parent_folder, agent_name)
|
||||
)
|
||||
dotenv_file_path = _walk_to_root_until_found(starting_folder, filename)
|
||||
if dotenv_file_path:
|
||||
load_dotenv(dotenv_file_path, override=True, verbose=True)
|
||||
logger.info(
|
||||
'Loaded %s file for %s at %s',
|
||||
filename,
|
||||
agent_name,
|
||||
dotenv_file_path,
|
||||
)
|
||||
logger.info(
|
||||
'Reloaded %s file for %s at %s', filename, agent_name, dotenv_file_path
|
||||
)
|
||||
else:
|
||||
logger.info('No %s file found for %s', filename, agent_name)
|
||||
@@ -0,0 +1,93 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Any
|
||||
|
||||
from ...sessions.session import Session
|
||||
|
||||
|
||||
def convert_session_to_eval_format(session: Session) -> list[dict[str, Any]]:
|
||||
"""Converts a session data into eval format.
|
||||
|
||||
Args:
|
||||
session: The session that should be converted.
|
||||
|
||||
Returns:
|
||||
list: A single evaluation dataset in the required format.
|
||||
"""
|
||||
eval_case = []
|
||||
events = session.events if session and session.events else []
|
||||
|
||||
for event in events:
|
||||
if event.author == 'user':
|
||||
if not event.content or not event.content.parts:
|
||||
continue
|
||||
|
||||
# Extract user query
|
||||
content = event.content
|
||||
parts = content.parts
|
||||
|
||||
query = parts[0].text or ''
|
||||
|
||||
# Find the corresponding tool usage or response for the query
|
||||
expected_tool_use = []
|
||||
intermediate_agent_responses = []
|
||||
|
||||
# Check subsequent events to extract tool uses or responses for this turn.
|
||||
for subsequent_event in events[events.index(event) + 1 :]:
|
||||
event_author = subsequent_event.author or 'agent'
|
||||
if event_author == 'user':
|
||||
# We found an event where the author was the user. This means that a
|
||||
# new turn has started. So close this turn here.
|
||||
break
|
||||
|
||||
if not subsequent_event.content or not subsequent_event.content.parts:
|
||||
continue
|
||||
|
||||
for subsequent_part in subsequent_event.content.parts:
|
||||
# Some events have both function call and reference
|
||||
|
||||
if subsequent_part.function_call:
|
||||
tool_name = subsequent_part.function_call.name or ''
|
||||
tool_input = subsequent_part.function_call.args or {}
|
||||
expected_tool_use.append({
|
||||
'tool_name': tool_name,
|
||||
'tool_input': tool_input,
|
||||
})
|
||||
elif subsequent_part.text:
|
||||
# Also keep track of all the natural langauge responses that
|
||||
# agent (or sub agents) generated.
|
||||
intermediate_agent_responses.append(
|
||||
{'author': event_author, 'text': subsequent_part.text}
|
||||
)
|
||||
|
||||
# If we are here then either we are done reading all the events or we
|
||||
# encountered an event that had content authored by the end-user.
|
||||
# This, basically means an end of turn.
|
||||
# We assume that the last natural langauge intermediate response is the
|
||||
# final response from the agent/model. We treat that as a reference.
|
||||
eval_case.append({
|
||||
'query': query,
|
||||
'expected_tool_use': expected_tool_use,
|
||||
'expected_intermediate_agent_responses': intermediate_agent_responses[
|
||||
:-1
|
||||
],
|
||||
'reference': (
|
||||
intermediate_agent_responses[-1]['text']
|
||||
if intermediate_agent_responses
|
||||
else ''
|
||||
),
|
||||
})
|
||||
|
||||
return eval_case
|
||||
@@ -0,0 +1,72 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
import os
|
||||
import tempfile
|
||||
import time
|
||||
|
||||
LOGGING_FORMAT = (
|
||||
'%(asctime)s - %(levelname)s - %(filename)s:%(lineno)d - %(message)s'
|
||||
)
|
||||
|
||||
|
||||
def log_to_stderr(level=logging.INFO):
|
||||
logging.basicConfig(
|
||||
level=level,
|
||||
format=LOGGING_FORMAT,
|
||||
)
|
||||
|
||||
|
||||
def log_to_tmp_folder(
|
||||
level=logging.INFO,
|
||||
*,
|
||||
sub_folder: str = 'agents_log',
|
||||
log_file_prefix: str = 'agent',
|
||||
log_file_timestamp: str = time.strftime('%Y%m%d_%H%M%S'),
|
||||
):
|
||||
"""Logs to system temp folder, instead of logging to stderr.
|
||||
|
||||
Args
|
||||
sub_folder: str = 'agents_log',
|
||||
log_file_prefix: str = 'agent',
|
||||
log_file_timestamp: str = time.strftime('%Y%m%d_%H%M%S'),
|
||||
|
||||
Returns
|
||||
the log file path.
|
||||
"""
|
||||
log_dir = os.path.join(tempfile.gettempdir(), sub_folder)
|
||||
log_filename = f'{log_file_prefix}.{log_file_timestamp}.log'
|
||||
log_filepath = os.path.join(log_dir, log_filename)
|
||||
|
||||
os.makedirs(log_dir, exist_ok=True)
|
||||
|
||||
file_handler = logging.FileHandler(log_filepath, mode='w')
|
||||
file_handler.setLevel(level)
|
||||
file_handler.setFormatter(logging.Formatter(LOGGING_FORMAT))
|
||||
|
||||
root_logger = logging.getLogger()
|
||||
root_logger.setLevel(level)
|
||||
root_logger.handlers = [] # Clear handles to disable logging to stderr
|
||||
root_logger.addHandler(file_handler)
|
||||
|
||||
print(f'Log setup complete: {log_filepath}')
|
||||
|
||||
latest_log_link = os.path.join(log_dir, f'{log_file_prefix}.latest.log')
|
||||
if os.path.islink(latest_log_link):
|
||||
os.unlink(latest_log_link)
|
||||
os.symlink(log_filepath, latest_log_link)
|
||||
|
||||
print(f'To access latest log: tail -F {latest_log_link}')
|
||||
return log_filepath
|
||||
Reference in New Issue
Block a user