diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md
new file mode 100644
index 0000000..bc23aae
--- /dev/null
+++ b/CONTRIBUTING.md
@@ -0,0 +1,33 @@
+# How to contribute
+
+We'd love to accept your patches and contributions to this project.
+
+## Before you begin
+
+### Sign our Contributor License Agreement
+
+Contributions to this project must be accompanied by a
+[Contributor License Agreement](https://cla.developers.google.com/about) (CLA).
+You (or your employer) retain the copyright to your contribution; this simply
+gives us permission to use and redistribute your contributions as part of the
+project.
+
+If you or your current employer have already signed the Google CLA (even if it
+was for a different project), you probably don't need to do it again.
+
+Visit to see your current agreements or to
+sign a new one.
+
+### Review our community guidelines
+
+This project follows
+[Google's Open Source Community Guidelines](https://opensource.google/conduct/).
+
+## Contribution process
+
+### Code reviews
+
+All submissions, including submissions by project members, require review. We
+use GitHub pull requests for this purpose. Consult
+[GitHub Help](https://help.github.com/articles/about-pull-requests/) for more
+information on using pull requests.
\ No newline at end of file
diff --git a/LICENSE b/LICENSE
new file mode 100644
index 0000000..7a4a3ea
--- /dev/null
+++ b/LICENSE
@@ -0,0 +1,202 @@
+
+ Apache License
+ Version 2.0, January 2004
+ http://www.apache.org/licenses/
+
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
+
+ 1. Definitions.
+
+ "License" shall mean the terms and conditions for use, reproduction,
+ and distribution as defined by Sections 1 through 9 of this document.
+
+ "Licensor" shall mean the copyright owner or entity authorized by
+ the copyright owner that is granting the License.
+
+ "Legal Entity" shall mean the union of the acting entity and all
+ other entities that control, are controlled by, or are under common
+ control with that entity. For the purposes of this definition,
+ "control" means (i) the power, direct or indirect, to cause the
+ direction or management of such entity, whether by contract or
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
+ outstanding shares, or (iii) beneficial ownership of such entity.
+
+ "You" (or "Your") shall mean an individual or Legal Entity
+ exercising permissions granted by this License.
+
+ "Source" form shall mean the preferred form for making modifications,
+ including but not limited to software source code, documentation
+ source, and configuration files.
+
+ "Object" form shall mean any form resulting from mechanical
+ transformation or translation of a Source form, including but
+ not limited to compiled object code, generated documentation,
+ and conversions to other media types.
+
+ "Work" shall mean the work of authorship, whether in Source or
+ Object form, made available under the License, as indicated by a
+ copyright notice that is included in or attached to the work
+ (an example is provided in the Appendix below).
+
+ "Derivative Works" shall mean any work, whether in Source or Object
+ form, that is based on (or derived from) the Work and for which the
+ editorial revisions, annotations, elaborations, or other modifications
+ represent, as a whole, an original work of authorship. For the purposes
+ of this License, Derivative Works shall not include works that remain
+ separable from, or merely link (or bind by name) to the interfaces of,
+ the Work and Derivative Works thereof.
+
+ "Contribution" shall mean any work of authorship, including
+ the original version of the Work and any modifications or additions
+ to that Work or Derivative Works thereof, that is intentionally
+ submitted to Licensor for inclusion in the Work by the copyright owner
+ or by an individual or Legal Entity authorized to submit on behalf of
+ the copyright owner. For the purposes of this definition, "submitted"
+ means any form of electronic, verbal, or written communication sent
+ to the Licensor or its representatives, including but not limited to
+ communication on electronic mailing lists, source code control systems,
+ and issue tracking systems that are managed by, or on behalf of, the
+ Licensor for the purpose of discussing and improving the Work, but
+ excluding communication that is conspicuously marked or otherwise
+ designated in writing by the copyright owner as "Not a Contribution."
+
+ "Contributor" shall mean Licensor and any individual or Legal Entity
+ on behalf of whom a Contribution has been received by Licensor and
+ subsequently incorporated within the Work.
+
+ 2. Grant of Copyright License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ copyright license to reproduce, prepare Derivative Works of,
+ publicly display, publicly perform, sublicense, and distribute the
+ Work and such Derivative Works in Source or Object form.
+
+ 3. Grant of Patent License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ (except as stated in this section) patent license to make, have made,
+ use, offer to sell, sell, import, and otherwise transfer the Work,
+ where such license applies only to those patent claims licensable
+ by such Contributor that are necessarily infringed by their
+ Contribution(s) alone or by combination of their Contribution(s)
+ with the Work to which such Contribution(s) was submitted. If You
+ institute patent litigation against any entity (including a
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
+ or a Contribution incorporated within the Work constitutes direct
+ or contributory patent infringement, then any patent licenses
+ granted to You under this License for that Work shall terminate
+ as of the date such litigation is filed.
+
+ 4. Redistribution. You may reproduce and distribute copies of the
+ Work or Derivative Works thereof in any medium, with or without
+ modifications, and in Source or Object form, provided that You
+ meet the following conditions:
+
+ (a) You must give any other recipients of the Work or
+ Derivative Works a copy of this License; and
+
+ (b) You must cause any modified files to carry prominent notices
+ stating that You changed the files; and
+
+ (c) You must retain, in the Source form of any Derivative Works
+ that You distribute, all copyright, patent, trademark, and
+ attribution notices from the Source form of the Work,
+ excluding those notices that do not pertain to any part of
+ the Derivative Works; and
+
+ (d) If the Work includes a "NOTICE" text file as part of its
+ distribution, then any Derivative Works that You distribute must
+ include a readable copy of the attribution notices contained
+ within such NOTICE file, excluding those notices that do not
+ pertain to any part of the Derivative Works, in at least one
+ of the following places: within a NOTICE text file distributed
+ as part of the Derivative Works; within the Source form or
+ documentation, if provided along with the Derivative Works; or,
+ within a display generated by the Derivative Works, if and
+ wherever such third-party notices normally appear. The contents
+ of the NOTICE file are for informational purposes only and
+ do not modify the License. You may add Your own attribution
+ notices within Derivative Works that You distribute, alongside
+ or as an addendum to the NOTICE text from the Work, provided
+ that such additional attribution notices cannot be construed
+ as modifying the License.
+
+ You may add Your own copyright statement to Your modifications and
+ may provide additional or different license terms and conditions
+ for use, reproduction, or distribution of Your modifications, or
+ for any such Derivative Works as a whole, provided Your use,
+ reproduction, and distribution of the Work otherwise complies with
+ the conditions stated in this License.
+
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
+ any Contribution intentionally submitted for inclusion in the Work
+ by You to the Licensor shall be under the terms and conditions of
+ this License, without any additional terms or conditions.
+ Notwithstanding the above, nothing herein shall supersede or modify
+ the terms of any separate license agreement you may have executed
+ with Licensor regarding such Contributions.
+
+ 6. Trademarks. This License does not grant permission to use the trade
+ names, trademarks, service marks, or product names of the Licensor,
+ except as required for reasonable and customary use in describing the
+ origin of the Work and reproducing the content of the NOTICE file.
+
+ 7. Disclaimer of Warranty. Unless required by applicable law or
+ agreed to in writing, Licensor provides the Work (and each
+ Contributor provides its Contributions) on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
+ implied, including, without limitation, any warranties or conditions
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
+ PARTICULAR PURPOSE. You are solely responsible for determining the
+ appropriateness of using or redistributing the Work and assume any
+ risks associated with Your exercise of permissions under this License.
+
+ 8. Limitation of Liability. In no event and under no legal theory,
+ whether in tort (including negligence), contract, or otherwise,
+ unless required by applicable law (such as deliberate and grossly
+ negligent acts) or agreed to in writing, shall any Contributor be
+ liable to You for damages, including any direct, indirect, special,
+ incidental, or consequential damages of any character arising as a
+ result of this License or out of the use or inability to use the
+ Work (including but not limited to damages for loss of goodwill,
+ work stoppage, computer failure or malfunction, or any and all
+ other commercial damages or losses), even if such Contributor
+ has been advised of the possibility of such damages.
+
+ 9. Accepting Warranty or Additional Liability. While redistributing
+ the Work or Derivative Works thereof, You may choose to offer,
+ and charge a fee for, acceptance of support, warranty, indemnity,
+ or other liability obligations and/or rights consistent with this
+ License. However, in accepting such obligations, You may act only
+ on Your own behalf and on Your sole responsibility, not on behalf
+ of any other Contributor, and only if You agree to indemnify,
+ defend, and hold each Contributor harmless for any liability
+ incurred by, or claims asserted against, such Contributor by reason
+ of your accepting any such warranty or additional liability.
+
+ END OF TERMS AND CONDITIONS
+
+ APPENDIX: How to apply the Apache License to your work.
+
+ To apply the Apache License to your work, attach the following
+ boilerplate notice, with the fields enclosed by brackets "[]"
+ replaced with your own identifying information. (Don't include
+ the brackets!) The text should be enclosed in the appropriate
+ comment syntax for the file format. We also recommend that a
+ file or class name and description of purpose be included on the
+ same "printed page" as the copyright notice for easier
+ identification within third-party archives.
+
+ Copyright [yyyy] [name of copyright owner]
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
\ No newline at end of file
diff --git a/README.md b/README.md
index be6421a..5e228af 100644
--- a/README.md
+++ b/README.md
@@ -1,3 +1,95 @@
-# adk-python
+# Agent Development Kit (ADK)
-Hello World!
+[](LICENSE)
+
+
+
+**An open-source, code-first Python toolkit for building, evaluating, and deploying sophisticated AI agents with flexibility and control.**
+
+The Agent Development Kit (ADK) is designed for developers seeking fine-grained control and flexibility when building advanced AI agents that are tightly integrated with services in Google Cloud. It allows you to define agent behavior, orchestration, and tool use directly in code, enabling robust debugging, versioning, and deployment anywhere – from your laptop to the cloud.
+
+---
+
+## ✨ Key Features
+
+* **Code-First Development:** Define agents, tools, and orchestration logic for maximum control, testability, and versioning.
+* **Multi-Agent Architecture:** Build modular and scalable applications by composing multiple specialized agents in flexible hierarchies.
+* **Rich Tool Ecosystem:** Equip agents with diverse capabilities using pre-built tools, custom Python functions, API specifications, or integrating existing tools.
+* **Flexible Orchestration:** Define workflows using built-in agents for predictable pipelines, or leverage LLM-driven dynamic routing for adaptive behavior.
+* **Integrated Developer Experience:** Develop, test, and debug locally with a CLI and visual web UI.
+* **Built-in Evaluation:** Measure agent performance by evaluating response quality and step-by-step execution trajectory.
+* **Deployment Ready:** Containerize and deploy your agents anywhere – scale with Vertex AI Agent Engine, Cloud Run, or Docker.
+* **Native Streaming Support:** Build real-time, interactive experiences with native support for bidirectional streaming (text and audio).
+* **State, Memory & Artifacts:** Manage short-term conversational context, configure long-term memory, and handle file uploads/downloads.
+* **Extensibility:** Customize agent behavior deeply with callbacks and easily integrate third-party tools and services.
+
+## 🚀 Installation
+
+You can install the Agent Developer Kit using `pip`:
+
+```bash
+pip install google-adk
+```
+
+## 🏁 Getting Started
+
+Create your first agent (`my_agent/agent.py`):
+
+```python
+# my_agent/agent.py
+from google.adk.agents import Agent
+from google.adk.tools import google_search
+
+root_agent = Agent(
+ name="search_assistant",
+ model="gemini-1.5-flash-latest", # Or your preferred model like gemini-2.0-flash-001
+ instruction="You are a helpful assistant. Answer user questions using Google Search when needed.",
+ description="An assistant that can search the web.",
+ tools=[google_search]
+)
+```
+
+Create `my_agent/__init__.py`:
+
+```python
+# my_agent/__init__.py
+from . import agent
+```
+
+Run it via the CLI (from the directory *containing* `my_agent`):
+
+```bash
+adk run my_agent
+```
+
+Or launch the Web UI from the folder that contains `my_agent` folder:
+
+```bash
+adk web
+```
+
+For a full step-by-step guide, check out the quickstart or sample agents.
+
+## 📚 Resources
+
+Explore the full documentation for detailed guides on building, evaluating, and deploying agents:
+
+* **[Get Started](get-started/introduction.md)**
+* **[Build Agents](build/agents.md)**
+* **[Browse Sample Agents](learn/sample_agents/)**
+* **[Evaluate Agents](evaluate/evaluate-agents.md)**
+* **[Deploy Agents](deploy/overview.md)**
+* **[API Reference](guides/reference.md)**
+* **[Troubleshooting](guides/troubleshooting.md)**
+
+## 🤝 Contributing
+
+We welcome contributions from the community! Whether it's bug reports, feature requests, documentation improvements, or code contributions, please see our [**Contributing Guidelines**](./CONTRIBUTING.md) to get started.
+
+## 📄 License
+
+This project is licensed under the Apache 2.0 License - see the [LICENSE](LICENSE) file for details.
+
+---
+
+*Happy Agent Building!*
\ No newline at end of file
diff --git a/assets/agent-development-kit.png b/assets/agent-development-kit.png
new file mode 100644
index 0000000..9f967ca
Binary files /dev/null and b/assets/agent-development-kit.png differ
diff --git a/pylintrc b/pylintrc
new file mode 100644
index 0000000..d35fac3
--- /dev/null
+++ b/pylintrc
@@ -0,0 +1,400 @@
+# This Pylint rcfile contains a best-effort configuration to uphold the
+# best-practices and style described in the Google Python style guide:
+# https://google.github.io/styleguide/pyguide.html
+#
+# Its canonical open-source location is:
+# https://google.github.io/styleguide/pylintrc
+
+[MAIN]
+
+# Files or directories to be skipped. They should be base names, not paths.
+ignore=third_party
+
+# Files or directories matching the regex patterns are skipped. The regex
+# matches against base names, not paths.
+ignore-patterns=
+
+# Pickle collected data for later comparisons.
+persistent=no
+
+# List of plugins (as comma separated values of python modules names) to load,
+# usually to register additional checkers.
+load-plugins=
+
+# Use multiple processes to speed up Pylint.
+jobs=4
+
+# Allow loading of arbitrary C extensions. Extensions are imported into the
+# active Python interpreter and may run arbitrary code.
+unsafe-load-any-extension=no
+
+
+[MESSAGES CONTROL]
+
+# Only show warnings with the listed confidence levels. Leave empty to show
+# all. Valid levels: HIGH, INFERENCE, INFERENCE_FAILURE, UNDEFINED
+confidence=
+
+# Enable the message, report, category or checker with the given id(s). You can
+# either give multiple identifier separated by comma (,) or put this option
+# multiple time (only on the command line, not in the configuration file where
+# it should appear only once). See also the "--disable" option for examples.
+#enable=
+
+# Disable the message, report, category or checker with the given id(s). You
+# can either give multiple identifiers separated by comma (,) or put this
+# option multiple times (only on the command line, not in the configuration
+# file where it should appear only once).You can also use "--disable=all" to
+# disable everything first and then reenable specific checks. For example, if
+# you want to run only the similarities checker, you can use "--disable=all
+# --enable=similarities". If you want to run only the classes checker, but have
+# no Warning level messages displayed, use"--disable=all --enable=classes
+# --disable=W"
+disable=R,
+ abstract-method,
+ apply-builtin,
+ arguments-differ,
+ attribute-defined-outside-init,
+ backtick,
+ bad-option-value,
+ basestring-builtin,
+ buffer-builtin,
+ c-extension-no-member,
+ consider-using-enumerate,
+ cmp-builtin,
+ cmp-method,
+ coerce-builtin,
+ coerce-method,
+ delslice-method,
+ div-method,
+ eq-without-hash,
+ execfile-builtin,
+ file-builtin,
+ filter-builtin-not-iterating,
+ fixme,
+ getslice-method,
+ global-statement,
+ hex-method,
+ idiv-method,
+ implicit-str-concat,
+ import-error,
+ import-self,
+ import-star-module-level,
+ import-outside-toplevel,
+ input-builtin,
+ intern-builtin,
+ invalid-str-codec,
+ locally-disabled,
+ long-builtin,
+ long-suffix,
+ map-builtin-not-iterating,
+ misplaced-comparison-constant,
+ missing-function-docstring,
+ metaclass-assignment,
+ next-method-called,
+ next-method-defined,
+ no-absolute-import,
+ no-init, # added
+ no-member,
+ no-name-in-module,
+ no-self-use,
+ nonzero-method,
+ oct-method,
+ old-division,
+ old-ne-operator,
+ old-octal-literal,
+ old-raise-syntax,
+ parameter-unpacking,
+ print-statement,
+ raising-string,
+ range-builtin-not-iterating,
+ raw_input-builtin,
+ rdiv-method,
+ reduce-builtin,
+ relative-import,
+ reload-builtin,
+ round-builtin,
+ setslice-method,
+ signature-differs,
+ standarderror-builtin,
+ suppressed-message,
+ sys-max-int,
+ trailing-newlines,
+ unichr-builtin,
+ unicode-builtin,
+ unnecessary-pass,
+ unpacking-in-except,
+ useless-else-on-loop,
+ useless-suppression,
+ using-cmp-argument,
+ wrong-import-order,
+ xrange-builtin,
+ zip-builtin-not-iterating,
+
+
+[REPORTS]
+
+# Set the output format. Available formats are text, parseable, colorized, msvs
+# (visual studio) and html. You can also give a reporter class, eg
+# mypackage.mymodule.MyReporterClass.
+output-format=text
+
+# Tells whether to display a full report or only the messages
+reports=no
+
+# Python expression which should return a note less than 10 (10 is the highest
+# note). You have access to the variables errors warning, statement which
+# respectively contain the number of errors / warnings messages and the total
+# number of statements analyzed. This is used by the global evaluation report
+# (RP0004).
+evaluation=10.0 - ((float(5 * error + warning + refactor + convention) / statement) * 10)
+
+# Template used to display messages. This is a python new-style format string
+# used to format the message information. See doc for all details
+#msg-template=
+
+
+[BASIC]
+
+# Good variable names which should always be accepted, separated by a comma
+good-names=main,_
+
+# Bad variable names which should always be refused, separated by a comma
+bad-names=
+
+# Colon-delimited sets of names that determine each other's naming style when
+# the name regexes allow several styles.
+name-group=
+
+# Include a hint for the correct naming format with invalid-name
+include-naming-hint=no
+
+# List of decorators that produce properties, such as abc.abstractproperty. Add
+# to this list to register other decorators that produce valid properties.
+property-classes=abc.abstractproperty,cached_property.cached_property,cached_property.threaded_cached_property,cached_property.cached_property_with_ttl,cached_property.threaded_cached_property_with_ttl
+
+# Regular expression matching correct function names
+function-rgx=^(?:(?PsetUp|tearDown|setUpModule|tearDownModule)|(?P_?[A-Z][a-zA-Z0-9]*)|(?P_?[a-z][a-z0-9_]*))$
+
+# Regular expression matching correct variable names
+variable-rgx=^[a-z][a-z0-9_]*$
+
+# Regular expression matching correct constant names
+const-rgx=^(_?[A-Z][A-Z0-9_]*|__[a-z0-9_]+__|_?[a-z][a-z0-9_]*)$
+
+# Regular expression matching correct attribute names
+attr-rgx=^_{0,2}[a-z][a-z0-9_]*$
+
+# Regular expression matching correct argument names
+argument-rgx=^[a-z][a-z0-9_]*$
+
+# Regular expression matching correct class attribute names
+class-attribute-rgx=^(_?[A-Z][A-Z0-9_]*|__[a-z0-9_]+__|_?[a-z][a-z0-9_]*)$
+
+# Regular expression matching correct inline iteration names
+inlinevar-rgx=^[a-z][a-z0-9_]*$
+
+# Regular expression matching correct class names
+class-rgx=^_?[A-Z][a-zA-Z0-9]*$
+
+# Regular expression matching correct module names
+module-rgx=^(_?[a-z][a-z0-9_]*|__init__)$
+
+# Regular expression matching correct method names
+method-rgx=(?x)^(?:(?P_[a-z0-9_]+__|runTest|setUp|tearDown|setUpTestCase|tearDownTestCase|setupSelf|tearDownClass|setUpClass|(test|assert)_*[A-Z0-9][a-zA-Z0-9_]*|next)|(?P_{0,2}[A-Z][a-zA-Z0-9_]*)|(?P_{0,2}[a-z][a-z0-9_]*))$
+
+# Regular expression which should only match function or class names that do
+# not require a docstring.
+no-docstring-rgx=(__.*__|main|test.*|.*test|.*Test)$
+
+# Minimum line length for functions/classes that require docstrings, shorter
+# ones are exempt.
+docstring-min-length=12
+
+
+[TYPECHECK]
+
+# List of decorators that produce context managers, such as
+# contextlib.contextmanager. Add to this list to register other decorators that
+# produce valid context managers.
+contextmanager-decorators=contextlib.contextmanager,contextlib2.contextmanager
+
+# List of module names for which member attributes should not be checked
+# (useful for modules/projects where namespaces are manipulated during runtime
+# and thus existing member attributes cannot be deduced by static analysis. It
+# supports qualified module names, as well as Unix pattern matching.
+ignored-modules=
+
+# List of class names for which member attributes should not be checked (useful
+# for classes with dynamically set attributes). This supports the use of
+# qualified names.
+ignored-classes=optparse.Values,thread._local,_thread._local
+
+# List of members which are set dynamically and missed by pylint inference
+# system, and so shouldn't trigger E1101 when accessed. Python regular
+# expressions are accepted.
+generated-members=
+
+
+[FORMAT]
+
+# Maximum number of characters on a single line.
+max-line-length=80
+
+# TODO(https://github.com/pylint-dev/pylint/issues/3352): Direct pylint to exempt
+# lines made too long by directives to pytype.
+
+# Regexp for a line that is allowed to be longer than the limit.
+ignore-long-lines=(?x)(
+ ^\s*(\#\ )??$|
+ ^\s*(from\s+\S+\s+)?import\s+.+$)
+
+# Allow the body of an if to be on the same line as the test if there is no
+# else.
+single-line-if-stmt=yes
+
+# Maximum number of lines in a module
+max-module-lines=99999
+
+# String used as indentation unit. The internal Google style guide mandates 2
+# spaces. Google's externaly-published style guide says 4, consistent with
+# PEP 8. Here, we use 2 spaces, for conformity with many open-sourced Google
+# projects (like TensorFlow).
+indent-string=' '
+
+# Number of spaces of indent required inside a hanging or continued line.
+indent-after-paren=4
+
+# Expected format of line ending, e.g. empty (any line ending), LF or CRLF.
+expected-line-ending-format=
+
+
+[MISCELLANEOUS]
+
+# List of note tags to take in consideration, separated by a comma.
+notes=TODO
+
+
+[STRING]
+
+# This flag controls whether inconsistent-quotes generates a warning when the
+# character used as a quote delimiter is used inconsistently within a module.
+check-quote-consistency=yes
+
+
+[VARIABLES]
+
+# Tells whether we should check for unused import in __init__ files.
+init-import=no
+
+# A regular expression matching the name of dummy variables (i.e. expectedly
+# not used).
+dummy-variables-rgx=^\*{0,2}(_$|unused_|dummy_)
+
+# List of additional names supposed to be defined in builtins. Remember that
+# you should avoid to define new builtins when possible.
+additional-builtins=
+
+# List of strings which can identify a callback function by name. A callback
+# name must start or end with one of those strings.
+callbacks=cb_,_cb
+
+# List of qualified module names which can have objects that can redefine
+# builtins.
+redefining-builtins-modules=six,six.moves,past.builtins,future.builtins,functools
+
+
+[LOGGING]
+
+# Logging modules to check that the string format arguments are in logging
+# function parameter format
+logging-modules=logging,absl.logging,tensorflow.io.logging
+
+
+[SIMILARITIES]
+
+# Minimum lines number of a similarity.
+min-similarity-lines=4
+
+# Ignore comments when computing similarities.
+ignore-comments=yes
+
+# Ignore docstrings when computing similarities.
+ignore-docstrings=yes
+
+# Ignore imports when computing similarities.
+ignore-imports=no
+
+
+[SPELLING]
+
+# Spelling dictionary name. Available dictionaries: none. To make it working
+# install python-enchant package.
+spelling-dict=
+
+# List of comma separated words that should not be checked.
+spelling-ignore-words=
+
+# A path to a file that contains private dictionary; one word per line.
+spelling-private-dict-file=
+
+# Tells whether to store unknown words to indicated private dictionary in
+# --spelling-private-dict-file option instead of raising a message.
+spelling-store-unknown-words=no
+
+
+[IMPORTS]
+
+# Deprecated modules which should not be used, separated by a comma
+deprecated-modules=regsub,
+ TERMIOS,
+ Bastion,
+ rexec,
+ sets
+
+# Create a graph of every (i.e. internal and external) dependencies in the
+# given file (report RP0402 must not be disabled)
+import-graph=
+
+# Create a graph of external dependencies in the given file (report RP0402 must
+# not be disabled)
+ext-import-graph=
+
+# Create a graph of internal dependencies in the given file (report RP0402 must
+# not be disabled)
+int-import-graph=
+
+# Force import order to recognize a module as part of the standard
+# compatibility libraries.
+known-standard-library=
+
+# Force import order to recognize a module as part of a third party library.
+known-third-party=enchant, absl
+
+# Analyse import fallback blocks. This can be used to support both Python 2 and
+# 3 compatible code, which means that the block might have code that exists
+# only in one or another interpreter, leading to false positives when analysed.
+analyse-fallback-blocks=no
+
+
+[CLASSES]
+
+# List of method names used to declare (i.e. assign) instance attributes.
+defining-attr-methods=__init__,
+ __new__,
+ setUp
+
+# List of member names, which should be excluded from the protected access
+# warning.
+exclude-protected=_asdict,
+ _fields,
+ _replace,
+ _source,
+ _make
+
+# List of valid names for the first argument in a class method.
+valid-classmethod-first-arg=cls,
+ class_
+
+# List of valid names for the first argument in a metaclass class method.
+valid-metaclass-classmethod-first-arg=mcs
diff --git a/pyproject.toml b/pyproject.toml
new file mode 100644
index 0000000..05711d1
--- /dev/null
+++ b/pyproject.toml
@@ -0,0 +1,146 @@
+[project]
+# Project metadata. Available keys are documented at:
+# https://packaging.python.org/en/latest/specifications/declaring-project-metadata
+
+name = "google-adk"
+description = "Agent Development Kit"
+readme = "README.md"
+requires-python = ">=3.9"
+license = { file = "LICENSE" }
+authors = [{ name = "Google LLC", email = "googleapis-packages@google.com" }]
+classifiers = [ # List of https://pypi.org/classifiers/
+ "Typing :: Typed",
+ "Intended Audience :: Developers",
+ "Intended Audience :: Science/Research",
+ "Programming Language :: Python",
+ "Programming Language :: Python :: 3",
+ "Programming Language :: Python :: 3.9",
+ "Programming Language :: Python :: 3.13",
+ "Programming Language :: Python :: 3.12",
+ "Programming Language :: Python :: 3.11",
+ "Programming Language :: Python :: 3.10",
+ "Operating System :: OS Independent",
+ "Topic :: Software Development :: Libraries :: Python Modules",
+ "License :: OSI Approved :: Apache Software License",
+]
+dependencies = [
+ # go/keep-sorted start
+ "authlib>=1.5.1", # For RestAPI Tool
+ "click>=8.1.8", # For CLI tools
+ "fastapi>=0.115.0", # FastAPI framework
+ "google-api-python-client>=2.157.0", # Google API client discovery
+ "google-cloud-aiplatform>=1.87.0", # For VertexAI integrations, e.g. example store.
+ "google-cloud-secret-manager>=2.22.0", # Fetching secrets in RestAPI Tool
+ "google-cloud-speech>=2.30.0", # For Audo Transcription
+ "google-cloud-storage>=2.18.0, <3.0.0", # For GCS Artifact service
+ "google-genai>=1.9.0", # Google GenAI SDK
+ "graphviz>=0.20.2", # Graphviz for graph rendering
+ "mcp>=1.5.0;python_version>='3.10'", # For MCP Toolset
+ "opentelemetry-api>=1.31.0", # OpenTelemetry
+ "opentelemetry-exporter-gcp-trace>=1.9.0",
+ "opentelemetry-sdk>=1.31.0",
+ "pydantic>=2.0, <3.0.0", # For data validation/models
+ "python-dotenv>=1.0.0", # To manage environment variables
+ "PyYAML>=6.0.2", # For APIHubToolset.
+ "sqlalchemy>=2.0", # SQL database ORM
+ "tzlocal>=5.3", # Time zone utilities
+ "uvicorn>=0.34.0", # ASGI server for FastAPI
+ # go/keep-sorted end
+]
+dynamic = ["version"]
+
+[project.urls]
+homepage = "https://google.github.io/adk-docs/"
+repository = "https://github.com/google/adk-python"
+changelog = "https://github.com/google/adk-python/blob/main/CHANGELOG.md"
+documentation = "https://google.github.io/adk-docs/"
+
+[project.scripts]
+adk = "google.adk.cli:main"
+
+[project.optional-dependencies]
+
+dev = [
+ # go/keep-sorted start
+ "flit>=3.10.0",
+ "isort>=6.0.0",
+ "pyink>=24.10.0",
+ "pylint>=2.6.0",
+ # go/keep-sorted end
+]
+
+eval = [
+ # go/keep-sorted start
+ "google-cloud-aiplatform[evaluation]>=1.87.0",
+ "pandas>=2.2.3",
+ "tabulate>=0.9.0",
+ # go/keep-sorted end
+]
+
+test = [
+ # go/keep-sorted start
+ "langchain-community>=0.3.17",
+ "pytest-asyncio>=0.25.0",
+ "pytest-mock>=3.14.0",
+ "pytest-xdist>=3.6.1",
+ "pytest>=8.3.4",
+ # go/keep-sorted end
+]
+
+docs = [
+ "autodoc_pydantic",
+ "furo",
+ "myst-parser",
+ "sphinx",
+ "sphinx-autodoc-typehints",
+ "sphinx-rtd-theme",
+]
+
+# Optional extensions
+extensions = [
+ "anthropic>=0.43.0", # For anthropic model support
+ "beautifulsoup4>=3.2.2", # For load_web_page tool.
+ "crewai[tools];python_version>='3.10'", # For CrewaiTool
+ "docker>=7.0.0", # For ContainerCodeExecutor
+ "langgraph>=0.2.60", # For LangGraphAgent
+ "litellm>=1.63.11", # For LiteLLM support
+ "llama-index-readers-file>=0.4.0", # for retrieval usings LlamaIndex.
+ "lxml>=5.3.0", # For load_web_page tool.
+]
+
+
+[tool.pyink]
+# Format py files following Google style-guide
+line-length = 80
+unstable = true
+pyink-indentation = 2
+pyink-use-majority-quotes = true
+
+
+[build-system]
+# Build system specify which backend is used to build/install the project (flit,
+# poetry, setuptools,...). All backends are supported by `pip install`
+requires = ["flit_core >=3.8,<4"]
+build-backend = "flit_core.buildapi"
+
+[tool.flit.sdist]
+include = ['src/**/*', 'README.md', 'pyproject.toml']
+exclude = ['src/**/*.sh']
+
+[tool.flit.module]
+name = "google.adk"
+
+[tool.isort]
+# Organize imports following Google style-guide
+force_single_line = true
+force_sort_within_sections = true
+honor_case_in_force_sorted_sections = true
+known_third_party = ["agents", "google"]
+order_by_type = false
+sort_relative_in_force_sorted_sections = true
+multi_line_output = 3
+line_length = 200
+
+[tool.pytest.ini_options]
+testpaths = ["tests"]
+asyncio_default_fixture_loop_scope = "function"
diff --git a/src/google/adk/__init__.py b/src/google/adk/__init__.py
new file mode 100644
index 0000000..f52f6e0
--- /dev/null
+++ b/src/google/adk/__init__.py
@@ -0,0 +1,20 @@
+# Copyright 2025 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from . import version
+from .agents.llm_agent import Agent
+from .runners import Runner
+
+__version__ = version.__version__
+__all__ = ["Agent", "Runner"]
diff --git a/src/google/adk/agents/__init__.py b/src/google/adk/agents/__init__.py
new file mode 100644
index 0000000..e1f773c
--- /dev/null
+++ b/src/google/adk/agents/__init__.py
@@ -0,0 +1,32 @@
+# Copyright 2025 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from .base_agent import BaseAgent
+from .live_request_queue import LiveRequest
+from .live_request_queue import LiveRequestQueue
+from .llm_agent import Agent
+from .llm_agent import LlmAgent
+from .loop_agent import LoopAgent
+from .parallel_agent import ParallelAgent
+from .run_config import RunConfig
+from .sequential_agent import SequentialAgent
+
+__all__ = [
+ 'Agent',
+ 'BaseAgent',
+ 'LlmAgent',
+ 'LoopAgent',
+ 'ParallelAgent',
+ 'SequentialAgent',
+]
diff --git a/src/google/adk/agents/active_streaming_tool.py b/src/google/adk/agents/active_streaming_tool.py
new file mode 100644
index 0000000..e5499e0
--- /dev/null
+++ b/src/google/adk/agents/active_streaming_tool.py
@@ -0,0 +1,38 @@
+# Copyright 2025 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import annotations
+
+import asyncio
+from typing import Optional
+
+from pydantic import BaseModel
+from pydantic import ConfigDict
+
+from .live_request_queue import LiveRequestQueue
+
+
+class ActiveStreamingTool(BaseModel):
+ """Manages streaming tool related resources during invocation."""
+
+ model_config = ConfigDict(
+ arbitrary_types_allowed=True,
+ extra='forbid',
+ )
+
+ task: Optional[asyncio.Task] = None
+ """The active task of this streaming tool."""
+
+ stream: Optional[LiveRequestQueue] = None
+ """The active (input) streams of this streaming tool."""
diff --git a/src/google/adk/agents/base_agent.py b/src/google/adk/agents/base_agent.py
new file mode 100644
index 0000000..c47d9e3
--- /dev/null
+++ b/src/google/adk/agents/base_agent.py
@@ -0,0 +1,345 @@
+# Copyright 2025 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import annotations
+
+from typing import Any
+from typing import AsyncGenerator
+from typing import Callable
+from typing import final
+from typing import Optional
+from typing import TYPE_CHECKING
+
+from google.genai import types
+from opentelemetry import trace
+from pydantic import BaseModel
+from pydantic import ConfigDict
+from pydantic import Field
+from pydantic import field_validator
+from typing_extensions import override
+
+from ..events.event import Event
+from .callback_context import CallbackContext
+
+if TYPE_CHECKING:
+ from .invocation_context import InvocationContext
+
+tracer = trace.get_tracer('gcp.vertex.agent')
+
+BeforeAgentCallback = Callable[[CallbackContext], Optional[types.Content]]
+"""Callback signature that is invoked before the agent run.
+
+Args:
+ callback_context: MUST be named 'callback_context' (enforced).
+
+Returns:
+ The content to return to the user. When set, the agent run will skipped and
+ the provided content will be returned to user.
+"""
+
+AfterAgentCallback = Callable[[CallbackContext], Optional[types.Content]]
+"""Callback signature that is invoked after the agent run.
+
+Args:
+ callback_context: MUST be named 'callback_context' (enforced).
+
+Returns:
+ The content to return to the user. When set, the agent run will skipped and
+ the provided content will be appended to event history as agent response.
+"""
+
+
+class BaseAgent(BaseModel):
+ """Base class for all agents in Agent Development Kit."""
+
+ model_config = ConfigDict(
+ arbitrary_types_allowed=True,
+ extra='forbid',
+ )
+
+ name: str
+ """The agent's name.
+
+ Agent name must be a Python identifier and unique within the agent tree.
+ Agent name cannot be "user", since it's reserved for end-user's input.
+ """
+
+ description: str = ''
+ """Description about the agent's capability.
+
+ The model uses this to determine whether to delegate control to the agent.
+ One-line description is enough and preferred.
+ """
+
+ parent_agent: Optional[BaseAgent] = Field(default=None, init=False)
+ """The parent agent of this agent.
+
+ Note that an agent can ONLY be added as sub-agent once.
+
+ If you want to add one agent twice as sub-agent, consider to create two agent
+ instances with identical config, but with different name and add them to the
+ agent tree.
+ """
+ sub_agents: list[BaseAgent] = Field(default_factory=list)
+ """The sub-agents of this agent."""
+
+ before_agent_callback: Optional[BeforeAgentCallback] = None
+ """Callback signature that is invoked before the agent run.
+
+ Args:
+ callback_context: MUST be named 'callback_context' (enforced).
+
+ Returns:
+ The content to return to the user. When set, the agent run will skipped and
+ the provided content will be returned to user.
+ """
+ after_agent_callback: Optional[AfterAgentCallback] = None
+ """Callback signature that is invoked after the agent run.
+
+ Args:
+ callback_context: MUST be named 'callback_context' (enforced).
+
+ Returns:
+ The content to return to the user. When set, the agent run will skipped and
+ the provided content will be appended to event history as agent response.
+ """
+
+ @final
+ async def run_async(
+ self,
+ parent_context: InvocationContext,
+ ) -> AsyncGenerator[Event, None]:
+ """Entry method to run an agent via text-based conversaction.
+
+ Args:
+ parent_context: InvocationContext, the invocation context of the parent
+ agent.
+
+ Yields:
+ Event: the events generated by the agent.
+ """
+
+ with tracer.start_as_current_span(f'agent_run [{self.name}]'):
+ ctx = self._create_invocation_context(parent_context)
+
+ if event := self.__handle_before_agent_callback(ctx):
+ yield event
+ if ctx.end_invocation:
+ return
+
+ async for event in self._run_async_impl(ctx):
+ yield event
+
+ if ctx.end_invocation:
+ return
+
+ if event := self.__handle_after_agent_callback(ctx):
+ yield event
+
+ @final
+ async def run_live(
+ self,
+ parent_context: InvocationContext,
+ ) -> AsyncGenerator[Event, None]:
+ """Entry method to run an agent via video/audio-based conversaction.
+
+ Args:
+ parent_context: InvocationContext, the invocation context of the parent
+ agent.
+
+ Yields:
+ Event: the events generated by the agent.
+ """
+ with tracer.start_as_current_span(f'agent_run [{self.name}]'):
+ ctx = self._create_invocation_context(parent_context)
+ # TODO(hangfei): support before/after_agent_callback
+
+ async for event in self._run_live_impl(ctx):
+ yield event
+
+ async def _run_async_impl(
+ self, ctx: InvocationContext
+ ) -> AsyncGenerator[Event, None]:
+ """Core logic to run this agent via text-based conversaction.
+
+ Args:
+ ctx: InvocationContext, the invocation context for this agent.
+
+ Yields:
+ Event: the events generated by the agent.
+ """
+ raise NotImplementedError(
+ f'_run_async_impl for {type(self)} is not implemented.'
+ )
+ yield # AsyncGenerator requires having at least one yield statement
+
+ async def _run_live_impl(
+ self, ctx: InvocationContext
+ ) -> AsyncGenerator[Event, None]:
+ """Core logic to run this agent via video/audio-based conversaction.
+
+ Args:
+ ctx: InvocationContext, the invocation context for this agent.
+
+ Yields:
+ Event: the events generated by the agent.
+ """
+ raise NotImplementedError(
+ f'_run_live_impl for {type(self)} is not implemented.'
+ )
+ yield # AsyncGenerator requires having at least one yield statement
+
+ @property
+ def root_agent(self) -> BaseAgent:
+ """Gets the root agent of this agent."""
+ root_agent = self
+ while root_agent.parent_agent is not None:
+ root_agent = root_agent.parent_agent
+ return root_agent
+
+ def find_agent(self, name: str) -> Optional[BaseAgent]:
+ """Finds the agent with the given name in this agent and its descendants.
+
+ Args:
+ name: The name of the agent to find.
+
+ Returns:
+ The agent with the matching name, or None if no such agent is found.
+ """
+ if self.name == name:
+ return self
+ return self.find_sub_agent(name)
+
+ def find_sub_agent(self, name: str) -> Optional[BaseAgent]:
+ """Finds the agent with the given name in this agent's descendants.
+
+ Args:
+ name: The name of the agent to find.
+
+ Returns:
+ The agent with the matching name, or None if no such agent is found.
+ """
+ for sub_agent in self.sub_agents:
+ if result := sub_agent.find_agent(name):
+ return result
+ return None
+
+ def _create_invocation_context(
+ self, parent_context: InvocationContext
+ ) -> InvocationContext:
+ """Creates a new invocation context for this agent."""
+ invocation_context = parent_context.model_copy(update={'agent': self})
+ if parent_context.branch:
+ invocation_context.branch = f'{parent_context.branch}.{self.name}'
+ return invocation_context
+
+ def __handle_before_agent_callback(
+ self, ctx: InvocationContext
+ ) -> Optional[Event]:
+ """Runs the before_agent_callback if it exists.
+
+ Returns:
+ Optional[Event]: an event if callback provides content or changed state.
+ """
+ ret_event = None
+
+ if not isinstance(self.before_agent_callback, Callable):
+ return ret_event
+
+ callback_context = CallbackContext(ctx)
+ before_agent_callback_content = self.before_agent_callback(
+ callback_context=callback_context
+ )
+
+ if before_agent_callback_content:
+ ret_event = Event(
+ invocation_id=ctx.invocation_id,
+ author=self.name,
+ branch=ctx.branch,
+ content=before_agent_callback_content,
+ actions=callback_context._event_actions,
+ )
+ ctx.end_invocation = True
+ return ret_event
+
+ if callback_context.state.has_delta():
+ ret_event = Event(
+ invocation_id=ctx.invocation_id,
+ author=self.name,
+ branch=ctx.branch,
+ actions=callback_context._event_actions,
+ )
+
+ return ret_event
+
+ def __handle_after_agent_callback(
+ self, invocation_context: InvocationContext
+ ) -> Optional[Event]:
+ """Runs the after_agent_callback if it exists.
+
+ Returns:
+ Optional[Event]: an event if callback provides content or changed state.
+ """
+ ret_event = None
+
+ if not isinstance(self.after_agent_callback, Callable):
+ return ret_event
+
+ callback_context = CallbackContext(invocation_context)
+ after_agent_callback_content = self.after_agent_callback(
+ callback_context=callback_context
+ )
+
+ if after_agent_callback_content or callback_context.state.has_delta():
+ ret_event = Event(
+ invocation_id=invocation_context.invocation_id,
+ author=self.name,
+ branch=invocation_context.branch,
+ content=after_agent_callback_content,
+ actions=callback_context._event_actions,
+ )
+
+ return ret_event
+
+ @override
+ def model_post_init(self, __context: Any) -> None:
+ self.__set_parent_agent_for_sub_agents()
+
+ @field_validator('name', mode='after')
+ @classmethod
+ def __validate_name(cls, value: str):
+ if not value.isidentifier():
+ raise ValueError(
+ f'Found invalid agent name: `{value}`.'
+ ' Agent name must be a valid identifier. It should start with a'
+ ' letter (a-z, A-Z) or an underscore (_), and can only contain'
+ ' letters, digits (0-9), and underscores.'
+ )
+ if value == 'user':
+ raise ValueError(
+ "Agent name cannot be `user`. `user` is reserved for end-user's"
+ ' input.'
+ )
+ return value
+
+ def __set_parent_agent_for_sub_agents(self) -> BaseAgent:
+ for sub_agent in self.sub_agents:
+ if sub_agent.parent_agent is not None:
+ raise ValueError(
+ f'Agent `{sub_agent.name}` already has a parent agent, current'
+ f' parent: `{sub_agent.parent_agent.name}`, trying to add:'
+ f' `{self.name}`'
+ )
+ sub_agent.parent_agent = self
+ return self
diff --git a/src/google/adk/agents/callback_context.py b/src/google/adk/agents/callback_context.py
new file mode 100644
index 0000000..9b9f237
--- /dev/null
+++ b/src/google/adk/agents/callback_context.py
@@ -0,0 +1,112 @@
+# Copyright 2025 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import annotations
+
+from typing import Optional, TYPE_CHECKING
+
+from typing_extensions import override
+
+from .readonly_context import ReadonlyContext
+
+if TYPE_CHECKING:
+ from google.genai import types
+
+ from ..events.event import Event
+ from ..events.event_actions import EventActions
+ from ..sessions.state import State
+ from .invocation_context import InvocationContext
+
+
+class CallbackContext(ReadonlyContext):
+ """The context of various callbacks within an agent run."""
+
+ def __init__(
+ self,
+ invocation_context: InvocationContext,
+ *,
+ event_actions: Optional[EventActions] = None,
+ ) -> None:
+ super().__init__(invocation_context)
+
+ from ..events.event_actions import EventActions
+ from ..sessions.state import State
+
+ # TODO(weisun): make this public for Agent Development Kit, but private for
+ # users.
+ self._event_actions = event_actions or EventActions()
+ self._state = State(
+ value=invocation_context.session.state,
+ delta=self._event_actions.state_delta,
+ )
+
+ @property
+ @override
+ def state(self) -> State:
+ """The delta-aware state of the current session.
+
+ For any state change, you can mutate this object directly,
+ e.g. `ctx.state['foo'] = 'bar'`
+ """
+ return self._state
+
+ @property
+ def user_content(self) -> Optional[types.Content]:
+ """The user content that started this invocation. READONLY field."""
+ return self._invocation_context.user_content
+
+ def load_artifact(
+ self, filename: str, version: Optional[int] = None
+ ) -> Optional[types.Part]:
+ """Loads an artifact attached to the current session.
+
+ Args:
+ filename: The filename of the artifact.
+ version: The version of the artifact. If None, the latest version will be
+ returned.
+
+ Returns:
+ The artifact.
+ """
+ if self._invocation_context.artifact_service is None:
+ raise ValueError("Artifact service is not initialized.")
+ return self._invocation_context.artifact_service.load_artifact(
+ app_name=self._invocation_context.app_name,
+ user_id=self._invocation_context.user_id,
+ session_id=self._invocation_context.session.id,
+ filename=filename,
+ version=version,
+ )
+
+ def save_artifact(self, filename: str, artifact: types.Part) -> int:
+ """Saves an artifact and records it as delta for the current session.
+
+ Args:
+ filename: The filename of the artifact.
+ artifact: The artifact to save.
+
+ Returns:
+ The version of the artifact.
+ """
+ if self._invocation_context.artifact_service is None:
+ raise ValueError("Artifact service is not initialized.")
+ version = self._invocation_context.artifact_service.save_artifact(
+ app_name=self._invocation_context.app_name,
+ user_id=self._invocation_context.user_id,
+ session_id=self._invocation_context.session.id,
+ filename=filename,
+ artifact=artifact,
+ )
+ self._event_actions.artifact_delta[filename] = version
+ return version
diff --git a/src/google/adk/agents/invocation_context.py b/src/google/adk/agents/invocation_context.py
new file mode 100644
index 0000000..c02d62b
--- /dev/null
+++ b/src/google/adk/agents/invocation_context.py
@@ -0,0 +1,181 @@
+# Copyright 2025 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import annotations
+
+from typing import Optional
+import uuid
+
+from google.genai import types
+from pydantic import BaseModel
+from pydantic import ConfigDict
+
+from ..artifacts.base_artifact_service import BaseArtifactService
+from ..memory.base_memory_service import BaseMemoryService
+from ..sessions.base_session_service import BaseSessionService
+from ..sessions.session import Session
+from .active_streaming_tool import ActiveStreamingTool
+from .base_agent import BaseAgent
+from .live_request_queue import LiveRequestQueue
+from .run_config import RunConfig
+from .transcription_entry import TranscriptionEntry
+
+
+class LlmCallsLimitExceededError(Exception):
+ """Error thrown when the number of LLM calls exceed the limit."""
+
+
+class _InvocationCostManager(BaseModel):
+ """A container to keep track of the cost of invocation.
+
+ While we don't expected the metrics captured here to be a direct
+ representatative of monetary cost incurred in executing the current
+ invocation, but they, in someways have an indirect affect.
+ """
+
+ _number_of_llm_calls: int = 0
+ """A counter that keeps track of number of llm calls made."""
+
+ def increment_and_enforce_llm_calls_limit(
+ self, run_config: Optional[RunConfig]
+ ):
+ """Increments _number_of_llm_calls and enforces the limit."""
+ # We first increment the counter and then check the conditions.
+ self._number_of_llm_calls += 1
+
+ if (
+ run_config
+ and run_config.max_llm_calls > 0
+ and self._number_of_llm_calls > run_config.max_llm_calls
+ ):
+ # We only enforce the limit if the limit is a positive number.
+ raise LlmCallsLimitExceededError(
+ "Max number of llm calls limit of"
+ f" `{run_config.max_llm_calls}` exceeded"
+ )
+
+
+class InvocationContext(BaseModel):
+ """An invocation context represents the data of a single invocation of an agent.
+
+ An invocation:
+ 1. Starts with a user message and ends with a final response.
+ 2. Can contain one or multiple agent calls.
+ 3. Is handled by runner.run_async().
+
+ An invocation runs an agent until it does not request to transfer to another
+ agent.
+
+ An agent call:
+ 1. Is handled by agent.run().
+ 2. Ends when agent.run() ends.
+
+ An LLM agent call is an agent with a BaseLLMFlow.
+ An LLM agent call can contain one or multiple steps.
+
+ An LLM agent runs steps in a loop until:
+ 1. A final response is generated.
+ 2. The agent transfers to another agent.
+ 3. The end_invocation is set to true by any callbacks or tools.
+
+ A step:
+ 1. Calls the LLM only once and yields its response.
+ 2. Calls the tools and yields their responses if requested.
+
+ The summarization of the function response is considered another step, since
+ it is another llm call.
+ A step ends when it's done calling llm and tools, or if the end_invocation
+ is set to true at any time.
+
+ ```
+ ┌─────────────────────── invocation ──────────────────────────┐
+ ┌──────────── llm_agent_call_1 ────────────┐ ┌─ agent_call_2 ─┐
+ ┌──── step_1 ────────┐ ┌───── step_2 ──────┐
+ [call_llm] [call_tool] [call_llm] [transfer]
+ ```
+ """
+
+ model_config = ConfigDict(
+ arbitrary_types_allowed=True,
+ extra="forbid",
+ )
+
+ artifact_service: Optional[BaseArtifactService] = None
+ session_service: BaseSessionService
+ memory_service: Optional[BaseMemoryService] = None
+
+ invocation_id: str
+ """The id of this invocation context. Readonly."""
+ branch: Optional[str] = None
+ """The branch of the invocation context.
+
+ The format is like agent_1.agent_2.agent_3, where agent_1 is the parent of
+ agent_2, and agent_2 is the parent of agent_3.
+
+ Branch is used when multiple sub-agents shouldn't see their peer agents'
+ conversaction history.
+ """
+ agent: BaseAgent
+ """The current agent of this invocation context. Readonly."""
+ user_content: Optional[types.Content] = None
+ """The user content that started this invocation. Readonly."""
+ session: Session
+ """The current session of this invocation context. Readonly."""
+
+ end_invocation: bool = False
+ """Whether to end this invocation.
+
+ Set to True in callbacks or tools to terminate this invocation."""
+
+ live_request_queue: Optional[LiveRequestQueue] = None
+ """The queue to receive live requests."""
+
+ active_streaming_tools: Optional[dict[str, ActiveStreamingTool]] = None
+ """The running streaming tools of this invocation."""
+
+ transcription_cache: Optional[list[TranscriptionEntry]] = None
+ """Caches necessary, data audio or contents, that are needed by transcription."""
+
+ run_config: Optional[RunConfig] = None
+ """Configurations for live agents under this invocation."""
+
+ _invocation_cost_manager: _InvocationCostManager = _InvocationCostManager()
+ """A container to keep track of different kinds of costs incurred as a part
+ of this invocation.
+ """
+
+ def increment_llm_call_count(
+ self,
+ ):
+ """Tracks number of llm calls made.
+
+ Raises:
+ LlmCallsLimitExceededError: If number of llm calls made exceed the set
+ threshold.
+ """
+ self._invocation_cost_manager.increment_and_enforce_llm_calls_limit(
+ self.run_config
+ )
+
+ @property
+ def app_name(self) -> str:
+ return self.session.app_name
+
+ @property
+ def user_id(self) -> str:
+ return self.session.user_id
+
+
+def new_invocation_context_id() -> str:
+ return "e-" + str(uuid.uuid4())
diff --git a/src/google/adk/agents/langgraph_agent.py b/src/google/adk/agents/langgraph_agent.py
new file mode 100644
index 0000000..33a21b4
--- /dev/null
+++ b/src/google/adk/agents/langgraph_agent.py
@@ -0,0 +1,140 @@
+# Copyright 2025 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import AsyncGenerator
+from typing import Union
+
+from google.genai import types
+from langchain_core.messages import AIMessage
+from langchain_core.messages import HumanMessage
+from langchain_core.messages import SystemMessage
+from langchain_core.runnables.config import RunnableConfig
+from langgraph.graph.graph import CompiledGraph
+from pydantic import ConfigDict
+from typing_extensions import override
+
+from ..events.event import Event
+from .base_agent import BaseAgent
+from .invocation_context import InvocationContext
+
+
+def _get_last_human_messages(events: list[Event]) -> list[HumanMessage]:
+ """Extracts last human messages from given list of events.
+
+ Args:
+ events: the list of events
+
+ Returns:
+ list of last human messages
+ """
+ messages = []
+ for event in reversed(events):
+ if messages and event.author != 'user':
+ break
+ if event.author == 'user' and event.content and event.content.parts:
+ messages.append(HumanMessage(content=event.content.parts[0].text))
+ return list(reversed(messages))
+
+
+class LangGraphAgent(BaseAgent):
+ """Currently a concept implementation, supports single and multi-turn."""
+
+ model_config = ConfigDict(
+ arbitrary_types_allowed=True,
+ )
+
+ graph: CompiledGraph
+
+ instruction: str = ''
+
+ @override
+ async def _run_async_impl(
+ self,
+ ctx: InvocationContext,
+ ) -> AsyncGenerator[Event, None]:
+
+ # Needed for langgraph checkpointer (for subsequent invocations; multi-turn)
+ config: RunnableConfig = {'configurable': {'thread_id': ctx.session.id}}
+
+ # Add instruction as SystemMessage if graph state is empty
+ current_graph_state = self.graph.get_state(config)
+ graph_messages = (
+ current_graph_state.values.get('messages', [])
+ if current_graph_state.values
+ else []
+ )
+ messages = (
+ [SystemMessage(content=self.instruction)]
+ if self.instruction and not graph_messages
+ else []
+ )
+ # Add events to messages (evaluating the memory used; parent agent vs checkpointer)
+ messages += self._get_messages(ctx.session.events)
+
+ # Use the Runnable
+ final_state = self.graph.invoke({'messages': messages}, config)
+ result = final_state['messages'][-1].content
+
+ result_event = Event(
+ invocation_id=ctx.invocation_id,
+ author=self.name,
+ branch=ctx.branch,
+ content=types.Content(
+ role='model',
+ parts=[types.Part.from_text(text=result)],
+ ),
+ )
+ yield result_event
+
+ def _get_messages(
+ self, events: list[Event]
+ ) -> list[Union[HumanMessage, AIMessage]]:
+ """Extracts messages from given list of events.
+
+ If the developer provides their own memory within langgraph, we return the
+ last user messages only. Otherwise, we return all messages between the user
+ and the agent.
+
+ Args:
+ events: the list of events
+
+ Returns:
+ list of messages
+ """
+ if self.graph.checkpointer:
+ return _get_last_human_messages(events)
+ else:
+ return self._get_conversation_with_agent(events)
+
+ def _get_conversation_with_agent(
+ self, events: list[Event]
+ ) -> list[Union[HumanMessage, AIMessage]]:
+ """Extracts messages from given list of events.
+
+ Args:
+ events: the list of events
+
+ Returns:
+ list of messages
+ """
+
+ messages = []
+ for event in events:
+ if not event.content or not event.content.parts:
+ continue
+ if event.author == 'user':
+ messages.append(HumanMessage(content=event.content.parts[0].text))
+ elif event.author == self.name:
+ messages.append(AIMessage(content=event.content.parts[0].text))
+ return messages
diff --git a/src/google/adk/agents/live_request_queue.py b/src/google/adk/agents/live_request_queue.py
new file mode 100644
index 0000000..3caf725
--- /dev/null
+++ b/src/google/adk/agents/live_request_queue.py
@@ -0,0 +1,64 @@
+# Copyright 2025 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import asyncio
+from typing import Optional
+
+from google.genai import types
+from pydantic import BaseModel
+from pydantic import ConfigDict
+
+
+class LiveRequest(BaseModel):
+ """Request send to live agents."""
+
+ model_config = ConfigDict(ser_json_bytes='base64', val_json_bytes='base64')
+
+ content: Optional[types.Content] = None
+ """If set, send the content to the model in turn-by-turn mode."""
+ blob: Optional[types.Blob] = None
+ """If set, send the blob to the model in realtime mode."""
+ close: bool = False
+ """If set, close the queue. queue.shutdown() is only supported in Python 3.13+."""
+
+
+class LiveRequestQueue:
+ """Queue used to send LiveRequest in a live(bidirectional streaming) way."""
+
+ def __init__(self):
+ # Ensure there's an event loop available in this thread
+ try:
+ asyncio.get_running_loop()
+ except RuntimeError:
+ # No running loop, create one
+ loop = asyncio.new_event_loop()
+ asyncio.set_event_loop(loop)
+
+ # Now create the queue (it will use the event loop we just ensured exists)
+ self._queue = asyncio.Queue()
+
+ def close(self):
+ self._queue.put_nowait(LiveRequest(close=True))
+
+ def send_content(self, content: types.Content):
+ self._queue.put_nowait(LiveRequest(content=content))
+
+ def send_realtime(self, blob: types.Blob):
+ self._queue.put_nowait(LiveRequest(blob=blob))
+
+ def send(self, req: LiveRequest):
+ self._queue.put_nowait(req)
+
+ async def get(self) -> LiveRequest:
+ return await self._queue.get()
diff --git a/src/google/adk/agents/llm_agent.py b/src/google/adk/agents/llm_agent.py
new file mode 100644
index 0000000..a140997
--- /dev/null
+++ b/src/google/adk/agents/llm_agent.py
@@ -0,0 +1,376 @@
+# Copyright 2025 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import annotations
+
+import logging
+from typing import Any
+from typing import AsyncGenerator
+from typing import Callable
+from typing import Literal
+from typing import Optional
+from typing import Union
+
+from google.genai import types
+from pydantic import BaseModel
+from pydantic import Field
+from pydantic import field_validator
+from pydantic import model_validator
+from typing_extensions import override
+from typing_extensions import TypeAlias
+
+from ..code_executors.base_code_executor import BaseCodeExecutor
+from ..events.event import Event
+from ..examples.base_example_provider import BaseExampleProvider
+from ..examples.example import Example
+from ..flows.llm_flows.auto_flow import AutoFlow
+from ..flows.llm_flows.base_llm_flow import BaseLlmFlow
+from ..flows.llm_flows.single_flow import SingleFlow
+from ..models.base_llm import BaseLlm
+from ..models.llm_request import LlmRequest
+from ..models.llm_response import LlmResponse
+from ..models.registry import LLMRegistry
+from ..planners.base_planner import BasePlanner
+from ..tools.base_tool import BaseTool
+from ..tools.function_tool import FunctionTool
+from ..tools.tool_context import ToolContext
+from .base_agent import BaseAgent
+from .callback_context import CallbackContext
+from .invocation_context import InvocationContext
+from .readonly_context import ReadonlyContext
+
+logger = logging.getLogger(__name__)
+
+
+BeforeModelCallback: TypeAlias = Callable[
+ [CallbackContext, LlmRequest], Optional[LlmResponse]
+]
+AfterModelCallback: TypeAlias = Callable[
+ [CallbackContext, LlmResponse],
+ Optional[LlmResponse],
+]
+BeforeToolCallback: TypeAlias = Callable[
+ [BaseTool, dict[str, Any], ToolContext],
+ Optional[dict],
+]
+AfterToolCallback: TypeAlias = Callable[
+ [BaseTool, dict[str, Any], ToolContext, dict],
+ Optional[dict],
+]
+
+InstructionProvider: TypeAlias = Callable[[ReadonlyContext], str]
+
+ToolUnion: TypeAlias = Union[Callable, BaseTool]
+ExamplesUnion = Union[list[Example], BaseExampleProvider]
+
+
+def _convert_tool_union_to_tool(
+ tool_union: ToolUnion,
+) -> BaseTool:
+ return (
+ tool_union
+ if isinstance(tool_union, BaseTool)
+ else FunctionTool(tool_union)
+ )
+
+
+class LlmAgent(BaseAgent):
+ """LLM-based Agent."""
+
+ model: Union[str, BaseLlm] = ''
+ """The model to use for the agent.
+
+ When not set, the agent will inherit the model from its ancestor.
+ """
+
+ instruction: Union[str, InstructionProvider] = ''
+ """Instructions for the LLM model, guiding the agent's behavior."""
+
+ global_instruction: Union[str, InstructionProvider] = ''
+ """Instructions for all the agents in the entire agent tree.
+
+ global_instruction ONLY takes effect in root agent.
+
+ For example: use global_instruction to make all agents have a stable identity
+ or personality.
+ """
+
+ tools: list[ToolUnion] = Field(default_factory=list)
+ """Tools available to this agent."""
+
+ generate_content_config: Optional[types.GenerateContentConfig] = None
+ """The additional content generation configurations.
+
+ NOTE: not all fields are usable, e.g. tools must be configured via `tools`,
+ thinking_config must be configured via `planner` in LlmAgent.
+
+ For example: use this config to adjust model temperature, configure safety
+ settings, etc.
+ """
+
+ # LLM-based agent transfer configs - Start
+ disallow_transfer_to_parent: bool = False
+ """Disallows LLM-controlled transferring to the parent agent."""
+ disallow_transfer_to_peers: bool = False
+ """Disallows LLM-controlled transferring to the peer agents."""
+ # LLM-based agent transfer configs - End
+
+ include_contents: Literal['default', 'none'] = 'default'
+ """Whether to include contents in the model request.
+
+ When set to 'none', the model request will not include any contents, such as
+ user messages, tool results, etc.
+ """
+
+ # Controlled input/output configurations - Start
+ input_schema: Optional[type[BaseModel]] = None
+ """The input schema when agent is used as a tool."""
+ output_schema: Optional[type[BaseModel]] = None
+ """The output schema when agent replies.
+
+ NOTE: when this is set, agent can ONLY reply and CANNOT use any tools, such as
+ function tools, RAGs, agent transfer, etc.
+ """
+ output_key: Optional[str] = None
+ """The key in session state to store the output of the agent.
+
+ Typically use cases:
+ - Extracts agent reply for later use, such as in tools, callbacks, etc.
+ - Connects agents to coordinate with each other.
+ """
+ # Controlled input/output configurations - End
+
+ # Advance features - Start
+ planner: Optional[BasePlanner] = None
+ """Instructs the agent to make a plan and execute it step by step.
+
+ NOTE: to use model's built-in thinking features, set the `thinking_config`
+ field in `google.adk.planners.built_in_planner`.
+
+ """
+
+ code_executor: Optional[BaseCodeExecutor] = None
+ """Allow agent to execute code blocks from model responses using the provided
+ CodeExecutor.
+
+ Check out available code executions in `google.adk.code_executor` package.
+
+ NOTE: to use model's built-in code executor, don't set this field, add
+ `google.adk.tools.built_in_code_execution` to tools instead.
+ """
+ # Advance features - End
+
+ # TODO: remove below fields after migration. - Start
+ # These fields are added back for easier migration.
+ examples: Optional[ExamplesUnion] = None
+ # TODO: remove above fields after migration. - End
+
+ # Callbacks - Start
+ before_model_callback: Optional[BeforeModelCallback] = None
+ """Called before calling the LLM.
+ Args:
+ callback_context: CallbackContext,
+ llm_request: LlmRequest, The raw model request. Callback can mutate the
+ request.
+
+ Returns:
+ The content to return to the user. When present, the model call will be
+ skipped and the provided content will be returned to user.
+ """
+ after_model_callback: Optional[AfterModelCallback] = None
+ """Called after calling LLM.
+
+ Args:
+ callback_context: CallbackContext,
+ llm_response: LlmResponse, the actual model response.
+
+ Returns:
+ The content to return to the user. When present, the actual model response
+ will be ignored and the provided content will be returned to user.
+ """
+ before_tool_callback: Optional[BeforeToolCallback] = None
+ """Called before the tool is called.
+
+ Args:
+ tool: The tool to be called.
+ args: The arguments to the tool.
+ tool_context: ToolContext,
+
+ Returns:
+ The tool response. When present, the returned tool response will be used and
+ the framework will skip calling the actual tool.
+ """
+ after_tool_callback: Optional[AfterToolCallback] = None
+ """Called after the tool is called.
+
+ Args:
+ tool: The tool to be called.
+ args: The arguments to the tool.
+ tool_context: ToolContext,
+ tool_response: The response from the tool.
+
+ Returns:
+ When present, the returned dict will be used as tool result.
+ """
+ # Callbacks - End
+
+ @override
+ async def _run_async_impl(
+ self, ctx: InvocationContext
+ ) -> AsyncGenerator[Event, None]:
+ async for event in self._llm_flow.run_async(ctx):
+ self.__maybe_save_output_to_state(event)
+ yield event
+
+ @override
+ async def _run_live_impl(
+ self, ctx: InvocationContext
+ ) -> AsyncGenerator[Event, None]:
+ async for event in self._llm_flow.run_live(ctx):
+ self.__maybe_save_output_to_state(event)
+ yield event
+ if ctx.end_invocation:
+ return
+
+ @property
+ def canonical_model(self) -> BaseLlm:
+ """The resolved self.model field as BaseLlm.
+
+ This method is only for use by Agent Development Kit.
+ """
+ if isinstance(self.model, BaseLlm):
+ return self.model
+ elif self.model: # model is non-empty str
+ return LLMRegistry.new_llm(self.model)
+ else: # find model from ancestors.
+ ancestor_agent = self.parent_agent
+ while ancestor_agent is not None:
+ if isinstance(ancestor_agent, LlmAgent):
+ return ancestor_agent.canonical_model
+ ancestor_agent = ancestor_agent.parent_agent
+ raise ValueError(f'No model found for {self.name}.')
+
+ def canonical_instruction(self, ctx: ReadonlyContext) -> str:
+ """The resolved self.instruction field to construct instruction for this agent.
+
+ This method is only for use by Agent Development Kit.
+ """
+ if isinstance(self.instruction, str):
+ return self.instruction
+ else:
+ return self.instruction(ctx)
+
+ def canonical_global_instruction(self, ctx: ReadonlyContext) -> str:
+ """The resolved self.instruction field to construct global instruction.
+
+ This method is only for use by Agent Development Kit.
+ """
+ if isinstance(self.global_instruction, str):
+ return self.global_instruction
+ else:
+ return self.global_instruction(ctx)
+
+ @property
+ def canonical_tools(self) -> list[BaseTool]:
+ """The resolved self.tools field as a list of BaseTool.
+
+ This method is only for use by Agent Development Kit.
+ """
+ return [_convert_tool_union_to_tool(tool) for tool in self.tools]
+
+ @property
+ def _llm_flow(self) -> BaseLlmFlow:
+ if (
+ self.disallow_transfer_to_parent
+ and self.disallow_transfer_to_peers
+ and not self.sub_agents
+ ):
+ return SingleFlow()
+ else:
+ return AutoFlow()
+
+ def __maybe_save_output_to_state(self, event: Event):
+ """Saves the model output to state if needed."""
+ if (
+ self.output_key
+ and event.is_final_response()
+ and event.content
+ and event.content.parts
+ ):
+ result = ''.join(
+ [part.text if part.text else '' for part in event.content.parts]
+ )
+ if self.output_schema:
+ result = self.output_schema.model_validate_json(result).model_dump(
+ exclude_none=True
+ )
+ event.actions.state_delta[self.output_key] = result
+
+ @model_validator(mode='after')
+ def __model_validator_after(self) -> LlmAgent:
+ self.__check_output_schema()
+ return self
+
+ def __check_output_schema(self):
+ if not self.output_schema:
+ return
+
+ if (
+ not self.disallow_transfer_to_parent
+ or not self.disallow_transfer_to_peers
+ ):
+ logger.warning(
+ 'Invalid config for agent %s: output_schema cannot co-exist with'
+ ' agent transfer configurations. Setting'
+ ' disallow_transfer_to_parent=True, disallow_transfer_to_peers=True',
+ self.name,
+ )
+ self.disallow_transfer_to_parent = True
+ self.disallow_transfer_to_peers = True
+
+ if self.sub_agents:
+ raise ValueError(
+ f'Invalid config for agent {self.name}: if output_schema is set,'
+ ' sub_agents must be empty to disable agent transfer.'
+ )
+
+ if self.tools:
+ raise ValueError(
+ f'Invalid config for agent {self.name}: if output_schema is set,'
+ ' tools must be empty'
+ )
+
+ @field_validator('generate_content_config', mode='after')
+ @classmethod
+ def __validate_generate_content_config(
+ cls, generate_content_config: Optional[types.GenerateContentConfig]
+ ) -> types.GenerateContentConfig:
+ if not generate_content_config:
+ return types.GenerateContentConfig()
+ if generate_content_config.thinking_config:
+ raise ValueError('Thinking config should be set via LlmAgent.planner.')
+ if generate_content_config.tools:
+ raise ValueError('All tools must be set via LlmAgent.tools.')
+ if generate_content_config.system_instruction:
+ raise ValueError(
+ 'System instruction must be set via LlmAgent.instruction.'
+ )
+ if generate_content_config.response_schema:
+ raise ValueError(
+ 'Response schema must be set via LlmAgent.output_schema.'
+ )
+ return generate_content_config
+
+
+Agent: TypeAlias = LlmAgent
diff --git a/src/google/adk/agents/loop_agent.py b/src/google/adk/agents/loop_agent.py
new file mode 100644
index 0000000..c760c37
--- /dev/null
+++ b/src/google/adk/agents/loop_agent.py
@@ -0,0 +1,62 @@
+# Copyright 2025 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Loop agent implementation."""
+
+from __future__ import annotations
+
+from typing import AsyncGenerator
+from typing import Optional
+
+from typing_extensions import override
+
+from ..agents.invocation_context import InvocationContext
+from ..events.event import Event
+from .base_agent import BaseAgent
+
+
+class LoopAgent(BaseAgent):
+ """A shell agent that run its sub-agents in a loop.
+
+ When sub-agent generates an event with escalate or max_iterations are
+ reached, the loop agent will stop.
+ """
+
+ max_iterations: Optional[int] = None
+ """The maximum number of iterations to run the loop agent.
+
+ If not set, the loop agent will run indefinitely until a sub-agent
+ escalates.
+ """
+
+ @override
+ async def _run_async_impl(
+ self, ctx: InvocationContext
+ ) -> AsyncGenerator[Event, None]:
+ times_looped = 0
+ while not self.max_iterations or times_looped < self.max_iterations:
+ for sub_agent in self.sub_agents:
+ async for event in sub_agent.run_async(ctx):
+ yield event
+ if event.actions.escalate:
+ return
+ times_looped += 1
+ return
+
+ @override
+ async def _run_live_impl(
+ self, ctx: InvocationContext
+ ) -> AsyncGenerator[Event, None]:
+ raise NotImplementedError('The behavior for run_live is not defined yet.')
+ yield # AsyncGenerator requires having at least one yield statement
diff --git a/src/google/adk/agents/parallel_agent.py b/src/google/adk/agents/parallel_agent.py
new file mode 100644
index 0000000..4647fd4
--- /dev/null
+++ b/src/google/adk/agents/parallel_agent.py
@@ -0,0 +1,96 @@
+# Copyright 2025 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Parallel agent implementation."""
+
+from __future__ import annotations
+
+import asyncio
+from typing import AsyncGenerator
+
+from typing_extensions import override
+
+from ..agents.invocation_context import InvocationContext
+from ..events.event import Event
+from .base_agent import BaseAgent
+
+
+def _set_branch_for_current_agent(
+ current_agent: BaseAgent, invocation_context: InvocationContext
+):
+ invocation_context.branch = (
+ f"{invocation_context.branch}.{current_agent.name}"
+ if invocation_context.branch
+ else current_agent.name
+ )
+
+
+async def _merge_agent_run(
+ agent_runs: list[AsyncGenerator[Event, None]],
+) -> AsyncGenerator[Event, None]:
+ """Merges the agent run event generator.
+
+ This implementation guarantees for each agent, it won't move on until the
+ generated event is processed by upstream runner.
+
+ Args:
+ agent_runs: A list of async generators that yield events from each agent.
+
+ Yields:
+ Event: The next event from the merged generator.
+ """
+ tasks = [
+ asyncio.create_task(events_for_one_agent.__anext__())
+ for events_for_one_agent in agent_runs
+ ]
+ pending_tasks = set(tasks)
+
+ while pending_tasks:
+ done, pending_tasks = await asyncio.wait(
+ pending_tasks, return_when=asyncio.FIRST_COMPLETED
+ )
+ for task in done:
+ try:
+ yield task.result()
+
+ # Find the generator that produced this event and move it on.
+ for i, original_task in enumerate(tasks):
+ if task == original_task:
+ new_task = asyncio.create_task(agent_runs[i].__anext__())
+ tasks[i] = new_task
+ pending_tasks.add(new_task)
+ break # stop iterating once found
+
+ except StopAsyncIteration:
+ continue
+
+
+class ParallelAgent(BaseAgent):
+ """A shell agent that run its sub-agents in parallel in isolated manner.
+
+ This approach is beneficial for scenarios requiring multiple perspectives or
+ attempts on a single task, such as:
+
+ - Running different algorithms simultaneously.
+ - Generating multiple responses for review by a subsequent evaluation agent.
+ """
+
+ @override
+ async def _run_async_impl(
+ self, ctx: InvocationContext
+ ) -> AsyncGenerator[Event, None]:
+ _set_branch_for_current_agent(self, ctx)
+ agent_runs = [agent.run_async(ctx) for agent in self.sub_agents]
+ async for event in _merge_agent_run(agent_runs):
+ yield event
diff --git a/src/google/adk/agents/readonly_context.py b/src/google/adk/agents/readonly_context.py
new file mode 100644
index 0000000..fb373cc
--- /dev/null
+++ b/src/google/adk/agents/readonly_context.py
@@ -0,0 +1,46 @@
+# Copyright 2025 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import annotations
+
+from types import MappingProxyType
+from typing import Any
+from typing import TYPE_CHECKING
+
+if TYPE_CHECKING:
+ from .invocation_context import InvocationContext
+
+
+class ReadonlyContext:
+
+ def __init__(
+ self,
+ invocation_context: InvocationContext,
+ ) -> None:
+ self._invocation_context = invocation_context
+
+ @property
+ def invocation_id(self) -> str:
+ """The current invocation id."""
+ return self._invocation_context.invocation_id
+
+ @property
+ def agent_name(self) -> str:
+ """The name of the agent that is currently running."""
+ return self._invocation_context.agent.name
+
+ @property
+ def state(self) -> MappingProxyType[str, Any]:
+ """The state of the current session. READONLY field."""
+ return MappingProxyType(self._invocation_context.session.state)
diff --git a/src/google/adk/agents/remote_agent.py b/src/google/adk/agents/remote_agent.py
new file mode 100644
index 0000000..2c5bc70
--- /dev/null
+++ b/src/google/adk/agents/remote_agent.py
@@ -0,0 +1,50 @@
+# Copyright 2025 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import json
+from typing import AsyncGenerator
+
+from pydantic import Field
+import requests
+from typing_extensions import override
+
+from ..events.event import Event
+from .base_agent import BaseAgent
+from .invocation_context import InvocationContext
+
+
+class RemoteAgent(BaseAgent):
+ """Experimental, do not use."""
+
+ url: str
+
+ sub_agents: list[BaseAgent] = Field(
+ default_factory=list, init=False, frozen=True
+ )
+ """Sub-agent is dsiabled in RemoteAgent."""
+
+ @override
+ async def _run_async_impl(
+ self, ctx: InvocationContext
+ ) -> AsyncGenerator[Event, None]:
+ data = {
+ 'invocation_id': ctx.invocation_id,
+ 'session': ctx.session.model_dump(exclude_none=True),
+ }
+ events = requests.post(self.url, data=json.dumps(data), timeout=120)
+ events.raise_for_status()
+ for event in events.json():
+ e = Event.model_validate(event)
+ e.author = self.name
+ yield e
diff --git a/src/google/adk/agents/run_config.py b/src/google/adk/agents/run_config.py
new file mode 100644
index 0000000..779ab1b
--- /dev/null
+++ b/src/google/adk/agents/run_config.py
@@ -0,0 +1,87 @@
+# Copyright 2025 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from enum import Enum
+import logging
+import sys
+from typing import Optional
+
+from google.genai import types
+from pydantic import BaseModel
+from pydantic import ConfigDict
+from pydantic import field_validator
+
+logger = logging.getLogger(__name__)
+
+
+class StreamingMode(Enum):
+ NONE = None
+ SSE = 'sse'
+ BIDI = 'bidi'
+
+
+class RunConfig(BaseModel):
+ """Configs for runtime behavior of agents."""
+
+ model_config = ConfigDict(
+ extra='forbid',
+ )
+
+ speech_config: Optional[types.SpeechConfig] = None
+ """Speech configuration for the live agent."""
+
+ response_modalities: Optional[list[str]] = None
+ """The output modalities. If not set, its default to AUDIO."""
+
+ save_input_blobs_as_artifacts: bool = False
+ """Whether or not to save the input blobs as artifacts."""
+
+ support_cfc: bool = False
+ """
+ Whether to support CFC (Compositional Function Calling). Only applicable for
+ StreamingMode.SSE. If it's true. the LIVE API will be invoked. Since only LIVE
+ API supports CFC
+ """
+
+ streaming_mode: StreamingMode = StreamingMode.NONE
+ """Streaming mode, None or StreamingMode.SSE or StreamingMode.BIDI."""
+
+ output_audio_transcription: Optional[types.AudioTranscriptionConfig] = None
+ """Output transcription for live agents with audio response."""
+
+ max_llm_calls: int = 500
+ """
+ A limit on the total number of llm calls for a given run.
+
+ Valid Values:
+ - More than 0 and less than sys.maxsize: The bound on the number of llm
+ calls is enforced, if the value is set in this range.
+ - Less than or equal to 0: This allows for unbounded number of llm calls.
+ """
+
+ @field_validator('max_llm_calls', mode='after')
+ @classmethod
+ def validate_max_llm_calls(cls, value: int) -> int:
+ if value == sys.maxsize:
+ raise ValueError(f'max_llm_calls should be less than {sys.maxsize}.')
+ elif value <= 0:
+ logger.warning(
+ 'max_llm_calls is less than or equal to 0. This will result in'
+ ' no enforcement on total number of llm calls that will be made for a'
+ ' run. This may not be ideal, as this could result in a never'
+ ' ending communication between the model and the agent in certain'
+ ' cases.',
+ )
+
+ return value
diff --git a/src/google/adk/agents/sequential_agent.py b/src/google/adk/agents/sequential_agent.py
new file mode 100644
index 0000000..8dabcff
--- /dev/null
+++ b/src/google/adk/agents/sequential_agent.py
@@ -0,0 +1,45 @@
+# Copyright 2025 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Sequential agent implementation."""
+
+from __future__ import annotations
+
+from typing import AsyncGenerator
+
+from typing_extensions import override
+
+from ..agents.invocation_context import InvocationContext
+from ..events.event import Event
+from .base_agent import BaseAgent
+
+
+class SequentialAgent(BaseAgent):
+ """A shell agent that run its sub-agents in sequence."""
+
+ @override
+ async def _run_async_impl(
+ self, ctx: InvocationContext
+ ) -> AsyncGenerator[Event, None]:
+ for sub_agent in self.sub_agents:
+ async for event in sub_agent.run_async(ctx):
+ yield event
+
+ @override
+ async def _run_live_impl(
+ self, ctx: InvocationContext
+ ) -> AsyncGenerator[Event, None]:
+ for sub_agent in self.sub_agents:
+ async for event in sub_agent.run_live(ctx):
+ yield event
diff --git a/src/google/adk/agents/transcription_entry.py b/src/google/adk/agents/transcription_entry.py
new file mode 100644
index 0000000..f415e7c
--- /dev/null
+++ b/src/google/adk/agents/transcription_entry.py
@@ -0,0 +1,34 @@
+# Copyright 2025 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import Union
+
+from google.genai import types
+from pydantic import BaseModel
+from pydantic import ConfigDict
+
+
+class TranscriptionEntry(BaseModel):
+ """Store the data that can be used for transcription."""
+
+ model_config = ConfigDict(
+ arbitrary_types_allowed=True,
+ extra='forbid',
+ )
+
+ role: str
+ """The role that created this data, typically "user" or "model"""
+
+ data: Union[types.Blob, types.Content]
+ """The data that can be used for transcription"""
diff --git a/src/google/adk/artifacts/__init__.py b/src/google/adk/artifacts/__init__.py
new file mode 100644
index 0000000..4a6c7c6
--- /dev/null
+++ b/src/google/adk/artifacts/__init__.py
@@ -0,0 +1,23 @@
+# Copyright 2025 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from .base_artifact_service import BaseArtifactService
+from .gcs_artifact_service import GcsArtifactService
+from .in_memory_artifact_service import InMemoryArtifactService
+
+__all__ = [
+ 'BaseArtifactService',
+ 'GcsArtifactService',
+ 'InMemoryArtifactService',
+]
diff --git a/src/google/adk/artifacts/base_artifact_service.py b/src/google/adk/artifacts/base_artifact_service.py
new file mode 100644
index 0000000..0af9146
--- /dev/null
+++ b/src/google/adk/artifacts/base_artifact_service.py
@@ -0,0 +1,128 @@
+# Copyright 2025 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Abstract base class for artifact services."""
+
+from abc import ABC
+from abc import abstractmethod
+from typing import Optional
+
+from google.genai import types
+
+
+class BaseArtifactService(ABC):
+ """Abstract base class for artifact services."""
+
+ @abstractmethod
+ def save_artifact(
+ self,
+ *,
+ app_name: str,
+ user_id: str,
+ session_id: str,
+ filename: str,
+ artifact: types.Part,
+ ) -> int:
+ """Saves an artifact to the artifact service storage.
+
+ The artifact is a file identified by the app name, user ID, session ID, and
+ filename. After saving the artifact, a revision ID is returned to identify
+ the artifact version.
+
+ Args:
+ app_name: The app name.
+ user_id: The user ID.
+ session_id: The session ID.
+ filename: The filename of the artifact.
+ artifact: The artifact to save.
+
+ Returns:
+ The revision ID. The first version of the artifact has a revision ID of 0.
+ This is incremented by 1 after each successful save.
+ """
+
+ @abstractmethod
+ def load_artifact(
+ self,
+ *,
+ app_name: str,
+ user_id: str,
+ session_id: str,
+ filename: str,
+ version: Optional[int] = None,
+ ) -> Optional[types.Part]:
+ """Gets an artifact from the artifact service storage.
+
+ The artifact is a file identified by the app name, user ID, session ID, and
+ filename.
+
+ Args:
+ app_name: The app name.
+ user_id: The user ID.
+ session_id: The session ID.
+ filename: The filename of the artifact.
+ version: The version of the artifact. If None, the latest version will be
+ returned.
+
+ Returns:
+ The artifact or None if not found.
+ """
+ pass
+
+ @abstractmethod
+ def list_artifact_keys(
+ self, *, app_name: str, user_id: str, session_id: str
+ ) -> list[str]:
+ """Lists all the artifact filenames within a session.
+
+ Args:
+ app_name: The name of the application.
+ user_id: The ID of the user.
+ session_id: The ID of the session.
+
+ Returns:
+ A list of all artifact filenames within a session.
+ """
+ pass
+
+ @abstractmethod
+ def delete_artifact(
+ self, *, app_name: str, user_id: str, session_id: str, filename: str
+ ) -> None:
+ """Deletes an artifact.
+
+ Args:
+ app_name: The name of the application.
+ user_id: The ID of the user.
+ session_id: The ID of the session.
+ filename: The name of the artifact file.
+ """
+ pass
+
+ @abstractmethod
+ def list_versions(
+ self, *, app_name: str, user_id: str, session_id: str, filename: str
+ ) -> list[int]:
+ """Lists all versions of an artifact.
+
+ Args:
+ app_name: The name of the application.
+ user_id: The ID of the user.
+ session_id: The ID of the session.
+ filename: The name of the artifact file.
+
+ Returns:
+ A list of all available versions of the artifact.
+ """
+ pass
diff --git a/src/google/adk/artifacts/gcs_artifact_service.py b/src/google/adk/artifacts/gcs_artifact_service.py
new file mode 100644
index 0000000..279d5e0
--- /dev/null
+++ b/src/google/adk/artifacts/gcs_artifact_service.py
@@ -0,0 +1,195 @@
+# Copyright 2025 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""An artifact service implementation using Google Cloud Storage (GCS)."""
+
+import logging
+from typing import Optional
+
+from google.cloud import storage
+from google.genai import types
+from typing_extensions import override
+
+from .base_artifact_service import BaseArtifactService
+
+logger = logging.getLogger(__name__)
+
+
+class GcsArtifactService(BaseArtifactService):
+ """An artifact service implementation using Google Cloud Storage (GCS)."""
+
+ def __init__(self, bucket_name: str, **kwargs):
+ """Initializes the GcsArtifactService.
+
+ Args:
+ bucket_name: The name of the bucket to use.
+ **kwargs: Keyword arguments to pass to the Google Cloud Storage client.
+ """
+ self.bucket_name = bucket_name
+ self.storage_client = storage.Client(**kwargs)
+ self.bucket = self.storage_client.bucket(self.bucket_name)
+
+ def _file_has_user_namespace(self, filename: str) -> bool:
+ """Checks if the filename has a user namespace.
+
+ Args:
+ filename: The filename to check.
+
+ Returns:
+ True if the filename has a user namespace (starts with "user:"),
+ False otherwise.
+ """
+ return filename.startswith("user:")
+
+ def _get_blob_name(
+ self,
+ app_name: str,
+ user_id: str,
+ session_id: str,
+ filename: str,
+ version: int,
+ ) -> str:
+ """Constructs the blob name in GCS.
+
+ Args:
+ app_name: The name of the application.
+ user_id: The ID of the user.
+ session_id: The ID of the session.
+ filename: The name of the artifact file.
+ version: The version of the artifact.
+
+ Returns:
+ The constructed blob name in GCS.
+ """
+ if self._file_has_user_namespace(filename):
+ return f"{app_name}/{user_id}/user/{filename}/{version}"
+ return f"{app_name}/{user_id}/{session_id}/{filename}/{version}"
+
+ @override
+ def save_artifact(
+ self,
+ *,
+ app_name: str,
+ user_id: str,
+ session_id: str,
+ filename: str,
+ artifact: types.Part,
+ ) -> int:
+ versions = self.list_versions(
+ app_name=app_name,
+ user_id=user_id,
+ session_id=session_id,
+ filename=filename,
+ )
+ version = 0 if not versions else max(versions) + 1
+
+ blob_name = self._get_blob_name(
+ app_name, user_id, session_id, filename, version
+ )
+ blob = self.bucket.blob(blob_name)
+
+ blob.upload_from_string(
+ data=artifact.inline_data.data,
+ content_type=artifact.inline_data.mime_type,
+ )
+
+ return version
+
+ @override
+ def load_artifact(
+ self,
+ *,
+ app_name: str,
+ user_id: str,
+ session_id: str,
+ filename: str,
+ version: Optional[int] = None,
+ ) -> Optional[types.Part]:
+ if version is None:
+ versions = self.list_versions(
+ app_name=app_name,
+ user_id=user_id,
+ session_id=session_id,
+ filename=filename,
+ )
+ if not versions:
+ return None
+ version = max(versions)
+
+ blob_name = self._get_blob_name(
+ app_name, user_id, session_id, filename, version
+ )
+ blob = self.bucket.blob(blob_name)
+
+ artifact_bytes = blob.download_as_bytes()
+ if not artifact_bytes:
+ return None
+ artifact = types.Part.from_bytes(
+ data=artifact_bytes, mime_type=blob.content_type
+ )
+ return artifact
+
+ @override
+ def list_artifact_keys(
+ self, *, app_name: str, user_id: str, session_id: str
+ ) -> list[str]:
+ filenames = set()
+
+ session_prefix = f"{app_name}/{user_id}/{session_id}/"
+ session_blobs = self.storage_client.list_blobs(
+ self.bucket, prefix=session_prefix
+ )
+ for blob in session_blobs:
+ _, _, _, filename, _ = blob.name.split("/")
+ filenames.add(filename)
+
+ user_namespace_prefix = f"{app_name}/{user_id}/user/"
+ user_namespace_blobs = self.storage_client.list_blobs(
+ self.bucket, prefix=user_namespace_prefix
+ )
+ for blob in user_namespace_blobs:
+ _, _, _, filename, _ = blob.name.split("/")
+ filenames.add(filename)
+
+ return sorted(list(filenames))
+
+ @override
+ def delete_artifact(
+ self, *, app_name: str, user_id: str, session_id: str, filename: str
+ ) -> None:
+ versions = self.list_versions(
+ app_name=app_name,
+ user_id=user_id,
+ session_id=session_id,
+ filename=filename,
+ )
+ for version in versions:
+ blob_name = self._get_blob_name(
+ app_name, user_id, session_id, filename, version
+ )
+ blob = self.bucket.blob(blob_name)
+ blob.delete()
+ return
+
+ @override
+ def list_versions(
+ self, *, app_name: str, user_id: str, session_id: str, filename: str
+ ) -> list[int]:
+ prefix = self._get_blob_name(app_name, user_id, session_id, filename, "")
+ blobs = self.storage_client.list_blobs(self.bucket, prefix=prefix)
+ versions = []
+ for blob in blobs:
+ _, _, _, _, version = blob.name.split("/")
+ versions.append(int(version))
+ return versions
diff --git a/src/google/adk/artifacts/in_memory_artifact_service.py b/src/google/adk/artifacts/in_memory_artifact_service.py
new file mode 100644
index 0000000..8c886f6
--- /dev/null
+++ b/src/google/adk/artifacts/in_memory_artifact_service.py
@@ -0,0 +1,133 @@
+# Copyright 2025 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""An in-memory implementation of the artifact service."""
+
+import logging
+from typing import Optional
+
+from google.genai import types
+from pydantic import BaseModel
+from pydantic import Field
+from typing_extensions import override
+
+from .base_artifact_service import BaseArtifactService
+
+logger = logging.getLogger(__name__)
+
+
+class InMemoryArtifactService(BaseArtifactService, BaseModel):
+ """An in-memory implementation of the artifact service."""
+
+ artifacts: dict[str, list[types.Part]] = Field(default_factory=dict)
+
+ def _file_has_user_namespace(self, filename: str) -> bool:
+ """Checks if the filename has a user namespace.
+
+ Args:
+ filename: The filename to check.
+
+ Returns:
+ True if the filename has a user namespace (starts with "user:"),
+ False otherwise.
+ """
+ return filename.startswith("user:")
+
+ def _artifact_path(
+ self, app_name: str, user_id: str, session_id: str, filename: str
+ ) -> str:
+ """Constructs the artifact path.
+
+ Args:
+ app_name: The name of the application.
+ user_id: The ID of the user.
+ session_id: The ID of the session.
+ filename: The name of the artifact file.
+
+ Returns:
+ The constructed artifact path.
+ """
+ if self._file_has_user_namespace(filename):
+ return f"{app_name}/{user_id}/user/{filename}"
+ return f"{app_name}/{user_id}/{session_id}/{filename}"
+
+ @override
+ def save_artifact(
+ self,
+ *,
+ app_name: str,
+ user_id: str,
+ session_id: str,
+ filename: str,
+ artifact: types.Part,
+ ) -> int:
+ path = self._artifact_path(app_name, user_id, session_id, filename)
+ if path not in self.artifacts:
+ self.artifacts[path] = []
+ version = len(self.artifacts[path])
+ self.artifacts[path].append(artifact)
+ return version
+
+ @override
+ def load_artifact(
+ self,
+ *,
+ app_name: str,
+ user_id: str,
+ session_id: str,
+ filename: str,
+ version: Optional[int] = None,
+ ) -> Optional[types.Part]:
+ path = self._artifact_path(app_name, user_id, session_id, filename)
+ versions = self.artifacts.get(path)
+ if not versions:
+ return None
+ if version is None:
+ version = -1
+ return versions[version]
+
+ @override
+ def list_artifact_keys(
+ self, *, app_name: str, user_id: str, session_id: str
+ ) -> list[str]:
+ session_prefix = f"{app_name}/{user_id}/{session_id}/"
+ usernamespace_prefix = f"{app_name}/{user_id}/user/"
+ filenames = []
+ for path in self.artifacts:
+ if path.startswith(session_prefix):
+ filename = path.removeprefix(session_prefix)
+ filenames.append(filename)
+ elif path.startswith(usernamespace_prefix):
+ filename = path.removeprefix(usernamespace_prefix)
+ filenames.append(filename)
+ return sorted(filenames)
+
+ @override
+ def delete_artifact(
+ self, *, app_name: str, user_id: str, session_id: str, filename: str
+ ) -> None:
+ path = self._artifact_path(app_name, user_id, session_id, filename)
+ if not self.artifacts.get(path):
+ return None
+ self.artifacts.pop(path, None)
+
+ @override
+ def list_versions(
+ self, *, app_name: str, user_id: str, session_id: str, filename: str
+ ) -> list[int]:
+ path = self._artifact_path(app_name, user_id, session_id, filename)
+ versions = self.artifacts.get(path)
+ if not versions:
+ return []
+ return list(range(len(versions)))
diff --git a/src/google/adk/auth/__init__.py b/src/google/adk/auth/__init__.py
new file mode 100644
index 0000000..49fba37
--- /dev/null
+++ b/src/google/adk/auth/__init__.py
@@ -0,0 +1,22 @@
+# Copyright 2025 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from .auth_credential import AuthCredential
+from .auth_credential import AuthCredentialTypes
+from .auth_credential import OAuth2Auth
+from .auth_handler import AuthHandler
+from .auth_schemes import AuthScheme
+from .auth_schemes import AuthSchemeType
+from .auth_schemes import OpenIdConnectWithConfig
+from .auth_tool import AuthConfig
diff --git a/src/google/adk/auth/auth_credential.py b/src/google/adk/auth/auth_credential.py
new file mode 100644
index 0000000..5d49cee
--- /dev/null
+++ b/src/google/adk/auth/auth_credential.py
@@ -0,0 +1,220 @@
+# Copyright 2025 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from enum import Enum
+from typing import Any
+from typing import Dict
+from typing import List
+from typing import Optional
+
+from pydantic import BaseModel
+from pydantic import Field
+
+
+class BaseModelWithConfig(BaseModel):
+ model_config = {"extra": "allow"}
+
+
+class HttpCredentials(BaseModelWithConfig):
+ """Represents the secret token value for HTTP authentication, like user name, password, oauth token, etc."""
+
+ username: Optional[str] = None
+ password: Optional[str] = None
+ token: Optional[str] = None
+
+ @classmethod
+ def model_validate(cls, data: Dict[str, Any]) -> "HttpCredentials":
+ return cls(
+ username=data.get("username"),
+ password=data.get("password"),
+ token=data.get("token"),
+ )
+
+
+class HttpAuth(BaseModelWithConfig):
+ """The credentials and metadata for HTTP authentication."""
+
+ # The name of the HTTP Authorization scheme to be used in the Authorization
+ # header as defined in RFC7235. The values used SHOULD be registered in the
+ # IANA Authentication Scheme registry.
+ # Examples: 'basic', 'bearer'
+ scheme: str
+ credentials: HttpCredentials
+
+
+class OAuth2Auth(BaseModelWithConfig):
+ """Represents credential value and its metadata for a OAuth2 credential."""
+
+ client_id: Optional[str] = None
+ client_secret: Optional[str] = None
+ # tool or adk can generate the auth_uri with the state info thus client
+ # can verify the state
+ auth_uri: Optional[str] = None
+ state: Optional[str] = None
+ # tool or adk can decide the redirect_uri if they don't want client to decide
+ redirect_uri: Optional[str] = None
+ auth_response_uri: Optional[str] = None
+ auth_code: Optional[str] = None
+ token: Optional[Dict[str, Any]] = None
+
+
+class ServiceAccountCredential(BaseModelWithConfig):
+ """Represents Google Service Account configuration.
+
+ Attributes:
+ type: The type should be "service_account".
+ project_id: The project ID.
+ private_key_id: The ID of the private key.
+ private_key: The private key.
+ client_email: The client email.
+ client_id: The client ID.
+ auth_uri: The authorization URI.
+ token_uri: The token URI.
+ auth_provider_x509_cert_url: URL for auth provider's X.509 cert.
+ client_x509_cert_url: URL for the client's X.509 cert.
+ universe_domain: The universe domain.
+
+ Example:
+
+ config = ServiceAccountCredential(
+ type_="service_account",
+ project_id="your_project_id",
+ private_key_id="your_private_key_id",
+ private_key="-----BEGIN PRIVATE KEY-----...",
+ client_email="...@....iam.gserviceaccount.com",
+ client_id="your_client_id",
+ auth_uri="https://accounts.google.com/o/oauth2/auth",
+ token_uri="https://oauth2.googleapis.com/token",
+ auth_provider_x509_cert_url="https://www.googleapis.com/oauth2/v1/certs",
+ client_x509_cert_url="https://www.googleapis.com/robot/v1/metadata/x509/...",
+ universe_domain="googleapis.com"
+ )
+
+
+ config = ServiceAccountConfig.model_construct(**{
+ ...service account config dict
+ })
+ """
+
+ type_: str = Field("", alias="type")
+ project_id: str
+ private_key_id: str
+ private_key: str
+ client_email: str
+ client_id: str
+ auth_uri: str
+ token_uri: str
+ auth_provider_x509_cert_url: str
+ client_x509_cert_url: str
+ universe_domain: str
+
+
+class ServiceAccount(BaseModelWithConfig):
+ """Represents Google Service Account configuration."""
+
+ service_account_credential: Optional[ServiceAccountCredential] = None
+ scopes: List[str]
+ use_default_credential: Optional[bool] = False
+
+
+class AuthCredentialTypes(str, Enum):
+ """Represents the type of authentication credential."""
+
+ # API Key credential:
+ # https://swagger.io/docs/specification/v3_0/authentication/api-keys/
+ API_KEY = "apiKey"
+
+ # Credentials for HTTP Auth schemes:
+ # https://www.iana.org/assignments/http-authschemes/http-authschemes.xhtml
+ HTTP = "http"
+
+ # OAuth2 credentials:
+ # https://swagger.io/docs/specification/v3_0/authentication/oauth2/
+ OAUTH2 = "oauth2"
+
+ # OpenID Connect credentials:
+ # https://swagger.io/docs/specification/v3_0/authentication/openid-connect-discovery/
+ OPEN_ID_CONNECT = "openIdConnect"
+
+ # Service Account credentials:
+ # https://cloud.google.com/iam/docs/service-account-creds
+ SERVICE_ACCOUNT = "serviceAccount"
+
+
+class AuthCredential(BaseModelWithConfig):
+ """Data class representing an authentication credential.
+
+ To exchange for the actual credential, please use
+ CredentialExchanger.exchange_credential().
+
+ Examples: API Key Auth
+ AuthCredential(
+ auth_type=AuthCredentialTypes.API_KEY,
+ api_key="1234",
+ )
+
+ Example: HTTP Auth
+ AuthCredential(
+ auth_type=AuthCredentialTypes.HTTP,
+ http=HttpAuth(
+ scheme="basic",
+ credentials=HttpCredentials(username="user", password="password"),
+ ),
+ )
+
+ Example: OAuth2 Bearer Token in HTTP Header
+ AuthCredential(
+ auth_type=AuthCredentialTypes.HTTP,
+ http=HttpAuth(
+ scheme="bearer",
+ credentials=HttpCredentials(token="eyAkaknabna...."),
+ ),
+ )
+
+ Example: OAuth2 Auth with Authorization Code Flow
+ AuthCredential(
+ auth_type=AuthCredentialTypes.OAUTH2,
+ oauth2=OAuth2Auth(
+ client_id="1234",
+ client_secret="secret",
+ ),
+ )
+
+ Example: OpenID Connect Auth
+ AuthCredential(
+ auth_type=AuthCredentialTypes.OPEN_ID_CONNECT,
+ oauth2=OAuth2Auth(
+ client_id="1234",
+ client_secret="secret",
+ redirect_uri="https://example.com",
+ scopes=["scope1", "scope2"],
+ ),
+ )
+
+ Example: Auth with resource reference
+ AuthCredential(
+ auth_type=AuthCredentialTypes.API_KEY,
+ resource_ref="projects/1234/locations/us-central1/resources/resource1",
+ )
+ """
+
+ auth_type: AuthCredentialTypes
+ # Resource reference for the credential.
+ # This will be supported in the future.
+ resource_ref: Optional[str] = None
+
+ api_key: Optional[str] = None
+ http: Optional[HttpAuth] = None
+ service_account: Optional[ServiceAccount] = None
+ oauth2: Optional[OAuth2Auth] = None
diff --git a/src/google/adk/auth/auth_handler.py b/src/google/adk/auth/auth_handler.py
new file mode 100644
index 0000000..268d666
--- /dev/null
+++ b/src/google/adk/auth/auth_handler.py
@@ -0,0 +1,265 @@
+# Copyright 2025 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import annotations
+
+from typing import TYPE_CHECKING
+
+from fastapi.openapi.models import OAuth2
+from fastapi.openapi.models import SecurityBase
+
+from .auth_credential import AuthCredential
+from .auth_credential import AuthCredentialTypes
+from .auth_credential import OAuth2Auth
+from .auth_schemes import AuthSchemeType
+from .auth_schemes import OAuthGrantType
+from .auth_schemes import OpenIdConnectWithConfig
+from .auth_tool import AuthConfig
+
+if TYPE_CHECKING:
+ from ..sessions.state import State
+
+try:
+ from authlib.integrations.requests_client import OAuth2Session
+
+ SUPPORT_TOKEN_EXCHANGE = True
+except ImportError:
+ SUPPORT_TOKEN_EXCHANGE = False
+
+
+class AuthHandler:
+
+ def __init__(self, auth_config: AuthConfig):
+ self.auth_config = auth_config
+
+ def exchange_auth_token(
+ self,
+ ) -> AuthCredential:
+ """Generates an auth token from the authorization response.
+
+ Returns:
+ An AuthCredential object containing the access token.
+
+ Raises:
+ ValueError: If the token endpoint is not configured in the auth
+ scheme.
+ AuthCredentialMissingError: If the access token cannot be retrieved
+ from the token endpoint.
+ """
+ auth_scheme = self.auth_config.auth_scheme
+ auth_credential = self.auth_config.exchanged_auth_credential
+ if not SUPPORT_TOKEN_EXCHANGE:
+ return auth_credential
+ if isinstance(auth_scheme, OpenIdConnectWithConfig):
+ if not hasattr(auth_scheme, "token_endpoint"):
+ return self.auth_config.exchanged_auth_credential
+ token_endpoint = auth_scheme.token_endpoint
+ scopes = auth_scheme.scopes
+ elif isinstance(auth_scheme, OAuth2):
+ if (
+ not auth_scheme.flows.authorizationCode
+ or not auth_scheme.flows.authorizationCode.tokenUrl
+ ):
+ return self.auth_config.exchanged_auth_credential
+ token_endpoint = auth_scheme.flows.authorizationCode.tokenUrl
+ scopes = list(auth_scheme.flows.authorizationCode.scopes.keys())
+ else:
+ return self.auth_config.exchanged_auth_credential
+
+ if (
+ not auth_credential
+ or not auth_credential.oauth2
+ or not auth_credential.oauth2.client_id
+ or not auth_credential.oauth2.client_secret
+ or auth_credential.oauth2.token
+ ):
+ return self.auth_config.exchanged_auth_credential
+
+ client = OAuth2Session(
+ auth_credential.oauth2.client_id,
+ auth_credential.oauth2.client_secret,
+ scope=",".join(scopes),
+ redirect_uri=auth_credential.oauth2.redirect_uri,
+ state=auth_credential.oauth2.state,
+ )
+ token = client.fetch_token(
+ token_endpoint,
+ authorization_response=auth_credential.oauth2.auth_response_uri,
+ code=auth_credential.oauth2.auth_code,
+ grant_type=OAuthGrantType.AUTHORIZATION_CODE,
+ )
+
+ updated_credential = AuthCredential(
+ auth_type=AuthCredentialTypes.OAUTH2,
+ oauth2=OAuth2Auth(token=dict(token)),
+ )
+ return updated_credential
+
+ def parse_and_store_auth_response(self, state: State) -> None:
+
+ credential_key = self.get_credential_key()
+
+ state[credential_key] = self.auth_config.exchanged_auth_credential
+ if not isinstance(
+ self.auth_config.auth_scheme, SecurityBase
+ ) or self.auth_config.auth_scheme.type_ not in (
+ AuthSchemeType.oauth2,
+ AuthSchemeType.openIdConnect,
+ ):
+ return
+
+ state[credential_key] = self.exchange_auth_token()
+
+ def _validate(self) -> None:
+ if not self.auth_scheme:
+ raise ValueError("auth_scheme is empty.")
+
+ def get_auth_response(self, state: State) -> AuthCredential:
+ credential_key = self.get_credential_key()
+ return state.get(credential_key, None)
+
+ def generate_auth_request(self) -> AuthConfig:
+ if not isinstance(
+ self.auth_config.auth_scheme, SecurityBase
+ ) or self.auth_config.auth_scheme.type_ not in (
+ AuthSchemeType.oauth2,
+ AuthSchemeType.openIdConnect,
+ ):
+ return self.auth_config.model_copy(deep=True)
+
+ # auth_uri already in exchanged credential
+ if (
+ self.auth_config.exchanged_auth_credential
+ and self.auth_config.exchanged_auth_credential.oauth2
+ and self.auth_config.exchanged_auth_credential.oauth2.auth_uri
+ ):
+ return self.auth_config.model_copy(deep=True)
+
+ # Check if raw_auth_credential exists
+ if not self.auth_config.raw_auth_credential:
+ raise ValueError(
+ f"Auth Scheme {self.auth_config.auth_scheme.type_} requires"
+ " auth_credential."
+ )
+
+ # Check if oauth2 exists in raw_auth_credential
+ if not self.auth_config.raw_auth_credential.oauth2:
+ raise ValueError(
+ f"Auth Scheme {self.auth_config.auth_scheme.type_} requires oauth2 in"
+ " auth_credential."
+ )
+
+ # auth_uri in raw credential
+ if self.auth_config.raw_auth_credential.oauth2.auth_uri:
+ return AuthConfig(
+ auth_scheme=self.auth_config.auth_scheme,
+ raw_auth_credential=self.auth_config.raw_auth_credential,
+ exchanged_auth_credential=self.auth_config.raw_auth_credential.model_copy(
+ deep=True
+ ),
+ )
+
+ # Check for client_id and client_secret
+ if (
+ not self.auth_config.raw_auth_credential.oauth2.client_id
+ or not self.auth_config.raw_auth_credential.oauth2.client_secret
+ ):
+ raise ValueError(
+ f"Auth Scheme {self.auth_config.auth_scheme.type_} requires both"
+ " client_id and client_secret in auth_credential.oauth2."
+ )
+
+ # Generate new auth URI
+ exchanged_credential = self.generate_auth_uri()
+ return AuthConfig(
+ auth_scheme=self.auth_config.auth_scheme,
+ raw_auth_credential=self.auth_config.raw_auth_credential,
+ exchanged_auth_credential=exchanged_credential,
+ )
+
+ def get_credential_key(self) -> str:
+ """Generates a unique key for the given auth scheme and credential."""
+ auth_scheme = self.auth_config.auth_scheme
+ auth_credential = self.auth_config.raw_auth_credential
+ if auth_scheme.model_extra:
+ auth_scheme = auth_scheme.model_copy(deep=True)
+ auth_scheme.model_extra.clear()
+ scheme_name = (
+ f"{auth_scheme.type_.name}_{hash(auth_scheme.model_dump_json())}"
+ if auth_scheme
+ else ""
+ )
+ if auth_credential.model_extra:
+ auth_credential = auth_credential.model_copy(deep=True)
+ auth_credential.model_extra.clear()
+ credential_name = (
+ f"{auth_credential.auth_type.value}_{hash(auth_credential.model_dump_json())}"
+ if auth_credential
+ else ""
+ )
+
+ return f"temp:adk_{scheme_name}_{credential_name}"
+
+ def generate_auth_uri(
+ self,
+ ) -> AuthCredential:
+ """Generates an response containing the auth uri for user to sign in.
+
+ Returns:
+ An AuthCredential object containing the auth URI and state.
+
+ Raises:
+ ValueError: If the authorization endpoint is not configured in the auth
+ scheme.
+ """
+ auth_scheme = self.auth_config.auth_scheme
+ auth_credential = self.auth_config.raw_auth_credential
+
+ if isinstance(auth_scheme, OpenIdConnectWithConfig):
+ authorization_endpoint = auth_scheme.authorization_endpoint
+ scopes = auth_scheme.scopes
+ else:
+ authorization_endpoint = (
+ auth_scheme.flows.implicit
+ and auth_scheme.flows.implicit.authorizationUrl
+ or auth_scheme.flows.authorizationCode
+ and auth_scheme.flows.authorizationCode.authorizationUrl
+ or auth_scheme.flows.clientCredentials
+ and auth_scheme.flows.clientCredentials.tokenUrl
+ or auth_scheme.flows.password
+ and auth_scheme.flows.password.tokenUrl
+ )
+ scopes = (
+ auth_scheme.flows.implicit
+ and auth_scheme.flows.implicit.scopes
+ or auth_scheme.flows.authorizationCode
+ and auth_scheme.flows.authorizationCode.scopes
+ or auth_scheme.flows.clientCredentials
+ and auth_scheme.flows.clientCredentials.scopes
+ or auth_scheme.flows.password
+ and auth_scheme.flows.password.scopes
+ )
+
+ client = OAuth2Session(
+ auth_credential.oauth2.client_id,
+ auth_credential.oauth2.client_secret,
+ scope=" ".join(scopes),
+ redirect_uri=auth_credential.oauth2.redirect_uri,
+ )
+ uri, state = client.create_authorization_url(url=authorization_endpoint)
+ exchanged_auth_credential = auth_credential.model_copy(deep=True)
+ exchanged_auth_credential.oauth2.auth_uri = uri
+ exchanged_auth_credential.oauth2.state = state
+
+ return exchanged_auth_credential
diff --git a/src/google/adk/auth/auth_preprocessor.py b/src/google/adk/auth/auth_preprocessor.py
new file mode 100644
index 0000000..bf2c432
--- /dev/null
+++ b/src/google/adk/auth/auth_preprocessor.py
@@ -0,0 +1,116 @@
+# Copyright 2025 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import annotations
+
+from typing import AsyncGenerator
+from typing import TYPE_CHECKING
+
+from typing_extensions import override
+
+from ..agents.invocation_context import InvocationContext
+from ..events.event import Event
+from ..flows.llm_flows import functions
+from ..flows.llm_flows._base_llm_processor import BaseLlmRequestProcessor
+from ..flows.llm_flows.functions import REQUEST_EUC_FUNCTION_CALL_NAME
+from ..models.llm_request import LlmRequest
+from .auth_handler import AuthHandler
+from .auth_tool import AuthConfig
+from .auth_tool import AuthToolArguments
+
+if TYPE_CHECKING:
+ from ..agents.llm_agent import LlmAgent
+
+
+class _AuthLlmRequestProcessor(BaseLlmRequestProcessor):
+ """Handles auth information to build the LLM request."""
+
+ @override
+ async def run_async(
+ self, invocation_context: InvocationContext, llm_request: LlmRequest
+ ) -> AsyncGenerator[Event, None]:
+ from ..agents.llm_agent import LlmAgent
+
+ agent = invocation_context.agent
+ if not isinstance(agent, LlmAgent):
+ return
+ events = invocation_context.session.events
+ if not events:
+ return
+ request_euc_function_call_response_event = events[-1]
+ responses = (
+ request_euc_function_call_response_event.get_function_responses()
+ )
+ if not responses:
+ return
+
+ request_euc_function_call_ids = set()
+
+ for function_call_response in responses:
+ if function_call_response.name != REQUEST_EUC_FUNCTION_CALL_NAME:
+ continue
+
+ # found the function call response for the system long running request euc
+ # function call
+ request_euc_function_call_ids.add(function_call_response.id)
+ auth_config = AuthConfig.model_validate(function_call_response.response)
+ AuthHandler(auth_config=auth_config).parse_and_store_auth_response(
+ state=invocation_context.session.state
+ )
+
+ if not request_euc_function_call_ids:
+ return
+
+ for i in range(len(events) - 2, -1, -1):
+ event = events[i]
+ # looking for the system long running reqeust euc function call
+ function_calls = event.get_function_calls()
+ if not function_calls:
+ continue
+
+ tools_to_resume = set()
+
+ for function_call in function_calls:
+ if function_call.id not in request_euc_function_call_ids:
+ continue
+ args = AuthToolArguments.model_validate(function_call.args)
+
+ tools_to_resume.add(args.function_call_id)
+ if not tools_to_resume:
+ continue
+ # found the the system long running reqeust euc function call
+ # looking for original function call that requests euc
+ for j in range(i - 1, -1, -1):
+ event = events[j]
+ function_calls = event.get_function_calls()
+ if not function_calls:
+ continue
+ for function_call in function_calls:
+ function_response_event = None
+ if function_call.id in tools_to_resume:
+ function_response_event = await functions.handle_function_calls_async(
+ invocation_context,
+ event,
+ {tool.name: tool for tool in agent.canonical_tools},
+ # there could be parallel function calls that require auth
+ # auth response would be a dict keyed by function call id
+ tools_to_resume,
+ )
+ if function_response_event:
+ yield function_response_event
+ return
+ return
+
+
+request_processor = _AuthLlmRequestProcessor()
diff --git a/src/google/adk/auth/auth_schemes.py b/src/google/adk/auth/auth_schemes.py
new file mode 100644
index 0000000..baccf64
--- /dev/null
+++ b/src/google/adk/auth/auth_schemes.py
@@ -0,0 +1,67 @@
+# Copyright 2025 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from enum import Enum
+from typing import List
+from typing import Optional
+from typing import Union
+
+from fastapi.openapi.models import OAuthFlows
+from fastapi.openapi.models import SecurityBase
+from fastapi.openapi.models import SecurityScheme
+from fastapi.openapi.models import SecuritySchemeType
+from pydantic import Field
+
+
+class OpenIdConnectWithConfig(SecurityBase):
+ type_: SecuritySchemeType = Field(
+ default=SecuritySchemeType.openIdConnect, alias="type"
+ )
+ authorization_endpoint: str
+ token_endpoint: str
+ userinfo_endpoint: Optional[str] = None
+ revocation_endpoint: Optional[str] = None
+ token_endpoint_auth_methods_supported: Optional[List[str]] = None
+ grant_types_supported: Optional[List[str]] = None
+ scopes: Optional[List[str]] = None
+
+
+# AuthSchemes contains SecuritySchemes from OpenAPI 3.0 and an extra flattened OpenIdConnectWithConfig.
+AuthScheme = Union[SecurityScheme, OpenIdConnectWithConfig]
+
+
+class OAuthGrantType(str, Enum):
+ """Represents the OAuth2 flow (or grant type)."""
+
+ CLIENT_CREDENTIALS = "client_credentials"
+ AUTHORIZATION_CODE = "authorization_code"
+ IMPLICIT = "implicit"
+ PASSWORD = "password"
+
+ @staticmethod
+ def from_flow(flow: OAuthFlows) -> "OAuthGrantType":
+ """Converts an OAuthFlows object to a OAuthGrantType."""
+ if flow.clientCredentials:
+ return OAuthGrantType.CLIENT_CREDENTIALS
+ if flow.authorizationCode:
+ return OAuthGrantType.AUTHORIZATION_CODE
+ if flow.implicit:
+ return OAuthGrantType.IMPLICIT
+ if flow.password:
+ return OAuthGrantType.PASSWORD
+ return None
+
+
+# AuthSchemeType re-exports SecuritySchemeType from OpenAPI 3.0.
+AuthSchemeType = SecuritySchemeType
diff --git a/src/google/adk/auth/auth_tool.py b/src/google/adk/auth/auth_tool.py
new file mode 100644
index 0000000..bd3f845
--- /dev/null
+++ b/src/google/adk/auth/auth_tool.py
@@ -0,0 +1,55 @@
+# Copyright 2025 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from pydantic import BaseModel
+
+from .auth_credential import AuthCredential
+from .auth_schemes import AuthScheme
+
+
+class AuthConfig(BaseModel):
+ """The auth config sent by tool asking client to collect auth credentails and
+
+ adk and client will help to fill in the response
+ """
+
+ auth_scheme: AuthScheme
+ """The auth scheme used to collect credentials"""
+ raw_auth_credential: AuthCredential = None
+ """The raw auth credential used to collect credentials. The raw auth
+ credentials are used in some auth scheme that needs to exchange auth
+ credentials. e.g. OAuth2 and OIDC. For other auth scheme, it could be None.
+ """
+ exchanged_auth_credential: AuthCredential = None
+ """The exchanged auth credential used to collect credentials. adk and client
+ will work together to fill it. For those auth scheme that doesn't need to
+ exchange auth credentials, e.g. API key, service account etc. It's filled by
+ client directly. For those auth scheme that need to exchange auth credentials,
+ e.g. OAuth2 and OIDC, it's first filled by adk. If the raw credentials
+ passed by tool only has client id and client credential, adk will help to
+ generate the corresponding authorization uri and state and store the processed
+ credential in this field. If the raw credentials passed by tool already has
+ authorization uri, state, etc. then it's copied to this field. Client will use
+ this field to guide the user through the OAuth2 flow and fill auth response in
+ this field"""
+
+
+class AuthToolArguments(BaseModel):
+ """the arguments for the special long running function tool that is used to
+
+ request end user credentials.
+ """
+
+ function_call_id: str
+ auth_config: AuthConfig
diff --git a/src/google/adk/cli/__init__.py b/src/google/adk/cli/__init__.py
new file mode 100644
index 0000000..1785dda
--- /dev/null
+++ b/src/google/adk/cli/__init__.py
@@ -0,0 +1,15 @@
+# Copyright 2025 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from .cli_tools_click import main
diff --git a/src/google/adk/cli/__main__.py b/src/google/adk/cli/__main__.py
new file mode 100644
index 0000000..6f62a4d
--- /dev/null
+++ b/src/google/adk/cli/__main__.py
@@ -0,0 +1,18 @@
+# Copyright 2025 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from .cli_tools_click import main
+
+if __name__ == '__main__':
+ main()
diff --git a/src/google/adk/cli/agent_graph.py b/src/google/adk/cli/agent_graph.py
new file mode 100644
index 0000000..5f4e76d
--- /dev/null
+++ b/src/google/adk/cli/agent_graph.py
@@ -0,0 +1,122 @@
+# Copyright 2025 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import Union
+
+import graphviz
+
+from ..agents import BaseAgent
+from ..agents.llm_agent import LlmAgent
+from ..tools.agent_tool import AgentTool
+from ..tools.base_tool import BaseTool
+from ..tools.function_tool import FunctionTool
+from ..tools.retrieval.base_retrieval_tool import BaseRetrievalTool
+
+
+def build_graph(graph, agent: BaseAgent, highlight_pairs):
+ dark_green = '#0F5223'
+ light_green = '#69CB87'
+ light_gray = '#cccccc'
+
+ def get_node_name(tool_or_agent: Union[BaseAgent, BaseTool]):
+ if isinstance(tool_or_agent, BaseAgent):
+ return tool_or_agent.name
+ elif isinstance(tool_or_agent, BaseTool):
+ return tool_or_agent.name
+ else:
+ raise ValueError(f'Unsupported tool type: {tool_or_agent}')
+
+ def get_node_caption(tool_or_agent: Union[BaseAgent, BaseTool]):
+ if isinstance(tool_or_agent, BaseAgent):
+ return '🤖 ' + tool_or_agent.name
+ elif isinstance(tool_or_agent, BaseRetrievalTool):
+ return '🔎 ' + tool_or_agent.name
+ elif isinstance(tool_or_agent, FunctionTool):
+ return '🔧 ' + tool_or_agent.name
+ elif isinstance(tool_or_agent, AgentTool):
+ return '🤖 ' + tool_or_agent.name
+ elif isinstance(tool_or_agent, BaseTool):
+ return '🔧 ' + tool_or_agent.name
+ else:
+ raise ValueError(f'Unsupported tool type: {type(tool)}')
+
+ def get_node_shape(tool_or_agent: Union[BaseAgent, BaseTool]):
+ if isinstance(tool_or_agent, BaseAgent):
+ return 'ellipse'
+ elif isinstance(tool_or_agent, BaseRetrievalTool):
+ return 'cylinder'
+ elif isinstance(tool_or_agent, FunctionTool):
+ return 'box'
+ elif isinstance(tool_or_agent, BaseTool):
+ return 'box'
+ else:
+ raise ValueError(f'Unsupported tool type: {type(tool_or_agent)}')
+
+ def draw_node(tool_or_agent: Union[BaseAgent, BaseTool]):
+ name = get_node_name(tool_or_agent)
+ shape = get_node_shape(tool_or_agent)
+ caption = get_node_caption(tool_or_agent)
+ if highlight_pairs:
+ for highlight_tuple in highlight_pairs:
+ if name in highlight_tuple:
+ graph.node(
+ name,
+ caption,
+ style='filled,rounded',
+ fillcolor=dark_green,
+ color=dark_green,
+ shape=shape,
+ fontcolor=light_gray,
+ )
+ return
+ # if not in highlight, draw non-highliht node
+ graph.node(
+ name,
+ caption,
+ shape=shape,
+ style='rounded',
+ color=light_gray,
+ fontcolor=light_gray,
+ )
+
+ def draw_edge(from_name, to_name):
+ if highlight_pairs:
+ for highlight_from, highlight_to in highlight_pairs:
+ if from_name == highlight_from and to_name == highlight_to:
+ graph.edge(from_name, to_name, color=light_green)
+ return
+ elif from_name == highlight_to and to_name == highlight_from:
+ graph.edge(from_name, to_name, color=light_green, dir='back')
+ return
+ # if no need to highlight, color gray
+ graph.edge(from_name, to_name, arrowhead='none', color=light_gray)
+
+ draw_node(agent)
+ for sub_agent in agent.sub_agents:
+ build_graph(graph, sub_agent, highlight_pairs)
+ draw_edge(agent.name, sub_agent.name)
+ if isinstance(agent, LlmAgent):
+ for tool in agent.canonical_tools:
+ draw_node(tool)
+ draw_edge(agent.name, get_node_name(tool))
+
+
+def get_agent_graph(root_agent, highlights_pairs, image=False):
+ print('build graph')
+ graph = graphviz.Digraph(graph_attr={'rankdir': 'LR', 'bgcolor': '#333537'})
+ build_graph(graph, root_agent, highlights_pairs)
+ if image:
+ return graph.pipe(format='png')
+ else:
+ return graph
diff --git a/src/google/adk/cli/cli.py b/src/google/adk/cli/cli.py
new file mode 100644
index 0000000..ac6d845
--- /dev/null
+++ b/src/google/adk/cli/cli.py
@@ -0,0 +1,181 @@
+# Copyright 2025 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from datetime import datetime
+import importlib
+import os
+import sys
+from typing import Optional
+
+import click
+from google.genai import types
+from pydantic import BaseModel
+
+from ..agents.llm_agent import LlmAgent
+from ..artifacts import BaseArtifactService
+from ..artifacts import InMemoryArtifactService
+from ..runners import Runner
+from ..sessions.base_session_service import BaseSessionService
+from ..sessions.in_memory_session_service import InMemorySessionService
+from ..sessions.session import Session
+from .utils import envs
+
+
+class InputFile(BaseModel):
+ state: dict[str, object]
+ queries: list[str]
+
+
+async def run_input_file(
+ app_name: str,
+ root_agent: LlmAgent,
+ artifact_service: BaseArtifactService,
+ session: Session,
+ session_service: BaseSessionService,
+ input_path: str,
+) -> None:
+ runner = Runner(
+ app_name=app_name,
+ agent=root_agent,
+ artifact_service=artifact_service,
+ session_service=session_service,
+ )
+ with open(input_path, 'r', encoding='utf-8') as f:
+ input_file = InputFile.model_validate_json(f.read())
+ input_file.state['_time'] = datetime.now()
+
+ session.state = input_file.state
+ for query in input_file.queries:
+ click.echo(f'user: {query}')
+ content = types.Content(role='user', parts=[types.Part(text=query)])
+ async for event in runner.run_async(
+ user_id=session.user_id, session_id=session.id, new_message=content
+ ):
+ if event.content and event.content.parts:
+ if text := ''.join(part.text or '' for part in event.content.parts):
+ click.echo(f'[{event.author}]: {text}')
+
+
+async def run_interactively(
+ app_name: str,
+ root_agent: LlmAgent,
+ artifact_service: BaseArtifactService,
+ session: Session,
+ session_service: BaseSessionService,
+) -> None:
+ runner = Runner(
+ app_name=app_name,
+ agent=root_agent,
+ artifact_service=artifact_service,
+ session_service=session_service,
+ )
+ while True:
+ query = input('user: ')
+ if query == 'exit':
+ break
+ async for event in runner.run_async(
+ user_id=session.user_id,
+ session_id=session.id,
+ new_message=types.Content(role='user', parts=[types.Part(text=query)]),
+ ):
+ if event.content and event.content.parts:
+ if text := ''.join(part.text or '' for part in event.content.parts):
+ click.echo(f'[{event.author}]: {text}')
+
+
+async def run_cli(
+ *,
+ agent_parent_dir: str,
+ agent_folder_name: str,
+ json_file_path: Optional[str] = None,
+ save_session: bool,
+) -> None:
+ """Runs an interactive CLI for a certain agent.
+
+ Args:
+ agent_parent_dir: str, the absolute path of the parent folder of the agent
+ folder.
+ agent_folder_name: str, the name of the agent folder.
+ json_file_path: Optional[str], the absolute path to the json file, either
+ *.input.json or *.session.json.
+ save_session: bool, whether to save the session on exit.
+ """
+ if agent_parent_dir not in sys.path:
+ sys.path.append(agent_parent_dir)
+
+ artifact_service = InMemoryArtifactService()
+ session_service = InMemorySessionService()
+ session = session_service.create_session(
+ app_name=agent_folder_name, user_id='test_user'
+ )
+
+ agent_module_path = os.path.join(agent_parent_dir, agent_folder_name)
+ agent_module = importlib.import_module(agent_folder_name)
+ root_agent = agent_module.agent.root_agent
+ envs.load_dotenv_for_agent(agent_folder_name, agent_parent_dir)
+ if json_file_path:
+ if json_file_path.endswith('.input.json'):
+ await run_input_file(
+ app_name=agent_folder_name,
+ root_agent=root_agent,
+ artifact_service=artifact_service,
+ session=session,
+ session_service=session_service,
+ input_path=json_file_path,
+ )
+ elif json_file_path.endswith('.session.json'):
+ with open(json_file_path, 'r') as f:
+ session = Session.model_validate_json(f.read())
+ for content in session.get_contents():
+ if content.role == 'user':
+ print('user: ', content.parts[0].text)
+ else:
+ print(content.parts[0].text)
+ await run_interactively(
+ agent_folder_name,
+ root_agent,
+ artifact_service,
+ session,
+ session_service,
+ )
+ else:
+ print(f'Unsupported file type: {json_file_path}')
+ exit(1)
+ else:
+ print(f'Running agent {root_agent.name}, type exit to exit.')
+ await run_interactively(
+ agent_folder_name,
+ root_agent,
+ artifact_service,
+ session,
+ session_service,
+ )
+
+ if save_session:
+ if json_file_path:
+ session_path = json_file_path.replace('.input.json', '.session.json')
+ else:
+ session_id = input('Session ID to save: ')
+ session_path = f'{agent_module_path}/{session_id}.session.json'
+ with open(session_path, 'w') as f:
+ f.write(session.model_dump_json(indent=2, exclude_none=True))
+ # TODO: Save from opentelemetry.
+ # logs_path = session_path.replace('.session.json', '.logs.json')
+ # with open(logs_path, 'w') as f:
+ # f.write(
+ # session.model_dump_json(
+ # indent=2, exclude_none=True, include='event_logs'
+ # )
+ # )
+ print('Session saved to', session_path)
diff --git a/src/google/adk/cli/cli_deploy.py b/src/google/adk/cli/cli_deploy.py
new file mode 100644
index 0000000..478fc91
--- /dev/null
+++ b/src/google/adk/cli/cli_deploy.py
@@ -0,0 +1,181 @@
+# Copyright 2025 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os
+import shutil
+import subprocess
+from typing import Optional
+
+import click
+
+_DOCKERFILE_TEMPLATE = """
+FROM python:3.11-slim
+WORKDIR /app
+
+# Create a non-root user
+RUN adduser --disabled-password --gecos "" myuser
+
+# Change ownership of /app to myuser
+RUN chown -R myuser:myuser /app
+
+# Switch to the non-root user
+USER myuser
+
+# Set up environment variables - Start
+ENV PATH="/home/myuser/.local/bin:$PATH"
+
+ENV GOOGLE_GENAI_USE_VERTEXAI=1
+# TODO: use passed-in value
+ENV GOOGLE_CLOUD_PROJECT={gcp_project_id}
+ENV GOOGLE_CLOUD_LOCATION={gcp_region}
+ENV ADK_TRACE_TO_CLOUD={with_cloud_trace}
+
+# Set up environment variables - End
+
+# Install ADK - Start
+RUN pip install google-adk
+# Install ADK - End
+
+# Copy agent - Start
+
+COPY "agents/{app_name}/" "/app/agents/{app_name}/"
+{install_agent_deps}
+
+# Copy agent - End
+
+EXPOSE {port}
+
+CMD adk {command} --port={port} "/app/agents"
+"""
+
+
+def _resolve_project(project_in_option: Optional[str]) -> str:
+ if project_in_option:
+ return project_in_option
+
+ result = subprocess.run(
+ ['gcloud', 'config', 'get-value', 'project'],
+ check=True,
+ capture_output=True,
+ text=True,
+ )
+ project = result.stdout.strip()
+ click.echo(f'Use default project: {project}')
+ return project
+
+
+def to_cloud_run(
+ *,
+ agent_folder: str,
+ project: Optional[str],
+ region: Optional[str],
+ service_name: str,
+ app_name: str,
+ temp_folder: str,
+ port: int,
+ with_cloud_trace: bool,
+ with_ui: bool,
+):
+ """Deploys an agent to Google Cloud Run.
+
+ `agent_folder` should contain the following files:
+
+ - __init__.py
+ - agent.py
+ - requirements.txt (optional, for additional dependencies)
+ - ... (other required source files)
+
+ The folder structure of temp_folder will be
+
+ * dist/[google_adk wheel file]
+ * agents/[app_name]/
+ * agent source code from `agent_folder`
+
+ Args:
+ agent_folder: The folder (absolute path) containing the agent source code.
+ project: Google Cloud project id.
+ region: Google Cloud region.
+ service_name: The service name in Cloud Run.
+ app_name: The name of the app, by default, it's basename of `agent_folder`.
+ temp_folder: The temp folder for the generated Cloud Run source files.
+ port: The port of the ADK api server.
+ with_cloud_trace: Whether to enable Cloud Trace.
+ with_ui: Whether to deploy with UI.
+ """
+ app_name = app_name or os.path.basename(agent_folder)
+
+ click.echo(f'Start generating Cloud Run source files in {temp_folder}')
+
+ # remove temp_folder if exists
+ if os.path.exists(temp_folder):
+ click.echo('Removing existing files')
+ shutil.rmtree(temp_folder)
+
+ try:
+ # copy agent source code
+ click.echo('Copying agent source code...')
+ agent_src_path = os.path.join(temp_folder, 'agents', app_name)
+ shutil.copytree(agent_folder, agent_src_path)
+ requirements_txt_path = os.path.join(agent_src_path, 'requirements.txt')
+ install_agent_deps = (
+ f'RUN pip install -r "/app/agents/{app_name}/requirements.txt"'
+ if os.path.exists(requirements_txt_path)
+ else ''
+ )
+ click.echo('Copying agent source code complete.')
+
+ # create Dockerfile
+ click.echo('Creating Dockerfile...')
+ dockerfile_content = _DOCKERFILE_TEMPLATE.format(
+ gcp_project_id=project,
+ gcp_region=region,
+ app_name=app_name,
+ port=port,
+ command='web' if with_ui else 'api_server',
+ install_agent_deps=install_agent_deps,
+ with_cloud_trace='1' if with_cloud_trace else '0',
+ )
+ dockerfile_path = os.path.join(temp_folder, 'Dockerfile')
+ os.makedirs(temp_folder, exist_ok=True)
+ with open(dockerfile_path, 'w', encoding='utf-8') as f:
+ f.write(
+ dockerfile_content,
+ )
+ click.echo(f'Creating Dockerfile complete: {dockerfile_path}')
+
+ # Deploy to Cloud Run
+ click.echo('Deploying to Cloud Run...')
+ region_options = ['--region', region] if region else []
+ project = _resolve_project(project)
+ subprocess.run(
+ [
+ 'gcloud',
+ 'run',
+ 'deploy',
+ service_name,
+ '--source',
+ temp_folder,
+ '--project',
+ project,
+ *region_options,
+ '--port',
+ str(port),
+ '--labels',
+ 'created-by=adk',
+ ],
+ check=True,
+ )
+ finally:
+ click.echo(f'Cleaning up the temp folder: {temp_folder}')
+ shutil.rmtree(temp_folder)
diff --git a/src/google/adk/cli/cli_eval.py b/src/google/adk/cli/cli_eval.py
new file mode 100644
index 0000000..68fe9f7
--- /dev/null
+++ b/src/google/adk/cli/cli_eval.py
@@ -0,0 +1,282 @@
+# Copyright 2025 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from enum import Enum
+import importlib.util
+import json
+import logging
+import os
+import sys
+import traceback
+from typing import Any
+from typing import Generator
+from typing import Optional
+import uuid
+
+from pydantic import BaseModel
+
+from ..agents import Agent
+
+logger = logging.getLogger(__name__)
+
+
+class EvalStatus(Enum):
+ PASSED = 1
+ FAILED = 2
+ NOT_EVALUATED = 3
+
+
+class EvalMetric(BaseModel):
+ metric_name: str
+ threshold: float
+
+
+class EvalMetricResult(BaseModel):
+ score: Optional[float]
+ eval_status: EvalStatus
+
+
+class EvalResult(BaseModel):
+ eval_set_file: str
+ eval_id: str
+ final_eval_status: EvalStatus
+ eval_metric_results: list[tuple[EvalMetric, EvalMetricResult]]
+ session_id: str
+
+
+MISSING_EVAL_DEPENDENCIES_MESSAGE = (
+ "Eval module is not installed, please install via `pip install"
+ " google-adk[eval]`."
+)
+TOOL_TRAJECTORY_SCORE_KEY = "tool_trajectory_avg_score"
+RESPONSE_MATCH_SCORE_KEY = "response_match_score"
+# This evaluation is not very stable.
+# This is always optional unless explicitly specified.
+RESPONSE_EVALUATION_SCORE_KEY = "response_evaluation_score"
+
+EVAL_SESSION_ID_PREFIX = "___eval___session___"
+DEFAULT_CRITERIA = {
+ TOOL_TRAJECTORY_SCORE_KEY: 1.0, # 1-point scale; 1.0 is perfect.
+ RESPONSE_MATCH_SCORE_KEY: 0.8,
+}
+
+
+def _import_from_path(module_name, file_path):
+ spec = importlib.util.spec_from_file_location(module_name, file_path)
+ module = importlib.util.module_from_spec(spec)
+ sys.modules[module_name] = module
+ spec.loader.exec_module(module)
+ return module
+
+
+def _get_agent_module(agent_module_file_path: str):
+ file_path = os.path.join(agent_module_file_path, "__init__.py")
+ module_name = "agent"
+ return _import_from_path(module_name, file_path)
+
+
+def get_evaluation_criteria_or_default(
+ eval_config_file_path: str,
+) -> dict[str, float]:
+ """Returns evaluation criteria from the config file, if present.
+
+ Otherwise a default one is returned.
+ """
+ if eval_config_file_path:
+ with open(eval_config_file_path, "r", encoding="utf-8") as f:
+ config_data = json.load(f)
+
+ if "criteria" in config_data and isinstance(config_data["criteria"], dict):
+ evaluation_criteria = config_data["criteria"]
+ else:
+ raise ValueError(
+ f"Invalid format for test_config.json at {eval_config_file_path}."
+ " Expected a 'criteria' dictionary."
+ )
+ else:
+ logger.info("No config file supplied. Using default criteria.")
+ evaluation_criteria = DEFAULT_CRITERIA
+
+ return evaluation_criteria
+
+
+def get_root_agent(agent_module_file_path: str) -> Agent:
+ """Returns root agent given the agetn module."""
+ agent_module = _get_agent_module(agent_module_file_path)
+ root_agent = agent_module.agent.root_agent
+ return root_agent
+
+
+def try_get_reset_func(agent_module_file_path: str) -> Any:
+ """Returns reset function for the agent, if present, given the agetn module."""
+ agent_module = _get_agent_module(agent_module_file_path)
+ reset_func = getattr(agent_module.agent, "reset_data", None)
+ return reset_func
+
+
+def parse_and_get_evals_to_run(
+ eval_set_file_path: tuple[str],
+) -> dict[str, list[str]]:
+ """Returns a dictionary of eval sets to evals that should be run."""
+ eval_set_to_evals = {}
+ for input_eval_set in eval_set_file_path:
+ evals = []
+ if ":" not in input_eval_set:
+ eval_set_file = input_eval_set
+ else:
+ eval_set_file = input_eval_set.split(":")[0]
+ evals = input_eval_set.split(":")[1].split(",")
+
+ if eval_set_file not in eval_set_to_evals:
+ eval_set_to_evals[eval_set_file] = []
+
+ eval_set_to_evals[eval_set_file].extend(evals)
+
+ return eval_set_to_evals
+
+
+def run_evals(
+ eval_set_to_evals: dict[str, list[str]],
+ root_agent: Agent,
+ reset_func: Optional[Any],
+ eval_metrics: list[EvalMetric],
+ session_service=None,
+ artifact_service=None,
+ print_detailed_results=False,
+) -> Generator[EvalResult, None, None]:
+ try:
+ from ..evaluation.agent_evaluator import EvaluationGenerator
+ from ..evaluation.response_evaluator import ResponseEvaluator
+ from ..evaluation.trajectory_evaluator import TrajectoryEvaluator
+ except ModuleNotFoundError as e:
+ raise ModuleNotFoundError(MISSING_EVAL_DEPENDENCIES_MESSAGE) from e
+
+ """Returns a summary of eval runs."""
+ for eval_set_file, evals_to_run in eval_set_to_evals.items():
+ with open(eval_set_file, "r", encoding="utf-8") as file:
+ eval_items = json.load(file) # Load JSON into a list
+
+ assert eval_items, f"No eval data found in eval set file: {eval_set_file}"
+
+ for eval_item in eval_items:
+ eval_name = eval_item["name"]
+ eval_data = eval_item["data"]
+ initial_session = eval_item.get("initial_session", {})
+
+ if evals_to_run and eval_name not in evals_to_run:
+ continue
+
+ try:
+ print(f"Running Eval: {eval_set_file}:{eval_name}")
+ session_id = f"{EVAL_SESSION_ID_PREFIX}{str(uuid.uuid4())}"
+
+ scrape_result = EvaluationGenerator._process_query_with_root_agent(
+ data=eval_data,
+ root_agent=root_agent,
+ reset_func=reset_func,
+ initial_session=initial_session,
+ session_id=session_id,
+ session_service=session_service,
+ artifact_service=artifact_service,
+ )
+
+ eval_metric_results = []
+ for eval_metric in eval_metrics:
+ eval_metric_result = None
+ if eval_metric.metric_name == TOOL_TRAJECTORY_SCORE_KEY:
+ score = TrajectoryEvaluator.evaluate(
+ [scrape_result], print_detailed_results=print_detailed_results
+ )
+ eval_metric_result = _get_eval_metric_result(eval_metric, score)
+ elif eval_metric.metric_name == RESPONSE_MATCH_SCORE_KEY:
+ score = ResponseEvaluator.evaluate(
+ [scrape_result],
+ [RESPONSE_MATCH_SCORE_KEY],
+ print_detailed_results=print_detailed_results,
+ )
+ eval_metric_result = _get_eval_metric_result(
+ eval_metric, score["rouge_1/mean"].item()
+ )
+ elif eval_metric.metric_name == RESPONSE_EVALUATION_SCORE_KEY:
+ score = ResponseEvaluator.evaluate(
+ [scrape_result],
+ [RESPONSE_EVALUATION_SCORE_KEY],
+ print_detailed_results=print_detailed_results,
+ )
+ eval_metric_result = _get_eval_metric_result(
+ eval_metric, score["coherence/mean"].item()
+ )
+ else:
+ logger.warning("`%s` is not supported.", eval_metric.metric_name)
+ eval_metric_results.append((
+ eval_metric,
+ EvalMetricResult(eval_status=EvalStatus.NOT_EVALUATED),
+ ))
+
+ eval_metric_results.append((
+ eval_metric,
+ eval_metric_result,
+ ))
+ _print_eval_metric_result(eval_metric, eval_metric_result)
+
+ final_eval_status = EvalStatus.NOT_EVALUATED
+
+ # Go over the all the eval statuses and mark the final eval status as
+ # passed if all of them pass, otherwise mark the final eval status to
+ # failed.
+ for eval_metric_result in eval_metric_results:
+ eval_status = eval_metric_result[1].eval_status
+ if eval_status == EvalStatus.PASSED:
+ final_eval_status = EvalStatus.PASSED
+ elif eval_status == EvalStatus.NOT_EVALUATED:
+ continue
+ elif eval_status == EvalStatus.FAILED:
+ final_eval_status = EvalStatus.FAILED
+ break
+ else:
+ raise ValueError("Unknown eval status.")
+
+ yield EvalResult(
+ eval_set_file=eval_set_file,
+ eval_id=eval_name,
+ final_eval_status=final_eval_status,
+ eval_metric_results=eval_metric_results,
+ session_id=session_id,
+ )
+
+ if final_eval_status == EvalStatus.PASSED:
+ result = "✅ Passsed"
+ else:
+ result = "❌ Failed"
+
+ print(f"Result: {result}\n")
+
+ except Exception as e:
+ print(f"Error: {e}")
+ logger.info("Error: %s", str(traceback.format_exc()))
+
+
+def _get_eval_metric_result(eval_metric, score):
+ eval_status = (
+ EvalStatus.PASSED if score >= eval_metric.threshold else EvalStatus.FAILED
+ )
+ return EvalMetricResult(score=score, eval_status=eval_status)
+
+
+def _print_eval_metric_result(eval_metric, eval_metric_result):
+ print(
+ f"Metric: {eval_metric.metric_name}\tStatus:"
+ f" {eval_metric_result.eval_status}\tScore:"
+ f" {eval_metric_result.score}\tThreshold: {eval_metric.threshold}"
+ )
diff --git a/src/google/adk/cli/cli_tools_click.py b/src/google/adk/cli/cli_tools_click.py
new file mode 100644
index 0000000..9fef197
--- /dev/null
+++ b/src/google/adk/cli/cli_tools_click.py
@@ -0,0 +1,479 @@
+# Copyright 2025 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import asyncio
+from datetime import datetime
+import logging
+import os
+import tempfile
+from typing import Optional
+
+import click
+import uvicorn
+
+from . import cli_deploy
+from .cli import run_cli
+from .cli_eval import MISSING_EVAL_DEPENDENCIES_MESSAGE
+from .fast_api import get_fast_api_app
+from .utils import envs
+from .utils import logs
+
+logger = logging.getLogger(__name__)
+
+
+@click.group(context_settings={"max_content_width": 240})
+def main():
+ """Agent Development Kit CLI tools."""
+ pass
+
+
+@main.group()
+def deploy():
+ """Deploy Agent."""
+ pass
+
+
+@main.command("run")
+@click.option(
+ "--save_session",
+ type=bool,
+ is_flag=True,
+ show_default=True,
+ default=False,
+ help="Optional. Whether to save the session to a json file on exit.",
+)
+@click.argument(
+ "agent",
+ type=click.Path(
+ exists=True, dir_okay=True, file_okay=False, resolve_path=True
+ ),
+)
+def cli_run(agent: str, save_session: bool):
+ """Run an interactive CLI for a certain agent.
+
+ AGENT: The path to the agent source code folder.
+
+ Example:
+
+ adk run path/to/my_agent
+ """
+ logs.log_to_tmp_folder()
+
+ agent_parent_folder = os.path.dirname(agent)
+ agent_folder_name = os.path.basename(agent)
+
+ asyncio.run(
+ run_cli(
+ agent_parent_dir=agent_parent_folder,
+ agent_folder_name=agent_folder_name,
+ save_session=save_session,
+ )
+ )
+
+
+@main.command("eval")
+@click.argument(
+ "agent_module_file_path",
+ type=click.Path(
+ exists=True, dir_okay=True, file_okay=False, resolve_path=True
+ ),
+)
+@click.argument("eval_set_file_path", nargs=-1)
+@click.option("--config_file_path", help="Optional. The path to config file.")
+@click.option(
+ "--print_detailed_results",
+ is_flag=True,
+ show_default=True,
+ default=False,
+ help="Optional. Whether to print detailed results on console or not.",
+)
+def eval_command(
+ agent_module_file_path: str,
+ eval_set_file_path: tuple[str],
+ config_file_path: str,
+ print_detailed_results: bool,
+):
+ """Evaluates an agent given the eval sets.
+
+ AGENT_MODULE_FILE_PATH: The path to the __init__.py file that contains a
+ module by the name "agent". "agent" module contains a root_agent.
+
+ EVAL_SET_FILE_PATH: You can specify one or more eval set file paths.
+
+ For each file, all evals will be run by default.
+
+ If you want to run only specific evals from a eval set, first create a comma
+ separated list of eval names and then add that as a suffix to the eval set
+ file name, demarcated by a `:`.
+
+ For example,
+
+ sample_eval_set_file.json:eval_1,eval_2,eval_3
+
+ This will only run eval_1, eval_2 and eval_3 from sample_eval_set_file.json.
+
+ CONFIG_FILE_PATH: The path to config file.
+
+ PRINT_DETAILED_RESULTS: Prints detailed results on the console.
+ """
+ envs.load_dotenv_for_agent(agent_module_file_path, ".")
+
+ try:
+ from .cli_eval import EvalMetric
+ from .cli_eval import EvalResult
+ from .cli_eval import EvalStatus
+ from .cli_eval import get_evaluation_criteria_or_default
+ from .cli_eval import get_root_agent
+ from .cli_eval import parse_and_get_evals_to_run
+ from .cli_eval import run_evals
+ from .cli_eval import try_get_reset_func
+ except ModuleNotFoundError:
+ raise click.ClickException(MISSING_EVAL_DEPENDENCIES_MESSAGE)
+
+ evaluation_criteria = get_evaluation_criteria_or_default(config_file_path)
+ eval_metrics = []
+ for metric_name, threshold in evaluation_criteria.items():
+ eval_metrics.append(
+ EvalMetric(metric_name=metric_name, threshold=threshold)
+ )
+
+ print(f"Using evaluation creiteria: {evaluation_criteria}")
+
+ root_agent = get_root_agent(agent_module_file_path)
+ reset_func = try_get_reset_func(agent_module_file_path)
+
+ eval_set_to_evals = parse_and_get_evals_to_run(eval_set_file_path)
+
+ try:
+ eval_results = list(
+ run_evals(
+ eval_set_to_evals,
+ root_agent,
+ reset_func,
+ eval_metrics,
+ print_detailed_results=print_detailed_results,
+ )
+ )
+ except ModuleNotFoundError:
+ raise click.ClickException(MISSING_EVAL_DEPENDENCIES_MESSAGE)
+
+ print("*********************************************************************")
+ eval_run_summary = {}
+
+ for eval_result in eval_results:
+ eval_result: EvalResult
+
+ if eval_result.eval_set_file not in eval_run_summary:
+ eval_run_summary[eval_result.eval_set_file] = [0, 0]
+
+ if eval_result.final_eval_status == EvalStatus.PASSED:
+ eval_run_summary[eval_result.eval_set_file][0] += 1
+ else:
+ eval_run_summary[eval_result.eval_set_file][1] += 1
+ print("Eval Run Summary")
+ for eval_set_file, pass_fail_count in eval_run_summary.items():
+ print(
+ f"{eval_set_file}:\n Tests passed: {pass_fail_count[0]}\n Tests"
+ f" failed: {pass_fail_count[1]}"
+ )
+
+
+@main.command("web")
+@click.option(
+ "--session_db_url",
+ help=(
+ "Optional. The database URL to store the session.\n\n - Use"
+ " 'agentengine://' to connect to Vertex"
+ " managed session service.\n\n - Use 'sqlite://'"
+ " to connect to a SQLite DB.\n\n - See"
+ " https://docs.sqlalchemy.org/en/20/core/engines.html#backend-specific-urls"
+ " for more details on supported DB URLs."
+ ),
+)
+@click.option(
+ "--port",
+ type=int,
+ help="Optional. The port of the server",
+ default=8000,
+)
+@click.option(
+ "--allow_origins",
+ help="Optional. Any additional origins to allow for CORS.",
+ multiple=True,
+)
+@click.option(
+ "--log_level",
+ type=click.Choice(
+ ["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"], case_sensitive=False
+ ),
+ default="INFO",
+ help="Optional. Set the logging level",
+)
+@click.option(
+ "--log_to_tmp",
+ is_flag=True,
+ show_default=True,
+ default=False,
+ help=(
+ "Optional. Whether to log to system temp folder instead of console."
+ " This is useful for local debugging."
+ ),
+)
+@click.argument(
+ "agents_dir",
+ type=click.Path(
+ exists=True, dir_okay=True, file_okay=False, resolve_path=True
+ ),
+ default=os.getcwd(),
+)
+def web(
+ agents_dir: str,
+ log_to_tmp: bool,
+ session_db_url: str = "",
+ log_level: str = "INFO",
+ allow_origins: Optional[list[str]] = None,
+ port: int = 8000,
+):
+ """Start a FastAPI server with web UI for a certain agent.
+
+ AGENTS_DIR: The directory of agents, where each sub-directory is a single
+ agent, containing at least `__init__.py` and `agent.py` files.
+
+ Example:
+
+ adk web --session_db_url=[db_url] --port=[port] path/to/agents_dir
+ """
+ if log_to_tmp:
+ logs.log_to_tmp_folder()
+ else:
+ logs.log_to_stderr()
+
+ logging.getLogger().setLevel(log_level)
+
+ config = uvicorn.Config(
+ get_fast_api_app(
+ agent_dir=agents_dir,
+ session_db_url=session_db_url,
+ allow_origins=allow_origins,
+ web=True,
+ ),
+ host="0.0.0.0",
+ port=port,
+ reload=True,
+ )
+ server = uvicorn.Server(config)
+ server.run()
+
+
+@main.command("api_server")
+@click.option(
+ "--session_db_url",
+ help=(
+ "Optional. The database URL to store the session.\n\n - Use"
+ " 'agentengine://' to connect to Vertex"
+ " managed session service.\n\n - Use 'sqlite://'"
+ " to connect to a SQLite DB.\n\n - See"
+ " https://docs.sqlalchemy.org/en/20/core/engines.html#backend-specific-urls"
+ " for more details on supported DB URLs."
+ ),
+)
+@click.option(
+ "--port",
+ type=int,
+ help="Optional. The port of the server",
+ default=8000,
+)
+@click.option(
+ "--allow_origins",
+ help="Optional. Any additional origins to allow for CORS.",
+ multiple=True,
+)
+@click.option(
+ "--log_level",
+ type=click.Choice(
+ ["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"], case_sensitive=False
+ ),
+ default="INFO",
+ help="Optional. Set the logging level",
+)
+@click.option(
+ "--log_to_tmp",
+ is_flag=True,
+ show_default=True,
+ default=False,
+ help=(
+ "Optional. Whether to log to system temp folder instead of console."
+ " This is useful for local debugging."
+ ),
+)
+# The directory of agents, where each sub-directory is a single agent.
+# By default, it is the current working directory
+@click.argument(
+ "agents_dir",
+ type=click.Path(
+ exists=True, dir_okay=True, file_okay=False, resolve_path=True
+ ),
+ default=os.getcwd(),
+)
+def cli_api_server(
+ agents_dir: str,
+ log_to_tmp: bool,
+ session_db_url: str = "",
+ log_level: str = "INFO",
+ allow_origins: Optional[list[str]] = None,
+ port: int = 8000,
+):
+ """Start an api server for a certain agent.
+
+ AGENTS_DIR: The directory of agents, where each sub-directory is a single
+ agent, containing at least `__init__.py` and `agent.py` files.
+
+ Example:
+
+ adk api_server --session_db_url=[db_url] --port=[port] path/to/agents_dir
+ """
+ if log_to_tmp:
+ logs.log_to_tmp_folder()
+ else:
+ logs.log_to_stderr()
+
+ logging.getLogger().setLevel(log_level)
+
+ config = uvicorn.Config(
+ get_fast_api_app(
+ agent_dir=agents_dir,
+ session_db_url=session_db_url,
+ allow_origins=allow_origins,
+ web=False,
+ ),
+ host="0.0.0.0",
+ port=port,
+ reload=True,
+ )
+ server = uvicorn.Server(config)
+ server.run()
+
+
+@deploy.command("cloud_run")
+@click.option(
+ "--project",
+ type=str,
+ help=(
+ "Required. Google Cloud project to deploy the agent. When absent,"
+ " default project from gcloud config is used."
+ ),
+)
+@click.option(
+ "--region",
+ type=str,
+ help=(
+ "Required. Google Cloud region to deploy the agent. When absent,"
+ " gcloud run deploy will prompt later."
+ ),
+)
+@click.option(
+ "--service_name",
+ type=str,
+ default="adk-default-service-name",
+ help=(
+ "Optional. The service name to use in Cloud Run (default:"
+ " 'adk-default-service-name')."
+ ),
+)
+@click.option(
+ "--app_name",
+ type=str,
+ default="",
+ help=(
+ "Optional. App name of the ADK API server (default: the folder name"
+ " of the AGENT source code)."
+ ),
+)
+@click.option(
+ "--port",
+ type=int,
+ default=8000,
+ help="Optional. The port of the ADK API server (default: 8000).",
+)
+@click.option(
+ "--with_cloud_trace",
+ type=bool,
+ is_flag=True,
+ show_default=True,
+ default=False,
+ help="Optional. Whether to enable Cloud Trace for cloud run.",
+)
+@click.option(
+ "--with_ui",
+ type=bool,
+ is_flag=True,
+ show_default=True,
+ default=False,
+ help=(
+ "Optional. Deploy ADK Web UI if set. (default: deploy ADK API server"
+ " only)"
+ ),
+)
+@click.option(
+ "--temp_folder",
+ type=str,
+ default=os.path.join(
+ tempfile.gettempdir(),
+ "cloud_run_deploy_src",
+ datetime.now().strftime("%Y%m%d_%H%M%S"),
+ ),
+ help=(
+ "Optional. Temp folder for the generated Cloud Run source files"
+ " (default: a timestamped folder in the system temp directory)."
+ ),
+)
+@click.argument(
+ "agent",
+ type=click.Path(
+ exists=True, dir_okay=True, file_okay=False, resolve_path=True
+ ),
+)
+def deploy_to_cloud_run(
+ agent: str,
+ project: Optional[str],
+ region: Optional[str],
+ service_name: str,
+ app_name: str,
+ temp_folder: str,
+ port: int,
+ with_cloud_trace: bool,
+ with_ui: bool,
+):
+ """Deploys agent to Cloud Run.
+
+ AGENT: The path to the agent source code folder.
+
+ Example:
+
+ adk deploy cloud_run --project=[project] --region=[region] path/to/my_agent
+ """
+ try:
+ cli_deploy.to_cloud_run(
+ agent_folder=agent,
+ project=project,
+ region=region,
+ service_name=service_name,
+ app_name=app_name,
+ temp_folder=temp_folder,
+ port=port,
+ with_cloud_trace=with_cloud_trace,
+ with_ui=with_ui,
+ )
+ except Exception as e:
+ click.secho(f"Deploy failed: {e}", fg="red", err=True)
diff --git a/src/google/adk/cli/fast_api.py b/src/google/adk/cli/fast_api.py
new file mode 100644
index 0000000..71f367b
--- /dev/null
+++ b/src/google/adk/cli/fast_api.py
@@ -0,0 +1,765 @@
+# Copyright 2025 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import asyncio
+import importlib
+import json
+import logging
+import os
+from pathlib import Path
+import re
+import sys
+import traceback
+import typing
+from typing import Any
+from typing import List
+from typing import Literal
+from typing import Optional
+
+import click
+from fastapi import FastAPI
+from fastapi import HTTPException
+from fastapi import Query
+from fastapi import Response
+from fastapi.middleware.cors import CORSMiddleware
+from fastapi.responses import FileResponse
+from fastapi.responses import RedirectResponse
+from fastapi.responses import StreamingResponse
+from fastapi.staticfiles import StaticFiles
+from fastapi.websockets import WebSocket
+from fastapi.websockets import WebSocketDisconnect
+from google.genai import types
+import graphviz
+from opentelemetry import trace
+from opentelemetry.exporter.cloud_trace import CloudTraceSpanExporter
+from opentelemetry.sdk.trace import export
+from opentelemetry.sdk.trace import ReadableSpan
+from opentelemetry.sdk.trace import TracerProvider
+from pydantic import BaseModel
+from pydantic import ValidationError
+
+from ..agents import RunConfig
+from ..agents.live_request_queue import LiveRequest
+from ..agents.live_request_queue import LiveRequestQueue
+from ..agents.llm_agent import Agent
+from ..agents.run_config import StreamingMode
+from ..artifacts import InMemoryArtifactService
+from ..events.event import Event
+from ..runners import Runner
+from ..sessions.database_session_service import DatabaseSessionService
+from ..sessions.in_memory_session_service import InMemorySessionService
+from ..sessions.session import Session
+from ..sessions.vertex_ai_session_service import VertexAiSessionService
+from .cli_eval import EVAL_SESSION_ID_PREFIX
+from .cli_eval import EvalMetric
+from .cli_eval import EvalMetricResult
+from .cli_eval import EvalStatus
+from .utils import create_empty_state
+from .utils import envs
+from .utils import evals
+
+logger = logging.getLogger(__name__)
+
+_EVAL_SET_FILE_EXTENSION = ".evalset.json"
+
+
+class ApiServerSpanExporter(export.SpanExporter):
+
+ def __init__(self, trace_dict):
+ self.trace_dict = trace_dict
+
+ def export(
+ self, spans: typing.Sequence[ReadableSpan]
+ ) -> export.SpanExportResult:
+ for span in spans:
+ if span.name == "call_llm" or span.name == "send_data":
+ attributes = dict(span.attributes)
+ attributes["trace_id"] = span.get_span_context().trace_id
+ attributes["span_id"] = span.get_span_context().span_id
+ if attributes.get("gcp.vertex.agent.event_id", None):
+ self.trace_dict[attributes["gcp.vertex.agent.event_id"]] = attributes
+ return export.SpanExportResult.SUCCESS
+
+ def force_flush(self, timeout_millis: int = 30000) -> bool:
+ return True
+
+
+class AgentRunRequest(BaseModel):
+ app_name: str
+ user_id: str
+ session_id: str
+ new_message: types.Content
+ streaming: bool = False
+
+
+class AddSessionToEvalSetRequest(BaseModel):
+ eval_id: str
+ session_id: str
+ user_id: str
+
+
+class RunEvalRequest(BaseModel):
+ eval_ids: list[str] # if empty, then all evals in the eval set are run.
+ eval_metrics: list[EvalMetric]
+
+
+class RunEvalResult(BaseModel):
+ eval_set_id: str
+ eval_id: str
+ final_eval_status: EvalStatus
+ eval_metric_results: list[tuple[EvalMetric, EvalMetricResult]]
+ session_id: str
+
+
+def get_fast_api_app(
+ *,
+ agent_dir: str,
+ session_db_url: str = "",
+ allow_origins: Optional[list[str]] = None,
+ web: bool,
+) -> FastAPI:
+ # InMemory tracing dict.
+ trace_dict: dict[str, Any] = {}
+
+ # Set up tracing in the FastAPI server.
+ provider = TracerProvider()
+ provider.add_span_processor(
+ export.SimpleSpanProcessor(ApiServerSpanExporter(trace_dict))
+ )
+ if os.environ.get("ADK_TRACE_TO_CLOUD", "0") == "1":
+ processor = export.BatchSpanProcessor(
+ CloudTraceSpanExporter(
+ project_id=os.environ.get("GOOGLE_CLOUD_PROJECT", "")
+ )
+ )
+ provider.add_span_processor(processor)
+
+ trace.set_tracer_provider(provider)
+
+ # Run the FastAPI server.
+ app = FastAPI()
+
+ if allow_origins:
+ app.add_middleware(
+ CORSMiddleware,
+ allow_origins=allow_origins,
+ allow_credentials=True,
+ allow_methods=["*"],
+ allow_headers=["*"],
+ )
+
+ if agent_dir not in sys.path:
+ sys.path.append(agent_dir)
+
+ runner_dict = {}
+ root_agent_dict = {}
+
+ # Build the Artifact service
+ artifact_service = InMemoryArtifactService()
+
+ # Build the Session service
+ agent_engine_id = ""
+ if session_db_url:
+ if session_db_url.startswith("agentengine://"):
+ # Create vertex session service
+ agent_engine_id = session_db_url.split("://")[1]
+ if not agent_engine_id:
+ raise click.ClickException("Agent engine id can not be empty.")
+ envs.load_dotenv_for_agent("", agent_dir)
+ session_service = VertexAiSessionService(
+ os.environ["GOOGLE_CLOUD_PROJECT"],
+ os.environ["GOOGLE_CLOUD_LOCATION"],
+ )
+ else:
+ session_service = DatabaseSessionService(db_url=session_db_url)
+ else:
+ session_service = InMemorySessionService()
+
+ @app.get("/list-apps")
+ def list_apps() -> list[str]:
+ base_path = Path.cwd() / agent_dir
+ if not base_path.exists():
+ raise HTTPException(status_code=404, detail="Path not found")
+ if not base_path.is_dir():
+ raise HTTPException(status_code=400, detail="Not a directory")
+ agent_names = [
+ x
+ for x in os.listdir(base_path)
+ if os.path.isdir(os.path.join(base_path, x))
+ and not x.startswith(".")
+ and x != "__pycache__"
+ ]
+ agent_names.sort()
+ return agent_names
+
+ @app.get("/debug/trace/{event_id}")
+ def get_trace_dict(event_id: str) -> Any:
+ event_dict = trace_dict.get(event_id, None)
+ if event_dict is None:
+ raise HTTPException(status_code=404, detail="Trace not found")
+ return event_dict
+
+ @app.get(
+ "/apps/{app_name}/users/{user_id}/sessions/{session_id}",
+ response_model_exclude_none=True,
+ )
+ def get_session(app_name: str, user_id: str, session_id: str) -> Session:
+ # Connect to managed session if agent_engine_id is set.
+ app_name = agent_engine_id if agent_engine_id else app_name
+ session = session_service.get_session(
+ app_name=app_name, user_id=user_id, session_id=session_id
+ )
+ if not session:
+ raise HTTPException(status_code=404, detail="Session not found")
+ return session
+
+ @app.get(
+ "/apps/{app_name}/users/{user_id}/sessions",
+ response_model_exclude_none=True,
+ )
+ def list_sessions(app_name: str, user_id: str) -> list[Session]:
+ # Connect to managed session if agent_engine_id is set.
+ app_name = agent_engine_id if agent_engine_id else app_name
+ return [
+ session
+ for session in session_service.list_sessions(
+ app_name=app_name, user_id=user_id
+ ).sessions
+ # Remove sessions that were generated as a part of Eval.
+ if not session.id.startswith(EVAL_SESSION_ID_PREFIX)
+ ]
+
+ @app.post(
+ "/apps/{app_name}/users/{user_id}/sessions/{session_id}",
+ response_model_exclude_none=True,
+ )
+ def create_session_with_id(
+ app_name: str,
+ user_id: str,
+ session_id: str,
+ state: Optional[dict[str, Any]] = None,
+ ) -> Session:
+ # Connect to managed session if agent_engine_id is set.
+ app_name = agent_engine_id if agent_engine_id else app_name
+ if (
+ session_service.get_session(
+ app_name=app_name, user_id=user_id, session_id=session_id
+ )
+ is not None
+ ):
+ logger.warning("Session already exists: %s", session_id)
+ raise HTTPException(
+ status_code=400, detail=f"Session already exists: {session_id}"
+ )
+
+ logger.info("New session created: %s", session_id)
+ return session_service.create_session(
+ app_name=app_name, user_id=user_id, state=state, session_id=session_id
+ )
+
+ @app.post(
+ "/apps/{app_name}/users/{user_id}/sessions",
+ response_model_exclude_none=True,
+ )
+ def create_session(
+ app_name: str,
+ user_id: str,
+ state: Optional[dict[str, Any]] = None,
+ ) -> Session:
+ # Connect to managed session if agent_engine_id is set.
+ app_name = agent_engine_id if agent_engine_id else app_name
+
+ logger.info("New session created")
+ return session_service.create_session(
+ app_name=app_name, user_id=user_id, state=state
+ )
+
+ def _get_eval_set_file_path(app_name, agent_dir, eval_set_id) -> str:
+ return os.path.join(
+ agent_dir,
+ app_name,
+ eval_set_id + _EVAL_SET_FILE_EXTENSION,
+ )
+
+ @app.post(
+ "/apps/{app_name}/eval_sets/{eval_set_id}",
+ response_model_exclude_none=True,
+ )
+ def create_eval_set(
+ app_name: str,
+ eval_set_id: str,
+ ):
+ """Creates an eval set, given the id."""
+ pattern = r"^[a-zA-Z0-9_]+$"
+ if not bool(re.fullmatch(pattern, eval_set_id)):
+ raise HTTPException(
+ status_code=400,
+ detail=(
+ f"Invalid eval set id. Eval set id should have the `{pattern}`"
+ " format"
+ ),
+ )
+ # Define the file path
+ new_eval_set_path = _get_eval_set_file_path(
+ app_name, agent_dir, eval_set_id
+ )
+
+ logger.info("Creating eval set file `%s`", new_eval_set_path)
+
+ if not os.path.exists(new_eval_set_path):
+ # Write the JSON string to the file
+ logger.info("Eval set file doesn't exist, we will create a new one.")
+ with open(new_eval_set_path, "w") as f:
+ empty_content = json.dumps([], indent=2)
+ f.write(empty_content)
+
+ @app.get(
+ "/apps/{app_name}/eval_sets",
+ response_model_exclude_none=True,
+ )
+ def list_eval_sets(app_name: str) -> list[str]:
+ """Lists all eval sets for the given app."""
+ eval_set_file_path = os.path.join(agent_dir, app_name)
+ eval_sets = []
+ for file in os.listdir(eval_set_file_path):
+ if file.endswith(_EVAL_SET_FILE_EXTENSION):
+ eval_sets.append(
+ os.path.basename(file).removesuffix(_EVAL_SET_FILE_EXTENSION)
+ )
+
+ return sorted(eval_sets)
+
+ @app.post(
+ "/apps/{app_name}/eval_sets/{eval_set_id}/add_session",
+ response_model_exclude_none=True,
+ )
+ def add_session_to_eval_set(
+ app_name: str, eval_set_id: str, req: AddSessionToEvalSetRequest
+ ):
+ pattern = r"^[a-zA-Z0-9_]+$"
+ if not bool(re.fullmatch(pattern, req.eval_id)):
+ raise HTTPException(
+ status_code=400,
+ detail=f"Invalid eval id. Eval id should have the `{pattern}` format",
+ )
+
+ # Get the session
+ session = session_service.get_session(
+ app_name=app_name, user_id=req.user_id, session_id=req.session_id
+ )
+ assert session, "Session not found."
+ # Load the eval set file data
+ eval_set_file_path = _get_eval_set_file_path(
+ app_name, agent_dir, eval_set_id
+ )
+ with open(eval_set_file_path, "r") as file:
+ eval_set_data = json.load(file) # Load JSON into a list
+
+ if [x for x in eval_set_data if x["name"] == req.eval_id]:
+ raise HTTPException(
+ status_code=400,
+ detail=(
+ f"Eval id `{req.eval_id}` already exists in `{eval_set_id}`"
+ " eval set."
+ ),
+ )
+
+ # Convert the session data to evaluation format
+ test_data = evals.convert_session_to_eval_format(session)
+
+ # Populate the session with initial session state.
+ initial_session_state = create_empty_state(_get_root_agent(app_name))
+
+ eval_set_data.append({
+ "name": req.eval_id,
+ "data": test_data,
+ "initial_session": {
+ "state": initial_session_state,
+ "app_name": app_name,
+ "user_id": req.user_id,
+ },
+ })
+ # Serialize the test data to JSON and write to the eval set file.
+ with open(eval_set_file_path, "w") as f:
+ f.write(json.dumps(eval_set_data, indent=2))
+
+ @app.get(
+ "/apps/{app_name}/eval_sets/{eval_set_id}/evals",
+ response_model_exclude_none=True,
+ )
+ def list_evals_in_eval_set(
+ app_name: str,
+ eval_set_id: str,
+ ) -> list[str]:
+ """Lists all evals in an eval set."""
+ # Load the eval set file data
+ eval_set_file_path = _get_eval_set_file_path(
+ app_name, agent_dir, eval_set_id
+ )
+ with open(eval_set_file_path, "r") as file:
+ eval_set_data = json.load(file) # Load JSON into a list
+
+ return sorted([x["name"] for x in eval_set_data])
+
+ @app.post(
+ "/apps/{app_name}/eval_sets/{eval_set_id}/run_eval",
+ response_model_exclude_none=True,
+ )
+ def run_eval(
+ app_name: str, eval_set_id: str, req: RunEvalRequest
+ ) -> list[RunEvalResult]:
+ from .cli_eval import run_evals
+
+ """Runs an eval given the details in the eval request."""
+ # Create a mapping from eval set file to all the evals that needed to be
+ # run.
+ eval_set_file_path = _get_eval_set_file_path(
+ app_name, agent_dir, eval_set_id
+ )
+ eval_set_to_evals = {eval_set_file_path: req.eval_ids}
+
+ if not req.eval_ids:
+ logger.info(
+ "Eval ids to run list is empty. We will all evals in the eval set."
+ )
+ root_agent = _get_root_agent(app_name)
+ eval_results = list(
+ run_evals(
+ eval_set_to_evals,
+ root_agent,
+ getattr(root_agent, "reset_data", None),
+ req.eval_metrics,
+ session_service=session_service,
+ artifact_service=artifact_service,
+ )
+ )
+
+ run_eval_results = []
+ for eval_result in eval_results:
+ run_eval_results.append(
+ RunEvalResult(
+ app_name=app_name,
+ eval_set_id=eval_set_id,
+ eval_id=eval_result.eval_id,
+ final_eval_status=eval_result.final_eval_status,
+ eval_metric_results=eval_result.eval_metric_results,
+ session_id=eval_result.session_id,
+ )
+ )
+ return run_eval_results
+
+ @app.delete("/apps/{app_name}/users/{user_id}/sessions/{session_id}")
+ def delete_session(app_name: str, user_id: str, session_id: str):
+ # Connect to managed session if agent_engine_id is set.
+ app_name = agent_engine_id if agent_engine_id else app_name
+ session_service.delete_session(
+ app_name=app_name, user_id=user_id, session_id=session_id
+ )
+
+ @app.get(
+ "/apps/{app_name}/users/{user_id}/sessions/{session_id}/artifacts/{artifact_name}",
+ response_model_exclude_none=True,
+ )
+ def load_artifact(
+ app_name: str,
+ user_id: str,
+ session_id: str,
+ artifact_name: str,
+ version: Optional[int] = Query(None),
+ ) -> Optional[types.Part]:
+ artifact = artifact_service.load_artifact(
+ app_name=app_name,
+ user_id=user_id,
+ session_id=session_id,
+ filename=artifact_name,
+ version=version,
+ )
+ if not artifact:
+ raise HTTPException(status_code=404, detail="Artifact not found")
+ return artifact
+
+ @app.get(
+ "/apps/{app_name}/users/{user_id}/sessions/{session_id}/artifacts/{artifact_name}/versions/{version_id}",
+ response_model_exclude_none=True,
+ )
+ def load_artifact_version(
+ app_name: str,
+ user_id: str,
+ session_id: str,
+ artifact_name: str,
+ version_id: int,
+ ) -> Optional[types.Part]:
+ artifact = artifact_service.load_artifact(
+ app_name=app_name,
+ user_id=user_id,
+ session_id=session_id,
+ filename=artifact_name,
+ version=version_id,
+ )
+ if not artifact:
+ raise HTTPException(status_code=404, detail="Artifact not found")
+ return artifact
+
+ @app.get(
+ "/apps/{app_name}/users/{user_id}/sessions/{session_id}/artifacts",
+ response_model_exclude_none=True,
+ )
+ def list_artifact_names(
+ app_name: str, user_id: str, session_id: str
+ ) -> list[str]:
+ return artifact_service.list_artifact_keys(
+ app_name=app_name, user_id=user_id, session_id=session_id
+ )
+
+ @app.get(
+ "/apps/{app_name}/users/{user_id}/sessions/{session_id}/artifacts/{artifact_name}/versions",
+ response_model_exclude_none=True,
+ )
+ def list_artifact_versions(
+ app_name: str, user_id: str, session_id: str, artifact_name: str
+ ) -> list[int]:
+ return artifact_service.list_versions(
+ app_name=app_name,
+ user_id=user_id,
+ session_id=session_id,
+ filename=artifact_name,
+ )
+
+ @app.delete(
+ "/apps/{app_name}/users/{user_id}/sessions/{session_id}/artifacts/{artifact_name}",
+ )
+ def delete_artifact(
+ app_name: str, user_id: str, session_id: str, artifact_name: str
+ ):
+ artifact_service.delete_artifact(
+ app_name=app_name,
+ user_id=user_id,
+ session_id=session_id,
+ filename=artifact_name,
+ )
+
+ @app.post("/run", response_model_exclude_none=True)
+ async def agent_run(req: AgentRunRequest) -> list[Event]:
+ # Connect to managed session if agent_engine_id is set.
+ app_id = agent_engine_id if agent_engine_id else req.app_name
+ session = session_service.get_session(
+ app_name=app_id, user_id=req.user_id, session_id=req.session_id
+ )
+ if not session:
+ raise HTTPException(status_code=404, detail="Session not found")
+ runner = _get_runner(req.app_name)
+ events = [
+ event
+ async for event in runner.run_async(
+ user_id=req.user_id,
+ session_id=req.session_id,
+ new_message=req.new_message,
+ )
+ ]
+ logger.info("Generated %s events in agent run: %s", len(events), events)
+ return events
+
+ @app.post("/run_sse")
+ async def agent_run_sse(req: AgentRunRequest) -> StreamingResponse:
+ # Connect to managed session if agent_engine_id is set.
+ app_id = agent_engine_id if agent_engine_id else req.app_name
+ # SSE endpoint
+ session = session_service.get_session(
+ app_name=app_id, user_id=req.user_id, session_id=req.session_id
+ )
+ if not session:
+ raise HTTPException(status_code=404, detail="Session not found")
+
+ # Convert the events to properly formatted SSE
+ async def event_generator():
+ try:
+ stream_mode = StreamingMode.SSE if req.streaming else StreamingMode.NONE
+ runner = _get_runner(req.app_name)
+ async for event in runner.run_async(
+ user_id=req.user_id,
+ session_id=req.session_id,
+ new_message=req.new_message,
+ run_config=RunConfig(streaming_mode=stream_mode),
+ ):
+ # Format as SSE data
+ sse_event = event.model_dump_json(exclude_none=True, by_alias=True)
+ logger.info("Generated event in agent run streaming: %s", sse_event)
+ yield f"data: {sse_event}\n\n"
+ except Exception as e:
+ logger.exception("Error in event_generator: %s", e)
+ # You might want to yield an error event here
+ yield f'data: {{"error": "{str(e)}"}}\n\n'
+
+ # Returns a streaming response with the proper media type for SSE
+ return StreamingResponse(
+ event_generator(),
+ media_type="text/event-stream",
+ )
+
+ @app.get(
+ "/apps/{app_name}/users/{user_id}/sessions/{session_id}/events/{event_id}/graph",
+ response_model_exclude_none=True,
+ )
+ def get_event_graph(
+ app_name: str, user_id: str, session_id: str, event_id: str
+ ):
+ # Connect to managed session if agent_engine_id is set.
+ app_id = agent_engine_id if agent_engine_id else app_name
+ session = session_service.get_session(
+ app_name=app_id, user_id=user_id, session_id=session_id
+ )
+ session_events = session.events if session else []
+ event = next((x for x in session_events if x.id == event_id), None)
+ if not event:
+ return {}
+
+ from . import agent_graph
+
+ function_calls = event.get_function_calls()
+ function_responses = event.get_function_responses()
+ root_agent = _get_root_agent(app_name)
+ dot_graph = None
+ if function_calls:
+ function_call_highlights = []
+ for function_call in function_calls:
+ from_name = event.author
+ to_name = function_call.name
+ function_call_highlights.append((from_name, to_name))
+ dot_graph = agent_graph.get_agent_graph(
+ root_agent, function_call_highlights
+ )
+ elif function_responses:
+ function_responses_highlights = []
+ for function_response in function_responses:
+ from_name = function_response.name
+ to_name = event.author
+ function_responses_highlights.append((from_name, to_name))
+ dot_graph = agent_graph.get_agent_graph(
+ root_agent, function_responses_highlights
+ )
+ else:
+ from_name = event.author
+ to_name = ""
+ dot_graph = agent_graph.get_agent_graph(
+ root_agent, [(from_name, to_name)]
+ )
+ if dot_graph and isinstance(dot_graph, graphviz.Digraph):
+ return {"dot_src": dot_graph.source}
+ else:
+ return {}
+
+ @app.websocket("/run_live")
+ async def agent_live_run(
+ websocket: WebSocket,
+ app_name: str,
+ user_id: str,
+ session_id: str,
+ modalities: List[Literal["TEXT", "AUDIO"]] = Query(
+ default=["TEXT", "AUDIO"]
+ ), # Only allows "TEXT" or "AUDIO"
+ ) -> None:
+ await websocket.accept()
+
+ # Connect to managed session if agent_engine_id is set.
+ app_id = agent_engine_id if agent_engine_id else app_name
+ session = session_service.get_session(
+ app_name=app_id, user_id=user_id, session_id=session_id
+ )
+ if not session:
+ # Accept first so that the client is aware of connection establishment,
+ # then close with a specific code.
+ await websocket.close(code=1002, reason="Session not found")
+ return
+
+ live_request_queue = LiveRequestQueue()
+
+ async def forward_events():
+ runner = _get_runner(app_name)
+ async for event in runner.run_live(
+ session=session, live_request_queue=live_request_queue
+ ):
+ await websocket.send_text(
+ event.model_dump_json(exclude_none=True, by_alias=True)
+ )
+
+ async def process_messages():
+ try:
+ while True:
+ data = await websocket.receive_text()
+ # Validate and send the received message to the live queue.
+ live_request_queue.send(LiveRequest.model_validate_json(data))
+ except ValidationError as ve:
+ logger.error("Validation error in process_messages: %s", ve)
+
+ # Run both tasks concurrently and cancel all if one fails.
+ tasks = [
+ asyncio.create_task(forward_events()),
+ asyncio.create_task(process_messages()),
+ ]
+ done, pending = await asyncio.wait(
+ tasks, return_when=asyncio.FIRST_EXCEPTION
+ )
+ try:
+ # This will re-raise any exception from the completed tasks.
+ for task in done:
+ task.result()
+ except WebSocketDisconnect:
+ logger.info("Client disconnected during process_messages.")
+ except Exception as e:
+ logger.exception("Error during live websocket communication: %s", e)
+ traceback.print_exc()
+ finally:
+ for task in pending:
+ task.cancel()
+
+ def _get_root_agent(app_name: str) -> Agent:
+ """Returns the root agent for the given app."""
+ if app_name in root_agent_dict:
+ return root_agent_dict[app_name]
+ envs.load_dotenv_for_agent(os.path.basename(app_name), agent_dir)
+ agent_module = importlib.import_module(app_name)
+ root_agent: Agent = agent_module.agent.root_agent
+ root_agent_dict[app_name] = root_agent
+ return root_agent
+
+ def _get_runner(app_name: str) -> Runner:
+ """Returns the runner for the given app."""
+ if app_name in runner_dict:
+ return runner_dict[app_name]
+ root_agent = _get_root_agent(app_name)
+ runner = Runner(
+ app_name=agent_engine_id if agent_engine_id else app_name,
+ agent=root_agent,
+ artifact_service=artifact_service,
+ session_service=session_service,
+ )
+ runner_dict[app_name] = runner
+ return runner
+
+ if web:
+ BASE_DIR = Path(__file__).parent.resolve()
+ ANGULAR_DIST_PATH = BASE_DIR / "browser"
+
+ @app.get("/")
+ async def redirect_to_dev_ui():
+ return RedirectResponse("/dev-ui")
+
+ @app.get("/dev-ui")
+ async def dev_ui():
+ return FileResponse(BASE_DIR / "browser/index.html")
+
+ app.mount(
+ "/", StaticFiles(directory=ANGULAR_DIST_PATH, html=True), name="static"
+ )
+ return app
diff --git a/src/google/adk/cli/utils/__init__.py b/src/google/adk/cli/utils/__init__.py
new file mode 100644
index 0000000..846c156
--- /dev/null
+++ b/src/google/adk/cli/utils/__init__.py
@@ -0,0 +1,49 @@
+# Copyright 2025 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import re
+from typing import Any
+from typing import Optional
+
+from ...agents.base_agent import BaseAgent
+from ...agents.llm_agent import LlmAgent
+
+__all__ = [
+ 'create_empty_state',
+]
+
+
+def _create_empty_state(agent: BaseAgent, all_state: dict[str, Any]):
+ for sub_agent in agent.sub_agents:
+ _create_empty_state(sub_agent, all_state)
+
+ if (
+ isinstance(agent, LlmAgent)
+ and agent.instruction
+ and isinstance(agent.instruction, str)
+ ):
+ for key in re.findall(r'{([\w]+)}', agent.instruction):
+ all_state[key] = ''
+
+
+def create_empty_state(
+ agent: BaseAgent, initialized_states: Optional[dict[str, Any]] = None
+) -> dict[str, Any]:
+ """Creates empty str for non-initialized states."""
+ non_initialized_states = {}
+ _create_empty_state(agent, non_initialized_states)
+ for key in initialized_states or {}:
+ if key in non_initialized_states:
+ del non_initialized_states[key]
+ return non_initialized_states
diff --git a/src/google/adk/cli/utils/envs.py b/src/google/adk/cli/utils/envs.py
new file mode 100644
index 0000000..2a6411c
--- /dev/null
+++ b/src/google/adk/cli/utils/envs.py
@@ -0,0 +1,57 @@
+# Copyright 2025 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import logging
+import os
+
+from dotenv import load_dotenv
+
+logger = logging.getLogger(__file__)
+
+
+def _walk_to_root_until_found(folder, filename) -> str:
+ checkpath = os.path.join(folder, filename)
+ if os.path.exists(checkpath) and os.path.isfile(checkpath):
+ return checkpath
+
+ parent_folder = os.path.dirname(folder)
+ if parent_folder == folder: # reached the root
+ return ''
+
+ return _walk_to_root_until_found(parent_folder, filename)
+
+
+def load_dotenv_for_agent(
+ agent_name: str, agent_parent_folder: str, filename: str = '.env'
+):
+ """Lods the .env file for the agent module."""
+
+ # Gets the folder of agent_module as starting_folder
+ starting_folder = os.path.abspath(
+ os.path.join(agent_parent_folder, agent_name)
+ )
+ dotenv_file_path = _walk_to_root_until_found(starting_folder, filename)
+ if dotenv_file_path:
+ load_dotenv(dotenv_file_path, override=True, verbose=True)
+ logger.info(
+ 'Loaded %s file for %s at %s',
+ filename,
+ agent_name,
+ dotenv_file_path,
+ )
+ logger.info(
+ 'Reloaded %s file for %s at %s', filename, agent_name, dotenv_file_path
+ )
+ else:
+ logger.info('No %s file found for %s', filename, agent_name)
diff --git a/src/google/adk/cli/utils/evals.py b/src/google/adk/cli/utils/evals.py
new file mode 100644
index 0000000..f62a3d9
--- /dev/null
+++ b/src/google/adk/cli/utils/evals.py
@@ -0,0 +1,93 @@
+# Copyright 2025 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import Any
+
+from ...sessions.session import Session
+
+
+def convert_session_to_eval_format(session: Session) -> list[dict[str, Any]]:
+ """Converts a session data into eval format.
+
+ Args:
+ session: The session that should be converted.
+
+ Returns:
+ list: A single evaluation dataset in the required format.
+ """
+ eval_case = []
+ events = session.events if session and session.events else []
+
+ for event in events:
+ if event.author == 'user':
+ if not event.content or not event.content.parts:
+ continue
+
+ # Extract user query
+ content = event.content
+ parts = content.parts
+
+ query = parts[0].text or ''
+
+ # Find the corresponding tool usage or response for the query
+ expected_tool_use = []
+ intermediate_agent_responses = []
+
+ # Check subsequent events to extract tool uses or responses for this turn.
+ for subsequent_event in events[events.index(event) + 1 :]:
+ event_author = subsequent_event.author or 'agent'
+ if event_author == 'user':
+ # We found an event where the author was the user. This means that a
+ # new turn has started. So close this turn here.
+ break
+
+ if not subsequent_event.content or not subsequent_event.content.parts:
+ continue
+
+ for subsequent_part in subsequent_event.content.parts:
+ # Some events have both function call and reference
+
+ if subsequent_part.function_call:
+ tool_name = subsequent_part.function_call.name or ''
+ tool_input = subsequent_part.function_call.args or {}
+ expected_tool_use.append({
+ 'tool_name': tool_name,
+ 'tool_input': tool_input,
+ })
+ elif subsequent_part.text:
+ # Also keep track of all the natural langauge responses that
+ # agent (or sub agents) generated.
+ intermediate_agent_responses.append(
+ {'author': event_author, 'text': subsequent_part.text}
+ )
+
+ # If we are here then either we are done reading all the events or we
+ # encountered an event that had content authored by the end-user.
+ # This, basically means an end of turn.
+ # We assume that the last natural langauge intermediate response is the
+ # final response from the agent/model. We treat that as a reference.
+ eval_case.append({
+ 'query': query,
+ 'expected_tool_use': expected_tool_use,
+ 'expected_intermediate_agent_responses': intermediate_agent_responses[
+ :-1
+ ],
+ 'reference': (
+ intermediate_agent_responses[-1]['text']
+ if intermediate_agent_responses
+ else ''
+ ),
+ })
+
+ return eval_case
diff --git a/src/google/adk/cli/utils/logs.py b/src/google/adk/cli/utils/logs.py
new file mode 100644
index 0000000..9723df0
--- /dev/null
+++ b/src/google/adk/cli/utils/logs.py
@@ -0,0 +1,72 @@
+# Copyright 2025 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import logging
+import os
+import tempfile
+import time
+
+LOGGING_FORMAT = (
+ '%(asctime)s - %(levelname)s - %(filename)s:%(lineno)d - %(message)s'
+)
+
+
+def log_to_stderr(level=logging.INFO):
+ logging.basicConfig(
+ level=level,
+ format=LOGGING_FORMAT,
+ )
+
+
+def log_to_tmp_folder(
+ level=logging.INFO,
+ *,
+ sub_folder: str = 'agents_log',
+ log_file_prefix: str = 'agent',
+ log_file_timestamp: str = time.strftime('%Y%m%d_%H%M%S'),
+):
+ """Logs to system temp folder, instead of logging to stderr.
+
+ Args
+ sub_folder: str = 'agents_log',
+ log_file_prefix: str = 'agent',
+ log_file_timestamp: str = time.strftime('%Y%m%d_%H%M%S'),
+
+ Returns
+ the log file path.
+ """
+ log_dir = os.path.join(tempfile.gettempdir(), sub_folder)
+ log_filename = f'{log_file_prefix}.{log_file_timestamp}.log'
+ log_filepath = os.path.join(log_dir, log_filename)
+
+ os.makedirs(log_dir, exist_ok=True)
+
+ file_handler = logging.FileHandler(log_filepath, mode='w')
+ file_handler.setLevel(level)
+ file_handler.setFormatter(logging.Formatter(LOGGING_FORMAT))
+
+ root_logger = logging.getLogger()
+ root_logger.setLevel(level)
+ root_logger.handlers = [] # Clear handles to disable logging to stderr
+ root_logger.addHandler(file_handler)
+
+ print(f'Log setup complete: {log_filepath}')
+
+ latest_log_link = os.path.join(log_dir, f'{log_file_prefix}.latest.log')
+ if os.path.islink(latest_log_link):
+ os.unlink(latest_log_link)
+ os.symlink(log_filepath, latest_log_link)
+
+ print(f'To access latest log: tail -F {latest_log_link}')
+ return log_filepath
diff --git a/src/google/adk/code_executors/__init__.py b/src/google/adk/code_executors/__init__.py
new file mode 100644
index 0000000..08fd663
--- /dev/null
+++ b/src/google/adk/code_executors/__init__.py
@@ -0,0 +1,49 @@
+# Copyright 2025 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import logging
+
+from .base_code_executor import BaseCodeExecutor
+from .code_executor_context import CodeExecutorContext
+from .unsafe_local_code_executor import UnsafeLocalCodeExecutor
+
+logger = logging.getLogger(__name__)
+
+__all__ = [
+ 'BaseCodeExecutor',
+ 'CodeExecutorContext',
+ 'UnsafeLocalCodeExecutor',
+]
+
+try:
+ from .vertex_ai_code_executor import VertexAiCodeExecutor
+
+ __all__.append('VertexAiCodeExecutor')
+except ImportError:
+ logger.debug(
+ 'The Vertex sdk is not installed. If you want to use the Vertex Code'
+ ' Interpreter with agents, please install it. If not, you can ignore this'
+ ' warning.'
+ )
+
+try:
+ from .container_code_executor import ContainerCodeExecutor
+
+ __all__.append('ContainerCodeExecutor')
+except ImportError:
+ logger.debug(
+ 'The docker sdk is not installed. If you want to use the Container Code'
+ ' Executor with agents, please install it. If not, you can ignore this'
+ ' warning.'
+ )
diff --git a/src/google/adk/code_executors/base_code_executor.py b/src/google/adk/code_executors/base_code_executor.py
new file mode 100644
index 0000000..b1c243b
--- /dev/null
+++ b/src/google/adk/code_executors/base_code_executor.py
@@ -0,0 +1,97 @@
+# Copyright 2025 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import abc
+from typing import List
+
+from pydantic import BaseModel
+
+from ..agents.invocation_context import InvocationContext
+from .code_execution_utils import CodeExecutionInput
+from .code_execution_utils import CodeExecutionResult
+
+
+class BaseCodeExecutor(BaseModel):
+ """Abstract base class for all code executors.
+
+ The code executor allows the agent to execute code blocks from model responses
+ and incorporate the execution results into the final response.
+
+ Attributes:
+ optimize_data_file: If true, extract and process data files from the model
+ request and attach them to the code executor. Supported data file
+ MimeTypes are [text/csv]. Default to False.
+ stateful: Whether the code executor is stateful. Default to False.
+ error_retry_attempts: The number of attempts to retry on consecutive code
+ execution errors. Default to 2.
+ code_block_delimiters: The list of the enclosing delimiters to identify the
+ code blocks.
+ execution_result_delimiters: The delimiters to format the code execution
+ result.
+ """
+
+ optimize_data_file: bool = False
+ """
+ If true, extract and process data files from the model request
+ and attach them to the code executor.
+ Supported data file MimeTypes are [text/csv].
+
+ Default to False.
+ """
+
+ stateful: bool = False
+ """
+ Whether the code executor is stateful. Default to False.
+ """
+
+ error_retry_attempts: int = 2
+ """
+ The number of attempts to retry on consecutive code execution errors. Default to 2.
+ """
+
+ code_block_delimiters: List[tuple[str, str]] = [
+ ('```tool_code\n', '\n```'),
+ ('```python\n', '\n```'),
+ ]
+ """
+ The list of the enclosing delimiters to identify the code blocks.
+ For example, the delimiter ('```python\n', '\n```') can be
+ used to identify code blocks with the following format:
+
+ ```python
+ print("hello")
+ ```
+ """
+
+ execution_result_delimiters: tuple[str, str] = ('```tool_output\n', '\n```')
+ """
+ The delimiters to format the code execution result.
+ """
+
+ @abc.abstractmethod
+ def execute_code(
+ self,
+ invocation_context: InvocationContext,
+ code_execution_input: CodeExecutionInput,
+ ) -> CodeExecutionResult:
+ """Executes code and return the code execution result.
+
+ Args:
+ invocation_context: The invocation context of the code execution.
+ code_execution_input: The code execution input.
+
+ Returns:
+ The code execution result.
+ """
+ pass
diff --git a/src/google/adk/code_executors/code_execution_utils.py b/src/google/adk/code_executors/code_execution_utils.py
new file mode 100644
index 0000000..9ebab46
--- /dev/null
+++ b/src/google/adk/code_executors/code_execution_utils.py
@@ -0,0 +1,245 @@
+# Copyright 2025 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Utility functions for code execution."""
+
+import base64
+import binascii
+import copy
+import dataclasses
+import re
+from typing import List, Optional
+
+from google.genai import types
+
+
+@dataclasses.dataclass(frozen=True)
+class File:
+ """A structure that contains a file name and its content."""
+
+ name: str
+ """
+ The name of the file with file extension (e.g., "file.csv").
+ """
+
+ content: str
+ """
+ The base64-encoded bytes of the file content.
+ """
+
+ mime_type: str = 'text/plain'
+ """
+ The mime type of the file (e.g., "image/png").
+ """
+
+
+@dataclasses.dataclass
+class CodeExecutionInput:
+ """A structure that contains the input of code execution."""
+
+ code: str
+ """
+ The code to execute.
+ """
+
+ input_files: list[File] = dataclasses.field(default_factory=list)
+ """
+ The input files available to the code.
+ """
+
+ execution_id: Optional[str] = None
+ """
+ The execution ID for the stateful code execution.
+ """
+
+
+@dataclasses.dataclass
+class CodeExecutionResult:
+ """A structure that contains the result of code execution."""
+
+ stdout: str = ''
+ """
+ The standard output of the code execution.
+ """
+
+ stderr: str = ''
+ """
+ The standard error of the code execution.
+ """
+
+ output_files: list[File] = dataclasses.field(default_factory=list)
+ """
+ The output files from the code execution.
+ """
+
+
+class CodeExecutionUtils:
+ """Utility functions for code execution."""
+
+ @staticmethod
+ def get_encoded_file_content(data: bytes) -> bytes:
+ """Gets the file content as a base64-encoded bytes.
+
+ Args:
+ data: The file content bytes.
+
+ Returns:
+ The file content as a base64-encoded bytes.
+ """
+
+ def _is_base64_encoded(data: bytes) -> bool:
+ try:
+ return base64.b64encode(base64.b64decode(data)) == data
+ except binascii.Error:
+ return False
+
+ return data if _is_base64_encoded(data) else base64.b64encode(data)
+
+ @staticmethod
+ def extract_code_and_truncate_content(
+ content: types.Content,
+ code_block_delimiters: List[tuple[str, str]],
+ ) -> Optional[str]:
+ """Extracts the first code block from the content and truncate everything after it.
+
+ Args:
+ content: The mutable content to extract the code from.
+ code_block_delimiters: The list of the enclosing delimiters to identify
+ the code blocks.
+
+ Returns:
+ The first code block if found, otherwise None.
+ """
+ if not content or not content.parts:
+ return
+
+ text_parts = [p for p in content.parts if p.text]
+ if not text_parts:
+ return
+
+ first_text_part = copy.deepcopy(text_parts[0])
+ response_text = '\n'.join([p.text for p in text_parts])
+
+ # Find the first code block.
+ leading_delimiter_pattern = '|'.join(d[0] for d in code_block_delimiters)
+ trailing_delimiter_pattern = '|'.join(d[1] for d in code_block_delimiters)
+ pattern = re.compile(
+ (
+ rf'(?P.*?)({leading_delimiter_pattern})(?P.*?)({trailing_delimiter_pattern})(?P.*?)$'
+ ).encode(),
+ re.DOTALL,
+ )
+ pattern_match = pattern.search(response_text.encode())
+ if pattern_match is None:
+ return
+
+ code_str = pattern_match.group('code').decode()
+ if not code_str:
+ return
+
+ content.parts = []
+ if pattern_match.group('prefix'):
+ first_text_part.text = pattern_match.group('prefix').decode()
+ content.parts.append(first_text_part)
+ content.parts.append(
+ CodeExecutionUtils.build_executable_code_part(code_str)
+ )
+ return pattern_match.group('code').decode()
+
+ @staticmethod
+ def build_executable_code_part(code: str) -> types.Part:
+ """Builds an executable code part with code string.
+
+ Args:
+ code: The code string.
+
+ Returns:
+ The constructed executable code part.
+ """
+ return types.Part.from_executable_code(
+ code=code,
+ language='PYTHON',
+ )
+
+ @staticmethod
+ def build_code_execution_result_part(
+ code_execution_result: CodeExecutionResult,
+ ) -> types.Part:
+ """Builds the code execution result part from the code execution result.
+
+ Args:
+ code_execution_result: The code execution result.
+
+ Returns:
+ The constructed code execution result part.
+ """
+ if code_execution_result.stderr:
+ return types.Part.from_code_execution_result(
+ outcome='OUTCOME_FAILED',
+ output=code_execution_result.stderr,
+ )
+ final_result = []
+ if code_execution_result.stdout or not code_execution_result.output_files:
+ final_result.append(
+ 'Code execution result:\n' + '%s\n' % code_execution_result.stdout
+ )
+ if code_execution_result.output_files:
+ final_result.append(
+ 'Saved artifacts:\n'
+ + ','.join(
+ ['`%s`' % f.name for f in code_execution_result.output_files]
+ )
+ )
+ return types.Part.from_code_execution_result(
+ outcome='OUTCOME_OK',
+ output='\n\n'.join(final_result),
+ )
+
+ @staticmethod
+ def convert_code_execution_parts(
+ content: types.Content,
+ code_block_delimiter: tuple[str, str],
+ execution_result_delimiters: tuple[str, str],
+ ):
+ """Converts the code execution parts to text parts in a Content.
+
+ Args:
+ content: The mutable content to convert the code execution parts to text
+ parts.
+ code_block_delimiter: The delimiter to format the code block.
+ execution_result_delimiters: The delimiter to format the code execution
+ result.
+ """
+ if not content.parts:
+ return
+
+ # Handle the conversion of trailing executable code parts.
+ if content.parts[-1].executable_code:
+ content.parts[-1] = types.Part(
+ text=(
+ code_block_delimiter[0]
+ + content.parts[-1].executable_code.code
+ + code_block_delimiter[1]
+ )
+ )
+ # Handle the conversion of trailing code execution result parts.
+ # Skip if the Content has multiple parts, which means the Content is
+ # likely generated by the model.
+ elif len(content.parts) == 1 and content.parts[-1].code_execution_result:
+ content.parts[-1] = types.Part(
+ text=execution_result_delimiters[0]
+ + content.parts[-1].code_execution_result.output
+ + execution_result_delimiters[1]
+ )
+ content.role = 'user'
diff --git a/src/google/adk/code_executors/code_executor_context.py b/src/google/adk/code_executors/code_executor_context.py
new file mode 100644
index 0000000..f649c48
--- /dev/null
+++ b/src/google/adk/code_executors/code_executor_context.py
@@ -0,0 +1,202 @@
+# Copyright 2025 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""The persistent context used to configure the code executor."""
+
+import copy
+import dataclasses
+import datetime
+from typing import Any
+from typing import Optional
+
+from ..sessions.state import State
+from .code_execution_utils import File
+
+_CONTEXT_KEY = '_code_execution_context'
+_SESSION_ID_KEY = 'execution_session_id'
+_PROCESSED_FILE_NAMES_KEY = 'processed_input_files'
+_INPUT_FILE_KEY = '_code_executor_input_files'
+_ERROR_COUNT_KEY = '_code_executor_error_counts'
+
+_CODE_EXECUTION_RESULTS_KEY = '_code_execution_results'
+
+
+class CodeExecutorContext:
+ """The persistent context used to configure the code executor."""
+
+ _context: dict[str, Any]
+
+ def __init__(self, session_state: State):
+ """Initializes the code executor context.
+
+ Args:
+ session_state: The session state to get the code executor context from.
+ """
+ self._context = self._get_code_executor_context(session_state)
+ self._session_state = session_state
+
+ def get_state_delta(self) -> dict[str, Any]:
+ """Gets the state delta to update in the persistent session state.
+
+ Returns:
+ The state delta to update in the persistent session state.
+ """
+ context_to_update = copy.deepcopy(self._context)
+ return {_CONTEXT_KEY: context_to_update}
+
+ def get_execution_id(self) -> Optional[str]:
+ """Gets the session ID for the code executor.
+
+ Returns:
+ The session ID for the code executor context.
+ """
+ if _SESSION_ID_KEY not in self._context:
+ return None
+ return self._context[_SESSION_ID_KEY]
+
+ def set_execution_id(self, session_id: str):
+ """Sets the session ID for the code executor.
+
+ Args:
+ session_id: The session ID for the code executor.
+ """
+ self._context[_SESSION_ID_KEY] = session_id
+
+ def get_processed_file_names(self) -> list[str]:
+ """Gets the processed file names from the session state.
+
+ Returns:
+ A list of processed file names in the code executor context.
+ """
+ if _PROCESSED_FILE_NAMES_KEY not in self._context:
+ return []
+ return self._context[_PROCESSED_FILE_NAMES_KEY]
+
+ def add_processed_file_names(self, file_names: [str]):
+ """Adds the processed file name to the session state.
+
+ Args:
+ file_names: The processed file names to add to the session state.
+ """
+ if _PROCESSED_FILE_NAMES_KEY not in self._context:
+ self._context[_PROCESSED_FILE_NAMES_KEY] = []
+ self._context[_PROCESSED_FILE_NAMES_KEY].extend(file_names)
+
+ def get_input_files(self) -> list[File]:
+ """Gets the code executor input file names from the session state.
+
+ Returns:
+ A list of input files in the code executor context.
+ """
+ if _INPUT_FILE_KEY not in self._session_state:
+ return []
+ return [File(**file) for file in self._session_state[_INPUT_FILE_KEY]]
+
+ def add_input_files(
+ self,
+ input_files: list[File],
+ ):
+ """Adds the input files to the code executor context.
+
+ Args:
+ input_files: The input files to add to the code executor context.
+ """
+ if _INPUT_FILE_KEY not in self._session_state:
+ self._session_state[_INPUT_FILE_KEY] = []
+ for input_file in input_files:
+ self._session_state[_INPUT_FILE_KEY].append(
+ dataclasses.asdict(input_file)
+ )
+
+ def clear_input_files(self):
+ """Removes the input files and processed file names to the code executor context."""
+ if _INPUT_FILE_KEY in self._session_state:
+ self._session_state[_INPUT_FILE_KEY] = []
+ if _PROCESSED_FILE_NAMES_KEY in self._context:
+ self._context[_PROCESSED_FILE_NAMES_KEY] = []
+
+ def get_error_count(self, invocation_id: str) -> int:
+ """Gets the error count from the session state.
+
+ Args:
+ invocation_id: The invocation ID to get the error count for.
+
+ Returns:
+ The error count for the given invocation ID.
+ """
+ if _ERROR_COUNT_KEY not in self._session_state:
+ return 0
+ return self._session_state[_ERROR_COUNT_KEY].get(invocation_id, 0)
+
+ def increment_error_count(self, invocation_id: str):
+ """Increments the error count from the session state.
+
+ Args:
+ invocation_id: The invocation ID to increment the error count for.
+ """
+ if _ERROR_COUNT_KEY not in self._session_state:
+ self._session_state[_ERROR_COUNT_KEY] = {}
+ self._session_state[_ERROR_COUNT_KEY][invocation_id] = (
+ self.get_error_count(invocation_id) + 1
+ )
+
+ def reset_error_count(self, invocation_id: str):
+ """Resets the error count from the session state.
+
+ Args:
+ invocation_id: The invocation ID to reset the error count for.
+ """
+ if _ERROR_COUNT_KEY not in self._session_state:
+ return
+ if invocation_id in self._session_state[_ERROR_COUNT_KEY]:
+ del self._session_state[_ERROR_COUNT_KEY][invocation_id]
+
+ def update_code_execution_result(
+ self,
+ invocation_id: str,
+ code: str,
+ result_stdout: str,
+ result_stderr: str,
+ ):
+ """Updates the code execution result.
+
+ Args:
+ invocation_id: The invocation ID to update the code execution result for.
+ code: The code to execute.
+ result_stdout: The standard output of the code execution.
+ result_stderr: The standard error of the code execution.
+ """
+ if _CODE_EXECUTION_RESULTS_KEY not in self._session_state:
+ self._session_state[_CODE_EXECUTION_RESULTS_KEY] = {}
+ if invocation_id not in self._session_state[_CODE_EXECUTION_RESULTS_KEY]:
+ self._session_state[_CODE_EXECUTION_RESULTS_KEY][invocation_id] = []
+ self._session_state[_CODE_EXECUTION_RESULTS_KEY][invocation_id].append({
+ 'code': code,
+ 'result_stdout': result_stdout,
+ 'result_stderr': result_stderr,
+ 'timestamp': int(datetime.datetime.now().timestamp()),
+ })
+
+ def _get_code_executor_context(self, session_state: State) -> dict[str, Any]:
+ """Gets the code executor context from the session state.
+
+ Args:
+ session_state: The session state to get the code executor context from.
+
+ Returns:
+ A dict of code executor context.
+ """
+ if _CONTEXT_KEY not in session_state:
+ session_state[_CONTEXT_KEY] = {}
+ return session_state[_CONTEXT_KEY]
diff --git a/src/google/adk/code_executors/container_code_executor.py b/src/google/adk/code_executors/container_code_executor.py
new file mode 100644
index 0000000..0ce2ec3
--- /dev/null
+++ b/src/google/adk/code_executors/container_code_executor.py
@@ -0,0 +1,196 @@
+# Copyright 2025 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import atexit
+import os
+from typing import Optional
+
+import docker
+from docker.client import DockerClient
+from docker.models.containers import Container
+from pydantic import Field
+from typing_extensions import override
+
+from ..agents.invocation_context import InvocationContext
+from .base_code_executor import BaseCodeExecutor
+from .code_execution_utils import CodeExecutionInput
+from .code_execution_utils import CodeExecutionResult
+
+
+DEFAULT_IMAGE_TAG = 'adk-code-executor:latest'
+
+
+class ContainerCodeExecutor(BaseCodeExecutor):
+ """A code executor that uses a custom container to execute code.
+
+ Attributes:
+ base_url: Optional. The base url of the user hosted Docker client.
+ image: The tag of the predefined image or custom image to run on the
+ container. Either docker_path or image must be set.
+ docker_path: The path to the directory containing the Dockerfile. If set,
+ build the image from the dockerfile path instead of using the predefined
+ image. Either docker_path or image must be set.
+ """
+
+ base_url: Optional[str] = None
+ """
+ Optional. The base url of the user hosted Docker client.
+ """
+
+ image: str = None
+ """
+ The tag of the predefined image or custom image to run on the container.
+ Either docker_path or image must be set.
+ """
+
+ docker_path: str = None
+ """
+ The path to the directory containing the Dockerfile.
+ If set, build the image from the dockerfile path instead of using the
+ predefined image. Either docker_path or image must be set.
+ """
+
+ # Overrides the BaseCodeExecutor attribute: this executor cannot be stateful.
+ stateful: bool = Field(default=False, frozen=True, exclude=True)
+
+ # Overrides the BaseCodeExecutor attribute: this executor cannot
+ # optimize_data_file.
+ optimize_data_file: bool = Field(default=False, frozen=True, exclude=True)
+
+ _client: DockerClient = None
+ _container: Container = None
+
+ def __init__(
+ self,
+ base_url: Optional[str] = None,
+ image: Optional[str] = None,
+ docker_path: Optional[str] = None,
+ **data,
+ ):
+ """Initializes the ContainerCodeExecutor.
+
+ Args:
+ base_url: Optional. The base url of the user hosted Docker client.
+ image: The tag of the predefined image or custom image to run on the
+ container. Either docker_path or image must be set.
+ docker_path: The path to the directory containing the Dockerfile. If set,
+ build the image from the dockerfile path instead of using the predefined
+ image. Either docker_path or image must be set.
+ **data: The data to initialize the ContainerCodeExecutor.
+ """
+ if not image and not docker_path:
+ raise ValueError(
+ 'Either image or docker_path must be set for ContainerCodeExecutor.'
+ )
+ if 'stateful' in data and data['stateful']:
+ raise ValueError('Cannot set `stateful=True` in ContainerCodeExecutor.')
+ if 'optimize_data_file' in data and data['optimize_data_file']:
+ raise ValueError(
+ 'Cannot set `optimize_data_file=True` in ContainerCodeExecutor.'
+ )
+
+ super().__init__(**data)
+ self.base_url = base_url
+ self.image = image if image else DEFAULT_IMAGE_TAG
+ self.docker_path = os.path.abspath(docker_path) if docker_path else None
+
+ self._client = (
+ docker.from_env()
+ if not self.base_url
+ else docker.DockerClient(base_url=self.base_url)
+ )
+ # Initialize the container.
+ self.__init_container()
+
+ # Close the container when the on exit.
+ atexit.register(self.__cleanup_container)
+
+ @override
+ def execute_code(
+ self,
+ invocation_context: InvocationContext,
+ code_execution_input: CodeExecutionInput,
+ ) -> CodeExecutionResult:
+ output = ''
+ error = ''
+ exec_result = self._container.exec_run(
+ ['python3', '-c', code_execution_input.code],
+ demux=True,
+ )
+
+ if exec_result.output and exec_result.output[0]:
+ output = exec_result.output[0].decode('utf-8')
+ if (
+ exec_result.output
+ and len(exec_result.output) > 1
+ and exec_result.output[1]
+ ):
+ error = exec_result.output[1].decode('utf-8')
+
+ # Collect the final result.
+ return CodeExecutionResult(
+ stdout=output,
+ stderr=error,
+ output_files=[],
+ )
+
+ def _build_docker_image(self):
+ """Builds the Docker image."""
+ if not self.docker_path:
+ raise ValueError('Docker path is not set.')
+ if not os.path.exists(self.docker_path):
+ raise FileNotFoundError(f'Invalid Docker path: {self.docker_path}')
+
+ print('Building Docker image...')
+ self._client.images.build(
+ path=self.docker_path,
+ tag=self.image,
+ rm=True,
+ )
+ print(f'Docker image: {self.image} built.')
+
+ def _verify_python_installation(self):
+ """Verifies the container has python3 installed."""
+ exec_result = self._container.exec_run(['which', 'python3'])
+ if exec_result.exit_code != 0:
+ raise ValueError('python3 is not installed in the container.')
+
+ def __init_container(self):
+ """Initializes the container."""
+ if not self._client:
+ raise RuntimeError('Docker client is not initialized.')
+
+ if self.docker_path:
+ self._build_docker_image()
+
+ print('Starting container for ContainerCodeExecutor...')
+ self._container = self._client.containers.run(
+ image=self.image,
+ detach=True,
+ tty=True,
+ )
+ print(f'Container {self._container.id} started.')
+
+ # Verify the container is able to run python3.
+ self._verify_python_installation()
+
+ def __cleanup_container(self):
+ """Closes the container on exit."""
+ if not self._container:
+ return
+
+ print('[Cleanup] Stopping the container...')
+ self._container.stop()
+ self._container.remove()
+ print(f'Container {self._container.id} stopped and removed.')
diff --git a/src/google/adk/code_executors/unsafe_local_code_executor.py b/src/google/adk/code_executors/unsafe_local_code_executor.py
new file mode 100644
index 0000000..e1e8004
--- /dev/null
+++ b/src/google/adk/code_executors/unsafe_local_code_executor.py
@@ -0,0 +1,71 @@
+# Copyright 2025 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from contextlib import redirect_stdout
+import io
+
+from pydantic import Field
+from typing_extensions import override
+
+from ..agents.invocation_context import InvocationContext
+from .base_code_executor import BaseCodeExecutor
+from .code_execution_utils import CodeExecutionInput
+from .code_execution_utils import CodeExecutionResult
+
+
+class UnsafeLocalCodeExecutor(BaseCodeExecutor):
+ """A code executor that unsafely execute code in the current local context."""
+
+ # Overrides the BaseCodeExecutor attribute: this executor cannot be stateful.
+ stateful: bool = Field(default=False, frozen=True, exclude=True)
+
+ # Overrides the BaseCodeExecutor attribute: this executor cannot
+ # optimize_data_file.
+ optimize_data_file: bool = Field(default=False, frozen=True, exclude=True)
+
+ def __init__(self, **data):
+ """Initializes the UnsafeLocalCodeExecutor."""
+ if 'stateful' in data and data['stateful']:
+ raise ValueError('Cannot set `stateful=True` in UnsafeLocalCodeExecutor.')
+ if 'optimize_data_file' in data and data['optimize_data_file']:
+ raise ValueError(
+ 'Cannot set `optimize_data_file=True` in UnsafeLocalCodeExecutor.'
+ )
+ super().__init__(**data)
+
+ @override
+ def execute_code(
+ self,
+ invocation_context: InvocationContext,
+ code_execution_input: CodeExecutionInput,
+ ) -> CodeExecutionResult:
+ # Execute the code.
+ output = ''
+ error = ''
+ try:
+ globals_ = {}
+ locals_ = {}
+ stdout = io.StringIO()
+ with redirect_stdout(stdout):
+ exec(code_execution_input.code, globals_, locals_)
+ output = stdout.getvalue()
+ except Exception as e:
+ error = str(e)
+
+ # Collect the final result.
+ return CodeExecutionResult(
+ stdout=output,
+ stderr=error,
+ output_files=[],
+ )
diff --git a/src/google/adk/code_executors/vertex_ai_code_executor.py b/src/google/adk/code_executors/vertex_ai_code_executor.py
new file mode 100644
index 0000000..31a0585
--- /dev/null
+++ b/src/google/adk/code_executors/vertex_ai_code_executor.py
@@ -0,0 +1,234 @@
+# Copyright 2025 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import datetime
+import mimetypes
+import os
+from typing import Any, Optional
+
+from typing_extensions import override
+from vertexai.preview.extensions import Extension
+
+from ..agents.invocation_context import InvocationContext
+from .base_code_executor import BaseCodeExecutor
+from .code_execution_utils import CodeExecutionInput
+from .code_execution_utils import CodeExecutionResult
+from .code_execution_utils import File
+
+_SUPPORTED_IMAGE_TYPES = ['png', 'jpg', 'jpeg']
+_SUPPORTED_DATA_FILE_TYPES = ['csv']
+
+_IMPORTED_LIBRARIES = '''
+import io
+import math
+import re
+
+import matplotlib.pyplot as plt
+import numpy as np
+import pandas as pd
+import scipy
+
+def crop(s: str, max_chars: int = 64) -> str:
+ """Crops a string to max_chars characters."""
+ return s[: max_chars - 3] + '...' if len(s) > max_chars else s
+
+
+def explore_df(df: pd.DataFrame) -> None:
+ """Prints some information about a pandas DataFrame."""
+
+ with pd.option_context(
+ 'display.max_columns', None, 'display.expand_frame_repr', False
+ ):
+ # Print the column names to never encounter KeyError when selecting one.
+ df_dtypes = df.dtypes
+
+ # Obtain information about data types and missing values.
+ df_nulls = (len(df) - df.isnull().sum()).apply(
+ lambda x: f'{x} / {df.shape[0]} non-null'
+ )
+
+ # Explore unique total values in columns using `.unique()`.
+ df_unique_count = df.apply(lambda x: len(x.unique()))
+
+ # Explore unique values in columns using `.unique()`.
+ df_unique = df.apply(lambda x: crop(str(list(x.unique()))))
+
+ df_info = pd.concat(
+ (
+ df_dtypes.rename('Dtype'),
+ df_nulls.rename('Non-Null Count'),
+ df_unique_count.rename('Unique Values Count'),
+ df_unique.rename('Unique Values'),
+ ),
+ axis=1,
+ )
+ df_info.index.name = 'Columns'
+ print(f"""Total rows: {df.shape[0]}
+Total columns: {df.shape[1]}
+
+{df_info}""")
+'''
+
+
+def _get_code_interpreter_extension(resource_name: str = None):
+ """Returns: Load or create the code interpreter extension."""
+ if not resource_name:
+ resource_name = os.environ.get('CODE_INTERPRETER_EXTENSION_NAME')
+ if resource_name:
+ new_code_interpreter = Extension(resource_name)
+ else:
+ print('No CODE_INTERPRETER_ID found in the environment. Create a new one.')
+ new_code_interpreter = Extension.from_hub('code_interpreter')
+ os.environ['CODE_INTERPRETER_EXTENSION_NAME'] = (
+ new_code_interpreter.gca_resource.name
+ )
+ return new_code_interpreter
+
+
+class VertexAiCodeExecutor(BaseCodeExecutor):
+ """A code executor that uses Vertex Code Interpreter Extension to execute code.
+
+ Attributes:
+ resource_name: If set, load the existing resource name of the code
+ interpreter extension instead of creating a new one. Format:
+ projects/123/locations/us-central1/extensions/456
+ """
+
+ resource_name: str = None
+ """
+ If set, load the existing resource name of the code interpreter extension
+ instead of creating a new one.
+ Format: projects/123/locations/us-central1/extensions/456
+ """
+
+ _code_interpreter_extension: Extension
+
+ def __init__(
+ self,
+ resource_name: str = None,
+ **data,
+ ):
+ """Initializes the VertexAiCodeExecutor.
+
+ Args:
+ resource_name: If set, load the existing resource name of the code
+ interpreter extension instead of creating a new one. Format:
+ projects/123/locations/us-central1/extensions/456
+ **data: Additional keyword arguments to be passed to the base class.
+ """
+ super().__init__(**data)
+ self.resource_name = resource_name
+ self._code_interpreter_extension = _get_code_interpreter_extension(
+ self.resource_name
+ )
+
+ @override
+ def execute_code(
+ self,
+ invocation_context: InvocationContext,
+ code_execution_input: CodeExecutionInput,
+ ) -> CodeExecutionResult:
+ # Execute the code.
+ code_execution_result = self._execute_code_interpreter(
+ self._get_code_with_imports(code_execution_input.code),
+ code_execution_input.input_files,
+ code_execution_input.execution_id,
+ )
+
+ # Save output file as artifacts.
+ current_timestamp = datetime.datetime.now().strftime('%Y%m%d_%H%M%S')
+ file_name_prefix = '%s_' % str(current_timestamp)
+ saved_files = []
+ file_count = 0
+ for output_file in code_execution_result['output_files']:
+ file_type = output_file['name'].split('.')[-1]
+ file_name = file_name_prefix + '%d.%s' % (file_count, file_type)
+ if file_type in _SUPPORTED_IMAGE_TYPES:
+ file_count += 1
+ saved_files.append(
+ File(
+ name='plot_' + file_name,
+ content=output_file['contents'],
+ mime_type=f'image/{file_type}',
+ )
+ )
+ elif file_type in _SUPPORTED_DATA_FILE_TYPES:
+ file_count += 1
+ saved_files.append(
+ File(
+ name='data_' + file_name,
+ content=output_file['contents'],
+ mime_type=f'text/{file_type}',
+ )
+ )
+ else:
+ mime_type, _ = mimetypes.guess_type(file_name)
+ saved_files.append(
+ File(
+ name=file_name,
+ content=output_file['contents'],
+ mime_type=mime_type,
+ )
+ )
+
+ # Collect the final result.
+ return CodeExecutionResult(
+ stdout=code_execution_result.get('execution_result', ''),
+ stderr=code_execution_result.get('execution_error', ''),
+ output_files=saved_files,
+ )
+
+ def _execute_code_interpreter(
+ self,
+ code: str,
+ input_files: Optional[list[File]] = None,
+ session_id: Optional[str] = None,
+ ) -> dict[str, Any]:
+ """Executes the code interpreter extension.
+
+ Args:
+ code: The code to execute.
+ input_files: The input files to execute the code with.
+ session_id: The session ID to execute the code with.
+
+ Returns:
+ The response from the code interpreter extension.
+ """
+ operation_params = {'code': code}
+ if input_files:
+ operation_params['files'] = [
+ {'name': f.name, 'contents': f.content} for f in input_files
+ ]
+ if session_id:
+ operation_params['session_id'] = session_id
+ response = self._code_interpreter_extension.execute(
+ operation_id='execute',
+ operation_params=operation_params,
+ )
+ return response
+
+ def _get_code_with_imports(self, code: str) -> str:
+ """Builds the code string with built-in imports.
+
+ Args:
+ code: The code to execute.
+
+ Returns:
+ The code string with built-in imports.
+ """
+ return f"""
+{_IMPORTED_LIBRARIES}
+
+{code}
+"""
diff --git a/src/google/adk/evaluation/__init__.py b/src/google/adk/evaluation/__init__.py
new file mode 100644
index 0000000..ae92ac7
--- /dev/null
+++ b/src/google/adk/evaluation/__init__.py
@@ -0,0 +1,31 @@
+# Copyright 2025 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import logging
+
+logger = logging.getLogger(__name__)
+
+__all__ = []
+
+try:
+ from .agent_evaluator import AgentEvaluator
+
+ __all__.append('AgentEvaluator')
+except ImportError:
+ logger.debug(
+ 'The Vertex[eval] sdk is not installed. If you want to use the Vertex'
+ ' Evaluation with agents, please install it(pip install'
+ ' "google-cloud-aiplatform[evaluation]). If not, you can ignore this'
+ ' warning.'
+ )
diff --git a/src/google/adk/evaluation/agent_evaluator.py b/src/google/adk/evaluation/agent_evaluator.py
new file mode 100644
index 0000000..71d9496
--- /dev/null
+++ b/src/google/adk/evaluation/agent_evaluator.py
@@ -0,0 +1,329 @@
+# Copyright 2025 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import json
+import os
+from os import path
+from typing import Dict
+from typing import List
+from typing import Union
+
+from .evaluation_generator import EvaluationGenerator
+from .response_evaluator import ResponseEvaluator
+from .trajectory_evaluator import TrajectoryEvaluator
+
+# Constants for default runs and evaluation criteria
+NUM_RUNS = 2
+TOOL_TRAJECTORY_SCORE_KEY = "tool_trajectory_avg_score"
+# This evaluation is not very stable.
+# This is always optional unless explicitly specified.
+RESPONSE_EVALUATION_SCORE_KEY = "response_evaluation_score"
+RESPONSE_MATCH_SCORE_KEY = "response_match_score"
+
+ALLOWED_CRITERIA = [
+ TOOL_TRAJECTORY_SCORE_KEY,
+ RESPONSE_EVALUATION_SCORE_KEY,
+ RESPONSE_MATCH_SCORE_KEY,
+]
+
+
+QUERY_COLUMN = "query"
+REFERENCE_COLUMN = "reference"
+EXPECTED_TOOL_USE_COLUMN = "expected_tool_use"
+
+
+DEFAULT_CRITERIA = {
+ TOOL_TRAJECTORY_SCORE_KEY: 1.0, # 1-point scale; 1.0 is perfect.
+ RESPONSE_MATCH_SCORE_KEY: 0.8, # Rouge-1 text match; 0.8 is default.
+}
+
+
+def load_json(file_path: str) -> Union[Dict, List]:
+ with open(file_path, "r") as f:
+ return json.load(f)
+
+
+class AgentEvaluator:
+ """An evaluator for Agents, mainly intented for helping with test cases."""
+
+ @staticmethod
+ def find_config_for_test_file(test_file: str):
+ """Find the test_config.json file in the same folder as the test file."""
+ test_folder = os.path.dirname(test_file)
+ config_path = os.path.join(test_folder, "test_config.json")
+ if os.path.exists(config_path):
+ config_data = load_json(config_path)
+ if "criteria" in config_data and isinstance(
+ config_data["criteria"], dict
+ ):
+ return config_data["criteria"]
+ else:
+ raise ValueError(
+ f"Invalid format for test_config.json at {config_path}. Expected a"
+ " 'criteria' dictionary."
+ )
+ return DEFAULT_CRITERIA
+
+ @staticmethod
+ def evaluate(
+ agent_module,
+ eval_dataset_file_path_or_dir,
+ num_runs=NUM_RUNS,
+ agent_name=None,
+ initial_session_file=None,
+ ):
+ """Evaluates an Agent given eval data.
+
+ Args:
+ agent_module: The path to python module that contains the definition of
+ the agent. There is convention in place here, where the code is going to
+ look for 'root_agent' in the loaded module.
+ eval_dataset: The eval data set. This can be either a string representing
+ full path to the file containing eval dataset, or a directory that is
+ recusively explored for all files that have a `.test.json` suffix.
+ num_runs: Number of times all entries in the eval dataset should be
+ assessed.
+ agent_name: The name of the agent.
+ initial_session_file: File that contains initial session state that is
+ needed by all the evals in the eval dataset.
+ """
+ test_files = []
+ if isinstance(eval_dataset_file_path_or_dir, str) and os.path.isdir(
+ eval_dataset_file_path_or_dir
+ ):
+ for root, _, files in os.walk(eval_dataset_file_path_or_dir):
+ for file in files:
+ if file.endswith(".test.json"):
+ test_files.append(path.join(root, file))
+ else:
+ test_files = [eval_dataset_file_path_or_dir]
+
+ initial_session_state = {}
+ if initial_session_file:
+ with open(initial_session_file, "r") as f:
+ initial_session_state = json.loads(f.read())["state"]
+
+ for test_file in test_files:
+ dataset = AgentEvaluator._load_dataset(test_file)[0]
+ criteria = AgentEvaluator.find_config_for_test_file(test_file)
+
+ AgentEvaluator._validate_input([dataset], criteria)
+
+ evaluation_response = AgentEvaluator._generate_responses(
+ agent_module,
+ [dataset],
+ num_runs,
+ agent_name=agent_name,
+ initial_session={"state": initial_session_state},
+ )
+
+ if AgentEvaluator._response_evaluation_required(criteria, [dataset]):
+ AgentEvaluator._evaluate_response_scores(
+ agent_module, evaluation_response, criteria
+ )
+
+ if AgentEvaluator._trajectory_evaluation_required(criteria, [dataset]):
+ AgentEvaluator._evaluate_tool_trajectory(
+ agent_module, evaluation_response, criteria
+ )
+
+ @staticmethod
+ def _load_dataset(
+ input_data: Union[str, List[str], List[Dict], List[List[Dict]]],
+ ) -> List[List[Dict]]:
+ def load_json_file(file_path: str) -> List[Dict]:
+ data = load_json(file_path)
+ if not isinstance(data, list) or not all(
+ isinstance(d, dict) for d in data
+ ):
+ raise ValueError(f"{file_path} must contain a list of dictionaries.")
+ return data
+
+ if isinstance(input_data, str):
+ if os.path.isdir(input_data):
+ test_files = []
+ for root, _, files in os.walk(input_data):
+ for file in files:
+ if file.endswith(".test.json"):
+ test_files.append(os.path.join(root, file))
+ return [load_json_file(f) for f in test_files]
+ elif os.path.isfile(input_data):
+ return [load_json_file(input_data)]
+ else:
+ raise ValueError(f"Input path {input_data} is invalid.")
+ elif isinstance(input_data, list):
+ if all(isinstance(i, str) and os.path.isfile(i) for i in input_data):
+ return [load_json_file(i) for i in input_data]
+ raise TypeError("Input list must contain valid file paths.")
+ raise TypeError("Invalid input type for dataset loading.")
+
+ @staticmethod
+ def _validate_input(eval_dataset, criteria):
+ """Validates that the evaluation criteria align with the provided dataset.
+
+ For efficiency, we only use first row to validate input.
+ """
+ if not eval_dataset:
+ raise ValueError("The evaluation dataset is None or empty.")
+
+ for key in criteria:
+ if key not in ALLOWED_CRITERIA:
+ raise ValueError(
+ f"Invalid criteria key: {key}. Expected one of {ALLOWED_CRITERIA}."
+ )
+
+ if not eval_dataset:
+ raise ValueError("The evaluation dataset is empty.")
+ sample = eval_dataset[0]
+ first_query = sample[0]
+
+ if not isinstance(sample, list) and not isinstance(first_query, dict):
+ raise ValueError(
+ "Each evaluation dataset sample must be list of dictionary. But it's"
+ f" {eval_dataset}"
+ )
+
+ if TOOL_TRAJECTORY_SCORE_KEY in criteria:
+ if (
+ QUERY_COLUMN not in first_query
+ or EXPECTED_TOOL_USE_COLUMN not in first_query
+ ):
+ raise ValueError(
+ f"Samples for {TOOL_TRAJECTORY_SCORE_KEY} must include"
+ f" '{QUERY_COLUMN}' and '{EXPECTED_TOOL_USE_COLUMN}' keys. The"
+ f" sample is {sample}."
+ )
+
+ if RESPONSE_EVALUATION_SCORE_KEY in criteria:
+ if QUERY_COLUMN not in first_query:
+ raise ValueError(
+ f"Samples for {RESPONSE_EVALUATION_SCORE_KEY} must include"
+ f" '{QUERY_COLUMN}' key. The sample is {sample}."
+ )
+
+ if RESPONSE_MATCH_SCORE_KEY in criteria:
+ if QUERY_COLUMN not in first_query or REFERENCE_COLUMN not in first_query:
+ raise ValueError(
+ f"Samples for {RESPONSE_MATCH_SCORE_KEY} must include"
+ f" '{QUERY_COLUMN}' and '{REFERENCE_COLUMN}' keys. The sample is"
+ f" {sample}."
+ )
+
+ @staticmethod
+ def _get_infer_criteria(eval_dataset):
+ """Infers evaluation criteria based on the provided dataset.
+
+ Args:
+ eval_dataset (list): A list of evaluation samples.
+
+ Returns:
+ dict: Inferred evaluation criteria based on dataset fields.
+ """
+ inferred_criteria = {}
+ sample = eval_dataset[0][0]
+
+ if QUERY_COLUMN in sample and EXPECTED_TOOL_USE_COLUMN in sample:
+ inferred_criteria[TOOL_TRAJECTORY_SCORE_KEY] = DEFAULT_CRITERIA[
+ TOOL_TRAJECTORY_SCORE_KEY
+ ]
+
+ if QUERY_COLUMN in sample and REFERENCE_COLUMN in sample:
+ inferred_criteria[RESPONSE_MATCH_SCORE_KEY] = DEFAULT_CRITERIA[
+ RESPONSE_MATCH_SCORE_KEY
+ ]
+
+ return inferred_criteria
+
+ @staticmethod
+ def _generate_responses(
+ agent_module, eval_dataset, num_runs, agent_name=None, initial_session={}
+ ):
+ """Generates evaluation responses by running the agent module multiple times."""
+ return EvaluationGenerator.generate_responses(
+ eval_dataset,
+ agent_module,
+ repeat_num=num_runs,
+ agent_name=agent_name,
+ initial_session=initial_session,
+ )
+
+ @staticmethod
+ def _generate_responses_from_session(eval_dataset, session_path):
+ """Generates evaluation responses by running the agent module multiple times."""
+ return EvaluationGenerator.generate_responses_from_session(
+ session_path, eval_dataset
+ )
+
+ @staticmethod
+ def _response_evaluation_required(criteria, eval_dataset):
+ """Checks if response evaluation are needed."""
+ return REFERENCE_COLUMN in eval_dataset[0][0] and any(
+ key in criteria
+ for key in [RESPONSE_EVALUATION_SCORE_KEY, RESPONSE_MATCH_SCORE_KEY]
+ )
+
+ @staticmethod
+ def _trajectory_evaluation_required(evaluation_criteria, eval_dataset):
+ """Checks if response evaluation are needed."""
+ return (
+ EXPECTED_TOOL_USE_COLUMN in eval_dataset[0][0]
+ and TOOL_TRAJECTORY_SCORE_KEY in evaluation_criteria
+ )
+
+ @staticmethod
+ def _evaluate_response_scores(agent_module, evaluation_response, criteria):
+ """Evaluates response scores and raises an assertion error if they don't meet the criteria."""
+ metrics = ResponseEvaluator.evaluate(
+ evaluation_response, criteria, print_detailed_results=True
+ )
+
+ AgentEvaluator._assert_score(
+ metrics,
+ "coherence/mean",
+ criteria.get(RESPONSE_EVALUATION_SCORE_KEY),
+ "Average response evaluation score",
+ agent_module,
+ )
+
+ AgentEvaluator._assert_score(
+ metrics,
+ "rouge_1/mean",
+ criteria.get(RESPONSE_MATCH_SCORE_KEY),
+ "Average response match score",
+ agent_module,
+ )
+
+ @staticmethod
+ def _evaluate_tool_trajectory(agent_module, evaluation_response, criteria):
+ """Evaluates tool trajectory scores and raises an assertion error if they don't meet the criteria."""
+ score = TrajectoryEvaluator.evaluate(
+ evaluation_response, print_detailed_results=True
+ )
+ AgentEvaluator._assert_score(
+ {TOOL_TRAJECTORY_SCORE_KEY: score},
+ TOOL_TRAJECTORY_SCORE_KEY,
+ criteria[TOOL_TRAJECTORY_SCORE_KEY],
+ "Average tool trajectory evaluation score",
+ agent_module,
+ )
+
+ @staticmethod
+ def _assert_score(metrics, metric_key, threshold, description, agent_module):
+ """Asserts that a metric meets the specified threshold."""
+ if metric_key in metrics:
+ actual_score = metrics[metric_key]
+ assert actual_score >= threshold, (
+ f"{description} for {agent_module} is lower than expected. "
+ f"Expected >= {threshold}, but got {actual_score}."
+ )
diff --git a/src/google/adk/evaluation/evaluation_constants.py b/src/google/adk/evaluation/evaluation_constants.py
new file mode 100644
index 0000000..73b7b5c
--- /dev/null
+++ b/src/google/adk/evaluation/evaluation_constants.py
@@ -0,0 +1,24 @@
+# Copyright 2025 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+class EvalConstants:
+ """Holds constants for evaluation file constants."""
+
+ QUERY = "query"
+ EXPECTED_TOOL_USE = "expected_tool_use"
+ RESPONSE = "response"
+ REFERENCE = "reference"
+ TOOL_NAME = "tool_name"
+ TOOL_INPUT = "tool_input"
+ MOCK_TOOL_OUTPUT = "mock_tool_output"
diff --git a/src/google/adk/evaluation/evaluation_generator.py b/src/google/adk/evaluation/evaluation_generator.py
new file mode 100644
index 0000000..3d93659
--- /dev/null
+++ b/src/google/adk/evaluation/evaluation_generator.py
@@ -0,0 +1,270 @@
+# Copyright 2025 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import importlib
+import uuid
+
+from google.genai import types
+
+from ..agents.base_agent import BaseAgent
+from ..agents.llm_agent import Agent
+from ..agents.llm_agent import BeforeToolCallback
+from ..agents.llm_agent import LlmAgent
+from ..artifacts.in_memory_artifact_service import InMemoryArtifactService
+from ..runners import Runner
+from ..sessions.in_memory_session_service import InMemorySessionService
+from ..sessions.session import Session
+from .evaluation_constants import EvalConstants
+
+
+class EvaluationGenerator:
+ """Generates evaluation responses for agents."""
+
+ @staticmethod
+ def generate_responses(
+ eval_dataset,
+ agent_module_path,
+ repeat_num=3,
+ agent_name=None,
+ initial_session={},
+ ):
+ """Returns evaluation responses for the given dataset and agent.
+
+ Args:
+ eval_dataset: The dataset that needs to be scraped for resposnes.
+ agent_module_path: Path to the module that contains the root agent.
+ repeat_num: Number of time the eval dataset should be repeated. This is
+ usually done to remove uncertainity that a single run may bring.
+ agent_name: The name of the agent that should be evaluated. This is
+ usually the sub-agent.
+ initial_session: Initial session for the eval data.
+ """
+ results = []
+
+ for _ in range(repeat_num):
+ for data in eval_dataset:
+ results.append(
+ EvaluationGenerator._process_query(
+ data, agent_module_path, agent_name, initial_session
+ )
+ )
+
+ return results
+
+ @staticmethod
+ def generate_responses_from_session(session_path, eval_dataset):
+ """Returns evaluation responses by combining session data with eval data.
+
+ Args:
+ session_path: Path to a json file that contains session data.
+ eval_dataset: The eval data set that should be combined with the session
+ data.
+ """
+ results = []
+
+ with open(session_path, "r") as f:
+ session_data = Session.model_validate_json(f.read())
+ print("loaded session", session_path)
+
+ for data in eval_dataset:
+ # load session data from session_path
+ results.append(
+ EvaluationGenerator._process_query_with_session(
+ session_data,
+ data,
+ )
+ )
+
+ return results
+
+ @staticmethod
+ def _process_query(data, module_name, agent_name=None, initial_session={}):
+ """Process a query using the agent and evaluation dataset."""
+ module_path = f"{module_name}"
+ agent_module = importlib.import_module(module_path)
+ root_agent = agent_module.agent.root_agent
+
+ reset_func = getattr(agent_module.agent, "reset_data", None)
+
+ agent_to_evaluate = root_agent
+ if agent_name:
+ agent_to_evaluate = root_agent.find_agent(agent_name)
+ assert agent_to_evaluate, f"Sub-Agent `{agent_name}` not found."
+
+ return EvaluationGenerator._process_query_with_root_agent(
+ data, agent_to_evaluate, reset_func, initial_session
+ )
+
+ @staticmethod
+ def _process_query_with_root_agent(
+ data,
+ root_agent,
+ reset_func,
+ initial_session={},
+ session_id=None,
+ session_service=None,
+ artifact_service=None,
+ ):
+ """Process a query using the agent and evaluation dataset."""
+
+ # we don't know which tools belong to which agent
+ # so we just apply to any agents that has certain tool outputs
+ all_mock_tools = set()
+ for eval_entry in data:
+ expected_tool_use = eval_entry.get(EvalConstants.EXPECTED_TOOL_USE, [])
+ for expected in expected_tool_use:
+ if EvalConstants.MOCK_TOOL_OUTPUT in expected:
+ all_mock_tools.add(expected[EvalConstants.TOOL_NAME])
+
+ eval_data_copy = data.copy()
+ EvaluationGenerator.apply_before_tool_callback(
+ root_agent,
+ lambda *args: EvaluationGenerator.before_tool_callback(
+ *args, eval_dataset=eval_data_copy
+ ),
+ all_mock_tools,
+ )
+
+ if not session_service:
+ session_service = InMemorySessionService()
+
+ app_name = initial_session.get("app_name", "EvaluationGenerator")
+ user_id = initial_session.get("user_id", "test_user_id")
+ session_id = session_id if session_id else str(uuid.uuid4())
+
+ _ = session_service.create_session(
+ app_name=app_name,
+ user_id=user_id,
+ state=initial_session.get("state", {}),
+ session_id=session_id,
+ )
+
+ if not artifact_service:
+ artifact_service = InMemoryArtifactService()
+ runner = Runner(
+ app_name=app_name,
+ agent=root_agent,
+ artifact_service=artifact_service,
+ session_service=session_service,
+ )
+
+ # Reset agent state for each query
+ if callable(reset_func):
+ reset_func()
+
+ responses = data.copy()
+
+ for index, eval_entry in enumerate(responses):
+ response = None
+ query = eval_entry["query"]
+ content = types.Content(role="user", parts=[types.Part(text=query)])
+ turn_actual_tool_uses = []
+
+ for event in runner.run(
+ user_id=user_id, session_id=session_id, new_message=content
+ ):
+ if event.is_final_response() and event.content and event.content.parts:
+ response = event.content.parts[0].text
+ elif event.get_function_calls():
+ for call in event.get_function_calls():
+ turn_actual_tool_uses.append({
+ EvalConstants.TOOL_NAME: call.name,
+ EvalConstants.TOOL_INPUT: call.args,
+ })
+
+ responses[index]["actual_tool_use"] = turn_actual_tool_uses
+ responses[index]["response"] = response
+
+ return responses
+
+ @staticmethod
+ def _process_query_with_session(session_data, data):
+ """Process the queries using the existing session data without invoking the runner."""
+ responses = data.copy()
+
+ # Iterate through the provided queries and align them with the session events
+ for index, eval_entry in enumerate(responses):
+ query = eval_entry["query"]
+ actual_tool_uses = []
+ response = None
+
+ # Search for the corresponding session events
+ for event in session_data.events:
+ # Match the query to a user event
+ if (
+ event.author == "user"
+ and event.content
+ and event.content.parts
+ and event.content.parts[0].text == query
+ ):
+ # Look for subsequent tool usage or model responses
+ for subsequent_event in session_data.events:
+ if subsequent_event.invocation_id == event.invocation_id:
+ # Extract tool usage
+ if subsequent_event.content.parts[0].function_call:
+ call = subsequent_event.content.parts[0].function_call
+ actual_tool_uses.append(
+ {"tool_name": call.name, "tool_input": call.args}
+ )
+ # Extract final response
+ elif subsequent_event.author != "user":
+ response = subsequent_event.content.parts[0].text
+
+ # Update the results for the current query
+ responses[index]["actual_tool_use"] = actual_tool_uses
+ responses[index]["response"] = response
+ return responses
+
+ @staticmethod
+ def before_tool_callback(tool, args, tool_context, eval_dataset):
+ """Intercept specific tool calls and return predefined outputs
+
+ from eval_dataset.
+ """
+ for index, eval_entry in enumerate(eval_dataset):
+ expected_tool_use = eval_entry.get("expected_tool_use", [])
+ for expected in expected_tool_use:
+ if (
+ EvalConstants.MOCK_TOOL_OUTPUT in expected
+ and tool.name == expected[EvalConstants.TOOL_NAME]
+ and args == expected.get(EvalConstants.TOOL_INPUT, {})
+ ):
+ # pop the matched entry so we don't rematch again
+ eval_dataset.pop(index)
+ return {"result": expected[EvalConstants.MOCK_TOOL_OUTPUT]}
+
+ return None
+
+ @staticmethod
+ def apply_before_tool_callback(
+ agent: BaseAgent,
+ callback: BeforeToolCallback,
+ all_mock_tools: set[str],
+ ):
+ """Recursively apply the before_tool_callback to the root agent and all its subagents."""
+ # check if the agent has tools that defined by evalset
+ # We use function name to check if tools match
+ if not isinstance(agent, Agent) and not isinstance(agent, LlmAgent):
+ return
+
+ for tool in agent.canonical_tools:
+ tool_name = tool.name
+ if tool_name in all_mock_tools:
+ agent.before_tool_callback = callback
+
+ # Apply recursively to subagents if they exist
+ for sub_agent in agent.sub_agents:
+ EvaluationGenerator.apply_before_tool_callback(
+ sub_agent, callback, all_mock_tools
+ )
diff --git a/src/google/adk/evaluation/response_evaluator.py b/src/google/adk/evaluation/response_evaluator.py
new file mode 100644
index 0000000..463f21f
--- /dev/null
+++ b/src/google/adk/evaluation/response_evaluator.py
@@ -0,0 +1,135 @@
+# Copyright 2025 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import Any
+
+import pandas as pd
+from tabulate import tabulate
+from vertexai.preview.evaluation import EvalTask
+from vertexai.preview.evaluation import MetricPromptTemplateExamples
+
+
+class ResponseEvaluator:
+ """Runs response evaluation for agents."""
+
+ @staticmethod
+ def evaluate(
+ raw_eval_dataset: list[list[dict[str, Any]]],
+ evaluation_criteria: list[str],
+ *,
+ print_detailed_results: bool = False
+ ):
+ r"""Returns the value of requested evaluation metrics.
+
+ Args:
+ raw_eval_dataset: The dataset that will be evaluated.
+ evaluation_criteria: The evaluation criteria to be used. This method
+ support two criterias, `response_evaluation_score` and
+ `response_match_score`.
+ print_detailed_results: Prints detailed results on the console. This is
+ usually helpful during debugging.
+
+ A note on evaluation_criteria:
+ `response_match_score`: This metric compares the agents final natural
+ language reponse with the expected final response, stored in the
+ "reference" field in test/eval files. We use Rouge metric to compare the
+ two responses.
+
+ Value Range: [0, 1]. A score closer to 0 means poor similarity between
+ response and reference. A score closer to 1 means strong similarity
+ between response and reference.
+
+ `response_evaluation_score`: Uses LLM to evalaute coherence of the
+ response, including tool use. This is pointwise metric.
+
+ Value range: [0, 5], where 0 means that the agent's response is not
+ coherent, while 5 means it is . High values are good.
+ A note on raw_eval_dataset:
+ The dataset should be a list session, where each sesssion is represented
+ as a list of interaction that need evaluation. Each evaluation is
+ represented as a dictionary that is expected to have values for the
+ following keys:
+
+ 1) query
+ 2) response
+ 3) acutal_tool_use
+ 4) expected_tool_use
+ 5) reference
+
+ Here is a sample eval_dataset value with one entry:
+ [
+ [
+ {
+ "query": "roll a die for me",
+ "response": "I rolled a 16 sided die and got 13.\n",
+ "expected_tool_use": [
+ {
+ "tool_name": "roll_die",
+ "tool_input": {
+ "sides": 16
+ }
+ }
+ ],
+ "acutal_tool_use": [
+ {
+ "tool_name": "roll_die",
+ "tool_input": {
+ "sides": 16
+ }
+ }
+ ],
+ "reference": "I rolled a 16 sided die and got 13.\n"
+ }
+ ]
+ ]
+ """
+ if not raw_eval_dataset:
+ raise ValueError("The evaluation dataset is empty.")
+
+ metrics = ResponseEvaluator._get_metrics(
+ raw_eval_dataset, evaluation_criteria
+ )
+ flattened_queries = [
+ item for sublist in raw_eval_dataset for item in sublist
+ ]
+ eval_dataset = pd.DataFrame(flattened_queries).rename(
+ columns={"query": "prompt", "expected_tool_use": "reference_trajectory"}
+ )
+ eval_task = EvalTask(dataset=eval_dataset, metrics=metrics)
+
+ eval_result = eval_task.evaluate()
+ if print_detailed_results:
+ ResponseEvaluator._print_results(eval_result)
+ return eval_result.summary_metrics
+
+ @staticmethod
+ def _get_metrics(raw_eval_dataset, criteria):
+ metrics = []
+ if (
+ "response_evaluation_score" in criteria
+ and "query" in raw_eval_dataset[0][0]
+ and "expected_tool_use" in raw_eval_dataset[0][0]
+ ):
+ metrics.append(MetricPromptTemplateExamples.Pointwise.COHERENCE)
+ if (
+ "response_match_score" in criteria
+ and "reference" in raw_eval_dataset[0][0]
+ ):
+ metrics.append("rouge_1")
+ return metrics
+
+ @staticmethod
+ def _print_results(eval_result):
+ print("Evaluation Summary Metrics:", eval_result.summary_metrics)
+ print(tabulate(eval_result.metrics_table, headers="keys", tablefmt="grid"))
diff --git a/src/google/adk/evaluation/trajectory_evaluator.py b/src/google/adk/evaluation/trajectory_evaluator.py
new file mode 100644
index 0000000..8578ed9
--- /dev/null
+++ b/src/google/adk/evaluation/trajectory_evaluator.py
@@ -0,0 +1,184 @@
+# Copyright 2025 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import Any
+
+import pandas as pd
+from tabulate import tabulate
+
+from .evaluation_constants import EvalConstants
+
+
+class TrajectoryEvaluator:
+ """Evaluates tool use trajectories for accuracy."""
+
+ @staticmethod
+ def evaluate(
+ eval_dataset: list[list[dict[str, Any]]],
+ *,
+ print_detailed_results: bool = False,
+ ):
+ r"""Returns the mean tool use accuracy of the eval dataset.
+
+ Tool use accuracy is calculated by comparing the expected and actuall tool
+ use trajectories. An exact match scores a 1, 0 otherwise. The final number
+ is an
+ average of these individual scores.
+
+ Value range: [0, 1], where 0 is means none of the too use entries aligned,
+ and 1 would mean all of them aligned. Higher value is good.
+
+ Args:
+ eval_dataset: The dataset that will be evaluated.
+ print_detailed_results: Prints detailed results on the console. This is
+ usually helpful during debugging.
+
+ A note on eval_dataset:
+ The dataset should be a list session, where each sesssion is represented
+ as a list of interaction that need evaluation. Each evaluation is
+ represented as a dictionary that is expected to have values for the
+ following keys:
+ 1) query
+ 2) response
+ 3) acutal_tool_use
+ 4) expected_tool_use
+
+ Here is a sample eval_dataset value with one entry:
+
+ [
+ [
+ {
+ "query": "Roll a 16 sided dice for me",
+ "response": "I rolled a 16 sided die and got 13.\n",
+ "expected_tool_use": [
+ {
+ "tool_name": "roll_die",
+ "tool_input": {
+ "sides": 16
+ }
+ }
+ ],
+ "acutal_tool_use": [
+ {
+ "tool_name": "roll_die",
+ "tool_input": {
+ "sides": 16
+ }
+ }
+ ]
+ }
+ ]
+ ]
+ """
+ if not eval_dataset:
+ raise ValueError("The evaluation dataset is empty.")
+
+ results_df = pd.DataFrame(
+ columns=[
+ "query",
+ "response",
+ "actual_tool_use",
+ "expected_tool_use",
+ "tool_use_accuracy",
+ ]
+ )
+ failures = []
+
+ for conversation in eval_dataset:
+ for index, row in enumerate(conversation):
+ new_row, failure = TrajectoryEvaluator._evaluate_row(row)
+ results_df = pd.concat(
+ [results_df, pd.DataFrame([new_row])], ignore_index=True
+ )
+ if failure:
+ failure["turn"] = index + 1
+ failures.append(failure)
+
+ TrajectoryEvaluator._report_failures(failures)
+
+ if print_detailed_results:
+ TrajectoryEvaluator._print_results(results_df)
+
+ return results_df["tool_use_accuracy"].mean()
+
+ @staticmethod
+ def _evaluate_row(row):
+ # We don't evaluate the mock tool outputs.
+ expected = TrajectoryEvaluator._remove_tool_outputs(
+ row["expected_tool_use"]
+ )
+ actual = row["actual_tool_use"]
+ tool_use_accuracy = (
+ 1.0 if TrajectoryEvaluator.are_tools_equal(actual, expected) else 0.0
+ )
+
+ new_row = {
+ "query": row["query"],
+ "response": row["response"],
+ "actual_tool_use": actual,
+ "expected_tool_use": expected,
+ "tool_use_accuracy": tool_use_accuracy,
+ }
+ failure = (
+ None
+ if tool_use_accuracy == 1.0
+ else {"query": row["query"], "actual": actual, "expected": expected}
+ )
+ return new_row, failure
+
+ @staticmethod
+ def are_tools_equal(list_a_original, list_b_original):
+ # Remove other entries that we don't want to evaluate
+ list_a = [
+ {"tool_name": tool["tool_name"], "tool_input": tool["tool_input"]}
+ for tool in list_a_original
+ ]
+
+ list_b = [
+ {"tool_name": tool["tool_name"], "tool_input": tool["tool_input"]}
+ for tool in list_b_original
+ ]
+
+ return list_a == list_b
+
+ @staticmethod
+ def _remove_tool_outputs(tool_use_list):
+ """Removes 'mock_tool_output' from each dictionary in the list."""
+ result = []
+ for tool_use in tool_use_list:
+ new_tool_use = (
+ tool_use.copy()
+ ) # Create a copy to avoid modifying the original
+ new_tool_use.pop(
+ EvalConstants.MOCK_TOOL_OUTPUT, None
+ ) # Remove 'tool_output' if it exists
+ result.append(new_tool_use)
+ return result
+
+ @staticmethod
+ def _report_failures(failures):
+ if failures:
+ print("Failures:")
+ for failure in failures:
+ print(f"""{{
+ "turn": {failure["turn"]},
+ "query": '{failure["query"]}',
+ "actual": {failure["actual"]},
+ "expected_tool_use": {failure["expected"]},
+}}
+""")
+
+ @staticmethod
+ def _print_results(results_df):
+ print(tabulate(results_df, headers="keys", tablefmt="grid"))
diff --git a/src/google/adk/events/__init__.py b/src/google/adk/events/__init__.py
new file mode 100644
index 0000000..955b050
--- /dev/null
+++ b/src/google/adk/events/__init__.py
@@ -0,0 +1,21 @@
+# Copyright 2025 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from .event import Event
+from .event_actions import EventActions
+
+__all__ = [
+ 'Event',
+ 'EventActions',
+]
diff --git a/src/google/adk/events/event.py b/src/google/adk/events/event.py
new file mode 100644
index 0000000..e6a8aba
--- /dev/null
+++ b/src/google/adk/events/event.py
@@ -0,0 +1,130 @@
+# Copyright 2025 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from __future__ import annotations
+
+from datetime import datetime
+import random
+import string
+from typing import Optional
+
+from google.genai import types
+from pydantic import ConfigDict
+from pydantic import Field
+
+from ..models.llm_response import LlmResponse
+from .event_actions import EventActions
+
+
+class Event(LlmResponse):
+ """Represents an event in a conversation between agents and users.
+
+ It is used to store the content of the conversation, as well as the actions
+ taken by the agents like function calls, etc.
+
+ Attributes:
+ invocation_id: The invocation ID of the event.
+ author: "user" or the name of the agent, indicating who appended the event
+ to the session.
+ actions: The actions taken by the agent.
+ long_running_tool_ids: The ids of the long running function calls.
+ branch: The branch of the event.
+ id: The unique identifier of the event.
+ timestamp: The timestamp of the event.
+ is_final_response: Whether the event is the final response of the agent.
+ get_function_calls: Returns the function calls in the event.
+ """
+
+ model_config = ConfigDict(
+ extra='forbid', ser_json_bytes='base64', val_json_bytes='base64'
+ )
+
+ # TODO: revert to be required after spark migration
+ invocation_id: str = ''
+ """The invocation ID of the event."""
+ author: str
+ """'user' or the name of the agent, indicating who appended the event to the
+ session."""
+ actions: EventActions = Field(default_factory=EventActions)
+ """The actions taken by the agent."""
+
+ long_running_tool_ids: Optional[set[str]] = None
+ """Set of ids of the long running function calls.
+ Agent client will know from this field about which function call is long running.
+ only valid for function call event
+ """
+ branch: Optional[str] = None
+ """The branch of the event.
+
+ The format is like agent_1.agent_2.agent_3, where agent_1 is the parent of
+ agent_2, and agent_2 is the parent of agent_3.
+
+ Branch is used when multiple sub-agent shouldn't see their peer agents'
+ conversaction history.
+ """
+
+ # The following are computed fields.
+ # Do not assign the ID. It will be assigned by the session.
+ id: str = ''
+ """The unique identifier of the event."""
+ timestamp: float = Field(default_factory=lambda: datetime.now().timestamp())
+ """The timestamp of the event."""
+
+ def model_post_init(self, __context):
+ """Post initialization logic for the event."""
+ # Generates a random ID for the event.
+ if not self.id:
+ self.id = Event.new_id()
+
+ def is_final_response(self) -> bool:
+ """Returns whether the event is the final response of the agent."""
+ if self.actions.skip_summarization or self.long_running_tool_ids:
+ return True
+ return (
+ not self.get_function_calls()
+ and not self.get_function_responses()
+ and not self.partial
+ and not self.has_trailing_code_exeuction_result()
+ )
+
+ def get_function_calls(self) -> list[types.FunctionCall]:
+ """Returns the function calls in the event."""
+ func_calls = []
+ if self.content and self.content.parts:
+ for part in self.content.parts:
+ if part.function_call:
+ func_calls.append(part.function_call)
+ return func_calls
+
+ def get_function_responses(self) -> list[types.FunctionResponse]:
+ """Returns the function responses in the event."""
+ func_response = []
+ if self.content and self.content.parts:
+ for part in self.content.parts:
+ if part.function_response:
+ func_response.append(part.function_response)
+ return func_response
+
+ def has_trailing_code_exeuction_result(
+ self,
+ ) -> bool:
+ """Returns whether the event has a trailing code execution result."""
+ if self.content:
+ if self.content.parts:
+ return self.content.parts[-1].code_execution_result is not None
+ return False
+
+ @staticmethod
+ def new_id():
+ characters = string.ascii_letters + string.digits
+ return ''.join(random.choice(characters) for _ in range(8))
diff --git a/src/google/adk/events/event_actions.py b/src/google/adk/events/event_actions.py
new file mode 100644
index 0000000..412546e
--- /dev/null
+++ b/src/google/adk/events/event_actions.py
@@ -0,0 +1,55 @@
+# Copyright 2025 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import annotations
+
+from typing import Optional
+
+from pydantic import BaseModel
+from pydantic import ConfigDict
+from pydantic import Field
+
+from ..auth.auth_tool import AuthConfig
+
+
+class EventActions(BaseModel):
+ """Represents the actions attached to an event."""
+
+ model_config = ConfigDict(extra='forbid')
+
+ skip_summarization: Optional[bool] = None
+ """If true, it won't call model to summarize function response.
+
+ Only used for function_response event.
+ """
+
+ state_delta: dict[str, object] = Field(default_factory=dict)
+ """Indicates that the event is updating the state with the given delta."""
+
+ artifact_delta: dict[str, int] = Field(default_factory=dict)
+ """Indicates that the event is updating an artifact. key is the filename,
+ value is the version."""
+
+ transfer_to_agent: Optional[str] = None
+ """If set, the event transfers to the specified agent."""
+
+ escalate: Optional[bool] = None
+ """The agent is escalating to a higher level agent."""
+
+ requested_auth_configs: dict[str, AuthConfig] = Field(default_factory=dict)
+ """Will only be set by a tool response indicating tool request euc.
+ dict key is the function call id since one function call response (from model)
+ could correspond to multiple function calls.
+ dict value is the required auth config.
+ """
diff --git a/src/google/adk/examples/__init__.py b/src/google/adk/examples/__init__.py
new file mode 100644
index 0000000..a193078
--- /dev/null
+++ b/src/google/adk/examples/__init__.py
@@ -0,0 +1,28 @@
+# Copyright 2025 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from .base_example_provider import BaseExampleProvider
+from .example import Example
+
+__all__ = [
+ 'BaseExampleProvider',
+ 'Example',
+]
+
+try:
+ from .vertex_ai_example_store import VertexAiExampleStore
+
+ __all__.append('VertexAiExampleStore')
+except ImportError:
+ pass
diff --git a/src/google/adk/examples/base_example_provider.py b/src/google/adk/examples/base_example_provider.py
new file mode 100644
index 0000000..bb8aa57
--- /dev/null
+++ b/src/google/adk/examples/base_example_provider.py
@@ -0,0 +1,35 @@
+# Copyright 2025 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import abc
+from .example import Example
+
+
+# A class that provides examples for a given query.
+class BaseExampleProvider(abc.ABC):
+ """Base class for example providers.
+
+ This class defines the interface for providing examples for a given query.
+ """
+
+ @abc.abstractmethod
+ def get_examples(self, query: str) -> list[Example]:
+ """Returns a list of examples for a given query.
+
+ Args:
+ query: The query to get examples for.
+
+ Returns:
+ A list of Example objects.
+ """
diff --git a/src/google/adk/examples/example.py b/src/google/adk/examples/example.py
new file mode 100644
index 0000000..e00ec26
--- /dev/null
+++ b/src/google/adk/examples/example.py
@@ -0,0 +1,27 @@
+# Copyright 2025 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from google.genai import types
+from pydantic import BaseModel
+
+
+class Example(BaseModel):
+ """A few-shot example.
+
+ Attributes:
+ input: The input content for the example.
+ output: The expected output content for the example.
+ """
+ input: types.Content
+ output: list[types.Content]
diff --git a/src/google/adk/examples/example_util.py b/src/google/adk/examples/example_util.py
new file mode 100644
index 0000000..6e264b4
--- /dev/null
+++ b/src/google/adk/examples/example_util.py
@@ -0,0 +1,123 @@
+# Copyright 2025 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Utility functions for converting examples to a string that can be used in system instructions in the prompt."""
+
+import logging
+from typing import Optional, Union
+from typing import TYPE_CHECKING
+
+from .base_example_provider import BaseExampleProvider
+from .example import Example
+
+if TYPE_CHECKING:
+ from ..sessions.session import Session
+
+logger = logging.getLogger(__name__)
+
+# Constant parts of the example string
+_EXAMPLES_INTRO = (
+ "\nBegin few-shot\nThe following are examples of user queries and"
+ " model responses using the available tools.\n\n"
+)
+_EXAMPLES_END = "End few-shot\n"
+_EXAMPLE_START = "EXAMPLE {}:\nBegin example\n"
+_EXAMPLE_END = "End example\n\n"
+_USER_PREFIX = "[user]\n"
+_MODEL_PREFIX = "[model]\n"
+_FUNCTION_PREFIX = "```\n"
+_FUNCTION_CALL_PREFIX = "```tool_code\n"
+_FUNCTION_CALL_SUFFIX = "\n```\n"
+_FUNCTION_RESPONSE_PREFIX = "```tool_outputs\n"
+_FUNCTION_RESPONSE_SUFFIX = "\n```\n"
+
+
+# TODO(yaojie): Add unit tests for this function.
+def convert_examples_to_text(
+ examples: list[Example], model: Optional[str]
+) -> str:
+ """Converts a list of examples to a string that can be used in a system instruction."""
+ examples_str = ""
+ for example_num, example in enumerate(examples):
+ output = f"{_EXAMPLE_START.format(example_num + 1)}{_USER_PREFIX}"
+ if example.input and example.input.parts:
+ output += (
+ "\n".join(part.text for part in example.input.parts if part.text)
+ + "\n"
+ )
+
+ gemini2 = model is None or "gemini-2" in model
+ previous_role = None
+ for content in example.output:
+ role = _MODEL_PREFIX if content.role == "model" else _USER_PREFIX
+ if role != previous_role:
+ output += role
+ previous_role = role
+ for part in content.parts:
+ if part.function_call:
+ args = []
+ # Convert function call part to python-like function call
+ for k, v in part.function_call.args.items():
+ if isinstance(v, str):
+ args.append(f"{k}='{v}'")
+ else:
+ args.append(f"{k}={v}")
+ prefix = _FUNCTION_PREFIX if gemini2 else _FUNCTION_CALL_PREFIX
+ output += (
+ f"{prefix}{part.function_call.name}({', '.join(args)}){_FUNCTION_CALL_SUFFIX}"
+ )
+ # Convert function response part to json string
+ elif part.function_response:
+ prefix = _FUNCTION_PREFIX if gemini2 else _FUNCTION_RESPONSE_PREFIX
+ output += f"{prefix}{part.function_response.__dict__}{_FUNCTION_RESPONSE_SUFFIX}"
+ elif part.text:
+ output += f"{part.text}\n"
+
+ output += _EXAMPLE_END
+ examples_str += output
+
+ return f"{_EXAMPLES_INTRO}{examples_str}{_EXAMPLES_END}"
+
+
+def _get_latest_message_from_user(session: "Session") -> str:
+ """Gets the latest message from the user.
+
+ Returns:
+ The latest message from the user. If not found, returns an empty string.
+ """
+ events = session.events
+ if not events:
+ return ""
+
+ event = events[-1]
+ if event.author == "user" and not event.get_function_responses():
+ if event.content.parts and event.content.parts[0].text:
+ return event.content.parts[0].text
+ else:
+ logger.warning("No message from user for fetching example.")
+
+ return ""
+
+
+def build_example_si(
+ examples: Union[list[Example], BaseExampleProvider],
+ query: str,
+ model: Optional[str],
+) -> str:
+ if isinstance(examples, list):
+ return convert_examples_to_text(examples, model)
+ if isinstance(examples, BaseExampleProvider):
+ return convert_examples_to_text(examples.get_examples(query), model)
+
+ raise ValueError("Invalid example configuration")
diff --git a/src/google/adk/examples/vertex_ai_example_store.py b/src/google/adk/examples/vertex_ai_example_store.py
new file mode 100644
index 0000000..718003a
--- /dev/null
+++ b/src/google/adk/examples/vertex_ai_example_store.py
@@ -0,0 +1,104 @@
+# Copyright 2025 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from google.genai import types
+from typing_extensions import override
+from vertexai.preview import example_stores
+
+from .base_example_provider import BaseExampleProvider
+from .example import Example
+
+
+class VertexAiExampleStore(BaseExampleProvider):
+ """Provides examples from Vertex example store."""
+
+ def __init__(self, examples_store_name: str):
+ """Initializes the VertexAiExampleStore.
+
+ Args:
+ examples_store_name: The resource name of the vertex example store, in
+ the format of
+ ``projects/{project}/locations/{location}/exampleStores/{example_store}``.
+ """
+ self.examples_store_name = examples_store_name
+
+ @override
+ def get_examples(self, query: str) -> list[Example]:
+ example_store = example_stores.ExampleStore(self.examples_store_name)
+ # Retrieve relevant examples.
+ request = {
+ "stored_contents_example_parameters": {
+ "content_search_key": {
+ "contents": [{"role": "user", "parts": [{"text": query}]}],
+ "search_key_generation_method": {"last_entry": {}},
+ }
+ },
+ "top_k": 10,
+ "example_store": self.examples_store_name,
+ }
+ response = example_store.api_client.search_examples(request)
+
+ returned_examples = []
+ # Convert results to genai formats
+ for result in response.results:
+ if result.similarity_score < 0.5:
+ continue
+ expected_contents = [
+ content.content
+ for content in result.example.stored_contents_example.contents_example.expected_contents
+ ]
+ expected_output = []
+ for content in expected_contents:
+ expected_parts = []
+ for part in content.parts:
+ if part.text:
+ expected_parts.append(types.Part.from_text(text=part.text))
+ elif part.function_call:
+ expected_parts.append(
+ types.Part.from_function_call(
+ name=part.function_call.name,
+ args={
+ key: value
+ for key, value in part.function_call.args.items()
+ },
+ )
+ )
+ elif part.function_response:
+ expected_parts.append(
+ types.Part.from_function_response(
+ name=part.function_response.name,
+ response={
+ key: value
+ for key, value in part.function_response.response.items()
+ },
+ )
+ )
+ expected_output.append(
+ types.Content(role=content.role, parts=expected_parts)
+ )
+
+ returned_examples.append(
+ Example(
+ input=types.Content(
+ role="user",
+ parts=[
+ types.Part.from_text(
+ text=result.example.stored_contents_example.search_key
+ )
+ ],
+ ),
+ output=expected_output,
+ )
+ )
+ return returned_examples
diff --git a/src/google/adk/flows/__init__.py b/src/google/adk/flows/__init__.py
new file mode 100644
index 0000000..36a1e8d
--- /dev/null
+++ b/src/google/adk/flows/__init__.py
@@ -0,0 +1,14 @@
+# Copyright 2025 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
diff --git a/src/google/adk/flows/llm_flows/__init__.py b/src/google/adk/flows/llm_flows/__init__.py
new file mode 100644
index 0000000..6dbd22f
--- /dev/null
+++ b/src/google/adk/flows/llm_flows/__init__.py
@@ -0,0 +1,20 @@
+# Copyright 2025 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from . import _code_execution
+from . import _nl_planning
+from . import contents
+from . import functions
+from . import identity
+from . import instructions
diff --git a/src/google/adk/flows/llm_flows/_base_llm_processor.py b/src/google/adk/flows/llm_flows/_base_llm_processor.py
new file mode 100644
index 0000000..488d994
--- /dev/null
+++ b/src/google/adk/flows/llm_flows/_base_llm_processor.py
@@ -0,0 +1,52 @@
+# Copyright 2025 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Defines the processor interface used for BaseLlmFlow."""
+from __future__ import annotations
+
+from abc import ABC
+from abc import abstractmethod
+from typing import AsyncGenerator
+from typing import TYPE_CHECKING
+
+from ...agents.invocation_context import InvocationContext
+from ...events.event import Event
+
+if TYPE_CHECKING:
+ from ...models.llm_request import LlmRequest
+ from ...models.llm_response import LlmResponse
+
+
+class BaseLlmRequestProcessor(ABC):
+ """Base class for LLM request processor."""
+
+ @abstractmethod
+ async def run_async(
+ self, invocation_context: InvocationContext, llm_request: LlmRequest
+ ) -> AsyncGenerator[Event, None]:
+ """Runs the processor."""
+ raise NotImplementedError("Not implemented.")
+ yield # AsyncGenerator requires a yield in function body.
+
+
+class BaseLlmResponseProcessor(ABC):
+ """Base class for LLM response processor."""
+
+ @abstractmethod
+ async def run_async(
+ self, invocation_context: InvocationContext, llm_response: LlmResponse
+ ) -> AsyncGenerator[Event, None]:
+ """Processes the LLM response."""
+ raise NotImplementedError("Not implemented.")
+ yield # AsyncGenerator requires a yield in function body.
diff --git a/src/google/adk/flows/llm_flows/_code_execution.py b/src/google/adk/flows/llm_flows/_code_execution.py
new file mode 100644
index 0000000..7ca57af
--- /dev/null
+++ b/src/google/adk/flows/llm_flows/_code_execution.py
@@ -0,0 +1,458 @@
+# Copyright 2025 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Handles Code Execution related logic."""
+
+from __future__ import annotations
+
+import base64
+import copy
+import dataclasses
+import os
+import re
+from typing import AsyncGenerator
+from typing import Generator
+from typing import Optional
+from typing import TYPE_CHECKING
+
+from google.genai import types
+from typing_extensions import override
+
+from ...agents.invocation_context import InvocationContext
+from ...code_executors.base_code_executor import BaseCodeExecutor
+from ...code_executors.code_execution_utils import CodeExecutionInput
+from ...code_executors.code_execution_utils import CodeExecutionResult
+from ...code_executors.code_execution_utils import CodeExecutionUtils
+from ...code_executors.code_execution_utils import File
+from ...code_executors.code_executor_context import CodeExecutorContext
+from ...events.event import Event
+from ...events.event_actions import EventActions
+from ...models.llm_response import LlmResponse
+from ._base_llm_processor import BaseLlmRequestProcessor
+from ._base_llm_processor import BaseLlmResponseProcessor
+
+if TYPE_CHECKING:
+ from ...models.llm_request import LlmRequest
+
+
+@dataclasses.dataclass
+class DataFileUtil:
+ """A structure that contains a data file name and its content."""
+
+ extension: str
+ """
+ The file extension (e.g., ".csv").
+ """
+
+ loader_code_template: str
+ """
+ The code template to load the data file.
+ """
+
+
+_DATA_FILE_UTIL_MAP = {
+ 'text/csv': DataFileUtil(
+ extension='.csv',
+ loader_code_template="pd.read_csv('{filename}')",
+ ),
+}
+
+_DATA_FILE_HELPER_LIB = '''
+import pandas as pd
+
+def explore_df(df: pd.DataFrame) -> None:
+ """Prints some information about a pandas DataFrame."""
+
+ with pd.option_context(
+ 'display.max_columns', None, 'display.expand_frame_repr', False
+ ):
+ # Print the column names to never encounter KeyError when selecting one.
+ df_dtypes = df.dtypes
+
+ # Obtain information about data types and missing values.
+ df_nulls = (len(df) - df.isnull().sum()).apply(
+ lambda x: f'{x} / {df.shape[0]} non-null'
+ )
+
+ # Explore unique total values in columns using `.unique()`.
+ df_unique_count = df.apply(lambda x: len(x.unique()))
+
+ # Explore unique values in columns using `.unique()`.
+ df_unique = df.apply(lambda x: crop(str(list(x.unique()))))
+
+ df_info = pd.concat(
+ (
+ df_dtypes.rename('Dtype'),
+ df_nulls.rename('Non-Null Count'),
+ df_unique_count.rename('Unique Values Count'),
+ df_unique.rename('Unique Values'),
+ ),
+ axis=1,
+ )
+ df_info.index.name = 'Columns'
+ print(f"""Total rows: {df.shape[0]}
+Total columns: {df.shape[1]}
+
+{df_info}""")
+'''
+
+
+class _CodeExecutionRequestProcessor(BaseLlmRequestProcessor):
+ """Processes code execution requests."""
+
+ @override
+ async def run_async(
+ self, invocation_context: InvocationContext, llm_request: LlmRequest
+ ) -> AsyncGenerator[Event, None]:
+ from ...agents.llm_agent import LlmAgent
+
+ if not isinstance(invocation_context.agent, LlmAgent):
+ return
+ if not invocation_context.agent.code_executor:
+ return
+
+ for event in _run_pre_processor(invocation_context, llm_request):
+ yield event
+
+ # Convert the code execution parts to text parts.
+ if not isinstance(invocation_context.agent.code_executor, BaseCodeExecutor):
+ return
+ for content in llm_request.contents:
+ CodeExecutionUtils.convert_code_execution_parts(
+ content,
+ invocation_context.agent.code_executor.code_block_delimiters[0]
+ if invocation_context.agent.code_executor.code_block_delimiters
+ else ('', ''),
+ invocation_context.agent.code_executor.execution_result_delimiters,
+ )
+
+
+request_processor = _CodeExecutionRequestProcessor()
+
+
+class _CodeExecutionResponseProcessor(BaseLlmResponseProcessor):
+ """Processes code execution responses."""
+
+ @override
+ async def run_async(
+ self, invocation_context: InvocationContext, llm_response: LlmResponse
+ ) -> AsyncGenerator[Event, None]:
+ # Skip if the response is partial (streaming).
+ if llm_response.partial:
+ return
+
+ for event in _run_post_processor(invocation_context, llm_response):
+ yield event
+
+
+response_processor = _CodeExecutionResponseProcessor()
+
+
+def _run_pre_processor(
+ invocation_context: InvocationContext,
+ llm_request: LlmRequest,
+) -> Generator[Event, None, None]:
+ """Pre-process the user message by adding the user message to the Colab notebook."""
+ from ...agents.llm_agent import LlmAgent
+
+ if not isinstance(invocation_context.agent, LlmAgent):
+ return
+
+ agent = invocation_context.agent
+ code_executor = agent.code_executor
+
+ if not code_executor or not isinstance(code_executor, BaseCodeExecutor):
+ return
+ if not code_executor.optimize_data_file:
+ return
+
+ code_executor_context = CodeExecutorContext(invocation_context.session.state)
+
+ # Skip if the error count exceeds the max retry attempts.
+ if (
+ code_executor_context.get_error_count(invocation_context.invocation_id)
+ >= code_executor.error_retry_attempts
+ ):
+ return
+
+ # [Step 1] Extract data files from the session_history and store them in
+ # memory. Meanwhile, mutate the inline data file to text part in session
+ # history from all turns.
+ all_input_files = _extrac_and_replace_inline_files(
+ code_executor_context, llm_request
+ )
+
+ # [Step 2] Run Explore_Df code on the data files from the current turn. We
+ # only need to explore the new data files because the previous data files
+ # should already be explored and cached in the code execution runtime.
+ processed_file_names = set(code_executor_context.get_processed_file_names())
+ files_to_process = [
+ f for f in all_input_files if f.name not in processed_file_names
+ ]
+ for file in files_to_process:
+ code_str = _get_data_file_preprocessing_code(file)
+ # Skip for unsupported file or executor types.
+ if not code_str:
+ return
+
+ # Emit the code to execute, and add it to the LLM request.
+ code_content = types.Content(
+ role='model',
+ parts=[
+ types.Part(text=f'Processing input file: `{file.name}`'),
+ CodeExecutionUtils.build_executable_code_part(code_str),
+ ],
+ )
+ llm_request.contents.append(copy.deepcopy(code_content))
+ yield Event(
+ invocation_id=invocation_context.invocation_id,
+ author=agent.name,
+ branch=invocation_context.branch,
+ content=code_content,
+ )
+
+ code_execution_result = code_executor.execute_code(
+ invocation_context,
+ CodeExecutionInput(
+ code=code_str,
+ input_files=[file],
+ execution_id=_get_or_set_execution_id(
+ invocation_context, code_executor_context
+ ),
+ ),
+ )
+ # Update the processing results to code executor context.
+ code_executor_context.update_code_execution_result(
+ invocation_context.invocation_id,
+ code_str,
+ code_execution_result.stdout,
+ code_execution_result.stderr,
+ )
+ code_executor_context.add_processed_file_names([file.name])
+
+ # Emit the execution result, and add it to the LLM request.
+ execution_result_event = _post_process_code_execution_result(
+ invocation_context, code_executor_context, code_execution_result
+ )
+ yield execution_result_event
+ llm_request.contents.append(copy.deepcopy(execution_result_event.content))
+
+
+def _run_post_processor(
+ invocation_context: InvocationContext,
+ llm_response,
+) -> Generator[Event, None, None]:
+ """Post-process the model response by extracting and executing the first code block."""
+ agent = invocation_context.agent
+ code_executor = agent.code_executor
+
+ if not code_executor or not isinstance(code_executor, BaseCodeExecutor):
+ return
+ if not llm_response or not llm_response.content:
+ return
+
+ code_executor_context = CodeExecutorContext(invocation_context.session.state)
+ # Skip if the error count exceeds the max retry attempts.
+ if (
+ code_executor_context.get_error_count(invocation_context.invocation_id)
+ >= code_executor.error_retry_attempts
+ ):
+ return
+
+ # [Step 1] Extract code from the model predict response and truncate the
+ # content to the part with the first code block.
+ response_content = llm_response.content
+ code_str = CodeExecutionUtils.extract_code_and_truncate_content(
+ response_content, code_executor.code_block_delimiters
+ )
+ # Terminal state: no code to execute.
+ if not code_str:
+ return
+
+ # [Step 2] Executes the code and emit 2 Events for code and execution result.
+ yield Event(
+ invocation_id=invocation_context.invocation_id,
+ author=agent.name,
+ branch=invocation_context.branch,
+ content=response_content,
+ actions=EventActions(),
+ )
+
+ code_execution_result = code_executor.execute_code(
+ invocation_context,
+ CodeExecutionInput(
+ code=code_str,
+ input_files=code_executor_context.get_input_files(),
+ execution_id=_get_or_set_execution_id(
+ invocation_context, code_executor_context
+ ),
+ ),
+ )
+ code_executor_context.update_code_execution_result(
+ invocation_context.invocation_id,
+ code_str,
+ code_execution_result.stdout,
+ code_execution_result.stderr,
+ )
+ yield _post_process_code_execution_result(
+ invocation_context, code_executor_context, code_execution_result
+ )
+
+ # [Step 3] Skip processing the original model response
+ # to continue code generation loop.
+ llm_response.content = None
+
+
+def _extrac_and_replace_inline_files(
+ code_executor_context: CodeExecutorContext,
+ llm_request: LlmRequest,
+) -> list[File]:
+ """Extracts and replaces inline files with file names in the LLM request."""
+ all_input_files = code_executor_context.get_input_files()
+ saved_file_names = set(f.name for f in all_input_files)
+
+ # [Step 1] Process input files from LlmRequest and cache them in CodeExecutor.
+ for i in range(len(llm_request.contents)):
+ content = llm_request.contents[i]
+ # Only process the user message.
+ if content.role != 'user' and not content.parts:
+ continue
+
+ for j in range(len(content.parts)):
+ part = content.parts[j]
+ # Skip if the inline data is not supported.
+ if (
+ not part.inline_data
+ or part.inline_data.mime_type not in _DATA_FILE_UTIL_MAP
+ ):
+ continue
+
+ # Replace the inline data file with a file name placeholder.
+ mime_type = part.inline_data.mime_type
+ file_name = f'data_{i+1}_{j+1}' + _DATA_FILE_UTIL_MAP[mime_type].extension
+ llm_request.contents[i].parts[j] = types.Part(
+ text='\nAvailable file: `%s`\n' % file_name
+ )
+
+ # Add the inlne data as input file to the code executor context.
+ file = File(
+ name=file_name,
+ content=CodeExecutionUtils.get_encoded_file_content(
+ part.inline_data.data
+ ).decode(),
+ mime_type=mime_type,
+ )
+ if file_name not in saved_file_names:
+ code_executor_context.add_input_files([file])
+ all_input_files.append(file)
+
+ return all_input_files
+
+
+def _get_or_set_execution_id(
+ invocation_context: InvocationContext,
+ code_executor_context: CodeExecutorContext,
+) -> Optional[str]:
+ """Returns the ID for stateful code execution or None if not stateful."""
+ if not invocation_context.agent.code_executor.stateful:
+ return None
+
+ execution_id = code_executor_context.get_execution_id()
+ if not execution_id:
+ execution_id = invocation_context.session.id
+ code_executor_context.set_execution_id(execution_id)
+ return execution_id
+
+
+def _post_process_code_execution_result(
+ invocation_context: InvocationContext,
+ code_executor_context: CodeExecutorContext,
+ code_execution_result: CodeExecutionResult,
+) -> Event:
+ """Post-process the code execution result and emit an Event."""
+ if invocation_context.artifact_service is None:
+ raise ValueError('Artifact service is not initialized.')
+
+ result_content = types.Content(
+ role='model',
+ parts=[
+ CodeExecutionUtils.build_code_execution_result_part(
+ code_execution_result
+ ),
+ ],
+ )
+ event_actions = EventActions(
+ state_delta=code_executor_context.get_state_delta()
+ )
+
+ # Handle code execution error retry.
+ if code_execution_result.stderr:
+ code_executor_context.increment_error_count(
+ invocation_context.invocation_id
+ )
+ else:
+ code_executor_context.reset_error_count(invocation_context.invocation_id)
+
+ # Handle output files.
+ for output_file in code_execution_result.output_files:
+ version = invocation_context.artifact_service.save_artifact(
+ app_name=invocation_context.app_name,
+ user_id=invocation_context.user_id,
+ session_id=invocation_context.session.id,
+ filename=output_file.name,
+ artifact=types.Part.from_bytes(
+ data=base64.b64decode(output_file.content),
+ mime_type=output_file.mime_type,
+ ),
+ )
+ event_actions.artifact_delta[output_file.name] = version
+
+ return Event(
+ invocation_id=invocation_context.invocation_id,
+ author=invocation_context.agent.name,
+ branch=invocation_context.branch,
+ content=result_content,
+ actions=event_actions,
+ )
+
+
+def _get_data_file_preprocessing_code(file: File) -> Optional[str]:
+ """Returns the code to explore the data file."""
+
+ def _get_normalized_file_name(file_name: str) -> str:
+ var_name, _ = os.path.splitext(file_name)
+ # Replace non-alphanumeric characters with underscores
+ var_name = re.sub(r'[^a-zA-Z0-9_]', '_', var_name)
+
+ # If the filename starts with a digit, prepend an underscore
+ if var_name[0].isdigit():
+ var_name = '_' + var_name
+ return var_name
+
+ if file.mime_type not in _DATA_FILE_UTIL_MAP:
+ return
+
+ var_name = _get_normalized_file_name(file.name)
+ loader_code = _DATA_FILE_UTIL_MAP[file.mime_type].loader_code_template.format(
+ filename=file.name
+ )
+ return f"""
+{_DATA_FILE_HELPER_LIB}
+
+# Load the dataframe.
+{var_name} = {loader_code}
+
+# Use `explore_df` to guide my analysis.
+explore_df({var_name})
+"""
diff --git a/src/google/adk/flows/llm_flows/_nl_planning.py b/src/google/adk/flows/llm_flows/_nl_planning.py
new file mode 100644
index 0000000..98aff1e
--- /dev/null
+++ b/src/google/adk/flows/llm_flows/_nl_planning.py
@@ -0,0 +1,129 @@
+# Copyright 2025 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Handles NL planning related logic."""
+
+from __future__ import annotations
+
+from typing import AsyncGenerator
+from typing import Generator
+from typing import Optional
+from typing import TYPE_CHECKING
+
+from typing_extensions import override
+
+from ...agents.callback_context import CallbackContext
+from ...agents.invocation_context import InvocationContext
+from ...agents.readonly_context import ReadonlyContext
+from ...events.event import Event
+from ...planners.plan_re_act_planner import PlanReActPlanner
+from ._base_llm_processor import BaseLlmRequestProcessor
+from ._base_llm_processor import BaseLlmResponseProcessor
+
+if TYPE_CHECKING:
+ from ...models.llm_request import LlmRequest
+ from ...models.llm_response import LlmResponse
+ from ...planners.base_planner import BasePlanner
+ from ...planners.built_in_planner import BuiltInPlanner
+
+
+class _NlPlanningRequestProcessor(BaseLlmRequestProcessor):
+ """Processor for NL planning."""
+
+ async def run_async(
+ self, invocation_context: InvocationContext, llm_request: LlmRequest
+ ) -> AsyncGenerator[Event, None]:
+ from ...planners.built_in_planner import BuiltInPlanner
+
+ planner = _get_planner(invocation_context)
+ if not planner:
+ return
+
+ if isinstance(planner, BuiltInPlanner):
+ planner.apply_thinking_config(llm_request)
+
+ planning_instruction = planner.build_planning_instruction(
+ ReadonlyContext(invocation_context), llm_request
+ )
+ if planning_instruction:
+ llm_request.append_instructions([planning_instruction])
+
+ _remove_thought_from_request(llm_request)
+
+ # Maintain async generator behavior
+ if False: # Ensures it behaves as a generator
+ yield # This is a no-op but maintains generator structure
+
+
+request_processor = _NlPlanningRequestProcessor()
+
+
+class _NlPlanningResponse(BaseLlmResponseProcessor):
+
+ @override
+ async def run_async(
+ self, invocation_context: InvocationContext, llm_response: LlmResponse
+ ) -> AsyncGenerator[Event, None]:
+ if (
+ not llm_response
+ or not llm_response.content
+ or not llm_response.content.parts
+ ):
+ return
+
+ planner = _get_planner(invocation_context)
+ if not planner:
+ return
+
+ # Postprocess the LLM response.
+ processed_parts = planner.process_planning_response(
+ CallbackContext(invocation_context), llm_response.content.parts
+ )
+ if processed_parts:
+ llm_response.content.parts = processed_parts
+
+ # Maintain async generator behavior
+ if False: # Ensures it behaves as a generator
+ yield # This is a no-op but maintains generator structure
+
+
+response_processor = _NlPlanningResponse()
+
+
+def _get_planner(
+ invocation_context: InvocationContext,
+) -> Optional[BasePlanner]:
+ from ...agents.llm_agent import Agent
+ from ...planners.base_planner import BasePlanner
+
+ agent = invocation_context.agent
+ if not isinstance(agent, Agent):
+ return None
+ if not agent.planner:
+ return None
+
+ if isinstance(agent.planner, BasePlanner):
+ return agent.planner
+ return PlanReActPlanner()
+
+
+def _remove_thought_from_request(llm_request: LlmRequest):
+ if not llm_request.contents:
+ return
+
+ for content in llm_request.contents:
+ if not content.parts:
+ continue
+ for part in content.parts:
+ part.thought = None
diff --git a/src/google/adk/flows/llm_flows/agent_transfer.py b/src/google/adk/flows/llm_flows/agent_transfer.py
new file mode 100644
index 0000000..7fba6d5
--- /dev/null
+++ b/src/google/adk/flows/llm_flows/agent_transfer.py
@@ -0,0 +1,132 @@
+# Copyright 2025 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Handles agent transfer for LLM flow."""
+
+from __future__ import annotations
+
+import typing
+from typing import AsyncGenerator
+
+from typing_extensions import override
+
+from ...agents.invocation_context import InvocationContext
+from ...events.event import Event
+from ...models.llm_request import LlmRequest
+from ...tools.function_tool import FunctionTool
+from ...tools.tool_context import ToolContext
+from ...tools.transfer_to_agent_tool import transfer_to_agent
+from ._base_llm_processor import BaseLlmRequestProcessor
+
+if typing.TYPE_CHECKING:
+ from ...agents import BaseAgent
+ from ...agents import LlmAgent
+
+
+class _AgentTransferLlmRequestProcessor(BaseLlmRequestProcessor):
+ """Agent transfer request processor."""
+
+ @override
+ async def run_async(
+ self, invocation_context: InvocationContext, llm_request: LlmRequest
+ ) -> AsyncGenerator[Event, None]:
+ from ...agents.llm_agent import LlmAgent
+
+ if not isinstance(invocation_context.agent, LlmAgent):
+ return
+
+ transfer_targets = _get_transfer_targets(invocation_context.agent)
+ if not transfer_targets:
+ return
+
+ llm_request.append_instructions([
+ _build_target_agents_instructions(
+ invocation_context.agent, transfer_targets
+ )
+ ])
+
+ transfer_to_agent_tool = FunctionTool(func=transfer_to_agent)
+ tool_context = ToolContext(invocation_context)
+ await transfer_to_agent_tool.process_llm_request(
+ tool_context=tool_context, llm_request=llm_request
+ )
+
+ return
+ yield # AsyncGenerator requires yield statement in function body.
+
+
+request_processor = _AgentTransferLlmRequestProcessor()
+
+
+def _build_target_agents_info(target_agent: BaseAgent) -> str:
+ return f"""
+Agent name: {target_agent.name}
+Agent description: {target_agent.description}
+"""
+
+
+line_break = '\n'
+
+
+def _build_target_agents_instructions(
+ agent: LlmAgent, target_agents: list[BaseAgent]
+) -> str:
+ si = f"""
+You have a list of other agents to transfer to:
+
+{line_break.join([
+ _build_target_agents_info(target_agent) for target_agent in target_agents
+])}
+
+If you are the best to answer the question according to your description, you
+can answer it.
+
+If another agent is better for answering the question according to its
+description, call `{_TRANSFER_TO_AGENT_FUNCTION_NAME}` function to transfer the
+question to that agent. When transfering, do not generate any text other than
+the function call.
+"""
+
+ if agent.parent_agent:
+ si += f"""
+Your parent agent is {agent.parent_agent.name}. If neither the other agents nor
+you are best for answering the question according to the descriptions, transfer
+to your parent agent. If you don't have parent agent, try answer by yourself.
+"""
+ return si
+
+
+_TRANSFER_TO_AGENT_FUNCTION_NAME = transfer_to_agent.__name__
+
+
+def _get_transfer_targets(agent: LlmAgent) -> list[BaseAgent]:
+ from ...agents.llm_agent import LlmAgent
+
+ result = []
+ result.extend(agent.sub_agents)
+
+ if not agent.parent_agent or not isinstance(agent.parent_agent, LlmAgent):
+ return result
+
+ if not agent.disallow_transfer_to_parent:
+ result.append(agent.parent_agent)
+
+ if not agent.disallow_transfer_to_peers:
+ result.extend([
+ peer_agent
+ for peer_agent in agent.parent_agent.sub_agents
+ if peer_agent.name != agent.name
+ ])
+
+ return result
diff --git a/src/google/adk/flows/llm_flows/audio_transcriber.py b/src/google/adk/flows/llm_flows/audio_transcriber.py
new file mode 100644
index 0000000..6709bb5
--- /dev/null
+++ b/src/google/adk/flows/llm_flows/audio_transcriber.py
@@ -0,0 +1,109 @@
+# Copyright 2025 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from __future__ import annotations
+
+from typing import TYPE_CHECKING
+
+from google.cloud import speech
+from google.genai import types as genai_types
+
+if TYPE_CHECKING:
+ from ...agents.invocation_context import InvocationContext
+
+
+class AudioTranscriber:
+ """Transcribes audio using Google Cloud Speech-to-Text."""
+
+ def __init__(self):
+ self.client = speech.SpeechClient()
+
+ def transcribe_file(
+ self, invocation_context: InvocationContext
+ ) -> list[genai_types.Content]:
+ """Transcribe audio, bundling consecutive segments from the same speaker.
+
+ The ordering of speakers will be preserved. Audio blobs will be merged for
+ the same speaker as much as we can do reduce the transcription latency.
+
+ Args:
+ invocation_context: The invocation context to access the transcription
+ cache.
+
+ Returns:
+ A list of Content objects containing the transcribed text.
+ """
+
+ bundled_audio = []
+ current_speaker = None
+ current_audio_data = b''
+ contents = []
+
+ # Step1: merge audio blobs
+ for transcription_entry in invocation_context.transcription_cache or []:
+ speaker, audio_data = (
+ transcription_entry.role,
+ transcription_entry.data,
+ )
+
+ if isinstance(audio_data, genai_types.Content):
+ if current_speaker is not None:
+ bundled_audio.append((current_speaker, current_audio_data))
+ current_speaker = None
+ current_audio_data = b''
+ bundled_audio.append((speaker, audio_data))
+ continue
+
+ if not audio_data.data:
+ continue
+
+ if speaker == current_speaker:
+ current_audio_data += audio_data.data
+ else:
+ if current_speaker is not None:
+ bundled_audio.append((current_speaker, current_audio_data))
+ current_speaker = speaker
+ current_audio_data = audio_data.data
+
+ # Append the last audio segment if any
+ if current_speaker is not None:
+ bundled_audio.append((current_speaker, current_audio_data))
+
+ # reset cache
+ invocation_context.transcription_cache = []
+
+ # Step2: transcription
+ for speaker, data in bundled_audio:
+ if speaker == 'user':
+ audio = speech.RecognitionAudio(content=data)
+
+ config = speech.RecognitionConfig(
+ encoding=speech.RecognitionConfig.AudioEncoding.LINEAR16,
+ sample_rate_hertz=16000,
+ language_code='en-US',
+ )
+
+ response = self.client.recognize(config=config, audio=audio)
+
+ for result in response.results:
+ transcript = result.alternatives[0].transcript
+
+ parts = [genai_types.Part(text=transcript)]
+ role = speaker.lower()
+ content = genai_types.Content(role=role, parts=parts)
+ contents.append(content)
+ else:
+ # don't need to transcribe model which are already text
+ contents.append(data)
+
+ return contents
diff --git a/src/google/adk/flows/llm_flows/auto_flow.py b/src/google/adk/flows/llm_flows/auto_flow.py
new file mode 100644
index 0000000..a35ea7e
--- /dev/null
+++ b/src/google/adk/flows/llm_flows/auto_flow.py
@@ -0,0 +1,49 @@
+# Copyright 2025 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Implementation of AutoFlow."""
+
+from . import agent_transfer
+from .single_flow import SingleFlow
+
+
+class AutoFlow(SingleFlow):
+ """AutoFlow is SingleFlow with agent transfer capability.
+
+ Agent transfer is allowed in the following direction:
+
+ 1. from parent to sub-agent;
+ 2. from sub-agent to parent;
+ 3. from sub-agent to its peer agents;
+
+ For peer-agent transfers, it's only enabled when all below conditions are met:
+
+ - The parent agent is also of AutoFlow;
+ - `disallow_transfer_to_peer` option of this agent is False (default).
+
+ Depending on the target agent flow type, the transfer may be automatically
+ reversed. The condition is as below:
+
+ - If the flow type of the tranferee agent is also auto, transfee agent will
+ remain as the active agent. The transfee agent will respond to the user's
+ next message directly.
+ - If the flow type of the transfere agent is not auto, the active agent will
+ be reversed back to previous agent.
+
+ TODO: allow user to config auto-reverse function.
+ """
+
+ def __init__(self):
+ super().__init__()
+ self.request_processors += [agent_transfer.request_processor]
diff --git a/src/google/adk/flows/llm_flows/base_llm_flow.py b/src/google/adk/flows/llm_flows/base_llm_flow.py
new file mode 100644
index 0000000..f8662a9
--- /dev/null
+++ b/src/google/adk/flows/llm_flows/base_llm_flow.py
@@ -0,0 +1,559 @@
+# Copyright 2025 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import annotations
+
+from abc import ABC
+import asyncio
+import logging
+from typing import AsyncGenerator
+from typing import cast
+from typing import Optional
+from typing import TYPE_CHECKING
+
+from websockets.exceptions import ConnectionClosedOK
+
+from ...agents.base_agent import BaseAgent
+from ...agents.callback_context import CallbackContext
+from ...agents.invocation_context import InvocationContext
+from ...agents.live_request_queue import LiveRequestQueue
+from ...agents.run_config import StreamingMode
+from ...agents.transcription_entry import TranscriptionEntry
+from ...events.event import Event
+from ...models.base_llm_connection import BaseLlmConnection
+from ...models.llm_request import LlmRequest
+from ...models.llm_response import LlmResponse
+from ...telemetry import trace_call_llm
+from ...telemetry import trace_send_data
+from ...telemetry import tracer
+from ...tools.tool_context import ToolContext
+from . import functions
+
+if TYPE_CHECKING:
+ from ...agents.llm_agent import LlmAgent
+ from ...models.base_llm import BaseLlm
+ from ._base_llm_processor import BaseLlmRequestProcessor
+ from ._base_llm_processor import BaseLlmResponseProcessor
+
+logger = logging.getLogger(__name__)
+
+
+class BaseLlmFlow(ABC):
+ """A basic flow that calls the LLM in a loop until a final response is generated.
+
+ This flow ends when it transfer to another agent.
+ """
+
+ def __init__(self):
+ self.request_processors: list[BaseLlmRequestProcessor] = []
+ self.response_processors: list[BaseLlmResponseProcessor] = []
+
+ async def run_live(
+ self,
+ invocation_context: InvocationContext,
+ ) -> AsyncGenerator[Event, None]:
+ """Runs the flow using live api."""
+ llm_request = LlmRequest()
+ event_id = Event.new_id()
+
+ # Preprocess before calling the LLM.
+ async for event in self._preprocess_async(invocation_context, llm_request):
+ yield event
+ if invocation_context.end_invocation:
+ return
+
+ llm = self.__get_llm(invocation_context)
+ logger.info(
+ 'Establishing live connection for agent: %s with llm request: %s',
+ invocation_context.agent.name,
+ llm_request,
+ )
+ async with llm.connect(llm_request) as llm_connection:
+ if llm_request.contents:
+ # Sends the conversation history to the model.
+ with tracer.start_as_current_span('send_data'):
+
+ if invocation_context.transcription_cache:
+ from . import audio_transcriber
+
+ audio_transcriber = audio_transcriber.AudioTranscriber()
+ contents = audio_transcriber.transcribe_file(invocation_context)
+ logger.debug('Sending history to model: %s', contents)
+ await llm_connection.send_history(contents)
+ invocation_context.transcription_cache = None
+ trace_send_data(invocation_context, event_id, contents)
+ else:
+ await llm_connection.send_history(llm_request.contents)
+ trace_send_data(invocation_context, event_id, llm_request.contents)
+
+ send_task = asyncio.create_task(
+ self._send_to_model(llm_connection, invocation_context)
+ )
+
+ try:
+ async for event in self._receive_from_model(
+ llm_connection,
+ event_id,
+ invocation_context,
+ llm_request,
+ ):
+ # Empty event means the queue is closed.
+ if not event:
+ break
+ logger.debug('Receive new event: %s', event)
+ yield event
+ # send back the function response
+ if event.get_function_responses():
+ logger.debug('Sending back last function resonse event: %s', event)
+ invocation_context.live_request_queue.send_content(event.content)
+ if (
+ event.content
+ and event.content.parts
+ and event.content.parts[0].function_response
+ and event.content.parts[0].function_response.name
+ == 'transfer_to_agent'
+ ):
+ await asyncio.sleep(1)
+ # cancel the tasks that belongs to the closed connection.
+ send_task.cancel()
+ await llm_connection.close()
+ finally:
+ # Clean up
+ if not send_task.done():
+ send_task.cancel()
+ try:
+ await send_task
+ except asyncio.CancelledError:
+ pass
+
+ async def _send_to_model(
+ self,
+ llm_connection: BaseLlmConnection,
+ invocation_context: InvocationContext,
+ ):
+ """Sends data to model."""
+ while True:
+ live_request_queue = invocation_context.live_request_queue
+ try:
+ # Streamlit's execution model doesn't preemptively yield to the event
+ # loop. Therefore, we must explicitly introduce timeouts to allow the
+ # event loop to process events.
+ # TODO: revert back(remove timeout) once we move off streamlit.
+ live_request = await asyncio.wait_for(
+ live_request_queue.get(), timeout=0.25
+ )
+ # duplicate the live_request to all the active streams
+ logger.debug(
+ 'Sending live request %s to active streams: %s',
+ live_request,
+ invocation_context.active_streaming_tools,
+ )
+ if invocation_context.active_streaming_tools:
+ for active_streaming_tool in (
+ invocation_context.active_streaming_tools
+ ).values():
+ if active_streaming_tool.stream:
+ active_streaming_tool.stream.send(live_request)
+ await asyncio.sleep(0)
+ except asyncio.TimeoutError:
+ continue
+ if live_request.close:
+ await llm_connection.close()
+ return
+ if live_request.blob:
+ # Cache audio data here for transcription
+ if not invocation_context.transcription_cache:
+ invocation_context.transcription_cache = []
+ invocation_context.transcription_cache.append(
+ TranscriptionEntry(role='user', data=live_request.blob)
+ )
+ await llm_connection.send_realtime(live_request.blob)
+ if live_request.content:
+ await llm_connection.send_content(live_request.content)
+
+ async def _receive_from_model(
+ self,
+ llm_connection: BaseLlmConnection,
+ event_id: str,
+ invocation_context: InvocationContext,
+ llm_request: LlmRequest,
+ ) -> AsyncGenerator[Event, None]:
+ """Receive data from model and process events using BaseLlmConnection."""
+ assert invocation_context.live_request_queue
+ try:
+ while True:
+ async for llm_response in llm_connection.receive():
+ model_response_event = Event(
+ id=Event.new_id(),
+ invocation_id=invocation_context.invocation_id,
+ author=invocation_context.agent.name,
+ )
+ async for event in self._postprocess_live(
+ invocation_context,
+ llm_request,
+ llm_response,
+ model_response_event,
+ ):
+ if (
+ event.content
+ and event.content.parts
+ and event.content.parts[0].text
+ and not event.partial
+ ):
+ if not invocation_context.transcription_cache:
+ invocation_context.transcription_cache = []
+ invocation_context.transcription_cache.append(
+ TranscriptionEntry(role='model', data=event.content)
+ )
+ yield event
+ # Give opportunity for other tasks to run.
+ await asyncio.sleep(0)
+ except ConnectionClosedOK:
+ pass
+
+ async def run_async(
+ self, invocation_context: InvocationContext
+ ) -> AsyncGenerator[Event, None]:
+ """Runs the flow."""
+ while True:
+ last_event = None
+ async for event in self._run_one_step_async(invocation_context):
+ last_event = event
+ yield event
+ if not last_event or last_event.is_final_response():
+ break
+
+ async def _run_one_step_async(
+ self,
+ invocation_context: InvocationContext,
+ ) -> AsyncGenerator[Event, None]:
+ """One step means one LLM call."""
+ llm_request = LlmRequest()
+
+ # Preprocess before calling the LLM.
+ async for event in self._preprocess_async(invocation_context, llm_request):
+ yield event
+ if invocation_context.end_invocation:
+ return
+
+ # Calls the LLM.
+ model_response_event = Event(
+ id=Event.new_id(),
+ invocation_id=invocation_context.invocation_id,
+ author=invocation_context.agent.name,
+ branch=invocation_context.branch,
+ )
+ async for llm_response in self._call_llm_async(
+ invocation_context, llm_request, model_response_event
+ ):
+ # Postprocess after calling the LLM.
+ async for event in self._postprocess_async(
+ invocation_context, llm_request, llm_response, model_response_event
+ ):
+ yield event
+
+ async def _preprocess_async(
+ self, invocation_context: InvocationContext, llm_request: LlmRequest
+ ) -> AsyncGenerator[Event, None]:
+ from ...agents.llm_agent import LlmAgent
+
+ agent = invocation_context.agent
+ if not isinstance(agent, LlmAgent):
+ return
+
+ # Runs processors.
+ for processor in self.request_processors:
+ async for event in processor.run_async(invocation_context, llm_request):
+ yield event
+
+ # Run processors for tools.
+ for tool in agent.canonical_tools:
+ tool_context = ToolContext(invocation_context)
+ await tool.process_llm_request(
+ tool_context=tool_context, llm_request=llm_request
+ )
+
+ async def _postprocess_async(
+ self,
+ invocation_context: InvocationContext,
+ llm_request: LlmRequest,
+ llm_response: LlmResponse,
+ model_response_event: Event,
+ ) -> AsyncGenerator[Event, None]:
+ """Postprocess after calling the LLM.
+
+ Args:
+ invocation_context: The invocation context.
+ llm_request: The original LLM request.
+ llm_response: The LLM response from the LLM call.
+ model_response_event: A mutable event for the LLM response.
+
+ Yields:
+ A generator of events.
+ """
+
+ # Runs processors.
+ async for event in self._postprocess_run_processors_async(
+ invocation_context, llm_response
+ ):
+ yield event
+
+ # Skip the model response event if there is no content and no error code.
+ # This is needed for the code executor to trigger another loop.
+ if (
+ not llm_response.content
+ and not llm_response.error_code
+ and not llm_response.interrupted
+ ):
+ return
+
+ # Builds the event.
+ model_response_event = self._finalize_model_response_event(
+ llm_request, llm_response, model_response_event
+ )
+ yield model_response_event
+
+ # Handles function calls.
+ if model_response_event.get_function_calls():
+ async for event in self._postprocess_handle_function_calls_async(
+ invocation_context, model_response_event, llm_request
+ ):
+ yield event
+
+ async def _postprocess_live(
+ self,
+ invocation_context: InvocationContext,
+ llm_request: LlmRequest,
+ llm_response: LlmResponse,
+ model_response_event: Event,
+ ) -> AsyncGenerator[Event, None]:
+ """Postprocess after calling the LLM asynchronously.
+
+ Args:
+ invocation_context: The invocation context.
+ llm_request: The original LLM request.
+ llm_response: The LLM response from the LLM call.
+ model_response_event: A mutable event for the LLM response.
+
+ Yields:
+ A generator of events.
+ """
+
+ # Runs processors.
+ async for event in self._postprocess_run_processors_async(
+ invocation_context, llm_response
+ ):
+ yield event
+
+ # Skip the model response event if there is no content and no error code.
+ # This is needed for the code executor to trigger another loop.
+ # But don't skip control events like turn_complete.
+ if (
+ not llm_response.content
+ and not llm_response.error_code
+ and not llm_response.interrupted
+ and not llm_response.turn_complete
+ ):
+ return
+
+ # Builds the event.
+ model_response_event = self._finalize_model_response_event(
+ llm_request, llm_response, model_response_event
+ )
+ yield model_response_event
+
+ # Handles function calls.
+ if model_response_event.get_function_calls():
+ function_response_event = await functions.handle_function_calls_live(
+ invocation_context, model_response_event, llm_request.tools_dict
+ )
+ yield function_response_event
+
+ transfer_to_agent = function_response_event.actions.transfer_to_agent
+ if transfer_to_agent:
+ agent_to_run = self._get_agent_to_run(
+ invocation_context, transfer_to_agent
+ )
+ async for item in agent_to_run.run_live(invocation_context):
+ yield item
+
+ async def _postprocess_run_processors_async(
+ self, invocation_context: InvocationContext, llm_response: LlmResponse
+ ) -> AsyncGenerator[Event, None]:
+ for processor in self.response_processors:
+ async for event in processor.run_async(invocation_context, llm_response):
+ yield event
+
+ async def _postprocess_handle_function_calls_async(
+ self,
+ invocation_context: InvocationContext,
+ function_call_event: Event,
+ llm_request: LlmRequest,
+ ) -> AsyncGenerator[Event, None]:
+ if function_response_event := await functions.handle_function_calls_async(
+ invocation_context, function_call_event, llm_request.tools_dict
+ ):
+ auth_event = functions.generate_auth_event(
+ invocation_context, function_response_event
+ )
+ if auth_event:
+ yield auth_event
+
+ yield function_response_event
+ transfer_to_agent = function_response_event.actions.transfer_to_agent
+ if transfer_to_agent:
+ agent_to_run = self._get_agent_to_run(
+ invocation_context, transfer_to_agent
+ )
+ async for event in agent_to_run.run_async(invocation_context):
+ yield event
+
+ def _get_agent_to_run(
+ self, invocation_context: InvocationContext, transfer_to_agent
+ ) -> BaseAgent:
+ root_agent = invocation_context.agent.root_agent
+ agent_to_run = root_agent.find_agent(transfer_to_agent)
+ if not agent_to_run:
+ raise ValueError(
+ f'Agent {transfer_to_agent} not found in the agent tree.'
+ )
+ return agent_to_run
+
+ async def _call_llm_async(
+ self,
+ invocation_context: InvocationContext,
+ llm_request: LlmRequest,
+ model_response_event: Event,
+ ) -> AsyncGenerator[LlmResponse, None]:
+ # Runs before_model_callback if it exists.
+ if response := self._handle_before_model_callback(
+ invocation_context, llm_request, model_response_event
+ ):
+ yield response
+ return
+
+ # Calls the LLM.
+ llm = self.__get_llm(invocation_context)
+ with tracer.start_as_current_span('call_llm'):
+ if invocation_context.run_config.support_cfc:
+ invocation_context.live_request_queue = LiveRequestQueue()
+ async for llm_response in self.run_live(invocation_context):
+ # Runs after_model_callback if it exists.
+ if altered_llm_response := self._handle_after_model_callback(
+ invocation_context, llm_response, model_response_event
+ ):
+ llm_response = altered_llm_response
+ # only yield partial response in SSE streaming mode
+ if (
+ invocation_context.run_config.streaming_mode == StreamingMode.SSE
+ or not llm_response.partial
+ ):
+ yield llm_response
+ if llm_response.turn_complete:
+ invocation_context.live_request_queue.close()
+ else:
+ # Check if we can make this llm call or not. If the current call pushes
+ # the counter beyond the max set value, then the execution is stopped
+ # right here, and exception is thrown.
+ invocation_context.increment_llm_call_count()
+ async for llm_response in llm.generate_content_async(
+ llm_request,
+ stream=invocation_context.run_config.streaming_mode
+ == StreamingMode.SSE,
+ ):
+ trace_call_llm(
+ invocation_context,
+ model_response_event.id,
+ llm_request,
+ llm_response,
+ )
+ # Runs after_model_callback if it exists.
+ if altered_llm_response := self._handle_after_model_callback(
+ invocation_context, llm_response, model_response_event
+ ):
+ llm_response = altered_llm_response
+
+ yield llm_response
+
+ def _handle_before_model_callback(
+ self,
+ invocation_context: InvocationContext,
+ llm_request: LlmRequest,
+ model_response_event: Event,
+ ) -> Optional[LlmResponse]:
+ from ...agents.llm_agent import LlmAgent
+
+ agent = invocation_context.agent
+ if not isinstance(agent, LlmAgent):
+ return
+
+ if not agent.before_model_callback:
+ return
+
+ callback_context = CallbackContext(
+ invocation_context, event_actions=model_response_event.actions
+ )
+ return agent.before_model_callback(
+ callback_context=callback_context, llm_request=llm_request
+ )
+
+ def _handle_after_model_callback(
+ self,
+ invocation_context: InvocationContext,
+ llm_response: LlmResponse,
+ model_response_event: Event,
+ ) -> Optional[LlmResponse]:
+ from ...agents.llm_agent import LlmAgent
+
+ agent = invocation_context.agent
+ if not isinstance(agent, LlmAgent):
+ return
+
+ if not agent.after_model_callback:
+ return
+
+ callback_context = CallbackContext(
+ invocation_context, event_actions=model_response_event.actions
+ )
+ return agent.after_model_callback(
+ callback_context=callback_context, llm_response=llm_response
+ )
+
+ def _finalize_model_response_event(
+ self,
+ llm_request: LlmRequest,
+ llm_response: LlmResponse,
+ model_response_event: Event,
+ ) -> Event:
+ model_response_event = Event.model_validate({
+ **model_response_event.model_dump(exclude_none=True),
+ **llm_response.model_dump(exclude_none=True),
+ })
+
+ if model_response_event.content:
+ function_calls = model_response_event.get_function_calls()
+ if function_calls:
+ functions.populate_client_function_call_id(model_response_event)
+ model_response_event.long_running_tool_ids = (
+ functions.get_long_running_function_calls(
+ function_calls, llm_request.tools_dict
+ )
+ )
+
+ return model_response_event
+
+ def __get_llm(self, invocation_context: InvocationContext) -> BaseLlm:
+ from ...agents.llm_agent import LlmAgent
+
+ return cast(LlmAgent, invocation_context.agent).canonical_model
diff --git a/src/google/adk/flows/llm_flows/basic.py b/src/google/adk/flows/llm_flows/basic.py
new file mode 100644
index 0000000..278b4cf
--- /dev/null
+++ b/src/google/adk/flows/llm_flows/basic.py
@@ -0,0 +1,72 @@
+# Copyright 2025 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Handles basic information to build the LLM request."""
+
+from __future__ import annotations
+
+from typing import AsyncGenerator
+from typing import Generator
+
+from google.genai import types
+from typing_extensions import override
+
+from ...agents.invocation_context import InvocationContext
+from ...events.event import Event
+from ...models.llm_request import LlmRequest
+from ._base_llm_processor import BaseLlmRequestProcessor
+
+
+class _BasicLlmRequestProcessor(BaseLlmRequestProcessor):
+
+ @override
+ async def run_async(
+ self, invocation_context: InvocationContext, llm_request: LlmRequest
+ ) -> AsyncGenerator[Event, None]:
+ from ...agents.llm_agent import LlmAgent
+
+ agent = invocation_context.agent
+ if not isinstance(agent, LlmAgent):
+ return
+
+ llm_request.model = (
+ agent.canonical_model
+ if isinstance(agent.canonical_model, str)
+ else agent.canonical_model.model
+ )
+ llm_request.config = (
+ agent.generate_content_config.model_copy(deep=True)
+ if agent.generate_content_config
+ else types.GenerateContentConfig()
+ )
+ if agent.output_schema:
+ llm_request.set_output_schema(agent.output_schema)
+
+ llm_request.live_connect_config.response_modalities = (
+ invocation_context.run_config.response_modalities
+ )
+ llm_request.live_connect_config.speech_config = (
+ invocation_context.run_config.speech_config
+ )
+ llm_request.live_connect_config.output_audio_transcription = (
+ invocation_context.run_config.output_audio_transcription
+ )
+
+ # TODO: handle tool append here, instead of in BaseTool.process_llm_request.
+
+ return
+ yield # Generator requires yield statement in function body.
+
+
+request_processor = _BasicLlmRequestProcessor()
diff --git a/src/google/adk/flows/llm_flows/contents.py b/src/google/adk/flows/llm_flows/contents.py
new file mode 100644
index 0000000..d2b8dee
--- /dev/null
+++ b/src/google/adk/flows/llm_flows/contents.py
@@ -0,0 +1,370 @@
+# Copyright 2025 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import annotations
+
+import copy
+from typing import AsyncGenerator
+from typing import Generator
+from typing import Optional
+
+from google.genai import types
+from typing_extensions import override
+
+from ...agents.invocation_context import InvocationContext
+from ...events.event import Event
+from ...models.llm_request import LlmRequest
+from ._base_llm_processor import BaseLlmRequestProcessor
+from .functions import remove_client_function_call_id
+
+
+class _ContentLlmRequestProcessor(BaseLlmRequestProcessor):
+ """Builds the contents for the LLM request."""
+
+ @override
+ async def run_async(
+ self, invocation_context: InvocationContext, llm_request: LlmRequest
+ ) -> AsyncGenerator[Event, None]:
+ from ...agents.llm_agent import LlmAgent
+
+ agent = invocation_context.agent
+ if not isinstance(agent, LlmAgent):
+ return
+
+ if agent.include_contents != 'none':
+ llm_request.contents = _get_contents(
+ invocation_context.branch,
+ invocation_context.session.events,
+ agent.name,
+ )
+
+ # Maintain async generator behavior
+ if False: # Ensures it behaves as a generator
+ yield # This is a no-op but maintains generator structure
+
+
+request_processor = _ContentLlmRequestProcessor()
+
+
+def _rearrange_events_for_async_function_responses_in_history(
+ events: list[Event],
+) -> list[Event]:
+ """Rearrange the async function_response events in the history."""
+
+ function_call_id_to_response_events_index: dict[str, list[Event]] = {}
+ for i, event in enumerate(events):
+ function_responses = event.get_function_responses()
+ if function_responses:
+ for function_response in function_responses:
+ function_call_id = function_response.id
+ function_call_id_to_response_events_index[function_call_id] = i
+
+ result_events: list[Event] = []
+ for event in events:
+ if event.get_function_responses():
+ # function_response should be handled together with function_call below.
+ continue
+ elif event.get_function_calls():
+
+ function_response_events_indices = set()
+ for function_call in event.get_function_calls():
+ function_call_id = function_call.id
+ if function_call_id in function_call_id_to_response_events_index:
+ function_response_events_indices.add(
+ function_call_id_to_response_events_index[function_call_id]
+ )
+ result_events.append(event)
+ if not function_response_events_indices:
+ continue
+ if len(function_response_events_indices) == 1:
+ result_events.append(
+ events[next(iter(function_response_events_indices))]
+ )
+ else: # Merge all async function_response as one response event
+ result_events.append(
+ _merge_function_response_events(
+ [events[i] for i in sorted(function_response_events_indices)]
+ )
+ )
+ continue
+ else:
+ result_events.append(event)
+
+ return result_events
+
+
+def _rearrange_events_for_latest_function_response(
+ events: list[Event],
+) -> list[Event]:
+ """Rearrange the events for the latest function_response.
+
+ If the latest function_response is for an async function_call, all events
+ bewteen the initial function_call and the latest function_response will be
+ removed.
+
+ Args:
+ events: A list of events.
+
+ Returns:
+ A list of events with the latest function_response rearranged.
+ """
+ if not events:
+ return events
+
+ function_responses = events[-1].get_function_responses()
+ if not function_responses:
+ # No need to process, since the latest event is not fuction_response.
+ return events
+
+ function_responses_ids = set()
+ for function_response in function_responses:
+ function_responses_ids.add(function_response.id)
+
+ function_calls = events[-2].get_function_calls()
+
+ if function_calls:
+ for function_call in function_calls:
+ # The latest function_response is already matched
+ if function_call.id in function_responses_ids:
+ return events
+
+ function_call_event_idx = -1
+ # look for corresponding function call event reversely
+ for idx in range(len(events) - 2, -1, -1):
+ event = events[idx]
+ function_calls = event.get_function_calls()
+ if function_calls:
+ for function_call in function_calls:
+ if function_call.id in function_responses_ids:
+ function_call_event_idx = idx
+ break
+ if function_call_event_idx != -1:
+ # in case the last response event only have part of the responses
+ # for the function calls in the function call event
+ for function_call in function_calls:
+ function_responses_ids.add(function_call.id)
+ break
+
+ if function_call_event_idx == -1:
+ raise ValueError(
+ 'No function call event found for function responses ids:'
+ f' {function_responses_ids}'
+ )
+
+ # collect all function response between last function response event
+ # and function call event
+
+ function_response_events: list[Event] = []
+ for idx in range(function_call_event_idx + 1, len(events) - 1):
+ event = events[idx]
+ function_responses = event.get_function_responses()
+ if (
+ function_responses
+ and function_responses[0].id in function_responses_ids
+ ):
+ function_response_events.append(event)
+ function_response_events.append(events[-1])
+
+ result_events = events[: function_call_event_idx + 1]
+ result_events.append(
+ _merge_function_response_events(function_response_events)
+ )
+
+ return result_events
+
+
+def _get_contents(
+ current_branch: Optional[str], events: list[Event], agent_name: str = ''
+) -> list[types.Content]:
+ """Get the contents for the LLM request.
+
+ Args:
+ current_branch: The current branch of the agent.
+ events: A list of events.
+ agent_name: The name of the agent.
+
+ Returns:
+ A list of contents.
+ """
+ filtered_events = []
+ # Parse the events, leaving the contents and the function calls and
+ # responses from the current agent.
+ for event in events:
+ if not event.content or not event.content.role:
+ # Skip events without content, or generated neither by user nor by model.
+ # E.g. events purely for mutating session states.
+ continue
+ if not _is_event_belongs_to_branch(current_branch, event):
+ # Skip events not belong to current branch.
+ continue
+
+ filtered_events.append(
+ _convert_foreign_event(event)
+ if _is_other_agent_reply(agent_name, event)
+ else event
+ )
+
+ result_events = _rearrange_events_for_latest_function_response(
+ filtered_events
+ )
+ result_events = _rearrange_events_for_async_function_responses_in_history(
+ result_events
+ )
+ contents = []
+ for event in result_events:
+ content = copy.deepcopy(event.content)
+ remove_client_function_call_id(content)
+ contents.append(content)
+ return contents
+
+
+def _is_other_agent_reply(current_agent_name: str, event: Event) -> bool:
+ """Whether the event is a reply from another agent."""
+ return bool(
+ current_agent_name
+ and event.author != current_agent_name
+ and event.author != 'user'
+ )
+
+
+def _convert_foreign_event(event: Event) -> Event:
+ """Converts an event authored by another agent as a user-content event.
+
+ This is to provide another agent's output as context to the current agent, so
+ that current agent can continue to respond, such as summarizing previous
+ agent's reply, etc.
+
+ Args:
+ event: The event to convert.
+
+ Returns:
+ The converted event.
+
+ """
+ if not event.content or not event.content.parts:
+ return event
+
+ content = types.Content()
+ content.role = 'user'
+ content.parts = [types.Part(text='For context:')]
+ for part in event.content.parts:
+ if part.text:
+ content.parts.append(
+ types.Part(text=f'[{event.author}] said: {part.text}')
+ )
+ elif part.function_call:
+ content.parts.append(
+ types.Part(
+ text=(
+ f'[{event.author}] called tool `{part.function_call.name}`'
+ f' with parameters: {part.function_call.args}'
+ )
+ )
+ )
+ elif part.function_response:
+ # Otherwise, create a new text part.
+ content.parts.append(
+ types.Part(
+ text=(
+ f'[{event.author}] `{part.function_response.name}` tool'
+ f' returned result: {part.function_response.response}'
+ )
+ )
+ )
+ # Fallback to the original part for non-text and non-functionCall parts.
+ else:
+ content.parts.append(part)
+
+ return Event(
+ timestamp=event.timestamp,
+ author='user',
+ content=content,
+ branch=event.branch,
+ )
+
+
+def _merge_function_response_events(
+ function_response_events: list[Event],
+) -> Event:
+ """Merges a list of function_response events into one event.
+
+ The key goal is to ensure:
+ 1. function_call and function_response are always of the same number.
+ 2. The function_call and function_response are consecutively in the content.
+
+ Args:
+ function_response_events: A list of function_response events.
+ NOTE: function_response_events must fulfill these requirements: 1. The
+ list is in increasing order of timestamp; 2. the first event is the
+ initial function_reponse event; 3. all later events should contain at
+ least one function_response part that related to the function_call
+ event. (Note, 3. may not be true when aync function return some
+ intermediate response, there could also be some intermediate model
+ response event without any function_response and such event will be
+ ignored.)
+ Caveat: This implementation doesn't support when a parallel function_call
+ event contains async function_call of the same name.
+
+ Returns:
+ A merged event, that is
+ 1. All later function_response will replace function_response part in
+ the initial function_response event.
+ 2. All non-function_response parts will be appended to the part list of
+ the initial function_response event.
+ """
+ if not function_response_events:
+ raise ValueError('At least one function_response event is required.')
+
+ merged_event = function_response_events[0].model_copy(deep=True)
+ parts_in_merged_event: list[types.Part] = merged_event.content.parts # type: ignore
+
+ if not parts_in_merged_event:
+ raise ValueError('There should be at least one function_response part.')
+
+ part_indices_in_merged_event: dict[str, int] = {}
+ for idx, part in enumerate(parts_in_merged_event):
+ if part.function_response:
+ function_call_id: str = part.function_response.id # type: ignore
+ part_indices_in_merged_event[function_call_id] = idx
+
+ for event in function_response_events[1:]:
+ if not event.content.parts:
+ raise ValueError('There should be at least one function_response part.')
+
+ for part in event.content.parts:
+ if part.function_response:
+ function_call_id: str = part.function_response.id # type: ignore
+ if function_call_id in part_indices_in_merged_event:
+ parts_in_merged_event[
+ part_indices_in_merged_event[function_call_id]
+ ] = part
+ else:
+ parts_in_merged_event.append(part)
+ part_indices_in_merged_event[function_call_id] = (
+ len(parts_in_merged_event) - 1
+ )
+
+ else:
+ parts_in_merged_event.append(part)
+
+ return merged_event
+
+
+def _is_event_belongs_to_branch(
+ invocation_branch: Optional[str], event: Event
+) -> bool:
+ """Event belongs to a branch, when event.branch is prefix of the invocation branch."""
+ if not invocation_branch or not event.branch:
+ return True
+ return invocation_branch.startswith(event.branch)
diff --git a/src/google/adk/flows/llm_flows/functions.py b/src/google/adk/flows/llm_flows/functions.py
new file mode 100644
index 0000000..88f5967
--- /dev/null
+++ b/src/google/adk/flows/llm_flows/functions.py
@@ -0,0 +1,463 @@
+# Copyright 2025 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Handles function callings for LLM flow."""
+
+from __future__ import annotations
+
+import asyncio
+import inspect
+import logging
+from typing import Any
+from typing import AsyncGenerator
+from typing import cast
+from typing import Optional
+import uuid
+
+from google.genai import types
+
+from ...agents.active_streaming_tool import ActiveStreamingTool
+from ...agents.invocation_context import InvocationContext
+from ...auth.auth_tool import AuthToolArguments
+from ...events.event import Event
+from ...events.event_actions import EventActions
+from ...telemetry import tracer
+from ...tools.base_tool import BaseTool
+from ...tools.tool_context import ToolContext
+
+AF_FUNCTION_CALL_ID_PREFIX = 'adk-'
+REQUEST_EUC_FUNCTION_CALL_NAME = 'adk_request_credential'
+
+logger = logging.getLogger(__name__)
+
+
+def generate_client_function_call_id() -> str:
+ return f'{AF_FUNCTION_CALL_ID_PREFIX}{uuid.uuid4()}'
+
+
+def populate_client_function_call_id(model_response_event: Event) -> None:
+ if not model_response_event.get_function_calls():
+ return
+ for function_call in model_response_event.get_function_calls():
+ if not function_call.id:
+ function_call.id = generate_client_function_call_id()
+
+
+def remove_client_function_call_id(content: types.Content) -> None:
+ if content and content.parts:
+ for part in content.parts:
+ if (
+ part.function_call
+ and part.function_call.id
+ and part.function_call.id.startswith(AF_FUNCTION_CALL_ID_PREFIX)
+ ):
+ part.function_call.id = None
+ if (
+ part.function_response
+ and part.function_response.id
+ and part.function_response.id.startswith(AF_FUNCTION_CALL_ID_PREFIX)
+ ):
+ part.function_response.id = None
+
+
+def get_long_running_function_calls(
+ function_calls: list[types.FunctionCall],
+ tools_dict: dict[str, BaseTool],
+) -> set[str]:
+ long_running_tool_ids = set()
+ for function_call in function_calls:
+ if (
+ function_call.name in tools_dict
+ and tools_dict[function_call.name].is_long_running
+ ):
+ long_running_tool_ids.add(function_call.id)
+
+ return long_running_tool_ids
+
+
+def generate_auth_event(
+ invocation_context: InvocationContext,
+ function_response_event: Event,
+) -> Optional[Event]:
+ if not function_response_event.actions.requested_auth_configs:
+ return None
+ parts = []
+ long_running_tool_ids = set()
+ for (
+ function_call_id,
+ auth_config,
+ ) in function_response_event.actions.requested_auth_configs.items():
+
+ request_euc_function_call = types.FunctionCall(
+ name=REQUEST_EUC_FUNCTION_CALL_NAME,
+ args=AuthToolArguments(
+ function_call_id=function_call_id,
+ auth_config=auth_config,
+ ).model_dump(exclude_none=True),
+ )
+ request_euc_function_call.id = generate_client_function_call_id()
+ long_running_tool_ids.add(request_euc_function_call.id)
+ parts.append(types.Part(function_call=request_euc_function_call))
+
+ return Event(
+ invocation_id=invocation_context.invocation_id,
+ author=invocation_context.agent.name,
+ branch=invocation_context.branch,
+ content=types.Content(parts=parts),
+ long_running_tool_ids=long_running_tool_ids,
+ )
+
+
+async def handle_function_calls_async(
+ invocation_context: InvocationContext,
+ function_call_event: Event,
+ tools_dict: dict[str, BaseTool],
+ filters: Optional[set[str]] = None,
+) -> Optional[Event]:
+ """Calls the functions and returns the function response event."""
+ from ...agents.llm_agent import LlmAgent
+
+ agent = invocation_context.agent
+ if not isinstance(agent, LlmAgent):
+ return
+
+ function_calls = function_call_event.get_function_calls()
+
+ function_response_events: list[Event] = []
+ for function_call in function_calls:
+ if filters and function_call.id not in filters:
+ continue
+ tool, tool_context = _get_tool_and_context(
+ invocation_context,
+ function_call_event,
+ function_call,
+ tools_dict,
+ )
+ # do not use "args" as the variable name, because it is a reserved keyword
+ # in python debugger.
+ function_args = function_call.args or {}
+ function_response = None
+ # Calls the tool if before_tool_callback does not exist or returns None.
+ if agent.before_tool_callback:
+ function_response = agent.before_tool_callback(
+ tool=tool, args=function_args, tool_context=tool_context
+ )
+
+ if not function_response:
+ function_response = await __call_tool_async(
+ tool, args=function_args, tool_context=tool_context
+ )
+
+ # Calls after_tool_callback if it exists.
+ if agent.after_tool_callback:
+ new_response = agent.after_tool_callback(
+ tool=tool,
+ args=function_args,
+ tool_context=tool_context,
+ tool_response=function_response,
+ )
+ if new_response:
+ function_response = new_response
+
+ if tool.is_long_running:
+ # Allow long running function to return None to not provide function response.
+ if not function_response:
+ continue
+
+ # Builds the function response event.
+ function_response_event = __build_response_event(
+ tool, function_response, tool_context, invocation_context
+ )
+ function_response_events.append(function_response_event)
+
+ if not function_response_events:
+ return None
+ merged_event = merge_parallel_function_response_events(
+ function_response_events
+ )
+ return merged_event
+
+
+async def handle_function_calls_live(
+ invocation_context: InvocationContext,
+ function_call_event: Event,
+ tools_dict: dict[str, BaseTool],
+) -> Event:
+ """Calls the functions and returns the function response event."""
+ from ...agents.llm_agent import LlmAgent
+
+ agent = cast(LlmAgent, invocation_context.agent)
+ function_calls = function_call_event.get_function_calls()
+
+ function_response_events: list[Event] = []
+ for function_call in function_calls:
+ tool, tool_context = _get_tool_and_context(
+ invocation_context, function_call_event, function_call, tools_dict
+ )
+ # do not use "args" as the variable name, because it is a reserved keyword
+ # in python debugger.
+ function_args = function_call.args or {}
+ function_response = None
+ # Calls the tool if before_tool_callback does not exist or returns None.
+ if agent.before_tool_callback:
+ function_response = agent.before_tool_callback(
+ tool, function_args, tool_context
+ )
+
+ if not function_response:
+ function_response = await _process_function_live_helper(
+ tool, tool_context, function_call, function_args, invocation_context
+ )
+
+ # Calls after_tool_callback if it exists.
+ if agent.after_tool_callback:
+ new_response = agent.after_tool_callback(
+ tool,
+ function_args,
+ tool_context,
+ function_response,
+ )
+ if new_response:
+ function_response = new_response
+
+ if tool.is_long_running:
+ # Allow async function to return None to not provide function response.
+ if not function_response:
+ continue
+
+ # Builds the function response event.
+ function_response_event = __build_response_event(
+ tool, function_response, tool_context, invocation_context
+ )
+ function_response_events.append(function_response_event)
+
+ if not function_response_events:
+ return None
+ merged_event = merge_parallel_function_response_events(
+ function_response_events
+ )
+ return merged_event
+
+
+async def _process_function_live_helper(
+ tool, tool_context, function_call, function_args, invocation_context
+):
+ function_response = None
+ # Check if this is a stop_streaming function call
+ if (
+ function_call.name == 'stop_streaming'
+ and 'function_name' in function_args
+ ):
+ function_name = function_args['function_name']
+ active_tasks = invocation_context.active_streaming_tools
+ if (
+ function_name in active_tasks
+ and active_tasks[function_name].task
+ and not active_tasks[function_name].task.done()
+ ):
+ task = active_tasks[function_name].task
+ task.cancel()
+ try:
+ # Wait for the task to be cancelled
+ await asyncio.wait_for(task, timeout=1.0)
+ except (asyncio.CancelledError, asyncio.TimeoutError):
+ # Log the specific condition
+ if task.cancelled():
+ logging.info(f'Task {function_name} was cancelled successfully')
+ elif task.done():
+ logging.info(f'Task {function_name} completed during cancellation')
+ else:
+ logging.warning(
+ f'Task {function_name} might still be running after'
+ ' cancellation timeout'
+ )
+ function_response = {
+ 'status': f'The task is not cancelled yet for {function_name}.'
+ }
+ if not function_response:
+ # Clean up the reference
+ active_tasks[function_name].task = None
+
+ function_response = {
+ 'status': f'Successfully stopped streaming function {function_name}'
+ }
+ else:
+ function_response = {
+ 'status': f'No active streaming function named {function_name} found'
+ }
+ elif inspect.isasyncgenfunction(tool.func):
+ print('is async')
+
+ # for streaming tool use case
+ # we require the function to be a async generator function
+ async def run_tool_and_update_queue(tool, function_args, tool_context):
+ try:
+ async for result in __call_tool_live(
+ tool=tool,
+ args=function_args,
+ tool_context=tool_context,
+ invocation_context=invocation_context,
+ ):
+ updated_content = types.Content(
+ role='user',
+ parts=[
+ types.Part.from_text(
+ text=f'Function {tool.name} returned: {result}'
+ )
+ ],
+ )
+ invocation_context.live_request_queue.send_content(updated_content)
+ except asyncio.CancelledError:
+ raise # Re-raise to properly propagate the cancellation
+
+ task = asyncio.create_task(
+ run_tool_and_update_queue(tool, function_args, tool_context)
+ )
+ if invocation_context.active_streaming_tools is None:
+ invocation_context.active_streaming_tools = {}
+ if tool.name in invocation_context.active_streaming_tools:
+ invocation_context.active_streaming_tools[tool.name].task = task
+ else:
+ invocation_context.active_streaming_tools[tool.name] = (
+ ActiveStreamingTool(task=task)
+ )
+ # Immediately return a pending response.
+ # This is required by current live model.
+ function_response = {
+ 'status': (
+ 'The function is running asynchronously and the results are'
+ ' pending.'
+ )
+ }
+ else:
+ function_response = await __call_tool_async(
+ tool, args=function_args, tool_context=tool_context
+ )
+ return function_response
+
+
+def _get_tool_and_context(
+ invocation_context: InvocationContext,
+ function_call_event: Event,
+ function_call: types.FunctionCall,
+ tools_dict: dict[str, BaseTool],
+):
+ if function_call.name not in tools_dict:
+ raise ValueError(
+ f'Function {function_call.name} is not found in the tools_dict.'
+ )
+
+ tool_context = ToolContext(
+ invocation_context=invocation_context,
+ function_call_id=function_call.id,
+ )
+
+ tool = tools_dict[function_call.name]
+
+ return (tool, tool_context)
+
+
+async def __call_tool_live(
+ tool: BaseTool,
+ args: dict[str, object],
+ tool_context: ToolContext,
+ invocation_context: InvocationContext,
+) -> AsyncGenerator[Event, None]:
+ """Calls the tool asynchronously (awaiting the coroutine)."""
+ with tracer.start_as_current_span(f'call_tool [{tool.name}]'):
+ async for item in tool._call_live(
+ args=args,
+ tool_context=tool_context,
+ invocation_context=invocation_context,
+ ):
+ yield item
+
+
+async def __call_tool_async(
+ tool: BaseTool,
+ args: dict[str, Any],
+ tool_context: ToolContext,
+) -> Any:
+ """Calls the tool."""
+ with tracer.start_as_current_span(f'call_tool [{tool.name}]'):
+ return await tool.run_async(args=args, tool_context=tool_context)
+
+
+def __build_response_event(
+ tool: BaseTool,
+ function_result: dict[str, object],
+ tool_context: ToolContext,
+ invocation_context: InvocationContext,
+) -> Event:
+ # Specs requires the result to be a dict.
+ if not isinstance(function_result, dict):
+ function_result = {'result': function_result}
+
+ part_function_response = types.Part.from_function_response(
+ name=tool.name, response=function_result
+ )
+ part_function_response.function_response.id = tool_context.function_call_id
+
+ content = types.Content(
+ role='user',
+ parts=[part_function_response],
+ )
+ return Event(
+ invocation_id=invocation_context.invocation_id,
+ author=invocation_context.agent.name,
+ content=content,
+ actions=tool_context.actions,
+ branch=invocation_context.branch,
+ )
+
+
+def merge_parallel_function_response_events(
+ function_response_events: list['Event'],
+) -> 'Event':
+ if not function_response_events:
+ raise ValueError('No function response events provided.')
+
+ if len(function_response_events) == 1:
+ return function_response_events[0]
+ merged_parts = []
+ for event in function_response_events:
+ if event.content:
+ for part in event.content.parts or []:
+ merged_parts.append(part)
+
+ # Use the first event as the "base" for common attributes
+ base_event = function_response_events[0]
+
+ # Merge actions from all events
+
+ merged_actions = EventActions()
+ merged_requested_auth_configs = {}
+ for event in function_response_events:
+ merged_requested_auth_configs.update(event.actions.requested_auth_configs)
+ merged_actions = merged_actions.model_copy(
+ update=event.actions.model_dump()
+ )
+ merged_actions.requested_auth_configs = merged_requested_auth_configs
+ # Create the new merged event
+ merged_event = Event(
+ invocation_id=Event.new_id(),
+ author=base_event.author,
+ branch=base_event.branch,
+ content=types.Content(role='user', parts=merged_parts),
+ actions=merged_actions, # Optionally merge actions if required
+ )
+
+ # Use the base_event as the timestamp
+ merged_event.timestamp = base_event.timestamp
+ return merged_event
diff --git a/src/google/adk/flows/llm_flows/identity.py b/src/google/adk/flows/llm_flows/identity.py
new file mode 100644
index 0000000..9a94821
--- /dev/null
+++ b/src/google/adk/flows/llm_flows/identity.py
@@ -0,0 +1,47 @@
+# Copyright 2025 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Gives the agent identity from the framework."""
+
+from __future__ import annotations
+
+from typing import AsyncGenerator
+
+from typing_extensions import override
+
+from ...agents.invocation_context import InvocationContext
+from ...events.event import Event
+from ...models.llm_request import LlmRequest
+from ._base_llm_processor import BaseLlmRequestProcessor
+
+
+class _IdentityLlmRequestProcessor(BaseLlmRequestProcessor):
+ """Gives the agent identity from the framework."""
+
+ @override
+ async def run_async(
+ self, invocation_context: InvocationContext, llm_request: LlmRequest
+ ) -> AsyncGenerator[Event, None]:
+ agent = invocation_context.agent
+ si = [f'You are an agent. Your internal name is "{agent.name}".']
+ if agent.description:
+ si.append(f' The description about you is "{agent.description}"')
+ llm_request.append_instructions(si)
+
+ # Maintain async generator behavior
+ if False: # Ensures it behaves as a generator
+ yield # This is a no-op but maintains generator structure
+
+
+request_processor = _IdentityLlmRequestProcessor()
diff --git a/src/google/adk/flows/llm_flows/instructions.py b/src/google/adk/flows/llm_flows/instructions.py
new file mode 100644
index 0000000..02af9dc
--- /dev/null
+++ b/src/google/adk/flows/llm_flows/instructions.py
@@ -0,0 +1,137 @@
+# Copyright 2025 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Handles instructions and global instructions for LLM flow."""
+
+from __future__ import annotations
+
+import re
+from typing import AsyncGenerator
+from typing import Generator
+from typing import TYPE_CHECKING
+
+from typing_extensions import override
+
+from ...agents.readonly_context import ReadonlyContext
+from ...events.event import Event
+from ...sessions.state import State
+from ._base_llm_processor import BaseLlmRequestProcessor
+
+if TYPE_CHECKING:
+ from ...agents.invocation_context import InvocationContext
+ from ...models.llm_request import LlmRequest
+
+
+class _InstructionsLlmRequestProcessor(BaseLlmRequestProcessor):
+ """Handles instructions and global instructions for LLM flow."""
+
+ @override
+ async def run_async(
+ self, invocation_context: InvocationContext, llm_request: LlmRequest
+ ) -> AsyncGenerator[Event, None]:
+ from ...agents.base_agent import BaseAgent
+ from ...agents.llm_agent import LlmAgent
+
+ agent = invocation_context.agent
+ if not isinstance(agent, LlmAgent):
+ return
+
+ root_agent: BaseAgent = agent.root_agent
+
+ # Appends global instructions if set.
+ if (
+ isinstance(root_agent, LlmAgent) and root_agent.global_instruction
+ ): # not emtpy str
+ raw_si = root_agent.canonical_global_instruction(
+ ReadonlyContext(invocation_context)
+ )
+ si = _populate_values(raw_si, invocation_context)
+ llm_request.append_instructions([si])
+
+ # Appends agent instructions if set.
+ if agent.instruction: # not emtpy str
+ raw_si = agent.canonical_instruction(ReadonlyContext(invocation_context))
+ si = _populate_values(raw_si, invocation_context)
+ llm_request.append_instructions([si])
+
+ # Maintain async generator behavior
+ if False: # Ensures it behaves as a generator
+ yield # This is a no-op but maintains generator structure
+
+
+request_processor = _InstructionsLlmRequestProcessor()
+
+
+def _populate_values(
+ instruction_template: str,
+ context: InvocationContext,
+) -> str:
+ """Populates values in the instruction template, e.g. state, artifact, etc."""
+
+ def _replace_match(match) -> str:
+ var_name = match.group().lstrip('{').rstrip('}').strip()
+ optional = False
+ if var_name.endswith('?'):
+ optional = True
+ var_name = var_name.removesuffix('?')
+ if var_name.startswith('artifact.'):
+ var_name = var_name.removeprefix('artifact.')
+ if context.artifact_service is None:
+ raise ValueError('Artifact service is not initialized.')
+ artifact = context.artifact_service.load_artifact(
+ app_name=context.session.app_name,
+ user_id=context.session.user_id,
+ session_id=context.session.id,
+ filename=var_name,
+ )
+ if not var_name:
+ raise KeyError(f'Artifact {var_name} not found.')
+ return str(artifact)
+ else:
+ if not _is_valid_state_name(var_name):
+ return match.group()
+ if var_name in context.session.state:
+ return str(context.session.state[var_name])
+ else:
+ if optional:
+ return ''
+ else:
+ raise KeyError(f'Context variable not found: `{var_name}`.')
+
+ return re.sub(r'{+[^{}]*}+', _replace_match, instruction_template)
+
+
+def _is_valid_state_name(var_name):
+ """Checks if the variable name is a valid state name.
+
+ Valid state is either:
+ - Valid identifier
+ - :
+ All the others will just return as it is.
+
+ Args:
+ var_name: The variable name to check.
+
+ Returns:
+ True if the variable name is a valid state name, False otherwise.
+ """
+ parts = var_name.split(':')
+ if len(parts) == 1:
+ return var_name.isidentifier()
+
+ if len(parts) == 2:
+ prefixes = [State.APP_PREFIX, State.USER_PREFIX, State.TEMP_PREFIX]
+ if (parts[0] + ':') in prefixes:
+ return parts[1].isidentifier()
+ return False
diff --git a/src/google/adk/flows/llm_flows/single_flow.py b/src/google/adk/flows/llm_flows/single_flow.py
new file mode 100644
index 0000000..8d3239c
--- /dev/null
+++ b/src/google/adk/flows/llm_flows/single_flow.py
@@ -0,0 +1,57 @@
+# Copyright 2025 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Implementation of single flow."""
+
+import logging
+
+from ...auth import auth_preprocessor
+from . import _code_execution
+from . import _nl_planning
+from . import basic
+from . import contents
+from . import identity
+from . import instructions
+from .base_llm_flow import BaseLlmFlow
+
+logger = logging.getLogger(__name__)
+
+
+class SingleFlow(BaseLlmFlow):
+ """SingleFlow is the LLM flows that handles tools calls.
+
+ A single flow only consider an agent itself and tools.
+ No sub-agents are allowed for single flow.
+ """
+
+ def __init__(self):
+ super().__init__()
+ self.request_processors += [
+ basic.request_processor,
+ auth_preprocessor.request_processor,
+ instructions.request_processor,
+ identity.request_processor,
+ contents.request_processor,
+ # Some implementations of NL Planning mark planning contents as thoughts
+ # in the post processor. Since these need to be unmarked, NL Planning
+ # should be after contents.
+ _nl_planning.request_processor,
+ # Code execution should be after the contents as it mutates the contents
+ # to optimize data files.
+ _code_execution.request_processor,
+ ]
+ self.response_processors += [
+ _nl_planning.response_processor,
+ _code_execution.response_processor,
+ ]
diff --git a/src/google/adk/memory/__init__.py b/src/google/adk/memory/__init__.py
new file mode 100644
index 0000000..473e315
--- /dev/null
+++ b/src/google/adk/memory/__init__.py
@@ -0,0 +1,35 @@
+# Copyright 2025 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import logging
+
+from .base_memory_service import BaseMemoryService
+from .in_memory_memory_service import InMemoryMemoryService
+
+logger = logging.getLogger(__name__)
+
+__all__ = [
+ 'BaseMemoryService',
+ 'InMemoryMemoryService',
+]
+
+try:
+ from .vertex_ai_rag_memory_service import VertexAiRagMemoryService
+
+ __all__.append('VertexAiRagMemoryService')
+except ImportError:
+ logger.debug(
+ 'The Vertex sdk is not installed. If you want to use the'
+ ' VertexAiRagMemoryService please install it. If not, you can ignore this'
+ ' warning.'
+ )
diff --git a/src/google/adk/memory/base_memory_service.py b/src/google/adk/memory/base_memory_service.py
new file mode 100644
index 0000000..8da0c5d
--- /dev/null
+++ b/src/google/adk/memory/base_memory_service.py
@@ -0,0 +1,74 @@
+# Copyright 2025 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import abc
+
+from pydantic import BaseModel
+from pydantic import Field
+
+from ..events.event import Event
+from ..sessions.session import Session
+
+
+class MemoryResult(BaseModel):
+ """Represents a single memory retrieval result.
+
+ Attributes:
+ session_id: The session id associated with the memory.
+ events: A list of events in the session.
+ """
+ session_id: str
+ events: list[Event]
+
+
+class SearchMemoryResponse(BaseModel):
+ """Represents the response from a memory search.
+
+ Attributes:
+ memories: A list of memory results matching the search query.
+ """
+ memories: list[MemoryResult] = Field(default_factory=list)
+
+
+class BaseMemoryService(abc.ABC):
+ """Base class for memory services.
+
+ The service provides functionalities to ingest sessions into memory so that
+ the memory can be used for user queries.
+ """
+
+ @abc.abstractmethod
+ def add_session_to_memory(self, session: Session):
+ """Adds a session to the memory service.
+
+ A session may be added multiple times during its lifetime.
+
+ Args:
+ session: The session to add.
+ """
+
+ @abc.abstractmethod
+ def search_memory(
+ self, *, app_name: str, user_id: str, query: str
+ ) -> SearchMemoryResponse:
+ """Searches for sessions that match the query.
+
+ Args:
+ app_name: The name of the application.
+ user_id: The id of the user.
+ query: The query to search for.
+
+ Returns:
+ A SearchMemoryResponse containing the matching memories.
+ """
diff --git a/src/google/adk/memory/in_memory_memory_service.py b/src/google/adk/memory/in_memory_memory_service.py
new file mode 100644
index 0000000..8976344
--- /dev/null
+++ b/src/google/adk/memory/in_memory_memory_service.py
@@ -0,0 +1,62 @@
+# Copyright 2025 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from ..events.event import Event
+from ..sessions.session import Session
+from .base_memory_service import BaseMemoryService
+from .base_memory_service import MemoryResult
+from .base_memory_service import SearchMemoryResponse
+
+
+class InMemoryMemoryService(BaseMemoryService):
+ """An in-memory memory service for prototyping purpose only.
+
+ Uses keyword matching instead of semantic search.
+ """
+
+ def __init__(self):
+ self.session_events: dict[str, list[Event]] = {}
+ """keys are app_name/user_id/session_id"""
+
+ def add_session_to_memory(self, session: Session):
+ key = f'{session.app_name}/{session.user_id}/{session.id}'
+ self.session_events[key] = [
+ event for event in session.events if event.content
+ ]
+
+ def search_memory(
+ self, *, app_name: str, user_id: str, query: str
+ ) -> SearchMemoryResponse:
+ """Prototyping purpose only."""
+ keywords = set(query.lower().split())
+ response = SearchMemoryResponse()
+ for key, events in self.session_events.items():
+ if not key.startswith(f'{app_name}/{user_id}/'):
+ continue
+ matched_events = []
+ for event in events:
+ if not event.content or not event.content.parts:
+ continue
+ parts = event.content.parts
+ text = '\n'.join([part.text for part in parts if part.text]).lower()
+ for keyword in keywords:
+ if keyword in text:
+ matched_events.append(event)
+ break
+ if matched_events:
+ session_id = key.split('/')[-1]
+ response.memories.append(
+ MemoryResult(session_id=session_id, events=matched_events)
+ )
+ return response
diff --git a/src/google/adk/memory/vertex_ai_rag_memory_service.py b/src/google/adk/memory/vertex_ai_rag_memory_service.py
new file mode 100644
index 0000000..3582260
--- /dev/null
+++ b/src/google/adk/memory/vertex_ai_rag_memory_service.py
@@ -0,0 +1,177 @@
+# Copyright 2025 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from collections import OrderedDict
+import json
+import os
+import tempfile
+
+from google.genai import types
+from typing_extensions import override
+from vertexai.preview import rag
+
+from ..events.event import Event
+from ..sessions.session import Session
+from .base_memory_service import BaseMemoryService
+from .base_memory_service import MemoryResult
+from .base_memory_service import SearchMemoryResponse
+
+
+class VertexAiRagMemoryService(BaseMemoryService):
+ """A memory service that uses Vertex AI RAG for storage and retrieval."""
+
+ def __init__(
+ self,
+ rag_corpus: str = None,
+ similarity_top_k: int = None,
+ vector_distance_threshold: float = 10,
+ ):
+ """Initializes a VertexAiRagMemoryService.
+
+ Args:
+ rag_corpus: The name of the Vertex AI RAG corpus to use. Format:
+ ``projects/{project}/locations/{location}/ragCorpora/{rag_corpus_id}``
+ or ``{rag_corpus_id}``
+ similarity_top_k: The number of contexts to retrieve.
+ vector_distance_threshold: Only returns contexts with vector distance
+ smaller than the threshold..
+ """
+ self.vertex_rag_store = types.VertexRagStore(
+ rag_resources=[rag.RagResource(rag_corpus=rag_corpus)],
+ similarity_top_k=similarity_top_k,
+ vector_distance_threshold=vector_distance_threshold,
+ )
+
+ @override
+ def add_session_to_memory(self, session: Session):
+ with tempfile.NamedTemporaryFile(
+ mode="w", delete=False, suffix=".txt"
+ ) as temp_file:
+
+ output_lines = []
+ for event in session.events:
+ if not event.content or not event.content.parts:
+ continue
+ text_parts = [
+ part.text.replace("\n", " ")
+ for part in event.content.parts
+ if part.text
+ ]
+ if text_parts:
+ output_lines.append(
+ json.dumps({
+ "author": event.author,
+ "timestamp": event.timestamp,
+ "text": ".".join(text_parts),
+ })
+ )
+ output_string = "\n".join(output_lines)
+ temp_file.write(output_string)
+ temp_file_path = temp_file.name
+ for rag_resource in self.vertex_rag_store.rag_resources:
+ rag.upload_file(
+ corpus_name=rag_resource.rag_corpus,
+ path=temp_file_path,
+ # this is the temp workaround as upload file does not support
+ # adding metadata, thus use display_name to store the session info.
+ display_name=f"{session.app_name}.{session.user_id}.{session.id}",
+ )
+
+ os.remove(temp_file_path)
+
+ @override
+ def search_memory(
+ self, *, app_name: str, user_id: str, query: str
+ ) -> SearchMemoryResponse:
+ """Searches for sessions that match the query using rag.retrieval_query."""
+ response = rag.retrieval_query(
+ text=query,
+ rag_resources=self.vertex_rag_store.rag_resources,
+ rag_corpora=self.vertex_rag_store.rag_corpora,
+ similarity_top_k=self.vertex_rag_store.similarity_top_k,
+ vector_distance_threshold=self.vertex_rag_store.vector_distance_threshold,
+ )
+
+ memory_results = []
+ session_events_map = OrderedDict()
+ for context in response.contexts.contexts:
+ # filter out context that is not related
+ # TODO: Add server side filtering by app_name and user_id.
+ # if not context.source_display_name.startswith(f"{app_name}.{user_id}."):
+ # continue
+ session_id = context.source_display_name.split(".")[-1]
+ events = []
+ if context.text:
+ lines = context.text.split("\n")
+
+ for line in lines:
+ line = line.strip()
+ if not line:
+ continue
+
+ try:
+ # Try to parse as JSON
+ event_data = json.loads(line)
+
+ author = event_data.get("author", "")
+ timestamp = float(event_data.get("timestamp", 0))
+ text = event_data.get("text", "")
+
+ content = types.Content(parts=[types.Part(text=text)])
+ event = Event(author=author, timestamp=timestamp, content=content)
+ events.append(event)
+ except json.JSONDecodeError:
+ # Not valid JSON, skip this line
+ continue
+
+ if session_id in session_events_map:
+ session_events_map[session_id].append(events)
+ else:
+ session_events_map[session_id] = [events]
+
+ # Remove overlap and combine events from the same session.
+ for session_id, event_lists in session_events_map.items():
+ for events in _merge_event_lists(event_lists):
+ sorted_events = sorted(events, key=lambda e: e.timestamp)
+ memory_results.append(
+ MemoryResult(session_id=session_id, events=sorted_events)
+ )
+ return SearchMemoryResponse(memories=memory_results)
+
+
+def _merge_event_lists(event_lists: list[list[Event]]) -> list[list[Event]]:
+ """Merge event lists that have overlapping timestamps."""
+ merged = []
+ while event_lists:
+ current = event_lists.pop(0)
+ current_ts = {event.timestamp for event in current}
+ merge_found = True
+
+ # Keep merging until no new overlap is found.
+ while merge_found:
+ merge_found = False
+ remaining = []
+ for other in event_lists:
+ other_ts = {event.timestamp for event in other}
+ # Overlap exists, so we merge and use the merged list to check again
+ if current_ts & other_ts:
+ new_events = [e for e in other if e.timestamp not in current_ts]
+ current.extend(new_events)
+ current_ts.update(e.timestamp for e in new_events)
+ merge_found = True
+ else:
+ remaining.append(other)
+ event_lists = remaining
+ merged.append(current)
+ return merged
diff --git a/src/google/adk/models/__init__.py b/src/google/adk/models/__init__.py
new file mode 100644
index 0000000..fc86c19
--- /dev/null
+++ b/src/google/adk/models/__init__.py
@@ -0,0 +1,31 @@
+# Copyright 2025 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Defines the interface to support a model."""
+
+from .base_llm import BaseLlm
+from .google_llm import Gemini
+from .llm_request import LlmRequest
+from .llm_response import LlmResponse
+from .registry import LLMRegistry
+
+__all__ = [
+ 'BaseLlm',
+ 'Gemini',
+ 'LLMRegistry',
+]
+
+
+for regex in Gemini.supported_models():
+ LLMRegistry.register(Gemini)
diff --git a/src/google/adk/models/anthropic_llm.py b/src/google/adk/models/anthropic_llm.py
new file mode 100644
index 0000000..4d625a0
--- /dev/null
+++ b/src/google/adk/models/anthropic_llm.py
@@ -0,0 +1,243 @@
+# Copyright 2025 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Anthropic integration for Claude models."""
+
+from __future__ import annotations
+
+from functools import cached_property
+import logging
+import os
+from typing import AsyncGenerator
+from typing import Generator
+from typing import Iterable
+from typing import Literal
+from typing import Optional, Union
+from typing import TYPE_CHECKING
+
+from anthropic import AnthropicVertex
+from anthropic import NOT_GIVEN
+from anthropic import types as anthropic_types
+from google.genai import types
+from pydantic import BaseModel
+from typing_extensions import override
+
+from .base_llm import BaseLlm
+from .llm_response import LlmResponse
+
+if TYPE_CHECKING:
+ from .llm_request import LlmRequest
+
+__all__ = ["Claude"]
+
+logger = logging.getLogger(__name__)
+
+MAX_TOKEN = 1024
+
+
+class ClaudeRequest(BaseModel):
+ system_instruction: str
+ messages: Iterable[anthropic_types.MessageParam]
+ tools: list[anthropic_types.ToolParam]
+
+
+def to_claude_role(role: Optional[str]) -> Literal["user", "assistant"]:
+ if role in ["model", "assistant"]:
+ return "assistant"
+ return "user"
+
+
+def to_google_genai_finish_reason(
+ anthropic_stop_reason: Optional[str],
+) -> types.FinishReason:
+ if anthropic_stop_reason in ["end_turn", "stop_sequence", "tool_use"]:
+ return "STOP"
+ if anthropic_stop_reason == "max_tokens":
+ return "MAX_TOKENS"
+ return "FINISH_REASON_UNSPECIFIED"
+
+
+def part_to_message_block(
+ part: types.Part,
+) -> Union[
+ anthropic_types.TextBlockParam,
+ anthropic_types.ImageBlockParam,
+ anthropic_types.ToolUseBlockParam,
+ anthropic_types.ToolResultBlockParam,
+]:
+ if part.text:
+ return anthropic_types.TextBlockParam(text=part.text, type="text")
+ if part.function_call:
+ assert part.function_call.name
+
+ return anthropic_types.ToolUseBlockParam(
+ id=part.function_call.id or "",
+ name=part.function_call.name,
+ input=part.function_call.args,
+ type="tool_use",
+ )
+ if part.function_response:
+ content = ""
+ if (
+ "result" in part.function_response.response
+ and part.function_response.response["result"]
+ ):
+ # Transformation is required because the content is a list of dict.
+ # ToolResultBlockParam content doesn't support list of dict. Converting
+ # to str to prevent anthropic.BadRequestError from being thrown.
+ content = str(part.function_response.response["result"])
+ return anthropic_types.ToolResultBlockParam(
+ tool_use_id=part.function_response.id or "",
+ type="tool_result",
+ content=content,
+ is_error=False,
+ )
+ raise NotImplementedError("Not supported yet.")
+
+
+def content_to_message_param(
+ content: types.Content,
+) -> anthropic_types.MessageParam:
+ return {
+ "role": to_claude_role(content.role),
+ "content": [part_to_message_block(part) for part in content.parts or []],
+ }
+
+
+def content_block_to_part(
+ content_block: anthropic_types.ContentBlock,
+) -> types.Part:
+ if isinstance(content_block, anthropic_types.TextBlock):
+ return types.Part.from_text(text=content_block.text)
+ if isinstance(content_block, anthropic_types.ToolUseBlock):
+ assert isinstance(content_block.input, dict)
+ part = types.Part.from_function_call(
+ name=content_block.name, args=content_block.input
+ )
+ part.function_call.id = content_block.id
+ return part
+ raise NotImplementedError("Not supported yet.")
+
+
+def message_to_generate_content_response(
+ message: anthropic_types.Message,
+) -> LlmResponse:
+
+ return LlmResponse(
+ content=types.Content(
+ role="model",
+ parts=[content_block_to_part(cb) for cb in message.content],
+ ),
+ # TODO: Deal with these later.
+ # finish_reason=to_google_genai_finish_reason(message.stop_reason),
+ # usage_metadata=types.GenerateContentResponseUsageMetadata(
+ # prompt_token_count=message.usage.input_tokens,
+ # candidates_token_count=message.usage.output_tokens,
+ # total_token_count=(
+ # message.usage.input_tokens + message.usage.output_tokens
+ # ),
+ # ),
+ )
+
+
+def function_declaration_to_tool_param(
+ function_declaration: types.FunctionDeclaration,
+) -> anthropic_types.ToolParam:
+ assert function_declaration.name
+
+ properties = {}
+ if (
+ function_declaration.parameters
+ and function_declaration.parameters.properties
+ ):
+ for key, value in function_declaration.parameters.properties.items():
+ value_dict = value.model_dump(exclude_none=True)
+ if "type" in value_dict:
+ value_dict["type"] = value_dict["type"].lower()
+ properties[key] = value_dict
+
+ return anthropic_types.ToolParam(
+ name=function_declaration.name,
+ description=function_declaration.description or "",
+ input_schema={
+ "type": "object",
+ "properties": properties,
+ },
+ )
+
+
+class Claude(BaseLlm):
+ model: str = "claude-3-5-sonnet-v2@20241022"
+
+ @staticmethod
+ @override
+ def supported_models() -> list[str]:
+ return [r"claude-3-.*"]
+
+ @override
+ async def generate_content_async(
+ self, llm_request: LlmRequest, stream: bool = False
+ ) -> AsyncGenerator[LlmResponse, None]:
+ messages = [
+ content_to_message_param(content)
+ for content in llm_request.contents or []
+ ]
+ tools = NOT_GIVEN
+ if (
+ llm_request.config
+ and llm_request.config.tools
+ and llm_request.config.tools[0].function_declarations
+ ):
+ tools = [
+ function_declaration_to_tool_param(tool)
+ for tool in llm_request.config.tools[0].function_declarations
+ ]
+ tool_choice = (
+ anthropic_types.ToolChoiceAutoParam(
+ type="auto",
+ # TODO: allow parallel tool use.
+ disable_parallel_tool_use=True,
+ )
+ if llm_request.tools_dict
+ else NOT_GIVEN
+ )
+ message = self._anthropic_client.messages.create(
+ model=llm_request.model,
+ system=llm_request.config.system_instruction,
+ messages=messages,
+ tools=tools,
+ tool_choice=tool_choice,
+ max_tokens=MAX_TOKEN,
+ )
+ logger.info(
+ "Claude response: %s",
+ message.model_dump_json(indent=2, exclude_none=True),
+ )
+ yield message_to_generate_content_response(message)
+
+ @cached_property
+ def _anthropic_client(self) -> AnthropicVertex:
+ if (
+ "GOOGLE_CLOUD_PROJECT" not in os.environ
+ or "GOOGLE_CLOUD_LOCATION" not in os.environ
+ ):
+ raise ValueError(
+ "GOOGLE_CLOUD_PROJECT and GOOGLE_CLOUD_LOCATION must be set for using"
+ " Anthropic on Vertex."
+ )
+
+ return AnthropicVertex(
+ project_id=os.environ["GOOGLE_CLOUD_PROJECT"],
+ region=os.environ["GOOGLE_CLOUD_LOCATION"],
+ )
diff --git a/src/google/adk/models/base_llm.py b/src/google/adk/models/base_llm.py
new file mode 100644
index 0000000..4876434
--- /dev/null
+++ b/src/google/adk/models/base_llm.py
@@ -0,0 +1,87 @@
+# Copyright 2025 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from __future__ import annotations
+
+from abc import abstractmethod
+from typing import AsyncGenerator
+from typing import TYPE_CHECKING
+
+from pydantic import BaseModel
+from pydantic import ConfigDict
+
+from .base_llm_connection import BaseLlmConnection
+
+if TYPE_CHECKING:
+ from .llm_request import LlmRequest
+ from .llm_response import LlmResponse
+
+
+class BaseLlm(BaseModel):
+ """The BaseLLM class.
+
+ Attributes:
+ model: The name of the LLM, e.g. gemini-1.5-flash or gemini-1.5-flash-001.
+ model_config: The model config
+ """
+
+ model_config = ConfigDict(
+ # This allows us to use arbitrary types in the model. E.g. PIL.Image.
+ arbitrary_types_allowed=True,
+ )
+ """The model config."""
+
+ model: str
+ """The name of the LLM, e.g. gemini-1.5-flash or gemini-1.5-flash-001."""
+
+ @classmethod
+ def supported_models(cls) -> list[str]:
+ """Returns a list of supported models in regex for LlmRegistry."""
+ return []
+
+ @abstractmethod
+ async def generate_content_async(
+ self, llm_request: LlmRequest, stream: bool = False
+ ) -> AsyncGenerator[LlmResponse, None]:
+ """Generates one content from the given contents and tools.
+
+ Args:
+ llm_request: LlmRequest, the request to send to the LLM.
+ stream: bool = False, whether to do streaming call.
+
+ Yields:
+ a generator of types.Content.
+
+ For non-streaming call, it will only yield one Content.
+
+ For streaming call, it may yield more than one content, but all yielded
+ contents should be treated as one content by merging the
+ parts list.
+ """
+ raise NotImplementedError(
+ f'Async generation is not supported for {self.model}.'
+ )
+ yield # AsyncGenerator requires a yield statement in function body.
+
+ def connect(self, llm_request: LlmRequest) -> BaseLlmConnection:
+ """Creates a live connection to the LLM.
+
+ Args:
+ llm_request: LlmRequest, the request to send to the LLM.
+
+ Returns:
+ BaseLlmConnection, the connection to the LLM.
+ """
+ raise NotImplementedError(
+ f'Live connection is not supported for {self.model}.'
+ )
diff --git a/src/google/adk/models/base_llm_connection.py b/src/google/adk/models/base_llm_connection.py
new file mode 100644
index 0000000..90be8fb
--- /dev/null
+++ b/src/google/adk/models/base_llm_connection.py
@@ -0,0 +1,76 @@
+# Copyright 2025 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from abc import abstractmethod
+from typing import AsyncGenerator
+from google.genai import types
+from .llm_response import LlmResponse
+
+
+class BaseLlmConnection:
+ """The base class for a live model connection."""
+
+ @abstractmethod
+ async def send_history(self, history: list[types.Content]):
+ """Sends the conversation history to the model.
+
+ You call this method right after setting up the model connection.
+ The model will respond if the last content is from user, otherwise it will
+ wait for new user input before responding.
+
+ Args:
+ history: The conversation history to send to the model.
+ """
+ pass
+
+ @abstractmethod
+ async def send_content(self, content: types.Content):
+ """Sends a user content to the model.
+
+ The model will respond immediately upon receiving the content.
+ If you send function responses, all parts in the content should be function
+ responses.
+
+ Args:
+ content: The content to send to the model.
+ """
+ pass
+
+ @abstractmethod
+ async def send_realtime(self, blob: types.Blob):
+ """Sends a chunk of audio or a frame of video to the model in realtime.
+
+ The model may not respond immediately upon receiving the blob. It will do
+ voice activity detection and decide when to respond.
+
+ Args:
+ blob: The blob to send to the model.
+ """
+ pass
+
+ @abstractmethod
+ async def receive(self) -> AsyncGenerator[LlmResponse, None]:
+ """Receives the model response using the llm server connection.
+
+ Args: None.
+
+ Yields:
+ LlmResponse: The model response.
+ """
+ pass
+
+ @abstractmethod
+ async def close(self):
+ """Closes the llm server connection."""
+ pass
diff --git a/src/google/adk/models/gemini_llm_connection.py b/src/google/adk/models/gemini_llm_connection.py
new file mode 100644
index 0000000..bbf35e2
--- /dev/null
+++ b/src/google/adk/models/gemini_llm_connection.py
@@ -0,0 +1,200 @@
+# Copyright 2025 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import logging
+from typing import AsyncGenerator
+
+from google.genai import live
+from google.genai import types
+
+from .base_llm_connection import BaseLlmConnection
+from .llm_response import LlmResponse
+
+logger = logging.getLogger(__name__)
+
+
+class GeminiLlmConnection(BaseLlmConnection):
+ """The Gemini model connection."""
+
+ def __init__(self, gemini_session: live.AsyncSession):
+ self._gemini_session = gemini_session
+
+ async def send_history(self, history: list[types.Content]):
+ """Sends the conversation history to the gemini model.
+
+ You call this method right after setting up the model connection.
+ The model will respond if the last content is from user, otherwise it will
+ wait for new user input before responding.
+
+ Args:
+ history: The conversation history to send to the model.
+ """
+
+ # TODO: Remove this filter and translate unary contents to streaming
+ # contents properly.
+
+ # We ignore any audio from user during the agent transfer phase
+ contents = [
+ content
+ for content in history
+ if content.parts and content.parts[0].text
+ ]
+
+ if contents:
+ await self._gemini_session.send(
+ input=types.LiveClientContent(
+ turns=contents,
+ turn_complete=contents[-1].role == 'user',
+ ),
+ )
+ else:
+ logger.info('no content is sent')
+
+ async def send_content(self, content: types.Content):
+ """Sends a user content to the gemini model.
+
+ The model will respond immediately upon receiving the content.
+ If you send function responses, all parts in the content should be function
+ responses.
+
+ Args:
+ content: The content to send to the model.
+ """
+
+ assert content.parts
+ if content.parts[0].function_response:
+ # All parts have to be function responses.
+ function_responses = [part.function_response for part in content.parts]
+ logger.debug('Sending LLM function response: %s', function_responses)
+ await self._gemini_session.send(
+ input=types.LiveClientToolResponse(
+ function_responses=function_responses
+ ),
+ )
+ else:
+ logger.debug('Sending LLM new content %s', content)
+ await self._gemini_session.send(
+ input=types.LiveClientContent(
+ turns=[content],
+ turn_complete=True,
+ )
+ )
+
+ async def send_realtime(self, blob: types.Blob):
+ """Sends a chunk of audio or a frame of video to the model in realtime.
+
+ Args:
+ blob: The blob to send to the model.
+ """
+
+ input_blob = blob.model_dump()
+ logger.debug('Sending LLM Blob: %s', input_blob)
+ await self._gemini_session.send(input=input_blob)
+
+ def __build_full_text_response(self, text: str):
+ """Builds a full text response.
+
+ The text should not partial and the returned LlmResponse is not be
+ partial.
+
+ Args:
+ text: The text to be included in the response.
+
+ Returns:
+ An LlmResponse containing the full text.
+ """
+ return LlmResponse(
+ content=types.Content(
+ role='model',
+ parts=[types.Part.from_text(text=text)],
+ ),
+ )
+
+ async def receive(self) -> AsyncGenerator[LlmResponse, None]:
+ """Receives the model response using the llm server connection.
+
+ Yields:
+ LlmResponse: The model response.
+ """
+
+ text = ''
+ async for message in self._gemini_session.receive():
+ logger.debug('Got LLM Live message: %s', message)
+ if message.server_content:
+ content = message.server_content.model_turn
+ if content and content.parts:
+ llm_response = LlmResponse(
+ content=content, interrupted=message.server_content.interrupted
+ )
+ if content.parts[0].text:
+ text += content.parts[0].text
+ llm_response.partial = True
+ # don't yield the merged text event when receiving audio data
+ elif text and not content.parts[0].inline_data:
+ yield self.__build_full_text_response(text)
+ text = ''
+ yield llm_response
+
+ if (
+ message.server_content.output_transcription
+ and message.server_content.output_transcription.text
+ ):
+ # TODO: Right now, we just support output_transcription without
+ # changing interface and data protocol. Later, we can consider to
+ # support output_transcription as a separete field in LlmResponse.
+
+ # Transcription is always considered as partial event
+ # We rely on other control signals to determine when to yield the
+ # full text response(turn_complete, interrupted, or tool_call).
+ text += message.server_content.output_transcription.text
+ parts = [
+ types.Part.from_text(
+ text=message.server_content.output_transcription.text
+ )
+ ]
+ llm_response = LlmResponse(
+ content=types.Content(role='model', parts=parts), partial=True
+ )
+ yield llm_response
+
+ if message.server_content.turn_complete:
+ if text:
+ yield self.__build_full_text_response(text)
+ text = ''
+ yield LlmResponse(
+ turn_complete=True, interrupted=message.server_content.interrupted
+ )
+ break
+ # in case of empty content or parts, we sill surface it
+ # in case it's an interrupted message, we merge the previous partial
+ # text. Other we don't merge. because content can be none when model
+ # safty threshold is triggered
+ if message.server_content.interrupted and text:
+ yield self.__build_full_text_response(text)
+ text = ''
+ yield LlmResponse(interrupted=message.server_content.interrupted)
+ if message.tool_call:
+ if text:
+ yield self.__build_full_text_response(text)
+ text = ''
+ parts = [
+ types.Part(function_call=function_call)
+ for function_call in message.tool_call.function_calls
+ ]
+ yield LlmResponse(content=types.Content(role='model', parts=parts))
+
+ async def close(self):
+ """Closes the llm server connection."""
+
+ await self._gemini_session.close()
diff --git a/src/google/adk/models/google_llm.py b/src/google/adk/models/google_llm.py
new file mode 100644
index 0000000..79d0d49
--- /dev/null
+++ b/src/google/adk/models/google_llm.py
@@ -0,0 +1,331 @@
+# Copyright 2025 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from __future__ import annotations
+
+import contextlib
+from functools import cached_property
+import logging
+import sys
+from typing import AsyncGenerator
+from typing import cast
+from typing import Generator
+from typing import TYPE_CHECKING
+
+from google.genai import Client
+from google.genai import types
+from typing_extensions import override
+
+from .. import version
+from .base_llm import BaseLlm
+from .base_llm_connection import BaseLlmConnection
+from .gemini_llm_connection import GeminiLlmConnection
+from .llm_response import LlmResponse
+
+if TYPE_CHECKING:
+ from .llm_request import LlmRequest
+
+logger = logging.getLogger(__name__)
+
+_NEW_LINE = '\n'
+_EXCLUDED_PART_FIELD = {'inline_data': {'data'}}
+
+
+class Gemini(BaseLlm):
+ """Integration for Gemini models.
+
+ Attributes:
+ model: The name of the Gemini model.
+ """
+
+ model: str = 'gemini-1.5-flash'
+
+ @staticmethod
+ @override
+ def supported_models() -> list[str]:
+ """Provides the list of supported models.
+
+ Returns:
+ A list of supported models.
+ """
+
+ return [
+ r'gemini-.*',
+ # fine-tuned vertex endpoint pattern
+ r'projects\/.+\/locations\/.+\/endpoints\/.+',
+ # vertex gemini long name
+ r'projects\/.+\/locations\/.+\/publishers\/google\/models\/gemini.+',
+ ]
+
+ async def generate_content_async(
+ self, llm_request: LlmRequest, stream: bool = False
+ ) -> AsyncGenerator[LlmResponse, None]:
+ """Sends a request to the Gemini model.
+
+ Args:
+ llm_request: LlmRequest, the request to send to the Gemini model.
+ stream: bool = False, whether to do streaming call.
+
+ Yields:
+ LlmResponse: The model response.
+ """
+
+ self._maybe_append_user_content(llm_request)
+ logger.info(
+ 'Sending out request, model: %s, backend: %s, stream: %s',
+ llm_request.model,
+ self._api_backend,
+ stream,
+ )
+ logger.info(_build_request_log(llm_request))
+
+ if stream:
+ responses = await self.api_client.aio.models.generate_content_stream(
+ model=llm_request.model,
+ contents=llm_request.contents,
+ config=llm_request.config,
+ )
+ response = None
+ text = ''
+ # for sse, similar as bidi (see receive method in gemini_llm_connecton.py),
+ # we need to mark those text content as partial and after all partial
+ # contents are sent, we send an accumulated event which contains all the
+ # previous partial content. The only difference is bidi rely on
+ # complete_turn flag to detect end while sse depends on finish_reason.
+ async for response in responses:
+ logger.info(_build_response_log(response))
+ llm_response = LlmResponse.create(response)
+ if (
+ llm_response.content
+ and llm_response.content.parts
+ and llm_response.content.parts[0].text
+ ):
+ text += llm_response.content.parts[0].text
+ llm_response.partial = True
+ elif text and (
+ not llm_response.content
+ or not llm_response.content.parts
+ # don't yield the merged text event when receiving audio data
+ or not llm_response.content.parts[0].inline_data
+ ):
+ yield LlmResponse(
+ content=types.ModelContent(
+ parts=[types.Part.from_text(text=text)],
+ ),
+ )
+ text = ''
+ yield llm_response
+ if (
+ text
+ and response
+ and response.candidates
+ and response.candidates[0].finish_reason == types.FinishReason.STOP
+ ):
+ yield LlmResponse(
+ content=types.ModelContent(
+ parts=[types.Part.from_text(text=text)],
+ ),
+ )
+
+ else:
+ response = await self.api_client.aio.models.generate_content(
+ model=llm_request.model,
+ contents=llm_request.contents,
+ config=llm_request.config,
+ )
+ logger.info(_build_response_log(response))
+ yield LlmResponse.create(response)
+
+ @cached_property
+ def api_client(self) -> Client:
+ """Provides the api client.
+
+ Returns:
+ The api client.
+ """
+ return Client(
+ http_options=types.HttpOptions(headers=self._tracking_headers)
+ )
+
+ @cached_property
+ def _api_backend(self) -> str:
+ return 'vertex' if self.api_client.vertexai else 'ml_dev'
+
+ @cached_property
+ def _tracking_headers(self) -> dict[str, str]:
+ framework_label = f'google-adk/{version.__version__}'
+ language_label = 'gl-python/' + sys.version.split()[0]
+ version_header_value = f'{framework_label} {language_label}'
+ tracking_headers = {
+ 'x-goog-api-client': version_header_value,
+ 'user-agent': version_header_value,
+ }
+ return tracking_headers
+
+ @cached_property
+ def _live_api_client(self) -> Client:
+ if self._api_backend == 'vertex':
+ # use default api version for vertex
+ return Client(
+ http_options=types.HttpOptions(headers=self._tracking_headers)
+ )
+ else:
+ # use v1alpha for ml_dev
+ api_version = 'v1alpha'
+ return Client(
+ http_options=types.HttpOptions(
+ headers=self._tracking_headers, api_version=api_version
+ )
+ )
+
+ @contextlib.asynccontextmanager
+ async def connect(self, llm_request: LlmRequest) -> BaseLlmConnection:
+ """Connects to the Gemini model and returns an llm connection.
+
+ Args:
+ llm_request: LlmRequest, the request to send to the Gemini model.
+
+ Yields:
+ BaseLlmConnection, the connection to the Gemini model.
+ """
+
+ llm_request.live_connect_config.system_instruction = types.Content(
+ role='system',
+ parts=[
+ types.Part.from_text(text=llm_request.config.system_instruction)
+ ],
+ )
+ llm_request.live_connect_config.tools = llm_request.config.tools
+ async with self._live_api_client.aio.live.connect(
+ model=llm_request.model, config=llm_request.live_connect_config
+ ) as live_session:
+ yield GeminiLlmConnection(live_session)
+
+ def _maybe_append_user_content(self, llm_request: LlmRequest):
+ """Appends a user content, so that model can continue to output.
+
+ Args:
+ llm_request: LlmRequest, the request to send to the Gemini model.
+ """
+ # If no content is provided, append a user content to hint model response
+ # using system instruction.
+ if not llm_request.contents:
+ llm_request.contents.append(
+ types.Content(
+ role='user',
+ parts=[
+ types.Part(
+ text=(
+ 'Handle the requests as specified in the System'
+ ' Instruction.'
+ )
+ )
+ ],
+ )
+ )
+ return
+
+ # Insert a user content to preserve user intent and to avoid empty
+ # model response.
+ if llm_request.contents[-1].role != 'user':
+ llm_request.contents.append(
+ types.Content(
+ role='user',
+ parts=[
+ types.Part(
+ text=(
+ 'Continue processing previous requests as instructed.'
+ ' Exit or provide a summary if no more outputs are'
+ ' needed.'
+ )
+ )
+ ],
+ )
+ )
+
+
+def _build_function_declaration_log(
+ func_decl: types.FunctionDeclaration,
+) -> str:
+ param_str = '{}'
+ if func_decl.parameters and func_decl.parameters.properties:
+ param_str = str({
+ k: v.model_dump(exclude_none=True)
+ for k, v in func_decl.parameters.properties.items()
+ })
+ return_str = 'None'
+ if func_decl.response:
+ return_str = str(func_decl.response.model_dump(exclude_none=True))
+ return f'{func_decl.name}: {param_str} -> {return_str}'
+
+
+def _build_request_log(req: LlmRequest) -> str:
+ function_decls: list[types.FunctionDeclaration] = cast(
+ list[types.FunctionDeclaration],
+ req.config.tools[0].function_declarations if req.config.tools else [],
+ )
+ function_logs = (
+ [
+ _build_function_declaration_log(func_decl)
+ for func_decl in function_decls
+ ]
+ if function_decls
+ else []
+ )
+ contents_logs = [
+ content.model_dump_json(
+ exclude_none=True,
+ exclude={
+ 'parts': {
+ i: _EXCLUDED_PART_FIELD for i in range(len(content.parts))
+ }
+ },
+ )
+ for content in req.contents
+ ]
+
+ return f"""
+LLM Request:
+-----------------------------------------------------------
+System Instruction:
+{req.config.system_instruction}
+-----------------------------------------------------------
+Contents:
+{_NEW_LINE.join(contents_logs)}
+-----------------------------------------------------------
+Functions:
+{_NEW_LINE.join(function_logs)}
+-----------------------------------------------------------
+"""
+
+
+def _build_response_log(resp: types.GenerateContentResponse) -> str:
+ function_calls_text = []
+ if function_calls := resp.function_calls:
+ for func_call in function_calls:
+ function_calls_text.append(
+ f'name: {func_call.name}, args: {func_call.args}'
+ )
+ return f"""
+LLM Response:
+-----------------------------------------------------------
+Text:
+{resp.text}
+-----------------------------------------------------------
+Function calls:
+{_NEW_LINE.join(function_calls_text)}
+-----------------------------------------------------------
+Raw response:
+{resp.model_dump_json(exclude_none=True)}
+-----------------------------------------------------------
+"""
diff --git a/src/google/adk/models/lite_llm.py b/src/google/adk/models/lite_llm.py
new file mode 100644
index 0000000..c96b13f
--- /dev/null
+++ b/src/google/adk/models/lite_llm.py
@@ -0,0 +1,673 @@
+# Copyright 2025 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import base64
+import json
+import logging
+from typing import Any
+from typing import AsyncGenerator
+from typing import cast
+from typing import Dict
+from typing import Generator
+from typing import Iterable
+from typing import Literal
+from typing import Optional
+from typing import Tuple
+from typing import Union
+
+from google.genai import types
+from litellm import acompletion
+from litellm import ChatCompletionAssistantMessage
+from litellm import ChatCompletionDeveloperMessage
+from litellm import ChatCompletionImageUrlObject
+from litellm import ChatCompletionMessageToolCall
+from litellm import ChatCompletionTextObject
+from litellm import ChatCompletionToolMessage
+from litellm import ChatCompletionUserMessage
+from litellm import ChatCompletionVideoUrlObject
+from litellm import completion
+from litellm import CustomStreamWrapper
+from litellm import Function
+from litellm import Message
+from litellm import ModelResponse
+from litellm import OpenAIMessageContent
+from pydantic import BaseModel
+from pydantic import Field
+from typing_extensions import override
+
+from .base_llm import BaseLlm
+from .llm_request import LlmRequest
+from .llm_response import LlmResponse
+
+logger = logging.getLogger(__name__)
+
+_NEW_LINE = "\n"
+_EXCLUDED_PART_FIELD = {"inline_data": {"data"}}
+
+
+class FunctionChunk(BaseModel):
+ id: Optional[str]
+ name: Optional[str]
+ args: Optional[str]
+
+
+class TextChunk(BaseModel):
+ text: str
+
+
+class LiteLLMClient:
+ """Provides acompletion method (for better testability)."""
+
+ async def acompletion(
+ self, model, messages, tools, **kwargs
+ ) -> Union[ModelResponse, CustomStreamWrapper]:
+ """Asynchronously calls acompletion.
+
+ Args:
+ model: The model name.
+ messages: The messages to send to the model.
+ tools: The tools to use for the model.
+ **kwargs: Additional arguments to pass to acompletion.
+
+ Returns:
+ The model response as a message.
+ """
+
+ return await acompletion(
+ model=model,
+ messages=messages,
+ tools=tools,
+ **kwargs,
+ )
+
+ def completion(
+ self, model, messages, tools, stream=False, **kwargs
+ ) -> Union[ModelResponse, CustomStreamWrapper]:
+ """Synchronously calls completion. This is used for streaming only.
+
+ Args:
+ model: The model to use.
+ messages: The messages to send.
+ tools: The tools to use for the model.
+ stream: Whether to stream the response.
+ **kwargs: Additional arguments to pass to completion.
+
+ Returns:
+ The response from the model.
+ """
+
+ return completion(
+ model=model,
+ messages=messages,
+ tools=tools,
+ stream=stream,
+ **kwargs,
+ )
+
+
+def _safe_json_serialize(obj) -> str:
+ """Convert any Python object to a JSON-serializable type or string.
+
+ Args:
+ obj: The object to serialize.
+
+ Returns:
+ The JSON-serialized object string or string.
+ """
+
+ try:
+ # Try direct JSON serialization first
+ return json.dumps(obj)
+ except (TypeError, OverflowError):
+ return str(obj)
+
+
+def _content_to_message_param(
+ content: types.Content,
+) -> Message:
+ """Converts a types.Content to a litellm Message.
+
+ Args:
+ content: The content to convert.
+
+ Returns:
+ The litellm Message.
+ """
+
+ if content.parts and content.parts[0].function_response:
+ return ChatCompletionToolMessage(
+ role="tool",
+ tool_call_id=content.parts[0].function_response.id,
+ content=_safe_json_serialize(
+ content.parts[0].function_response.response
+ ),
+ )
+
+ role = _to_litellm_role(content.role)
+
+ if role == "user":
+ return ChatCompletionUserMessage(
+ role="user", content=_get_content(content.parts)
+ )
+ else:
+
+ tool_calls = [
+ ChatCompletionMessageToolCall(
+ type="function",
+ id=part.function_call.id,
+ function=Function(
+ name=part.function_call.name,
+ arguments=part.function_call.args,
+ ),
+ )
+ for part in content.parts
+ if part.function_call
+ ]
+
+ return ChatCompletionAssistantMessage(
+ role=role,
+ content=_get_content(content.parts),
+ tool_calls=tool_calls or None,
+ )
+
+
+def _get_content(parts: Iterable[types.Part]) -> OpenAIMessageContent | str:
+ """Converts a list of parts to litellm content.
+
+ Args:
+ parts: The parts to convert.
+
+ Returns:
+ The litellm content.
+ """
+
+ content_objects = []
+ for part in parts:
+ if part.text:
+ if len(parts) == 1:
+ return part.text
+ content_objects.append(
+ ChatCompletionTextObject(
+ type="text",
+ text=part.text,
+ )
+ )
+ elif (
+ part.inline_data
+ and part.inline_data.data
+ and part.inline_data.mime_type
+ ):
+ base64_string = base64.b64encode(part.inline_data.data).decode("utf-8")
+ data_uri = f"data:{part.inline_data.mime_type};base64,{base64_string}"
+
+ if part.inline_data.mime_type.startswith("image"):
+ content_objects.append(
+ ChatCompletionImageUrlObject(
+ type="image_url",
+ image_url=data_uri,
+ )
+ )
+ elif part.inline_data.mime_type.startswith("video"):
+ content_objects.append(
+ ChatCompletionVideoUrlObject(
+ type="video_url",
+ video_url=data_uri,
+ )
+ )
+ else:
+ raise ValueError("LiteLlm(BaseLlm) does not support this content part.")
+
+ return content_objects
+
+
+def _to_litellm_role(role: Optional[str]) -> Literal["user", "assistant"]:
+ """Converts a types.Content role to a litellm role.
+
+ Args:
+ role: The types.Content role.
+
+ Returns:
+ The litellm role.
+ """
+
+ if role in ["model", "assistant"]:
+ return "assistant"
+ return "user"
+
+
+TYPE_LABELS = {
+ "STRING": "string",
+ "NUMBER": "number",
+ "BOOLEAN": "boolean",
+ "OBJECT": "object",
+ "ARRAY": "array",
+ "INTEGER": "integer",
+}
+
+
+def _schema_to_dict(schema: types.Schema) -> dict:
+ """Recursively converts a types.Schema to a dictionary.
+
+ Args:
+ schema: The schema to convert.
+
+ Returns:
+ The dictionary representation of the schema.
+ """
+
+ schema_dict = schema.model_dump(exclude_none=True)
+ if "type" in schema_dict:
+ schema_dict["type"] = schema_dict["type"].lower()
+ if "items" in schema_dict:
+ if isinstance(schema_dict["items"], dict):
+ schema_dict["items"] = _schema_to_dict(
+ types.Schema.model_validate(schema_dict["items"])
+ )
+ elif isinstance(schema_dict["items"]["type"], types.Type):
+ schema_dict["items"]["type"] = TYPE_LABELS[
+ schema_dict["items"]["type"].value
+ ]
+ if "properties" in schema_dict:
+ properties = {}
+ for key, value in schema_dict["properties"].items():
+ if isinstance(value, types.Schema):
+ properties[key] = _schema_to_dict(value)
+ else:
+ properties[key] = value
+ if "type" in properties[key]:
+ properties[key]["type"] = properties[key]["type"].lower()
+ schema_dict["properties"] = properties
+ return schema_dict
+
+
+def _function_declaration_to_tool_param(
+ function_declaration: types.FunctionDeclaration,
+) -> dict:
+ """Converts a types.FunctionDeclaration to a openapi spec dictionary.
+
+ Args:
+ function_declaration: The function declaration to convert.
+
+ Returns:
+ The openapi spec dictionary representation of the function declaration.
+ """
+
+ assert function_declaration.name
+
+ properties = {}
+ if (
+ function_declaration.parameters
+ and function_declaration.parameters.properties
+ ):
+ for key, value in function_declaration.parameters.properties.items():
+ properties[key] = _schema_to_dict(value)
+
+ return {
+ "type": "function",
+ "function": {
+ "name": function_declaration.name,
+ "description": function_declaration.description or "",
+ "parameters": {
+ "type": "object",
+ "properties": properties,
+ },
+ },
+ }
+
+
+def _model_response_to_chunk(
+ response: ModelResponse,
+) -> Generator[
+ Tuple[Optional[Union[TextChunk, FunctionChunk]], Optional[str]], None, None
+]:
+ """Converts a litellm message to text or function chunk.
+
+ Args:
+ response: The response from the model.
+
+ Yields:
+ A tuple of text or function chunk and finish reason.
+ """
+
+ message = None
+ if response.get("choices", None):
+ message = response["choices"][0].get("message", None)
+ finish_reason = response["choices"][0].get("finish_reason", None)
+ # check streaming delta
+ if message is None and response["choices"][0].get("delta", None):
+ message = response["choices"][0]["delta"]
+
+ if message.get("content", None):
+ yield TextChunk(text=message.get("content")), finish_reason
+
+ if message.get("tool_calls", None):
+ for tool_call in message.get("tool_calls"):
+ # aggregate tool_call
+ if tool_call.type == "function":
+ yield FunctionChunk(
+ id=tool_call.id,
+ name=tool_call.function.name,
+ args=tool_call.function.arguments,
+ ), finish_reason
+
+ if finish_reason and not (
+ message.get("content", None) or message.get("tool_calls", None)
+ ):
+ yield None, finish_reason
+
+ if not message:
+ yield None, None
+
+
+def _model_response_to_generate_content_response(
+ response: ModelResponse,
+) -> LlmResponse:
+ """Converts a litellm response to LlmResponse.
+
+ Args:
+ response: The model response.
+
+ Returns:
+ The LlmResponse.
+ """
+
+ message = None
+ if response.get("choices", None):
+ message = response["choices"][0].get("message", None)
+
+ if not message:
+ raise ValueError("No message in response")
+ return _message_to_generate_content_response(message)
+
+
+def _message_to_generate_content_response(
+ message: Message, is_partial: bool = False
+) -> LlmResponse:
+ """Converts a litellm message to LlmResponse.
+
+ Args:
+ message: The message to convert.
+ is_partial: Whether the message is partial.
+
+ Returns:
+ The LlmResponse.
+ """
+
+ parts = []
+ if message.get("content", None):
+ parts.append(types.Part.from_text(text=message.get("content")))
+
+ if message.get("tool_calls", None):
+ for tool_call in message.get("tool_calls"):
+ if tool_call.type == "function":
+ part = types.Part.from_function_call(
+ name=tool_call.function.name,
+ args=json.loads(tool_call.function.arguments or "{}"),
+ )
+ part.function_call.id = tool_call.id
+ parts.append(part)
+
+ return LlmResponse(
+ content=types.Content(role="model", parts=parts), partial=is_partial
+ )
+
+
+def _get_completion_inputs(
+ llm_request: LlmRequest,
+) -> tuple[Iterable[Message], Iterable[dict]]:
+ """Converts an LlmRequest to litellm inputs.
+
+ Args:
+ llm_request: The LlmRequest to convert.
+
+ Returns:
+ The litellm inputs (message list and tool dictionary).
+ """
+ messages = [
+ _content_to_message_param(content)
+ for content in llm_request.contents or []
+ ]
+
+ if llm_request.config.system_instruction:
+ messages.insert(
+ 0,
+ ChatCompletionDeveloperMessage(
+ role="developer",
+ content=llm_request.config.system_instruction,
+ ),
+ )
+
+ tools = None
+ if (
+ llm_request.config
+ and llm_request.config.tools
+ and llm_request.config.tools[0].function_declarations
+ ):
+ tools = [
+ _function_declaration_to_tool_param(tool)
+ for tool in llm_request.config.tools[0].function_declarations
+ ]
+ return messages, tools
+
+
+def _build_function_declaration_log(
+ func_decl: types.FunctionDeclaration,
+) -> str:
+ """Builds a function declaration log.
+
+ Args:
+ func_decl: The function declaration to convert.
+
+ Returns:
+ The function declaration log.
+ """
+
+ param_str = "{}"
+ if func_decl.parameters and func_decl.parameters.properties:
+ param_str = str({
+ k: v.model_dump(exclude_none=True)
+ for k, v in func_decl.parameters.properties.items()
+ })
+ return_str = "None"
+ if func_decl.response:
+ return_str = str(func_decl.response.model_dump(exclude_none=True))
+ return f"{func_decl.name}: {param_str} -> {return_str}"
+
+
+def _build_request_log(req: LlmRequest) -> str:
+ """Builds a request log.
+
+ Args:
+ req: The request to convert.
+
+ Returns:
+ The request log.
+ """
+
+ function_decls: list[types.FunctionDeclaration] = cast(
+ list[types.FunctionDeclaration],
+ req.config.tools[0].function_declarations if req.config.tools else [],
+ )
+ function_logs = (
+ [
+ _build_function_declaration_log(func_decl)
+ for func_decl in function_decls
+ ]
+ if function_decls
+ else []
+ )
+ contents_logs = [
+ content.model_dump_json(
+ exclude_none=True,
+ exclude={
+ "parts": {
+ i: _EXCLUDED_PART_FIELD for i in range(len(content.parts))
+ }
+ },
+ )
+ for content in req.contents
+ ]
+
+ return f"""
+LLM Request:
+-----------------------------------------------------------
+System Instruction:
+{req.config.system_instruction}
+-----------------------------------------------------------
+Contents:
+{_NEW_LINE.join(contents_logs)}
+-----------------------------------------------------------
+Functions:
+{_NEW_LINE.join(function_logs)}
+-----------------------------------------------------------
+"""
+
+
+class LiteLlm(BaseLlm):
+ """Wrapper around litellm.
+
+ This wrapper can be used with any of the models supported by litellm. The
+ environment variable(s) needed for authenticating with the model endpoint must
+ be set prior to instantiating this class.
+
+ Example usage:
+ ```
+ os.environ["VERTEXAI_PROJECT"] = "your-gcp-project-id"
+ os.environ["VERTEXAI_LOCATION"] = "your-gcp-location"
+
+ agent = Agent(
+ model=LiteLlm(model="vertex_ai/claude-3-7-sonnet@20250219"),
+ ...
+ )
+ ```
+
+ Attributes:
+ model: The name of the LiteLlm model.
+ llm_client: The LLM client to use for the model.
+ model_config: The model config.
+ """
+
+ llm_client: LiteLLMClient = Field(default_factory=LiteLLMClient)
+ """The LLM client to use for the model."""
+
+ _additional_args: Dict[str, Any] = None
+
+ def __init__(self, model: str, **kwargs):
+ """Initializes the LiteLlm class.
+
+ Args:
+ model: The name of the LiteLlm model.
+ **kwargs: Additional arguments to pass to the litellm completion api.
+ """
+ super().__init__(model=model, **kwargs)
+ self._additional_args = kwargs
+ # preventing generation call with llm_client
+ # and overriding messages, tools and stream which are managed internally
+ self._additional_args.pop("llm_client", None)
+ self._additional_args.pop("messages", None)
+ self._additional_args.pop("tools", None)
+ # public api called from runner determines to stream or not
+ self._additional_args.pop("stream", None)
+
+ async def generate_content_async(
+ self, llm_request: LlmRequest, stream: bool = False
+ ) -> AsyncGenerator[LlmResponse, None]:
+ """Generates content asynchronously.
+
+ Args:
+ llm_request: LlmRequest, the request to send to the LiteLlm model.
+ stream: bool = False, whether to do streaming call.
+
+ Yields:
+ LlmResponse: The model response.
+ """
+
+ logger.info(_build_request_log(llm_request))
+
+ messages, tools = _get_completion_inputs(llm_request)
+
+ completion_args = {
+ "model": self.model,
+ "messages": messages,
+ "tools": tools,
+ }
+ completion_args.update(self._additional_args)
+
+ if stream:
+ text = ""
+ function_name = ""
+ function_args = ""
+ function_id = None
+ completion_args["stream"] = True
+ for part in self.llm_client.completion(**completion_args):
+ for chunk, finish_reason in _model_response_to_chunk(part):
+ if isinstance(chunk, FunctionChunk):
+ if chunk.name:
+ function_name += chunk.name
+ if chunk.args:
+ function_args += chunk.args
+ function_id = chunk.id or function_id
+ elif isinstance(chunk, TextChunk):
+ text += chunk.text
+ yield _message_to_generate_content_response(
+ ChatCompletionAssistantMessage(
+ role="assistant",
+ content=chunk.text,
+ ),
+ is_partial=True,
+ )
+ if finish_reason == "tool_calls" and function_id:
+ yield _message_to_generate_content_response(
+ ChatCompletionAssistantMessage(
+ role="assistant",
+ content="",
+ tool_calls=[
+ ChatCompletionMessageToolCall(
+ type="function",
+ id=function_id,
+ function=Function(
+ name=function_name,
+ arguments=function_args,
+ ),
+ )
+ ],
+ )
+ )
+ function_name = ""
+ function_args = ""
+ function_id = None
+ elif finish_reason == "stop" and text:
+ yield _message_to_generate_content_response(
+ ChatCompletionAssistantMessage(role="assistant", content=text)
+ )
+ text = ""
+
+ else:
+ response = await self.llm_client.acompletion(**completion_args)
+ yield _model_response_to_generate_content_response(response)
+
+ @staticmethod
+ @override
+ def supported_models() -> list[str]:
+ """Provides the list of supported models.
+
+ LiteLlm supports all models supported by litellm. We do not keep track of
+ these models here. So we return an empty list.
+
+ Returns:
+ A list of supported models.
+ """
+
+ return []
diff --git a/src/google/adk/models/llm_request.py b/src/google/adk/models/llm_request.py
new file mode 100644
index 0000000..cc97c55
--- /dev/null
+++ b/src/google/adk/models/llm_request.py
@@ -0,0 +1,98 @@
+# Copyright 2025 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import annotations
+
+from typing import Optional
+
+from google.genai import types
+from pydantic import BaseModel
+from pydantic import ConfigDict
+from pydantic import Field
+
+from ..tools.base_tool import BaseTool
+
+
+class LlmRequest(BaseModel):
+ """LLM request class that allows passing in tools, output schema and system
+
+ instructions to the model.
+
+ Attributes:
+ model: The model name.
+ contents: The contents to send to the model.
+ config: Additional config for the generate content request.
+ tools_dict: The tools dictionary.
+ """
+
+ model_config = ConfigDict(arbitrary_types_allowed=True)
+ """The model config."""
+
+ model: Optional[str] = None
+ """The model name."""
+
+ contents: list[types.Content] = Field(default_factory=list)
+ """The contents to send to the model."""
+
+ config: Optional[types.GenerateContentConfig] = None
+ live_connect_config: types.LiveConnectConfig = types.LiveConnectConfig()
+ """Additional config for the generate content request.
+
+ tools in generate_content_config should not be set.
+ """
+ tools_dict: dict[str, BaseTool] = Field(default_factory=dict, exclude=True)
+ """The tools dictionary."""
+
+ def append_instructions(self, instructions: list[str]) -> None:
+ """Appends instructions to the system instruction.
+
+ Args:
+ instructions: The instructions to append.
+ """
+
+ if self.config.system_instruction:
+ self.config.system_instruction += '\n\n' + '\n\n'.join(instructions)
+ else:
+ self.config.system_instruction = '\n\n'.join(instructions)
+
+ def append_tools(self, tools: list[BaseTool]) -> None:
+ """Appends tools to the request.
+
+ Args:
+ tools: The tools to append.
+ """
+
+ if not tools:
+ return
+ declarations = []
+ for tool in tools:
+ if isinstance(tool, BaseTool):
+ declaration = tool._get_declaration()
+ else:
+ declaration = tool.get_declaration()
+ if declaration:
+ declarations.append(declaration)
+ self.tools_dict[tool.name] = tool
+ if declarations:
+ self.config.tools.append(types.Tool(function_declarations=declarations))
+
+ def set_output_schema(self, base_model: type[BaseModel]) -> None:
+ """Sets the output schema for the request.
+
+ Args:
+ base_model: The pydantic base model to set the output schema to.
+ """
+
+ self.config.response_schema = base_model
+ self.config.response_mime_type = 'application/json'
diff --git a/src/google/adk/models/llm_response.py b/src/google/adk/models/llm_response.py
new file mode 100644
index 0000000..895e7a0
--- /dev/null
+++ b/src/google/adk/models/llm_response.py
@@ -0,0 +1,111 @@
+# Copyright 2025 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import annotations
+
+from typing import Optional
+
+from google.genai import types
+from pydantic import BaseModel
+from pydantic import ConfigDict
+
+
+class LlmResponse(BaseModel):
+ """LLM response class that provides the first candidate response from the
+
+ model if available. Otherwise, returns error code and message.
+
+ Attributes:
+ content: The content of the response.
+ grounding_metadata: The grounding metadata of the response.
+ partial: Indicates whether the text content is part of a unfinished text
+ stream. Only used for streaming mode and when the content is plain text.
+ turn_complete: Indicates whether the response from the model is complete.
+ Only used for streaming mode.
+ error_code: Error code if the response is an error. Code varies by model.
+ error_message: Error message if the response is an error.
+ interrupted: Flag indicating that LLM was interrupted when generating the
+ content. Usually it's due to user interruption during a bidi streaming.
+ """
+
+ model_config = ConfigDict(extra='forbid')
+ """The model config."""
+
+ content: Optional[types.Content] = None
+ """The content of the response."""
+
+ grounding_metadata: Optional[types.GroundingMetadata] = None
+ """The grounding metadata of the response."""
+
+ partial: Optional[bool] = None
+ """Indicates whether the text content is part of a unfinished text stream.
+
+ Only used for streaming mode and when the content is plain text.
+ """
+
+ turn_complete: Optional[bool] = None
+ """Indicates whether the response from the model is complete.
+
+ Only used for streaming mode.
+ """
+
+ error_code: Optional[str] = None
+ """Error code if the response is an error. Code varies by model."""
+
+ error_message: Optional[str] = None
+ """Error message if the response is an error."""
+
+ interrupted: Optional[bool] = None
+ """Flag indicating that LLM was interrupted when generating the content.
+ Usually it's due to user interruption during a bidi streaming.
+ """
+
+ @staticmethod
+ def create(
+ generate_content_response: types.GenerateContentResponse,
+ ) -> 'LlmResponse':
+ """Creates an LlmResponse from a GenerateContentResponse.
+
+ Args:
+ generate_content_response: The GenerateContentResponse to create the
+ LlmResponse from.
+
+ Returns:
+ The LlmResponse.
+ """
+
+ if generate_content_response.candidates:
+ candidate = generate_content_response.candidates[0]
+ if candidate.content and candidate.content.parts:
+ return LlmResponse(
+ content=candidate.content,
+ grounding_metadata=candidate.grounding_metadata,
+ )
+ else:
+ return LlmResponse(
+ error_code=candidate.finish_reason,
+ error_message=candidate.finish_message,
+ )
+ else:
+ if generate_content_response.prompt_feedback:
+ prompt_feedback = generate_content_response.prompt_feedback
+ return LlmResponse(
+ error_code=prompt_feedback.block_reason,
+ error_message=prompt_feedback.block_reason_message,
+ )
+ else:
+ return LlmResponse(
+ error_code='UNKNOWN_ERROR',
+ error_message='Unknown error.',
+ )
diff --git a/src/google/adk/models/registry.py b/src/google/adk/models/registry.py
new file mode 100644
index 0000000..68be9eb
--- /dev/null
+++ b/src/google/adk/models/registry.py
@@ -0,0 +1,102 @@
+# Copyright 2025 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""The registry class for model."""
+
+from __future__ import annotations
+
+from functools import lru_cache
+import logging
+import re
+from typing import TYPE_CHECKING
+
+if TYPE_CHECKING:
+ from .base_llm import BaseLlm
+
+logger = logging.getLogger(__name__)
+
+
+_llm_registry_dict: dict[str, type[BaseLlm]] = {}
+"""Registry for LLMs.
+
+Key is the regex that matches the model name.
+Value is the class that implements the model.
+"""
+
+
+class LLMRegistry:
+ """Registry for LLMs."""
+
+ @staticmethod
+ def new_llm(model: str) -> BaseLlm:
+ """Creates a new LLM instance.
+
+ Args:
+ model: The model name.
+
+ Returns:
+ The LLM instance.
+ """
+
+ return LLMRegistry.resolve(model)(model=model)
+
+ @staticmethod
+ def _register(model_name_regex: str, llm_cls: type[BaseLlm]):
+ """Registers a new LLM class.
+
+ Args:
+ model_name_regex: The regex that matches the model name.
+ llm_cls: The class that implements the model.
+ """
+
+ if model_name_regex in _llm_registry_dict:
+ logger.info(
+ 'Updating LLM class for %s from %s to %s',
+ model_name_regex,
+ _llm_registry_dict[model_name_regex],
+ llm_cls,
+ )
+
+ _llm_registry_dict[model_name_regex] = llm_cls
+
+ @staticmethod
+ def register(llm_cls: type[BaseLlm]):
+ """Registers a new LLM class.
+
+ Args:
+ llm_cls: The class that implements the model.
+ """
+
+ for regex in llm_cls.supported_models():
+ LLMRegistry._register(regex, llm_cls)
+
+ @staticmethod
+ @lru_cache(maxsize=32)
+ def resolve(model: str) -> type[BaseLlm]:
+ """Resolves the model to a BaseLlm subclass.
+
+ Args:
+ model: The model name.
+
+ Returns:
+ The BaseLlm subclass.
+ Raises:
+ ValueError: If the model is not found.
+ """
+
+ for regex, llm_class in _llm_registry_dict.items():
+ if re.compile(regex).fullmatch(model):
+ return llm_class
+
+ raise ValueError(f'Model {model} not found.')
diff --git a/src/google/adk/planners/__init__.py b/src/google/adk/planners/__init__.py
new file mode 100644
index 0000000..a8193f6
--- /dev/null
+++ b/src/google/adk/planners/__init__.py
@@ -0,0 +1,23 @@
+# Copyright 2025 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from .base_planner import BasePlanner
+from .built_in_planner import BuiltInPlanner
+from .plan_re_act_planner import PlanReActPlanner
+
+__all__ = [
+ 'BasePlanner',
+ 'BuiltInPlanner',
+ 'PlanReActPlanner',
+]
diff --git a/src/google/adk/planners/base_planner.py b/src/google/adk/planners/base_planner.py
new file mode 100644
index 0000000..f248312
--- /dev/null
+++ b/src/google/adk/planners/base_planner.py
@@ -0,0 +1,66 @@
+# Copyright 2025 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import abc
+from abc import ABC
+from typing import List
+from typing import Optional
+
+from google.genai import types
+
+from ..agents.callback_context import CallbackContext
+from ..agents.readonly_context import ReadonlyContext
+from ..models.llm_request import LlmRequest
+
+
+class BasePlanner(ABC):
+ """Abstract base class for all planners.
+
+ The planner allows the agent to generate plans for the queries to guide its
+ action.
+ """
+
+ @abc.abstractmethod
+ def build_planning_instruction(
+ self,
+ readonly_context: ReadonlyContext,
+ llm_request: LlmRequest,
+ ) -> Optional[str]:
+ """Builds the system instruction to be appended to the LLM request for planning.
+
+ Args:
+ readonly_context: The readonly context of the invocation.
+ llm_request: The LLM request. Readonly.
+
+ Returns:
+ The planning system instruction, or None if no instruction is needed.
+ """
+ pass
+
+ @abc.abstractmethod
+ def process_planning_response(
+ self,
+ callback_context: CallbackContext,
+ response_parts: List[types.Part],
+ ) -> Optional[List[types.Part]]:
+ """Processes the LLM response for planning.
+
+ Args:
+ callback_context: The callback context of the invocation.
+ response_parts: The LLM response parts. Readonly.
+
+ Returns:
+ The processed response parts, or None if no processing is needed.
+ """
+ pass
diff --git a/src/google/adk/planners/built_in_planner.py b/src/google/adk/planners/built_in_planner.py
new file mode 100644
index 0000000..84ba453
--- /dev/null
+++ b/src/google/adk/planners/built_in_planner.py
@@ -0,0 +1,75 @@
+# Copyright 2025 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import List
+from typing import Optional
+
+from google.genai import types
+from typing_extensions import override
+
+from ..agents.callback_context import CallbackContext
+from ..agents.readonly_context import ReadonlyContext
+from ..models.llm_request import LlmRequest
+from .base_planner import BasePlanner
+
+
+class BuiltInPlanner(BasePlanner):
+ """The built-in planner that uses model's built-in thinking features.
+
+ Attributes:
+ thinking_config: Config for model built-in thinking features. An error
+ will be returned if this field is set for models that don't support
+ thinking.
+ """
+
+ thinking_config: types.ThinkingConfig
+ """
+ Config for model built-in thinking features. An error will be returned if this
+ field is set for models that don't support thinking.
+ """
+
+ def __init__(self, *, thinking_config: types.ThinkingConfig):
+ """Initializes the built-in planner.
+
+ Args:
+ thinking_config: Config for model built-in thinking features. An error
+ will be returned if this field is set for models that don't support
+ thinking.
+ """
+ self.thinking_config = thinking_config
+
+ def apply_thinking_config(self, llm_request: LlmRequest) -> None:
+ """Applies the thinking config to the LLM request.
+
+ Args:
+ llm_request: The LLM request to apply the thinking config to.
+ """
+ if self.thinking_config:
+ llm_request.config.thinking_config = self.thinking_config
+
+ @override
+ def build_planning_instruction(
+ self,
+ readonly_context: ReadonlyContext,
+ llm_request: LlmRequest,
+ ) -> Optional[str]:
+ return
+
+ @override
+ def process_planning_response(
+ self,
+ callback_context: CallbackContext,
+ response_parts: List[types.Part],
+ ) -> Optional[List[types.Part]]:
+ return
diff --git a/src/google/adk/planners/plan_re_act_planner.py b/src/google/adk/planners/plan_re_act_planner.py
new file mode 100644
index 0000000..87da2a5
--- /dev/null
+++ b/src/google/adk/planners/plan_re_act_planner.py
@@ -0,0 +1,208 @@
+# Copyright 2025 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import List
+from typing import Optional
+
+from google.genai import types
+from typing_extensions import override
+
+from ..agents.callback_context import CallbackContext
+from ..agents.readonly_context import ReadonlyContext
+from ..models.llm_request import LlmRequest
+from .base_planner import BasePlanner
+
+PLANNING_TAG = '/*PLANNING*/'
+REPLANNING_TAG = '/*REPLANNING*/'
+REASONING_TAG = '/*REASONING*/'
+ACTION_TAG = '/*ACTION*/'
+FINAL_ANSWER_TAG = '/*FINAL_ANSWER*/'
+
+
+class PlanReActPlanner(BasePlanner):
+ """Plan-Re-Act planner that constraints the LLM response to generate a plan before any action/observation.
+
+ Note: this planner does not require the model to support buil-in thinking
+ features or setting the thinking config.
+ """
+
+ @override
+ def build_planning_instruction(
+ self,
+ readonly_context: ReadonlyContext,
+ llm_request: LlmRequest,
+ ) -> str:
+ return self._build_nl_planner_instruction()
+
+ @override
+ def process_planning_response(
+ self,
+ callback_context: CallbackContext,
+ response_parts: List[types.Part],
+ ) -> Optional[List[types.Part]]:
+ if not response_parts:
+ return None
+
+ preserved_parts = []
+ first_fc_part_index = -1
+ for i in range(len(response_parts)):
+ # Stop at the first (group of) function calls.
+ if response_parts[i].function_call:
+ # Ignore and filter out function calls with empty names.
+ if not response_parts[i].function_call.name:
+ continue
+ preserved_parts.append(response_parts[i])
+ first_fc_part_index = i
+ break
+
+ # Split the response into reasoning and final answer parts.
+ self._handle_non_function_call_parts(response_parts[i], preserved_parts)
+
+ if first_fc_part_index > 0:
+ j = first_fc_part_index + 1
+ while j < len(response_parts):
+ if response_parts[j].function_call:
+ preserved_parts.append(response_parts[j])
+ j += 1
+ else:
+ break
+
+ return preserved_parts
+
+ def _split_by_last_pattern(self, text, separator):
+ """Splits the text by the last occurrence of the separator.
+
+ Args:
+ text: The text to split.
+ separator: The separator to split on.
+
+ Returns:
+ A tuple containing the text before the last separator and the text after
+ the last separator.
+ """
+ index = text.rfind(separator)
+ if index == -1:
+ return text, ''
+ return text[: index + len(separator)], text[index + len(separator) :]
+
+ def _handle_non_function_call_parts(
+ self, response_part: types.Part, preserved_parts: list[types.Part]
+ ):
+ """Handles non-function-call parts of the response.
+
+ Args:
+ response_part: The response part to handle.
+ preserved_parts: The mutable list of parts to store the processed parts
+ in.
+ """
+ if response_part.text and FINAL_ANSWER_TAG in response_part.text:
+ reasoning_text, final_answer_text = self._split_by_last_pattern(
+ response_part.text, FINAL_ANSWER_TAG
+ )
+ if reasoning_text:
+ reasoning_part = types.Part(text=reasoning_text)
+ self._mark_as_thought(reasoning_part)
+ preserved_parts.append(reasoning_part)
+ if final_answer_text:
+ preserved_parts.append(
+ types.Part(
+ text=final_answer_text,
+ )
+ )
+ else:
+ response_text = response_part.text or ''
+ # If the part is a text part with a planning/reasoning/action tag,
+ # label it as reasoning.
+ if response_text and (
+ any(
+ response_text.startswith(tag)
+ for tag in [
+ PLANNING_TAG,
+ REASONING_TAG,
+ ACTION_TAG,
+ REPLANNING_TAG,
+ ]
+ )
+ ):
+ self._mark_as_thought(response_part)
+ preserved_parts.append(response_part)
+
+ def _mark_as_thought(self, response_part: types.Part):
+ """Marks the response part as thought.
+
+ Args:
+ response_part: The mutable response part to mark as thought.
+ """
+ if response_part.text:
+ response_part.thought = True
+ return
+
+ def _build_nl_planner_instruction(self) -> str:
+ """Builds the NL planner instruction for the Plan-Re-Act planner.
+
+ Returns:
+ NL planner system instruction.
+ """
+
+ high_level_preamble = f"""
+When answering the question, try to leverage the available tools to gather the information instead of your memorized knowledge.
+
+Follow this process when answering the question: (1) first come up with a plan in natural language text format; (2) Then use tools to execute the plan and provide reasoning between tool code snippets to make a summary of current state and next step. Tool code snippets and reasoning should be interleaved with each other. (3) In the end, return one final answer.
+
+Follow this format when answering the question: (1) The planning part should be under {PLANNING_TAG}. (2) The tool code snippets should be under {ACTION_TAG}, and the reasoning parts should be under {REASONING_TAG}. (3) The final answer part should be under {FINAL_ANSWER_TAG}.
+"""
+
+ planning_preamble = f"""
+Below are the requirements for the planning:
+The plan is made to answer the user query if following the plan. The plan is coherent and covers all aspects of information from user query, and only involves the tools that are accessible by the agent. The plan contains the decomposed steps as a numbered list where each step should use one or multiple available tools. By reading the plan, you can intuitively know which tools to trigger or what actions to take.
+If the initial plan cannot be successfully executed, you should learn from previous execution results and revise your plan. The revised plan should be be under {REPLANNING_TAG}. Then use tools to follow the new plan.
+"""
+
+ reasoning_preamble = """
+Below are the requirements for the reasoning:
+The reasoning makes a summary of the current trajectory based on the user query and tool outputs. Based on the tool outputs and plan, the reasoning also comes up with instructions to the next steps, making the trajectory closer to the final answer.
+"""
+
+ final_answer_preamble = """
+Below are the requirements for the final answer:
+The final answer should be precise and follow query formatting requirements. Some queries may not be answerable with the available tools and information. In those cases, inform the user why you cannot process their query and ask for more information.
+"""
+
+ # Only contains the requirements for custom tool/libraries.
+ tool_code_without_python_libraries_preamble = """
+Below are the requirements for the tool code:
+
+**Custom Tools:** The available tools are described in the context and can be directly used.
+- Code must be valid self-contained Python snippets with no imports and no references to tools or Python libraries that are not in the context.
+- You cannot use any parameters or fields that are not explicitly defined in the APIs in the context.
+- The code snippets should be readable, efficient, and directly relevant to the user query and reasoning steps.
+- When using the tools, you should use the library name together with the function name, e.g., vertex_search.search().
+- If Python libraries are not provided in the context, NEVER write your own code other than the function calls using the provided tools.
+"""
+
+ user_input_preamble = """
+VERY IMPORTANT instruction that you MUST follow in addition to the above instructions:
+
+You should ask for clarification if you need more information to answer the question.
+You should prefer using the information available in the context instead of repeated tool use.
+"""
+
+ return '\n\n'.join([
+ high_level_preamble,
+ planning_preamble,
+ reasoning_preamble,
+ final_answer_preamble,
+ tool_code_without_python_libraries_preamble,
+ user_input_preamble,
+ ])
diff --git a/src/google/adk/runners.py b/src/google/adk/runners.py
new file mode 100644
index 0000000..bc4bb87
--- /dev/null
+++ b/src/google/adk/runners.py
@@ -0,0 +1,456 @@
+# Copyright 2025 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import annotations
+
+import asyncio
+import logging
+import queue
+import threading
+from typing import AsyncGenerator
+from typing import Generator
+from typing import Optional
+
+from deprecated import deprecated
+from google.genai import types
+
+from .agents.active_streaming_tool import ActiveStreamingTool
+from .agents.base_agent import BaseAgent
+from .agents.invocation_context import InvocationContext
+from .agents.invocation_context import new_invocation_context_id
+from .agents.live_request_queue import LiveRequestQueue
+from .agents.llm_agent import LlmAgent
+from .agents.run_config import RunConfig
+from .agents.run_config import StreamingMode
+from .artifacts.base_artifact_service import BaseArtifactService
+from .artifacts.in_memory_artifact_service import InMemoryArtifactService
+from .events.event import Event
+from .memory.base_memory_service import BaseMemoryService
+from .memory.in_memory_memory_service import InMemoryMemoryService
+from .sessions.base_session_service import BaseSessionService
+from .sessions.in_memory_session_service import InMemorySessionService
+from .sessions.session import Session
+from .telemetry import tracer
+from .tools.built_in_code_execution_tool import built_in_code_execution
+
+logger = logging.getLogger(__name__)
+
+
+class Runner:
+ """The Runner class is used to run agents.
+
+ It manages the execution of an agent within a session, handling message
+ processing, event generation, and interaction with various services like
+ artifact storage, session management, and memory.
+
+ Attributes:
+ app_name: The application name of the runner.
+ agent: The root agent to run.
+ artifact_service: The artifact service for the runner.
+ session_service: The session service for the runner.
+ memory_service: The memory service for the runner.
+ """
+
+ app_name: str
+ """The app name of the runner."""
+ agent: BaseAgent
+ """The root agent to run."""
+ artifact_service: Optional[BaseArtifactService] = None
+ """The artifact service for the runner."""
+ session_service: BaseSessionService
+ """The session service for the runner."""
+ memory_service: Optional[BaseMemoryService] = None
+ """The memory service for the runner."""
+
+ def __init__(
+ self,
+ *,
+ app_name: str,
+ agent: BaseAgent,
+ artifact_service: Optional[BaseArtifactService] = None,
+ session_service: BaseSessionService,
+ memory_service: Optional[BaseMemoryService] = None,
+ ):
+ """Initializes the Runner.
+
+ Args:
+ app_name: The application name of the runner.
+ agent: The root agent to run.
+ artifact_service: The artifact service for the runner.
+ session_service: The session service for the runner.
+ memory_service: The memory service for the runner.
+ """
+ self.app_name = app_name
+ self.agent = agent
+ self.artifact_service = artifact_service
+ self.session_service = session_service
+ self.memory_service = memory_service
+
+ def run(
+ self,
+ *,
+ user_id: str,
+ session_id: str,
+ new_message: types.Content,
+ run_config: RunConfig = RunConfig(),
+ ) -> Generator[Event, None, None]:
+ """Runs the agent.
+
+ NOTE: This sync interface is only for local testing and convenience purpose.
+ Consider to use `run_async` for production usage.
+
+ Args:
+ user_id: The user ID of the session.
+ session_id: The session ID of the session.
+ new_message: A new message to append to the session.
+ run_config: The run config for the agent.
+
+ Yields:
+ The events generated by the agent.
+ """
+ event_queue = queue.Queue()
+
+ async def _invoke_run_async():
+ try:
+ async for event in self.run_async(
+ user_id=user_id,
+ session_id=session_id,
+ new_message=new_message,
+ run_config=run_config,
+ ):
+ event_queue.put(event)
+ finally:
+ event_queue.put(None)
+
+ def _asyncio_thread_main():
+ try:
+ asyncio.run(_invoke_run_async())
+ finally:
+ event_queue.put(None)
+
+ thread = threading.Thread(target=_asyncio_thread_main)
+ thread.start()
+
+ # consumes and re-yield the events from background thread.
+ while True:
+ event = event_queue.get()
+ if event is None:
+ break
+ else:
+ yield event
+
+ thread.join()
+
+ async def run_async(
+ self,
+ *,
+ user_id: str,
+ session_id: str,
+ new_message: types.Content,
+ run_config: RunConfig = RunConfig(),
+ ) -> AsyncGenerator[Event, None]:
+ """Main entry method to run the agent in this runner.
+
+ Args:
+ user_id: The user ID of the session.
+ session_id: The session ID of the session.
+ new_message: A new message to append to the session.
+ run_config: The run config for the agent.
+
+ Yields:
+ The events generated by the agent.
+ """
+ with tracer.start_as_current_span('invocation'):
+ session = self.session_service.get_session(
+ app_name=self.app_name, user_id=user_id, session_id=session_id
+ )
+ if not session:
+ raise ValueError(f'Session not found: {session_id}')
+
+ invocation_context = self._new_invocation_context(
+ session,
+ new_message=new_message,
+ run_config=run_config,
+ )
+ root_agent = self.agent
+
+ if new_message:
+ self._append_new_message_to_session(
+ session,
+ new_message,
+ invocation_context,
+ run_config.save_input_blobs_as_artifacts,
+ )
+
+ invocation_context.agent = self._find_agent_to_run(session, root_agent)
+ async for event in invocation_context.agent.run_async(invocation_context):
+ if not event.partial:
+ self.session_service.append_event(session=session, event=event)
+ yield event
+
+ def _append_new_message_to_session(
+ self,
+ session: Session,
+ new_message: types.Content,
+ invocation_context: InvocationContext,
+ save_input_blobs_as_artifacts: bool = False,
+ ):
+ """Appends a new message to the session.
+
+ Args:
+ session: The session to append the message to.
+ new_message: The new message to append.
+ invocation_context: The invocation context for the message.
+ save_input_blobs_as_artifacts: Whether to save input blobs as artifacts.
+ """
+ if not new_message.parts:
+ raise ValueError('No parts in the new_message.')
+
+ if self.artifact_service and save_input_blobs_as_artifacts:
+ # The runner directly saves the artifacts (if applicable) in the
+ # user message and replaces the artifact data with a file name
+ # placeholder.
+ for i, part in enumerate(new_message.parts):
+ if part.inline_data is None:
+ continue
+ file_name = f'artifact_{invocation_context.invocation_id}_{i}'
+ self.artifact_service.save_artifact(
+ app_name=self.app_name,
+ user_id=session.user_id,
+ session_id=session.id,
+ filename=file_name,
+ artifact=part,
+ )
+ new_message.parts[i] = types.Part(
+ text=f'Uploaded file: {file_name}. It is saved into artifacts'
+ )
+ # Appends only. We do not yield the event because it's not from the model.
+ event = Event(
+ invocation_id=invocation_context.invocation_id,
+ author='user',
+ content=new_message,
+ )
+ self.session_service.append_event(session=session, event=event)
+
+ async def run_live(
+ self,
+ *,
+ session: Session,
+ live_request_queue: LiveRequestQueue,
+ run_config: RunConfig = RunConfig(),
+ ) -> AsyncGenerator[Event, None]:
+ """Runs the agent in live mode (experimental feature).
+
+ Args:
+ session: The session to use.
+ live_request_queue: The queue for live requests.
+ run_config: The run config for the agent.
+
+ Yields:
+ The events generated by the agent.
+ """
+ # TODO: right now, only works for a single audio agent without FC.
+ invocation_context = self._new_invocation_context_for_live(
+ session,
+ live_request_queue=live_request_queue,
+ run_config=run_config,
+ )
+
+ root_agent = self.agent
+ invocation_context.agent = self._find_agent_to_run(session, root_agent)
+
+ invocation_context.active_streaming_tools = {}
+ # TODO(hangfei): switch to use canonical_tools.
+ for tool in invocation_context.agent.tools:
+ # replicate a LiveRequestQueue for streaming tools that relis on
+ # LiveRequestQueue
+ from typing import get_type_hints
+
+ type_hints = get_type_hints(tool)
+ for arg_type in type_hints.values():
+ if arg_type is LiveRequestQueue:
+ if not invocation_context.active_streaming_tools:
+ invocation_context.active_streaming_tools = {}
+ active_streaming_tools = ActiveStreamingTool(
+ stream=LiveRequestQueue()
+ )
+ invocation_context.active_streaming_tools[tool.__name__] = (
+ active_streaming_tools
+ )
+
+ async for event in invocation_context.agent.run_live(invocation_context):
+ self.session_service.append_event(session=session, event=event)
+ yield event
+
+ def close_session(self, session: Session):
+ """Closes a session and adds it to the memory service (experimental feature).
+
+ Args:
+ session: The session to close.
+ """
+ if self.memory_service:
+ self.memory_service.add_session_to_memory(session)
+ self.session_service.close_session(session=session)
+
+ def _find_agent_to_run(
+ self, session: Session, root_agent: BaseAgent
+ ) -> BaseAgent:
+ """Finds the agent to run to continue the session.
+
+ A qualified agent must be either of:
+ - The root agent;
+ - An LlmAgent who replied last and is capable to transfer to any other agent
+ in the agent hierarchy.
+
+ Args:
+ session: The session to find the agent for.
+ root_agent: The root agent of the runner.
+
+ Returns:
+ The agent of the last message in the session or the root agent.
+ """
+ for event in filter(lambda e: e.author != 'user', reversed(session.events)):
+ if event.author == root_agent.name:
+ # Found root agent.
+ return root_agent
+ if not (agent := root_agent.find_sub_agent(event.author)):
+ # Agent not found, continue looking.
+ logger.warning(
+ 'Event from an unknown agent: %s, event id: %s',
+ event.author,
+ event.id,
+ )
+ continue
+ if self._is_transferable_across_agent_tree(agent):
+ return agent
+ # Falls back to root agent if no suitable agents are found in the session.
+ return root_agent
+
+ def _is_transferable_across_agent_tree(self, agent_to_run: BaseAgent) -> bool:
+ """Whether the agent to run can transfer to any other agent in the agent tree.
+
+ This typically means all agent_to_run's parent through root agent can
+ transfer to their parent_agent.
+
+ Args:
+ agent_to_run: The agent to check for transferability.
+
+ Returns:
+ True if the agent can transfer, False otherwise.
+ """
+ agent = agent_to_run
+ while agent:
+ if not isinstance(agent, LlmAgent):
+ # Only LLM-based Agent can provider agent transfer capability.
+ return False
+ if agent.disallow_transfer_to_parent:
+ return False
+ agent = agent.parent_agent
+ return True
+
+ def _new_invocation_context(
+ self,
+ session: Session,
+ *,
+ new_message: Optional[types.Content] = None,
+ live_request_queue: Optional[LiveRequestQueue] = None,
+ run_config: RunConfig = RunConfig(),
+ ) -> InvocationContext:
+ """Creates a new invocation context.
+
+ Args:
+ session: The session for the context.
+ new_message: The new message for the context.
+ live_request_queue: The live request queue for the context.
+ run_config: The run config for the context.
+
+ Returns:
+ The new invocation context.
+ """
+ invocation_id = new_invocation_context_id()
+
+ if run_config.support_cfc and isinstance(self.agent, LlmAgent):
+ model_name = self.agent.canonical_model.model
+ if not model_name.startswith('gemini-2'):
+ raise ValueError(
+ f'CFC is not supported for model: {model_name} in agent:'
+ f' {self.agent.name}'
+ )
+ if built_in_code_execution not in self.agent.canonical_tools:
+ self.agent.tools.append(built_in_code_execution)
+
+ return InvocationContext(
+ artifact_service=self.artifact_service,
+ session_service=self.session_service,
+ memory_service=self.memory_service,
+ invocation_id=invocation_id,
+ agent=self.agent,
+ session=session,
+ user_content=new_message,
+ live_request_queue=live_request_queue,
+ run_config=run_config,
+ )
+
+ def _new_invocation_context_for_live(
+ self,
+ session: Session,
+ *,
+ live_request_queue: Optional[LiveRequestQueue] = None,
+ run_config: RunConfig = RunConfig(),
+ ) -> InvocationContext:
+ """Creates a new invocation context for live multi-agent."""
+
+ # For live multi-agent, we need model's text transcription as context for
+ # next agent.
+ if self.agent.sub_agents and live_request_queue:
+ if not run_config.response_modalities:
+ # default
+ run_config.response_modalities = ['AUDIO', 'TEXT']
+ elif 'TEXT' not in run_config.response_modalities:
+ run_config.response_modalities.append('TEXT')
+ return self._new_invocation_context(
+ session,
+ live_request_queue=live_request_queue,
+ run_config=run_config,
+ )
+
+
+class InMemoryRunner(Runner):
+ """An in-memory Runner for testing and development.
+
+ This runner uses in-memory implementations for artifact, session, and memory
+ services, providing a lightweight and self-contained environment for agent
+ execution.
+
+ Attributes:
+ agent: The root agent to run.
+ app_name: The application name of the runner. Defaults to
+ 'InMemoryRunner'.
+ """
+
+ def __init__(self, agent: LlmAgent, *, app_name: str = 'InMemoryRunner'):
+ """Initializes the InMemoryRunner.
+
+ Args:
+ agent: The root agent to run.
+ app_name: The application name of the runner. Defaults to
+ 'InMemoryRunner'.
+ """
+ super().__init__(
+ app_name=app_name,
+ agent=agent,
+ artifact_service=InMemoryArtifactService(),
+ session_service=InMemorySessionService(),
+ memory_service=InMemoryMemoryService(),
+ )
diff --git a/src/google/adk/sessions/__init__.py b/src/google/adk/sessions/__init__.py
new file mode 100644
index 0000000..c9b8390
--- /dev/null
+++ b/src/google/adk/sessions/__init__.py
@@ -0,0 +1,41 @@
+# Copyright 2025 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import logging
+
+from .base_session_service import BaseSessionService
+from .in_memory_session_service import InMemorySessionService
+from .session import Session
+from .state import State
+from .vertex_ai_session_service import VertexAiSessionService
+
+logger = logging.getLogger(__name__)
+
+
+__all__ = [
+ 'BaseSessionService',
+ 'InMemorySessionService',
+ 'Session',
+ 'State',
+ 'VertexAiSessionService',
+]
+
+try:
+ from .database_session_service import DatabaseSessionService
+
+ __all__.append('DatabaseSessionService')
+except ImportError:
+ logger.debug(
+ 'DatabaseSessionService require sqlalchemy>=2.0, please ensure it is'
+ ' installed correctly.'
+ )
diff --git a/src/google/adk/sessions/base_session_service.py b/src/google/adk/sessions/base_session_service.py
new file mode 100644
index 0000000..be7f97e
--- /dev/null
+++ b/src/google/adk/sessions/base_session_service.py
@@ -0,0 +1,133 @@
+# Copyright 2025 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import abc
+from typing import Any
+from typing import Optional
+
+from pydantic import BaseModel
+from pydantic import Field
+
+from ..events.event import Event
+from .session import Session
+from .state import State
+
+
+class GetSessionConfig(BaseModel):
+ """The configuration of getting a session."""
+ num_recent_events: Optional[int] = None
+ after_timestamp: Optional[float] = None
+
+
+class ListSessionsResponse(BaseModel):
+ """The response of listing sessions.
+
+ The events and states are not set within each Session object.
+ """
+ sessions: list[Session] = Field(default_factory=list)
+
+
+class ListEventsResponse(BaseModel):
+ """The response of listing events in a session."""
+ events: list[Event] = Field(default_factory=list)
+ next_page_token: Optional[str] = None
+
+
+class BaseSessionService(abc.ABC):
+ """Base class for session services.
+
+ The service provides a set of methods for managing sessions and events.
+ """
+
+ @abc.abstractmethod
+ def create_session(
+ self,
+ *,
+ app_name: str,
+ user_id: str,
+ state: Optional[dict[str, Any]] = None,
+ session_id: Optional[str] = None,
+ ) -> Session:
+ """Creates a new session.
+
+ Args:
+ app_name: the name of the app.
+ user_id: the id of the user.
+ state: the initial state of the session.
+ session_id: the client-provided id of the session. If not provided, a
+ generated ID will be used.
+
+ Returns:
+ session: The newly created session instance.
+ """
+ pass
+
+ @abc.abstractmethod
+ def get_session(
+ self,
+ *,
+ app_name: str,
+ user_id: str,
+ session_id: str,
+ config: Optional[GetSessionConfig] = None,
+ ) -> Optional[Session]:
+ """Gets a session."""
+ pass
+
+ @abc.abstractmethod
+ def list_sessions(
+ self, *, app_name: str, user_id: str
+ ) -> ListSessionsResponse:
+ """Lists all the sessions."""
+ pass
+
+ @abc.abstractmethod
+ def delete_session(
+ self, *, app_name: str, user_id: str, session_id: str
+ ) -> None:
+ """Deletes a session."""
+ pass
+
+ @abc.abstractmethod
+ def list_events(
+ self,
+ *,
+ app_name: str,
+ user_id: str,
+ session_id: str,
+ ) -> ListEventsResponse:
+ """Lists events in a session."""
+ pass
+
+ def close_session(self, *, session: Session):
+ """Closes a session."""
+ # TODO: determine whether we want to finalize the session here.
+ pass
+
+ def append_event(self, session: Session, event: Event) -> Event:
+ """Appends an event to a session object."""
+ if event.partial:
+ return event
+ self.__update_session_state(session, event)
+ session.events.append(event)
+ return event
+
+ def __update_session_state(self, session: Session, event: Event):
+ """Updates the session state based on the event."""
+ if not event.actions or not event.actions.state_delta:
+ return
+ for key, value in event.actions.state_delta.items():
+ if key.startswith(State.TEMP_PREFIX):
+ continue
+ session.state.update({key: value})
diff --git a/src/google/adk/sessions/database_session_service.py b/src/google/adk/sessions/database_session_service.py
new file mode 100644
index 0000000..32fd66f
--- /dev/null
+++ b/src/google/adk/sessions/database_session_service.py
@@ -0,0 +1,522 @@
+# Copyright 2025 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import copy
+from datetime import datetime
+import json
+import logging
+from typing import Any
+from typing import Optional
+import uuid
+
+from sqlalchemy import delete
+from sqlalchemy import Dialect
+from sqlalchemy import ForeignKeyConstraint
+from sqlalchemy import func
+from sqlalchemy import select
+from sqlalchemy import Text
+from sqlalchemy.dialects import postgresql
+from sqlalchemy.engine import create_engine
+from sqlalchemy.engine import Engine
+from sqlalchemy.ext.mutable import MutableDict
+from sqlalchemy.inspection import inspect
+from sqlalchemy.orm import DeclarativeBase
+from sqlalchemy.orm import Mapped
+from sqlalchemy.orm import mapped_column
+from sqlalchemy.orm import relationship
+from sqlalchemy.orm import Session as DatabaseSessionFactory
+from sqlalchemy.orm import sessionmaker
+from sqlalchemy.schema import MetaData
+from sqlalchemy.types import DateTime
+from sqlalchemy.types import PickleType
+from sqlalchemy.types import String
+from sqlalchemy.types import TypeDecorator
+from typing_extensions import override
+from tzlocal import get_localzone
+
+from ..events.event import Event
+from .base_session_service import BaseSessionService
+from .base_session_service import GetSessionConfig
+from .base_session_service import ListEventsResponse
+from .base_session_service import ListSessionsResponse
+from .session import Session
+from .state import State
+
+logger = logging.getLogger(__name__)
+
+
+class DynamicJSON(TypeDecorator):
+ """A JSON-like type that uses JSONB on PostgreSQL and TEXT with JSON
+
+ serialization for other databases.
+ """
+
+ impl = Text # Default implementation is TEXT
+
+ def load_dialect_impl(self, dialect: Dialect):
+ if dialect.name == "postgresql":
+ return dialect.type_descriptor(postgresql.JSONB)
+ else:
+ return dialect.type_descriptor(Text) # Default to Text for other dialects
+
+ def process_bind_param(self, value, dialect: Dialect):
+ if value is not None:
+ if dialect.name == "postgresql":
+ return value # JSONB handles dict directly
+ else:
+ return json.dumps(value) # Serialize to JSON string for TEXT
+ return value
+
+ def process_result_value(self, value, dialect: Dialect):
+ if value is not None:
+ if dialect.name == "postgresql":
+ return value # JSONB returns dict directly
+ else:
+ return json.loads(value) # Deserialize from JSON string for TEXT
+ return value
+
+
+class Base(DeclarativeBase):
+ """Base class for database tables."""
+ pass
+
+
+class StorageSession(Base):
+ """Represents a session stored in the database."""
+ __tablename__ = "sessions"
+
+ app_name: Mapped[str] = mapped_column(String, primary_key=True)
+ user_id: Mapped[str] = mapped_column(String, primary_key=True)
+ id: Mapped[str] = mapped_column(
+ String, primary_key=True, default=lambda: str(uuid.uuid4())
+ )
+
+ state: Mapped[dict] = mapped_column(
+ MutableDict.as_mutable(DynamicJSON), default={}
+ )
+
+ create_time: Mapped[DateTime] = mapped_column(DateTime(), default=func.now())
+ update_time: Mapped[DateTime] = mapped_column(
+ DateTime(), default=func.now(), onupdate=func.now()
+ )
+
+ storage_events: Mapped[list["StorageEvent"]] = relationship(
+ "StorageEvent",
+ back_populates="storage_session",
+ )
+
+ def __repr__(self):
+ return f""
+
+
+class StorageEvent(Base):
+ """Represents an event stored in the database."""
+ __tablename__ = "events"
+
+ id: Mapped[str] = mapped_column(String, primary_key=True)
+ app_name: Mapped[str] = mapped_column(String, primary_key=True)
+ user_id: Mapped[str] = mapped_column(String, primary_key=True)
+ session_id: Mapped[str] = mapped_column(String, primary_key=True)
+
+ invocation_id: Mapped[str] = mapped_column(String)
+ author: Mapped[str] = mapped_column(String)
+ branch: Mapped[str] = mapped_column(String, nullable=True)
+ timestamp: Mapped[DateTime] = mapped_column(DateTime(), default=func.now())
+ content: Mapped[dict] = mapped_column(DynamicJSON)
+ actions: Mapped[dict] = mapped_column(PickleType)
+
+ storage_session: Mapped[StorageSession] = relationship(
+ "StorageSession",
+ back_populates="storage_events",
+ )
+
+ __table_args__ = (
+ ForeignKeyConstraint(
+ ["app_name", "user_id", "session_id"],
+ ["sessions.app_name", "sessions.user_id", "sessions.id"],
+ ondelete="CASCADE",
+ ),
+ )
+
+
+class StorageAppState(Base):
+ """Represents an app state stored in the database."""
+ __tablename__ = "app_states"
+
+ app_name: Mapped[str] = mapped_column(String, primary_key=True)
+ state: Mapped[dict] = mapped_column(
+ MutableDict.as_mutable(DynamicJSON), default={}
+ )
+ update_time: Mapped[DateTime] = mapped_column(
+ DateTime(), default=func.now(), onupdate=func.now()
+ )
+
+
+class StorageUserState(Base):
+ """Represents a user state stored in the database."""
+ __tablename__ = "user_states"
+
+ app_name: Mapped[str] = mapped_column(String, primary_key=True)
+ user_id: Mapped[str] = mapped_column(String, primary_key=True)
+ state: Mapped[dict] = mapped_column(
+ MutableDict.as_mutable(DynamicJSON), default={}
+ )
+ update_time: Mapped[DateTime] = mapped_column(
+ DateTime(), default=func.now(), onupdate=func.now()
+ )
+
+
+class DatabaseSessionService(BaseSessionService):
+ """A session service that uses a database for storage."""
+
+ def __init__(self, db_url: str):
+ """
+ Args:
+ db_url: The database URL to connect to.
+ """
+ # 1. Create DB engine for db connection
+ # 2. Create all tables based on schema
+ # 3. Initialize all properies
+
+ supported_dialects = ["postgresql", "mysql", "sqlite"]
+ dialect = db_url.split("://")[0]
+
+ if dialect in supported_dialects:
+ db_engine = create_engine(db_url)
+ else:
+ raise ValueError(f"Unsupported database URL: {db_url}")
+
+ # Get the local timezone
+ local_timezone = get_localzone()
+ logger.info(f"Local timezone: {local_timezone}")
+
+ self.db_engine: Engine = db_engine
+ self.metadata: MetaData = MetaData()
+ self.inspector = inspect(self.db_engine)
+
+ # DB session factory method
+ self.DatabaseSessionFactory: sessionmaker[DatabaseSessionFactory] = (
+ sessionmaker(bind=self.db_engine)
+ )
+
+ # Uncomment to recreate DB every time
+ # Base.metadata.drop_all(self.db_engine)
+ Base.metadata.create_all(self.db_engine)
+
+ @override
+ def create_session(
+ self,
+ *,
+ app_name: str,
+ user_id: str,
+ state: Optional[dict[str, Any]] = None,
+ session_id: Optional[str] = None,
+ ) -> Session:
+ # 1. Populate states.
+ # 2. Build storage session object
+ # 3. Add the object to the table
+ # 4. Build the session object with generated id
+ # 5. Return the session
+
+ with self.DatabaseSessionFactory() as sessionFactory:
+
+ # Fetch app and user states from storage
+ storage_app_state = sessionFactory.get(StorageAppState, (app_name))
+ storage_user_state = sessionFactory.get(
+ StorageUserState, (app_name, user_id)
+ )
+
+ app_state = storage_app_state.state if storage_app_state else {}
+ user_state = storage_user_state.state if storage_user_state else {}
+
+ # Create state tables if not exist
+ if not storage_app_state:
+ storage_app_state = StorageAppState(app_name=app_name, state={})
+ sessionFactory.add(storage_app_state)
+ if not storage_user_state:
+ storage_user_state = StorageUserState(
+ app_name=app_name, user_id=user_id, state={}
+ )
+ sessionFactory.add(storage_user_state)
+
+ # Extract state deltas
+ app_state_delta, user_state_delta, session_state = _extract_state_delta(
+ state
+ )
+
+ # Apply state delta
+ app_state.update(app_state_delta)
+ user_state.update(user_state_delta)
+
+ # Store app and user state
+ if app_state_delta:
+ storage_app_state.state = app_state
+ if user_state_delta:
+ storage_user_state.state = user_state
+
+ # Store the session
+ storage_session = StorageSession(
+ app_name=app_name,
+ user_id=user_id,
+ id=session_id,
+ state=session_state,
+ )
+ sessionFactory.add(storage_session)
+ sessionFactory.commit()
+
+ sessionFactory.refresh(storage_session)
+
+ # Merge states for response
+ merged_state = _merge_state(app_state, user_state, session_state)
+ session = Session(
+ app_name=str(storage_session.app_name),
+ user_id=str(storage_session.user_id),
+ id=str(storage_session.id),
+ state=merged_state,
+ last_update_time=storage_session.update_time.timestamp(),
+ )
+ return session
+ return None
+
+ @override
+ def get_session(
+ self,
+ *,
+ app_name: str,
+ user_id: str,
+ session_id: str,
+ config: Optional[GetSessionConfig] = None,
+ ) -> Optional[Session]:
+ # 1. Get the storage session entry from session table
+ # 2. Get all the events based on session id and filtering config
+ # 3. Convert and return the session
+ session: Session = None
+ with self.DatabaseSessionFactory() as sessionFactory:
+ storage_session = sessionFactory.get(
+ StorageSession, (app_name, user_id, session_id)
+ )
+ if storage_session is None:
+ return None
+
+ storage_events = (
+ sessionFactory.query(StorageEvent)
+ .filter(StorageEvent.session_id == storage_session.id)
+ .filter(
+ StorageEvent.timestamp < config.after_timestamp
+ if config
+ else True
+ )
+ .limit(config.num_recent_events if config else None)
+ .all()
+ )
+
+ # Fetch states from storage
+ storage_app_state = sessionFactory.get(StorageAppState, (app_name))
+ storage_user_state = sessionFactory.get(
+ StorageUserState, (app_name, user_id)
+ )
+
+ app_state = storage_app_state.state if storage_app_state else {}
+ user_state = storage_user_state.state if storage_user_state else {}
+ session_state = storage_session.state
+
+ # Merge states
+ merged_state = _merge_state(app_state, user_state, session_state)
+
+ # Convert storage session to session
+ session = Session(
+ app_name=app_name,
+ user_id=user_id,
+ id=session_id,
+ state=merged_state,
+ last_update_time=storage_session.update_time.timestamp(),
+ )
+ session.events = [
+ Event(
+ id=e.id,
+ author=e.author,
+ branch=e.branch,
+ invocation_id=e.invocation_id,
+ content=e.content,
+ actions=e.actions,
+ timestamp=e.timestamp.timestamp(),
+ )
+ for e in storage_events
+ ]
+
+ return session
+
+ @override
+ def list_sessions(
+ self, *, app_name: str, user_id: str
+ ) -> ListSessionsResponse:
+ with self.DatabaseSessionFactory() as sessionFactory:
+ results = (
+ sessionFactory.query(StorageSession)
+ .filter(StorageSession.app_name == app_name)
+ .filter(StorageSession.user_id == user_id)
+ .all()
+ )
+ sessions = []
+ for storage_session in results:
+ session = Session(
+ app_name=app_name,
+ user_id=user_id,
+ id=storage_session.id,
+ state={},
+ last_update_time=storage_session.update_time.timestamp(),
+ )
+ sessions.append(session)
+ return ListSessionsResponse(sessions=sessions)
+ raise ValueError("Failed to retrieve sessions.")
+
+ @override
+ def delete_session(
+ self, app_name: str, user_id: str, session_id: str
+ ) -> None:
+ with self.DatabaseSessionFactory() as sessionFactory:
+ stmt = delete(StorageSession).where(
+ StorageSession.app_name == app_name,
+ StorageSession.user_id == user_id,
+ StorageSession.id == session_id,
+ )
+ sessionFactory.execute(stmt)
+ sessionFactory.commit()
+
+ @override
+ def append_event(self, session: Session, event: Event) -> Event:
+ logger.info(f"Append event: {event} to session {session.id}")
+
+ if event.partial and not event.content:
+ return event
+
+ # 1. Check if timestamp is stale
+ # 2. Update session attributes based on event config
+ # 3. Store event to table
+ with self.DatabaseSessionFactory() as sessionFactory:
+ storage_session = sessionFactory.get(
+ StorageSession, (session.app_name, session.user_id, session.id)
+ )
+
+ if storage_session.update_time.timestamp() > session.last_update_time:
+ raise ValueError(
+ f"Session last_update_time {session.last_update_time} is later than"
+ f" the upate_time in storage {storage_session.update_time}"
+ )
+
+ # Fetch states from storage
+ storage_app_state = sessionFactory.get(
+ StorageAppState, (session.app_name)
+ )
+ storage_user_state = sessionFactory.get(
+ StorageUserState, (session.app_name, session.user_id)
+ )
+
+ app_state = storage_app_state.state if storage_app_state else {}
+ user_state = storage_user_state.state if storage_user_state else {}
+ session_state = storage_session.state
+
+ # Extract state delta
+ app_state_delta = {}
+ user_state_delta = {}
+ session_state_delta = {}
+ if event.actions:
+ if event.actions.state_delta:
+ app_state_delta, user_state_delta, session_state_delta = (
+ _extract_state_delta(event.actions.state_delta)
+ )
+
+ # Merge state
+ app_state.update(app_state_delta)
+ user_state.update(user_state_delta)
+ session_state.update(session_state_delta)
+
+ # Update storage
+ storage_app_state.state = app_state
+ storage_user_state.state = user_state
+ storage_session.state = session_state
+
+ encoded_content = event.content.model_dump(exclude_none=True)
+ storage_event = StorageEvent(
+ id=event.id,
+ invocation_id=event.invocation_id,
+ author=event.author,
+ branch=event.branch,
+ content=encoded_content,
+ actions=event.actions,
+ session_id=session.id,
+ app_name=session.app_name,
+ user_id=session.user_id,
+ timestamp=datetime.fromtimestamp(event.timestamp),
+ )
+
+ sessionFactory.add(storage_event)
+
+ sessionFactory.commit()
+ sessionFactory.refresh(storage_session)
+
+ # Update timestamp with commit time
+ session.last_update_time = storage_session.update_time.timestamp()
+
+ # Also update the in-memory session
+ super().append_event(session=session, event=event)
+ return event
+
+ @override
+ def list_events(
+ self,
+ *,
+ app_name: str,
+ user_id: str,
+ session_id: str,
+ ) -> ListEventsResponse:
+ pass
+
+
+def convert_event(event: StorageEvent) -> Event:
+ """Converts a storage event to an event."""
+ return Event(
+ id=event.id,
+ author=event.author,
+ branch=event.branch,
+ invocation_id=event.invocation_id,
+ content=event.content,
+ actions=event.actions,
+ timestamp=event.timestamp.timestamp(),
+ )
+
+
+def _extract_state_delta(state: dict):
+ app_state_delta = {}
+ user_state_delta = {}
+ session_state_delta = {}
+ if state:
+ for key in state.keys():
+ if key.startswith(State.APP_PREFIX):
+ app_state_delta[key.removeprefix(State.APP_PREFIX)] = state[key]
+ elif key.startswith(State.USER_PREFIX):
+ user_state_delta[key.removeprefix(State.USER_PREFIX)] = state[key]
+ elif not key.startswith(State.TEMP_PREFIX):
+ session_state_delta[key] = state[key]
+ return app_state_delta, user_state_delta, session_state_delta
+
+
+def _merge_state(app_state, user_state, session_state):
+ # Merge states for response
+ merged_state = copy.deepcopy(session_state)
+ for key in app_state.keys():
+ merged_state[State.APP_PREFIX + key] = app_state[key]
+ for key in user_state.keys():
+ merged_state[State.USER_PREFIX + key] = user_state[key]
+ return merged_state
diff --git a/src/google/adk/sessions/in_memory_session_service.py b/src/google/adk/sessions/in_memory_session_service.py
new file mode 100644
index 0000000..bcb659a
--- /dev/null
+++ b/src/google/adk/sessions/in_memory_session_service.py
@@ -0,0 +1,206 @@
+# Copyright 2025 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import copy
+import time
+from typing import Any
+from typing import Optional
+import uuid
+
+from typing_extensions import override
+
+from ..events.event import Event
+from .base_session_service import BaseSessionService
+from .base_session_service import GetSessionConfig
+from .base_session_service import ListEventsResponse
+from .base_session_service import ListSessionsResponse
+from .session import Session
+from .state import State
+
+
+class InMemorySessionService(BaseSessionService):
+ """An in-memory implementation of the session service."""
+
+ def __init__(self):
+ # A map from app name to a map from user ID to a map from session ID to session.
+ self.sessions: dict[str, dict[str, dict[str, Session]]] = {}
+ # A map from app name to a map from user ID to a map from key to the value.
+ self.user_state: dict[str, dict[str, dict[str, Any]]] = {}
+ # A map from app name to a map from key to the value.
+ self.app_state: dict[str, dict[str, Any]] = {}
+
+ @override
+ def create_session(
+ self,
+ *,
+ app_name: str,
+ user_id: str,
+ state: Optional[dict[str, Any]] = None,
+ session_id: Optional[str] = None,
+ ) -> Session:
+ session_id = (
+ session_id.strip()
+ if session_id and session_id.strip()
+ else str(uuid.uuid4())
+ )
+ session = Session(
+ app_name=app_name,
+ user_id=user_id,
+ id=session_id,
+ state=state or {},
+ last_update_time=time.time(),
+ )
+
+ if app_name not in self.sessions:
+ self.sessions[app_name] = {}
+ if user_id not in self.sessions[app_name]:
+ self.sessions[app_name][user_id] = {}
+ self.sessions[app_name][user_id][session_id] = session
+
+ copied_session = copy.deepcopy(session)
+ return self._merge_state(app_name, user_id, copied_session)
+
+ @override
+ def get_session(
+ self,
+ *,
+ app_name: str,
+ user_id: str,
+ session_id: str,
+ config: Optional[GetSessionConfig] = None,
+ ) -> Session:
+ if app_name not in self.sessions:
+ return None
+ if user_id not in self.sessions[app_name]:
+ return None
+ if session_id not in self.sessions[app_name][user_id]:
+ return None
+
+ session = self.sessions[app_name][user_id].get(session_id)
+ copied_session = copy.deepcopy(session)
+
+ if config:
+ if config.num_recent_events:
+ copied_session.events = copied_session.events[
+ -config.num_recent_events :
+ ]
+ elif config.after_timestamp:
+ i = len(session.events) - 1
+ while i >= 0:
+ if copied_session.events[i].timestamp < config.after_timestamp:
+ break
+ i -= 1
+ if i >= 0:
+ copied_session.events = copied_session.events[i:]
+
+ return self._merge_state(app_name, user_id, copied_session)
+
+ def _merge_state(self, app_name: str, user_id: str, copied_session: Session):
+ # Merge app state
+ if app_name in self.app_state:
+ for key in self.app_state[app_name].keys():
+ copied_session.state[State.APP_PREFIX + key] = self.app_state[app_name][
+ key
+ ]
+
+ if (
+ app_name not in self.user_state
+ or user_id not in self.user_state[app_name]
+ ):
+ return copied_session
+
+ # Merge session state with user state.
+ for key in self.user_state[app_name][user_id].keys():
+ copied_session.state[State.USER_PREFIX + key] = self.user_state[app_name][
+ user_id
+ ][key]
+ return copied_session
+
+ @override
+ def list_sessions(
+ self, *, app_name: str, user_id: str
+ ) -> ListSessionsResponse:
+ empty_response = ListSessionsResponse()
+ if app_name not in self.sessions:
+ return empty_response
+ if user_id not in self.sessions[app_name]:
+ return empty_response
+
+ sessions_without_events = []
+ for session in self.sessions[app_name][user_id].values():
+ copied_session = copy.deepcopy(session)
+ copied_session.events = []
+ copied_session.state = {}
+ sessions_without_events.append(copied_session)
+ return ListSessionsResponse(sessions=sessions_without_events)
+
+ @override
+ def delete_session(
+ self, *, app_name: str, user_id: str, session_id: str
+ ) -> None:
+ if (
+ self.get_session(
+ app_name=app_name, user_id=user_id, session_id=session_id
+ )
+ is None
+ ):
+ return None
+
+ self.sessions[app_name][user_id].pop(session_id)
+
+ @override
+ def append_event(self, session: Session, event: Event) -> Event:
+ # Update the in-memory session.
+ super().append_event(session=session, event=event)
+ session.last_update_time = event.timestamp
+
+ # Update the storage session
+ app_name = session.app_name
+ user_id = session.user_id
+ session_id = session.id
+ if app_name not in self.sessions:
+ return event
+ if user_id not in self.sessions[app_name]:
+ return event
+ if session_id not in self.sessions[app_name][user_id]:
+ return event
+
+ if event.actions and event.actions.state_delta:
+ for key in event.actions.state_delta:
+ if key.startswith(State.APP_PREFIX):
+ self.app_state.setdefault(app_name, {})[
+ key.removeprefix(State.APP_PREFIX)
+ ] = event.actions.state_delta[key]
+
+ if key.startswith(State.USER_PREFIX):
+ self.user_state.setdefault(app_name, {}).setdefault(user_id, {})[
+ key.removeprefix(State.USER_PREFIX)
+ ] = event.actions.state_delta[key]
+
+ storage_session = self.sessions[app_name][user_id].get(session_id)
+ super().append_event(session=storage_session, event=event)
+
+ storage_session.last_update_time = event.timestamp
+
+ return event
+
+ @override
+ def list_events(
+ self,
+ *,
+ app_name: str,
+ user_id: str,
+ session_id: str,
+ ) -> ListEventsResponse:
+ raise NotImplementedError()
diff --git a/src/google/adk/sessions/session.py b/src/google/adk/sessions/session.py
new file mode 100644
index 0000000..3d3fac5
--- /dev/null
+++ b/src/google/adk/sessions/session.py
@@ -0,0 +1,54 @@
+# Copyright 2025 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import Any
+
+from pydantic import BaseModel
+from pydantic import ConfigDict
+from pydantic import Field
+
+from ..events.event import Event
+
+
+class Session(BaseModel):
+ """Represents a series of interactions between a user and agents.
+
+ Attributes:
+ id: The unique identifier of the session.
+ app_name: The name of the app.
+ user_id: The id of the user.
+ state: The state of the session.
+ events: The events of the session, e.g. user input, model response, function
+ call/response, etc.
+ last_update_time: The last update time of the session.
+ """
+
+ model_config = ConfigDict(
+ extra='forbid',
+ arbitrary_types_allowed=True,
+ )
+
+ id: str
+ """The unique identifier of the session."""
+ app_name: str
+ """The name of the app."""
+ user_id: str
+ """The id of the user."""
+ state: dict[str, Any] = Field(default_factory=dict)
+ """The state of the session."""
+ events: list[Event] = Field(default_factory=list)
+ """The events of the session, e.g. user input, model response, function
+ call/response, etc."""
+ last_update_time: float = 0.0
+ """The last update time of the session."""
diff --git a/src/google/adk/sessions/state.py b/src/google/adk/sessions/state.py
new file mode 100644
index 0000000..a333a17
--- /dev/null
+++ b/src/google/adk/sessions/state.py
@@ -0,0 +1,71 @@
+# Copyright 2025 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import Any
+
+
+class State:
+ """A state dict that maintain the current value and the pending-commit delta."""
+
+ APP_PREFIX = "app:"
+ USER_PREFIX = "user:"
+ TEMP_PREFIX = "temp:"
+
+ def __init__(self, value: dict[str, Any], delta: dict[str, Any]):
+ """
+ Args:
+ value: The current value of the state dict.
+ delta: The delta change to the current value that hasn't been commited.
+ """
+ self._value = value
+ self._delta = delta
+
+ def __getitem__(self, key: str) -> Any:
+ """Returns the value of the state dict for the given key."""
+ if key in self._delta:
+ return self._delta[key]
+ return self._value[key]
+
+ def __setitem__(self, key: str, value: Any):
+ """Sets the value of the state dict for the given key."""
+ # TODO: make new change only store in delta, so that self._value is only
+ # updated at the storage commit time.
+ self._value[key] = value
+ self._delta[key] = value
+
+ def __contains__(self, key: str) -> bool:
+ """Whether the state dict contains the given key."""
+ return key in self._value or key in self._delta
+
+ def has_delta(self) -> bool:
+ """Whether the state has pending detla."""
+ return bool(self._delta)
+
+ def get(self, key: str, default: Any = None) -> Any:
+ """Returns the value of the state dict for the given key."""
+ if key not in self:
+ return default
+ return self[key]
+
+ def update(self, delta: dict[str, Any]):
+ """Updates the state dict with the given delta."""
+ self._value.update(delta)
+ self._delta.update(delta)
+
+ def to_dict(self) -> dict[str, Any]:
+ """Returns the state dict."""
+ result = {}
+ result.update(self._value)
+ result.update(self._delta)
+ return result
diff --git a/src/google/adk/sessions/vertex_ai_session_service.py b/src/google/adk/sessions/vertex_ai_session_service.py
new file mode 100644
index 0000000..5ec45c4
--- /dev/null
+++ b/src/google/adk/sessions/vertex_ai_session_service.py
@@ -0,0 +1,356 @@
+# Copyright 2025 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import logging
+import re
+import time
+from typing import Any
+from typing import Optional
+
+from dateutil.parser import isoparse
+from google import genai
+from typing_extensions import override
+
+from ..events.event import Event
+from ..events.event_actions import EventActions
+from .base_session_service import BaseSessionService
+from .base_session_service import GetSessionConfig
+from .base_session_service import ListEventsResponse
+from .base_session_service import ListSessionsResponse
+from .session import Session
+
+logger = logging.getLogger(__name__)
+
+
+class VertexAiSessionService(BaseSessionService):
+ """Connects to the managed Vertex AI Session Service."""
+
+ def __init__(
+ self,
+ project: str = None,
+ location: str = None,
+ ):
+ self.project = project
+ self.location = location
+
+ client = genai.Client(vertexai=True, project=project, location=location)
+ self.api_client = client._api_client
+
+ @override
+ def create_session(
+ self,
+ *,
+ app_name: str,
+ user_id: str,
+ state: Optional[dict[str, Any]] = None,
+ session_id: Optional[str] = None,
+ ) -> Session:
+ reasoning_engine_id = _parse_reasoning_engine_id(app_name)
+
+ session_json_dict = {'user_id': user_id}
+ if state:
+ session_json_dict['session_state'] = state
+
+ api_response = self.api_client.request(
+ http_method='POST',
+ path=f'reasoningEngines/{reasoning_engine_id}/sessions',
+ request_dict=session_json_dict,
+ )
+ logger.info(f'Create Session response {api_response}')
+
+ session_id = api_response['name'].split('/')[-3]
+ operation_id = api_response['name'].split('/')[-1]
+
+ max_retry_attempt = 5
+ while max_retry_attempt >= 0:
+ lro_response = self.api_client.request(
+ http_method='GET',
+ path=f'operations/{operation_id}',
+ request_dict={},
+ )
+
+ if lro_response.get('done', None):
+ break
+
+ time.sleep(1)
+ max_retry_attempt -= 1
+
+ # Get session resource
+ get_session_api_response = self.api_client.request(
+ http_method='GET',
+ path=f'reasoningEngines/{reasoning_engine_id}/sessions/{session_id}',
+ request_dict={},
+ )
+
+ update_timestamp = isoparse(
+ get_session_api_response['updateTime']
+ ).timestamp()
+ session = Session(
+ app_name=str(app_name),
+ user_id=str(user_id),
+ id=str(session_id),
+ state=get_session_api_response.get('sessionState', {}),
+ last_update_time=update_timestamp,
+ )
+ return session
+
+ @override
+ def get_session(
+ self,
+ *,
+ app_name: str,
+ user_id: str,
+ session_id: str,
+ config: Optional[GetSessionConfig] = None,
+ ) -> Session:
+ reasoning_engine_id = _parse_reasoning_engine_id(app_name)
+
+ # Get session resource
+ get_session_api_response = self.api_client.request(
+ http_method='GET',
+ path=f'reasoningEngines/{reasoning_engine_id}/sessions/{session_id}',
+ request_dict={},
+ )
+
+ session_id = get_session_api_response['name'].split('/')[-1]
+ update_timestamp = isoparse(
+ get_session_api_response['updateTime']
+ ).timestamp()
+ session = Session(
+ app_name=str(app_name),
+ user_id=str(user_id),
+ id=str(session_id),
+ state=get_session_api_response.get('sessionState', {}),
+ last_update_time=update_timestamp,
+ )
+
+ list_events_api_response = self.api_client.request(
+ http_method='GET',
+ path=f'reasoningEngines/{reasoning_engine_id}/sessions/{session_id}/events',
+ request_dict={},
+ )
+
+ # Handles empty response case
+ if list_events_api_response.get('httpHeaders', None):
+ return session
+
+ session.events = [
+ _from_api_event(event)
+ for event in list_events_api_response['sessionEvents']
+ ]
+ session.events = [
+ event for event in session.events if event.timestamp <= update_timestamp
+ ]
+ session.events.sort(key=lambda event: event.timestamp)
+
+ if config:
+ if config.num_recent_events:
+ session.events = session.events[-config.num_recent_events :]
+ elif config.after_timestamp:
+ i = len(session.events) - 1
+ while i >= 0:
+ if session.events[i].timestamp < config.after_timestamp:
+ break
+ i -= 1
+ if i >= 0:
+ session.events = session.events[i:]
+
+ return session
+
+ @override
+ def list_sessions(
+ self, *, app_name: str, user_id: str
+ ) -> ListSessionsResponse:
+ reasoning_engine_id = _parse_reasoning_engine_id(app_name)
+
+ api_response = self.api_client.request(
+ http_method='GET',
+ path=f'reasoningEngines/{reasoning_engine_id}/sessions?filter=user_id={user_id}',
+ request_dict={},
+ )
+
+ # Handles empty response case
+ if api_response.get('httpHeaders', None):
+ return ListSessionsResponse()
+
+ sessions = []
+ for api_session in api_response['sessions']:
+ session = Session(
+ app_name=app_name,
+ user_id=user_id,
+ id=api_session['name'].split('/')[-1],
+ state={},
+ last_update_time=isoparse(api_session['updateTime']).timestamp(),
+ )
+ sessions.append(session)
+ return ListSessionsResponse(sessions=sessions)
+
+ def delete_session(
+ self, *, app_name: str, user_id: str, session_id: str
+ ) -> None:
+ reasoning_engine_id = _parse_reasoning_engine_id(app_name)
+ self.api_client.request(
+ http_method='DELETE',
+ path=f'reasoningEngines/{reasoning_engine_id}/sessions/{session_id}',
+ request_dict={},
+ )
+
+ @override
+ def list_events(
+ self,
+ *,
+ app_name: str,
+ user_id: str,
+ session_id: str,
+ ) -> ListEventsResponse:
+ reasoning_engine_id = _parse_reasoning_engine_id(app_name)
+ api_response = self.api_client.request(
+ http_method='GET',
+ path=f'reasoningEngines/{reasoning_engine_id}/sessions/{session_id}/events',
+ request_dict={},
+ )
+
+ logger.info(f'List events response {api_response}')
+
+ # Handles empty response case
+ if api_response.get('httpHeaders', None):
+ return ListEventsResponse()
+
+ session_events = api_response['sessionEvents']
+
+ return ListEventsResponse(
+ events=[_from_api_event(event) for event in session_events]
+ )
+
+ @override
+ def append_event(self, session: Session, event: Event) -> Event:
+ # Update the in-memory session.
+ super().append_event(session=session, event=event)
+
+ reasoning_engine_id = _parse_reasoning_engine_id(session.app_name)
+ self.api_client.request(
+ http_method='POST',
+ path=f'reasoningEngines/{reasoning_engine_id}/sessions/{session.id}:appendEvent',
+ request_dict=_convert_event_to_json(event),
+ )
+
+ return event
+
+
+def _convert_event_to_json(event: Event):
+ metadata_json = {
+ 'partial': event.partial,
+ 'turn_complete': event.turn_complete,
+ 'interrupted': event.interrupted,
+ 'branch': event.branch,
+ 'long_running_tool_ids': (
+ list(event.long_running_tool_ids)
+ if event.long_running_tool_ids
+ else None
+ ),
+ }
+ if event.grounding_metadata:
+ metadata_json['grounding_metadata'] = event.grounding_metadata.model_dump(
+ exclude_none=True
+ )
+
+ event_json = {
+ 'author': event.author,
+ 'invocation_id': event.invocation_id,
+ 'timestamp': {
+ 'seconds': int(event.timestamp),
+ 'nanos': int(
+ (event.timestamp - int(event.timestamp)) * 1_000_000_000
+ ),
+ },
+ 'error_code': event.error_code,
+ 'error_message': event.error_message,
+ 'event_metadata': metadata_json,
+ }
+
+ if event.actions:
+ actions_json = {
+ 'skip_summarization': event.actions.skip_summarization,
+ 'state_delta': event.actions.state_delta,
+ 'artifact_delta': event.actions.artifact_delta,
+ 'transfer_agent': event.actions.transfer_to_agent,
+ 'escalate': event.actions.escalate,
+ 'requested_auth_configs': event.actions.requested_auth_configs,
+ }
+ event_json['actions'] = actions_json
+ if event.content:
+ event_json['content'] = event.content.model_dump(exclude_none=True)
+ if event.error_code:
+ event_json['error_code'] = event.error_code
+ if event.error_message:
+ event_json['error_message'] = event.error_message
+ return event_json
+
+
+def _from_api_event(api_event: dict) -> Event:
+ event_actions = EventActions()
+ if api_event.get('actions', None):
+ event_actions = EventActions(
+ skip_summarization=api_event['actions'].get('skipSummarization', None),
+ state_delta=api_event['actions'].get('stateDelta', {}),
+ artifact_delta=api_event['actions'].get('artifactDelta', {}),
+ transfer_to_agent=api_event['actions'].get('transferAgent', None),
+ escalate=api_event['actions'].get('escalate', None),
+ requested_auth_configs=api_event['actions'].get(
+ 'requestedAuthConfigs', {}
+ ),
+ )
+
+ event = Event(
+ id=api_event['name'].split('/')[-1],
+ invocation_id=api_event['invocationId'],
+ author=api_event['author'],
+ actions=event_actions,
+ content=api_event.get('content', None),
+ timestamp=isoparse(api_event['timestamp']).timestamp(),
+ error_code=api_event.get('errorCode', None),
+ error_message=api_event.get('errorMessage', None),
+ )
+
+ if api_event.get('eventMetadata', None):
+ long_running_tool_ids_list = api_event['eventMetadata'].get(
+ 'longRunningToolIds', None
+ )
+ event.partial = api_event['eventMetadata'].get('partial', None)
+ event.turn_complete = api_event['eventMetadata'].get('turnComplete', None)
+ event.interrupted = api_event['eventMetadata'].get('interrupted', None)
+ event.branch = api_event['eventMetadata'].get('branch', None)
+ event.grounding_metadata = api_event['eventMetadata'].get(
+ 'groundingMetadata', None
+ )
+ event.long_running_tool_ids = (
+ set(long_running_tool_ids_list) if long_running_tool_ids_list else None
+ )
+
+ return event
+
+
+def _parse_reasoning_engine_id(app_name: str):
+ if app_name.isdigit():
+ return app_name
+
+ pattern = r'^projects/([a-zA-Z0-9-_]+)/locations/([a-zA-Z0-9-_]+)/reasoningEngines/(\d+)$'
+ match = re.fullmatch(pattern, app_name)
+
+ if not bool(match):
+ raise ValueError(
+ f'App name {app_name} is not valid. It should either be the full'
+ ' ReasoningEngine resource name, or the reasoning engine id.'
+ )
+
+ return match.groups()[-1]
diff --git a/src/google/adk/telemetry.py b/src/google/adk/telemetry.py
new file mode 100644
index 0000000..b74795b
--- /dev/null
+++ b/src/google/adk/telemetry.py
@@ -0,0 +1,137 @@
+# Copyright 2025 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# NOTE:
+#
+# We expect that the underlying GenAI SDK will provide a certain
+# level of tracing and logging telemetry aligned with Open Telemetry
+# Semantic Conventions (such as logging prompts, respones, request
+# properties, etc.) and so the information that is recorded by the
+# Agent Development Kit should be focused on the higher-level
+# constructs of the framework that are not observable by the SDK.
+
+import json
+from typing import Any
+
+from google.genai import types
+from opentelemetry import trace
+
+from .agents.invocation_context import InvocationContext
+from .models.llm_request import LlmRequest
+from .models.llm_response import LlmResponse
+
+tracer = trace.get_tracer('gcp.vertex.agent')
+
+
+def trace_call_llm(
+ invocation_context: InvocationContext,
+ event_id: str,
+ llm_request: LlmRequest,
+ llm_response: LlmResponse,
+):
+ """Traces a call to the LLM.
+
+ This function records details about the LLM request and response as
+ attributes on the current OpenTelemetry span.
+
+ Args:
+ invocation_context: The invocation context for the current agent run.
+ event_id: The ID of the event.
+ llm_request: The LLM request object.
+ llm_response: The LLM response object.
+ """
+ span = trace.get_current_span()
+ # Special standard Open Telemetry GenaI attributes that indicate
+ # that this is a span related to a Generative AI system.
+ span.set_attribute('gen_ai.system', 'gcp.vertex.agent')
+ span.set_attribute('gen_ai.request.model', llm_request.model)
+ span.set_attribute(
+ 'gcp.vertex.agent.invocation_id', invocation_context.invocation_id
+ )
+ span.set_attribute('gcp.vertex.agent.event_id', event_id)
+ # Consider removing once GenAI SDK provides a way to record this info.
+ span.set_attribute(
+ 'gcp.vertex.agent.llm_request',
+ json.dumps(_build_llm_request_for_trace(llm_request)),
+ )
+ # Consider removing once GenAI SDK provides a way to record this info.
+ span.set_attribute(
+ 'gcp.vertex.agent.llm_response',
+ llm_response.model_dump_json(exclude_none=True),
+ )
+
+
+def trace_send_data(
+ invocation_context: InvocationContext,
+ event_id: str,
+ data: list[types.Content],
+):
+ """Traces the sending of data to the agent.
+
+ This function records details about the data sent to the agent as
+ attributes on the current OpenTelemetry span.
+
+ Args:
+ invocation_context: The invocation context for the current agent run.
+ event_id: The ID of the event.
+ data: A list of content objects.
+ """
+ span = trace.get_current_span()
+ span.set_attribute(
+ 'gcp.vertex.agent.invocation_id', invocation_context.invocation_id
+ )
+ span.set_attribute('gcp.vertex.agent.event_id', event_id)
+ # Once instrumentation is added to the GenAI SDK, consider whether this
+ # information still needs to be recorded by the Agent Development Kit.
+ span.set_attribute(
+ 'gcp.vertex.agent.data',
+ json.dumps([
+ types.Content(role=content.role, parts=content.parts).model_dump(
+ exclude_none=True
+ )
+ for content in data
+ ]),
+ )
+
+
+def _build_llm_request_for_trace(llm_request: LlmRequest) -> dict[str, Any]:
+ """Builds a dictionary representation of the LLM request for tracing.
+
+ This function prepares a dictionary representation of the LlmRequest
+ object, suitable for inclusion in a trace. It excludes fields that cannot
+ be serialized (e.g., function pointers) and avoids sending bytes data.
+
+ Args:
+ llm_request: The LlmRequest object.
+
+ Returns:
+ A dictionary representation of the LLM request.
+ """
+ # Some fields in LlmRequest are function pointers and can not be serialized.
+ result = {
+ 'model': llm_request.model,
+ 'config': llm_request.config.model_dump(
+ exclude_none=True, exclude='response_schema'
+ ),
+ 'contents': [],
+ }
+ # We do not want to send bytes data to the trace.
+ for content in llm_request.contents:
+ parts = [part for part in content.parts if not part.inline_data]
+ result['contents'].append(
+ types.Content(role=content.role, parts=parts).model_dump(
+ exclude_none=True
+ )
+ )
+ return result
diff --git a/src/google/adk/tests/__init__.py b/src/google/adk/tests/__init__.py
new file mode 100644
index 0000000..36a1e8d
--- /dev/null
+++ b/src/google/adk/tests/__init__.py
@@ -0,0 +1,14 @@
+# Copyright 2025 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
diff --git a/src/google/adk/tests/integration/.env.example b/src/google/adk/tests/integration/.env.example
new file mode 100644
index 0000000..2ae0bce
--- /dev/null
+++ b/src/google/adk/tests/integration/.env.example
@@ -0,0 +1,10 @@
+# Copy as .env file and fill your values below to run integration tests.
+
+# Choose Backend: GOOGLE_AI_ONLY | VERTEX_ONLY | BOTH (default)
+TEST_BACKEND=BOTH
+
+# ML Dev backend config
+GOOGLE_API_KEY=YOUR_VALUE_HERE
+# Vertex backend config
+GOOGLE_CLOUD_PROJECT=YOUR_VALUE_HERE
+GOOGLE_CLOUD_LOCATION=YOUR_VALUE_HERE
diff --git a/src/google/adk/tests/integration/__init__.py b/src/google/adk/tests/integration/__init__.py
new file mode 100644
index 0000000..a89a39a
--- /dev/null
+++ b/src/google/adk/tests/integration/__init__.py
@@ -0,0 +1,18 @@
+# Copyright 2025 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import pytest
+
+# This allows pytest to show the values of the asserts.
+pytest.register_assert_rewrite('tests.integration.utils')
diff --git a/src/google/adk/tests/integration/conftest.py b/src/google/adk/tests/integration/conftest.py
new file mode 100644
index 0000000..9ee1dc6
--- /dev/null
+++ b/src/google/adk/tests/integration/conftest.py
@@ -0,0 +1,119 @@
+# Copyright 2025 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import logging
+import os
+from typing import Literal
+import warnings
+
+from dotenv import load_dotenv
+from google.adk import Agent
+from pytest import fixture
+from pytest import FixtureRequest
+from pytest import hookimpl
+from pytest import Metafunc
+
+from .utils import TestRunner
+
+logger = logging.getLogger(__name__)
+
+
+def load_env_for_tests():
+ dotenv_path = os.path.join(os.path.dirname(__file__), '.env')
+ if not os.path.exists(dotenv_path):
+ warnings.warn(
+ f'Missing .env file at {dotenv_path}. See dotenv.sample for an example.'
+ )
+ else:
+ load_dotenv(dotenv_path, override=True, verbose=True)
+ if 'GOOGLE_API_KEY' not in os.environ:
+ warnings.warn(
+ 'Missing GOOGLE_API_KEY in the environment variables. GOOGLE_AI backend'
+ ' integration tests will fail.'
+ )
+ for env_var in [
+ 'GOOGLE_CLOUD_PROJECT',
+ 'GOOGLE_CLOUD_LOCATION',
+ ]:
+ if env_var not in os.environ:
+ warnings.warn(
+ f'Missing {env_var} in the environment variables. Vertex backend'
+ ' integration tests will fail.'
+ )
+
+
+load_env_for_tests()
+
+BackendType = Literal['GOOGLE_AI', 'VERTEX']
+
+
+@fixture
+def agent_runner(request: FixtureRequest) -> TestRunner:
+ assert isinstance(request.param, dict)
+
+ if 'agent' in request.param:
+ assert isinstance(request.param['agent'], Agent)
+ return TestRunner(request.param['agent'])
+ elif 'agent_name' in request.param:
+ assert isinstance(request.param['agent_name'], str)
+ return TestRunner.from_agent_name(request.param['agent_name'])
+
+ raise NotImplementedError('Must provide agent or agent_name.')
+
+
+@fixture(autouse=True)
+def llm_backend(request: FixtureRequest):
+ # Set backend environment value.
+ original_val = os.environ.get('GOOGLE_GENAI_USE_VERTEXAI')
+ backend_type = request.param
+ if backend_type == 'GOOGLE_AI':
+ os.environ['GOOGLE_GENAI_USE_VERTEXAI'] = '0'
+ else:
+ os.environ['GOOGLE_GENAI_USE_VERTEXAI'] = '1'
+
+ yield # Run the test
+
+ # Restore the environment
+ if original_val is None:
+ os.environ.pop('GOOGLE_GENAI_USE_VERTEXAI', None)
+ else:
+ os.environ['GOOGLE_GENAI_USE_VERTEXAI'] = original_val
+
+
+@hookimpl(tryfirst=True)
+def pytest_generate_tests(metafunc: Metafunc):
+ if llm_backend.__name__ in metafunc.fixturenames:
+ if not _is_explicitly_marked(llm_backend.__name__, metafunc):
+ test_backend = os.environ.get('TEST_BACKEND', 'BOTH')
+ if test_backend == 'GOOGLE_AI_ONLY':
+ metafunc.parametrize(llm_backend.__name__, ['GOOGLE_AI'], indirect=True)
+ elif test_backend == 'VERTEX_ONLY':
+ metafunc.parametrize(llm_backend.__name__, ['VERTEX'], indirect=True)
+ elif test_backend == 'BOTH':
+ metafunc.parametrize(
+ llm_backend.__name__, ['GOOGLE_AI', 'VERTEX'], indirect=True
+ )
+ else:
+ raise ValueError(
+ f'Invalid TEST_BACKEND value: {test_backend}, should be one of'
+ ' [GOOGLE_AI_ONLY, VERTEX_ONLY, BOTH]'
+ )
+
+
+def _is_explicitly_marked(mark_name: str, metafunc: Metafunc) -> bool:
+ if hasattr(metafunc.function, 'pytestmark'):
+ for mark in metafunc.function.pytestmark:
+ if mark.name == 'parametrize' and mark.args[0] == mark_name:
+ return True
+ return False
diff --git a/src/google/adk/tests/integration/fixture/__init__.py b/src/google/adk/tests/integration/fixture/__init__.py
new file mode 100644
index 0000000..36a1e8d
--- /dev/null
+++ b/src/google/adk/tests/integration/fixture/__init__.py
@@ -0,0 +1,14 @@
+# Copyright 2025 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
diff --git a/src/google/adk/tests/integration/fixture/agent_with_config/__init__.py b/src/google/adk/tests/integration/fixture/agent_with_config/__init__.py
new file mode 100644
index 0000000..c48963c
--- /dev/null
+++ b/src/google/adk/tests/integration/fixture/agent_with_config/__init__.py
@@ -0,0 +1,15 @@
+# Copyright 2025 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from . import agent
diff --git a/src/google/adk/tests/integration/fixture/agent_with_config/agent.py b/src/google/adk/tests/integration/fixture/agent_with_config/agent.py
new file mode 100644
index 0000000..a919b32
--- /dev/null
+++ b/src/google/adk/tests/integration/fixture/agent_with_config/agent.py
@@ -0,0 +1,88 @@
+# Copyright 2025 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from google.adk import Agent
+from google.genai import types
+
+new_message = types.Content(
+ role="user",
+ parts=[types.Part.from_text(text="Count a number")],
+)
+
+google_agent_1 = Agent(
+ model="gemini-1.5-flash",
+ name="agent_1",
+ description="The first agent in the team.",
+ instruction="Just say 1",
+ generate_content_config=types.GenerateContentConfig(
+ temperature=0.1,
+ ),
+)
+
+google_agent_2 = Agent(
+ model="gemini-1.5-flash",
+ name="agent_2",
+ description="The second agent in the team.",
+ instruction="Just say 2",
+ generate_content_config=types.GenerateContentConfig(
+ temperature=0.2,
+ safety_settings=[{
+ "category": "HARM_CATEGORY_HATE_SPEECH",
+ "threshold": "BLOCK_ONLY_HIGH",
+ }],
+ ),
+)
+
+google_agent_3 = Agent(
+ model="gemini-1.5-flash",
+ name="agent_3",
+ description="The third agent in the team.",
+ instruction="Just say 3",
+ generate_content_config=types.GenerateContentConfig(
+ temperature=0.5,
+ safety_settings=[{
+ "category": "HARM_CATEGORY_DANGEROUS_CONTENT",
+ "threshold": "BLOCK_NONE",
+ }],
+ ),
+)
+
+google_agent_with_instruction_in_config = Agent(
+ model="gemini-1.5-flash",
+ name="agent",
+ generate_content_config=types.GenerateContentConfig(
+ temperature=0.5, system_instruction="Count 1"
+ ),
+)
+
+
+def function():
+ pass
+
+
+google_agent_with_tools_in_config = Agent(
+ model="gemini-1.5-flash",
+ name="agent",
+ generate_content_config=types.GenerateContentConfig(
+ temperature=0.5, tools=[function]
+ ),
+)
+
+google_agent_with_response_schema_in_config = Agent(
+ model="gemini-1.5-flash",
+ name="agent",
+ generate_content_config=types.GenerateContentConfig(
+ temperature=0.5, response_schema={"key": "value"}
+ ),
+)
diff --git a/src/google/adk/tests/integration/fixture/callback_agent/__init__.py b/src/google/adk/tests/integration/fixture/callback_agent/__init__.py
new file mode 100644
index 0000000..44f7dab
--- /dev/null
+++ b/src/google/adk/tests/integration/fixture/callback_agent/__init__.py
@@ -0,0 +1,15 @@
+# Copyright 2025 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from . import agent
\ No newline at end of file
diff --git a/src/google/adk/tests/integration/fixture/callback_agent/agent.py b/src/google/adk/tests/integration/fixture/callback_agent/agent.py
new file mode 100644
index 0000000..f57c3aa
--- /dev/null
+++ b/src/google/adk/tests/integration/fixture/callback_agent/agent.py
@@ -0,0 +1,105 @@
+# Copyright 2025 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import Optional
+
+from google.adk import Agent
+from google.adk.agents.callback_context import CallbackContext
+from google.adk.agents.invocation_context import InvocationContext
+from google.adk.models import LlmRequest
+from google.adk.models import LlmResponse
+from google.genai import types
+
+
+def before_agent_call_end_invocation(
+ callback_context: CallbackContext,
+) -> types.Content:
+ return types.Content(
+ role='model',
+ parts=[types.Part(text='End invocation event before agent call.')],
+ )
+
+
+def before_agent_call(
+ invocation_context: InvocationContext,
+) -> types.Content:
+ return types.Content(
+ role='model',
+ parts=[types.Part.from_text(text='Plain text event before agent call.')],
+ )
+
+
+def before_model_call_end_invocation(
+ callback_context: CallbackContext, llm_request: LlmRequest
+) -> LlmResponse:
+ return LlmResponse(
+ content=types.Content(
+ role='model',
+ parts=[
+ types.Part.from_text(
+ text='End invocation event before model call.'
+ )
+ ],
+ )
+ )
+
+
+def before_model_call(
+ invocation_context: InvocationContext, request: LlmRequest
+) -> LlmResponse:
+ request.config.system_instruction = 'Just return 999 as response.'
+ return LlmResponse(
+ content=types.Content(
+ role='model',
+ parts=[
+ types.Part.from_text(
+ text='Update request event before model call.'
+ )
+ ],
+ )
+ )
+
+
+def after_model_call(
+ callback_context: CallbackContext,
+ llm_response: LlmResponse,
+) -> Optional[LlmResponse]:
+ content = llm_response.content
+ if not content or not content.parts or not content.parts[0].text:
+ return
+
+ content.parts[0].text += 'Update response event after model call.'
+ return llm_response
+
+
+before_agent_callback_agent = Agent(
+ model='gemini-1.5-flash',
+ name='before_agent_callback_agent',
+ instruction='echo 1',
+ before_agent_callback=before_agent_call_end_invocation,
+)
+
+before_model_callback_agent = Agent(
+ model='gemini-1.5-flash',
+ name='before_model_callback_agent',
+ instruction='echo 2',
+ before_model_callback=before_model_call_end_invocation,
+)
+
+after_model_callback_agent = Agent(
+ model='gemini-1.5-flash',
+ name='after_model_callback_agent',
+ instruction='Say hello',
+ after_model_callback=after_model_call,
+)
diff --git a/src/google/adk/tests/integration/fixture/context_update_test/OWNERS b/src/google/adk/tests/integration/fixture/context_update_test/OWNERS
new file mode 100644
index 0000000..02f72c4
--- /dev/null
+++ b/src/google/adk/tests/integration/fixture/context_update_test/OWNERS
@@ -0,0 +1 @@
+gkcng
diff --git a/src/google/adk/tests/integration/fixture/context_update_test/__init__.py b/src/google/adk/tests/integration/fixture/context_update_test/__init__.py
new file mode 100644
index 0000000..c48963c
--- /dev/null
+++ b/src/google/adk/tests/integration/fixture/context_update_test/__init__.py
@@ -0,0 +1,15 @@
+# Copyright 2025 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from . import agent
diff --git a/src/google/adk/tests/integration/fixture/context_update_test/agent.py b/src/google/adk/tests/integration/fixture/context_update_test/agent.py
new file mode 100644
index 0000000..e114824
--- /dev/null
+++ b/src/google/adk/tests/integration/fixture/context_update_test/agent.py
@@ -0,0 +1,43 @@
+# Copyright 2025 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import List
+from typing import Union
+
+from google.adk import Agent
+from google.adk.tools import ToolContext
+from pydantic import BaseModel
+
+
+def update_fc(
+ data_one: str,
+ data_two: Union[int, float, str],
+ data_three: list[str],
+ data_four: List[Union[int, float, str]],
+ tool_context: ToolContext,
+):
+ """Simply ask to update these variables in the context"""
+ tool_context.actions.update_state("data_one", data_one)
+ tool_context.actions.update_state("data_two", data_two)
+ tool_context.actions.update_state("data_three", data_three)
+ tool_context.actions.update_state("data_four", data_four)
+
+
+root_agent = Agent(
+ model="gemini-1.5-flash",
+ name="root_agent",
+ instruction="Call tools",
+ flow="auto",
+ tools=[update_fc],
+)
diff --git a/src/google/adk/tests/integration/fixture/context_update_test/successful_test.session.json b/src/google/adk/tests/integration/fixture/context_update_test/successful_test.session.json
new file mode 100644
index 0000000..d45430a
--- /dev/null
+++ b/src/google/adk/tests/integration/fixture/context_update_test/successful_test.session.json
@@ -0,0 +1,582 @@
+{
+ "id": "ead43200-b575-4241-9248-233b4be4f29a",
+ "context": {
+ "_time": "2024-12-01 09:02:43.531503",
+ "data_one": "RRRR",
+ "data_two": "3.141529",
+ "data_three": [
+ "apple",
+ "banana"
+ ],
+ "data_four": [
+ "1",
+ "hello",
+ "3.14"
+ ]
+ },
+ "events": [
+ {
+ "invocation_id": "6BGrtKJu",
+ "author": "user",
+ "content": {
+ "parts": [
+ {
+ "text": "hi"
+ }
+ ],
+ "role": "user"
+ },
+ "options": {},
+ "id": "ltzQTqR4",
+ "timestamp": 1733043686.8428597
+ },
+ {
+ "invocation_id": "6BGrtKJu",
+ "author": "root_agent",
+ "content": {
+ "parts": [
+ {
+ "text": "Hello! 👋 How can I help you today? \n"
+ }
+ ],
+ "role": "model"
+ },
+ "options": {
+ "partial": false
+ },
+ "id": "ClSROx8b",
+ "timestamp": 1733043688.1030986
+ },
+ {
+ "invocation_id": "M3dUcVa8",
+ "author": "user",
+ "content": {
+ "parts": [
+ {
+ "text": "update data_one to be RRRR, data_two to be 3.141529, data_three to be apple and banana, data_four to be 1, hello, and 3.14"
+ }
+ ],
+ "role": "user"
+ },
+ "options": {},
+ "id": "yxigGwIZ",
+ "timestamp": 1733043745.9900541
+ },
+ {
+ "invocation_id": "M3dUcVa8",
+ "author": "root_agent",
+ "content": {
+ "parts": [
+ {
+ "function_call": {
+ "args": {
+ "data_four": [
+ "1",
+ "hello",
+ "3.14"
+ ],
+ "data_two": "3.141529",
+ "data_three": [
+ "apple",
+ "banana"
+ ],
+ "data_one": "RRRR"
+ },
+ "name": "update_fc"
+ }
+ }
+ ],
+ "role": "model"
+ },
+ "options": {
+ "partial": false
+ },
+ "id": "8V6de8th",
+ "timestamp": 1733043747.4545543
+ },
+ {
+ "invocation_id": "M3dUcVa8",
+ "author": "root_agent",
+ "content": {
+ "parts": [
+ {
+ "function_response": {
+ "name": "update_fc",
+ "response": {}
+ }
+ }
+ ],
+ "role": "user"
+ },
+ "options": {
+ "update_context": {
+ "data_one": "RRRR",
+ "data_two": "3.141529",
+ "data_three": [
+ "apple",
+ "banana"
+ ],
+ "data_four": [
+ "1",
+ "hello",
+ "3.14"
+ ]
+ },
+ "function_call_event_id": "8V6de8th"
+ },
+ "id": "dkTj5v8B",
+ "timestamp": 1733043747.457031
+ },
+ {
+ "invocation_id": "M3dUcVa8",
+ "author": "root_agent",
+ "content": {
+ "parts": [
+ {
+ "text": "OK. I've updated the data. Anything else? \n"
+ }
+ ],
+ "role": "model"
+ },
+ "options": {
+ "partial": false
+ },
+ "id": "OZ77XR41",
+ "timestamp": 1733043748.7901294
+ }
+ ],
+ "past_events": [],
+ "pending_events": {},
+ "artifacts": {},
+ "event_logs": [
+ {
+ "invocation_id": "6BGrtKJu",
+ "event_id": "ClSROx8b",
+ "model_request": {
+ "model": "gemini-1.5-flash",
+ "contents": [
+ {
+ "parts": [
+ {
+ "text": "hi"
+ }
+ ],
+ "role": "user"
+ }
+ ],
+ "config": {
+ "system_instruction": "You are an agent. Your name is root_agent.\nCall tools",
+ "tools": [
+ {
+ "function_declarations": [
+ {
+ "description": "Hello",
+ "name": "update_fc",
+ "parameters": {
+ "type": "OBJECT",
+ "properties": {
+ "data_one": {
+ "type": "STRING"
+ },
+ "data_two": {
+ "type": "STRING"
+ },
+ "data_three": {
+ "type": "ARRAY",
+ "items": {
+ "type": "STRING"
+ }
+ },
+ "data_four": {
+ "type": "ARRAY",
+ "items": {
+ "any_of": [
+ {
+ "type": "INTEGER"
+ },
+ {
+ "type": "NUMBER"
+ },
+ {
+ "type": "STRING"
+ }
+ ],
+ "type": "STRING"
+ }
+ }
+ }
+ }
+ }
+ ]
+ }
+ ]
+ }
+ },
+ "model_response": {
+ "candidates": [
+ {
+ "content": {
+ "parts": [
+ {
+ "text": "Hello! 👋 How can I help you today? \n"
+ }
+ ],
+ "role": "model"
+ },
+ "avg_logprobs": -0.15831730915949896,
+ "finish_reason": "STOP",
+ "safety_ratings": [
+ {
+ "category": "HARM_CATEGORY_HATE_SPEECH",
+ "probability": "NEGLIGIBLE",
+ "probability_score": 0.071777344,
+ "severity": "HARM_SEVERITY_NEGLIGIBLE",
+ "severity_score": 0.07080078
+ },
+ {
+ "category": "HARM_CATEGORY_DANGEROUS_CONTENT",
+ "probability": "NEGLIGIBLE",
+ "probability_score": 0.16308594,
+ "severity": "HARM_SEVERITY_NEGLIGIBLE",
+ "severity_score": 0.14160156
+ },
+ {
+ "category": "HARM_CATEGORY_HARASSMENT",
+ "probability": "NEGLIGIBLE",
+ "probability_score": 0.09423828,
+ "severity": "HARM_SEVERITY_NEGLIGIBLE",
+ "severity_score": 0.037841797
+ },
+ {
+ "category": "HARM_CATEGORY_SEXUALLY_EXPLICIT",
+ "probability": "NEGLIGIBLE",
+ "probability_score": 0.059326172,
+ "severity": "HARM_SEVERITY_NEGLIGIBLE",
+ "severity_score": 0.02368164
+ }
+ ]
+ }
+ ],
+ "model_version": "gemini-1.5-flash-001",
+ "usage_metadata": {
+ "candidates_token_count": 13,
+ "prompt_token_count": 32,
+ "total_token_count": 45
+ }
+ }
+ },
+ {
+ "invocation_id": "M3dUcVa8",
+ "event_id": "8V6de8th",
+ "model_request": {
+ "model": "gemini-1.5-flash",
+ "contents": [
+ {
+ "parts": [
+ {
+ "text": "hi"
+ }
+ ],
+ "role": "user"
+ },
+ {
+ "parts": [
+ {
+ "text": "Hello! 👋 How can I help you today? \n"
+ }
+ ],
+ "role": "model"
+ },
+ {
+ "parts": [
+ {
+ "text": "update data_one to be RRRR, data_two to be 3.141529, data_three to be apple and banana, data_four to be 1, hello, and 3.14"
+ }
+ ],
+ "role": "user"
+ }
+ ],
+ "config": {
+ "system_instruction": "You are an agent. Your name is root_agent.\nCall tools",
+ "tools": [
+ {
+ "function_declarations": [
+ {
+ "description": "Hello",
+ "name": "update_fc",
+ "parameters": {
+ "type": "OBJECT",
+ "properties": {
+ "data_one": {
+ "type": "STRING"
+ },
+ "data_two": {
+ "type": "STRING"
+ },
+ "data_three": {
+ "type": "ARRAY",
+ "items": {
+ "type": "STRING"
+ }
+ },
+ "data_four": {
+ "type": "ARRAY",
+ "items": {
+ "any_of": [
+ {
+ "type": "INTEGER"
+ },
+ {
+ "type": "NUMBER"
+ },
+ {
+ "type": "STRING"
+ }
+ ],
+ "type": "STRING"
+ }
+ }
+ }
+ }
+ }
+ ]
+ }
+ ]
+ }
+ },
+ "model_response": {
+ "candidates": [
+ {
+ "content": {
+ "parts": [
+ {
+ "function_call": {
+ "args": {
+ "data_four": [
+ "1",
+ "hello",
+ "3.14"
+ ],
+ "data_two": "3.141529",
+ "data_three": [
+ "apple",
+ "banana"
+ ],
+ "data_one": "RRRR"
+ },
+ "name": "update_fc"
+ }
+ }
+ ],
+ "role": "model"
+ },
+ "avg_logprobs": -2.100960955431219e-6,
+ "finish_reason": "STOP",
+ "safety_ratings": [
+ {
+ "category": "HARM_CATEGORY_HATE_SPEECH",
+ "probability": "NEGLIGIBLE",
+ "probability_score": 0.12158203,
+ "severity": "HARM_SEVERITY_NEGLIGIBLE",
+ "severity_score": 0.13671875
+ },
+ {
+ "category": "HARM_CATEGORY_DANGEROUS_CONTENT",
+ "probability": "NEGLIGIBLE",
+ "probability_score": 0.421875,
+ "severity": "HARM_SEVERITY_LOW",
+ "severity_score": 0.24511719
+ },
+ {
+ "category": "HARM_CATEGORY_HARASSMENT",
+ "probability": "NEGLIGIBLE",
+ "probability_score": 0.15722656,
+ "severity": "HARM_SEVERITY_NEGLIGIBLE",
+ "severity_score": 0.072753906
+ },
+ {
+ "category": "HARM_CATEGORY_SEXUALLY_EXPLICIT",
+ "probability": "NEGLIGIBLE",
+ "probability_score": 0.083984375,
+ "severity": "HARM_SEVERITY_NEGLIGIBLE",
+ "severity_score": 0.03564453
+ }
+ ]
+ }
+ ],
+ "model_version": "gemini-1.5-flash-001",
+ "usage_metadata": {
+ "candidates_token_count": 32,
+ "prompt_token_count": 94,
+ "total_token_count": 126
+ }
+ }
+ },
+ {
+ "invocation_id": "M3dUcVa8",
+ "event_id": "OZ77XR41",
+ "model_request": {
+ "model": "gemini-1.5-flash",
+ "contents": [
+ {
+ "parts": [
+ {
+ "text": "hi"
+ }
+ ],
+ "role": "user"
+ },
+ {
+ "parts": [
+ {
+ "text": "Hello! 👋 How can I help you today? \n"
+ }
+ ],
+ "role": "model"
+ },
+ {
+ "parts": [
+ {
+ "text": "update data_one to be RRRR, data_two to be 3.141529, data_three to be apple and banana, data_four to be 1, hello, and 3.14"
+ }
+ ],
+ "role": "user"
+ },
+ {
+ "parts": [
+ {
+ "function_call": {
+ "args": {
+ "data_four": [
+ "1",
+ "hello",
+ "3.14"
+ ],
+ "data_two": "3.141529",
+ "data_three": [
+ "apple",
+ "banana"
+ ],
+ "data_one": "RRRR"
+ },
+ "name": "update_fc"
+ }
+ }
+ ],
+ "role": "model"
+ },
+ {
+ "parts": [
+ {
+ "function_response": {
+ "name": "update_fc",
+ "response": {}
+ }
+ }
+ ],
+ "role": "user"
+ }
+ ],
+ "config": {
+ "system_instruction": "You are an agent. Your name is root_agent.\nCall tools",
+ "tools": [
+ {
+ "function_declarations": [
+ {
+ "description": "Hello",
+ "name": "update_fc",
+ "parameters": {
+ "type": "OBJECT",
+ "properties": {
+ "data_one": {
+ "type": "STRING"
+ },
+ "data_two": {
+ "type": "STRING"
+ },
+ "data_three": {
+ "type": "ARRAY",
+ "items": {
+ "type": "STRING"
+ }
+ },
+ "data_four": {
+ "type": "ARRAY",
+ "items": {
+ "any_of": [
+ {
+ "type": "INTEGER"
+ },
+ {
+ "type": "NUMBER"
+ },
+ {
+ "type": "STRING"
+ }
+ ],
+ "type": "STRING"
+ }
+ }
+ }
+ }
+ }
+ ]
+ }
+ ]
+ }
+ },
+ "model_response": {
+ "candidates": [
+ {
+ "content": {
+ "parts": [
+ {
+ "text": "OK. I've updated the data. Anything else? \n"
+ }
+ ],
+ "role": "model"
+ },
+ "avg_logprobs": -0.22089435373033797,
+ "finish_reason": "STOP",
+ "safety_ratings": [
+ {
+ "category": "HARM_CATEGORY_HATE_SPEECH",
+ "probability": "NEGLIGIBLE",
+ "probability_score": 0.04663086,
+ "severity": "HARM_SEVERITY_NEGLIGIBLE",
+ "severity_score": 0.09423828
+ },
+ {
+ "category": "HARM_CATEGORY_DANGEROUS_CONTENT",
+ "probability": "NEGLIGIBLE",
+ "probability_score": 0.18554688,
+ "severity": "HARM_SEVERITY_NEGLIGIBLE",
+ "severity_score": 0.111328125
+ },
+ {
+ "category": "HARM_CATEGORY_HARASSMENT",
+ "probability": "NEGLIGIBLE",
+ "probability_score": 0.071777344,
+ "severity": "HARM_SEVERITY_NEGLIGIBLE",
+ "severity_score": 0.03112793
+ },
+ {
+ "category": "HARM_CATEGORY_SEXUALLY_EXPLICIT",
+ "probability": "NEGLIGIBLE",
+ "probability_score": 0.043945313,
+ "severity": "HARM_SEVERITY_NEGLIGIBLE",
+ "severity_score": 0.057373047
+ }
+ ]
+ }
+ ],
+ "model_version": "gemini-1.5-flash-001",
+ "usage_metadata": {
+ "candidates_token_count": 14,
+ "prompt_token_count": 129,
+ "total_token_count": 143
+ }
+ }
+ }
+ ]
+}
\ No newline at end of file
diff --git a/src/google/adk/tests/integration/fixture/context_variable_agent/__init__.py b/src/google/adk/tests/integration/fixture/context_variable_agent/__init__.py
new file mode 100644
index 0000000..c48963c
--- /dev/null
+++ b/src/google/adk/tests/integration/fixture/context_variable_agent/__init__.py
@@ -0,0 +1,15 @@
+# Copyright 2025 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from . import agent
diff --git a/src/google/adk/tests/integration/fixture/context_variable_agent/agent.py b/src/google/adk/tests/integration/fixture/context_variable_agent/agent.py
new file mode 100644
index 0000000..a18b61c
--- /dev/null
+++ b/src/google/adk/tests/integration/fixture/context_variable_agent/agent.py
@@ -0,0 +1,115 @@
+# Copyright 2025 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import List
+from typing import Union
+
+from google.adk import Agent
+from google.adk.agents.invocation_context import InvocationContext
+from google.adk.planners import PlanReActPlanner
+from google.adk.tools import ToolContext
+
+
+def update_fc(
+ data_one: str,
+ data_two: Union[int, float, str],
+ data_three: list[str],
+ data_four: List[Union[int, float, str]],
+ tool_context: ToolContext,
+) -> str:
+ """Simply ask to update these variables in the context"""
+ tool_context.actions.update_state('data_one', data_one)
+ tool_context.actions.update_state('data_two', data_two)
+ tool_context.actions.update_state('data_three', data_three)
+ tool_context.actions.update_state('data_four', data_four)
+ return 'The function `update_fc` executed successfully'
+
+
+def echo_info(customer_id: str) -> str:
+ """Echo the context variable"""
+ return customer_id
+
+
+def build_global_instruction(invocation_context: InvocationContext) -> str:
+ return (
+ 'This is the gloabl agent instruction for invocation:'
+ f' {invocation_context.invocation_id}.'
+ )
+
+
+def build_sub_agent_instruction(invocation_context: InvocationContext) -> str:
+ return 'This is the plain text sub agent instruction.'
+
+
+context_variable_echo_agent = Agent(
+ model='gemini-1.5-flash',
+ name='context_variable_echo_agent',
+ instruction=(
+ 'Use the echo_info tool to echo {customerId}, {customerInt},'
+ ' {customerFloat}, and {customerJson}. Ask for it if you need to.'
+ ),
+ flow='auto',
+ tools=[echo_info],
+)
+
+context_variable_with_complicated_format_agent = Agent(
+ model='gemini-1.5-flash',
+ name='context_variable_echo_agent',
+ instruction=(
+ 'Use the echo_info tool to echo { customerId }, {{customer_int }, { '
+ " non-identifier-float}}, {artifact.fileName}, {'key1': 'value1'} and"
+ " {{'key2': 'value2'}}. Ask for it if you need to."
+ ),
+ flow='auto',
+ tools=[echo_info],
+)
+
+context_variable_with_nl_planner_agent = Agent(
+ model='gemini-1.5-flash',
+ name='context_variable_with_nl_planner_agent',
+ instruction=(
+ 'Use the echo_info tool to echo {customerId}. Ask for it if you'
+ ' need to.'
+ ),
+ flow='auto',
+ planner=PlanReActPlanner(),
+ tools=[echo_info],
+)
+
+context_variable_with_function_instruction_agent = Agent(
+ model='gemini-1.5-flash',
+ name='context_variable_with_function_instruction_agent',
+ instruction=build_sub_agent_instruction,
+ flow='auto',
+)
+
+context_variable_update_agent = Agent(
+ model='gemini-1.5-flash',
+ name='context_variable_update_agent',
+ instruction='Call tools',
+ flow='auto',
+ tools=[update_fc],
+)
+
+root_agent = Agent(
+ model='gemini-1.5-flash',
+ name='root_agent',
+ description='The root agent.',
+ flow='auto',
+ global_instruction=build_global_instruction,
+ sub_agents=[
+ context_variable_with_nl_planner_agent,
+ context_variable_update_agent,
+ ],
+)
diff --git a/src/google/adk/tests/integration/fixture/customer_support_ma/__init__.py b/src/google/adk/tests/integration/fixture/customer_support_ma/__init__.py
new file mode 100644
index 0000000..c48963c
--- /dev/null
+++ b/src/google/adk/tests/integration/fixture/customer_support_ma/__init__.py
@@ -0,0 +1,15 @@
+# Copyright 2025 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from . import agent
diff --git a/src/google/adk/tests/integration/fixture/customer_support_ma/agent.py b/src/google/adk/tests/integration/fixture/customer_support_ma/agent.py
new file mode 100644
index 0000000..696f380
--- /dev/null
+++ b/src/google/adk/tests/integration/fixture/customer_support_ma/agent.py
@@ -0,0 +1,172 @@
+# Copyright 2025 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os
+import sys
+
+from google.adk import Agent
+from google.adk.agents import RemoteAgent
+from google.adk.examples import Example
+from google.adk.sessions import Session
+from google.genai import types
+
+
+def reset_data():
+ pass
+
+
+def fetch_user_flight_information(customer_email: str) -> str:
+ """Fetch user flight information."""
+ return """
+[{"ticket_no": "7240005432906569", "book_ref": "C46E9F", "flight_id": 19250, "flight_no": "LX0112", "departure_airport": "CDG", "arrival_airport": "BSL", "scheduled_departure": "2024-12-30 12:09:03.561731-04:00", "scheduled_arrival": "2024-12-30 13:39:03.561731-04:00", "seat_no": "18E", "fare_conditions": "Economy"}]
+"""
+
+
+def list_customer_flights(customer_email: str) -> str:
+ return "{'flights': [{'book_ref': 'C46E9F'}]}"
+
+
+def update_ticket_to_new_flight(ticket_no: str, new_flight_id: str) -> str:
+ return 'OK, your ticket has been updated.'
+
+
+def lookup_company_policy(topic: str) -> str:
+ """Lookup policies for flight cancelation and rebooking."""
+ return """
+1. How can I change my booking?
+ * The ticket number must start with 724 (SWISS ticket no./plate).
+ * The ticket was not paid for by barter or voucher (there are exceptions to voucher payments; if the ticket was paid for in full by voucher, then it may be possible to rebook online under certain circumstances. If it is not possible to rebook online because of the payment method, then you will be informed accordingly during the rebooking process).
+ * There must be an active flight booking for your ticket. It is not possible to rebook open tickets or tickets without the corresponding flight segments online at the moment.
+ * It is currently only possible to rebook outbound (one-way) tickets or return tickets with single flight routes (point-to-point).
+"""
+
+
+def search_flights(
+ departure_airport: str = None,
+ arrival_airport: str = None,
+ start_time: str = None,
+ end_time: str = None,
+) -> list[dict]:
+ return """
+[{"flight_id": 19238, "flight_no": "LX0112", "scheduled_departure": "2024-05-08 12:09:03.561731-04:00", "scheduled_arrival": "2024-05-08 13:39:03.561731-04:00", "departure_airport": "CDG", "arrival_airport": "BSL", "status": "Scheduled", "aircraft_code": "SU9", "actual_departure": null, "actual_arrival": null}, {"flight_id": 19242, "flight_no": "LX0112", "scheduled_departure": "2024-05-09 12:09:03.561731-04:00", "scheduled_arrival": "2024-05-09 13:39:03.561731-04:00", "departure_airport": "CDG", "arrival_airport": "BSL", "status": "Scheduled", "aircraft_code": "SU9", "actual_departure": null, "actual_arrival": null}]"""
+
+
+def search_hotels(
+ location: str = None,
+ price_tier: str = None,
+ checkin_date: str = None,
+ checkout_date: str = None,
+) -> list[dict]:
+ return """
+[{"id": 1, "name": "Hilton Basel", "location": "Basel", "price_tier": "Luxury"}, {"id": 3, "name": "Hyatt Regency Basel", "location": "Basel", "price_tier": "Upper Upscale"}, {"id": 8, "name": "Holiday Inn Basel", "location": "Basel", "price_tier": "Upper Midscale"}]
+"""
+
+
+def book_hotel(hotel_name: str) -> str:
+ return 'OK, your hotel has been booked.'
+
+
+def before_model_call(agent: Agent, session: Session, user_message):
+ if 'expedia' in user_message.lower():
+ response = types.Content(
+ role='model',
+ parts=[types.Part(text="Sorry, I can't answer this question.")],
+ )
+ return response
+ return None
+
+
+def after_model_call(
+ agent: Agent, session: Session, content: types.Content
+) -> bool:
+ model_message = content.parts[0].text
+ if 'expedia' in model_message.lower():
+ response = types.Content(
+ role='model',
+ parts=[types.Part(text="Sorry, I can't answer this question.")],
+ )
+ return response
+ return None
+
+
+flight_agent = Agent(
+ model='gemini-1.5-pro',
+ name='flight_agent',
+ description='Handles flight information, policy and updates',
+ instruction="""
+ You are a specialized assistant for handling flight updates.
+ The primary assistant delegates work to you whenever the user needs help updating their bookings.
+ Confirm the updated flight details with the customer and inform them of any additional fees.
+ When searching, be persistent. Expand your query bounds if the first search returns no results.
+ Remember that a booking isn't completed until after the relevant tool has successfully been used.
+ Do not waste the user's time. Do not make up invalid tools or functions.
+""",
+ tools=[
+ list_customer_flights,
+ lookup_company_policy,
+ fetch_user_flight_information,
+ search_flights,
+ update_ticket_to_new_flight,
+ ],
+)
+
+hotel_agent = Agent(
+ model='gemini-1.5-pro',
+ name='hotel_agent',
+ description='Handles hotel information and booking',
+ instruction="""
+ You are a specialized assistant for handling hotel bookings.
+ The primary assistant delegates work to you whenever the user needs help booking a hotel.
+ Search for available hotels based on the user's preferences and confirm the booking details with the customer.
+ When searching, be persistent. Expand your query bounds if the first search returns no results.
+""",
+ tools=[search_hotels, book_hotel],
+)
+
+
+idea_agent = RemoteAgent(
+ model='gemini-1.5-pro',
+ name='idea_agent',
+ description='Provide travel ideas base on the destination.',
+ url='http://localhost:8000/agent/run',
+)
+
+
+root_agent = Agent(
+ model='gemini-1.5-pro',
+ name='root_agent',
+ instruction="""
+ You are a helpful customer support assistant for Swiss Airlines.
+""",
+ sub_agents=[flight_agent, hotel_agent, idea_agent],
+ flow='auto',
+ examples=[
+ Example(
+ input=types.Content(
+ role='user',
+ parts=[types.Part(text='How were you built?')],
+ ),
+ output=[
+ types.Content(
+ role='model',
+ parts=[
+ types.Part(
+ text='I was built with the best agent framework.'
+ )
+ ],
+ )
+ ],
+ ),
+ ],
+)
diff --git a/src/google/adk/tests/integration/fixture/ecommerce_customer_service_agent/__init__.py b/src/google/adk/tests/integration/fixture/ecommerce_customer_service_agent/__init__.py
new file mode 100644
index 0000000..c48963c
--- /dev/null
+++ b/src/google/adk/tests/integration/fixture/ecommerce_customer_service_agent/__init__.py
@@ -0,0 +1,15 @@
+# Copyright 2025 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from . import agent
diff --git a/src/google/adk/tests/integration/fixture/ecommerce_customer_service_agent/agent.py b/src/google/adk/tests/integration/fixture/ecommerce_customer_service_agent/agent.py
new file mode 100644
index 0000000..4ac9ba1
--- /dev/null
+++ b/src/google/adk/tests/integration/fixture/ecommerce_customer_service_agent/agent.py
@@ -0,0 +1,338 @@
+# Copyright 2025 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from google.adk import Agent
+
+# A lightweight in-memory mock database
+ORDER_DB = {
+ "1": "FINISHED",
+ "2": "CANCELED",
+ "3": "PENDING",
+ "4": "PENDING",
+} # Order id to status mapping. Available states: 'FINISHED', 'PENDING', and 'CANCELED'
+USER_TO_ORDER_DB = {
+ "user_a": ["1", "4"],
+ "user_b": ["2"],
+ "user_c": ["3"],
+} # User id to Order id mapping
+TICKET_DB = [{
+ "ticket_id": "1",
+ "user_id": "user_a",
+ "issue_type": "LOGIN_ISSUE",
+ "status": "OPEN",
+}] # Available states: 'OPEN', 'CLOSED', 'ESCALATED'
+USER_INFO_DB = {
+ "user_a": {"name": "Alice", "email": "alice@example.com"},
+ "user_b": {"name": "Bob", "email": "bob@example.com"},
+}
+
+
+def reset_data():
+ global ORDER_DB
+ global USER_TO_ORDER_DB
+ global TICKET_DB
+ global USER_INFO_DB
+ ORDER_DB = {
+ "1": "FINISHED",
+ "2": "CANCELED",
+ "3": "PENDING",
+ "4": "PENDING",
+ }
+ USER_TO_ORDER_DB = {
+ "user_a": ["1", "4"],
+ "user_b": ["2"],
+ "user_c": ["3"],
+ }
+ TICKET_DB = [{
+ "ticket_id": "1",
+ "user_id": "user_a",
+ "issue_type": "LOGIN_ISSUE",
+ "status": "OPEN",
+ }]
+ USER_INFO_DB = {
+ "user_a": {"name": "Alice", "email": "alice@example.com"},
+ "user_b": {"name": "Bob", "email": "bob@example.com"},
+ }
+
+
+def get_order_status(order_id: str) -> str:
+ """Get the status of an order.
+
+ Args:
+ order_id (str): The unique identifier of the order.
+
+ Returns:
+ str: The status of the order (e.g., 'FINISHED', 'CANCELED', 'PENDING'),
+ or 'Order not found' if the order_id does not exist.
+ """
+ return ORDER_DB.get(order_id, "Order not found")
+
+
+def get_order_ids_for_user(user_id: str) -> list:
+ """Get the list of order IDs assigned to a specific transaction associated with a user.
+
+ Args:
+ user_id (str): The unique identifier of the user.
+
+ Returns:
+ List[str]: A list of order IDs associated with the user, or an empty list
+ if no orders are found.
+ """
+ return USER_TO_ORDER_DB.get(user_id, [])
+
+
+def cancel_order(order_id: str) -> str:
+ """Cancel an order if it is in a 'PENDING' state.
+
+ You should call "get_order_status" to check the status first, before calling
+ this tool.
+
+ Args:
+ order_id (str): The unique identifier of the order to be canceled.
+
+ Returns:
+ str: A message indicating whether the order was successfully canceled or
+ not.
+ """
+ if order_id in ORDER_DB and ORDER_DB[order_id] == "PENDING":
+ ORDER_DB[order_id] = "CANCELED"
+ return f"Order {order_id} has been canceled."
+ return f"Order {order_id} cannot be canceled."
+
+
+def refund_order(order_id: str) -> str:
+ """Process a refund for an order if it is in a 'CANCELED' state.
+
+ You should call "get_order_status" to check if status first, before calling
+ this tool.
+
+ Args:
+ order_id (str): The unique identifier of the order to be refunded.
+
+ Returns:
+ str: A message indicating whether the order was successfully refunded or
+ not.
+ """
+ if order_id in ORDER_DB and ORDER_DB[order_id] == "CANCELED":
+ return f"Order {order_id} has been refunded."
+ return f"Order {order_id} cannot be refunded."
+
+
+def create_ticket(user_id: str, issue_type: str) -> str:
+ """Create a new support ticket for a user.
+
+ Args:
+ user_id (str): The unique identifier of the user creating the ticket.
+ issue_type (str): An issue type the user is facing. Available types:
+ 'LOGIN_ISSUE', 'ORDER_ISSUE', 'OTHER'.
+
+ Returns:
+ str: A message indicating that the ticket was created successfully,
+ including the ticket ID.
+ """
+ ticket_id = str(len(TICKET_DB) + 1)
+ TICKET_DB.append({
+ "ticket_id": ticket_id,
+ "user_id": user_id,
+ "issue_type": issue_type,
+ "status": "OPEN",
+ })
+ return f"Ticket {ticket_id} created successfully."
+
+
+def get_ticket_info(ticket_id: str) -> str:
+ """Retrieve the information of a support ticket.
+
+ current status of a support ticket.
+
+ Args:
+ ticket_id (str): The unique identifier of the ticket.
+
+ Returns:
+ A dictionary contains the following fields, or 'Ticket not found' if the
+ ticket_id does not exist:
+ - "ticket_id": str, the current ticket id
+ - "user_id": str, the associated user id
+ - "issue": str, the issue type
+ - "status": The current status of the ticket (e.g., 'OPEN', 'CLOSED',
+ 'ESCALATED')
+
+ Example: {"ticket_id": "1", "user_id": "user_a", "issue": "Login issue",
+ "status": "OPEN"}
+ """
+ for ticket in TICKET_DB:
+ if ticket["ticket_id"] == ticket_id:
+ return ticket
+ return "Ticket not found"
+
+
+def get_tickets_for_user(user_id: str) -> list:
+ """Get all the ticket IDs associated with a user.
+
+ Args:
+ user_id (str): The unique identifier of the user.
+
+ Returns:
+ List[str]: A list of ticket IDs associated with the user.
+ If no tickets are found, returns an empty list.
+ """
+ return [
+ ticket["ticket_id"]
+ for ticket in TICKET_DB
+ if ticket["user_id"] == user_id
+ ]
+
+
+def update_ticket_status(ticket_id: str, status: str) -> str:
+ """Update the status of a support ticket.
+
+ Args:
+ ticket_id (str): The unique identifier of the ticket.
+ status (str): The new status to assign to the ticket (e.g., 'OPEN',
+ 'CLOSED', 'ESCALATED').
+
+ Returns:
+ str: A message indicating whether the ticket status was successfully
+ updated.
+ """
+ for ticket in TICKET_DB:
+ if ticket["ticket_id"] == ticket_id:
+ ticket["status"] = status
+ return f"Ticket {ticket_id} status updated to {status}."
+ return "Ticket not found"
+
+
+def get_user_info(user_id: str) -> dict:
+ """Retrieve information (name, email) about a user.
+
+ Args:
+ user_id (str): The unique identifier of the user.
+
+ Returns:
+ dict or str: A dictionary containing user information of the following
+ fields, or 'User not found' if the user_id does not exist:
+
+ - name: The name of the user
+ - email: The email address of the user
+
+ For example, {"name": "Chelsea", "email": "123@example.com"}
+ """
+ return USER_INFO_DB.get(user_id, "User not found")
+
+
+def send_email(user_id: str, email: str) -> list:
+ """Send email to user for notification.
+
+ Args:
+ user_id (str): The unique identifier of the user.
+ email (str): The email address of the user.
+
+ Returns:
+ str: A message indicating whether the email was successfully sent.
+ """
+ if user_id in USER_INFO_DB:
+ return f"Email sent to {email} for user id {user_id}"
+ return "Cannot find this user"
+
+
+# def update_user_info(user_id: str, new_info: dict[str, str]) -> str:
+def update_user_info(user_id: str, email: str, name: str) -> str:
+ """Update a user's information.
+
+ Args:
+ user_id (str): The unique identifier of the user.
+ new_info (dict): A dictionary containing the fields to be updated (e.g.,
+ {'email': 'new_email@example.com'}). Available field keys: 'email' and
+ 'name'.
+
+ Returns:
+ str: A message indicating whether the user's information was successfully
+ updated or not.
+ """
+ if user_id in USER_INFO_DB:
+ # USER_INFO_DB[user_id].update(new_info)
+ if email and name:
+ USER_INFO_DB[user_id].update({"email": email, "name": name})
+ elif email:
+ USER_INFO_DB[user_id].update({"email": email})
+ elif name:
+ USER_INFO_DB[user_id].update({"name": name})
+ else:
+ raise ValueError("this should not happen.")
+ return f"User {user_id} information updated."
+ return "User not found"
+
+
+def get_user_id_from_cookie() -> str:
+ """Get user ID(username) from the cookie.
+
+ Only use this function when you do not know user ID(username).
+
+ Args: None
+
+ Returns:
+ str: The user ID.
+ """
+ return "user_a"
+
+
+root_agent = Agent(
+ model="gemini-2.0-flash-001",
+ name="Ecommerce_Customer_Service",
+ instruction="""
+ You are an intelligent customer service assistant for an e-commerce platform. Your goal is to accurately understand user queries and use the appropriate tools to fulfill requests. Follow these guidelines:
+
+ 1. **Understand the Query**:
+ - Identify actions and conditions (e.g., create a ticket only for pending orders).
+ - Extract necessary details (e.g., user ID, order ID) from the query or infer them from the context.
+
+ 2. **Plan Multi-Step Workflows**:
+ - Break down complex queries into sequential steps. For example
+ - typical workflow:
+ - Retrieve IDs or references first (e.g., orders for a user).
+ - Evaluate conditions (e.g., check order status).
+ - Perform actions (e.g., create a ticket) only when conditions are met.
+ - another typical workflows - order cancellation and refund:
+ - Retrieve all orders for the user (`get_order_ids_for_user`).
+ - Cancel pending orders (`cancel_order`).
+ - Refund canceled orders (`refund_order`).
+ - Notify the user (`send_email`).
+ - another typical workflows - send user report:
+ - Get user id.
+ - Get user info(like emails)
+ - Send email to user.
+
+ 3. **Avoid Skipping Steps**:
+ - Ensure each intermediate step is completed before moving to the next.
+ - Do not create tickets or take other actions without verifying the conditions specified in the query.
+
+ 4. **Provide Clear Responses**:
+ - Confirm the actions performed, including details like ticket ID or pending orders.
+ - Ensure the response aligns with the steps taken and query intent.
+ """,
+ tools=[
+ get_order_status,
+ cancel_order,
+ get_order_ids_for_user,
+ refund_order,
+ create_ticket,
+ update_ticket_status,
+ get_tickets_for_user,
+ get_ticket_info,
+ get_user_info,
+ send_email,
+ update_user_info,
+ get_user_id_from_cookie,
+ ],
+)
diff --git a/src/google/adk/tests/integration/fixture/ecommerce_customer_service_agent/order_query.test.json b/src/google/adk/tests/integration/fixture/ecommerce_customer_service_agent/order_query.test.json
new file mode 100644
index 0000000..ac424f3
--- /dev/null
+++ b/src/google/adk/tests/integration/fixture/ecommerce_customer_service_agent/order_query.test.json
@@ -0,0 +1,69 @@
+[
+ {
+ "query": "Send an email to user user_a whose email address is alice@example.com",
+ "expected_tool_use": [
+ {
+ "tool_name": "send_email",
+ "tool_input": {
+ "email": "alice@example.com",
+ "user_id": "user_a"
+ }
+ }
+ ],
+ "reference": "Email sent to alice@example.com for user id user_a."
+ },
+ {
+ "query": "Can you tell me the status of my order with ID 1?",
+ "expected_tool_use": [
+ {
+ "tool_name": "get_order_status",
+ "tool_input": {
+ "order_id": "1"
+ }
+ }
+ ],
+ "reference": "Your order with ID 1 is FINISHED."
+ },
+ {
+ "query": "Cancel all pending order for the user with user id user_a",
+ "expected_tool_use": [
+ {
+ "tool_name": "get_order_ids_for_user",
+ "tool_input": {
+ "user_id": "user_a"
+ }
+ },
+ {
+ "tool_name": "get_order_status",
+ "tool_input": {
+ "order_id": "1"
+ }
+ },
+ {
+ "tool_name": "get_order_status",
+ "tool_input": {
+ "order_id": "4"
+ }
+ },
+ {
+ "tool_name": "cancel_order",
+ "tool_input": {
+ "order_id": "4"
+ }
+ }
+ ],
+ "reference": "I have checked your orders and order 4 was in pending status, so I have cancelled it. Order 1 was already finished and couldn't be cancelled.\n"
+ },
+ {
+ "query": "What orders have I placed under the username user_b?",
+ "expected_tool_use": [
+ {
+ "tool_name": "get_order_ids_for_user",
+ "tool_input": {
+ "user_id": "user_b"
+ }
+ }
+ ],
+ "reference": "User user_b has placed one order with order ID 2.\n"
+ }
+]
diff --git a/src/google/adk/tests/integration/fixture/ecommerce_customer_service_agent/test_config.json b/src/google/adk/tests/integration/fixture/ecommerce_customer_service_agent/test_config.json
new file mode 100644
index 0000000..b16d0a9
--- /dev/null
+++ b/src/google/adk/tests/integration/fixture/ecommerce_customer_service_agent/test_config.json
@@ -0,0 +1,6 @@
+{
+ "criteria": {
+ "tool_trajectory_avg_score": 0.7,
+ "response_match_score": 0.5
+ }
+}
diff --git a/src/google/adk/tests/integration/fixture/flow_complex_spark/__init__.py b/src/google/adk/tests/integration/fixture/flow_complex_spark/__init__.py
new file mode 100644
index 0000000..c48963c
--- /dev/null
+++ b/src/google/adk/tests/integration/fixture/flow_complex_spark/__init__.py
@@ -0,0 +1,15 @@
+# Copyright 2025 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from . import agent
diff --git a/src/google/adk/tests/integration/fixture/flow_complex_spark/agent.py b/src/google/adk/tests/integration/fixture/flow_complex_spark/agent.py
new file mode 100644
index 0000000..18ce62f
--- /dev/null
+++ b/src/google/adk/tests/integration/fixture/flow_complex_spark/agent.py
@@ -0,0 +1,182 @@
+# Copyright 2025 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from google.adk import Agent
+from google.genai import types
+
+research_plan_agent = Agent(
+ model="gemini-1.5-flash",
+ name="research_plan_agent",
+ description="I can help generate research plan.",
+ instruction="""\
+Your task is to create a research plan according to the user's query.
+
+# Here are the instructions for creating the research plan:
+
++ Focus on finding specific things, e.g. products, data, etc.
++ Have the personality of a work colleague that is very helpful and explains things very nicely.
++ Don't mention your name unless you are asked.
++ Think about the most common things that you would need to research.
++ Think about possible answers when creating the plan.
++ Your task is to create the sections that should be researched. You will output high level headers, preceded by ##
++ Underneath each header, write a short sentence on what we want to find there.
++ The headers will follow the logical analysis pattern, as well as logical exploration pattern.
++ The headers should be a statement, not be in the form of questions.
++ The header will not include roman numerals or anything of the sort, e.g. ":", etc
++ Do not include things that you cannot possibly know about from using Google Search: e.g. sales forecasting, competitors, profitability analysis, etc.
++ Do not have an executive summary
++ In each section describe specifically what will be researched.
++ Never use "we will", but rather "I will".
++ Don't ask for clarifications from the user.
++ Do not ask the user for clarifications or if they have any other questions.
++ All headers should be bolded.
++ If you have steps in the plan that depend on other information, make sure they are 2 diferent sections in the plan.
++ At the end mention that you will start researching.
+
+# Instruction on replying format
+
++ Start with your name as "[research_plan_agent]: ".
++ Output the content you want to say.
+
+Output summary:
+""",
+ flow="single",
+ sub_agents=[],
+ generate_content_config=types.GenerateContentConfig(
+ temperature=0.1,
+ ),
+)
+
+
+question_generation_agent = Agent(
+ model="gemini-1.5-flash",
+ name="question_generation_agent",
+ description="I can help generate questions related to user's question.",
+ instruction="""\
+Generate questions related to the research plan generated by research_plan_agent.
+
+# Instruction on replying format
+
+Your reply should be a numbered lsit.
+
+For each question, reply in the following format: "[question_generation_agent]: [generated questions]"
+
+Here is an example of the generated question list:
+
+1. [question_generation_agent]: which state is San Jose in?
+2. [question_generation_agent]: how google website is designed?
+""",
+ flow="single",
+ sub_agents=[],
+ generate_content_config=types.GenerateContentConfig(
+ temperature=0.1,
+ ),
+)
+
+information_retrieval_agent = Agent(
+ model="gemini-1.5-flash",
+ name="information_retrieval_agent",
+ description=(
+ "I can help retrieve information related to question_generation_agent's"
+ " question."
+ ),
+ instruction="""\
+Inspect all the questions after "[question_generation_agent]: " and asnwer them.
+
+# Instruction on replying format
+
+Always start with "[information_retrieval_agent]: "
+
+For the answer of one question:
+
+- Start with a title with one line summary of the reply.
+- The title line should be bolded and starts with No.x of the corresponding question.
+- Have a paragraph of detailed explain.
+
+# Instruction on exiting the loop
+
+- If you see there are less than 20 questions by "question_generation_agent", do not say "[exit]".
+- If you see there are already great or equal to 20 questions asked by "question_generation_agent", say "[exit]" at last to exit the loop.
+""",
+ flow="single",
+ sub_agents=[],
+ generate_content_config=types.GenerateContentConfig(
+ temperature=0.1,
+ ),
+)
+
+question_sources_generation_agent = Agent(
+ model="gemini-1.5-flash",
+ name="question_sources_generation_agent",
+ description=(
+ "I can help generate questions and retrieve related information."
+ ),
+ instruction="Generate questions and retrieve information.",
+ flow="loop",
+ sub_agents=[
+ question_generation_agent,
+ information_retrieval_agent,
+ ],
+ generate_content_config=types.GenerateContentConfig(
+ temperature=0.1,
+ ),
+)
+
+summary_agent = Agent(
+ model="gemini-1.5-flash",
+ name="summary_agent",
+ description="I can help summarize information of previous content.",
+ instruction="""\
+Summarize information in all historical messages that were replied by "question_generation_agent" and "information_retrieval_agent".
+
+# Instruction on replying format
+
+- The output should be like an essay that has a title, an abstract, multiple paragraphs for each topic and a conclusion.
+- Each paragraph should maps to one or more question in historical content.
+""",
+ flow="single",
+ generate_content_config=types.GenerateContentConfig(
+ temperature=0.8,
+ ),
+)
+
+research_assistant = Agent(
+ model="gemini-1.5-flash",
+ name="research_assistant",
+ description="I can help with research question.",
+ instruction="Help customers with their need.",
+ flow="sequential",
+ sub_agents=[
+ research_plan_agent,
+ question_sources_generation_agent,
+ summary_agent,
+ ],
+ generate_content_config=types.GenerateContentConfig(
+ temperature=0.1,
+ ),
+)
+
+spark_agent = Agent(
+ model="gemini-1.5-flash",
+ name="spark_assistant",
+ description="I can help with non-research question.",
+ instruction="Help customers with their need.",
+ flow="auto",
+ sub_agents=[research_assistant],
+ generate_content_config=types.GenerateContentConfig(
+ temperature=0.1,
+ ),
+)
+
+root_agent = spark_agent
diff --git a/src/google/adk/tests/integration/fixture/flow_complex_spark/sample.session.json b/src/google/adk/tests/integration/fixture/flow_complex_spark/sample.session.json
new file mode 100644
index 0000000..31575a8
--- /dev/null
+++ b/src/google/adk/tests/integration/fixture/flow_complex_spark/sample.session.json
@@ -0,0 +1,190 @@
+{
+ "id": "683bcc98-7359-4b65-9cdb-307f85ae62c0",
+ "context": {
+ "_time": "2024-12-15 11:52:21.958542"
+ },
+ "events": [
+ {
+ "invocation_id": "UfYAgiCT",
+ "author": "user",
+ "content": {
+ "parts": [
+ {
+ "text": "Research on the factors that affect floods in Florida. "
+ }
+ ],
+ "role": "user"
+ },
+ "options": {},
+ "is_greeting": false,
+ "id": "qKEz3Vss",
+ "timestamp": 1734292270.899412
+ },
+ {
+ "invocation_id": "UfYAgiCT",
+ "author": "spark_assistant",
+ "content": {
+ "parts": [
+ {
+ "function_call": {
+ "args": {
+ "agent_name": "research_assistant"
+ },
+ "name": "transfer_to_other_agent"
+ }
+ }
+ ],
+ "role": "model"
+ },
+ "options": {},
+ "is_greeting": false,
+ "id": "MEQQWBGe",
+ "timestamp": 1734292272.522898
+ },
+ {
+ "invocation_id": "UfYAgiCT",
+ "author": "spark_assistant",
+ "content": {
+ "parts": [
+ {
+ "function_response": {
+ "name": "transfer_to_other_agent",
+ "response": {
+ "status": "ok",
+ "target_agent_name": "research_assistant",
+ "message": "Transfered to research_assistant"
+ }
+ }
+ }
+ ],
+ "role": "user"
+ },
+ "options": {
+ "skip_summarization": true,
+ "function_call_event_id": "MEQQWBGe"
+ },
+ "is_greeting": false,
+ "id": "QjLptHaG",
+ "timestamp": 1734292272.524573
+ },
+ {
+ "invocation_id": "UfYAgiCT",
+ "author": "research_plan_agent",
+ "content": {
+ "parts": [
+ {
+ "text": "[research_plan_agent]: I can help you with that! Here is a research plan for the factors that affect floods in Florida:\n\n**Factors that affect floods in Florida**\nI will research the factors that affect floods in Florida, such as rainfall, sea level rise, and urbanization.\n\n**Rainfall in Florida**\nI will research the average rainfall in Florida, as well as the frequency and intensity of rainfall events.\n\n**Sea Level Rise in Florida**\nI will research the rate of sea level rise in Florida, as well as the impact of sea level rise on flooding.\n\n**Urbanization in Florida**\nI will research the impact of urbanization on flooding in Florida, such as the effects of impervious surfaces and drainage systems.\n\n**Flood History in Florida**\nI will research the history of floods in Florida, including the frequency, severity, and causes of past floods.\n\n**Flood Mitigation Efforts in Florida**\nI will research the flood mitigation efforts that are currently in place in Florida, as well as the effectiveness of these efforts.\n\nI will start researching now. \n"
+ }
+ ],
+ "role": "model"
+ },
+ "options": {},
+ "is_greeting": false,
+ "id": "NcWJjpm9",
+ "timestamp": 1734292274.707415
+ },
+ {
+ "invocation_id": "UfYAgiCT",
+ "author": "question_generation_agent",
+ "content": {
+ "parts": [
+ {
+ "text": "1. [question_generation_agent]: What are the specific types of rainfall events that contribute to flooding in Florida?\n2. [question_generation_agent]: How does the rate of sea level rise in Florida compare to the global average?\n3. [question_generation_agent]: What are the specific ways that urbanization in Florida contributes to increased flooding?\n4. [question_generation_agent]: Are there any specific regions in Florida that are particularly vulnerable to flooding?\n5. [question_generation_agent]: What are the most significant flood events in Florida's history, and what were their causes?\n6. [question_generation_agent]: What are the different types of flood mitigation efforts currently in place in Florida?\n7. [question_generation_agent]: How effective have these flood mitigation efforts been in reducing flood risk in Florida?\n8. [question_generation_agent]: What are the potential future impacts of climate change on flooding in Florida?\n9. [question_generation_agent]: What are the economic and social consequences of flooding in Florida?\n10. [question_generation_agent]: What are the recommendations for improving flood mitigation efforts in Florida? \n"
+ }
+ ],
+ "role": "model"
+ },
+ "options": {},
+ "is_greeting": false,
+ "id": "9DZFuhg7",
+ "timestamp": 1734292277.283785
+ },
+ {
+ "invocation_id": "UfYAgiCT",
+ "author": "information_retrieval_agent",
+ "content": {
+ "parts": [
+ {
+ "text": "[information_retrieval_agent]: **No. 1: Types of Rainfall Events Contributing to Flooding in Florida**\n\nFlorida experiences various types of rainfall events that contribute to flooding. These include:\n\n* **Tropical Storms and Hurricanes:** These storms bring heavy rainfall over extended periods, often exceeding the capacity of drainage systems and leading to widespread flooding.\n* **Thunderstorms:** While less intense than hurricanes, thunderstorms can produce heavy downpours in a short time, causing localized flooding, especially in urban areas with limited drainage.\n* **Frontal Systems:** Cold fronts moving across Florida can bring significant rainfall, particularly during the winter months, leading to flooding in low-lying areas.\n* **El Niño-Southern Oscillation (ENSO):** This climate pattern can influence rainfall patterns in Florida, with El Niño years often associated with increased rainfall and flooding.\n\n[information_retrieval_agent]: **No. 2: Sea Level Rise Rate in Florida Compared to Global Average**\n\nFlorida's sea level rise rate is significantly higher than the global average. While the global average sea level rise is about 3.4 millimeters per year, Florida's rate is estimated to be around 9 millimeters per year. This accelerated rate is attributed to factors such as land subsidence and the Gulf Stream's influence.\n\n[information_retrieval_agent]: **No. 3: Urbanization's Impact on Flooding in Florida**\n\nUrbanization in Florida contributes to increased flooding through several mechanisms:\n\n* **Impervious Surfaces:** Concrete and asphalt surfaces in urban areas reduce infiltration, forcing rainwater to run off into drainage systems, which can become overwhelmed during heavy rainfall.\n* **Modified Drainage Patterns:** Urban development often alters natural drainage patterns, leading to concentrated runoff and increased flood risk in certain areas.\n* **Increased Runoff Velocity:** Urban surfaces accelerate runoff, increasing the volume and speed of water flowing into drainage systems, potentially exceeding their capacity.\n\n[information_retrieval_agent]: **No. 4: Regions in Florida Particularly Vulnerable to Flooding**\n\nSeveral regions in Florida are particularly vulnerable to flooding due to their geographic characteristics and exposure to various flood risks:\n\n* **Coastal Areas:** Coastal areas are susceptible to storm surge, sea level rise, and coastal erosion, increasing their vulnerability to flooding.\n* **Low-Lying Areas:** Areas with low elevation are prone to flooding from heavy rainfall and rising sea levels.\n* **Urban Areas:** Urban areas with limited drainage capacity and high concentrations of impervious surfaces are at risk of flooding during heavy rainfall events.\n\n[information_retrieval_agent]: **No. 5: Significant Flood Events in Florida's History**\n\nFlorida has experienced numerous significant flood events throughout its history, with some notable examples:\n\n* **Hurricane Andrew (1992):** This Category 5 hurricane caused widespread flooding in South Florida, primarily due to storm surge and heavy rainfall.\n* **Hurricane Irma (2017):** This Category 4 hurricane brought significant rainfall and storm surge, leading to extensive flooding across the state.\n* **Hurricane Michael (2018):** This Category 5 hurricane caused significant flooding in the Panhandle region, primarily due to heavy rainfall and storm surge.\n\n[information_retrieval_agent]: **No. 6: Flood Mitigation Efforts in Florida**\n\nFlorida has implemented various flood mitigation efforts to reduce flood risk and protect communities:\n\n* **Levees and Seawalls:** These structures are designed to protect coastal areas from storm surge and rising sea levels.\n* **Drainage Systems:** Drainage systems are designed to collect and remove excess rainwater, reducing the risk of flooding in urban areas.\n* **Floodplain Management:** Regulations are in place to restrict development in floodplains, minimizing the risk of damage and loss of life.\n* **Early Warning Systems:** Advanced weather forecasting and warning systems provide timely alerts to residents, allowing them to prepare for potential flooding.\n\n[information_retrieval_agent]: **No. 7: Effectiveness of Flood Mitigation Efforts in Florida**\n\nThe effectiveness of flood mitigation efforts in Florida varies depending on the specific measures implemented and the severity of the flood event. While some efforts have been successful in reducing flood risk, others have proven less effective, particularly in the face of extreme weather events.\n\n[information_retrieval_agent]: **No. 8: Potential Future Impacts of Climate Change on Flooding in Florida**\n\nClimate change is expected to exacerbate flooding in Florida through several mechanisms:\n\n* **Increased Rainfall Intensity:** Climate change is projected to increase the frequency and intensity of rainfall events, leading to more frequent and severe flooding.\n* **Sea Level Rise:** Continued sea level rise will increase the risk of coastal flooding, particularly during storm surge events.\n* **Ocean Warming:** Warmer ocean temperatures can lead to more intense hurricanes, further increasing the risk of flooding.\n\n[information_retrieval_agent]: **No. 9: Economic and Social Consequences of Flooding in Florida**\n\nFlooding in Florida has significant economic and social consequences:\n\n* **Property Damage:** Flooding can cause extensive damage to homes, businesses, and infrastructure, resulting in substantial financial losses.\n* **Disruption of Transportation:** Flooding can disrupt transportation systems, including roads, bridges, and airports, leading to economic losses and social disruption.\n* **Health Risks:** Flooding can create health risks, such as waterborne diseases and injuries from debris.\n* **Displacement of Residents:** Flooding can force residents to evacuate their homes, leading to displacement and social disruption.\n\n[information_retrieval_agent]: **No. 10: Recommendations for Improving Flood Mitigation Efforts in Florida**\n\nTo improve flood mitigation efforts in Florida, several recommendations can be considered:\n\n* **Invest in Infrastructure:** Increased investment in drainage systems, levees, and seawalls is crucial to enhance flood protection.\n* **Promote Sustainable Development:** Encourage development practices that minimize impervious surfaces and preserve natural drainage patterns.\n* **Strengthen Early Warning Systems:** Enhance weather forecasting and warning systems to provide timely and accurate alerts to residents.\n* **Implement Comprehensive Floodplain Management:** Enforce strict regulations to restrict development in floodplains and minimize flood risk.\n* **Promote Public Awareness:** Educate residents about flood risks and encourage them to take steps to prepare for potential flooding.\n\n[exit] \n"
+ }
+ ],
+ "role": "model"
+ },
+ "options": {},
+ "is_greeting": false,
+ "id": "o11CKE8Q",
+ "timestamp": 1734292284.446259
+ },
+ {
+ "invocation_id": "UfYAgiCT",
+ "author": "summary_agent",
+ "content": {
+ "parts": [
+ {
+ "text": "## Flooding in Florida: A Comprehensive Analysis of Causes, Impacts, and Mitigation Strategies\n\n**Abstract:** This essay explores the complex issue of flooding in Florida, a state highly vulnerable to various flood threats. It examines the key factors contributing to flooding, including rainfall patterns, sea level rise, and urbanization, and analyzes their historical impact. Furthermore, it evaluates the effectiveness of existing flood mitigation efforts and explores recommendations for future improvements. \n\n**Rainfall Patterns and Flood Risk:**\n\nFlorida's climate is characterized by frequent and intense rainfall events, significantly contributing to flood risk. Tropical storms and hurricanes, particularly during the hurricane season, bring heavy rainfall over prolonged periods, overwhelming drainage systems and causing widespread flooding. Thunderstorms, though less intense, can generate heavy downpours in short bursts, leading to localized flooding, especially in urban areas. Furthermore, frontal systems, particularly during winter months, can bring substantial rainfall, exacerbating flooding in low-lying areas. El Niño-Southern Oscillation (ENSO) events can also influence rainfall patterns, with El Niño years often associated with increased rainfall and subsequent flooding.\n\n**Sea Level Rise and Coastal Flooding:**\n\nFlorida's sea level rise rate surpasses the global average, posing a significant threat to coastal areas. The rate of sea level rise in Florida is estimated to be around 9 millimeters per year, compared to the global average of 3.4 millimeters per year. This accelerated rise is attributed to factors such as land subsidence and the Gulf Stream's influence. As sea levels continue to rise, coastal areas face an increased risk of flooding, particularly during storm surge events, leading to erosion, saltwater intrusion, and potential displacement of coastal communities. \n\n**Urbanization and Flood Risk:**\n\nUrbanization in Florida, while contributing to economic growth, also exacerbates flood risk. The increasing development of impervious surfaces, such as concrete and asphalt, reduces rainwater infiltration, forcing excess water to flow into drainage systems. This can overwhelm the capacity of drainage systems, leading to flooding in urban areas. Furthermore, urban development often alters natural drainage patterns, diverting water flow and increasing flooding risks in specific locations. Finally, urban surfaces accelerate runoff velocity, increasing the volume and speed of water entering drainage systems, further enhancing the risk of flooding.\n\n**Historical Flood Events and Their Impact:**\n\nFlorida has experienced numerous significant flood events throughout its history, highlighting the state's vulnerability. Hurricane Andrew in 1992, a Category 5 hurricane, caused widespread flooding in South Florida, primarily due to storm surge and heavy rainfall. Hurricane Irma in 2017, a Category 4 hurricane, brought significant rainfall and storm surge, leading to extensive flooding across the state. Hurricane Michael in 2018, a Category 5 hurricane, caused substantial flooding in the Panhandle region, primarily due to heavy rainfall and storm surge. These events demonstrate the devastating impact of flooding on infrastructure, property, and human lives. \n\n**Flood Mitigation Efforts and Their Effectiveness:**\n\nFlorida has implemented various flood mitigation efforts to reduce flood risk and protect communities. These efforts include construction of levees and seawalls to protect coastal areas from storm surge and rising sea levels, construction of drainage systems to collect and remove excess rainwater, implementation of floodplain management regulations to restrict development in floodplains, and development of early warning systems to provide timely alerts to residents. While these efforts have contributed to reducing flood risk in some areas, their effectiveness varies depending on the severity of the flood event. Extreme weather events, such as hurricanes, can overwhelm existing infrastructure and mitigation measures, highlighting the need for continuous improvement and adaptation.\n\n**Future Impacts of Climate Change on Flooding:**\n\nClimate change is expected to exacerbate flooding in Florida, posing significant challenges for the state's future. Increased rainfall intensity, driven by climate change, is projected to lead to more frequent and severe flooding events. Continued sea level rise will further increase the risk of coastal flooding, especially during storm surge events. Ocean warming, a consequence of climate change, can contribute to more intense hurricanes, further amplifying flood risk. These compounding factors necessitate proactive measures to adapt to changing climate conditions and mitigate future flood risks. \n\n**Economic and Social Consequences of Flooding:**\n\nFlooding in Florida carries significant economic and social consequences. Property damage from flooding can be extensive, resulting in substantial financial losses for individuals, businesses, and the state. Flooding disrupts transportation systems, including roads, bridges, and airports, causing economic losses and social disruption. Flooding can also create health risks, such as waterborne diseases and injuries from debris. Furthermore, displacement of residents due to flooding can lead to social disruption and strain on existing resources. The economic and social costs of flooding underscore the importance of effective flood mitigation and adaptation strategies.\n\n**Recommendations for Improving Flood Mitigation Efforts:**\n\nTo improve flood mitigation efforts in Florida and address the challenges posed by climate change, several recommendations are crucial. Increased investment in infrastructure, including drainage systems, levees, and seawalls, is essential to enhance flood protection. Promoting sustainable development practices, such as minimizing impervious surfaces and preserving natural drainage patterns, can reduce runoff and mitigate flood risk. Strengthening early warning systems to provide timely and accurate alerts to residents is crucial for preparedness and response. Implementing comprehensive floodplain management regulations to restrict development in floodplains and minimize flood risk is essential. Finally, promoting public awareness about flood risks and encouraging residents to take steps to prepare for potential flooding is vital for community resilience.\n\n**Conclusion:**\n\nFlooding in Florida is a complex issue with multifaceted causes, impacts, and mitigation strategies. Rainfall patterns, sea level rise, and urbanization all contribute to flood risk, while historical events highlight the state's vulnerability. Existing flood mitigation efforts, while valuable, need continuous improvement and adaptation to address the challenges posed by climate change. Investing in infrastructure, promoting sustainable development, strengthening early warning systems, implementing comprehensive floodplain management, and fostering public awareness are key recommendations for enhancing flood resilience and safeguarding Florida's communities from the devastating impacts of flooding. \n"
+ }
+ ],
+ "role": "model"
+ },
+ "options": {},
+ "is_greeting": false,
+ "id": "om1YdTV1",
+ "timestamp": 1734292291.449593
+ },
+ {
+ "invocation_id": "UfYAgiCT",
+ "author": "spark_assistant",
+ "content": {
+ "parts": [
+ {
+ "function_call": {
+ "args": {
+ "agent_name": "spark_assistant"
+ },
+ "name": "transfer_to_other_agent"
+ }
+ }
+ ],
+ "role": "model"
+ },
+ "options": {
+ "skip_summarization": true
+ },
+ "is_greeting": false,
+ "id": "oRxzw1Wp",
+ "timestamp": 1734292291.451956
+ },
+ {
+ "invocation_id": "UfYAgiCT",
+ "author": "spark_assistant",
+ "content": {
+ "parts": [
+ {
+ "function_response": {
+ "name": "transfer_to_other_agent",
+ "response": {
+ "status": "ok",
+ "target_agent_name": "spark_assistant",
+ "message": "Transfered to spark_assistant"
+ }
+ }
+ }
+ ],
+ "role": "user"
+ },
+ "options": {
+ "skip_summarization": true,
+ "function_call_event_id": "oRxzw1Wp"
+ },
+ "is_greeting": false,
+ "id": "s2OiE3PP",
+ "timestamp": 1734292291.453157
+ }
+ ],
+ "past_events": [],
+ "current_agent_name": "spark_assistant",
+ "pending_events": {},
+ "artifacts": {},
+ "last_update_time": 1734292291.453157,
+ "event_logs": []
+}
\ No newline at end of file
diff --git a/src/google/adk/tests/integration/fixture/hello_world_agent/__init__.py b/src/google/adk/tests/integration/fixture/hello_world_agent/__init__.py
new file mode 100644
index 0000000..c48963c
--- /dev/null
+++ b/src/google/adk/tests/integration/fixture/hello_world_agent/__init__.py
@@ -0,0 +1,15 @@
+# Copyright 2025 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from . import agent
diff --git a/src/google/adk/tests/integration/fixture/hello_world_agent/agent.py b/src/google/adk/tests/integration/fixture/hello_world_agent/agent.py
new file mode 100644
index 0000000..b3ead16
--- /dev/null
+++ b/src/google/adk/tests/integration/fixture/hello_world_agent/agent.py
@@ -0,0 +1,95 @@
+# Copyright 2025 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# Hello world agent from agent 1.0 - https://colab.sandbox.google.com/drive/1Zq-nqmgK0nCERCv8jKIaoeTTgbNn6oSo?resourcekey=0-GYaz9pFT4wY8CI8Cvjy5GA#scrollTo=u3X3XwDOaCv9
+import random
+
+from google.adk import Agent
+from google.genai import types
+
+
+def roll_die(sides: int) -> int:
+ """Roll a die and return the rolled result.
+
+ Args:
+ sides: The integer number of sides the die has.
+
+ Returns:
+ An integer of the result of rolling the die.
+ """
+ return random.randint(1, sides)
+
+
+def check_prime(nums: list[int]) -> list[str]:
+ """Check if a given list of numbers are prime.
+
+ Args:
+ nums: The list of numbers to check.
+
+ Returns:
+ A str indicating which number is prime.
+ """
+ primes = set()
+ for number in nums:
+ number = int(number)
+ if number <= 1:
+ continue
+ is_prime = True
+ for i in range(2, int(number**0.5) + 1):
+ if number % i == 0:
+ is_prime = False
+ break
+ if is_prime:
+ primes.add(number)
+ return (
+ 'No prime numbers found.'
+ if not primes
+ else f"{', '.join(str(num) for num in primes)} are prime numbers."
+ )
+
+
+root_agent = Agent(
+ model='gemini-2.0-flash-001',
+ name='data_processing_agent',
+ instruction="""
+ You roll dice and answer questions about the outcome of the dice rolls.
+ You can roll dice of different sizes.
+ You can use multiple tools in parallel by calling functions in parallel(in one request and in one round).
+ The only things you do are roll dice for the user and discuss the outcomes.
+ It is ok to discuss previous dice roles, and comment on the dice rolls.
+ When you are asked to roll a die, you must call the roll_die tool with the number of sides. Be sure to pass in an integer. Do not pass in a string.
+ You should never roll a die on your own.
+ When checking prime numbers, call the check_prime tool with a list of integers. Be sure to pass in a list of integers. You should never pass in a string.
+ You should not check prime numbers before calling the tool.
+ When you are asked to roll a die and check prime numbers, you should always make the following two function calls:
+ 1. You should first call the roll_die tool to get a roll. Wait for the function response before calling the check_prime tool.
+ 2. After you get the function response from roll_die tool, you should call the check_prime tool with the roll_die result.
+ 2.1 If user asks you to check primes based on previous rolls, make sure you include the previous rolls in the list.
+ 3. When you respond, you must include the roll_die result from step 1.
+ You should always perform the previous 3 steps when asking for a roll and checking prime numbers.
+ You should not rely on the previous history on prime results.
+ """,
+ tools=[
+ roll_die,
+ check_prime,
+ ],
+ generate_content_config=types.GenerateContentConfig(
+ safety_settings=[
+ types.SafetySetting( # avoid false alarm about rolling dice.
+ category=types.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT,
+ threshold=types.HarmBlockThreshold.OFF,
+ ),
+ ]
+ ),
+)
diff --git a/src/google/adk/tests/integration/fixture/hello_world_agent/roll_die.test.json b/src/google/adk/tests/integration/fixture/hello_world_agent/roll_die.test.json
new file mode 100644
index 0000000..fdc8127
--- /dev/null
+++ b/src/google/adk/tests/integration/fixture/hello_world_agent/roll_die.test.json
@@ -0,0 +1,24 @@
+[
+ {
+ "query": "Hi who are you?",
+ "expected_tool_use": [],
+ "reference": "I am a data processing agent. I can roll dice and check if the results are prime numbers. What would you like me to do? \n"
+ },
+ {
+ "query": "What can you do?",
+ "expected_tool_use": [],
+ "reference": "I can roll dice for you of different sizes, and I can check if the results are prime numbers. I can also remember previous rolls if you'd like to check those for primes as well. What would you like me to do? \n"
+ },
+ {
+ "query": "Can you roll a die with 6 sides",
+ "expected_tool_use": [
+ {
+ "tool_name": "roll_die",
+ "tool_input": {
+ "sides": 6
+ }
+ }
+ ],
+ "reference": null
+ }
+]
diff --git a/src/google/adk/tests/integration/fixture/hello_world_agent/test_config.json b/src/google/adk/tests/integration/fixture/hello_world_agent/test_config.json
new file mode 100644
index 0000000..c7fba6a
--- /dev/null
+++ b/src/google/adk/tests/integration/fixture/hello_world_agent/test_config.json
@@ -0,0 +1,6 @@
+{
+ "criteria": {
+ "tool_trajectory_avg_score": 1.0,
+ "response_match_score": 0.5
+ }
+}
diff --git a/src/google/adk/tests/integration/fixture/home_automation_agent/__init__.py b/src/google/adk/tests/integration/fixture/home_automation_agent/__init__.py
new file mode 100644
index 0000000..c48963c
--- /dev/null
+++ b/src/google/adk/tests/integration/fixture/home_automation_agent/__init__.py
@@ -0,0 +1,15 @@
+# Copyright 2025 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from . import agent
diff --git a/src/google/adk/tests/integration/fixture/home_automation_agent/agent.py b/src/google/adk/tests/integration/fixture/home_automation_agent/agent.py
new file mode 100644
index 0000000..7a037a4
--- /dev/null
+++ b/src/google/adk/tests/integration/fixture/home_automation_agent/agent.py
@@ -0,0 +1,304 @@
+# Copyright 2025 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os
+import sys
+
+from google.adk import Agent
+
+DEVICE_DB = {
+ "device_1": {"status": "ON", "location": "Living Room"},
+ "device_2": {"status": "OFF", "location": "Bedroom"},
+ "device_3": {"status": "OFF", "location": "Kitchen"},
+}
+
+TEMPERATURE_DB = {
+ "Living Room": 22,
+ "Bedroom": 20,
+ "Kitchen": 24,
+}
+
+SCHEDULE_DB = {
+ "device_1": {"time": "18:00", "status": "ON"},
+ "device_2": {"time": "22:00", "status": "OFF"},
+}
+
+USER_PREFERENCES_DB = {
+ "user_x": {"preferred_temp": 21, "location": "Bedroom"},
+ "user_x": {"preferred_temp": 21, "location": "Living Room"},
+ "user_y": {"preferred_temp": 23, "location": "Living Room"},
+}
+
+
+def reset_data():
+ global DEVICE_DB
+ global TEMPERATURE_DB
+ global SCHEDULE_DB
+ global USER_PREFERENCES_DB
+ DEVICE_DB = {
+ "device_1": {"status": "ON", "location": "Living Room"},
+ "device_2": {"status": "OFF", "location": "Bedroom"},
+ "device_3": {"status": "OFF", "location": "Kitchen"},
+ }
+
+ TEMPERATURE_DB = {
+ "Living Room": 22,
+ "Bedroom": 20,
+ "Kitchen": 24,
+ }
+
+ SCHEDULE_DB = {
+ "device_1": {"time": "18:00", "status": "ON"},
+ "device_2": {"time": "22:00", "status": "OFF"},
+ }
+
+ USER_PREFERENCES_DB = {
+ "user_x": {"preferred_temp": 21, "location": "Bedroom"},
+ "user_x": {"preferred_temp": 21, "location": "Living Room"},
+ "user_y": {"preferred_temp": 23, "location": "Living Room"},
+ }
+
+
+def get_device_info(device_id: str) -> dict:
+ """Get the current status and location of a AC device.
+
+ Args:
+ device_id (str): The unique identifier of the device.
+
+ Returns:
+ dict: A dictionary containing the following fields, or 'Device not found'
+ if the device_id does not exist:
+ - status: The current status of the device (e.g., 'ON', 'OFF')
+ - location: The location where the device is installed (e.g., 'Living
+ Room', 'Bedroom', ''Kitchen')
+ """
+ return DEVICE_DB.get(device_id, "Device not found")
+
+
+# def set_device_info(device_id: str, updates: dict) -> str:
+# """Update the information of a AC device, specifically its status and/or location.
+
+# Args:
+# device_id (str): Required. The unique identifier of the device.
+# updates (dict): Required. A dictionary containing the fields to be
+# updated. Supported keys: - "status" (str): The new status to set for the
+# device. Accepted values: 'ON', 'OFF'. **Only these values are allowed.**
+# - "location" (str): The new location to set for the device. Accepted
+# values: 'Living Room', 'Bedroom', 'Kitchen'. **Only these values are
+# allowed.**
+
+
+# Returns:
+# str: A message indicating whether the device information was successfully
+# updated.
+# """
+# if device_id in DEVICE_DB:
+# if "status" in updates:
+# DEVICE_DB[device_id]["status"] = updates["status"]
+# if "location" in updates:
+# DEVICE_DB[device_id]["location"] = updates["location"]
+# return f"Device {device_id} information updated: {updates}."
+# return "Device not found"
+def set_device_info(
+ device_id: str, status: str = "", location: str = ""
+) -> str:
+ """Update the information of a AC device, specifically its status and/or location.
+
+ Args:
+ device_id (str): Required. The unique identifier of the device.
+ status (str): The new status to set for the
+ device. Accepted values: 'ON', 'OFF'. **Only these values are allowed.**
+ location (str): The new location to set for the device. Accepted
+ values: 'Living Room', 'Bedroom', 'Kitchen'. **Only these values are
+ allowed.**
+
+ Returns:
+ str: A message indicating whether the device information was successfully
+ updated.
+ """
+ if device_id in DEVICE_DB:
+ if status:
+ DEVICE_DB[device_id]["status"] = status
+ return f"Device {device_id} information updated: status -> {status}."
+ if location:
+ DEVICE_DB[device_id]["location"] = location
+ return f"Device {device_id} information updated: location -> {location}."
+ return "Device not found"
+
+
+def get_temperature(location: str) -> int:
+ """Get the current temperature in celsius of a location (e.g., 'Living Room', 'Bedroom', 'Kitchen').
+
+ Args:
+ location (str): The location for which to retrieve the temperature (e.g.,
+ 'Living Room', 'Bedroom', 'Kitchen').
+
+ Returns:
+ int: The current temperature in celsius in the specified location, or
+ 'Location not found' if the location does not exist.
+ """
+ return TEMPERATURE_DB.get(location, "Location not found")
+
+
+def set_temperature(location: str, temperature: int) -> str:
+ """Set the desired temperature in celsius for a location.
+
+ Acceptable range of temperature: 18-30 celsius. If it's out of the range, do
+ not call this tool.
+
+ Args:
+ location (str): The location where the temperature should be set.
+ temperature (int): The desired temperature as integer to set in celsius.
+ Acceptable range: 18-30 celsius.
+
+ Returns:
+ str: A message indicating whether the temperature was successfully set.
+ """
+ if location in TEMPERATURE_DB:
+ TEMPERATURE_DB[location] = temperature
+ return f"Temperature in {location} set to {temperature}°C."
+ return "Location not found"
+
+
+def get_user_preferences(user_id: str) -> dict:
+ """Get the temperature preferences and preferred location of a user_id.
+
+ user_id must be provided.
+
+ Args:
+ user_id (str): The unique identifier of the user.
+
+ Returns:
+ dict: A dictionary containing the following fields, or 'User not found' if
+ the user_id does not exist:
+ - preferred_temp: The user's preferred temperature.
+ - location: The location where the user prefers to be.
+ """
+ return USER_PREFERENCES_DB.get(user_id, "User not found")
+
+
+def set_device_schedule(device_id: str, time: str, status: str) -> str:
+ """Schedule a device to change its status at a specific time.
+
+ Args:
+ device_id (str): The unique identifier of the device.
+ time (str): The time at which the device should change its status (format:
+ 'HH:MM').
+ status (str): The status to set for the device at the specified time
+ (e.g., 'ON', 'OFF').
+
+ Returns:
+ str: A message indicating whether the schedule was successfully set.
+ """
+ if device_id in DEVICE_DB:
+ SCHEDULE_DB[device_id] = {"time": time, "status": status}
+ return f"Device {device_id} scheduled to turn {status} at {time}."
+ return "Device not found"
+
+
+def get_device_schedule(device_id: str) -> dict:
+ """Retrieve the schedule of a device.
+
+ Args:
+ device_id (str): The unique identifier of the device.
+
+ Returns:
+ dict: A dictionary containing the following fields, or 'Schedule not
+ found' if the device_id does not exist:
+ - time: The scheduled time for the device to change its status (format:
+ 'HH:MM').
+ - status: The status that will be set at the scheduled time (e.g., 'ON',
+ 'OFF').
+ """
+ return SCHEDULE_DB.get(device_id, "Schedule not found")
+
+
+def celsius_to_fahrenheit(celsius: int) -> float:
+ """Convert Celsius to Fahrenheit.
+
+ You must call this to do the conversion of temperature, so you can get the
+ precise number in required format.
+
+ Args:
+ celsius (int): Temperature in Celsius.
+
+ Returns:
+ float: Temperature in Fahrenheit.
+ """
+ return (celsius * 9 / 5) + 32
+
+
+def fahrenheit_to_celsius(fahrenheit: float) -> int:
+ """Convert Fahrenheit to Celsius.
+
+ You must call this to do the conversion of temperature, so you can get the
+ precise number in required format.
+
+ Args:
+ fahrenheit (float): Temperature in Fahrenheit.
+
+ Returns:
+ int: Temperature in Celsius.
+ """
+ return int((fahrenheit - 32) * 5 / 9)
+
+
+def list_devices(status: str = "", location: str = "") -> list:
+ """Retrieve a list of AC devices, filtered by status and/or location when provided.
+
+ For cost efficiency, always apply as many filters (status and location) as
+ available in the input arguments.
+
+ Args:
+ status (str, optional): The status to filter devices by (e.g., 'ON',
+ 'OFF'). Defaults to None.
+ location (str, optional): The location to filter devices by (e.g., 'Living
+ Room', 'Bedroom', ''Kitchen'). Defaults to None.
+
+ Returns:
+ list: A list of dictionaries, each containing the device ID, status, and
+ location, or an empty list if no devices match the criteria.
+ """
+ devices = []
+ for device_id, info in DEVICE_DB.items():
+ if ((not status) or info["status"] == status) and (
+ (not location) or info["location"] == location
+ ):
+ devices.append({
+ "device_id": device_id,
+ "status": info["status"],
+ "location": info["location"],
+ })
+ return devices if devices else "No devices found matching the criteria."
+
+
+root_agent = Agent(
+ model="gemini-2.0-flash-001",
+ name="Home_automation_agent",
+ instruction="""
+ You are Home Automation Agent. You are responsible for controlling the devices in the home.
+ """,
+ tools=[
+ get_device_info,
+ set_device_info,
+ get_temperature,
+ set_temperature,
+ get_user_preferences,
+ set_device_schedule,
+ get_device_schedule,
+ celsius_to_fahrenheit,
+ fahrenheit_to_celsius,
+ list_devices,
+ ],
+)
diff --git a/src/google/adk/tests/integration/fixture/home_automation_agent/simple_test.test.json b/src/google/adk/tests/integration/fixture/home_automation_agent/simple_test.test.json
new file mode 100644
index 0000000..978c36f
--- /dev/null
+++ b/src/google/adk/tests/integration/fixture/home_automation_agent/simple_test.test.json
@@ -0,0 +1,5 @@
+[{
+ "query": "Turn off device_2 in the Bedroom.",
+ "expected_tool_use": [{"tool_name": "set_device_info", "tool_input": {"location": "Bedroom", "device_id": "device_2", "status": "OFF"}}],
+ "reference": "I have set the device_2 status to off."
+}]
diff --git a/src/google/adk/tests/integration/fixture/home_automation_agent/simple_test2.test.json b/src/google/adk/tests/integration/fixture/home_automation_agent/simple_test2.test.json
new file mode 100644
index 0000000..5ba5d82
--- /dev/null
+++ b/src/google/adk/tests/integration/fixture/home_automation_agent/simple_test2.test.json
@@ -0,0 +1,5 @@
+[{
+ "query": "Turn off device_3 in the Bedroom.",
+ "expected_tool_use": [{"tool_name": "set_device_info", "tool_input": {"location": "Bedroom", "device_id": "device_3", "status": "OFF"}}],
+ "reference": "I have set the device_3 status to off."
+}]
diff --git a/src/google/adk/tests/integration/fixture/home_automation_agent/test_config.json b/src/google/adk/tests/integration/fixture/home_automation_agent/test_config.json
new file mode 100644
index 0000000..424c95d
--- /dev/null
+++ b/src/google/adk/tests/integration/fixture/home_automation_agent/test_config.json
@@ -0,0 +1,5 @@
+{
+ "criteria": {
+ "tool_trajectory_avg_score": 1.0
+ }
+}
diff --git a/src/google/adk/tests/integration/fixture/home_automation_agent/test_files/dependent_tool_calls.test.json b/src/google/adk/tests/integration/fixture/home_automation_agent/test_files/dependent_tool_calls.test.json
new file mode 100644
index 0000000..0633eab
--- /dev/null
+++ b/src/google/adk/tests/integration/fixture/home_automation_agent/test_files/dependent_tool_calls.test.json
@@ -0,0 +1,18 @@
+[
+ {
+ "query": "Turn off device_2 in the Bedroom.",
+ "expected_tool_use": [{
+ "tool_name": "set_device_info",
+ "tool_input": {"location": "Bedroom", "status": "OFF", "device_id": "device_2"}
+ }],
+ "reference": "I have set the device 2 status to off."
+ },
+ {
+ "query": "What's the status of device_2 in the Bedroom?",
+ "expected_tool_use": [{
+ "tool_name": "get_device_info",
+ "tool_input": {"device_id": "device_2"}
+ }],
+ "reference": "Status of device_2 is off."
+ }
+]
diff --git a/src/google/adk/tests/integration/fixture/home_automation_agent/test_files/memorizing_past_events/eval_data.test.json b/src/google/adk/tests/integration/fixture/home_automation_agent/test_files/memorizing_past_events/eval_data.test.json
new file mode 100644
index 0000000..0e5778b
--- /dev/null
+++ b/src/google/adk/tests/integration/fixture/home_automation_agent/test_files/memorizing_past_events/eval_data.test.json
@@ -0,0 +1,17 @@
+[
+ {
+ "query": "Turn off device_2 in the Bedroom.",
+ "expected_tool_use": [
+ {
+ "tool_name": "set_device_info",
+ "tool_input": {"location": "Bedroom", "device_id": "device_2", "status": "OFF"}
+ }
+ ],
+ "reference": "OK. I've turned off device_2 in the Bedroom. Anything else?\n"
+ },
+ {
+ "query": "What's the command I just issued?",
+ "expected_tool_use": [],
+ "reference": "You asked me to turn off device_2 in the Bedroom.\n"
+ }
+]
diff --git a/src/google/adk/tests/integration/fixture/home_automation_agent/test_files/memorizing_past_events/test_config.json b/src/google/adk/tests/integration/fixture/home_automation_agent/test_files/memorizing_past_events/test_config.json
new file mode 100644
index 0000000..c7fba6a
--- /dev/null
+++ b/src/google/adk/tests/integration/fixture/home_automation_agent/test_files/memorizing_past_events/test_config.json
@@ -0,0 +1,6 @@
+{
+ "criteria": {
+ "tool_trajectory_avg_score": 1.0,
+ "response_match_score": 0.5
+ }
+}
diff --git a/src/google/adk/tests/integration/fixture/home_automation_agent/test_files/simple_multi_turn_conversation.test.json b/src/google/adk/tests/integration/fixture/home_automation_agent/test_files/simple_multi_turn_conversation.test.json
new file mode 100644
index 0000000..334dd2d
--- /dev/null
+++ b/src/google/adk/tests/integration/fixture/home_automation_agent/test_files/simple_multi_turn_conversation.test.json
@@ -0,0 +1,18 @@
+[
+ {
+ "query": "Turn off device_2 in the Bedroom.",
+ "expected_tool_use": [{
+ "tool_name": "set_device_info",
+ "tool_input": {"location": "Bedroom", "device_id": "device_2", "status": "OFF"}
+ }],
+ "reference": "I have set the device 2 status to off."
+ },
+ {
+ "query": "Turn on device_2 in the Bedroom.",
+ "expected_tool_use": [{
+ "tool_name": "set_device_info",
+ "tool_input": {"location": "Bedroom", "status": "ON", "device_id": "device_2"}
+ }],
+ "reference": "I have set the device 2 status to on."
+ }
+]
diff --git a/src/google/adk/tests/integration/fixture/home_automation_agent/test_files/simple_test.test.json b/src/google/adk/tests/integration/fixture/home_automation_agent/test_files/simple_test.test.json
new file mode 100644
index 0000000..0e5778b
--- /dev/null
+++ b/src/google/adk/tests/integration/fixture/home_automation_agent/test_files/simple_test.test.json
@@ -0,0 +1,17 @@
+[
+ {
+ "query": "Turn off device_2 in the Bedroom.",
+ "expected_tool_use": [
+ {
+ "tool_name": "set_device_info",
+ "tool_input": {"location": "Bedroom", "device_id": "device_2", "status": "OFF"}
+ }
+ ],
+ "reference": "OK. I've turned off device_2 in the Bedroom. Anything else?\n"
+ },
+ {
+ "query": "What's the command I just issued?",
+ "expected_tool_use": [],
+ "reference": "You asked me to turn off device_2 in the Bedroom.\n"
+ }
+]
diff --git a/src/google/adk/tests/integration/fixture/home_automation_agent/test_files/simple_test2.test.json b/src/google/adk/tests/integration/fixture/home_automation_agent/test_files/simple_test2.test.json
new file mode 100644
index 0000000..5ba5d82
--- /dev/null
+++ b/src/google/adk/tests/integration/fixture/home_automation_agent/test_files/simple_test2.test.json
@@ -0,0 +1,5 @@
+[{
+ "query": "Turn off device_3 in the Bedroom.",
+ "expected_tool_use": [{"tool_name": "set_device_info", "tool_input": {"location": "Bedroom", "device_id": "device_3", "status": "OFF"}}],
+ "reference": "I have set the device_3 status to off."
+}]
diff --git a/src/google/adk/tests/integration/fixture/home_automation_agent/test_files/test_config.json b/src/google/adk/tests/integration/fixture/home_automation_agent/test_files/test_config.json
new file mode 100644
index 0000000..424c95d
--- /dev/null
+++ b/src/google/adk/tests/integration/fixture/home_automation_agent/test_files/test_config.json
@@ -0,0 +1,5 @@
+{
+ "criteria": {
+ "tool_trajectory_avg_score": 1.0
+ }
+}
diff --git a/src/google/adk/tests/integration/fixture/tool_agent/__init__.py b/src/google/adk/tests/integration/fixture/tool_agent/__init__.py
new file mode 100644
index 0000000..c48963c
--- /dev/null
+++ b/src/google/adk/tests/integration/fixture/tool_agent/__init__.py
@@ -0,0 +1,15 @@
+# Copyright 2025 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from . import agent
diff --git a/src/google/adk/tests/integration/fixture/tool_agent/agent.py b/src/google/adk/tests/integration/fixture/tool_agent/agent.py
new file mode 100644
index 0000000..a89d208
--- /dev/null
+++ b/src/google/adk/tests/integration/fixture/tool_agent/agent.py
@@ -0,0 +1,218 @@
+# Copyright 2025 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os
+from typing import Any
+
+from crewai_tools import DirectoryReadTool
+from google.adk import Agent
+from google.adk.tools.agent_tool import AgentTool
+from google.adk.tools.crewai_tool import CrewaiTool
+from google.adk.tools.langchain_tool import LangchainTool
+from google.adk.tools.retrieval.files_retrieval import FilesRetrieval
+from google.adk.tools.retrieval.vertex_ai_rag_retrieval import VertexAiRagRetrieval
+from langchain_community.tools import ShellTool
+from pydantic import BaseModel
+
+
+class TestCase(BaseModel):
+ case: str
+
+
+class Test(BaseModel):
+ test_title: list[str]
+
+
+def simple_function(param: str) -> str:
+ if isinstance(param, str):
+ return "Called simple function successfully"
+ return "Called simple function with wrong param type"
+
+
+def no_param_function() -> str:
+ return "Called no param function successfully"
+
+
+def no_output_function(param: str):
+ return
+
+
+def multiple_param_types_function(
+ param1: str, param2: int, param3: float, param4: bool
+) -> str:
+ if (
+ isinstance(param1, str)
+ and isinstance(param2, int)
+ and isinstance(param3, float)
+ and isinstance(param4, bool)
+ ):
+ return "Called multiple param types function successfully"
+ return "Called multiple param types function with wrong param types"
+
+
+def throw_error_function(param: str) -> str:
+ raise ValueError("Error thrown by throw_error_function")
+
+
+def list_str_param_function(param: list[str]) -> str:
+ if isinstance(param, list) and all(isinstance(item, str) for item in param):
+ return "Called list str param function successfully"
+ return "Called list str param function with wrong param type"
+
+
+def return_list_str_function(param: str) -> list[str]:
+ return ["Called return list str function successfully"]
+
+
+def complex_function_list_dict(
+ param1: dict[str, Any], param2: list[dict[str, Any]]
+) -> list[Test]:
+ if (
+ isinstance(param1, dict)
+ and isinstance(param2, list)
+ and all(isinstance(item, dict) for item in param2)
+ ):
+ return [
+ Test(test_title=["function test 1", "function test 2"]),
+ Test(test_title=["retrieval test"]),
+ ]
+ raise ValueError("Wrong param")
+
+
+def repetive_call_1(param: str):
+ return f"Call repetive_call_2 tool with param {param + '_repetive'}"
+
+
+def repetive_call_2(param: str):
+ return param
+
+
+test_case_retrieval = FilesRetrieval(
+ name="test_case_retrieval",
+ description="General guidence for agent test cases",
+ input_dir=os.path.join(os.path.dirname(__file__), "files"),
+)
+
+valid_rag_retrieval = VertexAiRagRetrieval(
+ name="valid_rag_retrieval",
+ rag_corpora=[
+ "projects/1096655024998/locations/us-central1/ragCorpora/4985766262475849728"
+ ],
+ description="General guidence for agent test cases",
+)
+
+invalid_rag_retrieval = VertexAiRagRetrieval(
+ name="invalid_rag_retrieval",
+ rag_corpora=[
+ "projects/1096655024998/locations/us-central1/InValidRagCorporas/4985766262475849728"
+ ],
+ description="Invalid rag retrieval resource name",
+)
+
+non_exist_rag_retrieval = VertexAiRagRetrieval(
+ name="non_exist_rag_retrieval",
+ rag_corpora=[
+ "projects/1096655024998/locations/us-central1/RagCorpora/1234567"
+ ],
+ description="Non exist rag retrieval resource name",
+)
+
+shell_tool = LangchainTool(ShellTool())
+
+docs_tool = CrewaiTool(
+ name="direcotry_read_tool",
+ description="use this to find files for you.",
+ tool=DirectoryReadTool(directory="."),
+)
+
+no_schema_agent = Agent(
+ model="gemini-1.5-flash",
+ name="no_schema_agent",
+ instruction="""Just say 'Hi'
+""",
+)
+
+schema_agent = Agent(
+ model="gemini-1.5-flash",
+ name="schema_agent",
+ instruction="""
+ You will be given a test case.
+ Return a list of the received test case appended with '_success' and '_failure' as test_titles
+""",
+ input_schema=TestCase,
+ output_schema=Test,
+)
+
+no_input_schema_agent = Agent(
+ model="gemini-1.5-flash",
+ name="no_input_schema_agent",
+ instruction="""
+ Just return ['Tools_success, Tools_failure']
+""",
+ output_schema=Test,
+)
+
+no_output_schema_agent = Agent(
+ model="gemini-1.5-flash",
+ name="no_output_schema_agent",
+ instruction="""
+ Just say 'Hi'
+""",
+ input_schema=TestCase,
+)
+
+single_function_agent = Agent(
+ model="gemini-1.5-flash",
+ name="single_function_agent",
+ description="An agent that calls a single function",
+ instruction="When calling tools, just return what the tool returns.",
+ tools=[simple_function],
+)
+
+root_agent = Agent(
+ model="gemini-1.5-flash",
+ name="tool_agent",
+ description="An agent that can call other tools",
+ instruction="When calling tools, just return what the tool returns.",
+ tools=[
+ simple_function,
+ no_param_function,
+ no_output_function,
+ multiple_param_types_function,
+ throw_error_function,
+ list_str_param_function,
+ return_list_str_function,
+ # complex_function_list_dict,
+ repetive_call_1,
+ repetive_call_2,
+ test_case_retrieval,
+ valid_rag_retrieval,
+ invalid_rag_retrieval,
+ non_exist_rag_retrieval,
+ shell_tool,
+ docs_tool,
+ AgentTool(
+ agent=no_schema_agent,
+ ),
+ AgentTool(
+ agent=schema_agent,
+ ),
+ AgentTool(
+ agent=no_input_schema_agent,
+ ),
+ AgentTool(
+ agent=no_output_schema_agent,
+ ),
+ ],
+)
diff --git a/src/google/adk/tests/integration/fixture/tool_agent/files/Agent_test_plan.pdf b/src/google/adk/tests/integration/fixture/tool_agent/files/Agent_test_plan.pdf
new file mode 100644
index 0000000..d8a1ac5
Binary files /dev/null and b/src/google/adk/tests/integration/fixture/tool_agent/files/Agent_test_plan.pdf differ
diff --git a/src/google/adk/tests/integration/fixture/trip_planner_agent/__init__.py b/src/google/adk/tests/integration/fixture/trip_planner_agent/__init__.py
new file mode 100644
index 0000000..c48963c
--- /dev/null
+++ b/src/google/adk/tests/integration/fixture/trip_planner_agent/__init__.py
@@ -0,0 +1,15 @@
+# Copyright 2025 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from . import agent
diff --git a/src/google/adk/tests/integration/fixture/trip_planner_agent/agent.py b/src/google/adk/tests/integration/fixture/trip_planner_agent/agent.py
new file mode 100644
index 0000000..ea8a33a
--- /dev/null
+++ b/src/google/adk/tests/integration/fixture/trip_planner_agent/agent.py
@@ -0,0 +1,110 @@
+# Copyright 2025 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# https://github.com/crewAIInc/crewAI-examples/tree/main/trip_planner
+
+from google.adk import Agent
+
+# Agent that selects the best city for the trip.
+identify_agent = Agent(
+ name='identify_agent',
+ description='Select the best city based on weather, season, and prices.',
+ instruction="""
+ Analyze and select the best city for the trip based
+ on specific criteria such as weather patterns, seasonal
+ events, and travel costs. This task involves comparing
+ multiple cities, considering factors like current weather
+ conditions, upcoming cultural or seasonal events, and
+ overall travel expenses.
+
+ Your final answer must be a detailed
+ report on the chosen city, and everything you found out
+ about it, including the actual flight costs, weather
+ forecast and attractions.
+
+ Traveling from: {origin}
+ City Options: {cities}
+ Trip Date: {range}
+ Traveler Interests: {interests}
+""",
+)
+
+# Agent that gathers information about the city.
+gather_agent = Agent(
+ name='gather_agent',
+ description='Provide the BEST insights about the selected city',
+ instruction="""
+ As a local expert on this city you must compile an
+ in-depth guide for someone traveling there and wanting
+ to have THE BEST trip ever!
+ Gather information about key attractions, local customs,
+ special events, and daily activity recommendations.
+ Find the best spots to go to, the kind of place only a
+ local would know.
+ This guide should provide a thorough overview of what
+ the city has to offer, including hidden gems, cultural
+ hotspots, must-visit landmarks, weather forecasts, and
+ high level costs.
+
+ The final answer must be a comprehensive city guide,
+ rich in cultural insights and practical tips,
+ tailored to enhance the travel experience.
+
+ Trip Date: {range}
+ Traveling from: {origin}
+ Traveler Interests: {interests}
+""",
+)
+
+# Agent that plans the trip.
+plan_agent = Agent(
+ name='plan_agent',
+ description="""Create the most amazing travel itineraries with budget and
+ packing suggestions for the city""",
+ instruction="""
+ Expand this guide into a full 7-day travel
+ itinerary with detailed per-day plans, including
+ weather forecasts, places to eat, packing suggestions,
+ and a budget breakdown.
+
+ You MUST suggest actual places to visit, actual hotels
+ to stay and actual restaurants to go to.
+
+ This itinerary should cover all aspects of the trip,
+ from arrival to departure, integrating the city guide
+ information with practical travel logistics.
+
+ Your final answer MUST be a complete expanded travel plan,
+ formatted as markdown, encompassing a daily schedule,
+ anticipated weather conditions, recommended clothing and
+ items to pack, and a detailed budget, ensuring THE BEST
+ TRIP EVER. Be specific and give it a reason why you picked
+ each place, what makes them special!
+
+ Trip Date: {range}
+ Traveling from: {origin}
+ Traveler Interests: {interests}
+""",
+)
+
+root_agent = Agent(
+ model='gemini-2.0-flash-001',
+ name='trip_planner',
+ description='Plan the best trip ever',
+ instruction="""
+ Your goal is to plan the best trip according to information listed above.
+ You describe why did you choose the city, list top 3
+ attactions and provide a detailed itinerary for each day.""",
+ sub_agents=[identify_agent, gather_agent, plan_agent],
+)
diff --git a/src/google/adk/tests/integration/fixture/trip_planner_agent/initial.session.json b/src/google/adk/tests/integration/fixture/trip_planner_agent/initial.session.json
new file mode 100644
index 0000000..b33840c
--- /dev/null
+++ b/src/google/adk/tests/integration/fixture/trip_planner_agent/initial.session.json
@@ -0,0 +1,13 @@
+{
+ "id": "test_id",
+ "app_name": "trip_planner_agent",
+ "user_id": "test_user",
+ "state": {
+ "origin": "San Francisco",
+ "interests": "Food, Shopping, Museums",
+ "range": "1000 miles",
+ "cities": ""
+ },
+ "events": [],
+ "last_update_time": 1741218714.258285
+}
diff --git a/src/google/adk/tests/integration/fixture/trip_planner_agent/test_config.json b/src/google/adk/tests/integration/fixture/trip_planner_agent/test_config.json
new file mode 100644
index 0000000..fffaa15
--- /dev/null
+++ b/src/google/adk/tests/integration/fixture/trip_planner_agent/test_config.json
@@ -0,0 +1,5 @@
+{
+ "criteria": {
+ "response_match_score": 0.5
+ }
+}
diff --git a/src/google/adk/tests/integration/fixture/trip_planner_agent/test_files/initial.session.json b/src/google/adk/tests/integration/fixture/trip_planner_agent/test_files/initial.session.json
new file mode 100644
index 0000000..b33840c
--- /dev/null
+++ b/src/google/adk/tests/integration/fixture/trip_planner_agent/test_files/initial.session.json
@@ -0,0 +1,13 @@
+{
+ "id": "test_id",
+ "app_name": "trip_planner_agent",
+ "user_id": "test_user",
+ "state": {
+ "origin": "San Francisco",
+ "interests": "Food, Shopping, Museums",
+ "range": "1000 miles",
+ "cities": ""
+ },
+ "events": [],
+ "last_update_time": 1741218714.258285
+}
diff --git a/src/google/adk/tests/integration/fixture/trip_planner_agent/test_files/test_config.json b/src/google/adk/tests/integration/fixture/trip_planner_agent/test_files/test_config.json
new file mode 100644
index 0000000..fffaa15
--- /dev/null
+++ b/src/google/adk/tests/integration/fixture/trip_planner_agent/test_files/test_config.json
@@ -0,0 +1,5 @@
+{
+ "criteria": {
+ "response_match_score": 0.5
+ }
+}
diff --git a/src/google/adk/tests/integration/fixture/trip_planner_agent/test_files/trip_inquiry_sub_agent.test.json b/src/google/adk/tests/integration/fixture/trip_planner_agent/test_files/trip_inquiry_sub_agent.test.json
new file mode 100644
index 0000000..03f52ab
--- /dev/null
+++ b/src/google/adk/tests/integration/fixture/trip_planner_agent/test_files/trip_inquiry_sub_agent.test.json
@@ -0,0 +1,7 @@
+[
+ {
+ "query": "Based on my interests, where should I go, Yosemite national park or Los Angeles?",
+ "expected_tool_use": [],
+ "reference": "Given your interests in food, shopping, and museums, Los Angeles would be a better choice than Yosemite National Park. Yosemite is primarily focused on outdoor activities and natural landscapes, while Los Angeles offers a diverse range of culinary experiences, shopping districts, and world-class museums. I will now gather information to create an in-depth guide for your trip to Los Angeles.\n"
+ }
+]
diff --git a/src/google/adk/tests/integration/fixture/trip_planner_agent/trip_inquiry.test.json b/src/google/adk/tests/integration/fixture/trip_planner_agent/trip_inquiry.test.json
new file mode 100644
index 0000000..c504f68
--- /dev/null
+++ b/src/google/adk/tests/integration/fixture/trip_planner_agent/trip_inquiry.test.json
@@ -0,0 +1,19 @@
+[
+ {
+ "query": "Hi, who are you? What can you do?",
+ "expected_tool_use": [],
+ "reference": "I am trip_planner, and my goal is to plan the best trip ever. I can describe why a city was chosen, list its top attractions, and provide a detailed itinerary for each day of the trip.\n"
+ },
+ {
+ "query": "I want to travel from San Francisco to an European country in fall next year. I am considering London and Paris. What is your advice?",
+ "expected_tool_use": [
+ {
+ "tool_name": "transfer_to_agent",
+ "tool_input": {
+ "agent_name": "indentify_agent"
+ }
+ }
+ ],
+ "reference": "Okay, I can help you analyze London and Paris to determine which city is better for your trip next fall. I will consider weather patterns, seasonal events, travel costs (including flights from San Francisco), and your interests (food, shopping, and museums). After gathering this information, I'll provide a detailed report on my chosen city.\n"
+ }
+]
diff --git a/src/google/adk/tests/integration/models/__init__.py b/src/google/adk/tests/integration/models/__init__.py
new file mode 100644
index 0000000..36a1e8d
--- /dev/null
+++ b/src/google/adk/tests/integration/models/__init__.py
@@ -0,0 +1,14 @@
+# Copyright 2025 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
diff --git a/src/google/adk/tests/integration/models/test_google_llm.py b/src/google/adk/tests/integration/models/test_google_llm.py
new file mode 100644
index 0000000..daa0b51
--- /dev/null
+++ b/src/google/adk/tests/integration/models/test_google_llm.py
@@ -0,0 +1,65 @@
+# Copyright 2025 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from google.adk.models import LlmRequest
+from google.adk.models import LlmResponse
+from google.adk.models.google_llm import Gemini
+from google.genai import types
+from google.genai.types import Content
+from google.genai.types import Part
+import pytest
+
+
+@pytest.fixture
+def gemini_llm():
+ return Gemini(model="gemini-1.5-flash")
+
+
+@pytest.fixture
+def llm_request():
+ return LlmRequest(
+ model="gemini-1.5-flash",
+ contents=[Content(role="user", parts=[Part.from_text(text="Hello")])],
+ config=types.GenerateContentConfig(
+ temperature=0.1,
+ response_modalities=[types.Modality.TEXT],
+ system_instruction="You are a helpful assistant",
+ ),
+ )
+
+
+@pytest.mark.asyncio
+async def test_generate_content_async(gemini_llm, llm_request):
+ async for response in gemini_llm.generate_content_async(llm_request):
+ assert isinstance(response, LlmResponse)
+ assert response.content.parts[0].text
+
+
+@pytest.mark.asyncio
+async def test_generate_content_async_stream(gemini_llm, llm_request):
+ responses = [
+ resp
+ async for resp in gemini_llm.generate_content_async(
+ llm_request, stream=True
+ )
+ ]
+ text = ""
+ for i in range(len(responses) - 1):
+ assert responses[i].partial is True
+ assert responses[i].content.parts[0].text
+ text += responses[i].content.parts[0].text
+
+ # Last message should be accumulated text
+ assert responses[-1].content.parts[0].text == text
+ assert not responses[-1].partial
diff --git a/src/google/adk/tests/integration/test_callback.py b/src/google/adk/tests/integration/test_callback.py
new file mode 100644
index 0000000..4d9f1d3
--- /dev/null
+++ b/src/google/adk/tests/integration/test_callback.py
@@ -0,0 +1,70 @@
+# Copyright 2025 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from pytest import mark
+
+from ..unittests.utils import simplify_events
+from .fixture import callback_agent
+from .utils import assert_agent_says
+from .utils import TestRunner
+
+
+@mark.parametrize(
+ "agent_runner",
+ [{"agent": callback_agent.agent.before_agent_callback_agent}],
+ indirect=True,
+)
+def test_before_agent_call(agent_runner: TestRunner):
+ agent_runner.run("Hi.")
+
+ # Assert the response content
+ assert_agent_says(
+ "End invocation event before agent call.",
+ agent_name="before_agent_callback_agent",
+ agent_runner=agent_runner,
+ )
+
+
+@mark.parametrize(
+ "agent_runner",
+ [{"agent": callback_agent.agent.before_model_callback_agent}],
+ indirect=True,
+)
+def test_before_model_call(agent_runner: TestRunner):
+ agent_runner.run("Hi.")
+
+ # Assert the response content
+ assert_agent_says(
+ "End invocation event before model call.",
+ agent_name="before_model_callback_agent",
+ agent_runner=agent_runner,
+ )
+
+
+# TODO: re-enable vertex by removing below line after fixing.
+@mark.parametrize("llm_backend", ["GOOGLE_AI"], indirect=True)
+@mark.parametrize(
+ "agent_runner",
+ [{"agent": callback_agent.agent.after_model_callback_agent}],
+ indirect=True,
+)
+def test_after_model_call(agent_runner: TestRunner):
+ events = agent_runner.run("Hi.")
+
+ # Assert the response content
+ simplified_events = simplify_events(events)
+ assert simplified_events[0][0] == "after_model_callback_agent"
+ assert simplified_events[0][1].endswith(
+ "Update response event after model call."
+ )
diff --git a/src/google/adk/tests/integration/test_context_variable.py b/src/google/adk/tests/integration/test_context_variable.py
new file mode 100644
index 0000000..ba6af06
--- /dev/null
+++ b/src/google/adk/tests/integration/test_context_variable.py
@@ -0,0 +1,67 @@
+# Copyright 2025 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import json
+
+import pytest
+
+# Skip until fixed.
+pytest.skip(allow_module_level=True)
+
+from .fixture import context_variable_agent
+from .utils import TestRunner
+
+
+@pytest.mark.parametrize(
+ "agent_runner",
+ [{"agent": context_variable_agent.agent.state_variable_echo_agent}],
+ indirect=True,
+)
+def test_context_variable_missing(agent_runner: TestRunner):
+ with pytest.raises(KeyError) as e_info:
+ agent_runner.run("Hi echo my customer id.")
+ assert "customerId" in str(e_info.value)
+
+
+@pytest.mark.parametrize(
+ "agent_runner",
+ [{"agent": context_variable_agent.agent.state_variable_update_agent}],
+ indirect=True,
+)
+def test_context_variable_update(agent_runner: TestRunner):
+ _call_function_and_assert(
+ agent_runner,
+ "update_fc",
+ ["RRRR", "3.141529", ["apple", "banana"], [1, 3.14, "hello"]],
+ "successfully",
+ )
+
+
+def _call_function_and_assert(
+ agent_runner: TestRunner, function_name: str, params, expected
+):
+ param_section = (
+ " with params"
+ f" {params if isinstance(params, str) else json.dumps(params)}"
+ if params is not None
+ else ""
+ )
+ agent_runner.run(
+ f"Call {function_name}{param_section} and show me the result"
+ )
+
+ model_response_event = agent_runner.get_events()[-1]
+ assert model_response_event.author == "context_variable_update_agent"
+ assert model_response_event.content.role == "model"
+ assert expected in model_response_event.content.parts[0].text.strip()
diff --git a/src/google/adk/tests/integration/test_evalute_agent_in_fixture.py b/src/google/adk/tests/integration/test_evalute_agent_in_fixture.py
new file mode 100644
index 0000000..234e71b
--- /dev/null
+++ b/src/google/adk/tests/integration/test_evalute_agent_in_fixture.py
@@ -0,0 +1,76 @@
+# Copyright 2025 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Evaluate all agents in fixture folder if evaluation test files exist."""
+
+import os
+
+from google.adk.evaluation import AgentEvaluator
+import pytest
+
+def agent_eval_artifacts_in_fixture():
+ """Get all agents from fixture folder."""
+ agent_eval_artifacts = []
+ fixture_dir = os.path.join(os.path.dirname(__file__), 'fixture')
+ for agent_name in os.listdir(fixture_dir):
+ agent_dir = os.path.join(fixture_dir, agent_name)
+ if not os.path.isdir(agent_dir):
+ continue
+ for filename in os.listdir(agent_dir):
+ # Evaluation test files end with test.json
+ if not filename.endswith('test.json'):
+ continue
+ initial_session_file = (
+ f'tests/integration/fixture/{agent_name}/initial.session.json'
+ )
+ agent_eval_artifacts.append((
+ f'tests.integration.fixture.{agent_name}',
+ f'tests/integration/fixture/{agent_name}/{filename}',
+ initial_session_file
+ if os.path.exists(initial_session_file)
+ else None,
+ ))
+
+ # This method gets invoked twice, sorting helps ensure that both the
+ # invocations have the same view.
+ agent_eval_artifacts = sorted(
+ agent_eval_artifacts, key=lambda item: f'{item[0]}|{item[1]}'
+ )
+ return agent_eval_artifacts
+
+
+@pytest.mark.parametrize(
+ 'agent_name, evalfile, initial_session_file',
+ agent_eval_artifacts_in_fixture(),
+ ids=[agent_name for agent_name, _, _ in agent_eval_artifacts_in_fixture()],
+)
+def test_evaluate_agents_long_running_4_runs_per_eval_item(
+ agent_name, evalfile, initial_session_file
+):
+ """Test agents evaluation in fixture folder.
+
+ After querying the fixture folder, we have 5 eval items. For each eval item
+ we use 4 runs.
+
+ A single eval item is a session that can have multiple queries in it.
+ """
+ AgentEvaluator.evaluate(
+ agent_module=agent_name,
+ eval_dataset_file_path_or_dir=evalfile,
+ initial_session_file=initial_session_file,
+ # Using a slightly higher value helps us manange the variances that may
+ # happen in each eval.
+ # This, of course, comes at a cost of incrased test run times.
+ num_runs=4,
+ )
diff --git a/src/google/adk/tests/integration/test_multi_agent.py b/src/google/adk/tests/integration/test_multi_agent.py
new file mode 100644
index 0000000..48b5897
--- /dev/null
+++ b/src/google/adk/tests/integration/test_multi_agent.py
@@ -0,0 +1,28 @@
+# Copyright 2025 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from google.adk.evaluation import AgentEvaluator
+
+
+def test_eval_agent():
+ AgentEvaluator.evaluate(
+ agent_module="tests.integration.fixture.trip_planner_agent",
+ eval_dataset_file_path_or_dir=(
+ "tests/integration/fixture/trip_planner_agent/trip_inquiry.test.json"
+ ),
+ initial_session_file=(
+ "tests/integration/fixture/trip_planner_agent/initial.session.json"
+ ),
+ num_runs=4,
+ )
diff --git a/src/google/adk/tests/integration/test_multi_turn.py b/src/google/adk/tests/integration/test_multi_turn.py
new file mode 100644
index 0000000..0281082
--- /dev/null
+++ b/src/google/adk/tests/integration/test_multi_turn.py
@@ -0,0 +1,42 @@
+# Copyright 2025 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from google.adk.evaluation import AgentEvaluator
+
+
+def test_simple_multi_turn_conversation():
+ """Test a simple multi-turn conversation."""
+ AgentEvaluator.evaluate(
+ agent_module="tests.integration.fixture.home_automation_agent",
+ eval_dataset_file_path_or_dir="tests/integration/fixture/home_automation_agent/test_files/simple_multi_turn_conversation.test.json",
+ num_runs=4,
+ )
+
+
+def test_dependent_tool_calls():
+ """Test subsequent tool calls that are dependent on previous tool calls."""
+ AgentEvaluator.evaluate(
+ agent_module="tests.integration.fixture.home_automation_agent",
+ eval_dataset_file_path_or_dir="tests/integration/fixture/home_automation_agent/test_files/dependent_tool_calls.test.json",
+ num_runs=4,
+ )
+
+
+def test_memorizing_past_events():
+ """Test memorizing past events."""
+ AgentEvaluator.evaluate(
+ agent_module="tests.integration.fixture.home_automation_agent",
+ eval_dataset_file_path_or_dir="tests/integration/fixture/home_automation_agent/test_files/memorizing_past_events/eval_data.test.json",
+ num_runs=4,
+ )
diff --git a/src/google/adk/tests/integration/test_single_agent.py b/src/google/adk/tests/integration/test_single_agent.py
new file mode 100644
index 0000000..5c7c26d
--- /dev/null
+++ b/src/google/adk/tests/integration/test_single_agent.py
@@ -0,0 +1,23 @@
+# Copyright 2025 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from google.adk.evaluation import AgentEvaluator
+
+
+def test_eval_agent():
+ AgentEvaluator.evaluate(
+ agent_module="tests.integration.fixture.home_automation_agent",
+ eval_dataset_file_path_or_dir="tests/integration/fixture/home_automation_agent/simple_test.test.json",
+ num_runs=4,
+ )
diff --git a/src/google/adk/tests/integration/test_sub_agent.py b/src/google/adk/tests/integration/test_sub_agent.py
new file mode 100644
index 0000000..27646bf
--- /dev/null
+++ b/src/google/adk/tests/integration/test_sub_agent.py
@@ -0,0 +1,26 @@
+# Copyright 2025 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from google.adk.evaluation import AgentEvaluator
+
+
+def test_eval_agent():
+ """Test hotel sub agent in a multi-agent system."""
+ AgentEvaluator.evaluate(
+ agent_module="tests.integration.fixture.trip_planner_agent",
+ eval_dataset_file_path_or_dir="tests/integration/fixture/trip_planner_agent/test_files/trip_inquiry_sub_agent.test.json",
+ initial_session_file="tests/integration/fixture/trip_planner_agent/test_files/initial.session.json",
+ agent_name="identify_agent",
+ num_runs=4,
+ )
diff --git a/src/google/adk/tests/integration/test_system_instruction.py b/src/google/adk/tests/integration/test_system_instruction.py
new file mode 100644
index 0000000..8ce1b09
--- /dev/null
+++ b/src/google/adk/tests/integration/test_system_instruction.py
@@ -0,0 +1,177 @@
+# Copyright 2025 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import pytest
+
+# Skip until fixed.
+pytest.skip(allow_module_level=True)
+
+from google.adk.agents import InvocationContext
+from google.adk.sessions import Session
+from google.genai import types
+
+from .fixture import context_variable_agent
+from .utils import TestRunner
+
+nl_planner_si = """
+You are an intelligent tool use agent built upon the Gemini large language model. When answering the question, try to leverage the available tools to gather the information instead of your memorized knowledge.
+
+Follow this process when answering the question: (1) first come up with a plan in natural language text format; (2) Then use tools to execute the plan and provide reasoning between tool code snippets to make a summary of current state and next step. Tool code snippets and reasoning should be interleaved with each other. (3) In the end, return one final answer.
+
+Follow this format when answering the question: (1) The planning part should be under /*PLANNING*/. (2) The tool code snippets should be under /*ACTION*/, and the reasoning parts should be under /*REASONING*/. (3) The final answer part should be under /*FINAL_ANSWER*/.
+
+
+Below are the requirements for the planning:
+The plan is made to answer the user query if following the plan. The plan is coherent and covers all aspects of information from user query, and only involves the tools that are accessible by the agent. The plan contains the decomposed steps as a numbered list where each step should use one or multiple available tools. By reading the plan, you can intuitively know which tools to trigger or what actions to take.
+If the initial plan cannot be successfully executed, you should learn from previous execution results and revise your plan. The revised plan should be be under /*REPLANNING*/. Then use tools to follow the new plan.
+
+Below are the requirements for the reasoning:
+The reasoning makes a summary of the current trajectory based on the user query and tool outputs. Based on the tool outputs and plan, the reasoning also comes up with instructions to the next steps, making the trajectory closer to the final answer.
+
+
+
+Below are the requirements for the final answer:
+The final answer should be precise and follow query formatting requirements. Some queries may not be answerable with the available tools and information. In those cases, inform the user why you cannot process their query and ask for more information.
+
+
+
+Below are the requirements for the tool code:
+
+**Custom Tools:** The available tools are described in the context and can be directly used.
+- Code must be valid self-contained Python snippets with no imports and no references to tools or Python libraries that are not in the context.
+- You cannot use any parameters or fields that are not explicitly defined in the APIs in the context.
+- Use "print" to output execution results for the next step or final answer that you need for responding to the user. Never generate ```tool_outputs yourself.
+- The code snippets should be readable, efficient, and directly relevant to the user query and reasoning steps.
+- When using the tools, you should use the library name together with the function name, e.g., vertex_search.search().
+- If Python libraries are not provided in the context, NEVER write your own code other than the function calls using the provided tools.
+
+
+
+VERY IMPORTANT instruction that you MUST follow in addition to the above instructions:
+
+You should ask for clarification if you need more information to answer the question.
+You should prefer using the information available in the context instead of repeated tool use.
+
+You should ONLY generate code snippets prefixed with "```tool_code" if you need to use the tools to answer the question.
+
+If you are asked to write code by user specifically,
+- you should ALWAYS use "```python" to format the code.
+- you should NEVER put "tool_code" to format the code.
+- Good example:
+```python
+print('hello')
+```
+- Bad example:
+```tool_code
+print('hello')
+```
+"""
+
+
+@pytest.mark.parametrize(
+ "agent_runner",
+ [{"agent": context_variable_agent.agent.state_variable_echo_agent}],
+ indirect=True,
+)
+def test_context_variable(agent_runner: TestRunner):
+ session = Session(
+ context={
+ "customerId": "1234567890",
+ "customerInt": 30,
+ "customerFloat": 12.34,
+ "customerJson": {"name": "John Doe", "age": 30, "count": 11.1},
+ }
+ )
+ si = UnitFlow()._build_system_instruction(
+ InvocationContext(
+ invocation_id="1234567890", agent=agent_runner.agent, session=session
+ )
+ )
+
+ assert (
+ "Use the echo_info tool to echo 1234567890, 30, 12.34, and {'name': 'John"
+ " Doe', 'age': 30, 'count': 11.1}. Ask for it if you need to."
+ in si
+ )
+
+
+@pytest.mark.parametrize(
+ "agent_runner",
+ [{
+ "agent": (
+ context_variable_agent.agent.state_variable_with_complicated_format_agent
+ )
+ }],
+ indirect=True,
+)
+def test_context_variable_with_complicated_format(agent_runner: TestRunner):
+ session = Session(
+ context={"customerId": "1234567890", "customer_int": 30},
+ artifacts={"fileName": [types.Part(text="test artifact")]},
+ )
+ si = _context_formatter.populate_context_and_artifact_variable_values(
+ agent_runner.agent.instruction,
+ session.get_state(),
+ session.get_artifact_dict(),
+ )
+
+ assert (
+ si
+ == "Use the echo_info tool to echo 1234567890, 30, { "
+ " non-identifier-float}}, test artifact, {'key1': 'value1'} and"
+ " {{'key2': 'value2'}}. Ask for it if you need to."
+ )
+
+
+@pytest.mark.parametrize(
+ "agent_runner",
+ [{
+ "agent": (
+ context_variable_agent.agent.state_variable_with_nl_planner_agent
+ )
+ }],
+ indirect=True,
+)
+def test_nl_planner(agent_runner: TestRunner):
+ session = Session(context={"customerId": "1234567890"})
+ si = UnitFlow()._build_system_instruction(
+ InvocationContext(
+ invocation_id="1234567890",
+ agent=agent_runner.agent,
+ session=session,
+ )
+ )
+
+ for line in nl_planner_si.splitlines():
+ assert line in si
+
+
+@pytest.mark.parametrize(
+ "agent_runner",
+ [{
+ "agent": (
+ context_variable_agent.agent.state_variable_with_function_instruction_agent
+ )
+ }],
+ indirect=True,
+)
+def test_function_instruction(agent_runner: TestRunner):
+ session = Session(context={"customerId": "1234567890"})
+ si = UnitFlow()._build_system_instruction(
+ InvocationContext(
+ invocation_id="1234567890", agent=agent_runner.agent, session=session
+ )
+ )
+
+ assert "This is the plain text sub agent instruction." in si
diff --git a/src/google/adk/tests/integration/test_tools.py b/src/google/adk/tests/integration/test_tools.py
new file mode 100644
index 0000000..dab1487
--- /dev/null
+++ b/src/google/adk/tests/integration/test_tools.py
@@ -0,0 +1,287 @@
+# Copyright 2025 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import json
+
+import pytest
+
+# Skip until fixed.
+pytest.skip(allow_module_level=True)
+
+from .fixture import tool_agent
+from .utils import TestRunner
+
+
+@pytest.mark.parametrize(
+ "agent_runner",
+ [{"agent": tool_agent.agent.single_function_agent}],
+ indirect=True,
+)
+def test_single_function_calls_success(agent_runner: TestRunner):
+ _call_function_and_assert(
+ agent_runner,
+ "simple_function",
+ "test",
+ "success",
+ )
+
+
+@pytest.mark.parametrize(
+ "agent_runner",
+ [{"agent": tool_agent.agent.root_agent}],
+ indirect=True,
+)
+def test_multiple_function_calls_success(agent_runner: TestRunner):
+ _call_function_and_assert(
+ agent_runner,
+ "simple_function",
+ "test",
+ "success",
+ )
+ _call_function_and_assert(
+ agent_runner,
+ "no_param_function",
+ None,
+ "Called no param function successfully",
+ )
+ _call_function_and_assert(
+ agent_runner,
+ "no_output_function",
+ "test",
+ "",
+ )
+ _call_function_and_assert(
+ agent_runner,
+ "multiple_param_types_function",
+ ["test", 1, 2.34, True],
+ "success",
+ )
+ _call_function_and_assert(
+ agent_runner,
+ "return_list_str_function",
+ "test",
+ "success",
+ )
+ _call_function_and_assert(
+ agent_runner,
+ "list_str_param_function",
+ ["test", "test2", "test3", "test4"],
+ "success",
+ )
+
+
+@pytest.mark.skip(reason="Currently failing with 400 on MLDev.")
+@pytest.mark.parametrize(
+ "agent_runner",
+ [{"agent": tool_agent.agent.root_agent}],
+ indirect=True,
+)
+def test_complex_function_calls_success(agent_runner: TestRunner):
+ param1 = {"name": "Test", "count": 3}
+ param2 = [
+ {"name": "Function", "count": 2},
+ {"name": "Retrieval", "count": 1},
+ ]
+ _call_function_and_assert(
+ agent_runner,
+ "complex_function_list_dict",
+ [param1, param2],
+ "test",
+ )
+
+
+@pytest.mark.parametrize(
+ "agent_runner",
+ [{"agent": tool_agent.agent.root_agent}],
+ indirect=True,
+)
+def test_repetive_call_success(agent_runner: TestRunner):
+ _call_function_and_assert(
+ agent_runner,
+ "repetive_call_1",
+ "test",
+ "test_repetive",
+ )
+
+
+@pytest.mark.parametrize(
+ "agent_runner",
+ [{"agent": tool_agent.agent.root_agent}],
+ indirect=True,
+)
+def test_function_calls_fail(agent_runner: TestRunner):
+ _call_function_and_assert(
+ agent_runner,
+ "throw_error_function",
+ "test",
+ None,
+ ValueError,
+ )
+
+
+@pytest.mark.parametrize(
+ "agent_runner",
+ [{"agent": tool_agent.agent.root_agent}],
+ indirect=True,
+)
+def test_agent_tools_success(agent_runner: TestRunner):
+ _call_function_and_assert(
+ agent_runner,
+ "no_schema_agent",
+ "Hi",
+ "Hi",
+ )
+ _call_function_and_assert(
+ agent_runner,
+ "schema_agent",
+ "Agent_tools",
+ "Agent_tools_success",
+ )
+ _call_function_and_assert(
+ agent_runner, "no_input_schema_agent", "Tools", "Tools_success"
+ )
+ _call_function_and_assert(agent_runner, "no_output_schema_agent", "Hi", "Hi")
+
+
+@pytest.mark.parametrize(
+ "agent_runner",
+ [{"agent": tool_agent.agent.root_agent}],
+ indirect=True,
+)
+def test_files_retrieval_success(agent_runner: TestRunner):
+ _call_function_and_assert(
+ agent_runner,
+ "test_case_retrieval",
+ "What is the testing strategy of agent 2.0?",
+ "test",
+ )
+ # For non relevant query, the agent should still be running fine, just return
+ # response might be different for different calls, so we don't compare the
+ # response here.
+ _call_function_and_assert(
+ agent_runner,
+ "test_case_retrieval",
+ "What is the whether in bay area?",
+ "",
+ )
+
+
+@pytest.mark.parametrize(
+ "agent_runner",
+ [{"agent": tool_agent.agent.root_agent}],
+ indirect=True,
+)
+def test_rag_retrieval_success(agent_runner: TestRunner):
+ _call_function_and_assert(
+ agent_runner,
+ "valid_rag_retrieval",
+ "What is the testing strategy of agent 2.0?",
+ "test",
+ )
+ _call_function_and_assert(
+ agent_runner,
+ "valid_rag_retrieval",
+ "What is the whether in bay area?",
+ "No",
+ )
+
+
+@pytest.mark.parametrize(
+ "agent_runner",
+ [{"agent": tool_agent.agent.root_agent}],
+ indirect=True,
+)
+def test_rag_retrieval_fail(agent_runner: TestRunner):
+ _call_function_and_assert(
+ agent_runner,
+ "invalid_rag_retrieval",
+ "What is the testing strategy of agent 2.0?",
+ None,
+ ValueError,
+ )
+ _call_function_and_assert(
+ agent_runner,
+ "non_exist_rag_retrieval",
+ "What is the whether in bay area?",
+ None,
+ ValueError,
+ )
+
+
+@pytest.mark.parametrize(
+ "agent_runner",
+ [{"agent": tool_agent.agent.root_agent}],
+ indirect=True,
+)
+def test_langchain_tool_success(agent_runner: TestRunner):
+ _call_function_and_assert(
+ agent_runner,
+ "terminal",
+ "Run the following shell command 'echo test!'",
+ "test",
+ )
+
+
+@pytest.mark.parametrize(
+ "agent_runner",
+ [{"agent": tool_agent.agent.root_agent}],
+ indirect=True,
+)
+def test_crewai_tool_success(agent_runner: TestRunner):
+ _call_function_and_assert(
+ agent_runner,
+ "direcotry_read_tool",
+ "Find all the file paths",
+ "file",
+ )
+
+
+def _call_function_and_assert(
+ agent_runner: TestRunner,
+ function_name: str,
+ params,
+ expected=None,
+ exception: Exception = None,
+):
+ param_section = (
+ " with params"
+ f" {params if isinstance(params, str) else json.dumps(params)}"
+ if params is not None
+ else ""
+ )
+ query = f"Call {function_name}{param_section} and show me the result"
+ if exception:
+ _assert_raises(agent_runner, query, exception)
+ return
+
+ _assert_function_output(agent_runner, query, expected)
+
+
+def _assert_raises(agent_runner: TestRunner, query: str, exception: Exception):
+ with pytest.raises(exception):
+ agent_runner.run(query)
+
+
+def _assert_function_output(agent_runner: TestRunner, query: str, expected):
+ agent_runner.run(query)
+
+ # Retrieve the latest model response event
+ model_response_event = agent_runner.get_events()[-1]
+
+ # Assert the response content
+ assert model_response_event.content.role == "model"
+ assert (
+ expected.lower()
+ in model_response_event.content.parts[0].text.strip().lower()
+ )
diff --git a/src/google/adk/tests/integration/test_with_test_file.py b/src/google/adk/tests/integration/test_with_test_file.py
new file mode 100644
index 0000000..2024e1d
--- /dev/null
+++ b/src/google/adk/tests/integration/test_with_test_file.py
@@ -0,0 +1,34 @@
+# Copyright 2025 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from google.adk.evaluation import AgentEvaluator
+
+
+def test_with_single_test_file():
+ """Test the agent's basic ability via session file."""
+ AgentEvaluator.evaluate(
+ agent_module="tests.integration.fixture.home_automation_agent",
+ eval_dataset_file_path_or_dir="tests/integration/fixture/home_automation_agent/simple_test.test.json",
+ )
+
+
+def test_with_folder_of_test_files_long_running():
+ """Test the agent's basic ability via a folder of session files."""
+ AgentEvaluator.evaluate(
+ agent_module="tests.integration.fixture.home_automation_agent",
+ eval_dataset_file_path_or_dir=(
+ "tests/integration/fixture/home_automation_agent/test_files"
+ ),
+ num_runs=4,
+ )
diff --git a/src/google/adk/tests/integration/tools/__init__.py b/src/google/adk/tests/integration/tools/__init__.py
new file mode 100644
index 0000000..36a1e8d
--- /dev/null
+++ b/src/google/adk/tests/integration/tools/__init__.py
@@ -0,0 +1,14 @@
+# Copyright 2025 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
diff --git a/src/google/adk/tests/integration/utils/__init__.py b/src/google/adk/tests/integration/utils/__init__.py
new file mode 100644
index 0000000..cde5caf
--- /dev/null
+++ b/src/google/adk/tests/integration/utils/__init__.py
@@ -0,0 +1,16 @@
+# Copyright 2025 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from .asserts import *
+from .test_runner import TestRunner
diff --git a/src/google/adk/tests/integration/utils/asserts.py b/src/google/adk/tests/integration/utils/asserts.py
new file mode 100644
index 0000000..98d25b8
--- /dev/null
+++ b/src/google/adk/tests/integration/utils/asserts.py
@@ -0,0 +1,75 @@
+# Copyright 2025 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import TypedDict
+
+from .test_runner import TestRunner
+
+
+class Message(TypedDict):
+ agent_name: str
+ expected_text: str
+
+
+def assert_current_agent_is(agent_name: str, *, agent_runner: TestRunner):
+ assert agent_runner.get_current_agent_name() == agent_name
+
+
+def assert_agent_says(
+ expected_text: str, *, agent_name: str, agent_runner: TestRunner
+):
+ for event in reversed(agent_runner.get_events()):
+ if event.author == agent_name and event.content.parts[0].text:
+ assert event.content.parts[0].text.strip() == expected_text
+ return
+
+
+def assert_agent_says_in_order(
+ expected_conversaction: list[Message], agent_runner: TestRunner
+):
+ expected_conversaction_idx = len(expected_conversaction) - 1
+ for event in reversed(agent_runner.get_events()):
+ if event.content.parts and event.content.parts[0].text:
+ assert (
+ event.author
+ == expected_conversaction[expected_conversaction_idx]['agent_name']
+ )
+ assert (
+ event.content.parts[0].text.strip()
+ == expected_conversaction[expected_conversaction_idx]['expected_text']
+ )
+ expected_conversaction_idx -= 1
+ if expected_conversaction_idx < 0:
+ return
+
+
+def assert_agent_transfer_path(
+ expected_path: list[str], *, agent_runner: TestRunner
+):
+ events = agent_runner.get_events()
+ idx_in_expected_path = len(expected_path) - 1
+ # iterate events in reverse order
+ for event in reversed(events):
+ function_calls = event.get_function_calls()
+ if (
+ len(function_calls) == 1
+ and function_calls[0].name == 'transfer_to_agent'
+ ):
+ assert (
+ function_calls[0].args['agent_name']
+ == expected_path[idx_in_expected_path]
+ )
+ idx_in_expected_path -= 1
+ if idx_in_expected_path < 0:
+ return
diff --git a/src/google/adk/tests/integration/utils/test_runner.py b/src/google/adk/tests/integration/utils/test_runner.py
new file mode 100644
index 0000000..13e7566
--- /dev/null
+++ b/src/google/adk/tests/integration/utils/test_runner.py
@@ -0,0 +1,97 @@
+# Copyright 2025 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import importlib
+from typing import Optional
+
+from google.adk import Agent
+from google.adk import Runner
+from google.adk.artifacts import BaseArtifactService
+from google.adk.artifacts import InMemoryArtifactService
+from google.adk.events import Event
+from google.adk.sessions import BaseSessionService
+from google.adk.sessions import InMemorySessionService
+from google.adk.sessions import Session
+from google.genai import types
+
+
+class TestRunner:
+ """Agents runner for testings."""
+
+ app_name = "test_app"
+ user_id = "test_user"
+
+ def __init__(
+ self,
+ agent: Agent,
+ artifact_service: BaseArtifactService = InMemoryArtifactService(),
+ session_service: BaseSessionService = InMemorySessionService(),
+ ) -> None:
+ self.agent = agent
+ self.agent_client = Runner(
+ app_name=self.app_name,
+ agent=agent,
+ artifact_service=artifact_service,
+ session_service=session_service,
+ )
+ self.session_service = session_service
+ self.current_session_id = session_service.create_session(
+ app_name=self.app_name, user_id=self.user_id
+ ).id
+
+ def new_session(self, session_id: Optional[str] = None) -> None:
+ self.current_session_id = self.session_service.create_session(
+ app_name=self.app_name, user_id=self.user_id, session_id=session_id
+ ).id
+
+ def run(self, prompt: str) -> list[Event]:
+ current_session = self.session_service.get_session(
+ app_name=self.app_name,
+ user_id=self.user_id,
+ session_id=self.current_session_id,
+ )
+ assert current_session is not None
+
+ return list(
+ self.agent_client.run(
+ user_id=current_session.user_id,
+ session_id=current_session.id,
+ new_message=types.Content(
+ role="user",
+ parts=[types.Part.from_text(text=prompt)],
+ ),
+ )
+ )
+
+ def get_current_session(self) -> Optional[Session]:
+ return self.session_service.get_session(
+ app_name=self.app_name,
+ user_id=self.user_id,
+ session_id=self.current_session_id,
+ )
+
+ def get_events(self) -> list[Event]:
+ return self.get_current_session().events
+
+ @classmethod
+ def from_agent_name(cls, agent_name: str):
+ agent_module_path = f"tests.integration.fixture.{agent_name}"
+ agent_module = importlib.import_module(agent_module_path)
+ agent: Agent = agent_module.agent.root_agent
+ return cls(agent)
+
+ def get_current_agent_name(self) -> str:
+ return self.agent_client._find_agent_to_run(
+ self.get_current_session(), self.agent
+ ).name
diff --git a/src/google/adk/tests/unittests/__init__.py b/src/google/adk/tests/unittests/__init__.py
new file mode 100644
index 0000000..36a1e8d
--- /dev/null
+++ b/src/google/adk/tests/unittests/__init__.py
@@ -0,0 +1,14 @@
+# Copyright 2025 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
diff --git a/src/google/adk/tests/unittests/agents/__init__.py b/src/google/adk/tests/unittests/agents/__init__.py
new file mode 100644
index 0000000..36a1e8d
--- /dev/null
+++ b/src/google/adk/tests/unittests/agents/__init__.py
@@ -0,0 +1,14 @@
+# Copyright 2025 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
diff --git a/src/google/adk/tests/unittests/agents/test_base_agent.py b/src/google/adk/tests/unittests/agents/test_base_agent.py
new file mode 100644
index 0000000..d4e7387
--- /dev/null
+++ b/src/google/adk/tests/unittests/agents/test_base_agent.py
@@ -0,0 +1,407 @@
+# Copyright 2025 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Testings for the BaseAgent."""
+
+from typing import AsyncGenerator
+from typing import Optional
+from typing import Union
+
+from google.adk.agents.base_agent import BaseAgent
+from google.adk.agents.callback_context import CallbackContext
+from google.adk.agents.invocation_context import InvocationContext
+from google.adk.events import Event
+from google.adk.sessions.in_memory_session_service import InMemorySessionService
+from google.genai import types
+import pytest
+import pytest_mock
+from typing_extensions import override
+
+
+def _before_agent_callback_noop(callback_context: CallbackContext) -> None:
+ pass
+
+
+def _before_agent_callback_bypass_agent(
+ callback_context: CallbackContext,
+) -> types.Content:
+ return types.Content(parts=[types.Part(text='agent run is bypassed.')])
+
+
+def _after_agent_callback_noop(callback_context: CallbackContext) -> None:
+ pass
+
+
+def _after_agent_callback_append_agent_reply(
+ callback_context: CallbackContext,
+) -> types.Content:
+ return types.Content(
+ parts=[types.Part(text='Agent reply from after agent callback.')]
+ )
+
+
+class _IncompleteAgent(BaseAgent):
+ pass
+
+
+class _TestingAgent(BaseAgent):
+
+ @override
+ async def _run_async_impl(
+ self, ctx: InvocationContext
+ ) -> AsyncGenerator[Event, None]:
+ yield Event(
+ author=self.name,
+ branch=ctx.branch,
+ invocation_id=ctx.invocation_id,
+ content=types.Content(parts=[types.Part(text='Hello, world!')]),
+ )
+
+ @override
+ async def _run_live_impl(
+ self, ctx: InvocationContext
+ ) -> AsyncGenerator[Event, None]:
+ yield Event(
+ author=self.name,
+ invocation_id=ctx.invocation_id,
+ branch=ctx.branch,
+ content=types.Content(parts=[types.Part(text='Hello, live!')]),
+ )
+
+
+def _create_parent_invocation_context(
+ test_name: str, agent: BaseAgent, branch: Optional[str] = None
+) -> InvocationContext:
+ session_service = InMemorySessionService()
+ session = session_service.create_session(
+ app_name='test_app', user_id='test_user'
+ )
+ return InvocationContext(
+ invocation_id=f'{test_name}_invocation_id',
+ branch=branch,
+ agent=agent,
+ session=session,
+ session_service=session_service,
+ )
+
+
+def test_invalid_agent_name():
+ with pytest.raises(ValueError):
+ _ = _TestingAgent(name='not an identifier')
+
+
+@pytest.mark.asyncio
+async def test_run_async(request: pytest.FixtureRequest):
+ agent = _TestingAgent(name=f'{request.function.__name__}_test_agent')
+ parent_ctx = _create_parent_invocation_context(
+ request.function.__name__, agent
+ )
+
+ events = [e async for e in agent.run_async(parent_ctx)]
+
+ assert len(events) == 1
+ assert events[0].author == agent.name
+ assert events[0].content.parts[0].text == 'Hello, world!'
+
+
+@pytest.mark.asyncio
+async def test_run_async_with_branch(request: pytest.FixtureRequest):
+ agent = _TestingAgent(name=f'{request.function.__name__}_test_agent')
+ parent_ctx = _create_parent_invocation_context(
+ request.function.__name__, agent, branch='parent_branch'
+ )
+
+ events = [e async for e in agent.run_async(parent_ctx)]
+
+ assert len(events) == 1
+ assert events[0].author == agent.name
+ assert events[0].content.parts[0].text == 'Hello, world!'
+ assert events[0].branch.endswith(agent.name)
+
+
+@pytest.mark.asyncio
+async def test_run_async_before_agent_callback_noop(
+ request: pytest.FixtureRequest,
+ mocker: pytest_mock.MockerFixture,
+) -> Union[types.Content, None]:
+ # Arrange
+ agent = _TestingAgent(
+ name=f'{request.function.__name__}_test_agent',
+ before_agent_callback=_before_agent_callback_noop,
+ )
+ parent_ctx = _create_parent_invocation_context(
+ request.function.__name__, agent
+ )
+ spy_run_async_impl = mocker.spy(agent, BaseAgent._run_async_impl.__name__)
+ spy_before_agent_callback = mocker.spy(agent, 'before_agent_callback')
+
+ # Act
+ _ = [e async for e in agent.run_async(parent_ctx)]
+
+ # Assert
+ spy_before_agent_callback.assert_called_once()
+ _, kwargs = spy_before_agent_callback.call_args
+ assert 'callback_context' in kwargs
+ assert isinstance(kwargs['callback_context'], CallbackContext)
+
+ spy_run_async_impl.assert_called_once()
+
+
+@pytest.mark.asyncio
+async def test_run_async_before_agent_callback_bypass_agent(
+ request: pytest.FixtureRequest,
+ mocker: pytest_mock.MockerFixture,
+):
+ # Arrange
+ agent = _TestingAgent(
+ name=f'{request.function.__name__}_test_agent',
+ before_agent_callback=_before_agent_callback_bypass_agent,
+ )
+ parent_ctx = _create_parent_invocation_context(
+ request.function.__name__, agent
+ )
+ spy_run_async_impl = mocker.spy(agent, BaseAgent._run_async_impl.__name__)
+ spy_before_agent_callback = mocker.spy(agent, 'before_agent_callback')
+
+ # Act
+ events = [e async for e in agent.run_async(parent_ctx)]
+
+ # Assert
+ spy_before_agent_callback.assert_called_once()
+ spy_run_async_impl.assert_not_called()
+
+ assert len(events) == 1
+ assert events[0].content.parts[0].text == 'agent run is bypassed.'
+
+
+@pytest.mark.asyncio
+async def test_run_async_after_agent_callback_noop(
+ request: pytest.FixtureRequest,
+ mocker: pytest_mock.MockerFixture,
+):
+ # Arrange
+ agent = _TestingAgent(
+ name=f'{request.function.__name__}_test_agent',
+ after_agent_callback=_after_agent_callback_noop,
+ )
+ parent_ctx = _create_parent_invocation_context(
+ request.function.__name__, agent
+ )
+ spy_after_agent_callback = mocker.spy(agent, 'after_agent_callback')
+
+ # Act
+ events = [e async for e in agent.run_async(parent_ctx)]
+
+ # Assert
+ spy_after_agent_callback.assert_called_once()
+ _, kwargs = spy_after_agent_callback.call_args
+ assert 'callback_context' in kwargs
+ assert isinstance(kwargs['callback_context'], CallbackContext)
+ assert len(events) == 1
+
+
+@pytest.mark.asyncio
+async def test_run_async_after_agent_callback_append_reply(
+ request: pytest.FixtureRequest,
+):
+ # Arrange
+ agent = _TestingAgent(
+ name=f'{request.function.__name__}_test_agent',
+ after_agent_callback=_after_agent_callback_append_agent_reply,
+ )
+ parent_ctx = _create_parent_invocation_context(
+ request.function.__name__, agent
+ )
+
+ # Act
+ events = [e async for e in agent.run_async(parent_ctx)]
+
+ # Assert
+ assert len(events) == 2
+ assert events[1].author == agent.name
+ assert (
+ events[1].content.parts[0].text
+ == 'Agent reply from after agent callback.'
+ )
+
+
+@pytest.mark.asyncio
+async def test_run_async_incomplete_agent(request: pytest.FixtureRequest):
+ agent = _IncompleteAgent(name=f'{request.function.__name__}_test_agent')
+ parent_ctx = _create_parent_invocation_context(
+ request.function.__name__, agent
+ )
+
+ with pytest.raises(NotImplementedError):
+ [e async for e in agent.run_async(parent_ctx)]
+
+
+@pytest.mark.asyncio
+async def test_run_live(request: pytest.FixtureRequest):
+ agent = _TestingAgent(name=f'{request.function.__name__}_test_agent')
+ parent_ctx = _create_parent_invocation_context(
+ request.function.__name__, agent
+ )
+
+ events = [e async for e in agent.run_live(parent_ctx)]
+
+ assert len(events) == 1
+ assert events[0].author == agent.name
+ assert events[0].content.parts[0].text == 'Hello, live!'
+
+
+@pytest.mark.asyncio
+async def test_run_live_with_branch(request: pytest.FixtureRequest):
+ agent = _TestingAgent(name=f'{request.function.__name__}_test_agent')
+ parent_ctx = _create_parent_invocation_context(
+ request.function.__name__, agent, branch='parent_branch'
+ )
+
+ events = [e async for e in agent.run_live(parent_ctx)]
+
+ assert len(events) == 1
+ assert events[0].author == agent.name
+ assert events[0].content.parts[0].text == 'Hello, live!'
+ assert events[0].branch.endswith(agent.name)
+
+
+@pytest.mark.asyncio
+async def test_run_live_incomplete_agent(request: pytest.FixtureRequest):
+ agent = _IncompleteAgent(name=f'{request.function.__name__}_test_agent')
+ parent_ctx = _create_parent_invocation_context(
+ request.function.__name__, agent
+ )
+
+ with pytest.raises(NotImplementedError):
+ [e async for e in agent.run_live(parent_ctx)]
+
+
+def test_set_parent_agent_for_sub_agents(request: pytest.FixtureRequest):
+ sub_agents: list[BaseAgent] = [
+ _TestingAgent(name=f'{request.function.__name__}_sub_agent_1'),
+ _TestingAgent(name=f'{request.function.__name__}_sub_agent_2'),
+ ]
+ parent = _TestingAgent(
+ name=f'{request.function.__name__}_parent',
+ sub_agents=sub_agents,
+ )
+
+ for sub_agent in sub_agents:
+ assert sub_agent.parent_agent == parent
+
+
+def test_find_agent(request: pytest.FixtureRequest):
+ grand_sub_agent_1 = _TestingAgent(
+ name=f'{request.function.__name__}__grand_sub_agent_1'
+ )
+ grand_sub_agent_2 = _TestingAgent(
+ name=f'{request.function.__name__}__grand_sub_agent_2'
+ )
+ sub_agent_1 = _TestingAgent(
+ name=f'{request.function.__name__}_sub_agent_1',
+ sub_agents=[grand_sub_agent_1],
+ )
+ sub_agent_2 = _TestingAgent(
+ name=f'{request.function.__name__}_sub_agent_2',
+ sub_agents=[grand_sub_agent_2],
+ )
+ parent = _TestingAgent(
+ name=f'{request.function.__name__}_parent',
+ sub_agents=[sub_agent_1, sub_agent_2],
+ )
+
+ assert parent.find_agent(parent.name) == parent
+ assert parent.find_agent(sub_agent_1.name) == sub_agent_1
+ assert parent.find_agent(sub_agent_2.name) == sub_agent_2
+ assert parent.find_agent(grand_sub_agent_1.name) == grand_sub_agent_1
+ assert parent.find_agent(grand_sub_agent_2.name) == grand_sub_agent_2
+ assert sub_agent_1.find_agent(grand_sub_agent_1.name) == grand_sub_agent_1
+ assert sub_agent_1.find_agent(grand_sub_agent_2.name) is None
+ assert sub_agent_2.find_agent(grand_sub_agent_1.name) is None
+ assert sub_agent_2.find_agent(sub_agent_2.name) == sub_agent_2
+ assert parent.find_agent('not_exist') is None
+
+
+def test_find_sub_agent(request: pytest.FixtureRequest):
+ grand_sub_agent_1 = _TestingAgent(
+ name=f'{request.function.__name__}__grand_sub_agent_1'
+ )
+ grand_sub_agent_2 = _TestingAgent(
+ name=f'{request.function.__name__}__grand_sub_agent_2'
+ )
+ sub_agent_1 = _TestingAgent(
+ name=f'{request.function.__name__}_sub_agent_1',
+ sub_agents=[grand_sub_agent_1],
+ )
+ sub_agent_2 = _TestingAgent(
+ name=f'{request.function.__name__}_sub_agent_2',
+ sub_agents=[grand_sub_agent_2],
+ )
+ parent = _TestingAgent(
+ name=f'{request.function.__name__}_parent',
+ sub_agents=[sub_agent_1, sub_agent_2],
+ )
+
+ assert parent.find_sub_agent(sub_agent_1.name) == sub_agent_1
+ assert parent.find_sub_agent(sub_agent_2.name) == sub_agent_2
+ assert parent.find_sub_agent(grand_sub_agent_1.name) == grand_sub_agent_1
+ assert parent.find_sub_agent(grand_sub_agent_2.name) == grand_sub_agent_2
+ assert sub_agent_1.find_sub_agent(grand_sub_agent_1.name) == grand_sub_agent_1
+ assert sub_agent_1.find_sub_agent(grand_sub_agent_2.name) is None
+ assert sub_agent_2.find_sub_agent(grand_sub_agent_1.name) is None
+ assert sub_agent_2.find_sub_agent(grand_sub_agent_2.name) == grand_sub_agent_2
+ assert parent.find_sub_agent(parent.name) is None
+ assert parent.find_sub_agent('not_exist') is None
+
+
+def test_root_agent(request: pytest.FixtureRequest):
+ grand_sub_agent_1 = _TestingAgent(
+ name=f'{request.function.__name__}__grand_sub_agent_1'
+ )
+ grand_sub_agent_2 = _TestingAgent(
+ name=f'{request.function.__name__}__grand_sub_agent_2'
+ )
+ sub_agent_1 = _TestingAgent(
+ name=f'{request.function.__name__}_sub_agent_1',
+ sub_agents=[grand_sub_agent_1],
+ )
+ sub_agent_2 = _TestingAgent(
+ name=f'{request.function.__name__}_sub_agent_2',
+ sub_agents=[grand_sub_agent_2],
+ )
+ parent = _TestingAgent(
+ name=f'{request.function.__name__}_parent',
+ sub_agents=[sub_agent_1, sub_agent_2],
+ )
+
+ assert parent.root_agent == parent
+ assert sub_agent_1.root_agent == parent
+ assert sub_agent_2.root_agent == parent
+ assert grand_sub_agent_1.root_agent == parent
+ assert grand_sub_agent_2.root_agent == parent
+
+
+def test_set_parent_agent_for_sub_agent_twice(
+ request: pytest.FixtureRequest,
+):
+ sub_agent = _TestingAgent(name=f'{request.function.__name__}_sub_agent')
+ _ = _TestingAgent(
+ name=f'{request.function.__name__}_parent_1',
+ sub_agents=[sub_agent],
+ )
+ with pytest.raises(ValueError):
+ _ = _TestingAgent(
+ name=f'{request.function.__name__}_parent_2',
+ sub_agents=[sub_agent],
+ )
diff --git a/src/google/adk/tests/unittests/agents/test_langgraph_agent.py b/src/google/adk/tests/unittests/agents/test_langgraph_agent.py
new file mode 100644
index 0000000..aa6cb6a
--- /dev/null
+++ b/src/google/adk/tests/unittests/agents/test_langgraph_agent.py
@@ -0,0 +1,191 @@
+# Copyright 2025 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from unittest.mock import MagicMock
+
+from google.adk.agents.invocation_context import InvocationContext
+from google.adk.agents.langgraph_agent import LangGraphAgent
+from google.adk.events import Event
+from google.genai import types
+from langchain_core.messages import AIMessage
+from langchain_core.messages import HumanMessage
+from langchain_core.messages import SystemMessage
+from langgraph.graph.graph import CompiledGraph
+import pytest
+
+
+@pytest.mark.parametrize(
+ "checkpointer_value, events_list, expected_messages",
+ [
+ (
+ MagicMock(),
+ [
+ Event(
+ invocation_id="test_invocation_id",
+ author="user",
+ content=types.Content(
+ role="user",
+ parts=[types.Part.from_text(text="test prompt")],
+ ),
+ ),
+ Event(
+ invocation_id="test_invocation_id",
+ author="root_agent",
+ content=types.Content(
+ role="model",
+ parts=[types.Part.from_text(text="(some delegation)")],
+ ),
+ ),
+ ],
+ [
+ SystemMessage(content="test system prompt"),
+ HumanMessage(content="test prompt"),
+ ],
+ ),
+ (
+ None,
+ [
+ Event(
+ invocation_id="test_invocation_id",
+ author="user",
+ content=types.Content(
+ role="user",
+ parts=[types.Part.from_text(text="user prompt 1")],
+ ),
+ ),
+ Event(
+ invocation_id="test_invocation_id",
+ author="root_agent",
+ content=types.Content(
+ role="model",
+ parts=[
+ types.Part.from_text(text="root agent response")
+ ],
+ ),
+ ),
+ Event(
+ invocation_id="test_invocation_id",
+ author="weather_agent",
+ content=types.Content(
+ role="model",
+ parts=[
+ types.Part.from_text(text="weather agent response")
+ ],
+ ),
+ ),
+ Event(
+ invocation_id="test_invocation_id",
+ author="user",
+ content=types.Content(
+ role="user",
+ parts=[types.Part.from_text(text="user prompt 2")],
+ ),
+ ),
+ ],
+ [
+ SystemMessage(content="test system prompt"),
+ HumanMessage(content="user prompt 1"),
+ AIMessage(content="weather agent response"),
+ HumanMessage(content="user prompt 2"),
+ ],
+ ),
+ (
+ MagicMock(),
+ [
+ Event(
+ invocation_id="test_invocation_id",
+ author="user",
+ content=types.Content(
+ role="user",
+ parts=[types.Part.from_text(text="user prompt 1")],
+ ),
+ ),
+ Event(
+ invocation_id="test_invocation_id",
+ author="root_agent",
+ content=types.Content(
+ role="model",
+ parts=[
+ types.Part.from_text(text="root agent response")
+ ],
+ ),
+ ),
+ Event(
+ invocation_id="test_invocation_id",
+ author="weather_agent",
+ content=types.Content(
+ role="model",
+ parts=[
+ types.Part.from_text(text="weather agent response")
+ ],
+ ),
+ ),
+ Event(
+ invocation_id="test_invocation_id",
+ author="user",
+ content=types.Content(
+ role="user",
+ parts=[types.Part.from_text(text="user prompt 2")],
+ ),
+ ),
+ ],
+ [
+ SystemMessage(content="test system prompt"),
+ HumanMessage(content="user prompt 2"),
+ ],
+ ),
+ ],
+)
+@pytest.mark.asyncio
+async def test_langgraph_agent(
+ checkpointer_value, events_list, expected_messages
+):
+ mock_graph = MagicMock(spec=CompiledGraph)
+ mock_graph_state = MagicMock()
+ mock_graph_state.values = {}
+ mock_graph.get_state.return_value = mock_graph_state
+
+ mock_graph.checkpointer = checkpointer_value
+ mock_graph.invoke.return_value = {
+ "messages": [AIMessage(content="test response")]
+ }
+
+ mock_parent_context = MagicMock(spec=InvocationContext)
+ mock_session = MagicMock()
+ mock_parent_context.session = mock_session
+ mock_parent_context.branch = "parent_agent"
+ mock_parent_context.end_invocation = False
+ mock_session.events = events_list
+ mock_parent_context.invocation_id = "test_invocation_id"
+ mock_parent_context.model_copy.return_value = mock_parent_context
+
+ weather_agent = LangGraphAgent(
+ name="weather_agent",
+ description="A agent that answers weather questions",
+ instruction="test system prompt",
+ graph=mock_graph,
+ )
+
+ result_event = None
+ async for event in weather_agent.run_async(mock_parent_context):
+ result_event = event
+
+ assert result_event.author == "weather_agent"
+ assert result_event.content.parts[0].text == "test response"
+
+ mock_graph.invoke.assert_called_once()
+ mock_graph.invoke.assert_called_with(
+ {"messages": expected_messages},
+ {"configurable": {"thread_id": mock_session.id}},
+ )
diff --git a/src/google/adk/tests/unittests/agents/test_llm_agent_callbacks.py b/src/google/adk/tests/unittests/agents/test_llm_agent_callbacks.py
new file mode 100644
index 0000000..377e1cf
--- /dev/null
+++ b/src/google/adk/tests/unittests/agents/test_llm_agent_callbacks.py
@@ -0,0 +1,138 @@
+# Copyright 2025 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import Any
+from typing import Optional
+
+from google.adk.agents.callback_context import CallbackContext
+from google.adk.agents.llm_agent import Agent
+from google.adk.models import LlmRequest
+from google.adk.models import LlmResponse
+from google.genai import types
+from pydantic import BaseModel
+import pytest
+
+from .. import utils
+
+
+class MockBeforeModelCallback(BaseModel):
+ mock_response: str
+
+ def __call__(
+ self,
+ callback_context: CallbackContext,
+ llm_request: LlmRequest,
+ ) -> LlmResponse:
+ return LlmResponse(
+ content=utils.ModelContent(
+ [types.Part.from_text(text=self.mock_response)]
+ )
+ )
+
+
+class MockAfterModelCallback(BaseModel):
+ mock_response: str
+
+ def __call__(
+ self,
+ callback_context: CallbackContext,
+ llm_response: LlmResponse,
+ ) -> LlmResponse:
+ return LlmResponse(
+ content=utils.ModelContent(
+ [types.Part.from_text(text=self.mock_response)]
+ )
+ )
+
+
+def noop_callback(**kwargs) -> Optional[LlmResponse]:
+ pass
+
+
+@pytest.mark.asyncio
+async def test_before_model_callback():
+ responses = ['model_response']
+ mock_model = utils.MockModel.create(responses=responses)
+ agent = Agent(
+ name='root_agent',
+ model=mock_model,
+ before_model_callback=MockBeforeModelCallback(
+ mock_response='before_model_callback'
+ ),
+ )
+
+ runner = utils.TestInMemoryRunner(agent)
+ assert utils.simplify_events(
+ await runner.run_async_with_new_session('test')
+ ) == [
+ ('root_agent', 'before_model_callback'),
+ ]
+
+
+@pytest.mark.asyncio
+async def test_before_model_callback_noop():
+ responses = ['model_response']
+ mock_model = utils.MockModel.create(responses=responses)
+ agent = Agent(
+ name='root_agent',
+ model=mock_model,
+ before_model_callback=noop_callback,
+ )
+
+ runner = utils.TestInMemoryRunner(agent)
+ assert utils.simplify_events(
+ await runner.run_async_with_new_session('test')
+ ) == [
+ ('root_agent', 'model_response'),
+ ]
+
+
+@pytest.mark.asyncio
+async def test_before_model_callback_end():
+ responses = ['model_response']
+ mock_model = utils.MockModel.create(responses=responses)
+ agent = Agent(
+ name='root_agent',
+ model=mock_model,
+ before_model_callback=MockBeforeModelCallback(
+ mock_response='before_model_callback',
+ ),
+ )
+
+ runner = utils.TestInMemoryRunner(agent)
+ assert utils.simplify_events(
+ await runner.run_async_with_new_session('test')
+ ) == [
+ ('root_agent', 'before_model_callback'),
+ ]
+
+
+@pytest.mark.asyncio
+async def test_after_model_callback():
+ responses = ['model_response']
+ mock_model = utils.MockModel.create(responses=responses)
+ agent = Agent(
+ name='root_agent',
+ model=mock_model,
+ after_model_callback=MockAfterModelCallback(
+ mock_response='after_model_callback'
+ ),
+ )
+
+ runner = utils.TestInMemoryRunner(agent)
+ assert utils.simplify_events(
+ await runner.run_async_with_new_session('test')
+ ) == [
+ ('root_agent', 'after_model_callback'),
+ ]
diff --git a/src/google/adk/tests/unittests/agents/test_llm_agent_fields.py b/src/google/adk/tests/unittests/agents/test_llm_agent_fields.py
new file mode 100644
index 0000000..42ad5ca
--- /dev/null
+++ b/src/google/adk/tests/unittests/agents/test_llm_agent_fields.py
@@ -0,0 +1,231 @@
+# Copyright 2025 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Unit tests for canonical_xxx fields in LlmAgent."""
+
+from typing import Any
+from typing import Optional
+
+from google.adk.agents.callback_context import CallbackContext
+from google.adk.agents.invocation_context import InvocationContext
+from google.adk.agents.llm_agent import LlmAgent
+from google.adk.agents.loop_agent import LoopAgent
+from google.adk.agents.readonly_context import ReadonlyContext
+from google.adk.models.llm_request import LlmRequest
+from google.adk.models.registry import LLMRegistry
+from google.adk.sessions.in_memory_session_service import InMemorySessionService
+from google.genai import types
+from pydantic import BaseModel
+import pytest
+
+
+def _create_readonly_context(
+ agent: LlmAgent, state: Optional[dict[str, Any]] = None
+) -> ReadonlyContext:
+ session_service = InMemorySessionService()
+ session = session_service.create_session(
+ app_name='test_app', user_id='test_user', state=state
+ )
+ invocation_context = InvocationContext(
+ invocation_id='test_id',
+ agent=agent,
+ session=session,
+ session_service=session_service,
+ )
+ return ReadonlyContext(invocation_context)
+
+
+def test_canonical_model_empty():
+ agent = LlmAgent(name='test_agent')
+
+ with pytest.raises(ValueError):
+ _ = agent.canonical_model
+
+
+def test_canonical_model_str():
+ agent = LlmAgent(name='test_agent', model='gemini-pro')
+
+ assert agent.canonical_model.model == 'gemini-pro'
+
+
+def test_canonical_model_llm():
+ llm = LLMRegistry.new_llm('gemini-pro')
+ agent = LlmAgent(name='test_agent', model=llm)
+
+ assert agent.canonical_model == llm
+
+
+def test_canonical_model_inherit():
+ sub_agent = LlmAgent(name='sub_agent')
+ parent_agent = LlmAgent(
+ name='parent_agent', model='gemini-pro', sub_agents=[sub_agent]
+ )
+
+ assert sub_agent.canonical_model == parent_agent.canonical_model
+
+
+def test_canonical_instruction_str():
+ agent = LlmAgent(name='test_agent', instruction='instruction')
+ ctx = _create_readonly_context(agent)
+
+ assert agent.canonical_instruction(ctx) == 'instruction'
+
+
+def test_canonical_instruction():
+ def _instruction_provider(ctx: ReadonlyContext) -> str:
+ return f'instruction: {ctx.state["state_var"]}'
+
+ agent = LlmAgent(name='test_agent', instruction=_instruction_provider)
+ ctx = _create_readonly_context(agent, state={'state_var': 'state_value'})
+
+ assert agent.canonical_instruction(ctx) == 'instruction: state_value'
+
+
+def test_canonical_global_instruction_str():
+ agent = LlmAgent(name='test_agent', global_instruction='global instruction')
+ ctx = _create_readonly_context(agent)
+
+ assert agent.canonical_global_instruction(ctx) == 'global instruction'
+
+
+def test_canonical_global_instruction():
+ def _global_instruction_provider(ctx: ReadonlyContext) -> str:
+ return f'global instruction: {ctx.state["state_var"]}'
+
+ agent = LlmAgent(
+ name='test_agent', global_instruction=_global_instruction_provider
+ )
+ ctx = _create_readonly_context(agent, state={'state_var': 'state_value'})
+
+ assert (
+ agent.canonical_global_instruction(ctx)
+ == 'global instruction: state_value'
+ )
+
+
+def test_output_schema_will_disable_transfer(caplog: pytest.LogCaptureFixture):
+ with caplog.at_level('WARNING'):
+
+ class Schema(BaseModel):
+ pass
+
+ agent = LlmAgent(
+ name='test_agent',
+ output_schema=Schema,
+ )
+
+ # Transfer is automatically disabled
+ assert agent.disallow_transfer_to_parent
+ assert agent.disallow_transfer_to_peers
+ assert (
+ 'output_schema cannot co-exist with agent transfer configurations.'
+ in caplog.text
+ )
+
+
+def test_output_schema_with_sub_agents_will_throw():
+ class Schema(BaseModel):
+ pass
+
+ sub_agent = LlmAgent(
+ name='sub_agent',
+ )
+
+ with pytest.raises(ValueError):
+ _ = LlmAgent(
+ name='test_agent',
+ output_schema=Schema,
+ sub_agents=[sub_agent],
+ )
+
+
+def test_output_schema_with_tools_will_throw():
+ class Schema(BaseModel):
+ pass
+
+ def _a_tool():
+ pass
+
+ with pytest.raises(ValueError):
+ _ = LlmAgent(
+ name='test_agent',
+ output_schema=Schema,
+ tools=[_a_tool],
+ )
+
+
+def test_before_model_callback():
+ def _before_model_callback(
+ callback_context: CallbackContext,
+ llm_request: LlmRequest,
+ ) -> None:
+ return None
+
+ agent = LlmAgent(
+ name='test_agent', before_model_callback=_before_model_callback
+ )
+
+ # TODO: add more logic assertions later.
+ assert agent.before_model_callback is not None
+
+
+def test_validate_generate_content_config_thinking_config_throw():
+ with pytest.raises(ValueError):
+ _ = LlmAgent(
+ name='test_agent',
+ generate_content_config=types.GenerateContentConfig(
+ thinking_config=types.ThinkingConfig()
+ ),
+ )
+
+
+def test_validate_generate_content_config_tools_throw():
+ with pytest.raises(ValueError):
+ _ = LlmAgent(
+ name='test_agent',
+ generate_content_config=types.GenerateContentConfig(
+ tools=[types.Tool(function_declarations=[])]
+ ),
+ )
+
+
+def test_validate_generate_content_config_system_instruction_throw():
+ with pytest.raises(ValueError):
+ _ = LlmAgent(
+ name='test_agent',
+ generate_content_config=types.GenerateContentConfig(
+ system_instruction='system instruction'
+ ),
+ )
+
+
+def test_validate_generate_content_config_response_schema_throw():
+ class Schema(BaseModel):
+ pass
+
+ with pytest.raises(ValueError):
+ _ = LlmAgent(
+ name='test_agent',
+ generate_content_config=types.GenerateContentConfig(
+ response_schema=Schema
+ ),
+ )
+
+
+def test_allow_transfer_by_default():
+ sub_agent = LlmAgent(name='sub_agent')
+ agent = LlmAgent(name='test_agent', sub_agents=[sub_agent])
+
+ assert not agent.disallow_transfer_to_parent
+ assert not agent.disallow_transfer_to_peers
diff --git a/src/google/adk/tests/unittests/agents/test_loop_agent.py b/src/google/adk/tests/unittests/agents/test_loop_agent.py
new file mode 100644
index 0000000..deafaf2
--- /dev/null
+++ b/src/google/adk/tests/unittests/agents/test_loop_agent.py
@@ -0,0 +1,136 @@
+# Copyright 2025 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Testings for the SequentialAgent."""
+
+from typing import AsyncGenerator
+
+from google.adk.agents.base_agent import BaseAgent
+from google.adk.agents.invocation_context import InvocationContext
+from google.adk.agents.loop_agent import LoopAgent
+from google.adk.events import Event
+from google.adk.events import EventActions
+from google.adk.sessions.in_memory_session_service import InMemorySessionService
+from google.genai import types
+import pytest
+from typing_extensions import override
+
+
+class _TestingAgent(BaseAgent):
+
+ @override
+ async def _run_async_impl(
+ self, ctx: InvocationContext
+ ) -> AsyncGenerator[Event, None]:
+ yield Event(
+ author=self.name,
+ invocation_id=ctx.invocation_id,
+ content=types.Content(
+ parts=[types.Part(text=f'Hello, async {self.name}!')]
+ ),
+ )
+
+ @override
+ async def _run_live_impl(
+ self, ctx: InvocationContext
+ ) -> AsyncGenerator[Event, None]:
+ yield Event(
+ author=self.name,
+ invocation_id=ctx.invocation_id,
+ content=types.Content(
+ parts=[types.Part(text=f'Hello, live {self.name}!')]
+ ),
+ )
+
+
+class _TestingAgentWithEscalateAction(BaseAgent):
+
+ @override
+ async def _run_async_impl(
+ self, ctx: InvocationContext
+ ) -> AsyncGenerator[Event, None]:
+ yield Event(
+ author=self.name,
+ invocation_id=ctx.invocation_id,
+ content=types.Content(
+ parts=[types.Part(text=f'Hello, async {self.name}!')]
+ ),
+ actions=EventActions(escalate=True),
+ )
+
+
+def _create_parent_invocation_context(
+ test_name: str, agent: BaseAgent
+) -> InvocationContext:
+ session_service = InMemorySessionService()
+ session = session_service.create_session(
+ app_name='test_app', user_id='test_user'
+ )
+ return InvocationContext(
+ invocation_id=f'{test_name}_invocation_id',
+ agent=agent,
+ session=session,
+ session_service=session_service,
+ )
+
+
+@pytest.mark.asyncio
+async def test_run_async(request: pytest.FixtureRequest):
+ agent = _TestingAgent(name=f'{request.function.__name__}_test_agent')
+ loop_agent = LoopAgent(
+ name=f'{request.function.__name__}_test_loop_agent',
+ max_iterations=2,
+ sub_agents=[
+ agent,
+ ],
+ )
+ parent_ctx = _create_parent_invocation_context(
+ request.function.__name__, loop_agent
+ )
+ events = [e async for e in loop_agent.run_async(parent_ctx)]
+
+ assert len(events) == 2
+ assert events[0].author == agent.name
+ assert events[1].author == agent.name
+ assert events[0].content.parts[0].text == f'Hello, async {agent.name}!'
+ assert events[1].content.parts[0].text == f'Hello, async {agent.name}!'
+
+
+@pytest.mark.asyncio
+async def test_run_async_with_escalate_action(request: pytest.FixtureRequest):
+ non_escalating_agent = _TestingAgent(
+ name=f'{request.function.__name__}_test_non_escalating_agent'
+ )
+ escalating_agent = _TestingAgentWithEscalateAction(
+ name=f'{request.function.__name__}_test_escalating_agent'
+ )
+ loop_agent = LoopAgent(
+ name=f'{request.function.__name__}_test_loop_agent',
+ sub_agents=[non_escalating_agent, escalating_agent],
+ )
+ parent_ctx = _create_parent_invocation_context(
+ request.function.__name__, loop_agent
+ )
+ events = [e async for e in loop_agent.run_async(parent_ctx)]
+
+ # Only two events are generated because the sub escalating_agent escalates.
+ assert len(events) == 2
+ assert events[0].author == non_escalating_agent.name
+ assert events[1].author == escalating_agent.name
+ assert events[0].content.parts[0].text == (
+ f'Hello, async {non_escalating_agent.name}!'
+ )
+ assert events[1].content.parts[0].text == (
+ f'Hello, async {escalating_agent.name}!'
+ )
diff --git a/src/google/adk/tests/unittests/agents/test_parallel_agent.py b/src/google/adk/tests/unittests/agents/test_parallel_agent.py
new file mode 100644
index 0000000..4d4ff1c
--- /dev/null
+++ b/src/google/adk/tests/unittests/agents/test_parallel_agent.py
@@ -0,0 +1,92 @@
+# Copyright 2025 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Tests for the ParallelAgent."""
+
+import asyncio
+from typing import AsyncGenerator
+
+from google.adk.agents.base_agent import BaseAgent
+from google.adk.agents.invocation_context import InvocationContext
+from google.adk.agents.parallel_agent import ParallelAgent
+from google.adk.events import Event
+from google.adk.sessions.in_memory_session_service import InMemorySessionService
+from google.genai import types
+import pytest
+from typing_extensions import override
+
+
+class _TestingAgent(BaseAgent):
+
+ delay: float = 0
+ """The delay before the agent generates an event."""
+
+ @override
+ async def _run_async_impl(
+ self, ctx: InvocationContext
+ ) -> AsyncGenerator[Event, None]:
+ await asyncio.sleep(self.delay)
+ yield Event(
+ author=self.name,
+ branch=ctx.branch,
+ invocation_id=ctx.invocation_id,
+ content=types.Content(
+ parts=[types.Part(text=f'Hello, async {self.name}!')]
+ ),
+ )
+
+
+def _create_parent_invocation_context(
+ test_name: str, agent: BaseAgent
+) -> InvocationContext:
+ session_service = InMemorySessionService()
+ session = session_service.create_session(
+ app_name='test_app', user_id='test_user'
+ )
+ return InvocationContext(
+ invocation_id=f'{test_name}_invocation_id',
+ agent=agent,
+ session=session,
+ session_service=session_service,
+ )
+
+
+@pytest.mark.asyncio
+async def test_run_async(request: pytest.FixtureRequest):
+ agent1 = _TestingAgent(
+ name=f'{request.function.__name__}_test_agent_1',
+ delay=0.5,
+ )
+ agent2 = _TestingAgent(name=f'{request.function.__name__}_test_agent_2')
+ parallel_agent = ParallelAgent(
+ name=f'{request.function.__name__}_test_parallel_agent',
+ sub_agents=[
+ agent1,
+ agent2,
+ ],
+ )
+ parent_ctx = _create_parent_invocation_context(
+ request.function.__name__, parallel_agent
+ )
+ events = [e async for e in parallel_agent.run_async(parent_ctx)]
+
+ assert len(events) == 2
+ # agent2 generates an event first, then agent1. Because they run in parallel
+ # and agent1 has a delay.
+ assert events[0].author == agent2.name
+ assert events[1].author == agent1.name
+ assert events[0].branch.endswith(agent2.name)
+ assert events[1].branch.endswith(agent1.name)
+ assert events[0].content.parts[0].text == f'Hello, async {agent2.name}!'
+ assert events[1].content.parts[0].text == f'Hello, async {agent1.name}!'
diff --git a/src/google/adk/tests/unittests/agents/test_sequential_agent.py b/src/google/adk/tests/unittests/agents/test_sequential_agent.py
new file mode 100644
index 0000000..f964737
--- /dev/null
+++ b/src/google/adk/tests/unittests/agents/test_sequential_agent.py
@@ -0,0 +1,114 @@
+# Copyright 2025 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Testings for the SequentialAgent."""
+
+from typing import AsyncGenerator
+
+from google.adk.agents.base_agent import BaseAgent
+from google.adk.agents.invocation_context import InvocationContext
+from google.adk.agents.sequential_agent import SequentialAgent
+from google.adk.events import Event
+from google.adk.sessions.in_memory_session_service import InMemorySessionService
+from google.genai import types
+import pytest
+from typing_extensions import override
+
+
+class _TestingAgent(BaseAgent):
+
+ @override
+ async def _run_async_impl(
+ self, ctx: InvocationContext
+ ) -> AsyncGenerator[Event, None]:
+ yield Event(
+ author=self.name,
+ invocation_id=ctx.invocation_id,
+ content=types.Content(
+ parts=[types.Part(text=f'Hello, async {self.name}!')]
+ ),
+ )
+
+ @override
+ async def _run_live_impl(
+ self, ctx: InvocationContext
+ ) -> AsyncGenerator[Event, None]:
+ yield Event(
+ author=self.name,
+ invocation_id=ctx.invocation_id,
+ content=types.Content(
+ parts=[types.Part(text=f'Hello, live {self.name}!')]
+ ),
+ )
+
+
+def _create_parent_invocation_context(
+ test_name: str, agent: BaseAgent
+) -> InvocationContext:
+ session_service = InMemorySessionService()
+ session = session_service.create_session(
+ app_name='test_app', user_id='test_user'
+ )
+ return InvocationContext(
+ invocation_id=f'{test_name}_invocation_id',
+ agent=agent,
+ session=session,
+ session_service=session_service,
+ )
+
+
+@pytest.mark.asyncio
+async def test_run_async(request: pytest.FixtureRequest):
+ agent_1 = _TestingAgent(name=f'{request.function.__name__}_test_agent_1')
+ agent_2 = _TestingAgent(name=f'{request.function.__name__}_test_agent_2')
+ sequential_agent = SequentialAgent(
+ name=f'{request.function.__name__}_test_agent',
+ sub_agents=[
+ agent_1,
+ agent_2,
+ ],
+ )
+ parent_ctx = _create_parent_invocation_context(
+ request.function.__name__, sequential_agent
+ )
+ events = [e async for e in sequential_agent.run_async(parent_ctx)]
+
+ assert len(events) == 2
+ assert events[0].author == agent_1.name
+ assert events[1].author == agent_2.name
+ assert events[0].content.parts[0].text == f'Hello, async {agent_1.name}!'
+ assert events[1].content.parts[0].text == f'Hello, async {agent_2.name}!'
+
+
+@pytest.mark.asyncio
+async def test_run_live(request: pytest.FixtureRequest):
+ agent_1 = _TestingAgent(name=f'{request.function.__name__}_test_agent_1')
+ agent_2 = _TestingAgent(name=f'{request.function.__name__}_test_agent_2')
+ sequential_agent = SequentialAgent(
+ name=f'{request.function.__name__}_test_agent',
+ sub_agents=[
+ agent_1,
+ agent_2,
+ ],
+ )
+ parent_ctx = _create_parent_invocation_context(
+ request.function.__name__, sequential_agent
+ )
+ events = [e async for e in sequential_agent.run_live(parent_ctx)]
+
+ assert len(events) == 2
+ assert events[0].author == agent_1.name
+ assert events[1].author == agent_2.name
+ assert events[0].content.parts[0].text == f'Hello, live {agent_1.name}!'
+ assert events[1].content.parts[0].text == f'Hello, live {agent_2.name}!'
diff --git a/src/google/adk/tests/unittests/artifacts/__init__.py b/src/google/adk/tests/unittests/artifacts/__init__.py
new file mode 100644
index 0000000..36a1e8d
--- /dev/null
+++ b/src/google/adk/tests/unittests/artifacts/__init__.py
@@ -0,0 +1,14 @@
+# Copyright 2025 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
diff --git a/src/google/adk/tests/unittests/artifacts/test_artifact_service.py b/src/google/adk/tests/unittests/artifacts/test_artifact_service.py
new file mode 100644
index 0000000..e8ce497
--- /dev/null
+++ b/src/google/adk/tests/unittests/artifacts/test_artifact_service.py
@@ -0,0 +1,276 @@
+# Copyright 2025 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Tests for the artifact service."""
+
+import enum
+from typing import Optional
+from typing import Union
+
+from google.adk.artifacts import GcsArtifactService
+from google.adk.artifacts import InMemoryArtifactService
+from google.genai import types
+import pytest
+
+Enum = enum.Enum
+
+
+class ArtifactServiceType(Enum):
+ IN_MEMORY = "IN_MEMORY"
+ GCS = "GCS"
+
+
+class MockBlob:
+ """Mocks a GCS Blob object.
+
+ This class provides mock implementations for a few common GCS Blob methods,
+ allowing the user to test code that interacts with GCS without actually
+ connecting to a real bucket.
+ """
+
+ def __init__(self, name: str) -> None:
+ """Initializes a MockBlob.
+
+ Args:
+ name: The name of the blob.
+ """
+ self.name = name
+ self.content: Optional[bytes] = None
+ self.content_type: Optional[str] = None
+
+ def upload_from_string(
+ self, data: Union[str, bytes], content_type: Optional[str] = None
+ ) -> None:
+ """Mocks uploading data to the blob (from a string or bytes).
+
+ Args:
+ data: The data to upload (string or bytes).
+ content_type: The content type of the data (optional).
+ """
+ if isinstance(data, str):
+ self.content = data.encode("utf-8")
+ elif isinstance(data, bytes):
+ self.content = data
+ else:
+ raise TypeError("data must be str or bytes")
+
+ if content_type:
+ self.content_type = content_type
+
+ def download_as_bytes(self) -> bytes:
+ """Mocks downloading the blob's content as bytes.
+
+ Returns:
+ bytes: The content of the blob as bytes.
+
+ Raises:
+ Exception: If the blob doesn't exist (hasn't been uploaded to).
+ """
+ if self.content is None:
+ return b""
+ return self.content
+
+ def delete(self) -> None:
+ """Mocks deleting a blob."""
+ self.content = None
+ self.content_type = None
+
+
+class MockBucket:
+ """Mocks a GCS Bucket object."""
+
+ def __init__(self, name: str) -> None:
+ """Initializes a MockBucket.
+
+ Args:
+ name: The name of the bucket.
+ """
+ self.name = name
+ self.blobs: dict[str, MockBlob] = {}
+
+ def blob(self, blob_name: str) -> MockBlob:
+ """Mocks getting a Blob object (doesn't create it in storage).
+
+ Args:
+ blob_name: The name of the blob.
+
+ Returns:
+ A MockBlob instance.
+ """
+ if blob_name not in self.blobs:
+ self.blobs[blob_name] = MockBlob(blob_name)
+ return self.blobs[blob_name]
+
+
+class MockClient:
+ """Mocks the GCS Client."""
+
+ def __init__(self) -> None:
+ """Initializes MockClient."""
+ self.buckets: dict[str, MockBucket] = {}
+
+ def bucket(self, bucket_name: str) -> MockBucket:
+ """Mocks getting a Bucket object."""
+ if bucket_name not in self.buckets:
+ self.buckets[bucket_name] = MockBucket(bucket_name)
+ return self.buckets[bucket_name]
+
+ def list_blobs(self, bucket: MockBucket, prefix: Optional[str] = None):
+ """Mocks listing blobs in a bucket, optionally with a prefix."""
+ if prefix:
+ return [
+ blob for name, blob in bucket.blobs.items() if name.startswith(prefix)
+ ]
+ return list(bucket.blobs.values())
+
+
+def mock_gcs_artifact_service():
+ """Creates a mock GCS artifact service for testing."""
+ service = GcsArtifactService(bucket_name="test_bucket")
+ service.storage_client = MockClient()
+ service.bucket = service.storage_client.bucket("test_bucket")
+ return service
+
+
+def get_artifact_service(
+ service_type: ArtifactServiceType = ArtifactServiceType.IN_MEMORY,
+):
+ """Creates an artifact service for testing."""
+ if service_type == ArtifactServiceType.GCS:
+ return mock_gcs_artifact_service()
+ return InMemoryArtifactService()
+
+
+@pytest.mark.parametrize(
+ "service_type", [ArtifactServiceType.IN_MEMORY, ArtifactServiceType.GCS]
+)
+def test_load_empty(service_type):
+ """Tests loading an artifact when none exists."""
+ artifact_service = get_artifact_service(service_type)
+ assert not artifact_service.load_artifact(
+ app_name="test_app",
+ user_id="test_user",
+ session_id="session_id",
+ filename="filename",
+ )
+
+
+@pytest.mark.parametrize(
+ "service_type", [ArtifactServiceType.IN_MEMORY, ArtifactServiceType.GCS]
+)
+def test_save_load_delete(service_type):
+ """Tests saving, loading, and deleting an artifact."""
+ artifact_service = get_artifact_service(service_type)
+ artifact = types.Part.from_bytes(data=b"test_data", mime_type="text/plain")
+ app_name = "app0"
+ user_id = "user0"
+ session_id = "123"
+ filename = "file456"
+
+ artifact_service.save_artifact(
+ app_name=app_name,
+ user_id=user_id,
+ session_id=session_id,
+ filename=filename,
+ artifact=artifact,
+ )
+ assert (
+ artifact_service.load_artifact(
+ app_name=app_name,
+ user_id=user_id,
+ session_id=session_id,
+ filename=filename,
+ )
+ == artifact
+ )
+
+ artifact_service.delete_artifact(
+ app_name=app_name,
+ user_id=user_id,
+ session_id=session_id,
+ filename=filename,
+ )
+ assert not artifact_service.load_artifact(
+ app_name=app_name,
+ user_id=user_id,
+ session_id=session_id,
+ filename=filename,
+ )
+
+
+@pytest.mark.parametrize(
+ "service_type", [ArtifactServiceType.IN_MEMORY, ArtifactServiceType.GCS]
+)
+def test_list_keys(service_type):
+ """Tests listing keys in the artifact service."""
+ artifact_service = get_artifact_service(service_type)
+ artifact = types.Part.from_bytes(data=b"test_data", mime_type="text/plain")
+ app_name = "app0"
+ user_id = "user0"
+ session_id = "123"
+ filename = "filename"
+ filenames = [filename + str(i) for i in range(5)]
+
+ for f in filenames:
+ artifact_service.save_artifact(
+ app_name=app_name,
+ user_id=user_id,
+ session_id=session_id,
+ filename=f,
+ artifact=artifact,
+ )
+
+ assert (
+ artifact_service.list_artifact_keys(
+ app_name=app_name, user_id=user_id, session_id=session_id
+ )
+ == filenames
+ )
+
+
+@pytest.mark.parametrize(
+ "service_type", [ArtifactServiceType.IN_MEMORY, ArtifactServiceType.GCS]
+)
+def test_list_versions(service_type):
+ """Tests listing versions of an artifact."""
+ artifact_service = get_artifact_service(service_type)
+
+ app_name = "app0"
+ user_id = "user0"
+ session_id = "123"
+ filename = "filename"
+ versions = [
+ types.Part.from_bytes(
+ data=i.to_bytes(2, byteorder="big"), mime_type="text/plain"
+ )
+ for i in range(3)
+ ]
+
+ for i in range(3):
+ artifact_service.save_artifact(
+ app_name=app_name,
+ user_id=user_id,
+ session_id=session_id,
+ filename=filename,
+ artifact=versions[i],
+ )
+
+ response_versions = artifact_service.list_versions(
+ app_name=app_name,
+ user_id=user_id,
+ session_id=session_id,
+ filename=filename,
+ )
+
+ assert response_versions == list(range(3))
diff --git a/src/google/adk/tests/unittests/auth/test_auth_handler.py b/src/google/adk/tests/unittests/auth/test_auth_handler.py
new file mode 100644
index 0000000..1481749
--- /dev/null
+++ b/src/google/adk/tests/unittests/auth/test_auth_handler.py
@@ -0,0 +1,575 @@
+# Copyright 2025 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import copy
+from unittest.mock import patch
+
+from fastapi.openapi.models import APIKey
+from fastapi.openapi.models import APIKeyIn
+from fastapi.openapi.models import OAuth2
+from fastapi.openapi.models import OAuthFlowAuthorizationCode
+from fastapi.openapi.models import OAuthFlows
+from google.adk.auth.auth_credential import AuthCredential
+from google.adk.auth.auth_credential import AuthCredentialTypes
+from google.adk.auth.auth_credential import OAuth2Auth
+from google.adk.auth.auth_handler import AuthHandler
+from google.adk.auth.auth_schemes import OpenIdConnectWithConfig
+from google.adk.auth.auth_tool import AuthConfig
+import pytest
+
+
+# Mock classes for testing
+class MockState(dict):
+ """Mock State class for testing."""
+
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+
+ def get(self, key, default=None):
+ return super().get(key, default)
+
+
+class MockOAuth2Session:
+ """Mock OAuth2Session for testing."""
+
+ def __init__(
+ self,
+ client_id=None,
+ client_secret=None,
+ scope=None,
+ redirect_uri=None,
+ state=None,
+ ):
+ self.client_id = client_id
+ self.client_secret = client_secret
+ self.scope = scope
+ self.redirect_uri = redirect_uri
+ self.state = state
+
+ def create_authorization_url(self, url):
+ return f"{url}?client_id={self.client_id}&scope={self.scope}", "mock_state"
+
+ def fetch_token(
+ self,
+ token_endpoint,
+ authorization_response=None,
+ code=None,
+ grant_type=None,
+ ):
+ return {
+ "access_token": "mock_access_token",
+ "token_type": "bearer",
+ "expires_in": 3600,
+ "refresh_token": "mock_refresh_token",
+ }
+
+
+# Fixtures for common test objects
+@pytest.fixture
+def oauth2_auth_scheme():
+ """Create an OAuth2 auth scheme for testing."""
+ # Create the OAuthFlows object first
+ flows = OAuthFlows(
+ authorizationCode=OAuthFlowAuthorizationCode(
+ authorizationUrl="https://example.com/oauth2/authorize",
+ tokenUrl="https://example.com/oauth2/token",
+ scopes={"read": "Read access", "write": "Write access"},
+ )
+ )
+
+ # Then create the OAuth2 object with the flows
+ return OAuth2(flows=flows)
+
+
+@pytest.fixture
+def openid_auth_scheme():
+ """Create an OpenID Connect auth scheme for testing."""
+ return OpenIdConnectWithConfig(
+ openIdConnectUrl="https://example.com/.well-known/openid-configuration",
+ authorization_endpoint="https://example.com/oauth2/authorize",
+ token_endpoint="https://example.com/oauth2/token",
+ scopes=["openid", "profile", "email"],
+ )
+
+
+@pytest.fixture
+def oauth2_credentials():
+ """Create OAuth2 credentials for testing."""
+ return AuthCredential(
+ auth_type=AuthCredentialTypes.OAUTH2,
+ oauth2=OAuth2Auth(
+ client_id="mock_client_id",
+ client_secret="mock_client_secret",
+ redirect_uri="https://example.com/callback",
+ ),
+ )
+
+
+@pytest.fixture
+def oauth2_credentials_with_token():
+ """Create OAuth2 credentials with a token for testing."""
+ return AuthCredential(
+ auth_type=AuthCredentialTypes.OAUTH2,
+ oauth2=OAuth2Auth(
+ client_id="mock_client_id",
+ client_secret="mock_client_secret",
+ redirect_uri="https://example.com/callback",
+ token={
+ "access_token": "mock_access_token",
+ "token_type": "bearer",
+ "expires_in": 3600,
+ "refresh_token": "mock_refresh_token",
+ },
+ ),
+ )
+
+
+@pytest.fixture
+def oauth2_credentials_with_auth_uri():
+ """Create OAuth2 credentials with an auth URI for testing."""
+ return AuthCredential(
+ auth_type=AuthCredentialTypes.OAUTH2,
+ oauth2=OAuth2Auth(
+ client_id="mock_client_id",
+ client_secret="mock_client_secret",
+ redirect_uri="https://example.com/callback",
+ auth_uri="https://example.com/oauth2/authorize?client_id=mock_client_id&scope=read,write",
+ state="mock_state",
+ ),
+ )
+
+
+@pytest.fixture
+def oauth2_credentials_with_auth_code():
+ """Create OAuth2 credentials with an auth code for testing."""
+ return AuthCredential(
+ auth_type=AuthCredentialTypes.OAUTH2,
+ oauth2=OAuth2Auth(
+ client_id="mock_client_id",
+ client_secret="mock_client_secret",
+ redirect_uri="https://example.com/callback",
+ auth_uri="https://example.com/oauth2/authorize?client_id=mock_client_id&scope=read,write",
+ state="mock_state",
+ auth_code="mock_auth_code",
+ auth_response_uri="https://example.com/callback?code=mock_auth_code&state=mock_state",
+ ),
+ )
+
+
+@pytest.fixture
+def auth_config(oauth2_auth_scheme, oauth2_credentials):
+ """Create an AuthConfig for testing."""
+ # Create a copy of the credentials for the exchanged_auth_credential
+ exchanged_credential = oauth2_credentials.model_copy(deep=True)
+
+ return AuthConfig(
+ auth_scheme=oauth2_auth_scheme,
+ raw_auth_credential=oauth2_credentials,
+ exchanged_auth_credential=exchanged_credential,
+ )
+
+
+@pytest.fixture
+def auth_config_with_exchanged(
+ oauth2_auth_scheme, oauth2_credentials, oauth2_credentials_with_auth_uri
+):
+ """Create an AuthConfig with exchanged credentials for testing."""
+ return AuthConfig(
+ auth_scheme=oauth2_auth_scheme,
+ raw_auth_credential=oauth2_credentials,
+ exchanged_auth_credential=oauth2_credentials_with_auth_uri,
+ )
+
+
+@pytest.fixture
+def auth_config_with_auth_code(
+ oauth2_auth_scheme, oauth2_credentials, oauth2_credentials_with_auth_code
+):
+ """Create an AuthConfig with auth code for testing."""
+ return AuthConfig(
+ auth_scheme=oauth2_auth_scheme,
+ raw_auth_credential=oauth2_credentials,
+ exchanged_auth_credential=oauth2_credentials_with_auth_code,
+ )
+
+
+class TestAuthHandlerInit:
+ """Tests for the AuthHandler initialization."""
+
+ def test_init(self, auth_config):
+ """Test the initialization of AuthHandler."""
+ handler = AuthHandler(auth_config)
+ assert handler.auth_config == auth_config
+
+
+class TestGetCredentialKey:
+ """Tests for the get_credential_key method."""
+
+ def test_get_credential_key(self, auth_config):
+ """Test generating a unique credential key."""
+ handler = AuthHandler(auth_config)
+ key = handler.get_credential_key()
+ assert key.startswith("temp:adk_oauth2_")
+ assert "_oauth2_" in key
+
+ def test_get_credential_key_with_extras(self, auth_config):
+ """Test generating a key when model_extra exists."""
+ # Add model_extra to test cleanup
+
+ original_key = AuthHandler(auth_config).get_credential_key()
+ key = AuthHandler(auth_config).get_credential_key()
+
+ auth_config.auth_scheme.model_extra["extra_field"] = "value"
+ auth_config.raw_auth_credential.model_extra["extra_field"] = "value"
+
+ assert original_key == key
+ assert "extra_field" in auth_config.auth_scheme.model_extra
+ assert "extra_field" in auth_config.raw_auth_credential.model_extra
+
+
+class TestGenerateAuthUri:
+ """Tests for the generate_auth_uri method."""
+
+ @patch("google.adk.auth.auth_handler.OAuth2Session", MockOAuth2Session)
+ def test_generate_auth_uri_oauth2(self, auth_config):
+ """Test generating an auth URI for OAuth2."""
+ handler = AuthHandler(auth_config)
+ result = handler.generate_auth_uri()
+
+ assert result.oauth2.auth_uri.startswith(
+ "https://example.com/oauth2/authorize"
+ )
+ assert "client_id=mock_client_id" in result.oauth2.auth_uri
+ assert result.oauth2.state == "mock_state"
+
+ @patch("google.adk.auth.auth_handler.OAuth2Session", MockOAuth2Session)
+ def test_generate_auth_uri_openid(
+ self, openid_auth_scheme, oauth2_credentials
+ ):
+ """Test generating an auth URI for OpenID Connect."""
+ # Create a copy for the exchanged credential
+ exchanged = oauth2_credentials.model_copy(deep=True)
+
+ config = AuthConfig(
+ auth_scheme=openid_auth_scheme,
+ raw_auth_credential=oauth2_credentials,
+ exchanged_auth_credential=exchanged,
+ )
+ handler = AuthHandler(config)
+ result = handler.generate_auth_uri()
+
+ assert result.oauth2.auth_uri.startswith(
+ "https://example.com/oauth2/authorize"
+ )
+ assert "client_id=mock_client_id" in result.oauth2.auth_uri
+ assert result.oauth2.state == "mock_state"
+
+
+class TestGenerateAuthRequest:
+ """Tests for the generate_auth_request method."""
+
+ def test_non_oauth_scheme(self):
+ """Test with a non-OAuth auth scheme."""
+ # Use a SecurityBase instance without using APIKey which has validation issues
+ api_key_scheme = APIKey(**{"name": "test_api_key", "in": APIKeyIn.header})
+
+ credential = AuthCredential(
+ auth_type=AuthCredentialTypes.API_KEY, api_key="test_api_key"
+ )
+
+ # Create a copy for the exchanged credential
+ exchanged = credential.model_copy(deep=True)
+
+ config = AuthConfig(
+ auth_scheme=api_key_scheme,
+ raw_auth_credential=credential,
+ exchanged_auth_credential=exchanged,
+ )
+
+ handler = AuthHandler(config)
+ result = handler.generate_auth_request()
+
+ assert result == config
+
+ def test_with_existing_auth_uri(self, auth_config_with_exchanged):
+ """Test when auth_uri already exists in exchanged credential."""
+ handler = AuthHandler(auth_config_with_exchanged)
+ result = handler.generate_auth_request()
+
+ assert (
+ result.exchanged_auth_credential.oauth2.auth_uri
+ == auth_config_with_exchanged.exchanged_auth_credential.oauth2.auth_uri
+ )
+
+ def test_missing_raw_credential(self, oauth2_auth_scheme):
+ """Test when raw_auth_credential is missing."""
+
+ config = AuthConfig(
+ auth_scheme=oauth2_auth_scheme,
+ )
+ handler = AuthHandler(config)
+
+ with pytest.raises(ValueError, match="requires auth_credential"):
+ handler.generate_auth_request()
+
+ def test_missing_oauth2_in_raw_credential(self, oauth2_auth_scheme):
+ """Test when oauth2 is missing in raw_auth_credential."""
+ credential = AuthCredential(
+ auth_type=AuthCredentialTypes.API_KEY, api_key="test_api_key"
+ )
+
+ # Create a copy for the exchanged credential
+ exchanged = credential.model_copy(deep=True)
+
+ config = AuthConfig(
+ auth_scheme=oauth2_auth_scheme,
+ raw_auth_credential=credential,
+ exchanged_auth_credential=exchanged,
+ )
+ handler = AuthHandler(config)
+
+ with pytest.raises(ValueError, match="requires oauth2 in auth_credential"):
+ handler.generate_auth_request()
+
+ def test_auth_uri_in_raw_credential(
+ self, oauth2_auth_scheme, oauth2_credentials_with_auth_uri
+ ):
+ """Test when auth_uri exists in raw_credential."""
+ config = AuthConfig(
+ auth_scheme=oauth2_auth_scheme,
+ raw_auth_credential=oauth2_credentials_with_auth_uri,
+ exchanged_auth_credential=oauth2_credentials_with_auth_uri.model_copy(
+ deep=True
+ ),
+ )
+ handler = AuthHandler(config)
+ result = handler.generate_auth_request()
+
+ assert (
+ result.exchanged_auth_credential.oauth2.auth_uri
+ == oauth2_credentials_with_auth_uri.oauth2.auth_uri
+ )
+
+ def test_missing_client_credentials(self, oauth2_auth_scheme):
+ """Test when client_id or client_secret is missing."""
+ bad_credential = AuthCredential(
+ auth_type=AuthCredentialTypes.OAUTH2,
+ oauth2=OAuth2Auth(redirect_uri="https://example.com/callback"),
+ )
+
+ # Create a copy for the exchanged credential
+ exchanged = bad_credential.model_copy(deep=True)
+
+ config = AuthConfig(
+ auth_scheme=oauth2_auth_scheme,
+ raw_auth_credential=bad_credential,
+ exchanged_auth_credential=exchanged,
+ )
+ handler = AuthHandler(config)
+
+ with pytest.raises(
+ ValueError, match="requires both client_id and client_secret"
+ ):
+ handler.generate_auth_request()
+
+ @patch("google.adk.auth.auth_handler.AuthHandler.generate_auth_uri")
+ def test_generate_new_auth_uri(self, mock_generate_auth_uri, auth_config):
+ """Test generating a new auth URI."""
+ mock_credential = AuthCredential(
+ auth_type=AuthCredentialTypes.OAUTH2,
+ oauth2=OAuth2Auth(
+ client_id="mock_client_id",
+ client_secret="mock_client_secret",
+ redirect_uri="https://example.com/callback",
+ auth_uri="https://example.com/generated",
+ state="generated_state",
+ ),
+ )
+ mock_generate_auth_uri.return_value = mock_credential
+
+ handler = AuthHandler(auth_config)
+ result = handler.generate_auth_request()
+
+ assert mock_generate_auth_uri.called
+ assert result.exchanged_auth_credential == mock_credential
+
+
+class TestGetAuthResponse:
+ """Tests for the get_auth_response method."""
+
+ def test_get_auth_response_exists(
+ self, auth_config, oauth2_credentials_with_auth_uri
+ ):
+ """Test retrieving an existing auth response from state."""
+ handler = AuthHandler(auth_config)
+ state = MockState()
+
+ # Store a credential in the state
+ credential_key = handler.get_credential_key()
+ state[credential_key] = oauth2_credentials_with_auth_uri
+
+ result = handler.get_auth_response(state)
+ assert result == oauth2_credentials_with_auth_uri
+
+ def test_get_auth_response_not_exists(self, auth_config):
+ """Test retrieving a non-existent auth response from state."""
+ handler = AuthHandler(auth_config)
+ state = MockState()
+
+ result = handler.get_auth_response(state)
+ assert result is None
+
+
+class TestParseAndStoreAuthResponse:
+ """Tests for the parse_and_store_auth_response method."""
+
+ def test_non_oauth_scheme(self, auth_config_with_exchanged):
+ """Test with a non-OAuth auth scheme."""
+ # Modify the auth scheme type to be non-OAuth
+ auth_config = copy.deepcopy(auth_config_with_exchanged)
+ auth_config.auth_scheme = APIKey(
+ **{"name": "test_api_key", "in": APIKeyIn.header}
+ )
+
+ handler = AuthHandler(auth_config)
+ state = MockState()
+
+ handler.parse_and_store_auth_response(state)
+
+ credential_key = handler.get_credential_key()
+ assert state[credential_key] == auth_config.exchanged_auth_credential
+
+ @patch("google.adk.auth.auth_handler.AuthHandler.exchange_auth_token")
+ def test_oauth_scheme(self, mock_exchange_token, auth_config_with_exchanged):
+ """Test with an OAuth auth scheme."""
+ mock_exchange_token.return_value = AuthCredential(
+ auth_type=AuthCredentialTypes.OAUTH2,
+ oauth2=OAuth2Auth(token={"access_token": "exchanged_token"}),
+ )
+
+ handler = AuthHandler(auth_config_with_exchanged)
+ state = MockState()
+
+ handler.parse_and_store_auth_response(state)
+
+ credential_key = handler.get_credential_key()
+ assert state[credential_key] == mock_exchange_token.return_value
+ assert mock_exchange_token.called
+
+
+class TestExchangeAuthToken:
+ """Tests for the exchange_auth_token method."""
+
+ def test_token_exchange_not_supported(
+ self, auth_config_with_auth_code, monkeypatch
+ ):
+ """Test when token exchange is not supported."""
+ monkeypatch.setattr(
+ "google.adk.auth.auth_handler.SUPPORT_TOKEN_EXCHANGE", False
+ )
+
+ handler = AuthHandler(auth_config_with_auth_code)
+ result = handler.exchange_auth_token()
+
+ assert result == auth_config_with_auth_code.exchanged_auth_credential
+
+ def test_openid_missing_token_endpoint(
+ self, openid_auth_scheme, oauth2_credentials_with_auth_code
+ ):
+ """Test OpenID Connect without a token endpoint."""
+ # Create a scheme without token_endpoint
+ scheme_without_token = copy.deepcopy(openid_auth_scheme)
+ delattr(scheme_without_token, "token_endpoint")
+
+ config = AuthConfig(
+ auth_scheme=scheme_without_token,
+ raw_auth_credential=oauth2_credentials_with_auth_code,
+ exchanged_auth_credential=oauth2_credentials_with_auth_code,
+ )
+
+ handler = AuthHandler(config)
+ result = handler.exchange_auth_token()
+
+ assert result == oauth2_credentials_with_auth_code
+
+ def test_oauth2_missing_token_url(
+ self, oauth2_auth_scheme, oauth2_credentials_with_auth_code
+ ):
+ """Test OAuth2 without a token URL."""
+ # Create a scheme without tokenUrl
+ scheme_without_token = copy.deepcopy(oauth2_auth_scheme)
+ scheme_without_token.flows.authorizationCode.tokenUrl = None
+
+ config = AuthConfig(
+ auth_scheme=scheme_without_token,
+ raw_auth_credential=oauth2_credentials_with_auth_code,
+ exchanged_auth_credential=oauth2_credentials_with_auth_code,
+ )
+
+ handler = AuthHandler(config)
+ result = handler.exchange_auth_token()
+
+ assert result == oauth2_credentials_with_auth_code
+
+ def test_non_oauth_scheme(self, auth_config_with_auth_code):
+ """Test with a non-OAuth auth scheme."""
+ # Modify the auth scheme type to be non-OAuth
+ auth_config = copy.deepcopy(auth_config_with_auth_code)
+ auth_config.auth_scheme = APIKey(
+ **{"name": "test_api_key", "in": APIKeyIn.header}
+ )
+
+ handler = AuthHandler(auth_config)
+ result = handler.exchange_auth_token()
+
+ assert result == auth_config.exchanged_auth_credential
+
+ def test_missing_credentials(self, oauth2_auth_scheme):
+ """Test with missing credentials."""
+ empty_credential = AuthCredential(auth_type=AuthCredentialTypes.OAUTH2)
+
+ config = AuthConfig(
+ auth_scheme=oauth2_auth_scheme,
+ exchanged_auth_credential=empty_credential,
+ )
+
+ handler = AuthHandler(config)
+ result = handler.exchange_auth_token()
+
+ assert result == empty_credential
+
+ def test_credentials_with_token(
+ self, auth_config, oauth2_credentials_with_token
+ ):
+ """Test when credentials already have a token."""
+ config = AuthConfig(
+ auth_scheme=auth_config.auth_scheme,
+ raw_auth_credential=auth_config.raw_auth_credential,
+ exchanged_auth_credential=oauth2_credentials_with_token,
+ )
+
+ handler = AuthHandler(config)
+ result = handler.exchange_auth_token()
+
+ assert result == oauth2_credentials_with_token
+
+ @patch("google.adk.auth.auth_handler.OAuth2Session", MockOAuth2Session)
+ def test_successful_token_exchange(self, auth_config_with_auth_code):
+ """Test a successful token exchange."""
+ handler = AuthHandler(auth_config_with_auth_code)
+ result = handler.exchange_auth_token()
+
+ assert result.oauth2.token["access_token"] == "mock_access_token"
+ assert result.oauth2.token["refresh_token"] == "mock_refresh_token"
+ assert result.auth_type == AuthCredentialTypes.OAUTH2
diff --git a/src/google/adk/tests/unittests/conftest.py b/src/google/adk/tests/unittests/conftest.py
new file mode 100644
index 0000000..ad20400
--- /dev/null
+++ b/src/google/adk/tests/unittests/conftest.py
@@ -0,0 +1,73 @@
+# Copyright 2025 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os
+
+from pytest import fixture
+from pytest import FixtureRequest
+from pytest import hookimpl
+from pytest import Metafunc
+
+_ENV_VARS = {
+ 'GOOGLE_API_KEY': 'fake_google_api_key',
+ 'GOOGLE_CLOUD_PROJECT': 'fake_google_cloud_project',
+ 'GOOGLE_CLOUD_LOCATION': 'fake_google_cloud_location',
+}
+
+ENV_SETUPS = {
+ 'GOOGLE_AI': {
+ 'GOOGLE_GENAI_USE_VERTEXAI': '0',
+ **_ENV_VARS,
+ },
+ 'VERTEX': {
+ 'GOOGLE_GENAI_USE_VERTEXAI': '1',
+ **_ENV_VARS,
+ },
+}
+
+
+@fixture(autouse=True)
+def env_variables(request: FixtureRequest):
+ # Set up the environment
+ env_name: str = request.param
+ envs = ENV_SETUPS[env_name]
+ original_env = {key: os.environ.get(key) for key in envs}
+ os.environ.update(envs)
+
+ yield # Run the test
+
+ # Restore the environment
+ for key in envs:
+ if (original_val := original_env.get(key)) is None:
+ os.environ.pop(key, None)
+ else:
+ os.environ[key] = original_val
+
+
+@hookimpl(tryfirst=True)
+def pytest_generate_tests(metafunc: Metafunc):
+ """Generate test cases for each environment setup."""
+ if env_variables.__name__ in metafunc.fixturenames:
+ if not _is_explicitly_marked(env_variables.__name__, metafunc):
+ metafunc.parametrize(
+ env_variables.__name__, ENV_SETUPS.keys(), indirect=True
+ )
+
+
+def _is_explicitly_marked(mark_name: str, metafunc: Metafunc) -> bool:
+ if hasattr(metafunc.function, 'pytestmark'):
+ for mark in metafunc.function.pytestmark:
+ if mark.name == 'parametrize' and mark.args[0] == mark_name:
+ return True
+ return False
diff --git a/src/google/adk/tests/unittests/fast_api/__init__.py b/src/google/adk/tests/unittests/fast_api/__init__.py
new file mode 100644
index 0000000..36a1e8d
--- /dev/null
+++ b/src/google/adk/tests/unittests/fast_api/__init__.py
@@ -0,0 +1,14 @@
+# Copyright 2025 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
diff --git a/src/google/adk/tests/unittests/fast_api/test_fast_api.py b/src/google/adk/tests/unittests/fast_api/test_fast_api.py
new file mode 100644
index 0000000..1f7fd17
--- /dev/null
+++ b/src/google/adk/tests/unittests/fast_api/test_fast_api.py
@@ -0,0 +1,269 @@
+# Copyright 2025 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import asyncio
+import json
+import sys
+import threading
+import time
+import types as ptypes
+from typing import AsyncGenerator
+
+from google.adk.agents import BaseAgent
+from google.adk.agents import LiveRequest
+from google.adk.agents.run_config import RunConfig
+from google.adk.cli.fast_api import AgentRunRequest
+from google.adk.cli.fast_api import get_fast_api_app
+from google.adk.cli.utils import envs
+from google.adk.events import Event
+from google.adk.runners import Runner
+from google.genai import types
+import httpx
+import pytest
+from uvicorn.main import run as uvicorn_run
+import websockets
+
+
+# Here we “fake” the agent module that get_fast_api_app expects.
+# The server code does: `agent_module = importlib.import_module(agent_name)`
+# and then accesses: agent_module.agent.root_agent.
+class DummyAgent(BaseAgent):
+ pass
+
+
+dummy_module = ptypes.ModuleType("test_agent")
+dummy_module.agent = ptypes.SimpleNamespace(
+ root_agent=DummyAgent(name="dummy_agent")
+)
+sys.modules["test_app"] = dummy_module
+envs.load_dotenv_for_agent("test_app", ".")
+
+event1 = Event(
+ author="dummy agent",
+ invocation_id="invocation_id",
+ content=types.Content(
+ role="model", parts=[types.Part(text="LLM reply", inline_data=None)]
+ ),
+)
+
+event2 = Event(
+ author="dummy agent",
+ invocation_id="invocation_id",
+ content=types.Content(
+ role="model",
+ parts=[
+ types.Part(
+ text=None,
+ inline_data=types.Blob(
+ mime_type="audio/pcm;rate=24000", data=b"\x00\xFF"
+ ),
+ )
+ ],
+ ),
+)
+
+event3 = Event(
+ author="dummy agent", invocation_id="invocation_id", interrupted=True
+)
+
+
+# For simplicity, we patch Runner.run_live to yield dummy events.
+# We use SimpleNamespace to mimic attribute-access (i.e. event.content.parts).
+async def dummy_run_live(
+ self, session, live_request_queue
+) -> AsyncGenerator[Event, None]:
+ # Immediately yield a dummy event with a text reply.
+ yield event1
+ await asyncio.sleep(0)
+
+ yield event2
+ await asyncio.sleep(0)
+
+ yield event3
+
+ raise Exception()
+
+
+async def dummy_run_async(
+ self,
+ user_id,
+ session_id,
+ new_message,
+ run_config: RunConfig = RunConfig(),
+) -> AsyncGenerator[Event, None]:
+ # Immediately yield a dummy event with a text reply.
+ yield event1
+ await asyncio.sleep(0)
+
+ yield event2
+ await asyncio.sleep(0)
+
+ yield event3
+
+ return
+
+
+###############################################################################
+# Pytest fixtures to patch methods and start the server
+###############################################################################
+
+
+@pytest.fixture(autouse=True)
+def patch_runner(monkeypatch):
+ # Patch the Runner methods to use our dummy implementations.
+ monkeypatch.setattr(Runner, "run_live", dummy_run_live)
+ monkeypatch.setattr(Runner, "run_async", dummy_run_async)
+
+
+@pytest.fixture(scope="module", autouse=True)
+def start_server():
+ """Start the FastAPI server in a background thread."""
+
+ def run_server():
+ uvicorn_run(
+ get_fast_api_app(agent_dir=".", web=True),
+ host="0.0.0.0",
+ log_config=None,
+ )
+
+ server_thread = threading.Thread(target=run_server, daemon=True)
+ server_thread.start()
+ # Wait a moment to ensure the server is up.
+ time.sleep(2)
+ yield
+ # The daemon thread will be terminated when tests complete.
+
+
+@pytest.mark.asyncio
+async def test_sse_endpoint():
+ base_http_url = "http://127.0.0.1:8000"
+ user_id = "test_user"
+ session_id = "test_session"
+
+ # Ensure that the session exists (create if necessary).
+ url_create = (
+ f"{base_http_url}/apps/test_app/users/{user_id}/sessions/{session_id}"
+ )
+ httpx.post(url_create, json={"state": {}})
+
+ async with httpx.AsyncClient() as client:
+ # Make a POST request to the SSE endpoint.
+ async with client.stream(
+ "POST",
+ f"{base_http_url}/run_sse",
+ json=json.loads(
+ AgentRunRequest(
+ app_name="test_app",
+ user_id=user_id,
+ session_id=session_id,
+ new_message=types.Content(
+ parts=[types.Part(text="Hello via SSE", inline_data=None)]
+ ),
+ streaming=False,
+ ).model_dump_json(exclude_none=True)
+ ),
+ ) as response:
+ # Ensure the status code and header are as expected.
+ assert response.status_code == 200
+ assert (
+ response.headers.get("content-type")
+ == "text/event-stream; charset=utf-8"
+ )
+
+ # Iterate over events from the stream.
+ event_count = 0
+ event_buffer = ""
+
+ async for line in response.aiter_lines():
+ event_buffer += line + "\n"
+
+ # An SSE event is terminated by an empty line (double newline)
+ if line == "" and event_buffer.strip():
+ # Process the complete event
+ event_data = None
+ for event_line in event_buffer.split("\n"):
+ if event_line.startswith("data: "):
+ event_data = event_line[6:] # Remove "data: " prefix
+
+ if event_data:
+ event_count += 1
+ if event_count == 1:
+ assert event_data == event1.model_dump_json(
+ exclude_none=True, by_alias=True
+ )
+ elif event_count == 2:
+ assert event_data == event2.model_dump_json(
+ exclude_none=True, by_alias=True
+ )
+ elif event_count == 3:
+ assert event_data == event3.model_dump_json(
+ exclude_none=True, by_alias=True
+ )
+ else:
+ pass
+
+ # Reset buffer for next event
+ event_buffer = ""
+
+ assert event_count == 3 # Expecting 3 events from dummy_run_async
+
+
+@pytest.mark.asyncio
+async def test_websocket_endpoint():
+ base_http_url = "http://127.0.0.1:8000"
+ base_ws_url = "ws://127.0.0.1:8000"
+ user_id = "test_user"
+ session_id = "test_session"
+
+ # Ensure that the session exists (create if necessary).
+ url_create = (
+ f"{base_http_url}/apps/test_app/users/{user_id}/sessions/{session_id}"
+ )
+ httpx.post(url_create, json={"state": {}})
+
+ ws_url = f"{base_ws_url}/run_live?app_name=test_app&user_id={user_id}&session_id={session_id}"
+ async with websockets.connect(ws_url) as ws:
+ # --- Test sending text data ---
+ text_payload = LiveRequest(
+ content=types.Content(
+ parts=[types.Part(text="Hello via WebSocket", inline_data=None)]
+ )
+ )
+ await ws.send(text_payload.model_dump_json())
+ # Wait for a reply from our dummy_run_live.
+ reply = await ws.recv()
+ event = Event.model_validate_json(reply)
+ assert event.content.parts[0].text == "LLM reply"
+
+ # --- Test sending binary data (allowed mime type "audio/pcm") ---
+ sample_audio = b"\x00\xFF"
+ binary_payload = LiveRequest(
+ blob=types.Blob(
+ mime_type="audio/pcm",
+ data=sample_audio,
+ )
+ )
+ await ws.send(binary_payload.model_dump_json())
+ # Wait for a reply.
+ reply = await ws.recv()
+ event = Event.model_validate_json(reply)
+ assert (
+ event.content.parts[0].inline_data.mime_type == "audio/pcm;rate=24000"
+ )
+ assert event.content.parts[0].inline_data.data == b"\x00\xFF"
+
+ reply = await ws.recv()
+ event = Event.model_validate_json(reply)
+ assert event.interrupted is True
+ assert event.content is None
diff --git a/src/google/adk/tests/unittests/flows/__init__.py b/src/google/adk/tests/unittests/flows/__init__.py
new file mode 100644
index 0000000..36a1e8d
--- /dev/null
+++ b/src/google/adk/tests/unittests/flows/__init__.py
@@ -0,0 +1,14 @@
+# Copyright 2025 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
diff --git a/src/google/adk/tests/unittests/flows/llm_flows/__init__.py b/src/google/adk/tests/unittests/flows/llm_flows/__init__.py
new file mode 100644
index 0000000..36a1e8d
--- /dev/null
+++ b/src/google/adk/tests/unittests/flows/llm_flows/__init__.py
@@ -0,0 +1,14 @@
+# Copyright 2025 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
diff --git a/src/google/adk/tests/unittests/flows/llm_flows/_test_examples.py b/src/google/adk/tests/unittests/flows/llm_flows/_test_examples.py
new file mode 100644
index 0000000..9b51460
--- /dev/null
+++ b/src/google/adk/tests/unittests/flows/llm_flows/_test_examples.py
@@ -0,0 +1,142 @@
+# Copyright 2025 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# TODO: delete and rewrite unit tests
+from google.adk.agents import Agent
+from google.adk.examples import BaseExampleProvider
+from google.adk.examples import Example
+from google.adk.flows.llm_flows import examples
+from google.adk.models.base_llm import LlmRequest
+from google.genai import types
+import pytest
+
+from ... import utils
+
+
+@pytest.mark.asyncio
+async def test_no_examples():
+ request = LlmRequest(
+ model="gemini-1.5-flash",
+ config=types.GenerateContentConfig(system_instruction=""),
+ )
+ agent = Agent(model="gemini-1.5-flash", name="agent", examples=[])
+ invocation_context = utils.create_invocation_context(
+ agent=agent, user_content=""
+ )
+
+ async for _ in examples.request_processor.run_async(
+ invocation_context,
+ request,
+ ):
+ pass
+
+ assert request.config.system_instruction == ""
+
+
+@pytest.mark.asyncio
+async def test_agent_examples():
+ example_list = [
+ Example(
+ input=types.Content(
+ role="user",
+ parts=[types.Part.from_text(text="test1")],
+ ),
+ output=[
+ types.Content(
+ role="model",
+ parts=[types.Part.from_text(text="response1")],
+ ),
+ ],
+ )
+ ]
+ request = LlmRequest(
+ model="gemini-1.5-flash",
+ config=types.GenerateContentConfig(system_instruction=""),
+ )
+ agent = Agent(
+ model="gemini-1.5-flash",
+ name="agent",
+ examples=example_list,
+ )
+ invocation_context = utils.create_invocation_context(
+ agent=agent, user_content="test"
+ )
+
+ async for _ in examples.request_processor.run_async(
+ invocation_context,
+ request,
+ ):
+ pass
+
+ assert (
+ request.config.system_instruction
+ == "\nBegin few-shot\nThe following are examples of user"
+ " queries and model responses using the available tools.\n\nEXAMPLE"
+ " 1:\nBegin example\n[user]\ntest1\n\n[model]\nresponse1\nEnd"
+ " example\n\nEnd few-shot\nNow, try to follow these examples and"
+ " complete the following conversation\n"
+ )
+
+
+@pytest.mark.asyncio
+async def test_agent_base_example_provider():
+ class TestExampleProvider(BaseExampleProvider):
+
+ def get_examples(self, query: str) -> list[Example]:
+ if query == "test":
+ return [
+ Example(
+ input=types.Content(
+ role="user",
+ parts=[types.Part.from_text(text="test")],
+ ),
+ output=[
+ types.Content(
+ role="model",
+ parts=[types.Part.from_text(text="response1")],
+ ),
+ ],
+ )
+ ]
+ else:
+ return []
+
+ provider = TestExampleProvider()
+ request = LlmRequest(
+ model="gemini-1.5-flash",
+ config=types.GenerateContentConfig(system_instruction=""),
+ )
+ agent = Agent(
+ model="gemini-1.5-flash",
+ name="agent",
+ examples=provider,
+ )
+ invocation_context = utils.create_invocation_context(
+ agent=agent, user_content="test"
+ )
+
+ async for _ in examples.request_processor.run_async(
+ invocation_context,
+ request,
+ ):
+ pass
+
+ assert (
+ request.config.system_instruction
+ == "\nBegin few-shot\nThe following are examples of user"
+ " queries and model responses using the available tools.\n\nEXAMPLE"
+ " 1:\nBegin example\n[user]\ntest\n\n[model]\nresponse1\nEnd"
+ " example\n\nEnd few-shot\nNow, try to follow these examples and"
+ " complete the following conversation\n"
+ )
diff --git a/src/google/adk/tests/unittests/flows/llm_flows/test_agent_transfer.py b/src/google/adk/tests/unittests/flows/llm_flows/test_agent_transfer.py
new file mode 100644
index 0000000..f236077
--- /dev/null
+++ b/src/google/adk/tests/unittests/flows/llm_flows/test_agent_transfer.py
@@ -0,0 +1,311 @@
+# Copyright 2025 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from google.adk.agents.llm_agent import Agent
+from google.adk.agents.loop_agent import LoopAgent
+from google.adk.agents.sequential_agent import SequentialAgent
+from google.adk.tools import exit_loop
+from google.genai.types import Part
+
+from ... import utils
+
+
+def transfer_call_part(agent_name: str) -> Part:
+ return Part.from_function_call(
+ name='transfer_to_agent', args={'agent_name': agent_name}
+ )
+
+
+TRANSFER_RESPONSE_PART = Part.from_function_response(
+ name='transfer_to_agent', response={}
+)
+
+
+def test_auto_to_auto():
+ response = [
+ transfer_call_part('sub_agent_1'),
+ 'response1',
+ 'response2',
+ ]
+ mockModel = utils.MockModel.create(responses=response)
+ # root (auto) - sub_agent_1 (auto)
+ sub_agent_1 = Agent(name='sub_agent_1', model=mockModel)
+ root_agent = Agent(
+ name='root_agent',
+ model=mockModel,
+ sub_agents=[sub_agent_1],
+ )
+
+ runner = utils.InMemoryRunner(root_agent)
+
+ # Asserts the transfer.
+ assert utils.simplify_events(runner.run('test1')) == [
+ ('root_agent', transfer_call_part('sub_agent_1')),
+ ('root_agent', TRANSFER_RESPONSE_PART),
+ ('sub_agent_1', 'response1'),
+ ]
+
+ # sub_agent_1 should still be the current agent.
+ assert utils.simplify_events(runner.run('test2')) == [
+ ('sub_agent_1', 'response2'),
+ ]
+
+
+def test_auto_to_single():
+ response = [
+ transfer_call_part('sub_agent_1'),
+ 'response1',
+ 'response2',
+ ]
+ mockModel = utils.MockModel.create(responses=response)
+ # root (auto) - sub_agent_1 (single)
+ sub_agent_1 = Agent(
+ name='sub_agent_1',
+ model=mockModel,
+ disallow_transfer_to_parent=True,
+ disallow_transfer_to_peers=True,
+ )
+ root_agent = Agent(
+ name='root_agent', model=mockModel, sub_agents=[sub_agent_1]
+ )
+
+ runner = utils.InMemoryRunner(root_agent)
+
+ # Asserts the responses.
+ assert utils.simplify_events(runner.run('test1')) == [
+ ('root_agent', transfer_call_part('sub_agent_1')),
+ ('root_agent', TRANSFER_RESPONSE_PART),
+ ('sub_agent_1', 'response1'),
+ ]
+
+ # root_agent should still be the current agent, becaues sub_agent_1 is single.
+ assert utils.simplify_events(runner.run('test2')) == [
+ ('root_agent', 'response2'),
+ ]
+
+
+def test_auto_to_auto_to_single():
+ response = [
+ transfer_call_part('sub_agent_1'),
+ # sub_agent_1 transfers to sub_agent_1_1.
+ transfer_call_part('sub_agent_1_1'),
+ 'response1',
+ 'response2',
+ ]
+ mockModel = utils.MockModel.create(responses=response)
+ # root (auto) - sub_agent_1 (auto) - sub_agent_1_1 (single)
+ sub_agent_1_1 = Agent(
+ name='sub_agent_1_1',
+ model=mockModel,
+ disallow_transfer_to_parent=True,
+ disallow_transfer_to_peers=True,
+ )
+ sub_agent_1 = Agent(
+ name='sub_agent_1', model=mockModel, sub_agents=[sub_agent_1_1]
+ )
+ root_agent = Agent(
+ name='root_agent', model=mockModel, sub_agents=[sub_agent_1]
+ )
+
+ runner = utils.InMemoryRunner(root_agent)
+
+ # Asserts the responses.
+ assert utils.simplify_events(runner.run('test1')) == [
+ ('root_agent', transfer_call_part('sub_agent_1')),
+ ('root_agent', TRANSFER_RESPONSE_PART),
+ ('sub_agent_1', transfer_call_part('sub_agent_1_1')),
+ ('sub_agent_1', TRANSFER_RESPONSE_PART),
+ ('sub_agent_1_1', 'response1'),
+ ]
+
+ # sub_agent_1 should still be the current agent. sub_agent_1_1 is single so it should
+ # not be the current agent, otherwise the conversation will be tied to
+ # sub_agent_1_1 forever.
+ assert utils.simplify_events(runner.run('test2')) == [
+ ('sub_agent_1', 'response2'),
+ ]
+
+
+def test_auto_to_sequential():
+ response = [
+ transfer_call_part('sub_agent_1'),
+ # sub_agent_1 responds directly instead of transfering.
+ 'response1',
+ 'response2',
+ 'response3',
+ ]
+ mockModel = utils.MockModel.create(responses=response)
+ # root (auto) - sub_agent_1 (sequential) - sub_agent_1_1 (single)
+ # \ sub_agent_1_2 (single)
+ sub_agent_1_1 = Agent(
+ name='sub_agent_1_1',
+ model=mockModel,
+ disallow_transfer_to_parent=True,
+ disallow_transfer_to_peers=True,
+ )
+ sub_agent_1_2 = Agent(
+ name='sub_agent_1_2',
+ model=mockModel,
+ disallow_transfer_to_parent=True,
+ disallow_transfer_to_peers=True,
+ )
+ sub_agent_1 = SequentialAgent(
+ name='sub_agent_1',
+ sub_agents=[sub_agent_1_1, sub_agent_1_2],
+ )
+ root_agent = Agent(
+ name='root_agent',
+ model=mockModel,
+ sub_agents=[sub_agent_1],
+ )
+
+ runner = utils.InMemoryRunner(root_agent)
+
+ # Asserts the transfer.
+ assert utils.simplify_events(runner.run('test1')) == [
+ ('root_agent', transfer_call_part('sub_agent_1')),
+ ('root_agent', TRANSFER_RESPONSE_PART),
+ ('sub_agent_1_1', 'response1'),
+ ('sub_agent_1_2', 'response2'),
+ ]
+
+ # root_agent should still be the current agent because sub_agent_1 is sequential.
+ assert utils.simplify_events(runner.run('test2')) == [
+ ('root_agent', 'response3'),
+ ]
+
+
+def test_auto_to_sequential_to_auto():
+ response = [
+ transfer_call_part('sub_agent_1'),
+ # sub_agent_1 responds directly instead of transfering.
+ 'response1',
+ transfer_call_part('sub_agent_1_2_1'),
+ 'response2',
+ 'response3',
+ 'response4',
+ ]
+ mockModel = utils.MockModel.create(responses=response)
+ # root (auto) - sub_agent_1 (seq) - sub_agent_1_1 (single)
+ # \ sub_agent_1_2 (auto) - sub_agent_1_2_1 (auto)
+ # \ sub_agent_1_3 (single)
+ sub_agent_1_1 = Agent(
+ name='sub_agent_1_1',
+ model=mockModel,
+ disallow_transfer_to_parent=True,
+ disallow_transfer_to_peers=True,
+ )
+ sub_agent_1_2_1 = Agent(name='sub_agent_1_2_1', model=mockModel)
+ sub_agent_1_2 = Agent(
+ name='sub_agent_1_2',
+ model=mockModel,
+ sub_agents=[sub_agent_1_2_1],
+ )
+ sub_agent_1_3 = Agent(
+ name='sub_agent_1_3',
+ model=mockModel,
+ disallow_transfer_to_parent=True,
+ disallow_transfer_to_peers=True,
+ )
+ sub_agent_1 = SequentialAgent(
+ name='sub_agent_1',
+ sub_agents=[sub_agent_1_1, sub_agent_1_2, sub_agent_1_3],
+ )
+ root_agent = Agent(
+ name='root_agent',
+ model=mockModel,
+ sub_agents=[sub_agent_1],
+ )
+
+ runner = utils.InMemoryRunner(root_agent)
+
+ # Asserts the transfer.
+ assert utils.simplify_events(runner.run('test1')) == [
+ ('root_agent', transfer_call_part('sub_agent_1')),
+ ('root_agent', TRANSFER_RESPONSE_PART),
+ ('sub_agent_1_1', 'response1'),
+ ('sub_agent_1_2', transfer_call_part('sub_agent_1_2_1')),
+ ('sub_agent_1_2', TRANSFER_RESPONSE_PART),
+ ('sub_agent_1_2_1', 'response2'),
+ ('sub_agent_1_3', 'response3'),
+ ]
+
+ # root_agent should still be the current agent because sub_agent_1 is sequential.
+ assert utils.simplify_events(runner.run('test2')) == [
+ ('root_agent', 'response4'),
+ ]
+
+
+def test_auto_to_loop():
+ response = [
+ transfer_call_part('sub_agent_1'),
+ # sub_agent_1 responds directly instead of transfering.
+ 'response1',
+ 'response2',
+ 'response3',
+ Part.from_function_call(name='exit_loop', args={}),
+ 'response4',
+ 'response5',
+ ]
+ mockModel = utils.MockModel.create(responses=response)
+ # root (auto) - sub_agent_1 (loop) - sub_agent_1_1 (single)
+ # \ sub_agent_1_2 (single)
+ sub_agent_1_1 = Agent(
+ name='sub_agent_1_1',
+ model=mockModel,
+ disallow_transfer_to_parent=True,
+ disallow_transfer_to_peers=True,
+ )
+ sub_agent_1_2 = Agent(
+ name='sub_agent_1_2',
+ model=mockModel,
+ disallow_transfer_to_parent=True,
+ disallow_transfer_to_peers=True,
+ tools=[exit_loop],
+ )
+ sub_agent_1 = LoopAgent(
+ name='sub_agent_1',
+ sub_agents=[sub_agent_1_1, sub_agent_1_2],
+ )
+ root_agent = Agent(
+ name='root_agent',
+ model=mockModel,
+ sub_agents=[sub_agent_1],
+ )
+
+ runner = utils.InMemoryRunner(root_agent)
+
+ # Asserts the transfer.
+ assert utils.simplify_events(runner.run('test1')) == [
+ # Transfers to sub_agent_1.
+ ('root_agent', transfer_call_part('sub_agent_1')),
+ ('root_agent', TRANSFER_RESPONSE_PART),
+ # Loops.
+ ('sub_agent_1_1', 'response1'),
+ ('sub_agent_1_2', 'response2'),
+ ('sub_agent_1_1', 'response3'),
+ # Exits.
+ ('sub_agent_1_2', Part.from_function_call(name='exit_loop', args={})),
+ (
+ 'sub_agent_1_2',
+ Part.from_function_response(name='exit_loop', response={}),
+ ),
+ # root_agent summarizes.
+ ('root_agent', 'response4'),
+ ]
+
+ # root_agent should still be the current agent because sub_agent_1 is loop.
+ assert utils.simplify_events(runner.run('test2')) == [
+ ('root_agent', 'response5'),
+ ]
diff --git a/src/google/adk/tests/unittests/flows/llm_flows/test_functions_long_running.py b/src/google/adk/tests/unittests/flows/llm_flows/test_functions_long_running.py
new file mode 100644
index 0000000..a547517
--- /dev/null
+++ b/src/google/adk/tests/unittests/flows/llm_flows/test_functions_long_running.py
@@ -0,0 +1,244 @@
+# Copyright 2025 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from google.adk.agents import Agent
+from google.adk.tools import ToolContext
+from google.adk.tools.long_running_tool import LongRunningFunctionTool
+from google.genai.types import Part
+
+from ... import utils
+
+
+def test_async_function():
+ responses = [
+ Part.from_function_call(name='increase_by_one', args={'x': 1}),
+ 'response1',
+ 'response2',
+ 'response3',
+ 'response4',
+ ]
+ mockModel = utils.MockModel.create(responses=responses)
+ function_called = 0
+
+ def increase_by_one(x: int, tool_context: ToolContext) -> int:
+ nonlocal function_called
+
+ function_called += 1
+ return {'status': 'pending'}
+
+ # Calls the first time.
+ agent = Agent(
+ name='root_agent',
+ model=mockModel,
+ tools=[LongRunningFunctionTool(func=increase_by_one)],
+ )
+ runner = utils.InMemoryRunner(agent)
+ events = runner.run('test1')
+
+ # Asserts the requests.
+ assert len(mockModel.requests) == 2
+ # 1 item: user content
+ assert mockModel.requests[0].contents == [
+ utils.UserContent('test1'),
+ ]
+ increase_by_one_call = Part.from_function_call(
+ name='increase_by_one', args={'x': 1}
+ )
+ pending_response = Part.from_function_response(
+ name='increase_by_one', response={'status': 'pending'}
+ )
+
+ assert utils.simplify_contents(mockModel.requests[1].contents) == [
+ ('user', 'test1'),
+ ('model', increase_by_one_call),
+ ('user', pending_response),
+ ]
+
+ # Asserts the function calls.
+ assert function_called == 1
+
+ # Asserts the responses.
+ assert utils.simplify_events(events) == [
+ (
+ 'root_agent',
+ Part.from_function_call(name='increase_by_one', args={'x': 1}),
+ ),
+ (
+ 'root_agent',
+ Part.from_function_response(
+ name='increase_by_one', response={'status': 'pending'}
+ ),
+ ),
+ ('root_agent', 'response1'),
+ ]
+ assert events[0].long_running_tool_ids
+
+ # Updates with another pending progress.
+ still_waiting_response = Part.from_function_response(
+ name='increase_by_one', response={'status': 'still waiting'}
+ )
+ events = runner.run(utils.UserContent(still_waiting_response))
+ # We have one new request.
+ assert len(mockModel.requests) == 3
+ assert utils.simplify_contents(mockModel.requests[2].contents) == [
+ ('user', 'test1'),
+ ('model', increase_by_one_call),
+ ('user', still_waiting_response),
+ ]
+
+ assert utils.simplify_events(events) == [('root_agent', 'response2')]
+
+ # Calls when the result is ready.
+ result_response = Part.from_function_response(
+ name='increase_by_one', response={'result': 2}
+ )
+ events = runner.run(utils.UserContent(result_response))
+ # We have one new request.
+ assert len(mockModel.requests) == 4
+ assert utils.simplify_contents(mockModel.requests[3].contents) == [
+ ('user', 'test1'),
+ ('model', increase_by_one_call),
+ ('user', result_response),
+ ]
+ assert utils.simplify_events(events) == [('root_agent', 'response3')]
+
+ # Calls when the result is ready. Here we still accept the result and do
+ # another summarization. Whether this is the right behavior is TBD.
+ another_result_response = Part.from_function_response(
+ name='increase_by_one', response={'result': 3}
+ )
+ events = runner.run(utils.UserContent(another_result_response))
+ # We have one new request.
+ assert len(mockModel.requests) == 5
+ assert utils.simplify_contents(mockModel.requests[4].contents) == [
+ ('user', 'test1'),
+ ('model', increase_by_one_call),
+ ('user', another_result_response),
+ ]
+ assert utils.simplify_events(events) == [('root_agent', 'response4')]
+
+ # At the end, function_called should still be 1.
+ assert function_called == 1
+
+
+def test_async_function_with_none_response():
+ responses = [
+ Part.from_function_call(name='increase_by_one', args={'x': 1}),
+ 'response1',
+ 'response2',
+ 'response3',
+ 'response4',
+ ]
+ mockModel = utils.MockModel.create(responses=responses)
+ function_called = 0
+
+ def increase_by_one(x: int, tool_context: ToolContext) -> int:
+ nonlocal function_called
+ function_called += 1
+ return 'pending'
+
+ # Calls the first time.
+ agent = Agent(
+ name='root_agent',
+ model=mockModel,
+ tools=[LongRunningFunctionTool(func=increase_by_one)],
+ )
+ runner = utils.InMemoryRunner(agent)
+ events = runner.run('test1')
+
+ # Asserts the requests.
+ assert len(mockModel.requests) == 2
+ # 1 item: user content
+ assert mockModel.requests[0].contents == [
+ utils.UserContent('test1'),
+ ]
+ increase_by_one_call = Part.from_function_call(
+ name='increase_by_one', args={'x': 1}
+ )
+
+ assert utils.simplify_contents(mockModel.requests[1].contents) == [
+ ('user', 'test1'),
+ ('model', increase_by_one_call),
+ (
+ 'user',
+ Part.from_function_response(
+ name='increase_by_one', response={'result': 'pending'}
+ ),
+ ),
+ ]
+
+ # Asserts the function calls.
+ assert function_called == 1
+
+ # Asserts the responses.
+ assert utils.simplify_events(events) == [
+ (
+ 'root_agent',
+ Part.from_function_call(name='increase_by_one', args={'x': 1}),
+ ),
+ (
+ 'root_agent',
+ Part.from_function_response(
+ name='increase_by_one', response={'result': 'pending'}
+ ),
+ ),
+ ('root_agent', 'response1'),
+ ]
+
+ # Updates with another pending progress.
+ still_waiting_response = Part.from_function_response(
+ name='increase_by_one', response={'status': 'still waiting'}
+ )
+ events = runner.run(utils.UserContent(still_waiting_response))
+ # We have one new request.
+ assert len(mockModel.requests) == 3
+ assert utils.simplify_contents(mockModel.requests[2].contents) == [
+ ('user', 'test1'),
+ ('model', increase_by_one_call),
+ ('user', still_waiting_response),
+ ]
+
+ assert utils.simplify_events(events) == [('root_agent', 'response2')]
+
+ # Calls when the result is ready.
+ result_response = Part.from_function_response(
+ name='increase_by_one', response={'result': 2}
+ )
+ events = runner.run(utils.UserContent(result_response))
+ # We have one new request.
+ assert len(mockModel.requests) == 4
+ assert utils.simplify_contents(mockModel.requests[3].contents) == [
+ ('user', 'test1'),
+ ('model', increase_by_one_call),
+ ('user', result_response),
+ ]
+ assert utils.simplify_events(events) == [('root_agent', 'response3')]
+
+ # Calls when the result is ready. Here we still accept the result and do
+ # another summarization. Whether this is the right behavior is TBD.
+ another_result_response = Part.from_function_response(
+ name='increase_by_one', response={'result': 3}
+ )
+ events = runner.run(utils.UserContent(another_result_response))
+ # We have one new request.
+ assert len(mockModel.requests) == 5
+ assert utils.simplify_contents(mockModel.requests[4].contents) == [
+ ('user', 'test1'),
+ ('model', increase_by_one_call),
+ ('user', another_result_response),
+ ]
+ assert utils.simplify_events(events) == [('root_agent', 'response4')]
+
+ # At the end, function_called should still be 1.
+ assert function_called == 1
diff --git a/src/google/adk/tests/unittests/flows/llm_flows/test_functions_request_euc.py b/src/google/adk/tests/unittests/flows/llm_flows/test_functions_request_euc.py
new file mode 100644
index 0000000..5c6b784
--- /dev/null
+++ b/src/google/adk/tests/unittests/flows/llm_flows/test_functions_request_euc.py
@@ -0,0 +1,346 @@
+# Copyright 2025 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import Any
+from typing import Optional
+
+from fastapi.openapi.models import OAuth2
+from fastapi.openapi.models import OAuthFlowAuthorizationCode
+from fastapi.openapi.models import OAuthFlows
+from google.adk.agents import Agent
+from google.adk.auth import AuthConfig
+from google.adk.auth import AuthCredential
+from google.adk.auth import AuthCredentialTypes
+from google.adk.auth import OAuth2Auth
+from google.adk.flows.llm_flows import functions
+from google.adk.tools import AuthToolArguments
+from google.adk.tools import ToolContext
+from google.genai import types
+
+from ... import utils
+
+
+def function_call(function_call_id, name, args: dict[str, Any]) -> types.Part:
+ part = types.Part.from_function_call(name=name, args=args)
+ part.function_call.id = function_call_id
+ return part
+
+
+def test_function_request_euc():
+ responses = [
+ [
+ types.Part.from_function_call(name='call_external_api1', args={}),
+ types.Part.from_function_call(name='call_external_api2', args={}),
+ ],
+ [
+ types.Part.from_text(text='response1'),
+ ],
+ ]
+
+ auth_config1 = AuthConfig(
+ auth_scheme=OAuth2(
+ flows=OAuthFlows(
+ authorizationCode=OAuthFlowAuthorizationCode(
+ authorizationUrl='https://accounts.google.com/o/oauth2/auth',
+ tokenUrl='https://oauth2.googleapis.com/token',
+ scopes={
+ 'https://www.googleapis.com/auth/calendar': (
+ 'See, edit, share, and permanently delete all the'
+ ' calendars you can access using Google Calendar'
+ )
+ },
+ )
+ )
+ ),
+ raw_auth_credential=AuthCredential(
+ auth_type=AuthCredentialTypes.OAUTH2,
+ oauth2=OAuth2Auth(
+ client_id='oauth_client_id_1',
+ client_secret='oauth_client_secret1',
+ ),
+ ),
+ )
+ auth_config2 = AuthConfig(
+ auth_scheme=OAuth2(
+ flows=OAuthFlows(
+ authorizationCode=OAuthFlowAuthorizationCode(
+ authorizationUrl='https://accounts.google.com/o/oauth2/auth',
+ tokenUrl='https://oauth2.googleapis.com/token',
+ scopes={
+ 'https://www.googleapis.com/auth/calendar': (
+ 'See, edit, share, and permanently delete all the'
+ ' calendars you can access using Google Calendar'
+ )
+ },
+ )
+ )
+ ),
+ raw_auth_credential=AuthCredential(
+ auth_type=AuthCredentialTypes.OAUTH2,
+ oauth2=OAuth2Auth(
+ client_id='oauth_client_id_2',
+ client_secret='oauth_client_secret2',
+ ),
+ ),
+ )
+
+ mock_model = utils.MockModel.create(responses=responses)
+
+ def call_external_api1(tool_context: ToolContext) -> Optional[int]:
+ tool_context.request_credential(auth_config1)
+
+ def call_external_api2(tool_context: ToolContext) -> Optional[int]:
+ tool_context.request_credential(auth_config2)
+
+ agent = Agent(
+ name='root_agent',
+ model=mock_model,
+ tools=[call_external_api1, call_external_api2],
+ )
+ runner = utils.InMemoryRunner(agent)
+ events = runner.run('test')
+ assert events[0].content.parts[0].function_call is not None
+ assert events[0].content.parts[1].function_call is not None
+ auth_configs = list(events[2].actions.requested_auth_configs.values())
+ exchanged_auth_config1 = auth_configs[0]
+ exchanged_auth_config2 = auth_configs[1]
+ assert exchanged_auth_config1.auth_scheme == auth_config1.auth_scheme
+ assert (
+ exchanged_auth_config1.raw_auth_credential
+ == auth_config1.raw_auth_credential
+ )
+ assert (
+ exchanged_auth_config1.exchanged_auth_credential.oauth2.auth_uri
+ is not None
+ )
+ assert exchanged_auth_config2.auth_scheme == auth_config2.auth_scheme
+ assert (
+ exchanged_auth_config2.raw_auth_credential
+ == auth_config2.raw_auth_credential
+ )
+ assert (
+ exchanged_auth_config2.exchanged_auth_credential.oauth2.auth_uri
+ is not None
+ )
+ function_call_ids = list(events[2].actions.requested_auth_configs.keys())
+
+ for idx, part in enumerate(events[1].content.parts):
+ reqeust_euc_function_call = part.function_call
+ assert reqeust_euc_function_call is not None
+ assert (
+ reqeust_euc_function_call.name
+ == functions.REQUEST_EUC_FUNCTION_CALL_NAME
+ )
+ args = AuthToolArguments.model_validate(reqeust_euc_function_call.args)
+
+ assert args.function_call_id == function_call_ids[idx]
+ args.auth_config.auth_scheme.model_extra.clear()
+ assert args.auth_config.auth_scheme == auth_configs[idx].auth_scheme
+ assert (
+ args.auth_config.raw_auth_credential
+ == auth_configs[idx].raw_auth_credential
+ )
+
+
+def test_function_get_auth_response():
+ id_1 = 'id_1'
+ id_2 = 'id_2'
+ responses = [
+ [
+ function_call(id_1, 'call_external_api1', {}),
+ function_call(id_2, 'call_external_api2', {}),
+ ],
+ [
+ types.Part.from_text(text='response1'),
+ ],
+ [
+ types.Part.from_text(text='response2'),
+ ],
+ ]
+
+ mock_model = utils.MockModel.create(responses=responses)
+ function_invoked = 0
+
+ auth_config1 = AuthConfig(
+ auth_scheme=OAuth2(
+ flows=OAuthFlows(
+ authorizationCode=OAuthFlowAuthorizationCode(
+ authorizationUrl='https://accounts.google.com/o/oauth2/auth',
+ tokenUrl='https://oauth2.googleapis.com/token',
+ scopes={
+ 'https://www.googleapis.com/auth/calendar': (
+ 'See, edit, share, and permanently delete all the'
+ ' calendars you can access using Google Calendar'
+ )
+ },
+ )
+ )
+ ),
+ raw_auth_credential=AuthCredential(
+ auth_type=AuthCredentialTypes.OAUTH2,
+ oauth2=OAuth2Auth(
+ client_id='oauth_client_id_1',
+ client_secret='oauth_client_secret1',
+ ),
+ ),
+ )
+ auth_config2 = AuthConfig(
+ auth_scheme=OAuth2(
+ flows=OAuthFlows(
+ authorizationCode=OAuthFlowAuthorizationCode(
+ authorizationUrl='https://accounts.google.com/o/oauth2/auth',
+ tokenUrl='https://oauth2.googleapis.com/token',
+ scopes={
+ 'https://www.googleapis.com/auth/calendar': (
+ 'See, edit, share, and permanently delete all the'
+ ' calendars you can access using Google Calendar'
+ )
+ },
+ )
+ )
+ ),
+ raw_auth_credential=AuthCredential(
+ auth_type=AuthCredentialTypes.OAUTH2,
+ oauth2=OAuth2Auth(
+ client_id='oauth_client_id_2',
+ client_secret='oauth_client_secret2',
+ ),
+ ),
+ )
+
+ auth_response1 = AuthConfig(
+ auth_scheme=OAuth2(
+ flows=OAuthFlows(
+ authorizationCode=OAuthFlowAuthorizationCode(
+ authorizationUrl='https://accounts.google.com/o/oauth2/auth',
+ tokenUrl='https://oauth2.googleapis.com/token',
+ scopes={
+ 'https://www.googleapis.com/auth/calendar': (
+ 'See, edit, share, and permanently delete all the'
+ ' calendars you can access using Google Calendar'
+ )
+ },
+ )
+ )
+ ),
+ raw_auth_credential=AuthCredential(
+ auth_type=AuthCredentialTypes.OAUTH2,
+ oauth2=OAuth2Auth(
+ client_id='oauth_client_id_1',
+ client_secret='oauth_client_secret1',
+ ),
+ ),
+ exchanged_auth_credential=AuthCredential(
+ auth_type=AuthCredentialTypes.OAUTH2,
+ oauth2=OAuth2Auth(
+ client_id='oauth_client_id_1',
+ client_secret='oauth_client_secret1',
+ token={'access_token': 'token1'},
+ ),
+ ),
+ )
+ auth_response2 = AuthConfig(
+ auth_scheme=OAuth2(
+ flows=OAuthFlows(
+ authorizationCode=OAuthFlowAuthorizationCode(
+ authorizationUrl='https://accounts.google.com/o/oauth2/auth',
+ tokenUrl='https://oauth2.googleapis.com/token',
+ scopes={
+ 'https://www.googleapis.com/auth/calendar': (
+ 'See, edit, share, and permanently delete all the'
+ ' calendars you can access using Google Calendar'
+ )
+ },
+ )
+ )
+ ),
+ raw_auth_credential=AuthCredential(
+ auth_type=AuthCredentialTypes.OAUTH2,
+ oauth2=OAuth2Auth(
+ client_id='oauth_client_id_2',
+ client_secret='oauth_client_secret2',
+ ),
+ ),
+ exchanged_auth_credential=AuthCredential(
+ auth_type=AuthCredentialTypes.OAUTH2,
+ oauth2=OAuth2Auth(
+ client_id='oauth_client_id_2',
+ client_secret='oauth_client_secret2',
+ token={'access_token': 'token2'},
+ ),
+ ),
+ )
+
+ def call_external_api1(tool_context: ToolContext) -> int:
+ nonlocal function_invoked
+ function_invoked += 1
+ auth_response = tool_context.get_auth_response(auth_config1)
+ if not auth_response:
+ tool_context.request_credential(auth_config1)
+ return
+ assert auth_response == auth_response1.exchanged_auth_credential
+ return 1
+
+ def call_external_api2(tool_context: ToolContext) -> int:
+ nonlocal function_invoked
+ function_invoked += 1
+ auth_response = tool_context.get_auth_response(auth_config2)
+ if not auth_response:
+ tool_context.request_credential(auth_config2)
+ return
+ assert auth_response == auth_response2.exchanged_auth_credential
+ return 2
+
+ agent = Agent(
+ name='root_agent',
+ model=mock_model,
+ tools=[call_external_api1, call_external_api2],
+ )
+ runner = utils.InMemoryRunner(agent)
+ runner.run('test')
+ request_euc_function_call_event = runner.session.events[-3]
+ function_response1 = types.FunctionResponse(
+ name=request_euc_function_call_event.content.parts[0].function_call.name,
+ response=auth_response1.model_dump(),
+ )
+ function_response1.id = request_euc_function_call_event.content.parts[
+ 0
+ ].function_call.id
+
+ function_response2 = types.FunctionResponse(
+ name=request_euc_function_call_event.content.parts[1].function_call.name,
+ response=auth_response2.model_dump(),
+ )
+ function_response2.id = request_euc_function_call_event.content.parts[
+ 1
+ ].function_call.id
+ runner.run(
+ new_message=types.Content(
+ role='user',
+ parts=[
+ types.Part(function_response=function_response1),
+ types.Part(function_response=function_response2),
+ ],
+ ),
+ )
+
+ assert function_invoked == 4
+ reqeust = mock_model.requests[-1]
+ content = reqeust.contents[-1]
+ parts = content.parts
+ assert len(parts) == 2
+ assert parts[0].function_response.name == 'call_external_api1'
+ assert parts[0].function_response.response == {'result': 1}
+ assert parts[1].function_response.name == 'call_external_api2'
+ assert parts[1].function_response.response == {'result': 2}
diff --git a/src/google/adk/tests/unittests/flows/llm_flows/test_functions_sequential.py b/src/google/adk/tests/unittests/flows/llm_flows/test_functions_sequential.py
new file mode 100644
index 0000000..02c2d41
--- /dev/null
+++ b/src/google/adk/tests/unittests/flows/llm_flows/test_functions_sequential.py
@@ -0,0 +1,93 @@
+# Copyright 2025 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import Any
+
+from google.adk.agents import Agent
+from google.genai import types
+
+from ... import utils
+
+
+def function_call(args: dict[str, Any]) -> types.Part:
+ return types.Part.from_function_call(name='increase_by_one', args=args)
+
+
+def function_response(response: dict[str, Any]) -> types.Part:
+ return types.Part.from_function_response(
+ name='increase_by_one', response=response
+ )
+
+
+def test_sequential_calls():
+ responses = [
+ function_call({'x': 1}),
+ function_call({'x': 2}),
+ function_call({'x': 3}),
+ 'response1',
+ ]
+ mockModel = utils.MockModel.create(responses=responses)
+ function_called = 0
+
+ def increase_by_one(x: int) -> int:
+ nonlocal function_called
+ function_called += 1
+ return x + 1
+
+ agent = Agent(name='root_agent', model=mockModel, tools=[increase_by_one])
+ runner = utils.InMemoryRunner(agent)
+ result = utils.simplify_events(runner.run('test'))
+ assert result == [
+ ('root_agent', function_call({'x': 1})),
+ ('root_agent', function_response({'result': 2})),
+ ('root_agent', function_call({'x': 2})),
+ ('root_agent', function_response({'result': 3})),
+ ('root_agent', function_call({'x': 3})),
+ ('root_agent', function_response({'result': 4})),
+ ('root_agent', 'response1'),
+ ]
+
+ # Asserts the requests.
+ assert len(mockModel.requests) == 4
+ # 1 item: user content
+ assert utils.simplify_contents(mockModel.requests[0].contents) == [
+ ('user', 'test')
+ ]
+ # 3 items: user content, functaion call / response for the 1st call
+ assert utils.simplify_contents(mockModel.requests[1].contents) == [
+ ('user', 'test'),
+ ('model', function_call({'x': 1})),
+ ('user', function_response({'result': 2})),
+ ]
+ # 5 items: user content, functaion call / response for two calls
+ assert utils.simplify_contents(mockModel.requests[2].contents) == [
+ ('user', 'test'),
+ ('model', function_call({'x': 1})),
+ ('user', function_response({'result': 2})),
+ ('model', function_call({'x': 2})),
+ ('user', function_response({'result': 3})),
+ ]
+ # 7 items: user content, functaion call / response for three calls
+ assert utils.simplify_contents(mockModel.requests[3].contents) == [
+ ('user', 'test'),
+ ('model', function_call({'x': 1})),
+ ('user', function_response({'result': 2})),
+ ('model', function_call({'x': 2})),
+ ('user', function_response({'result': 3})),
+ ('model', function_call({'x': 3})),
+ ('user', function_response({'result': 4})),
+ ]
+
+ # Asserts the function calls.
+ assert function_called == 3
diff --git a/src/google/adk/tests/unittests/flows/llm_flows/test_functions_simple.py b/src/google/adk/tests/unittests/flows/llm_flows/test_functions_simple.py
new file mode 100644
index 0000000..0e9e43f
--- /dev/null
+++ b/src/google/adk/tests/unittests/flows/llm_flows/test_functions_simple.py
@@ -0,0 +1,258 @@
+# Copyright 2025 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import Any
+from typing import AsyncGenerator
+from typing import Callable
+
+from google.adk.agents import Agent
+from google.adk.tools import ToolContext
+from google.adk.tools.function_tool import FunctionTool
+from google.genai import types
+import pytest
+
+from ... import utils
+
+
+def test_simple_function():
+ function_call_1 = types.Part.from_function_call(
+ name='increase_by_one', args={'x': 1}
+ )
+ function_respones_2 = types.Part.from_function_response(
+ name='increase_by_one', response={'result': 2}
+ )
+ responses: list[types.Content] = [
+ function_call_1,
+ 'response1',
+ 'response2',
+ 'response3',
+ 'response4',
+ ]
+ function_called = 0
+ mock_model = utils.MockModel.create(responses=responses)
+
+ def increase_by_one(x: int) -> int:
+ nonlocal function_called
+ function_called += 1
+ return x + 1
+
+ agent = Agent(name='root_agent', model=mock_model, tools=[increase_by_one])
+ runner = utils.InMemoryRunner(agent)
+ assert utils.simplify_events(runner.run('test')) == [
+ ('root_agent', function_call_1),
+ ('root_agent', function_respones_2),
+ ('root_agent', 'response1'),
+ ]
+
+ # Asserts the requests.
+ assert utils.simplify_contents(mock_model.requests[0].contents) == [
+ ('user', 'test')
+ ]
+ assert utils.simplify_contents(mock_model.requests[1].contents) == [
+ ('user', 'test'),
+ ('model', function_call_1),
+ ('user', function_respones_2),
+ ]
+
+ # Asserts the function calls.
+ assert function_called == 1
+
+
+@pytest.mark.asyncio
+async def test_async_function():
+ function_calls = [
+ types.Part.from_function_call(name='increase_by_one', args={'x': 1}),
+ types.Part.from_function_call(name='multiple_by_two', args={'x': 2}),
+ types.Part.from_function_call(name='multiple_by_two_sync', args={'x': 3}),
+ ]
+ function_responses = [
+ types.Part.from_function_response(
+ name='increase_by_one', response={'result': 2}
+ ),
+ types.Part.from_function_response(
+ name='multiple_by_two', response={'result': 4}
+ ),
+ types.Part.from_function_response(
+ name='multiple_by_two_sync', response={'result': 6}
+ ),
+ ]
+
+ responses: list[types.Content] = [
+ function_calls,
+ 'response1',
+ 'response2',
+ 'response3',
+ 'response4',
+ ]
+ function_called = 0
+ mock_model = utils.MockModel.create(responses=responses)
+
+ async def increase_by_one(x: int) -> int:
+ nonlocal function_called
+ function_called += 1
+ return x + 1
+
+ async def multiple_by_two(x: int) -> int:
+ nonlocal function_called
+ function_called += 1
+ return x * 2
+
+ def multiple_by_two_sync(x: int) -> int:
+ nonlocal function_called
+ function_called += 1
+ return x * 2
+
+ agent = Agent(
+ name='root_agent',
+ model=mock_model,
+ tools=[increase_by_one, multiple_by_two, multiple_by_two_sync],
+ )
+ runner = utils.TestInMemoryRunner(agent)
+ events = await runner.run_async_with_new_session('test')
+ assert utils.simplify_events(events) == [
+ ('root_agent', function_calls),
+ ('root_agent', function_responses),
+ ('root_agent', 'response1'),
+ ]
+
+ # Asserts the requests.
+ assert utils.simplify_contents(mock_model.requests[0].contents) == [
+ ('user', 'test')
+ ]
+ assert utils.simplify_contents(mock_model.requests[1].contents) == [
+ ('user', 'test'),
+ ('model', function_calls),
+ ('user', function_responses),
+ ]
+
+ # Asserts the function calls.
+ assert function_called == 3
+
+
+@pytest.mark.asyncio
+async def test_function_tool():
+ function_calls = [
+ types.Part.from_function_call(name='increase_by_one', args={'x': 1}),
+ types.Part.from_function_call(name='multiple_by_two', args={'x': 2}),
+ types.Part.from_function_call(name='multiple_by_two_sync', args={'x': 3}),
+ ]
+ function_responses = [
+ types.Part.from_function_response(
+ name='increase_by_one', response={'result': 2}
+ ),
+ types.Part.from_function_response(
+ name='multiple_by_two', response={'result': 4}
+ ),
+ types.Part.from_function_response(
+ name='multiple_by_two_sync', response={'result': 6}
+ ),
+ ]
+
+ responses: list[types.Content] = [
+ function_calls,
+ 'response1',
+ 'response2',
+ 'response3',
+ 'response4',
+ ]
+ function_called = 0
+ mock_model = utils.MockModel.create(responses=responses)
+
+ async def increase_by_one(x: int) -> int:
+ nonlocal function_called
+ function_called += 1
+ return x + 1
+
+ async def multiple_by_two(x: int) -> int:
+ nonlocal function_called
+ function_called += 1
+ return x * 2
+
+ def multiple_by_two_sync(x: int) -> int:
+ nonlocal function_called
+ function_called += 1
+ return x * 2
+
+ class TestTool(FunctionTool):
+
+ def __init__(self, func: Callable[..., Any]):
+ super().__init__(func=func)
+
+ wrapped_increase_by_one = TestTool(func=increase_by_one)
+ agent = Agent(
+ name='root_agent',
+ model=mock_model,
+ tools=[wrapped_increase_by_one, multiple_by_two, multiple_by_two_sync],
+ )
+ runner = utils.TestInMemoryRunner(agent)
+ events = await runner.run_async_with_new_session('test')
+ assert utils.simplify_events(events) == [
+ ('root_agent', function_calls),
+ ('root_agent', function_responses),
+ ('root_agent', 'response1'),
+ ]
+
+ # Asserts the requests.
+ assert utils.simplify_contents(mock_model.requests[0].contents) == [
+ ('user', 'test')
+ ]
+ assert utils.simplify_contents(mock_model.requests[1].contents) == [
+ ('user', 'test'),
+ ('model', function_calls),
+ ('user', function_responses),
+ ]
+
+ # Asserts the function calls.
+ assert function_called == 3
+
+
+def test_update_state():
+ mock_model = utils.MockModel.create(
+ responses=[
+ types.Part.from_function_call(name='update_state', args={}),
+ 'response1',
+ ]
+ )
+
+ def update_state(tool_context: ToolContext):
+ tool_context.state['x'] = 1
+
+ agent = Agent(name='root_agent', model=mock_model, tools=[update_state])
+ runner = utils.InMemoryRunner(agent)
+ runner.run('test')
+ assert runner.session.state['x'] == 1
+
+
+def test_function_call_id():
+ responses = [
+ types.Part.from_function_call(name='increase_by_one', args={'x': 1}),
+ 'response1',
+ ]
+ mock_model = utils.MockModel.create(responses=responses)
+
+ def increase_by_one(x: int) -> int:
+ return x + 1
+
+ agent = Agent(name='root_agent', model=mock_model, tools=[increase_by_one])
+ runner = utils.InMemoryRunner(agent)
+ events = runner.run('test')
+ for reqeust in mock_model.requests:
+ for content in reqeust.contents:
+ for part in content.parts:
+ if part.function_call:
+ assert part.function_call.id is None
+ if part.function_response:
+ assert part.function_response.id is None
+ assert events[0].content.parts[0].function_call.id.startswith('adk-')
+ assert events[1].content.parts[0].function_response.id.startswith('adk-')
diff --git a/src/google/adk/tests/unittests/flows/llm_flows/test_identity.py b/src/google/adk/tests/unittests/flows/llm_flows/test_identity.py
new file mode 100644
index 0000000..0e88527
--- /dev/null
+++ b/src/google/adk/tests/unittests/flows/llm_flows/test_identity.py
@@ -0,0 +1,66 @@
+# Copyright 2025 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from google.adk.agents import Agent
+from google.adk.flows.llm_flows import identity
+from google.adk.models import LlmRequest
+from google.genai import types
+import pytest
+
+from ... import utils
+
+
+@pytest.mark.asyncio
+async def test_no_description():
+ request = LlmRequest(
+ model="gemini-1.5-flash",
+ config=types.GenerateContentConfig(system_instruction=""),
+ )
+ agent = Agent(model="gemini-1.5-flash", name="agent")
+ invocation_context = utils.create_invocation_context(agent=agent)
+
+ async for _ in identity.request_processor.run_async(
+ invocation_context,
+ request,
+ ):
+ pass
+
+ assert request.config.system_instruction == (
+ """You are an agent. Your internal name is "agent"."""
+ )
+
+
+@pytest.mark.asyncio
+async def test_with_description():
+ request = LlmRequest(
+ model="gemini-1.5-flash",
+ config=types.GenerateContentConfig(system_instruction=""),
+ )
+ agent = Agent(
+ model="gemini-1.5-flash",
+ name="agent",
+ description="test description",
+ )
+ invocation_context = utils.create_invocation_context(agent=agent)
+
+ async for _ in identity.request_processor.run_async(
+ invocation_context,
+ request,
+ ):
+ pass
+
+ assert request.config.system_instruction == "\n\n".join([
+ 'You are an agent. Your internal name is "agent".',
+ ' The description about you is "test description"',
+ ])
diff --git a/src/google/adk/tests/unittests/flows/llm_flows/test_instructions.py b/src/google/adk/tests/unittests/flows/llm_flows/test_instructions.py
new file mode 100644
index 0000000..edc7902
--- /dev/null
+++ b/src/google/adk/tests/unittests/flows/llm_flows/test_instructions.py
@@ -0,0 +1,164 @@
+# Copyright 2025 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from google.adk.agents import Agent
+from google.adk.agents.readonly_context import ReadonlyContext
+from google.adk.flows.llm_flows import instructions
+from google.adk.models import LlmRequest
+from google.adk.sessions import Session
+from google.genai import types
+import pytest
+
+from ... import utils
+
+
+@pytest.mark.asyncio
+async def test_build_system_instruction():
+ request = LlmRequest(
+ model="gemini-1.5-flash",
+ config=types.GenerateContentConfig(system_instruction=""),
+ )
+ agent = Agent(
+ model="gemini-1.5-flash",
+ name="agent",
+ instruction=("""Use the echo_info tool to echo { customerId }, \
+{{customer_int }, { non-identifier-float}}, \
+{'key1': 'value1'} and {{'key2': 'value2'}}."""),
+ )
+ invocation_context = utils.create_invocation_context(agent=agent)
+ invocation_context.session = Session(
+ app_name="test_app",
+ user_id="test_user",
+ id="test_id",
+ state={"customerId": "1234567890", "customer_int": 30},
+ )
+
+ async for _ in instructions.request_processor.run_async(
+ invocation_context,
+ request,
+ ):
+ pass
+
+ assert request.config.system_instruction == (
+ """Use the echo_info tool to echo 1234567890, 30, \
+{ non-identifier-float}}, {'key1': 'value1'} and {{'key2': 'value2'}}."""
+ )
+
+
+@pytest.mark.asyncio
+async def test_function_system_instruction():
+ def build_function_instruction(readonly_context: ReadonlyContext) -> str:
+ return (
+ "This is the function agent instruction for invocation:"
+ f" {readonly_context.invocation_id}."
+ )
+
+ request = LlmRequest(
+ model="gemini-1.5-flash",
+ config=types.GenerateContentConfig(system_instruction=""),
+ )
+ agent = Agent(
+ model="gemini-1.5-flash",
+ name="agent",
+ instruction=build_function_instruction,
+ )
+ invocation_context = utils.create_invocation_context(agent=agent)
+ invocation_context.session = Session(
+ app_name="test_app",
+ user_id="test_user",
+ id="test_id",
+ state={"customerId": "1234567890", "customer_int": 30},
+ )
+
+ async for _ in instructions.request_processor.run_async(
+ invocation_context,
+ request,
+ ):
+ pass
+
+ assert request.config.system_instruction == (
+ "This is the function agent instruction for invocation: test_id."
+ )
+
+
+@pytest.mark.asyncio
+async def test_global_system_instruction():
+ sub_agent = Agent(
+ model="gemini-1.5-flash",
+ name="sub_agent",
+ instruction="This is the sub agent instruction.",
+ )
+ root_agent = Agent(
+ model="gemini-1.5-flash",
+ name="root_agent",
+ global_instruction="This is the global instruction.",
+ sub_agents=[sub_agent],
+ )
+ request = LlmRequest(
+ model="gemini-1.5-flash",
+ config=types.GenerateContentConfig(system_instruction=""),
+ )
+ invocation_context = utils.create_invocation_context(agent=sub_agent)
+ invocation_context.session = Session(
+ app_name="test_app",
+ user_id="test_user",
+ id="test_id",
+ state={"customerId": "1234567890", "customer_int": 30},
+ )
+
+ async for _ in instructions.request_processor.run_async(
+ invocation_context,
+ request,
+ ):
+ pass
+
+ assert request.config.system_instruction == (
+ "This is the global instruction.\n\nThis is the sub agent instruction."
+ )
+
+
+@pytest.mark.asyncio
+async def test_build_system_instruction_with_namespace():
+ request = LlmRequest(
+ model="gemini-1.5-flash",
+ config=types.GenerateContentConfig(system_instruction=""),
+ )
+ agent = Agent(
+ model="gemini-1.5-flash",
+ name="agent",
+ instruction=(
+ """Use the echo_info tool to echo { customerId }, {app:key}, {user:key}, {a:key}."""
+ ),
+ )
+ invocation_context = utils.create_invocation_context(agent=agent)
+ invocation_context.session = Session(
+ app_name="test_app",
+ user_id="test_user",
+ id="test_id",
+ state={
+ "customerId": "1234567890",
+ "app:key": "app_value",
+ "user:key": "user_value",
+ },
+ )
+
+ async for _ in instructions.request_processor.run_async(
+ invocation_context,
+ request,
+ ):
+ pass
+
+ assert request.config.system_instruction == (
+ """Use the echo_info tool to echo 1234567890, app_value, user_value, {a:key}."""
+ )
diff --git a/src/google/adk/tests/unittests/flows/llm_flows/test_model_callbacks.py b/src/google/adk/tests/unittests/flows/llm_flows/test_model_callbacks.py
new file mode 100644
index 0000000..dd2d3cf
--- /dev/null
+++ b/src/google/adk/tests/unittests/flows/llm_flows/test_model_callbacks.py
@@ -0,0 +1,142 @@
+# Copyright 2025 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import Any
+from typing import Optional
+
+from google.adk.agents import Agent
+from google.adk.agents.callback_context import CallbackContext
+from google.adk.models import LlmRequest
+from google.adk.models import LlmResponse
+from google.genai import types
+from pydantic import BaseModel
+import pytest
+
+from ... import utils
+
+
+class MockBeforeModelCallback(BaseModel):
+ mock_response: str
+
+ def __call__(
+ self,
+ callback_context: CallbackContext,
+ llm_request: LlmRequest,
+ ) -> LlmResponse:
+ return LlmResponse(
+ content=utils.ModelContent(
+ [types.Part.from_text(text=self.mock_response)]
+ )
+ )
+
+
+class MockAfterModelCallback(BaseModel):
+ mock_response: str
+
+ def __call__(
+ self,
+ callback_context: CallbackContext,
+ llm_response: LlmResponse,
+ ) -> LlmResponse:
+ return LlmResponse(
+ content=utils.ModelContent(
+ [types.Part.from_text(text=self.mock_response)]
+ )
+ )
+
+
+def noop_callback(**kwargs) -> Optional[LlmResponse]:
+ pass
+
+
+def test_before_model_callback():
+ responses = ['model_response']
+ mock_model = utils.MockModel.create(responses=responses)
+ agent = Agent(
+ name='root_agent',
+ model=mock_model,
+ before_model_callback=MockBeforeModelCallback(
+ mock_response='before_model_callback'
+ ),
+ )
+
+ runner = utils.InMemoryRunner(agent)
+ assert utils.simplify_events(runner.run('test')) == [
+ ('root_agent', 'before_model_callback'),
+ ]
+
+
+def test_before_model_callback_noop():
+ responses = ['model_response']
+ mock_model = utils.MockModel.create(responses=responses)
+ agent = Agent(
+ name='root_agent',
+ model=mock_model,
+ before_model_callback=noop_callback,
+ )
+
+ runner = utils.InMemoryRunner(agent)
+ assert utils.simplify_events(runner.run('test')) == [
+ ('root_agent', 'model_response'),
+ ]
+
+
+def test_before_model_callback_end():
+ responses = ['model_response']
+ mock_model = utils.MockModel.create(responses=responses)
+ agent = Agent(
+ name='root_agent',
+ model=mock_model,
+ before_model_callback=MockBeforeModelCallback(
+ mock_response='before_model_callback',
+ ),
+ )
+
+ runner = utils.InMemoryRunner(agent)
+ assert utils.simplify_events(runner.run('test')) == [
+ ('root_agent', 'before_model_callback'),
+ ]
+
+
+def test_after_model_callback():
+ responses = ['model_response']
+ mock_model = utils.MockModel.create(responses=responses)
+ agent = Agent(
+ name='root_agent',
+ model=mock_model,
+ after_model_callback=MockAfterModelCallback(
+ mock_response='after_model_callback'
+ ),
+ )
+
+ runner = utils.InMemoryRunner(agent)
+ assert utils.simplify_events(runner.run('test')) == [
+ ('root_agent', 'after_model_callback'),
+ ]
+
+
+@pytest.mark.asyncio
+async def test_after_model_callback_noop():
+ responses = ['model_response']
+ mock_model = utils.MockModel.create(responses=responses)
+ agent = Agent(
+ name='root_agent',
+ model=mock_model,
+ after_model_callback=noop_callback,
+ )
+
+ runner = utils.TestInMemoryRunner(agent)
+ assert utils.simplify_events(
+ await runner.run_async_with_new_session('test')
+ ) == [('root_agent', 'model_response')]
diff --git a/src/google/adk/tests/unittests/flows/llm_flows/test_other_configs.py b/src/google/adk/tests/unittests/flows/llm_flows/test_other_configs.py
new file mode 100644
index 0000000..63ba950
--- /dev/null
+++ b/src/google/adk/tests/unittests/flows/llm_flows/test_other_configs.py
@@ -0,0 +1,46 @@
+# Copyright 2025 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from google.adk.agents import Agent
+from google.adk.tools import ToolContext
+from google.genai.types import Part
+from pydantic import BaseModel
+
+from ... import utils
+
+
+def test_output_schema():
+ class CustomOutput(BaseModel):
+ custom_field: str
+
+ response = [
+ 'response1',
+ ]
+ mockModel = utils.MockModel.create(responses=response)
+ root_agent = Agent(
+ name='root_agent',
+ model=mockModel,
+ output_schema=CustomOutput,
+ disallow_transfer_to_parent=True,
+ disallow_transfer_to_peers=True,
+ )
+
+ runner = utils.InMemoryRunner(root_agent)
+
+ assert utils.simplify_events(runner.run('test1')) == [
+ ('root_agent', 'response1'),
+ ]
+ assert len(mockModel.requests) == 1
+ assert mockModel.requests[0].config.response_schema == CustomOutput
+ assert mockModel.requests[0].config.response_mime_type == 'application/json'
diff --git a/src/google/adk/tests/unittests/flows/llm_flows/test_tool_callbacks.py b/src/google/adk/tests/unittests/flows/llm_flows/test_tool_callbacks.py
new file mode 100644
index 0000000..5383f41
--- /dev/null
+++ b/src/google/adk/tests/unittests/flows/llm_flows/test_tool_callbacks.py
@@ -0,0 +1,269 @@
+# Copyright 2025 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import Any
+
+from google.adk.agents import Agent
+from google.adk.tools import BaseTool
+from google.adk.tools import ToolContext
+from google.genai import types
+from google.genai.types import Part
+from pydantic import BaseModel
+
+from ... import utils
+
+
+def simple_function(input_str: str) -> str:
+ return {'result': input_str}
+
+
+class MockBeforeToolCallback(BaseModel):
+ mock_response: dict[str, object]
+ modify_tool_request: bool = False
+
+ def __call__(
+ self,
+ tool: BaseTool,
+ args: dict[str, Any],
+ tool_context: ToolContext,
+ ) -> dict[str, object]:
+ if self.modify_tool_request:
+ args['input_str'] = 'modified_input'
+ return None
+ return self.mock_response
+
+
+class MockAfterToolCallback(BaseModel):
+ mock_response: dict[str, object]
+ modify_tool_request: bool = False
+ modify_tool_response: bool = False
+
+ def __call__(
+ self,
+ tool: BaseTool,
+ args: dict[str, Any],
+ tool_context: ToolContext,
+ tool_response: dict[str, Any] = None,
+ ) -> dict[str, object]:
+ if self.modify_tool_request:
+ args['input_str'] = 'modified_input'
+ return None
+ if self.modify_tool_response:
+ tool_response['result'] = 'modified_output'
+ return tool_response
+ return self.mock_response
+
+
+def noop_callback(
+ **kwargs,
+) -> dict[str, object]:
+ pass
+
+
+def test_before_tool_callback():
+ responses = [
+ types.Part.from_function_call(name='simple_function', args={}),
+ 'response1',
+ ]
+ mock_model = utils.MockModel.create(responses=responses)
+ agent = Agent(
+ name='root_agent',
+ model=mock_model,
+ before_tool_callback=MockBeforeToolCallback(
+ mock_response={'test': 'before_tool_callback'}
+ ),
+ tools=[simple_function],
+ )
+
+ runner = utils.InMemoryRunner(agent)
+ assert utils.simplify_events(runner.run('test')) == [
+ ('root_agent', Part.from_function_call(name='simple_function', args={})),
+ (
+ 'root_agent',
+ Part.from_function_response(
+ name='simple_function', response={'test': 'before_tool_callback'}
+ ),
+ ),
+ ('root_agent', 'response1'),
+ ]
+
+
+def test_before_tool_callback_noop():
+ responses = [
+ types.Part.from_function_call(
+ name='simple_function', args={'input_str': 'simple_function_call'}
+ ),
+ 'response1',
+ ]
+ mock_model = utils.MockModel.create(responses=responses)
+ agent = Agent(
+ name='root_agent',
+ model=mock_model,
+ before_tool_callback=noop_callback,
+ tools=[simple_function],
+ )
+
+ runner = utils.InMemoryRunner(agent)
+ assert utils.simplify_events(runner.run('test')) == [
+ (
+ 'root_agent',
+ Part.from_function_call(
+ name='simple_function', args={'input_str': 'simple_function_call'}
+ ),
+ ),
+ (
+ 'root_agent',
+ Part.from_function_response(
+ name='simple_function',
+ response={'result': 'simple_function_call'},
+ ),
+ ),
+ ('root_agent', 'response1'),
+ ]
+
+
+def test_before_tool_callback_modify_tool_request():
+ responses = [
+ types.Part.from_function_call(name='simple_function', args={}),
+ 'response1',
+ ]
+ mock_model = utils.MockModel.create(responses=responses)
+ agent = Agent(
+ name='root_agent',
+ model=mock_model,
+ before_tool_callback=MockBeforeToolCallback(
+ mock_response={'test': 'before_tool_callback'},
+ modify_tool_request=True,
+ ),
+ tools=[simple_function],
+ )
+
+ runner = utils.InMemoryRunner(agent)
+ assert utils.simplify_events(runner.run('test')) == [
+ ('root_agent', Part.from_function_call(name='simple_function', args={})),
+ (
+ 'root_agent',
+ Part.from_function_response(
+ name='simple_function',
+ response={'result': 'modified_input'},
+ ),
+ ),
+ ('root_agent', 'response1'),
+ ]
+
+
+def test_after_tool_callback():
+ responses = [
+ types.Part.from_function_call(
+ name='simple_function', args={'input_str': 'simple_function_call'}
+ ),
+ 'response1',
+ ]
+ mock_model = utils.MockModel.create(responses=responses)
+ agent = Agent(
+ name='root_agent',
+ model=mock_model,
+ after_tool_callback=MockAfterToolCallback(
+ mock_response={'test': 'after_tool_callback'}
+ ),
+ tools=[simple_function],
+ )
+
+ runner = utils.InMemoryRunner(agent)
+ assert utils.simplify_events(runner.run('test')) == [
+ (
+ 'root_agent',
+ Part.from_function_call(
+ name='simple_function', args={'input_str': 'simple_function_call'}
+ ),
+ ),
+ (
+ 'root_agent',
+ Part.from_function_response(
+ name='simple_function', response={'test': 'after_tool_callback'}
+ ),
+ ),
+ ('root_agent', 'response1'),
+ ]
+
+
+def test_after_tool_callback_noop():
+ responses = [
+ types.Part.from_function_call(
+ name='simple_function', args={'input_str': 'simple_function_call'}
+ ),
+ 'response1',
+ ]
+ mock_model = utils.MockModel.create(responses=responses)
+ agent = Agent(
+ name='root_agent',
+ model=mock_model,
+ after_tool_callback=noop_callback,
+ tools=[simple_function],
+ )
+
+ runner = utils.InMemoryRunner(agent)
+ assert utils.simplify_events(runner.run('test')) == [
+ (
+ 'root_agent',
+ Part.from_function_call(
+ name='simple_function', args={'input_str': 'simple_function_call'}
+ ),
+ ),
+ (
+ 'root_agent',
+ Part.from_function_response(
+ name='simple_function',
+ response={'result': 'simple_function_call'},
+ ),
+ ),
+ ('root_agent', 'response1'),
+ ]
+
+
+def test_after_tool_callback_modify_tool_response():
+ responses = [
+ types.Part.from_function_call(
+ name='simple_function', args={'input_str': 'simple_function_call'}
+ ),
+ 'response1',
+ ]
+ mock_model = utils.MockModel.create(responses=responses)
+ agent = Agent(
+ name='root_agent',
+ model=mock_model,
+ after_tool_callback=MockAfterToolCallback(
+ mock_response={'result': 'after_tool_callback'},
+ modify_tool_response=True,
+ ),
+ tools=[simple_function],
+ )
+
+ runner = utils.InMemoryRunner(agent)
+ assert utils.simplify_events(runner.run('test')) == [
+ (
+ 'root_agent',
+ Part.from_function_call(
+ name='simple_function', args={'input_str': 'simple_function_call'}
+ ),
+ ),
+ (
+ 'root_agent',
+ Part.from_function_response(
+ name='simple_function',
+ response={'result': 'modified_output'},
+ ),
+ ),
+ ('root_agent', 'response1'),
+ ]
diff --git a/src/google/adk/tests/unittests/models/__init__.py b/src/google/adk/tests/unittests/models/__init__.py
new file mode 100644
index 0000000..36a1e8d
--- /dev/null
+++ b/src/google/adk/tests/unittests/models/__init__.py
@@ -0,0 +1,14 @@
+# Copyright 2025 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
diff --git a/src/google/adk/tests/unittests/models/test_google_llm.py b/src/google/adk/tests/unittests/models/test_google_llm.py
new file mode 100644
index 0000000..73f6167
--- /dev/null
+++ b/src/google/adk/tests/unittests/models/test_google_llm.py
@@ -0,0 +1,224 @@
+# Copyright 2025 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import sys
+from unittest import mock
+
+from google.adk import version
+from google.adk.models.gemini_llm_connection import GeminiLlmConnection
+from google.adk.models.google_llm import Gemini
+from google.adk.models.llm_request import LlmRequest
+from google.adk.models.llm_response import LlmResponse
+from google.genai import types
+from google.genai.types import Content
+from google.genai.types import Part
+import pytest
+
+
+@pytest.fixture
+def generate_content_response():
+ return types.GenerateContentResponse(
+ candidates=[
+ types.Candidate(
+ content=Content(
+ role="model",
+ parts=[Part.from_text(text="Hello, how can I help you?")],
+ ),
+ finish_reason=types.FinishReason.STOP,
+ )
+ ]
+ )
+
+
+@pytest.fixture
+def gemini_llm():
+ return Gemini(model="gemini-1.5-flash")
+
+
+@pytest.fixture
+def llm_request():
+ return LlmRequest(
+ model="gemini-1.5-flash",
+ contents=[Content(role="user", parts=[Part.from_text(text="Hello")])],
+ config=types.GenerateContentConfig(
+ temperature=0.1,
+ response_modalities=[types.Modality.TEXT],
+ system_instruction="You are a helpful assistant",
+ ),
+ )
+
+
+def test_supported_models():
+ models = Gemini.supported_models()
+ assert len(models) == 3
+ assert models[0] == r"gemini-.*"
+ assert models[1] == r"projects\/.+\/locations\/.+\/endpoints\/.+"
+ assert (
+ models[2]
+ == r"projects\/.+\/locations\/.+\/publishers\/google\/models\/gemini.+"
+ )
+
+
+def test_client_version_header():
+ model = Gemini(model="gemini-1.5-flash")
+ client = model.api_client
+ expected_header = (
+ f"google-adk/{version.__version__}"
+ f" gl-python/{sys.version.split()[0]} google-genai-sdk/"
+ )
+ assert (
+ expected_header
+ in client._api_client._http_options.headers["x-goog-api-client"]
+ )
+ assert (
+ expected_header in client._api_client._http_options.headers["user-agent"]
+ )
+
+
+def test_maybe_append_user_content(gemini_llm, llm_request):
+ # Test with user content already present
+ gemini_llm._maybe_append_user_content(llm_request)
+ assert len(llm_request.contents) == 1
+
+ # Test with model content as the last message
+ llm_request.contents.append(
+ Content(role="model", parts=[Part.from_text(text="Response")])
+ )
+ gemini_llm._maybe_append_user_content(llm_request)
+ assert len(llm_request.contents) == 3
+ assert llm_request.contents[-1].role == "user"
+ assert "Continue processing" in llm_request.contents[-1].parts[0].text
+
+
+@pytest.mark.asyncio
+async def test_generate_content_async(
+ gemini_llm, llm_request, generate_content_response
+):
+ with mock.patch.object(gemini_llm, "api_client") as mock_client:
+ # Create a mock coroutine that returns the generate_content_response
+ async def mock_coro():
+ return generate_content_response
+
+ # Assign the coroutine to the mocked method
+ mock_client.aio.models.generate_content.return_value = mock_coro()
+
+ responses = [
+ resp
+ async for resp in gemini_llm.generate_content_async(
+ llm_request, stream=False
+ )
+ ]
+
+ assert len(responses) == 1
+ assert isinstance(responses[0], LlmResponse)
+ assert responses[0].content.parts[0].text == "Hello, how can I help you?"
+ mock_client.aio.models.generate_content.assert_called_once()
+
+
+@pytest.mark.asyncio
+async def test_generate_content_async_stream(gemini_llm, llm_request):
+ with mock.patch.object(gemini_llm, "api_client") as mock_client:
+ # Create mock stream responses
+ class MockAsyncIterator:
+
+ def __init__(self, seq):
+ self.iter = iter(seq)
+
+ def __aiter__(self):
+ return self
+
+ async def __anext__(self):
+ try:
+ return next(self.iter)
+ except StopIteration:
+ raise StopAsyncIteration
+
+ mock_responses = [
+ types.GenerateContentResponse(
+ candidates=[
+ types.Candidate(
+ content=Content(
+ role="model", parts=[Part.from_text(text="Hello")]
+ ),
+ finish_reason=None,
+ )
+ ]
+ ),
+ types.GenerateContentResponse(
+ candidates=[
+ types.Candidate(
+ content=Content(
+ role="model", parts=[Part.from_text(text=", how")]
+ ),
+ finish_reason=None,
+ )
+ ]
+ ),
+ types.GenerateContentResponse(
+ candidates=[
+ types.Candidate(
+ content=Content(
+ role="model",
+ parts=[Part.from_text(text=" can I help you?")],
+ ),
+ finish_reason=types.FinishReason.STOP,
+ )
+ ]
+ ),
+ ]
+
+ # Create a mock coroutine that returns the MockAsyncIterator
+ async def mock_coro():
+ return MockAsyncIterator(mock_responses)
+
+ # Set the mock to return the coroutine
+ mock_client.aio.models.generate_content_stream.return_value = mock_coro()
+
+ responses = [
+ resp
+ async for resp in gemini_llm.generate_content_async(
+ llm_request, stream=True
+ )
+ ]
+
+ # Assertions remain the same
+ assert len(responses) == 4
+ assert responses[0].partial is True
+ assert responses[1].partial is True
+ assert responses[2].partial is True
+ assert responses[3].content.parts[0].text == "Hello, how can I help you?"
+ mock_client.aio.models.generate_content_stream.assert_called_once()
+
+
+@pytest.mark.asyncio
+async def test_connect(gemini_llm, llm_request):
+ # Create a mock connection
+ mock_connection = mock.MagicMock(spec=GeminiLlmConnection)
+
+ # Create a mock context manager
+ class MockContextManager:
+
+ async def __aenter__(self):
+ return mock_connection
+
+ async def __aexit__(self, *args):
+ pass
+
+ # Mock the connect method at the class level
+ with mock.patch(
+ "google.adk.models.google_llm.Gemini.connect",
+ return_value=MockContextManager(),
+ ):
+ async with gemini_llm.connect(llm_request) as connection:
+ assert connection is mock_connection
diff --git a/src/google/adk/tests/unittests/models/test_litellm.py b/src/google/adk/tests/unittests/models/test_litellm.py
new file mode 100644
index 0000000..6b09b1c
--- /dev/null
+++ b/src/google/adk/tests/unittests/models/test_litellm.py
@@ -0,0 +1,804 @@
+# Copyright 2025 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+from unittest.mock import AsyncMock
+from unittest.mock import Mock
+from google.adk.models.lite_llm import _content_to_message_param
+from google.adk.models.lite_llm import _function_declaration_to_tool_param
+from google.adk.models.lite_llm import _get_content
+from google.adk.models.lite_llm import _message_to_generate_content_response
+from google.adk.models.lite_llm import _model_response_to_chunk
+from google.adk.models.lite_llm import _to_litellm_role
+from google.adk.models.lite_llm import FunctionChunk
+from google.adk.models.lite_llm import LiteLlm
+from google.adk.models.lite_llm import LiteLLMClient
+from google.adk.models.lite_llm import TextChunk
+from google.adk.models.llm_request import LlmRequest
+from google.genai import types
+from litellm import ChatCompletionAssistantMessage
+from litellm import ChatCompletionMessageToolCall
+from litellm import Function
+from litellm.types.utils import ChatCompletionDeltaToolCall
+from litellm.types.utils import Choices
+from litellm.types.utils import Delta
+from litellm.types.utils import ModelResponse
+from litellm.types.utils import StreamingChoices
+import pytest
+
+LLM_REQUEST_WITH_FUNCTION_DECLARATION = LlmRequest(
+ contents=[
+ types.Content(
+ role="user", parts=[types.Part.from_text(text="Test prompt")]
+ )
+ ],
+ config=types.GenerateContentConfig(
+ tools=[
+ types.Tool(
+ function_declarations=[
+ types.FunctionDeclaration(
+ name="test_function",
+ description="Test function description",
+ parameters=types.Schema(
+ type=types.Type.OBJECT,
+ properties={
+ "test_arg": types.Schema(
+ type=types.Type.STRING
+ ),
+ "array_arg": types.Schema(
+ type=types.Type.ARRAY,
+ items={
+ "type": types.Type.STRING,
+ },
+ ),
+ "nested_arg": types.Schema(
+ type=types.Type.OBJECT,
+ properties={
+ "nested_key1": types.Schema(
+ type=types.Type.STRING
+ ),
+ "nested_key2": types.Schema(
+ type=types.Type.STRING
+ ),
+ },
+ ),
+ },
+ ),
+ )
+ ]
+ )
+ ],
+ ),
+)
+
+
+STREAMING_MODEL_RESPONSE = [
+ ModelResponse(
+ choices=[
+ StreamingChoices(
+ finish_reason=None,
+ delta=Delta(
+ role="assistant",
+ content="zero, ",
+ ),
+ )
+ ]
+ ),
+ ModelResponse(
+ choices=[
+ StreamingChoices(
+ finish_reason=None,
+ delta=Delta(
+ role="assistant",
+ content="one, ",
+ ),
+ )
+ ]
+ ),
+ ModelResponse(
+ choices=[
+ StreamingChoices(
+ finish_reason=None,
+ delta=Delta(
+ role="assistant",
+ content="two:",
+ ),
+ )
+ ]
+ ),
+ ModelResponse(
+ choices=[
+ StreamingChoices(
+ finish_reason=None,
+ delta=Delta(
+ role="assistant",
+ tool_calls=[
+ ChatCompletionDeltaToolCall(
+ type="function",
+ id="test_tool_call_id",
+ function=Function(
+ name="test_function",
+ arguments='{"test_arg": "test_',
+ ),
+ index=0,
+ )
+ ],
+ ),
+ )
+ ]
+ ),
+ ModelResponse(
+ choices=[
+ StreamingChoices(
+ finish_reason=None,
+ delta=Delta(
+ role="assistant",
+ tool_calls=[
+ ChatCompletionDeltaToolCall(
+ type="function",
+ id=None,
+ function=Function(
+ name=None,
+ arguments='value"}',
+ ),
+ index=0,
+ )
+ ],
+ ),
+ )
+ ]
+ ),
+ ModelResponse(
+ choices=[
+ StreamingChoices(
+ finish_reason="tool_use",
+ )
+ ]
+ ),
+]
+
+@pytest.fixture
+def mock_response():
+ return ModelResponse(
+ choices=[
+ Choices(
+ message=ChatCompletionAssistantMessage(
+ role="assistant",
+ content="Test response",
+ tool_calls=[
+ ChatCompletionMessageToolCall(
+ type="function",
+ id="test_tool_call_id",
+ function=Function(
+ name="test_function",
+ arguments='{"test_arg": "test_value"}',
+ ),
+ )
+ ],
+ )
+ )
+ ]
+ )
+
+
+@pytest.fixture
+def mock_acompletion(mock_response):
+ return AsyncMock(return_value=mock_response)
+
+
+@pytest.fixture
+def mock_completion(mock_response):
+ return Mock(return_value=mock_response)
+
+
+@pytest.fixture
+def mock_client(mock_acompletion, mock_completion):
+ return MockLLMClient(mock_acompletion, mock_completion)
+
+
+@pytest.fixture
+def lite_llm_instance(mock_client):
+ return LiteLlm(model="test_model", llm_client=mock_client)
+
+
+class MockLLMClient(LiteLLMClient):
+
+ def __init__(self, acompletion_mock, completion_mock):
+ self.acompletion_mock = acompletion_mock
+ self.completion_mock = completion_mock
+
+ async def acompletion(self, model, messages, tools, **kwargs):
+ return await self.acompletion_mock(
+ model=model, messages=messages, tools=tools, **kwargs
+ )
+
+ def completion(self, model, messages, tools, stream, **kwargs):
+ return self.completion_mock(
+ model=model, messages=messages, tools=tools, stream=stream, **kwargs
+ )
+
+
+@pytest.mark.asyncio
+async def test_generate_content_async(mock_acompletion, lite_llm_instance):
+
+ async for response in lite_llm_instance.generate_content_async(
+ LLM_REQUEST_WITH_FUNCTION_DECLARATION
+ ):
+ assert response.content.role == "model"
+ assert response.content.parts[0].text == "Test response"
+ assert response.content.parts[1].function_call.name == "test_function"
+ assert response.content.parts[1].function_call.args == {
+ "test_arg": "test_value"
+ }
+ assert response.content.parts[1].function_call.id == "test_tool_call_id"
+
+ mock_acompletion.assert_called_once()
+
+ _, kwargs = mock_acompletion.call_args
+ assert kwargs["model"] == "test_model"
+ assert kwargs["messages"][0]["role"] == "user"
+ assert kwargs["messages"][0]["content"] == "Test prompt"
+ assert kwargs["tools"][0]["function"]["name"] == "test_function"
+ assert (
+ kwargs["tools"][0]["function"]["description"]
+ == "Test function description"
+ )
+ assert (
+ kwargs["tools"][0]["function"]["parameters"]["properties"]["test_arg"][
+ "type"
+ ]
+ == "string"
+ )
+
+
+function_declaration_test_cases = [
+ (
+ "simple_function",
+ types.FunctionDeclaration(
+ name="test_function",
+ description="Test function description",
+ parameters=types.Schema(
+ type=types.Type.OBJECT,
+ properties={
+ "test_arg": types.Schema(type=types.Type.STRING),
+ "array_arg": types.Schema(
+ type=types.Type.ARRAY,
+ items=types.Schema(
+ type=types.Type.STRING,
+ ),
+ ),
+ "nested_arg": types.Schema(
+ type=types.Type.OBJECT,
+ properties={
+ "nested_key1": types.Schema(type=types.Type.STRING),
+ "nested_key2": types.Schema(type=types.Type.STRING),
+ },
+ ),
+ },
+ ),
+ ),
+ {
+ "type": "function",
+ "function": {
+ "name": "test_function",
+ "description": "Test function description",
+ "parameters": {
+ "type": "object",
+ "properties": {
+ "test_arg": {"type": "string"},
+ "array_arg": {
+ "items": {"type": "string"},
+ "type": "array",
+ },
+ "nested_arg": {
+ "properties": {
+ "nested_key1": {"type": "string"},
+ "nested_key2": {"type": "string"},
+ },
+ "type": "object",
+ },
+ },
+ },
+ },
+ },
+ ),
+ (
+ "no_description",
+ types.FunctionDeclaration(
+ name="test_function_no_description",
+ parameters=types.Schema(
+ type=types.Type.OBJECT,
+ properties={
+ "test_arg": types.Schema(type=types.Type.STRING),
+ },
+ ),
+ ),
+ {
+ "type": "function",
+ "function": {
+ "name": "test_function_no_description",
+ "description": "",
+ "parameters": {
+ "type": "object",
+ "properties": {
+ "test_arg": {"type": "string"},
+ },
+ },
+ },
+ },
+ ),
+ (
+ "empty_parameters",
+ types.FunctionDeclaration(
+ name="test_function_empty_params",
+ parameters=types.Schema(type=types.Type.OBJECT, properties={}),
+ ),
+ {
+ "type": "function",
+ "function": {
+ "name": "test_function_empty_params",
+ "description": "",
+ "parameters": {
+ "type": "object",
+ "properties": {},
+ },
+ },
+ },
+ ),
+ (
+ "nested_array",
+ types.FunctionDeclaration(
+ name="test_function_nested_array",
+ parameters=types.Schema(
+ type=types.Type.OBJECT,
+ properties={
+ "array_arg": types.Schema(
+ type=types.Type.ARRAY,
+ items=types.Schema(
+ type=types.Type.OBJECT,
+ properties={
+ "nested_key": types.Schema(
+ type=types.Type.STRING
+ )
+ },
+ ),
+ ),
+ },
+ ),
+ ),
+ {
+ "type": "function",
+ "function": {
+ "name": "test_function_nested_array",
+ "description": "",
+ "parameters": {
+ "type": "object",
+ "properties": {
+ "array_arg": {
+ "items": {
+ "properties": {
+ "nested_key": {"type": "string"}
+ },
+ "type": "object",
+ },
+ "type": "array",
+ },
+ },
+ },
+ },
+ },
+ ),
+]
+
+
+@pytest.mark.parametrize(
+ "_, function_declaration, expected_output",
+ function_declaration_test_cases,
+ ids=[case[0] for case in function_declaration_test_cases],
+)
+def test_function_declaration_to_tool_param(
+ _, function_declaration, expected_output
+):
+ assert (
+ _function_declaration_to_tool_param(function_declaration)
+ == expected_output
+ )
+
+
+@pytest.mark.asyncio
+async def test_generate_content_async_with_system_instruction(
+ lite_llm_instance, mock_acompletion
+):
+ mock_response_with_system_instruction = ModelResponse(
+ choices=[
+ Choices(
+ message=ChatCompletionAssistantMessage(
+ role="assistant",
+ content="Test response",
+ )
+ )
+ ]
+ )
+ mock_acompletion.return_value = mock_response_with_system_instruction
+
+ llm_request = LlmRequest(
+ contents=[
+ types.Content(
+ role="user", parts=[types.Part.from_text(text="Test prompt")]
+ )
+ ],
+ config=types.GenerateContentConfig(
+ system_instruction="Test system instruction"
+ ),
+ )
+
+ async for response in lite_llm_instance.generate_content_async(llm_request):
+ assert response.content.role == "model"
+ assert response.content.parts[0].text == "Test response"
+
+ mock_acompletion.assert_called_once()
+
+ _, kwargs = mock_acompletion.call_args
+ assert kwargs["model"] == "test_model"
+ assert kwargs["messages"][0]["role"] == "developer"
+ assert kwargs["messages"][0]["content"] == "Test system instruction"
+ assert kwargs["messages"][1]["role"] == "user"
+ assert kwargs["messages"][1]["content"] == "Test prompt"
+
+
+@pytest.mark.asyncio
+async def test_generate_content_async_with_tool_response(
+ lite_llm_instance, mock_acompletion
+):
+ mock_response_with_tool_response = ModelResponse(
+ choices=[
+ Choices(
+ message=ChatCompletionAssistantMessage(
+ role="tool",
+ content='{"result": "test_result"}',
+ tool_call_id="test_tool_call_id",
+ )
+ )
+ ]
+ )
+ mock_acompletion.return_value = mock_response_with_tool_response
+
+ llm_request = LlmRequest(
+ contents=[
+ types.Content(
+ role="user", parts=[types.Part.from_text(text="Test prompt")]
+ ),
+ types.Content(
+ role="tool",
+ parts=[
+ types.Part.from_function_response(
+ name="test_function",
+ response={"result": "test_result"},
+ )
+ ],
+ ),
+ ],
+ config=types.GenerateContentConfig(
+ system_instruction="test instruction",
+ ),
+ )
+ async for response in lite_llm_instance.generate_content_async(llm_request):
+ assert response.content.role == "model"
+ assert response.content.parts[0].text == '{"result": "test_result"}'
+
+ mock_acompletion.assert_called_once()
+
+ _, kwargs = mock_acompletion.call_args
+ assert kwargs["model"] == "test_model"
+
+ assert kwargs["messages"][2]["role"] == "tool"
+ assert kwargs["messages"][2]["content"] == '{"result": "test_result"}'
+
+
+def test_content_to_message_param_user_message():
+ content = types.Content(
+ role="user", parts=[types.Part.from_text(text="Test prompt")]
+ )
+ message = _content_to_message_param(content)
+ assert message["role"] == "user"
+ assert message["content"] == "Test prompt"
+
+
+def test_content_to_message_param_assistant_message():
+ content = types.Content(
+ role="assistant", parts=[types.Part.from_text(text="Test response")]
+ )
+ message = _content_to_message_param(content)
+ assert message["role"] == "assistant"
+ assert message["content"] == "Test response"
+
+
+def test_content_to_message_param_function_call():
+ content = types.Content(
+ role="assistant",
+ parts=[
+ types.Part.from_function_call(
+ name="test_function", args={"test_arg": "test_value"}
+ )
+ ],
+ )
+ content.parts[0].function_call.id = "test_tool_call_id"
+ message = _content_to_message_param(content)
+ assert message["role"] == "assistant"
+ assert message["content"] == []
+ assert message["tool_calls"][0].type == "function"
+ assert message["tool_calls"][0].id == "test_tool_call_id"
+ assert message["tool_calls"][0].function.name == "test_function"
+ assert (
+ message["tool_calls"][0].function.arguments
+ == '{"test_arg": "test_value"}'
+ )
+
+
+def test_message_to_generate_content_response_text():
+ message = ChatCompletionAssistantMessage(
+ role="assistant",
+ content="Test response",
+ )
+ response = _message_to_generate_content_response(message)
+ assert response.content.role == "model"
+ assert response.content.parts[0].text == "Test response"
+
+
+def test_message_to_generate_content_response_tool_call():
+ message = ChatCompletionAssistantMessage(
+ role="assistant",
+ content=None,
+ tool_calls=[
+ ChatCompletionMessageToolCall(
+ type="function",
+ id="test_tool_call_id",
+ function=Function(
+ name="test_function",
+ arguments='{"test_arg": "test_value"}',
+ ),
+ )
+ ],
+ )
+
+ response = _message_to_generate_content_response(message)
+ assert response.content.role == "model"
+ assert response.content.parts[0].function_call.name == "test_function"
+ assert response.content.parts[0].function_call.args == {
+ "test_arg": "test_value"
+ }
+ assert response.content.parts[0].function_call.id == "test_tool_call_id"
+
+
+def test_get_content_text():
+ parts = [types.Part.from_text(text="Test text")]
+ content = _get_content(parts)
+ assert content == "Test text"
+
+
+def test_get_content_image():
+ parts = [
+ types.Part.from_bytes(data=b"test_image_data", mime_type="image/png")
+ ]
+ content = _get_content(parts)
+ assert content[0]["type"] == "image_url"
+ assert content[0]["image_url"] == ""
+
+
+def test_get_content_video():
+ parts = [
+ types.Part.from_bytes(data=b"test_video_data", mime_type="video/mp4")
+ ]
+ content = _get_content(parts)
+ assert content[0]["type"] == "video_url"
+ assert content[0]["video_url"] == "data:video/mp4;base64,dGVzdF92aWRlb19kYXRh"
+
+
+def test_to_litellm_role():
+ assert _to_litellm_role("model") == "assistant"
+ assert _to_litellm_role("assistant") == "assistant"
+ assert _to_litellm_role("user") == "user"
+ assert _to_litellm_role(None) == "user"
+
+
+@pytest.mark.parametrize(
+ "response, expected_chunk, expected_finished",
+ [
+ (
+ ModelResponse(
+ choices=[
+ {
+ "message": {
+ "content": "this is a test",
+ }
+ }
+ ]
+ ),
+ TextChunk(text="this is a test"),
+ "stop",
+ ),
+ (
+ ModelResponse(
+ choices=[
+ StreamingChoices(
+ finish_reason=None,
+ delta=Delta(
+ role="assistant",
+ tool_calls=[
+ ChatCompletionDeltaToolCall(
+ type="function",
+ id="1",
+ function=Function(
+ name="test_function",
+ arguments='{"key": "va',
+ ),
+ index=0,
+ )
+ ],
+ ),
+ )
+ ]
+ ),
+ FunctionChunk(id="1", name="test_function", args='{"key": "va'),
+ None,
+ ),
+ (
+ ModelResponse(choices=[{"finish_reason": "tool_calls"}]),
+ None,
+ "tool_calls",
+ ),
+ (ModelResponse(choices=[{}]), None, "stop"),
+ ],
+)
+def test_model_response_to_chunk(response, expected_chunk, expected_finished):
+ result = list(_model_response_to_chunk(response))
+ assert len(result) == 1
+ chunk, finished = result[0]
+ if expected_chunk:
+ assert isinstance(chunk, type(expected_chunk))
+ assert chunk == expected_chunk
+ else:
+ assert chunk is None
+ assert finished == expected_finished
+
+
+@pytest.mark.asyncio
+async def test_acompletion_additional_args(mock_acompletion, mock_client):
+ lite_llm_instance = LiteLlm(
+ # valid args
+ model="test_model",
+ llm_client=mock_client,
+ api_key="test_key",
+ api_base="some://url",
+ api_version="2024-09-12",
+ # invalid args (ignored)
+ stream=True,
+ messages=[{"role": "invalid", "content": "invalid"}],
+ tools=[{
+ "type": "function",
+ "function": {
+ "name": "invalid",
+ },
+ }],
+ )
+
+ async for response in lite_llm_instance.generate_content_async(
+ LLM_REQUEST_WITH_FUNCTION_DECLARATION
+ ):
+ assert response.content.role == "model"
+ assert response.content.parts[0].text == "Test response"
+ assert response.content.parts[1].function_call.name == "test_function"
+ assert response.content.parts[1].function_call.args == {
+ "test_arg": "test_value"
+ }
+ assert response.content.parts[1].function_call.id == "test_tool_call_id"
+
+ mock_acompletion.assert_called_once()
+
+ _, kwargs = mock_acompletion.call_args
+
+ assert kwargs["model"] == "test_model"
+ assert kwargs["messages"][0]["role"] == "user"
+ assert kwargs["messages"][0]["content"] == "Test prompt"
+ assert kwargs["tools"][0]["function"]["name"] == "test_function"
+ assert "stream" not in kwargs
+ assert "llm_client" not in kwargs
+ assert kwargs["api_base"] == "some://url"
+
+
+@pytest.mark.asyncio
+async def test_completion_additional_args(mock_completion, mock_client):
+ lite_llm_instance = LiteLlm(
+ # valid args
+ model="test_model",
+ llm_client=mock_client,
+ api_key="test_key",
+ api_base="some://url",
+ api_version="2024-09-12",
+ # invalid args (ignored)
+ stream=False,
+ messages=[{"role": "invalid", "content": "invalid"}],
+ tools=[{
+ "type": "function",
+ "function": {
+ "name": "invalid",
+ },
+ }],
+ )
+
+ mock_completion.return_value = iter(STREAMING_MODEL_RESPONSE)
+
+ responses = [
+ response
+ async for response in lite_llm_instance.generate_content_async(
+ LLM_REQUEST_WITH_FUNCTION_DECLARATION, stream=True
+ )
+ ]
+ assert len(responses) == 4
+ mock_completion.assert_called_once()
+
+ _, kwargs = mock_completion.call_args
+
+ assert kwargs["model"] == "test_model"
+ assert kwargs["messages"][0]["role"] == "user"
+ assert kwargs["messages"][0]["content"] == "Test prompt"
+ assert kwargs["tools"][0]["function"]["name"] == "test_function"
+ assert kwargs["stream"]
+ assert "llm_client" not in kwargs
+ assert kwargs["api_base"] == "some://url"
+
+
+@pytest.mark.asyncio
+async def test_generate_content_async_stream(
+ mock_completion, lite_llm_instance
+):
+
+ mock_completion.return_value = iter(STREAMING_MODEL_RESPONSE)
+
+ responses = [
+ response
+ async for response in lite_llm_instance.generate_content_async(
+ LLM_REQUEST_WITH_FUNCTION_DECLARATION, stream=True
+ )
+ ]
+ assert len(responses) == 4
+ assert responses[0].content.role == "model"
+ assert responses[0].content.parts[0].text == "zero, "
+ assert responses[1].content.role == "model"
+ assert responses[1].content.parts[0].text == "one, "
+ assert responses[2].content.role == "model"
+ assert responses[2].content.parts[0].text == "two:"
+ assert responses[3].content.role == "model"
+ assert responses[3].content.parts[0].function_call.name == "test_function"
+ assert responses[3].content.parts[0].function_call.args == {
+ "test_arg": "test_value"
+ }
+ assert responses[3].content.parts[0].function_call.id == "test_tool_call_id"
+ mock_completion.assert_called_once()
+
+ _, kwargs = mock_completion.call_args
+ assert kwargs["model"] == "test_model"
+ assert kwargs["messages"][0]["role"] == "user"
+ assert kwargs["messages"][0]["content"] == "Test prompt"
+ assert kwargs["tools"][0]["function"]["name"] == "test_function"
+ assert (
+ kwargs["tools"][0]["function"]["description"]
+ == "Test function description"
+ )
+ assert (
+ kwargs["tools"][0]["function"]["parameters"]["properties"]["test_arg"][
+ "type"
+ ]
+ == "string"
+ )
diff --git a/src/google/adk/tests/unittests/models/test_models.py b/src/google/adk/tests/unittests/models/test_models.py
new file mode 100644
index 0000000..fb21171
--- /dev/null
+++ b/src/google/adk/tests/unittests/models/test_models.py
@@ -0,0 +1,60 @@
+# Copyright 2025 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from google.adk import models
+from google.adk.models.anthropic_llm import Claude
+from google.adk.models.google_llm import Gemini
+from google.adk.models.registry import LLMRegistry
+import pytest
+
+
+@pytest.mark.parametrize(
+ 'model_name',
+ [
+ 'gemini-1.5-flash',
+ 'gemini-1.5-flash-001',
+ 'gemini-1.5-flash-002',
+ 'gemini-1.5-pro',
+ 'gemini-1.5-pro-001',
+ 'gemini-1.5-pro-002',
+ 'gemini-2.0-flash-exp',
+ 'projects/123456/locations/us-central1/endpoints/123456', # finetuned vertex gemini endpoint
+ 'projects/123456/locations/us-central1/publishers/google/models/gemini-2.0-flash-exp', # vertex gemini long name
+ ],
+)
+def test_match_gemini_family(model_name):
+ assert models.LLMRegistry.resolve(model_name) is Gemini
+
+
+@pytest.mark.parametrize(
+ 'model_name',
+ [
+ 'claude-3-5-haiku@20241022',
+ 'claude-3-5-sonnet-v2@20241022',
+ 'claude-3-5-sonnet@20240620',
+ 'claude-3-haiku@20240307',
+ 'claude-3-opus@20240229',
+ 'claude-3-sonnet@20240229',
+ ],
+)
+def test_match_claude_family(model_name):
+ LLMRegistry.register(Claude)
+
+ assert models.LLMRegistry.resolve(model_name) is Claude
+
+
+def test_non_exist_model():
+ with pytest.raises(ValueError) as e_info:
+ models.LLMRegistry.resolve('non-exist-model')
+ assert 'Model non-exist-model not found.' in str(e_info.value)
diff --git a/src/google/adk/tests/unittests/sessions/__init__.py b/src/google/adk/tests/unittests/sessions/__init__.py
new file mode 100644
index 0000000..36a1e8d
--- /dev/null
+++ b/src/google/adk/tests/unittests/sessions/__init__.py
@@ -0,0 +1,14 @@
+# Copyright 2025 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
diff --git a/src/google/adk/tests/unittests/sessions/test_session_service.py b/src/google/adk/tests/unittests/sessions/test_session_service.py
new file mode 100644
index 0000000..953d497
--- /dev/null
+++ b/src/google/adk/tests/unittests/sessions/test_session_service.py
@@ -0,0 +1,227 @@
+# Copyright 2025 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import enum
+import pytest
+
+from google.adk.events import Event
+from google.adk.events import EventActions
+from google.adk.sessions import DatabaseSessionService
+from google.adk.sessions import InMemorySessionService
+from google.genai import types
+
+
+class SessionServiceType(enum.Enum):
+ IN_MEMORY = 'IN_MEMORY'
+ DATABASE = 'DATABASE'
+
+
+def get_session_service(
+ service_type: SessionServiceType = SessionServiceType.IN_MEMORY,
+):
+ """Creates a session service for testing."""
+ if service_type == SessionServiceType.DATABASE:
+ return DatabaseSessionService('sqlite:///:memory:')
+ return InMemorySessionService()
+
+
+@pytest.mark.parametrize(
+ 'service_type', [SessionServiceType.IN_MEMORY, SessionServiceType.DATABASE]
+)
+def test_get_empty_session(service_type):
+ session_service = get_session_service(service_type)
+ assert not session_service.get_session(
+ app_name='my_app', user_id='test_user', session_id='123'
+ )
+
+
+@pytest.mark.parametrize(
+ 'service_type', [SessionServiceType.IN_MEMORY, SessionServiceType.DATABASE]
+)
+def test_create_get_session(service_type):
+ session_service = get_session_service(service_type)
+ app_name = 'my_app'
+ user_id = 'test_user'
+ state = {'key': 'value'}
+
+ session = session_service.create_session(
+ app_name=app_name, user_id=user_id, state=state
+ )
+ assert session.app_name == app_name
+ assert session.user_id == user_id
+ assert session.id
+ assert session.state == state
+ assert (
+ session_service.get_session(
+ app_name=app_name, user_id=user_id, session_id=session.id
+ )
+ == session
+ )
+
+ session_id = session.id
+ session_service.delete_session(
+ app_name=app_name, user_id=user_id, session_id=session_id
+ )
+
+ assert (
+ not session_service.get_session(
+ app_name=app_name, user_id=user_id, session_id=session.id
+ )
+ == session
+ )
+
+
+@pytest.mark.parametrize(
+ 'service_type', [SessionServiceType.IN_MEMORY, SessionServiceType.DATABASE]
+)
+def test_create_and_list_sessions(service_type):
+ session_service = get_session_service(service_type)
+ app_name = 'my_app'
+ user_id = 'test_user'
+
+ session_ids = ['session' + str(i) for i in range(5)]
+ for session_id in session_ids:
+ session_service.create_session(
+ app_name=app_name, user_id=user_id, session_id=session_id
+ )
+
+ sessions = session_service.list_sessions(
+ app_name=app_name, user_id=user_id
+ ).sessions
+ for i in range(len(sessions)):
+ assert sessions[i].id == session_ids[i]
+
+
+@pytest.mark.parametrize(
+ 'service_type', [SessionServiceType.IN_MEMORY, SessionServiceType.DATABASE]
+)
+def test_session_state(service_type):
+ session_service = get_session_service(service_type)
+ app_name = 'my_app'
+ user_id_1 = 'user1'
+ user_id_2 = 'user2'
+ session_id_11 = 'session11'
+ session_id_12 = 'session12'
+ session_id_2 = 'session2'
+ state_11 = {'key11': 'value11'}
+ state_12 = {'key12': 'value12'}
+
+ session_11 = session_service.create_session(
+ app_name=app_name,
+ user_id=user_id_1,
+ state=state_11,
+ session_id=session_id_11,
+ )
+ session_service.create_session(
+ app_name=app_name,
+ user_id=user_id_1,
+ state=state_12,
+ session_id=session_id_12,
+ )
+ session_service.create_session(
+ app_name=app_name, user_id=user_id_2, session_id=session_id_2
+ )
+
+ assert session_11.state.get('key11') == 'value11'
+
+ event = Event(
+ invocation_id='invocation',
+ author='user',
+ content=types.Content(role='user', parts=[types.Part(text='text')]),
+ actions=EventActions(
+ state_delta={
+ 'app:key': 'value',
+ 'user:key1': 'value1',
+ 'temp:key': 'temp',
+ 'key11': 'value11_new',
+ }
+ ),
+ )
+ session_service.append_event(session=session_11, event=event)
+
+ # User and app state is stored, temp state is filtered.
+ assert session_11.state.get('app:key') == 'value'
+ assert session_11.state.get('key11') == 'value11_new'
+ assert session_11.state.get('user:key1') == 'value1'
+ assert not session_11.state.get('temp:key')
+
+ session_12 = session_service.get_session(
+ app_name=app_name, user_id=user_id_1, session_id=session_id_12
+ )
+ # After getting a new instance, the session_12 got the user and app state,
+ # even append_event is not applied to it, temp state has no effect
+ assert session_12.state.get('key12') == 'value12'
+ assert not session_12.state.get('temp:key')
+
+ # The user1's state is not visible to user2, app state is visible
+ session_2 = session_service.get_session(
+ app_name=app_name, user_id=user_id_2, session_id=session_id_2
+ )
+ assert session_2.state.get('app:key') == 'value'
+ assert not session_2.state.get('user:key1')
+
+ assert not session_2.state.get('user:key1')
+
+ # The change to session_11 is persisted
+ session_11 = session_service.get_session(
+ app_name=app_name, user_id=user_id_1, session_id=session_id_11
+ )
+ assert session_11.state.get('key11') == 'value11_new'
+ assert session_11.state.get('user:key1') == 'value1'
+ assert not session_11.state.get('temp:key')
+
+
+@pytest.mark.parametrize(
+ "service_type", [SessionServiceType.IN_MEMORY, SessionServiceType.DATABASE]
+)
+def test_create_new_session_will_merge_states(service_type):
+ session_service = get_session_service(service_type)
+ app_name = 'my_app'
+ user_id = 'user'
+ session_id_1 = 'session1'
+ session_id_2 = 'session2'
+ state_1 = {'key1': 'value1'}
+
+ session_1 = session_service.create_session(
+ app_name=app_name, user_id=user_id, state=state_1, session_id=session_id_1
+ )
+
+ event = Event(
+ invocation_id='invocation',
+ author='user',
+ content=types.Content(role='user', parts=[types.Part(text='text')]),
+ actions=EventActions(
+ state_delta={
+ 'app:key': 'value',
+ 'user:key1': 'value1',
+ 'temp:key': 'temp',
+ }
+ ),
+ )
+ session_service.append_event(session=session_1, event=event)
+
+ # User and app state is stored, temp state is filtered.
+ assert session_1.state.get('app:key') == 'value'
+ assert session_1.state.get('key1') == 'value1'
+ assert session_1.state.get('user:key1') == 'value1'
+ assert not session_1.state.get('temp:key')
+
+ session_2 = session_service.create_session(
+ app_name=app_name, user_id=user_id, state={}, session_id=session_id_2
+ )
+ # Session 2 has the persisted states
+ assert session_2.state.get('app:key') == 'value'
+ assert session_2.state.get('user:key1') == 'value1'
+ assert not session_2.state.get('key1')
+ assert not session_2.state.get('temp:key')
diff --git a/src/google/adk/tests/unittests/sessions/test_vertex_ai_session_service.py b/src/google/adk/tests/unittests/sessions/test_vertex_ai_session_service.py
new file mode 100644
index 0000000..d6aad34
--- /dev/null
+++ b/src/google/adk/tests/unittests/sessions/test_vertex_ai_session_service.py
@@ -0,0 +1,246 @@
+# Copyright 2025 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import re
+import this
+from typing import Any
+import uuid
+from dateutil.parser import isoparse
+from google.adk.events import Event
+from google.adk.events import EventActions
+from google.adk.sessions import Session
+from google.adk.sessions import VertexAiSessionService
+from google.genai import types
+import pytest
+
+
+MOCK_SESSION_JSON_1 = {
+ 'name': (
+ 'projects/test-project/locations/test-location/'
+ 'reasoningEngines/123/sessions/1'
+ ),
+ 'createTime': '2024-12-12T12:12:12.123456Z',
+ 'updateTime': '2024-12-12T12:12:12.123456Z',
+ 'sessionState': {
+ 'key': {'value': 'test_value'},
+ },
+ 'userId': 'user',
+}
+MOCK_SESSION_JSON_2 = {
+ 'name': (
+ 'projects/test-project/locations/test-location/'
+ 'reasoningEngines/123/sessions/2'
+ ),
+ 'updateTime': '2024-12-13T12:12:12.123456Z',
+ 'userId': 'user',
+}
+MOCK_SESSION_JSON_3 = {
+ 'name': (
+ 'projects/test-project/locations/test-location/'
+ 'reasoningEngines/123/sessions/3'
+ ),
+ 'updateTime': '2024-12-14T12:12:12.123456Z',
+ 'userId': 'user2',
+}
+MOCK_EVENT_JSON = [
+ {
+ 'name': (
+ 'projects/test-project/locations/test-location/'
+ 'reasoningEngines/test_engine/sessions/1/events/123'
+ ),
+ 'invocationId': '123',
+ 'author': 'user',
+ 'timestamp': '2024-12-12T12:12:12.123456Z',
+ 'content': {
+ 'parts': [
+ {'text': 'test_content'},
+ ],
+ },
+ 'actions': {
+ 'stateDelta': {
+ 'key': {'value': 'test_value'},
+ },
+ 'transferAgent': 'agent',
+ },
+ 'eventMetadata': {
+ 'partial': False,
+ 'turnComplete': True,
+ 'interrupted': False,
+ 'branch': '',
+ 'longRunningToolIds': ['tool1'],
+ },
+ },
+]
+
+MOCK_SESSION = Session(
+ app_name='123',
+ user_id='user',
+ id='1',
+ state=MOCK_SESSION_JSON_1['sessionState'],
+ last_update_time=isoparse(MOCK_SESSION_JSON_1['updateTime']).timestamp(),
+ events=[
+ Event(
+ id='123',
+ invocation_id='123',
+ author='user',
+ timestamp=isoparse(MOCK_EVENT_JSON[0]['timestamp']).timestamp(),
+ content=types.Content(parts=[types.Part(text='test_content')]),
+ actions=EventActions(
+ transfer_to_agent='agent',
+ state_delta={'key': {'value': 'test_value'}},
+ ),
+ partial=False,
+ turn_complete=True,
+ interrupted=False,
+ branch='',
+ long_running_tool_ids={'tool1'},
+ ),
+ ],
+)
+
+
+SESSION_REGEX = r'^reasoningEngines/([^/]+)/sessions/([^/]+)$'
+SESSIONS_REGEX = r'^reasoningEngines/([^/]+)/sessions$'
+EVENTS_REGEX = r'^reasoningEngines/([^/]+)/sessions/([^/]+)/events$'
+LRO_REGEX = r'^operations/([^/]+)$'
+
+
+class MockApiClient:
+ """Mocks the API Client."""
+
+ def __init__(self) -> None:
+ """Initializes MockClient."""
+ this.session_dict: dict[str, Any] = {}
+ this.event_dict: dict[str, list[Any]] = {}
+
+ def request(self, http_method: str, path: str, request_dict: dict[str, Any]):
+ """Mocks the API Client request method."""
+ if http_method == 'GET':
+ if re.match(SESSION_REGEX, path):
+ match = re.match(SESSION_REGEX, path)
+ if match:
+ session_id = match.group(2)
+ if session_id in self.session_dict:
+ return self.session_dict[session_id]
+ else:
+ raise ValueError(f'Session not found: {session_id}')
+ elif re.match(SESSIONS_REGEX, path):
+ return {
+ 'sessions': self.session_dict.values(),
+ }
+ elif re.match(EVENTS_REGEX, path):
+ match = re.match(EVENTS_REGEX, path)
+ if match:
+ return {'sessionEvents': self.event_dict[match.group(2)]}
+ elif re.match(LRO_REGEX, path):
+ return {
+ 'name': (
+ 'projects/test-project/locations/test-location/'
+ 'reasoningEngines/123/sessions/123'
+ ),
+ 'done': True,
+ }
+ else:
+ raise ValueError(f'Unsupported path: {path}')
+ elif http_method == 'POST':
+ id = str(uuid.uuid4())
+ self.session_dict[id] = {
+ 'name': (
+ 'projects/test-project/locations/test-location/'
+ 'reasoningEngines/123/sessions/'
+ + id
+ ),
+ 'userId': request_dict['user_id'],
+ 'sessionState': request_dict.get('sessionState', {}),
+ 'updateTime': '2024-12-12T12:12:12.123456Z',
+ }
+ return {
+ 'name': (
+ 'projects/test_project/locations/test_location/'
+ 'reasoningEngines/test_engine/sessions/123'
+ ),
+ 'done': False,
+ }
+ elif http_method == 'DELETE':
+ match = re.match(SESSION_REGEX, path)
+ if match:
+ self.session_dict.pop(match.group(2))
+ else:
+ raise ValueError(f'Unsupported http method: {http_method}')
+
+
+def mock_vertex_ai_session_service():
+ """Creates a mock Vertex AI Session service for testing."""
+ service = VertexAiSessionService(
+ project='test-project', location='test-location'
+ )
+ service.api_client = MockApiClient()
+ service.api_client.session_dict = {
+ '1': MOCK_SESSION_JSON_1,
+ '2': MOCK_SESSION_JSON_2,
+ '3': MOCK_SESSION_JSON_3,
+ }
+ service.api_client.event_dict = {
+ '1': MOCK_EVENT_JSON,
+ }
+ return service
+
+
+def test_get_empty_session():
+ session_service = mock_vertex_ai_session_service()
+ with pytest.raises(ValueError) as excinfo:
+ assert session_service.get_session(
+ app_name='123', user_id='user', session_id='0'
+ )
+ assert str(excinfo.value) == 'Session not found: 0'
+
+
+def test_get_and_delete_session():
+ session_service = mock_vertex_ai_session_service()
+
+ assert (
+ session_service.get_session(
+ app_name='123', user_id='user', session_id='1'
+ )
+ == MOCK_SESSION
+ )
+
+ session_service.delete_session(app_name='123', user_id='user', session_id='1')
+ with pytest.raises(ValueError) as excinfo:
+ assert session_service.get_session(
+ app_name='123', user_id='user', session_id='1'
+ )
+ assert str(excinfo.value) == 'Session not found: 1'
+
+ def test_list_sessions():
+ session_service = mock_vertex_ai_session_service()
+ sessions = session_service.list_sessions(app_name='123', user_id='user')
+ assert len(sessions.sessions) == 2
+ assert sessions.sessions[0].id == '1'
+ assert sessions.sessions[1].id == '2'
+
+ def test_create_session():
+ session_service = mock_vertex_ai_session_service()
+ session = session_service.create_session(
+ app_name='123', user_id='user', state={'key': 'value'}
+ )
+ assert session.state == {'key': 'value'}
+ assert session.app_name == '123'
+ assert session.user_id == 'user'
+ assert session.last_update_time is not None
+
+ session_id = session.id
+ assert session == session_service.get_session(
+ app_name='123', user_id='user', session_id=session_id
+ )
diff --git a/src/google/adk/tests/unittests/streaming/__init__.py b/src/google/adk/tests/unittests/streaming/__init__.py
new file mode 100644
index 0000000..36a1e8d
--- /dev/null
+++ b/src/google/adk/tests/unittests/streaming/__init__.py
@@ -0,0 +1,14 @@
+# Copyright 2025 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
diff --git a/src/google/adk/tests/unittests/streaming/test_streaming.py b/src/google/adk/tests/unittests/streaming/test_streaming.py
new file mode 100644
index 0000000..00d387b
--- /dev/null
+++ b/src/google/adk/tests/unittests/streaming/test_streaming.py
@@ -0,0 +1,50 @@
+# Copyright 2025 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from google.adk.agents import Agent
+from google.adk.agents import LiveRequestQueue
+from google.adk.models import LlmResponse
+from google.genai import types
+import pytest
+
+from .. import utils
+
+
+@pytest.mark.skip(reason='Streaming is hanging.')
+def test_streaming():
+ response1 = LlmResponse(
+ turn_complete=True,
+ )
+
+ mock_model = utils.MockModel.create([response1])
+
+ root_agent = Agent(
+ name='root_agent',
+ model=mock_model,
+ tools=[],
+ )
+
+ runner = utils.InMemoryRunner(
+ root_agent=root_agent, response_modalities=['AUDIO']
+ )
+ live_request_queue = LiveRequestQueue()
+ live_request_queue.send_realtime(
+ blob=types.Blob(data=b'\x00\xFF', mime_type='audio/pcm')
+ )
+ res_events = runner.run_live(live_request_queue)
+
+ assert res_events is not None, 'Expected a list of events, got None.'
+ assert (
+ len(res_events) > 0
+ ), 'Expected at least one response, but got an empty list.'
diff --git a/src/google/adk/tests/unittests/tools/__init__.py b/src/google/adk/tests/unittests/tools/__init__.py
new file mode 100644
index 0000000..36a1e8d
--- /dev/null
+++ b/src/google/adk/tests/unittests/tools/__init__.py
@@ -0,0 +1,14 @@
+# Copyright 2025 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
diff --git a/src/google/adk/tests/unittests/tools/apihub_tool/clients/test_apihub_client.py b/src/google/adk/tests/unittests/tools/apihub_tool/clients/test_apihub_client.py
new file mode 100644
index 0000000..9a84ee9
--- /dev/null
+++ b/src/google/adk/tests/unittests/tools/apihub_tool/clients/test_apihub_client.py
@@ -0,0 +1,499 @@
+# Copyright 2025 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import base64
+import json
+from unittest.mock import MagicMock, patch
+from google.adk.tools.apihub_tool.clients.apihub_client import APIHubClient
+import pytest
+from requests.exceptions import HTTPError
+
+# Mock data for API responses
+MOCK_API_LIST = {
+ "apis": [
+ {"name": "projects/test-project/locations/us-central1/apis/api1"},
+ {"name": "projects/test-project/locations/us-central1/apis/api2"},
+ ]
+}
+MOCK_API_DETAIL = {
+ "name": "projects/test-project/locations/us-central1/apis/api1",
+ "versions": [
+ "projects/test-project/locations/us-central1/apis/api1/versions/v1"
+ ],
+}
+MOCK_API_VERSION = {
+ "name": "projects/test-project/locations/us-central1/apis/api1/versions/v1",
+ "specs": [
+ "projects/test-project/locations/us-central1/apis/api1/versions/v1/specs/spec1"
+ ],
+}
+MOCK_SPEC_CONTENT = {"contents": base64.b64encode(b"spec content").decode()}
+
+
+# Test cases
+class TestAPIHubClient:
+
+ @pytest.fixture
+ def client(self):
+ return APIHubClient(access_token="mocked_token")
+
+ @pytest.fixture
+ def service_account_config(self):
+ return json.dumps({
+ "type": "service_account",
+ "project_id": "test",
+ "token_uri": "test.com",
+ "client_email": "test@example.com",
+ "private_key": "1234",
+ })
+
+ @patch("requests.get")
+ def test_list_apis(self, mock_get, client):
+ mock_get.return_value.json.return_value = MOCK_API_LIST
+ mock_get.return_value.status_code = 200
+
+ apis = client.list_apis("test-project", "us-central1")
+ assert apis == MOCK_API_LIST["apis"]
+ mock_get.assert_called_once_with(
+ "https://apihub.googleapis.com/v1/projects/test-project/locations/us-central1/apis",
+ headers={
+ "accept": "application/json, text/plain, */*",
+ "Authorization": "Bearer mocked_token",
+ },
+ )
+
+ @patch("requests.get")
+ def test_list_apis_empty(self, mock_get, client):
+ mock_get.return_value.json.return_value = {"apis": []}
+ mock_get.return_value.status_code = 200
+
+ apis = client.list_apis("test-project", "us-central1")
+ assert apis == []
+
+ @patch("requests.get")
+ def test_list_apis_error(self, mock_get, client):
+ mock_get.return_value.raise_for_status.side_effect = HTTPError
+
+ with pytest.raises(HTTPError):
+ client.list_apis("test-project", "us-central1")
+
+ @patch("requests.get")
+ def test_get_api(self, mock_get, client):
+ mock_get.return_value.json.return_value = MOCK_API_DETAIL
+ mock_get.return_value.status_code = 200
+ api = client.get_api(
+ "projects/test-project/locations/us-central1/apis/api1"
+ )
+ assert api == MOCK_API_DETAIL
+ mock_get.assert_called_once_with(
+ "https://apihub.googleapis.com/v1/projects/test-project/locations/us-central1/apis/api1",
+ headers={
+ "accept": "application/json, text/plain, */*",
+ "Authorization": "Bearer mocked_token",
+ },
+ )
+
+ @patch("requests.get")
+ def test_get_api_error(self, mock_get, client):
+ mock_get.return_value.raise_for_status.side_effect = HTTPError
+ with pytest.raises(HTTPError):
+ client.get_api("projects/test-project/locations/us-central1/apis/api1")
+
+ @patch("requests.get")
+ def test_get_api_version(self, mock_get, client):
+ mock_get.return_value.json.return_value = MOCK_API_VERSION
+ mock_get.return_value.status_code = 200
+ api_version = client.get_api_version(
+ "projects/test-project/locations/us-central1/apis/api1/versions/v1"
+ )
+ assert api_version == MOCK_API_VERSION
+ mock_get.assert_called_once_with(
+ "https://apihub.googleapis.com/v1/projects/test-project/locations/us-central1/apis/api1/versions/v1",
+ headers={
+ "accept": "application/json, text/plain, */*",
+ "Authorization": "Bearer mocked_token",
+ },
+ )
+
+ @patch("requests.get")
+ def test_get_api_version_error(self, mock_get, client):
+ mock_get.return_value.raise_for_status.side_effect = HTTPError
+ with pytest.raises(HTTPError):
+ client.get_api_version(
+ "projects/test-project/locations/us-central1/apis/api1/versions/v1"
+ )
+
+ @patch("requests.get")
+ def test_get_spec_content(self, mock_get, client):
+ mock_get.return_value.json.return_value = MOCK_SPEC_CONTENT
+ mock_get.return_value.status_code = 200
+ spec_content = client.get_spec_content(
+ "projects/test-project/locations/us-central1/apis/api1/versions/v1/specs/spec1"
+ )
+ assert spec_content == "spec content"
+ mock_get.assert_called_once_with(
+ "https://apihub.googleapis.com/v1/projects/test-project/locations/us-central1/apis/api1/versions/v1/specs/spec1:contents",
+ headers={
+ "accept": "application/json, text/plain, */*",
+ "Authorization": "Bearer mocked_token",
+ },
+ )
+
+ @patch("requests.get")
+ def test_get_spec_content_empty(self, mock_get, client):
+ mock_get.return_value.json.return_value = {"contents": ""}
+ mock_get.return_value.status_code = 200
+ spec_content = client.get_spec_content(
+ "projects/test-project/locations/us-central1/apis/api1/versions/v1/specs/spec1"
+ )
+ assert spec_content == ""
+
+ @patch("requests.get")
+ def test_get_spec_content_error(self, mock_get, client):
+ mock_get.return_value.raise_for_status.side_effect = HTTPError
+ with pytest.raises(HTTPError):
+ client.get_spec_content(
+ "projects/test-project/locations/us-central1/apis/api1/versions/v1/specs/spec1"
+ )
+
+ @pytest.mark.parametrize(
+ "url_or_path, expected",
+ [
+ (
+ "projects/test-project/locations/us-central1/apis/api1",
+ (
+ "projects/test-project/locations/us-central1/apis/api1",
+ None,
+ None,
+ ),
+ ),
+ (
+ "projects/test-project/locations/us-central1/apis/api1/versions/v1",
+ (
+ "projects/test-project/locations/us-central1/apis/api1",
+ "projects/test-project/locations/us-central1/apis/api1/versions/v1",
+ None,
+ ),
+ ),
+ (
+ "projects/test-project/locations/us-central1/apis/api1/versions/v1/specs/spec1",
+ (
+ "projects/test-project/locations/us-central1/apis/api1",
+ "projects/test-project/locations/us-central1/apis/api1/versions/v1",
+ "projects/test-project/locations/us-central1/apis/api1/versions/v1/specs/spec1",
+ ),
+ ),
+ (
+ "https://console.cloud.google.com/apigee/api-hub/projects/test-project/locations/us-central1/apis/api1/versions/v1?project=test-project",
+ (
+ "projects/test-project/locations/us-central1/apis/api1",
+ "projects/test-project/locations/us-central1/apis/api1/versions/v1",
+ None,
+ ),
+ ),
+ (
+ "https://console.cloud.google.com/apigee/api-hub/projects/test-project/locations/us-central1/apis/api1/versions/v1/specs/spec1?project=test-project",
+ (
+ "projects/test-project/locations/us-central1/apis/api1",
+ "projects/test-project/locations/us-central1/apis/api1/versions/v1",
+ "projects/test-project/locations/us-central1/apis/api1/versions/v1/specs/spec1",
+ ),
+ ),
+ (
+ "/projects/test-project/locations/us-central1/apis/api1/versions/v1",
+ (
+ "projects/test-project/locations/us-central1/apis/api1",
+ "projects/test-project/locations/us-central1/apis/api1/versions/v1",
+ None,
+ ),
+ ),
+ ( # Added trailing slashes
+ "projects/test-project/locations/us-central1/apis/api1/",
+ (
+ "projects/test-project/locations/us-central1/apis/api1",
+ None,
+ None,
+ ),
+ ),
+ ( # case location name
+ "projects/test-project/locations/LOCATION/apis/api1/",
+ (
+ "projects/test-project/locations/LOCATION/apis/api1",
+ None,
+ None,
+ ),
+ ),
+ (
+ "projects/p1/locations/l1/apis/a1/versions/v1/specs/s1",
+ (
+ "projects/p1/locations/l1/apis/a1",
+ "projects/p1/locations/l1/apis/a1/versions/v1",
+ "projects/p1/locations/l1/apis/a1/versions/v1/specs/s1",
+ ),
+ ),
+ ],
+ )
+ def test_extract_resource_name(self, client, url_or_path, expected):
+ result = client._extract_resource_name(url_or_path)
+ assert result == expected
+
+ @pytest.mark.parametrize(
+ "url_or_path, expected_error_message",
+ [
+ (
+ "invalid-path",
+ "Project ID not found in URL or path in APIHubClient.",
+ ),
+ (
+ "projects/test-project",
+ "Location not found in URL or path in APIHubClient.",
+ ),
+ (
+ "projects/test-project/locations/us-central1",
+ "API id not found in URL or path in APIHubClient.",
+ ),
+ ],
+ )
+ def test_extract_resource_name_invalid(
+ self, client, url_or_path, expected_error_message
+ ):
+ with pytest.raises(ValueError, match=expected_error_message):
+ client._extract_resource_name(url_or_path)
+
+ @patch(
+ "google.adk.tools.apihub_tool.clients.apihub_client.default_service_credential"
+ )
+ @patch(
+ "google.adk.tools.apihub_tool.clients.apihub_client.service_account.Credentials.from_service_account_info"
+ )
+ def test_get_access_token_use_default_credential(
+ self,
+ mock_from_service_account_info,
+ mock_default_service_credential,
+ ):
+ mock_credential = MagicMock()
+ mock_credential.token = "default_token"
+ mock_default_service_credential.return_value = (
+ mock_credential,
+ "project_id",
+ )
+ mock_config_credential = MagicMock()
+ mock_config_credential.token = "config_token"
+ mock_from_service_account_info.return_value = mock_config_credential
+
+ client = APIHubClient()
+ token = client._get_access_token()
+ assert token == "default_token"
+ mock_credential.refresh.assert_called_once()
+ assert client.credential_cache == mock_credential
+
+ @patch(
+ "google.adk.tools.apihub_tool.clients.apihub_client.default_service_credential"
+ )
+ @patch(
+ "google.adk.tools.apihub_tool.clients.apihub_client.service_account.Credentials.from_service_account_info"
+ )
+ def test_get_access_token_use_configured_service_account(
+ self,
+ mock_from_service_account_info,
+ mock_default_service_credential,
+ service_account_config,
+ ):
+ mock_credential = MagicMock()
+ mock_credential.token = "default_token"
+ mock_default_service_credential.return_value = (
+ mock_credential,
+ "project_id",
+ )
+ mock_config_credential = MagicMock()
+ mock_config_credential.token = "config_token"
+ mock_from_service_account_info.return_value = mock_config_credential
+
+ client = APIHubClient(service_account_json=service_account_config)
+ token = client._get_access_token()
+
+ assert token == "config_token"
+ mock_from_service_account_info.assert_called_once_with(
+ json.loads(service_account_config),
+ scopes=["https://www.googleapis.com/auth/cloud-platform"],
+ )
+ mock_config_credential.refresh.assert_called_once()
+ assert client.credential_cache == mock_config_credential
+
+ @patch(
+ "google.adk.tools.apihub_tool.clients.apihub_client.default_service_credential"
+ )
+ def test_get_access_token_not_expired_use_cached_token(
+ self, mock_default_credential
+ ):
+ mock_credentials = MagicMock()
+ mock_credentials.token = "default_service_account_token"
+ mock_default_credential.return_value = (mock_credentials, "")
+
+ client = APIHubClient()
+ # Call #1: Setup cache
+ token = client._get_access_token()
+ assert token == "default_service_account_token"
+ mock_default_credential.assert_called_once()
+
+ # Call #2: Reuse cache
+ mock_credentials.reset_mock()
+ mock_credentials.expired = False
+ token = client._get_access_token()
+ assert token == "default_service_account_token"
+ mock_credentials.refresh.assert_not_called()
+
+ @patch(
+ "google.adk.tools.apihub_tool.clients.apihub_client.default_service_credential"
+ )
+ def test_get_access_token_expired_refresh(self, mock_default_credential):
+ mock_credentials = MagicMock()
+ mock_credentials.token = "default_service_account_token"
+ mock_default_credential.return_value = (mock_credentials, "")
+ client = APIHubClient()
+
+ # Call #1: Setup cache
+ token = client._get_access_token()
+ assert token == "default_service_account_token"
+ mock_default_credential.assert_called_once()
+
+ # Call #2: Cache expired
+ mock_credentials.reset_mock()
+ mock_credentials.expired = True
+ token = client._get_access_token()
+ mock_credentials.refresh.assert_called_once()
+ assert token == "default_service_account_token"
+
+ @patch(
+ "google.adk.tools.apihub_tool.clients.apihub_client.default_service_credential"
+ )
+ def test_get_access_token_no_credentials(
+ self, mock_default_service_credential
+ ):
+ mock_default_service_credential.return_value = (None, None)
+ with pytest.raises(
+ ValueError,
+ match=(
+ "Please provide a service account or an access token to API Hub"
+ " client."
+ ),
+ ):
+ # no service account client
+ APIHubClient()._get_access_token()
+
+ @patch("requests.get")
+ def test_get_spec_content_api_level(self, mock_get, client):
+ mock_get.side_effect = [
+ MagicMock(status_code=200, json=lambda: MOCK_API_DETAIL), # For get_api
+ MagicMock(
+ status_code=200, json=lambda: MOCK_API_VERSION
+ ), # For get_api_version
+ MagicMock(
+ status_code=200, json=lambda: MOCK_SPEC_CONTENT
+ ), # For get_spec_content
+ ]
+
+ content = client.get_spec_content(
+ "projects/test-project/locations/us-central1/apis/api1"
+ )
+ assert content == "spec content"
+ # Check calls - get_api, get_api_version, then get_spec_content
+ assert mock_get.call_count == 3
+
+ @patch("requests.get")
+ def test_get_spec_content_version_level(self, mock_get, client):
+ mock_get.side_effect = [
+ MagicMock(
+ status_code=200, json=lambda: MOCK_API_VERSION
+ ), # For get_api_version
+ MagicMock(
+ status_code=200, json=lambda: MOCK_SPEC_CONTENT
+ ), # For get_spec_content
+ ]
+
+ content = client.get_spec_content(
+ "projects/test-project/locations/us-central1/apis/api1/versions/v1"
+ )
+ assert content == "spec content"
+ assert mock_get.call_count == 2 # get_api_version and get_spec_content
+
+ @patch("requests.get")
+ def test_get_spec_content_spec_level(self, mock_get, client):
+ mock_get.return_value.json.return_value = MOCK_SPEC_CONTENT
+ mock_get.return_value.status_code = 200
+
+ content = client.get_spec_content(
+ "projects/test-project/locations/us-central1/apis/api1/versions/v1/specs/spec1"
+ )
+ assert content == "spec content"
+ mock_get.assert_called_once() # Only get_spec_content should be called
+
+ @patch("requests.get")
+ def test_get_spec_content_no_versions(self, mock_get, client):
+ mock_get.return_value.json.return_value = {
+ "name": "projects/test-project/locations/us-central1/apis/api1",
+ "versions": [],
+ } # No versions
+ mock_get.return_value.status_code = 200
+ with pytest.raises(
+ ValueError,
+ match=(
+ "No versions found in API Hub resource:"
+ " projects/test-project/locations/us-central1/apis/api1"
+ ),
+ ):
+ client.get_spec_content(
+ "projects/test-project/locations/us-central1/apis/api1"
+ )
+
+ @patch("requests.get")
+ def test_get_spec_content_no_specs(self, mock_get, client):
+ mock_get.side_effect = [
+ MagicMock(status_code=200, json=lambda: MOCK_API_DETAIL),
+ MagicMock(
+ status_code=200,
+ json=lambda: {
+ "name": (
+ "projects/test-project/locations/us-central1/apis/api1/versions/v1"
+ ),
+ "specs": [],
+ },
+ ), # No specs
+ ]
+
+ with pytest.raises(
+ ValueError,
+ match=(
+ "No specs found in API Hub version:"
+ " projects/test-project/locations/us-central1/apis/api1/versions/v1"
+ ),
+ ):
+ client.get_spec_content(
+ "projects/test-project/locations/us-central1/apis/api1/versions/v1"
+ )
+
+ @patch("requests.get")
+ def test_get_spec_content_invalid_path(self, mock_get, client):
+ with pytest.raises(
+ ValueError,
+ match=(
+ "Project ID not found in URL or path in APIHubClient. Input"
+ " path is 'invalid-path'."
+ ),
+ ):
+ client.get_spec_content("invalid-path")
+
+
+if __name__ == "__main__":
+ pytest.main([__file__])
diff --git a/src/google/adk/tests/unittests/tools/apihub_tool/test_apihub_toolset.py b/src/google/adk/tests/unittests/tools/apihub_tool/test_apihub_toolset.py
new file mode 100644
index 0000000..9ec68fa
--- /dev/null
+++ b/src/google/adk/tests/unittests/tools/apihub_tool/test_apihub_toolset.py
@@ -0,0 +1,204 @@
+# Copyright 2025 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from unittest.mock import MagicMock
+
+from google.adk.auth.auth_credential import AuthCredential
+from google.adk.auth.auth_schemes import AuthScheme
+from google.adk.tools.apihub_tool.apihub_toolset import APIHubToolset
+from google.adk.tools.apihub_tool.clients.apihub_client import BaseAPIHubClient
+import pytest
+import yaml
+
+
+class MockAPIHubClient(BaseAPIHubClient):
+
+ def get_spec_content(self, apihub_resource_name: str) -> str:
+ return """
+openapi: 3.0.0
+info:
+ version: 1.0.0
+ title: Mock API
+ description: Mock API Description
+paths:
+ /test:
+ get:
+ summary: Test GET endpoint
+ operationId: testGet
+ responses:
+ '200':
+ description: Successful response
+ """
+
+
+# Fixture for a basic APIHubToolset
+@pytest.fixture
+def basic_apihub_toolset():
+ apihub_client = MockAPIHubClient()
+ tool = APIHubToolset(
+ apihub_resource_name='test_resource', apihub_client=apihub_client
+ )
+ return tool
+
+
+# Fixture for an APIHubToolset with lazy loading
+@pytest.fixture
+def lazy_apihub_toolset():
+ apihub_client = MockAPIHubClient()
+ tool = APIHubToolset(
+ apihub_resource_name='test_resource',
+ apihub_client=apihub_client,
+ lazy_load_spec=True,
+ )
+ return tool
+
+
+# Fixture for auth scheme
+@pytest.fixture
+def mock_auth_scheme():
+ return MagicMock(spec=AuthScheme)
+
+
+# Fixture for auth credential
+@pytest.fixture
+def mock_auth_credential():
+ return MagicMock(spec=AuthCredential)
+
+
+# Test cases
+def test_apihub_toolset_initialization(basic_apihub_toolset):
+ assert basic_apihub_toolset.name == 'mock_api'
+ assert basic_apihub_toolset.description == 'Mock API Description'
+ assert basic_apihub_toolset.apihub_resource_name == 'test_resource'
+ assert not basic_apihub_toolset.lazy_load_spec
+ assert len(basic_apihub_toolset.generated_tools) == 1
+ assert 'test_get' in basic_apihub_toolset.generated_tools
+
+
+def test_apihub_toolset_lazy_loading(lazy_apihub_toolset):
+ assert lazy_apihub_toolset.lazy_load_spec
+ assert not lazy_apihub_toolset.generated_tools
+
+ tools = lazy_apihub_toolset.get_tools()
+ assert len(tools) == 1
+ assert lazy_apihub_toolset.get_tool('test_get') == tools[0]
+
+
+def test_apihub_toolset_no_title_in_spec(basic_apihub_toolset):
+ spec = """
+openapi: 3.0.0
+info:
+ version: 1.0.0
+paths:
+ /empty_desc_test:
+ delete:
+ summary: Test DELETE endpoint
+ operationId: emptyDescTest
+ responses:
+ '200':
+ description: Successful response
+ """
+
+ class MockAPIHubClientEmptySpec(BaseAPIHubClient):
+
+ def get_spec_content(self, apihub_resource_name: str) -> str:
+ return spec
+
+ apihub_client = MockAPIHubClientEmptySpec()
+ toolset = APIHubToolset(
+ apihub_resource_name='test_resource',
+ apihub_client=apihub_client,
+ )
+
+ assert toolset.name == 'unnamed'
+
+
+def test_apihub_toolset_empty_description_in_spec():
+ spec = """
+openapi: 3.0.0
+info:
+ version: 1.0.0
+ title: Empty Description API
+paths:
+ /empty_desc_test:
+ delete:
+ summary: Test DELETE endpoint
+ operationId: emptyDescTest
+ responses:
+ '200':
+ description: Successful response
+ """
+
+ class MockAPIHubClientEmptySpec(BaseAPIHubClient):
+
+ def get_spec_content(self, apihub_resource_name: str) -> str:
+ return spec
+
+ apihub_client = MockAPIHubClientEmptySpec()
+ toolset = APIHubToolset(
+ apihub_resource_name='test_resource',
+ apihub_client=apihub_client,
+ )
+
+ assert toolset.name == 'empty_description_api'
+ assert toolset.description == ''
+
+
+def test_get_tools_with_auth(mock_auth_scheme, mock_auth_credential):
+ apihub_client = MockAPIHubClient()
+ tool = APIHubToolset(
+ apihub_resource_name='test_resource',
+ apihub_client=apihub_client,
+ auth_scheme=mock_auth_scheme,
+ auth_credential=mock_auth_credential,
+ )
+ tools = tool.get_tools()
+ assert len(tools) == 1
+
+
+def test_apihub_toolset_get_tools_lazy_load_empty_spec():
+
+ class MockAPIHubClientEmptySpec(BaseAPIHubClient):
+
+ def get_spec_content(self, apihub_resource_name: str) -> str:
+ return ''
+
+ apihub_client = MockAPIHubClientEmptySpec()
+ tool = APIHubToolset(
+ apihub_resource_name='test_resource',
+ apihub_client=apihub_client,
+ lazy_load_spec=True,
+ )
+ tools = tool.get_tools()
+ assert not tools
+
+
+def test_apihub_toolset_get_tools_invalid_yaml():
+
+ class MockAPIHubClientInvalidYAML(BaseAPIHubClient):
+
+ def get_spec_content(self, apihub_resource_name: str) -> str:
+ return '{invalid yaml' # Return invalid YAML
+
+ with pytest.raises(yaml.YAMLError):
+ apihub_client = MockAPIHubClientInvalidYAML()
+ tool = APIHubToolset(
+ apihub_resource_name='test_resource',
+ apihub_client=apihub_client,
+ )
+ tool.get_tools()
+
+
+if __name__ == '__main__':
+ pytest.main([__file__])
diff --git a/src/google/adk/tests/unittests/tools/application_integration_tool/clients/test_connections_client.py b/src/google/adk/tests/unittests/tools/application_integration_tool/clients/test_connections_client.py
new file mode 100644
index 0000000..975073f
--- /dev/null
+++ b/src/google/adk/tests/unittests/tools/application_integration_tool/clients/test_connections_client.py
@@ -0,0 +1,600 @@
+# Copyright 2025 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import json
+from unittest import mock
+
+from google.adk.tools.application_integration_tool.clients.connections_client import ConnectionsClient
+import google.auth
+import pytest
+import requests
+from requests import exceptions
+
+
+@pytest.fixture
+def project():
+ return "test-project"
+
+
+@pytest.fixture
+def location():
+ return "us-central1"
+
+
+@pytest.fixture
+def connection_name():
+ return "test-connection"
+
+
+@pytest.fixture
+def mock_credentials():
+ creds = mock.create_autospec(google.auth.credentials.Credentials)
+ creds.token = "test_token"
+ creds.expired = False
+ return creds
+
+
+@pytest.fixture
+def mock_auth_request():
+ return mock.create_autospec(google.auth.transport.requests.Request)
+
+
+class TestConnectionsClient:
+
+ def test_initialization(self, project, location, connection_name):
+ credentials = {"email": "test@example.com"}
+ client = ConnectionsClient(
+ project, location, connection_name, json.dumps(credentials)
+ )
+ assert client.project == project
+ assert client.location == location
+ assert client.connection == connection_name
+ assert client.connector_url == "https://connectors.googleapis.com"
+ assert client.service_account_json == json.dumps(credentials)
+ assert client.credential_cache is None
+
+ def test_execute_api_call_success(
+ self, project, location, connection_name, mock_credentials
+ ):
+ credentials = {"email": "test@example.com"}
+ client = ConnectionsClient(project, location, connection_name, credentials)
+ mock_response = mock.MagicMock()
+ mock_response.status_code = 200
+ mock_response.raise_for_status.return_value = None
+ mock_response.json.return_value = {"data": "test"}
+
+ with mock.patch.object(
+ client, "_get_access_token", return_value=mock_credentials.token
+ ), mock.patch("requests.get", return_value=mock_response):
+ response = client._execute_api_call("https://test.url")
+ assert response.json() == {"data": "test"}
+ requests.get.assert_called_once_with(
+ "https://test.url",
+ headers={
+ "Content-Type": "application/json",
+ "Authorization": f"Bearer {mock_credentials.token}",
+ },
+ )
+
+ def test_execute_api_call_credential_error(
+ self, project, location, connection_name
+ ):
+ credentials = {"email": "test@example.com"}
+ client = ConnectionsClient(project, location, connection_name, credentials)
+ with mock.patch.object(
+ client,
+ "_get_access_token",
+ side_effect=google.auth.exceptions.DefaultCredentialsError("Test"),
+ ):
+ with pytest.raises(PermissionError, match="Credentials error: Test"):
+ client._execute_api_call("https://test.url")
+
+ @pytest.mark.parametrize(
+ "status_code, response_text",
+ [(404, "Not Found"), (400, "Bad Request")],
+ )
+ def test_execute_api_call_request_error_not_found_or_bad_request(
+ self,
+ project,
+ location,
+ connection_name,
+ mock_credentials,
+ status_code,
+ response_text,
+ ):
+ credentials = {"email": "test@example.com"}
+ client = ConnectionsClient(project, location, connection_name, credentials)
+ mock_response = mock.MagicMock()
+ mock_response.status_code = status_code
+ mock_response.raise_for_status.side_effect = exceptions.HTTPError(
+ f"HTTP error {status_code}: {response_text}"
+ )
+
+ with mock.patch.object(
+ client, "_get_access_token", return_value=mock_credentials.token
+ ), mock.patch("requests.get", return_value=mock_response):
+ with pytest.raises(
+ ValueError, match="Invalid request. Please check the provided"
+ ):
+ client._execute_api_call("https://test.url")
+
+ def test_execute_api_call_other_request_error(
+ self, project, location, connection_name, mock_credentials
+ ):
+ credentials = {"email": "test@example.com"}
+ client = ConnectionsClient(project, location, connection_name, credentials)
+ mock_response = mock.MagicMock()
+ mock_response.status_code = 500
+ mock_response.raise_for_status.side_effect = exceptions.HTTPError(
+ "Internal Server Error"
+ )
+
+ with mock.patch.object(
+ client, "_get_access_token", return_value=mock_credentials.token
+ ), mock.patch("requests.get", return_value=mock_response):
+ with pytest.raises(ValueError, match="Request error: "):
+ client._execute_api_call("https://test.url")
+
+ def test_execute_api_call_unexpected_error(
+ self, project, location, connection_name, mock_credentials
+ ):
+ credentials = {"email": "test@example.com"}
+ client = ConnectionsClient(project, location, connection_name, credentials)
+ with mock.patch.object(
+ client, "_get_access_token", return_value=mock_credentials.token
+ ), mock.patch(
+ "requests.get", side_effect=Exception("Something went wrong")
+ ):
+ with pytest.raises(
+ Exception, match="An unexpected error occurred: Something went wrong"
+ ):
+ client._execute_api_call("https://test.url")
+
+ def test_get_connection_details_success_with_host(
+ self, project, location, connection_name, mock_credentials
+ ):
+ credentials = {"email": "test@example.com"}
+ client = ConnectionsClient(project, location, connection_name, credentials)
+ mock_response = mock.MagicMock()
+ mock_response.status_code = 200
+ mock_response.json.return_value = {
+ "serviceDirectory": "test_service",
+ "host": "test.host",
+ "tlsServiceDirectory": "tls_test_service",
+ "authOverrideEnabled": True,
+ }
+
+ with mock.patch.object(
+ client, "_execute_api_call", return_value=mock_response
+ ):
+ details = client.get_connection_details()
+ assert details == {
+ "serviceName": "tls_test_service",
+ "host": "test.host",
+ "authOverrideEnabled": True,
+ }
+
+ def test_get_connection_details_success_without_host(
+ self, project, location, connection_name, mock_credentials
+ ):
+ credentials = {"email": "test@example.com"}
+ client = ConnectionsClient(project, location, connection_name, credentials)
+ mock_response = mock.MagicMock()
+ mock_response.status_code = 200
+ mock_response.json.return_value = {
+ "serviceDirectory": "test_service",
+ "authOverrideEnabled": False,
+ }
+
+ with mock.patch.object(
+ client, "_execute_api_call", return_value=mock_response
+ ):
+ details = client.get_connection_details()
+ assert details == {
+ "serviceName": "test_service",
+ "host": "",
+ "authOverrideEnabled": False,
+ }
+
+ def test_get_connection_details_error(
+ self, project, location, connection_name
+ ):
+ credentials = {"email": "test@example.com"}
+ client = ConnectionsClient(project, location, connection_name, credentials)
+ with mock.patch.object(
+ client, "_execute_api_call", side_effect=ValueError("Request error")
+ ):
+ with pytest.raises(ValueError, match="Request error"):
+ client.get_connection_details()
+
+ def test_get_entity_schema_and_operations_success(
+ self, project, location, connection_name, mock_credentials
+ ):
+ credentials = {"email": "test@example.com"}
+ client = ConnectionsClient(project, location, connection_name, credentials)
+ mock_execute_response_initial = mock.MagicMock()
+ mock_execute_response_initial.status_code = 200
+ mock_execute_response_initial.json.return_value = {
+ "name": "operations/test_op"
+ }
+
+ mock_execute_response_poll_done = mock.MagicMock()
+ mock_execute_response_poll_done.status_code = 200
+ mock_execute_response_poll_done.json.return_value = {
+ "done": True,
+ "response": {
+ "jsonSchema": {"type": "object"},
+ "operations": ["LIST", "GET"],
+ },
+ }
+
+ with mock.patch.object(
+ client,
+ "_execute_api_call",
+ side_effect=[
+ mock_execute_response_initial,
+ mock_execute_response_poll_done,
+ ],
+ ):
+ schema, operations = client.get_entity_schema_and_operations("entity1")
+ assert schema == {"type": "object"}
+ assert operations == ["LIST", "GET"]
+ assert (
+ mock.call(
+ f"https://connectors.googleapis.com/v1/projects/{project}/locations/{location}/connections/{connection_name}/connectionSchemaMetadata:getEntityType?entityId=entity1"
+ )
+ in client._execute_api_call.mock_calls
+ )
+ assert (
+ mock.call(f"https://connectors.googleapis.com/v1/operations/test_op")
+ in client._execute_api_call.mock_calls
+ )
+
+ def test_get_entity_schema_and_operations_no_operation_id(
+ self, project, location, connection_name, mock_credentials
+ ):
+ credentials = {"email": "test@example.com"}
+ client = ConnectionsClient(project, location, connection_name, credentials)
+ mock_execute_response = mock.MagicMock()
+ mock_execute_response.status_code = 200
+ mock_execute_response.json.return_value = {}
+
+ with mock.patch.object(
+ client, "_execute_api_call", return_value=mock_execute_response
+ ):
+ with pytest.raises(
+ ValueError,
+ match=(
+ "Failed to get entity schema and operations for entity: entity1"
+ ),
+ ):
+ client.get_entity_schema_and_operations("entity1")
+
+ def test_get_entity_schema_and_operations_execute_api_call_error(
+ self, project, location, connection_name
+ ):
+ credentials = {"email": "test@example.com"}
+ client = ConnectionsClient(project, location, connection_name, credentials)
+ with mock.patch.object(
+ client, "_execute_api_call", side_effect=ValueError("Request error")
+ ):
+ with pytest.raises(ValueError, match="Request error"):
+ client.get_entity_schema_and_operations("entity1")
+
+ def test_get_action_schema_success(
+ self, project, location, connection_name, mock_credentials
+ ):
+ credentials = {"email": "test@example.com"}
+ client = ConnectionsClient(project, location, connection_name, credentials)
+ mock_execute_response_initial = mock.MagicMock()
+ mock_execute_response_initial.status_code = 200
+ mock_execute_response_initial.json.return_value = {
+ "name": "operations/test_op"
+ }
+
+ mock_execute_response_poll_done = mock.MagicMock()
+ mock_execute_response_poll_done.status_code = 200
+ mock_execute_response_poll_done.json.return_value = {
+ "done": True,
+ "response": {
+ "inputJsonSchema": {
+ "type": "object",
+ "properties": {"input": {"type": "string"}},
+ },
+ "outputJsonSchema": {
+ "type": "object",
+ "properties": {"output": {"type": "string"}},
+ },
+ "description": "Test Action Description",
+ "displayName": "TestAction",
+ },
+ }
+
+ with mock.patch.object(
+ client,
+ "_execute_api_call",
+ side_effect=[
+ mock_execute_response_initial,
+ mock_execute_response_poll_done,
+ ],
+ ):
+ schema = client.get_action_schema("action1")
+ assert schema == {
+ "inputSchema": {
+ "type": "object",
+ "properties": {"input": {"type": "string"}},
+ },
+ "outputSchema": {
+ "type": "object",
+ "properties": {"output": {"type": "string"}},
+ },
+ "description": "Test Action Description",
+ "displayName": "TestAction",
+ }
+ assert (
+ mock.call(
+ f"https://connectors.googleapis.com/v1/projects/{project}/locations/{location}/connections/{connection_name}/connectionSchemaMetadata:getAction?actionId=action1"
+ )
+ in client._execute_api_call.mock_calls
+ )
+ assert (
+ mock.call(f"https://connectors.googleapis.com/v1/operations/test_op")
+ in client._execute_api_call.mock_calls
+ )
+
+ def test_get_action_schema_no_operation_id(
+ self, project, location, connection_name, mock_credentials
+ ):
+ credentials = {"email": "test@example.com"}
+ client = ConnectionsClient(project, location, connection_name, credentials)
+ mock_execute_response = mock.MagicMock()
+ mock_execute_response.status_code = 200
+ mock_execute_response.json.return_value = {}
+
+ with mock.patch.object(
+ client, "_execute_api_call", return_value=mock_execute_response
+ ):
+ with pytest.raises(
+ ValueError, match="Failed to get action schema for action: action1"
+ ):
+ client.get_action_schema("action1")
+
+ def test_get_action_schema_execute_api_call_error(
+ self, project, location, connection_name
+ ):
+ credentials = {"email": "test@example.com"}
+ client = ConnectionsClient(project, location, connection_name, credentials)
+ with mock.patch.object(
+ client, "_execute_api_call", side_effect=ValueError("Request error")
+ ):
+ with pytest.raises(ValueError, match="Request error"):
+ client.get_action_schema("action1")
+
+ def test_get_connector_base_spec(self):
+ spec = ConnectionsClient.get_connector_base_spec()
+ assert "openapi" in spec
+ assert spec["info"]["title"] == "ExecuteConnection"
+ assert "components" in spec
+ assert "schemas" in spec["components"]
+ assert "operation" in spec["components"]["schemas"]
+
+ def test_get_action_operation(self):
+ operation = ConnectionsClient.get_action_operation(
+ "TestAction", "EXECUTE_ACTION", "TestActionDisplayName", "test_tool"
+ )
+ assert "post" in operation
+ assert operation["post"]["summary"] == "TestActionDisplayName"
+ assert "operationId" in operation["post"]
+ assert operation["post"]["operationId"] == "test_tool_TestActionDisplayName"
+
+ def test_list_operation(self):
+ operation = ConnectionsClient.list_operation(
+ "Entity1", '{"type": "object"}', "test_tool"
+ )
+ assert "post" in operation
+ assert operation["post"]["summary"] == "List Entity1"
+ assert "operationId" in operation["post"]
+ assert operation["post"]["operationId"] == "test_tool_list_Entity1"
+
+ def test_get_operation_static(self):
+ operation = ConnectionsClient.get_operation(
+ "Entity1", '{"type": "object"}', "test_tool"
+ )
+ assert "post" in operation
+ assert operation["post"]["summary"] == "Get Entity1"
+ assert "operationId" in operation["post"]
+ assert operation["post"]["operationId"] == "test_tool_get_Entity1"
+
+ def test_create_operation(self):
+ operation = ConnectionsClient.create_operation("Entity1", "test_tool")
+ assert "post" in operation
+ assert operation["post"]["summary"] == "Create Entity1"
+ assert "operationId" in operation["post"]
+ assert operation["post"]["operationId"] == "test_tool_create_Entity1"
+
+ def test_update_operation(self):
+ operation = ConnectionsClient.update_operation("Entity1", "test_tool")
+ assert "post" in operation
+ assert operation["post"]["summary"] == "Update Entity1"
+ assert "operationId" in operation["post"]
+ assert operation["post"]["operationId"] == "test_tool_update_Entity1"
+
+ def test_delete_operation(self):
+ operation = ConnectionsClient.delete_operation("Entity1", "test_tool")
+ assert "post" in operation
+ assert operation["post"]["summary"] == "Delete Entity1"
+ assert operation["post"]["operationId"] == "test_tool_delete_Entity1"
+
+ def test_create_operation_request(self):
+ schema = ConnectionsClient.create_operation_request("Entity1")
+ assert "type" in schema
+ assert schema["type"] == "object"
+ assert "properties" in schema
+ assert "connectorInputPayload" in schema["properties"]
+
+ def test_update_operation_request(self):
+ schema = ConnectionsClient.update_operation_request("Entity1")
+ assert "type" in schema
+ assert schema["type"] == "object"
+ assert "properties" in schema
+ assert "entityId" in schema["properties"]
+
+ def test_get_operation_request_static(self):
+ schema = ConnectionsClient.get_operation_request()
+ assert "type" in schema
+ assert schema["type"] == "object"
+ assert "properties" in schema
+ assert "entityId" in schema["properties"]
+
+ def test_delete_operation_request(self):
+ schema = ConnectionsClient.delete_operation_request()
+ assert "type" in schema
+ assert schema["type"] == "object"
+ assert "properties" in schema
+ assert "entityId" in schema["properties"]
+
+ def test_list_operation_request(self):
+ schema = ConnectionsClient.list_operation_request()
+ assert "type" in schema
+ assert schema["type"] == "object"
+ assert "properties" in schema
+ assert "filterClause" in schema["properties"]
+
+ def test_action_request(self):
+ schema = ConnectionsClient.action_request("TestAction")
+ assert "type" in schema
+ assert schema["type"] == "object"
+ assert "properties" in schema
+ assert "connectorInputPayload" in schema["properties"]
+
+ def test_action_response(self):
+ schema = ConnectionsClient.action_response("TestAction")
+ assert "type" in schema
+ assert schema["type"] == "object"
+ assert "properties" in schema
+ assert "connectorOutputPayload" in schema["properties"]
+
+ def test_execute_custom_query_request(self):
+ schema = ConnectionsClient.execute_custom_query_request()
+ assert "type" in schema
+ assert schema["type"] == "object"
+ assert "properties" in schema
+ assert "query" in schema["properties"]
+
+ def test_connector_payload(self):
+ client = ConnectionsClient("test-project", "us-central1", "test-connection")
+ schema = client.connector_payload(
+ json_schema={
+ "type": "object",
+ "properties": {
+ "input": {
+ "type": ["null", "string"],
+ "description": "description",
+ }
+ },
+ }
+ )
+ assert schema == {
+ "type": "object",
+ "properties": {
+ "input": {
+ "type": "string",
+ "nullable": True,
+ "description": "description",
+ }
+ },
+ }
+
+ def test_get_access_token_uses_cached_token(
+ self, project, location, connection_name, mock_credentials
+ ):
+ credentials = {"email": "test@example.com"}
+ client = ConnectionsClient(project, location, connection_name, credentials)
+ client.credential_cache = mock_credentials
+ token = client._get_access_token()
+ assert token == "test_token"
+
+ def test_get_access_token_with_service_account_credentials(
+ self, project, location, connection_name
+ ):
+ service_account_json = json.dumps({
+ "client_email": "test@example.com",
+ "private_key": "test_key",
+ })
+ client = ConnectionsClient(
+ project, location, connection_name, service_account_json
+ )
+ mock_creds = mock.create_autospec(google.oauth2.service_account.Credentials)
+ mock_creds.token = "sa_token"
+ mock_creds.expired = False
+
+ with mock.patch(
+ "google.oauth2.service_account.Credentials.from_service_account_info",
+ return_value=mock_creds,
+ ), mock.patch.object(mock_creds, "refresh", return_value=None):
+ token = client._get_access_token()
+ assert token == "sa_token"
+ google.oauth2.service_account.Credentials.from_service_account_info.assert_called_once_with(
+ json.loads(service_account_json),
+ scopes=["https://www.googleapis.com/auth/cloud-platform"],
+ )
+ mock_creds.refresh.assert_called_once()
+
+ def test_get_access_token_with_default_credentials(
+ self, project, location, connection_name, mock_credentials
+ ):
+ client = ConnectionsClient(project, location, connection_name, None)
+ with mock.patch(
+ "google.adk.tools.application_integration_tool.clients.connections_client.default_service_credential",
+ return_value=(mock_credentials, "test_project_id"),
+ ), mock.patch.object(mock_credentials, "refresh", return_value=None):
+ token = client._get_access_token()
+ assert token == "test_token"
+
+ def test_get_access_token_no_valid_credentials(
+ self, project, location, connection_name
+ ):
+ client = ConnectionsClient(project, location, connection_name, None)
+ with mock.patch(
+ "google.adk.tools.application_integration_tool.clients.connections_client.default_service_credential",
+ return_value=(None, None),
+ ):
+ with pytest.raises(
+ ValueError,
+ match=(
+ "Please provide a service account that has the required"
+ " permissions"
+ ),
+ ):
+ client._get_access_token()
+
+ def test_get_access_token_refreshes_expired_token(
+ self, project, location, connection_name, mock_credentials
+ ):
+ client = ConnectionsClient(project, location, connection_name, None)
+ mock_credentials.expired = True
+ mock_credentials.token = "old_token"
+ mock_credentials.refresh.return_value = None
+
+ client.credential_cache = mock_credentials
+ with mock.patch(
+ "google.adk.tools.application_integration_tool.clients.connections_client.default_service_credential",
+ return_value=(mock_credentials, "test_project_id"),
+ ):
+ # Mock the refresh method directly on the instance within the context
+ with mock.patch.object(mock_credentials, "refresh") as mock_refresh:
+ mock_credentials.token = "new_token" # Set the expected new token
+ token = client._get_access_token()
+ assert token == "new_token"
+ mock_refresh.assert_called_once()
diff --git a/src/google/adk/tests/unittests/tools/application_integration_tool/clients/test_integration_client.py b/src/google/adk/tests/unittests/tools/application_integration_tool/clients/test_integration_client.py
new file mode 100644
index 0000000..469fa62
--- /dev/null
+++ b/src/google/adk/tests/unittests/tools/application_integration_tool/clients/test_integration_client.py
@@ -0,0 +1,630 @@
+# Copyright 2025 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import json
+from unittest import mock
+
+from google.adk.tools.application_integration_tool.clients.connections_client import ConnectionsClient
+from google.adk.tools.application_integration_tool.clients.integration_client import IntegrationClient
+import google.auth
+import google.auth.transport.requests
+from google.auth.transport.requests import Request
+from google.oauth2 import service_account
+import pytest
+import requests
+from requests import exceptions
+
+
+@pytest.fixture
+def project():
+ return "test-project"
+
+
+@pytest.fixture
+def location():
+ return "us-central1"
+
+
+@pytest.fixture
+def integration_name():
+ return "test-integration"
+
+
+@pytest.fixture
+def trigger_name():
+ return "test-trigger"
+
+
+@pytest.fixture
+def connection_name():
+ return "test-connection"
+
+
+@pytest.fixture
+def mock_credentials():
+ creds = mock.create_autospec(google.auth.credentials.Credentials)
+ creds.token = "test_token"
+ return creds
+
+
+@pytest.fixture
+def mock_auth_request():
+ return mock.create_autospec(Request)
+
+
+@pytest.fixture
+def mock_connections_client():
+ with mock.patch(
+ "google.adk.tools.application_integration_tool.clients.integration_client.ConnectionsClient"
+ ) as mock_client:
+ mock_instance = mock.create_autospec(ConnectionsClient)
+ mock_client.return_value = mock_instance
+ yield mock_client
+
+
+class TestIntegrationClient:
+
+ def test_initialization(
+ self, project, location, integration_name, trigger_name, connection_name
+ ):
+ client = IntegrationClient(
+ project=project,
+ location=location,
+ integration=integration_name,
+ trigger=trigger_name,
+ connection=connection_name,
+ entity_operations={"entity": ["LIST"]},
+ actions=["action1"],
+ service_account_json=json.dumps({"email": "test@example.com"}),
+ )
+ assert client.project == project
+ assert client.location == location
+ assert client.integration == integration_name
+ assert client.trigger == trigger_name
+ assert client.connection == connection_name
+ assert client.entity_operations == {"entity": ["LIST"]}
+ assert client.actions == ["action1"]
+ assert client.service_account_json == json.dumps(
+ {"email": "test@example.com"}
+ )
+ assert client.credential_cache is None
+
+ def test_get_openapi_spec_for_integration_success(
+ self,
+ project,
+ location,
+ integration_name,
+ trigger_name,
+ mock_credentials,
+ mock_connections_client,
+ ):
+ expected_spec = {"openapi": "3.0.0", "info": {"title": "Test Integration"}}
+ mock_response = mock.MagicMock()
+ mock_response.status_code = 200
+ mock_response.json.return_value = {"openApiSpec": json.dumps(expected_spec)}
+
+ with mock.patch.object(
+ IntegrationClient,
+ "_get_access_token",
+ return_value=mock_credentials.token,
+ ), mock.patch("requests.post", return_value=mock_response):
+ client = IntegrationClient(
+ project=project,
+ location=location,
+ integration=integration_name,
+ trigger=trigger_name,
+ connection=None,
+ entity_operations=None,
+ actions=None,
+ service_account_json=None,
+ )
+ spec = client.get_openapi_spec_for_integration()
+ assert spec == expected_spec
+ requests.post.assert_called_once_with(
+ f"https://{location}-integrations.googleapis.com/v1/projects/{project}/locations/{location}:generateOpenApiSpec",
+ headers={
+ "Content-Type": "application/json",
+ "Authorization": f"Bearer {mock_credentials.token}",
+ },
+ json={
+ "apiTriggerResources": [{
+ "integrationResource": integration_name,
+ "triggerId": [trigger_name],
+ }],
+ "fileFormat": "JSON",
+ },
+ )
+
+ def test_get_openapi_spec_for_integration_credential_error(
+ self,
+ project,
+ location,
+ integration_name,
+ trigger_name,
+ mock_connections_client,
+ ):
+ with mock.patch.object(
+ IntegrationClient,
+ "_get_access_token",
+ side_effect=ValueError(
+ "Please provide a service account that has the required permissions"
+ " to access the connection."
+ ),
+ ):
+ client = IntegrationClient(
+ project=project,
+ location=location,
+ integration=integration_name,
+ trigger=trigger_name,
+ connection=None,
+ entity_operations=None,
+ actions=None,
+ service_account_json=None,
+ )
+ with pytest.raises(
+ Exception,
+ match=(
+ "An unexpected error occurred: Please provide a service account"
+ " that has the required permissions to access the connection."
+ ),
+ ):
+ client.get_openapi_spec_for_integration()
+
+ @pytest.mark.parametrize(
+ "status_code, response_text",
+ [(404, "Not Found"), (400, "Bad Request"), (404, ""), (400, "")],
+ )
+ def test_get_openapi_spec_for_integration_request_error_not_found_or_bad_request(
+ self,
+ project,
+ location,
+ integration_name,
+ trigger_name,
+ mock_credentials,
+ status_code,
+ response_text,
+ mock_connections_client,
+ ):
+ mock_response = mock.MagicMock()
+ mock_response.status_code = status_code
+ mock_response.raise_for_status.side_effect = exceptions.HTTPError(
+ f"HTTP error {status_code}: {response_text}"
+ )
+
+ with mock.patch.object(
+ IntegrationClient,
+ "_get_access_token",
+ return_value=mock_credentials.token,
+ ), mock.patch("requests.post", return_value=mock_response):
+ client = IntegrationClient(
+ project=project,
+ location=location,
+ integration=integration_name,
+ trigger=trigger_name,
+ connection=None,
+ entity_operations=None,
+ actions=None,
+ service_account_json=None,
+ )
+ with pytest.raises(
+ ValueError,
+ match=(
+ "Invalid request. Please check the provided values of"
+ f" project\\({project}\\), location\\({location}\\),"
+ f" integration\\({integration_name}\\) and"
+ f" trigger\\({trigger_name}\\)."
+ ),
+ ):
+ client.get_openapi_spec_for_integration()
+
+ def test_get_openapi_spec_for_integration_other_request_error(
+ self,
+ project,
+ location,
+ integration_name,
+ trigger_name,
+ mock_credentials,
+ mock_connections_client,
+ ):
+ mock_response = mock.MagicMock()
+ mock_response.status_code = 500
+ mock_response.raise_for_status.side_effect = exceptions.HTTPError(
+ "Internal Server Error"
+ )
+
+ with mock.patch.object(
+ IntegrationClient,
+ "_get_access_token",
+ return_value=mock_credentials.token,
+ ), mock.patch("requests.post", return_value=mock_response):
+ client = IntegrationClient(
+ project=project,
+ location=location,
+ integration=integration_name,
+ trigger=trigger_name,
+ connection=None,
+ entity_operations=None,
+ actions=None,
+ service_account_json=None,
+ )
+ with pytest.raises(ValueError, match="Request error: "):
+ client.get_openapi_spec_for_integration()
+
+ def test_get_openapi_spec_for_integration_unexpected_error(
+ self,
+ project,
+ location,
+ integration_name,
+ trigger_name,
+ mock_credentials,
+ mock_connections_client,
+ ):
+ with mock.patch.object(
+ IntegrationClient,
+ "_get_access_token",
+ return_value=mock_credentials.token,
+ ), mock.patch(
+ "requests.post", side_effect=Exception("Something went wrong")
+ ):
+ client = IntegrationClient(
+ project=project,
+ location=location,
+ integration=integration_name,
+ trigger=trigger_name,
+ connection=None,
+ entity_operations=None,
+ actions=None,
+ service_account_json=None,
+ )
+ with pytest.raises(
+ Exception, match="An unexpected error occurred: Something went wrong"
+ ):
+ client.get_openapi_spec_for_integration()
+
+ def test_get_openapi_spec_for_connection_no_entity_operations_or_actions(
+ self, project, location, connection_name, mock_connections_client
+ ):
+ client = IntegrationClient(
+ project=project,
+ location=location,
+ integration=None,
+ trigger=None,
+ connection=connection_name,
+ entity_operations=None,
+ actions=None,
+ service_account_json=None,
+ )
+ with pytest.raises(
+ ValueError,
+ match=(
+ "No entity operations or actions provided. Please provide at least"
+ " one of them."
+ ),
+ ):
+ client.get_openapi_spec_for_connection()
+
+ def test_get_openapi_spec_for_connection_with_entity_operations(
+ self, project, location, connection_name, mock_connections_client
+ ):
+ entity_operations = {"entity1": ["LIST", "GET"]}
+
+ mock_connections_client_instance = mock_connections_client.return_value
+ mock_connections_client_instance.get_connector_base_spec.return_value = {
+ "components": {"schemas": {}},
+ "paths": {},
+ }
+ mock_connections_client_instance.get_entity_schema_and_operations.return_value = (
+ {"type": "object", "properties": {"id": {"type": "string"}}},
+ ["LIST", "GET"],
+ )
+ mock_connections_client_instance.connector_payload.return_value = {
+ "type": "object"
+ }
+ mock_connections_client_instance.list_operation.return_value = {"get": {}}
+ mock_connections_client_instance.list_operation_request.return_value = {
+ "type": "object"
+ }
+ mock_connections_client_instance.get_operation.return_value = {"get": {}}
+ mock_connections_client_instance.get_operation_request.return_value = {
+ "type": "object"
+ }
+
+ client = IntegrationClient(
+ project=project,
+ location=location,
+ integration=None,
+ trigger=None,
+ connection=connection_name,
+ entity_operations=entity_operations,
+ actions=None,
+ service_account_json=None,
+ )
+ spec = client.get_openapi_spec_for_connection()
+ assert "paths" in spec
+ assert (
+ f"/v2/projects/{project}/locations/{location}/integrations/ExecuteConnection:execute?triggerId=api_trigger/ExecuteConnection#list_entity1"
+ in spec["paths"]
+ )
+ assert (
+ f"/v2/projects/{project}/locations/{location}/integrations/ExecuteConnection:execute?triggerId=api_trigger/ExecuteConnection#get_entity1"
+ in spec["paths"]
+ )
+ mock_connections_client.assert_called_once_with(
+ project, location, connection_name, None
+ )
+ mock_connections_client_instance.get_connector_base_spec.assert_called_once()
+ mock_connections_client_instance.get_entity_schema_and_operations.assert_any_call(
+ "entity1"
+ )
+ mock_connections_client_instance.connector_payload.assert_any_call(
+ {"type": "object", "properties": {"id": {"type": "string"}}}
+ )
+ mock_connections_client_instance.list_operation.assert_called_once()
+ mock_connections_client_instance.get_operation.assert_called_once()
+
+ def test_get_openapi_spec_for_connection_with_actions(
+ self, project, location, connection_name, mock_connections_client
+ ):
+ actions = ["TestAction"]
+ mock_connections_client_instance = (
+ mock_connections_client.return_value
+ ) # Corrected line
+ mock_connections_client_instance.get_connector_base_spec.return_value = {
+ "components": {"schemas": {}},
+ "paths": {},
+ }
+ mock_connections_client_instance.get_action_schema.return_value = {
+ "inputSchema": {
+ "type": "object",
+ "properties": {"input": {"type": "string"}},
+ },
+ "outputSchema": {
+ "type": "object",
+ "properties": {"output": {"type": "string"}},
+ },
+ "displayName": "TestAction",
+ }
+ mock_connections_client_instance.connector_payload.side_effect = [
+ {"type": "object"},
+ {"type": "object"},
+ ]
+ mock_connections_client_instance.action_request.return_value = {
+ "type": "object"
+ }
+ mock_connections_client_instance.action_response.return_value = {
+ "type": "object"
+ }
+ mock_connections_client_instance.get_action_operation.return_value = {
+ "post": {}
+ }
+
+ client = IntegrationClient(
+ project=project,
+ location=location,
+ integration=None,
+ trigger=None,
+ connection=connection_name,
+ entity_operations=None,
+ actions=actions,
+ service_account_json=None,
+ )
+ spec = client.get_openapi_spec_for_connection()
+ assert "paths" in spec
+ assert (
+ f"/v2/projects/{project}/locations/{location}/integrations/ExecuteConnection:execute?triggerId=api_trigger/ExecuteConnection#TestAction"
+ in spec["paths"]
+ )
+ mock_connections_client.assert_called_once_with(
+ project, location, connection_name, None
+ )
+ mock_connections_client_instance.get_connector_base_spec.assert_called_once()
+ mock_connections_client_instance.get_action_schema.assert_called_once_with(
+ "TestAction"
+ )
+ mock_connections_client_instance.connector_payload.assert_any_call(
+ {"type": "object", "properties": {"input": {"type": "string"}}}
+ )
+ mock_connections_client_instance.connector_payload.assert_any_call(
+ {"type": "object", "properties": {"output": {"type": "string"}}}
+ )
+ mock_connections_client_instance.action_request.assert_called_once_with(
+ "TestAction"
+ )
+ mock_connections_client_instance.action_response.assert_called_once_with(
+ "TestAction"
+ )
+ mock_connections_client_instance.get_action_operation.assert_called_once()
+
+ def test_get_openapi_spec_for_connection_invalid_operation(
+ self, project, location, connection_name, mock_connections_client
+ ):
+ entity_operations = {"entity1": ["INVALID"]}
+ mock_connections_client_instance = mock_connections_client.return_value
+ mock_connections_client_instance.get_connector_base_spec.return_value = {
+ "components": {"schemas": {}},
+ "paths": {},
+ }
+ mock_connections_client_instance.get_entity_schema_and_operations.return_value = (
+ {"type": "object", "properties": {"id": {"type": "string"}}},
+ ["LIST", "GET"],
+ )
+
+ client = IntegrationClient(
+ project=project,
+ location=location,
+ integration=None,
+ trigger=None,
+ connection=connection_name,
+ entity_operations=entity_operations,
+ actions=None,
+ service_account_json=None,
+ )
+ with pytest.raises(
+ ValueError, match="Invalid operation: INVALID for entity: entity1"
+ ):
+ client.get_openapi_spec_for_connection()
+
+ def test_get_access_token_with_service_account_json(
+ self, project, location, integration_name, trigger_name, connection_name
+ ):
+ service_account_json = json.dumps({
+ "client_email": "test@example.com",
+ "private_key": "test_key",
+ })
+ mock_creds = mock.create_autospec(service_account.Credentials)
+ mock_creds.token = "sa_token"
+ mock_creds.expired = False
+
+ with mock.patch(
+ "google.oauth2.service_account.Credentials.from_service_account_info",
+ return_value=mock_creds,
+ ), mock.patch.object(mock_creds, "refresh", return_value=None):
+ client = IntegrationClient(
+ project=project,
+ location=location,
+ integration=integration_name,
+ trigger=trigger_name,
+ connection=connection_name,
+ entity_operations=None,
+ actions=None,
+ service_account_json=service_account_json,
+ )
+ token = client._get_access_token()
+ assert token == "sa_token"
+ service_account.Credentials.from_service_account_info.assert_called_once_with(
+ json.loads(service_account_json),
+ scopes=["https://www.googleapis.com/auth/cloud-platform"],
+ )
+ mock_creds.refresh.assert_called_once()
+
+ def test_get_access_token_with_default_credentials(
+ self,
+ project,
+ location,
+ integration_name,
+ trigger_name,
+ connection_name,
+ mock_credentials,
+ ):
+ mock_credentials.expired = False
+ with mock.patch(
+ "google.adk.tools.application_integration_tool.clients.integration_client.default_service_credential",
+ return_value=(mock_credentials, "test_project_id"),
+ ), mock.patch.object(mock_credentials, "refresh", return_value=None):
+ client = IntegrationClient(
+ project=project,
+ location=location,
+ integration=integration_name,
+ trigger=trigger_name,
+ connection=connection_name,
+ entity_operations=None,
+ actions=None,
+ service_account_json=None,
+ )
+ token = client._get_access_token()
+ assert token == "test_token"
+
+ def test_get_access_token_no_valid_credentials(
+ self, project, location, integration_name, trigger_name, connection_name
+ ):
+ with mock.patch(
+ "google.adk.tools.application_integration_tool.clients.integration_client.default_service_credential",
+ return_value=(None, None),
+ ), mock.patch(
+ "google.oauth2.service_account.Credentials.from_service_account_info",
+ return_value=None,
+ ):
+ client = IntegrationClient(
+ project=project,
+ location=location,
+ integration=integration_name,
+ trigger=trigger_name,
+ connection=connection_name,
+ entity_operations=None,
+ actions=None,
+ service_account_json=None,
+ )
+ try:
+ client._get_access_token()
+ assert False, "ValueError was not raised" # Explicitly fail if no error
+ except ValueError as e:
+ assert (
+ "Please provide a service account that has the required permissions"
+ " to access the connection."
+ in str(e)
+ )
+
+ def test_get_access_token_uses_cached_token(
+ self,
+ project,
+ location,
+ integration_name,
+ trigger_name,
+ connection_name,
+ mock_credentials,
+ ):
+ mock_credentials.token = "cached_token"
+ mock_credentials.expired = False
+ client = IntegrationClient(
+ project=project,
+ location=location,
+ integration=integration_name,
+ trigger=trigger_name,
+ connection=connection_name,
+ entity_operations=None,
+ actions=None,
+ service_account_json=None,
+ )
+ client.credential_cache = mock_credentials # Simulate a cached credential
+ with mock.patch("google.auth.default") as mock_default, mock.patch(
+ "google.oauth2.service_account.Credentials.from_service_account_info"
+ ) as mock_sa:
+ token = client._get_access_token()
+ assert token == "cached_token"
+ mock_default.assert_not_called()
+ mock_sa.assert_not_called()
+
+ def test_get_access_token_refreshes_expired_token(
+ self,
+ project,
+ location,
+ integration_name,
+ trigger_name,
+ connection_name,
+ mock_credentials,
+ ):
+ mock_credentials = mock.create_autospec(google.auth.credentials.Credentials)
+ mock_credentials.token = "old_token"
+ mock_credentials.expired = True
+ mock_credentials.refresh.return_value = None
+ mock_credentials.token = "new_token" # Simulate token refresh
+
+ with mock.patch(
+ "google.adk.tools.application_integration_tool.clients.integration_client.default_service_credential",
+ return_value=(mock_credentials, "test_project_id"),
+ ):
+ client = IntegrationClient(
+ project=project,
+ location=location,
+ integration=integration_name,
+ trigger=trigger_name,
+ connection=connection_name,
+ entity_operations=None,
+ actions=None,
+ service_account_json=None,
+ )
+ client.credential_cache = mock_credentials
+ token = client._get_access_token()
+ assert token == "new_token"
+ mock_credentials.refresh.assert_called_once()
diff --git a/src/google/adk/tests/unittests/tools/application_integration_tool/test_application_integration_toolset.py b/src/google/adk/tests/unittests/tools/application_integration_tool/test_application_integration_toolset.py
new file mode 100644
index 0000000..3a7f6ea
--- /dev/null
+++ b/src/google/adk/tests/unittests/tools/application_integration_tool/test_application_integration_toolset.py
@@ -0,0 +1,345 @@
+# Copyright 2025 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import json
+from unittest import mock
+
+from google.adk.auth.auth_credential import AuthCredential
+from google.adk.tools.application_integration_tool.application_integration_toolset import ApplicationIntegrationToolset
+from google.adk.tools.openapi_tool.openapi_spec_parser import rest_api_tool
+import pytest
+
+
+@pytest.fixture
+def mock_integration_client():
+ with mock.patch(
+ "google.adk.tools.application_integration_tool.application_integration_toolset.IntegrationClient"
+ ) as mock_client:
+ yield mock_client
+
+
+@pytest.fixture
+def mock_connections_client():
+ with mock.patch(
+ "google.adk.tools.application_integration_tool.application_integration_toolset.ConnectionsClient"
+ ) as mock_client:
+ yield mock_client
+
+
+@pytest.fixture
+def mock_openapi_toolset():
+ with mock.patch(
+ "google.adk.tools.application_integration_tool.application_integration_toolset.OpenAPIToolset"
+ ) as mock_toolset:
+ mock_toolset_instance = mock.MagicMock()
+ mock_rest_api_tool = mock.MagicMock(spec=rest_api_tool.RestApiTool)
+ mock_rest_api_tool.name = "Test Tool"
+ mock_toolset_instance.get_tools.return_value = [mock_rest_api_tool]
+ mock_toolset.return_value = mock_toolset_instance
+ yield mock_toolset
+
+
+@pytest.fixture
+def project():
+ return "test-project"
+
+
+@pytest.fixture
+def location():
+ return "us-central1"
+
+
+@pytest.fixture
+def integration_spec():
+ return {"openapi": "3.0.0", "info": {"title": "Integration API"}}
+
+
+@pytest.fixture
+def connection_spec():
+ return {"openapi": "3.0.0", "info": {"title": "Connection API"}}
+
+
+@pytest.fixture
+def connection_details():
+ return {"serviceName": "test-service", "host": "test.host"}
+
+
+def test_initialization_with_integration_and_trigger(
+ project,
+ location,
+ mock_integration_client,
+ mock_connections_client,
+ mock_openapi_toolset,
+):
+ integration_name = "test-integration"
+ trigger_name = "test-trigger"
+ toolset = ApplicationIntegrationToolset(
+ project, location, integration=integration_name, trigger=trigger_name
+ )
+ mock_integration_client.assert_called_once_with(
+ project, location, integration_name, trigger_name, None, None, None, None
+ )
+ mock_integration_client.return_value.get_openapi_spec_for_integration.assert_called_once()
+ mock_connections_client.assert_not_called()
+ mock_openapi_toolset.assert_called_once()
+ assert len(toolset.get_tools()) == 1
+ assert toolset.get_tools()[0].name == "Test Tool"
+
+
+def test_initialization_with_connection_and_entity_operations(
+ project,
+ location,
+ mock_integration_client,
+ mock_connections_client,
+ mock_openapi_toolset,
+ connection_details,
+):
+ connection_name = "test-connection"
+ entity_operations_list = ["list", "get"]
+ tool_name = "My Connection Tool"
+ tool_instructions = "Use this tool to manage entities."
+ mock_connections_client.return_value.get_connection_details.return_value = (
+ connection_details
+ )
+ toolset = ApplicationIntegrationToolset(
+ project,
+ location,
+ connection=connection_name,
+ entity_operations=entity_operations_list,
+ tool_name=tool_name,
+ tool_instructions=tool_instructions,
+ )
+ mock_integration_client.assert_called_once_with(
+ project,
+ location,
+ None,
+ None,
+ connection_name,
+ entity_operations_list,
+ None,
+ None,
+ )
+ mock_connections_client.assert_called_once_with(
+ project, location, connection_name, None
+ )
+ mock_connections_client.return_value.get_connection_details.assert_called_once()
+ mock_integration_client.return_value.get_openapi_spec_for_connection.assert_called_once_with(
+ tool_name,
+ tool_instructions
+ + f"ALWAYS use serviceName = {connection_details['serviceName']}, host ="
+ f" {connection_details['host']} and the connection name ="
+ f" projects/{project}/locations/{location}/connections/{connection_name} when"
+ " using this tool. DONOT ask the user for these values as you already"
+ " have those.",
+ )
+ mock_openapi_toolset.assert_called_once()
+ assert len(toolset.get_tools()) == 1
+ assert toolset.get_tools()[0].name == "Test Tool"
+
+
+def test_initialization_with_connection_and_actions(
+ project,
+ location,
+ mock_integration_client,
+ mock_connections_client,
+ mock_openapi_toolset,
+ connection_details,
+):
+ connection_name = "test-connection"
+ actions_list = ["create", "delete"]
+ tool_name = "My Actions Tool"
+ tool_instructions = "Perform actions using this tool."
+ mock_connections_client.return_value.get_connection_details.return_value = (
+ connection_details
+ )
+ toolset = ApplicationIntegrationToolset(
+ project,
+ location,
+ connection=connection_name,
+ actions=actions_list,
+ tool_name=tool_name,
+ tool_instructions=tool_instructions,
+ )
+ mock_integration_client.assert_called_once_with(
+ project, location, None, None, connection_name, None, actions_list, None
+ )
+ mock_connections_client.assert_called_once_with(
+ project, location, connection_name, None
+ )
+ mock_connections_client.return_value.get_connection_details.assert_called_once()
+ mock_integration_client.return_value.get_openapi_spec_for_connection.assert_called_once_with(
+ tool_name,
+ tool_instructions
+ + f"ALWAYS use serviceName = {connection_details['serviceName']}, host ="
+ f" {connection_details['host']} and the connection name ="
+ f" projects/{project}/locations/{location}/connections/{connection_name} when"
+ " using this tool. DONOT ask the user for these values as you already"
+ " have those.",
+ )
+ mock_openapi_toolset.assert_called_once()
+ assert len(toolset.get_tools()) == 1
+ assert toolset.get_tools()[0].name == "Test Tool"
+
+
+def test_initialization_without_required_params(project, location):
+ with pytest.raises(
+ ValueError,
+ match=(
+ "Either \\(integration and trigger\\) or \\(connection and"
+ " \\(entity_operations or actions\\)\\) should be provided."
+ ),
+ ):
+ ApplicationIntegrationToolset(project, location)
+
+ with pytest.raises(
+ ValueError,
+ match=(
+ "Either \\(integration and trigger\\) or \\(connection and"
+ " \\(entity_operations or actions\\)\\) should be provided."
+ ),
+ ):
+ ApplicationIntegrationToolset(project, location, integration="test")
+
+ with pytest.raises(
+ ValueError,
+ match=(
+ "Either \\(integration and trigger\\) or \\(connection and"
+ " \\(entity_operations or actions\\)\\) should be provided."
+ ),
+ ):
+ ApplicationIntegrationToolset(project, location, trigger="test")
+
+ with pytest.raises(
+ ValueError,
+ match=(
+ "Either \\(integration and trigger\\) or \\(connection and"
+ " \\(entity_operations or actions\\)\\) should be provided."
+ ),
+ ):
+ ApplicationIntegrationToolset(project, location, connection="test")
+
+
+def test_initialization_with_service_account_credentials(
+ project, location, mock_integration_client, mock_openapi_toolset
+):
+ service_account_json = json.dumps({
+ "type": "service_account",
+ "project_id": "dummy",
+ "private_key_id": "dummy",
+ "private_key": "dummy",
+ "client_email": "test@example.com",
+ "client_id": "131331543646416",
+ "auth_uri": "https://accounts.google.com/o/oauth2/auth",
+ "token_uri": "https://oauth2.googleapis.com/token",
+ "auth_provider_x509_cert_url": (
+ "https://www.googleapis.com/oauth2/v1/certs"
+ ),
+ "client_x509_cert_url": (
+ "http://www.googleapis.com/robot/v1/metadata/x509/dummy%40dummy.com"
+ ),
+ "universe_domain": "googleapis.com",
+ })
+ integration_name = "test-integration"
+ trigger_name = "test-trigger"
+ toolset = ApplicationIntegrationToolset(
+ project,
+ location,
+ integration=integration_name,
+ trigger=trigger_name,
+ service_account_json=service_account_json,
+ )
+ mock_integration_client.assert_called_once_with(
+ project,
+ location,
+ integration_name,
+ trigger_name,
+ None,
+ None,
+ None,
+ service_account_json,
+ )
+ mock_openapi_toolset.assert_called_once()
+ _, kwargs = mock_openapi_toolset.call_args
+ assert isinstance(kwargs["auth_credential"], AuthCredential)
+ assert (
+ kwargs[
+ "auth_credential"
+ ].service_account.service_account_credential.client_email
+ == "test@example.com"
+ )
+
+
+def test_initialization_without_explicit_service_account_credentials(
+ project, location, mock_integration_client, mock_openapi_toolset
+):
+ integration_name = "test-integration"
+ trigger_name = "test-trigger"
+ toolset = ApplicationIntegrationToolset(
+ project, location, integration=integration_name, trigger=trigger_name
+ )
+ mock_integration_client.assert_called_once_with(
+ project, location, integration_name, trigger_name, None, None, None, None
+ )
+ mock_openapi_toolset.assert_called_once()
+ _, kwargs = mock_openapi_toolset.call_args
+ assert isinstance(kwargs["auth_credential"], AuthCredential)
+ assert kwargs["auth_credential"].service_account.use_default_credential
+
+
+def test_get_tools(
+ project, location, mock_integration_client, mock_openapi_toolset
+):
+ integration_name = "test-integration"
+ trigger_name = "test-trigger"
+ toolset = ApplicationIntegrationToolset(
+ project, location, integration=integration_name, trigger=trigger_name
+ )
+ tools = toolset.get_tools()
+ assert len(tools) == 1
+ assert isinstance(tools[0], rest_api_tool.RestApiTool)
+ assert tools[0].name == "Test Tool"
+
+
+def test_initialization_with_connection_details(
+ project,
+ location,
+ mock_integration_client,
+ mock_connections_client,
+ mock_openapi_toolset,
+):
+ connection_name = "test-connection"
+ entity_operations_list = ["list"]
+ tool_name = "My Connection Tool"
+ tool_instructions = "Use this tool."
+ mock_connections_client.return_value.get_connection_details.return_value = {
+ "serviceName": "custom-service",
+ "host": "custom.host",
+ }
+ toolset = ApplicationIntegrationToolset(
+ project,
+ location,
+ connection=connection_name,
+ entity_operations=entity_operations_list,
+ tool_name=tool_name,
+ tool_instructions=tool_instructions,
+ )
+ mock_integration_client.return_value.get_openapi_spec_for_connection.assert_called_once_with(
+ tool_name,
+ tool_instructions
+ + "ALWAYS use serviceName = custom-service, host = custom.host and the"
+ " connection name ="
+ " projects/test-project/locations/us-central1/connections/test-connection"
+ " when using this tool. DONOT ask the user for these values as you"
+ " already have those.",
+ )
diff --git a/src/google/adk/tests/unittests/tools/google_api_tool/__init__.py b/src/google/adk/tests/unittests/tools/google_api_tool/__init__.py
new file mode 100644
index 0000000..0a2669d
--- /dev/null
+++ b/src/google/adk/tests/unittests/tools/google_api_tool/__init__.py
@@ -0,0 +1,13 @@
+# Copyright 2025 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
diff --git a/src/google/adk/tests/unittests/tools/google_api_tool/test_googleapi_to_openapi_converter.py b/src/google/adk/tests/unittests/tools/google_api_tool/test_googleapi_to_openapi_converter.py
new file mode 100644
index 0000000..f35a1cd
--- /dev/null
+++ b/src/google/adk/tests/unittests/tools/google_api_tool/test_googleapi_to_openapi_converter.py
@@ -0,0 +1,657 @@
+# Copyright 2025 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from unittest.mock import MagicMock
+
+from google.adk.tools.google_api_tool.googleapi_to_openapi_converter import GoogleApiToOpenApiConverter
+# Import the converter class
+from googleapiclient.errors import HttpError
+import pytest
+
+
+@pytest.fixture
+def calendar_api_spec():
+ """Fixture that provides a mock Google Calendar API spec for testing."""
+ return {
+ "kind": "discovery#restDescription",
+ "id": "calendar:v3",
+ "name": "calendar",
+ "version": "v3",
+ "title": "Google Calendar API",
+ "description": "Accesses the Google Calendar API",
+ "documentationLink": "https://developers.google.com/calendar/",
+ "protocol": "rest",
+ "rootUrl": "https://www.googleapis.com/",
+ "servicePath": "calendar/v3/",
+ "auth": {
+ "oauth2": {
+ "scopes": {
+ "https://www.googleapis.com/auth/calendar": {
+ "description": "Full access to Google Calendar"
+ },
+ "https://www.googleapis.com/auth/calendar.readonly": {
+ "description": "Read-only access to Google Calendar"
+ },
+ }
+ }
+ },
+ "schemas": {
+ "Calendar": {
+ "type": "object",
+ "description": "A calendar resource",
+ "properties": {
+ "id": {
+ "type": "string",
+ "description": "Calendar identifier",
+ },
+ "summary": {
+ "type": "string",
+ "description": "Calendar summary",
+ "required": True,
+ },
+ "timeZone": {
+ "type": "string",
+ "description": "Calendar timezone",
+ },
+ },
+ },
+ "Event": {
+ "type": "object",
+ "description": "An event resource",
+ "properties": {
+ "id": {"type": "string", "description": "Event identifier"},
+ "summary": {"type": "string", "description": "Event summary"},
+ "start": {"$ref": "EventDateTime"},
+ "end": {"$ref": "EventDateTime"},
+ "attendees": {
+ "type": "array",
+ "description": "Event attendees",
+ "items": {"$ref": "EventAttendee"},
+ },
+ },
+ },
+ "EventDateTime": {
+ "type": "object",
+ "description": "Date/time for an event",
+ "properties": {
+ "dateTime": {
+ "type": "string",
+ "format": "date-time",
+ "description": "Date/time in RFC3339 format",
+ },
+ "timeZone": {
+ "type": "string",
+ "description": "Timezone for the date/time",
+ },
+ },
+ },
+ "EventAttendee": {
+ "type": "object",
+ "description": "An attendee of an event",
+ "properties": {
+ "email": {"type": "string", "description": "Attendee email"},
+ "responseStatus": {
+ "type": "string",
+ "description": "Response status",
+ "enum": [
+ "needsAction",
+ "declined",
+ "tentative",
+ "accepted",
+ ],
+ },
+ },
+ },
+ },
+ "resources": {
+ "calendars": {
+ "methods": {
+ "get": {
+ "id": "calendar.calendars.get",
+ "path": "calendars/{calendarId}",
+ "httpMethod": "GET",
+ "description": "Returns metadata for a calendar.",
+ "parameters": {
+ "calendarId": {
+ "type": "string",
+ "description": "Calendar identifier",
+ "required": True,
+ "location": "path",
+ }
+ },
+ "response": {"$ref": "Calendar"},
+ "scopes": [
+ "https://www.googleapis.com/auth/calendar",
+ "https://www.googleapis.com/auth/calendar.readonly",
+ ],
+ },
+ "insert": {
+ "id": "calendar.calendars.insert",
+ "path": "calendars",
+ "httpMethod": "POST",
+ "description": "Creates a secondary calendar.",
+ "request": {"$ref": "Calendar"},
+ "response": {"$ref": "Calendar"},
+ "scopes": ["https://www.googleapis.com/auth/calendar"],
+ },
+ },
+ "resources": {
+ "events": {
+ "methods": {
+ "list": {
+ "id": "calendar.events.list",
+ "path": "calendars/{calendarId}/events",
+ "httpMethod": "GET",
+ "description": (
+ "Returns events on the specified calendar."
+ ),
+ "parameters": {
+ "calendarId": {
+ "type": "string",
+ "description": "Calendar identifier",
+ "required": True,
+ "location": "path",
+ },
+ "maxResults": {
+ "type": "integer",
+ "description": (
+ "Maximum number of events returned"
+ ),
+ "format": "int32",
+ "minimum": "1",
+ "maximum": "2500",
+ "default": "250",
+ "location": "query",
+ },
+ "orderBy": {
+ "type": "string",
+ "description": (
+ "Order of the events returned"
+ ),
+ "enum": ["startTime", "updated"],
+ "location": "query",
+ },
+ },
+ "response": {"$ref": "Events"},
+ "scopes": [
+ "https://www.googleapis.com/auth/calendar",
+ "https://www.googleapis.com/auth/calendar.readonly",
+ ],
+ }
+ }
+ }
+ },
+ }
+ },
+ }
+
+
+@pytest.fixture
+def converter():
+ """Fixture that provides a basic converter instance."""
+ return GoogleApiToOpenApiConverter("calendar", "v3")
+
+
+@pytest.fixture
+def mock_api_resource(calendar_api_spec):
+ """Fixture that provides a mock API resource with the test spec."""
+ mock_resource = MagicMock()
+ mock_resource._rootDesc = calendar_api_spec
+ return mock_resource
+
+
+@pytest.fixture
+def prepared_converter(converter, calendar_api_spec):
+ """Fixture that provides a converter with the API spec already set."""
+ converter.google_api_spec = calendar_api_spec
+ return converter
+
+
+@pytest.fixture
+def converter_with_patched_build(monkeypatch, mock_api_resource):
+ """Fixture that provides a converter with the build function patched.
+
+ This simulates a successful API spec fetch.
+ """
+ # Create a mock for the build function
+ mock_build = MagicMock(return_value=mock_api_resource)
+
+ # Patch the build function in the target module
+ monkeypatch.setattr(
+ "google.adk.tools.google_api_tool.googleapi_to_openapi_converter.build",
+ mock_build,
+ )
+
+ # Create and return a converter instance
+ return GoogleApiToOpenApiConverter("calendar", "v3")
+
+
+class TestGoogleApiToOpenApiConverter:
+ """Test suite for the GoogleApiToOpenApiConverter class."""
+
+ def test_init(self, converter):
+ """Test converter initialization."""
+ assert converter.api_name == "calendar"
+ assert converter.api_version == "v3"
+ assert converter.google_api_resource is None
+ assert converter.google_api_spec is None
+ assert converter.openapi_spec["openapi"] == "3.0.0"
+ assert "info" in converter.openapi_spec
+ assert "paths" in converter.openapi_spec
+ assert "components" in converter.openapi_spec
+
+ def test_fetch_google_api_spec(
+ self, converter_with_patched_build, calendar_api_spec
+ ):
+ """Test fetching Google API specification."""
+ # Call the method
+ converter_with_patched_build.fetch_google_api_spec()
+
+ # Verify the results
+ assert converter_with_patched_build.google_api_spec == calendar_api_spec
+
+ def test_fetch_google_api_spec_error(self, monkeypatch, converter):
+ """Test error handling when fetching Google API specification."""
+ # Create a mock that raises an error
+ mock_build = MagicMock(
+ side_effect=HttpError(resp=MagicMock(status=404), content=b"Not Found")
+ )
+ monkeypatch.setattr(
+ "google.adk.tools.google_api_tool.googleapi_to_openapi_converter.build",
+ mock_build,
+ )
+
+ # Verify exception is raised
+ with pytest.raises(HttpError):
+ converter.fetch_google_api_spec()
+
+ def test_convert_info(self, prepared_converter):
+ """Test conversion of basic API information."""
+ # Call the method
+ prepared_converter._convert_info()
+
+ # Verify the results
+ info = prepared_converter.openapi_spec["info"]
+ assert info["title"] == "Google Calendar API"
+ assert info["description"] == "Accesses the Google Calendar API"
+ assert info["version"] == "v3"
+ assert info["termsOfService"] == "https://developers.google.com/calendar/"
+
+ # Check external docs
+ external_docs = prepared_converter.openapi_spec["externalDocs"]
+ assert external_docs["url"] == "https://developers.google.com/calendar/"
+
+ def test_convert_servers(self, prepared_converter):
+ """Test conversion of server information."""
+ # Call the method
+ prepared_converter._convert_servers()
+
+ # Verify the results
+ servers = prepared_converter.openapi_spec["servers"]
+ assert len(servers) == 1
+ assert servers[0]["url"] == "https://www.googleapis.com/calendar/v3"
+ assert servers[0]["description"] == "calendar v3 API"
+
+ def test_convert_security_schemes(self, prepared_converter):
+ """Test conversion of security schemes."""
+ # Call the method
+ prepared_converter._convert_security_schemes()
+
+ # Verify the results
+ security_schemes = prepared_converter.openapi_spec["components"][
+ "securitySchemes"
+ ]
+
+ # Check OAuth2 configuration
+ assert "oauth2" in security_schemes
+ oauth2 = security_schemes["oauth2"]
+ assert oauth2["type"] == "oauth2"
+
+ # Check OAuth2 scopes
+ scopes = oauth2["flows"]["authorizationCode"]["scopes"]
+ assert "https://www.googleapis.com/auth/calendar" in scopes
+ assert "https://www.googleapis.com/auth/calendar.readonly" in scopes
+
+ # Check API key configuration
+ assert "apiKey" in security_schemes
+ assert security_schemes["apiKey"]["type"] == "apiKey"
+ assert security_schemes["apiKey"]["in"] == "query"
+ assert security_schemes["apiKey"]["name"] == "key"
+
+ def test_convert_schemas(self, prepared_converter):
+ """Test conversion of schema definitions."""
+ # Call the method
+ prepared_converter._convert_schemas()
+
+ # Verify the results
+ schemas = prepared_converter.openapi_spec["components"]["schemas"]
+
+ # Check Calendar schema
+ assert "Calendar" in schemas
+ calendar_schema = schemas["Calendar"]
+ assert calendar_schema["type"] == "object"
+ assert calendar_schema["description"] == "A calendar resource"
+
+ # Check required properties
+ assert "required" in calendar_schema
+ assert "summary" in calendar_schema["required"]
+
+ # Check Event schema references
+ assert "Event" in schemas
+ event_schema = schemas["Event"]
+ assert (
+ event_schema["properties"]["start"]["$ref"]
+ == "#/components/schemas/EventDateTime"
+ )
+
+ # Check array type with references
+ attendees_schema = event_schema["properties"]["attendees"]
+ assert attendees_schema["type"] == "array"
+ assert (
+ attendees_schema["items"]["$ref"]
+ == "#/components/schemas/EventAttendee"
+ )
+
+ # Check enum values
+ attendee_schema = schemas["EventAttendee"]
+ response_status = attendee_schema["properties"]["responseStatus"]
+ assert "enum" in response_status
+ assert "accepted" in response_status["enum"]
+
+ @pytest.mark.parametrize(
+ "schema_def, expected_type, expected_attrs",
+ [
+ # Test object type
+ (
+ {
+ "type": "object",
+ "description": "Test object",
+ "properties": {
+ "id": {"type": "string", "required": True},
+ "name": {"type": "string"},
+ },
+ },
+ "object",
+ {"description": "Test object", "required": ["id"]},
+ ),
+ # Test array type
+ (
+ {
+ "type": "array",
+ "description": "Test array",
+ "items": {"type": "string"},
+ },
+ "array",
+ {"description": "Test array", "items": {"type": "string"}},
+ ),
+ # Test reference conversion
+ (
+ {"$ref": "Calendar"},
+ None, # No type for references
+ {"$ref": "#/components/schemas/Calendar"},
+ ),
+ # Test enum conversion
+ (
+ {"type": "string", "enum": ["value1", "value2"]},
+ "string",
+ {"enum": ["value1", "value2"]},
+ ),
+ ],
+ )
+ def test_convert_schema_object(
+ self, converter, schema_def, expected_type, expected_attrs
+ ):
+ """Test conversion of individual schema objects with different input variations."""
+ converted = converter._convert_schema_object(schema_def)
+
+ # Check type if expected
+ if expected_type:
+ assert converted["type"] == expected_type
+
+ # Check other expected attributes
+ for key, value in expected_attrs.items():
+ assert converted[key] == value
+
+ @pytest.mark.parametrize(
+ "path, expected_params",
+ [
+ # Path with parameters
+ (
+ "/calendars/{calendarId}/events/{eventId}",
+ ["calendarId", "eventId"],
+ ),
+ # Path without parameters
+ ("/calendars/events", []),
+ # Mixed path
+ ("/users/{userId}/calendars/default", ["userId"]),
+ ],
+ )
+ def test_extract_path_parameters(self, converter, path, expected_params):
+ """Test extraction of path parameters from URL path with various inputs."""
+ params = converter._extract_path_parameters(path)
+ assert set(params) == set(expected_params)
+ assert len(params) == len(expected_params)
+
+ @pytest.mark.parametrize(
+ "param_data, expected_result",
+ [
+ # String parameter
+ (
+ {
+ "type": "string",
+ "description": "String parameter",
+ "pattern": "^[a-z]+$",
+ },
+ {"type": "string", "pattern": "^[a-z]+$"},
+ ),
+ # Integer parameter with format
+ (
+ {"type": "integer", "format": "int32", "default": "10"},
+ {"type": "integer", "format": "int32", "default": "10"},
+ ),
+ # Enum parameter
+ (
+ {"type": "string", "enum": ["option1", "option2"]},
+ {"type": "string", "enum": ["option1", "option2"]},
+ ),
+ ],
+ )
+ def test_convert_parameter_schema(
+ self, converter, param_data, expected_result
+ ):
+ """Test conversion of parameter definitions to OpenAPI schemas."""
+ converted = converter._convert_parameter_schema(param_data)
+
+ # Check all expected attributes
+ for key, value in expected_result.items():
+ assert converted[key] == value
+
+ def test_convert(self, converter_with_patched_build):
+ """Test the complete conversion process."""
+ # Call the method
+ result = converter_with_patched_build.convert()
+
+ # Verify basic structure
+ assert result["openapi"] == "3.0.0"
+ assert "info" in result
+ assert "servers" in result
+ assert "paths" in result
+ assert "components" in result
+
+ # Verify paths
+ paths = result["paths"]
+ assert "/calendars/{calendarId}" in paths
+ assert "get" in paths["/calendars/{calendarId}"]
+
+ # Verify nested resources
+ assert "/calendars/{calendarId}/events" in paths
+
+ # Verify method details
+ get_calendar = paths["/calendars/{calendarId}"]["get"]
+ assert get_calendar["operationId"] == "calendar.calendars.get"
+ assert "parameters" in get_calendar
+
+ # Verify request body
+ insert_calendar = paths["/calendars"]["post"]
+ assert "requestBody" in insert_calendar
+ request_schema = insert_calendar["requestBody"]["content"][
+ "application/json"
+ ]["schema"]
+ assert request_schema["$ref"] == "#/components/schemas/Calendar"
+
+ # Verify response body
+ assert "responses" in get_calendar
+ response_schema = get_calendar["responses"]["200"]["content"][
+ "application/json"
+ ]["schema"]
+ assert response_schema["$ref"] == "#/components/schemas/Calendar"
+
+ def test_convert_methods(self, prepared_converter, calendar_api_spec):
+ """Test conversion of API methods."""
+ # Convert methods
+ methods = calendar_api_spec["resources"]["calendars"]["methods"]
+ prepared_converter._convert_methods(methods, "/calendars")
+
+ # Verify the results
+ paths = prepared_converter.openapi_spec["paths"]
+
+ # Check GET method
+ assert "/calendars/{calendarId}" in paths
+ get_method = paths["/calendars/{calendarId}"]["get"]
+ assert get_method["operationId"] == "calendar.calendars.get"
+
+ # Check parameters
+ params = get_method["parameters"]
+ param_names = [p["name"] for p in params]
+ assert "calendarId" in param_names
+
+ # Check POST method
+ assert "/calendars" in paths
+ post_method = paths["/calendars"]["post"]
+ assert post_method["operationId"] == "calendar.calendars.insert"
+
+ # Check request body
+ assert "requestBody" in post_method
+ assert (
+ post_method["requestBody"]["content"]["application/json"]["schema"][
+ "$ref"
+ ]
+ == "#/components/schemas/Calendar"
+ )
+
+ # Check response
+ assert (
+ post_method["responses"]["200"]["content"]["application/json"][
+ "schema"
+ ]["$ref"]
+ == "#/components/schemas/Calendar"
+ )
+
+ def test_convert_resources(self, prepared_converter, calendar_api_spec):
+ """Test conversion of nested resources."""
+ # Convert resources
+ resources = calendar_api_spec["resources"]
+ prepared_converter._convert_resources(resources)
+
+ # Verify the results
+ paths = prepared_converter.openapi_spec["paths"]
+
+ # Check top-level resource methods
+ assert "/calendars/{calendarId}" in paths
+
+ # Check nested resource methods
+ assert "/calendars/{calendarId}/events" in paths
+ events_method = paths["/calendars/{calendarId}/events"]["get"]
+ assert events_method["operationId"] == "calendar.events.list"
+
+ # Check parameters in nested resource
+ params = events_method["parameters"]
+ param_names = [p["name"] for p in params]
+ assert "calendarId" in param_names
+ assert "maxResults" in param_names
+ assert "orderBy" in param_names
+
+ def test_integration_calendar_api(self, converter_with_patched_build):
+ """Integration test using Calendar API specification."""
+ # Create and run the converter
+ openapi_spec = converter_with_patched_build.convert()
+
+ # Verify conversion results
+ assert openapi_spec["info"]["title"] == "Google Calendar API"
+ assert (
+ openapi_spec["servers"][0]["url"]
+ == "https://www.googleapis.com/calendar/v3"
+ )
+
+ # Check security schemes
+ security_schemes = openapi_spec["components"]["securitySchemes"]
+ assert "oauth2" in security_schemes
+ assert "apiKey" in security_schemes
+
+ # Check schemas
+ schemas = openapi_spec["components"]["schemas"]
+ assert "Calendar" in schemas
+ assert "Event" in schemas
+ assert "EventDateTime" in schemas
+
+ # Check paths
+ paths = openapi_spec["paths"]
+ assert "/calendars/{calendarId}" in paths
+ assert "/calendars" in paths
+ assert "/calendars/{calendarId}/events" in paths
+
+ # Check method details
+ get_events = paths["/calendars/{calendarId}/events"]["get"]
+ assert get_events["operationId"] == "calendar.events.list"
+
+ # Check parameter details
+ param_dict = {p["name"]: p for p in get_events["parameters"]}
+ assert "maxResults" in param_dict
+ max_results = param_dict["maxResults"]
+ assert max_results["in"] == "query"
+ assert max_results["schema"]["type"] == "integer"
+ assert max_results["schema"]["default"] == "250"
+
+
+@pytest.fixture
+def conftest_content():
+ """Returns content for a conftest.py file to help with testing."""
+ return """
+import pytest
+from unittest.mock import MagicMock
+
+# This file contains fixtures that can be shared across multiple test modules
+
+@pytest.fixture
+def mock_google_response():
+ \"\"\"Fixture that provides a mock response from Google's API.\"\"\"
+ return {"key": "value", "items": [{"id": 1}, {"id": 2}]}
+
+@pytest.fixture
+def mock_http_error():
+ \"\"\"Fixture that provides a mock HTTP error.\"\"\"
+ mock_resp = MagicMock()
+ mock_resp.status = 404
+ return HttpError(resp=mock_resp, content=b'Not Found')
+"""
+
+
+def test_generate_conftest_example(conftest_content):
+ """This is a meta-test that demonstrates how to generate a conftest.py file.
+
+ In a real project, you would create a separate conftest.py file.
+ """
+ # In a real scenario, you would write this to a file named conftest.py
+ # This test just verifies the conftest content is not empty
+ assert len(conftest_content) > 0
diff --git a/src/google/adk/tests/unittests/tools/openapi_tool/auth/credential_exchangers/test_auto_auth_credential_exchanger.py b/src/google/adk/tests/unittests/tools/openapi_tool/auth/credential_exchangers/test_auto_auth_credential_exchanger.py
new file mode 100644
index 0000000..bca0596
--- /dev/null
+++ b/src/google/adk/tests/unittests/tools/openapi_tool/auth/credential_exchangers/test_auto_auth_credential_exchanger.py
@@ -0,0 +1,145 @@
+# Copyright 2025 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Unit tests for AutoAuthCredentialExchanger."""
+
+from typing import Dict
+from typing import Optional
+from typing import Type
+from unittest.mock import MagicMock
+
+from google.adk.auth.auth_credential import AuthCredential
+from google.adk.auth.auth_credential import AuthCredentialTypes
+from google.adk.auth.auth_schemes import AuthScheme
+from google.adk.tools.openapi_tool.auth.credential_exchangers.auto_auth_credential_exchanger import AutoAuthCredentialExchanger
+from google.adk.tools.openapi_tool.auth.credential_exchangers.base_credential_exchanger import BaseAuthCredentialExchanger
+from google.adk.tools.openapi_tool.auth.credential_exchangers.oauth2_exchanger import OAuth2CredentialExchanger
+from google.adk.tools.openapi_tool.auth.credential_exchangers.service_account_exchanger import ServiceAccountCredentialExchanger
+import pytest
+
+
+class MockCredentialExchanger(BaseAuthCredentialExchanger):
+ """Mock credential exchanger for testing."""
+
+ def exchange_credential(
+ self,
+ auth_scheme: AuthScheme,
+ auth_credential: Optional[AuthCredential] = None,
+ ) -> AuthCredential:
+ """Mock exchange credential method."""
+ return auth_credential
+
+
+@pytest.fixture
+def auto_exchanger():
+ """Fixture for creating an AutoAuthCredentialExchanger instance."""
+ return AutoAuthCredentialExchanger()
+
+
+@pytest.fixture
+def auth_scheme():
+ """Fixture for creating a mock AuthScheme instance."""
+ scheme = MagicMock(spec=AuthScheme)
+ return scheme
+
+
+def test_init_with_custom_exchangers():
+ """Test initialization with custom exchangers."""
+ custom_exchangers: Dict[str, Type[BaseAuthCredentialExchanger]] = {
+ AuthCredentialTypes.API_KEY: MockCredentialExchanger
+ }
+
+ auto_exchanger = AutoAuthCredentialExchanger(
+ custom_exchangers=custom_exchangers
+ )
+
+ assert (
+ auto_exchanger.exchangers[AuthCredentialTypes.API_KEY]
+ == MockCredentialExchanger
+ )
+ assert (
+ auto_exchanger.exchangers[AuthCredentialTypes.OPEN_ID_CONNECT]
+ == OAuth2CredentialExchanger
+ )
+
+
+def test_exchange_credential_no_auth_credential(auto_exchanger, auth_scheme):
+ """Test exchange_credential with no auth_credential."""
+
+ assert auto_exchanger.exchange_credential(auth_scheme, None) is None
+
+
+def test_exchange_credential_no_exchange(auto_exchanger, auth_scheme):
+ """Test exchange_credential with NoExchangeCredentialExchanger."""
+ auth_credential = AuthCredential(auth_type=AuthCredentialTypes.API_KEY)
+
+ result = auto_exchanger.exchange_credential(auth_scheme, auth_credential)
+
+ assert result == auth_credential
+
+
+def test_exchange_credential_open_id_connect(auto_exchanger, auth_scheme):
+ """Test exchange_credential with OpenID Connect scheme."""
+ auth_credential = AuthCredential(
+ auth_type=AuthCredentialTypes.OPEN_ID_CONNECT
+ )
+ mock_exchanger = MagicMock(spec=OAuth2CredentialExchanger)
+ mock_exchanger.exchange_credential.return_value = "exchanged_credential"
+ auto_exchanger.exchangers[AuthCredentialTypes.OPEN_ID_CONNECT] = (
+ lambda: mock_exchanger
+ )
+
+ result = auto_exchanger.exchange_credential(auth_scheme, auth_credential)
+
+ assert result == "exchanged_credential"
+ mock_exchanger.exchange_credential.assert_called_once_with(
+ auth_scheme, auth_credential
+ )
+
+
+def test_exchange_credential_service_account(auto_exchanger, auth_scheme):
+ """Test exchange_credential with Service Account scheme."""
+ auth_credential = AuthCredential(
+ auth_type=AuthCredentialTypes.SERVICE_ACCOUNT
+ )
+ mock_exchanger = MagicMock(spec=ServiceAccountCredentialExchanger)
+ mock_exchanger.exchange_credential.return_value = "exchanged_credential_sa"
+ auto_exchanger.exchangers[AuthCredentialTypes.SERVICE_ACCOUNT] = (
+ lambda: mock_exchanger
+ )
+
+ result = auto_exchanger.exchange_credential(auth_scheme, auth_credential)
+
+ assert result == "exchanged_credential_sa"
+ mock_exchanger.exchange_credential.assert_called_once_with(
+ auth_scheme, auth_credential
+ )
+
+
+def test_exchange_credential_custom_exchanger(auto_exchanger, auth_scheme):
+ """Test that exchange_credential calls the correct (custom) exchanger."""
+ # Use a custom exchanger via the initialization
+ mock_exchanger = MagicMock(spec=MockCredentialExchanger)
+ mock_exchanger.exchange_credential.return_value = "custom_credential"
+ auto_exchanger.exchangers[AuthCredentialTypes.API_KEY] = (
+ lambda: mock_exchanger
+ )
+ auth_credential = AuthCredential(auth_type=AuthCredentialTypes.API_KEY)
+
+ result = auto_exchanger.exchange_credential(auth_scheme, auth_credential)
+
+ assert result == "custom_credential"
+ mock_exchanger.exchange_credential.assert_called_once_with(
+ auth_scheme, auth_credential
+ )
diff --git a/src/google/adk/tests/unittests/tools/openapi_tool/auth/credential_exchangers/test_base_auth_credential_exchanger.py b/src/google/adk/tests/unittests/tools/openapi_tool/auth/credential_exchangers/test_base_auth_credential_exchanger.py
new file mode 100644
index 0000000..1b00afb
--- /dev/null
+++ b/src/google/adk/tests/unittests/tools/openapi_tool/auth/credential_exchangers/test_base_auth_credential_exchanger.py
@@ -0,0 +1,68 @@
+# Copyright 2025 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Tests for the BaseAuthCredentialExchanger class."""
+
+from typing import Optional
+from unittest.mock import MagicMock
+
+from google.adk.auth.auth_credential import AuthCredential
+from google.adk.auth.auth_credential import AuthCredentialTypes
+from google.adk.auth.auth_schemes import AuthScheme
+from google.adk.tools.openapi_tool.auth.credential_exchangers.base_credential_exchanger import AuthCredentialMissingError
+from google.adk.tools.openapi_tool.auth.credential_exchangers.base_credential_exchanger import BaseAuthCredentialExchanger
+import pytest
+
+
+class MockAuthCredentialExchanger(BaseAuthCredentialExchanger):
+
+ def exchange_credential(
+ self,
+ auth_scheme: AuthScheme,
+ auth_credential: Optional[AuthCredential] = None,
+ ) -> AuthCredential:
+ return AuthCredential(token="some-token")
+
+
+class TestBaseAuthCredentialExchanger:
+ """Tests for the BaseAuthCredentialExchanger class."""
+
+ @pytest.fixture
+ def base_exchanger(self):
+ return BaseAuthCredentialExchanger()
+
+ @pytest.fixture
+ def auth_scheme(self):
+ scheme = MagicMock(spec=AuthScheme)
+ scheme.type = "apiKey"
+ scheme.name = "x-api-key"
+ return scheme
+
+ def test_exchange_credential_not_implemented(
+ self, base_exchanger, auth_scheme
+ ):
+ auth_credential = AuthCredential(
+ auth_type=AuthCredentialTypes.API_KEY, token="some-token"
+ )
+ with pytest.raises(NotImplementedError) as exc_info:
+ base_exchanger.exchange_credential(auth_scheme, auth_credential)
+ assert "Subclasses must implement exchange_credential." in str(
+ exc_info.value
+ )
+
+ def test_auth_credential_missing_error(self):
+ error_message = "Test missing credential"
+ error = AuthCredentialMissingError(error_message)
+ # assert error.message == error_message
+ assert str(error) == error_message
diff --git a/src/google/adk/tests/unittests/tools/openapi_tool/auth/credential_exchangers/test_oauth2_exchanger.py b/src/google/adk/tests/unittests/tools/openapi_tool/auth/credential_exchangers/test_oauth2_exchanger.py
new file mode 100644
index 0000000..c028e0e
--- /dev/null
+++ b/src/google/adk/tests/unittests/tools/openapi_tool/auth/credential_exchangers/test_oauth2_exchanger.py
@@ -0,0 +1,153 @@
+# Copyright 2025 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Tests for OAuth2CredentialExchanger."""
+
+import copy
+from unittest.mock import MagicMock
+
+from google.adk.auth.auth_credential import AuthCredential
+from google.adk.auth.auth_credential import AuthCredentialTypes
+from google.adk.auth.auth_credential import OAuth2Auth
+from google.adk.auth.auth_schemes import AuthSchemeType
+from google.adk.auth.auth_schemes import OpenIdConnectWithConfig
+from google.adk.tools.openapi_tool.auth.credential_exchangers import OAuth2CredentialExchanger
+from google.adk.tools.openapi_tool.auth.credential_exchangers.base_credential_exchanger import AuthCredentialMissingError
+import pytest
+
+
+@pytest.fixture
+def oauth2_exchanger():
+ return OAuth2CredentialExchanger()
+
+
+@pytest.fixture
+def auth_scheme():
+ openid_config = OpenIdConnectWithConfig(
+ type_=AuthSchemeType.openIdConnect,
+ authorization_endpoint="https://example.com/auth",
+ token_endpoint="https://example.com/token",
+ scopes=["openid", "profile"],
+ )
+ return openid_config
+
+
+def test_check_scheme_credential_type_success(oauth2_exchanger, auth_scheme):
+ auth_credential = AuthCredential(
+ auth_type=AuthCredentialTypes.OAUTH2,
+ oauth2=OAuth2Auth(
+ client_id="test_client",
+ client_secret="test_secret",
+ redirect_uri="http://localhost:8080",
+ ),
+ )
+ # Check that the method does not raise an exception
+ oauth2_exchanger._check_scheme_credential_type(auth_scheme, auth_credential)
+
+
+def test_check_scheme_credential_type_missing_credential(
+ oauth2_exchanger, auth_scheme
+):
+ # Test case: auth_credential is None
+ with pytest.raises(ValueError) as exc_info:
+ oauth2_exchanger._check_scheme_credential_type(auth_scheme, None)
+ assert "auth_credential is empty" in str(exc_info.value)
+
+
+def test_check_scheme_credential_type_invalid_scheme_type(
+ oauth2_exchanger, auth_scheme: OpenIdConnectWithConfig
+):
+ """Test case: Invalid AuthSchemeType."""
+ # Test case: Invalid AuthSchemeType
+ invalid_scheme = copy.deepcopy(auth_scheme)
+ invalid_scheme.type_ = AuthSchemeType.apiKey
+ auth_credential = AuthCredential(
+ auth_type=AuthCredentialTypes.OAUTH2,
+ oauth2=OAuth2Auth(
+ client_id="test_client",
+ client_secret="test_secret",
+ redirect_uri="http://localhost:8080",
+ ),
+ )
+ with pytest.raises(ValueError) as exc_info:
+ oauth2_exchanger._check_scheme_credential_type(
+ invalid_scheme, auth_credential
+ )
+ assert "Invalid security scheme" in str(exc_info.value)
+
+
+def test_check_scheme_credential_type_missing_openid_connect(
+ oauth2_exchanger, auth_scheme
+):
+ auth_credential = AuthCredential(
+ auth_type=AuthCredentialTypes.OAUTH2,
+ )
+ with pytest.raises(ValueError) as exc_info:
+ oauth2_exchanger._check_scheme_credential_type(auth_scheme, auth_credential)
+ assert "auth_credential is not configured with oauth2" in str(exc_info.value)
+
+
+def test_generate_auth_token_success(
+ oauth2_exchanger, auth_scheme, monkeypatch
+):
+ """Test case: Successful generation of access token."""
+ # Test case: Successful generation of access token
+ auth_credential = AuthCredential(
+ auth_type=AuthCredentialTypes.OAUTH2,
+ oauth2=OAuth2Auth(
+ client_id="test_client",
+ client_secret="test_secret",
+ redirect_uri="http://localhost:8080",
+ auth_response_uri="https://example.com/callback?code=test_code",
+ token={"access_token": "test_access_token"},
+ ),
+ )
+ updated_credential = oauth2_exchanger.generate_auth_token(auth_credential)
+
+ assert updated_credential.auth_type == AuthCredentialTypes.HTTP
+ assert updated_credential.http.scheme == "bearer"
+ assert updated_credential.http.credentials.token == "test_access_token"
+
+
+def test_exchange_credential_generate_auth_token(
+ oauth2_exchanger, auth_scheme, monkeypatch
+):
+ """Test exchange_credential when auth_response_uri is present."""
+ auth_credential = AuthCredential(
+ auth_type=AuthCredentialTypes.OAUTH2,
+ oauth2=OAuth2Auth(
+ client_id="test_client",
+ client_secret="test_secret",
+ redirect_uri="http://localhost:8080",
+ auth_response_uri="https://example.com/callback?code=test_code",
+ token={"access_token": "test_access_token"},
+ ),
+ )
+
+ updated_credential = oauth2_exchanger.exchange_credential(
+ auth_scheme, auth_credential
+ )
+
+ assert updated_credential.auth_type == AuthCredentialTypes.HTTP
+ assert updated_credential.http.scheme == "bearer"
+ assert updated_credential.http.credentials.token == "test_access_token"
+
+
+def test_exchange_credential_auth_missing(oauth2_exchanger, auth_scheme):
+ """Test exchange_credential when auth_credential is missing."""
+ with pytest.raises(ValueError) as exc_info:
+ oauth2_exchanger.exchange_credential(auth_scheme, None)
+ assert "auth_credential is empty. Please create AuthCredential using" in str(
+ exc_info.value
+ )
diff --git a/src/google/adk/tests/unittests/tools/openapi_tool/auth/credential_exchangers/test_service_account_exchanger.py b/src/google/adk/tests/unittests/tools/openapi_tool/auth/credential_exchangers/test_service_account_exchanger.py
new file mode 100644
index 0000000..32a144d
--- /dev/null
+++ b/src/google/adk/tests/unittests/tools/openapi_tool/auth/credential_exchangers/test_service_account_exchanger.py
@@ -0,0 +1,196 @@
+# Copyright 2025 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Unit tests for the service account credential exchanger."""
+
+from unittest.mock import MagicMock
+
+from google.adk.auth.auth_credential import AuthCredential
+from google.adk.auth.auth_credential import AuthCredentialTypes
+from google.adk.auth.auth_credential import ServiceAccount
+from google.adk.auth.auth_credential import ServiceAccountCredential
+from google.adk.auth.auth_schemes import AuthScheme
+from google.adk.auth.auth_schemes import AuthSchemeType
+from google.adk.tools.openapi_tool.auth.credential_exchangers.base_credential_exchanger import AuthCredentialMissingError
+from google.adk.tools.openapi_tool.auth.credential_exchangers.service_account_exchanger import ServiceAccountCredentialExchanger
+import google.auth
+import pytest
+
+
+@pytest.fixture
+def service_account_exchanger():
+ return ServiceAccountCredentialExchanger()
+
+
+@pytest.fixture
+def auth_scheme():
+ scheme = MagicMock(spec=AuthScheme)
+ scheme.type_ = AuthSchemeType.oauth2
+ scheme.description = "Google Service Account"
+ return scheme
+
+
+def test_exchange_credential_success(
+ service_account_exchanger, auth_scheme, monkeypatch
+):
+ """Test successful exchange of service account credentials."""
+ mock_credentials = MagicMock()
+ mock_credentials.token = "mock_access_token"
+
+ # Mock the from_service_account_info method
+ mock_from_service_account_info = MagicMock(return_value=mock_credentials)
+ target_path = (
+ "google.adk.tools.openapi_tool.auth.credential_exchangers."
+ "service_account_exchanger.service_account.Credentials."
+ "from_service_account_info"
+ )
+ monkeypatch.setattr(
+ target_path,
+ mock_from_service_account_info,
+ )
+
+ # Mock the refresh method
+ mock_credentials.refresh = MagicMock()
+
+ # Create a valid AuthCredential with service account info
+ auth_credential = AuthCredential(
+ auth_type=AuthCredentialTypes.SERVICE_ACCOUNT,
+ service_account=ServiceAccount(
+ service_account_credential=ServiceAccountCredential(
+ type_="service_account",
+ project_id="your_project_id",
+ private_key_id="your_private_key_id",
+ private_key="-----BEGIN PRIVATE KEY-----...",
+ client_email="...@....iam.gserviceaccount.com",
+ client_id="your_client_id",
+ auth_uri="https://accounts.google.com/o/oauth2/auth",
+ token_uri="https://oauth2.googleapis.com/token",
+ auth_provider_x509_cert_url=(
+ "https://www.googleapis.com/oauth2/v1/certs"
+ ),
+ client_x509_cert_url=(
+ "https://www.googleapis.com/robot/v1/metadata/x509/..."
+ ),
+ universe_domain="googleapis.com",
+ ),
+ scopes=["https://www.googleapis.com/auth/cloud-platform"],
+ ),
+ )
+
+ result = service_account_exchanger.exchange_credential(
+ auth_scheme, auth_credential
+ )
+
+ assert result.auth_type == AuthCredentialTypes.HTTP
+ assert result.http.scheme == "bearer"
+ assert result.http.credentials.token == "mock_access_token"
+ mock_from_service_account_info.assert_called_once()
+ mock_credentials.refresh.assert_called_once()
+
+
+def test_exchange_credential_use_default_credential_success(
+ service_account_exchanger, auth_scheme, monkeypatch
+):
+ """Test successful exchange of service account credentials using default credential."""
+ mock_credentials = MagicMock()
+ mock_credentials.token = "mock_access_token"
+ mock_google_auth_default = MagicMock(
+ return_value=(mock_credentials, "test_project")
+ )
+ monkeypatch.setattr(google.auth, "default", mock_google_auth_default)
+
+ auth_credential = AuthCredential(
+ auth_type=AuthCredentialTypes.SERVICE_ACCOUNT,
+ service_account=ServiceAccount(
+ use_default_credential=True,
+ scopes=["https://www.googleapis.com/auth/cloud-platform"],
+ ),
+ )
+
+ result = service_account_exchanger.exchange_credential(
+ auth_scheme, auth_credential
+ )
+
+ assert result.auth_type == AuthCredentialTypes.HTTP
+ assert result.http.scheme == "bearer"
+ assert result.http.credentials.token == "mock_access_token"
+ mock_google_auth_default.assert_called_once()
+ mock_credentials.refresh.assert_called_once()
+
+
+def test_exchange_credential_missing_auth_credential(
+ service_account_exchanger, auth_scheme
+):
+ """Test missing auth credential during exchange."""
+ with pytest.raises(AuthCredentialMissingError) as exc_info:
+ service_account_exchanger.exchange_credential(auth_scheme, None)
+ assert "Service account credentials are missing" in str(exc_info.value)
+
+
+def test_exchange_credential_missing_service_account_info(
+ service_account_exchanger, auth_scheme
+):
+ """Test missing service account info during exchange."""
+ auth_credential = AuthCredential(
+ auth_type=AuthCredentialTypes.SERVICE_ACCOUNT,
+ )
+ with pytest.raises(AuthCredentialMissingError) as exc_info:
+ service_account_exchanger.exchange_credential(auth_scheme, auth_credential)
+ assert "Service account credentials are missing" in str(exc_info.value)
+
+
+def test_exchange_credential_exchange_failure(
+ service_account_exchanger, auth_scheme, monkeypatch
+):
+ """Test failure during service account token exchange."""
+ mock_from_service_account_info = MagicMock(
+ side_effect=Exception("Failed to load credentials")
+ )
+ target_path = (
+ "google.adk.tools.openapi_tool.auth.credential_exchangers."
+ "service_account_exchanger.service_account.Credentials."
+ "from_service_account_info"
+ )
+ monkeypatch.setattr(
+ target_path,
+ mock_from_service_account_info,
+ )
+
+ auth_credential = AuthCredential(
+ auth_type=AuthCredentialTypes.SERVICE_ACCOUNT,
+ service_account=ServiceAccount(
+ service_account_credential=ServiceAccountCredential(
+ type_="service_account",
+ project_id="your_project_id",
+ private_key_id="your_private_key_id",
+ private_key="-----BEGIN PRIVATE KEY-----...",
+ client_email="...@....iam.gserviceaccount.com",
+ client_id="your_client_id",
+ auth_uri="https://accounts.google.com/o/oauth2/auth",
+ token_uri="https://oauth2.googleapis.com/token",
+ auth_provider_x509_cert_url=(
+ "https://www.googleapis.com/oauth2/v1/certs"
+ ),
+ client_x509_cert_url=(
+ "https://www.googleapis.com/robot/v1/metadata/x509/..."
+ ),
+ universe_domain="googleapis.com",
+ ),
+ scopes=["https://www.googleapis.com/auth/cloud-platform"],
+ ),
+ )
+ with pytest.raises(AuthCredentialMissingError) as exc_info:
+ service_account_exchanger.exchange_credential(auth_scheme, auth_credential)
+ assert "Failed to exchange service account token" in str(exc_info.value)
+ mock_from_service_account_info.assert_called_once()
diff --git a/src/google/adk/tests/unittests/tools/openapi_tool/auth/test_auth_helper.py b/src/google/adk/tests/unittests/tools/openapi_tool/auth/test_auth_helper.py
new file mode 100644
index 0000000..af7bf4f
--- /dev/null
+++ b/src/google/adk/tests/unittests/tools/openapi_tool/auth/test_auth_helper.py
@@ -0,0 +1,573 @@
+# Copyright 2025 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from unittest.mock import patch
+
+from fastapi.openapi.models import APIKey
+from fastapi.openapi.models import APIKeyIn
+from fastapi.openapi.models import HTTPBase
+from fastapi.openapi.models import HTTPBearer
+from fastapi.openapi.models import OAuth2
+from fastapi.openapi.models import OpenIdConnect
+from google.adk.auth.auth_credential import AuthCredential
+from google.adk.auth.auth_credential import AuthCredentialTypes
+from google.adk.auth.auth_credential import HttpAuth
+from google.adk.auth.auth_credential import HttpCredentials
+from google.adk.auth.auth_credential import ServiceAccount
+from google.adk.auth.auth_credential import ServiceAccountCredential
+from google.adk.auth.auth_schemes import AuthSchemeType
+from google.adk.auth.auth_schemes import OpenIdConnectWithConfig
+from google.adk.tools.openapi_tool.auth.auth_helpers import credential_to_param
+from google.adk.tools.openapi_tool.auth.auth_helpers import dict_to_auth_scheme
+from google.adk.tools.openapi_tool.auth.auth_helpers import INTERNAL_AUTH_PREFIX
+from google.adk.tools.openapi_tool.auth.auth_helpers import openid_dict_to_scheme_credential
+from google.adk.tools.openapi_tool.auth.auth_helpers import openid_url_to_scheme_credential
+from google.adk.tools.openapi_tool.auth.auth_helpers import service_account_dict_to_scheme_credential
+from google.adk.tools.openapi_tool.auth.auth_helpers import service_account_scheme_credential
+from google.adk.tools.openapi_tool.auth.auth_helpers import token_to_scheme_credential
+import pytest
+import requests
+
+
+def test_token_to_scheme_credential_api_key_header():
+ scheme, credential = token_to_scheme_credential(
+ "apikey", "header", "X-API-Key", "test_key"
+ )
+
+ assert isinstance(scheme, APIKey)
+ assert scheme.type_ == AuthSchemeType.apiKey
+ assert scheme.in_ == APIKeyIn.header
+ assert scheme.name == "X-API-Key"
+ assert credential == AuthCredential(
+ auth_type=AuthCredentialTypes.API_KEY, api_key="test_key"
+ )
+
+
+def test_token_to_scheme_credential_api_key_query():
+ scheme, credential = token_to_scheme_credential(
+ "apikey", "query", "api_key", "test_key"
+ )
+
+ assert isinstance(scheme, APIKey)
+ assert scheme.type_ == AuthSchemeType.apiKey
+ assert scheme.in_ == APIKeyIn.query
+ assert scheme.name == "api_key"
+ assert credential == AuthCredential(
+ auth_type=AuthCredentialTypes.API_KEY, api_key="test_key"
+ )
+
+
+def test_token_to_scheme_credential_api_key_cookie():
+ scheme, credential = token_to_scheme_credential(
+ "apikey", "cookie", "session_id", "test_key"
+ )
+
+ assert isinstance(scheme, APIKey)
+ assert scheme.type_ == AuthSchemeType.apiKey
+ assert scheme.in_ == APIKeyIn.cookie
+ assert scheme.name == "session_id"
+ assert credential == AuthCredential(
+ auth_type=AuthCredentialTypes.API_KEY, api_key="test_key"
+ )
+
+
+def test_token_to_scheme_credential_api_key_no_credential():
+ scheme, credential = token_to_scheme_credential(
+ "apikey", "cookie", "session_id"
+ )
+
+ assert isinstance(scheme, APIKey)
+ assert credential is None
+
+
+def test_token_to_scheme_credential_oauth2_token():
+ scheme, credential = token_to_scheme_credential(
+ "oauth2Token", "header", "Authorization", "test_token"
+ )
+
+ assert isinstance(scheme, HTTPBearer)
+ assert scheme.bearerFormat == "JWT"
+ assert credential == AuthCredential(
+ auth_type=AuthCredentialTypes.HTTP,
+ http=HttpAuth(
+ scheme="bearer", credentials=HttpCredentials(token="test_token")
+ ),
+ )
+
+
+def test_token_to_scheme_credential_oauth2_no_credential():
+ scheme, credential = token_to_scheme_credential(
+ "oauth2Token", "header", "Authorization"
+ )
+
+ assert isinstance(scheme, HTTPBearer)
+ assert credential is None
+
+
+def test_service_account_dict_to_scheme_credential():
+ config = {
+ "type": "service_account",
+ "project_id": "project_id",
+ "private_key_id": "private_key_id",
+ "private_key": "private_key",
+ "client_email": "client_email",
+ "client_id": "client_id",
+ "auth_uri": "auth_uri",
+ "token_uri": "token_uri",
+ "auth_provider_x509_cert_url": "auth_provider_x509_cert_url",
+ "client_x509_cert_url": "client_x509_cert_url",
+ "universe_domain": "universe_domain",
+ }
+ scopes = ["scope1", "scope2"]
+
+ scheme, credential = service_account_dict_to_scheme_credential(config, scopes)
+
+ assert isinstance(scheme, HTTPBearer)
+ assert scheme.bearerFormat == "JWT"
+ assert credential.auth_type == AuthCredentialTypes.SERVICE_ACCOUNT
+ assert credential.service_account.scopes == scopes
+ assert (
+ credential.service_account.service_account_credential.project_id
+ == "project_id"
+ )
+
+
+def test_service_account_scheme_credential():
+ config = ServiceAccount(
+ service_account_credential=ServiceAccountCredential(
+ type="service_account",
+ project_id="project_id",
+ private_key_id="private_key_id",
+ private_key="private_key",
+ client_email="client_email",
+ client_id="client_id",
+ auth_uri="auth_uri",
+ token_uri="token_uri",
+ auth_provider_x509_cert_url="auth_provider_x509_cert_url",
+ client_x509_cert_url="client_x509_cert_url",
+ universe_domain="universe_domain",
+ ),
+ scopes=["scope1", "scope2"],
+ )
+
+ scheme, credential = service_account_scheme_credential(config)
+
+ assert isinstance(scheme, HTTPBearer)
+ assert scheme.bearerFormat == "JWT"
+ assert credential.auth_type == AuthCredentialTypes.SERVICE_ACCOUNT
+ assert credential.service_account == config
+
+
+def test_openid_dict_to_scheme_credential():
+ config_dict = {
+ "authorization_endpoint": "auth_url",
+ "token_endpoint": "token_url",
+ "openIdConnectUrl": "openid_url",
+ }
+ credential_dict = {
+ "client_id": "client_id",
+ "client_secret": "client_secret",
+ "redirect_uri": "redirect_uri",
+ }
+ scopes = ["scope1", "scope2"]
+
+ scheme, credential = openid_dict_to_scheme_credential(
+ config_dict, scopes, credential_dict
+ )
+
+ assert isinstance(scheme, OpenIdConnectWithConfig)
+ assert scheme.authorization_endpoint == "auth_url"
+ assert scheme.token_endpoint == "token_url"
+ assert scheme.scopes == scopes
+ assert credential.auth_type == AuthCredentialTypes.OPEN_ID_CONNECT
+ assert credential.oauth2.client_id == "client_id"
+ assert credential.oauth2.client_secret == "client_secret"
+ assert credential.oauth2.redirect_uri == "redirect_uri"
+
+
+def test_openid_dict_to_scheme_credential_no_openid_url():
+ config_dict = {
+ "authorization_endpoint": "auth_url",
+ "token_endpoint": "token_url",
+ }
+ credential_dict = {
+ "client_id": "client_id",
+ "client_secret": "client_secret",
+ "redirect_uri": "redirect_uri",
+ }
+ scopes = ["scope1", "scope2"]
+
+ scheme, credential = openid_dict_to_scheme_credential(
+ config_dict, scopes, credential_dict
+ )
+
+ assert scheme.openIdConnectUrl == ""
+
+
+def test_openid_dict_to_scheme_credential_google_oauth_credential():
+ config_dict = {
+ "authorization_endpoint": "auth_url",
+ "token_endpoint": "token_url",
+ "openIdConnectUrl": "openid_url",
+ }
+ credential_dict = {
+ "web": {
+ "client_id": "client_id",
+ "client_secret": "client_secret",
+ "redirect_uri": "redirect_uri",
+ }
+ }
+ scopes = ["scope1", "scope2"]
+
+ scheme, credential = openid_dict_to_scheme_credential(
+ config_dict, scopes, credential_dict
+ )
+
+ assert isinstance(scheme, OpenIdConnectWithConfig)
+ assert credential.auth_type == AuthCredentialTypes.OPEN_ID_CONNECT
+ assert credential.oauth2.client_id == "client_id"
+ assert credential.oauth2.client_secret == "client_secret"
+ assert credential.oauth2.redirect_uri == "redirect_uri"
+
+
+def test_openid_dict_to_scheme_credential_invalid_config():
+ config_dict = {
+ "invalid_field": "value",
+ }
+ credential_dict = {
+ "client_id": "client_id",
+ "client_secret": "client_secret",
+ }
+ scopes = ["scope1", "scope2"]
+
+ with pytest.raises(ValueError, match="Invalid OpenID Connect configuration"):
+ openid_dict_to_scheme_credential(config_dict, scopes, credential_dict)
+
+
+def test_openid_dict_to_scheme_credential_missing_credential_fields():
+ config_dict = {
+ "authorization_endpoint": "auth_url",
+ "token_endpoint": "token_url",
+ }
+ credential_dict = {
+ "client_id": "client_id",
+ }
+ scopes = ["scope1", "scope2"]
+
+ with pytest.raises(
+ ValueError,
+ match="Missing required fields in credential_dict: client_secret",
+ ):
+ openid_dict_to_scheme_credential(config_dict, scopes, credential_dict)
+
+
+@patch("requests.get")
+def test_openid_url_to_scheme_credential(mock_get):
+ mock_response = {
+ "authorization_endpoint": "auth_url",
+ "token_endpoint": "token_url",
+ "userinfo_endpoint": "userinfo_url",
+ }
+ mock_get.return_value.json.return_value = mock_response
+ mock_get.return_value.raise_for_status.return_value = None
+ credential_dict = {
+ "client_id": "client_id",
+ "client_secret": "client_secret",
+ "redirect_uri": "redirect_uri",
+ }
+ scopes = ["scope1", "scope2"]
+
+ scheme, credential = openid_url_to_scheme_credential(
+ "openid_url", scopes, credential_dict
+ )
+
+ assert isinstance(scheme, OpenIdConnectWithConfig)
+ assert scheme.authorization_endpoint == "auth_url"
+ assert scheme.token_endpoint == "token_url"
+ assert scheme.scopes == scopes
+ assert credential.auth_type == AuthCredentialTypes.OPEN_ID_CONNECT
+ assert credential.oauth2.client_id == "client_id"
+ assert credential.oauth2.client_secret == "client_secret"
+ assert credential.oauth2.redirect_uri == "redirect_uri"
+ mock_get.assert_called_once_with("openid_url", timeout=10)
+
+
+@patch("requests.get")
+def test_openid_url_to_scheme_credential_no_openid_url(mock_get):
+ mock_response = {
+ "authorization_endpoint": "auth_url",
+ "token_endpoint": "token_url",
+ "userinfo_endpoint": "userinfo_url",
+ }
+ mock_get.return_value.json.return_value = mock_response
+ mock_get.return_value.raise_for_status.return_value = None
+ credential_dict = {
+ "client_id": "client_id",
+ "client_secret": "client_secret",
+ "redirect_uri": "redirect_uri",
+ }
+ scopes = ["scope1", "scope2"]
+
+ scheme, credential = openid_url_to_scheme_credential(
+ "openid_url", scopes, credential_dict
+ )
+
+ assert scheme.openIdConnectUrl == "openid_url"
+
+
+@patch("requests.get")
+def test_openid_url_to_scheme_credential_request_exception(mock_get):
+ mock_get.side_effect = requests.exceptions.RequestException("Test Error")
+ credential_dict = {"client_id": "client_id", "client_secret": "client_secret"}
+
+ with pytest.raises(
+ ValueError, match="Failed to fetch OpenID configuration from openid_url"
+ ):
+ openid_url_to_scheme_credential("openid_url", [], credential_dict)
+
+
+@patch("requests.get")
+def test_openid_url_to_scheme_credential_invalid_json(mock_get):
+ mock_get.return_value.json.side_effect = ValueError("Invalid JSON")
+ mock_get.return_value.raise_for_status.return_value = None
+ credential_dict = {"client_id": "client_id", "client_secret": "client_secret"}
+
+ with pytest.raises(
+ ValueError,
+ match=(
+ "Invalid JSON response from OpenID configuration endpoint openid_url"
+ ),
+ ):
+ openid_url_to_scheme_credential("openid_url", [], credential_dict)
+
+
+def test_credential_to_param_api_key_header():
+ auth_scheme = APIKey(
+ **{"type": "apiKey", "in": "header", "name": "X-API-Key"}
+ )
+ auth_credential = AuthCredential(
+ auth_type=AuthCredentialTypes.API_KEY, api_key="test_key"
+ )
+
+ param, kwargs = credential_to_param(auth_scheme, auth_credential)
+
+ assert param.original_name == "X-API-Key"
+ assert param.param_location == "header"
+ assert kwargs == {INTERNAL_AUTH_PREFIX + "X-API-Key": "test_key"}
+
+
+def test_credential_to_param_api_key_query():
+ auth_scheme = APIKey(**{"type": "apiKey", "in": "query", "name": "api_key"})
+ auth_credential = AuthCredential(
+ auth_type=AuthCredentialTypes.API_KEY, api_key="test_key"
+ )
+
+ param, kwargs = credential_to_param(auth_scheme, auth_credential)
+
+ assert param.original_name == "api_key"
+ assert param.param_location == "query"
+ assert kwargs == {INTERNAL_AUTH_PREFIX + "api_key": "test_key"}
+
+
+def test_credential_to_param_api_key_cookie():
+ auth_scheme = APIKey(
+ **{"type": "apiKey", "in": "cookie", "name": "session_id"}
+ )
+ auth_credential = AuthCredential(
+ auth_type=AuthCredentialTypes.API_KEY, api_key="test_key"
+ )
+
+ param, kwargs = credential_to_param(auth_scheme, auth_credential)
+
+ assert param.original_name == "session_id"
+ assert param.param_location == "cookie"
+ assert kwargs == {INTERNAL_AUTH_PREFIX + "session_id": "test_key"}
+
+
+def test_credential_to_param_http_bearer():
+ auth_scheme = HTTPBearer(bearerFormat="JWT")
+ auth_credential = AuthCredential(
+ auth_type=AuthCredentialTypes.HTTP,
+ http=HttpAuth(
+ scheme="bearer", credentials=HttpCredentials(token="test_token")
+ ),
+ )
+
+ param, kwargs = credential_to_param(auth_scheme, auth_credential)
+
+ assert param.original_name == "Authorization"
+ assert param.param_location == "header"
+ assert kwargs == {INTERNAL_AUTH_PREFIX + "Authorization": "Bearer test_token"}
+
+
+def test_credential_to_param_http_basic_not_supported():
+ auth_scheme = HTTPBase(scheme="basic")
+ auth_credential = AuthCredential(
+ auth_type=AuthCredentialTypes.HTTP,
+ http=HttpAuth(
+ scheme="basic",
+ credentials=HttpCredentials(username="user", password="password"),
+ ),
+ )
+
+ with pytest.raises(
+ NotImplementedError, match="Basic Authentication is not supported."
+ ):
+ credential_to_param(auth_scheme, auth_credential)
+
+
+def test_credential_to_param_http_invalid_credentials_no_http():
+ auth_scheme = HTTPBase(scheme="basic")
+ auth_credential = AuthCredential(auth_type=AuthCredentialTypes.HTTP)
+
+ with pytest.raises(ValueError, match="Invalid HTTP auth credentials"):
+ credential_to_param(auth_scheme, auth_credential)
+
+
+def test_credential_to_param_oauth2():
+ auth_scheme = OAuth2(flows={})
+ auth_credential = AuthCredential(
+ auth_type=AuthCredentialTypes.HTTP,
+ http=HttpAuth(
+ scheme="bearer", credentials=HttpCredentials(token="test_token")
+ ),
+ )
+
+ param, kwargs = credential_to_param(auth_scheme, auth_credential)
+
+ assert param.original_name == "Authorization"
+ assert param.param_location == "header"
+ assert kwargs == {INTERNAL_AUTH_PREFIX + "Authorization": "Bearer test_token"}
+
+
+def test_credential_to_param_openid_connect():
+ auth_scheme = OpenIdConnect(openIdConnectUrl="openid_url")
+ auth_credential = AuthCredential(
+ auth_type=AuthCredentialTypes.HTTP,
+ http=HttpAuth(
+ scheme="bearer", credentials=HttpCredentials(token="test_token")
+ ),
+ )
+
+ param, kwargs = credential_to_param(auth_scheme, auth_credential)
+
+ assert param.original_name == "Authorization"
+ assert param.param_location == "header"
+ assert kwargs == {INTERNAL_AUTH_PREFIX + "Authorization": "Bearer test_token"}
+
+
+def test_credential_to_param_openid_no_credential():
+ auth_scheme = OpenIdConnect(openIdConnectUrl="openid_url")
+
+ param, kwargs = credential_to_param(auth_scheme, None)
+
+ assert param == None
+ assert kwargs == None
+
+
+def test_credential_to_param_oauth2_no_credential():
+ auth_scheme = OAuth2(flows={})
+
+ param, kwargs = credential_to_param(auth_scheme, None)
+
+ assert param == None
+ assert kwargs == None
+
+
+def test_dict_to_auth_scheme_api_key():
+ data = {"type": "apiKey", "in": "header", "name": "X-API-Key"}
+
+ scheme = dict_to_auth_scheme(data)
+
+ assert isinstance(scheme, APIKey)
+ assert scheme.type_ == AuthSchemeType.apiKey
+ assert scheme.in_ == APIKeyIn.header
+ assert scheme.name == "X-API-Key"
+
+
+def test_dict_to_auth_scheme_http_bearer():
+ data = {"type": "http", "scheme": "bearer", "bearerFormat": "JWT"}
+
+ scheme = dict_to_auth_scheme(data)
+
+ assert isinstance(scheme, HTTPBearer)
+ assert scheme.scheme == "bearer"
+ assert scheme.bearerFormat == "JWT"
+
+
+def test_dict_to_auth_scheme_http_base():
+ data = {"type": "http", "scheme": "basic"}
+
+ scheme = dict_to_auth_scheme(data)
+
+ assert isinstance(scheme, HTTPBase)
+ assert scheme.scheme == "basic"
+
+
+def test_dict_to_auth_scheme_oauth2():
+ data = {
+ "type": "oauth2",
+ "flows": {
+ "authorizationCode": {
+ "authorizationUrl": "https://example.com/auth",
+ "tokenUrl": "https://example.com/token",
+ }
+ },
+ }
+
+ scheme = dict_to_auth_scheme(data)
+
+ assert isinstance(scheme, OAuth2)
+ assert hasattr(scheme.flows, "authorizationCode")
+
+
+def test_dict_to_auth_scheme_openid_connect():
+ data = {
+ "type": "openIdConnect",
+ "openIdConnectUrl": (
+ "https://example.com/.well-known/openid-configuration"
+ ),
+ }
+
+ scheme = dict_to_auth_scheme(data)
+
+ assert isinstance(scheme, OpenIdConnect)
+ assert (
+ scheme.openIdConnectUrl
+ == "https://example.com/.well-known/openid-configuration"
+ )
+
+
+def test_dict_to_auth_scheme_missing_type():
+ data = {"in": "header", "name": "X-API-Key"}
+ with pytest.raises(
+ ValueError, match="Missing 'type' field in security scheme dictionary."
+ ):
+ dict_to_auth_scheme(data)
+
+
+def test_dict_to_auth_scheme_invalid_type():
+ data = {"type": "invalid", "in": "header", "name": "X-API-Key"}
+ with pytest.raises(ValueError, match="Invalid security scheme type: invalid"):
+ dict_to_auth_scheme(data)
+
+
+def test_dict_to_auth_scheme_invalid_data():
+ data = {"type": "apiKey", "in": "header"} # Missing 'name'
+ with pytest.raises(ValueError, match="Invalid security scheme data"):
+ dict_to_auth_scheme(data)
+
+
+if __name__ == "__main__":
+ pytest.main([__file__])
diff --git a/src/google/adk/tests/unittests/tools/openapi_tool/common/test_common.py b/src/google/adk/tests/unittests/tools/openapi_tool/common/test_common.py
new file mode 100644
index 0000000..f20de57
--- /dev/null
+++ b/src/google/adk/tests/unittests/tools/openapi_tool/common/test_common.py
@@ -0,0 +1,436 @@
+# Copyright 2025 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import Any
+from typing import Dict
+from typing import List
+
+from fastapi.openapi.models import Response, Schema
+from google.adk.tools.openapi_tool.common.common import ApiParameter
+from google.adk.tools.openapi_tool.common.common import PydocHelper
+from google.adk.tools.openapi_tool.common.common import rename_python_keywords
+from google.adk.tools.openapi_tool.common.common import to_snake_case
+from google.adk.tools.openapi_tool.common.common import TypeHintHelper
+import pytest
+
+
+def dict_to_responses(input: Dict[str, Any]) -> Dict[str, Response]:
+ return {k: Response.model_validate(input[k]) for k in input}
+
+
+class TestToSnakeCase:
+
+ @pytest.mark.parametrize(
+ 'input_str, expected_output',
+ [
+ ('lowerCamelCase', 'lower_camel_case'),
+ ('UpperCamelCase', 'upper_camel_case'),
+ ('space separated', 'space_separated'),
+ ('REST API', 'rest_api'),
+ ('Mixed_CASE with_Spaces', 'mixed_case_with_spaces'),
+ ('__init__', 'init'),
+ ('APIKey', 'api_key'),
+ ('SomeLongURL', 'some_long_url'),
+ ('CONSTANT_CASE', 'constant_case'),
+ ('already_snake_case', 'already_snake_case'),
+ ('single', 'single'),
+ ('', ''),
+ (' spaced ', 'spaced'),
+ ('with123numbers', 'with123numbers'),
+ ('With_Mixed_123_and_SPACES', 'with_mixed_123_and_spaces'),
+ ('HTMLParser', 'html_parser'),
+ ('HTTPResponseCode', 'http_response_code'),
+ ('a_b_c', 'a_b_c'),
+ ('A_B_C', 'a_b_c'),
+ ('fromAtoB', 'from_ato_b'),
+ ('XMLHTTPRequest', 'xmlhttp_request'),
+ ('_leading', 'leading'),
+ ('trailing_', 'trailing'),
+ (' leading_and_trailing_ ', 'leading_and_trailing'),
+ ('Multiple___Underscores', 'multiple_underscores'),
+ (' spaces_and___underscores ', 'spaces_and_underscores'),
+ (' _mixed_Case ', 'mixed_case'),
+ ('123Start', '123_start'),
+ ('End123', 'end123'),
+ ('Mid123dle', 'mid123dle'),
+ ],
+ )
+ def test_to_snake_case(self, input_str, expected_output):
+ assert to_snake_case(input_str) == expected_output
+
+
+class TestRenamePythonKeywords:
+
+ @pytest.mark.parametrize(
+ 'input_str, expected_output',
+ [
+ ('in', 'param_in'),
+ ('for', 'param_for'),
+ ('class', 'param_class'),
+ ('normal', 'normal'),
+ ('param_if', 'param_if'),
+ ('', ''),
+ ],
+ )
+ def test_rename_python_keywords(self, input_str, expected_output):
+ assert rename_python_keywords(input_str) == expected_output
+
+
+class TestApiParameter:
+
+ def test_api_parameter_initialization(self):
+ schema = Schema(type='string', description='A string parameter')
+ param = ApiParameter(
+ original_name='testParam',
+ description='A string description',
+ param_location='query',
+ param_schema=schema,
+ )
+ assert param.original_name == 'testParam'
+ assert param.param_location == 'query'
+ assert param.param_schema.type == 'string'
+ assert param.param_schema.description == 'A string parameter'
+ assert param.py_name == 'test_param'
+ assert param.type_hint == 'str'
+ assert param.type_value == str
+ assert param.description == 'A string description'
+
+ def test_api_parameter_keyword_rename(self):
+ schema = Schema(type='string')
+ param = ApiParameter(
+ original_name='in',
+ param_location='query',
+ param_schema=schema,
+ )
+ assert param.py_name == 'param_in'
+
+ def test_api_parameter_custom_py_name(self):
+ schema = Schema(type='integer')
+ param = ApiParameter(
+ original_name='testParam',
+ param_location='query',
+ param_schema=schema,
+ py_name='custom_name',
+ )
+ assert param.py_name == 'custom_name'
+
+ def test_api_parameter_str_representation(self):
+ schema = Schema(type='number')
+ param = ApiParameter(
+ original_name='testParam',
+ param_location='query',
+ param_schema=schema,
+ )
+ assert str(param) == 'test_param: float'
+
+ def test_api_parameter_to_arg_string(self):
+ schema = Schema(type='boolean')
+ param = ApiParameter(
+ original_name='testParam',
+ param_location='query',
+ param_schema=schema,
+ )
+ assert param.to_arg_string() == 'test_param=test_param'
+
+ def test_api_parameter_to_dict_property(self):
+ schema = Schema(type='string')
+ param = ApiParameter(
+ original_name='testParam',
+ param_location='path',
+ param_schema=schema,
+ )
+ assert param.to_dict_property() == '"test_param": test_param'
+
+ def test_api_parameter_model_serializer(self):
+ schema = Schema(type='string', description='test description')
+ param = ApiParameter(
+ original_name='TestParam',
+ param_location='path',
+ param_schema=schema,
+ py_name='test_param_custom',
+ description='test description',
+ )
+
+ serialized_param = param.model_dump(mode='json', exclude_none=True)
+
+ assert serialized_param == {
+ 'original_name': 'TestParam',
+ 'param_location': 'path',
+ 'param_schema': {'type': 'string', 'description': 'test description'},
+ 'description': 'test description',
+ 'py_name': 'test_param_custom',
+ }
+
+ @pytest.mark.parametrize(
+ 'schema, expected_type_value, expected_type_hint',
+ [
+ ({'type': 'integer'}, int, 'int'),
+ ({'type': 'number'}, float, 'float'),
+ ({'type': 'boolean'}, bool, 'bool'),
+ ({'type': 'string'}, str, 'str'),
+ (
+ {'type': 'string', 'format': 'date'},
+ str,
+ 'str',
+ ),
+ (
+ {'type': 'string', 'format': 'date-time'},
+ str,
+ 'str',
+ ),
+ (
+ {'type': 'array', 'items': {'type': 'integer'}},
+ List[int],
+ 'List[int]',
+ ),
+ (
+ {'type': 'array', 'items': {'type': 'string'}},
+ List[str],
+ 'List[str]',
+ ),
+ (
+ {
+ 'type': 'array',
+ 'items': {'type': 'object'},
+ },
+ List[Dict[str, Any]],
+ 'List[Dict[str, Any]]',
+ ),
+ ({'type': 'object'}, Dict[str, Any], 'Dict[str, Any]'),
+ ({'type': 'unknown'}, Any, 'Any'),
+ ({}, Any, 'Any'),
+ ],
+ )
+ def test_api_parameter_type_hint_helper(
+ self, schema, expected_type_value, expected_type_hint
+ ):
+ param = ApiParameter(
+ original_name='test', param_location='query', param_schema=schema
+ )
+ assert param.type_value == expected_type_value
+ assert param.type_hint == expected_type_hint
+ assert (
+ TypeHintHelper.get_type_hint(param.param_schema) == expected_type_hint
+ )
+ assert (
+ TypeHintHelper.get_type_value(param.param_schema) == expected_type_value
+ )
+
+ def test_api_parameter_description(self):
+ schema = Schema(type='string')
+ param = ApiParameter(
+ original_name='param1',
+ param_location='query',
+ param_schema=schema,
+ description='The description',
+ )
+ assert param.description == 'The description'
+
+ def test_api_parameter_description_use_schema_fallback(self):
+ schema = Schema(type='string', description='The description')
+ param = ApiParameter(
+ original_name='param1',
+ param_location='query',
+ param_schema=schema,
+ )
+ assert param.description == 'The description'
+
+
+class TestTypeHintHelper:
+
+ @pytest.mark.parametrize(
+ 'schema, expected_type_value, expected_type_hint',
+ [
+ ({'type': 'integer'}, int, 'int'),
+ ({'type': 'number'}, float, 'float'),
+ ({'type': 'string'}, str, 'str'),
+ (
+ {
+ 'type': 'array',
+ 'items': {'type': 'string'},
+ },
+ List[str],
+ 'List[str]',
+ ),
+ ],
+ )
+ def test_get_type_value_and_hint(
+ self, schema, expected_type_value, expected_type_hint
+ ):
+
+ param = ApiParameter(
+ original_name='test_param',
+ param_location='query',
+ param_schema=schema,
+ description='Test parameter',
+ )
+ assert (
+ TypeHintHelper.get_type_value(param.param_schema) == expected_type_value
+ )
+ assert (
+ TypeHintHelper.get_type_hint(param.param_schema) == expected_type_hint
+ )
+
+
+class TestPydocHelper:
+
+ def test_generate_param_doc_simple(self):
+ schema = Schema(type='string')
+ param = ApiParameter(
+ original_name='test_param',
+ param_location='query',
+ param_schema=schema,
+ description='Test description',
+ )
+
+ expected_doc = 'test_param (str): Test description'
+ assert PydocHelper.generate_param_doc(param) == expected_doc
+
+ def test_generate_param_doc_no_description(self):
+ schema = Schema(type='integer')
+ param = ApiParameter(
+ original_name='test_param',
+ param_location='query',
+ param_schema=schema,
+ )
+ expected_doc = 'test_param (int): '
+ assert PydocHelper.generate_param_doc(param) == expected_doc
+
+ def test_generate_param_doc_object(self):
+ schema = Schema(
+ type='object',
+ properties={
+ 'prop1': {'type': 'string', 'description': 'Prop1 desc'},
+ 'prop2': {'type': 'integer'},
+ },
+ )
+ param = ApiParameter(
+ original_name='test_param',
+ param_location='query',
+ param_schema=schema,
+ description='Test object parameter',
+ )
+ expected_doc = (
+ 'test_param (Dict[str, Any]): Test object parameter Object'
+ ' properties:\n prop1 (str): Prop1 desc\n prop2'
+ ' (int): \n'
+ )
+ assert PydocHelper.generate_param_doc(param) == expected_doc
+
+ def test_generate_param_doc_object_no_properties(self):
+ schema = Schema(type='object', description='A test schema')
+ param = ApiParameter(
+ original_name='test_param',
+ param_location='query',
+ param_schema=schema,
+ description='The description.',
+ )
+ expected_doc = 'test_param (Dict[str, Any]): The description.'
+ assert PydocHelper.generate_param_doc(param) == expected_doc
+
+ def test_generate_return_doc_simple(self):
+ responses = {
+ '200': {
+ 'description': 'Successful response',
+ 'content': {'application/json': {'schema': {'type': 'string'}}},
+ }
+ }
+ expected_doc = 'Returns (str): Successful response'
+ assert (
+ PydocHelper.generate_return_doc(dict_to_responses(responses))
+ == expected_doc
+ )
+
+ def test_generate_return_doc_no_content(self):
+ responses = {'204': {'description': 'No content'}}
+ assert not PydocHelper.generate_return_doc(dict_to_responses(responses))
+
+ def test_generate_return_doc_object(self):
+ responses = {
+ '200': {
+ 'description': 'Successful object response',
+ 'content': {
+ 'application/json': {
+ 'schema': {
+ 'type': 'object',
+ 'properties': {
+ 'prop1': {
+ 'type': 'string',
+ 'description': 'Prop1 desc',
+ },
+ 'prop2': {'type': 'integer'},
+ },
+ }
+ }
+ },
+ }
+ }
+
+ return_doc = PydocHelper.generate_return_doc(dict_to_responses(responses))
+
+ assert 'Returns (Dict[str, Any]): Successful object response' in return_doc
+ assert 'prop1 (str): Prop1 desc' in return_doc
+ assert 'prop2 (int):' in return_doc
+
+ def test_generate_return_doc_multiple_success(self):
+ responses = {
+ '200': {
+ 'description': 'Successful response',
+ 'content': {'application/json': {'schema': {'type': 'string'}}},
+ },
+ '400': {'description': 'Bad request'},
+ }
+ expected_doc = 'Returns (str): Successful response'
+ assert (
+ PydocHelper.generate_return_doc(dict_to_responses(responses))
+ == expected_doc
+ )
+
+ def test_generate_return_doc_2xx_smallest_status_code_response(self):
+ responses = {
+ '201': {
+ 'description': '201 response',
+ 'content': {'application/json': {'schema': {'type': 'integer'}}},
+ },
+ '200': {
+ 'description': '200 response',
+ 'content': {'application/json': {'schema': {'type': 'string'}}},
+ },
+ '400': {'description': 'Bad request'},
+ }
+
+ expected_doc = 'Returns (str): 200 response'
+ assert (
+ PydocHelper.generate_return_doc(dict_to_responses(responses))
+ == expected_doc
+ )
+
+ def test_generate_return_doc_contentful_response(self):
+ responses = {
+ '200': {'description': 'No content response'},
+ '201': {
+ 'description': '201 response',
+ 'content': {'application/json': {'schema': {'type': 'string'}}},
+ },
+ '400': {'description': 'Bad request'},
+ }
+ expected_doc = 'Returns (str): 201 response'
+ assert (
+ PydocHelper.generate_return_doc(dict_to_responses(responses))
+ == expected_doc
+ )
+
+
+if __name__ == '__main__':
+ pytest.main([__file__])
diff --git a/src/google/adk/tests/unittests/tools/openapi_tool/openapi_spec_parser/test.yaml b/src/google/adk/tests/unittests/tools/openapi_tool/openapi_spec_parser/test.yaml
new file mode 100644
index 0000000..0cea003
--- /dev/null
+++ b/src/google/adk/tests/unittests/tools/openapi_tool/openapi_spec_parser/test.yaml
@@ -0,0 +1,1367 @@
+# Copyright 2025 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+openapi: 3.0.0
+servers:
+ - url: https://www.googleapis.com/calendar/v3
+info:
+ contact:
+ name: Google
+ url: https://google.com
+ x-twitter: youtube
+ description: Manipulates events and other calendar data.
+ license:
+ name: Creative Commons Attribution 3.0
+ url: http://creativecommons.org/licenses/by/3.0/
+ termsOfService: https://developers.google.com/terms/
+ title: Calendar API
+ version: v3
+ x-apiClientRegistration:
+ url: https://console.developers.google.com
+ x-apisguru-categories:
+ - analytics
+ - media
+ x-logo:
+ url: https://www.google.com/images/branding/googlelogo/2x/googlelogo_color_272x92dp.png
+ x-origin:
+ - format: google
+ url: https://calendar-json.googleapis.com/$discovery/rest?version=v3
+ version: v1
+ x-providerName: googleapis.com
+ x-serviceName: calendar
+externalDocs:
+ url: https://developers.google.com/google-apps/calendar/firstapp
+tags:
+ - name: acl
+ - name: calendarList
+ - name: calendars
+ - name: channels
+ - name: colors
+ - name: events
+ - name: freebusy
+ - name: settings
+paths:
+ /calendars:
+ parameters:
+ - $ref: "#/components/parameters/alt"
+ - $ref: "#/components/parameters/fields"
+ - $ref: "#/components/parameters/key"
+ - $ref: "#/components/parameters/oauth_token"
+ - $ref: "#/components/parameters/prettyPrint"
+ - $ref: "#/components/parameters/quotaUser"
+ - $ref: "#/components/parameters/userIp"
+ post:
+ description: Creates a secondary calendar.
+ operationId: calendar.calendars.insert
+ requestBody:
+ content:
+ application/json:
+ schema:
+ $ref: "#/components/schemas/Calendar"
+ responses:
+ "200":
+ content:
+ application/json:
+ schema:
+ $ref: "#/components/schemas/Calendar"
+ description: Successful response
+ security:
+ - Oauth2:
+ - https://www.googleapis.com/auth/calendar
+ Oauth2c:
+ - https://www.googleapis.com/auth/calendar
+ tags:
+ - calendars
+ /calendars/{calendarId}:
+ delete:
+ description: Deletes a secondary calendar. Use calendars.clear for clearing all events on primary calendars.
+ operationId: calendar.calendars.delete
+ parameters:
+ - description: Calendar identifier. To retrieve calendar IDs call the calendarList.list method. If you want to access the primary calendar of the currently logged in user, use the "primary" keyword.
+ in: path
+ name: calendarId
+ required: true
+ schema:
+ type: string
+ responses:
+ "200":
+ description: Successful response
+ security:
+ - Oauth2:
+ - https://www.googleapis.com/auth/calendar
+ Oauth2c:
+ - https://www.googleapis.com/auth/calendar
+ tags:
+ - calendars
+ get:
+ description: Returns metadata for a calendar.
+ operationId: calendar.calendars.get
+ parameters:
+ - description: Calendar identifier. To retrieve calendar IDs call the calendarList.list method. If you want to access the primary calendar of the currently logged in user, use the "primary" keyword.
+ in: path
+ name: calendarId
+ required: true
+ schema:
+ type: string
+ responses:
+ "200":
+ content:
+ application/json:
+ schema:
+ $ref: "#/components/schemas/Calendar"
+ description: Successful response
+ security:
+ - Oauth2:
+ - https://www.googleapis.com/auth/calendar
+ Oauth2c:
+ - https://www.googleapis.com/auth/calendar
+ - Oauth2:
+ - https://www.googleapis.com/auth/calendar.readonly
+ Oauth2c:
+ - https://www.googleapis.com/auth/calendar.readonly
+ tags:
+ - calendars
+ parameters:
+ - $ref: "#/components/parameters/alt"
+ - $ref: "#/components/parameters/fields"
+ - $ref: "#/components/parameters/key"
+ - $ref: "#/components/parameters/oauth_token"
+ - $ref: "#/components/parameters/prettyPrint"
+ - $ref: "#/components/parameters/quotaUser"
+ - $ref: "#/components/parameters/userIp"
+ patch:
+ description: Updates metadata for a calendar. This method supports patch semantics.
+ operationId: calendar.calendars.patch
+ parameters:
+ - description: Calendar identifier. To retrieve calendar IDs call the calendarList.list method. If you want to access the primary calendar of the currently logged in user, use the "primary" keyword.
+ in: path
+ name: calendarId
+ required: true
+ schema:
+ type: string
+ requestBody:
+ content:
+ application/json:
+ schema:
+ $ref: "#/components/schemas/Calendar"
+ responses:
+ "200":
+ content:
+ application/json:
+ schema:
+ $ref: "#/components/schemas/Calendar"
+ description: Successful response
+ security:
+ - Oauth2:
+ - https://www.googleapis.com/auth/calendar
+ Oauth2c:
+ - https://www.googleapis.com/auth/calendar
+ tags:
+ - calendars
+ put:
+ description: Updates metadata for a calendar.
+ operationId: calendar.calendars.update
+ parameters:
+ - description: Calendar identifier. To retrieve calendar IDs call the calendarList.list method. If you want to access the primary calendar of the currently logged in user, use the "primary" keyword.
+ in: path
+ name: calendarId
+ required: true
+ schema:
+ type: string
+ requestBody:
+ content:
+ application/json:
+ schema:
+ $ref: "#/components/schemas/Calendar"
+ responses:
+ "200":
+ content:
+ application/json:
+ schema:
+ $ref: "#/components/schemas/Calendar"
+ description: Successful response
+ security:
+ - Oauth2:
+ - https://www.googleapis.com/auth/calendar
+ Oauth2c:
+ - https://www.googleapis.com/auth/calendar
+ tags:
+ - calendars
+components:
+ parameters:
+ alt:
+ description: Data format for the response.
+ in: query
+ name: alt
+ schema:
+ enum:
+ - json
+ type: string
+ fields:
+ description: Selector specifying which fields to include in a partial response.
+ in: query
+ name: fields
+ schema:
+ type: string
+ key:
+ description: API key. Your API key identifies your project and provides you with API access, quota, and reports. Required unless you provide an OAuth 2.0 token.
+ in: query
+ name: key
+ schema:
+ type: string
+ oauth_token:
+ description: OAuth 2.0 token for the current user.
+ in: query
+ name: oauth_token
+ schema:
+ type: string
+ prettyPrint:
+ description: Returns response with indentations and line breaks.
+ in: query
+ name: prettyPrint
+ schema:
+ type: boolean
+ quotaUser:
+ description: An opaque string that represents a user for quota purposes. Must not exceed 40 characters.
+ in: query
+ name: quotaUser
+ schema:
+ type: string
+ userIp:
+ description: Deprecated. Please use quotaUser instead.
+ in: query
+ name: userIp
+ schema:
+ type: string
+ schemas:
+ Acl:
+ properties:
+ etag:
+ description: ETag of the collection.
+ type: string
+ items:
+ description: List of rules on the access control list.
+ items:
+ $ref: "#/components/schemas/AclRule"
+ type: array
+ kind:
+ default: calendar#acl
+ description: Type of the collection ("calendar#acl").
+ type: string
+ nextPageToken:
+ description: Token used to access the next page of this result. Omitted if no further results are available, in which case nextSyncToken is provided.
+ type: string
+ nextSyncToken:
+ description: Token used at a later point in time to retrieve only the entries that have changed since this result was returned. Omitted if further results are available, in which case nextPageToken is provided.
+ type: string
+ type: object
+ AclRule:
+ properties:
+ etag:
+ description: ETag of the resource.
+ type: string
+ id:
+ description: Identifier of the Access Control List (ACL) rule. See Sharing calendars.
+ type: string
+ kind:
+ default: calendar#aclRule
+ description: Type of the resource ("calendar#aclRule").
+ type: string
+ role:
+ description: |-
+ The role assigned to the scope. Possible values are:
+ - "none" - Provides no access.
+ - "freeBusyReader" - Provides read access to free/busy information.
+ - "reader" - Provides read access to the calendar. Private events will appear to users with reader access, but event details will be hidden.
+ - "writer" - Provides read and write access to the calendar. Private events will appear to users with writer access, and event details will be visible.
+ - "owner" - Provides ownership of the calendar. This role has all of the permissions of the writer role with the additional ability to see and manipulate ACLs.
+ type: string
+ scope:
+ description: The extent to which calendar access is granted by this ACL rule.
+ properties:
+ type:
+ description: |-
+ The type of the scope. Possible values are:
+ - "default" - The public scope. This is the default value.
+ - "user" - Limits the scope to a single user.
+ - "group" - Limits the scope to a group.
+ - "domain" - Limits the scope to a domain. Note: The permissions granted to the "default", or public, scope apply to any user, authenticated or not.
+ type: string
+ value:
+ description: The email address of a user or group, or the name of a domain, depending on the scope type. Omitted for type "default".
+ type: string
+ type: object
+ type: object
+ Calendar:
+ properties:
+ conferenceProperties:
+ $ref: "#/components/schemas/ConferenceProperties"
+ description: Conferencing properties for this calendar, for example what types of conferences are allowed.
+ description:
+ description: Description of the calendar. Optional.
+ type: string
+ etag:
+ description: ETag of the resource.
+ type: string
+ id:
+ description: Identifier of the calendar. To retrieve IDs call the calendarList.list() method.
+ type: string
+ kind:
+ default: calendar#calendar
+ description: Type of the resource ("calendar#calendar").
+ type: string
+ location:
+ description: Geographic location of the calendar as free-form text. Optional.
+ type: string
+ summary:
+ description: Title of the calendar.
+ type: string
+ timeZone:
+ description: The time zone of the calendar. (Formatted as an IANA Time Zone Database name, e.g. "Europe/Zurich".) Optional.
+ type: string
+ type: object
+ CalendarList:
+ properties:
+ etag:
+ description: ETag of the collection.
+ type: string
+ items:
+ description: Calendars that are present on the user's calendar list.
+ items:
+ $ref: "#/components/schemas/CalendarListEntry"
+ type: array
+ kind:
+ default: calendar#calendarList
+ description: Type of the collection ("calendar#calendarList").
+ type: string
+ nextPageToken:
+ description: Token used to access the next page of this result. Omitted if no further results are available, in which case nextSyncToken is provided.
+ type: string
+ nextSyncToken:
+ description: Token used at a later point in time to retrieve only the entries that have changed since this result was returned. Omitted if further results are available, in which case nextPageToken is provided.
+ type: string
+ type: object
+ CalendarListEntry:
+ properties:
+ accessRole:
+ description: |-
+ The effective access role that the authenticated user has on the calendar. Read-only. Possible values are:
+ - "freeBusyReader" - Provides read access to free/busy information.
+ - "reader" - Provides read access to the calendar. Private events will appear to users with reader access, but event details will be hidden.
+ - "writer" - Provides read and write access to the calendar. Private events will appear to users with writer access, and event details will be visible.
+ - "owner" - Provides ownership of the calendar. This role has all of the permissions of the writer role with the additional ability to see and manipulate ACLs.
+ type: string
+ backgroundColor:
+ description: The main color of the calendar in the hexadecimal format "#0088aa". This property supersedes the index-based colorId property. To set or change this property, you need to specify colorRgbFormat=true in the parameters of the insert, update and patch methods. Optional.
+ type: string
+ colorId:
+ description: The color of the calendar. This is an ID referring to an entry in the calendar section of the colors definition (see the colors endpoint). This property is superseded by the backgroundColor and foregroundColor properties and can be ignored when using these properties. Optional.
+ type: string
+ conferenceProperties:
+ $ref: "#/components/schemas/ConferenceProperties"
+ description: Conferencing properties for this calendar, for example what types of conferences are allowed.
+ defaultReminders:
+ description: The default reminders that the authenticated user has for this calendar.
+ items:
+ $ref: "#/components/schemas/EventReminder"
+ type: array
+ deleted:
+ default: false
+ description: Whether this calendar list entry has been deleted from the calendar list. Read-only. Optional. The default is False.
+ type: boolean
+ description:
+ description: Description of the calendar. Optional. Read-only.
+ type: string
+ etag:
+ description: ETag of the resource.
+ type: string
+ foregroundColor:
+ description: The foreground color of the calendar in the hexadecimal format "#ffffff". This property supersedes the index-based colorId property. To set or change this property, you need to specify colorRgbFormat=true in the parameters of the insert, update and patch methods. Optional.
+ type: string
+ hidden:
+ default: false
+ description: Whether the calendar has been hidden from the list. Optional. The attribute is only returned when the calendar is hidden, in which case the value is true.
+ type: boolean
+ id:
+ description: Identifier of the calendar.
+ type: string
+ kind:
+ default: calendar#calendarListEntry
+ description: Type of the resource ("calendar#calendarListEntry").
+ type: string
+ location:
+ description: Geographic location of the calendar as free-form text. Optional. Read-only.
+ type: string
+ notificationSettings:
+ description: The notifications that the authenticated user is receiving for this calendar.
+ properties:
+ notifications:
+ description: The list of notifications set for this calendar.
+ items:
+ $ref: "#/components/schemas/CalendarNotification"
+ type: array
+ type: object
+ primary:
+ default: false
+ description: Whether the calendar is the primary calendar of the authenticated user. Read-only. Optional. The default is False.
+ type: boolean
+ selected:
+ default: false
+ description: Whether the calendar content shows up in the calendar UI. Optional. The default is False.
+ type: boolean
+ summary:
+ description: Title of the calendar. Read-only.
+ type: string
+ summaryOverride:
+ description: The summary that the authenticated user has set for this calendar. Optional.
+ type: string
+ timeZone:
+ description: The time zone of the calendar. Optional. Read-only.
+ type: string
+ type: object
+ CalendarNotification:
+ properties:
+ method:
+ description: |-
+ The method used to deliver the notification. The possible value is:
+ - "email" - Notifications are sent via email.
+ Required when adding a notification.
+ type: string
+ type:
+ description: |-
+ The type of notification. Possible values are:
+ - "eventCreation" - Notification sent when a new event is put on the calendar.
+ - "eventChange" - Notification sent when an event is changed.
+ - "eventCancellation" - Notification sent when an event is cancelled.
+ - "eventResponse" - Notification sent when an attendee responds to the event invitation.
+ - "agenda" - An agenda with the events of the day (sent out in the morning).
+ Required when adding a notification.
+ type: string
+ type: object
+ Channel:
+ properties:
+ address:
+ description: The address where notifications are delivered for this channel.
+ type: string
+ expiration:
+ description: Date and time of notification channel expiration, expressed as a Unix timestamp, in milliseconds. Optional.
+ format: int64
+ type: string
+ id:
+ description: A UUID or similar unique string that identifies this channel.
+ type: string
+ kind:
+ default: api#channel
+ description: Identifies this as a notification channel used to watch for changes to a resource, which is "api#channel".
+ type: string
+ params:
+ additionalProperties:
+ description: Declares a new parameter by name.
+ type: string
+ description: Additional parameters controlling delivery channel behavior. Optional.
+ type: object
+ payload:
+ description: A Boolean value to indicate whether payload is wanted. Optional.
+ type: boolean
+ resourceId:
+ description: An opaque ID that identifies the resource being watched on this channel. Stable across different API versions.
+ type: string
+ resourceUri:
+ description: A version-specific identifier for the watched resource.
+ type: string
+ token:
+ description: An arbitrary string delivered to the target address with each notification delivered over this channel. Optional.
+ type: string
+ type:
+ description: The type of delivery mechanism used for this channel. Valid values are "web_hook" (or "webhook"). Both values refer to a channel where Http requests are used to deliver messages.
+ type: string
+ type: object
+ ColorDefinition:
+ properties:
+ background:
+ description: The background color associated with this color definition.
+ type: string
+ foreground:
+ description: The foreground color that can be used to write on top of a background with 'background' color.
+ type: string
+ type: object
+ Colors:
+ properties:
+ calendar:
+ additionalProperties:
+ $ref: "#/components/schemas/ColorDefinition"
+ description: A calendar color definition.
+ description: A global palette of calendar colors, mapping from the color ID to its definition. A calendarListEntry resource refers to one of these color IDs in its colorId field. Read-only.
+ type: object
+ event:
+ additionalProperties:
+ $ref: "#/components/schemas/ColorDefinition"
+ description: An event color definition.
+ description: A global palette of event colors, mapping from the color ID to its definition. An event resource may refer to one of these color IDs in its colorId field. Read-only.
+ type: object
+ kind:
+ default: calendar#colors
+ description: Type of the resource ("calendar#colors").
+ type: string
+ updated:
+ description: Last modification time of the color palette (as a RFC3339 timestamp). Read-only.
+ format: date-time
+ type: string
+ type: object
+ ConferenceData:
+ properties:
+ conferenceId:
+ description: |-
+ The ID of the conference.
+ Can be used by developers to keep track of conferences, should not be displayed to users.
+ The ID value is formed differently for each conference solution type:
+ - eventHangout: ID is not set. (This conference type is deprecated.)
+ - eventNamedHangout: ID is the name of the Hangout. (This conference type is deprecated.)
+ - hangoutsMeet: ID is the 10-letter meeting code, for example aaa-bbbb-ccc.
+ - addOn: ID is defined by the third-party provider. Optional.
+ type: string
+ conferenceSolution:
+ $ref: "#/components/schemas/ConferenceSolution"
+ description: |-
+ The conference solution, such as Google Meet.
+ Unset for a conference with a failed create request.
+ Either conferenceSolution and at least one entryPoint, or createRequest is required.
+ createRequest:
+ $ref: "#/components/schemas/CreateConferenceRequest"
+ description: |-
+ A request to generate a new conference and attach it to the event. The data is generated asynchronously. To see whether the data is present check the status field.
+ Either conferenceSolution and at least one entryPoint, or createRequest is required.
+ entryPoints:
+ description: |-
+ Information about individual conference entry points, such as URLs or phone numbers.
+ All of them must belong to the same conference.
+ Either conferenceSolution and at least one entryPoint, or createRequest is required.
+ items:
+ $ref: "#/components/schemas/EntryPoint"
+ type: array
+ notes:
+ description: Additional notes (such as instructions from the domain administrator, legal notices) to display to the user. Can contain HTML. The maximum length is 2048 characters. Optional.
+ type: string
+ parameters:
+ $ref: "#/components/schemas/ConferenceParameters"
+ description: Additional properties related to a conference. An example would be a solution-specific setting for enabling video streaming.
+ signature:
+ description: |-
+ The signature of the conference data.
+ Generated on server side.
+ Unset for a conference with a failed create request.
+ Optional for a conference with a pending create request.
+ type: string
+ type: object
+ ConferenceParameters:
+ properties:
+ addOnParameters:
+ $ref: "#/components/schemas/ConferenceParametersAddOnParameters"
+ description: Additional add-on specific data.
+ type: object
+ ConferenceParametersAddOnParameters:
+ properties:
+ parameters:
+ additionalProperties:
+ type: string
+ type: object
+ type: object
+ ConferenceProperties:
+ properties:
+ allowedConferenceSolutionTypes:
+ description: |-
+ The types of conference solutions that are supported for this calendar.
+ The possible values are:
+ - "eventHangout"
+ - "eventNamedHangout"
+ - "hangoutsMeet" Optional.
+ items:
+ type: string
+ type: array
+ type: object
+ ConferenceRequestStatus:
+ properties:
+ statusCode:
+ description: |-
+ The current status of the conference create request. Read-only.
+ The possible values are:
+ - "pending": the conference create request is still being processed.
+ - "success": the conference create request succeeded, the entry points are populated.
+ - "failure": the conference create request failed, there are no entry points.
+ type: string
+ type: object
+ ConferenceSolution:
+ properties:
+ iconUri:
+ description: The user-visible icon for this solution.
+ type: string
+ key:
+ $ref: "#/components/schemas/ConferenceSolutionKey"
+ description: The key which can uniquely identify the conference solution for this event.
+ name:
+ description: The user-visible name of this solution. Not localized.
+ type: string
+ type: object
+ ConferenceSolutionKey:
+ properties:
+ type:
+ description: |-
+ The conference solution type.
+ If a client encounters an unfamiliar or empty type, it should still be able to display the entry points. However, it should disallow modifications.
+ The possible values are:
+ - "eventHangout" for Hangouts for consumers (deprecated; existing events may show this conference solution type but new conferences cannot be created)
+ - "eventNamedHangout" for classic Hangouts for Google Workspace users (deprecated; existing events may show this conference solution type but new conferences cannot be created)
+ - "hangoutsMeet" for Google Meet (http://meet.google.com)
+ - "addOn" for 3P conference providers
+ type: string
+ type: object
+ CreateConferenceRequest:
+ properties:
+ conferenceSolutionKey:
+ $ref: "#/components/schemas/ConferenceSolutionKey"
+ description: The conference solution, such as Hangouts or Google Meet.
+ requestId:
+ description: |-
+ The client-generated unique ID for this request.
+ Clients should regenerate this ID for every new request. If an ID provided is the same as for the previous request, the request is ignored.
+ type: string
+ status:
+ $ref: "#/components/schemas/ConferenceRequestStatus"
+ description: The status of the conference create request.
+ type: object
+ EntryPoint:
+ properties:
+ accessCode:
+ description: |-
+ The access code to access the conference. The maximum length is 128 characters.
+ When creating new conference data, populate only the subset of {meetingCode, accessCode, passcode, password, pin} fields that match the terminology that the conference provider uses. Only the populated fields should be displayed.
+ Optional.
+ type: string
+ entryPointFeatures:
+ description: Features of the entry point, such as being toll or toll-free. One entry point can have multiple features. However, toll and toll-free cannot be both set on the same entry point.
+ items:
+ type: string
+ type: array
+ entryPointType:
+ description: |-
+ The type of the conference entry point.
+ Possible values are:
+ - "video" - joining a conference over HTTP. A conference can have zero or one video entry point.
+ - "phone" - joining a conference by dialing a phone number. A conference can have zero or more phone entry points.
+ - "sip" - joining a conference over SIP. A conference can have zero or one sip entry point.
+ - "more" - further conference joining instructions, for example additional phone numbers. A conference can have zero or one more entry point. A conference with only a more entry point is not a valid conference.
+ type: string
+ label:
+ description: |-
+ The label for the URI. Visible to end users. Not localized. The maximum length is 512 characters.
+ Examples:
+ - for video: meet.google.com/aaa-bbbb-ccc
+ - for phone: +1 123 268 2601
+ - for sip: 12345678@altostrat.com
+ - for more: should not be filled
+ Optional.
+ type: string
+ meetingCode:
+ description: |-
+ The meeting code to access the conference. The maximum length is 128 characters.
+ When creating new conference data, populate only the subset of {meetingCode, accessCode, passcode, password, pin} fields that match the terminology that the conference provider uses. Only the populated fields should be displayed.
+ Optional.
+ type: string
+ passcode:
+ description: |-
+ The passcode to access the conference. The maximum length is 128 characters.
+ When creating new conference data, populate only the subset of {meetingCode, accessCode, passcode, password, pin} fields that match the terminology that the conference provider uses. Only the populated fields should be displayed.
+ type: string
+ password:
+ description: |-
+ The password to access the conference. The maximum length is 128 characters.
+ When creating new conference data, populate only the subset of {meetingCode, accessCode, passcode, password, pin} fields that match the terminology that the conference provider uses. Only the populated fields should be displayed.
+ Optional.
+ type: string
+ pin:
+ description: |-
+ The PIN to access the conference. The maximum length is 128 characters.
+ When creating new conference data, populate only the subset of {meetingCode, accessCode, passcode, password, pin} fields that match the terminology that the conference provider uses. Only the populated fields should be displayed.
+ Optional.
+ type: string
+ regionCode:
+ description: |-
+ The CLDR/ISO 3166 region code for the country associated with this phone access. Example: "SE" for Sweden.
+ Calendar backend will populate this field only for EntryPointType.PHONE.
+ type: string
+ uri:
+ description: |-
+ The URI of the entry point. The maximum length is 1300 characters.
+ Format:
+ - for video, http: or https: schema is required.
+ - for phone, tel: schema is required. The URI should include the entire dial sequence (e.g., tel:+12345678900,,,123456789;1234).
+ - for sip, sip: schema is required, e.g., sip:12345678@myprovider.com.
+ - for more, http: or https: schema is required.
+ type: string
+ type: object
+ Error:
+ properties:
+ domain:
+ description: Domain, or broad category, of the error.
+ type: string
+ reason:
+ description: |-
+ Specific reason for the error. Some of the possible values are:
+ - "groupTooBig" - The group of users requested is too large for a single query.
+ - "tooManyCalendarsRequested" - The number of calendars requested is too large for a single query.
+ - "notFound" - The requested resource was not found.
+ - "internalError" - The API service has encountered an internal error. Additional error types may be added in the future, so clients should gracefully handle additional error statuses not included in this list.
+ type: string
+ type: object
+ Event:
+ properties:
+ anyoneCanAddSelf:
+ default: false
+ description: Whether anyone can invite themselves to the event (deprecated). Optional. The default is False.
+ type: boolean
+ attachments:
+ description: |-
+ File attachments for the event.
+ In order to modify attachments the supportsAttachments request parameter should be set to true.
+ There can be at most 25 attachments per event,
+ items:
+ $ref: "#/components/schemas/EventAttachment"
+ type: array
+ attendees:
+ description: The attendees of the event. See the Events with attendees guide for more information on scheduling events with other calendar users. Service accounts need to use domain-wide delegation of authority to populate the attendee list.
+ items:
+ $ref: "#/components/schemas/EventAttendee"
+ type: array
+ attendeesOmitted:
+ default: false
+ description: Whether attendees may have been omitted from the event's representation. When retrieving an event, this may be due to a restriction specified by the maxAttendee query parameter. When updating an event, this can be used to only update the participant's response. Optional. The default is False.
+ type: boolean
+ colorId:
+ description: The color of the event. This is an ID referring to an entry in the event section of the colors definition (see the colors endpoint). Optional.
+ type: string
+ conferenceData:
+ $ref: "#/components/schemas/ConferenceData"
+ description: The conference-related information, such as details of a Google Meet conference. To create new conference details use the createRequest field. To persist your changes, remember to set the conferenceDataVersion request parameter to 1 for all event modification requests.
+ created:
+ description: Creation time of the event (as a RFC3339 timestamp). Read-only.
+ format: date-time
+ type: string
+ creator:
+ description: The creator of the event. Read-only.
+ properties:
+ displayName:
+ description: The creator's name, if available.
+ type: string
+ email:
+ description: The creator's email address, if available.
+ type: string
+ id:
+ description: The creator's Profile ID, if available.
+ type: string
+ self:
+ default: false
+ description: Whether the creator corresponds to the calendar on which this copy of the event appears. Read-only. The default is False.
+ type: boolean
+ type: object
+ description:
+ description: Description of the event. Can contain HTML. Optional.
+ type: string
+ end:
+ $ref: "#/components/schemas/EventDateTime"
+ description: The (exclusive) end time of the event. For a recurring event, this is the end time of the first instance.
+ endTimeUnspecified:
+ default: false
+ description: Whether the end time is actually unspecified. An end time is still provided for compatibility reasons, even if this attribute is set to True. The default is False.
+ type: boolean
+ etag:
+ description: ETag of the resource.
+ type: string
+ eventType:
+ default: default
+ description: |-
+ Specific type of the event. This cannot be modified after the event is created. Possible values are:
+ - "default" - A regular event or not further specified.
+ - "outOfOffice" - An out-of-office event.
+ - "focusTime" - A focus-time event.
+ - "workingLocation" - A working location event. Currently, only "default " and "workingLocation" events can be created using the API. Extended support for other event types will be made available in later releases.
+ type: string
+ extendedProperties:
+ description: Extended properties of the event.
+ properties:
+ private:
+ additionalProperties:
+ description: The name of the private property and the corresponding value.
+ type: string
+ description: Properties that are private to the copy of the event that appears on this calendar.
+ type: object
+ shared:
+ additionalProperties:
+ description: The name of the shared property and the corresponding value.
+ type: string
+ description: Properties that are shared between copies of the event on other attendees' calendars.
+ type: object
+ type: object
+ focusTimeProperties:
+ $ref: "#/components/schemas/EventFocusTimeProperties"
+ description: Focus Time event data. Used if eventType is focusTime.
+ gadget:
+ description: A gadget that extends this event. Gadgets are deprecated; this structure is instead only used for returning birthday calendar metadata.
+ properties:
+ display:
+ description: |-
+ The gadget's display mode. Deprecated. Possible values are:
+ - "icon" - The gadget displays next to the event's title in the calendar view.
+ - "chip" - The gadget displays when the event is clicked.
+ type: string
+ height:
+ description: The gadget's height in pixels. The height must be an integer greater than 0. Optional. Deprecated.
+ format: int32
+ type: integer
+ iconLink:
+ description: The gadget's icon URL. The URL scheme must be HTTPS. Deprecated.
+ type: string
+ link:
+ description: The gadget's URL. The URL scheme must be HTTPS. Deprecated.
+ type: string
+ preferences:
+ additionalProperties:
+ description: The preference name and corresponding value.
+ type: string
+ description: Preferences.
+ type: object
+ title:
+ description: The gadget's title. Deprecated.
+ type: string
+ type:
+ description: The gadget's type. Deprecated.
+ type: string
+ width:
+ description: The gadget's width in pixels. The width must be an integer greater than 0. Optional. Deprecated.
+ format: int32
+ type: integer
+ type: object
+ guestsCanInviteOthers:
+ default: true
+ description: Whether attendees other than the organizer can invite others to the event. Optional. The default is True.
+ type: boolean
+ guestsCanModify:
+ default: false
+ description: Whether attendees other than the organizer can modify the event. Optional. The default is False.
+ type: boolean
+ guestsCanSeeOtherGuests:
+ default: true
+ description: Whether attendees other than the organizer can see who the event's attendees are. Optional. The default is True.
+ type: boolean
+ hangoutLink:
+ description: An absolute link to the Google Hangout associated with this event. Read-only.
+ type: string
+ htmlLink:
+ description: An absolute link to this event in the Google Calendar Web UI. Read-only.
+ type: string
+ iCalUID:
+ description: |-
+ Event unique identifier as defined in RFC5545. It is used to uniquely identify events accross calendaring systems and must be supplied when importing events via the import method.
+ Note that the iCalUID and the id are not identical and only one of them should be supplied at event creation time. One difference in their semantics is that in recurring events, all occurrences of one event have different ids while they all share the same iCalUIDs. To retrieve an event using its iCalUID, call the events.list method using the iCalUID parameter. To retrieve an event using its id, call the events.get method.
+ type: string
+ id:
+ description: |-
+ Opaque identifier of the event. When creating new single or recurring events, you can specify their IDs. Provided IDs must follow these rules:
+ - characters allowed in the ID are those used in base32hex encoding, i.e. lowercase letters a-v and digits 0-9, see section 3.1.2 in RFC2938
+ - the length of the ID must be between 5 and 1024 characters
+ - the ID must be unique per calendar Due to the globally distributed nature of the system, we cannot guarantee that ID collisions will be detected at event creation time. To minimize the risk of collisions we recommend using an established UUID algorithm such as one described in RFC4122.
+ If you do not specify an ID, it will be automatically generated by the server.
+ Note that the icalUID and the id are not identical and only one of them should be supplied at event creation time. One difference in their semantics is that in recurring events, all occurrences of one event have different ids while they all share the same icalUIDs.
+ type: string
+ kind:
+ default: calendar#event
+ description: Type of the resource ("calendar#event").
+ type: string
+ location:
+ description: Geographic location of the event as free-form text. Optional.
+ type: string
+ locked:
+ default: false
+ description: Whether this is a locked event copy where no changes can be made to the main event fields "summary", "description", "location", "start", "end" or "recurrence". The default is False. Read-Only.
+ type: boolean
+ organizer:
+ description: The organizer of the event. If the organizer is also an attendee, this is indicated with a separate entry in attendees with the organizer field set to True. To change the organizer, use the move operation. Read-only, except when importing an event.
+ properties:
+ displayName:
+ description: The organizer's name, if available.
+ type: string
+ email:
+ description: The organizer's email address, if available. It must be a valid email address as per RFC5322.
+ type: string
+ id:
+ description: The organizer's Profile ID, if available.
+ type: string
+ self:
+ default: false
+ description: Whether the organizer corresponds to the calendar on which this copy of the event appears. Read-only. The default is False.
+ type: boolean
+ type: object
+ originalStartTime:
+ $ref: "#/components/schemas/EventDateTime"
+ description: For an instance of a recurring event, this is the time at which this event would start according to the recurrence data in the recurring event identified by recurringEventId. It uniquely identifies the instance within the recurring event series even if the instance was moved to a different time. Immutable.
+ outOfOfficeProperties:
+ $ref: "#/components/schemas/EventOutOfOfficeProperties"
+ description: Out of office event data. Used if eventType is outOfOffice.
+ privateCopy:
+ default: false
+ description: If set to True, Event propagation is disabled. Note that it is not the same thing as Private event properties. Optional. Immutable. The default is False.
+ type: boolean
+ recurrence:
+ description: List of RRULE, EXRULE, RDATE and EXDATE lines for a recurring event, as specified in RFC5545. Note that DTSTART and DTEND lines are not allowed in this field; event start and end times are specified in the start and end fields. This field is omitted for single events or instances of recurring events.
+ items:
+ type: string
+ type: array
+ recurringEventId:
+ description: For an instance of a recurring event, this is the id of the recurring event to which this instance belongs. Immutable.
+ type: string
+ reminders:
+ description: Information about the event's reminders for the authenticated user.
+ properties:
+ overrides:
+ description: If the event doesn't use the default reminders, this lists the reminders specific to the event, or, if not set, indicates that no reminders are set for this event. The maximum number of override reminders is 5.
+ items:
+ $ref: "#/components/schemas/EventReminder"
+ type: array
+ useDefault:
+ description: Whether the default reminders of the calendar apply to the event.
+ type: boolean
+ type: object
+ sequence:
+ description: Sequence number as per iCalendar.
+ format: int32
+ type: integer
+ source:
+ description: Source from which the event was created. For example, a web page, an email message or any document identifiable by an URL with HTTP or HTTPS scheme. Can only be seen or modified by the creator of the event.
+ properties:
+ title:
+ description: Title of the source; for example a title of a web page or an email subject.
+ type: string
+ url:
+ description: URL of the source pointing to a resource. The URL scheme must be HTTP or HTTPS.
+ type: string
+ type: object
+ start:
+ $ref: "#/components/schemas/EventDateTime"
+ description: The (inclusive) start time of the event. For a recurring event, this is the start time of the first instance.
+ status:
+ description: |-
+ Status of the event. Optional. Possible values are:
+ - "confirmed" - The event is confirmed. This is the default status.
+ - "tentative" - The event is tentatively confirmed.
+ - "cancelled" - The event is cancelled (deleted). The list method returns cancelled events only on incremental sync (when syncToken or updatedMin are specified) or if the showDeleted flag is set to true. The get method always returns them.
+ A cancelled status represents two different states depending on the event type:
+ - Cancelled exceptions of an uncancelled recurring event indicate that this instance should no longer be presented to the user. Clients should store these events for the lifetime of the parent recurring event.
+ Cancelled exceptions are only guaranteed to have values for the id, recurringEventId and originalStartTime fields populated. The other fields might be empty.
+ - All other cancelled events represent deleted events. Clients should remove their locally synced copies. Such cancelled events will eventually disappear, so do not rely on them being available indefinitely.
+ Deleted events are only guaranteed to have the id field populated. On the organizer's calendar, cancelled events continue to expose event details (summary, location, etc.) so that they can be restored (undeleted). Similarly, the events to which the user was invited and that they manually removed continue to provide details. However, incremental sync requests with showDeleted set to false will not return these details.
+ If an event changes its organizer (for example via the move operation) and the original organizer is not on the attendee list, it will leave behind a cancelled event where only the id field is guaranteed to be populated.
+ type: string
+ summary:
+ description: Title of the event.
+ type: string
+ transparency:
+ default: opaque
+ description: |-
+ Whether the event blocks time on the calendar. Optional. Possible values are:
+ - "opaque" - Default value. The event does block time on the calendar. This is equivalent to setting Show me as to Busy in the Calendar UI.
+ - "transparent" - The event does not block time on the calendar. This is equivalent to setting Show me as to Available in the Calendar UI.
+ type: string
+ updated:
+ description: Last modification time of the event (as a RFC3339 timestamp). Read-only.
+ format: date-time
+ type: string
+ visibility:
+ default: default
+ description: |-
+ Visibility of the event. Optional. Possible values are:
+ - "default" - Uses the default visibility for events on the calendar. This is the default value.
+ - "public" - The event is public and event details are visible to all readers of the calendar.
+ - "private" - The event is private and only event attendees may view event details.
+ - "confidential" - The event is private. This value is provided for compatibility reasons.
+ type: string
+ workingLocationProperties:
+ $ref: "#/components/schemas/EventWorkingLocationProperties"
+ description: Working location event data.
+ type: object
+ EventAttachment:
+ properties:
+ fileId:
+ description: |-
+ ID of the attached file. Read-only.
+ For Google Drive files, this is the ID of the corresponding Files resource entry in the Drive API.
+ type: string
+ fileUrl:
+ description: |-
+ URL link to the attachment.
+ For adding Google Drive file attachments use the same format as in alternateLink property of the Files resource in the Drive API.
+ Required when adding an attachment.
+ type: string
+ iconLink:
+ description: URL link to the attachment's icon. This field can only be modified for custom third-party attachments.
+ type: string
+ mimeType:
+ description: Internet media type (MIME type) of the attachment.
+ type: string
+ title:
+ description: Attachment title.
+ type: string
+ type: object
+ EventAttendee:
+ properties:
+ additionalGuests:
+ default: 0
+ description: Number of additional guests. Optional. The default is 0.
+ format: int32
+ type: integer
+ comment:
+ description: The attendee's response comment. Optional.
+ type: string
+ displayName:
+ description: The attendee's name, if available. Optional.
+ type: string
+ email:
+ description: |-
+ The attendee's email address, if available. This field must be present when adding an attendee. It must be a valid email address as per RFC5322.
+ Required when adding an attendee.
+ type: string
+ id:
+ description: The attendee's Profile ID, if available.
+ type: string
+ optional:
+ default: false
+ description: Whether this is an optional attendee. Optional. The default is False.
+ type: boolean
+ organizer:
+ description: Whether the attendee is the organizer of the event. Read-only. The default is False.
+ type: boolean
+ resource:
+ default: false
+ description: Whether the attendee is a resource. Can only be set when the attendee is added to the event for the first time. Subsequent modifications are ignored. Optional. The default is False.
+ type: boolean
+ responseStatus:
+ description: |-
+ The attendee's response status. Possible values are:
+ - "needsAction" - The attendee has not responded to the invitation (recommended for new events).
+ - "declined" - The attendee has declined the invitation.
+ - "tentative" - The attendee has tentatively accepted the invitation.
+ - "accepted" - The attendee has accepted the invitation. Warning: If you add an event using the values declined, tentative, or accepted, attendees with the "Add invitations to my calendar" setting set to "When I respond to invitation in email" won't see an event on their calendar unless they choose to change their invitation response in the event invitation email.
+ type: string
+ self:
+ default: false
+ description: Whether this entry represents the calendar on which this copy of the event appears. Read-only. The default is False.
+ type: boolean
+ type: object
+ EventDateTime:
+ properties:
+ date:
+ description: The date, in the format "yyyy-mm-dd", if this is an all-day event.
+ format: date
+ type: string
+ dateTime:
+ description: The time, as a combined date-time value (formatted according to RFC3339). A time zone offset is required unless a time zone is explicitly specified in timeZone.
+ format: date-time
+ type: string
+ timeZone:
+ description: The time zone in which the time is specified. (Formatted as an IANA Time Zone Database name, e.g. "Europe/Zurich".) For recurring events this field is required and specifies the time zone in which the recurrence is expanded. For single events this field is optional and indicates a custom time zone for the event start/end.
+ type: string
+ type: object
+ EventFocusTimeProperties:
+ properties:
+ autoDeclineMode:
+ description: Whether to decline meeting invitations which overlap Focus Time events. Valid values are declineNone, meaning that no meeting invitations are declined; declineAllConflictingInvitations, meaning that all conflicting meeting invitations that conflict with the event are declined; and declineOnlyNewConflictingInvitations, meaning that only new conflicting meeting invitations which arrive while the Focus Time event is present are to be declined.
+ type: string
+ chatStatus:
+ description: The status to mark the user in Chat and related products. This can be available or doNotDisturb.
+ type: string
+ declineMessage:
+ description: Response message to set if an existing event or new invitation is automatically declined by Calendar.
+ type: string
+ type: object
+ EventOutOfOfficeProperties:
+ properties:
+ autoDeclineMode:
+ description: Whether to decline meeting invitations which overlap Out of office events. Valid values are declineNone, meaning that no meeting invitations are declined; declineAllConflictingInvitations, meaning that all conflicting meeting invitations that conflict with the event are declined; and declineOnlyNewConflictingInvitations, meaning that only new conflicting meeting invitations which arrive while the Out of office event is present are to be declined.
+ type: string
+ declineMessage:
+ description: Response message to set if an existing event or new invitation is automatically declined by Calendar.
+ type: string
+ type: object
+ EventReminder:
+ properties:
+ method:
+ description: |-
+ The method used by this reminder. Possible values are:
+ - "email" - Reminders are sent via email.
+ - "popup" - Reminders are sent via a UI popup.
+ Required when adding a reminder.
+ type: string
+ minutes:
+ description: |-
+ Number of minutes before the start of the event when the reminder should trigger. Valid values are between 0 and 40320 (4 weeks in minutes).
+ Required when adding a reminder.
+ format: int32
+ type: integer
+ type: object
+ EventWorkingLocationProperties:
+ properties:
+ customLocation:
+ description: If present, specifies that the user is working from a custom location.
+ properties:
+ label:
+ description: An optional extra label for additional information.
+ type: string
+ type: object
+ homeOffice:
+ description: If present, specifies that the user is working at home.
+ officeLocation:
+ description: If present, specifies that the user is working from an office.
+ properties:
+ buildingId:
+ description: An optional building identifier. This should reference a building ID in the organization's Resources database.
+ type: string
+ deskId:
+ description: An optional desk identifier.
+ type: string
+ floorId:
+ description: An optional floor identifier.
+ type: string
+ floorSectionId:
+ description: An optional floor section identifier.
+ type: string
+ label:
+ description: The office name that's displayed in Calendar Web and Mobile clients. We recommend you reference a building name in the organization's Resources database.
+ type: string
+ type: object
+ type:
+ description: |-
+ Type of the working location. Possible values are:
+ - "homeOffice" - The user is working at home.
+ - "officeLocation" - The user is working from an office.
+ - "customLocation" - The user is working from a custom location. Any details are specified in a sub-field of the specified name, but this field may be missing if empty. Any other fields are ignored.
+ Required when adding working location properties.
+ type: string
+ type: object
+ Events:
+ properties:
+ accessRole:
+ description: |-
+ The user's access role for this calendar. Read-only. Possible values are:
+ - "none" - The user has no access.
+ - "freeBusyReader" - The user has read access to free/busy information.
+ - "reader" - The user has read access to the calendar. Private events will appear to users with reader access, but event details will be hidden.
+ - "writer" - The user has read and write access to the calendar. Private events will appear to users with writer access, and event details will be visible.
+ - "owner" - The user has ownership of the calendar. This role has all of the permissions of the writer role with the additional ability to see and manipulate ACLs.
+ type: string
+ defaultReminders:
+ description: The default reminders on the calendar for the authenticated user. These reminders apply to all events on this calendar that do not explicitly override them (i.e. do not have reminders.useDefault set to True).
+ items:
+ $ref: "#/components/schemas/EventReminder"
+ type: array
+ description:
+ description: Description of the calendar. Read-only.
+ type: string
+ etag:
+ description: ETag of the collection.
+ type: string
+ items:
+ description: List of events on the calendar.
+ items:
+ $ref: "#/components/schemas/Event"
+ type: array
+ kind:
+ default: calendar#events
+ description: Type of the collection ("calendar#events").
+ type: string
+ nextPageToken:
+ description: Token used to access the next page of this result. Omitted if no further results are available, in which case nextSyncToken is provided.
+ type: string
+ nextSyncToken:
+ description: Token used at a later point in time to retrieve only the entries that have changed since this result was returned. Omitted if further results are available, in which case nextPageToken is provided.
+ type: string
+ summary:
+ description: Title of the calendar. Read-only.
+ type: string
+ timeZone:
+ description: The time zone of the calendar. Read-only.
+ type: string
+ updated:
+ description: Last modification time of the calendar (as a RFC3339 timestamp). Read-only.
+ format: date-time
+ type: string
+ type: object
+ FreeBusyCalendar:
+ properties:
+ busy:
+ description: List of time ranges during which this calendar should be regarded as busy.
+ items:
+ $ref: "#/components/schemas/TimePeriod"
+ type: array
+ errors:
+ description: Optional error(s) (if computation for the calendar failed).
+ items:
+ $ref: "#/components/schemas/Error"
+ type: array
+ type: object
+ FreeBusyGroup:
+ properties:
+ calendars:
+ description: List of calendars' identifiers within a group.
+ items:
+ type: string
+ type: array
+ errors:
+ description: Optional error(s) (if computation for the group failed).
+ items:
+ $ref: "#/components/schemas/Error"
+ type: array
+ type: object
+ FreeBusyRequest:
+ properties:
+ calendarExpansionMax:
+ description: Maximal number of calendars for which FreeBusy information is to be provided. Optional. Maximum value is 50.
+ format: int32
+ type: integer
+ groupExpansionMax:
+ description: Maximal number of calendar identifiers to be provided for a single group. Optional. An error is returned for a group with more members than this value. Maximum value is 100.
+ format: int32
+ type: integer
+ items:
+ description: List of calendars and/or groups to query.
+ items:
+ $ref: "#/components/schemas/FreeBusyRequestItem"
+ type: array
+ timeMax:
+ description: The end of the interval for the query formatted as per RFC3339.
+ format: date-time
+ type: string
+ timeMin:
+ description: The start of the interval for the query formatted as per RFC3339.
+ format: date-time
+ type: string
+ timeZone:
+ default: UTC
+ description: Time zone used in the response. Optional. The default is UTC.
+ type: string
+ type: object
+ FreeBusyRequestItem:
+ properties:
+ id:
+ description: The identifier of a calendar or a group.
+ type: string
+ type: object
+ FreeBusyResponse:
+ properties:
+ calendars:
+ additionalProperties:
+ $ref: "#/components/schemas/FreeBusyCalendar"
+ description: Free/busy expansions for a single calendar.
+ description: List of free/busy information for calendars.
+ type: object
+ groups:
+ additionalProperties:
+ $ref: "#/components/schemas/FreeBusyGroup"
+ description: List of calendars that are members of this group.
+ description: Expansion of groups.
+ type: object
+ kind:
+ default: calendar#freeBusy
+ description: Type of the resource ("calendar#freeBusy").
+ type: string
+ timeMax:
+ description: The end of the interval.
+ format: date-time
+ type: string
+ timeMin:
+ description: The start of the interval.
+ format: date-time
+ type: string
+ type: object
+ Setting:
+ properties:
+ etag:
+ description: ETag of the resource.
+ type: string
+ id:
+ description: The id of the user setting.
+ type: string
+ kind:
+ default: calendar#setting
+ description: Type of the resource ("calendar#setting").
+ type: string
+ value:
+ description: Value of the user setting. The format of the value depends on the ID of the setting. It must always be a UTF-8 string of length up to 1024 characters.
+ type: string
+ type: object
+ Settings:
+ properties:
+ etag:
+ description: Etag of the collection.
+ type: string
+ items:
+ description: List of user settings.
+ items:
+ $ref: "#/components/schemas/Setting"
+ type: array
+ kind:
+ default: calendar#settings
+ description: Type of the collection ("calendar#settings").
+ type: string
+ nextPageToken:
+ description: Token used to access the next page of this result. Omitted if no further results are available, in which case nextSyncToken is provided.
+ type: string
+ nextSyncToken:
+ description: Token used at a later point in time to retrieve only the entries that have changed since this result was returned. Omitted if further results are available, in which case nextPageToken is provided.
+ type: string
+ type: object
+ TimePeriod:
+ properties:
+ end:
+ description: The (exclusive) end of the time period.
+ format: date-time
+ type: string
+ start:
+ description: The (inclusive) start of the time period.
+ format: date-time
+ type: string
+ type: object
+ securitySchemes:
+ Oauth2:
+ description: Oauth 2.0 implicit authentication
+ flows:
+ implicit:
+ authorizationUrl: https://accounts.google.com/o/oauth2/auth
+ scopes:
+ https://www.googleapis.com/auth/calendar: See, edit, share, and permanently delete all the calendars you can access using Google Calendar
+ https://www.googleapis.com/auth/calendar.events: View and edit events on all your calendars
+ https://www.googleapis.com/auth/calendar.events.readonly: View events on all your calendars
+ https://www.googleapis.com/auth/calendar.readonly: See and download any calendar you can access using your Google Calendar
+ https://www.googleapis.com/auth/calendar.settings.readonly: View your Calendar settings
+ type: oauth2
+ Oauth2c:
+ description: Oauth 2.0 authorizationCode authentication
+ flows:
+ authorizationCode:
+ authorizationUrl: https://accounts.google.com/o/oauth2/auth
+ scopes:
+ https://www.googleapis.com/auth/calendar: See, edit, share, and permanently delete all the calendars you can access using Google Calendar
+ https://www.googleapis.com/auth/calendar.events: View and edit events on all your calendars
+ https://www.googleapis.com/auth/calendar.events.readonly: View events on all your calendars
+ https://www.googleapis.com/auth/calendar.readonly: See and download any calendar you can access using your Google Calendar
+ https://www.googleapis.com/auth/calendar.settings.readonly: View your Calendar settings
+ tokenUrl: https://accounts.google.com/o/oauth2/token
+ type: oauth2
diff --git a/src/google/adk/tests/unittests/tools/openapi_tool/openapi_spec_parser/test_openapi_spec_parser.py b/src/google/adk/tests/unittests/tools/openapi_tool/openapi_spec_parser/test_openapi_spec_parser.py
new file mode 100644
index 0000000..de3156e
--- /dev/null
+++ b/src/google/adk/tests/unittests/tools/openapi_tool/openapi_spec_parser/test_openapi_spec_parser.py
@@ -0,0 +1,628 @@
+# Copyright 2025 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import Any
+from typing import Dict
+
+from google.adk.tools.openapi_tool.openapi_spec_parser.openapi_spec_parser import OpenApiSpecParser
+import pytest
+
+
+def create_minimal_openapi_spec() -> Dict[str, Any]:
+ """Creates a minimal valid OpenAPI spec."""
+ return {
+ "openapi": "3.1.0",
+ "info": {"title": "Minimal API", "version": "1.0.0"},
+ "paths": {
+ "/test": {
+ "get": {
+ "summary": "Test GET endpoint",
+ "operationId": "testGet",
+ "responses": {
+ "200": {
+ "description": "Successful response",
+ "content": {
+ "application/json": {"schema": {"type": "string"}}
+ },
+ }
+ },
+ }
+ }
+ },
+ }
+
+
+@pytest.fixture
+def openapi_spec_generator():
+ """Fixture for creating an OperationGenerator instance."""
+ return OpenApiSpecParser()
+
+
+def test_parse_minimal_spec(openapi_spec_generator):
+ """Test parsing a minimal OpenAPI specification."""
+ openapi_spec = create_minimal_openapi_spec()
+
+ parsed_operations = openapi_spec_generator.parse(openapi_spec)
+ op = parsed_operations[0]
+
+ assert len(parsed_operations) == 1
+ assert op.name == "test_get"
+ assert op.endpoint.path == "/test"
+ assert op.endpoint.method == "get"
+ assert op.return_value.type_value == str
+
+
+def test_parse_spec_with_no_operation_id(openapi_spec_generator):
+ """Test parsing a spec where operationId is missing (auto-generation)."""
+ openapi_spec = create_minimal_openapi_spec()
+ del openapi_spec["paths"]["/test"]["get"]["operationId"] # Remove operationId
+
+ parsed_operations = openapi_spec_generator.parse(openapi_spec)
+
+ assert len(parsed_operations) == 1
+ # Check if operationId is auto generated based on path and method.
+ assert parsed_operations[0].name == "test_get"
+
+
+def test_parse_spec_with_multiple_methods(openapi_spec_generator):
+ """Test parsing a spec with multiple HTTP methods for the same path."""
+ openapi_spec = create_minimal_openapi_spec()
+ openapi_spec["paths"]["/test"]["post"] = {
+ "summary": "Test POST endpoint",
+ "operationId": "testPost",
+ "responses": {"200": {"description": "Successful response"}},
+ }
+
+ parsed_operations = openapi_spec_generator.parse(openapi_spec)
+ operation_names = {op.name for op in parsed_operations}
+
+ assert len(parsed_operations) == 2
+ assert "test_get" in operation_names
+ assert "test_post" in operation_names
+
+
+def test_parse_spec_with_parameters(openapi_spec_generator):
+ openapi_spec = create_minimal_openapi_spec()
+ openapi_spec["paths"]["/test"]["get"]["parameters"] = [
+ {"name": "param1", "in": "query", "schema": {"type": "string"}},
+ {"name": "param2", "in": "header", "schema": {"type": "integer"}},
+ ]
+
+ parsed_operations = openapi_spec_generator.parse(openapi_spec)
+
+ assert len(parsed_operations[0].parameters) == 2
+ assert parsed_operations[0].parameters[0].original_name == "param1"
+ assert parsed_operations[0].parameters[0].param_location == "query"
+ assert parsed_operations[0].parameters[1].original_name == "param2"
+ assert parsed_operations[0].parameters[1].param_location == "header"
+
+
+def test_parse_spec_with_request_body(openapi_spec_generator):
+ openapi_spec = create_minimal_openapi_spec()
+ openapi_spec["paths"]["/test"]["post"] = {
+ "summary": "Endpoint with request body",
+ "operationId": "testPostWithBody",
+ "requestBody": {
+ "content": {
+ "application/json": {
+ "schema": {
+ "type": "object",
+ "properties": {"name": {"type": "string"}},
+ }
+ }
+ }
+ },
+ "responses": {"200": {"description": "OK"}},
+ }
+
+ parsed_operations = openapi_spec_generator.parse(openapi_spec)
+ post_operations = [
+ op for op in parsed_operations if op.endpoint.method == "post"
+ ]
+ op = post_operations[0]
+
+ assert len(post_operations) == 1
+ assert op.name == "test_post_with_body"
+ assert len(op.parameters) == 1
+ assert op.parameters[0].original_name == "name"
+ assert op.parameters[0].type_value == str
+
+
+def test_parse_spec_with_reference(openapi_spec_generator):
+ """Test parsing a specification with $ref."""
+ openapi_spec = {
+ "openapi": "3.1.0",
+ "info": {"title": "API with Refs", "version": "1.0.0"},
+ "paths": {
+ "/test_ref": {
+ "get": {
+ "summary": "Endpoint with ref",
+ "operationId": "testGetRef",
+ "responses": {
+ "200": {
+ "description": "Success",
+ "content": {
+ "application/json": {
+ "schema": {
+ "$ref": "#/components/schemas/MySchema"
+ }
+ }
+ },
+ }
+ },
+ }
+ }
+ },
+ "components": {
+ "schemas": {
+ "MySchema": {
+ "type": "object",
+ "properties": {"name": {"type": "string"}},
+ }
+ }
+ },
+ }
+ parsed_operations = openapi_spec_generator.parse(openapi_spec)
+ op = parsed_operations[0]
+
+ assert len(parsed_operations) == 1
+ assert op.return_value.type_value.__origin__ is dict
+
+
+def test_parse_spec_with_circular_reference(openapi_spec_generator):
+ """Test correct handling of circular $ref (important!)."""
+ openapi_spec = {
+ "openapi": "3.1.0",
+ "info": {"title": "Circular Ref API", "version": "1.0.0"},
+ "paths": {
+ "/circular": {
+ "get": {
+ "responses": {
+ "200": {
+ "description": "OK",
+ "content": {
+ "application/json": {
+ "schema": {"$ref": "#/components/schemas/A"}
+ }
+ },
+ }
+ }
+ }
+ }
+ },
+ "components": {
+ "schemas": {
+ "A": {
+ "type": "object",
+ "properties": {"b": {"$ref": "#/components/schemas/B"}},
+ },
+ "B": {
+ "type": "object",
+ "properties": {"a": {"$ref": "#/components/schemas/A"}},
+ },
+ }
+ },
+ }
+
+ parsed_operations = openapi_spec_generator.parse(openapi_spec)
+ assert len(parsed_operations) == 1
+
+ op = parsed_operations[0]
+ assert op.return_value.type_value.__origin__ is dict
+ assert op.return_value.type_hint == "Dict[str, Any]"
+
+
+def test_parse_no_paths(openapi_spec_generator):
+ """Test with a spec that has no paths defined."""
+ openapi_spec = {
+ "openapi": "3.1.0",
+ "info": {"title": "No Paths API", "version": "1.0.0"},
+ }
+ parsed_operations = openapi_spec_generator.parse(openapi_spec)
+ assert len(parsed_operations) == 0 # Should be empty
+
+
+def test_parse_empty_path_item(openapi_spec_generator):
+ """Test a path item that is present but empty."""
+ openapi_spec = {
+ "openapi": "3.1.0",
+ "info": {"title": "Empty Path Item API", "version": "1.0.0"},
+ "paths": {"/empty": None},
+ }
+
+ parsed_operations = openapi_spec_generator.parse(openapi_spec)
+
+ assert len(parsed_operations) == 0
+
+
+def test_parse_spec_with_global_auth_scheme(openapi_spec_generator):
+ """Test parsing with a global security scheme."""
+ openapi_spec = create_minimal_openapi_spec()
+ openapi_spec["security"] = [{"api_key": []}]
+ openapi_spec["components"] = {
+ "securitySchemes": {
+ "api_key": {"type": "apiKey", "in": "header", "name": "X-API-Key"}
+ }
+ }
+
+ parsed_operations = openapi_spec_generator.parse(openapi_spec)
+ op = parsed_operations[0]
+
+ assert len(parsed_operations) == 1
+ assert op.auth_scheme is not None
+ assert op.auth_scheme.type_.value == "apiKey"
+
+
+def test_parse_spec_with_local_auth_scheme(openapi_spec_generator):
+ """Test parsing with a local (operation-level) security scheme."""
+ openapi_spec = create_minimal_openapi_spec()
+ openapi_spec["paths"]["/test"]["get"]["security"] = [{"local_auth": []}]
+ openapi_spec["components"] = {
+ "securitySchemes": {"local_auth": {"type": "http", "scheme": "bearer"}}
+ }
+
+ parsed_operations = openapi_spec_generator.parse(openapi_spec)
+ op = parsed_operations[0]
+
+ assert op.auth_scheme is not None
+ assert op.auth_scheme.type_.value == "http"
+ assert op.auth_scheme.scheme == "bearer"
+
+
+def test_parse_spec_with_servers(openapi_spec_generator):
+ """Test parsing with server URLs."""
+ openapi_spec = create_minimal_openapi_spec()
+ openapi_spec["servers"] = [
+ {"url": "https://api.example.com"},
+ {"url": "http://localhost:8000"},
+ ]
+
+ parsed_operations = openapi_spec_generator.parse(openapi_spec)
+
+ assert len(parsed_operations) == 1
+ assert parsed_operations[0].endpoint.base_url == "https://api.example.com"
+
+
+def test_parse_spec_with_no_servers(openapi_spec_generator):
+ """Test with no servers defined (should default to empty string)."""
+ openapi_spec = create_minimal_openapi_spec()
+ if "servers" in openapi_spec:
+ del openapi_spec["servers"]
+
+ parsed_operations = openapi_spec_generator.parse(openapi_spec)
+
+ assert len(parsed_operations) == 1
+ assert parsed_operations[0].endpoint.base_url == ""
+
+
+def test_parse_spec_with_description(openapi_spec_generator):
+ openapi_spec = create_minimal_openapi_spec()
+ expected_description = "This is a test description."
+ openapi_spec["paths"]["/test"]["get"]["description"] = expected_description
+
+ parsed_operations = openapi_spec_generator.parse(openapi_spec)
+
+ assert len(parsed_operations) == 1
+ assert parsed_operations[0].description == expected_description
+
+
+def test_parse_spec_with_empty_description(openapi_spec_generator):
+ openapi_spec = create_minimal_openapi_spec()
+ openapi_spec["paths"]["/test"]["get"]["description"] = ""
+ openapi_spec["paths"]["/test"]["get"]["summary"] = ""
+
+ parsed_operations = openapi_spec_generator.parse(openapi_spec)
+
+ assert len(parsed_operations) == 1
+ assert parsed_operations[0].description == ""
+
+
+def test_parse_spec_with_no_description(openapi_spec_generator):
+ openapi_spec = create_minimal_openapi_spec()
+
+ # delete description
+ if "description" in openapi_spec["paths"]["/test"]["get"]:
+ del openapi_spec["paths"]["/test"]["get"]["description"]
+ if "summary" in openapi_spec["paths"]["/test"]["get"]:
+ del openapi_spec["paths"]["/test"]["get"]["summary"]
+
+ parsed_operations = openapi_spec_generator.parse(openapi_spec)
+
+ assert len(parsed_operations) == 1
+ assert (
+ parsed_operations[0].description == ""
+ ) # it should be initialized with empty string
+
+
+def test_parse_invalid_openapi_spec_type(openapi_spec_generator):
+ """Test that passing a non-dict object to parse raises TypeError"""
+ with pytest.raises(AttributeError):
+ openapi_spec_generator.parse(123) # type: ignore
+
+ with pytest.raises(AttributeError):
+ openapi_spec_generator.parse("openapi_spec") # type: ignore
+
+ with pytest.raises(AttributeError):
+ openapi_spec_generator.parse([]) # type: ignore
+
+
+def test_parse_external_ref_raises_error(openapi_spec_generator):
+ """Check that external references (not starting with #) raise ValueError."""
+ openapi_spec = {
+ "openapi": "3.1.0",
+ "info": {"title": "External Ref API", "version": "1.0.0"},
+ "paths": {
+ "/external": {
+ "get": {
+ "responses": {
+ "200": {
+ "description": "OK",
+ "content": {
+ "application/json": {
+ "schema": {
+ "$ref": (
+ "external_file.json#/components/schemas/ExternalSchema"
+ )
+ }
+ }
+ },
+ }
+ }
+ }
+ }
+ },
+ }
+ with pytest.raises(ValueError):
+ openapi_spec_generator.parse(openapi_spec)
+
+
+def test_parse_spec_with_multiple_paths_deep_refs(openapi_spec_generator):
+ """Test specs with multiple paths, request/response bodies using deep refs."""
+ openapi_spec = {
+ "openapi": "3.1.0",
+ "info": {"title": "Multiple Paths Deep Refs API", "version": "1.0.0"},
+ "paths": {
+ "/path1": {
+ "post": {
+ "operationId": "postPath1",
+ "requestBody": {
+ "content": {
+ "application/json": {
+ "schema": {
+ "$ref": "#/components/schemas/Request1"
+ }
+ }
+ }
+ },
+ "responses": {
+ "200": {
+ "description": "OK",
+ "content": {
+ "application/json": {
+ "schema": {
+ "$ref": "#/components/schemas/Response1"
+ }
+ }
+ },
+ }
+ },
+ }
+ },
+ "/path2": {
+ "put": {
+ "operationId": "putPath2",
+ "requestBody": {
+ "content": {
+ "application/json": {
+ "schema": {
+ "$ref": "#/components/schemas/Request2"
+ }
+ }
+ }
+ },
+ "responses": {
+ "200": {
+ "description": "OK",
+ "content": {
+ "application/json": {
+ "schema": {
+ "$ref": "#/components/schemas/Response2"
+ }
+ }
+ },
+ }
+ },
+ },
+ "get": {
+ "operationId": "getPath2",
+ "responses": {
+ "200": {
+ "description": "OK",
+ "content": {
+ "application/json": {
+ "schema": {
+ "$ref": "#/components/schemas/Response2"
+ }
+ }
+ },
+ }
+ },
+ },
+ },
+ },
+ "components": {
+ "schemas": {
+ "Request1": {
+ "type": "object",
+ "properties": {
+ "req1_prop1": {"$ref": "#/components/schemas/Level1_1"}
+ },
+ },
+ "Response1": {
+ "type": "object",
+ "properties": {
+ "res1_prop1": {"$ref": "#/components/schemas/Level1_2"}
+ },
+ },
+ "Request2": {
+ "type": "object",
+ "properties": {
+ "req2_prop1": {"$ref": "#/components/schemas/Level1_1"}
+ },
+ },
+ "Response2": {
+ "type": "object",
+ "properties": {
+ "res2_prop1": {"$ref": "#/components/schemas/Level1_2"}
+ },
+ },
+ "Level1_1": {
+ "type": "object",
+ "properties": {
+ "level1_1_prop1": {
+ "$ref": "#/components/schemas/Level2_1"
+ }
+ },
+ },
+ "Level1_2": {
+ "type": "object",
+ "properties": {
+ "level1_2_prop1": {
+ "$ref": "#/components/schemas/Level2_2"
+ }
+ },
+ },
+ "Level2_1": {
+ "type": "object",
+ "properties": {
+ "level2_1_prop1": {"$ref": "#/components/schemas/Level3"}
+ },
+ },
+ "Level2_2": {
+ "type": "object",
+ "properties": {"level2_2_prop1": {"type": "string"}},
+ },
+ "Level3": {"type": "integer"},
+ }
+ },
+ }
+
+ parsed_operations = openapi_spec_generator.parse(openapi_spec)
+ assert len(parsed_operations) == 3
+
+ # Verify Path 1
+ path1_ops = [op for op in parsed_operations if op.endpoint.path == "/path1"]
+ assert len(path1_ops) == 1
+ path1_op = path1_ops[0]
+ assert path1_op.name == "post_path1"
+
+ assert len(path1_op.parameters) == 1
+ assert path1_op.parameters[0].original_name == "req1_prop1"
+ assert (
+ path1_op.parameters[0]
+ .param_schema.properties["level1_1_prop1"]
+ .properties["level2_1_prop1"]
+ .type
+ == "integer"
+ )
+ assert (
+ path1_op.return_value.param_schema.properties["res1_prop1"]
+ .properties["level1_2_prop1"]
+ .properties["level2_2_prop1"]
+ .type
+ == "string"
+ )
+
+ # Verify Path 2
+ path2_ops = [
+ op
+ for op in parsed_operations
+ if op.endpoint.path == "/path2" and op.name == "put_path2"
+ ]
+ path2_op = path2_ops[0]
+ assert path2_op is not None
+ assert len(path2_op.parameters) == 1
+ assert path2_op.parameters[0].original_name == "req2_prop1"
+ assert (
+ path2_op.parameters[0]
+ .param_schema.properties["level1_1_prop1"]
+ .properties["level2_1_prop1"]
+ .type
+ == "integer"
+ )
+ assert (
+ path2_op.return_value.param_schema.properties["res2_prop1"]
+ .properties["level1_2_prop1"]
+ .properties["level2_2_prop1"]
+ .type
+ == "string"
+ )
+
+
+def test_parse_spec_with_duplicate_parameter_names(openapi_spec_generator):
+ """Test handling of duplicate parameter names (one in query, one in body).
+
+ The expected behavior is that both parameters should be captured but with
+ different suffix, and
+ their `original_name` attributes should reflect their origin (query or body).
+ """
+ openapi_spec = {
+ "openapi": "3.1.0",
+ "info": {"title": "Duplicate Parameter Names API", "version": "1.0.0"},
+ "paths": {
+ "/duplicate": {
+ "post": {
+ "operationId": "createWithDuplicate",
+ "parameters": [{
+ "name": "name",
+ "in": "query",
+ "schema": {"type": "string"},
+ }],
+ "requestBody": {
+ "content": {
+ "application/json": {
+ "schema": {
+ "type": "object",
+ "properties": {"name": {"type": "integer"}},
+ }
+ }
+ }
+ },
+ "responses": {"200": {"description": "OK"}},
+ }
+ }
+ },
+ }
+
+ parsed_operations = openapi_spec_generator.parse(openapi_spec)
+ assert len(parsed_operations) == 1
+ op = parsed_operations[0]
+ assert op.name == "create_with_duplicate"
+ assert len(op.parameters) == 2
+
+ query_param = None
+ body_param = None
+ for param in op.parameters:
+ if param.param_location == "query" and param.original_name == "name":
+ query_param = param
+ elif param.param_location == "body" and param.original_name == "name":
+ body_param = param
+
+ assert query_param is not None
+ assert query_param.original_name == "name"
+ assert query_param.py_name == "name"
+
+ assert body_param is not None
+ assert body_param.original_name == "name"
+ assert body_param.py_name == "name_0"
diff --git a/src/google/adk/tests/unittests/tools/openapi_tool/openapi_spec_parser/test_openapi_toolset.py b/src/google/adk/tests/unittests/tools/openapi_tool/openapi_spec_parser/test_openapi_toolset.py
new file mode 100644
index 0000000..1b1e218
--- /dev/null
+++ b/src/google/adk/tests/unittests/tools/openapi_tool/openapi_spec_parser/test_openapi_toolset.py
@@ -0,0 +1,139 @@
+# Copyright 2025 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os
+from typing import Dict
+
+from fastapi.openapi.models import APIKey
+from fastapi.openapi.models import APIKeyIn
+from fastapi.openapi.models import MediaType
+from fastapi.openapi.models import OAuth2
+from fastapi.openapi.models import ParameterInType
+from fastapi.openapi.models import SecuritySchemeType
+from google.adk.auth.auth_credential import AuthCredential
+from google.adk.auth.auth_credential import AuthCredentialTypes
+from google.adk.tools.openapi_tool.openapi_spec_parser.openapi_toolset import OpenAPIToolset
+from google.adk.tools.openapi_tool.openapi_spec_parser.rest_api_tool import RestApiTool
+import pytest
+import yaml
+
+
+def load_spec(file_path: str) -> Dict:
+ """Loads the OpenAPI specification from a YAML file."""
+ with open(file_path, "r", encoding="utf-8") as f:
+ return yaml.safe_load(f)
+
+
+@pytest.fixture
+def openapi_spec() -> Dict:
+ """Fixture to load the OpenAPI specification."""
+ current_dir = os.path.dirname(os.path.abspath(__file__))
+ # Join the directory path with the filename
+ yaml_path = os.path.join(current_dir, "test.yaml")
+ return load_spec(yaml_path)
+
+
+def test_openapi_toolset_initialization_from_dict(openapi_spec: Dict):
+ """Test initialization of OpenAPIToolset with a dictionary."""
+ toolset = OpenAPIToolset(spec_dict=openapi_spec)
+ assert isinstance(toolset.tools, list)
+ assert len(toolset.tools) == 5
+ assert all(isinstance(tool, RestApiTool) for tool in toolset.tools)
+
+
+def test_openapi_toolset_initialization_from_yaml_string(openapi_spec: Dict):
+ """Test initialization of OpenAPIToolset with a YAML string."""
+ spec_str = yaml.dump(openapi_spec)
+ toolset = OpenAPIToolset(spec_str=spec_str, spec_str_type="yaml")
+ assert isinstance(toolset.tools, list)
+ assert len(toolset.tools) == 5
+ assert all(isinstance(tool, RestApiTool) for tool in toolset.tools)
+
+
+def test_openapi_toolset_tool_existing(openapi_spec: Dict):
+ """Test the tool() method for an existing tool."""
+ toolset = OpenAPIToolset(spec_dict=openapi_spec)
+ tool_name = "calendar_calendars_insert" # Example operationId from the spec
+ tool = toolset.get_tool(tool_name)
+ assert isinstance(tool, RestApiTool)
+ assert tool.name == tool_name
+ assert tool.description == "Creates a secondary calendar."
+ assert tool.endpoint.method == "post"
+ assert tool.endpoint.base_url == "https://www.googleapis.com/calendar/v3"
+ assert tool.endpoint.path == "/calendars"
+ assert tool.is_long_running is False
+ assert tool.operation.operationId == "calendar.calendars.insert"
+ assert tool.operation.description == "Creates a secondary calendar."
+ assert isinstance(
+ tool.operation.requestBody.content["application/json"], MediaType
+ )
+ assert len(tool.operation.responses) == 1
+ response = tool.operation.responses["200"]
+ assert response.description == "Successful response"
+ assert isinstance(response.content["application/json"], MediaType)
+ assert isinstance(tool.auth_scheme, OAuth2)
+
+ tool_name = "calendar_calendars_get"
+ tool = toolset.get_tool(tool_name)
+ assert isinstance(tool, RestApiTool)
+ assert tool.name == tool_name
+ assert tool.description == "Returns metadata for a calendar."
+ assert tool.endpoint.method == "get"
+ assert tool.endpoint.base_url == "https://www.googleapis.com/calendar/v3"
+ assert tool.endpoint.path == "/calendars/{calendarId}"
+ assert tool.is_long_running is False
+ assert tool.operation.operationId == "calendar.calendars.get"
+ assert tool.operation.description == "Returns metadata for a calendar."
+ assert len(tool.operation.parameters) == 1
+ assert tool.operation.parameters[0].name == "calendarId"
+ assert tool.operation.parameters[0].in_ == ParameterInType.path
+ assert tool.operation.parameters[0].required is True
+ assert tool.operation.parameters[0].schema_.type == "string"
+ assert (
+ tool.operation.parameters[0].description
+ == "Calendar identifier. To retrieve calendar IDs call the"
+ " calendarList.list method. If you want to access the primary calendar"
+ ' of the currently logged in user, use the "primary" keyword.'
+ )
+ assert isinstance(tool.auth_scheme, OAuth2)
+
+ assert isinstance(toolset.get_tool("calendar_calendars_update"), RestApiTool)
+ assert isinstance(toolset.get_tool("calendar_calendars_delete"), RestApiTool)
+ assert isinstance(toolset.get_tool("calendar_calendars_patch"), RestApiTool)
+
+
+def test_openapi_toolset_tool_non_existing(openapi_spec: Dict):
+ """Test the tool() method for a non-existing tool."""
+ toolset = OpenAPIToolset(spec_dict=openapi_spec)
+ tool = toolset.get_tool("non_existent_tool")
+ assert tool is None
+
+
+def test_openapi_toolset_configure_auth_on_init(openapi_spec: Dict):
+ """Test configuring auth during initialization."""
+
+ auth_scheme = APIKey(**{
+ "in": APIKeyIn.header, # Use alias name in dict
+ "name": "api_key",
+ "type": SecuritySchemeType.http,
+ })
+ auth_credential = AuthCredential(auth_type=AuthCredentialTypes.API_KEY)
+ toolset = OpenAPIToolset(
+ spec_dict=openapi_spec,
+ auth_scheme=auth_scheme,
+ auth_credential=auth_credential,
+ )
+ for tool in toolset.tools:
+ assert tool.auth_scheme == auth_scheme
+ assert tool.auth_credential == auth_credential
diff --git a/src/google/adk/tests/unittests/tools/openapi_tool/openapi_spec_parser/test_operation_parser.py b/src/google/adk/tests/unittests/tools/openapi_tool/openapi_spec_parser/test_operation_parser.py
new file mode 100644
index 0000000..aa6fc5b
--- /dev/null
+++ b/src/google/adk/tests/unittests/tools/openapi_tool/openapi_spec_parser/test_operation_parser.py
@@ -0,0 +1,406 @@
+# Copyright 2025 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from fastapi.openapi.models import MediaType
+from fastapi.openapi.models import Operation
+from fastapi.openapi.models import Parameter
+from fastapi.openapi.models import RequestBody
+from fastapi.openapi.models import Response
+from fastapi.openapi.models import Schema
+from google.adk.tools.openapi_tool.common.common import ApiParameter
+from google.adk.tools.openapi_tool.openapi_spec_parser.operation_parser import OperationParser
+import pytest
+
+
+@pytest.fixture
+def sample_operation() -> Operation:
+ """Fixture to provide a sample OpenAPI Operation object."""
+ return Operation(
+ operationId='test_operation',
+ summary='Test Summary',
+ description='Test Description',
+ parameters=[
+ Parameter(**{
+ 'name': 'param1',
+ 'in': 'query',
+ 'schema': Schema(type='string'),
+ 'description': 'Parameter 1',
+ }),
+ Parameter(**{
+ 'name': 'param2',
+ 'in': 'header',
+ 'schema': Schema(type='string'),
+ 'description': 'Parameter 2',
+ }),
+ ],
+ requestBody=RequestBody(
+ content={
+ 'application/json': MediaType(
+ schema=Schema(
+ type='object',
+ properties={
+ 'prop1': Schema(
+ type='string', description='Property 1'
+ ),
+ 'prop2': Schema(
+ type='integer', description='Property 2'
+ ),
+ },
+ )
+ )
+ },
+ description='Request body description',
+ ),
+ responses={
+ '200': Response(
+ description='Success',
+ content={
+ 'application/json': MediaType(schema=Schema(type='string'))
+ },
+ ),
+ '400': Response(description='Client Error'),
+ },
+ security=[{'oauth2': ['resource: read', 'resource: write']}],
+ )
+
+
+def test_operation_parser_initialization(sample_operation):
+ """Test initialization of OperationParser."""
+ parser = OperationParser(sample_operation)
+ assert parser.operation == sample_operation
+ assert len(parser.params) == 4 # 2 params + 2 request body props
+ assert parser.return_value is not None
+
+
+def test_process_operation_parameters(sample_operation):
+ """Test _process_operation_parameters method."""
+ parser = OperationParser(sample_operation, should_parse=False)
+ parser._process_operation_parameters()
+ assert len(parser.params) == 2
+ assert parser.params[0].original_name == 'param1'
+ assert parser.params[0].param_location == 'query'
+ assert parser.params[1].original_name == 'param2'
+ assert parser.params[1].param_location == 'header'
+
+
+def test_process_request_body(sample_operation):
+ """Test _process_request_body method."""
+ parser = OperationParser(sample_operation, should_parse=False)
+ parser._process_request_body()
+ assert len(parser.params) == 2 # 2 properties in request body
+ assert parser.params[0].original_name == 'prop1'
+ assert parser.params[0].param_location == 'body'
+ assert parser.params[1].original_name == 'prop2'
+ assert parser.params[1].param_location == 'body'
+
+
+def test_process_request_body_array():
+ """Test _process_request_body method with array schema."""
+ operation = Operation(
+ requestBody=RequestBody(
+ content={
+ 'application/json': MediaType(
+ schema=Schema(
+ type='array',
+ items=Schema(
+ type='object',
+ properties={
+ 'item_prop1': Schema(
+ type='string', description='Item Property 1'
+ ),
+ 'item_prop2': Schema(
+ type='integer', description='Item Property 2'
+ ),
+ },
+ ),
+ )
+ )
+ }
+ )
+ )
+
+ parser = OperationParser(operation, should_parse=False)
+ parser._process_request_body()
+ assert len(parser.params) == 1
+ assert parser.params[0].original_name == 'array'
+ assert parser.params[0].param_location == 'body'
+ # Check that schema is correctly propagated and is a dictionary
+ assert parser.params[0].param_schema.type == 'array'
+ assert parser.params[0].param_schema.items.type == 'object'
+ assert 'item_prop1' in parser.params[0].param_schema.items.properties
+ assert 'item_prop2' in parser.params[0].param_schema.items.properties
+ assert (
+ parser.params[0].param_schema.items.properties['item_prop1'].description
+ == 'Item Property 1'
+ )
+ assert (
+ parser.params[0].param_schema.items.properties['item_prop2'].description
+ == 'Item Property 2'
+ )
+
+
+def test_process_request_body_no_name():
+ """Test _process_request_body with a schema that has no properties (unnamed)"""
+ operation = Operation(
+ requestBody=RequestBody(
+ content={'application/json': MediaType(schema=Schema(type='string'))}
+ )
+ )
+ parser = OperationParser(operation, should_parse=False)
+ parser._process_request_body()
+ assert len(parser.params) == 1
+ assert parser.params[0].original_name == '' # No name
+ assert parser.params[0].param_location == 'body'
+
+
+def test_dedupe_param_names(sample_operation):
+ """Test _dedupe_param_names method."""
+ parser = OperationParser(sample_operation, should_parse=False)
+ # Add duplicate named parameters.
+ parser.params = [
+ ApiParameter(original_name='test', param_location='', param_schema={}),
+ ApiParameter(original_name='test', param_location='', param_schema={}),
+ ApiParameter(original_name='test', param_location='', param_schema={}),
+ ]
+ parser._dedupe_param_names()
+ assert parser.params[0].py_name == 'test'
+ assert parser.params[1].py_name == 'test_0'
+ assert parser.params[2].py_name == 'test_1'
+
+
+def test_process_return_value(sample_operation):
+ """Test _process_return_value method."""
+ parser = OperationParser(sample_operation, should_parse=False)
+ parser._process_return_value()
+ assert parser.return_value is not None
+ assert parser.return_value.type_hint == 'str'
+
+
+def test_process_return_value_no_2xx(sample_operation):
+ """Tests _process_return_value when no 2xx response exists."""
+ operation_no_2xx = Operation(
+ responses={'400': Response(description='Client Error')}
+ )
+ parser = OperationParser(operation_no_2xx, should_parse=False)
+ parser._process_return_value()
+ assert parser.return_value is not None
+ assert parser.return_value.type_hint == 'Any'
+
+
+def test_process_return_value_multiple_2xx(sample_operation):
+ """Tests _process_return_value when multiple 2xx responses exist."""
+ operation_multi_2xx = Operation(
+ responses={
+ '201': Response(
+ description='Success',
+ content={
+ 'application/json': MediaType(schema=Schema(type='integer'))
+ },
+ ),
+ '202': Response(
+ description='Success',
+ content={'text/plain': MediaType(schema=Schema(type='string'))},
+ ),
+ '200': Response(
+ description='Success',
+ content={
+ 'application/pdf': MediaType(schema=Schema(type='boolean'))
+ },
+ ),
+ '400': Response(
+ description='Failure',
+ content={
+ 'application/xml': MediaType(schema=Schema(type='object'))
+ },
+ ),
+ }
+ )
+
+ parser = OperationParser(operation_multi_2xx, should_parse=False)
+ parser._process_return_value()
+
+ assert parser.return_value is not None
+ # Take the content type of the 200 response since it's the smallest response
+ # code
+ assert parser.return_value.param_schema.type == 'boolean'
+
+
+def test_process_return_value_no_content(sample_operation):
+ """Test when 2xx response has no content"""
+ operation_no_content = Operation(
+ responses={'200': Response(description='Success', content={})}
+ )
+ parser = OperationParser(operation_no_content, should_parse=False)
+ parser._process_return_value()
+ assert parser.return_value.type_hint == 'Any'
+
+
+def test_process_return_value_no_schema(sample_operation):
+ """Tests when the 2xx response's content has no schema."""
+ operation_no_schema = Operation(
+ responses={
+ '200': Response(
+ description='Success',
+ content={'application/json': MediaType(schema=None)},
+ )
+ }
+ )
+ parser = OperationParser(operation_no_schema, should_parse=False)
+ parser._process_return_value()
+ assert parser.return_value.type_hint == 'Any'
+
+
+def test_get_function_name(sample_operation):
+ """Test get_function_name method."""
+ parser = OperationParser(sample_operation)
+ assert parser.get_function_name() == 'test_operation'
+
+
+def test_get_function_name_missing_id():
+ """Tests get_function_name when operationId is missing"""
+ operation = Operation() # No ID
+ parser = OperationParser(operation)
+ with pytest.raises(ValueError, match='Operation ID is missing'):
+ parser.get_function_name()
+
+
+def test_get_return_type_hint(sample_operation):
+ """Test get_return_type_hint method."""
+ parser = OperationParser(sample_operation)
+ assert parser.get_return_type_hint() == 'str'
+
+
+def test_get_return_type_value(sample_operation):
+ """Test get_return_type_value method."""
+ parser = OperationParser(sample_operation)
+ assert parser.get_return_type_value() == str
+
+
+def test_get_parameters(sample_operation):
+ """Test get_parameters method."""
+ parser = OperationParser(sample_operation)
+ params = parser.get_parameters()
+ assert len(params) == 4 # Correct count after processing
+ assert all(isinstance(p, ApiParameter) for p in params)
+
+
+def test_get_return_value(sample_operation):
+ """Test get_return_value method."""
+ parser = OperationParser(sample_operation)
+ return_value = parser.get_return_value()
+ assert isinstance(return_value, ApiParameter)
+
+
+def test_get_auth_scheme_name(sample_operation):
+ """Test get_auth_scheme_name method."""
+ parser = OperationParser(sample_operation)
+ assert parser.get_auth_scheme_name() == 'oauth2'
+
+
+def test_get_auth_scheme_name_no_security():
+ """Test get_auth_scheme_name when no security is present."""
+ operation = Operation(responses={})
+ parser = OperationParser(operation)
+ assert parser.get_auth_scheme_name() == ''
+
+
+def test_get_pydoc_string(sample_operation):
+ """Test get_pydoc_string method."""
+ parser = OperationParser(sample_operation)
+ pydoc_string = parser.get_pydoc_string()
+ assert 'Test Summary' in pydoc_string
+ assert 'Args:' in pydoc_string
+ assert 'param1 (str): Parameter 1' in pydoc_string
+ assert 'prop1 (str): Property 1' in pydoc_string
+ assert 'Returns (str):' in pydoc_string
+ assert 'Success' in pydoc_string
+
+
+def test_get_json_schema(sample_operation):
+ """Test get_json_schema method."""
+ parser = OperationParser(sample_operation)
+ json_schema = parser.get_json_schema()
+ assert json_schema['title'] == 'test_operation_Arguments'
+ assert json_schema['type'] == 'object'
+ assert 'param1' in json_schema['properties']
+ assert 'prop1' in json_schema['properties']
+ assert 'param1' in json_schema['required']
+ assert 'prop1' in json_schema['required']
+
+
+def test_get_signature_parameters(sample_operation):
+ """Test get_signature_parameters method."""
+ parser = OperationParser(sample_operation)
+ signature_params = parser.get_signature_parameters()
+ assert len(signature_params) == 4
+ assert signature_params[0].name == 'param1'
+ assert signature_params[0].annotation == str
+ assert signature_params[2].name == 'prop1'
+ assert signature_params[2].annotation == str
+
+
+def test_get_annotations(sample_operation):
+ """Test get_annotations method."""
+ parser = OperationParser(sample_operation)
+ annotations = parser.get_annotations()
+ assert len(annotations) == 5 # 4 parameters + return
+ assert annotations['param1'] == str
+ assert annotations['prop1'] == str
+ assert annotations['return'] == str
+
+
+def test_load():
+ """Test the load classmethod."""
+ operation = Operation(operationId='my_op') # Minimal operation
+ params = [
+ ApiParameter(
+ original_name='p1',
+ param_location='',
+ param_schema={'type': 'integer'},
+ )
+ ]
+ return_value = ApiParameter(
+ original_name='', param_location='', param_schema={'type': 'string'}
+ )
+
+ parser = OperationParser.load(operation, params, return_value)
+
+ assert isinstance(parser, OperationParser)
+ assert parser.operation == operation
+ assert parser.params == params
+ assert parser.return_value == return_value
+ assert (
+ parser.get_function_name() == 'my_op'
+ ) # Check that the operation is loaded
+
+
+def test_operation_parser_with_dict():
+ """Test initialization of OperationParser with a dictionary."""
+ operation_dict = {
+ 'operationId': 'test_dict_operation',
+ 'parameters': [
+ {'name': 'dict_param', 'in': 'query', 'schema': {'type': 'string'}}
+ ],
+ 'responses': {
+ '200': {
+ 'description': 'Dict Success',
+ 'content': {'application/json': {'schema': {'type': 'string'}}},
+ }
+ },
+ }
+ parser = OperationParser(operation_dict)
+ assert parser.operation.operationId == 'test_dict_operation'
+ assert len(parser.params) == 1
+ assert parser.params[0].original_name == 'dict_param'
+ assert parser.return_value.type_hint == 'str'
diff --git a/src/google/adk/tests/unittests/tools/openapi_tool/openapi_spec_parser/test_rest_api_tool.py b/src/google/adk/tests/unittests/tools/openapi_tool/openapi_spec_parser/test_rest_api_tool.py
new file mode 100644
index 0000000..f3976f8
--- /dev/null
+++ b/src/google/adk/tests/unittests/tools/openapi_tool/openapi_spec_parser/test_rest_api_tool.py
@@ -0,0 +1,966 @@
+# Copyright 2025 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import json
+from unittest.mock import MagicMock
+from unittest.mock import patch
+
+from fastapi.openapi.models import MediaType
+from fastapi.openapi.models import Operation
+from fastapi.openapi.models import Parameter as OpenAPIParameter
+from fastapi.openapi.models import RequestBody
+from fastapi.openapi.models import Schema as OpenAPISchema
+from google.adk.sessions.state import State
+from google.adk.tools.openapi_tool.auth.auth_helpers import token_to_scheme_credential
+from google.adk.tools.openapi_tool.common.common import ApiParameter
+from google.adk.tools.openapi_tool.openapi_spec_parser.openapi_spec_parser import OperationEndpoint
+from google.adk.tools.openapi_tool.openapi_spec_parser.operation_parser import OperationParser
+from google.adk.tools.openapi_tool.openapi_spec_parser.rest_api_tool import RestApiTool
+from google.adk.tools.openapi_tool.openapi_spec_parser.rest_api_tool import snake_to_lower_camel
+from google.adk.tools.openapi_tool.openapi_spec_parser.rest_api_tool import to_gemini_schema
+from google.adk.tools.tool_context import ToolContext
+from google.genai.types import FunctionDeclaration
+from google.genai.types import Schema
+from google.genai.types import Type
+import pytest
+
+
+class TestRestApiTool:
+
+ @pytest.fixture
+ def mock_tool_context(self):
+ """Fixture for a mock OperationParser."""
+ mock_context = MagicMock(spec=ToolContext)
+ mock_context.state = State({}, {})
+ mock_context.get_auth_response.return_value = {}
+ mock_context.request_credential.return_value = {}
+ return mock_context
+
+ @pytest.fixture
+ def mock_operation_parser(self):
+ """Fixture for a mock OperationParser."""
+ mock_parser = MagicMock(spec=OperationParser)
+ mock_parser.get_function_name.return_value = "mock_function_name"
+ mock_parser.get_json_schema.return_value = {}
+ mock_parser.get_parameters.return_value = []
+ mock_parser.get_return_type_hint.return_value = "str"
+ mock_parser.get_pydoc_string.return_value = "Mock docstring"
+ mock_parser.get_signature_parameters.return_value = []
+ mock_parser.get_return_type_value.return_value = str
+ mock_parser.get_annotations.return_value = {}
+ return mock_parser
+
+ @pytest.fixture
+ def sample_endpiont(self):
+ return OperationEndpoint(
+ base_url="https://example.com", path="/test", method="GET"
+ )
+
+ @pytest.fixture
+ def sample_operation(self):
+ return Operation(
+ operationId="testOperation",
+ description="Test operation",
+ parameters=[],
+ requestBody=RequestBody(
+ content={
+ "application/json": MediaType(
+ schema=OpenAPISchema(
+ type="object",
+ properties={
+ "testBodyParam": OpenAPISchema(type="string")
+ },
+ )
+ )
+ }
+ ),
+ )
+
+ @pytest.fixture
+ def sample_api_parameters(self):
+ return [
+ ApiParameter(
+ original_name="test_param",
+ py_name="test_param",
+ param_location="query",
+ param_schema=OpenAPISchema(type="string"),
+ is_required=True,
+ ),
+ ApiParameter(
+ original_name="",
+ py_name="test_body_param",
+ param_location="body",
+ param_schema=OpenAPISchema(type="string"),
+ is_required=True,
+ ),
+ ]
+
+ @pytest.fixture
+ def sample_return_parameter(self):
+ return ApiParameter(
+ original_name="test_param",
+ py_name="test_param",
+ param_location="query",
+ param_schema=OpenAPISchema(type="string"),
+ is_required=True,
+ )
+
+ @pytest.fixture
+ def sample_auth_scheme(self):
+ scheme, _ = token_to_scheme_credential(
+ "apikey", "header", "", "sample_auth_credential_internal_test"
+ )
+ return scheme
+
+ @pytest.fixture
+ def sample_auth_credential(self):
+ _, credential = token_to_scheme_credential(
+ "apikey", "header", "", "sample_auth_credential_internal_test"
+ )
+ return credential
+
+ def test_init(
+ self,
+ sample_endpiont,
+ sample_operation,
+ sample_auth_scheme,
+ sample_auth_credential,
+ ):
+ tool = RestApiTool(
+ name="test_tool",
+ description="Test Tool",
+ endpoint=sample_endpiont,
+ operation=sample_operation,
+ auth_scheme=sample_auth_scheme,
+ auth_credential=sample_auth_credential,
+ )
+ assert tool.name == "test_tool"
+ assert tool.description == "Test Tool"
+ assert tool.endpoint == sample_endpiont
+ assert tool.operation == sample_operation
+ assert tool.auth_credential == sample_auth_credential
+ assert tool.auth_scheme == sample_auth_scheme
+ assert tool.credential_exchanger is not None
+
+ def test_from_parsed_operation_str(
+ self,
+ sample_endpiont,
+ sample_api_parameters,
+ sample_return_parameter,
+ sample_operation,
+ ):
+ parsed_operation_str = json.dumps({
+ "name": "test_operation",
+ "description": "Test Description",
+ "endpoint": sample_endpiont.model_dump(),
+ "operation": sample_operation.model_dump(),
+ "auth_scheme": None,
+ "auth_credential": None,
+ "parameters": [p.model_dump() for p in sample_api_parameters],
+ "return_value": sample_return_parameter.model_dump(),
+ })
+
+ tool = RestApiTool.from_parsed_operation_str(parsed_operation_str)
+ assert tool.name == "test_operation"
+
+ def test_get_declaration(
+ self, sample_endpiont, sample_operation, mock_operation_parser
+ ):
+ tool = RestApiTool(
+ name="test_tool",
+ description="Test description",
+ endpoint=sample_endpiont,
+ operation=sample_operation,
+ should_parse_operation=False,
+ )
+ tool._operation_parser = mock_operation_parser
+
+ declaration = tool._get_declaration()
+ assert isinstance(declaration, FunctionDeclaration)
+ assert declaration.name == "test_tool"
+ assert declaration.description == "Test description"
+ assert isinstance(declaration.parameters, Schema)
+
+ @patch(
+ "google.adk.tools.openapi_tool.openapi_spec_parser.rest_api_tool.requests.request"
+ )
+ def test_call_success(
+ self,
+ mock_request,
+ mock_tool_context,
+ sample_endpiont,
+ sample_operation,
+ sample_auth_scheme,
+ sample_auth_credential,
+ ):
+ mock_response = MagicMock()
+ mock_response.json.return_value = {"result": "success"}
+ mock_request.return_value = mock_response
+
+ tool = RestApiTool(
+ name="test_tool",
+ description="Test Tool",
+ endpoint=sample_endpiont,
+ operation=sample_operation,
+ auth_scheme=sample_auth_scheme,
+ auth_credential=sample_auth_credential,
+ )
+
+ # Call the method
+ result = tool.call(args={}, tool_context=mock_tool_context)
+
+ # Check the result
+ assert result == {"result": "success"}
+
+ @patch(
+ "google.adk.tools.openapi_tool.openapi_spec_parser.rest_api_tool.requests.request"
+ )
+ def test_call_auth_pending(
+ self,
+ mock_request,
+ sample_endpiont,
+ sample_operation,
+ sample_auth_scheme,
+ sample_auth_credential,
+ ):
+
+ tool = RestApiTool(
+ name="test_tool",
+ description="Test Tool",
+ endpoint=sample_endpiont,
+ operation=sample_operation,
+ auth_scheme=sample_auth_scheme,
+ auth_credential=sample_auth_credential,
+ )
+ with patch(
+ "google.adk.tools.openapi_tool.openapi_spec_parser.rest_api_tool.ToolAuthHandler.from_tool_context"
+ ) as mock_from_tool_context:
+ mock_tool_auth_handler_instance = MagicMock()
+ mock_tool_auth_handler_instance.prepare_auth_credentials.return_value.state = (
+ "pending"
+ )
+ mock_from_tool_context.return_value = mock_tool_auth_handler_instance
+
+ response = tool.call(args={}, tool_context=None)
+ assert response == {
+ "pending": True,
+ "message": "Needs your authorization to access your data.",
+ }
+
+ def test_prepare_request_params_query_body(
+ self, sample_endpiont, sample_auth_credential, sample_auth_scheme
+ ):
+ # Create a mock Operation object
+ mock_operation = Operation(
+ operationId="test_op",
+ parameters=[
+ OpenAPIParameter(**{
+ "name": "testQueryParam",
+ "in": "query",
+ "schema": OpenAPISchema(type="string"),
+ })
+ ],
+ requestBody=RequestBody(
+ content={
+ "application/json": MediaType(
+ schema=OpenAPISchema(
+ type="object",
+ properties={
+ "param1": OpenAPISchema(type="string"),
+ "param2": OpenAPISchema(type="integer"),
+ },
+ )
+ )
+ }
+ ),
+ )
+
+ tool = RestApiTool(
+ name="test_tool",
+ description="test",
+ endpoint=sample_endpiont,
+ operation=mock_operation,
+ auth_credential=sample_auth_credential,
+ auth_scheme=sample_auth_scheme,
+ )
+
+ params = [
+ ApiParameter(
+ original_name="param1",
+ py_name="param1",
+ param_location="body",
+ param_schema=OpenAPISchema(type="string"),
+ ),
+ ApiParameter(
+ original_name="param2",
+ py_name="param2",
+ param_location="body",
+ param_schema=OpenAPISchema(type="integer"),
+ ),
+ ApiParameter(
+ original_name="testQueryParam",
+ py_name="test_query_param",
+ param_location="query",
+ param_schema=OpenAPISchema(type="string"),
+ ),
+ ]
+ kwargs = {
+ "param1": "value1",
+ "param2": 123,
+ "test_query_param": "query_value",
+ }
+
+ request_params = tool._prepare_request_params(params, kwargs)
+ assert request_params["method"] == "get"
+ assert request_params["url"] == "https://example.com/test"
+ assert request_params["json"] == {"param1": "value1", "param2": 123}
+ assert request_params["params"] == {"testQueryParam": "query_value"}
+
+ def test_prepare_request_params_array(
+ self, sample_endpiont, sample_auth_scheme, sample_auth_credential
+ ):
+ mock_operation = Operation(
+ operationId="test_op",
+ requestBody=RequestBody(
+ content={
+ "application/json": MediaType(
+ schema=OpenAPISchema(
+ type="array", items=OpenAPISchema(type="string")
+ )
+ )
+ }
+ ),
+ )
+
+ tool = RestApiTool(
+ name="test_tool",
+ description="test",
+ endpoint=sample_endpiont,
+ operation=mock_operation,
+ auth_credential=sample_auth_credential,
+ auth_scheme=sample_auth_scheme,
+ )
+ params = [
+ ApiParameter(
+ original_name="array", # Match the parameter name
+ py_name="array",
+ param_location="body",
+ param_schema=OpenAPISchema(
+ type="array", items=OpenAPISchema(type="string")
+ ),
+ )
+ ]
+ kwargs = {"array": ["item1", "item2"]}
+
+ request_params = tool._prepare_request_params(params, kwargs)
+
+ assert request_params["json"] == ["item1", "item2"]
+
+ def test_prepare_request_params_string(
+ self, sample_endpiont, sample_auth_credential, sample_auth_scheme
+ ):
+ mock_operation = Operation(
+ operationId="test_op",
+ requestBody=RequestBody(
+ content={
+ "text/plain": MediaType(schema=OpenAPISchema(type="string"))
+ }
+ ),
+ )
+ tool = RestApiTool(
+ name="test_tool",
+ description="Test Tool",
+ endpoint=sample_endpiont,
+ operation=mock_operation,
+ auth_credential=sample_auth_credential,
+ auth_scheme=sample_auth_scheme,
+ )
+ params = [
+ ApiParameter(
+ original_name="",
+ py_name="input_string",
+ param_location="body",
+ param_schema=OpenAPISchema(type="string"),
+ )
+ ]
+ kwargs = {"input_string": "test_value"}
+
+ request_params = tool._prepare_request_params(params, kwargs)
+
+ assert request_params["data"] == "test_value"
+ assert request_params["headers"]["Content-Type"] == "text/plain"
+
+ def test_prepare_request_params_form_data(
+ self, sample_endpiont, sample_auth_scheme, sample_auth_credential
+ ):
+ mock_operation = Operation(
+ operationId="test_op",
+ requestBody=RequestBody(
+ content={
+ "application/x-www-form-urlencoded": MediaType(
+ schema=OpenAPISchema(
+ type="object",
+ properties={"key1": OpenAPISchema(type="string")},
+ )
+ )
+ }
+ ),
+ )
+ tool = RestApiTool(
+ name="test_tool",
+ description="test",
+ endpoint=sample_endpiont,
+ operation=mock_operation,
+ auth_credential=sample_auth_credential,
+ auth_scheme=sample_auth_scheme,
+ )
+ params = [
+ ApiParameter(
+ original_name="key1",
+ py_name="key1",
+ param_location="body",
+ param_schema=OpenAPISchema(type="string"),
+ )
+ ]
+ kwargs = {"key1": "value1"}
+
+ request_params = tool._prepare_request_params(params, kwargs)
+
+ assert request_params["data"] == {"key1": "value1"}
+ assert (
+ request_params["headers"]["Content-Type"]
+ == "application/x-www-form-urlencoded"
+ )
+
+ def test_prepare_request_params_multipart(
+ self, sample_endpiont, sample_auth_credential, sample_auth_scheme
+ ):
+ mock_operation = Operation(
+ operationId="test_op",
+ requestBody=RequestBody(
+ content={
+ "multipart/form-data": MediaType(
+ schema=OpenAPISchema(
+ type="object",
+ properties={
+ "file1": OpenAPISchema(
+ type="string", format="binary"
+ )
+ },
+ )
+ )
+ }
+ ),
+ )
+ tool = RestApiTool(
+ name="test_tool",
+ description="test",
+ endpoint=sample_endpiont,
+ operation=mock_operation,
+ auth_credential=sample_auth_credential,
+ auth_scheme=sample_auth_scheme,
+ )
+ params = [
+ ApiParameter(
+ original_name="file1",
+ py_name="file1",
+ param_location="body",
+ param_schema=OpenAPISchema(type="string", format="binary"),
+ )
+ ]
+ kwargs = {"file1": b"file_content"}
+
+ request_params = tool._prepare_request_params(params, kwargs)
+
+ assert request_params["files"] == {"file1": b"file_content"}
+ assert request_params["headers"]["Content-Type"] == "multipart/form-data"
+
+ def test_prepare_request_params_octet_stream(
+ self, sample_endpiont, sample_auth_scheme, sample_auth_credential
+ ):
+ mock_operation = Operation(
+ operationId="test_op",
+ requestBody=RequestBody(
+ content={
+ "application/octet-stream": MediaType(
+ schema=OpenAPISchema(type="string", format="binary")
+ )
+ }
+ ),
+ )
+ tool = RestApiTool(
+ name="test_tool",
+ description="test",
+ endpoint=sample_endpiont,
+ operation=mock_operation,
+ auth_credential=sample_auth_credential,
+ auth_scheme=sample_auth_scheme,
+ )
+ params = [
+ ApiParameter(
+ original_name="",
+ py_name="data",
+ param_location="body",
+ param_schema=OpenAPISchema(type="string", format="binary"),
+ )
+ ]
+ kwargs = {"data": b"binary_data"}
+
+ request_params = tool._prepare_request_params(params, kwargs)
+
+ assert request_params["data"] == b"binary_data"
+ assert (
+ request_params["headers"]["Content-Type"] == "application/octet-stream"
+ )
+
+ def test_prepare_request_params_path_param(
+ self, sample_endpiont, sample_auth_credential, sample_auth_scheme
+ ):
+ mock_operation = Operation(operationId="test_op")
+ tool = RestApiTool(
+ name="test_tool",
+ description="Test Tool",
+ endpoint=sample_endpiont,
+ operation=mock_operation,
+ auth_credential=sample_auth_credential,
+ auth_scheme=sample_auth_scheme,
+ )
+ params = [
+ ApiParameter(
+ original_name="user_id",
+ py_name="user_id",
+ param_location="path",
+ param_schema=OpenAPISchema(type="string"),
+ )
+ ]
+ kwargs = {"user_id": "123"}
+ endpoint_with_path = OperationEndpoint(
+ base_url="https://example.com", path="/test/{user_id}", method="get"
+ )
+ tool.endpoint = endpoint_with_path
+
+ request_params = tool._prepare_request_params(params, kwargs)
+
+ assert (
+ request_params["url"] == "https://example.com/test/123"
+ ) # Path param replaced
+
+ def test_prepare_request_params_header_param(
+ self,
+ sample_endpiont,
+ sample_auth_credential,
+ sample_auth_scheme,
+ sample_operation,
+ ):
+ tool = RestApiTool(
+ name="test_tool",
+ description="Test Tool",
+ endpoint=sample_endpiont,
+ operation=sample_operation,
+ auth_credential=sample_auth_credential,
+ auth_scheme=sample_auth_scheme,
+ )
+ params = [
+ ApiParameter(
+ original_name="X-Custom-Header",
+ py_name="x_custom_header",
+ param_location="header",
+ param_schema=OpenAPISchema(type="string"),
+ )
+ ]
+ kwargs = {"x_custom_header": "header_value"}
+
+ request_params = tool._prepare_request_params(params, kwargs)
+
+ assert request_params["headers"]["X-Custom-Header"] == "header_value"
+
+ def test_prepare_request_params_cookie_param(
+ self,
+ sample_endpiont,
+ sample_auth_credential,
+ sample_auth_scheme,
+ sample_operation,
+ ):
+ tool = RestApiTool(
+ name="test_tool",
+ description="Test Tool",
+ endpoint=sample_endpiont,
+ operation=sample_operation,
+ auth_credential=sample_auth_credential,
+ auth_scheme=sample_auth_scheme,
+ )
+ params = [
+ ApiParameter(
+ original_name="session_id",
+ py_name="session_id",
+ param_location="cookie",
+ param_schema=OpenAPISchema(type="string"),
+ )
+ ]
+ kwargs = {"session_id": "cookie_value"}
+
+ request_params = tool._prepare_request_params(params, kwargs)
+
+ assert request_params["cookies"]["session_id"] == "cookie_value"
+
+ def test_prepare_request_params_multiple_mime_types(
+ self, sample_endpiont, sample_auth_credential, sample_auth_scheme
+ ):
+ # Test what happens when multiple mime types are specified. It should take
+ # the first one.
+ mock_operation = Operation(
+ operationId="test_op",
+ requestBody=RequestBody(
+ content={
+ "application/json": MediaType(
+ schema=OpenAPISchema(type="string")
+ ),
+ "text/plain": MediaType(schema=OpenAPISchema(type="string")),
+ }
+ ),
+ )
+ tool = RestApiTool(
+ name="test_tool",
+ description="Test Tool",
+ endpoint=sample_endpiont,
+ operation=mock_operation,
+ auth_credential=sample_auth_credential,
+ auth_scheme=sample_auth_scheme,
+ )
+ params = [
+ ApiParameter(
+ original_name="",
+ py_name="input",
+ param_location="body",
+ param_schema=OpenAPISchema(type="string"),
+ )
+ ]
+ kwargs = {"input": "some_value"}
+
+ request_params = tool._prepare_request_params(params, kwargs)
+
+ assert request_params["headers"]["Content-Type"] == "application/json"
+
+ def test_prepare_request_params_unknown_parameter(
+ self,
+ sample_endpiont,
+ sample_auth_credential,
+ sample_auth_scheme,
+ sample_operation,
+ ):
+ tool = RestApiTool(
+ name="test_tool",
+ description="Test Tool",
+ endpoint=sample_endpiont,
+ operation=sample_operation,
+ auth_credential=sample_auth_credential,
+ auth_scheme=sample_auth_scheme,
+ )
+ params = [
+ ApiParameter(
+ original_name="known_param",
+ py_name="known_param",
+ param_location="query",
+ param_schema=OpenAPISchema(type="string"),
+ )
+ ]
+ kwargs = {"known_param": "value", "unknown_param": "unknown"}
+
+ request_params = tool._prepare_request_params(params, kwargs)
+
+ # Make sure unknown parameters are ignored and do not raise errors.
+ assert "unknown_param" not in request_params["params"]
+
+ def test_prepare_request_params_base_url_handling(
+ self, sample_auth_credential, sample_auth_scheme, sample_operation
+ ):
+ # No base_url provided, should use path as is
+ tool_no_base = RestApiTool(
+ name="test_tool_no_base",
+ description="Test Tool",
+ endpoint=OperationEndpoint(base_url="", path="/no_base", method="get"),
+ operation=sample_operation,
+ auth_credential=sample_auth_credential,
+ auth_scheme=sample_auth_scheme,
+ )
+ params = []
+ kwargs = {}
+
+ request_params_no_base = tool_no_base._prepare_request_params(
+ params, kwargs
+ )
+ assert request_params_no_base["url"] == "/no_base"
+
+ tool_trailing_slash = RestApiTool(
+ name="test_tool",
+ description="Test Tool",
+ endpoint=OperationEndpoint(
+ base_url="https://example.com/", path="/trailing", method="get"
+ ),
+ operation=sample_operation,
+ auth_credential=sample_auth_credential,
+ auth_scheme=sample_auth_scheme,
+ )
+
+ request_params_trailing = tool_trailing_slash._prepare_request_params(
+ params, kwargs
+ )
+ assert request_params_trailing["url"] == "https://example.com/trailing"
+
+ def test_prepare_request_params_no_unrecognized_query_parameter(
+ self,
+ sample_endpiont,
+ sample_auth_credential,
+ sample_auth_scheme,
+ sample_operation,
+ ):
+ tool = RestApiTool(
+ name="test_tool",
+ description="Test Tool",
+ endpoint=sample_endpiont,
+ operation=sample_operation,
+ auth_credential=sample_auth_credential,
+ auth_scheme=sample_auth_scheme,
+ )
+ params = [
+ ApiParameter(
+ original_name="unrecognized_param",
+ py_name="unrecognized_param",
+ param_location="query",
+ param_schema=OpenAPISchema(type="string"),
+ )
+ ]
+ kwargs = {"unrecognized_param": None} # Explicitly passing None
+ request_params = tool._prepare_request_params(params, kwargs)
+
+ # Query param not in sample_operation. It should be ignored.
+ assert "unrecognized_param" not in request_params["params"]
+
+ def test_prepare_request_params_no_credential(
+ self,
+ sample_endpiont,
+ sample_operation,
+ ):
+ tool = RestApiTool(
+ name="test_tool",
+ description="Test Tool",
+ endpoint=sample_endpiont,
+ operation=sample_operation,
+ auth_credential=None,
+ auth_scheme=None,
+ )
+ params = [
+ ApiParameter(
+ original_name="param_name",
+ py_name="param_name",
+ param_location="query",
+ param_schema=OpenAPISchema(type="string"),
+ )
+ ]
+ kwargs = {"param_name": "aaa", "empty_param": ""}
+
+ request_params = tool._prepare_request_params(params, kwargs)
+
+ assert "param_name" in request_params["params"]
+ assert "empty_param" not in request_params["params"]
+
+
+class TestToGeminiSchema:
+
+ def test_to_gemini_schema_none(self):
+ assert to_gemini_schema(None) is None
+
+ def test_to_gemini_schema_not_dict(self):
+ with pytest.raises(TypeError, match="openapi_schema must be a dictionary"):
+ to_gemini_schema("not a dict")
+
+ def test_to_gemini_schema_empty_dict(self):
+ result = to_gemini_schema({})
+ assert isinstance(result, Schema)
+ assert result.type == Type.OBJECT
+ assert result.properties == {"dummy_DO_NOT_GENERATE": Schema(type="string")}
+
+ def test_to_gemini_schema_dict_with_only_object_type(self):
+ result = to_gemini_schema({"type": "object"})
+ assert isinstance(result, Schema)
+ assert result.type == Type.OBJECT
+ assert result.properties == {"dummy_DO_NOT_GENERATE": Schema(type="string")}
+
+ def test_to_gemini_schema_basic_types(self):
+ openapi_schema = {
+ "type": "object",
+ "properties": {
+ "name": {"type": "string"},
+ "age": {"type": "integer"},
+ "is_active": {"type": "boolean"},
+ },
+ }
+ gemini_schema = to_gemini_schema(openapi_schema)
+ assert isinstance(gemini_schema, Schema)
+ assert gemini_schema.type == Type.OBJECT
+ assert gemini_schema.properties["name"].type == Type.STRING
+ assert gemini_schema.properties["age"].type == Type.INTEGER
+ assert gemini_schema.properties["is_active"].type == Type.BOOLEAN
+
+ def test_to_gemini_schema_nested_objects(self):
+ openapi_schema = {
+ "type": "object",
+ "properties": {
+ "address": {
+ "type": "object",
+ "properties": {
+ "street": {"type": "string"},
+ "city": {"type": "string"},
+ },
+ }
+ },
+ }
+ gemini_schema = to_gemini_schema(openapi_schema)
+ assert gemini_schema.properties["address"].type == Type.OBJECT
+ assert (
+ gemini_schema.properties["address"].properties["street"].type
+ == Type.STRING
+ )
+ assert (
+ gemini_schema.properties["address"].properties["city"].type
+ == Type.STRING
+ )
+
+ def test_to_gemini_schema_array(self):
+ openapi_schema = {
+ "type": "array",
+ "items": {"type": "string"},
+ }
+ gemini_schema = to_gemini_schema(openapi_schema)
+ assert gemini_schema.type == Type.ARRAY
+ assert gemini_schema.items.type == Type.STRING
+
+ def test_to_gemini_schema_nested_array(self):
+ openapi_schema = {
+ "type": "array",
+ "items": {
+ "type": "object",
+ "properties": {"name": {"type": "string"}},
+ },
+ }
+ gemini_schema = to_gemini_schema(openapi_schema)
+ assert gemini_schema.items.properties["name"].type == Type.STRING
+
+ def test_to_gemini_schema_any_of(self):
+ openapi_schema = {
+ "anyOf": [{"type": "string"}, {"type": "integer"}],
+ }
+ gemini_schema = to_gemini_schema(openapi_schema)
+ assert len(gemini_schema.any_of) == 2
+ assert gemini_schema.any_of[0].type == Type.STRING
+ assert gemini_schema.any_of[1].type == Type.INTEGER
+
+ def test_to_gemini_schema_general_list(self):
+ openapi_schema = {
+ "type": "array",
+ "properties": {
+ "list_field": {"type": "array", "items": {"type": "string"}},
+ },
+ }
+ gemini_schema = to_gemini_schema(openapi_schema)
+ assert gemini_schema.properties["list_field"].type == Type.ARRAY
+ assert gemini_schema.properties["list_field"].items.type == Type.STRING
+
+ def test_to_gemini_schema_enum(self):
+ openapi_schema = {"type": "string", "enum": ["a", "b", "c"]}
+ gemini_schema = to_gemini_schema(openapi_schema)
+ assert gemini_schema.enum == ["a", "b", "c"]
+
+ def test_to_gemini_schema_required(self):
+ openapi_schema = {
+ "type": "object",
+ "required": ["name"],
+ "properties": {"name": {"type": "string"}},
+ }
+ gemini_schema = to_gemini_schema(openapi_schema)
+ assert gemini_schema.required == ["name"]
+
+ def test_to_gemini_schema_nested_dict(self):
+ openapi_schema = {
+ "type": "object",
+ "properties": {"metadata": {"key1": "value1", "key2": 123}},
+ }
+ gemini_schema = to_gemini_schema(openapi_schema)
+ # Since metadata is not properties nor item, it will call to_gemini_schema recursively.
+ assert isinstance(gemini_schema.properties["metadata"], Schema)
+ assert (
+ gemini_schema.properties["metadata"].type == Type.OBJECT
+ ) # add object type by default
+ assert gemini_schema.properties["metadata"].properties == {
+ "dummy_DO_NOT_GENERATE": Schema(type="string")
+ }
+
+ def test_to_gemini_schema_ignore_title_default_format(self):
+ openapi_schema = {
+ "type": "string",
+ "title": "Test Title",
+ "default": "default_value",
+ "format": "date",
+ }
+ gemini_schema = to_gemini_schema(openapi_schema)
+
+ assert gemini_schema.title is None
+ assert gemini_schema.default is None
+ assert gemini_schema.format is None
+
+ def test_to_gemini_schema_property_ordering(self):
+ openapi_schema = {
+ "type": "object",
+ "propertyOrdering": ["name", "age"],
+ "properties": {
+ "name": {"type": "string"},
+ "age": {"type": "integer"},
+ },
+ }
+
+ gemini_schema = to_gemini_schema(openapi_schema)
+ assert gemini_schema.property_ordering == ["name", "age"]
+
+ def test_to_gemini_schema_converts_property_dict(self):
+ openapi_schema = {
+ "properties": {
+ "name": {"type": "string", "description": "The property key"},
+ "value": {"type": "string", "description": "The property value"},
+ },
+ "type": "object",
+ "description": "A single property entry in the Properties message.",
+ }
+ gemini_schema = to_gemini_schema(openapi_schema)
+ assert gemini_schema.type == Type.OBJECT
+ assert gemini_schema.properties["name"].type == Type.STRING
+ assert gemini_schema.properties["value"].type == Type.STRING
+
+ def test_to_gemini_schema_remove_unrecognized_fields(self):
+ openapi_schema = {
+ "type": "string",
+ "description": "A single date string.",
+ "format": "date",
+ }
+ gemini_schema = to_gemini_schema(openapi_schema)
+ assert gemini_schema.type == Type.STRING
+ assert not gemini_schema.format
+
+
+def test_snake_to_lower_camel():
+ assert snake_to_lower_camel("single") == "single"
+ assert snake_to_lower_camel("two_words") == "twoWords"
+ assert snake_to_lower_camel("three_word_example") == "threeWordExample"
+ assert not snake_to_lower_camel("")
+ assert snake_to_lower_camel("alreadyCamelCase") == "alreadyCamelCase"
diff --git a/src/google/adk/tests/unittests/tools/openapi_tool/openapi_spec_parser/test_tool_auth_handler.py b/src/google/adk/tests/unittests/tools/openapi_tool/openapi_spec_parser/test_tool_auth_handler.py
new file mode 100644
index 0000000..0a3b8cc
--- /dev/null
+++ b/src/google/adk/tests/unittests/tools/openapi_tool/openapi_spec_parser/test_tool_auth_handler.py
@@ -0,0 +1,201 @@
+# Copyright 2025 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import Optional
+from unittest.mock import MagicMock
+
+from google.adk.agents.invocation_context import InvocationContext
+from google.adk.agents.llm_agent import LlmAgent
+from google.adk.auth.auth_credential import AuthCredential
+from google.adk.auth.auth_credential import AuthCredentialTypes
+from google.adk.auth.auth_credential import HttpAuth
+from google.adk.auth.auth_credential import HttpCredentials
+from google.adk.auth.auth_credential import OAuth2Auth
+from google.adk.auth.auth_schemes import AuthScheme
+from google.adk.sessions.in_memory_session_service import InMemorySessionService
+from google.adk.sessions.session import Session
+from google.adk.tools.openapi_tool.auth.auth_helpers import openid_dict_to_scheme_credential
+from google.adk.tools.openapi_tool.auth.auth_helpers import token_to_scheme_credential
+from google.adk.tools.openapi_tool.auth.credential_exchangers.auto_auth_credential_exchanger import OAuth2CredentialExchanger
+from google.adk.tools.openapi_tool.openapi_spec_parser.tool_auth_handler import ToolAuthHandler
+from google.adk.tools.openapi_tool.openapi_spec_parser.tool_auth_handler import ToolContextCredentialStore
+from google.adk.tools.tool_context import ToolContext
+import pytest
+
+
+# Helper function to create a mock ToolContext
+def create_mock_tool_context():
+ return ToolContext(
+ function_call_id='test-fc-id',
+ invocation_context=InvocationContext(
+ agent=LlmAgent(name='test'),
+ session=Session(app_name='test', user_id='123', id='123'),
+ invocation_id='123',
+ session_service=InMemorySessionService(),
+ ),
+ )
+
+
+# Test cases for OpenID Connect
+class MockOpenIdConnectCredentialExchanger(OAuth2CredentialExchanger):
+
+ def __init__(
+ self, expected_scheme, expected_credential, expected_access_token
+ ):
+ self.expected_scheme = expected_scheme
+ self.expected_credential = expected_credential
+ self.expected_access_token = expected_access_token
+
+ def exchange_credential(
+ self,
+ auth_scheme: AuthScheme,
+ auth_credential: Optional[AuthCredential] = None,
+ ) -> AuthCredential:
+ if auth_credential.oauth2 and (
+ auth_credential.oauth2.auth_response_uri
+ or auth_credential.oauth2.auth_code
+ ):
+ auth_code = (
+ auth_credential.oauth2.auth_response_uri
+ if auth_credential.oauth2.auth_response_uri
+ else auth_credential.oauth2.auth_code
+ )
+ # Simulate the token exchange
+ updated_credential = AuthCredential(
+ auth_type=AuthCredentialTypes.HTTP, # Store as a bearer token
+ http=HttpAuth(
+ scheme='bearer',
+ credentials=HttpCredentials(
+ token=auth_code + self.expected_access_token
+ ),
+ ),
+ )
+ return updated_credential
+
+ # simulate the case of getting auth_uri
+ return None
+
+
+def get_mock_openid_scheme_credential():
+ config_dict = {
+ 'authorization_endpoint': 'test.com',
+ 'token_endpoint': 'test.com',
+ }
+ scopes = ['test_scope']
+ credential_dict = {
+ 'client_id': '123',
+ 'client_secret': '456',
+ 'redirect_uri': 'test.com',
+ }
+ return openid_dict_to_scheme_credential(config_dict, scopes, credential_dict)
+
+
+# Fixture for the OpenID Connect security scheme
+@pytest.fixture
+def openid_connect_scheme():
+ scheme, _ = get_mock_openid_scheme_credential()
+ return scheme
+
+
+# Fixture for a base OpenID Connect credential
+@pytest.fixture
+def openid_connect_credential():
+ _, credential = get_mock_openid_scheme_credential()
+ return credential
+
+
+def test_openid_connect_no_auth_response(
+ openid_connect_scheme, openid_connect_credential
+):
+ # Setup Mock exchanger
+ mock_exchanger = MockOpenIdConnectCredentialExchanger(
+ openid_connect_scheme, openid_connect_credential, None
+ )
+ tool_context = create_mock_tool_context()
+ credential_store = ToolContextCredentialStore(tool_context=tool_context)
+ handler = ToolAuthHandler(
+ tool_context,
+ openid_connect_scheme,
+ openid_connect_credential,
+ credential_exchanger=mock_exchanger,
+ credential_store=credential_store,
+ )
+ result = handler.prepare_auth_credentials()
+ assert result.state == 'pending'
+ assert result.auth_credential == openid_connect_credential
+
+
+def test_openid_connect_with_auth_response(
+ openid_connect_scheme, openid_connect_credential, monkeypatch
+):
+ mock_exchanger = MockOpenIdConnectCredentialExchanger(
+ openid_connect_scheme,
+ openid_connect_credential,
+ 'test_access_token',
+ )
+ tool_context = create_mock_tool_context()
+
+ mock_auth_handler = MagicMock()
+ mock_auth_handler.get_auth_response.return_value = AuthCredential(
+ auth_type=AuthCredentialTypes.OPEN_ID_CONNECT,
+ oauth2=OAuth2Auth(auth_response_uri='test_auth_response_uri'),
+ )
+ mock_auth_handler_path = 'google.adk.tools.tool_context.AuthHandler'
+ monkeypatch.setattr(
+ mock_auth_handler_path, lambda *args, **kwargs: mock_auth_handler
+ )
+
+ credential_store = ToolContextCredentialStore(tool_context=tool_context)
+ handler = ToolAuthHandler(
+ tool_context,
+ openid_connect_scheme,
+ openid_connect_credential,
+ credential_exchanger=mock_exchanger,
+ credential_store=credential_store,
+ )
+ result = handler.prepare_auth_credentials()
+ assert result.state == 'done'
+ assert result.auth_credential.auth_type == AuthCredentialTypes.HTTP
+ assert 'test_access_token' in result.auth_credential.http.credentials.token
+ # Verify that the credential was stored:
+ stored_credential = credential_store.get_credential(
+ openid_connect_scheme, openid_connect_credential
+ )
+ assert stored_credential == result.auth_credential
+ mock_auth_handler.get_auth_response.assert_called_once()
+
+
+def test_openid_connect_existing_token(
+ openid_connect_scheme, openid_connect_credential
+):
+ _, existing_credential = token_to_scheme_credential(
+ 'oauth2Token', 'header', 'bearer', '123123123'
+ )
+ tool_context = create_mock_tool_context()
+ # Store the credential to simulate existing credential
+ credential_store = ToolContextCredentialStore(tool_context=tool_context)
+ key = credential_store.get_credential_key(
+ openid_connect_scheme, openid_connect_credential
+ )
+ credential_store.store_credential(key, existing_credential)
+
+ handler = ToolAuthHandler(
+ tool_context,
+ openid_connect_scheme,
+ openid_connect_credential,
+ credential_store=credential_store,
+ )
+ result = handler.prepare_auth_credentials()
+ assert result.state == 'done'
+ assert result.auth_credential == existing_credential
diff --git a/src/google/adk/tests/unittests/tools/retrieval/__init__.py b/src/google/adk/tests/unittests/tools/retrieval/__init__.py
new file mode 100644
index 0000000..36a1e8d
--- /dev/null
+++ b/src/google/adk/tests/unittests/tools/retrieval/__init__.py
@@ -0,0 +1,14 @@
+# Copyright 2025 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
diff --git a/src/google/adk/tests/unittests/tools/retrieval/test_vertex_ai_rag_retrieval.py b/src/google/adk/tests/unittests/tools/retrieval/test_vertex_ai_rag_retrieval.py
new file mode 100644
index 0000000..f8d122c
--- /dev/null
+++ b/src/google/adk/tests/unittests/tools/retrieval/test_vertex_ai_rag_retrieval.py
@@ -0,0 +1,147 @@
+# Copyright 2025 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from google.adk.agents import Agent
+from google.adk.tools.function_tool import FunctionTool
+from google.adk.tools.retrieval.vertex_ai_rag_retrieval import VertexAiRagRetrieval
+from google.genai import types
+
+from ... import utils
+
+
+def noop_tool(x: str) -> str:
+ return x
+
+
+def test_vertex_rag_retrieval_for_gemini_1_x():
+ responses = [
+ 'response1',
+ ]
+ mockModel = utils.MockModel.create(responses=responses)
+ mockModel.model = 'gemini-1.5-pro'
+
+ # Calls the first time.
+ agent = Agent(
+ name='root_agent',
+ model=mockModel,
+ tools=[
+ VertexAiRagRetrieval(
+ name='rag_retrieval',
+ description='rag_retrieval',
+ rag_corpora=[
+ 'projects/123456789/locations/us-central1/ragCorpora/1234567890'
+ ],
+ )
+ ],
+ )
+ runner = utils.InMemoryRunner(agent)
+ events = runner.run('test1')
+
+ # Asserts the requests.
+ assert len(mockModel.requests) == 1
+ assert utils.simplify_contents(mockModel.requests[0].contents) == [
+ ('user', 'test1'),
+ ]
+ assert len(mockModel.requests[0].config.tools) == 1
+ assert (
+ mockModel.requests[0].config.tools[0].function_declarations[0].name
+ == 'rag_retrieval'
+ )
+ assert mockModel.requests[0].tools_dict['rag_retrieval'] is not None
+
+
+def test_vertex_rag_retrieval_for_gemini_1_x_with_another_function_tool():
+ responses = [
+ 'response1',
+ ]
+ mockModel = utils.MockModel.create(responses=responses)
+ mockModel.model = 'gemini-1.5-pro'
+
+ # Calls the first time.
+ agent = Agent(
+ name='root_agent',
+ model=mockModel,
+ tools=[
+ VertexAiRagRetrieval(
+ name='rag_retrieval',
+ description='rag_retrieval',
+ rag_corpora=[
+ 'projects/123456789/locations/us-central1/ragCorpora/1234567890'
+ ],
+ ),
+ FunctionTool(func=noop_tool),
+ ],
+ )
+ runner = utils.InMemoryRunner(agent)
+ events = runner.run('test1')
+
+ # Asserts the requests.
+ assert len(mockModel.requests) == 1
+ assert utils.simplify_contents(mockModel.requests[0].contents) == [
+ ('user', 'test1'),
+ ]
+ assert len(mockModel.requests[0].config.tools[0].function_declarations) == 2
+ assert (
+ mockModel.requests[0].config.tools[0].function_declarations[0].name
+ == 'rag_retrieval'
+ )
+ assert (
+ mockModel.requests[0].config.tools[0].function_declarations[1].name
+ == 'noop_tool'
+ )
+ assert mockModel.requests[0].tools_dict['rag_retrieval'] is not None
+
+
+def test_vertex_rag_retrieval_for_gemini_2_x():
+ responses = [
+ 'response1',
+ ]
+ mockModel = utils.MockModel.create(responses=responses)
+ mockModel.model = 'gemini-2.0-flash'
+
+ # Calls the first time.
+ agent = Agent(
+ name='root_agent',
+ model=mockModel,
+ tools=[
+ VertexAiRagRetrieval(
+ name='rag_retrieval',
+ description='rag_retrieval',
+ rag_corpora=[
+ 'projects/123456789/locations/us-central1/ragCorpora/1234567890'
+ ],
+ )
+ ],
+ )
+ runner = utils.InMemoryRunner(agent)
+ events = runner.run('test1')
+
+ # Asserts the requests.
+ assert len(mockModel.requests) == 1
+ assert utils.simplify_contents(mockModel.requests[0].contents) == [
+ ('user', 'test1'),
+ ]
+ assert len(mockModel.requests[0].config.tools) == 1
+ assert mockModel.requests[0].config.tools == [
+ types.Tool(
+ retrieval=types.Retrieval(
+ vertex_rag_store=types.VertexRagStore(
+ rag_corpora=[
+ 'projects/123456789/locations/us-central1/ragCorpora/1234567890'
+ ]
+ )
+ )
+ )
+ ]
+ assert 'rag_retrieval' not in mockModel.requests[0].tools_dict
diff --git a/src/google/adk/tests/unittests/tools/test_agent_tool.py b/src/google/adk/tests/unittests/tools/test_agent_tool.py
new file mode 100644
index 0000000..dc8cdeb
--- /dev/null
+++ b/src/google/adk/tests/unittests/tools/test_agent_tool.py
@@ -0,0 +1,167 @@
+# Copyright 2025 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from google.adk.agents import Agent
+from google.adk.agents.callback_context import CallbackContext
+from google.adk.tools.agent_tool import AgentTool
+from google.genai.types import Part
+from pydantic import BaseModel
+import pytest
+from pytest import mark
+
+from .. import utils
+
+pytestmark = pytest.mark.skip(
+ reason='Skipping until tool.func evaluations are fixed (async)'
+)
+
+
+function_call_custom = Part.from_function_call(
+ name='tool_agent', args={'custom_input': 'test1'}
+)
+
+function_call_no_schema = Part.from_function_call(
+ name='tool_agent', args={'request': 'test1'}
+)
+
+function_response_custom = Part.from_function_response(
+ name='tool_agent', response={'custom_output': 'response1'}
+)
+
+function_response_no_schema = Part.from_function_response(
+ name='tool_agent', response={'result': 'response1'}
+)
+
+
+def change_state_callback(callback_context: CallbackContext):
+ callback_context.state['state_1'] = 'changed_value'
+ print('change_state_callback: ', callback_context.state)
+
+
+def test_no_schema():
+ mock_model = utils.MockModel.create(
+ responses=[
+ function_call_no_schema,
+ 'response1',
+ 'response2',
+ ]
+ )
+
+ tool_agent = Agent(
+ name='tool_agent',
+ model=mock_model,
+ )
+
+ root_agent = Agent(
+ name='root_agent',
+ model=mock_model,
+ tools=[AgentTool(agent=tool_agent)],
+ )
+
+ runner = utils.InMemoryRunner(root_agent)
+
+ assert utils.simplify_events(runner.run('test1')) == [
+ ('root_agent', function_call_no_schema),
+ ('root_agent', function_response_no_schema),
+ ('root_agent', 'response2'),
+ ]
+
+
+def test_update_state():
+ """The agent tool can read and change parent state."""
+
+ mock_model = utils.MockModel.create(
+ responses=[
+ function_call_no_schema,
+ '{"custom_output": "response1"}',
+ 'response2',
+ ]
+ )
+
+ tool_agent = Agent(
+ name='tool_agent',
+ model=mock_model,
+ instruction='input: {state_1}',
+ before_agent_callback=change_state_callback,
+ )
+
+ root_agent = Agent(
+ name='root_agent',
+ model=mock_model,
+ tools=[AgentTool(agent=tool_agent)],
+ )
+
+ runner = utils.InMemoryRunner(root_agent)
+ runner.session.state['state_1'] = 'state1_value'
+
+ runner.run('test1')
+ assert (
+ 'input: changed_value' in mock_model.requests[1].config.system_instruction
+ )
+ assert runner.session.state['state_1'] == 'changed_value'
+
+
+@mark.parametrize(
+ 'env_variables',
+ [
+ 'GOOGLE_AI',
+ # TODO(wanyif): re-enable after fix.
+ # 'VERTEX',
+ ],
+ indirect=True,
+)
+def test_custom_schema():
+ class CustomInput(BaseModel):
+ custom_input: str
+
+ class CustomOutput(BaseModel):
+ custom_output: str
+
+ mock_model = utils.MockModel.create(
+ responses=[
+ function_call_custom,
+ '{"custom_output": "response1"}',
+ 'response2',
+ ]
+ )
+
+ tool_agent = Agent(
+ name='tool_agent',
+ model=mock_model,
+ input_schema=CustomInput,
+ output_schema=CustomOutput,
+ output_key='tool_output',
+ )
+
+ root_agent = Agent(
+ name='root_agent',
+ model=mock_model,
+ tools=[AgentTool(agent=tool_agent)],
+ )
+
+ runner = utils.InMemoryRunner(root_agent)
+ runner.session.state['state_1'] = 'state1_value'
+
+ assert utils.simplify_events(runner.run('test1')) == [
+ ('root_agent', function_call_custom),
+ ('root_agent', function_response_custom),
+ ('root_agent', 'response2'),
+ ]
+
+ assert runner.session.state['tool_output'] == {'custom_output': 'response1'}
+
+ assert len(mock_model.requests) == 3
+ # The second request is the tool agent request.
+ assert mock_model.requests[1].config.response_schema == CustomOutput
+ assert mock_model.requests[1].config.response_mime_type == 'application/json'
diff --git a/src/google/adk/tests/unittests/tools/test_base_tool.py b/src/google/adk/tests/unittests/tools/test_base_tool.py
new file mode 100644
index 0000000..13f06d7
--- /dev/null
+++ b/src/google/adk/tests/unittests/tools/test_base_tool.py
@@ -0,0 +1,141 @@
+# Copyright 2025 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import Optional
+
+from google.adk.agents.invocation_context import InvocationContext
+from google.adk.agents.sequential_agent import SequentialAgent
+from google.adk.models.llm_request import LlmRequest
+from google.adk.sessions.in_memory_session_service import InMemorySessionService
+from google.adk.tools.base_tool import BaseTool
+from google.adk.tools.tool_context import ToolContext
+from google.genai import types
+import pytest
+
+
+class _TestingTool(BaseTool):
+
+ def __init__(
+ self,
+ declaration: Optional[types.FunctionDeclaration] = None,
+ ):
+ super().__init__(name='test_tool', description='test_description')
+ self.declaration = declaration
+
+ def _get_declaration(self) -> Optional[types.FunctionDeclaration]:
+ return self.declaration
+
+
+def _create_tool_context() -> ToolContext:
+ session_service = InMemorySessionService()
+ session = session_service.create_session(
+ app_name='test_app', user_id='test_user'
+ )
+ agent = SequentialAgent(name='test_agent')
+ invocation_context = InvocationContext(
+ invocation_id='invocation_id',
+ agent=agent,
+ session=session,
+ session_service=session_service,
+ )
+ return ToolContext(invocation_context)
+
+
+@pytest.mark.asyncio
+async def test_process_llm_request_no_declaration():
+ tool = _TestingTool()
+ tool_context = _create_tool_context()
+ llm_request = LlmRequest()
+
+ await tool.process_llm_request(
+ tool_context=tool_context, llm_request=llm_request
+ )
+
+ assert llm_request.config is None
+
+
+@pytest.mark.asyncio
+async def test_process_llm_request_with_declaration():
+ declaration = types.FunctionDeclaration(
+ name='test_tool',
+ description='test_description',
+ parameters=types.Schema(
+ type=types.Type.STRING,
+ title='param_1',
+ ),
+ )
+ tool = _TestingTool(declaration)
+ llm_request = LlmRequest()
+ tool_context = _create_tool_context()
+
+ await tool.process_llm_request(
+ tool_context=tool_context, llm_request=llm_request
+ )
+
+ assert llm_request.config.tools[0].function_declarations == [declaration]
+
+
+@pytest.mark.asyncio
+async def test_process_llm_request_with_builtin_tool():
+ declaration = types.FunctionDeclaration(
+ name='test_tool',
+ description='test_description',
+ parameters=types.Schema(
+ type=types.Type.STRING,
+ title='param_1',
+ ),
+ )
+ tool = _TestingTool(declaration)
+ llm_request = LlmRequest(
+ config=types.GenerateContentConfig(
+ tools=[types.Tool(google_search=types.GoogleSearch())]
+ )
+ )
+ tool_context = _create_tool_context()
+
+ await tool.process_llm_request(
+ tool_context=tool_context, llm_request=llm_request
+ )
+
+ # function_declaration is added to another types.Tool without builtin tool.
+ assert llm_request.config.tools[1].function_declarations == [declaration]
+
+
+@pytest.mark.asyncio
+async def test_process_llm_request_with_builtin_tool_and_another_declaration():
+ declaration = types.FunctionDeclaration(
+ name='test_tool',
+ description='test_description',
+ parameters=types.Schema(
+ type=types.Type.STRING,
+ title='param_1',
+ ),
+ )
+ tool = _TestingTool(declaration)
+ llm_request = LlmRequest(
+ config=types.GenerateContentConfig(
+ tools=[
+ types.Tool(google_search=types.GoogleSearch()),
+ types.Tool(function_declarations=[types.FunctionDeclaration()]),
+ ]
+ )
+ )
+ tool_context = _create_tool_context()
+
+ await tool.process_llm_request(
+ tool_context=tool_context, llm_request=llm_request
+ )
+
+ # function_declaration is added to existing types.Tool with function_declaration.
+ assert llm_request.config.tools[1].function_declarations[1] == declaration
diff --git a/src/google/adk/tests/unittests/tools/test_build_function_declaration.py b/src/google/adk/tests/unittests/tools/test_build_function_declaration.py
new file mode 100644
index 0000000..d71d6d2
--- /dev/null
+++ b/src/google/adk/tests/unittests/tools/test_build_function_declaration.py
@@ -0,0 +1,277 @@
+# Copyright 2025 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import Dict
+from typing import List
+
+from google.adk.tools import _automatic_function_calling_util
+from google.adk.tools.agent_tool import ToolContext
+from google.adk.tools.langchain_tool import LangchainTool
+# TODO: crewai requires python 3.10 as minimum
+# from crewai_tools import FileReadTool
+from langchain_community.tools import ShellTool
+from pydantic import BaseModel
+import pytest
+
+
+def test_unsupported_variant():
+ def simple_function(input_str: str) -> str:
+ return {'result': input_str}
+
+ with pytest.raises(ValueError):
+ _automatic_function_calling_util.build_function_declaration(
+ func=simple_function, variant='Unsupported'
+ )
+
+
+def test_string_input():
+ def simple_function(input_str: str) -> str:
+ return {'result': input_str}
+
+ function_decl = _automatic_function_calling_util.build_function_declaration(
+ func=simple_function
+ )
+
+ assert function_decl.name == 'simple_function'
+ assert function_decl.parameters.type == 'OBJECT'
+ assert function_decl.parameters.properties['input_str'].type == 'STRING'
+
+
+def test_int_input():
+ def simple_function(input_str: int) -> str:
+ return {'result': input_str}
+
+ function_decl = _automatic_function_calling_util.build_function_declaration(
+ func=simple_function
+ )
+
+ assert function_decl.name == 'simple_function'
+ assert function_decl.parameters.type == 'OBJECT'
+ assert function_decl.parameters.properties['input_str'].type == 'INTEGER'
+
+
+def test_float_input():
+ def simple_function(input_str: float) -> str:
+ return {'result': input_str}
+
+ function_decl = _automatic_function_calling_util.build_function_declaration(
+ func=simple_function
+ )
+
+ assert function_decl.name == 'simple_function'
+ assert function_decl.parameters.type == 'OBJECT'
+ assert function_decl.parameters.properties['input_str'].type == 'NUMBER'
+
+
+def test_bool_input():
+ def simple_function(input_str: bool) -> str:
+ return {'result': input_str}
+
+ function_decl = _automatic_function_calling_util.build_function_declaration(
+ func=simple_function
+ )
+
+ assert function_decl.name == 'simple_function'
+ assert function_decl.parameters.type == 'OBJECT'
+ assert function_decl.parameters.properties['input_str'].type == 'BOOLEAN'
+
+
+def test_array_input():
+ def simple_function(input_str: List[str]) -> str:
+ return {'result': input_str}
+
+ function_decl = _automatic_function_calling_util.build_function_declaration(
+ func=simple_function
+ )
+
+ assert function_decl.name == 'simple_function'
+ assert function_decl.parameters.type == 'OBJECT'
+ assert function_decl.parameters.properties['input_str'].type == 'ARRAY'
+
+
+def test_dict_input():
+ def simple_function(input_str: Dict[str, str]) -> str:
+ return {'result': input_str}
+
+ function_decl = _automatic_function_calling_util.build_function_declaration(
+ func=simple_function
+ )
+
+ assert function_decl.name == 'simple_function'
+ assert function_decl.parameters.type == 'OBJECT'
+ assert function_decl.parameters.properties['input_str'].type == 'OBJECT'
+
+
+def test_basemodel_input():
+ class CustomInput(BaseModel):
+ input_str: str
+
+ def simple_function(input: CustomInput) -> str:
+ return {'result': input}
+
+ function_decl = _automatic_function_calling_util.build_function_declaration(
+ func=simple_function
+ )
+
+ assert function_decl.name == 'simple_function'
+ assert function_decl.parameters.type == 'OBJECT'
+ assert function_decl.parameters.properties['input'].type == 'OBJECT'
+ assert (
+ function_decl.parameters.properties['input'].properties['input_str'].type
+ == 'STRING'
+ )
+
+
+def test_toolcontext_ignored():
+ def simple_function(input_str: str, tool_context: ToolContext) -> str:
+ return {'result': input_str}
+
+ function_decl = _automatic_function_calling_util.build_function_declaration(
+ func=simple_function, ignore_params=['tool_context']
+ )
+
+ assert function_decl.name == 'simple_function'
+ assert function_decl.parameters.type == 'OBJECT'
+ assert function_decl.parameters.properties['input_str'].type == 'STRING'
+ assert 'tool_context' not in function_decl.parameters.properties
+
+
+def test_basemodel():
+ class SimpleFunction(BaseModel):
+ input_str: str
+ custom_input: int
+
+ function_decl = _automatic_function_calling_util.build_function_declaration(
+ func=SimpleFunction, ignore_params=['custom_input']
+ )
+
+ assert function_decl.name == 'SimpleFunction'
+ assert function_decl.parameters.type == 'OBJECT'
+ assert function_decl.parameters.properties['input_str'].type == 'STRING'
+ assert 'custom_input' not in function_decl.parameters.properties
+
+
+def test_nested_basemodel_input():
+ class ChildInput(BaseModel):
+ input_str: str
+
+ class CustomInput(BaseModel):
+ child: ChildInput
+
+ def simple_function(input: CustomInput) -> str:
+ return {'result': input}
+
+ function_decl = _automatic_function_calling_util.build_function_declaration(
+ func=simple_function
+ )
+
+ assert function_decl.name == 'simple_function'
+ assert function_decl.parameters.type == 'OBJECT'
+ assert function_decl.parameters.properties['input'].type == 'OBJECT'
+ assert (
+ function_decl.parameters.properties['input'].properties['child'].type
+ == 'OBJECT'
+ )
+ assert (
+ function_decl.parameters.properties['input']
+ .properties['child']
+ .properties['input_str']
+ .type
+ == 'STRING'
+ )
+
+
+def test_basemodel_with_nested_basemodel():
+ class ChildInput(BaseModel):
+ input_str: str
+
+ class CustomInput(BaseModel):
+ child: ChildInput
+
+ function_decl = _automatic_function_calling_util.build_function_declaration(
+ func=CustomInput, ignore_params=['custom_input']
+ )
+
+ assert function_decl.name == 'CustomInput'
+ assert function_decl.parameters.type == 'OBJECT'
+ assert function_decl.parameters.properties['child'].type == 'OBJECT'
+ assert (
+ function_decl.parameters.properties['child'].properties['input_str'].type
+ == 'STRING'
+ )
+ assert 'custom_input' not in function_decl.parameters.properties
+
+
+def test_list():
+ def simple_function(
+ input_str: List[str], input_dir: List[Dict[str, str]]
+ ) -> str:
+ return {'result': input_str}
+
+ function_decl = _automatic_function_calling_util.build_function_declaration(
+ func=simple_function
+ )
+
+ assert function_decl.name == 'simple_function'
+ assert function_decl.parameters.type == 'OBJECT'
+ assert function_decl.parameters.properties['input_str'].type == 'ARRAY'
+ assert function_decl.parameters.properties['input_str'].items.type == 'STRING'
+ assert function_decl.parameters.properties['input_dir'].type == 'ARRAY'
+ assert function_decl.parameters.properties['input_dir'].items.type == 'OBJECT'
+
+
+def test_basemodel_list():
+ class ChildInput(BaseModel):
+ input_str: str
+
+ class CustomInput(BaseModel):
+ child: ChildInput
+
+ def simple_function(input_str: List[CustomInput]) -> str:
+ return {'result': input_str}
+
+ function_decl = _automatic_function_calling_util.build_function_declaration(
+ func=simple_function
+ )
+
+ assert function_decl.name == 'simple_function'
+ assert function_decl.parameters.type == 'OBJECT'
+ assert function_decl.parameters.properties['input_str'].type == 'ARRAY'
+ assert function_decl.parameters.properties['input_str'].items.type == 'OBJECT'
+ assert (
+ function_decl.parameters.properties['input_str']
+ .items.properties['child']
+ .type
+ == 'OBJECT'
+ )
+ assert (
+ function_decl.parameters.properties['input_str']
+ .items.properties['child']
+ .properties['input_str']
+ .type
+ == 'STRING'
+ )
+
+
+# TODO: comment out this test for now as crewai requires python 3.10 as minimum
+# def test_crewai_tool():
+# docs_tool = CrewaiTool(
+# name='direcotry_read_tool',
+# description='use this to find files for you.',
+# tool=FileReadTool(),
+# )
+# function_decl = docs_tool.get_declaration()
+# assert function_decl.name == 'direcotry_read_tool'
+# assert function_decl.parameters.type == 'OBJECT'
+# assert function_decl.parameters.properties['file_path'].type == 'STRING'
diff --git a/src/google/adk/tests/unittests/utils.py b/src/google/adk/tests/unittests/utils.py
new file mode 100644
index 0000000..592ff15
--- /dev/null
+++ b/src/google/adk/tests/unittests/utils.py
@@ -0,0 +1,304 @@
+# Copyright 2025 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import asyncio
+import contextlib
+from typing import AsyncGenerator
+from typing import Generator
+from typing import Union
+
+from google.adk.agents.invocation_context import InvocationContext
+from google.adk.agents.live_request_queue import LiveRequestQueue
+from google.adk.agents.llm_agent import Agent
+from google.adk.agents.llm_agent import LlmAgent
+from google.adk.agents.run_config import RunConfig
+from google.adk.artifacts import InMemoryArtifactService
+from google.adk.events.event import Event
+from google.adk.memory.in_memory_memory_service import InMemoryMemoryService
+from google.adk.models.base_llm import BaseLlm
+from google.adk.models.base_llm_connection import BaseLlmConnection
+from google.adk.models.llm_request import LlmRequest
+from google.adk.models.llm_response import LlmResponse
+from google.adk.runners import InMemoryRunner as AfInMemoryRunner
+from google.adk.runners import Runner
+from google.adk.sessions.in_memory_session_service import InMemorySessionService
+from google.adk.sessions.session import Session
+from google.genai import types
+from google.genai.types import Part
+from typing_extensions import override
+
+
+class UserContent(types.Content):
+
+ def __init__(self, text_or_part: str):
+ parts = [
+ types.Part.from_text(text=text_or_part)
+ if isinstance(text_or_part, str)
+ else text_or_part
+ ]
+ super().__init__(role='user', parts=parts)
+
+
+class ModelContent(types.Content):
+
+ def __init__(self, parts: list[types.Part]):
+ super().__init__(role='model', parts=parts)
+
+
+def create_invocation_context(agent: Agent, user_content: str = ''):
+ invocation_id = 'test_id'
+ artifact_service = InMemoryArtifactService()
+ session_service = InMemorySessionService()
+ memory_service = InMemoryMemoryService()
+ invocation_context = InvocationContext(
+ artifact_service=artifact_service,
+ session_service=session_service,
+ memory_service=memory_service,
+ invocation_id=invocation_id,
+ agent=agent,
+ session=session_service.create_session(
+ app_name='test_app', user_id='test_user'
+ ),
+ user_content=types.Content(
+ role='user', parts=[types.Part.from_text(text=user_content)]
+ ),
+ run_config=RunConfig(),
+ )
+ if user_content:
+ append_user_content(
+ invocation_context, [types.Part.from_text(text=user_content)]
+ )
+ return invocation_context
+
+
+def append_user_content(
+ invocation_context: InvocationContext, parts: list[types.Part]
+) -> Event:
+ session = invocation_context.session
+ event = Event(
+ invocation_id=invocation_context.invocation_id,
+ author='user',
+ content=types.Content(role='user', parts=parts),
+ )
+ session.events.append(event)
+ return event
+
+
+# Extracts the contents from the events and transform them into a list of
+# (author, simplified_content) tuples.
+def simplify_events(events: list[Event]) -> list[(str, types.Part)]:
+ return [(event.author, simplify_content(event.content)) for event in events]
+
+
+# Simplifies the contents into a list of (author, simplified_content) tuples.
+def simplify_contents(contents: list[types.Content]) -> list[(str, types.Part)]:
+ return [(content.role, simplify_content(content)) for content in contents]
+
+
+# Simplifies the content so it's easier to assert.
+# - If there is only one part, return part
+# - If the only part is pure text, return stripped_text
+# - If there are multiple parts, return parts
+# - remove function_call_id if it exists
+def simplify_content(
+ content: types.Content,
+) -> Union[str, types.Part, list[types.Part]]:
+ for part in content.parts:
+ if part.function_call and part.function_call.id:
+ part.function_call.id = None
+ if part.function_response and part.function_response.id:
+ part.function_response.id = None
+ if len(content.parts) == 1:
+ if content.parts[0].text:
+ return content.parts[0].text.strip()
+ else:
+ return content.parts[0]
+ return content.parts
+
+
+def get_user_content(message: types.ContentUnion) -> types.Content:
+ return message if isinstance(message, types.Content) else UserContent(message)
+
+
+class TestInMemoryRunner(AfInMemoryRunner):
+ """InMemoryRunner that is tailored for tests, features async run method.
+
+ app_name is hardcoded as InMemoryRunner in the parent class.
+ """
+
+ async def run_async_with_new_session(
+ self, new_message: types.ContentUnion
+ ) -> list[Event]:
+
+ session = self.session_service.create_session(
+ app_name='InMemoryRunner', user_id='test_user'
+ )
+ collected_events = []
+
+ async for event in self.run_async(
+ user_id=session.user_id,
+ session_id=session.id,
+ new_message=get_user_content(new_message),
+ ):
+ collected_events.append(event)
+
+ return collected_events
+
+
+class InMemoryRunner:
+ """InMemoryRunner that is tailored for tests."""
+
+ def __init__(
+ self,
+ root_agent: Union[Agent, LlmAgent],
+ response_modalities: list[str] = None,
+ ):
+ self.root_agent = root_agent
+ self.runner = Runner(
+ app_name='test_app',
+ agent=root_agent,
+ artifact_service=InMemoryArtifactService(),
+ session_service=InMemorySessionService(),
+ memory_service=InMemoryMemoryService(),
+ )
+ self.session_id = self.runner.session_service.create_session(
+ app_name='test_app', user_id='test_user'
+ ).id
+
+ @property
+ def session(self) -> Session:
+ return self.runner.session_service.get_session(
+ app_name='test_app', user_id='test_user', session_id=self.session_id
+ )
+
+ def run(self, new_message: types.ContentUnion) -> list[Event]:
+ return list(
+ self.runner.run(
+ user_id=self.session.user_id,
+ session_id=self.session.id,
+ new_message=get_user_content(new_message),
+ )
+ )
+
+ def run_live(self, live_request_queue: LiveRequestQueue) -> list[Event]:
+ collected_responses = []
+
+ async def consume_responses():
+ run_res = self.runner.run_live(
+ session=self.session,
+ live_request_queue=live_request_queue,
+ )
+
+ async for response in run_res:
+ collected_responses.append(response)
+ # When we have enough response, we should return
+ if len(collected_responses) >= 1:
+ return
+
+ try:
+ asyncio.run(consume_responses())
+ except asyncio.TimeoutError:
+ print('Returning any partial results collected so far.')
+
+ return collected_responses
+
+
+class MockModel(BaseLlm):
+ model: str = 'mock'
+
+ requests: list[LlmRequest] = []
+ responses: list[LlmResponse]
+ response_index: int = -1
+
+ @classmethod
+ def create(
+ cls,
+ responses: Union[
+ list[types.Part], list[LlmResponse], list[str], list[list[types.Part]]
+ ],
+ ):
+ if not responses:
+ return cls(responses=[])
+ elif isinstance(responses[0], LlmResponse):
+ # reponses is list[LlmResponse]
+ return cls(responses=responses)
+ else:
+ responses = [
+ LlmResponse(content=ModelContent(item))
+ if isinstance(item, list) and isinstance(item[0], types.Part)
+ # responses is list[list[Part]]
+ else LlmResponse(
+ content=ModelContent(
+ # responses is list[str] or list[Part]
+ [Part(text=item) if isinstance(item, str) else item]
+ )
+ )
+ for item in responses
+ if item
+ ]
+
+ return cls(responses=responses)
+
+ @staticmethod
+ def supported_models() -> list[str]:
+ return ['mock']
+
+ def generate_content(
+ self, llm_request: LlmRequest, stream: bool = False
+ ) -> Generator[LlmResponse, None, None]:
+ # Increasement of the index has to happen before the yield.
+ self.response_index += 1
+ self.requests.append(llm_request)
+ # yield LlmResponse(content=self.responses[self.response_index])
+ yield self.responses[self.response_index]
+
+ @override
+ async def generate_content_async(
+ self, llm_request: LlmRequest, stream: bool = False
+ ) -> AsyncGenerator[LlmResponse, None]:
+ # Increasement of the index has to happen before the yield.
+ self.response_index += 1
+ self.requests.append(llm_request)
+ yield self.responses[self.response_index]
+
+ @contextlib.asynccontextmanager
+ async def connect(self, llm_request: LlmRequest) -> BaseLlmConnection:
+ """Creates a live connection to the LLM."""
+ yield MockLlmConnection(self.responses)
+
+
+class MockLlmConnection(BaseLlmConnection):
+
+ def __init__(self, llm_responses: list[LlmResponse]):
+ self.llm_responses = llm_responses
+
+ async def send_history(self, history: list[types.Content]):
+ pass
+
+ async def send_content(self, content: types.Content):
+ pass
+
+ async def send(self, data):
+ pass
+
+ async def send_realtime(self, blob: types.Blob):
+ pass
+
+ async def receive(self) -> AsyncGenerator[LlmResponse, None]:
+ """Yield each of the pre-defined LlmResponses."""
+ for response in self.llm_responses:
+ yield response
+
+ async def close(self):
+ pass
diff --git a/src/google/adk/tools/__init__.py b/src/google/adk/tools/__init__.py
new file mode 100644
index 0000000..8c74f0d
--- /dev/null
+++ b/src/google/adk/tools/__init__.py
@@ -0,0 +1,51 @@
+# Copyright 2025 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# pylint: disable=g-bad-import-order
+from .base_tool import BaseTool
+
+from ..auth.auth_tool import AuthToolArguments
+from .apihub_tool.apihub_toolset import APIHubToolset
+from .built_in_code_execution_tool import built_in_code_execution
+from .google_search_tool import google_search
+from .vertex_ai_search_tool import VertexAiSearchTool
+from .example_tool import ExampleTool
+from .exit_loop_tool import exit_loop
+from .function_tool import FunctionTool
+from .get_user_choice_tool import get_user_choice_tool as get_user_choice
+from .load_artifacts_tool import load_artifacts_tool as load_artifacts
+from .load_memory_tool import load_memory_tool as load_memory
+from .long_running_tool import LongRunningFunctionTool
+from .preload_memory_tool import preload_memory_tool as preload_memory
+from .tool_context import ToolContext
+from .transfer_to_agent_tool import transfer_to_agent
+
+
+__all__ = [
+ 'APIHubToolset',
+ 'AuthToolArguments',
+ 'BaseTool',
+ 'built_in_code_execution',
+ 'google_search',
+ 'VertexAiSearchTool',
+ 'ExampleTool',
+ 'exit_loop',
+ 'FunctionTool',
+ 'get_user_choice',
+ 'load_artifacts',
+ 'load_memory',
+ 'LongRunningFunctionTool',
+ 'preload_memory',
+ 'ToolContext',
+ 'transfer_to_agent',
+]
diff --git a/src/google/adk/tools/_automatic_function_calling_util.py b/src/google/adk/tools/_automatic_function_calling_util.py
new file mode 100644
index 0000000..abfb4e7
--- /dev/null
+++ b/src/google/adk/tools/_automatic_function_calling_util.py
@@ -0,0 +1,346 @@
+# Copyright 2025 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Forked from google3/third_party/py/google/genai/_automatic_function_calling_util.py temporarily."""
+
+import inspect
+from types import FunctionType
+from typing import Any
+from typing import Callable
+from typing import Dict
+from typing import Literal
+from typing import Optional
+from typing import Union
+
+from google.genai import types
+import pydantic
+from pydantic import BaseModel
+from pydantic import create_model
+from pydantic import fields as pydantic_fields
+
+from . import function_parameter_parse_util
+
+_py_type_2_schema_type = {
+ 'str': types.Type.STRING,
+ 'int': types.Type.INTEGER,
+ 'float': types.Type.NUMBER,
+ 'bool': types.Type.BOOLEAN,
+ 'string': types.Type.STRING,
+ 'integer': types.Type.INTEGER,
+ 'number': types.Type.NUMBER,
+ 'boolean': types.Type.BOOLEAN,
+ 'list': types.Type.ARRAY,
+ 'array': types.Type.ARRAY,
+ 'tuple': types.Type.ARRAY,
+ 'object': types.Type.OBJECT,
+ 'Dict': types.Type.OBJECT,
+ 'List': types.Type.ARRAY,
+ 'Tuple': types.Type.ARRAY,
+ 'Any': types.Type.TYPE_UNSPECIFIED,
+}
+
+
+def _get_fields_dict(func: Callable) -> Dict:
+ param_signature = dict(inspect.signature(func).parameters)
+ fields_dict = {
+ name: (
+ # 1. We infer the argument type here: use Any rather than None so
+ # it will not try to auto-infer the type based on the default value.
+ (
+ param.annotation
+ if param.annotation != inspect.Parameter.empty
+ else Any
+ ),
+ pydantic.Field(
+ # 2. We do not support default values for now.
+ default=(
+ param.default
+ if param.default != inspect.Parameter.empty
+ # ! Need to use Undefined instead of None
+ else pydantic_fields.PydanticUndefined
+ ),
+ # 3. Do not support parameter description for now.
+ description=None,
+ ),
+ )
+ for name, param in param_signature.items()
+ # We do not support *args or **kwargs
+ if param.kind
+ in (
+ inspect.Parameter.POSITIONAL_OR_KEYWORD,
+ inspect.Parameter.KEYWORD_ONLY,
+ inspect.Parameter.POSITIONAL_ONLY,
+ )
+ }
+ return fields_dict
+
+
+def _annotate_nullable_fields(schema: Dict):
+ for _, property_schema in schema.get('properties', {}).items():
+ # for Optional[T], the pydantic schema is:
+ # {
+ # "type": "object",
+ # "properties": {
+ # "anyOf": [
+ # {
+ # "type": "null"
+ # },
+ # {
+ # "type": "T"
+ # }
+ # ]
+ # }
+ # }
+ for type_ in property_schema.get('anyOf', []):
+ if type_.get('type') == 'null':
+ property_schema['nullable'] = True
+ property_schema['anyOf'].remove(type_)
+ break
+
+
+def _annotate_required_fields(schema: Dict):
+ required = [
+ field_name
+ for field_name, field_schema in schema.get('properties', {}).items()
+ if not field_schema.get('nullable') and 'default' not in field_schema
+ ]
+ schema['required'] = required
+
+
+def _remove_any_of(schema: Dict):
+ for _, property_schema in schema.get('properties', {}).items():
+ union_types = property_schema.pop('anyOf', None)
+ # Take the first non-null type.
+ if union_types:
+ for type_ in union_types:
+ if type_.get('type') != 'null':
+ property_schema.update(type_)
+
+
+def _remove_default(schema: Dict):
+ for _, property_schema in schema.get('properties', {}).items():
+ property_schema.pop('default', None)
+
+
+def _remove_nullable(schema: Dict):
+ for _, property_schema in schema.get('properties', {}).items():
+ property_schema.pop('nullable', None)
+
+
+def _remove_title(schema: Dict):
+ for _, property_schema in schema.get('properties', {}).items():
+ property_schema.pop('title', None)
+
+
+def _get_pydantic_schema(func: Callable) -> Dict:
+ fields_dict = _get_fields_dict(func)
+ if 'tool_context' in fields_dict.keys():
+ fields_dict.pop('tool_context')
+ return pydantic.create_model(func.__name__, **fields_dict).model_json_schema()
+
+
+def _process_pydantic_schema(vertexai: bool, schema: Dict) -> Dict:
+ _annotate_nullable_fields(schema)
+ _annotate_required_fields(schema)
+ if not vertexai:
+ _remove_any_of(schema)
+ _remove_default(schema)
+ _remove_nullable(schema)
+ _remove_title(schema)
+ return schema
+
+
+def _map_pydantic_type_to_property_schema(property_schema: Dict):
+ if 'type' in property_schema:
+ property_schema['type'] = _py_type_2_schema_type.get(
+ property_schema['type'], 'TYPE_UNSPECIFIED'
+ )
+ if property_schema['type'] == 'ARRAY':
+ _map_pydantic_type_to_property_schema(property_schema['items'])
+ for type_ in property_schema.get('anyOf', []):
+ if 'type' in type_:
+ type_['type'] = _py_type_2_schema_type.get(
+ type_['type'], 'TYPE_UNSPECIFIED'
+ )
+ # TODO: To investigate. Unclear why a Type is needed with 'anyOf' to
+ # avoid google.genai.errors.ClientError: 400 INVALID_ARGUMENT.
+ property_schema['type'] = type_['type']
+
+
+def _map_pydantic_type_to_schema_type(schema: Dict):
+ for _, property_schema in schema.get('properties', {}).items():
+ _map_pydantic_type_to_property_schema(property_schema)
+
+
+def _get_return_type(func: Callable) -> Any:
+ return _py_type_2_schema_type.get(
+ inspect.signature(func).return_annotation.__name__,
+ inspect.signature(func).return_annotation.__name__,
+ )
+
+
+def build_function_declaration(
+ func: Union[Callable, BaseModel],
+ ignore_params: Optional[list[str]] = None,
+ variant: Literal['GOOGLE_AI', 'VERTEX_AI', 'DEFAULT'] = 'GOOGLE_AI',
+) -> types.FunctionDeclaration:
+ signature = inspect.signature(func)
+ should_update_signature = False
+ new_func = None
+ if not ignore_params:
+ ignore_params = []
+ for name, _ in signature.parameters.items():
+ if name in ignore_params:
+ should_update_signature = True
+ break
+ if should_update_signature:
+ new_params = [
+ param
+ for name, param in signature.parameters.items()
+ if name not in ignore_params
+ ]
+ if isinstance(func, type):
+ fields = {
+ name: (param.annotation, param.default)
+ for name, param in signature.parameters.items()
+ if name not in ignore_params
+ }
+ new_func = create_model(func.__name__, **fields)
+ else:
+ new_sig = signature.replace(parameters=new_params)
+ new_func = FunctionType(
+ func.__code__,
+ func.__globals__,
+ func.__name__,
+ func.__defaults__,
+ func.__closure__,
+ )
+ new_func.__signature__ = new_sig
+
+ return (
+ from_function_with_options(func, variant)
+ if not should_update_signature
+ else from_function_with_options(new_func, variant)
+ )
+
+
+def build_function_declaration_for_langchain(
+ vertexai: bool, name, description, func, param_pydantic_schema
+) -> types.FunctionDeclaration:
+ param_pydantic_schema = _process_pydantic_schema(
+ vertexai, {'properties': param_pydantic_schema}
+ )['properties']
+ param_copy = param_pydantic_schema.copy()
+ required_fields = param_copy.pop('required', [])
+ before_param_pydantic_schema = {
+ 'properties': param_copy,
+ 'required': required_fields,
+ }
+ return build_function_declaration_util(
+ vertexai, name, description, func, before_param_pydantic_schema
+ )
+
+
+def build_function_declaration_for_params_for_crewai(
+ vertexai: bool, name, description, func, param_pydantic_schema
+) -> types.FunctionDeclaration:
+ param_pydantic_schema = _process_pydantic_schema(
+ vertexai, param_pydantic_schema
+ )
+ param_copy = param_pydantic_schema.copy()
+ return build_function_declaration_util(
+ vertexai, name, description, func, param_copy
+ )
+
+
+def build_function_declaration_util(
+ vertexai: bool, name, description, func, before_param_pydantic_schema
+) -> types.FunctionDeclaration:
+ _map_pydantic_type_to_schema_type(before_param_pydantic_schema)
+ properties = before_param_pydantic_schema.get('properties', {})
+ function_declaration = types.FunctionDeclaration(
+ parameters=types.Schema(
+ type='OBJECT',
+ properties=properties,
+ )
+ if properties
+ else None,
+ description=description,
+ name=name,
+ )
+ if vertexai and isinstance(func, Callable):
+ return_pydantic_schema = _get_return_type(func)
+ function_declaration.response = types.Schema(
+ type=return_pydantic_schema,
+ )
+ return function_declaration
+
+
+def from_function_with_options(
+ func: Callable,
+ variant: Literal['GOOGLE_AI', 'VERTEX_AI', 'DEFAULT'] = 'GOOGLE_AI',
+) -> 'types.FunctionDeclaration':
+
+ supported_variants = ['GOOGLE_AI', 'VERTEX_AI', 'DEFAULT']
+ if variant not in supported_variants:
+ raise ValueError(
+ f'Unsupported variant: {variant}. Supported variants are:'
+ f' {", ".join(supported_variants)}'
+ )
+
+ parameters_properties = {}
+ for name, param in inspect.signature(func).parameters.items():
+ if param.kind in (
+ inspect.Parameter.POSITIONAL_OR_KEYWORD,
+ inspect.Parameter.KEYWORD_ONLY,
+ inspect.Parameter.POSITIONAL_ONLY,
+ ):
+ schema = function_parameter_parse_util._parse_schema_from_parameter(
+ variant, param, func.__name__
+ )
+ parameters_properties[name] = schema
+ declaration = types.FunctionDeclaration(
+ name=func.__name__,
+ description=func.__doc__,
+ )
+ if parameters_properties:
+ declaration.parameters = types.Schema(
+ type='OBJECT',
+ properties=parameters_properties,
+ )
+ if variant == 'VERTEX_AI':
+ declaration.parameters.required = (
+ function_parameter_parse_util._get_required_fields(
+ declaration.parameters
+ )
+ )
+ if not variant == 'VERTEX_AI':
+ return declaration
+
+ return_annotation = inspect.signature(func).return_annotation
+ if return_annotation is inspect._empty:
+ return declaration
+
+ declaration.response = (
+ function_parameter_parse_util._parse_schema_from_parameter(
+ variant,
+ inspect.Parameter(
+ 'return_value',
+ inspect.Parameter.POSITIONAL_OR_KEYWORD,
+ annotation=return_annotation,
+ ),
+ func.__name__,
+ )
+ )
+ return declaration
diff --git a/src/google/adk/tools/agent_tool.py b/src/google/adk/tools/agent_tool.py
new file mode 100644
index 0000000..460a83f
--- /dev/null
+++ b/src/google/adk/tools/agent_tool.py
@@ -0,0 +1,176 @@
+# Copyright 2025 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import annotations
+
+from typing import Any
+from typing import TYPE_CHECKING
+
+from google.genai import types
+from pydantic import model_validator
+from typing_extensions import override
+
+from ..memory.in_memory_memory_service import InMemoryMemoryService
+from ..runners import Runner
+from ..sessions.in_memory_session_service import InMemorySessionService
+from . import _automatic_function_calling_util
+from .base_tool import BaseTool
+from .tool_context import ToolContext
+
+if TYPE_CHECKING:
+ from ..agents.base_agent import BaseAgent
+ from ..agents.llm_agent import LlmAgent
+
+
+class AgentTool(BaseTool):
+ """A tool that wraps an agent.
+
+ This tool allows an agent to be called as a tool within a larger application.
+ The agent's input schema is used to define the tool's input parameters, and
+ the agent's output is returned as the tool's result.
+
+ Attributes:
+ agent: The agent to wrap.
+ skip_summarization: Whether to skip summarization of the agent output.
+ """
+
+ def __init__(self, agent: BaseAgent):
+ self.agent = agent
+ self.skip_summarization: bool = False
+ """Whether to skip summarization of the agent output."""
+
+ super().__init__(name=agent.name, description=agent.description)
+
+ @model_validator(mode='before')
+ @classmethod
+ def populate_name(cls, data: Any) -> Any:
+ data['name'] = data['agent'].name
+ return data
+
+ @override
+ def _get_declaration(self) -> types.FunctionDeclaration:
+ from ..agents.llm_agent import LlmAgent
+
+ if isinstance(self.agent, LlmAgent) and self.agent.input_schema:
+ result = _automatic_function_calling_util.build_function_declaration(
+ func=self.agent.input_schema, variant=self._api_variant
+ )
+ else:
+ result = types.FunctionDeclaration(
+ parameters=types.Schema(
+ type=types.Type.OBJECT,
+ properties={
+ 'request': types.Schema(
+ type=types.Type.STRING,
+ ),
+ },
+ required=['request'],
+ ),
+ description=self.agent.description,
+ name=self.name,
+ )
+ result.name = self.name
+ return result
+
+ @override
+ async def run_async(
+ self,
+ *,
+ args: dict[str, Any],
+ tool_context: ToolContext,
+ ) -> Any:
+ from ..agents.llm_agent import LlmAgent
+
+ if self.skip_summarization:
+ tool_context.actions.skip_summarization = True
+
+ if isinstance(self.agent, LlmAgent) and self.agent.input_schema:
+ input_value = self.agent.input_schema.model_validate(args)
+ else:
+ input_value = args['request']
+
+ if isinstance(self.agent, LlmAgent) and self.agent.input_schema:
+ if isinstance(input_value, dict):
+ input_value = self.agent.input_schema.model_validate(input_value)
+ if not isinstance(input_value, self.agent.input_schema):
+ raise ValueError(
+ f'Input value {input_value} is not of type'
+ f' `{self.agent.input_schema}`.'
+ )
+ content = types.Content(
+ role='user',
+ parts=[
+ types.Part.from_text(
+ text=input_value.model_dump_json(exclude_none=True)
+ )
+ ],
+ )
+ else:
+ content = types.Content(
+ role='user',
+ parts=[types.Part.from_text(text=input_value)],
+ )
+ runner = Runner(
+ app_name=self.agent.name,
+ agent=self.agent,
+ # TODO(kech): Remove the access to the invocation context.
+ # It seems we don't need re-use artifact_service if we forward below.
+ artifact_service=tool_context._invocation_context.artifact_service,
+ session_service=InMemorySessionService(),
+ memory_service=InMemoryMemoryService(),
+ )
+ session = runner.session_service.create_session(
+ app_name=self.agent.name,
+ user_id='tmp_user',
+ state=tool_context.state.to_dict(),
+ )
+
+ last_event = None
+ async for event in runner.run_async(
+ user_id=session.user_id, session_id=session.id, new_message=content
+ ):
+ # Forward state delta to parent session.
+ if event.actions.state_delta:
+ tool_context.state.update(event.actions.state_delta)
+ last_event = event
+
+ if runner.artifact_service:
+ # Forward all artifacts to parent session.
+ for artifact_name in runner.artifact_service.list_artifact_keys(
+ app_name=session.app_name,
+ user_id=session.user_id,
+ session_id=session.id,
+ ):
+ if artifact := runner.artifact_service.load_artifact(
+ app_name=session.app_name,
+ user_id=session.user_id,
+ session_id=session.id,
+ filename=artifact_name,
+ ):
+ tool_context.save_artifact(filename=artifact_name, artifact=artifact)
+
+ if (
+ not last_event
+ or not last_event.content
+ or not last_event.content.parts
+ or not last_event.content.parts[0].text
+ ):
+ return ''
+ if isinstance(self.agent, LlmAgent) and self.agent.output_schema:
+ tool_result = self.agent.output_schema.model_validate_json(
+ last_event.content.parts[0].text
+ ).model_dump(exclude_none=True)
+ else:
+ tool_result = last_event.content.parts[0].text
+ return tool_result
diff --git a/src/google/adk/tools/apihub_tool/__init__.py b/src/google/adk/tools/apihub_tool/__init__.py
new file mode 100644
index 0000000..141e990
--- /dev/null
+++ b/src/google/adk/tools/apihub_tool/__init__.py
@@ -0,0 +1,19 @@
+# Copyright 2025 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from .apihub_toolset import APIHubToolset
+
+__all__ = [
+ 'APIHubToolset',
+]
diff --git a/src/google/adk/tools/apihub_tool/apihub_toolset.py b/src/google/adk/tools/apihub_tool/apihub_toolset.py
new file mode 100644
index 0000000..0cf160e
--- /dev/null
+++ b/src/google/adk/tools/apihub_tool/apihub_toolset.py
@@ -0,0 +1,209 @@
+# Copyright 2025 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+from typing import Dict, List, Optional
+
+import yaml
+
+from ...auth.auth_credential import AuthCredential
+from ...auth.auth_schemes import AuthScheme
+from ..openapi_tool.common.common import to_snake_case
+from ..openapi_tool.openapi_spec_parser.openapi_toolset import OpenAPIToolset
+from ..openapi_tool.openapi_spec_parser.rest_api_tool import RestApiTool
+from .clients.apihub_client import APIHubClient
+
+
+class APIHubToolset:
+ """APIHubTool generates tools from a given API Hub resource.
+
+ Examples:
+
+ ```
+ apihub_toolset = APIHubToolset(
+ apihub_resource_name="projects/test-project/locations/us-central1/apis/test-api",
+ service_account_json="...",
+ )
+
+ # Get all available tools
+ agent = LlmAgent(tools=apihub_toolset.get_tools())
+
+ # Get a specific tool
+ agent = LlmAgent(tools=[
+ ...
+ apihub_toolset.get_tool('my_tool'),
+ ])
+ ```
+
+ **apihub_resource_name** is the resource name from API Hub. It must include
+ API name, and can optionally include API version and spec name.
+ - If apihub_resource_name includes a spec resource name, the content of that
+ spec will be used for generating the tools.
+ - If apihub_resource_name includes only an api or a version name, the
+ first spec of the first version of that API will be used.
+ """
+
+ def __init__(
+ self,
+ *,
+ # Parameters for fetching API Hub resource
+ apihub_resource_name: str,
+ access_token: Optional[str] = None,
+ service_account_json: Optional[str] = None,
+ # Parameters for the toolset itself
+ name: str = '',
+ description: str = '',
+ # Parameters for generating tools
+ lazy_load_spec=False,
+ auth_scheme: Optional[AuthScheme] = None,
+ auth_credential: Optional[AuthCredential] = None,
+ # Optionally, you can provide a custom API Hub client
+ apihub_client: Optional[APIHubClient] = None,
+ ):
+ """Initializes the APIHubTool with the given parameters.
+
+ Examples:
+ ```
+ apihub_toolset = APIHubToolset(
+ apihub_resource_name="projects/test-project/locations/us-central1/apis/test-api",
+ service_account_json="...",
+ )
+
+ # Get all available tools
+ agent = LlmAgent(tools=apihub_toolset.get_tools())
+
+ # Get a specific tool
+ agent = LlmAgent(tools=[
+ ...
+ apihub_toolset.get_tool('my_tool'),
+ ])
+ ```
+
+ **apihub_resource_name** is the resource name from API Hub. It must include
+ API name, and can optionally include API version and spec name.
+ - If apihub_resource_name includes a spec resource name, the content of that
+ spec will be used for generating the tools.
+ - If apihub_resource_name includes only an api or a version name, the
+ first spec of the first version of that API will be used.
+
+ Example:
+ * projects/xxx/locations/us-central1/apis/apiname/...
+ * https://console.cloud.google.com/apigee/api-hub/apis/apiname?project=xxx
+
+ Args:
+ apihub_resource_name: The resource name of the API in API Hub.
+ Example: `projects/test-project/locations/us-central1/apis/test-api`.
+ access_token: Google Access token. Generate with gcloud cli `gcloud auth
+ auth print-access-token`. Used for fetching API Specs from API Hub.
+ service_account_json: The service account config as a json string.
+ Required if not using default service credential. It is used for
+ creating the API Hub client and fetching the API Specs from API Hub.
+ apihub_client: Optional custom API Hub client.
+ name: Name of the toolset. Optional.
+ description: Description of the toolset. Optional.
+ auth_scheme: Auth scheme that applies to all the tool in the toolset.
+ auth_credential: Auth credential that applies to all the tool in the
+ toolset.
+ lazy_load_spec: If True, the spec will be loaded lazily when needed.
+ Otherwise, the spec will be loaded immediately and the tools will be
+ generated during initialization.
+ """
+ self.name = name
+ self.description = description
+ self.apihub_resource_name = apihub_resource_name
+ self.lazy_load_spec = lazy_load_spec
+ self.apihub_client = apihub_client or APIHubClient(
+ access_token=access_token,
+ service_account_json=service_account_json,
+ )
+
+ self.generated_tools: Dict[str, RestApiTool] = {}
+ self.auth_scheme = auth_scheme
+ self.auth_credential = auth_credential
+
+ if not self.lazy_load_spec:
+ self._prepare_tools()
+
+ def get_tool(self, name: str) -> Optional[RestApiTool]:
+ """Retrieves a specific tool by its name.
+
+ Example:
+ ```
+ apihub_tool = apihub_toolset.get_tool('my_tool')
+ ```
+
+ Args:
+ name: The name of the tool to retrieve.
+
+ Returns:
+ The tool with the given name, or None if no such tool exists.
+ """
+ if not self._are_tools_ready():
+ self._prepare_tools()
+
+ return self.generated_tools[name] if name in self.generated_tools else None
+
+ def get_tools(self) -> List[RestApiTool]:
+ """Retrieves all available tools.
+
+ Returns:
+ A list of all available RestApiTool objects.
+ """
+ if not self._are_tools_ready():
+ self._prepare_tools()
+
+ return list(self.generated_tools.values())
+
+ def _are_tools_ready(self) -> bool:
+ return not self.lazy_load_spec or self.generated_tools
+
+ def _prepare_tools(self) -> str:
+ """Fetches the spec from API Hub and generates the tools.
+
+ Returns:
+ True if the tools are ready, False otherwise.
+ """
+ # For each API, get the first version and the first spec of that version.
+ spec = self.apihub_client.get_spec_content(self.apihub_resource_name)
+ self.generated_tools: Dict[str, RestApiTool] = {}
+
+ tools = self._parse_spec_to_tools(spec)
+ for tool in tools:
+ self.generated_tools[tool.name] = tool
+
+ def _parse_spec_to_tools(self, spec_str: str) -> List[RestApiTool]:
+ """Parses the spec string to a list of RestApiTool.
+
+ Args:
+ spec_str: The spec string to parse.
+
+ Returns:
+ A list of RestApiTool objects.
+ """
+ spec_dict = yaml.safe_load(spec_str)
+ if not spec_dict:
+ return []
+
+ self.name = self.name or to_snake_case(
+ spec_dict.get('info', {}).get('title', 'unnamed')
+ )
+ self.description = self.description or spec_dict.get('info', {}).get(
+ 'description', ''
+ )
+ tools = OpenAPIToolset(
+ spec_dict=spec_dict,
+ auth_credential=self.auth_credential,
+ auth_scheme=self.auth_scheme,
+ ).get_tools()
+ return tools
diff --git a/src/google/adk/tools/apihub_tool/clients/__init__.py b/src/google/adk/tools/apihub_tool/clients/__init__.py
new file mode 100644
index 0000000..0a2669d
--- /dev/null
+++ b/src/google/adk/tools/apihub_tool/clients/__init__.py
@@ -0,0 +1,13 @@
+# Copyright 2025 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
diff --git a/src/google/adk/tools/apihub_tool/clients/apihub_client.py b/src/google/adk/tools/apihub_tool/clients/apihub_client.py
new file mode 100644
index 0000000..25cf98b
--- /dev/null
+++ b/src/google/adk/tools/apihub_tool/clients/apihub_client.py
@@ -0,0 +1,332 @@
+# Copyright 2025 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from abc import ABC, abstractmethod
+import base64
+import json
+from typing import Any, Dict, List, Optional, Tuple
+from urllib.parse import parse_qs, urlparse
+from google.auth import default as default_service_credential
+from google.auth.transport.requests import Request
+from google.oauth2 import service_account
+import requests
+
+
+class BaseAPIHubClient(ABC):
+ """Base class for API Hub clients."""
+
+ @abstractmethod
+ def get_spec_content(self, resource_name: str) -> str:
+ """From a given resource name, get the soec in the API Hub."""
+ raise NotImplementedError()
+
+
+class APIHubClient(BaseAPIHubClient):
+ """Client for interacting with the API Hub service."""
+
+ def __init__(
+ self,
+ *,
+ access_token: Optional[str] = None,
+ service_account_json: Optional[str] = None,
+ ):
+ """Initializes the APIHubClient.
+
+ You must set either access_token or service_account_json. This
+ credential is used for sending request to API Hub API.
+
+ Args:
+ access_token: Google Access token. Generate with gcloud cli `gcloud auth
+ print-access-token`. Useful for local testing.
+ service_account_json: The service account configuration as a dictionary.
+ Required if not using default service credential.
+ """
+ self.root_url = "https://apihub.googleapis.com/v1"
+ self.credential_cache = None
+ self.access_token, self.service_account = None, None
+
+ if access_token:
+ self.access_token = access_token
+ elif service_account_json:
+ self.service_account = service_account_json
+
+ def get_spec_content(self, path: str) -> str:
+ """From a given path, get the first spec available in the API Hub.
+
+ - If path includes /apis/apiname, get the first spec of that API
+ - If path includes /apis/apiname/versions/versionname, get the first spec
+ of that API Version
+ - If path includes /apis/apiname/versions/versionname/specs/specname, return
+ that spec
+
+ Path can be resource name (projects/xxx/locations/us-central1/apis/apiname),
+ and URL from the UI
+ (https://console.cloud.google.com/apigee/api-hub/apis/apiname?project=xxx)
+
+ Args:
+ path: The path to the API, API Version, or API Spec.
+
+ Returns:
+ The content of the first spec available in the API Hub.
+ """
+ apihub_resource_name, api_version_resource_name, api_spec_resource_name = (
+ self._extract_resource_name(path)
+ )
+
+ if apihub_resource_name and not api_version_resource_name:
+ api = self.get_api(apihub_resource_name)
+ versions = api.get("versions", [])
+ if not versions:
+ raise ValueError(
+ f"No versions found in API Hub resource: {apihub_resource_name}"
+ )
+ api_version_resource_name = versions[0]
+
+ if api_version_resource_name and not api_spec_resource_name:
+ api_version = self.get_api_version(api_version_resource_name)
+ spec_resource_names = api_version.get("specs", [])
+ if not spec_resource_names:
+ raise ValueError(
+ f"No specs found in API Hub version: {api_version_resource_name}"
+ )
+ api_spec_resource_name = spec_resource_names[0]
+
+ if api_spec_resource_name:
+ spec_content = self._fetch_spec(api_spec_resource_name)
+ return spec_content
+
+ raise ValueError("No API Hub resource found in path: {path}")
+
+ def list_apis(self, project: str, location: str) -> List[Dict[str, Any]]:
+ """Lists all APIs in the specified project and location.
+
+ Args:
+ project: The Google Cloud project name.
+ location: The location of the API Hub resources (e.g., 'us-central1').
+
+ Returns:
+ A list of API dictionaries, or an empty list if an error occurs.
+ """
+ url = f"{self.root_url}/projects/{project}/locations/{location}/apis"
+ headers = {
+ "accept": "application/json, text/plain, */*",
+ "Authorization": f"Bearer {self._get_access_token()}",
+ }
+ response = requests.get(url, headers=headers)
+ response.raise_for_status()
+ apis = response.json().get("apis", [])
+ return apis
+
+ def get_api(self, api_resource_name: str) -> Dict[str, Any]:
+ """Get API detail by API name.
+
+ Args:
+ api_resource_name: Resource name of this API, like
+ projects/xxx/locations/us-central1/apis/apiname
+
+ Returns:
+ An API and details in a dict.
+ """
+ url = f"{self.root_url}/{api_resource_name}"
+ headers = {
+ "accept": "application/json, text/plain, */*",
+ "Authorization": f"Bearer {self._get_access_token()}",
+ }
+ response = requests.get(url, headers=headers)
+ response.raise_for_status()
+ apis = response.json()
+ return apis
+
+ def get_api_version(self, api_version_name: str) -> Dict[str, Any]:
+ """Gets details of a specific API version.
+
+ Args:
+ api_version_name: The resource name of the API version.
+
+ Returns:
+ The API version details as a dictionary, or an empty dictionary if an
+ error occurs.
+ """
+ url = f"{self.root_url}/{api_version_name}"
+ headers = {
+ "accept": "application/json, text/plain, */*",
+ "Authorization": f"Bearer {self._get_access_token()}",
+ }
+ response = requests.get(url, headers=headers)
+ response.raise_for_status()
+ return response.json()
+
+ def _fetch_spec(self, api_spec_resource_name: str) -> str:
+ """Retrieves the content of a specific API specification.
+
+ Args:
+ api_spec_resource_name: The resource name of the API spec.
+
+ Returns:
+ The decoded content of the specification as a string, or an empty string
+ if an error occurs.
+ """
+ url = f"{self.root_url}/{api_spec_resource_name}:contents"
+ headers = {
+ "accept": "application/json, text/plain, */*",
+ "Authorization": f"Bearer {self._get_access_token()}",
+ }
+ response = requests.get(url, headers=headers)
+ response.raise_for_status()
+ content_base64 = response.json().get("contents", "")
+ if content_base64:
+ content_decoded = base64.b64decode(content_base64).decode("utf-8")
+ return content_decoded
+ else:
+ return ""
+
+ def _extract_resource_name(self, url_or_path: str) -> Tuple[str, str, str]:
+ """Extracts the resource names of an API, API Version, and API Spec from a given URL or path.
+
+ Args:
+ url_or_path: The URL (UI or resource) or path string.
+
+ Returns:
+ A dictionary containing the resource names:
+ {
+ "api_resource_name": "projects/*/locations/*/apis/*",
+ "api_version_resource_name":
+ "projects/*/locations/*/apis/*/versions/*",
+ "api_spec_resource_name":
+ "projects/*/locations/*/apis/*/versions/*/specs/*"
+ }
+ or raises ValueError if extraction fails.
+
+ Raises:
+ ValueError: If the URL or path is invalid or if required components
+ (project, location, api) are missing.
+ """
+
+ query_params = None
+ try:
+ parsed_url = urlparse(url_or_path)
+ path = parsed_url.path
+ query_params = parse_qs(parsed_url.query)
+
+ # This is a path from UI. Remove unnecessary prefix.
+ if "api-hub/" in path:
+ path = path.split("api-hub")[1]
+ except Exception:
+ path = url_or_path
+
+ path_segments = [segment for segment in path.split("/") if segment]
+
+ project = None
+ location = None
+ api_id = None
+ version_id = None
+ spec_id = None
+
+ if "projects" in path_segments:
+ project_index = path_segments.index("projects")
+ if project_index + 1 < len(path_segments):
+ project = path_segments[project_index + 1]
+ elif query_params and "project" in query_params:
+ project = query_params["project"][0]
+
+ if not project:
+ raise ValueError(
+ "Project ID not found in URL or path in APIHubClient. Input path is"
+ f" '{url_or_path}'. Please make sure there is either"
+ " '/projects/PROJECT_ID' in the path or 'project=PROJECT_ID' query"
+ " param in the input."
+ )
+
+ if "locations" in path_segments:
+ location_index = path_segments.index("locations")
+ if location_index + 1 < len(path_segments):
+ location = path_segments[location_index + 1]
+ if not location:
+ raise ValueError(
+ "Location not found in URL or path in APIHubClient. Input path is"
+ f" '{url_or_path}'. Please make sure there is either"
+ " '/location/LOCATION_ID' in the path."
+ )
+
+ if "apis" in path_segments:
+ api_index = path_segments.index("apis")
+ if api_index + 1 < len(path_segments):
+ api_id = path_segments[api_index + 1]
+ if not api_id:
+ raise ValueError(
+ "API id not found in URL or path in APIHubClient. Input path is"
+ f" '{url_or_path}'. Please make sure there is either"
+ " '/apis/API_ID' in the path."
+ )
+ if "versions" in path_segments:
+ version_index = path_segments.index("versions")
+ if version_index + 1 < len(path_segments):
+ version_id = path_segments[version_index + 1]
+
+ if "specs" in path_segments:
+ spec_index = path_segments.index("specs")
+ if spec_index + 1 < len(path_segments):
+ spec_id = path_segments[spec_index + 1]
+
+ api_resource_name = f"projects/{project}/locations/{location}/apis/{api_id}"
+ api_version_resource_name = (
+ f"{api_resource_name}/versions/{version_id}" if version_id else None
+ )
+ api_spec_resource_name = (
+ f"{api_version_resource_name}/specs/{spec_id}"
+ if version_id and spec_id
+ else None
+ )
+
+ return (
+ api_resource_name,
+ api_version_resource_name,
+ api_spec_resource_name,
+ )
+
+ def _get_access_token(self) -> str:
+ """Gets the access token for the service account.
+
+ Returns:
+ The access token.
+ """
+ if self.access_token:
+ return self.access_token
+
+ if self.credential_cache and not self.credential_cache.expired:
+ return self.credential_cache.token
+
+ if self.service_account:
+ try:
+ credentials = service_account.Credentials.from_service_account_info(
+ json.loads(self.service_account),
+ scopes=["https://www.googleapis.com/auth/cloud-platform"],
+ )
+ except json.JSONDecodeError as e:
+ raise ValueError(f"Invalid service account JSON: {e}") from e
+ else:
+ try:
+ credentials, _ = default_service_credential()
+ except:
+ credentials = None
+
+ if not credentials:
+ raise ValueError(
+ "Please provide a service account or an access token to API Hub"
+ " client."
+ )
+
+ credentials.refresh(Request())
+ self.credential_cache = credentials
+ return credentials.token
diff --git a/src/google/adk/tools/apihub_tool/clients/secret_client.py b/src/google/adk/tools/apihub_tool/clients/secret_client.py
new file mode 100644
index 0000000..2813861
--- /dev/null
+++ b/src/google/adk/tools/apihub_tool/clients/secret_client.py
@@ -0,0 +1,115 @@
+# Copyright 2025 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import json
+from typing import Optional
+import google.auth
+from google.auth import default as default_service_credential
+import google.auth.transport.requests
+from google.cloud import secretmanager
+from google.oauth2 import service_account
+
+
+class SecretManagerClient:
+ """A client for interacting with Google Cloud Secret Manager.
+
+ This class provides a simplified interface for retrieving secrets from
+ Secret Manager, handling authentication using either a service account
+ JSON keyfile (passed as a string) or a pre-existing authorization token.
+
+ Attributes:
+ _credentials: Google Cloud credentials object (ServiceAccountCredentials
+ or Credentials).
+ _client: Secret Manager client instance.
+ """
+
+ def __init__(
+ self,
+ service_account_json: Optional[str] = None,
+ auth_token: Optional[str] = None,
+ ):
+ """Initializes the SecretManagerClient.
+
+ Args:
+ service_account_json: The content of a service account JSON keyfile (as
+ a string), not the file path. Must be valid JSON.
+ auth_token: An existing Google Cloud authorization token.
+
+ Raises:
+ ValueError: If neither `service_account_json` nor `auth_token` is
+ provided,
+ or if both are provided. Also raised if the service_account_json
+ is not valid JSON.
+ google.auth.exceptions.GoogleAuthError: If authentication fails.
+ """
+ if service_account_json:
+ try:
+ credentials = service_account.Credentials.from_service_account_info(
+ json.loads(service_account_json)
+ )
+ except json.JSONDecodeError as e:
+ raise ValueError(f"Invalid service account JSON: {e}") from e
+ elif auth_token:
+ credentials = google.auth.credentials.Credentials(
+ token=auth_token,
+ refresh_token=None,
+ token_uri=None,
+ client_id=None,
+ client_secret=None,
+ )
+ request = google.auth.transport.requests.Request()
+ credentials.refresh(request)
+ else:
+ try:
+ credentials, _ = default_service_credential()
+ except Exception as e:
+ raise ValueError(
+ "'service_account_json' or 'auth_token' are both missing, and"
+ f" error occurred while trying to use default credentials: {e}"
+ ) from e
+
+ if not credentials:
+ raise ValueError(
+ "Must provide either 'service_account_json' or 'auth_token', not both"
+ " or neither."
+ )
+
+ self._credentials = credentials
+ self._client = secretmanager.SecretManagerServiceClient(
+ credentials=self._credentials
+ )
+
+ def get_secret(self, resource_name: str) -> str:
+ """Retrieves a secret from Google Cloud Secret Manager.
+
+ Args:
+ resource_name: The full resource name of the secret, in the format
+ "projects/*/secrets/*/versions/*". Usually you want the "latest"
+ version, e.g.,
+ "projects/my-project/secrets/my-secret/versions/latest".
+
+ Returns:
+ The secret payload as a string.
+
+ Raises:
+ google.api_core.exceptions.GoogleAPIError: If the Secret Manager API
+ returns an error (e.g., secret not found, permission denied).
+ Exception: For other unexpected errors.
+ """
+ try:
+ response = self._client.access_secret_version(name=resource_name)
+ return response.payload.data.decode("UTF-8")
+ except Exception as e:
+ raise e # Re-raise the exception to allow for handling by the caller
+ # Consider logging the exception here before re-raising.
diff --git a/src/google/adk/tools/application_integration_tool/__init__.py b/src/google/adk/tools/application_integration_tool/__init__.py
new file mode 100644
index 0000000..fd9eb51
--- /dev/null
+++ b/src/google/adk/tools/application_integration_tool/__init__.py
@@ -0,0 +1,19 @@
+# Copyright 2025 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from .application_integration_toolset import ApplicationIntegrationToolset
+
+__all__ = [
+ 'ApplicationIntegrationToolset',
+]
diff --git a/src/google/adk/tools/application_integration_tool/application_integration_toolset.py b/src/google/adk/tools/application_integration_tool/application_integration_toolset.py
new file mode 100644
index 0000000..5874bb5
--- /dev/null
+++ b/src/google/adk/tools/application_integration_tool/application_integration_toolset.py
@@ -0,0 +1,230 @@
+# Copyright 2025 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import Dict
+from typing import List
+from typing import Optional
+
+from fastapi.openapi.models import HTTPBearer
+from google.adk.tools.application_integration_tool.clients.connections_client import ConnectionsClient
+from google.adk.tools.application_integration_tool.clients.integration_client import IntegrationClient
+from google.adk.tools.openapi_tool.auth.auth_helpers import service_account_scheme_credential
+from google.adk.tools.openapi_tool.openapi_spec_parser.openapi_toolset import OpenAPIToolset
+from google.adk.tools.openapi_tool.openapi_spec_parser.rest_api_tool import RestApiTool
+
+from ...auth.auth_credential import AuthCredential
+from ...auth.auth_credential import AuthCredentialTypes
+from ...auth.auth_credential import ServiceAccount
+from ...auth.auth_credential import ServiceAccountCredential
+
+
+# TODO(cheliu): Apply a common toolset interface
+class ApplicationIntegrationToolset:
+ """ApplicationIntegrationToolset generates tools from a given Application
+
+ Integration or Integration Connector resource.
+ Example Usage:
+ ```
+ # Get all available tools for an integration with api trigger
+ application_integration_toolset = ApplicationIntegrationToolset(
+
+ project="test-project",
+ location="us-central1"
+ integration="test-integration",
+ trigger="api_trigger/test_trigger",
+ service_account_credentials={...},
+ )
+
+ # Get all available tools for a connection using entity operations and
+ # actions
+ # Note: Find the list of supported entity operations and actions for a
+ connection
+ # using integration connector apis:
+ #
+ https://cloud.google.com/integration-connectors/docs/reference/rest/v1/projects.locations.connections.connectionSchemaMetadata
+ application_integration_toolset = ApplicationIntegrationToolset(
+ project="test-project",
+ location="us-central1"
+ connection="test-connection",
+ entity_operations=["EntityId1": ["LIST","CREATE"], "EntityId2": []],
+ #empty list for actions means all operations on the entity are supported
+ actions=["action1"],
+ service_account_credentials={...},
+ )
+
+ # Get all available tools
+ agent = LlmAgent(tools=[
+ ...
+ *application_integration_toolset.get_tools(),
+ ])
+ ```
+ """
+
+ def __init__(
+ self,
+ project: str,
+ location: str,
+ integration: Optional[str] = None,
+ trigger: Optional[str] = None,
+ connection: Optional[str] = None,
+ entity_operations: Optional[str] = None,
+ actions: Optional[str] = None,
+ # Optional parameter for the toolset. This is prepended to the generated
+ # tool/python function name.
+ tool_name: Optional[str] = "",
+ # Optional parameter for the toolset. This is appended to the generated
+ # tool/python function description.
+ tool_instructions: Optional[str] = "",
+ service_account_json: Optional[str] = None,
+ ):
+ """Initializes the ApplicationIntegrationToolset.
+
+ Example Usage:
+ ```
+ # Get all available tools for an integration with api trigger
+ application_integration_toolset = ApplicationIntegrationToolset(
+
+ project="test-project",
+ location="us-central1"
+ integration="test-integration",
+ trigger="api_trigger/test_trigger",
+ service_account_credentials={...},
+ )
+
+ # Get all available tools for a connection using entity operations and
+ # actions
+ # Note: Find the list of supported entity operations and actions for a
+ connection
+ # using integration connector apis:
+ #
+ https://cloud.google.com/integration-connectors/docs/reference/rest/v1/projects.locations.connections.connectionSchemaMetadata
+ application_integration_toolset = ApplicationIntegrationToolset(
+ project="test-project",
+ location="us-central1"
+ connection="test-connection",
+ entity_operations=["EntityId1": ["LIST","CREATE"], "EntityId2": []],
+ #empty list for actions means all operations on the entity are supported
+ actions=["action1"],
+ service_account_credentials={...},
+ )
+
+ # Get all available tools
+ agent = LlmAgent(tools=[
+ ...
+ *application_integration_toolset.get_tools(),
+ ])
+ ```
+
+ Args:
+ project: The GCP project ID.
+ location: The GCP location.
+ integration: The integration name.
+ trigger: The trigger name.
+ connection: The connection name.
+ entity_operations: The entity operations supported by the connection.
+ actions: The actions supported by the connection.
+ tool_name: The name of the tool.
+ tool_instructions: The instructions for the tool.
+ service_account_json: The service account configuration as a dictionary.
+ Required if not using default service credential. Used for fetching
+ the Application Integration or Integration Connector resource.
+
+ Raises:
+ ValueError: If neither integration and trigger nor connection and
+ (entity_operations or actions) is provided.
+ Exception: If there is an error during the initialization of the
+ integration or connection client.
+ """
+ self.project = project
+ self.location = location
+ self.integration = integration
+ self.trigger = trigger
+ self.connection = connection
+ self.entity_operations = entity_operations
+ self.actions = actions
+ self.tool_name = tool_name
+ self.tool_instructions = tool_instructions
+ self.service_account_json = service_account_json
+ self.generated_tools: Dict[str, RestApiTool] = {}
+
+ integration_client = IntegrationClient(
+ project,
+ location,
+ integration,
+ trigger,
+ connection,
+ entity_operations,
+ actions,
+ service_account_json,
+ )
+ if integration and trigger:
+ spec = integration_client.get_openapi_spec_for_integration()
+ elif connection and (entity_operations or actions):
+ connections_client = ConnectionsClient(
+ project, location, connection, service_account_json
+ )
+ connection_details = connections_client.get_connection_details()
+ tool_instructions += (
+ "ALWAYS use serviceName = "
+ + connection_details["serviceName"]
+ + ", host = "
+ + connection_details["host"]
+ + " and the connection name = "
+ + f"projects/{project}/locations/{location}/connections/{connection} when"
+ " using this tool"
+ + ". DONOT ask the user for these values as you already have those."
+ )
+ spec = integration_client.get_openapi_spec_for_connection(
+ tool_name,
+ tool_instructions,
+ )
+ else:
+ raise ValueError(
+ "Either (integration and trigger) or (connection and"
+ " (entity_operations or actions)) should be provided."
+ )
+ self._parse_spec_to_tools(spec)
+
+ def _parse_spec_to_tools(self, spec_dict):
+ """Parses the spec dict to a list of RestApiTool."""
+ if self.service_account_json:
+ sa_credential = ServiceAccountCredential.model_validate_json(
+ self.service_account_json
+ )
+ service_account = ServiceAccount(
+ service_account_credential=sa_credential,
+ scopes=["https://www.googleapis.com/auth/cloud-platform"],
+ )
+ auth_scheme, auth_credential = service_account_scheme_credential(
+ config=service_account
+ )
+ else:
+ auth_credential = AuthCredential(
+ auth_type=AuthCredentialTypes.SERVICE_ACCOUNT,
+ service_account=ServiceAccount(
+ use_default_credential=True,
+ scopes=["https://www.googleapis.com/auth/cloud-platform"],
+ ),
+ )
+ auth_scheme = HTTPBearer(bearerFormat="JWT")
+ tools = OpenAPIToolset(
+ spec_dict=spec_dict,
+ auth_credential=auth_credential,
+ auth_scheme=auth_scheme,
+ ).get_tools()
+ for tool in tools:
+ self.generated_tools[tool.name] = tool
+
+ def get_tools(self) -> List[RestApiTool]:
+ return list(self.generated_tools.values())
diff --git a/src/google/adk/tools/application_integration_tool/clients/connections_client.py b/src/google/adk/tools/application_integration_tool/clients/connections_client.py
new file mode 100644
index 0000000..06b4acf
--- /dev/null
+++ b/src/google/adk/tools/application_integration_tool/clients/connections_client.py
@@ -0,0 +1,903 @@
+# Copyright 2025 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import json
+import time
+from typing import Any, Dict, List, Optional, Tuple
+
+import google.auth
+from google.auth import default as default_service_credential
+from google.auth.transport.requests import Request
+from google.oauth2 import service_account
+import requests
+
+
+class ConnectionsClient:
+ """Utility class for interacting with Google Cloud Connectors API."""
+
+ def __init__(
+ self,
+ project: str,
+ location: str,
+ connection: str,
+ service_account_json: Optional[str] = None,
+ ):
+ """Initializes the ConnectionsClient.
+
+ Args:
+ project: The Google Cloud project ID.
+ location: The Google Cloud location (e.g., us-central1).
+ connection: The connection name.
+ service_account_json: The service account configuration as a dictionary.
+ Required if not using default service credential. Used for fetching
+ connection details.
+ """
+ self.project = project
+ self.location = location
+ self.connection = connection
+ self.connector_url = "https://connectors.googleapis.com"
+ self.service_account_json = service_account_json
+ self.credential_cache = None
+
+ def get_connection_details(self) -> Dict[str, Any]:
+ """Retrieves service details (service name and host) for a given connection.
+
+ Also returns if auth override is enabled for the connection.
+
+ Returns:
+ tuple: A tuple containing (service_name, host).
+
+ Raises:
+ PermissionError: If there are credential issues.
+ ValueError: If there's a request error.
+ Exception: For any other unexpected errors.
+ """
+ url = f"{self.connector_url}/v1/projects/{self.project}/locations/{self.location}/connections/{self.connection}?view=BASIC"
+
+ response = self._execute_api_call(url)
+
+ connection_data = response.json()
+ service_name = connection_data.get("serviceDirectory", "")
+ host = connection_data.get("host", "")
+ if host:
+ service_name = connection_data.get("tlsServiceDirectory", "")
+ auth_override_enabled = connection_data.get("authOverrideEnabled", False)
+ return {
+ "serviceName": service_name,
+ "host": host,
+ "authOverrideEnabled": auth_override_enabled,
+ }
+
+ def get_entity_schema_and_operations(
+ self, entity: str
+ ) -> Tuple[Dict[str, Any], List[str]]:
+ """Retrieves the JSON schema for a given entity in a connection.
+
+ Args:
+ entity (str): The entity name.
+
+ Returns:
+ tuple: A tuple containing (schema, operations).
+
+ Raises:
+ PermissionError: If there are credential issues.
+ ValueError: If there's a request or processing error.
+ Exception: For any other unexpected errors.
+ """
+ url = f"{self.connector_url}/v1/projects/{self.project}/locations/{self.location}/connections/{self.connection}/connectionSchemaMetadata:getEntityType?entityId={entity}"
+
+ response = self._execute_api_call(url)
+ operation_id = response.json().get("name")
+
+ if not operation_id:
+ raise ValueError(
+ f"Failed to get entity schema and operations for entity: {entity}"
+ )
+
+ operation_response = self._poll_operation(operation_id)
+
+ schema = operation_response.get("response", {}).get("jsonSchema", {})
+ operations = operation_response.get("response", {}).get("operations", [])
+ return schema, operations
+
+ def get_action_schema(self, action: str) -> Dict[str, Any]:
+ """Retrieves the input and output JSON schema for a given action in a connection.
+
+ Args:
+ action (str): The action name.
+
+ Returns:
+ tuple: A tuple containing (input_schema, output_schema).
+
+ Raises:
+ PermissionError: If there are credential issues.
+ ValueError: If there's a request or processing error.
+ Exception: For any other unexpected errors.
+ """
+ url = f"{self.connector_url}/v1/projects/{self.project}/locations/{self.location}/connections/{self.connection}/connectionSchemaMetadata:getAction?actionId={action}"
+
+ response = self._execute_api_call(url)
+
+ operation_id = response.json().get("name")
+
+ if not operation_id:
+ raise ValueError(f"Failed to get action schema for action: {action}")
+
+ operation_response = self._poll_operation(operation_id)
+
+ input_schema = operation_response.get("response", {}).get(
+ "inputJsonSchema", {}
+ )
+ output_schema = operation_response.get("response", {}).get(
+ "outputJsonSchema", {}
+ )
+ description = operation_response.get("response", {}).get("description", "")
+ display_name = operation_response.get("response", {}).get("displayName", "")
+ return {
+ "inputSchema": input_schema,
+ "outputSchema": output_schema,
+ "description": description,
+ "displayName": display_name,
+ }
+
+ @staticmethod
+ def get_connector_base_spec() -> Dict[str, Any]:
+ return {
+ "openapi": "3.0.1",
+ "info": {
+ "title": "ExecuteConnection",
+ "description": "This tool can execute a query on connection",
+ "version": "4",
+ },
+ "servers": [{"url": "https://integrations.googleapis.com"}],
+ "security": [
+ {"google_auth": ["https://www.googleapis.com/auth/cloud-platform"]}
+ ],
+ "paths": {},
+ "components": {
+ "schemas": {
+ "operation": {
+ "type": "string",
+ "default": "LIST_ENTITIES",
+ "description": (
+ "Operation to execute. Possible values are"
+ " LIST_ENTITIES, GET_ENTITY, CREATE_ENTITY,"
+ " UPDATE_ENTITY, DELETE_ENTITY in case of entities."
+ " EXECUTE_ACTION in case of actions. and EXECUTE_QUERY"
+ " in case of custom queries."
+ ),
+ },
+ "entityId": {
+ "type": "string",
+ "description": "Name of the entity",
+ },
+ "connectorInputPayload": {"type": "object"},
+ "filterClause": {
+ "type": "string",
+ "default": "",
+ "description": "WHERE clause in SQL query",
+ },
+ "pageSize": {
+ "type": "integer",
+ "default": 50,
+ "description": (
+ "Number of entities to return in the response"
+ ),
+ },
+ "pageToken": {
+ "type": "string",
+ "default": "",
+ "description": (
+ "Page token to return the next page of entities"
+ ),
+ },
+ "connectionName": {
+ "type": "string",
+ "default": "",
+ "description": (
+ "Connection resource name to run the query for"
+ ),
+ },
+ "serviceName": {
+ "type": "string",
+ "default": "",
+ "description": "Service directory for the connection",
+ },
+ "host": {
+ "type": "string",
+ "default": "",
+ "description": "Host name incase of tls service directory",
+ },
+ "entity": {
+ "type": "string",
+ "default": "Issues",
+ "description": "Entity to run the query for",
+ },
+ "action": {
+ "type": "string",
+ "default": "ExecuteCustomQuery",
+ "description": "Action to run the query for",
+ },
+ "query": {
+ "type": "string",
+ "default": "",
+ "description": "Custom Query to execute on the connection",
+ },
+ "dynamicAuthConfig": {
+ "type": "object",
+ "default": {},
+ "description": "Dynamic auth config for the connection",
+ },
+ "timeout": {
+ "type": "integer",
+ "default": 120,
+ "description": (
+ "Timeout in seconds for execution of custom query"
+ ),
+ },
+ "connectorOutputPayload": {"type": "object"},
+ "nextPageToken": {"type": "string"},
+ "execute-connector_Response": {
+ "required": ["connectorOutputPayload"],
+ "type": "object",
+ "properties": {
+ "connectorOutputPayload": {
+ "$ref": (
+ "#/components/schemas/connectorOutputPayload"
+ )
+ },
+ "nextPageToken": {
+ "$ref": "#/components/schemas/nextPageToken"
+ },
+ },
+ },
+ },
+ "securitySchemes": {
+ "google_auth": {
+ "type": "oauth2",
+ "flows": {
+ "implicit": {
+ "authorizationUrl": (
+ "https://accounts.google.com/o/oauth2/auth"
+ ),
+ "scopes": {
+ "https://www.googleapis.com/auth/cloud-platform": (
+ "Auth for google cloud services"
+ )
+ },
+ }
+ },
+ }
+ },
+ },
+ }
+
+ @staticmethod
+ def get_action_operation(
+ action: str,
+ operation: str,
+ action_display_name: str,
+ tool_name: str = "",
+ tool_instructions: str = "",
+ ) -> Dict[str, Any]:
+ description = (
+ f"Use this tool with" f' action = "{action}" and'
+ ) + f' operation = "{operation}" only. Dont ask these values from user.'
+ if operation == "EXECUTE_QUERY":
+ description = (
+ (f"Use this tool with" f' action = "{action}" and')
+ + f' operation = "{operation}" only. Dont ask these values from user.'
+ " Use pageSize = 50 and timeout = 120 until user specifies a"
+ " different value otherwise. If user provides a query in natural"
+ " language, convert it to SQL query and then execute it using the"
+ " tool."
+ )
+ return {
+ "post": {
+ "summary": f"{action_display_name}",
+ "description": f"{description} {tool_instructions}",
+ "operationId": f"{tool_name}_{action_display_name}",
+ "requestBody": {
+ "content": {
+ "application/json": {
+ "schema": {
+ "$ref": (
+ f"#/components/schemas/{action_display_name}_Request"
+ )
+ }
+ }
+ }
+ },
+ "responses": {
+ "200": {
+ "description": "Success response",
+ "content": {
+ "application/json": {
+ "schema": {
+ "$ref": (
+ f"#/components/schemas/{action_display_name}_Response"
+ ),
+ }
+ }
+ },
+ }
+ },
+ }
+ }
+
+ @staticmethod
+ def list_operation(
+ entity: str,
+ schema_as_string: str = "",
+ tool_name: str = "",
+ tool_instructions: str = "",
+ ) -> Dict[str, Any]:
+ return {
+ "post": {
+ "summary": f"List {entity}",
+ "description": (
+ f"Returns all entities of type {entity}. Use this tool with"
+ + f' entity = "{entity}" and'
+ + ' operation = "LIST_ENTITIES" only. Dont ask these values'
+ " from"
+ + ' user. Always use ""'
+ + ' as filter clause and ""'
+ + " as page token and 50 as page size until user specifies a"
+ " different value otherwise. Use single quotes for strings in"
+ f" filter clause. {tool_instructions}"
+ ),
+ "operationId": f"{tool_name}_list_{entity}",
+ "requestBody": {
+ "content": {
+ "application/json": {
+ "schema": {
+ "$ref": (
+ f"#/components/schemas/list_{entity}_Request"
+ )
+ }
+ }
+ }
+ },
+ "responses": {
+ "200": {
+ "description": "Success response",
+ "content": {
+ "application/json": {
+ "schema": {
+ "description": (
+ f"Returns a list of {entity} of json"
+ f" schema: {schema_as_string}"
+ ),
+ "$ref": (
+ "#/components/schemas/execute-connector_Response"
+ ),
+ }
+ }
+ },
+ }
+ },
+ }
+ }
+
+ @staticmethod
+ def get_operation(
+ entity: str,
+ schema_as_string: str = "",
+ tool_name: str = "",
+ tool_instructions: str = "",
+ ) -> Dict[str, Any]:
+ return {
+ "post": {
+ "summary": f"Get {entity}",
+ "description": (
+ (
+ f"Returns the details of the {entity}. Use this tool with"
+ f' entity = "{entity}" and'
+ )
+ + ' operation = "GET_ENTITY" only. Dont ask these values from'
+ f" user. {tool_instructions}"
+ ),
+ "operationId": f"{tool_name}_get_{entity}",
+ "requestBody": {
+ "content": {
+ "application/json": {
+ "schema": {
+ "$ref": f"#/components/schemas/get_{entity}_Request"
+ }
+ }
+ }
+ },
+ "responses": {
+ "200": {
+ "description": "Success response",
+ "content": {
+ "application/json": {
+ "schema": {
+ "description": (
+ f"Returns {entity} of json schema:"
+ f" {schema_as_string}"
+ ),
+ "$ref": (
+ "#/components/schemas/execute-connector_Response"
+ ),
+ }
+ }
+ },
+ }
+ },
+ }
+ }
+
+ @staticmethod
+ def create_operation(
+ entity: str, tool_name: str = "", tool_instructions: str = ""
+ ) -> Dict[str, Any]:
+ return {
+ "post": {
+ "summary": f"Create {entity}",
+ "description": (
+ (
+ f"Creates a new entity of type {entity}. Use this tool with"
+ f' entity = "{entity}" and'
+ )
+ + ' operation = "CREATE_ENTITY" only. Dont ask these values'
+ " from"
+ + " user. Follow the schema of the entity provided in the"
+ f" instructions to create {entity}. {tool_instructions}"
+ ),
+ "operationId": f"{tool_name}_create_{entity}",
+ "requestBody": {
+ "content": {
+ "application/json": {
+ "schema": {
+ "$ref": (
+ f"#/components/schemas/create_{entity}_Request"
+ )
+ }
+ }
+ }
+ },
+ "responses": {
+ "200": {
+ "description": "Success response",
+ "content": {
+ "application/json": {
+ "schema": {
+ "$ref": (
+ "#/components/schemas/execute-connector_Response"
+ )
+ }
+ }
+ },
+ }
+ },
+ }
+ }
+
+ @staticmethod
+ def update_operation(
+ entity: str, tool_name: str = "", tool_instructions: str = ""
+ ) -> Dict[str, Any]:
+ return {
+ "post": {
+ "summary": f"Update {entity}",
+ "description": (
+ (
+ f"Updates an entity of type {entity}. Use this tool with"
+ f' entity = "{entity}" and'
+ )
+ + ' operation = "UPDATE_ENTITY" only. Dont ask these values'
+ " from"
+ + " user. Use entityId to uniquely identify the entity to"
+ " update. Follow the schema of the entity provided in the"
+ f" instructions to update {entity}. {tool_instructions}"
+ ),
+ "operationId": f"{tool_name}_update_{entity}",
+ "requestBody": {
+ "content": {
+ "application/json": {
+ "schema": {
+ "$ref": (
+ f"#/components/schemas/update_{entity}_Request"
+ )
+ }
+ }
+ }
+ },
+ "responses": {
+ "200": {
+ "description": "Success response",
+ "content": {
+ "application/json": {
+ "schema": {
+ "$ref": (
+ "#/components/schemas/execute-connector_Response"
+ )
+ }
+ }
+ },
+ }
+ },
+ }
+ }
+
+ @staticmethod
+ def delete_operation(
+ entity: str, tool_name: str = "", tool_instructions: str = ""
+ ) -> Dict[str, Any]:
+ return {
+ "post": {
+ "summary": f"Delete {entity}",
+ "description": (
+ (
+ f"Deletes an entity of type {entity}. Use this tool with"
+ f' entity = "{entity}" and'
+ )
+ + ' operation = "DELETE_ENTITY" only. Dont ask these values'
+ " from"
+ f" user. {tool_instructions}"
+ ),
+ "operationId": f"{tool_name}_delete_{entity}",
+ "requestBody": {
+ "content": {
+ "application/json": {
+ "schema": {
+ "$ref": (
+ f"#/components/schemas/delete_{entity}_Request"
+ )
+ }
+ }
+ }
+ },
+ "responses": {
+ "200": {
+ "description": "Success response",
+ "content": {
+ "application/json": {
+ "schema": {
+ "$ref": (
+ "#/components/schemas/execute-connector_Response"
+ )
+ }
+ }
+ },
+ }
+ },
+ }
+ }
+
+ @staticmethod
+ def create_operation_request(entity: str) -> Dict[str, Any]:
+ return {
+ "type": "object",
+ "required": [
+ "connectorInputPayload",
+ "operation",
+ "connectionName",
+ "serviceName",
+ "host",
+ "entity",
+ ],
+ "properties": {
+ "connectorInputPayload": {
+ "$ref": f"#/components/schemas/connectorInputPayload_{entity}"
+ },
+ "operation": {"$ref": "#/components/schemas/operation"},
+ "connectionName": {"$ref": "#/components/schemas/connectionName"},
+ "serviceName": {"$ref": "#/components/schemas/serviceName"},
+ "host": {"$ref": "#/components/schemas/host"},
+ "entity": {"$ref": "#/components/schemas/entity"},
+ },
+ }
+
+ @staticmethod
+ def update_operation_request(entity: str) -> Dict[str, Any]:
+ return {
+ "type": "object",
+ "required": [
+ "connectorInputPayload",
+ "entityId",
+ "operation",
+ "connectionName",
+ "serviceName",
+ "host",
+ "entity",
+ ],
+ "properties": {
+ "connectorInputPayload": {
+ "$ref": f"#/components/schemas/connectorInputPayload_{entity}"
+ },
+ "entityId": {"$ref": "#/components/schemas/entityId"},
+ "operation": {"$ref": "#/components/schemas/operation"},
+ "connectionName": {"$ref": "#/components/schemas/connectionName"},
+ "serviceName": {"$ref": "#/components/schemas/serviceName"},
+ "host": {"$ref": "#/components/schemas/host"},
+ "entity": {"$ref": "#/components/schemas/entity"},
+ },
+ }
+
+ @staticmethod
+ def get_operation_request() -> Dict[str, Any]:
+ return {
+ "type": "object",
+ "required": [
+ "entityId",
+ "operation",
+ "connectionName",
+ "serviceName",
+ "host",
+ "entity",
+ ],
+ "properties": {
+ "entityId": {"$ref": "#/components/schemas/entityId"},
+ "operation": {"$ref": "#/components/schemas/operation"},
+ "connectionName": {"$ref": "#/components/schemas/connectionName"},
+ "serviceName": {"$ref": "#/components/schemas/serviceName"},
+ "host": {"$ref": "#/components/schemas/host"},
+ "entity": {"$ref": "#/components/schemas/entity"},
+ },
+ }
+
+ @staticmethod
+ def delete_operation_request() -> Dict[str, Any]:
+ return {
+ "type": "object",
+ "required": [
+ "entityId",
+ "operation",
+ "connectionName",
+ "serviceName",
+ "host",
+ "entity",
+ ],
+ "properties": {
+ "entityId": {"$ref": "#/components/schemas/entityId"},
+ "operation": {"$ref": "#/components/schemas/operation"},
+ "connectionName": {"$ref": "#/components/schemas/connectionName"},
+ "serviceName": {"$ref": "#/components/schemas/serviceName"},
+ "host": {"$ref": "#/components/schemas/host"},
+ "entity": {"$ref": "#/components/schemas/entity"},
+ },
+ }
+
+ @staticmethod
+ def list_operation_request() -> Dict[str, Any]:
+ return {
+ "type": "object",
+ "required": [
+ "operation",
+ "connectionName",
+ "serviceName",
+ "host",
+ "entity",
+ ],
+ "properties": {
+ "filterClause": {"$ref": "#/components/schemas/filterClause"},
+ "pageSize": {"$ref": "#/components/schemas/pageSize"},
+ "pageToken": {"$ref": "#/components/schemas/pageToken"},
+ "operation": {"$ref": "#/components/schemas/operation"},
+ "connectionName": {"$ref": "#/components/schemas/connectionName"},
+ "serviceName": {"$ref": "#/components/schemas/serviceName"},
+ "host": {"$ref": "#/components/schemas/host"},
+ "entity": {"$ref": "#/components/schemas/entity"},
+ },
+ }
+
+ @staticmethod
+ def action_request(action: str) -> Dict[str, Any]:
+ return {
+ "type": "object",
+ "required": [
+ "operation",
+ "connectionName",
+ "serviceName",
+ "host",
+ "action",
+ "connectorInputPayload",
+ ],
+ "properties": {
+ "operation": {"$ref": "#/components/schemas/operation"},
+ "connectionName": {"$ref": "#/components/schemas/connectionName"},
+ "serviceName": {"$ref": "#/components/schemas/serviceName"},
+ "host": {"$ref": "#/components/schemas/host"},
+ "action": {"$ref": "#/components/schemas/action"},
+ "connectorInputPayload": {
+ "$ref": f"#/components/schemas/connectorInputPayload_{action}"
+ },
+ },
+ }
+
+ @staticmethod
+ def action_response(action: str) -> Dict[str, Any]:
+ return {
+ "type": "object",
+ "properties": {
+ "connectorOutputPayload": {
+ "$ref": f"#/components/schemas/connectorOutputPayload_{action}"
+ },
+ },
+ }
+
+ @staticmethod
+ def execute_custom_query_request() -> Dict[str, Any]:
+ return {
+ "type": "object",
+ "required": [
+ "operation",
+ "connectionName",
+ "serviceName",
+ "host",
+ "action",
+ "query",
+ "timeout",
+ "pageSize",
+ ],
+ "properties": {
+ "operation": {"$ref": "#/components/schemas/operation"},
+ "connectionName": {"$ref": "#/components/schemas/connectionName"},
+ "serviceName": {"$ref": "#/components/schemas/serviceName"},
+ "host": {"$ref": "#/components/schemas/host"},
+ "action": {"$ref": "#/components/schemas/action"},
+ "query": {"$ref": "#/components/schemas/query"},
+ "timeout": {"$ref": "#/components/schemas/timeout"},
+ "pageSize": {"$ref": "#/components/schemas/pageSize"},
+ },
+ }
+
+ def connector_payload(self, json_schema: Dict[str, Any]) -> Dict[str, Any]:
+ return self._convert_json_schema_to_openapi_schema(json_schema)
+
+ def _convert_json_schema_to_openapi_schema(self, json_schema):
+ """Converts a JSON schema dictionary to an OpenAPI schema dictionary, handling variable types, properties, items, nullable, and description.
+
+ Args:
+ json_schema (dict): The input JSON schema dictionary.
+
+ Returns:
+ dict: The converted OpenAPI schema dictionary.
+ """
+ openapi_schema = {}
+
+ if "description" in json_schema:
+ openapi_schema["description"] = json_schema["description"]
+
+ if "type" in json_schema:
+ if isinstance(json_schema["type"], list):
+ if "null" in json_schema["type"]:
+ openapi_schema["nullable"] = True
+ other_types = [t for t in json_schema["type"] if t != "null"]
+ if other_types:
+ openapi_schema["type"] = other_types[0]
+ else:
+ openapi_schema["type"] = json_schema["type"][0]
+ else:
+ openapi_schema["type"] = json_schema["type"]
+
+ if openapi_schema.get("type") == "object" and "properties" in json_schema:
+ openapi_schema["properties"] = {}
+ for prop_name, prop_schema in json_schema["properties"].items():
+ openapi_schema["properties"][prop_name] = (
+ self._convert_json_schema_to_openapi_schema(prop_schema)
+ )
+
+ elif openapi_schema.get("type") == "array" and "items" in json_schema:
+ if isinstance(json_schema["items"], list):
+ openapi_schema["items"] = [
+ self._convert_json_schema_to_openapi_schema(item)
+ for item in json_schema["items"]
+ ]
+ else:
+ openapi_schema["items"] = self._convert_json_schema_to_openapi_schema(
+ json_schema["items"]
+ )
+
+ return openapi_schema
+
+ def _get_access_token(self) -> str:
+ """Gets the access token for the service account.
+
+ Returns:
+ The access token.
+ """
+ if self.credential_cache and not self.credential_cache.expired:
+ return self.credential_cache.token
+
+ if self.service_account_json:
+ credentials = service_account.Credentials.from_service_account_info(
+ json.loads(self.service_account_json),
+ scopes=["https://www.googleapis.com/auth/cloud-platform"],
+ )
+ else:
+ try:
+ credentials, _ = default_service_credential()
+ except:
+ credentials = None
+
+ if not credentials:
+ raise ValueError(
+ "Please provide a service account that has the required permissions"
+ " to access the connection."
+ )
+
+ credentials.refresh(Request())
+ self.credential_cache = credentials
+ return credentials.token
+
+ def _execute_api_call(self, url):
+ """Executes an API call to the given URL.
+
+ Args:
+ url (str): The URL to call.
+
+ Returns:
+ requests.Response: The response object from the API call.
+
+ Raises:
+ PermissionError: If there are credential issues.
+ ValueError: If there's a request error.
+ Exception: For any other unexpected errors.
+ """
+ try:
+ headers = {
+ "Content-Type": "application/json",
+ "Authorization": f"Bearer {self._get_access_token()}",
+ }
+
+ response = requests.get(url, headers=headers)
+ response.raise_for_status()
+ return response
+
+ except google.auth.exceptions.DefaultCredentialsError as e:
+ raise PermissionError(f"Credentials error: {e}") from e
+
+ except requests.exceptions.RequestException as e:
+ if (
+ "404" in str(e)
+ or "Not found" in str(e)
+ or "400" in str(e)
+ or "Bad request" in str(e)
+ ):
+ raise ValueError(
+ "Invalid request. Please check the provided"
+ f" values of project({self.project}), location({self.location}),"
+ f" connection({self.connection})."
+ ) from e
+ raise ValueError(f"Request error: {e}") from e
+
+ except Exception as e:
+ raise Exception(f"An unexpected error occurred: {e}") from e
+
+ def _poll_operation(self, operation_id: str) -> Dict[str, Any]:
+ """Polls an operation until it is done.
+
+ Args:
+ operation_id: The ID of the operation to poll.
+
+ Returns:
+ The final response of the operation.
+
+ Raises:
+ PermissionError: If there are credential issues.
+ ValueError: If there's a request error.
+ Exception: For any other unexpected errors.
+ """
+ operation_done: bool = False
+ operation_response: Dict[str, Any] = {}
+ while not operation_done:
+ get_operation_url = f"{self.connector_url}/v1/{operation_id}"
+ response = self._execute_api_call(get_operation_url)
+ operation_response = response.json()
+ operation_done = operation_response.get("done", False)
+ time.sleep(1)
+ return operation_response
diff --git a/src/google/adk/tools/application_integration_tool/clients/integration_client.py b/src/google/adk/tools/application_integration_tool/clients/integration_client.py
new file mode 100644
index 0000000..e5be60f
--- /dev/null
+++ b/src/google/adk/tools/application_integration_tool/clients/integration_client.py
@@ -0,0 +1,253 @@
+# Copyright 2025 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import json
+from typing import Optional
+from google.adk.tools.application_integration_tool.clients.connections_client import ConnectionsClient
+import google.auth
+from google.auth import default as default_service_credential
+import google.auth.transport.requests
+from google.auth.transport.requests import Request
+from google.oauth2 import service_account
+import requests
+
+
+class IntegrationClient:
+ """A client for interacting with Google Cloud Application Integration.
+
+ This class provides methods for retrieving OpenAPI spec for an integration or
+ a connection.
+ """
+
+ def __init__(
+ self,
+ project: str,
+ location: str,
+ integration: Optional[str] = None,
+ trigger: Optional[str] = None,
+ connection: Optional[str] = None,
+ entity_operations: Optional[dict[str, list[str]]] = None,
+ actions: Optional[list[str]] = None,
+ service_account_json: Optional[str] = None,
+ ):
+ """Initializes the ApplicationIntegrationClient.
+
+ Args:
+ project: The Google Cloud project ID.
+ location: The Google Cloud location (e.g., us-central1).
+ integration: The integration name.
+ trigger: The trigger ID for the integration.
+ connection: The connection name.
+ entity_operations: A dictionary mapping entity names to a list of
+ operations (e.g., LIST, CREATE, UPDATE, DELETE, GET).
+ actions: List of actions.
+ service_account_json: The service account configuration as a dictionary.
+ Required if not using default service credential. Used for fetching
+ connection details.
+ """
+ self.project = project
+ self.location = location
+ self.integration = integration
+ self.trigger = trigger
+ self.connection = connection
+ self.entity_operations = (
+ entity_operations if entity_operations is not None else {}
+ )
+ self.actions = actions if actions is not None else []
+ self.service_account_json = service_account_json
+ self.credential_cache = None
+
+ def get_openapi_spec_for_integration(self):
+ """Gets the OpenAPI spec for the integration.
+
+ Returns:
+ dict: The OpenAPI spec as a dictionary.
+ Raises:
+ PermissionError: If there are credential issues.
+ ValueError: If there's a request error or processing error.
+ Exception: For any other unexpected errors.
+ """
+ try:
+ url = f"https://{self.location}-integrations.googleapis.com/v1/projects/{self.project}/locations/{self.location}:generateOpenApiSpec"
+ headers = {
+ "Content-Type": "application/json",
+ "Authorization": f"Bearer {self._get_access_token()}",
+ }
+ data = {
+ "apiTriggerResources": [
+ {
+ "integrationResource": self.integration,
+ "triggerId": [self.trigger],
+ },
+ ],
+ "fileFormat": "JSON",
+ }
+ response = requests.post(url, headers=headers, json=data)
+ response.raise_for_status()
+ spec = response.json().get("openApiSpec", {})
+ return json.loads(spec)
+ except google.auth.exceptions.DefaultCredentialsError as e:
+ raise PermissionError(f"Credentials error: {e}") from e
+ except requests.exceptions.RequestException as e:
+ if (
+ "404" in str(e)
+ or "Not found" in str(e)
+ or "400" in str(e)
+ or "Bad request" in str(e)
+ ):
+ raise ValueError(
+ "Invalid request. Please check the provided values of"
+ f" project({self.project}), location({self.location}),"
+ f" integration({self.integration}) and trigger({self.trigger})."
+ ) from e
+ raise ValueError(f"Request error: {e}") from e
+ except Exception as e:
+ raise Exception(f"An unexpected error occurred: {e}") from e
+
+ def get_openapi_spec_for_connection(self, tool_name="", tool_instructions=""):
+ """Gets the OpenAPI spec for the connection.
+
+ Returns:
+ dict: The OpenAPI spec as a dictionary.
+ Raises:
+ ValueError: If there's an error retrieving the OpenAPI spec.
+ PermissionError: If there are credential issues.
+ Exception: For any other unexpected errors.
+ """
+ # Application Integration needs to be provisioned in the same region as connection and an integration with name "ExecuteConnection" and trigger "api_trigger/ExecuteConnection" should be created as per the documentation.
+ integration_name = "ExecuteConnection"
+ connections_client = ConnectionsClient(
+ self.project,
+ self.location,
+ self.connection,
+ self.service_account_json,
+ )
+ if not self.entity_operations and not self.actions:
+ raise ValueError(
+ "No entity operations or actions provided. Please provide at least"
+ " one of them."
+ )
+ connector_spec = connections_client.get_connector_base_spec()
+ for entity, operations in self.entity_operations.items():
+ schema, supported_operations = (
+ connections_client.get_entity_schema_and_operations(entity)
+ )
+ if not operations:
+ operations = supported_operations
+ json_schema_as_string = json.dumps(schema)
+ entity_lower = entity
+ connector_spec["components"]["schemas"][
+ f"connectorInputPayload_{entity_lower}"
+ ] = connections_client.connector_payload(schema)
+ for operation in operations:
+ operation_lower = operation.lower()
+ path = f"/v2/projects/{self.project}/locations/{self.location}/integrations/{integration_name}:execute?triggerId=api_trigger/{integration_name}#{operation_lower}_{entity_lower}"
+ if operation_lower == "create":
+ connector_spec["paths"][path] = connections_client.create_operation(
+ entity_lower, tool_name, tool_instructions
+ )
+ connector_spec["components"]["schemas"][
+ f"create_{entity_lower}_Request"
+ ] = connections_client.create_operation_request(entity_lower)
+ elif operation_lower == "update":
+ connector_spec["paths"][path] = connections_client.update_operation(
+ entity_lower, tool_name, tool_instructions
+ )
+ connector_spec["components"]["schemas"][
+ f"update_{entity_lower}_Request"
+ ] = connections_client.update_operation_request(entity_lower)
+ elif operation_lower == "delete":
+ connector_spec["paths"][path] = connections_client.delete_operation(
+ entity_lower, tool_name, tool_instructions
+ )
+ connector_spec["components"]["schemas"][
+ f"delete_{entity_lower}_Request"
+ ] = connections_client.delete_operation_request()
+ elif operation_lower == "list":
+ connector_spec["paths"][path] = connections_client.list_operation(
+ entity_lower, json_schema_as_string, tool_name, tool_instructions
+ )
+ connector_spec["components"]["schemas"][
+ f"list_{entity_lower}_Request"
+ ] = connections_client.list_operation_request()
+ elif operation_lower == "get":
+ connector_spec["paths"][path] = connections_client.get_operation(
+ entity_lower, json_schema_as_string, tool_name, tool_instructions
+ )
+ connector_spec["components"]["schemas"][
+ f"get_{entity_lower}_Request"
+ ] = connections_client.get_operation_request()
+ else:
+ raise ValueError(
+ f"Invalid operation: {operation} for entity: {entity}"
+ )
+ for action in self.actions:
+ action_details = connections_client.get_action_schema(action)
+ input_schema = action_details["inputSchema"]
+ output_schema = action_details["outputSchema"]
+ action_display_name = action_details["displayName"]
+ operation = "EXECUTE_ACTION"
+ if action == "ExecuteCustomQuery":
+ connector_spec["components"]["schemas"][
+ f"{action}_Request"
+ ] = connections_client.execute_custom_query_request()
+ operation = "EXECUTE_QUERY"
+ else:
+ connector_spec["components"]["schemas"][
+ f"{action_display_name}_Request"
+ ] = connections_client.action_request(action_display_name)
+ connector_spec["components"]["schemas"][
+ f"connectorInputPayload_{action_display_name}"
+ ] = connections_client.connector_payload(input_schema)
+ connector_spec["components"]["schemas"][
+ f"connectorOutputPayload_{action_display_name}"
+ ] = connections_client.connector_payload(output_schema)
+ connector_spec["components"]["schemas"][
+ f"{action_display_name}_Response"
+ ] = connections_client.action_response(action_display_name)
+ path = f"/v2/projects/{self.project}/locations/{self.location}/integrations/{integration_name}:execute?triggerId=api_trigger/{integration_name}#{action}"
+ connector_spec["paths"][path] = connections_client.get_action_operation(
+ action, operation, action_display_name, tool_name, tool_instructions
+ )
+ return connector_spec
+
+ def _get_access_token(self) -> str:
+ """Gets the access token for the service account or using default credentials.
+
+ Returns:
+ The access token.
+ """
+ if self.credential_cache and not self.credential_cache.expired:
+ return self.credential_cache.token
+
+ if self.service_account_json:
+ credentials = service_account.Credentials.from_service_account_info(
+ json.loads(self.service_account_json),
+ scopes=["https://www.googleapis.com/auth/cloud-platform"],
+ )
+ else:
+ try:
+ credentials, _ = default_service_credential()
+ except:
+ credentials = None
+
+ if not credentials:
+ raise ValueError(
+ "Please provide a service account that has the required permissions"
+ " to access the connection."
+ )
+
+ credentials.refresh(Request())
+ self.credential_cache = credentials
+ return credentials.token
diff --git a/src/google/adk/tools/base_tool.py b/src/google/adk/tools/base_tool.py
new file mode 100644
index 0000000..88e9b1e
--- /dev/null
+++ b/src/google/adk/tools/base_tool.py
@@ -0,0 +1,144 @@
+# Copyright 2025 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import annotations
+
+from abc import ABC
+import os
+from typing import Any
+from typing import Optional
+from typing import TYPE_CHECKING
+
+from deprecated import deprecated
+from google.genai import types
+
+from .tool_context import ToolContext
+
+if TYPE_CHECKING:
+ from ..models.llm_request import LlmRequest
+
+
+class BaseTool(ABC):
+ """The base class for all tools."""
+
+ name: str
+ """The name of the tool."""
+ description: str
+ """The description of the tool."""
+
+ is_long_running: bool = False
+ """Whether the tool is a long running operation, which typically returns a
+ resource id first and finishes the operation later."""
+
+ def __init__(self, *, name, description, is_long_running: bool = False):
+ self.name = name
+ self.description = description
+ self.is_long_running = is_long_running
+
+ def _get_declaration(self) -> Optional[types.FunctionDeclaration]:
+ """Gets the OpenAPI specification of this tool in the form of a FunctionDeclaration.
+
+ NOTE
+ - Required if subclass uses the default implementation of
+ `process_llm_request` to add function declaration to LLM request.
+ - Otherwise, can be skipped, e.g. for a built-in GoogleSearch tool for
+ Gemini.
+
+ Returns:
+ The FunctionDeclaration of this tool, or None if it doesn't need to be
+ added to LlmRequest.config.
+ """
+ return None
+
+ async def run_async(
+ self, *, args: dict[str, Any], tool_context: ToolContext
+ ) -> Any:
+ """Runs the tool with the given arguments and context.
+
+ NOTE
+ - Required if this tool needs to run at the client side.
+ - Otherwise, can be skipped, e.g. for a built-in GoogleSearch tool for
+ Gemini.
+
+ Args:
+ args: The LLM-filled arguments.
+ ctx: The context of the tool.
+
+ Returns:
+ The result of running the tool.
+ """
+ raise NotImplementedError(f'{type(self)} is not implemented')
+
+ async def process_llm_request(
+ self, *, tool_context: ToolContext, llm_request: LlmRequest
+ ) -> None:
+ """Processes the outgoing LLM request for this tool.
+
+ Use cases:
+ - Most common use case is adding this tool to the LLM request.
+ - Some tools may just preprocess the LLM request before it's sent out.
+
+ Args:
+ tool_context: The context of the tool.
+ llm_request: The outgoing LLM request, mutable this method.
+ """
+ if (function_declaration := self._get_declaration()) is None:
+ return
+
+ llm_request.tools_dict[self.name] = self
+ if tool_with_function_declarations := _find_tool_with_function_declarations(
+ llm_request
+ ):
+ if tool_with_function_declarations.function_declarations is None:
+ tool_with_function_declarations.function_declarations = []
+ tool_with_function_declarations.function_declarations.append(
+ function_declaration
+ )
+ else:
+ llm_request.config = (
+ types.GenerateContentConfig()
+ if not llm_request.config
+ else llm_request.config
+ )
+ llm_request.config.tools = (
+ [] if not llm_request.config.tools else llm_request.config.tools
+ )
+ llm_request.config.tools.append(
+ types.Tool(function_declarations=[function_declaration])
+ )
+
+ @property
+ def _api_variant(self) -> str:
+ use_vertexai = os.environ.get('GOOGLE_GENAI_USE_VERTEXAI', '0').lower() in [
+ 'true',
+ '1',
+ ]
+ return 'VERTEX_AI' if use_vertexai else 'GOOGLE_AI'
+
+
+def _find_tool_with_function_declarations(
+ llm_request: LlmRequest,
+) -> Optional[types.Tool]:
+ # TODO: add individual tool with declaration and merge in google_llm.py
+ if not llm_request.config or not llm_request.config.tools:
+ return None
+
+ return next(
+ (
+ tool
+ for tool in llm_request.config.tools
+ if isinstance(tool, types.Tool) and tool.function_declarations
+ ),
+ None,
+ )
diff --git a/src/google/adk/tools/built_in_code_execution_tool.py b/src/google/adk/tools/built_in_code_execution_tool.py
new file mode 100644
index 0000000..742059a
--- /dev/null
+++ b/src/google/adk/tools/built_in_code_execution_tool.py
@@ -0,0 +1,59 @@
+# Copyright 2025 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import annotations
+
+from typing import TYPE_CHECKING
+
+from google.genai import types
+from typing_extensions import override
+
+from .base_tool import BaseTool
+from .tool_context import ToolContext
+
+if TYPE_CHECKING:
+ from ..models import LlmRequest
+
+
+class BuiltInCodeExecutionTool(BaseTool):
+ """A built-in code execution tool that is automatically invoked by Gemini 2 models.
+
+ This tool operates internally within the model and does not require or perform
+ local code execution.
+ """
+
+ def __init__(self):
+ # Name and description are not used because this is a model built-in tool.
+ super().__init__(name='code_execution', description='code_execution')
+
+ @override
+ async def process_llm_request(
+ self,
+ *,
+ tool_context: ToolContext,
+ llm_request: LlmRequest,
+ ) -> None:
+ if llm_request.model and llm_request.model.startswith('gemini-2'):
+ llm_request.config = llm_request.config or types.GenerateContentConfig()
+ llm_request.config.tools = llm_request.config.tools or []
+ llm_request.config.tools.append(
+ types.Tool(code_execution=types.ToolCodeExecution())
+ )
+ else:
+ raise ValueError(
+ f'Code execution tool is not supported for model {llm_request.model}'
+ )
+
+
+built_in_code_execution = BuiltInCodeExecutionTool()
diff --git a/src/google/adk/tools/crewai_tool.py b/src/google/adk/tools/crewai_tool.py
new file mode 100644
index 0000000..db4c533
--- /dev/null
+++ b/src/google/adk/tools/crewai_tool.py
@@ -0,0 +1,72 @@
+# Copyright 2025 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import annotations
+
+from google.genai import types
+from typing_extensions import override
+
+from . import _automatic_function_calling_util
+from .function_tool import FunctionTool
+
+try:
+ from crewai.tools import BaseTool as CrewaiBaseTool
+except ImportError as e:
+ import sys
+
+ if sys.version_info < (3, 10):
+ raise ImportError(
+ "Crewai Tools require Python 3.10+. Please upgrade your Python version."
+ ) from e
+ else:
+ raise ImportError(
+ "Crewai Tools require pip install 'google-adk[extensions]'."
+ ) from e
+
+
+class CrewaiTool(FunctionTool):
+ """Use this class to wrap a CrewAI tool.
+
+ If the original tool name and description are not suitable, you can override
+ them in the constructor.
+ """
+
+ tool: CrewaiBaseTool
+ """The wrapped CrewAI tool."""
+
+ def __init__(self, tool: CrewaiBaseTool, *, name: str, description: str):
+ super().__init__(tool.run)
+ self.tool = tool
+ if name:
+ self.name = name
+ elif tool.name:
+ # Right now, CrewAI tool name contains white spaces. White spaces are
+ # not supported in our framework. So we replace them with "_".
+ self.name = tool.name.replace(" ", "_").lower()
+ if description:
+ self.description = description
+ elif tool.description:
+ self.description = tool.description
+
+ @override
+ def _get_declaration(self) -> types.FunctionDeclaration:
+ """Build the function declaration for the tool."""
+ function_declaration = _automatic_function_calling_util.build_function_declaration_for_params_for_crewai(
+ False,
+ self.name,
+ self.description,
+ self.func,
+ self.tool.args_schema.model_json_schema(),
+ )
+ return function_declaration
diff --git a/src/google/adk/tools/example_tool.py b/src/google/adk/tools/example_tool.py
new file mode 100644
index 0000000..a59c0a2
--- /dev/null
+++ b/src/google/adk/tools/example_tool.py
@@ -0,0 +1,62 @@
+# Copyright 2025 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import annotations
+
+from typing import TYPE_CHECKING
+from typing import Union
+
+from pydantic import TypeAdapter
+from typing_extensions import override
+
+from ..examples import example_util
+from ..examples.base_example_provider import BaseExampleProvider
+from ..examples.example import Example
+from .base_tool import BaseTool
+from .tool_context import ToolContext
+
+if TYPE_CHECKING:
+ from ..models.llm_request import LlmRequest
+
+
+class ExampleTool(BaseTool):
+ """A tool that adds (few-shot) examples to the LLM request.
+
+ Attributes:
+ examples: The examples to add to the LLM request.
+ """
+
+ def __init__(self, examples: Union[list[Example], BaseExampleProvider]):
+ # Name and description are not used because this tool only changes
+ # llm_request.
+ super().__init__(name='example_tool', description='example tool')
+ self.examples = (
+ TypeAdapter(list[Example]).validate_python(examples)
+ if isinstance(examples, list)
+ else examples
+ )
+
+ @override
+ async def process_llm_request(
+ self, *, tool_context: ToolContext, llm_request: LlmRequest
+ ) -> None:
+ parts = tool_context.user_content.parts
+ if not parts or not parts[0].text:
+ return
+
+ llm_request.append_instructions([
+ example_util.build_example_si(
+ self.examples, parts[0].text, llm_request.model
+ )
+ ])
diff --git a/src/google/adk/tools/exit_loop_tool.py b/src/google/adk/tools/exit_loop_tool.py
new file mode 100644
index 0000000..181dc7e
--- /dev/null
+++ b/src/google/adk/tools/exit_loop_tool.py
@@ -0,0 +1,23 @@
+# Copyright 2025 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from .tool_context import ToolContext
+
+
+def exit_loop(tool_context: ToolContext):
+ """Exits the loop.
+
+ Call this function only when you are instructed to do so.
+ """
+ tool_context.actions.escalate = True
diff --git a/src/google/adk/tools/function_parameter_parse_util.py b/src/google/adk/tools/function_parameter_parse_util.py
new file mode 100644
index 0000000..4252c4f
--- /dev/null
+++ b/src/google/adk/tools/function_parameter_parse_util.py
@@ -0,0 +1,307 @@
+# Copyright 2024 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import inspect
+import logging
+import types as typing_types
+from typing import _GenericAlias
+from typing import Any
+from typing import get_args
+from typing import get_origin
+from typing import Literal
+from typing import Union
+
+from google.genai import types
+import pydantic
+
+_py_builtin_type_to_schema_type = {
+ str: types.Type.STRING,
+ int: types.Type.INTEGER,
+ float: types.Type.NUMBER,
+ bool: types.Type.BOOLEAN,
+ list: types.Type.ARRAY,
+ dict: types.Type.OBJECT,
+}
+
+logger = logging.getLogger(__name__)
+
+
+def _is_builtin_primitive_or_compound(
+ annotation: inspect.Parameter.annotation,
+) -> bool:
+ return annotation in _py_builtin_type_to_schema_type.keys()
+
+
+def _raise_for_any_of_if_mldev(schema: types.Schema):
+ if schema.any_of:
+ raise ValueError(
+ 'AnyOf is not supported in function declaration schema for Google AI.'
+ )
+
+
+def _update_for_default_if_mldev(schema: types.Schema):
+ if schema.default is not None:
+ # TODO(kech): Remove this walkaround once mldev supports default value.
+ schema.default = None
+ logger.warning(
+ 'Default value is not supported in function declaration schema for'
+ ' Google AI.'
+ )
+
+
+def _raise_if_schema_unsupported(variant: str, schema: types.Schema):
+ if not variant == 'VERTEX_AI':
+ _raise_for_any_of_if_mldev(schema)
+ _update_for_default_if_mldev(schema)
+
+
+def _is_default_value_compatible(
+ default_value: Any, annotation: inspect.Parameter.annotation
+) -> bool:
+ # None type is expected to be handled external to this function
+ if _is_builtin_primitive_or_compound(annotation):
+ return isinstance(default_value, annotation)
+
+ if (
+ isinstance(annotation, _GenericAlias)
+ or isinstance(annotation, typing_types.GenericAlias)
+ or isinstance(annotation, typing_types.UnionType)
+ ):
+ origin = get_origin(annotation)
+ if origin in (Union, typing_types.UnionType):
+ return any(
+ _is_default_value_compatible(default_value, arg)
+ for arg in get_args(annotation)
+ )
+
+ if origin is dict:
+ return isinstance(default_value, dict)
+
+ if origin is list:
+ if not isinstance(default_value, list):
+ return False
+ # most tricky case, element in list is union type
+ # need to apply any logic within all
+ # see test case test_generic_alias_complex_array_with_default_value
+ # a: typing.List[int | str | float | bool]
+ # default_value: [1, 'a', 1.1, True]
+ return all(
+ any(
+ _is_default_value_compatible(item, arg)
+ for arg in get_args(annotation)
+ )
+ for item in default_value
+ )
+
+ if origin is Literal:
+ return default_value in get_args(annotation)
+
+ # return False for any other unrecognized annotation
+ # let caller handle the raise
+ return False
+
+
+def _parse_schema_from_parameter(
+ variant: str, param: inspect.Parameter, func_name: str
+) -> types.Schema:
+ """parse schema from parameter.
+
+ from the simplest case to the most complex case.
+ """
+ schema = types.Schema()
+ default_value_error_msg = (
+ f'Default value {param.default} of parameter {param} of function'
+ f' {func_name} is not compatible with the parameter annotation'
+ f' {param.annotation}.'
+ )
+ if _is_builtin_primitive_or_compound(param.annotation):
+ if param.default is not inspect.Parameter.empty:
+ if not _is_default_value_compatible(param.default, param.annotation):
+ raise ValueError(default_value_error_msg)
+ schema.default = param.default
+ schema.type = _py_builtin_type_to_schema_type[param.annotation]
+ _raise_if_schema_unsupported(variant, schema)
+ return schema
+ if (
+ get_origin(param.annotation) is Union
+ # only parse simple UnionType, example int | str | float | bool
+ # complex types.UnionType will be invoked in raise branch
+ and all(
+ (_is_builtin_primitive_or_compound(arg) or arg is type(None))
+ for arg in get_args(param.annotation)
+ )
+ ):
+ schema.type = types.Type.OBJECT
+ schema.any_of = []
+ unique_types = set()
+ for arg in get_args(param.annotation):
+ if arg.__name__ == 'NoneType': # Optional type
+ schema.nullable = True
+ continue
+ schema_in_any_of = _parse_schema_from_parameter(
+ variant,
+ inspect.Parameter(
+ 'item', inspect.Parameter.POSITIONAL_OR_KEYWORD, annotation=arg
+ ),
+ func_name,
+ )
+ if (
+ schema_in_any_of.model_dump_json(exclude_none=True)
+ not in unique_types
+ ):
+ schema.any_of.append(schema_in_any_of)
+ unique_types.add(schema_in_any_of.model_dump_json(exclude_none=True))
+ if len(schema.any_of) == 1: # param: list | None -> Array
+ schema.type = schema.any_of[0].type
+ schema.any_of = None
+ if (
+ param.default is not inspect.Parameter.empty
+ and param.default is not None
+ ):
+ if not _is_default_value_compatible(param.default, param.annotation):
+ raise ValueError(default_value_error_msg)
+ schema.default = param.default
+ _raise_if_schema_unsupported(variant, schema)
+ return schema
+ if isinstance(param.annotation, _GenericAlias) or isinstance(
+ param.annotation, typing_types.GenericAlias
+ ):
+ origin = get_origin(param.annotation)
+ args = get_args(param.annotation)
+ if origin is dict:
+ schema.type = types.Type.OBJECT
+ if param.default is not inspect.Parameter.empty:
+ if not _is_default_value_compatible(param.default, param.annotation):
+ raise ValueError(default_value_error_msg)
+ schema.default = param.default
+ _raise_if_schema_unsupported(variant, schema)
+ return schema
+ if origin is Literal:
+ if not all(isinstance(arg, str) for arg in args):
+ raise ValueError(
+ f'Literal type {param.annotation} must be a list of strings.'
+ )
+ schema.type = types.Type.STRING
+ schema.enum = list(args)
+ if param.default is not inspect.Parameter.empty:
+ if not _is_default_value_compatible(param.default, param.annotation):
+ raise ValueError(default_value_error_msg)
+ schema.default = param.default
+ _raise_if_schema_unsupported(variant, schema)
+ return schema
+ if origin is list:
+ schema.type = types.Type.ARRAY
+ schema.items = _parse_schema_from_parameter(
+ variant,
+ inspect.Parameter(
+ 'item',
+ inspect.Parameter.POSITIONAL_OR_KEYWORD,
+ annotation=args[0],
+ ),
+ func_name,
+ )
+ if param.default is not inspect.Parameter.empty:
+ if not _is_default_value_compatible(param.default, param.annotation):
+ raise ValueError(default_value_error_msg)
+ schema.default = param.default
+ _raise_if_schema_unsupported(variant, schema)
+ return schema
+ if origin is Union:
+ schema.any_of = []
+ schema.type = types.Type.OBJECT
+ unique_types = set()
+ for arg in args:
+ if arg.__name__ == 'NoneType': # Optional type
+ schema.nullable = True
+ continue
+ schema_in_any_of = _parse_schema_from_parameter(
+ variant,
+ inspect.Parameter(
+ 'item',
+ inspect.Parameter.POSITIONAL_OR_KEYWORD,
+ annotation=arg,
+ ),
+ func_name,
+ )
+ if (
+ len(param.annotation.__args__) == 2
+ and type(None) in param.annotation.__args__
+ ): # Optional type
+ for optional_arg in param.annotation.__args__:
+ if (
+ hasattr(optional_arg, '__origin__')
+ and optional_arg.__origin__ is list
+ ):
+ # Optional type with list, for example Optional[list[str]]
+ schema.items = schema_in_any_of.items
+ if (
+ schema_in_any_of.model_dump_json(exclude_none=True)
+ not in unique_types
+ ):
+ schema.any_of.append(schema_in_any_of)
+ unique_types.add(schema_in_any_of.model_dump_json(exclude_none=True))
+ if len(schema.any_of) == 1: # param: Union[List, None] -> Array
+ schema.type = schema.any_of[0].type
+ schema.any_of = None
+ if (
+ param.default is not None
+ and param.default is not inspect.Parameter.empty
+ ):
+ if not _is_default_value_compatible(param.default, param.annotation):
+ raise ValueError(default_value_error_msg)
+ schema.default = param.default
+ _raise_if_schema_unsupported(variant, schema)
+ return schema
+ # all other generic alias will be invoked in raise branch
+ if (
+ inspect.isclass(param.annotation)
+ # for user defined class, we only support pydantic model
+ and issubclass(param.annotation, pydantic.BaseModel)
+ ):
+ if (
+ param.default is not inspect.Parameter.empty
+ and param.default is not None
+ ):
+ schema.default = param.default
+ schema.type = types.Type.OBJECT
+ schema.properties = {}
+ for field_name, field_info in param.annotation.model_fields.items():
+ schema.properties[field_name] = _parse_schema_from_parameter(
+ variant,
+ inspect.Parameter(
+ field_name,
+ inspect.Parameter.POSITIONAL_OR_KEYWORD,
+ annotation=field_info.annotation,
+ ),
+ func_name,
+ )
+ _raise_if_schema_unsupported(variant, schema)
+ return schema
+ raise ValueError(
+ f'Failed to parse the parameter {param} of function {func_name} for'
+ ' automatic function calling.Automatic function calling works best with'
+ ' simpler function signature schema,consider manually parse your'
+ f' function declaration for function {func_name}.'
+ )
+
+
+def _get_required_fields(schema: types.Schema) -> list[str]:
+ if not schema.properties:
+ return
+ return [
+ field_name
+ for field_name, field_schema in schema.properties.items()
+ if not field_schema.nullable and field_schema.default is None
+ ]
diff --git a/src/google/adk/tools/function_tool.py b/src/google/adk/tools/function_tool.py
new file mode 100644
index 0000000..15b9c6b
--- /dev/null
+++ b/src/google/adk/tools/function_tool.py
@@ -0,0 +1,87 @@
+# Copyright 2025 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import inspect
+from typing import Any
+from typing import Callable
+from typing import Optional
+
+from google.genai import types
+from typing_extensions import override
+
+from ._automatic_function_calling_util import build_function_declaration
+from .base_tool import BaseTool
+from .tool_context import ToolContext
+
+
+class FunctionTool(BaseTool):
+ """A tool that wraps a user-defined Python function.
+
+ Attributes:
+ func: The function to wrap.
+ """
+
+ def __init__(self, func: Callable[..., Any]):
+ super().__init__(name=func.__name__, description=func.__doc__)
+ self.func = func
+
+ @override
+ def _get_declaration(self) -> Optional[types.FunctionDeclaration]:
+ function_decl = types.FunctionDeclaration.model_validate(
+ build_function_declaration(
+ func=self.func,
+ # The model doesn't understand the function context.
+ # input_stream is for streaming tool
+ ignore_params=['tool_context', 'input_stream'],
+ variant=self._api_variant,
+ )
+ )
+
+ return function_decl
+
+ @override
+ async def run_async(
+ self, *, args: dict[str, Any], tool_context: ToolContext
+ ) -> Any:
+ args_to_call = args.copy()
+ signature = inspect.signature(self.func)
+ if 'tool_context' in signature.parameters:
+ args_to_call['tool_context'] = tool_context
+
+ if inspect.iscoroutinefunction(self.func):
+ return await self.func(**args_to_call) or {}
+ else:
+ return self.func(**args_to_call) or {}
+
+ # TODO(hangfei): fix call live for function stream.
+ async def _call_live(
+ self,
+ *,
+ args: dict[str, Any],
+ tool_context: ToolContext,
+ invocation_context,
+ ) -> Any:
+ args_to_call = args.copy()
+ signature = inspect.signature(self.func)
+ if (
+ self.name in invocation_context.active_streaming_tools
+ and invocation_context.active_streaming_tools[self.name].stream
+ ):
+ args_to_call['input_stream'] = invocation_context.active_streaming_tools[
+ self.name
+ ].stream
+ if 'tool_context' in signature.parameters:
+ args_to_call['tool_context'] = tool_context
+ async for item in self.func(**args_to_call):
+ yield item
diff --git a/src/google/adk/tools/get_user_choice_tool.py b/src/google/adk/tools/get_user_choice_tool.py
new file mode 100644
index 0000000..99d86f0
--- /dev/null
+++ b/src/google/adk/tools/get_user_choice_tool.py
@@ -0,0 +1,28 @@
+# Copyright 2025 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import Optional
+from .long_running_tool import LongRunningFunctionTool
+from .tool_context import ToolContext
+
+
+def get_user_choice(
+ options: list[str], tool_context: ToolContext
+) -> Optional[str]:
+ """Provides the options to the user and asks them to choose one."""
+ tool_context.actions.skip_summarization = True
+ return None
+
+
+get_user_choice_tool = LongRunningFunctionTool(func=get_user_choice)
diff --git a/src/google/adk/tools/google_api_tool/__init__.py b/src/google/adk/tools/google_api_tool/__init__.py
new file mode 100644
index 0000000..be55eb5
--- /dev/null
+++ b/src/google/adk/tools/google_api_tool/__init__.py
@@ -0,0 +1,14 @@
+# Copyright 2025 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from .google_api_tool_sets import calendar_tool_set
diff --git a/src/google/adk/tools/google_api_tool/google_api_tool.py b/src/google/adk/tools/google_api_tool/google_api_tool.py
new file mode 100644
index 0000000..921ac2d
--- /dev/null
+++ b/src/google/adk/tools/google_api_tool/google_api_tool.py
@@ -0,0 +1,59 @@
+# Copyright 2025 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import Any
+from typing import Dict
+from typing import Optional
+
+from google.genai.types import FunctionDeclaration
+from typing_extensions import override
+
+from ...auth import AuthCredential
+from ...auth import AuthCredentialTypes
+from ...auth import OAuth2Auth
+from .. import BaseTool
+from ..openapi_tool import RestApiTool
+from ..tool_context import ToolContext
+
+
+class GoogleApiTool(BaseTool):
+
+ def __init__(self, rest_api_tool: RestApiTool):
+ super().__init__(
+ name=rest_api_tool.name,
+ description=rest_api_tool.description,
+ is_long_running=rest_api_tool.is_long_running,
+ )
+ self.rest_api_tool = rest_api_tool
+
+ @override
+ def _get_declaration(self) -> FunctionDeclaration:
+ return self.rest_api_tool._get_declaration()
+
+ @override
+ async def run_async(
+ self, *, args: dict[str, Any], tool_context: Optional[ToolContext]
+ ) -> Dict[str, Any]:
+ return await self.rest_api_tool.run_async(
+ args=args, tool_context=tool_context
+ )
+
+ def configure_auth(self, client_id: str, client_secret: str):
+ self.rest_api_tool.auth_credential = AuthCredential(
+ auth_type=AuthCredentialTypes.OPEN_ID_CONNECT,
+ oauth2=OAuth2Auth(
+ client_id=client_id,
+ client_secret=client_secret,
+ ),
+ )
diff --git a/src/google/adk/tools/google_api_tool/google_api_tool_set.py b/src/google/adk/tools/google_api_tool/google_api_tool_set.py
new file mode 100644
index 0000000..53201ba
--- /dev/null
+++ b/src/google/adk/tools/google_api_tool/google_api_tool_set.py
@@ -0,0 +1,107 @@
+# Copyright 2025 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import inspect
+import os
+from typing import Any
+from typing import Dict
+from typing import Final
+from typing import List
+from typing import Optional
+from typing import Type
+
+from ...auth import OpenIdConnectWithConfig
+from ..openapi_tool import OpenAPIToolset
+from ..openapi_tool import RestApiTool
+from .google_api_tool import GoogleApiTool
+from .googleapi_to_openapi_converter import GoogleApiToOpenApiConverter
+
+
+class GoogleApiToolSet:
+
+ def __init__(self, tools: List[RestApiTool]):
+ self.tools: Final[List[GoogleApiTool]] = [
+ GoogleApiTool(tool) for tool in tools
+ ]
+
+ def get_tools(self) -> List[GoogleApiTool]:
+ """Get all tools in the toolset."""
+ return self.tools
+
+ def get_tool(self, tool_name: str) -> Optional[GoogleApiTool]:
+ """Get a tool by name."""
+ matching_tool = filter(lambda t: t.name == tool_name, self.tools)
+ return next(matching_tool, None)
+
+ @staticmethod
+ def _load_tool_set_with_oidc_auth(
+ spec_file: str = None,
+ spec_dict: Dict[str, Any] = None,
+ scopes: list[str] = None,
+ ) -> Optional[OpenAPIToolset]:
+ spec_str = None
+ if spec_file:
+ # Get the frame of the caller
+ caller_frame = inspect.stack()[1]
+ # Get the filename of the caller
+ caller_filename = caller_frame.filename
+ # Get the directory of the caller
+ caller_dir = os.path.dirname(os.path.abspath(caller_filename))
+ # Join the directory path with the filename
+ yaml_path = os.path.join(caller_dir, spec_file)
+ with open(yaml_path, 'r', encoding='utf-8') as file:
+ spec_str = file.read()
+ tool_set = OpenAPIToolset(
+ spec_dict=spec_dict,
+ spec_str=spec_str,
+ spec_str_type='yaml',
+ auth_scheme=OpenIdConnectWithConfig(
+ authorization_endpoint=(
+ 'https://accounts.google.com/o/oauth2/v2/auth'
+ ),
+ token_endpoint='https://oauth2.googleapis.com/token',
+ userinfo_endpoint=(
+ 'https://openidconnect.googleapis.com/v1/userinfo'
+ ),
+ revocation_endpoint='https://oauth2.googleapis.com/revoke',
+ token_endpoint_auth_methods_supported=[
+ 'client_secret_post',
+ 'client_secret_basic',
+ ],
+ grant_types_supported=['authorization_code'],
+ scopes=scopes,
+ ),
+ )
+ return tool_set
+
+ def configure_auth(self, client_id: str, client_secret: str):
+ for tool in self.tools:
+ tool.configure_auth(client_id, client_secret)
+
+ @classmethod
+ def load_tool_set(
+ cl: Type['GoogleApiToolSet'],
+ api_name: str,
+ api_version: str,
+ ) -> 'GoogleApiToolSet':
+ spec_dict = GoogleApiToOpenApiConverter(api_name, api_version).convert()
+ scope = list(
+ spec_dict['components']['securitySchemes']['oauth2']['flows'][
+ 'authorizationCode'
+ ]['scopes'].keys()
+ )[0]
+ return cl(
+ cl._load_tool_set_with_oidc_auth(
+ spec_dict=spec_dict, scopes=[scope]
+ ).get_tools()
+ )
diff --git a/src/google/adk/tools/google_api_tool/google_api_tool_sets.py b/src/google/adk/tools/google_api_tool/google_api_tool_sets.py
new file mode 100644
index 0000000..a8e30c7
--- /dev/null
+++ b/src/google/adk/tools/google_api_tool/google_api_tool_sets.py
@@ -0,0 +1,55 @@
+# Copyright 2025 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import logging
+
+from .google_api_tool_set import GoogleApiToolSet
+
+logger = logging.getLogger(__name__)
+
+calendar_tool_set = GoogleApiToolSet.load_tool_set(
+ api_name="calendar",
+ api_version="v3",
+)
+
+bigquery_tool_set = GoogleApiToolSet.load_tool_set(
+ api_name="bigquery",
+ api_version="v2",
+)
+
+gmail_tool_set = GoogleApiToolSet.load_tool_set(
+ api_name="gmail",
+ api_version="v1",
+)
+
+youtube_tool_set = GoogleApiToolSet.load_tool_set(
+ api_name="youtube",
+ api_version="v3",
+)
+
+slides_tool_set = GoogleApiToolSet.load_tool_set(
+ api_name="slides",
+ api_version="v1",
+)
+
+sheets_tool_set = GoogleApiToolSet.load_tool_set(
+ api_name="sheets",
+ api_version="v4",
+)
+
+docs_tool_set = GoogleApiToolSet.load_tool_set(
+ api_name="docs",
+ api_version="v1",
+)
diff --git a/src/google/adk/tools/google_api_tool/googleapi_to_openapi_converter.py b/src/google/adk/tools/google_api_tool/googleapi_to_openapi_converter.py
new file mode 100644
index 0000000..818d244
--- /dev/null
+++ b/src/google/adk/tools/google_api_tool/googleapi_to_openapi_converter.py
@@ -0,0 +1,521 @@
+# Copyright 2025 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import argparse
+import json
+import logging
+from typing import Any
+from typing import Dict
+from typing import List
+from typing import Optional
+from typing import Union
+
+# Google API client
+from googleapiclient.discovery import build
+from googleapiclient.discovery import Resource
+from googleapiclient.errors import HttpError
+
+# Configure logging
+logging.basicConfig(
+ format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
+)
+logger = logging.getLogger(__name__)
+
+
+class GoogleApiToOpenApiConverter:
+ """Converts Google API Discovery documents to OpenAPI v3 format."""
+
+ def __init__(self, api_name: str, api_version: str):
+ """Initialize the converter with the API name and version.
+
+ Args:
+ api_name: The name of the Google API (e.g., "calendar")
+ api_version: The version of the API (e.g., "v3")
+ """
+ self.api_name = api_name
+ self.api_version = api_version
+ self.google_api_resource = None
+ self.google_api_spec = None
+ self.openapi_spec = {
+ "openapi": "3.0.0",
+ "info": {},
+ "servers": [],
+ "paths": {},
+ "components": {"schemas": {}, "securitySchemes": {}},
+ }
+
+ def fetch_google_api_spec(self) -> None:
+ """Fetches the Google API specification using discovery service."""
+ try:
+ logger.info(
+ "Fetching Google API spec for %s %s", self.api_name, self.api_version
+ )
+ # Build a resource object for the specified API
+ self.google_api_resource = build(self.api_name, self.api_version)
+
+ # Access the underlying API discovery document
+ self.google_api_spec = self.google_api_resource._rootDesc
+
+ if not self.google_api_spec:
+ raise ValueError("Failed to retrieve API specification")
+
+ logger.info("Successfully fetched %s API specification", self.api_name)
+ except HttpError as e:
+ logger.error("HTTP Error: %s", e)
+ raise
+ except Exception as e:
+ logger.error("Error fetching API spec: %s", e)
+ raise
+
+ def convert(self) -> Dict[str, Any]:
+ """Convert the Google API spec to OpenAPI v3 format.
+
+ Returns:
+ Dict containing the converted OpenAPI v3 specification
+ """
+ if not self.google_api_spec:
+ self.fetch_google_api_spec()
+
+ # Convert basic API information
+ self._convert_info()
+
+ # Convert server information
+ self._convert_servers()
+
+ # Convert authentication/authorization schemes
+ self._convert_security_schemes()
+
+ # Convert schemas (models)
+ self._convert_schemas()
+
+ # Convert endpoints/paths
+ self._convert_resources(self.google_api_spec.get("resources", {}))
+
+ # Convert top-level methods, if any
+ self._convert_methods(self.google_api_spec.get("methods", {}), "/")
+
+ return self.openapi_spec
+
+ def _convert_info(self) -> None:
+ """Convert basic API information."""
+ self.openapi_spec["info"] = {
+ "title": self.google_api_spec.get("title", f"{self.api_name} API"),
+ "description": self.google_api_spec.get("description", ""),
+ "version": self.google_api_spec.get("version", self.api_version),
+ "contact": {},
+ "termsOfService": self.google_api_spec.get("documentationLink", ""),
+ }
+
+ # Add documentation links if available
+ docs_link = self.google_api_spec.get("documentationLink")
+ if docs_link:
+ self.openapi_spec["externalDocs"] = {
+ "description": "API Documentation",
+ "url": docs_link,
+ }
+
+ def _convert_servers(self) -> None:
+ """Convert server information."""
+ base_url = self.google_api_spec.get(
+ "rootUrl", ""
+ ) + self.google_api_spec.get("servicePath", "")
+
+ # Remove trailing slash if present
+ if base_url.endswith("/"):
+ base_url = base_url[:-1]
+
+ self.openapi_spec["servers"] = [{
+ "url": base_url,
+ "description": f"{self.api_name} {self.api_version} API",
+ }]
+
+ def _convert_security_schemes(self) -> None:
+ """Convert authentication and authorization schemes."""
+ auth = self.google_api_spec.get("auth", {})
+ oauth2 = auth.get("oauth2", {})
+
+ if oauth2:
+ # Handle OAuth2
+ scopes = oauth2.get("scopes", {})
+ formatted_scopes = {}
+
+ for scope, scope_info in scopes.items():
+ formatted_scopes[scope] = scope_info.get("description", "")
+
+ self.openapi_spec["components"]["securitySchemes"]["oauth2"] = {
+ "type": "oauth2",
+ "description": "OAuth 2.0 authentication",
+ "flows": {
+ "authorizationCode": {
+ "authorizationUrl": (
+ "https://accounts.google.com/o/oauth2/auth"
+ ),
+ "tokenUrl": "https://oauth2.googleapis.com/token",
+ "scopes": formatted_scopes,
+ }
+ },
+ }
+
+ # Add API key authentication (most Google APIs support this)
+ self.openapi_spec["components"]["securitySchemes"]["apiKey"] = {
+ "type": "apiKey",
+ "in": "query",
+ "name": "key",
+ "description": "API key for accessing this API",
+ }
+
+ # Create global security requirement
+ self.openapi_spec["security"] = [
+ {"oauth2": list(formatted_scopes.keys())} if oauth2 else {},
+ {"apiKey": []},
+ ]
+
+ def _convert_schemas(self) -> None:
+ """Convert schema definitions (models)."""
+ schemas = self.google_api_spec.get("schemas", {})
+
+ for schema_name, schema_def in schemas.items():
+ converted_schema = self._convert_schema_object(schema_def)
+ self.openapi_spec["components"]["schemas"][schema_name] = converted_schema
+
+ def _convert_schema_object(
+ self, schema_def: Dict[str, Any]
+ ) -> Dict[str, Any]:
+ """Recursively convert a Google API schema object to OpenAPI schema.
+
+ Args:
+ schema_def: Google API schema definition
+
+ Returns:
+ Converted OpenAPI schema object
+ """
+ result = {}
+
+ # Convert the type
+ if "type" in schema_def:
+ gtype = schema_def["type"]
+ if gtype == "object":
+ result["type"] = "object"
+
+ # Handle properties
+ if "properties" in schema_def:
+ result["properties"] = {}
+ for prop_name, prop_def in schema_def["properties"].items():
+ result["properties"][prop_name] = self._convert_schema_object(
+ prop_def
+ )
+
+ # Handle required fields
+ required_fields = []
+ for prop_name, prop_def in schema_def.get("properties", {}).items():
+ if prop_def.get("required", False):
+ required_fields.append(prop_name)
+ if required_fields:
+ result["required"] = required_fields
+
+ elif gtype == "array":
+ result["type"] = "array"
+ if "items" in schema_def:
+ result["items"] = self._convert_schema_object(schema_def["items"])
+
+ elif gtype == "any":
+ # OpenAPI doesn't have direct "any" type
+ # Use oneOf with multiple options as alternative
+ result["oneOf"] = [
+ {"type": "object"},
+ {"type": "array"},
+ {"type": "string"},
+ {"type": "number"},
+ {"type": "boolean"},
+ {"type": "null"},
+ ]
+
+ else:
+ # Handle other primitive types
+ result["type"] = gtype
+
+ # Handle references
+ if "$ref" in schema_def:
+ ref = schema_def["$ref"]
+ # Google refs use "#" at start, OpenAPI uses "#/components/schemas/"
+ if ref.startswith("#"):
+ ref = ref.replace("#", "#/components/schemas/")
+ else:
+ ref = "#/components/schemas/" + ref
+ result["$ref"] = ref
+
+ # Handle format
+ if "format" in schema_def:
+ result["format"] = schema_def["format"]
+
+ # Handle enum values
+ if "enum" in schema_def:
+ result["enum"] = schema_def["enum"]
+
+ # Handle description
+ if "description" in schema_def:
+ result["description"] = schema_def["description"]
+
+ # Handle pattern
+ if "pattern" in schema_def:
+ result["pattern"] = schema_def["pattern"]
+
+ # Handle default value
+ if "default" in schema_def:
+ result["default"] = schema_def["default"]
+
+ return result
+
+ def _convert_resources(
+ self, resources: Dict[str, Any], parent_path: str = ""
+ ) -> None:
+ """Recursively convert all resources and their methods.
+
+ Args:
+ resources: Dictionary of resources from the Google API spec
+ parent_path: The parent path prefix for nested resources
+ """
+ for resource_name, resource_data in resources.items():
+ # Process methods for this resource
+ resource_path = f"{parent_path}/{resource_name}"
+ methods = resource_data.get("methods", {})
+ self._convert_methods(methods, resource_path)
+
+ # Process nested resources recursively
+ nested_resources = resource_data.get("resources", {})
+ if nested_resources:
+ self._convert_resources(nested_resources, resource_path)
+
+ def _convert_methods(
+ self, methods: Dict[str, Any], resource_path: str
+ ) -> None:
+ """Convert methods for a specific resource path.
+
+ Args:
+ methods: Dictionary of methods from the Google API spec
+ resource_path: The path of the resource these methods belong to
+ """
+ for method_name, method_data in methods.items():
+ http_method = method_data.get("httpMethod", "GET").lower()
+
+ # Determine the actual endpoint path
+ # Google often has the format something like 'users.messages.list'
+ rest_path = method_data.get("path", "/")
+ if not rest_path.startswith("/"):
+ rest_path = "/" + rest_path
+
+ path_params = self._extract_path_parameters(rest_path)
+
+ # Create path entry if it doesn't exist
+ if rest_path not in self.openapi_spec["paths"]:
+ self.openapi_spec["paths"][rest_path] = {}
+
+ # Add the operation for this method
+ self.openapi_spec["paths"][rest_path][http_method] = (
+ self._convert_operation(method_data, path_params)
+ )
+
+ def _extract_path_parameters(self, path: str) -> List[str]:
+ """Extract path parameters from a URL path.
+
+ Args:
+ path: The URL path with path parameters
+
+ Returns:
+ List of parameter names
+ """
+ params = []
+ segments = path.split("/")
+
+ for segment in segments:
+ # Google APIs often use {param} format for path parameters
+ if segment.startswith("{") and segment.endswith("}"):
+ param_name = segment[1:-1]
+ params.append(param_name)
+
+ return params
+
+ def _convert_operation(
+ self, method_data: Dict[str, Any], path_params: List[str]
+ ) -> Dict[str, Any]:
+ """Convert a Google API method to an OpenAPI operation.
+
+ Args:
+ method_data: Google API method data
+ path_params: List of path parameter names
+
+ Returns:
+ OpenAPI operation object
+ """
+ operation = {
+ "operationId": method_data.get("id", ""),
+ "summary": method_data.get("description", ""),
+ "description": method_data.get("description", ""),
+ "parameters": [],
+ "responses": {
+ "200": {"description": "Successful operation"},
+ "400": {"description": "Bad request"},
+ "401": {"description": "Unauthorized"},
+ "403": {"description": "Forbidden"},
+ "404": {"description": "Not found"},
+ "500": {"description": "Server error"},
+ },
+ }
+
+ # Add path parameters
+ for param_name in path_params:
+ param = {
+ "name": param_name,
+ "in": "path",
+ "required": True,
+ "schema": {"type": "string"},
+ }
+ operation["parameters"].append(param)
+
+ # Add query parameters
+ for param_name, param_data in method_data.get("parameters", {}).items():
+ # Skip parameters already included in path
+ if param_name in path_params:
+ continue
+
+ param = {
+ "name": param_name,
+ "in": "query",
+ "description": param_data.get("description", ""),
+ "required": param_data.get("required", False),
+ "schema": self._convert_parameter_schema(param_data),
+ }
+ operation["parameters"].append(param)
+
+ # Handle request body
+ if "request" in method_data:
+ request_ref = method_data.get("request", {}).get("$ref", "")
+ if request_ref:
+ if request_ref.startswith("#"):
+ # Convert Google's reference format to OpenAPI format
+ openapi_ref = request_ref.replace("#", "#/components/schemas/")
+ else:
+ openapi_ref = "#/components/schemas/" + request_ref
+ operation["requestBody"] = {
+ "description": "Request body",
+ "content": {"application/json": {"schema": {"$ref": openapi_ref}}},
+ "required": True,
+ }
+
+ # Handle response body
+ if "response" in method_data:
+ response_ref = method_data.get("response", {}).get("$ref", "")
+ if response_ref:
+ if response_ref.startswith("#"):
+ # Convert Google's reference format to OpenAPI format
+ openapi_ref = response_ref.replace("#", "#/components/schemas/")
+ else:
+ openapi_ref = "#/components/schemas/" + response_ref
+ operation["responses"]["200"]["content"] = {
+ "application/json": {"schema": {"$ref": openapi_ref}}
+ }
+
+ # Add scopes if available
+ scopes = method_data.get("scopes", [])
+ if scopes:
+ # Add method-specific security requirement if different from global
+ operation["security"] = [{"oauth2": scopes}]
+
+ return operation
+
+ def _convert_parameter_schema(
+ self, param_data: Dict[str, Any]
+ ) -> Dict[str, Any]:
+ """Convert a parameter definition to an OpenAPI schema.
+
+ Args:
+ param_data: Google API parameter data
+
+ Returns:
+ OpenAPI schema for the parameter
+ """
+ schema = {}
+
+ # Convert type
+ param_type = param_data.get("type", "string")
+ schema["type"] = param_type
+
+ # Handle enum values
+ if "enum" in param_data:
+ schema["enum"] = param_data["enum"]
+
+ # Handle format
+ if "format" in param_data:
+ schema["format"] = param_data["format"]
+
+ # Handle default value
+ if "default" in param_data:
+ schema["default"] = param_data["default"]
+
+ # Handle pattern
+ if "pattern" in param_data:
+ schema["pattern"] = param_data["pattern"]
+
+ return schema
+
+ def save_openapi_spec(self, output_path: str) -> None:
+ """Save the OpenAPI specification to a file.
+
+ Args:
+ output_path: Path where the OpenAPI spec should be saved
+ """
+ with open(output_path, "w", encoding="utf-8") as f:
+ json.dump(self.openapi_spec, f, indent=2)
+ logger.info("OpenAPI specification saved to %s", output_path)
+
+
+def main():
+ """Command line interface for the converter."""
+ parser = argparse.ArgumentParser(
+ description=(
+ "Convert Google API Discovery documents to OpenAPI v3 specifications"
+ )
+ )
+ parser.add_argument(
+ "api_name", help="Name of the Google API (e.g., 'calendar')"
+ )
+ parser.add_argument("api_version", help="Version of the API (e.g., 'v3')")
+ parser.add_argument(
+ "--output",
+ "-o",
+ default="openapi_spec.json",
+ help="Output file path for the OpenAPI specification",
+ )
+
+ args = parser.parse_args()
+
+ try:
+ # Create and run the converter
+ converter = GoogleApiToOpenApiConverter(args.api_name, args.api_version)
+ converter.convert()
+ converter.save_openapi_spec(args.output)
+ print(
+ f"Successfully converted {args.api_name} {args.api_version} to"
+ " OpenAPI v3"
+ )
+ print(f"Output saved to {args.output}")
+ except Exception as e:
+ logger.error("Conversion failed: %s", e)
+ return 1
+
+ return 0
+
+
+if __name__ == "__main__":
+ main()
diff --git a/src/google/adk/tools/google_search_tool.py b/src/google/adk/tools/google_search_tool.py
new file mode 100644
index 0000000..e029a09
--- /dev/null
+++ b/src/google/adk/tools/google_search_tool.py
@@ -0,0 +1,68 @@
+# Copyright 2025 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import annotations
+
+from typing import TYPE_CHECKING
+
+from google.genai import types
+from typing_extensions import override
+
+from .base_tool import BaseTool
+from .tool_context import ToolContext
+
+if TYPE_CHECKING:
+ from ..models import LlmRequest
+
+
+class GoogleSearchTool(BaseTool):
+ """A built-in tool that is automatically invoked by Gemini 2 models to retrieve search results from Google Search.
+
+ This tool operates internally within the model and does not require or perform
+ local code execution.
+ """
+
+ def __init__(self):
+ # Name and description are not used because this is a model built-in tool.
+ super().__init__(name='google_search', description='google_search')
+
+ @override
+ async def process_llm_request(
+ self,
+ *,
+ tool_context: ToolContext,
+ llm_request: LlmRequest,
+ ) -> None:
+ llm_request.config = llm_request.config or types.GenerateContentConfig()
+ llm_request.config.tools = llm_request.config.tools or []
+ if llm_request.model and llm_request.model.startswith('gemini-1'):
+ if llm_request.config.tools:
+ print(llm_request.config.tools)
+ raise ValueError(
+ 'Google search tool can not be used with other tools in Gemini 1.x.'
+ )
+ llm_request.config.tools.append(
+ types.Tool(google_search_retrieval=types.GoogleSearchRetrieval())
+ )
+ elif llm_request.model and llm_request.model.startswith('gemini-2'):
+ llm_request.config.tools.append(
+ types.Tool(google_search=types.GoogleSearch())
+ )
+ else:
+ raise ValueError(
+ f'Google search tool is not supported for model {llm_request.model}'
+ )
+
+
+google_search = GoogleSearchTool()
diff --git a/src/google/adk/tools/langchain_tool.py b/src/google/adk/tools/langchain_tool.py
new file mode 100644
index 0000000..b275926
--- /dev/null
+++ b/src/google/adk/tools/langchain_tool.py
@@ -0,0 +1,86 @@
+# Copyright 2025 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import Any
+from typing import Callable
+
+from google.genai import types
+from pydantic import model_validator
+from typing_extensions import override
+
+from . import _automatic_function_calling_util
+from .function_tool import FunctionTool
+
+
+class LangchainTool(FunctionTool):
+ """Use this class to wrap a langchain tool.
+
+ If the original tool name and description are not suitable, you can override
+ them in the constructor.
+ """
+
+ tool: Any
+ """The wrapped langchain tool."""
+
+ def __init__(self, tool: Any):
+ super().__init__(tool._run)
+ self.tool = tool
+ if tool.name:
+ self.name = tool.name
+ if tool.description:
+ self.description = tool.description
+
+ @model_validator(mode='before')
+ @classmethod
+ def populate_name(cls, data: Any) -> Any:
+ # Override this to not use function's signature name as it's
+ # mostly "run" or "invoke" for thir-party tools.
+ return data
+
+ @override
+ def _get_declaration(self) -> types.FunctionDeclaration:
+ """Build the function declaration for the tool."""
+ from langchain.agents import Tool
+ from langchain_core.tools import BaseTool
+
+ # There are two types of tools:
+ # 1. BaseTool: the tool is defined in langchain.tools.
+ # 2. Other tools: the tool doesn't inherit any class but follow some
+ # conventions, like having a "run" method.
+ if isinstance(self.tool, BaseTool):
+ tool_wrapper = Tool(
+ name=self.name,
+ func=self.func,
+ description=self.description,
+ )
+ if self.tool.args_schema:
+ tool_wrapper.args_schema = self.tool.args_schema
+ function_declaration = _automatic_function_calling_util.build_function_declaration_for_langchain(
+ False,
+ self.name,
+ self.description,
+ tool_wrapper.func,
+ tool_wrapper.args,
+ )
+ return function_declaration
+ else:
+ # Need to provide a way to override the function names and descriptions
+ # as the original function names are mostly ".run" and the descriptions
+ # may not meet users' needs.
+ function_declaration = (
+ _automatic_function_calling_util.build_function_declaration(
+ func=self.tool.run,
+ )
+ )
+ return function_declaration
diff --git a/src/google/adk/tools/load_artifacts_tool.py b/src/google/adk/tools/load_artifacts_tool.py
new file mode 100644
index 0000000..4972856
--- /dev/null
+++ b/src/google/adk/tools/load_artifacts_tool.py
@@ -0,0 +1,113 @@
+# Copyright 2025 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import annotations
+
+import json
+from typing import Any
+from typing import TYPE_CHECKING
+
+from google.genai import types
+from typing_extensions import override
+
+from .base_tool import BaseTool
+
+if TYPE_CHECKING:
+ from ..models.llm_request import LlmRequest
+ from .tool_context import ToolContext
+
+
+class LoadArtifactsTool(BaseTool):
+ """A tool that loads the artifacts and adds them to the session."""
+
+ def __init__(self):
+ super().__init__(
+ name='load_artifacts',
+ description='Loads the artifacts and adds them to the session.',
+ )
+
+ def _get_declaration(self) -> types.FunctionDeclaration | None:
+ return types.FunctionDeclaration(
+ name=self.name,
+ description=self.description,
+ parameters=types.Schema(
+ type=types.Type.OBJECT,
+ properties={
+ 'artifact_names': types.Schema(
+ type=types.Type.ARRAY,
+ items=types.Schema(
+ type=types.Type.STRING,
+ ),
+ )
+ },
+ ),
+ )
+
+ @override
+ async def run_async(
+ self, *, args: dict[str, Any], tool_context: ToolContext
+ ) -> Any:
+ artifact_names: list[str] = args.get('artifact_names', [])
+ return {'artifact_names': artifact_names}
+
+ @override
+ async def process_llm_request(
+ self, *, tool_context: ToolContext, llm_request: LlmRequest
+ ) -> None:
+ await super().process_llm_request(
+ tool_context=tool_context,
+ llm_request=llm_request,
+ )
+ self._append_artifacts_to_llm_request(
+ tool_context=tool_context, llm_request=llm_request
+ )
+
+ def _append_artifacts_to_llm_request(
+ self, *, tool_context: ToolContext, llm_request: LlmRequest
+ ):
+ artifact_names = tool_context.list_artifacts()
+ if not artifact_names:
+ return
+
+ # Tell the model about the available artifacts.
+ llm_request.append_instructions([f"""You have a list of artifacts:
+ {json.dumps(artifact_names)}
+
+ When the user asks questions about any of the artifacts, you should call the
+ `load_artifacts` function to load the artifact. Do not generate any text other
+ than the function call.
+ """])
+
+ # Attache the content of the artifacts if the model requests them.
+ # This only adds the content to the model request, instead of the session.
+ if llm_request.contents and llm_request.contents[-1].parts:
+ function_response = llm_request.contents[-1].parts[0].function_response
+ if function_response and function_response.name == 'load_artifacts':
+ artifact_names = function_response.response['artifact_names']
+ for artifact_name in artifact_names:
+ artifact = tool_context.load_artifact(artifact_name)
+ llm_request.contents.append(
+ types.Content(
+ role='user',
+ parts=[
+ types.Part.from_text(
+ text=f'Artifact {artifact_name} is:'
+ ),
+ artifact,
+ ],
+ )
+ )
+
+
+load_artifacts_tool = LoadArtifactsTool()
diff --git a/src/google/adk/tools/load_memory_tool.py b/src/google/adk/tools/load_memory_tool.py
new file mode 100644
index 0000000..28782d8
--- /dev/null
+++ b/src/google/adk/tools/load_memory_tool.py
@@ -0,0 +1,58 @@
+# Copyright 2025 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import annotations
+
+from typing import TYPE_CHECKING
+
+from typing_extensions import override
+
+from .function_tool import FunctionTool
+from .tool_context import ToolContext
+
+if TYPE_CHECKING:
+ from ..models import LlmRequest
+ from ..memory.base_memory_service import MemoryResult
+
+
+def load_memory(query: str, tool_context: ToolContext) -> 'list[MemoryResult]':
+ """Loads the memory for the current user."""
+ response = tool_context.search_memory(query)
+ return response.memories
+
+
+class LoadMemoryTool(FunctionTool):
+ """A tool that loads the memory for the current user."""
+
+ def __init__(self):
+ super().__init__(load_memory)
+
+ @override
+ async def process_llm_request(
+ self,
+ *,
+ tool_context: ToolContext,
+ llm_request: LlmRequest,
+ ) -> None:
+ await super().process_llm_request(
+ tool_context=tool_context, llm_request=llm_request
+ )
+ # Tell the model about the memory.
+ llm_request.append_instructions(["""
+You have memory. You can use it to answer questions. If any questions need
+you to look up the memory, you should call load_memory function with a query.
+"""])
+
+
+load_memory_tool = LoadMemoryTool()
diff --git a/src/google/adk/tools/load_web_page.py b/src/google/adk/tools/load_web_page.py
new file mode 100644
index 0000000..eaefedc
--- /dev/null
+++ b/src/google/adk/tools/load_web_page.py
@@ -0,0 +1,41 @@
+# Copyright 2025 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Tool for web browse."""
+
+import requests
+
+
+def load_web_page(url: str) -> str:
+ """Fetches the content in the url and returns the text in it.
+
+ Args:
+ url (str): The url to browse.
+
+ Returns:
+ str: The text content of the url.
+ """
+ from bs4 import BeautifulSoup
+
+ response = requests.get(url)
+
+ if response.status_code == 200:
+ soup = BeautifulSoup(response.content, 'lxml')
+ text = soup.get_text(separator='\n', strip=True)
+ else:
+ text = f'Failed to fetch url: {url}'
+
+ # Split the text into lines, filtering out very short lines
+ # (e.g., single words or short subtitles)
+ return '\n'.join(line for line in text.splitlines() if len(line.split()) > 3)
diff --git a/src/google/adk/tools/long_running_tool.py b/src/google/adk/tools/long_running_tool.py
new file mode 100644
index 0000000..1d5ce31
--- /dev/null
+++ b/src/google/adk/tools/long_running_tool.py
@@ -0,0 +1,39 @@
+# Copyright 2025 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import Callable
+
+from .function_tool import FunctionTool
+
+
+class LongRunningFunctionTool(FunctionTool):
+ """A function tool that returns the result asynchronously.
+
+ This tool is used for long-running operations that may take a significant
+ amount of time to complete. The framework will call the function. Once the
+ function returns, the response will be returned asynchronously to the
+ framework which is identified by the function_call_id.
+
+ Example:
+ ```python
+ tool = LongRunningFunctionTool(a_long_running_function)
+ ```
+
+ Attributes:
+ is_long_running: Whether the tool is a long running operation.
+ """
+
+ def __init__(self, func: Callable):
+ super().__init__(func)
+ self.is_long_running = True
diff --git a/src/google/adk/tools/mcp_tool/__init__.py b/src/google/adk/tools/mcp_tool/__init__.py
new file mode 100644
index 0000000..8b93a1a
--- /dev/null
+++ b/src/google/adk/tools/mcp_tool/__init__.py
@@ -0,0 +1,42 @@
+# Copyright 2025 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+__all__ = []
+
+try:
+ from .conversion_utils import adk_to_mcp_tool_type, gemini_to_json_schema
+ from .mcp_tool import MCPTool
+ from .mcp_toolset import MCPToolset
+
+ __all__.extend([
+ 'adk_to_mcp_tool_type',
+ 'gemini_to_json_schema',
+ 'MCPTool',
+ 'MCPToolset',
+ ])
+
+except ImportError as e:
+ import logging
+ import sys
+
+ logger = logging.getLogger(__name__)
+
+ if sys.version_info < (3, 10):
+ logger.warning(
+ 'MCP Tool requires Python 3.10 or above. Please upgrade your Python'
+ ' version.'
+ )
+ else:
+ logger.debug('MCP Tool is not installed')
+ logger.debug(e)
diff --git a/src/google/adk/tools/mcp_tool/conversion_utils.py b/src/google/adk/tools/mcp_tool/conversion_utils.py
new file mode 100644
index 0000000..8afa301
--- /dev/null
+++ b/src/google/adk/tools/mcp_tool/conversion_utils.py
@@ -0,0 +1,161 @@
+# Copyright 2025 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import Any, Dict
+from google.genai.types import Schema, Type
+import mcp.types as mcp_types
+from ..base_tool import BaseTool
+
+
+def adk_to_mcp_tool_type(tool: BaseTool) -> mcp_types.Tool:
+ """Convert a Tool in ADK into MCP tool type.
+
+ This function transforms an ADK tool definition into its equivalent
+ representation in the MCP (Model Control Plane) system.
+
+ Args:
+ tool: The ADK tool to convert. It should be an instance of a class derived
+ from `BaseTool`.
+
+ Returns:
+ An object of MCP Tool type, representing the converted tool.
+
+ Examples:
+ # Assuming 'my_tool' is an instance of a BaseTool derived class
+ mcp_tool = adk_to_mcp_tool_type(my_tool)
+ print(mcp_tool)
+ """
+ tool_declaration = tool._get_declaration()
+ if not tool_declaration:
+ input_schema = {}
+ else:
+ input_schema = gemini_to_json_schema(tool._get_declaration().parameters)
+ return mcp_types.Tool(
+ name=tool.name,
+ description=tool.description,
+ inputSchema=input_schema,
+ )
+
+
+def gemini_to_json_schema(gemini_schema: Schema) -> Dict[str, Any]:
+ """Converts a Gemini Schema object into a JSON Schema dictionary.
+
+ Args:
+ gemini_schema: An instance of the Gemini Schema class.
+
+ Returns:
+ A dictionary representing the equivalent JSON Schema.
+
+ Raises:
+ TypeError: If the input is not an instance of the expected Schema class.
+ ValueError: If an invalid Gemini Type enum value is encountered.
+ """
+ if not isinstance(gemini_schema, Schema):
+ raise TypeError(
+ f"Input must be an instance of Schema, got {type(gemini_schema)}"
+ )
+
+ json_schema_dict: Dict[str, Any] = {}
+
+ # Map Type
+ gemini_type = getattr(gemini_schema, "type", None)
+ if gemini_type and gemini_type != Type.TYPE_UNSPECIFIED:
+ json_schema_dict["type"] = gemini_type.lower()
+ else:
+ json_schema_dict["type"] = "null"
+
+ # Map Nullable
+ if getattr(gemini_schema, "nullable", None) == True:
+ json_schema_dict["nullable"] = True
+
+ # --- Map direct fields ---
+ direct_mappings = {
+ "title": "title",
+ "description": "description",
+ "default": "default",
+ "enum": "enum",
+ "format": "format",
+ "example": "example",
+ }
+ for gemini_key, json_key in direct_mappings.items():
+ value = getattr(gemini_schema, gemini_key, None)
+ if value is not None:
+ json_schema_dict[json_key] = value
+
+ # String validation
+ if gemini_type == Type.STRING:
+ str_mappings = {
+ "pattern": "pattern",
+ "min_length": "minLength",
+ "max_length": "maxLength",
+ }
+ for gemini_key, json_key in str_mappings.items():
+ value = getattr(gemini_schema, gemini_key, None)
+ if value is not None:
+ json_schema_dict[json_key] = value
+
+ # Number/Integer validation
+ if gemini_type in (Type.NUMBER, Type.INTEGER):
+ num_mappings = {
+ "minimum": "minimum",
+ "maximum": "maximum",
+ }
+ for gemini_key, json_key in num_mappings.items():
+ value = getattr(gemini_schema, gemini_key, None)
+ if value is not None:
+ json_schema_dict[json_key] = value
+
+ # Array validation (Recursive call for items)
+ if gemini_type == Type.ARRAY:
+ items_schema = getattr(gemini_schema, "items", None)
+ if items_schema is not None:
+ json_schema_dict["items"] = gemini_to_json_schema(items_schema)
+
+ arr_mappings = {
+ "min_items": "minItems",
+ "max_items": "maxItems",
+ }
+ for gemini_key, json_key in arr_mappings.items():
+ value = getattr(gemini_schema, gemini_key, None)
+ if value is not None:
+ json_schema_dict[json_key] = value
+
+ # Object validation (Recursive call for properties)
+ if gemini_type == Type.OBJECT:
+ properties_dict = getattr(gemini_schema, "properties", None)
+ if properties_dict is not None:
+ json_schema_dict["properties"] = {
+ prop_name: gemini_to_json_schema(prop_schema)
+ for prop_name, prop_schema in properties_dict.items()
+ }
+
+ obj_mappings = {
+ "required": "required",
+ "min_properties": "minProperties",
+ "max_properties": "maxProperties",
+ # Note: Ignoring 'property_ordering' as it's not standard JSON Schema
+ }
+ for gemini_key, json_key in obj_mappings.items():
+ value = getattr(gemini_schema, gemini_key, None)
+ if value is not None:
+ json_schema_dict[json_key] = value
+
+ # Map anyOf (Recursive call for subschemas)
+ any_of_list = getattr(gemini_schema, "any_of", None)
+ if any_of_list is not None:
+ json_schema_dict["anyOf"] = [
+ gemini_to_json_schema(sub_schema) for sub_schema in any_of_list
+ ]
+
+ return json_schema_dict
diff --git a/src/google/adk/tools/mcp_tool/mcp_tool.py b/src/google/adk/tools/mcp_tool/mcp_tool.py
new file mode 100644
index 0000000..df12381
--- /dev/null
+++ b/src/google/adk/tools/mcp_tool/mcp_tool.py
@@ -0,0 +1,113 @@
+# Copyright 2025 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import Optional
+
+from google.genai.types import FunctionDeclaration
+from typing_extensions import override
+
+# Attempt to import MCP Tool from the MCP library, and hints user to upgrade
+# their Python version to 3.10 if it fails.
+try:
+ from mcp import ClientSession
+ from mcp.types import Tool as McpBaseTool
+except ImportError as e:
+ import sys
+
+ if sys.version_info < (3, 10):
+ raise ImportError(
+ "MCP Tool requires Python 3.10 or above. Please upgrade your Python"
+ " version."
+ ) from e
+ else:
+ raise e
+
+from ..base_tool import BaseTool
+from ...auth.auth_credential import AuthCredential
+from ...auth.auth_schemes import AuthScheme
+from ..openapi_tool.openapi_spec_parser.rest_api_tool import to_gemini_schema
+from ..tool_context import ToolContext
+
+
+class MCPTool(BaseTool):
+ """Turns a MCP Tool into a Vertex Agent Framework Tool.
+
+ Internally, the tool initializes from a MCP Tool, and uses the MCP Session to
+ call the tool.
+ """
+
+ def __init__(
+ self,
+ mcp_tool: McpBaseTool,
+ mcp_session: ClientSession,
+ auth_scheme: Optional[AuthScheme] = None,
+ auth_credential: Optional[AuthCredential] | None = None,
+ ):
+ """Initializes a MCPTool.
+
+ This tool wraps a MCP Tool interface and an active MCP Session. It invokes
+ the MCP Tool through executing the tool from remote MCP Session.
+
+ Example:
+ tool = MCPTool(mcp_tool=mcp_tool, mcp_session=mcp_session)
+
+ Args:
+ mcp_tool: The MCP tool to wrap.
+ mcp_session: The MCP session to use to call the tool.
+ auth_scheme: The authentication scheme to use.
+ auth_credential: The authentication credential to use.
+
+ Raises:
+ ValueError: If mcp_tool or mcp_session is None.
+ """
+ if mcp_tool is None:
+ raise ValueError("mcp_tool cannot be None")
+ if mcp_session is None:
+ raise ValueError("mcp_session cannot be None")
+ self.name = mcp_tool.name
+ self.description = mcp_tool.description if mcp_tool.description else ""
+ self.mcp_tool = mcp_tool
+ self.mcp_session = mcp_session
+ # TODO(cheliu): Support passing auth to MCP Server.
+ self.auth_scheme = auth_scheme
+ self.auth_credential = auth_credential
+
+ @override
+ def _get_declaration(self) -> FunctionDeclaration:
+ """Gets the function declaration for the tool.
+
+ Returns:
+ FunctionDeclaration: The Gemini function declaration for the tool.
+ """
+ schema_dict = self.mcp_tool.inputSchema
+ parameters = to_gemini_schema(schema_dict)
+ function_decl = FunctionDeclaration(
+ name=self.name, description=self.description, parameters=parameters
+ )
+ return function_decl
+
+ @override
+ async def run_async(self, *, args, tool_context: ToolContext):
+ """Runs the tool asynchronously.
+
+ Args:
+ args: The arguments as a dict to pass to the tool.
+ tool_context: The tool context from upper level ADK agent.
+
+ Returns:
+ Any: The response from the tool.
+ """
+ # TODO(cheliu): Support passing tool context to MCP Server.
+ response = await self.mcp_session.call_tool(self.name, arguments=args)
+ return response
diff --git a/src/google/adk/tools/mcp_tool/mcp_toolset.py b/src/google/adk/tools/mcp_tool/mcp_toolset.py
new file mode 100644
index 0000000..8e2b64b
--- /dev/null
+++ b/src/google/adk/tools/mcp_tool/mcp_toolset.py
@@ -0,0 +1,272 @@
+# Copyright 2025 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from contextlib import AsyncExitStack
+from types import TracebackType
+from typing import Any, List, Optional, Tuple, Type
+
+# Attempt to import MCP Tool from the MCP library, and hints user to upgrade
+# their Python version to 3.10 if it fails.
+try:
+ from mcp import ClientSession, StdioServerParameters
+ from mcp.client.sse import sse_client
+ from mcp.client.stdio import stdio_client
+ from mcp.types import ListToolsResult
+except ImportError as e:
+ import sys
+
+ if sys.version_info < (3, 10):
+ raise ImportError(
+ 'MCP Tool requires Python 3.10 or above. Please upgrade your Python'
+ ' version.'
+ ) from e
+ else:
+ raise e
+
+from pydantic import BaseModel
+
+from .mcp_tool import MCPTool
+
+
+class SseServerParams(BaseModel):
+ url: str
+ headers: dict[str, Any] | None = None
+ timeout: float = 5
+ sse_read_timeout: float = 60 * 5
+
+
+class MCPToolset:
+ """Connects to a MCP Server, and retrieves MCP Tools into ADK Tools.
+
+ Usage:
+ Example 1: (using from_server helper):
+ ```
+ async def load_tools():
+ return await MCPToolset.from_server(
+ connection_params=StdioServerParameters(
+ command='npx',
+ args=["-y", "@modelcontextprotocol/server-filesystem"],
+ )
+ )
+
+ # Use the tools in an LLM agent
+ tools, exit_stack = await load_tools()
+ agent = LlmAgent(
+ tools=tools
+ )
+ ...
+ await exit_stack.aclose()
+ ```
+
+ Example 2: (using `async with`):
+
+ ```
+ async def load_tools():
+ async with MCPToolset(
+ connection_params=SseServerParams(url="http://0.0.0.0:8090/sse")
+ ) as toolset:
+ tools = await toolset.load_tools()
+
+ agent = LlmAgent(
+ ...
+ tools=tools
+ )
+ ```
+
+ Example 3: (provide AsyncExitStack):
+ ```
+ async def load_tools():
+ async_exit_stack = AsyncExitStack()
+ toolset = MCPToolset(
+ connection_params=StdioServerParameters(...),
+ )
+ async_exit_stack.enter_async_context(toolset)
+ tools = await toolset.load_tools()
+ agent = LlmAgent(
+ ...
+ tools=tools
+ )
+ ...
+ await async_exit_stack.aclose()
+
+ ```
+
+ Attributes:
+ connection_params: The connection parameters to the MCP server. Can be
+ either `StdioServerParameters` or `SseServerParams`.
+ exit_stack: The async exit stack to manage the connection to the MCP server.
+ session: The MCP session being initialized with the connection.
+ """
+
+ def __init__(
+ self, *, connection_params: StdioServerParameters | SseServerParams
+ ):
+ """Initializes the MCPToolset.
+
+ Usage:
+ Example 1: (using from_server helper):
+ ```
+ async def load_tools():
+ return await MCPToolset.from_server(
+ connection_params=StdioServerParameters(
+ command='npx',
+ args=["-y", "@modelcontextprotocol/server-filesystem"],
+ )
+ )
+
+ # Use the tools in an LLM agent
+ tools, exit_stack = await load_tools()
+ agent = LlmAgent(
+ tools=tools
+ )
+ ...
+ await exit_stack.aclose()
+ ```
+
+ Example 2: (using `async with`):
+
+ ```
+ async def load_tools():
+ async with MCPToolset(
+ connection_params=SseServerParams(url="http://0.0.0.0:8090/sse")
+ ) as toolset:
+ tools = await toolset.load_tools()
+
+ agent = LlmAgent(
+ ...
+ tools=tools
+ )
+ ```
+
+ Example 3: (provide AsyncExitStack):
+ ```
+ async def load_tools():
+ async_exit_stack = AsyncExitStack()
+ toolset = MCPToolset(
+ connection_params=StdioServerParameters(...),
+ )
+ async_exit_stack.enter_async_context(toolset)
+ tools = await toolset.load_tools()
+ agent = LlmAgent(
+ ...
+ tools=tools
+ )
+ ...
+ await async_exit_stack.aclose()
+
+ ```
+
+ Args:
+ connection_params: The connection parameters to the MCP server. Can be:
+ `StdioServerParameters` for using local mcp server (e.g. using `npx` or
+ `python3`); or `SseServerParams` for a local/remote SSE server.
+ """
+ if not connection_params:
+ raise ValueError('Missing connection params in MCPToolset.')
+ self.connection_params = connection_params
+ self.exit_stack = AsyncExitStack()
+
+ @classmethod
+ async def from_server(
+ cls,
+ *,
+ connection_params: StdioServerParameters | SseServerParams,
+ async_exit_stack: Optional[AsyncExitStack] = None,
+ ) -> Tuple[List[MCPTool], AsyncExitStack]:
+ """Retrieve all tools from the MCP connection.
+
+ Usage:
+ ```
+ async def load_tools():
+ tools, exit_stack = await MCPToolset.from_server(
+ connection_params=StdioServerParameters(
+ command='npx',
+ args=["-y", "@modelcontextprotocol/server-filesystem"],
+ )
+ )
+ ```
+
+ Args:
+ connection_params: The connection parameters to the MCP server.
+ async_exit_stack: The async exit stack to use. If not provided, a new
+ AsyncExitStack will be created.
+
+ Returns:
+ A tuple of the list of MCPTools and the AsyncExitStack.
+ - tools: The list of MCPTools.
+ - async_exit_stack: The AsyncExitStack used to manage the connection to
+ the MCP server. Use `await async_exit_stack.aclose()` to close the
+ connection when server shuts down.
+ """
+ toolset = cls(connection_params=connection_params)
+ async_exit_stack = async_exit_stack or AsyncExitStack()
+ await async_exit_stack.enter_async_context(toolset)
+ tools = await toolset.load_tools()
+ return (tools, async_exit_stack)
+
+ async def _initialize(self) -> ClientSession:
+ """Connects to the MCP Server and initializes the ClientSession."""
+ if isinstance(self.connection_params, StdioServerParameters):
+ client = stdio_client(self.connection_params)
+ elif isinstance(self.connection_params, SseServerParams):
+ client = sse_client(
+ url=self.connection_params.url,
+ headers=self.connection_params.headers,
+ timeout=self.connection_params.timeout,
+ sse_read_timeout=self.connection_params.sse_read_timeout,
+ )
+ else:
+ raise ValueError(
+ 'Unable to initialize connection. Connection should be'
+ ' StdioServerParameters or SseServerParams, but got'
+ f' {self.connection_params}'
+ )
+
+ transports = await self.exit_stack.enter_async_context(client)
+ self.session = await self.exit_stack.enter_async_context(
+ ClientSession(*transports)
+ )
+ await self.session.initialize()
+ return self.session
+
+ async def _exit(self):
+ """Closes the connection to MCP Server."""
+ await self.exit_stack.aclose()
+
+ async def load_tools(self) -> List[MCPTool]:
+ """Loads all tools from the MCP Server.
+
+ Returns:
+ A list of MCPTools imported from the MCP Server.
+ """
+ tools_response: ListToolsResult = await self.session.list_tools()
+ return [
+ MCPTool(mcp_tool=tool, mcp_session=self.session)
+ for tool in tools_response.tools
+ ]
+
+ async def __aenter__(self):
+ try:
+ await self._initialize()
+ return self
+ except Exception as e:
+ raise e
+
+ async def __aexit__(
+ self,
+ exc_type: Optional[Type[BaseException]],
+ exc: Optional[BaseException],
+ tb: Optional[TracebackType],
+ ) -> None:
+ await self._exit()
diff --git a/src/google/adk/tools/openapi_tool/__init__.py b/src/google/adk/tools/openapi_tool/__init__.py
new file mode 100644
index 0000000..20a6a82
--- /dev/null
+++ b/src/google/adk/tools/openapi_tool/__init__.py
@@ -0,0 +1,21 @@
+# Copyright 2025 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from .openapi_spec_parser import OpenAPIToolset
+from .openapi_spec_parser import RestApiTool
+
+__all__ = [
+ 'OpenAPIToolset',
+ 'RestApiTool',
+]
diff --git a/src/google/adk/tools/openapi_tool/auth/__init__.py b/src/google/adk/tools/openapi_tool/auth/__init__.py
new file mode 100644
index 0000000..1b92af3
--- /dev/null
+++ b/src/google/adk/tools/openapi_tool/auth/__init__.py
@@ -0,0 +1,19 @@
+# Copyright 2025 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from . import auth_helpers
+
+__all__ = [
+ 'auth_helpers',
+]
diff --git a/src/google/adk/tools/openapi_tool/auth/auth_helpers.py b/src/google/adk/tools/openapi_tool/auth/auth_helpers.py
new file mode 100644
index 0000000..197b7fa
--- /dev/null
+++ b/src/google/adk/tools/openapi_tool/auth/auth_helpers.py
@@ -0,0 +1,498 @@
+# Copyright 2025 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import Any
+from typing import Dict
+from typing import List
+from typing import Literal
+from typing import Optional
+from typing import Tuple
+
+from fastapi.openapi.models import APIKey
+from fastapi.openapi.models import APIKeyIn
+from fastapi.openapi.models import HTTPBase
+from fastapi.openapi.models import HTTPBearer
+from fastapi.openapi.models import OAuth2
+from fastapi.openapi.models import OpenIdConnect
+from fastapi.openapi.models import Schema
+from pydantic import BaseModel
+from pydantic import ValidationError
+import requests
+
+from ....auth.auth_credential import AuthCredential
+from ....auth.auth_credential import AuthCredentialTypes
+from ....auth.auth_credential import HttpAuth
+from ....auth.auth_credential import HttpCredentials
+from ....auth.auth_credential import OAuth2Auth
+from ....auth.auth_credential import ServiceAccount
+from ....auth.auth_credential import ServiceAccountCredential
+from ....auth.auth_schemes import AuthScheme
+from ....auth.auth_schemes import AuthSchemeType
+from ....auth.auth_schemes import OpenIdConnectWithConfig
+from ..common.common import ApiParameter
+
+
+class OpenIdConfig(BaseModel):
+ """Represents OpenID Connect configuration.
+
+ Attributes:
+ client_id: The client ID.
+ auth_uri: The authorization URI.
+ token_uri: The token URI.
+ client_secret: The client secret.
+
+ Example:
+ config = OpenIdConfig(
+ client_id="your_client_id",
+ auth_uri="https://accounts.google.com/o/oauth2/auth",
+ token_uri="https://oauth2.googleapis.com/token",
+ client_secret="your_client_secret",
+ redirect
+ )
+ """
+
+ client_id: str
+ auth_uri: str
+ token_uri: str
+ client_secret: str
+ redirect_uri: Optional[str]
+
+
+def token_to_scheme_credential(
+ token_type: Literal["apikey", "oauth2Token"],
+ location: Optional[Literal["header", "query", "cookie"]] = None,
+ name: Optional[str] = None,
+ credential_value: Optional[str] = None,
+) -> Tuple[AuthScheme, AuthCredential]:
+ """Creates a AuthScheme and AuthCredential for API key or bearer token.
+
+ Examples:
+ ```
+ # API Key in header
+ auth_scheme, auth_credential = token_to_scheme_credential("apikey", "header",
+ "X-API-Key", "your_api_key_value")
+
+ # API Key in query parameter
+ auth_scheme, auth_credential = token_to_scheme_credential("apikey", "query",
+ "api_key", "your_api_key_value")
+
+ # OAuth2 Bearer Token in Authorization header
+ auth_scheme, auth_credential = token_to_scheme_credential("oauth2Token",
+ "header", "Authorization", "your_bearer_token_value")
+ ```
+
+ Args:
+ type: 'apikey' or 'oauth2Token'.
+ location: 'header', 'query', or 'cookie' (only 'header' for oauth2Token).
+ name: The name of the header, query parameter, or cookie.
+ credential_value: The value of the API Key/ Token.
+
+ Returns:
+ Tuple: (AuthScheme, AuthCredential)
+
+ Raises:
+ ValueError: For invalid type or location.
+ """
+ if token_type == "apikey":
+ in_: APIKeyIn
+ if location == "header":
+ in_ = APIKeyIn.header
+ elif location == "query":
+ in_ = APIKeyIn.query
+ elif location == "cookie":
+ in_ = APIKeyIn.cookie
+ else:
+ raise ValueError(f"Invalid location for apiKey: {location}")
+ auth_scheme = APIKey(**{
+ "type": AuthSchemeType.apiKey,
+ "in": in_,
+ "name": name,
+ })
+ if credential_value:
+ auth_credential = AuthCredential(
+ auth_type=AuthCredentialTypes.API_KEY, api_key=credential_value
+ )
+ else:
+ auth_credential = None
+
+ return auth_scheme, auth_credential
+
+ elif token_type == "oauth2Token":
+ # ignore location. OAuth2 Bearer Token is always in Authorization header.
+ auth_scheme = HTTPBearer(
+ bearerFormat="JWT"
+ ) # Common format, can be omitted.
+ if credential_value:
+ auth_credential = AuthCredential(
+ auth_type=AuthCredentialTypes.HTTP,
+ http=HttpAuth(
+ scheme="bearer",
+ credentials=HttpCredentials(token=credential_value),
+ ),
+ )
+ else:
+ auth_credential = None
+
+ return auth_scheme, auth_credential
+
+ else:
+ raise ValueError(f"Invalid security scheme type: {type}")
+
+
+def service_account_dict_to_scheme_credential(
+ config: Dict[str, Any],
+ scopes: List[str],
+) -> Tuple[AuthScheme, AuthCredential]:
+ """Creates AuthScheme and AuthCredential for Google Service Account.
+
+ Returns a bearer token scheme, and a service account credential.
+
+ Args:
+ config: A ServiceAccount object containing the Google Service Account
+ configuration.
+ scopes: A list of scopes to be used.
+
+ Returns:
+ Tuple: (AuthScheme, AuthCredential)
+ """
+ auth_scheme = HTTPBearer(bearerFormat="JWT")
+ service_account = ServiceAccount(
+ service_account_credential=ServiceAccountCredential.model_construct(
+ **config
+ ),
+ scopes=scopes,
+ )
+ auth_credential = AuthCredential(
+ auth_type=AuthCredentialTypes.SERVICE_ACCOUNT,
+ service_account=service_account,
+ )
+ return auth_scheme, auth_credential
+
+
+def service_account_scheme_credential(
+ config: ServiceAccount,
+) -> Tuple[AuthScheme, AuthCredential]:
+ """Creates AuthScheme and AuthCredential for Google Service Account.
+
+ Returns a bearer token scheme, and a service account credential.
+
+ Args:
+ config: A ServiceAccount object containing the Google Service Account
+ configuration.
+
+ Returns:
+ Tuple: (AuthScheme, AuthCredential)
+ """
+ auth_scheme = HTTPBearer(bearerFormat="JWT")
+ auth_credential = AuthCredential(
+ auth_type=AuthCredentialTypes.SERVICE_ACCOUNT, service_account=config
+ )
+ return auth_scheme, auth_credential
+
+
+def openid_dict_to_scheme_credential(
+ config_dict: Dict[str, Any],
+ scopes: List[str],
+ credential_dict: Dict[str, Any],
+) -> Tuple[OpenIdConnectWithConfig, AuthCredential]:
+ """Constructs OpenID scheme and credential from configuration and credential dictionaries.
+
+ Args:
+ config_dict: Dictionary containing OpenID Connect configuration, must
+ include at least 'authorization_endpoint' and 'token_endpoint'.
+ scopes: List of scopes to be used.
+ credential_dict: Dictionary containing credential information, must
+ include 'client_id', 'client_secret', and 'scopes'. May optionally
+ include 'redirect_uri'.
+
+ Returns:
+ Tuple: (OpenIdConnectWithConfig, AuthCredential)
+
+ Raises:
+ ValueError: If required fields are missing in the input dictionaries.
+ """
+
+ # Validate and create the OpenIdConnectWithConfig scheme
+ try:
+ config_dict["scopes"] = scopes
+ # If user provides the OpenID Config as a static dict, it may not contain
+ # openIdConnect URL.
+ if "openIdConnectUrl" not in config_dict:
+ config_dict["openIdConnectUrl"] = ""
+ openid_scheme = OpenIdConnectWithConfig.model_validate(config_dict)
+ except ValidationError as e:
+ raise ValueError(f"Invalid OpenID Connect configuration: {e}") from e
+
+ # Attempt to adjust credential_dict if this is a key downloaded from Google
+ # OAuth config
+ if len(list(credential_dict.values())) == 1:
+ credential_value = list(credential_dict.values())[0]
+ if "client_id" in credential_value and "client_secret" in credential_value:
+ credential_dict = credential_value
+
+ # Validate credential_dict
+ required_credential_fields = ["client_id", "client_secret"]
+ missing_fields = [
+ field
+ for field in required_credential_fields
+ if field not in credential_dict
+ ]
+ if missing_fields:
+ raise ValueError(
+ "Missing required fields in credential_dict:"
+ f" {', '.join(missing_fields)}"
+ )
+
+ # Construct AuthCredential
+ auth_credential = AuthCredential(
+ auth_type=AuthCredentialTypes.OPEN_ID_CONNECT,
+ oauth2=OAuth2Auth(
+ client_id=credential_dict["client_id"],
+ client_secret=credential_dict["client_secret"],
+ redirect_uri=credential_dict.get("redirect_uri", None),
+ ),
+ )
+
+ return openid_scheme, auth_credential
+
+
+def openid_url_to_scheme_credential(
+ openid_url: str, scopes: List[str], credential_dict: Dict[str, Any]
+) -> Tuple[OpenIdConnectWithConfig, AuthCredential]:
+ """Constructs OpenID scheme and credential from OpenID URL, scopes, and credential dictionary.
+
+ Fetches OpenID configuration from the provided URL.
+
+ Args:
+ openid_url: The OpenID Connect discovery URL.
+ scopes: List of scopes to be used.
+ credential_dict: Dictionary containing credential information, must
+ include at least "client_id" and "client_secret", may optionally include
+ "redirect_uri" and "scope"
+
+ Returns:
+ Tuple: (AuthScheme, AuthCredential)
+
+ Raises:
+ ValueError: If the OpenID URL is invalid, fetching fails, or required
+ fields are missing.
+ requests.exceptions.RequestException: If there's an error during the
+ HTTP request.
+ """
+ try:
+ response = requests.get(openid_url, timeout=10)
+ response.raise_for_status()
+ config_dict = response.json()
+ except requests.exceptions.RequestException as e:
+ raise ValueError(
+ f"Failed to fetch OpenID configuration from {openid_url}: {e}"
+ ) from e
+ except ValueError as e:
+ raise ValueError(
+ "Invalid JSON response from OpenID configuration endpoint"
+ f" {openid_url}: {e}"
+ ) from e
+
+ # Add openIdConnectUrl to config dict
+ config_dict["openIdConnectUrl"] = openid_url
+
+ return openid_dict_to_scheme_credential(config_dict, scopes, credential_dict)
+
+
+INTERNAL_AUTH_PREFIX = "_auth_prefix_vaf_"
+
+
+def credential_to_param(
+ auth_scheme: AuthScheme,
+ auth_credential: AuthCredential,
+) -> Tuple[Optional[ApiParameter], Optional[Dict[str, Any]]]:
+ """Converts AuthCredential and AuthScheme to a Parameter and a dictionary for additional kwargs.
+
+ This function now supports all credential types returned by the exchangers:
+ - API Key
+ - HTTP Bearer (for Bearer tokens, OAuth2, Service Account, OpenID Connect)
+ - OAuth2 and OpenID Connect (returns None, None, as the token is now a Bearer
+ token)
+ - Service Account (returns None, None, as the token is now a Bearer token)
+
+ Args:
+ auth_scheme: The AuthScheme object.
+ auth_credential: The AuthCredential object.
+
+ Returns:
+ Tuple: (ApiParameter, Dict[str, Any])
+ """
+ if not auth_credential:
+ return None, None
+
+ if (
+ auth_scheme.type_ == AuthSchemeType.apiKey
+ and auth_credential
+ and auth_credential.api_key
+ ):
+ param_name = auth_scheme.name or ""
+ python_name = INTERNAL_AUTH_PREFIX + param_name
+ if auth_scheme.in_ == APIKeyIn.header:
+ param_location = "header"
+ elif auth_scheme.in_ == APIKeyIn.query:
+ param_location = "query"
+ elif auth_scheme.in_ == APIKeyIn.cookie:
+ param_location = "cookie"
+ else:
+ raise ValueError(f"Invalid API Key location: {auth_scheme.in_}")
+
+ param = ApiParameter(
+ original_name=param_name,
+ param_location=param_location,
+ param_schema=Schema(type="string"),
+ description=auth_scheme.description or "",
+ py_name=python_name,
+ )
+ kwargs = {param.py_name: auth_credential.api_key}
+ return param, kwargs
+
+ # TODO(cheliu): Split handling for OpenIDConnect scheme and native HTTPBearer
+ # Scheme
+ elif (
+ auth_credential and auth_credential.auth_type == AuthCredentialTypes.HTTP
+ ):
+ if (
+ auth_credential
+ and auth_credential.http
+ and auth_credential.http.credentials
+ and auth_credential.http.credentials.token
+ ):
+ param = ApiParameter(
+ original_name="Authorization",
+ param_location="header",
+ param_schema=Schema(type="string"),
+ description=auth_scheme.description or "Bearer token",
+ py_name=INTERNAL_AUTH_PREFIX + "Authorization",
+ )
+ kwargs = {
+ param.py_name: f"Bearer {auth_credential.http.credentials.token}"
+ }
+ return param, kwargs
+ elif (
+ auth_credential
+ and auth_credential.http
+ and auth_credential.http.credentials
+ and (
+ auth_credential.http.credentials.username
+ or auth_credential.http.credentials.password
+ )
+ ):
+ # Basic Auth is explicitly NOT supported
+ raise NotImplementedError("Basic Authentication is not supported.")
+ else:
+ raise ValueError("Invalid HTTP auth credentials")
+
+ # Service Account tokens, OAuth2 Tokens and OpenID Tokens are now handled as
+ # Bearer tokens.
+ elif (auth_scheme.type_ == AuthSchemeType.oauth2 and auth_credential) or (
+ auth_scheme.type_ == AuthSchemeType.openIdConnect and auth_credential
+ ):
+ if (
+ auth_credential.http
+ and auth_credential.http.credentials
+ and auth_credential.http.credentials.token
+ ):
+ param = ApiParameter(
+ original_name="Authorization",
+ param_location="header",
+ param_schema=Schema(type="string"),
+ description=auth_scheme.description or "Bearer token",
+ py_name=INTERNAL_AUTH_PREFIX + "Authorization",
+ )
+ kwargs = {
+ param.py_name: f"Bearer {auth_credential.http.credentials.token}"
+ }
+ return param, kwargs
+ return None, None
+ else:
+ raise ValueError("Invalid security scheme and credential combination")
+
+
+def dict_to_auth_scheme(data: Dict[str, Any]) -> AuthScheme:
+ """Converts a dictionary to a FastAPI AuthScheme object.
+
+ Args:
+ data: The dictionary representing the security scheme.
+
+ Returns:
+ A AuthScheme object (APIKey, HTTPBase, OAuth2, OpenIdConnect, or
+ HTTPBearer).
+
+ Raises:
+ ValueError: If the 'type' field is missing or invalid, or if the
+ dictionary cannot be converted to the corresponding Pydantic model.
+
+ Example:
+ ```python
+ api_key_data = {
+ "type": "apiKey",
+ "in": "header",
+ "name": "X-API-Key",
+ }
+ api_key_scheme = dict_to_auth_scheme(api_key_data)
+
+ bearer_data = {
+ "type": "http",
+ "scheme": "bearer",
+ "bearerFormat": "JWT",
+ }
+ bearer_scheme = dict_to_auth_scheme(bearer_data)
+
+
+ oauth2_data = {
+ "type": "oauth2",
+ "flows": {
+ "authorizationCode": {
+ "authorizationUrl": "https://example.com/auth",
+ "tokenUrl": "https://example.com/token",
+ }
+ }
+ }
+ oauth2_scheme = dict_to_auth_scheme(oauth2_data)
+
+ openid_data = {
+ "type": "openIdConnect",
+ "openIdConnectUrl": "https://example.com/.well-known/openid-configuration"
+ }
+ openid_scheme = dict_to_auth_scheme(openid_data)
+
+
+ ```
+ """
+ if "type" not in data:
+ raise ValueError("Missing 'type' field in security scheme dictionary.")
+
+ security_type = data["type"]
+ try:
+ if security_type == "apiKey":
+ return APIKey.model_validate(data)
+ elif security_type == "http":
+ if data.get("scheme") == "bearer":
+ return HTTPBearer.model_validate(data)
+ else:
+ return HTTPBase.model_validate(data) # Generic HTTP
+ elif security_type == "oauth2":
+ return OAuth2.model_validate(data)
+ elif security_type == "openIdConnect":
+ return OpenIdConnect.model_validate(data)
+ else:
+ raise ValueError(f"Invalid security scheme type: {security_type}")
+
+ except ValidationError as e:
+ raise ValueError(f"Invalid security scheme data: {e}") from e
diff --git a/src/google/adk/tools/openapi_tool/auth/credential_exchangers/__init__.py b/src/google/adk/tools/openapi_tool/auth/credential_exchangers/__init__.py
new file mode 100644
index 0000000..ac97ddb
--- /dev/null
+++ b/src/google/adk/tools/openapi_tool/auth/credential_exchangers/__init__.py
@@ -0,0 +1,25 @@
+# Copyright 2025 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from .auto_auth_credential_exchanger import AutoAuthCredentialExchanger
+from .base_credential_exchanger import BaseAuthCredentialExchanger
+from .oauth2_exchanger import OAuth2CredentialExchanger
+from .service_account_exchanger import ServiceAccountCredentialExchanger
+
+__all__ = [
+ 'AutoAuthCredentialExchanger',
+ 'BaseAuthCredentialExchanger',
+ 'OAuth2CredentialExchanger',
+ 'ServiceAccountCredentialExchanger',
+]
diff --git a/src/google/adk/tools/openapi_tool/auth/credential_exchangers/auto_auth_credential_exchanger.py b/src/google/adk/tools/openapi_tool/auth/credential_exchangers/auto_auth_credential_exchanger.py
new file mode 100644
index 0000000..1bcc6fd
--- /dev/null
+++ b/src/google/adk/tools/openapi_tool/auth/credential_exchangers/auto_auth_credential_exchanger.py
@@ -0,0 +1,105 @@
+# Copyright 2025 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import Dict
+from typing import Optional
+from typing import Type
+
+from .....auth.auth_credential import AuthCredential
+from .....auth.auth_credential import AuthCredentialTypes
+from .....auth.auth_schemes import AuthScheme
+from .base_credential_exchanger import BaseAuthCredentialExchanger
+from .oauth2_exchanger import OAuth2CredentialExchanger
+from .service_account_exchanger import ServiceAccountCredentialExchanger
+
+
+class AutoAuthCredentialExchanger(BaseAuthCredentialExchanger):
+ """Automatically selects the appropriate credential exchanger based on the auth scheme.
+
+ Optionally, an override can be provided to use a specific exchanger for a
+ given auth scheme.
+
+ Example (common case):
+ ```
+ exchanger = AutoAuthCredentialExchanger()
+ auth_credential = exchanger.exchange_credential(
+ auth_scheme=service_account_scheme,
+ auth_credential=service_account_credential,
+ )
+ # Returns an oauth token in the form of a bearer token.
+ ```
+
+ Example (use CustomAuthExchanger for OAuth2):
+ ```
+ exchanger = AutoAuthCredentialExchanger(
+ custom_exchangers={
+ AuthScheme.OAUTH2: CustomAuthExchanger,
+ }
+ )
+ ```
+
+ Attributes:
+ exchangers: A dictionary mapping auth scheme to credential exchanger class.
+ """
+
+ def __init__(
+ self,
+ custom_exchangers: Optional[
+ Dict[str, Type[BaseAuthCredentialExchanger]]
+ ] = None,
+ ):
+ """Initializes the AutoAuthCredentialExchanger.
+
+ Args:
+ custom_exchangers: Optional dictionary for adding or overriding auth
+ exchangers. The key is the auth scheme, and the value is the credential
+ exchanger class.
+ """
+ self.exchangers = {
+ AuthCredentialTypes.OAUTH2: OAuth2CredentialExchanger,
+ AuthCredentialTypes.OPEN_ID_CONNECT: OAuth2CredentialExchanger,
+ AuthCredentialTypes.SERVICE_ACCOUNT: ServiceAccountCredentialExchanger,
+ }
+
+ if custom_exchangers:
+ self.exchangers.update(custom_exchangers)
+
+ def exchange_credential(
+ self,
+ auth_scheme: AuthScheme,
+ auth_credential: Optional[AuthCredential] = None,
+ ) -> Optional[AuthCredential]:
+ """Automatically exchanges for the credential uses the appropriate credential exchanger.
+
+ Args:
+ auth_scheme (AuthScheme): The security scheme.
+ auth_credential (AuthCredential): Optional. The authentication
+ credential.
+
+ Returns: (AuthCredential)
+ A new AuthCredential object containing the exchanged credential.
+
+ """
+ if not auth_credential:
+ return None
+
+ exchanger_class = self.exchangers.get(
+ auth_credential.auth_type if auth_credential else None
+ )
+
+ if not exchanger_class:
+ return auth_credential
+
+ exchanger = exchanger_class()
+ return exchanger.exchange_credential(auth_scheme, auth_credential)
diff --git a/src/google/adk/tools/openapi_tool/auth/credential_exchangers/base_credential_exchanger.py b/src/google/adk/tools/openapi_tool/auth/credential_exchangers/base_credential_exchanger.py
new file mode 100644
index 0000000..44ceec5
--- /dev/null
+++ b/src/google/adk/tools/openapi_tool/auth/credential_exchangers/base_credential_exchanger.py
@@ -0,0 +1,55 @@
+# Copyright 2025 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import abc
+from typing import Optional
+
+from .....auth.auth_credential import (
+ AuthCredential,
+)
+from .....auth.auth_schemes import AuthScheme
+
+
+class AuthCredentialMissingError(Exception):
+ """Exception raised when required authentication credentials are missing."""
+
+ def __init__(self, message: str):
+ super().__init__(message)
+ self.message = message
+
+
+class BaseAuthCredentialExchanger:
+ """Base class for authentication credential exchangers."""
+
+ @abc.abstractmethod
+ def exchange_credential(
+ self,
+ auth_scheme: AuthScheme,
+ auth_credential: Optional[AuthCredential] = None,
+ ) -> AuthCredential:
+ """Exchanges the provided authentication credential for a usable token/credential.
+
+ Args:
+ auth_scheme: The security scheme.
+ auth_credential: The authentication credential.
+
+ Returns:
+ An updated AuthCredential object containing the fetched credential.
+ For simple schemes like API key, it may return the original credential
+ if no exchange is needed.
+
+ Raises:
+ NotImplementedError: If the method is not implemented by a subclass.
+ """
+ raise NotImplementedError("Subclasses must implement exchange_credential.")
diff --git a/src/google/adk/tools/openapi_tool/auth/credential_exchangers/oauth2_exchanger.py b/src/google/adk/tools/openapi_tool/auth/credential_exchangers/oauth2_exchanger.py
new file mode 100644
index 0000000..ee1935d
--- /dev/null
+++ b/src/google/adk/tools/openapi_tool/auth/credential_exchangers/oauth2_exchanger.py
@@ -0,0 +1,117 @@
+# Copyright 2025 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Credential fetcher for OpenID Connect."""
+
+from typing import Optional
+
+from .....auth.auth_credential import AuthCredential
+from .....auth.auth_credential import AuthCredentialTypes
+from .....auth.auth_credential import HttpAuth
+from .....auth.auth_credential import HttpCredentials
+from .....auth.auth_schemes import AuthScheme
+from .....auth.auth_schemes import AuthSchemeType
+from .base_credential_exchanger import BaseAuthCredentialExchanger
+
+
+class OAuth2CredentialExchanger(BaseAuthCredentialExchanger):
+ """Fetches credentials for OAuth2 and OpenID Connect."""
+
+ def _check_scheme_credential_type(
+ self,
+ auth_scheme: AuthScheme,
+ auth_credential: Optional[AuthCredential] = None,
+ ):
+ if not auth_credential:
+ raise ValueError(
+ "auth_credential is empty. Please create AuthCredential using"
+ " OAuth2Auth."
+ )
+
+ if auth_scheme.type_ not in (
+ AuthSchemeType.openIdConnect,
+ AuthSchemeType.oauth2,
+ ):
+ raise ValueError(
+ "Invalid security scheme, expect AuthSchemeType.openIdConnect or "
+ f"AuthSchemeType.oauth2 auth scheme, but got {auth_scheme.type_}"
+ )
+
+ if not auth_credential.oauth2 and not auth_credential.http:
+ raise ValueError(
+ "auth_credential is not configured with oauth2. Please"
+ " create AuthCredential and set OAuth2Auth."
+ )
+
+ def generate_auth_token(
+ self,
+ auth_credential: Optional[AuthCredential] = None,
+ ) -> AuthCredential:
+ """Generates an auth token from the authorization response.
+
+ Args:
+ auth_scheme: The OpenID Connect or OAuth2 auth scheme.
+ auth_credential: The auth credential.
+
+ Returns:
+ An AuthCredential object containing the HTTP bearer access token. If the
+ HTTO bearer token cannot be generated, return the origianl credential
+ """
+
+ if "access_token" not in auth_credential.oauth2.token:
+ return auth_credential
+
+ # Return the access token as a bearer token.
+ updated_credential = AuthCredential(
+ auth_type=AuthCredentialTypes.HTTP, # Store as a bearer token
+ http=HttpAuth(
+ scheme="bearer",
+ credentials=HttpCredentials(
+ token=auth_credential.oauth2.token["access_token"]
+ ),
+ ),
+ )
+ return updated_credential
+
+ def exchange_credential(
+ self,
+ auth_scheme: AuthScheme,
+ auth_credential: Optional[AuthCredential] = None,
+ ) -> AuthCredential:
+ """Exchanges the OpenID Connect auth credential for an access token or an auth URI.
+
+ Args:
+ auth_scheme: The auth scheme.
+ auth_credential: The auth credential.
+
+ Returns:
+ An AuthCredential object containing the HTTP Bearer access token.
+
+ Raises:
+ ValueError: If the auth scheme or auth credential is invalid.
+ """
+ # TODO(cheliu): Implement token refresh flow
+
+ self._check_scheme_credential_type(auth_scheme, auth_credential)
+
+ # If token is already HTTPBearer token, do nothing assuming that this token
+ # is valid.
+ if auth_credential.http:
+ return auth_credential
+
+ # If access token is exchanged, exchange a HTTPBearer token.
+ if auth_credential.oauth2.token:
+ return self.generate_auth_token(auth_credential)
+
+ return None
diff --git a/src/google/adk/tools/openapi_tool/auth/credential_exchangers/service_account_exchanger.py b/src/google/adk/tools/openapi_tool/auth/credential_exchangers/service_account_exchanger.py
new file mode 100644
index 0000000..4dbcb6e
--- /dev/null
+++ b/src/google/adk/tools/openapi_tool/auth/credential_exchangers/service_account_exchanger.py
@@ -0,0 +1,97 @@
+# Copyright 2025 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Credential fetcher for Google Service Account."""
+
+from typing import Optional
+
+import google.auth
+from google.auth.transport.requests import Request
+from google.oauth2 import service_account
+import google.oauth2.credentials
+
+from .....auth.auth_credential import (
+ AuthCredential,
+ AuthCredentialTypes,
+ HttpAuth,
+ HttpCredentials,
+)
+from .....auth.auth_schemes import AuthScheme
+from .base_credential_exchanger import AuthCredentialMissingError, BaseAuthCredentialExchanger
+
+
+class ServiceAccountCredentialExchanger(BaseAuthCredentialExchanger):
+ """Fetches credentials for Google Service Account.
+
+ Uses the default service credential if `use_default_credential = True`.
+ Otherwise, uses the service account credential provided in the auth
+ credential.
+ """
+
+ def exchange_credential(
+ self,
+ auth_scheme: AuthScheme,
+ auth_credential: Optional[AuthCredential] = None,
+ ) -> AuthCredential:
+ """Exchanges the service account auth credential for an access token.
+
+ If auth_credential contains a service account credential, it will be used
+ to fetch an access token. Otherwise, the default service credential will be
+ used for fetching an access token.
+
+ Args:
+ auth_scheme: The auth scheme.
+ auth_credential: The auth credential.
+
+ Returns:
+ An AuthCredential in HTTPBearer format, containing the access token.
+ """
+ if (
+ auth_credential is None
+ or auth_credential.service_account is None
+ or (
+ auth_credential.service_account.service_account_credential is None
+ and not auth_credential.service_account.use_default_credential
+ )
+ ):
+ raise AuthCredentialMissingError(
+ "Service account credentials are missing. Please provide them, or set"
+ " `use_default_credential = True` to use application default"
+ " credential in a hosted service like Cloud Run."
+ )
+
+ try:
+ if auth_credential.service_account.use_default_credential:
+ credentials, _ = google.auth.default()
+ else:
+ config = auth_credential.service_account
+ credentials = service_account.Credentials.from_service_account_info(
+ config.service_account_credential.model_dump(), scopes=config.scopes
+ )
+
+ credentials.refresh(Request())
+
+ updated_credential = AuthCredential(
+ auth_type=AuthCredentialTypes.HTTP, # Store as a bearer token
+ http=HttpAuth(
+ scheme="bearer",
+ credentials=HttpCredentials(token=credentials.token),
+ ),
+ )
+ return updated_credential
+
+ except Exception as e:
+ raise AuthCredentialMissingError(
+ f"Failed to exchange service account token: {e}"
+ ) from e
diff --git a/src/google/adk/tools/openapi_tool/common/__init__.py b/src/google/adk/tools/openapi_tool/common/__init__.py
new file mode 100644
index 0000000..35faa05
--- /dev/null
+++ b/src/google/adk/tools/openapi_tool/common/__init__.py
@@ -0,0 +1,19 @@
+# Copyright 2025 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from . import common
+
+__all__ = [
+ 'common',
+]
diff --git a/src/google/adk/tools/openapi_tool/common/common.py b/src/google/adk/tools/openapi_tool/common/common.py
new file mode 100644
index 0000000..cc3bca7
--- /dev/null
+++ b/src/google/adk/tools/openapi_tool/common/common.py
@@ -0,0 +1,300 @@
+# Copyright 2025 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import keyword
+import re
+from typing import Any
+from typing import Dict
+from typing import List
+from typing import Optional
+from typing import Union
+
+from fastapi.openapi.models import Response
+from fastapi.openapi.models import Schema
+from pydantic import BaseModel
+from pydantic import Field
+from pydantic import model_serializer
+
+
+def to_snake_case(text: str) -> str:
+ """Converts a string into snake_case.
+
+ Handles lowerCamelCase, UpperCamelCase, or space-separated case, acronyms
+ (e.g., "REST API") and consecutive uppercase letters correctly. Also handles
+ mixed cases with and without spaces.
+
+ Examples:
+ ```
+ to_snake_case('camelCase') -> 'camel_case'
+ to_snake_case('UpperCamelCase') -> 'upper_camel_case'
+ to_snake_case('space separated') -> 'space_separated'
+ ```
+
+ Args:
+ text: The input string.
+
+ Returns:
+ The snake_case version of the string.
+ """
+
+ # Handle spaces and non-alphanumeric characters (replace with underscores)
+ text = re.sub(r'[^a-zA-Z0-9]+', '_', text)
+
+ # Insert underscores before uppercase letters (handling both CamelCases)
+ text = re.sub(r'([a-z0-9])([A-Z])', r'\1_\2', text) # lowerCamelCase
+ text = re.sub(
+ r'([A-Z]+)([A-Z][a-z])', r'\1_\2', text
+ ) # UpperCamelCase and acronyms
+
+ # Convert to lowercase
+ text = text.lower()
+
+ # Remove consecutive underscores (clean up extra underscores)
+ text = re.sub(r'_+', '_', text)
+
+ # Remove leading and trailing underscores
+ text = text.strip('_')
+
+ return text
+
+
+def rename_python_keywords(s: str, prefix: str = 'param_') -> str:
+ """Renames Python keywords by adding a prefix.
+
+ Example:
+ ```
+ rename_python_keywords('if') -> 'param_if'
+ rename_python_keywords('for') -> 'param_for'
+ ```
+
+ Args:
+ s: The input string.
+ prefix: The prefix to add to the keyword.
+
+ Returns:
+ The renamed string.
+ """
+ if keyword.iskeyword(s):
+ return prefix + s
+ return s
+
+
+class ApiParameter(BaseModel):
+ """Data class representing a function parameter."""
+
+ original_name: str
+ param_location: str
+ param_schema: Union[str, Schema]
+ description: Optional[str] = ''
+ py_name: Optional[str] = ''
+ type_value: type[Any] = Field(default=None, init_var=False)
+ type_hint: str = Field(default=None, init_var=False)
+
+ def model_post_init(self, _: Any):
+ self.py_name = (
+ self.py_name
+ if self.py_name
+ else rename_python_keywords(to_snake_case(self.original_name))
+ )
+ if isinstance(self.param_schema, str):
+ self.param_schema = Schema.model_validate_json(self.param_schema)
+
+ self.description = self.description or self.param_schema.description or ''
+ self.type_value = TypeHintHelper.get_type_value(self.param_schema)
+ self.type_hint = TypeHintHelper.get_type_hint(self.param_schema)
+ return self
+
+ @model_serializer
+ def _serialize(self):
+ return {
+ 'original_name': self.original_name,
+ 'param_location': self.param_location,
+ 'param_schema': self.param_schema,
+ 'description': self.description,
+ 'py_name': self.py_name,
+ }
+
+ def __str__(self):
+ return f'{self.py_name}: {self.type_hint}'
+
+ def to_arg_string(self):
+ """Converts the parameter to an argument string for function call."""
+ return f'{self.py_name}={self.py_name}'
+
+ def to_dict_property(self):
+ """Converts the parameter to a key:value string for dict property."""
+ return f'"{self.py_name}": {self.py_name}'
+
+ def to_pydoc_string(self):
+ """Converts the parameter to a PyDoc parameter docstr."""
+ return PydocHelper.generate_param_doc(self)
+
+
+class TypeHintHelper:
+ """Helper class for generating type hints."""
+
+ @staticmethod
+ def get_type_value(schema: Schema) -> Any:
+ """Generates the Python type value for a given parameter."""
+ param_type = schema.type if schema.type else Any
+
+ if param_type == 'integer':
+ return int
+ elif param_type == 'number':
+ return float
+ elif param_type == 'boolean':
+ return bool
+ elif param_type == 'string':
+ return str
+ elif param_type == 'array':
+ items_type = Any
+ if schema.items and schema.items.type:
+ items_type = schema.items.type
+
+ if items_type == 'object':
+ return List[Dict[str, Any]]
+ else:
+ type_map = {
+ 'integer': int,
+ 'number': float,
+ 'boolean': bool,
+ 'string': str,
+ 'object': Dict[str, Any],
+ 'array': List[Any],
+ }
+ return List[type_map.get(items_type, 'Any')]
+ elif param_type == 'object':
+ return Dict[str, Any]
+ else:
+ return Any
+
+ @staticmethod
+ def get_type_hint(schema: Schema) -> str:
+ """Generates the Python type in string for a given parameter."""
+ param_type = schema.type if schema.type else 'Any'
+
+ if param_type == 'integer':
+ return 'int'
+ elif param_type == 'number':
+ return 'float'
+ elif param_type == 'boolean':
+ return 'bool'
+ elif param_type == 'string':
+ return 'str'
+ elif param_type == 'array':
+ items_type = 'Any'
+ if schema.items and schema.items.type:
+ items_type = schema.items.type
+
+ if items_type == 'object':
+ return 'List[Dict[str, Any]]'
+ else:
+ type_map = {
+ 'integer': 'int',
+ 'number': 'float',
+ 'boolean': 'bool',
+ 'string': 'str',
+ }
+ return f"List[{type_map.get(items_type, 'Any')}]"
+ elif param_type == 'object':
+ return 'Dict[str, Any]'
+ else:
+ return 'Any'
+
+
+class PydocHelper:
+ """Helper class for generating PyDoc strings."""
+
+ @staticmethod
+ def generate_param_doc(
+ param: ApiParameter,
+ ) -> str:
+ """Generates a parameter documentation string.
+
+ Args:
+ param: ApiParameter - The parameter to generate the documentation for.
+
+ Returns:
+ str: The generated parameter Python documentation string.
+ """
+ description = param.description.strip() if param.description else ''
+ param_doc = f'{param.py_name} ({param.type_hint}): {description}'
+
+ if param.param_schema.type == 'object':
+ properties = param.param_schema.properties
+ if properties:
+ param_doc += ' Object properties:\n'
+ for prop_name, prop_details in properties.items():
+ prop_desc = prop_details.description or ''
+ prop_type = TypeHintHelper.get_type_hint(prop_details)
+ param_doc += f' {prop_name} ({prop_type}): {prop_desc}\n'
+
+ return param_doc
+
+ @staticmethod
+ def generate_return_doc(responses: Dict[str, Response]) -> str:
+ """Generates a return value documentation string.
+
+ Args:
+ responses: Dict[str, TypedDict[Response]] - Response in an OpenAPI
+ Operation
+
+ Returns:
+ str: The generated return value Python documentation string.
+ """
+ return_doc = ''
+
+ # Only consider 2xx responses for return type hinting.
+ # Returns the 2xx response with the smallest status code number and with
+ # content defined.
+ sorted_responses = sorted(responses.items(), key=lambda item: int(item[0]))
+ qualified_response = next(
+ filter(
+ lambda r: r[0].startswith('2') and r[1].content,
+ sorted_responses,
+ ),
+ None,
+ )
+ if not qualified_response:
+ return ''
+ response_details = qualified_response[1]
+
+ description = (response_details.description or '').strip()
+ content = response_details.content or {}
+
+ # Generate return type hint and properties for the first response type.
+ # TODO(cheliu): Handle multiple content types.
+ for _, schema_details in content.items():
+ schema = schema_details.schema_ or {}
+
+ # Use a dummy Parameter object for return type hinting.
+ dummy_param = ApiParameter(
+ original_name='', param_location='', param_schema=schema
+ )
+ return_doc = f'Returns ({dummy_param.type_hint}): {description}'
+
+ response_type = schema.type or 'Any'
+ if response_type != 'object':
+ break
+ properties = schema.properties
+ if not properties:
+ break
+ return_doc += ' Object properties:\n'
+ for prop_name, prop_details in properties.items():
+ prop_desc = prop_details.description or ''
+ prop_type = TypeHintHelper.get_type_hint(prop_details)
+ return_doc += f' {prop_name} ({prop_type}): {prop_desc}\n'
+ break
+
+ return return_doc
diff --git a/src/google/adk/tools/openapi_tool/openapi_spec_parser/__init__.py b/src/google/adk/tools/openapi_tool/openapi_spec_parser/__init__.py
new file mode 100644
index 0000000..171d5e2
--- /dev/null
+++ b/src/google/adk/tools/openapi_tool/openapi_spec_parser/__init__.py
@@ -0,0 +1,32 @@
+# Copyright 2025 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from .openapi_spec_parser import OpenApiSpecParser, OperationEndpoint, ParsedOperation
+from .openapi_toolset import OpenAPIToolset
+from .operation_parser import OperationParser
+from .rest_api_tool import AuthPreparationState, RestApiTool, snake_to_lower_camel, to_gemini_schema
+from .tool_auth_handler import ToolAuthHandler
+
+__all__ = [
+ 'OpenApiSpecParser',
+ 'OperationEndpoint',
+ 'ParsedOperation',
+ 'OpenAPIToolset',
+ 'OperationParser',
+ 'RestApiTool',
+ 'to_gemini_schema',
+ 'snake_to_lower_camel',
+ 'AuthPreparationState',
+ 'ToolAuthHandler',
+]
diff --git a/src/google/adk/tools/openapi_tool/openapi_spec_parser/openapi_spec_parser.py b/src/google/adk/tools/openapi_tool/openapi_spec_parser/openapi_spec_parser.py
new file mode 100644
index 0000000..9535953
--- /dev/null
+++ b/src/google/adk/tools/openapi_tool/openapi_spec_parser/openapi_spec_parser.py
@@ -0,0 +1,231 @@
+# Copyright 2025 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import copy
+from typing import Any
+from typing import Dict
+from typing import List
+from typing import Optional
+
+from fastapi.openapi.models import Operation
+from pydantic import BaseModel
+
+from ....auth.auth_credential import AuthCredential
+from ....auth.auth_schemes import AuthScheme
+from ..common.common import ApiParameter
+from ..common.common import to_snake_case
+from .operation_parser import OperationParser
+
+
+class OperationEndpoint(BaseModel):
+ base_url: str
+ path: str
+ method: str
+
+
+class ParsedOperation(BaseModel):
+ name: str
+ description: str
+ endpoint: OperationEndpoint
+ operation: Operation
+ parameters: List[ApiParameter]
+ return_value: ApiParameter
+ auth_scheme: Optional[AuthScheme] = None
+ auth_credential: Optional[AuthCredential] = None
+ additional_context: Optional[Any] = None
+
+
+class OpenApiSpecParser:
+ """Generates Python code, JSON schema, and callables for an OpenAPI operation.
+
+ This class takes an OpenApiOperation object and provides methods to generate:
+ 1. A string representation of a Python function that handles the operation.
+ 2. A JSON schema representing the input parameters of the operation.
+ 3. A callable Python object (a function) that can execute the operation.
+ """
+
+ def parse(self, openapi_spec_dict: Dict[str, Any]) -> List[ParsedOperation]:
+ """Extracts an OpenAPI spec dict into a list of ParsedOperation objects.
+
+ ParsedOperation objects are further used for generating RestApiTool.
+
+ Args:
+ openapi_spec_dict: A dictionary representing the OpenAPI specification.
+
+ Returns:
+ A list of ParsedOperation objects.
+ """
+
+ openapi_spec_dict = self._resolve_references(openapi_spec_dict)
+ operations = self._collect_operations(openapi_spec_dict)
+ return operations
+
+ def _collect_operations(
+ self, openapi_spec: Dict[str, Any]
+ ) -> List[ParsedOperation]:
+ """Collects operations from an OpenAPI spec."""
+ operations = []
+
+ # Taking first server url, or default to empty string if not present
+ base_url = ""
+ if openapi_spec.get("servers"):
+ base_url = openapi_spec["servers"][0].get("url", "")
+
+ # Get global security scheme (if any)
+ global_scheme_name = None
+ if openapi_spec.get("security"):
+ # Use first scheme by default.
+ scheme_names = list(openapi_spec["security"][0].keys())
+ global_scheme_name = scheme_names[0] if scheme_names else None
+
+ auth_schemes = openapi_spec.get("components", {}).get("securitySchemes", {})
+
+ for path, path_item in openapi_spec.get("paths", {}).items():
+ if path_item is None:
+ continue
+
+ for method in (
+ "get",
+ "post",
+ "put",
+ "delete",
+ "patch",
+ "head",
+ "options",
+ "trace",
+ ):
+ operation_dict = path_item.get(method)
+ if operation_dict is None:
+ continue
+
+ # If operation ID is missing, assign an operation id based on path
+ # and method
+ if "operationId" not in operation_dict:
+ temp_id = to_snake_case(f"{path}_{method}")
+ operation_dict["operationId"] = temp_id
+
+ url = OperationEndpoint(base_url=base_url, path=path, method=method)
+ operation = Operation.model_validate(operation_dict)
+ operation_parser = OperationParser(operation)
+
+ # Check for operation-specific auth scheme
+ auth_scheme_name = operation_parser.get_auth_scheme_name()
+ auth_scheme_name = (
+ auth_scheme_name if auth_scheme_name else global_scheme_name
+ )
+ auth_scheme = (
+ auth_schemes.get(auth_scheme_name) if auth_scheme_name else None
+ )
+
+ parsed_op = ParsedOperation(
+ name=operation_parser.get_function_name(),
+ description=operation.description or operation.summary or "",
+ endpoint=url,
+ operation=operation,
+ parameters=operation_parser.get_parameters(),
+ return_value=operation_parser.get_return_value(),
+ auth_scheme=auth_scheme,
+ auth_credential=None, # Placeholder
+ additional_context={},
+ )
+ operations.append(parsed_op)
+
+ return operations
+
+ def _resolve_references(self, openapi_spec: Dict[str, Any]) -> Dict[str, Any]:
+ """Recursively resolves all $ref references in an OpenAPI specification.
+
+ Handles circular references correctly.
+
+ Args:
+ openapi_spec: A dictionary representing the OpenAPI specification.
+
+ Returns:
+ A dictionary representing the OpenAPI specification with all references
+ resolved.
+ """
+
+ openapi_spec = copy.deepcopy(openapi_spec) # Work on a copy
+ resolved_cache = {} # Cache resolved references
+
+ def resolve_ref(ref_string, current_doc):
+ """Resolves a single $ref string."""
+ parts = ref_string.split("/")
+ if parts[0] != "#":
+ raise ValueError(f"External references not supported: {ref_string}")
+
+ current = current_doc
+ for part in parts[1:]:
+ if part in current:
+ current = current[part]
+ else:
+ return None # Reference not found
+ return current
+
+ def recursive_resolve(obj, current_doc, seen_refs=None):
+ """Recursively resolves references, handling circularity.
+
+ Args:
+ obj: The object to traverse.
+ current_doc: Document to search for refs.
+ seen_refs: A set to track already-visited references (for circularity
+ detection).
+
+ Returns:
+ The resolved object.
+ """
+ if seen_refs is None:
+ seen_refs = set() # Initialize the set if it's the first call
+
+ if isinstance(obj, dict):
+ if "$ref" in obj and isinstance(obj["$ref"], str):
+ ref_string = obj["$ref"]
+
+ # Check for circularity
+ if ref_string in seen_refs and ref_string not in resolved_cache:
+ # Circular reference detected! Return a *copy* of the object,
+ # but *without* the $ref. This breaks the cycle while
+ # still maintaining the overall structure.
+ return {k: v for k, v in obj.items() if k != "$ref"}
+
+ seen_refs.add(ref_string) # Add the reference to the set
+
+ # Check if we have a cached resolved value
+ if ref_string in resolved_cache:
+ return copy.deepcopy(resolved_cache[ref_string])
+
+ resolved_value = resolve_ref(ref_string, current_doc)
+ if resolved_value is not None:
+ # Recursively resolve the *resolved* value,
+ # passing along the 'seen_refs' set
+ resolved_value = recursive_resolve(
+ resolved_value, current_doc, seen_refs
+ )
+ resolved_cache[ref_string] = resolved_value
+ return copy.deepcopy(resolved_value) # return the cached result
+ else:
+ return obj # return original if no resolved value.
+
+ else:
+ new_dict = {}
+ for key, value in obj.items():
+ new_dict[key] = recursive_resolve(value, current_doc, seen_refs)
+ return new_dict
+
+ elif isinstance(obj, list):
+ return [recursive_resolve(item, current_doc, seen_refs) for item in obj]
+ else:
+ return obj
+
+ return recursive_resolve(openapi_spec, openapi_spec)
diff --git a/src/google/adk/tools/openapi_tool/openapi_spec_parser/openapi_toolset.py b/src/google/adk/tools/openapi_tool/openapi_spec_parser/openapi_toolset.py
new file mode 100644
index 0000000..dfe0b18
--- /dev/null
+++ b/src/google/adk/tools/openapi_tool/openapi_spec_parser/openapi_toolset.py
@@ -0,0 +1,144 @@
+# Copyright 2025 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import json
+import logging
+from typing import Any
+from typing import Dict
+from typing import Final
+from typing import List
+from typing import Literal
+from typing import Optional
+
+import yaml
+
+from ....auth.auth_credential import AuthCredential
+from ....auth.auth_schemes import AuthScheme
+from .openapi_spec_parser import OpenApiSpecParser
+from .rest_api_tool import RestApiTool
+
+logger = logging.getLogger(__name__)
+
+
+class OpenAPIToolset:
+ """Class for parsing OpenAPI spec into a list of RestApiTool.
+
+ Usage:
+ ```
+ # Initialize OpenAPI toolset from a spec string.
+ openapi_toolset = OpenAPIToolset(spec_str=openapi_spec_str,
+ spec_str_type="json")
+ # Or, initialize OpenAPI toolset from a spec dictionary.
+ openapi_toolset = OpenAPIToolset(spec_dict=openapi_spec_dict)
+
+ # Add all tools to an agent.
+ agent = Agent(
+ tools=[*openapi_toolset.get_tools()]
+ )
+ # Or, add a single tool to an agent.
+ agent = Agent(
+ tools=[openapi_toolset.get_tool('tool_name')]
+ )
+ ```
+ """
+
+ def __init__(
+ self,
+ *,
+ spec_dict: Optional[Dict[str, Any]] = None,
+ spec_str: Optional[str] = None,
+ spec_str_type: Literal["json", "yaml"] = "json",
+ auth_scheme: Optional[AuthScheme] = None,
+ auth_credential: Optional[AuthCredential] = None,
+ ):
+ """Initializes the OpenAPIToolset.
+
+ Usage:
+ ```
+ # Initialize OpenAPI toolset from a spec string.
+ openapi_toolset = OpenAPIToolset(spec_str=openapi_spec_str,
+ spec_str_type="json")
+ # Or, initialize OpenAPI toolset from a spec dictionary.
+ openapi_toolset = OpenAPIToolset(spec_dict=openapi_spec_dict)
+
+ # Add all tools to an agent.
+ agent = Agent(
+ tools=[*openapi_toolset.get_tools()]
+ )
+ # Or, add a single tool to an agent.
+ agent = Agent(
+ tools=[openapi_toolset.get_tool('tool_name')]
+ )
+ ```
+
+ Args:
+ spec_dict: The OpenAPI spec dictionary. If provided, it will be used
+ instead of loading the spec from a string.
+ spec_str: The OpenAPI spec string in JSON or YAML format. It will be used
+ when spec_dict is not provided.
+ spec_str_type: The type of the OpenAPI spec string. Can be "json" or
+ "yaml".
+ auth_scheme: The auth scheme to use for all tools. Use AuthScheme or use
+ helpers in `google.adk.tools.openapi_tool.auth.auth_helpers`
+ auth_credential: The auth credential to use for all tools. Use
+ AuthCredential or use helpers in
+ `google.adk.tools.openapi_tool.auth.auth_helpers`
+ """
+ if not spec_dict:
+ spec_dict = self._load_spec(spec_str, spec_str_type)
+ self.tools: Final[List[RestApiTool]] = list(self._parse(spec_dict))
+ if auth_scheme or auth_credential:
+ self._configure_auth_all(auth_scheme, auth_credential)
+
+ def _configure_auth_all(
+ self, auth_scheme: AuthScheme, auth_credential: AuthCredential
+ ):
+ """Configure auth scheme and credential for all tools."""
+
+ for tool in self.tools:
+ if auth_scheme:
+ tool.configure_auth_scheme(auth_scheme)
+ if auth_credential:
+ tool.configure_auth_credential(auth_credential)
+
+ def get_tools(self) -> List[RestApiTool]:
+ """Get all tools in the toolset."""
+ return self.tools
+
+ def get_tool(self, tool_name: str) -> Optional[RestApiTool]:
+ """Get a tool by name."""
+ matching_tool = filter(lambda t: t.name == tool_name, self.tools)
+ return next(matching_tool, None)
+
+ def _load_spec(
+ self, spec_str: str, spec_type: Literal["json", "yaml"]
+ ) -> Dict[str, Any]:
+ """Loads the OpenAPI spec string into adictionary."""
+ if spec_type == "json":
+ return json.loads(spec_str)
+ elif spec_type == "yaml":
+ return yaml.safe_load(spec_str)
+ else:
+ raise ValueError(f"Unsupported spec type: {spec_type}")
+
+ def _parse(self, openapi_spec_dict: Dict[str, Any]) -> List[RestApiTool]:
+ """Parse OpenAPI spec into a list of RestApiTool."""
+ operations = OpenApiSpecParser().parse(openapi_spec_dict)
+
+ tools = []
+ for o in operations:
+ tool = RestApiTool.from_parsed_operation(o)
+ logger.info("Parsed tool: %s", tool.name)
+ tools.append(tool)
+ return tools
diff --git a/src/google/adk/tools/openapi_tool/openapi_spec_parser/operation_parser.py b/src/google/adk/tools/openapi_tool/openapi_spec_parser/operation_parser.py
new file mode 100644
index 0000000..baed0f4
--- /dev/null
+++ b/src/google/adk/tools/openapi_tool/openapi_spec_parser/operation_parser.py
@@ -0,0 +1,260 @@
+# Copyright 2025 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import inspect
+from textwrap import dedent
+from typing import Any
+from typing import Dict
+from typing import List
+from typing import Optional
+from typing import Union
+
+from fastapi.encoders import jsonable_encoder
+from fastapi.openapi.models import Operation
+from fastapi.openapi.models import Parameter
+from fastapi.openapi.models import Schema
+
+from ..common.common import ApiParameter
+from ..common.common import PydocHelper
+from ..common.common import to_snake_case
+
+
+class OperationParser:
+ """Generates parameters for Python functions from an OpenAPI operation.
+
+ This class processes an OpenApiOperation object and provides helper methods
+ to extract information needed to generate Python function declarations,
+ docstrings, signatures, and JSON schemas. It handles parameter processing,
+ name deduplication, and type hint generation.
+ """
+
+ def __init__(
+ self, operation: Union[Operation, Dict[str, Any], str], should_parse=True
+ ):
+ """Initializes the OperationParser with an OpenApiOperation.
+
+ Args:
+ operation: The OpenApiOperation object or a dictionary to process.
+ should_parse: Whether to parse the operation during initialization.
+ """
+ if isinstance(operation, dict):
+ self.operation = Operation.model_validate(operation)
+ elif isinstance(operation, str):
+ self.operation = Operation.model_validate_json(operation)
+ else:
+ self.operation = operation
+
+ self.params: List[ApiParameter] = []
+ self.return_value: Optional[ApiParameter] = None
+ if should_parse:
+ self._process_operation_parameters()
+ self._process_request_body()
+ self._process_return_value()
+ self._dedupe_param_names()
+
+ @classmethod
+ def load(
+ cls,
+ operation: Union[Operation, Dict[str, Any]],
+ params: List[ApiParameter],
+ return_value: Optional[ApiParameter] = None,
+ ) -> 'OperationParser':
+ parser = cls(operation, should_parse=False)
+ parser.params = params
+ parser.return_value = return_value
+ return parser
+
+ def _process_operation_parameters(self):
+ """Processes parameters from the OpenAPI operation."""
+ parameters = self.operation.parameters or []
+ for param in parameters:
+ if isinstance(param, Parameter):
+ original_name = param.name
+ description = param.description or ''
+ location = param.in_ or ''
+ schema = param.schema_ or {} # Use schema_ instead of .schema
+
+ self.params.append(
+ ApiParameter(
+ original_name=original_name,
+ param_location=location,
+ param_schema=schema,
+ description=description,
+ )
+ )
+
+ def _process_request_body(self):
+ """Processes the request body from the OpenAPI operation."""
+ request_body = self.operation.requestBody
+ if not request_body:
+ return
+
+ content = request_body.content or {}
+ if not content:
+ return
+
+ # If request body is an object, expand the properties as parameters
+ for _, media_type_object in content.items():
+ schema = media_type_object.schema_ or {}
+ description = request_body.description or ''
+
+ if schema and schema.type == 'object':
+ for prop_name, prop_details in schema.properties.items():
+ self.params.append(
+ ApiParameter(
+ original_name=prop_name,
+ param_location='body',
+ param_schema=prop_details,
+ description=prop_details.description,
+ )
+ )
+
+ elif schema and schema.type == 'array':
+ self.params.append(
+ ApiParameter(
+ original_name='array',
+ param_location='body',
+ param_schema=schema,
+ description=description,
+ )
+ )
+ else:
+ self.params.append(
+ # Empty name for unnamed body param
+ ApiParameter(
+ original_name='',
+ param_location='body',
+ param_schema=schema,
+ description=description,
+ )
+ )
+ break # Process first mime type only
+
+ def _dedupe_param_names(self):
+ """Deduplicates parameter names to avoid conflicts."""
+ params_cnt = {}
+ for param in self.params:
+ name = param.py_name
+ if name not in params_cnt:
+ params_cnt[name] = 0
+ else:
+ params_cnt[name] += 1
+ param.py_name = f'{name}_{params_cnt[name] -1}'
+
+ def _process_return_value(self) -> Parameter:
+ """Returns a Parameter object representing the return type."""
+ responses = self.operation.responses or {}
+ # Default to Any if no 2xx response or if schema is missing
+ return_schema = Schema(type='Any')
+
+ # Take the 20x response with the smallest response code.
+ valid_codes = list(
+ filter(lambda k: k.startswith('2'), list(responses.keys()))
+ )
+ min_20x_status_code = min(valid_codes) if valid_codes else None
+
+ if min_20x_status_code and responses[min_20x_status_code].content:
+ content = responses[min_20x_status_code].content
+ for mime_type in content:
+ if content[mime_type].schema_:
+ return_schema = content[mime_type].schema_
+ break
+
+ self.return_value = ApiParameter(
+ original_name='',
+ param_location='',
+ param_schema=return_schema,
+ )
+
+ def get_function_name(self) -> str:
+ """Returns the generated function name."""
+ operation_id = self.operation.operationId
+ if not operation_id:
+ raise ValueError('Operation ID is missing')
+ return to_snake_case(operation_id)[:60]
+
+ def get_return_type_hint(self) -> str:
+ """Returns the return type hint string (like 'str', 'int', etc.)."""
+ return self.return_value.type_hint
+
+ def get_return_type_value(self) -> Any:
+ """Returns the return type value (like str, int, List[str], etc.)."""
+ return self.return_value.type_value
+
+ def get_parameters(self) -> List[ApiParameter]:
+ """Returns the list of Parameter objects."""
+ return self.params
+
+ def get_return_value(self) -> ApiParameter:
+ """Returns the list of Parameter objects."""
+ return self.return_value
+
+ def get_auth_scheme_name(self) -> str:
+ """Returns the name of the auth scheme for this operation from the spec."""
+ if self.operation.security:
+ scheme_name = list(self.operation.security[0].keys())[0]
+ return scheme_name
+ return ''
+
+ def get_pydoc_string(self) -> str:
+ """Returns the generated PyDoc string."""
+ pydoc_params = [param.to_pydoc_string() for param in self.params]
+ pydoc_description = (
+ self.operation.summary or self.operation.description or ''
+ )
+ pydoc_return = PydocHelper.generate_return_doc(
+ self.operation.responses or {}
+ )
+ pydoc_arg_list = chr(10).join(
+ f' {param_doc}' for param_doc in pydoc_params
+ )
+ return dedent(f"""
+ \"\"\"{pydoc_description}
+
+ Args:
+ {pydoc_arg_list}
+
+ {pydoc_return}
+ \"\"\"
+ """).strip()
+
+ def get_json_schema(self) -> Dict[str, Any]:
+ """Returns the JSON schema for the function arguments."""
+ properties = {
+ p.py_name: jsonable_encoder(p.param_schema, exclude_none=True)
+ for p in self.params
+ }
+ return {
+ 'properties': properties,
+ 'required': [p.py_name for p in self.params],
+ 'title': f"{self.operation.operationId or 'unnamed'}_Arguments",
+ 'type': 'object',
+ }
+
+ def get_signature_parameters(self) -> List[inspect.Parameter]:
+ """Returns a list of inspect.Parameter objects for the function."""
+ return [
+ inspect.Parameter(
+ param.py_name,
+ inspect.Parameter.POSITIONAL_OR_KEYWORD,
+ annotation=param.type_value,
+ )
+ for param in self.params
+ ]
+
+ def get_annotations(self) -> Dict[str, Any]:
+ """Returns a dictionary of parameter annotations for the function."""
+ annotations = {p.py_name: p.type_value for p in self.params}
+ annotations['return'] = self.get_return_type_value()
+ return annotations
diff --git a/src/google/adk/tools/openapi_tool/openapi_spec_parser/rest_api_tool.py b/src/google/adk/tools/openapi_tool/openapi_spec_parser/rest_api_tool.py
new file mode 100644
index 0000000..b6962b7
--- /dev/null
+++ b/src/google/adk/tools/openapi_tool/openapi_spec_parser/rest_api_tool.py
@@ -0,0 +1,496 @@
+# Copyright 2025 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import Any
+from typing import Dict
+from typing import List
+from typing import Literal
+from typing import Optional
+from typing import Tuple
+from typing import Union
+
+from fastapi.openapi.models import Operation
+from google.genai.types import FunctionDeclaration
+from google.genai.types import Schema
+import requests
+from typing_extensions import override
+
+from ....auth.auth_credential import AuthCredential
+from ....auth.auth_schemes import AuthScheme
+from ....tools import BaseTool
+from ...tool_context import ToolContext
+from ..auth.auth_helpers import credential_to_param
+from ..auth.auth_helpers import dict_to_auth_scheme
+from ..auth.credential_exchangers.auto_auth_credential_exchanger import AutoAuthCredentialExchanger
+from ..common.common import ApiParameter
+from ..common.common import to_snake_case
+from .openapi_spec_parser import OperationEndpoint
+from .openapi_spec_parser import ParsedOperation
+from .operation_parser import OperationParser
+from .tool_auth_handler import ToolAuthHandler
+
+
+def snake_to_lower_camel(snake_case_string: str):
+ """Converts a snake_case string to a lower_camel_case string.
+
+ Args:
+ snake_case_string: The input snake_case string.
+
+ Returns:
+ The lower_camel_case string.
+ """
+ if "_" not in snake_case_string:
+ return snake_case_string
+
+ return "".join([
+ s.lower() if i == 0 else s.capitalize()
+ for i, s in enumerate(snake_case_string.split("_"))
+ ])
+
+
+def to_gemini_schema(openapi_schema: Optional[Dict[str, Any]] = None) -> Schema:
+ """Converts an OpenAPI schema dictionary to a Gemini Schema object.
+
+ Args:
+ openapi_schema: The OpenAPI schema dictionary.
+
+ Returns:
+ A Pydantic Schema object. Returns None if input is None.
+ Raises TypeError if input is not a dict.
+ """
+ if openapi_schema is None:
+ return None
+
+ if not isinstance(openapi_schema, dict):
+ raise TypeError("openapi_schema must be a dictionary")
+
+ pydantic_schema_data = {}
+
+ # Adding this to force adding a type to an empty dict
+ # This avoid "... one_of or any_of must specify a type" error
+ if not openapi_schema.get("type"):
+ openapi_schema["type"] = "object"
+
+ # Adding this to avoid "properties: should be non-empty for OBJECT type" error
+ # See b/385165182
+ if openapi_schema.get("type", "") == "object" and not openapi_schema.get(
+ "properties"
+ ):
+ openapi_schema["properties"] = {"dummy_DO_NOT_GENERATE": {"type": "string"}}
+
+ for key, value in openapi_schema.items():
+ snake_case_key = to_snake_case(key)
+ # Check if the snake_case_key exists in the Schema model's fields.
+ if snake_case_key in Schema.model_fields:
+ if snake_case_key in ["title", "default", "format"]:
+ # Ignore these fields as Gemini backend doesn't recognize them, and will
+ # throw exception if they appear in the schema.
+ # Format: properties[expiration].format: only 'enum' and 'date-time' are
+ # supported for STRING type
+ continue
+ if snake_case_key == "properties" and isinstance(value, dict):
+ pydantic_schema_data[snake_case_key] = {
+ k: to_gemini_schema(v) for k, v in value.items()
+ }
+ elif snake_case_key == "items" and isinstance(value, dict):
+ pydantic_schema_data[snake_case_key] = to_gemini_schema(value)
+ elif snake_case_key == "any_of" and isinstance(value, list):
+ pydantic_schema_data[snake_case_key] = [
+ to_gemini_schema(item) for item in value
+ ]
+ # Important: Handle cases where the OpenAPI schema might contain lists
+ # or other structures that need to be recursively processed.
+ elif isinstance(value, list) and snake_case_key not in (
+ "enum",
+ "required",
+ "property_ordering",
+ ):
+ new_list = []
+ for item in value:
+ if isinstance(item, dict):
+ new_list.append(to_gemini_schema(item))
+ else:
+ new_list.append(item)
+ pydantic_schema_data[snake_case_key] = new_list
+ elif isinstance(value, dict) and snake_case_key not in ("properties"):
+ # Handle dictionary which is neither properties or items
+ pydantic_schema_data[snake_case_key] = to_gemini_schema(value)
+ else:
+ # Simple value assignment (int, str, bool, etc.)
+ pydantic_schema_data[snake_case_key] = value
+
+ return Schema(**pydantic_schema_data)
+
+
+AuthPreparationState = Literal["pending", "done"]
+
+
+class RestApiTool(BaseTool):
+ """A generic tool that interacts with a REST API.
+
+ * Generates request params and body
+ * Attaches auth credentials to API call.
+
+ Example:
+ ```
+ # Each API operation in the spec will be turned into its own tool
+ # Name of the tool is the operationId of that operation, in snake case
+ operations = OperationGenerator().parse(openapi_spec_dict)
+ tool = [RestApiTool.from_parsed_operation(o) for o in operations]
+ ```
+ """
+
+ def __init__(
+ self,
+ name: str,
+ description: str,
+ endpoint: Union[OperationEndpoint, str],
+ operation: Union[Operation, str],
+ auth_scheme: Optional[Union[AuthScheme, str]] = None,
+ auth_credential: Optional[Union[AuthCredential, str]] = None,
+ should_parse_operation=True,
+ ):
+ """Initializes the RestApiTool with the given parameters.
+
+ To generate RestApiTool from OpenAPI Specs, use OperationGenerator.
+ Example:
+ ```
+ # Each API operation in the spec will be turned into its own tool
+ # Name of the tool is the operationId of that operation, in snake case
+ operations = OperationGenerator().parse(openapi_spec_dict)
+ tool = [RestApiTool.from_parsed_operation(o) for o in operations]
+ ```
+
+ Hint: Use google.adk.tools.openapi_tool.auth.auth_helpers to construct
+ auth_scheme and auth_credential.
+
+ Args:
+ name: The name of the tool.
+ description: The description of the tool.
+ endpoint: Include the base_url, path, and method of the tool.
+ operation: Pydantic object or a dict. Representing the OpenAPI Operation
+ object
+ (https://github.com/OAI/OpenAPI-Specification/blob/main/versions/3.1.0.md#operation-object)
+ auth_scheme: The auth scheme of the tool. Representing the OpenAPI
+ SecurityScheme object
+ (https://github.com/OAI/OpenAPI-Specification/blob/main/versions/3.1.0.md#security-scheme-object)
+ auth_credential: The authentication credential of the tool.
+ should_parse_operation: Whether to parse the operation.
+ """
+ # Gemini restrict the length of function name to be less than 64 characters
+ self.name = name[:60]
+ self.description = description
+ self.endpoint = (
+ OperationEndpoint.model_validate_json(endpoint)
+ if isinstance(endpoint, str)
+ else endpoint
+ )
+ self.operation = (
+ Operation.model_validate_json(operation)
+ if isinstance(operation, str)
+ else operation
+ )
+ self.auth_credential, self.auth_scheme = None, None
+
+ self.configure_auth_credential(auth_credential)
+ self.configure_auth_scheme(auth_scheme)
+
+ # Private properties
+ self.credential_exchanger = AutoAuthCredentialExchanger()
+ if should_parse_operation:
+ self._operation_parser = OperationParser(self.operation)
+
+ @classmethod
+ def from_parsed_operation(cls, parsed: ParsedOperation) -> "RestApiTool":
+ """Initializes the RestApiTool from a ParsedOperation object.
+
+ Args:
+ parsed: A ParsedOperation object.
+
+ Returns:
+ A RestApiTool object.
+ """
+ operation_parser = OperationParser.load(
+ parsed.operation, parsed.parameters, parsed.return_value
+ )
+
+ tool_name = to_snake_case(operation_parser.get_function_name())
+ generated = cls(
+ name=tool_name,
+ description=parsed.operation.description
+ or parsed.operation.summary
+ or "",
+ endpoint=parsed.endpoint,
+ operation=parsed.operation,
+ auth_scheme=parsed.auth_scheme,
+ auth_credential=parsed.auth_credential,
+ )
+ generated._operation_parser = operation_parser
+ return generated
+
+ @classmethod
+ def from_parsed_operation_str(
+ cls, parsed_operation_str: str
+ ) -> "RestApiTool":
+ """Initializes the RestApiTool from a dict.
+
+ Args:
+ parsed: A dict representation of a ParsedOperation object.
+
+ Returns:
+ A RestApiTool object.
+ """
+ operation = ParsedOperation.model_validate_json(parsed_operation_str)
+ return RestApiTool.from_parsed_operation(operation)
+
+ @override
+ def _get_declaration(self) -> FunctionDeclaration:
+ """Returns the function declaration in the Gemini Schema format."""
+ schema_dict = self._operation_parser.get_json_schema()
+ parameters = to_gemini_schema(schema_dict)
+ function_decl = FunctionDeclaration(
+ name=self.name, description=self.description, parameters=parameters
+ )
+ return function_decl
+
+ def configure_auth_scheme(
+ self, auth_scheme: Union[AuthScheme, Dict[str, Any]]
+ ):
+ """Configures the authentication scheme for the API call.
+
+ Args:
+ auth_scheme: AuthScheme|dict -: The authentication scheme. The dict is
+ converted to a AuthScheme object.
+ """
+ if isinstance(auth_scheme, dict):
+ auth_scheme = dict_to_auth_scheme(auth_scheme)
+ self.auth_scheme = auth_scheme
+
+ def configure_auth_credential(
+ self, auth_credential: Optional[Union[AuthCredential, str]] = None
+ ):
+ """Configures the authentication credential for the API call.
+
+ Args:
+ auth_credential: AuthCredential|dict - The authentication credential.
+ The dict is converted to an AuthCredential object.
+ """
+ if isinstance(auth_credential, str):
+ auth_credential = AuthCredential.model_validate_json(auth_credential)
+ self.auth_credential = auth_credential
+
+ def _prepare_auth_request_params(
+ self,
+ auth_scheme: AuthScheme,
+ auth_credential: AuthCredential,
+ ) -> Tuple[List[ApiParameter], Dict[str, Any]]:
+ # Handle Authentication
+ if not auth_scheme or not auth_credential:
+ return
+
+ return credential_to_param(auth_scheme, auth_credential)
+
+ def _prepare_request_params(
+ self, parameters: List[ApiParameter], kwargs: Dict[str, Any]
+ ) -> Dict[str, Any]:
+ """Prepares the request parameters for the API call.
+
+ Args:
+ parameters: A list of ApiParameter objects representing the parameters
+ for the API call.
+ kwargs: The keyword arguments passed to the call function from the Tool
+ caller.
+
+ Returns:
+ A dictionary containing the request parameters for the API call. This
+ initializes a requests.request() call.
+
+ Example:
+ self._prepare_request_params({"input_id": "test-id"})
+ """
+ method = self.endpoint.method.lower()
+ if not method:
+ raise ValueError("Operation method not found.")
+
+ path_params: Dict[str, Any] = {}
+ query_params: Dict[str, Any] = {}
+ header_params: Dict[str, Any] = {}
+ cookie_params: Dict[str, Any] = {}
+
+ params_map: Dict[str, ApiParameter] = {p.py_name: p for p in parameters}
+
+ # Fill in path, query, header and cookie parameters to the request
+ for param_k, v in kwargs.items():
+ param_obj = params_map.get(param_k)
+ if not param_obj:
+ continue # If input arg not in the ApiParameter list, ignore it.
+
+ original_k = param_obj.original_name
+ param_location = param_obj.param_location
+
+ if param_location == "path":
+ path_params[original_k] = v
+ elif param_location == "query":
+ if v:
+ query_params[original_k] = v
+ elif param_location == "header":
+ header_params[original_k] = v
+ elif param_location == "cookie":
+ cookie_params[original_k] = v
+
+ # Construct URL
+ base_url = self.endpoint.base_url or ""
+ base_url = base_url[:-1] if base_url.endswith("/") else base_url
+ url = f"{base_url}{self.endpoint.path.format(**path_params)}"
+
+ # Construct body
+ body_kwargs: Dict[str, Any] = {}
+ request_body = self.operation.requestBody
+ if request_body:
+ for mime_type, media_type_object in request_body.content.items():
+ schema = media_type_object.schema_
+ body_data = None
+
+ if schema.type == "object":
+ body_data = {}
+ for param in parameters:
+ if param.param_location == "body" and param.py_name in kwargs:
+ body_data[param.original_name] = kwargs[param.py_name]
+
+ elif schema.type == "array":
+ for param in parameters:
+ if param.param_location == "body" and param.py_name == "array":
+ body_data = kwargs.get("array")
+ break
+ else: # like string
+ for param in parameters:
+ # original_name = '' indicating this param applies to the full body.
+ if param.param_location == "body" and not param.original_name:
+ body_data = (
+ kwargs.get(param.py_name) if param.py_name in kwargs else None
+ )
+ break
+
+ if mime_type == "application/json" or mime_type.endswith("+json"):
+ if body_data is not None:
+ body_kwargs["json"] = body_data
+ elif mime_type == "application/x-www-form-urlencoded":
+ body_kwargs["data"] = body_data
+ elif mime_type == "multipart/form-data":
+ body_kwargs["files"] = body_data
+ elif mime_type == "application/octet-stream":
+ body_kwargs["data"] = body_data
+ elif mime_type == "text/plain":
+ body_kwargs["data"] = body_data
+
+ if mime_type:
+ header_params["Content-Type"] = mime_type
+ break # Process only the first mime_type
+
+ filtered_query_params: Dict[str, Any] = {
+ k: v for k, v in query_params.items() if v is not None
+ }
+
+ request_params: Dict[str, Any] = {
+ "method": method,
+ "url": url,
+ "params": filtered_query_params,
+ "headers": header_params,
+ "cookies": cookie_params,
+ **body_kwargs,
+ }
+
+ return request_params
+
+ @override
+ async def run_async(
+ self, *, args: dict[str, Any], tool_context: Optional[ToolContext]
+ ) -> Dict[str, Any]:
+ return self.call(args=args, tool_context=tool_context)
+
+ def call(
+ self, *, args: dict[str, Any], tool_context: Optional[ToolContext]
+ ) -> Dict[str, Any]:
+ """Executes the REST API call.
+
+ Args:
+ args: Keyword arguments representing the operation parameters.
+ tool_context: The tool context (not used here, but required by the
+ interface).
+
+ Returns:
+ The API response as a dictionary.
+ """
+ # Prepare auth credentials for the API call
+ tool_auth_handler = ToolAuthHandler.from_tool_context(
+ tool_context, self.auth_scheme, self.auth_credential
+ )
+ auth_result = tool_auth_handler.prepare_auth_credentials()
+ auth_state, auth_scheme, auth_credential = (
+ auth_result.state,
+ auth_result.auth_scheme,
+ auth_result.auth_credential,
+ )
+
+ if auth_state == "pending":
+ return {
+ "pending": True,
+ "message": "Needs your authorization to access your data.",
+ }
+
+ # Attach parameters from auth into main parameters list
+ api_params, api_args = self._operation_parser.get_parameters().copy(), args
+ if auth_credential:
+ # Attach parameters from auth into main parameters list
+ auth_param, auth_args = self._prepare_auth_request_params(
+ auth_scheme, auth_credential
+ )
+ if auth_param and auth_args:
+ api_params = [auth_param] + api_params
+ api_args.update(auth_args)
+
+ # Got all parameters. Call the API.
+ request_params = self._prepare_request_params(api_params, api_args)
+ response = requests.request(**request_params)
+
+ # Parse API response
+ try:
+ response.raise_for_status() # Raise HTTPError for bad responses
+ return response.json() # Try to decode JSON
+ except requests.exceptions.HTTPError:
+ error_details = response.content.decode("utf-8")
+ return {
+ "error": (
+ f"Tool {self.name} execution failed. Analyze this execution error"
+ " and your inputs. Retry with adjustments if applicable. But"
+ " make sure don't retry more than 3 times. Execution Error:"
+ f" {error_details}"
+ )
+ }
+ except ValueError:
+ return {"text": response.text} # Return text if not JSON
+
+ def __str__(self):
+ return (
+ f'RestApiTool(name="{self.name}", description="{self.description}",'
+ f' endpoint="{self.endpoint}")'
+ )
+
+ def __repr__(self):
+ return (
+ f'RestApiTool(name="{self.name}", description="{self.description}",'
+ f' endpoint="{self.endpoint}", operation="{self.operation}",'
+ f' auth_scheme="{self.auth_scheme}",'
+ f' auth_credential="{self.auth_credential}")'
+ )
diff --git a/src/google/adk/tools/openapi_tool/openapi_spec_parser/tool_auth_handler.py b/src/google/adk/tools/openapi_tool/openapi_spec_parser/tool_auth_handler.py
new file mode 100644
index 0000000..eac1ef3
--- /dev/null
+++ b/src/google/adk/tools/openapi_tool/openapi_spec_parser/tool_auth_handler.py
@@ -0,0 +1,268 @@
+# Copyright 2025 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import logging
+from typing import Literal
+from typing import Optional
+
+from fastapi.encoders import jsonable_encoder
+from pydantic import BaseModel
+
+from ....auth.auth_credential import AuthCredential
+from ....auth.auth_credential import AuthCredentialTypes
+from ....auth.auth_schemes import AuthScheme
+from ....auth.auth_schemes import AuthSchemeType
+from ....auth.auth_tool import AuthConfig
+from ...tool_context import ToolContext
+from ..auth.credential_exchangers.auto_auth_credential_exchanger import AutoAuthCredentialExchanger
+from ..auth.credential_exchangers.base_credential_exchanger import AuthCredentialMissingError
+from ..auth.credential_exchangers.base_credential_exchanger import BaseAuthCredentialExchanger
+
+logger = logging.getLogger(__name__)
+
+AuthPreparationState = Literal["pending", "done"]
+
+
+class AuthPreparationResult(BaseModel):
+ """Result of the credential preparation process."""
+
+ state: AuthPreparationState
+ auth_scheme: Optional[AuthScheme] = None
+ auth_credential: Optional[AuthCredential] = None
+
+
+class ToolContextCredentialStore:
+ """Handles storage and retrieval of credentials within a ToolContext."""
+
+ def __init__(self, tool_context: ToolContext):
+ self.tool_context = tool_context
+
+ def get_credential_key(
+ self,
+ auth_scheme: Optional[AuthScheme],
+ auth_credential: Optional[AuthCredential],
+ ) -> str:
+ """Generates a unique key for the given auth scheme and credential."""
+ scheme_name = (
+ f"{auth_scheme.type_.name}_{hash(auth_scheme.model_dump_json())}"
+ if auth_scheme
+ else ""
+ )
+ credential_name = (
+ f"{auth_credential.auth_type.value}_{hash(auth_credential.model_dump_json())}"
+ if auth_credential
+ else ""
+ )
+ # no need to prepend temp: namespace, session state is a copy, changes to
+ # it won't be persisted , only changes in event_action.state_delta will be
+ # persisted. temp: namespace will be cleared after current run. but tool
+ # want access token to be there stored across runs
+
+ return f"{scheme_name}_{credential_name}_existing_exchanged_credential"
+
+ def get_credential(
+ self,
+ auth_scheme: Optional[AuthScheme],
+ auth_credential: Optional[AuthCredential],
+ ) -> Optional[AuthCredential]:
+ if not self.tool_context:
+ return None
+
+ token_key = self.get_credential_key(auth_scheme, auth_credential)
+ # TODO try not to use session state, this looks a hacky way, depend on
+ # session implementation, we don't want session to persist the token,
+ # meanwhile we want the token shared across runs.
+ serialized_credential = self.tool_context.state.get(token_key)
+ if not serialized_credential:
+ return None
+ return AuthCredential.model_validate(serialized_credential)
+
+ def store_credential(
+ self,
+ key: str,
+ auth_credential: Optional[AuthCredential],
+ ):
+ if self.tool_context:
+ serializable_credential = jsonable_encoder(
+ auth_credential, exclude_none=True
+ )
+ self.tool_context.state[key] = serializable_credential
+
+ def remove_credential(self, key: str):
+ del self.tool_context.state[key]
+
+
+class ToolAuthHandler:
+ """Handles the preparation and exchange of authentication credentials for tools."""
+
+ def __init__(
+ self,
+ tool_context: ToolContext,
+ auth_scheme: Optional[AuthScheme],
+ auth_credential: Optional[AuthCredential],
+ credential_exchanger: Optional[BaseAuthCredentialExchanger] = None,
+ credential_store: Optional["ToolContextCredentialStore"] = None,
+ ):
+ self.tool_context = tool_context
+ self.auth_scheme = (
+ auth_scheme.model_copy(deep=True) if auth_scheme else None
+ )
+ self.auth_credential = (
+ auth_credential.model_copy(deep=True) if auth_credential else None
+ )
+ self.credential_exchanger = (
+ credential_exchanger or AutoAuthCredentialExchanger()
+ )
+ self.credential_store = credential_store
+ self.should_store_credential = True
+
+ @classmethod
+ def from_tool_context(
+ cls,
+ tool_context: ToolContext,
+ auth_scheme: Optional[AuthScheme],
+ auth_credential: Optional[AuthCredential],
+ credential_exchanger: Optional[BaseAuthCredentialExchanger] = None,
+ ) -> "ToolAuthHandler":
+ """Creates a ToolAuthHandler instance from a ToolContext."""
+ credential_store = ToolContextCredentialStore(tool_context)
+ return cls(
+ tool_context,
+ auth_scheme,
+ auth_credential,
+ credential_exchanger,
+ credential_store,
+ )
+
+ def _handle_existing_credential(
+ self,
+ ) -> Optional[AuthPreparationResult]:
+ """Checks for and returns an existing, exchanged credential."""
+ if self.credential_store:
+ existing_credential = self.credential_store.get_credential(
+ self.auth_scheme, self.auth_credential
+ )
+ if existing_credential:
+ return AuthPreparationResult(
+ state="done",
+ auth_scheme=self.auth_scheme,
+ auth_credential=existing_credential,
+ )
+ return None
+
+ def _exchange_credential(
+ self, auth_credential: AuthCredential
+ ) -> Optional[AuthPreparationResult]:
+ """Handles an OpenID Connect authorization response."""
+
+ exchanged_credential = None
+ try:
+ exchanged_credential = self.credential_exchanger.exchange_credential(
+ self.auth_scheme, auth_credential
+ )
+ except Exception as e:
+ logger.error("Failed to exchange credential: %s", e)
+ return exchanged_credential
+
+ def _store_credential(self, auth_credential: AuthCredential) -> None:
+ """stores the auth_credential."""
+
+ if self.credential_store:
+ key = self.credential_store.get_credential_key(
+ self.auth_scheme, self.auth_credential
+ )
+ self.credential_store.store_credential(key, auth_credential)
+
+ def _reqeust_credential(self) -> None:
+ """Handles the case where an OpenID Connect or OAuth2 authentication request is needed."""
+ if self.auth_scheme.type_ in (
+ AuthSchemeType.openIdConnect,
+ AuthSchemeType.oauth2,
+ ):
+ if not self.auth_credential or not self.auth_credential.oauth2:
+ raise ValueError(
+ f"auth_credential is empty for scheme {self.auth_scheme.type_}."
+ "Please create AuthCredential using OAuth2Auth."
+ )
+
+ if not self.auth_credential.oauth2.client_id:
+ raise AuthCredentialMissingError(
+ "OAuth2 credentials client_id is missing."
+ )
+
+ if not self.auth_credential.oauth2.client_secret:
+ raise AuthCredentialMissingError(
+ "OAuth2 credentials client_secret is missing."
+ )
+
+ self.tool_context.request_credential(
+ AuthConfig(
+ auth_scheme=self.auth_scheme,
+ raw_auth_credential=self.auth_credential,
+ )
+ )
+ return None
+
+ def _get_auth_response(self) -> AuthCredential:
+ return self.tool_context.get_auth_response(
+ AuthConfig(
+ auth_scheme=self.auth_scheme,
+ raw_auth_credential=self.auth_credential,
+ )
+ )
+
+ def _request_credential(self, auth_config: AuthConfig):
+ if not self.tool_context:
+ return
+ self.tool_context.request_credential(auth_config)
+
+ def prepare_auth_credentials(
+ self,
+ ) -> AuthPreparationResult:
+ """Prepares authentication credentials, handling exchange and user interaction."""
+
+ # no auth is needed
+ if not self.auth_scheme:
+ return AuthPreparationResult(state="done")
+
+ # Check for existing credential.
+ existing_result = self._handle_existing_credential()
+ if existing_result:
+ return existing_result
+
+ # fetch credential from adk framework
+ # Some auth scheme like OAuth2 AuthCode & OpenIDConnect may require
+ # multi-step exchange:
+ # client_id , client_secret -> auth_uri -> auth_code -> access_token
+ # -> bearer token
+ # adk framework supports exchange access_token already
+ fetched_credential = self._get_auth_response() or self.auth_credential
+
+ exchanged_credential = self._exchange_credential(fetched_credential)
+
+ if exchanged_credential:
+ self._store_credential(exchanged_credential)
+ return AuthPreparationResult(
+ state="done",
+ auth_scheme=self.auth_scheme,
+ auth_credential=exchanged_credential,
+ )
+ else:
+ self._reqeust_credential()
+ return AuthPreparationResult(
+ state="pending",
+ auth_scheme=self.auth_scheme,
+ auth_credential=self.auth_credential,
+ )
diff --git a/src/google/adk/tools/preload_memory_tool.py b/src/google/adk/tools/preload_memory_tool.py
new file mode 100644
index 0000000..ebc682d
--- /dev/null
+++ b/src/google/adk/tools/preload_memory_tool.py
@@ -0,0 +1,72 @@
+# Copyright 2025 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import annotations
+
+from datetime import datetime
+from typing import TYPE_CHECKING
+
+from typing_extensions import override
+
+from .base_tool import BaseTool
+from .tool_context import ToolContext
+
+if TYPE_CHECKING:
+ from ..models import LlmRequest
+
+
+class PreloadMemoryTool(BaseTool):
+ """A tool that preloads the memory for the current user."""
+
+ def __init__(self):
+ # Name and description are not used because this tool only
+ # changes llm_request.
+ super().__init__(name='preload_memory', description='preload_memory')
+
+ @override
+ async def process_llm_request(
+ self,
+ *,
+ tool_context: ToolContext,
+ llm_request: LlmRequest,
+ ) -> None:
+ parts = tool_context.user_content.parts
+ if not parts or not parts[0].text:
+ return
+ query = parts[0].text
+ response = tool_context.search_memory(query)
+ if not response.memories:
+ return
+ memory_text = ''
+ for memory in response.memories:
+ time_str = datetime.fromtimestamp(memory.events[0].timestamp).isoformat()
+ memory_text += f'Time: {time_str}\n'
+ for event in memory.events:
+ # TODO: support multi-part content.
+ if (
+ event.content
+ and event.content.parts
+ and event.content.parts[0].text
+ ):
+ memory_text += f'{event.author}: {event.content.parts[0].text}\n'
+ si = f"""The following content is from your previous conversations with the user.
+They may be useful for answering the user's current query.
+
+{memory_text}
+
+"""
+ llm_request.append_instructions([si])
+
+
+preload_memory_tool = PreloadMemoryTool()
diff --git a/src/google/adk/tools/retrieval/__init__.py b/src/google/adk/tools/retrieval/__init__.py
new file mode 100644
index 0000000..424b75a
--- /dev/null
+++ b/src/google/adk/tools/retrieval/__init__.py
@@ -0,0 +1,36 @@
+# Copyright 2025 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from .base_retrieval_tool import BaseRetrievalTool
+from .files_retrieval import FilesRetrieval
+from .llama_index_retrieval import LlamaIndexRetrieval
+
+__all__ = [
+ 'BaseRetrievalTool',
+ 'FilesRetrieval',
+ 'LlamaIndexRetrieval',
+]
+
+try:
+ from .vertex_ai_rag_retrieval import VertexAiRagRetrieval
+
+ __all__.append('VertexAiRagRetrieval')
+except ImportError:
+ import logging
+
+ logger = logging.getLogger(__name__)
+ logger.debug(
+ 'The Vertex sdk is not installed. If you want to use the Vertex RAG with'
+ ' agents, please install it. If not, you can ignore this warning.'
+ )
diff --git a/src/google/adk/tools/retrieval/base_retrieval_tool.py b/src/google/adk/tools/retrieval/base_retrieval_tool.py
new file mode 100644
index 0000000..64f3ec9
--- /dev/null
+++ b/src/google/adk/tools/retrieval/base_retrieval_tool.py
@@ -0,0 +1,37 @@
+# Copyright 2025 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from google.genai import types
+from typing_extensions import override
+
+from ..base_tool import BaseTool
+
+
+class BaseRetrievalTool(BaseTool):
+
+ @override
+ def _get_declaration(self) -> types.FunctionDeclaration:
+ return types.FunctionDeclaration(
+ name=self.name,
+ description=self.description,
+ parameters=types.Schema(
+ type=types.Type.OBJECT,
+ properties={
+ 'query': types.Schema(
+ type=types.Type.STRING,
+ description='The query to retrieve.',
+ ),
+ },
+ ),
+ )
diff --git a/src/google/adk/tools/retrieval/files_retrieval.py b/src/google/adk/tools/retrieval/files_retrieval.py
new file mode 100644
index 0000000..d65a709
--- /dev/null
+++ b/src/google/adk/tools/retrieval/files_retrieval.py
@@ -0,0 +1,33 @@
+# Copyright 2025 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Provides data for the agent."""
+
+from llama_index.core import SimpleDirectoryReader
+from llama_index.core import VectorStoreIndex
+
+from .llama_index_retrieval import LlamaIndexRetrieval
+
+
+class FilesRetrieval(LlamaIndexRetrieval):
+
+ def __init__(self, *, name: str, description: str, input_dir: str):
+
+ self.input_dir = input_dir
+
+ print(f'Loading data from {input_dir}')
+ retriever = VectorStoreIndex.from_documents(
+ SimpleDirectoryReader(input_dir).load_data()
+ ).as_retriever()
+ super().__init__(name=name, description=description, retriever=retriever)
diff --git a/src/google/adk/tools/retrieval/llama_index_retrieval.py b/src/google/adk/tools/retrieval/llama_index_retrieval.py
new file mode 100644
index 0000000..5297870
--- /dev/null
+++ b/src/google/adk/tools/retrieval/llama_index_retrieval.py
@@ -0,0 +1,41 @@
+# Copyright 2025 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Provides data for the agent."""
+
+from __future__ import annotations
+
+from typing import Any
+from typing import TYPE_CHECKING
+
+from typing_extensions import override
+
+from ..tool_context import ToolContext
+from .base_retrieval_tool import BaseRetrievalTool
+
+if TYPE_CHECKING:
+ from llama_index.core.base.base_retriever import BaseRetriever
+
+
+class LlamaIndexRetrieval(BaseRetrievalTool):
+
+ def __init__(self, *, name: str, description: str, retriever: BaseRetriever):
+ super().__init__(name=name, description=description)
+ self.retriever = retriever
+
+ @override
+ async def run_async(
+ self, *, args: dict[str, Any], tool_context: ToolContext
+ ) -> Any:
+ return self.retriever.retrieve(args['query'])[0].text
diff --git a/src/google/adk/tools/retrieval/vertex_ai_rag_retrieval.py b/src/google/adk/tools/retrieval/vertex_ai_rag_retrieval.py
new file mode 100644
index 0000000..ad33262
--- /dev/null
+++ b/src/google/adk/tools/retrieval/vertex_ai_rag_retrieval.py
@@ -0,0 +1,107 @@
+# Copyright 2025 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""A retrieval tool that uses Vertex AI RAG to retrieve data."""
+
+from __future__ import annotations
+
+import logging
+from typing import Any
+from typing import TYPE_CHECKING
+
+from google.genai import types
+from typing_extensions import override
+from vertexai.preview import rag
+
+from ..tool_context import ToolContext
+from .base_retrieval_tool import BaseRetrievalTool
+
+if TYPE_CHECKING:
+ from ...models.llm_request import LlmRequest
+
+logger = logging.getLogger(__name__)
+
+
+class VertexAiRagRetrieval(BaseRetrievalTool):
+ """A retrieval tool that uses Vertex AI RAG (Retrieval-Augmented Generation) to retrieve data."""
+
+ def __init__(
+ self,
+ *,
+ name: str,
+ description: str,
+ rag_corpora: list[str] = None,
+ rag_resources: list[rag.RagResource] = None,
+ similarity_top_k: int = None,
+ vector_distance_threshold: float = None,
+ ):
+ super().__init__(name=name, description=description)
+ self.vertex_rag_store = types.VertexRagStore(
+ rag_corpora=rag_corpora,
+ rag_resources=rag_resources,
+ similarity_top_k=similarity_top_k,
+ vector_distance_threshold=vector_distance_threshold,
+ )
+
+ @override
+ async def process_llm_request(
+ self,
+ *,
+ tool_context: ToolContext,
+ llm_request: LlmRequest,
+ ) -> None:
+ # Use Gemini built-in Vertex AI RAG tool for Gemini 2 models.
+ if llm_request.model and llm_request.model.startswith('gemini-2'):
+ llm_request.config = (
+ types.GenerateContentConfig()
+ if not llm_request.config
+ else llm_request.config
+ )
+ llm_request.config.tools = (
+ [] if not llm_request.config.tools else llm_request.config.tools
+ )
+ llm_request.config.tools.append(
+ types.Tool(
+ retrieval=types.Retrieval(vertex_rag_store=self.vertex_rag_store)
+ )
+ )
+ else:
+ # Add the function declaration to the tools
+ await super().process_llm_request(
+ tool_context=tool_context, llm_request=llm_request
+ )
+
+ @override
+ async def run_async(
+ self,
+ *,
+ args: dict[str, Any],
+ tool_context: ToolContext,
+ ) -> Any:
+
+ response = rag.retrieval_query(
+ text=args['query'],
+ rag_resources=self.vertex_rag_store.rag_resources,
+ rag_corpora=self.vertex_rag_store.rag_corpora,
+ similarity_top_k=self.vertex_rag_store.similarity_top_k,
+ vector_distance_threshold=self.vertex_rag_store.vector_distance_threshold,
+ )
+
+ logging.debug('RAG raw response: %s', response)
+
+ return (
+ f'No matching result found with the config: {self.vertex_rag_store}'
+ if not response.contexts.contexts
+ else [context.text for context in response.contexts.contexts]
+ )
diff --git a/src/google/adk/tools/tool_context.py b/src/google/adk/tools/tool_context.py
new file mode 100644
index 0000000..e2d1262
--- /dev/null
+++ b/src/google/adk/tools/tool_context.py
@@ -0,0 +1,90 @@
+# Copyright 2025 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import annotations
+
+from typing import Optional
+from typing import TYPE_CHECKING
+
+from ..agents.callback_context import CallbackContext
+from ..auth.auth_credential import AuthCredential
+from ..auth.auth_handler import AuthHandler
+from ..auth.auth_tool import AuthConfig
+
+if TYPE_CHECKING:
+ from ..agents.invocation_context import InvocationContext
+ from ..events.event_actions import EventActions
+ from ..memory.base_memory_service import SearchMemoryResponse
+
+
+class ToolContext(CallbackContext):
+ """The context of the tool.
+
+ This class provides the context for a tool invocation, including access to
+ the invocation context, function call ID, event actions, and authentication
+ response. It also provides methods for requesting credentials, retrieving
+ authentication responses, listing artifacts, and searching memory.
+
+ Attributes:
+ invocation_context: The invocation context of the tool.
+ function_call_id: The function call id of the current tool call. This id was
+ returned in the function call event from LLM to identify a function call.
+ If LLM didn't return this id, ADK will assign one to it. This id is used
+ to map function call response to the original function call.
+ event_actions: The event actions of the current tool call.
+ """
+
+ def __init__(
+ self,
+ invocation_context: InvocationContext,
+ *,
+ function_call_id: Optional[str] = None,
+ event_actions: Optional[EventActions] = None,
+ ):
+ super().__init__(invocation_context, event_actions=event_actions)
+ self.function_call_id = function_call_id
+
+ @property
+ def actions(self) -> EventActions:
+ return self._event_actions
+
+ def request_credential(self, auth_config: AuthConfig) -> None:
+ if not self.function_call_id:
+ raise ValueError('function_call_id is not set.')
+ self._event_actions.requested_auth_configs[self.function_call_id] = (
+ AuthHandler(auth_config).generate_auth_request()
+ )
+
+ def get_auth_response(self, auth_config: AuthConfig) -> AuthCredential:
+ return AuthHandler(auth_config).get_auth_response(self.state)
+
+ def list_artifacts(self) -> list[str]:
+ """Lists the filenames of the artifacts attached to the current session."""
+ if self._invocation_context.artifact_service is None:
+ raise ValueError('Artifact service is not initialized.')
+ return self._invocation_context.artifact_service.list_artifact_keys(
+ app_name=self._invocation_context.app_name,
+ user_id=self._invocation_context.user_id,
+ session_id=self._invocation_context.session.id,
+ )
+
+ def search_memory(self, query: str) -> 'SearchMemoryResponse':
+ """Searches the memory of the current user."""
+ if self._invocation_context.memory_service is None:
+ raise ValueError('Memory service is not available.')
+ return self._invocation_context.memory_service.search_memory(
+ app_name=self._invocation_context.app_name,
+ user_id=self._invocation_context.user_id,
+ query=query,
+ )
diff --git a/src/google/adk/tools/toolbox_tool.py b/src/google/adk/tools/toolbox_tool.py
new file mode 100644
index 0000000..06e8390
--- /dev/null
+++ b/src/google/adk/tools/toolbox_tool.py
@@ -0,0 +1,46 @@
+# Copyright 2025 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import Any
+
+from . import _automatic_function_calling_util
+from .langchain_tool import LangchainTool
+
+
+class ToolboxTool:
+ """A class that provides access to toolbox tools.
+
+ Example:
+ ```python
+ toolbox = ToolboxTool("http://127.0.0.1:8080")
+ tool = toolbox.get_tool("tool_name")
+ toolset = toolbox.get_toolset("toolset_name")
+ ```
+ """
+
+ toolbox_client: Any
+ """The toolbox client."""
+
+ def __init__(self, url: str):
+ from toolbox_langchain import ToolboxClient
+
+ self.toolbox_client = ToolboxClient(url)
+
+ def get_tool(self, tool_name: str) -> LangchainTool:
+ tool = self.toolbox_client.load_tool(tool_name)
+ return LangchainTool(tool)
+
+ def get_toolset(self, toolset_name: str) -> list[LangchainTool]:
+ tools = self.toolbox_client.load_toolset(toolset_name)
+ return [LangchainTool(tool) for tool in tools]
diff --git a/src/google/adk/tools/transfer_to_agent_tool.py b/src/google/adk/tools/transfer_to_agent_tool.py
new file mode 100644
index 0000000..dea624e
--- /dev/null
+++ b/src/google/adk/tools/transfer_to_agent_tool.py
@@ -0,0 +1,21 @@
+# Copyright 2025 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from .tool_context import ToolContext
+
+
+# TODO: make this internal, since user doesn't need to use this tool directly.
+def transfer_to_agent(agent_name: str, tool_context: ToolContext):
+ """Transfer the question to another agent."""
+ tool_context.actions.transfer_to_agent = agent_name
diff --git a/src/google/adk/tools/vertex_ai_search_tool.py b/src/google/adk/tools/vertex_ai_search_tool.py
new file mode 100644
index 0000000..ebe236e
--- /dev/null
+++ b/src/google/adk/tools/vertex_ai_search_tool.py
@@ -0,0 +1,96 @@
+# Copyright 2025 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import annotations
+
+from typing import Optional
+from typing import TYPE_CHECKING
+
+from google.genai import types
+from typing_extensions import override
+
+from .base_tool import BaseTool
+from .tool_context import ToolContext
+
+if TYPE_CHECKING:
+ from ..models import LlmRequest
+
+
+class VertexAiSearchTool(BaseTool):
+ """A built-in tool using Vertex AI Search.
+
+ Attributes:
+ data_store_id: The Vertex AI search data store resource ID.
+ search_engine_id: The Vertex AI search engine resource ID.
+ """
+
+ def __init__(
+ self,
+ *,
+ data_store_id: Optional[str] = None,
+ search_engine_id: Optional[str] = None,
+ ):
+ """Initializes the Vertex AI Search tool.
+
+ Args:
+ data_store_id: The Vertex AI search data store resource ID in the format
+ of
+ "projects/{project}/locations/{location}/collections/{collection}/dataStores/{dataStore}".
+ search_engine_id: The Vertex AI search engine resource ID in the format of
+ "projects/{project}/locations/{location}/collections/{collection}/engines/{engine}".
+
+ Raises:
+ ValueError: If both data_store_id and search_engine_id are not specified
+ or both are specified.
+ """
+ # Name and description are not used because this is a model built-in tool.
+ super().__init__(name='vertex_ai_search', description='vertex_ai_search')
+ if (data_store_id is None and search_engine_id is None) or (
+ data_store_id is not None and search_engine_id is not None
+ ):
+ raise ValueError(
+ 'Either data_store_id or search_engine_id must be specified.'
+ )
+ self.data_store_id = data_store_id
+ self.search_engine_id = search_engine_id
+
+ @override
+ async def process_llm_request(
+ self,
+ *,
+ tool_context: ToolContext,
+ llm_request: LlmRequest,
+ ) -> None:
+ if llm_request.model and llm_request.model.startswith('gemini-'):
+ if llm_request.model.startswith('gemini-1') and llm_request.config.tools:
+ raise ValueError(
+ 'Vertex AI search tool can not be used with other tools in Gemini'
+ ' 1.x.'
+ )
+ llm_request.config = llm_request.config or types.GenerateContentConfig()
+ llm_request.config.tools = llm_request.config.tools or []
+ llm_request.config.tools.append(
+ types.Tool(
+ retrieval=types.Retrieval(
+ vertex_ai_search=types.VertexAISearch(
+ datastore=self.data_store_id, engine=self.search_engine_id
+ )
+ )
+ )
+ )
+ else:
+ raise ValueError(
+ 'Vertex AI search tool is not supported for model'
+ f' {llm_request.model}'
+ )
diff --git a/src/google/adk/version.py b/src/google/adk/version.py
new file mode 100644
index 0000000..8993cf8
--- /dev/null
+++ b/src/google/adk/version.py
@@ -0,0 +1,16 @@
+# Copyright 2025 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# version: date+base_cl
+__version__ = "0.0.2"