mirror of
https://github.com/EvolutionAPI/adk-python.git
synced 2025-07-14 01:41:25 -06:00
No public description
PiperOrigin-RevId: 748777998
This commit is contained in:
parent
290058eb05
commit
61d4be2d76
13
README.md
13
README.md
@ -5,9 +5,9 @@
|
|||||||
[](https://www.reddit.com/r/agentdevelopmentkit/)
|
[](https://www.reddit.com/r/agentdevelopmentkit/)
|
||||||
|
|
||||||
<html>
|
<html>
|
||||||
<h1 align="center">
|
<h2 align="center">
|
||||||
<img src="assets/agent-development-kit.png" width="256"/>
|
<img src="https://raw.githubusercontent.com/google/adk-python/main/assets/agent-development-kit.png" width="256"/>
|
||||||
</h1>
|
</h2>
|
||||||
<h3 align="center">
|
<h3 align="center">
|
||||||
An open-source, code-first Python toolkit for building, evaluating, and deploying sophisticated AI agents with flexibility and control.
|
An open-source, code-first Python toolkit for building, evaluating, and deploying sophisticated AI agents with flexibility and control.
|
||||||
</h3>
|
</h3>
|
||||||
@ -50,6 +50,7 @@ You can install the ADK using `pip`:
|
|||||||
```bash
|
```bash
|
||||||
pip install google-adk
|
pip install google-adk
|
||||||
```
|
```
|
||||||
|
|
||||||
## 📚 Documentation
|
## 📚 Documentation
|
||||||
|
|
||||||
Explore the full documentation for detailed guides on building, evaluating, and
|
Explore the full documentation for detailed guides on building, evaluating, and
|
||||||
@ -60,6 +61,7 @@ deploying agents:
|
|||||||
## 🏁 Feature Highlight
|
## 🏁 Feature Highlight
|
||||||
|
|
||||||
### Define a single agent:
|
### Define a single agent:
|
||||||
|
|
||||||
```python
|
```python
|
||||||
from google.adk.agents import Agent
|
from google.adk.agents import Agent
|
||||||
from google.adk.tools import google_search
|
from google.adk.tools import google_search
|
||||||
@ -74,7 +76,9 @@ root_agent = Agent(
|
|||||||
```
|
```
|
||||||
|
|
||||||
### Define a multi-agent system:
|
### Define a multi-agent system:
|
||||||
|
|
||||||
Define a multi-agent system with coordinator agent, greeter agent, and task execution agent. Then ADK engine and the model will guide the agents works together to accomplish the task.
|
Define a multi-agent system with coordinator agent, greeter agent, and task execution agent. Then ADK engine and the model will guide the agents works together to accomplish the task.
|
||||||
|
|
||||||
```python
|
```python
|
||||||
from google.adk.agents import LlmAgent, BaseAgent
|
from google.adk.agents import LlmAgent, BaseAgent
|
||||||
|
|
||||||
@ -92,14 +96,13 @@ coordinator = LlmAgent(
|
|||||||
task_executor
|
task_executor
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
```
|
```
|
||||||
|
|
||||||
### Development UI
|
### Development UI
|
||||||
|
|
||||||
A built-in development UI to help you test, evaluate, debug, and showcase your agent(s).
|
A built-in development UI to help you test, evaluate, debug, and showcase your agent(s).
|
||||||
|
|
||||||
<img src="assets/adk-web-dev-ui-function-call.png"/>
|
<img src="https://raw.githubusercontent.com/google/adk-python/main/assets/adk-web-dev-ui-function-call.png"/>
|
||||||
|
|
||||||
### Evaluate Agents
|
### Evaluate Agents
|
||||||
|
|
||||||
|
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
@ -82,6 +82,8 @@ async def run_interactively(
|
|||||||
)
|
)
|
||||||
while True:
|
while True:
|
||||||
query = input('user: ')
|
query = input('user: ')
|
||||||
|
if not query or not query.strip():
|
||||||
|
continue
|
||||||
if query == 'exit':
|
if query == 'exit':
|
||||||
break
|
break
|
||||||
async for event in runner.run_async(
|
async for event in runner.run_async(
|
||||||
|
279
src/google/adk/cli/cli_create.py
Normal file
279
src/google/adk/cli/cli_create.py
Normal file
@ -0,0 +1,279 @@
|
|||||||
|
# Copyright 2025 Google LLC
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import os
|
||||||
|
import subprocess
|
||||||
|
from typing import Optional
|
||||||
|
from typing import Tuple
|
||||||
|
|
||||||
|
import click
|
||||||
|
|
||||||
|
_INIT_PY_TEMPLATE = """\
|
||||||
|
from . import agent
|
||||||
|
"""
|
||||||
|
|
||||||
|
_AGENT_PY_TEMPLATE = """\
|
||||||
|
from google.adk.agents import Agent
|
||||||
|
|
||||||
|
root_agent = Agent(
|
||||||
|
model='{model_name}',
|
||||||
|
name='root_agent',
|
||||||
|
description='A helpful assistant for user questions.',
|
||||||
|
instruction='Answer user questions to the best of your knowledge',
|
||||||
|
)
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
_GOOGLE_API_MSG = """
|
||||||
|
Don't have API Key? Create one in AI Studio: https://aistudio.google.com/apikey
|
||||||
|
"""
|
||||||
|
|
||||||
|
_GOOGLE_CLOUD_SETUP_MSG = """
|
||||||
|
You need an existing Google Cloud account and project, check out this link for details:
|
||||||
|
https://google.github.io/adk-docs/get-started/quickstart/#gemini---google-cloud-vertex-ai
|
||||||
|
"""
|
||||||
|
|
||||||
|
_OTHER_MODEL_MSG = """
|
||||||
|
Please see below guide to configure other models:
|
||||||
|
https://google.github.io/adk-docs/agents/models
|
||||||
|
"""
|
||||||
|
|
||||||
|
_SUCCESS_MSG = """
|
||||||
|
Agent created in {agent_folder}:
|
||||||
|
- .env
|
||||||
|
- __init__.py
|
||||||
|
- agent.py
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
def _get_gcp_project_from_gcloud() -> str:
|
||||||
|
"""Uses gcloud to get default project."""
|
||||||
|
try:
|
||||||
|
result = subprocess.run(
|
||||||
|
["gcloud", "config", "get-value", "project"],
|
||||||
|
capture_output=True,
|
||||||
|
text=True,
|
||||||
|
check=True,
|
||||||
|
)
|
||||||
|
return result.stdout.strip()
|
||||||
|
except (subprocess.CalledProcessError, FileNotFoundError):
|
||||||
|
return ""
|
||||||
|
|
||||||
|
|
||||||
|
def _get_gcp_region_from_gcloud() -> str:
|
||||||
|
"""Uses gcloud to get default region."""
|
||||||
|
try:
|
||||||
|
result = subprocess.run(
|
||||||
|
["gcloud", "config", "get-value", "compute/region"],
|
||||||
|
capture_output=True,
|
||||||
|
text=True,
|
||||||
|
check=True,
|
||||||
|
)
|
||||||
|
return result.stdout.strip()
|
||||||
|
except (subprocess.CalledProcessError, FileNotFoundError):
|
||||||
|
return ""
|
||||||
|
|
||||||
|
|
||||||
|
def _prompt_str(
|
||||||
|
prompt_prefix: str,
|
||||||
|
*,
|
||||||
|
prior_msg: Optional[str] = None,
|
||||||
|
default_value: Optional[str] = None,
|
||||||
|
) -> str:
|
||||||
|
if prior_msg:
|
||||||
|
click.secho(prior_msg, fg="green")
|
||||||
|
while True:
|
||||||
|
value: str = click.prompt(
|
||||||
|
prompt_prefix, default=default_value or None, type=str
|
||||||
|
)
|
||||||
|
if value and value.strip():
|
||||||
|
return value.strip()
|
||||||
|
|
||||||
|
|
||||||
|
def _prompt_for_google_cloud(
|
||||||
|
google_cloud_project: Optional[str],
|
||||||
|
) -> str:
|
||||||
|
"""Prompts user for Google Cloud project ID."""
|
||||||
|
google_cloud_project = (
|
||||||
|
google_cloud_project
|
||||||
|
or os.environ.get("GOOGLE_CLOUD_PROJECT", None)
|
||||||
|
or _get_gcp_project_from_gcloud()
|
||||||
|
)
|
||||||
|
|
||||||
|
google_cloud_project = _prompt_str(
|
||||||
|
"Enter Google Cloud project ID", default_value=google_cloud_project
|
||||||
|
)
|
||||||
|
|
||||||
|
return google_cloud_project
|
||||||
|
|
||||||
|
|
||||||
|
def _prompt_for_google_cloud_region(
|
||||||
|
google_cloud_region: Optional[str],
|
||||||
|
) -> str:
|
||||||
|
"""Prompts user for Google Cloud region."""
|
||||||
|
google_cloud_region = (
|
||||||
|
google_cloud_region
|
||||||
|
or os.environ.get("GOOGLE_CLOUD_LOCATION", None)
|
||||||
|
or _get_gcp_region_from_gcloud()
|
||||||
|
)
|
||||||
|
|
||||||
|
google_cloud_region = _prompt_str(
|
||||||
|
"Enter Google Cloud region",
|
||||||
|
default_value=google_cloud_region or "us-central1",
|
||||||
|
)
|
||||||
|
return google_cloud_region
|
||||||
|
|
||||||
|
|
||||||
|
def _prompt_for_google_api_key(
|
||||||
|
google_api_key: Optional[str],
|
||||||
|
) -> str:
|
||||||
|
"""Prompts user for Google API key."""
|
||||||
|
google_api_key = google_api_key or os.environ.get("GOOGLE_API_KEY", None)
|
||||||
|
|
||||||
|
google_api_key = _prompt_str(
|
||||||
|
"Enter Google API key",
|
||||||
|
prior_msg=_GOOGLE_API_MSG,
|
||||||
|
default_value=google_api_key,
|
||||||
|
)
|
||||||
|
return google_api_key
|
||||||
|
|
||||||
|
|
||||||
|
def _generate_files(
|
||||||
|
agent_folder: str,
|
||||||
|
*,
|
||||||
|
google_api_key: Optional[str] = None,
|
||||||
|
google_cloud_project: Optional[str] = None,
|
||||||
|
google_cloud_region: Optional[str] = None,
|
||||||
|
model: Optional[str] = None,
|
||||||
|
):
|
||||||
|
"""Generates a folder name for the agent."""
|
||||||
|
os.makedirs(agent_folder, exist_ok=True)
|
||||||
|
|
||||||
|
dotenv_file_path = os.path.join(agent_folder, ".env")
|
||||||
|
init_file_path = os.path.join(agent_folder, "__init__.py")
|
||||||
|
agent_file_path = os.path.join(agent_folder, "agent.py")
|
||||||
|
|
||||||
|
with open(dotenv_file_path, "w", encoding="utf-8") as f:
|
||||||
|
lines = []
|
||||||
|
if google_api_key:
|
||||||
|
lines.append("GOOGLE_GENAI_USE_VERTEXAI=0")
|
||||||
|
elif google_cloud_project and google_cloud_region:
|
||||||
|
lines.append("GOOGLE_GENAI_USE_VERTEXAI=1")
|
||||||
|
if google_api_key:
|
||||||
|
lines.append(f"GOOGLE_API_KEY={google_api_key}")
|
||||||
|
if google_cloud_project:
|
||||||
|
lines.append(f"GOOGLE_CLOUD_PROJECT={google_cloud_project}")
|
||||||
|
if google_cloud_region:
|
||||||
|
lines.append(f"GOOGLE_CLOUD_LOCATION={google_cloud_region}")
|
||||||
|
f.write("\n".join(lines))
|
||||||
|
|
||||||
|
with open(init_file_path, "w", encoding="utf-8") as f:
|
||||||
|
f.write(_INIT_PY_TEMPLATE)
|
||||||
|
|
||||||
|
with open(agent_file_path, "w", encoding="utf-8") as f:
|
||||||
|
f.write(_AGENT_PY_TEMPLATE.format(model_name=model))
|
||||||
|
|
||||||
|
click.secho(
|
||||||
|
_SUCCESS_MSG.format(agent_folder=agent_folder),
|
||||||
|
fg="green",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _prompt_for_model() -> str:
|
||||||
|
model_choice = click.prompt(
|
||||||
|
"""\
|
||||||
|
Choose a model for the root agent:
|
||||||
|
1. gemini-2.0-flash-001
|
||||||
|
2. Other models (fill later)
|
||||||
|
Choose model""",
|
||||||
|
type=click.Choice(["1", "2"]),
|
||||||
|
)
|
||||||
|
if model_choice == "1":
|
||||||
|
return "gemini-2.0-flash-001"
|
||||||
|
else:
|
||||||
|
click.secho(_OTHER_MODEL_MSG, fg="green")
|
||||||
|
return "<FILL_IN_MODEL>"
|
||||||
|
|
||||||
|
|
||||||
|
def _prompt_to_choose_backend(
|
||||||
|
google_api_key: Optional[str],
|
||||||
|
google_cloud_project: Optional[str],
|
||||||
|
google_cloud_region: Optional[str],
|
||||||
|
) -> Tuple[Optional[str], Optional[str], Optional[str]]:
|
||||||
|
"""Prompts user to choose backend.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A tuple of (google_api_key, google_cloud_project, google_cloud_region).
|
||||||
|
"""
|
||||||
|
backend_choice = click.prompt(
|
||||||
|
"1. Google AI\n2. Vertex AI\nChoose a backend",
|
||||||
|
type=click.Choice(["1", "2"]),
|
||||||
|
)
|
||||||
|
if backend_choice == "1":
|
||||||
|
google_api_key = _prompt_for_google_api_key(google_api_key)
|
||||||
|
elif backend_choice == "2":
|
||||||
|
click.secho(_GOOGLE_CLOUD_SETUP_MSG, fg="green")
|
||||||
|
google_cloud_project = _prompt_for_google_cloud(google_cloud_project)
|
||||||
|
google_cloud_region = _prompt_for_google_cloud_region(google_cloud_region)
|
||||||
|
return google_api_key, google_cloud_project, google_cloud_region
|
||||||
|
|
||||||
|
|
||||||
|
def run_cmd(
|
||||||
|
agent_name: str,
|
||||||
|
*,
|
||||||
|
model: Optional[str],
|
||||||
|
google_api_key: Optional[str],
|
||||||
|
google_cloud_project: Optional[str],
|
||||||
|
google_cloud_region: Optional[str],
|
||||||
|
):
|
||||||
|
"""Runs `adk create` command to create agent template.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
agent_name: str, The name of the agent.
|
||||||
|
google_api_key: Optional[str], The Google API key for using Google AI as
|
||||||
|
backend.
|
||||||
|
google_cloud_project: Optional[str], The Google Cloud project for using
|
||||||
|
VertexAI as backend.
|
||||||
|
google_cloud_region: Optional[str], The Google Cloud region for using
|
||||||
|
VertexAI as backend.
|
||||||
|
"""
|
||||||
|
agent_folder = os.path.join(os.getcwd(), agent_name)
|
||||||
|
# check folder doesn't exist or it's empty. Otherwise, throw
|
||||||
|
if os.path.exists(agent_folder) and os.listdir(agent_folder):
|
||||||
|
# Prompt user whether to override existing files using click
|
||||||
|
if not click.confirm(
|
||||||
|
f"Non-empty folder already exist: '{agent_folder}'\n"
|
||||||
|
"Override existing content?",
|
||||||
|
default=False,
|
||||||
|
):
|
||||||
|
raise click.Abort()
|
||||||
|
|
||||||
|
if not model:
|
||||||
|
model = _prompt_for_model()
|
||||||
|
|
||||||
|
if not google_api_key and not (google_cloud_project and google_cloud_region):
|
||||||
|
if model.startswith("gemini"):
|
||||||
|
google_api_key, google_cloud_project, google_cloud_region = (
|
||||||
|
_prompt_to_choose_backend(
|
||||||
|
google_api_key, google_cloud_project, google_cloud_region
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
_generate_files(
|
||||||
|
agent_folder,
|
||||||
|
google_api_key=google_api_key,
|
||||||
|
google_cloud_project=google_cloud_project,
|
||||||
|
google_cloud_region=google_cloud_region,
|
||||||
|
model=model,
|
||||||
|
)
|
@ -82,8 +82,9 @@ def to_cloud_run(
|
|||||||
app_name: str,
|
app_name: str,
|
||||||
temp_folder: str,
|
temp_folder: str,
|
||||||
port: int,
|
port: int,
|
||||||
with_cloud_trace: bool,
|
trace_to_cloud: bool,
|
||||||
with_ui: bool,
|
with_ui: bool,
|
||||||
|
verbosity: str,
|
||||||
):
|
):
|
||||||
"""Deploys an agent to Google Cloud Run.
|
"""Deploys an agent to Google Cloud Run.
|
||||||
|
|
||||||
@ -108,8 +109,9 @@ def to_cloud_run(
|
|||||||
app_name: The name of the app, by default, it's basename of `agent_folder`.
|
app_name: The name of the app, by default, it's basename of `agent_folder`.
|
||||||
temp_folder: The temp folder for the generated Cloud Run source files.
|
temp_folder: The temp folder for the generated Cloud Run source files.
|
||||||
port: The port of the ADK api server.
|
port: The port of the ADK api server.
|
||||||
with_cloud_trace: Whether to enable Cloud Trace.
|
trace_to_cloud: Whether to enable Cloud Trace.
|
||||||
with_ui: Whether to deploy with UI.
|
with_ui: Whether to deploy with UI.
|
||||||
|
verbosity: The verbosity level of the CLI.
|
||||||
"""
|
"""
|
||||||
app_name = app_name or os.path.basename(agent_folder)
|
app_name = app_name or os.path.basename(agent_folder)
|
||||||
|
|
||||||
@ -142,7 +144,7 @@ def to_cloud_run(
|
|||||||
port=port,
|
port=port,
|
||||||
command='web' if with_ui else 'api_server',
|
command='web' if with_ui else 'api_server',
|
||||||
install_agent_deps=install_agent_deps,
|
install_agent_deps=install_agent_deps,
|
||||||
trace_to_cloud_option='--trace_to_cloud' if with_cloud_trace else '',
|
trace_to_cloud_option='--trace_to_cloud' if trace_to_cloud else '',
|
||||||
)
|
)
|
||||||
dockerfile_path = os.path.join(temp_folder, 'Dockerfile')
|
dockerfile_path = os.path.join(temp_folder, 'Dockerfile')
|
||||||
os.makedirs(temp_folder, exist_ok=True)
|
os.makedirs(temp_folder, exist_ok=True)
|
||||||
@ -169,6 +171,8 @@ def to_cloud_run(
|
|||||||
*region_options,
|
*region_options,
|
||||||
'--port',
|
'--port',
|
||||||
str(port),
|
str(port),
|
||||||
|
'--verbosity',
|
||||||
|
verbosity,
|
||||||
'--labels',
|
'--labels',
|
||||||
'created-by=adk',
|
'created-by=adk',
|
||||||
],
|
],
|
||||||
|
@ -24,6 +24,7 @@ import click
|
|||||||
from fastapi import FastAPI
|
from fastapi import FastAPI
|
||||||
import uvicorn
|
import uvicorn
|
||||||
|
|
||||||
|
from . import cli_create
|
||||||
from . import cli_deploy
|
from . import cli_deploy
|
||||||
from .cli import run_cli
|
from .cli import run_cli
|
||||||
from .cli_eval import MISSING_EVAL_DEPENDENCIES_MESSAGE
|
from .cli_eval import MISSING_EVAL_DEPENDENCIES_MESSAGE
|
||||||
@ -42,10 +43,59 @@ def main():
|
|||||||
|
|
||||||
@main.group()
|
@main.group()
|
||||||
def deploy():
|
def deploy():
|
||||||
"""Deploy Agent."""
|
"""Deploys agent to hosted environments."""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
@main.command("create")
|
||||||
|
@click.option(
|
||||||
|
"--model",
|
||||||
|
type=str,
|
||||||
|
help="Optional. The model used for the root agent.",
|
||||||
|
)
|
||||||
|
@click.option(
|
||||||
|
"--api_key",
|
||||||
|
type=str,
|
||||||
|
help=(
|
||||||
|
"Optional. The API Key needed to access the model, e.g. Google AI API"
|
||||||
|
" Key."
|
||||||
|
),
|
||||||
|
)
|
||||||
|
@click.option(
|
||||||
|
"--project",
|
||||||
|
type=str,
|
||||||
|
help="Optional. The Google Cloud Project for using VertexAI as backend.",
|
||||||
|
)
|
||||||
|
@click.option(
|
||||||
|
"--region",
|
||||||
|
type=str,
|
||||||
|
help="Optional. The Google Cloud Region for using VertexAI as backend.",
|
||||||
|
)
|
||||||
|
@click.argument("app_name", type=str, required=True)
|
||||||
|
def cli_create_cmd(
|
||||||
|
app_name: str,
|
||||||
|
model: Optional[str],
|
||||||
|
api_key: Optional[str],
|
||||||
|
project: Optional[str],
|
||||||
|
region: Optional[str],
|
||||||
|
):
|
||||||
|
"""Creates a new app in the current folder with prepopulated agent template.
|
||||||
|
|
||||||
|
APP_NAME: required, the folder of the agent source code.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
|
||||||
|
adk create path/to/my_app
|
||||||
|
"""
|
||||||
|
cli_create.run_cmd(
|
||||||
|
app_name,
|
||||||
|
model=model,
|
||||||
|
google_api_key=api_key,
|
||||||
|
google_cloud_project=project,
|
||||||
|
google_cloud_region=region,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@main.command("run")
|
@main.command("run")
|
||||||
@click.option(
|
@click.option(
|
||||||
"--save_session",
|
"--save_session",
|
||||||
@ -62,7 +112,7 @@ def deploy():
|
|||||||
),
|
),
|
||||||
)
|
)
|
||||||
def cli_run(agent: str, save_session: bool):
|
def cli_run(agent: str, save_session: bool):
|
||||||
"""Run an interactive CLI for a certain agent.
|
"""Runs an interactive CLI for a certain agent.
|
||||||
|
|
||||||
AGENT: The path to the agent source code folder.
|
AGENT: The path to the agent source code folder.
|
||||||
|
|
||||||
@ -150,7 +200,7 @@ def cli_eval(
|
|||||||
EvalMetric(metric_name=metric_name, threshold=threshold)
|
EvalMetric(metric_name=metric_name, threshold=threshold)
|
||||||
)
|
)
|
||||||
|
|
||||||
print(f"Using evaluation criteria: {evaluation_criteria}")
|
print(f"Using evaluation creiteria: {evaluation_criteria}")
|
||||||
|
|
||||||
root_agent = get_root_agent(agent_module_file_path)
|
root_agent = get_root_agent(agent_module_file_path)
|
||||||
reset_func = try_get_reset_func(agent_module_file_path)
|
reset_func = try_get_reset_func(agent_module_file_path)
|
||||||
@ -244,7 +294,7 @@ def cli_eval(
|
|||||||
type=click.Path(
|
type=click.Path(
|
||||||
exists=True, dir_okay=True, file_okay=False, resolve_path=True
|
exists=True, dir_okay=True, file_okay=False, resolve_path=True
|
||||||
),
|
),
|
||||||
default=os.getcwd(),
|
default=os.getcwd,
|
||||||
)
|
)
|
||||||
def cli_web(
|
def cli_web(
|
||||||
agents_dir: str,
|
agents_dir: str,
|
||||||
@ -255,7 +305,7 @@ def cli_web(
|
|||||||
port: int = 8000,
|
port: int = 8000,
|
||||||
trace_to_cloud: bool = False,
|
trace_to_cloud: bool = False,
|
||||||
):
|
):
|
||||||
"""Start a FastAPI server with Web UI for agents.
|
"""Starts a FastAPI server with Web UI for agents.
|
||||||
|
|
||||||
AGENTS_DIR: The directory of agents, where each sub-directory is a single
|
AGENTS_DIR: The directory of agents, where each sub-directory is a single
|
||||||
agent, containing at least `__init__.py` and `agent.py` files.
|
agent, containing at least `__init__.py` and `agent.py` files.
|
||||||
@ -274,7 +324,7 @@ def cli_web(
|
|||||||
@asynccontextmanager
|
@asynccontextmanager
|
||||||
async def _lifespan(app: FastAPI):
|
async def _lifespan(app: FastAPI):
|
||||||
click.secho(
|
click.secho(
|
||||||
f"""\
|
f"""
|
||||||
+-----------------------------------------------------------------------------+
|
+-----------------------------------------------------------------------------+
|
||||||
| ADK Web Server started |
|
| ADK Web Server started |
|
||||||
| |
|
| |
|
||||||
@ -285,7 +335,7 @@ def cli_web(
|
|||||||
)
|
)
|
||||||
yield # Startup is done, now app is running
|
yield # Startup is done, now app is running
|
||||||
click.secho(
|
click.secho(
|
||||||
"""\
|
"""
|
||||||
+-----------------------------------------------------------------------------+
|
+-----------------------------------------------------------------------------+
|
||||||
| ADK Web Server shutting down... |
|
| ADK Web Server shutting down... |
|
||||||
+-----------------------------------------------------------------------------+
|
+-----------------------------------------------------------------------------+
|
||||||
@ -378,7 +428,7 @@ def cli_api_server(
|
|||||||
port: int = 8000,
|
port: int = 8000,
|
||||||
trace_to_cloud: bool = False,
|
trace_to_cloud: bool = False,
|
||||||
):
|
):
|
||||||
"""Start a FastAPI server for agents.
|
"""Starts a FastAPI server for agents.
|
||||||
|
|
||||||
AGENTS_DIR: The directory of agents, where each sub-directory is a single
|
AGENTS_DIR: The directory of agents, where each sub-directory is a single
|
||||||
agent, containing at least `__init__.py` and `agent.py` files.
|
agent, containing at least `__init__.py` and `agent.py` files.
|
||||||
@ -452,7 +502,7 @@ def cli_api_server(
|
|||||||
help="Optional. The port of the ADK API server (default: 8000).",
|
help="Optional. The port of the ADK API server (default: 8000).",
|
||||||
)
|
)
|
||||||
@click.option(
|
@click.option(
|
||||||
"--with_cloud_trace",
|
"--trace_to_cloud",
|
||||||
type=bool,
|
type=bool,
|
||||||
is_flag=True,
|
is_flag=True,
|
||||||
show_default=True,
|
show_default=True,
|
||||||
@ -483,6 +533,14 @@ def cli_api_server(
|
|||||||
" (default: a timestamped folder in the system temp directory)."
|
" (default: a timestamped folder in the system temp directory)."
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
@click.option(
|
||||||
|
"--verbosity",
|
||||||
|
type=click.Choice(
|
||||||
|
["debug", "info", "warning", "error", "critical"], case_sensitive=False
|
||||||
|
),
|
||||||
|
default="WARNING",
|
||||||
|
help="Optional. Override the default verbosity level.",
|
||||||
|
)
|
||||||
@click.argument(
|
@click.argument(
|
||||||
"agent",
|
"agent",
|
||||||
type=click.Path(
|
type=click.Path(
|
||||||
@ -497,8 +555,9 @@ def cli_deploy_cloud_run(
|
|||||||
app_name: str,
|
app_name: str,
|
||||||
temp_folder: str,
|
temp_folder: str,
|
||||||
port: int,
|
port: int,
|
||||||
with_cloud_trace: bool,
|
trace_to_cloud: bool,
|
||||||
with_ui: bool,
|
with_ui: bool,
|
||||||
|
verbosity: str,
|
||||||
):
|
):
|
||||||
"""Deploys an agent to Cloud Run.
|
"""Deploys an agent to Cloud Run.
|
||||||
|
|
||||||
@ -517,8 +576,9 @@ def cli_deploy_cloud_run(
|
|||||||
app_name=app_name,
|
app_name=app_name,
|
||||||
temp_folder=temp_folder,
|
temp_folder=temp_folder,
|
||||||
port=port,
|
port=port,
|
||||||
with_cloud_trace=with_cloud_trace,
|
trace_to_cloud=trace_to_cloud,
|
||||||
with_ui=with_ui,
|
with_ui=with_ui,
|
||||||
|
verbosity=verbosity,
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
click.secho(f"Deploy failed: {e}", fg="red", err=True)
|
click.secho(f"Deploy failed: {e}", fg="red", err=True)
|
||||||
|
@ -13,7 +13,9 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
|
from contextlib import asynccontextmanager
|
||||||
import importlib
|
import importlib
|
||||||
|
import inspect
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
@ -28,6 +30,7 @@ from typing import Literal
|
|||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
import click
|
import click
|
||||||
|
from click import Tuple
|
||||||
from fastapi import FastAPI
|
from fastapi import FastAPI
|
||||||
from fastapi import HTTPException
|
from fastapi import HTTPException
|
||||||
from fastapi import Query
|
from fastapi import Query
|
||||||
@ -56,6 +59,7 @@ from ..agents.llm_agent import Agent
|
|||||||
from ..agents.run_config import StreamingMode
|
from ..agents.run_config import StreamingMode
|
||||||
from ..artifacts import InMemoryArtifactService
|
from ..artifacts import InMemoryArtifactService
|
||||||
from ..events.event import Event
|
from ..events.event import Event
|
||||||
|
from ..memory.in_memory_memory_service import InMemoryMemoryService
|
||||||
from ..runners import Runner
|
from ..runners import Runner
|
||||||
from ..sessions.database_session_service import DatabaseSessionService
|
from ..sessions.database_session_service import DatabaseSessionService
|
||||||
from ..sessions.in_memory_session_service import InMemorySessionService
|
from ..sessions.in_memory_session_service import InMemorySessionService
|
||||||
@ -143,11 +147,8 @@ def get_fast_api_app(
|
|||||||
provider.add_span_processor(
|
provider.add_span_processor(
|
||||||
export.SimpleSpanProcessor(ApiServerSpanExporter(trace_dict))
|
export.SimpleSpanProcessor(ApiServerSpanExporter(trace_dict))
|
||||||
)
|
)
|
||||||
envs.load_dotenv()
|
if trace_to_cloud:
|
||||||
enable_cloud_tracing = trace_to_cloud or os.environ.get(
|
envs.load_dotenv_for_agent("", agent_dir)
|
||||||
"ADK_TRACE_TO_CLOUD", "0"
|
|
||||||
).lower() in ["1", "true"]
|
|
||||||
if enable_cloud_tracing:
|
|
||||||
if project_id := os.environ.get("GOOGLE_CLOUD_PROJECT", None):
|
if project_id := os.environ.get("GOOGLE_CLOUD_PROJECT", None):
|
||||||
processor = export.BatchSpanProcessor(
|
processor = export.BatchSpanProcessor(
|
||||||
CloudTraceSpanExporter(project_id=project_id)
|
CloudTraceSpanExporter(project_id=project_id)
|
||||||
@ -161,8 +162,22 @@ def get_fast_api_app(
|
|||||||
|
|
||||||
trace.set_tracer_provider(provider)
|
trace.set_tracer_provider(provider)
|
||||||
|
|
||||||
|
exit_stacks = []
|
||||||
|
|
||||||
|
@asynccontextmanager
|
||||||
|
async def internal_lifespan(app: FastAPI):
|
||||||
|
if lifespan:
|
||||||
|
async with lifespan(app) as lifespan_context:
|
||||||
|
yield
|
||||||
|
|
||||||
|
if exit_stacks:
|
||||||
|
for stack in exit_stacks:
|
||||||
|
await stack.aclose()
|
||||||
|
else:
|
||||||
|
yield
|
||||||
|
|
||||||
# Run the FastAPI server.
|
# Run the FastAPI server.
|
||||||
app = FastAPI(lifespan=lifespan)
|
app = FastAPI(lifespan=internal_lifespan)
|
||||||
|
|
||||||
if allow_origins:
|
if allow_origins:
|
||||||
app.add_middleware(
|
app.add_middleware(
|
||||||
@ -181,6 +196,7 @@ def get_fast_api_app(
|
|||||||
|
|
||||||
# Build the Artifact service
|
# Build the Artifact service
|
||||||
artifact_service = InMemoryArtifactService()
|
artifact_service = InMemoryArtifactService()
|
||||||
|
memory_service = InMemoryMemoryService()
|
||||||
|
|
||||||
# Build the Session service
|
# Build the Session service
|
||||||
agent_engine_id = ""
|
agent_engine_id = ""
|
||||||
@ -358,7 +374,7 @@ def get_fast_api_app(
|
|||||||
"/apps/{app_name}/eval_sets/{eval_set_id}/add_session",
|
"/apps/{app_name}/eval_sets/{eval_set_id}/add_session",
|
||||||
response_model_exclude_none=True,
|
response_model_exclude_none=True,
|
||||||
)
|
)
|
||||||
def add_session_to_eval_set(
|
async def add_session_to_eval_set(
|
||||||
app_name: str, eval_set_id: str, req: AddSessionToEvalSetRequest
|
app_name: str, eval_set_id: str, req: AddSessionToEvalSetRequest
|
||||||
):
|
):
|
||||||
pattern = r"^[a-zA-Z0-9_]+$"
|
pattern = r"^[a-zA-Z0-9_]+$"
|
||||||
@ -393,7 +409,9 @@ def get_fast_api_app(
|
|||||||
test_data = evals.convert_session_to_eval_format(session)
|
test_data = evals.convert_session_to_eval_format(session)
|
||||||
|
|
||||||
# Populate the session with initial session state.
|
# Populate the session with initial session state.
|
||||||
initial_session_state = create_empty_state(_get_root_agent(app_name))
|
initial_session_state = create_empty_state(
|
||||||
|
await _get_root_agent_async(app_name)
|
||||||
|
)
|
||||||
|
|
||||||
eval_set_data.append({
|
eval_set_data.append({
|
||||||
"name": req.eval_id,
|
"name": req.eval_id,
|
||||||
@ -430,7 +448,7 @@ def get_fast_api_app(
|
|||||||
"/apps/{app_name}/eval_sets/{eval_set_id}/run_eval",
|
"/apps/{app_name}/eval_sets/{eval_set_id}/run_eval",
|
||||||
response_model_exclude_none=True,
|
response_model_exclude_none=True,
|
||||||
)
|
)
|
||||||
def run_eval(
|
async def run_eval(
|
||||||
app_name: str, eval_set_id: str, req: RunEvalRequest
|
app_name: str, eval_set_id: str, req: RunEvalRequest
|
||||||
) -> list[RunEvalResult]:
|
) -> list[RunEvalResult]:
|
||||||
from .cli_eval import run_evals
|
from .cli_eval import run_evals
|
||||||
@ -447,7 +465,7 @@ def get_fast_api_app(
|
|||||||
logger.info(
|
logger.info(
|
||||||
"Eval ids to run list is empty. We will all evals in the eval set."
|
"Eval ids to run list is empty. We will all evals in the eval set."
|
||||||
)
|
)
|
||||||
root_agent = _get_root_agent(app_name)
|
root_agent = await _get_root_agent_async(app_name)
|
||||||
eval_results = list(
|
eval_results = list(
|
||||||
run_evals(
|
run_evals(
|
||||||
eval_set_to_evals,
|
eval_set_to_evals,
|
||||||
@ -577,7 +595,7 @@ def get_fast_api_app(
|
|||||||
)
|
)
|
||||||
if not session:
|
if not session:
|
||||||
raise HTTPException(status_code=404, detail="Session not found")
|
raise HTTPException(status_code=404, detail="Session not found")
|
||||||
runner = _get_runner(req.app_name)
|
runner = await _get_runner_async(req.app_name)
|
||||||
events = [
|
events = [
|
||||||
event
|
event
|
||||||
async for event in runner.run_async(
|
async for event in runner.run_async(
|
||||||
@ -604,7 +622,7 @@ def get_fast_api_app(
|
|||||||
async def event_generator():
|
async def event_generator():
|
||||||
try:
|
try:
|
||||||
stream_mode = StreamingMode.SSE if req.streaming else StreamingMode.NONE
|
stream_mode = StreamingMode.SSE if req.streaming else StreamingMode.NONE
|
||||||
runner = _get_runner(req.app_name)
|
runner = await _get_runner_async(req.app_name)
|
||||||
async for event in runner.run_async(
|
async for event in runner.run_async(
|
||||||
user_id=req.user_id,
|
user_id=req.user_id,
|
||||||
session_id=req.session_id,
|
session_id=req.session_id,
|
||||||
@ -630,7 +648,7 @@ def get_fast_api_app(
|
|||||||
"/apps/{app_name}/users/{user_id}/sessions/{session_id}/events/{event_id}/graph",
|
"/apps/{app_name}/users/{user_id}/sessions/{session_id}/events/{event_id}/graph",
|
||||||
response_model_exclude_none=True,
|
response_model_exclude_none=True,
|
||||||
)
|
)
|
||||||
def get_event_graph(
|
async def get_event_graph(
|
||||||
app_name: str, user_id: str, session_id: str, event_id: str
|
app_name: str, user_id: str, session_id: str, event_id: str
|
||||||
):
|
):
|
||||||
# Connect to managed session if agent_engine_id is set.
|
# Connect to managed session if agent_engine_id is set.
|
||||||
@ -647,7 +665,7 @@ def get_fast_api_app(
|
|||||||
|
|
||||||
function_calls = event.get_function_calls()
|
function_calls = event.get_function_calls()
|
||||||
function_responses = event.get_function_responses()
|
function_responses = event.get_function_responses()
|
||||||
root_agent = _get_root_agent(app_name)
|
root_agent = await _get_root_agent_async(app_name)
|
||||||
dot_graph = None
|
dot_graph = None
|
||||||
if function_calls:
|
if function_calls:
|
||||||
function_call_highlights = []
|
function_call_highlights = []
|
||||||
@ -704,7 +722,7 @@ def get_fast_api_app(
|
|||||||
live_request_queue = LiveRequestQueue()
|
live_request_queue = LiveRequestQueue()
|
||||||
|
|
||||||
async def forward_events():
|
async def forward_events():
|
||||||
runner = _get_runner(app_name)
|
runner = await _get_runner_async(app_name)
|
||||||
async for event in runner.run_live(
|
async for event in runner.run_live(
|
||||||
session=session, live_request_queue=live_request_queue
|
session=session, live_request_queue=live_request_queue
|
||||||
):
|
):
|
||||||
@ -742,26 +760,40 @@ def get_fast_api_app(
|
|||||||
for task in pending:
|
for task in pending:
|
||||||
task.cancel()
|
task.cancel()
|
||||||
|
|
||||||
def _get_root_agent(app_name: str) -> Agent:
|
async def _get_root_agent_async(app_name: str) -> Agent:
|
||||||
"""Returns the root agent for the given app."""
|
"""Returns the root agent for the given app."""
|
||||||
if app_name in root_agent_dict:
|
if app_name in root_agent_dict:
|
||||||
return root_agent_dict[app_name]
|
return root_agent_dict[app_name]
|
||||||
envs.load_dotenv_for_agent(os.path.basename(app_name), agent_dir)
|
|
||||||
agent_module = importlib.import_module(app_name)
|
agent_module = importlib.import_module(app_name)
|
||||||
root_agent: Agent = agent_module.agent.root_agent
|
if getattr(agent_module.agent, "root_agent"):
|
||||||
|
root_agent = agent_module.agent.root_agent
|
||||||
|
else:
|
||||||
|
raise ValueError(f'Unable to find "root_agent" from {app_name}.')
|
||||||
|
|
||||||
|
# Handle an awaitable root agent and await for the actual agent.
|
||||||
|
if inspect.isawaitable(root_agent):
|
||||||
|
try:
|
||||||
|
agent, exit_stack = await root_agent
|
||||||
|
exit_stacks.append(exit_stack)
|
||||||
|
root_agent = agent
|
||||||
|
except Exception as e:
|
||||||
|
raise RuntimeError(f"error getting root agent, {e}") from e
|
||||||
|
|
||||||
root_agent_dict[app_name] = root_agent
|
root_agent_dict[app_name] = root_agent
|
||||||
return root_agent
|
return root_agent
|
||||||
|
|
||||||
def _get_runner(app_name: str) -> Runner:
|
async def _get_runner_async(app_name: str) -> Runner:
|
||||||
"""Returns the runner for the given app."""
|
"""Returns the runner for the given app."""
|
||||||
|
envs.load_dotenv_for_agent(os.path.basename(app_name), agent_dir)
|
||||||
if app_name in runner_dict:
|
if app_name in runner_dict:
|
||||||
return runner_dict[app_name]
|
return runner_dict[app_name]
|
||||||
root_agent = _get_root_agent(app_name)
|
root_agent = await _get_root_agent_async(app_name)
|
||||||
runner = Runner(
|
runner = Runner(
|
||||||
app_name=agent_engine_id if agent_engine_id else app_name,
|
app_name=agent_engine_id if agent_engine_id else app_name,
|
||||||
agent=root_agent,
|
agent=root_agent,
|
||||||
artifact_service=artifact_service,
|
artifact_service=artifact_service,
|
||||||
session_service=session_service,
|
session_service=session_service,
|
||||||
|
memory_service=memory_service,
|
||||||
)
|
)
|
||||||
runner_dict[app_name] = runner
|
runner_dict[app_name] = runner
|
||||||
return runner
|
return runner
|
||||||
|
@ -50,8 +50,5 @@ def load_dotenv_for_agent(
|
|||||||
agent_name,
|
agent_name,
|
||||||
dotenv_file_path,
|
dotenv_file_path,
|
||||||
)
|
)
|
||||||
logger.info(
|
|
||||||
'Reloaded %s file for %s at %s', filename, agent_name, dotenv_file_path
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
logger.info('No %s file found for %s', filename, agent_name)
|
logger.info('No %s file found for %s', filename, agent_name)
|
||||||
|
@ -106,9 +106,11 @@ class ResponseEvaluator:
|
|||||||
eval_dataset = pd.DataFrame(flattened_queries).rename(
|
eval_dataset = pd.DataFrame(flattened_queries).rename(
|
||||||
columns={"query": "prompt", "expected_tool_use": "reference_trajectory"}
|
columns={"query": "prompt", "expected_tool_use": "reference_trajectory"}
|
||||||
)
|
)
|
||||||
eval_task = EvalTask(dataset=eval_dataset, metrics=metrics)
|
|
||||||
|
|
||||||
eval_result = eval_task.evaluate()
|
eval_result = ResponseEvaluator._perform_eval(
|
||||||
|
dataset=eval_dataset, metrics=metrics
|
||||||
|
)
|
||||||
|
|
||||||
if print_detailed_results:
|
if print_detailed_results:
|
||||||
ResponseEvaluator._print_results(eval_result)
|
ResponseEvaluator._print_results(eval_result)
|
||||||
return eval_result.summary_metrics
|
return eval_result.summary_metrics
|
||||||
@ -129,6 +131,16 @@ class ResponseEvaluator:
|
|||||||
metrics.append("rouge_1")
|
metrics.append("rouge_1")
|
||||||
return metrics
|
return metrics
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _perform_eval(dataset, metrics):
|
||||||
|
"""This method hides away the call to external service.
|
||||||
|
|
||||||
|
Primarily helps with unit testing.
|
||||||
|
"""
|
||||||
|
eval_task = EvalTask(dataset=dataset, metrics=metrics)
|
||||||
|
|
||||||
|
return eval_task.evaluate()
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _print_results(eval_result):
|
def _print_results(eval_result):
|
||||||
print("Evaluation Summary Metrics:", eval_result.summary_metrics)
|
print("Evaluation Summary Metrics:", eval_result.summary_metrics)
|
||||||
|
@ -87,15 +87,21 @@ class _NlPlanningResponse(BaseLlmResponseProcessor):
|
|||||||
return
|
return
|
||||||
|
|
||||||
# Postprocess the LLM response.
|
# Postprocess the LLM response.
|
||||||
|
callback_context = CallbackContext(invocation_context)
|
||||||
processed_parts = planner.process_planning_response(
|
processed_parts = planner.process_planning_response(
|
||||||
CallbackContext(invocation_context), llm_response.content.parts
|
callback_context, llm_response.content.parts
|
||||||
)
|
)
|
||||||
if processed_parts:
|
if processed_parts:
|
||||||
llm_response.content.parts = processed_parts
|
llm_response.content.parts = processed_parts
|
||||||
|
|
||||||
# Maintain async generator behavior
|
if callback_context.state.has_delta():
|
||||||
if False: # Ensures it behaves as a generator
|
state_update_event = Event(
|
||||||
yield # This is a no-op but maintains generator structure
|
invocation_id=invocation_context.invocation_id,
|
||||||
|
author=invocation_context.agent.name,
|
||||||
|
branch=invocation_context.branch,
|
||||||
|
actions=callback_context._event_actions,
|
||||||
|
)
|
||||||
|
yield state_update_event
|
||||||
|
|
||||||
|
|
||||||
response_processor = _NlPlanningResponse()
|
response_processor = _NlPlanningResponse()
|
||||||
|
@ -12,6 +12,7 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
import base64
|
||||||
import copy
|
import copy
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
import json
|
import json
|
||||||
@ -20,17 +21,17 @@ from typing import Any
|
|||||||
from typing import Optional
|
from typing import Optional
|
||||||
import uuid
|
import uuid
|
||||||
|
|
||||||
|
from sqlalchemy import Boolean
|
||||||
from sqlalchemy import delete
|
from sqlalchemy import delete
|
||||||
from sqlalchemy import Dialect
|
from sqlalchemy import Dialect
|
||||||
from sqlalchemy import ForeignKeyConstraint
|
from sqlalchemy import ForeignKeyConstraint
|
||||||
from sqlalchemy import func
|
from sqlalchemy import func
|
||||||
from sqlalchemy import select
|
|
||||||
from sqlalchemy import Text
|
from sqlalchemy import Text
|
||||||
from sqlalchemy.dialects import postgresql
|
from sqlalchemy.dialects import postgresql
|
||||||
from sqlalchemy.engine import create_engine
|
from sqlalchemy.engine import create_engine
|
||||||
from sqlalchemy.engine import Engine
|
from sqlalchemy.engine import Engine
|
||||||
from sqlalchemy.ext.mutable import MutableDict
|
|
||||||
from sqlalchemy.exc import ArgumentError
|
from sqlalchemy.exc import ArgumentError
|
||||||
|
from sqlalchemy.ext.mutable import MutableDict
|
||||||
from sqlalchemy.inspection import inspect
|
from sqlalchemy.inspection import inspect
|
||||||
from sqlalchemy.orm import DeclarativeBase
|
from sqlalchemy.orm import DeclarativeBase
|
||||||
from sqlalchemy.orm import Mapped
|
from sqlalchemy.orm import Mapped
|
||||||
@ -54,6 +55,7 @@ from .base_session_service import ListSessionsResponse
|
|||||||
from .session import Session
|
from .session import Session
|
||||||
from .state import State
|
from .state import State
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
@ -103,7 +105,7 @@ class StorageSession(Base):
|
|||||||
String, primary_key=True, default=lambda: str(uuid.uuid4())
|
String, primary_key=True, default=lambda: str(uuid.uuid4())
|
||||||
)
|
)
|
||||||
|
|
||||||
state: Mapped[dict] = mapped_column(
|
state: Mapped[MutableDict[str, Any]] = mapped_column(
|
||||||
MutableDict.as_mutable(DynamicJSON), default={}
|
MutableDict.as_mutable(DynamicJSON), default={}
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -134,8 +136,20 @@ class StorageEvent(Base):
|
|||||||
author: Mapped[str] = mapped_column(String)
|
author: Mapped[str] = mapped_column(String)
|
||||||
branch: Mapped[str] = mapped_column(String, nullable=True)
|
branch: Mapped[str] = mapped_column(String, nullable=True)
|
||||||
timestamp: Mapped[DateTime] = mapped_column(DateTime(), default=func.now())
|
timestamp: Mapped[DateTime] = mapped_column(DateTime(), default=func.now())
|
||||||
content: Mapped[dict] = mapped_column(DynamicJSON)
|
content: Mapped[dict[str, Any]] = mapped_column(DynamicJSON)
|
||||||
actions: Mapped[dict] = mapped_column(PickleType)
|
actions: Mapped[MutableDict[str, Any]] = mapped_column(PickleType)
|
||||||
|
|
||||||
|
long_running_tool_ids_json: Mapped[Optional[str]] = mapped_column(
|
||||||
|
Text, nullable=True
|
||||||
|
)
|
||||||
|
grounding_metadata: Mapped[dict[str, Any]] = mapped_column(
|
||||||
|
DynamicJSON, nullable=True
|
||||||
|
)
|
||||||
|
partial: Mapped[bool] = mapped_column(Boolean, nullable=True)
|
||||||
|
turn_complete: Mapped[bool] = mapped_column(Boolean, nullable=True)
|
||||||
|
error_code: Mapped[str] = mapped_column(String, nullable=True)
|
||||||
|
error_message: Mapped[str] = mapped_column(String, nullable=True)
|
||||||
|
interrupted: Mapped[bool] = mapped_column(Boolean, nullable=True)
|
||||||
|
|
||||||
storage_session: Mapped[StorageSession] = relationship(
|
storage_session: Mapped[StorageSession] = relationship(
|
||||||
"StorageSession",
|
"StorageSession",
|
||||||
@ -150,13 +164,28 @@ class StorageEvent(Base):
|
|||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def long_running_tool_ids(self) -> set[str]:
|
||||||
|
return (
|
||||||
|
set(json.loads(self.long_running_tool_ids_json))
|
||||||
|
if self.long_running_tool_ids_json
|
||||||
|
else set()
|
||||||
|
)
|
||||||
|
|
||||||
|
@long_running_tool_ids.setter
|
||||||
|
def long_running_tool_ids(self, value: set[str]):
|
||||||
|
if value is None:
|
||||||
|
self.long_running_tool_ids_json = None
|
||||||
|
else:
|
||||||
|
self.long_running_tool_ids_json = json.dumps(list(value))
|
||||||
|
|
||||||
|
|
||||||
class StorageAppState(Base):
|
class StorageAppState(Base):
|
||||||
"""Represents an app state stored in the database."""
|
"""Represents an app state stored in the database."""
|
||||||
__tablename__ = "app_states"
|
__tablename__ = "app_states"
|
||||||
|
|
||||||
app_name: Mapped[str] = mapped_column(String, primary_key=True)
|
app_name: Mapped[str] = mapped_column(String, primary_key=True)
|
||||||
state: Mapped[dict] = mapped_column(
|
state: Mapped[MutableDict[str, Any]] = mapped_column(
|
||||||
MutableDict.as_mutable(DynamicJSON), default={}
|
MutableDict.as_mutable(DynamicJSON), default={}
|
||||||
)
|
)
|
||||||
update_time: Mapped[DateTime] = mapped_column(
|
update_time: Mapped[DateTime] = mapped_column(
|
||||||
@ -170,7 +199,7 @@ class StorageUserState(Base):
|
|||||||
|
|
||||||
app_name: Mapped[str] = mapped_column(String, primary_key=True)
|
app_name: Mapped[str] = mapped_column(String, primary_key=True)
|
||||||
user_id: Mapped[str] = mapped_column(String, primary_key=True)
|
user_id: Mapped[str] = mapped_column(String, primary_key=True)
|
||||||
state: Mapped[dict] = mapped_column(
|
state: Mapped[MutableDict[str, Any]] = mapped_column(
|
||||||
MutableDict.as_mutable(DynamicJSON), default={}
|
MutableDict.as_mutable(DynamicJSON), default={}
|
||||||
)
|
)
|
||||||
update_time: Mapped[DateTime] = mapped_column(
|
update_time: Mapped[DateTime] = mapped_column(
|
||||||
@ -295,7 +324,6 @@ class DatabaseSessionService(BaseSessionService):
|
|||||||
last_update_time=storage_session.update_time.timestamp(),
|
last_update_time=storage_session.update_time.timestamp(),
|
||||||
)
|
)
|
||||||
return session
|
return session
|
||||||
return None
|
|
||||||
|
|
||||||
@override
|
@override
|
||||||
def get_session(
|
def get_session(
|
||||||
@ -309,7 +337,6 @@ class DatabaseSessionService(BaseSessionService):
|
|||||||
# 1. Get the storage session entry from session table
|
# 1. Get the storage session entry from session table
|
||||||
# 2. Get all the events based on session id and filtering config
|
# 2. Get all the events based on session id and filtering config
|
||||||
# 3. Convert and return the session
|
# 3. Convert and return the session
|
||||||
session: Session = None
|
|
||||||
with self.DatabaseSessionFactory() as sessionFactory:
|
with self.DatabaseSessionFactory() as sessionFactory:
|
||||||
storage_session = sessionFactory.get(
|
storage_session = sessionFactory.get(
|
||||||
StorageSession, (app_name, user_id, session_id)
|
StorageSession, (app_name, user_id, session_id)
|
||||||
@ -356,13 +383,19 @@ class DatabaseSessionService(BaseSessionService):
|
|||||||
author=e.author,
|
author=e.author,
|
||||||
branch=e.branch,
|
branch=e.branch,
|
||||||
invocation_id=e.invocation_id,
|
invocation_id=e.invocation_id,
|
||||||
content=e.content,
|
content=_decode_content(e.content),
|
||||||
actions=e.actions,
|
actions=e.actions,
|
||||||
timestamp=e.timestamp.timestamp(),
|
timestamp=e.timestamp.timestamp(),
|
||||||
|
long_running_tool_ids=e.long_running_tool_ids,
|
||||||
|
grounding_metadata=e.grounding_metadata,
|
||||||
|
partial=e.partial,
|
||||||
|
turn_complete=e.turn_complete,
|
||||||
|
error_code=e.error_code,
|
||||||
|
error_message=e.error_message,
|
||||||
|
interrupted=e.interrupted,
|
||||||
)
|
)
|
||||||
for e in storage_events
|
for e in storage_events
|
||||||
]
|
]
|
||||||
|
|
||||||
return session
|
return session
|
||||||
|
|
||||||
@override
|
@override
|
||||||
@ -387,7 +420,6 @@ class DatabaseSessionService(BaseSessionService):
|
|||||||
)
|
)
|
||||||
sessions.append(session)
|
sessions.append(session)
|
||||||
return ListSessionsResponse(sessions=sessions)
|
return ListSessionsResponse(sessions=sessions)
|
||||||
raise ValueError("Failed to retrieve sessions.")
|
|
||||||
|
|
||||||
@override
|
@override
|
||||||
def delete_session(
|
def delete_session(
|
||||||
@ -406,7 +438,7 @@ class DatabaseSessionService(BaseSessionService):
|
|||||||
def append_event(self, session: Session, event: Event) -> Event:
|
def append_event(self, session: Session, event: Event) -> Event:
|
||||||
logger.info(f"Append event: {event} to session {session.id}")
|
logger.info(f"Append event: {event} to session {session.id}")
|
||||||
|
|
||||||
if event.partial and not event.content:
|
if event.partial:
|
||||||
return event
|
return event
|
||||||
|
|
||||||
# 1. Check if timestamp is stale
|
# 1. Check if timestamp is stale
|
||||||
@ -455,19 +487,34 @@ class DatabaseSessionService(BaseSessionService):
|
|||||||
storage_user_state.state = user_state
|
storage_user_state.state = user_state
|
||||||
storage_session.state = session_state
|
storage_session.state = session_state
|
||||||
|
|
||||||
encoded_content = event.content.model_dump(exclude_none=True)
|
|
||||||
storage_event = StorageEvent(
|
storage_event = StorageEvent(
|
||||||
id=event.id,
|
id=event.id,
|
||||||
invocation_id=event.invocation_id,
|
invocation_id=event.invocation_id,
|
||||||
author=event.author,
|
author=event.author,
|
||||||
branch=event.branch,
|
branch=event.branch,
|
||||||
content=encoded_content,
|
|
||||||
actions=event.actions,
|
actions=event.actions,
|
||||||
session_id=session.id,
|
session_id=session.id,
|
||||||
app_name=session.app_name,
|
app_name=session.app_name,
|
||||||
user_id=session.user_id,
|
user_id=session.user_id,
|
||||||
timestamp=datetime.fromtimestamp(event.timestamp),
|
timestamp=datetime.fromtimestamp(event.timestamp),
|
||||||
|
long_running_tool_ids=event.long_running_tool_ids,
|
||||||
|
grounding_metadata=event.grounding_metadata,
|
||||||
|
partial=event.partial,
|
||||||
|
turn_complete=event.turn_complete,
|
||||||
|
error_code=event.error_code,
|
||||||
|
error_message=event.error_message,
|
||||||
|
interrupted=event.interrupted,
|
||||||
)
|
)
|
||||||
|
if event.content:
|
||||||
|
encoded_content = event.content.model_dump(exclude_none=True)
|
||||||
|
# Workaround for multimodal Content throwing JSON not serializable
|
||||||
|
# error with SQLAlchemy.
|
||||||
|
for p in encoded_content["parts"]:
|
||||||
|
if "inline_data" in p:
|
||||||
|
p["inline_data"]["data"] = (
|
||||||
|
base64.b64encode(p["inline_data"]["data"]).decode("utf-8"),
|
||||||
|
)
|
||||||
|
storage_event.content = encoded_content
|
||||||
|
|
||||||
sessionFactory.add(storage_event)
|
sessionFactory.add(storage_event)
|
||||||
|
|
||||||
@ -489,8 +536,7 @@ class DatabaseSessionService(BaseSessionService):
|
|||||||
user_id: str,
|
user_id: str,
|
||||||
session_id: str,
|
session_id: str,
|
||||||
) -> ListEventsResponse:
|
) -> ListEventsResponse:
|
||||||
pass
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
|
||||||
def convert_event(event: StorageEvent) -> Event:
|
def convert_event(event: StorageEvent) -> Event:
|
||||||
"""Converts a storage event to an event."""
|
"""Converts a storage event to an event."""
|
||||||
@ -505,7 +551,7 @@ def convert_event(event: StorageEvent) -> Event:
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def _extract_state_delta(state: dict):
|
def _extract_state_delta(state: dict[str, Any]):
|
||||||
app_state_delta = {}
|
app_state_delta = {}
|
||||||
user_state_delta = {}
|
user_state_delta = {}
|
||||||
session_state_delta = {}
|
session_state_delta = {}
|
||||||
@ -528,3 +574,10 @@ def _merge_state(app_state, user_state, session_state):
|
|||||||
for key in user_state.keys():
|
for key in user_state.keys():
|
||||||
merged_state[State.USER_PREFIX + key] = user_state[key]
|
merged_state[State.USER_PREFIX + key] = user_state[key]
|
||||||
return merged_state
|
return merged_state
|
||||||
|
|
||||||
|
|
||||||
|
def _decode_content(content: dict[str, Any]) -> dict[str, Any]:
|
||||||
|
for p in content["parts"]:
|
||||||
|
if "inline_data" in p:
|
||||||
|
p["inline_data"]["data"] = base64.b64decode(p["inline_data"]["data"][0])
|
||||||
|
return content
|
||||||
|
@ -196,11 +196,12 @@ class IntegrationClient:
|
|||||||
action_details = connections_client.get_action_schema(action)
|
action_details = connections_client.get_action_schema(action)
|
||||||
input_schema = action_details["inputSchema"]
|
input_schema = action_details["inputSchema"]
|
||||||
output_schema = action_details["outputSchema"]
|
output_schema = action_details["outputSchema"]
|
||||||
action_display_name = action_details["displayName"]
|
# Remove spaces from the display name to generate valid spec
|
||||||
|
action_display_name = action_details["displayName"].replace(" ", "")
|
||||||
operation = "EXECUTE_ACTION"
|
operation = "EXECUTE_ACTION"
|
||||||
if action == "ExecuteCustomQuery":
|
if action == "ExecuteCustomQuery":
|
||||||
connector_spec["components"]["schemas"][
|
connector_spec["components"]["schemas"][
|
||||||
f"{action}_Request"
|
f"{action_display_name}_Request"
|
||||||
] = connections_client.execute_custom_query_request()
|
] = connections_client.execute_custom_query_request()
|
||||||
operation = "EXECUTE_QUERY"
|
operation = "EXECUTE_QUERY"
|
||||||
else:
|
else:
|
||||||
|
@ -291,7 +291,7 @@ def _parse_schema_from_parameter(
|
|||||||
return schema
|
return schema
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f'Failed to parse the parameter {param} of function {func_name} for'
|
f'Failed to parse the parameter {param} of function {func_name} for'
|
||||||
' automatic function calling.Automatic function calling works best with'
|
' automatic function calling. Automatic function calling works best with'
|
||||||
' simpler function signature schema,consider manually parse your'
|
' simpler function signature schema,consider manually parse your'
|
||||||
f' function declaration for function {func_name}.'
|
f' function declaration for function {func_name}.'
|
||||||
)
|
)
|
||||||
|
@ -11,4 +11,77 @@
|
|||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
from .google_api_tool_sets import calendar_tool_set
|
__all__ = [
|
||||||
|
'bigquery_tool_set',
|
||||||
|
'calendar_tool_set',
|
||||||
|
'gmail_tool_set',
|
||||||
|
'youtube_tool_set',
|
||||||
|
'slides_tool_set',
|
||||||
|
'sheets_tool_set',
|
||||||
|
'docs_tool_set',
|
||||||
|
]
|
||||||
|
|
||||||
|
# Nothing is imported here automatically
|
||||||
|
# Each tool set will only be imported when accessed
|
||||||
|
|
||||||
|
_bigquery_tool_set = None
|
||||||
|
_calendar_tool_set = None
|
||||||
|
_gmail_tool_set = None
|
||||||
|
_youtube_tool_set = None
|
||||||
|
_slides_tool_set = None
|
||||||
|
_sheets_tool_set = None
|
||||||
|
_docs_tool_set = None
|
||||||
|
|
||||||
|
|
||||||
|
def __getattr__(name):
|
||||||
|
global _bigquery_tool_set, _calendar_tool_set, _gmail_tool_set, _youtube_tool_set, _slides_tool_set, _sheets_tool_set, _docs_tool_set
|
||||||
|
|
||||||
|
match name:
|
||||||
|
case 'bigquery_tool_set':
|
||||||
|
if _bigquery_tool_set is None:
|
||||||
|
from .google_api_tool_sets import bigquery_tool_set as bigquery
|
||||||
|
|
||||||
|
_bigquery_tool_set = bigquery
|
||||||
|
return _bigquery_tool_set
|
||||||
|
|
||||||
|
case 'calendar_tool_set':
|
||||||
|
if _calendar_tool_set is None:
|
||||||
|
from .google_api_tool_sets import calendar_tool_set as calendar
|
||||||
|
|
||||||
|
_calendar_tool_set = calendar
|
||||||
|
return _calendar_tool_set
|
||||||
|
|
||||||
|
case 'gmail_tool_set':
|
||||||
|
if _gmail_tool_set is None:
|
||||||
|
from .google_api_tool_sets import gmail_tool_set as gmail
|
||||||
|
|
||||||
|
_gmail_tool_set = gmail
|
||||||
|
return _gmail_tool_set
|
||||||
|
|
||||||
|
case 'youtube_tool_set':
|
||||||
|
if _youtube_tool_set is None:
|
||||||
|
from .google_api_tool_sets import youtube_tool_set as youtube
|
||||||
|
|
||||||
|
_youtube_tool_set = youtube
|
||||||
|
return _youtube_tool_set
|
||||||
|
|
||||||
|
case 'slides_tool_set':
|
||||||
|
if _slides_tool_set is None:
|
||||||
|
from .google_api_tool_sets import slides_tool_set as slides
|
||||||
|
|
||||||
|
_slides_tool_set = slides
|
||||||
|
return _slides_tool_set
|
||||||
|
|
||||||
|
case 'sheets_tool_set':
|
||||||
|
if _sheets_tool_set is None:
|
||||||
|
from .google_api_tool_sets import sheets_tool_set as sheets
|
||||||
|
|
||||||
|
_sheets_tool_set = sheets
|
||||||
|
return _sheets_tool_set
|
||||||
|
|
||||||
|
case 'docs_tool_set':
|
||||||
|
if _docs_tool_set is None:
|
||||||
|
from .google_api_tool_sets import docs_tool_set as docs
|
||||||
|
|
||||||
|
_docs_tool_set = docs
|
||||||
|
return _docs_tool_set
|
||||||
|
@ -19,37 +19,94 @@ from .google_api_tool_set import GoogleApiToolSet
|
|||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
calendar_tool_set = GoogleApiToolSet.load_tool_set(
|
_bigquery_tool_set = None
|
||||||
api_name="calendar",
|
_calendar_tool_set = None
|
||||||
api_version="v3",
|
_gmail_tool_set = None
|
||||||
)
|
_youtube_tool_set = None
|
||||||
|
_slides_tool_set = None
|
||||||
|
_sheets_tool_set = None
|
||||||
|
_docs_tool_set = None
|
||||||
|
|
||||||
bigquery_tool_set = GoogleApiToolSet.load_tool_set(
|
|
||||||
|
def __getattr__(name):
|
||||||
|
"""This method dynamically loads and returns GoogleApiToolSet instances for
|
||||||
|
|
||||||
|
various Google APIs. It uses a lazy loading approach, initializing each
|
||||||
|
tool set only when it is first requested. This avoids unnecessary loading
|
||||||
|
of tool sets that are not used in a given session.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name (str): The name of the tool set to retrieve (e.g.,
|
||||||
|
"bigquery_tool_set").
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
GoogleApiToolSet: The requested tool set instance.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
AttributeError: If the requested tool set name is not recognized.
|
||||||
|
"""
|
||||||
|
global _bigquery_tool_set, _calendar_tool_set, _gmail_tool_set, _youtube_tool_set, _slides_tool_set, _sheets_tool_set, _docs_tool_set
|
||||||
|
|
||||||
|
match name:
|
||||||
|
case "bigquery_tool_set":
|
||||||
|
if _bigquery_tool_set is None:
|
||||||
|
_bigquery_tool_set = GoogleApiToolSet.load_tool_set(
|
||||||
api_name="bigquery",
|
api_name="bigquery",
|
||||||
api_version="v2",
|
api_version="v2",
|
||||||
)
|
)
|
||||||
|
|
||||||
gmail_tool_set = GoogleApiToolSet.load_tool_set(
|
return _bigquery_tool_set
|
||||||
|
|
||||||
|
case "calendar_tool_set":
|
||||||
|
if _calendar_tool_set is None:
|
||||||
|
_calendar_tool_set = GoogleApiToolSet.load_tool_set(
|
||||||
|
api_name="calendar",
|
||||||
|
api_version="v3",
|
||||||
|
)
|
||||||
|
|
||||||
|
return _calendar_tool_set
|
||||||
|
|
||||||
|
case "gmail_tool_set":
|
||||||
|
if _gmail_tool_set is None:
|
||||||
|
_gmail_tool_set = GoogleApiToolSet.load_tool_set(
|
||||||
api_name="gmail",
|
api_name="gmail",
|
||||||
api_version="v1",
|
api_version="v1",
|
||||||
)
|
)
|
||||||
|
|
||||||
youtube_tool_set = GoogleApiToolSet.load_tool_set(
|
return _gmail_tool_set
|
||||||
|
|
||||||
|
case "youtube_tool_set":
|
||||||
|
if _youtube_tool_set is None:
|
||||||
|
_youtube_tool_set = GoogleApiToolSet.load_tool_set(
|
||||||
api_name="youtube",
|
api_name="youtube",
|
||||||
api_version="v3",
|
api_version="v3",
|
||||||
)
|
)
|
||||||
|
|
||||||
slides_tool_set = GoogleApiToolSet.load_tool_set(
|
return _youtube_tool_set
|
||||||
|
|
||||||
|
case "slides_tool_set":
|
||||||
|
if _slides_tool_set is None:
|
||||||
|
_slides_tool_set = GoogleApiToolSet.load_tool_set(
|
||||||
api_name="slides",
|
api_name="slides",
|
||||||
api_version="v1",
|
api_version="v1",
|
||||||
)
|
)
|
||||||
|
|
||||||
sheets_tool_set = GoogleApiToolSet.load_tool_set(
|
return _slides_tool_set
|
||||||
|
|
||||||
|
case "sheets_tool_set":
|
||||||
|
if _sheets_tool_set is None:
|
||||||
|
_sheets_tool_set = GoogleApiToolSet.load_tool_set(
|
||||||
api_name="sheets",
|
api_name="sheets",
|
||||||
api_version="v4",
|
api_version="v4",
|
||||||
)
|
)
|
||||||
|
|
||||||
docs_tool_set = GoogleApiToolSet.load_tool_set(
|
return _sheets_tool_set
|
||||||
|
|
||||||
|
case "docs_tool_set":
|
||||||
|
if _docs_tool_set is None:
|
||||||
|
_docs_tool_set = GoogleApiToolSet.load_tool_set(
|
||||||
api_name="docs",
|
api_name="docs",
|
||||||
api_version="v1",
|
api_version="v1",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
return _docs_tool_set
|
||||||
|
@ -311,7 +311,9 @@ class GoogleApiToOpenApiConverter:
|
|||||||
|
|
||||||
# Determine the actual endpoint path
|
# Determine the actual endpoint path
|
||||||
# Google often has the format something like 'users.messages.list'
|
# Google often has the format something like 'users.messages.list'
|
||||||
rest_path = method_data.get("path", "/")
|
# flatPath is preferred as it provides the actual path, while path
|
||||||
|
# might contain variables like {+projectId}
|
||||||
|
rest_path = method_data.get("flatPath", method_data.get("path", "/"))
|
||||||
if not rest_path.startswith("/"):
|
if not rest_path.startswith("/"):
|
||||||
rest_path = "/" + rest_path
|
rest_path = "/" + rest_path
|
||||||
|
|
||||||
|
@ -16,18 +16,26 @@ from __future__ import annotations
|
|||||||
|
|
||||||
from typing import TYPE_CHECKING
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
|
from google.genai import types
|
||||||
from typing_extensions import override
|
from typing_extensions import override
|
||||||
|
|
||||||
from .function_tool import FunctionTool
|
from .function_tool import FunctionTool
|
||||||
from .tool_context import ToolContext
|
from .tool_context import ToolContext
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from ..models import LlmRequest
|
|
||||||
from ..memory.base_memory_service import MemoryResult
|
from ..memory.base_memory_service import MemoryResult
|
||||||
|
from ..models import LlmRequest
|
||||||
|
|
||||||
|
|
||||||
def load_memory(query: str, tool_context: ToolContext) -> 'list[MemoryResult]':
|
def load_memory(query: str, tool_context: ToolContext) -> 'list[MemoryResult]':
|
||||||
"""Loads the memory for the current user."""
|
"""Loads the memory for the current user.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
query: The query to load the memory for.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A list of memory results.
|
||||||
|
"""
|
||||||
response = tool_context.search_memory(query)
|
response = tool_context.search_memory(query)
|
||||||
return response.memories
|
return response.memories
|
||||||
|
|
||||||
@ -38,6 +46,21 @@ class LoadMemoryTool(FunctionTool):
|
|||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__(load_memory)
|
super().__init__(load_memory)
|
||||||
|
|
||||||
|
@override
|
||||||
|
def _get_declaration(self) -> types.FunctionDeclaration | None:
|
||||||
|
return types.FunctionDeclaration(
|
||||||
|
name=self.name,
|
||||||
|
description=self.description,
|
||||||
|
parameters=types.Schema(
|
||||||
|
type=types.Type.OBJECT,
|
||||||
|
properties={
|
||||||
|
'query': types.Schema(
|
||||||
|
type=types.Type.STRING,
|
||||||
|
)
|
||||||
|
},
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
@override
|
@override
|
||||||
async def process_llm_request(
|
async def process_llm_request(
|
||||||
self,
|
self,
|
||||||
|
176
src/google/adk/tools/mcp_tool/mcp_session_manager.py
Normal file
176
src/google/adk/tools/mcp_tool/mcp_session_manager.py
Normal file
@ -0,0 +1,176 @@
|
|||||||
|
from contextlib import AsyncExitStack
|
||||||
|
import functools
|
||||||
|
import sys
|
||||||
|
from typing import Any, TextIO
|
||||||
|
import anyio
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
try:
|
||||||
|
from mcp import ClientSession, StdioServerParameters
|
||||||
|
from mcp.client.sse import sse_client
|
||||||
|
from mcp.client.stdio import stdio_client
|
||||||
|
except ImportError as e:
|
||||||
|
import sys
|
||||||
|
|
||||||
|
if sys.version_info < (3, 10):
|
||||||
|
raise ImportError(
|
||||||
|
'MCP Tool requires Python 3.10 or above. Please upgrade your Python'
|
||||||
|
' version.'
|
||||||
|
) from e
|
||||||
|
else:
|
||||||
|
raise e
|
||||||
|
|
||||||
|
|
||||||
|
class SseServerParams(BaseModel):
|
||||||
|
"""Parameters for the MCP SSE connection.
|
||||||
|
|
||||||
|
See MCP SSE Client documentation for more details.
|
||||||
|
https://github.com/modelcontextprotocol/python-sdk/blob/main/src/mcp/client/sse.py
|
||||||
|
"""
|
||||||
|
|
||||||
|
url: str
|
||||||
|
headers: dict[str, Any] | None = None
|
||||||
|
timeout: float = 5
|
||||||
|
sse_read_timeout: float = 60 * 5
|
||||||
|
|
||||||
|
|
||||||
|
def retry_on_closed_resource(async_reinit_func_name: str):
|
||||||
|
"""Decorator to automatically reinitialize session and retry action.
|
||||||
|
|
||||||
|
When MCP session was closed, the decorator will automatically recreate the
|
||||||
|
session and retry the action with the same parameters.
|
||||||
|
|
||||||
|
Note:
|
||||||
|
1. async_reinit_func_name is the name of the class member function that
|
||||||
|
reinitializes the MCP session.
|
||||||
|
2. Both the decorated function and the async_reinit_func_name must be async
|
||||||
|
functions.
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
class MCPTool:
|
||||||
|
...
|
||||||
|
async def create_session(self):
|
||||||
|
self.session = ...
|
||||||
|
|
||||||
|
@retry_on_closed_resource('create_session')
|
||||||
|
async def use_session(self):
|
||||||
|
await self.session.call_tool()
|
||||||
|
|
||||||
|
Args:
|
||||||
|
async_reinit_func_name: The name of the async function to recreate session.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The decorated function.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def decorator(func):
|
||||||
|
@functools.wraps(
|
||||||
|
func
|
||||||
|
) # Preserves original function metadata (name, docstring)
|
||||||
|
async def wrapper(self, *args, **kwargs):
|
||||||
|
try:
|
||||||
|
return await func(self, *args, **kwargs)
|
||||||
|
except anyio.ClosedResourceError:
|
||||||
|
try:
|
||||||
|
if hasattr(self, async_reinit_func_name) and callable(
|
||||||
|
getattr(self, async_reinit_func_name)
|
||||||
|
):
|
||||||
|
async_init_fn = getattr(self, async_reinit_func_name)
|
||||||
|
await async_init_fn()
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
f'Function {async_reinit_func_name} does not exist in decorated'
|
||||||
|
' class. Please check the function name in'
|
||||||
|
' retry_on_closed_resource decorator.'
|
||||||
|
)
|
||||||
|
except Exception as reinit_err:
|
||||||
|
raise RuntimeError(
|
||||||
|
f'Error reinitializing: {reinit_err}'
|
||||||
|
) from reinit_err
|
||||||
|
return await func(self, *args, **kwargs)
|
||||||
|
|
||||||
|
return wrapper
|
||||||
|
|
||||||
|
return decorator
|
||||||
|
|
||||||
|
|
||||||
|
class MCPSessionManager:
|
||||||
|
"""Manages MCP client sessions.
|
||||||
|
|
||||||
|
This class provides methods for creating and initializing MCP client sessions,
|
||||||
|
handling different connection parameters (Stdio and SSE).
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
connection_params: StdioServerParameters | SseServerParams,
|
||||||
|
exit_stack: AsyncExitStack,
|
||||||
|
errlog: TextIO = sys.stderr,
|
||||||
|
) -> ClientSession:
|
||||||
|
"""Initializes the MCP session manager.
|
||||||
|
|
||||||
|
Example usage:
|
||||||
|
```
|
||||||
|
mcp_session_manager = MCPSessionManager(
|
||||||
|
connection_params=connection_params,
|
||||||
|
exit_stack=exit_stack,
|
||||||
|
)
|
||||||
|
session = await mcp_session_manager.create_session()
|
||||||
|
```
|
||||||
|
|
||||||
|
Args:
|
||||||
|
connection_params: Parameters for the MCP connection (Stdio or SSE).
|
||||||
|
exit_stack: AsyncExitStack to manage the session lifecycle.
|
||||||
|
errlog: (Optional) TextIO stream for error logging. Use only for
|
||||||
|
initializing a local stdio MCP session.
|
||||||
|
"""
|
||||||
|
self.connection_params = connection_params
|
||||||
|
self.exit_stack = exit_stack
|
||||||
|
self.errlog = errlog
|
||||||
|
|
||||||
|
async def create_session(self) -> ClientSession:
|
||||||
|
return await MCPSessionManager.initialize_session(
|
||||||
|
connection_params=self.connection_params,
|
||||||
|
exit_stack=self.exit_stack,
|
||||||
|
errlog=self.errlog,
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def initialize_session(
|
||||||
|
cls,
|
||||||
|
*,
|
||||||
|
connection_params: StdioServerParameters | SseServerParams,
|
||||||
|
exit_stack: AsyncExitStack,
|
||||||
|
errlog: TextIO = sys.stderr,
|
||||||
|
) -> ClientSession:
|
||||||
|
"""Initializes an MCP client session.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
connection_params: Parameters for the MCP connection (Stdio or SSE).
|
||||||
|
exit_stack: AsyncExitStack to manage the session lifecycle.
|
||||||
|
errlog: (Optional) TextIO stream for error logging. Use only for
|
||||||
|
initializing a local stdio MCP session.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ClientSession: The initialized MCP client session.
|
||||||
|
"""
|
||||||
|
if isinstance(connection_params, StdioServerParameters):
|
||||||
|
client = stdio_client(server=connection_params, errlog=errlog)
|
||||||
|
elif isinstance(connection_params, SseServerParams):
|
||||||
|
client = sse_client(
|
||||||
|
url=connection_params.url,
|
||||||
|
headers=connection_params.headers,
|
||||||
|
timeout=connection_params.timeout,
|
||||||
|
sse_read_timeout=connection_params.sse_read_timeout,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
'Unable to initialize connection. Connection should be'
|
||||||
|
' StdioServerParameters or SseServerParams, but got'
|
||||||
|
f' {connection_params}'
|
||||||
|
)
|
||||||
|
|
||||||
|
transports = await exit_stack.enter_async_context(client)
|
||||||
|
session = await exit_stack.enter_async_context(ClientSession(*transports))
|
||||||
|
await session.initialize()
|
||||||
|
return session
|
@ -17,6 +17,8 @@ from typing import Optional
|
|||||||
from google.genai.types import FunctionDeclaration
|
from google.genai.types import FunctionDeclaration
|
||||||
from typing_extensions import override
|
from typing_extensions import override
|
||||||
|
|
||||||
|
from .mcp_session_manager import MCPSessionManager, retry_on_closed_resource
|
||||||
|
|
||||||
# Attempt to import MCP Tool from the MCP library, and hints user to upgrade
|
# Attempt to import MCP Tool from the MCP library, and hints user to upgrade
|
||||||
# their Python version to 3.10 if it fails.
|
# their Python version to 3.10 if it fails.
|
||||||
try:
|
try:
|
||||||
@ -33,6 +35,7 @@ except ImportError as e:
|
|||||||
else:
|
else:
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
|
|
||||||
from ..base_tool import BaseTool
|
from ..base_tool import BaseTool
|
||||||
from ...auth.auth_credential import AuthCredential
|
from ...auth.auth_credential import AuthCredential
|
||||||
from ...auth.auth_schemes import AuthScheme
|
from ...auth.auth_schemes import AuthScheme
|
||||||
@ -51,6 +54,7 @@ class MCPTool(BaseTool):
|
|||||||
self,
|
self,
|
||||||
mcp_tool: McpBaseTool,
|
mcp_tool: McpBaseTool,
|
||||||
mcp_session: ClientSession,
|
mcp_session: ClientSession,
|
||||||
|
mcp_session_manager: MCPSessionManager,
|
||||||
auth_scheme: Optional[AuthScheme] = None,
|
auth_scheme: Optional[AuthScheme] = None,
|
||||||
auth_credential: Optional[AuthCredential] | None = None,
|
auth_credential: Optional[AuthCredential] | None = None,
|
||||||
):
|
):
|
||||||
@ -79,10 +83,14 @@ class MCPTool(BaseTool):
|
|||||||
self.description = mcp_tool.description if mcp_tool.description else ""
|
self.description = mcp_tool.description if mcp_tool.description else ""
|
||||||
self.mcp_tool = mcp_tool
|
self.mcp_tool = mcp_tool
|
||||||
self.mcp_session = mcp_session
|
self.mcp_session = mcp_session
|
||||||
|
self.mcp_session_manager = mcp_session_manager
|
||||||
# TODO(cheliu): Support passing auth to MCP Server.
|
# TODO(cheliu): Support passing auth to MCP Server.
|
||||||
self.auth_scheme = auth_scheme
|
self.auth_scheme = auth_scheme
|
||||||
self.auth_credential = auth_credential
|
self.auth_credential = auth_credential
|
||||||
|
|
||||||
|
async def _reinitialize_session(self):
|
||||||
|
self.mcp_session = await self.mcp_session_manager.create_session()
|
||||||
|
|
||||||
@override
|
@override
|
||||||
def _get_declaration(self) -> FunctionDeclaration:
|
def _get_declaration(self) -> FunctionDeclaration:
|
||||||
"""Gets the function declaration for the tool.
|
"""Gets the function declaration for the tool.
|
||||||
@ -98,6 +106,7 @@ class MCPTool(BaseTool):
|
|||||||
return function_decl
|
return function_decl
|
||||||
|
|
||||||
@override
|
@override
|
||||||
|
@retry_on_closed_resource("_reinitialize_session")
|
||||||
async def run_async(self, *, args, tool_context: ToolContext):
|
async def run_async(self, *, args, tool_context: ToolContext):
|
||||||
"""Runs the tool asynchronously.
|
"""Runs the tool asynchronously.
|
||||||
|
|
||||||
@ -109,5 +118,9 @@ class MCPTool(BaseTool):
|
|||||||
Any: The response from the tool.
|
Any: The response from the tool.
|
||||||
"""
|
"""
|
||||||
# TODO(cheliu): Support passing tool context to MCP Server.
|
# TODO(cheliu): Support passing tool context to MCP Server.
|
||||||
|
try:
|
||||||
response = await self.mcp_session.call_tool(self.name, arguments=args)
|
response = await self.mcp_session.call_tool(self.name, arguments=args)
|
||||||
return response
|
return response
|
||||||
|
except Exception as e:
|
||||||
|
print(e)
|
||||||
|
raise e
|
||||||
|
@ -13,15 +13,16 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
from contextlib import AsyncExitStack
|
from contextlib import AsyncExitStack
|
||||||
|
import sys
|
||||||
from types import TracebackType
|
from types import TracebackType
|
||||||
from typing import Any, List, Optional, Tuple, Type
|
from typing import List, Optional, TextIO, Tuple, Type
|
||||||
|
|
||||||
|
from .mcp_session_manager import MCPSessionManager, SseServerParams, retry_on_closed_resource
|
||||||
|
|
||||||
# Attempt to import MCP Tool from the MCP library, and hints user to upgrade
|
# Attempt to import MCP Tool from the MCP library, and hints user to upgrade
|
||||||
# their Python version to 3.10 if it fails.
|
# their Python version to 3.10 if it fails.
|
||||||
try:
|
try:
|
||||||
from mcp import ClientSession, StdioServerParameters
|
from mcp import ClientSession, StdioServerParameters
|
||||||
from mcp.client.sse import sse_client
|
|
||||||
from mcp.client.stdio import stdio_client
|
|
||||||
from mcp.types import ListToolsResult
|
from mcp.types import ListToolsResult
|
||||||
except ImportError as e:
|
except ImportError as e:
|
||||||
import sys
|
import sys
|
||||||
@ -34,18 +35,9 @@ except ImportError as e:
|
|||||||
else:
|
else:
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
from pydantic import BaseModel
|
|
||||||
|
|
||||||
from .mcp_tool import MCPTool
|
from .mcp_tool import MCPTool
|
||||||
|
|
||||||
|
|
||||||
class SseServerParams(BaseModel):
|
|
||||||
url: str
|
|
||||||
headers: dict[str, Any] | None = None
|
|
||||||
timeout: float = 5
|
|
||||||
sse_read_timeout: float = 60 * 5
|
|
||||||
|
|
||||||
|
|
||||||
class MCPToolset:
|
class MCPToolset:
|
||||||
"""Connects to a MCP Server, and retrieves MCP Tools into ADK Tools.
|
"""Connects to a MCP Server, and retrieves MCP Tools into ADK Tools.
|
||||||
|
|
||||||
@ -110,7 +102,11 @@ class MCPToolset:
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self, *, connection_params: StdioServerParameters | SseServerParams
|
self,
|
||||||
|
*,
|
||||||
|
connection_params: StdioServerParameters | SseServerParams,
|
||||||
|
errlog: TextIO = sys.stderr,
|
||||||
|
exit_stack=AsyncExitStack(),
|
||||||
):
|
):
|
||||||
"""Initializes the MCPToolset.
|
"""Initializes the MCPToolset.
|
||||||
|
|
||||||
@ -175,7 +171,14 @@ class MCPToolset:
|
|||||||
if not connection_params:
|
if not connection_params:
|
||||||
raise ValueError('Missing connection params in MCPToolset.')
|
raise ValueError('Missing connection params in MCPToolset.')
|
||||||
self.connection_params = connection_params
|
self.connection_params = connection_params
|
||||||
self.exit_stack = AsyncExitStack()
|
self.errlog = errlog
|
||||||
|
self.exit_stack = exit_stack
|
||||||
|
|
||||||
|
self.session_manager = MCPSessionManager(
|
||||||
|
connection_params=self.connection_params,
|
||||||
|
exit_stack=self.exit_stack,
|
||||||
|
errlog=self.errlog,
|
||||||
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
async def from_server(
|
async def from_server(
|
||||||
@ -183,6 +186,7 @@ class MCPToolset:
|
|||||||
*,
|
*,
|
||||||
connection_params: StdioServerParameters | SseServerParams,
|
connection_params: StdioServerParameters | SseServerParams,
|
||||||
async_exit_stack: Optional[AsyncExitStack] = None,
|
async_exit_stack: Optional[AsyncExitStack] = None,
|
||||||
|
errlog: TextIO = sys.stderr,
|
||||||
) -> Tuple[List[MCPTool], AsyncExitStack]:
|
) -> Tuple[List[MCPTool], AsyncExitStack]:
|
||||||
"""Retrieve all tools from the MCP connection.
|
"""Retrieve all tools from the MCP connection.
|
||||||
|
|
||||||
@ -209,41 +213,27 @@ class MCPToolset:
|
|||||||
the MCP server. Use `await async_exit_stack.aclose()` to close the
|
the MCP server. Use `await async_exit_stack.aclose()` to close the
|
||||||
connection when server shuts down.
|
connection when server shuts down.
|
||||||
"""
|
"""
|
||||||
toolset = cls(connection_params=connection_params)
|
|
||||||
async_exit_stack = async_exit_stack or AsyncExitStack()
|
async_exit_stack = async_exit_stack or AsyncExitStack()
|
||||||
|
toolset = cls(
|
||||||
|
connection_params=connection_params,
|
||||||
|
exit_stack=async_exit_stack,
|
||||||
|
errlog=errlog,
|
||||||
|
)
|
||||||
|
|
||||||
await async_exit_stack.enter_async_context(toolset)
|
await async_exit_stack.enter_async_context(toolset)
|
||||||
tools = await toolset.load_tools()
|
tools = await toolset.load_tools()
|
||||||
return (tools, async_exit_stack)
|
return (tools, async_exit_stack)
|
||||||
|
|
||||||
async def _initialize(self) -> ClientSession:
|
async def _initialize(self) -> ClientSession:
|
||||||
"""Connects to the MCP Server and initializes the ClientSession."""
|
"""Connects to the MCP Server and initializes the ClientSession."""
|
||||||
if isinstance(self.connection_params, StdioServerParameters):
|
self.session = await self.session_manager.create_session()
|
||||||
client = stdio_client(self.connection_params)
|
|
||||||
elif isinstance(self.connection_params, SseServerParams):
|
|
||||||
client = sse_client(
|
|
||||||
url=self.connection_params.url,
|
|
||||||
headers=self.connection_params.headers,
|
|
||||||
timeout=self.connection_params.timeout,
|
|
||||||
sse_read_timeout=self.connection_params.sse_read_timeout,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
raise ValueError(
|
|
||||||
'Unable to initialize connection. Connection should be'
|
|
||||||
' StdioServerParameters or SseServerParams, but got'
|
|
||||||
f' {self.connection_params}'
|
|
||||||
)
|
|
||||||
|
|
||||||
transports = await self.exit_stack.enter_async_context(client)
|
|
||||||
self.session = await self.exit_stack.enter_async_context(
|
|
||||||
ClientSession(*transports)
|
|
||||||
)
|
|
||||||
await self.session.initialize()
|
|
||||||
return self.session
|
return self.session
|
||||||
|
|
||||||
async def _exit(self):
|
async def _exit(self):
|
||||||
"""Closes the connection to MCP Server."""
|
"""Closes the connection to MCP Server."""
|
||||||
await self.exit_stack.aclose()
|
await self.exit_stack.aclose()
|
||||||
|
|
||||||
|
@retry_on_closed_resource('_initialize')
|
||||||
async def load_tools(self) -> List[MCPTool]:
|
async def load_tools(self) -> List[MCPTool]:
|
||||||
"""Loads all tools from the MCP Server.
|
"""Loads all tools from the MCP Server.
|
||||||
|
|
||||||
@ -252,7 +242,11 @@ class MCPToolset:
|
|||||||
"""
|
"""
|
||||||
tools_response: ListToolsResult = await self.session.list_tools()
|
tools_response: ListToolsResult = await self.session.list_tools()
|
||||||
return [
|
return [
|
||||||
MCPTool(mcp_tool=tool, mcp_session=self.session)
|
MCPTool(
|
||||||
|
mcp_tool=tool,
|
||||||
|
mcp_session=self.session,
|
||||||
|
mcp_session_manager=self.session_manager,
|
||||||
|
)
|
||||||
for tool in tools_response.tools
|
for tool in tools_response.tools
|
||||||
]
|
]
|
||||||
|
|
||||||
|
@ -28,7 +28,7 @@ from typing_extensions import override
|
|||||||
|
|
||||||
from ....auth.auth_credential import AuthCredential
|
from ....auth.auth_credential import AuthCredential
|
||||||
from ....auth.auth_schemes import AuthScheme
|
from ....auth.auth_schemes import AuthScheme
|
||||||
from ....tools import BaseTool
|
from ....tools.base_tool import BaseTool
|
||||||
from ...tool_context import ToolContext
|
from ...tool_context import ToolContext
|
||||||
from ..auth.auth_helpers import credential_to_param
|
from ..auth.auth_helpers import credential_to_param
|
||||||
from ..auth.auth_helpers import dict_to_auth_scheme
|
from ..auth.auth_helpers import dict_to_auth_scheme
|
||||||
|
@ -13,4 +13,4 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
# version: date+base_cl
|
# version: date+base_cl
|
||||||
__version__ = "0.1.0"
|
__version__ = "0.1.1"
|
||||||
|
13
tests/unittests/cli/__init__.py
Normal file
13
tests/unittests/cli/__init__.py
Normal file
@ -0,0 +1,13 @@
|
|||||||
|
# Copyright 2025 Google LLC
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
13
tests/unittests/cli/utils/__init__.py
Normal file
13
tests/unittests/cli/utils/__init__.py
Normal file
@ -0,0 +1,13 @@
|
|||||||
|
# Copyright 2025 Google LLC
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
434
tests/unittests/cli/utils/test_evals.py
Normal file
434
tests/unittests/cli/utils/test_evals.py
Normal file
@ -0,0 +1,434 @@
|
|||||||
|
# Copyright 2025 Google LLC
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
"""Tests for utilities in eval."""
|
||||||
|
|
||||||
|
|
||||||
|
from google.adk.cli.utils.evals import convert_session_to_eval_format
|
||||||
|
from google.adk.events.event import Event
|
||||||
|
from google.adk.sessions.session import Session
|
||||||
|
from google.genai import types
|
||||||
|
|
||||||
|
|
||||||
|
def build_event(author: str, parts_content: list[dict]) -> Event:
|
||||||
|
"""Builds an Event object with specified parts."""
|
||||||
|
parts = []
|
||||||
|
for p_data in parts_content:
|
||||||
|
part_args = {}
|
||||||
|
if "text" in p_data:
|
||||||
|
part_args["text"] = p_data["text"]
|
||||||
|
if "func_name" in p_data:
|
||||||
|
part_args["function_call"] = types.FunctionCall(
|
||||||
|
name=p_data.get("func_name"), args=p_data.get("func_args")
|
||||||
|
)
|
||||||
|
# Add other part types here if needed for future tests
|
||||||
|
parts.append(types.Part(**part_args))
|
||||||
|
return Event(author=author, content=types.Content(parts=parts))
|
||||||
|
|
||||||
|
|
||||||
|
def test_convert_empty_session():
|
||||||
|
"""Test conversion function with empty events list in Session."""
|
||||||
|
# Pydantic models require mandatory fields for instantiation
|
||||||
|
session_empty_events = Session(
|
||||||
|
id="s1", app_name="app", user_id="u1", events=[]
|
||||||
|
)
|
||||||
|
assert not convert_session_to_eval_format(session_empty_events)
|
||||||
|
|
||||||
|
|
||||||
|
def test_convert_none_session():
|
||||||
|
"""Test conversion function with None Session."""
|
||||||
|
assert not convert_session_to_eval_format(None)
|
||||||
|
|
||||||
|
|
||||||
|
def test_convert_session_skips_initial_non_user_events():
|
||||||
|
"""Test conversion function with only user events."""
|
||||||
|
events = [
|
||||||
|
build_event("model", [{"text": "Hello"}]),
|
||||||
|
build_event("user", [{"text": "How are you?"}]),
|
||||||
|
]
|
||||||
|
session = Session(id="s1", app_name="app", user_id="u1", events=events)
|
||||||
|
expected = [
|
||||||
|
{
|
||||||
|
"query": "How are you?",
|
||||||
|
"expected_tool_use": [],
|
||||||
|
"expected_intermediate_agent_responses": [],
|
||||||
|
"reference": "",
|
||||||
|
},
|
||||||
|
]
|
||||||
|
assert convert_session_to_eval_format(session) == expected
|
||||||
|
|
||||||
|
|
||||||
|
def test_convert_single_turn_text_only():
|
||||||
|
"""Test a single user query followed by a single agent text response."""
|
||||||
|
events = [
|
||||||
|
build_event("user", [{"text": "What is the time?"}]),
|
||||||
|
build_event("root_agent", [{"text": "It is 3 PM."}]),
|
||||||
|
]
|
||||||
|
session = Session(id="s1", app_name="app", user_id="u1", events=events)
|
||||||
|
expected = [{
|
||||||
|
"query": "What is the time?",
|
||||||
|
"expected_tool_use": [],
|
||||||
|
"expected_intermediate_agent_responses": [],
|
||||||
|
"reference": "It is 3 PM.",
|
||||||
|
}]
|
||||||
|
assert convert_session_to_eval_format(session) == expected
|
||||||
|
|
||||||
|
|
||||||
|
def test_convert_single_turn_tool_only():
|
||||||
|
"""Test a single user query followed by a single agent tool call."""
|
||||||
|
events = [
|
||||||
|
build_event("user", [{"text": "Get weather for Seattle"}]),
|
||||||
|
build_event(
|
||||||
|
"root_agent",
|
||||||
|
[{"func_name": "get_weather", "func_args": {"city": "Seattle"}}],
|
||||||
|
),
|
||||||
|
]
|
||||||
|
session = Session(id="s1", app_name="app", user_id="u1", events=events)
|
||||||
|
expected = [{
|
||||||
|
"query": "Get weather for Seattle",
|
||||||
|
"expected_tool_use": [
|
||||||
|
{"tool_name": "get_weather", "tool_input": {"city": "Seattle"}}
|
||||||
|
],
|
||||||
|
"expected_intermediate_agent_responses": [],
|
||||||
|
"reference": "",
|
||||||
|
}]
|
||||||
|
assert convert_session_to_eval_format(session) == expected
|
||||||
|
|
||||||
|
|
||||||
|
def test_convert_single_turn_multiple_tools_and_texts():
|
||||||
|
"""Test a turn with multiple agent responses (tools and text)."""
|
||||||
|
events = [
|
||||||
|
build_event("user", [{"text": "Do task A then task B"}]),
|
||||||
|
build_event(
|
||||||
|
"root_agent", [{"text": "Okay, starting task A."}]
|
||||||
|
), # Intermediate Text 1
|
||||||
|
build_event(
|
||||||
|
"root_agent", [{"func_name": "task_A", "func_args": {"param": 1}}]
|
||||||
|
), # Tool 1
|
||||||
|
build_event(
|
||||||
|
"root_agent", [{"text": "Task A done. Now starting task B."}]
|
||||||
|
), # Intermediate Text 2
|
||||||
|
build_event(
|
||||||
|
"another_agent", [{"func_name": "task_B", "func_args": {}}]
|
||||||
|
), # Tool 2
|
||||||
|
build_event(
|
||||||
|
"root_agent", [{"text": "All tasks completed."}]
|
||||||
|
), # Final Text (Reference)
|
||||||
|
]
|
||||||
|
session = Session(id="s1", app_name="app", user_id="u1", events=events)
|
||||||
|
expected = [{
|
||||||
|
"query": "Do task A then task B",
|
||||||
|
"expected_tool_use": [
|
||||||
|
{"tool_name": "task_A", "tool_input": {"param": 1}},
|
||||||
|
{"tool_name": "task_B", "tool_input": {}},
|
||||||
|
],
|
||||||
|
"expected_intermediate_agent_responses": [
|
||||||
|
{"author": "root_agent", "text": "Okay, starting task A."},
|
||||||
|
{
|
||||||
|
"author": "root_agent",
|
||||||
|
"text": "Task A done. Now starting task B.",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
"reference": "All tasks completed.",
|
||||||
|
}]
|
||||||
|
assert convert_session_to_eval_format(session) == expected
|
||||||
|
|
||||||
|
|
||||||
|
def test_convert_multi_turn_session():
|
||||||
|
"""Test a session with multiple user/agent turns."""
|
||||||
|
events = [
|
||||||
|
build_event("user", [{"text": "Query 1"}]),
|
||||||
|
build_event("agent", [{"text": "Response 1"}]),
|
||||||
|
build_event("user", [{"text": "Query 2"}]),
|
||||||
|
build_event("agent", [{"func_name": "tool_X", "func_args": {}}]),
|
||||||
|
build_event("agent", [{"text": "Response 2"}]),
|
||||||
|
]
|
||||||
|
session = Session(id="s1", app_name="app", user_id="u1", events=events)
|
||||||
|
expected = [
|
||||||
|
{ # Turn 1
|
||||||
|
"query": "Query 1",
|
||||||
|
"expected_tool_use": [],
|
||||||
|
"expected_intermediate_agent_responses": [],
|
||||||
|
"reference": "Response 1",
|
||||||
|
},
|
||||||
|
{ # Turn 2
|
||||||
|
"query": "Query 2",
|
||||||
|
"expected_tool_use": [{"tool_name": "tool_X", "tool_input": {}}],
|
||||||
|
"expected_intermediate_agent_responses": [],
|
||||||
|
"reference": "Response 2",
|
||||||
|
},
|
||||||
|
]
|
||||||
|
assert convert_session_to_eval_format(session) == expected
|
||||||
|
|
||||||
|
|
||||||
|
def test_convert_agent_event_multiple_parts():
|
||||||
|
"""Test an agent event with both text and tool call parts."""
|
||||||
|
events = [
|
||||||
|
build_event("user", [{"text": "Do something complex"}]),
|
||||||
|
# Build event with multiple dicts in parts_content list
|
||||||
|
build_event(
|
||||||
|
"agent",
|
||||||
|
[
|
||||||
|
{"text": "Okay, doing it."},
|
||||||
|
{"func_name": "complex_tool", "func_args": {"value": True}},
|
||||||
|
],
|
||||||
|
),
|
||||||
|
build_event("agent", [{"text": "Finished."}]),
|
||||||
|
]
|
||||||
|
session = Session(id="s1", app_name="app", user_id="u1", events=events)
|
||||||
|
expected = [{
|
||||||
|
"query": "Do something complex",
|
||||||
|
"expected_tool_use": [
|
||||||
|
{"tool_name": "complex_tool", "tool_input": {"value": True}}
|
||||||
|
],
|
||||||
|
"expected_intermediate_agent_responses": [{
|
||||||
|
"author": "agent",
|
||||||
|
"text": "Okay, doing it.",
|
||||||
|
}], # Text from first part of agent event
|
||||||
|
"reference": "Finished.", # Text from second agent event
|
||||||
|
}]
|
||||||
|
assert convert_session_to_eval_format(session) == expected
|
||||||
|
|
||||||
|
|
||||||
|
def test_convert_handles_missing_content_or_parts():
|
||||||
|
"""Test that events missing content or parts are skipped gracefully."""
|
||||||
|
events = [
|
||||||
|
build_event("user", [{"text": "Query 1"}]),
|
||||||
|
Event(author="agent", content=None), # Agent event missing content
|
||||||
|
build_event("agent", [{"text": "Response 1"}]),
|
||||||
|
Event(author="user", content=None), # User event missing content
|
||||||
|
build_event("user", [{"text": "Query 2"}]),
|
||||||
|
Event(
|
||||||
|
author="agent", content=types.Content(parts=[])
|
||||||
|
), # Agent event with empty parts list
|
||||||
|
build_event("agent", [{"text": "Response 2"}]),
|
||||||
|
# User event with content but no parts (or None parts)
|
||||||
|
Event(author="user", content=types.Content(parts=None)),
|
||||||
|
build_event("user", [{"text": "Query 3"}]),
|
||||||
|
build_event("agent", [{"text": "Response 3"}]),
|
||||||
|
]
|
||||||
|
session = Session(id="s1", app_name="app", user_id="u1", events=events)
|
||||||
|
expected = [
|
||||||
|
{ # Turn 1 (from Query 1)
|
||||||
|
"query": "Query 1",
|
||||||
|
"expected_tool_use": [],
|
||||||
|
"expected_intermediate_agent_responses": [],
|
||||||
|
"reference": "Response 1",
|
||||||
|
},
|
||||||
|
{ # Turn 2 (from Query 2 - user event with None content was skipped)
|
||||||
|
"query": "Query 2",
|
||||||
|
"expected_tool_use": [],
|
||||||
|
"expected_intermediate_agent_responses": [],
|
||||||
|
"reference": "Response 2",
|
||||||
|
},
|
||||||
|
{ # Turn 3 (from Query 3 - user event with None parts was skipped)
|
||||||
|
"query": "Query 3",
|
||||||
|
"expected_tool_use": [],
|
||||||
|
"expected_intermediate_agent_responses": [],
|
||||||
|
"reference": "Response 3",
|
||||||
|
},
|
||||||
|
]
|
||||||
|
assert convert_session_to_eval_format(session) == expected
|
||||||
|
|
||||||
|
|
||||||
|
def test_convert_handles_missing_tool_name_or_args():
|
||||||
|
"""Test tool calls with missing name or args."""
|
||||||
|
events = [
|
||||||
|
build_event("user", [{"text": "Call tools"}]),
|
||||||
|
# Event where FunctionCall has name=None
|
||||||
|
Event(
|
||||||
|
author="agent",
|
||||||
|
content=types.Content(
|
||||||
|
parts=[
|
||||||
|
types.Part(
|
||||||
|
function_call=types.FunctionCall(name=None, args={"a": 1})
|
||||||
|
)
|
||||||
|
]
|
||||||
|
),
|
||||||
|
),
|
||||||
|
# Event where FunctionCall has args=None
|
||||||
|
Event(
|
||||||
|
author="agent",
|
||||||
|
content=types.Content(
|
||||||
|
parts=[
|
||||||
|
types.Part(
|
||||||
|
function_call=types.FunctionCall(name="tool_B", args=None)
|
||||||
|
)
|
||||||
|
]
|
||||||
|
),
|
||||||
|
),
|
||||||
|
# Event where FunctionCall part exists but FunctionCall object is None
|
||||||
|
# (should skip)
|
||||||
|
Event(
|
||||||
|
author="agent",
|
||||||
|
content=types.Content(
|
||||||
|
parts=[types.Part(function_call=None, text="some text")]
|
||||||
|
),
|
||||||
|
),
|
||||||
|
build_event("agent", [{"text": "Done"}]),
|
||||||
|
]
|
||||||
|
session = Session(id="s1", app_name="app", user_id="u1", events=events)
|
||||||
|
expected = [{
|
||||||
|
"query": "Call tools",
|
||||||
|
"expected_tool_use": [
|
||||||
|
{"tool_name": "", "tool_input": {"a": 1}}, # Defaults name to ""
|
||||||
|
{"tool_name": "tool_B", "tool_input": {}}, # Defaults args to {}
|
||||||
|
],
|
||||||
|
"expected_intermediate_agent_responses": [{
|
||||||
|
"author": "agent",
|
||||||
|
"text": "some text",
|
||||||
|
}], # Text part from the event where function_call was None
|
||||||
|
"reference": "Done",
|
||||||
|
}]
|
||||||
|
assert convert_session_to_eval_format(session) == expected
|
||||||
|
|
||||||
|
|
||||||
|
def test_convert_handles_missing_user_query_text():
|
||||||
|
"""Test user event where the first part has no text."""
|
||||||
|
events = [
|
||||||
|
# Event where user part has text=None
|
||||||
|
Event(
|
||||||
|
author="user", content=types.Content(parts=[types.Part(text=None)])
|
||||||
|
),
|
||||||
|
build_event("agent", [{"text": "Response 1"}]),
|
||||||
|
# Event where user part has text=""
|
||||||
|
build_event("user", [{"text": ""}]),
|
||||||
|
build_event("agent", [{"text": "Response 2"}]),
|
||||||
|
]
|
||||||
|
session = Session(id="s1", app_name="app", user_id="u1", events=events)
|
||||||
|
expected = [
|
||||||
|
{
|
||||||
|
"query": "", # Defaults to "" if text is None
|
||||||
|
"expected_tool_use": [],
|
||||||
|
"expected_intermediate_agent_responses": [],
|
||||||
|
"reference": "Response 1",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"query": "", # Defaults to "" if text is ""
|
||||||
|
"expected_tool_use": [],
|
||||||
|
"expected_intermediate_agent_responses": [],
|
||||||
|
"reference": "Response 2",
|
||||||
|
},
|
||||||
|
]
|
||||||
|
assert convert_session_to_eval_format(session) == expected
|
||||||
|
|
||||||
|
|
||||||
|
def test_convert_handles_empty_agent_text():
|
||||||
|
"""Test agent responses with empty string text."""
|
||||||
|
events = [
|
||||||
|
build_event("user", [{"text": "Query"}]),
|
||||||
|
build_event("agent", [{"text": "Okay"}]),
|
||||||
|
build_event("agent", [{"text": ""}]), # Empty text
|
||||||
|
build_event("agent", [{"text": "Done"}]),
|
||||||
|
]
|
||||||
|
session = Session(id="s1", app_name="app", user_id="u1", events=events)
|
||||||
|
expected = [{
|
||||||
|
"query": "Query",
|
||||||
|
"expected_tool_use": [],
|
||||||
|
"expected_intermediate_agent_responses": [
|
||||||
|
{"author": "agent", "text": "Okay"},
|
||||||
|
],
|
||||||
|
"reference": "Done",
|
||||||
|
}]
|
||||||
|
assert convert_session_to_eval_format(session) == expected
|
||||||
|
|
||||||
|
|
||||||
|
def test_convert_complex_sample_session():
|
||||||
|
"""Test using the complex sample session provided earlier."""
|
||||||
|
events = [
|
||||||
|
build_event("user", [{"text": "What can you do?"}]),
|
||||||
|
build_event(
|
||||||
|
"root_agent",
|
||||||
|
[{"text": "I can roll dice and check if numbers are prime. \n"}],
|
||||||
|
),
|
||||||
|
build_event(
|
||||||
|
"user",
|
||||||
|
[{
|
||||||
|
"text": (
|
||||||
|
"Roll a 8 sided dice and then check if 90 is a prime number"
|
||||||
|
" or not."
|
||||||
|
)
|
||||||
|
}],
|
||||||
|
),
|
||||||
|
build_event(
|
||||||
|
"root_agent",
|
||||||
|
[{
|
||||||
|
"func_name": "transfer_to_agent",
|
||||||
|
"func_args": {"agent_name": "roll_agent"},
|
||||||
|
}],
|
||||||
|
),
|
||||||
|
# Skipping FunctionResponse events as they don't have text/functionCall
|
||||||
|
# parts used by converter
|
||||||
|
build_event(
|
||||||
|
"roll_agent", [{"func_name": "roll_die", "func_args": {"sides": 8}}]
|
||||||
|
),
|
||||||
|
# Skipping FunctionResponse
|
||||||
|
build_event(
|
||||||
|
"roll_agent",
|
||||||
|
[
|
||||||
|
{"text": "I rolled a 2. Now, I'll check if 90 is prime. \n\n"},
|
||||||
|
{
|
||||||
|
"func_name": "transfer_to_agent",
|
||||||
|
"func_args": {"agent_name": "prime_agent"},
|
||||||
|
},
|
||||||
|
],
|
||||||
|
),
|
||||||
|
# Skipping FunctionResponse
|
||||||
|
build_event(
|
||||||
|
"prime_agent",
|
||||||
|
[{"func_name": "check_prime", "func_args": {"nums": [90]}}],
|
||||||
|
),
|
||||||
|
# Skipping FunctionResponse
|
||||||
|
build_event("prime_agent", [{"text": "90 is not a prime number. \n"}]),
|
||||||
|
]
|
||||||
|
session = Session(
|
||||||
|
id="some_id",
|
||||||
|
app_name="hello_world_ma",
|
||||||
|
user_id="user",
|
||||||
|
events=events,
|
||||||
|
)
|
||||||
|
expected = [
|
||||||
|
{
|
||||||
|
"query": "What can you do?",
|
||||||
|
"expected_tool_use": [],
|
||||||
|
"expected_intermediate_agent_responses": [],
|
||||||
|
"reference": "I can roll dice and check if numbers are prime. \n",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"query": (
|
||||||
|
"Roll a 8 sided dice and then check if 90 is a prime number or"
|
||||||
|
" not."
|
||||||
|
),
|
||||||
|
"expected_tool_use": [
|
||||||
|
{
|
||||||
|
"tool_name": "transfer_to_agent",
|
||||||
|
"tool_input": {"agent_name": "roll_agent"},
|
||||||
|
},
|
||||||
|
{"tool_name": "roll_die", "tool_input": {"sides": 8}},
|
||||||
|
{
|
||||||
|
"tool_name": "transfer_to_agent",
|
||||||
|
"tool_input": {"agent_name": "prime_agent"},
|
||||||
|
}, # From combined event
|
||||||
|
{"tool_name": "check_prime", "tool_input": {"nums": [90]}},
|
||||||
|
],
|
||||||
|
"expected_intermediate_agent_responses": [{
|
||||||
|
"author": "roll_agent",
|
||||||
|
"text": "I rolled a 2. Now, I'll check if 90 is prime. \n\n",
|
||||||
|
}], # Text from combined event
|
||||||
|
"reference": "90 is not a prime number. \n",
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
actual = convert_session_to_eval_format(session)
|
||||||
|
assert actual == expected
|
13
tests/unittests/evaluation/__init__.py
Normal file
13
tests/unittests/evaluation/__init__.py
Normal file
@ -0,0 +1,13 @@
|
|||||||
|
# Copyright 2025 Google LLC
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
259
tests/unittests/evaluation/test_response_evaluator.py
Normal file
259
tests/unittests/evaluation/test_response_evaluator.py
Normal file
@ -0,0 +1,259 @@
|
|||||||
|
# Copyright 2025 Google LLC
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
"""Tests for the Response Evaluator."""
|
||||||
|
from unittest.mock import MagicMock
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
|
from google.adk.evaluation.response_evaluator import ResponseEvaluator
|
||||||
|
import pandas as pd
|
||||||
|
import pytest
|
||||||
|
from vertexai.preview.evaluation import MetricPromptTemplateExamples
|
||||||
|
|
||||||
|
# Mock object for the result normally returned by _perform_eval
|
||||||
|
MOCK_EVAL_RESULT = MagicMock()
|
||||||
|
MOCK_EVAL_RESULT.summary_metrics = {"mock_metric": 0.75, "another_mock": 3.5}
|
||||||
|
# Add a metrics_table for testing _print_results interaction
|
||||||
|
MOCK_EVAL_RESULT.metrics_table = pd.DataFrame({
|
||||||
|
"prompt": ["mock_query1"],
|
||||||
|
"response": ["mock_resp1"],
|
||||||
|
"mock_metric": [0.75],
|
||||||
|
})
|
||||||
|
|
||||||
|
SAMPLE_TURN_1_ALL_KEYS = {
|
||||||
|
"query": "query1",
|
||||||
|
"response": "response1",
|
||||||
|
"actual_tool_use": [{"tool_name": "tool_a", "tool_input": {}}],
|
||||||
|
"expected_tool_use": [{"tool_name": "tool_a", "tool_input": {}}],
|
||||||
|
"reference": "reference1",
|
||||||
|
}
|
||||||
|
SAMPLE_TURN_2_MISSING_REF = {
|
||||||
|
"query": "query2",
|
||||||
|
"response": "response2",
|
||||||
|
"actual_tool_use": [],
|
||||||
|
"expected_tool_use": [],
|
||||||
|
# "reference": "reference2" # Missing
|
||||||
|
}
|
||||||
|
SAMPLE_TURN_3_MISSING_EXP_TOOLS = {
|
||||||
|
"query": "query3",
|
||||||
|
"response": "response3",
|
||||||
|
"actual_tool_use": [{"tool_name": "tool_b", "tool_input": {}}],
|
||||||
|
# "expected_tool_use": [], # Missing
|
||||||
|
"reference": "reference3",
|
||||||
|
}
|
||||||
|
SAMPLE_TURN_4_MINIMAL = {
|
||||||
|
"query": "query4",
|
||||||
|
"response": "response4",
|
||||||
|
# Minimal keys, others missing
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@patch(
|
||||||
|
"google.adk.evaluation.response_evaluator.ResponseEvaluator._perform_eval"
|
||||||
|
)
|
||||||
|
class TestResponseEvaluator:
|
||||||
|
"""A class to help organize "patch" that are applicabple to all tests."""
|
||||||
|
|
||||||
|
def test_evaluate_none_dataset_raises_value_error(self, mock_perform_eval):
|
||||||
|
"""Test evaluate function raises ValueError for an empty list."""
|
||||||
|
with pytest.raises(ValueError, match="The evaluation dataset is empty."):
|
||||||
|
ResponseEvaluator.evaluate(None, ["response_evaluation_score"])
|
||||||
|
mock_perform_eval.assert_not_called() # Ensure _perform_eval was not called
|
||||||
|
|
||||||
|
def test_evaluate_empty_dataset_raises_value_error(self, mock_perform_eval):
|
||||||
|
"""Test evaluate function raises ValueError for an empty list."""
|
||||||
|
with pytest.raises(ValueError, match="The evaluation dataset is empty."):
|
||||||
|
ResponseEvaluator.evaluate([], ["response_evaluation_score"])
|
||||||
|
mock_perform_eval.assert_not_called() # Ensure _perform_eval was not called
|
||||||
|
|
||||||
|
def test_evaluate_determines_metrics_correctly_for_perform_eval(
|
||||||
|
self, mock_perform_eval
|
||||||
|
):
|
||||||
|
"""Test that the correct metrics list is passed to _perform_eval based on criteria/keys."""
|
||||||
|
mock_perform_eval.return_value = MOCK_EVAL_RESULT
|
||||||
|
|
||||||
|
# Test case 1: Only Coherence
|
||||||
|
raw_data_1 = [[SAMPLE_TURN_1_ALL_KEYS]]
|
||||||
|
criteria_1 = ["response_evaluation_score"]
|
||||||
|
ResponseEvaluator.evaluate(raw_data_1, criteria_1)
|
||||||
|
_, kwargs = mock_perform_eval.call_args
|
||||||
|
assert kwargs["metrics"] == [
|
||||||
|
MetricPromptTemplateExamples.Pointwise.COHERENCE
|
||||||
|
]
|
||||||
|
mock_perform_eval.reset_mock() # Reset mock for next call
|
||||||
|
|
||||||
|
# Test case 2: Only Rouge
|
||||||
|
raw_data_2 = [[SAMPLE_TURN_1_ALL_KEYS]]
|
||||||
|
criteria_2 = ["response_match_score"]
|
||||||
|
ResponseEvaluator.evaluate(raw_data_2, criteria_2)
|
||||||
|
_, kwargs = mock_perform_eval.call_args
|
||||||
|
assert kwargs["metrics"] == ["rouge_1"]
|
||||||
|
mock_perform_eval.reset_mock()
|
||||||
|
|
||||||
|
# Test case 3: No metrics if keys missing in first turn
|
||||||
|
raw_data_3 = [[SAMPLE_TURN_4_MINIMAL, SAMPLE_TURN_1_ALL_KEYS]]
|
||||||
|
criteria_3 = ["response_evaluation_score", "response_match_score"]
|
||||||
|
ResponseEvaluator.evaluate(raw_data_3, criteria_3)
|
||||||
|
_, kwargs = mock_perform_eval.call_args
|
||||||
|
assert kwargs["metrics"] == []
|
||||||
|
mock_perform_eval.reset_mock()
|
||||||
|
|
||||||
|
# Test case 4: No metrics if criteria empty
|
||||||
|
raw_data_4 = [[SAMPLE_TURN_1_ALL_KEYS]]
|
||||||
|
criteria_4 = []
|
||||||
|
ResponseEvaluator.evaluate(raw_data_4, criteria_4)
|
||||||
|
_, kwargs = mock_perform_eval.call_args
|
||||||
|
assert kwargs["metrics"] == []
|
||||||
|
mock_perform_eval.reset_mock()
|
||||||
|
|
||||||
|
def test_evaluate_calls_perform_eval_correctly_all_metrics(
|
||||||
|
self, mock_perform_eval
|
||||||
|
):
|
||||||
|
"""Test evaluate function calls _perform_eval with expected args when all criteria/keys are present."""
|
||||||
|
# Arrange
|
||||||
|
mock_perform_eval.return_value = (
|
||||||
|
MOCK_EVAL_RESULT # Configure the mock return value
|
||||||
|
)
|
||||||
|
|
||||||
|
raw_data = [[SAMPLE_TURN_1_ALL_KEYS]]
|
||||||
|
criteria = ["response_evaluation_score", "response_match_score"]
|
||||||
|
|
||||||
|
# Act
|
||||||
|
summary = ResponseEvaluator.evaluate(raw_data, criteria)
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
# 1. Check metrics determined by _get_metrics (passed to _perform_eval)
|
||||||
|
expected_metrics_list = [
|
||||||
|
MetricPromptTemplateExamples.Pointwise.COHERENCE,
|
||||||
|
"rouge_1",
|
||||||
|
]
|
||||||
|
# 2. Check DataFrame prepared (passed to _perform_eval)
|
||||||
|
expected_df_data = [{
|
||||||
|
"prompt": "query1",
|
||||||
|
"response": "response1",
|
||||||
|
"actual_tool_use": [{"tool_name": "tool_a", "tool_input": {}}],
|
||||||
|
"reference_trajectory": [{"tool_name": "tool_a", "tool_input": {}}],
|
||||||
|
"reference": "reference1",
|
||||||
|
}]
|
||||||
|
expected_df = pd.DataFrame(expected_df_data)
|
||||||
|
|
||||||
|
# Assert _perform_eval was called once
|
||||||
|
mock_perform_eval.assert_called_once()
|
||||||
|
# Get the arguments passed to the mocked _perform_eval
|
||||||
|
_, kwargs = mock_perform_eval.call_args
|
||||||
|
# Check the 'dataset' keyword argument
|
||||||
|
pd.testing.assert_frame_equal(kwargs["dataset"], expected_df)
|
||||||
|
# Check the 'metrics' keyword argument
|
||||||
|
assert kwargs["metrics"] == expected_metrics_list
|
||||||
|
|
||||||
|
# 3. Check the correct summary metrics are returned
|
||||||
|
# (from mock_perform_eval's return value)
|
||||||
|
assert summary == MOCK_EVAL_RESULT.summary_metrics
|
||||||
|
|
||||||
|
def test_evaluate_prepares_dataframe_correctly_for_perform_eval(
|
||||||
|
self, mock_perform_eval
|
||||||
|
):
|
||||||
|
"""Test that the DataFrame is correctly flattened and renamed before passing to _perform_eval."""
|
||||||
|
mock_perform_eval.return_value = MOCK_EVAL_RESULT
|
||||||
|
|
||||||
|
raw_data = [
|
||||||
|
[SAMPLE_TURN_1_ALL_KEYS], # Conversation 1
|
||||||
|
[
|
||||||
|
SAMPLE_TURN_2_MISSING_REF,
|
||||||
|
SAMPLE_TURN_3_MISSING_EXP_TOOLS,
|
||||||
|
], # Conversation 2
|
||||||
|
]
|
||||||
|
criteria = [
|
||||||
|
"response_match_score"
|
||||||
|
] # Doesn't affect the DataFrame structure
|
||||||
|
|
||||||
|
ResponseEvaluator.evaluate(raw_data, criteria)
|
||||||
|
|
||||||
|
# Expected flattened and renamed data
|
||||||
|
expected_df_data = [
|
||||||
|
# Turn 1 (from SAMPLE_TURN_1_ALL_KEYS)
|
||||||
|
{
|
||||||
|
"prompt": "query1",
|
||||||
|
"response": "response1",
|
||||||
|
"actual_tool_use": [{"tool_name": "tool_a", "tool_input": {}}],
|
||||||
|
"reference_trajectory": [{"tool_name": "tool_a", "tool_input": {}}],
|
||||||
|
"reference": "reference1",
|
||||||
|
},
|
||||||
|
# Turn 2 (from SAMPLE_TURN_2_MISSING_REF)
|
||||||
|
{
|
||||||
|
"prompt": "query2",
|
||||||
|
"response": "response2",
|
||||||
|
"actual_tool_use": [],
|
||||||
|
"reference_trajectory": [],
|
||||||
|
# "reference": None # Missing key results in NaN in DataFrame
|
||||||
|
# usually
|
||||||
|
},
|
||||||
|
# Turn 3 (from SAMPLE_TURN_3_MISSING_EXP_TOOLS)
|
||||||
|
{
|
||||||
|
"prompt": "query3",
|
||||||
|
"response": "response3",
|
||||||
|
"actual_tool_use": [{"tool_name": "tool_b", "tool_input": {}}],
|
||||||
|
# "reference_trajectory": None, # Missing key results in NaN
|
||||||
|
"reference": "reference3",
|
||||||
|
},
|
||||||
|
]
|
||||||
|
# Need to be careful with missing keys -> NaN when creating DataFrame
|
||||||
|
# Pandas handles this automatically when creating from list of dicts
|
||||||
|
expected_df = pd.DataFrame(expected_df_data)
|
||||||
|
|
||||||
|
mock_perform_eval.assert_called_once()
|
||||||
|
_, kwargs = mock_perform_eval.call_args
|
||||||
|
# Compare the DataFrame passed to the mock
|
||||||
|
pd.testing.assert_frame_equal(kwargs["dataset"], expected_df)
|
||||||
|
|
||||||
|
@patch(
|
||||||
|
"google.adk.evaluation.response_evaluator.ResponseEvaluator._print_results"
|
||||||
|
) # Mock the private print method
|
||||||
|
def test_evaluate_print_detailed_results(
|
||||||
|
self, mock_print_results, mock_perform_eval
|
||||||
|
):
|
||||||
|
"""Test _print_results function is called when print_detailed_results=True."""
|
||||||
|
mock_perform_eval.return_value = (
|
||||||
|
MOCK_EVAL_RESULT # Ensure _perform_eval returns our mock result
|
||||||
|
)
|
||||||
|
|
||||||
|
raw_data = [[SAMPLE_TURN_1_ALL_KEYS]]
|
||||||
|
criteria = ["response_match_score"]
|
||||||
|
|
||||||
|
ResponseEvaluator.evaluate(raw_data, criteria, print_detailed_results=True)
|
||||||
|
|
||||||
|
# Assert _perform_eval was called
|
||||||
|
mock_perform_eval.assert_called_once()
|
||||||
|
# Assert _print_results was called once with the result object
|
||||||
|
# from _perform_eval
|
||||||
|
mock_print_results.assert_called_once_with(MOCK_EVAL_RESULT)
|
||||||
|
|
||||||
|
@patch(
|
||||||
|
"google.adk.evaluation.response_evaluator.ResponseEvaluator._print_results"
|
||||||
|
)
|
||||||
|
def test_evaluate_no_print_detailed_results(
|
||||||
|
self, mock_print_results, mock_perform_eval
|
||||||
|
):
|
||||||
|
"""Test _print_results function is NOT called when print_detailed_results=False (default)."""
|
||||||
|
mock_perform_eval.return_value = MOCK_EVAL_RESULT
|
||||||
|
|
||||||
|
raw_data = [[SAMPLE_TURN_1_ALL_KEYS]]
|
||||||
|
criteria = ["response_match_score"]
|
||||||
|
|
||||||
|
ResponseEvaluator.evaluate(raw_data, criteria, print_detailed_results=False)
|
||||||
|
|
||||||
|
# Assert _perform_eval was called
|
||||||
|
mock_perform_eval.assert_called_once()
|
||||||
|
# Assert _print_results was NOT called
|
||||||
|
mock_print_results.assert_not_called()
|
271
tests/unittests/evaluation/test_trajectory_evaluator.py
Normal file
271
tests/unittests/evaluation/test_trajectory_evaluator.py
Normal file
@ -0,0 +1,271 @@
|
|||||||
|
# Copyright 2025 Google LLC
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
"""Testings for the Trajectory Evaluator."""
|
||||||
|
|
||||||
|
import math
|
||||||
|
from google.adk.evaluation.trajectory_evaluator import TrajectoryEvaluator
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
# Define reusable tool call structures
|
||||||
|
TOOL_ROLL_DICE_16 = {"tool_name": "roll_die", "tool_input": {"sides": 16}}
|
||||||
|
TOOL_ROLL_DICE_6 = {"tool_name": "roll_die", "tool_input": {"sides": 6}}
|
||||||
|
TOOL_GET_WEATHER = {
|
||||||
|
"tool_name": "get_weather",
|
||||||
|
"tool_input": {"location": "Paris"},
|
||||||
|
}
|
||||||
|
TOOL_GET_WEATHER_SF = {
|
||||||
|
"tool_name": "get_weather",
|
||||||
|
"tool_input": {"location": "SF"},
|
||||||
|
}
|
||||||
|
|
||||||
|
# Sample data for turns
|
||||||
|
TURN_MATCH = {
|
||||||
|
"query": "Q1",
|
||||||
|
"response": "R1",
|
||||||
|
"actual_tool_use": [TOOL_ROLL_DICE_16],
|
||||||
|
"expected_tool_use": [TOOL_ROLL_DICE_16],
|
||||||
|
}
|
||||||
|
TURN_MISMATCH_INPUT = {
|
||||||
|
"query": "Q2",
|
||||||
|
"response": "R2",
|
||||||
|
"actual_tool_use": [TOOL_ROLL_DICE_6],
|
||||||
|
"expected_tool_use": [TOOL_ROLL_DICE_16],
|
||||||
|
}
|
||||||
|
TURN_MISMATCH_NAME = {
|
||||||
|
"query": "Q3",
|
||||||
|
"response": "R3",
|
||||||
|
"actual_tool_use": [TOOL_GET_WEATHER],
|
||||||
|
"expected_tool_use": [TOOL_ROLL_DICE_16],
|
||||||
|
}
|
||||||
|
TURN_MATCH_MULTIPLE = {
|
||||||
|
"query": "Q4",
|
||||||
|
"response": "R4",
|
||||||
|
"actual_tool_use": [TOOL_GET_WEATHER, TOOL_ROLL_DICE_6],
|
||||||
|
"expected_tool_use": [TOOL_GET_WEATHER, TOOL_ROLL_DICE_6],
|
||||||
|
}
|
||||||
|
TURN_MISMATCH_ORDER = {
|
||||||
|
"query": "Q5",
|
||||||
|
"response": "R5",
|
||||||
|
"actual_tool_use": [TOOL_ROLL_DICE_6, TOOL_GET_WEATHER],
|
||||||
|
"expected_tool_use": [TOOL_GET_WEATHER, TOOL_ROLL_DICE_6],
|
||||||
|
}
|
||||||
|
TURN_MISMATCH_LENGTH_ACTUAL_LONGER = {
|
||||||
|
"query": "Q6",
|
||||||
|
"response": "R6",
|
||||||
|
"actual_tool_use": [TOOL_GET_WEATHER, TOOL_ROLL_DICE_6],
|
||||||
|
"expected_tool_use": [TOOL_GET_WEATHER],
|
||||||
|
}
|
||||||
|
TURN_MISMATCH_LENGTH_EXPECTED_LONGER = {
|
||||||
|
"query": "Q7",
|
||||||
|
"response": "R7",
|
||||||
|
"actual_tool_use": [TOOL_GET_WEATHER],
|
||||||
|
"expected_tool_use": [TOOL_GET_WEATHER, TOOL_ROLL_DICE_6],
|
||||||
|
}
|
||||||
|
TURN_MATCH_WITH_MOCK_OUTPUT = {
|
||||||
|
"query": "Q8",
|
||||||
|
"response": "R8",
|
||||||
|
"actual_tool_use": [TOOL_GET_WEATHER_SF],
|
||||||
|
"expected_tool_use": [
|
||||||
|
{**TOOL_GET_WEATHER_SF, "mock_tool_output": "Sunny"}
|
||||||
|
], # Add mock output to expected
|
||||||
|
}
|
||||||
|
TURN_MATCH_EMPTY_TOOLS = {
|
||||||
|
"query": "Q9",
|
||||||
|
"response": "R9",
|
||||||
|
"actual_tool_use": [],
|
||||||
|
"expected_tool_use": [],
|
||||||
|
}
|
||||||
|
TURN_MISMATCH_EMPTY_VS_NONEMPTY = {
|
||||||
|
"query": "Q10",
|
||||||
|
"response": "R10",
|
||||||
|
"actual_tool_use": [],
|
||||||
|
"expected_tool_use": [TOOL_GET_WEATHER],
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def test_evaluate_none_dataset_raises_value_error():
|
||||||
|
"""Tests evaluate function raises ValueError for an empty list."""
|
||||||
|
with pytest.raises(ValueError, match="The evaluation dataset is empty."):
|
||||||
|
TrajectoryEvaluator.evaluate(None)
|
||||||
|
|
||||||
|
|
||||||
|
def test_evaluate_empty_dataset_raises_value_error():
|
||||||
|
"""Tests evaluate function raises ValueError for an empty list."""
|
||||||
|
with pytest.raises(ValueError, match="The evaluation dataset is empty."):
|
||||||
|
TrajectoryEvaluator.evaluate([])
|
||||||
|
|
||||||
|
|
||||||
|
def test_evaluate_single_turn_match():
|
||||||
|
"""Tests evaluate function with one conversation, one turn, perfect match."""
|
||||||
|
eval_dataset = [[TURN_MATCH]]
|
||||||
|
assert TrajectoryEvaluator.evaluate(eval_dataset) == 1.0
|
||||||
|
|
||||||
|
|
||||||
|
def test_evaluate_single_turn_mismatch():
|
||||||
|
"""Tests evaluate function with one conversation, one turn, mismatch."""
|
||||||
|
eval_dataset = [[TURN_MISMATCH_INPUT]]
|
||||||
|
assert TrajectoryEvaluator.evaluate(eval_dataset) == 0.0
|
||||||
|
|
||||||
|
|
||||||
|
def test_evaluate_multiple_turns_all_match():
|
||||||
|
"""Tests evaluate function with one conversation, multiple turns, all match."""
|
||||||
|
eval_dataset = [[TURN_MATCH, TURN_MATCH_MULTIPLE, TURN_MATCH_EMPTY_TOOLS]]
|
||||||
|
assert TrajectoryEvaluator.evaluate(eval_dataset) == 1.0
|
||||||
|
|
||||||
|
|
||||||
|
def test_evaluate_multiple_turns_mixed():
|
||||||
|
"""Tests evaluate function with one conversation, mixed match/mismatch turns."""
|
||||||
|
eval_dataset = [
|
||||||
|
[TURN_MATCH, TURN_MISMATCH_NAME, TURN_MATCH_MULTIPLE, TURN_MISMATCH_ORDER]
|
||||||
|
]
|
||||||
|
# Expected: (1.0 + 0.0 + 1.0 + 0.0) / 4 = 0.5
|
||||||
|
assert TrajectoryEvaluator.evaluate(eval_dataset) == 0.5
|
||||||
|
|
||||||
|
|
||||||
|
def test_evaluate_multiple_conversations_mixed():
|
||||||
|
"""Tests evaluate function with multiple conversations, mixed turns."""
|
||||||
|
eval_dataset = [
|
||||||
|
[TURN_MATCH, TURN_MISMATCH_INPUT], # Conv 1: 1.0, 0.0 -> Avg 0.5
|
||||||
|
[TURN_MATCH_MULTIPLE], # Conv 2: 1.0 -> Avg 1.0
|
||||||
|
[
|
||||||
|
TURN_MISMATCH_ORDER,
|
||||||
|
TURN_MISMATCH_LENGTH_ACTUAL_LONGER,
|
||||||
|
TURN_MATCH,
|
||||||
|
], # Conv 3: 0.0, 0.0, 1.0 -> Avg 1/3
|
||||||
|
]
|
||||||
|
# Expected: (1.0 + 0.0 + 1.0 + 0.0 + 0.0 + 1.0) / 6 = 3.0 / 6 = 0.5
|
||||||
|
assert TrajectoryEvaluator.evaluate(eval_dataset) == 0.5
|
||||||
|
|
||||||
|
|
||||||
|
def test_evaluate_ignores_mock_tool_output_in_expected():
|
||||||
|
"""Tests evaluate function correctly compares even if expected has mock_tool_output."""
|
||||||
|
eval_dataset = [[TURN_MATCH_WITH_MOCK_OUTPUT]]
|
||||||
|
assert TrajectoryEvaluator.evaluate(eval_dataset) == 1.0
|
||||||
|
|
||||||
|
|
||||||
|
def test_evaluate_match_empty_tool_lists():
|
||||||
|
"""Tests evaluate function correctly matches empty tool lists."""
|
||||||
|
eval_dataset = [[TURN_MATCH_EMPTY_TOOLS]]
|
||||||
|
assert TrajectoryEvaluator.evaluate(eval_dataset) == 1.0
|
||||||
|
|
||||||
|
|
||||||
|
def test_evaluate_mismatch_empty_vs_nonempty():
|
||||||
|
"""Tests evaluate function correctly mismatches empty vs non-empty tool lists."""
|
||||||
|
eval_dataset = [[TURN_MISMATCH_EMPTY_VS_NONEMPTY]]
|
||||||
|
assert TrajectoryEvaluator.evaluate(eval_dataset) == 0.0
|
||||||
|
eval_dataset_rev = [[{
|
||||||
|
**TURN_MISMATCH_EMPTY_VS_NONEMPTY, # Swap actual/expected
|
||||||
|
"actual_tool_use": [TOOL_GET_WEATHER],
|
||||||
|
"expected_tool_use": [],
|
||||||
|
}]]
|
||||||
|
assert TrajectoryEvaluator.evaluate(eval_dataset_rev) == 0.0
|
||||||
|
|
||||||
|
|
||||||
|
def test_evaluate_dataset_with_empty_conversation():
|
||||||
|
"""Tests evaluate function handles dataset containing an empty conversation list."""
|
||||||
|
eval_dataset = [[TURN_MATCH], []] # One valid conversation, one empty
|
||||||
|
# Should only evaluate the first conversation -> 1.0 / 1 turn = 1.0
|
||||||
|
assert TrajectoryEvaluator.evaluate(eval_dataset) == 1.0
|
||||||
|
|
||||||
|
|
||||||
|
def test_evaluate_dataset_only_empty_conversation():
|
||||||
|
"""Tests evaluate function handles dataset with only an empty conversation."""
|
||||||
|
eval_dataset = [[]]
|
||||||
|
# No rows evaluated, mean of empty series is NaN
|
||||||
|
# Depending on desired behavior, this could be 0.0 or NaN. The code returns
|
||||||
|
# NaN.
|
||||||
|
assert math.isnan(TrajectoryEvaluator.evaluate(eval_dataset))
|
||||||
|
|
||||||
|
|
||||||
|
def test_evaluate_print_detailed_results(capsys):
|
||||||
|
"""Tests evaluate function runs with print_detailed_results=True and prints something."""
|
||||||
|
eval_dataset = [[TURN_MATCH, TURN_MISMATCH_INPUT]]
|
||||||
|
TrajectoryEvaluator.evaluate(eval_dataset, print_detailed_results=True)
|
||||||
|
captured = capsys.readouterr()
|
||||||
|
assert "query" in captured.out # Check if the results table header is printed
|
||||||
|
assert "R1" in captured.out # Check if some data is printed
|
||||||
|
assert "Failures:" in captured.out # Check if failures header is printed
|
||||||
|
assert "Q2" in captured.out # Check if the failing query is printed
|
||||||
|
|
||||||
|
|
||||||
|
def test_evaluate_no_failures_print(capsys):
|
||||||
|
"""Tests evaluate function does not print Failures section when all turns match."""
|
||||||
|
eval_dataset = [[TURN_MATCH]]
|
||||||
|
TrajectoryEvaluator.evaluate(eval_dataset, print_detailed_results=True)
|
||||||
|
captured = capsys.readouterr()
|
||||||
|
assert "query" in captured.out # Results table should still print
|
||||||
|
assert "Failures:" not in captured.out # Failures section should NOT print
|
||||||
|
|
||||||
|
|
||||||
|
def test_are_tools_equal_identical():
|
||||||
|
"""Tests are_tools_equal function with identical lists."""
|
||||||
|
list_a = [TOOL_GET_WEATHER, TOOL_ROLL_DICE_6]
|
||||||
|
list_b = [TOOL_GET_WEATHER, TOOL_ROLL_DICE_6]
|
||||||
|
assert TrajectoryEvaluator.are_tools_equal(list_a, list_b)
|
||||||
|
|
||||||
|
|
||||||
|
def test_are_tools_equal_empty():
|
||||||
|
"""Tests are_tools_equal function with empty lists."""
|
||||||
|
assert TrajectoryEvaluator.are_tools_equal([], [])
|
||||||
|
|
||||||
|
|
||||||
|
def test_are_tools_equal_different_order():
|
||||||
|
"""Tests are_tools_equal function with same tools, different order."""
|
||||||
|
list_a = [TOOL_ROLL_DICE_6, TOOL_GET_WEATHER]
|
||||||
|
list_b = [TOOL_GET_WEATHER, TOOL_ROLL_DICE_6]
|
||||||
|
assert not TrajectoryEvaluator.are_tools_equal(list_a, list_b)
|
||||||
|
|
||||||
|
|
||||||
|
def test_are_tools_equal_different_length():
|
||||||
|
"""Tests are_tools_equal function with lists of different lengths."""
|
||||||
|
list_a = [TOOL_GET_WEATHER, TOOL_ROLL_DICE_6]
|
||||||
|
list_b = [TOOL_GET_WEATHER]
|
||||||
|
assert not TrajectoryEvaluator.are_tools_equal(list_a, list_b)
|
||||||
|
|
||||||
|
|
||||||
|
def test_are_tools_equal_different_input_values():
|
||||||
|
"""Tests are_tools_equal function with different input values."""
|
||||||
|
list_a = [TOOL_ROLL_DICE_16]
|
||||||
|
list_b = [TOOL_ROLL_DICE_6]
|
||||||
|
assert not TrajectoryEvaluator.are_tools_equal(list_a, list_b)
|
||||||
|
|
||||||
|
|
||||||
|
def test_are_tools_equal_different_tool_names():
|
||||||
|
"""Tests are_tools_equal function with different tool names."""
|
||||||
|
list_a = [TOOL_ROLL_DICE_16]
|
||||||
|
list_b = [TOOL_GET_WEATHER]
|
||||||
|
assert not TrajectoryEvaluator.are_tools_equal(list_a, list_b)
|
||||||
|
|
||||||
|
|
||||||
|
def test_are_tools_equal_ignores_extra_keys():
|
||||||
|
"""Tests are_tools_equal function ignores keys other than tool_name/tool_input."""
|
||||||
|
list_a = [{
|
||||||
|
"tool_name": "get_weather",
|
||||||
|
"tool_input": {"location": "Paris"},
|
||||||
|
"extra_key": "abc",
|
||||||
|
}]
|
||||||
|
list_b = [{
|
||||||
|
"tool_name": "get_weather",
|
||||||
|
"tool_input": {"location": "Paris"},
|
||||||
|
"other_key": 123,
|
||||||
|
}]
|
||||||
|
assert TrajectoryEvaluator.are_tools_equal(list_a, list_b)
|
||||||
|
|
||||||
|
|
||||||
|
def test_are_tools_equal_one_empty_one_not():
|
||||||
|
"""Tests are_tools_equal function with one empty list and one non-empty list."""
|
||||||
|
list_a = []
|
||||||
|
list_b = [TOOL_GET_WEATHER]
|
||||||
|
assert not TrajectoryEvaluator.are_tools_equal(list_a, list_b)
|
@ -225,3 +225,76 @@ def test_create_new_session_will_merge_states(service_type):
|
|||||||
assert session_2.state.get('user:key1') == 'value1'
|
assert session_2.state.get('user:key1') == 'value1'
|
||||||
assert not session_2.state.get('key1')
|
assert not session_2.state.get('key1')
|
||||||
assert not session_2.state.get('temp:key')
|
assert not session_2.state.get('temp:key')
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
'service_type', [SessionServiceType.IN_MEMORY, SessionServiceType.DATABASE]
|
||||||
|
)
|
||||||
|
def test_append_event_bytes(service_type):
|
||||||
|
session_service = get_session_service(service_type)
|
||||||
|
app_name = 'my_app'
|
||||||
|
user_id = 'user'
|
||||||
|
|
||||||
|
session = session_service.create_session(app_name=app_name, user_id=user_id)
|
||||||
|
event = Event(
|
||||||
|
invocation_id='invocation',
|
||||||
|
author='user',
|
||||||
|
content=types.Content(
|
||||||
|
role='user',
|
||||||
|
parts=[
|
||||||
|
types.Part.from_bytes(
|
||||||
|
data=b'test_image_data', mime_type='image/png'
|
||||||
|
),
|
||||||
|
],
|
||||||
|
),
|
||||||
|
)
|
||||||
|
session_service.append_event(session=session, event=event)
|
||||||
|
|
||||||
|
assert session.events[0].content.parts[0] == types.Part.from_bytes(
|
||||||
|
data=b'test_image_data', mime_type='image/png'
|
||||||
|
)
|
||||||
|
|
||||||
|
events = session_service.get_session(
|
||||||
|
app_name=app_name, user_id=user_id, session_id=session.id
|
||||||
|
).events
|
||||||
|
assert len(events) == 1
|
||||||
|
assert events[0].content.parts[0] == types.Part.from_bytes(
|
||||||
|
data=b'test_image_data', mime_type='image/png'
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
'service_type', [SessionServiceType.IN_MEMORY, SessionServiceType.DATABASE]
|
||||||
|
)
|
||||||
|
def test_append_event_complete(service_type):
|
||||||
|
session_service = get_session_service(service_type)
|
||||||
|
app_name = 'my_app'
|
||||||
|
user_id = 'user'
|
||||||
|
|
||||||
|
session = session_service.create_session(app_name=app_name, user_id=user_id)
|
||||||
|
event = Event(
|
||||||
|
invocation_id='invocation',
|
||||||
|
author='user',
|
||||||
|
content=types.Content(role='user', parts=[types.Part(text='test_text')]),
|
||||||
|
turn_complete=True,
|
||||||
|
partial=False,
|
||||||
|
actions=EventActions(
|
||||||
|
artifact_delta={
|
||||||
|
'file': 0,
|
||||||
|
},
|
||||||
|
transfer_to_agent='agent',
|
||||||
|
escalate=True,
|
||||||
|
),
|
||||||
|
long_running_tool_ids={'tool1'},
|
||||||
|
error_code='error_code',
|
||||||
|
error_message='error_message',
|
||||||
|
interrupted=True,
|
||||||
|
)
|
||||||
|
session_service.append_event(session=session, event=event)
|
||||||
|
|
||||||
|
assert (
|
||||||
|
session_service.get_session(
|
||||||
|
app_name=app_name, user_id=user_id, session_id=session.id
|
||||||
|
)
|
||||||
|
== session
|
||||||
|
)
|
||||||
|
@ -57,7 +57,7 @@ MOCK_EVENT_JSON = [
|
|||||||
{
|
{
|
||||||
'name': (
|
'name': (
|
||||||
'projects/test-project/locations/test-location/'
|
'projects/test-project/locations/test-location/'
|
||||||
'reasoningEngines/test_engine/sessions/1/events/123'
|
'reasoningEngines/123/sessions/1/events/123'
|
||||||
),
|
),
|
||||||
'invocationId': '123',
|
'invocationId': '123',
|
||||||
'author': 'user',
|
'author': 'user',
|
||||||
@ -111,7 +111,7 @@ MOCK_SESSION = Session(
|
|||||||
|
|
||||||
|
|
||||||
SESSION_REGEX = r'^reasoningEngines/([^/]+)/sessions/([^/]+)$'
|
SESSION_REGEX = r'^reasoningEngines/([^/]+)/sessions/([^/]+)$'
|
||||||
SESSIONS_REGEX = r'^reasoningEngines/([^/]+)/sessions$'
|
SESSIONS_REGEX = r'^reasoningEngines/([^/]+)/sessions\?filter=user_id=([^/]+)$'
|
||||||
EVENTS_REGEX = r'^reasoningEngines/([^/]+)/sessions/([^/]+)/events$'
|
EVENTS_REGEX = r'^reasoningEngines/([^/]+)/sessions/([^/]+)/events$'
|
||||||
LRO_REGEX = r'^operations/([^/]+)$'
|
LRO_REGEX = r'^operations/([^/]+)$'
|
||||||
|
|
||||||
@ -136,39 +136,52 @@ class MockApiClient:
|
|||||||
else:
|
else:
|
||||||
raise ValueError(f'Session not found: {session_id}')
|
raise ValueError(f'Session not found: {session_id}')
|
||||||
elif re.match(SESSIONS_REGEX, path):
|
elif re.match(SESSIONS_REGEX, path):
|
||||||
|
match = re.match(SESSIONS_REGEX, path)
|
||||||
return {
|
return {
|
||||||
'sessions': self.session_dict.values(),
|
'sessions': [
|
||||||
|
session
|
||||||
|
for session in self.session_dict.values()
|
||||||
|
if session['userId'] == match.group(2)
|
||||||
|
],
|
||||||
}
|
}
|
||||||
elif re.match(EVENTS_REGEX, path):
|
elif re.match(EVENTS_REGEX, path):
|
||||||
match = re.match(EVENTS_REGEX, path)
|
match = re.match(EVENTS_REGEX, path)
|
||||||
if match:
|
if match:
|
||||||
return {'sessionEvents': self.event_dict[match.group(2)]}
|
return {
|
||||||
|
'sessionEvents': (
|
||||||
|
self.event_dict[match.group(2)]
|
||||||
|
if match.group(2) in self.event_dict
|
||||||
|
else []
|
||||||
|
)
|
||||||
|
}
|
||||||
elif re.match(LRO_REGEX, path):
|
elif re.match(LRO_REGEX, path):
|
||||||
return {
|
return {
|
||||||
'name': (
|
'name': (
|
||||||
'projects/test-project/locations/test-location/'
|
'projects/test-project/locations/test-location/'
|
||||||
'reasoningEngines/123/sessions/123'
|
'reasoningEngines/123/sessions/4'
|
||||||
),
|
),
|
||||||
'done': True,
|
'done': True,
|
||||||
}
|
}
|
||||||
else:
|
else:
|
||||||
raise ValueError(f'Unsupported path: {path}')
|
raise ValueError(f'Unsupported path: {path}')
|
||||||
elif http_method == 'POST':
|
elif http_method == 'POST':
|
||||||
id = str(uuid.uuid4())
|
new_session_id = '4'
|
||||||
self.session_dict[id] = {
|
self.session_dict[new_session_id] = {
|
||||||
'name': (
|
'name': (
|
||||||
'projects/test-project/locations/test-location/'
|
'projects/test-project/locations/test-location/'
|
||||||
'reasoningEngines/123/sessions/'
|
'reasoningEngines/123/sessions/'
|
||||||
+ id
|
+ new_session_id
|
||||||
),
|
),
|
||||||
'userId': request_dict['user_id'],
|
'userId': request_dict['user_id'],
|
||||||
'sessionState': request_dict.get('sessionState', {}),
|
'sessionState': request_dict.get('session_state', {}),
|
||||||
'updateTime': '2024-12-12T12:12:12.123456Z',
|
'updateTime': '2024-12-12T12:12:12.123456Z',
|
||||||
}
|
}
|
||||||
return {
|
return {
|
||||||
'name': (
|
'name': (
|
||||||
'projects/test_project/locations/test_location/'
|
'projects/test_project/locations/test_location/'
|
||||||
'reasoningEngines/test_engine/sessions/123'
|
'reasoningEngines/123/sessions/'
|
||||||
|
+ new_session_id
|
||||||
|
+ '/operations/111'
|
||||||
),
|
),
|
||||||
'done': False,
|
'done': False,
|
||||||
}
|
}
|
||||||
@ -223,19 +236,23 @@ def test_get_and_delete_session():
|
|||||||
)
|
)
|
||||||
assert str(excinfo.value) == 'Session not found: 1'
|
assert str(excinfo.value) == 'Session not found: 1'
|
||||||
|
|
||||||
def test_list_sessions():
|
|
||||||
|
def test_list_sessions():
|
||||||
session_service = mock_vertex_ai_session_service()
|
session_service = mock_vertex_ai_session_service()
|
||||||
sessions = session_service.list_sessions(app_name='123', user_id='user')
|
sessions = session_service.list_sessions(app_name='123', user_id='user')
|
||||||
assert len(sessions.sessions) == 2
|
assert len(sessions.sessions) == 2
|
||||||
assert sessions.sessions[0].id == '1'
|
assert sessions.sessions[0].id == '1'
|
||||||
assert sessions.sessions[1].id == '2'
|
assert sessions.sessions[1].id == '2'
|
||||||
|
|
||||||
def test_create_session():
|
|
||||||
|
def test_create_session():
|
||||||
session_service = mock_vertex_ai_session_service()
|
session_service = mock_vertex_ai_session_service()
|
||||||
|
|
||||||
|
state = {'key': 'value'}
|
||||||
session = session_service.create_session(
|
session = session_service.create_session(
|
||||||
app_name='123', user_id='user', state={'key': 'value'}
|
app_name='123', user_id='user', state=state
|
||||||
)
|
)
|
||||||
assert session.state == {'key': 'value'}
|
assert session.state == state
|
||||||
assert session.app_name == '123'
|
assert session.app_name == '123'
|
||||||
assert session.user_id == 'user'
|
assert session.user_id == 'user'
|
||||||
assert session.last_update_time is not None
|
assert session.last_update_time is not None
|
||||||
|
@ -119,7 +119,7 @@ def calendar_api_spec():
|
|||||||
"methods": {
|
"methods": {
|
||||||
"get": {
|
"get": {
|
||||||
"id": "calendar.calendars.get",
|
"id": "calendar.calendars.get",
|
||||||
"path": "calendars/{calendarId}",
|
"flatPath": "calendars/{calendarId}",
|
||||||
"httpMethod": "GET",
|
"httpMethod": "GET",
|
||||||
"description": "Returns metadata for a calendar.",
|
"description": "Returns metadata for a calendar.",
|
||||||
"parameters": {
|
"parameters": {
|
||||||
@ -151,7 +151,7 @@ def calendar_api_spec():
|
|||||||
"methods": {
|
"methods": {
|
||||||
"list": {
|
"list": {
|
||||||
"id": "calendar.events.list",
|
"id": "calendar.events.list",
|
||||||
"path": "calendars/{calendarId}/events",
|
"flatPath": "calendars/{calendarId}/events",
|
||||||
"httpMethod": "GET",
|
"httpMethod": "GET",
|
||||||
"description": (
|
"description": (
|
||||||
"Returns events on the specified calendar."
|
"Returns events on the specified calendar."
|
||||||
|
Loading…
Reference in New Issue
Block a user