mirror of
https://github.com/EvolutionAPI/adk-python.git
synced 2025-07-14 09:51:25 -06:00
Merge branch 'main' into #247-OpenAPIToolSet-Considering-Required-parameters
This commit is contained in:
commit
f12300113d
59
.github/workflows/pyink.yml
vendored
Normal file
59
.github/workflows/pyink.yml
vendored
Normal file
@ -0,0 +1,59 @@
|
|||||||
|
name: Check Pyink Formatting
|
||||||
|
|
||||||
|
on:
|
||||||
|
pull_request:
|
||||||
|
paths:
|
||||||
|
- 'src/**/*.py'
|
||||||
|
- 'tests/**/*.py'
|
||||||
|
- 'pyproject.toml'
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
pyink-check:
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
|
||||||
|
steps:
|
||||||
|
- name: Checkout code
|
||||||
|
uses: actions/checkout@v4
|
||||||
|
|
||||||
|
- name: Set up Python
|
||||||
|
uses: actions/setup-python@v5
|
||||||
|
with:
|
||||||
|
python-version: '3.x'
|
||||||
|
|
||||||
|
- name: Install pyink
|
||||||
|
run: |
|
||||||
|
pip install pyink
|
||||||
|
|
||||||
|
- name: Detect changed Python files
|
||||||
|
id: detect_changes
|
||||||
|
run: |
|
||||||
|
git fetch origin ${{ github.base_ref }}
|
||||||
|
CHANGED_FILES=$(git diff --diff-filter=ACMR --name-only origin/${{ github.base_ref }}...HEAD | grep -E '\.py$' || true)
|
||||||
|
echo "CHANGED_FILES=${CHANGED_FILES}" >> $GITHUB_ENV
|
||||||
|
|
||||||
|
- name: Run pyink on changed files
|
||||||
|
if: env.CHANGED_FILES != ''
|
||||||
|
run: |
|
||||||
|
echo "Changed Python files:"
|
||||||
|
echo "$CHANGED_FILES"
|
||||||
|
|
||||||
|
# Run pyink --check
|
||||||
|
set +e
|
||||||
|
pyink --check --config pyproject.toml $CHANGED_FILES
|
||||||
|
RESULT=$?
|
||||||
|
set -e
|
||||||
|
|
||||||
|
if [ $RESULT -ne 0 ]; then
|
||||||
|
echo ""
|
||||||
|
echo "❌ Pyink formatting check failed!"
|
||||||
|
echo "👉 To fix formatting, run locally:"
|
||||||
|
echo ""
|
||||||
|
echo " pyink --config pyproject.toml $CHANGED_FILES"
|
||||||
|
echo ""
|
||||||
|
exit $RESULT
|
||||||
|
fi
|
||||||
|
|
||||||
|
- name: No changed Python files detected
|
||||||
|
if: env.CHANGED_FILES == ''
|
||||||
|
run: |
|
||||||
|
echo "No Python files changed. Skipping pyink check."
|
8
.github/workflows/python-unit-tests.yml
vendored
8
.github/workflows/python-unit-tests.yml
vendored
@ -29,11 +29,13 @@ jobs:
|
|||||||
run: |
|
run: |
|
||||||
uv venv .venv
|
uv venv .venv
|
||||||
source .venv/bin/activate
|
source .venv/bin/activate
|
||||||
uv sync --extra test
|
uv sync --extra test --extra eval
|
||||||
|
|
||||||
- name: Run unit tests with pytest
|
- name: Run unit tests with pytest
|
||||||
run: |
|
run: |
|
||||||
source .venv/bin/activate
|
source .venv/bin/activate
|
||||||
pytest tests/unittests \
|
pytest tests/unittests \
|
||||||
--ignore=tests/unittests/tools/google_api_tool/test_googleapi_to_openapi_converter.py \
|
--ignore=tests/unittests/artifacts/test_artifact_service.py \
|
||||||
--ignore=tests/unittests/artifacts/test_artifact_service.py
|
--ignore=tests/unittests/tools/application_integration_tool/clients/test_connections_client.py \
|
||||||
|
--ignore=tests/unittests/tools/google_api_tool/test_googleapi_to_openapi_converter.py
|
||||||
|
|
||||||
|
76
CHANGELOG.md
76
CHANGELOG.md
@ -1,17 +1,89 @@
|
|||||||
# Changelog
|
# Changelog
|
||||||
|
|
||||||
|
## 0.3.0
|
||||||
|
|
||||||
|
### ⚠ BREAKING CHANGES
|
||||||
|
|
||||||
|
* Auth: expose `access_token` and `refresh_token` at top level of auth
|
||||||
|
credentails, instead of a `dict`
|
||||||
|
([commit](https://github.com/google/adk-python/commit/956fb912e8851b139668b1ccb8db10fd252a6990)).
|
||||||
|
|
||||||
|
### Features
|
||||||
|
|
||||||
|
* Added support for running agents with MCPToolset easily on `adk web`.
|
||||||
|
* Added `custom_metadata` field to `LlmResponse`, which can be used to tag
|
||||||
|
LlmResponse via `after_model_callback`.
|
||||||
|
* Added `--session_db_url` to `adk deploy cloud_run` option.
|
||||||
|
* Many Dev UI improvements:
|
||||||
|
* Better google search result rendering.
|
||||||
|
* Show websocket close reason in Dev UI.
|
||||||
|
* Better error message showing for audio/video.
|
||||||
|
|
||||||
|
### Bug Fixes
|
||||||
|
|
||||||
|
* Fixed MCP tool json schema parsing issue.
|
||||||
|
* Fixed issues in DatabaseSessionService that leads to crash.
|
||||||
|
* Fixed functions.py.
|
||||||
|
* Fixed `skip_summarization` behavior in `AgentTool`.
|
||||||
|
|
||||||
|
### Miscellaneous Chores
|
||||||
|
|
||||||
|
* README.md impprovements.
|
||||||
|
* Various code improvements.
|
||||||
|
* Various typo fixes.
|
||||||
|
* Bump min version of google-genai to 1.11.0.
|
||||||
|
|
||||||
|
## 0.2.0
|
||||||
|
|
||||||
|
### ⚠ BREAKING CHANGES
|
||||||
|
|
||||||
|
* Fix typo in method name in `Event`: has_trailing_code_exeuction_result --> has_trailing_code_execution_result.
|
||||||
|
|
||||||
|
### Features
|
||||||
|
|
||||||
|
* `adk` CLI:
|
||||||
|
* Introduce `adk create` cli tool to help creating agents.
|
||||||
|
* Adds `--verbosity` option to `adk deploy cloud_run` to show detailed cloud
|
||||||
|
run deploy logging.
|
||||||
|
* Improve the initialization error message for `DatabaseSessionService`.
|
||||||
|
* Lazy loading for Google 1P tools to minimize the initial latency.
|
||||||
|
* Support emitting state-change-only events from planners.
|
||||||
|
* Lots of Dev UI updates, including:
|
||||||
|
* Show planner thoughts and actions in the Dev UI.
|
||||||
|
* Support MCP tools in Dev UI.
|
||||||
|
(NOTE: `agent.py` interface is temp solution and is subject to change)
|
||||||
|
* Auto-select the only app if only one app is available.
|
||||||
|
* Show grounding links generated by Google Search Tool.
|
||||||
|
* `.env` file is reloaded on every agent run.
|
||||||
|
|
||||||
|
### Bug Fixes
|
||||||
|
|
||||||
|
* `LiteLlm`: arg parsing error and python 3.9 compatibility.
|
||||||
|
* `DatabaseSessionService`: adds the missing fields; fixes event with empty
|
||||||
|
content not being persisted.
|
||||||
|
* Google API Discovery response parsing issue.
|
||||||
|
* `load_memory_tool` rendering issue in Dev UI.
|
||||||
|
* Markdown text overflows in Dev UI.
|
||||||
|
|
||||||
|
### Miscellaneous Chores
|
||||||
|
|
||||||
|
* Adds unit tests in Github action.
|
||||||
|
* Improves test coverage.
|
||||||
|
* Various typo fixes.
|
||||||
|
|
||||||
## 0.1.0
|
## 0.1.0
|
||||||
|
|
||||||
### Features
|
### Features
|
||||||
|
|
||||||
* Initial release of the Agent Development Kit (ADK).
|
* Initial release of the Agent Development Kit (ADK).
|
||||||
* Multi-agent, agent-as-workflow, and custom agent support
|
* Multi-agent, agent-as-workflow, and custom agent support
|
||||||
* Tool authentication support
|
* Tool authentication support
|
||||||
* Rich tool support, e.g. bult-in tools, google-cloud tools, thir-party tools, and MCP tools
|
* Rich tool support, e.g. built-in tools, google-cloud tools, third-party tools, and MCP tools
|
||||||
* Rich callback support
|
* Rich callback support
|
||||||
* Built-in code execution capability
|
* Built-in code execution capability
|
||||||
* Asynchronous runtime and execution
|
* Asynchronous runtime and execution
|
||||||
* Session, and memory support
|
* Session, and memory support
|
||||||
* Built-in evaluation support
|
* Built-in evaluation support
|
||||||
* Development UI that makes local devlopment easy
|
* Development UI that makes local development easy
|
||||||
* Deploy to Google Cloud Run, Agent Engine
|
* Deploy to Google Cloud Run, Agent Engine
|
||||||
* (Experimental) Live(Bidi) auido/video agent support and Compositional Function Calling(CFC) support
|
* (Experimental) Live(Bidi) auido/video agent support and Compositional Function Calling(CFC) support
|
||||||
|
@ -25,6 +25,19 @@ This project follows
|
|||||||
|
|
||||||
## Contribution process
|
## Contribution process
|
||||||
|
|
||||||
|
### Requirement for PRs
|
||||||
|
|
||||||
|
- All PRs, other than small documentation or typo fixes, should have a Issue assoicated. If not, please create one.
|
||||||
|
- Small, focused PRs. Keep changes minimal—one concern per PR.
|
||||||
|
- For bug fixes or features, please provide logs or screenshot after the fix is applied to help reviewers better understand the fix.
|
||||||
|
- Please add corresponding testing for your code change if it's not covered by existing tests.
|
||||||
|
|
||||||
|
### Large or Complex Changes
|
||||||
|
For substantial features or architectural revisions:
|
||||||
|
|
||||||
|
- Open an Issue First: Outline your proposal, including design considerations and impact.
|
||||||
|
- Gather Feedback: Discuss with maintainers and the community to ensure alignment and avoid duplicate work
|
||||||
|
|
||||||
### Code reviews
|
### Code reviews
|
||||||
|
|
||||||
All submissions, including submissions by project members, require review. We
|
All submissions, including submissions by project members, require review. We
|
||||||
|
33
README.md
33
README.md
@ -18,11 +18,7 @@
|
|||||||
</h3>
|
</h3>
|
||||||
</html>
|
</html>
|
||||||
|
|
||||||
Agent Development Kit (ADK) is designed for developers seeking fine-grained
|
Agent Development Kit (ADK) is a flexible and modular framework for developing and deploying AI agents. While optimized for Gemini and the Google ecosystem, ADK is model-agnostic, deployment-agnostic, and is built for compatibility with other frameworks. ADK was designed to make agent development feel more like software development, to make it easier for developers to create, deploy, and orchestrate agentic architectures that range from simple tasks to complex workflows.
|
||||||
control and flexibility when building advanced AI agents that are tightly
|
|
||||||
integrated with services in Google Cloud. It allows you to define agent
|
|
||||||
behavior, orchestration, and tool use directly in code, enabling robust
|
|
||||||
debugging, versioning, and deployment anywhere – from your laptop to the cloud.
|
|
||||||
|
|
||||||
|
|
||||||
---
|
---
|
||||||
@ -45,12 +41,27 @@ debugging, versioning, and deployment anywhere – from your laptop to the cloud
|
|||||||
|
|
||||||
## 🚀 Installation
|
## 🚀 Installation
|
||||||
|
|
||||||
You can install the ADK using `pip`:
|
### Stable Release (Recommended)
|
||||||
|
|
||||||
|
You can install the latest stable version of ADK using `pip`:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
pip install google-adk
|
pip install google-adk
|
||||||
```
|
```
|
||||||
|
|
||||||
|
The release cadence is weekly.
|
||||||
|
|
||||||
|
This version is recommended for most users as it represents the most recent official release.
|
||||||
|
|
||||||
|
### Development Version
|
||||||
|
Bug fixes and new features are merged into the main branch on GitHub first. If you need access to changes that haven't been included in an official PyPI release yet, you can install directly from the main branch:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
pip install git+https://github.com/google/adk-python.git@main
|
||||||
|
```
|
||||||
|
|
||||||
|
Note: The development version is built directly from the latest code commits. While it includes the newest fixes and features, it may also contain experimental changes or bugs not present in the stable release. Use it primarily for testing upcoming changes or accessing critical fixes before they are officially released.
|
||||||
|
|
||||||
## 📚 Documentation
|
## 📚 Documentation
|
||||||
|
|
||||||
Explore the full documentation for detailed guides on building, evaluating, and
|
Explore the full documentation for detailed guides on building, evaluating, and
|
||||||
@ -112,10 +123,18 @@ adk eval \
|
|||||||
samples_for_testing/hello_world/hello_world_eval_set_001.evalset.json
|
samples_for_testing/hello_world/hello_world_eval_set_001.evalset.json
|
||||||
```
|
```
|
||||||
|
|
||||||
|
## 🤖 A2A and ADK integration
|
||||||
|
|
||||||
|
For remote agent-to-agent communication, ADK integrates with the
|
||||||
|
[A2A protocol](https://github.com/google/A2A/).
|
||||||
|
See this [example](https://github.com/google/A2A/tree/main/samples/python/agents/google_adk)
|
||||||
|
for how they can work together.
|
||||||
|
|
||||||
## 🤝 Contributing
|
## 🤝 Contributing
|
||||||
|
|
||||||
We welcome contributions from the community! Whether it's bug reports, feature requests, documentation improvements, or code contributions, please see our [**Contributing Guidelines**](./CONTRIBUTING.md) to get started.
|
We welcome contributions from the community! Whether it's bug reports, feature requests, documentation improvements, or code contributions, please see our
|
||||||
|
- [General contribution guideline and flow](https://google.github.io/adk-docs/contributing-guide/#questions).
|
||||||
|
- Then if you want to contribute code, please read [Code Contributing Guidelines](./CONTRIBUTING.md) to get started.
|
||||||
|
|
||||||
## 📄 License
|
## 📄 License
|
||||||
|
|
||||||
|
2
pylintrc
2
pylintrc
@ -45,7 +45,7 @@ confidence=
|
|||||||
# can either give multiple identifiers separated by comma (,) or put this
|
# can either give multiple identifiers separated by comma (,) or put this
|
||||||
# option multiple times (only on the command line, not in the configuration
|
# option multiple times (only on the command line, not in the configuration
|
||||||
# file where it should appear only once).You can also use "--disable=all" to
|
# file where it should appear only once).You can also use "--disable=all" to
|
||||||
# disable everything first and then reenable specific checks. For example, if
|
# disable everything first and then re-enable specific checks. For example, if
|
||||||
# you want to run only the similarities checker, you can use "--disable=all
|
# you want to run only the similarities checker, you can use "--disable=all
|
||||||
# --enable=similarities". If you want to run only the classes checker, but have
|
# --enable=similarities". If you want to run only the classes checker, but have
|
||||||
# no Warning level messages displayed, use"--disable=all --enable=classes
|
# no Warning level messages displayed, use"--disable=all --enable=classes
|
||||||
|
@ -33,7 +33,7 @@ dependencies = [
|
|||||||
"google-cloud-secret-manager>=2.22.0", # Fetching secrets in RestAPI Tool
|
"google-cloud-secret-manager>=2.22.0", # Fetching secrets in RestAPI Tool
|
||||||
"google-cloud-speech>=2.30.0", # For Audo Transcription
|
"google-cloud-speech>=2.30.0", # For Audo Transcription
|
||||||
"google-cloud-storage>=2.18.0, <3.0.0", # For GCS Artifact service
|
"google-cloud-storage>=2.18.0, <3.0.0", # For GCS Artifact service
|
||||||
"google-genai>=1.9.0", # Google GenAI SDK
|
"google-genai>=1.11.0", # Google GenAI SDK
|
||||||
"graphviz>=0.20.2", # Graphviz for graph rendering
|
"graphviz>=0.20.2", # Graphviz for graph rendering
|
||||||
"mcp>=1.5.0;python_version>='3.10'", # For MCP Toolset
|
"mcp>=1.5.0;python_version>='3.10'", # For MCP Toolset
|
||||||
"opentelemetry-api>=1.31.0", # OpenTelemetry
|
"opentelemetry-api>=1.31.0", # OpenTelemetry
|
||||||
@ -119,6 +119,15 @@ line-length = 80
|
|||||||
unstable = true
|
unstable = true
|
||||||
pyink-indentation = 2
|
pyink-indentation = 2
|
||||||
pyink-use-majority-quotes = true
|
pyink-use-majority-quotes = true
|
||||||
|
pyink-annotation-pragmas = [
|
||||||
|
"noqa",
|
||||||
|
"pylint:",
|
||||||
|
"type: ignore",
|
||||||
|
"pytype:",
|
||||||
|
"mypy:",
|
||||||
|
"pyright:",
|
||||||
|
"pyre-",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
[build-system]
|
[build-system]
|
||||||
@ -135,15 +144,10 @@ exclude = ['src/**/*.sh']
|
|||||||
name = "google.adk"
|
name = "google.adk"
|
||||||
|
|
||||||
[tool.isort]
|
[tool.isort]
|
||||||
# Organize imports following Google style-guide
|
profile = "google"
|
||||||
force_single_line = true
|
|
||||||
force_sort_within_sections = true
|
|
||||||
honor_case_in_force_sorted_sections = true
|
|
||||||
order_by_type = false
|
|
||||||
sort_relative_in_force_sorted_sections = true
|
|
||||||
multi_line_output = 3
|
|
||||||
line_length = 200
|
|
||||||
|
|
||||||
[tool.pytest.ini_options]
|
[tool.pytest.ini_options]
|
||||||
testpaths = ["tests"]
|
testpaths = ["tests"]
|
||||||
asyncio_default_fixture_loop_scope = "function"
|
asyncio_default_fixture_loop_scope = "function"
|
||||||
|
asyncio_mode = "auto"
|
||||||
|
@ -44,7 +44,7 @@ Args:
|
|||||||
callback_context: MUST be named 'callback_context' (enforced).
|
callback_context: MUST be named 'callback_context' (enforced).
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
The content to return to the user. When set, the agent run will skipped and
|
The content to return to the user. When set, the agent run will be skipped and
|
||||||
the provided content will be returned to user.
|
the provided content will be returned to user.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@ -55,8 +55,8 @@ Args:
|
|||||||
callback_context: MUST be named 'callback_context' (enforced).
|
callback_context: MUST be named 'callback_context' (enforced).
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
The content to return to the user. When set, the agent run will skipped and
|
The content to return to the user. When set, the provided content will be
|
||||||
the provided content will be appended to event history as agent response.
|
appended to event history as agent response.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
@ -101,8 +101,8 @@ class BaseAgent(BaseModel):
|
|||||||
callback_context: MUST be named 'callback_context' (enforced).
|
callback_context: MUST be named 'callback_context' (enforced).
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
The content to return to the user. When set, the agent run will skipped and
|
The content to return to the user. When set, the agent run will be skipped
|
||||||
the provided content will be returned to user.
|
and the provided content will be returned to user.
|
||||||
"""
|
"""
|
||||||
after_agent_callback: Optional[AfterAgentCallback] = None
|
after_agent_callback: Optional[AfterAgentCallback] = None
|
||||||
"""Callback signature that is invoked after the agent run.
|
"""Callback signature that is invoked after the agent run.
|
||||||
@ -111,8 +111,8 @@ class BaseAgent(BaseModel):
|
|||||||
callback_context: MUST be named 'callback_context' (enforced).
|
callback_context: MUST be named 'callback_context' (enforced).
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
The content to return to the user. When set, the agent run will skipped and
|
The content to return to the user. When set, the provided content will be
|
||||||
the provided content will be appended to event history as agent response.
|
appended to event history as agent response.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@final
|
@final
|
||||||
|
@ -23,7 +23,6 @@ from .readonly_context import ReadonlyContext
|
|||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from google.genai import types
|
from google.genai import types
|
||||||
|
|
||||||
from ..events.event import Event
|
|
||||||
from ..events.event_actions import EventActions
|
from ..events.event_actions import EventActions
|
||||||
from ..sessions.state import State
|
from ..sessions.state import State
|
||||||
from .invocation_context import InvocationContext
|
from .invocation_context import InvocationContext
|
||||||
|
@ -15,12 +15,7 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
from typing import Any
|
from typing import Any, AsyncGenerator, Awaitable, Callable, Literal, Optional, Union
|
||||||
from typing import AsyncGenerator
|
|
||||||
from typing import Callable
|
|
||||||
from typing import Literal
|
|
||||||
from typing import Optional
|
|
||||||
from typing import Union
|
|
||||||
|
|
||||||
from google.genai import types
|
from google.genai import types
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
@ -62,11 +57,11 @@ AfterModelCallback: TypeAlias = Callable[
|
|||||||
]
|
]
|
||||||
BeforeToolCallback: TypeAlias = Callable[
|
BeforeToolCallback: TypeAlias = Callable[
|
||||||
[BaseTool, dict[str, Any], ToolContext],
|
[BaseTool, dict[str, Any], ToolContext],
|
||||||
Optional[dict],
|
Union[Awaitable[Optional[dict]], Optional[dict]],
|
||||||
]
|
]
|
||||||
AfterToolCallback: TypeAlias = Callable[
|
AfterToolCallback: TypeAlias = Callable[
|
||||||
[BaseTool, dict[str, Any], ToolContext, dict],
|
[BaseTool, dict[str, Any], ToolContext, dict],
|
||||||
Optional[dict],
|
Union[Awaitable[Optional[dict]], Optional[dict]],
|
||||||
]
|
]
|
||||||
|
|
||||||
InstructionProvider: TypeAlias = Callable[[ReadonlyContext], str]
|
InstructionProvider: TypeAlias = Callable[[ReadonlyContext], str]
|
||||||
|
@ -66,7 +66,8 @@ class OAuth2Auth(BaseModelWithConfig):
|
|||||||
redirect_uri: Optional[str] = None
|
redirect_uri: Optional[str] = None
|
||||||
auth_response_uri: Optional[str] = None
|
auth_response_uri: Optional[str] = None
|
||||||
auth_code: Optional[str] = None
|
auth_code: Optional[str] = None
|
||||||
token: Optional[Dict[str, Any]] = None
|
access_token: Optional[str] = None
|
||||||
|
refresh_token: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
class ServiceAccountCredential(BaseModelWithConfig):
|
class ServiceAccountCredential(BaseModelWithConfig):
|
||||||
|
@ -82,7 +82,8 @@ class AuthHandler:
|
|||||||
or not auth_credential.oauth2
|
or not auth_credential.oauth2
|
||||||
or not auth_credential.oauth2.client_id
|
or not auth_credential.oauth2.client_id
|
||||||
or not auth_credential.oauth2.client_secret
|
or not auth_credential.oauth2.client_secret
|
||||||
or auth_credential.oauth2.token
|
or auth_credential.oauth2.access_token
|
||||||
|
or auth_credential.oauth2.refresh_token
|
||||||
):
|
):
|
||||||
return self.auth_config.exchanged_auth_credential
|
return self.auth_config.exchanged_auth_credential
|
||||||
|
|
||||||
@ -93,7 +94,7 @@ class AuthHandler:
|
|||||||
redirect_uri=auth_credential.oauth2.redirect_uri,
|
redirect_uri=auth_credential.oauth2.redirect_uri,
|
||||||
state=auth_credential.oauth2.state,
|
state=auth_credential.oauth2.state,
|
||||||
)
|
)
|
||||||
token = client.fetch_token(
|
tokens = client.fetch_token(
|
||||||
token_endpoint,
|
token_endpoint,
|
||||||
authorization_response=auth_credential.oauth2.auth_response_uri,
|
authorization_response=auth_credential.oauth2.auth_response_uri,
|
||||||
code=auth_credential.oauth2.auth_code,
|
code=auth_credential.oauth2.auth_code,
|
||||||
@ -102,7 +103,10 @@ class AuthHandler:
|
|||||||
|
|
||||||
updated_credential = AuthCredential(
|
updated_credential = AuthCredential(
|
||||||
auth_type=AuthCredentialTypes.OAUTH2,
|
auth_type=AuthCredentialTypes.OAUTH2,
|
||||||
oauth2=OAuth2Auth(token=dict(token)),
|
oauth2=OAuth2Auth(
|
||||||
|
access_token=tokens.get("access_token"),
|
||||||
|
refresh_token=tokens.get("refresh_token"),
|
||||||
|
),
|
||||||
)
|
)
|
||||||
return updated_credential
|
return updated_credential
|
||||||
|
|
||||||
|
@ -29,5 +29,5 @@
|
|||||||
<style>html{color-scheme:dark}html{--mat-sys-background:light-dark(#fcf9f8, #131314);--mat-sys-error:light-dark(#ba1a1a, #ffb4ab);--mat-sys-error-container:light-dark(#ffdad6, #93000a);--mat-sys-inverse-on-surface:light-dark(#f3f0f0, #313030);--mat-sys-inverse-primary:light-dark(#c1c7cd, #595f65);--mat-sys-inverse-surface:light-dark(#313030, #e5e2e2);--mat-sys-on-background:light-dark(#1c1b1c, #e5e2e2);--mat-sys-on-error:light-dark(#ffffff, #690005);--mat-sys-on-error-container:light-dark(#410002, #ffdad6);--mat-sys-on-primary:light-dark(#ffffff, #2b3136);--mat-sys-on-primary-container:light-dark(#161c21, #dde3e9);--mat-sys-on-primary-fixed:light-dark(#161c21, #161c21);--mat-sys-on-primary-fixed-variant:light-dark(#41474d, #41474d);--mat-sys-on-secondary:light-dark(#ffffff, #003061);--mat-sys-on-secondary-container:light-dark(#001b3c, #d5e3ff);--mat-sys-on-secondary-fixed:light-dark(#001b3c, #001b3c);--mat-sys-on-secondary-fixed-variant:light-dark(#0f4784, #0f4784);--mat-sys-on-surface:light-dark(#1c1b1c, #e5e2e2);--mat-sys-on-surface-variant:light-dark(#44474a, #e1e2e6);--mat-sys-on-tertiary:light-dark(#ffffff, #2b3136);--mat-sys-on-tertiary-container:light-dark(#161c21, #dde3e9);--mat-sys-on-tertiary-fixed:light-dark(#161c21, #161c21);--mat-sys-on-tertiary-fixed-variant:light-dark(#41474d, #41474d);--mat-sys-outline:light-dark(#74777b, #8e9194);--mat-sys-outline-variant:light-dark(#c4c7ca, #44474a);--mat-sys-primary:light-dark(#595f65, #c1c7cd);--mat-sys-primary-container:light-dark(#dde3e9, #41474d);--mat-sys-primary-fixed:light-dark(#dde3e9, #dde3e9);--mat-sys-primary-fixed-dim:light-dark(#c1c7cd, #c1c7cd);--mat-sys-scrim:light-dark(#000000, #000000);--mat-sys-secondary:light-dark(#305f9d, #a7c8ff);--mat-sys-secondary-container:light-dark(#d5e3ff, #0f4784);--mat-sys-secondary-fixed:light-dark(#d5e3ff, #d5e3ff);--mat-sys-secondary-fixed-dim:light-dark(#a7c8ff, #a7c8ff);--mat-sys-shadow:light-dark(#000000, #000000);--mat-sys-surface:light-dark(#fcf9f8, #131314);--mat-sys-surface-bright:light-dark(#fcf9f8, #393939);--mat-sys-surface-container:light-dark(#f0eded, #201f20);--mat-sys-surface-container-high:light-dark(#eae7e7, #2a2a2a);--mat-sys-surface-container-highest:light-dark(#e5e2e2, #393939);--mat-sys-surface-container-low:light-dark(#f6f3f3, #1c1b1c);--mat-sys-surface-container-lowest:light-dark(#ffffff, #0e0e0e);--mat-sys-surface-dim:light-dark(#dcd9d9, #131314);--mat-sys-surface-tint:light-dark(#595f65, #c1c7cd);--mat-sys-surface-variant:light-dark(#e1e2e6, #44474a);--mat-sys-tertiary:light-dark(#595f65, #c1c7cd);--mat-sys-tertiary-container:light-dark(#dde3e9, #41474d);--mat-sys-tertiary-fixed:light-dark(#dde3e9, #dde3e9);--mat-sys-tertiary-fixed-dim:light-dark(#c1c7cd, #c1c7cd);--mat-sys-neutral-variant20:#2d3134;--mat-sys-neutral10:#1c1b1c}html{--mat-sys-level0:0px 0px 0px 0px rgba(0, 0, 0, .2), 0px 0px 0px 0px rgba(0, 0, 0, .14), 0px 0px 0px 0px rgba(0, 0, 0, .12)}html{--mat-sys-level1:0px 2px 1px -1px rgba(0, 0, 0, .2), 0px 1px 1px 0px rgba(0, 0, 0, .14), 0px 1px 3px 0px rgba(0, 0, 0, .12)}html{--mat-sys-level2:0px 3px 3px -2px rgba(0, 0, 0, .2), 0px 3px 4px 0px rgba(0, 0, 0, .14), 0px 1px 8px 0px rgba(0, 0, 0, .12)}html{--mat-sys-level3:0px 3px 5px -1px rgba(0, 0, 0, .2), 0px 6px 10px 0px rgba(0, 0, 0, .14), 0px 1px 18px 0px rgba(0, 0, 0, .12)}html{--mat-sys-level4:0px 5px 5px -3px rgba(0, 0, 0, .2), 0px 8px 10px 1px rgba(0, 0, 0, .14), 0px 3px 14px 2px rgba(0, 0, 0, .12)}html{--mat-sys-level5:0px 7px 8px -4px rgba(0, 0, 0, .2), 0px 12px 17px 2px rgba(0, 0, 0, .14), 0px 5px 22px 4px rgba(0, 0, 0, .12)}html{--mat-sys-corner-extra-large:28px;--mat-sys-corner-extra-large-top:28px 28px 0 0;--mat-sys-corner-extra-small:4px;--mat-sys-corner-extra-small-top:4px 4px 0 0;--mat-sys-corner-full:9999px;--mat-sys-corner-large:16px;--mat-sys-corner-large-end:0 16px 16px 0;--mat-sys-corner-large-start:16px 0 0 16px;--mat-sys-corner-large-top:16px 16px 0 0;--mat-sys-corner-medium:12px;--mat-sys-corner-none:0;--mat-sys-corner-small:8px}html{--mat-sys-dragged-state-layer-opacity:.16;--mat-sys-focus-state-layer-opacity:.12;--mat-sys-hover-state-layer-opacity:.08;--mat-sys-pressed-state-layer-opacity:.12}html{font-family:Google Sans,Helvetica Neue,sans-serif!important}body{height:100vh;margin:0}:root{--mat-sys-primary:black;--mdc-checkbox-selected-icon-color:white;--mat-sys-background:#131314;--mat-tab-header-active-label-text-color:#8AB4F8;--mat-tab-header-active-hover-label-text-color:#8AB4F8;--mat-tab-header-active-focus-label-text-color:#8AB4F8;--mat-tab-header-label-text-weight:500;--mdc-text-button-label-text-color:#89b4f8}:root{--mdc-dialog-container-color:#2b2b2f}:root{--mdc-dialog-subhead-color:white}:root{--mdc-circular-progress-active-indicator-color:#a8c7fa}:root{--mdc-circular-progress-size:80}</style><link rel="stylesheet" href="styles-4VDSPQ37.css" media="print" onload="this.media='all'"><noscript><link rel="stylesheet" href="styles-4VDSPQ37.css"></noscript></head>
|
<style>html{color-scheme:dark}html{--mat-sys-background:light-dark(#fcf9f8, #131314);--mat-sys-error:light-dark(#ba1a1a, #ffb4ab);--mat-sys-error-container:light-dark(#ffdad6, #93000a);--mat-sys-inverse-on-surface:light-dark(#f3f0f0, #313030);--mat-sys-inverse-primary:light-dark(#c1c7cd, #595f65);--mat-sys-inverse-surface:light-dark(#313030, #e5e2e2);--mat-sys-on-background:light-dark(#1c1b1c, #e5e2e2);--mat-sys-on-error:light-dark(#ffffff, #690005);--mat-sys-on-error-container:light-dark(#410002, #ffdad6);--mat-sys-on-primary:light-dark(#ffffff, #2b3136);--mat-sys-on-primary-container:light-dark(#161c21, #dde3e9);--mat-sys-on-primary-fixed:light-dark(#161c21, #161c21);--mat-sys-on-primary-fixed-variant:light-dark(#41474d, #41474d);--mat-sys-on-secondary:light-dark(#ffffff, #003061);--mat-sys-on-secondary-container:light-dark(#001b3c, #d5e3ff);--mat-sys-on-secondary-fixed:light-dark(#001b3c, #001b3c);--mat-sys-on-secondary-fixed-variant:light-dark(#0f4784, #0f4784);--mat-sys-on-surface:light-dark(#1c1b1c, #e5e2e2);--mat-sys-on-surface-variant:light-dark(#44474a, #e1e2e6);--mat-sys-on-tertiary:light-dark(#ffffff, #2b3136);--mat-sys-on-tertiary-container:light-dark(#161c21, #dde3e9);--mat-sys-on-tertiary-fixed:light-dark(#161c21, #161c21);--mat-sys-on-tertiary-fixed-variant:light-dark(#41474d, #41474d);--mat-sys-outline:light-dark(#74777b, #8e9194);--mat-sys-outline-variant:light-dark(#c4c7ca, #44474a);--mat-sys-primary:light-dark(#595f65, #c1c7cd);--mat-sys-primary-container:light-dark(#dde3e9, #41474d);--mat-sys-primary-fixed:light-dark(#dde3e9, #dde3e9);--mat-sys-primary-fixed-dim:light-dark(#c1c7cd, #c1c7cd);--mat-sys-scrim:light-dark(#000000, #000000);--mat-sys-secondary:light-dark(#305f9d, #a7c8ff);--mat-sys-secondary-container:light-dark(#d5e3ff, #0f4784);--mat-sys-secondary-fixed:light-dark(#d5e3ff, #d5e3ff);--mat-sys-secondary-fixed-dim:light-dark(#a7c8ff, #a7c8ff);--mat-sys-shadow:light-dark(#000000, #000000);--mat-sys-surface:light-dark(#fcf9f8, #131314);--mat-sys-surface-bright:light-dark(#fcf9f8, #393939);--mat-sys-surface-container:light-dark(#f0eded, #201f20);--mat-sys-surface-container-high:light-dark(#eae7e7, #2a2a2a);--mat-sys-surface-container-highest:light-dark(#e5e2e2, #393939);--mat-sys-surface-container-low:light-dark(#f6f3f3, #1c1b1c);--mat-sys-surface-container-lowest:light-dark(#ffffff, #0e0e0e);--mat-sys-surface-dim:light-dark(#dcd9d9, #131314);--mat-sys-surface-tint:light-dark(#595f65, #c1c7cd);--mat-sys-surface-variant:light-dark(#e1e2e6, #44474a);--mat-sys-tertiary:light-dark(#595f65, #c1c7cd);--mat-sys-tertiary-container:light-dark(#dde3e9, #41474d);--mat-sys-tertiary-fixed:light-dark(#dde3e9, #dde3e9);--mat-sys-tertiary-fixed-dim:light-dark(#c1c7cd, #c1c7cd);--mat-sys-neutral-variant20:#2d3134;--mat-sys-neutral10:#1c1b1c}html{--mat-sys-level0:0px 0px 0px 0px rgba(0, 0, 0, .2), 0px 0px 0px 0px rgba(0, 0, 0, .14), 0px 0px 0px 0px rgba(0, 0, 0, .12)}html{--mat-sys-level1:0px 2px 1px -1px rgba(0, 0, 0, .2), 0px 1px 1px 0px rgba(0, 0, 0, .14), 0px 1px 3px 0px rgba(0, 0, 0, .12)}html{--mat-sys-level2:0px 3px 3px -2px rgba(0, 0, 0, .2), 0px 3px 4px 0px rgba(0, 0, 0, .14), 0px 1px 8px 0px rgba(0, 0, 0, .12)}html{--mat-sys-level3:0px 3px 5px -1px rgba(0, 0, 0, .2), 0px 6px 10px 0px rgba(0, 0, 0, .14), 0px 1px 18px 0px rgba(0, 0, 0, .12)}html{--mat-sys-level4:0px 5px 5px -3px rgba(0, 0, 0, .2), 0px 8px 10px 1px rgba(0, 0, 0, .14), 0px 3px 14px 2px rgba(0, 0, 0, .12)}html{--mat-sys-level5:0px 7px 8px -4px rgba(0, 0, 0, .2), 0px 12px 17px 2px rgba(0, 0, 0, .14), 0px 5px 22px 4px rgba(0, 0, 0, .12)}html{--mat-sys-corner-extra-large:28px;--mat-sys-corner-extra-large-top:28px 28px 0 0;--mat-sys-corner-extra-small:4px;--mat-sys-corner-extra-small-top:4px 4px 0 0;--mat-sys-corner-full:9999px;--mat-sys-corner-large:16px;--mat-sys-corner-large-end:0 16px 16px 0;--mat-sys-corner-large-start:16px 0 0 16px;--mat-sys-corner-large-top:16px 16px 0 0;--mat-sys-corner-medium:12px;--mat-sys-corner-none:0;--mat-sys-corner-small:8px}html{--mat-sys-dragged-state-layer-opacity:.16;--mat-sys-focus-state-layer-opacity:.12;--mat-sys-hover-state-layer-opacity:.08;--mat-sys-pressed-state-layer-opacity:.12}html{font-family:Google Sans,Helvetica Neue,sans-serif!important}body{height:100vh;margin:0}:root{--mat-sys-primary:black;--mdc-checkbox-selected-icon-color:white;--mat-sys-background:#131314;--mat-tab-header-active-label-text-color:#8AB4F8;--mat-tab-header-active-hover-label-text-color:#8AB4F8;--mat-tab-header-active-focus-label-text-color:#8AB4F8;--mat-tab-header-label-text-weight:500;--mdc-text-button-label-text-color:#89b4f8}:root{--mdc-dialog-container-color:#2b2b2f}:root{--mdc-dialog-subhead-color:white}:root{--mdc-circular-progress-active-indicator-color:#a8c7fa}:root{--mdc-circular-progress-size:80}</style><link rel="stylesheet" href="styles-4VDSPQ37.css" media="print" onload="this.media='all'"><noscript><link rel="stylesheet" href="styles-4VDSPQ37.css"></noscript></head>
|
||||||
<body>
|
<body>
|
||||||
<app-root></app-root>
|
<app-root></app-root>
|
||||||
<script src="polyfills-FFHMD2TL.js" type="module"></script><script src="main-ZBO76GRM.js" type="module"></script></body>
|
<script src="polyfills-FFHMD2TL.js" type="module"></script><script src="main-HWIBUY2R.js" type="module"></script></body>
|
||||||
</html>
|
</html>
|
||||||
|
File diff suppressed because one or more lines are too long
@ -39,12 +39,12 @@ class InputFile(BaseModel):
|
|||||||
|
|
||||||
async def run_input_file(
|
async def run_input_file(
|
||||||
app_name: str,
|
app_name: str,
|
||||||
|
user_id: str,
|
||||||
root_agent: LlmAgent,
|
root_agent: LlmAgent,
|
||||||
artifact_service: BaseArtifactService,
|
artifact_service: BaseArtifactService,
|
||||||
session: Session,
|
|
||||||
session_service: BaseSessionService,
|
session_service: BaseSessionService,
|
||||||
input_path: str,
|
input_path: str,
|
||||||
) -> None:
|
) -> Session:
|
||||||
runner = Runner(
|
runner = Runner(
|
||||||
app_name=app_name,
|
app_name=app_name,
|
||||||
agent=root_agent,
|
agent=root_agent,
|
||||||
@ -55,9 +55,11 @@ async def run_input_file(
|
|||||||
input_file = InputFile.model_validate_json(f.read())
|
input_file = InputFile.model_validate_json(f.read())
|
||||||
input_file.state['_time'] = datetime.now()
|
input_file.state['_time'] = datetime.now()
|
||||||
|
|
||||||
session.state = input_file.state
|
session = session_service.create_session(
|
||||||
|
app_name=app_name, user_id=user_id, state=input_file.state
|
||||||
|
)
|
||||||
for query in input_file.queries:
|
for query in input_file.queries:
|
||||||
click.echo(f'user: {query}')
|
click.echo(f'[user]: {query}')
|
||||||
content = types.Content(role='user', parts=[types.Part(text=query)])
|
content = types.Content(role='user', parts=[types.Part(text=query)])
|
||||||
async for event in runner.run_async(
|
async for event in runner.run_async(
|
||||||
user_id=session.user_id, session_id=session.id, new_message=content
|
user_id=session.user_id, session_id=session.id, new_message=content
|
||||||
@ -65,23 +67,23 @@ async def run_input_file(
|
|||||||
if event.content and event.content.parts:
|
if event.content and event.content.parts:
|
||||||
if text := ''.join(part.text or '' for part in event.content.parts):
|
if text := ''.join(part.text or '' for part in event.content.parts):
|
||||||
click.echo(f'[{event.author}]: {text}')
|
click.echo(f'[{event.author}]: {text}')
|
||||||
|
return session
|
||||||
|
|
||||||
|
|
||||||
async def run_interactively(
|
async def run_interactively(
|
||||||
app_name: str,
|
|
||||||
root_agent: LlmAgent,
|
root_agent: LlmAgent,
|
||||||
artifact_service: BaseArtifactService,
|
artifact_service: BaseArtifactService,
|
||||||
session: Session,
|
session: Session,
|
||||||
session_service: BaseSessionService,
|
session_service: BaseSessionService,
|
||||||
) -> None:
|
) -> None:
|
||||||
runner = Runner(
|
runner = Runner(
|
||||||
app_name=app_name,
|
app_name=session.app_name,
|
||||||
agent=root_agent,
|
agent=root_agent,
|
||||||
artifact_service=artifact_service,
|
artifact_service=artifact_service,
|
||||||
session_service=session_service,
|
session_service=session_service,
|
||||||
)
|
)
|
||||||
while True:
|
while True:
|
||||||
query = input('user: ')
|
query = input('[user]: ')
|
||||||
if not query or not query.strip():
|
if not query or not query.strip():
|
||||||
continue
|
continue
|
||||||
if query == 'exit':
|
if query == 'exit':
|
||||||
@ -100,7 +102,8 @@ async def run_cli(
|
|||||||
*,
|
*,
|
||||||
agent_parent_dir: str,
|
agent_parent_dir: str,
|
||||||
agent_folder_name: str,
|
agent_folder_name: str,
|
||||||
json_file_path: Optional[str] = None,
|
input_file: Optional[str] = None,
|
||||||
|
saved_session_file: Optional[str] = None,
|
||||||
save_session: bool,
|
save_session: bool,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Runs an interactive CLI for a certain agent.
|
"""Runs an interactive CLI for a certain agent.
|
||||||
@ -109,8 +112,11 @@ async def run_cli(
|
|||||||
agent_parent_dir: str, the absolute path of the parent folder of the agent
|
agent_parent_dir: str, the absolute path of the parent folder of the agent
|
||||||
folder.
|
folder.
|
||||||
agent_folder_name: str, the name of the agent folder.
|
agent_folder_name: str, the name of the agent folder.
|
||||||
json_file_path: Optional[str], the absolute path to the json file, either
|
input_file: Optional[str], the absolute path to the json file that contains
|
||||||
*.input.json or *.session.json.
|
the initial session state and user queries, exclusive with
|
||||||
|
saved_session_file.
|
||||||
|
saved_session_file: Optional[str], the absolute path to the json file that
|
||||||
|
contains a previously saved session, exclusive with input_file.
|
||||||
save_session: bool, whether to save the session on exit.
|
save_session: bool, whether to save the session on exit.
|
||||||
"""
|
"""
|
||||||
if agent_parent_dir not in sys.path:
|
if agent_parent_dir not in sys.path:
|
||||||
@ -118,46 +124,50 @@ async def run_cli(
|
|||||||
|
|
||||||
artifact_service = InMemoryArtifactService()
|
artifact_service = InMemoryArtifactService()
|
||||||
session_service = InMemorySessionService()
|
session_service = InMemorySessionService()
|
||||||
session = session_service.create_session(
|
|
||||||
app_name=agent_folder_name, user_id='test_user'
|
|
||||||
)
|
|
||||||
|
|
||||||
agent_module_path = os.path.join(agent_parent_dir, agent_folder_name)
|
agent_module_path = os.path.join(agent_parent_dir, agent_folder_name)
|
||||||
agent_module = importlib.import_module(agent_folder_name)
|
agent_module = importlib.import_module(agent_folder_name)
|
||||||
|
user_id = 'test_user'
|
||||||
|
session = session_service.create_session(
|
||||||
|
app_name=agent_folder_name, user_id=user_id
|
||||||
|
)
|
||||||
root_agent = agent_module.agent.root_agent
|
root_agent = agent_module.agent.root_agent
|
||||||
envs.load_dotenv_for_agent(agent_folder_name, agent_parent_dir)
|
envs.load_dotenv_for_agent(agent_folder_name, agent_parent_dir)
|
||||||
if json_file_path:
|
if input_file:
|
||||||
if json_file_path.endswith('.input.json'):
|
session = await run_input_file(
|
||||||
await run_input_file(
|
app_name=agent_folder_name,
|
||||||
app_name=agent_folder_name,
|
user_id=user_id,
|
||||||
root_agent=root_agent,
|
root_agent=root_agent,
|
||||||
artifact_service=artifact_service,
|
artifact_service=artifact_service,
|
||||||
session=session,
|
session_service=session_service,
|
||||||
session_service=session_service,
|
input_path=input_file,
|
||||||
input_path=json_file_path,
|
)
|
||||||
)
|
elif saved_session_file:
|
||||||
elif json_file_path.endswith('.session.json'):
|
|
||||||
with open(json_file_path, 'r') as f:
|
loaded_session = None
|
||||||
session = Session.model_validate_json(f.read())
|
with open(saved_session_file, 'r') as f:
|
||||||
for content in session.get_contents():
|
loaded_session = Session.model_validate_json(f.read())
|
||||||
if content.role == 'user':
|
|
||||||
print('user: ', content.parts[0].text)
|
if loaded_session:
|
||||||
|
for event in loaded_session.events:
|
||||||
|
session_service.append_event(session, event)
|
||||||
|
content = event.content
|
||||||
|
if not content or not content.parts or not content.parts[0].text:
|
||||||
|
continue
|
||||||
|
if event.author == 'user':
|
||||||
|
click.echo(f'[user]: {content.parts[0].text}')
|
||||||
else:
|
else:
|
||||||
print(content.parts[0].text)
|
click.echo(f'[{event.author}]: {content.parts[0].text}')
|
||||||
await run_interactively(
|
|
||||||
agent_folder_name,
|
await run_interactively(
|
||||||
root_agent,
|
root_agent,
|
||||||
artifact_service,
|
artifact_service,
|
||||||
session,
|
session,
|
||||||
session_service,
|
session_service,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
print(f'Unsupported file type: {json_file_path}')
|
click.echo(f'Running agent {root_agent.name}, type exit to exit.')
|
||||||
exit(1)
|
|
||||||
else:
|
|
||||||
print(f'Running agent {root_agent.name}, type exit to exit.')
|
|
||||||
await run_interactively(
|
await run_interactively(
|
||||||
agent_folder_name,
|
|
||||||
root_agent,
|
root_agent,
|
||||||
artifact_service,
|
artifact_service,
|
||||||
session,
|
session,
|
||||||
@ -165,11 +175,8 @@ async def run_cli(
|
|||||||
)
|
)
|
||||||
|
|
||||||
if save_session:
|
if save_session:
|
||||||
if json_file_path:
|
session_id = input('Session ID to save: ')
|
||||||
session_path = json_file_path.replace('.input.json', '.session.json')
|
session_path = f'{agent_module_path}/{session_id}.session.json'
|
||||||
else:
|
|
||||||
session_id = input('Session ID to save: ')
|
|
||||||
session_path = f'{agent_module_path}/{session_id}.session.json'
|
|
||||||
|
|
||||||
# Fetch the session again to get all the details.
|
# Fetch the session again to get all the details.
|
||||||
session = session_service.get_session(
|
session = session_service.get_session(
|
||||||
|
@ -54,7 +54,7 @@ COPY "agents/{app_name}/" "/app/agents/{app_name}/"
|
|||||||
|
|
||||||
EXPOSE {port}
|
EXPOSE {port}
|
||||||
|
|
||||||
CMD adk {command} --port={port} {trace_to_cloud_option} "/app/agents"
|
CMD adk {command} --port={port} {session_db_option} {trace_to_cloud_option} "/app/agents"
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
@ -85,6 +85,7 @@ def to_cloud_run(
|
|||||||
trace_to_cloud: bool,
|
trace_to_cloud: bool,
|
||||||
with_ui: bool,
|
with_ui: bool,
|
||||||
verbosity: str,
|
verbosity: str,
|
||||||
|
session_db_url: str,
|
||||||
):
|
):
|
||||||
"""Deploys an agent to Google Cloud Run.
|
"""Deploys an agent to Google Cloud Run.
|
||||||
|
|
||||||
@ -112,6 +113,7 @@ def to_cloud_run(
|
|||||||
trace_to_cloud: Whether to enable Cloud Trace.
|
trace_to_cloud: Whether to enable Cloud Trace.
|
||||||
with_ui: Whether to deploy with UI.
|
with_ui: Whether to deploy with UI.
|
||||||
verbosity: The verbosity level of the CLI.
|
verbosity: The verbosity level of the CLI.
|
||||||
|
session_db_url: The database URL to connect the session.
|
||||||
"""
|
"""
|
||||||
app_name = app_name or os.path.basename(agent_folder)
|
app_name = app_name or os.path.basename(agent_folder)
|
||||||
|
|
||||||
@ -144,6 +146,9 @@ def to_cloud_run(
|
|||||||
port=port,
|
port=port,
|
||||||
command='web' if with_ui else 'api_server',
|
command='web' if with_ui else 'api_server',
|
||||||
install_agent_deps=install_agent_deps,
|
install_agent_deps=install_agent_deps,
|
||||||
|
session_db_option=f'--session_db_url={session_db_url}'
|
||||||
|
if session_db_url
|
||||||
|
else '',
|
||||||
trace_to_cloud_option='--trace_to_cloud' if trace_to_cloud else '',
|
trace_to_cloud_option='--trace_to_cloud' if trace_to_cloud else '',
|
||||||
)
|
)
|
||||||
dockerfile_path = os.path.join(temp_folder, 'Dockerfile')
|
dockerfile_path = os.path.join(temp_folder, 'Dockerfile')
|
||||||
|
@ -256,7 +256,7 @@ def run_evals(
|
|||||||
)
|
)
|
||||||
|
|
||||||
if final_eval_status == EvalStatus.PASSED:
|
if final_eval_status == EvalStatus.PASSED:
|
||||||
result = "✅ Passsed"
|
result = "✅ Passed"
|
||||||
else:
|
else:
|
||||||
result = "❌ Failed"
|
result = "❌ Failed"
|
||||||
|
|
||||||
|
@ -96,6 +96,23 @@ def cli_create_cmd(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def validate_exclusive(ctx, param, value):
|
||||||
|
# Store the validated parameters in the context
|
||||||
|
if not hasattr(ctx, "exclusive_opts"):
|
||||||
|
ctx.exclusive_opts = {}
|
||||||
|
|
||||||
|
# If this option has a value and we've already seen another exclusive option
|
||||||
|
if value is not None and any(ctx.exclusive_opts.values()):
|
||||||
|
exclusive_opt = next(key for key, val in ctx.exclusive_opts.items() if val)
|
||||||
|
raise click.UsageError(
|
||||||
|
f"Options '{param.name}' and '{exclusive_opt}' cannot be set together."
|
||||||
|
)
|
||||||
|
|
||||||
|
# Record this option's value
|
||||||
|
ctx.exclusive_opts[param.name] = value is not None
|
||||||
|
return value
|
||||||
|
|
||||||
|
|
||||||
@main.command("run")
|
@main.command("run")
|
||||||
@click.option(
|
@click.option(
|
||||||
"--save_session",
|
"--save_session",
|
||||||
@ -105,13 +122,43 @@ def cli_create_cmd(
|
|||||||
default=False,
|
default=False,
|
||||||
help="Optional. Whether to save the session to a json file on exit.",
|
help="Optional. Whether to save the session to a json file on exit.",
|
||||||
)
|
)
|
||||||
|
@click.option(
|
||||||
|
"--replay",
|
||||||
|
type=click.Path(
|
||||||
|
exists=True, dir_okay=False, file_okay=True, resolve_path=True
|
||||||
|
),
|
||||||
|
help=(
|
||||||
|
"The json file that contains the initial state of the session and user"
|
||||||
|
" queries. A new session will be created using this state. And user"
|
||||||
|
" queries are run againt the newly created session. Users cannot"
|
||||||
|
" continue to interact with the agent."
|
||||||
|
),
|
||||||
|
callback=validate_exclusive,
|
||||||
|
)
|
||||||
|
@click.option(
|
||||||
|
"--resume",
|
||||||
|
type=click.Path(
|
||||||
|
exists=True, dir_okay=False, file_okay=True, resolve_path=True
|
||||||
|
),
|
||||||
|
help=(
|
||||||
|
"The json file that contains a previously saved session (by"
|
||||||
|
"--save_session option). The previous session will be re-displayed. And"
|
||||||
|
" user can continue to interact with the agent."
|
||||||
|
),
|
||||||
|
callback=validate_exclusive,
|
||||||
|
)
|
||||||
@click.argument(
|
@click.argument(
|
||||||
"agent",
|
"agent",
|
||||||
type=click.Path(
|
type=click.Path(
|
||||||
exists=True, dir_okay=True, file_okay=False, resolve_path=True
|
exists=True, dir_okay=True, file_okay=False, resolve_path=True
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
def cli_run(agent: str, save_session: bool):
|
def cli_run(
|
||||||
|
agent: str,
|
||||||
|
save_session: bool,
|
||||||
|
replay: Optional[str],
|
||||||
|
resume: Optional[str],
|
||||||
|
):
|
||||||
"""Runs an interactive CLI for a certain agent.
|
"""Runs an interactive CLI for a certain agent.
|
||||||
|
|
||||||
AGENT: The path to the agent source code folder.
|
AGENT: The path to the agent source code folder.
|
||||||
@ -129,6 +176,8 @@ def cli_run(agent: str, save_session: bool):
|
|||||||
run_cli(
|
run_cli(
|
||||||
agent_parent_dir=agent_parent_folder,
|
agent_parent_dir=agent_parent_folder,
|
||||||
agent_folder_name=agent_folder_name,
|
agent_folder_name=agent_folder_name,
|
||||||
|
input_file=replay,
|
||||||
|
saved_session_file=resume,
|
||||||
save_session=save_session,
|
save_session=save_session,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
@ -245,12 +294,13 @@ def cli_eval(
|
|||||||
@click.option(
|
@click.option(
|
||||||
"--session_db_url",
|
"--session_db_url",
|
||||||
help=(
|
help=(
|
||||||
"Optional. The database URL to store the session.\n\n - Use"
|
"""Optional. The database URL to store the session.
|
||||||
" 'agentengine://<agent_engine_resource_id>' to connect to Vertex"
|
|
||||||
" managed session service.\n\n - Use 'sqlite://<path_to_sqlite_file>'"
|
- Use 'agentengine://<agent_engine_resource_id>' to connect to Agent Engine sessions.
|
||||||
" to connect to a SQLite DB.\n\n - See"
|
|
||||||
" https://docs.sqlalchemy.org/en/20/core/engines.html#backend-specific-urls"
|
- Use 'sqlite://<path_to_sqlite_file>' to connect to a SQLite DB.
|
||||||
" for more details on supported DB URLs."
|
|
||||||
|
- See https://docs.sqlalchemy.org/en/20/core/engines.html#backend-specific-urls for more details on supported DB URLs."""
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
@click.option(
|
@click.option(
|
||||||
@ -366,12 +416,13 @@ def cli_web(
|
|||||||
@click.option(
|
@click.option(
|
||||||
"--session_db_url",
|
"--session_db_url",
|
||||||
help=(
|
help=(
|
||||||
"Optional. The database URL to store the session.\n\n - Use"
|
"""Optional. The database URL to store the session.
|
||||||
" 'agentengine://<agent_engine_resource_id>' to connect to Vertex"
|
|
||||||
" managed session service.\n\n - Use 'sqlite://<path_to_sqlite_file>'"
|
- Use 'agentengine://<agent_engine_resource_id>' to connect to Agent Engine sessions.
|
||||||
" to connect to a SQLite DB.\n\n - See"
|
|
||||||
" https://docs.sqlalchemy.org/en/20/core/engines.html#backend-specific-urls"
|
- Use 'sqlite://<path_to_sqlite_file>' to connect to a SQLite DB.
|
||||||
" for more details on supported DB URLs."
|
|
||||||
|
- See https://docs.sqlalchemy.org/en/20/core/engines.html#backend-specific-urls for more details on supported DB URLs."""
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
@click.option(
|
@click.option(
|
||||||
@ -541,6 +592,18 @@ def cli_api_server(
|
|||||||
default="WARNING",
|
default="WARNING",
|
||||||
help="Optional. Override the default verbosity level.",
|
help="Optional. Override the default verbosity level.",
|
||||||
)
|
)
|
||||||
|
@click.option(
|
||||||
|
"--session_db_url",
|
||||||
|
help=(
|
||||||
|
"""Optional. The database URL to store the session.
|
||||||
|
|
||||||
|
- Use 'agentengine://<agent_engine_resource_id>' to connect to Agent Engine sessions.
|
||||||
|
|
||||||
|
- Use 'sqlite://<path_to_sqlite_file>' to connect to a SQLite DB.
|
||||||
|
|
||||||
|
- See https://docs.sqlalchemy.org/en/20/core/engines.html#backend-specific-urls for more details on supported DB URLs."""
|
||||||
|
),
|
||||||
|
)
|
||||||
@click.argument(
|
@click.argument(
|
||||||
"agent",
|
"agent",
|
||||||
type=click.Path(
|
type=click.Path(
|
||||||
@ -558,6 +621,7 @@ def cli_deploy_cloud_run(
|
|||||||
trace_to_cloud: bool,
|
trace_to_cloud: bool,
|
||||||
with_ui: bool,
|
with_ui: bool,
|
||||||
verbosity: str,
|
verbosity: str,
|
||||||
|
session_db_url: str,
|
||||||
):
|
):
|
||||||
"""Deploys an agent to Cloud Run.
|
"""Deploys an agent to Cloud Run.
|
||||||
|
|
||||||
@ -579,6 +643,7 @@ def cli_deploy_cloud_run(
|
|||||||
trace_to_cloud=trace_to_cloud,
|
trace_to_cloud=trace_to_cloud,
|
||||||
with_ui=with_ui,
|
with_ui=with_ui,
|
||||||
verbosity=verbosity,
|
verbosity=verbosity,
|
||||||
|
session_db_url=session_db_url,
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
click.secho(f"Deploy failed: {e}", fg="red", err=True)
|
click.secho(f"Deploy failed: {e}", fg="red", err=True)
|
||||||
|
@ -756,6 +756,12 @@ def get_fast_api_app(
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.exception("Error during live websocket communication: %s", e)
|
logger.exception("Error during live websocket communication: %s", e)
|
||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
|
WEBSOCKET_INTERNAL_ERROR_CODE = 1011
|
||||||
|
WEBSOCKET_MAX_BYTES_FOR_REASON = 123
|
||||||
|
await websocket.close(
|
||||||
|
code=WEBSOCKET_INTERNAL_ERROR_CODE,
|
||||||
|
reason=str(e)[:WEBSOCKET_MAX_BYTES_FOR_REASON],
|
||||||
|
)
|
||||||
finally:
|
finally:
|
||||||
for task in pending:
|
for task in pending:
|
||||||
task.cancel()
|
task.cancel()
|
||||||
|
@ -55,7 +55,7 @@ def load_json(file_path: str) -> Union[Dict, List]:
|
|||||||
|
|
||||||
|
|
||||||
class AgentEvaluator:
|
class AgentEvaluator:
|
||||||
"""An evaluator for Agents, mainly intented for helping with test cases."""
|
"""An evaluator for Agents, mainly intended for helping with test cases."""
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def find_config_for_test_file(test_file: str):
|
def find_config_for_test_file(test_file: str):
|
||||||
@ -91,7 +91,7 @@ class AgentEvaluator:
|
|||||||
look for 'root_agent' in the loaded module.
|
look for 'root_agent' in the loaded module.
|
||||||
eval_dataset: The eval data set. This can be either a string representing
|
eval_dataset: The eval data set. This can be either a string representing
|
||||||
full path to the file containing eval dataset, or a directory that is
|
full path to the file containing eval dataset, or a directory that is
|
||||||
recusively explored for all files that have a `.test.json` suffix.
|
recursively explored for all files that have a `.test.json` suffix.
|
||||||
num_runs: Number of times all entries in the eval dataset should be
|
num_runs: Number of times all entries in the eval dataset should be
|
||||||
assessed.
|
assessed.
|
||||||
agent_name: The name of the agent.
|
agent_name: The name of the agent.
|
||||||
|
@ -35,7 +35,7 @@ class ResponseEvaluator:
|
|||||||
Args:
|
Args:
|
||||||
raw_eval_dataset: The dataset that will be evaluated.
|
raw_eval_dataset: The dataset that will be evaluated.
|
||||||
evaluation_criteria: The evaluation criteria to be used. This method
|
evaluation_criteria: The evaluation criteria to be used. This method
|
||||||
support two criterias, `response_evaluation_score` and
|
support two criteria, `response_evaluation_score` and
|
||||||
`response_match_score`.
|
`response_match_score`.
|
||||||
print_detailed_results: Prints detailed results on the console. This is
|
print_detailed_results: Prints detailed results on the console. This is
|
||||||
usually helpful during debugging.
|
usually helpful during debugging.
|
||||||
@ -56,7 +56,7 @@ class ResponseEvaluator:
|
|||||||
Value range: [0, 5], where 0 means that the agent's response is not
|
Value range: [0, 5], where 0 means that the agent's response is not
|
||||||
coherent, while 5 means it is . High values are good.
|
coherent, while 5 means it is . High values are good.
|
||||||
A note on raw_eval_dataset:
|
A note on raw_eval_dataset:
|
||||||
The dataset should be a list session, where each sesssion is represented
|
The dataset should be a list session, where each session is represented
|
||||||
as a list of interaction that need evaluation. Each evaluation is
|
as a list of interaction that need evaluation. Each evaluation is
|
||||||
represented as a dictionary that is expected to have values for the
|
represented as a dictionary that is expected to have values for the
|
||||||
following keys:
|
following keys:
|
||||||
|
@ -31,10 +31,9 @@ class TrajectoryEvaluator:
|
|||||||
):
|
):
|
||||||
r"""Returns the mean tool use accuracy of the eval dataset.
|
r"""Returns the mean tool use accuracy of the eval dataset.
|
||||||
|
|
||||||
Tool use accuracy is calculated by comparing the expected and actuall tool
|
Tool use accuracy is calculated by comparing the expected and the actual
|
||||||
use trajectories. An exact match scores a 1, 0 otherwise. The final number
|
tool use trajectories. An exact match scores a 1, 0 otherwise. The final
|
||||||
is an
|
number is an average of these individual scores.
|
||||||
average of these individual scores.
|
|
||||||
|
|
||||||
Value range: [0, 1], where 0 is means none of the too use entries aligned,
|
Value range: [0, 1], where 0 is means none of the too use entries aligned,
|
||||||
and 1 would mean all of them aligned. Higher value is good.
|
and 1 would mean all of them aligned. Higher value is good.
|
||||||
@ -45,7 +44,7 @@ class TrajectoryEvaluator:
|
|||||||
usually helpful during debugging.
|
usually helpful during debugging.
|
||||||
|
|
||||||
A note on eval_dataset:
|
A note on eval_dataset:
|
||||||
The dataset should be a list session, where each sesssion is represented
|
The dataset should be a list session, where each session is represented
|
||||||
as a list of interaction that need evaluation. Each evaluation is
|
as a list of interaction that need evaluation. Each evaluation is
|
||||||
represented as a dictionary that is expected to have values for the
|
represented as a dictionary that is expected to have values for the
|
||||||
following keys:
|
following keys:
|
||||||
|
@ -48,8 +48,13 @@ class EventActions(BaseModel):
|
|||||||
"""The agent is escalating to a higher level agent."""
|
"""The agent is escalating to a higher level agent."""
|
||||||
|
|
||||||
requested_auth_configs: dict[str, AuthConfig] = Field(default_factory=dict)
|
requested_auth_configs: dict[str, AuthConfig] = Field(default_factory=dict)
|
||||||
"""Will only be set by a tool response indicating tool request euc.
|
"""Authentication configurations requested by tool responses.
|
||||||
dict key is the function call id since one function call response (from model)
|
|
||||||
could correspond to multiple function calls.
|
This field will only be set by a tool response event indicating tool request
|
||||||
dict value is the required auth config.
|
auth credential.
|
||||||
|
- Keys: The function call id. Since one function response event could contain
|
||||||
|
multiple function responses that correspond to multiple function calls. Each
|
||||||
|
function call could request different auth configs. This id is used to
|
||||||
|
identify the function call.
|
||||||
|
- Values: The requested auth config.
|
||||||
"""
|
"""
|
||||||
|
@ -94,7 +94,7 @@ can answer it.
|
|||||||
|
|
||||||
If another agent is better for answering the question according to its
|
If another agent is better for answering the question according to its
|
||||||
description, call `{_TRANSFER_TO_AGENT_FUNCTION_NAME}` function to transfer the
|
description, call `{_TRANSFER_TO_AGENT_FUNCTION_NAME}` function to transfer the
|
||||||
question to that agent. When transfering, do not generate any text other than
|
question to that agent. When transferring, do not generate any text other than
|
||||||
the function call.
|
the function call.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
@ -115,7 +115,7 @@ class BaseLlmFlow(ABC):
|
|||||||
yield event
|
yield event
|
||||||
# send back the function response
|
# send back the function response
|
||||||
if event.get_function_responses():
|
if event.get_function_responses():
|
||||||
logger.debug('Sending back last function resonse event: %s', event)
|
logger.debug('Sending back last function response event: %s', event)
|
||||||
invocation_context.live_request_queue.send_content(event.content)
|
invocation_context.live_request_queue.send_content(event.content)
|
||||||
if (
|
if (
|
||||||
event.content
|
event.content
|
||||||
|
@ -111,7 +111,7 @@ def _rearrange_events_for_latest_function_response(
|
|||||||
"""Rearrange the events for the latest function_response.
|
"""Rearrange the events for the latest function_response.
|
||||||
|
|
||||||
If the latest function_response is for an async function_call, all events
|
If the latest function_response is for an async function_call, all events
|
||||||
bewteen the initial function_call and the latest function_response will be
|
between the initial function_call and the latest function_response will be
|
||||||
removed.
|
removed.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
@ -151,28 +151,33 @@ async def handle_function_calls_async(
|
|||||||
# do not use "args" as the variable name, because it is a reserved keyword
|
# do not use "args" as the variable name, because it is a reserved keyword
|
||||||
# in python debugger.
|
# in python debugger.
|
||||||
function_args = function_call.args or {}
|
function_args = function_call.args or {}
|
||||||
function_response = None
|
function_response: Optional[dict] = None
|
||||||
# Calls the tool if before_tool_callback does not exist or returns None.
|
|
||||||
|
# before_tool_callback (sync or async)
|
||||||
if agent.before_tool_callback:
|
if agent.before_tool_callback:
|
||||||
function_response = agent.before_tool_callback(
|
function_response = agent.before_tool_callback(
|
||||||
tool=tool, args=function_args, tool_context=tool_context
|
tool=tool, args=function_args, tool_context=tool_context
|
||||||
)
|
)
|
||||||
|
if inspect.isawaitable(function_response):
|
||||||
|
function_response = await function_response
|
||||||
|
|
||||||
if not function_response:
|
if not function_response:
|
||||||
function_response = await __call_tool_async(
|
function_response = await __call_tool_async(
|
||||||
tool, args=function_args, tool_context=tool_context
|
tool, args=function_args, tool_context=tool_context
|
||||||
)
|
)
|
||||||
|
|
||||||
# Calls after_tool_callback if it exists.
|
# after_tool_callback (sync or async)
|
||||||
if agent.after_tool_callback:
|
if agent.after_tool_callback:
|
||||||
new_response = agent.after_tool_callback(
|
altered_function_response = agent.after_tool_callback(
|
||||||
tool=tool,
|
tool=tool,
|
||||||
args=function_args,
|
args=function_args,
|
||||||
tool_context=tool_context,
|
tool_context=tool_context,
|
||||||
tool_response=function_response,
|
tool_response=function_response,
|
||||||
)
|
)
|
||||||
if new_response:
|
if inspect.isawaitable(altered_function_response):
|
||||||
function_response = new_response
|
altered_function_response = await altered_function_response
|
||||||
|
if altered_function_response is not None:
|
||||||
|
function_response = altered_function_response
|
||||||
|
|
||||||
if tool.is_long_running:
|
if tool.is_long_running:
|
||||||
# Allow long running function to return None to not provide function response.
|
# Allow long running function to return None to not provide function response.
|
||||||
@ -223,11 +228,17 @@ async def handle_function_calls_live(
|
|||||||
# in python debugger.
|
# in python debugger.
|
||||||
function_args = function_call.args or {}
|
function_args = function_call.args or {}
|
||||||
function_response = None
|
function_response = None
|
||||||
# Calls the tool if before_tool_callback does not exist or returns None.
|
# # Calls the tool if before_tool_callback does not exist or returns None.
|
||||||
|
# if agent.before_tool_callback:
|
||||||
|
# function_response = agent.before_tool_callback(
|
||||||
|
# tool, function_args, tool_context
|
||||||
|
# )
|
||||||
if agent.before_tool_callback:
|
if agent.before_tool_callback:
|
||||||
function_response = agent.before_tool_callback(
|
function_response = agent.before_tool_callback(
|
||||||
tool, function_args, tool_context
|
tool=tool, args=function_args, tool_context=tool_context
|
||||||
)
|
)
|
||||||
|
if inspect.isawaitable(function_response):
|
||||||
|
function_response = await function_response
|
||||||
|
|
||||||
if not function_response:
|
if not function_response:
|
||||||
function_response = await _process_function_live_helper(
|
function_response = await _process_function_live_helper(
|
||||||
@ -235,15 +246,26 @@ async def handle_function_calls_live(
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Calls after_tool_callback if it exists.
|
# Calls after_tool_callback if it exists.
|
||||||
|
# if agent.after_tool_callback:
|
||||||
|
# new_response = agent.after_tool_callback(
|
||||||
|
# tool,
|
||||||
|
# function_args,
|
||||||
|
# tool_context,
|
||||||
|
# function_response,
|
||||||
|
# )
|
||||||
|
# if new_response:
|
||||||
|
# function_response = new_response
|
||||||
if agent.after_tool_callback:
|
if agent.after_tool_callback:
|
||||||
new_response = agent.after_tool_callback(
|
altered_function_response = agent.after_tool_callback(
|
||||||
tool,
|
tool=tool,
|
||||||
function_args,
|
args=function_args,
|
||||||
tool_context,
|
tool_context=tool_context,
|
||||||
function_response,
|
tool_response=function_response,
|
||||||
)
|
)
|
||||||
if new_response:
|
if inspect.isawaitable(altered_function_response):
|
||||||
function_response = new_response
|
altered_function_response = await altered_function_response
|
||||||
|
if altered_function_response is not None:
|
||||||
|
function_response = altered_function_response
|
||||||
|
|
||||||
if tool.is_long_running:
|
if tool.is_long_running:
|
||||||
# Allow async function to return None to not provide function response.
|
# Allow async function to return None to not provide function response.
|
||||||
@ -310,9 +332,7 @@ async def _process_function_live_helper(
|
|||||||
function_response = {
|
function_response = {
|
||||||
'status': f'No active streaming function named {function_name} found'
|
'status': f'No active streaming function named {function_name} found'
|
||||||
}
|
}
|
||||||
elif inspect.isasyncgenfunction(tool.func):
|
elif hasattr(tool, "func") and inspect.isasyncgenfunction(tool.func):
|
||||||
print('is async')
|
|
||||||
|
|
||||||
# for streaming tool use case
|
# for streaming tool use case
|
||||||
# we require the function to be a async generator function
|
# we require the function to be a async generator function
|
||||||
async def run_tool_and_update_queue(tool, function_args, tool_context):
|
async def run_tool_and_update_queue(tool, function_args, tool_context):
|
||||||
|
@ -52,7 +52,7 @@ class _InstructionsLlmRequestProcessor(BaseLlmRequestProcessor):
|
|||||||
# Appends global instructions if set.
|
# Appends global instructions if set.
|
||||||
if (
|
if (
|
||||||
isinstance(root_agent, LlmAgent) and root_agent.global_instruction
|
isinstance(root_agent, LlmAgent) and root_agent.global_instruction
|
||||||
): # not emtpy str
|
): # not empty str
|
||||||
raw_si = root_agent.canonical_global_instruction(
|
raw_si = root_agent.canonical_global_instruction(
|
||||||
ReadonlyContext(invocation_context)
|
ReadonlyContext(invocation_context)
|
||||||
)
|
)
|
||||||
@ -60,7 +60,7 @@ class _InstructionsLlmRequestProcessor(BaseLlmRequestProcessor):
|
|||||||
llm_request.append_instructions([si])
|
llm_request.append_instructions([si])
|
||||||
|
|
||||||
# Appends agent instructions if set.
|
# Appends agent instructions if set.
|
||||||
if agent.instruction: # not emtpy str
|
if agent.instruction: # not empty str
|
||||||
raw_si = agent.canonical_instruction(ReadonlyContext(invocation_context))
|
raw_si = agent.canonical_instruction(ReadonlyContext(invocation_context))
|
||||||
si = _populate_values(raw_si, invocation_context)
|
si = _populate_values(raw_si, invocation_context)
|
||||||
llm_request.append_instructions([si])
|
llm_request.append_instructions([si])
|
||||||
|
@ -152,7 +152,7 @@ class GeminiLlmConnection(BaseLlmConnection):
|
|||||||
):
|
):
|
||||||
# TODO: Right now, we just support output_transcription without
|
# TODO: Right now, we just support output_transcription without
|
||||||
# changing interface and data protocol. Later, we can consider to
|
# changing interface and data protocol. Later, we can consider to
|
||||||
# support output_transcription as a separete field in LlmResponse.
|
# support output_transcription as a separate field in LlmResponse.
|
||||||
|
|
||||||
# Transcription is always considered as partial event
|
# Transcription is always considered as partial event
|
||||||
# We rely on other control signals to determine when to yield the
|
# We rely on other control signals to determine when to yield the
|
||||||
@ -179,7 +179,7 @@ class GeminiLlmConnection(BaseLlmConnection):
|
|||||||
# in case of empty content or parts, we sill surface it
|
# in case of empty content or parts, we sill surface it
|
||||||
# in case it's an interrupted message, we merge the previous partial
|
# in case it's an interrupted message, we merge the previous partial
|
||||||
# text. Other we don't merge. because content can be none when model
|
# text. Other we don't merge. because content can be none when model
|
||||||
# safty threshold is triggered
|
# safety threshold is triggered
|
||||||
if message.server_content.interrupted and text:
|
if message.server_content.interrupted and text:
|
||||||
yield self.__build_full_text_response(text)
|
yield self.__build_full_text_response(text)
|
||||||
text = ''
|
text = ''
|
||||||
|
@ -14,7 +14,7 @@
|
|||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from typing import Optional
|
from typing import Any, Optional
|
||||||
|
|
||||||
from google.genai import types
|
from google.genai import types
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
@ -37,6 +37,7 @@ class LlmResponse(BaseModel):
|
|||||||
error_message: Error message if the response is an error.
|
error_message: Error message if the response is an error.
|
||||||
interrupted: Flag indicating that LLM was interrupted when generating the
|
interrupted: Flag indicating that LLM was interrupted when generating the
|
||||||
content. Usually it's due to user interruption during a bidi streaming.
|
content. Usually it's due to user interruption during a bidi streaming.
|
||||||
|
custom_metadata: The custom metadata of the LlmResponse.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
model_config = ConfigDict(extra='forbid')
|
model_config = ConfigDict(extra='forbid')
|
||||||
@ -71,6 +72,14 @@ class LlmResponse(BaseModel):
|
|||||||
Usually it's due to user interruption during a bidi streaming.
|
Usually it's due to user interruption during a bidi streaming.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
custom_metadata: Optional[dict[str, Any]] = None
|
||||||
|
"""The custom metadata of the LlmResponse.
|
||||||
|
|
||||||
|
An optional key-value pair to label an LlmResponse.
|
||||||
|
|
||||||
|
NOTE: the entire dict must be JSON serializable.
|
||||||
|
"""
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def create(
|
def create(
|
||||||
generate_content_response: types.GenerateContentResponse,
|
generate_content_response: types.GenerateContentResponse,
|
||||||
|
@ -56,6 +56,7 @@ class BuiltInPlanner(BasePlanner):
|
|||||||
llm_request: The LLM request to apply the thinking config to.
|
llm_request: The LLM request to apply the thinking config to.
|
||||||
"""
|
"""
|
||||||
if self.thinking_config:
|
if self.thinking_config:
|
||||||
|
llm_request.config = llm_request.config or types.GenerateContentConfig()
|
||||||
llm_request.config.thinking_config = self.thinking_config
|
llm_request.config.thinking_config = self.thinking_config
|
||||||
|
|
||||||
@override
|
@override
|
||||||
|
29
src/google/adk/sessions/_session_util.py
Normal file
29
src/google/adk/sessions/_session_util.py
Normal file
@ -0,0 +1,29 @@
|
|||||||
|
"""Utility functions for session service."""
|
||||||
|
|
||||||
|
import base64
|
||||||
|
from typing import Any, Optional
|
||||||
|
|
||||||
|
from google.genai import types
|
||||||
|
|
||||||
|
|
||||||
|
def encode_content(content: types.Content):
|
||||||
|
"""Encodes a content object to a JSON dictionary."""
|
||||||
|
encoded_content = content.model_dump(exclude_none=True)
|
||||||
|
for p in encoded_content["parts"]:
|
||||||
|
if "inline_data" in p:
|
||||||
|
p["inline_data"]["data"] = base64.b64encode(
|
||||||
|
p["inline_data"]["data"]
|
||||||
|
).decode("utf-8")
|
||||||
|
return encoded_content
|
||||||
|
|
||||||
|
|
||||||
|
def decode_content(
|
||||||
|
content: Optional[dict[str, Any]],
|
||||||
|
) -> Optional[types.Content]:
|
||||||
|
"""Decodes a content object from a JSON dictionary."""
|
||||||
|
if not content:
|
||||||
|
return None
|
||||||
|
for p in content["parts"]:
|
||||||
|
if "inline_data" in p:
|
||||||
|
p["inline_data"]["data"] = base64.b64decode(p["inline_data"]["data"])
|
||||||
|
return types.Content.model_validate(content)
|
@ -11,14 +11,11 @@
|
|||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import base64
|
|
||||||
import copy
|
import copy
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
from typing import Any
|
from typing import Any, Optional
|
||||||
from typing import Optional
|
|
||||||
import uuid
|
import uuid
|
||||||
|
|
||||||
from sqlalchemy import Boolean
|
from sqlalchemy import Boolean
|
||||||
@ -27,6 +24,7 @@ from sqlalchemy import Dialect
|
|||||||
from sqlalchemy import ForeignKeyConstraint
|
from sqlalchemy import ForeignKeyConstraint
|
||||||
from sqlalchemy import func
|
from sqlalchemy import func
|
||||||
from sqlalchemy import Text
|
from sqlalchemy import Text
|
||||||
|
from sqlalchemy.dialects import mysql
|
||||||
from sqlalchemy.dialects import postgresql
|
from sqlalchemy.dialects import postgresql
|
||||||
from sqlalchemy.engine import create_engine
|
from sqlalchemy.engine import create_engine
|
||||||
from sqlalchemy.engine import Engine
|
from sqlalchemy.engine import Engine
|
||||||
@ -48,6 +46,7 @@ from typing_extensions import override
|
|||||||
from tzlocal import get_localzone
|
from tzlocal import get_localzone
|
||||||
|
|
||||||
from ..events.event import Event
|
from ..events.event import Event
|
||||||
|
from . import _session_util
|
||||||
from .base_session_service import BaseSessionService
|
from .base_session_service import BaseSessionService
|
||||||
from .base_session_service import GetSessionConfig
|
from .base_session_service import GetSessionConfig
|
||||||
from .base_session_service import ListEventsResponse
|
from .base_session_service import ListEventsResponse
|
||||||
@ -58,6 +57,9 @@ from .state import State
|
|||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
DEFAULT_MAX_KEY_LENGTH = 128
|
||||||
|
DEFAULT_MAX_VARCHAR_LENGTH = 256
|
||||||
|
|
||||||
|
|
||||||
class DynamicJSON(TypeDecorator):
|
class DynamicJSON(TypeDecorator):
|
||||||
"""A JSON-like type that uses JSONB on PostgreSQL and TEXT with JSON
|
"""A JSON-like type that uses JSONB on PostgreSQL and TEXT with JSON
|
||||||
@ -70,15 +72,16 @@ class DynamicJSON(TypeDecorator):
|
|||||||
def load_dialect_impl(self, dialect: Dialect):
|
def load_dialect_impl(self, dialect: Dialect):
|
||||||
if dialect.name == "postgresql":
|
if dialect.name == "postgresql":
|
||||||
return dialect.type_descriptor(postgresql.JSONB)
|
return dialect.type_descriptor(postgresql.JSONB)
|
||||||
else:
|
if dialect.name == "mysql":
|
||||||
return dialect.type_descriptor(Text) # Default to Text for other dialects
|
# Use LONGTEXT for MySQL to address the data too long issue
|
||||||
|
return dialect.type_descriptor(mysql.LONGTEXT)
|
||||||
|
return dialect.type_descriptor(Text) # Default to Text for other dialects
|
||||||
|
|
||||||
def process_bind_param(self, value, dialect: Dialect):
|
def process_bind_param(self, value, dialect: Dialect):
|
||||||
if value is not None:
|
if value is not None:
|
||||||
if dialect.name == "postgresql":
|
if dialect.name == "postgresql":
|
||||||
return value # JSONB handles dict directly
|
return value # JSONB handles dict directly
|
||||||
else:
|
return json.dumps(value) # Serialize to JSON string for TEXT
|
||||||
return json.dumps(value) # Serialize to JSON string for TEXT
|
|
||||||
return value
|
return value
|
||||||
|
|
||||||
def process_result_value(self, value, dialect: Dialect):
|
def process_result_value(self, value, dialect: Dialect):
|
||||||
@ -92,17 +95,25 @@ class DynamicJSON(TypeDecorator):
|
|||||||
|
|
||||||
class Base(DeclarativeBase):
|
class Base(DeclarativeBase):
|
||||||
"""Base class for database tables."""
|
"""Base class for database tables."""
|
||||||
|
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
class StorageSession(Base):
|
class StorageSession(Base):
|
||||||
"""Represents a session stored in the database."""
|
"""Represents a session stored in the database."""
|
||||||
|
|
||||||
__tablename__ = "sessions"
|
__tablename__ = "sessions"
|
||||||
|
|
||||||
app_name: Mapped[str] = mapped_column(String, primary_key=True)
|
app_name: Mapped[str] = mapped_column(
|
||||||
user_id: Mapped[str] = mapped_column(String, primary_key=True)
|
String(DEFAULT_MAX_KEY_LENGTH), primary_key=True
|
||||||
|
)
|
||||||
|
user_id: Mapped[str] = mapped_column(
|
||||||
|
String(DEFAULT_MAX_KEY_LENGTH), primary_key=True
|
||||||
|
)
|
||||||
id: Mapped[str] = mapped_column(
|
id: Mapped[str] = mapped_column(
|
||||||
String, primary_key=True, default=lambda: str(uuid.uuid4())
|
String(DEFAULT_MAX_KEY_LENGTH),
|
||||||
|
primary_key=True,
|
||||||
|
default=lambda: str(uuid.uuid4()),
|
||||||
)
|
)
|
||||||
|
|
||||||
state: Mapped[MutableDict[str, Any]] = mapped_column(
|
state: Mapped[MutableDict[str, Any]] = mapped_column(
|
||||||
@ -125,18 +136,29 @@ class StorageSession(Base):
|
|||||||
|
|
||||||
class StorageEvent(Base):
|
class StorageEvent(Base):
|
||||||
"""Represents an event stored in the database."""
|
"""Represents an event stored in the database."""
|
||||||
|
|
||||||
__tablename__ = "events"
|
__tablename__ = "events"
|
||||||
|
|
||||||
id: Mapped[str] = mapped_column(String, primary_key=True)
|
id: Mapped[str] = mapped_column(
|
||||||
app_name: Mapped[str] = mapped_column(String, primary_key=True)
|
String(DEFAULT_MAX_KEY_LENGTH), primary_key=True
|
||||||
user_id: Mapped[str] = mapped_column(String, primary_key=True)
|
)
|
||||||
session_id: Mapped[str] = mapped_column(String, primary_key=True)
|
app_name: Mapped[str] = mapped_column(
|
||||||
|
String(DEFAULT_MAX_KEY_LENGTH), primary_key=True
|
||||||
|
)
|
||||||
|
user_id: Mapped[str] = mapped_column(
|
||||||
|
String(DEFAULT_MAX_KEY_LENGTH), primary_key=True
|
||||||
|
)
|
||||||
|
session_id: Mapped[str] = mapped_column(
|
||||||
|
String(DEFAULT_MAX_KEY_LENGTH), primary_key=True
|
||||||
|
)
|
||||||
|
|
||||||
invocation_id: Mapped[str] = mapped_column(String)
|
invocation_id: Mapped[str] = mapped_column(String(DEFAULT_MAX_VARCHAR_LENGTH))
|
||||||
author: Mapped[str] = mapped_column(String)
|
author: Mapped[str] = mapped_column(String(DEFAULT_MAX_VARCHAR_LENGTH))
|
||||||
branch: Mapped[str] = mapped_column(String, nullable=True)
|
branch: Mapped[str] = mapped_column(
|
||||||
|
String(DEFAULT_MAX_VARCHAR_LENGTH), nullable=True
|
||||||
|
)
|
||||||
timestamp: Mapped[DateTime] = mapped_column(DateTime(), default=func.now())
|
timestamp: Mapped[DateTime] = mapped_column(DateTime(), default=func.now())
|
||||||
content: Mapped[dict[str, Any]] = mapped_column(DynamicJSON)
|
content: Mapped[dict[str, Any]] = mapped_column(DynamicJSON, nullable=True)
|
||||||
actions: Mapped[MutableDict[str, Any]] = mapped_column(PickleType)
|
actions: Mapped[MutableDict[str, Any]] = mapped_column(PickleType)
|
||||||
|
|
||||||
long_running_tool_ids_json: Mapped[Optional[str]] = mapped_column(
|
long_running_tool_ids_json: Mapped[Optional[str]] = mapped_column(
|
||||||
@ -147,8 +169,10 @@ class StorageEvent(Base):
|
|||||||
)
|
)
|
||||||
partial: Mapped[bool] = mapped_column(Boolean, nullable=True)
|
partial: Mapped[bool] = mapped_column(Boolean, nullable=True)
|
||||||
turn_complete: Mapped[bool] = mapped_column(Boolean, nullable=True)
|
turn_complete: Mapped[bool] = mapped_column(Boolean, nullable=True)
|
||||||
error_code: Mapped[str] = mapped_column(String, nullable=True)
|
error_code: Mapped[str] = mapped_column(
|
||||||
error_message: Mapped[str] = mapped_column(String, nullable=True)
|
String(DEFAULT_MAX_VARCHAR_LENGTH), nullable=True
|
||||||
|
)
|
||||||
|
error_message: Mapped[str] = mapped_column(String(1024), nullable=True)
|
||||||
interrupted: Mapped[bool] = mapped_column(Boolean, nullable=True)
|
interrupted: Mapped[bool] = mapped_column(Boolean, nullable=True)
|
||||||
|
|
||||||
storage_session: Mapped[StorageSession] = relationship(
|
storage_session: Mapped[StorageSession] = relationship(
|
||||||
@ -182,9 +206,12 @@ class StorageEvent(Base):
|
|||||||
|
|
||||||
class StorageAppState(Base):
|
class StorageAppState(Base):
|
||||||
"""Represents an app state stored in the database."""
|
"""Represents an app state stored in the database."""
|
||||||
|
|
||||||
__tablename__ = "app_states"
|
__tablename__ = "app_states"
|
||||||
|
|
||||||
app_name: Mapped[str] = mapped_column(String, primary_key=True)
|
app_name: Mapped[str] = mapped_column(
|
||||||
|
String(DEFAULT_MAX_KEY_LENGTH), primary_key=True
|
||||||
|
)
|
||||||
state: Mapped[MutableDict[str, Any]] = mapped_column(
|
state: Mapped[MutableDict[str, Any]] = mapped_column(
|
||||||
MutableDict.as_mutable(DynamicJSON), default={}
|
MutableDict.as_mutable(DynamicJSON), default={}
|
||||||
)
|
)
|
||||||
@ -192,13 +219,17 @@ class StorageAppState(Base):
|
|||||||
DateTime(), default=func.now(), onupdate=func.now()
|
DateTime(), default=func.now(), onupdate=func.now()
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class StorageUserState(Base):
|
class StorageUserState(Base):
|
||||||
"""Represents a user state stored in the database."""
|
"""Represents a user state stored in the database."""
|
||||||
|
|
||||||
__tablename__ = "user_states"
|
__tablename__ = "user_states"
|
||||||
|
|
||||||
app_name: Mapped[str] = mapped_column(String, primary_key=True)
|
app_name: Mapped[str] = mapped_column(
|
||||||
user_id: Mapped[str] = mapped_column(String, primary_key=True)
|
String(DEFAULT_MAX_KEY_LENGTH), primary_key=True
|
||||||
|
)
|
||||||
|
user_id: Mapped[str] = mapped_column(
|
||||||
|
String(DEFAULT_MAX_KEY_LENGTH), primary_key=True
|
||||||
|
)
|
||||||
state: Mapped[MutableDict[str, Any]] = mapped_column(
|
state: Mapped[MutableDict[str, Any]] = mapped_column(
|
||||||
MutableDict.as_mutable(DynamicJSON), default={}
|
MutableDict.as_mutable(DynamicJSON), default={}
|
||||||
)
|
)
|
||||||
@ -217,7 +248,7 @@ class DatabaseSessionService(BaseSessionService):
|
|||||||
"""
|
"""
|
||||||
# 1. Create DB engine for db connection
|
# 1. Create DB engine for db connection
|
||||||
# 2. Create all tables based on schema
|
# 2. Create all tables based on schema
|
||||||
# 3. Initialize all properies
|
# 3. Initialize all properties
|
||||||
|
|
||||||
try:
|
try:
|
||||||
db_engine = create_engine(db_url)
|
db_engine = create_engine(db_url)
|
||||||
@ -353,6 +384,7 @@ class DatabaseSessionService(BaseSessionService):
|
|||||||
else True
|
else True
|
||||||
)
|
)
|
||||||
.limit(config.num_recent_events if config else None)
|
.limit(config.num_recent_events if config else None)
|
||||||
|
.order_by(StorageEvent.timestamp.asc())
|
||||||
.all()
|
.all()
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -383,7 +415,7 @@ class DatabaseSessionService(BaseSessionService):
|
|||||||
author=e.author,
|
author=e.author,
|
||||||
branch=e.branch,
|
branch=e.branch,
|
||||||
invocation_id=e.invocation_id,
|
invocation_id=e.invocation_id,
|
||||||
content=_decode_content(e.content),
|
content=_session_util.decode_content(e.content),
|
||||||
actions=e.actions,
|
actions=e.actions,
|
||||||
timestamp=e.timestamp.timestamp(),
|
timestamp=e.timestamp.timestamp(),
|
||||||
long_running_tool_ids=e.long_running_tool_ids,
|
long_running_tool_ids=e.long_running_tool_ids,
|
||||||
@ -506,15 +538,7 @@ class DatabaseSessionService(BaseSessionService):
|
|||||||
interrupted=event.interrupted,
|
interrupted=event.interrupted,
|
||||||
)
|
)
|
||||||
if event.content:
|
if event.content:
|
||||||
encoded_content = event.content.model_dump(exclude_none=True)
|
storage_event.content = _session_util.encode_content(event.content)
|
||||||
# Workaround for multimodal Content throwing JSON not serializable
|
|
||||||
# error with SQLAlchemy.
|
|
||||||
for p in encoded_content["parts"]:
|
|
||||||
if "inline_data" in p:
|
|
||||||
p["inline_data"]["data"] = (
|
|
||||||
base64.b64encode(p["inline_data"]["data"]).decode("utf-8"),
|
|
||||||
)
|
|
||||||
storage_event.content = encoded_content
|
|
||||||
|
|
||||||
sessionFactory.add(storage_event)
|
sessionFactory.add(storage_event)
|
||||||
|
|
||||||
@ -574,10 +598,3 @@ def _merge_state(app_state, user_state, session_state):
|
|||||||
for key in user_state.keys():
|
for key in user_state.keys():
|
||||||
merged_state[State.USER_PREFIX + key] = user_state[key]
|
merged_state[State.USER_PREFIX + key] = user_state[key]
|
||||||
return merged_state
|
return merged_state
|
||||||
|
|
||||||
|
|
||||||
def _decode_content(content: dict[str, Any]) -> dict[str, Any]:
|
|
||||||
for p in content["parts"]:
|
|
||||||
if "inline_data" in p:
|
|
||||||
p["inline_data"]["data"] = base64.b64decode(p["inline_data"]["data"][0])
|
|
||||||
return content
|
|
||||||
|
@ -26,7 +26,7 @@ class State:
|
|||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
value: The current value of the state dict.
|
value: The current value of the state dict.
|
||||||
delta: The delta change to the current value that hasn't been commited.
|
delta: The delta change to the current value that hasn't been committed.
|
||||||
"""
|
"""
|
||||||
self._value = value
|
self._value = value
|
||||||
self._delta = delta
|
self._delta = delta
|
||||||
|
@ -14,21 +14,23 @@
|
|||||||
import logging
|
import logging
|
||||||
import re
|
import re
|
||||||
import time
|
import time
|
||||||
from typing import Any
|
from typing import Any, Optional
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
from dateutil.parser import isoparse
|
from dateutil import parser
|
||||||
from google import genai
|
from google import genai
|
||||||
from typing_extensions import override
|
from typing_extensions import override
|
||||||
|
|
||||||
from ..events.event import Event
|
from ..events.event import Event
|
||||||
from ..events.event_actions import EventActions
|
from ..events.event_actions import EventActions
|
||||||
|
from . import _session_util
|
||||||
from .base_session_service import BaseSessionService
|
from .base_session_service import BaseSessionService
|
||||||
from .base_session_service import GetSessionConfig
|
from .base_session_service import GetSessionConfig
|
||||||
from .base_session_service import ListEventsResponse
|
from .base_session_service import ListEventsResponse
|
||||||
from .base_session_service import ListSessionsResponse
|
from .base_session_service import ListSessionsResponse
|
||||||
from .session import Session
|
from .session import Session
|
||||||
|
|
||||||
|
|
||||||
|
isoparse = parser.isoparse
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
@ -289,7 +291,7 @@ def _convert_event_to_json(event: Event):
|
|||||||
}
|
}
|
||||||
event_json['actions'] = actions_json
|
event_json['actions'] = actions_json
|
||||||
if event.content:
|
if event.content:
|
||||||
event_json['content'] = event.content.model_dump(exclude_none=True)
|
event_json['content'] = _session_util.encode_content(event.content)
|
||||||
if event.error_code:
|
if event.error_code:
|
||||||
event_json['error_code'] = event.error_code
|
event_json['error_code'] = event.error_code
|
||||||
if event.error_message:
|
if event.error_message:
|
||||||
@ -316,7 +318,7 @@ def _from_api_event(api_event: dict) -> Event:
|
|||||||
invocation_id=api_event['invocationId'],
|
invocation_id=api_event['invocationId'],
|
||||||
author=api_event['author'],
|
author=api_event['author'],
|
||||||
actions=event_actions,
|
actions=event_actions,
|
||||||
content=api_event.get('content', None),
|
content=_session_util.decode_content(api_event.get('content', None)),
|
||||||
timestamp=isoparse(api_event['timestamp']).timestamp(),
|
timestamp=isoparse(api_event['timestamp']).timestamp(),
|
||||||
error_code=api_event.get('errorCode', None),
|
error_code=api_event.get('errorCode', None),
|
||||||
error_message=api_event.get('errorMessage', None),
|
error_message=api_event.get('errorMessage', None),
|
||||||
|
@ -45,10 +45,9 @@ class AgentTool(BaseTool):
|
|||||||
skip_summarization: Whether to skip summarization of the agent output.
|
skip_summarization: Whether to skip summarization of the agent output.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, agent: BaseAgent):
|
def __init__(self, agent: BaseAgent, skip_summarization: bool = False):
|
||||||
self.agent = agent
|
self.agent = agent
|
||||||
self.skip_summarization: bool = False
|
self.skip_summarization: bool = skip_summarization
|
||||||
"""Whether to skip summarization of the agent output."""
|
|
||||||
|
|
||||||
super().__init__(name=agent.name, description=agent.description)
|
super().__init__(name=agent.name, description=agent.description)
|
||||||
|
|
||||||
|
@ -13,7 +13,9 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
from .application_integration_toolset import ApplicationIntegrationToolset
|
from .application_integration_toolset import ApplicationIntegrationToolset
|
||||||
|
from .integration_connector_tool import IntegrationConnectorTool
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
'ApplicationIntegrationToolset',
|
'ApplicationIntegrationToolset',
|
||||||
|
'IntegrationConnectorTool',
|
||||||
]
|
]
|
||||||
|
@ -12,21 +12,21 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
from typing import Dict
|
from typing import Dict, List, Optional
|
||||||
from typing import List
|
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
from fastapi.openapi.models import HTTPBearer
|
from fastapi.openapi.models import HTTPBearer
|
||||||
from google.adk.tools.application_integration_tool.clients.connections_client import ConnectionsClient
|
|
||||||
from google.adk.tools.application_integration_tool.clients.integration_client import IntegrationClient
|
|
||||||
from google.adk.tools.openapi_tool.auth.auth_helpers import service_account_scheme_credential
|
|
||||||
from google.adk.tools.openapi_tool.openapi_spec_parser.openapi_toolset import OpenAPIToolset
|
|
||||||
from google.adk.tools.openapi_tool.openapi_spec_parser.rest_api_tool import RestApiTool
|
|
||||||
|
|
||||||
from ...auth.auth_credential import AuthCredential
|
from ...auth.auth_credential import AuthCredential
|
||||||
from ...auth.auth_credential import AuthCredentialTypes
|
from ...auth.auth_credential import AuthCredentialTypes
|
||||||
from ...auth.auth_credential import ServiceAccount
|
from ...auth.auth_credential import ServiceAccount
|
||||||
from ...auth.auth_credential import ServiceAccountCredential
|
from ...auth.auth_credential import ServiceAccountCredential
|
||||||
|
from ..openapi_tool.auth.auth_helpers import service_account_scheme_credential
|
||||||
|
from ..openapi_tool.openapi_spec_parser.openapi_spec_parser import OpenApiSpecParser
|
||||||
|
from ..openapi_tool.openapi_spec_parser.openapi_toolset import OpenAPIToolset
|
||||||
|
from ..openapi_tool.openapi_spec_parser.rest_api_tool import RestApiTool
|
||||||
|
from .clients.connections_client import ConnectionsClient
|
||||||
|
from .clients.integration_client import IntegrationClient
|
||||||
|
from .integration_connector_tool import IntegrationConnectorTool
|
||||||
|
|
||||||
|
|
||||||
# TODO(cheliu): Apply a common toolset interface
|
# TODO(cheliu): Apply a common toolset interface
|
||||||
@ -168,6 +168,7 @@ class ApplicationIntegrationToolset:
|
|||||||
actions,
|
actions,
|
||||||
service_account_json,
|
service_account_json,
|
||||||
)
|
)
|
||||||
|
connection_details = {}
|
||||||
if integration and trigger:
|
if integration and trigger:
|
||||||
spec = integration_client.get_openapi_spec_for_integration()
|
spec = integration_client.get_openapi_spec_for_integration()
|
||||||
elif connection and (entity_operations or actions):
|
elif connection and (entity_operations or actions):
|
||||||
@ -175,16 +176,6 @@ class ApplicationIntegrationToolset:
|
|||||||
project, location, connection, service_account_json
|
project, location, connection, service_account_json
|
||||||
)
|
)
|
||||||
connection_details = connections_client.get_connection_details()
|
connection_details = connections_client.get_connection_details()
|
||||||
tool_instructions += (
|
|
||||||
"ALWAYS use serviceName = "
|
|
||||||
+ connection_details["serviceName"]
|
|
||||||
+ ", host = "
|
|
||||||
+ connection_details["host"]
|
|
||||||
+ " and the connection name = "
|
|
||||||
+ f"projects/{project}/locations/{location}/connections/{connection} when"
|
|
||||||
" using this tool"
|
|
||||||
+ ". DONOT ask the user for these values as you already have those."
|
|
||||||
)
|
|
||||||
spec = integration_client.get_openapi_spec_for_connection(
|
spec = integration_client.get_openapi_spec_for_connection(
|
||||||
tool_name,
|
tool_name,
|
||||||
tool_instructions,
|
tool_instructions,
|
||||||
@ -194,9 +185,9 @@ class ApplicationIntegrationToolset:
|
|||||||
"Either (integration and trigger) or (connection and"
|
"Either (integration and trigger) or (connection and"
|
||||||
" (entity_operations or actions)) should be provided."
|
" (entity_operations or actions)) should be provided."
|
||||||
)
|
)
|
||||||
self._parse_spec_to_tools(spec)
|
self._parse_spec_to_tools(spec, connection_details)
|
||||||
|
|
||||||
def _parse_spec_to_tools(self, spec_dict):
|
def _parse_spec_to_tools(self, spec_dict, connection_details):
|
||||||
"""Parses the spec dict to a list of RestApiTool."""
|
"""Parses the spec dict to a list of RestApiTool."""
|
||||||
if self.service_account_json:
|
if self.service_account_json:
|
||||||
sa_credential = ServiceAccountCredential.model_validate_json(
|
sa_credential = ServiceAccountCredential.model_validate_json(
|
||||||
@ -218,12 +209,43 @@ class ApplicationIntegrationToolset:
|
|||||||
),
|
),
|
||||||
)
|
)
|
||||||
auth_scheme = HTTPBearer(bearerFormat="JWT")
|
auth_scheme = HTTPBearer(bearerFormat="JWT")
|
||||||
tools = OpenAPIToolset(
|
|
||||||
spec_dict=spec_dict,
|
if self.integration and self.trigger:
|
||||||
auth_credential=auth_credential,
|
tools = OpenAPIToolset(
|
||||||
auth_scheme=auth_scheme,
|
spec_dict=spec_dict,
|
||||||
).get_tools()
|
auth_credential=auth_credential,
|
||||||
for tool in tools:
|
auth_scheme=auth_scheme,
|
||||||
|
).get_tools()
|
||||||
|
for tool in tools:
|
||||||
|
self.generated_tools[tool.name] = tool
|
||||||
|
return
|
||||||
|
|
||||||
|
operations = OpenApiSpecParser().parse(spec_dict)
|
||||||
|
|
||||||
|
for open_api_operation in operations:
|
||||||
|
operation = getattr(open_api_operation.operation, "x-operation")
|
||||||
|
entity = None
|
||||||
|
action = None
|
||||||
|
if hasattr(open_api_operation.operation, "x-entity"):
|
||||||
|
entity = getattr(open_api_operation.operation, "x-entity")
|
||||||
|
elif hasattr(open_api_operation.operation, "x-action"):
|
||||||
|
action = getattr(open_api_operation.operation, "x-action")
|
||||||
|
rest_api_tool = RestApiTool.from_parsed_operation(open_api_operation)
|
||||||
|
if auth_scheme:
|
||||||
|
rest_api_tool.configure_auth_scheme(auth_scheme)
|
||||||
|
if auth_credential:
|
||||||
|
rest_api_tool.configure_auth_credential(auth_credential)
|
||||||
|
tool = IntegrationConnectorTool(
|
||||||
|
name=rest_api_tool.name,
|
||||||
|
description=rest_api_tool.description,
|
||||||
|
connection_name=connection_details["name"],
|
||||||
|
connection_host=connection_details["host"],
|
||||||
|
connection_service_name=connection_details["serviceName"],
|
||||||
|
entity=entity,
|
||||||
|
action=action,
|
||||||
|
operation=operation,
|
||||||
|
rest_api_tool=rest_api_tool,
|
||||||
|
)
|
||||||
self.generated_tools[tool.name] = tool
|
self.generated_tools[tool.name] = tool
|
||||||
|
|
||||||
def get_tools(self) -> List[RestApiTool]:
|
def get_tools(self) -> List[RestApiTool]:
|
||||||
|
@ -68,12 +68,14 @@ class ConnectionsClient:
|
|||||||
response = self._execute_api_call(url)
|
response = self._execute_api_call(url)
|
||||||
|
|
||||||
connection_data = response.json()
|
connection_data = response.json()
|
||||||
|
connection_name = connection_data.get("name", "")
|
||||||
service_name = connection_data.get("serviceDirectory", "")
|
service_name = connection_data.get("serviceDirectory", "")
|
||||||
host = connection_data.get("host", "")
|
host = connection_data.get("host", "")
|
||||||
if host:
|
if host:
|
||||||
service_name = connection_data.get("tlsServiceDirectory", "")
|
service_name = connection_data.get("tlsServiceDirectory", "")
|
||||||
auth_override_enabled = connection_data.get("authOverrideEnabled", False)
|
auth_override_enabled = connection_data.get("authOverrideEnabled", False)
|
||||||
return {
|
return {
|
||||||
|
"name": connection_name,
|
||||||
"serviceName": service_name,
|
"serviceName": service_name,
|
||||||
"host": host,
|
"host": host,
|
||||||
"authOverrideEnabled": auth_override_enabled,
|
"authOverrideEnabled": auth_override_enabled,
|
||||||
@ -291,13 +293,9 @@ class ConnectionsClient:
|
|||||||
tool_name: str = "",
|
tool_name: str = "",
|
||||||
tool_instructions: str = "",
|
tool_instructions: str = "",
|
||||||
) -> Dict[str, Any]:
|
) -> Dict[str, Any]:
|
||||||
description = (
|
description = f"Use this tool to execute {action}"
|
||||||
f"Use this tool with" f' action = "{action}" and'
|
|
||||||
) + f' operation = "{operation}" only. Dont ask these values from user.'
|
|
||||||
if operation == "EXECUTE_QUERY":
|
if operation == "EXECUTE_QUERY":
|
||||||
description = (
|
description += (
|
||||||
(f"Use this tool with" f' action = "{action}" and')
|
|
||||||
+ f' operation = "{operation}" only. Dont ask these values from user.'
|
|
||||||
" Use pageSize = 50 and timeout = 120 until user specifies a"
|
" Use pageSize = 50 and timeout = 120 until user specifies a"
|
||||||
" different value otherwise. If user provides a query in natural"
|
" different value otherwise. If user provides a query in natural"
|
||||||
" language, convert it to SQL query and then execute it using the"
|
" language, convert it to SQL query and then execute it using the"
|
||||||
@ -308,6 +306,8 @@ class ConnectionsClient:
|
|||||||
"summary": f"{action_display_name}",
|
"summary": f"{action_display_name}",
|
||||||
"description": f"{description} {tool_instructions}",
|
"description": f"{description} {tool_instructions}",
|
||||||
"operationId": f"{tool_name}_{action_display_name}",
|
"operationId": f"{tool_name}_{action_display_name}",
|
||||||
|
"x-action": f"{action}",
|
||||||
|
"x-operation": f"{operation}",
|
||||||
"requestBody": {
|
"requestBody": {
|
||||||
"content": {
|
"content": {
|
||||||
"application/json": {
|
"application/json": {
|
||||||
@ -347,16 +347,12 @@ class ConnectionsClient:
|
|||||||
"post": {
|
"post": {
|
||||||
"summary": f"List {entity}",
|
"summary": f"List {entity}",
|
||||||
"description": (
|
"description": (
|
||||||
f"Returns all entities of type {entity}. Use this tool with"
|
f"""Returns the list of {entity} data. If the page token was available in the response, let users know there are more records available. Ask if the user wants to fetch the next page of results. When passing filter use the
|
||||||
+ f' entity = "{entity}" and'
|
following format: `field_name1='value1' AND field_name2='value2'
|
||||||
+ ' operation = "LIST_ENTITIES" only. Dont ask these values'
|
`. {tool_instructions}"""
|
||||||
" from"
|
|
||||||
+ ' user. Always use ""'
|
|
||||||
+ ' as filter clause and ""'
|
|
||||||
+ " as page token and 50 as page size until user specifies a"
|
|
||||||
" different value otherwise. Use single quotes for strings in"
|
|
||||||
f" filter clause. {tool_instructions}"
|
|
||||||
),
|
),
|
||||||
|
"x-operation": "LIST_ENTITIES",
|
||||||
|
"x-entity": f"{entity}",
|
||||||
"operationId": f"{tool_name}_list_{entity}",
|
"operationId": f"{tool_name}_list_{entity}",
|
||||||
"requestBody": {
|
"requestBody": {
|
||||||
"content": {
|
"content": {
|
||||||
@ -401,14 +397,11 @@ class ConnectionsClient:
|
|||||||
"post": {
|
"post": {
|
||||||
"summary": f"Get {entity}",
|
"summary": f"Get {entity}",
|
||||||
"description": (
|
"description": (
|
||||||
(
|
f"Returns the details of the {entity}. {tool_instructions}"
|
||||||
f"Returns the details of the {entity}. Use this tool with"
|
|
||||||
f' entity = "{entity}" and'
|
|
||||||
)
|
|
||||||
+ ' operation = "GET_ENTITY" only. Dont ask these values from'
|
|
||||||
f" user. {tool_instructions}"
|
|
||||||
),
|
),
|
||||||
"operationId": f"{tool_name}_get_{entity}",
|
"operationId": f"{tool_name}_get_{entity}",
|
||||||
|
"x-operation": "GET_ENTITY",
|
||||||
|
"x-entity": f"{entity}",
|
||||||
"requestBody": {
|
"requestBody": {
|
||||||
"content": {
|
"content": {
|
||||||
"application/json": {
|
"application/json": {
|
||||||
@ -445,17 +438,10 @@ class ConnectionsClient:
|
|||||||
) -> Dict[str, Any]:
|
) -> Dict[str, Any]:
|
||||||
return {
|
return {
|
||||||
"post": {
|
"post": {
|
||||||
"summary": f"Create {entity}",
|
"summary": f"Creates a new {entity}",
|
||||||
"description": (
|
"description": f"Creates a new {entity}. {tool_instructions}",
|
||||||
(
|
"x-operation": "CREATE_ENTITY",
|
||||||
f"Creates a new entity of type {entity}. Use this tool with"
|
"x-entity": f"{entity}",
|
||||||
f' entity = "{entity}" and'
|
|
||||||
)
|
|
||||||
+ ' operation = "CREATE_ENTITY" only. Dont ask these values'
|
|
||||||
" from"
|
|
||||||
+ " user. Follow the schema of the entity provided in the"
|
|
||||||
f" instructions to create {entity}. {tool_instructions}"
|
|
||||||
),
|
|
||||||
"operationId": f"{tool_name}_create_{entity}",
|
"operationId": f"{tool_name}_create_{entity}",
|
||||||
"requestBody": {
|
"requestBody": {
|
||||||
"content": {
|
"content": {
|
||||||
@ -491,18 +477,10 @@ class ConnectionsClient:
|
|||||||
) -> Dict[str, Any]:
|
) -> Dict[str, Any]:
|
||||||
return {
|
return {
|
||||||
"post": {
|
"post": {
|
||||||
"summary": f"Update {entity}",
|
"summary": f"Updates the {entity}",
|
||||||
"description": (
|
"description": f"Updates the {entity}. {tool_instructions}",
|
||||||
(
|
"x-operation": "UPDATE_ENTITY",
|
||||||
f"Updates an entity of type {entity}. Use this tool with"
|
"x-entity": f"{entity}",
|
||||||
f' entity = "{entity}" and'
|
|
||||||
)
|
|
||||||
+ ' operation = "UPDATE_ENTITY" only. Dont ask these values'
|
|
||||||
" from"
|
|
||||||
+ " user. Use entityId to uniquely identify the entity to"
|
|
||||||
" update. Follow the schema of the entity provided in the"
|
|
||||||
f" instructions to update {entity}. {tool_instructions}"
|
|
||||||
),
|
|
||||||
"operationId": f"{tool_name}_update_{entity}",
|
"operationId": f"{tool_name}_update_{entity}",
|
||||||
"requestBody": {
|
"requestBody": {
|
||||||
"content": {
|
"content": {
|
||||||
@ -538,16 +516,10 @@ class ConnectionsClient:
|
|||||||
) -> Dict[str, Any]:
|
) -> Dict[str, Any]:
|
||||||
return {
|
return {
|
||||||
"post": {
|
"post": {
|
||||||
"summary": f"Delete {entity}",
|
"summary": f"Delete the {entity}",
|
||||||
"description": (
|
"description": f"Deletes the {entity}. {tool_instructions}",
|
||||||
(
|
"x-operation": "DELETE_ENTITY",
|
||||||
f"Deletes an entity of type {entity}. Use this tool with"
|
"x-entity": f"{entity}",
|
||||||
f' entity = "{entity}" and'
|
|
||||||
)
|
|
||||||
+ ' operation = "DELETE_ENTITY" only. Dont ask these values'
|
|
||||||
" from"
|
|
||||||
f" user. {tool_instructions}"
|
|
||||||
),
|
|
||||||
"operationId": f"{tool_name}_delete_{entity}",
|
"operationId": f"{tool_name}_delete_{entity}",
|
||||||
"requestBody": {
|
"requestBody": {
|
||||||
"content": {
|
"content": {
|
||||||
|
@ -0,0 +1,159 @@
|
|||||||
|
# Copyright 2025 Google LLC
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from typing import Any
|
||||||
|
from typing import Dict
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from google.adk.tools.openapi_tool.openapi_spec_parser.rest_api_tool import RestApiTool
|
||||||
|
from google.adk.tools.openapi_tool.openapi_spec_parser.rest_api_tool import to_gemini_schema
|
||||||
|
from google.genai.types import FunctionDeclaration
|
||||||
|
from typing_extensions import override
|
||||||
|
|
||||||
|
from .. import BaseTool
|
||||||
|
from ..tool_context import ToolContext
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class IntegrationConnectorTool(BaseTool):
|
||||||
|
"""A tool that wraps a RestApiTool to interact with a specific Application Integration endpoint.
|
||||||
|
|
||||||
|
This tool adds Application Integration specific context like connection
|
||||||
|
details, entity, operation, and action to the underlying REST API call
|
||||||
|
handled by RestApiTool. It prepares the arguments and then delegates the
|
||||||
|
actual API call execution to the contained RestApiTool instance.
|
||||||
|
|
||||||
|
* Generates request params and body
|
||||||
|
* Attaches auth credentials to API call.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
```
|
||||||
|
# Each API operation in the spec will be turned into its own tool
|
||||||
|
# Name of the tool is the operationId of that operation, in snake case
|
||||||
|
operations = OperationGenerator().parse(openapi_spec_dict)
|
||||||
|
tool = [RestApiTool.from_parsed_operation(o) for o in operations]
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
|
||||||
|
EXCLUDE_FIELDS = [
|
||||||
|
'connection_name',
|
||||||
|
'service_name',
|
||||||
|
'host',
|
||||||
|
'entity',
|
||||||
|
'operation',
|
||||||
|
'action',
|
||||||
|
]
|
||||||
|
|
||||||
|
OPTIONAL_FIELDS = [
|
||||||
|
'page_size',
|
||||||
|
'page_token',
|
||||||
|
'filter',
|
||||||
|
]
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
name: str,
|
||||||
|
description: str,
|
||||||
|
connection_name: str,
|
||||||
|
connection_host: str,
|
||||||
|
connection_service_name: str,
|
||||||
|
entity: str,
|
||||||
|
operation: str,
|
||||||
|
action: str,
|
||||||
|
rest_api_tool: RestApiTool,
|
||||||
|
):
|
||||||
|
"""Initializes the ApplicationIntegrationTool.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name: The name of the tool, typically derived from the API operation.
|
||||||
|
Should be unique and adhere to Gemini function naming conventions
|
||||||
|
(e.g., less than 64 characters).
|
||||||
|
description: A description of what the tool does, usually based on the
|
||||||
|
API operation's summary or description.
|
||||||
|
connection_name: The name of the Integration Connector connection.
|
||||||
|
connection_host: The hostname or IP address for the connection.
|
||||||
|
connection_service_name: The specific service name within the host.
|
||||||
|
entity: The Integration Connector entity being targeted.
|
||||||
|
operation: The specific operation being performed on the entity.
|
||||||
|
action: The action associated with the operation (e.g., 'execute').
|
||||||
|
rest_api_tool: An initialized RestApiTool instance that handles the
|
||||||
|
underlying REST API communication based on an OpenAPI specification
|
||||||
|
operation. This tool will be called by ApplicationIntegrationTool with
|
||||||
|
added connection and context arguments. tool =
|
||||||
|
[RestApiTool.from_parsed_operation(o) for o in operations]
|
||||||
|
"""
|
||||||
|
# Gemini restrict the length of function name to be less than 64 characters
|
||||||
|
super().__init__(
|
||||||
|
name=name,
|
||||||
|
description=description,
|
||||||
|
)
|
||||||
|
self.connection_name = connection_name
|
||||||
|
self.connection_host = connection_host
|
||||||
|
self.connection_service_name = connection_service_name
|
||||||
|
self.entity = entity
|
||||||
|
self.operation = operation
|
||||||
|
self.action = action
|
||||||
|
self.rest_api_tool = rest_api_tool
|
||||||
|
|
||||||
|
@override
|
||||||
|
def _get_declaration(self) -> FunctionDeclaration:
|
||||||
|
"""Returns the function declaration in the Gemini Schema format."""
|
||||||
|
schema_dict = self.rest_api_tool._operation_parser.get_json_schema()
|
||||||
|
for field in self.EXCLUDE_FIELDS:
|
||||||
|
if field in schema_dict['properties']:
|
||||||
|
del schema_dict['properties'][field]
|
||||||
|
for field in self.OPTIONAL_FIELDS + self.EXCLUDE_FIELDS:
|
||||||
|
if field in schema_dict['required']:
|
||||||
|
schema_dict['required'].remove(field)
|
||||||
|
|
||||||
|
parameters = to_gemini_schema(schema_dict)
|
||||||
|
function_decl = FunctionDeclaration(
|
||||||
|
name=self.name, description=self.description, parameters=parameters
|
||||||
|
)
|
||||||
|
return function_decl
|
||||||
|
|
||||||
|
@override
|
||||||
|
async def run_async(
|
||||||
|
self, *, args: dict[str, Any], tool_context: Optional[ToolContext]
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
args['connection_name'] = self.connection_name
|
||||||
|
args['service_name'] = self.connection_service_name
|
||||||
|
args['host'] = self.connection_host
|
||||||
|
args['entity'] = self.entity
|
||||||
|
args['operation'] = self.operation
|
||||||
|
args['action'] = self.action
|
||||||
|
logger.info('Running tool: %s with args: %s', self.name, args)
|
||||||
|
return self.rest_api_tool.call(args=args, tool_context=tool_context)
|
||||||
|
|
||||||
|
def __str__(self):
|
||||||
|
return (
|
||||||
|
f'ApplicationIntegrationTool(name="{self.name}",'
|
||||||
|
f' description="{self.description}",'
|
||||||
|
f' connection_name="{self.connection_name}", entity="{self.entity}",'
|
||||||
|
f' operation="{self.operation}", action="{self.action}")'
|
||||||
|
)
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return (
|
||||||
|
f'ApplicationIntegrationTool(name="{self.name}",'
|
||||||
|
f' description="{self.description}",'
|
||||||
|
f' connection_name="{self.connection_name}",'
|
||||||
|
f' connection_host="{self.connection_host}",'
|
||||||
|
f' connection_service_name="{self.connection_service_name}",'
|
||||||
|
f' entity="{self.entity}", operation="{self.operation}",'
|
||||||
|
f' action="{self.action}", rest_api_tool={repr(self.rest_api_tool)})'
|
||||||
|
)
|
@ -59,6 +59,23 @@ class FunctionTool(BaseTool):
|
|||||||
if 'tool_context' in signature.parameters:
|
if 'tool_context' in signature.parameters:
|
||||||
args_to_call['tool_context'] = tool_context
|
args_to_call['tool_context'] = tool_context
|
||||||
|
|
||||||
|
# Before invoking the function, we check for if the list of args passed in
|
||||||
|
# has all the mandatory arguments or not.
|
||||||
|
# If the check fails, then we don't invoke the tool and let the Agent know
|
||||||
|
# that there was a missing a input parameter. This will basically help
|
||||||
|
# the underlying model fix the issue and retry.
|
||||||
|
mandatory_args = self._get_mandatory_args()
|
||||||
|
missing_mandatory_args = [
|
||||||
|
arg for arg in mandatory_args if arg not in args_to_call
|
||||||
|
]
|
||||||
|
|
||||||
|
if missing_mandatory_args:
|
||||||
|
missing_mandatory_args_str = '\n'.join(missing_mandatory_args)
|
||||||
|
error_str = f"""Invoking `{self.name}()` failed as the following mandatory input parameters are not present:
|
||||||
|
{missing_mandatory_args_str}
|
||||||
|
You could retry calling this tool, but it is IMPORTANT for you to provide all the mandatory parameters."""
|
||||||
|
return {'error': error_str}
|
||||||
|
|
||||||
if inspect.iscoroutinefunction(self.func):
|
if inspect.iscoroutinefunction(self.func):
|
||||||
return await self.func(**args_to_call) or {}
|
return await self.func(**args_to_call) or {}
|
||||||
else:
|
else:
|
||||||
@ -85,3 +102,28 @@ class FunctionTool(BaseTool):
|
|||||||
args_to_call['tool_context'] = tool_context
|
args_to_call['tool_context'] = tool_context
|
||||||
async for item in self.func(**args_to_call):
|
async for item in self.func(**args_to_call):
|
||||||
yield item
|
yield item
|
||||||
|
|
||||||
|
def _get_mandatory_args(
|
||||||
|
self,
|
||||||
|
) -> list[str]:
|
||||||
|
"""Identifies mandatory parameters (those without default values) for a function.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A list of strings, where each string is the name of a mandatory parameter.
|
||||||
|
"""
|
||||||
|
signature = inspect.signature(self.func)
|
||||||
|
mandatory_params = []
|
||||||
|
|
||||||
|
for name, param in signature.parameters.items():
|
||||||
|
# A parameter is mandatory if:
|
||||||
|
# 1. It has no default value (param.default is inspect.Parameter.empty)
|
||||||
|
# 2. It's not a variable positional (*args) or variable keyword (**kwargs) parameter
|
||||||
|
#
|
||||||
|
# For more refer to: https://docs.python.org/3/library/inspect.html#inspect.Parameter.kind
|
||||||
|
if param.default == inspect.Parameter.empty and param.kind not in (
|
||||||
|
inspect.Parameter.VAR_POSITIONAL,
|
||||||
|
inspect.Parameter.VAR_KEYWORD,
|
||||||
|
):
|
||||||
|
mandatory_params.append(name)
|
||||||
|
|
||||||
|
return mandatory_params
|
||||||
|
@ -11,10 +11,12 @@
|
|||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import inspect
|
import inspect
|
||||||
import os
|
import os
|
||||||
from typing import Any
|
from typing import Any
|
||||||
from typing import Dict
|
|
||||||
from typing import Final
|
from typing import Final
|
||||||
from typing import List
|
from typing import List
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
@ -28,6 +30,7 @@ from .googleapi_to_openapi_converter import GoogleApiToOpenApiConverter
|
|||||||
|
|
||||||
|
|
||||||
class GoogleApiToolSet:
|
class GoogleApiToolSet:
|
||||||
|
"""Google API Tool Set."""
|
||||||
|
|
||||||
def __init__(self, tools: List[RestApiTool]):
|
def __init__(self, tools: List[RestApiTool]):
|
||||||
self.tools: Final[List[GoogleApiTool]] = [
|
self.tools: Final[List[GoogleApiTool]] = [
|
||||||
@ -45,10 +48,10 @@ class GoogleApiToolSet:
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _load_tool_set_with_oidc_auth(
|
def _load_tool_set_with_oidc_auth(
|
||||||
spec_file: str = None,
|
spec_file: Optional[str] = None,
|
||||||
spec_dict: Dict[str, Any] = None,
|
spec_dict: Optional[dict[str, Any]] = None,
|
||||||
scopes: list[str] = None,
|
scopes: Optional[list[str]] = None,
|
||||||
) -> Optional[OpenAPIToolset]:
|
) -> OpenAPIToolset:
|
||||||
spec_str = None
|
spec_str = None
|
||||||
if spec_file:
|
if spec_file:
|
||||||
# Get the frame of the caller
|
# Get the frame of the caller
|
||||||
@ -90,18 +93,18 @@ class GoogleApiToolSet:
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def load_tool_set(
|
def load_tool_set(
|
||||||
cl: Type['GoogleApiToolSet'],
|
cls: Type[GoogleApiToolSet],
|
||||||
api_name: str,
|
api_name: str,
|
||||||
api_version: str,
|
api_version: str,
|
||||||
) -> 'GoogleApiToolSet':
|
) -> GoogleApiToolSet:
|
||||||
spec_dict = GoogleApiToOpenApiConverter(api_name, api_version).convert()
|
spec_dict = GoogleApiToOpenApiConverter(api_name, api_version).convert()
|
||||||
scope = list(
|
scope = list(
|
||||||
spec_dict['components']['securitySchemes']['oauth2']['flows'][
|
spec_dict['components']['securitySchemes']['oauth2']['flows'][
|
||||||
'authorizationCode'
|
'authorizationCode'
|
||||||
]['scopes'].keys()
|
]['scopes'].keys()
|
||||||
)[0]
|
)[0]
|
||||||
return cl(
|
return cls(
|
||||||
cl._load_tool_set_with_oidc_auth(
|
cls._load_tool_set_with_oidc_auth(
|
||||||
spec_dict=spec_dict, scopes=[scope]
|
spec_dict=spec_dict, scopes=[scope]
|
||||||
).get_tools()
|
).get_tools()
|
||||||
)
|
)
|
||||||
|
@ -89,7 +89,7 @@ class LoadArtifactsTool(BaseTool):
|
|||||||
than the function call.
|
than the function call.
|
||||||
"""])
|
"""])
|
||||||
|
|
||||||
# Attache the content of the artifacts if the model requests them.
|
# Attach the content of the artifacts if the model requests them.
|
||||||
# This only adds the content to the model request, instead of the session.
|
# This only adds the content to the model request, instead of the session.
|
||||||
if llm_request.contents and llm_request.contents[-1].parts:
|
if llm_request.contents and llm_request.contents[-1].parts:
|
||||||
function_response = llm_request.contents[-1].parts[0].function_response
|
function_response = llm_request.contents[-1].parts[0].function_response
|
||||||
|
@ -66,10 +66,10 @@ class OAuth2CredentialExchanger(BaseAuthCredentialExchanger):
|
|||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
An AuthCredential object containing the HTTP bearer access token. If the
|
An AuthCredential object containing the HTTP bearer access token. If the
|
||||||
HTTO bearer token cannot be generated, return the origianl credential
|
HTTP bearer token cannot be generated, return the original credential.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if "access_token" not in auth_credential.oauth2.token:
|
if not auth_credential.oauth2.access_token:
|
||||||
return auth_credential
|
return auth_credential
|
||||||
|
|
||||||
# Return the access token as a bearer token.
|
# Return the access token as a bearer token.
|
||||||
@ -78,7 +78,7 @@ class OAuth2CredentialExchanger(BaseAuthCredentialExchanger):
|
|||||||
http=HttpAuth(
|
http=HttpAuth(
|
||||||
scheme="bearer",
|
scheme="bearer",
|
||||||
credentials=HttpCredentials(
|
credentials=HttpCredentials(
|
||||||
token=auth_credential.oauth2.token["access_token"]
|
token=auth_credential.oauth2.access_token
|
||||||
),
|
),
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
@ -111,7 +111,7 @@ class OAuth2CredentialExchanger(BaseAuthCredentialExchanger):
|
|||||||
return auth_credential
|
return auth_credential
|
||||||
|
|
||||||
# If access token is exchanged, exchange a HTTPBearer token.
|
# If access token is exchanged, exchange a HTTPBearer token.
|
||||||
if auth_credential.oauth2.token:
|
if auth_credential.oauth2.access_token:
|
||||||
return self.generate_auth_token(auth_credential)
|
return self.generate_auth_token(auth_credential)
|
||||||
|
|
||||||
return None
|
return None
|
||||||
|
@ -124,7 +124,7 @@ class OpenAPIToolset:
|
|||||||
def _load_spec(
|
def _load_spec(
|
||||||
self, spec_str: str, spec_type: Literal["json", "yaml"]
|
self, spec_str: str, spec_type: Literal["json", "yaml"]
|
||||||
) -> Dict[str, Any]:
|
) -> Dict[str, Any]:
|
||||||
"""Loads the OpenAPI spec string into adictionary."""
|
"""Loads the OpenAPI spec string into a dictionary."""
|
||||||
if spec_type == "json":
|
if spec_type == "json":
|
||||||
return json.loads(spec_str)
|
return json.loads(spec_str)
|
||||||
elif spec_type == "yaml":
|
elif spec_type == "yaml":
|
||||||
|
@ -14,20 +14,12 @@
|
|||||||
|
|
||||||
import inspect
|
import inspect
|
||||||
from textwrap import dedent
|
from textwrap import dedent
|
||||||
from typing import Any
|
from typing import Any, Dict, List, Optional, Union
|
||||||
from typing import Dict
|
|
||||||
from typing import List
|
|
||||||
from typing import Optional
|
|
||||||
from typing import Union
|
|
||||||
|
|
||||||
from fastapi.encoders import jsonable_encoder
|
from fastapi.encoders import jsonable_encoder
|
||||||
from fastapi.openapi.models import Operation
|
from fastapi.openapi.models import Operation, Parameter, Schema
|
||||||
from fastapi.openapi.models import Parameter
|
|
||||||
from fastapi.openapi.models import Schema
|
|
||||||
|
|
||||||
from ..common.common import ApiParameter
|
from ..common.common import ApiParameter, PydocHelper, to_snake_case
|
||||||
from ..common.common import PydocHelper
|
|
||||||
from ..common.common import to_snake_case
|
|
||||||
|
|
||||||
|
|
||||||
class OperationParser:
|
class OperationParser:
|
||||||
@ -113,7 +105,8 @@ class OperationParser:
|
|||||||
description = request_body.description or ''
|
description = request_body.description or ''
|
||||||
|
|
||||||
if schema and schema.type == 'object':
|
if schema and schema.type == 'object':
|
||||||
for prop_name, prop_details in schema.properties.items():
|
properties = schema.properties or {}
|
||||||
|
for prop_name, prop_details in properties.items():
|
||||||
self.params.append(
|
self.params.append(
|
||||||
ApiParameter(
|
ApiParameter(
|
||||||
original_name=prop_name,
|
original_name=prop_name,
|
||||||
|
@ -17,6 +17,7 @@ from typing import Dict
|
|||||||
from typing import List
|
from typing import List
|
||||||
from typing import Literal
|
from typing import Literal
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
from typing import Sequence
|
||||||
from typing import Tuple
|
from typing import Tuple
|
||||||
from typing import Union
|
from typing import Union
|
||||||
|
|
||||||
@ -59,6 +60,40 @@ def snake_to_lower_camel(snake_case_string: str):
|
|||||||
])
|
])
|
||||||
|
|
||||||
|
|
||||||
|
# TODO: Switch to Gemini `from_json_schema` util when it is released
|
||||||
|
# in Gemini SDK.
|
||||||
|
def normalize_json_schema_type(
|
||||||
|
json_schema_type: Optional[Union[str, Sequence[str]]],
|
||||||
|
) -> tuple[Optional[str], bool]:
|
||||||
|
"""Converts a JSON Schema Type into Gemini Schema type.
|
||||||
|
|
||||||
|
Adopted and modified from Gemini SDK. This gets the first available schema
|
||||||
|
type from JSON Schema, and use it to mark Gemini schema type. If JSON Schema
|
||||||
|
contains a list of types, the first non null type is used.
|
||||||
|
|
||||||
|
Remove this after switching to Gemini `from_json_schema`.
|
||||||
|
"""
|
||||||
|
if json_schema_type is None:
|
||||||
|
return None, False
|
||||||
|
if isinstance(json_schema_type, str):
|
||||||
|
if json_schema_type == "null":
|
||||||
|
return None, True
|
||||||
|
return json_schema_type, False
|
||||||
|
|
||||||
|
non_null_types = []
|
||||||
|
nullable = False
|
||||||
|
# If json schema type is an array, pick the first non null type.
|
||||||
|
for type_value in json_schema_type:
|
||||||
|
if type_value == "null":
|
||||||
|
nullable = True
|
||||||
|
else:
|
||||||
|
non_null_types.append(type_value)
|
||||||
|
non_null_type = non_null_types[0] if non_null_types else None
|
||||||
|
return non_null_type, nullable
|
||||||
|
|
||||||
|
|
||||||
|
# TODO: Switch to Gemini `from_json_schema` util when it is released
|
||||||
|
# in Gemini SDK.
|
||||||
def to_gemini_schema(openapi_schema: Optional[Dict[str, Any]] = None) -> Schema:
|
def to_gemini_schema(openapi_schema: Optional[Dict[str, Any]] = None) -> Schema:
|
||||||
"""Converts an OpenAPI schema dictionary to a Gemini Schema object.
|
"""Converts an OpenAPI schema dictionary to a Gemini Schema object.
|
||||||
|
|
||||||
@ -82,13 +117,6 @@ def to_gemini_schema(openapi_schema: Optional[Dict[str, Any]] = None) -> Schema:
|
|||||||
if not openapi_schema.get("type"):
|
if not openapi_schema.get("type"):
|
||||||
openapi_schema["type"] = "object"
|
openapi_schema["type"] = "object"
|
||||||
|
|
||||||
# Adding this to avoid "properties: should be non-empty for OBJECT type" error
|
|
||||||
# See b/385165182
|
|
||||||
if openapi_schema.get("type", "") == "object" and not openapi_schema.get(
|
|
||||||
"properties"
|
|
||||||
):
|
|
||||||
openapi_schema["properties"] = {"dummy_DO_NOT_GENERATE": {"type": "string"}}
|
|
||||||
|
|
||||||
for key, value in openapi_schema.items():
|
for key, value in openapi_schema.items():
|
||||||
snake_case_key = to_snake_case(key)
|
snake_case_key = to_snake_case(key)
|
||||||
# Check if the snake_case_key exists in the Schema model's fields.
|
# Check if the snake_case_key exists in the Schema model's fields.
|
||||||
@ -99,7 +127,17 @@ def to_gemini_schema(openapi_schema: Optional[Dict[str, Any]] = None) -> Schema:
|
|||||||
# Format: properties[expiration].format: only 'enum' and 'date-time' are
|
# Format: properties[expiration].format: only 'enum' and 'date-time' are
|
||||||
# supported for STRING type
|
# supported for STRING type
|
||||||
continue
|
continue
|
||||||
if snake_case_key == "properties" and isinstance(value, dict):
|
elif snake_case_key == "type":
|
||||||
|
schema_type, nullable = normalize_json_schema_type(
|
||||||
|
openapi_schema.get("type", None)
|
||||||
|
)
|
||||||
|
# Adding this to force adding a type to an empty dict
|
||||||
|
# This avoid "... one_of or any_of must specify a type" error
|
||||||
|
pydantic_schema_data["type"] = schema_type if schema_type else "object"
|
||||||
|
pydantic_schema_data["type"] = pydantic_schema_data["type"].upper()
|
||||||
|
if nullable:
|
||||||
|
pydantic_schema_data["nullable"] = True
|
||||||
|
elif snake_case_key == "properties" and isinstance(value, dict):
|
||||||
pydantic_schema_data[snake_case_key] = {
|
pydantic_schema_data[snake_case_key] = {
|
||||||
k: to_gemini_schema(v) for k, v in value.items()
|
k: to_gemini_schema(v) for k, v in value.items()
|
||||||
}
|
}
|
||||||
|
@ -13,4 +13,4 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
# version: date+base_cl
|
# version: date+base_cl
|
||||||
__version__ = "0.1.1"
|
__version__ = "0.3.0"
|
||||||
|
@ -241,7 +241,7 @@ def test_langchain_tool_success(agent_runner: TestRunner):
|
|||||||
def test_crewai_tool_success(agent_runner: TestRunner):
|
def test_crewai_tool_success(agent_runner: TestRunner):
|
||||||
_call_function_and_assert(
|
_call_function_and_assert(
|
||||||
agent_runner,
|
agent_runner,
|
||||||
"direcotry_read_tool",
|
"directory_read_tool",
|
||||||
"Find all the file paths",
|
"Find all the file paths",
|
||||||
"file",
|
"file",
|
||||||
)
|
)
|
||||||
|
@ -126,12 +126,8 @@ def oauth2_credentials_with_token():
|
|||||||
client_id="mock_client_id",
|
client_id="mock_client_id",
|
||||||
client_secret="mock_client_secret",
|
client_secret="mock_client_secret",
|
||||||
redirect_uri="https://example.com/callback",
|
redirect_uri="https://example.com/callback",
|
||||||
token={
|
access_token="mock_access_token",
|
||||||
"access_token": "mock_access_token",
|
refresh_token="mock_refresh_token",
|
||||||
"token_type": "bearer",
|
|
||||||
"expires_in": 3600,
|
|
||||||
"refresh_token": "mock_refresh_token",
|
|
||||||
},
|
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -458,7 +454,7 @@ class TestParseAndStoreAuthResponse:
|
|||||||
"""Test with an OAuth auth scheme."""
|
"""Test with an OAuth auth scheme."""
|
||||||
mock_exchange_token.return_value = AuthCredential(
|
mock_exchange_token.return_value = AuthCredential(
|
||||||
auth_type=AuthCredentialTypes.OAUTH2,
|
auth_type=AuthCredentialTypes.OAUTH2,
|
||||||
oauth2=OAuth2Auth(token={"access_token": "exchanged_token"}),
|
oauth2=OAuth2Auth(access_token="exchanged_token"),
|
||||||
)
|
)
|
||||||
|
|
||||||
handler = AuthHandler(auth_config_with_exchanged)
|
handler = AuthHandler(auth_config_with_exchanged)
|
||||||
@ -573,6 +569,6 @@ class TestExchangeAuthToken:
|
|||||||
handler = AuthHandler(auth_config_with_auth_code)
|
handler = AuthHandler(auth_config_with_auth_code)
|
||||||
result = handler.exchange_auth_token()
|
result = handler.exchange_auth_token()
|
||||||
|
|
||||||
assert result.oauth2.token["access_token"] == "mock_access_token"
|
assert result.oauth2.access_token == "mock_access_token"
|
||||||
assert result.oauth2.token["refresh_token"] == "mock_refresh_token"
|
assert result.oauth2.refresh_token == "mock_refresh_token"
|
||||||
assert result.auth_type == AuthCredentialTypes.OAUTH2
|
assert result.auth_type == AuthCredentialTypes.OAUTH2
|
||||||
|
109
tests/unittests/flows/llm_flows/test_async_tool_callbacks.py
Normal file
109
tests/unittests/flows/llm_flows/test_async_tool_callbacks.py
Normal file
@ -0,0 +1,109 @@
|
|||||||
|
# Copyright 2025 Google LLC
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from typing import Any, Dict, Optional
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from google.adk.agents import Agent
|
||||||
|
from google.adk.tools.function_tool import FunctionTool
|
||||||
|
from google.adk.tools.tool_context import ToolContext
|
||||||
|
from google.adk.flows.llm_flows.functions import handle_function_calls_async
|
||||||
|
from google.adk.events.event import Event
|
||||||
|
from google.genai import types
|
||||||
|
|
||||||
|
from ... import utils
|
||||||
|
|
||||||
|
|
||||||
|
class AsyncBeforeToolCallback:
|
||||||
|
|
||||||
|
def __init__(self, mock_response: Dict[str, Any]):
|
||||||
|
self.mock_response = mock_response
|
||||||
|
|
||||||
|
async def __call__(
|
||||||
|
self,
|
||||||
|
tool: FunctionTool,
|
||||||
|
args: Dict[str, Any],
|
||||||
|
tool_context: ToolContext,
|
||||||
|
) -> Optional[Dict[str, Any]]:
|
||||||
|
return self.mock_response
|
||||||
|
|
||||||
|
|
||||||
|
class AsyncAfterToolCallback:
|
||||||
|
|
||||||
|
def __init__(self, mock_response: Dict[str, Any]):
|
||||||
|
self.mock_response = mock_response
|
||||||
|
|
||||||
|
async def __call__(
|
||||||
|
self,
|
||||||
|
tool: FunctionTool,
|
||||||
|
args: Dict[str, Any],
|
||||||
|
tool_context: ToolContext,
|
||||||
|
tool_response: Dict[str, Any],
|
||||||
|
) -> Optional[Dict[str, Any]]:
|
||||||
|
return self.mock_response
|
||||||
|
|
||||||
|
|
||||||
|
async def invoke_tool_with_callbacks(
|
||||||
|
before_cb=None, after_cb=None
|
||||||
|
) -> Optional[Event]:
|
||||||
|
def simple_fn(**kwargs) -> Dict[str, Any]:
|
||||||
|
return {"initial": "response"}
|
||||||
|
|
||||||
|
tool = FunctionTool(simple_fn)
|
||||||
|
model = utils.MockModel.create(responses=[])
|
||||||
|
agent = Agent(
|
||||||
|
name="agent",
|
||||||
|
model=model,
|
||||||
|
tools=[tool],
|
||||||
|
before_tool_callback=before_cb,
|
||||||
|
after_tool_callback=after_cb,
|
||||||
|
)
|
||||||
|
invocation_context = utils.create_invocation_context(
|
||||||
|
agent=agent, user_content=""
|
||||||
|
)
|
||||||
|
# Build function call event
|
||||||
|
function_call = types.FunctionCall(name=tool.name, args={})
|
||||||
|
content = types.Content(parts=[types.Part(function_call=function_call)])
|
||||||
|
event = Event(
|
||||||
|
invocation_id=invocation_context.invocation_id,
|
||||||
|
author=agent.name,
|
||||||
|
content=content,
|
||||||
|
)
|
||||||
|
tools_dict = {tool.name: tool}
|
||||||
|
return await handle_function_calls_async(
|
||||||
|
invocation_context,
|
||||||
|
event,
|
||||||
|
tools_dict,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_async_before_tool_callback():
|
||||||
|
mock_resp = {"test": "before_tool_callback"}
|
||||||
|
before_cb = AsyncBeforeToolCallback(mock_resp)
|
||||||
|
result_event = await invoke_tool_with_callbacks(before_cb=before_cb)
|
||||||
|
assert result_event is not None
|
||||||
|
part = result_event.content.parts[0]
|
||||||
|
assert part.function_response.response == mock_resp
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_async_after_tool_callback():
|
||||||
|
mock_resp = {"test": "after_tool_callback"}
|
||||||
|
after_cb = AsyncAfterToolCallback(mock_resp)
|
||||||
|
result_event = await invoke_tool_with_callbacks(after_cb=after_cb)
|
||||||
|
assert result_event is not None
|
||||||
|
part = result_event.content.parts[0]
|
||||||
|
assert part.function_response.response == mock_resp
|
@ -246,7 +246,7 @@ def test_function_get_auth_response():
|
|||||||
oauth2=OAuth2Auth(
|
oauth2=OAuth2Auth(
|
||||||
client_id='oauth_client_id_1',
|
client_id='oauth_client_id_1',
|
||||||
client_secret='oauth_client_secret1',
|
client_secret='oauth_client_secret1',
|
||||||
token={'access_token': 'token1'},
|
access_token='token1',
|
||||||
),
|
),
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
@ -277,7 +277,7 @@ def test_function_get_auth_response():
|
|||||||
oauth2=OAuth2Auth(
|
oauth2=OAuth2Auth(
|
||||||
client_id='oauth_client_id_2',
|
client_id='oauth_client_id_2',
|
||||||
client_secret='oauth_client_secret2',
|
client_secret='oauth_client_secret2',
|
||||||
token={'access_token': 'token2'},
|
access_token='token2',
|
||||||
),
|
),
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
@ -14,10 +14,12 @@
|
|||||||
|
|
||||||
import json
|
import json
|
||||||
from unittest import mock
|
from unittest import mock
|
||||||
|
from fastapi.openapi.models import Operation
|
||||||
from google.adk.auth.auth_credential import AuthCredential
|
from google.adk.auth.auth_credential import AuthCredential
|
||||||
from google.adk.tools.application_integration_tool.application_integration_toolset import ApplicationIntegrationToolset
|
from google.adk.tools.application_integration_tool.application_integration_toolset import ApplicationIntegrationToolset
|
||||||
from google.adk.tools.openapi_tool.openapi_spec_parser import rest_api_tool
|
from google.adk.tools.application_integration_tool.integration_connector_tool import IntegrationConnectorTool
|
||||||
|
from google.adk.tools.openapi_tool.openapi_spec_parser import ParsedOperation, rest_api_tool
|
||||||
|
from google.adk.tools.openapi_tool.openapi_spec_parser.openapi_spec_parser import OperationEndpoint
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
|
|
||||||
@ -50,6 +52,59 @@ def mock_openapi_toolset():
|
|||||||
yield mock_toolset
|
yield mock_toolset
|
||||||
|
|
||||||
|
|
||||||
|
def get_mocked_parsed_operation(operation_id, attributes):
|
||||||
|
mock_openapi_spec_parser_instance = mock.MagicMock()
|
||||||
|
mock_parsed_operation = mock.MagicMock(spec=ParsedOperation)
|
||||||
|
mock_parsed_operation.name = "list_issues"
|
||||||
|
mock_parsed_operation.description = "list_issues_description"
|
||||||
|
mock_parsed_operation.endpoint = OperationEndpoint(
|
||||||
|
base_url="http://localhost:8080",
|
||||||
|
path="/v1/issues",
|
||||||
|
method="GET",
|
||||||
|
)
|
||||||
|
mock_parsed_operation.auth_scheme = None
|
||||||
|
mock_parsed_operation.auth_credential = None
|
||||||
|
mock_parsed_operation.additional_context = {}
|
||||||
|
mock_parsed_operation.parameters = []
|
||||||
|
mock_operation = mock.MagicMock(spec=Operation)
|
||||||
|
mock_operation.operationId = operation_id
|
||||||
|
mock_operation.description = "list_issues_description"
|
||||||
|
mock_operation.parameters = []
|
||||||
|
mock_operation.requestBody = None
|
||||||
|
mock_operation.responses = {}
|
||||||
|
mock_operation.callbacks = {}
|
||||||
|
for key, value in attributes.items():
|
||||||
|
setattr(mock_operation, key, value)
|
||||||
|
mock_parsed_operation.operation = mock_operation
|
||||||
|
mock_openapi_spec_parser_instance.parse.return_value = [mock_parsed_operation]
|
||||||
|
return mock_openapi_spec_parser_instance
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_openapi_entity_spec_parser():
|
||||||
|
with mock.patch(
|
||||||
|
"google.adk.tools.application_integration_tool.application_integration_toolset.OpenApiSpecParser"
|
||||||
|
) as mock_spec_parser:
|
||||||
|
mock_openapi_spec_parser_instance = get_mocked_parsed_operation(
|
||||||
|
"list_issues", {"x-entity": "Issues", "x-operation": "LIST_ENTITIES"}
|
||||||
|
)
|
||||||
|
mock_spec_parser.return_value = mock_openapi_spec_parser_instance
|
||||||
|
yield mock_spec_parser
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_openapi_action_spec_parser():
|
||||||
|
with mock.patch(
|
||||||
|
"google.adk.tools.application_integration_tool.application_integration_toolset.OpenApiSpecParser"
|
||||||
|
) as mock_spec_parser:
|
||||||
|
mock_openapi_action_spec_parser_instance = get_mocked_parsed_operation(
|
||||||
|
"list_issues_operation",
|
||||||
|
{"x-action": "CustomAction", "x-operation": "EXECUTE_ACTION"},
|
||||||
|
)
|
||||||
|
mock_spec_parser.return_value = mock_openapi_action_spec_parser_instance
|
||||||
|
yield mock_spec_parser
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def project():
|
def project():
|
||||||
return "test-project"
|
return "test-project"
|
||||||
@ -72,7 +127,11 @@ def connection_spec():
|
|||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def connection_details():
|
def connection_details():
|
||||||
return {"serviceName": "test-service", "host": "test.host"}
|
return {
|
||||||
|
"serviceName": "test-service",
|
||||||
|
"host": "test.host",
|
||||||
|
"name": "test-connection",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
def test_initialization_with_integration_and_trigger(
|
def test_initialization_with_integration_and_trigger(
|
||||||
@ -102,7 +161,7 @@ def test_initialization_with_connection_and_entity_operations(
|
|||||||
location,
|
location,
|
||||||
mock_integration_client,
|
mock_integration_client,
|
||||||
mock_connections_client,
|
mock_connections_client,
|
||||||
mock_openapi_toolset,
|
mock_openapi_entity_spec_parser,
|
||||||
connection_details,
|
connection_details,
|
||||||
):
|
):
|
||||||
connection_name = "test-connection"
|
connection_name = "test-connection"
|
||||||
@ -133,19 +192,17 @@ def test_initialization_with_connection_and_entity_operations(
|
|||||||
mock_connections_client.assert_called_once_with(
|
mock_connections_client.assert_called_once_with(
|
||||||
project, location, connection_name, None
|
project, location, connection_name, None
|
||||||
)
|
)
|
||||||
|
mock_openapi_entity_spec_parser.return_value.parse.assert_called_once()
|
||||||
mock_connections_client.return_value.get_connection_details.assert_called_once()
|
mock_connections_client.return_value.get_connection_details.assert_called_once()
|
||||||
mock_integration_client.return_value.get_openapi_spec_for_connection.assert_called_once_with(
|
mock_integration_client.return_value.get_openapi_spec_for_connection.assert_called_once_with(
|
||||||
tool_name,
|
tool_name,
|
||||||
tool_instructions
|
tool_instructions,
|
||||||
+ f"ALWAYS use serviceName = {connection_details['serviceName']}, host ="
|
|
||||||
f" {connection_details['host']} and the connection name ="
|
|
||||||
f" projects/{project}/locations/{location}/connections/{connection_name} when"
|
|
||||||
" using this tool. DONOT ask the user for these values as you already"
|
|
||||||
" have those.",
|
|
||||||
)
|
)
|
||||||
mock_openapi_toolset.assert_called_once()
|
|
||||||
assert len(toolset.get_tools()) == 1
|
assert len(toolset.get_tools()) == 1
|
||||||
assert toolset.get_tools()[0].name == "Test Tool"
|
assert toolset.get_tools()[0].name == "list_issues"
|
||||||
|
assert isinstance(toolset.get_tools()[0], IntegrationConnectorTool)
|
||||||
|
assert toolset.get_tools()[0].entity == "Issues"
|
||||||
|
assert toolset.get_tools()[0].operation == "LIST_ENTITIES"
|
||||||
|
|
||||||
|
|
||||||
def test_initialization_with_connection_and_actions(
|
def test_initialization_with_connection_and_actions(
|
||||||
@ -153,7 +210,7 @@ def test_initialization_with_connection_and_actions(
|
|||||||
location,
|
location,
|
||||||
mock_integration_client,
|
mock_integration_client,
|
||||||
mock_connections_client,
|
mock_connections_client,
|
||||||
mock_openapi_toolset,
|
mock_openapi_action_spec_parser,
|
||||||
connection_details,
|
connection_details,
|
||||||
):
|
):
|
||||||
connection_name = "test-connection"
|
connection_name = "test-connection"
|
||||||
@ -181,15 +238,13 @@ def test_initialization_with_connection_and_actions(
|
|||||||
mock_integration_client.return_value.get_openapi_spec_for_connection.assert_called_once_with(
|
mock_integration_client.return_value.get_openapi_spec_for_connection.assert_called_once_with(
|
||||||
tool_name,
|
tool_name,
|
||||||
tool_instructions
|
tool_instructions
|
||||||
+ f"ALWAYS use serviceName = {connection_details['serviceName']}, host ="
|
|
||||||
f" {connection_details['host']} and the connection name ="
|
|
||||||
f" projects/{project}/locations/{location}/connections/{connection_name} when"
|
|
||||||
" using this tool. DONOT ask the user for these values as you already"
|
|
||||||
" have those.",
|
|
||||||
)
|
)
|
||||||
mock_openapi_toolset.assert_called_once()
|
mock_openapi_action_spec_parser.return_value.parse.assert_called_once()
|
||||||
assert len(toolset.get_tools()) == 1
|
assert len(toolset.get_tools()) == 1
|
||||||
assert toolset.get_tools()[0].name == "Test Tool"
|
assert toolset.get_tools()[0].name == "list_issues_operation"
|
||||||
|
assert isinstance(toolset.get_tools()[0], IntegrationConnectorTool)
|
||||||
|
assert toolset.get_tools()[0].action == "CustomAction"
|
||||||
|
assert toolset.get_tools()[0].operation == "EXECUTE_ACTION"
|
||||||
|
|
||||||
|
|
||||||
def test_initialization_without_required_params(project, location):
|
def test_initialization_without_required_params(project, location):
|
||||||
@ -337,9 +392,4 @@ def test_initialization_with_connection_details(
|
|||||||
mock_integration_client.return_value.get_openapi_spec_for_connection.assert_called_once_with(
|
mock_integration_client.return_value.get_openapi_spec_for_connection.assert_called_once_with(
|
||||||
tool_name,
|
tool_name,
|
||||||
tool_instructions
|
tool_instructions
|
||||||
+ "ALWAYS use serviceName = custom-service, host = custom.host and the"
|
|
||||||
" connection name ="
|
|
||||||
" projects/test-project/locations/us-central1/connections/test-connection"
|
|
||||||
" when using this tool. DONOT ask the user for these values as you"
|
|
||||||
" already have those.",
|
|
||||||
)
|
)
|
||||||
|
@ -0,0 +1,125 @@
|
|||||||
|
# Copyright 2025 Google LLC
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from unittest import mock
|
||||||
|
|
||||||
|
from google.adk.tools.application_integration_tool.integration_connector_tool import IntegrationConnectorTool
|
||||||
|
from google.adk.tools.openapi_tool.openapi_spec_parser.rest_api_tool import RestApiTool
|
||||||
|
from google.genai.types import FunctionDeclaration
|
||||||
|
from google.genai.types import Schema
|
||||||
|
from google.genai.types import Tool
|
||||||
|
from google.genai.types import Type
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_rest_api_tool():
|
||||||
|
"""Fixture for a mocked RestApiTool."""
|
||||||
|
mock_tool = mock.MagicMock(spec=RestApiTool)
|
||||||
|
mock_tool.name = "mock_rest_tool"
|
||||||
|
mock_tool.description = "Mock REST tool description."
|
||||||
|
# Mock the internal parser needed for _get_declaration
|
||||||
|
mock_parser = mock.MagicMock()
|
||||||
|
mock_parser.get_json_schema.return_value = {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"user_id": {"type": "string", "description": "User ID"},
|
||||||
|
"connection_name": {"type": "string"},
|
||||||
|
"host": {"type": "string"},
|
||||||
|
"service_name": {"type": "string"},
|
||||||
|
"entity": {"type": "string"},
|
||||||
|
"operation": {"type": "string"},
|
||||||
|
"action": {"type": "string"},
|
||||||
|
"page_size": {"type": "integer"},
|
||||||
|
"filter": {"type": "string"},
|
||||||
|
},
|
||||||
|
"required": ["user_id", "page_size", "filter", "connection_name"],
|
||||||
|
}
|
||||||
|
mock_tool._operation_parser = mock_parser
|
||||||
|
mock_tool.call.return_value = {"status": "success", "data": "mock_data"}
|
||||||
|
return mock_tool
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def integration_tool(mock_rest_api_tool):
|
||||||
|
"""Fixture for an IntegrationConnectorTool instance."""
|
||||||
|
return IntegrationConnectorTool(
|
||||||
|
name="test_integration_tool",
|
||||||
|
description="Test integration tool description.",
|
||||||
|
connection_name="test-conn",
|
||||||
|
connection_host="test.example.com",
|
||||||
|
connection_service_name="test-service",
|
||||||
|
entity="TestEntity",
|
||||||
|
operation="LIST",
|
||||||
|
action="TestAction",
|
||||||
|
rest_api_tool=mock_rest_api_tool,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_declaration(integration_tool):
|
||||||
|
"""Tests the generation of the function declaration."""
|
||||||
|
declaration = integration_tool._get_declaration()
|
||||||
|
|
||||||
|
assert isinstance(declaration, FunctionDeclaration)
|
||||||
|
assert declaration.name == "test_integration_tool"
|
||||||
|
assert declaration.description == "Test integration tool description."
|
||||||
|
|
||||||
|
# Check parameters schema
|
||||||
|
params = declaration.parameters
|
||||||
|
assert isinstance(params, Schema)
|
||||||
|
print(f"params: {params}")
|
||||||
|
assert params.type == Type.OBJECT
|
||||||
|
|
||||||
|
# Check properties (excluded fields should not be present)
|
||||||
|
assert "user_id" in params.properties
|
||||||
|
assert "connection_name" not in params.properties
|
||||||
|
assert "host" not in params.properties
|
||||||
|
assert "service_name" not in params.properties
|
||||||
|
assert "entity" not in params.properties
|
||||||
|
assert "operation" not in params.properties
|
||||||
|
assert "action" not in params.properties
|
||||||
|
assert "page_size" in params.properties
|
||||||
|
assert "filter" in params.properties
|
||||||
|
|
||||||
|
# Check required fields (optional and excluded fields should not be required)
|
||||||
|
assert "user_id" in params.required
|
||||||
|
assert "page_size" not in params.required
|
||||||
|
assert "filter" not in params.required
|
||||||
|
assert "connection_name" not in params.required
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_run_async(integration_tool, mock_rest_api_tool):
|
||||||
|
"""Tests the async execution delegates correctly to the RestApiTool."""
|
||||||
|
input_args = {"user_id": "user123", "page_size": 10}
|
||||||
|
expected_call_args = {
|
||||||
|
"user_id": "user123",
|
||||||
|
"page_size": 10,
|
||||||
|
"connection_name": "test-conn",
|
||||||
|
"host": "test.example.com",
|
||||||
|
"service_name": "test-service",
|
||||||
|
"entity": "TestEntity",
|
||||||
|
"operation": "LIST",
|
||||||
|
"action": "TestAction",
|
||||||
|
}
|
||||||
|
|
||||||
|
result = await integration_tool.run_async(args=input_args, tool_context=None)
|
||||||
|
|
||||||
|
# Assert the underlying rest_api_tool.call was called correctly
|
||||||
|
mock_rest_api_tool.call.assert_called_once_with(
|
||||||
|
args=expected_call_args, tool_context=None
|
||||||
|
)
|
||||||
|
|
||||||
|
# Assert the result is what the mocked call returned
|
||||||
|
assert result == {"status": "success", "data": "mock_data"}
|
@ -110,7 +110,7 @@ def test_generate_auth_token_success(
|
|||||||
client_secret="test_secret",
|
client_secret="test_secret",
|
||||||
redirect_uri="http://localhost:8080",
|
redirect_uri="http://localhost:8080",
|
||||||
auth_response_uri="https://example.com/callback?code=test_code",
|
auth_response_uri="https://example.com/callback?code=test_code",
|
||||||
token={"access_token": "test_access_token"},
|
access_token="test_access_token",
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
updated_credential = oauth2_exchanger.generate_auth_token(auth_credential)
|
updated_credential = oauth2_exchanger.generate_auth_token(auth_credential)
|
||||||
@ -131,7 +131,7 @@ def test_exchange_credential_generate_auth_token(
|
|||||||
client_secret="test_secret",
|
client_secret="test_secret",
|
||||||
redirect_uri="http://localhost:8080",
|
redirect_uri="http://localhost:8080",
|
||||||
auth_response_uri="https://example.com/callback?code=test_code",
|
auth_response_uri="https://example.com/callback?code=test_code",
|
||||||
token={"access_token": "test_access_token"},
|
access_token="test_access_token",
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -164,6 +164,18 @@ def test_process_request_body_no_name():
|
|||||||
assert parser.params[0].param_location == 'body'
|
assert parser.params[0].param_location == 'body'
|
||||||
|
|
||||||
|
|
||||||
|
def test_process_request_body_empty_object():
|
||||||
|
"""Test _process_request_body with a schema that is of type object but with no properties."""
|
||||||
|
operation = Operation(
|
||||||
|
requestBody=RequestBody(
|
||||||
|
content={'application/json': MediaType(schema=Schema(type='object'))}
|
||||||
|
)
|
||||||
|
)
|
||||||
|
parser = OperationParser(operation, should_parse=False)
|
||||||
|
parser._process_request_body()
|
||||||
|
assert len(parser.params) == 0
|
||||||
|
|
||||||
|
|
||||||
def test_dedupe_param_names(sample_operation):
|
def test_dedupe_param_names(sample_operation):
|
||||||
"""Test _dedupe_param_names method."""
|
"""Test _dedupe_param_names method."""
|
||||||
parser = OperationParser(sample_operation, should_parse=False)
|
parser = OperationParser(sample_operation, should_parse=False)
|
||||||
|
@ -14,11 +14,9 @@
|
|||||||
|
|
||||||
|
|
||||||
import json
|
import json
|
||||||
from unittest.mock import MagicMock
|
from unittest.mock import MagicMock, patch
|
||||||
from unittest.mock import patch
|
|
||||||
|
|
||||||
from fastapi.openapi.models import MediaType
|
from fastapi.openapi.models import MediaType, Operation
|
||||||
from fastapi.openapi.models import Operation
|
|
||||||
from fastapi.openapi.models import Parameter as OpenAPIParameter
|
from fastapi.openapi.models import Parameter as OpenAPIParameter
|
||||||
from fastapi.openapi.models import RequestBody
|
from fastapi.openapi.models import RequestBody
|
||||||
from fastapi.openapi.models import Schema as OpenAPISchema
|
from fastapi.openapi.models import Schema as OpenAPISchema
|
||||||
@ -27,13 +25,13 @@ from google.adk.tools.openapi_tool.auth.auth_helpers import token_to_scheme_cred
|
|||||||
from google.adk.tools.openapi_tool.common.common import ApiParameter
|
from google.adk.tools.openapi_tool.common.common import ApiParameter
|
||||||
from google.adk.tools.openapi_tool.openapi_spec_parser.openapi_spec_parser import OperationEndpoint
|
from google.adk.tools.openapi_tool.openapi_spec_parser.openapi_spec_parser import OperationEndpoint
|
||||||
from google.adk.tools.openapi_tool.openapi_spec_parser.operation_parser import OperationParser
|
from google.adk.tools.openapi_tool.openapi_spec_parser.operation_parser import OperationParser
|
||||||
from google.adk.tools.openapi_tool.openapi_spec_parser.rest_api_tool import RestApiTool
|
from google.adk.tools.openapi_tool.openapi_spec_parser.rest_api_tool import (
|
||||||
from google.adk.tools.openapi_tool.openapi_spec_parser.rest_api_tool import snake_to_lower_camel
|
RestApiTool,
|
||||||
from google.adk.tools.openapi_tool.openapi_spec_parser.rest_api_tool import to_gemini_schema
|
snake_to_lower_camel,
|
||||||
|
to_gemini_schema,
|
||||||
|
)
|
||||||
from google.adk.tools.tool_context import ToolContext
|
from google.adk.tools.tool_context import ToolContext
|
||||||
from google.genai.types import FunctionDeclaration
|
from google.genai.types import FunctionDeclaration, Schema, Type
|
||||||
from google.genai.types import Schema
|
|
||||||
from google.genai.types import Type
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
|
|
||||||
@ -790,13 +788,13 @@ class TestToGeminiSchema:
|
|||||||
result = to_gemini_schema({})
|
result = to_gemini_schema({})
|
||||||
assert isinstance(result, Schema)
|
assert isinstance(result, Schema)
|
||||||
assert result.type == Type.OBJECT
|
assert result.type == Type.OBJECT
|
||||||
assert result.properties == {"dummy_DO_NOT_GENERATE": Schema(type="string")}
|
assert result.properties is None
|
||||||
|
|
||||||
def test_to_gemini_schema_dict_with_only_object_type(self):
|
def test_to_gemini_schema_dict_with_only_object_type(self):
|
||||||
result = to_gemini_schema({"type": "object"})
|
result = to_gemini_schema({"type": "object"})
|
||||||
assert isinstance(result, Schema)
|
assert isinstance(result, Schema)
|
||||||
assert result.type == Type.OBJECT
|
assert result.type == Type.OBJECT
|
||||||
assert result.properties == {"dummy_DO_NOT_GENERATE": Schema(type="string")}
|
assert result.properties is None
|
||||||
|
|
||||||
def test_to_gemini_schema_basic_types(self):
|
def test_to_gemini_schema_basic_types(self):
|
||||||
openapi_schema = {
|
openapi_schema = {
|
||||||
@ -814,6 +812,42 @@ class TestToGeminiSchema:
|
|||||||
assert gemini_schema.properties["age"].type == Type.INTEGER
|
assert gemini_schema.properties["age"].type == Type.INTEGER
|
||||||
assert gemini_schema.properties["is_active"].type == Type.BOOLEAN
|
assert gemini_schema.properties["is_active"].type == Type.BOOLEAN
|
||||||
|
|
||||||
|
def test_to_gemini_schema_array_string_types(self):
|
||||||
|
openapi_schema = {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"boolean_field": {"type": "boolean"},
|
||||||
|
"nonnullable_string": {"type": ["string"]},
|
||||||
|
"nullable_string": {"type": ["string", "null"]},
|
||||||
|
"nullable_number": {"type": ["null", "integer"]},
|
||||||
|
"object_nullable": {"type": "null"},
|
||||||
|
"multi_types_nullable": {"type": ["string", "null", "integer"]},
|
||||||
|
"empty_default_object": {},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
gemini_schema = to_gemini_schema(openapi_schema)
|
||||||
|
assert isinstance(gemini_schema, Schema)
|
||||||
|
assert gemini_schema.type == Type.OBJECT
|
||||||
|
assert gemini_schema.properties["boolean_field"].type == Type.BOOLEAN
|
||||||
|
|
||||||
|
assert gemini_schema.properties["nonnullable_string"].type == Type.STRING
|
||||||
|
assert not gemini_schema.properties["nonnullable_string"].nullable
|
||||||
|
|
||||||
|
assert gemini_schema.properties["nullable_string"].type == Type.STRING
|
||||||
|
assert gemini_schema.properties["nullable_string"].nullable
|
||||||
|
|
||||||
|
assert gemini_schema.properties["nullable_number"].type == Type.INTEGER
|
||||||
|
assert gemini_schema.properties["nullable_number"].nullable
|
||||||
|
|
||||||
|
assert gemini_schema.properties["object_nullable"].type == Type.OBJECT
|
||||||
|
assert gemini_schema.properties["object_nullable"].nullable
|
||||||
|
|
||||||
|
assert gemini_schema.properties["multi_types_nullable"].type == Type.STRING
|
||||||
|
assert gemini_schema.properties["multi_types_nullable"].nullable
|
||||||
|
|
||||||
|
assert gemini_schema.properties["empty_default_object"].type == Type.OBJECT
|
||||||
|
assert not gemini_schema.properties["empty_default_object"].nullable
|
||||||
|
|
||||||
def test_to_gemini_schema_nested_objects(self):
|
def test_to_gemini_schema_nested_objects(self):
|
||||||
openapi_schema = {
|
openapi_schema = {
|
||||||
"type": "object",
|
"type": "object",
|
||||||
@ -895,7 +929,15 @@ class TestToGeminiSchema:
|
|||||||
def test_to_gemini_schema_nested_dict(self):
|
def test_to_gemini_schema_nested_dict(self):
|
||||||
openapi_schema = {
|
openapi_schema = {
|
||||||
"type": "object",
|
"type": "object",
|
||||||
"properties": {"metadata": {"key1": "value1", "key2": 123}},
|
"properties": {
|
||||||
|
"metadata": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"key1": {"type": "object"},
|
||||||
|
"key2": {"type": "string"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
},
|
||||||
}
|
}
|
||||||
gemini_schema = to_gemini_schema(openapi_schema)
|
gemini_schema = to_gemini_schema(openapi_schema)
|
||||||
# Since metadata is not properties nor item, it will call to_gemini_schema recursively.
|
# Since metadata is not properties nor item, it will call to_gemini_schema recursively.
|
||||||
@ -903,9 +945,15 @@ class TestToGeminiSchema:
|
|||||||
assert (
|
assert (
|
||||||
gemini_schema.properties["metadata"].type == Type.OBJECT
|
gemini_schema.properties["metadata"].type == Type.OBJECT
|
||||||
) # add object type by default
|
) # add object type by default
|
||||||
assert gemini_schema.properties["metadata"].properties == {
|
assert len(gemini_schema.properties["metadata"].properties) == 2
|
||||||
"dummy_DO_NOT_GENERATE": Schema(type="string")
|
assert (
|
||||||
}
|
gemini_schema.properties["metadata"].properties["key1"].type
|
||||||
|
== Type.OBJECT
|
||||||
|
)
|
||||||
|
assert (
|
||||||
|
gemini_schema.properties["metadata"].properties["key2"].type
|
||||||
|
== Type.STRING
|
||||||
|
)
|
||||||
|
|
||||||
def test_to_gemini_schema_ignore_title_default_format(self):
|
def test_to_gemini_schema_ignore_title_default_format(self):
|
||||||
openapi_schema = {
|
openapi_schema = {
|
||||||
|
238
tests/unittests/tools/test_function_tool.py
Normal file
238
tests/unittests/tools/test_function_tool.py
Normal file
@ -0,0 +1,238 @@
|
|||||||
|
# Copyright 2025 Google LLC
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from unittest.mock import MagicMock
|
||||||
|
|
||||||
|
from google.adk.tools.function_tool import FunctionTool
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
|
||||||
|
def function_for_testing_with_no_args():
|
||||||
|
"""Function for testing with no args."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
async def async_function_for_testing_with_1_arg_and_tool_context(
|
||||||
|
arg1, tool_context
|
||||||
|
):
|
||||||
|
"""Async function for testing with 1 arge and tool context."""
|
||||||
|
assert arg1
|
||||||
|
assert tool_context
|
||||||
|
return arg1
|
||||||
|
|
||||||
|
|
||||||
|
async def async_function_for_testing_with_2_arg_and_no_tool_context(arg1, arg2):
|
||||||
|
"""Async function for testing with 2 arge and no tool context."""
|
||||||
|
assert arg1
|
||||||
|
assert arg2
|
||||||
|
return arg1
|
||||||
|
|
||||||
|
|
||||||
|
def function_for_testing_with_1_arg_and_tool_context(arg1, tool_context):
|
||||||
|
"""Function for testing with 1 arge and tool context."""
|
||||||
|
assert arg1
|
||||||
|
assert tool_context
|
||||||
|
return arg1
|
||||||
|
|
||||||
|
|
||||||
|
def function_for_testing_with_2_arg_and_no_tool_context(arg1, arg2):
|
||||||
|
"""Function for testing with 2 arge and no tool context."""
|
||||||
|
assert arg1
|
||||||
|
assert arg2
|
||||||
|
return arg1
|
||||||
|
|
||||||
|
|
||||||
|
async def async_function_for_testing_with_4_arg_and_no_tool_context(
|
||||||
|
arg1, arg2, arg3, arg4
|
||||||
|
):
|
||||||
|
"""Async function for testing with 4 args."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def function_for_testing_with_4_arg_and_no_tool_context(arg1, arg2, arg3, arg4):
|
||||||
|
"""Function for testing with 4 args."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def test_init():
|
||||||
|
"""Test that the FunctionTool is initialized correctly."""
|
||||||
|
tool = FunctionTool(function_for_testing_with_no_args)
|
||||||
|
assert tool.name == "function_for_testing_with_no_args"
|
||||||
|
assert tool.description == "Function for testing with no args."
|
||||||
|
assert tool.func == function_for_testing_with_no_args
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_run_async_with_tool_context_async_func():
|
||||||
|
"""Test that run_async calls the function with tool_context when tool_context is in signature (async function)."""
|
||||||
|
|
||||||
|
tool = FunctionTool(async_function_for_testing_with_1_arg_and_tool_context)
|
||||||
|
args = {"arg1": "test_value_1"}
|
||||||
|
result = await tool.run_async(args=args, tool_context=MagicMock())
|
||||||
|
assert result == "test_value_1"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_run_async_without_tool_context_async_func():
|
||||||
|
"""Test that run_async calls the function without tool_context when tool_context is not in signature (async function)."""
|
||||||
|
tool = FunctionTool(async_function_for_testing_with_2_arg_and_no_tool_context)
|
||||||
|
args = {"arg1": "test_value_1", "arg2": "test_value_2"}
|
||||||
|
result = await tool.run_async(args=args, tool_context=MagicMock())
|
||||||
|
assert result == "test_value_1"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_run_async_with_tool_context_sync_func():
|
||||||
|
"""Test that run_async calls the function with tool_context when tool_context is in signature (synchronous function)."""
|
||||||
|
tool = FunctionTool(function_for_testing_with_1_arg_and_tool_context)
|
||||||
|
args = {"arg1": "test_value_1"}
|
||||||
|
result = await tool.run_async(args=args, tool_context=MagicMock())
|
||||||
|
assert result == "test_value_1"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_run_async_without_tool_context_sync_func():
|
||||||
|
"""Test that run_async calls the function without tool_context when tool_context is not in signature (synchronous function)."""
|
||||||
|
tool = FunctionTool(function_for_testing_with_2_arg_and_no_tool_context)
|
||||||
|
args = {"arg1": "test_value_1", "arg2": "test_value_2"}
|
||||||
|
result = await tool.run_async(args=args, tool_context=MagicMock())
|
||||||
|
assert result == "test_value_1"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_run_async_1_missing_arg_sync_func():
|
||||||
|
"""Test that run_async calls the function with 1 missing arg in signature (synchronous function)."""
|
||||||
|
tool = FunctionTool(function_for_testing_with_2_arg_and_no_tool_context)
|
||||||
|
args = {"arg1": "test_value_1"}
|
||||||
|
result = await tool.run_async(args=args, tool_context=MagicMock())
|
||||||
|
assert result == {
|
||||||
|
"error": (
|
||||||
|
"""Invoking `function_for_testing_with_2_arg_and_no_tool_context()` failed as the following mandatory input parameters are not present:
|
||||||
|
arg2
|
||||||
|
You could retry calling this tool, but it is IMPORTANT for you to provide all the mandatory parameters."""
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_run_async_1_missing_arg_async_func():
|
||||||
|
"""Test that run_async calls the function with 1 missing arg in signature (async function)."""
|
||||||
|
tool = FunctionTool(async_function_for_testing_with_2_arg_and_no_tool_context)
|
||||||
|
args = {"arg2": "test_value_1"}
|
||||||
|
result = await tool.run_async(args=args, tool_context=MagicMock())
|
||||||
|
assert result == {
|
||||||
|
"error": (
|
||||||
|
"""Invoking `async_function_for_testing_with_2_arg_and_no_tool_context()` failed as the following mandatory input parameters are not present:
|
||||||
|
arg1
|
||||||
|
You could retry calling this tool, but it is IMPORTANT for you to provide all the mandatory parameters."""
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_run_async_3_missing_arg_sync_func():
|
||||||
|
"""Test that run_async calls the function with 3 missing args in signature (synchronous function)."""
|
||||||
|
tool = FunctionTool(function_for_testing_with_4_arg_and_no_tool_context)
|
||||||
|
args = {"arg2": "test_value_1"}
|
||||||
|
result = await tool.run_async(args=args, tool_context=MagicMock())
|
||||||
|
assert result == {
|
||||||
|
"error": (
|
||||||
|
"""Invoking `function_for_testing_with_4_arg_and_no_tool_context()` failed as the following mandatory input parameters are not present:
|
||||||
|
arg1
|
||||||
|
arg3
|
||||||
|
arg4
|
||||||
|
You could retry calling this tool, but it is IMPORTANT for you to provide all the mandatory parameters."""
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_run_async_3_missing_arg_async_func():
|
||||||
|
"""Test that run_async calls the function with 3 missing args in signature (async function)."""
|
||||||
|
tool = FunctionTool(async_function_for_testing_with_4_arg_and_no_tool_context)
|
||||||
|
args = {"arg3": "test_value_1"}
|
||||||
|
result = await tool.run_async(args=args, tool_context=MagicMock())
|
||||||
|
assert result == {
|
||||||
|
"error": (
|
||||||
|
"""Invoking `async_function_for_testing_with_4_arg_and_no_tool_context()` failed as the following mandatory input parameters are not present:
|
||||||
|
arg1
|
||||||
|
arg2
|
||||||
|
arg4
|
||||||
|
You could retry calling this tool, but it is IMPORTANT for you to provide all the mandatory parameters."""
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_run_async_missing_all_arg_sync_func():
|
||||||
|
"""Test that run_async calls the function with all missing args in signature (synchronous function)."""
|
||||||
|
tool = FunctionTool(function_for_testing_with_4_arg_and_no_tool_context)
|
||||||
|
args = {}
|
||||||
|
result = await tool.run_async(args=args, tool_context=MagicMock())
|
||||||
|
assert result == {
|
||||||
|
"error": (
|
||||||
|
"""Invoking `function_for_testing_with_4_arg_and_no_tool_context()` failed as the following mandatory input parameters are not present:
|
||||||
|
arg1
|
||||||
|
arg2
|
||||||
|
arg3
|
||||||
|
arg4
|
||||||
|
You could retry calling this tool, but it is IMPORTANT for you to provide all the mandatory parameters."""
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_run_async_missing_all_arg_async_func():
|
||||||
|
"""Test that run_async calls the function with all missing args in signature (async function)."""
|
||||||
|
tool = FunctionTool(async_function_for_testing_with_4_arg_and_no_tool_context)
|
||||||
|
args = {}
|
||||||
|
result = await tool.run_async(args=args, tool_context=MagicMock())
|
||||||
|
assert result == {
|
||||||
|
"error": (
|
||||||
|
"""Invoking `async_function_for_testing_with_4_arg_and_no_tool_context()` failed as the following mandatory input parameters are not present:
|
||||||
|
arg1
|
||||||
|
arg2
|
||||||
|
arg3
|
||||||
|
arg4
|
||||||
|
You could retry calling this tool, but it is IMPORTANT for you to provide all the mandatory parameters."""
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_run_async_with_optional_args_not_set_sync_func():
|
||||||
|
"""Test that run_async calls the function for sync funciton with optional args not set."""
|
||||||
|
|
||||||
|
def func_with_optional_args(arg1, arg2=None, *, arg3, arg4=None, **kwargs):
|
||||||
|
return f"{arg1},{arg3}"
|
||||||
|
|
||||||
|
tool = FunctionTool(func_with_optional_args)
|
||||||
|
args = {"arg1": "test_value_1", "arg3": "test_value_3"}
|
||||||
|
result = await tool.run_async(args=args, tool_context=MagicMock())
|
||||||
|
assert result == "test_value_1,test_value_3"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_run_async_with_optional_args_not_set_async_func():
|
||||||
|
"""Test that run_async calls the function for async funciton with optional args not set."""
|
||||||
|
|
||||||
|
async def async_func_with_optional_args(
|
||||||
|
arg1, arg2=None, *, arg3, arg4=None, **kwargs
|
||||||
|
):
|
||||||
|
return f"{arg1},{arg3}"
|
||||||
|
|
||||||
|
tool = FunctionTool(async_func_with_optional_args)
|
||||||
|
args = {"arg1": "test_value_1", "arg3": "test_value_3"}
|
||||||
|
result = await tool.run_async(args=args, tool_context=MagicMock())
|
||||||
|
assert result == "test_value_1,test_value_3"
|
Loading…
Reference in New Issue
Block a user