Merge branch 'main' into #247-OpenAPIToolSet-Considering-Required-parameters

This commit is contained in:
Wei Sun (Jack) 2025-05-01 18:47:28 -07:00 committed by GitHub
commit f12300113d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
58 changed files with 1579 additions and 422 deletions

59
.github/workflows/pyink.yml vendored Normal file
View File

@ -0,0 +1,59 @@
name: Check Pyink Formatting
on:
pull_request:
paths:
- 'src/**/*.py'
- 'tests/**/*.py'
- 'pyproject.toml'
jobs:
pyink-check:
runs-on: ubuntu-latest
steps:
- name: Checkout code
uses: actions/checkout@v4
- name: Set up Python
uses: actions/setup-python@v5
with:
python-version: '3.x'
- name: Install pyink
run: |
pip install pyink
- name: Detect changed Python files
id: detect_changes
run: |
git fetch origin ${{ github.base_ref }}
CHANGED_FILES=$(git diff --diff-filter=ACMR --name-only origin/${{ github.base_ref }}...HEAD | grep -E '\.py$' || true)
echo "CHANGED_FILES=${CHANGED_FILES}" >> $GITHUB_ENV
- name: Run pyink on changed files
if: env.CHANGED_FILES != ''
run: |
echo "Changed Python files:"
echo "$CHANGED_FILES"
# Run pyink --check
set +e
pyink --check --config pyproject.toml $CHANGED_FILES
RESULT=$?
set -e
if [ $RESULT -ne 0 ]; then
echo ""
echo "❌ Pyink formatting check failed!"
echo "👉 To fix formatting, run locally:"
echo ""
echo " pyink --config pyproject.toml $CHANGED_FILES"
echo ""
exit $RESULT
fi
- name: No changed Python files detected
if: env.CHANGED_FILES == ''
run: |
echo "No Python files changed. Skipping pyink check."

View File

@ -29,11 +29,13 @@ jobs:
run: |
uv venv .venv
source .venv/bin/activate
uv sync --extra test
uv sync --extra test --extra eval
- name: Run unit tests with pytest
run: |
source .venv/bin/activate
pytest tests/unittests \
--ignore=tests/unittests/tools/google_api_tool/test_googleapi_to_openapi_converter.py \
--ignore=tests/unittests/artifacts/test_artifact_service.py
--ignore=tests/unittests/artifacts/test_artifact_service.py \
--ignore=tests/unittests/tools/application_integration_tool/clients/test_connections_client.py \
--ignore=tests/unittests/tools/google_api_tool/test_googleapi_to_openapi_converter.py

View File

@ -1,17 +1,89 @@
# Changelog
## 0.3.0
### ⚠ BREAKING CHANGES
* Auth: expose `access_token` and `refresh_token` at top level of auth
credentails, instead of a `dict`
([commit](https://github.com/google/adk-python/commit/956fb912e8851b139668b1ccb8db10fd252a6990)).
### Features
* Added support for running agents with MCPToolset easily on `adk web`.
* Added `custom_metadata` field to `LlmResponse`, which can be used to tag
LlmResponse via `after_model_callback`.
* Added `--session_db_url` to `adk deploy cloud_run` option.
* Many Dev UI improvements:
* Better google search result rendering.
* Show websocket close reason in Dev UI.
* Better error message showing for audio/video.
### Bug Fixes
* Fixed MCP tool json schema parsing issue.
* Fixed issues in DatabaseSessionService that leads to crash.
* Fixed functions.py.
* Fixed `skip_summarization` behavior in `AgentTool`.
### Miscellaneous Chores
* README.md impprovements.
* Various code improvements.
* Various typo fixes.
* Bump min version of google-genai to 1.11.0.
## 0.2.0
### ⚠ BREAKING CHANGES
* Fix typo in method name in `Event`: has_trailing_code_exeuction_result --> has_trailing_code_execution_result.
### Features
* `adk` CLI:
* Introduce `adk create` cli tool to help creating agents.
* Adds `--verbosity` option to `adk deploy cloud_run` to show detailed cloud
run deploy logging.
* Improve the initialization error message for `DatabaseSessionService`.
* Lazy loading for Google 1P tools to minimize the initial latency.
* Support emitting state-change-only events from planners.
* Lots of Dev UI updates, including:
* Show planner thoughts and actions in the Dev UI.
* Support MCP tools in Dev UI.
(NOTE: `agent.py` interface is temp solution and is subject to change)
* Auto-select the only app if only one app is available.
* Show grounding links generated by Google Search Tool.
* `.env` file is reloaded on every agent run.
### Bug Fixes
* `LiteLlm`: arg parsing error and python 3.9 compatibility.
* `DatabaseSessionService`: adds the missing fields; fixes event with empty
content not being persisted.
* Google API Discovery response parsing issue.
* `load_memory_tool` rendering issue in Dev UI.
* Markdown text overflows in Dev UI.
### Miscellaneous Chores
* Adds unit tests in Github action.
* Improves test coverage.
* Various typo fixes.
## 0.1.0
### Features
* Initial release of the Agent Development Kit (ADK).
* Multi-agent, agent-as-workflow, and custom agent support
* Tool authentication support
* Rich tool support, e.g. bult-in tools, google-cloud tools, thir-party tools, and MCP tools
* Rich tool support, e.g. built-in tools, google-cloud tools, third-party tools, and MCP tools
* Rich callback support
* Built-in code execution capability
* Asynchronous runtime and execution
* Session, and memory support
* Built-in evaluation support
* Development UI that makes local devlopment easy
* Development UI that makes local development easy
* Deploy to Google Cloud Run, Agent Engine
* (Experimental) Live(Bidi) auido/video agent support and Compositional Function Calling(CFC) support

View File

@ -25,6 +25,19 @@ This project follows
## Contribution process
### Requirement for PRs
- All PRs, other than small documentation or typo fixes, should have a Issue assoicated. If not, please create one.
- Small, focused PRs. Keep changes minimal—one concern per PR.
- For bug fixes or features, please provide logs or screenshot after the fix is applied to help reviewers better understand the fix.
- Please add corresponding testing for your code change if it's not covered by existing tests.
### Large or Complex Changes
For substantial features or architectural revisions:
- Open an Issue First: Outline your proposal, including design considerations and impact.
- Gather Feedback: Discuss with maintainers and the community to ensure alignment and avoid duplicate work
### Code reviews
All submissions, including submissions by project members, require review. We

View File

@ -18,11 +18,7 @@
</h3>
</html>
Agent Development Kit (ADK) is designed for developers seeking fine-grained
control and flexibility when building advanced AI agents that are tightly
integrated with services in Google Cloud. It allows you to define agent
behavior, orchestration, and tool use directly in code, enabling robust
debugging, versioning, and deployment anywhere from your laptop to the cloud.
Agent Development Kit (ADK) is a flexible and modular framework for developing and deploying AI agents. While optimized for Gemini and the Google ecosystem, ADK is model-agnostic, deployment-agnostic, and is built for compatibility with other frameworks. ADK was designed to make agent development feel more like software development, to make it easier for developers to create, deploy, and orchestrate agentic architectures that range from simple tasks to complex workflows.
---
@ -45,12 +41,27 @@ debugging, versioning, and deployment anywhere from your laptop to the cloud
## 🚀 Installation
You can install the ADK using `pip`:
### Stable Release (Recommended)
You can install the latest stable version of ADK using `pip`:
```bash
pip install google-adk
```
The release cadence is weekly.
This version is recommended for most users as it represents the most recent official release.
### Development Version
Bug fixes and new features are merged into the main branch on GitHub first. If you need access to changes that haven't been included in an official PyPI release yet, you can install directly from the main branch:
```bash
pip install git+https://github.com/google/adk-python.git@main
```
Note: The development version is built directly from the latest code commits. While it includes the newest fixes and features, it may also contain experimental changes or bugs not present in the stable release. Use it primarily for testing upcoming changes or accessing critical fixes before they are officially released.
## 📚 Documentation
Explore the full documentation for detailed guides on building, evaluating, and
@ -112,10 +123,18 @@ adk eval \
samples_for_testing/hello_world/hello_world_eval_set_001.evalset.json
```
## 🤖 A2A and ADK integration
For remote agent-to-agent communication, ADK integrates with the
[A2A protocol](https://github.com/google/A2A/).
See this [example](https://github.com/google/A2A/tree/main/samples/python/agents/google_adk)
for how they can work together.
## 🤝 Contributing
We welcome contributions from the community! Whether it's bug reports, feature requests, documentation improvements, or code contributions, please see our [**Contributing Guidelines**](./CONTRIBUTING.md) to get started.
We welcome contributions from the community! Whether it's bug reports, feature requests, documentation improvements, or code contributions, please see our
- [General contribution guideline and flow](https://google.github.io/adk-docs/contributing-guide/#questions).
- Then if you want to contribute code, please read [Code Contributing Guidelines](./CONTRIBUTING.md) to get started.
## 📄 License

View File

@ -45,7 +45,7 @@ confidence=
# can either give multiple identifiers separated by comma (,) or put this
# option multiple times (only on the command line, not in the configuration
# file where it should appear only once).You can also use "--disable=all" to
# disable everything first and then reenable specific checks. For example, if
# disable everything first and then re-enable specific checks. For example, if
# you want to run only the similarities checker, you can use "--disable=all
# --enable=similarities". If you want to run only the classes checker, but have
# no Warning level messages displayed, use"--disable=all --enable=classes

View File

@ -33,7 +33,7 @@ dependencies = [
"google-cloud-secret-manager>=2.22.0", # Fetching secrets in RestAPI Tool
"google-cloud-speech>=2.30.0", # For Audo Transcription
"google-cloud-storage>=2.18.0, <3.0.0", # For GCS Artifact service
"google-genai>=1.9.0", # Google GenAI SDK
"google-genai>=1.11.0", # Google GenAI SDK
"graphviz>=0.20.2", # Graphviz for graph rendering
"mcp>=1.5.0;python_version>='3.10'", # For MCP Toolset
"opentelemetry-api>=1.31.0", # OpenTelemetry
@ -119,6 +119,15 @@ line-length = 80
unstable = true
pyink-indentation = 2
pyink-use-majority-quotes = true
pyink-annotation-pragmas = [
"noqa",
"pylint:",
"type: ignore",
"pytype:",
"mypy:",
"pyright:",
"pyre-",
]
[build-system]
@ -135,15 +144,10 @@ exclude = ['src/**/*.sh']
name = "google.adk"
[tool.isort]
# Organize imports following Google style-guide
force_single_line = true
force_sort_within_sections = true
honor_case_in_force_sorted_sections = true
order_by_type = false
sort_relative_in_force_sorted_sections = true
multi_line_output = 3
line_length = 200
profile = "google"
[tool.pytest.ini_options]
testpaths = ["tests"]
asyncio_default_fixture_loop_scope = "function"
asyncio_mode = "auto"

View File

@ -44,7 +44,7 @@ Args:
callback_context: MUST be named 'callback_context' (enforced).
Returns:
The content to return to the user. When set, the agent run will skipped and
The content to return to the user. When set, the agent run will be skipped and
the provided content will be returned to user.
"""
@ -55,8 +55,8 @@ Args:
callback_context: MUST be named 'callback_context' (enforced).
Returns:
The content to return to the user. When set, the agent run will skipped and
the provided content will be appended to event history as agent response.
The content to return to the user. When set, the provided content will be
appended to event history as agent response.
"""
@ -101,8 +101,8 @@ class BaseAgent(BaseModel):
callback_context: MUST be named 'callback_context' (enforced).
Returns:
The content to return to the user. When set, the agent run will skipped and
the provided content will be returned to user.
The content to return to the user. When set, the agent run will be skipped
and the provided content will be returned to user.
"""
after_agent_callback: Optional[AfterAgentCallback] = None
"""Callback signature that is invoked after the agent run.
@ -111,8 +111,8 @@ class BaseAgent(BaseModel):
callback_context: MUST be named 'callback_context' (enforced).
Returns:
The content to return to the user. When set, the agent run will skipped and
the provided content will be appended to event history as agent response.
The content to return to the user. When set, the provided content will be
appended to event history as agent response.
"""
@final

View File

@ -23,7 +23,6 @@ from .readonly_context import ReadonlyContext
if TYPE_CHECKING:
from google.genai import types
from ..events.event import Event
from ..events.event_actions import EventActions
from ..sessions.state import State
from .invocation_context import InvocationContext

View File

@ -15,12 +15,7 @@
from __future__ import annotations
import logging
from typing import Any
from typing import AsyncGenerator
from typing import Callable
from typing import Literal
from typing import Optional
from typing import Union
from typing import Any, AsyncGenerator, Awaitable, Callable, Literal, Optional, Union
from google.genai import types
from pydantic import BaseModel
@ -62,11 +57,11 @@ AfterModelCallback: TypeAlias = Callable[
]
BeforeToolCallback: TypeAlias = Callable[
[BaseTool, dict[str, Any], ToolContext],
Optional[dict],
Union[Awaitable[Optional[dict]], Optional[dict]],
]
AfterToolCallback: TypeAlias = Callable[
[BaseTool, dict[str, Any], ToolContext, dict],
Optional[dict],
Union[Awaitable[Optional[dict]], Optional[dict]],
]
InstructionProvider: TypeAlias = Callable[[ReadonlyContext], str]

View File

@ -66,7 +66,8 @@ class OAuth2Auth(BaseModelWithConfig):
redirect_uri: Optional[str] = None
auth_response_uri: Optional[str] = None
auth_code: Optional[str] = None
token: Optional[Dict[str, Any]] = None
access_token: Optional[str] = None
refresh_token: Optional[str] = None
class ServiceAccountCredential(BaseModelWithConfig):

View File

@ -82,7 +82,8 @@ class AuthHandler:
or not auth_credential.oauth2
or not auth_credential.oauth2.client_id
or not auth_credential.oauth2.client_secret
or auth_credential.oauth2.token
or auth_credential.oauth2.access_token
or auth_credential.oauth2.refresh_token
):
return self.auth_config.exchanged_auth_credential
@ -93,7 +94,7 @@ class AuthHandler:
redirect_uri=auth_credential.oauth2.redirect_uri,
state=auth_credential.oauth2.state,
)
token = client.fetch_token(
tokens = client.fetch_token(
token_endpoint,
authorization_response=auth_credential.oauth2.auth_response_uri,
code=auth_credential.oauth2.auth_code,
@ -102,7 +103,10 @@ class AuthHandler:
updated_credential = AuthCredential(
auth_type=AuthCredentialTypes.OAUTH2,
oauth2=OAuth2Auth(token=dict(token)),
oauth2=OAuth2Auth(
access_token=tokens.get("access_token"),
refresh_token=tokens.get("refresh_token"),
),
)
return updated_credential

View File

@ -29,5 +29,5 @@
<style>html{color-scheme:dark}html{--mat-sys-background:light-dark(#fcf9f8, #131314);--mat-sys-error:light-dark(#ba1a1a, #ffb4ab);--mat-sys-error-container:light-dark(#ffdad6, #93000a);--mat-sys-inverse-on-surface:light-dark(#f3f0f0, #313030);--mat-sys-inverse-primary:light-dark(#c1c7cd, #595f65);--mat-sys-inverse-surface:light-dark(#313030, #e5e2e2);--mat-sys-on-background:light-dark(#1c1b1c, #e5e2e2);--mat-sys-on-error:light-dark(#ffffff, #690005);--mat-sys-on-error-container:light-dark(#410002, #ffdad6);--mat-sys-on-primary:light-dark(#ffffff, #2b3136);--mat-sys-on-primary-container:light-dark(#161c21, #dde3e9);--mat-sys-on-primary-fixed:light-dark(#161c21, #161c21);--mat-sys-on-primary-fixed-variant:light-dark(#41474d, #41474d);--mat-sys-on-secondary:light-dark(#ffffff, #003061);--mat-sys-on-secondary-container:light-dark(#001b3c, #d5e3ff);--mat-sys-on-secondary-fixed:light-dark(#001b3c, #001b3c);--mat-sys-on-secondary-fixed-variant:light-dark(#0f4784, #0f4784);--mat-sys-on-surface:light-dark(#1c1b1c, #e5e2e2);--mat-sys-on-surface-variant:light-dark(#44474a, #e1e2e6);--mat-sys-on-tertiary:light-dark(#ffffff, #2b3136);--mat-sys-on-tertiary-container:light-dark(#161c21, #dde3e9);--mat-sys-on-tertiary-fixed:light-dark(#161c21, #161c21);--mat-sys-on-tertiary-fixed-variant:light-dark(#41474d, #41474d);--mat-sys-outline:light-dark(#74777b, #8e9194);--mat-sys-outline-variant:light-dark(#c4c7ca, #44474a);--mat-sys-primary:light-dark(#595f65, #c1c7cd);--mat-sys-primary-container:light-dark(#dde3e9, #41474d);--mat-sys-primary-fixed:light-dark(#dde3e9, #dde3e9);--mat-sys-primary-fixed-dim:light-dark(#c1c7cd, #c1c7cd);--mat-sys-scrim:light-dark(#000000, #000000);--mat-sys-secondary:light-dark(#305f9d, #a7c8ff);--mat-sys-secondary-container:light-dark(#d5e3ff, #0f4784);--mat-sys-secondary-fixed:light-dark(#d5e3ff, #d5e3ff);--mat-sys-secondary-fixed-dim:light-dark(#a7c8ff, #a7c8ff);--mat-sys-shadow:light-dark(#000000, #000000);--mat-sys-surface:light-dark(#fcf9f8, #131314);--mat-sys-surface-bright:light-dark(#fcf9f8, #393939);--mat-sys-surface-container:light-dark(#f0eded, #201f20);--mat-sys-surface-container-high:light-dark(#eae7e7, #2a2a2a);--mat-sys-surface-container-highest:light-dark(#e5e2e2, #393939);--mat-sys-surface-container-low:light-dark(#f6f3f3, #1c1b1c);--mat-sys-surface-container-lowest:light-dark(#ffffff, #0e0e0e);--mat-sys-surface-dim:light-dark(#dcd9d9, #131314);--mat-sys-surface-tint:light-dark(#595f65, #c1c7cd);--mat-sys-surface-variant:light-dark(#e1e2e6, #44474a);--mat-sys-tertiary:light-dark(#595f65, #c1c7cd);--mat-sys-tertiary-container:light-dark(#dde3e9, #41474d);--mat-sys-tertiary-fixed:light-dark(#dde3e9, #dde3e9);--mat-sys-tertiary-fixed-dim:light-dark(#c1c7cd, #c1c7cd);--mat-sys-neutral-variant20:#2d3134;--mat-sys-neutral10:#1c1b1c}html{--mat-sys-level0:0px 0px 0px 0px rgba(0, 0, 0, .2), 0px 0px 0px 0px rgba(0, 0, 0, .14), 0px 0px 0px 0px rgba(0, 0, 0, .12)}html{--mat-sys-level1:0px 2px 1px -1px rgba(0, 0, 0, .2), 0px 1px 1px 0px rgba(0, 0, 0, .14), 0px 1px 3px 0px rgba(0, 0, 0, .12)}html{--mat-sys-level2:0px 3px 3px -2px rgba(0, 0, 0, .2), 0px 3px 4px 0px rgba(0, 0, 0, .14), 0px 1px 8px 0px rgba(0, 0, 0, .12)}html{--mat-sys-level3:0px 3px 5px -1px rgba(0, 0, 0, .2), 0px 6px 10px 0px rgba(0, 0, 0, .14), 0px 1px 18px 0px rgba(0, 0, 0, .12)}html{--mat-sys-level4:0px 5px 5px -3px rgba(0, 0, 0, .2), 0px 8px 10px 1px rgba(0, 0, 0, .14), 0px 3px 14px 2px rgba(0, 0, 0, .12)}html{--mat-sys-level5:0px 7px 8px -4px rgba(0, 0, 0, .2), 0px 12px 17px 2px rgba(0, 0, 0, .14), 0px 5px 22px 4px rgba(0, 0, 0, .12)}html{--mat-sys-corner-extra-large:28px;--mat-sys-corner-extra-large-top:28px 28px 0 0;--mat-sys-corner-extra-small:4px;--mat-sys-corner-extra-small-top:4px 4px 0 0;--mat-sys-corner-full:9999px;--mat-sys-corner-large:16px;--mat-sys-corner-large-end:0 16px 16px 0;--mat-sys-corner-large-start:16px 0 0 16px;--mat-sys-corner-large-top:16px 16px 0 0;--mat-sys-corner-medium:12px;--mat-sys-corner-none:0;--mat-sys-corner-small:8px}html{--mat-sys-dragged-state-layer-opacity:.16;--mat-sys-focus-state-layer-opacity:.12;--mat-sys-hover-state-layer-opacity:.08;--mat-sys-pressed-state-layer-opacity:.12}html{font-family:Google Sans,Helvetica Neue,sans-serif!important}body{height:100vh;margin:0}:root{--mat-sys-primary:black;--mdc-checkbox-selected-icon-color:white;--mat-sys-background:#131314;--mat-tab-header-active-label-text-color:#8AB4F8;--mat-tab-header-active-hover-label-text-color:#8AB4F8;--mat-tab-header-active-focus-label-text-color:#8AB4F8;--mat-tab-header-label-text-weight:500;--mdc-text-button-label-text-color:#89b4f8}:root{--mdc-dialog-container-color:#2b2b2f}:root{--mdc-dialog-subhead-color:white}:root{--mdc-circular-progress-active-indicator-color:#a8c7fa}:root{--mdc-circular-progress-size:80}</style><link rel="stylesheet" href="styles-4VDSPQ37.css" media="print" onload="this.media='all'"><noscript><link rel="stylesheet" href="styles-4VDSPQ37.css"></noscript></head>
<body>
<app-root></app-root>
<script src="polyfills-FFHMD2TL.js" type="module"></script><script src="main-ZBO76GRM.js" type="module"></script></body>
<script src="polyfills-FFHMD2TL.js" type="module"></script><script src="main-HWIBUY2R.js" type="module"></script></body>
</html>

File diff suppressed because one or more lines are too long

View File

@ -39,12 +39,12 @@ class InputFile(BaseModel):
async def run_input_file(
app_name: str,
user_id: str,
root_agent: LlmAgent,
artifact_service: BaseArtifactService,
session: Session,
session_service: BaseSessionService,
input_path: str,
) -> None:
) -> Session:
runner = Runner(
app_name=app_name,
agent=root_agent,
@ -55,9 +55,11 @@ async def run_input_file(
input_file = InputFile.model_validate_json(f.read())
input_file.state['_time'] = datetime.now()
session.state = input_file.state
session = session_service.create_session(
app_name=app_name, user_id=user_id, state=input_file.state
)
for query in input_file.queries:
click.echo(f'user: {query}')
click.echo(f'[user]: {query}')
content = types.Content(role='user', parts=[types.Part(text=query)])
async for event in runner.run_async(
user_id=session.user_id, session_id=session.id, new_message=content
@ -65,23 +67,23 @@ async def run_input_file(
if event.content and event.content.parts:
if text := ''.join(part.text or '' for part in event.content.parts):
click.echo(f'[{event.author}]: {text}')
return session
async def run_interactively(
app_name: str,
root_agent: LlmAgent,
artifact_service: BaseArtifactService,
session: Session,
session_service: BaseSessionService,
) -> None:
runner = Runner(
app_name=app_name,
app_name=session.app_name,
agent=root_agent,
artifact_service=artifact_service,
session_service=session_service,
)
while True:
query = input('user: ')
query = input('[user]: ')
if not query or not query.strip():
continue
if query == 'exit':
@ -100,7 +102,8 @@ async def run_cli(
*,
agent_parent_dir: str,
agent_folder_name: str,
json_file_path: Optional[str] = None,
input_file: Optional[str] = None,
saved_session_file: Optional[str] = None,
save_session: bool,
) -> None:
"""Runs an interactive CLI for a certain agent.
@ -109,8 +112,11 @@ async def run_cli(
agent_parent_dir: str, the absolute path of the parent folder of the agent
folder.
agent_folder_name: str, the name of the agent folder.
json_file_path: Optional[str], the absolute path to the json file, either
*.input.json or *.session.json.
input_file: Optional[str], the absolute path to the json file that contains
the initial session state and user queries, exclusive with
saved_session_file.
saved_session_file: Optional[str], the absolute path to the json file that
contains a previously saved session, exclusive with input_file.
save_session: bool, whether to save the session on exit.
"""
if agent_parent_dir not in sys.path:
@ -118,46 +124,50 @@ async def run_cli(
artifact_service = InMemoryArtifactService()
session_service = InMemorySessionService()
session = session_service.create_session(
app_name=agent_folder_name, user_id='test_user'
)
agent_module_path = os.path.join(agent_parent_dir, agent_folder_name)
agent_module = importlib.import_module(agent_folder_name)
user_id = 'test_user'
session = session_service.create_session(
app_name=agent_folder_name, user_id=user_id
)
root_agent = agent_module.agent.root_agent
envs.load_dotenv_for_agent(agent_folder_name, agent_parent_dir)
if json_file_path:
if json_file_path.endswith('.input.json'):
await run_input_file(
if input_file:
session = await run_input_file(
app_name=agent_folder_name,
user_id=user_id,
root_agent=root_agent,
artifact_service=artifact_service,
session=session,
session_service=session_service,
input_path=json_file_path,
input_path=input_file,
)
elif json_file_path.endswith('.session.json'):
with open(json_file_path, 'r') as f:
session = Session.model_validate_json(f.read())
for content in session.get_contents():
if content.role == 'user':
print('user: ', content.parts[0].text)
elif saved_session_file:
loaded_session = None
with open(saved_session_file, 'r') as f:
loaded_session = Session.model_validate_json(f.read())
if loaded_session:
for event in loaded_session.events:
session_service.append_event(session, event)
content = event.content
if not content or not content.parts or not content.parts[0].text:
continue
if event.author == 'user':
click.echo(f'[user]: {content.parts[0].text}')
else:
print(content.parts[0].text)
click.echo(f'[{event.author}]: {content.parts[0].text}')
await run_interactively(
agent_folder_name,
root_agent,
artifact_service,
session,
session_service,
)
else:
print(f'Unsupported file type: {json_file_path}')
exit(1)
else:
print(f'Running agent {root_agent.name}, type exit to exit.')
click.echo(f'Running agent {root_agent.name}, type exit to exit.')
await run_interactively(
agent_folder_name,
root_agent,
artifact_service,
session,
@ -165,9 +175,6 @@ async def run_cli(
)
if save_session:
if json_file_path:
session_path = json_file_path.replace('.input.json', '.session.json')
else:
session_id = input('Session ID to save: ')
session_path = f'{agent_module_path}/{session_id}.session.json'

View File

@ -54,7 +54,7 @@ COPY "agents/{app_name}/" "/app/agents/{app_name}/"
EXPOSE {port}
CMD adk {command} --port={port} {trace_to_cloud_option} "/app/agents"
CMD adk {command} --port={port} {session_db_option} {trace_to_cloud_option} "/app/agents"
"""
@ -85,6 +85,7 @@ def to_cloud_run(
trace_to_cloud: bool,
with_ui: bool,
verbosity: str,
session_db_url: str,
):
"""Deploys an agent to Google Cloud Run.
@ -112,6 +113,7 @@ def to_cloud_run(
trace_to_cloud: Whether to enable Cloud Trace.
with_ui: Whether to deploy with UI.
verbosity: The verbosity level of the CLI.
session_db_url: The database URL to connect the session.
"""
app_name = app_name or os.path.basename(agent_folder)
@ -144,6 +146,9 @@ def to_cloud_run(
port=port,
command='web' if with_ui else 'api_server',
install_agent_deps=install_agent_deps,
session_db_option=f'--session_db_url={session_db_url}'
if session_db_url
else '',
trace_to_cloud_option='--trace_to_cloud' if trace_to_cloud else '',
)
dockerfile_path = os.path.join(temp_folder, 'Dockerfile')

View File

@ -256,7 +256,7 @@ def run_evals(
)
if final_eval_status == EvalStatus.PASSED:
result = "✅ Passsed"
result = "✅ Passed"
else:
result = "❌ Failed"

View File

@ -96,6 +96,23 @@ def cli_create_cmd(
)
def validate_exclusive(ctx, param, value):
# Store the validated parameters in the context
if not hasattr(ctx, "exclusive_opts"):
ctx.exclusive_opts = {}
# If this option has a value and we've already seen another exclusive option
if value is not None and any(ctx.exclusive_opts.values()):
exclusive_opt = next(key for key, val in ctx.exclusive_opts.items() if val)
raise click.UsageError(
f"Options '{param.name}' and '{exclusive_opt}' cannot be set together."
)
# Record this option's value
ctx.exclusive_opts[param.name] = value is not None
return value
@main.command("run")
@click.option(
"--save_session",
@ -105,13 +122,43 @@ def cli_create_cmd(
default=False,
help="Optional. Whether to save the session to a json file on exit.",
)
@click.option(
"--replay",
type=click.Path(
exists=True, dir_okay=False, file_okay=True, resolve_path=True
),
help=(
"The json file that contains the initial state of the session and user"
" queries. A new session will be created using this state. And user"
" queries are run againt the newly created session. Users cannot"
" continue to interact with the agent."
),
callback=validate_exclusive,
)
@click.option(
"--resume",
type=click.Path(
exists=True, dir_okay=False, file_okay=True, resolve_path=True
),
help=(
"The json file that contains a previously saved session (by"
"--save_session option). The previous session will be re-displayed. And"
" user can continue to interact with the agent."
),
callback=validate_exclusive,
)
@click.argument(
"agent",
type=click.Path(
exists=True, dir_okay=True, file_okay=False, resolve_path=True
),
)
def cli_run(agent: str, save_session: bool):
def cli_run(
agent: str,
save_session: bool,
replay: Optional[str],
resume: Optional[str],
):
"""Runs an interactive CLI for a certain agent.
AGENT: The path to the agent source code folder.
@ -129,6 +176,8 @@ def cli_run(agent: str, save_session: bool):
run_cli(
agent_parent_dir=agent_parent_folder,
agent_folder_name=agent_folder_name,
input_file=replay,
saved_session_file=resume,
save_session=save_session,
)
)
@ -245,12 +294,13 @@ def cli_eval(
@click.option(
"--session_db_url",
help=(
"Optional. The database URL to store the session.\n\n - Use"
" 'agentengine://<agent_engine_resource_id>' to connect to Vertex"
" managed session service.\n\n - Use 'sqlite://<path_to_sqlite_file>'"
" to connect to a SQLite DB.\n\n - See"
" https://docs.sqlalchemy.org/en/20/core/engines.html#backend-specific-urls"
" for more details on supported DB URLs."
"""Optional. The database URL to store the session.
- Use 'agentengine://<agent_engine_resource_id>' to connect to Agent Engine sessions.
- Use 'sqlite://<path_to_sqlite_file>' to connect to a SQLite DB.
- See https://docs.sqlalchemy.org/en/20/core/engines.html#backend-specific-urls for more details on supported DB URLs."""
),
)
@click.option(
@ -366,12 +416,13 @@ def cli_web(
@click.option(
"--session_db_url",
help=(
"Optional. The database URL to store the session.\n\n - Use"
" 'agentengine://<agent_engine_resource_id>' to connect to Vertex"
" managed session service.\n\n - Use 'sqlite://<path_to_sqlite_file>'"
" to connect to a SQLite DB.\n\n - See"
" https://docs.sqlalchemy.org/en/20/core/engines.html#backend-specific-urls"
" for more details on supported DB URLs."
"""Optional. The database URL to store the session.
- Use 'agentengine://<agent_engine_resource_id>' to connect to Agent Engine sessions.
- Use 'sqlite://<path_to_sqlite_file>' to connect to a SQLite DB.
- See https://docs.sqlalchemy.org/en/20/core/engines.html#backend-specific-urls for more details on supported DB URLs."""
),
)
@click.option(
@ -541,6 +592,18 @@ def cli_api_server(
default="WARNING",
help="Optional. Override the default verbosity level.",
)
@click.option(
"--session_db_url",
help=(
"""Optional. The database URL to store the session.
- Use 'agentengine://<agent_engine_resource_id>' to connect to Agent Engine sessions.
- Use 'sqlite://<path_to_sqlite_file>' to connect to a SQLite DB.
- See https://docs.sqlalchemy.org/en/20/core/engines.html#backend-specific-urls for more details on supported DB URLs."""
),
)
@click.argument(
"agent",
type=click.Path(
@ -558,6 +621,7 @@ def cli_deploy_cloud_run(
trace_to_cloud: bool,
with_ui: bool,
verbosity: str,
session_db_url: str,
):
"""Deploys an agent to Cloud Run.
@ -579,6 +643,7 @@ def cli_deploy_cloud_run(
trace_to_cloud=trace_to_cloud,
with_ui=with_ui,
verbosity=verbosity,
session_db_url=session_db_url,
)
except Exception as e:
click.secho(f"Deploy failed: {e}", fg="red", err=True)

View File

@ -756,6 +756,12 @@ def get_fast_api_app(
except Exception as e:
logger.exception("Error during live websocket communication: %s", e)
traceback.print_exc()
WEBSOCKET_INTERNAL_ERROR_CODE = 1011
WEBSOCKET_MAX_BYTES_FOR_REASON = 123
await websocket.close(
code=WEBSOCKET_INTERNAL_ERROR_CODE,
reason=str(e)[:WEBSOCKET_MAX_BYTES_FOR_REASON],
)
finally:
for task in pending:
task.cancel()

View File

@ -55,7 +55,7 @@ def load_json(file_path: str) -> Union[Dict, List]:
class AgentEvaluator:
"""An evaluator for Agents, mainly intented for helping with test cases."""
"""An evaluator for Agents, mainly intended for helping with test cases."""
@staticmethod
def find_config_for_test_file(test_file: str):
@ -91,7 +91,7 @@ class AgentEvaluator:
look for 'root_agent' in the loaded module.
eval_dataset: The eval data set. This can be either a string representing
full path to the file containing eval dataset, or a directory that is
recusively explored for all files that have a `.test.json` suffix.
recursively explored for all files that have a `.test.json` suffix.
num_runs: Number of times all entries in the eval dataset should be
assessed.
agent_name: The name of the agent.

View File

@ -35,7 +35,7 @@ class ResponseEvaluator:
Args:
raw_eval_dataset: The dataset that will be evaluated.
evaluation_criteria: The evaluation criteria to be used. This method
support two criterias, `response_evaluation_score` and
support two criteria, `response_evaluation_score` and
`response_match_score`.
print_detailed_results: Prints detailed results on the console. This is
usually helpful during debugging.
@ -56,7 +56,7 @@ class ResponseEvaluator:
Value range: [0, 5], where 0 means that the agent's response is not
coherent, while 5 means it is . High values are good.
A note on raw_eval_dataset:
The dataset should be a list session, where each sesssion is represented
The dataset should be a list session, where each session is represented
as a list of interaction that need evaluation. Each evaluation is
represented as a dictionary that is expected to have values for the
following keys:

View File

@ -31,10 +31,9 @@ class TrajectoryEvaluator:
):
r"""Returns the mean tool use accuracy of the eval dataset.
Tool use accuracy is calculated by comparing the expected and actuall tool
use trajectories. An exact match scores a 1, 0 otherwise. The final number
is an
average of these individual scores.
Tool use accuracy is calculated by comparing the expected and the actual
tool use trajectories. An exact match scores a 1, 0 otherwise. The final
number is an average of these individual scores.
Value range: [0, 1], where 0 is means none of the too use entries aligned,
and 1 would mean all of them aligned. Higher value is good.
@ -45,7 +44,7 @@ class TrajectoryEvaluator:
usually helpful during debugging.
A note on eval_dataset:
The dataset should be a list session, where each sesssion is represented
The dataset should be a list session, where each session is represented
as a list of interaction that need evaluation. Each evaluation is
represented as a dictionary that is expected to have values for the
following keys:

View File

@ -48,8 +48,13 @@ class EventActions(BaseModel):
"""The agent is escalating to a higher level agent."""
requested_auth_configs: dict[str, AuthConfig] = Field(default_factory=dict)
"""Will only be set by a tool response indicating tool request euc.
dict key is the function call id since one function call response (from model)
could correspond to multiple function calls.
dict value is the required auth config.
"""Authentication configurations requested by tool responses.
This field will only be set by a tool response event indicating tool request
auth credential.
- Keys: The function call id. Since one function response event could contain
multiple function responses that correspond to multiple function calls. Each
function call could request different auth configs. This id is used to
identify the function call.
- Values: The requested auth config.
"""

View File

@ -94,7 +94,7 @@ can answer it.
If another agent is better for answering the question according to its
description, call `{_TRANSFER_TO_AGENT_FUNCTION_NAME}` function to transfer the
question to that agent. When transfering, do not generate any text other than
question to that agent. When transferring, do not generate any text other than
the function call.
"""

View File

@ -115,7 +115,7 @@ class BaseLlmFlow(ABC):
yield event
# send back the function response
if event.get_function_responses():
logger.debug('Sending back last function resonse event: %s', event)
logger.debug('Sending back last function response event: %s', event)
invocation_context.live_request_queue.send_content(event.content)
if (
event.content

View File

@ -111,7 +111,7 @@ def _rearrange_events_for_latest_function_response(
"""Rearrange the events for the latest function_response.
If the latest function_response is for an async function_call, all events
bewteen the initial function_call and the latest function_response will be
between the initial function_call and the latest function_response will be
removed.
Args:

View File

@ -151,28 +151,33 @@ async def handle_function_calls_async(
# do not use "args" as the variable name, because it is a reserved keyword
# in python debugger.
function_args = function_call.args or {}
function_response = None
# Calls the tool if before_tool_callback does not exist or returns None.
function_response: Optional[dict] = None
# before_tool_callback (sync or async)
if agent.before_tool_callback:
function_response = agent.before_tool_callback(
tool=tool, args=function_args, tool_context=tool_context
)
if inspect.isawaitable(function_response):
function_response = await function_response
if not function_response:
function_response = await __call_tool_async(
tool, args=function_args, tool_context=tool_context
)
# Calls after_tool_callback if it exists.
# after_tool_callback (sync or async)
if agent.after_tool_callback:
new_response = agent.after_tool_callback(
altered_function_response = agent.after_tool_callback(
tool=tool,
args=function_args,
tool_context=tool_context,
tool_response=function_response,
)
if new_response:
function_response = new_response
if inspect.isawaitable(altered_function_response):
altered_function_response = await altered_function_response
if altered_function_response is not None:
function_response = altered_function_response
if tool.is_long_running:
# Allow long running function to return None to not provide function response.
@ -223,11 +228,17 @@ async def handle_function_calls_live(
# in python debugger.
function_args = function_call.args or {}
function_response = None
# Calls the tool if before_tool_callback does not exist or returns None.
# # Calls the tool if before_tool_callback does not exist or returns None.
# if agent.before_tool_callback:
# function_response = agent.before_tool_callback(
# tool, function_args, tool_context
# )
if agent.before_tool_callback:
function_response = agent.before_tool_callback(
tool, function_args, tool_context
tool=tool, args=function_args, tool_context=tool_context
)
if inspect.isawaitable(function_response):
function_response = await function_response
if not function_response:
function_response = await _process_function_live_helper(
@ -235,15 +246,26 @@ async def handle_function_calls_live(
)
# Calls after_tool_callback if it exists.
# if agent.after_tool_callback:
# new_response = agent.after_tool_callback(
# tool,
# function_args,
# tool_context,
# function_response,
# )
# if new_response:
# function_response = new_response
if agent.after_tool_callback:
new_response = agent.after_tool_callback(
tool,
function_args,
tool_context,
function_response,
altered_function_response = agent.after_tool_callback(
tool=tool,
args=function_args,
tool_context=tool_context,
tool_response=function_response,
)
if new_response:
function_response = new_response
if inspect.isawaitable(altered_function_response):
altered_function_response = await altered_function_response
if altered_function_response is not None:
function_response = altered_function_response
if tool.is_long_running:
# Allow async function to return None to not provide function response.
@ -310,9 +332,7 @@ async def _process_function_live_helper(
function_response = {
'status': f'No active streaming function named {function_name} found'
}
elif inspect.isasyncgenfunction(tool.func):
print('is async')
elif hasattr(tool, "func") and inspect.isasyncgenfunction(tool.func):
# for streaming tool use case
# we require the function to be a async generator function
async def run_tool_and_update_queue(tool, function_args, tool_context):

View File

@ -52,7 +52,7 @@ class _InstructionsLlmRequestProcessor(BaseLlmRequestProcessor):
# Appends global instructions if set.
if (
isinstance(root_agent, LlmAgent) and root_agent.global_instruction
): # not emtpy str
): # not empty str
raw_si = root_agent.canonical_global_instruction(
ReadonlyContext(invocation_context)
)
@ -60,7 +60,7 @@ class _InstructionsLlmRequestProcessor(BaseLlmRequestProcessor):
llm_request.append_instructions([si])
# Appends agent instructions if set.
if agent.instruction: # not emtpy str
if agent.instruction: # not empty str
raw_si = agent.canonical_instruction(ReadonlyContext(invocation_context))
si = _populate_values(raw_si, invocation_context)
llm_request.append_instructions([si])

View File

@ -152,7 +152,7 @@ class GeminiLlmConnection(BaseLlmConnection):
):
# TODO: Right now, we just support output_transcription without
# changing interface and data protocol. Later, we can consider to
# support output_transcription as a separete field in LlmResponse.
# support output_transcription as a separate field in LlmResponse.
# Transcription is always considered as partial event
# We rely on other control signals to determine when to yield the
@ -179,7 +179,7 @@ class GeminiLlmConnection(BaseLlmConnection):
# in case of empty content or parts, we sill surface it
# in case it's an interrupted message, we merge the previous partial
# text. Other we don't merge. because content can be none when model
# safty threshold is triggered
# safety threshold is triggered
if message.server_content.interrupted and text:
yield self.__build_full_text_response(text)
text = ''

View File

@ -14,7 +14,7 @@
from __future__ import annotations
from typing import Optional
from typing import Any, Optional
from google.genai import types
from pydantic import BaseModel
@ -37,6 +37,7 @@ class LlmResponse(BaseModel):
error_message: Error message if the response is an error.
interrupted: Flag indicating that LLM was interrupted when generating the
content. Usually it's due to user interruption during a bidi streaming.
custom_metadata: The custom metadata of the LlmResponse.
"""
model_config = ConfigDict(extra='forbid')
@ -71,6 +72,14 @@ class LlmResponse(BaseModel):
Usually it's due to user interruption during a bidi streaming.
"""
custom_metadata: Optional[dict[str, Any]] = None
"""The custom metadata of the LlmResponse.
An optional key-value pair to label an LlmResponse.
NOTE: the entire dict must be JSON serializable.
"""
@staticmethod
def create(
generate_content_response: types.GenerateContentResponse,

View File

@ -56,6 +56,7 @@ class BuiltInPlanner(BasePlanner):
llm_request: The LLM request to apply the thinking config to.
"""
if self.thinking_config:
llm_request.config = llm_request.config or types.GenerateContentConfig()
llm_request.config.thinking_config = self.thinking_config
@override

View File

@ -0,0 +1,29 @@
"""Utility functions for session service."""
import base64
from typing import Any, Optional
from google.genai import types
def encode_content(content: types.Content):
"""Encodes a content object to a JSON dictionary."""
encoded_content = content.model_dump(exclude_none=True)
for p in encoded_content["parts"]:
if "inline_data" in p:
p["inline_data"]["data"] = base64.b64encode(
p["inline_data"]["data"]
).decode("utf-8")
return encoded_content
def decode_content(
content: Optional[dict[str, Any]],
) -> Optional[types.Content]:
"""Decodes a content object from a JSON dictionary."""
if not content:
return None
for p in content["parts"]:
if "inline_data" in p:
p["inline_data"]["data"] = base64.b64decode(p["inline_data"]["data"])
return types.Content.model_validate(content)

View File

@ -11,14 +11,11 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import base64
import copy
from datetime import datetime
import json
import logging
from typing import Any
from typing import Optional
from typing import Any, Optional
import uuid
from sqlalchemy import Boolean
@ -27,6 +24,7 @@ from sqlalchemy import Dialect
from sqlalchemy import ForeignKeyConstraint
from sqlalchemy import func
from sqlalchemy import Text
from sqlalchemy.dialects import mysql
from sqlalchemy.dialects import postgresql
from sqlalchemy.engine import create_engine
from sqlalchemy.engine import Engine
@ -48,6 +46,7 @@ from typing_extensions import override
from tzlocal import get_localzone
from ..events.event import Event
from . import _session_util
from .base_session_service import BaseSessionService
from .base_session_service import GetSessionConfig
from .base_session_service import ListEventsResponse
@ -58,6 +57,9 @@ from .state import State
logger = logging.getLogger(__name__)
DEFAULT_MAX_KEY_LENGTH = 128
DEFAULT_MAX_VARCHAR_LENGTH = 256
class DynamicJSON(TypeDecorator):
"""A JSON-like type that uses JSONB on PostgreSQL and TEXT with JSON
@ -70,14 +72,15 @@ class DynamicJSON(TypeDecorator):
def load_dialect_impl(self, dialect: Dialect):
if dialect.name == "postgresql":
return dialect.type_descriptor(postgresql.JSONB)
else:
if dialect.name == "mysql":
# Use LONGTEXT for MySQL to address the data too long issue
return dialect.type_descriptor(mysql.LONGTEXT)
return dialect.type_descriptor(Text) # Default to Text for other dialects
def process_bind_param(self, value, dialect: Dialect):
if value is not None:
if dialect.name == "postgresql":
return value # JSONB handles dict directly
else:
return json.dumps(value) # Serialize to JSON string for TEXT
return value
@ -92,17 +95,25 @@ class DynamicJSON(TypeDecorator):
class Base(DeclarativeBase):
"""Base class for database tables."""
pass
class StorageSession(Base):
"""Represents a session stored in the database."""
__tablename__ = "sessions"
app_name: Mapped[str] = mapped_column(String, primary_key=True)
user_id: Mapped[str] = mapped_column(String, primary_key=True)
app_name: Mapped[str] = mapped_column(
String(DEFAULT_MAX_KEY_LENGTH), primary_key=True
)
user_id: Mapped[str] = mapped_column(
String(DEFAULT_MAX_KEY_LENGTH), primary_key=True
)
id: Mapped[str] = mapped_column(
String, primary_key=True, default=lambda: str(uuid.uuid4())
String(DEFAULT_MAX_KEY_LENGTH),
primary_key=True,
default=lambda: str(uuid.uuid4()),
)
state: Mapped[MutableDict[str, Any]] = mapped_column(
@ -125,18 +136,29 @@ class StorageSession(Base):
class StorageEvent(Base):
"""Represents an event stored in the database."""
__tablename__ = "events"
id: Mapped[str] = mapped_column(String, primary_key=True)
app_name: Mapped[str] = mapped_column(String, primary_key=True)
user_id: Mapped[str] = mapped_column(String, primary_key=True)
session_id: Mapped[str] = mapped_column(String, primary_key=True)
id: Mapped[str] = mapped_column(
String(DEFAULT_MAX_KEY_LENGTH), primary_key=True
)
app_name: Mapped[str] = mapped_column(
String(DEFAULT_MAX_KEY_LENGTH), primary_key=True
)
user_id: Mapped[str] = mapped_column(
String(DEFAULT_MAX_KEY_LENGTH), primary_key=True
)
session_id: Mapped[str] = mapped_column(
String(DEFAULT_MAX_KEY_LENGTH), primary_key=True
)
invocation_id: Mapped[str] = mapped_column(String)
author: Mapped[str] = mapped_column(String)
branch: Mapped[str] = mapped_column(String, nullable=True)
invocation_id: Mapped[str] = mapped_column(String(DEFAULT_MAX_VARCHAR_LENGTH))
author: Mapped[str] = mapped_column(String(DEFAULT_MAX_VARCHAR_LENGTH))
branch: Mapped[str] = mapped_column(
String(DEFAULT_MAX_VARCHAR_LENGTH), nullable=True
)
timestamp: Mapped[DateTime] = mapped_column(DateTime(), default=func.now())
content: Mapped[dict[str, Any]] = mapped_column(DynamicJSON)
content: Mapped[dict[str, Any]] = mapped_column(DynamicJSON, nullable=True)
actions: Mapped[MutableDict[str, Any]] = mapped_column(PickleType)
long_running_tool_ids_json: Mapped[Optional[str]] = mapped_column(
@ -147,8 +169,10 @@ class StorageEvent(Base):
)
partial: Mapped[bool] = mapped_column(Boolean, nullable=True)
turn_complete: Mapped[bool] = mapped_column(Boolean, nullable=True)
error_code: Mapped[str] = mapped_column(String, nullable=True)
error_message: Mapped[str] = mapped_column(String, nullable=True)
error_code: Mapped[str] = mapped_column(
String(DEFAULT_MAX_VARCHAR_LENGTH), nullable=True
)
error_message: Mapped[str] = mapped_column(String(1024), nullable=True)
interrupted: Mapped[bool] = mapped_column(Boolean, nullable=True)
storage_session: Mapped[StorageSession] = relationship(
@ -182,9 +206,12 @@ class StorageEvent(Base):
class StorageAppState(Base):
"""Represents an app state stored in the database."""
__tablename__ = "app_states"
app_name: Mapped[str] = mapped_column(String, primary_key=True)
app_name: Mapped[str] = mapped_column(
String(DEFAULT_MAX_KEY_LENGTH), primary_key=True
)
state: Mapped[MutableDict[str, Any]] = mapped_column(
MutableDict.as_mutable(DynamicJSON), default={}
)
@ -192,13 +219,17 @@ class StorageAppState(Base):
DateTime(), default=func.now(), onupdate=func.now()
)
class StorageUserState(Base):
"""Represents a user state stored in the database."""
__tablename__ = "user_states"
app_name: Mapped[str] = mapped_column(String, primary_key=True)
user_id: Mapped[str] = mapped_column(String, primary_key=True)
app_name: Mapped[str] = mapped_column(
String(DEFAULT_MAX_KEY_LENGTH), primary_key=True
)
user_id: Mapped[str] = mapped_column(
String(DEFAULT_MAX_KEY_LENGTH), primary_key=True
)
state: Mapped[MutableDict[str, Any]] = mapped_column(
MutableDict.as_mutable(DynamicJSON), default={}
)
@ -217,7 +248,7 @@ class DatabaseSessionService(BaseSessionService):
"""
# 1. Create DB engine for db connection
# 2. Create all tables based on schema
# 3. Initialize all properies
# 3. Initialize all properties
try:
db_engine = create_engine(db_url)
@ -353,6 +384,7 @@ class DatabaseSessionService(BaseSessionService):
else True
)
.limit(config.num_recent_events if config else None)
.order_by(StorageEvent.timestamp.asc())
.all()
)
@ -383,7 +415,7 @@ class DatabaseSessionService(BaseSessionService):
author=e.author,
branch=e.branch,
invocation_id=e.invocation_id,
content=_decode_content(e.content),
content=_session_util.decode_content(e.content),
actions=e.actions,
timestamp=e.timestamp.timestamp(),
long_running_tool_ids=e.long_running_tool_ids,
@ -506,15 +538,7 @@ class DatabaseSessionService(BaseSessionService):
interrupted=event.interrupted,
)
if event.content:
encoded_content = event.content.model_dump(exclude_none=True)
# Workaround for multimodal Content throwing JSON not serializable
# error with SQLAlchemy.
for p in encoded_content["parts"]:
if "inline_data" in p:
p["inline_data"]["data"] = (
base64.b64encode(p["inline_data"]["data"]).decode("utf-8"),
)
storage_event.content = encoded_content
storage_event.content = _session_util.encode_content(event.content)
sessionFactory.add(storage_event)
@ -574,10 +598,3 @@ def _merge_state(app_state, user_state, session_state):
for key in user_state.keys():
merged_state[State.USER_PREFIX + key] = user_state[key]
return merged_state
def _decode_content(content: dict[str, Any]) -> dict[str, Any]:
for p in content["parts"]:
if "inline_data" in p:
p["inline_data"]["data"] = base64.b64decode(p["inline_data"]["data"][0])
return content

View File

@ -26,7 +26,7 @@ class State:
"""
Args:
value: The current value of the state dict.
delta: The delta change to the current value that hasn't been commited.
delta: The delta change to the current value that hasn't been committed.
"""
self._value = value
self._delta = delta

View File

@ -14,21 +14,23 @@
import logging
import re
import time
from typing import Any
from typing import Optional
from typing import Any, Optional
from dateutil.parser import isoparse
from dateutil import parser
from google import genai
from typing_extensions import override
from ..events.event import Event
from ..events.event_actions import EventActions
from . import _session_util
from .base_session_service import BaseSessionService
from .base_session_service import GetSessionConfig
from .base_session_service import ListEventsResponse
from .base_session_service import ListSessionsResponse
from .session import Session
isoparse = parser.isoparse
logger = logging.getLogger(__name__)
@ -289,7 +291,7 @@ def _convert_event_to_json(event: Event):
}
event_json['actions'] = actions_json
if event.content:
event_json['content'] = event.content.model_dump(exclude_none=True)
event_json['content'] = _session_util.encode_content(event.content)
if event.error_code:
event_json['error_code'] = event.error_code
if event.error_message:
@ -316,7 +318,7 @@ def _from_api_event(api_event: dict) -> Event:
invocation_id=api_event['invocationId'],
author=api_event['author'],
actions=event_actions,
content=api_event.get('content', None),
content=_session_util.decode_content(api_event.get('content', None)),
timestamp=isoparse(api_event['timestamp']).timestamp(),
error_code=api_event.get('errorCode', None),
error_message=api_event.get('errorMessage', None),

View File

@ -45,10 +45,9 @@ class AgentTool(BaseTool):
skip_summarization: Whether to skip summarization of the agent output.
"""
def __init__(self, agent: BaseAgent):
def __init__(self, agent: BaseAgent, skip_summarization: bool = False):
self.agent = agent
self.skip_summarization: bool = False
"""Whether to skip summarization of the agent output."""
self.skip_summarization: bool = skip_summarization
super().__init__(name=agent.name, description=agent.description)

View File

@ -13,7 +13,9 @@
# limitations under the License.
from .application_integration_toolset import ApplicationIntegrationToolset
from .integration_connector_tool import IntegrationConnectorTool
__all__ = [
'ApplicationIntegrationToolset',
'IntegrationConnectorTool',
]

View File

@ -12,21 +12,21 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Dict
from typing import List
from typing import Optional
from typing import Dict, List, Optional
from fastapi.openapi.models import HTTPBearer
from google.adk.tools.application_integration_tool.clients.connections_client import ConnectionsClient
from google.adk.tools.application_integration_tool.clients.integration_client import IntegrationClient
from google.adk.tools.openapi_tool.auth.auth_helpers import service_account_scheme_credential
from google.adk.tools.openapi_tool.openapi_spec_parser.openapi_toolset import OpenAPIToolset
from google.adk.tools.openapi_tool.openapi_spec_parser.rest_api_tool import RestApiTool
from ...auth.auth_credential import AuthCredential
from ...auth.auth_credential import AuthCredentialTypes
from ...auth.auth_credential import ServiceAccount
from ...auth.auth_credential import ServiceAccountCredential
from ..openapi_tool.auth.auth_helpers import service_account_scheme_credential
from ..openapi_tool.openapi_spec_parser.openapi_spec_parser import OpenApiSpecParser
from ..openapi_tool.openapi_spec_parser.openapi_toolset import OpenAPIToolset
from ..openapi_tool.openapi_spec_parser.rest_api_tool import RestApiTool
from .clients.connections_client import ConnectionsClient
from .clients.integration_client import IntegrationClient
from .integration_connector_tool import IntegrationConnectorTool
# TODO(cheliu): Apply a common toolset interface
@ -168,6 +168,7 @@ class ApplicationIntegrationToolset:
actions,
service_account_json,
)
connection_details = {}
if integration and trigger:
spec = integration_client.get_openapi_spec_for_integration()
elif connection and (entity_operations or actions):
@ -175,16 +176,6 @@ class ApplicationIntegrationToolset:
project, location, connection, service_account_json
)
connection_details = connections_client.get_connection_details()
tool_instructions += (
"ALWAYS use serviceName = "
+ connection_details["serviceName"]
+ ", host = "
+ connection_details["host"]
+ " and the connection name = "
+ f"projects/{project}/locations/{location}/connections/{connection} when"
" using this tool"
+ ". DONOT ask the user for these values as you already have those."
)
spec = integration_client.get_openapi_spec_for_connection(
tool_name,
tool_instructions,
@ -194,9 +185,9 @@ class ApplicationIntegrationToolset:
"Either (integration and trigger) or (connection and"
" (entity_operations or actions)) should be provided."
)
self._parse_spec_to_tools(spec)
self._parse_spec_to_tools(spec, connection_details)
def _parse_spec_to_tools(self, spec_dict):
def _parse_spec_to_tools(self, spec_dict, connection_details):
"""Parses the spec dict to a list of RestApiTool."""
if self.service_account_json:
sa_credential = ServiceAccountCredential.model_validate_json(
@ -218,6 +209,8 @@ class ApplicationIntegrationToolset:
),
)
auth_scheme = HTTPBearer(bearerFormat="JWT")
if self.integration and self.trigger:
tools = OpenAPIToolset(
spec_dict=spec_dict,
auth_credential=auth_credential,
@ -225,6 +218,35 @@ class ApplicationIntegrationToolset:
).get_tools()
for tool in tools:
self.generated_tools[tool.name] = tool
return
operations = OpenApiSpecParser().parse(spec_dict)
for open_api_operation in operations:
operation = getattr(open_api_operation.operation, "x-operation")
entity = None
action = None
if hasattr(open_api_operation.operation, "x-entity"):
entity = getattr(open_api_operation.operation, "x-entity")
elif hasattr(open_api_operation.operation, "x-action"):
action = getattr(open_api_operation.operation, "x-action")
rest_api_tool = RestApiTool.from_parsed_operation(open_api_operation)
if auth_scheme:
rest_api_tool.configure_auth_scheme(auth_scheme)
if auth_credential:
rest_api_tool.configure_auth_credential(auth_credential)
tool = IntegrationConnectorTool(
name=rest_api_tool.name,
description=rest_api_tool.description,
connection_name=connection_details["name"],
connection_host=connection_details["host"],
connection_service_name=connection_details["serviceName"],
entity=entity,
action=action,
operation=operation,
rest_api_tool=rest_api_tool,
)
self.generated_tools[tool.name] = tool
def get_tools(self) -> List[RestApiTool]:
return list(self.generated_tools.values())

View File

@ -68,12 +68,14 @@ class ConnectionsClient:
response = self._execute_api_call(url)
connection_data = response.json()
connection_name = connection_data.get("name", "")
service_name = connection_data.get("serviceDirectory", "")
host = connection_data.get("host", "")
if host:
service_name = connection_data.get("tlsServiceDirectory", "")
auth_override_enabled = connection_data.get("authOverrideEnabled", False)
return {
"name": connection_name,
"serviceName": service_name,
"host": host,
"authOverrideEnabled": auth_override_enabled,
@ -291,13 +293,9 @@ class ConnectionsClient:
tool_name: str = "",
tool_instructions: str = "",
) -> Dict[str, Any]:
description = (
f"Use this tool with" f' action = "{action}" and'
) + f' operation = "{operation}" only. Dont ask these values from user.'
description = f"Use this tool to execute {action}"
if operation == "EXECUTE_QUERY":
description = (
(f"Use this tool with" f' action = "{action}" and')
+ f' operation = "{operation}" only. Dont ask these values from user.'
description += (
" Use pageSize = 50 and timeout = 120 until user specifies a"
" different value otherwise. If user provides a query in natural"
" language, convert it to SQL query and then execute it using the"
@ -308,6 +306,8 @@ class ConnectionsClient:
"summary": f"{action_display_name}",
"description": f"{description} {tool_instructions}",
"operationId": f"{tool_name}_{action_display_name}",
"x-action": f"{action}",
"x-operation": f"{operation}",
"requestBody": {
"content": {
"application/json": {
@ -347,16 +347,12 @@ class ConnectionsClient:
"post": {
"summary": f"List {entity}",
"description": (
f"Returns all entities of type {entity}. Use this tool with"
+ f' entity = "{entity}" and'
+ ' operation = "LIST_ENTITIES" only. Dont ask these values'
" from"
+ ' user. Always use ""'
+ ' as filter clause and ""'
+ " as page token and 50 as page size until user specifies a"
" different value otherwise. Use single quotes for strings in"
f" filter clause. {tool_instructions}"
f"""Returns the list of {entity} data. If the page token was available in the response, let users know there are more records available. Ask if the user wants to fetch the next page of results. When passing filter use the
following format: `field_name1='value1' AND field_name2='value2'
`. {tool_instructions}"""
),
"x-operation": "LIST_ENTITIES",
"x-entity": f"{entity}",
"operationId": f"{tool_name}_list_{entity}",
"requestBody": {
"content": {
@ -401,14 +397,11 @@ class ConnectionsClient:
"post": {
"summary": f"Get {entity}",
"description": (
(
f"Returns the details of the {entity}. Use this tool with"
f' entity = "{entity}" and'
)
+ ' operation = "GET_ENTITY" only. Dont ask these values from'
f" user. {tool_instructions}"
f"Returns the details of the {entity}. {tool_instructions}"
),
"operationId": f"{tool_name}_get_{entity}",
"x-operation": "GET_ENTITY",
"x-entity": f"{entity}",
"requestBody": {
"content": {
"application/json": {
@ -445,17 +438,10 @@ class ConnectionsClient:
) -> Dict[str, Any]:
return {
"post": {
"summary": f"Create {entity}",
"description": (
(
f"Creates a new entity of type {entity}. Use this tool with"
f' entity = "{entity}" and'
)
+ ' operation = "CREATE_ENTITY" only. Dont ask these values'
" from"
+ " user. Follow the schema of the entity provided in the"
f" instructions to create {entity}. {tool_instructions}"
),
"summary": f"Creates a new {entity}",
"description": f"Creates a new {entity}. {tool_instructions}",
"x-operation": "CREATE_ENTITY",
"x-entity": f"{entity}",
"operationId": f"{tool_name}_create_{entity}",
"requestBody": {
"content": {
@ -491,18 +477,10 @@ class ConnectionsClient:
) -> Dict[str, Any]:
return {
"post": {
"summary": f"Update {entity}",
"description": (
(
f"Updates an entity of type {entity}. Use this tool with"
f' entity = "{entity}" and'
)
+ ' operation = "UPDATE_ENTITY" only. Dont ask these values'
" from"
+ " user. Use entityId to uniquely identify the entity to"
" update. Follow the schema of the entity provided in the"
f" instructions to update {entity}. {tool_instructions}"
),
"summary": f"Updates the {entity}",
"description": f"Updates the {entity}. {tool_instructions}",
"x-operation": "UPDATE_ENTITY",
"x-entity": f"{entity}",
"operationId": f"{tool_name}_update_{entity}",
"requestBody": {
"content": {
@ -538,16 +516,10 @@ class ConnectionsClient:
) -> Dict[str, Any]:
return {
"post": {
"summary": f"Delete {entity}",
"description": (
(
f"Deletes an entity of type {entity}. Use this tool with"
f' entity = "{entity}" and'
)
+ ' operation = "DELETE_ENTITY" only. Dont ask these values'
" from"
f" user. {tool_instructions}"
),
"summary": f"Delete the {entity}",
"description": f"Deletes the {entity}. {tool_instructions}",
"x-operation": "DELETE_ENTITY",
"x-entity": f"{entity}",
"operationId": f"{tool_name}_delete_{entity}",
"requestBody": {
"content": {

View File

@ -0,0 +1,159 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from typing import Any
from typing import Dict
from typing import Optional
from google.adk.tools.openapi_tool.openapi_spec_parser.rest_api_tool import RestApiTool
from google.adk.tools.openapi_tool.openapi_spec_parser.rest_api_tool import to_gemini_schema
from google.genai.types import FunctionDeclaration
from typing_extensions import override
from .. import BaseTool
from ..tool_context import ToolContext
logger = logging.getLogger(__name__)
class IntegrationConnectorTool(BaseTool):
"""A tool that wraps a RestApiTool to interact with a specific Application Integration endpoint.
This tool adds Application Integration specific context like connection
details, entity, operation, and action to the underlying REST API call
handled by RestApiTool. It prepares the arguments and then delegates the
actual API call execution to the contained RestApiTool instance.
* Generates request params and body
* Attaches auth credentials to API call.
Example:
```
# Each API operation in the spec will be turned into its own tool
# Name of the tool is the operationId of that operation, in snake case
operations = OperationGenerator().parse(openapi_spec_dict)
tool = [RestApiTool.from_parsed_operation(o) for o in operations]
```
"""
EXCLUDE_FIELDS = [
'connection_name',
'service_name',
'host',
'entity',
'operation',
'action',
]
OPTIONAL_FIELDS = [
'page_size',
'page_token',
'filter',
]
def __init__(
self,
name: str,
description: str,
connection_name: str,
connection_host: str,
connection_service_name: str,
entity: str,
operation: str,
action: str,
rest_api_tool: RestApiTool,
):
"""Initializes the ApplicationIntegrationTool.
Args:
name: The name of the tool, typically derived from the API operation.
Should be unique and adhere to Gemini function naming conventions
(e.g., less than 64 characters).
description: A description of what the tool does, usually based on the
API operation's summary or description.
connection_name: The name of the Integration Connector connection.
connection_host: The hostname or IP address for the connection.
connection_service_name: The specific service name within the host.
entity: The Integration Connector entity being targeted.
operation: The specific operation being performed on the entity.
action: The action associated with the operation (e.g., 'execute').
rest_api_tool: An initialized RestApiTool instance that handles the
underlying REST API communication based on an OpenAPI specification
operation. This tool will be called by ApplicationIntegrationTool with
added connection and context arguments. tool =
[RestApiTool.from_parsed_operation(o) for o in operations]
"""
# Gemini restrict the length of function name to be less than 64 characters
super().__init__(
name=name,
description=description,
)
self.connection_name = connection_name
self.connection_host = connection_host
self.connection_service_name = connection_service_name
self.entity = entity
self.operation = operation
self.action = action
self.rest_api_tool = rest_api_tool
@override
def _get_declaration(self) -> FunctionDeclaration:
"""Returns the function declaration in the Gemini Schema format."""
schema_dict = self.rest_api_tool._operation_parser.get_json_schema()
for field in self.EXCLUDE_FIELDS:
if field in schema_dict['properties']:
del schema_dict['properties'][field]
for field in self.OPTIONAL_FIELDS + self.EXCLUDE_FIELDS:
if field in schema_dict['required']:
schema_dict['required'].remove(field)
parameters = to_gemini_schema(schema_dict)
function_decl = FunctionDeclaration(
name=self.name, description=self.description, parameters=parameters
)
return function_decl
@override
async def run_async(
self, *, args: dict[str, Any], tool_context: Optional[ToolContext]
) -> Dict[str, Any]:
args['connection_name'] = self.connection_name
args['service_name'] = self.connection_service_name
args['host'] = self.connection_host
args['entity'] = self.entity
args['operation'] = self.operation
args['action'] = self.action
logger.info('Running tool: %s with args: %s', self.name, args)
return self.rest_api_tool.call(args=args, tool_context=tool_context)
def __str__(self):
return (
f'ApplicationIntegrationTool(name="{self.name}",'
f' description="{self.description}",'
f' connection_name="{self.connection_name}", entity="{self.entity}",'
f' operation="{self.operation}", action="{self.action}")'
)
def __repr__(self):
return (
f'ApplicationIntegrationTool(name="{self.name}",'
f' description="{self.description}",'
f' connection_name="{self.connection_name}",'
f' connection_host="{self.connection_host}",'
f' connection_service_name="{self.connection_service_name}",'
f' entity="{self.entity}", operation="{self.operation}",'
f' action="{self.action}", rest_api_tool={repr(self.rest_api_tool)})'
)

View File

@ -59,6 +59,23 @@ class FunctionTool(BaseTool):
if 'tool_context' in signature.parameters:
args_to_call['tool_context'] = tool_context
# Before invoking the function, we check for if the list of args passed in
# has all the mandatory arguments or not.
# If the check fails, then we don't invoke the tool and let the Agent know
# that there was a missing a input parameter. This will basically help
# the underlying model fix the issue and retry.
mandatory_args = self._get_mandatory_args()
missing_mandatory_args = [
arg for arg in mandatory_args if arg not in args_to_call
]
if missing_mandatory_args:
missing_mandatory_args_str = '\n'.join(missing_mandatory_args)
error_str = f"""Invoking `{self.name}()` failed as the following mandatory input parameters are not present:
{missing_mandatory_args_str}
You could retry calling this tool, but it is IMPORTANT for you to provide all the mandatory parameters."""
return {'error': error_str}
if inspect.iscoroutinefunction(self.func):
return await self.func(**args_to_call) or {}
else:
@ -85,3 +102,28 @@ class FunctionTool(BaseTool):
args_to_call['tool_context'] = tool_context
async for item in self.func(**args_to_call):
yield item
def _get_mandatory_args(
self,
) -> list[str]:
"""Identifies mandatory parameters (those without default values) for a function.
Returns:
A list of strings, where each string is the name of a mandatory parameter.
"""
signature = inspect.signature(self.func)
mandatory_params = []
for name, param in signature.parameters.items():
# A parameter is mandatory if:
# 1. It has no default value (param.default is inspect.Parameter.empty)
# 2. It's not a variable positional (*args) or variable keyword (**kwargs) parameter
#
# For more refer to: https://docs.python.org/3/library/inspect.html#inspect.Parameter.kind
if param.default == inspect.Parameter.empty and param.kind not in (
inspect.Parameter.VAR_POSITIONAL,
inspect.Parameter.VAR_KEYWORD,
):
mandatory_params.append(name)
return mandatory_params

View File

@ -11,10 +11,12 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import inspect
import os
from typing import Any
from typing import Dict
from typing import Final
from typing import List
from typing import Optional
@ -28,6 +30,7 @@ from .googleapi_to_openapi_converter import GoogleApiToOpenApiConverter
class GoogleApiToolSet:
"""Google API Tool Set."""
def __init__(self, tools: List[RestApiTool]):
self.tools: Final[List[GoogleApiTool]] = [
@ -45,10 +48,10 @@ class GoogleApiToolSet:
@staticmethod
def _load_tool_set_with_oidc_auth(
spec_file: str = None,
spec_dict: Dict[str, Any] = None,
scopes: list[str] = None,
) -> Optional[OpenAPIToolset]:
spec_file: Optional[str] = None,
spec_dict: Optional[dict[str, Any]] = None,
scopes: Optional[list[str]] = None,
) -> OpenAPIToolset:
spec_str = None
if spec_file:
# Get the frame of the caller
@ -90,18 +93,18 @@ class GoogleApiToolSet:
@classmethod
def load_tool_set(
cl: Type['GoogleApiToolSet'],
cls: Type[GoogleApiToolSet],
api_name: str,
api_version: str,
) -> 'GoogleApiToolSet':
) -> GoogleApiToolSet:
spec_dict = GoogleApiToOpenApiConverter(api_name, api_version).convert()
scope = list(
spec_dict['components']['securitySchemes']['oauth2']['flows'][
'authorizationCode'
]['scopes'].keys()
)[0]
return cl(
cl._load_tool_set_with_oidc_auth(
return cls(
cls._load_tool_set_with_oidc_auth(
spec_dict=spec_dict, scopes=[scope]
).get_tools()
)

View File

@ -89,7 +89,7 @@ class LoadArtifactsTool(BaseTool):
than the function call.
"""])
# Attache the content of the artifacts if the model requests them.
# Attach the content of the artifacts if the model requests them.
# This only adds the content to the model request, instead of the session.
if llm_request.contents and llm_request.contents[-1].parts:
function_response = llm_request.contents[-1].parts[0].function_response

View File

@ -66,10 +66,10 @@ class OAuth2CredentialExchanger(BaseAuthCredentialExchanger):
Returns:
An AuthCredential object containing the HTTP bearer access token. If the
HTTO bearer token cannot be generated, return the origianl credential
HTTP bearer token cannot be generated, return the original credential.
"""
if "access_token" not in auth_credential.oauth2.token:
if not auth_credential.oauth2.access_token:
return auth_credential
# Return the access token as a bearer token.
@ -78,7 +78,7 @@ class OAuth2CredentialExchanger(BaseAuthCredentialExchanger):
http=HttpAuth(
scheme="bearer",
credentials=HttpCredentials(
token=auth_credential.oauth2.token["access_token"]
token=auth_credential.oauth2.access_token
),
),
)
@ -111,7 +111,7 @@ class OAuth2CredentialExchanger(BaseAuthCredentialExchanger):
return auth_credential
# If access token is exchanged, exchange a HTTPBearer token.
if auth_credential.oauth2.token:
if auth_credential.oauth2.access_token:
return self.generate_auth_token(auth_credential)
return None

View File

@ -124,7 +124,7 @@ class OpenAPIToolset:
def _load_spec(
self, spec_str: str, spec_type: Literal["json", "yaml"]
) -> Dict[str, Any]:
"""Loads the OpenAPI spec string into adictionary."""
"""Loads the OpenAPI spec string into a dictionary."""
if spec_type == "json":
return json.loads(spec_str)
elif spec_type == "yaml":

View File

@ -14,20 +14,12 @@
import inspect
from textwrap import dedent
from typing import Any
from typing import Dict
from typing import List
from typing import Optional
from typing import Union
from typing import Any, Dict, List, Optional, Union
from fastapi.encoders import jsonable_encoder
from fastapi.openapi.models import Operation
from fastapi.openapi.models import Parameter
from fastapi.openapi.models import Schema
from fastapi.openapi.models import Operation, Parameter, Schema
from ..common.common import ApiParameter
from ..common.common import PydocHelper
from ..common.common import to_snake_case
from ..common.common import ApiParameter, PydocHelper, to_snake_case
class OperationParser:
@ -113,7 +105,8 @@ class OperationParser:
description = request_body.description or ''
if schema and schema.type == 'object':
for prop_name, prop_details in schema.properties.items():
properties = schema.properties or {}
for prop_name, prop_details in properties.items():
self.params.append(
ApiParameter(
original_name=prop_name,

View File

@ -17,6 +17,7 @@ from typing import Dict
from typing import List
from typing import Literal
from typing import Optional
from typing import Sequence
from typing import Tuple
from typing import Union
@ -59,6 +60,40 @@ def snake_to_lower_camel(snake_case_string: str):
])
# TODO: Switch to Gemini `from_json_schema` util when it is released
# in Gemini SDK.
def normalize_json_schema_type(
json_schema_type: Optional[Union[str, Sequence[str]]],
) -> tuple[Optional[str], bool]:
"""Converts a JSON Schema Type into Gemini Schema type.
Adopted and modified from Gemini SDK. This gets the first available schema
type from JSON Schema, and use it to mark Gemini schema type. If JSON Schema
contains a list of types, the first non null type is used.
Remove this after switching to Gemini `from_json_schema`.
"""
if json_schema_type is None:
return None, False
if isinstance(json_schema_type, str):
if json_schema_type == "null":
return None, True
return json_schema_type, False
non_null_types = []
nullable = False
# If json schema type is an array, pick the first non null type.
for type_value in json_schema_type:
if type_value == "null":
nullable = True
else:
non_null_types.append(type_value)
non_null_type = non_null_types[0] if non_null_types else None
return non_null_type, nullable
# TODO: Switch to Gemini `from_json_schema` util when it is released
# in Gemini SDK.
def to_gemini_schema(openapi_schema: Optional[Dict[str, Any]] = None) -> Schema:
"""Converts an OpenAPI schema dictionary to a Gemini Schema object.
@ -82,13 +117,6 @@ def to_gemini_schema(openapi_schema: Optional[Dict[str, Any]] = None) -> Schema:
if not openapi_schema.get("type"):
openapi_schema["type"] = "object"
# Adding this to avoid "properties: should be non-empty for OBJECT type" error
# See b/385165182
if openapi_schema.get("type", "") == "object" and not openapi_schema.get(
"properties"
):
openapi_schema["properties"] = {"dummy_DO_NOT_GENERATE": {"type": "string"}}
for key, value in openapi_schema.items():
snake_case_key = to_snake_case(key)
# Check if the snake_case_key exists in the Schema model's fields.
@ -99,7 +127,17 @@ def to_gemini_schema(openapi_schema: Optional[Dict[str, Any]] = None) -> Schema:
# Format: properties[expiration].format: only 'enum' and 'date-time' are
# supported for STRING type
continue
if snake_case_key == "properties" and isinstance(value, dict):
elif snake_case_key == "type":
schema_type, nullable = normalize_json_schema_type(
openapi_schema.get("type", None)
)
# Adding this to force adding a type to an empty dict
# This avoid "... one_of or any_of must specify a type" error
pydantic_schema_data["type"] = schema_type if schema_type else "object"
pydantic_schema_data["type"] = pydantic_schema_data["type"].upper()
if nullable:
pydantic_schema_data["nullable"] = True
elif snake_case_key == "properties" and isinstance(value, dict):
pydantic_schema_data[snake_case_key] = {
k: to_gemini_schema(v) for k, v in value.items()
}

View File

@ -13,4 +13,4 @@
# limitations under the License.
# version: date+base_cl
__version__ = "0.1.1"
__version__ = "0.3.0"

View File

@ -241,7 +241,7 @@ def test_langchain_tool_success(agent_runner: TestRunner):
def test_crewai_tool_success(agent_runner: TestRunner):
_call_function_and_assert(
agent_runner,
"direcotry_read_tool",
"directory_read_tool",
"Find all the file paths",
"file",
)

View File

@ -126,12 +126,8 @@ def oauth2_credentials_with_token():
client_id="mock_client_id",
client_secret="mock_client_secret",
redirect_uri="https://example.com/callback",
token={
"access_token": "mock_access_token",
"token_type": "bearer",
"expires_in": 3600,
"refresh_token": "mock_refresh_token",
},
access_token="mock_access_token",
refresh_token="mock_refresh_token",
),
)
@ -458,7 +454,7 @@ class TestParseAndStoreAuthResponse:
"""Test with an OAuth auth scheme."""
mock_exchange_token.return_value = AuthCredential(
auth_type=AuthCredentialTypes.OAUTH2,
oauth2=OAuth2Auth(token={"access_token": "exchanged_token"}),
oauth2=OAuth2Auth(access_token="exchanged_token"),
)
handler = AuthHandler(auth_config_with_exchanged)
@ -573,6 +569,6 @@ class TestExchangeAuthToken:
handler = AuthHandler(auth_config_with_auth_code)
result = handler.exchange_auth_token()
assert result.oauth2.token["access_token"] == "mock_access_token"
assert result.oauth2.token["refresh_token"] == "mock_refresh_token"
assert result.oauth2.access_token == "mock_access_token"
assert result.oauth2.refresh_token == "mock_refresh_token"
assert result.auth_type == AuthCredentialTypes.OAUTH2

View File

@ -0,0 +1,109 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Dict, Optional
import pytest
from google.adk.agents import Agent
from google.adk.tools.function_tool import FunctionTool
from google.adk.tools.tool_context import ToolContext
from google.adk.flows.llm_flows.functions import handle_function_calls_async
from google.adk.events.event import Event
from google.genai import types
from ... import utils
class AsyncBeforeToolCallback:
def __init__(self, mock_response: Dict[str, Any]):
self.mock_response = mock_response
async def __call__(
self,
tool: FunctionTool,
args: Dict[str, Any],
tool_context: ToolContext,
) -> Optional[Dict[str, Any]]:
return self.mock_response
class AsyncAfterToolCallback:
def __init__(self, mock_response: Dict[str, Any]):
self.mock_response = mock_response
async def __call__(
self,
tool: FunctionTool,
args: Dict[str, Any],
tool_context: ToolContext,
tool_response: Dict[str, Any],
) -> Optional[Dict[str, Any]]:
return self.mock_response
async def invoke_tool_with_callbacks(
before_cb=None, after_cb=None
) -> Optional[Event]:
def simple_fn(**kwargs) -> Dict[str, Any]:
return {"initial": "response"}
tool = FunctionTool(simple_fn)
model = utils.MockModel.create(responses=[])
agent = Agent(
name="agent",
model=model,
tools=[tool],
before_tool_callback=before_cb,
after_tool_callback=after_cb,
)
invocation_context = utils.create_invocation_context(
agent=agent, user_content=""
)
# Build function call event
function_call = types.FunctionCall(name=tool.name, args={})
content = types.Content(parts=[types.Part(function_call=function_call)])
event = Event(
invocation_id=invocation_context.invocation_id,
author=agent.name,
content=content,
)
tools_dict = {tool.name: tool}
return await handle_function_calls_async(
invocation_context,
event,
tools_dict,
)
@pytest.mark.asyncio
async def test_async_before_tool_callback():
mock_resp = {"test": "before_tool_callback"}
before_cb = AsyncBeforeToolCallback(mock_resp)
result_event = await invoke_tool_with_callbacks(before_cb=before_cb)
assert result_event is not None
part = result_event.content.parts[0]
assert part.function_response.response == mock_resp
@pytest.mark.asyncio
async def test_async_after_tool_callback():
mock_resp = {"test": "after_tool_callback"}
after_cb = AsyncAfterToolCallback(mock_resp)
result_event = await invoke_tool_with_callbacks(after_cb=after_cb)
assert result_event is not None
part = result_event.content.parts[0]
assert part.function_response.response == mock_resp

View File

@ -246,7 +246,7 @@ def test_function_get_auth_response():
oauth2=OAuth2Auth(
client_id='oauth_client_id_1',
client_secret='oauth_client_secret1',
token={'access_token': 'token1'},
access_token='token1',
),
),
)
@ -277,7 +277,7 @@ def test_function_get_auth_response():
oauth2=OAuth2Auth(
client_id='oauth_client_id_2',
client_secret='oauth_client_secret2',
token={'access_token': 'token2'},
access_token='token2',
),
),
)

View File

@ -14,10 +14,12 @@
import json
from unittest import mock
from fastapi.openapi.models import Operation
from google.adk.auth.auth_credential import AuthCredential
from google.adk.tools.application_integration_tool.application_integration_toolset import ApplicationIntegrationToolset
from google.adk.tools.openapi_tool.openapi_spec_parser import rest_api_tool
from google.adk.tools.application_integration_tool.integration_connector_tool import IntegrationConnectorTool
from google.adk.tools.openapi_tool.openapi_spec_parser import ParsedOperation, rest_api_tool
from google.adk.tools.openapi_tool.openapi_spec_parser.openapi_spec_parser import OperationEndpoint
import pytest
@ -50,6 +52,59 @@ def mock_openapi_toolset():
yield mock_toolset
def get_mocked_parsed_operation(operation_id, attributes):
mock_openapi_spec_parser_instance = mock.MagicMock()
mock_parsed_operation = mock.MagicMock(spec=ParsedOperation)
mock_parsed_operation.name = "list_issues"
mock_parsed_operation.description = "list_issues_description"
mock_parsed_operation.endpoint = OperationEndpoint(
base_url="http://localhost:8080",
path="/v1/issues",
method="GET",
)
mock_parsed_operation.auth_scheme = None
mock_parsed_operation.auth_credential = None
mock_parsed_operation.additional_context = {}
mock_parsed_operation.parameters = []
mock_operation = mock.MagicMock(spec=Operation)
mock_operation.operationId = operation_id
mock_operation.description = "list_issues_description"
mock_operation.parameters = []
mock_operation.requestBody = None
mock_operation.responses = {}
mock_operation.callbacks = {}
for key, value in attributes.items():
setattr(mock_operation, key, value)
mock_parsed_operation.operation = mock_operation
mock_openapi_spec_parser_instance.parse.return_value = [mock_parsed_operation]
return mock_openapi_spec_parser_instance
@pytest.fixture
def mock_openapi_entity_spec_parser():
with mock.patch(
"google.adk.tools.application_integration_tool.application_integration_toolset.OpenApiSpecParser"
) as mock_spec_parser:
mock_openapi_spec_parser_instance = get_mocked_parsed_operation(
"list_issues", {"x-entity": "Issues", "x-operation": "LIST_ENTITIES"}
)
mock_spec_parser.return_value = mock_openapi_spec_parser_instance
yield mock_spec_parser
@pytest.fixture
def mock_openapi_action_spec_parser():
with mock.patch(
"google.adk.tools.application_integration_tool.application_integration_toolset.OpenApiSpecParser"
) as mock_spec_parser:
mock_openapi_action_spec_parser_instance = get_mocked_parsed_operation(
"list_issues_operation",
{"x-action": "CustomAction", "x-operation": "EXECUTE_ACTION"},
)
mock_spec_parser.return_value = mock_openapi_action_spec_parser_instance
yield mock_spec_parser
@pytest.fixture
def project():
return "test-project"
@ -72,7 +127,11 @@ def connection_spec():
@pytest.fixture
def connection_details():
return {"serviceName": "test-service", "host": "test.host"}
return {
"serviceName": "test-service",
"host": "test.host",
"name": "test-connection",
}
def test_initialization_with_integration_and_trigger(
@ -102,7 +161,7 @@ def test_initialization_with_connection_and_entity_operations(
location,
mock_integration_client,
mock_connections_client,
mock_openapi_toolset,
mock_openapi_entity_spec_parser,
connection_details,
):
connection_name = "test-connection"
@ -133,19 +192,17 @@ def test_initialization_with_connection_and_entity_operations(
mock_connections_client.assert_called_once_with(
project, location, connection_name, None
)
mock_openapi_entity_spec_parser.return_value.parse.assert_called_once()
mock_connections_client.return_value.get_connection_details.assert_called_once()
mock_integration_client.return_value.get_openapi_spec_for_connection.assert_called_once_with(
tool_name,
tool_instructions
+ f"ALWAYS use serviceName = {connection_details['serviceName']}, host ="
f" {connection_details['host']} and the connection name ="
f" projects/{project}/locations/{location}/connections/{connection_name} when"
" using this tool. DONOT ask the user for these values as you already"
" have those.",
tool_instructions,
)
mock_openapi_toolset.assert_called_once()
assert len(toolset.get_tools()) == 1
assert toolset.get_tools()[0].name == "Test Tool"
assert toolset.get_tools()[0].name == "list_issues"
assert isinstance(toolset.get_tools()[0], IntegrationConnectorTool)
assert toolset.get_tools()[0].entity == "Issues"
assert toolset.get_tools()[0].operation == "LIST_ENTITIES"
def test_initialization_with_connection_and_actions(
@ -153,7 +210,7 @@ def test_initialization_with_connection_and_actions(
location,
mock_integration_client,
mock_connections_client,
mock_openapi_toolset,
mock_openapi_action_spec_parser,
connection_details,
):
connection_name = "test-connection"
@ -181,15 +238,13 @@ def test_initialization_with_connection_and_actions(
mock_integration_client.return_value.get_openapi_spec_for_connection.assert_called_once_with(
tool_name,
tool_instructions
+ f"ALWAYS use serviceName = {connection_details['serviceName']}, host ="
f" {connection_details['host']} and the connection name ="
f" projects/{project}/locations/{location}/connections/{connection_name} when"
" using this tool. DONOT ask the user for these values as you already"
" have those.",
)
mock_openapi_toolset.assert_called_once()
mock_openapi_action_spec_parser.return_value.parse.assert_called_once()
assert len(toolset.get_tools()) == 1
assert toolset.get_tools()[0].name == "Test Tool"
assert toolset.get_tools()[0].name == "list_issues_operation"
assert isinstance(toolset.get_tools()[0], IntegrationConnectorTool)
assert toolset.get_tools()[0].action == "CustomAction"
assert toolset.get_tools()[0].operation == "EXECUTE_ACTION"
def test_initialization_without_required_params(project, location):
@ -337,9 +392,4 @@ def test_initialization_with_connection_details(
mock_integration_client.return_value.get_openapi_spec_for_connection.assert_called_once_with(
tool_name,
tool_instructions
+ "ALWAYS use serviceName = custom-service, host = custom.host and the"
" connection name ="
" projects/test-project/locations/us-central1/connections/test-connection"
" when using this tool. DONOT ask the user for these values as you"
" already have those.",
)

View File

@ -0,0 +1,125 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from unittest import mock
from google.adk.tools.application_integration_tool.integration_connector_tool import IntegrationConnectorTool
from google.adk.tools.openapi_tool.openapi_spec_parser.rest_api_tool import RestApiTool
from google.genai.types import FunctionDeclaration
from google.genai.types import Schema
from google.genai.types import Tool
from google.genai.types import Type
import pytest
@pytest.fixture
def mock_rest_api_tool():
"""Fixture for a mocked RestApiTool."""
mock_tool = mock.MagicMock(spec=RestApiTool)
mock_tool.name = "mock_rest_tool"
mock_tool.description = "Mock REST tool description."
# Mock the internal parser needed for _get_declaration
mock_parser = mock.MagicMock()
mock_parser.get_json_schema.return_value = {
"type": "object",
"properties": {
"user_id": {"type": "string", "description": "User ID"},
"connection_name": {"type": "string"},
"host": {"type": "string"},
"service_name": {"type": "string"},
"entity": {"type": "string"},
"operation": {"type": "string"},
"action": {"type": "string"},
"page_size": {"type": "integer"},
"filter": {"type": "string"},
},
"required": ["user_id", "page_size", "filter", "connection_name"],
}
mock_tool._operation_parser = mock_parser
mock_tool.call.return_value = {"status": "success", "data": "mock_data"}
return mock_tool
@pytest.fixture
def integration_tool(mock_rest_api_tool):
"""Fixture for an IntegrationConnectorTool instance."""
return IntegrationConnectorTool(
name="test_integration_tool",
description="Test integration tool description.",
connection_name="test-conn",
connection_host="test.example.com",
connection_service_name="test-service",
entity="TestEntity",
operation="LIST",
action="TestAction",
rest_api_tool=mock_rest_api_tool,
)
def test_get_declaration(integration_tool):
"""Tests the generation of the function declaration."""
declaration = integration_tool._get_declaration()
assert isinstance(declaration, FunctionDeclaration)
assert declaration.name == "test_integration_tool"
assert declaration.description == "Test integration tool description."
# Check parameters schema
params = declaration.parameters
assert isinstance(params, Schema)
print(f"params: {params}")
assert params.type == Type.OBJECT
# Check properties (excluded fields should not be present)
assert "user_id" in params.properties
assert "connection_name" not in params.properties
assert "host" not in params.properties
assert "service_name" not in params.properties
assert "entity" not in params.properties
assert "operation" not in params.properties
assert "action" not in params.properties
assert "page_size" in params.properties
assert "filter" in params.properties
# Check required fields (optional and excluded fields should not be required)
assert "user_id" in params.required
assert "page_size" not in params.required
assert "filter" not in params.required
assert "connection_name" not in params.required
@pytest.mark.asyncio
async def test_run_async(integration_tool, mock_rest_api_tool):
"""Tests the async execution delegates correctly to the RestApiTool."""
input_args = {"user_id": "user123", "page_size": 10}
expected_call_args = {
"user_id": "user123",
"page_size": 10,
"connection_name": "test-conn",
"host": "test.example.com",
"service_name": "test-service",
"entity": "TestEntity",
"operation": "LIST",
"action": "TestAction",
}
result = await integration_tool.run_async(args=input_args, tool_context=None)
# Assert the underlying rest_api_tool.call was called correctly
mock_rest_api_tool.call.assert_called_once_with(
args=expected_call_args, tool_context=None
)
# Assert the result is what the mocked call returned
assert result == {"status": "success", "data": "mock_data"}

View File

@ -110,7 +110,7 @@ def test_generate_auth_token_success(
client_secret="test_secret",
redirect_uri="http://localhost:8080",
auth_response_uri="https://example.com/callback?code=test_code",
token={"access_token": "test_access_token"},
access_token="test_access_token",
),
)
updated_credential = oauth2_exchanger.generate_auth_token(auth_credential)
@ -131,7 +131,7 @@ def test_exchange_credential_generate_auth_token(
client_secret="test_secret",
redirect_uri="http://localhost:8080",
auth_response_uri="https://example.com/callback?code=test_code",
token={"access_token": "test_access_token"},
access_token="test_access_token",
),
)

View File

@ -164,6 +164,18 @@ def test_process_request_body_no_name():
assert parser.params[0].param_location == 'body'
def test_process_request_body_empty_object():
"""Test _process_request_body with a schema that is of type object but with no properties."""
operation = Operation(
requestBody=RequestBody(
content={'application/json': MediaType(schema=Schema(type='object'))}
)
)
parser = OperationParser(operation, should_parse=False)
parser._process_request_body()
assert len(parser.params) == 0
def test_dedupe_param_names(sample_operation):
"""Test _dedupe_param_names method."""
parser = OperationParser(sample_operation, should_parse=False)

View File

@ -14,11 +14,9 @@
import json
from unittest.mock import MagicMock
from unittest.mock import patch
from unittest.mock import MagicMock, patch
from fastapi.openapi.models import MediaType
from fastapi.openapi.models import Operation
from fastapi.openapi.models import MediaType, Operation
from fastapi.openapi.models import Parameter as OpenAPIParameter
from fastapi.openapi.models import RequestBody
from fastapi.openapi.models import Schema as OpenAPISchema
@ -27,13 +25,13 @@ from google.adk.tools.openapi_tool.auth.auth_helpers import token_to_scheme_cred
from google.adk.tools.openapi_tool.common.common import ApiParameter
from google.adk.tools.openapi_tool.openapi_spec_parser.openapi_spec_parser import OperationEndpoint
from google.adk.tools.openapi_tool.openapi_spec_parser.operation_parser import OperationParser
from google.adk.tools.openapi_tool.openapi_spec_parser.rest_api_tool import RestApiTool
from google.adk.tools.openapi_tool.openapi_spec_parser.rest_api_tool import snake_to_lower_camel
from google.adk.tools.openapi_tool.openapi_spec_parser.rest_api_tool import to_gemini_schema
from google.adk.tools.openapi_tool.openapi_spec_parser.rest_api_tool import (
RestApiTool,
snake_to_lower_camel,
to_gemini_schema,
)
from google.adk.tools.tool_context import ToolContext
from google.genai.types import FunctionDeclaration
from google.genai.types import Schema
from google.genai.types import Type
from google.genai.types import FunctionDeclaration, Schema, Type
import pytest
@ -790,13 +788,13 @@ class TestToGeminiSchema:
result = to_gemini_schema({})
assert isinstance(result, Schema)
assert result.type == Type.OBJECT
assert result.properties == {"dummy_DO_NOT_GENERATE": Schema(type="string")}
assert result.properties is None
def test_to_gemini_schema_dict_with_only_object_type(self):
result = to_gemini_schema({"type": "object"})
assert isinstance(result, Schema)
assert result.type == Type.OBJECT
assert result.properties == {"dummy_DO_NOT_GENERATE": Schema(type="string")}
assert result.properties is None
def test_to_gemini_schema_basic_types(self):
openapi_schema = {
@ -814,6 +812,42 @@ class TestToGeminiSchema:
assert gemini_schema.properties["age"].type == Type.INTEGER
assert gemini_schema.properties["is_active"].type == Type.BOOLEAN
def test_to_gemini_schema_array_string_types(self):
openapi_schema = {
"type": "object",
"properties": {
"boolean_field": {"type": "boolean"},
"nonnullable_string": {"type": ["string"]},
"nullable_string": {"type": ["string", "null"]},
"nullable_number": {"type": ["null", "integer"]},
"object_nullable": {"type": "null"},
"multi_types_nullable": {"type": ["string", "null", "integer"]},
"empty_default_object": {},
},
}
gemini_schema = to_gemini_schema(openapi_schema)
assert isinstance(gemini_schema, Schema)
assert gemini_schema.type == Type.OBJECT
assert gemini_schema.properties["boolean_field"].type == Type.BOOLEAN
assert gemini_schema.properties["nonnullable_string"].type == Type.STRING
assert not gemini_schema.properties["nonnullable_string"].nullable
assert gemini_schema.properties["nullable_string"].type == Type.STRING
assert gemini_schema.properties["nullable_string"].nullable
assert gemini_schema.properties["nullable_number"].type == Type.INTEGER
assert gemini_schema.properties["nullable_number"].nullable
assert gemini_schema.properties["object_nullable"].type == Type.OBJECT
assert gemini_schema.properties["object_nullable"].nullable
assert gemini_schema.properties["multi_types_nullable"].type == Type.STRING
assert gemini_schema.properties["multi_types_nullable"].nullable
assert gemini_schema.properties["empty_default_object"].type == Type.OBJECT
assert not gemini_schema.properties["empty_default_object"].nullable
def test_to_gemini_schema_nested_objects(self):
openapi_schema = {
"type": "object",
@ -895,7 +929,15 @@ class TestToGeminiSchema:
def test_to_gemini_schema_nested_dict(self):
openapi_schema = {
"type": "object",
"properties": {"metadata": {"key1": "value1", "key2": 123}},
"properties": {
"metadata": {
"type": "object",
"properties": {
"key1": {"type": "object"},
"key2": {"type": "string"},
},
}
},
}
gemini_schema = to_gemini_schema(openapi_schema)
# Since metadata is not properties nor item, it will call to_gemini_schema recursively.
@ -903,9 +945,15 @@ class TestToGeminiSchema:
assert (
gemini_schema.properties["metadata"].type == Type.OBJECT
) # add object type by default
assert gemini_schema.properties["metadata"].properties == {
"dummy_DO_NOT_GENERATE": Schema(type="string")
}
assert len(gemini_schema.properties["metadata"].properties) == 2
assert (
gemini_schema.properties["metadata"].properties["key1"].type
== Type.OBJECT
)
assert (
gemini_schema.properties["metadata"].properties["key2"].type
== Type.STRING
)
def test_to_gemini_schema_ignore_title_default_format(self):
openapi_schema = {

View File

@ -0,0 +1,238 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from unittest.mock import MagicMock
from google.adk.tools.function_tool import FunctionTool
import pytest
def function_for_testing_with_no_args():
"""Function for testing with no args."""
pass
async def async_function_for_testing_with_1_arg_and_tool_context(
arg1, tool_context
):
"""Async function for testing with 1 arge and tool context."""
assert arg1
assert tool_context
return arg1
async def async_function_for_testing_with_2_arg_and_no_tool_context(arg1, arg2):
"""Async function for testing with 2 arge and no tool context."""
assert arg1
assert arg2
return arg1
def function_for_testing_with_1_arg_and_tool_context(arg1, tool_context):
"""Function for testing with 1 arge and tool context."""
assert arg1
assert tool_context
return arg1
def function_for_testing_with_2_arg_and_no_tool_context(arg1, arg2):
"""Function for testing with 2 arge and no tool context."""
assert arg1
assert arg2
return arg1
async def async_function_for_testing_with_4_arg_and_no_tool_context(
arg1, arg2, arg3, arg4
):
"""Async function for testing with 4 args."""
pass
def function_for_testing_with_4_arg_and_no_tool_context(arg1, arg2, arg3, arg4):
"""Function for testing with 4 args."""
pass
def test_init():
"""Test that the FunctionTool is initialized correctly."""
tool = FunctionTool(function_for_testing_with_no_args)
assert tool.name == "function_for_testing_with_no_args"
assert tool.description == "Function for testing with no args."
assert tool.func == function_for_testing_with_no_args
@pytest.mark.asyncio
async def test_run_async_with_tool_context_async_func():
"""Test that run_async calls the function with tool_context when tool_context is in signature (async function)."""
tool = FunctionTool(async_function_for_testing_with_1_arg_and_tool_context)
args = {"arg1": "test_value_1"}
result = await tool.run_async(args=args, tool_context=MagicMock())
assert result == "test_value_1"
@pytest.mark.asyncio
async def test_run_async_without_tool_context_async_func():
"""Test that run_async calls the function without tool_context when tool_context is not in signature (async function)."""
tool = FunctionTool(async_function_for_testing_with_2_arg_and_no_tool_context)
args = {"arg1": "test_value_1", "arg2": "test_value_2"}
result = await tool.run_async(args=args, tool_context=MagicMock())
assert result == "test_value_1"
@pytest.mark.asyncio
async def test_run_async_with_tool_context_sync_func():
"""Test that run_async calls the function with tool_context when tool_context is in signature (synchronous function)."""
tool = FunctionTool(function_for_testing_with_1_arg_and_tool_context)
args = {"arg1": "test_value_1"}
result = await tool.run_async(args=args, tool_context=MagicMock())
assert result == "test_value_1"
@pytest.mark.asyncio
async def test_run_async_without_tool_context_sync_func():
"""Test that run_async calls the function without tool_context when tool_context is not in signature (synchronous function)."""
tool = FunctionTool(function_for_testing_with_2_arg_and_no_tool_context)
args = {"arg1": "test_value_1", "arg2": "test_value_2"}
result = await tool.run_async(args=args, tool_context=MagicMock())
assert result == "test_value_1"
@pytest.mark.asyncio
async def test_run_async_1_missing_arg_sync_func():
"""Test that run_async calls the function with 1 missing arg in signature (synchronous function)."""
tool = FunctionTool(function_for_testing_with_2_arg_and_no_tool_context)
args = {"arg1": "test_value_1"}
result = await tool.run_async(args=args, tool_context=MagicMock())
assert result == {
"error": (
"""Invoking `function_for_testing_with_2_arg_and_no_tool_context()` failed as the following mandatory input parameters are not present:
arg2
You could retry calling this tool, but it is IMPORTANT for you to provide all the mandatory parameters."""
)
}
@pytest.mark.asyncio
async def test_run_async_1_missing_arg_async_func():
"""Test that run_async calls the function with 1 missing arg in signature (async function)."""
tool = FunctionTool(async_function_for_testing_with_2_arg_and_no_tool_context)
args = {"arg2": "test_value_1"}
result = await tool.run_async(args=args, tool_context=MagicMock())
assert result == {
"error": (
"""Invoking `async_function_for_testing_with_2_arg_and_no_tool_context()` failed as the following mandatory input parameters are not present:
arg1
You could retry calling this tool, but it is IMPORTANT for you to provide all the mandatory parameters."""
)
}
@pytest.mark.asyncio
async def test_run_async_3_missing_arg_sync_func():
"""Test that run_async calls the function with 3 missing args in signature (synchronous function)."""
tool = FunctionTool(function_for_testing_with_4_arg_and_no_tool_context)
args = {"arg2": "test_value_1"}
result = await tool.run_async(args=args, tool_context=MagicMock())
assert result == {
"error": (
"""Invoking `function_for_testing_with_4_arg_and_no_tool_context()` failed as the following mandatory input parameters are not present:
arg1
arg3
arg4
You could retry calling this tool, but it is IMPORTANT for you to provide all the mandatory parameters."""
)
}
@pytest.mark.asyncio
async def test_run_async_3_missing_arg_async_func():
"""Test that run_async calls the function with 3 missing args in signature (async function)."""
tool = FunctionTool(async_function_for_testing_with_4_arg_and_no_tool_context)
args = {"arg3": "test_value_1"}
result = await tool.run_async(args=args, tool_context=MagicMock())
assert result == {
"error": (
"""Invoking `async_function_for_testing_with_4_arg_and_no_tool_context()` failed as the following mandatory input parameters are not present:
arg1
arg2
arg4
You could retry calling this tool, but it is IMPORTANT for you to provide all the mandatory parameters."""
)
}
@pytest.mark.asyncio
async def test_run_async_missing_all_arg_sync_func():
"""Test that run_async calls the function with all missing args in signature (synchronous function)."""
tool = FunctionTool(function_for_testing_with_4_arg_and_no_tool_context)
args = {}
result = await tool.run_async(args=args, tool_context=MagicMock())
assert result == {
"error": (
"""Invoking `function_for_testing_with_4_arg_and_no_tool_context()` failed as the following mandatory input parameters are not present:
arg1
arg2
arg3
arg4
You could retry calling this tool, but it is IMPORTANT for you to provide all the mandatory parameters."""
)
}
@pytest.mark.asyncio
async def test_run_async_missing_all_arg_async_func():
"""Test that run_async calls the function with all missing args in signature (async function)."""
tool = FunctionTool(async_function_for_testing_with_4_arg_and_no_tool_context)
args = {}
result = await tool.run_async(args=args, tool_context=MagicMock())
assert result == {
"error": (
"""Invoking `async_function_for_testing_with_4_arg_and_no_tool_context()` failed as the following mandatory input parameters are not present:
arg1
arg2
arg3
arg4
You could retry calling this tool, but it is IMPORTANT for you to provide all the mandatory parameters."""
)
}
@pytest.mark.asyncio
async def test_run_async_with_optional_args_not_set_sync_func():
"""Test that run_async calls the function for sync funciton with optional args not set."""
def func_with_optional_args(arg1, arg2=None, *, arg3, arg4=None, **kwargs):
return f"{arg1},{arg3}"
tool = FunctionTool(func_with_optional_args)
args = {"arg1": "test_value_1", "arg3": "test_value_3"}
result = await tool.run_async(args=args, tool_context=MagicMock())
assert result == "test_value_1,test_value_3"
@pytest.mark.asyncio
async def test_run_async_with_optional_args_not_set_async_func():
"""Test that run_async calls the function for async funciton with optional args not set."""
async def async_func_with_optional_args(
arg1, arg2=None, *, arg3, arg4=None, **kwargs
):
return f"{arg1},{arg3}"
tool = FunctionTool(async_func_with_optional_args)
args = {"arg1": "test_value_1", "arg3": "test_value_3"}
result = await tool.run_async(args=args, tool_context=MagicMock())
assert result == "test_value_1,test_value_3"