structure saas with tools

This commit is contained in:
Davidson Gomes
2025-04-25 15:30:54 -03:00
commit 1aef473937
16434 changed files with 6584257 additions and 0 deletions

View File

@@ -0,0 +1,61 @@
# Copyright 2017 Google Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Utilities for Google Media Downloads and Resumable Uploads.
This package has some general purposes modules, e.g.
:mod:`~google.resumable_media.common`, but the majority of the
public interface will be contained in subpackages.
===========
Subpackages
===========
Each subpackage is tailored to a specific transport library:
* the :mod:`~google.resumable_media.requests` subpackage uses the ``requests``
transport library.
.. _requests: http://docs.python-requests.org/
==========
Installing
==========
To install with `pip`_:
.. code-block:: console
$ pip install --upgrade google-resumable-media
.. _pip: https://pip.pypa.io/
"""
from google.resumable_media.common import DataCorruption
from google.resumable_media.common import InvalidResponse
from google.resumable_media.common import PERMANENT_REDIRECT
from google.resumable_media.common import RetryStrategy
from google.resumable_media.common import TOO_MANY_REQUESTS
from google.resumable_media.common import UPLOAD_CHUNK_SIZE
__all__ = [
"DataCorruption",
"InvalidResponse",
"PERMANENT_REDIRECT",
"RetryStrategy",
"TOO_MANY_REQUESTS",
"UPLOAD_CHUNK_SIZE",
]

View File

@@ -0,0 +1,550 @@
# Copyright 2017 Google Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Virtual bases classes for downloading media from Google APIs."""
import http.client
import re
from google._async_resumable_media import _helpers
from google.resumable_media import common
_CONTENT_RANGE_RE = re.compile(
r"bytes (?P<start_byte>\d+)-(?P<end_byte>\d+)/(?P<total_bytes>\d+)",
flags=re.IGNORECASE,
)
_ACCEPTABLE_STATUS_CODES = (http.client.OK, http.client.PARTIAL_CONTENT)
_GET = "GET"
_ZERO_CONTENT_RANGE_HEADER = "bytes */0"
class DownloadBase(object):
"""Base class for download helpers.
Defines core shared behavior across different download types.
Args:
media_url (str): The URL containing the media to be downloaded.
stream (IO[bytes]): A write-able stream (i.e. file-like object) that
the downloaded resource can be written to.
start (int): The first byte in a range to be downloaded.
end (int): The last byte in a range to be downloaded.
headers (Optional[Mapping[str, str]]): Extra headers that should
be sent with the request, e.g. headers for encrypted data.
Attributes:
media_url (str): The URL containing the media to be downloaded.
start (Optional[int]): The first byte in a range to be downloaded.
end (Optional[int]): The last byte in a range to be downloaded.
"""
def __init__(self, media_url, stream=None, start=None, end=None, headers=None):
self.media_url = media_url
self._stream = stream
self.start = start
self.end = end
if headers is None:
headers = {}
self._headers = headers
self._finished = False
self._retry_strategy = common.RetryStrategy()
@property
def finished(self):
"""bool: Flag indicating if the download has completed."""
return self._finished
@staticmethod
def _get_status_code(response):
"""Access the status code from an HTTP response.
Args:
response (object): The HTTP response object.
Raises:
NotImplementedError: Always, since virtual.
"""
raise NotImplementedError("This implementation is virtual.")
@staticmethod
def _get_headers(response):
"""Access the headers from an HTTP response.
Args:
response (object): The HTTP response object.
Raises:
NotImplementedError: Always, since virtual.
"""
raise NotImplementedError("This implementation is virtual.")
@staticmethod
def _get_body(response):
"""Access the response body from an HTTP response.
Args:
response (object): The HTTP response object.
Raises:
NotImplementedError: Always, since virtual.
"""
raise NotImplementedError("This implementation is virtual.")
class Download(DownloadBase):
"""Helper to manage downloading a resource from a Google API.
"Slices" of the resource can be retrieved by specifying a range
with ``start`` and / or ``end``. However, in typical usage, neither
``start`` nor ``end`` is expected to be provided.
Args:
media_url (str): The URL containing the media to be downloaded.
stream (IO[bytes]): A write-able stream (i.e. file-like object) that
the downloaded resource can be written to.
start (int): The first byte in a range to be downloaded. If not
provided, but ``end`` is provided, will download from the
beginning to ``end`` of the media.
end (int): The last byte in a range to be downloaded. If not
provided, but ``start`` is provided, will download from the
``start`` to the end of the media.
headers (Optional[Mapping[str, str]]): Extra headers that should
be sent with the request, e.g. headers for encrypted data.
checksum Optional([str]): The type of checksum to compute to verify
the integrity of the object. The response headers must contain
a checksum of the requested type. If the headers lack an
appropriate checksum (for instance in the case of transcoded or
ranged downloads where the remote service does not know the
correct checksum) an INFO-level log will be emitted. Supported
values are "md5", "crc32c" and None.
"""
def __init__(
self, media_url, stream=None, start=None, end=None, headers=None, checksum="md5"
):
super(Download, self).__init__(
media_url, stream=stream, start=start, end=end, headers=headers
)
self.checksum = checksum
def _prepare_request(self):
"""Prepare the contents of an HTTP request.
This is everything that must be done before a request that doesn't
require network I/O (or other I/O). This is based on the `sans-I/O`_
philosophy.
Returns:
Tuple[str, str, NoneType, Mapping[str, str]]: The quadruple
* HTTP verb for the request (always GET)
* the URL for the request
* the body of the request (always :data:`None`)
* headers for the request
Raises:
ValueError: If the current :class:`Download` has already
finished.
.. _sans-I/O: https://sans-io.readthedocs.io/
"""
if self.finished:
raise ValueError("A download can only be used once.")
add_bytes_range(self.start, self.end, self._headers)
return _GET, self.media_url, None, self._headers
def _process_response(self, response):
"""Process the response from an HTTP request.
This is everything that must be done after a request that doesn't
require network I/O (or other I/O). This is based on the `sans-I/O`_
philosophy.
Args:
response (object): The HTTP response object.
.. _sans-I/O: https://sans-io.readthedocs.io/
"""
# Tombstone the current Download so it cannot be used again.
self._finished = True
_helpers.require_status_code(
response, _ACCEPTABLE_STATUS_CODES, self._get_status_code
)
def consume(self, transport, timeout=None):
"""Consume the resource to be downloaded.
If a ``stream`` is attached to this download, then the downloaded
resource will be written to the stream.
Args:
transport (object): An object which can make authenticated
requests.
timeout (Optional[Union[float, aiohttp.ClientTimeout]]):
The number of seconds to wait for the server response.
Depending on the retry strategy, a request may be repeated
several times using the same timeout each time.
Can also be passed as an `aiohttp.ClientTimeout` object.
Raises:
NotImplementedError: Always, since virtual.
"""
raise NotImplementedError("This implementation is virtual.")
class ChunkedDownload(DownloadBase):
"""Download a resource in chunks from a Google API.
Args:
media_url (str): The URL containing the media to be downloaded.
chunk_size (int): The number of bytes to be retrieved in each
request.
stream (IO[bytes]): A write-able stream (i.e. file-like object) that
will be used to concatenate chunks of the resource as they are
downloaded.
start (int): The first byte in a range to be downloaded. If not
provided, defaults to ``0``.
end (int): The last byte in a range to be downloaded. If not
provided, will download to the end of the media.
headers (Optional[Mapping[str, str]]): Extra headers that should
be sent with each request, e.g. headers for data encryption
key headers.
Attributes:
media_url (str): The URL containing the media to be downloaded.
start (Optional[int]): The first byte in a range to be downloaded.
end (Optional[int]): The last byte in a range to be downloaded.
chunk_size (int): The number of bytes to be retrieved in each request.
Raises:
ValueError: If ``start`` is negative.
"""
def __init__(self, media_url, chunk_size, stream, start=0, end=None, headers=None):
if start < 0:
raise ValueError(
"On a chunked download the starting " "value cannot be negative."
)
super(ChunkedDownload, self).__init__(
media_url, stream=stream, start=start, end=end, headers=headers
)
self.chunk_size = chunk_size
self._bytes_downloaded = 0
self._total_bytes = None
self._invalid = False
@property
def bytes_downloaded(self):
"""int: Number of bytes that have been downloaded."""
return self._bytes_downloaded
@property
def total_bytes(self):
"""Optional[int]: The total number of bytes to be downloaded."""
return self._total_bytes
@property
def invalid(self):
"""bool: Indicates if the download is in an invalid state.
This will occur if a call to :meth:`consume_next_chunk` fails.
"""
return self._invalid
def _get_byte_range(self):
"""Determines the byte range for the next request.
Returns:
Tuple[int, int]: The pair of begin and end byte for the next
chunked request.
"""
curr_start = self.start + self.bytes_downloaded
curr_end = curr_start + self.chunk_size - 1
# Make sure ``curr_end`` does not exceed ``end``.
if self.end is not None:
curr_end = min(curr_end, self.end)
# Make sure ``curr_end`` does not exceed ``total_bytes - 1``.
if self.total_bytes is not None:
curr_end = min(curr_end, self.total_bytes - 1)
return curr_start, curr_end
def _prepare_request(self):
"""Prepare the contents of an HTTP request.
This is everything that must be done before a request that doesn't
require network I/O (or other I/O). This is based on the `sans-I/O`_
philosophy.
.. note:
This method will be used multiple times, so ``headers`` will
be mutated in between requests. However, we don't make a copy
since the same keys are being updated.
Returns:
Tuple[str, str, NoneType, Mapping[str, str]]: The quadruple
* HTTP verb for the request (always GET)
* the URL for the request
* the body of the request (always :data:`None`)
* headers for the request
Raises:
ValueError: If the current download has finished.
ValueError: If the current download is invalid.
.. _sans-I/O: https://sans-io.readthedocs.io/
"""
if self.finished:
raise ValueError("Download has finished.")
if self.invalid:
raise ValueError("Download is invalid and cannot be re-used.")
curr_start, curr_end = self._get_byte_range()
add_bytes_range(curr_start, curr_end, self._headers)
return _GET, self.media_url, None, self._headers
def _make_invalid(self):
"""Simple setter for ``invalid``.
This is intended to be passed along as a callback to helpers that
raise an exception so they can mark this instance as invalid before
raising.
"""
self._invalid = True
async def _process_response(self, response):
"""Process the response from an HTTP request.
This is everything that must be done after a request that doesn't
require network I/O. This is based on the `sans-I/O`_ philosophy.
For the time being, this **does require** some form of I/O to write
a chunk to ``stream``. However, this will (almost) certainly not be
network I/O.
Updates the current state after consuming a chunk. First,
increments ``bytes_downloaded`` by the number of bytes in the
``content-length`` header.
If ``total_bytes`` is already set, this assumes (but does not check)
that we already have the correct value and doesn't bother to check
that it agrees with the headers.
We expect the **total** length to be in the ``content-range`` header,
but this header is only present on requests which sent the ``range``
header. This response header should be of the form
``bytes {start}-{end}/{total}`` and ``{end} - {start} + 1``
should be the same as the ``Content-Length``.
Args:
response (object): The HTTP response object (need headers).
Raises:
~google.resumable_media.common.InvalidResponse: If the number
of bytes in the body doesn't match the content length header.
.. _sans-I/O: https://sans-io.readthedocs.io/
"""
# Verify the response before updating the current instance.
if _check_for_zero_content_range(
response, self._get_status_code, self._get_headers
):
self._finished = True
return
_helpers.require_status_code(
response,
_ACCEPTABLE_STATUS_CODES,
self._get_status_code,
callback=self._make_invalid,
)
headers = self._get_headers(response)
response_body = await self._get_body(response)
start_byte, end_byte, total_bytes = get_range_info(
response, self._get_headers, callback=self._make_invalid
)
transfer_encoding = headers.get("transfer-encoding")
if transfer_encoding is None:
content_length = _helpers.header_required(
response,
"content-length",
self._get_headers,
callback=self._make_invalid,
)
num_bytes = int(content_length)
if len(response_body) != num_bytes:
self._make_invalid()
raise common.InvalidResponse(
response,
"Response is different size than content-length",
"Expected",
num_bytes,
"Received",
len(response_body),
)
else:
# 'content-length' header not allowed with chunked encoding.
num_bytes = end_byte - start_byte + 1
# First update ``bytes_downloaded``.
self._bytes_downloaded += num_bytes
# If the end byte is past ``end`` or ``total_bytes - 1`` we are done.
if self.end is not None and end_byte >= self.end:
self._finished = True
elif end_byte >= total_bytes - 1:
self._finished = True
# NOTE: We only use ``total_bytes`` if not already known.
if self.total_bytes is None:
self._total_bytes = total_bytes
# Write the response body to the stream.
self._stream.write(response_body)
def consume_next_chunk(self, transport, timeout=None):
"""Consume the next chunk of the resource to be downloaded.
Args:
transport (object): An object which can make authenticated
requests.
timeout (Optional[Union[float, aiohttp.ClientTimeout]]):
The number of seconds to wait for the server response.
Depending on the retry strategy, a request may be repeated
several times using the same timeout each time.
Can also be passed as an `aiohttp.ClientTimeout` object.
Raises:
NotImplementedError: Always, since virtual.
"""
raise NotImplementedError("This implementation is virtual.")
def add_bytes_range(start, end, headers):
"""Add a bytes range to a header dictionary.
Some possible inputs and the corresponding bytes ranges::
>>> headers = {}
>>> add_bytes_range(None, None, headers)
>>> headers
{}
>>> add_bytes_range(500, 999, headers)
>>> headers['range']
'bytes=500-999'
>>> add_bytes_range(None, 499, headers)
>>> headers['range']
'bytes=0-499'
>>> add_bytes_range(-500, None, headers)
>>> headers['range']
'bytes=-500'
>>> add_bytes_range(9500, None, headers)
>>> headers['range']
'bytes=9500-'
Args:
start (Optional[int]): The first byte in a range. Can be zero,
positive, negative or :data:`None`.
end (Optional[int]): The last byte in a range. Assumed to be
positive.
headers (Mapping[str, str]): A headers mapping which can have the
bytes range added if at least one of ``start`` or ``end``
is not :data:`None`.
"""
if start is None:
if end is None:
# No range to add.
return
else:
# NOTE: This assumes ``end`` is non-negative.
bytes_range = "0-{:d}".format(end)
else:
if end is None:
if start < 0:
bytes_range = "{:d}".format(start)
else:
bytes_range = "{:d}-".format(start)
else:
# NOTE: This is invalid if ``start < 0``.
bytes_range = "{:d}-{:d}".format(start, end)
headers[_helpers.RANGE_HEADER] = "bytes=" + bytes_range
def get_range_info(response, get_headers, callback=_helpers.do_nothing):
"""Get the start, end and total bytes from a content range header.
Args:
response (object): An HTTP response object.
get_headers (Callable[Any, Mapping[str, str]]): Helper to get headers
from an HTTP response.
callback (Optional[Callable]): A callback that takes no arguments,
to be executed when an exception is being raised.
Returns:
Tuple[int, int, int]: The start byte, end byte and total bytes.
Raises:
~google.resumable_media.common.InvalidResponse: If the
``Content-Range`` header is not of the form
``bytes {start}-{end}/{total}``.
"""
content_range = _helpers.header_required(
response, _helpers.CONTENT_RANGE_HEADER, get_headers, callback=callback
)
match = _CONTENT_RANGE_RE.match(content_range)
if match is None:
callback()
raise common.InvalidResponse(
response,
"Unexpected content-range header",
content_range,
'Expected to be of the form "bytes {start}-{end}/{total}"',
)
return (
int(match.group("start_byte")),
int(match.group("end_byte")),
int(match.group("total_bytes")),
)
def _check_for_zero_content_range(response, get_status_code, get_headers):
"""Validate if response status code is 416 and content range is zero.
This is the special case for handling zero bytes files.
Args:
response (object): An HTTP response object.
get_status_code (Callable[Any, int]): Helper to get a status code
from a response.
get_headers (Callable[Any, Mapping[str, str]]): Helper to get headers
from an HTTP response.
Returns:
bool: True if content range total bytes is zero, false otherwise.
"""
if get_status_code(response) == http.client.REQUESTED_RANGE_NOT_SATISFIABLE:
content_range = _helpers.header_required(
response,
_helpers.CONTENT_RANGE_HEADER,
get_headers,
callback=_helpers.do_nothing,
)
if content_range == _ZERO_CONTENT_RANGE_HEADER:
return True
return False

View File

@@ -0,0 +1,197 @@
# Copyright 2020 Google Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Shared utilities used by both downloads and uploads."""
import logging
import random
import time
from google.resumable_media import common
RANGE_HEADER = "range"
CONTENT_RANGE_HEADER = "content-range"
_SLOW_CRC32C_WARNING = (
"Currently using crcmod in pure python form. This is a slow "
"implementation. Python 3 has a faster implementation, `google-crc32c`, "
"which will be used if it is installed."
)
_HASH_HEADER = "x-goog-hash"
_MISSING_CHECKSUM = """\
No {checksum_type} checksum was returned from the service while downloading {}
(which happens for composite objects), so client-side content integrity
checking is not being performed."""
_LOGGER = logging.getLogger(__name__)
def do_nothing():
"""Simple default callback."""
def header_required(response, name, get_headers, callback=do_nothing):
"""Checks that a specific header is in a headers dictionary.
Args:
response (object): An HTTP response object, expected to have a
``headers`` attribute that is a ``Mapping[str, str]``.
name (str): The name of a required header.
get_headers (Callable[Any, Mapping[str, str]]): Helper to get headers
from an HTTP response.
callback (Optional[Callable]): A callback that takes no arguments,
to be executed when an exception is being raised.
Returns:
str: The desired header.
Raises:
~google.resumable_media.common.InvalidResponse: If the header
is missing.
"""
headers = get_headers(response)
if name not in headers:
callback()
raise common.InvalidResponse(
response, "Response headers must contain header", name
)
return headers[name]
def require_status_code(response, status_codes, get_status_code, callback=do_nothing):
"""Require a response has a status code among a list.
Args:
response (object): The HTTP response object.
status_codes (tuple): The acceptable status codes.
get_status_code (Callable[Any, int]): Helper to get a status code
from a response.
callback (Optional[Callable]): A callback that takes no arguments,
to be executed when an exception is being raised.
Returns:
int: The status code.
Raises:
~google.resumable_media.common.InvalidResponse: If the status code
is not one of the values in ``status_codes``.
"""
status_code = get_status_code(response)
if status_code not in status_codes:
callback()
raise common.InvalidResponse(
response,
"Request failed with status code",
status_code,
"Expected one of",
*status_codes
)
return status_code
def calculate_retry_wait(base_wait, max_sleep):
"""Calculate the amount of time to wait before a retry attempt.
Wait time grows exponentially with the number of attempts, until
``max_sleep``.
A random amount of jitter (between 0 and 1 seconds) is added to spread out
retry attempts from different clients.
Args:
base_wait (float): The "base" wait time (i.e. without any jitter)
that will be doubled until it reaches the maximum sleep.
max_sleep (float): Maximum value that a sleep time is allowed to be.
Returns:
Tuple[float, float]: The new base wait time as well as the wait time
to be applied (with a random amount of jitter between 0 and 1 seconds
added).
"""
new_base_wait = 2.0 * base_wait
if new_base_wait > max_sleep:
new_base_wait = max_sleep
jitter_ms = random.randint(0, 1000)
return new_base_wait, new_base_wait + 0.001 * jitter_ms
async def wait_and_retry(func, get_status_code, retry_strategy):
"""Attempts to retry a call to ``func`` until success.
Expects ``func`` to return an HTTP response and uses ``get_status_code``
to check if the response is retry-able.
Will retry until :meth:`~.RetryStrategy.retry_allowed` (on the current
``retry_strategy``) returns :data:`False`. Uses
:func:`calculate_retry_wait` to double the wait time (with jitter) after
each attempt.
Args:
func (Callable): A callable that takes no arguments and produces
an HTTP response which will be checked as retry-able.
get_status_code (Callable[Any, int]): Helper to get a status code
from a response.
retry_strategy (~google.resumable_media.common.RetryStrategy): The
strategy to use if the request fails and must be retried.
Returns:
object: The return value of ``func``.
"""
total_sleep = 0.0
num_retries = 0
base_wait = 0.5 # When doubled will give 1.0
while True: # return on success or when retries exhausted.
error = None
try:
response = await func()
except ConnectionError as e:
error = e
else:
if get_status_code(response) not in common.RETRYABLE:
return response
if not retry_strategy.retry_allowed(total_sleep, num_retries):
# Retries are exhausted and no acceptable response was received. Raise the
# retriable_error or return the unacceptable response.
if error:
raise error
return response
base_wait, wait_time = calculate_retry_wait(base_wait, retry_strategy.max_sleep)
num_retries += 1
total_sleep += wait_time
time.sleep(wait_time)
class _DoNothingHash(object):
"""Do-nothing hash object.
Intended as a stand-in for ``hashlib.md5`` or a crc32c checksum
implementation in cases where it isn't necessary to compute the hash.
"""
def update(self, unused_chunk):
"""Do-nothing ``update`` method.
Intended to match the interface of ``hashlib.md5`` and other checksums.
Args:
unused_chunk (bytes): A chunk of data.
"""

View File

@@ -0,0 +1,976 @@
# Copyright 2017 Google Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Virtual bases classes for uploading media via Google APIs.
Supported here are:
* simple (media) uploads
* multipart uploads that contain both metadata and a small file as payload
* resumable uploads (with metadata as well)
"""
import http.client
import json
import os
import random
import sys
from google import _async_resumable_media
from google._async_resumable_media import _helpers
from google.resumable_media import _helpers as sync_helpers
from google.resumable_media import _upload as sync_upload
from google.resumable_media import common
from google.resumable_media._upload import (
_CONTENT_TYPE_HEADER,
_CONTENT_RANGE_TEMPLATE,
_RANGE_UNKNOWN_TEMPLATE,
_EMPTY_RANGE_TEMPLATE,
_BOUNDARY_FORMAT,
_MULTIPART_SEP,
_CRLF,
_MULTIPART_BEGIN,
_RELATED_HEADER,
_BYTES_RANGE_RE,
_STREAM_ERROR_TEMPLATE,
_POST,
_PUT,
_UPLOAD_CHECKSUM_MISMATCH_MESSAGE,
_UPLOAD_METADATA_NO_APPROPRIATE_CHECKSUM_MESSAGE,
)
class UploadBase(object):
"""Base class for upload helpers.
Defines core shared behavior across different upload types.
Args:
upload_url (str): The URL where the content will be uploaded.
headers (Optional[Mapping[str, str]]): Extra headers that should
be sent with the request, e.g. headers for encrypted data.
Attributes:
upload_url (str): The URL where the content will be uploaded.
"""
def __init__(self, upload_url, headers=None):
self.upload_url = upload_url
if headers is None:
headers = {}
self._headers = headers
self._finished = False
self._retry_strategy = common.RetryStrategy()
@property
def finished(self):
"""bool: Flag indicating if the upload has completed."""
return self._finished
def _process_response(self, response):
"""Process the response from an HTTP request.
This is everything that must be done after a request that doesn't
require network I/O (or other I/O). This is based on the `sans-I/O`_
philosophy.
Args:
response (object): The HTTP response object.
Raises:
~google.resumable_media.common.InvalidResponse: If the status
code is not 200.
.. _sans-I/O: https://sans-io.readthedocs.io/
"""
# Tombstone the current upload so it cannot be used again (in either
# failure or success).
self._finished = True
_helpers.require_status_code(response, (http.client.OK,), self._get_status_code)
@staticmethod
def _get_status_code(response):
"""Access the status code from an HTTP response.
Args:
response (object): The HTTP response object.
Raises:
NotImplementedError: Always, since virtual.
"""
raise NotImplementedError("This implementation is virtual.")
@staticmethod
def _get_headers(response):
"""Access the headers from an HTTP response.
Args:
response (object): The HTTP response object.
Raises:
NotImplementedError: Always, since virtual.
"""
raise NotImplementedError("This implementation is virtual.")
@staticmethod
def _get_body(response):
"""Access the response body from an HTTP response.
Args:
response (object): The HTTP response object.
Raises:
NotImplementedError: Always, since virtual.
"""
raise NotImplementedError("This implementation is virtual.")
class SimpleUpload(UploadBase):
"""Upload a resource to a Google API.
A **simple** media upload sends no metadata and completes the upload
in a single request.
Args:
upload_url (str): The URL where the content will be uploaded.
headers (Optional[Mapping[str, str]]): Extra headers that should
be sent with the request, e.g. headers for encrypted data.
Attributes:
upload_url (str): The URL where the content will be uploaded.
"""
def _prepare_request(self, data, content_type):
"""Prepare the contents of an HTTP request.
This is everything that must be done before a request that doesn't
require network I/O (or other I/O). This is based on the `sans-I/O`_
philosophy.
.. note:
This method will be used only once, so ``headers`` will be
mutated by having a new key added to it.
Args:
data (bytes): The resource content to be uploaded.
content_type (str): The content type for the request.
Returns:
Tuple[str, str, bytes, Mapping[str, str]]: The quadruple
* HTTP verb for the request (always POST)
* the URL for the request
* the body of the request
* headers for the request
Raises:
ValueError: If the current upload has already finished.
TypeError: If ``data`` isn't bytes.
.. _sans-I/O: https://sans-io.readthedocs.io/
"""
if self.finished:
raise ValueError("An upload can only be used once.")
if not isinstance(data, bytes):
raise TypeError("`data` must be bytes, received", type(data))
self._headers[_CONTENT_TYPE_HEADER] = content_type
return _POST, self.upload_url, data, self._headers
def transmit(self, transport, data, content_type, timeout=None):
"""Transmit the resource to be uploaded.
Args:
transport (object): An object which can make authenticated
requests.
data (bytes): The resource content to be uploaded.
content_type (str): The content type of the resource, e.g. a JPEG
image has content type ``image/jpeg``.
timeout (Optional[Union[float, aiohttp.ClientTimeout]]):
The number of seconds to wait for the server response.
Depending on the retry strategy, a request may be repeated
several times using the same timeout each time.
Can also be passed as an `aiohttp.ClientTimeout` object.
Raises:
NotImplementedError: Always, since virtual.
"""
raise NotImplementedError("This implementation is virtual.")
class MultipartUpload(UploadBase):
"""Upload a resource with metadata to a Google API.
A **multipart** upload sends both metadata and the resource in a single
(multipart) request.
Args:
upload_url (str): The URL where the content will be uploaded.
headers (Optional[Mapping[str, str]]): Extra headers that should
be sent with the request, e.g. headers for encrypted data.
checksum Optional([str]): The type of checksum to compute to verify
the integrity of the object. The request metadata will be amended
to include the computed value. Using this option will override a
manually-set checksum value. Supported values are "md5", "crc32c"
and None. The default is None.
Attributes:
upload_url (str): The URL where the content will be uploaded.
"""
def __init__(self, upload_url, headers=None, checksum=None):
super(MultipartUpload, self).__init__(upload_url, headers=headers)
self._checksum_type = checksum
def _prepare_request(self, data, metadata, content_type):
"""Prepare the contents of an HTTP request.
This is everything that must be done before a request that doesn't
require network I/O (or other I/O). This is based on the `sans-I/O`_
philosophy.
.. note:
This method will be used only once, so ``headers`` will be
mutated by having a new key added to it.
Args:
data (bytes): The resource content to be uploaded.
metadata (Mapping[str, str]): The resource metadata, such as an
ACL list.
content_type (str): The content type of the resource, e.g. a JPEG
image has content type ``image/jpeg``.
Returns:
Tuple[str, str, bytes, Mapping[str, str]]: The quadruple
* HTTP verb for the request (always POST)
* the URL for the request
* the body of the request
* headers for the request
Raises:
ValueError: If the current upload has already finished.
TypeError: If ``data`` isn't bytes.
.. _sans-I/O: https://sans-io.readthedocs.io/
"""
if self.finished:
raise ValueError("An upload can only be used once.")
if not isinstance(data, bytes):
raise TypeError("`data` must be bytes, received", type(data))
checksum_object = sync_helpers._get_checksum_object(self._checksum_type)
if checksum_object is not None:
checksum_object.update(data)
actual_checksum = sync_helpers.prepare_checksum_digest(
checksum_object.digest()
)
metadata_key = sync_helpers._get_metadata_key(self._checksum_type)
metadata[metadata_key] = actual_checksum
content, multipart_boundary = construct_multipart_request(
data, metadata, content_type
)
multipart_content_type = _RELATED_HEADER + multipart_boundary + b'"'
self._headers[_CONTENT_TYPE_HEADER] = multipart_content_type
return _POST, self.upload_url, content, self._headers
def transmit(self, transport, data, metadata, content_type, timeout=None):
"""Transmit the resource to be uploaded.
Args:
transport (object): An object which can make authenticated
requests.
data (bytes): The resource content to be uploaded.
metadata (Mapping[str, str]): The resource metadata, such as an
ACL list.
content_type (str): The content type of the resource, e.g. a JPEG
image has content type ``image/jpeg``.
timeout (Optional[Union[float, aiohttp.ClientTimeout]]):
The number of seconds to wait for the server response.
Depending on the retry strategy, a request may be repeated
several times using the same timeout each time.
Can also be passed as an `aiohttp.ClientTimeout` object.
Raises:
NotImplementedError: Always, since virtual.
"""
raise NotImplementedError("This implementation is virtual.")
class ResumableUpload(UploadBase, sync_upload.ResumableUpload):
"""Initiate and fulfill a resumable upload to a Google API.
A **resumable** upload sends an initial request with the resource metadata
and then gets assigned an upload ID / upload URL to send bytes to.
Using the upload URL, the upload is then done in chunks (determined by
the user) until all bytes have been uploaded.
Args:
upload_url (str): The URL where the resumable upload will be initiated.
chunk_size (int): The size of each chunk used to upload the resource.
headers (Optional[Mapping[str, str]]): Extra headers that should
be sent with the :meth:`initiate` request, e.g. headers for
encrypted data. These **will not** be sent with
:meth:`transmit_next_chunk` or :meth:`recover` requests.
checksum Optional([str]): The type of checksum to compute to verify
the integrity of the object. After the upload is complete, the
server-computed checksum of the resulting object will be read
and google.resumable_media.common.DataCorruption will be raised on
a mismatch. The corrupted file will not be deleted from the remote
host automatically. Supported values are "md5", "crc32c" and None.
The default is None.
Attributes:
upload_url (str): The URL where the content will be uploaded.
Raises:
ValueError: If ``chunk_size`` is not a multiple of
:data:`.UPLOAD_CHUNK_SIZE`.
"""
def __init__(self, upload_url, chunk_size, checksum=None, headers=None):
super(ResumableUpload, self).__init__(upload_url, headers=headers)
if chunk_size % _async_resumable_media.UPLOAD_CHUNK_SIZE != 0:
raise ValueError(
"{} KB must divide chunk size".format(
_async_resumable_media.UPLOAD_CHUNK_SIZE / 1024
)
)
self._chunk_size = chunk_size
self._stream = None
self._content_type = None
self._bytes_uploaded = 0
self._bytes_checksummed = 0
self._checksum_type = checksum
self._checksum_object = None
self._total_bytes = None
self._resumable_url = None
self._invalid = False
@property
def invalid(self):
"""bool: Indicates if the upload is in an invalid state.
This will occur if a call to :meth:`transmit_next_chunk` fails.
To recover from such a failure, call :meth:`recover`.
"""
return self._invalid
@property
def chunk_size(self):
"""int: The size of each chunk used to upload the resource."""
return self._chunk_size
@property
def resumable_url(self):
"""Optional[str]: The URL of the in-progress resumable upload."""
return self._resumable_url
@property
def bytes_uploaded(self):
"""int: Number of bytes that have been uploaded."""
return self._bytes_uploaded
@property
def total_bytes(self):
"""Optional[int]: The total number of bytes to be uploaded.
If this upload is initiated (via :meth:`initiate`) with
``stream_final=True``, this value will be populated based on the size
of the ``stream`` being uploaded. (By default ``stream_final=True``.)
If this upload is initiated with ``stream_final=False``,
:attr:`total_bytes` will be :data:`None` since it cannot be
determined from the stream.
"""
return self._total_bytes
def _prepare_initiate_request(
self, stream, metadata, content_type, total_bytes=None, stream_final=True
):
"""Prepare the contents of HTTP request to initiate upload.
This is everything that must be done before a request that doesn't
require network I/O (or other I/O). This is based on the `sans-I/O`_
philosophy.
Args:
stream (IO[bytes]): The stream (i.e. file-like object) that will
be uploaded. The stream **must** be at the beginning (i.e.
``stream.tell() == 0``).
metadata (Mapping[str, str]): The resource metadata, such as an
ACL list.
content_type (str): The content type of the resource, e.g. a JPEG
image has content type ``image/jpeg``.
total_bytes (Optional[int]): The total number of bytes to be
uploaded. If specified, the upload size **will not** be
determined from the stream (even if ``stream_final=True``).
stream_final (Optional[bool]): Indicates if the ``stream`` is
"final" (i.e. no more bytes will be added to it). In this case
we determine the upload size from the size of the stream. If
``total_bytes`` is passed, this argument will be ignored.
Returns:
Tuple[str, str, bytes, Mapping[str, str]]: The quadruple
* HTTP verb for the request (always POST)
* the URL for the request
* the body of the request
* headers for the request
Raises:
ValueError: If the current upload has already been initiated.
ValueError: If ``stream`` is not at the beginning.
.. _sans-I/O: https://sans-io.readthedocs.io/
"""
if self.resumable_url is not None:
raise ValueError("This upload has already been initiated.")
if stream.tell() != 0:
raise ValueError("Stream must be at beginning.")
self._stream = stream
self._content_type = content_type
headers = {
_CONTENT_TYPE_HEADER: "application/json; charset=UTF-8",
"x-upload-content-type": content_type,
}
# Set the total bytes if possible.
if total_bytes is not None:
self._total_bytes = total_bytes
elif stream_final:
self._total_bytes = get_total_bytes(stream)
# Add the total bytes to the headers if set.
if self._total_bytes is not None:
content_length = "{:d}".format(self._total_bytes)
headers["x-upload-content-length"] = content_length
headers.update(self._headers)
payload = json.dumps(metadata).encode("utf-8")
return _POST, self.upload_url, payload, headers
def _process_initiate_response(self, response):
"""Process the response from an HTTP request that initiated upload.
This is everything that must be done after a request that doesn't
require network I/O (or other I/O). This is based on the `sans-I/O`_
philosophy.
This method takes the URL from the ``Location`` header and stores it
for future use. Within that URL, we assume the ``upload_id`` query
parameter has been included, but we do not check.
Args:
response (object): The HTTP response object (need headers).
.. _sans-I/O: https://sans-io.readthedocs.io/
"""
_helpers.require_status_code(
response,
(http.client.OK,),
self._get_status_code,
callback=self._make_invalid,
)
self._resumable_url = _helpers.header_required(
response, "location", self._get_headers
)
def initiate(
self,
transport,
stream,
metadata,
content_type,
total_bytes=None,
stream_final=True,
timeout=None,
):
"""Initiate a resumable upload.
By default, this method assumes your ``stream`` is in a "final"
state ready to transmit. However, ``stream_final=False`` can be used
to indicate that the size of the resource is not known. This can happen
if bytes are being dynamically fed into ``stream``, e.g. if the stream
is attached to application logs.
If ``stream_final=False`` is used, :attr:`chunk_size` bytes will be
read from the stream every time :meth:`transmit_next_chunk` is called.
If one of those reads produces strictly fewer bites than the chunk
size, the upload will be concluded.
Args:
transport (object): An object which can make authenticated
requests.
stream (IO[bytes]): The stream (i.e. file-like object) that will
be uploaded. The stream **must** be at the beginning (i.e.
``stream.tell() == 0``).
metadata (Mapping[str, str]): The resource metadata, such as an
ACL list.
content_type (str): The content type of the resource, e.g. a JPEG
image has content type ``image/jpeg``.
total_bytes (Optional[int]): The total number of bytes to be
uploaded. If specified, the upload size **will not** be
determined from the stream (even if ``stream_final=True``).
stream_final (Optional[bool]): Indicates if the ``stream`` is
"final" (i.e. no more bytes will be added to it). In this case
we determine the upload size from the size of the stream. If
``total_bytes`` is passed, this argument will be ignored.
timeout (Optional[Union[float, aiohttp.ClientTimeout]]):
The number of seconds to wait for the server response.
Depending on the retry strategy, a request may be repeated
several times using the same timeout each time.
Can also be passed as an `aiohttp.ClientTimeout` object.
Raises:
NotImplementedError: Always, since virtual.
"""
raise NotImplementedError("This implementation is virtual.")
def _prepare_request(self):
"""Prepare the contents of HTTP request to upload a chunk.
This is everything that must be done before a request that doesn't
require network I/O. This is based on the `sans-I/O`_ philosophy.
For the time being, this **does require** some form of I/O to read
a chunk from ``stream`` (via :func:`get_next_chunk`). However, this
will (almost) certainly not be network I/O.
Returns:
Tuple[str, str, bytes, Mapping[str, str]]: The quadruple
* HTTP verb for the request (always PUT)
* the URL for the request
* the body of the request
* headers for the request
The headers **do not** incorporate the ``_headers`` on the
current instance.
Raises:
ValueError: If the current upload has finished.
ValueError: If the current upload is in an invalid state.
ValueError: If the current upload has not been initiated.
ValueError: If the location in the stream (i.e. ``stream.tell()``)
does not agree with ``bytes_uploaded``.
.. _sans-I/O: https://sans-io.readthedocs.io/
"""
if self.finished:
raise ValueError("Upload has finished.")
if self.invalid:
raise ValueError(
"Upload is in an invalid state. To recover call `recover()`."
)
if self.resumable_url is None:
raise ValueError(
"This upload has not been initiated. Please call "
"initiate() before beginning to transmit chunks."
)
start_byte, payload, content_range = get_next_chunk(
self._stream, self._chunk_size, self._total_bytes
)
if start_byte != self.bytes_uploaded:
msg = _STREAM_ERROR_TEMPLATE.format(start_byte, self.bytes_uploaded)
raise ValueError(msg)
self._update_checksum(start_byte, payload)
headers = {
_CONTENT_TYPE_HEADER: self._content_type,
_helpers.CONTENT_RANGE_HEADER: content_range,
}
return _PUT, self.resumable_url, payload, headers
def _make_invalid(self):
"""Simple setter for ``invalid``.
This is intended to be passed along as a callback to helpers that
raise an exception so they can mark this instance as invalid before
raising.
"""
self._invalid = True
async def _process_resumable_response(self, response, bytes_sent):
"""Process the response from an HTTP request.
This is everything that must be done after a request that doesn't
require network I/O (or other I/O). This is based on the `sans-I/O`_
philosophy.
Args:
response (object): The HTTP response object.
bytes_sent (int): The number of bytes sent in the request that
``response`` was returned for.
Raises:
~google.resumable_media.common.InvalidResponse: If the status
code is 308 and the ``range`` header is not of the form
``bytes 0-{end}``.
~google.resumable_media.common.InvalidResponse: If the status
code is not 200 or 308.
.. _sans-I/O: https://sans-io.readthedocs.io/
"""
status_code = _helpers.require_status_code(
response,
(http.client.OK, http.client.PERMANENT_REDIRECT),
self._get_status_code,
callback=self._make_invalid,
)
if status_code == http.client.OK:
# NOTE: We use the "local" information of ``bytes_sent`` to update
# ``bytes_uploaded``, but do not verify this against other
# state. However, there may be some other information:
#
# * a ``size`` key in JSON response body
# * the ``total_bytes`` attribute (if set)
# * ``stream.tell()`` (relying on fact that ``initiate()``
# requires stream to be at the beginning)
self._bytes_uploaded = self._bytes_uploaded + bytes_sent
# Tombstone the current upload so it cannot be used again.
self._finished = True
# Validate the checksum. This can raise an exception on failure.
await self._validate_checksum(response)
else:
bytes_range = _helpers.header_required(
response,
_helpers.RANGE_HEADER,
self._get_headers,
callback=self._make_invalid,
)
match = _BYTES_RANGE_RE.match(bytes_range)
if match is None:
self._make_invalid()
raise common.InvalidResponse(
response,
'Unexpected "range" header',
bytes_range,
'Expected to be of the form "bytes=0-{end}"',
)
self._bytes_uploaded = int(match.group("end_byte")) + 1
async def _validate_checksum(self, response):
"""Check the computed checksum, if any, against the response headers.
Args:
response (object): The HTTP response object.
Raises:
~google.resumable_media.common.DataCorruption: If the checksum
computed locally and the checksum reported by the remote host do
not match.
"""
if self._checksum_type is None:
return
metadata_key = sync_helpers._get_metadata_key(self._checksum_type)
metadata = await response.json()
remote_checksum = metadata.get(metadata_key)
if remote_checksum is None:
raise common.InvalidResponse(
response,
_UPLOAD_METADATA_NO_APPROPRIATE_CHECKSUM_MESSAGE.format(metadata_key),
self._get_headers(response),
)
local_checksum = sync_helpers.prepare_checksum_digest(
self._checksum_object.digest()
)
if local_checksum != remote_checksum:
raise common.DataCorruption(
response,
_UPLOAD_CHECKSUM_MISMATCH_MESSAGE.format(
self._checksum_type.upper(), local_checksum, remote_checksum
),
)
def transmit_next_chunk(self, transport, timeout=None):
"""Transmit the next chunk of the resource to be uploaded.
If the current upload was initiated with ``stream_final=False``,
this method will dynamically determine if the upload has completed.
The upload will be considered complete if the stream produces
fewer than :attr:`chunk_size` bytes when a chunk is read from it.
Args:
transport (object): An object which can make authenticated
requests.
timeout (Optional[Union[float, aiohttp.ClientTimeout]]):
The number of seconds to wait for the server response.
Depending on the retry strategy, a request may be repeated
several times using the same timeout each time.
Can also be passed as an `aiohttp.ClientTimeout` object.
Raises:
NotImplementedError: Always, since virtual.
"""
raise NotImplementedError("This implementation is virtual.")
def _prepare_recover_request(self):
"""Prepare the contents of HTTP request to recover from failure.
This is everything that must be done before a request that doesn't
require network I/O. This is based on the `sans-I/O`_ philosophy.
We assume that the :attr:`resumable_url` is set (i.e. the only way
the upload can end up :attr:`invalid` is if it has been initiated.
Returns:
Tuple[str, str, NoneType, Mapping[str, str]]: The quadruple
* HTTP verb for the request (always PUT)
* the URL for the request
* the body of the request (always :data:`None`)
* headers for the request
The headers **do not** incorporate the ``_headers`` on the
current instance.
Raises:
ValueError: If the current upload is not in an invalid state.
.. _sans-I/O: https://sans-io.readthedocs.io/
"""
if not self.invalid:
raise ValueError("Upload is not in invalid state, no need to recover.")
headers = {_helpers.CONTENT_RANGE_HEADER: "bytes */*"}
return _PUT, self.resumable_url, None, headers
def _process_recover_response(self, response):
"""Process the response from an HTTP request to recover from failure.
This is everything that must be done after a request that doesn't
require network I/O (or other I/O). This is based on the `sans-I/O`_
philosophy.
Args:
response (object): The HTTP response object.
Raises:
~google.resumable_media.common.InvalidResponse: If the status
code is not 308.
~google.resumable_media.common.InvalidResponse: If the status
code is 308 and the ``range`` header is not of the form
``bytes 0-{end}``.
.. _sans-I/O: https://sans-io.readthedocs.io/
"""
_helpers.require_status_code(
response,
(http.client.PERMANENT_REDIRECT,),
self._get_status_code,
)
headers = self._get_headers(response)
if _helpers.RANGE_HEADER in headers:
bytes_range = headers[_helpers.RANGE_HEADER]
match = _BYTES_RANGE_RE.match(bytes_range)
if match is None:
raise common.InvalidResponse(
response,
'Unexpected "range" header',
bytes_range,
'Expected to be of the form "bytes=0-{end}"',
)
self._bytes_uploaded = int(match.group("end_byte")) + 1
else:
# In this case, the upload has not "begun".
self._bytes_uploaded = 0
self._stream.seek(self._bytes_uploaded)
self._invalid = False
def recover(self, transport):
"""Recover from a failure.
This method should be used when a :class:`ResumableUpload` is in an
:attr:`~ResumableUpload.invalid` state due to a request failure.
This will verify the progress with the server and make sure the
current upload is in a valid state before :meth:`transmit_next_chunk`
can be used again.
Args:
transport (object): An object which can make authenticated
requests.
Raises:
NotImplementedError: Always, since virtual.
"""
raise NotImplementedError("This implementation is virtual.")
def get_boundary():
"""Get a random boundary for a multipart request.
Returns:
bytes: The boundary used to separate parts of a multipart request.
"""
random_int = random.randrange(sys.maxsize)
boundary = _BOUNDARY_FORMAT.format(random_int)
# NOTE: Neither % formatting nor .format() are available for byte strings
# in Python 3.4, so we must use unicode strings as templates.
return boundary.encode("utf-8")
def construct_multipart_request(data, metadata, content_type):
"""Construct a multipart request body.
Args:
data (bytes): The resource content (UTF-8 encoded as bytes)
to be uploaded.
metadata (Mapping[str, str]): The resource metadata, such as an
ACL list.
content_type (str): The content type of the resource, e.g. a JPEG
image has content type ``image/jpeg``.
Returns:
Tuple[bytes, bytes]: The multipart request body and the boundary used
between each part.
"""
multipart_boundary = get_boundary()
json_bytes = json.dumps(metadata).encode("utf-8")
content_type = content_type.encode("utf-8")
# Combine the two parts into a multipart payload.
# NOTE: We'd prefer a bytes template but are restricted by Python 3.4.
boundary_sep = _MULTIPART_SEP + multipart_boundary
content = (
boundary_sep
+ _MULTIPART_BEGIN
+ json_bytes
+ _CRLF
+ boundary_sep
+ _CRLF
+ b"content-type: "
+ content_type
+ _CRLF
+ _CRLF
+ data # Empty line between headers and body.
+ _CRLF
+ boundary_sep
+ _MULTIPART_SEP
)
return content, multipart_boundary
def get_total_bytes(stream):
"""Determine the total number of bytes in a stream.
Args:
stream (IO[bytes]): The stream (i.e. file-like object).
Returns:
int: The number of bytes.
"""
current_position = stream.tell()
# NOTE: ``.seek()`` **should** return the same value that ``.tell()``
# returns, but in Python 2, ``file`` objects do not.
stream.seek(0, os.SEEK_END)
end_position = stream.tell()
# Go back to the initial position.
stream.seek(current_position)
return end_position
def get_next_chunk(stream, chunk_size, total_bytes):
"""Get a chunk from an I/O stream.
The ``stream`` may have fewer bytes remaining than ``chunk_size``
so it may not always be the case that
``end_byte == start_byte + chunk_size - 1``.
Args:
stream (IO[bytes]): The stream (i.e. file-like object).
chunk_size (int): The size of the chunk to be read from the ``stream``.
total_bytes (Optional[int]): The (expected) total number of bytes
in the ``stream``.
Returns:
Tuple[int, bytes, str]: Triple of:
* the start byte index
* the content in between the start and end bytes (inclusive)
* content range header for the chunk (slice) that has been read
Raises:
ValueError: If ``total_bytes == 0`` but ``stream.read()`` yields
non-empty content.
ValueError: If there is no data left to consume. This corresponds
exactly to the case ``end_byte < start_byte``, which can only
occur if ``end_byte == start_byte - 1``.
"""
start_byte = stream.tell()
if total_bytes is not None and start_byte + chunk_size >= total_bytes > 0:
payload = stream.read(total_bytes - start_byte)
else:
payload = stream.read(chunk_size)
end_byte = stream.tell() - 1
num_bytes_read = len(payload)
if total_bytes is None:
if num_bytes_read < chunk_size:
# We now **KNOW** the total number of bytes.
total_bytes = end_byte + 1
elif total_bytes == 0:
# NOTE: We also expect ``start_byte == 0`` here but don't check
# because ``_prepare_initiate_request()`` requires the
# stream to be at the beginning.
if num_bytes_read != 0:
raise ValueError(
"Stream specified as empty, but produced non-empty content."
)
else:
if num_bytes_read == 0:
raise ValueError(
"Stream is already exhausted. There is no content remaining."
)
content_range = get_content_range(start_byte, end_byte, total_bytes)
return start_byte, payload, content_range
def get_content_range(start_byte, end_byte, total_bytes):
"""Convert start, end and total into content range header.
If ``total_bytes`` is not known, uses "bytes {start}-{end}/*".
If we are dealing with an empty range (i.e. ``end_byte < start_byte``)
then "bytes */{total}" is used.
This function **ASSUMES** that if the size is not known, the caller will
not also pass an empty range.
Args:
start_byte (int): The start (inclusive) of the byte range.
end_byte (int): The end (inclusive) of the byte range.
total_bytes (Optional[int]): The number of bytes in the byte
range (if known).
Returns:
str: The content range header.
"""
if total_bytes is None:
return _RANGE_UNKNOWN_TEMPLATE.format(start_byte, end_byte)
elif end_byte < start_byte:
return _EMPTY_RANGE_TEMPLATE.format(total_bytes)
else:
return _CONTENT_RANGE_TEMPLATE.format(start_byte, end_byte, total_bytes)

View File

@@ -0,0 +1,682 @@
# Copyright 2017 Google Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""``requests`` utilities for Google Media Downloads and Resumable Uploads.
This sub-package assumes callers will use the `requests`_ library
as transport and `google-auth`_ for sending authenticated HTTP traffic
with ``requests``.
.. _requests: http://docs.python-requests.org/
.. _google-auth: https://google-auth.readthedocs.io/
====================
Authorized Transport
====================
To use ``google-auth`` and ``requests`` to create an authorized transport
that has read-only access to Google Cloud Storage (GCS):
.. testsetup:: get-credentials
import google.auth
import google.auth.credentials as creds_mod
import mock
def mock_default(scopes=None):
credentials = mock.Mock(spec=creds_mod.Credentials)
return credentials, 'mock-project'
# Patch the ``default`` function on the module.
original_default = google.auth.default
google.auth.default = mock_default
.. doctest:: get-credentials
>>> import google.auth
>>> import google.auth.transport.requests as tr_requests
>>>
>>> ro_scope = 'https://www.googleapis.com/auth/devstorage.read_only'
>>> credentials, _ = google.auth.default(scopes=(ro_scope,))
>>> transport = tr_requests.AuthorizedSession(credentials)
>>> transport
<google.auth.transport.requests.AuthorizedSession object at 0x...>
.. testcleanup:: get-credentials
# Put back the correct ``default`` function on the module.
google.auth.default = original_default
================
Simple Downloads
================
To download an object from Google Cloud Storage, construct the media URL
for the GCS object and download it with an authorized transport that has
access to the resource:
.. testsetup:: basic-download
import mock
import requests
import http.client
bucket = 'bucket-foo'
blob_name = 'file.txt'
fake_response = requests.Response()
fake_response.status_code = int(http.client.OK)
fake_response.headers['Content-Length'] = '1364156'
fake_content = mock.MagicMock(spec=['__len__'])
fake_content.__len__.return_value = 1364156
fake_response._content = fake_content
get_method = mock.Mock(return_value=fake_response, spec=[])
transport = mock.Mock(request=get_method, spec=['request'])
.. doctest:: basic-download
>>> from google.resumable_media.requests import Download
>>>
>>> url_template = (
... 'https://www.googleapis.com/download/storage/v1/b/'
... '{bucket}/o/{blob_name}?alt=media')
>>> media_url = url_template.format(
... bucket=bucket, blob_name=blob_name)
>>>
>>> download = Download(media_url)
>>> response = download.consume(transport)
>>> download.finished
True
>>> response
<Response [200]>
>>> response.headers['Content-Length']
'1364156'
>>> len(response.content)
1364156
To download only a portion of the bytes in the object,
specify ``start`` and ``end`` byte positions (both optional):
.. testsetup:: basic-download-with-slice
import mock
import requests
import http.client
from google.resumable_media.requests import Download
media_url = 'http://test.invalid'
start = 4096
end = 8191
slice_size = end - start + 1
fake_response = requests.Response()
fake_response.status_code = int(http.client.PARTIAL_CONTENT)
fake_response.headers['Content-Length'] = '{:d}'.format(slice_size)
content_range = 'bytes {:d}-{:d}/1364156'.format(start, end)
fake_response.headers['Content-Range'] = content_range
fake_content = mock.MagicMock(spec=['__len__'])
fake_content.__len__.return_value = slice_size
fake_response._content = fake_content
get_method = mock.Mock(return_value=fake_response, spec=[])
transport = mock.Mock(request=get_method, spec=['request'])
.. doctest:: basic-download-with-slice
>>> download = Download(media_url, start=4096, end=8191)
>>> response = download.consume(transport)
>>> download.finished
True
>>> response
<Response [206]>
>>> response.headers['Content-Length']
'4096'
>>> response.headers['Content-Range']
'bytes 4096-8191/1364156'
>>> len(response.content)
4096
=================
Chunked Downloads
=================
For very large objects or objects of unknown size, it may make more sense
to download the object in chunks rather than all at once. This can be done
to avoid dropped connections with a poor internet connection or can allow
multiple chunks to be downloaded in parallel to speed up the total
download.
A :class:`.ChunkedDownload` uses the same media URL and authorized
transport that a basic :class:`.Download` would use, but also
requires a chunk size and a write-able byte ``stream``. The chunk size is used
to determine how much of the resouce to consume with each request and the
stream is to allow the resource to be written out (e.g. to disk) without
having to fit in memory all at once.
.. testsetup:: chunked-download
import io
import mock
import requests
import http.client
media_url = 'http://test.invalid'
fifty_mb = 50 * 1024 * 1024
one_gb = 1024 * 1024 * 1024
fake_response = requests.Response()
fake_response.status_code = int(http.client.PARTIAL_CONTENT)
fake_response.headers['Content-Length'] = '{:d}'.format(fifty_mb)
content_range = 'bytes 0-{:d}/{:d}'.format(fifty_mb - 1, one_gb)
fake_response.headers['Content-Range'] = content_range
fake_content_begin = b'The beginning of the chunk...'
fake_content = fake_content_begin + b'1' * (fifty_mb - 29)
fake_response._content = fake_content
get_method = mock.Mock(return_value=fake_response, spec=[])
transport = mock.Mock(request=get_method, spec=['request'])
.. doctest:: chunked-download
>>> from google.resumable_media.requests import ChunkedDownload
>>>
>>> chunk_size = 50 * 1024 * 1024 # 50MB
>>> stream = io.BytesIO()
>>> download = ChunkedDownload(
... media_url, chunk_size, stream)
>>> # Check the state of the download before starting.
>>> download.bytes_downloaded
0
>>> download.total_bytes is None
True
>>> response = download.consume_next_chunk(transport)
>>> # Check the state of the download after consuming one chunk.
>>> download.finished
False
>>> download.bytes_downloaded # chunk_size
52428800
>>> download.total_bytes # 1GB
1073741824
>>> response
<Response [206]>
>>> response.headers['Content-Length']
'52428800'
>>> response.headers['Content-Range']
'bytes 0-52428799/1073741824'
>>> len(response.content) == chunk_size
True
>>> stream.seek(0)
0
>>> stream.read(29)
b'The beginning of the chunk...'
The download will change it's ``finished`` status to :data:`True`
once the final chunk is consumed. In some cases, the final chunk may
not be the same size as the other chunks:
.. testsetup:: chunked-download-end
import mock
import requests
import http.client
from google.resumable_media.requests import ChunkedDownload
media_url = 'http://test.invalid'
fifty_mb = 50 * 1024 * 1024
one_gb = 1024 * 1024 * 1024
stream = mock.Mock(spec=['write'])
download = ChunkedDownload(media_url, fifty_mb, stream)
download._bytes_downloaded = 20 * fifty_mb
download._total_bytes = one_gb
fake_response = requests.Response()
fake_response.status_code = int(http.client.PARTIAL_CONTENT)
slice_size = one_gb - 20 * fifty_mb
fake_response.headers['Content-Length'] = '{:d}'.format(slice_size)
content_range = 'bytes {:d}-{:d}/{:d}'.format(
20 * fifty_mb, one_gb - 1, one_gb)
fake_response.headers['Content-Range'] = content_range
fake_content = mock.MagicMock(spec=['__len__'])
fake_content.__len__.return_value = slice_size
fake_response._content = fake_content
get_method = mock.Mock(return_value=fake_response, spec=[])
transport = mock.Mock(request=get_method, spec=['request'])
.. doctest:: chunked-download-end
>>> # The state of the download in progress.
>>> download.finished
False
>>> download.bytes_downloaded # 20 chunks at 50MB
1048576000
>>> download.total_bytes # 1GB
1073741824
>>> response = download.consume_next_chunk(transport)
>>> # The state of the download after consuming the final chunk.
>>> download.finished
True
>>> download.bytes_downloaded == download.total_bytes
True
>>> response
<Response [206]>
>>> response.headers['Content-Length']
'25165824'
>>> response.headers['Content-Range']
'bytes 1048576000-1073741823/1073741824'
>>> len(response.content) < download.chunk_size
True
In addition, a :class:`.ChunkedDownload` can also take optional
``start`` and ``end`` byte positions.
Usually, no checksum is returned with a chunked download. Even if one is returned,
it is not validated. If you need to validate the checksum, you can do so
by buffering the chunks and validating the checksum against the completed download.
==============
Simple Uploads
==============
Among the three supported upload classes, the simplest is
:class:`.SimpleUpload`. A simple upload should be used when the resource
being uploaded is small and when there is no metadata (other than the name)
associated with the resource.
.. testsetup:: simple-upload
import json
import mock
import requests
import http.client
bucket = 'some-bucket'
blob_name = 'file.txt'
fake_response = requests.Response()
fake_response.status_code = int(http.client.OK)
payload = {
'bucket': bucket,
'contentType': 'text/plain',
'md5Hash': 'M0XLEsX9/sMdiI+4pB4CAQ==',
'name': blob_name,
'size': '27',
}
fake_response._content = json.dumps(payload).encode('utf-8')
post_method = mock.Mock(return_value=fake_response, spec=[])
transport = mock.Mock(request=post_method, spec=['request'])
.. doctest:: simple-upload
:options: +NORMALIZE_WHITESPACE
>>> from google.resumable_media.requests import SimpleUpload
>>>
>>> url_template = (
... 'https://www.googleapis.com/upload/storage/v1/b/{bucket}/o?'
... 'uploadType=media&'
... 'name={blob_name}')
>>> upload_url = url_template.format(
... bucket=bucket, blob_name=blob_name)
>>>
>>> upload = SimpleUpload(upload_url)
>>> data = b'Some not too large content.'
>>> content_type = 'text/plain'
>>> response = upload.transmit(transport, data, content_type)
>>> upload.finished
True
>>> response
<Response [200]>
>>> json_response = response.json()
>>> json_response['bucket'] == bucket
True
>>> json_response['name'] == blob_name
True
>>> json_response['contentType'] == content_type
True
>>> json_response['md5Hash']
'M0XLEsX9/sMdiI+4pB4CAQ=='
>>> int(json_response['size']) == len(data)
True
In the rare case that an upload fails, an :exc:`.InvalidResponse`
will be raised:
.. testsetup:: simple-upload-fail
import time
import mock
import requests
import http.client
from google import resumable_media
from google.resumable_media import _helpers
from google.resumable_media.requests import SimpleUpload as constructor
upload_url = 'http://test.invalid'
data = b'Some not too large content.'
content_type = 'text/plain'
fake_response = requests.Response()
fake_response.status_code = int(http.client.SERVICE_UNAVAILABLE)
post_method = mock.Mock(return_value=fake_response, spec=[])
transport = mock.Mock(request=post_method, spec=['request'])
time_sleep = time.sleep
def dont_sleep(seconds):
raise RuntimeError('No sleep', seconds)
def SimpleUpload(*args, **kwargs):
upload = constructor(*args, **kwargs)
# Mock the cumulative sleep to avoid retries (and `time.sleep()`).
upload._retry_strategy = resumable_media.RetryStrategy(
max_cumulative_retry=-1.0)
return upload
time.sleep = dont_sleep
.. doctest:: simple-upload-fail
:options: +NORMALIZE_WHITESPACE
>>> upload = SimpleUpload(upload_url)
>>> error = None
>>> try:
... upload.transmit(transport, data, content_type)
... except resumable_media.InvalidResponse as caught_exc:
... error = caught_exc
...
>>> error
InvalidResponse('Request failed with status code', 503,
'Expected one of', <HTTPStatus.OK: 200>)
>>> error.response
<Response [503]>
>>>
>>> upload.finished
True
.. testcleanup:: simple-upload-fail
# Put back the correct ``sleep`` function on the ``time`` module.
time.sleep = time_sleep
Even in the case of failure, we see that the upload is
:attr:`~.SimpleUpload.finished`, i.e. it cannot be re-used.
=================
Multipart Uploads
=================
After the simple upload, the :class:`.MultipartUpload` can be used to
achieve essentially the same task. However, a multipart upload allows some
metadata about the resource to be sent along as well. (This is the "multi":
we send a first part with the metadata and a second part with the actual
bytes in the resource.)
Usage is similar to the simple upload, but :meth:`~.MultipartUpload.transmit`
accepts an extra required argument: ``metadata``.
.. testsetup:: multipart-upload
import json
import mock
import requests
import http.client
bucket = 'some-bucket'
blob_name = 'file.txt'
data = b'Some not too large content.'
content_type = 'text/plain'
fake_response = requests.Response()
fake_response.status_code = int(http.client.OK)
payload = {
'bucket': bucket,
'name': blob_name,
'metadata': {'color': 'grurple'},
}
fake_response._content = json.dumps(payload).encode('utf-8')
post_method = mock.Mock(return_value=fake_response, spec=[])
transport = mock.Mock(request=post_method, spec=['request'])
.. doctest:: multipart-upload
>>> from google.resumable_media.requests import MultipartUpload
>>>
>>> url_template = (
... 'https://www.googleapis.com/upload/storage/v1/b/{bucket}/o?'
... 'uploadType=multipart')
>>> upload_url = url_template.format(bucket=bucket)
>>>
>>> upload = MultipartUpload(upload_url)
>>> metadata = {
... 'name': blob_name,
... 'metadata': {
... 'color': 'grurple',
... },
... }
>>> response = upload.transmit(transport, data, metadata, content_type)
>>> upload.finished
True
>>> response
<Response [200]>
>>> json_response = response.json()
>>> json_response['bucket'] == bucket
True
>>> json_response['name'] == blob_name
True
>>> json_response['metadata'] == metadata['metadata']
True
As with the simple upload, in the case of failure an :exc:`.InvalidResponse`
is raised, enclosing the :attr:`~.InvalidResponse.response` that caused
the failure and the ``upload`` object cannot be re-used after a failure.
=================
Resumable Uploads
=================
A :class:`.ResumableUpload` deviates from the other two upload classes:
it transmits a resource over the course of multiple requests. This
is intended to be used in cases where:
* the size of the resource is not known (i.e. it is generated on the fly)
* requests must be short-lived
* the client has request **size** limitations
* the resource is too large to fit into memory
In general, a resource should be sent in a **single** request to avoid
latency and reduce QPS. See `GCS best practices`_ for more things to
consider when using a resumable upload.
.. _GCS best practices: https://cloud.google.com/storage/docs/\
best-practices#uploading
After creating a :class:`.ResumableUpload` instance, a
**resumable upload session** must be initiated to let the server know that
a series of chunked upload requests will be coming and to obtain an
``upload_id`` for the session. In contrast to the other two upload classes,
:meth:`~.ResumableUpload.initiate` takes a byte ``stream`` as input rather
than raw bytes as ``data``. This can be a file object, a :class:`~io.BytesIO`
object or any other stream implementing the same interface.
.. testsetup:: resumable-initiate
import io
import mock
import requests
import http.client
bucket = 'some-bucket'
blob_name = 'file.txt'
data = b'Some resumable bytes.'
content_type = 'text/plain'
fake_response = requests.Response()
fake_response.status_code = int(http.client.OK)
fake_response._content = b''
upload_id = 'ABCdef189XY_super_serious'
resumable_url_template = (
'https://www.googleapis.com/upload/storage/v1/b/{bucket}'
'/o?uploadType=resumable&upload_id={upload_id}')
resumable_url = resumable_url_template.format(
bucket=bucket, upload_id=upload_id)
fake_response.headers['location'] = resumable_url
fake_response.headers['x-guploader-uploadid'] = upload_id
post_method = mock.Mock(return_value=fake_response, spec=[])
transport = mock.Mock(request=post_method, spec=['request'])
.. doctest:: resumable-initiate
>>> from google.resumable_media.requests import ResumableUpload
>>>
>>> url_template = (
... 'https://www.googleapis.com/upload/storage/v1/b/{bucket}/o?'
... 'uploadType=resumable')
>>> upload_url = url_template.format(bucket=bucket)
>>>
>>> chunk_size = 1024 * 1024 # 1MB
>>> upload = ResumableUpload(upload_url, chunk_size)
>>> stream = io.BytesIO(data)
>>> # The upload doesn't know how "big" it is until seeing a stream.
>>> upload.total_bytes is None
True
>>> metadata = {'name': blob_name}
>>> response = upload.initiate(transport, stream, metadata, content_type)
>>> response
<Response [200]>
>>> upload.resumable_url == response.headers['Location']
True
>>> upload.total_bytes == len(data)
True
>>> upload_id = response.headers['X-GUploader-UploadID']
>>> upload_id
'ABCdef189XY_super_serious'
>>> upload.resumable_url == upload_url + '&upload_id=' + upload_id
True
Once a :class:`.ResumableUpload` has been initiated, the resource is
transmitted in chunks until completion:
.. testsetup:: resumable-transmit
import io
import json
import mock
import requests
import http.client
from google import resumable_media
import google.resumable_media.requests.upload as upload_mod
data = b'01234567891'
stream = io.BytesIO(data)
# Create an "already initiated" upload.
upload_url = 'http://test.invalid'
chunk_size = 256 * 1024 # 256KB
upload = upload_mod.ResumableUpload(upload_url, chunk_size)
upload._resumable_url = 'http://test.invalid?upload_id=mocked'
upload._stream = stream
upload._content_type = 'text/plain'
upload._total_bytes = len(data)
# After-the-fact update the chunk size so that len(data)
# is split into three.
upload._chunk_size = 4
# Make three fake responses.
fake_response0 = requests.Response()
fake_response0.status_code = http.client.PERMANENT_REDIRECT
fake_response0.headers['range'] = 'bytes=0-3'
fake_response1 = requests.Response()
fake_response1.status_code = http.client.PERMANENT_REDIRECT
fake_response1.headers['range'] = 'bytes=0-7'
fake_response2 = requests.Response()
fake_response2.status_code = int(http.client.OK)
bucket = 'some-bucket'
blob_name = 'file.txt'
payload = {
'bucket': bucket,
'name': blob_name,
'size': '{:d}'.format(len(data)),
}
fake_response2._content = json.dumps(payload).encode('utf-8')
# Use the fake responses to mock a transport.
responses = [fake_response0, fake_response1, fake_response2]
put_method = mock.Mock(side_effect=responses, spec=[])
transport = mock.Mock(request=put_method, spec=['request'])
.. doctest:: resumable-transmit
>>> response0 = upload.transmit_next_chunk(transport)
>>> response0
<Response [308]>
>>> upload.finished
False
>>> upload.bytes_uploaded == upload.chunk_size
True
>>>
>>> response1 = upload.transmit_next_chunk(transport)
>>> response1
<Response [308]>
>>> upload.finished
False
>>> upload.bytes_uploaded == 2 * upload.chunk_size
True
>>>
>>> response2 = upload.transmit_next_chunk(transport)
>>> response2
<Response [200]>
>>> upload.finished
True
>>> upload.bytes_uploaded == upload.total_bytes
True
>>> json_response = response2.json()
>>> json_response['bucket'] == bucket
True
>>> json_response['name'] == blob_name
True
"""
from google._async_resumable_media.requests.download import ChunkedDownload
from google._async_resumable_media.requests.download import Download
from google._async_resumable_media.requests.upload import MultipartUpload
from google._async_resumable_media.requests.download import RawChunkedDownload
from google._async_resumable_media.requests.download import RawDownload
from google._async_resumable_media.requests.upload import ResumableUpload
from google._async_resumable_media.requests.upload import SimpleUpload
__all__ = [
"ChunkedDownload",
"Download",
"MultipartUpload",
"RawChunkedDownload",
"RawDownload",
"ResumableUpload",
"SimpleUpload",
]

View File

@@ -0,0 +1,155 @@
# Copyright 2017 Google Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Shared utilities used by both downloads and uploads.
This utilities are explicitly catered to ``requests``-like transports.
"""
import functools
from google._async_resumable_media import _helpers
from google.resumable_media import common
from google.auth.transport import _aiohttp_requests as aiohttp_requests # type: ignore
import aiohttp # type: ignore
_DEFAULT_RETRY_STRATEGY = common.RetryStrategy()
_SINGLE_GET_CHUNK_SIZE = 8192
# The number of seconds to wait to establish a connection
# (connect() call on socket). Avoid setting this to a multiple of 3 to not
# Align with TCP Retransmission timing. (typically 2.5-3s)
_DEFAULT_CONNECT_TIMEOUT = 61
# The number of seconds to wait between bytes sent from the server.
_DEFAULT_READ_TIMEOUT = 60
_DEFAULT_TIMEOUT = aiohttp.ClientTimeout(
connect=_DEFAULT_CONNECT_TIMEOUT, sock_read=_DEFAULT_READ_TIMEOUT
)
class RequestsMixin(object):
"""Mix-in class implementing ``requests``-specific behavior.
These are methods that are more general purpose, with implementations
specific to the types defined in ``requests``.
"""
@staticmethod
def _get_status_code(response):
"""Access the status code from an HTTP response.
Args:
response (~requests.Response): The HTTP response object.
Returns:
int: The status code.
"""
return response.status
@staticmethod
def _get_headers(response):
"""Access the headers from an HTTP response.
Args:
response (~requests.Response): The HTTP response object.
Returns:
~requests.structures.CaseInsensitiveDict: The header mapping (keys
are case-insensitive).
"""
# For Async testing,`_headers` is modified instead of headers
# access via the internal field.
return response._headers
@staticmethod
async def _get_body(response):
"""Access the response body from an HTTP response.
Args:
response (~requests.Response): The HTTP response object.
Returns:
bytes: The body of the ``response``.
"""
wrapped_response = aiohttp_requests._CombinedResponse(response)
content = await wrapped_response.data.read()
return content
class RawRequestsMixin(RequestsMixin):
@staticmethod
async def _get_body(response):
"""Access the response body from an HTTP response.
Args:
response (~requests.Response): The HTTP response object.
Returns:
bytes: The body of the ``response``.
"""
wrapped_response = aiohttp_requests._CombinedResponse(response)
content = await wrapped_response.raw_content()
return content
async def http_request(
transport,
method,
url,
data=None,
headers=None,
retry_strategy=_DEFAULT_RETRY_STRATEGY,
**transport_kwargs
):
"""Make an HTTP request.
Args:
transport (~requests.Session): A ``requests`` object which can make
authenticated requests via a ``request()`` method. This method
must accept an HTTP method, an upload URL, a ``data`` keyword
argument and a ``headers`` keyword argument.
method (str): The HTTP method for the request.
url (str): The URL for the request.
data (Optional[bytes]): The body of the request.
headers (Mapping[str, str]): The headers for the request (``transport``
may also add additional headers).
retry_strategy (~google.resumable_media.common.RetryStrategy): The
strategy to use if the request fails and must be retried.
transport_kwargs (Dict[str, str]): Extra keyword arguments to be
passed along to ``transport.request``.
Returns:
~requests.Response: The return value of ``transport.request()``.
"""
# NOTE(asyncio/aiohttp): Sync versions use a tuple for two timeouts,
# default connect timeout and read timeout. Since async requests only
# accepts a single value, this is using the connect timeout. This logic
# diverges from the sync implementation.
if "timeout" not in transport_kwargs:
timeout = _DEFAULT_TIMEOUT
transport_kwargs["timeout"] = timeout
func = functools.partial(
transport.request, method, url, data=data, headers=headers, **transport_kwargs
)
resp = await _helpers.wait_and_retry(
func, RequestsMixin._get_status_code, retry_strategy
)
return resp

View File

@@ -0,0 +1,465 @@
# Copyright 2017 Google Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Support for downloading media from Google APIs."""
import urllib3.response # type: ignore
import http
from google._async_resumable_media import _download
from google._async_resumable_media import _helpers
from google._async_resumable_media.requests import _request_helpers
from google.resumable_media import common
from google.resumable_media import _helpers as sync_helpers
from google.resumable_media.requests import download
_CHECKSUM_MISMATCH = download._CHECKSUM_MISMATCH
class Download(_request_helpers.RequestsMixin, _download.Download):
"""Helper to manage downloading a resource from a Google API.
"Slices" of the resource can be retrieved by specifying a range
with ``start`` and / or ``end``. However, in typical usage, neither
``start`` nor ``end`` is expected to be provided.
Args:
media_url (str): The URL containing the media to be downloaded.
stream (IO[bytes]): A write-able stream (i.e. file-like object) that
the downloaded resource can be written to.
start (int): The first byte in a range to be downloaded. If not
provided, but ``end`` is provided, will download from the
beginning to ``end`` of the media.
end (int): The last byte in a range to be downloaded. If not
provided, but ``start`` is provided, will download from the
``start`` to the end of the media.
headers (Optional[Mapping[str, str]]): Extra headers that should
be sent with the request, e.g. headers for encrypted data.
checksum Optional([str]): The type of checksum to compute to verify
the integrity of the object. The response headers must contain
a checksum of the requested type. If the headers lack an
appropriate checksum (for instance in the case of transcoded or
ranged downloads where the remote service does not know the
correct checksum) an INFO-level log will be emitted. Supported
values are "md5", "crc32c" and None. The default is "md5".
Attributes:
media_url (str): The URL containing the media to be downloaded.
start (Optional[int]): The first byte in a range to be downloaded.
end (Optional[int]): The last byte in a range to be downloaded.
"""
async def _write_to_stream(self, response):
"""Write response body to a write-able stream.
.. note:
This method assumes that the ``_stream`` attribute is set on the
current download.
Args:
response (~requests.Response): The HTTP response object.
Raises:
~google.resumable_media.common.DataCorruption: If the download's
checksum doesn't agree with server-computed checksum.
"""
# `_get_expected_checksum()` may return None even if a checksum was
# requested, in which case it will emit an info log _MISSING_CHECKSUM.
# If an invalid checksum type is specified, this will raise ValueError.
expected_checksum, checksum_object = sync_helpers._get_expected_checksum(
response, self._get_headers, self.media_url, checksum_type=self.checksum
)
local_checksum_object = _add_decoder(response, checksum_object)
async for chunk in response.content.iter_chunked(
_request_helpers._SINGLE_GET_CHUNK_SIZE
):
self._stream.write(chunk)
local_checksum_object.update(chunk)
# Don't validate the checksum for partial responses.
if (
expected_checksum is not None
and response.status != http.client.PARTIAL_CONTENT
):
actual_checksum = sync_helpers.prepare_checksum_digest(
checksum_object.digest()
)
if actual_checksum != expected_checksum:
msg = _CHECKSUM_MISMATCH.format(
self.media_url,
expected_checksum,
actual_checksum,
checksum_type=self.checksum.upper(),
)
raise common.DataCorruption(response, msg)
async def consume(self, transport, timeout=_request_helpers._DEFAULT_TIMEOUT):
"""Consume the resource to be downloaded.
If a ``stream`` is attached to this download, then the downloaded
resource will be written to the stream.
Args:
transport (~requests.Session): A ``requests`` object which can
make authenticated requests.
timeout (Optional[Union[float, aiohttp.ClientTimeout]]):
The number of seconds to wait for the server response.
Depending on the retry strategy, a request may be repeated
several times using the same timeout each time.
Can also be passed as an `aiohttp.ClientTimeout` object.
Returns:
~requests.Response: The HTTP response returned by ``transport``.
Raises:
~google.resumable_media.common.DataCorruption: If the download's
checksum doesn't agree with server-computed checksum.
ValueError: If the current :class:`Download` has already
finished.
"""
method, url, payload, headers = self._prepare_request()
# NOTE: We assume "payload is None" but pass it along anyway.
request_kwargs = {
"data": payload,
"headers": headers,
"retry_strategy": self._retry_strategy,
"timeout": timeout,
}
if self._stream is not None:
request_kwargs["stream"] = True
result = await _request_helpers.http_request(
transport, method, url, **request_kwargs
)
self._process_response(result)
if self._stream is not None:
await self._write_to_stream(result)
return result
class RawDownload(_request_helpers.RawRequestsMixin, _download.Download):
"""Helper to manage downloading a raw resource from a Google API.
"Slices" of the resource can be retrieved by specifying a range
with ``start`` and / or ``end``. However, in typical usage, neither
``start`` nor ``end`` is expected to be provided.
Args:
media_url (str): The URL containing the media to be downloaded.
stream (IO[bytes]): A write-able stream (i.e. file-like object) that
the downloaded resource can be written to.
start (int): The first byte in a range to be downloaded. If not
provided, but ``end`` is provided, will download from the
beginning to ``end`` of the media.
end (int): The last byte in a range to be downloaded. If not
provided, but ``start`` is provided, will download from the
``start`` to the end of the media.
headers (Optional[Mapping[str, str]]): Extra headers that should
be sent with the request, e.g. headers for encrypted data.
checksum Optional([str]): The type of checksum to compute to verify
the integrity of the object. The response headers must contain
a checksum of the requested type. If the headers lack an
appropriate checksum (for instance in the case of transcoded or
ranged downloads where the remote service does not know the
correct checksum) an INFO-level log will be emitted. Supported
values are "md5", "crc32c" and None. The default is "md5".
Attributes:
media_url (str): The URL containing the media to be downloaded.
start (Optional[int]): The first byte in a range to be downloaded.
end (Optional[int]): The last byte in a range to be downloaded.
"""
async def _write_to_stream(self, response):
"""Write response body to a write-able stream.
.. note:
This method assumes that the ``_stream`` attribute is set on the
current download.
Args:
response (~requests.Response): The HTTP response object.
Raises:
~google.resumable_media.common.DataCorruption: If the download's
checksum doesn't agree with server-computed checksum.
"""
# `_get_expected_checksum()` may return None even if a checksum was
# requested, in which case it will emit an info log _MISSING_CHECKSUM.
# If an invalid checksum type is specified, this will raise ValueError.
expected_checksum, checksum_object = sync_helpers._get_expected_checksum(
response, self._get_headers, self.media_url, checksum_type=self.checksum
)
async for chunk in response.content.iter_chunked(
_request_helpers._SINGLE_GET_CHUNK_SIZE
):
self._stream.write(chunk)
checksum_object.update(chunk)
# Don't validate the checksum for partial responses.
if (
expected_checksum is not None
and response.status != http.client.PARTIAL_CONTENT
):
actual_checksum = sync_helpers.prepare_checksum_digest(
checksum_object.digest()
)
if actual_checksum != expected_checksum:
msg = _CHECKSUM_MISMATCH.format(
self.media_url,
expected_checksum,
actual_checksum,
checksum_type=self.checksum.upper(),
)
raise common.DataCorruption(response, msg)
async def consume(self, transport, timeout=_request_helpers._DEFAULT_TIMEOUT):
"""Consume the resource to be downloaded.
If a ``stream`` is attached to this download, then the downloaded
resource will be written to the stream.
Args:
transport (~requests.Session): A ``requests`` object which can
make authenticated requests.
timeout (Optional[Union[float, Tuple[float, float]]]):
The number of seconds to wait for the server response.
Depending on the retry strategy, a request may be repeated
several times using the same timeout each time.
Can also be passed as a tuple (connect_timeout, read_timeout).
See :meth:`requests.Session.request` documentation for details.
Returns:
~requests.Response: The HTTP response returned by ``transport``.
Raises:
~google.resumable_media.common.DataCorruption: If the download's
checksum doesn't agree with server-computed checksum.
ValueError: If the current :class:`Download` has already
finished.
"""
method, url, payload, headers = self._prepare_request()
# NOTE: We assume "payload is None" but pass it along anyway.
result = await _request_helpers.http_request(
transport,
method,
url,
data=payload,
headers=headers,
retry_strategy=self._retry_strategy,
)
self._process_response(result)
if self._stream is not None:
await self._write_to_stream(result)
return result
class ChunkedDownload(_request_helpers.RequestsMixin, _download.ChunkedDownload):
"""Download a resource in chunks from a Google API.
Args:
media_url (str): The URL containing the media to be downloaded.
chunk_size (int): The number of bytes to be retrieved in each
request.
stream (IO[bytes]): A write-able stream (i.e. file-like object) that
will be used to concatenate chunks of the resource as they are
downloaded.
start (int): The first byte in a range to be downloaded. If not
provided, defaults to ``0``.
end (int): The last byte in a range to be downloaded. If not
provided, will download to the end of the media.
headers (Optional[Mapping[str, str]]): Extra headers that should
be sent with each request, e.g. headers for data encryption
key headers.
Attributes:
media_url (str): The URL containing the media to be downloaded.
start (Optional[int]): The first byte in a range to be downloaded.
end (Optional[int]): The last byte in a range to be downloaded.
chunk_size (int): The number of bytes to be retrieved in each request.
Raises:
ValueError: If ``start`` is negative.
"""
async def consume_next_chunk(
self, transport, timeout=_request_helpers._DEFAULT_TIMEOUT
):
"""
Consume the next chunk of the resource to be downloaded.
Args:
transport (~requests.Session): A ``requests`` object which can
make authenticated requests.
timeout (Optional[Union[float, aiohttp.ClientTimeout]]):
The number of seconds to wait for the server response.
Depending on the retry strategy, a request may be repeated
several times using the same timeout each time.
Can also be passed as an `aiohttp.ClientTimeout` object.
Returns:
~requests.Response: The HTTP response returned by ``transport``.
Raises:
ValueError: If the current download has finished.
"""
method, url, payload, headers = self._prepare_request()
# NOTE: We assume "payload is None" but pass it along anyway.
result = await _request_helpers.http_request(
transport,
method,
url,
data=payload,
headers=headers,
retry_strategy=self._retry_strategy,
timeout=timeout,
)
await self._process_response(result)
return result
class RawChunkedDownload(_request_helpers.RawRequestsMixin, _download.ChunkedDownload):
"""Download a raw resource in chunks from a Google API.
Args:
media_url (str): The URL containing the media to be downloaded.
chunk_size (int): The number of bytes to be retrieved in each
request.
stream (IO[bytes]): A write-able stream (i.e. file-like object) that
will be used to concatenate chunks of the resource as they are
downloaded.
start (int): The first byte in a range to be downloaded. If not
provided, defaults to ``0``.
end (int): The last byte in a range to be downloaded. If not
provided, will download to the end of the media.
headers (Optional[Mapping[str, str]]): Extra headers that should
be sent with each request, e.g. headers for data encryption
key headers.
Attributes:
media_url (str): The URL containing the media to be downloaded.
start (Optional[int]): The first byte in a range to be downloaded.
end (Optional[int]): The last byte in a range to be downloaded.
chunk_size (int): The number of bytes to be retrieved in each request.
Raises:
ValueError: If ``start`` is negative.
"""
async def consume_next_chunk(
self, transport, timeout=_request_helpers._DEFAULT_TIMEOUT
):
"""Consume the next chunk of the resource to be downloaded.
Args:
transport (~requests.Session): A ``requests`` object which can
make authenticated requests.
timeout (Optional[Union[float, aiohttp.ClientTimeout]]):
The number of seconds to wait for the server response.
Depending on the retry strategy, a request may be repeated
several times using the same timeout each time.
Can also be passed as an `aiohttp.ClientTimeout` object.
Returns:
~requests.Response: The HTTP response returned by ``transport``.
Raises:
ValueError: If the current download has finished.
"""
method, url, payload, headers = self._prepare_request()
# NOTE: We assume "payload is None" but pass it along anyway.
result = await _request_helpers.http_request(
transport,
method,
url,
data=payload,
headers=headers,
retry_strategy=self._retry_strategy,
timeout=timeout,
)
await self._process_response(result)
return result
def _add_decoder(response_raw, checksum):
"""Patch the ``_decoder`` on a ``urllib3`` response.
This is so that we can intercept the compressed bytes before they are
decoded.
Only patches if the content encoding is ``gzip``.
Args:
response_raw (urllib3.response.HTTPResponse): The raw response for
an HTTP request.
checksum (object):
A checksum which will be updated with compressed bytes.
Returns:
object: Either the original ``checksum`` if ``_decoder`` is not
patched, or a ``_DoNothingHash`` if the decoder is patched, since the
caller will no longer need to hash to decoded bytes.
"""
encoding = response_raw.headers.get("content-encoding", "").lower()
if encoding != "gzip":
return checksum
response_raw._decoder = _GzipDecoder(checksum)
return _helpers._DoNothingHash()
class _GzipDecoder(urllib3.response.GzipDecoder):
"""Custom subclass of ``urllib3`` decoder for ``gzip``-ed bytes.
Allows a checksum function to see the compressed bytes before they are
decoded. This way the checksum of the compressed value can be computed.
Args:
checksum (object):
A checksum which will be updated with compressed bytes.
"""
def __init__(self, checksum):
super(_GzipDecoder, self).__init__()
self._checksum = checksum
def decompress(self, data):
"""Decompress the bytes.
Args:
data (bytes): The compressed bytes to be decompressed.
Returns:
bytes: The decompressed bytes from ``data``.
"""
self._checksum.update(data)
return super(_GzipDecoder, self).decompress(data)

View File

@@ -0,0 +1,515 @@
# Copyright 2017 Google Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Support for resumable uploads.
Also supported here are simple (media) uploads and multipart
uploads that contain both metadata and a small file as payload.
"""
from google._async_resumable_media import _upload
from google._async_resumable_media.requests import _request_helpers
class SimpleUpload(_request_helpers.RequestsMixin, _upload.SimpleUpload):
"""Upload a resource to a Google API.
A **simple** media upload sends no metadata and completes the upload
in a single request.
Args:
upload_url (str): The URL where the content will be uploaded.
headers (Optional[Mapping[str, str]]): Extra headers that should
be sent with the request, e.g. headers for encrypted data.
Attributes:
upload_url (str): The URL where the content will be uploaded.
"""
async def transmit(
self,
transport,
data,
content_type,
timeout=_request_helpers._DEFAULT_TIMEOUT,
):
"""Transmit the resource to be uploaded.
Args:
transport (~requests.Session): A ``requests`` object which can
make authenticated requests.
data (bytes): The resource content to be uploaded.
content_type (str): The content type of the resource, e.g. a JPEG
image has content type ``image/jpeg``.
timeout (Optional[Union[float, aiohttp.ClientTimeout]]):
The number of seconds to wait for the server response.
Depending on the retry strategy, a request may be repeated
several times using the same timeout each time.
Can also be passed as an `aiohttp.ClientTimeout` object.
Returns:
~requests.Response: The HTTP response returned by ``transport``.
"""
method, url, payload, headers = self._prepare_request(data, content_type)
response = await _request_helpers.http_request(
transport,
method,
url,
data=payload,
headers=headers,
retry_strategy=self._retry_strategy,
timeout=timeout,
)
self._process_response(response)
return response
class MultipartUpload(_request_helpers.RequestsMixin, _upload.MultipartUpload):
"""Upload a resource with metadata to a Google API.
A **multipart** upload sends both metadata and the resource in a single
(multipart) request.
Args:
upload_url (str): The URL where the content will be uploaded.
headers (Optional[Mapping[str, str]]): Extra headers that should
be sent with the request, e.g. headers for encrypted data.
checksum Optional([str]): The type of checksum to compute to verify
the integrity of the object. The request metadata will be amended
to include the computed value. Using this option will override a
manually-set checksum value. Supported values are "md5",
"crc32c" and None. The default is None.
Attributes:
upload_url (str): The URL where the content will be uploaded.
"""
async def transmit(
self,
transport,
data,
metadata,
content_type,
timeout=_request_helpers._DEFAULT_TIMEOUT,
):
"""Transmit the resource to be uploaded.
Args:
transport (~requests.Session): A ``requests`` object which can
make authenticated requests.
data (bytes): The resource content to be uploaded.
metadata (Mapping[str, str]): The resource metadata, such as an
ACL list.
content_type (str): The content type of the resource, e.g. a JPEG
image has content type ``image/jpeg``.
timeout (Optional[Union[float, aiohttp.ClientTimeout]]):
The number of seconds to wait for the server response.
Depending on the retry strategy, a request may be repeated
several times using the same timeout each time.
Can also be passed as an `aiohttp.ClientTimeout` object.
Returns:
~requests.Response: The HTTP response returned by ``transport``.
"""
method, url, payload, headers = self._prepare_request(
data, metadata, content_type
)
response = await _request_helpers.http_request(
transport,
method,
url,
data=payload,
headers=headers,
retry_strategy=self._retry_strategy,
timeout=timeout,
)
self._process_response(response)
return response
class ResumableUpload(_request_helpers.RequestsMixin, _upload.ResumableUpload):
"""Initiate and fulfill a resumable upload to a Google API.
A **resumable** upload sends an initial request with the resource metadata
and then gets assigned an upload ID / upload URL to send bytes to.
Using the upload URL, the upload is then done in chunks (determined by
the user) until all bytes have been uploaded.
When constructing a resumable upload, only the resumable upload URL and
the chunk size are required:
.. testsetup:: resumable-constructor
bucket = 'bucket-foo'
.. doctest:: resumable-constructor
>>> from google.resumable_media.requests import ResumableUpload
>>>
>>> url_template = (
... 'https://www.googleapis.com/upload/storage/v1/b/{bucket}/o?'
... 'uploadType=resumable')
>>> upload_url = url_template.format(bucket=bucket)
>>>
>>> chunk_size = 3 * 1024 * 1024 # 3MB
>>> upload = ResumableUpload(upload_url, chunk_size)
When initiating an upload (via :meth:`initiate`), the caller is expected
to pass the resource being uploaded as a file-like ``stream``. If the size
of the resource is explicitly known, it can be passed in directly:
.. testsetup:: resumable-explicit-size
import os
import tempfile
import mock
import requests
import http.client
from google.resumable_media.requests import ResumableUpload
upload_url = 'http://test.invalid'
chunk_size = 3 * 1024 * 1024 # 3MB
upload = ResumableUpload(upload_url, chunk_size)
file_desc, filename = tempfile.mkstemp()
os.close(file_desc)
data = b'some bytes!'
with open(filename, 'wb') as file_obj:
file_obj.write(data)
fake_response = requests.Response()
fake_response.status_code = int(http.client.OK)
fake_response._content = b''
resumable_url = 'http://test.invalid?upload_id=7up'
fake_response.headers['location'] = resumable_url
post_method = mock.Mock(return_value=fake_response, spec=[])
transport = mock.Mock(request=post_method, spec=['request'])
.. doctest:: resumable-explicit-size
>>> import os
>>>
>>> upload.total_bytes is None
True
>>>
>>> stream = open(filename, 'rb')
>>> total_bytes = os.path.getsize(filename)
>>> metadata = {'name': filename}
>>> response = upload.initiate(
... transport, stream, metadata, 'text/plain',
... total_bytes=total_bytes)
>>> response
<Response [200]>
>>>
>>> upload.total_bytes == total_bytes
True
.. testcleanup:: resumable-explicit-size
os.remove(filename)
If the stream is in a "final" state (i.e. it won't have any more bytes
written to it), the total number of bytes can be determined implicitly
from the ``stream`` itself:
.. testsetup:: resumable-implicit-size
import io
import mock
import requests
import http.client
from google.resumable_media.requests import ResumableUpload
upload_url = 'http://test.invalid'
chunk_size = 3 * 1024 * 1024 # 3MB
upload = ResumableUpload(upload_url, chunk_size)
fake_response = requests.Response()
fake_response.status_code = int(http.client.OK)
fake_response._content = b''
resumable_url = 'http://test.invalid?upload_id=7up'
fake_response.headers['location'] = resumable_url
post_method = mock.Mock(return_value=fake_response, spec=[])
transport = mock.Mock(request=post_method, spec=['request'])
data = b'some MOAR bytes!'
metadata = {'name': 'some-file.jpg'}
content_type = 'image/jpeg'
.. doctest:: resumable-implicit-size
>>> stream = io.BytesIO(data)
>>> response = upload.initiate(
... transport, stream, metadata, content_type)
>>>
>>> upload.total_bytes == len(data)
True
If the size of the resource is **unknown** when the upload is initiated,
the ``stream_final`` argument can be used. This might occur if the
resource is being dynamically created on the client (e.g. application
logs). To use this argument:
.. testsetup:: resumable-unknown-size
import io
import mock
import requests
import http.client
from google.resumable_media.requests import ResumableUpload
upload_url = 'http://test.invalid'
chunk_size = 3 * 1024 * 1024 # 3MB
upload = ResumableUpload(upload_url, chunk_size)
fake_response = requests.Response()
fake_response.status_code = int(http.client.OK)
fake_response._content = b''
resumable_url = 'http://test.invalid?upload_id=7up'
fake_response.headers['location'] = resumable_url
post_method = mock.Mock(return_value=fake_response, spec=[])
transport = mock.Mock(request=post_method, spec=['request'])
metadata = {'name': 'some-file.jpg'}
content_type = 'application/octet-stream'
stream = io.BytesIO(b'data')
.. doctest:: resumable-unknown-size
>>> response = upload.initiate(
... transport, stream, metadata, content_type,
... stream_final=False)
>>>
>>> upload.total_bytes is None
True
Args:
upload_url (str): The URL where the resumable upload will be initiated.
chunk_size (int): The size of each chunk used to upload the resource.
headers (Optional[Mapping[str, str]]): Extra headers that should
be sent with the :meth:`initiate` request, e.g. headers for
encrypted data. These **will not** be sent with
:meth:`transmit_next_chunk` or :meth:`recover` requests.
checksum Optional([str]): The type of checksum to compute to verify
the integrity of the object. After the upload is complete, the
server-computed checksum of the resulting object will be checked
and google.resumable_media.common.DataCorruption will be raised on
a mismatch. The corrupted file will not be deleted from the remote
host automatically. Supported values are "md5", "crc32c" and None.
The default is None.
Attributes:
upload_url (str): The URL where the content will be uploaded.
Raises:
ValueError: If ``chunk_size`` is not a multiple of
:data:`.UPLOAD_CHUNK_SIZE`.
"""
async def initiate(
self,
transport,
stream,
metadata,
content_type,
total_bytes=None,
stream_final=True,
timeout=_request_helpers._DEFAULT_TIMEOUT,
):
"""Initiate a resumable upload.
By default, this method assumes your ``stream`` is in a "final"
state ready to transmit. However, ``stream_final=False`` can be used
to indicate that the size of the resource is not known. This can happen
if bytes are being dynamically fed into ``stream``, e.g. if the stream
is attached to application logs.
If ``stream_final=False`` is used, :attr:`chunk_size` bytes will be
read from the stream every time :meth:`transmit_next_chunk` is called.
If one of those reads produces strictly fewer bites than the chunk
size, the upload will be concluded.
Args:
transport (~requests.Session): A ``requests`` object which can
make authenticated requests.
stream (IO[bytes]): The stream (i.e. file-like object) that will
be uploaded. The stream **must** be at the beginning (i.e.
``stream.tell() == 0``).
metadata (Mapping[str, str]): The resource metadata, such as an
ACL list.
content_type (str): The content type of the resource, e.g. a JPEG
image has content type ``image/jpeg``.
total_bytes (Optional[int]): The total number of bytes to be
uploaded. If specified, the upload size **will not** be
determined from the stream (even if ``stream_final=True``).
stream_final (Optional[bool]): Indicates if the ``stream`` is
"final" (i.e. no more bytes will be added to it). In this case
we determine the upload size from the size of the stream. If
``total_bytes`` is passed, this argument will be ignored.
timeout (Optional[Union[float, aiohttp.ClientTimeout]]):
The number of seconds to wait for the server response.
Depending on the retry strategy, a request may be repeated
several times using the same timeout each time.
Can also be passed as an `aiohttp.ClientTimeout` object.
Returns:
~requests.Response: The HTTP response returned by ``transport``.
"""
method, url, payload, headers = self._prepare_initiate_request(
stream,
metadata,
content_type,
total_bytes=total_bytes,
stream_final=stream_final,
)
response = await _request_helpers.http_request(
transport,
method,
url,
data=payload,
headers=headers,
retry_strategy=self._retry_strategy,
timeout=timeout,
)
self._process_initiate_response(response)
return response
async def transmit_next_chunk(
self, transport, timeout=_request_helpers._DEFAULT_TIMEOUT
):
"""Transmit the next chunk of the resource to be uploaded.
If the current upload was initiated with ``stream_final=False``,
this method will dynamically determine if the upload has completed.
The upload will be considered complete if the stream produces
fewer than :attr:`chunk_size` bytes when a chunk is read from it.
In the case of failure, an exception is thrown that preserves the
failed response:
.. testsetup:: bad-response
import io
import mock
import requests
import http.client
from google import resumable_media
import google.resumable_media.requests.upload as upload_mod
transport = mock.Mock(spec=['request'])
fake_response = requests.Response()
fake_response.status_code = int(http.client.BAD_REQUEST)
transport.request.return_value = fake_response
upload_url = 'http://test.invalid'
upload = upload_mod.ResumableUpload(
upload_url, resumable_media.UPLOAD_CHUNK_SIZE)
# Fake that the upload has been initiate()-d
data = b'data is here'
upload._stream = io.BytesIO(data)
upload._total_bytes = len(data)
upload._resumable_url = 'http://test.invalid?upload_id=nope'
.. doctest:: bad-response
:options: +NORMALIZE_WHITESPACE
>>> error = None
>>> try:
... upload.transmit_next_chunk(transport)
... except resumable_media.InvalidResponse as caught_exc:
... error = caught_exc
...
>>> error
InvalidResponse('Request failed with status code', 400,
'Expected one of', <HTTPStatus.OK: 200>, 308)
>>> error.response
<Response [400]>
Args:
transport (~requests.Session): A ``requests`` object which can
make authenticated requests.
timeout (Optional[Union[float, aiohttp.ClientTimeout]]):
The number of seconds to wait for the server response.
Depending on the retry strategy, a request may be repeated
several times using the same timeout each time.
Can also be passed as an `aiohttp.ClientTimeout` object.
Returns:
~requests.Response: The HTTP response returned by ``transport``.
Raises:
~google.resumable_media.common.InvalidResponse: If the status
code is not 200 or 308.
~google.resumable_media.common.DataCorruption: If this is the final
chunk, a checksum validation was requested, and the checksum
does not match or is not available.
"""
method, url, payload, headers = self._prepare_request()
response = await _request_helpers.http_request(
transport,
method,
url,
data=payload,
headers=headers,
retry_strategy=self._retry_strategy,
timeout=timeout,
)
await self._process_resumable_response(response, len(payload))
return response
async def recover(self, transport):
"""Recover from a failure.
This method should be used when a :class:`ResumableUpload` is in an
:attr:`~ResumableUpload.invalid` state due to a request failure.
This will verify the progress with the server and make sure the
current upload is in a valid state before :meth:`transmit_next_chunk`
can be used again.
Args:
transport (~requests.Session): A ``requests`` object which can
make authenticated requests.
Returns:
~requests.Response: The HTTP response returned by ``transport``.
"""
method, url, payload, headers = self._prepare_recover_request()
# NOTE: We assume "payload is None" but pass it along anyway.
response = await _request_helpers.http_request(
transport,
method,
url,
data=payload,
headers=headers,
retry_strategy=self._retry_strategy,
)
self._process_recover_response(response)
return response

View File

@@ -0,0 +1,20 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from . import version
from .agents.llm_agent import Agent
from .runners import Runner
__version__ = version.__version__
__all__ = ["Agent", "Runner"]

View File

@@ -0,0 +1,32 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .base_agent import BaseAgent
from .live_request_queue import LiveRequest
from .live_request_queue import LiveRequestQueue
from .llm_agent import Agent
from .llm_agent import LlmAgent
from .loop_agent import LoopAgent
from .parallel_agent import ParallelAgent
from .run_config import RunConfig
from .sequential_agent import SequentialAgent
__all__ = [
'Agent',
'BaseAgent',
'LlmAgent',
'LoopAgent',
'ParallelAgent',
'SequentialAgent',
]

View File

@@ -0,0 +1,38 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import asyncio
from typing import Optional
from pydantic import BaseModel
from pydantic import ConfigDict
from .live_request_queue import LiveRequestQueue
class ActiveStreamingTool(BaseModel):
"""Manages streaming tool related resources during invocation."""
model_config = ConfigDict(
arbitrary_types_allowed=True,
extra='forbid',
)
task: Optional[asyncio.Task] = None
"""The active task of this streaming tool."""
stream: Optional[LiveRequestQueue] = None
"""The active (input) streams of this streaming tool."""

View File

@@ -0,0 +1,345 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
from typing import Any
from typing import AsyncGenerator
from typing import Callable
from typing import final
from typing import Optional
from typing import TYPE_CHECKING
from google.genai import types
from opentelemetry import trace
from pydantic import BaseModel
from pydantic import ConfigDict
from pydantic import Field
from pydantic import field_validator
from typing_extensions import override
from ..events.event import Event
from .callback_context import CallbackContext
if TYPE_CHECKING:
from .invocation_context import InvocationContext
tracer = trace.get_tracer('gcp.vertex.agent')
BeforeAgentCallback = Callable[[CallbackContext], Optional[types.Content]]
"""Callback signature that is invoked before the agent run.
Args:
callback_context: MUST be named 'callback_context' (enforced).
Returns:
The content to return to the user. When set, the agent run will skipped and
the provided content will be returned to user.
"""
AfterAgentCallback = Callable[[CallbackContext], Optional[types.Content]]
"""Callback signature that is invoked after the agent run.
Args:
callback_context: MUST be named 'callback_context' (enforced).
Returns:
The content to return to the user. When set, the agent run will skipped and
the provided content will be appended to event history as agent response.
"""
class BaseAgent(BaseModel):
"""Base class for all agents in Agent Development Kit."""
model_config = ConfigDict(
arbitrary_types_allowed=True,
extra='forbid',
)
name: str
"""The agent's name.
Agent name must be a Python identifier and unique within the agent tree.
Agent name cannot be "user", since it's reserved for end-user's input.
"""
description: str = ''
"""Description about the agent's capability.
The model uses this to determine whether to delegate control to the agent.
One-line description is enough and preferred.
"""
parent_agent: Optional[BaseAgent] = Field(default=None, init=False)
"""The parent agent of this agent.
Note that an agent can ONLY be added as sub-agent once.
If you want to add one agent twice as sub-agent, consider to create two agent
instances with identical config, but with different name and add them to the
agent tree.
"""
sub_agents: list[BaseAgent] = Field(default_factory=list)
"""The sub-agents of this agent."""
before_agent_callback: Optional[BeforeAgentCallback] = None
"""Callback signature that is invoked before the agent run.
Args:
callback_context: MUST be named 'callback_context' (enforced).
Returns:
The content to return to the user. When set, the agent run will skipped and
the provided content will be returned to user.
"""
after_agent_callback: Optional[AfterAgentCallback] = None
"""Callback signature that is invoked after the agent run.
Args:
callback_context: MUST be named 'callback_context' (enforced).
Returns:
The content to return to the user. When set, the agent run will skipped and
the provided content will be appended to event history as agent response.
"""
@final
async def run_async(
self,
parent_context: InvocationContext,
) -> AsyncGenerator[Event, None]:
"""Entry method to run an agent via text-based conversation.
Args:
parent_context: InvocationContext, the invocation context of the parent
agent.
Yields:
Event: the events generated by the agent.
"""
with tracer.start_as_current_span(f'agent_run [{self.name}]'):
ctx = self._create_invocation_context(parent_context)
if event := self.__handle_before_agent_callback(ctx):
yield event
if ctx.end_invocation:
return
async for event in self._run_async_impl(ctx):
yield event
if ctx.end_invocation:
return
if event := self.__handle_after_agent_callback(ctx):
yield event
@final
async def run_live(
self,
parent_context: InvocationContext,
) -> AsyncGenerator[Event, None]:
"""Entry method to run an agent via video/audio-based conversation.
Args:
parent_context: InvocationContext, the invocation context of the parent
agent.
Yields:
Event: the events generated by the agent.
"""
with tracer.start_as_current_span(f'agent_run [{self.name}]'):
ctx = self._create_invocation_context(parent_context)
# TODO(hangfei): support before/after_agent_callback
async for event in self._run_live_impl(ctx):
yield event
async def _run_async_impl(
self, ctx: InvocationContext
) -> AsyncGenerator[Event, None]:
"""Core logic to run this agent via text-based conversation.
Args:
ctx: InvocationContext, the invocation context for this agent.
Yields:
Event: the events generated by the agent.
"""
raise NotImplementedError(
f'_run_async_impl for {type(self)} is not implemented.'
)
yield # AsyncGenerator requires having at least one yield statement
async def _run_live_impl(
self, ctx: InvocationContext
) -> AsyncGenerator[Event, None]:
"""Core logic to run this agent via video/audio-based conversation.
Args:
ctx: InvocationContext, the invocation context for this agent.
Yields:
Event: the events generated by the agent.
"""
raise NotImplementedError(
f'_run_live_impl for {type(self)} is not implemented.'
)
yield # AsyncGenerator requires having at least one yield statement
@property
def root_agent(self) -> BaseAgent:
"""Gets the root agent of this agent."""
root_agent = self
while root_agent.parent_agent is not None:
root_agent = root_agent.parent_agent
return root_agent
def find_agent(self, name: str) -> Optional[BaseAgent]:
"""Finds the agent with the given name in this agent and its descendants.
Args:
name: The name of the agent to find.
Returns:
The agent with the matching name, or None if no such agent is found.
"""
if self.name == name:
return self
return self.find_sub_agent(name)
def find_sub_agent(self, name: str) -> Optional[BaseAgent]:
"""Finds the agent with the given name in this agent's descendants.
Args:
name: The name of the agent to find.
Returns:
The agent with the matching name, or None if no such agent is found.
"""
for sub_agent in self.sub_agents:
if result := sub_agent.find_agent(name):
return result
return None
def _create_invocation_context(
self, parent_context: InvocationContext
) -> InvocationContext:
"""Creates a new invocation context for this agent."""
invocation_context = parent_context.model_copy(update={'agent': self})
if parent_context.branch:
invocation_context.branch = f'{parent_context.branch}.{self.name}'
return invocation_context
def __handle_before_agent_callback(
self, ctx: InvocationContext
) -> Optional[Event]:
"""Runs the before_agent_callback if it exists.
Returns:
Optional[Event]: an event if callback provides content or changed state.
"""
ret_event = None
if not isinstance(self.before_agent_callback, Callable):
return ret_event
callback_context = CallbackContext(ctx)
before_agent_callback_content = self.before_agent_callback(
callback_context=callback_context
)
if before_agent_callback_content:
ret_event = Event(
invocation_id=ctx.invocation_id,
author=self.name,
branch=ctx.branch,
content=before_agent_callback_content,
actions=callback_context._event_actions,
)
ctx.end_invocation = True
return ret_event
if callback_context.state.has_delta():
ret_event = Event(
invocation_id=ctx.invocation_id,
author=self.name,
branch=ctx.branch,
actions=callback_context._event_actions,
)
return ret_event
def __handle_after_agent_callback(
self, invocation_context: InvocationContext
) -> Optional[Event]:
"""Runs the after_agent_callback if it exists.
Returns:
Optional[Event]: an event if callback provides content or changed state.
"""
ret_event = None
if not isinstance(self.after_agent_callback, Callable):
return ret_event
callback_context = CallbackContext(invocation_context)
after_agent_callback_content = self.after_agent_callback(
callback_context=callback_context
)
if after_agent_callback_content or callback_context.state.has_delta():
ret_event = Event(
invocation_id=invocation_context.invocation_id,
author=self.name,
branch=invocation_context.branch,
content=after_agent_callback_content,
actions=callback_context._event_actions,
)
return ret_event
@override
def model_post_init(self, __context: Any) -> None:
self.__set_parent_agent_for_sub_agents()
@field_validator('name', mode='after')
@classmethod
def __validate_name(cls, value: str):
if not value.isidentifier():
raise ValueError(
f'Found invalid agent name: `{value}`.'
' Agent name must be a valid identifier. It should start with a'
' letter (a-z, A-Z) or an underscore (_), and can only contain'
' letters, digits (0-9), and underscores.'
)
if value == 'user':
raise ValueError(
"Agent name cannot be `user`. `user` is reserved for end-user's"
' input.'
)
return value
def __set_parent_agent_for_sub_agents(self) -> BaseAgent:
for sub_agent in self.sub_agents:
if sub_agent.parent_agent is not None:
raise ValueError(
f'Agent `{sub_agent.name}` already has a parent agent, current'
f' parent: `{sub_agent.parent_agent.name}`, trying to add:'
f' `{self.name}`'
)
sub_agent.parent_agent = self
return self

View File

@@ -0,0 +1,111 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
from typing import Optional, TYPE_CHECKING
from typing_extensions import override
from .readonly_context import ReadonlyContext
if TYPE_CHECKING:
from google.genai import types
from ..events.event_actions import EventActions
from ..sessions.state import State
from .invocation_context import InvocationContext
class CallbackContext(ReadonlyContext):
"""The context of various callbacks within an agent run."""
def __init__(
self,
invocation_context: InvocationContext,
*,
event_actions: Optional[EventActions] = None,
) -> None:
super().__init__(invocation_context)
from ..events.event_actions import EventActions
from ..sessions.state import State
# TODO(weisun): make this public for Agent Development Kit, but private for
# users.
self._event_actions = event_actions or EventActions()
self._state = State(
value=invocation_context.session.state,
delta=self._event_actions.state_delta,
)
@property
@override
def state(self) -> State:
"""The delta-aware state of the current session.
For any state change, you can mutate this object directly,
e.g. `ctx.state['foo'] = 'bar'`
"""
return self._state
@property
def user_content(self) -> Optional[types.Content]:
"""The user content that started this invocation. READONLY field."""
return self._invocation_context.user_content
def load_artifact(
self, filename: str, version: Optional[int] = None
) -> Optional[types.Part]:
"""Loads an artifact attached to the current session.
Args:
filename: The filename of the artifact.
version: The version of the artifact. If None, the latest version will be
returned.
Returns:
The artifact.
"""
if self._invocation_context.artifact_service is None:
raise ValueError("Artifact service is not initialized.")
return self._invocation_context.artifact_service.load_artifact(
app_name=self._invocation_context.app_name,
user_id=self._invocation_context.user_id,
session_id=self._invocation_context.session.id,
filename=filename,
version=version,
)
def save_artifact(self, filename: str, artifact: types.Part) -> int:
"""Saves an artifact and records it as delta for the current session.
Args:
filename: The filename of the artifact.
artifact: The artifact to save.
Returns:
The version of the artifact.
"""
if self._invocation_context.artifact_service is None:
raise ValueError("Artifact service is not initialized.")
version = self._invocation_context.artifact_service.save_artifact(
app_name=self._invocation_context.app_name,
user_id=self._invocation_context.user_id,
session_id=self._invocation_context.session.id,
filename=filename,
artifact=artifact,
)
self._event_actions.artifact_delta[filename] = version
return version

View File

@@ -0,0 +1,181 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
from typing import Optional
import uuid
from google.genai import types
from pydantic import BaseModel
from pydantic import ConfigDict
from ..artifacts.base_artifact_service import BaseArtifactService
from ..memory.base_memory_service import BaseMemoryService
from ..sessions.base_session_service import BaseSessionService
from ..sessions.session import Session
from .active_streaming_tool import ActiveStreamingTool
from .base_agent import BaseAgent
from .live_request_queue import LiveRequestQueue
from .run_config import RunConfig
from .transcription_entry import TranscriptionEntry
class LlmCallsLimitExceededError(Exception):
"""Error thrown when the number of LLM calls exceed the limit."""
class _InvocationCostManager(BaseModel):
"""A container to keep track of the cost of invocation.
While we don't expected the metrics captured here to be a direct
representatative of monetary cost incurred in executing the current
invocation, but they, in someways have an indirect affect.
"""
_number_of_llm_calls: int = 0
"""A counter that keeps track of number of llm calls made."""
def increment_and_enforce_llm_calls_limit(
self, run_config: Optional[RunConfig]
):
"""Increments _number_of_llm_calls and enforces the limit."""
# We first increment the counter and then check the conditions.
self._number_of_llm_calls += 1
if (
run_config
and run_config.max_llm_calls > 0
and self._number_of_llm_calls > run_config.max_llm_calls
):
# We only enforce the limit if the limit is a positive number.
raise LlmCallsLimitExceededError(
"Max number of llm calls limit of"
f" `{run_config.max_llm_calls}` exceeded"
)
class InvocationContext(BaseModel):
"""An invocation context represents the data of a single invocation of an agent.
An invocation:
1. Starts with a user message and ends with a final response.
2. Can contain one or multiple agent calls.
3. Is handled by runner.run_async().
An invocation runs an agent until it does not request to transfer to another
agent.
An agent call:
1. Is handled by agent.run().
2. Ends when agent.run() ends.
An LLM agent call is an agent with a BaseLLMFlow.
An LLM agent call can contain one or multiple steps.
An LLM agent runs steps in a loop until:
1. A final response is generated.
2. The agent transfers to another agent.
3. The end_invocation is set to true by any callbacks or tools.
A step:
1. Calls the LLM only once and yields its response.
2. Calls the tools and yields their responses if requested.
The summarization of the function response is considered another step, since
it is another llm call.
A step ends when it's done calling llm and tools, or if the end_invocation
is set to true at any time.
```
┌─────────────────────── invocation ──────────────────────────┐
┌──────────── llm_agent_call_1 ────────────┐ ┌─ agent_call_2 ─┐
┌──── step_1 ────────┐ ┌───── step_2 ──────┐
[call_llm] [call_tool] [call_llm] [transfer]
```
"""
model_config = ConfigDict(
arbitrary_types_allowed=True,
extra="forbid",
)
artifact_service: Optional[BaseArtifactService] = None
session_service: BaseSessionService
memory_service: Optional[BaseMemoryService] = None
invocation_id: str
"""The id of this invocation context. Readonly."""
branch: Optional[str] = None
"""The branch of the invocation context.
The format is like agent_1.agent_2.agent_3, where agent_1 is the parent of
agent_2, and agent_2 is the parent of agent_3.
Branch is used when multiple sub-agents shouldn't see their peer agents'
conversation history.
"""
agent: BaseAgent
"""The current agent of this invocation context. Readonly."""
user_content: Optional[types.Content] = None
"""The user content that started this invocation. Readonly."""
session: Session
"""The current session of this invocation context. Readonly."""
end_invocation: bool = False
"""Whether to end this invocation.
Set to True in callbacks or tools to terminate this invocation."""
live_request_queue: Optional[LiveRequestQueue] = None
"""The queue to receive live requests."""
active_streaming_tools: Optional[dict[str, ActiveStreamingTool]] = None
"""The running streaming tools of this invocation."""
transcription_cache: Optional[list[TranscriptionEntry]] = None
"""Caches necessary, data audio or contents, that are needed by transcription."""
run_config: Optional[RunConfig] = None
"""Configurations for live agents under this invocation."""
_invocation_cost_manager: _InvocationCostManager = _InvocationCostManager()
"""A container to keep track of different kinds of costs incurred as a part
of this invocation.
"""
def increment_llm_call_count(
self,
):
"""Tracks number of llm calls made.
Raises:
LlmCallsLimitExceededError: If number of llm calls made exceed the set
threshold.
"""
self._invocation_cost_manager.increment_and_enforce_llm_calls_limit(
self.run_config
)
@property
def app_name(self) -> str:
return self.session.app_name
@property
def user_id(self) -> str:
return self.session.user_id
def new_invocation_context_id() -> str:
return "e-" + str(uuid.uuid4())

View File

@@ -0,0 +1,140 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import AsyncGenerator
from typing import Union
from google.genai import types
from langchain_core.messages import AIMessage
from langchain_core.messages import HumanMessage
from langchain_core.messages import SystemMessage
from langchain_core.runnables.config import RunnableConfig
from langgraph.graph.graph import CompiledGraph
from pydantic import ConfigDict
from typing_extensions import override
from ..events.event import Event
from .base_agent import BaseAgent
from .invocation_context import InvocationContext
def _get_last_human_messages(events: list[Event]) -> list[HumanMessage]:
"""Extracts last human messages from given list of events.
Args:
events: the list of events
Returns:
list of last human messages
"""
messages = []
for event in reversed(events):
if messages and event.author != 'user':
break
if event.author == 'user' and event.content and event.content.parts:
messages.append(HumanMessage(content=event.content.parts[0].text))
return list(reversed(messages))
class LangGraphAgent(BaseAgent):
"""Currently a concept implementation, supports single and multi-turn."""
model_config = ConfigDict(
arbitrary_types_allowed=True,
)
graph: CompiledGraph
instruction: str = ''
@override
async def _run_async_impl(
self,
ctx: InvocationContext,
) -> AsyncGenerator[Event, None]:
# Needed for langgraph checkpointer (for subsequent invocations; multi-turn)
config: RunnableConfig = {'configurable': {'thread_id': ctx.session.id}}
# Add instruction as SystemMessage if graph state is empty
current_graph_state = self.graph.get_state(config)
graph_messages = (
current_graph_state.values.get('messages', [])
if current_graph_state.values
else []
)
messages = (
[SystemMessage(content=self.instruction)]
if self.instruction and not graph_messages
else []
)
# Add events to messages (evaluating the memory used; parent agent vs checkpointer)
messages += self._get_messages(ctx.session.events)
# Use the Runnable
final_state = self.graph.invoke({'messages': messages}, config)
result = final_state['messages'][-1].content
result_event = Event(
invocation_id=ctx.invocation_id,
author=self.name,
branch=ctx.branch,
content=types.Content(
role='model',
parts=[types.Part.from_text(text=result)],
),
)
yield result_event
def _get_messages(
self, events: list[Event]
) -> list[Union[HumanMessage, AIMessage]]:
"""Extracts messages from given list of events.
If the developer provides their own memory within langgraph, we return the
last user messages only. Otherwise, we return all messages between the user
and the agent.
Args:
events: the list of events
Returns:
list of messages
"""
if self.graph.checkpointer:
return _get_last_human_messages(events)
else:
return self._get_conversation_with_agent(events)
def _get_conversation_with_agent(
self, events: list[Event]
) -> list[Union[HumanMessage, AIMessage]]:
"""Extracts messages from given list of events.
Args:
events: the list of events
Returns:
list of messages
"""
messages = []
for event in events:
if not event.content or not event.content.parts:
continue
if event.author == 'user':
messages.append(HumanMessage(content=event.content.parts[0].text))
elif event.author == self.name:
messages.append(AIMessage(content=event.content.parts[0].text))
return messages

View File

@@ -0,0 +1,64 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import asyncio
from typing import Optional
from google.genai import types
from pydantic import BaseModel
from pydantic import ConfigDict
class LiveRequest(BaseModel):
"""Request send to live agents."""
model_config = ConfigDict(ser_json_bytes='base64', val_json_bytes='base64')
content: Optional[types.Content] = None
"""If set, send the content to the model in turn-by-turn mode."""
blob: Optional[types.Blob] = None
"""If set, send the blob to the model in realtime mode."""
close: bool = False
"""If set, close the queue. queue.shutdown() is only supported in Python 3.13+."""
class LiveRequestQueue:
"""Queue used to send LiveRequest in a live(bidirectional streaming) way."""
def __init__(self):
# Ensure there's an event loop available in this thread
try:
asyncio.get_running_loop()
except RuntimeError:
# No running loop, create one
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
# Now create the queue (it will use the event loop we just ensured exists)
self._queue = asyncio.Queue()
def close(self):
self._queue.put_nowait(LiveRequest(close=True))
def send_content(self, content: types.Content):
self._queue.put_nowait(LiveRequest(content=content))
def send_realtime(self, blob: types.Blob):
self._queue.put_nowait(LiveRequest(blob=blob))
def send(self, req: LiveRequest):
self._queue.put_nowait(req)
async def get(self) -> LiveRequest:
return await self._queue.get()

View File

@@ -0,0 +1,376 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import logging
from typing import Any
from typing import AsyncGenerator
from typing import Callable
from typing import Literal
from typing import Optional
from typing import Union
from google.genai import types
from pydantic import BaseModel
from pydantic import Field
from pydantic import field_validator
from pydantic import model_validator
from typing_extensions import override
from typing_extensions import TypeAlias
from ..code_executors.base_code_executor import BaseCodeExecutor
from ..events.event import Event
from ..examples.base_example_provider import BaseExampleProvider
from ..examples.example import Example
from ..flows.llm_flows.auto_flow import AutoFlow
from ..flows.llm_flows.base_llm_flow import BaseLlmFlow
from ..flows.llm_flows.single_flow import SingleFlow
from ..models.base_llm import BaseLlm
from ..models.llm_request import LlmRequest
from ..models.llm_response import LlmResponse
from ..models.registry import LLMRegistry
from ..planners.base_planner import BasePlanner
from ..tools.base_tool import BaseTool
from ..tools.function_tool import FunctionTool
from ..tools.tool_context import ToolContext
from .base_agent import BaseAgent
from .callback_context import CallbackContext
from .invocation_context import InvocationContext
from .readonly_context import ReadonlyContext
logger = logging.getLogger(__name__)
BeforeModelCallback: TypeAlias = Callable[
[CallbackContext, LlmRequest], Optional[LlmResponse]
]
AfterModelCallback: TypeAlias = Callable[
[CallbackContext, LlmResponse],
Optional[LlmResponse],
]
BeforeToolCallback: TypeAlias = Callable[
[BaseTool, dict[str, Any], ToolContext],
Optional[dict],
]
AfterToolCallback: TypeAlias = Callable[
[BaseTool, dict[str, Any], ToolContext, dict],
Optional[dict],
]
InstructionProvider: TypeAlias = Callable[[ReadonlyContext], str]
ToolUnion: TypeAlias = Union[Callable, BaseTool]
ExamplesUnion = Union[list[Example], BaseExampleProvider]
def _convert_tool_union_to_tool(
tool_union: ToolUnion,
) -> BaseTool:
return (
tool_union
if isinstance(tool_union, BaseTool)
else FunctionTool(tool_union)
)
class LlmAgent(BaseAgent):
"""LLM-based Agent."""
model: Union[str, BaseLlm] = ''
"""The model to use for the agent.
When not set, the agent will inherit the model from its ancestor.
"""
instruction: Union[str, InstructionProvider] = ''
"""Instructions for the LLM model, guiding the agent's behavior."""
global_instruction: Union[str, InstructionProvider] = ''
"""Instructions for all the agents in the entire agent tree.
global_instruction ONLY takes effect in root agent.
For example: use global_instruction to make all agents have a stable identity
or personality.
"""
tools: list[ToolUnion] = Field(default_factory=list)
"""Tools available to this agent."""
generate_content_config: Optional[types.GenerateContentConfig] = None
"""The additional content generation configurations.
NOTE: not all fields are usable, e.g. tools must be configured via `tools`,
thinking_config must be configured via `planner` in LlmAgent.
For example: use this config to adjust model temperature, configure safety
settings, etc.
"""
# LLM-based agent transfer configs - Start
disallow_transfer_to_parent: bool = False
"""Disallows LLM-controlled transferring to the parent agent."""
disallow_transfer_to_peers: bool = False
"""Disallows LLM-controlled transferring to the peer agents."""
# LLM-based agent transfer configs - End
include_contents: Literal['default', 'none'] = 'default'
"""Whether to include contents in the model request.
When set to 'none', the model request will not include any contents, such as
user messages, tool results, etc.
"""
# Controlled input/output configurations - Start
input_schema: Optional[type[BaseModel]] = None
"""The input schema when agent is used as a tool."""
output_schema: Optional[type[BaseModel]] = None
"""The output schema when agent replies.
NOTE: when this is set, agent can ONLY reply and CANNOT use any tools, such as
function tools, RAGs, agent transfer, etc.
"""
output_key: Optional[str] = None
"""The key in session state to store the output of the agent.
Typically use cases:
- Extracts agent reply for later use, such as in tools, callbacks, etc.
- Connects agents to coordinate with each other.
"""
# Controlled input/output configurations - End
# Advance features - Start
planner: Optional[BasePlanner] = None
"""Instructs the agent to make a plan and execute it step by step.
NOTE: to use model's built-in thinking features, set the `thinking_config`
field in `google.adk.planners.built_in_planner`.
"""
code_executor: Optional[BaseCodeExecutor] = None
"""Allow agent to execute code blocks from model responses using the provided
CodeExecutor.
Check out available code executions in `google.adk.code_executor` package.
NOTE: to use model's built-in code executor, don't set this field, add
`google.adk.tools.built_in_code_execution` to tools instead.
"""
# Advance features - End
# TODO: remove below fields after migration. - Start
# These fields are added back for easier migration.
examples: Optional[ExamplesUnion] = None
# TODO: remove above fields after migration. - End
# Callbacks - Start
before_model_callback: Optional[BeforeModelCallback] = None
"""Called before calling the LLM.
Args:
callback_context: CallbackContext,
llm_request: LlmRequest, The raw model request. Callback can mutate the
request.
Returns:
The content to return to the user. When present, the model call will be
skipped and the provided content will be returned to user.
"""
after_model_callback: Optional[AfterModelCallback] = None
"""Called after calling LLM.
Args:
callback_context: CallbackContext,
llm_response: LlmResponse, the actual model response.
Returns:
The content to return to the user. When present, the actual model response
will be ignored and the provided content will be returned to user.
"""
before_tool_callback: Optional[BeforeToolCallback] = None
"""Called before the tool is called.
Args:
tool: The tool to be called.
args: The arguments to the tool.
tool_context: ToolContext,
Returns:
The tool response. When present, the returned tool response will be used and
the framework will skip calling the actual tool.
"""
after_tool_callback: Optional[AfterToolCallback] = None
"""Called after the tool is called.
Args:
tool: The tool to be called.
args: The arguments to the tool.
tool_context: ToolContext,
tool_response: The response from the tool.
Returns:
When present, the returned dict will be used as tool result.
"""
# Callbacks - End
@override
async def _run_async_impl(
self, ctx: InvocationContext
) -> AsyncGenerator[Event, None]:
async for event in self._llm_flow.run_async(ctx):
self.__maybe_save_output_to_state(event)
yield event
@override
async def _run_live_impl(
self, ctx: InvocationContext
) -> AsyncGenerator[Event, None]:
async for event in self._llm_flow.run_live(ctx):
self.__maybe_save_output_to_state(event)
yield event
if ctx.end_invocation:
return
@property
def canonical_model(self) -> BaseLlm:
"""The resolved self.model field as BaseLlm.
This method is only for use by Agent Development Kit.
"""
if isinstance(self.model, BaseLlm):
return self.model
elif self.model: # model is non-empty str
return LLMRegistry.new_llm(self.model)
else: # find model from ancestors.
ancestor_agent = self.parent_agent
while ancestor_agent is not None:
if isinstance(ancestor_agent, LlmAgent):
return ancestor_agent.canonical_model
ancestor_agent = ancestor_agent.parent_agent
raise ValueError(f'No model found for {self.name}.')
def canonical_instruction(self, ctx: ReadonlyContext) -> str:
"""The resolved self.instruction field to construct instruction for this agent.
This method is only for use by Agent Development Kit.
"""
if isinstance(self.instruction, str):
return self.instruction
else:
return self.instruction(ctx)
def canonical_global_instruction(self, ctx: ReadonlyContext) -> str:
"""The resolved self.instruction field to construct global instruction.
This method is only for use by Agent Development Kit.
"""
if isinstance(self.global_instruction, str):
return self.global_instruction
else:
return self.global_instruction(ctx)
@property
def canonical_tools(self) -> list[BaseTool]:
"""The resolved self.tools field as a list of BaseTool.
This method is only for use by Agent Development Kit.
"""
return [_convert_tool_union_to_tool(tool) for tool in self.tools]
@property
def _llm_flow(self) -> BaseLlmFlow:
if (
self.disallow_transfer_to_parent
and self.disallow_transfer_to_peers
and not self.sub_agents
):
return SingleFlow()
else:
return AutoFlow()
def __maybe_save_output_to_state(self, event: Event):
"""Saves the model output to state if needed."""
if (
self.output_key
and event.is_final_response()
and event.content
and event.content.parts
):
result = ''.join(
[part.text if part.text else '' for part in event.content.parts]
)
if self.output_schema:
result = self.output_schema.model_validate_json(result).model_dump(
exclude_none=True
)
event.actions.state_delta[self.output_key] = result
@model_validator(mode='after')
def __model_validator_after(self) -> LlmAgent:
self.__check_output_schema()
return self
def __check_output_schema(self):
if not self.output_schema:
return
if (
not self.disallow_transfer_to_parent
or not self.disallow_transfer_to_peers
):
logger.warning(
'Invalid config for agent %s: output_schema cannot co-exist with'
' agent transfer configurations. Setting'
' disallow_transfer_to_parent=True, disallow_transfer_to_peers=True',
self.name,
)
self.disallow_transfer_to_parent = True
self.disallow_transfer_to_peers = True
if self.sub_agents:
raise ValueError(
f'Invalid config for agent {self.name}: if output_schema is set,'
' sub_agents must be empty to disable agent transfer.'
)
if self.tools:
raise ValueError(
f'Invalid config for agent {self.name}: if output_schema is set,'
' tools must be empty'
)
@field_validator('generate_content_config', mode='after')
@classmethod
def __validate_generate_content_config(
cls, generate_content_config: Optional[types.GenerateContentConfig]
) -> types.GenerateContentConfig:
if not generate_content_config:
return types.GenerateContentConfig()
if generate_content_config.thinking_config:
raise ValueError('Thinking config should be set via LlmAgent.planner.')
if generate_content_config.tools:
raise ValueError('All tools must be set via LlmAgent.tools.')
if generate_content_config.system_instruction:
raise ValueError(
'System instruction must be set via LlmAgent.instruction.'
)
if generate_content_config.response_schema:
raise ValueError(
'Response schema must be set via LlmAgent.output_schema.'
)
return generate_content_config
Agent: TypeAlias = LlmAgent

View File

@@ -0,0 +1,62 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Loop agent implementation."""
from __future__ import annotations
from typing import AsyncGenerator
from typing import Optional
from typing_extensions import override
from ..agents.invocation_context import InvocationContext
from ..events.event import Event
from .base_agent import BaseAgent
class LoopAgent(BaseAgent):
"""A shell agent that run its sub-agents in a loop.
When sub-agent generates an event with escalate or max_iterations are
reached, the loop agent will stop.
"""
max_iterations: Optional[int] = None
"""The maximum number of iterations to run the loop agent.
If not set, the loop agent will run indefinitely until a sub-agent
escalates.
"""
@override
async def _run_async_impl(
self, ctx: InvocationContext
) -> AsyncGenerator[Event, None]:
times_looped = 0
while not self.max_iterations or times_looped < self.max_iterations:
for sub_agent in self.sub_agents:
async for event in sub_agent.run_async(ctx):
yield event
if event.actions.escalate:
return
times_looped += 1
return
@override
async def _run_live_impl(
self, ctx: InvocationContext
) -> AsyncGenerator[Event, None]:
raise NotImplementedError('The behavior for run_live is not defined yet.')
yield # AsyncGenerator requires having at least one yield statement

View File

@@ -0,0 +1,96 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Parallel agent implementation."""
from __future__ import annotations
import asyncio
from typing import AsyncGenerator
from typing_extensions import override
from ..agents.invocation_context import InvocationContext
from ..events.event import Event
from .base_agent import BaseAgent
def _set_branch_for_current_agent(
current_agent: BaseAgent, invocation_context: InvocationContext
):
invocation_context.branch = (
f"{invocation_context.branch}.{current_agent.name}"
if invocation_context.branch
else current_agent.name
)
async def _merge_agent_run(
agent_runs: list[AsyncGenerator[Event, None]],
) -> AsyncGenerator[Event, None]:
"""Merges the agent run event generator.
This implementation guarantees for each agent, it won't move on until the
generated event is processed by upstream runner.
Args:
agent_runs: A list of async generators that yield events from each agent.
Yields:
Event: The next event from the merged generator.
"""
tasks = [
asyncio.create_task(events_for_one_agent.__anext__())
for events_for_one_agent in agent_runs
]
pending_tasks = set(tasks)
while pending_tasks:
done, pending_tasks = await asyncio.wait(
pending_tasks, return_when=asyncio.FIRST_COMPLETED
)
for task in done:
try:
yield task.result()
# Find the generator that produced this event and move it on.
for i, original_task in enumerate(tasks):
if task == original_task:
new_task = asyncio.create_task(agent_runs[i].__anext__())
tasks[i] = new_task
pending_tasks.add(new_task)
break # stop iterating once found
except StopAsyncIteration:
continue
class ParallelAgent(BaseAgent):
"""A shell agent that run its sub-agents in parallel in isolated manner.
This approach is beneficial for scenarios requiring multiple perspectives or
attempts on a single task, such as:
- Running different algorithms simultaneously.
- Generating multiple responses for review by a subsequent evaluation agent.
"""
@override
async def _run_async_impl(
self, ctx: InvocationContext
) -> AsyncGenerator[Event, None]:
_set_branch_for_current_agent(self, ctx)
agent_runs = [agent.run_async(ctx) for agent in self.sub_agents]
async for event in _merge_agent_run(agent_runs):
yield event

View File

@@ -0,0 +1,46 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
from types import MappingProxyType
from typing import Any
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from .invocation_context import InvocationContext
class ReadonlyContext:
def __init__(
self,
invocation_context: InvocationContext,
) -> None:
self._invocation_context = invocation_context
@property
def invocation_id(self) -> str:
"""The current invocation id."""
return self._invocation_context.invocation_id
@property
def agent_name(self) -> str:
"""The name of the agent that is currently running."""
return self._invocation_context.agent.name
@property
def state(self) -> MappingProxyType[str, Any]:
"""The state of the current session. READONLY field."""
return MappingProxyType(self._invocation_context.session.state)

View File

@@ -0,0 +1,50 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
from typing import AsyncGenerator
from pydantic import Field
import requests
from typing_extensions import override
from ..events.event import Event
from .base_agent import BaseAgent
from .invocation_context import InvocationContext
class RemoteAgent(BaseAgent):
"""Experimental, do not use."""
url: str
sub_agents: list[BaseAgent] = Field(
default_factory=list, init=False, frozen=True
)
"""Sub-agent is disabled in RemoteAgent."""
@override
async def _run_async_impl(
self, ctx: InvocationContext
) -> AsyncGenerator[Event, None]:
data = {
'invocation_id': ctx.invocation_id,
'session': ctx.session.model_dump(exclude_none=True),
}
events = requests.post(self.url, data=json.dumps(data), timeout=120)
events.raise_for_status()
for event in events.json():
e = Event.model_validate(event)
e.author = self.name
yield e

View File

@@ -0,0 +1,91 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from enum import Enum
import logging
import sys
from typing import Optional
from google.genai import types
from pydantic import BaseModel
from pydantic import ConfigDict
from pydantic import field_validator
logger = logging.getLogger(__name__)
class StreamingMode(Enum):
NONE = None
SSE = 'sse'
BIDI = 'bidi'
class RunConfig(BaseModel):
"""Configs for runtime behavior of agents."""
model_config = ConfigDict(
extra='forbid',
)
speech_config: Optional[types.SpeechConfig] = None
"""Speech configuration for the live agent."""
response_modalities: Optional[list[str]] = None
"""The output modalities. If not set, it's default to AUDIO."""
save_input_blobs_as_artifacts: bool = False
"""Whether or not to save the input blobs as artifacts."""
support_cfc: bool = False
"""
Whether to support CFC (Compositional Function Calling). Only applicable for
StreamingMode.SSE. If it's true. the LIVE API will be invoked. Since only LIVE
API supports CFC
.. warning::
This feature is **experimental** and its API or behavior may change
in future releases.
"""
streaming_mode: StreamingMode = StreamingMode.NONE
"""Streaming mode, None or StreamingMode.SSE or StreamingMode.BIDI."""
output_audio_transcription: Optional[types.AudioTranscriptionConfig] = None
"""Output transcription for live agents with audio response."""
max_llm_calls: int = 500
"""
A limit on the total number of llm calls for a given run.
Valid Values:
- More than 0 and less than sys.maxsize: The bound on the number of llm
calls is enforced, if the value is set in this range.
- Less than or equal to 0: This allows for unbounded number of llm calls.
"""
@field_validator('max_llm_calls', mode='after')
@classmethod
def validate_max_llm_calls(cls, value: int) -> int:
if value == sys.maxsize:
raise ValueError(f'max_llm_calls should be less than {sys.maxsize}.')
elif value <= 0:
logger.warning(
'max_llm_calls is less than or equal to 0. This will result in'
' no enforcement on total number of llm calls that will be made for a'
' run. This may not be ideal, as this could result in a never'
' ending communication between the model and the agent in certain'
' cases.',
)
return value

View File

@@ -0,0 +1,45 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Sequential agent implementation."""
from __future__ import annotations
from typing import AsyncGenerator
from typing_extensions import override
from ..agents.invocation_context import InvocationContext
from ..events.event import Event
from .base_agent import BaseAgent
class SequentialAgent(BaseAgent):
"""A shell agent that run its sub-agents in sequence."""
@override
async def _run_async_impl(
self, ctx: InvocationContext
) -> AsyncGenerator[Event, None]:
for sub_agent in self.sub_agents:
async for event in sub_agent.run_async(ctx):
yield event
@override
async def _run_live_impl(
self, ctx: InvocationContext
) -> AsyncGenerator[Event, None]:
for sub_agent in self.sub_agents:
async for event in sub_agent.run_live(ctx):
yield event

View File

@@ -0,0 +1,34 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Union
from google.genai import types
from pydantic import BaseModel
from pydantic import ConfigDict
class TranscriptionEntry(BaseModel):
"""Store the data that can be used for transcription."""
model_config = ConfigDict(
arbitrary_types_allowed=True,
extra='forbid',
)
role: str
"""The role that created this data, typically "user" or "model"""
data: Union[types.Blob, types.Content]
"""The data that can be used for transcription"""

View File

@@ -0,0 +1,23 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .base_artifact_service import BaseArtifactService
from .gcs_artifact_service import GcsArtifactService
from .in_memory_artifact_service import InMemoryArtifactService
__all__ = [
'BaseArtifactService',
'GcsArtifactService',
'InMemoryArtifactService',
]

View File

@@ -0,0 +1,128 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Abstract base class for artifact services."""
from abc import ABC
from abc import abstractmethod
from typing import Optional
from google.genai import types
class BaseArtifactService(ABC):
"""Abstract base class for artifact services."""
@abstractmethod
def save_artifact(
self,
*,
app_name: str,
user_id: str,
session_id: str,
filename: str,
artifact: types.Part,
) -> int:
"""Saves an artifact to the artifact service storage.
The artifact is a file identified by the app name, user ID, session ID, and
filename. After saving the artifact, a revision ID is returned to identify
the artifact version.
Args:
app_name: The app name.
user_id: The user ID.
session_id: The session ID.
filename: The filename of the artifact.
artifact: The artifact to save.
Returns:
The revision ID. The first version of the artifact has a revision ID of 0.
This is incremented by 1 after each successful save.
"""
@abstractmethod
def load_artifact(
self,
*,
app_name: str,
user_id: str,
session_id: str,
filename: str,
version: Optional[int] = None,
) -> Optional[types.Part]:
"""Gets an artifact from the artifact service storage.
The artifact is a file identified by the app name, user ID, session ID, and
filename.
Args:
app_name: The app name.
user_id: The user ID.
session_id: The session ID.
filename: The filename of the artifact.
version: The version of the artifact. If None, the latest version will be
returned.
Returns:
The artifact or None if not found.
"""
pass
@abstractmethod
def list_artifact_keys(
self, *, app_name: str, user_id: str, session_id: str
) -> list[str]:
"""Lists all the artifact filenames within a session.
Args:
app_name: The name of the application.
user_id: The ID of the user.
session_id: The ID of the session.
Returns:
A list of all artifact filenames within a session.
"""
pass
@abstractmethod
def delete_artifact(
self, *, app_name: str, user_id: str, session_id: str, filename: str
) -> None:
"""Deletes an artifact.
Args:
app_name: The name of the application.
user_id: The ID of the user.
session_id: The ID of the session.
filename: The name of the artifact file.
"""
pass
@abstractmethod
def list_versions(
self, *, app_name: str, user_id: str, session_id: str, filename: str
) -> list[int]:
"""Lists all versions of an artifact.
Args:
app_name: The name of the application.
user_id: The ID of the user.
session_id: The ID of the session.
filename: The name of the artifact file.
Returns:
A list of all available versions of the artifact.
"""
pass

View File

@@ -0,0 +1,195 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""An artifact service implementation using Google Cloud Storage (GCS)."""
import logging
from typing import Optional
from google.cloud import storage
from google.genai import types
from typing_extensions import override
from .base_artifact_service import BaseArtifactService
logger = logging.getLogger(__name__)
class GcsArtifactService(BaseArtifactService):
"""An artifact service implementation using Google Cloud Storage (GCS)."""
def __init__(self, bucket_name: str, **kwargs):
"""Initializes the GcsArtifactService.
Args:
bucket_name: The name of the bucket to use.
**kwargs: Keyword arguments to pass to the Google Cloud Storage client.
"""
self.bucket_name = bucket_name
self.storage_client = storage.Client(**kwargs)
self.bucket = self.storage_client.bucket(self.bucket_name)
def _file_has_user_namespace(self, filename: str) -> bool:
"""Checks if the filename has a user namespace.
Args:
filename: The filename to check.
Returns:
True if the filename has a user namespace (starts with "user:"),
False otherwise.
"""
return filename.startswith("user:")
def _get_blob_name(
self,
app_name: str,
user_id: str,
session_id: str,
filename: str,
version: int,
) -> str:
"""Constructs the blob name in GCS.
Args:
app_name: The name of the application.
user_id: The ID of the user.
session_id: The ID of the session.
filename: The name of the artifact file.
version: The version of the artifact.
Returns:
The constructed blob name in GCS.
"""
if self._file_has_user_namespace(filename):
return f"{app_name}/{user_id}/user/{filename}/{version}"
return f"{app_name}/{user_id}/{session_id}/{filename}/{version}"
@override
def save_artifact(
self,
*,
app_name: str,
user_id: str,
session_id: str,
filename: str,
artifact: types.Part,
) -> int:
versions = self.list_versions(
app_name=app_name,
user_id=user_id,
session_id=session_id,
filename=filename,
)
version = 0 if not versions else max(versions) + 1
blob_name = self._get_blob_name(
app_name, user_id, session_id, filename, version
)
blob = self.bucket.blob(blob_name)
blob.upload_from_string(
data=artifact.inline_data.data,
content_type=artifact.inline_data.mime_type,
)
return version
@override
def load_artifact(
self,
*,
app_name: str,
user_id: str,
session_id: str,
filename: str,
version: Optional[int] = None,
) -> Optional[types.Part]:
if version is None:
versions = self.list_versions(
app_name=app_name,
user_id=user_id,
session_id=session_id,
filename=filename,
)
if not versions:
return None
version = max(versions)
blob_name = self._get_blob_name(
app_name, user_id, session_id, filename, version
)
blob = self.bucket.blob(blob_name)
artifact_bytes = blob.download_as_bytes()
if not artifact_bytes:
return None
artifact = types.Part.from_bytes(
data=artifact_bytes, mime_type=blob.content_type
)
return artifact
@override
def list_artifact_keys(
self, *, app_name: str, user_id: str, session_id: str
) -> list[str]:
filenames = set()
session_prefix = f"{app_name}/{user_id}/{session_id}/"
session_blobs = self.storage_client.list_blobs(
self.bucket, prefix=session_prefix
)
for blob in session_blobs:
_, _, _, filename, _ = blob.name.split("/")
filenames.add(filename)
user_namespace_prefix = f"{app_name}/{user_id}/user/"
user_namespace_blobs = self.storage_client.list_blobs(
self.bucket, prefix=user_namespace_prefix
)
for blob in user_namespace_blobs:
_, _, _, filename, _ = blob.name.split("/")
filenames.add(filename)
return sorted(list(filenames))
@override
def delete_artifact(
self, *, app_name: str, user_id: str, session_id: str, filename: str
) -> None:
versions = self.list_versions(
app_name=app_name,
user_id=user_id,
session_id=session_id,
filename=filename,
)
for version in versions:
blob_name = self._get_blob_name(
app_name, user_id, session_id, filename, version
)
blob = self.bucket.blob(blob_name)
blob.delete()
return
@override
def list_versions(
self, *, app_name: str, user_id: str, session_id: str, filename: str
) -> list[int]:
prefix = self._get_blob_name(app_name, user_id, session_id, filename, "")
blobs = self.storage_client.list_blobs(self.bucket, prefix=prefix)
versions = []
for blob in blobs:
_, _, _, _, version = blob.name.split("/")
versions.append(int(version))
return versions

View File

@@ -0,0 +1,133 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""An in-memory implementation of the artifact service."""
import logging
from typing import Optional
from google.genai import types
from pydantic import BaseModel
from pydantic import Field
from typing_extensions import override
from .base_artifact_service import BaseArtifactService
logger = logging.getLogger(__name__)
class InMemoryArtifactService(BaseArtifactService, BaseModel):
"""An in-memory implementation of the artifact service."""
artifacts: dict[str, list[types.Part]] = Field(default_factory=dict)
def _file_has_user_namespace(self, filename: str) -> bool:
"""Checks if the filename has a user namespace.
Args:
filename: The filename to check.
Returns:
True if the filename has a user namespace (starts with "user:"),
False otherwise.
"""
return filename.startswith("user:")
def _artifact_path(
self, app_name: str, user_id: str, session_id: str, filename: str
) -> str:
"""Constructs the artifact path.
Args:
app_name: The name of the application.
user_id: The ID of the user.
session_id: The ID of the session.
filename: The name of the artifact file.
Returns:
The constructed artifact path.
"""
if self._file_has_user_namespace(filename):
return f"{app_name}/{user_id}/user/{filename}"
return f"{app_name}/{user_id}/{session_id}/{filename}"
@override
def save_artifact(
self,
*,
app_name: str,
user_id: str,
session_id: str,
filename: str,
artifact: types.Part,
) -> int:
path = self._artifact_path(app_name, user_id, session_id, filename)
if path not in self.artifacts:
self.artifacts[path] = []
version = len(self.artifacts[path])
self.artifacts[path].append(artifact)
return version
@override
def load_artifact(
self,
*,
app_name: str,
user_id: str,
session_id: str,
filename: str,
version: Optional[int] = None,
) -> Optional[types.Part]:
path = self._artifact_path(app_name, user_id, session_id, filename)
versions = self.artifacts.get(path)
if not versions:
return None
if version is None:
version = -1
return versions[version]
@override
def list_artifact_keys(
self, *, app_name: str, user_id: str, session_id: str
) -> list[str]:
session_prefix = f"{app_name}/{user_id}/{session_id}/"
usernamespace_prefix = f"{app_name}/{user_id}/user/"
filenames = []
for path in self.artifacts:
if path.startswith(session_prefix):
filename = path.removeprefix(session_prefix)
filenames.append(filename)
elif path.startswith(usernamespace_prefix):
filename = path.removeprefix(usernamespace_prefix)
filenames.append(filename)
return sorted(filenames)
@override
def delete_artifact(
self, *, app_name: str, user_id: str, session_id: str, filename: str
) -> None:
path = self._artifact_path(app_name, user_id, session_id, filename)
if not self.artifacts.get(path):
return None
self.artifacts.pop(path, None)
@override
def list_versions(
self, *, app_name: str, user_id: str, session_id: str, filename: str
) -> list[int]:
path = self._artifact_path(app_name, user_id, session_id, filename)
versions = self.artifacts.get(path)
if not versions:
return []
return list(range(len(versions)))

View File

@@ -0,0 +1,22 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .auth_credential import AuthCredential
from .auth_credential import AuthCredentialTypes
from .auth_credential import OAuth2Auth
from .auth_handler import AuthHandler
from .auth_schemes import AuthScheme
from .auth_schemes import AuthSchemeType
from .auth_schemes import OpenIdConnectWithConfig
from .auth_tool import AuthConfig

View File

@@ -0,0 +1,221 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from enum import Enum
from typing import Any
from typing import Dict
from typing import List
from typing import Optional
from pydantic import BaseModel
from pydantic import Field
class BaseModelWithConfig(BaseModel):
model_config = {"extra": "allow"}
class HttpCredentials(BaseModelWithConfig):
"""Represents the secret token value for HTTP authentication, like user name, password, oauth token, etc."""
username: Optional[str] = None
password: Optional[str] = None
token: Optional[str] = None
@classmethod
def model_validate(cls, data: Dict[str, Any]) -> "HttpCredentials":
return cls(
username=data.get("username"),
password=data.get("password"),
token=data.get("token"),
)
class HttpAuth(BaseModelWithConfig):
"""The credentials and metadata for HTTP authentication."""
# The name of the HTTP Authorization scheme to be used in the Authorization
# header as defined in RFC7235. The values used SHOULD be registered in the
# IANA Authentication Scheme registry.
# Examples: 'basic', 'bearer'
scheme: str
credentials: HttpCredentials
class OAuth2Auth(BaseModelWithConfig):
"""Represents credential value and its metadata for a OAuth2 credential."""
client_id: Optional[str] = None
client_secret: Optional[str] = None
# tool or adk can generate the auth_uri with the state info thus client
# can verify the state
auth_uri: Optional[str] = None
state: Optional[str] = None
# tool or adk can decide the redirect_uri if they don't want client to decide
redirect_uri: Optional[str] = None
auth_response_uri: Optional[str] = None
auth_code: Optional[str] = None
access_token: Optional[str] = None
refresh_token: Optional[str] = None
class ServiceAccountCredential(BaseModelWithConfig):
"""Represents Google Service Account configuration.
Attributes:
type: The type should be "service_account".
project_id: The project ID.
private_key_id: The ID of the private key.
private_key: The private key.
client_email: The client email.
client_id: The client ID.
auth_uri: The authorization URI.
token_uri: The token URI.
auth_provider_x509_cert_url: URL for auth provider's X.509 cert.
client_x509_cert_url: URL for the client's X.509 cert.
universe_domain: The universe domain.
Example:
config = ServiceAccountCredential(
type_="service_account",
project_id="your_project_id",
private_key_id="your_private_key_id",
private_key="-----BEGIN PRIVATE KEY-----...",
client_email="...@....iam.gserviceaccount.com",
client_id="your_client_id",
auth_uri="https://accounts.google.com/o/oauth2/auth",
token_uri="https://oauth2.googleapis.com/token",
auth_provider_x509_cert_url="https://www.googleapis.com/oauth2/v1/certs",
client_x509_cert_url="https://www.googleapis.com/robot/v1/metadata/x509/...",
universe_domain="googleapis.com"
)
config = ServiceAccountConfig.model_construct(**{
...service account config dict
})
"""
type_: str = Field("", alias="type")
project_id: str
private_key_id: str
private_key: str
client_email: str
client_id: str
auth_uri: str
token_uri: str
auth_provider_x509_cert_url: str
client_x509_cert_url: str
universe_domain: str
class ServiceAccount(BaseModelWithConfig):
"""Represents Google Service Account configuration."""
service_account_credential: Optional[ServiceAccountCredential] = None
scopes: List[str]
use_default_credential: Optional[bool] = False
class AuthCredentialTypes(str, Enum):
"""Represents the type of authentication credential."""
# API Key credential:
# https://swagger.io/docs/specification/v3_0/authentication/api-keys/
API_KEY = "apiKey"
# Credentials for HTTP Auth schemes:
# https://www.iana.org/assignments/http-authschemes/http-authschemes.xhtml
HTTP = "http"
# OAuth2 credentials:
# https://swagger.io/docs/specification/v3_0/authentication/oauth2/
OAUTH2 = "oauth2"
# OpenID Connect credentials:
# https://swagger.io/docs/specification/v3_0/authentication/openid-connect-discovery/
OPEN_ID_CONNECT = "openIdConnect"
# Service Account credentials:
# https://cloud.google.com/iam/docs/service-account-creds
SERVICE_ACCOUNT = "serviceAccount"
class AuthCredential(BaseModelWithConfig):
"""Data class representing an authentication credential.
To exchange for the actual credential, please use
CredentialExchanger.exchange_credential().
Examples: API Key Auth
AuthCredential(
auth_type=AuthCredentialTypes.API_KEY,
api_key="1234",
)
Example: HTTP Auth
AuthCredential(
auth_type=AuthCredentialTypes.HTTP,
http=HttpAuth(
scheme="basic",
credentials=HttpCredentials(username="user", password="password"),
),
)
Example: OAuth2 Bearer Token in HTTP Header
AuthCredential(
auth_type=AuthCredentialTypes.HTTP,
http=HttpAuth(
scheme="bearer",
credentials=HttpCredentials(token="eyAkaknabna...."),
),
)
Example: OAuth2 Auth with Authorization Code Flow
AuthCredential(
auth_type=AuthCredentialTypes.OAUTH2,
oauth2=OAuth2Auth(
client_id="1234",
client_secret="secret",
),
)
Example: OpenID Connect Auth
AuthCredential(
auth_type=AuthCredentialTypes.OPEN_ID_CONNECT,
oauth2=OAuth2Auth(
client_id="1234",
client_secret="secret",
redirect_uri="https://example.com",
scopes=["scope1", "scope2"],
),
)
Example: Auth with resource reference
AuthCredential(
auth_type=AuthCredentialTypes.API_KEY,
resource_ref="projects/1234/locations/us-central1/resources/resource1",
)
"""
auth_type: AuthCredentialTypes
# Resource reference for the credential.
# This will be supported in the future.
resource_ref: Optional[str] = None
api_key: Optional[str] = None
http: Optional[HttpAuth] = None
service_account: Optional[ServiceAccount] = None
oauth2: Optional[OAuth2Auth] = None

View File

@@ -0,0 +1,272 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
from typing import TYPE_CHECKING
from fastapi.openapi.models import OAuth2
from fastapi.openapi.models import SecurityBase
from .auth_credential import AuthCredential
from .auth_credential import AuthCredentialTypes
from .auth_credential import OAuth2Auth
from .auth_schemes import AuthSchemeType
from .auth_schemes import OAuthGrantType
from .auth_schemes import OpenIdConnectWithConfig
from .auth_tool import AuthConfig
if TYPE_CHECKING:
from ..sessions.state import State
try:
from authlib.integrations.requests_client import OAuth2Session
SUPPORT_TOKEN_EXCHANGE = True
except ImportError:
SUPPORT_TOKEN_EXCHANGE = False
class AuthHandler:
def __init__(self, auth_config: AuthConfig):
self.auth_config = auth_config
def exchange_auth_token(
self,
) -> AuthCredential:
"""Generates an auth token from the authorization response.
Returns:
An AuthCredential object containing the access token.
Raises:
ValueError: If the token endpoint is not configured in the auth
scheme.
AuthCredentialMissingError: If the access token cannot be retrieved
from the token endpoint.
"""
auth_scheme = self.auth_config.auth_scheme
auth_credential = self.auth_config.exchanged_auth_credential
if not SUPPORT_TOKEN_EXCHANGE:
return auth_credential
if isinstance(auth_scheme, OpenIdConnectWithConfig):
if not hasattr(auth_scheme, "token_endpoint"):
return self.auth_config.exchanged_auth_credential
token_endpoint = auth_scheme.token_endpoint
scopes = auth_scheme.scopes
elif isinstance(auth_scheme, OAuth2):
if (
not auth_scheme.flows.authorizationCode
or not auth_scheme.flows.authorizationCode.tokenUrl
):
return self.auth_config.exchanged_auth_credential
token_endpoint = auth_scheme.flows.authorizationCode.tokenUrl
scopes = list(auth_scheme.flows.authorizationCode.scopes.keys())
else:
return self.auth_config.exchanged_auth_credential
if (
not auth_credential
or not auth_credential.oauth2
or not auth_credential.oauth2.client_id
or not auth_credential.oauth2.client_secret
or auth_credential.oauth2.access_token
or auth_credential.oauth2.refresh_token
):
return self.auth_config.exchanged_auth_credential
client = OAuth2Session(
auth_credential.oauth2.client_id,
auth_credential.oauth2.client_secret,
scope=" ".join(scopes),
redirect_uri=auth_credential.oauth2.redirect_uri,
state=auth_credential.oauth2.state,
)
tokens = client.fetch_token(
token_endpoint,
authorization_response=auth_credential.oauth2.auth_response_uri,
code=auth_credential.oauth2.auth_code,
grant_type=OAuthGrantType.AUTHORIZATION_CODE,
)
updated_credential = AuthCredential(
auth_type=AuthCredentialTypes.OAUTH2,
oauth2=OAuth2Auth(
access_token=tokens.get("access_token"),
refresh_token=tokens.get("refresh_token"),
),
)
return updated_credential
def parse_and_store_auth_response(self, state: State) -> None:
credential_key = self.get_credential_key()
state[credential_key] = self.auth_config.exchanged_auth_credential
if not isinstance(
self.auth_config.auth_scheme, SecurityBase
) or self.auth_config.auth_scheme.type_ not in (
AuthSchemeType.oauth2,
AuthSchemeType.openIdConnect,
):
return
state[credential_key] = self.exchange_auth_token()
def _validate(self) -> None:
if not self.auth_scheme:
raise ValueError("auth_scheme is empty.")
def get_auth_response(self, state: State) -> AuthCredential:
credential_key = self.get_credential_key()
return state.get(credential_key, None)
def generate_auth_request(self) -> AuthConfig:
if not isinstance(
self.auth_config.auth_scheme, SecurityBase
) or self.auth_config.auth_scheme.type_ not in (
AuthSchemeType.oauth2,
AuthSchemeType.openIdConnect,
):
return self.auth_config.model_copy(deep=True)
# auth_uri already in exchanged credential
if (
self.auth_config.exchanged_auth_credential
and self.auth_config.exchanged_auth_credential.oauth2
and self.auth_config.exchanged_auth_credential.oauth2.auth_uri
):
return self.auth_config.model_copy(deep=True)
# Check if raw_auth_credential exists
if not self.auth_config.raw_auth_credential:
raise ValueError(
f"Auth Scheme {self.auth_config.auth_scheme.type_} requires"
" auth_credential."
)
# Check if oauth2 exists in raw_auth_credential
if not self.auth_config.raw_auth_credential.oauth2:
raise ValueError(
f"Auth Scheme {self.auth_config.auth_scheme.type_} requires oauth2 in"
" auth_credential."
)
# auth_uri in raw credential
if self.auth_config.raw_auth_credential.oauth2.auth_uri:
return AuthConfig(
auth_scheme=self.auth_config.auth_scheme,
raw_auth_credential=self.auth_config.raw_auth_credential,
exchanged_auth_credential=self.auth_config.raw_auth_credential.model_copy(
deep=True
),
)
# Check for client_id and client_secret
if (
not self.auth_config.raw_auth_credential.oauth2.client_id
or not self.auth_config.raw_auth_credential.oauth2.client_secret
):
raise ValueError(
f"Auth Scheme {self.auth_config.auth_scheme.type_} requires both"
" client_id and client_secret in auth_credential.oauth2."
)
# Generate new auth URI
exchanged_credential = self.generate_auth_uri()
return AuthConfig(
auth_scheme=self.auth_config.auth_scheme,
raw_auth_credential=self.auth_config.raw_auth_credential,
exchanged_auth_credential=exchanged_credential,
)
def get_credential_key(self) -> str:
"""Generates a unique key for the given auth scheme and credential."""
auth_scheme = self.auth_config.auth_scheme
auth_credential = self.auth_config.raw_auth_credential
if auth_scheme.model_extra:
auth_scheme = auth_scheme.model_copy(deep=True)
auth_scheme.model_extra.clear()
scheme_name = (
f"{auth_scheme.type_.name}_{hash(auth_scheme.model_dump_json())}"
if auth_scheme
else ""
)
if auth_credential.model_extra:
auth_credential = auth_credential.model_copy(deep=True)
auth_credential.model_extra.clear()
credential_name = (
f"{auth_credential.auth_type.value}_{hash(auth_credential.model_dump_json())}"
if auth_credential
else ""
)
return f"temp:adk_{scheme_name}_{credential_name}"
def generate_auth_uri(
self,
) -> AuthCredential:
"""Generates an response containing the auth uri for user to sign in.
Returns:
An AuthCredential object containing the auth URI and state.
Raises:
ValueError: If the authorization endpoint is not configured in the auth
scheme.
"""
auth_scheme = self.auth_config.auth_scheme
auth_credential = self.auth_config.raw_auth_credential
if isinstance(auth_scheme, OpenIdConnectWithConfig):
authorization_endpoint = auth_scheme.authorization_endpoint
scopes = auth_scheme.scopes
else:
authorization_endpoint = (
auth_scheme.flows.implicit
and auth_scheme.flows.implicit.authorizationUrl
or auth_scheme.flows.authorizationCode
and auth_scheme.flows.authorizationCode.authorizationUrl
or auth_scheme.flows.clientCredentials
and auth_scheme.flows.clientCredentials.tokenUrl
or auth_scheme.flows.password
and auth_scheme.flows.password.tokenUrl
)
scopes = (
auth_scheme.flows.implicit
and auth_scheme.flows.implicit.scopes
or auth_scheme.flows.authorizationCode
and auth_scheme.flows.authorizationCode.scopes
or auth_scheme.flows.clientCredentials
and auth_scheme.flows.clientCredentials.scopes
or auth_scheme.flows.password
and auth_scheme.flows.password.scopes
)
scopes = list(scopes.keys())
client = OAuth2Session(
auth_credential.oauth2.client_id,
auth_credential.oauth2.client_secret,
scope=" ".join(scopes),
redirect_uri=auth_credential.oauth2.redirect_uri,
)
uri, state = client.create_authorization_url(
url=authorization_endpoint, access_type="offline", prompt="consent"
)
exchanged_auth_credential = auth_credential.model_copy(deep=True)
exchanged_auth_credential.oauth2.auth_uri = uri
exchanged_auth_credential.oauth2.state = state
return exchanged_auth_credential

View File

@@ -0,0 +1,119 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
from typing import AsyncGenerator
from typing import TYPE_CHECKING
from typing_extensions import override
from ..agents.invocation_context import InvocationContext
from ..events.event import Event
from ..flows.llm_flows import functions
from ..flows.llm_flows._base_llm_processor import BaseLlmRequestProcessor
from ..flows.llm_flows.functions import REQUEST_EUC_FUNCTION_CALL_NAME
from ..models.llm_request import LlmRequest
from .auth_handler import AuthHandler
from .auth_tool import AuthConfig
from .auth_tool import AuthToolArguments
if TYPE_CHECKING:
from ..agents.llm_agent import LlmAgent
class _AuthLlmRequestProcessor(BaseLlmRequestProcessor):
"""Handles auth information to build the LLM request."""
@override
async def run_async(
self, invocation_context: InvocationContext, llm_request: LlmRequest
) -> AsyncGenerator[Event, None]:
from ..agents.llm_agent import LlmAgent
agent = invocation_context.agent
if not isinstance(agent, LlmAgent):
return
events = invocation_context.session.events
if not events:
return
request_euc_function_call_ids = set()
for k in range(len(events) - 1, -1, -1):
event = events[k]
# look for first event authored by user
if not event.author or event.author != 'user':
continue
responses = event.get_function_responses()
if not responses:
return
for function_call_response in responses:
if function_call_response.name != REQUEST_EUC_FUNCTION_CALL_NAME:
continue
# found the function call response for the system long running request euc
# function call
request_euc_function_call_ids.add(function_call_response.id)
auth_config = AuthConfig.model_validate(function_call_response.response)
AuthHandler(auth_config=auth_config).parse_and_store_auth_response(
state=invocation_context.session.state
)
break
if not request_euc_function_call_ids:
return
for i in range(len(events) - 2, -1, -1):
event = events[i]
# looking for the system long running request euc function call
function_calls = event.get_function_calls()
if not function_calls:
continue
tools_to_resume = set()
for function_call in function_calls:
if function_call.id not in request_euc_function_call_ids:
continue
args = AuthToolArguments.model_validate(function_call.args)
tools_to_resume.add(args.function_call_id)
if not tools_to_resume:
continue
# found the the system long running request euc function call
# looking for original function call that requests euc
for j in range(i - 1, -1, -1):
event = events[j]
function_calls = event.get_function_calls()
if not function_calls:
continue
for function_call in function_calls:
function_response_event = None
if function_call.id in tools_to_resume:
function_response_event = await functions.handle_function_calls_async(
invocation_context,
event,
{tool.name: tool for tool in agent.canonical_tools},
# there could be parallel function calls that require auth
# auth response would be a dict keyed by function call id
tools_to_resume,
)
if function_response_event:
yield function_response_event
return
return
request_processor = _AuthLlmRequestProcessor()

View File

@@ -0,0 +1,67 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from enum import Enum
from typing import List
from typing import Optional
from typing import Union
from fastapi.openapi.models import OAuthFlows
from fastapi.openapi.models import SecurityBase
from fastapi.openapi.models import SecurityScheme
from fastapi.openapi.models import SecuritySchemeType
from pydantic import Field
class OpenIdConnectWithConfig(SecurityBase):
type_: SecuritySchemeType = Field(
default=SecuritySchemeType.openIdConnect, alias="type"
)
authorization_endpoint: str
token_endpoint: str
userinfo_endpoint: Optional[str] = None
revocation_endpoint: Optional[str] = None
token_endpoint_auth_methods_supported: Optional[List[str]] = None
grant_types_supported: Optional[List[str]] = None
scopes: Optional[List[str]] = None
# AuthSchemes contains SecuritySchemes from OpenAPI 3.0 and an extra flattened OpenIdConnectWithConfig.
AuthScheme = Union[SecurityScheme, OpenIdConnectWithConfig]
class OAuthGrantType(str, Enum):
"""Represents the OAuth2 flow (or grant type)."""
CLIENT_CREDENTIALS = "client_credentials"
AUTHORIZATION_CODE = "authorization_code"
IMPLICIT = "implicit"
PASSWORD = "password"
@staticmethod
def from_flow(flow: OAuthFlows) -> "OAuthGrantType":
"""Converts an OAuthFlows object to a OAuthGrantType."""
if flow.clientCredentials:
return OAuthGrantType.CLIENT_CREDENTIALS
if flow.authorizationCode:
return OAuthGrantType.AUTHORIZATION_CODE
if flow.implicit:
return OAuthGrantType.IMPLICIT
if flow.password:
return OAuthGrantType.PASSWORD
return None
# AuthSchemeType re-exports SecuritySchemeType from OpenAPI 3.0.
AuthSchemeType = SecuritySchemeType

View File

@@ -0,0 +1,55 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from pydantic import BaseModel
from .auth_credential import AuthCredential
from .auth_schemes import AuthScheme
class AuthConfig(BaseModel):
"""The auth config sent by tool asking client to collect auth credentials and
adk and client will help to fill in the response
"""
auth_scheme: AuthScheme
"""The auth scheme used to collect credentials"""
raw_auth_credential: AuthCredential = None
"""The raw auth credential used to collect credentials. The raw auth
credentials are used in some auth scheme that needs to exchange auth
credentials. e.g. OAuth2 and OIDC. For other auth scheme, it could be None.
"""
exchanged_auth_credential: AuthCredential = None
"""The exchanged auth credential used to collect credentials. adk and client
will work together to fill it. For those auth scheme that doesn't need to
exchange auth credentials, e.g. API key, service account etc. It's filled by
client directly. For those auth scheme that need to exchange auth credentials,
e.g. OAuth2 and OIDC, it's first filled by adk. If the raw credentials
passed by tool only has client id and client credential, adk will help to
generate the corresponding authorization uri and state and store the processed
credential in this field. If the raw credentials passed by tool already has
authorization uri, state, etc. then it's copied to this field. Client will use
this field to guide the user through the OAuth2 flow and fill auth response in
this field"""
class AuthToolArguments(BaseModel):
"""the arguments for the special long running function tool that is used to
request end user credentials.
"""
function_call_id: str
auth_config: AuthConfig

View File

@@ -0,0 +1,15 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .cli_tools_click import main

View File

@@ -0,0 +1,18 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .cli_tools_click import main
if __name__ == '__main__':
main()

View File

@@ -0,0 +1,148 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import logging
from typing import Union
import graphviz
from ..agents import BaseAgent
from ..agents.llm_agent import LlmAgent
from ..tools.agent_tool import AgentTool
from ..tools.base_tool import BaseTool
from ..tools.function_tool import FunctionTool
logger = logging.getLogger(__name__)
try:
from ..tools.retrieval.base_retrieval_tool import BaseRetrievalTool
except ModuleNotFoundError:
retrieval_tool_module_loaded = False
else:
retrieval_tool_module_loaded = True
def build_graph(graph, agent: BaseAgent, highlight_pairs):
dark_green = '#0F5223'
light_green = '#69CB87'
light_gray = '#cccccc'
def get_node_name(tool_or_agent: Union[BaseAgent, BaseTool]):
if isinstance(tool_or_agent, BaseAgent):
return tool_or_agent.name
elif isinstance(tool_or_agent, BaseTool):
return tool_or_agent.name
else:
raise ValueError(f'Unsupported tool type: {tool_or_agent}')
def get_node_caption(tool_or_agent: Union[BaseAgent, BaseTool]):
if isinstance(tool_or_agent, BaseAgent):
return '🤖 ' + tool_or_agent.name
elif retrieval_tool_module_loaded and isinstance(
tool_or_agent, BaseRetrievalTool
):
return '🔎 ' + tool_or_agent.name
elif isinstance(tool_or_agent, FunctionTool):
return '🔧 ' + tool_or_agent.name
elif isinstance(tool_or_agent, AgentTool):
return '🤖 ' + tool_or_agent.name
elif isinstance(tool_or_agent, BaseTool):
return '🔧 ' + tool_or_agent.name
else:
logger.warning(
'Unsupported tool, type: %s, obj: %s',
type(tool_or_agent),
tool_or_agent,
)
return f'❓ Unsupported tool type: {type(tool_or_agent)}'
def get_node_shape(tool_or_agent: Union[BaseAgent, BaseTool]):
if isinstance(tool_or_agent, BaseAgent):
return 'ellipse'
elif retrieval_tool_module_loaded and isinstance(
tool_or_agent, BaseRetrievalTool
):
return 'cylinder'
elif isinstance(tool_or_agent, FunctionTool):
return 'box'
elif isinstance(tool_or_agent, BaseTool):
return 'box'
else:
logger.warning(
'Unsupported tool, type: %s, obj: %s',
type(tool_or_agent),
tool_or_agent,
)
return 'cylinder'
def draw_node(tool_or_agent: Union[BaseAgent, BaseTool]):
name = get_node_name(tool_or_agent)
shape = get_node_shape(tool_or_agent)
caption = get_node_caption(tool_or_agent)
if highlight_pairs:
for highlight_tuple in highlight_pairs:
if name in highlight_tuple:
graph.node(
name,
caption,
style='filled,rounded',
fillcolor=dark_green,
color=dark_green,
shape=shape,
fontcolor=light_gray,
)
return
# if not in highlight, draw non-highliht node
graph.node(
name,
caption,
shape=shape,
style='rounded',
color=light_gray,
fontcolor=light_gray,
)
def draw_edge(from_name, to_name):
if highlight_pairs:
for highlight_from, highlight_to in highlight_pairs:
if from_name == highlight_from and to_name == highlight_to:
graph.edge(from_name, to_name, color=light_green)
return
elif from_name == highlight_to and to_name == highlight_from:
graph.edge(from_name, to_name, color=light_green, dir='back')
return
# if no need to highlight, color gray
graph.edge(from_name, to_name, arrowhead='none', color=light_gray)
draw_node(agent)
for sub_agent in agent.sub_agents:
build_graph(graph, sub_agent, highlight_pairs)
draw_edge(agent.name, sub_agent.name)
if isinstance(agent, LlmAgent):
for tool in agent.canonical_tools:
draw_node(tool)
draw_edge(agent.name, get_node_name(tool))
def get_agent_graph(root_agent, highlights_pairs, image=False):
print('build graph')
graph = graphviz.Digraph(graph_attr={'rankdir': 'LR', 'bgcolor': '#333537'})
build_graph(graph, root_agent, highlights_pairs)
if image:
return graph.pipe(format='png')
else:
return graph

View File

@@ -0,0 +1,17 @@
<svg width="16" height="16" viewBox="0 0 16 16" fill="none" xmlns="http://www.w3.org/2000/svg">
<g clip-path="url(#clip0_756_3354)">
<path fill-rule="evenodd" clip-rule="evenodd" d="M8.69139 10.1458C8.89799 10.3937 8.8645 10.7622 8.61657 10.9688L7.07351 12.2547L8.61657 13.5406C8.8645 13.7472 8.89799 14.1157 8.69139 14.3636C8.48478 14.6115 8.11631 14.645 7.86838 14.4384L5.82029 12.7317C5.52243 12.4834 5.52242 12.026 5.82029 11.7777L7.86838 10.071C8.11631 9.86438 8.48478 9.89788 8.69139 10.1458Z" fill="#EA4335"/>
<path fill-rule="evenodd" clip-rule="evenodd" d="M11.4459 10.1458C11.2393 10.3937 11.2728 10.7622 11.5207 10.9688L13.0638 12.2547L11.5207 13.5406C11.2728 13.7472 11.2393 14.1157 11.4459 14.3636C11.6525 14.6115 12.021 14.645 12.2689 14.4384L14.317 12.7317C14.6149 12.4834 14.6149 12.026 14.317 11.7777L12.2689 10.071C12.021 9.86438 11.6525 9.89788 11.4459 10.1458Z" fill="#EA4335"/>
<path fill-rule="evenodd" clip-rule="evenodd" d="M5.94165 2.19288C4.44903 2.19288 3.23902 3.40289 3.23902 4.89551C3.23902 6.38813 4.44903 7.59814 5.94165 7.59814H8.60776V8.76685H5.94165C3.80357 8.76685 2.07031 7.03359 2.07031 4.89551C2.07031 2.75743 3.80357 1.02417 5.94165 1.02417H9.73995C10.0627 1.02417 10.3243 1.28579 10.3243 1.60852C10.3243 1.93125 10.0627 2.19288 9.73995 2.19288H5.94165Z" fill="#4285F4"/>
<path fill-rule="evenodd" clip-rule="evenodd" d="M10.6895 2.19288C12.1821 2.19288 13.3922 3.40289 13.3922 4.89551C13.3922 6.38813 12.1821 7.59814 10.6895 7.59814H6.89123C6.5685 7.59814 6.30687 7.85977 6.30687 8.1825C6.30687 8.50523 6.5685 8.76685 6.89123 8.76685H10.6895C12.8276 8.76685 14.5609 7.03359 14.5609 4.89551C14.5609 2.75743 12.8276 1.02417 10.6895 1.02417H6.89123C6.5685 1.02417 6.30687 1.28579 6.30687 1.60852C6.30687 1.93125 6.5685 2.19288 6.89123 2.19288H10.6895Z" fill="#34A853"/>
<path fill-rule="evenodd" clip-rule="evenodd" d="M3.23902 10.739H4.18859C4.51132 10.739 4.77295 10.4774 4.77295 10.1547C4.77295 9.83196 4.51132 9.57033 4.18859 9.57033H3.01989C2.49545 9.57033 2.07031 9.99547 2.07031 10.5199V14.026C2.07031 14.5505 2.49545 14.9756 3.01989 14.9756H4.18859C4.51132 14.9756 4.77295 14.714 4.77295 14.3912C4.77295 14.0685 4.51132 13.8069 4.18859 13.8069H3.23902V10.739Z" fill="#FBBC04"/>
<path fill-rule="evenodd" clip-rule="evenodd" d="M10.9452 8.1825C10.9452 7.85977 10.6836 7.59814 10.3608 7.59814H6.89123C6.5685 7.59814 6.30687 7.85977 6.30687 8.1825C6.30687 8.50523 6.5685 8.76685 6.89123 8.76685H10.3608C10.6836 8.76685 10.9452 8.50523 10.9452 8.1825Z" fill="#4285F4"/>
<path d="M6.74514 4.89551C6.74514 5.25858 6.45081 5.55291 6.08774 5.55291C5.72467 5.55291 5.43034 5.25858 5.43034 4.89551C5.43034 4.53244 5.72467 4.23811 6.08774 4.23811C6.45081 4.23811 6.74514 4.53244 6.74514 4.89551Z" fill="#4285F4"/>
<path d="M11.2739 4.89551C11.2739 5.25858 10.9795 5.55291 10.6165 5.55291C10.2534 5.55291 9.95908 5.25858 9.95908 4.89551C9.95908 4.53244 10.2534 4.23811 10.6165 4.23811C10.9795 4.23811 11.2739 4.53244 11.2739 4.89551Z" fill="#4285F4"/>
</g>
<defs>
<clipPath id="clip0_756_3354">
<rect width="12.6294" height="14" fill="white" transform="translate(2 1)"/>
</clipPath>
</defs>
</svg>

After

Width:  |  Height:  |  Size: 3.1 KiB

View File

@@ -0,0 +1,51 @@
/**
* Copyright 2025 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
class AudioProcessor extends AudioWorkletProcessor {
constructor() {
super();
this.targetSampleRate = 22000; // Change to your desired rate
this.originalSampleRate = sampleRate; // Browser's sample rate
this.resampleRatio = this.originalSampleRate / this.targetSampleRate;
}
process(inputs, outputs, parameters) {
const input = inputs[0];
if (input.length > 0) {
let audioData = input[0]; // Get first channel's data
if (this.resampleRatio !== 1) {
audioData = this.resample(audioData);
}
this.port.postMessage(audioData);
}
return true; // Keep processor alive
}
resample(audioData) {
const newLength = Math.round(audioData.length / this.resampleRatio);
const resampled = new Float32Array(newLength);
for (let i = 0; i < newLength; i++) {
const srcIndex = Math.floor(i * this.resampleRatio);
resampled[i] = audioData[srcIndex]; // Nearest neighbor resampling
}
return resampled;
}
}
registerProcessor('audio-processor', AudioProcessor);

View File

@@ -0,0 +1,3 @@
{
"backendUrl": ""
}

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

View File

@@ -0,0 +1,17 @@
/**
* Copyright 2025 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
html{color-scheme:dark}html{--mat-sys-background: light-dark(#fcf9f8, #131314);--mat-sys-error: light-dark(#ba1a1a, #ffb4ab);--mat-sys-error-container: light-dark(#ffdad6, #93000a);--mat-sys-inverse-on-surface: light-dark(#f3f0f0, #313030);--mat-sys-inverse-primary: light-dark(#c1c7cd, #595f65);--mat-sys-inverse-surface: light-dark(#313030, #e5e2e2);--mat-sys-on-background: light-dark(#1c1b1c, #e5e2e2);--mat-sys-on-error: light-dark(#ffffff, #690005);--mat-sys-on-error-container: light-dark(#410002, #ffdad6);--mat-sys-on-primary: light-dark(#ffffff, #2b3136);--mat-sys-on-primary-container: light-dark(#161c21, #dde3e9);--mat-sys-on-primary-fixed: light-dark(#161c21, #161c21);--mat-sys-on-primary-fixed-variant: light-dark(#41474d, #41474d);--mat-sys-on-secondary: light-dark(#ffffff, #003061);--mat-sys-on-secondary-container: light-dark(#001b3c, #d5e3ff);--mat-sys-on-secondary-fixed: light-dark(#001b3c, #001b3c);--mat-sys-on-secondary-fixed-variant: light-dark(#0f4784, #0f4784);--mat-sys-on-surface: light-dark(#1c1b1c, #e5e2e2);--mat-sys-on-surface-variant: light-dark(#44474a, #e1e2e6);--mat-sys-on-tertiary: light-dark(#ffffff, #2b3136);--mat-sys-on-tertiary-container: light-dark(#161c21, #dde3e9);--mat-sys-on-tertiary-fixed: light-dark(#161c21, #161c21);--mat-sys-on-tertiary-fixed-variant: light-dark(#41474d, #41474d);--mat-sys-outline: light-dark(#74777b, #8e9194);--mat-sys-outline-variant: light-dark(#c4c7ca, #44474a);--mat-sys-primary: light-dark(#595f65, #c1c7cd);--mat-sys-primary-container: light-dark(#dde3e9, #41474d);--mat-sys-primary-fixed: light-dark(#dde3e9, #dde3e9);--mat-sys-primary-fixed-dim: light-dark(#c1c7cd, #c1c7cd);--mat-sys-scrim: light-dark(#000000, #000000);--mat-sys-secondary: light-dark(#305f9d, #a7c8ff);--mat-sys-secondary-container: light-dark(#d5e3ff, #0f4784);--mat-sys-secondary-fixed: light-dark(#d5e3ff, #d5e3ff);--mat-sys-secondary-fixed-dim: light-dark(#a7c8ff, #a7c8ff);--mat-sys-shadow: light-dark(#000000, #000000);--mat-sys-surface: light-dark(#fcf9f8, #131314);--mat-sys-surface-bright: light-dark(#fcf9f8, #393939);--mat-sys-surface-container: light-dark(#f0eded, #201f20);--mat-sys-surface-container-high: light-dark(#eae7e7, #2a2a2a);--mat-sys-surface-container-highest: light-dark(#e5e2e2, #393939);--mat-sys-surface-container-low: light-dark(#f6f3f3, #1c1b1c);--mat-sys-surface-container-lowest: light-dark(#ffffff, #0e0e0e);--mat-sys-surface-dim: light-dark(#dcd9d9, #131314);--mat-sys-surface-tint: light-dark(#595f65, #c1c7cd);--mat-sys-surface-variant: light-dark(#e1e2e6, #44474a);--mat-sys-tertiary: light-dark(#595f65, #c1c7cd);--mat-sys-tertiary-container: light-dark(#dde3e9, #41474d);--mat-sys-tertiary-fixed: light-dark(#dde3e9, #dde3e9);--mat-sys-tertiary-fixed-dim: light-dark(#c1c7cd, #c1c7cd);--mat-sys-neutral-variant20: #2d3134;--mat-sys-neutral10: #1c1b1c}html{--mat-sys-level0: 0px 0px 0px 0px rgba(0, 0, 0, .2), 0px 0px 0px 0px rgba(0, 0, 0, .14), 0px 0px 0px 0px rgba(0, 0, 0, .12)}html{--mat-sys-level1: 0px 2px 1px -1px rgba(0, 0, 0, .2), 0px 1px 1px 0px rgba(0, 0, 0, .14), 0px 1px 3px 0px rgba(0, 0, 0, .12)}html{--mat-sys-level2: 0px 3px 3px -2px rgba(0, 0, 0, .2), 0px 3px 4px 0px rgba(0, 0, 0, .14), 0px 1px 8px 0px rgba(0, 0, 0, .12)}html{--mat-sys-level3: 0px 3px 5px -1px rgba(0, 0, 0, .2), 0px 6px 10px 0px rgba(0, 0, 0, .14), 0px 1px 18px 0px rgba(0, 0, 0, .12)}html{--mat-sys-level4: 0px 5px 5px -3px rgba(0, 0, 0, .2), 0px 8px 10px 1px rgba(0, 0, 0, .14), 0px 3px 14px 2px rgba(0, 0, 0, .12)}html{--mat-sys-level5: 0px 7px 8px -4px rgba(0, 0, 0, .2), 0px 12px 17px 2px rgba(0, 0, 0, .14), 0px 5px 22px 4px rgba(0, 0, 0, .12)}html{--mat-sys-corner-extra-large: 28px;--mat-sys-corner-extra-large-top: 28px 28px 0 0;--mat-sys-corner-extra-small: 4px;--mat-sys-corner-extra-small-top: 4px 4px 0 0;--mat-sys-corner-full: 9999px;--mat-sys-corner-large: 16px;--mat-sys-corner-large-end: 0 16px 16px 0;--mat-sys-corner-large-start: 16px 0 0 16px;--mat-sys-corner-large-top: 16px 16px 0 0;--mat-sys-corner-medium: 12px;--mat-sys-corner-none: 0;--mat-sys-corner-small: 8px}html{--mat-sys-dragged-state-layer-opacity: .16;--mat-sys-focus-state-layer-opacity: .12;--mat-sys-hover-state-layer-opacity: .08;--mat-sys-pressed-state-layer-opacity: .12}html{font-family:Google Sans,Helvetica Neue,sans-serif!important}body{height:100vh;margin:0}markdown p{margin-block-start:.5em;margin-block-end:.5em}:root{--mat-sys-primary: black;--mdc-checkbox-selected-icon-color: white;--mat-sys-background: #131314;--mat-tab-header-active-label-text-color: #8AB4F8;--mat-tab-header-active-hover-label-text-color: #8AB4F8;--mat-tab-header-active-focus-label-text-color: #8AB4F8;--mat-tab-header-label-text-weight: 500;--mdc-text-button-label-text-color: #89b4f8}:root{--mdc-dialog-container-color: #2b2b2f}:root{--mdc-dialog-subhead-color: white}:root{--mdc-circular-progress-active-indicator-color: #a8c7fa}:root{--mdc-circular-progress-size: 80}

View File

@@ -0,0 +1,183 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from datetime import datetime
import importlib
import os
import sys
from typing import Optional
import click
from google.genai import types
from pydantic import BaseModel
from ..agents.llm_agent import LlmAgent
from ..artifacts import BaseArtifactService
from ..artifacts import InMemoryArtifactService
from ..runners import Runner
from ..sessions.base_session_service import BaseSessionService
from ..sessions.in_memory_session_service import InMemorySessionService
from ..sessions.session import Session
from .utils import envs
class InputFile(BaseModel):
state: dict[str, object]
queries: list[str]
async def run_input_file(
app_name: str,
root_agent: LlmAgent,
artifact_service: BaseArtifactService,
session: Session,
session_service: BaseSessionService,
input_path: str,
) -> None:
runner = Runner(
app_name=app_name,
agent=root_agent,
artifact_service=artifact_service,
session_service=session_service,
)
with open(input_path, 'r', encoding='utf-8') as f:
input_file = InputFile.model_validate_json(f.read())
input_file.state['_time'] = datetime.now()
session.state = input_file.state
for query in input_file.queries:
click.echo(f'user: {query}')
content = types.Content(role='user', parts=[types.Part(text=query)])
async for event in runner.run_async(
user_id=session.user_id, session_id=session.id, new_message=content
):
if event.content and event.content.parts:
if text := ''.join(part.text or '' for part in event.content.parts):
click.echo(f'[{event.author}]: {text}')
async def run_interactively(
app_name: str,
root_agent: LlmAgent,
artifact_service: BaseArtifactService,
session: Session,
session_service: BaseSessionService,
) -> None:
runner = Runner(
app_name=app_name,
agent=root_agent,
artifact_service=artifact_service,
session_service=session_service,
)
while True:
query = input('user: ')
if not query or not query.strip():
continue
if query == 'exit':
break
async for event in runner.run_async(
user_id=session.user_id,
session_id=session.id,
new_message=types.Content(role='user', parts=[types.Part(text=query)]),
):
if event.content and event.content.parts:
if text := ''.join(part.text or '' for part in event.content.parts):
click.echo(f'[{event.author}]: {text}')
async def run_cli(
*,
agent_parent_dir: str,
agent_folder_name: str,
json_file_path: Optional[str] = None,
save_session: bool,
) -> None:
"""Runs an interactive CLI for a certain agent.
Args:
agent_parent_dir: str, the absolute path of the parent folder of the agent
folder.
agent_folder_name: str, the name of the agent folder.
json_file_path: Optional[str], the absolute path to the json file, either
*.input.json or *.session.json.
save_session: bool, whether to save the session on exit.
"""
if agent_parent_dir not in sys.path:
sys.path.append(agent_parent_dir)
artifact_service = InMemoryArtifactService()
session_service = InMemorySessionService()
session = session_service.create_session(
app_name=agent_folder_name, user_id='test_user'
)
agent_module_path = os.path.join(agent_parent_dir, agent_folder_name)
agent_module = importlib.import_module(agent_folder_name)
root_agent = agent_module.agent.root_agent
envs.load_dotenv_for_agent(agent_folder_name, agent_parent_dir)
if json_file_path:
if json_file_path.endswith('.input.json'):
await run_input_file(
app_name=agent_folder_name,
root_agent=root_agent,
artifact_service=artifact_service,
session=session,
session_service=session_service,
input_path=json_file_path,
)
elif json_file_path.endswith('.session.json'):
with open(json_file_path, 'r') as f:
session = Session.model_validate_json(f.read())
for content in session.get_contents():
if content.role == 'user':
print('user: ', content.parts[0].text)
else:
print(content.parts[0].text)
await run_interactively(
agent_folder_name,
root_agent,
artifact_service,
session,
session_service,
)
else:
print(f'Unsupported file type: {json_file_path}')
exit(1)
else:
print(f'Running agent {root_agent.name}, type exit to exit.')
await run_interactively(
agent_folder_name,
root_agent,
artifact_service,
session,
session_service,
)
if save_session:
if json_file_path:
session_path = json_file_path.replace('.input.json', '.session.json')
else:
session_id = input('Session ID to save: ')
session_path = f'{agent_module_path}/{session_id}.session.json'
# Fetch the session again to get all the details.
session = session_service.get_session(
app_name=session.app_name,
user_id=session.user_id,
session_id=session.id,
)
with open(session_path, 'w') as f:
f.write(session.model_dump_json(indent=2, exclude_none=True))
print('Session saved to', session_path)

View File

@@ -0,0 +1,279 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import subprocess
from typing import Optional
from typing import Tuple
import click
_INIT_PY_TEMPLATE = """\
from . import agent
"""
_AGENT_PY_TEMPLATE = """\
from google.adk.agents import Agent
root_agent = Agent(
model='{model_name}',
name='root_agent',
description='A helpful assistant for user questions.',
instruction='Answer user questions to the best of your knowledge',
)
"""
_GOOGLE_API_MSG = """
Don't have API Key? Create one in AI Studio: https://aistudio.google.com/apikey
"""
_GOOGLE_CLOUD_SETUP_MSG = """
You need an existing Google Cloud account and project, check out this link for details:
https://google.github.io/adk-docs/get-started/quickstart/#gemini---google-cloud-vertex-ai
"""
_OTHER_MODEL_MSG = """
Please see below guide to configure other models:
https://google.github.io/adk-docs/agents/models
"""
_SUCCESS_MSG = """
Agent created in {agent_folder}:
- .env
- __init__.py
- agent.py
"""
def _get_gcp_project_from_gcloud() -> str:
"""Uses gcloud to get default project."""
try:
result = subprocess.run(
["gcloud", "config", "get-value", "project"],
capture_output=True,
text=True,
check=True,
)
return result.stdout.strip()
except (subprocess.CalledProcessError, FileNotFoundError):
return ""
def _get_gcp_region_from_gcloud() -> str:
"""Uses gcloud to get default region."""
try:
result = subprocess.run(
["gcloud", "config", "get-value", "compute/region"],
capture_output=True,
text=True,
check=True,
)
return result.stdout.strip()
except (subprocess.CalledProcessError, FileNotFoundError):
return ""
def _prompt_str(
prompt_prefix: str,
*,
prior_msg: Optional[str] = None,
default_value: Optional[str] = None,
) -> str:
if prior_msg:
click.secho(prior_msg, fg="green")
while True:
value: str = click.prompt(
prompt_prefix, default=default_value or None, type=str
)
if value and value.strip():
return value.strip()
def _prompt_for_google_cloud(
google_cloud_project: Optional[str],
) -> str:
"""Prompts user for Google Cloud project ID."""
google_cloud_project = (
google_cloud_project
or os.environ.get("GOOGLE_CLOUD_PROJECT", None)
or _get_gcp_project_from_gcloud()
)
google_cloud_project = _prompt_str(
"Enter Google Cloud project ID", default_value=google_cloud_project
)
return google_cloud_project
def _prompt_for_google_cloud_region(
google_cloud_region: Optional[str],
) -> str:
"""Prompts user for Google Cloud region."""
google_cloud_region = (
google_cloud_region
or os.environ.get("GOOGLE_CLOUD_LOCATION", None)
or _get_gcp_region_from_gcloud()
)
google_cloud_region = _prompt_str(
"Enter Google Cloud region",
default_value=google_cloud_region or "us-central1",
)
return google_cloud_region
def _prompt_for_google_api_key(
google_api_key: Optional[str],
) -> str:
"""Prompts user for Google API key."""
google_api_key = google_api_key or os.environ.get("GOOGLE_API_KEY", None)
google_api_key = _prompt_str(
"Enter Google API key",
prior_msg=_GOOGLE_API_MSG,
default_value=google_api_key,
)
return google_api_key
def _generate_files(
agent_folder: str,
*,
google_api_key: Optional[str] = None,
google_cloud_project: Optional[str] = None,
google_cloud_region: Optional[str] = None,
model: Optional[str] = None,
):
"""Generates a folder name for the agent."""
os.makedirs(agent_folder, exist_ok=True)
dotenv_file_path = os.path.join(agent_folder, ".env")
init_file_path = os.path.join(agent_folder, "__init__.py")
agent_file_path = os.path.join(agent_folder, "agent.py")
with open(dotenv_file_path, "w", encoding="utf-8") as f:
lines = []
if google_api_key:
lines.append("GOOGLE_GENAI_USE_VERTEXAI=0")
elif google_cloud_project and google_cloud_region:
lines.append("GOOGLE_GENAI_USE_VERTEXAI=1")
if google_api_key:
lines.append(f"GOOGLE_API_KEY={google_api_key}")
if google_cloud_project:
lines.append(f"GOOGLE_CLOUD_PROJECT={google_cloud_project}")
if google_cloud_region:
lines.append(f"GOOGLE_CLOUD_LOCATION={google_cloud_region}")
f.write("\n".join(lines))
with open(init_file_path, "w", encoding="utf-8") as f:
f.write(_INIT_PY_TEMPLATE)
with open(agent_file_path, "w", encoding="utf-8") as f:
f.write(_AGENT_PY_TEMPLATE.format(model_name=model))
click.secho(
_SUCCESS_MSG.format(agent_folder=agent_folder),
fg="green",
)
def _prompt_for_model() -> str:
model_choice = click.prompt(
"""\
Choose a model for the root agent:
1. gemini-2.0-flash-001
2. Other models (fill later)
Choose model""",
type=click.Choice(["1", "2"]),
)
if model_choice == "1":
return "gemini-2.0-flash-001"
else:
click.secho(_OTHER_MODEL_MSG, fg="green")
return "<FILL_IN_MODEL>"
def _prompt_to_choose_backend(
google_api_key: Optional[str],
google_cloud_project: Optional[str],
google_cloud_region: Optional[str],
) -> Tuple[Optional[str], Optional[str], Optional[str]]:
"""Prompts user to choose backend.
Returns:
A tuple of (google_api_key, google_cloud_project, google_cloud_region).
"""
backend_choice = click.prompt(
"1. Google AI\n2. Vertex AI\nChoose a backend",
type=click.Choice(["1", "2"]),
)
if backend_choice == "1":
google_api_key = _prompt_for_google_api_key(google_api_key)
elif backend_choice == "2":
click.secho(_GOOGLE_CLOUD_SETUP_MSG, fg="green")
google_cloud_project = _prompt_for_google_cloud(google_cloud_project)
google_cloud_region = _prompt_for_google_cloud_region(google_cloud_region)
return google_api_key, google_cloud_project, google_cloud_region
def run_cmd(
agent_name: str,
*,
model: Optional[str],
google_api_key: Optional[str],
google_cloud_project: Optional[str],
google_cloud_region: Optional[str],
):
"""Runs `adk create` command to create agent template.
Args:
agent_name: str, The name of the agent.
google_api_key: Optional[str], The Google API key for using Google AI as
backend.
google_cloud_project: Optional[str], The Google Cloud project for using
VertexAI as backend.
google_cloud_region: Optional[str], The Google Cloud region for using
VertexAI as backend.
"""
agent_folder = os.path.join(os.getcwd(), agent_name)
# check folder doesn't exist or it's empty. Otherwise, throw
if os.path.exists(agent_folder) and os.listdir(agent_folder):
# Prompt user whether to override existing files using click
if not click.confirm(
f"Non-empty folder already exist: '{agent_folder}'\n"
"Override existing content?",
default=False,
):
raise click.Abort()
if not model:
model = _prompt_for_model()
if not google_api_key and not (google_cloud_project and google_cloud_region):
if model.startswith("gemini"):
google_api_key, google_cloud_project, google_cloud_region = (
_prompt_to_choose_backend(
google_api_key, google_cloud_project, google_cloud_region
)
)
_generate_files(
agent_folder,
google_api_key=google_api_key,
google_cloud_project=google_cloud_project,
google_cloud_region=google_cloud_region,
model=model,
)

View File

@@ -0,0 +1,188 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import shutil
import subprocess
from typing import Optional
import click
_DOCKERFILE_TEMPLATE = """
FROM python:3.11-slim
WORKDIR /app
# Create a non-root user
RUN adduser --disabled-password --gecos "" myuser
# Change ownership of /app to myuser
RUN chown -R myuser:myuser /app
# Switch to the non-root user
USER myuser
# Set up environment variables - Start
ENV PATH="/home/myuser/.local/bin:$PATH"
ENV GOOGLE_GENAI_USE_VERTEXAI=1
ENV GOOGLE_CLOUD_PROJECT={gcp_project_id}
ENV GOOGLE_CLOUD_LOCATION={gcp_region}
# Set up environment variables - End
# Install ADK - Start
RUN pip install google-adk
# Install ADK - End
# Copy agent - Start
COPY "agents/{app_name}/" "/app/agents/{app_name}/"
{install_agent_deps}
# Copy agent - End
EXPOSE {port}
CMD adk {command} --port={port} {session_db_option} {trace_to_cloud_option} "/app/agents"
"""
def _resolve_project(project_in_option: Optional[str]) -> str:
if project_in_option:
return project_in_option
result = subprocess.run(
['gcloud', 'config', 'get-value', 'project'],
check=True,
capture_output=True,
text=True,
)
project = result.stdout.strip()
click.echo(f'Use default project: {project}')
return project
def to_cloud_run(
*,
agent_folder: str,
project: Optional[str],
region: Optional[str],
service_name: str,
app_name: str,
temp_folder: str,
port: int,
trace_to_cloud: bool,
with_ui: bool,
verbosity: str,
session_db_url: str,
):
"""Deploys an agent to Google Cloud Run.
`agent_folder` should contain the following files:
- __init__.py
- agent.py
- requirements.txt (optional, for additional dependencies)
- ... (other required source files)
The folder structure of temp_folder will be
* dist/[google_adk wheel file]
* agents/[app_name]/
* agent source code from `agent_folder`
Args:
agent_folder: The folder (absolute path) containing the agent source code.
project: Google Cloud project id.
region: Google Cloud region.
service_name: The service name in Cloud Run.
app_name: The name of the app, by default, it's basename of `agent_folder`.
temp_folder: The temp folder for the generated Cloud Run source files.
port: The port of the ADK api server.
trace_to_cloud: Whether to enable Cloud Trace.
with_ui: Whether to deploy with UI.
verbosity: The verbosity level of the CLI.
session_db_url: The database URL to connect the session.
"""
app_name = app_name or os.path.basename(agent_folder)
click.echo(f'Start generating Cloud Run source files in {temp_folder}')
# remove temp_folder if exists
if os.path.exists(temp_folder):
click.echo('Removing existing files')
shutil.rmtree(temp_folder)
try:
# copy agent source code
click.echo('Copying agent source code...')
agent_src_path = os.path.join(temp_folder, 'agents', app_name)
shutil.copytree(agent_folder, agent_src_path)
requirements_txt_path = os.path.join(agent_src_path, 'requirements.txt')
install_agent_deps = (
f'RUN pip install -r "/app/agents/{app_name}/requirements.txt"'
if os.path.exists(requirements_txt_path)
else ''
)
click.echo('Copying agent source code complete.')
# create Dockerfile
click.echo('Creating Dockerfile...')
dockerfile_content = _DOCKERFILE_TEMPLATE.format(
gcp_project_id=project,
gcp_region=region,
app_name=app_name,
port=port,
command='web' if with_ui else 'api_server',
install_agent_deps=install_agent_deps,
session_db_option=f'--session_db_url={session_db_url}'
if session_db_url
else '',
trace_to_cloud_option='--trace_to_cloud' if trace_to_cloud else '',
)
dockerfile_path = os.path.join(temp_folder, 'Dockerfile')
os.makedirs(temp_folder, exist_ok=True)
with open(dockerfile_path, 'w', encoding='utf-8') as f:
f.write(
dockerfile_content,
)
click.echo(f'Creating Dockerfile complete: {dockerfile_path}')
# Deploy to Cloud Run
click.echo('Deploying to Cloud Run...')
region_options = ['--region', region] if region else []
project = _resolve_project(project)
subprocess.run(
[
'gcloud',
'run',
'deploy',
service_name,
'--source',
temp_folder,
'--project',
project,
*region_options,
'--port',
str(port),
'--verbosity',
verbosity,
'--labels',
'created-by=adk',
],
check=True,
)
finally:
click.echo(f'Cleaning up the temp folder: {temp_folder}')
shutil.rmtree(temp_folder)

View File

@@ -0,0 +1,282 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from enum import Enum
import importlib.util
import json
import logging
import os
import sys
import traceback
from typing import Any
from typing import Generator
from typing import Optional
import uuid
from pydantic import BaseModel
from ..agents import Agent
logger = logging.getLogger(__name__)
class EvalStatus(Enum):
PASSED = 1
FAILED = 2
NOT_EVALUATED = 3
class EvalMetric(BaseModel):
metric_name: str
threshold: float
class EvalMetricResult(BaseModel):
score: Optional[float]
eval_status: EvalStatus
class EvalResult(BaseModel):
eval_set_file: str
eval_id: str
final_eval_status: EvalStatus
eval_metric_results: list[tuple[EvalMetric, EvalMetricResult]]
session_id: str
MISSING_EVAL_DEPENDENCIES_MESSAGE = (
"Eval module is not installed, please install via `pip install"
" google-adk[eval]`."
)
TOOL_TRAJECTORY_SCORE_KEY = "tool_trajectory_avg_score"
RESPONSE_MATCH_SCORE_KEY = "response_match_score"
# This evaluation is not very stable.
# This is always optional unless explicitly specified.
RESPONSE_EVALUATION_SCORE_KEY = "response_evaluation_score"
EVAL_SESSION_ID_PREFIX = "___eval___session___"
DEFAULT_CRITERIA = {
TOOL_TRAJECTORY_SCORE_KEY: 1.0, # 1-point scale; 1.0 is perfect.
RESPONSE_MATCH_SCORE_KEY: 0.8,
}
def _import_from_path(module_name, file_path):
spec = importlib.util.spec_from_file_location(module_name, file_path)
module = importlib.util.module_from_spec(spec)
sys.modules[module_name] = module
spec.loader.exec_module(module)
return module
def _get_agent_module(agent_module_file_path: str):
file_path = os.path.join(agent_module_file_path, "__init__.py")
module_name = "agent"
return _import_from_path(module_name, file_path)
def get_evaluation_criteria_or_default(
eval_config_file_path: str,
) -> dict[str, float]:
"""Returns evaluation criteria from the config file, if present.
Otherwise a default one is returned.
"""
if eval_config_file_path:
with open(eval_config_file_path, "r", encoding="utf-8") as f:
config_data = json.load(f)
if "criteria" in config_data and isinstance(config_data["criteria"], dict):
evaluation_criteria = config_data["criteria"]
else:
raise ValueError(
f"Invalid format for test_config.json at {eval_config_file_path}."
" Expected a 'criteria' dictionary."
)
else:
logger.info("No config file supplied. Using default criteria.")
evaluation_criteria = DEFAULT_CRITERIA
return evaluation_criteria
def get_root_agent(agent_module_file_path: str) -> Agent:
"""Returns root agent given the agent module."""
agent_module = _get_agent_module(agent_module_file_path)
root_agent = agent_module.agent.root_agent
return root_agent
def try_get_reset_func(agent_module_file_path: str) -> Any:
"""Returns reset function for the agent, if present, given the agent module."""
agent_module = _get_agent_module(agent_module_file_path)
reset_func = getattr(agent_module.agent, "reset_data", None)
return reset_func
def parse_and_get_evals_to_run(
eval_set_file_path: tuple[str],
) -> dict[str, list[str]]:
"""Returns a dictionary of eval sets to evals that should be run."""
eval_set_to_evals = {}
for input_eval_set in eval_set_file_path:
evals = []
if ":" not in input_eval_set:
eval_set_file = input_eval_set
else:
eval_set_file = input_eval_set.split(":")[0]
evals = input_eval_set.split(":")[1].split(",")
if eval_set_file not in eval_set_to_evals:
eval_set_to_evals[eval_set_file] = []
eval_set_to_evals[eval_set_file].extend(evals)
return eval_set_to_evals
def run_evals(
eval_set_to_evals: dict[str, list[str]],
root_agent: Agent,
reset_func: Optional[Any],
eval_metrics: list[EvalMetric],
session_service=None,
artifact_service=None,
print_detailed_results=False,
) -> Generator[EvalResult, None, None]:
try:
from ..evaluation.agent_evaluator import EvaluationGenerator
from ..evaluation.response_evaluator import ResponseEvaluator
from ..evaluation.trajectory_evaluator import TrajectoryEvaluator
except ModuleNotFoundError as e:
raise ModuleNotFoundError(MISSING_EVAL_DEPENDENCIES_MESSAGE) from e
"""Returns a summary of eval runs."""
for eval_set_file, evals_to_run in eval_set_to_evals.items():
with open(eval_set_file, "r", encoding="utf-8") as file:
eval_items = json.load(file) # Load JSON into a list
assert eval_items, f"No eval data found in eval set file: {eval_set_file}"
for eval_item in eval_items:
eval_name = eval_item["name"]
eval_data = eval_item["data"]
initial_session = eval_item.get("initial_session", {})
if evals_to_run and eval_name not in evals_to_run:
continue
try:
print(f"Running Eval: {eval_set_file}:{eval_name}")
session_id = f"{EVAL_SESSION_ID_PREFIX}{str(uuid.uuid4())}"
scrape_result = EvaluationGenerator._process_query_with_root_agent(
data=eval_data,
root_agent=root_agent,
reset_func=reset_func,
initial_session=initial_session,
session_id=session_id,
session_service=session_service,
artifact_service=artifact_service,
)
eval_metric_results = []
for eval_metric in eval_metrics:
eval_metric_result = None
if eval_metric.metric_name == TOOL_TRAJECTORY_SCORE_KEY:
score = TrajectoryEvaluator.evaluate(
[scrape_result], print_detailed_results=print_detailed_results
)
eval_metric_result = _get_eval_metric_result(eval_metric, score)
elif eval_metric.metric_name == RESPONSE_MATCH_SCORE_KEY:
score = ResponseEvaluator.evaluate(
[scrape_result],
[RESPONSE_MATCH_SCORE_KEY],
print_detailed_results=print_detailed_results,
)
eval_metric_result = _get_eval_metric_result(
eval_metric, score["rouge_1/mean"].item()
)
elif eval_metric.metric_name == RESPONSE_EVALUATION_SCORE_KEY:
score = ResponseEvaluator.evaluate(
[scrape_result],
[RESPONSE_EVALUATION_SCORE_KEY],
print_detailed_results=print_detailed_results,
)
eval_metric_result = _get_eval_metric_result(
eval_metric, score["coherence/mean"].item()
)
else:
logger.warning("`%s` is not supported.", eval_metric.metric_name)
eval_metric_results.append((
eval_metric,
EvalMetricResult(eval_status=EvalStatus.NOT_EVALUATED),
))
eval_metric_results.append((
eval_metric,
eval_metric_result,
))
_print_eval_metric_result(eval_metric, eval_metric_result)
final_eval_status = EvalStatus.NOT_EVALUATED
# Go over the all the eval statuses and mark the final eval status as
# passed if all of them pass, otherwise mark the final eval status to
# failed.
for eval_metric_result in eval_metric_results:
eval_status = eval_metric_result[1].eval_status
if eval_status == EvalStatus.PASSED:
final_eval_status = EvalStatus.PASSED
elif eval_status == EvalStatus.NOT_EVALUATED:
continue
elif eval_status == EvalStatus.FAILED:
final_eval_status = EvalStatus.FAILED
break
else:
raise ValueError("Unknown eval status.")
yield EvalResult(
eval_set_file=eval_set_file,
eval_id=eval_name,
final_eval_status=final_eval_status,
eval_metric_results=eval_metric_results,
session_id=session_id,
)
if final_eval_status == EvalStatus.PASSED:
result = "✅ Passed"
else:
result = "❌ Failed"
print(f"Result: {result}\n")
except Exception as e:
print(f"Error: {e}")
logger.info("Error: %s", str(traceback.format_exc()))
def _get_eval_metric_result(eval_metric, score):
eval_status = (
EvalStatus.PASSED if score >= eval_metric.threshold else EvalStatus.FAILED
)
return EvalMetricResult(score=score, eval_status=eval_status)
def _print_eval_metric_result(eval_metric, eval_metric_result):
print(
f"Metric: {eval_metric.metric_name}\tStatus:"
f" {eval_metric_result.eval_status}\tScore:"
f" {eval_metric_result.score}\tThreshold: {eval_metric.threshold}"
)

View File

@@ -0,0 +1,600 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import asyncio
from contextlib import asynccontextmanager
from datetime import datetime
import logging
import os
import tempfile
from typing import Optional
import click
from fastapi import FastAPI
import uvicorn
from . import cli_create
from . import cli_deploy
from .cli import run_cli
from .cli_eval import MISSING_EVAL_DEPENDENCIES_MESSAGE
from .fast_api import get_fast_api_app
from .utils import envs
from .utils import logs
logger = logging.getLogger(__name__)
@click.group(context_settings={"max_content_width": 240})
def main():
"""Agent Development Kit CLI tools."""
pass
@main.group()
def deploy():
"""Deploys agent to hosted environments."""
pass
@main.command("create")
@click.option(
"--model",
type=str,
help="Optional. The model used for the root agent.",
)
@click.option(
"--api_key",
type=str,
help=(
"Optional. The API Key needed to access the model, e.g. Google AI API"
" Key."
),
)
@click.option(
"--project",
type=str,
help="Optional. The Google Cloud Project for using VertexAI as backend.",
)
@click.option(
"--region",
type=str,
help="Optional. The Google Cloud Region for using VertexAI as backend.",
)
@click.argument("app_name", type=str, required=True)
def cli_create_cmd(
app_name: str,
model: Optional[str],
api_key: Optional[str],
project: Optional[str],
region: Optional[str],
):
"""Creates a new app in the current folder with prepopulated agent template.
APP_NAME: required, the folder of the agent source code.
Example:
adk create path/to/my_app
"""
cli_create.run_cmd(
app_name,
model=model,
google_api_key=api_key,
google_cloud_project=project,
google_cloud_region=region,
)
@main.command("run")
@click.option(
"--save_session",
type=bool,
is_flag=True,
show_default=True,
default=False,
help="Optional. Whether to save the session to a json file on exit.",
)
@click.argument(
"agent",
type=click.Path(
exists=True, dir_okay=True, file_okay=False, resolve_path=True
),
)
def cli_run(agent: str, save_session: bool):
"""Runs an interactive CLI for a certain agent.
AGENT: The path to the agent source code folder.
Example:
adk run path/to/my_agent
"""
logs.log_to_tmp_folder()
agent_parent_folder = os.path.dirname(agent)
agent_folder_name = os.path.basename(agent)
asyncio.run(
run_cli(
agent_parent_dir=agent_parent_folder,
agent_folder_name=agent_folder_name,
save_session=save_session,
)
)
@main.command("eval")
@click.argument(
"agent_module_file_path",
type=click.Path(
exists=True, dir_okay=True, file_okay=False, resolve_path=True
),
)
@click.argument("eval_set_file_path", nargs=-1)
@click.option("--config_file_path", help="Optional. The path to config file.")
@click.option(
"--print_detailed_results",
is_flag=True,
show_default=True,
default=False,
help="Optional. Whether to print detailed results on console or not.",
)
def cli_eval(
agent_module_file_path: str,
eval_set_file_path: tuple[str],
config_file_path: str,
print_detailed_results: bool,
):
"""Evaluates an agent given the eval sets.
AGENT_MODULE_FILE_PATH: The path to the __init__.py file that contains a
module by the name "agent". "agent" module contains a root_agent.
EVAL_SET_FILE_PATH: You can specify one or more eval set file paths.
For each file, all evals will be run by default.
If you want to run only specific evals from a eval set, first create a comma
separated list of eval names and then add that as a suffix to the eval set
file name, demarcated by a `:`.
For example,
sample_eval_set_file.json:eval_1,eval_2,eval_3
This will only run eval_1, eval_2 and eval_3 from sample_eval_set_file.json.
CONFIG_FILE_PATH: The path to config file.
PRINT_DETAILED_RESULTS: Prints detailed results on the console.
"""
envs.load_dotenv_for_agent(agent_module_file_path, ".")
try:
from .cli_eval import EvalMetric
from .cli_eval import EvalResult
from .cli_eval import EvalStatus
from .cli_eval import get_evaluation_criteria_or_default
from .cli_eval import get_root_agent
from .cli_eval import parse_and_get_evals_to_run
from .cli_eval import run_evals
from .cli_eval import try_get_reset_func
except ModuleNotFoundError:
raise click.ClickException(MISSING_EVAL_DEPENDENCIES_MESSAGE)
evaluation_criteria = get_evaluation_criteria_or_default(config_file_path)
eval_metrics = []
for metric_name, threshold in evaluation_criteria.items():
eval_metrics.append(
EvalMetric(metric_name=metric_name, threshold=threshold)
)
print(f"Using evaluation creiteria: {evaluation_criteria}")
root_agent = get_root_agent(agent_module_file_path)
reset_func = try_get_reset_func(agent_module_file_path)
eval_set_to_evals = parse_and_get_evals_to_run(eval_set_file_path)
try:
eval_results = list(
run_evals(
eval_set_to_evals,
root_agent,
reset_func,
eval_metrics,
print_detailed_results=print_detailed_results,
)
)
except ModuleNotFoundError:
raise click.ClickException(MISSING_EVAL_DEPENDENCIES_MESSAGE)
print("*********************************************************************")
eval_run_summary = {}
for eval_result in eval_results:
eval_result: EvalResult
if eval_result.eval_set_file not in eval_run_summary:
eval_run_summary[eval_result.eval_set_file] = [0, 0]
if eval_result.final_eval_status == EvalStatus.PASSED:
eval_run_summary[eval_result.eval_set_file][0] += 1
else:
eval_run_summary[eval_result.eval_set_file][1] += 1
print("Eval Run Summary")
for eval_set_file, pass_fail_count in eval_run_summary.items():
print(
f"{eval_set_file}:\n Tests passed: {pass_fail_count[0]}\n Tests"
f" failed: {pass_fail_count[1]}"
)
@main.command("web")
@click.option(
"--session_db_url",
help=(
"""Optional. The database URL to store the session.
- Use 'agentengine://<agent_engine_resource_id>' to connect to Agent Engine sessions.
- Use 'sqlite://<path_to_sqlite_file>' to connect to a SQLite DB.
- See https://docs.sqlalchemy.org/en/20/core/engines.html#backend-specific-urls for more details on supported DB URLs."""
),
)
@click.option(
"--port",
type=int,
help="Optional. The port of the server",
default=8000,
)
@click.option(
"--allow_origins",
help="Optional. Any additional origins to allow for CORS.",
multiple=True,
)
@click.option(
"--log_level",
type=click.Choice(
["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"], case_sensitive=False
),
default="INFO",
help="Optional. Set the logging level",
)
@click.option(
"--log_to_tmp",
is_flag=True,
show_default=True,
default=False,
help=(
"Optional. Whether to log to system temp folder instead of console."
" This is useful for local debugging."
),
)
@click.option(
"--trace_to_cloud",
is_flag=True,
show_default=True,
default=False,
help="Optional. Whether to enable cloud trace for telemetry.",
)
@click.argument(
"agents_dir",
type=click.Path(
exists=True, dir_okay=True, file_okay=False, resolve_path=True
),
default=os.getcwd,
)
def cli_web(
agents_dir: str,
log_to_tmp: bool,
session_db_url: str = "",
log_level: str = "INFO",
allow_origins: Optional[list[str]] = None,
port: int = 8000,
trace_to_cloud: bool = False,
):
"""Starts a FastAPI server with Web UI for agents.
AGENTS_DIR: The directory of agents, where each sub-directory is a single
agent, containing at least `__init__.py` and `agent.py` files.
Example:
adk web --session_db_url=[db_url] --port=[port] path/to/agents_dir
"""
if log_to_tmp:
logs.log_to_tmp_folder()
else:
logs.log_to_stderr()
logging.getLogger().setLevel(log_level)
@asynccontextmanager
async def _lifespan(app: FastAPI):
click.secho(
f"""
+-----------------------------------------------------------------------------+
| ADK Web Server started |
| |
| For local testing, access at http://localhost:{port}.{" "*(29 - len(str(port)))}|
+-----------------------------------------------------------------------------+
""",
fg="green",
)
yield # Startup is done, now app is running
click.secho(
"""
+-----------------------------------------------------------------------------+
| ADK Web Server shutting down... |
+-----------------------------------------------------------------------------+
""",
fg="green",
)
app = get_fast_api_app(
agent_dir=agents_dir,
session_db_url=session_db_url,
allow_origins=allow_origins,
web=True,
trace_to_cloud=trace_to_cloud,
lifespan=_lifespan,
)
config = uvicorn.Config(
app,
host="0.0.0.0",
port=port,
reload=True,
)
server = uvicorn.Server(config)
server.run()
@main.command("api_server")
@click.option(
"--session_db_url",
help=(
"""Optional. The database URL to store the session.
- Use 'agentengine://<agent_engine_resource_id>' to connect to Agent Engine sessions.
- Use 'sqlite://<path_to_sqlite_file>' to connect to a SQLite DB.
- See https://docs.sqlalchemy.org/en/20/core/engines.html#backend-specific-urls for more details on supported DB URLs."""
),
)
@click.option(
"--port",
type=int,
help="Optional. The port of the server",
default=8000,
)
@click.option(
"--allow_origins",
help="Optional. Any additional origins to allow for CORS.",
multiple=True,
)
@click.option(
"--log_level",
type=click.Choice(
["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"], case_sensitive=False
),
default="INFO",
help="Optional. Set the logging level",
)
@click.option(
"--log_to_tmp",
is_flag=True,
show_default=True,
default=False,
help=(
"Optional. Whether to log to system temp folder instead of console."
" This is useful for local debugging."
),
)
@click.option(
"--trace_to_cloud",
is_flag=True,
show_default=True,
default=False,
help="Optional. Whether to enable cloud trace for telemetry.",
)
# The directory of agents, where each sub-directory is a single agent.
# By default, it is the current working directory
@click.argument(
"agents_dir",
type=click.Path(
exists=True, dir_okay=True, file_okay=False, resolve_path=True
),
default=os.getcwd(),
)
def cli_api_server(
agents_dir: str,
log_to_tmp: bool,
session_db_url: str = "",
log_level: str = "INFO",
allow_origins: Optional[list[str]] = None,
port: int = 8000,
trace_to_cloud: bool = False,
):
"""Starts a FastAPI server for agents.
AGENTS_DIR: The directory of agents, where each sub-directory is a single
agent, containing at least `__init__.py` and `agent.py` files.
Example:
adk api_server --session_db_url=[db_url] --port=[port] path/to/agents_dir
"""
if log_to_tmp:
logs.log_to_tmp_folder()
else:
logs.log_to_stderr()
logging.getLogger().setLevel(log_level)
config = uvicorn.Config(
get_fast_api_app(
agent_dir=agents_dir,
session_db_url=session_db_url,
allow_origins=allow_origins,
web=False,
trace_to_cloud=trace_to_cloud,
),
host="0.0.0.0",
port=port,
reload=True,
)
server = uvicorn.Server(config)
server.run()
@deploy.command("cloud_run")
@click.option(
"--project",
type=str,
help=(
"Required. Google Cloud project to deploy the agent. When absent,"
" default project from gcloud config is used."
),
)
@click.option(
"--region",
type=str,
help=(
"Required. Google Cloud region to deploy the agent. When absent,"
" gcloud run deploy will prompt later."
),
)
@click.option(
"--service_name",
type=str,
default="adk-default-service-name",
help=(
"Optional. The service name to use in Cloud Run (default:"
" 'adk-default-service-name')."
),
)
@click.option(
"--app_name",
type=str,
default="",
help=(
"Optional. App name of the ADK API server (default: the folder name"
" of the AGENT source code)."
),
)
@click.option(
"--port",
type=int,
default=8000,
help="Optional. The port of the ADK API server (default: 8000).",
)
@click.option(
"--trace_to_cloud",
type=bool,
is_flag=True,
show_default=True,
default=False,
help="Optional. Whether to enable Cloud Trace for cloud run.",
)
@click.option(
"--with_ui",
type=bool,
is_flag=True,
show_default=True,
default=False,
help=(
"Optional. Deploy ADK Web UI if set. (default: deploy ADK API server"
" only)"
),
)
@click.option(
"--temp_folder",
type=str,
default=os.path.join(
tempfile.gettempdir(),
"cloud_run_deploy_src",
datetime.now().strftime("%Y%m%d_%H%M%S"),
),
help=(
"Optional. Temp folder for the generated Cloud Run source files"
" (default: a timestamped folder in the system temp directory)."
),
)
@click.option(
"--verbosity",
type=click.Choice(
["debug", "info", "warning", "error", "critical"], case_sensitive=False
),
default="WARNING",
help="Optional. Override the default verbosity level.",
)
@click.option(
"--session_db_url",
help=(
"""Optional. The database URL to store the session.
- Use 'agentengine://<agent_engine_resource_id>' to connect to Agent Engine sessions.
- Use 'sqlite://<path_to_sqlite_file>' to connect to a SQLite DB.
- See https://docs.sqlalchemy.org/en/20/core/engines.html#backend-specific-urls for more details on supported DB URLs."""
),
)
@click.argument(
"agent",
type=click.Path(
exists=True, dir_okay=True, file_okay=False, resolve_path=True
),
)
def cli_deploy_cloud_run(
agent: str,
project: Optional[str],
region: Optional[str],
service_name: str,
app_name: str,
temp_folder: str,
port: int,
trace_to_cloud: bool,
with_ui: bool,
verbosity: str,
session_db_url: str,
):
"""Deploys an agent to Cloud Run.
AGENT: The path to the agent source code folder.
Example:
adk deploy cloud_run --project=[project] --region=[region] path/to/my_agent
"""
try:
cli_deploy.to_cloud_run(
agent_folder=agent,
project=project,
region=region,
service_name=service_name,
app_name=app_name,
temp_folder=temp_folder,
port=port,
trace_to_cloud=trace_to_cloud,
with_ui=with_ui,
verbosity=verbosity,
session_db_url=session_db_url,
)
except Exception as e:
click.secho(f"Deploy failed: {e}", fg="red", err=True)

View File

@@ -0,0 +1,822 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import asyncio
from contextlib import asynccontextmanager
import importlib
import inspect
import json
import logging
import os
from pathlib import Path
import re
import sys
import traceback
import typing
from typing import Any
from typing import List
from typing import Literal
from typing import Optional
import click
from click import Tuple
from fastapi import FastAPI
from fastapi import HTTPException
from fastapi import Query
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import FileResponse
from fastapi.responses import RedirectResponse
from fastapi.responses import StreamingResponse
from fastapi.staticfiles import StaticFiles
from fastapi.websockets import WebSocket
from fastapi.websockets import WebSocketDisconnect
from google.genai import types
import graphviz
from opentelemetry import trace
from opentelemetry.exporter.cloud_trace import CloudTraceSpanExporter
from opentelemetry.sdk.trace import export
from opentelemetry.sdk.trace import ReadableSpan
from opentelemetry.sdk.trace import TracerProvider
from pydantic import BaseModel
from pydantic import ValidationError
from starlette.types import Lifespan
from ..agents import RunConfig
from ..agents.live_request_queue import LiveRequest
from ..agents.live_request_queue import LiveRequestQueue
from ..agents.llm_agent import Agent
from ..agents.run_config import StreamingMode
from ..artifacts import InMemoryArtifactService
from ..events.event import Event
from ..memory.in_memory_memory_service import InMemoryMemoryService
from ..runners import Runner
from ..sessions.database_session_service import DatabaseSessionService
from ..sessions.in_memory_session_service import InMemorySessionService
from ..sessions.session import Session
from ..sessions.vertex_ai_session_service import VertexAiSessionService
from .cli_eval import EVAL_SESSION_ID_PREFIX
from .cli_eval import EvalMetric
from .cli_eval import EvalMetricResult
from .cli_eval import EvalStatus
from .utils import create_empty_state
from .utils import envs
from .utils import evals
logger = logging.getLogger(__name__)
_EVAL_SET_FILE_EXTENSION = ".evalset.json"
class ApiServerSpanExporter(export.SpanExporter):
def __init__(self, trace_dict):
self.trace_dict = trace_dict
def export(
self, spans: typing.Sequence[ReadableSpan]
) -> export.SpanExportResult:
for span in spans:
if (
span.name == "call_llm"
or span.name == "send_data"
or span.name.startswith("tool_response")
):
attributes = dict(span.attributes)
attributes["trace_id"] = span.get_span_context().trace_id
attributes["span_id"] = span.get_span_context().span_id
if attributes.get("gcp.vertex.agent.event_id", None):
self.trace_dict[attributes["gcp.vertex.agent.event_id"]] = attributes
return export.SpanExportResult.SUCCESS
def force_flush(self, timeout_millis: int = 30000) -> bool:
return True
class AgentRunRequest(BaseModel):
app_name: str
user_id: str
session_id: str
new_message: types.Content
streaming: bool = False
class AddSessionToEvalSetRequest(BaseModel):
eval_id: str
session_id: str
user_id: str
class RunEvalRequest(BaseModel):
eval_ids: list[str] # if empty, then all evals in the eval set are run.
eval_metrics: list[EvalMetric]
class RunEvalResult(BaseModel):
eval_set_id: str
eval_id: str
final_eval_status: EvalStatus
eval_metric_results: list[tuple[EvalMetric, EvalMetricResult]]
session_id: str
def get_fast_api_app(
*,
agent_dir: str,
session_db_url: str = "",
allow_origins: Optional[list[str]] = None,
web: bool,
trace_to_cloud: bool = False,
lifespan: Optional[Lifespan[FastAPI]] = None,
) -> FastAPI:
# InMemory tracing dict.
trace_dict: dict[str, Any] = {}
# Set up tracing in the FastAPI server.
provider = TracerProvider()
provider.add_span_processor(
export.SimpleSpanProcessor(ApiServerSpanExporter(trace_dict))
)
if trace_to_cloud:
envs.load_dotenv_for_agent("", agent_dir)
if project_id := os.environ.get("GOOGLE_CLOUD_PROJECT", None):
processor = export.BatchSpanProcessor(
CloudTraceSpanExporter(project_id=project_id)
)
provider.add_span_processor(processor)
else:
logging.warning(
"GOOGLE_CLOUD_PROJECT environment variable is not set. Tracing will"
" not be enabled."
)
trace.set_tracer_provider(provider)
exit_stacks = []
@asynccontextmanager
async def internal_lifespan(app: FastAPI):
if lifespan:
async with lifespan(app) as lifespan_context:
yield
if exit_stacks:
for stack in exit_stacks:
await stack.aclose()
else:
yield
# Run the FastAPI server.
app = FastAPI(lifespan=internal_lifespan)
if allow_origins:
app.add_middleware(
CORSMiddleware,
allow_origins=allow_origins,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
if agent_dir not in sys.path:
sys.path.append(agent_dir)
runner_dict = {}
root_agent_dict = {}
# Build the Artifact service
artifact_service = InMemoryArtifactService()
memory_service = InMemoryMemoryService()
# Build the Session service
agent_engine_id = ""
if session_db_url:
if session_db_url.startswith("agentengine://"):
# Create vertex session service
agent_engine_id = session_db_url.split("://")[1]
if not agent_engine_id:
raise click.ClickException("Agent engine id can not be empty.")
envs.load_dotenv_for_agent("", agent_dir)
session_service = VertexAiSessionService(
os.environ["GOOGLE_CLOUD_PROJECT"],
os.environ["GOOGLE_CLOUD_LOCATION"],
)
else:
session_service = DatabaseSessionService(db_url=session_db_url)
else:
session_service = InMemorySessionService()
@app.get("/list-apps")
def list_apps() -> list[str]:
base_path = Path.cwd() / agent_dir
if not base_path.exists():
raise HTTPException(status_code=404, detail="Path not found")
if not base_path.is_dir():
raise HTTPException(status_code=400, detail="Not a directory")
agent_names = [
x
for x in os.listdir(base_path)
if os.path.isdir(os.path.join(base_path, x))
and not x.startswith(".")
and x != "__pycache__"
]
agent_names.sort()
return agent_names
@app.get("/debug/trace/{event_id}")
def get_trace_dict(event_id: str) -> Any:
event_dict = trace_dict.get(event_id, None)
if event_dict is None:
raise HTTPException(status_code=404, detail="Trace not found")
return event_dict
@app.get(
"/apps/{app_name}/users/{user_id}/sessions/{session_id}",
response_model_exclude_none=True,
)
def get_session(app_name: str, user_id: str, session_id: str) -> Session:
# Connect to managed session if agent_engine_id is set.
app_name = agent_engine_id if agent_engine_id else app_name
session = session_service.get_session(
app_name=app_name, user_id=user_id, session_id=session_id
)
if not session:
raise HTTPException(status_code=404, detail="Session not found")
return session
@app.get(
"/apps/{app_name}/users/{user_id}/sessions",
response_model_exclude_none=True,
)
def list_sessions(app_name: str, user_id: str) -> list[Session]:
# Connect to managed session if agent_engine_id is set.
app_name = agent_engine_id if agent_engine_id else app_name
return [
session
for session in session_service.list_sessions(
app_name=app_name, user_id=user_id
).sessions
# Remove sessions that were generated as a part of Eval.
if not session.id.startswith(EVAL_SESSION_ID_PREFIX)
]
@app.post(
"/apps/{app_name}/users/{user_id}/sessions/{session_id}",
response_model_exclude_none=True,
)
def create_session_with_id(
app_name: str,
user_id: str,
session_id: str,
state: Optional[dict[str, Any]] = None,
) -> Session:
# Connect to managed session if agent_engine_id is set.
app_name = agent_engine_id if agent_engine_id else app_name
if (
session_service.get_session(
app_name=app_name, user_id=user_id, session_id=session_id
)
is not None
):
logger.warning("Session already exists: %s", session_id)
raise HTTPException(
status_code=400, detail=f"Session already exists: {session_id}"
)
logger.info("New session created: %s", session_id)
return session_service.create_session(
app_name=app_name, user_id=user_id, state=state, session_id=session_id
)
@app.post(
"/apps/{app_name}/users/{user_id}/sessions",
response_model_exclude_none=True,
)
def create_session(
app_name: str,
user_id: str,
state: Optional[dict[str, Any]] = None,
) -> Session:
# Connect to managed session if agent_engine_id is set.
app_name = agent_engine_id if agent_engine_id else app_name
logger.info("New session created")
return session_service.create_session(
app_name=app_name, user_id=user_id, state=state
)
def _get_eval_set_file_path(app_name, agent_dir, eval_set_id) -> str:
return os.path.join(
agent_dir,
app_name,
eval_set_id + _EVAL_SET_FILE_EXTENSION,
)
@app.post(
"/apps/{app_name}/eval_sets/{eval_set_id}",
response_model_exclude_none=True,
)
def create_eval_set(
app_name: str,
eval_set_id: str,
):
"""Creates an eval set, given the id."""
pattern = r"^[a-zA-Z0-9_]+$"
if not bool(re.fullmatch(pattern, eval_set_id)):
raise HTTPException(
status_code=400,
detail=(
f"Invalid eval set id. Eval set id should have the `{pattern}`"
" format"
),
)
# Define the file path
new_eval_set_path = _get_eval_set_file_path(
app_name, agent_dir, eval_set_id
)
logger.info("Creating eval set file `%s`", new_eval_set_path)
if not os.path.exists(new_eval_set_path):
# Write the JSON string to the file
logger.info("Eval set file doesn't exist, we will create a new one.")
with open(new_eval_set_path, "w") as f:
empty_content = json.dumps([], indent=2)
f.write(empty_content)
@app.get(
"/apps/{app_name}/eval_sets",
response_model_exclude_none=True,
)
def list_eval_sets(app_name: str) -> list[str]:
"""Lists all eval sets for the given app."""
eval_set_file_path = os.path.join(agent_dir, app_name)
eval_sets = []
for file in os.listdir(eval_set_file_path):
if file.endswith(_EVAL_SET_FILE_EXTENSION):
eval_sets.append(
os.path.basename(file).removesuffix(_EVAL_SET_FILE_EXTENSION)
)
return sorted(eval_sets)
@app.post(
"/apps/{app_name}/eval_sets/{eval_set_id}/add_session",
response_model_exclude_none=True,
)
async def add_session_to_eval_set(
app_name: str, eval_set_id: str, req: AddSessionToEvalSetRequest
):
pattern = r"^[a-zA-Z0-9_]+$"
if not bool(re.fullmatch(pattern, req.eval_id)):
raise HTTPException(
status_code=400,
detail=f"Invalid eval id. Eval id should have the `{pattern}` format",
)
# Get the session
session = session_service.get_session(
app_name=app_name, user_id=req.user_id, session_id=req.session_id
)
assert session, "Session not found."
# Load the eval set file data
eval_set_file_path = _get_eval_set_file_path(
app_name, agent_dir, eval_set_id
)
with open(eval_set_file_path, "r") as file:
eval_set_data = json.load(file) # Load JSON into a list
if [x for x in eval_set_data if x["name"] == req.eval_id]:
raise HTTPException(
status_code=400,
detail=(
f"Eval id `{req.eval_id}` already exists in `{eval_set_id}`"
" eval set."
),
)
# Convert the session data to evaluation format
test_data = evals.convert_session_to_eval_format(session)
# Populate the session with initial session state.
initial_session_state = create_empty_state(
await _get_root_agent_async(app_name)
)
eval_set_data.append({
"name": req.eval_id,
"data": test_data,
"initial_session": {
"state": initial_session_state,
"app_name": app_name,
"user_id": req.user_id,
},
})
# Serialize the test data to JSON and write to the eval set file.
with open(eval_set_file_path, "w") as f:
f.write(json.dumps(eval_set_data, indent=2))
@app.get(
"/apps/{app_name}/eval_sets/{eval_set_id}/evals",
response_model_exclude_none=True,
)
def list_evals_in_eval_set(
app_name: str,
eval_set_id: str,
) -> list[str]:
"""Lists all evals in an eval set."""
# Load the eval set file data
eval_set_file_path = _get_eval_set_file_path(
app_name, agent_dir, eval_set_id
)
with open(eval_set_file_path, "r") as file:
eval_set_data = json.load(file) # Load JSON into a list
return sorted([x["name"] for x in eval_set_data])
@app.post(
"/apps/{app_name}/eval_sets/{eval_set_id}/run_eval",
response_model_exclude_none=True,
)
async def run_eval(
app_name: str, eval_set_id: str, req: RunEvalRequest
) -> list[RunEvalResult]:
from .cli_eval import run_evals
"""Runs an eval given the details in the eval request."""
# Create a mapping from eval set file to all the evals that needed to be
# run.
eval_set_file_path = _get_eval_set_file_path(
app_name, agent_dir, eval_set_id
)
eval_set_to_evals = {eval_set_file_path: req.eval_ids}
if not req.eval_ids:
logger.info(
"Eval ids to run list is empty. We will all evals in the eval set."
)
root_agent = await _get_root_agent_async(app_name)
eval_results = list(
run_evals(
eval_set_to_evals,
root_agent,
getattr(root_agent, "reset_data", None),
req.eval_metrics,
session_service=session_service,
artifact_service=artifact_service,
)
)
run_eval_results = []
for eval_result in eval_results:
run_eval_results.append(
RunEvalResult(
app_name=app_name,
eval_set_id=eval_set_id,
eval_id=eval_result.eval_id,
final_eval_status=eval_result.final_eval_status,
eval_metric_results=eval_result.eval_metric_results,
session_id=eval_result.session_id,
)
)
return run_eval_results
@app.delete("/apps/{app_name}/users/{user_id}/sessions/{session_id}")
def delete_session(app_name: str, user_id: str, session_id: str):
# Connect to managed session if agent_engine_id is set.
app_name = agent_engine_id if agent_engine_id else app_name
session_service.delete_session(
app_name=app_name, user_id=user_id, session_id=session_id
)
@app.get(
"/apps/{app_name}/users/{user_id}/sessions/{session_id}/artifacts/{artifact_name}",
response_model_exclude_none=True,
)
def load_artifact(
app_name: str,
user_id: str,
session_id: str,
artifact_name: str,
version: Optional[int] = Query(None),
) -> Optional[types.Part]:
app_name = agent_engine_id if agent_engine_id else app_name
artifact = artifact_service.load_artifact(
app_name=app_name,
user_id=user_id,
session_id=session_id,
filename=artifact_name,
version=version,
)
if not artifact:
raise HTTPException(status_code=404, detail="Artifact not found")
return artifact
@app.get(
"/apps/{app_name}/users/{user_id}/sessions/{session_id}/artifacts/{artifact_name}/versions/{version_id}",
response_model_exclude_none=True,
)
def load_artifact_version(
app_name: str,
user_id: str,
session_id: str,
artifact_name: str,
version_id: int,
) -> Optional[types.Part]:
app_name = agent_engine_id if agent_engine_id else app_name
artifact = artifact_service.load_artifact(
app_name=app_name,
user_id=user_id,
session_id=session_id,
filename=artifact_name,
version=version_id,
)
if not artifact:
raise HTTPException(status_code=404, detail="Artifact not found")
return artifact
@app.get(
"/apps/{app_name}/users/{user_id}/sessions/{session_id}/artifacts",
response_model_exclude_none=True,
)
def list_artifact_names(
app_name: str, user_id: str, session_id: str
) -> list[str]:
app_name = agent_engine_id if agent_engine_id else app_name
return artifact_service.list_artifact_keys(
app_name=app_name, user_id=user_id, session_id=session_id
)
@app.get(
"/apps/{app_name}/users/{user_id}/sessions/{session_id}/artifacts/{artifact_name}/versions",
response_model_exclude_none=True,
)
def list_artifact_versions(
app_name: str, user_id: str, session_id: str, artifact_name: str
) -> list[int]:
app_name = agent_engine_id if agent_engine_id else app_name
return artifact_service.list_versions(
app_name=app_name,
user_id=user_id,
session_id=session_id,
filename=artifact_name,
)
@app.delete(
"/apps/{app_name}/users/{user_id}/sessions/{session_id}/artifacts/{artifact_name}",
)
def delete_artifact(
app_name: str, user_id: str, session_id: str, artifact_name: str
):
app_name = agent_engine_id if agent_engine_id else app_name
artifact_service.delete_artifact(
app_name=app_name,
user_id=user_id,
session_id=session_id,
filename=artifact_name,
)
@app.post("/run", response_model_exclude_none=True)
async def agent_run(req: AgentRunRequest) -> list[Event]:
# Connect to managed session if agent_engine_id is set.
app_id = agent_engine_id if agent_engine_id else req.app_name
session = session_service.get_session(
app_name=app_id, user_id=req.user_id, session_id=req.session_id
)
if not session:
raise HTTPException(status_code=404, detail="Session not found")
runner = await _get_runner_async(req.app_name)
events = [
event
async for event in runner.run_async(
user_id=req.user_id,
session_id=req.session_id,
new_message=req.new_message,
)
]
logger.info("Generated %s events in agent run: %s", len(events), events)
return events
@app.post("/run_sse")
async def agent_run_sse(req: AgentRunRequest) -> StreamingResponse:
# Connect to managed session if agent_engine_id is set.
app_id = agent_engine_id if agent_engine_id else req.app_name
# SSE endpoint
session = session_service.get_session(
app_name=app_id, user_id=req.user_id, session_id=req.session_id
)
if not session:
raise HTTPException(status_code=404, detail="Session not found")
# Convert the events to properly formatted SSE
async def event_generator():
try:
stream_mode = StreamingMode.SSE if req.streaming else StreamingMode.NONE
runner = await _get_runner_async(req.app_name)
async for event in runner.run_async(
user_id=req.user_id,
session_id=req.session_id,
new_message=req.new_message,
run_config=RunConfig(streaming_mode=stream_mode),
):
# Format as SSE data
sse_event = event.model_dump_json(exclude_none=True, by_alias=True)
logger.info("Generated event in agent run streaming: %s", sse_event)
yield f"data: {sse_event}\n\n"
except Exception as e:
logger.exception("Error in event_generator: %s", e)
# You might want to yield an error event here
yield f'data: {{"error": "{str(e)}"}}\n\n'
# Returns a streaming response with the proper media type for SSE
return StreamingResponse(
event_generator(),
media_type="text/event-stream",
)
@app.get(
"/apps/{app_name}/users/{user_id}/sessions/{session_id}/events/{event_id}/graph",
response_model_exclude_none=True,
)
async def get_event_graph(
app_name: str, user_id: str, session_id: str, event_id: str
):
# Connect to managed session if agent_engine_id is set.
app_id = agent_engine_id if agent_engine_id else app_name
session = session_service.get_session(
app_name=app_id, user_id=user_id, session_id=session_id
)
session_events = session.events if session else []
event = next((x for x in session_events if x.id == event_id), None)
if not event:
return {}
from . import agent_graph
function_calls = event.get_function_calls()
function_responses = event.get_function_responses()
root_agent = await _get_root_agent_async(app_name)
dot_graph = None
if function_calls:
function_call_highlights = []
for function_call in function_calls:
from_name = event.author
to_name = function_call.name
function_call_highlights.append((from_name, to_name))
dot_graph = agent_graph.get_agent_graph(
root_agent, function_call_highlights
)
elif function_responses:
function_responses_highlights = []
for function_response in function_responses:
from_name = function_response.name
to_name = event.author
function_responses_highlights.append((from_name, to_name))
dot_graph = agent_graph.get_agent_graph(
root_agent, function_responses_highlights
)
else:
from_name = event.author
to_name = ""
dot_graph = agent_graph.get_agent_graph(
root_agent, [(from_name, to_name)]
)
if dot_graph and isinstance(dot_graph, graphviz.Digraph):
return {"dot_src": dot_graph.source}
else:
return {}
@app.websocket("/run_live")
async def agent_live_run(
websocket: WebSocket,
app_name: str,
user_id: str,
session_id: str,
modalities: List[Literal["TEXT", "AUDIO"]] = Query(
default=["TEXT", "AUDIO"]
), # Only allows "TEXT" or "AUDIO"
) -> None:
await websocket.accept()
# Connect to managed session if agent_engine_id is set.
app_id = agent_engine_id if agent_engine_id else app_name
session = session_service.get_session(
app_name=app_id, user_id=user_id, session_id=session_id
)
if not session:
# Accept first so that the client is aware of connection establishment,
# then close with a specific code.
await websocket.close(code=1002, reason="Session not found")
return
live_request_queue = LiveRequestQueue()
async def forward_events():
runner = await _get_runner_async(app_name)
async for event in runner.run_live(
session=session, live_request_queue=live_request_queue
):
await websocket.send_text(
event.model_dump_json(exclude_none=True, by_alias=True)
)
async def process_messages():
try:
while True:
data = await websocket.receive_text()
# Validate and send the received message to the live queue.
live_request_queue.send(LiveRequest.model_validate_json(data))
except ValidationError as ve:
logger.error("Validation error in process_messages: %s", ve)
# Run both tasks concurrently and cancel all if one fails.
tasks = [
asyncio.create_task(forward_events()),
asyncio.create_task(process_messages()),
]
done, pending = await asyncio.wait(
tasks, return_when=asyncio.FIRST_EXCEPTION
)
try:
# This will re-raise any exception from the completed tasks.
for task in done:
task.result()
except WebSocketDisconnect:
logger.info("Client disconnected during process_messages.")
except Exception as e:
logger.exception("Error during live websocket communication: %s", e)
traceback.print_exc()
WEBSOCKET_INTERNAL_ERROR_CODE = 1011
WEBSOCKET_MAX_BYTES_FOR_REASON = 123
await websocket.close(
code=WEBSOCKET_INTERNAL_ERROR_CODE,
reason=str(e)[:WEBSOCKET_MAX_BYTES_FOR_REASON],
)
finally:
for task in pending:
task.cancel()
async def _get_root_agent_async(app_name: str) -> Agent:
"""Returns the root agent for the given app."""
if app_name in root_agent_dict:
return root_agent_dict[app_name]
agent_module = importlib.import_module(app_name)
if getattr(agent_module.agent, "root_agent"):
root_agent = agent_module.agent.root_agent
else:
raise ValueError(f'Unable to find "root_agent" from {app_name}.')
# Handle an awaitable root agent and await for the actual agent.
if inspect.isawaitable(root_agent):
try:
agent, exit_stack = await root_agent
exit_stacks.append(exit_stack)
root_agent = agent
except Exception as e:
raise RuntimeError(f"error getting root agent, {e}") from e
root_agent_dict[app_name] = root_agent
return root_agent
async def _get_runner_async(app_name: str) -> Runner:
"""Returns the runner for the given app."""
envs.load_dotenv_for_agent(os.path.basename(app_name), agent_dir)
if app_name in runner_dict:
return runner_dict[app_name]
root_agent = await _get_root_agent_async(app_name)
runner = Runner(
app_name=agent_engine_id if agent_engine_id else app_name,
agent=root_agent,
artifact_service=artifact_service,
session_service=session_service,
memory_service=memory_service,
)
runner_dict[app_name] = runner
return runner
if web:
BASE_DIR = Path(__file__).parent.resolve()
ANGULAR_DIST_PATH = BASE_DIR / "browser"
@app.get("/")
async def redirect_to_dev_ui():
return RedirectResponse("/dev-ui")
@app.get("/dev-ui")
async def dev_ui():
return FileResponse(BASE_DIR / "browser/index.html")
app.mount(
"/", StaticFiles(directory=ANGULAR_DIST_PATH, html=True), name="static"
)
return app

View File

@@ -0,0 +1,49 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import re
from typing import Any
from typing import Optional
from ...agents.base_agent import BaseAgent
from ...agents.llm_agent import LlmAgent
__all__ = [
'create_empty_state',
]
def _create_empty_state(agent: BaseAgent, all_state: dict[str, Any]):
for sub_agent in agent.sub_agents:
_create_empty_state(sub_agent, all_state)
if (
isinstance(agent, LlmAgent)
and agent.instruction
and isinstance(agent.instruction, str)
):
for key in re.findall(r'{([\w]+)}', agent.instruction):
all_state[key] = ''
def create_empty_state(
agent: BaseAgent, initialized_states: Optional[dict[str, Any]] = None
) -> dict[str, Any]:
"""Creates empty str for non-initialized states."""
non_initialized_states = {}
_create_empty_state(agent, non_initialized_states)
for key in initialized_states or {}:
if key in non_initialized_states:
del non_initialized_states[key]
return non_initialized_states

Some files were not shown because too many files have changed in this diff Show More