Merge branch 'main' into #247-OpenAPIToolSet-Considering-Required-parameters

This commit is contained in:
Wei Sun (Jack) 2025-05-01 18:47:28 -07:00 committed by GitHub
commit f12300113d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
58 changed files with 1579 additions and 422 deletions

59
.github/workflows/pyink.yml vendored Normal file
View File

@ -0,0 +1,59 @@
name: Check Pyink Formatting
on:
pull_request:
paths:
- 'src/**/*.py'
- 'tests/**/*.py'
- 'pyproject.toml'
jobs:
pyink-check:
runs-on: ubuntu-latest
steps:
- name: Checkout code
uses: actions/checkout@v4
- name: Set up Python
uses: actions/setup-python@v5
with:
python-version: '3.x'
- name: Install pyink
run: |
pip install pyink
- name: Detect changed Python files
id: detect_changes
run: |
git fetch origin ${{ github.base_ref }}
CHANGED_FILES=$(git diff --diff-filter=ACMR --name-only origin/${{ github.base_ref }}...HEAD | grep -E '\.py$' || true)
echo "CHANGED_FILES=${CHANGED_FILES}" >> $GITHUB_ENV
- name: Run pyink on changed files
if: env.CHANGED_FILES != ''
run: |
echo "Changed Python files:"
echo "$CHANGED_FILES"
# Run pyink --check
set +e
pyink --check --config pyproject.toml $CHANGED_FILES
RESULT=$?
set -e
if [ $RESULT -ne 0 ]; then
echo ""
echo "❌ Pyink formatting check failed!"
echo "👉 To fix formatting, run locally:"
echo ""
echo " pyink --config pyproject.toml $CHANGED_FILES"
echo ""
exit $RESULT
fi
- name: No changed Python files detected
if: env.CHANGED_FILES == ''
run: |
echo "No Python files changed. Skipping pyink check."

View File

@ -29,11 +29,13 @@ jobs:
run: | run: |
uv venv .venv uv venv .venv
source .venv/bin/activate source .venv/bin/activate
uv sync --extra test uv sync --extra test --extra eval
- name: Run unit tests with pytest - name: Run unit tests with pytest
run: | run: |
source .venv/bin/activate source .venv/bin/activate
pytest tests/unittests \ pytest tests/unittests \
--ignore=tests/unittests/tools/google_api_tool/test_googleapi_to_openapi_converter.py \ --ignore=tests/unittests/artifacts/test_artifact_service.py \
--ignore=tests/unittests/artifacts/test_artifact_service.py --ignore=tests/unittests/tools/application_integration_tool/clients/test_connections_client.py \
--ignore=tests/unittests/tools/google_api_tool/test_googleapi_to_openapi_converter.py

View File

@ -1,17 +1,89 @@
# Changelog # Changelog
## 0.3.0
### ⚠ BREAKING CHANGES
* Auth: expose `access_token` and `refresh_token` at top level of auth
credentails, instead of a `dict`
([commit](https://github.com/google/adk-python/commit/956fb912e8851b139668b1ccb8db10fd252a6990)).
### Features
* Added support for running agents with MCPToolset easily on `adk web`.
* Added `custom_metadata` field to `LlmResponse`, which can be used to tag
LlmResponse via `after_model_callback`.
* Added `--session_db_url` to `adk deploy cloud_run` option.
* Many Dev UI improvements:
* Better google search result rendering.
* Show websocket close reason in Dev UI.
* Better error message showing for audio/video.
### Bug Fixes
* Fixed MCP tool json schema parsing issue.
* Fixed issues in DatabaseSessionService that leads to crash.
* Fixed functions.py.
* Fixed `skip_summarization` behavior in `AgentTool`.
### Miscellaneous Chores
* README.md impprovements.
* Various code improvements.
* Various typo fixes.
* Bump min version of google-genai to 1.11.0.
## 0.2.0
### ⚠ BREAKING CHANGES
* Fix typo in method name in `Event`: has_trailing_code_exeuction_result --> has_trailing_code_execution_result.
### Features
* `adk` CLI:
* Introduce `adk create` cli tool to help creating agents.
* Adds `--verbosity` option to `adk deploy cloud_run` to show detailed cloud
run deploy logging.
* Improve the initialization error message for `DatabaseSessionService`.
* Lazy loading for Google 1P tools to minimize the initial latency.
* Support emitting state-change-only events from planners.
* Lots of Dev UI updates, including:
* Show planner thoughts and actions in the Dev UI.
* Support MCP tools in Dev UI.
(NOTE: `agent.py` interface is temp solution and is subject to change)
* Auto-select the only app if only one app is available.
* Show grounding links generated by Google Search Tool.
* `.env` file is reloaded on every agent run.
### Bug Fixes
* `LiteLlm`: arg parsing error and python 3.9 compatibility.
* `DatabaseSessionService`: adds the missing fields; fixes event with empty
content not being persisted.
* Google API Discovery response parsing issue.
* `load_memory_tool` rendering issue in Dev UI.
* Markdown text overflows in Dev UI.
### Miscellaneous Chores
* Adds unit tests in Github action.
* Improves test coverage.
* Various typo fixes.
## 0.1.0 ## 0.1.0
### Features ### Features
* Initial release of the Agent Development Kit (ADK). * Initial release of the Agent Development Kit (ADK).
* Multi-agent, agent-as-workflow, and custom agent support * Multi-agent, agent-as-workflow, and custom agent support
* Tool authentication support * Tool authentication support
* Rich tool support, e.g. bult-in tools, google-cloud tools, thir-party tools, and MCP tools * Rich tool support, e.g. built-in tools, google-cloud tools, third-party tools, and MCP tools
* Rich callback support * Rich callback support
* Built-in code execution capability * Built-in code execution capability
* Asynchronous runtime and execution * Asynchronous runtime and execution
* Session, and memory support * Session, and memory support
* Built-in evaluation support * Built-in evaluation support
* Development UI that makes local devlopment easy * Development UI that makes local development easy
* Deploy to Google Cloud Run, Agent Engine * Deploy to Google Cloud Run, Agent Engine
* (Experimental) Live(Bidi) auido/video agent support and Compositional Function Calling(CFC) support * (Experimental) Live(Bidi) auido/video agent support and Compositional Function Calling(CFC) support

View File

@ -25,6 +25,19 @@ This project follows
## Contribution process ## Contribution process
### Requirement for PRs
- All PRs, other than small documentation or typo fixes, should have a Issue assoicated. If not, please create one.
- Small, focused PRs. Keep changes minimal—one concern per PR.
- For bug fixes or features, please provide logs or screenshot after the fix is applied to help reviewers better understand the fix.
- Please add corresponding testing for your code change if it's not covered by existing tests.
### Large or Complex Changes
For substantial features or architectural revisions:
- Open an Issue First: Outline your proposal, including design considerations and impact.
- Gather Feedback: Discuss with maintainers and the community to ensure alignment and avoid duplicate work
### Code reviews ### Code reviews
All submissions, including submissions by project members, require review. We All submissions, including submissions by project members, require review. We

View File

@ -18,11 +18,7 @@
</h3> </h3>
</html> </html>
Agent Development Kit (ADK) is designed for developers seeking fine-grained Agent Development Kit (ADK) is a flexible and modular framework for developing and deploying AI agents. While optimized for Gemini and the Google ecosystem, ADK is model-agnostic, deployment-agnostic, and is built for compatibility with other frameworks. ADK was designed to make agent development feel more like software development, to make it easier for developers to create, deploy, and orchestrate agentic architectures that range from simple tasks to complex workflows.
control and flexibility when building advanced AI agents that are tightly
integrated with services in Google Cloud. It allows you to define agent
behavior, orchestration, and tool use directly in code, enabling robust
debugging, versioning, and deployment anywhere from your laptop to the cloud.
--- ---
@ -45,12 +41,27 @@ debugging, versioning, and deployment anywhere from your laptop to the cloud
## 🚀 Installation ## 🚀 Installation
You can install the ADK using `pip`: ### Stable Release (Recommended)
You can install the latest stable version of ADK using `pip`:
```bash ```bash
pip install google-adk pip install google-adk
``` ```
The release cadence is weekly.
This version is recommended for most users as it represents the most recent official release.
### Development Version
Bug fixes and new features are merged into the main branch on GitHub first. If you need access to changes that haven't been included in an official PyPI release yet, you can install directly from the main branch:
```bash
pip install git+https://github.com/google/adk-python.git@main
```
Note: The development version is built directly from the latest code commits. While it includes the newest fixes and features, it may also contain experimental changes or bugs not present in the stable release. Use it primarily for testing upcoming changes or accessing critical fixes before they are officially released.
## 📚 Documentation ## 📚 Documentation
Explore the full documentation for detailed guides on building, evaluating, and Explore the full documentation for detailed guides on building, evaluating, and
@ -112,10 +123,18 @@ adk eval \
samples_for_testing/hello_world/hello_world_eval_set_001.evalset.json samples_for_testing/hello_world/hello_world_eval_set_001.evalset.json
``` ```
## 🤖 A2A and ADK integration
For remote agent-to-agent communication, ADK integrates with the
[A2A protocol](https://github.com/google/A2A/).
See this [example](https://github.com/google/A2A/tree/main/samples/python/agents/google_adk)
for how they can work together.
## 🤝 Contributing ## 🤝 Contributing
We welcome contributions from the community! Whether it's bug reports, feature requests, documentation improvements, or code contributions, please see our [**Contributing Guidelines**](./CONTRIBUTING.md) to get started. We welcome contributions from the community! Whether it's bug reports, feature requests, documentation improvements, or code contributions, please see our
- [General contribution guideline and flow](https://google.github.io/adk-docs/contributing-guide/#questions).
- Then if you want to contribute code, please read [Code Contributing Guidelines](./CONTRIBUTING.md) to get started.
## 📄 License ## 📄 License

View File

@ -45,7 +45,7 @@ confidence=
# can either give multiple identifiers separated by comma (,) or put this # can either give multiple identifiers separated by comma (,) or put this
# option multiple times (only on the command line, not in the configuration # option multiple times (only on the command line, not in the configuration
# file where it should appear only once).You can also use "--disable=all" to # file where it should appear only once).You can also use "--disable=all" to
# disable everything first and then reenable specific checks. For example, if # disable everything first and then re-enable specific checks. For example, if
# you want to run only the similarities checker, you can use "--disable=all # you want to run only the similarities checker, you can use "--disable=all
# --enable=similarities". If you want to run only the classes checker, but have # --enable=similarities". If you want to run only the classes checker, but have
# no Warning level messages displayed, use"--disable=all --enable=classes # no Warning level messages displayed, use"--disable=all --enable=classes

View File

@ -33,7 +33,7 @@ dependencies = [
"google-cloud-secret-manager>=2.22.0", # Fetching secrets in RestAPI Tool "google-cloud-secret-manager>=2.22.0", # Fetching secrets in RestAPI Tool
"google-cloud-speech>=2.30.0", # For Audo Transcription "google-cloud-speech>=2.30.0", # For Audo Transcription
"google-cloud-storage>=2.18.0, <3.0.0", # For GCS Artifact service "google-cloud-storage>=2.18.0, <3.0.0", # For GCS Artifact service
"google-genai>=1.9.0", # Google GenAI SDK "google-genai>=1.11.0", # Google GenAI SDK
"graphviz>=0.20.2", # Graphviz for graph rendering "graphviz>=0.20.2", # Graphviz for graph rendering
"mcp>=1.5.0;python_version>='3.10'", # For MCP Toolset "mcp>=1.5.0;python_version>='3.10'", # For MCP Toolset
"opentelemetry-api>=1.31.0", # OpenTelemetry "opentelemetry-api>=1.31.0", # OpenTelemetry
@ -119,6 +119,15 @@ line-length = 80
unstable = true unstable = true
pyink-indentation = 2 pyink-indentation = 2
pyink-use-majority-quotes = true pyink-use-majority-quotes = true
pyink-annotation-pragmas = [
"noqa",
"pylint:",
"type: ignore",
"pytype:",
"mypy:",
"pyright:",
"pyre-",
]
[build-system] [build-system]
@ -135,15 +144,10 @@ exclude = ['src/**/*.sh']
name = "google.adk" name = "google.adk"
[tool.isort] [tool.isort]
# Organize imports following Google style-guide profile = "google"
force_single_line = true
force_sort_within_sections = true
honor_case_in_force_sorted_sections = true
order_by_type = false
sort_relative_in_force_sorted_sections = true
multi_line_output = 3
line_length = 200
[tool.pytest.ini_options] [tool.pytest.ini_options]
testpaths = ["tests"] testpaths = ["tests"]
asyncio_default_fixture_loop_scope = "function" asyncio_default_fixture_loop_scope = "function"
asyncio_mode = "auto"

View File

@ -44,7 +44,7 @@ Args:
callback_context: MUST be named 'callback_context' (enforced). callback_context: MUST be named 'callback_context' (enforced).
Returns: Returns:
The content to return to the user. When set, the agent run will skipped and The content to return to the user. When set, the agent run will be skipped and
the provided content will be returned to user. the provided content will be returned to user.
""" """
@ -55,8 +55,8 @@ Args:
callback_context: MUST be named 'callback_context' (enforced). callback_context: MUST be named 'callback_context' (enforced).
Returns: Returns:
The content to return to the user. When set, the agent run will skipped and The content to return to the user. When set, the provided content will be
the provided content will be appended to event history as agent response. appended to event history as agent response.
""" """
@ -101,8 +101,8 @@ class BaseAgent(BaseModel):
callback_context: MUST be named 'callback_context' (enforced). callback_context: MUST be named 'callback_context' (enforced).
Returns: Returns:
The content to return to the user. When set, the agent run will skipped and The content to return to the user. When set, the agent run will be skipped
the provided content will be returned to user. and the provided content will be returned to user.
""" """
after_agent_callback: Optional[AfterAgentCallback] = None after_agent_callback: Optional[AfterAgentCallback] = None
"""Callback signature that is invoked after the agent run. """Callback signature that is invoked after the agent run.
@ -111,8 +111,8 @@ class BaseAgent(BaseModel):
callback_context: MUST be named 'callback_context' (enforced). callback_context: MUST be named 'callback_context' (enforced).
Returns: Returns:
The content to return to the user. When set, the agent run will skipped and The content to return to the user. When set, the provided content will be
the provided content will be appended to event history as agent response. appended to event history as agent response.
""" """
@final @final

View File

@ -23,7 +23,6 @@ from .readonly_context import ReadonlyContext
if TYPE_CHECKING: if TYPE_CHECKING:
from google.genai import types from google.genai import types
from ..events.event import Event
from ..events.event_actions import EventActions from ..events.event_actions import EventActions
from ..sessions.state import State from ..sessions.state import State
from .invocation_context import InvocationContext from .invocation_context import InvocationContext

View File

@ -15,12 +15,7 @@
from __future__ import annotations from __future__ import annotations
import logging import logging
from typing import Any from typing import Any, AsyncGenerator, Awaitable, Callable, Literal, Optional, Union
from typing import AsyncGenerator
from typing import Callable
from typing import Literal
from typing import Optional
from typing import Union
from google.genai import types from google.genai import types
from pydantic import BaseModel from pydantic import BaseModel
@ -62,11 +57,11 @@ AfterModelCallback: TypeAlias = Callable[
] ]
BeforeToolCallback: TypeAlias = Callable[ BeforeToolCallback: TypeAlias = Callable[
[BaseTool, dict[str, Any], ToolContext], [BaseTool, dict[str, Any], ToolContext],
Optional[dict], Union[Awaitable[Optional[dict]], Optional[dict]],
] ]
AfterToolCallback: TypeAlias = Callable[ AfterToolCallback: TypeAlias = Callable[
[BaseTool, dict[str, Any], ToolContext, dict], [BaseTool, dict[str, Any], ToolContext, dict],
Optional[dict], Union[Awaitable[Optional[dict]], Optional[dict]],
] ]
InstructionProvider: TypeAlias = Callable[[ReadonlyContext], str] InstructionProvider: TypeAlias = Callable[[ReadonlyContext], str]

View File

@ -66,7 +66,8 @@ class OAuth2Auth(BaseModelWithConfig):
redirect_uri: Optional[str] = None redirect_uri: Optional[str] = None
auth_response_uri: Optional[str] = None auth_response_uri: Optional[str] = None
auth_code: Optional[str] = None auth_code: Optional[str] = None
token: Optional[Dict[str, Any]] = None access_token: Optional[str] = None
refresh_token: Optional[str] = None
class ServiceAccountCredential(BaseModelWithConfig): class ServiceAccountCredential(BaseModelWithConfig):

View File

@ -82,7 +82,8 @@ class AuthHandler:
or not auth_credential.oauth2 or not auth_credential.oauth2
or not auth_credential.oauth2.client_id or not auth_credential.oauth2.client_id
or not auth_credential.oauth2.client_secret or not auth_credential.oauth2.client_secret
or auth_credential.oauth2.token or auth_credential.oauth2.access_token
or auth_credential.oauth2.refresh_token
): ):
return self.auth_config.exchanged_auth_credential return self.auth_config.exchanged_auth_credential
@ -93,7 +94,7 @@ class AuthHandler:
redirect_uri=auth_credential.oauth2.redirect_uri, redirect_uri=auth_credential.oauth2.redirect_uri,
state=auth_credential.oauth2.state, state=auth_credential.oauth2.state,
) )
token = client.fetch_token( tokens = client.fetch_token(
token_endpoint, token_endpoint,
authorization_response=auth_credential.oauth2.auth_response_uri, authorization_response=auth_credential.oauth2.auth_response_uri,
code=auth_credential.oauth2.auth_code, code=auth_credential.oauth2.auth_code,
@ -102,7 +103,10 @@ class AuthHandler:
updated_credential = AuthCredential( updated_credential = AuthCredential(
auth_type=AuthCredentialTypes.OAUTH2, auth_type=AuthCredentialTypes.OAUTH2,
oauth2=OAuth2Auth(token=dict(token)), oauth2=OAuth2Auth(
access_token=tokens.get("access_token"),
refresh_token=tokens.get("refresh_token"),
),
) )
return updated_credential return updated_credential

View File

@ -29,5 +29,5 @@
<style>html{color-scheme:dark}html{--mat-sys-background:light-dark(#fcf9f8, #131314);--mat-sys-error:light-dark(#ba1a1a, #ffb4ab);--mat-sys-error-container:light-dark(#ffdad6, #93000a);--mat-sys-inverse-on-surface:light-dark(#f3f0f0, #313030);--mat-sys-inverse-primary:light-dark(#c1c7cd, #595f65);--mat-sys-inverse-surface:light-dark(#313030, #e5e2e2);--mat-sys-on-background:light-dark(#1c1b1c, #e5e2e2);--mat-sys-on-error:light-dark(#ffffff, #690005);--mat-sys-on-error-container:light-dark(#410002, #ffdad6);--mat-sys-on-primary:light-dark(#ffffff, #2b3136);--mat-sys-on-primary-container:light-dark(#161c21, #dde3e9);--mat-sys-on-primary-fixed:light-dark(#161c21, #161c21);--mat-sys-on-primary-fixed-variant:light-dark(#41474d, #41474d);--mat-sys-on-secondary:light-dark(#ffffff, #003061);--mat-sys-on-secondary-container:light-dark(#001b3c, #d5e3ff);--mat-sys-on-secondary-fixed:light-dark(#001b3c, #001b3c);--mat-sys-on-secondary-fixed-variant:light-dark(#0f4784, #0f4784);--mat-sys-on-surface:light-dark(#1c1b1c, #e5e2e2);--mat-sys-on-surface-variant:light-dark(#44474a, #e1e2e6);--mat-sys-on-tertiary:light-dark(#ffffff, #2b3136);--mat-sys-on-tertiary-container:light-dark(#161c21, #dde3e9);--mat-sys-on-tertiary-fixed:light-dark(#161c21, #161c21);--mat-sys-on-tertiary-fixed-variant:light-dark(#41474d, #41474d);--mat-sys-outline:light-dark(#74777b, #8e9194);--mat-sys-outline-variant:light-dark(#c4c7ca, #44474a);--mat-sys-primary:light-dark(#595f65, #c1c7cd);--mat-sys-primary-container:light-dark(#dde3e9, #41474d);--mat-sys-primary-fixed:light-dark(#dde3e9, #dde3e9);--mat-sys-primary-fixed-dim:light-dark(#c1c7cd, #c1c7cd);--mat-sys-scrim:light-dark(#000000, #000000);--mat-sys-secondary:light-dark(#305f9d, #a7c8ff);--mat-sys-secondary-container:light-dark(#d5e3ff, #0f4784);--mat-sys-secondary-fixed:light-dark(#d5e3ff, #d5e3ff);--mat-sys-secondary-fixed-dim:light-dark(#a7c8ff, #a7c8ff);--mat-sys-shadow:light-dark(#000000, #000000);--mat-sys-surface:light-dark(#fcf9f8, #131314);--mat-sys-surface-bright:light-dark(#fcf9f8, #393939);--mat-sys-surface-container:light-dark(#f0eded, #201f20);--mat-sys-surface-container-high:light-dark(#eae7e7, #2a2a2a);--mat-sys-surface-container-highest:light-dark(#e5e2e2, #393939);--mat-sys-surface-container-low:light-dark(#f6f3f3, #1c1b1c);--mat-sys-surface-container-lowest:light-dark(#ffffff, #0e0e0e);--mat-sys-surface-dim:light-dark(#dcd9d9, #131314);--mat-sys-surface-tint:light-dark(#595f65, #c1c7cd);--mat-sys-surface-variant:light-dark(#e1e2e6, #44474a);--mat-sys-tertiary:light-dark(#595f65, #c1c7cd);--mat-sys-tertiary-container:light-dark(#dde3e9, #41474d);--mat-sys-tertiary-fixed:light-dark(#dde3e9, #dde3e9);--mat-sys-tertiary-fixed-dim:light-dark(#c1c7cd, #c1c7cd);--mat-sys-neutral-variant20:#2d3134;--mat-sys-neutral10:#1c1b1c}html{--mat-sys-level0:0px 0px 0px 0px rgba(0, 0, 0, .2), 0px 0px 0px 0px rgba(0, 0, 0, .14), 0px 0px 0px 0px rgba(0, 0, 0, .12)}html{--mat-sys-level1:0px 2px 1px -1px rgba(0, 0, 0, .2), 0px 1px 1px 0px rgba(0, 0, 0, .14), 0px 1px 3px 0px rgba(0, 0, 0, .12)}html{--mat-sys-level2:0px 3px 3px -2px rgba(0, 0, 0, .2), 0px 3px 4px 0px rgba(0, 0, 0, .14), 0px 1px 8px 0px rgba(0, 0, 0, .12)}html{--mat-sys-level3:0px 3px 5px -1px rgba(0, 0, 0, .2), 0px 6px 10px 0px rgba(0, 0, 0, .14), 0px 1px 18px 0px rgba(0, 0, 0, .12)}html{--mat-sys-level4:0px 5px 5px -3px rgba(0, 0, 0, .2), 0px 8px 10px 1px rgba(0, 0, 0, .14), 0px 3px 14px 2px rgba(0, 0, 0, .12)}html{--mat-sys-level5:0px 7px 8px -4px rgba(0, 0, 0, .2), 0px 12px 17px 2px rgba(0, 0, 0, .14), 0px 5px 22px 4px rgba(0, 0, 0, .12)}html{--mat-sys-corner-extra-large:28px;--mat-sys-corner-extra-large-top:28px 28px 0 0;--mat-sys-corner-extra-small:4px;--mat-sys-corner-extra-small-top:4px 4px 0 0;--mat-sys-corner-full:9999px;--mat-sys-corner-large:16px;--mat-sys-corner-large-end:0 16px 16px 0;--mat-sys-corner-large-start:16px 0 0 16px;--mat-sys-corner-large-top:16px 16px 0 0;--mat-sys-corner-medium:12px;--mat-sys-corner-none:0;--mat-sys-corner-small:8px}html{--mat-sys-dragged-state-layer-opacity:.16;--mat-sys-focus-state-layer-opacity:.12;--mat-sys-hover-state-layer-opacity:.08;--mat-sys-pressed-state-layer-opacity:.12}html{font-family:Google Sans,Helvetica Neue,sans-serif!important}body{height:100vh;margin:0}:root{--mat-sys-primary:black;--mdc-checkbox-selected-icon-color:white;--mat-sys-background:#131314;--mat-tab-header-active-label-text-color:#8AB4F8;--mat-tab-header-active-hover-label-text-color:#8AB4F8;--mat-tab-header-active-focus-label-text-color:#8AB4F8;--mat-tab-header-label-text-weight:500;--mdc-text-button-label-text-color:#89b4f8}:root{--mdc-dialog-container-color:#2b2b2f}:root{--mdc-dialog-subhead-color:white}:root{--mdc-circular-progress-active-indicator-color:#a8c7fa}:root{--mdc-circular-progress-size:80}</style><link rel="stylesheet" href="styles-4VDSPQ37.css" media="print" onload="this.media='all'"><noscript><link rel="stylesheet" href="styles-4VDSPQ37.css"></noscript></head> <style>html{color-scheme:dark}html{--mat-sys-background:light-dark(#fcf9f8, #131314);--mat-sys-error:light-dark(#ba1a1a, #ffb4ab);--mat-sys-error-container:light-dark(#ffdad6, #93000a);--mat-sys-inverse-on-surface:light-dark(#f3f0f0, #313030);--mat-sys-inverse-primary:light-dark(#c1c7cd, #595f65);--mat-sys-inverse-surface:light-dark(#313030, #e5e2e2);--mat-sys-on-background:light-dark(#1c1b1c, #e5e2e2);--mat-sys-on-error:light-dark(#ffffff, #690005);--mat-sys-on-error-container:light-dark(#410002, #ffdad6);--mat-sys-on-primary:light-dark(#ffffff, #2b3136);--mat-sys-on-primary-container:light-dark(#161c21, #dde3e9);--mat-sys-on-primary-fixed:light-dark(#161c21, #161c21);--mat-sys-on-primary-fixed-variant:light-dark(#41474d, #41474d);--mat-sys-on-secondary:light-dark(#ffffff, #003061);--mat-sys-on-secondary-container:light-dark(#001b3c, #d5e3ff);--mat-sys-on-secondary-fixed:light-dark(#001b3c, #001b3c);--mat-sys-on-secondary-fixed-variant:light-dark(#0f4784, #0f4784);--mat-sys-on-surface:light-dark(#1c1b1c, #e5e2e2);--mat-sys-on-surface-variant:light-dark(#44474a, #e1e2e6);--mat-sys-on-tertiary:light-dark(#ffffff, #2b3136);--mat-sys-on-tertiary-container:light-dark(#161c21, #dde3e9);--mat-sys-on-tertiary-fixed:light-dark(#161c21, #161c21);--mat-sys-on-tertiary-fixed-variant:light-dark(#41474d, #41474d);--mat-sys-outline:light-dark(#74777b, #8e9194);--mat-sys-outline-variant:light-dark(#c4c7ca, #44474a);--mat-sys-primary:light-dark(#595f65, #c1c7cd);--mat-sys-primary-container:light-dark(#dde3e9, #41474d);--mat-sys-primary-fixed:light-dark(#dde3e9, #dde3e9);--mat-sys-primary-fixed-dim:light-dark(#c1c7cd, #c1c7cd);--mat-sys-scrim:light-dark(#000000, #000000);--mat-sys-secondary:light-dark(#305f9d, #a7c8ff);--mat-sys-secondary-container:light-dark(#d5e3ff, #0f4784);--mat-sys-secondary-fixed:light-dark(#d5e3ff, #d5e3ff);--mat-sys-secondary-fixed-dim:light-dark(#a7c8ff, #a7c8ff);--mat-sys-shadow:light-dark(#000000, #000000);--mat-sys-surface:light-dark(#fcf9f8, #131314);--mat-sys-surface-bright:light-dark(#fcf9f8, #393939);--mat-sys-surface-container:light-dark(#f0eded, #201f20);--mat-sys-surface-container-high:light-dark(#eae7e7, #2a2a2a);--mat-sys-surface-container-highest:light-dark(#e5e2e2, #393939);--mat-sys-surface-container-low:light-dark(#f6f3f3, #1c1b1c);--mat-sys-surface-container-lowest:light-dark(#ffffff, #0e0e0e);--mat-sys-surface-dim:light-dark(#dcd9d9, #131314);--mat-sys-surface-tint:light-dark(#595f65, #c1c7cd);--mat-sys-surface-variant:light-dark(#e1e2e6, #44474a);--mat-sys-tertiary:light-dark(#595f65, #c1c7cd);--mat-sys-tertiary-container:light-dark(#dde3e9, #41474d);--mat-sys-tertiary-fixed:light-dark(#dde3e9, #dde3e9);--mat-sys-tertiary-fixed-dim:light-dark(#c1c7cd, #c1c7cd);--mat-sys-neutral-variant20:#2d3134;--mat-sys-neutral10:#1c1b1c}html{--mat-sys-level0:0px 0px 0px 0px rgba(0, 0, 0, .2), 0px 0px 0px 0px rgba(0, 0, 0, .14), 0px 0px 0px 0px rgba(0, 0, 0, .12)}html{--mat-sys-level1:0px 2px 1px -1px rgba(0, 0, 0, .2), 0px 1px 1px 0px rgba(0, 0, 0, .14), 0px 1px 3px 0px rgba(0, 0, 0, .12)}html{--mat-sys-level2:0px 3px 3px -2px rgba(0, 0, 0, .2), 0px 3px 4px 0px rgba(0, 0, 0, .14), 0px 1px 8px 0px rgba(0, 0, 0, .12)}html{--mat-sys-level3:0px 3px 5px -1px rgba(0, 0, 0, .2), 0px 6px 10px 0px rgba(0, 0, 0, .14), 0px 1px 18px 0px rgba(0, 0, 0, .12)}html{--mat-sys-level4:0px 5px 5px -3px rgba(0, 0, 0, .2), 0px 8px 10px 1px rgba(0, 0, 0, .14), 0px 3px 14px 2px rgba(0, 0, 0, .12)}html{--mat-sys-level5:0px 7px 8px -4px rgba(0, 0, 0, .2), 0px 12px 17px 2px rgba(0, 0, 0, .14), 0px 5px 22px 4px rgba(0, 0, 0, .12)}html{--mat-sys-corner-extra-large:28px;--mat-sys-corner-extra-large-top:28px 28px 0 0;--mat-sys-corner-extra-small:4px;--mat-sys-corner-extra-small-top:4px 4px 0 0;--mat-sys-corner-full:9999px;--mat-sys-corner-large:16px;--mat-sys-corner-large-end:0 16px 16px 0;--mat-sys-corner-large-start:16px 0 0 16px;--mat-sys-corner-large-top:16px 16px 0 0;--mat-sys-corner-medium:12px;--mat-sys-corner-none:0;--mat-sys-corner-small:8px}html{--mat-sys-dragged-state-layer-opacity:.16;--mat-sys-focus-state-layer-opacity:.12;--mat-sys-hover-state-layer-opacity:.08;--mat-sys-pressed-state-layer-opacity:.12}html{font-family:Google Sans,Helvetica Neue,sans-serif!important}body{height:100vh;margin:0}:root{--mat-sys-primary:black;--mdc-checkbox-selected-icon-color:white;--mat-sys-background:#131314;--mat-tab-header-active-label-text-color:#8AB4F8;--mat-tab-header-active-hover-label-text-color:#8AB4F8;--mat-tab-header-active-focus-label-text-color:#8AB4F8;--mat-tab-header-label-text-weight:500;--mdc-text-button-label-text-color:#89b4f8}:root{--mdc-dialog-container-color:#2b2b2f}:root{--mdc-dialog-subhead-color:white}:root{--mdc-circular-progress-active-indicator-color:#a8c7fa}:root{--mdc-circular-progress-size:80}</style><link rel="stylesheet" href="styles-4VDSPQ37.css" media="print" onload="this.media='all'"><noscript><link rel="stylesheet" href="styles-4VDSPQ37.css"></noscript></head>
<body> <body>
<app-root></app-root> <app-root></app-root>
<script src="polyfills-FFHMD2TL.js" type="module"></script><script src="main-ZBO76GRM.js" type="module"></script></body> <script src="polyfills-FFHMD2TL.js" type="module"></script><script src="main-HWIBUY2R.js" type="module"></script></body>
</html> </html>

File diff suppressed because one or more lines are too long

View File

@ -39,12 +39,12 @@ class InputFile(BaseModel):
async def run_input_file( async def run_input_file(
app_name: str, app_name: str,
user_id: str,
root_agent: LlmAgent, root_agent: LlmAgent,
artifact_service: BaseArtifactService, artifact_service: BaseArtifactService,
session: Session,
session_service: BaseSessionService, session_service: BaseSessionService,
input_path: str, input_path: str,
) -> None: ) -> Session:
runner = Runner( runner = Runner(
app_name=app_name, app_name=app_name,
agent=root_agent, agent=root_agent,
@ -55,9 +55,11 @@ async def run_input_file(
input_file = InputFile.model_validate_json(f.read()) input_file = InputFile.model_validate_json(f.read())
input_file.state['_time'] = datetime.now() input_file.state['_time'] = datetime.now()
session.state = input_file.state session = session_service.create_session(
app_name=app_name, user_id=user_id, state=input_file.state
)
for query in input_file.queries: for query in input_file.queries:
click.echo(f'user: {query}') click.echo(f'[user]: {query}')
content = types.Content(role='user', parts=[types.Part(text=query)]) content = types.Content(role='user', parts=[types.Part(text=query)])
async for event in runner.run_async( async for event in runner.run_async(
user_id=session.user_id, session_id=session.id, new_message=content user_id=session.user_id, session_id=session.id, new_message=content
@ -65,23 +67,23 @@ async def run_input_file(
if event.content and event.content.parts: if event.content and event.content.parts:
if text := ''.join(part.text or '' for part in event.content.parts): if text := ''.join(part.text or '' for part in event.content.parts):
click.echo(f'[{event.author}]: {text}') click.echo(f'[{event.author}]: {text}')
return session
async def run_interactively( async def run_interactively(
app_name: str,
root_agent: LlmAgent, root_agent: LlmAgent,
artifact_service: BaseArtifactService, artifact_service: BaseArtifactService,
session: Session, session: Session,
session_service: BaseSessionService, session_service: BaseSessionService,
) -> None: ) -> None:
runner = Runner( runner = Runner(
app_name=app_name, app_name=session.app_name,
agent=root_agent, agent=root_agent,
artifact_service=artifact_service, artifact_service=artifact_service,
session_service=session_service, session_service=session_service,
) )
while True: while True:
query = input('user: ') query = input('[user]: ')
if not query or not query.strip(): if not query or not query.strip():
continue continue
if query == 'exit': if query == 'exit':
@ -100,7 +102,8 @@ async def run_cli(
*, *,
agent_parent_dir: str, agent_parent_dir: str,
agent_folder_name: str, agent_folder_name: str,
json_file_path: Optional[str] = None, input_file: Optional[str] = None,
saved_session_file: Optional[str] = None,
save_session: bool, save_session: bool,
) -> None: ) -> None:
"""Runs an interactive CLI for a certain agent. """Runs an interactive CLI for a certain agent.
@ -109,8 +112,11 @@ async def run_cli(
agent_parent_dir: str, the absolute path of the parent folder of the agent agent_parent_dir: str, the absolute path of the parent folder of the agent
folder. folder.
agent_folder_name: str, the name of the agent folder. agent_folder_name: str, the name of the agent folder.
json_file_path: Optional[str], the absolute path to the json file, either input_file: Optional[str], the absolute path to the json file that contains
*.input.json or *.session.json. the initial session state and user queries, exclusive with
saved_session_file.
saved_session_file: Optional[str], the absolute path to the json file that
contains a previously saved session, exclusive with input_file.
save_session: bool, whether to save the session on exit. save_session: bool, whether to save the session on exit.
""" """
if agent_parent_dir not in sys.path: if agent_parent_dir not in sys.path:
@ -118,46 +124,50 @@ async def run_cli(
artifact_service = InMemoryArtifactService() artifact_service = InMemoryArtifactService()
session_service = InMemorySessionService() session_service = InMemorySessionService()
session = session_service.create_session(
app_name=agent_folder_name, user_id='test_user'
)
agent_module_path = os.path.join(agent_parent_dir, agent_folder_name) agent_module_path = os.path.join(agent_parent_dir, agent_folder_name)
agent_module = importlib.import_module(agent_folder_name) agent_module = importlib.import_module(agent_folder_name)
user_id = 'test_user'
session = session_service.create_session(
app_name=agent_folder_name, user_id=user_id
)
root_agent = agent_module.agent.root_agent root_agent = agent_module.agent.root_agent
envs.load_dotenv_for_agent(agent_folder_name, agent_parent_dir) envs.load_dotenv_for_agent(agent_folder_name, agent_parent_dir)
if json_file_path: if input_file:
if json_file_path.endswith('.input.json'): session = await run_input_file(
await run_input_file( app_name=agent_folder_name,
app_name=agent_folder_name, user_id=user_id,
root_agent=root_agent, root_agent=root_agent,
artifact_service=artifact_service, artifact_service=artifact_service,
session=session, session_service=session_service,
session_service=session_service, input_path=input_file,
input_path=json_file_path, )
) elif saved_session_file:
elif json_file_path.endswith('.session.json'):
with open(json_file_path, 'r') as f: loaded_session = None
session = Session.model_validate_json(f.read()) with open(saved_session_file, 'r') as f:
for content in session.get_contents(): loaded_session = Session.model_validate_json(f.read())
if content.role == 'user':
print('user: ', content.parts[0].text) if loaded_session:
for event in loaded_session.events:
session_service.append_event(session, event)
content = event.content
if not content or not content.parts or not content.parts[0].text:
continue
if event.author == 'user':
click.echo(f'[user]: {content.parts[0].text}')
else: else:
print(content.parts[0].text) click.echo(f'[{event.author}]: {content.parts[0].text}')
await run_interactively(
agent_folder_name, await run_interactively(
root_agent, root_agent,
artifact_service, artifact_service,
session, session,
session_service, session_service,
) )
else: else:
print(f'Unsupported file type: {json_file_path}') click.echo(f'Running agent {root_agent.name}, type exit to exit.')
exit(1)
else:
print(f'Running agent {root_agent.name}, type exit to exit.')
await run_interactively( await run_interactively(
agent_folder_name,
root_agent, root_agent,
artifact_service, artifact_service,
session, session,
@ -165,11 +175,8 @@ async def run_cli(
) )
if save_session: if save_session:
if json_file_path: session_id = input('Session ID to save: ')
session_path = json_file_path.replace('.input.json', '.session.json') session_path = f'{agent_module_path}/{session_id}.session.json'
else:
session_id = input('Session ID to save: ')
session_path = f'{agent_module_path}/{session_id}.session.json'
# Fetch the session again to get all the details. # Fetch the session again to get all the details.
session = session_service.get_session( session = session_service.get_session(

View File

@ -54,7 +54,7 @@ COPY "agents/{app_name}/" "/app/agents/{app_name}/"
EXPOSE {port} EXPOSE {port}
CMD adk {command} --port={port} {trace_to_cloud_option} "/app/agents" CMD adk {command} --port={port} {session_db_option} {trace_to_cloud_option} "/app/agents"
""" """
@ -85,6 +85,7 @@ def to_cloud_run(
trace_to_cloud: bool, trace_to_cloud: bool,
with_ui: bool, with_ui: bool,
verbosity: str, verbosity: str,
session_db_url: str,
): ):
"""Deploys an agent to Google Cloud Run. """Deploys an agent to Google Cloud Run.
@ -112,6 +113,7 @@ def to_cloud_run(
trace_to_cloud: Whether to enable Cloud Trace. trace_to_cloud: Whether to enable Cloud Trace.
with_ui: Whether to deploy with UI. with_ui: Whether to deploy with UI.
verbosity: The verbosity level of the CLI. verbosity: The verbosity level of the CLI.
session_db_url: The database URL to connect the session.
""" """
app_name = app_name or os.path.basename(agent_folder) app_name = app_name or os.path.basename(agent_folder)
@ -144,6 +146,9 @@ def to_cloud_run(
port=port, port=port,
command='web' if with_ui else 'api_server', command='web' if with_ui else 'api_server',
install_agent_deps=install_agent_deps, install_agent_deps=install_agent_deps,
session_db_option=f'--session_db_url={session_db_url}'
if session_db_url
else '',
trace_to_cloud_option='--trace_to_cloud' if trace_to_cloud else '', trace_to_cloud_option='--trace_to_cloud' if trace_to_cloud else '',
) )
dockerfile_path = os.path.join(temp_folder, 'Dockerfile') dockerfile_path = os.path.join(temp_folder, 'Dockerfile')

View File

@ -256,7 +256,7 @@ def run_evals(
) )
if final_eval_status == EvalStatus.PASSED: if final_eval_status == EvalStatus.PASSED:
result = "✅ Passsed" result = "✅ Passed"
else: else:
result = "❌ Failed" result = "❌ Failed"

View File

@ -96,6 +96,23 @@ def cli_create_cmd(
) )
def validate_exclusive(ctx, param, value):
# Store the validated parameters in the context
if not hasattr(ctx, "exclusive_opts"):
ctx.exclusive_opts = {}
# If this option has a value and we've already seen another exclusive option
if value is not None and any(ctx.exclusive_opts.values()):
exclusive_opt = next(key for key, val in ctx.exclusive_opts.items() if val)
raise click.UsageError(
f"Options '{param.name}' and '{exclusive_opt}' cannot be set together."
)
# Record this option's value
ctx.exclusive_opts[param.name] = value is not None
return value
@main.command("run") @main.command("run")
@click.option( @click.option(
"--save_session", "--save_session",
@ -105,13 +122,43 @@ def cli_create_cmd(
default=False, default=False,
help="Optional. Whether to save the session to a json file on exit.", help="Optional. Whether to save the session to a json file on exit.",
) )
@click.option(
"--replay",
type=click.Path(
exists=True, dir_okay=False, file_okay=True, resolve_path=True
),
help=(
"The json file that contains the initial state of the session and user"
" queries. A new session will be created using this state. And user"
" queries are run againt the newly created session. Users cannot"
" continue to interact with the agent."
),
callback=validate_exclusive,
)
@click.option(
"--resume",
type=click.Path(
exists=True, dir_okay=False, file_okay=True, resolve_path=True
),
help=(
"The json file that contains a previously saved session (by"
"--save_session option). The previous session will be re-displayed. And"
" user can continue to interact with the agent."
),
callback=validate_exclusive,
)
@click.argument( @click.argument(
"agent", "agent",
type=click.Path( type=click.Path(
exists=True, dir_okay=True, file_okay=False, resolve_path=True exists=True, dir_okay=True, file_okay=False, resolve_path=True
), ),
) )
def cli_run(agent: str, save_session: bool): def cli_run(
agent: str,
save_session: bool,
replay: Optional[str],
resume: Optional[str],
):
"""Runs an interactive CLI for a certain agent. """Runs an interactive CLI for a certain agent.
AGENT: The path to the agent source code folder. AGENT: The path to the agent source code folder.
@ -129,6 +176,8 @@ def cli_run(agent: str, save_session: bool):
run_cli( run_cli(
agent_parent_dir=agent_parent_folder, agent_parent_dir=agent_parent_folder,
agent_folder_name=agent_folder_name, agent_folder_name=agent_folder_name,
input_file=replay,
saved_session_file=resume,
save_session=save_session, save_session=save_session,
) )
) )
@ -245,12 +294,13 @@ def cli_eval(
@click.option( @click.option(
"--session_db_url", "--session_db_url",
help=( help=(
"Optional. The database URL to store the session.\n\n - Use" """Optional. The database URL to store the session.
" 'agentengine://<agent_engine_resource_id>' to connect to Vertex"
" managed session service.\n\n - Use 'sqlite://<path_to_sqlite_file>'" - Use 'agentengine://<agent_engine_resource_id>' to connect to Agent Engine sessions.
" to connect to a SQLite DB.\n\n - See"
" https://docs.sqlalchemy.org/en/20/core/engines.html#backend-specific-urls" - Use 'sqlite://<path_to_sqlite_file>' to connect to a SQLite DB.
" for more details on supported DB URLs."
- See https://docs.sqlalchemy.org/en/20/core/engines.html#backend-specific-urls for more details on supported DB URLs."""
), ),
) )
@click.option( @click.option(
@ -366,12 +416,13 @@ def cli_web(
@click.option( @click.option(
"--session_db_url", "--session_db_url",
help=( help=(
"Optional. The database URL to store the session.\n\n - Use" """Optional. The database URL to store the session.
" 'agentengine://<agent_engine_resource_id>' to connect to Vertex"
" managed session service.\n\n - Use 'sqlite://<path_to_sqlite_file>'" - Use 'agentengine://<agent_engine_resource_id>' to connect to Agent Engine sessions.
" to connect to a SQLite DB.\n\n - See"
" https://docs.sqlalchemy.org/en/20/core/engines.html#backend-specific-urls" - Use 'sqlite://<path_to_sqlite_file>' to connect to a SQLite DB.
" for more details on supported DB URLs."
- See https://docs.sqlalchemy.org/en/20/core/engines.html#backend-specific-urls for more details on supported DB URLs."""
), ),
) )
@click.option( @click.option(
@ -541,6 +592,18 @@ def cli_api_server(
default="WARNING", default="WARNING",
help="Optional. Override the default verbosity level.", help="Optional. Override the default verbosity level.",
) )
@click.option(
"--session_db_url",
help=(
"""Optional. The database URL to store the session.
- Use 'agentengine://<agent_engine_resource_id>' to connect to Agent Engine sessions.
- Use 'sqlite://<path_to_sqlite_file>' to connect to a SQLite DB.
- See https://docs.sqlalchemy.org/en/20/core/engines.html#backend-specific-urls for more details on supported DB URLs."""
),
)
@click.argument( @click.argument(
"agent", "agent",
type=click.Path( type=click.Path(
@ -558,6 +621,7 @@ def cli_deploy_cloud_run(
trace_to_cloud: bool, trace_to_cloud: bool,
with_ui: bool, with_ui: bool,
verbosity: str, verbosity: str,
session_db_url: str,
): ):
"""Deploys an agent to Cloud Run. """Deploys an agent to Cloud Run.
@ -579,6 +643,7 @@ def cli_deploy_cloud_run(
trace_to_cloud=trace_to_cloud, trace_to_cloud=trace_to_cloud,
with_ui=with_ui, with_ui=with_ui,
verbosity=verbosity, verbosity=verbosity,
session_db_url=session_db_url,
) )
except Exception as e: except Exception as e:
click.secho(f"Deploy failed: {e}", fg="red", err=True) click.secho(f"Deploy failed: {e}", fg="red", err=True)

View File

@ -756,6 +756,12 @@ def get_fast_api_app(
except Exception as e: except Exception as e:
logger.exception("Error during live websocket communication: %s", e) logger.exception("Error during live websocket communication: %s", e)
traceback.print_exc() traceback.print_exc()
WEBSOCKET_INTERNAL_ERROR_CODE = 1011
WEBSOCKET_MAX_BYTES_FOR_REASON = 123
await websocket.close(
code=WEBSOCKET_INTERNAL_ERROR_CODE,
reason=str(e)[:WEBSOCKET_MAX_BYTES_FOR_REASON],
)
finally: finally:
for task in pending: for task in pending:
task.cancel() task.cancel()

View File

@ -55,7 +55,7 @@ def load_json(file_path: str) -> Union[Dict, List]:
class AgentEvaluator: class AgentEvaluator:
"""An evaluator for Agents, mainly intented for helping with test cases.""" """An evaluator for Agents, mainly intended for helping with test cases."""
@staticmethod @staticmethod
def find_config_for_test_file(test_file: str): def find_config_for_test_file(test_file: str):
@ -91,7 +91,7 @@ class AgentEvaluator:
look for 'root_agent' in the loaded module. look for 'root_agent' in the loaded module.
eval_dataset: The eval data set. This can be either a string representing eval_dataset: The eval data set. This can be either a string representing
full path to the file containing eval dataset, or a directory that is full path to the file containing eval dataset, or a directory that is
recusively explored for all files that have a `.test.json` suffix. recursively explored for all files that have a `.test.json` suffix.
num_runs: Number of times all entries in the eval dataset should be num_runs: Number of times all entries in the eval dataset should be
assessed. assessed.
agent_name: The name of the agent. agent_name: The name of the agent.

View File

@ -35,7 +35,7 @@ class ResponseEvaluator:
Args: Args:
raw_eval_dataset: The dataset that will be evaluated. raw_eval_dataset: The dataset that will be evaluated.
evaluation_criteria: The evaluation criteria to be used. This method evaluation_criteria: The evaluation criteria to be used. This method
support two criterias, `response_evaluation_score` and support two criteria, `response_evaluation_score` and
`response_match_score`. `response_match_score`.
print_detailed_results: Prints detailed results on the console. This is print_detailed_results: Prints detailed results on the console. This is
usually helpful during debugging. usually helpful during debugging.
@ -56,7 +56,7 @@ class ResponseEvaluator:
Value range: [0, 5], where 0 means that the agent's response is not Value range: [0, 5], where 0 means that the agent's response is not
coherent, while 5 means it is . High values are good. coherent, while 5 means it is . High values are good.
A note on raw_eval_dataset: A note on raw_eval_dataset:
The dataset should be a list session, where each sesssion is represented The dataset should be a list session, where each session is represented
as a list of interaction that need evaluation. Each evaluation is as a list of interaction that need evaluation. Each evaluation is
represented as a dictionary that is expected to have values for the represented as a dictionary that is expected to have values for the
following keys: following keys:

View File

@ -31,10 +31,9 @@ class TrajectoryEvaluator:
): ):
r"""Returns the mean tool use accuracy of the eval dataset. r"""Returns the mean tool use accuracy of the eval dataset.
Tool use accuracy is calculated by comparing the expected and actuall tool Tool use accuracy is calculated by comparing the expected and the actual
use trajectories. An exact match scores a 1, 0 otherwise. The final number tool use trajectories. An exact match scores a 1, 0 otherwise. The final
is an number is an average of these individual scores.
average of these individual scores.
Value range: [0, 1], where 0 is means none of the too use entries aligned, Value range: [0, 1], where 0 is means none of the too use entries aligned,
and 1 would mean all of them aligned. Higher value is good. and 1 would mean all of them aligned. Higher value is good.
@ -45,7 +44,7 @@ class TrajectoryEvaluator:
usually helpful during debugging. usually helpful during debugging.
A note on eval_dataset: A note on eval_dataset:
The dataset should be a list session, where each sesssion is represented The dataset should be a list session, where each session is represented
as a list of interaction that need evaluation. Each evaluation is as a list of interaction that need evaluation. Each evaluation is
represented as a dictionary that is expected to have values for the represented as a dictionary that is expected to have values for the
following keys: following keys:

View File

@ -48,8 +48,13 @@ class EventActions(BaseModel):
"""The agent is escalating to a higher level agent.""" """The agent is escalating to a higher level agent."""
requested_auth_configs: dict[str, AuthConfig] = Field(default_factory=dict) requested_auth_configs: dict[str, AuthConfig] = Field(default_factory=dict)
"""Will only be set by a tool response indicating tool request euc. """Authentication configurations requested by tool responses.
dict key is the function call id since one function call response (from model)
could correspond to multiple function calls. This field will only be set by a tool response event indicating tool request
dict value is the required auth config. auth credential.
- Keys: The function call id. Since one function response event could contain
multiple function responses that correspond to multiple function calls. Each
function call could request different auth configs. This id is used to
identify the function call.
- Values: The requested auth config.
""" """

View File

@ -94,7 +94,7 @@ can answer it.
If another agent is better for answering the question according to its If another agent is better for answering the question according to its
description, call `{_TRANSFER_TO_AGENT_FUNCTION_NAME}` function to transfer the description, call `{_TRANSFER_TO_AGENT_FUNCTION_NAME}` function to transfer the
question to that agent. When transfering, do not generate any text other than question to that agent. When transferring, do not generate any text other than
the function call. the function call.
""" """

View File

@ -115,7 +115,7 @@ class BaseLlmFlow(ABC):
yield event yield event
# send back the function response # send back the function response
if event.get_function_responses(): if event.get_function_responses():
logger.debug('Sending back last function resonse event: %s', event) logger.debug('Sending back last function response event: %s', event)
invocation_context.live_request_queue.send_content(event.content) invocation_context.live_request_queue.send_content(event.content)
if ( if (
event.content event.content

View File

@ -111,7 +111,7 @@ def _rearrange_events_for_latest_function_response(
"""Rearrange the events for the latest function_response. """Rearrange the events for the latest function_response.
If the latest function_response is for an async function_call, all events If the latest function_response is for an async function_call, all events
bewteen the initial function_call and the latest function_response will be between the initial function_call and the latest function_response will be
removed. removed.
Args: Args:

View File

@ -151,28 +151,33 @@ async def handle_function_calls_async(
# do not use "args" as the variable name, because it is a reserved keyword # do not use "args" as the variable name, because it is a reserved keyword
# in python debugger. # in python debugger.
function_args = function_call.args or {} function_args = function_call.args or {}
function_response = None function_response: Optional[dict] = None
# Calls the tool if before_tool_callback does not exist or returns None.
# before_tool_callback (sync or async)
if agent.before_tool_callback: if agent.before_tool_callback:
function_response = agent.before_tool_callback( function_response = agent.before_tool_callback(
tool=tool, args=function_args, tool_context=tool_context tool=tool, args=function_args, tool_context=tool_context
) )
if inspect.isawaitable(function_response):
function_response = await function_response
if not function_response: if not function_response:
function_response = await __call_tool_async( function_response = await __call_tool_async(
tool, args=function_args, tool_context=tool_context tool, args=function_args, tool_context=tool_context
) )
# Calls after_tool_callback if it exists. # after_tool_callback (sync or async)
if agent.after_tool_callback: if agent.after_tool_callback:
new_response = agent.after_tool_callback( altered_function_response = agent.after_tool_callback(
tool=tool, tool=tool,
args=function_args, args=function_args,
tool_context=tool_context, tool_context=tool_context,
tool_response=function_response, tool_response=function_response,
) )
if new_response: if inspect.isawaitable(altered_function_response):
function_response = new_response altered_function_response = await altered_function_response
if altered_function_response is not None:
function_response = altered_function_response
if tool.is_long_running: if tool.is_long_running:
# Allow long running function to return None to not provide function response. # Allow long running function to return None to not provide function response.
@ -223,11 +228,17 @@ async def handle_function_calls_live(
# in python debugger. # in python debugger.
function_args = function_call.args or {} function_args = function_call.args or {}
function_response = None function_response = None
# Calls the tool if before_tool_callback does not exist or returns None. # # Calls the tool if before_tool_callback does not exist or returns None.
# if agent.before_tool_callback:
# function_response = agent.before_tool_callback(
# tool, function_args, tool_context
# )
if agent.before_tool_callback: if agent.before_tool_callback:
function_response = agent.before_tool_callback( function_response = agent.before_tool_callback(
tool, function_args, tool_context tool=tool, args=function_args, tool_context=tool_context
) )
if inspect.isawaitable(function_response):
function_response = await function_response
if not function_response: if not function_response:
function_response = await _process_function_live_helper( function_response = await _process_function_live_helper(
@ -235,15 +246,26 @@ async def handle_function_calls_live(
) )
# Calls after_tool_callback if it exists. # Calls after_tool_callback if it exists.
# if agent.after_tool_callback:
# new_response = agent.after_tool_callback(
# tool,
# function_args,
# tool_context,
# function_response,
# )
# if new_response:
# function_response = new_response
if agent.after_tool_callback: if agent.after_tool_callback:
new_response = agent.after_tool_callback( altered_function_response = agent.after_tool_callback(
tool, tool=tool,
function_args, args=function_args,
tool_context, tool_context=tool_context,
function_response, tool_response=function_response,
) )
if new_response: if inspect.isawaitable(altered_function_response):
function_response = new_response altered_function_response = await altered_function_response
if altered_function_response is not None:
function_response = altered_function_response
if tool.is_long_running: if tool.is_long_running:
# Allow async function to return None to not provide function response. # Allow async function to return None to not provide function response.
@ -310,9 +332,7 @@ async def _process_function_live_helper(
function_response = { function_response = {
'status': f'No active streaming function named {function_name} found' 'status': f'No active streaming function named {function_name} found'
} }
elif inspect.isasyncgenfunction(tool.func): elif hasattr(tool, "func") and inspect.isasyncgenfunction(tool.func):
print('is async')
# for streaming tool use case # for streaming tool use case
# we require the function to be a async generator function # we require the function to be a async generator function
async def run_tool_and_update_queue(tool, function_args, tool_context): async def run_tool_and_update_queue(tool, function_args, tool_context):

View File

@ -52,7 +52,7 @@ class _InstructionsLlmRequestProcessor(BaseLlmRequestProcessor):
# Appends global instructions if set. # Appends global instructions if set.
if ( if (
isinstance(root_agent, LlmAgent) and root_agent.global_instruction isinstance(root_agent, LlmAgent) and root_agent.global_instruction
): # not emtpy str ): # not empty str
raw_si = root_agent.canonical_global_instruction( raw_si = root_agent.canonical_global_instruction(
ReadonlyContext(invocation_context) ReadonlyContext(invocation_context)
) )
@ -60,7 +60,7 @@ class _InstructionsLlmRequestProcessor(BaseLlmRequestProcessor):
llm_request.append_instructions([si]) llm_request.append_instructions([si])
# Appends agent instructions if set. # Appends agent instructions if set.
if agent.instruction: # not emtpy str if agent.instruction: # not empty str
raw_si = agent.canonical_instruction(ReadonlyContext(invocation_context)) raw_si = agent.canonical_instruction(ReadonlyContext(invocation_context))
si = _populate_values(raw_si, invocation_context) si = _populate_values(raw_si, invocation_context)
llm_request.append_instructions([si]) llm_request.append_instructions([si])

View File

@ -152,7 +152,7 @@ class GeminiLlmConnection(BaseLlmConnection):
): ):
# TODO: Right now, we just support output_transcription without # TODO: Right now, we just support output_transcription without
# changing interface and data protocol. Later, we can consider to # changing interface and data protocol. Later, we can consider to
# support output_transcription as a separete field in LlmResponse. # support output_transcription as a separate field in LlmResponse.
# Transcription is always considered as partial event # Transcription is always considered as partial event
# We rely on other control signals to determine when to yield the # We rely on other control signals to determine when to yield the
@ -179,7 +179,7 @@ class GeminiLlmConnection(BaseLlmConnection):
# in case of empty content or parts, we sill surface it # in case of empty content or parts, we sill surface it
# in case it's an interrupted message, we merge the previous partial # in case it's an interrupted message, we merge the previous partial
# text. Other we don't merge. because content can be none when model # text. Other we don't merge. because content can be none when model
# safty threshold is triggered # safety threshold is triggered
if message.server_content.interrupted and text: if message.server_content.interrupted and text:
yield self.__build_full_text_response(text) yield self.__build_full_text_response(text)
text = '' text = ''

View File

@ -14,7 +14,7 @@
from __future__ import annotations from __future__ import annotations
from typing import Optional from typing import Any, Optional
from google.genai import types from google.genai import types
from pydantic import BaseModel from pydantic import BaseModel
@ -37,6 +37,7 @@ class LlmResponse(BaseModel):
error_message: Error message if the response is an error. error_message: Error message if the response is an error.
interrupted: Flag indicating that LLM was interrupted when generating the interrupted: Flag indicating that LLM was interrupted when generating the
content. Usually it's due to user interruption during a bidi streaming. content. Usually it's due to user interruption during a bidi streaming.
custom_metadata: The custom metadata of the LlmResponse.
""" """
model_config = ConfigDict(extra='forbid') model_config = ConfigDict(extra='forbid')
@ -71,6 +72,14 @@ class LlmResponse(BaseModel):
Usually it's due to user interruption during a bidi streaming. Usually it's due to user interruption during a bidi streaming.
""" """
custom_metadata: Optional[dict[str, Any]] = None
"""The custom metadata of the LlmResponse.
An optional key-value pair to label an LlmResponse.
NOTE: the entire dict must be JSON serializable.
"""
@staticmethod @staticmethod
def create( def create(
generate_content_response: types.GenerateContentResponse, generate_content_response: types.GenerateContentResponse,

View File

@ -56,6 +56,7 @@ class BuiltInPlanner(BasePlanner):
llm_request: The LLM request to apply the thinking config to. llm_request: The LLM request to apply the thinking config to.
""" """
if self.thinking_config: if self.thinking_config:
llm_request.config = llm_request.config or types.GenerateContentConfig()
llm_request.config.thinking_config = self.thinking_config llm_request.config.thinking_config = self.thinking_config
@override @override

View File

@ -0,0 +1,29 @@
"""Utility functions for session service."""
import base64
from typing import Any, Optional
from google.genai import types
def encode_content(content: types.Content):
"""Encodes a content object to a JSON dictionary."""
encoded_content = content.model_dump(exclude_none=True)
for p in encoded_content["parts"]:
if "inline_data" in p:
p["inline_data"]["data"] = base64.b64encode(
p["inline_data"]["data"]
).decode("utf-8")
return encoded_content
def decode_content(
content: Optional[dict[str, Any]],
) -> Optional[types.Content]:
"""Decodes a content object from a JSON dictionary."""
if not content:
return None
for p in content["parts"]:
if "inline_data" in p:
p["inline_data"]["data"] = base64.b64decode(p["inline_data"]["data"])
return types.Content.model_validate(content)

View File

@ -11,14 +11,11 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import base64
import copy import copy
from datetime import datetime from datetime import datetime
import json import json
import logging import logging
from typing import Any from typing import Any, Optional
from typing import Optional
import uuid import uuid
from sqlalchemy import Boolean from sqlalchemy import Boolean
@ -27,6 +24,7 @@ from sqlalchemy import Dialect
from sqlalchemy import ForeignKeyConstraint from sqlalchemy import ForeignKeyConstraint
from sqlalchemy import func from sqlalchemy import func
from sqlalchemy import Text from sqlalchemy import Text
from sqlalchemy.dialects import mysql
from sqlalchemy.dialects import postgresql from sqlalchemy.dialects import postgresql
from sqlalchemy.engine import create_engine from sqlalchemy.engine import create_engine
from sqlalchemy.engine import Engine from sqlalchemy.engine import Engine
@ -48,6 +46,7 @@ from typing_extensions import override
from tzlocal import get_localzone from tzlocal import get_localzone
from ..events.event import Event from ..events.event import Event
from . import _session_util
from .base_session_service import BaseSessionService from .base_session_service import BaseSessionService
from .base_session_service import GetSessionConfig from .base_session_service import GetSessionConfig
from .base_session_service import ListEventsResponse from .base_session_service import ListEventsResponse
@ -58,6 +57,9 @@ from .state import State
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
DEFAULT_MAX_KEY_LENGTH = 128
DEFAULT_MAX_VARCHAR_LENGTH = 256
class DynamicJSON(TypeDecorator): class DynamicJSON(TypeDecorator):
"""A JSON-like type that uses JSONB on PostgreSQL and TEXT with JSON """A JSON-like type that uses JSONB on PostgreSQL and TEXT with JSON
@ -70,15 +72,16 @@ class DynamicJSON(TypeDecorator):
def load_dialect_impl(self, dialect: Dialect): def load_dialect_impl(self, dialect: Dialect):
if dialect.name == "postgresql": if dialect.name == "postgresql":
return dialect.type_descriptor(postgresql.JSONB) return dialect.type_descriptor(postgresql.JSONB)
else: if dialect.name == "mysql":
return dialect.type_descriptor(Text) # Default to Text for other dialects # Use LONGTEXT for MySQL to address the data too long issue
return dialect.type_descriptor(mysql.LONGTEXT)
return dialect.type_descriptor(Text) # Default to Text for other dialects
def process_bind_param(self, value, dialect: Dialect): def process_bind_param(self, value, dialect: Dialect):
if value is not None: if value is not None:
if dialect.name == "postgresql": if dialect.name == "postgresql":
return value # JSONB handles dict directly return value # JSONB handles dict directly
else: return json.dumps(value) # Serialize to JSON string for TEXT
return json.dumps(value) # Serialize to JSON string for TEXT
return value return value
def process_result_value(self, value, dialect: Dialect): def process_result_value(self, value, dialect: Dialect):
@ -92,17 +95,25 @@ class DynamicJSON(TypeDecorator):
class Base(DeclarativeBase): class Base(DeclarativeBase):
"""Base class for database tables.""" """Base class for database tables."""
pass pass
class StorageSession(Base): class StorageSession(Base):
"""Represents a session stored in the database.""" """Represents a session stored in the database."""
__tablename__ = "sessions" __tablename__ = "sessions"
app_name: Mapped[str] = mapped_column(String, primary_key=True) app_name: Mapped[str] = mapped_column(
user_id: Mapped[str] = mapped_column(String, primary_key=True) String(DEFAULT_MAX_KEY_LENGTH), primary_key=True
)
user_id: Mapped[str] = mapped_column(
String(DEFAULT_MAX_KEY_LENGTH), primary_key=True
)
id: Mapped[str] = mapped_column( id: Mapped[str] = mapped_column(
String, primary_key=True, default=lambda: str(uuid.uuid4()) String(DEFAULT_MAX_KEY_LENGTH),
primary_key=True,
default=lambda: str(uuid.uuid4()),
) )
state: Mapped[MutableDict[str, Any]] = mapped_column( state: Mapped[MutableDict[str, Any]] = mapped_column(
@ -125,18 +136,29 @@ class StorageSession(Base):
class StorageEvent(Base): class StorageEvent(Base):
"""Represents an event stored in the database.""" """Represents an event stored in the database."""
__tablename__ = "events" __tablename__ = "events"
id: Mapped[str] = mapped_column(String, primary_key=True) id: Mapped[str] = mapped_column(
app_name: Mapped[str] = mapped_column(String, primary_key=True) String(DEFAULT_MAX_KEY_LENGTH), primary_key=True
user_id: Mapped[str] = mapped_column(String, primary_key=True) )
session_id: Mapped[str] = mapped_column(String, primary_key=True) app_name: Mapped[str] = mapped_column(
String(DEFAULT_MAX_KEY_LENGTH), primary_key=True
)
user_id: Mapped[str] = mapped_column(
String(DEFAULT_MAX_KEY_LENGTH), primary_key=True
)
session_id: Mapped[str] = mapped_column(
String(DEFAULT_MAX_KEY_LENGTH), primary_key=True
)
invocation_id: Mapped[str] = mapped_column(String) invocation_id: Mapped[str] = mapped_column(String(DEFAULT_MAX_VARCHAR_LENGTH))
author: Mapped[str] = mapped_column(String) author: Mapped[str] = mapped_column(String(DEFAULT_MAX_VARCHAR_LENGTH))
branch: Mapped[str] = mapped_column(String, nullable=True) branch: Mapped[str] = mapped_column(
String(DEFAULT_MAX_VARCHAR_LENGTH), nullable=True
)
timestamp: Mapped[DateTime] = mapped_column(DateTime(), default=func.now()) timestamp: Mapped[DateTime] = mapped_column(DateTime(), default=func.now())
content: Mapped[dict[str, Any]] = mapped_column(DynamicJSON) content: Mapped[dict[str, Any]] = mapped_column(DynamicJSON, nullable=True)
actions: Mapped[MutableDict[str, Any]] = mapped_column(PickleType) actions: Mapped[MutableDict[str, Any]] = mapped_column(PickleType)
long_running_tool_ids_json: Mapped[Optional[str]] = mapped_column( long_running_tool_ids_json: Mapped[Optional[str]] = mapped_column(
@ -147,8 +169,10 @@ class StorageEvent(Base):
) )
partial: Mapped[bool] = mapped_column(Boolean, nullable=True) partial: Mapped[bool] = mapped_column(Boolean, nullable=True)
turn_complete: Mapped[bool] = mapped_column(Boolean, nullable=True) turn_complete: Mapped[bool] = mapped_column(Boolean, nullable=True)
error_code: Mapped[str] = mapped_column(String, nullable=True) error_code: Mapped[str] = mapped_column(
error_message: Mapped[str] = mapped_column(String, nullable=True) String(DEFAULT_MAX_VARCHAR_LENGTH), nullable=True
)
error_message: Mapped[str] = mapped_column(String(1024), nullable=True)
interrupted: Mapped[bool] = mapped_column(Boolean, nullable=True) interrupted: Mapped[bool] = mapped_column(Boolean, nullable=True)
storage_session: Mapped[StorageSession] = relationship( storage_session: Mapped[StorageSession] = relationship(
@ -182,9 +206,12 @@ class StorageEvent(Base):
class StorageAppState(Base): class StorageAppState(Base):
"""Represents an app state stored in the database.""" """Represents an app state stored in the database."""
__tablename__ = "app_states" __tablename__ = "app_states"
app_name: Mapped[str] = mapped_column(String, primary_key=True) app_name: Mapped[str] = mapped_column(
String(DEFAULT_MAX_KEY_LENGTH), primary_key=True
)
state: Mapped[MutableDict[str, Any]] = mapped_column( state: Mapped[MutableDict[str, Any]] = mapped_column(
MutableDict.as_mutable(DynamicJSON), default={} MutableDict.as_mutable(DynamicJSON), default={}
) )
@ -192,13 +219,17 @@ class StorageAppState(Base):
DateTime(), default=func.now(), onupdate=func.now() DateTime(), default=func.now(), onupdate=func.now()
) )
class StorageUserState(Base): class StorageUserState(Base):
"""Represents a user state stored in the database.""" """Represents a user state stored in the database."""
__tablename__ = "user_states" __tablename__ = "user_states"
app_name: Mapped[str] = mapped_column(String, primary_key=True) app_name: Mapped[str] = mapped_column(
user_id: Mapped[str] = mapped_column(String, primary_key=True) String(DEFAULT_MAX_KEY_LENGTH), primary_key=True
)
user_id: Mapped[str] = mapped_column(
String(DEFAULT_MAX_KEY_LENGTH), primary_key=True
)
state: Mapped[MutableDict[str, Any]] = mapped_column( state: Mapped[MutableDict[str, Any]] = mapped_column(
MutableDict.as_mutable(DynamicJSON), default={} MutableDict.as_mutable(DynamicJSON), default={}
) )
@ -217,7 +248,7 @@ class DatabaseSessionService(BaseSessionService):
""" """
# 1. Create DB engine for db connection # 1. Create DB engine for db connection
# 2. Create all tables based on schema # 2. Create all tables based on schema
# 3. Initialize all properies # 3. Initialize all properties
try: try:
db_engine = create_engine(db_url) db_engine = create_engine(db_url)
@ -353,6 +384,7 @@ class DatabaseSessionService(BaseSessionService):
else True else True
) )
.limit(config.num_recent_events if config else None) .limit(config.num_recent_events if config else None)
.order_by(StorageEvent.timestamp.asc())
.all() .all()
) )
@ -383,7 +415,7 @@ class DatabaseSessionService(BaseSessionService):
author=e.author, author=e.author,
branch=e.branch, branch=e.branch,
invocation_id=e.invocation_id, invocation_id=e.invocation_id,
content=_decode_content(e.content), content=_session_util.decode_content(e.content),
actions=e.actions, actions=e.actions,
timestamp=e.timestamp.timestamp(), timestamp=e.timestamp.timestamp(),
long_running_tool_ids=e.long_running_tool_ids, long_running_tool_ids=e.long_running_tool_ids,
@ -506,15 +538,7 @@ class DatabaseSessionService(BaseSessionService):
interrupted=event.interrupted, interrupted=event.interrupted,
) )
if event.content: if event.content:
encoded_content = event.content.model_dump(exclude_none=True) storage_event.content = _session_util.encode_content(event.content)
# Workaround for multimodal Content throwing JSON not serializable
# error with SQLAlchemy.
for p in encoded_content["parts"]:
if "inline_data" in p:
p["inline_data"]["data"] = (
base64.b64encode(p["inline_data"]["data"]).decode("utf-8"),
)
storage_event.content = encoded_content
sessionFactory.add(storage_event) sessionFactory.add(storage_event)
@ -574,10 +598,3 @@ def _merge_state(app_state, user_state, session_state):
for key in user_state.keys(): for key in user_state.keys():
merged_state[State.USER_PREFIX + key] = user_state[key] merged_state[State.USER_PREFIX + key] = user_state[key]
return merged_state return merged_state
def _decode_content(content: dict[str, Any]) -> dict[str, Any]:
for p in content["parts"]:
if "inline_data" in p:
p["inline_data"]["data"] = base64.b64decode(p["inline_data"]["data"][0])
return content

View File

@ -26,7 +26,7 @@ class State:
""" """
Args: Args:
value: The current value of the state dict. value: The current value of the state dict.
delta: The delta change to the current value that hasn't been commited. delta: The delta change to the current value that hasn't been committed.
""" """
self._value = value self._value = value
self._delta = delta self._delta = delta

View File

@ -14,21 +14,23 @@
import logging import logging
import re import re
import time import time
from typing import Any from typing import Any, Optional
from typing import Optional
from dateutil.parser import isoparse from dateutil import parser
from google import genai from google import genai
from typing_extensions import override from typing_extensions import override
from ..events.event import Event from ..events.event import Event
from ..events.event_actions import EventActions from ..events.event_actions import EventActions
from . import _session_util
from .base_session_service import BaseSessionService from .base_session_service import BaseSessionService
from .base_session_service import GetSessionConfig from .base_session_service import GetSessionConfig
from .base_session_service import ListEventsResponse from .base_session_service import ListEventsResponse
from .base_session_service import ListSessionsResponse from .base_session_service import ListSessionsResponse
from .session import Session from .session import Session
isoparse = parser.isoparse
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -289,7 +291,7 @@ def _convert_event_to_json(event: Event):
} }
event_json['actions'] = actions_json event_json['actions'] = actions_json
if event.content: if event.content:
event_json['content'] = event.content.model_dump(exclude_none=True) event_json['content'] = _session_util.encode_content(event.content)
if event.error_code: if event.error_code:
event_json['error_code'] = event.error_code event_json['error_code'] = event.error_code
if event.error_message: if event.error_message:
@ -316,7 +318,7 @@ def _from_api_event(api_event: dict) -> Event:
invocation_id=api_event['invocationId'], invocation_id=api_event['invocationId'],
author=api_event['author'], author=api_event['author'],
actions=event_actions, actions=event_actions,
content=api_event.get('content', None), content=_session_util.decode_content(api_event.get('content', None)),
timestamp=isoparse(api_event['timestamp']).timestamp(), timestamp=isoparse(api_event['timestamp']).timestamp(),
error_code=api_event.get('errorCode', None), error_code=api_event.get('errorCode', None),
error_message=api_event.get('errorMessage', None), error_message=api_event.get('errorMessage', None),

View File

@ -45,10 +45,9 @@ class AgentTool(BaseTool):
skip_summarization: Whether to skip summarization of the agent output. skip_summarization: Whether to skip summarization of the agent output.
""" """
def __init__(self, agent: BaseAgent): def __init__(self, agent: BaseAgent, skip_summarization: bool = False):
self.agent = agent self.agent = agent
self.skip_summarization: bool = False self.skip_summarization: bool = skip_summarization
"""Whether to skip summarization of the agent output."""
super().__init__(name=agent.name, description=agent.description) super().__init__(name=agent.name, description=agent.description)

View File

@ -13,7 +13,9 @@
# limitations under the License. # limitations under the License.
from .application_integration_toolset import ApplicationIntegrationToolset from .application_integration_toolset import ApplicationIntegrationToolset
from .integration_connector_tool import IntegrationConnectorTool
__all__ = [ __all__ = [
'ApplicationIntegrationToolset', 'ApplicationIntegrationToolset',
'IntegrationConnectorTool',
] ]

View File

@ -12,21 +12,21 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from typing import Dict from typing import Dict, List, Optional
from typing import List
from typing import Optional
from fastapi.openapi.models import HTTPBearer from fastapi.openapi.models import HTTPBearer
from google.adk.tools.application_integration_tool.clients.connections_client import ConnectionsClient
from google.adk.tools.application_integration_tool.clients.integration_client import IntegrationClient
from google.adk.tools.openapi_tool.auth.auth_helpers import service_account_scheme_credential
from google.adk.tools.openapi_tool.openapi_spec_parser.openapi_toolset import OpenAPIToolset
from google.adk.tools.openapi_tool.openapi_spec_parser.rest_api_tool import RestApiTool
from ...auth.auth_credential import AuthCredential from ...auth.auth_credential import AuthCredential
from ...auth.auth_credential import AuthCredentialTypes from ...auth.auth_credential import AuthCredentialTypes
from ...auth.auth_credential import ServiceAccount from ...auth.auth_credential import ServiceAccount
from ...auth.auth_credential import ServiceAccountCredential from ...auth.auth_credential import ServiceAccountCredential
from ..openapi_tool.auth.auth_helpers import service_account_scheme_credential
from ..openapi_tool.openapi_spec_parser.openapi_spec_parser import OpenApiSpecParser
from ..openapi_tool.openapi_spec_parser.openapi_toolset import OpenAPIToolset
from ..openapi_tool.openapi_spec_parser.rest_api_tool import RestApiTool
from .clients.connections_client import ConnectionsClient
from .clients.integration_client import IntegrationClient
from .integration_connector_tool import IntegrationConnectorTool
# TODO(cheliu): Apply a common toolset interface # TODO(cheliu): Apply a common toolset interface
@ -168,6 +168,7 @@ class ApplicationIntegrationToolset:
actions, actions,
service_account_json, service_account_json,
) )
connection_details = {}
if integration and trigger: if integration and trigger:
spec = integration_client.get_openapi_spec_for_integration() spec = integration_client.get_openapi_spec_for_integration()
elif connection and (entity_operations or actions): elif connection and (entity_operations or actions):
@ -175,16 +176,6 @@ class ApplicationIntegrationToolset:
project, location, connection, service_account_json project, location, connection, service_account_json
) )
connection_details = connections_client.get_connection_details() connection_details = connections_client.get_connection_details()
tool_instructions += (
"ALWAYS use serviceName = "
+ connection_details["serviceName"]
+ ", host = "
+ connection_details["host"]
+ " and the connection name = "
+ f"projects/{project}/locations/{location}/connections/{connection} when"
" using this tool"
+ ". DONOT ask the user for these values as you already have those."
)
spec = integration_client.get_openapi_spec_for_connection( spec = integration_client.get_openapi_spec_for_connection(
tool_name, tool_name,
tool_instructions, tool_instructions,
@ -194,9 +185,9 @@ class ApplicationIntegrationToolset:
"Either (integration and trigger) or (connection and" "Either (integration and trigger) or (connection and"
" (entity_operations or actions)) should be provided." " (entity_operations or actions)) should be provided."
) )
self._parse_spec_to_tools(spec) self._parse_spec_to_tools(spec, connection_details)
def _parse_spec_to_tools(self, spec_dict): def _parse_spec_to_tools(self, spec_dict, connection_details):
"""Parses the spec dict to a list of RestApiTool.""" """Parses the spec dict to a list of RestApiTool."""
if self.service_account_json: if self.service_account_json:
sa_credential = ServiceAccountCredential.model_validate_json( sa_credential = ServiceAccountCredential.model_validate_json(
@ -218,12 +209,43 @@ class ApplicationIntegrationToolset:
), ),
) )
auth_scheme = HTTPBearer(bearerFormat="JWT") auth_scheme = HTTPBearer(bearerFormat="JWT")
tools = OpenAPIToolset(
spec_dict=spec_dict, if self.integration and self.trigger:
auth_credential=auth_credential, tools = OpenAPIToolset(
auth_scheme=auth_scheme, spec_dict=spec_dict,
).get_tools() auth_credential=auth_credential,
for tool in tools: auth_scheme=auth_scheme,
).get_tools()
for tool in tools:
self.generated_tools[tool.name] = tool
return
operations = OpenApiSpecParser().parse(spec_dict)
for open_api_operation in operations:
operation = getattr(open_api_operation.operation, "x-operation")
entity = None
action = None
if hasattr(open_api_operation.operation, "x-entity"):
entity = getattr(open_api_operation.operation, "x-entity")
elif hasattr(open_api_operation.operation, "x-action"):
action = getattr(open_api_operation.operation, "x-action")
rest_api_tool = RestApiTool.from_parsed_operation(open_api_operation)
if auth_scheme:
rest_api_tool.configure_auth_scheme(auth_scheme)
if auth_credential:
rest_api_tool.configure_auth_credential(auth_credential)
tool = IntegrationConnectorTool(
name=rest_api_tool.name,
description=rest_api_tool.description,
connection_name=connection_details["name"],
connection_host=connection_details["host"],
connection_service_name=connection_details["serviceName"],
entity=entity,
action=action,
operation=operation,
rest_api_tool=rest_api_tool,
)
self.generated_tools[tool.name] = tool self.generated_tools[tool.name] = tool
def get_tools(self) -> List[RestApiTool]: def get_tools(self) -> List[RestApiTool]:

View File

@ -68,12 +68,14 @@ class ConnectionsClient:
response = self._execute_api_call(url) response = self._execute_api_call(url)
connection_data = response.json() connection_data = response.json()
connection_name = connection_data.get("name", "")
service_name = connection_data.get("serviceDirectory", "") service_name = connection_data.get("serviceDirectory", "")
host = connection_data.get("host", "") host = connection_data.get("host", "")
if host: if host:
service_name = connection_data.get("tlsServiceDirectory", "") service_name = connection_data.get("tlsServiceDirectory", "")
auth_override_enabled = connection_data.get("authOverrideEnabled", False) auth_override_enabled = connection_data.get("authOverrideEnabled", False)
return { return {
"name": connection_name,
"serviceName": service_name, "serviceName": service_name,
"host": host, "host": host,
"authOverrideEnabled": auth_override_enabled, "authOverrideEnabled": auth_override_enabled,
@ -291,13 +293,9 @@ class ConnectionsClient:
tool_name: str = "", tool_name: str = "",
tool_instructions: str = "", tool_instructions: str = "",
) -> Dict[str, Any]: ) -> Dict[str, Any]:
description = ( description = f"Use this tool to execute {action}"
f"Use this tool with" f' action = "{action}" and'
) + f' operation = "{operation}" only. Dont ask these values from user.'
if operation == "EXECUTE_QUERY": if operation == "EXECUTE_QUERY":
description = ( description += (
(f"Use this tool with" f' action = "{action}" and')
+ f' operation = "{operation}" only. Dont ask these values from user.'
" Use pageSize = 50 and timeout = 120 until user specifies a" " Use pageSize = 50 and timeout = 120 until user specifies a"
" different value otherwise. If user provides a query in natural" " different value otherwise. If user provides a query in natural"
" language, convert it to SQL query and then execute it using the" " language, convert it to SQL query and then execute it using the"
@ -308,6 +306,8 @@ class ConnectionsClient:
"summary": f"{action_display_name}", "summary": f"{action_display_name}",
"description": f"{description} {tool_instructions}", "description": f"{description} {tool_instructions}",
"operationId": f"{tool_name}_{action_display_name}", "operationId": f"{tool_name}_{action_display_name}",
"x-action": f"{action}",
"x-operation": f"{operation}",
"requestBody": { "requestBody": {
"content": { "content": {
"application/json": { "application/json": {
@ -347,16 +347,12 @@ class ConnectionsClient:
"post": { "post": {
"summary": f"List {entity}", "summary": f"List {entity}",
"description": ( "description": (
f"Returns all entities of type {entity}. Use this tool with" f"""Returns the list of {entity} data. If the page token was available in the response, let users know there are more records available. Ask if the user wants to fetch the next page of results. When passing filter use the
+ f' entity = "{entity}" and' following format: `field_name1='value1' AND field_name2='value2'
+ ' operation = "LIST_ENTITIES" only. Dont ask these values' `. {tool_instructions}"""
" from"
+ ' user. Always use ""'
+ ' as filter clause and ""'
+ " as page token and 50 as page size until user specifies a"
" different value otherwise. Use single quotes for strings in"
f" filter clause. {tool_instructions}"
), ),
"x-operation": "LIST_ENTITIES",
"x-entity": f"{entity}",
"operationId": f"{tool_name}_list_{entity}", "operationId": f"{tool_name}_list_{entity}",
"requestBody": { "requestBody": {
"content": { "content": {
@ -401,14 +397,11 @@ class ConnectionsClient:
"post": { "post": {
"summary": f"Get {entity}", "summary": f"Get {entity}",
"description": ( "description": (
( f"Returns the details of the {entity}. {tool_instructions}"
f"Returns the details of the {entity}. Use this tool with"
f' entity = "{entity}" and'
)
+ ' operation = "GET_ENTITY" only. Dont ask these values from'
f" user. {tool_instructions}"
), ),
"operationId": f"{tool_name}_get_{entity}", "operationId": f"{tool_name}_get_{entity}",
"x-operation": "GET_ENTITY",
"x-entity": f"{entity}",
"requestBody": { "requestBody": {
"content": { "content": {
"application/json": { "application/json": {
@ -445,17 +438,10 @@ class ConnectionsClient:
) -> Dict[str, Any]: ) -> Dict[str, Any]:
return { return {
"post": { "post": {
"summary": f"Create {entity}", "summary": f"Creates a new {entity}",
"description": ( "description": f"Creates a new {entity}. {tool_instructions}",
( "x-operation": "CREATE_ENTITY",
f"Creates a new entity of type {entity}. Use this tool with" "x-entity": f"{entity}",
f' entity = "{entity}" and'
)
+ ' operation = "CREATE_ENTITY" only. Dont ask these values'
" from"
+ " user. Follow the schema of the entity provided in the"
f" instructions to create {entity}. {tool_instructions}"
),
"operationId": f"{tool_name}_create_{entity}", "operationId": f"{tool_name}_create_{entity}",
"requestBody": { "requestBody": {
"content": { "content": {
@ -491,18 +477,10 @@ class ConnectionsClient:
) -> Dict[str, Any]: ) -> Dict[str, Any]:
return { return {
"post": { "post": {
"summary": f"Update {entity}", "summary": f"Updates the {entity}",
"description": ( "description": f"Updates the {entity}. {tool_instructions}",
( "x-operation": "UPDATE_ENTITY",
f"Updates an entity of type {entity}. Use this tool with" "x-entity": f"{entity}",
f' entity = "{entity}" and'
)
+ ' operation = "UPDATE_ENTITY" only. Dont ask these values'
" from"
+ " user. Use entityId to uniquely identify the entity to"
" update. Follow the schema of the entity provided in the"
f" instructions to update {entity}. {tool_instructions}"
),
"operationId": f"{tool_name}_update_{entity}", "operationId": f"{tool_name}_update_{entity}",
"requestBody": { "requestBody": {
"content": { "content": {
@ -538,16 +516,10 @@ class ConnectionsClient:
) -> Dict[str, Any]: ) -> Dict[str, Any]:
return { return {
"post": { "post": {
"summary": f"Delete {entity}", "summary": f"Delete the {entity}",
"description": ( "description": f"Deletes the {entity}. {tool_instructions}",
( "x-operation": "DELETE_ENTITY",
f"Deletes an entity of type {entity}. Use this tool with" "x-entity": f"{entity}",
f' entity = "{entity}" and'
)
+ ' operation = "DELETE_ENTITY" only. Dont ask these values'
" from"
f" user. {tool_instructions}"
),
"operationId": f"{tool_name}_delete_{entity}", "operationId": f"{tool_name}_delete_{entity}",
"requestBody": { "requestBody": {
"content": { "content": {

View File

@ -0,0 +1,159 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from typing import Any
from typing import Dict
from typing import Optional
from google.adk.tools.openapi_tool.openapi_spec_parser.rest_api_tool import RestApiTool
from google.adk.tools.openapi_tool.openapi_spec_parser.rest_api_tool import to_gemini_schema
from google.genai.types import FunctionDeclaration
from typing_extensions import override
from .. import BaseTool
from ..tool_context import ToolContext
logger = logging.getLogger(__name__)
class IntegrationConnectorTool(BaseTool):
"""A tool that wraps a RestApiTool to interact with a specific Application Integration endpoint.
This tool adds Application Integration specific context like connection
details, entity, operation, and action to the underlying REST API call
handled by RestApiTool. It prepares the arguments and then delegates the
actual API call execution to the contained RestApiTool instance.
* Generates request params and body
* Attaches auth credentials to API call.
Example:
```
# Each API operation in the spec will be turned into its own tool
# Name of the tool is the operationId of that operation, in snake case
operations = OperationGenerator().parse(openapi_spec_dict)
tool = [RestApiTool.from_parsed_operation(o) for o in operations]
```
"""
EXCLUDE_FIELDS = [
'connection_name',
'service_name',
'host',
'entity',
'operation',
'action',
]
OPTIONAL_FIELDS = [
'page_size',
'page_token',
'filter',
]
def __init__(
self,
name: str,
description: str,
connection_name: str,
connection_host: str,
connection_service_name: str,
entity: str,
operation: str,
action: str,
rest_api_tool: RestApiTool,
):
"""Initializes the ApplicationIntegrationTool.
Args:
name: The name of the tool, typically derived from the API operation.
Should be unique and adhere to Gemini function naming conventions
(e.g., less than 64 characters).
description: A description of what the tool does, usually based on the
API operation's summary or description.
connection_name: The name of the Integration Connector connection.
connection_host: The hostname or IP address for the connection.
connection_service_name: The specific service name within the host.
entity: The Integration Connector entity being targeted.
operation: The specific operation being performed on the entity.
action: The action associated with the operation (e.g., 'execute').
rest_api_tool: An initialized RestApiTool instance that handles the
underlying REST API communication based on an OpenAPI specification
operation. This tool will be called by ApplicationIntegrationTool with
added connection and context arguments. tool =
[RestApiTool.from_parsed_operation(o) for o in operations]
"""
# Gemini restrict the length of function name to be less than 64 characters
super().__init__(
name=name,
description=description,
)
self.connection_name = connection_name
self.connection_host = connection_host
self.connection_service_name = connection_service_name
self.entity = entity
self.operation = operation
self.action = action
self.rest_api_tool = rest_api_tool
@override
def _get_declaration(self) -> FunctionDeclaration:
"""Returns the function declaration in the Gemini Schema format."""
schema_dict = self.rest_api_tool._operation_parser.get_json_schema()
for field in self.EXCLUDE_FIELDS:
if field in schema_dict['properties']:
del schema_dict['properties'][field]
for field in self.OPTIONAL_FIELDS + self.EXCLUDE_FIELDS:
if field in schema_dict['required']:
schema_dict['required'].remove(field)
parameters = to_gemini_schema(schema_dict)
function_decl = FunctionDeclaration(
name=self.name, description=self.description, parameters=parameters
)
return function_decl
@override
async def run_async(
self, *, args: dict[str, Any], tool_context: Optional[ToolContext]
) -> Dict[str, Any]:
args['connection_name'] = self.connection_name
args['service_name'] = self.connection_service_name
args['host'] = self.connection_host
args['entity'] = self.entity
args['operation'] = self.operation
args['action'] = self.action
logger.info('Running tool: %s with args: %s', self.name, args)
return self.rest_api_tool.call(args=args, tool_context=tool_context)
def __str__(self):
return (
f'ApplicationIntegrationTool(name="{self.name}",'
f' description="{self.description}",'
f' connection_name="{self.connection_name}", entity="{self.entity}",'
f' operation="{self.operation}", action="{self.action}")'
)
def __repr__(self):
return (
f'ApplicationIntegrationTool(name="{self.name}",'
f' description="{self.description}",'
f' connection_name="{self.connection_name}",'
f' connection_host="{self.connection_host}",'
f' connection_service_name="{self.connection_service_name}",'
f' entity="{self.entity}", operation="{self.operation}",'
f' action="{self.action}", rest_api_tool={repr(self.rest_api_tool)})'
)

View File

@ -59,6 +59,23 @@ class FunctionTool(BaseTool):
if 'tool_context' in signature.parameters: if 'tool_context' in signature.parameters:
args_to_call['tool_context'] = tool_context args_to_call['tool_context'] = tool_context
# Before invoking the function, we check for if the list of args passed in
# has all the mandatory arguments or not.
# If the check fails, then we don't invoke the tool and let the Agent know
# that there was a missing a input parameter. This will basically help
# the underlying model fix the issue and retry.
mandatory_args = self._get_mandatory_args()
missing_mandatory_args = [
arg for arg in mandatory_args if arg not in args_to_call
]
if missing_mandatory_args:
missing_mandatory_args_str = '\n'.join(missing_mandatory_args)
error_str = f"""Invoking `{self.name}()` failed as the following mandatory input parameters are not present:
{missing_mandatory_args_str}
You could retry calling this tool, but it is IMPORTANT for you to provide all the mandatory parameters."""
return {'error': error_str}
if inspect.iscoroutinefunction(self.func): if inspect.iscoroutinefunction(self.func):
return await self.func(**args_to_call) or {} return await self.func(**args_to_call) or {}
else: else:
@ -85,3 +102,28 @@ class FunctionTool(BaseTool):
args_to_call['tool_context'] = tool_context args_to_call['tool_context'] = tool_context
async for item in self.func(**args_to_call): async for item in self.func(**args_to_call):
yield item yield item
def _get_mandatory_args(
self,
) -> list[str]:
"""Identifies mandatory parameters (those without default values) for a function.
Returns:
A list of strings, where each string is the name of a mandatory parameter.
"""
signature = inspect.signature(self.func)
mandatory_params = []
for name, param in signature.parameters.items():
# A parameter is mandatory if:
# 1. It has no default value (param.default is inspect.Parameter.empty)
# 2. It's not a variable positional (*args) or variable keyword (**kwargs) parameter
#
# For more refer to: https://docs.python.org/3/library/inspect.html#inspect.Parameter.kind
if param.default == inspect.Parameter.empty and param.kind not in (
inspect.Parameter.VAR_POSITIONAL,
inspect.Parameter.VAR_KEYWORD,
):
mandatory_params.append(name)
return mandatory_params

View File

@ -11,10 +11,12 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from __future__ import annotations
import inspect import inspect
import os import os
from typing import Any from typing import Any
from typing import Dict
from typing import Final from typing import Final
from typing import List from typing import List
from typing import Optional from typing import Optional
@ -28,6 +30,7 @@ from .googleapi_to_openapi_converter import GoogleApiToOpenApiConverter
class GoogleApiToolSet: class GoogleApiToolSet:
"""Google API Tool Set."""
def __init__(self, tools: List[RestApiTool]): def __init__(self, tools: List[RestApiTool]):
self.tools: Final[List[GoogleApiTool]] = [ self.tools: Final[List[GoogleApiTool]] = [
@ -45,10 +48,10 @@ class GoogleApiToolSet:
@staticmethod @staticmethod
def _load_tool_set_with_oidc_auth( def _load_tool_set_with_oidc_auth(
spec_file: str = None, spec_file: Optional[str] = None,
spec_dict: Dict[str, Any] = None, spec_dict: Optional[dict[str, Any]] = None,
scopes: list[str] = None, scopes: Optional[list[str]] = None,
) -> Optional[OpenAPIToolset]: ) -> OpenAPIToolset:
spec_str = None spec_str = None
if spec_file: if spec_file:
# Get the frame of the caller # Get the frame of the caller
@ -90,18 +93,18 @@ class GoogleApiToolSet:
@classmethod @classmethod
def load_tool_set( def load_tool_set(
cl: Type['GoogleApiToolSet'], cls: Type[GoogleApiToolSet],
api_name: str, api_name: str,
api_version: str, api_version: str,
) -> 'GoogleApiToolSet': ) -> GoogleApiToolSet:
spec_dict = GoogleApiToOpenApiConverter(api_name, api_version).convert() spec_dict = GoogleApiToOpenApiConverter(api_name, api_version).convert()
scope = list( scope = list(
spec_dict['components']['securitySchemes']['oauth2']['flows'][ spec_dict['components']['securitySchemes']['oauth2']['flows'][
'authorizationCode' 'authorizationCode'
]['scopes'].keys() ]['scopes'].keys()
)[0] )[0]
return cl( return cls(
cl._load_tool_set_with_oidc_auth( cls._load_tool_set_with_oidc_auth(
spec_dict=spec_dict, scopes=[scope] spec_dict=spec_dict, scopes=[scope]
).get_tools() ).get_tools()
) )

View File

@ -89,7 +89,7 @@ class LoadArtifactsTool(BaseTool):
than the function call. than the function call.
"""]) """])
# Attache the content of the artifacts if the model requests them. # Attach the content of the artifacts if the model requests them.
# This only adds the content to the model request, instead of the session. # This only adds the content to the model request, instead of the session.
if llm_request.contents and llm_request.contents[-1].parts: if llm_request.contents and llm_request.contents[-1].parts:
function_response = llm_request.contents[-1].parts[0].function_response function_response = llm_request.contents[-1].parts[0].function_response

View File

@ -66,10 +66,10 @@ class OAuth2CredentialExchanger(BaseAuthCredentialExchanger):
Returns: Returns:
An AuthCredential object containing the HTTP bearer access token. If the An AuthCredential object containing the HTTP bearer access token. If the
HTTO bearer token cannot be generated, return the origianl credential HTTP bearer token cannot be generated, return the original credential.
""" """
if "access_token" not in auth_credential.oauth2.token: if not auth_credential.oauth2.access_token:
return auth_credential return auth_credential
# Return the access token as a bearer token. # Return the access token as a bearer token.
@ -78,7 +78,7 @@ class OAuth2CredentialExchanger(BaseAuthCredentialExchanger):
http=HttpAuth( http=HttpAuth(
scheme="bearer", scheme="bearer",
credentials=HttpCredentials( credentials=HttpCredentials(
token=auth_credential.oauth2.token["access_token"] token=auth_credential.oauth2.access_token
), ),
), ),
) )
@ -111,7 +111,7 @@ class OAuth2CredentialExchanger(BaseAuthCredentialExchanger):
return auth_credential return auth_credential
# If access token is exchanged, exchange a HTTPBearer token. # If access token is exchanged, exchange a HTTPBearer token.
if auth_credential.oauth2.token: if auth_credential.oauth2.access_token:
return self.generate_auth_token(auth_credential) return self.generate_auth_token(auth_credential)
return None return None

View File

@ -124,7 +124,7 @@ class OpenAPIToolset:
def _load_spec( def _load_spec(
self, spec_str: str, spec_type: Literal["json", "yaml"] self, spec_str: str, spec_type: Literal["json", "yaml"]
) -> Dict[str, Any]: ) -> Dict[str, Any]:
"""Loads the OpenAPI spec string into adictionary.""" """Loads the OpenAPI spec string into a dictionary."""
if spec_type == "json": if spec_type == "json":
return json.loads(spec_str) return json.loads(spec_str)
elif spec_type == "yaml": elif spec_type == "yaml":

View File

@ -14,20 +14,12 @@
import inspect import inspect
from textwrap import dedent from textwrap import dedent
from typing import Any from typing import Any, Dict, List, Optional, Union
from typing import Dict
from typing import List
from typing import Optional
from typing import Union
from fastapi.encoders import jsonable_encoder from fastapi.encoders import jsonable_encoder
from fastapi.openapi.models import Operation from fastapi.openapi.models import Operation, Parameter, Schema
from fastapi.openapi.models import Parameter
from fastapi.openapi.models import Schema
from ..common.common import ApiParameter from ..common.common import ApiParameter, PydocHelper, to_snake_case
from ..common.common import PydocHelper
from ..common.common import to_snake_case
class OperationParser: class OperationParser:
@ -113,7 +105,8 @@ class OperationParser:
description = request_body.description or '' description = request_body.description or ''
if schema and schema.type == 'object': if schema and schema.type == 'object':
for prop_name, prop_details in schema.properties.items(): properties = schema.properties or {}
for prop_name, prop_details in properties.items():
self.params.append( self.params.append(
ApiParameter( ApiParameter(
original_name=prop_name, original_name=prop_name,

View File

@ -17,6 +17,7 @@ from typing import Dict
from typing import List from typing import List
from typing import Literal from typing import Literal
from typing import Optional from typing import Optional
from typing import Sequence
from typing import Tuple from typing import Tuple
from typing import Union from typing import Union
@ -59,6 +60,40 @@ def snake_to_lower_camel(snake_case_string: str):
]) ])
# TODO: Switch to Gemini `from_json_schema` util when it is released
# in Gemini SDK.
def normalize_json_schema_type(
json_schema_type: Optional[Union[str, Sequence[str]]],
) -> tuple[Optional[str], bool]:
"""Converts a JSON Schema Type into Gemini Schema type.
Adopted and modified from Gemini SDK. This gets the first available schema
type from JSON Schema, and use it to mark Gemini schema type. If JSON Schema
contains a list of types, the first non null type is used.
Remove this after switching to Gemini `from_json_schema`.
"""
if json_schema_type is None:
return None, False
if isinstance(json_schema_type, str):
if json_schema_type == "null":
return None, True
return json_schema_type, False
non_null_types = []
nullable = False
# If json schema type is an array, pick the first non null type.
for type_value in json_schema_type:
if type_value == "null":
nullable = True
else:
non_null_types.append(type_value)
non_null_type = non_null_types[0] if non_null_types else None
return non_null_type, nullable
# TODO: Switch to Gemini `from_json_schema` util when it is released
# in Gemini SDK.
def to_gemini_schema(openapi_schema: Optional[Dict[str, Any]] = None) -> Schema: def to_gemini_schema(openapi_schema: Optional[Dict[str, Any]] = None) -> Schema:
"""Converts an OpenAPI schema dictionary to a Gemini Schema object. """Converts an OpenAPI schema dictionary to a Gemini Schema object.
@ -82,13 +117,6 @@ def to_gemini_schema(openapi_schema: Optional[Dict[str, Any]] = None) -> Schema:
if not openapi_schema.get("type"): if not openapi_schema.get("type"):
openapi_schema["type"] = "object" openapi_schema["type"] = "object"
# Adding this to avoid "properties: should be non-empty for OBJECT type" error
# See b/385165182
if openapi_schema.get("type", "") == "object" and not openapi_schema.get(
"properties"
):
openapi_schema["properties"] = {"dummy_DO_NOT_GENERATE": {"type": "string"}}
for key, value in openapi_schema.items(): for key, value in openapi_schema.items():
snake_case_key = to_snake_case(key) snake_case_key = to_snake_case(key)
# Check if the snake_case_key exists in the Schema model's fields. # Check if the snake_case_key exists in the Schema model's fields.
@ -99,7 +127,17 @@ def to_gemini_schema(openapi_schema: Optional[Dict[str, Any]] = None) -> Schema:
# Format: properties[expiration].format: only 'enum' and 'date-time' are # Format: properties[expiration].format: only 'enum' and 'date-time' are
# supported for STRING type # supported for STRING type
continue continue
if snake_case_key == "properties" and isinstance(value, dict): elif snake_case_key == "type":
schema_type, nullable = normalize_json_schema_type(
openapi_schema.get("type", None)
)
# Adding this to force adding a type to an empty dict
# This avoid "... one_of or any_of must specify a type" error
pydantic_schema_data["type"] = schema_type if schema_type else "object"
pydantic_schema_data["type"] = pydantic_schema_data["type"].upper()
if nullable:
pydantic_schema_data["nullable"] = True
elif snake_case_key == "properties" and isinstance(value, dict):
pydantic_schema_data[snake_case_key] = { pydantic_schema_data[snake_case_key] = {
k: to_gemini_schema(v) for k, v in value.items() k: to_gemini_schema(v) for k, v in value.items()
} }

View File

@ -13,4 +13,4 @@
# limitations under the License. # limitations under the License.
# version: date+base_cl # version: date+base_cl
__version__ = "0.1.1" __version__ = "0.3.0"

View File

@ -241,7 +241,7 @@ def test_langchain_tool_success(agent_runner: TestRunner):
def test_crewai_tool_success(agent_runner: TestRunner): def test_crewai_tool_success(agent_runner: TestRunner):
_call_function_and_assert( _call_function_and_assert(
agent_runner, agent_runner,
"direcotry_read_tool", "directory_read_tool",
"Find all the file paths", "Find all the file paths",
"file", "file",
) )

View File

@ -126,12 +126,8 @@ def oauth2_credentials_with_token():
client_id="mock_client_id", client_id="mock_client_id",
client_secret="mock_client_secret", client_secret="mock_client_secret",
redirect_uri="https://example.com/callback", redirect_uri="https://example.com/callback",
token={ access_token="mock_access_token",
"access_token": "mock_access_token", refresh_token="mock_refresh_token",
"token_type": "bearer",
"expires_in": 3600,
"refresh_token": "mock_refresh_token",
},
), ),
) )
@ -458,7 +454,7 @@ class TestParseAndStoreAuthResponse:
"""Test with an OAuth auth scheme.""" """Test with an OAuth auth scheme."""
mock_exchange_token.return_value = AuthCredential( mock_exchange_token.return_value = AuthCredential(
auth_type=AuthCredentialTypes.OAUTH2, auth_type=AuthCredentialTypes.OAUTH2,
oauth2=OAuth2Auth(token={"access_token": "exchanged_token"}), oauth2=OAuth2Auth(access_token="exchanged_token"),
) )
handler = AuthHandler(auth_config_with_exchanged) handler = AuthHandler(auth_config_with_exchanged)
@ -573,6 +569,6 @@ class TestExchangeAuthToken:
handler = AuthHandler(auth_config_with_auth_code) handler = AuthHandler(auth_config_with_auth_code)
result = handler.exchange_auth_token() result = handler.exchange_auth_token()
assert result.oauth2.token["access_token"] == "mock_access_token" assert result.oauth2.access_token == "mock_access_token"
assert result.oauth2.token["refresh_token"] == "mock_refresh_token" assert result.oauth2.refresh_token == "mock_refresh_token"
assert result.auth_type == AuthCredentialTypes.OAUTH2 assert result.auth_type == AuthCredentialTypes.OAUTH2

View File

@ -0,0 +1,109 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Dict, Optional
import pytest
from google.adk.agents import Agent
from google.adk.tools.function_tool import FunctionTool
from google.adk.tools.tool_context import ToolContext
from google.adk.flows.llm_flows.functions import handle_function_calls_async
from google.adk.events.event import Event
from google.genai import types
from ... import utils
class AsyncBeforeToolCallback:
def __init__(self, mock_response: Dict[str, Any]):
self.mock_response = mock_response
async def __call__(
self,
tool: FunctionTool,
args: Dict[str, Any],
tool_context: ToolContext,
) -> Optional[Dict[str, Any]]:
return self.mock_response
class AsyncAfterToolCallback:
def __init__(self, mock_response: Dict[str, Any]):
self.mock_response = mock_response
async def __call__(
self,
tool: FunctionTool,
args: Dict[str, Any],
tool_context: ToolContext,
tool_response: Dict[str, Any],
) -> Optional[Dict[str, Any]]:
return self.mock_response
async def invoke_tool_with_callbacks(
before_cb=None, after_cb=None
) -> Optional[Event]:
def simple_fn(**kwargs) -> Dict[str, Any]:
return {"initial": "response"}
tool = FunctionTool(simple_fn)
model = utils.MockModel.create(responses=[])
agent = Agent(
name="agent",
model=model,
tools=[tool],
before_tool_callback=before_cb,
after_tool_callback=after_cb,
)
invocation_context = utils.create_invocation_context(
agent=agent, user_content=""
)
# Build function call event
function_call = types.FunctionCall(name=tool.name, args={})
content = types.Content(parts=[types.Part(function_call=function_call)])
event = Event(
invocation_id=invocation_context.invocation_id,
author=agent.name,
content=content,
)
tools_dict = {tool.name: tool}
return await handle_function_calls_async(
invocation_context,
event,
tools_dict,
)
@pytest.mark.asyncio
async def test_async_before_tool_callback():
mock_resp = {"test": "before_tool_callback"}
before_cb = AsyncBeforeToolCallback(mock_resp)
result_event = await invoke_tool_with_callbacks(before_cb=before_cb)
assert result_event is not None
part = result_event.content.parts[0]
assert part.function_response.response == mock_resp
@pytest.mark.asyncio
async def test_async_after_tool_callback():
mock_resp = {"test": "after_tool_callback"}
after_cb = AsyncAfterToolCallback(mock_resp)
result_event = await invoke_tool_with_callbacks(after_cb=after_cb)
assert result_event is not None
part = result_event.content.parts[0]
assert part.function_response.response == mock_resp

View File

@ -246,7 +246,7 @@ def test_function_get_auth_response():
oauth2=OAuth2Auth( oauth2=OAuth2Auth(
client_id='oauth_client_id_1', client_id='oauth_client_id_1',
client_secret='oauth_client_secret1', client_secret='oauth_client_secret1',
token={'access_token': 'token1'}, access_token='token1',
), ),
), ),
) )
@ -277,7 +277,7 @@ def test_function_get_auth_response():
oauth2=OAuth2Auth( oauth2=OAuth2Auth(
client_id='oauth_client_id_2', client_id='oauth_client_id_2',
client_secret='oauth_client_secret2', client_secret='oauth_client_secret2',
token={'access_token': 'token2'}, access_token='token2',
), ),
), ),
) )

View File

@ -14,10 +14,12 @@
import json import json
from unittest import mock from unittest import mock
from fastapi.openapi.models import Operation
from google.adk.auth.auth_credential import AuthCredential from google.adk.auth.auth_credential import AuthCredential
from google.adk.tools.application_integration_tool.application_integration_toolset import ApplicationIntegrationToolset from google.adk.tools.application_integration_tool.application_integration_toolset import ApplicationIntegrationToolset
from google.adk.tools.openapi_tool.openapi_spec_parser import rest_api_tool from google.adk.tools.application_integration_tool.integration_connector_tool import IntegrationConnectorTool
from google.adk.tools.openapi_tool.openapi_spec_parser import ParsedOperation, rest_api_tool
from google.adk.tools.openapi_tool.openapi_spec_parser.openapi_spec_parser import OperationEndpoint
import pytest import pytest
@ -50,6 +52,59 @@ def mock_openapi_toolset():
yield mock_toolset yield mock_toolset
def get_mocked_parsed_operation(operation_id, attributes):
mock_openapi_spec_parser_instance = mock.MagicMock()
mock_parsed_operation = mock.MagicMock(spec=ParsedOperation)
mock_parsed_operation.name = "list_issues"
mock_parsed_operation.description = "list_issues_description"
mock_parsed_operation.endpoint = OperationEndpoint(
base_url="http://localhost:8080",
path="/v1/issues",
method="GET",
)
mock_parsed_operation.auth_scheme = None
mock_parsed_operation.auth_credential = None
mock_parsed_operation.additional_context = {}
mock_parsed_operation.parameters = []
mock_operation = mock.MagicMock(spec=Operation)
mock_operation.operationId = operation_id
mock_operation.description = "list_issues_description"
mock_operation.parameters = []
mock_operation.requestBody = None
mock_operation.responses = {}
mock_operation.callbacks = {}
for key, value in attributes.items():
setattr(mock_operation, key, value)
mock_parsed_operation.operation = mock_operation
mock_openapi_spec_parser_instance.parse.return_value = [mock_parsed_operation]
return mock_openapi_spec_parser_instance
@pytest.fixture
def mock_openapi_entity_spec_parser():
with mock.patch(
"google.adk.tools.application_integration_tool.application_integration_toolset.OpenApiSpecParser"
) as mock_spec_parser:
mock_openapi_spec_parser_instance = get_mocked_parsed_operation(
"list_issues", {"x-entity": "Issues", "x-operation": "LIST_ENTITIES"}
)
mock_spec_parser.return_value = mock_openapi_spec_parser_instance
yield mock_spec_parser
@pytest.fixture
def mock_openapi_action_spec_parser():
with mock.patch(
"google.adk.tools.application_integration_tool.application_integration_toolset.OpenApiSpecParser"
) as mock_spec_parser:
mock_openapi_action_spec_parser_instance = get_mocked_parsed_operation(
"list_issues_operation",
{"x-action": "CustomAction", "x-operation": "EXECUTE_ACTION"},
)
mock_spec_parser.return_value = mock_openapi_action_spec_parser_instance
yield mock_spec_parser
@pytest.fixture @pytest.fixture
def project(): def project():
return "test-project" return "test-project"
@ -72,7 +127,11 @@ def connection_spec():
@pytest.fixture @pytest.fixture
def connection_details(): def connection_details():
return {"serviceName": "test-service", "host": "test.host"} return {
"serviceName": "test-service",
"host": "test.host",
"name": "test-connection",
}
def test_initialization_with_integration_and_trigger( def test_initialization_with_integration_and_trigger(
@ -102,7 +161,7 @@ def test_initialization_with_connection_and_entity_operations(
location, location,
mock_integration_client, mock_integration_client,
mock_connections_client, mock_connections_client,
mock_openapi_toolset, mock_openapi_entity_spec_parser,
connection_details, connection_details,
): ):
connection_name = "test-connection" connection_name = "test-connection"
@ -133,19 +192,17 @@ def test_initialization_with_connection_and_entity_operations(
mock_connections_client.assert_called_once_with( mock_connections_client.assert_called_once_with(
project, location, connection_name, None project, location, connection_name, None
) )
mock_openapi_entity_spec_parser.return_value.parse.assert_called_once()
mock_connections_client.return_value.get_connection_details.assert_called_once() mock_connections_client.return_value.get_connection_details.assert_called_once()
mock_integration_client.return_value.get_openapi_spec_for_connection.assert_called_once_with( mock_integration_client.return_value.get_openapi_spec_for_connection.assert_called_once_with(
tool_name, tool_name,
tool_instructions tool_instructions,
+ f"ALWAYS use serviceName = {connection_details['serviceName']}, host ="
f" {connection_details['host']} and the connection name ="
f" projects/{project}/locations/{location}/connections/{connection_name} when"
" using this tool. DONOT ask the user for these values as you already"
" have those.",
) )
mock_openapi_toolset.assert_called_once()
assert len(toolset.get_tools()) == 1 assert len(toolset.get_tools()) == 1
assert toolset.get_tools()[0].name == "Test Tool" assert toolset.get_tools()[0].name == "list_issues"
assert isinstance(toolset.get_tools()[0], IntegrationConnectorTool)
assert toolset.get_tools()[0].entity == "Issues"
assert toolset.get_tools()[0].operation == "LIST_ENTITIES"
def test_initialization_with_connection_and_actions( def test_initialization_with_connection_and_actions(
@ -153,7 +210,7 @@ def test_initialization_with_connection_and_actions(
location, location,
mock_integration_client, mock_integration_client,
mock_connections_client, mock_connections_client,
mock_openapi_toolset, mock_openapi_action_spec_parser,
connection_details, connection_details,
): ):
connection_name = "test-connection" connection_name = "test-connection"
@ -181,15 +238,13 @@ def test_initialization_with_connection_and_actions(
mock_integration_client.return_value.get_openapi_spec_for_connection.assert_called_once_with( mock_integration_client.return_value.get_openapi_spec_for_connection.assert_called_once_with(
tool_name, tool_name,
tool_instructions tool_instructions
+ f"ALWAYS use serviceName = {connection_details['serviceName']}, host ="
f" {connection_details['host']} and the connection name ="
f" projects/{project}/locations/{location}/connections/{connection_name} when"
" using this tool. DONOT ask the user for these values as you already"
" have those.",
) )
mock_openapi_toolset.assert_called_once() mock_openapi_action_spec_parser.return_value.parse.assert_called_once()
assert len(toolset.get_tools()) == 1 assert len(toolset.get_tools()) == 1
assert toolset.get_tools()[0].name == "Test Tool" assert toolset.get_tools()[0].name == "list_issues_operation"
assert isinstance(toolset.get_tools()[0], IntegrationConnectorTool)
assert toolset.get_tools()[0].action == "CustomAction"
assert toolset.get_tools()[0].operation == "EXECUTE_ACTION"
def test_initialization_without_required_params(project, location): def test_initialization_without_required_params(project, location):
@ -337,9 +392,4 @@ def test_initialization_with_connection_details(
mock_integration_client.return_value.get_openapi_spec_for_connection.assert_called_once_with( mock_integration_client.return_value.get_openapi_spec_for_connection.assert_called_once_with(
tool_name, tool_name,
tool_instructions tool_instructions
+ "ALWAYS use serviceName = custom-service, host = custom.host and the"
" connection name ="
" projects/test-project/locations/us-central1/connections/test-connection"
" when using this tool. DONOT ask the user for these values as you"
" already have those.",
) )

View File

@ -0,0 +1,125 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from unittest import mock
from google.adk.tools.application_integration_tool.integration_connector_tool import IntegrationConnectorTool
from google.adk.tools.openapi_tool.openapi_spec_parser.rest_api_tool import RestApiTool
from google.genai.types import FunctionDeclaration
from google.genai.types import Schema
from google.genai.types import Tool
from google.genai.types import Type
import pytest
@pytest.fixture
def mock_rest_api_tool():
"""Fixture for a mocked RestApiTool."""
mock_tool = mock.MagicMock(spec=RestApiTool)
mock_tool.name = "mock_rest_tool"
mock_tool.description = "Mock REST tool description."
# Mock the internal parser needed for _get_declaration
mock_parser = mock.MagicMock()
mock_parser.get_json_schema.return_value = {
"type": "object",
"properties": {
"user_id": {"type": "string", "description": "User ID"},
"connection_name": {"type": "string"},
"host": {"type": "string"},
"service_name": {"type": "string"},
"entity": {"type": "string"},
"operation": {"type": "string"},
"action": {"type": "string"},
"page_size": {"type": "integer"},
"filter": {"type": "string"},
},
"required": ["user_id", "page_size", "filter", "connection_name"],
}
mock_tool._operation_parser = mock_parser
mock_tool.call.return_value = {"status": "success", "data": "mock_data"}
return mock_tool
@pytest.fixture
def integration_tool(mock_rest_api_tool):
"""Fixture for an IntegrationConnectorTool instance."""
return IntegrationConnectorTool(
name="test_integration_tool",
description="Test integration tool description.",
connection_name="test-conn",
connection_host="test.example.com",
connection_service_name="test-service",
entity="TestEntity",
operation="LIST",
action="TestAction",
rest_api_tool=mock_rest_api_tool,
)
def test_get_declaration(integration_tool):
"""Tests the generation of the function declaration."""
declaration = integration_tool._get_declaration()
assert isinstance(declaration, FunctionDeclaration)
assert declaration.name == "test_integration_tool"
assert declaration.description == "Test integration tool description."
# Check parameters schema
params = declaration.parameters
assert isinstance(params, Schema)
print(f"params: {params}")
assert params.type == Type.OBJECT
# Check properties (excluded fields should not be present)
assert "user_id" in params.properties
assert "connection_name" not in params.properties
assert "host" not in params.properties
assert "service_name" not in params.properties
assert "entity" not in params.properties
assert "operation" not in params.properties
assert "action" not in params.properties
assert "page_size" in params.properties
assert "filter" in params.properties
# Check required fields (optional and excluded fields should not be required)
assert "user_id" in params.required
assert "page_size" not in params.required
assert "filter" not in params.required
assert "connection_name" not in params.required
@pytest.mark.asyncio
async def test_run_async(integration_tool, mock_rest_api_tool):
"""Tests the async execution delegates correctly to the RestApiTool."""
input_args = {"user_id": "user123", "page_size": 10}
expected_call_args = {
"user_id": "user123",
"page_size": 10,
"connection_name": "test-conn",
"host": "test.example.com",
"service_name": "test-service",
"entity": "TestEntity",
"operation": "LIST",
"action": "TestAction",
}
result = await integration_tool.run_async(args=input_args, tool_context=None)
# Assert the underlying rest_api_tool.call was called correctly
mock_rest_api_tool.call.assert_called_once_with(
args=expected_call_args, tool_context=None
)
# Assert the result is what the mocked call returned
assert result == {"status": "success", "data": "mock_data"}

View File

@ -110,7 +110,7 @@ def test_generate_auth_token_success(
client_secret="test_secret", client_secret="test_secret",
redirect_uri="http://localhost:8080", redirect_uri="http://localhost:8080",
auth_response_uri="https://example.com/callback?code=test_code", auth_response_uri="https://example.com/callback?code=test_code",
token={"access_token": "test_access_token"}, access_token="test_access_token",
), ),
) )
updated_credential = oauth2_exchanger.generate_auth_token(auth_credential) updated_credential = oauth2_exchanger.generate_auth_token(auth_credential)
@ -131,7 +131,7 @@ def test_exchange_credential_generate_auth_token(
client_secret="test_secret", client_secret="test_secret",
redirect_uri="http://localhost:8080", redirect_uri="http://localhost:8080",
auth_response_uri="https://example.com/callback?code=test_code", auth_response_uri="https://example.com/callback?code=test_code",
token={"access_token": "test_access_token"}, access_token="test_access_token",
), ),
) )

View File

@ -164,6 +164,18 @@ def test_process_request_body_no_name():
assert parser.params[0].param_location == 'body' assert parser.params[0].param_location == 'body'
def test_process_request_body_empty_object():
"""Test _process_request_body with a schema that is of type object but with no properties."""
operation = Operation(
requestBody=RequestBody(
content={'application/json': MediaType(schema=Schema(type='object'))}
)
)
parser = OperationParser(operation, should_parse=False)
parser._process_request_body()
assert len(parser.params) == 0
def test_dedupe_param_names(sample_operation): def test_dedupe_param_names(sample_operation):
"""Test _dedupe_param_names method.""" """Test _dedupe_param_names method."""
parser = OperationParser(sample_operation, should_parse=False) parser = OperationParser(sample_operation, should_parse=False)

View File

@ -14,11 +14,9 @@
import json import json
from unittest.mock import MagicMock from unittest.mock import MagicMock, patch
from unittest.mock import patch
from fastapi.openapi.models import MediaType from fastapi.openapi.models import MediaType, Operation
from fastapi.openapi.models import Operation
from fastapi.openapi.models import Parameter as OpenAPIParameter from fastapi.openapi.models import Parameter as OpenAPIParameter
from fastapi.openapi.models import RequestBody from fastapi.openapi.models import RequestBody
from fastapi.openapi.models import Schema as OpenAPISchema from fastapi.openapi.models import Schema as OpenAPISchema
@ -27,13 +25,13 @@ from google.adk.tools.openapi_tool.auth.auth_helpers import token_to_scheme_cred
from google.adk.tools.openapi_tool.common.common import ApiParameter from google.adk.tools.openapi_tool.common.common import ApiParameter
from google.adk.tools.openapi_tool.openapi_spec_parser.openapi_spec_parser import OperationEndpoint from google.adk.tools.openapi_tool.openapi_spec_parser.openapi_spec_parser import OperationEndpoint
from google.adk.tools.openapi_tool.openapi_spec_parser.operation_parser import OperationParser from google.adk.tools.openapi_tool.openapi_spec_parser.operation_parser import OperationParser
from google.adk.tools.openapi_tool.openapi_spec_parser.rest_api_tool import RestApiTool from google.adk.tools.openapi_tool.openapi_spec_parser.rest_api_tool import (
from google.adk.tools.openapi_tool.openapi_spec_parser.rest_api_tool import snake_to_lower_camel RestApiTool,
from google.adk.tools.openapi_tool.openapi_spec_parser.rest_api_tool import to_gemini_schema snake_to_lower_camel,
to_gemini_schema,
)
from google.adk.tools.tool_context import ToolContext from google.adk.tools.tool_context import ToolContext
from google.genai.types import FunctionDeclaration from google.genai.types import FunctionDeclaration, Schema, Type
from google.genai.types import Schema
from google.genai.types import Type
import pytest import pytest
@ -790,13 +788,13 @@ class TestToGeminiSchema:
result = to_gemini_schema({}) result = to_gemini_schema({})
assert isinstance(result, Schema) assert isinstance(result, Schema)
assert result.type == Type.OBJECT assert result.type == Type.OBJECT
assert result.properties == {"dummy_DO_NOT_GENERATE": Schema(type="string")} assert result.properties is None
def test_to_gemini_schema_dict_with_only_object_type(self): def test_to_gemini_schema_dict_with_only_object_type(self):
result = to_gemini_schema({"type": "object"}) result = to_gemini_schema({"type": "object"})
assert isinstance(result, Schema) assert isinstance(result, Schema)
assert result.type == Type.OBJECT assert result.type == Type.OBJECT
assert result.properties == {"dummy_DO_NOT_GENERATE": Schema(type="string")} assert result.properties is None
def test_to_gemini_schema_basic_types(self): def test_to_gemini_schema_basic_types(self):
openapi_schema = { openapi_schema = {
@ -814,6 +812,42 @@ class TestToGeminiSchema:
assert gemini_schema.properties["age"].type == Type.INTEGER assert gemini_schema.properties["age"].type == Type.INTEGER
assert gemini_schema.properties["is_active"].type == Type.BOOLEAN assert gemini_schema.properties["is_active"].type == Type.BOOLEAN
def test_to_gemini_schema_array_string_types(self):
openapi_schema = {
"type": "object",
"properties": {
"boolean_field": {"type": "boolean"},
"nonnullable_string": {"type": ["string"]},
"nullable_string": {"type": ["string", "null"]},
"nullable_number": {"type": ["null", "integer"]},
"object_nullable": {"type": "null"},
"multi_types_nullable": {"type": ["string", "null", "integer"]},
"empty_default_object": {},
},
}
gemini_schema = to_gemini_schema(openapi_schema)
assert isinstance(gemini_schema, Schema)
assert gemini_schema.type == Type.OBJECT
assert gemini_schema.properties["boolean_field"].type == Type.BOOLEAN
assert gemini_schema.properties["nonnullable_string"].type == Type.STRING
assert not gemini_schema.properties["nonnullable_string"].nullable
assert gemini_schema.properties["nullable_string"].type == Type.STRING
assert gemini_schema.properties["nullable_string"].nullable
assert gemini_schema.properties["nullable_number"].type == Type.INTEGER
assert gemini_schema.properties["nullable_number"].nullable
assert gemini_schema.properties["object_nullable"].type == Type.OBJECT
assert gemini_schema.properties["object_nullable"].nullable
assert gemini_schema.properties["multi_types_nullable"].type == Type.STRING
assert gemini_schema.properties["multi_types_nullable"].nullable
assert gemini_schema.properties["empty_default_object"].type == Type.OBJECT
assert not gemini_schema.properties["empty_default_object"].nullable
def test_to_gemini_schema_nested_objects(self): def test_to_gemini_schema_nested_objects(self):
openapi_schema = { openapi_schema = {
"type": "object", "type": "object",
@ -895,7 +929,15 @@ class TestToGeminiSchema:
def test_to_gemini_schema_nested_dict(self): def test_to_gemini_schema_nested_dict(self):
openapi_schema = { openapi_schema = {
"type": "object", "type": "object",
"properties": {"metadata": {"key1": "value1", "key2": 123}}, "properties": {
"metadata": {
"type": "object",
"properties": {
"key1": {"type": "object"},
"key2": {"type": "string"},
},
}
},
} }
gemini_schema = to_gemini_schema(openapi_schema) gemini_schema = to_gemini_schema(openapi_schema)
# Since metadata is not properties nor item, it will call to_gemini_schema recursively. # Since metadata is not properties nor item, it will call to_gemini_schema recursively.
@ -903,9 +945,15 @@ class TestToGeminiSchema:
assert ( assert (
gemini_schema.properties["metadata"].type == Type.OBJECT gemini_schema.properties["metadata"].type == Type.OBJECT
) # add object type by default ) # add object type by default
assert gemini_schema.properties["metadata"].properties == { assert len(gemini_schema.properties["metadata"].properties) == 2
"dummy_DO_NOT_GENERATE": Schema(type="string") assert (
} gemini_schema.properties["metadata"].properties["key1"].type
== Type.OBJECT
)
assert (
gemini_schema.properties["metadata"].properties["key2"].type
== Type.STRING
)
def test_to_gemini_schema_ignore_title_default_format(self): def test_to_gemini_schema_ignore_title_default_format(self):
openapi_schema = { openapi_schema = {

View File

@ -0,0 +1,238 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from unittest.mock import MagicMock
from google.adk.tools.function_tool import FunctionTool
import pytest
def function_for_testing_with_no_args():
"""Function for testing with no args."""
pass
async def async_function_for_testing_with_1_arg_and_tool_context(
arg1, tool_context
):
"""Async function for testing with 1 arge and tool context."""
assert arg1
assert tool_context
return arg1
async def async_function_for_testing_with_2_arg_and_no_tool_context(arg1, arg2):
"""Async function for testing with 2 arge and no tool context."""
assert arg1
assert arg2
return arg1
def function_for_testing_with_1_arg_and_tool_context(arg1, tool_context):
"""Function for testing with 1 arge and tool context."""
assert arg1
assert tool_context
return arg1
def function_for_testing_with_2_arg_and_no_tool_context(arg1, arg2):
"""Function for testing with 2 arge and no tool context."""
assert arg1
assert arg2
return arg1
async def async_function_for_testing_with_4_arg_and_no_tool_context(
arg1, arg2, arg3, arg4
):
"""Async function for testing with 4 args."""
pass
def function_for_testing_with_4_arg_and_no_tool_context(arg1, arg2, arg3, arg4):
"""Function for testing with 4 args."""
pass
def test_init():
"""Test that the FunctionTool is initialized correctly."""
tool = FunctionTool(function_for_testing_with_no_args)
assert tool.name == "function_for_testing_with_no_args"
assert tool.description == "Function for testing with no args."
assert tool.func == function_for_testing_with_no_args
@pytest.mark.asyncio
async def test_run_async_with_tool_context_async_func():
"""Test that run_async calls the function with tool_context when tool_context is in signature (async function)."""
tool = FunctionTool(async_function_for_testing_with_1_arg_and_tool_context)
args = {"arg1": "test_value_1"}
result = await tool.run_async(args=args, tool_context=MagicMock())
assert result == "test_value_1"
@pytest.mark.asyncio
async def test_run_async_without_tool_context_async_func():
"""Test that run_async calls the function without tool_context when tool_context is not in signature (async function)."""
tool = FunctionTool(async_function_for_testing_with_2_arg_and_no_tool_context)
args = {"arg1": "test_value_1", "arg2": "test_value_2"}
result = await tool.run_async(args=args, tool_context=MagicMock())
assert result == "test_value_1"
@pytest.mark.asyncio
async def test_run_async_with_tool_context_sync_func():
"""Test that run_async calls the function with tool_context when tool_context is in signature (synchronous function)."""
tool = FunctionTool(function_for_testing_with_1_arg_and_tool_context)
args = {"arg1": "test_value_1"}
result = await tool.run_async(args=args, tool_context=MagicMock())
assert result == "test_value_1"
@pytest.mark.asyncio
async def test_run_async_without_tool_context_sync_func():
"""Test that run_async calls the function without tool_context when tool_context is not in signature (synchronous function)."""
tool = FunctionTool(function_for_testing_with_2_arg_and_no_tool_context)
args = {"arg1": "test_value_1", "arg2": "test_value_2"}
result = await tool.run_async(args=args, tool_context=MagicMock())
assert result == "test_value_1"
@pytest.mark.asyncio
async def test_run_async_1_missing_arg_sync_func():
"""Test that run_async calls the function with 1 missing arg in signature (synchronous function)."""
tool = FunctionTool(function_for_testing_with_2_arg_and_no_tool_context)
args = {"arg1": "test_value_1"}
result = await tool.run_async(args=args, tool_context=MagicMock())
assert result == {
"error": (
"""Invoking `function_for_testing_with_2_arg_and_no_tool_context()` failed as the following mandatory input parameters are not present:
arg2
You could retry calling this tool, but it is IMPORTANT for you to provide all the mandatory parameters."""
)
}
@pytest.mark.asyncio
async def test_run_async_1_missing_arg_async_func():
"""Test that run_async calls the function with 1 missing arg in signature (async function)."""
tool = FunctionTool(async_function_for_testing_with_2_arg_and_no_tool_context)
args = {"arg2": "test_value_1"}
result = await tool.run_async(args=args, tool_context=MagicMock())
assert result == {
"error": (
"""Invoking `async_function_for_testing_with_2_arg_and_no_tool_context()` failed as the following mandatory input parameters are not present:
arg1
You could retry calling this tool, but it is IMPORTANT for you to provide all the mandatory parameters."""
)
}
@pytest.mark.asyncio
async def test_run_async_3_missing_arg_sync_func():
"""Test that run_async calls the function with 3 missing args in signature (synchronous function)."""
tool = FunctionTool(function_for_testing_with_4_arg_and_no_tool_context)
args = {"arg2": "test_value_1"}
result = await tool.run_async(args=args, tool_context=MagicMock())
assert result == {
"error": (
"""Invoking `function_for_testing_with_4_arg_and_no_tool_context()` failed as the following mandatory input parameters are not present:
arg1
arg3
arg4
You could retry calling this tool, but it is IMPORTANT for you to provide all the mandatory parameters."""
)
}
@pytest.mark.asyncio
async def test_run_async_3_missing_arg_async_func():
"""Test that run_async calls the function with 3 missing args in signature (async function)."""
tool = FunctionTool(async_function_for_testing_with_4_arg_and_no_tool_context)
args = {"arg3": "test_value_1"}
result = await tool.run_async(args=args, tool_context=MagicMock())
assert result == {
"error": (
"""Invoking `async_function_for_testing_with_4_arg_and_no_tool_context()` failed as the following mandatory input parameters are not present:
arg1
arg2
arg4
You could retry calling this tool, but it is IMPORTANT for you to provide all the mandatory parameters."""
)
}
@pytest.mark.asyncio
async def test_run_async_missing_all_arg_sync_func():
"""Test that run_async calls the function with all missing args in signature (synchronous function)."""
tool = FunctionTool(function_for_testing_with_4_arg_and_no_tool_context)
args = {}
result = await tool.run_async(args=args, tool_context=MagicMock())
assert result == {
"error": (
"""Invoking `function_for_testing_with_4_arg_and_no_tool_context()` failed as the following mandatory input parameters are not present:
arg1
arg2
arg3
arg4
You could retry calling this tool, but it is IMPORTANT for you to provide all the mandatory parameters."""
)
}
@pytest.mark.asyncio
async def test_run_async_missing_all_arg_async_func():
"""Test that run_async calls the function with all missing args in signature (async function)."""
tool = FunctionTool(async_function_for_testing_with_4_arg_and_no_tool_context)
args = {}
result = await tool.run_async(args=args, tool_context=MagicMock())
assert result == {
"error": (
"""Invoking `async_function_for_testing_with_4_arg_and_no_tool_context()` failed as the following mandatory input parameters are not present:
arg1
arg2
arg3
arg4
You could retry calling this tool, but it is IMPORTANT for you to provide all the mandatory parameters."""
)
}
@pytest.mark.asyncio
async def test_run_async_with_optional_args_not_set_sync_func():
"""Test that run_async calls the function for sync funciton with optional args not set."""
def func_with_optional_args(arg1, arg2=None, *, arg3, arg4=None, **kwargs):
return f"{arg1},{arg3}"
tool = FunctionTool(func_with_optional_args)
args = {"arg1": "test_value_1", "arg3": "test_value_3"}
result = await tool.run_async(args=args, tool_context=MagicMock())
assert result == "test_value_1,test_value_3"
@pytest.mark.asyncio
async def test_run_async_with_optional_args_not_set_async_func():
"""Test that run_async calls the function for async funciton with optional args not set."""
async def async_func_with_optional_args(
arg1, arg2=None, *, arg3, arg4=None, **kwargs
):
return f"{arg1},{arg3}"
tool = FunctionTool(async_func_with_optional_args)
args = {"arg1": "test_value_1", "arg3": "test_value_3"}
result = await tool.run_async(args=args, tool_context=MagicMock())
assert result == "test_value_1,test_value_3"