mirror of
https://github.com/EvolutionAPI/adk-python.git
synced 2025-12-10 18:39:37 -06:00
Agent Development Kit(ADK)
An easy-to-use and powerful framework to build AI agents.
This commit is contained in:
parent
f92478bd5c
commit
9827820143
33
CONTRIBUTING.md
Normal file
33
CONTRIBUTING.md
Normal file
@ -0,0 +1,33 @@
|
||||
# How to contribute
|
||||
|
||||
We'd love to accept your patches and contributions to this project.
|
||||
|
||||
## Before you begin
|
||||
|
||||
### Sign our Contributor License Agreement
|
||||
|
||||
Contributions to this project must be accompanied by a
|
||||
[Contributor License Agreement](https://cla.developers.google.com/about) (CLA).
|
||||
You (or your employer) retain the copyright to your contribution; this simply
|
||||
gives us permission to use and redistribute your contributions as part of the
|
||||
project.
|
||||
|
||||
If you or your current employer have already signed the Google CLA (even if it
|
||||
was for a different project), you probably don't need to do it again.
|
||||
|
||||
Visit <https://cla.developers.google.com/> to see your current agreements or to
|
||||
sign a new one.
|
||||
|
||||
### Review our community guidelines
|
||||
|
||||
This project follows
|
||||
[Google's Open Source Community Guidelines](https://opensource.google/conduct/).
|
||||
|
||||
## Contribution process
|
||||
|
||||
### Code reviews
|
||||
|
||||
All submissions, including submissions by project members, require review. We
|
||||
use GitHub pull requests for this purpose. Consult
|
||||
[GitHub Help](https://help.github.com/articles/about-pull-requests/) for more
|
||||
information on using pull requests.
|
||||
202
LICENSE
Normal file
202
LICENSE
Normal file
@ -0,0 +1,202 @@
|
||||
|
||||
Apache License
|
||||
Version 2.0, January 2004
|
||||
http://www.apache.org/licenses/
|
||||
|
||||
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
||||
|
||||
1. Definitions.
|
||||
|
||||
"License" shall mean the terms and conditions for use, reproduction,
|
||||
and distribution as defined by Sections 1 through 9 of this document.
|
||||
|
||||
"Licensor" shall mean the copyright owner or entity authorized by
|
||||
the copyright owner that is granting the License.
|
||||
|
||||
"Legal Entity" shall mean the union of the acting entity and all
|
||||
other entities that control, are controlled by, or are under common
|
||||
control with that entity. For the purposes of this definition,
|
||||
"control" means (i) the power, direct or indirect, to cause the
|
||||
direction or management of such entity, whether by contract or
|
||||
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
||||
outstanding shares, or (iii) beneficial ownership of such entity.
|
||||
|
||||
"You" (or "Your") shall mean an individual or Legal Entity
|
||||
exercising permissions granted by this License.
|
||||
|
||||
"Source" form shall mean the preferred form for making modifications,
|
||||
including but not limited to software source code, documentation
|
||||
source, and configuration files.
|
||||
|
||||
"Object" form shall mean any form resulting from mechanical
|
||||
transformation or translation of a Source form, including but
|
||||
not limited to compiled object code, generated documentation,
|
||||
and conversions to other media types.
|
||||
|
||||
"Work" shall mean the work of authorship, whether in Source or
|
||||
Object form, made available under the License, as indicated by a
|
||||
copyright notice that is included in or attached to the work
|
||||
(an example is provided in the Appendix below).
|
||||
|
||||
"Derivative Works" shall mean any work, whether in Source or Object
|
||||
form, that is based on (or derived from) the Work and for which the
|
||||
editorial revisions, annotations, elaborations, or other modifications
|
||||
represent, as a whole, an original work of authorship. For the purposes
|
||||
of this License, Derivative Works shall not include works that remain
|
||||
separable from, or merely link (or bind by name) to the interfaces of,
|
||||
the Work and Derivative Works thereof.
|
||||
|
||||
"Contribution" shall mean any work of authorship, including
|
||||
the original version of the Work and any modifications or additions
|
||||
to that Work or Derivative Works thereof, that is intentionally
|
||||
submitted to Licensor for inclusion in the Work by the copyright owner
|
||||
or by an individual or Legal Entity authorized to submit on behalf of
|
||||
the copyright owner. For the purposes of this definition, "submitted"
|
||||
means any form of electronic, verbal, or written communication sent
|
||||
to the Licensor or its representatives, including but not limited to
|
||||
communication on electronic mailing lists, source code control systems,
|
||||
and issue tracking systems that are managed by, or on behalf of, the
|
||||
Licensor for the purpose of discussing and improving the Work, but
|
||||
excluding communication that is conspicuously marked or otherwise
|
||||
designated in writing by the copyright owner as "Not a Contribution."
|
||||
|
||||
"Contributor" shall mean Licensor and any individual or Legal Entity
|
||||
on behalf of whom a Contribution has been received by Licensor and
|
||||
subsequently incorporated within the Work.
|
||||
|
||||
2. Grant of Copyright License. Subject to the terms and conditions of
|
||||
this License, each Contributor hereby grants to You a perpetual,
|
||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||
copyright license to reproduce, prepare Derivative Works of,
|
||||
publicly display, publicly perform, sublicense, and distribute the
|
||||
Work and such Derivative Works in Source or Object form.
|
||||
|
||||
3. Grant of Patent License. Subject to the terms and conditions of
|
||||
this License, each Contributor hereby grants to You a perpetual,
|
||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||
(except as stated in this section) patent license to make, have made,
|
||||
use, offer to sell, sell, import, and otherwise transfer the Work,
|
||||
where such license applies only to those patent claims licensable
|
||||
by such Contributor that are necessarily infringed by their
|
||||
Contribution(s) alone or by combination of their Contribution(s)
|
||||
with the Work to which such Contribution(s) was submitted. If You
|
||||
institute patent litigation against any entity (including a
|
||||
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
||||
or a Contribution incorporated within the Work constitutes direct
|
||||
or contributory patent infringement, then any patent licenses
|
||||
granted to You under this License for that Work shall terminate
|
||||
as of the date such litigation is filed.
|
||||
|
||||
4. Redistribution. You may reproduce and distribute copies of the
|
||||
Work or Derivative Works thereof in any medium, with or without
|
||||
modifications, and in Source or Object form, provided that You
|
||||
meet the following conditions:
|
||||
|
||||
(a) You must give any other recipients of the Work or
|
||||
Derivative Works a copy of this License; and
|
||||
|
||||
(b) You must cause any modified files to carry prominent notices
|
||||
stating that You changed the files; and
|
||||
|
||||
(c) You must retain, in the Source form of any Derivative Works
|
||||
that You distribute, all copyright, patent, trademark, and
|
||||
attribution notices from the Source form of the Work,
|
||||
excluding those notices that do not pertain to any part of
|
||||
the Derivative Works; and
|
||||
|
||||
(d) If the Work includes a "NOTICE" text file as part of its
|
||||
distribution, then any Derivative Works that You distribute must
|
||||
include a readable copy of the attribution notices contained
|
||||
within such NOTICE file, excluding those notices that do not
|
||||
pertain to any part of the Derivative Works, in at least one
|
||||
of the following places: within a NOTICE text file distributed
|
||||
as part of the Derivative Works; within the Source form or
|
||||
documentation, if provided along with the Derivative Works; or,
|
||||
within a display generated by the Derivative Works, if and
|
||||
wherever such third-party notices normally appear. The contents
|
||||
of the NOTICE file are for informational purposes only and
|
||||
do not modify the License. You may add Your own attribution
|
||||
notices within Derivative Works that You distribute, alongside
|
||||
or as an addendum to the NOTICE text from the Work, provided
|
||||
that such additional attribution notices cannot be construed
|
||||
as modifying the License.
|
||||
|
||||
You may add Your own copyright statement to Your modifications and
|
||||
may provide additional or different license terms and conditions
|
||||
for use, reproduction, or distribution of Your modifications, or
|
||||
for any such Derivative Works as a whole, provided Your use,
|
||||
reproduction, and distribution of the Work otherwise complies with
|
||||
the conditions stated in this License.
|
||||
|
||||
5. Submission of Contributions. Unless You explicitly state otherwise,
|
||||
any Contribution intentionally submitted for inclusion in the Work
|
||||
by You to the Licensor shall be under the terms and conditions of
|
||||
this License, without any additional terms or conditions.
|
||||
Notwithstanding the above, nothing herein shall supersede or modify
|
||||
the terms of any separate license agreement you may have executed
|
||||
with Licensor regarding such Contributions.
|
||||
|
||||
6. Trademarks. This License does not grant permission to use the trade
|
||||
names, trademarks, service marks, or product names of the Licensor,
|
||||
except as required for reasonable and customary use in describing the
|
||||
origin of the Work and reproducing the content of the NOTICE file.
|
||||
|
||||
7. Disclaimer of Warranty. Unless required by applicable law or
|
||||
agreed to in writing, Licensor provides the Work (and each
|
||||
Contributor provides its Contributions) on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
||||
implied, including, without limitation, any warranties or conditions
|
||||
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
||||
PARTICULAR PURPOSE. You are solely responsible for determining the
|
||||
appropriateness of using or redistributing the Work and assume any
|
||||
risks associated with Your exercise of permissions under this License.
|
||||
|
||||
8. Limitation of Liability. In no event and under no legal theory,
|
||||
whether in tort (including negligence), contract, or otherwise,
|
||||
unless required by applicable law (such as deliberate and grossly
|
||||
negligent acts) or agreed to in writing, shall any Contributor be
|
||||
liable to You for damages, including any direct, indirect, special,
|
||||
incidental, or consequential damages of any character arising as a
|
||||
result of this License or out of the use or inability to use the
|
||||
Work (including but not limited to damages for loss of goodwill,
|
||||
work stoppage, computer failure or malfunction, or any and all
|
||||
other commercial damages or losses), even if such Contributor
|
||||
has been advised of the possibility of such damages.
|
||||
|
||||
9. Accepting Warranty or Additional Liability. While redistributing
|
||||
the Work or Derivative Works thereof, You may choose to offer,
|
||||
and charge a fee for, acceptance of support, warranty, indemnity,
|
||||
or other liability obligations and/or rights consistent with this
|
||||
License. However, in accepting such obligations, You may act only
|
||||
on Your own behalf and on Your sole responsibility, not on behalf
|
||||
of any other Contributor, and only if You agree to indemnify,
|
||||
defend, and hold each Contributor harmless for any liability
|
||||
incurred by, or claims asserted against, such Contributor by reason
|
||||
of your accepting any such warranty or additional liability.
|
||||
|
||||
END OF TERMS AND CONDITIONS
|
||||
|
||||
APPENDIX: How to apply the Apache License to your work.
|
||||
|
||||
To apply the Apache License to your work, attach the following
|
||||
boilerplate notice, with the fields enclosed by brackets "[]"
|
||||
replaced with your own identifying information. (Don't include
|
||||
the brackets!) The text should be enclosed in the appropriate
|
||||
comment syntax for the file format. We also recommend that a
|
||||
file or class name and description of purpose be included on the
|
||||
same "printed page" as the copyright notice for easier
|
||||
identification within third-party archives.
|
||||
|
||||
Copyright [yyyy] [name of copyright owner]
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
96
README.md
96
README.md
@ -1,3 +1,95 @@
|
||||
# adk-python
|
||||
# Agent Development Kit (ADK)
|
||||
|
||||
Hello World!
|
||||
[](LICENSE)
|
||||
|
||||
<img src="assets/agent-development-kit.png" alt="Agent Development Kit Logo" width="150">
|
||||
|
||||
**An open-source, code-first Python toolkit for building, evaluating, and deploying sophisticated AI agents with flexibility and control.**
|
||||
|
||||
The Agent Development Kit (ADK) is designed for developers seeking fine-grained control and flexibility when building advanced AI agents that are tightly integrated with services in Google Cloud. It allows you to define agent behavior, orchestration, and tool use directly in code, enabling robust debugging, versioning, and deployment anywhere – from your laptop to the cloud.
|
||||
|
||||
---
|
||||
|
||||
## ✨ Key Features
|
||||
|
||||
* **Code-First Development:** Define agents, tools, and orchestration logic for maximum control, testability, and versioning.
|
||||
* **Multi-Agent Architecture:** Build modular and scalable applications by composing multiple specialized agents in flexible hierarchies.
|
||||
* **Rich Tool Ecosystem:** Equip agents with diverse capabilities using pre-built tools, custom Python functions, API specifications, or integrating existing tools.
|
||||
* **Flexible Orchestration:** Define workflows using built-in agents for predictable pipelines, or leverage LLM-driven dynamic routing for adaptive behavior.
|
||||
* **Integrated Developer Experience:** Develop, test, and debug locally with a CLI and visual web UI.
|
||||
* **Built-in Evaluation:** Measure agent performance by evaluating response quality and step-by-step execution trajectory.
|
||||
* **Deployment Ready:** Containerize and deploy your agents anywhere – scale with Vertex AI Agent Engine, Cloud Run, or Docker.
|
||||
* **Native Streaming Support:** Build real-time, interactive experiences with native support for bidirectional streaming (text and audio).
|
||||
* **State, Memory & Artifacts:** Manage short-term conversational context, configure long-term memory, and handle file uploads/downloads.
|
||||
* **Extensibility:** Customize agent behavior deeply with callbacks and easily integrate third-party tools and services.
|
||||
|
||||
## 🚀 Installation
|
||||
|
||||
You can install the Agent Developer Kit using `pip`:
|
||||
|
||||
```bash
|
||||
pip install google-adk
|
||||
```
|
||||
|
||||
## 🏁 Getting Started
|
||||
|
||||
Create your first agent (`my_agent/agent.py`):
|
||||
|
||||
```python
|
||||
# my_agent/agent.py
|
||||
from google.adk.agents import Agent
|
||||
from google.adk.tools import google_search
|
||||
|
||||
root_agent = Agent(
|
||||
name="search_assistant",
|
||||
model="gemini-1.5-flash-latest", # Or your preferred model like gemini-2.0-flash-001
|
||||
instruction="You are a helpful assistant. Answer user questions using Google Search when needed.",
|
||||
description="An assistant that can search the web.",
|
||||
tools=[google_search]
|
||||
)
|
||||
```
|
||||
|
||||
Create `my_agent/__init__.py`:
|
||||
|
||||
```python
|
||||
# my_agent/__init__.py
|
||||
from . import agent
|
||||
```
|
||||
|
||||
Run it via the CLI (from the directory *containing* `my_agent`):
|
||||
|
||||
```bash
|
||||
adk run my_agent
|
||||
```
|
||||
|
||||
Or launch the Web UI from the folder that contains `my_agent` folder:
|
||||
|
||||
```bash
|
||||
adk web
|
||||
```
|
||||
|
||||
For a full step-by-step guide, check out the quickstart or sample agents.
|
||||
|
||||
## 📚 Resources
|
||||
|
||||
Explore the full documentation for detailed guides on building, evaluating, and deploying agents:
|
||||
|
||||
* **[Get Started](get-started/introduction.md)**
|
||||
* **[Build Agents](build/agents.md)**
|
||||
* **[Browse Sample Agents](learn/sample_agents/)**
|
||||
* **[Evaluate Agents](evaluate/evaluate-agents.md)**
|
||||
* **[Deploy Agents](deploy/overview.md)**
|
||||
* **[API Reference](guides/reference.md)**
|
||||
* **[Troubleshooting](guides/troubleshooting.md)**
|
||||
|
||||
## 🤝 Contributing
|
||||
|
||||
We welcome contributions from the community! Whether it's bug reports, feature requests, documentation improvements, or code contributions, please see our [**Contributing Guidelines**](./CONTRIBUTING.md) to get started.
|
||||
|
||||
## 📄 License
|
||||
|
||||
This project is licensed under the Apache 2.0 License - see the [LICENSE](LICENSE) file for details.
|
||||
|
||||
---
|
||||
|
||||
*Happy Agent Building!*
|
||||
BIN
assets/agent-development-kit.png
Normal file
BIN
assets/agent-development-kit.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 12 KiB |
400
pylintrc
Normal file
400
pylintrc
Normal file
@ -0,0 +1,400 @@
|
||||
# This Pylint rcfile contains a best-effort configuration to uphold the
|
||||
# best-practices and style described in the Google Python style guide:
|
||||
# https://google.github.io/styleguide/pyguide.html
|
||||
#
|
||||
# Its canonical open-source location is:
|
||||
# https://google.github.io/styleguide/pylintrc
|
||||
|
||||
[MAIN]
|
||||
|
||||
# Files or directories to be skipped. They should be base names, not paths.
|
||||
ignore=third_party
|
||||
|
||||
# Files or directories matching the regex patterns are skipped. The regex
|
||||
# matches against base names, not paths.
|
||||
ignore-patterns=
|
||||
|
||||
# Pickle collected data for later comparisons.
|
||||
persistent=no
|
||||
|
||||
# List of plugins (as comma separated values of python modules names) to load,
|
||||
# usually to register additional checkers.
|
||||
load-plugins=
|
||||
|
||||
# Use multiple processes to speed up Pylint.
|
||||
jobs=4
|
||||
|
||||
# Allow loading of arbitrary C extensions. Extensions are imported into the
|
||||
# active Python interpreter and may run arbitrary code.
|
||||
unsafe-load-any-extension=no
|
||||
|
||||
|
||||
[MESSAGES CONTROL]
|
||||
|
||||
# Only show warnings with the listed confidence levels. Leave empty to show
|
||||
# all. Valid levels: HIGH, INFERENCE, INFERENCE_FAILURE, UNDEFINED
|
||||
confidence=
|
||||
|
||||
# Enable the message, report, category or checker with the given id(s). You can
|
||||
# either give multiple identifier separated by comma (,) or put this option
|
||||
# multiple time (only on the command line, not in the configuration file where
|
||||
# it should appear only once). See also the "--disable" option for examples.
|
||||
#enable=
|
||||
|
||||
# Disable the message, report, category or checker with the given id(s). You
|
||||
# can either give multiple identifiers separated by comma (,) or put this
|
||||
# option multiple times (only on the command line, not in the configuration
|
||||
# file where it should appear only once).You can also use "--disable=all" to
|
||||
# disable everything first and then reenable specific checks. For example, if
|
||||
# you want to run only the similarities checker, you can use "--disable=all
|
||||
# --enable=similarities". If you want to run only the classes checker, but have
|
||||
# no Warning level messages displayed, use"--disable=all --enable=classes
|
||||
# --disable=W"
|
||||
disable=R,
|
||||
abstract-method,
|
||||
apply-builtin,
|
||||
arguments-differ,
|
||||
attribute-defined-outside-init,
|
||||
backtick,
|
||||
bad-option-value,
|
||||
basestring-builtin,
|
||||
buffer-builtin,
|
||||
c-extension-no-member,
|
||||
consider-using-enumerate,
|
||||
cmp-builtin,
|
||||
cmp-method,
|
||||
coerce-builtin,
|
||||
coerce-method,
|
||||
delslice-method,
|
||||
div-method,
|
||||
eq-without-hash,
|
||||
execfile-builtin,
|
||||
file-builtin,
|
||||
filter-builtin-not-iterating,
|
||||
fixme,
|
||||
getslice-method,
|
||||
global-statement,
|
||||
hex-method,
|
||||
idiv-method,
|
||||
implicit-str-concat,
|
||||
import-error,
|
||||
import-self,
|
||||
import-star-module-level,
|
||||
import-outside-toplevel,
|
||||
input-builtin,
|
||||
intern-builtin,
|
||||
invalid-str-codec,
|
||||
locally-disabled,
|
||||
long-builtin,
|
||||
long-suffix,
|
||||
map-builtin-not-iterating,
|
||||
misplaced-comparison-constant,
|
||||
missing-function-docstring,
|
||||
metaclass-assignment,
|
||||
next-method-called,
|
||||
next-method-defined,
|
||||
no-absolute-import,
|
||||
no-init, # added
|
||||
no-member,
|
||||
no-name-in-module,
|
||||
no-self-use,
|
||||
nonzero-method,
|
||||
oct-method,
|
||||
old-division,
|
||||
old-ne-operator,
|
||||
old-octal-literal,
|
||||
old-raise-syntax,
|
||||
parameter-unpacking,
|
||||
print-statement,
|
||||
raising-string,
|
||||
range-builtin-not-iterating,
|
||||
raw_input-builtin,
|
||||
rdiv-method,
|
||||
reduce-builtin,
|
||||
relative-import,
|
||||
reload-builtin,
|
||||
round-builtin,
|
||||
setslice-method,
|
||||
signature-differs,
|
||||
standarderror-builtin,
|
||||
suppressed-message,
|
||||
sys-max-int,
|
||||
trailing-newlines,
|
||||
unichr-builtin,
|
||||
unicode-builtin,
|
||||
unnecessary-pass,
|
||||
unpacking-in-except,
|
||||
useless-else-on-loop,
|
||||
useless-suppression,
|
||||
using-cmp-argument,
|
||||
wrong-import-order,
|
||||
xrange-builtin,
|
||||
zip-builtin-not-iterating,
|
||||
|
||||
|
||||
[REPORTS]
|
||||
|
||||
# Set the output format. Available formats are text, parseable, colorized, msvs
|
||||
# (visual studio) and html. You can also give a reporter class, eg
|
||||
# mypackage.mymodule.MyReporterClass.
|
||||
output-format=text
|
||||
|
||||
# Tells whether to display a full report or only the messages
|
||||
reports=no
|
||||
|
||||
# Python expression which should return a note less than 10 (10 is the highest
|
||||
# note). You have access to the variables errors warning, statement which
|
||||
# respectively contain the number of errors / warnings messages and the total
|
||||
# number of statements analyzed. This is used by the global evaluation report
|
||||
# (RP0004).
|
||||
evaluation=10.0 - ((float(5 * error + warning + refactor + convention) / statement) * 10)
|
||||
|
||||
# Template used to display messages. This is a python new-style format string
|
||||
# used to format the message information. See doc for all details
|
||||
#msg-template=
|
||||
|
||||
|
||||
[BASIC]
|
||||
|
||||
# Good variable names which should always be accepted, separated by a comma
|
||||
good-names=main,_
|
||||
|
||||
# Bad variable names which should always be refused, separated by a comma
|
||||
bad-names=
|
||||
|
||||
# Colon-delimited sets of names that determine each other's naming style when
|
||||
# the name regexes allow several styles.
|
||||
name-group=
|
||||
|
||||
# Include a hint for the correct naming format with invalid-name
|
||||
include-naming-hint=no
|
||||
|
||||
# List of decorators that produce properties, such as abc.abstractproperty. Add
|
||||
# to this list to register other decorators that produce valid properties.
|
||||
property-classes=abc.abstractproperty,cached_property.cached_property,cached_property.threaded_cached_property,cached_property.cached_property_with_ttl,cached_property.threaded_cached_property_with_ttl
|
||||
|
||||
# Regular expression matching correct function names
|
||||
function-rgx=^(?:(?P<exempt>setUp|tearDown|setUpModule|tearDownModule)|(?P<camel_case>_?[A-Z][a-zA-Z0-9]*)|(?P<snake_case>_?[a-z][a-z0-9_]*))$
|
||||
|
||||
# Regular expression matching correct variable names
|
||||
variable-rgx=^[a-z][a-z0-9_]*$
|
||||
|
||||
# Regular expression matching correct constant names
|
||||
const-rgx=^(_?[A-Z][A-Z0-9_]*|__[a-z0-9_]+__|_?[a-z][a-z0-9_]*)$
|
||||
|
||||
# Regular expression matching correct attribute names
|
||||
attr-rgx=^_{0,2}[a-z][a-z0-9_]*$
|
||||
|
||||
# Regular expression matching correct argument names
|
||||
argument-rgx=^[a-z][a-z0-9_]*$
|
||||
|
||||
# Regular expression matching correct class attribute names
|
||||
class-attribute-rgx=^(_?[A-Z][A-Z0-9_]*|__[a-z0-9_]+__|_?[a-z][a-z0-9_]*)$
|
||||
|
||||
# Regular expression matching correct inline iteration names
|
||||
inlinevar-rgx=^[a-z][a-z0-9_]*$
|
||||
|
||||
# Regular expression matching correct class names
|
||||
class-rgx=^_?[A-Z][a-zA-Z0-9]*$
|
||||
|
||||
# Regular expression matching correct module names
|
||||
module-rgx=^(_?[a-z][a-z0-9_]*|__init__)$
|
||||
|
||||
# Regular expression matching correct method names
|
||||
method-rgx=(?x)^(?:(?P<exempt>_[a-z0-9_]+__|runTest|setUp|tearDown|setUpTestCase|tearDownTestCase|setupSelf|tearDownClass|setUpClass|(test|assert)_*[A-Z0-9][a-zA-Z0-9_]*|next)|(?P<camel_case>_{0,2}[A-Z][a-zA-Z0-9_]*)|(?P<snake_case>_{0,2}[a-z][a-z0-9_]*))$
|
||||
|
||||
# Regular expression which should only match function or class names that do
|
||||
# not require a docstring.
|
||||
no-docstring-rgx=(__.*__|main|test.*|.*test|.*Test)$
|
||||
|
||||
# Minimum line length for functions/classes that require docstrings, shorter
|
||||
# ones are exempt.
|
||||
docstring-min-length=12
|
||||
|
||||
|
||||
[TYPECHECK]
|
||||
|
||||
# List of decorators that produce context managers, such as
|
||||
# contextlib.contextmanager. Add to this list to register other decorators that
|
||||
# produce valid context managers.
|
||||
contextmanager-decorators=contextlib.contextmanager,contextlib2.contextmanager
|
||||
|
||||
# List of module names for which member attributes should not be checked
|
||||
# (useful for modules/projects where namespaces are manipulated during runtime
|
||||
# and thus existing member attributes cannot be deduced by static analysis. It
|
||||
# supports qualified module names, as well as Unix pattern matching.
|
||||
ignored-modules=
|
||||
|
||||
# List of class names for which member attributes should not be checked (useful
|
||||
# for classes with dynamically set attributes). This supports the use of
|
||||
# qualified names.
|
||||
ignored-classes=optparse.Values,thread._local,_thread._local
|
||||
|
||||
# List of members which are set dynamically and missed by pylint inference
|
||||
# system, and so shouldn't trigger E1101 when accessed. Python regular
|
||||
# expressions are accepted.
|
||||
generated-members=
|
||||
|
||||
|
||||
[FORMAT]
|
||||
|
||||
# Maximum number of characters on a single line.
|
||||
max-line-length=80
|
||||
|
||||
# TODO(https://github.com/pylint-dev/pylint/issues/3352): Direct pylint to exempt
|
||||
# lines made too long by directives to pytype.
|
||||
|
||||
# Regexp for a line that is allowed to be longer than the limit.
|
||||
ignore-long-lines=(?x)(
|
||||
^\s*(\#\ )?<?https?://\S+>?$|
|
||||
^\s*(from\s+\S+\s+)?import\s+.+$)
|
||||
|
||||
# Allow the body of an if to be on the same line as the test if there is no
|
||||
# else.
|
||||
single-line-if-stmt=yes
|
||||
|
||||
# Maximum number of lines in a module
|
||||
max-module-lines=99999
|
||||
|
||||
# String used as indentation unit. The internal Google style guide mandates 2
|
||||
# spaces. Google's externaly-published style guide says 4, consistent with
|
||||
# PEP 8. Here, we use 2 spaces, for conformity with many open-sourced Google
|
||||
# projects (like TensorFlow).
|
||||
indent-string=' '
|
||||
|
||||
# Number of spaces of indent required inside a hanging or continued line.
|
||||
indent-after-paren=4
|
||||
|
||||
# Expected format of line ending, e.g. empty (any line ending), LF or CRLF.
|
||||
expected-line-ending-format=
|
||||
|
||||
|
||||
[MISCELLANEOUS]
|
||||
|
||||
# List of note tags to take in consideration, separated by a comma.
|
||||
notes=TODO
|
||||
|
||||
|
||||
[STRING]
|
||||
|
||||
# This flag controls whether inconsistent-quotes generates a warning when the
|
||||
# character used as a quote delimiter is used inconsistently within a module.
|
||||
check-quote-consistency=yes
|
||||
|
||||
|
||||
[VARIABLES]
|
||||
|
||||
# Tells whether we should check for unused import in __init__ files.
|
||||
init-import=no
|
||||
|
||||
# A regular expression matching the name of dummy variables (i.e. expectedly
|
||||
# not used).
|
||||
dummy-variables-rgx=^\*{0,2}(_$|unused_|dummy_)
|
||||
|
||||
# List of additional names supposed to be defined in builtins. Remember that
|
||||
# you should avoid to define new builtins when possible.
|
||||
additional-builtins=
|
||||
|
||||
# List of strings which can identify a callback function by name. A callback
|
||||
# name must start or end with one of those strings.
|
||||
callbacks=cb_,_cb
|
||||
|
||||
# List of qualified module names which can have objects that can redefine
|
||||
# builtins.
|
||||
redefining-builtins-modules=six,six.moves,past.builtins,future.builtins,functools
|
||||
|
||||
|
||||
[LOGGING]
|
||||
|
||||
# Logging modules to check that the string format arguments are in logging
|
||||
# function parameter format
|
||||
logging-modules=logging,absl.logging,tensorflow.io.logging
|
||||
|
||||
|
||||
[SIMILARITIES]
|
||||
|
||||
# Minimum lines number of a similarity.
|
||||
min-similarity-lines=4
|
||||
|
||||
# Ignore comments when computing similarities.
|
||||
ignore-comments=yes
|
||||
|
||||
# Ignore docstrings when computing similarities.
|
||||
ignore-docstrings=yes
|
||||
|
||||
# Ignore imports when computing similarities.
|
||||
ignore-imports=no
|
||||
|
||||
|
||||
[SPELLING]
|
||||
|
||||
# Spelling dictionary name. Available dictionaries: none. To make it working
|
||||
# install python-enchant package.
|
||||
spelling-dict=
|
||||
|
||||
# List of comma separated words that should not be checked.
|
||||
spelling-ignore-words=
|
||||
|
||||
# A path to a file that contains private dictionary; one word per line.
|
||||
spelling-private-dict-file=
|
||||
|
||||
# Tells whether to store unknown words to indicated private dictionary in
|
||||
# --spelling-private-dict-file option instead of raising a message.
|
||||
spelling-store-unknown-words=no
|
||||
|
||||
|
||||
[IMPORTS]
|
||||
|
||||
# Deprecated modules which should not be used, separated by a comma
|
||||
deprecated-modules=regsub,
|
||||
TERMIOS,
|
||||
Bastion,
|
||||
rexec,
|
||||
sets
|
||||
|
||||
# Create a graph of every (i.e. internal and external) dependencies in the
|
||||
# given file (report RP0402 must not be disabled)
|
||||
import-graph=
|
||||
|
||||
# Create a graph of external dependencies in the given file (report RP0402 must
|
||||
# not be disabled)
|
||||
ext-import-graph=
|
||||
|
||||
# Create a graph of internal dependencies in the given file (report RP0402 must
|
||||
# not be disabled)
|
||||
int-import-graph=
|
||||
|
||||
# Force import order to recognize a module as part of the standard
|
||||
# compatibility libraries.
|
||||
known-standard-library=
|
||||
|
||||
# Force import order to recognize a module as part of a third party library.
|
||||
known-third-party=enchant, absl
|
||||
|
||||
# Analyse import fallback blocks. This can be used to support both Python 2 and
|
||||
# 3 compatible code, which means that the block might have code that exists
|
||||
# only in one or another interpreter, leading to false positives when analysed.
|
||||
analyse-fallback-blocks=no
|
||||
|
||||
|
||||
[CLASSES]
|
||||
|
||||
# List of method names used to declare (i.e. assign) instance attributes.
|
||||
defining-attr-methods=__init__,
|
||||
__new__,
|
||||
setUp
|
||||
|
||||
# List of member names, which should be excluded from the protected access
|
||||
# warning.
|
||||
exclude-protected=_asdict,
|
||||
_fields,
|
||||
_replace,
|
||||
_source,
|
||||
_make
|
||||
|
||||
# List of valid names for the first argument in a class method.
|
||||
valid-classmethod-first-arg=cls,
|
||||
class_
|
||||
|
||||
# List of valid names for the first argument in a metaclass class method.
|
||||
valid-metaclass-classmethod-first-arg=mcs
|
||||
146
pyproject.toml
Normal file
146
pyproject.toml
Normal file
@ -0,0 +1,146 @@
|
||||
[project]
|
||||
# Project metadata. Available keys are documented at:
|
||||
# https://packaging.python.org/en/latest/specifications/declaring-project-metadata
|
||||
|
||||
name = "google-adk"
|
||||
description = "Agent Development Kit"
|
||||
readme = "README.md"
|
||||
requires-python = ">=3.9"
|
||||
license = { file = "LICENSE" }
|
||||
authors = [{ name = "Google LLC", email = "googleapis-packages@google.com" }]
|
||||
classifiers = [ # List of https://pypi.org/classifiers/
|
||||
"Typing :: Typed",
|
||||
"Intended Audience :: Developers",
|
||||
"Intended Audience :: Science/Research",
|
||||
"Programming Language :: Python",
|
||||
"Programming Language :: Python :: 3",
|
||||
"Programming Language :: Python :: 3.9",
|
||||
"Programming Language :: Python :: 3.13",
|
||||
"Programming Language :: Python :: 3.12",
|
||||
"Programming Language :: Python :: 3.11",
|
||||
"Programming Language :: Python :: 3.10",
|
||||
"Operating System :: OS Independent",
|
||||
"Topic :: Software Development :: Libraries :: Python Modules",
|
||||
"License :: OSI Approved :: Apache Software License",
|
||||
]
|
||||
dependencies = [
|
||||
# go/keep-sorted start
|
||||
"authlib>=1.5.1", # For RestAPI Tool
|
||||
"click>=8.1.8", # For CLI tools
|
||||
"fastapi>=0.115.0", # FastAPI framework
|
||||
"google-api-python-client>=2.157.0", # Google API client discovery
|
||||
"google-cloud-aiplatform>=1.87.0", # For VertexAI integrations, e.g. example store.
|
||||
"google-cloud-secret-manager>=2.22.0", # Fetching secrets in RestAPI Tool
|
||||
"google-cloud-speech>=2.30.0", # For Audo Transcription
|
||||
"google-cloud-storage>=2.18.0, <3.0.0", # For GCS Artifact service
|
||||
"google-genai>=1.9.0", # Google GenAI SDK
|
||||
"graphviz>=0.20.2", # Graphviz for graph rendering
|
||||
"mcp>=1.5.0;python_version>='3.10'", # For MCP Toolset
|
||||
"opentelemetry-api>=1.31.0", # OpenTelemetry
|
||||
"opentelemetry-exporter-gcp-trace>=1.9.0",
|
||||
"opentelemetry-sdk>=1.31.0",
|
||||
"pydantic>=2.0, <3.0.0", # For data validation/models
|
||||
"python-dotenv>=1.0.0", # To manage environment variables
|
||||
"PyYAML>=6.0.2", # For APIHubToolset.
|
||||
"sqlalchemy>=2.0", # SQL database ORM
|
||||
"tzlocal>=5.3", # Time zone utilities
|
||||
"uvicorn>=0.34.0", # ASGI server for FastAPI
|
||||
# go/keep-sorted end
|
||||
]
|
||||
dynamic = ["version"]
|
||||
|
||||
[project.urls]
|
||||
homepage = "https://google.github.io/adk-docs/"
|
||||
repository = "https://github.com/google/adk-python"
|
||||
changelog = "https://github.com/google/adk-python/blob/main/CHANGELOG.md"
|
||||
documentation = "https://google.github.io/adk-docs/"
|
||||
|
||||
[project.scripts]
|
||||
adk = "google.adk.cli:main"
|
||||
|
||||
[project.optional-dependencies]
|
||||
|
||||
dev = [
|
||||
# go/keep-sorted start
|
||||
"flit>=3.10.0",
|
||||
"isort>=6.0.0",
|
||||
"pyink>=24.10.0",
|
||||
"pylint>=2.6.0",
|
||||
# go/keep-sorted end
|
||||
]
|
||||
|
||||
eval = [
|
||||
# go/keep-sorted start
|
||||
"google-cloud-aiplatform[evaluation]>=1.87.0",
|
||||
"pandas>=2.2.3",
|
||||
"tabulate>=0.9.0",
|
||||
# go/keep-sorted end
|
||||
]
|
||||
|
||||
test = [
|
||||
# go/keep-sorted start
|
||||
"langchain-community>=0.3.17",
|
||||
"pytest-asyncio>=0.25.0",
|
||||
"pytest-mock>=3.14.0",
|
||||
"pytest-xdist>=3.6.1",
|
||||
"pytest>=8.3.4",
|
||||
# go/keep-sorted end
|
||||
]
|
||||
|
||||
docs = [
|
||||
"autodoc_pydantic",
|
||||
"furo",
|
||||
"myst-parser",
|
||||
"sphinx",
|
||||
"sphinx-autodoc-typehints",
|
||||
"sphinx-rtd-theme",
|
||||
]
|
||||
|
||||
# Optional extensions
|
||||
extensions = [
|
||||
"anthropic>=0.43.0", # For anthropic model support
|
||||
"beautifulsoup4>=3.2.2", # For load_web_page tool.
|
||||
"crewai[tools];python_version>='3.10'", # For CrewaiTool
|
||||
"docker>=7.0.0", # For ContainerCodeExecutor
|
||||
"langgraph>=0.2.60", # For LangGraphAgent
|
||||
"litellm>=1.63.11", # For LiteLLM support
|
||||
"llama-index-readers-file>=0.4.0", # for retrieval usings LlamaIndex.
|
||||
"lxml>=5.3.0", # For load_web_page tool.
|
||||
]
|
||||
|
||||
|
||||
[tool.pyink]
|
||||
# Format py files following Google style-guide
|
||||
line-length = 80
|
||||
unstable = true
|
||||
pyink-indentation = 2
|
||||
pyink-use-majority-quotes = true
|
||||
|
||||
|
||||
[build-system]
|
||||
# Build system specify which backend is used to build/install the project (flit,
|
||||
# poetry, setuptools,...). All backends are supported by `pip install`
|
||||
requires = ["flit_core >=3.8,<4"]
|
||||
build-backend = "flit_core.buildapi"
|
||||
|
||||
[tool.flit.sdist]
|
||||
include = ['src/**/*', 'README.md', 'pyproject.toml']
|
||||
exclude = ['src/**/*.sh']
|
||||
|
||||
[tool.flit.module]
|
||||
name = "google.adk"
|
||||
|
||||
[tool.isort]
|
||||
# Organize imports following Google style-guide
|
||||
force_single_line = true
|
||||
force_sort_within_sections = true
|
||||
honor_case_in_force_sorted_sections = true
|
||||
known_third_party = ["agents", "google"]
|
||||
order_by_type = false
|
||||
sort_relative_in_force_sorted_sections = true
|
||||
multi_line_output = 3
|
||||
line_length = 200
|
||||
|
||||
[tool.pytest.ini_options]
|
||||
testpaths = ["tests"]
|
||||
asyncio_default_fixture_loop_scope = "function"
|
||||
20
src/google/adk/__init__.py
Normal file
20
src/google/adk/__init__.py
Normal file
@ -0,0 +1,20 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from . import version
|
||||
from .agents.llm_agent import Agent
|
||||
from .runners import Runner
|
||||
|
||||
__version__ = version.__version__
|
||||
__all__ = ["Agent", "Runner"]
|
||||
32
src/google/adk/agents/__init__.py
Normal file
32
src/google/adk/agents/__init__.py
Normal file
@ -0,0 +1,32 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from .base_agent import BaseAgent
|
||||
from .live_request_queue import LiveRequest
|
||||
from .live_request_queue import LiveRequestQueue
|
||||
from .llm_agent import Agent
|
||||
from .llm_agent import LlmAgent
|
||||
from .loop_agent import LoopAgent
|
||||
from .parallel_agent import ParallelAgent
|
||||
from .run_config import RunConfig
|
||||
from .sequential_agent import SequentialAgent
|
||||
|
||||
__all__ = [
|
||||
'Agent',
|
||||
'BaseAgent',
|
||||
'LlmAgent',
|
||||
'LoopAgent',
|
||||
'ParallelAgent',
|
||||
'SequentialAgent',
|
||||
]
|
||||
38
src/google/adk/agents/active_streaming_tool.py
Normal file
38
src/google/adk/agents/active_streaming_tool.py
Normal file
@ -0,0 +1,38 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
from pydantic import ConfigDict
|
||||
|
||||
from .live_request_queue import LiveRequestQueue
|
||||
|
||||
|
||||
class ActiveStreamingTool(BaseModel):
|
||||
"""Manages streaming tool related resources during invocation."""
|
||||
|
||||
model_config = ConfigDict(
|
||||
arbitrary_types_allowed=True,
|
||||
extra='forbid',
|
||||
)
|
||||
|
||||
task: Optional[asyncio.Task] = None
|
||||
"""The active task of this streaming tool."""
|
||||
|
||||
stream: Optional[LiveRequestQueue] = None
|
||||
"""The active (input) streams of this streaming tool."""
|
||||
345
src/google/adk/agents/base_agent.py
Normal file
345
src/google/adk/agents/base_agent.py
Normal file
@ -0,0 +1,345 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
from typing import AsyncGenerator
|
||||
from typing import Callable
|
||||
from typing import final
|
||||
from typing import Optional
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from google.genai import types
|
||||
from opentelemetry import trace
|
||||
from pydantic import BaseModel
|
||||
from pydantic import ConfigDict
|
||||
from pydantic import Field
|
||||
from pydantic import field_validator
|
||||
from typing_extensions import override
|
||||
|
||||
from ..events.event import Event
|
||||
from .callback_context import CallbackContext
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .invocation_context import InvocationContext
|
||||
|
||||
tracer = trace.get_tracer('gcp.vertex.agent')
|
||||
|
||||
BeforeAgentCallback = Callable[[CallbackContext], Optional[types.Content]]
|
||||
"""Callback signature that is invoked before the agent run.
|
||||
|
||||
Args:
|
||||
callback_context: MUST be named 'callback_context' (enforced).
|
||||
|
||||
Returns:
|
||||
The content to return to the user. When set, the agent run will skipped and
|
||||
the provided content will be returned to user.
|
||||
"""
|
||||
|
||||
AfterAgentCallback = Callable[[CallbackContext], Optional[types.Content]]
|
||||
"""Callback signature that is invoked after the agent run.
|
||||
|
||||
Args:
|
||||
callback_context: MUST be named 'callback_context' (enforced).
|
||||
|
||||
Returns:
|
||||
The content to return to the user. When set, the agent run will skipped and
|
||||
the provided content will be appended to event history as agent response.
|
||||
"""
|
||||
|
||||
|
||||
class BaseAgent(BaseModel):
|
||||
"""Base class for all agents in Agent Development Kit."""
|
||||
|
||||
model_config = ConfigDict(
|
||||
arbitrary_types_allowed=True,
|
||||
extra='forbid',
|
||||
)
|
||||
|
||||
name: str
|
||||
"""The agent's name.
|
||||
|
||||
Agent name must be a Python identifier and unique within the agent tree.
|
||||
Agent name cannot be "user", since it's reserved for end-user's input.
|
||||
"""
|
||||
|
||||
description: str = ''
|
||||
"""Description about the agent's capability.
|
||||
|
||||
The model uses this to determine whether to delegate control to the agent.
|
||||
One-line description is enough and preferred.
|
||||
"""
|
||||
|
||||
parent_agent: Optional[BaseAgent] = Field(default=None, init=False)
|
||||
"""The parent agent of this agent.
|
||||
|
||||
Note that an agent can ONLY be added as sub-agent once.
|
||||
|
||||
If you want to add one agent twice as sub-agent, consider to create two agent
|
||||
instances with identical config, but with different name and add them to the
|
||||
agent tree.
|
||||
"""
|
||||
sub_agents: list[BaseAgent] = Field(default_factory=list)
|
||||
"""The sub-agents of this agent."""
|
||||
|
||||
before_agent_callback: Optional[BeforeAgentCallback] = None
|
||||
"""Callback signature that is invoked before the agent run.
|
||||
|
||||
Args:
|
||||
callback_context: MUST be named 'callback_context' (enforced).
|
||||
|
||||
Returns:
|
||||
The content to return to the user. When set, the agent run will skipped and
|
||||
the provided content will be returned to user.
|
||||
"""
|
||||
after_agent_callback: Optional[AfterAgentCallback] = None
|
||||
"""Callback signature that is invoked after the agent run.
|
||||
|
||||
Args:
|
||||
callback_context: MUST be named 'callback_context' (enforced).
|
||||
|
||||
Returns:
|
||||
The content to return to the user. When set, the agent run will skipped and
|
||||
the provided content will be appended to event history as agent response.
|
||||
"""
|
||||
|
||||
@final
|
||||
async def run_async(
|
||||
self,
|
||||
parent_context: InvocationContext,
|
||||
) -> AsyncGenerator[Event, None]:
|
||||
"""Entry method to run an agent via text-based conversaction.
|
||||
|
||||
Args:
|
||||
parent_context: InvocationContext, the invocation context of the parent
|
||||
agent.
|
||||
|
||||
Yields:
|
||||
Event: the events generated by the agent.
|
||||
"""
|
||||
|
||||
with tracer.start_as_current_span(f'agent_run [{self.name}]'):
|
||||
ctx = self._create_invocation_context(parent_context)
|
||||
|
||||
if event := self.__handle_before_agent_callback(ctx):
|
||||
yield event
|
||||
if ctx.end_invocation:
|
||||
return
|
||||
|
||||
async for event in self._run_async_impl(ctx):
|
||||
yield event
|
||||
|
||||
if ctx.end_invocation:
|
||||
return
|
||||
|
||||
if event := self.__handle_after_agent_callback(ctx):
|
||||
yield event
|
||||
|
||||
@final
|
||||
async def run_live(
|
||||
self,
|
||||
parent_context: InvocationContext,
|
||||
) -> AsyncGenerator[Event, None]:
|
||||
"""Entry method to run an agent via video/audio-based conversaction.
|
||||
|
||||
Args:
|
||||
parent_context: InvocationContext, the invocation context of the parent
|
||||
agent.
|
||||
|
||||
Yields:
|
||||
Event: the events generated by the agent.
|
||||
"""
|
||||
with tracer.start_as_current_span(f'agent_run [{self.name}]'):
|
||||
ctx = self._create_invocation_context(parent_context)
|
||||
# TODO(hangfei): support before/after_agent_callback
|
||||
|
||||
async for event in self._run_live_impl(ctx):
|
||||
yield event
|
||||
|
||||
async def _run_async_impl(
|
||||
self, ctx: InvocationContext
|
||||
) -> AsyncGenerator[Event, None]:
|
||||
"""Core logic to run this agent via text-based conversaction.
|
||||
|
||||
Args:
|
||||
ctx: InvocationContext, the invocation context for this agent.
|
||||
|
||||
Yields:
|
||||
Event: the events generated by the agent.
|
||||
"""
|
||||
raise NotImplementedError(
|
||||
f'_run_async_impl for {type(self)} is not implemented.'
|
||||
)
|
||||
yield # AsyncGenerator requires having at least one yield statement
|
||||
|
||||
async def _run_live_impl(
|
||||
self, ctx: InvocationContext
|
||||
) -> AsyncGenerator[Event, None]:
|
||||
"""Core logic to run this agent via video/audio-based conversaction.
|
||||
|
||||
Args:
|
||||
ctx: InvocationContext, the invocation context for this agent.
|
||||
|
||||
Yields:
|
||||
Event: the events generated by the agent.
|
||||
"""
|
||||
raise NotImplementedError(
|
||||
f'_run_live_impl for {type(self)} is not implemented.'
|
||||
)
|
||||
yield # AsyncGenerator requires having at least one yield statement
|
||||
|
||||
@property
|
||||
def root_agent(self) -> BaseAgent:
|
||||
"""Gets the root agent of this agent."""
|
||||
root_agent = self
|
||||
while root_agent.parent_agent is not None:
|
||||
root_agent = root_agent.parent_agent
|
||||
return root_agent
|
||||
|
||||
def find_agent(self, name: str) -> Optional[BaseAgent]:
|
||||
"""Finds the agent with the given name in this agent and its descendants.
|
||||
|
||||
Args:
|
||||
name: The name of the agent to find.
|
||||
|
||||
Returns:
|
||||
The agent with the matching name, or None if no such agent is found.
|
||||
"""
|
||||
if self.name == name:
|
||||
return self
|
||||
return self.find_sub_agent(name)
|
||||
|
||||
def find_sub_agent(self, name: str) -> Optional[BaseAgent]:
|
||||
"""Finds the agent with the given name in this agent's descendants.
|
||||
|
||||
Args:
|
||||
name: The name of the agent to find.
|
||||
|
||||
Returns:
|
||||
The agent with the matching name, or None if no such agent is found.
|
||||
"""
|
||||
for sub_agent in self.sub_agents:
|
||||
if result := sub_agent.find_agent(name):
|
||||
return result
|
||||
return None
|
||||
|
||||
def _create_invocation_context(
|
||||
self, parent_context: InvocationContext
|
||||
) -> InvocationContext:
|
||||
"""Creates a new invocation context for this agent."""
|
||||
invocation_context = parent_context.model_copy(update={'agent': self})
|
||||
if parent_context.branch:
|
||||
invocation_context.branch = f'{parent_context.branch}.{self.name}'
|
||||
return invocation_context
|
||||
|
||||
def __handle_before_agent_callback(
|
||||
self, ctx: InvocationContext
|
||||
) -> Optional[Event]:
|
||||
"""Runs the before_agent_callback if it exists.
|
||||
|
||||
Returns:
|
||||
Optional[Event]: an event if callback provides content or changed state.
|
||||
"""
|
||||
ret_event = None
|
||||
|
||||
if not isinstance(self.before_agent_callback, Callable):
|
||||
return ret_event
|
||||
|
||||
callback_context = CallbackContext(ctx)
|
||||
before_agent_callback_content = self.before_agent_callback(
|
||||
callback_context=callback_context
|
||||
)
|
||||
|
||||
if before_agent_callback_content:
|
||||
ret_event = Event(
|
||||
invocation_id=ctx.invocation_id,
|
||||
author=self.name,
|
||||
branch=ctx.branch,
|
||||
content=before_agent_callback_content,
|
||||
actions=callback_context._event_actions,
|
||||
)
|
||||
ctx.end_invocation = True
|
||||
return ret_event
|
||||
|
||||
if callback_context.state.has_delta():
|
||||
ret_event = Event(
|
||||
invocation_id=ctx.invocation_id,
|
||||
author=self.name,
|
||||
branch=ctx.branch,
|
||||
actions=callback_context._event_actions,
|
||||
)
|
||||
|
||||
return ret_event
|
||||
|
||||
def __handle_after_agent_callback(
|
||||
self, invocation_context: InvocationContext
|
||||
) -> Optional[Event]:
|
||||
"""Runs the after_agent_callback if it exists.
|
||||
|
||||
Returns:
|
||||
Optional[Event]: an event if callback provides content or changed state.
|
||||
"""
|
||||
ret_event = None
|
||||
|
||||
if not isinstance(self.after_agent_callback, Callable):
|
||||
return ret_event
|
||||
|
||||
callback_context = CallbackContext(invocation_context)
|
||||
after_agent_callback_content = self.after_agent_callback(
|
||||
callback_context=callback_context
|
||||
)
|
||||
|
||||
if after_agent_callback_content or callback_context.state.has_delta():
|
||||
ret_event = Event(
|
||||
invocation_id=invocation_context.invocation_id,
|
||||
author=self.name,
|
||||
branch=invocation_context.branch,
|
||||
content=after_agent_callback_content,
|
||||
actions=callback_context._event_actions,
|
||||
)
|
||||
|
||||
return ret_event
|
||||
|
||||
@override
|
||||
def model_post_init(self, __context: Any) -> None:
|
||||
self.__set_parent_agent_for_sub_agents()
|
||||
|
||||
@field_validator('name', mode='after')
|
||||
@classmethod
|
||||
def __validate_name(cls, value: str):
|
||||
if not value.isidentifier():
|
||||
raise ValueError(
|
||||
f'Found invalid agent name: `{value}`.'
|
||||
' Agent name must be a valid identifier. It should start with a'
|
||||
' letter (a-z, A-Z) or an underscore (_), and can only contain'
|
||||
' letters, digits (0-9), and underscores.'
|
||||
)
|
||||
if value == 'user':
|
||||
raise ValueError(
|
||||
"Agent name cannot be `user`. `user` is reserved for end-user's"
|
||||
' input.'
|
||||
)
|
||||
return value
|
||||
|
||||
def __set_parent_agent_for_sub_agents(self) -> BaseAgent:
|
||||
for sub_agent in self.sub_agents:
|
||||
if sub_agent.parent_agent is not None:
|
||||
raise ValueError(
|
||||
f'Agent `{sub_agent.name}` already has a parent agent, current'
|
||||
f' parent: `{sub_agent.parent_agent.name}`, trying to add:'
|
||||
f' `{self.name}`'
|
||||
)
|
||||
sub_agent.parent_agent = self
|
||||
return self
|
||||
112
src/google/adk/agents/callback_context.py
Normal file
112
src/google/adk/agents/callback_context.py
Normal file
@ -0,0 +1,112 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Optional, TYPE_CHECKING
|
||||
|
||||
from typing_extensions import override
|
||||
|
||||
from .readonly_context import ReadonlyContext
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from google.genai import types
|
||||
|
||||
from ..events.event import Event
|
||||
from ..events.event_actions import EventActions
|
||||
from ..sessions.state import State
|
||||
from .invocation_context import InvocationContext
|
||||
|
||||
|
||||
class CallbackContext(ReadonlyContext):
|
||||
"""The context of various callbacks within an agent run."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
invocation_context: InvocationContext,
|
||||
*,
|
||||
event_actions: Optional[EventActions] = None,
|
||||
) -> None:
|
||||
super().__init__(invocation_context)
|
||||
|
||||
from ..events.event_actions import EventActions
|
||||
from ..sessions.state import State
|
||||
|
||||
# TODO(weisun): make this public for Agent Development Kit, but private for
|
||||
# users.
|
||||
self._event_actions = event_actions or EventActions()
|
||||
self._state = State(
|
||||
value=invocation_context.session.state,
|
||||
delta=self._event_actions.state_delta,
|
||||
)
|
||||
|
||||
@property
|
||||
@override
|
||||
def state(self) -> State:
|
||||
"""The delta-aware state of the current session.
|
||||
|
||||
For any state change, you can mutate this object directly,
|
||||
e.g. `ctx.state['foo'] = 'bar'`
|
||||
"""
|
||||
return self._state
|
||||
|
||||
@property
|
||||
def user_content(self) -> Optional[types.Content]:
|
||||
"""The user content that started this invocation. READONLY field."""
|
||||
return self._invocation_context.user_content
|
||||
|
||||
def load_artifact(
|
||||
self, filename: str, version: Optional[int] = None
|
||||
) -> Optional[types.Part]:
|
||||
"""Loads an artifact attached to the current session.
|
||||
|
||||
Args:
|
||||
filename: The filename of the artifact.
|
||||
version: The version of the artifact. If None, the latest version will be
|
||||
returned.
|
||||
|
||||
Returns:
|
||||
The artifact.
|
||||
"""
|
||||
if self._invocation_context.artifact_service is None:
|
||||
raise ValueError("Artifact service is not initialized.")
|
||||
return self._invocation_context.artifact_service.load_artifact(
|
||||
app_name=self._invocation_context.app_name,
|
||||
user_id=self._invocation_context.user_id,
|
||||
session_id=self._invocation_context.session.id,
|
||||
filename=filename,
|
||||
version=version,
|
||||
)
|
||||
|
||||
def save_artifact(self, filename: str, artifact: types.Part) -> int:
|
||||
"""Saves an artifact and records it as delta for the current session.
|
||||
|
||||
Args:
|
||||
filename: The filename of the artifact.
|
||||
artifact: The artifact to save.
|
||||
|
||||
Returns:
|
||||
The version of the artifact.
|
||||
"""
|
||||
if self._invocation_context.artifact_service is None:
|
||||
raise ValueError("Artifact service is not initialized.")
|
||||
version = self._invocation_context.artifact_service.save_artifact(
|
||||
app_name=self._invocation_context.app_name,
|
||||
user_id=self._invocation_context.user_id,
|
||||
session_id=self._invocation_context.session.id,
|
||||
filename=filename,
|
||||
artifact=artifact,
|
||||
)
|
||||
self._event_actions.artifact_delta[filename] = version
|
||||
return version
|
||||
181
src/google/adk/agents/invocation_context.py
Normal file
181
src/google/adk/agents/invocation_context.py
Normal file
@ -0,0 +1,181 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Optional
|
||||
import uuid
|
||||
|
||||
from google.genai import types
|
||||
from pydantic import BaseModel
|
||||
from pydantic import ConfigDict
|
||||
|
||||
from ..artifacts.base_artifact_service import BaseArtifactService
|
||||
from ..memory.base_memory_service import BaseMemoryService
|
||||
from ..sessions.base_session_service import BaseSessionService
|
||||
from ..sessions.session import Session
|
||||
from .active_streaming_tool import ActiveStreamingTool
|
||||
from .base_agent import BaseAgent
|
||||
from .live_request_queue import LiveRequestQueue
|
||||
from .run_config import RunConfig
|
||||
from .transcription_entry import TranscriptionEntry
|
||||
|
||||
|
||||
class LlmCallsLimitExceededError(Exception):
|
||||
"""Error thrown when the number of LLM calls exceed the limit."""
|
||||
|
||||
|
||||
class _InvocationCostManager(BaseModel):
|
||||
"""A container to keep track of the cost of invocation.
|
||||
|
||||
While we don't expected the metrics captured here to be a direct
|
||||
representatative of monetary cost incurred in executing the current
|
||||
invocation, but they, in someways have an indirect affect.
|
||||
"""
|
||||
|
||||
_number_of_llm_calls: int = 0
|
||||
"""A counter that keeps track of number of llm calls made."""
|
||||
|
||||
def increment_and_enforce_llm_calls_limit(
|
||||
self, run_config: Optional[RunConfig]
|
||||
):
|
||||
"""Increments _number_of_llm_calls and enforces the limit."""
|
||||
# We first increment the counter and then check the conditions.
|
||||
self._number_of_llm_calls += 1
|
||||
|
||||
if (
|
||||
run_config
|
||||
and run_config.max_llm_calls > 0
|
||||
and self._number_of_llm_calls > run_config.max_llm_calls
|
||||
):
|
||||
# We only enforce the limit if the limit is a positive number.
|
||||
raise LlmCallsLimitExceededError(
|
||||
"Max number of llm calls limit of"
|
||||
f" `{run_config.max_llm_calls}` exceeded"
|
||||
)
|
||||
|
||||
|
||||
class InvocationContext(BaseModel):
|
||||
"""An invocation context represents the data of a single invocation of an agent.
|
||||
|
||||
An invocation:
|
||||
1. Starts with a user message and ends with a final response.
|
||||
2. Can contain one or multiple agent calls.
|
||||
3. Is handled by runner.run_async().
|
||||
|
||||
An invocation runs an agent until it does not request to transfer to another
|
||||
agent.
|
||||
|
||||
An agent call:
|
||||
1. Is handled by agent.run().
|
||||
2. Ends when agent.run() ends.
|
||||
|
||||
An LLM agent call is an agent with a BaseLLMFlow.
|
||||
An LLM agent call can contain one or multiple steps.
|
||||
|
||||
An LLM agent runs steps in a loop until:
|
||||
1. A final response is generated.
|
||||
2. The agent transfers to another agent.
|
||||
3. The end_invocation is set to true by any callbacks or tools.
|
||||
|
||||
A step:
|
||||
1. Calls the LLM only once and yields its response.
|
||||
2. Calls the tools and yields their responses if requested.
|
||||
|
||||
The summarization of the function response is considered another step, since
|
||||
it is another llm call.
|
||||
A step ends when it's done calling llm and tools, or if the end_invocation
|
||||
is set to true at any time.
|
||||
|
||||
```
|
||||
┌─────────────────────── invocation ──────────────────────────┐
|
||||
┌──────────── llm_agent_call_1 ────────────┐ ┌─ agent_call_2 ─┐
|
||||
┌──── step_1 ────────┐ ┌───── step_2 ──────┐
|
||||
[call_llm] [call_tool] [call_llm] [transfer]
|
||||
```
|
||||
"""
|
||||
|
||||
model_config = ConfigDict(
|
||||
arbitrary_types_allowed=True,
|
||||
extra="forbid",
|
||||
)
|
||||
|
||||
artifact_service: Optional[BaseArtifactService] = None
|
||||
session_service: BaseSessionService
|
||||
memory_service: Optional[BaseMemoryService] = None
|
||||
|
||||
invocation_id: str
|
||||
"""The id of this invocation context. Readonly."""
|
||||
branch: Optional[str] = None
|
||||
"""The branch of the invocation context.
|
||||
|
||||
The format is like agent_1.agent_2.agent_3, where agent_1 is the parent of
|
||||
agent_2, and agent_2 is the parent of agent_3.
|
||||
|
||||
Branch is used when multiple sub-agents shouldn't see their peer agents'
|
||||
conversaction history.
|
||||
"""
|
||||
agent: BaseAgent
|
||||
"""The current agent of this invocation context. Readonly."""
|
||||
user_content: Optional[types.Content] = None
|
||||
"""The user content that started this invocation. Readonly."""
|
||||
session: Session
|
||||
"""The current session of this invocation context. Readonly."""
|
||||
|
||||
end_invocation: bool = False
|
||||
"""Whether to end this invocation.
|
||||
|
||||
Set to True in callbacks or tools to terminate this invocation."""
|
||||
|
||||
live_request_queue: Optional[LiveRequestQueue] = None
|
||||
"""The queue to receive live requests."""
|
||||
|
||||
active_streaming_tools: Optional[dict[str, ActiveStreamingTool]] = None
|
||||
"""The running streaming tools of this invocation."""
|
||||
|
||||
transcription_cache: Optional[list[TranscriptionEntry]] = None
|
||||
"""Caches necessary, data audio or contents, that are needed by transcription."""
|
||||
|
||||
run_config: Optional[RunConfig] = None
|
||||
"""Configurations for live agents under this invocation."""
|
||||
|
||||
_invocation_cost_manager: _InvocationCostManager = _InvocationCostManager()
|
||||
"""A container to keep track of different kinds of costs incurred as a part
|
||||
of this invocation.
|
||||
"""
|
||||
|
||||
def increment_llm_call_count(
|
||||
self,
|
||||
):
|
||||
"""Tracks number of llm calls made.
|
||||
|
||||
Raises:
|
||||
LlmCallsLimitExceededError: If number of llm calls made exceed the set
|
||||
threshold.
|
||||
"""
|
||||
self._invocation_cost_manager.increment_and_enforce_llm_calls_limit(
|
||||
self.run_config
|
||||
)
|
||||
|
||||
@property
|
||||
def app_name(self) -> str:
|
||||
return self.session.app_name
|
||||
|
||||
@property
|
||||
def user_id(self) -> str:
|
||||
return self.session.user_id
|
||||
|
||||
|
||||
def new_invocation_context_id() -> str:
|
||||
return "e-" + str(uuid.uuid4())
|
||||
140
src/google/adk/agents/langgraph_agent.py
Normal file
140
src/google/adk/agents/langgraph_agent.py
Normal file
@ -0,0 +1,140 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import AsyncGenerator
|
||||
from typing import Union
|
||||
|
||||
from google.genai import types
|
||||
from langchain_core.messages import AIMessage
|
||||
from langchain_core.messages import HumanMessage
|
||||
from langchain_core.messages import SystemMessage
|
||||
from langchain_core.runnables.config import RunnableConfig
|
||||
from langgraph.graph.graph import CompiledGraph
|
||||
from pydantic import ConfigDict
|
||||
from typing_extensions import override
|
||||
|
||||
from ..events.event import Event
|
||||
from .base_agent import BaseAgent
|
||||
from .invocation_context import InvocationContext
|
||||
|
||||
|
||||
def _get_last_human_messages(events: list[Event]) -> list[HumanMessage]:
|
||||
"""Extracts last human messages from given list of events.
|
||||
|
||||
Args:
|
||||
events: the list of events
|
||||
|
||||
Returns:
|
||||
list of last human messages
|
||||
"""
|
||||
messages = []
|
||||
for event in reversed(events):
|
||||
if messages and event.author != 'user':
|
||||
break
|
||||
if event.author == 'user' and event.content and event.content.parts:
|
||||
messages.append(HumanMessage(content=event.content.parts[0].text))
|
||||
return list(reversed(messages))
|
||||
|
||||
|
||||
class LangGraphAgent(BaseAgent):
|
||||
"""Currently a concept implementation, supports single and multi-turn."""
|
||||
|
||||
model_config = ConfigDict(
|
||||
arbitrary_types_allowed=True,
|
||||
)
|
||||
|
||||
graph: CompiledGraph
|
||||
|
||||
instruction: str = ''
|
||||
|
||||
@override
|
||||
async def _run_async_impl(
|
||||
self,
|
||||
ctx: InvocationContext,
|
||||
) -> AsyncGenerator[Event, None]:
|
||||
|
||||
# Needed for langgraph checkpointer (for subsequent invocations; multi-turn)
|
||||
config: RunnableConfig = {'configurable': {'thread_id': ctx.session.id}}
|
||||
|
||||
# Add instruction as SystemMessage if graph state is empty
|
||||
current_graph_state = self.graph.get_state(config)
|
||||
graph_messages = (
|
||||
current_graph_state.values.get('messages', [])
|
||||
if current_graph_state.values
|
||||
else []
|
||||
)
|
||||
messages = (
|
||||
[SystemMessage(content=self.instruction)]
|
||||
if self.instruction and not graph_messages
|
||||
else []
|
||||
)
|
||||
# Add events to messages (evaluating the memory used; parent agent vs checkpointer)
|
||||
messages += self._get_messages(ctx.session.events)
|
||||
|
||||
# Use the Runnable
|
||||
final_state = self.graph.invoke({'messages': messages}, config)
|
||||
result = final_state['messages'][-1].content
|
||||
|
||||
result_event = Event(
|
||||
invocation_id=ctx.invocation_id,
|
||||
author=self.name,
|
||||
branch=ctx.branch,
|
||||
content=types.Content(
|
||||
role='model',
|
||||
parts=[types.Part.from_text(text=result)],
|
||||
),
|
||||
)
|
||||
yield result_event
|
||||
|
||||
def _get_messages(
|
||||
self, events: list[Event]
|
||||
) -> list[Union[HumanMessage, AIMessage]]:
|
||||
"""Extracts messages from given list of events.
|
||||
|
||||
If the developer provides their own memory within langgraph, we return the
|
||||
last user messages only. Otherwise, we return all messages between the user
|
||||
and the agent.
|
||||
|
||||
Args:
|
||||
events: the list of events
|
||||
|
||||
Returns:
|
||||
list of messages
|
||||
"""
|
||||
if self.graph.checkpointer:
|
||||
return _get_last_human_messages(events)
|
||||
else:
|
||||
return self._get_conversation_with_agent(events)
|
||||
|
||||
def _get_conversation_with_agent(
|
||||
self, events: list[Event]
|
||||
) -> list[Union[HumanMessage, AIMessage]]:
|
||||
"""Extracts messages from given list of events.
|
||||
|
||||
Args:
|
||||
events: the list of events
|
||||
|
||||
Returns:
|
||||
list of messages
|
||||
"""
|
||||
|
||||
messages = []
|
||||
for event in events:
|
||||
if not event.content or not event.content.parts:
|
||||
continue
|
||||
if event.author == 'user':
|
||||
messages.append(HumanMessage(content=event.content.parts[0].text))
|
||||
elif event.author == self.name:
|
||||
messages.append(AIMessage(content=event.content.parts[0].text))
|
||||
return messages
|
||||
64
src/google/adk/agents/live_request_queue.py
Normal file
64
src/google/adk/agents/live_request_queue.py
Normal file
@ -0,0 +1,64 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import asyncio
|
||||
from typing import Optional
|
||||
|
||||
from google.genai import types
|
||||
from pydantic import BaseModel
|
||||
from pydantic import ConfigDict
|
||||
|
||||
|
||||
class LiveRequest(BaseModel):
|
||||
"""Request send to live agents."""
|
||||
|
||||
model_config = ConfigDict(ser_json_bytes='base64', val_json_bytes='base64')
|
||||
|
||||
content: Optional[types.Content] = None
|
||||
"""If set, send the content to the model in turn-by-turn mode."""
|
||||
blob: Optional[types.Blob] = None
|
||||
"""If set, send the blob to the model in realtime mode."""
|
||||
close: bool = False
|
||||
"""If set, close the queue. queue.shutdown() is only supported in Python 3.13+."""
|
||||
|
||||
|
||||
class LiveRequestQueue:
|
||||
"""Queue used to send LiveRequest in a live(bidirectional streaming) way."""
|
||||
|
||||
def __init__(self):
|
||||
# Ensure there's an event loop available in this thread
|
||||
try:
|
||||
asyncio.get_running_loop()
|
||||
except RuntimeError:
|
||||
# No running loop, create one
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
|
||||
# Now create the queue (it will use the event loop we just ensured exists)
|
||||
self._queue = asyncio.Queue()
|
||||
|
||||
def close(self):
|
||||
self._queue.put_nowait(LiveRequest(close=True))
|
||||
|
||||
def send_content(self, content: types.Content):
|
||||
self._queue.put_nowait(LiveRequest(content=content))
|
||||
|
||||
def send_realtime(self, blob: types.Blob):
|
||||
self._queue.put_nowait(LiveRequest(blob=blob))
|
||||
|
||||
def send(self, req: LiveRequest):
|
||||
self._queue.put_nowait(req)
|
||||
|
||||
async def get(self) -> LiveRequest:
|
||||
return await self._queue.get()
|
||||
376
src/google/adk/agents/llm_agent.py
Normal file
376
src/google/adk/agents/llm_agent.py
Normal file
@ -0,0 +1,376 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
from typing import AsyncGenerator
|
||||
from typing import Callable
|
||||
from typing import Literal
|
||||
from typing import Optional
|
||||
from typing import Union
|
||||
|
||||
from google.genai import types
|
||||
from pydantic import BaseModel
|
||||
from pydantic import Field
|
||||
from pydantic import field_validator
|
||||
from pydantic import model_validator
|
||||
from typing_extensions import override
|
||||
from typing_extensions import TypeAlias
|
||||
|
||||
from ..code_executors.base_code_executor import BaseCodeExecutor
|
||||
from ..events.event import Event
|
||||
from ..examples.base_example_provider import BaseExampleProvider
|
||||
from ..examples.example import Example
|
||||
from ..flows.llm_flows.auto_flow import AutoFlow
|
||||
from ..flows.llm_flows.base_llm_flow import BaseLlmFlow
|
||||
from ..flows.llm_flows.single_flow import SingleFlow
|
||||
from ..models.base_llm import BaseLlm
|
||||
from ..models.llm_request import LlmRequest
|
||||
from ..models.llm_response import LlmResponse
|
||||
from ..models.registry import LLMRegistry
|
||||
from ..planners.base_planner import BasePlanner
|
||||
from ..tools.base_tool import BaseTool
|
||||
from ..tools.function_tool import FunctionTool
|
||||
from ..tools.tool_context import ToolContext
|
||||
from .base_agent import BaseAgent
|
||||
from .callback_context import CallbackContext
|
||||
from .invocation_context import InvocationContext
|
||||
from .readonly_context import ReadonlyContext
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
BeforeModelCallback: TypeAlias = Callable[
|
||||
[CallbackContext, LlmRequest], Optional[LlmResponse]
|
||||
]
|
||||
AfterModelCallback: TypeAlias = Callable[
|
||||
[CallbackContext, LlmResponse],
|
||||
Optional[LlmResponse],
|
||||
]
|
||||
BeforeToolCallback: TypeAlias = Callable[
|
||||
[BaseTool, dict[str, Any], ToolContext],
|
||||
Optional[dict],
|
||||
]
|
||||
AfterToolCallback: TypeAlias = Callable[
|
||||
[BaseTool, dict[str, Any], ToolContext, dict],
|
||||
Optional[dict],
|
||||
]
|
||||
|
||||
InstructionProvider: TypeAlias = Callable[[ReadonlyContext], str]
|
||||
|
||||
ToolUnion: TypeAlias = Union[Callable, BaseTool]
|
||||
ExamplesUnion = Union[list[Example], BaseExampleProvider]
|
||||
|
||||
|
||||
def _convert_tool_union_to_tool(
|
||||
tool_union: ToolUnion,
|
||||
) -> BaseTool:
|
||||
return (
|
||||
tool_union
|
||||
if isinstance(tool_union, BaseTool)
|
||||
else FunctionTool(tool_union)
|
||||
)
|
||||
|
||||
|
||||
class LlmAgent(BaseAgent):
|
||||
"""LLM-based Agent."""
|
||||
|
||||
model: Union[str, BaseLlm] = ''
|
||||
"""The model to use for the agent.
|
||||
|
||||
When not set, the agent will inherit the model from its ancestor.
|
||||
"""
|
||||
|
||||
instruction: Union[str, InstructionProvider] = ''
|
||||
"""Instructions for the LLM model, guiding the agent's behavior."""
|
||||
|
||||
global_instruction: Union[str, InstructionProvider] = ''
|
||||
"""Instructions for all the agents in the entire agent tree.
|
||||
|
||||
global_instruction ONLY takes effect in root agent.
|
||||
|
||||
For example: use global_instruction to make all agents have a stable identity
|
||||
or personality.
|
||||
"""
|
||||
|
||||
tools: list[ToolUnion] = Field(default_factory=list)
|
||||
"""Tools available to this agent."""
|
||||
|
||||
generate_content_config: Optional[types.GenerateContentConfig] = None
|
||||
"""The additional content generation configurations.
|
||||
|
||||
NOTE: not all fields are usable, e.g. tools must be configured via `tools`,
|
||||
thinking_config must be configured via `planner` in LlmAgent.
|
||||
|
||||
For example: use this config to adjust model temperature, configure safety
|
||||
settings, etc.
|
||||
"""
|
||||
|
||||
# LLM-based agent transfer configs - Start
|
||||
disallow_transfer_to_parent: bool = False
|
||||
"""Disallows LLM-controlled transferring to the parent agent."""
|
||||
disallow_transfer_to_peers: bool = False
|
||||
"""Disallows LLM-controlled transferring to the peer agents."""
|
||||
# LLM-based agent transfer configs - End
|
||||
|
||||
include_contents: Literal['default', 'none'] = 'default'
|
||||
"""Whether to include contents in the model request.
|
||||
|
||||
When set to 'none', the model request will not include any contents, such as
|
||||
user messages, tool results, etc.
|
||||
"""
|
||||
|
||||
# Controlled input/output configurations - Start
|
||||
input_schema: Optional[type[BaseModel]] = None
|
||||
"""The input schema when agent is used as a tool."""
|
||||
output_schema: Optional[type[BaseModel]] = None
|
||||
"""The output schema when agent replies.
|
||||
|
||||
NOTE: when this is set, agent can ONLY reply and CANNOT use any tools, such as
|
||||
function tools, RAGs, agent transfer, etc.
|
||||
"""
|
||||
output_key: Optional[str] = None
|
||||
"""The key in session state to store the output of the agent.
|
||||
|
||||
Typically use cases:
|
||||
- Extracts agent reply for later use, such as in tools, callbacks, etc.
|
||||
- Connects agents to coordinate with each other.
|
||||
"""
|
||||
# Controlled input/output configurations - End
|
||||
|
||||
# Advance features - Start
|
||||
planner: Optional[BasePlanner] = None
|
||||
"""Instructs the agent to make a plan and execute it step by step.
|
||||
|
||||
NOTE: to use model's built-in thinking features, set the `thinking_config`
|
||||
field in `google.adk.planners.built_in_planner`.
|
||||
|
||||
"""
|
||||
|
||||
code_executor: Optional[BaseCodeExecutor] = None
|
||||
"""Allow agent to execute code blocks from model responses using the provided
|
||||
CodeExecutor.
|
||||
|
||||
Check out available code executions in `google.adk.code_executor` package.
|
||||
|
||||
NOTE: to use model's built-in code executor, don't set this field, add
|
||||
`google.adk.tools.built_in_code_execution` to tools instead.
|
||||
"""
|
||||
# Advance features - End
|
||||
|
||||
# TODO: remove below fields after migration. - Start
|
||||
# These fields are added back for easier migration.
|
||||
examples: Optional[ExamplesUnion] = None
|
||||
# TODO: remove above fields after migration. - End
|
||||
|
||||
# Callbacks - Start
|
||||
before_model_callback: Optional[BeforeModelCallback] = None
|
||||
"""Called before calling the LLM.
|
||||
Args:
|
||||
callback_context: CallbackContext,
|
||||
llm_request: LlmRequest, The raw model request. Callback can mutate the
|
||||
request.
|
||||
|
||||
Returns:
|
||||
The content to return to the user. When present, the model call will be
|
||||
skipped and the provided content will be returned to user.
|
||||
"""
|
||||
after_model_callback: Optional[AfterModelCallback] = None
|
||||
"""Called after calling LLM.
|
||||
|
||||
Args:
|
||||
callback_context: CallbackContext,
|
||||
llm_response: LlmResponse, the actual model response.
|
||||
|
||||
Returns:
|
||||
The content to return to the user. When present, the actual model response
|
||||
will be ignored and the provided content will be returned to user.
|
||||
"""
|
||||
before_tool_callback: Optional[BeforeToolCallback] = None
|
||||
"""Called before the tool is called.
|
||||
|
||||
Args:
|
||||
tool: The tool to be called.
|
||||
args: The arguments to the tool.
|
||||
tool_context: ToolContext,
|
||||
|
||||
Returns:
|
||||
The tool response. When present, the returned tool response will be used and
|
||||
the framework will skip calling the actual tool.
|
||||
"""
|
||||
after_tool_callback: Optional[AfterToolCallback] = None
|
||||
"""Called after the tool is called.
|
||||
|
||||
Args:
|
||||
tool: The tool to be called.
|
||||
args: The arguments to the tool.
|
||||
tool_context: ToolContext,
|
||||
tool_response: The response from the tool.
|
||||
|
||||
Returns:
|
||||
When present, the returned dict will be used as tool result.
|
||||
"""
|
||||
# Callbacks - End
|
||||
|
||||
@override
|
||||
async def _run_async_impl(
|
||||
self, ctx: InvocationContext
|
||||
) -> AsyncGenerator[Event, None]:
|
||||
async for event in self._llm_flow.run_async(ctx):
|
||||
self.__maybe_save_output_to_state(event)
|
||||
yield event
|
||||
|
||||
@override
|
||||
async def _run_live_impl(
|
||||
self, ctx: InvocationContext
|
||||
) -> AsyncGenerator[Event, None]:
|
||||
async for event in self._llm_flow.run_live(ctx):
|
||||
self.__maybe_save_output_to_state(event)
|
||||
yield event
|
||||
if ctx.end_invocation:
|
||||
return
|
||||
|
||||
@property
|
||||
def canonical_model(self) -> BaseLlm:
|
||||
"""The resolved self.model field as BaseLlm.
|
||||
|
||||
This method is only for use by Agent Development Kit.
|
||||
"""
|
||||
if isinstance(self.model, BaseLlm):
|
||||
return self.model
|
||||
elif self.model: # model is non-empty str
|
||||
return LLMRegistry.new_llm(self.model)
|
||||
else: # find model from ancestors.
|
||||
ancestor_agent = self.parent_agent
|
||||
while ancestor_agent is not None:
|
||||
if isinstance(ancestor_agent, LlmAgent):
|
||||
return ancestor_agent.canonical_model
|
||||
ancestor_agent = ancestor_agent.parent_agent
|
||||
raise ValueError(f'No model found for {self.name}.')
|
||||
|
||||
def canonical_instruction(self, ctx: ReadonlyContext) -> str:
|
||||
"""The resolved self.instruction field to construct instruction for this agent.
|
||||
|
||||
This method is only for use by Agent Development Kit.
|
||||
"""
|
||||
if isinstance(self.instruction, str):
|
||||
return self.instruction
|
||||
else:
|
||||
return self.instruction(ctx)
|
||||
|
||||
def canonical_global_instruction(self, ctx: ReadonlyContext) -> str:
|
||||
"""The resolved self.instruction field to construct global instruction.
|
||||
|
||||
This method is only for use by Agent Development Kit.
|
||||
"""
|
||||
if isinstance(self.global_instruction, str):
|
||||
return self.global_instruction
|
||||
else:
|
||||
return self.global_instruction(ctx)
|
||||
|
||||
@property
|
||||
def canonical_tools(self) -> list[BaseTool]:
|
||||
"""The resolved self.tools field as a list of BaseTool.
|
||||
|
||||
This method is only for use by Agent Development Kit.
|
||||
"""
|
||||
return [_convert_tool_union_to_tool(tool) for tool in self.tools]
|
||||
|
||||
@property
|
||||
def _llm_flow(self) -> BaseLlmFlow:
|
||||
if (
|
||||
self.disallow_transfer_to_parent
|
||||
and self.disallow_transfer_to_peers
|
||||
and not self.sub_agents
|
||||
):
|
||||
return SingleFlow()
|
||||
else:
|
||||
return AutoFlow()
|
||||
|
||||
def __maybe_save_output_to_state(self, event: Event):
|
||||
"""Saves the model output to state if needed."""
|
||||
if (
|
||||
self.output_key
|
||||
and event.is_final_response()
|
||||
and event.content
|
||||
and event.content.parts
|
||||
):
|
||||
result = ''.join(
|
||||
[part.text if part.text else '' for part in event.content.parts]
|
||||
)
|
||||
if self.output_schema:
|
||||
result = self.output_schema.model_validate_json(result).model_dump(
|
||||
exclude_none=True
|
||||
)
|
||||
event.actions.state_delta[self.output_key] = result
|
||||
|
||||
@model_validator(mode='after')
|
||||
def __model_validator_after(self) -> LlmAgent:
|
||||
self.__check_output_schema()
|
||||
return self
|
||||
|
||||
def __check_output_schema(self):
|
||||
if not self.output_schema:
|
||||
return
|
||||
|
||||
if (
|
||||
not self.disallow_transfer_to_parent
|
||||
or not self.disallow_transfer_to_peers
|
||||
):
|
||||
logger.warning(
|
||||
'Invalid config for agent %s: output_schema cannot co-exist with'
|
||||
' agent transfer configurations. Setting'
|
||||
' disallow_transfer_to_parent=True, disallow_transfer_to_peers=True',
|
||||
self.name,
|
||||
)
|
||||
self.disallow_transfer_to_parent = True
|
||||
self.disallow_transfer_to_peers = True
|
||||
|
||||
if self.sub_agents:
|
||||
raise ValueError(
|
||||
f'Invalid config for agent {self.name}: if output_schema is set,'
|
||||
' sub_agents must be empty to disable agent transfer.'
|
||||
)
|
||||
|
||||
if self.tools:
|
||||
raise ValueError(
|
||||
f'Invalid config for agent {self.name}: if output_schema is set,'
|
||||
' tools must be empty'
|
||||
)
|
||||
|
||||
@field_validator('generate_content_config', mode='after')
|
||||
@classmethod
|
||||
def __validate_generate_content_config(
|
||||
cls, generate_content_config: Optional[types.GenerateContentConfig]
|
||||
) -> types.GenerateContentConfig:
|
||||
if not generate_content_config:
|
||||
return types.GenerateContentConfig()
|
||||
if generate_content_config.thinking_config:
|
||||
raise ValueError('Thinking config should be set via LlmAgent.planner.')
|
||||
if generate_content_config.tools:
|
||||
raise ValueError('All tools must be set via LlmAgent.tools.')
|
||||
if generate_content_config.system_instruction:
|
||||
raise ValueError(
|
||||
'System instruction must be set via LlmAgent.instruction.'
|
||||
)
|
||||
if generate_content_config.response_schema:
|
||||
raise ValueError(
|
||||
'Response schema must be set via LlmAgent.output_schema.'
|
||||
)
|
||||
return generate_content_config
|
||||
|
||||
|
||||
Agent: TypeAlias = LlmAgent
|
||||
62
src/google/adk/agents/loop_agent.py
Normal file
62
src/google/adk/agents/loop_agent.py
Normal file
@ -0,0 +1,62 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Loop agent implementation."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import AsyncGenerator
|
||||
from typing import Optional
|
||||
|
||||
from typing_extensions import override
|
||||
|
||||
from ..agents.invocation_context import InvocationContext
|
||||
from ..events.event import Event
|
||||
from .base_agent import BaseAgent
|
||||
|
||||
|
||||
class LoopAgent(BaseAgent):
|
||||
"""A shell agent that run its sub-agents in a loop.
|
||||
|
||||
When sub-agent generates an event with escalate or max_iterations are
|
||||
reached, the loop agent will stop.
|
||||
"""
|
||||
|
||||
max_iterations: Optional[int] = None
|
||||
"""The maximum number of iterations to run the loop agent.
|
||||
|
||||
If not set, the loop agent will run indefinitely until a sub-agent
|
||||
escalates.
|
||||
"""
|
||||
|
||||
@override
|
||||
async def _run_async_impl(
|
||||
self, ctx: InvocationContext
|
||||
) -> AsyncGenerator[Event, None]:
|
||||
times_looped = 0
|
||||
while not self.max_iterations or times_looped < self.max_iterations:
|
||||
for sub_agent in self.sub_agents:
|
||||
async for event in sub_agent.run_async(ctx):
|
||||
yield event
|
||||
if event.actions.escalate:
|
||||
return
|
||||
times_looped += 1
|
||||
return
|
||||
|
||||
@override
|
||||
async def _run_live_impl(
|
||||
self, ctx: InvocationContext
|
||||
) -> AsyncGenerator[Event, None]:
|
||||
raise NotImplementedError('The behavior for run_live is not defined yet.')
|
||||
yield # AsyncGenerator requires having at least one yield statement
|
||||
96
src/google/adk/agents/parallel_agent.py
Normal file
96
src/google/adk/agents/parallel_agent.py
Normal file
@ -0,0 +1,96 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Parallel agent implementation."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from typing import AsyncGenerator
|
||||
|
||||
from typing_extensions import override
|
||||
|
||||
from ..agents.invocation_context import InvocationContext
|
||||
from ..events.event import Event
|
||||
from .base_agent import BaseAgent
|
||||
|
||||
|
||||
def _set_branch_for_current_agent(
|
||||
current_agent: BaseAgent, invocation_context: InvocationContext
|
||||
):
|
||||
invocation_context.branch = (
|
||||
f"{invocation_context.branch}.{current_agent.name}"
|
||||
if invocation_context.branch
|
||||
else current_agent.name
|
||||
)
|
||||
|
||||
|
||||
async def _merge_agent_run(
|
||||
agent_runs: list[AsyncGenerator[Event, None]],
|
||||
) -> AsyncGenerator[Event, None]:
|
||||
"""Merges the agent run event generator.
|
||||
|
||||
This implementation guarantees for each agent, it won't move on until the
|
||||
generated event is processed by upstream runner.
|
||||
|
||||
Args:
|
||||
agent_runs: A list of async generators that yield events from each agent.
|
||||
|
||||
Yields:
|
||||
Event: The next event from the merged generator.
|
||||
"""
|
||||
tasks = [
|
||||
asyncio.create_task(events_for_one_agent.__anext__())
|
||||
for events_for_one_agent in agent_runs
|
||||
]
|
||||
pending_tasks = set(tasks)
|
||||
|
||||
while pending_tasks:
|
||||
done, pending_tasks = await asyncio.wait(
|
||||
pending_tasks, return_when=asyncio.FIRST_COMPLETED
|
||||
)
|
||||
for task in done:
|
||||
try:
|
||||
yield task.result()
|
||||
|
||||
# Find the generator that produced this event and move it on.
|
||||
for i, original_task in enumerate(tasks):
|
||||
if task == original_task:
|
||||
new_task = asyncio.create_task(agent_runs[i].__anext__())
|
||||
tasks[i] = new_task
|
||||
pending_tasks.add(new_task)
|
||||
break # stop iterating once found
|
||||
|
||||
except StopAsyncIteration:
|
||||
continue
|
||||
|
||||
|
||||
class ParallelAgent(BaseAgent):
|
||||
"""A shell agent that run its sub-agents in parallel in isolated manner.
|
||||
|
||||
This approach is beneficial for scenarios requiring multiple perspectives or
|
||||
attempts on a single task, such as:
|
||||
|
||||
- Running different algorithms simultaneously.
|
||||
- Generating multiple responses for review by a subsequent evaluation agent.
|
||||
"""
|
||||
|
||||
@override
|
||||
async def _run_async_impl(
|
||||
self, ctx: InvocationContext
|
||||
) -> AsyncGenerator[Event, None]:
|
||||
_set_branch_for_current_agent(self, ctx)
|
||||
agent_runs = [agent.run_async(ctx) for agent in self.sub_agents]
|
||||
async for event in _merge_agent_run(agent_runs):
|
||||
yield event
|
||||
46
src/google/adk/agents/readonly_context.py
Normal file
46
src/google/adk/agents/readonly_context.py
Normal file
@ -0,0 +1,46 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from types import MappingProxyType
|
||||
from typing import Any
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .invocation_context import InvocationContext
|
||||
|
||||
|
||||
class ReadonlyContext:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
invocation_context: InvocationContext,
|
||||
) -> None:
|
||||
self._invocation_context = invocation_context
|
||||
|
||||
@property
|
||||
def invocation_id(self) -> str:
|
||||
"""The current invocation id."""
|
||||
return self._invocation_context.invocation_id
|
||||
|
||||
@property
|
||||
def agent_name(self) -> str:
|
||||
"""The name of the agent that is currently running."""
|
||||
return self._invocation_context.agent.name
|
||||
|
||||
@property
|
||||
def state(self) -> MappingProxyType[str, Any]:
|
||||
"""The state of the current session. READONLY field."""
|
||||
return MappingProxyType(self._invocation_context.session.state)
|
||||
50
src/google/adk/agents/remote_agent.py
Normal file
50
src/google/adk/agents/remote_agent.py
Normal file
@ -0,0 +1,50 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import json
|
||||
from typing import AsyncGenerator
|
||||
|
||||
from pydantic import Field
|
||||
import requests
|
||||
from typing_extensions import override
|
||||
|
||||
from ..events.event import Event
|
||||
from .base_agent import BaseAgent
|
||||
from .invocation_context import InvocationContext
|
||||
|
||||
|
||||
class RemoteAgent(BaseAgent):
|
||||
"""Experimental, do not use."""
|
||||
|
||||
url: str
|
||||
|
||||
sub_agents: list[BaseAgent] = Field(
|
||||
default_factory=list, init=False, frozen=True
|
||||
)
|
||||
"""Sub-agent is dsiabled in RemoteAgent."""
|
||||
|
||||
@override
|
||||
async def _run_async_impl(
|
||||
self, ctx: InvocationContext
|
||||
) -> AsyncGenerator[Event, None]:
|
||||
data = {
|
||||
'invocation_id': ctx.invocation_id,
|
||||
'session': ctx.session.model_dump(exclude_none=True),
|
||||
}
|
||||
events = requests.post(self.url, data=json.dumps(data), timeout=120)
|
||||
events.raise_for_status()
|
||||
for event in events.json():
|
||||
e = Event.model_validate(event)
|
||||
e.author = self.name
|
||||
yield e
|
||||
87
src/google/adk/agents/run_config.py
Normal file
87
src/google/adk/agents/run_config.py
Normal file
@ -0,0 +1,87 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from enum import Enum
|
||||
import logging
|
||||
import sys
|
||||
from typing import Optional
|
||||
|
||||
from google.genai import types
|
||||
from pydantic import BaseModel
|
||||
from pydantic import ConfigDict
|
||||
from pydantic import field_validator
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class StreamingMode(Enum):
|
||||
NONE = None
|
||||
SSE = 'sse'
|
||||
BIDI = 'bidi'
|
||||
|
||||
|
||||
class RunConfig(BaseModel):
|
||||
"""Configs for runtime behavior of agents."""
|
||||
|
||||
model_config = ConfigDict(
|
||||
extra='forbid',
|
||||
)
|
||||
|
||||
speech_config: Optional[types.SpeechConfig] = None
|
||||
"""Speech configuration for the live agent."""
|
||||
|
||||
response_modalities: Optional[list[str]] = None
|
||||
"""The output modalities. If not set, its default to AUDIO."""
|
||||
|
||||
save_input_blobs_as_artifacts: bool = False
|
||||
"""Whether or not to save the input blobs as artifacts."""
|
||||
|
||||
support_cfc: bool = False
|
||||
"""
|
||||
Whether to support CFC (Compositional Function Calling). Only applicable for
|
||||
StreamingMode.SSE. If it's true. the LIVE API will be invoked. Since only LIVE
|
||||
API supports CFC
|
||||
"""
|
||||
|
||||
streaming_mode: StreamingMode = StreamingMode.NONE
|
||||
"""Streaming mode, None or StreamingMode.SSE or StreamingMode.BIDI."""
|
||||
|
||||
output_audio_transcription: Optional[types.AudioTranscriptionConfig] = None
|
||||
"""Output transcription for live agents with audio response."""
|
||||
|
||||
max_llm_calls: int = 500
|
||||
"""
|
||||
A limit on the total number of llm calls for a given run.
|
||||
|
||||
Valid Values:
|
||||
- More than 0 and less than sys.maxsize: The bound on the number of llm
|
||||
calls is enforced, if the value is set in this range.
|
||||
- Less than or equal to 0: This allows for unbounded number of llm calls.
|
||||
"""
|
||||
|
||||
@field_validator('max_llm_calls', mode='after')
|
||||
@classmethod
|
||||
def validate_max_llm_calls(cls, value: int) -> int:
|
||||
if value == sys.maxsize:
|
||||
raise ValueError(f'max_llm_calls should be less than {sys.maxsize}.')
|
||||
elif value <= 0:
|
||||
logger.warning(
|
||||
'max_llm_calls is less than or equal to 0. This will result in'
|
||||
' no enforcement on total number of llm calls that will be made for a'
|
||||
' run. This may not be ideal, as this could result in a never'
|
||||
' ending communication between the model and the agent in certain'
|
||||
' cases.',
|
||||
)
|
||||
|
||||
return value
|
||||
45
src/google/adk/agents/sequential_agent.py
Normal file
45
src/google/adk/agents/sequential_agent.py
Normal file
@ -0,0 +1,45 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Sequential agent implementation."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import AsyncGenerator
|
||||
|
||||
from typing_extensions import override
|
||||
|
||||
from ..agents.invocation_context import InvocationContext
|
||||
from ..events.event import Event
|
||||
from .base_agent import BaseAgent
|
||||
|
||||
|
||||
class SequentialAgent(BaseAgent):
|
||||
"""A shell agent that run its sub-agents in sequence."""
|
||||
|
||||
@override
|
||||
async def _run_async_impl(
|
||||
self, ctx: InvocationContext
|
||||
) -> AsyncGenerator[Event, None]:
|
||||
for sub_agent in self.sub_agents:
|
||||
async for event in sub_agent.run_async(ctx):
|
||||
yield event
|
||||
|
||||
@override
|
||||
async def _run_live_impl(
|
||||
self, ctx: InvocationContext
|
||||
) -> AsyncGenerator[Event, None]:
|
||||
for sub_agent in self.sub_agents:
|
||||
async for event in sub_agent.run_live(ctx):
|
||||
yield event
|
||||
34
src/google/adk/agents/transcription_entry.py
Normal file
34
src/google/adk/agents/transcription_entry.py
Normal file
@ -0,0 +1,34 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Union
|
||||
|
||||
from google.genai import types
|
||||
from pydantic import BaseModel
|
||||
from pydantic import ConfigDict
|
||||
|
||||
|
||||
class TranscriptionEntry(BaseModel):
|
||||
"""Store the data that can be used for transcription."""
|
||||
|
||||
model_config = ConfigDict(
|
||||
arbitrary_types_allowed=True,
|
||||
extra='forbid',
|
||||
)
|
||||
|
||||
role: str
|
||||
"""The role that created this data, typically "user" or "model"""
|
||||
|
||||
data: Union[types.Blob, types.Content]
|
||||
"""The data that can be used for transcription"""
|
||||
23
src/google/adk/artifacts/__init__.py
Normal file
23
src/google/adk/artifacts/__init__.py
Normal file
@ -0,0 +1,23 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from .base_artifact_service import BaseArtifactService
|
||||
from .gcs_artifact_service import GcsArtifactService
|
||||
from .in_memory_artifact_service import InMemoryArtifactService
|
||||
|
||||
__all__ = [
|
||||
'BaseArtifactService',
|
||||
'GcsArtifactService',
|
||||
'InMemoryArtifactService',
|
||||
]
|
||||
128
src/google/adk/artifacts/base_artifact_service.py
Normal file
128
src/google/adk/artifacts/base_artifact_service.py
Normal file
@ -0,0 +1,128 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Abstract base class for artifact services."""
|
||||
|
||||
from abc import ABC
|
||||
from abc import abstractmethod
|
||||
from typing import Optional
|
||||
|
||||
from google.genai import types
|
||||
|
||||
|
||||
class BaseArtifactService(ABC):
|
||||
"""Abstract base class for artifact services."""
|
||||
|
||||
@abstractmethod
|
||||
def save_artifact(
|
||||
self,
|
||||
*,
|
||||
app_name: str,
|
||||
user_id: str,
|
||||
session_id: str,
|
||||
filename: str,
|
||||
artifact: types.Part,
|
||||
) -> int:
|
||||
"""Saves an artifact to the artifact service storage.
|
||||
|
||||
The artifact is a file identified by the app name, user ID, session ID, and
|
||||
filename. After saving the artifact, a revision ID is returned to identify
|
||||
the artifact version.
|
||||
|
||||
Args:
|
||||
app_name: The app name.
|
||||
user_id: The user ID.
|
||||
session_id: The session ID.
|
||||
filename: The filename of the artifact.
|
||||
artifact: The artifact to save.
|
||||
|
||||
Returns:
|
||||
The revision ID. The first version of the artifact has a revision ID of 0.
|
||||
This is incremented by 1 after each successful save.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def load_artifact(
|
||||
self,
|
||||
*,
|
||||
app_name: str,
|
||||
user_id: str,
|
||||
session_id: str,
|
||||
filename: str,
|
||||
version: Optional[int] = None,
|
||||
) -> Optional[types.Part]:
|
||||
"""Gets an artifact from the artifact service storage.
|
||||
|
||||
The artifact is a file identified by the app name, user ID, session ID, and
|
||||
filename.
|
||||
|
||||
Args:
|
||||
app_name: The app name.
|
||||
user_id: The user ID.
|
||||
session_id: The session ID.
|
||||
filename: The filename of the artifact.
|
||||
version: The version of the artifact. If None, the latest version will be
|
||||
returned.
|
||||
|
||||
Returns:
|
||||
The artifact or None if not found.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def list_artifact_keys(
|
||||
self, *, app_name: str, user_id: str, session_id: str
|
||||
) -> list[str]:
|
||||
"""Lists all the artifact filenames within a session.
|
||||
|
||||
Args:
|
||||
app_name: The name of the application.
|
||||
user_id: The ID of the user.
|
||||
session_id: The ID of the session.
|
||||
|
||||
Returns:
|
||||
A list of all artifact filenames within a session.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def delete_artifact(
|
||||
self, *, app_name: str, user_id: str, session_id: str, filename: str
|
||||
) -> None:
|
||||
"""Deletes an artifact.
|
||||
|
||||
Args:
|
||||
app_name: The name of the application.
|
||||
user_id: The ID of the user.
|
||||
session_id: The ID of the session.
|
||||
filename: The name of the artifact file.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def list_versions(
|
||||
self, *, app_name: str, user_id: str, session_id: str, filename: str
|
||||
) -> list[int]:
|
||||
"""Lists all versions of an artifact.
|
||||
|
||||
Args:
|
||||
app_name: The name of the application.
|
||||
user_id: The ID of the user.
|
||||
session_id: The ID of the session.
|
||||
filename: The name of the artifact file.
|
||||
|
||||
Returns:
|
||||
A list of all available versions of the artifact.
|
||||
"""
|
||||
pass
|
||||
195
src/google/adk/artifacts/gcs_artifact_service.py
Normal file
195
src/google/adk/artifacts/gcs_artifact_service.py
Normal file
@ -0,0 +1,195 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""An artifact service implementation using Google Cloud Storage (GCS)."""
|
||||
|
||||
import logging
|
||||
from typing import Optional
|
||||
|
||||
from google.cloud import storage
|
||||
from google.genai import types
|
||||
from typing_extensions import override
|
||||
|
||||
from .base_artifact_service import BaseArtifactService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class GcsArtifactService(BaseArtifactService):
|
||||
"""An artifact service implementation using Google Cloud Storage (GCS)."""
|
||||
|
||||
def __init__(self, bucket_name: str, **kwargs):
|
||||
"""Initializes the GcsArtifactService.
|
||||
|
||||
Args:
|
||||
bucket_name: The name of the bucket to use.
|
||||
**kwargs: Keyword arguments to pass to the Google Cloud Storage client.
|
||||
"""
|
||||
self.bucket_name = bucket_name
|
||||
self.storage_client = storage.Client(**kwargs)
|
||||
self.bucket = self.storage_client.bucket(self.bucket_name)
|
||||
|
||||
def _file_has_user_namespace(self, filename: str) -> bool:
|
||||
"""Checks if the filename has a user namespace.
|
||||
|
||||
Args:
|
||||
filename: The filename to check.
|
||||
|
||||
Returns:
|
||||
True if the filename has a user namespace (starts with "user:"),
|
||||
False otherwise.
|
||||
"""
|
||||
return filename.startswith("user:")
|
||||
|
||||
def _get_blob_name(
|
||||
self,
|
||||
app_name: str,
|
||||
user_id: str,
|
||||
session_id: str,
|
||||
filename: str,
|
||||
version: int,
|
||||
) -> str:
|
||||
"""Constructs the blob name in GCS.
|
||||
|
||||
Args:
|
||||
app_name: The name of the application.
|
||||
user_id: The ID of the user.
|
||||
session_id: The ID of the session.
|
||||
filename: The name of the artifact file.
|
||||
version: The version of the artifact.
|
||||
|
||||
Returns:
|
||||
The constructed blob name in GCS.
|
||||
"""
|
||||
if self._file_has_user_namespace(filename):
|
||||
return f"{app_name}/{user_id}/user/{filename}/{version}"
|
||||
return f"{app_name}/{user_id}/{session_id}/{filename}/{version}"
|
||||
|
||||
@override
|
||||
def save_artifact(
|
||||
self,
|
||||
*,
|
||||
app_name: str,
|
||||
user_id: str,
|
||||
session_id: str,
|
||||
filename: str,
|
||||
artifact: types.Part,
|
||||
) -> int:
|
||||
versions = self.list_versions(
|
||||
app_name=app_name,
|
||||
user_id=user_id,
|
||||
session_id=session_id,
|
||||
filename=filename,
|
||||
)
|
||||
version = 0 if not versions else max(versions) + 1
|
||||
|
||||
blob_name = self._get_blob_name(
|
||||
app_name, user_id, session_id, filename, version
|
||||
)
|
||||
blob = self.bucket.blob(blob_name)
|
||||
|
||||
blob.upload_from_string(
|
||||
data=artifact.inline_data.data,
|
||||
content_type=artifact.inline_data.mime_type,
|
||||
)
|
||||
|
||||
return version
|
||||
|
||||
@override
|
||||
def load_artifact(
|
||||
self,
|
||||
*,
|
||||
app_name: str,
|
||||
user_id: str,
|
||||
session_id: str,
|
||||
filename: str,
|
||||
version: Optional[int] = None,
|
||||
) -> Optional[types.Part]:
|
||||
if version is None:
|
||||
versions = self.list_versions(
|
||||
app_name=app_name,
|
||||
user_id=user_id,
|
||||
session_id=session_id,
|
||||
filename=filename,
|
||||
)
|
||||
if not versions:
|
||||
return None
|
||||
version = max(versions)
|
||||
|
||||
blob_name = self._get_blob_name(
|
||||
app_name, user_id, session_id, filename, version
|
||||
)
|
||||
blob = self.bucket.blob(blob_name)
|
||||
|
||||
artifact_bytes = blob.download_as_bytes()
|
||||
if not artifact_bytes:
|
||||
return None
|
||||
artifact = types.Part.from_bytes(
|
||||
data=artifact_bytes, mime_type=blob.content_type
|
||||
)
|
||||
return artifact
|
||||
|
||||
@override
|
||||
def list_artifact_keys(
|
||||
self, *, app_name: str, user_id: str, session_id: str
|
||||
) -> list[str]:
|
||||
filenames = set()
|
||||
|
||||
session_prefix = f"{app_name}/{user_id}/{session_id}/"
|
||||
session_blobs = self.storage_client.list_blobs(
|
||||
self.bucket, prefix=session_prefix
|
||||
)
|
||||
for blob in session_blobs:
|
||||
_, _, _, filename, _ = blob.name.split("/")
|
||||
filenames.add(filename)
|
||||
|
||||
user_namespace_prefix = f"{app_name}/{user_id}/user/"
|
||||
user_namespace_blobs = self.storage_client.list_blobs(
|
||||
self.bucket, prefix=user_namespace_prefix
|
||||
)
|
||||
for blob in user_namespace_blobs:
|
||||
_, _, _, filename, _ = blob.name.split("/")
|
||||
filenames.add(filename)
|
||||
|
||||
return sorted(list(filenames))
|
||||
|
||||
@override
|
||||
def delete_artifact(
|
||||
self, *, app_name: str, user_id: str, session_id: str, filename: str
|
||||
) -> None:
|
||||
versions = self.list_versions(
|
||||
app_name=app_name,
|
||||
user_id=user_id,
|
||||
session_id=session_id,
|
||||
filename=filename,
|
||||
)
|
||||
for version in versions:
|
||||
blob_name = self._get_blob_name(
|
||||
app_name, user_id, session_id, filename, version
|
||||
)
|
||||
blob = self.bucket.blob(blob_name)
|
||||
blob.delete()
|
||||
return
|
||||
|
||||
@override
|
||||
def list_versions(
|
||||
self, *, app_name: str, user_id: str, session_id: str, filename: str
|
||||
) -> list[int]:
|
||||
prefix = self._get_blob_name(app_name, user_id, session_id, filename, "")
|
||||
blobs = self.storage_client.list_blobs(self.bucket, prefix=prefix)
|
||||
versions = []
|
||||
for blob in blobs:
|
||||
_, _, _, _, version = blob.name.split("/")
|
||||
versions.append(int(version))
|
||||
return versions
|
||||
133
src/google/adk/artifacts/in_memory_artifact_service.py
Normal file
133
src/google/adk/artifacts/in_memory_artifact_service.py
Normal file
@ -0,0 +1,133 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""An in-memory implementation of the artifact service."""
|
||||
|
||||
import logging
|
||||
from typing import Optional
|
||||
|
||||
from google.genai import types
|
||||
from pydantic import BaseModel
|
||||
from pydantic import Field
|
||||
from typing_extensions import override
|
||||
|
||||
from .base_artifact_service import BaseArtifactService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class InMemoryArtifactService(BaseArtifactService, BaseModel):
|
||||
"""An in-memory implementation of the artifact service."""
|
||||
|
||||
artifacts: dict[str, list[types.Part]] = Field(default_factory=dict)
|
||||
|
||||
def _file_has_user_namespace(self, filename: str) -> bool:
|
||||
"""Checks if the filename has a user namespace.
|
||||
|
||||
Args:
|
||||
filename: The filename to check.
|
||||
|
||||
Returns:
|
||||
True if the filename has a user namespace (starts with "user:"),
|
||||
False otherwise.
|
||||
"""
|
||||
return filename.startswith("user:")
|
||||
|
||||
def _artifact_path(
|
||||
self, app_name: str, user_id: str, session_id: str, filename: str
|
||||
) -> str:
|
||||
"""Constructs the artifact path.
|
||||
|
||||
Args:
|
||||
app_name: The name of the application.
|
||||
user_id: The ID of the user.
|
||||
session_id: The ID of the session.
|
||||
filename: The name of the artifact file.
|
||||
|
||||
Returns:
|
||||
The constructed artifact path.
|
||||
"""
|
||||
if self._file_has_user_namespace(filename):
|
||||
return f"{app_name}/{user_id}/user/{filename}"
|
||||
return f"{app_name}/{user_id}/{session_id}/{filename}"
|
||||
|
||||
@override
|
||||
def save_artifact(
|
||||
self,
|
||||
*,
|
||||
app_name: str,
|
||||
user_id: str,
|
||||
session_id: str,
|
||||
filename: str,
|
||||
artifact: types.Part,
|
||||
) -> int:
|
||||
path = self._artifact_path(app_name, user_id, session_id, filename)
|
||||
if path not in self.artifacts:
|
||||
self.artifacts[path] = []
|
||||
version = len(self.artifacts[path])
|
||||
self.artifacts[path].append(artifact)
|
||||
return version
|
||||
|
||||
@override
|
||||
def load_artifact(
|
||||
self,
|
||||
*,
|
||||
app_name: str,
|
||||
user_id: str,
|
||||
session_id: str,
|
||||
filename: str,
|
||||
version: Optional[int] = None,
|
||||
) -> Optional[types.Part]:
|
||||
path = self._artifact_path(app_name, user_id, session_id, filename)
|
||||
versions = self.artifacts.get(path)
|
||||
if not versions:
|
||||
return None
|
||||
if version is None:
|
||||
version = -1
|
||||
return versions[version]
|
||||
|
||||
@override
|
||||
def list_artifact_keys(
|
||||
self, *, app_name: str, user_id: str, session_id: str
|
||||
) -> list[str]:
|
||||
session_prefix = f"{app_name}/{user_id}/{session_id}/"
|
||||
usernamespace_prefix = f"{app_name}/{user_id}/user/"
|
||||
filenames = []
|
||||
for path in self.artifacts:
|
||||
if path.startswith(session_prefix):
|
||||
filename = path.removeprefix(session_prefix)
|
||||
filenames.append(filename)
|
||||
elif path.startswith(usernamespace_prefix):
|
||||
filename = path.removeprefix(usernamespace_prefix)
|
||||
filenames.append(filename)
|
||||
return sorted(filenames)
|
||||
|
||||
@override
|
||||
def delete_artifact(
|
||||
self, *, app_name: str, user_id: str, session_id: str, filename: str
|
||||
) -> None:
|
||||
path = self._artifact_path(app_name, user_id, session_id, filename)
|
||||
if not self.artifacts.get(path):
|
||||
return None
|
||||
self.artifacts.pop(path, None)
|
||||
|
||||
@override
|
||||
def list_versions(
|
||||
self, *, app_name: str, user_id: str, session_id: str, filename: str
|
||||
) -> list[int]:
|
||||
path = self._artifact_path(app_name, user_id, session_id, filename)
|
||||
versions = self.artifacts.get(path)
|
||||
if not versions:
|
||||
return []
|
||||
return list(range(len(versions)))
|
||||
22
src/google/adk/auth/__init__.py
Normal file
22
src/google/adk/auth/__init__.py
Normal file
@ -0,0 +1,22 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from .auth_credential import AuthCredential
|
||||
from .auth_credential import AuthCredentialTypes
|
||||
from .auth_credential import OAuth2Auth
|
||||
from .auth_handler import AuthHandler
|
||||
from .auth_schemes import AuthScheme
|
||||
from .auth_schemes import AuthSchemeType
|
||||
from .auth_schemes import OpenIdConnectWithConfig
|
||||
from .auth_tool import AuthConfig
|
||||
220
src/google/adk/auth/auth_credential.py
Normal file
220
src/google/adk/auth/auth_credential.py
Normal file
@ -0,0 +1,220 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
from typing import Dict
|
||||
from typing import List
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
from pydantic import Field
|
||||
|
||||
|
||||
class BaseModelWithConfig(BaseModel):
|
||||
model_config = {"extra": "allow"}
|
||||
|
||||
|
||||
class HttpCredentials(BaseModelWithConfig):
|
||||
"""Represents the secret token value for HTTP authentication, like user name, password, oauth token, etc."""
|
||||
|
||||
username: Optional[str] = None
|
||||
password: Optional[str] = None
|
||||
token: Optional[str] = None
|
||||
|
||||
@classmethod
|
||||
def model_validate(cls, data: Dict[str, Any]) -> "HttpCredentials":
|
||||
return cls(
|
||||
username=data.get("username"),
|
||||
password=data.get("password"),
|
||||
token=data.get("token"),
|
||||
)
|
||||
|
||||
|
||||
class HttpAuth(BaseModelWithConfig):
|
||||
"""The credentials and metadata for HTTP authentication."""
|
||||
|
||||
# The name of the HTTP Authorization scheme to be used in the Authorization
|
||||
# header as defined in RFC7235. The values used SHOULD be registered in the
|
||||
# IANA Authentication Scheme registry.
|
||||
# Examples: 'basic', 'bearer'
|
||||
scheme: str
|
||||
credentials: HttpCredentials
|
||||
|
||||
|
||||
class OAuth2Auth(BaseModelWithConfig):
|
||||
"""Represents credential value and its metadata for a OAuth2 credential."""
|
||||
|
||||
client_id: Optional[str] = None
|
||||
client_secret: Optional[str] = None
|
||||
# tool or adk can generate the auth_uri with the state info thus client
|
||||
# can verify the state
|
||||
auth_uri: Optional[str] = None
|
||||
state: Optional[str] = None
|
||||
# tool or adk can decide the redirect_uri if they don't want client to decide
|
||||
redirect_uri: Optional[str] = None
|
||||
auth_response_uri: Optional[str] = None
|
||||
auth_code: Optional[str] = None
|
||||
token: Optional[Dict[str, Any]] = None
|
||||
|
||||
|
||||
class ServiceAccountCredential(BaseModelWithConfig):
|
||||
"""Represents Google Service Account configuration.
|
||||
|
||||
Attributes:
|
||||
type: The type should be "service_account".
|
||||
project_id: The project ID.
|
||||
private_key_id: The ID of the private key.
|
||||
private_key: The private key.
|
||||
client_email: The client email.
|
||||
client_id: The client ID.
|
||||
auth_uri: The authorization URI.
|
||||
token_uri: The token URI.
|
||||
auth_provider_x509_cert_url: URL for auth provider's X.509 cert.
|
||||
client_x509_cert_url: URL for the client's X.509 cert.
|
||||
universe_domain: The universe domain.
|
||||
|
||||
Example:
|
||||
|
||||
config = ServiceAccountCredential(
|
||||
type_="service_account",
|
||||
project_id="your_project_id",
|
||||
private_key_id="your_private_key_id",
|
||||
private_key="-----BEGIN PRIVATE KEY-----...",
|
||||
client_email="...@....iam.gserviceaccount.com",
|
||||
client_id="your_client_id",
|
||||
auth_uri="https://accounts.google.com/o/oauth2/auth",
|
||||
token_uri="https://oauth2.googleapis.com/token",
|
||||
auth_provider_x509_cert_url="https://www.googleapis.com/oauth2/v1/certs",
|
||||
client_x509_cert_url="https://www.googleapis.com/robot/v1/metadata/x509/...",
|
||||
universe_domain="googleapis.com"
|
||||
)
|
||||
|
||||
|
||||
config = ServiceAccountConfig.model_construct(**{
|
||||
...service account config dict
|
||||
})
|
||||
"""
|
||||
|
||||
type_: str = Field("", alias="type")
|
||||
project_id: str
|
||||
private_key_id: str
|
||||
private_key: str
|
||||
client_email: str
|
||||
client_id: str
|
||||
auth_uri: str
|
||||
token_uri: str
|
||||
auth_provider_x509_cert_url: str
|
||||
client_x509_cert_url: str
|
||||
universe_domain: str
|
||||
|
||||
|
||||
class ServiceAccount(BaseModelWithConfig):
|
||||
"""Represents Google Service Account configuration."""
|
||||
|
||||
service_account_credential: Optional[ServiceAccountCredential] = None
|
||||
scopes: List[str]
|
||||
use_default_credential: Optional[bool] = False
|
||||
|
||||
|
||||
class AuthCredentialTypes(str, Enum):
|
||||
"""Represents the type of authentication credential."""
|
||||
|
||||
# API Key credential:
|
||||
# https://swagger.io/docs/specification/v3_0/authentication/api-keys/
|
||||
API_KEY = "apiKey"
|
||||
|
||||
# Credentials for HTTP Auth schemes:
|
||||
# https://www.iana.org/assignments/http-authschemes/http-authschemes.xhtml
|
||||
HTTP = "http"
|
||||
|
||||
# OAuth2 credentials:
|
||||
# https://swagger.io/docs/specification/v3_0/authentication/oauth2/
|
||||
OAUTH2 = "oauth2"
|
||||
|
||||
# OpenID Connect credentials:
|
||||
# https://swagger.io/docs/specification/v3_0/authentication/openid-connect-discovery/
|
||||
OPEN_ID_CONNECT = "openIdConnect"
|
||||
|
||||
# Service Account credentials:
|
||||
# https://cloud.google.com/iam/docs/service-account-creds
|
||||
SERVICE_ACCOUNT = "serviceAccount"
|
||||
|
||||
|
||||
class AuthCredential(BaseModelWithConfig):
|
||||
"""Data class representing an authentication credential.
|
||||
|
||||
To exchange for the actual credential, please use
|
||||
CredentialExchanger.exchange_credential().
|
||||
|
||||
Examples: API Key Auth
|
||||
AuthCredential(
|
||||
auth_type=AuthCredentialTypes.API_KEY,
|
||||
api_key="1234",
|
||||
)
|
||||
|
||||
Example: HTTP Auth
|
||||
AuthCredential(
|
||||
auth_type=AuthCredentialTypes.HTTP,
|
||||
http=HttpAuth(
|
||||
scheme="basic",
|
||||
credentials=HttpCredentials(username="user", password="password"),
|
||||
),
|
||||
)
|
||||
|
||||
Example: OAuth2 Bearer Token in HTTP Header
|
||||
AuthCredential(
|
||||
auth_type=AuthCredentialTypes.HTTP,
|
||||
http=HttpAuth(
|
||||
scheme="bearer",
|
||||
credentials=HttpCredentials(token="eyAkaknabna...."),
|
||||
),
|
||||
)
|
||||
|
||||
Example: OAuth2 Auth with Authorization Code Flow
|
||||
AuthCredential(
|
||||
auth_type=AuthCredentialTypes.OAUTH2,
|
||||
oauth2=OAuth2Auth(
|
||||
client_id="1234",
|
||||
client_secret="secret",
|
||||
),
|
||||
)
|
||||
|
||||
Example: OpenID Connect Auth
|
||||
AuthCredential(
|
||||
auth_type=AuthCredentialTypes.OPEN_ID_CONNECT,
|
||||
oauth2=OAuth2Auth(
|
||||
client_id="1234",
|
||||
client_secret="secret",
|
||||
redirect_uri="https://example.com",
|
||||
scopes=["scope1", "scope2"],
|
||||
),
|
||||
)
|
||||
|
||||
Example: Auth with resource reference
|
||||
AuthCredential(
|
||||
auth_type=AuthCredentialTypes.API_KEY,
|
||||
resource_ref="projects/1234/locations/us-central1/resources/resource1",
|
||||
)
|
||||
"""
|
||||
|
||||
auth_type: AuthCredentialTypes
|
||||
# Resource reference for the credential.
|
||||
# This will be supported in the future.
|
||||
resource_ref: Optional[str] = None
|
||||
|
||||
api_key: Optional[str] = None
|
||||
http: Optional[HttpAuth] = None
|
||||
service_account: Optional[ServiceAccount] = None
|
||||
oauth2: Optional[OAuth2Auth] = None
|
||||
265
src/google/adk/auth/auth_handler.py
Normal file
265
src/google/adk/auth/auth_handler.py
Normal file
@ -0,0 +1,265 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from fastapi.openapi.models import OAuth2
|
||||
from fastapi.openapi.models import SecurityBase
|
||||
|
||||
from .auth_credential import AuthCredential
|
||||
from .auth_credential import AuthCredentialTypes
|
||||
from .auth_credential import OAuth2Auth
|
||||
from .auth_schemes import AuthSchemeType
|
||||
from .auth_schemes import OAuthGrantType
|
||||
from .auth_schemes import OpenIdConnectWithConfig
|
||||
from .auth_tool import AuthConfig
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..sessions.state import State
|
||||
|
||||
try:
|
||||
from authlib.integrations.requests_client import OAuth2Session
|
||||
|
||||
SUPPORT_TOKEN_EXCHANGE = True
|
||||
except ImportError:
|
||||
SUPPORT_TOKEN_EXCHANGE = False
|
||||
|
||||
|
||||
class AuthHandler:
|
||||
|
||||
def __init__(self, auth_config: AuthConfig):
|
||||
self.auth_config = auth_config
|
||||
|
||||
def exchange_auth_token(
|
||||
self,
|
||||
) -> AuthCredential:
|
||||
"""Generates an auth token from the authorization response.
|
||||
|
||||
Returns:
|
||||
An AuthCredential object containing the access token.
|
||||
|
||||
Raises:
|
||||
ValueError: If the token endpoint is not configured in the auth
|
||||
scheme.
|
||||
AuthCredentialMissingError: If the access token cannot be retrieved
|
||||
from the token endpoint.
|
||||
"""
|
||||
auth_scheme = self.auth_config.auth_scheme
|
||||
auth_credential = self.auth_config.exchanged_auth_credential
|
||||
if not SUPPORT_TOKEN_EXCHANGE:
|
||||
return auth_credential
|
||||
if isinstance(auth_scheme, OpenIdConnectWithConfig):
|
||||
if not hasattr(auth_scheme, "token_endpoint"):
|
||||
return self.auth_config.exchanged_auth_credential
|
||||
token_endpoint = auth_scheme.token_endpoint
|
||||
scopes = auth_scheme.scopes
|
||||
elif isinstance(auth_scheme, OAuth2):
|
||||
if (
|
||||
not auth_scheme.flows.authorizationCode
|
||||
or not auth_scheme.flows.authorizationCode.tokenUrl
|
||||
):
|
||||
return self.auth_config.exchanged_auth_credential
|
||||
token_endpoint = auth_scheme.flows.authorizationCode.tokenUrl
|
||||
scopes = list(auth_scheme.flows.authorizationCode.scopes.keys())
|
||||
else:
|
||||
return self.auth_config.exchanged_auth_credential
|
||||
|
||||
if (
|
||||
not auth_credential
|
||||
or not auth_credential.oauth2
|
||||
or not auth_credential.oauth2.client_id
|
||||
or not auth_credential.oauth2.client_secret
|
||||
or auth_credential.oauth2.token
|
||||
):
|
||||
return self.auth_config.exchanged_auth_credential
|
||||
|
||||
client = OAuth2Session(
|
||||
auth_credential.oauth2.client_id,
|
||||
auth_credential.oauth2.client_secret,
|
||||
scope=",".join(scopes),
|
||||
redirect_uri=auth_credential.oauth2.redirect_uri,
|
||||
state=auth_credential.oauth2.state,
|
||||
)
|
||||
token = client.fetch_token(
|
||||
token_endpoint,
|
||||
authorization_response=auth_credential.oauth2.auth_response_uri,
|
||||
code=auth_credential.oauth2.auth_code,
|
||||
grant_type=OAuthGrantType.AUTHORIZATION_CODE,
|
||||
)
|
||||
|
||||
updated_credential = AuthCredential(
|
||||
auth_type=AuthCredentialTypes.OAUTH2,
|
||||
oauth2=OAuth2Auth(token=dict(token)),
|
||||
)
|
||||
return updated_credential
|
||||
|
||||
def parse_and_store_auth_response(self, state: State) -> None:
|
||||
|
||||
credential_key = self.get_credential_key()
|
||||
|
||||
state[credential_key] = self.auth_config.exchanged_auth_credential
|
||||
if not isinstance(
|
||||
self.auth_config.auth_scheme, SecurityBase
|
||||
) or self.auth_config.auth_scheme.type_ not in (
|
||||
AuthSchemeType.oauth2,
|
||||
AuthSchemeType.openIdConnect,
|
||||
):
|
||||
return
|
||||
|
||||
state[credential_key] = self.exchange_auth_token()
|
||||
|
||||
def _validate(self) -> None:
|
||||
if not self.auth_scheme:
|
||||
raise ValueError("auth_scheme is empty.")
|
||||
|
||||
def get_auth_response(self, state: State) -> AuthCredential:
|
||||
credential_key = self.get_credential_key()
|
||||
return state.get(credential_key, None)
|
||||
|
||||
def generate_auth_request(self) -> AuthConfig:
|
||||
if not isinstance(
|
||||
self.auth_config.auth_scheme, SecurityBase
|
||||
) or self.auth_config.auth_scheme.type_ not in (
|
||||
AuthSchemeType.oauth2,
|
||||
AuthSchemeType.openIdConnect,
|
||||
):
|
||||
return self.auth_config.model_copy(deep=True)
|
||||
|
||||
# auth_uri already in exchanged credential
|
||||
if (
|
||||
self.auth_config.exchanged_auth_credential
|
||||
and self.auth_config.exchanged_auth_credential.oauth2
|
||||
and self.auth_config.exchanged_auth_credential.oauth2.auth_uri
|
||||
):
|
||||
return self.auth_config.model_copy(deep=True)
|
||||
|
||||
# Check if raw_auth_credential exists
|
||||
if not self.auth_config.raw_auth_credential:
|
||||
raise ValueError(
|
||||
f"Auth Scheme {self.auth_config.auth_scheme.type_} requires"
|
||||
" auth_credential."
|
||||
)
|
||||
|
||||
# Check if oauth2 exists in raw_auth_credential
|
||||
if not self.auth_config.raw_auth_credential.oauth2:
|
||||
raise ValueError(
|
||||
f"Auth Scheme {self.auth_config.auth_scheme.type_} requires oauth2 in"
|
||||
" auth_credential."
|
||||
)
|
||||
|
||||
# auth_uri in raw credential
|
||||
if self.auth_config.raw_auth_credential.oauth2.auth_uri:
|
||||
return AuthConfig(
|
||||
auth_scheme=self.auth_config.auth_scheme,
|
||||
raw_auth_credential=self.auth_config.raw_auth_credential,
|
||||
exchanged_auth_credential=self.auth_config.raw_auth_credential.model_copy(
|
||||
deep=True
|
||||
),
|
||||
)
|
||||
|
||||
# Check for client_id and client_secret
|
||||
if (
|
||||
not self.auth_config.raw_auth_credential.oauth2.client_id
|
||||
or not self.auth_config.raw_auth_credential.oauth2.client_secret
|
||||
):
|
||||
raise ValueError(
|
||||
f"Auth Scheme {self.auth_config.auth_scheme.type_} requires both"
|
||||
" client_id and client_secret in auth_credential.oauth2."
|
||||
)
|
||||
|
||||
# Generate new auth URI
|
||||
exchanged_credential = self.generate_auth_uri()
|
||||
return AuthConfig(
|
||||
auth_scheme=self.auth_config.auth_scheme,
|
||||
raw_auth_credential=self.auth_config.raw_auth_credential,
|
||||
exchanged_auth_credential=exchanged_credential,
|
||||
)
|
||||
|
||||
def get_credential_key(self) -> str:
|
||||
"""Generates a unique key for the given auth scheme and credential."""
|
||||
auth_scheme = self.auth_config.auth_scheme
|
||||
auth_credential = self.auth_config.raw_auth_credential
|
||||
if auth_scheme.model_extra:
|
||||
auth_scheme = auth_scheme.model_copy(deep=True)
|
||||
auth_scheme.model_extra.clear()
|
||||
scheme_name = (
|
||||
f"{auth_scheme.type_.name}_{hash(auth_scheme.model_dump_json())}"
|
||||
if auth_scheme
|
||||
else ""
|
||||
)
|
||||
if auth_credential.model_extra:
|
||||
auth_credential = auth_credential.model_copy(deep=True)
|
||||
auth_credential.model_extra.clear()
|
||||
credential_name = (
|
||||
f"{auth_credential.auth_type.value}_{hash(auth_credential.model_dump_json())}"
|
||||
if auth_credential
|
||||
else ""
|
||||
)
|
||||
|
||||
return f"temp:adk_{scheme_name}_{credential_name}"
|
||||
|
||||
def generate_auth_uri(
|
||||
self,
|
||||
) -> AuthCredential:
|
||||
"""Generates an response containing the auth uri for user to sign in.
|
||||
|
||||
Returns:
|
||||
An AuthCredential object containing the auth URI and state.
|
||||
|
||||
Raises:
|
||||
ValueError: If the authorization endpoint is not configured in the auth
|
||||
scheme.
|
||||
"""
|
||||
auth_scheme = self.auth_config.auth_scheme
|
||||
auth_credential = self.auth_config.raw_auth_credential
|
||||
|
||||
if isinstance(auth_scheme, OpenIdConnectWithConfig):
|
||||
authorization_endpoint = auth_scheme.authorization_endpoint
|
||||
scopes = auth_scheme.scopes
|
||||
else:
|
||||
authorization_endpoint = (
|
||||
auth_scheme.flows.implicit
|
||||
and auth_scheme.flows.implicit.authorizationUrl
|
||||
or auth_scheme.flows.authorizationCode
|
||||
and auth_scheme.flows.authorizationCode.authorizationUrl
|
||||
or auth_scheme.flows.clientCredentials
|
||||
and auth_scheme.flows.clientCredentials.tokenUrl
|
||||
or auth_scheme.flows.password
|
||||
and auth_scheme.flows.password.tokenUrl
|
||||
)
|
||||
scopes = (
|
||||
auth_scheme.flows.implicit
|
||||
and auth_scheme.flows.implicit.scopes
|
||||
or auth_scheme.flows.authorizationCode
|
||||
and auth_scheme.flows.authorizationCode.scopes
|
||||
or auth_scheme.flows.clientCredentials
|
||||
and auth_scheme.flows.clientCredentials.scopes
|
||||
or auth_scheme.flows.password
|
||||
and auth_scheme.flows.password.scopes
|
||||
)
|
||||
|
||||
client = OAuth2Session(
|
||||
auth_credential.oauth2.client_id,
|
||||
auth_credential.oauth2.client_secret,
|
||||
scope=" ".join(scopes),
|
||||
redirect_uri=auth_credential.oauth2.redirect_uri,
|
||||
)
|
||||
uri, state = client.create_authorization_url(url=authorization_endpoint)
|
||||
exchanged_auth_credential = auth_credential.model_copy(deep=True)
|
||||
exchanged_auth_credential.oauth2.auth_uri = uri
|
||||
exchanged_auth_credential.oauth2.state = state
|
||||
|
||||
return exchanged_auth_credential
|
||||
116
src/google/adk/auth/auth_preprocessor.py
Normal file
116
src/google/adk/auth/auth_preprocessor.py
Normal file
@ -0,0 +1,116 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import AsyncGenerator
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from typing_extensions import override
|
||||
|
||||
from ..agents.invocation_context import InvocationContext
|
||||
from ..events.event import Event
|
||||
from ..flows.llm_flows import functions
|
||||
from ..flows.llm_flows._base_llm_processor import BaseLlmRequestProcessor
|
||||
from ..flows.llm_flows.functions import REQUEST_EUC_FUNCTION_CALL_NAME
|
||||
from ..models.llm_request import LlmRequest
|
||||
from .auth_handler import AuthHandler
|
||||
from .auth_tool import AuthConfig
|
||||
from .auth_tool import AuthToolArguments
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..agents.llm_agent import LlmAgent
|
||||
|
||||
|
||||
class _AuthLlmRequestProcessor(BaseLlmRequestProcessor):
|
||||
"""Handles auth information to build the LLM request."""
|
||||
|
||||
@override
|
||||
async def run_async(
|
||||
self, invocation_context: InvocationContext, llm_request: LlmRequest
|
||||
) -> AsyncGenerator[Event, None]:
|
||||
from ..agents.llm_agent import LlmAgent
|
||||
|
||||
agent = invocation_context.agent
|
||||
if not isinstance(agent, LlmAgent):
|
||||
return
|
||||
events = invocation_context.session.events
|
||||
if not events:
|
||||
return
|
||||
request_euc_function_call_response_event = events[-1]
|
||||
responses = (
|
||||
request_euc_function_call_response_event.get_function_responses()
|
||||
)
|
||||
if not responses:
|
||||
return
|
||||
|
||||
request_euc_function_call_ids = set()
|
||||
|
||||
for function_call_response in responses:
|
||||
if function_call_response.name != REQUEST_EUC_FUNCTION_CALL_NAME:
|
||||
continue
|
||||
|
||||
# found the function call response for the system long running request euc
|
||||
# function call
|
||||
request_euc_function_call_ids.add(function_call_response.id)
|
||||
auth_config = AuthConfig.model_validate(function_call_response.response)
|
||||
AuthHandler(auth_config=auth_config).parse_and_store_auth_response(
|
||||
state=invocation_context.session.state
|
||||
)
|
||||
|
||||
if not request_euc_function_call_ids:
|
||||
return
|
||||
|
||||
for i in range(len(events) - 2, -1, -1):
|
||||
event = events[i]
|
||||
# looking for the system long running reqeust euc function call
|
||||
function_calls = event.get_function_calls()
|
||||
if not function_calls:
|
||||
continue
|
||||
|
||||
tools_to_resume = set()
|
||||
|
||||
for function_call in function_calls:
|
||||
if function_call.id not in request_euc_function_call_ids:
|
||||
continue
|
||||
args = AuthToolArguments.model_validate(function_call.args)
|
||||
|
||||
tools_to_resume.add(args.function_call_id)
|
||||
if not tools_to_resume:
|
||||
continue
|
||||
# found the the system long running reqeust euc function call
|
||||
# looking for original function call that requests euc
|
||||
for j in range(i - 1, -1, -1):
|
||||
event = events[j]
|
||||
function_calls = event.get_function_calls()
|
||||
if not function_calls:
|
||||
continue
|
||||
for function_call in function_calls:
|
||||
function_response_event = None
|
||||
if function_call.id in tools_to_resume:
|
||||
function_response_event = await functions.handle_function_calls_async(
|
||||
invocation_context,
|
||||
event,
|
||||
{tool.name: tool for tool in agent.canonical_tools},
|
||||
# there could be parallel function calls that require auth
|
||||
# auth response would be a dict keyed by function call id
|
||||
tools_to_resume,
|
||||
)
|
||||
if function_response_event:
|
||||
yield function_response_event
|
||||
return
|
||||
return
|
||||
|
||||
|
||||
request_processor = _AuthLlmRequestProcessor()
|
||||
67
src/google/adk/auth/auth_schemes.py
Normal file
67
src/google/adk/auth/auth_schemes.py
Normal file
@ -0,0 +1,67 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from enum import Enum
|
||||
from typing import List
|
||||
from typing import Optional
|
||||
from typing import Union
|
||||
|
||||
from fastapi.openapi.models import OAuthFlows
|
||||
from fastapi.openapi.models import SecurityBase
|
||||
from fastapi.openapi.models import SecurityScheme
|
||||
from fastapi.openapi.models import SecuritySchemeType
|
||||
from pydantic import Field
|
||||
|
||||
|
||||
class OpenIdConnectWithConfig(SecurityBase):
|
||||
type_: SecuritySchemeType = Field(
|
||||
default=SecuritySchemeType.openIdConnect, alias="type"
|
||||
)
|
||||
authorization_endpoint: str
|
||||
token_endpoint: str
|
||||
userinfo_endpoint: Optional[str] = None
|
||||
revocation_endpoint: Optional[str] = None
|
||||
token_endpoint_auth_methods_supported: Optional[List[str]] = None
|
||||
grant_types_supported: Optional[List[str]] = None
|
||||
scopes: Optional[List[str]] = None
|
||||
|
||||
|
||||
# AuthSchemes contains SecuritySchemes from OpenAPI 3.0 and an extra flattened OpenIdConnectWithConfig.
|
||||
AuthScheme = Union[SecurityScheme, OpenIdConnectWithConfig]
|
||||
|
||||
|
||||
class OAuthGrantType(str, Enum):
|
||||
"""Represents the OAuth2 flow (or grant type)."""
|
||||
|
||||
CLIENT_CREDENTIALS = "client_credentials"
|
||||
AUTHORIZATION_CODE = "authorization_code"
|
||||
IMPLICIT = "implicit"
|
||||
PASSWORD = "password"
|
||||
|
||||
@staticmethod
|
||||
def from_flow(flow: OAuthFlows) -> "OAuthGrantType":
|
||||
"""Converts an OAuthFlows object to a OAuthGrantType."""
|
||||
if flow.clientCredentials:
|
||||
return OAuthGrantType.CLIENT_CREDENTIALS
|
||||
if flow.authorizationCode:
|
||||
return OAuthGrantType.AUTHORIZATION_CODE
|
||||
if flow.implicit:
|
||||
return OAuthGrantType.IMPLICIT
|
||||
if flow.password:
|
||||
return OAuthGrantType.PASSWORD
|
||||
return None
|
||||
|
||||
|
||||
# AuthSchemeType re-exports SecuritySchemeType from OpenAPI 3.0.
|
||||
AuthSchemeType = SecuritySchemeType
|
||||
55
src/google/adk/auth/auth_tool.py
Normal file
55
src/google/adk/auth/auth_tool.py
Normal file
@ -0,0 +1,55 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from .auth_credential import AuthCredential
|
||||
from .auth_schemes import AuthScheme
|
||||
|
||||
|
||||
class AuthConfig(BaseModel):
|
||||
"""The auth config sent by tool asking client to collect auth credentails and
|
||||
|
||||
adk and client will help to fill in the response
|
||||
"""
|
||||
|
||||
auth_scheme: AuthScheme
|
||||
"""The auth scheme used to collect credentials"""
|
||||
raw_auth_credential: AuthCredential = None
|
||||
"""The raw auth credential used to collect credentials. The raw auth
|
||||
credentials are used in some auth scheme that needs to exchange auth
|
||||
credentials. e.g. OAuth2 and OIDC. For other auth scheme, it could be None.
|
||||
"""
|
||||
exchanged_auth_credential: AuthCredential = None
|
||||
"""The exchanged auth credential used to collect credentials. adk and client
|
||||
will work together to fill it. For those auth scheme that doesn't need to
|
||||
exchange auth credentials, e.g. API key, service account etc. It's filled by
|
||||
client directly. For those auth scheme that need to exchange auth credentials,
|
||||
e.g. OAuth2 and OIDC, it's first filled by adk. If the raw credentials
|
||||
passed by tool only has client id and client credential, adk will help to
|
||||
generate the corresponding authorization uri and state and store the processed
|
||||
credential in this field. If the raw credentials passed by tool already has
|
||||
authorization uri, state, etc. then it's copied to this field. Client will use
|
||||
this field to guide the user through the OAuth2 flow and fill auth response in
|
||||
this field"""
|
||||
|
||||
|
||||
class AuthToolArguments(BaseModel):
|
||||
"""the arguments for the special long running function tool that is used to
|
||||
|
||||
request end user credentials.
|
||||
"""
|
||||
|
||||
function_call_id: str
|
||||
auth_config: AuthConfig
|
||||
15
src/google/adk/cli/__init__.py
Normal file
15
src/google/adk/cli/__init__.py
Normal file
@ -0,0 +1,15 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from .cli_tools_click import main
|
||||
18
src/google/adk/cli/__main__.py
Normal file
18
src/google/adk/cli/__main__.py
Normal file
@ -0,0 +1,18 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from .cli_tools_click import main
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
122
src/google/adk/cli/agent_graph.py
Normal file
122
src/google/adk/cli/agent_graph.py
Normal file
@ -0,0 +1,122 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Union
|
||||
|
||||
import graphviz
|
||||
|
||||
from ..agents import BaseAgent
|
||||
from ..agents.llm_agent import LlmAgent
|
||||
from ..tools.agent_tool import AgentTool
|
||||
from ..tools.base_tool import BaseTool
|
||||
from ..tools.function_tool import FunctionTool
|
||||
from ..tools.retrieval.base_retrieval_tool import BaseRetrievalTool
|
||||
|
||||
|
||||
def build_graph(graph, agent: BaseAgent, highlight_pairs):
|
||||
dark_green = '#0F5223'
|
||||
light_green = '#69CB87'
|
||||
light_gray = '#cccccc'
|
||||
|
||||
def get_node_name(tool_or_agent: Union[BaseAgent, BaseTool]):
|
||||
if isinstance(tool_or_agent, BaseAgent):
|
||||
return tool_or_agent.name
|
||||
elif isinstance(tool_or_agent, BaseTool):
|
||||
return tool_or_agent.name
|
||||
else:
|
||||
raise ValueError(f'Unsupported tool type: {tool_or_agent}')
|
||||
|
||||
def get_node_caption(tool_or_agent: Union[BaseAgent, BaseTool]):
|
||||
if isinstance(tool_or_agent, BaseAgent):
|
||||
return '🤖 ' + tool_or_agent.name
|
||||
elif isinstance(tool_or_agent, BaseRetrievalTool):
|
||||
return '🔎 ' + tool_or_agent.name
|
||||
elif isinstance(tool_or_agent, FunctionTool):
|
||||
return '🔧 ' + tool_or_agent.name
|
||||
elif isinstance(tool_or_agent, AgentTool):
|
||||
return '🤖 ' + tool_or_agent.name
|
||||
elif isinstance(tool_or_agent, BaseTool):
|
||||
return '🔧 ' + tool_or_agent.name
|
||||
else:
|
||||
raise ValueError(f'Unsupported tool type: {type(tool)}')
|
||||
|
||||
def get_node_shape(tool_or_agent: Union[BaseAgent, BaseTool]):
|
||||
if isinstance(tool_or_agent, BaseAgent):
|
||||
return 'ellipse'
|
||||
elif isinstance(tool_or_agent, BaseRetrievalTool):
|
||||
return 'cylinder'
|
||||
elif isinstance(tool_or_agent, FunctionTool):
|
||||
return 'box'
|
||||
elif isinstance(tool_or_agent, BaseTool):
|
||||
return 'box'
|
||||
else:
|
||||
raise ValueError(f'Unsupported tool type: {type(tool_or_agent)}')
|
||||
|
||||
def draw_node(tool_or_agent: Union[BaseAgent, BaseTool]):
|
||||
name = get_node_name(tool_or_agent)
|
||||
shape = get_node_shape(tool_or_agent)
|
||||
caption = get_node_caption(tool_or_agent)
|
||||
if highlight_pairs:
|
||||
for highlight_tuple in highlight_pairs:
|
||||
if name in highlight_tuple:
|
||||
graph.node(
|
||||
name,
|
||||
caption,
|
||||
style='filled,rounded',
|
||||
fillcolor=dark_green,
|
||||
color=dark_green,
|
||||
shape=shape,
|
||||
fontcolor=light_gray,
|
||||
)
|
||||
return
|
||||
# if not in highlight, draw non-highliht node
|
||||
graph.node(
|
||||
name,
|
||||
caption,
|
||||
shape=shape,
|
||||
style='rounded',
|
||||
color=light_gray,
|
||||
fontcolor=light_gray,
|
||||
)
|
||||
|
||||
def draw_edge(from_name, to_name):
|
||||
if highlight_pairs:
|
||||
for highlight_from, highlight_to in highlight_pairs:
|
||||
if from_name == highlight_from and to_name == highlight_to:
|
||||
graph.edge(from_name, to_name, color=light_green)
|
||||
return
|
||||
elif from_name == highlight_to and to_name == highlight_from:
|
||||
graph.edge(from_name, to_name, color=light_green, dir='back')
|
||||
return
|
||||
# if no need to highlight, color gray
|
||||
graph.edge(from_name, to_name, arrowhead='none', color=light_gray)
|
||||
|
||||
draw_node(agent)
|
||||
for sub_agent in agent.sub_agents:
|
||||
build_graph(graph, sub_agent, highlight_pairs)
|
||||
draw_edge(agent.name, sub_agent.name)
|
||||
if isinstance(agent, LlmAgent):
|
||||
for tool in agent.canonical_tools:
|
||||
draw_node(tool)
|
||||
draw_edge(agent.name, get_node_name(tool))
|
||||
|
||||
|
||||
def get_agent_graph(root_agent, highlights_pairs, image=False):
|
||||
print('build graph')
|
||||
graph = graphviz.Digraph(graph_attr={'rankdir': 'LR', 'bgcolor': '#333537'})
|
||||
build_graph(graph, root_agent, highlights_pairs)
|
||||
if image:
|
||||
return graph.pipe(format='png')
|
||||
else:
|
||||
return graph
|
||||
181
src/google/adk/cli/cli.py
Normal file
181
src/google/adk/cli/cli.py
Normal file
@ -0,0 +1,181 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from datetime import datetime
|
||||
import importlib
|
||||
import os
|
||||
import sys
|
||||
from typing import Optional
|
||||
|
||||
import click
|
||||
from google.genai import types
|
||||
from pydantic import BaseModel
|
||||
|
||||
from ..agents.llm_agent import LlmAgent
|
||||
from ..artifacts import BaseArtifactService
|
||||
from ..artifacts import InMemoryArtifactService
|
||||
from ..runners import Runner
|
||||
from ..sessions.base_session_service import BaseSessionService
|
||||
from ..sessions.in_memory_session_service import InMemorySessionService
|
||||
from ..sessions.session import Session
|
||||
from .utils import envs
|
||||
|
||||
|
||||
class InputFile(BaseModel):
|
||||
state: dict[str, object]
|
||||
queries: list[str]
|
||||
|
||||
|
||||
async def run_input_file(
|
||||
app_name: str,
|
||||
root_agent: LlmAgent,
|
||||
artifact_service: BaseArtifactService,
|
||||
session: Session,
|
||||
session_service: BaseSessionService,
|
||||
input_path: str,
|
||||
) -> None:
|
||||
runner = Runner(
|
||||
app_name=app_name,
|
||||
agent=root_agent,
|
||||
artifact_service=artifact_service,
|
||||
session_service=session_service,
|
||||
)
|
||||
with open(input_path, 'r', encoding='utf-8') as f:
|
||||
input_file = InputFile.model_validate_json(f.read())
|
||||
input_file.state['_time'] = datetime.now()
|
||||
|
||||
session.state = input_file.state
|
||||
for query in input_file.queries:
|
||||
click.echo(f'user: {query}')
|
||||
content = types.Content(role='user', parts=[types.Part(text=query)])
|
||||
async for event in runner.run_async(
|
||||
user_id=session.user_id, session_id=session.id, new_message=content
|
||||
):
|
||||
if event.content and event.content.parts:
|
||||
if text := ''.join(part.text or '' for part in event.content.parts):
|
||||
click.echo(f'[{event.author}]: {text}')
|
||||
|
||||
|
||||
async def run_interactively(
|
||||
app_name: str,
|
||||
root_agent: LlmAgent,
|
||||
artifact_service: BaseArtifactService,
|
||||
session: Session,
|
||||
session_service: BaseSessionService,
|
||||
) -> None:
|
||||
runner = Runner(
|
||||
app_name=app_name,
|
||||
agent=root_agent,
|
||||
artifact_service=artifact_service,
|
||||
session_service=session_service,
|
||||
)
|
||||
while True:
|
||||
query = input('user: ')
|
||||
if query == 'exit':
|
||||
break
|
||||
async for event in runner.run_async(
|
||||
user_id=session.user_id,
|
||||
session_id=session.id,
|
||||
new_message=types.Content(role='user', parts=[types.Part(text=query)]),
|
||||
):
|
||||
if event.content and event.content.parts:
|
||||
if text := ''.join(part.text or '' for part in event.content.parts):
|
||||
click.echo(f'[{event.author}]: {text}')
|
||||
|
||||
|
||||
async def run_cli(
|
||||
*,
|
||||
agent_parent_dir: str,
|
||||
agent_folder_name: str,
|
||||
json_file_path: Optional[str] = None,
|
||||
save_session: bool,
|
||||
) -> None:
|
||||
"""Runs an interactive CLI for a certain agent.
|
||||
|
||||
Args:
|
||||
agent_parent_dir: str, the absolute path of the parent folder of the agent
|
||||
folder.
|
||||
agent_folder_name: str, the name of the agent folder.
|
||||
json_file_path: Optional[str], the absolute path to the json file, either
|
||||
*.input.json or *.session.json.
|
||||
save_session: bool, whether to save the session on exit.
|
||||
"""
|
||||
if agent_parent_dir not in sys.path:
|
||||
sys.path.append(agent_parent_dir)
|
||||
|
||||
artifact_service = InMemoryArtifactService()
|
||||
session_service = InMemorySessionService()
|
||||
session = session_service.create_session(
|
||||
app_name=agent_folder_name, user_id='test_user'
|
||||
)
|
||||
|
||||
agent_module_path = os.path.join(agent_parent_dir, agent_folder_name)
|
||||
agent_module = importlib.import_module(agent_folder_name)
|
||||
root_agent = agent_module.agent.root_agent
|
||||
envs.load_dotenv_for_agent(agent_folder_name, agent_parent_dir)
|
||||
if json_file_path:
|
||||
if json_file_path.endswith('.input.json'):
|
||||
await run_input_file(
|
||||
app_name=agent_folder_name,
|
||||
root_agent=root_agent,
|
||||
artifact_service=artifact_service,
|
||||
session=session,
|
||||
session_service=session_service,
|
||||
input_path=json_file_path,
|
||||
)
|
||||
elif json_file_path.endswith('.session.json'):
|
||||
with open(json_file_path, 'r') as f:
|
||||
session = Session.model_validate_json(f.read())
|
||||
for content in session.get_contents():
|
||||
if content.role == 'user':
|
||||
print('user: ', content.parts[0].text)
|
||||
else:
|
||||
print(content.parts[0].text)
|
||||
await run_interactively(
|
||||
agent_folder_name,
|
||||
root_agent,
|
||||
artifact_service,
|
||||
session,
|
||||
session_service,
|
||||
)
|
||||
else:
|
||||
print(f'Unsupported file type: {json_file_path}')
|
||||
exit(1)
|
||||
else:
|
||||
print(f'Running agent {root_agent.name}, type exit to exit.')
|
||||
await run_interactively(
|
||||
agent_folder_name,
|
||||
root_agent,
|
||||
artifact_service,
|
||||
session,
|
||||
session_service,
|
||||
)
|
||||
|
||||
if save_session:
|
||||
if json_file_path:
|
||||
session_path = json_file_path.replace('.input.json', '.session.json')
|
||||
else:
|
||||
session_id = input('Session ID to save: ')
|
||||
session_path = f'{agent_module_path}/{session_id}.session.json'
|
||||
with open(session_path, 'w') as f:
|
||||
f.write(session.model_dump_json(indent=2, exclude_none=True))
|
||||
# TODO: Save from opentelemetry.
|
||||
# logs_path = session_path.replace('.session.json', '.logs.json')
|
||||
# with open(logs_path, 'w') as f:
|
||||
# f.write(
|
||||
# session.model_dump_json(
|
||||
# indent=2, exclude_none=True, include='event_logs'
|
||||
# )
|
||||
# )
|
||||
print('Session saved to', session_path)
|
||||
181
src/google/adk/cli/cli_deploy.py
Normal file
181
src/google/adk/cli/cli_deploy.py
Normal file
@ -0,0 +1,181 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
import shutil
|
||||
import subprocess
|
||||
from typing import Optional
|
||||
|
||||
import click
|
||||
|
||||
_DOCKERFILE_TEMPLATE = """
|
||||
FROM python:3.11-slim
|
||||
WORKDIR /app
|
||||
|
||||
# Create a non-root user
|
||||
RUN adduser --disabled-password --gecos "" myuser
|
||||
|
||||
# Change ownership of /app to myuser
|
||||
RUN chown -R myuser:myuser /app
|
||||
|
||||
# Switch to the non-root user
|
||||
USER myuser
|
||||
|
||||
# Set up environment variables - Start
|
||||
ENV PATH="/home/myuser/.local/bin:$PATH"
|
||||
|
||||
ENV GOOGLE_GENAI_USE_VERTEXAI=1
|
||||
# TODO: use passed-in value
|
||||
ENV GOOGLE_CLOUD_PROJECT={gcp_project_id}
|
||||
ENV GOOGLE_CLOUD_LOCATION={gcp_region}
|
||||
ENV ADK_TRACE_TO_CLOUD={with_cloud_trace}
|
||||
|
||||
# Set up environment variables - End
|
||||
|
||||
# Install ADK - Start
|
||||
RUN pip install google-adk
|
||||
# Install ADK - End
|
||||
|
||||
# Copy agent - Start
|
||||
|
||||
COPY "agents/{app_name}/" "/app/agents/{app_name}/"
|
||||
{install_agent_deps}
|
||||
|
||||
# Copy agent - End
|
||||
|
||||
EXPOSE {port}
|
||||
|
||||
CMD adk {command} --port={port} "/app/agents"
|
||||
"""
|
||||
|
||||
|
||||
def _resolve_project(project_in_option: Optional[str]) -> str:
|
||||
if project_in_option:
|
||||
return project_in_option
|
||||
|
||||
result = subprocess.run(
|
||||
['gcloud', 'config', 'get-value', 'project'],
|
||||
check=True,
|
||||
capture_output=True,
|
||||
text=True,
|
||||
)
|
||||
project = result.stdout.strip()
|
||||
click.echo(f'Use default project: {project}')
|
||||
return project
|
||||
|
||||
|
||||
def to_cloud_run(
|
||||
*,
|
||||
agent_folder: str,
|
||||
project: Optional[str],
|
||||
region: Optional[str],
|
||||
service_name: str,
|
||||
app_name: str,
|
||||
temp_folder: str,
|
||||
port: int,
|
||||
with_cloud_trace: bool,
|
||||
with_ui: bool,
|
||||
):
|
||||
"""Deploys an agent to Google Cloud Run.
|
||||
|
||||
`agent_folder` should contain the following files:
|
||||
|
||||
- __init__.py
|
||||
- agent.py
|
||||
- requirements.txt (optional, for additional dependencies)
|
||||
- ... (other required source files)
|
||||
|
||||
The folder structure of temp_folder will be
|
||||
|
||||
* dist/[google_adk wheel file]
|
||||
* agents/[app_name]/
|
||||
* agent source code from `agent_folder`
|
||||
|
||||
Args:
|
||||
agent_folder: The folder (absolute path) containing the agent source code.
|
||||
project: Google Cloud project id.
|
||||
region: Google Cloud region.
|
||||
service_name: The service name in Cloud Run.
|
||||
app_name: The name of the app, by default, it's basename of `agent_folder`.
|
||||
temp_folder: The temp folder for the generated Cloud Run source files.
|
||||
port: The port of the ADK api server.
|
||||
with_cloud_trace: Whether to enable Cloud Trace.
|
||||
with_ui: Whether to deploy with UI.
|
||||
"""
|
||||
app_name = app_name or os.path.basename(agent_folder)
|
||||
|
||||
click.echo(f'Start generating Cloud Run source files in {temp_folder}')
|
||||
|
||||
# remove temp_folder if exists
|
||||
if os.path.exists(temp_folder):
|
||||
click.echo('Removing existing files')
|
||||
shutil.rmtree(temp_folder)
|
||||
|
||||
try:
|
||||
# copy agent source code
|
||||
click.echo('Copying agent source code...')
|
||||
agent_src_path = os.path.join(temp_folder, 'agents', app_name)
|
||||
shutil.copytree(agent_folder, agent_src_path)
|
||||
requirements_txt_path = os.path.join(agent_src_path, 'requirements.txt')
|
||||
install_agent_deps = (
|
||||
f'RUN pip install -r "/app/agents/{app_name}/requirements.txt"'
|
||||
if os.path.exists(requirements_txt_path)
|
||||
else ''
|
||||
)
|
||||
click.echo('Copying agent source code complete.')
|
||||
|
||||
# create Dockerfile
|
||||
click.echo('Creating Dockerfile...')
|
||||
dockerfile_content = _DOCKERFILE_TEMPLATE.format(
|
||||
gcp_project_id=project,
|
||||
gcp_region=region,
|
||||
app_name=app_name,
|
||||
port=port,
|
||||
command='web' if with_ui else 'api_server',
|
||||
install_agent_deps=install_agent_deps,
|
||||
with_cloud_trace='1' if with_cloud_trace else '0',
|
||||
)
|
||||
dockerfile_path = os.path.join(temp_folder, 'Dockerfile')
|
||||
os.makedirs(temp_folder, exist_ok=True)
|
||||
with open(dockerfile_path, 'w', encoding='utf-8') as f:
|
||||
f.write(
|
||||
dockerfile_content,
|
||||
)
|
||||
click.echo(f'Creating Dockerfile complete: {dockerfile_path}')
|
||||
|
||||
# Deploy to Cloud Run
|
||||
click.echo('Deploying to Cloud Run...')
|
||||
region_options = ['--region', region] if region else []
|
||||
project = _resolve_project(project)
|
||||
subprocess.run(
|
||||
[
|
||||
'gcloud',
|
||||
'run',
|
||||
'deploy',
|
||||
service_name,
|
||||
'--source',
|
||||
temp_folder,
|
||||
'--project',
|
||||
project,
|
||||
*region_options,
|
||||
'--port',
|
||||
str(port),
|
||||
'--labels',
|
||||
'created-by=adk',
|
||||
],
|
||||
check=True,
|
||||
)
|
||||
finally:
|
||||
click.echo(f'Cleaning up the temp folder: {temp_folder}')
|
||||
shutil.rmtree(temp_folder)
|
||||
282
src/google/adk/cli/cli_eval.py
Normal file
282
src/google/adk/cli/cli_eval.py
Normal file
@ -0,0 +1,282 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from enum import Enum
|
||||
import importlib.util
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
import traceback
|
||||
from typing import Any
|
||||
from typing import Generator
|
||||
from typing import Optional
|
||||
import uuid
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from ..agents import Agent
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class EvalStatus(Enum):
|
||||
PASSED = 1
|
||||
FAILED = 2
|
||||
NOT_EVALUATED = 3
|
||||
|
||||
|
||||
class EvalMetric(BaseModel):
|
||||
metric_name: str
|
||||
threshold: float
|
||||
|
||||
|
||||
class EvalMetricResult(BaseModel):
|
||||
score: Optional[float]
|
||||
eval_status: EvalStatus
|
||||
|
||||
|
||||
class EvalResult(BaseModel):
|
||||
eval_set_file: str
|
||||
eval_id: str
|
||||
final_eval_status: EvalStatus
|
||||
eval_metric_results: list[tuple[EvalMetric, EvalMetricResult]]
|
||||
session_id: str
|
||||
|
||||
|
||||
MISSING_EVAL_DEPENDENCIES_MESSAGE = (
|
||||
"Eval module is not installed, please install via `pip install"
|
||||
" google-adk[eval]`."
|
||||
)
|
||||
TOOL_TRAJECTORY_SCORE_KEY = "tool_trajectory_avg_score"
|
||||
RESPONSE_MATCH_SCORE_KEY = "response_match_score"
|
||||
# This evaluation is not very stable.
|
||||
# This is always optional unless explicitly specified.
|
||||
RESPONSE_EVALUATION_SCORE_KEY = "response_evaluation_score"
|
||||
|
||||
EVAL_SESSION_ID_PREFIX = "___eval___session___"
|
||||
DEFAULT_CRITERIA = {
|
||||
TOOL_TRAJECTORY_SCORE_KEY: 1.0, # 1-point scale; 1.0 is perfect.
|
||||
RESPONSE_MATCH_SCORE_KEY: 0.8,
|
||||
}
|
||||
|
||||
|
||||
def _import_from_path(module_name, file_path):
|
||||
spec = importlib.util.spec_from_file_location(module_name, file_path)
|
||||
module = importlib.util.module_from_spec(spec)
|
||||
sys.modules[module_name] = module
|
||||
spec.loader.exec_module(module)
|
||||
return module
|
||||
|
||||
|
||||
def _get_agent_module(agent_module_file_path: str):
|
||||
file_path = os.path.join(agent_module_file_path, "__init__.py")
|
||||
module_name = "agent"
|
||||
return _import_from_path(module_name, file_path)
|
||||
|
||||
|
||||
def get_evaluation_criteria_or_default(
|
||||
eval_config_file_path: str,
|
||||
) -> dict[str, float]:
|
||||
"""Returns evaluation criteria from the config file, if present.
|
||||
|
||||
Otherwise a default one is returned.
|
||||
"""
|
||||
if eval_config_file_path:
|
||||
with open(eval_config_file_path, "r", encoding="utf-8") as f:
|
||||
config_data = json.load(f)
|
||||
|
||||
if "criteria" in config_data and isinstance(config_data["criteria"], dict):
|
||||
evaluation_criteria = config_data["criteria"]
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Invalid format for test_config.json at {eval_config_file_path}."
|
||||
" Expected a 'criteria' dictionary."
|
||||
)
|
||||
else:
|
||||
logger.info("No config file supplied. Using default criteria.")
|
||||
evaluation_criteria = DEFAULT_CRITERIA
|
||||
|
||||
return evaluation_criteria
|
||||
|
||||
|
||||
def get_root_agent(agent_module_file_path: str) -> Agent:
|
||||
"""Returns root agent given the agetn module."""
|
||||
agent_module = _get_agent_module(agent_module_file_path)
|
||||
root_agent = agent_module.agent.root_agent
|
||||
return root_agent
|
||||
|
||||
|
||||
def try_get_reset_func(agent_module_file_path: str) -> Any:
|
||||
"""Returns reset function for the agent, if present, given the agetn module."""
|
||||
agent_module = _get_agent_module(agent_module_file_path)
|
||||
reset_func = getattr(agent_module.agent, "reset_data", None)
|
||||
return reset_func
|
||||
|
||||
|
||||
def parse_and_get_evals_to_run(
|
||||
eval_set_file_path: tuple[str],
|
||||
) -> dict[str, list[str]]:
|
||||
"""Returns a dictionary of eval sets to evals that should be run."""
|
||||
eval_set_to_evals = {}
|
||||
for input_eval_set in eval_set_file_path:
|
||||
evals = []
|
||||
if ":" not in input_eval_set:
|
||||
eval_set_file = input_eval_set
|
||||
else:
|
||||
eval_set_file = input_eval_set.split(":")[0]
|
||||
evals = input_eval_set.split(":")[1].split(",")
|
||||
|
||||
if eval_set_file not in eval_set_to_evals:
|
||||
eval_set_to_evals[eval_set_file] = []
|
||||
|
||||
eval_set_to_evals[eval_set_file].extend(evals)
|
||||
|
||||
return eval_set_to_evals
|
||||
|
||||
|
||||
def run_evals(
|
||||
eval_set_to_evals: dict[str, list[str]],
|
||||
root_agent: Agent,
|
||||
reset_func: Optional[Any],
|
||||
eval_metrics: list[EvalMetric],
|
||||
session_service=None,
|
||||
artifact_service=None,
|
||||
print_detailed_results=False,
|
||||
) -> Generator[EvalResult, None, None]:
|
||||
try:
|
||||
from ..evaluation.agent_evaluator import EvaluationGenerator
|
||||
from ..evaluation.response_evaluator import ResponseEvaluator
|
||||
from ..evaluation.trajectory_evaluator import TrajectoryEvaluator
|
||||
except ModuleNotFoundError as e:
|
||||
raise ModuleNotFoundError(MISSING_EVAL_DEPENDENCIES_MESSAGE) from e
|
||||
|
||||
"""Returns a summary of eval runs."""
|
||||
for eval_set_file, evals_to_run in eval_set_to_evals.items():
|
||||
with open(eval_set_file, "r", encoding="utf-8") as file:
|
||||
eval_items = json.load(file) # Load JSON into a list
|
||||
|
||||
assert eval_items, f"No eval data found in eval set file: {eval_set_file}"
|
||||
|
||||
for eval_item in eval_items:
|
||||
eval_name = eval_item["name"]
|
||||
eval_data = eval_item["data"]
|
||||
initial_session = eval_item.get("initial_session", {})
|
||||
|
||||
if evals_to_run and eval_name not in evals_to_run:
|
||||
continue
|
||||
|
||||
try:
|
||||
print(f"Running Eval: {eval_set_file}:{eval_name}")
|
||||
session_id = f"{EVAL_SESSION_ID_PREFIX}{str(uuid.uuid4())}"
|
||||
|
||||
scrape_result = EvaluationGenerator._process_query_with_root_agent(
|
||||
data=eval_data,
|
||||
root_agent=root_agent,
|
||||
reset_func=reset_func,
|
||||
initial_session=initial_session,
|
||||
session_id=session_id,
|
||||
session_service=session_service,
|
||||
artifact_service=artifact_service,
|
||||
)
|
||||
|
||||
eval_metric_results = []
|
||||
for eval_metric in eval_metrics:
|
||||
eval_metric_result = None
|
||||
if eval_metric.metric_name == TOOL_TRAJECTORY_SCORE_KEY:
|
||||
score = TrajectoryEvaluator.evaluate(
|
||||
[scrape_result], print_detailed_results=print_detailed_results
|
||||
)
|
||||
eval_metric_result = _get_eval_metric_result(eval_metric, score)
|
||||
elif eval_metric.metric_name == RESPONSE_MATCH_SCORE_KEY:
|
||||
score = ResponseEvaluator.evaluate(
|
||||
[scrape_result],
|
||||
[RESPONSE_MATCH_SCORE_KEY],
|
||||
print_detailed_results=print_detailed_results,
|
||||
)
|
||||
eval_metric_result = _get_eval_metric_result(
|
||||
eval_metric, score["rouge_1/mean"].item()
|
||||
)
|
||||
elif eval_metric.metric_name == RESPONSE_EVALUATION_SCORE_KEY:
|
||||
score = ResponseEvaluator.evaluate(
|
||||
[scrape_result],
|
||||
[RESPONSE_EVALUATION_SCORE_KEY],
|
||||
print_detailed_results=print_detailed_results,
|
||||
)
|
||||
eval_metric_result = _get_eval_metric_result(
|
||||
eval_metric, score["coherence/mean"].item()
|
||||
)
|
||||
else:
|
||||
logger.warning("`%s` is not supported.", eval_metric.metric_name)
|
||||
eval_metric_results.append((
|
||||
eval_metric,
|
||||
EvalMetricResult(eval_status=EvalStatus.NOT_EVALUATED),
|
||||
))
|
||||
|
||||
eval_metric_results.append((
|
||||
eval_metric,
|
||||
eval_metric_result,
|
||||
))
|
||||
_print_eval_metric_result(eval_metric, eval_metric_result)
|
||||
|
||||
final_eval_status = EvalStatus.NOT_EVALUATED
|
||||
|
||||
# Go over the all the eval statuses and mark the final eval status as
|
||||
# passed if all of them pass, otherwise mark the final eval status to
|
||||
# failed.
|
||||
for eval_metric_result in eval_metric_results:
|
||||
eval_status = eval_metric_result[1].eval_status
|
||||
if eval_status == EvalStatus.PASSED:
|
||||
final_eval_status = EvalStatus.PASSED
|
||||
elif eval_status == EvalStatus.NOT_EVALUATED:
|
||||
continue
|
||||
elif eval_status == EvalStatus.FAILED:
|
||||
final_eval_status = EvalStatus.FAILED
|
||||
break
|
||||
else:
|
||||
raise ValueError("Unknown eval status.")
|
||||
|
||||
yield EvalResult(
|
||||
eval_set_file=eval_set_file,
|
||||
eval_id=eval_name,
|
||||
final_eval_status=final_eval_status,
|
||||
eval_metric_results=eval_metric_results,
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
if final_eval_status == EvalStatus.PASSED:
|
||||
result = "✅ Passsed"
|
||||
else:
|
||||
result = "❌ Failed"
|
||||
|
||||
print(f"Result: {result}\n")
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error: {e}")
|
||||
logger.info("Error: %s", str(traceback.format_exc()))
|
||||
|
||||
|
||||
def _get_eval_metric_result(eval_metric, score):
|
||||
eval_status = (
|
||||
EvalStatus.PASSED if score >= eval_metric.threshold else EvalStatus.FAILED
|
||||
)
|
||||
return EvalMetricResult(score=score, eval_status=eval_status)
|
||||
|
||||
|
||||
def _print_eval_metric_result(eval_metric, eval_metric_result):
|
||||
print(
|
||||
f"Metric: {eval_metric.metric_name}\tStatus:"
|
||||
f" {eval_metric_result.eval_status}\tScore:"
|
||||
f" {eval_metric_result.score}\tThreshold: {eval_metric.threshold}"
|
||||
)
|
||||
479
src/google/adk/cli/cli_tools_click.py
Normal file
479
src/google/adk/cli/cli_tools_click.py
Normal file
@ -0,0 +1,479 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import asyncio
|
||||
from datetime import datetime
|
||||
import logging
|
||||
import os
|
||||
import tempfile
|
||||
from typing import Optional
|
||||
|
||||
import click
|
||||
import uvicorn
|
||||
|
||||
from . import cli_deploy
|
||||
from .cli import run_cli
|
||||
from .cli_eval import MISSING_EVAL_DEPENDENCIES_MESSAGE
|
||||
from .fast_api import get_fast_api_app
|
||||
from .utils import envs
|
||||
from .utils import logs
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@click.group(context_settings={"max_content_width": 240})
|
||||
def main():
|
||||
"""Agent Development Kit CLI tools."""
|
||||
pass
|
||||
|
||||
|
||||
@main.group()
|
||||
def deploy():
|
||||
"""Deploy Agent."""
|
||||
pass
|
||||
|
||||
|
||||
@main.command("run")
|
||||
@click.option(
|
||||
"--save_session",
|
||||
type=bool,
|
||||
is_flag=True,
|
||||
show_default=True,
|
||||
default=False,
|
||||
help="Optional. Whether to save the session to a json file on exit.",
|
||||
)
|
||||
@click.argument(
|
||||
"agent",
|
||||
type=click.Path(
|
||||
exists=True, dir_okay=True, file_okay=False, resolve_path=True
|
||||
),
|
||||
)
|
||||
def cli_run(agent: str, save_session: bool):
|
||||
"""Run an interactive CLI for a certain agent.
|
||||
|
||||
AGENT: The path to the agent source code folder.
|
||||
|
||||
Example:
|
||||
|
||||
adk run path/to/my_agent
|
||||
"""
|
||||
logs.log_to_tmp_folder()
|
||||
|
||||
agent_parent_folder = os.path.dirname(agent)
|
||||
agent_folder_name = os.path.basename(agent)
|
||||
|
||||
asyncio.run(
|
||||
run_cli(
|
||||
agent_parent_dir=agent_parent_folder,
|
||||
agent_folder_name=agent_folder_name,
|
||||
save_session=save_session,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@main.command("eval")
|
||||
@click.argument(
|
||||
"agent_module_file_path",
|
||||
type=click.Path(
|
||||
exists=True, dir_okay=True, file_okay=False, resolve_path=True
|
||||
),
|
||||
)
|
||||
@click.argument("eval_set_file_path", nargs=-1)
|
||||
@click.option("--config_file_path", help="Optional. The path to config file.")
|
||||
@click.option(
|
||||
"--print_detailed_results",
|
||||
is_flag=True,
|
||||
show_default=True,
|
||||
default=False,
|
||||
help="Optional. Whether to print detailed results on console or not.",
|
||||
)
|
||||
def eval_command(
|
||||
agent_module_file_path: str,
|
||||
eval_set_file_path: tuple[str],
|
||||
config_file_path: str,
|
||||
print_detailed_results: bool,
|
||||
):
|
||||
"""Evaluates an agent given the eval sets.
|
||||
|
||||
AGENT_MODULE_FILE_PATH: The path to the __init__.py file that contains a
|
||||
module by the name "agent". "agent" module contains a root_agent.
|
||||
|
||||
EVAL_SET_FILE_PATH: You can specify one or more eval set file paths.
|
||||
|
||||
For each file, all evals will be run by default.
|
||||
|
||||
If you want to run only specific evals from a eval set, first create a comma
|
||||
separated list of eval names and then add that as a suffix to the eval set
|
||||
file name, demarcated by a `:`.
|
||||
|
||||
For example,
|
||||
|
||||
sample_eval_set_file.json:eval_1,eval_2,eval_3
|
||||
|
||||
This will only run eval_1, eval_2 and eval_3 from sample_eval_set_file.json.
|
||||
|
||||
CONFIG_FILE_PATH: The path to config file.
|
||||
|
||||
PRINT_DETAILED_RESULTS: Prints detailed results on the console.
|
||||
"""
|
||||
envs.load_dotenv_for_agent(agent_module_file_path, ".")
|
||||
|
||||
try:
|
||||
from .cli_eval import EvalMetric
|
||||
from .cli_eval import EvalResult
|
||||
from .cli_eval import EvalStatus
|
||||
from .cli_eval import get_evaluation_criteria_or_default
|
||||
from .cli_eval import get_root_agent
|
||||
from .cli_eval import parse_and_get_evals_to_run
|
||||
from .cli_eval import run_evals
|
||||
from .cli_eval import try_get_reset_func
|
||||
except ModuleNotFoundError:
|
||||
raise click.ClickException(MISSING_EVAL_DEPENDENCIES_MESSAGE)
|
||||
|
||||
evaluation_criteria = get_evaluation_criteria_or_default(config_file_path)
|
||||
eval_metrics = []
|
||||
for metric_name, threshold in evaluation_criteria.items():
|
||||
eval_metrics.append(
|
||||
EvalMetric(metric_name=metric_name, threshold=threshold)
|
||||
)
|
||||
|
||||
print(f"Using evaluation creiteria: {evaluation_criteria}")
|
||||
|
||||
root_agent = get_root_agent(agent_module_file_path)
|
||||
reset_func = try_get_reset_func(agent_module_file_path)
|
||||
|
||||
eval_set_to_evals = parse_and_get_evals_to_run(eval_set_file_path)
|
||||
|
||||
try:
|
||||
eval_results = list(
|
||||
run_evals(
|
||||
eval_set_to_evals,
|
||||
root_agent,
|
||||
reset_func,
|
||||
eval_metrics,
|
||||
print_detailed_results=print_detailed_results,
|
||||
)
|
||||
)
|
||||
except ModuleNotFoundError:
|
||||
raise click.ClickException(MISSING_EVAL_DEPENDENCIES_MESSAGE)
|
||||
|
||||
print("*********************************************************************")
|
||||
eval_run_summary = {}
|
||||
|
||||
for eval_result in eval_results:
|
||||
eval_result: EvalResult
|
||||
|
||||
if eval_result.eval_set_file not in eval_run_summary:
|
||||
eval_run_summary[eval_result.eval_set_file] = [0, 0]
|
||||
|
||||
if eval_result.final_eval_status == EvalStatus.PASSED:
|
||||
eval_run_summary[eval_result.eval_set_file][0] += 1
|
||||
else:
|
||||
eval_run_summary[eval_result.eval_set_file][1] += 1
|
||||
print("Eval Run Summary")
|
||||
for eval_set_file, pass_fail_count in eval_run_summary.items():
|
||||
print(
|
||||
f"{eval_set_file}:\n Tests passed: {pass_fail_count[0]}\n Tests"
|
||||
f" failed: {pass_fail_count[1]}"
|
||||
)
|
||||
|
||||
|
||||
@main.command("web")
|
||||
@click.option(
|
||||
"--session_db_url",
|
||||
help=(
|
||||
"Optional. The database URL to store the session.\n\n - Use"
|
||||
" 'agentengine://<agent_engine_resource_id>' to connect to Vertex"
|
||||
" managed session service.\n\n - Use 'sqlite://<path_to_sqlite_file>'"
|
||||
" to connect to a SQLite DB.\n\n - See"
|
||||
" https://docs.sqlalchemy.org/en/20/core/engines.html#backend-specific-urls"
|
||||
" for more details on supported DB URLs."
|
||||
),
|
||||
)
|
||||
@click.option(
|
||||
"--port",
|
||||
type=int,
|
||||
help="Optional. The port of the server",
|
||||
default=8000,
|
||||
)
|
||||
@click.option(
|
||||
"--allow_origins",
|
||||
help="Optional. Any additional origins to allow for CORS.",
|
||||
multiple=True,
|
||||
)
|
||||
@click.option(
|
||||
"--log_level",
|
||||
type=click.Choice(
|
||||
["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"], case_sensitive=False
|
||||
),
|
||||
default="INFO",
|
||||
help="Optional. Set the logging level",
|
||||
)
|
||||
@click.option(
|
||||
"--log_to_tmp",
|
||||
is_flag=True,
|
||||
show_default=True,
|
||||
default=False,
|
||||
help=(
|
||||
"Optional. Whether to log to system temp folder instead of console."
|
||||
" This is useful for local debugging."
|
||||
),
|
||||
)
|
||||
@click.argument(
|
||||
"agents_dir",
|
||||
type=click.Path(
|
||||
exists=True, dir_okay=True, file_okay=False, resolve_path=True
|
||||
),
|
||||
default=os.getcwd(),
|
||||
)
|
||||
def web(
|
||||
agents_dir: str,
|
||||
log_to_tmp: bool,
|
||||
session_db_url: str = "",
|
||||
log_level: str = "INFO",
|
||||
allow_origins: Optional[list[str]] = None,
|
||||
port: int = 8000,
|
||||
):
|
||||
"""Start a FastAPI server with web UI for a certain agent.
|
||||
|
||||
AGENTS_DIR: The directory of agents, where each sub-directory is a single
|
||||
agent, containing at least `__init__.py` and `agent.py` files.
|
||||
|
||||
Example:
|
||||
|
||||
adk web --session_db_url=[db_url] --port=[port] path/to/agents_dir
|
||||
"""
|
||||
if log_to_tmp:
|
||||
logs.log_to_tmp_folder()
|
||||
else:
|
||||
logs.log_to_stderr()
|
||||
|
||||
logging.getLogger().setLevel(log_level)
|
||||
|
||||
config = uvicorn.Config(
|
||||
get_fast_api_app(
|
||||
agent_dir=agents_dir,
|
||||
session_db_url=session_db_url,
|
||||
allow_origins=allow_origins,
|
||||
web=True,
|
||||
),
|
||||
host="0.0.0.0",
|
||||
port=port,
|
||||
reload=True,
|
||||
)
|
||||
server = uvicorn.Server(config)
|
||||
server.run()
|
||||
|
||||
|
||||
@main.command("api_server")
|
||||
@click.option(
|
||||
"--session_db_url",
|
||||
help=(
|
||||
"Optional. The database URL to store the session.\n\n - Use"
|
||||
" 'agentengine://<agent_engine_resource_id>' to connect to Vertex"
|
||||
" managed session service.\n\n - Use 'sqlite://<path_to_sqlite_file>'"
|
||||
" to connect to a SQLite DB.\n\n - See"
|
||||
" https://docs.sqlalchemy.org/en/20/core/engines.html#backend-specific-urls"
|
||||
" for more details on supported DB URLs."
|
||||
),
|
||||
)
|
||||
@click.option(
|
||||
"--port",
|
||||
type=int,
|
||||
help="Optional. The port of the server",
|
||||
default=8000,
|
||||
)
|
||||
@click.option(
|
||||
"--allow_origins",
|
||||
help="Optional. Any additional origins to allow for CORS.",
|
||||
multiple=True,
|
||||
)
|
||||
@click.option(
|
||||
"--log_level",
|
||||
type=click.Choice(
|
||||
["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"], case_sensitive=False
|
||||
),
|
||||
default="INFO",
|
||||
help="Optional. Set the logging level",
|
||||
)
|
||||
@click.option(
|
||||
"--log_to_tmp",
|
||||
is_flag=True,
|
||||
show_default=True,
|
||||
default=False,
|
||||
help=(
|
||||
"Optional. Whether to log to system temp folder instead of console."
|
||||
" This is useful for local debugging."
|
||||
),
|
||||
)
|
||||
# The directory of agents, where each sub-directory is a single agent.
|
||||
# By default, it is the current working directory
|
||||
@click.argument(
|
||||
"agents_dir",
|
||||
type=click.Path(
|
||||
exists=True, dir_okay=True, file_okay=False, resolve_path=True
|
||||
),
|
||||
default=os.getcwd(),
|
||||
)
|
||||
def cli_api_server(
|
||||
agents_dir: str,
|
||||
log_to_tmp: bool,
|
||||
session_db_url: str = "",
|
||||
log_level: str = "INFO",
|
||||
allow_origins: Optional[list[str]] = None,
|
||||
port: int = 8000,
|
||||
):
|
||||
"""Start an api server for a certain agent.
|
||||
|
||||
AGENTS_DIR: The directory of agents, where each sub-directory is a single
|
||||
agent, containing at least `__init__.py` and `agent.py` files.
|
||||
|
||||
Example:
|
||||
|
||||
adk api_server --session_db_url=[db_url] --port=[port] path/to/agents_dir
|
||||
"""
|
||||
if log_to_tmp:
|
||||
logs.log_to_tmp_folder()
|
||||
else:
|
||||
logs.log_to_stderr()
|
||||
|
||||
logging.getLogger().setLevel(log_level)
|
||||
|
||||
config = uvicorn.Config(
|
||||
get_fast_api_app(
|
||||
agent_dir=agents_dir,
|
||||
session_db_url=session_db_url,
|
||||
allow_origins=allow_origins,
|
||||
web=False,
|
||||
),
|
||||
host="0.0.0.0",
|
||||
port=port,
|
||||
reload=True,
|
||||
)
|
||||
server = uvicorn.Server(config)
|
||||
server.run()
|
||||
|
||||
|
||||
@deploy.command("cloud_run")
|
||||
@click.option(
|
||||
"--project",
|
||||
type=str,
|
||||
help=(
|
||||
"Required. Google Cloud project to deploy the agent. When absent,"
|
||||
" default project from gcloud config is used."
|
||||
),
|
||||
)
|
||||
@click.option(
|
||||
"--region",
|
||||
type=str,
|
||||
help=(
|
||||
"Required. Google Cloud region to deploy the agent. When absent,"
|
||||
" gcloud run deploy will prompt later."
|
||||
),
|
||||
)
|
||||
@click.option(
|
||||
"--service_name",
|
||||
type=str,
|
||||
default="adk-default-service-name",
|
||||
help=(
|
||||
"Optional. The service name to use in Cloud Run (default:"
|
||||
" 'adk-default-service-name')."
|
||||
),
|
||||
)
|
||||
@click.option(
|
||||
"--app_name",
|
||||
type=str,
|
||||
default="",
|
||||
help=(
|
||||
"Optional. App name of the ADK API server (default: the folder name"
|
||||
" of the AGENT source code)."
|
||||
),
|
||||
)
|
||||
@click.option(
|
||||
"--port",
|
||||
type=int,
|
||||
default=8000,
|
||||
help="Optional. The port of the ADK API server (default: 8000).",
|
||||
)
|
||||
@click.option(
|
||||
"--with_cloud_trace",
|
||||
type=bool,
|
||||
is_flag=True,
|
||||
show_default=True,
|
||||
default=False,
|
||||
help="Optional. Whether to enable Cloud Trace for cloud run.",
|
||||
)
|
||||
@click.option(
|
||||
"--with_ui",
|
||||
type=bool,
|
||||
is_flag=True,
|
||||
show_default=True,
|
||||
default=False,
|
||||
help=(
|
||||
"Optional. Deploy ADK Web UI if set. (default: deploy ADK API server"
|
||||
" only)"
|
||||
),
|
||||
)
|
||||
@click.option(
|
||||
"--temp_folder",
|
||||
type=str,
|
||||
default=os.path.join(
|
||||
tempfile.gettempdir(),
|
||||
"cloud_run_deploy_src",
|
||||
datetime.now().strftime("%Y%m%d_%H%M%S"),
|
||||
),
|
||||
help=(
|
||||
"Optional. Temp folder for the generated Cloud Run source files"
|
||||
" (default: a timestamped folder in the system temp directory)."
|
||||
),
|
||||
)
|
||||
@click.argument(
|
||||
"agent",
|
||||
type=click.Path(
|
||||
exists=True, dir_okay=True, file_okay=False, resolve_path=True
|
||||
),
|
||||
)
|
||||
def deploy_to_cloud_run(
|
||||
agent: str,
|
||||
project: Optional[str],
|
||||
region: Optional[str],
|
||||
service_name: str,
|
||||
app_name: str,
|
||||
temp_folder: str,
|
||||
port: int,
|
||||
with_cloud_trace: bool,
|
||||
with_ui: bool,
|
||||
):
|
||||
"""Deploys agent to Cloud Run.
|
||||
|
||||
AGENT: The path to the agent source code folder.
|
||||
|
||||
Example:
|
||||
|
||||
adk deploy cloud_run --project=[project] --region=[region] path/to/my_agent
|
||||
"""
|
||||
try:
|
||||
cli_deploy.to_cloud_run(
|
||||
agent_folder=agent,
|
||||
project=project,
|
||||
region=region,
|
||||
service_name=service_name,
|
||||
app_name=app_name,
|
||||
temp_folder=temp_folder,
|
||||
port=port,
|
||||
with_cloud_trace=with_cloud_trace,
|
||||
with_ui=with_ui,
|
||||
)
|
||||
except Exception as e:
|
||||
click.secho(f"Deploy failed: {e}", fg="red", err=True)
|
||||
765
src/google/adk/cli/fast_api.py
Normal file
765
src/google/adk/cli/fast_api.py
Normal file
@ -0,0 +1,765 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import asyncio
|
||||
import importlib
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
from pathlib import Path
|
||||
import re
|
||||
import sys
|
||||
import traceback
|
||||
import typing
|
||||
from typing import Any
|
||||
from typing import List
|
||||
from typing import Literal
|
||||
from typing import Optional
|
||||
|
||||
import click
|
||||
from fastapi import FastAPI
|
||||
from fastapi import HTTPException
|
||||
from fastapi import Query
|
||||
from fastapi import Response
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.responses import FileResponse
|
||||
from fastapi.responses import RedirectResponse
|
||||
from fastapi.responses import StreamingResponse
|
||||
from fastapi.staticfiles import StaticFiles
|
||||
from fastapi.websockets import WebSocket
|
||||
from fastapi.websockets import WebSocketDisconnect
|
||||
from google.genai import types
|
||||
import graphviz
|
||||
from opentelemetry import trace
|
||||
from opentelemetry.exporter.cloud_trace import CloudTraceSpanExporter
|
||||
from opentelemetry.sdk.trace import export
|
||||
from opentelemetry.sdk.trace import ReadableSpan
|
||||
from opentelemetry.sdk.trace import TracerProvider
|
||||
from pydantic import BaseModel
|
||||
from pydantic import ValidationError
|
||||
|
||||
from ..agents import RunConfig
|
||||
from ..agents.live_request_queue import LiveRequest
|
||||
from ..agents.live_request_queue import LiveRequestQueue
|
||||
from ..agents.llm_agent import Agent
|
||||
from ..agents.run_config import StreamingMode
|
||||
from ..artifacts import InMemoryArtifactService
|
||||
from ..events.event import Event
|
||||
from ..runners import Runner
|
||||
from ..sessions.database_session_service import DatabaseSessionService
|
||||
from ..sessions.in_memory_session_service import InMemorySessionService
|
||||
from ..sessions.session import Session
|
||||
from ..sessions.vertex_ai_session_service import VertexAiSessionService
|
||||
from .cli_eval import EVAL_SESSION_ID_PREFIX
|
||||
from .cli_eval import EvalMetric
|
||||
from .cli_eval import EvalMetricResult
|
||||
from .cli_eval import EvalStatus
|
||||
from .utils import create_empty_state
|
||||
from .utils import envs
|
||||
from .utils import evals
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_EVAL_SET_FILE_EXTENSION = ".evalset.json"
|
||||
|
||||
|
||||
class ApiServerSpanExporter(export.SpanExporter):
|
||||
|
||||
def __init__(self, trace_dict):
|
||||
self.trace_dict = trace_dict
|
||||
|
||||
def export(
|
||||
self, spans: typing.Sequence[ReadableSpan]
|
||||
) -> export.SpanExportResult:
|
||||
for span in spans:
|
||||
if span.name == "call_llm" or span.name == "send_data":
|
||||
attributes = dict(span.attributes)
|
||||
attributes["trace_id"] = span.get_span_context().trace_id
|
||||
attributes["span_id"] = span.get_span_context().span_id
|
||||
if attributes.get("gcp.vertex.agent.event_id", None):
|
||||
self.trace_dict[attributes["gcp.vertex.agent.event_id"]] = attributes
|
||||
return export.SpanExportResult.SUCCESS
|
||||
|
||||
def force_flush(self, timeout_millis: int = 30000) -> bool:
|
||||
return True
|
||||
|
||||
|
||||
class AgentRunRequest(BaseModel):
|
||||
app_name: str
|
||||
user_id: str
|
||||
session_id: str
|
||||
new_message: types.Content
|
||||
streaming: bool = False
|
||||
|
||||
|
||||
class AddSessionToEvalSetRequest(BaseModel):
|
||||
eval_id: str
|
||||
session_id: str
|
||||
user_id: str
|
||||
|
||||
|
||||
class RunEvalRequest(BaseModel):
|
||||
eval_ids: list[str] # if empty, then all evals in the eval set are run.
|
||||
eval_metrics: list[EvalMetric]
|
||||
|
||||
|
||||
class RunEvalResult(BaseModel):
|
||||
eval_set_id: str
|
||||
eval_id: str
|
||||
final_eval_status: EvalStatus
|
||||
eval_metric_results: list[tuple[EvalMetric, EvalMetricResult]]
|
||||
session_id: str
|
||||
|
||||
|
||||
def get_fast_api_app(
|
||||
*,
|
||||
agent_dir: str,
|
||||
session_db_url: str = "",
|
||||
allow_origins: Optional[list[str]] = None,
|
||||
web: bool,
|
||||
) -> FastAPI:
|
||||
# InMemory tracing dict.
|
||||
trace_dict: dict[str, Any] = {}
|
||||
|
||||
# Set up tracing in the FastAPI server.
|
||||
provider = TracerProvider()
|
||||
provider.add_span_processor(
|
||||
export.SimpleSpanProcessor(ApiServerSpanExporter(trace_dict))
|
||||
)
|
||||
if os.environ.get("ADK_TRACE_TO_CLOUD", "0") == "1":
|
||||
processor = export.BatchSpanProcessor(
|
||||
CloudTraceSpanExporter(
|
||||
project_id=os.environ.get("GOOGLE_CLOUD_PROJECT", "")
|
||||
)
|
||||
)
|
||||
provider.add_span_processor(processor)
|
||||
|
||||
trace.set_tracer_provider(provider)
|
||||
|
||||
# Run the FastAPI server.
|
||||
app = FastAPI()
|
||||
|
||||
if allow_origins:
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=allow_origins,
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
if agent_dir not in sys.path:
|
||||
sys.path.append(agent_dir)
|
||||
|
||||
runner_dict = {}
|
||||
root_agent_dict = {}
|
||||
|
||||
# Build the Artifact service
|
||||
artifact_service = InMemoryArtifactService()
|
||||
|
||||
# Build the Session service
|
||||
agent_engine_id = ""
|
||||
if session_db_url:
|
||||
if session_db_url.startswith("agentengine://"):
|
||||
# Create vertex session service
|
||||
agent_engine_id = session_db_url.split("://")[1]
|
||||
if not agent_engine_id:
|
||||
raise click.ClickException("Agent engine id can not be empty.")
|
||||
envs.load_dotenv_for_agent("", agent_dir)
|
||||
session_service = VertexAiSessionService(
|
||||
os.environ["GOOGLE_CLOUD_PROJECT"],
|
||||
os.environ["GOOGLE_CLOUD_LOCATION"],
|
||||
)
|
||||
else:
|
||||
session_service = DatabaseSessionService(db_url=session_db_url)
|
||||
else:
|
||||
session_service = InMemorySessionService()
|
||||
|
||||
@app.get("/list-apps")
|
||||
def list_apps() -> list[str]:
|
||||
base_path = Path.cwd() / agent_dir
|
||||
if not base_path.exists():
|
||||
raise HTTPException(status_code=404, detail="Path not found")
|
||||
if not base_path.is_dir():
|
||||
raise HTTPException(status_code=400, detail="Not a directory")
|
||||
agent_names = [
|
||||
x
|
||||
for x in os.listdir(base_path)
|
||||
if os.path.isdir(os.path.join(base_path, x))
|
||||
and not x.startswith(".")
|
||||
and x != "__pycache__"
|
||||
]
|
||||
agent_names.sort()
|
||||
return agent_names
|
||||
|
||||
@app.get("/debug/trace/{event_id}")
|
||||
def get_trace_dict(event_id: str) -> Any:
|
||||
event_dict = trace_dict.get(event_id, None)
|
||||
if event_dict is None:
|
||||
raise HTTPException(status_code=404, detail="Trace not found")
|
||||
return event_dict
|
||||
|
||||
@app.get(
|
||||
"/apps/{app_name}/users/{user_id}/sessions/{session_id}",
|
||||
response_model_exclude_none=True,
|
||||
)
|
||||
def get_session(app_name: str, user_id: str, session_id: str) -> Session:
|
||||
# Connect to managed session if agent_engine_id is set.
|
||||
app_name = agent_engine_id if agent_engine_id else app_name
|
||||
session = session_service.get_session(
|
||||
app_name=app_name, user_id=user_id, session_id=session_id
|
||||
)
|
||||
if not session:
|
||||
raise HTTPException(status_code=404, detail="Session not found")
|
||||
return session
|
||||
|
||||
@app.get(
|
||||
"/apps/{app_name}/users/{user_id}/sessions",
|
||||
response_model_exclude_none=True,
|
||||
)
|
||||
def list_sessions(app_name: str, user_id: str) -> list[Session]:
|
||||
# Connect to managed session if agent_engine_id is set.
|
||||
app_name = agent_engine_id if agent_engine_id else app_name
|
||||
return [
|
||||
session
|
||||
for session in session_service.list_sessions(
|
||||
app_name=app_name, user_id=user_id
|
||||
).sessions
|
||||
# Remove sessions that were generated as a part of Eval.
|
||||
if not session.id.startswith(EVAL_SESSION_ID_PREFIX)
|
||||
]
|
||||
|
||||
@app.post(
|
||||
"/apps/{app_name}/users/{user_id}/sessions/{session_id}",
|
||||
response_model_exclude_none=True,
|
||||
)
|
||||
def create_session_with_id(
|
||||
app_name: str,
|
||||
user_id: str,
|
||||
session_id: str,
|
||||
state: Optional[dict[str, Any]] = None,
|
||||
) -> Session:
|
||||
# Connect to managed session if agent_engine_id is set.
|
||||
app_name = agent_engine_id if agent_engine_id else app_name
|
||||
if (
|
||||
session_service.get_session(
|
||||
app_name=app_name, user_id=user_id, session_id=session_id
|
||||
)
|
||||
is not None
|
||||
):
|
||||
logger.warning("Session already exists: %s", session_id)
|
||||
raise HTTPException(
|
||||
status_code=400, detail=f"Session already exists: {session_id}"
|
||||
)
|
||||
|
||||
logger.info("New session created: %s", session_id)
|
||||
return session_service.create_session(
|
||||
app_name=app_name, user_id=user_id, state=state, session_id=session_id
|
||||
)
|
||||
|
||||
@app.post(
|
||||
"/apps/{app_name}/users/{user_id}/sessions",
|
||||
response_model_exclude_none=True,
|
||||
)
|
||||
def create_session(
|
||||
app_name: str,
|
||||
user_id: str,
|
||||
state: Optional[dict[str, Any]] = None,
|
||||
) -> Session:
|
||||
# Connect to managed session if agent_engine_id is set.
|
||||
app_name = agent_engine_id if agent_engine_id else app_name
|
||||
|
||||
logger.info("New session created")
|
||||
return session_service.create_session(
|
||||
app_name=app_name, user_id=user_id, state=state
|
||||
)
|
||||
|
||||
def _get_eval_set_file_path(app_name, agent_dir, eval_set_id) -> str:
|
||||
return os.path.join(
|
||||
agent_dir,
|
||||
app_name,
|
||||
eval_set_id + _EVAL_SET_FILE_EXTENSION,
|
||||
)
|
||||
|
||||
@app.post(
|
||||
"/apps/{app_name}/eval_sets/{eval_set_id}",
|
||||
response_model_exclude_none=True,
|
||||
)
|
||||
def create_eval_set(
|
||||
app_name: str,
|
||||
eval_set_id: str,
|
||||
):
|
||||
"""Creates an eval set, given the id."""
|
||||
pattern = r"^[a-zA-Z0-9_]+$"
|
||||
if not bool(re.fullmatch(pattern, eval_set_id)):
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=(
|
||||
f"Invalid eval set id. Eval set id should have the `{pattern}`"
|
||||
" format"
|
||||
),
|
||||
)
|
||||
# Define the file path
|
||||
new_eval_set_path = _get_eval_set_file_path(
|
||||
app_name, agent_dir, eval_set_id
|
||||
)
|
||||
|
||||
logger.info("Creating eval set file `%s`", new_eval_set_path)
|
||||
|
||||
if not os.path.exists(new_eval_set_path):
|
||||
# Write the JSON string to the file
|
||||
logger.info("Eval set file doesn't exist, we will create a new one.")
|
||||
with open(new_eval_set_path, "w") as f:
|
||||
empty_content = json.dumps([], indent=2)
|
||||
f.write(empty_content)
|
||||
|
||||
@app.get(
|
||||
"/apps/{app_name}/eval_sets",
|
||||
response_model_exclude_none=True,
|
||||
)
|
||||
def list_eval_sets(app_name: str) -> list[str]:
|
||||
"""Lists all eval sets for the given app."""
|
||||
eval_set_file_path = os.path.join(agent_dir, app_name)
|
||||
eval_sets = []
|
||||
for file in os.listdir(eval_set_file_path):
|
||||
if file.endswith(_EVAL_SET_FILE_EXTENSION):
|
||||
eval_sets.append(
|
||||
os.path.basename(file).removesuffix(_EVAL_SET_FILE_EXTENSION)
|
||||
)
|
||||
|
||||
return sorted(eval_sets)
|
||||
|
||||
@app.post(
|
||||
"/apps/{app_name}/eval_sets/{eval_set_id}/add_session",
|
||||
response_model_exclude_none=True,
|
||||
)
|
||||
def add_session_to_eval_set(
|
||||
app_name: str, eval_set_id: str, req: AddSessionToEvalSetRequest
|
||||
):
|
||||
pattern = r"^[a-zA-Z0-9_]+$"
|
||||
if not bool(re.fullmatch(pattern, req.eval_id)):
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Invalid eval id. Eval id should have the `{pattern}` format",
|
||||
)
|
||||
|
||||
# Get the session
|
||||
session = session_service.get_session(
|
||||
app_name=app_name, user_id=req.user_id, session_id=req.session_id
|
||||
)
|
||||
assert session, "Session not found."
|
||||
# Load the eval set file data
|
||||
eval_set_file_path = _get_eval_set_file_path(
|
||||
app_name, agent_dir, eval_set_id
|
||||
)
|
||||
with open(eval_set_file_path, "r") as file:
|
||||
eval_set_data = json.load(file) # Load JSON into a list
|
||||
|
||||
if [x for x in eval_set_data if x["name"] == req.eval_id]:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=(
|
||||
f"Eval id `{req.eval_id}` already exists in `{eval_set_id}`"
|
||||
" eval set."
|
||||
),
|
||||
)
|
||||
|
||||
# Convert the session data to evaluation format
|
||||
test_data = evals.convert_session_to_eval_format(session)
|
||||
|
||||
# Populate the session with initial session state.
|
||||
initial_session_state = create_empty_state(_get_root_agent(app_name))
|
||||
|
||||
eval_set_data.append({
|
||||
"name": req.eval_id,
|
||||
"data": test_data,
|
||||
"initial_session": {
|
||||
"state": initial_session_state,
|
||||
"app_name": app_name,
|
||||
"user_id": req.user_id,
|
||||
},
|
||||
})
|
||||
# Serialize the test data to JSON and write to the eval set file.
|
||||
with open(eval_set_file_path, "w") as f:
|
||||
f.write(json.dumps(eval_set_data, indent=2))
|
||||
|
||||
@app.get(
|
||||
"/apps/{app_name}/eval_sets/{eval_set_id}/evals",
|
||||
response_model_exclude_none=True,
|
||||
)
|
||||
def list_evals_in_eval_set(
|
||||
app_name: str,
|
||||
eval_set_id: str,
|
||||
) -> list[str]:
|
||||
"""Lists all evals in an eval set."""
|
||||
# Load the eval set file data
|
||||
eval_set_file_path = _get_eval_set_file_path(
|
||||
app_name, agent_dir, eval_set_id
|
||||
)
|
||||
with open(eval_set_file_path, "r") as file:
|
||||
eval_set_data = json.load(file) # Load JSON into a list
|
||||
|
||||
return sorted([x["name"] for x in eval_set_data])
|
||||
|
||||
@app.post(
|
||||
"/apps/{app_name}/eval_sets/{eval_set_id}/run_eval",
|
||||
response_model_exclude_none=True,
|
||||
)
|
||||
def run_eval(
|
||||
app_name: str, eval_set_id: str, req: RunEvalRequest
|
||||
) -> list[RunEvalResult]:
|
||||
from .cli_eval import run_evals
|
||||
|
||||
"""Runs an eval given the details in the eval request."""
|
||||
# Create a mapping from eval set file to all the evals that needed to be
|
||||
# run.
|
||||
eval_set_file_path = _get_eval_set_file_path(
|
||||
app_name, agent_dir, eval_set_id
|
||||
)
|
||||
eval_set_to_evals = {eval_set_file_path: req.eval_ids}
|
||||
|
||||
if not req.eval_ids:
|
||||
logger.info(
|
||||
"Eval ids to run list is empty. We will all evals in the eval set."
|
||||
)
|
||||
root_agent = _get_root_agent(app_name)
|
||||
eval_results = list(
|
||||
run_evals(
|
||||
eval_set_to_evals,
|
||||
root_agent,
|
||||
getattr(root_agent, "reset_data", None),
|
||||
req.eval_metrics,
|
||||
session_service=session_service,
|
||||
artifact_service=artifact_service,
|
||||
)
|
||||
)
|
||||
|
||||
run_eval_results = []
|
||||
for eval_result in eval_results:
|
||||
run_eval_results.append(
|
||||
RunEvalResult(
|
||||
app_name=app_name,
|
||||
eval_set_id=eval_set_id,
|
||||
eval_id=eval_result.eval_id,
|
||||
final_eval_status=eval_result.final_eval_status,
|
||||
eval_metric_results=eval_result.eval_metric_results,
|
||||
session_id=eval_result.session_id,
|
||||
)
|
||||
)
|
||||
return run_eval_results
|
||||
|
||||
@app.delete("/apps/{app_name}/users/{user_id}/sessions/{session_id}")
|
||||
def delete_session(app_name: str, user_id: str, session_id: str):
|
||||
# Connect to managed session if agent_engine_id is set.
|
||||
app_name = agent_engine_id if agent_engine_id else app_name
|
||||
session_service.delete_session(
|
||||
app_name=app_name, user_id=user_id, session_id=session_id
|
||||
)
|
||||
|
||||
@app.get(
|
||||
"/apps/{app_name}/users/{user_id}/sessions/{session_id}/artifacts/{artifact_name}",
|
||||
response_model_exclude_none=True,
|
||||
)
|
||||
def load_artifact(
|
||||
app_name: str,
|
||||
user_id: str,
|
||||
session_id: str,
|
||||
artifact_name: str,
|
||||
version: Optional[int] = Query(None),
|
||||
) -> Optional[types.Part]:
|
||||
artifact = artifact_service.load_artifact(
|
||||
app_name=app_name,
|
||||
user_id=user_id,
|
||||
session_id=session_id,
|
||||
filename=artifact_name,
|
||||
version=version,
|
||||
)
|
||||
if not artifact:
|
||||
raise HTTPException(status_code=404, detail="Artifact not found")
|
||||
return artifact
|
||||
|
||||
@app.get(
|
||||
"/apps/{app_name}/users/{user_id}/sessions/{session_id}/artifacts/{artifact_name}/versions/{version_id}",
|
||||
response_model_exclude_none=True,
|
||||
)
|
||||
def load_artifact_version(
|
||||
app_name: str,
|
||||
user_id: str,
|
||||
session_id: str,
|
||||
artifact_name: str,
|
||||
version_id: int,
|
||||
) -> Optional[types.Part]:
|
||||
artifact = artifact_service.load_artifact(
|
||||
app_name=app_name,
|
||||
user_id=user_id,
|
||||
session_id=session_id,
|
||||
filename=artifact_name,
|
||||
version=version_id,
|
||||
)
|
||||
if not artifact:
|
||||
raise HTTPException(status_code=404, detail="Artifact not found")
|
||||
return artifact
|
||||
|
||||
@app.get(
|
||||
"/apps/{app_name}/users/{user_id}/sessions/{session_id}/artifacts",
|
||||
response_model_exclude_none=True,
|
||||
)
|
||||
def list_artifact_names(
|
||||
app_name: str, user_id: str, session_id: str
|
||||
) -> list[str]:
|
||||
return artifact_service.list_artifact_keys(
|
||||
app_name=app_name, user_id=user_id, session_id=session_id
|
||||
)
|
||||
|
||||
@app.get(
|
||||
"/apps/{app_name}/users/{user_id}/sessions/{session_id}/artifacts/{artifact_name}/versions",
|
||||
response_model_exclude_none=True,
|
||||
)
|
||||
def list_artifact_versions(
|
||||
app_name: str, user_id: str, session_id: str, artifact_name: str
|
||||
) -> list[int]:
|
||||
return artifact_service.list_versions(
|
||||
app_name=app_name,
|
||||
user_id=user_id,
|
||||
session_id=session_id,
|
||||
filename=artifact_name,
|
||||
)
|
||||
|
||||
@app.delete(
|
||||
"/apps/{app_name}/users/{user_id}/sessions/{session_id}/artifacts/{artifact_name}",
|
||||
)
|
||||
def delete_artifact(
|
||||
app_name: str, user_id: str, session_id: str, artifact_name: str
|
||||
):
|
||||
artifact_service.delete_artifact(
|
||||
app_name=app_name,
|
||||
user_id=user_id,
|
||||
session_id=session_id,
|
||||
filename=artifact_name,
|
||||
)
|
||||
|
||||
@app.post("/run", response_model_exclude_none=True)
|
||||
async def agent_run(req: AgentRunRequest) -> list[Event]:
|
||||
# Connect to managed session if agent_engine_id is set.
|
||||
app_id = agent_engine_id if agent_engine_id else req.app_name
|
||||
session = session_service.get_session(
|
||||
app_name=app_id, user_id=req.user_id, session_id=req.session_id
|
||||
)
|
||||
if not session:
|
||||
raise HTTPException(status_code=404, detail="Session not found")
|
||||
runner = _get_runner(req.app_name)
|
||||
events = [
|
||||
event
|
||||
async for event in runner.run_async(
|
||||
user_id=req.user_id,
|
||||
session_id=req.session_id,
|
||||
new_message=req.new_message,
|
||||
)
|
||||
]
|
||||
logger.info("Generated %s events in agent run: %s", len(events), events)
|
||||
return events
|
||||
|
||||
@app.post("/run_sse")
|
||||
async def agent_run_sse(req: AgentRunRequest) -> StreamingResponse:
|
||||
# Connect to managed session if agent_engine_id is set.
|
||||
app_id = agent_engine_id if agent_engine_id else req.app_name
|
||||
# SSE endpoint
|
||||
session = session_service.get_session(
|
||||
app_name=app_id, user_id=req.user_id, session_id=req.session_id
|
||||
)
|
||||
if not session:
|
||||
raise HTTPException(status_code=404, detail="Session not found")
|
||||
|
||||
# Convert the events to properly formatted SSE
|
||||
async def event_generator():
|
||||
try:
|
||||
stream_mode = StreamingMode.SSE if req.streaming else StreamingMode.NONE
|
||||
runner = _get_runner(req.app_name)
|
||||
async for event in runner.run_async(
|
||||
user_id=req.user_id,
|
||||
session_id=req.session_id,
|
||||
new_message=req.new_message,
|
||||
run_config=RunConfig(streaming_mode=stream_mode),
|
||||
):
|
||||
# Format as SSE data
|
||||
sse_event = event.model_dump_json(exclude_none=True, by_alias=True)
|
||||
logger.info("Generated event in agent run streaming: %s", sse_event)
|
||||
yield f"data: {sse_event}\n\n"
|
||||
except Exception as e:
|
||||
logger.exception("Error in event_generator: %s", e)
|
||||
# You might want to yield an error event here
|
||||
yield f'data: {{"error": "{str(e)}"}}\n\n'
|
||||
|
||||
# Returns a streaming response with the proper media type for SSE
|
||||
return StreamingResponse(
|
||||
event_generator(),
|
||||
media_type="text/event-stream",
|
||||
)
|
||||
|
||||
@app.get(
|
||||
"/apps/{app_name}/users/{user_id}/sessions/{session_id}/events/{event_id}/graph",
|
||||
response_model_exclude_none=True,
|
||||
)
|
||||
def get_event_graph(
|
||||
app_name: str, user_id: str, session_id: str, event_id: str
|
||||
):
|
||||
# Connect to managed session if agent_engine_id is set.
|
||||
app_id = agent_engine_id if agent_engine_id else app_name
|
||||
session = session_service.get_session(
|
||||
app_name=app_id, user_id=user_id, session_id=session_id
|
||||
)
|
||||
session_events = session.events if session else []
|
||||
event = next((x for x in session_events if x.id == event_id), None)
|
||||
if not event:
|
||||
return {}
|
||||
|
||||
from . import agent_graph
|
||||
|
||||
function_calls = event.get_function_calls()
|
||||
function_responses = event.get_function_responses()
|
||||
root_agent = _get_root_agent(app_name)
|
||||
dot_graph = None
|
||||
if function_calls:
|
||||
function_call_highlights = []
|
||||
for function_call in function_calls:
|
||||
from_name = event.author
|
||||
to_name = function_call.name
|
||||
function_call_highlights.append((from_name, to_name))
|
||||
dot_graph = agent_graph.get_agent_graph(
|
||||
root_agent, function_call_highlights
|
||||
)
|
||||
elif function_responses:
|
||||
function_responses_highlights = []
|
||||
for function_response in function_responses:
|
||||
from_name = function_response.name
|
||||
to_name = event.author
|
||||
function_responses_highlights.append((from_name, to_name))
|
||||
dot_graph = agent_graph.get_agent_graph(
|
||||
root_agent, function_responses_highlights
|
||||
)
|
||||
else:
|
||||
from_name = event.author
|
||||
to_name = ""
|
||||
dot_graph = agent_graph.get_agent_graph(
|
||||
root_agent, [(from_name, to_name)]
|
||||
)
|
||||
if dot_graph and isinstance(dot_graph, graphviz.Digraph):
|
||||
return {"dot_src": dot_graph.source}
|
||||
else:
|
||||
return {}
|
||||
|
||||
@app.websocket("/run_live")
|
||||
async def agent_live_run(
|
||||
websocket: WebSocket,
|
||||
app_name: str,
|
||||
user_id: str,
|
||||
session_id: str,
|
||||
modalities: List[Literal["TEXT", "AUDIO"]] = Query(
|
||||
default=["TEXT", "AUDIO"]
|
||||
), # Only allows "TEXT" or "AUDIO"
|
||||
) -> None:
|
||||
await websocket.accept()
|
||||
|
||||
# Connect to managed session if agent_engine_id is set.
|
||||
app_id = agent_engine_id if agent_engine_id else app_name
|
||||
session = session_service.get_session(
|
||||
app_name=app_id, user_id=user_id, session_id=session_id
|
||||
)
|
||||
if not session:
|
||||
# Accept first so that the client is aware of connection establishment,
|
||||
# then close with a specific code.
|
||||
await websocket.close(code=1002, reason="Session not found")
|
||||
return
|
||||
|
||||
live_request_queue = LiveRequestQueue()
|
||||
|
||||
async def forward_events():
|
||||
runner = _get_runner(app_name)
|
||||
async for event in runner.run_live(
|
||||
session=session, live_request_queue=live_request_queue
|
||||
):
|
||||
await websocket.send_text(
|
||||
event.model_dump_json(exclude_none=True, by_alias=True)
|
||||
)
|
||||
|
||||
async def process_messages():
|
||||
try:
|
||||
while True:
|
||||
data = await websocket.receive_text()
|
||||
# Validate and send the received message to the live queue.
|
||||
live_request_queue.send(LiveRequest.model_validate_json(data))
|
||||
except ValidationError as ve:
|
||||
logger.error("Validation error in process_messages: %s", ve)
|
||||
|
||||
# Run both tasks concurrently and cancel all if one fails.
|
||||
tasks = [
|
||||
asyncio.create_task(forward_events()),
|
||||
asyncio.create_task(process_messages()),
|
||||
]
|
||||
done, pending = await asyncio.wait(
|
||||
tasks, return_when=asyncio.FIRST_EXCEPTION
|
||||
)
|
||||
try:
|
||||
# This will re-raise any exception from the completed tasks.
|
||||
for task in done:
|
||||
task.result()
|
||||
except WebSocketDisconnect:
|
||||
logger.info("Client disconnected during process_messages.")
|
||||
except Exception as e:
|
||||
logger.exception("Error during live websocket communication: %s", e)
|
||||
traceback.print_exc()
|
||||
finally:
|
||||
for task in pending:
|
||||
task.cancel()
|
||||
|
||||
def _get_root_agent(app_name: str) -> Agent:
|
||||
"""Returns the root agent for the given app."""
|
||||
if app_name in root_agent_dict:
|
||||
return root_agent_dict[app_name]
|
||||
envs.load_dotenv_for_agent(os.path.basename(app_name), agent_dir)
|
||||
agent_module = importlib.import_module(app_name)
|
||||
root_agent: Agent = agent_module.agent.root_agent
|
||||
root_agent_dict[app_name] = root_agent
|
||||
return root_agent
|
||||
|
||||
def _get_runner(app_name: str) -> Runner:
|
||||
"""Returns the runner for the given app."""
|
||||
if app_name in runner_dict:
|
||||
return runner_dict[app_name]
|
||||
root_agent = _get_root_agent(app_name)
|
||||
runner = Runner(
|
||||
app_name=agent_engine_id if agent_engine_id else app_name,
|
||||
agent=root_agent,
|
||||
artifact_service=artifact_service,
|
||||
session_service=session_service,
|
||||
)
|
||||
runner_dict[app_name] = runner
|
||||
return runner
|
||||
|
||||
if web:
|
||||
BASE_DIR = Path(__file__).parent.resolve()
|
||||
ANGULAR_DIST_PATH = BASE_DIR / "browser"
|
||||
|
||||
@app.get("/")
|
||||
async def redirect_to_dev_ui():
|
||||
return RedirectResponse("/dev-ui")
|
||||
|
||||
@app.get("/dev-ui")
|
||||
async def dev_ui():
|
||||
return FileResponse(BASE_DIR / "browser/index.html")
|
||||
|
||||
app.mount(
|
||||
"/", StaticFiles(directory=ANGULAR_DIST_PATH, html=True), name="static"
|
||||
)
|
||||
return app
|
||||
49
src/google/adk/cli/utils/__init__.py
Normal file
49
src/google/adk/cli/utils/__init__.py
Normal file
@ -0,0 +1,49 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import re
|
||||
from typing import Any
|
||||
from typing import Optional
|
||||
|
||||
from ...agents.base_agent import BaseAgent
|
||||
from ...agents.llm_agent import LlmAgent
|
||||
|
||||
__all__ = [
|
||||
'create_empty_state',
|
||||
]
|
||||
|
||||
|
||||
def _create_empty_state(agent: BaseAgent, all_state: dict[str, Any]):
|
||||
for sub_agent in agent.sub_agents:
|
||||
_create_empty_state(sub_agent, all_state)
|
||||
|
||||
if (
|
||||
isinstance(agent, LlmAgent)
|
||||
and agent.instruction
|
||||
and isinstance(agent.instruction, str)
|
||||
):
|
||||
for key in re.findall(r'{([\w]+)}', agent.instruction):
|
||||
all_state[key] = ''
|
||||
|
||||
|
||||
def create_empty_state(
|
||||
agent: BaseAgent, initialized_states: Optional[dict[str, Any]] = None
|
||||
) -> dict[str, Any]:
|
||||
"""Creates empty str for non-initialized states."""
|
||||
non_initialized_states = {}
|
||||
_create_empty_state(agent, non_initialized_states)
|
||||
for key in initialized_states or {}:
|
||||
if key in non_initialized_states:
|
||||
del non_initialized_states[key]
|
||||
return non_initialized_states
|
||||
57
src/google/adk/cli/utils/envs.py
Normal file
57
src/google/adk/cli/utils/envs.py
Normal file
@ -0,0 +1,57 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
import os
|
||||
|
||||
from dotenv import load_dotenv
|
||||
|
||||
logger = logging.getLogger(__file__)
|
||||
|
||||
|
||||
def _walk_to_root_until_found(folder, filename) -> str:
|
||||
checkpath = os.path.join(folder, filename)
|
||||
if os.path.exists(checkpath) and os.path.isfile(checkpath):
|
||||
return checkpath
|
||||
|
||||
parent_folder = os.path.dirname(folder)
|
||||
if parent_folder == folder: # reached the root
|
||||
return ''
|
||||
|
||||
return _walk_to_root_until_found(parent_folder, filename)
|
||||
|
||||
|
||||
def load_dotenv_for_agent(
|
||||
agent_name: str, agent_parent_folder: str, filename: str = '.env'
|
||||
):
|
||||
"""Lods the .env file for the agent module."""
|
||||
|
||||
# Gets the folder of agent_module as starting_folder
|
||||
starting_folder = os.path.abspath(
|
||||
os.path.join(agent_parent_folder, agent_name)
|
||||
)
|
||||
dotenv_file_path = _walk_to_root_until_found(starting_folder, filename)
|
||||
if dotenv_file_path:
|
||||
load_dotenv(dotenv_file_path, override=True, verbose=True)
|
||||
logger.info(
|
||||
'Loaded %s file for %s at %s',
|
||||
filename,
|
||||
agent_name,
|
||||
dotenv_file_path,
|
||||
)
|
||||
logger.info(
|
||||
'Reloaded %s file for %s at %s', filename, agent_name, dotenv_file_path
|
||||
)
|
||||
else:
|
||||
logger.info('No %s file found for %s', filename, agent_name)
|
||||
93
src/google/adk/cli/utils/evals.py
Normal file
93
src/google/adk/cli/utils/evals.py
Normal file
@ -0,0 +1,93 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Any
|
||||
|
||||
from ...sessions.session import Session
|
||||
|
||||
|
||||
def convert_session_to_eval_format(session: Session) -> list[dict[str, Any]]:
|
||||
"""Converts a session data into eval format.
|
||||
|
||||
Args:
|
||||
session: The session that should be converted.
|
||||
|
||||
Returns:
|
||||
list: A single evaluation dataset in the required format.
|
||||
"""
|
||||
eval_case = []
|
||||
events = session.events if session and session.events else []
|
||||
|
||||
for event in events:
|
||||
if event.author == 'user':
|
||||
if not event.content or not event.content.parts:
|
||||
continue
|
||||
|
||||
# Extract user query
|
||||
content = event.content
|
||||
parts = content.parts
|
||||
|
||||
query = parts[0].text or ''
|
||||
|
||||
# Find the corresponding tool usage or response for the query
|
||||
expected_tool_use = []
|
||||
intermediate_agent_responses = []
|
||||
|
||||
# Check subsequent events to extract tool uses or responses for this turn.
|
||||
for subsequent_event in events[events.index(event) + 1 :]:
|
||||
event_author = subsequent_event.author or 'agent'
|
||||
if event_author == 'user':
|
||||
# We found an event where the author was the user. This means that a
|
||||
# new turn has started. So close this turn here.
|
||||
break
|
||||
|
||||
if not subsequent_event.content or not subsequent_event.content.parts:
|
||||
continue
|
||||
|
||||
for subsequent_part in subsequent_event.content.parts:
|
||||
# Some events have both function call and reference
|
||||
|
||||
if subsequent_part.function_call:
|
||||
tool_name = subsequent_part.function_call.name or ''
|
||||
tool_input = subsequent_part.function_call.args or {}
|
||||
expected_tool_use.append({
|
||||
'tool_name': tool_name,
|
||||
'tool_input': tool_input,
|
||||
})
|
||||
elif subsequent_part.text:
|
||||
# Also keep track of all the natural langauge responses that
|
||||
# agent (or sub agents) generated.
|
||||
intermediate_agent_responses.append(
|
||||
{'author': event_author, 'text': subsequent_part.text}
|
||||
)
|
||||
|
||||
# If we are here then either we are done reading all the events or we
|
||||
# encountered an event that had content authored by the end-user.
|
||||
# This, basically means an end of turn.
|
||||
# We assume that the last natural langauge intermediate response is the
|
||||
# final response from the agent/model. We treat that as a reference.
|
||||
eval_case.append({
|
||||
'query': query,
|
||||
'expected_tool_use': expected_tool_use,
|
||||
'expected_intermediate_agent_responses': intermediate_agent_responses[
|
||||
:-1
|
||||
],
|
||||
'reference': (
|
||||
intermediate_agent_responses[-1]['text']
|
||||
if intermediate_agent_responses
|
||||
else ''
|
||||
),
|
||||
})
|
||||
|
||||
return eval_case
|
||||
72
src/google/adk/cli/utils/logs.py
Normal file
72
src/google/adk/cli/utils/logs.py
Normal file
@ -0,0 +1,72 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
import os
|
||||
import tempfile
|
||||
import time
|
||||
|
||||
LOGGING_FORMAT = (
|
||||
'%(asctime)s - %(levelname)s - %(filename)s:%(lineno)d - %(message)s'
|
||||
)
|
||||
|
||||
|
||||
def log_to_stderr(level=logging.INFO):
|
||||
logging.basicConfig(
|
||||
level=level,
|
||||
format=LOGGING_FORMAT,
|
||||
)
|
||||
|
||||
|
||||
def log_to_tmp_folder(
|
||||
level=logging.INFO,
|
||||
*,
|
||||
sub_folder: str = 'agents_log',
|
||||
log_file_prefix: str = 'agent',
|
||||
log_file_timestamp: str = time.strftime('%Y%m%d_%H%M%S'),
|
||||
):
|
||||
"""Logs to system temp folder, instead of logging to stderr.
|
||||
|
||||
Args
|
||||
sub_folder: str = 'agents_log',
|
||||
log_file_prefix: str = 'agent',
|
||||
log_file_timestamp: str = time.strftime('%Y%m%d_%H%M%S'),
|
||||
|
||||
Returns
|
||||
the log file path.
|
||||
"""
|
||||
log_dir = os.path.join(tempfile.gettempdir(), sub_folder)
|
||||
log_filename = f'{log_file_prefix}.{log_file_timestamp}.log'
|
||||
log_filepath = os.path.join(log_dir, log_filename)
|
||||
|
||||
os.makedirs(log_dir, exist_ok=True)
|
||||
|
||||
file_handler = logging.FileHandler(log_filepath, mode='w')
|
||||
file_handler.setLevel(level)
|
||||
file_handler.setFormatter(logging.Formatter(LOGGING_FORMAT))
|
||||
|
||||
root_logger = logging.getLogger()
|
||||
root_logger.setLevel(level)
|
||||
root_logger.handlers = [] # Clear handles to disable logging to stderr
|
||||
root_logger.addHandler(file_handler)
|
||||
|
||||
print(f'Log setup complete: {log_filepath}')
|
||||
|
||||
latest_log_link = os.path.join(log_dir, f'{log_file_prefix}.latest.log')
|
||||
if os.path.islink(latest_log_link):
|
||||
os.unlink(latest_log_link)
|
||||
os.symlink(log_filepath, latest_log_link)
|
||||
|
||||
print(f'To access latest log: tail -F {latest_log_link}')
|
||||
return log_filepath
|
||||
49
src/google/adk/code_executors/__init__.py
Normal file
49
src/google/adk/code_executors/__init__.py
Normal file
@ -0,0 +1,49 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
|
||||
from .base_code_executor import BaseCodeExecutor
|
||||
from .code_executor_context import CodeExecutorContext
|
||||
from .unsafe_local_code_executor import UnsafeLocalCodeExecutor
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
__all__ = [
|
||||
'BaseCodeExecutor',
|
||||
'CodeExecutorContext',
|
||||
'UnsafeLocalCodeExecutor',
|
||||
]
|
||||
|
||||
try:
|
||||
from .vertex_ai_code_executor import VertexAiCodeExecutor
|
||||
|
||||
__all__.append('VertexAiCodeExecutor')
|
||||
except ImportError:
|
||||
logger.debug(
|
||||
'The Vertex sdk is not installed. If you want to use the Vertex Code'
|
||||
' Interpreter with agents, please install it. If not, you can ignore this'
|
||||
' warning.'
|
||||
)
|
||||
|
||||
try:
|
||||
from .container_code_executor import ContainerCodeExecutor
|
||||
|
||||
__all__.append('ContainerCodeExecutor')
|
||||
except ImportError:
|
||||
logger.debug(
|
||||
'The docker sdk is not installed. If you want to use the Container Code'
|
||||
' Executor with agents, please install it. If not, you can ignore this'
|
||||
' warning.'
|
||||
)
|
||||
97
src/google/adk/code_executors/base_code_executor.py
Normal file
97
src/google/adk/code_executors/base_code_executor.py
Normal file
@ -0,0 +1,97 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import abc
|
||||
from typing import List
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from ..agents.invocation_context import InvocationContext
|
||||
from .code_execution_utils import CodeExecutionInput
|
||||
from .code_execution_utils import CodeExecutionResult
|
||||
|
||||
|
||||
class BaseCodeExecutor(BaseModel):
|
||||
"""Abstract base class for all code executors.
|
||||
|
||||
The code executor allows the agent to execute code blocks from model responses
|
||||
and incorporate the execution results into the final response.
|
||||
|
||||
Attributes:
|
||||
optimize_data_file: If true, extract and process data files from the model
|
||||
request and attach them to the code executor. Supported data file
|
||||
MimeTypes are [text/csv]. Default to False.
|
||||
stateful: Whether the code executor is stateful. Default to False.
|
||||
error_retry_attempts: The number of attempts to retry on consecutive code
|
||||
execution errors. Default to 2.
|
||||
code_block_delimiters: The list of the enclosing delimiters to identify the
|
||||
code blocks.
|
||||
execution_result_delimiters: The delimiters to format the code execution
|
||||
result.
|
||||
"""
|
||||
|
||||
optimize_data_file: bool = False
|
||||
"""
|
||||
If true, extract and process data files from the model request
|
||||
and attach them to the code executor.
|
||||
Supported data file MimeTypes are [text/csv].
|
||||
|
||||
Default to False.
|
||||
"""
|
||||
|
||||
stateful: bool = False
|
||||
"""
|
||||
Whether the code executor is stateful. Default to False.
|
||||
"""
|
||||
|
||||
error_retry_attempts: int = 2
|
||||
"""
|
||||
The number of attempts to retry on consecutive code execution errors. Default to 2.
|
||||
"""
|
||||
|
||||
code_block_delimiters: List[tuple[str, str]] = [
|
||||
('```tool_code\n', '\n```'),
|
||||
('```python\n', '\n```'),
|
||||
]
|
||||
"""
|
||||
The list of the enclosing delimiters to identify the code blocks.
|
||||
For example, the delimiter ('```python\n', '\n```') can be
|
||||
used to identify code blocks with the following format:
|
||||
|
||||
```python
|
||||
print("hello")
|
||||
```
|
||||
"""
|
||||
|
||||
execution_result_delimiters: tuple[str, str] = ('```tool_output\n', '\n```')
|
||||
"""
|
||||
The delimiters to format the code execution result.
|
||||
"""
|
||||
|
||||
@abc.abstractmethod
|
||||
def execute_code(
|
||||
self,
|
||||
invocation_context: InvocationContext,
|
||||
code_execution_input: CodeExecutionInput,
|
||||
) -> CodeExecutionResult:
|
||||
"""Executes code and return the code execution result.
|
||||
|
||||
Args:
|
||||
invocation_context: The invocation context of the code execution.
|
||||
code_execution_input: The code execution input.
|
||||
|
||||
Returns:
|
||||
The code execution result.
|
||||
"""
|
||||
pass
|
||||
245
src/google/adk/code_executors/code_execution_utils.py
Normal file
245
src/google/adk/code_executors/code_execution_utils.py
Normal file
@ -0,0 +1,245 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Utility functions for code execution."""
|
||||
|
||||
import base64
|
||||
import binascii
|
||||
import copy
|
||||
import dataclasses
|
||||
import re
|
||||
from typing import List, Optional
|
||||
|
||||
from google.genai import types
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class File:
|
||||
"""A structure that contains a file name and its content."""
|
||||
|
||||
name: str
|
||||
"""
|
||||
The name of the file with file extension (e.g., "file.csv").
|
||||
"""
|
||||
|
||||
content: str
|
||||
"""
|
||||
The base64-encoded bytes of the file content.
|
||||
"""
|
||||
|
||||
mime_type: str = 'text/plain'
|
||||
"""
|
||||
The mime type of the file (e.g., "image/png").
|
||||
"""
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class CodeExecutionInput:
|
||||
"""A structure that contains the input of code execution."""
|
||||
|
||||
code: str
|
||||
"""
|
||||
The code to execute.
|
||||
"""
|
||||
|
||||
input_files: list[File] = dataclasses.field(default_factory=list)
|
||||
"""
|
||||
The input files available to the code.
|
||||
"""
|
||||
|
||||
execution_id: Optional[str] = None
|
||||
"""
|
||||
The execution ID for the stateful code execution.
|
||||
"""
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class CodeExecutionResult:
|
||||
"""A structure that contains the result of code execution."""
|
||||
|
||||
stdout: str = ''
|
||||
"""
|
||||
The standard output of the code execution.
|
||||
"""
|
||||
|
||||
stderr: str = ''
|
||||
"""
|
||||
The standard error of the code execution.
|
||||
"""
|
||||
|
||||
output_files: list[File] = dataclasses.field(default_factory=list)
|
||||
"""
|
||||
The output files from the code execution.
|
||||
"""
|
||||
|
||||
|
||||
class CodeExecutionUtils:
|
||||
"""Utility functions for code execution."""
|
||||
|
||||
@staticmethod
|
||||
def get_encoded_file_content(data: bytes) -> bytes:
|
||||
"""Gets the file content as a base64-encoded bytes.
|
||||
|
||||
Args:
|
||||
data: The file content bytes.
|
||||
|
||||
Returns:
|
||||
The file content as a base64-encoded bytes.
|
||||
"""
|
||||
|
||||
def _is_base64_encoded(data: bytes) -> bool:
|
||||
try:
|
||||
return base64.b64encode(base64.b64decode(data)) == data
|
||||
except binascii.Error:
|
||||
return False
|
||||
|
||||
return data if _is_base64_encoded(data) else base64.b64encode(data)
|
||||
|
||||
@staticmethod
|
||||
def extract_code_and_truncate_content(
|
||||
content: types.Content,
|
||||
code_block_delimiters: List[tuple[str, str]],
|
||||
) -> Optional[str]:
|
||||
"""Extracts the first code block from the content and truncate everything after it.
|
||||
|
||||
Args:
|
||||
content: The mutable content to extract the code from.
|
||||
code_block_delimiters: The list of the enclosing delimiters to identify
|
||||
the code blocks.
|
||||
|
||||
Returns:
|
||||
The first code block if found, otherwise None.
|
||||
"""
|
||||
if not content or not content.parts:
|
||||
return
|
||||
|
||||
text_parts = [p for p in content.parts if p.text]
|
||||
if not text_parts:
|
||||
return
|
||||
|
||||
first_text_part = copy.deepcopy(text_parts[0])
|
||||
response_text = '\n'.join([p.text for p in text_parts])
|
||||
|
||||
# Find the first code block.
|
||||
leading_delimiter_pattern = '|'.join(d[0] for d in code_block_delimiters)
|
||||
trailing_delimiter_pattern = '|'.join(d[1] for d in code_block_delimiters)
|
||||
pattern = re.compile(
|
||||
(
|
||||
rf'(?P<prefix>.*?)({leading_delimiter_pattern})(?P<code>.*?)({trailing_delimiter_pattern})(?P<suffix>.*?)$'
|
||||
).encode(),
|
||||
re.DOTALL,
|
||||
)
|
||||
pattern_match = pattern.search(response_text.encode())
|
||||
if pattern_match is None:
|
||||
return
|
||||
|
||||
code_str = pattern_match.group('code').decode()
|
||||
if not code_str:
|
||||
return
|
||||
|
||||
content.parts = []
|
||||
if pattern_match.group('prefix'):
|
||||
first_text_part.text = pattern_match.group('prefix').decode()
|
||||
content.parts.append(first_text_part)
|
||||
content.parts.append(
|
||||
CodeExecutionUtils.build_executable_code_part(code_str)
|
||||
)
|
||||
return pattern_match.group('code').decode()
|
||||
|
||||
@staticmethod
|
||||
def build_executable_code_part(code: str) -> types.Part:
|
||||
"""Builds an executable code part with code string.
|
||||
|
||||
Args:
|
||||
code: The code string.
|
||||
|
||||
Returns:
|
||||
The constructed executable code part.
|
||||
"""
|
||||
return types.Part.from_executable_code(
|
||||
code=code,
|
||||
language='PYTHON',
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def build_code_execution_result_part(
|
||||
code_execution_result: CodeExecutionResult,
|
||||
) -> types.Part:
|
||||
"""Builds the code execution result part from the code execution result.
|
||||
|
||||
Args:
|
||||
code_execution_result: The code execution result.
|
||||
|
||||
Returns:
|
||||
The constructed code execution result part.
|
||||
"""
|
||||
if code_execution_result.stderr:
|
||||
return types.Part.from_code_execution_result(
|
||||
outcome='OUTCOME_FAILED',
|
||||
output=code_execution_result.stderr,
|
||||
)
|
||||
final_result = []
|
||||
if code_execution_result.stdout or not code_execution_result.output_files:
|
||||
final_result.append(
|
||||
'Code execution result:\n' + '%s\n' % code_execution_result.stdout
|
||||
)
|
||||
if code_execution_result.output_files:
|
||||
final_result.append(
|
||||
'Saved artifacts:\n'
|
||||
+ ','.join(
|
||||
['`%s`' % f.name for f in code_execution_result.output_files]
|
||||
)
|
||||
)
|
||||
return types.Part.from_code_execution_result(
|
||||
outcome='OUTCOME_OK',
|
||||
output='\n\n'.join(final_result),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def convert_code_execution_parts(
|
||||
content: types.Content,
|
||||
code_block_delimiter: tuple[str, str],
|
||||
execution_result_delimiters: tuple[str, str],
|
||||
):
|
||||
"""Converts the code execution parts to text parts in a Content.
|
||||
|
||||
Args:
|
||||
content: The mutable content to convert the code execution parts to text
|
||||
parts.
|
||||
code_block_delimiter: The delimiter to format the code block.
|
||||
execution_result_delimiters: The delimiter to format the code execution
|
||||
result.
|
||||
"""
|
||||
if not content.parts:
|
||||
return
|
||||
|
||||
# Handle the conversion of trailing executable code parts.
|
||||
if content.parts[-1].executable_code:
|
||||
content.parts[-1] = types.Part(
|
||||
text=(
|
||||
code_block_delimiter[0]
|
||||
+ content.parts[-1].executable_code.code
|
||||
+ code_block_delimiter[1]
|
||||
)
|
||||
)
|
||||
# Handle the conversion of trailing code execution result parts.
|
||||
# Skip if the Content has multiple parts, which means the Content is
|
||||
# likely generated by the model.
|
||||
elif len(content.parts) == 1 and content.parts[-1].code_execution_result:
|
||||
content.parts[-1] = types.Part(
|
||||
text=execution_result_delimiters[0]
|
||||
+ content.parts[-1].code_execution_result.output
|
||||
+ execution_result_delimiters[1]
|
||||
)
|
||||
content.role = 'user'
|
||||
202
src/google/adk/code_executors/code_executor_context.py
Normal file
202
src/google/adk/code_executors/code_executor_context.py
Normal file
@ -0,0 +1,202 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""The persistent context used to configure the code executor."""
|
||||
|
||||
import copy
|
||||
import dataclasses
|
||||
import datetime
|
||||
from typing import Any
|
||||
from typing import Optional
|
||||
|
||||
from ..sessions.state import State
|
||||
from .code_execution_utils import File
|
||||
|
||||
_CONTEXT_KEY = '_code_execution_context'
|
||||
_SESSION_ID_KEY = 'execution_session_id'
|
||||
_PROCESSED_FILE_NAMES_KEY = 'processed_input_files'
|
||||
_INPUT_FILE_KEY = '_code_executor_input_files'
|
||||
_ERROR_COUNT_KEY = '_code_executor_error_counts'
|
||||
|
||||
_CODE_EXECUTION_RESULTS_KEY = '_code_execution_results'
|
||||
|
||||
|
||||
class CodeExecutorContext:
|
||||
"""The persistent context used to configure the code executor."""
|
||||
|
||||
_context: dict[str, Any]
|
||||
|
||||
def __init__(self, session_state: State):
|
||||
"""Initializes the code executor context.
|
||||
|
||||
Args:
|
||||
session_state: The session state to get the code executor context from.
|
||||
"""
|
||||
self._context = self._get_code_executor_context(session_state)
|
||||
self._session_state = session_state
|
||||
|
||||
def get_state_delta(self) -> dict[str, Any]:
|
||||
"""Gets the state delta to update in the persistent session state.
|
||||
|
||||
Returns:
|
||||
The state delta to update in the persistent session state.
|
||||
"""
|
||||
context_to_update = copy.deepcopy(self._context)
|
||||
return {_CONTEXT_KEY: context_to_update}
|
||||
|
||||
def get_execution_id(self) -> Optional[str]:
|
||||
"""Gets the session ID for the code executor.
|
||||
|
||||
Returns:
|
||||
The session ID for the code executor context.
|
||||
"""
|
||||
if _SESSION_ID_KEY not in self._context:
|
||||
return None
|
||||
return self._context[_SESSION_ID_KEY]
|
||||
|
||||
def set_execution_id(self, session_id: str):
|
||||
"""Sets the session ID for the code executor.
|
||||
|
||||
Args:
|
||||
session_id: The session ID for the code executor.
|
||||
"""
|
||||
self._context[_SESSION_ID_KEY] = session_id
|
||||
|
||||
def get_processed_file_names(self) -> list[str]:
|
||||
"""Gets the processed file names from the session state.
|
||||
|
||||
Returns:
|
||||
A list of processed file names in the code executor context.
|
||||
"""
|
||||
if _PROCESSED_FILE_NAMES_KEY not in self._context:
|
||||
return []
|
||||
return self._context[_PROCESSED_FILE_NAMES_KEY]
|
||||
|
||||
def add_processed_file_names(self, file_names: [str]):
|
||||
"""Adds the processed file name to the session state.
|
||||
|
||||
Args:
|
||||
file_names: The processed file names to add to the session state.
|
||||
"""
|
||||
if _PROCESSED_FILE_NAMES_KEY not in self._context:
|
||||
self._context[_PROCESSED_FILE_NAMES_KEY] = []
|
||||
self._context[_PROCESSED_FILE_NAMES_KEY].extend(file_names)
|
||||
|
||||
def get_input_files(self) -> list[File]:
|
||||
"""Gets the code executor input file names from the session state.
|
||||
|
||||
Returns:
|
||||
A list of input files in the code executor context.
|
||||
"""
|
||||
if _INPUT_FILE_KEY not in self._session_state:
|
||||
return []
|
||||
return [File(**file) for file in self._session_state[_INPUT_FILE_KEY]]
|
||||
|
||||
def add_input_files(
|
||||
self,
|
||||
input_files: list[File],
|
||||
):
|
||||
"""Adds the input files to the code executor context.
|
||||
|
||||
Args:
|
||||
input_files: The input files to add to the code executor context.
|
||||
"""
|
||||
if _INPUT_FILE_KEY not in self._session_state:
|
||||
self._session_state[_INPUT_FILE_KEY] = []
|
||||
for input_file in input_files:
|
||||
self._session_state[_INPUT_FILE_KEY].append(
|
||||
dataclasses.asdict(input_file)
|
||||
)
|
||||
|
||||
def clear_input_files(self):
|
||||
"""Removes the input files and processed file names to the code executor context."""
|
||||
if _INPUT_FILE_KEY in self._session_state:
|
||||
self._session_state[_INPUT_FILE_KEY] = []
|
||||
if _PROCESSED_FILE_NAMES_KEY in self._context:
|
||||
self._context[_PROCESSED_FILE_NAMES_KEY] = []
|
||||
|
||||
def get_error_count(self, invocation_id: str) -> int:
|
||||
"""Gets the error count from the session state.
|
||||
|
||||
Args:
|
||||
invocation_id: The invocation ID to get the error count for.
|
||||
|
||||
Returns:
|
||||
The error count for the given invocation ID.
|
||||
"""
|
||||
if _ERROR_COUNT_KEY not in self._session_state:
|
||||
return 0
|
||||
return self._session_state[_ERROR_COUNT_KEY].get(invocation_id, 0)
|
||||
|
||||
def increment_error_count(self, invocation_id: str):
|
||||
"""Increments the error count from the session state.
|
||||
|
||||
Args:
|
||||
invocation_id: The invocation ID to increment the error count for.
|
||||
"""
|
||||
if _ERROR_COUNT_KEY not in self._session_state:
|
||||
self._session_state[_ERROR_COUNT_KEY] = {}
|
||||
self._session_state[_ERROR_COUNT_KEY][invocation_id] = (
|
||||
self.get_error_count(invocation_id) + 1
|
||||
)
|
||||
|
||||
def reset_error_count(self, invocation_id: str):
|
||||
"""Resets the error count from the session state.
|
||||
|
||||
Args:
|
||||
invocation_id: The invocation ID to reset the error count for.
|
||||
"""
|
||||
if _ERROR_COUNT_KEY not in self._session_state:
|
||||
return
|
||||
if invocation_id in self._session_state[_ERROR_COUNT_KEY]:
|
||||
del self._session_state[_ERROR_COUNT_KEY][invocation_id]
|
||||
|
||||
def update_code_execution_result(
|
||||
self,
|
||||
invocation_id: str,
|
||||
code: str,
|
||||
result_stdout: str,
|
||||
result_stderr: str,
|
||||
):
|
||||
"""Updates the code execution result.
|
||||
|
||||
Args:
|
||||
invocation_id: The invocation ID to update the code execution result for.
|
||||
code: The code to execute.
|
||||
result_stdout: The standard output of the code execution.
|
||||
result_stderr: The standard error of the code execution.
|
||||
"""
|
||||
if _CODE_EXECUTION_RESULTS_KEY not in self._session_state:
|
||||
self._session_state[_CODE_EXECUTION_RESULTS_KEY] = {}
|
||||
if invocation_id not in self._session_state[_CODE_EXECUTION_RESULTS_KEY]:
|
||||
self._session_state[_CODE_EXECUTION_RESULTS_KEY][invocation_id] = []
|
||||
self._session_state[_CODE_EXECUTION_RESULTS_KEY][invocation_id].append({
|
||||
'code': code,
|
||||
'result_stdout': result_stdout,
|
||||
'result_stderr': result_stderr,
|
||||
'timestamp': int(datetime.datetime.now().timestamp()),
|
||||
})
|
||||
|
||||
def _get_code_executor_context(self, session_state: State) -> dict[str, Any]:
|
||||
"""Gets the code executor context from the session state.
|
||||
|
||||
Args:
|
||||
session_state: The session state to get the code executor context from.
|
||||
|
||||
Returns:
|
||||
A dict of code executor context.
|
||||
"""
|
||||
if _CONTEXT_KEY not in session_state:
|
||||
session_state[_CONTEXT_KEY] = {}
|
||||
return session_state[_CONTEXT_KEY]
|
||||
196
src/google/adk/code_executors/container_code_executor.py
Normal file
196
src/google/adk/code_executors/container_code_executor.py
Normal file
@ -0,0 +1,196 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import atexit
|
||||
import os
|
||||
from typing import Optional
|
||||
|
||||
import docker
|
||||
from docker.client import DockerClient
|
||||
from docker.models.containers import Container
|
||||
from pydantic import Field
|
||||
from typing_extensions import override
|
||||
|
||||
from ..agents.invocation_context import InvocationContext
|
||||
from .base_code_executor import BaseCodeExecutor
|
||||
from .code_execution_utils import CodeExecutionInput
|
||||
from .code_execution_utils import CodeExecutionResult
|
||||
|
||||
|
||||
DEFAULT_IMAGE_TAG = 'adk-code-executor:latest'
|
||||
|
||||
|
||||
class ContainerCodeExecutor(BaseCodeExecutor):
|
||||
"""A code executor that uses a custom container to execute code.
|
||||
|
||||
Attributes:
|
||||
base_url: Optional. The base url of the user hosted Docker client.
|
||||
image: The tag of the predefined image or custom image to run on the
|
||||
container. Either docker_path or image must be set.
|
||||
docker_path: The path to the directory containing the Dockerfile. If set,
|
||||
build the image from the dockerfile path instead of using the predefined
|
||||
image. Either docker_path or image must be set.
|
||||
"""
|
||||
|
||||
base_url: Optional[str] = None
|
||||
"""
|
||||
Optional. The base url of the user hosted Docker client.
|
||||
"""
|
||||
|
||||
image: str = None
|
||||
"""
|
||||
The tag of the predefined image or custom image to run on the container.
|
||||
Either docker_path or image must be set.
|
||||
"""
|
||||
|
||||
docker_path: str = None
|
||||
"""
|
||||
The path to the directory containing the Dockerfile.
|
||||
If set, build the image from the dockerfile path instead of using the
|
||||
predefined image. Either docker_path or image must be set.
|
||||
"""
|
||||
|
||||
# Overrides the BaseCodeExecutor attribute: this executor cannot be stateful.
|
||||
stateful: bool = Field(default=False, frozen=True, exclude=True)
|
||||
|
||||
# Overrides the BaseCodeExecutor attribute: this executor cannot
|
||||
# optimize_data_file.
|
||||
optimize_data_file: bool = Field(default=False, frozen=True, exclude=True)
|
||||
|
||||
_client: DockerClient = None
|
||||
_container: Container = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
base_url: Optional[str] = None,
|
||||
image: Optional[str] = None,
|
||||
docker_path: Optional[str] = None,
|
||||
**data,
|
||||
):
|
||||
"""Initializes the ContainerCodeExecutor.
|
||||
|
||||
Args:
|
||||
base_url: Optional. The base url of the user hosted Docker client.
|
||||
image: The tag of the predefined image or custom image to run on the
|
||||
container. Either docker_path or image must be set.
|
||||
docker_path: The path to the directory containing the Dockerfile. If set,
|
||||
build the image from the dockerfile path instead of using the predefined
|
||||
image. Either docker_path or image must be set.
|
||||
**data: The data to initialize the ContainerCodeExecutor.
|
||||
"""
|
||||
if not image and not docker_path:
|
||||
raise ValueError(
|
||||
'Either image or docker_path must be set for ContainerCodeExecutor.'
|
||||
)
|
||||
if 'stateful' in data and data['stateful']:
|
||||
raise ValueError('Cannot set `stateful=True` in ContainerCodeExecutor.')
|
||||
if 'optimize_data_file' in data and data['optimize_data_file']:
|
||||
raise ValueError(
|
||||
'Cannot set `optimize_data_file=True` in ContainerCodeExecutor.'
|
||||
)
|
||||
|
||||
super().__init__(**data)
|
||||
self.base_url = base_url
|
||||
self.image = image if image else DEFAULT_IMAGE_TAG
|
||||
self.docker_path = os.path.abspath(docker_path) if docker_path else None
|
||||
|
||||
self._client = (
|
||||
docker.from_env()
|
||||
if not self.base_url
|
||||
else docker.DockerClient(base_url=self.base_url)
|
||||
)
|
||||
# Initialize the container.
|
||||
self.__init_container()
|
||||
|
||||
# Close the container when the on exit.
|
||||
atexit.register(self.__cleanup_container)
|
||||
|
||||
@override
|
||||
def execute_code(
|
||||
self,
|
||||
invocation_context: InvocationContext,
|
||||
code_execution_input: CodeExecutionInput,
|
||||
) -> CodeExecutionResult:
|
||||
output = ''
|
||||
error = ''
|
||||
exec_result = self._container.exec_run(
|
||||
['python3', '-c', code_execution_input.code],
|
||||
demux=True,
|
||||
)
|
||||
|
||||
if exec_result.output and exec_result.output[0]:
|
||||
output = exec_result.output[0].decode('utf-8')
|
||||
if (
|
||||
exec_result.output
|
||||
and len(exec_result.output) > 1
|
||||
and exec_result.output[1]
|
||||
):
|
||||
error = exec_result.output[1].decode('utf-8')
|
||||
|
||||
# Collect the final result.
|
||||
return CodeExecutionResult(
|
||||
stdout=output,
|
||||
stderr=error,
|
||||
output_files=[],
|
||||
)
|
||||
|
||||
def _build_docker_image(self):
|
||||
"""Builds the Docker image."""
|
||||
if not self.docker_path:
|
||||
raise ValueError('Docker path is not set.')
|
||||
if not os.path.exists(self.docker_path):
|
||||
raise FileNotFoundError(f'Invalid Docker path: {self.docker_path}')
|
||||
|
||||
print('Building Docker image...')
|
||||
self._client.images.build(
|
||||
path=self.docker_path,
|
||||
tag=self.image,
|
||||
rm=True,
|
||||
)
|
||||
print(f'Docker image: {self.image} built.')
|
||||
|
||||
def _verify_python_installation(self):
|
||||
"""Verifies the container has python3 installed."""
|
||||
exec_result = self._container.exec_run(['which', 'python3'])
|
||||
if exec_result.exit_code != 0:
|
||||
raise ValueError('python3 is not installed in the container.')
|
||||
|
||||
def __init_container(self):
|
||||
"""Initializes the container."""
|
||||
if not self._client:
|
||||
raise RuntimeError('Docker client is not initialized.')
|
||||
|
||||
if self.docker_path:
|
||||
self._build_docker_image()
|
||||
|
||||
print('Starting container for ContainerCodeExecutor...')
|
||||
self._container = self._client.containers.run(
|
||||
image=self.image,
|
||||
detach=True,
|
||||
tty=True,
|
||||
)
|
||||
print(f'Container {self._container.id} started.')
|
||||
|
||||
# Verify the container is able to run python3.
|
||||
self._verify_python_installation()
|
||||
|
||||
def __cleanup_container(self):
|
||||
"""Closes the container on exit."""
|
||||
if not self._container:
|
||||
return
|
||||
|
||||
print('[Cleanup] Stopping the container...')
|
||||
self._container.stop()
|
||||
self._container.remove()
|
||||
print(f'Container {self._container.id} stopped and removed.')
|
||||
71
src/google/adk/code_executors/unsafe_local_code_executor.py
Normal file
71
src/google/adk/code_executors/unsafe_local_code_executor.py
Normal file
@ -0,0 +1,71 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from contextlib import redirect_stdout
|
||||
import io
|
||||
|
||||
from pydantic import Field
|
||||
from typing_extensions import override
|
||||
|
||||
from ..agents.invocation_context import InvocationContext
|
||||
from .base_code_executor import BaseCodeExecutor
|
||||
from .code_execution_utils import CodeExecutionInput
|
||||
from .code_execution_utils import CodeExecutionResult
|
||||
|
||||
|
||||
class UnsafeLocalCodeExecutor(BaseCodeExecutor):
|
||||
"""A code executor that unsafely execute code in the current local context."""
|
||||
|
||||
# Overrides the BaseCodeExecutor attribute: this executor cannot be stateful.
|
||||
stateful: bool = Field(default=False, frozen=True, exclude=True)
|
||||
|
||||
# Overrides the BaseCodeExecutor attribute: this executor cannot
|
||||
# optimize_data_file.
|
||||
optimize_data_file: bool = Field(default=False, frozen=True, exclude=True)
|
||||
|
||||
def __init__(self, **data):
|
||||
"""Initializes the UnsafeLocalCodeExecutor."""
|
||||
if 'stateful' in data and data['stateful']:
|
||||
raise ValueError('Cannot set `stateful=True` in UnsafeLocalCodeExecutor.')
|
||||
if 'optimize_data_file' in data and data['optimize_data_file']:
|
||||
raise ValueError(
|
||||
'Cannot set `optimize_data_file=True` in UnsafeLocalCodeExecutor.'
|
||||
)
|
||||
super().__init__(**data)
|
||||
|
||||
@override
|
||||
def execute_code(
|
||||
self,
|
||||
invocation_context: InvocationContext,
|
||||
code_execution_input: CodeExecutionInput,
|
||||
) -> CodeExecutionResult:
|
||||
# Execute the code.
|
||||
output = ''
|
||||
error = ''
|
||||
try:
|
||||
globals_ = {}
|
||||
locals_ = {}
|
||||
stdout = io.StringIO()
|
||||
with redirect_stdout(stdout):
|
||||
exec(code_execution_input.code, globals_, locals_)
|
||||
output = stdout.getvalue()
|
||||
except Exception as e:
|
||||
error = str(e)
|
||||
|
||||
# Collect the final result.
|
||||
return CodeExecutionResult(
|
||||
stdout=output,
|
||||
stderr=error,
|
||||
output_files=[],
|
||||
)
|
||||
234
src/google/adk/code_executors/vertex_ai_code_executor.py
Normal file
234
src/google/adk/code_executors/vertex_ai_code_executor.py
Normal file
@ -0,0 +1,234 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import datetime
|
||||
import mimetypes
|
||||
import os
|
||||
from typing import Any, Optional
|
||||
|
||||
from typing_extensions import override
|
||||
from vertexai.preview.extensions import Extension
|
||||
|
||||
from ..agents.invocation_context import InvocationContext
|
||||
from .base_code_executor import BaseCodeExecutor
|
||||
from .code_execution_utils import CodeExecutionInput
|
||||
from .code_execution_utils import CodeExecutionResult
|
||||
from .code_execution_utils import File
|
||||
|
||||
_SUPPORTED_IMAGE_TYPES = ['png', 'jpg', 'jpeg']
|
||||
_SUPPORTED_DATA_FILE_TYPES = ['csv']
|
||||
|
||||
_IMPORTED_LIBRARIES = '''
|
||||
import io
|
||||
import math
|
||||
import re
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import scipy
|
||||
|
||||
def crop(s: str, max_chars: int = 64) -> str:
|
||||
"""Crops a string to max_chars characters."""
|
||||
return s[: max_chars - 3] + '...' if len(s) > max_chars else s
|
||||
|
||||
|
||||
def explore_df(df: pd.DataFrame) -> None:
|
||||
"""Prints some information about a pandas DataFrame."""
|
||||
|
||||
with pd.option_context(
|
||||
'display.max_columns', None, 'display.expand_frame_repr', False
|
||||
):
|
||||
# Print the column names to never encounter KeyError when selecting one.
|
||||
df_dtypes = df.dtypes
|
||||
|
||||
# Obtain information about data types and missing values.
|
||||
df_nulls = (len(df) - df.isnull().sum()).apply(
|
||||
lambda x: f'{x} / {df.shape[0]} non-null'
|
||||
)
|
||||
|
||||
# Explore unique total values in columns using `.unique()`.
|
||||
df_unique_count = df.apply(lambda x: len(x.unique()))
|
||||
|
||||
# Explore unique values in columns using `.unique()`.
|
||||
df_unique = df.apply(lambda x: crop(str(list(x.unique()))))
|
||||
|
||||
df_info = pd.concat(
|
||||
(
|
||||
df_dtypes.rename('Dtype'),
|
||||
df_nulls.rename('Non-Null Count'),
|
||||
df_unique_count.rename('Unique Values Count'),
|
||||
df_unique.rename('Unique Values'),
|
||||
),
|
||||
axis=1,
|
||||
)
|
||||
df_info.index.name = 'Columns'
|
||||
print(f"""Total rows: {df.shape[0]}
|
||||
Total columns: {df.shape[1]}
|
||||
|
||||
{df_info}""")
|
||||
'''
|
||||
|
||||
|
||||
def _get_code_interpreter_extension(resource_name: str = None):
|
||||
"""Returns: Load or create the code interpreter extension."""
|
||||
if not resource_name:
|
||||
resource_name = os.environ.get('CODE_INTERPRETER_EXTENSION_NAME')
|
||||
if resource_name:
|
||||
new_code_interpreter = Extension(resource_name)
|
||||
else:
|
||||
print('No CODE_INTERPRETER_ID found in the environment. Create a new one.')
|
||||
new_code_interpreter = Extension.from_hub('code_interpreter')
|
||||
os.environ['CODE_INTERPRETER_EXTENSION_NAME'] = (
|
||||
new_code_interpreter.gca_resource.name
|
||||
)
|
||||
return new_code_interpreter
|
||||
|
||||
|
||||
class VertexAiCodeExecutor(BaseCodeExecutor):
|
||||
"""A code executor that uses Vertex Code Interpreter Extension to execute code.
|
||||
|
||||
Attributes:
|
||||
resource_name: If set, load the existing resource name of the code
|
||||
interpreter extension instead of creating a new one. Format:
|
||||
projects/123/locations/us-central1/extensions/456
|
||||
"""
|
||||
|
||||
resource_name: str = None
|
||||
"""
|
||||
If set, load the existing resource name of the code interpreter extension
|
||||
instead of creating a new one.
|
||||
Format: projects/123/locations/us-central1/extensions/456
|
||||
"""
|
||||
|
||||
_code_interpreter_extension: Extension
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
resource_name: str = None,
|
||||
**data,
|
||||
):
|
||||
"""Initializes the VertexAiCodeExecutor.
|
||||
|
||||
Args:
|
||||
resource_name: If set, load the existing resource name of the code
|
||||
interpreter extension instead of creating a new one. Format:
|
||||
projects/123/locations/us-central1/extensions/456
|
||||
**data: Additional keyword arguments to be passed to the base class.
|
||||
"""
|
||||
super().__init__(**data)
|
||||
self.resource_name = resource_name
|
||||
self._code_interpreter_extension = _get_code_interpreter_extension(
|
||||
self.resource_name
|
||||
)
|
||||
|
||||
@override
|
||||
def execute_code(
|
||||
self,
|
||||
invocation_context: InvocationContext,
|
||||
code_execution_input: CodeExecutionInput,
|
||||
) -> CodeExecutionResult:
|
||||
# Execute the code.
|
||||
code_execution_result = self._execute_code_interpreter(
|
||||
self._get_code_with_imports(code_execution_input.code),
|
||||
code_execution_input.input_files,
|
||||
code_execution_input.execution_id,
|
||||
)
|
||||
|
||||
# Save output file as artifacts.
|
||||
current_timestamp = datetime.datetime.now().strftime('%Y%m%d_%H%M%S')
|
||||
file_name_prefix = '%s_' % str(current_timestamp)
|
||||
saved_files = []
|
||||
file_count = 0
|
||||
for output_file in code_execution_result['output_files']:
|
||||
file_type = output_file['name'].split('.')[-1]
|
||||
file_name = file_name_prefix + '%d.%s' % (file_count, file_type)
|
||||
if file_type in _SUPPORTED_IMAGE_TYPES:
|
||||
file_count += 1
|
||||
saved_files.append(
|
||||
File(
|
||||
name='plot_' + file_name,
|
||||
content=output_file['contents'],
|
||||
mime_type=f'image/{file_type}',
|
||||
)
|
||||
)
|
||||
elif file_type in _SUPPORTED_DATA_FILE_TYPES:
|
||||
file_count += 1
|
||||
saved_files.append(
|
||||
File(
|
||||
name='data_' + file_name,
|
||||
content=output_file['contents'],
|
||||
mime_type=f'text/{file_type}',
|
||||
)
|
||||
)
|
||||
else:
|
||||
mime_type, _ = mimetypes.guess_type(file_name)
|
||||
saved_files.append(
|
||||
File(
|
||||
name=file_name,
|
||||
content=output_file['contents'],
|
||||
mime_type=mime_type,
|
||||
)
|
||||
)
|
||||
|
||||
# Collect the final result.
|
||||
return CodeExecutionResult(
|
||||
stdout=code_execution_result.get('execution_result', ''),
|
||||
stderr=code_execution_result.get('execution_error', ''),
|
||||
output_files=saved_files,
|
||||
)
|
||||
|
||||
def _execute_code_interpreter(
|
||||
self,
|
||||
code: str,
|
||||
input_files: Optional[list[File]] = None,
|
||||
session_id: Optional[str] = None,
|
||||
) -> dict[str, Any]:
|
||||
"""Executes the code interpreter extension.
|
||||
|
||||
Args:
|
||||
code: The code to execute.
|
||||
input_files: The input files to execute the code with.
|
||||
session_id: The session ID to execute the code with.
|
||||
|
||||
Returns:
|
||||
The response from the code interpreter extension.
|
||||
"""
|
||||
operation_params = {'code': code}
|
||||
if input_files:
|
||||
operation_params['files'] = [
|
||||
{'name': f.name, 'contents': f.content} for f in input_files
|
||||
]
|
||||
if session_id:
|
||||
operation_params['session_id'] = session_id
|
||||
response = self._code_interpreter_extension.execute(
|
||||
operation_id='execute',
|
||||
operation_params=operation_params,
|
||||
)
|
||||
return response
|
||||
|
||||
def _get_code_with_imports(self, code: str) -> str:
|
||||
"""Builds the code string with built-in imports.
|
||||
|
||||
Args:
|
||||
code: The code to execute.
|
||||
|
||||
Returns:
|
||||
The code string with built-in imports.
|
||||
"""
|
||||
return f"""
|
||||
{_IMPORTED_LIBRARIES}
|
||||
|
||||
{code}
|
||||
"""
|
||||
31
src/google/adk/evaluation/__init__.py
Normal file
31
src/google/adk/evaluation/__init__.py
Normal file
@ -0,0 +1,31 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
__all__ = []
|
||||
|
||||
try:
|
||||
from .agent_evaluator import AgentEvaluator
|
||||
|
||||
__all__.append('AgentEvaluator')
|
||||
except ImportError:
|
||||
logger.debug(
|
||||
'The Vertex[eval] sdk is not installed. If you want to use the Vertex'
|
||||
' Evaluation with agents, please install it(pip install'
|
||||
' "google-cloud-aiplatform[evaluation]). If not, you can ignore this'
|
||||
' warning.'
|
||||
)
|
||||
329
src/google/adk/evaluation/agent_evaluator.py
Normal file
329
src/google/adk/evaluation/agent_evaluator.py
Normal file
@ -0,0 +1,329 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import json
|
||||
import os
|
||||
from os import path
|
||||
from typing import Dict
|
||||
from typing import List
|
||||
from typing import Union
|
||||
|
||||
from .evaluation_generator import EvaluationGenerator
|
||||
from .response_evaluator import ResponseEvaluator
|
||||
from .trajectory_evaluator import TrajectoryEvaluator
|
||||
|
||||
# Constants for default runs and evaluation criteria
|
||||
NUM_RUNS = 2
|
||||
TOOL_TRAJECTORY_SCORE_KEY = "tool_trajectory_avg_score"
|
||||
# This evaluation is not very stable.
|
||||
# This is always optional unless explicitly specified.
|
||||
RESPONSE_EVALUATION_SCORE_KEY = "response_evaluation_score"
|
||||
RESPONSE_MATCH_SCORE_KEY = "response_match_score"
|
||||
|
||||
ALLOWED_CRITERIA = [
|
||||
TOOL_TRAJECTORY_SCORE_KEY,
|
||||
RESPONSE_EVALUATION_SCORE_KEY,
|
||||
RESPONSE_MATCH_SCORE_KEY,
|
||||
]
|
||||
|
||||
|
||||
QUERY_COLUMN = "query"
|
||||
REFERENCE_COLUMN = "reference"
|
||||
EXPECTED_TOOL_USE_COLUMN = "expected_tool_use"
|
||||
|
||||
|
||||
DEFAULT_CRITERIA = {
|
||||
TOOL_TRAJECTORY_SCORE_KEY: 1.0, # 1-point scale; 1.0 is perfect.
|
||||
RESPONSE_MATCH_SCORE_KEY: 0.8, # Rouge-1 text match; 0.8 is default.
|
||||
}
|
||||
|
||||
|
||||
def load_json(file_path: str) -> Union[Dict, List]:
|
||||
with open(file_path, "r") as f:
|
||||
return json.load(f)
|
||||
|
||||
|
||||
class AgentEvaluator:
|
||||
"""An evaluator for Agents, mainly intented for helping with test cases."""
|
||||
|
||||
@staticmethod
|
||||
def find_config_for_test_file(test_file: str):
|
||||
"""Find the test_config.json file in the same folder as the test file."""
|
||||
test_folder = os.path.dirname(test_file)
|
||||
config_path = os.path.join(test_folder, "test_config.json")
|
||||
if os.path.exists(config_path):
|
||||
config_data = load_json(config_path)
|
||||
if "criteria" in config_data and isinstance(
|
||||
config_data["criteria"], dict
|
||||
):
|
||||
return config_data["criteria"]
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Invalid format for test_config.json at {config_path}. Expected a"
|
||||
" 'criteria' dictionary."
|
||||
)
|
||||
return DEFAULT_CRITERIA
|
||||
|
||||
@staticmethod
|
||||
def evaluate(
|
||||
agent_module,
|
||||
eval_dataset_file_path_or_dir,
|
||||
num_runs=NUM_RUNS,
|
||||
agent_name=None,
|
||||
initial_session_file=None,
|
||||
):
|
||||
"""Evaluates an Agent given eval data.
|
||||
|
||||
Args:
|
||||
agent_module: The path to python module that contains the definition of
|
||||
the agent. There is convention in place here, where the code is going to
|
||||
look for 'root_agent' in the loaded module.
|
||||
eval_dataset: The eval data set. This can be either a string representing
|
||||
full path to the file containing eval dataset, or a directory that is
|
||||
recusively explored for all files that have a `.test.json` suffix.
|
||||
num_runs: Number of times all entries in the eval dataset should be
|
||||
assessed.
|
||||
agent_name: The name of the agent.
|
||||
initial_session_file: File that contains initial session state that is
|
||||
needed by all the evals in the eval dataset.
|
||||
"""
|
||||
test_files = []
|
||||
if isinstance(eval_dataset_file_path_or_dir, str) and os.path.isdir(
|
||||
eval_dataset_file_path_or_dir
|
||||
):
|
||||
for root, _, files in os.walk(eval_dataset_file_path_or_dir):
|
||||
for file in files:
|
||||
if file.endswith(".test.json"):
|
||||
test_files.append(path.join(root, file))
|
||||
else:
|
||||
test_files = [eval_dataset_file_path_or_dir]
|
||||
|
||||
initial_session_state = {}
|
||||
if initial_session_file:
|
||||
with open(initial_session_file, "r") as f:
|
||||
initial_session_state = json.loads(f.read())["state"]
|
||||
|
||||
for test_file in test_files:
|
||||
dataset = AgentEvaluator._load_dataset(test_file)[0]
|
||||
criteria = AgentEvaluator.find_config_for_test_file(test_file)
|
||||
|
||||
AgentEvaluator._validate_input([dataset], criteria)
|
||||
|
||||
evaluation_response = AgentEvaluator._generate_responses(
|
||||
agent_module,
|
||||
[dataset],
|
||||
num_runs,
|
||||
agent_name=agent_name,
|
||||
initial_session={"state": initial_session_state},
|
||||
)
|
||||
|
||||
if AgentEvaluator._response_evaluation_required(criteria, [dataset]):
|
||||
AgentEvaluator._evaluate_response_scores(
|
||||
agent_module, evaluation_response, criteria
|
||||
)
|
||||
|
||||
if AgentEvaluator._trajectory_evaluation_required(criteria, [dataset]):
|
||||
AgentEvaluator._evaluate_tool_trajectory(
|
||||
agent_module, evaluation_response, criteria
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _load_dataset(
|
||||
input_data: Union[str, List[str], List[Dict], List[List[Dict]]],
|
||||
) -> List[List[Dict]]:
|
||||
def load_json_file(file_path: str) -> List[Dict]:
|
||||
data = load_json(file_path)
|
||||
if not isinstance(data, list) or not all(
|
||||
isinstance(d, dict) for d in data
|
||||
):
|
||||
raise ValueError(f"{file_path} must contain a list of dictionaries.")
|
||||
return data
|
||||
|
||||
if isinstance(input_data, str):
|
||||
if os.path.isdir(input_data):
|
||||
test_files = []
|
||||
for root, _, files in os.walk(input_data):
|
||||
for file in files:
|
||||
if file.endswith(".test.json"):
|
||||
test_files.append(os.path.join(root, file))
|
||||
return [load_json_file(f) for f in test_files]
|
||||
elif os.path.isfile(input_data):
|
||||
return [load_json_file(input_data)]
|
||||
else:
|
||||
raise ValueError(f"Input path {input_data} is invalid.")
|
||||
elif isinstance(input_data, list):
|
||||
if all(isinstance(i, str) and os.path.isfile(i) for i in input_data):
|
||||
return [load_json_file(i) for i in input_data]
|
||||
raise TypeError("Input list must contain valid file paths.")
|
||||
raise TypeError("Invalid input type for dataset loading.")
|
||||
|
||||
@staticmethod
|
||||
def _validate_input(eval_dataset, criteria):
|
||||
"""Validates that the evaluation criteria align with the provided dataset.
|
||||
|
||||
For efficiency, we only use first row to validate input.
|
||||
"""
|
||||
if not eval_dataset:
|
||||
raise ValueError("The evaluation dataset is None or empty.")
|
||||
|
||||
for key in criteria:
|
||||
if key not in ALLOWED_CRITERIA:
|
||||
raise ValueError(
|
||||
f"Invalid criteria key: {key}. Expected one of {ALLOWED_CRITERIA}."
|
||||
)
|
||||
|
||||
if not eval_dataset:
|
||||
raise ValueError("The evaluation dataset is empty.")
|
||||
sample = eval_dataset[0]
|
||||
first_query = sample[0]
|
||||
|
||||
if not isinstance(sample, list) and not isinstance(first_query, dict):
|
||||
raise ValueError(
|
||||
"Each evaluation dataset sample must be list of dictionary. But it's"
|
||||
f" {eval_dataset}"
|
||||
)
|
||||
|
||||
if TOOL_TRAJECTORY_SCORE_KEY in criteria:
|
||||
if (
|
||||
QUERY_COLUMN not in first_query
|
||||
or EXPECTED_TOOL_USE_COLUMN not in first_query
|
||||
):
|
||||
raise ValueError(
|
||||
f"Samples for {TOOL_TRAJECTORY_SCORE_KEY} must include"
|
||||
f" '{QUERY_COLUMN}' and '{EXPECTED_TOOL_USE_COLUMN}' keys. The"
|
||||
f" sample is {sample}."
|
||||
)
|
||||
|
||||
if RESPONSE_EVALUATION_SCORE_KEY in criteria:
|
||||
if QUERY_COLUMN not in first_query:
|
||||
raise ValueError(
|
||||
f"Samples for {RESPONSE_EVALUATION_SCORE_KEY} must include"
|
||||
f" '{QUERY_COLUMN}' key. The sample is {sample}."
|
||||
)
|
||||
|
||||
if RESPONSE_MATCH_SCORE_KEY in criteria:
|
||||
if QUERY_COLUMN not in first_query or REFERENCE_COLUMN not in first_query:
|
||||
raise ValueError(
|
||||
f"Samples for {RESPONSE_MATCH_SCORE_KEY} must include"
|
||||
f" '{QUERY_COLUMN}' and '{REFERENCE_COLUMN}' keys. The sample is"
|
||||
f" {sample}."
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _get_infer_criteria(eval_dataset):
|
||||
"""Infers evaluation criteria based on the provided dataset.
|
||||
|
||||
Args:
|
||||
eval_dataset (list): A list of evaluation samples.
|
||||
|
||||
Returns:
|
||||
dict: Inferred evaluation criteria based on dataset fields.
|
||||
"""
|
||||
inferred_criteria = {}
|
||||
sample = eval_dataset[0][0]
|
||||
|
||||
if QUERY_COLUMN in sample and EXPECTED_TOOL_USE_COLUMN in sample:
|
||||
inferred_criteria[TOOL_TRAJECTORY_SCORE_KEY] = DEFAULT_CRITERIA[
|
||||
TOOL_TRAJECTORY_SCORE_KEY
|
||||
]
|
||||
|
||||
if QUERY_COLUMN in sample and REFERENCE_COLUMN in sample:
|
||||
inferred_criteria[RESPONSE_MATCH_SCORE_KEY] = DEFAULT_CRITERIA[
|
||||
RESPONSE_MATCH_SCORE_KEY
|
||||
]
|
||||
|
||||
return inferred_criteria
|
||||
|
||||
@staticmethod
|
||||
def _generate_responses(
|
||||
agent_module, eval_dataset, num_runs, agent_name=None, initial_session={}
|
||||
):
|
||||
"""Generates evaluation responses by running the agent module multiple times."""
|
||||
return EvaluationGenerator.generate_responses(
|
||||
eval_dataset,
|
||||
agent_module,
|
||||
repeat_num=num_runs,
|
||||
agent_name=agent_name,
|
||||
initial_session=initial_session,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _generate_responses_from_session(eval_dataset, session_path):
|
||||
"""Generates evaluation responses by running the agent module multiple times."""
|
||||
return EvaluationGenerator.generate_responses_from_session(
|
||||
session_path, eval_dataset
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _response_evaluation_required(criteria, eval_dataset):
|
||||
"""Checks if response evaluation are needed."""
|
||||
return REFERENCE_COLUMN in eval_dataset[0][0] and any(
|
||||
key in criteria
|
||||
for key in [RESPONSE_EVALUATION_SCORE_KEY, RESPONSE_MATCH_SCORE_KEY]
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _trajectory_evaluation_required(evaluation_criteria, eval_dataset):
|
||||
"""Checks if response evaluation are needed."""
|
||||
return (
|
||||
EXPECTED_TOOL_USE_COLUMN in eval_dataset[0][0]
|
||||
and TOOL_TRAJECTORY_SCORE_KEY in evaluation_criteria
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _evaluate_response_scores(agent_module, evaluation_response, criteria):
|
||||
"""Evaluates response scores and raises an assertion error if they don't meet the criteria."""
|
||||
metrics = ResponseEvaluator.evaluate(
|
||||
evaluation_response, criteria, print_detailed_results=True
|
||||
)
|
||||
|
||||
AgentEvaluator._assert_score(
|
||||
metrics,
|
||||
"coherence/mean",
|
||||
criteria.get(RESPONSE_EVALUATION_SCORE_KEY),
|
||||
"Average response evaluation score",
|
||||
agent_module,
|
||||
)
|
||||
|
||||
AgentEvaluator._assert_score(
|
||||
metrics,
|
||||
"rouge_1/mean",
|
||||
criteria.get(RESPONSE_MATCH_SCORE_KEY),
|
||||
"Average response match score",
|
||||
agent_module,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _evaluate_tool_trajectory(agent_module, evaluation_response, criteria):
|
||||
"""Evaluates tool trajectory scores and raises an assertion error if they don't meet the criteria."""
|
||||
score = TrajectoryEvaluator.evaluate(
|
||||
evaluation_response, print_detailed_results=True
|
||||
)
|
||||
AgentEvaluator._assert_score(
|
||||
{TOOL_TRAJECTORY_SCORE_KEY: score},
|
||||
TOOL_TRAJECTORY_SCORE_KEY,
|
||||
criteria[TOOL_TRAJECTORY_SCORE_KEY],
|
||||
"Average tool trajectory evaluation score",
|
||||
agent_module,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _assert_score(metrics, metric_key, threshold, description, agent_module):
|
||||
"""Asserts that a metric meets the specified threshold."""
|
||||
if metric_key in metrics:
|
||||
actual_score = metrics[metric_key]
|
||||
assert actual_score >= threshold, (
|
||||
f"{description} for {agent_module} is lower than expected. "
|
||||
f"Expected >= {threshold}, but got {actual_score}."
|
||||
)
|
||||
24
src/google/adk/evaluation/evaluation_constants.py
Normal file
24
src/google/adk/evaluation/evaluation_constants.py
Normal file
@ -0,0 +1,24 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
class EvalConstants:
|
||||
"""Holds constants for evaluation file constants."""
|
||||
|
||||
QUERY = "query"
|
||||
EXPECTED_TOOL_USE = "expected_tool_use"
|
||||
RESPONSE = "response"
|
||||
REFERENCE = "reference"
|
||||
TOOL_NAME = "tool_name"
|
||||
TOOL_INPUT = "tool_input"
|
||||
MOCK_TOOL_OUTPUT = "mock_tool_output"
|
||||
270
src/google/adk/evaluation/evaluation_generator.py
Normal file
270
src/google/adk/evaluation/evaluation_generator.py
Normal file
@ -0,0 +1,270 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import importlib
|
||||
import uuid
|
||||
|
||||
from google.genai import types
|
||||
|
||||
from ..agents.base_agent import BaseAgent
|
||||
from ..agents.llm_agent import Agent
|
||||
from ..agents.llm_agent import BeforeToolCallback
|
||||
from ..agents.llm_agent import LlmAgent
|
||||
from ..artifacts.in_memory_artifact_service import InMemoryArtifactService
|
||||
from ..runners import Runner
|
||||
from ..sessions.in_memory_session_service import InMemorySessionService
|
||||
from ..sessions.session import Session
|
||||
from .evaluation_constants import EvalConstants
|
||||
|
||||
|
||||
class EvaluationGenerator:
|
||||
"""Generates evaluation responses for agents."""
|
||||
|
||||
@staticmethod
|
||||
def generate_responses(
|
||||
eval_dataset,
|
||||
agent_module_path,
|
||||
repeat_num=3,
|
||||
agent_name=None,
|
||||
initial_session={},
|
||||
):
|
||||
"""Returns evaluation responses for the given dataset and agent.
|
||||
|
||||
Args:
|
||||
eval_dataset: The dataset that needs to be scraped for resposnes.
|
||||
agent_module_path: Path to the module that contains the root agent.
|
||||
repeat_num: Number of time the eval dataset should be repeated. This is
|
||||
usually done to remove uncertainity that a single run may bring.
|
||||
agent_name: The name of the agent that should be evaluated. This is
|
||||
usually the sub-agent.
|
||||
initial_session: Initial session for the eval data.
|
||||
"""
|
||||
results = []
|
||||
|
||||
for _ in range(repeat_num):
|
||||
for data in eval_dataset:
|
||||
results.append(
|
||||
EvaluationGenerator._process_query(
|
||||
data, agent_module_path, agent_name, initial_session
|
||||
)
|
||||
)
|
||||
|
||||
return results
|
||||
|
||||
@staticmethod
|
||||
def generate_responses_from_session(session_path, eval_dataset):
|
||||
"""Returns evaluation responses by combining session data with eval data.
|
||||
|
||||
Args:
|
||||
session_path: Path to a json file that contains session data.
|
||||
eval_dataset: The eval data set that should be combined with the session
|
||||
data.
|
||||
"""
|
||||
results = []
|
||||
|
||||
with open(session_path, "r") as f:
|
||||
session_data = Session.model_validate_json(f.read())
|
||||
print("loaded session", session_path)
|
||||
|
||||
for data in eval_dataset:
|
||||
# load session data from session_path
|
||||
results.append(
|
||||
EvaluationGenerator._process_query_with_session(
|
||||
session_data,
|
||||
data,
|
||||
)
|
||||
)
|
||||
|
||||
return results
|
||||
|
||||
@staticmethod
|
||||
def _process_query(data, module_name, agent_name=None, initial_session={}):
|
||||
"""Process a query using the agent and evaluation dataset."""
|
||||
module_path = f"{module_name}"
|
||||
agent_module = importlib.import_module(module_path)
|
||||
root_agent = agent_module.agent.root_agent
|
||||
|
||||
reset_func = getattr(agent_module.agent, "reset_data", None)
|
||||
|
||||
agent_to_evaluate = root_agent
|
||||
if agent_name:
|
||||
agent_to_evaluate = root_agent.find_agent(agent_name)
|
||||
assert agent_to_evaluate, f"Sub-Agent `{agent_name}` not found."
|
||||
|
||||
return EvaluationGenerator._process_query_with_root_agent(
|
||||
data, agent_to_evaluate, reset_func, initial_session
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _process_query_with_root_agent(
|
||||
data,
|
||||
root_agent,
|
||||
reset_func,
|
||||
initial_session={},
|
||||
session_id=None,
|
||||
session_service=None,
|
||||
artifact_service=None,
|
||||
):
|
||||
"""Process a query using the agent and evaluation dataset."""
|
||||
|
||||
# we don't know which tools belong to which agent
|
||||
# so we just apply to any agents that has certain tool outputs
|
||||
all_mock_tools = set()
|
||||
for eval_entry in data:
|
||||
expected_tool_use = eval_entry.get(EvalConstants.EXPECTED_TOOL_USE, [])
|
||||
for expected in expected_tool_use:
|
||||
if EvalConstants.MOCK_TOOL_OUTPUT in expected:
|
||||
all_mock_tools.add(expected[EvalConstants.TOOL_NAME])
|
||||
|
||||
eval_data_copy = data.copy()
|
||||
EvaluationGenerator.apply_before_tool_callback(
|
||||
root_agent,
|
||||
lambda *args: EvaluationGenerator.before_tool_callback(
|
||||
*args, eval_dataset=eval_data_copy
|
||||
),
|
||||
all_mock_tools,
|
||||
)
|
||||
|
||||
if not session_service:
|
||||
session_service = InMemorySessionService()
|
||||
|
||||
app_name = initial_session.get("app_name", "EvaluationGenerator")
|
||||
user_id = initial_session.get("user_id", "test_user_id")
|
||||
session_id = session_id if session_id else str(uuid.uuid4())
|
||||
|
||||
_ = session_service.create_session(
|
||||
app_name=app_name,
|
||||
user_id=user_id,
|
||||
state=initial_session.get("state", {}),
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
if not artifact_service:
|
||||
artifact_service = InMemoryArtifactService()
|
||||
runner = Runner(
|
||||
app_name=app_name,
|
||||
agent=root_agent,
|
||||
artifact_service=artifact_service,
|
||||
session_service=session_service,
|
||||
)
|
||||
|
||||
# Reset agent state for each query
|
||||
if callable(reset_func):
|
||||
reset_func()
|
||||
|
||||
responses = data.copy()
|
||||
|
||||
for index, eval_entry in enumerate(responses):
|
||||
response = None
|
||||
query = eval_entry["query"]
|
||||
content = types.Content(role="user", parts=[types.Part(text=query)])
|
||||
turn_actual_tool_uses = []
|
||||
|
||||
for event in runner.run(
|
||||
user_id=user_id, session_id=session_id, new_message=content
|
||||
):
|
||||
if event.is_final_response() and event.content and event.content.parts:
|
||||
response = event.content.parts[0].text
|
||||
elif event.get_function_calls():
|
||||
for call in event.get_function_calls():
|
||||
turn_actual_tool_uses.append({
|
||||
EvalConstants.TOOL_NAME: call.name,
|
||||
EvalConstants.TOOL_INPUT: call.args,
|
||||
})
|
||||
|
||||
responses[index]["actual_tool_use"] = turn_actual_tool_uses
|
||||
responses[index]["response"] = response
|
||||
|
||||
return responses
|
||||
|
||||
@staticmethod
|
||||
def _process_query_with_session(session_data, data):
|
||||
"""Process the queries using the existing session data without invoking the runner."""
|
||||
responses = data.copy()
|
||||
|
||||
# Iterate through the provided queries and align them with the session events
|
||||
for index, eval_entry in enumerate(responses):
|
||||
query = eval_entry["query"]
|
||||
actual_tool_uses = []
|
||||
response = None
|
||||
|
||||
# Search for the corresponding session events
|
||||
for event in session_data.events:
|
||||
# Match the query to a user event
|
||||
if (
|
||||
event.author == "user"
|
||||
and event.content
|
||||
and event.content.parts
|
||||
and event.content.parts[0].text == query
|
||||
):
|
||||
# Look for subsequent tool usage or model responses
|
||||
for subsequent_event in session_data.events:
|
||||
if subsequent_event.invocation_id == event.invocation_id:
|
||||
# Extract tool usage
|
||||
if subsequent_event.content.parts[0].function_call:
|
||||
call = subsequent_event.content.parts[0].function_call
|
||||
actual_tool_uses.append(
|
||||
{"tool_name": call.name, "tool_input": call.args}
|
||||
)
|
||||
# Extract final response
|
||||
elif subsequent_event.author != "user":
|
||||
response = subsequent_event.content.parts[0].text
|
||||
|
||||
# Update the results for the current query
|
||||
responses[index]["actual_tool_use"] = actual_tool_uses
|
||||
responses[index]["response"] = response
|
||||
return responses
|
||||
|
||||
@staticmethod
|
||||
def before_tool_callback(tool, args, tool_context, eval_dataset):
|
||||
"""Intercept specific tool calls and return predefined outputs
|
||||
|
||||
from eval_dataset.
|
||||
"""
|
||||
for index, eval_entry in enumerate(eval_dataset):
|
||||
expected_tool_use = eval_entry.get("expected_tool_use", [])
|
||||
for expected in expected_tool_use:
|
||||
if (
|
||||
EvalConstants.MOCK_TOOL_OUTPUT in expected
|
||||
and tool.name == expected[EvalConstants.TOOL_NAME]
|
||||
and args == expected.get(EvalConstants.TOOL_INPUT, {})
|
||||
):
|
||||
# pop the matched entry so we don't rematch again
|
||||
eval_dataset.pop(index)
|
||||
return {"result": expected[EvalConstants.MOCK_TOOL_OUTPUT]}
|
||||
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def apply_before_tool_callback(
|
||||
agent: BaseAgent,
|
||||
callback: BeforeToolCallback,
|
||||
all_mock_tools: set[str],
|
||||
):
|
||||
"""Recursively apply the before_tool_callback to the root agent and all its subagents."""
|
||||
# check if the agent has tools that defined by evalset
|
||||
# We use function name to check if tools match
|
||||
if not isinstance(agent, Agent) and not isinstance(agent, LlmAgent):
|
||||
return
|
||||
|
||||
for tool in agent.canonical_tools:
|
||||
tool_name = tool.name
|
||||
if tool_name in all_mock_tools:
|
||||
agent.before_tool_callback = callback
|
||||
|
||||
# Apply recursively to subagents if they exist
|
||||
for sub_agent in agent.sub_agents:
|
||||
EvaluationGenerator.apply_before_tool_callback(
|
||||
sub_agent, callback, all_mock_tools
|
||||
)
|
||||
135
src/google/adk/evaluation/response_evaluator.py
Normal file
135
src/google/adk/evaluation/response_evaluator.py
Normal file
@ -0,0 +1,135 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Any
|
||||
|
||||
import pandas as pd
|
||||
from tabulate import tabulate
|
||||
from vertexai.preview.evaluation import EvalTask
|
||||
from vertexai.preview.evaluation import MetricPromptTemplateExamples
|
||||
|
||||
|
||||
class ResponseEvaluator:
|
||||
"""Runs response evaluation for agents."""
|
||||
|
||||
@staticmethod
|
||||
def evaluate(
|
||||
raw_eval_dataset: list[list[dict[str, Any]]],
|
||||
evaluation_criteria: list[str],
|
||||
*,
|
||||
print_detailed_results: bool = False
|
||||
):
|
||||
r"""Returns the value of requested evaluation metrics.
|
||||
|
||||
Args:
|
||||
raw_eval_dataset: The dataset that will be evaluated.
|
||||
evaluation_criteria: The evaluation criteria to be used. This method
|
||||
support two criterias, `response_evaluation_score` and
|
||||
`response_match_score`.
|
||||
print_detailed_results: Prints detailed results on the console. This is
|
||||
usually helpful during debugging.
|
||||
|
||||
A note on evaluation_criteria:
|
||||
`response_match_score`: This metric compares the agents final natural
|
||||
language reponse with the expected final response, stored in the
|
||||
"reference" field in test/eval files. We use Rouge metric to compare the
|
||||
two responses.
|
||||
|
||||
Value Range: [0, 1]. A score closer to 0 means poor similarity between
|
||||
response and reference. A score closer to 1 means strong similarity
|
||||
between response and reference.
|
||||
|
||||
`response_evaluation_score`: Uses LLM to evalaute coherence of the
|
||||
response, including tool use. This is pointwise metric.
|
||||
|
||||
Value range: [0, 5], where 0 means that the agent's response is not
|
||||
coherent, while 5 means it is . High values are good.
|
||||
A note on raw_eval_dataset:
|
||||
The dataset should be a list session, where each sesssion is represented
|
||||
as a list of interaction that need evaluation. Each evaluation is
|
||||
represented as a dictionary that is expected to have values for the
|
||||
following keys:
|
||||
|
||||
1) query
|
||||
2) response
|
||||
3) acutal_tool_use
|
||||
4) expected_tool_use
|
||||
5) reference
|
||||
|
||||
Here is a sample eval_dataset value with one entry:
|
||||
[
|
||||
[
|
||||
{
|
||||
"query": "roll a die for me",
|
||||
"response": "I rolled a 16 sided die and got 13.\n",
|
||||
"expected_tool_use": [
|
||||
{
|
||||
"tool_name": "roll_die",
|
||||
"tool_input": {
|
||||
"sides": 16
|
||||
}
|
||||
}
|
||||
],
|
||||
"acutal_tool_use": [
|
||||
{
|
||||
"tool_name": "roll_die",
|
||||
"tool_input": {
|
||||
"sides": 16
|
||||
}
|
||||
}
|
||||
],
|
||||
"reference": "I rolled a 16 sided die and got 13.\n"
|
||||
}
|
||||
]
|
||||
]
|
||||
"""
|
||||
if not raw_eval_dataset:
|
||||
raise ValueError("The evaluation dataset is empty.")
|
||||
|
||||
metrics = ResponseEvaluator._get_metrics(
|
||||
raw_eval_dataset, evaluation_criteria
|
||||
)
|
||||
flattened_queries = [
|
||||
item for sublist in raw_eval_dataset for item in sublist
|
||||
]
|
||||
eval_dataset = pd.DataFrame(flattened_queries).rename(
|
||||
columns={"query": "prompt", "expected_tool_use": "reference_trajectory"}
|
||||
)
|
||||
eval_task = EvalTask(dataset=eval_dataset, metrics=metrics)
|
||||
|
||||
eval_result = eval_task.evaluate()
|
||||
if print_detailed_results:
|
||||
ResponseEvaluator._print_results(eval_result)
|
||||
return eval_result.summary_metrics
|
||||
|
||||
@staticmethod
|
||||
def _get_metrics(raw_eval_dataset, criteria):
|
||||
metrics = []
|
||||
if (
|
||||
"response_evaluation_score" in criteria
|
||||
and "query" in raw_eval_dataset[0][0]
|
||||
and "expected_tool_use" in raw_eval_dataset[0][0]
|
||||
):
|
||||
metrics.append(MetricPromptTemplateExamples.Pointwise.COHERENCE)
|
||||
if (
|
||||
"response_match_score" in criteria
|
||||
and "reference" in raw_eval_dataset[0][0]
|
||||
):
|
||||
metrics.append("rouge_1")
|
||||
return metrics
|
||||
|
||||
@staticmethod
|
||||
def _print_results(eval_result):
|
||||
print("Evaluation Summary Metrics:", eval_result.summary_metrics)
|
||||
print(tabulate(eval_result.metrics_table, headers="keys", tablefmt="grid"))
|
||||
184
src/google/adk/evaluation/trajectory_evaluator.py
Normal file
184
src/google/adk/evaluation/trajectory_evaluator.py
Normal file
@ -0,0 +1,184 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Any
|
||||
|
||||
import pandas as pd
|
||||
from tabulate import tabulate
|
||||
|
||||
from .evaluation_constants import EvalConstants
|
||||
|
||||
|
||||
class TrajectoryEvaluator:
|
||||
"""Evaluates tool use trajectories for accuracy."""
|
||||
|
||||
@staticmethod
|
||||
def evaluate(
|
||||
eval_dataset: list[list[dict[str, Any]]],
|
||||
*,
|
||||
print_detailed_results: bool = False,
|
||||
):
|
||||
r"""Returns the mean tool use accuracy of the eval dataset.
|
||||
|
||||
Tool use accuracy is calculated by comparing the expected and actuall tool
|
||||
use trajectories. An exact match scores a 1, 0 otherwise. The final number
|
||||
is an
|
||||
average of these individual scores.
|
||||
|
||||
Value range: [0, 1], where 0 is means none of the too use entries aligned,
|
||||
and 1 would mean all of them aligned. Higher value is good.
|
||||
|
||||
Args:
|
||||
eval_dataset: The dataset that will be evaluated.
|
||||
print_detailed_results: Prints detailed results on the console. This is
|
||||
usually helpful during debugging.
|
||||
|
||||
A note on eval_dataset:
|
||||
The dataset should be a list session, where each sesssion is represented
|
||||
as a list of interaction that need evaluation. Each evaluation is
|
||||
represented as a dictionary that is expected to have values for the
|
||||
following keys:
|
||||
1) query
|
||||
2) response
|
||||
3) acutal_tool_use
|
||||
4) expected_tool_use
|
||||
|
||||
Here is a sample eval_dataset value with one entry:
|
||||
|
||||
[
|
||||
[
|
||||
{
|
||||
"query": "Roll a 16 sided dice for me",
|
||||
"response": "I rolled a 16 sided die and got 13.\n",
|
||||
"expected_tool_use": [
|
||||
{
|
||||
"tool_name": "roll_die",
|
||||
"tool_input": {
|
||||
"sides": 16
|
||||
}
|
||||
}
|
||||
],
|
||||
"acutal_tool_use": [
|
||||
{
|
||||
"tool_name": "roll_die",
|
||||
"tool_input": {
|
||||
"sides": 16
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
]
|
||||
"""
|
||||
if not eval_dataset:
|
||||
raise ValueError("The evaluation dataset is empty.")
|
||||
|
||||
results_df = pd.DataFrame(
|
||||
columns=[
|
||||
"query",
|
||||
"response",
|
||||
"actual_tool_use",
|
||||
"expected_tool_use",
|
||||
"tool_use_accuracy",
|
||||
]
|
||||
)
|
||||
failures = []
|
||||
|
||||
for conversation in eval_dataset:
|
||||
for index, row in enumerate(conversation):
|
||||
new_row, failure = TrajectoryEvaluator._evaluate_row(row)
|
||||
results_df = pd.concat(
|
||||
[results_df, pd.DataFrame([new_row])], ignore_index=True
|
||||
)
|
||||
if failure:
|
||||
failure["turn"] = index + 1
|
||||
failures.append(failure)
|
||||
|
||||
TrajectoryEvaluator._report_failures(failures)
|
||||
|
||||
if print_detailed_results:
|
||||
TrajectoryEvaluator._print_results(results_df)
|
||||
|
||||
return results_df["tool_use_accuracy"].mean()
|
||||
|
||||
@staticmethod
|
||||
def _evaluate_row(row):
|
||||
# We don't evaluate the mock tool outputs.
|
||||
expected = TrajectoryEvaluator._remove_tool_outputs(
|
||||
row["expected_tool_use"]
|
||||
)
|
||||
actual = row["actual_tool_use"]
|
||||
tool_use_accuracy = (
|
||||
1.0 if TrajectoryEvaluator.are_tools_equal(actual, expected) else 0.0
|
||||
)
|
||||
|
||||
new_row = {
|
||||
"query": row["query"],
|
||||
"response": row["response"],
|
||||
"actual_tool_use": actual,
|
||||
"expected_tool_use": expected,
|
||||
"tool_use_accuracy": tool_use_accuracy,
|
||||
}
|
||||
failure = (
|
||||
None
|
||||
if tool_use_accuracy == 1.0
|
||||
else {"query": row["query"], "actual": actual, "expected": expected}
|
||||
)
|
||||
return new_row, failure
|
||||
|
||||
@staticmethod
|
||||
def are_tools_equal(list_a_original, list_b_original):
|
||||
# Remove other entries that we don't want to evaluate
|
||||
list_a = [
|
||||
{"tool_name": tool["tool_name"], "tool_input": tool["tool_input"]}
|
||||
for tool in list_a_original
|
||||
]
|
||||
|
||||
list_b = [
|
||||
{"tool_name": tool["tool_name"], "tool_input": tool["tool_input"]}
|
||||
for tool in list_b_original
|
||||
]
|
||||
|
||||
return list_a == list_b
|
||||
|
||||
@staticmethod
|
||||
def _remove_tool_outputs(tool_use_list):
|
||||
"""Removes 'mock_tool_output' from each dictionary in the list."""
|
||||
result = []
|
||||
for tool_use in tool_use_list:
|
||||
new_tool_use = (
|
||||
tool_use.copy()
|
||||
) # Create a copy to avoid modifying the original
|
||||
new_tool_use.pop(
|
||||
EvalConstants.MOCK_TOOL_OUTPUT, None
|
||||
) # Remove 'tool_output' if it exists
|
||||
result.append(new_tool_use)
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
def _report_failures(failures):
|
||||
if failures:
|
||||
print("Failures:")
|
||||
for failure in failures:
|
||||
print(f"""{{
|
||||
"turn": {failure["turn"]},
|
||||
"query": '{failure["query"]}',
|
||||
"actual": {failure["actual"]},
|
||||
"expected_tool_use": {failure["expected"]},
|
||||
}}
|
||||
""")
|
||||
|
||||
@staticmethod
|
||||
def _print_results(results_df):
|
||||
print(tabulate(results_df, headers="keys", tablefmt="grid"))
|
||||
21
src/google/adk/events/__init__.py
Normal file
21
src/google/adk/events/__init__.py
Normal file
@ -0,0 +1,21 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from .event import Event
|
||||
from .event_actions import EventActions
|
||||
|
||||
__all__ = [
|
||||
'Event',
|
||||
'EventActions',
|
||||
]
|
||||
130
src/google/adk/events/event.py
Normal file
130
src/google/adk/events/event.py
Normal file
@ -0,0 +1,130 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
import random
|
||||
import string
|
||||
from typing import Optional
|
||||
|
||||
from google.genai import types
|
||||
from pydantic import ConfigDict
|
||||
from pydantic import Field
|
||||
|
||||
from ..models.llm_response import LlmResponse
|
||||
from .event_actions import EventActions
|
||||
|
||||
|
||||
class Event(LlmResponse):
|
||||
"""Represents an event in a conversation between agents and users.
|
||||
|
||||
It is used to store the content of the conversation, as well as the actions
|
||||
taken by the agents like function calls, etc.
|
||||
|
||||
Attributes:
|
||||
invocation_id: The invocation ID of the event.
|
||||
author: "user" or the name of the agent, indicating who appended the event
|
||||
to the session.
|
||||
actions: The actions taken by the agent.
|
||||
long_running_tool_ids: The ids of the long running function calls.
|
||||
branch: The branch of the event.
|
||||
id: The unique identifier of the event.
|
||||
timestamp: The timestamp of the event.
|
||||
is_final_response: Whether the event is the final response of the agent.
|
||||
get_function_calls: Returns the function calls in the event.
|
||||
"""
|
||||
|
||||
model_config = ConfigDict(
|
||||
extra='forbid', ser_json_bytes='base64', val_json_bytes='base64'
|
||||
)
|
||||
|
||||
# TODO: revert to be required after spark migration
|
||||
invocation_id: str = ''
|
||||
"""The invocation ID of the event."""
|
||||
author: str
|
||||
"""'user' or the name of the agent, indicating who appended the event to the
|
||||
session."""
|
||||
actions: EventActions = Field(default_factory=EventActions)
|
||||
"""The actions taken by the agent."""
|
||||
|
||||
long_running_tool_ids: Optional[set[str]] = None
|
||||
"""Set of ids of the long running function calls.
|
||||
Agent client will know from this field about which function call is long running.
|
||||
only valid for function call event
|
||||
"""
|
||||
branch: Optional[str] = None
|
||||
"""The branch of the event.
|
||||
|
||||
The format is like agent_1.agent_2.agent_3, where agent_1 is the parent of
|
||||
agent_2, and agent_2 is the parent of agent_3.
|
||||
|
||||
Branch is used when multiple sub-agent shouldn't see their peer agents'
|
||||
conversaction history.
|
||||
"""
|
||||
|
||||
# The following are computed fields.
|
||||
# Do not assign the ID. It will be assigned by the session.
|
||||
id: str = ''
|
||||
"""The unique identifier of the event."""
|
||||
timestamp: float = Field(default_factory=lambda: datetime.now().timestamp())
|
||||
"""The timestamp of the event."""
|
||||
|
||||
def model_post_init(self, __context):
|
||||
"""Post initialization logic for the event."""
|
||||
# Generates a random ID for the event.
|
||||
if not self.id:
|
||||
self.id = Event.new_id()
|
||||
|
||||
def is_final_response(self) -> bool:
|
||||
"""Returns whether the event is the final response of the agent."""
|
||||
if self.actions.skip_summarization or self.long_running_tool_ids:
|
||||
return True
|
||||
return (
|
||||
not self.get_function_calls()
|
||||
and not self.get_function_responses()
|
||||
and not self.partial
|
||||
and not self.has_trailing_code_exeuction_result()
|
||||
)
|
||||
|
||||
def get_function_calls(self) -> list[types.FunctionCall]:
|
||||
"""Returns the function calls in the event."""
|
||||
func_calls = []
|
||||
if self.content and self.content.parts:
|
||||
for part in self.content.parts:
|
||||
if part.function_call:
|
||||
func_calls.append(part.function_call)
|
||||
return func_calls
|
||||
|
||||
def get_function_responses(self) -> list[types.FunctionResponse]:
|
||||
"""Returns the function responses in the event."""
|
||||
func_response = []
|
||||
if self.content and self.content.parts:
|
||||
for part in self.content.parts:
|
||||
if part.function_response:
|
||||
func_response.append(part.function_response)
|
||||
return func_response
|
||||
|
||||
def has_trailing_code_exeuction_result(
|
||||
self,
|
||||
) -> bool:
|
||||
"""Returns whether the event has a trailing code execution result."""
|
||||
if self.content:
|
||||
if self.content.parts:
|
||||
return self.content.parts[-1].code_execution_result is not None
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def new_id():
|
||||
characters = string.ascii_letters + string.digits
|
||||
return ''.join(random.choice(characters) for _ in range(8))
|
||||
55
src/google/adk/events/event_actions.py
Normal file
55
src/google/adk/events/event_actions.py
Normal file
@ -0,0 +1,55 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
from pydantic import ConfigDict
|
||||
from pydantic import Field
|
||||
|
||||
from ..auth.auth_tool import AuthConfig
|
||||
|
||||
|
||||
class EventActions(BaseModel):
|
||||
"""Represents the actions attached to an event."""
|
||||
|
||||
model_config = ConfigDict(extra='forbid')
|
||||
|
||||
skip_summarization: Optional[bool] = None
|
||||
"""If true, it won't call model to summarize function response.
|
||||
|
||||
Only used for function_response event.
|
||||
"""
|
||||
|
||||
state_delta: dict[str, object] = Field(default_factory=dict)
|
||||
"""Indicates that the event is updating the state with the given delta."""
|
||||
|
||||
artifact_delta: dict[str, int] = Field(default_factory=dict)
|
||||
"""Indicates that the event is updating an artifact. key is the filename,
|
||||
value is the version."""
|
||||
|
||||
transfer_to_agent: Optional[str] = None
|
||||
"""If set, the event transfers to the specified agent."""
|
||||
|
||||
escalate: Optional[bool] = None
|
||||
"""The agent is escalating to a higher level agent."""
|
||||
|
||||
requested_auth_configs: dict[str, AuthConfig] = Field(default_factory=dict)
|
||||
"""Will only be set by a tool response indicating tool request euc.
|
||||
dict key is the function call id since one function call response (from model)
|
||||
could correspond to multiple function calls.
|
||||
dict value is the required auth config.
|
||||
"""
|
||||
28
src/google/adk/examples/__init__.py
Normal file
28
src/google/adk/examples/__init__.py
Normal file
@ -0,0 +1,28 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from .base_example_provider import BaseExampleProvider
|
||||
from .example import Example
|
||||
|
||||
__all__ = [
|
||||
'BaseExampleProvider',
|
||||
'Example',
|
||||
]
|
||||
|
||||
try:
|
||||
from .vertex_ai_example_store import VertexAiExampleStore
|
||||
|
||||
__all__.append('VertexAiExampleStore')
|
||||
except ImportError:
|
||||
pass
|
||||
35
src/google/adk/examples/base_example_provider.py
Normal file
35
src/google/adk/examples/base_example_provider.py
Normal file
@ -0,0 +1,35 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import abc
|
||||
from .example import Example
|
||||
|
||||
|
||||
# A class that provides examples for a given query.
|
||||
class BaseExampleProvider(abc.ABC):
|
||||
"""Base class for example providers.
|
||||
|
||||
This class defines the interface for providing examples for a given query.
|
||||
"""
|
||||
|
||||
@abc.abstractmethod
|
||||
def get_examples(self, query: str) -> list[Example]:
|
||||
"""Returns a list of examples for a given query.
|
||||
|
||||
Args:
|
||||
query: The query to get examples for.
|
||||
|
||||
Returns:
|
||||
A list of Example objects.
|
||||
"""
|
||||
27
src/google/adk/examples/example.py
Normal file
27
src/google/adk/examples/example.py
Normal file
@ -0,0 +1,27 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from google.genai import types
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class Example(BaseModel):
|
||||
"""A few-shot example.
|
||||
|
||||
Attributes:
|
||||
input: The input content for the example.
|
||||
output: The expected output content for the example.
|
||||
"""
|
||||
input: types.Content
|
||||
output: list[types.Content]
|
||||
123
src/google/adk/examples/example_util.py
Normal file
123
src/google/adk/examples/example_util.py
Normal file
@ -0,0 +1,123 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Utility functions for converting examples to a string that can be used in system instructions in the prompt."""
|
||||
|
||||
import logging
|
||||
from typing import Optional, Union
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from .base_example_provider import BaseExampleProvider
|
||||
from .example import Example
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..sessions.session import Session
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Constant parts of the example string
|
||||
_EXAMPLES_INTRO = (
|
||||
"<EXAMPLES>\nBegin few-shot\nThe following are examples of user queries and"
|
||||
" model responses using the available tools.\n\n"
|
||||
)
|
||||
_EXAMPLES_END = "End few-shot\n<EXAMPLES>"
|
||||
_EXAMPLE_START = "EXAMPLE {}:\nBegin example\n"
|
||||
_EXAMPLE_END = "End example\n\n"
|
||||
_USER_PREFIX = "[user]\n"
|
||||
_MODEL_PREFIX = "[model]\n"
|
||||
_FUNCTION_PREFIX = "```\n"
|
||||
_FUNCTION_CALL_PREFIX = "```tool_code\n"
|
||||
_FUNCTION_CALL_SUFFIX = "\n```\n"
|
||||
_FUNCTION_RESPONSE_PREFIX = "```tool_outputs\n"
|
||||
_FUNCTION_RESPONSE_SUFFIX = "\n```\n"
|
||||
|
||||
|
||||
# TODO(yaojie): Add unit tests for this function.
|
||||
def convert_examples_to_text(
|
||||
examples: list[Example], model: Optional[str]
|
||||
) -> str:
|
||||
"""Converts a list of examples to a string that can be used in a system instruction."""
|
||||
examples_str = ""
|
||||
for example_num, example in enumerate(examples):
|
||||
output = f"{_EXAMPLE_START.format(example_num + 1)}{_USER_PREFIX}"
|
||||
if example.input and example.input.parts:
|
||||
output += (
|
||||
"\n".join(part.text for part in example.input.parts if part.text)
|
||||
+ "\n"
|
||||
)
|
||||
|
||||
gemini2 = model is None or "gemini-2" in model
|
||||
previous_role = None
|
||||
for content in example.output:
|
||||
role = _MODEL_PREFIX if content.role == "model" else _USER_PREFIX
|
||||
if role != previous_role:
|
||||
output += role
|
||||
previous_role = role
|
||||
for part in content.parts:
|
||||
if part.function_call:
|
||||
args = []
|
||||
# Convert function call part to python-like function call
|
||||
for k, v in part.function_call.args.items():
|
||||
if isinstance(v, str):
|
||||
args.append(f"{k}='{v}'")
|
||||
else:
|
||||
args.append(f"{k}={v}")
|
||||
prefix = _FUNCTION_PREFIX if gemini2 else _FUNCTION_CALL_PREFIX
|
||||
output += (
|
||||
f"{prefix}{part.function_call.name}({', '.join(args)}){_FUNCTION_CALL_SUFFIX}"
|
||||
)
|
||||
# Convert function response part to json string
|
||||
elif part.function_response:
|
||||
prefix = _FUNCTION_PREFIX if gemini2 else _FUNCTION_RESPONSE_PREFIX
|
||||
output += f"{prefix}{part.function_response.__dict__}{_FUNCTION_RESPONSE_SUFFIX}"
|
||||
elif part.text:
|
||||
output += f"{part.text}\n"
|
||||
|
||||
output += _EXAMPLE_END
|
||||
examples_str += output
|
||||
|
||||
return f"{_EXAMPLES_INTRO}{examples_str}{_EXAMPLES_END}"
|
||||
|
||||
|
||||
def _get_latest_message_from_user(session: "Session") -> str:
|
||||
"""Gets the latest message from the user.
|
||||
|
||||
Returns:
|
||||
The latest message from the user. If not found, returns an empty string.
|
||||
"""
|
||||
events = session.events
|
||||
if not events:
|
||||
return ""
|
||||
|
||||
event = events[-1]
|
||||
if event.author == "user" and not event.get_function_responses():
|
||||
if event.content.parts and event.content.parts[0].text:
|
||||
return event.content.parts[0].text
|
||||
else:
|
||||
logger.warning("No message from user for fetching example.")
|
||||
|
||||
return ""
|
||||
|
||||
|
||||
def build_example_si(
|
||||
examples: Union[list[Example], BaseExampleProvider],
|
||||
query: str,
|
||||
model: Optional[str],
|
||||
) -> str:
|
||||
if isinstance(examples, list):
|
||||
return convert_examples_to_text(examples, model)
|
||||
if isinstance(examples, BaseExampleProvider):
|
||||
return convert_examples_to_text(examples.get_examples(query), model)
|
||||
|
||||
raise ValueError("Invalid example configuration")
|
||||
104
src/google/adk/examples/vertex_ai_example_store.py
Normal file
104
src/google/adk/examples/vertex_ai_example_store.py
Normal file
@ -0,0 +1,104 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from google.genai import types
|
||||
from typing_extensions import override
|
||||
from vertexai.preview import example_stores
|
||||
|
||||
from .base_example_provider import BaseExampleProvider
|
||||
from .example import Example
|
||||
|
||||
|
||||
class VertexAiExampleStore(BaseExampleProvider):
|
||||
"""Provides examples from Vertex example store."""
|
||||
|
||||
def __init__(self, examples_store_name: str):
|
||||
"""Initializes the VertexAiExampleStore.
|
||||
|
||||
Args:
|
||||
examples_store_name: The resource name of the vertex example store, in
|
||||
the format of
|
||||
``projects/{project}/locations/{location}/exampleStores/{example_store}``.
|
||||
"""
|
||||
self.examples_store_name = examples_store_name
|
||||
|
||||
@override
|
||||
def get_examples(self, query: str) -> list[Example]:
|
||||
example_store = example_stores.ExampleStore(self.examples_store_name)
|
||||
# Retrieve relevant examples.
|
||||
request = {
|
||||
"stored_contents_example_parameters": {
|
||||
"content_search_key": {
|
||||
"contents": [{"role": "user", "parts": [{"text": query}]}],
|
||||
"search_key_generation_method": {"last_entry": {}},
|
||||
}
|
||||
},
|
||||
"top_k": 10,
|
||||
"example_store": self.examples_store_name,
|
||||
}
|
||||
response = example_store.api_client.search_examples(request)
|
||||
|
||||
returned_examples = []
|
||||
# Convert results to genai formats
|
||||
for result in response.results:
|
||||
if result.similarity_score < 0.5:
|
||||
continue
|
||||
expected_contents = [
|
||||
content.content
|
||||
for content in result.example.stored_contents_example.contents_example.expected_contents
|
||||
]
|
||||
expected_output = []
|
||||
for content in expected_contents:
|
||||
expected_parts = []
|
||||
for part in content.parts:
|
||||
if part.text:
|
||||
expected_parts.append(types.Part.from_text(text=part.text))
|
||||
elif part.function_call:
|
||||
expected_parts.append(
|
||||
types.Part.from_function_call(
|
||||
name=part.function_call.name,
|
||||
args={
|
||||
key: value
|
||||
for key, value in part.function_call.args.items()
|
||||
},
|
||||
)
|
||||
)
|
||||
elif part.function_response:
|
||||
expected_parts.append(
|
||||
types.Part.from_function_response(
|
||||
name=part.function_response.name,
|
||||
response={
|
||||
key: value
|
||||
for key, value in part.function_response.response.items()
|
||||
},
|
||||
)
|
||||
)
|
||||
expected_output.append(
|
||||
types.Content(role=content.role, parts=expected_parts)
|
||||
)
|
||||
|
||||
returned_examples.append(
|
||||
Example(
|
||||
input=types.Content(
|
||||
role="user",
|
||||
parts=[
|
||||
types.Part.from_text(
|
||||
text=result.example.stored_contents_example.search_key
|
||||
)
|
||||
],
|
||||
),
|
||||
output=expected_output,
|
||||
)
|
||||
)
|
||||
return returned_examples
|
||||
14
src/google/adk/flows/__init__.py
Normal file
14
src/google/adk/flows/__init__.py
Normal file
@ -0,0 +1,14 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
20
src/google/adk/flows/llm_flows/__init__.py
Normal file
20
src/google/adk/flows/llm_flows/__init__.py
Normal file
@ -0,0 +1,20 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from . import _code_execution
|
||||
from . import _nl_planning
|
||||
from . import contents
|
||||
from . import functions
|
||||
from . import identity
|
||||
from . import instructions
|
||||
52
src/google/adk/flows/llm_flows/_base_llm_processor.py
Normal file
52
src/google/adk/flows/llm_flows/_base_llm_processor.py
Normal file
@ -0,0 +1,52 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Defines the processor interface used for BaseLlmFlow."""
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC
|
||||
from abc import abstractmethod
|
||||
from typing import AsyncGenerator
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from ...agents.invocation_context import InvocationContext
|
||||
from ...events.event import Event
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ...models.llm_request import LlmRequest
|
||||
from ...models.llm_response import LlmResponse
|
||||
|
||||
|
||||
class BaseLlmRequestProcessor(ABC):
|
||||
"""Base class for LLM request processor."""
|
||||
|
||||
@abstractmethod
|
||||
async def run_async(
|
||||
self, invocation_context: InvocationContext, llm_request: LlmRequest
|
||||
) -> AsyncGenerator[Event, None]:
|
||||
"""Runs the processor."""
|
||||
raise NotImplementedError("Not implemented.")
|
||||
yield # AsyncGenerator requires a yield in function body.
|
||||
|
||||
|
||||
class BaseLlmResponseProcessor(ABC):
|
||||
"""Base class for LLM response processor."""
|
||||
|
||||
@abstractmethod
|
||||
async def run_async(
|
||||
self, invocation_context: InvocationContext, llm_response: LlmResponse
|
||||
) -> AsyncGenerator[Event, None]:
|
||||
"""Processes the LLM response."""
|
||||
raise NotImplementedError("Not implemented.")
|
||||
yield # AsyncGenerator requires a yield in function body.
|
||||
458
src/google/adk/flows/llm_flows/_code_execution.py
Normal file
458
src/google/adk/flows/llm_flows/_code_execution.py
Normal file
@ -0,0 +1,458 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Handles Code Execution related logic."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
import copy
|
||||
import dataclasses
|
||||
import os
|
||||
import re
|
||||
from typing import AsyncGenerator
|
||||
from typing import Generator
|
||||
from typing import Optional
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from google.genai import types
|
||||
from typing_extensions import override
|
||||
|
||||
from ...agents.invocation_context import InvocationContext
|
||||
from ...code_executors.base_code_executor import BaseCodeExecutor
|
||||
from ...code_executors.code_execution_utils import CodeExecutionInput
|
||||
from ...code_executors.code_execution_utils import CodeExecutionResult
|
||||
from ...code_executors.code_execution_utils import CodeExecutionUtils
|
||||
from ...code_executors.code_execution_utils import File
|
||||
from ...code_executors.code_executor_context import CodeExecutorContext
|
||||
from ...events.event import Event
|
||||
from ...events.event_actions import EventActions
|
||||
from ...models.llm_response import LlmResponse
|
||||
from ._base_llm_processor import BaseLlmRequestProcessor
|
||||
from ._base_llm_processor import BaseLlmResponseProcessor
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ...models.llm_request import LlmRequest
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class DataFileUtil:
|
||||
"""A structure that contains a data file name and its content."""
|
||||
|
||||
extension: str
|
||||
"""
|
||||
The file extension (e.g., ".csv").
|
||||
"""
|
||||
|
||||
loader_code_template: str
|
||||
"""
|
||||
The code template to load the data file.
|
||||
"""
|
||||
|
||||
|
||||
_DATA_FILE_UTIL_MAP = {
|
||||
'text/csv': DataFileUtil(
|
||||
extension='.csv',
|
||||
loader_code_template="pd.read_csv('{filename}')",
|
||||
),
|
||||
}
|
||||
|
||||
_DATA_FILE_HELPER_LIB = '''
|
||||
import pandas as pd
|
||||
|
||||
def explore_df(df: pd.DataFrame) -> None:
|
||||
"""Prints some information about a pandas DataFrame."""
|
||||
|
||||
with pd.option_context(
|
||||
'display.max_columns', None, 'display.expand_frame_repr', False
|
||||
):
|
||||
# Print the column names to never encounter KeyError when selecting one.
|
||||
df_dtypes = df.dtypes
|
||||
|
||||
# Obtain information about data types and missing values.
|
||||
df_nulls = (len(df) - df.isnull().sum()).apply(
|
||||
lambda x: f'{x} / {df.shape[0]} non-null'
|
||||
)
|
||||
|
||||
# Explore unique total values in columns using `.unique()`.
|
||||
df_unique_count = df.apply(lambda x: len(x.unique()))
|
||||
|
||||
# Explore unique values in columns using `.unique()`.
|
||||
df_unique = df.apply(lambda x: crop(str(list(x.unique()))))
|
||||
|
||||
df_info = pd.concat(
|
||||
(
|
||||
df_dtypes.rename('Dtype'),
|
||||
df_nulls.rename('Non-Null Count'),
|
||||
df_unique_count.rename('Unique Values Count'),
|
||||
df_unique.rename('Unique Values'),
|
||||
),
|
||||
axis=1,
|
||||
)
|
||||
df_info.index.name = 'Columns'
|
||||
print(f"""Total rows: {df.shape[0]}
|
||||
Total columns: {df.shape[1]}
|
||||
|
||||
{df_info}""")
|
||||
'''
|
||||
|
||||
|
||||
class _CodeExecutionRequestProcessor(BaseLlmRequestProcessor):
|
||||
"""Processes code execution requests."""
|
||||
|
||||
@override
|
||||
async def run_async(
|
||||
self, invocation_context: InvocationContext, llm_request: LlmRequest
|
||||
) -> AsyncGenerator[Event, None]:
|
||||
from ...agents.llm_agent import LlmAgent
|
||||
|
||||
if not isinstance(invocation_context.agent, LlmAgent):
|
||||
return
|
||||
if not invocation_context.agent.code_executor:
|
||||
return
|
||||
|
||||
for event in _run_pre_processor(invocation_context, llm_request):
|
||||
yield event
|
||||
|
||||
# Convert the code execution parts to text parts.
|
||||
if not isinstance(invocation_context.agent.code_executor, BaseCodeExecutor):
|
||||
return
|
||||
for content in llm_request.contents:
|
||||
CodeExecutionUtils.convert_code_execution_parts(
|
||||
content,
|
||||
invocation_context.agent.code_executor.code_block_delimiters[0]
|
||||
if invocation_context.agent.code_executor.code_block_delimiters
|
||||
else ('', ''),
|
||||
invocation_context.agent.code_executor.execution_result_delimiters,
|
||||
)
|
||||
|
||||
|
||||
request_processor = _CodeExecutionRequestProcessor()
|
||||
|
||||
|
||||
class _CodeExecutionResponseProcessor(BaseLlmResponseProcessor):
|
||||
"""Processes code execution responses."""
|
||||
|
||||
@override
|
||||
async def run_async(
|
||||
self, invocation_context: InvocationContext, llm_response: LlmResponse
|
||||
) -> AsyncGenerator[Event, None]:
|
||||
# Skip if the response is partial (streaming).
|
||||
if llm_response.partial:
|
||||
return
|
||||
|
||||
for event in _run_post_processor(invocation_context, llm_response):
|
||||
yield event
|
||||
|
||||
|
||||
response_processor = _CodeExecutionResponseProcessor()
|
||||
|
||||
|
||||
def _run_pre_processor(
|
||||
invocation_context: InvocationContext,
|
||||
llm_request: LlmRequest,
|
||||
) -> Generator[Event, None, None]:
|
||||
"""Pre-process the user message by adding the user message to the Colab notebook."""
|
||||
from ...agents.llm_agent import LlmAgent
|
||||
|
||||
if not isinstance(invocation_context.agent, LlmAgent):
|
||||
return
|
||||
|
||||
agent = invocation_context.agent
|
||||
code_executor = agent.code_executor
|
||||
|
||||
if not code_executor or not isinstance(code_executor, BaseCodeExecutor):
|
||||
return
|
||||
if not code_executor.optimize_data_file:
|
||||
return
|
||||
|
||||
code_executor_context = CodeExecutorContext(invocation_context.session.state)
|
||||
|
||||
# Skip if the error count exceeds the max retry attempts.
|
||||
if (
|
||||
code_executor_context.get_error_count(invocation_context.invocation_id)
|
||||
>= code_executor.error_retry_attempts
|
||||
):
|
||||
return
|
||||
|
||||
# [Step 1] Extract data files from the session_history and store them in
|
||||
# memory. Meanwhile, mutate the inline data file to text part in session
|
||||
# history from all turns.
|
||||
all_input_files = _extrac_and_replace_inline_files(
|
||||
code_executor_context, llm_request
|
||||
)
|
||||
|
||||
# [Step 2] Run Explore_Df code on the data files from the current turn. We
|
||||
# only need to explore the new data files because the previous data files
|
||||
# should already be explored and cached in the code execution runtime.
|
||||
processed_file_names = set(code_executor_context.get_processed_file_names())
|
||||
files_to_process = [
|
||||
f for f in all_input_files if f.name not in processed_file_names
|
||||
]
|
||||
for file in files_to_process:
|
||||
code_str = _get_data_file_preprocessing_code(file)
|
||||
# Skip for unsupported file or executor types.
|
||||
if not code_str:
|
||||
return
|
||||
|
||||
# Emit the code to execute, and add it to the LLM request.
|
||||
code_content = types.Content(
|
||||
role='model',
|
||||
parts=[
|
||||
types.Part(text=f'Processing input file: `{file.name}`'),
|
||||
CodeExecutionUtils.build_executable_code_part(code_str),
|
||||
],
|
||||
)
|
||||
llm_request.contents.append(copy.deepcopy(code_content))
|
||||
yield Event(
|
||||
invocation_id=invocation_context.invocation_id,
|
||||
author=agent.name,
|
||||
branch=invocation_context.branch,
|
||||
content=code_content,
|
||||
)
|
||||
|
||||
code_execution_result = code_executor.execute_code(
|
||||
invocation_context,
|
||||
CodeExecutionInput(
|
||||
code=code_str,
|
||||
input_files=[file],
|
||||
execution_id=_get_or_set_execution_id(
|
||||
invocation_context, code_executor_context
|
||||
),
|
||||
),
|
||||
)
|
||||
# Update the processing results to code executor context.
|
||||
code_executor_context.update_code_execution_result(
|
||||
invocation_context.invocation_id,
|
||||
code_str,
|
||||
code_execution_result.stdout,
|
||||
code_execution_result.stderr,
|
||||
)
|
||||
code_executor_context.add_processed_file_names([file.name])
|
||||
|
||||
# Emit the execution result, and add it to the LLM request.
|
||||
execution_result_event = _post_process_code_execution_result(
|
||||
invocation_context, code_executor_context, code_execution_result
|
||||
)
|
||||
yield execution_result_event
|
||||
llm_request.contents.append(copy.deepcopy(execution_result_event.content))
|
||||
|
||||
|
||||
def _run_post_processor(
|
||||
invocation_context: InvocationContext,
|
||||
llm_response,
|
||||
) -> Generator[Event, None, None]:
|
||||
"""Post-process the model response by extracting and executing the first code block."""
|
||||
agent = invocation_context.agent
|
||||
code_executor = agent.code_executor
|
||||
|
||||
if not code_executor or not isinstance(code_executor, BaseCodeExecutor):
|
||||
return
|
||||
if not llm_response or not llm_response.content:
|
||||
return
|
||||
|
||||
code_executor_context = CodeExecutorContext(invocation_context.session.state)
|
||||
# Skip if the error count exceeds the max retry attempts.
|
||||
if (
|
||||
code_executor_context.get_error_count(invocation_context.invocation_id)
|
||||
>= code_executor.error_retry_attempts
|
||||
):
|
||||
return
|
||||
|
||||
# [Step 1] Extract code from the model predict response and truncate the
|
||||
# content to the part with the first code block.
|
||||
response_content = llm_response.content
|
||||
code_str = CodeExecutionUtils.extract_code_and_truncate_content(
|
||||
response_content, code_executor.code_block_delimiters
|
||||
)
|
||||
# Terminal state: no code to execute.
|
||||
if not code_str:
|
||||
return
|
||||
|
||||
# [Step 2] Executes the code and emit 2 Events for code and execution result.
|
||||
yield Event(
|
||||
invocation_id=invocation_context.invocation_id,
|
||||
author=agent.name,
|
||||
branch=invocation_context.branch,
|
||||
content=response_content,
|
||||
actions=EventActions(),
|
||||
)
|
||||
|
||||
code_execution_result = code_executor.execute_code(
|
||||
invocation_context,
|
||||
CodeExecutionInput(
|
||||
code=code_str,
|
||||
input_files=code_executor_context.get_input_files(),
|
||||
execution_id=_get_or_set_execution_id(
|
||||
invocation_context, code_executor_context
|
||||
),
|
||||
),
|
||||
)
|
||||
code_executor_context.update_code_execution_result(
|
||||
invocation_context.invocation_id,
|
||||
code_str,
|
||||
code_execution_result.stdout,
|
||||
code_execution_result.stderr,
|
||||
)
|
||||
yield _post_process_code_execution_result(
|
||||
invocation_context, code_executor_context, code_execution_result
|
||||
)
|
||||
|
||||
# [Step 3] Skip processing the original model response
|
||||
# to continue code generation loop.
|
||||
llm_response.content = None
|
||||
|
||||
|
||||
def _extrac_and_replace_inline_files(
|
||||
code_executor_context: CodeExecutorContext,
|
||||
llm_request: LlmRequest,
|
||||
) -> list[File]:
|
||||
"""Extracts and replaces inline files with file names in the LLM request."""
|
||||
all_input_files = code_executor_context.get_input_files()
|
||||
saved_file_names = set(f.name for f in all_input_files)
|
||||
|
||||
# [Step 1] Process input files from LlmRequest and cache them in CodeExecutor.
|
||||
for i in range(len(llm_request.contents)):
|
||||
content = llm_request.contents[i]
|
||||
# Only process the user message.
|
||||
if content.role != 'user' and not content.parts:
|
||||
continue
|
||||
|
||||
for j in range(len(content.parts)):
|
||||
part = content.parts[j]
|
||||
# Skip if the inline data is not supported.
|
||||
if (
|
||||
not part.inline_data
|
||||
or part.inline_data.mime_type not in _DATA_FILE_UTIL_MAP
|
||||
):
|
||||
continue
|
||||
|
||||
# Replace the inline data file with a file name placeholder.
|
||||
mime_type = part.inline_data.mime_type
|
||||
file_name = f'data_{i+1}_{j+1}' + _DATA_FILE_UTIL_MAP[mime_type].extension
|
||||
llm_request.contents[i].parts[j] = types.Part(
|
||||
text='\nAvailable file: `%s`\n' % file_name
|
||||
)
|
||||
|
||||
# Add the inlne data as input file to the code executor context.
|
||||
file = File(
|
||||
name=file_name,
|
||||
content=CodeExecutionUtils.get_encoded_file_content(
|
||||
part.inline_data.data
|
||||
).decode(),
|
||||
mime_type=mime_type,
|
||||
)
|
||||
if file_name not in saved_file_names:
|
||||
code_executor_context.add_input_files([file])
|
||||
all_input_files.append(file)
|
||||
|
||||
return all_input_files
|
||||
|
||||
|
||||
def _get_or_set_execution_id(
|
||||
invocation_context: InvocationContext,
|
||||
code_executor_context: CodeExecutorContext,
|
||||
) -> Optional[str]:
|
||||
"""Returns the ID for stateful code execution or None if not stateful."""
|
||||
if not invocation_context.agent.code_executor.stateful:
|
||||
return None
|
||||
|
||||
execution_id = code_executor_context.get_execution_id()
|
||||
if not execution_id:
|
||||
execution_id = invocation_context.session.id
|
||||
code_executor_context.set_execution_id(execution_id)
|
||||
return execution_id
|
||||
|
||||
|
||||
def _post_process_code_execution_result(
|
||||
invocation_context: InvocationContext,
|
||||
code_executor_context: CodeExecutorContext,
|
||||
code_execution_result: CodeExecutionResult,
|
||||
) -> Event:
|
||||
"""Post-process the code execution result and emit an Event."""
|
||||
if invocation_context.artifact_service is None:
|
||||
raise ValueError('Artifact service is not initialized.')
|
||||
|
||||
result_content = types.Content(
|
||||
role='model',
|
||||
parts=[
|
||||
CodeExecutionUtils.build_code_execution_result_part(
|
||||
code_execution_result
|
||||
),
|
||||
],
|
||||
)
|
||||
event_actions = EventActions(
|
||||
state_delta=code_executor_context.get_state_delta()
|
||||
)
|
||||
|
||||
# Handle code execution error retry.
|
||||
if code_execution_result.stderr:
|
||||
code_executor_context.increment_error_count(
|
||||
invocation_context.invocation_id
|
||||
)
|
||||
else:
|
||||
code_executor_context.reset_error_count(invocation_context.invocation_id)
|
||||
|
||||
# Handle output files.
|
||||
for output_file in code_execution_result.output_files:
|
||||
version = invocation_context.artifact_service.save_artifact(
|
||||
app_name=invocation_context.app_name,
|
||||
user_id=invocation_context.user_id,
|
||||
session_id=invocation_context.session.id,
|
||||
filename=output_file.name,
|
||||
artifact=types.Part.from_bytes(
|
||||
data=base64.b64decode(output_file.content),
|
||||
mime_type=output_file.mime_type,
|
||||
),
|
||||
)
|
||||
event_actions.artifact_delta[output_file.name] = version
|
||||
|
||||
return Event(
|
||||
invocation_id=invocation_context.invocation_id,
|
||||
author=invocation_context.agent.name,
|
||||
branch=invocation_context.branch,
|
||||
content=result_content,
|
||||
actions=event_actions,
|
||||
)
|
||||
|
||||
|
||||
def _get_data_file_preprocessing_code(file: File) -> Optional[str]:
|
||||
"""Returns the code to explore the data file."""
|
||||
|
||||
def _get_normalized_file_name(file_name: str) -> str:
|
||||
var_name, _ = os.path.splitext(file_name)
|
||||
# Replace non-alphanumeric characters with underscores
|
||||
var_name = re.sub(r'[^a-zA-Z0-9_]', '_', var_name)
|
||||
|
||||
# If the filename starts with a digit, prepend an underscore
|
||||
if var_name[0].isdigit():
|
||||
var_name = '_' + var_name
|
||||
return var_name
|
||||
|
||||
if file.mime_type not in _DATA_FILE_UTIL_MAP:
|
||||
return
|
||||
|
||||
var_name = _get_normalized_file_name(file.name)
|
||||
loader_code = _DATA_FILE_UTIL_MAP[file.mime_type].loader_code_template.format(
|
||||
filename=file.name
|
||||
)
|
||||
return f"""
|
||||
{_DATA_FILE_HELPER_LIB}
|
||||
|
||||
# Load the dataframe.
|
||||
{var_name} = {loader_code}
|
||||
|
||||
# Use `explore_df` to guide my analysis.
|
||||
explore_df({var_name})
|
||||
"""
|
||||
129
src/google/adk/flows/llm_flows/_nl_planning.py
Normal file
129
src/google/adk/flows/llm_flows/_nl_planning.py
Normal file
@ -0,0 +1,129 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Handles NL planning related logic."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import AsyncGenerator
|
||||
from typing import Generator
|
||||
from typing import Optional
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from typing_extensions import override
|
||||
|
||||
from ...agents.callback_context import CallbackContext
|
||||
from ...agents.invocation_context import InvocationContext
|
||||
from ...agents.readonly_context import ReadonlyContext
|
||||
from ...events.event import Event
|
||||
from ...planners.plan_re_act_planner import PlanReActPlanner
|
||||
from ._base_llm_processor import BaseLlmRequestProcessor
|
||||
from ._base_llm_processor import BaseLlmResponseProcessor
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ...models.llm_request import LlmRequest
|
||||
from ...models.llm_response import LlmResponse
|
||||
from ...planners.base_planner import BasePlanner
|
||||
from ...planners.built_in_planner import BuiltInPlanner
|
||||
|
||||
|
||||
class _NlPlanningRequestProcessor(BaseLlmRequestProcessor):
|
||||
"""Processor for NL planning."""
|
||||
|
||||
async def run_async(
|
||||
self, invocation_context: InvocationContext, llm_request: LlmRequest
|
||||
) -> AsyncGenerator[Event, None]:
|
||||
from ...planners.built_in_planner import BuiltInPlanner
|
||||
|
||||
planner = _get_planner(invocation_context)
|
||||
if not planner:
|
||||
return
|
||||
|
||||
if isinstance(planner, BuiltInPlanner):
|
||||
planner.apply_thinking_config(llm_request)
|
||||
|
||||
planning_instruction = planner.build_planning_instruction(
|
||||
ReadonlyContext(invocation_context), llm_request
|
||||
)
|
||||
if planning_instruction:
|
||||
llm_request.append_instructions([planning_instruction])
|
||||
|
||||
_remove_thought_from_request(llm_request)
|
||||
|
||||
# Maintain async generator behavior
|
||||
if False: # Ensures it behaves as a generator
|
||||
yield # This is a no-op but maintains generator structure
|
||||
|
||||
|
||||
request_processor = _NlPlanningRequestProcessor()
|
||||
|
||||
|
||||
class _NlPlanningResponse(BaseLlmResponseProcessor):
|
||||
|
||||
@override
|
||||
async def run_async(
|
||||
self, invocation_context: InvocationContext, llm_response: LlmResponse
|
||||
) -> AsyncGenerator[Event, None]:
|
||||
if (
|
||||
not llm_response
|
||||
or not llm_response.content
|
||||
or not llm_response.content.parts
|
||||
):
|
||||
return
|
||||
|
||||
planner = _get_planner(invocation_context)
|
||||
if not planner:
|
||||
return
|
||||
|
||||
# Postprocess the LLM response.
|
||||
processed_parts = planner.process_planning_response(
|
||||
CallbackContext(invocation_context), llm_response.content.parts
|
||||
)
|
||||
if processed_parts:
|
||||
llm_response.content.parts = processed_parts
|
||||
|
||||
# Maintain async generator behavior
|
||||
if False: # Ensures it behaves as a generator
|
||||
yield # This is a no-op but maintains generator structure
|
||||
|
||||
|
||||
response_processor = _NlPlanningResponse()
|
||||
|
||||
|
||||
def _get_planner(
|
||||
invocation_context: InvocationContext,
|
||||
) -> Optional[BasePlanner]:
|
||||
from ...agents.llm_agent import Agent
|
||||
from ...planners.base_planner import BasePlanner
|
||||
|
||||
agent = invocation_context.agent
|
||||
if not isinstance(agent, Agent):
|
||||
return None
|
||||
if not agent.planner:
|
||||
return None
|
||||
|
||||
if isinstance(agent.planner, BasePlanner):
|
||||
return agent.planner
|
||||
return PlanReActPlanner()
|
||||
|
||||
|
||||
def _remove_thought_from_request(llm_request: LlmRequest):
|
||||
if not llm_request.contents:
|
||||
return
|
||||
|
||||
for content in llm_request.contents:
|
||||
if not content.parts:
|
||||
continue
|
||||
for part in content.parts:
|
||||
part.thought = None
|
||||
132
src/google/adk/flows/llm_flows/agent_transfer.py
Normal file
132
src/google/adk/flows/llm_flows/agent_transfer.py
Normal file
@ -0,0 +1,132 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Handles agent transfer for LLM flow."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import typing
|
||||
from typing import AsyncGenerator
|
||||
|
||||
from typing_extensions import override
|
||||
|
||||
from ...agents.invocation_context import InvocationContext
|
||||
from ...events.event import Event
|
||||
from ...models.llm_request import LlmRequest
|
||||
from ...tools.function_tool import FunctionTool
|
||||
from ...tools.tool_context import ToolContext
|
||||
from ...tools.transfer_to_agent_tool import transfer_to_agent
|
||||
from ._base_llm_processor import BaseLlmRequestProcessor
|
||||
|
||||
if typing.TYPE_CHECKING:
|
||||
from ...agents import BaseAgent
|
||||
from ...agents import LlmAgent
|
||||
|
||||
|
||||
class _AgentTransferLlmRequestProcessor(BaseLlmRequestProcessor):
|
||||
"""Agent transfer request processor."""
|
||||
|
||||
@override
|
||||
async def run_async(
|
||||
self, invocation_context: InvocationContext, llm_request: LlmRequest
|
||||
) -> AsyncGenerator[Event, None]:
|
||||
from ...agents.llm_agent import LlmAgent
|
||||
|
||||
if not isinstance(invocation_context.agent, LlmAgent):
|
||||
return
|
||||
|
||||
transfer_targets = _get_transfer_targets(invocation_context.agent)
|
||||
if not transfer_targets:
|
||||
return
|
||||
|
||||
llm_request.append_instructions([
|
||||
_build_target_agents_instructions(
|
||||
invocation_context.agent, transfer_targets
|
||||
)
|
||||
])
|
||||
|
||||
transfer_to_agent_tool = FunctionTool(func=transfer_to_agent)
|
||||
tool_context = ToolContext(invocation_context)
|
||||
await transfer_to_agent_tool.process_llm_request(
|
||||
tool_context=tool_context, llm_request=llm_request
|
||||
)
|
||||
|
||||
return
|
||||
yield # AsyncGenerator requires yield statement in function body.
|
||||
|
||||
|
||||
request_processor = _AgentTransferLlmRequestProcessor()
|
||||
|
||||
|
||||
def _build_target_agents_info(target_agent: BaseAgent) -> str:
|
||||
return f"""
|
||||
Agent name: {target_agent.name}
|
||||
Agent description: {target_agent.description}
|
||||
"""
|
||||
|
||||
|
||||
line_break = '\n'
|
||||
|
||||
|
||||
def _build_target_agents_instructions(
|
||||
agent: LlmAgent, target_agents: list[BaseAgent]
|
||||
) -> str:
|
||||
si = f"""
|
||||
You have a list of other agents to transfer to:
|
||||
|
||||
{line_break.join([
|
||||
_build_target_agents_info(target_agent) for target_agent in target_agents
|
||||
])}
|
||||
|
||||
If you are the best to answer the question according to your description, you
|
||||
can answer it.
|
||||
|
||||
If another agent is better for answering the question according to its
|
||||
description, call `{_TRANSFER_TO_AGENT_FUNCTION_NAME}` function to transfer the
|
||||
question to that agent. When transfering, do not generate any text other than
|
||||
the function call.
|
||||
"""
|
||||
|
||||
if agent.parent_agent:
|
||||
si += f"""
|
||||
Your parent agent is {agent.parent_agent.name}. If neither the other agents nor
|
||||
you are best for answering the question according to the descriptions, transfer
|
||||
to your parent agent. If you don't have parent agent, try answer by yourself.
|
||||
"""
|
||||
return si
|
||||
|
||||
|
||||
_TRANSFER_TO_AGENT_FUNCTION_NAME = transfer_to_agent.__name__
|
||||
|
||||
|
||||
def _get_transfer_targets(agent: LlmAgent) -> list[BaseAgent]:
|
||||
from ...agents.llm_agent import LlmAgent
|
||||
|
||||
result = []
|
||||
result.extend(agent.sub_agents)
|
||||
|
||||
if not agent.parent_agent or not isinstance(agent.parent_agent, LlmAgent):
|
||||
return result
|
||||
|
||||
if not agent.disallow_transfer_to_parent:
|
||||
result.append(agent.parent_agent)
|
||||
|
||||
if not agent.disallow_transfer_to_peers:
|
||||
result.extend([
|
||||
peer_agent
|
||||
for peer_agent in agent.parent_agent.sub_agents
|
||||
if peer_agent.name != agent.name
|
||||
])
|
||||
|
||||
return result
|
||||
109
src/google/adk/flows/llm_flows/audio_transcriber.py
Normal file
109
src/google/adk/flows/llm_flows/audio_transcriber.py
Normal file
@ -0,0 +1,109 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from google.cloud import speech
|
||||
from google.genai import types as genai_types
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ...agents.invocation_context import InvocationContext
|
||||
|
||||
|
||||
class AudioTranscriber:
|
||||
"""Transcribes audio using Google Cloud Speech-to-Text."""
|
||||
|
||||
def __init__(self):
|
||||
self.client = speech.SpeechClient()
|
||||
|
||||
def transcribe_file(
|
||||
self, invocation_context: InvocationContext
|
||||
) -> list[genai_types.Content]:
|
||||
"""Transcribe audio, bundling consecutive segments from the same speaker.
|
||||
|
||||
The ordering of speakers will be preserved. Audio blobs will be merged for
|
||||
the same speaker as much as we can do reduce the transcription latency.
|
||||
|
||||
Args:
|
||||
invocation_context: The invocation context to access the transcription
|
||||
cache.
|
||||
|
||||
Returns:
|
||||
A list of Content objects containing the transcribed text.
|
||||
"""
|
||||
|
||||
bundled_audio = []
|
||||
current_speaker = None
|
||||
current_audio_data = b''
|
||||
contents = []
|
||||
|
||||
# Step1: merge audio blobs
|
||||
for transcription_entry in invocation_context.transcription_cache or []:
|
||||
speaker, audio_data = (
|
||||
transcription_entry.role,
|
||||
transcription_entry.data,
|
||||
)
|
||||
|
||||
if isinstance(audio_data, genai_types.Content):
|
||||
if current_speaker is not None:
|
||||
bundled_audio.append((current_speaker, current_audio_data))
|
||||
current_speaker = None
|
||||
current_audio_data = b''
|
||||
bundled_audio.append((speaker, audio_data))
|
||||
continue
|
||||
|
||||
if not audio_data.data:
|
||||
continue
|
||||
|
||||
if speaker == current_speaker:
|
||||
current_audio_data += audio_data.data
|
||||
else:
|
||||
if current_speaker is not None:
|
||||
bundled_audio.append((current_speaker, current_audio_data))
|
||||
current_speaker = speaker
|
||||
current_audio_data = audio_data.data
|
||||
|
||||
# Append the last audio segment if any
|
||||
if current_speaker is not None:
|
||||
bundled_audio.append((current_speaker, current_audio_data))
|
||||
|
||||
# reset cache
|
||||
invocation_context.transcription_cache = []
|
||||
|
||||
# Step2: transcription
|
||||
for speaker, data in bundled_audio:
|
||||
if speaker == 'user':
|
||||
audio = speech.RecognitionAudio(content=data)
|
||||
|
||||
config = speech.RecognitionConfig(
|
||||
encoding=speech.RecognitionConfig.AudioEncoding.LINEAR16,
|
||||
sample_rate_hertz=16000,
|
||||
language_code='en-US',
|
||||
)
|
||||
|
||||
response = self.client.recognize(config=config, audio=audio)
|
||||
|
||||
for result in response.results:
|
||||
transcript = result.alternatives[0].transcript
|
||||
|
||||
parts = [genai_types.Part(text=transcript)]
|
||||
role = speaker.lower()
|
||||
content = genai_types.Content(role=role, parts=parts)
|
||||
contents.append(content)
|
||||
else:
|
||||
# don't need to transcribe model which are already text
|
||||
contents.append(data)
|
||||
|
||||
return contents
|
||||
49
src/google/adk/flows/llm_flows/auto_flow.py
Normal file
49
src/google/adk/flows/llm_flows/auto_flow.py
Normal file
@ -0,0 +1,49 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Implementation of AutoFlow."""
|
||||
|
||||
from . import agent_transfer
|
||||
from .single_flow import SingleFlow
|
||||
|
||||
|
||||
class AutoFlow(SingleFlow):
|
||||
"""AutoFlow is SingleFlow with agent transfer capability.
|
||||
|
||||
Agent transfer is allowed in the following direction:
|
||||
|
||||
1. from parent to sub-agent;
|
||||
2. from sub-agent to parent;
|
||||
3. from sub-agent to its peer agents;
|
||||
|
||||
For peer-agent transfers, it's only enabled when all below conditions are met:
|
||||
|
||||
- The parent agent is also of AutoFlow;
|
||||
- `disallow_transfer_to_peer` option of this agent is False (default).
|
||||
|
||||
Depending on the target agent flow type, the transfer may be automatically
|
||||
reversed. The condition is as below:
|
||||
|
||||
- If the flow type of the tranferee agent is also auto, transfee agent will
|
||||
remain as the active agent. The transfee agent will respond to the user's
|
||||
next message directly.
|
||||
- If the flow type of the transfere agent is not auto, the active agent will
|
||||
be reversed back to previous agent.
|
||||
|
||||
TODO: allow user to config auto-reverse function.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.request_processors += [agent_transfer.request_processor]
|
||||
559
src/google/adk/flows/llm_flows/base_llm_flow.py
Normal file
559
src/google/adk/flows/llm_flows/base_llm_flow.py
Normal file
@ -0,0 +1,559 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC
|
||||
import asyncio
|
||||
import logging
|
||||
from typing import AsyncGenerator
|
||||
from typing import cast
|
||||
from typing import Optional
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from websockets.exceptions import ConnectionClosedOK
|
||||
|
||||
from ...agents.base_agent import BaseAgent
|
||||
from ...agents.callback_context import CallbackContext
|
||||
from ...agents.invocation_context import InvocationContext
|
||||
from ...agents.live_request_queue import LiveRequestQueue
|
||||
from ...agents.run_config import StreamingMode
|
||||
from ...agents.transcription_entry import TranscriptionEntry
|
||||
from ...events.event import Event
|
||||
from ...models.base_llm_connection import BaseLlmConnection
|
||||
from ...models.llm_request import LlmRequest
|
||||
from ...models.llm_response import LlmResponse
|
||||
from ...telemetry import trace_call_llm
|
||||
from ...telemetry import trace_send_data
|
||||
from ...telemetry import tracer
|
||||
from ...tools.tool_context import ToolContext
|
||||
from . import functions
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ...agents.llm_agent import LlmAgent
|
||||
from ...models.base_llm import BaseLlm
|
||||
from ._base_llm_processor import BaseLlmRequestProcessor
|
||||
from ._base_llm_processor import BaseLlmResponseProcessor
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class BaseLlmFlow(ABC):
|
||||
"""A basic flow that calls the LLM in a loop until a final response is generated.
|
||||
|
||||
This flow ends when it transfer to another agent.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.request_processors: list[BaseLlmRequestProcessor] = []
|
||||
self.response_processors: list[BaseLlmResponseProcessor] = []
|
||||
|
||||
async def run_live(
|
||||
self,
|
||||
invocation_context: InvocationContext,
|
||||
) -> AsyncGenerator[Event, None]:
|
||||
"""Runs the flow using live api."""
|
||||
llm_request = LlmRequest()
|
||||
event_id = Event.new_id()
|
||||
|
||||
# Preprocess before calling the LLM.
|
||||
async for event in self._preprocess_async(invocation_context, llm_request):
|
||||
yield event
|
||||
if invocation_context.end_invocation:
|
||||
return
|
||||
|
||||
llm = self.__get_llm(invocation_context)
|
||||
logger.info(
|
||||
'Establishing live connection for agent: %s with llm request: %s',
|
||||
invocation_context.agent.name,
|
||||
llm_request,
|
||||
)
|
||||
async with llm.connect(llm_request) as llm_connection:
|
||||
if llm_request.contents:
|
||||
# Sends the conversation history to the model.
|
||||
with tracer.start_as_current_span('send_data'):
|
||||
|
||||
if invocation_context.transcription_cache:
|
||||
from . import audio_transcriber
|
||||
|
||||
audio_transcriber = audio_transcriber.AudioTranscriber()
|
||||
contents = audio_transcriber.transcribe_file(invocation_context)
|
||||
logger.debug('Sending history to model: %s', contents)
|
||||
await llm_connection.send_history(contents)
|
||||
invocation_context.transcription_cache = None
|
||||
trace_send_data(invocation_context, event_id, contents)
|
||||
else:
|
||||
await llm_connection.send_history(llm_request.contents)
|
||||
trace_send_data(invocation_context, event_id, llm_request.contents)
|
||||
|
||||
send_task = asyncio.create_task(
|
||||
self._send_to_model(llm_connection, invocation_context)
|
||||
)
|
||||
|
||||
try:
|
||||
async for event in self._receive_from_model(
|
||||
llm_connection,
|
||||
event_id,
|
||||
invocation_context,
|
||||
llm_request,
|
||||
):
|
||||
# Empty event means the queue is closed.
|
||||
if not event:
|
||||
break
|
||||
logger.debug('Receive new event: %s', event)
|
||||
yield event
|
||||
# send back the function response
|
||||
if event.get_function_responses():
|
||||
logger.debug('Sending back last function resonse event: %s', event)
|
||||
invocation_context.live_request_queue.send_content(event.content)
|
||||
if (
|
||||
event.content
|
||||
and event.content.parts
|
||||
and event.content.parts[0].function_response
|
||||
and event.content.parts[0].function_response.name
|
||||
== 'transfer_to_agent'
|
||||
):
|
||||
await asyncio.sleep(1)
|
||||
# cancel the tasks that belongs to the closed connection.
|
||||
send_task.cancel()
|
||||
await llm_connection.close()
|
||||
finally:
|
||||
# Clean up
|
||||
if not send_task.done():
|
||||
send_task.cancel()
|
||||
try:
|
||||
await send_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
async def _send_to_model(
|
||||
self,
|
||||
llm_connection: BaseLlmConnection,
|
||||
invocation_context: InvocationContext,
|
||||
):
|
||||
"""Sends data to model."""
|
||||
while True:
|
||||
live_request_queue = invocation_context.live_request_queue
|
||||
try:
|
||||
# Streamlit's execution model doesn't preemptively yield to the event
|
||||
# loop. Therefore, we must explicitly introduce timeouts to allow the
|
||||
# event loop to process events.
|
||||
# TODO: revert back(remove timeout) once we move off streamlit.
|
||||
live_request = await asyncio.wait_for(
|
||||
live_request_queue.get(), timeout=0.25
|
||||
)
|
||||
# duplicate the live_request to all the active streams
|
||||
logger.debug(
|
||||
'Sending live request %s to active streams: %s',
|
||||
live_request,
|
||||
invocation_context.active_streaming_tools,
|
||||
)
|
||||
if invocation_context.active_streaming_tools:
|
||||
for active_streaming_tool in (
|
||||
invocation_context.active_streaming_tools
|
||||
).values():
|
||||
if active_streaming_tool.stream:
|
||||
active_streaming_tool.stream.send(live_request)
|
||||
await asyncio.sleep(0)
|
||||
except asyncio.TimeoutError:
|
||||
continue
|
||||
if live_request.close:
|
||||
await llm_connection.close()
|
||||
return
|
||||
if live_request.blob:
|
||||
# Cache audio data here for transcription
|
||||
if not invocation_context.transcription_cache:
|
||||
invocation_context.transcription_cache = []
|
||||
invocation_context.transcription_cache.append(
|
||||
TranscriptionEntry(role='user', data=live_request.blob)
|
||||
)
|
||||
await llm_connection.send_realtime(live_request.blob)
|
||||
if live_request.content:
|
||||
await llm_connection.send_content(live_request.content)
|
||||
|
||||
async def _receive_from_model(
|
||||
self,
|
||||
llm_connection: BaseLlmConnection,
|
||||
event_id: str,
|
||||
invocation_context: InvocationContext,
|
||||
llm_request: LlmRequest,
|
||||
) -> AsyncGenerator[Event, None]:
|
||||
"""Receive data from model and process events using BaseLlmConnection."""
|
||||
assert invocation_context.live_request_queue
|
||||
try:
|
||||
while True:
|
||||
async for llm_response in llm_connection.receive():
|
||||
model_response_event = Event(
|
||||
id=Event.new_id(),
|
||||
invocation_id=invocation_context.invocation_id,
|
||||
author=invocation_context.agent.name,
|
||||
)
|
||||
async for event in self._postprocess_live(
|
||||
invocation_context,
|
||||
llm_request,
|
||||
llm_response,
|
||||
model_response_event,
|
||||
):
|
||||
if (
|
||||
event.content
|
||||
and event.content.parts
|
||||
and event.content.parts[0].text
|
||||
and not event.partial
|
||||
):
|
||||
if not invocation_context.transcription_cache:
|
||||
invocation_context.transcription_cache = []
|
||||
invocation_context.transcription_cache.append(
|
||||
TranscriptionEntry(role='model', data=event.content)
|
||||
)
|
||||
yield event
|
||||
# Give opportunity for other tasks to run.
|
||||
await asyncio.sleep(0)
|
||||
except ConnectionClosedOK:
|
||||
pass
|
||||
|
||||
async def run_async(
|
||||
self, invocation_context: InvocationContext
|
||||
) -> AsyncGenerator[Event, None]:
|
||||
"""Runs the flow."""
|
||||
while True:
|
||||
last_event = None
|
||||
async for event in self._run_one_step_async(invocation_context):
|
||||
last_event = event
|
||||
yield event
|
||||
if not last_event or last_event.is_final_response():
|
||||
break
|
||||
|
||||
async def _run_one_step_async(
|
||||
self,
|
||||
invocation_context: InvocationContext,
|
||||
) -> AsyncGenerator[Event, None]:
|
||||
"""One step means one LLM call."""
|
||||
llm_request = LlmRequest()
|
||||
|
||||
# Preprocess before calling the LLM.
|
||||
async for event in self._preprocess_async(invocation_context, llm_request):
|
||||
yield event
|
||||
if invocation_context.end_invocation:
|
||||
return
|
||||
|
||||
# Calls the LLM.
|
||||
model_response_event = Event(
|
||||
id=Event.new_id(),
|
||||
invocation_id=invocation_context.invocation_id,
|
||||
author=invocation_context.agent.name,
|
||||
branch=invocation_context.branch,
|
||||
)
|
||||
async for llm_response in self._call_llm_async(
|
||||
invocation_context, llm_request, model_response_event
|
||||
):
|
||||
# Postprocess after calling the LLM.
|
||||
async for event in self._postprocess_async(
|
||||
invocation_context, llm_request, llm_response, model_response_event
|
||||
):
|
||||
yield event
|
||||
|
||||
async def _preprocess_async(
|
||||
self, invocation_context: InvocationContext, llm_request: LlmRequest
|
||||
) -> AsyncGenerator[Event, None]:
|
||||
from ...agents.llm_agent import LlmAgent
|
||||
|
||||
agent = invocation_context.agent
|
||||
if not isinstance(agent, LlmAgent):
|
||||
return
|
||||
|
||||
# Runs processors.
|
||||
for processor in self.request_processors:
|
||||
async for event in processor.run_async(invocation_context, llm_request):
|
||||
yield event
|
||||
|
||||
# Run processors for tools.
|
||||
for tool in agent.canonical_tools:
|
||||
tool_context = ToolContext(invocation_context)
|
||||
await tool.process_llm_request(
|
||||
tool_context=tool_context, llm_request=llm_request
|
||||
)
|
||||
|
||||
async def _postprocess_async(
|
||||
self,
|
||||
invocation_context: InvocationContext,
|
||||
llm_request: LlmRequest,
|
||||
llm_response: LlmResponse,
|
||||
model_response_event: Event,
|
||||
) -> AsyncGenerator[Event, None]:
|
||||
"""Postprocess after calling the LLM.
|
||||
|
||||
Args:
|
||||
invocation_context: The invocation context.
|
||||
llm_request: The original LLM request.
|
||||
llm_response: The LLM response from the LLM call.
|
||||
model_response_event: A mutable event for the LLM response.
|
||||
|
||||
Yields:
|
||||
A generator of events.
|
||||
"""
|
||||
|
||||
# Runs processors.
|
||||
async for event in self._postprocess_run_processors_async(
|
||||
invocation_context, llm_response
|
||||
):
|
||||
yield event
|
||||
|
||||
# Skip the model response event if there is no content and no error code.
|
||||
# This is needed for the code executor to trigger another loop.
|
||||
if (
|
||||
not llm_response.content
|
||||
and not llm_response.error_code
|
||||
and not llm_response.interrupted
|
||||
):
|
||||
return
|
||||
|
||||
# Builds the event.
|
||||
model_response_event = self._finalize_model_response_event(
|
||||
llm_request, llm_response, model_response_event
|
||||
)
|
||||
yield model_response_event
|
||||
|
||||
# Handles function calls.
|
||||
if model_response_event.get_function_calls():
|
||||
async for event in self._postprocess_handle_function_calls_async(
|
||||
invocation_context, model_response_event, llm_request
|
||||
):
|
||||
yield event
|
||||
|
||||
async def _postprocess_live(
|
||||
self,
|
||||
invocation_context: InvocationContext,
|
||||
llm_request: LlmRequest,
|
||||
llm_response: LlmResponse,
|
||||
model_response_event: Event,
|
||||
) -> AsyncGenerator[Event, None]:
|
||||
"""Postprocess after calling the LLM asynchronously.
|
||||
|
||||
Args:
|
||||
invocation_context: The invocation context.
|
||||
llm_request: The original LLM request.
|
||||
llm_response: The LLM response from the LLM call.
|
||||
model_response_event: A mutable event for the LLM response.
|
||||
|
||||
Yields:
|
||||
A generator of events.
|
||||
"""
|
||||
|
||||
# Runs processors.
|
||||
async for event in self._postprocess_run_processors_async(
|
||||
invocation_context, llm_response
|
||||
):
|
||||
yield event
|
||||
|
||||
# Skip the model response event if there is no content and no error code.
|
||||
# This is needed for the code executor to trigger another loop.
|
||||
# But don't skip control events like turn_complete.
|
||||
if (
|
||||
not llm_response.content
|
||||
and not llm_response.error_code
|
||||
and not llm_response.interrupted
|
||||
and not llm_response.turn_complete
|
||||
):
|
||||
return
|
||||
|
||||
# Builds the event.
|
||||
model_response_event = self._finalize_model_response_event(
|
||||
llm_request, llm_response, model_response_event
|
||||
)
|
||||
yield model_response_event
|
||||
|
||||
# Handles function calls.
|
||||
if model_response_event.get_function_calls():
|
||||
function_response_event = await functions.handle_function_calls_live(
|
||||
invocation_context, model_response_event, llm_request.tools_dict
|
||||
)
|
||||
yield function_response_event
|
||||
|
||||
transfer_to_agent = function_response_event.actions.transfer_to_agent
|
||||
if transfer_to_agent:
|
||||
agent_to_run = self._get_agent_to_run(
|
||||
invocation_context, transfer_to_agent
|
||||
)
|
||||
async for item in agent_to_run.run_live(invocation_context):
|
||||
yield item
|
||||
|
||||
async def _postprocess_run_processors_async(
|
||||
self, invocation_context: InvocationContext, llm_response: LlmResponse
|
||||
) -> AsyncGenerator[Event, None]:
|
||||
for processor in self.response_processors:
|
||||
async for event in processor.run_async(invocation_context, llm_response):
|
||||
yield event
|
||||
|
||||
async def _postprocess_handle_function_calls_async(
|
||||
self,
|
||||
invocation_context: InvocationContext,
|
||||
function_call_event: Event,
|
||||
llm_request: LlmRequest,
|
||||
) -> AsyncGenerator[Event, None]:
|
||||
if function_response_event := await functions.handle_function_calls_async(
|
||||
invocation_context, function_call_event, llm_request.tools_dict
|
||||
):
|
||||
auth_event = functions.generate_auth_event(
|
||||
invocation_context, function_response_event
|
||||
)
|
||||
if auth_event:
|
||||
yield auth_event
|
||||
|
||||
yield function_response_event
|
||||
transfer_to_agent = function_response_event.actions.transfer_to_agent
|
||||
if transfer_to_agent:
|
||||
agent_to_run = self._get_agent_to_run(
|
||||
invocation_context, transfer_to_agent
|
||||
)
|
||||
async for event in agent_to_run.run_async(invocation_context):
|
||||
yield event
|
||||
|
||||
def _get_agent_to_run(
|
||||
self, invocation_context: InvocationContext, transfer_to_agent
|
||||
) -> BaseAgent:
|
||||
root_agent = invocation_context.agent.root_agent
|
||||
agent_to_run = root_agent.find_agent(transfer_to_agent)
|
||||
if not agent_to_run:
|
||||
raise ValueError(
|
||||
f'Agent {transfer_to_agent} not found in the agent tree.'
|
||||
)
|
||||
return agent_to_run
|
||||
|
||||
async def _call_llm_async(
|
||||
self,
|
||||
invocation_context: InvocationContext,
|
||||
llm_request: LlmRequest,
|
||||
model_response_event: Event,
|
||||
) -> AsyncGenerator[LlmResponse, None]:
|
||||
# Runs before_model_callback if it exists.
|
||||
if response := self._handle_before_model_callback(
|
||||
invocation_context, llm_request, model_response_event
|
||||
):
|
||||
yield response
|
||||
return
|
||||
|
||||
# Calls the LLM.
|
||||
llm = self.__get_llm(invocation_context)
|
||||
with tracer.start_as_current_span('call_llm'):
|
||||
if invocation_context.run_config.support_cfc:
|
||||
invocation_context.live_request_queue = LiveRequestQueue()
|
||||
async for llm_response in self.run_live(invocation_context):
|
||||
# Runs after_model_callback if it exists.
|
||||
if altered_llm_response := self._handle_after_model_callback(
|
||||
invocation_context, llm_response, model_response_event
|
||||
):
|
||||
llm_response = altered_llm_response
|
||||
# only yield partial response in SSE streaming mode
|
||||
if (
|
||||
invocation_context.run_config.streaming_mode == StreamingMode.SSE
|
||||
or not llm_response.partial
|
||||
):
|
||||
yield llm_response
|
||||
if llm_response.turn_complete:
|
||||
invocation_context.live_request_queue.close()
|
||||
else:
|
||||
# Check if we can make this llm call or not. If the current call pushes
|
||||
# the counter beyond the max set value, then the execution is stopped
|
||||
# right here, and exception is thrown.
|
||||
invocation_context.increment_llm_call_count()
|
||||
async for llm_response in llm.generate_content_async(
|
||||
llm_request,
|
||||
stream=invocation_context.run_config.streaming_mode
|
||||
== StreamingMode.SSE,
|
||||
):
|
||||
trace_call_llm(
|
||||
invocation_context,
|
||||
model_response_event.id,
|
||||
llm_request,
|
||||
llm_response,
|
||||
)
|
||||
# Runs after_model_callback if it exists.
|
||||
if altered_llm_response := self._handle_after_model_callback(
|
||||
invocation_context, llm_response, model_response_event
|
||||
):
|
||||
llm_response = altered_llm_response
|
||||
|
||||
yield llm_response
|
||||
|
||||
def _handle_before_model_callback(
|
||||
self,
|
||||
invocation_context: InvocationContext,
|
||||
llm_request: LlmRequest,
|
||||
model_response_event: Event,
|
||||
) -> Optional[LlmResponse]:
|
||||
from ...agents.llm_agent import LlmAgent
|
||||
|
||||
agent = invocation_context.agent
|
||||
if not isinstance(agent, LlmAgent):
|
||||
return
|
||||
|
||||
if not agent.before_model_callback:
|
||||
return
|
||||
|
||||
callback_context = CallbackContext(
|
||||
invocation_context, event_actions=model_response_event.actions
|
||||
)
|
||||
return agent.before_model_callback(
|
||||
callback_context=callback_context, llm_request=llm_request
|
||||
)
|
||||
|
||||
def _handle_after_model_callback(
|
||||
self,
|
||||
invocation_context: InvocationContext,
|
||||
llm_response: LlmResponse,
|
||||
model_response_event: Event,
|
||||
) -> Optional[LlmResponse]:
|
||||
from ...agents.llm_agent import LlmAgent
|
||||
|
||||
agent = invocation_context.agent
|
||||
if not isinstance(agent, LlmAgent):
|
||||
return
|
||||
|
||||
if not agent.after_model_callback:
|
||||
return
|
||||
|
||||
callback_context = CallbackContext(
|
||||
invocation_context, event_actions=model_response_event.actions
|
||||
)
|
||||
return agent.after_model_callback(
|
||||
callback_context=callback_context, llm_response=llm_response
|
||||
)
|
||||
|
||||
def _finalize_model_response_event(
|
||||
self,
|
||||
llm_request: LlmRequest,
|
||||
llm_response: LlmResponse,
|
||||
model_response_event: Event,
|
||||
) -> Event:
|
||||
model_response_event = Event.model_validate({
|
||||
**model_response_event.model_dump(exclude_none=True),
|
||||
**llm_response.model_dump(exclude_none=True),
|
||||
})
|
||||
|
||||
if model_response_event.content:
|
||||
function_calls = model_response_event.get_function_calls()
|
||||
if function_calls:
|
||||
functions.populate_client_function_call_id(model_response_event)
|
||||
model_response_event.long_running_tool_ids = (
|
||||
functions.get_long_running_function_calls(
|
||||
function_calls, llm_request.tools_dict
|
||||
)
|
||||
)
|
||||
|
||||
return model_response_event
|
||||
|
||||
def __get_llm(self, invocation_context: InvocationContext) -> BaseLlm:
|
||||
from ...agents.llm_agent import LlmAgent
|
||||
|
||||
return cast(LlmAgent, invocation_context.agent).canonical_model
|
||||
72
src/google/adk/flows/llm_flows/basic.py
Normal file
72
src/google/adk/flows/llm_flows/basic.py
Normal file
@ -0,0 +1,72 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Handles basic information to build the LLM request."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import AsyncGenerator
|
||||
from typing import Generator
|
||||
|
||||
from google.genai import types
|
||||
from typing_extensions import override
|
||||
|
||||
from ...agents.invocation_context import InvocationContext
|
||||
from ...events.event import Event
|
||||
from ...models.llm_request import LlmRequest
|
||||
from ._base_llm_processor import BaseLlmRequestProcessor
|
||||
|
||||
|
||||
class _BasicLlmRequestProcessor(BaseLlmRequestProcessor):
|
||||
|
||||
@override
|
||||
async def run_async(
|
||||
self, invocation_context: InvocationContext, llm_request: LlmRequest
|
||||
) -> AsyncGenerator[Event, None]:
|
||||
from ...agents.llm_agent import LlmAgent
|
||||
|
||||
agent = invocation_context.agent
|
||||
if not isinstance(agent, LlmAgent):
|
||||
return
|
||||
|
||||
llm_request.model = (
|
||||
agent.canonical_model
|
||||
if isinstance(agent.canonical_model, str)
|
||||
else agent.canonical_model.model
|
||||
)
|
||||
llm_request.config = (
|
||||
agent.generate_content_config.model_copy(deep=True)
|
||||
if agent.generate_content_config
|
||||
else types.GenerateContentConfig()
|
||||
)
|
||||
if agent.output_schema:
|
||||
llm_request.set_output_schema(agent.output_schema)
|
||||
|
||||
llm_request.live_connect_config.response_modalities = (
|
||||
invocation_context.run_config.response_modalities
|
||||
)
|
||||
llm_request.live_connect_config.speech_config = (
|
||||
invocation_context.run_config.speech_config
|
||||
)
|
||||
llm_request.live_connect_config.output_audio_transcription = (
|
||||
invocation_context.run_config.output_audio_transcription
|
||||
)
|
||||
|
||||
# TODO: handle tool append here, instead of in BaseTool.process_llm_request.
|
||||
|
||||
return
|
||||
yield # Generator requires yield statement in function body.
|
||||
|
||||
|
||||
request_processor = _BasicLlmRequestProcessor()
|
||||
370
src/google/adk/flows/llm_flows/contents.py
Normal file
370
src/google/adk/flows/llm_flows/contents.py
Normal file
@ -0,0 +1,370 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import copy
|
||||
from typing import AsyncGenerator
|
||||
from typing import Generator
|
||||
from typing import Optional
|
||||
|
||||
from google.genai import types
|
||||
from typing_extensions import override
|
||||
|
||||
from ...agents.invocation_context import InvocationContext
|
||||
from ...events.event import Event
|
||||
from ...models.llm_request import LlmRequest
|
||||
from ._base_llm_processor import BaseLlmRequestProcessor
|
||||
from .functions import remove_client_function_call_id
|
||||
|
||||
|
||||
class _ContentLlmRequestProcessor(BaseLlmRequestProcessor):
|
||||
"""Builds the contents for the LLM request."""
|
||||
|
||||
@override
|
||||
async def run_async(
|
||||
self, invocation_context: InvocationContext, llm_request: LlmRequest
|
||||
) -> AsyncGenerator[Event, None]:
|
||||
from ...agents.llm_agent import LlmAgent
|
||||
|
||||
agent = invocation_context.agent
|
||||
if not isinstance(agent, LlmAgent):
|
||||
return
|
||||
|
||||
if agent.include_contents != 'none':
|
||||
llm_request.contents = _get_contents(
|
||||
invocation_context.branch,
|
||||
invocation_context.session.events,
|
||||
agent.name,
|
||||
)
|
||||
|
||||
# Maintain async generator behavior
|
||||
if False: # Ensures it behaves as a generator
|
||||
yield # This is a no-op but maintains generator structure
|
||||
|
||||
|
||||
request_processor = _ContentLlmRequestProcessor()
|
||||
|
||||
|
||||
def _rearrange_events_for_async_function_responses_in_history(
|
||||
events: list[Event],
|
||||
) -> list[Event]:
|
||||
"""Rearrange the async function_response events in the history."""
|
||||
|
||||
function_call_id_to_response_events_index: dict[str, list[Event]] = {}
|
||||
for i, event in enumerate(events):
|
||||
function_responses = event.get_function_responses()
|
||||
if function_responses:
|
||||
for function_response in function_responses:
|
||||
function_call_id = function_response.id
|
||||
function_call_id_to_response_events_index[function_call_id] = i
|
||||
|
||||
result_events: list[Event] = []
|
||||
for event in events:
|
||||
if event.get_function_responses():
|
||||
# function_response should be handled together with function_call below.
|
||||
continue
|
||||
elif event.get_function_calls():
|
||||
|
||||
function_response_events_indices = set()
|
||||
for function_call in event.get_function_calls():
|
||||
function_call_id = function_call.id
|
||||
if function_call_id in function_call_id_to_response_events_index:
|
||||
function_response_events_indices.add(
|
||||
function_call_id_to_response_events_index[function_call_id]
|
||||
)
|
||||
result_events.append(event)
|
||||
if not function_response_events_indices:
|
||||
continue
|
||||
if len(function_response_events_indices) == 1:
|
||||
result_events.append(
|
||||
events[next(iter(function_response_events_indices))]
|
||||
)
|
||||
else: # Merge all async function_response as one response event
|
||||
result_events.append(
|
||||
_merge_function_response_events(
|
||||
[events[i] for i in sorted(function_response_events_indices)]
|
||||
)
|
||||
)
|
||||
continue
|
||||
else:
|
||||
result_events.append(event)
|
||||
|
||||
return result_events
|
||||
|
||||
|
||||
def _rearrange_events_for_latest_function_response(
|
||||
events: list[Event],
|
||||
) -> list[Event]:
|
||||
"""Rearrange the events for the latest function_response.
|
||||
|
||||
If the latest function_response is for an async function_call, all events
|
||||
bewteen the initial function_call and the latest function_response will be
|
||||
removed.
|
||||
|
||||
Args:
|
||||
events: A list of events.
|
||||
|
||||
Returns:
|
||||
A list of events with the latest function_response rearranged.
|
||||
"""
|
||||
if not events:
|
||||
return events
|
||||
|
||||
function_responses = events[-1].get_function_responses()
|
||||
if not function_responses:
|
||||
# No need to process, since the latest event is not fuction_response.
|
||||
return events
|
||||
|
||||
function_responses_ids = set()
|
||||
for function_response in function_responses:
|
||||
function_responses_ids.add(function_response.id)
|
||||
|
||||
function_calls = events[-2].get_function_calls()
|
||||
|
||||
if function_calls:
|
||||
for function_call in function_calls:
|
||||
# The latest function_response is already matched
|
||||
if function_call.id in function_responses_ids:
|
||||
return events
|
||||
|
||||
function_call_event_idx = -1
|
||||
# look for corresponding function call event reversely
|
||||
for idx in range(len(events) - 2, -1, -1):
|
||||
event = events[idx]
|
||||
function_calls = event.get_function_calls()
|
||||
if function_calls:
|
||||
for function_call in function_calls:
|
||||
if function_call.id in function_responses_ids:
|
||||
function_call_event_idx = idx
|
||||
break
|
||||
if function_call_event_idx != -1:
|
||||
# in case the last response event only have part of the responses
|
||||
# for the function calls in the function call event
|
||||
for function_call in function_calls:
|
||||
function_responses_ids.add(function_call.id)
|
||||
break
|
||||
|
||||
if function_call_event_idx == -1:
|
||||
raise ValueError(
|
||||
'No function call event found for function responses ids:'
|
||||
f' {function_responses_ids}'
|
||||
)
|
||||
|
||||
# collect all function response between last function response event
|
||||
# and function call event
|
||||
|
||||
function_response_events: list[Event] = []
|
||||
for idx in range(function_call_event_idx + 1, len(events) - 1):
|
||||
event = events[idx]
|
||||
function_responses = event.get_function_responses()
|
||||
if (
|
||||
function_responses
|
||||
and function_responses[0].id in function_responses_ids
|
||||
):
|
||||
function_response_events.append(event)
|
||||
function_response_events.append(events[-1])
|
||||
|
||||
result_events = events[: function_call_event_idx + 1]
|
||||
result_events.append(
|
||||
_merge_function_response_events(function_response_events)
|
||||
)
|
||||
|
||||
return result_events
|
||||
|
||||
|
||||
def _get_contents(
|
||||
current_branch: Optional[str], events: list[Event], agent_name: str = ''
|
||||
) -> list[types.Content]:
|
||||
"""Get the contents for the LLM request.
|
||||
|
||||
Args:
|
||||
current_branch: The current branch of the agent.
|
||||
events: A list of events.
|
||||
agent_name: The name of the agent.
|
||||
|
||||
Returns:
|
||||
A list of contents.
|
||||
"""
|
||||
filtered_events = []
|
||||
# Parse the events, leaving the contents and the function calls and
|
||||
# responses from the current agent.
|
||||
for event in events:
|
||||
if not event.content or not event.content.role:
|
||||
# Skip events without content, or generated neither by user nor by model.
|
||||
# E.g. events purely for mutating session states.
|
||||
continue
|
||||
if not _is_event_belongs_to_branch(current_branch, event):
|
||||
# Skip events not belong to current branch.
|
||||
continue
|
||||
|
||||
filtered_events.append(
|
||||
_convert_foreign_event(event)
|
||||
if _is_other_agent_reply(agent_name, event)
|
||||
else event
|
||||
)
|
||||
|
||||
result_events = _rearrange_events_for_latest_function_response(
|
||||
filtered_events
|
||||
)
|
||||
result_events = _rearrange_events_for_async_function_responses_in_history(
|
||||
result_events
|
||||
)
|
||||
contents = []
|
||||
for event in result_events:
|
||||
content = copy.deepcopy(event.content)
|
||||
remove_client_function_call_id(content)
|
||||
contents.append(content)
|
||||
return contents
|
||||
|
||||
|
||||
def _is_other_agent_reply(current_agent_name: str, event: Event) -> bool:
|
||||
"""Whether the event is a reply from another agent."""
|
||||
return bool(
|
||||
current_agent_name
|
||||
and event.author != current_agent_name
|
||||
and event.author != 'user'
|
||||
)
|
||||
|
||||
|
||||
def _convert_foreign_event(event: Event) -> Event:
|
||||
"""Converts an event authored by another agent as a user-content event.
|
||||
|
||||
This is to provide another agent's output as context to the current agent, so
|
||||
that current agent can continue to respond, such as summarizing previous
|
||||
agent's reply, etc.
|
||||
|
||||
Args:
|
||||
event: The event to convert.
|
||||
|
||||
Returns:
|
||||
The converted event.
|
||||
|
||||
"""
|
||||
if not event.content or not event.content.parts:
|
||||
return event
|
||||
|
||||
content = types.Content()
|
||||
content.role = 'user'
|
||||
content.parts = [types.Part(text='For context:')]
|
||||
for part in event.content.parts:
|
||||
if part.text:
|
||||
content.parts.append(
|
||||
types.Part(text=f'[{event.author}] said: {part.text}')
|
||||
)
|
||||
elif part.function_call:
|
||||
content.parts.append(
|
||||
types.Part(
|
||||
text=(
|
||||
f'[{event.author}] called tool `{part.function_call.name}`'
|
||||
f' with parameters: {part.function_call.args}'
|
||||
)
|
||||
)
|
||||
)
|
||||
elif part.function_response:
|
||||
# Otherwise, create a new text part.
|
||||
content.parts.append(
|
||||
types.Part(
|
||||
text=(
|
||||
f'[{event.author}] `{part.function_response.name}` tool'
|
||||
f' returned result: {part.function_response.response}'
|
||||
)
|
||||
)
|
||||
)
|
||||
# Fallback to the original part for non-text and non-functionCall parts.
|
||||
else:
|
||||
content.parts.append(part)
|
||||
|
||||
return Event(
|
||||
timestamp=event.timestamp,
|
||||
author='user',
|
||||
content=content,
|
||||
branch=event.branch,
|
||||
)
|
||||
|
||||
|
||||
def _merge_function_response_events(
|
||||
function_response_events: list[Event],
|
||||
) -> Event:
|
||||
"""Merges a list of function_response events into one event.
|
||||
|
||||
The key goal is to ensure:
|
||||
1. function_call and function_response are always of the same number.
|
||||
2. The function_call and function_response are consecutively in the content.
|
||||
|
||||
Args:
|
||||
function_response_events: A list of function_response events.
|
||||
NOTE: function_response_events must fulfill these requirements: 1. The
|
||||
list is in increasing order of timestamp; 2. the first event is the
|
||||
initial function_reponse event; 3. all later events should contain at
|
||||
least one function_response part that related to the function_call
|
||||
event. (Note, 3. may not be true when aync function return some
|
||||
intermediate response, there could also be some intermediate model
|
||||
response event without any function_response and such event will be
|
||||
ignored.)
|
||||
Caveat: This implementation doesn't support when a parallel function_call
|
||||
event contains async function_call of the same name.
|
||||
|
||||
Returns:
|
||||
A merged event, that is
|
||||
1. All later function_response will replace function_response part in
|
||||
the initial function_response event.
|
||||
2. All non-function_response parts will be appended to the part list of
|
||||
the initial function_response event.
|
||||
"""
|
||||
if not function_response_events:
|
||||
raise ValueError('At least one function_response event is required.')
|
||||
|
||||
merged_event = function_response_events[0].model_copy(deep=True)
|
||||
parts_in_merged_event: list[types.Part] = merged_event.content.parts # type: ignore
|
||||
|
||||
if not parts_in_merged_event:
|
||||
raise ValueError('There should be at least one function_response part.')
|
||||
|
||||
part_indices_in_merged_event: dict[str, int] = {}
|
||||
for idx, part in enumerate(parts_in_merged_event):
|
||||
if part.function_response:
|
||||
function_call_id: str = part.function_response.id # type: ignore
|
||||
part_indices_in_merged_event[function_call_id] = idx
|
||||
|
||||
for event in function_response_events[1:]:
|
||||
if not event.content.parts:
|
||||
raise ValueError('There should be at least one function_response part.')
|
||||
|
||||
for part in event.content.parts:
|
||||
if part.function_response:
|
||||
function_call_id: str = part.function_response.id # type: ignore
|
||||
if function_call_id in part_indices_in_merged_event:
|
||||
parts_in_merged_event[
|
||||
part_indices_in_merged_event[function_call_id]
|
||||
] = part
|
||||
else:
|
||||
parts_in_merged_event.append(part)
|
||||
part_indices_in_merged_event[function_call_id] = (
|
||||
len(parts_in_merged_event) - 1
|
||||
)
|
||||
|
||||
else:
|
||||
parts_in_merged_event.append(part)
|
||||
|
||||
return merged_event
|
||||
|
||||
|
||||
def _is_event_belongs_to_branch(
|
||||
invocation_branch: Optional[str], event: Event
|
||||
) -> bool:
|
||||
"""Event belongs to a branch, when event.branch is prefix of the invocation branch."""
|
||||
if not invocation_branch or not event.branch:
|
||||
return True
|
||||
return invocation_branch.startswith(event.branch)
|
||||
463
src/google/adk/flows/llm_flows/functions.py
Normal file
463
src/google/adk/flows/llm_flows/functions.py
Normal file
@ -0,0 +1,463 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Handles function callings for LLM flow."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import inspect
|
||||
import logging
|
||||
from typing import Any
|
||||
from typing import AsyncGenerator
|
||||
from typing import cast
|
||||
from typing import Optional
|
||||
import uuid
|
||||
|
||||
from google.genai import types
|
||||
|
||||
from ...agents.active_streaming_tool import ActiveStreamingTool
|
||||
from ...agents.invocation_context import InvocationContext
|
||||
from ...auth.auth_tool import AuthToolArguments
|
||||
from ...events.event import Event
|
||||
from ...events.event_actions import EventActions
|
||||
from ...telemetry import tracer
|
||||
from ...tools.base_tool import BaseTool
|
||||
from ...tools.tool_context import ToolContext
|
||||
|
||||
AF_FUNCTION_CALL_ID_PREFIX = 'adk-'
|
||||
REQUEST_EUC_FUNCTION_CALL_NAME = 'adk_request_credential'
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def generate_client_function_call_id() -> str:
|
||||
return f'{AF_FUNCTION_CALL_ID_PREFIX}{uuid.uuid4()}'
|
||||
|
||||
|
||||
def populate_client_function_call_id(model_response_event: Event) -> None:
|
||||
if not model_response_event.get_function_calls():
|
||||
return
|
||||
for function_call in model_response_event.get_function_calls():
|
||||
if not function_call.id:
|
||||
function_call.id = generate_client_function_call_id()
|
||||
|
||||
|
||||
def remove_client_function_call_id(content: types.Content) -> None:
|
||||
if content and content.parts:
|
||||
for part in content.parts:
|
||||
if (
|
||||
part.function_call
|
||||
and part.function_call.id
|
||||
and part.function_call.id.startswith(AF_FUNCTION_CALL_ID_PREFIX)
|
||||
):
|
||||
part.function_call.id = None
|
||||
if (
|
||||
part.function_response
|
||||
and part.function_response.id
|
||||
and part.function_response.id.startswith(AF_FUNCTION_CALL_ID_PREFIX)
|
||||
):
|
||||
part.function_response.id = None
|
||||
|
||||
|
||||
def get_long_running_function_calls(
|
||||
function_calls: list[types.FunctionCall],
|
||||
tools_dict: dict[str, BaseTool],
|
||||
) -> set[str]:
|
||||
long_running_tool_ids = set()
|
||||
for function_call in function_calls:
|
||||
if (
|
||||
function_call.name in tools_dict
|
||||
and tools_dict[function_call.name].is_long_running
|
||||
):
|
||||
long_running_tool_ids.add(function_call.id)
|
||||
|
||||
return long_running_tool_ids
|
||||
|
||||
|
||||
def generate_auth_event(
|
||||
invocation_context: InvocationContext,
|
||||
function_response_event: Event,
|
||||
) -> Optional[Event]:
|
||||
if not function_response_event.actions.requested_auth_configs:
|
||||
return None
|
||||
parts = []
|
||||
long_running_tool_ids = set()
|
||||
for (
|
||||
function_call_id,
|
||||
auth_config,
|
||||
) in function_response_event.actions.requested_auth_configs.items():
|
||||
|
||||
request_euc_function_call = types.FunctionCall(
|
||||
name=REQUEST_EUC_FUNCTION_CALL_NAME,
|
||||
args=AuthToolArguments(
|
||||
function_call_id=function_call_id,
|
||||
auth_config=auth_config,
|
||||
).model_dump(exclude_none=True),
|
||||
)
|
||||
request_euc_function_call.id = generate_client_function_call_id()
|
||||
long_running_tool_ids.add(request_euc_function_call.id)
|
||||
parts.append(types.Part(function_call=request_euc_function_call))
|
||||
|
||||
return Event(
|
||||
invocation_id=invocation_context.invocation_id,
|
||||
author=invocation_context.agent.name,
|
||||
branch=invocation_context.branch,
|
||||
content=types.Content(parts=parts),
|
||||
long_running_tool_ids=long_running_tool_ids,
|
||||
)
|
||||
|
||||
|
||||
async def handle_function_calls_async(
|
||||
invocation_context: InvocationContext,
|
||||
function_call_event: Event,
|
||||
tools_dict: dict[str, BaseTool],
|
||||
filters: Optional[set[str]] = None,
|
||||
) -> Optional[Event]:
|
||||
"""Calls the functions and returns the function response event."""
|
||||
from ...agents.llm_agent import LlmAgent
|
||||
|
||||
agent = invocation_context.agent
|
||||
if not isinstance(agent, LlmAgent):
|
||||
return
|
||||
|
||||
function_calls = function_call_event.get_function_calls()
|
||||
|
||||
function_response_events: list[Event] = []
|
||||
for function_call in function_calls:
|
||||
if filters and function_call.id not in filters:
|
||||
continue
|
||||
tool, tool_context = _get_tool_and_context(
|
||||
invocation_context,
|
||||
function_call_event,
|
||||
function_call,
|
||||
tools_dict,
|
||||
)
|
||||
# do not use "args" as the variable name, because it is a reserved keyword
|
||||
# in python debugger.
|
||||
function_args = function_call.args or {}
|
||||
function_response = None
|
||||
# Calls the tool if before_tool_callback does not exist or returns None.
|
||||
if agent.before_tool_callback:
|
||||
function_response = agent.before_tool_callback(
|
||||
tool=tool, args=function_args, tool_context=tool_context
|
||||
)
|
||||
|
||||
if not function_response:
|
||||
function_response = await __call_tool_async(
|
||||
tool, args=function_args, tool_context=tool_context
|
||||
)
|
||||
|
||||
# Calls after_tool_callback if it exists.
|
||||
if agent.after_tool_callback:
|
||||
new_response = agent.after_tool_callback(
|
||||
tool=tool,
|
||||
args=function_args,
|
||||
tool_context=tool_context,
|
||||
tool_response=function_response,
|
||||
)
|
||||
if new_response:
|
||||
function_response = new_response
|
||||
|
||||
if tool.is_long_running:
|
||||
# Allow long running function to return None to not provide function response.
|
||||
if not function_response:
|
||||
continue
|
||||
|
||||
# Builds the function response event.
|
||||
function_response_event = __build_response_event(
|
||||
tool, function_response, tool_context, invocation_context
|
||||
)
|
||||
function_response_events.append(function_response_event)
|
||||
|
||||
if not function_response_events:
|
||||
return None
|
||||
merged_event = merge_parallel_function_response_events(
|
||||
function_response_events
|
||||
)
|
||||
return merged_event
|
||||
|
||||
|
||||
async def handle_function_calls_live(
|
||||
invocation_context: InvocationContext,
|
||||
function_call_event: Event,
|
||||
tools_dict: dict[str, BaseTool],
|
||||
) -> Event:
|
||||
"""Calls the functions and returns the function response event."""
|
||||
from ...agents.llm_agent import LlmAgent
|
||||
|
||||
agent = cast(LlmAgent, invocation_context.agent)
|
||||
function_calls = function_call_event.get_function_calls()
|
||||
|
||||
function_response_events: list[Event] = []
|
||||
for function_call in function_calls:
|
||||
tool, tool_context = _get_tool_and_context(
|
||||
invocation_context, function_call_event, function_call, tools_dict
|
||||
)
|
||||
# do not use "args" as the variable name, because it is a reserved keyword
|
||||
# in python debugger.
|
||||
function_args = function_call.args or {}
|
||||
function_response = None
|
||||
# Calls the tool if before_tool_callback does not exist or returns None.
|
||||
if agent.before_tool_callback:
|
||||
function_response = agent.before_tool_callback(
|
||||
tool, function_args, tool_context
|
||||
)
|
||||
|
||||
if not function_response:
|
||||
function_response = await _process_function_live_helper(
|
||||
tool, tool_context, function_call, function_args, invocation_context
|
||||
)
|
||||
|
||||
# Calls after_tool_callback if it exists.
|
||||
if agent.after_tool_callback:
|
||||
new_response = agent.after_tool_callback(
|
||||
tool,
|
||||
function_args,
|
||||
tool_context,
|
||||
function_response,
|
||||
)
|
||||
if new_response:
|
||||
function_response = new_response
|
||||
|
||||
if tool.is_long_running:
|
||||
# Allow async function to return None to not provide function response.
|
||||
if not function_response:
|
||||
continue
|
||||
|
||||
# Builds the function response event.
|
||||
function_response_event = __build_response_event(
|
||||
tool, function_response, tool_context, invocation_context
|
||||
)
|
||||
function_response_events.append(function_response_event)
|
||||
|
||||
if not function_response_events:
|
||||
return None
|
||||
merged_event = merge_parallel_function_response_events(
|
||||
function_response_events
|
||||
)
|
||||
return merged_event
|
||||
|
||||
|
||||
async def _process_function_live_helper(
|
||||
tool, tool_context, function_call, function_args, invocation_context
|
||||
):
|
||||
function_response = None
|
||||
# Check if this is a stop_streaming function call
|
||||
if (
|
||||
function_call.name == 'stop_streaming'
|
||||
and 'function_name' in function_args
|
||||
):
|
||||
function_name = function_args['function_name']
|
||||
active_tasks = invocation_context.active_streaming_tools
|
||||
if (
|
||||
function_name in active_tasks
|
||||
and active_tasks[function_name].task
|
||||
and not active_tasks[function_name].task.done()
|
||||
):
|
||||
task = active_tasks[function_name].task
|
||||
task.cancel()
|
||||
try:
|
||||
# Wait for the task to be cancelled
|
||||
await asyncio.wait_for(task, timeout=1.0)
|
||||
except (asyncio.CancelledError, asyncio.TimeoutError):
|
||||
# Log the specific condition
|
||||
if task.cancelled():
|
||||
logging.info(f'Task {function_name} was cancelled successfully')
|
||||
elif task.done():
|
||||
logging.info(f'Task {function_name} completed during cancellation')
|
||||
else:
|
||||
logging.warning(
|
||||
f'Task {function_name} might still be running after'
|
||||
' cancellation timeout'
|
||||
)
|
||||
function_response = {
|
||||
'status': f'The task is not cancelled yet for {function_name}.'
|
||||
}
|
||||
if not function_response:
|
||||
# Clean up the reference
|
||||
active_tasks[function_name].task = None
|
||||
|
||||
function_response = {
|
||||
'status': f'Successfully stopped streaming function {function_name}'
|
||||
}
|
||||
else:
|
||||
function_response = {
|
||||
'status': f'No active streaming function named {function_name} found'
|
||||
}
|
||||
elif inspect.isasyncgenfunction(tool.func):
|
||||
print('is async')
|
||||
|
||||
# for streaming tool use case
|
||||
# we require the function to be a async generator function
|
||||
async def run_tool_and_update_queue(tool, function_args, tool_context):
|
||||
try:
|
||||
async for result in __call_tool_live(
|
||||
tool=tool,
|
||||
args=function_args,
|
||||
tool_context=tool_context,
|
||||
invocation_context=invocation_context,
|
||||
):
|
||||
updated_content = types.Content(
|
||||
role='user',
|
||||
parts=[
|
||||
types.Part.from_text(
|
||||
text=f'Function {tool.name} returned: {result}'
|
||||
)
|
||||
],
|
||||
)
|
||||
invocation_context.live_request_queue.send_content(updated_content)
|
||||
except asyncio.CancelledError:
|
||||
raise # Re-raise to properly propagate the cancellation
|
||||
|
||||
task = asyncio.create_task(
|
||||
run_tool_and_update_queue(tool, function_args, tool_context)
|
||||
)
|
||||
if invocation_context.active_streaming_tools is None:
|
||||
invocation_context.active_streaming_tools = {}
|
||||
if tool.name in invocation_context.active_streaming_tools:
|
||||
invocation_context.active_streaming_tools[tool.name].task = task
|
||||
else:
|
||||
invocation_context.active_streaming_tools[tool.name] = (
|
||||
ActiveStreamingTool(task=task)
|
||||
)
|
||||
# Immediately return a pending response.
|
||||
# This is required by current live model.
|
||||
function_response = {
|
||||
'status': (
|
||||
'The function is running asynchronously and the results are'
|
||||
' pending.'
|
||||
)
|
||||
}
|
||||
else:
|
||||
function_response = await __call_tool_async(
|
||||
tool, args=function_args, tool_context=tool_context
|
||||
)
|
||||
return function_response
|
||||
|
||||
|
||||
def _get_tool_and_context(
|
||||
invocation_context: InvocationContext,
|
||||
function_call_event: Event,
|
||||
function_call: types.FunctionCall,
|
||||
tools_dict: dict[str, BaseTool],
|
||||
):
|
||||
if function_call.name not in tools_dict:
|
||||
raise ValueError(
|
||||
f'Function {function_call.name} is not found in the tools_dict.'
|
||||
)
|
||||
|
||||
tool_context = ToolContext(
|
||||
invocation_context=invocation_context,
|
||||
function_call_id=function_call.id,
|
||||
)
|
||||
|
||||
tool = tools_dict[function_call.name]
|
||||
|
||||
return (tool, tool_context)
|
||||
|
||||
|
||||
async def __call_tool_live(
|
||||
tool: BaseTool,
|
||||
args: dict[str, object],
|
||||
tool_context: ToolContext,
|
||||
invocation_context: InvocationContext,
|
||||
) -> AsyncGenerator[Event, None]:
|
||||
"""Calls the tool asynchronously (awaiting the coroutine)."""
|
||||
with tracer.start_as_current_span(f'call_tool [{tool.name}]'):
|
||||
async for item in tool._call_live(
|
||||
args=args,
|
||||
tool_context=tool_context,
|
||||
invocation_context=invocation_context,
|
||||
):
|
||||
yield item
|
||||
|
||||
|
||||
async def __call_tool_async(
|
||||
tool: BaseTool,
|
||||
args: dict[str, Any],
|
||||
tool_context: ToolContext,
|
||||
) -> Any:
|
||||
"""Calls the tool."""
|
||||
with tracer.start_as_current_span(f'call_tool [{tool.name}]'):
|
||||
return await tool.run_async(args=args, tool_context=tool_context)
|
||||
|
||||
|
||||
def __build_response_event(
|
||||
tool: BaseTool,
|
||||
function_result: dict[str, object],
|
||||
tool_context: ToolContext,
|
||||
invocation_context: InvocationContext,
|
||||
) -> Event:
|
||||
# Specs requires the result to be a dict.
|
||||
if not isinstance(function_result, dict):
|
||||
function_result = {'result': function_result}
|
||||
|
||||
part_function_response = types.Part.from_function_response(
|
||||
name=tool.name, response=function_result
|
||||
)
|
||||
part_function_response.function_response.id = tool_context.function_call_id
|
||||
|
||||
content = types.Content(
|
||||
role='user',
|
||||
parts=[part_function_response],
|
||||
)
|
||||
return Event(
|
||||
invocation_id=invocation_context.invocation_id,
|
||||
author=invocation_context.agent.name,
|
||||
content=content,
|
||||
actions=tool_context.actions,
|
||||
branch=invocation_context.branch,
|
||||
)
|
||||
|
||||
|
||||
def merge_parallel_function_response_events(
|
||||
function_response_events: list['Event'],
|
||||
) -> 'Event':
|
||||
if not function_response_events:
|
||||
raise ValueError('No function response events provided.')
|
||||
|
||||
if len(function_response_events) == 1:
|
||||
return function_response_events[0]
|
||||
merged_parts = []
|
||||
for event in function_response_events:
|
||||
if event.content:
|
||||
for part in event.content.parts or []:
|
||||
merged_parts.append(part)
|
||||
|
||||
# Use the first event as the "base" for common attributes
|
||||
base_event = function_response_events[0]
|
||||
|
||||
# Merge actions from all events
|
||||
|
||||
merged_actions = EventActions()
|
||||
merged_requested_auth_configs = {}
|
||||
for event in function_response_events:
|
||||
merged_requested_auth_configs.update(event.actions.requested_auth_configs)
|
||||
merged_actions = merged_actions.model_copy(
|
||||
update=event.actions.model_dump()
|
||||
)
|
||||
merged_actions.requested_auth_configs = merged_requested_auth_configs
|
||||
# Create the new merged event
|
||||
merged_event = Event(
|
||||
invocation_id=Event.new_id(),
|
||||
author=base_event.author,
|
||||
branch=base_event.branch,
|
||||
content=types.Content(role='user', parts=merged_parts),
|
||||
actions=merged_actions, # Optionally merge actions if required
|
||||
)
|
||||
|
||||
# Use the base_event as the timestamp
|
||||
merged_event.timestamp = base_event.timestamp
|
||||
return merged_event
|
||||
47
src/google/adk/flows/llm_flows/identity.py
Normal file
47
src/google/adk/flows/llm_flows/identity.py
Normal file
@ -0,0 +1,47 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Gives the agent identity from the framework."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import AsyncGenerator
|
||||
|
||||
from typing_extensions import override
|
||||
|
||||
from ...agents.invocation_context import InvocationContext
|
||||
from ...events.event import Event
|
||||
from ...models.llm_request import LlmRequest
|
||||
from ._base_llm_processor import BaseLlmRequestProcessor
|
||||
|
||||
|
||||
class _IdentityLlmRequestProcessor(BaseLlmRequestProcessor):
|
||||
"""Gives the agent identity from the framework."""
|
||||
|
||||
@override
|
||||
async def run_async(
|
||||
self, invocation_context: InvocationContext, llm_request: LlmRequest
|
||||
) -> AsyncGenerator[Event, None]:
|
||||
agent = invocation_context.agent
|
||||
si = [f'You are an agent. Your internal name is "{agent.name}".']
|
||||
if agent.description:
|
||||
si.append(f' The description about you is "{agent.description}"')
|
||||
llm_request.append_instructions(si)
|
||||
|
||||
# Maintain async generator behavior
|
||||
if False: # Ensures it behaves as a generator
|
||||
yield # This is a no-op but maintains generator structure
|
||||
|
||||
|
||||
request_processor = _IdentityLlmRequestProcessor()
|
||||
137
src/google/adk/flows/llm_flows/instructions.py
Normal file
137
src/google/adk/flows/llm_flows/instructions.py
Normal file
@ -0,0 +1,137 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Handles instructions and global instructions for LLM flow."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
from typing import AsyncGenerator
|
||||
from typing import Generator
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from typing_extensions import override
|
||||
|
||||
from ...agents.readonly_context import ReadonlyContext
|
||||
from ...events.event import Event
|
||||
from ...sessions.state import State
|
||||
from ._base_llm_processor import BaseLlmRequestProcessor
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ...agents.invocation_context import InvocationContext
|
||||
from ...models.llm_request import LlmRequest
|
||||
|
||||
|
||||
class _InstructionsLlmRequestProcessor(BaseLlmRequestProcessor):
|
||||
"""Handles instructions and global instructions for LLM flow."""
|
||||
|
||||
@override
|
||||
async def run_async(
|
||||
self, invocation_context: InvocationContext, llm_request: LlmRequest
|
||||
) -> AsyncGenerator[Event, None]:
|
||||
from ...agents.base_agent import BaseAgent
|
||||
from ...agents.llm_agent import LlmAgent
|
||||
|
||||
agent = invocation_context.agent
|
||||
if not isinstance(agent, LlmAgent):
|
||||
return
|
||||
|
||||
root_agent: BaseAgent = agent.root_agent
|
||||
|
||||
# Appends global instructions if set.
|
||||
if (
|
||||
isinstance(root_agent, LlmAgent) and root_agent.global_instruction
|
||||
): # not emtpy str
|
||||
raw_si = root_agent.canonical_global_instruction(
|
||||
ReadonlyContext(invocation_context)
|
||||
)
|
||||
si = _populate_values(raw_si, invocation_context)
|
||||
llm_request.append_instructions([si])
|
||||
|
||||
# Appends agent instructions if set.
|
||||
if agent.instruction: # not emtpy str
|
||||
raw_si = agent.canonical_instruction(ReadonlyContext(invocation_context))
|
||||
si = _populate_values(raw_si, invocation_context)
|
||||
llm_request.append_instructions([si])
|
||||
|
||||
# Maintain async generator behavior
|
||||
if False: # Ensures it behaves as a generator
|
||||
yield # This is a no-op but maintains generator structure
|
||||
|
||||
|
||||
request_processor = _InstructionsLlmRequestProcessor()
|
||||
|
||||
|
||||
def _populate_values(
|
||||
instruction_template: str,
|
||||
context: InvocationContext,
|
||||
) -> str:
|
||||
"""Populates values in the instruction template, e.g. state, artifact, etc."""
|
||||
|
||||
def _replace_match(match) -> str:
|
||||
var_name = match.group().lstrip('{').rstrip('}').strip()
|
||||
optional = False
|
||||
if var_name.endswith('?'):
|
||||
optional = True
|
||||
var_name = var_name.removesuffix('?')
|
||||
if var_name.startswith('artifact.'):
|
||||
var_name = var_name.removeprefix('artifact.')
|
||||
if context.artifact_service is None:
|
||||
raise ValueError('Artifact service is not initialized.')
|
||||
artifact = context.artifact_service.load_artifact(
|
||||
app_name=context.session.app_name,
|
||||
user_id=context.session.user_id,
|
||||
session_id=context.session.id,
|
||||
filename=var_name,
|
||||
)
|
||||
if not var_name:
|
||||
raise KeyError(f'Artifact {var_name} not found.')
|
||||
return str(artifact)
|
||||
else:
|
||||
if not _is_valid_state_name(var_name):
|
||||
return match.group()
|
||||
if var_name in context.session.state:
|
||||
return str(context.session.state[var_name])
|
||||
else:
|
||||
if optional:
|
||||
return ''
|
||||
else:
|
||||
raise KeyError(f'Context variable not found: `{var_name}`.')
|
||||
|
||||
return re.sub(r'{+[^{}]*}+', _replace_match, instruction_template)
|
||||
|
||||
|
||||
def _is_valid_state_name(var_name):
|
||||
"""Checks if the variable name is a valid state name.
|
||||
|
||||
Valid state is either:
|
||||
- Valid identifier
|
||||
- <Valid prefix>:<Valid identifier>
|
||||
All the others will just return as it is.
|
||||
|
||||
Args:
|
||||
var_name: The variable name to check.
|
||||
|
||||
Returns:
|
||||
True if the variable name is a valid state name, False otherwise.
|
||||
"""
|
||||
parts = var_name.split(':')
|
||||
if len(parts) == 1:
|
||||
return var_name.isidentifier()
|
||||
|
||||
if len(parts) == 2:
|
||||
prefixes = [State.APP_PREFIX, State.USER_PREFIX, State.TEMP_PREFIX]
|
||||
if (parts[0] + ':') in prefixes:
|
||||
return parts[1].isidentifier()
|
||||
return False
|
||||
57
src/google/adk/flows/llm_flows/single_flow.py
Normal file
57
src/google/adk/flows/llm_flows/single_flow.py
Normal file
@ -0,0 +1,57 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Implementation of single flow."""
|
||||
|
||||
import logging
|
||||
|
||||
from ...auth import auth_preprocessor
|
||||
from . import _code_execution
|
||||
from . import _nl_planning
|
||||
from . import basic
|
||||
from . import contents
|
||||
from . import identity
|
||||
from . import instructions
|
||||
from .base_llm_flow import BaseLlmFlow
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SingleFlow(BaseLlmFlow):
|
||||
"""SingleFlow is the LLM flows that handles tools calls.
|
||||
|
||||
A single flow only consider an agent itself and tools.
|
||||
No sub-agents are allowed for single flow.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.request_processors += [
|
||||
basic.request_processor,
|
||||
auth_preprocessor.request_processor,
|
||||
instructions.request_processor,
|
||||
identity.request_processor,
|
||||
contents.request_processor,
|
||||
# Some implementations of NL Planning mark planning contents as thoughts
|
||||
# in the post processor. Since these need to be unmarked, NL Planning
|
||||
# should be after contents.
|
||||
_nl_planning.request_processor,
|
||||
# Code execution should be after the contents as it mutates the contents
|
||||
# to optimize data files.
|
||||
_code_execution.request_processor,
|
||||
]
|
||||
self.response_processors += [
|
||||
_nl_planning.response_processor,
|
||||
_code_execution.response_processor,
|
||||
]
|
||||
35
src/google/adk/memory/__init__.py
Normal file
35
src/google/adk/memory/__init__.py
Normal file
@ -0,0 +1,35 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import logging
|
||||
|
||||
from .base_memory_service import BaseMemoryService
|
||||
from .in_memory_memory_service import InMemoryMemoryService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
__all__ = [
|
||||
'BaseMemoryService',
|
||||
'InMemoryMemoryService',
|
||||
]
|
||||
|
||||
try:
|
||||
from .vertex_ai_rag_memory_service import VertexAiRagMemoryService
|
||||
|
||||
__all__.append('VertexAiRagMemoryService')
|
||||
except ImportError:
|
||||
logger.debug(
|
||||
'The Vertex sdk is not installed. If you want to use the'
|
||||
' VertexAiRagMemoryService please install it. If not, you can ignore this'
|
||||
' warning.'
|
||||
)
|
||||
74
src/google/adk/memory/base_memory_service.py
Normal file
74
src/google/adk/memory/base_memory_service.py
Normal file
@ -0,0 +1,74 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import abc
|
||||
|
||||
from pydantic import BaseModel
|
||||
from pydantic import Field
|
||||
|
||||
from ..events.event import Event
|
||||
from ..sessions.session import Session
|
||||
|
||||
|
||||
class MemoryResult(BaseModel):
|
||||
"""Represents a single memory retrieval result.
|
||||
|
||||
Attributes:
|
||||
session_id: The session id associated with the memory.
|
||||
events: A list of events in the session.
|
||||
"""
|
||||
session_id: str
|
||||
events: list[Event]
|
||||
|
||||
|
||||
class SearchMemoryResponse(BaseModel):
|
||||
"""Represents the response from a memory search.
|
||||
|
||||
Attributes:
|
||||
memories: A list of memory results matching the search query.
|
||||
"""
|
||||
memories: list[MemoryResult] = Field(default_factory=list)
|
||||
|
||||
|
||||
class BaseMemoryService(abc.ABC):
|
||||
"""Base class for memory services.
|
||||
|
||||
The service provides functionalities to ingest sessions into memory so that
|
||||
the memory can be used for user queries.
|
||||
"""
|
||||
|
||||
@abc.abstractmethod
|
||||
def add_session_to_memory(self, session: Session):
|
||||
"""Adds a session to the memory service.
|
||||
|
||||
A session may be added multiple times during its lifetime.
|
||||
|
||||
Args:
|
||||
session: The session to add.
|
||||
"""
|
||||
|
||||
@abc.abstractmethod
|
||||
def search_memory(
|
||||
self, *, app_name: str, user_id: str, query: str
|
||||
) -> SearchMemoryResponse:
|
||||
"""Searches for sessions that match the query.
|
||||
|
||||
Args:
|
||||
app_name: The name of the application.
|
||||
user_id: The id of the user.
|
||||
query: The query to search for.
|
||||
|
||||
Returns:
|
||||
A SearchMemoryResponse containing the matching memories.
|
||||
"""
|
||||
62
src/google/adk/memory/in_memory_memory_service.py
Normal file
62
src/google/adk/memory/in_memory_memory_service.py
Normal file
@ -0,0 +1,62 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from ..events.event import Event
|
||||
from ..sessions.session import Session
|
||||
from .base_memory_service import BaseMemoryService
|
||||
from .base_memory_service import MemoryResult
|
||||
from .base_memory_service import SearchMemoryResponse
|
||||
|
||||
|
||||
class InMemoryMemoryService(BaseMemoryService):
|
||||
"""An in-memory memory service for prototyping purpose only.
|
||||
|
||||
Uses keyword matching instead of semantic search.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.session_events: dict[str, list[Event]] = {}
|
||||
"""keys are app_name/user_id/session_id"""
|
||||
|
||||
def add_session_to_memory(self, session: Session):
|
||||
key = f'{session.app_name}/{session.user_id}/{session.id}'
|
||||
self.session_events[key] = [
|
||||
event for event in session.events if event.content
|
||||
]
|
||||
|
||||
def search_memory(
|
||||
self, *, app_name: str, user_id: str, query: str
|
||||
) -> SearchMemoryResponse:
|
||||
"""Prototyping purpose only."""
|
||||
keywords = set(query.lower().split())
|
||||
response = SearchMemoryResponse()
|
||||
for key, events in self.session_events.items():
|
||||
if not key.startswith(f'{app_name}/{user_id}/'):
|
||||
continue
|
||||
matched_events = []
|
||||
for event in events:
|
||||
if not event.content or not event.content.parts:
|
||||
continue
|
||||
parts = event.content.parts
|
||||
text = '\n'.join([part.text for part in parts if part.text]).lower()
|
||||
for keyword in keywords:
|
||||
if keyword in text:
|
||||
matched_events.append(event)
|
||||
break
|
||||
if matched_events:
|
||||
session_id = key.split('/')[-1]
|
||||
response.memories.append(
|
||||
MemoryResult(session_id=session_id, events=matched_events)
|
||||
)
|
||||
return response
|
||||
177
src/google/adk/memory/vertex_ai_rag_memory_service.py
Normal file
177
src/google/adk/memory/vertex_ai_rag_memory_service.py
Normal file
@ -0,0 +1,177 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from collections import OrderedDict
|
||||
import json
|
||||
import os
|
||||
import tempfile
|
||||
|
||||
from google.genai import types
|
||||
from typing_extensions import override
|
||||
from vertexai.preview import rag
|
||||
|
||||
from ..events.event import Event
|
||||
from ..sessions.session import Session
|
||||
from .base_memory_service import BaseMemoryService
|
||||
from .base_memory_service import MemoryResult
|
||||
from .base_memory_service import SearchMemoryResponse
|
||||
|
||||
|
||||
class VertexAiRagMemoryService(BaseMemoryService):
|
||||
"""A memory service that uses Vertex AI RAG for storage and retrieval."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
rag_corpus: str = None,
|
||||
similarity_top_k: int = None,
|
||||
vector_distance_threshold: float = 10,
|
||||
):
|
||||
"""Initializes a VertexAiRagMemoryService.
|
||||
|
||||
Args:
|
||||
rag_corpus: The name of the Vertex AI RAG corpus to use. Format:
|
||||
``projects/{project}/locations/{location}/ragCorpora/{rag_corpus_id}``
|
||||
or ``{rag_corpus_id}``
|
||||
similarity_top_k: The number of contexts to retrieve.
|
||||
vector_distance_threshold: Only returns contexts with vector distance
|
||||
smaller than the threshold..
|
||||
"""
|
||||
self.vertex_rag_store = types.VertexRagStore(
|
||||
rag_resources=[rag.RagResource(rag_corpus=rag_corpus)],
|
||||
similarity_top_k=similarity_top_k,
|
||||
vector_distance_threshold=vector_distance_threshold,
|
||||
)
|
||||
|
||||
@override
|
||||
def add_session_to_memory(self, session: Session):
|
||||
with tempfile.NamedTemporaryFile(
|
||||
mode="w", delete=False, suffix=".txt"
|
||||
) as temp_file:
|
||||
|
||||
output_lines = []
|
||||
for event in session.events:
|
||||
if not event.content or not event.content.parts:
|
||||
continue
|
||||
text_parts = [
|
||||
part.text.replace("\n", " ")
|
||||
for part in event.content.parts
|
||||
if part.text
|
||||
]
|
||||
if text_parts:
|
||||
output_lines.append(
|
||||
json.dumps({
|
||||
"author": event.author,
|
||||
"timestamp": event.timestamp,
|
||||
"text": ".".join(text_parts),
|
||||
})
|
||||
)
|
||||
output_string = "\n".join(output_lines)
|
||||
temp_file.write(output_string)
|
||||
temp_file_path = temp_file.name
|
||||
for rag_resource in self.vertex_rag_store.rag_resources:
|
||||
rag.upload_file(
|
||||
corpus_name=rag_resource.rag_corpus,
|
||||
path=temp_file_path,
|
||||
# this is the temp workaround as upload file does not support
|
||||
# adding metadata, thus use display_name to store the session info.
|
||||
display_name=f"{session.app_name}.{session.user_id}.{session.id}",
|
||||
)
|
||||
|
||||
os.remove(temp_file_path)
|
||||
|
||||
@override
|
||||
def search_memory(
|
||||
self, *, app_name: str, user_id: str, query: str
|
||||
) -> SearchMemoryResponse:
|
||||
"""Searches for sessions that match the query using rag.retrieval_query."""
|
||||
response = rag.retrieval_query(
|
||||
text=query,
|
||||
rag_resources=self.vertex_rag_store.rag_resources,
|
||||
rag_corpora=self.vertex_rag_store.rag_corpora,
|
||||
similarity_top_k=self.vertex_rag_store.similarity_top_k,
|
||||
vector_distance_threshold=self.vertex_rag_store.vector_distance_threshold,
|
||||
)
|
||||
|
||||
memory_results = []
|
||||
session_events_map = OrderedDict()
|
||||
for context in response.contexts.contexts:
|
||||
# filter out context that is not related
|
||||
# TODO: Add server side filtering by app_name and user_id.
|
||||
# if not context.source_display_name.startswith(f"{app_name}.{user_id}."):
|
||||
# continue
|
||||
session_id = context.source_display_name.split(".")[-1]
|
||||
events = []
|
||||
if context.text:
|
||||
lines = context.text.split("\n")
|
||||
|
||||
for line in lines:
|
||||
line = line.strip()
|
||||
if not line:
|
||||
continue
|
||||
|
||||
try:
|
||||
# Try to parse as JSON
|
||||
event_data = json.loads(line)
|
||||
|
||||
author = event_data.get("author", "")
|
||||
timestamp = float(event_data.get("timestamp", 0))
|
||||
text = event_data.get("text", "")
|
||||
|
||||
content = types.Content(parts=[types.Part(text=text)])
|
||||
event = Event(author=author, timestamp=timestamp, content=content)
|
||||
events.append(event)
|
||||
except json.JSONDecodeError:
|
||||
# Not valid JSON, skip this line
|
||||
continue
|
||||
|
||||
if session_id in session_events_map:
|
||||
session_events_map[session_id].append(events)
|
||||
else:
|
||||
session_events_map[session_id] = [events]
|
||||
|
||||
# Remove overlap and combine events from the same session.
|
||||
for session_id, event_lists in session_events_map.items():
|
||||
for events in _merge_event_lists(event_lists):
|
||||
sorted_events = sorted(events, key=lambda e: e.timestamp)
|
||||
memory_results.append(
|
||||
MemoryResult(session_id=session_id, events=sorted_events)
|
||||
)
|
||||
return SearchMemoryResponse(memories=memory_results)
|
||||
|
||||
|
||||
def _merge_event_lists(event_lists: list[list[Event]]) -> list[list[Event]]:
|
||||
"""Merge event lists that have overlapping timestamps."""
|
||||
merged = []
|
||||
while event_lists:
|
||||
current = event_lists.pop(0)
|
||||
current_ts = {event.timestamp for event in current}
|
||||
merge_found = True
|
||||
|
||||
# Keep merging until no new overlap is found.
|
||||
while merge_found:
|
||||
merge_found = False
|
||||
remaining = []
|
||||
for other in event_lists:
|
||||
other_ts = {event.timestamp for event in other}
|
||||
# Overlap exists, so we merge and use the merged list to check again
|
||||
if current_ts & other_ts:
|
||||
new_events = [e for e in other if e.timestamp not in current_ts]
|
||||
current.extend(new_events)
|
||||
current_ts.update(e.timestamp for e in new_events)
|
||||
merge_found = True
|
||||
else:
|
||||
remaining.append(other)
|
||||
event_lists = remaining
|
||||
merged.append(current)
|
||||
return merged
|
||||
31
src/google/adk/models/__init__.py
Normal file
31
src/google/adk/models/__init__.py
Normal file
@ -0,0 +1,31 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Defines the interface to support a model."""
|
||||
|
||||
from .base_llm import BaseLlm
|
||||
from .google_llm import Gemini
|
||||
from .llm_request import LlmRequest
|
||||
from .llm_response import LlmResponse
|
||||
from .registry import LLMRegistry
|
||||
|
||||
__all__ = [
|
||||
'BaseLlm',
|
||||
'Gemini',
|
||||
'LLMRegistry',
|
||||
]
|
||||
|
||||
|
||||
for regex in Gemini.supported_models():
|
||||
LLMRegistry.register(Gemini)
|
||||
243
src/google/adk/models/anthropic_llm.py
Normal file
243
src/google/adk/models/anthropic_llm.py
Normal file
@ -0,0 +1,243 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Anthropic integration for Claude models."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from functools import cached_property
|
||||
import logging
|
||||
import os
|
||||
from typing import AsyncGenerator
|
||||
from typing import Generator
|
||||
from typing import Iterable
|
||||
from typing import Literal
|
||||
from typing import Optional, Union
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from anthropic import AnthropicVertex
|
||||
from anthropic import NOT_GIVEN
|
||||
from anthropic import types as anthropic_types
|
||||
from google.genai import types
|
||||
from pydantic import BaseModel
|
||||
from typing_extensions import override
|
||||
|
||||
from .base_llm import BaseLlm
|
||||
from .llm_response import LlmResponse
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .llm_request import LlmRequest
|
||||
|
||||
__all__ = ["Claude"]
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
MAX_TOKEN = 1024
|
||||
|
||||
|
||||
class ClaudeRequest(BaseModel):
|
||||
system_instruction: str
|
||||
messages: Iterable[anthropic_types.MessageParam]
|
||||
tools: list[anthropic_types.ToolParam]
|
||||
|
||||
|
||||
def to_claude_role(role: Optional[str]) -> Literal["user", "assistant"]:
|
||||
if role in ["model", "assistant"]:
|
||||
return "assistant"
|
||||
return "user"
|
||||
|
||||
|
||||
def to_google_genai_finish_reason(
|
||||
anthropic_stop_reason: Optional[str],
|
||||
) -> types.FinishReason:
|
||||
if anthropic_stop_reason in ["end_turn", "stop_sequence", "tool_use"]:
|
||||
return "STOP"
|
||||
if anthropic_stop_reason == "max_tokens":
|
||||
return "MAX_TOKENS"
|
||||
return "FINISH_REASON_UNSPECIFIED"
|
||||
|
||||
|
||||
def part_to_message_block(
|
||||
part: types.Part,
|
||||
) -> Union[
|
||||
anthropic_types.TextBlockParam,
|
||||
anthropic_types.ImageBlockParam,
|
||||
anthropic_types.ToolUseBlockParam,
|
||||
anthropic_types.ToolResultBlockParam,
|
||||
]:
|
||||
if part.text:
|
||||
return anthropic_types.TextBlockParam(text=part.text, type="text")
|
||||
if part.function_call:
|
||||
assert part.function_call.name
|
||||
|
||||
return anthropic_types.ToolUseBlockParam(
|
||||
id=part.function_call.id or "",
|
||||
name=part.function_call.name,
|
||||
input=part.function_call.args,
|
||||
type="tool_use",
|
||||
)
|
||||
if part.function_response:
|
||||
content = ""
|
||||
if (
|
||||
"result" in part.function_response.response
|
||||
and part.function_response.response["result"]
|
||||
):
|
||||
# Transformation is required because the content is a list of dict.
|
||||
# ToolResultBlockParam content doesn't support list of dict. Converting
|
||||
# to str to prevent anthropic.BadRequestError from being thrown.
|
||||
content = str(part.function_response.response["result"])
|
||||
return anthropic_types.ToolResultBlockParam(
|
||||
tool_use_id=part.function_response.id or "",
|
||||
type="tool_result",
|
||||
content=content,
|
||||
is_error=False,
|
||||
)
|
||||
raise NotImplementedError("Not supported yet.")
|
||||
|
||||
|
||||
def content_to_message_param(
|
||||
content: types.Content,
|
||||
) -> anthropic_types.MessageParam:
|
||||
return {
|
||||
"role": to_claude_role(content.role),
|
||||
"content": [part_to_message_block(part) for part in content.parts or []],
|
||||
}
|
||||
|
||||
|
||||
def content_block_to_part(
|
||||
content_block: anthropic_types.ContentBlock,
|
||||
) -> types.Part:
|
||||
if isinstance(content_block, anthropic_types.TextBlock):
|
||||
return types.Part.from_text(text=content_block.text)
|
||||
if isinstance(content_block, anthropic_types.ToolUseBlock):
|
||||
assert isinstance(content_block.input, dict)
|
||||
part = types.Part.from_function_call(
|
||||
name=content_block.name, args=content_block.input
|
||||
)
|
||||
part.function_call.id = content_block.id
|
||||
return part
|
||||
raise NotImplementedError("Not supported yet.")
|
||||
|
||||
|
||||
def message_to_generate_content_response(
|
||||
message: anthropic_types.Message,
|
||||
) -> LlmResponse:
|
||||
|
||||
return LlmResponse(
|
||||
content=types.Content(
|
||||
role="model",
|
||||
parts=[content_block_to_part(cb) for cb in message.content],
|
||||
),
|
||||
# TODO: Deal with these later.
|
||||
# finish_reason=to_google_genai_finish_reason(message.stop_reason),
|
||||
# usage_metadata=types.GenerateContentResponseUsageMetadata(
|
||||
# prompt_token_count=message.usage.input_tokens,
|
||||
# candidates_token_count=message.usage.output_tokens,
|
||||
# total_token_count=(
|
||||
# message.usage.input_tokens + message.usage.output_tokens
|
||||
# ),
|
||||
# ),
|
||||
)
|
||||
|
||||
|
||||
def function_declaration_to_tool_param(
|
||||
function_declaration: types.FunctionDeclaration,
|
||||
) -> anthropic_types.ToolParam:
|
||||
assert function_declaration.name
|
||||
|
||||
properties = {}
|
||||
if (
|
||||
function_declaration.parameters
|
||||
and function_declaration.parameters.properties
|
||||
):
|
||||
for key, value in function_declaration.parameters.properties.items():
|
||||
value_dict = value.model_dump(exclude_none=True)
|
||||
if "type" in value_dict:
|
||||
value_dict["type"] = value_dict["type"].lower()
|
||||
properties[key] = value_dict
|
||||
|
||||
return anthropic_types.ToolParam(
|
||||
name=function_declaration.name,
|
||||
description=function_declaration.description or "",
|
||||
input_schema={
|
||||
"type": "object",
|
||||
"properties": properties,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
class Claude(BaseLlm):
|
||||
model: str = "claude-3-5-sonnet-v2@20241022"
|
||||
|
||||
@staticmethod
|
||||
@override
|
||||
def supported_models() -> list[str]:
|
||||
return [r"claude-3-.*"]
|
||||
|
||||
@override
|
||||
async def generate_content_async(
|
||||
self, llm_request: LlmRequest, stream: bool = False
|
||||
) -> AsyncGenerator[LlmResponse, None]:
|
||||
messages = [
|
||||
content_to_message_param(content)
|
||||
for content in llm_request.contents or []
|
||||
]
|
||||
tools = NOT_GIVEN
|
||||
if (
|
||||
llm_request.config
|
||||
and llm_request.config.tools
|
||||
and llm_request.config.tools[0].function_declarations
|
||||
):
|
||||
tools = [
|
||||
function_declaration_to_tool_param(tool)
|
||||
for tool in llm_request.config.tools[0].function_declarations
|
||||
]
|
||||
tool_choice = (
|
||||
anthropic_types.ToolChoiceAutoParam(
|
||||
type="auto",
|
||||
# TODO: allow parallel tool use.
|
||||
disable_parallel_tool_use=True,
|
||||
)
|
||||
if llm_request.tools_dict
|
||||
else NOT_GIVEN
|
||||
)
|
||||
message = self._anthropic_client.messages.create(
|
||||
model=llm_request.model,
|
||||
system=llm_request.config.system_instruction,
|
||||
messages=messages,
|
||||
tools=tools,
|
||||
tool_choice=tool_choice,
|
||||
max_tokens=MAX_TOKEN,
|
||||
)
|
||||
logger.info(
|
||||
"Claude response: %s",
|
||||
message.model_dump_json(indent=2, exclude_none=True),
|
||||
)
|
||||
yield message_to_generate_content_response(message)
|
||||
|
||||
@cached_property
|
||||
def _anthropic_client(self) -> AnthropicVertex:
|
||||
if (
|
||||
"GOOGLE_CLOUD_PROJECT" not in os.environ
|
||||
or "GOOGLE_CLOUD_LOCATION" not in os.environ
|
||||
):
|
||||
raise ValueError(
|
||||
"GOOGLE_CLOUD_PROJECT and GOOGLE_CLOUD_LOCATION must be set for using"
|
||||
" Anthropic on Vertex."
|
||||
)
|
||||
|
||||
return AnthropicVertex(
|
||||
project_id=os.environ["GOOGLE_CLOUD_PROJECT"],
|
||||
region=os.environ["GOOGLE_CLOUD_LOCATION"],
|
||||
)
|
||||
87
src/google/adk/models/base_llm.py
Normal file
87
src/google/adk/models/base_llm.py
Normal file
@ -0,0 +1,87 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import abstractmethod
|
||||
from typing import AsyncGenerator
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from pydantic import BaseModel
|
||||
from pydantic import ConfigDict
|
||||
|
||||
from .base_llm_connection import BaseLlmConnection
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .llm_request import LlmRequest
|
||||
from .llm_response import LlmResponse
|
||||
|
||||
|
||||
class BaseLlm(BaseModel):
|
||||
"""The BaseLLM class.
|
||||
|
||||
Attributes:
|
||||
model: The name of the LLM, e.g. gemini-1.5-flash or gemini-1.5-flash-001.
|
||||
model_config: The model config
|
||||
"""
|
||||
|
||||
model_config = ConfigDict(
|
||||
# This allows us to use arbitrary types in the model. E.g. PIL.Image.
|
||||
arbitrary_types_allowed=True,
|
||||
)
|
||||
"""The model config."""
|
||||
|
||||
model: str
|
||||
"""The name of the LLM, e.g. gemini-1.5-flash or gemini-1.5-flash-001."""
|
||||
|
||||
@classmethod
|
||||
def supported_models(cls) -> list[str]:
|
||||
"""Returns a list of supported models in regex for LlmRegistry."""
|
||||
return []
|
||||
|
||||
@abstractmethod
|
||||
async def generate_content_async(
|
||||
self, llm_request: LlmRequest, stream: bool = False
|
||||
) -> AsyncGenerator[LlmResponse, None]:
|
||||
"""Generates one content from the given contents and tools.
|
||||
|
||||
Args:
|
||||
llm_request: LlmRequest, the request to send to the LLM.
|
||||
stream: bool = False, whether to do streaming call.
|
||||
|
||||
Yields:
|
||||
a generator of types.Content.
|
||||
|
||||
For non-streaming call, it will only yield one Content.
|
||||
|
||||
For streaming call, it may yield more than one content, but all yielded
|
||||
contents should be treated as one content by merging the
|
||||
parts list.
|
||||
"""
|
||||
raise NotImplementedError(
|
||||
f'Async generation is not supported for {self.model}.'
|
||||
)
|
||||
yield # AsyncGenerator requires a yield statement in function body.
|
||||
|
||||
def connect(self, llm_request: LlmRequest) -> BaseLlmConnection:
|
||||
"""Creates a live connection to the LLM.
|
||||
|
||||
Args:
|
||||
llm_request: LlmRequest, the request to send to the LLM.
|
||||
|
||||
Returns:
|
||||
BaseLlmConnection, the connection to the LLM.
|
||||
"""
|
||||
raise NotImplementedError(
|
||||
f'Live connection is not supported for {self.model}.'
|
||||
)
|
||||
76
src/google/adk/models/base_llm_connection.py
Normal file
76
src/google/adk/models/base_llm_connection.py
Normal file
@ -0,0 +1,76 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from abc import abstractmethod
|
||||
from typing import AsyncGenerator
|
||||
from google.genai import types
|
||||
from .llm_response import LlmResponse
|
||||
|
||||
|
||||
class BaseLlmConnection:
|
||||
"""The base class for a live model connection."""
|
||||
|
||||
@abstractmethod
|
||||
async def send_history(self, history: list[types.Content]):
|
||||
"""Sends the conversation history to the model.
|
||||
|
||||
You call this method right after setting up the model connection.
|
||||
The model will respond if the last content is from user, otherwise it will
|
||||
wait for new user input before responding.
|
||||
|
||||
Args:
|
||||
history: The conversation history to send to the model.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def send_content(self, content: types.Content):
|
||||
"""Sends a user content to the model.
|
||||
|
||||
The model will respond immediately upon receiving the content.
|
||||
If you send function responses, all parts in the content should be function
|
||||
responses.
|
||||
|
||||
Args:
|
||||
content: The content to send to the model.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def send_realtime(self, blob: types.Blob):
|
||||
"""Sends a chunk of audio or a frame of video to the model in realtime.
|
||||
|
||||
The model may not respond immediately upon receiving the blob. It will do
|
||||
voice activity detection and decide when to respond.
|
||||
|
||||
Args:
|
||||
blob: The blob to send to the model.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def receive(self) -> AsyncGenerator[LlmResponse, None]:
|
||||
"""Receives the model response using the llm server connection.
|
||||
|
||||
Args: None.
|
||||
|
||||
Yields:
|
||||
LlmResponse: The model response.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def close(self):
|
||||
"""Closes the llm server connection."""
|
||||
pass
|
||||
200
src/google/adk/models/gemini_llm_connection.py
Normal file
200
src/google/adk/models/gemini_llm_connection.py
Normal file
@ -0,0 +1,200 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
from typing import AsyncGenerator
|
||||
|
||||
from google.genai import live
|
||||
from google.genai import types
|
||||
|
||||
from .base_llm_connection import BaseLlmConnection
|
||||
from .llm_response import LlmResponse
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class GeminiLlmConnection(BaseLlmConnection):
|
||||
"""The Gemini model connection."""
|
||||
|
||||
def __init__(self, gemini_session: live.AsyncSession):
|
||||
self._gemini_session = gemini_session
|
||||
|
||||
async def send_history(self, history: list[types.Content]):
|
||||
"""Sends the conversation history to the gemini model.
|
||||
|
||||
You call this method right after setting up the model connection.
|
||||
The model will respond if the last content is from user, otherwise it will
|
||||
wait for new user input before responding.
|
||||
|
||||
Args:
|
||||
history: The conversation history to send to the model.
|
||||
"""
|
||||
|
||||
# TODO: Remove this filter and translate unary contents to streaming
|
||||
# contents properly.
|
||||
|
||||
# We ignore any audio from user during the agent transfer phase
|
||||
contents = [
|
||||
content
|
||||
for content in history
|
||||
if content.parts and content.parts[0].text
|
||||
]
|
||||
|
||||
if contents:
|
||||
await self._gemini_session.send(
|
||||
input=types.LiveClientContent(
|
||||
turns=contents,
|
||||
turn_complete=contents[-1].role == 'user',
|
||||
),
|
||||
)
|
||||
else:
|
||||
logger.info('no content is sent')
|
||||
|
||||
async def send_content(self, content: types.Content):
|
||||
"""Sends a user content to the gemini model.
|
||||
|
||||
The model will respond immediately upon receiving the content.
|
||||
If you send function responses, all parts in the content should be function
|
||||
responses.
|
||||
|
||||
Args:
|
||||
content: The content to send to the model.
|
||||
"""
|
||||
|
||||
assert content.parts
|
||||
if content.parts[0].function_response:
|
||||
# All parts have to be function responses.
|
||||
function_responses = [part.function_response for part in content.parts]
|
||||
logger.debug('Sending LLM function response: %s', function_responses)
|
||||
await self._gemini_session.send(
|
||||
input=types.LiveClientToolResponse(
|
||||
function_responses=function_responses
|
||||
),
|
||||
)
|
||||
else:
|
||||
logger.debug('Sending LLM new content %s', content)
|
||||
await self._gemini_session.send(
|
||||
input=types.LiveClientContent(
|
||||
turns=[content],
|
||||
turn_complete=True,
|
||||
)
|
||||
)
|
||||
|
||||
async def send_realtime(self, blob: types.Blob):
|
||||
"""Sends a chunk of audio or a frame of video to the model in realtime.
|
||||
|
||||
Args:
|
||||
blob: The blob to send to the model.
|
||||
"""
|
||||
|
||||
input_blob = blob.model_dump()
|
||||
logger.debug('Sending LLM Blob: %s', input_blob)
|
||||
await self._gemini_session.send(input=input_blob)
|
||||
|
||||
def __build_full_text_response(self, text: str):
|
||||
"""Builds a full text response.
|
||||
|
||||
The text should not partial and the returned LlmResponse is not be
|
||||
partial.
|
||||
|
||||
Args:
|
||||
text: The text to be included in the response.
|
||||
|
||||
Returns:
|
||||
An LlmResponse containing the full text.
|
||||
"""
|
||||
return LlmResponse(
|
||||
content=types.Content(
|
||||
role='model',
|
||||
parts=[types.Part.from_text(text=text)],
|
||||
),
|
||||
)
|
||||
|
||||
async def receive(self) -> AsyncGenerator[LlmResponse, None]:
|
||||
"""Receives the model response using the llm server connection.
|
||||
|
||||
Yields:
|
||||
LlmResponse: The model response.
|
||||
"""
|
||||
|
||||
text = ''
|
||||
async for message in self._gemini_session.receive():
|
||||
logger.debug('Got LLM Live message: %s', message)
|
||||
if message.server_content:
|
||||
content = message.server_content.model_turn
|
||||
if content and content.parts:
|
||||
llm_response = LlmResponse(
|
||||
content=content, interrupted=message.server_content.interrupted
|
||||
)
|
||||
if content.parts[0].text:
|
||||
text += content.parts[0].text
|
||||
llm_response.partial = True
|
||||
# don't yield the merged text event when receiving audio data
|
||||
elif text and not content.parts[0].inline_data:
|
||||
yield self.__build_full_text_response(text)
|
||||
text = ''
|
||||
yield llm_response
|
||||
|
||||
if (
|
||||
message.server_content.output_transcription
|
||||
and message.server_content.output_transcription.text
|
||||
):
|
||||
# TODO: Right now, we just support output_transcription without
|
||||
# changing interface and data protocol. Later, we can consider to
|
||||
# support output_transcription as a separete field in LlmResponse.
|
||||
|
||||
# Transcription is always considered as partial event
|
||||
# We rely on other control signals to determine when to yield the
|
||||
# full text response(turn_complete, interrupted, or tool_call).
|
||||
text += message.server_content.output_transcription.text
|
||||
parts = [
|
||||
types.Part.from_text(
|
||||
text=message.server_content.output_transcription.text
|
||||
)
|
||||
]
|
||||
llm_response = LlmResponse(
|
||||
content=types.Content(role='model', parts=parts), partial=True
|
||||
)
|
||||
yield llm_response
|
||||
|
||||
if message.server_content.turn_complete:
|
||||
if text:
|
||||
yield self.__build_full_text_response(text)
|
||||
text = ''
|
||||
yield LlmResponse(
|
||||
turn_complete=True, interrupted=message.server_content.interrupted
|
||||
)
|
||||
break
|
||||
# in case of empty content or parts, we sill surface it
|
||||
# in case it's an interrupted message, we merge the previous partial
|
||||
# text. Other we don't merge. because content can be none when model
|
||||
# safty threshold is triggered
|
||||
if message.server_content.interrupted and text:
|
||||
yield self.__build_full_text_response(text)
|
||||
text = ''
|
||||
yield LlmResponse(interrupted=message.server_content.interrupted)
|
||||
if message.tool_call:
|
||||
if text:
|
||||
yield self.__build_full_text_response(text)
|
||||
text = ''
|
||||
parts = [
|
||||
types.Part(function_call=function_call)
|
||||
for function_call in message.tool_call.function_calls
|
||||
]
|
||||
yield LlmResponse(content=types.Content(role='model', parts=parts))
|
||||
|
||||
async def close(self):
|
||||
"""Closes the llm server connection."""
|
||||
|
||||
await self._gemini_session.close()
|
||||
331
src/google/adk/models/google_llm.py
Normal file
331
src/google/adk/models/google_llm.py
Normal file
@ -0,0 +1,331 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from __future__ import annotations
|
||||
|
||||
import contextlib
|
||||
from functools import cached_property
|
||||
import logging
|
||||
import sys
|
||||
from typing import AsyncGenerator
|
||||
from typing import cast
|
||||
from typing import Generator
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from google.genai import Client
|
||||
from google.genai import types
|
||||
from typing_extensions import override
|
||||
|
||||
from .. import version
|
||||
from .base_llm import BaseLlm
|
||||
from .base_llm_connection import BaseLlmConnection
|
||||
from .gemini_llm_connection import GeminiLlmConnection
|
||||
from .llm_response import LlmResponse
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .llm_request import LlmRequest
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_NEW_LINE = '\n'
|
||||
_EXCLUDED_PART_FIELD = {'inline_data': {'data'}}
|
||||
|
||||
|
||||
class Gemini(BaseLlm):
|
||||
"""Integration for Gemini models.
|
||||
|
||||
Attributes:
|
||||
model: The name of the Gemini model.
|
||||
"""
|
||||
|
||||
model: str = 'gemini-1.5-flash'
|
||||
|
||||
@staticmethod
|
||||
@override
|
||||
def supported_models() -> list[str]:
|
||||
"""Provides the list of supported models.
|
||||
|
||||
Returns:
|
||||
A list of supported models.
|
||||
"""
|
||||
|
||||
return [
|
||||
r'gemini-.*',
|
||||
# fine-tuned vertex endpoint pattern
|
||||
r'projects\/.+\/locations\/.+\/endpoints\/.+',
|
||||
# vertex gemini long name
|
||||
r'projects\/.+\/locations\/.+\/publishers\/google\/models\/gemini.+',
|
||||
]
|
||||
|
||||
async def generate_content_async(
|
||||
self, llm_request: LlmRequest, stream: bool = False
|
||||
) -> AsyncGenerator[LlmResponse, None]:
|
||||
"""Sends a request to the Gemini model.
|
||||
|
||||
Args:
|
||||
llm_request: LlmRequest, the request to send to the Gemini model.
|
||||
stream: bool = False, whether to do streaming call.
|
||||
|
||||
Yields:
|
||||
LlmResponse: The model response.
|
||||
"""
|
||||
|
||||
self._maybe_append_user_content(llm_request)
|
||||
logger.info(
|
||||
'Sending out request, model: %s, backend: %s, stream: %s',
|
||||
llm_request.model,
|
||||
self._api_backend,
|
||||
stream,
|
||||
)
|
||||
logger.info(_build_request_log(llm_request))
|
||||
|
||||
if stream:
|
||||
responses = await self.api_client.aio.models.generate_content_stream(
|
||||
model=llm_request.model,
|
||||
contents=llm_request.contents,
|
||||
config=llm_request.config,
|
||||
)
|
||||
response = None
|
||||
text = ''
|
||||
# for sse, similar as bidi (see receive method in gemini_llm_connecton.py),
|
||||
# we need to mark those text content as partial and after all partial
|
||||
# contents are sent, we send an accumulated event which contains all the
|
||||
# previous partial content. The only difference is bidi rely on
|
||||
# complete_turn flag to detect end while sse depends on finish_reason.
|
||||
async for response in responses:
|
||||
logger.info(_build_response_log(response))
|
||||
llm_response = LlmResponse.create(response)
|
||||
if (
|
||||
llm_response.content
|
||||
and llm_response.content.parts
|
||||
and llm_response.content.parts[0].text
|
||||
):
|
||||
text += llm_response.content.parts[0].text
|
||||
llm_response.partial = True
|
||||
elif text and (
|
||||
not llm_response.content
|
||||
or not llm_response.content.parts
|
||||
# don't yield the merged text event when receiving audio data
|
||||
or not llm_response.content.parts[0].inline_data
|
||||
):
|
||||
yield LlmResponse(
|
||||
content=types.ModelContent(
|
||||
parts=[types.Part.from_text(text=text)],
|
||||
),
|
||||
)
|
||||
text = ''
|
||||
yield llm_response
|
||||
if (
|
||||
text
|
||||
and response
|
||||
and response.candidates
|
||||
and response.candidates[0].finish_reason == types.FinishReason.STOP
|
||||
):
|
||||
yield LlmResponse(
|
||||
content=types.ModelContent(
|
||||
parts=[types.Part.from_text(text=text)],
|
||||
),
|
||||
)
|
||||
|
||||
else:
|
||||
response = await self.api_client.aio.models.generate_content(
|
||||
model=llm_request.model,
|
||||
contents=llm_request.contents,
|
||||
config=llm_request.config,
|
||||
)
|
||||
logger.info(_build_response_log(response))
|
||||
yield LlmResponse.create(response)
|
||||
|
||||
@cached_property
|
||||
def api_client(self) -> Client:
|
||||
"""Provides the api client.
|
||||
|
||||
Returns:
|
||||
The api client.
|
||||
"""
|
||||
return Client(
|
||||
http_options=types.HttpOptions(headers=self._tracking_headers)
|
||||
)
|
||||
|
||||
@cached_property
|
||||
def _api_backend(self) -> str:
|
||||
return 'vertex' if self.api_client.vertexai else 'ml_dev'
|
||||
|
||||
@cached_property
|
||||
def _tracking_headers(self) -> dict[str, str]:
|
||||
framework_label = f'google-adk/{version.__version__}'
|
||||
language_label = 'gl-python/' + sys.version.split()[0]
|
||||
version_header_value = f'{framework_label} {language_label}'
|
||||
tracking_headers = {
|
||||
'x-goog-api-client': version_header_value,
|
||||
'user-agent': version_header_value,
|
||||
}
|
||||
return tracking_headers
|
||||
|
||||
@cached_property
|
||||
def _live_api_client(self) -> Client:
|
||||
if self._api_backend == 'vertex':
|
||||
# use default api version for vertex
|
||||
return Client(
|
||||
http_options=types.HttpOptions(headers=self._tracking_headers)
|
||||
)
|
||||
else:
|
||||
# use v1alpha for ml_dev
|
||||
api_version = 'v1alpha'
|
||||
return Client(
|
||||
http_options=types.HttpOptions(
|
||||
headers=self._tracking_headers, api_version=api_version
|
||||
)
|
||||
)
|
||||
|
||||
@contextlib.asynccontextmanager
|
||||
async def connect(self, llm_request: LlmRequest) -> BaseLlmConnection:
|
||||
"""Connects to the Gemini model and returns an llm connection.
|
||||
|
||||
Args:
|
||||
llm_request: LlmRequest, the request to send to the Gemini model.
|
||||
|
||||
Yields:
|
||||
BaseLlmConnection, the connection to the Gemini model.
|
||||
"""
|
||||
|
||||
llm_request.live_connect_config.system_instruction = types.Content(
|
||||
role='system',
|
||||
parts=[
|
||||
types.Part.from_text(text=llm_request.config.system_instruction)
|
||||
],
|
||||
)
|
||||
llm_request.live_connect_config.tools = llm_request.config.tools
|
||||
async with self._live_api_client.aio.live.connect(
|
||||
model=llm_request.model, config=llm_request.live_connect_config
|
||||
) as live_session:
|
||||
yield GeminiLlmConnection(live_session)
|
||||
|
||||
def _maybe_append_user_content(self, llm_request: LlmRequest):
|
||||
"""Appends a user content, so that model can continue to output.
|
||||
|
||||
Args:
|
||||
llm_request: LlmRequest, the request to send to the Gemini model.
|
||||
"""
|
||||
# If no content is provided, append a user content to hint model response
|
||||
# using system instruction.
|
||||
if not llm_request.contents:
|
||||
llm_request.contents.append(
|
||||
types.Content(
|
||||
role='user',
|
||||
parts=[
|
||||
types.Part(
|
||||
text=(
|
||||
'Handle the requests as specified in the System'
|
||||
' Instruction.'
|
||||
)
|
||||
)
|
||||
],
|
||||
)
|
||||
)
|
||||
return
|
||||
|
||||
# Insert a user content to preserve user intent and to avoid empty
|
||||
# model response.
|
||||
if llm_request.contents[-1].role != 'user':
|
||||
llm_request.contents.append(
|
||||
types.Content(
|
||||
role='user',
|
||||
parts=[
|
||||
types.Part(
|
||||
text=(
|
||||
'Continue processing previous requests as instructed.'
|
||||
' Exit or provide a summary if no more outputs are'
|
||||
' needed.'
|
||||
)
|
||||
)
|
||||
],
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def _build_function_declaration_log(
|
||||
func_decl: types.FunctionDeclaration,
|
||||
) -> str:
|
||||
param_str = '{}'
|
||||
if func_decl.parameters and func_decl.parameters.properties:
|
||||
param_str = str({
|
||||
k: v.model_dump(exclude_none=True)
|
||||
for k, v in func_decl.parameters.properties.items()
|
||||
})
|
||||
return_str = 'None'
|
||||
if func_decl.response:
|
||||
return_str = str(func_decl.response.model_dump(exclude_none=True))
|
||||
return f'{func_decl.name}: {param_str} -> {return_str}'
|
||||
|
||||
|
||||
def _build_request_log(req: LlmRequest) -> str:
|
||||
function_decls: list[types.FunctionDeclaration] = cast(
|
||||
list[types.FunctionDeclaration],
|
||||
req.config.tools[0].function_declarations if req.config.tools else [],
|
||||
)
|
||||
function_logs = (
|
||||
[
|
||||
_build_function_declaration_log(func_decl)
|
||||
for func_decl in function_decls
|
||||
]
|
||||
if function_decls
|
||||
else []
|
||||
)
|
||||
contents_logs = [
|
||||
content.model_dump_json(
|
||||
exclude_none=True,
|
||||
exclude={
|
||||
'parts': {
|
||||
i: _EXCLUDED_PART_FIELD for i in range(len(content.parts))
|
||||
}
|
||||
},
|
||||
)
|
||||
for content in req.contents
|
||||
]
|
||||
|
||||
return f"""
|
||||
LLM Request:
|
||||
-----------------------------------------------------------
|
||||
System Instruction:
|
||||
{req.config.system_instruction}
|
||||
-----------------------------------------------------------
|
||||
Contents:
|
||||
{_NEW_LINE.join(contents_logs)}
|
||||
-----------------------------------------------------------
|
||||
Functions:
|
||||
{_NEW_LINE.join(function_logs)}
|
||||
-----------------------------------------------------------
|
||||
"""
|
||||
|
||||
|
||||
def _build_response_log(resp: types.GenerateContentResponse) -> str:
|
||||
function_calls_text = []
|
||||
if function_calls := resp.function_calls:
|
||||
for func_call in function_calls:
|
||||
function_calls_text.append(
|
||||
f'name: {func_call.name}, args: {func_call.args}'
|
||||
)
|
||||
return f"""
|
||||
LLM Response:
|
||||
-----------------------------------------------------------
|
||||
Text:
|
||||
{resp.text}
|
||||
-----------------------------------------------------------
|
||||
Function calls:
|
||||
{_NEW_LINE.join(function_calls_text)}
|
||||
-----------------------------------------------------------
|
||||
Raw response:
|
||||
{resp.model_dump_json(exclude_none=True)}
|
||||
-----------------------------------------------------------
|
||||
"""
|
||||
673
src/google/adk/models/lite_llm.py
Normal file
673
src/google/adk/models/lite_llm.py
Normal file
@ -0,0 +1,673 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import base64
|
||||
import json
|
||||
import logging
|
||||
from typing import Any
|
||||
from typing import AsyncGenerator
|
||||
from typing import cast
|
||||
from typing import Dict
|
||||
from typing import Generator
|
||||
from typing import Iterable
|
||||
from typing import Literal
|
||||
from typing import Optional
|
||||
from typing import Tuple
|
||||
from typing import Union
|
||||
|
||||
from google.genai import types
|
||||
from litellm import acompletion
|
||||
from litellm import ChatCompletionAssistantMessage
|
||||
from litellm import ChatCompletionDeveloperMessage
|
||||
from litellm import ChatCompletionImageUrlObject
|
||||
from litellm import ChatCompletionMessageToolCall
|
||||
from litellm import ChatCompletionTextObject
|
||||
from litellm import ChatCompletionToolMessage
|
||||
from litellm import ChatCompletionUserMessage
|
||||
from litellm import ChatCompletionVideoUrlObject
|
||||
from litellm import completion
|
||||
from litellm import CustomStreamWrapper
|
||||
from litellm import Function
|
||||
from litellm import Message
|
||||
from litellm import ModelResponse
|
||||
from litellm import OpenAIMessageContent
|
||||
from pydantic import BaseModel
|
||||
from pydantic import Field
|
||||
from typing_extensions import override
|
||||
|
||||
from .base_llm import BaseLlm
|
||||
from .llm_request import LlmRequest
|
||||
from .llm_response import LlmResponse
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_NEW_LINE = "\n"
|
||||
_EXCLUDED_PART_FIELD = {"inline_data": {"data"}}
|
||||
|
||||
|
||||
class FunctionChunk(BaseModel):
|
||||
id: Optional[str]
|
||||
name: Optional[str]
|
||||
args: Optional[str]
|
||||
|
||||
|
||||
class TextChunk(BaseModel):
|
||||
text: str
|
||||
|
||||
|
||||
class LiteLLMClient:
|
||||
"""Provides acompletion method (for better testability)."""
|
||||
|
||||
async def acompletion(
|
||||
self, model, messages, tools, **kwargs
|
||||
) -> Union[ModelResponse, CustomStreamWrapper]:
|
||||
"""Asynchronously calls acompletion.
|
||||
|
||||
Args:
|
||||
model: The model name.
|
||||
messages: The messages to send to the model.
|
||||
tools: The tools to use for the model.
|
||||
**kwargs: Additional arguments to pass to acompletion.
|
||||
|
||||
Returns:
|
||||
The model response as a message.
|
||||
"""
|
||||
|
||||
return await acompletion(
|
||||
model=model,
|
||||
messages=messages,
|
||||
tools=tools,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def completion(
|
||||
self, model, messages, tools, stream=False, **kwargs
|
||||
) -> Union[ModelResponse, CustomStreamWrapper]:
|
||||
"""Synchronously calls completion. This is used for streaming only.
|
||||
|
||||
Args:
|
||||
model: The model to use.
|
||||
messages: The messages to send.
|
||||
tools: The tools to use for the model.
|
||||
stream: Whether to stream the response.
|
||||
**kwargs: Additional arguments to pass to completion.
|
||||
|
||||
Returns:
|
||||
The response from the model.
|
||||
"""
|
||||
|
||||
return completion(
|
||||
model=model,
|
||||
messages=messages,
|
||||
tools=tools,
|
||||
stream=stream,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
def _safe_json_serialize(obj) -> str:
|
||||
"""Convert any Python object to a JSON-serializable type or string.
|
||||
|
||||
Args:
|
||||
obj: The object to serialize.
|
||||
|
||||
Returns:
|
||||
The JSON-serialized object string or string.
|
||||
"""
|
||||
|
||||
try:
|
||||
# Try direct JSON serialization first
|
||||
return json.dumps(obj)
|
||||
except (TypeError, OverflowError):
|
||||
return str(obj)
|
||||
|
||||
|
||||
def _content_to_message_param(
|
||||
content: types.Content,
|
||||
) -> Message:
|
||||
"""Converts a types.Content to a litellm Message.
|
||||
|
||||
Args:
|
||||
content: The content to convert.
|
||||
|
||||
Returns:
|
||||
The litellm Message.
|
||||
"""
|
||||
|
||||
if content.parts and content.parts[0].function_response:
|
||||
return ChatCompletionToolMessage(
|
||||
role="tool",
|
||||
tool_call_id=content.parts[0].function_response.id,
|
||||
content=_safe_json_serialize(
|
||||
content.parts[0].function_response.response
|
||||
),
|
||||
)
|
||||
|
||||
role = _to_litellm_role(content.role)
|
||||
|
||||
if role == "user":
|
||||
return ChatCompletionUserMessage(
|
||||
role="user", content=_get_content(content.parts)
|
||||
)
|
||||
else:
|
||||
|
||||
tool_calls = [
|
||||
ChatCompletionMessageToolCall(
|
||||
type="function",
|
||||
id=part.function_call.id,
|
||||
function=Function(
|
||||
name=part.function_call.name,
|
||||
arguments=part.function_call.args,
|
||||
),
|
||||
)
|
||||
for part in content.parts
|
||||
if part.function_call
|
||||
]
|
||||
|
||||
return ChatCompletionAssistantMessage(
|
||||
role=role,
|
||||
content=_get_content(content.parts),
|
||||
tool_calls=tool_calls or None,
|
||||
)
|
||||
|
||||
|
||||
def _get_content(parts: Iterable[types.Part]) -> OpenAIMessageContent | str:
|
||||
"""Converts a list of parts to litellm content.
|
||||
|
||||
Args:
|
||||
parts: The parts to convert.
|
||||
|
||||
Returns:
|
||||
The litellm content.
|
||||
"""
|
||||
|
||||
content_objects = []
|
||||
for part in parts:
|
||||
if part.text:
|
||||
if len(parts) == 1:
|
||||
return part.text
|
||||
content_objects.append(
|
||||
ChatCompletionTextObject(
|
||||
type="text",
|
||||
text=part.text,
|
||||
)
|
||||
)
|
||||
elif (
|
||||
part.inline_data
|
||||
and part.inline_data.data
|
||||
and part.inline_data.mime_type
|
||||
):
|
||||
base64_string = base64.b64encode(part.inline_data.data).decode("utf-8")
|
||||
data_uri = f"data:{part.inline_data.mime_type};base64,{base64_string}"
|
||||
|
||||
if part.inline_data.mime_type.startswith("image"):
|
||||
content_objects.append(
|
||||
ChatCompletionImageUrlObject(
|
||||
type="image_url",
|
||||
image_url=data_uri,
|
||||
)
|
||||
)
|
||||
elif part.inline_data.mime_type.startswith("video"):
|
||||
content_objects.append(
|
||||
ChatCompletionVideoUrlObject(
|
||||
type="video_url",
|
||||
video_url=data_uri,
|
||||
)
|
||||
)
|
||||
else:
|
||||
raise ValueError("LiteLlm(BaseLlm) does not support this content part.")
|
||||
|
||||
return content_objects
|
||||
|
||||
|
||||
def _to_litellm_role(role: Optional[str]) -> Literal["user", "assistant"]:
|
||||
"""Converts a types.Content role to a litellm role.
|
||||
|
||||
Args:
|
||||
role: The types.Content role.
|
||||
|
||||
Returns:
|
||||
The litellm role.
|
||||
"""
|
||||
|
||||
if role in ["model", "assistant"]:
|
||||
return "assistant"
|
||||
return "user"
|
||||
|
||||
|
||||
TYPE_LABELS = {
|
||||
"STRING": "string",
|
||||
"NUMBER": "number",
|
||||
"BOOLEAN": "boolean",
|
||||
"OBJECT": "object",
|
||||
"ARRAY": "array",
|
||||
"INTEGER": "integer",
|
||||
}
|
||||
|
||||
|
||||
def _schema_to_dict(schema: types.Schema) -> dict:
|
||||
"""Recursively converts a types.Schema to a dictionary.
|
||||
|
||||
Args:
|
||||
schema: The schema to convert.
|
||||
|
||||
Returns:
|
||||
The dictionary representation of the schema.
|
||||
"""
|
||||
|
||||
schema_dict = schema.model_dump(exclude_none=True)
|
||||
if "type" in schema_dict:
|
||||
schema_dict["type"] = schema_dict["type"].lower()
|
||||
if "items" in schema_dict:
|
||||
if isinstance(schema_dict["items"], dict):
|
||||
schema_dict["items"] = _schema_to_dict(
|
||||
types.Schema.model_validate(schema_dict["items"])
|
||||
)
|
||||
elif isinstance(schema_dict["items"]["type"], types.Type):
|
||||
schema_dict["items"]["type"] = TYPE_LABELS[
|
||||
schema_dict["items"]["type"].value
|
||||
]
|
||||
if "properties" in schema_dict:
|
||||
properties = {}
|
||||
for key, value in schema_dict["properties"].items():
|
||||
if isinstance(value, types.Schema):
|
||||
properties[key] = _schema_to_dict(value)
|
||||
else:
|
||||
properties[key] = value
|
||||
if "type" in properties[key]:
|
||||
properties[key]["type"] = properties[key]["type"].lower()
|
||||
schema_dict["properties"] = properties
|
||||
return schema_dict
|
||||
|
||||
|
||||
def _function_declaration_to_tool_param(
|
||||
function_declaration: types.FunctionDeclaration,
|
||||
) -> dict:
|
||||
"""Converts a types.FunctionDeclaration to a openapi spec dictionary.
|
||||
|
||||
Args:
|
||||
function_declaration: The function declaration to convert.
|
||||
|
||||
Returns:
|
||||
The openapi spec dictionary representation of the function declaration.
|
||||
"""
|
||||
|
||||
assert function_declaration.name
|
||||
|
||||
properties = {}
|
||||
if (
|
||||
function_declaration.parameters
|
||||
and function_declaration.parameters.properties
|
||||
):
|
||||
for key, value in function_declaration.parameters.properties.items():
|
||||
properties[key] = _schema_to_dict(value)
|
||||
|
||||
return {
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": function_declaration.name,
|
||||
"description": function_declaration.description or "",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": properties,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def _model_response_to_chunk(
|
||||
response: ModelResponse,
|
||||
) -> Generator[
|
||||
Tuple[Optional[Union[TextChunk, FunctionChunk]], Optional[str]], None, None
|
||||
]:
|
||||
"""Converts a litellm message to text or function chunk.
|
||||
|
||||
Args:
|
||||
response: The response from the model.
|
||||
|
||||
Yields:
|
||||
A tuple of text or function chunk and finish reason.
|
||||
"""
|
||||
|
||||
message = None
|
||||
if response.get("choices", None):
|
||||
message = response["choices"][0].get("message", None)
|
||||
finish_reason = response["choices"][0].get("finish_reason", None)
|
||||
# check streaming delta
|
||||
if message is None and response["choices"][0].get("delta", None):
|
||||
message = response["choices"][0]["delta"]
|
||||
|
||||
if message.get("content", None):
|
||||
yield TextChunk(text=message.get("content")), finish_reason
|
||||
|
||||
if message.get("tool_calls", None):
|
||||
for tool_call in message.get("tool_calls"):
|
||||
# aggregate tool_call
|
||||
if tool_call.type == "function":
|
||||
yield FunctionChunk(
|
||||
id=tool_call.id,
|
||||
name=tool_call.function.name,
|
||||
args=tool_call.function.arguments,
|
||||
), finish_reason
|
||||
|
||||
if finish_reason and not (
|
||||
message.get("content", None) or message.get("tool_calls", None)
|
||||
):
|
||||
yield None, finish_reason
|
||||
|
||||
if not message:
|
||||
yield None, None
|
||||
|
||||
|
||||
def _model_response_to_generate_content_response(
|
||||
response: ModelResponse,
|
||||
) -> LlmResponse:
|
||||
"""Converts a litellm response to LlmResponse.
|
||||
|
||||
Args:
|
||||
response: The model response.
|
||||
|
||||
Returns:
|
||||
The LlmResponse.
|
||||
"""
|
||||
|
||||
message = None
|
||||
if response.get("choices", None):
|
||||
message = response["choices"][0].get("message", None)
|
||||
|
||||
if not message:
|
||||
raise ValueError("No message in response")
|
||||
return _message_to_generate_content_response(message)
|
||||
|
||||
|
||||
def _message_to_generate_content_response(
|
||||
message: Message, is_partial: bool = False
|
||||
) -> LlmResponse:
|
||||
"""Converts a litellm message to LlmResponse.
|
||||
|
||||
Args:
|
||||
message: The message to convert.
|
||||
is_partial: Whether the message is partial.
|
||||
|
||||
Returns:
|
||||
The LlmResponse.
|
||||
"""
|
||||
|
||||
parts = []
|
||||
if message.get("content", None):
|
||||
parts.append(types.Part.from_text(text=message.get("content")))
|
||||
|
||||
if message.get("tool_calls", None):
|
||||
for tool_call in message.get("tool_calls"):
|
||||
if tool_call.type == "function":
|
||||
part = types.Part.from_function_call(
|
||||
name=tool_call.function.name,
|
||||
args=json.loads(tool_call.function.arguments or "{}"),
|
||||
)
|
||||
part.function_call.id = tool_call.id
|
||||
parts.append(part)
|
||||
|
||||
return LlmResponse(
|
||||
content=types.Content(role="model", parts=parts), partial=is_partial
|
||||
)
|
||||
|
||||
|
||||
def _get_completion_inputs(
|
||||
llm_request: LlmRequest,
|
||||
) -> tuple[Iterable[Message], Iterable[dict]]:
|
||||
"""Converts an LlmRequest to litellm inputs.
|
||||
|
||||
Args:
|
||||
llm_request: The LlmRequest to convert.
|
||||
|
||||
Returns:
|
||||
The litellm inputs (message list and tool dictionary).
|
||||
"""
|
||||
messages = [
|
||||
_content_to_message_param(content)
|
||||
for content in llm_request.contents or []
|
||||
]
|
||||
|
||||
if llm_request.config.system_instruction:
|
||||
messages.insert(
|
||||
0,
|
||||
ChatCompletionDeveloperMessage(
|
||||
role="developer",
|
||||
content=llm_request.config.system_instruction,
|
||||
),
|
||||
)
|
||||
|
||||
tools = None
|
||||
if (
|
||||
llm_request.config
|
||||
and llm_request.config.tools
|
||||
and llm_request.config.tools[0].function_declarations
|
||||
):
|
||||
tools = [
|
||||
_function_declaration_to_tool_param(tool)
|
||||
for tool in llm_request.config.tools[0].function_declarations
|
||||
]
|
||||
return messages, tools
|
||||
|
||||
|
||||
def _build_function_declaration_log(
|
||||
func_decl: types.FunctionDeclaration,
|
||||
) -> str:
|
||||
"""Builds a function declaration log.
|
||||
|
||||
Args:
|
||||
func_decl: The function declaration to convert.
|
||||
|
||||
Returns:
|
||||
The function declaration log.
|
||||
"""
|
||||
|
||||
param_str = "{}"
|
||||
if func_decl.parameters and func_decl.parameters.properties:
|
||||
param_str = str({
|
||||
k: v.model_dump(exclude_none=True)
|
||||
for k, v in func_decl.parameters.properties.items()
|
||||
})
|
||||
return_str = "None"
|
||||
if func_decl.response:
|
||||
return_str = str(func_decl.response.model_dump(exclude_none=True))
|
||||
return f"{func_decl.name}: {param_str} -> {return_str}"
|
||||
|
||||
|
||||
def _build_request_log(req: LlmRequest) -> str:
|
||||
"""Builds a request log.
|
||||
|
||||
Args:
|
||||
req: The request to convert.
|
||||
|
||||
Returns:
|
||||
The request log.
|
||||
"""
|
||||
|
||||
function_decls: list[types.FunctionDeclaration] = cast(
|
||||
list[types.FunctionDeclaration],
|
||||
req.config.tools[0].function_declarations if req.config.tools else [],
|
||||
)
|
||||
function_logs = (
|
||||
[
|
||||
_build_function_declaration_log(func_decl)
|
||||
for func_decl in function_decls
|
||||
]
|
||||
if function_decls
|
||||
else []
|
||||
)
|
||||
contents_logs = [
|
||||
content.model_dump_json(
|
||||
exclude_none=True,
|
||||
exclude={
|
||||
"parts": {
|
||||
i: _EXCLUDED_PART_FIELD for i in range(len(content.parts))
|
||||
}
|
||||
},
|
||||
)
|
||||
for content in req.contents
|
||||
]
|
||||
|
||||
return f"""
|
||||
LLM Request:
|
||||
-----------------------------------------------------------
|
||||
System Instruction:
|
||||
{req.config.system_instruction}
|
||||
-----------------------------------------------------------
|
||||
Contents:
|
||||
{_NEW_LINE.join(contents_logs)}
|
||||
-----------------------------------------------------------
|
||||
Functions:
|
||||
{_NEW_LINE.join(function_logs)}
|
||||
-----------------------------------------------------------
|
||||
"""
|
||||
|
||||
|
||||
class LiteLlm(BaseLlm):
|
||||
"""Wrapper around litellm.
|
||||
|
||||
This wrapper can be used with any of the models supported by litellm. The
|
||||
environment variable(s) needed for authenticating with the model endpoint must
|
||||
be set prior to instantiating this class.
|
||||
|
||||
Example usage:
|
||||
```
|
||||
os.environ["VERTEXAI_PROJECT"] = "your-gcp-project-id"
|
||||
os.environ["VERTEXAI_LOCATION"] = "your-gcp-location"
|
||||
|
||||
agent = Agent(
|
||||
model=LiteLlm(model="vertex_ai/claude-3-7-sonnet@20250219"),
|
||||
...
|
||||
)
|
||||
```
|
||||
|
||||
Attributes:
|
||||
model: The name of the LiteLlm model.
|
||||
llm_client: The LLM client to use for the model.
|
||||
model_config: The model config.
|
||||
"""
|
||||
|
||||
llm_client: LiteLLMClient = Field(default_factory=LiteLLMClient)
|
||||
"""The LLM client to use for the model."""
|
||||
|
||||
_additional_args: Dict[str, Any] = None
|
||||
|
||||
def __init__(self, model: str, **kwargs):
|
||||
"""Initializes the LiteLlm class.
|
||||
|
||||
Args:
|
||||
model: The name of the LiteLlm model.
|
||||
**kwargs: Additional arguments to pass to the litellm completion api.
|
||||
"""
|
||||
super().__init__(model=model, **kwargs)
|
||||
self._additional_args = kwargs
|
||||
# preventing generation call with llm_client
|
||||
# and overriding messages, tools and stream which are managed internally
|
||||
self._additional_args.pop("llm_client", None)
|
||||
self._additional_args.pop("messages", None)
|
||||
self._additional_args.pop("tools", None)
|
||||
# public api called from runner determines to stream or not
|
||||
self._additional_args.pop("stream", None)
|
||||
|
||||
async def generate_content_async(
|
||||
self, llm_request: LlmRequest, stream: bool = False
|
||||
) -> AsyncGenerator[LlmResponse, None]:
|
||||
"""Generates content asynchronously.
|
||||
|
||||
Args:
|
||||
llm_request: LlmRequest, the request to send to the LiteLlm model.
|
||||
stream: bool = False, whether to do streaming call.
|
||||
|
||||
Yields:
|
||||
LlmResponse: The model response.
|
||||
"""
|
||||
|
||||
logger.info(_build_request_log(llm_request))
|
||||
|
||||
messages, tools = _get_completion_inputs(llm_request)
|
||||
|
||||
completion_args = {
|
||||
"model": self.model,
|
||||
"messages": messages,
|
||||
"tools": tools,
|
||||
}
|
||||
completion_args.update(self._additional_args)
|
||||
|
||||
if stream:
|
||||
text = ""
|
||||
function_name = ""
|
||||
function_args = ""
|
||||
function_id = None
|
||||
completion_args["stream"] = True
|
||||
for part in self.llm_client.completion(**completion_args):
|
||||
for chunk, finish_reason in _model_response_to_chunk(part):
|
||||
if isinstance(chunk, FunctionChunk):
|
||||
if chunk.name:
|
||||
function_name += chunk.name
|
||||
if chunk.args:
|
||||
function_args += chunk.args
|
||||
function_id = chunk.id or function_id
|
||||
elif isinstance(chunk, TextChunk):
|
||||
text += chunk.text
|
||||
yield _message_to_generate_content_response(
|
||||
ChatCompletionAssistantMessage(
|
||||
role="assistant",
|
||||
content=chunk.text,
|
||||
),
|
||||
is_partial=True,
|
||||
)
|
||||
if finish_reason == "tool_calls" and function_id:
|
||||
yield _message_to_generate_content_response(
|
||||
ChatCompletionAssistantMessage(
|
||||
role="assistant",
|
||||
content="",
|
||||
tool_calls=[
|
||||
ChatCompletionMessageToolCall(
|
||||
type="function",
|
||||
id=function_id,
|
||||
function=Function(
|
||||
name=function_name,
|
||||
arguments=function_args,
|
||||
),
|
||||
)
|
||||
],
|
||||
)
|
||||
)
|
||||
function_name = ""
|
||||
function_args = ""
|
||||
function_id = None
|
||||
elif finish_reason == "stop" and text:
|
||||
yield _message_to_generate_content_response(
|
||||
ChatCompletionAssistantMessage(role="assistant", content=text)
|
||||
)
|
||||
text = ""
|
||||
|
||||
else:
|
||||
response = await self.llm_client.acompletion(**completion_args)
|
||||
yield _model_response_to_generate_content_response(response)
|
||||
|
||||
@staticmethod
|
||||
@override
|
||||
def supported_models() -> list[str]:
|
||||
"""Provides the list of supported models.
|
||||
|
||||
LiteLlm supports all models supported by litellm. We do not keep track of
|
||||
these models here. So we return an empty list.
|
||||
|
||||
Returns:
|
||||
A list of supported models.
|
||||
"""
|
||||
|
||||
return []
|
||||
98
src/google/adk/models/llm_request.py
Normal file
98
src/google/adk/models/llm_request.py
Normal file
@ -0,0 +1,98 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from google.genai import types
|
||||
from pydantic import BaseModel
|
||||
from pydantic import ConfigDict
|
||||
from pydantic import Field
|
||||
|
||||
from ..tools.base_tool import BaseTool
|
||||
|
||||
|
||||
class LlmRequest(BaseModel):
|
||||
"""LLM request class that allows passing in tools, output schema and system
|
||||
|
||||
instructions to the model.
|
||||
|
||||
Attributes:
|
||||
model: The model name.
|
||||
contents: The contents to send to the model.
|
||||
config: Additional config for the generate content request.
|
||||
tools_dict: The tools dictionary.
|
||||
"""
|
||||
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
"""The model config."""
|
||||
|
||||
model: Optional[str] = None
|
||||
"""The model name."""
|
||||
|
||||
contents: list[types.Content] = Field(default_factory=list)
|
||||
"""The contents to send to the model."""
|
||||
|
||||
config: Optional[types.GenerateContentConfig] = None
|
||||
live_connect_config: types.LiveConnectConfig = types.LiveConnectConfig()
|
||||
"""Additional config for the generate content request.
|
||||
|
||||
tools in generate_content_config should not be set.
|
||||
"""
|
||||
tools_dict: dict[str, BaseTool] = Field(default_factory=dict, exclude=True)
|
||||
"""The tools dictionary."""
|
||||
|
||||
def append_instructions(self, instructions: list[str]) -> None:
|
||||
"""Appends instructions to the system instruction.
|
||||
|
||||
Args:
|
||||
instructions: The instructions to append.
|
||||
"""
|
||||
|
||||
if self.config.system_instruction:
|
||||
self.config.system_instruction += '\n\n' + '\n\n'.join(instructions)
|
||||
else:
|
||||
self.config.system_instruction = '\n\n'.join(instructions)
|
||||
|
||||
def append_tools(self, tools: list[BaseTool]) -> None:
|
||||
"""Appends tools to the request.
|
||||
|
||||
Args:
|
||||
tools: The tools to append.
|
||||
"""
|
||||
|
||||
if not tools:
|
||||
return
|
||||
declarations = []
|
||||
for tool in tools:
|
||||
if isinstance(tool, BaseTool):
|
||||
declaration = tool._get_declaration()
|
||||
else:
|
||||
declaration = tool.get_declaration()
|
||||
if declaration:
|
||||
declarations.append(declaration)
|
||||
self.tools_dict[tool.name] = tool
|
||||
if declarations:
|
||||
self.config.tools.append(types.Tool(function_declarations=declarations))
|
||||
|
||||
def set_output_schema(self, base_model: type[BaseModel]) -> None:
|
||||
"""Sets the output schema for the request.
|
||||
|
||||
Args:
|
||||
base_model: The pydantic base model to set the output schema to.
|
||||
"""
|
||||
|
||||
self.config.response_schema = base_model
|
||||
self.config.response_mime_type = 'application/json'
|
||||
111
src/google/adk/models/llm_response.py
Normal file
111
src/google/adk/models/llm_response.py
Normal file
@ -0,0 +1,111 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from google.genai import types
|
||||
from pydantic import BaseModel
|
||||
from pydantic import ConfigDict
|
||||
|
||||
|
||||
class LlmResponse(BaseModel):
|
||||
"""LLM response class that provides the first candidate response from the
|
||||
|
||||
model if available. Otherwise, returns error code and message.
|
||||
|
||||
Attributes:
|
||||
content: The content of the response.
|
||||
grounding_metadata: The grounding metadata of the response.
|
||||
partial: Indicates whether the text content is part of a unfinished text
|
||||
stream. Only used for streaming mode and when the content is plain text.
|
||||
turn_complete: Indicates whether the response from the model is complete.
|
||||
Only used for streaming mode.
|
||||
error_code: Error code if the response is an error. Code varies by model.
|
||||
error_message: Error message if the response is an error.
|
||||
interrupted: Flag indicating that LLM was interrupted when generating the
|
||||
content. Usually it's due to user interruption during a bidi streaming.
|
||||
"""
|
||||
|
||||
model_config = ConfigDict(extra='forbid')
|
||||
"""The model config."""
|
||||
|
||||
content: Optional[types.Content] = None
|
||||
"""The content of the response."""
|
||||
|
||||
grounding_metadata: Optional[types.GroundingMetadata] = None
|
||||
"""The grounding metadata of the response."""
|
||||
|
||||
partial: Optional[bool] = None
|
||||
"""Indicates whether the text content is part of a unfinished text stream.
|
||||
|
||||
Only used for streaming mode and when the content is plain text.
|
||||
"""
|
||||
|
||||
turn_complete: Optional[bool] = None
|
||||
"""Indicates whether the response from the model is complete.
|
||||
|
||||
Only used for streaming mode.
|
||||
"""
|
||||
|
||||
error_code: Optional[str] = None
|
||||
"""Error code if the response is an error. Code varies by model."""
|
||||
|
||||
error_message: Optional[str] = None
|
||||
"""Error message if the response is an error."""
|
||||
|
||||
interrupted: Optional[bool] = None
|
||||
"""Flag indicating that LLM was interrupted when generating the content.
|
||||
Usually it's due to user interruption during a bidi streaming.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def create(
|
||||
generate_content_response: types.GenerateContentResponse,
|
||||
) -> 'LlmResponse':
|
||||
"""Creates an LlmResponse from a GenerateContentResponse.
|
||||
|
||||
Args:
|
||||
generate_content_response: The GenerateContentResponse to create the
|
||||
LlmResponse from.
|
||||
|
||||
Returns:
|
||||
The LlmResponse.
|
||||
"""
|
||||
|
||||
if generate_content_response.candidates:
|
||||
candidate = generate_content_response.candidates[0]
|
||||
if candidate.content and candidate.content.parts:
|
||||
return LlmResponse(
|
||||
content=candidate.content,
|
||||
grounding_metadata=candidate.grounding_metadata,
|
||||
)
|
||||
else:
|
||||
return LlmResponse(
|
||||
error_code=candidate.finish_reason,
|
||||
error_message=candidate.finish_message,
|
||||
)
|
||||
else:
|
||||
if generate_content_response.prompt_feedback:
|
||||
prompt_feedback = generate_content_response.prompt_feedback
|
||||
return LlmResponse(
|
||||
error_code=prompt_feedback.block_reason,
|
||||
error_message=prompt_feedback.block_reason_message,
|
||||
)
|
||||
else:
|
||||
return LlmResponse(
|
||||
error_code='UNKNOWN_ERROR',
|
||||
error_message='Unknown error.',
|
||||
)
|
||||
102
src/google/adk/models/registry.py
Normal file
102
src/google/adk/models/registry.py
Normal file
@ -0,0 +1,102 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""The registry class for model."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from functools import lru_cache
|
||||
import logging
|
||||
import re
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .base_llm import BaseLlm
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
_llm_registry_dict: dict[str, type[BaseLlm]] = {}
|
||||
"""Registry for LLMs.
|
||||
|
||||
Key is the regex that matches the model name.
|
||||
Value is the class that implements the model.
|
||||
"""
|
||||
|
||||
|
||||
class LLMRegistry:
|
||||
"""Registry for LLMs."""
|
||||
|
||||
@staticmethod
|
||||
def new_llm(model: str) -> BaseLlm:
|
||||
"""Creates a new LLM instance.
|
||||
|
||||
Args:
|
||||
model: The model name.
|
||||
|
||||
Returns:
|
||||
The LLM instance.
|
||||
"""
|
||||
|
||||
return LLMRegistry.resolve(model)(model=model)
|
||||
|
||||
@staticmethod
|
||||
def _register(model_name_regex: str, llm_cls: type[BaseLlm]):
|
||||
"""Registers a new LLM class.
|
||||
|
||||
Args:
|
||||
model_name_regex: The regex that matches the model name.
|
||||
llm_cls: The class that implements the model.
|
||||
"""
|
||||
|
||||
if model_name_regex in _llm_registry_dict:
|
||||
logger.info(
|
||||
'Updating LLM class for %s from %s to %s',
|
||||
model_name_regex,
|
||||
_llm_registry_dict[model_name_regex],
|
||||
llm_cls,
|
||||
)
|
||||
|
||||
_llm_registry_dict[model_name_regex] = llm_cls
|
||||
|
||||
@staticmethod
|
||||
def register(llm_cls: type[BaseLlm]):
|
||||
"""Registers a new LLM class.
|
||||
|
||||
Args:
|
||||
llm_cls: The class that implements the model.
|
||||
"""
|
||||
|
||||
for regex in llm_cls.supported_models():
|
||||
LLMRegistry._register(regex, llm_cls)
|
||||
|
||||
@staticmethod
|
||||
@lru_cache(maxsize=32)
|
||||
def resolve(model: str) -> type[BaseLlm]:
|
||||
"""Resolves the model to a BaseLlm subclass.
|
||||
|
||||
Args:
|
||||
model: The model name.
|
||||
|
||||
Returns:
|
||||
The BaseLlm subclass.
|
||||
Raises:
|
||||
ValueError: If the model is not found.
|
||||
"""
|
||||
|
||||
for regex, llm_class in _llm_registry_dict.items():
|
||||
if re.compile(regex).fullmatch(model):
|
||||
return llm_class
|
||||
|
||||
raise ValueError(f'Model {model} not found.')
|
||||
23
src/google/adk/planners/__init__.py
Normal file
23
src/google/adk/planners/__init__.py
Normal file
@ -0,0 +1,23 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from .base_planner import BasePlanner
|
||||
from .built_in_planner import BuiltInPlanner
|
||||
from .plan_re_act_planner import PlanReActPlanner
|
||||
|
||||
__all__ = [
|
||||
'BasePlanner',
|
||||
'BuiltInPlanner',
|
||||
'PlanReActPlanner',
|
||||
]
|
||||
66
src/google/adk/planners/base_planner.py
Normal file
66
src/google/adk/planners/base_planner.py
Normal file
@ -0,0 +1,66 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import abc
|
||||
from abc import ABC
|
||||
from typing import List
|
||||
from typing import Optional
|
||||
|
||||
from google.genai import types
|
||||
|
||||
from ..agents.callback_context import CallbackContext
|
||||
from ..agents.readonly_context import ReadonlyContext
|
||||
from ..models.llm_request import LlmRequest
|
||||
|
||||
|
||||
class BasePlanner(ABC):
|
||||
"""Abstract base class for all planners.
|
||||
|
||||
The planner allows the agent to generate plans for the queries to guide its
|
||||
action.
|
||||
"""
|
||||
|
||||
@abc.abstractmethod
|
||||
def build_planning_instruction(
|
||||
self,
|
||||
readonly_context: ReadonlyContext,
|
||||
llm_request: LlmRequest,
|
||||
) -> Optional[str]:
|
||||
"""Builds the system instruction to be appended to the LLM request for planning.
|
||||
|
||||
Args:
|
||||
readonly_context: The readonly context of the invocation.
|
||||
llm_request: The LLM request. Readonly.
|
||||
|
||||
Returns:
|
||||
The planning system instruction, or None if no instruction is needed.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def process_planning_response(
|
||||
self,
|
||||
callback_context: CallbackContext,
|
||||
response_parts: List[types.Part],
|
||||
) -> Optional[List[types.Part]]:
|
||||
"""Processes the LLM response for planning.
|
||||
|
||||
Args:
|
||||
callback_context: The callback context of the invocation.
|
||||
response_parts: The LLM response parts. Readonly.
|
||||
|
||||
Returns:
|
||||
The processed response parts, or None if no processing is needed.
|
||||
"""
|
||||
pass
|
||||
75
src/google/adk/planners/built_in_planner.py
Normal file
75
src/google/adk/planners/built_in_planner.py
Normal file
@ -0,0 +1,75 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import List
|
||||
from typing import Optional
|
||||
|
||||
from google.genai import types
|
||||
from typing_extensions import override
|
||||
|
||||
from ..agents.callback_context import CallbackContext
|
||||
from ..agents.readonly_context import ReadonlyContext
|
||||
from ..models.llm_request import LlmRequest
|
||||
from .base_planner import BasePlanner
|
||||
|
||||
|
||||
class BuiltInPlanner(BasePlanner):
|
||||
"""The built-in planner that uses model's built-in thinking features.
|
||||
|
||||
Attributes:
|
||||
thinking_config: Config for model built-in thinking features. An error
|
||||
will be returned if this field is set for models that don't support
|
||||
thinking.
|
||||
"""
|
||||
|
||||
thinking_config: types.ThinkingConfig
|
||||
"""
|
||||
Config for model built-in thinking features. An error will be returned if this
|
||||
field is set for models that don't support thinking.
|
||||
"""
|
||||
|
||||
def __init__(self, *, thinking_config: types.ThinkingConfig):
|
||||
"""Initializes the built-in planner.
|
||||
|
||||
Args:
|
||||
thinking_config: Config for model built-in thinking features. An error
|
||||
will be returned if this field is set for models that don't support
|
||||
thinking.
|
||||
"""
|
||||
self.thinking_config = thinking_config
|
||||
|
||||
def apply_thinking_config(self, llm_request: LlmRequest) -> None:
|
||||
"""Applies the thinking config to the LLM request.
|
||||
|
||||
Args:
|
||||
llm_request: The LLM request to apply the thinking config to.
|
||||
"""
|
||||
if self.thinking_config:
|
||||
llm_request.config.thinking_config = self.thinking_config
|
||||
|
||||
@override
|
||||
def build_planning_instruction(
|
||||
self,
|
||||
readonly_context: ReadonlyContext,
|
||||
llm_request: LlmRequest,
|
||||
) -> Optional[str]:
|
||||
return
|
||||
|
||||
@override
|
||||
def process_planning_response(
|
||||
self,
|
||||
callback_context: CallbackContext,
|
||||
response_parts: List[types.Part],
|
||||
) -> Optional[List[types.Part]]:
|
||||
return
|
||||
208
src/google/adk/planners/plan_re_act_planner.py
Normal file
208
src/google/adk/planners/plan_re_act_planner.py
Normal file
@ -0,0 +1,208 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import List
|
||||
from typing import Optional
|
||||
|
||||
from google.genai import types
|
||||
from typing_extensions import override
|
||||
|
||||
from ..agents.callback_context import CallbackContext
|
||||
from ..agents.readonly_context import ReadonlyContext
|
||||
from ..models.llm_request import LlmRequest
|
||||
from .base_planner import BasePlanner
|
||||
|
||||
PLANNING_TAG = '/*PLANNING*/'
|
||||
REPLANNING_TAG = '/*REPLANNING*/'
|
||||
REASONING_TAG = '/*REASONING*/'
|
||||
ACTION_TAG = '/*ACTION*/'
|
||||
FINAL_ANSWER_TAG = '/*FINAL_ANSWER*/'
|
||||
|
||||
|
||||
class PlanReActPlanner(BasePlanner):
|
||||
"""Plan-Re-Act planner that constraints the LLM response to generate a plan before any action/observation.
|
||||
|
||||
Note: this planner does not require the model to support buil-in thinking
|
||||
features or setting the thinking config.
|
||||
"""
|
||||
|
||||
@override
|
||||
def build_planning_instruction(
|
||||
self,
|
||||
readonly_context: ReadonlyContext,
|
||||
llm_request: LlmRequest,
|
||||
) -> str:
|
||||
return self._build_nl_planner_instruction()
|
||||
|
||||
@override
|
||||
def process_planning_response(
|
||||
self,
|
||||
callback_context: CallbackContext,
|
||||
response_parts: List[types.Part],
|
||||
) -> Optional[List[types.Part]]:
|
||||
if not response_parts:
|
||||
return None
|
||||
|
||||
preserved_parts = []
|
||||
first_fc_part_index = -1
|
||||
for i in range(len(response_parts)):
|
||||
# Stop at the first (group of) function calls.
|
||||
if response_parts[i].function_call:
|
||||
# Ignore and filter out function calls with empty names.
|
||||
if not response_parts[i].function_call.name:
|
||||
continue
|
||||
preserved_parts.append(response_parts[i])
|
||||
first_fc_part_index = i
|
||||
break
|
||||
|
||||
# Split the response into reasoning and final answer parts.
|
||||
self._handle_non_function_call_parts(response_parts[i], preserved_parts)
|
||||
|
||||
if first_fc_part_index > 0:
|
||||
j = first_fc_part_index + 1
|
||||
while j < len(response_parts):
|
||||
if response_parts[j].function_call:
|
||||
preserved_parts.append(response_parts[j])
|
||||
j += 1
|
||||
else:
|
||||
break
|
||||
|
||||
return preserved_parts
|
||||
|
||||
def _split_by_last_pattern(self, text, separator):
|
||||
"""Splits the text by the last occurrence of the separator.
|
||||
|
||||
Args:
|
||||
text: The text to split.
|
||||
separator: The separator to split on.
|
||||
|
||||
Returns:
|
||||
A tuple containing the text before the last separator and the text after
|
||||
the last separator.
|
||||
"""
|
||||
index = text.rfind(separator)
|
||||
if index == -1:
|
||||
return text, ''
|
||||
return text[: index + len(separator)], text[index + len(separator) :]
|
||||
|
||||
def _handle_non_function_call_parts(
|
||||
self, response_part: types.Part, preserved_parts: list[types.Part]
|
||||
):
|
||||
"""Handles non-function-call parts of the response.
|
||||
|
||||
Args:
|
||||
response_part: The response part to handle.
|
||||
preserved_parts: The mutable list of parts to store the processed parts
|
||||
in.
|
||||
"""
|
||||
if response_part.text and FINAL_ANSWER_TAG in response_part.text:
|
||||
reasoning_text, final_answer_text = self._split_by_last_pattern(
|
||||
response_part.text, FINAL_ANSWER_TAG
|
||||
)
|
||||
if reasoning_text:
|
||||
reasoning_part = types.Part(text=reasoning_text)
|
||||
self._mark_as_thought(reasoning_part)
|
||||
preserved_parts.append(reasoning_part)
|
||||
if final_answer_text:
|
||||
preserved_parts.append(
|
||||
types.Part(
|
||||
text=final_answer_text,
|
||||
)
|
||||
)
|
||||
else:
|
||||
response_text = response_part.text or ''
|
||||
# If the part is a text part with a planning/reasoning/action tag,
|
||||
# label it as reasoning.
|
||||
if response_text and (
|
||||
any(
|
||||
response_text.startswith(tag)
|
||||
for tag in [
|
||||
PLANNING_TAG,
|
||||
REASONING_TAG,
|
||||
ACTION_TAG,
|
||||
REPLANNING_TAG,
|
||||
]
|
||||
)
|
||||
):
|
||||
self._mark_as_thought(response_part)
|
||||
preserved_parts.append(response_part)
|
||||
|
||||
def _mark_as_thought(self, response_part: types.Part):
|
||||
"""Marks the response part as thought.
|
||||
|
||||
Args:
|
||||
response_part: The mutable response part to mark as thought.
|
||||
"""
|
||||
if response_part.text:
|
||||
response_part.thought = True
|
||||
return
|
||||
|
||||
def _build_nl_planner_instruction(self) -> str:
|
||||
"""Builds the NL planner instruction for the Plan-Re-Act planner.
|
||||
|
||||
Returns:
|
||||
NL planner system instruction.
|
||||
"""
|
||||
|
||||
high_level_preamble = f"""
|
||||
When answering the question, try to leverage the available tools to gather the information instead of your memorized knowledge.
|
||||
|
||||
Follow this process when answering the question: (1) first come up with a plan in natural language text format; (2) Then use tools to execute the plan and provide reasoning between tool code snippets to make a summary of current state and next step. Tool code snippets and reasoning should be interleaved with each other. (3) In the end, return one final answer.
|
||||
|
||||
Follow this format when answering the question: (1) The planning part should be under {PLANNING_TAG}. (2) The tool code snippets should be under {ACTION_TAG}, and the reasoning parts should be under {REASONING_TAG}. (3) The final answer part should be under {FINAL_ANSWER_TAG}.
|
||||
"""
|
||||
|
||||
planning_preamble = f"""
|
||||
Below are the requirements for the planning:
|
||||
The plan is made to answer the user query if following the plan. The plan is coherent and covers all aspects of information from user query, and only involves the tools that are accessible by the agent. The plan contains the decomposed steps as a numbered list where each step should use one or multiple available tools. By reading the plan, you can intuitively know which tools to trigger or what actions to take.
|
||||
If the initial plan cannot be successfully executed, you should learn from previous execution results and revise your plan. The revised plan should be be under {REPLANNING_TAG}. Then use tools to follow the new plan.
|
||||
"""
|
||||
|
||||
reasoning_preamble = """
|
||||
Below are the requirements for the reasoning:
|
||||
The reasoning makes a summary of the current trajectory based on the user query and tool outputs. Based on the tool outputs and plan, the reasoning also comes up with instructions to the next steps, making the trajectory closer to the final answer.
|
||||
"""
|
||||
|
||||
final_answer_preamble = """
|
||||
Below are the requirements for the final answer:
|
||||
The final answer should be precise and follow query formatting requirements. Some queries may not be answerable with the available tools and information. In those cases, inform the user why you cannot process their query and ask for more information.
|
||||
"""
|
||||
|
||||
# Only contains the requirements for custom tool/libraries.
|
||||
tool_code_without_python_libraries_preamble = """
|
||||
Below are the requirements for the tool code:
|
||||
|
||||
**Custom Tools:** The available tools are described in the context and can be directly used.
|
||||
- Code must be valid self-contained Python snippets with no imports and no references to tools or Python libraries that are not in the context.
|
||||
- You cannot use any parameters or fields that are not explicitly defined in the APIs in the context.
|
||||
- The code snippets should be readable, efficient, and directly relevant to the user query and reasoning steps.
|
||||
- When using the tools, you should use the library name together with the function name, e.g., vertex_search.search().
|
||||
- If Python libraries are not provided in the context, NEVER write your own code other than the function calls using the provided tools.
|
||||
"""
|
||||
|
||||
user_input_preamble = """
|
||||
VERY IMPORTANT instruction that you MUST follow in addition to the above instructions:
|
||||
|
||||
You should ask for clarification if you need more information to answer the question.
|
||||
You should prefer using the information available in the context instead of repeated tool use.
|
||||
"""
|
||||
|
||||
return '\n\n'.join([
|
||||
high_level_preamble,
|
||||
planning_preamble,
|
||||
reasoning_preamble,
|
||||
final_answer_preamble,
|
||||
tool_code_without_python_libraries_preamble,
|
||||
user_input_preamble,
|
||||
])
|
||||
456
src/google/adk/runners.py
Normal file
456
src/google/adk/runners.py
Normal file
@ -0,0 +1,456 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import queue
|
||||
import threading
|
||||
from typing import AsyncGenerator
|
||||
from typing import Generator
|
||||
from typing import Optional
|
||||
|
||||
from deprecated import deprecated
|
||||
from google.genai import types
|
||||
|
||||
from .agents.active_streaming_tool import ActiveStreamingTool
|
||||
from .agents.base_agent import BaseAgent
|
||||
from .agents.invocation_context import InvocationContext
|
||||
from .agents.invocation_context import new_invocation_context_id
|
||||
from .agents.live_request_queue import LiveRequestQueue
|
||||
from .agents.llm_agent import LlmAgent
|
||||
from .agents.run_config import RunConfig
|
||||
from .agents.run_config import StreamingMode
|
||||
from .artifacts.base_artifact_service import BaseArtifactService
|
||||
from .artifacts.in_memory_artifact_service import InMemoryArtifactService
|
||||
from .events.event import Event
|
||||
from .memory.base_memory_service import BaseMemoryService
|
||||
from .memory.in_memory_memory_service import InMemoryMemoryService
|
||||
from .sessions.base_session_service import BaseSessionService
|
||||
from .sessions.in_memory_session_service import InMemorySessionService
|
||||
from .sessions.session import Session
|
||||
from .telemetry import tracer
|
||||
from .tools.built_in_code_execution_tool import built_in_code_execution
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class Runner:
|
||||
"""The Runner class is used to run agents.
|
||||
|
||||
It manages the execution of an agent within a session, handling message
|
||||
processing, event generation, and interaction with various services like
|
||||
artifact storage, session management, and memory.
|
||||
|
||||
Attributes:
|
||||
app_name: The application name of the runner.
|
||||
agent: The root agent to run.
|
||||
artifact_service: The artifact service for the runner.
|
||||
session_service: The session service for the runner.
|
||||
memory_service: The memory service for the runner.
|
||||
"""
|
||||
|
||||
app_name: str
|
||||
"""The app name of the runner."""
|
||||
agent: BaseAgent
|
||||
"""The root agent to run."""
|
||||
artifact_service: Optional[BaseArtifactService] = None
|
||||
"""The artifact service for the runner."""
|
||||
session_service: BaseSessionService
|
||||
"""The session service for the runner."""
|
||||
memory_service: Optional[BaseMemoryService] = None
|
||||
"""The memory service for the runner."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
app_name: str,
|
||||
agent: BaseAgent,
|
||||
artifact_service: Optional[BaseArtifactService] = None,
|
||||
session_service: BaseSessionService,
|
||||
memory_service: Optional[BaseMemoryService] = None,
|
||||
):
|
||||
"""Initializes the Runner.
|
||||
|
||||
Args:
|
||||
app_name: The application name of the runner.
|
||||
agent: The root agent to run.
|
||||
artifact_service: The artifact service for the runner.
|
||||
session_service: The session service for the runner.
|
||||
memory_service: The memory service for the runner.
|
||||
"""
|
||||
self.app_name = app_name
|
||||
self.agent = agent
|
||||
self.artifact_service = artifact_service
|
||||
self.session_service = session_service
|
||||
self.memory_service = memory_service
|
||||
|
||||
def run(
|
||||
self,
|
||||
*,
|
||||
user_id: str,
|
||||
session_id: str,
|
||||
new_message: types.Content,
|
||||
run_config: RunConfig = RunConfig(),
|
||||
) -> Generator[Event, None, None]:
|
||||
"""Runs the agent.
|
||||
|
||||
NOTE: This sync interface is only for local testing and convenience purpose.
|
||||
Consider to use `run_async` for production usage.
|
||||
|
||||
Args:
|
||||
user_id: The user ID of the session.
|
||||
session_id: The session ID of the session.
|
||||
new_message: A new message to append to the session.
|
||||
run_config: The run config for the agent.
|
||||
|
||||
Yields:
|
||||
The events generated by the agent.
|
||||
"""
|
||||
event_queue = queue.Queue()
|
||||
|
||||
async def _invoke_run_async():
|
||||
try:
|
||||
async for event in self.run_async(
|
||||
user_id=user_id,
|
||||
session_id=session_id,
|
||||
new_message=new_message,
|
||||
run_config=run_config,
|
||||
):
|
||||
event_queue.put(event)
|
||||
finally:
|
||||
event_queue.put(None)
|
||||
|
||||
def _asyncio_thread_main():
|
||||
try:
|
||||
asyncio.run(_invoke_run_async())
|
||||
finally:
|
||||
event_queue.put(None)
|
||||
|
||||
thread = threading.Thread(target=_asyncio_thread_main)
|
||||
thread.start()
|
||||
|
||||
# consumes and re-yield the events from background thread.
|
||||
while True:
|
||||
event = event_queue.get()
|
||||
if event is None:
|
||||
break
|
||||
else:
|
||||
yield event
|
||||
|
||||
thread.join()
|
||||
|
||||
async def run_async(
|
||||
self,
|
||||
*,
|
||||
user_id: str,
|
||||
session_id: str,
|
||||
new_message: types.Content,
|
||||
run_config: RunConfig = RunConfig(),
|
||||
) -> AsyncGenerator[Event, None]:
|
||||
"""Main entry method to run the agent in this runner.
|
||||
|
||||
Args:
|
||||
user_id: The user ID of the session.
|
||||
session_id: The session ID of the session.
|
||||
new_message: A new message to append to the session.
|
||||
run_config: The run config for the agent.
|
||||
|
||||
Yields:
|
||||
The events generated by the agent.
|
||||
"""
|
||||
with tracer.start_as_current_span('invocation'):
|
||||
session = self.session_service.get_session(
|
||||
app_name=self.app_name, user_id=user_id, session_id=session_id
|
||||
)
|
||||
if not session:
|
||||
raise ValueError(f'Session not found: {session_id}')
|
||||
|
||||
invocation_context = self._new_invocation_context(
|
||||
session,
|
||||
new_message=new_message,
|
||||
run_config=run_config,
|
||||
)
|
||||
root_agent = self.agent
|
||||
|
||||
if new_message:
|
||||
self._append_new_message_to_session(
|
||||
session,
|
||||
new_message,
|
||||
invocation_context,
|
||||
run_config.save_input_blobs_as_artifacts,
|
||||
)
|
||||
|
||||
invocation_context.agent = self._find_agent_to_run(session, root_agent)
|
||||
async for event in invocation_context.agent.run_async(invocation_context):
|
||||
if not event.partial:
|
||||
self.session_service.append_event(session=session, event=event)
|
||||
yield event
|
||||
|
||||
def _append_new_message_to_session(
|
||||
self,
|
||||
session: Session,
|
||||
new_message: types.Content,
|
||||
invocation_context: InvocationContext,
|
||||
save_input_blobs_as_artifacts: bool = False,
|
||||
):
|
||||
"""Appends a new message to the session.
|
||||
|
||||
Args:
|
||||
session: The session to append the message to.
|
||||
new_message: The new message to append.
|
||||
invocation_context: The invocation context for the message.
|
||||
save_input_blobs_as_artifacts: Whether to save input blobs as artifacts.
|
||||
"""
|
||||
if not new_message.parts:
|
||||
raise ValueError('No parts in the new_message.')
|
||||
|
||||
if self.artifact_service and save_input_blobs_as_artifacts:
|
||||
# The runner directly saves the artifacts (if applicable) in the
|
||||
# user message and replaces the artifact data with a file name
|
||||
# placeholder.
|
||||
for i, part in enumerate(new_message.parts):
|
||||
if part.inline_data is None:
|
||||
continue
|
||||
file_name = f'artifact_{invocation_context.invocation_id}_{i}'
|
||||
self.artifact_service.save_artifact(
|
||||
app_name=self.app_name,
|
||||
user_id=session.user_id,
|
||||
session_id=session.id,
|
||||
filename=file_name,
|
||||
artifact=part,
|
||||
)
|
||||
new_message.parts[i] = types.Part(
|
||||
text=f'Uploaded file: {file_name}. It is saved into artifacts'
|
||||
)
|
||||
# Appends only. We do not yield the event because it's not from the model.
|
||||
event = Event(
|
||||
invocation_id=invocation_context.invocation_id,
|
||||
author='user',
|
||||
content=new_message,
|
||||
)
|
||||
self.session_service.append_event(session=session, event=event)
|
||||
|
||||
async def run_live(
|
||||
self,
|
||||
*,
|
||||
session: Session,
|
||||
live_request_queue: LiveRequestQueue,
|
||||
run_config: RunConfig = RunConfig(),
|
||||
) -> AsyncGenerator[Event, None]:
|
||||
"""Runs the agent in live mode (experimental feature).
|
||||
|
||||
Args:
|
||||
session: The session to use.
|
||||
live_request_queue: The queue for live requests.
|
||||
run_config: The run config for the agent.
|
||||
|
||||
Yields:
|
||||
The events generated by the agent.
|
||||
"""
|
||||
# TODO: right now, only works for a single audio agent without FC.
|
||||
invocation_context = self._new_invocation_context_for_live(
|
||||
session,
|
||||
live_request_queue=live_request_queue,
|
||||
run_config=run_config,
|
||||
)
|
||||
|
||||
root_agent = self.agent
|
||||
invocation_context.agent = self._find_agent_to_run(session, root_agent)
|
||||
|
||||
invocation_context.active_streaming_tools = {}
|
||||
# TODO(hangfei): switch to use canonical_tools.
|
||||
for tool in invocation_context.agent.tools:
|
||||
# replicate a LiveRequestQueue for streaming tools that relis on
|
||||
# LiveRequestQueue
|
||||
from typing import get_type_hints
|
||||
|
||||
type_hints = get_type_hints(tool)
|
||||
for arg_type in type_hints.values():
|
||||
if arg_type is LiveRequestQueue:
|
||||
if not invocation_context.active_streaming_tools:
|
||||
invocation_context.active_streaming_tools = {}
|
||||
active_streaming_tools = ActiveStreamingTool(
|
||||
stream=LiveRequestQueue()
|
||||
)
|
||||
invocation_context.active_streaming_tools[tool.__name__] = (
|
||||
active_streaming_tools
|
||||
)
|
||||
|
||||
async for event in invocation_context.agent.run_live(invocation_context):
|
||||
self.session_service.append_event(session=session, event=event)
|
||||
yield event
|
||||
|
||||
def close_session(self, session: Session):
|
||||
"""Closes a session and adds it to the memory service (experimental feature).
|
||||
|
||||
Args:
|
||||
session: The session to close.
|
||||
"""
|
||||
if self.memory_service:
|
||||
self.memory_service.add_session_to_memory(session)
|
||||
self.session_service.close_session(session=session)
|
||||
|
||||
def _find_agent_to_run(
|
||||
self, session: Session, root_agent: BaseAgent
|
||||
) -> BaseAgent:
|
||||
"""Finds the agent to run to continue the session.
|
||||
|
||||
A qualified agent must be either of:
|
||||
- The root agent;
|
||||
- An LlmAgent who replied last and is capable to transfer to any other agent
|
||||
in the agent hierarchy.
|
||||
|
||||
Args:
|
||||
session: The session to find the agent for.
|
||||
root_agent: The root agent of the runner.
|
||||
|
||||
Returns:
|
||||
The agent of the last message in the session or the root agent.
|
||||
"""
|
||||
for event in filter(lambda e: e.author != 'user', reversed(session.events)):
|
||||
if event.author == root_agent.name:
|
||||
# Found root agent.
|
||||
return root_agent
|
||||
if not (agent := root_agent.find_sub_agent(event.author)):
|
||||
# Agent not found, continue looking.
|
||||
logger.warning(
|
||||
'Event from an unknown agent: %s, event id: %s',
|
||||
event.author,
|
||||
event.id,
|
||||
)
|
||||
continue
|
||||
if self._is_transferable_across_agent_tree(agent):
|
||||
return agent
|
||||
# Falls back to root agent if no suitable agents are found in the session.
|
||||
return root_agent
|
||||
|
||||
def _is_transferable_across_agent_tree(self, agent_to_run: BaseAgent) -> bool:
|
||||
"""Whether the agent to run can transfer to any other agent in the agent tree.
|
||||
|
||||
This typically means all agent_to_run's parent through root agent can
|
||||
transfer to their parent_agent.
|
||||
|
||||
Args:
|
||||
agent_to_run: The agent to check for transferability.
|
||||
|
||||
Returns:
|
||||
True if the agent can transfer, False otherwise.
|
||||
"""
|
||||
agent = agent_to_run
|
||||
while agent:
|
||||
if not isinstance(agent, LlmAgent):
|
||||
# Only LLM-based Agent can provider agent transfer capability.
|
||||
return False
|
||||
if agent.disallow_transfer_to_parent:
|
||||
return False
|
||||
agent = agent.parent_agent
|
||||
return True
|
||||
|
||||
def _new_invocation_context(
|
||||
self,
|
||||
session: Session,
|
||||
*,
|
||||
new_message: Optional[types.Content] = None,
|
||||
live_request_queue: Optional[LiveRequestQueue] = None,
|
||||
run_config: RunConfig = RunConfig(),
|
||||
) -> InvocationContext:
|
||||
"""Creates a new invocation context.
|
||||
|
||||
Args:
|
||||
session: The session for the context.
|
||||
new_message: The new message for the context.
|
||||
live_request_queue: The live request queue for the context.
|
||||
run_config: The run config for the context.
|
||||
|
||||
Returns:
|
||||
The new invocation context.
|
||||
"""
|
||||
invocation_id = new_invocation_context_id()
|
||||
|
||||
if run_config.support_cfc and isinstance(self.agent, LlmAgent):
|
||||
model_name = self.agent.canonical_model.model
|
||||
if not model_name.startswith('gemini-2'):
|
||||
raise ValueError(
|
||||
f'CFC is not supported for model: {model_name} in agent:'
|
||||
f' {self.agent.name}'
|
||||
)
|
||||
if built_in_code_execution not in self.agent.canonical_tools:
|
||||
self.agent.tools.append(built_in_code_execution)
|
||||
|
||||
return InvocationContext(
|
||||
artifact_service=self.artifact_service,
|
||||
session_service=self.session_service,
|
||||
memory_service=self.memory_service,
|
||||
invocation_id=invocation_id,
|
||||
agent=self.agent,
|
||||
session=session,
|
||||
user_content=new_message,
|
||||
live_request_queue=live_request_queue,
|
||||
run_config=run_config,
|
||||
)
|
||||
|
||||
def _new_invocation_context_for_live(
|
||||
self,
|
||||
session: Session,
|
||||
*,
|
||||
live_request_queue: Optional[LiveRequestQueue] = None,
|
||||
run_config: RunConfig = RunConfig(),
|
||||
) -> InvocationContext:
|
||||
"""Creates a new invocation context for live multi-agent."""
|
||||
|
||||
# For live multi-agent, we need model's text transcription as context for
|
||||
# next agent.
|
||||
if self.agent.sub_agents and live_request_queue:
|
||||
if not run_config.response_modalities:
|
||||
# default
|
||||
run_config.response_modalities = ['AUDIO', 'TEXT']
|
||||
elif 'TEXT' not in run_config.response_modalities:
|
||||
run_config.response_modalities.append('TEXT')
|
||||
return self._new_invocation_context(
|
||||
session,
|
||||
live_request_queue=live_request_queue,
|
||||
run_config=run_config,
|
||||
)
|
||||
|
||||
|
||||
class InMemoryRunner(Runner):
|
||||
"""An in-memory Runner for testing and development.
|
||||
|
||||
This runner uses in-memory implementations for artifact, session, and memory
|
||||
services, providing a lightweight and self-contained environment for agent
|
||||
execution.
|
||||
|
||||
Attributes:
|
||||
agent: The root agent to run.
|
||||
app_name: The application name of the runner. Defaults to
|
||||
'InMemoryRunner'.
|
||||
"""
|
||||
|
||||
def __init__(self, agent: LlmAgent, *, app_name: str = 'InMemoryRunner'):
|
||||
"""Initializes the InMemoryRunner.
|
||||
|
||||
Args:
|
||||
agent: The root agent to run.
|
||||
app_name: The application name of the runner. Defaults to
|
||||
'InMemoryRunner'.
|
||||
"""
|
||||
super().__init__(
|
||||
app_name=app_name,
|
||||
agent=agent,
|
||||
artifact_service=InMemoryArtifactService(),
|
||||
session_service=InMemorySessionService(),
|
||||
memory_service=InMemoryMemoryService(),
|
||||
)
|
||||
41
src/google/adk/sessions/__init__.py
Normal file
41
src/google/adk/sessions/__init__.py
Normal file
@ -0,0 +1,41 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import logging
|
||||
|
||||
from .base_session_service import BaseSessionService
|
||||
from .in_memory_session_service import InMemorySessionService
|
||||
from .session import Session
|
||||
from .state import State
|
||||
from .vertex_ai_session_service import VertexAiSessionService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
__all__ = [
|
||||
'BaseSessionService',
|
||||
'InMemorySessionService',
|
||||
'Session',
|
||||
'State',
|
||||
'VertexAiSessionService',
|
||||
]
|
||||
|
||||
try:
|
||||
from .database_session_service import DatabaseSessionService
|
||||
|
||||
__all__.append('DatabaseSessionService')
|
||||
except ImportError:
|
||||
logger.debug(
|
||||
'DatabaseSessionService require sqlalchemy>=2.0, please ensure it is'
|
||||
' installed correctly.'
|
||||
)
|
||||
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user