structure saas with tools

This commit is contained in:
Davidson Gomes
2025-04-25 15:30:54 -03:00
commit 1aef473937
16434 changed files with 6584257 additions and 0 deletions

View File

@@ -0,0 +1,590 @@
# Copyright 2014 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Shared helpers for Google Cloud packages.
This module is not part of the public API surface.
"""
from __future__ import absolute_import
import calendar
import datetime
import http.client
import os
import re
from threading import local as Local
from typing import Union
import google.auth
import google.auth.transport.requests
from google.protobuf import duration_pb2
from google.protobuf import timestamp_pb2
try:
import grpc
import google.auth.transport.grpc
except ImportError: # pragma: NO COVER
grpc = None
# `google.cloud._helpers._NOW` is deprecated
_NOW = datetime.datetime.utcnow
UTC = datetime.timezone.utc # Singleton instance to be used throughout.
_EPOCH = datetime.datetime(1970, 1, 1, tzinfo=datetime.timezone.utc)
_RFC3339_MICROS = "%Y-%m-%dT%H:%M:%S.%fZ"
_RFC3339_NO_FRACTION = "%Y-%m-%dT%H:%M:%S"
_TIMEONLY_W_MICROS = "%H:%M:%S.%f"
_TIMEONLY_NO_FRACTION = "%H:%M:%S"
# datetime.strptime cannot handle nanosecond precision: parse w/ regex
_RFC3339_NANOS = re.compile(
r"""
(?P<no_fraction>
\d{4}-\d{2}-\d{2}T\d{2}:\d{2}:\d{2} # YYYY-MM-DDTHH:MM:SS
)
( # Optional decimal part
\. # decimal point
(?P<nanos>\d{1,9}) # nanoseconds, maybe truncated
)?
Z # Zulu
""",
re.VERBOSE,
)
# NOTE: Catching this ImportError is a workaround for GAE not supporting the
# "pwd" module which is imported lazily when "expanduser" is called.
_USER_ROOT: Union[str, None]
try:
_USER_ROOT = os.path.expanduser("~")
except ImportError: # pragma: NO COVER
_USER_ROOT = None
_GCLOUD_CONFIG_FILE = os.path.join("gcloud", "configurations", "config_default")
_GCLOUD_CONFIG_SECTION = "core"
_GCLOUD_CONFIG_KEY = "project"
class _LocalStack(Local):
"""Manage a thread-local LIFO stack of resources.
Intended for use in :class:`google.cloud.datastore.batch.Batch.__enter__`,
:class:`google.cloud.storage.batch.Batch.__enter__`, etc.
"""
def __init__(self):
super(_LocalStack, self).__init__()
self._stack = []
def __iter__(self):
"""Iterate the stack in LIFO order."""
return iter(reversed(self._stack))
def push(self, resource):
"""Push a resource onto our stack."""
self._stack.append(resource)
def pop(self):
"""Pop a resource from our stack.
:rtype: object
:returns: the top-most resource, after removing it.
:raises IndexError: if the stack is empty.
"""
return self._stack.pop()
@property
def top(self):
"""Get the top-most resource
:rtype: object
:returns: the top-most item, or None if the stack is empty.
"""
if self._stack:
return self._stack[-1]
def _ensure_tuple_or_list(arg_name, tuple_or_list):
"""Ensures an input is a tuple or list.
This effectively reduces the iterable types allowed to a very short
allowlist: list and tuple.
:type arg_name: str
:param arg_name: Name of argument to use in error message.
:type tuple_or_list: sequence of str
:param tuple_or_list: Sequence to be verified.
:rtype: list of str
:returns: The ``tuple_or_list`` passed in cast to a ``list``.
:raises TypeError: if the ``tuple_or_list`` is not a tuple or list.
"""
if not isinstance(tuple_or_list, (tuple, list)):
raise TypeError(
"Expected %s to be a tuple or list. "
"Received %r" % (arg_name, tuple_or_list)
)
return list(tuple_or_list)
def _determine_default_project(project=None):
"""Determine default project ID explicitly or implicitly as fall-back.
See :func:`google.auth.default` for details on how the default project
is determined.
:type project: str
:param project: Optional. The project name to use as default.
:rtype: str or ``NoneType``
:returns: Default project if it can be determined.
"""
if project is None:
_, project = google.auth.default()
return project
def _millis(when):
"""Convert a zone-aware datetime to integer milliseconds.
:type when: :class:`datetime.datetime`
:param when: the datetime to convert
:rtype: int
:returns: milliseconds since epoch for ``when``
"""
micros = _microseconds_from_datetime(when)
return micros // 1000
def _datetime_from_microseconds(value):
"""Convert timestamp to datetime, assuming UTC.
:type value: float
:param value: The timestamp to convert
:rtype: :class:`datetime.datetime`
:returns: The datetime object created from the value.
"""
return _EPOCH + datetime.timedelta(microseconds=value)
def _microseconds_from_datetime(value):
"""Convert non-none datetime to microseconds.
:type value: :class:`datetime.datetime`
:param value: The timestamp to convert.
:rtype: int
:returns: The timestamp, in microseconds.
"""
if not value.tzinfo:
value = value.replace(tzinfo=UTC)
# Regardless of what timezone is on the value, convert it to UTC.
value = value.astimezone(UTC)
# Convert the datetime to a microsecond timestamp.
return int(calendar.timegm(value.timetuple()) * 1e6) + value.microsecond
def _millis_from_datetime(value):
"""Convert non-none datetime to timestamp, assuming UTC.
:type value: :class:`datetime.datetime`
:param value: (Optional) the timestamp
:rtype: int, or ``NoneType``
:returns: the timestamp, in milliseconds, or None
"""
if value is not None:
return _millis(value)
def _date_from_iso8601_date(value):
"""Convert a ISO8601 date string to native datetime date
:type value: str
:param value: The date string to convert
:rtype: :class:`datetime.date`
:returns: A datetime date object created from the string
"""
return datetime.datetime.strptime(value, "%Y-%m-%d").date()
def _time_from_iso8601_time_naive(value):
"""Convert a zoneless ISO8601 time string to naive datetime time
:type value: str
:param value: The time string to convert
:rtype: :class:`datetime.time`
:returns: A datetime time object created from the string
:raises ValueError: if the value does not match a known format.
"""
if len(value) == 8: # HH:MM:SS
fmt = _TIMEONLY_NO_FRACTION
elif len(value) == 15: # HH:MM:SS.micros
fmt = _TIMEONLY_W_MICROS
else:
raise ValueError("Unknown time format: {}".format(value))
return datetime.datetime.strptime(value, fmt).time()
def _rfc3339_to_datetime(dt_str):
"""Convert a microsecond-precision timestamp to a native datetime.
:type dt_str: str
:param dt_str: The string to convert.
:rtype: :class:`datetime.datetime`
:returns: The datetime object created from the string.
"""
return datetime.datetime.strptime(dt_str, _RFC3339_MICROS).replace(tzinfo=UTC)
def _rfc3339_nanos_to_datetime(dt_str):
"""Convert a nanosecond-precision timestamp to a native datetime.
.. note::
Python datetimes do not support nanosecond precision; this function
therefore truncates such values to microseconds.
:type dt_str: str
:param dt_str: The string to convert.
:rtype: :class:`datetime.datetime`
:returns: The datetime object created from the string.
:raises ValueError: If the timestamp does not match the RFC 3339
regular expression.
"""
with_nanos = _RFC3339_NANOS.match(dt_str)
if with_nanos is None:
raise ValueError(
"Timestamp: %r, does not match pattern: %r"
% (dt_str, _RFC3339_NANOS.pattern)
)
bare_seconds = datetime.datetime.strptime(
with_nanos.group("no_fraction"), _RFC3339_NO_FRACTION
)
fraction = with_nanos.group("nanos")
if fraction is None:
micros = 0
else:
scale = 9 - len(fraction)
nanos = int(fraction) * (10**scale)
micros = nanos // 1000
return bare_seconds.replace(microsecond=micros, tzinfo=UTC)
def _datetime_to_rfc3339(value, ignore_zone=True):
"""Convert a timestamp to a string.
:type value: :class:`datetime.datetime`
:param value: The datetime object to be converted to a string.
:type ignore_zone: bool
:param ignore_zone: If True, then the timezone (if any) of the datetime
object is ignored.
:rtype: str
:returns: The string representing the datetime stamp.
"""
if not ignore_zone and value.tzinfo is not None:
# Convert to UTC and remove the time zone info.
value = value.replace(tzinfo=None) - value.utcoffset()
return value.strftime(_RFC3339_MICROS)
def _to_bytes(value, encoding="ascii"):
"""Converts a string value to bytes, if necessary.
:type value: str / bytes or unicode
:param value: The string/bytes value to be converted.
:type encoding: str
:param encoding: The encoding to use to convert unicode to bytes. Defaults
to "ascii", which will not allow any characters from
ordinals larger than 127. Other useful values are
"latin-1", which which will only allows byte ordinals
(up to 255) and "utf-8", which will encode any unicode
that needs to be.
:rtype: str / bytes
:returns: The original value converted to bytes (if unicode) or as passed
in if it started out as bytes.
:raises TypeError: if the value could not be converted to bytes.
"""
result = value.encode(encoding) if isinstance(value, str) else value
if isinstance(result, bytes):
return result
else:
raise TypeError("%r could not be converted to bytes" % (value,))
def _bytes_to_unicode(value):
"""Converts bytes to a unicode value, if necessary.
:type value: bytes
:param value: bytes value to attempt string conversion on.
:rtype: str
:returns: The original value converted to unicode (if bytes) or as passed
in if it started out as unicode.
:raises ValueError: if the value could not be converted to unicode.
"""
result = value.decode("utf-8") if isinstance(value, bytes) else value
if isinstance(result, str):
return result
else:
raise ValueError("%r could not be converted to unicode" % (value,))
def _from_any_pb(pb_type, any_pb):
"""Converts an Any protobuf to the specified message type
Args:
pb_type (type): the type of the message that any_pb stores an instance
of.
any_pb (google.protobuf.any_pb2.Any): the object to be converted.
Returns:
pb_type: An instance of the pb_type message.
Raises:
TypeError: if the message could not be converted.
"""
msg = pb_type()
if not any_pb.Unpack(msg):
raise TypeError(
"Could not convert {} to {}".format(
any_pb.__class__.__name__, pb_type.__name__
)
)
return msg
def _pb_timestamp_to_datetime(timestamp_pb):
"""Convert a Timestamp protobuf to a datetime object.
:type timestamp_pb: :class:`google.protobuf.timestamp_pb2.Timestamp`
:param timestamp_pb: A Google returned timestamp protobuf.
:rtype: :class:`datetime.datetime`
:returns: A UTC datetime object converted from a protobuf timestamp.
"""
return _EPOCH + datetime.timedelta(
seconds=timestamp_pb.seconds, microseconds=(timestamp_pb.nanos / 1000.0)
)
def _pb_timestamp_to_rfc3339(timestamp_pb):
"""Convert a Timestamp protobuf to an RFC 3339 string.
:type timestamp_pb: :class:`google.protobuf.timestamp_pb2.Timestamp`
:param timestamp_pb: A Google returned timestamp protobuf.
:rtype: str
:returns: An RFC 3339 formatted timestamp string.
"""
timestamp = _pb_timestamp_to_datetime(timestamp_pb)
return _datetime_to_rfc3339(timestamp)
def _datetime_to_pb_timestamp(when):
"""Convert a datetime object to a Timestamp protobuf.
:type when: :class:`datetime.datetime`
:param when: the datetime to convert
:rtype: :class:`google.protobuf.timestamp_pb2.Timestamp`
:returns: A timestamp protobuf corresponding to the object.
"""
ms_value = _microseconds_from_datetime(when)
seconds, micros = divmod(ms_value, 10**6)
nanos = micros * 10**3
return timestamp_pb2.Timestamp(seconds=seconds, nanos=nanos)
def _timedelta_to_duration_pb(timedelta_val):
"""Convert a Python timedelta object to a duration protobuf.
.. note::
The Python timedelta has a granularity of microseconds while
the protobuf duration type has a duration of nanoseconds.
:type timedelta_val: :class:`datetime.timedelta`
:param timedelta_val: A timedelta object.
:rtype: :class:`google.protobuf.duration_pb2.Duration`
:returns: A duration object equivalent to the time delta.
"""
duration_pb = duration_pb2.Duration()
duration_pb.FromTimedelta(timedelta_val)
return duration_pb
def _duration_pb_to_timedelta(duration_pb):
"""Convert a duration protobuf to a Python timedelta object.
.. note::
The Python timedelta has a granularity of microseconds while
the protobuf duration type has a duration of nanoseconds.
:type duration_pb: :class:`google.protobuf.duration_pb2.Duration`
:param duration_pb: A protobuf duration object.
:rtype: :class:`datetime.timedelta`
:returns: The converted timedelta object.
"""
return datetime.timedelta(
seconds=duration_pb.seconds, microseconds=(duration_pb.nanos / 1000.0)
)
def _name_from_project_path(path, project, template):
"""Validate a URI path and get the leaf object's name.
:type path: str
:param path: URI path containing the name.
:type project: str
:param project: (Optional) The project associated with the request. It is
included for validation purposes. If passed as None,
disables validation.
:type template: str
:param template: Template regex describing the expected form of the path.
The regex must have two named groups, 'project' and
'name'.
:rtype: str
:returns: Name parsed from ``path``.
:raises ValueError: if the ``path`` is ill-formed or if the project from
the ``path`` does not agree with the ``project``
passed in.
"""
if isinstance(template, str):
template = re.compile(template)
match = template.match(path)
if not match:
raise ValueError(
'path "%s" did not match expected pattern "%s"' % (path, template.pattern)
)
if project is not None:
found_project = match.group("project")
if found_project != project:
raise ValueError(
"Project from client (%s) should agree with "
"project from resource(%s)." % (project, found_project)
)
return match.group("name")
def make_secure_channel(credentials, user_agent, host, extra_options=()):
"""Makes a secure channel for an RPC service.
Uses / depends on gRPC.
:type credentials: :class:`google.auth.credentials.Credentials`
:param credentials: The OAuth2 Credentials to use for creating
access tokens.
:type user_agent: str
:param user_agent: The user agent to be used with API requests.
:type host: str
:param host: The host for the service.
:type extra_options: tuple
:param extra_options: (Optional) Extra gRPC options used when creating the
channel.
:rtype: :class:`grpc._channel.Channel`
:returns: gRPC secure channel with credentials attached.
"""
target = "%s:%d" % (host, http.client.HTTPS_PORT)
http_request = google.auth.transport.requests.Request()
user_agent_option = ("grpc.primary_user_agent", user_agent)
options = (user_agent_option,) + extra_options
return google.auth.transport.grpc.secure_authorized_channel(
credentials, http_request, target, options=options
)
def make_secure_stub(credentials, user_agent, stub_class, host, extra_options=()):
"""Makes a secure stub for an RPC service.
Uses / depends on gRPC.
:type credentials: :class:`google.auth.credentials.Credentials`
:param credentials: The OAuth2 Credentials to use for creating
access tokens.
:type user_agent: str
:param user_agent: The user agent to be used with API requests.
:type stub_class: type
:param stub_class: A gRPC stub type for a given service.
:type host: str
:param host: The host for the service.
:type extra_options: tuple
:param extra_options: (Optional) Extra gRPC options passed when creating
the channel.
:rtype: object, instance of ``stub_class``
:returns: The stub object used to make gRPC requests to a given API.
"""
channel = make_secure_channel(
credentials, user_agent, host, extra_options=extra_options
)
return stub_class(channel)
def make_insecure_stub(stub_class, host, port=None):
"""Makes an insecure stub for an RPC service.
Uses / depends on gRPC.
:type stub_class: type
:param stub_class: A gRPC stub type for a given service.
:type host: str
:param host: The host for the service. May also include the port
if ``port`` is unspecified.
:type port: int
:param port: (Optional) The port for the service.
:rtype: object, instance of ``stub_class``
:returns: The stub object used to make gRPC requests to a given API.
"""
if port is None:
target = host
else:
# NOTE: This assumes port != http.client.HTTPS_PORT:
target = "%s:%d" % (host, port)
channel = grpc.insecure_channel(target)
return stub_class(channel)

View File

@@ -0,0 +1,2 @@
# Marker file for PEP 561.
# This package uses inline types.

View File

@@ -0,0 +1,499 @@
# Copyright 2014 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Shared implementation of connections to API servers."""
import collections
import collections.abc
import json
import os
import platform
from typing import Optional
from urllib.parse import urlencode
import warnings
from google.api_core.client_info import ClientInfo
from google.cloud import exceptions
from google.cloud import version
API_BASE_URL = "https://www.googleapis.com"
"""The base of the API call URL."""
DEFAULT_USER_AGENT = "gcloud-python/{0}".format(version.__version__)
"""The user agent for google-cloud-python requests."""
CLIENT_INFO_HEADER = "X-Goog-API-Client"
CLIENT_INFO_TEMPLATE = "gl-python/" + platform.python_version() + " gccl/{}"
_USER_AGENT_ALL_CAPS_DEPRECATED = """\
The 'USER_AGENT' class-level attribute is deprecated. Please use
'user_agent' instead.
"""
_EXTRA_HEADERS_ALL_CAPS_DEPRECATED = """\
The '_EXTRA_HEADERS' class-level attribute is deprecated. Please use
'extra_headers' instead.
"""
_DEFAULT_TIMEOUT = 60 # in seconds
class Connection(object):
"""A generic connection to Google Cloud Platform.
:type client: :class:`~google.cloud.client.Client`
:param client: The client that owns the current connection.
:type client_info: :class:`~google.api_core.client_info.ClientInfo`
:param client_info: (Optional) instance used to generate user agent.
"""
_user_agent = DEFAULT_USER_AGENT
def __init__(self, client, client_info=None):
self._client = client
if client_info is None:
client_info = ClientInfo()
self._client_info = client_info
self._extra_headers = {}
@property
def USER_AGENT(self):
"""Deprecated: get / set user agent sent by connection.
:rtype: str
:returns: user agent
"""
warnings.warn(_USER_AGENT_ALL_CAPS_DEPRECATED, DeprecationWarning, stacklevel=2)
return self.user_agent
@USER_AGENT.setter
def USER_AGENT(self, value):
warnings.warn(_USER_AGENT_ALL_CAPS_DEPRECATED, DeprecationWarning, stacklevel=2)
self.user_agent = value
@property
def user_agent(self):
"""Get / set user agent sent by connection.
:rtype: str
:returns: user agent
"""
return self._client_info.to_user_agent()
@user_agent.setter
def user_agent(self, value):
self._client_info.user_agent = value
@property
def _EXTRA_HEADERS(self):
"""Deprecated: get / set extra headers sent by connection.
:rtype: dict
:returns: header keys / values
"""
warnings.warn(
_EXTRA_HEADERS_ALL_CAPS_DEPRECATED, DeprecationWarning, stacklevel=2
)
return self.extra_headers
@_EXTRA_HEADERS.setter
def _EXTRA_HEADERS(self, value):
warnings.warn(
_EXTRA_HEADERS_ALL_CAPS_DEPRECATED, DeprecationWarning, stacklevel=2
)
self.extra_headers = value
@property
def extra_headers(self):
"""Get / set extra headers sent by connection.
:rtype: dict
:returns: header keys / values
"""
return self._extra_headers
@extra_headers.setter
def extra_headers(self, value):
self._extra_headers = value
@property
def credentials(self):
"""Getter for current credentials.
:rtype: :class:`google.auth.credentials.Credentials` or
:class:`NoneType`
:returns: The credentials object associated with this connection.
"""
return self._client._credentials
@property
def http(self):
"""A getter for the HTTP transport used in talking to the API.
Returns:
google.auth.transport.requests.AuthorizedSession:
A :class:`requests.Session` instance.
"""
return self._client._http
class JSONConnection(Connection):
"""A connection to a Google JSON-based API.
These APIs are discovery based. For reference:
https://developers.google.com/discovery/
This defines :meth:`api_request` for making a generic JSON
API request and API requests are created elsewhere.
* :attr:`API_BASE_URL`
* :attr:`API_VERSION`
* :attr:`API_URL_TEMPLATE`
must be updated by subclasses.
"""
API_BASE_URL: Optional[str] = None
"""The base of the API call URL."""
API_BASE_MTLS_URL: Optional[str] = None
"""The base of the API call URL for mutual TLS."""
ALLOW_AUTO_SWITCH_TO_MTLS_URL = False
"""Indicates if auto switch to mTLS url is allowed."""
API_VERSION: Optional[str] = None
"""The version of the API, used in building the API call's URL."""
API_URL_TEMPLATE: Optional[str] = None
"""A template for the URL of a particular API call."""
def get_api_base_url_for_mtls(self, api_base_url=None):
"""Return the api base url for mutual TLS.
Typically, you shouldn't need to use this method.
The logic is as follows:
If `api_base_url` is provided, just return this value; otherwise, the
return value depends `GOOGLE_API_USE_MTLS_ENDPOINT` environment variable
value.
If the environment variable value is "always", return `API_BASE_MTLS_URL`.
If the environment variable value is "never", return `API_BASE_URL`.
Otherwise, if `ALLOW_AUTO_SWITCH_TO_MTLS_URL` is True and the underlying
http is mTLS, then return `API_BASE_MTLS_URL`; otherwise return `API_BASE_URL`.
:type api_base_url: str
:param api_base_url: User provided api base url. It takes precedence over
`API_BASE_URL` and `API_BASE_MTLS_URL`.
:rtype: str
:returns: The api base url used for mTLS.
"""
if api_base_url:
return api_base_url
env = os.getenv("GOOGLE_API_USE_MTLS_ENDPOINT", "auto")
if env == "always":
url_to_use = self.API_BASE_MTLS_URL
elif env == "never":
url_to_use = self.API_BASE_URL
else:
if self.ALLOW_AUTO_SWITCH_TO_MTLS_URL:
url_to_use = (
self.API_BASE_MTLS_URL if self.http.is_mtls else self.API_BASE_URL
)
else:
url_to_use = self.API_BASE_URL
return url_to_use
def build_api_url(
self, path, query_params=None, api_base_url=None, api_version=None
):
"""Construct an API url given a few components, some optional.
Typically, you shouldn't need to use this method.
:type path: str
:param path: The path to the resource (ie, ``'/b/bucket-name'``).
:type query_params: dict or list
:param query_params: A dictionary of keys and values (or list of
key-value pairs) to insert into the query
string of the URL.
:type api_base_url: str
:param api_base_url: The base URL for the API endpoint.
Typically you won't have to provide this.
:type api_version: str
:param api_version: The version of the API to call.
Typically you shouldn't provide this and instead
use the default for the library.
:rtype: str
:returns: The URL assembled from the pieces provided.
"""
url = self.API_URL_TEMPLATE.format(
api_base_url=self.get_api_base_url_for_mtls(api_base_url),
api_version=(api_version or self.API_VERSION),
path=path,
)
query_params = query_params or {}
if isinstance(query_params, collections.abc.Mapping):
query_params = query_params.copy()
else:
query_params_dict = collections.defaultdict(list)
for key, value in query_params:
query_params_dict[key].append(value)
query_params = query_params_dict
query_params.setdefault("prettyPrint", "false")
url += "?" + urlencode(query_params, doseq=True)
return url
def _make_request(
self,
method,
url,
data=None,
content_type=None,
headers=None,
target_object=None,
timeout=_DEFAULT_TIMEOUT,
extra_api_info=None,
):
"""A low level method to send a request to the API.
Typically, you shouldn't need to use this method.
:type method: str
:param method: The HTTP method to use in the request.
:type url: str
:param url: The URL to send the request to.
:type data: str
:param data: The data to send as the body of the request.
:type content_type: str
:param content_type: The proper MIME type of the data provided.
:type headers: dict
:param headers: (Optional) A dictionary of HTTP headers to send with
the request. If passed, will be modified directly
here with added headers.
:type target_object: object
:param target_object:
(Optional) Argument to be used by library callers. This can allow
custom behavior, for example, to defer an HTTP request and complete
initialization of the object at a later time.
:type timeout: float or tuple
:param timeout: (optional) The amount of time, in seconds, to wait
for the server response.
Can also be passed as a tuple (connect_timeout, read_timeout).
See :meth:`requests.Session.request` documentation for details.
:type extra_api_info: string
:param extra_api_info: (optional) Extra api info to be appended to
the X-Goog-API-Client header
:rtype: :class:`requests.Response`
:returns: The HTTP response.
"""
headers = headers or {}
headers.update(self.extra_headers)
headers["Accept-Encoding"] = "gzip"
if content_type:
headers["Content-Type"] = content_type
if extra_api_info:
headers[CLIENT_INFO_HEADER] = f"{self.user_agent} {extra_api_info}"
else:
headers[CLIENT_INFO_HEADER] = self.user_agent
headers["User-Agent"] = self.user_agent
return self._do_request(
method, url, headers, data, target_object, timeout=timeout
)
def _do_request(
self, method, url, headers, data, target_object, timeout=_DEFAULT_TIMEOUT
): # pylint: disable=unused-argument
"""Low-level helper: perform the actual API request over HTTP.
Allows batch context managers to override and defer a request.
:type method: str
:param method: The HTTP method to use in the request.
:type url: str
:param url: The URL to send the request to.
:type headers: dict
:param headers: A dictionary of HTTP headers to send with the request.
:type data: str
:param data: The data to send as the body of the request.
:type target_object: object
:param target_object:
(Optional) Unused ``target_object`` here but may be used by a
superclass.
:type timeout: float or tuple
:param timeout: (optional) The amount of time, in seconds, to wait
for the server response.
Can also be passed as a tuple (connect_timeout, read_timeout).
See :meth:`requests.Session.request` documentation for details.
:rtype: :class:`requests.Response`
:returns: The HTTP response.
"""
return self.http.request(
url=url, method=method, headers=headers, data=data, timeout=timeout
)
def api_request(
self,
method,
path,
query_params=None,
data=None,
content_type=None,
headers=None,
api_base_url=None,
api_version=None,
expect_json=True,
_target_object=None,
timeout=_DEFAULT_TIMEOUT,
extra_api_info=None,
):
"""Make a request over the HTTP transport to the API.
You shouldn't need to use this method, but if you plan to
interact with the API using these primitives, this is the
correct one to use.
:type method: str
:param method: The HTTP method name (ie, ``GET``, ``POST``, etc).
Required.
:type path: str
:param path: The path to the resource (ie, ``'/b/bucket-name'``).
Required.
:type query_params: dict or list
:param query_params: A dictionary of keys and values (or list of
key-value pairs) to insert into the query
string of the URL.
:type data: str
:param data: The data to send as the body of the request. Default is
the empty string.
:type content_type: str
:param content_type: The proper MIME type of the data provided. Default
is None.
:type headers: dict
:param headers: extra HTTP headers to be sent with the request.
:type api_base_url: str
:param api_base_url: The base URL for the API endpoint.
Typically you won't have to provide this.
Default is the standard API base URL.
:type api_version: str
:param api_version: The version of the API to call. Typically
you shouldn't provide this and instead use
the default for the library. Default is the
latest API version supported by
google-cloud-python.
:type expect_json: bool
:param expect_json: If True, this method will try to parse the
response as JSON and raise an exception if
that cannot be done. Default is True.
:type _target_object: :class:`object`
:param _target_object:
(Optional) Protected argument to be used by library callers. This
can allow custom behavior, for example, to defer an HTTP request
and complete initialization of the object at a later time.
:type timeout: float or tuple
:param timeout: (optional) The amount of time, in seconds, to wait
for the server response.
Can also be passed as a tuple (connect_timeout, read_timeout).
See :meth:`requests.Session.request` documentation for details.
:type extra_api_info: string
:param extra_api_info: (optional) Extra api info to be appended to
the X-Goog-API-Client header
:raises ~google.cloud.exceptions.GoogleCloudError: if the response code
is not 200 OK.
:raises ValueError: if the response content type is not JSON.
:rtype: dict or str
:returns: The API response payload, either as a raw string or
a dictionary if the response is valid JSON.
"""
url = self.build_api_url(
path=path,
query_params=query_params,
api_base_url=api_base_url,
api_version=api_version,
)
# Making the executive decision that any dictionary
# data will be sent properly as JSON.
if data and isinstance(data, dict):
data = json.dumps(data)
content_type = "application/json"
response = self._make_request(
method=method,
url=url,
data=data,
content_type=content_type,
headers=headers,
target_object=_target_object,
timeout=timeout,
extra_api_info=extra_api_info,
)
if not 200 <= response.status_code < 300:
raise exceptions.from_http_response(response)
if expect_json and response.content:
return response.json()
else:
return response.content

View File

@@ -0,0 +1,2 @@
# Marker file for PEP 561.
# This package uses inline types.

View File

@@ -0,0 +1,121 @@
# Copyright 2014 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Shared testing utilities."""
from __future__ import absolute_import
class _Monkey(object):
"""Context-manager for replacing module names in the scope of a test."""
def __init__(self, module, **kw):
self.module = module
if not kw: # pragma: NO COVER
raise ValueError("_Monkey was used with nothing to monkey-patch")
self.to_restore = {key: getattr(module, key) for key in kw}
for key, value in kw.items():
setattr(module, key, value)
def __enter__(self):
return self
def __exit__(self, exc_type, exc_val, exc_tb):
for key, value in self.to_restore.items():
setattr(self.module, key, value)
class _NamedTemporaryFile(object):
def __init__(self, suffix=""):
import os
import tempfile
filehandle, self.name = tempfile.mkstemp(suffix=suffix)
os.close(filehandle)
def __enter__(self):
return self
def __exit__(self, exc_type, exc_val, exc_tb):
import os
os.remove(self.name)
def _tempdir_maker():
import contextlib
import shutil
import tempfile
@contextlib.contextmanager
def _tempdir_mgr():
temp_dir = tempfile.mkdtemp()
yield temp_dir
shutil.rmtree(temp_dir)
return _tempdir_mgr
# pylint: disable=invalid-name
# Retain _tempdir as a constant for backwards compatibility despite
# being an invalid name.
_tempdir = _tempdir_maker()
del _tempdir_maker
# pylint: enable=invalid-name
class _GAXBaseAPI(object):
_random_gax_error = False
def __init__(self, **kw):
self.__dict__.update(kw)
@staticmethod
def _make_grpc_error(status_code, trailing=None):
from grpc._channel import _RPCState
from google.cloud.exceptions import GrpcRendezvous
details = "Some error details."
exc_state = _RPCState((), None, trailing, status_code, details)
return GrpcRendezvous(exc_state, None, None, None)
def _make_grpc_not_found(self):
from grpc import StatusCode
return self._make_grpc_error(StatusCode.NOT_FOUND)
def _make_grpc_failed_precondition(self):
from grpc import StatusCode
return self._make_grpc_error(StatusCode.FAILED_PRECONDITION)
def _make_grpc_already_exists(self):
from grpc import StatusCode
return self._make_grpc_error(StatusCode.ALREADY_EXISTS)
def _make_grpc_deadline_exceeded(self):
from grpc import StatusCode
return self._make_grpc_error(StatusCode.DEADLINE_EXCEEDED)
class _GAXPageIterator(object):
def __init__(self, *pages, **kwargs):
self._pages = iter(pages)
self.page_token = kwargs.get("page_token")
def __next__(self):
"""Iterate to the next page."""
return next(self._pages)

View File

@@ -0,0 +1,2 @@
# Marker file for PEP 561.
# This package uses inline types.

View File

@@ -0,0 +1,187 @@
# -*- coding: utf-8 -*-
# Copyright 2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from google.cloud.aiplatform import version as aiplatform_version
__version__ = aiplatform_version.__version__
from google.cloud.aiplatform import initializer
from google.cloud.aiplatform.datasets import (
ImageDataset,
TabularDataset,
TextDataset,
TimeSeriesDataset,
VideoDataset,
)
from google.cloud.aiplatform import explain
from google.cloud.aiplatform import gapic
from google.cloud.aiplatform import hyperparameter_tuning
from google.cloud.aiplatform.featurestore import (
EntityType,
Feature,
Featurestore,
)
from google.cloud.aiplatform.matching_engine import (
MatchingEngineIndex,
MatchingEngineIndexEndpoint,
)
from google.cloud.aiplatform import metadata
from google.cloud.aiplatform.tensorboard import uploader_tracker
from google.cloud.aiplatform.models import DeploymentResourcePool
from google.cloud.aiplatform.models import Endpoint
from google.cloud.aiplatform.models import PrivateEndpoint
from google.cloud.aiplatform.models import Model
from google.cloud.aiplatform.models import ModelRegistry
from google.cloud.aiplatform.model_evaluation import ModelEvaluation
from google.cloud.aiplatform.jobs import (
BatchPredictionJob,
CustomJob,
HyperparameterTuningJob,
ModelDeploymentMonitoringJob,
)
from google.cloud.aiplatform.pipeline_jobs import PipelineJob
from google.cloud.aiplatform.pipeline_job_schedules import (
PipelineJobSchedule,
)
from google.cloud.aiplatform.tensorboard import (
Tensorboard,
TensorboardExperiment,
TensorboardRun,
TensorboardTimeSeries,
)
from google.cloud.aiplatform.training_jobs import (
CustomTrainingJob,
CustomContainerTrainingJob,
CustomPythonPackageTrainingJob,
AutoMLTabularTrainingJob,
AutoMLForecastingTrainingJob,
SequenceToSequencePlusForecastingTrainingJob,
TemporalFusionTransformerForecastingTrainingJob,
TimeSeriesDenseEncoderForecastingTrainingJob,
AutoMLImageTrainingJob,
AutoMLTextTrainingJob,
AutoMLVideoTrainingJob,
)
from google.cloud.aiplatform import helpers
"""
Usage:
from google.cloud import aiplatform
aiplatform.init(project='my_project')
"""
init = initializer.global_config.init
get_pipeline_df = metadata.metadata._LegacyExperimentService.get_pipeline_df
log_params = metadata.metadata._experiment_tracker.log_params
log_metrics = metadata.metadata._experiment_tracker.log_metrics
log_classification_metrics = (
metadata.metadata._experiment_tracker.log_classification_metrics
)
log_model = metadata.metadata._experiment_tracker.log_model
get_experiment_df = metadata.metadata._experiment_tracker.get_experiment_df
start_run = metadata.metadata._experiment_tracker.start_run
autolog = metadata.metadata._experiment_tracker.autolog
start_execution = metadata.metadata._experiment_tracker.start_execution
log = metadata.metadata._experiment_tracker.log
log_time_series_metrics = metadata.metadata._experiment_tracker.log_time_series_metrics
end_run = metadata.metadata._experiment_tracker.end_run
upload_tb_log = uploader_tracker._tensorboard_tracker.upload_tb_log
start_upload_tb_log = uploader_tracker._tensorboard_tracker.start_upload_tb_log
end_upload_tb_log = uploader_tracker._tensorboard_tracker.end_upload_tb_log
save_model = metadata._models.save_model
get_experiment_model = metadata.schema.google.artifact_schema.ExperimentModel.get
Experiment = metadata.experiment_resources.Experiment
ExperimentRun = metadata.experiment_run_resource.ExperimentRun
Artifact = metadata.artifact.Artifact
Execution = metadata.execution.Execution
Context = metadata.context.Context
__all__ = (
"end_run",
"explain",
"gapic",
"init",
"helpers",
"hyperparameter_tuning",
"log",
"log_params",
"log_metrics",
"log_classification_metrics",
"log_model",
"log_time_series_metrics",
"get_experiment_df",
"get_pipeline_df",
"start_run",
"start_execution",
"save_model",
"get_experiment_model",
"autolog",
"upload_tb_log",
"start_upload_tb_log",
"end_upload_tb_log",
"Artifact",
"AutoMLImageTrainingJob",
"AutoMLTabularTrainingJob",
"AutoMLForecastingTrainingJob",
"AutoMLTextTrainingJob",
"AutoMLVideoTrainingJob",
"BatchPredictionJob",
"CustomJob",
"CustomTrainingJob",
"CustomContainerTrainingJob",
"CustomPythonPackageTrainingJob",
"DeploymentResourcePool",
"Endpoint",
"EntityType",
"Execution",
"Experiment",
"ExperimentRun",
"Feature",
"Featurestore",
"MatchingEngineIndex",
"MatchingEngineIndexEndpoint",
"ImageDataset",
"HyperparameterTuningJob",
"Model",
"ModelRegistry",
"ModelEvaluation",
"ModelDeploymentMonitoringJob",
"PipelineJob",
"PipelineJobSchedule",
"PrivateEndpoint",
"SequenceToSequencePlusForecastingTrainingJob",
"TabularDataset",
"Tensorboard",
"TensorboardExperiment",
"TensorboardRun",
"TensorboardTimeSeries",
"TextDataset",
"TemporalFusionTransformerForecastingTrainingJob",
"TimeSeriesDataset",
"TimeSeriesDenseEncoderForecastingTrainingJob",
"VideoDataset",
)

View File

@@ -0,0 +1,16 @@
# -*- coding: utf-8 -*-
# Copyright 2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

View File

@@ -0,0 +1,471 @@
# -*- coding: utf-8 -*-
# Copyright 2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from collections import defaultdict
from typing import Any, Dict, List, NamedTuple, Optional, Union
from mlflow import entities as mlflow_entities
from mlflow.store.tracking import abstract_store
from mlflow import exceptions as mlflow_exceptions
from google.cloud import aiplatform
from google.cloud.aiplatform import base
from google.cloud.aiplatform import utils
from google.cloud.aiplatform.compat.types import execution as execution_v1
_LOGGER = base.Logger(__name__)
# MLFlow RunStatus:
# https://www.mlflow.org/docs/latest/python_api/mlflow.entities.html#mlflow.entities.RunStatus
_MLFLOW_RUN_TO_VERTEX_RUN_STATUS = {
mlflow_entities.RunStatus.FINISHED: execution_v1.Execution.State.COMPLETE,
mlflow_entities.RunStatus.FAILED: execution_v1.Execution.State.FAILED,
mlflow_entities.RunStatus.RUNNING: execution_v1.Execution.State.RUNNING,
mlflow_entities.RunStatus.KILLED: execution_v1.Execution.State.CANCELLED,
mlflow_entities.RunStatus.SCHEDULED: execution_v1.Execution.State.NEW,
}
mlflow_to_vertex_run_default = defaultdict(
lambda: execution_v1.Execution.State.STATE_UNSPECIFIED
)
for mlflow_status in _MLFLOW_RUN_TO_VERTEX_RUN_STATUS:
mlflow_to_vertex_run_default[mlflow_status] = _MLFLOW_RUN_TO_VERTEX_RUN_STATUS[
mlflow_status
]
# Mapping of Vertex run status to MLFlow run status (inverse of _MLFLOW_RUN_TO_VERTEX_RUN_STATUS)
_VERTEX_RUN_TO_MLFLOW_RUN_STATUS = {
v: k for k, v in _MLFLOW_RUN_TO_VERTEX_RUN_STATUS.items()
}
vertex_run_to_mflow_default = defaultdict(lambda: mlflow_entities.RunStatus.FAILED)
for vertex_status in _VERTEX_RUN_TO_MLFLOW_RUN_STATUS:
vertex_run_to_mflow_default[vertex_status] = _VERTEX_RUN_TO_MLFLOW_RUN_STATUS[
vertex_status
]
_MLFLOW_TERMINAL_RUN_STATES = [
mlflow_entities.RunStatus.FINISHED,
mlflow_entities.RunStatus.FAILED,
mlflow_entities.RunStatus.KILLED,
]
class _RunTracker(NamedTuple):
"""Tracks the current Vertex ExperimentRun.
Stores the current ExperimentRun the plugin is writing to and whether or
not this run is autocreated.
Attributes:
autocreate (bool):
Whether the Vertex ExperimentRun should be autocreated. If False,
the plugin writes to the currently active run created via
`aiplatform.start_run()`.
experiment_run (aiplatform.ExperimentRun):
The currently set ExperimentRun.
"""
autocreate: bool
experiment_run: "aiplatform.ExperimentRun"
class _VertexMlflowTracking(abstract_store.AbstractStore):
"""Vertex plugin implementation of MLFlow's AbstractStore class."""
def _to_mlflow_metric(
self,
vertex_metrics: Dict[str, Union[float, int, str]],
) -> Optional[List[mlflow_entities.Metric]]:
"""Helper method to convert Vertex metrics to mlflow.entities.Metric type.
Args:
vertex_metrics (Dict[str, Union[float, int, str]]):
Required. A dictionary of Vertex metrics returned from
ExperimentRun.get_metrics()
Returns:
List[mlflow_entities.Metric] - A list of metrics converted to MLFlow's
Metric type.
"""
mlflow_metrics = []
if vertex_metrics:
for metric_key in vertex_metrics:
mlflow_metric = mlflow_entities.Metric(
key=metric_key,
value=vertex_metrics[metric_key],
step=0,
timestamp=0,
)
mlflow_metrics.append(mlflow_metric)
else:
return None
return mlflow_metrics
def _to_mlflow_params(
self, vertex_params: Dict[str, Union[float, int, str]]
) -> Optional[mlflow_entities.Param]:
"""Helper method to convert Vertex params to mlflow.entities.Param type.
Args:
vertex_params (Dict[str, Union[float, int, str]]):
Required. A dictionary of Vertex params returned from
ExperimentRun.get_params()
Returns:
List[mlflow_entities.Param] - A list of params converted to MLFlow's
Param type.
"""
mlflow_params = []
if vertex_params:
for param_key in vertex_params:
mlflow_param = mlflow_entities.Param(
key=param_key, value=vertex_params[param_key]
)
mlflow_params.append(mlflow_param)
else:
return None
return mlflow_params
def _to_mlflow_entity(
self,
vertex_exp: "aiplatform.Experiment",
vertex_run: "aiplatform.ExperimentRun",
) -> mlflow_entities.Run:
"""Helper method to convert data to required MLFlow type.
This converts data into MLFlow's mlflow_entities.Run type, which is a
required return type for some methods we're overriding in this plugin.
Args:
vertex_exp (aiplatform.Experiment):
Required. The current Vertex Experiment.
vertex_run (aiplatform.ExperimentRun):
Required. The active Vertex ExperimentRun
Returns:
mlflow_entities.Run - The data from the currently active run
converted to MLFLow's mlflow_entities.Run type.
https://www.mlflow.org/docs/latest/python_api/mlflow.entities.html#mlflow.entities.Run
"""
run_info = mlflow_entities.RunInfo(
run_id=f"{vertex_exp.name}-{vertex_run.name}",
run_uuid=f"{vertex_exp.name}-{vertex_run.name}",
experiment_id=vertex_exp.name,
user_id="",
status=vertex_run_to_mflow_default[vertex_run.state],
start_time=1,
end_time=2,
lifecycle_stage=mlflow_entities.LifecycleStage.ACTIVE,
artifact_uri="file:///tmp/", # The plugin will fail if artifact_uri is not set to a valid filepath string
)
run_data = mlflow_entities.RunData(
metrics=self._to_mlflow_metric(vertex_run.get_metrics()),
params=self._to_mlflow_params(vertex_run.get_params()),
tags={},
)
return mlflow_entities.Run(run_info=run_info, run_data=run_data)
def __init__(self, store_uri: Optional[str], artifact_uri: Optional[str]) -> None:
"""Initializes the Vertex MLFlow plugin.
This plugin overrides MLFlow's AbstractStore class to write metrics and
parameters from model training code to Vertex Experiments. This plugin
is private and should not be instantiated outside the Vertex SDK.
The _run_map instance property is a dict mapping MLFlow run_id to an
instance of _RunTracker with data on the corresponding Vertex
ExperimentRun.
For example: {
'sklearn-12345': _RunTracker(autocreate=True, experiment_run=aiplatform.ExperimentRun(...))
}
Until autologging and Experiments supports nested runs, _nested_run_tracker
is used to ensure the plugin shows a warning log exactly once every time it
encounters a model that produces nested runs, like sklearn GridSearchCV and
RandomizedSearchCV models. It is a mapping of parent_run_id to the number of
child runs for that parent. When exactly 1 child run is found, the warning
log is shown.
Args:
store_uri (str):
The tracking store uri used by MLFlow to write parameters and
metrics for a run. This plugin ignores store_uri since we are
writing data to Vertex Experiments. For this plugin, the value
of store_uri will always be `vertex-mlflow-plugin://`.
artifact_uri (str):
The artifact uri used by MLFlow to write artifacts generated by
a run. This plugin ignores artifact_uri since it doesn't write
any artifacts to Vertex.
"""
self._run_map = {}
self._vertex_experiment = None
self._nested_run_tracker = {}
super(_VertexMlflowTracking, self).__init__()
@property
def run_map(self) -> Dict[str, Any]:
return self._run_map
@property
def vertex_experiment(self) -> "aiplatform.Experiment":
return self._vertex_experiment
def create_run(
self,
experiment_id: str,
user_id: str,
start_time: str,
tags: List[mlflow_entities.RunTag],
run_name: str,
) -> mlflow_entities.Run:
"""Creates a new ExperimentRun in Vertex if no run is active.
This overrides the behavior of MLFlow's `create_run()` method to check
if there is a currently active ExperimentRun. If no ExperimentRun is
active, a new Vertex ExperimentRun will be created with the name
`<ml-framework>-<timestamp>`. If aiplatform.start_run() has been
invoked and there is an active run, no run will be created and the
currently active ExperimentRun will be returned as an MLFlow Run
entity.
Args:
experiment_id (str):
The ID of the currently set MLFlow Experiment. Not used by this
plugin.
user_id (str):
The ID of the MLFlow user. Not used by this plugin.
start_time (int):
The start time of the run, in milliseconds since the UNIX
epoch. Not used by this plugin.
tags (List[mlflow_entities.RunTag]):
The tags provided by MLFlow. Only the `mlflow.autologging` tag
is used by this plugin.
run_name (str):
The name of the MLFlow run. Not used by this plugin.
Returns:
mlflow_entities.Run - The created run returned as MLFLow's run
type.
Raises:
RuntimeError:
If a second model training call is made to a manually created
run created via `aiplatform.start_run()` that has already been
used to autolog metrics and parameters in this session.
"""
self._vertex_experiment = (
aiplatform.metadata.metadata._experiment_tracker.experiment
)
currently_active_run = (
aiplatform.metadata.metadata._experiment_tracker.experiment_run
)
parent_run_id = None
for tag in tags:
if tag.key == "mlflow.parentRunId" and tag.value is not None:
parent_run_id = tag.value
if parent_run_id in self._nested_run_tracker:
self._nested_run_tracker[parent_run_id] += 1
else:
self._nested_run_tracker[parent_run_id] = 1
_LOGGER.warning(
f"This model creates nested runs. No additional ExperimentRun resources will be created for nested runs, summary metrics and parameters will be logged to the parent ExperimentRun: {parent_run_id}."
)
if currently_active_run:
if (
f"{currently_active_run.resource_id}" in self._run_map
and not parent_run_id
):
_LOGGER.warning(
"Metrics and parameters have already been logged to this run. Call aiplatform.end_run() to end the current run before training a new model."
)
raise mlflow_exceptions.MlflowException(
"Metrics and parameters have already been logged to this run. Call aiplatform.end_run() to end the current run before training a new model."
)
elif not parent_run_id:
run_tracker = _RunTracker(
autocreate=False, experiment_run=currently_active_run
)
current_run_id = currently_active_run.name
# nested run case
else:
raise mlflow_exceptions.MlflowException(
f"This model creates nested runs. No additional ExperimentRun resources will be created for nested runs, summary metrics and parameters will be logged to the {parent_run_id}: ExperimentRun."
)
# Create a new run if aiplatform.start_run() hasn't been called
else:
framework = ""
for tag in tags:
if tag.key == "mlflow.autologging":
framework = tag.value
current_run_id = f"{framework}-{utils.timestamped_unique_name()}"
currently_active_run = aiplatform.start_run(run=current_run_id)
run_tracker = _RunTracker(
autocreate=True, experiment_run=currently_active_run
)
self._run_map[currently_active_run.resource_id] = run_tracker
return self._to_mlflow_entity(
vertex_exp=self._vertex_experiment,
vertex_run=run_tracker.experiment_run,
)
def update_run_info(
self,
run_id: str,
run_status: mlflow_entities.RunStatus,
end_time: int,
run_name: str,
) -> mlflow_entities.RunInfo:
"""Updates the ExperimentRun status with the status provided by MLFlow.
Args:
run_id (str):
The ID of the currently set MLFlow run. This is mapped to the
corresponding ExperimentRun in self._run_map.
run_status (mlflow_entities.RunStatus):
The run status provided by MLFlow MLFlow.
end_time (int):
The end time of the run. Not used by this plugin.
run_name (str):
The name of the MLFlow run. Not used by this plugin.
Returns:
mlflow_entities.RunInfo - Info about the updated run in MLFlow's
required RunInfo format.
"""
# The if block below does the following:
# - Ends autocreated ExperimentRuns when MLFlow returns a terminal RunStatus.
# - For other autocreated runs or runs where MLFlow returns a non-terminal
# RunStatus, this updates the ExperimentRun with the corresponding
# _MLFLOW_RUN_TO_VERTEX_RUN_STATUS.
# - Non-autocreated ExperimentRuns with a terminal status are not ended.
if (
self._run_map[run_id].autocreate
and run_status in _MLFLOW_TERMINAL_RUN_STATES
and self._run_map[run_id].experiment_run
is aiplatform.metadata.metadata._experiment_tracker.experiment_run
):
aiplatform.metadata.metadata._experiment_tracker.end_run(
state=execution_v1.Execution.State.COMPLETE
)
elif (
self._run_map[run_id].autocreate
or run_status not in _MLFLOW_TERMINAL_RUN_STATES
):
self._run_map[run_id].experiment_run.update_state(
state=mlflow_to_vertex_run_default[run_status]
)
return mlflow_entities.RunInfo(
run_uuid=run_id,
run_id=run_id,
status=run_status,
end_time=end_time,
experiment_id=self._vertex_experiment,
user_id="",
start_time=1,
lifecycle_stage=mlflow_entities.LifecycleStage.ACTIVE,
artifact_uri="file:///tmp/",
)
def log_batch(
self,
run_id: str,
metrics: List[mlflow_entities.Metric],
params: List[mlflow_entities.Param],
tags: List[mlflow_entities.RunTag],
) -> None:
"""The primary logging method used by MLFlow.
This plugin overrides this method to write the metrics and parameters
provided by MLFlow to the active Vertex ExperimentRun.
Args:
run_id (str):
The ID of the MLFlow run to write metrics to. This is mapped to
the corresponding ExperimentRun in self._run_map.
metrics (List[mlflow_entities.Metric]):
A list of MLFlow metrics generated from the current model
training run.
params (List[mlflow_entities.Param]):
A list of MLFlow params generated from the current model
training run.
tags (List[mlflow_entities.RunTag]):
The tags provided by MLFlow. Not used by this plugin.
"""
summary_metrics = {}
summary_params = {}
time_series_metrics = {}
# Get the run to write to
vertex_run = self._run_map[run_id].experiment_run
for metric in metrics:
if metric.step:
if metric.step not in time_series_metrics:
time_series_metrics[metric.step] = {metric.key: metric.value}
else:
time_series_metrics[metric.step][metric.key] = metric.value
else:
summary_metrics[metric.key] = metric.value
for param in params:
summary_params[param.key] = param.value
if summary_metrics:
vertex_run.log_metrics(metrics=summary_metrics)
if summary_params:
vertex_run.log_params(params=summary_params)
# TODO(b/261722623): batch these calls
if time_series_metrics:
for step in time_series_metrics:
vertex_run.log_time_series_metrics(time_series_metrics[step], step)
def get_run(self, run_id: str) -> mlflow_entities.Run:
"""Gets the currently active run.
Args:
run_id (str):
The ID of the currently set MLFlow run. This is mapped to the
corresponding ExperimentRun in self._run_map.
Returns:
mlflow_entities.Run - The currently active Vertex ExperimentRun,
returned as MLFLow's run type.
"""
return self._to_mlflow_entity(
vertex_exp=self._vertex_experiment,
vertex_run=self._run_map[run_id].experiment_run,
)

View File

@@ -0,0 +1,22 @@
# -*- coding: utf-8 -*-
# Copyright 2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from google.cloud.aiplatform._pipeline_based_service.pipeline_based_service import (
_VertexAiPipelineBasedService,
)
__all__ = ("_VertexAiPipelineBasedService",)

View File

@@ -0,0 +1,431 @@
# -*- coding: utf-8 -*-
# Copyright 2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import abc
import logging
from typing import (
Any,
Dict,
FrozenSet,
Optional,
List,
Tuple,
Union,
)
from google.auth import credentials as auth_credentials
from google.cloud import aiplatform
from google.cloud.aiplatform import base
from google.cloud.aiplatform import pipeline_jobs
from google.cloud.aiplatform import utils
from google.cloud.aiplatform.compat.types import (
pipeline_state as gca_pipeline_state,
)
from google.cloud.aiplatform.constants import pipeline as pipeline_constants
_PIPELINE_COMPLETE_STATES = pipeline_constants._PIPELINE_COMPLETE_STATES
class _VertexAiPipelineBasedService(base.VertexAiStatefulResource):
"""Base class for Vertex AI Pipeline based services."""
client_class = utils.PipelineJobClientWithOverride
_resource_noun = "pipelineJob"
_delete_method = "delete_pipeline_job"
_getter_method = "get_pipeline_job"
_list_method = "list_pipeline_jobs"
_parse_resource_name_method = "parse_pipeline_job_path"
_format_resource_name_method = "pipeline_job_path"
_valid_done_states = _PIPELINE_COMPLETE_STATES
@property
@classmethod
@abc.abstractmethod
def _template_ref(cls) -> FrozenSet[Tuple[str, str]]:
"""A dictionary of the pipeline template URLs for this service.
The key is an identifier for that template and the value is the url of
that pipeline template.
For example: {"tabular_classification": "gs://path/to/tabular/pipeline/template.json"}
"""
pass
@property
@classmethod
@abc.abstractmethod
def _creation_log_message(cls) -> str:
"""A log message to use when the Pipeline-based Service is created.
_VertexAiPipelineBasedService supresses logs from PipelineJob creation
to avoid duplication.
For example: 'Created PipelineJob for your Model Evaluation.'
"""
pass
@property
@classmethod
@abc.abstractmethod
def _component_identifier(cls) -> str:
"""A 'component_type' value unique to this service's pipeline execution metadata.
This is an identifier used by the _validate_pipeline_template_matches_service method
to confirm the pipeline being instantiated belongs to this service. Use something
specific to your service's PipelineJob.
For example: 'fpc-model-evaluation'
"""
pass
@property
@classmethod
@abc.abstractmethod
def _template_name_identifier(cls) -> Optional[str]:
"""An optional name identifier for the pipeline template.
This will validate on the Pipeline's PipelineSpec.PipelineInfo.name
field. Setting this property will lead to an additional validation
check on pipeline templates in _does_pipeline_template_match_service.
If this property is present, the validation method will check for it
after validating on `_component_identifier`.
"""
pass
@classmethod
@abc.abstractmethod
def submit(self) -> "_VertexAiPipelineBasedService":
"""Subclasses should implement this method to submit the underlying PipelineJob."""
pass
# TODO (b/248582133): Consider updating this to return a list in the future to support multiple outputs
@property
@abc.abstractmethod
def _metadata_output_artifact(self) -> Optional[str]:
"""The ML Metadata output artifact resource URI from the completed pipeline run."""
pass
@property
def backing_pipeline_job(self) -> "pipeline_jobs.PipelineJob":
"""The PipelineJob associated with the resource."""
return pipeline_jobs.PipelineJob.get(resource_name=self.resource_name)
@property
def pipeline_console_uri(self) -> Optional[str]:
"""The console URI of the PipelineJob created by the service."""
if self.backing_pipeline_job:
return self.backing_pipeline_job._dashboard_uri()
@property
def state(self) -> Optional[gca_pipeline_state.PipelineState]:
"""The state of the Pipeline run associated with the service."""
if self.backing_pipeline_job:
return self.backing_pipeline_job.state
return None
@classmethod
def _does_pipeline_template_match_service(
cls, pipeline_job: "pipeline_jobs.PipelineJob"
) -> bool:
"""Checks whether the provided pipeline template matches the service.
Args:
pipeline_job (aiplatform.PipelineJob):
Required. The PipelineJob to validate with this Pipeline Based Service.
Returns:
Boolean indicating whether the provided template matches the
service it's trying to instantiate.
"""
valid_schema_titles = ["system.Run", "system.DagExecution"]
# We get the Execution here because we want to allow instantiating
# failed pipeline runs that match the service. The component_type is
# present in the Execution metadata for both failed and successful
# pipeline runs
for component in pipeline_job.task_details:
if not (
"name" in component.execution
and component.execution.schema_title in valid_schema_titles
):
continue
execution_resource = aiplatform.Execution.get(
component.execution.name, credentials=pipeline_job.credentials
)
# First validate on component_type
if (
"component_type" in execution_resource.metadata
and execution_resource.metadata.get("component_type")
== cls._component_identifier
):
# Then validate on _template_name_identifier if provided
if cls._template_name_identifier is None or (
pipeline_job.pipeline_spec is not None
and cls._template_name_identifier
== pipeline_job.pipeline_spec["pipelineInfo"]["name"]
):
return True
return False
# TODO (b/249153354): expose _template_ref in error message when artifact
# registry support is added
@classmethod
def _validate_pipeline_template_matches_service(
cls, pipeline_job: "pipeline_jobs.PipelineJob"
):
"""Validates the provided pipeline matches the template of the Pipeline Based Service.
Args:
pipeline_job (aiplatform.PipelineJob):
Required. The PipelineJob to validate with this Pipeline Based Service.
Raises:
ValueError: if the provided pipeline ID doesn't match the pipeline service.
"""
if not cls._does_pipeline_template_match_service(pipeline_job):
raise ValueError(
f"The provided pipeline template is not compatible with {cls.__name__}"
)
def __init__(
self,
pipeline_job_name: str,
project: Optional[str] = None,
location: Optional[str] = None,
credentials: Optional[auth_credentials.Credentials] = None,
):
"""Retrieves an existing Pipeline Based Service given the ID of the pipeline execution.
Args:
pipeline_job_name (str):
Required. A fully-qualified pipeline job run.
Example: "projects/123/locations/us-central1/pipelineJobs/456" or
"456" when project and location are initialized or passed.
project (str):
Optional. Project to retrieve pipeline job from. If not set, project
set in aiplatform.init will be used.
location (str):
Optional. Location to retrieve pipeline job from. If not set, location
set in aiplatform.init will be used.
credentials (auth_credentials.Credentials):
Optional. Custom credentials to use to retrieve this pipeline job. Overrides
credentials set in aiplatform.init.
Raises:
ValueError: if the pipeline template used in this PipelineJob is not
consistent with the _template_ref defined on the subclass.
"""
super().__init__(
project=project,
location=location,
credentials=credentials,
resource_name=pipeline_job_name,
)
job_resource = pipeline_jobs.PipelineJob.get(
resource_name=pipeline_job_name, credentials=credentials
)
self._validate_pipeline_template_matches_service(job_resource)
self._gca_resource = job_resource._gca_resource
@classmethod
def _create_and_submit_pipeline_job(
cls,
template_params: Dict[str, Any],
template_path: str,
pipeline_root: Optional[str] = None,
display_name: Optional[str] = None,
job_id: Optional[str] = None,
service_account: Optional[str] = None,
network: Optional[str] = None,
encryption_spec_key_name: Optional[str] = None,
project: Optional[str] = None,
location: Optional[str] = None,
credentials: Optional[auth_credentials.Credentials] = None,
experiment: Optional[Union[str, "aiplatform.Experiment"]] = None,
enable_caching: Optional[bool] = None,
) -> "_VertexAiPipelineBasedService":
"""Create a new PipelineJob using the provided template and parameters.
Args:
template_params (Dict[str, Any]):
Required. The parameters to pass to the given pipeline template.
template_path (str):
Required. The path of the pipeline template to use for this
pipeline run.
pipeline_root (str):
Optional. The GCS directory to store the pipeline run output.
If not set, the bucket set in `aiplatform.init(staging_bucket=...)`
will be used.
display_name (str):
Optional. The user-defined name of the PipelineJob created by
this Pipeline Based Service.
job_id (str):
Optional. The unique ID of the job run.
If not specified, pipeline name + timestamp will be used.
service_account (str):
Specifies the service account for workload run-as account.
Users submitting jobs must have act-as permission on this run-as account.
network (str):
The full name of the Compute Engine network to which the job
should be peered. For example, projects/12345/global/networks/myVPC.
Private services access must already be configured for the network.
If left unspecified, the job is not peered with any network.
encryption_spec_key_name (str):
Customer managed encryption key resource name.
project (str):
Optional. The project to run this PipelineJob in. If not set,
the project set in aiplatform.init will be used.
location (str):
Optional. Location to create PipelineJob. If not set,
location set in aiplatform.init will be used.
credentials (auth_credentials.Credentials):
Optional. Custom credentials to use to create the PipelineJob.
Overrides credentials set in aiplatform.init.
experiment (Union[str, experiments_resource.Experiment]):
Optional. The Vertex AI experiment name or instance to associate
to the PipelineJob executing this model evaluation job.
enable_caching (bool):
Optional. Whether to turn on caching for the run.
If this is not set, defaults to the compile time settings, which
are True for all tasks by default, while users may specify
different caching options for individual tasks.
If this is set, the setting applies to all tasks in the pipeline.
Overrides the compile time settings.
Returns:
(VertexAiPipelineBasedService):
Instantiated representation of a Vertex AI Pipeline based service.
"""
if not display_name:
display_name = cls._generate_display_name()
self = cls._empty_constructor(
project=project,
location=location,
credentials=credentials,
)
service_pipeline_job = pipeline_jobs.PipelineJob(
display_name=display_name,
template_path=template_path,
job_id=job_id,
pipeline_root=pipeline_root,
parameter_values=template_params,
encryption_spec_key_name=encryption_spec_key_name,
project=project,
location=location,
credentials=credentials,
enable_caching=enable_caching,
)
# Suppresses logs from PipelineJob
# The class implementing _VertexAiPipelineBasedService should define a
# custom log message via `_creation_log_message`
logging.getLogger("google.cloud.aiplatform.pipeline_jobs").setLevel(
logging.WARNING
)
service_pipeline_job.submit(
service_account=service_account,
network=network,
experiment=experiment,
)
logging.getLogger("google.cloud.aiplatform.pipeline_jobs").setLevel(
logging.INFO
)
self._gca_resource = service_pipeline_job.gca_resource
return self
@classmethod
def list(
cls,
project: Optional[str] = None,
location: Optional[str] = None,
credentials: Optional[str] = None,
) -> List["_VertexAiPipelineBasedService"]:
"""Lists all PipelineJob resources associated with this Pipeline Based service.
Args:
project (str):
Optional. The project to retrieve the Pipeline Based Services from.
If not set, the project set in aiplatform.init will be used.
location (str):
Optional. Location to retrieve the Pipeline Based Services from.
If not set, location set in aiplatform.init will be used.
credentials (auth_credentials.Credentials):
Optional. Custom credentials to use to retrieve the Pipeline Based
Services from. Overrides credentials set in aiplatform.init.
Returns:
(List[PipelineJob]):
A list of PipelineJob resource objects.
"""
filter_str = f"metadata.component_type.string_value={cls._component_identifier}"
filtered_pipeline_executions = aiplatform.Execution.list(
filter=filter_str, credentials=credentials
)
service_pipeline_jobs = []
for pipeline_execution in filtered_pipeline_executions:
if "pipeline_job_resource_name" in pipeline_execution.metadata:
# This is wrapped in a try/except for cases when both
# `_coponent_identifier` and `_template_name_identifier` are
# set. In that case, even though all pipelines returned by the
# Execution.list() call will match the `_component_identifier`,
# some may not match the `_template_name_identifier`
try:
service_pipeline_job = cls(
pipeline_execution.metadata["pipeline_job_resource_name"],
project=project,
location=location,
credentials=credentials,
)
service_pipeline_jobs.append(service_pipeline_job)
except ValueError:
continue
return service_pipeline_jobs
def wait(self):
"""Wait for the PipelineJob to complete."""
pipeline_run = self.backing_pipeline_job
if pipeline_run._latest_future is None:
pipeline_run._block_until_complete()
else:
pipeline_run.wait()

View File

@@ -0,0 +1,79 @@
# -*- coding: utf-8 -*-
# Copyright 2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import re
from typing import Optional
from google.auth import credentials as auth_credentials
from google.cloud.aiplatform import base
from google.cloud.aiplatform import utils
class _PublisherModel(base.VertexAiResourceNoun):
"""Publisher Model Resource for Vertex AI."""
client_class = utils.ModelGardenClientWithOverride
_resource_noun = "publisher_model"
_getter_method = "get_publisher_model"
_delete_method = None
_parse_resource_name_method = "parse_publisher_model_path"
_format_resource_name_method = "publisher_model_path"
def __init__(
self,
resource_name: str,
project: Optional[str] = None,
location: Optional[str] = None,
credentials: Optional[auth_credentials.Credentials] = None,
):
"""Retrieves an existing PublisherModel resource given a resource name or model garden id.
Args:
resource_name (str):
Required. A fully-qualified PublisherModel resource name or
model garden id. Format:
`publishers/{publisher}/models/{publisher_model}` or
`{publisher}/{publisher_model}`.
project (str):
Optional. Project to retrieve the resource from. If not set,
project set in aiplatform.init will be used.
location (str):
Optional. Location to retrieve the resource from. If not set,
location set in aiplatform.init will be used.
credentials (auth_credentials.Credentials):
Optional. Custom credentials to use to retrieve the resource.
Overrides credentials set in aiplatform.init.
"""
super().__init__(project=project, location=location, credentials=credentials)
if self._parse_resource_name(resource_name):
full_resource_name = resource_name
else:
m = re.match(r"^(?P<publisher>.+?)/(?P<model>.+?)$", resource_name)
if m:
full_resource_name = self._format_resource_name(**m.groupdict())
else:
raise ValueError(
f"`{resource_name}` is not a valid PublisherModel resource "
"name or model garden id."
)
self._gca_resource = getattr(self.api_client, self._getter_method)(
name=full_resource_name, retry=base._DEFAULT_RETRY
)

View File

@@ -0,0 +1,252 @@
# -*- coding: utf-8 -*-
# Copyright 2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""Streaming prediction functions."""
from typing import Any, AsyncIterator, Dict, Iterator, List, Optional, Sequence
from google.cloud.aiplatform_v1.services import prediction_service
from google.cloud.aiplatform_v1.types import (
prediction_service as prediction_service_types,
)
from google.cloud.aiplatform_v1.types import (
types as aiplatform_types,
)
def value_to_tensor(value: Any) -> aiplatform_types.Tensor:
"""Converts a Python value to `Tensor`.
Args:
value: A value to convert
Returns:
A `Tensor` object
"""
if value is None:
return aiplatform_types.Tensor()
elif isinstance(value, int):
return aiplatform_types.Tensor(int_val=[value])
elif isinstance(value, float):
return aiplatform_types.Tensor(float_val=[value])
elif isinstance(value, bool):
return aiplatform_types.Tensor(bool_val=[value])
elif isinstance(value, str):
return aiplatform_types.Tensor(string_val=[value])
elif isinstance(value, bytes):
return aiplatform_types.Tensor(bytes_val=[value])
elif isinstance(value, list):
return aiplatform_types.Tensor(list_val=[value_to_tensor(x) for x in value])
elif isinstance(value, dict):
return aiplatform_types.Tensor(
struct_val={k: value_to_tensor(v) for k, v in value.items()}
)
raise TypeError(f"Unsupported value type {type(value)}")
def tensor_to_value(tensor_pb: aiplatform_types.Tensor) -> Any:
"""Converts `Tensor` to a Python value.
Args:
tensor_pb: A `Tensor` object
Returns:
A corresponding Python object
"""
list_of_fields = tensor_pb.ListFields()
if not list_of_fields:
return None
descriptor, value = tensor_pb.ListFields()[0]
if descriptor.name == "list_val":
return [tensor_to_value(x) for x in value]
elif descriptor.name == "struct_val":
return {k: tensor_to_value(v) for k, v in value.items()}
if not isinstance(value, Sequence):
raise TypeError(f"Unexpected non-list tensor value {value}")
if len(value) == 1:
return value[0]
else:
return value
def predict_stream_of_tensor_lists_from_single_tensor_list(
prediction_service_client: prediction_service.PredictionServiceClient,
endpoint_name: str,
tensor_list: List[aiplatform_types.Tensor],
parameters_tensor: Optional[aiplatform_types.Tensor] = None,
) -> Iterator[List[aiplatform_types.Tensor]]:
"""Predicts a stream of lists of `Tensor` objects from a single list of `Tensor` objects.
Args:
tensor_list: Model input as a list of `Tensor` objects.
parameters_tensor: Optional. Prediction parameters in `Tensor` form.
prediction_service_client: A PredictionServiceClient object.
endpoint_name: Resource name of Endpoint or PublisherModel.
Yields:
A generator of model prediction `Tensor` lists.
"""
request = prediction_service_types.StreamingPredictRequest(
endpoint=endpoint_name,
inputs=tensor_list,
parameters=parameters_tensor,
)
for response in prediction_service_client.server_streaming_predict(request=request):
yield response.outputs
async def predict_stream_of_tensor_lists_from_single_tensor_list_async(
prediction_service_async_client: prediction_service.PredictionServiceAsyncClient,
endpoint_name: str,
tensor_list: List[aiplatform_types.Tensor],
parameters_tensor: Optional[aiplatform_types.Tensor] = None,
) -> AsyncIterator[List[aiplatform_types.Tensor]]:
"""Asynchronously predicts a stream of lists of `Tensor` objects from a single list of `Tensor` objects.
Args:
tensor_list: Model input as a list of `Tensor` objects.
parameters_tensor: Optional. Prediction parameters in `Tensor` form.
prediction_service_async_client: A PredictionServiceAsyncClient object.
endpoint_name: Resource name of Endpoint or PublisherModel.
Yields:
A generator of model prediction `Tensor` lists.
"""
request = prediction_service_types.StreamingPredictRequest(
endpoint=endpoint_name,
inputs=tensor_list,
parameters=parameters_tensor,
)
async for response in await prediction_service_async_client.server_streaming_predict(
request=request
):
yield response.outputs
def predict_stream_of_dict_lists_from_single_dict_list(
prediction_service_client: prediction_service.PredictionServiceClient,
endpoint_name: str,
dict_list: List[Dict[str, Any]],
parameters: Optional[Dict[str, Any]] = None,
) -> Iterator[List[Dict[str, Any]]]:
"""Predicts a stream of lists of dicts from a stream of lists of dicts.
Args:
dict_list: Model input as a list of `dict` objects.
parameters: Optional. Prediction parameters `dict` form.
prediction_service_client: A PredictionServiceClient object.
endpoint_name: Resource name of Endpoint or PublisherModel.
Yields:
A generator of model prediction dict lists.
"""
tensor_list = [value_to_tensor(d) for d in dict_list]
parameters_tensor = value_to_tensor(parameters) if parameters else None
for tensor_list in predict_stream_of_tensor_lists_from_single_tensor_list(
prediction_service_client=prediction_service_client,
endpoint_name=endpoint_name,
tensor_list=tensor_list,
parameters_tensor=parameters_tensor,
):
yield [tensor_to_value(tensor._pb) for tensor in tensor_list]
async def predict_stream_of_dict_lists_from_single_dict_list_async(
prediction_service_async_client: prediction_service.PredictionServiceAsyncClient,
endpoint_name: str,
dict_list: List[Dict[str, Any]],
parameters: Optional[Dict[str, Any]] = None,
) -> AsyncIterator[List[Dict[str, Any]]]:
"""Asynchronously predicts a stream of lists of dicts from a stream of lists of dicts.
Args:
dict_list: Model input as a list of `dict` objects.
parameters: Optional. Prediction parameters `dict` form.
prediction_service_async_client: A PredictionServiceAsyncClient object.
endpoint_name: Resource name of Endpoint or PublisherModel.
Yields:
A generator of model prediction dict lists.
"""
tensor_list = [value_to_tensor(d) for d in dict_list]
parameters_tensor = value_to_tensor(parameters) if parameters else None
async for tensor_list in predict_stream_of_tensor_lists_from_single_tensor_list_async(
prediction_service_async_client=prediction_service_async_client,
endpoint_name=endpoint_name,
tensor_list=tensor_list,
parameters_tensor=parameters_tensor,
):
yield [tensor_to_value(tensor._pb) for tensor in tensor_list]
def predict_stream_of_dicts_from_single_dict(
prediction_service_client: prediction_service.PredictionServiceClient,
endpoint_name: str,
instance: Dict[str, Any],
parameters: Optional[Dict[str, Any]] = None,
) -> Iterator[Dict[str, Any]]:
"""Predicts a stream of dicts from a single instance dict.
Args:
instance: A single input instance `dict`.
parameters: Optional. Prediction parameters `dict`.
prediction_service_client: A PredictionServiceClient object.
endpoint_name: Resource name of Endpoint or PublisherModel.
Yields:
A generator of model prediction dicts.
"""
for dict_list in predict_stream_of_dict_lists_from_single_dict_list(
prediction_service_client=prediction_service_client,
endpoint_name=endpoint_name,
dict_list=[instance],
parameters=parameters,
):
if len(dict_list) > 1:
raise ValueError(
f"Expected to receive a single output, but got {dict_list}"
)
yield dict_list[0]
async def predict_stream_of_dicts_from_single_dict_async(
prediction_service_async_client: prediction_service.PredictionServiceAsyncClient,
endpoint_name: str,
instance: Dict[str, Any],
parameters: Optional[Dict[str, Any]] = None,
) -> AsyncIterator[Dict[str, Any]]:
"""Asynchronously predicts a stream of dicts from a single instance dict.
Args:
instance: A single input instance `dict`.
parameters: Optional. Prediction parameters `dict`.
prediction_service_async_client: A PredictionServiceAsyncClient object.
endpoint_name: Resource name of Endpoint or PublisherModel.
Yields:
A generator of model prediction dicts.
"""
async for dict_list in predict_stream_of_dict_lists_from_single_dict_list_async(
prediction_service_async_client=prediction_service_async_client,
endpoint_name=endpoint_name,
dict_list=[instance],
parameters=parameters,
):
if len(dict_list) > 1:
raise ValueError(
f"Expected to receive a single output, but got {dict_list}"
)
yield dict_list[0]

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,297 @@
# -*- coding: utf-8 -*-
# Copyright 2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from google.cloud.aiplatform.compat import services
from google.cloud.aiplatform.compat import types
V1BETA1 = "v1beta1"
V1 = "v1"
DEFAULT_VERSION = V1
if DEFAULT_VERSION == V1BETA1:
services.dataset_service_client = services.dataset_service_client_v1beta1
services.deployment_resource_pool_service_client = (
services.deployment_resource_pool_service_client_v1beta1
)
services.endpoint_service_client = services.endpoint_service_client_v1beta1
services.feature_online_store_admin_service_client = (
services.feature_online_store_admin_service_client_v1beta1
)
services.feature_online_store_service_client = (
services.feature_online_store_service_client_v1beta1
)
services.feature_registry_service_client = (
services.feature_registry_service_client_v1beta1
)
services.featurestore_online_serving_service_client = (
services.featurestore_online_serving_service_client_v1beta1
)
services.featurestore_service_client = services.featurestore_service_client_v1beta1
services.gen_ai_cache_service_client = services.gen_ai_cache_service_client_v1beta1
services.job_service_client = services.job_service_client_v1beta1
services.model_service_client = services.model_service_client_v1beta1
services.model_garden_service_client = services.model_garden_service_client_v1beta1
services.pipeline_service_client = services.pipeline_service_client_v1beta1
services.prediction_service_client = services.prediction_service_client_v1beta1
services.prediction_service_async_client = (
services.prediction_service_async_client_v1beta1
)
services.schedule_service_client = services.schedule_service_client_v1beta1
services.specialist_pool_service_client = (
services.specialist_pool_service_client_v1beta1
)
services.match_service_client = services.match_service_client_v1beta1
services.metadata_service_client = services.metadata_service_client_v1beta1
services.tensorboard_service_client = services.tensorboard_service_client_v1beta1
services.index_service_client = services.index_service_client_v1beta1
services.index_endpoint_service_client = (
services.index_endpoint_service_client_v1beta1
)
services.vizier_service_client = services.vizier_service_client_v1beta1
types.accelerator_type = types.accelerator_type_v1beta1
types.annotation = types.annotation_v1beta1
types.annotation_spec = types.annotation_spec_v1beta1
types.artifact = types.artifact_v1beta1
types.batch_prediction_job = types.batch_prediction_job_v1beta1
types.cached_content = types.cached_content_v1beta1
types.completion_stats = types.completion_stats_v1beta1
types.context = types.context_v1beta1
types.custom_job = types.custom_job_v1beta1
types.data_item = types.data_item_v1beta1
types.data_labeling_job = types.data_labeling_job_v1beta1
types.dataset = types.dataset_v1beta1
types.dataset_service = types.dataset_service_v1beta1
types.deployed_model_ref = types.deployed_model_ref_v1beta1
types.deployment_resource_pool = types.deployment_resource_pool_v1beta1
types.deployment_resource_pool_service = (
types.deployment_resource_pool_service_v1beta1
)
types.encryption_spec = types.encryption_spec_v1beta1
types.endpoint = types.endpoint_v1beta1
types.endpoint_service = types.endpoint_service_v1beta1
types.entity_type = types.entity_type_v1beta1
types.env_var = types.env_var_v1beta1
types.event = types.event_v1beta1
types.execution = types.execution_v1beta1
types.explanation = types.explanation_v1beta1
types.explanation_metadata = types.explanation_metadata_v1beta1
types.feature = types.feature_v1beta1
types.feature_group = types.feature_group_v1beta1
types.feature_monitor = types.feature_monitor_v1beta1
types.feature_monitor_job = types.feature_monitor_job_v1beta1
types.feature_monitoring_stats = types.feature_monitoring_stats_v1beta1
types.feature_online_store = types.feature_online_store_v1beta1
types.feature_online_store_admin_service = (
types.feature_online_store_admin_service_v1beta1
)
types.feature_registry_service = types.feature_registry_service_v1beta1
types.feature_online_store_service = types.feature_online_store_service_v1beta1
types.feature_selector = types.feature_selector_v1beta1
types.feature_view = types.feature_view_v1beta1
types.feature_view_sync = types.feature_view_sync_v1beta1
types.featurestore = types.featurestore_v1beta1
types.featurestore_monitoring = types.featurestore_monitoring_v1beta1
types.featurestore_online_service = types.featurestore_online_service_v1beta1
types.featurestore_service = types.featurestore_service_v1beta1
types.hyperparameter_tuning_job = types.hyperparameter_tuning_job_v1beta1
types.index = types.index_v1beta1
types.index_endpoint = types.index_endpoint_v1beta1
types.index_service = types.index_service_v1beta1
types.io = types.io_v1beta1
types.job_service = types.job_service_v1beta1
types.job_state = types.job_state_v1beta1
types.lineage_subgraph = types.lineage_subgraph_v1beta1
types.machine_resources = types.machine_resources_v1beta1
types.manual_batch_tuning_parameters = types.manual_batch_tuning_parameters_v1beta1
types.matching_engine_deployed_index_ref = (
types.matching_engine_deployed_index_ref_v1beta1
)
types.matching_engine_index = types.index_v1beta1
types.matching_engine_index_endpoint = types.index_endpoint_v1beta1
types.metadata_service = types.metadata_service_v1beta1
types.metadata_schema = types.metadata_schema_v1beta1
types.metadata_store = types.metadata_store_v1beta1
types.model = types.model_v1beta1
types.model_evaluation = types.model_evaluation_v1beta1
types.model_evaluation_slice = types.model_evaluation_slice_v1beta1
types.model_deployment_monitoring_job = (
types.model_deployment_monitoring_job_v1beta1
)
types.model_garden_service = types.model_garden_service_v1beta1
types.model_monitoring = types.model_monitoring_v1beta1
types.model_service = types.model_service_v1beta1
types.service_networking = types.service_networking_v1beta1
types.operation = types.operation_v1beta1
types.pipeline_failure_policy = types.pipeline_failure_policy_v1beta1
types.pipeline_job = types.pipeline_job_v1beta1
types.pipeline_service = types.pipeline_service_v1beta1
types.pipeline_state = types.pipeline_state_v1beta1
types.prediction_service = types.prediction_service_v1beta1
types.publisher_model = types.publisher_model_v1beta1
types.schedule = types.schedule_v1beta1
types.schedule_service = types.schedule_service_v1beta1
types.specialist_pool = types.specialist_pool_v1beta1
types.specialist_pool_service = types.specialist_pool_service_v1beta1
types.study = types.study_v1beta1
types.tensorboard = types.tensorboard_v1beta1
types.tensorboard_service = types.tensorboard_service_v1beta1
types.tensorboard_data = types.tensorboard_data_v1beta1
types.tensorboard_experiment = types.tensorboard_experiment_v1beta1
types.tensorboard_run = types.tensorboard_run_v1beta1
types.tensorboard_service = types.tensorboard_service_v1beta1
types.tensorboard_time_series = types.tensorboard_time_series_v1beta1
types.training_pipeline = types.training_pipeline_v1beta1
types.types = types.types_v1beta1
types.vizier_service = types.vizier_service_v1beta1
if DEFAULT_VERSION == V1:
services.dataset_service_client = services.dataset_service_client_v1
services.deployment_resource_pool_service_client = (
services.deployment_resource_pool_service_client_v1
)
services.endpoint_service_client = services.endpoint_service_client_v1
services.feature_online_store_admin_service_client = (
services.feature_online_store_admin_service_client_v1
)
services.feature_registry_service_client = (
services.feature_registry_service_client_v1
)
services.feature_online_store_service_client = (
services.feature_online_store_service_client_v1
)
services.featurestore_online_serving_service_client = (
services.featurestore_online_serving_service_client_v1
)
services.featurestore_service_client = services.featurestore_service_client_v1
services.gen_ai_cache_service_client = services.gen_ai_cache_service_client_v1
services.job_service_client = services.job_service_client_v1
services.model_garden_service_client = services.model_garden_service_client_v1
services.model_service_client = services.model_service_client_v1
services.pipeline_service_client = services.pipeline_service_client_v1
services.prediction_service_client = services.prediction_service_client_v1
services.prediction_service_async_client = (
services.prediction_service_async_client_v1
)
services.schedule_service_client = services.schedule_service_client_v1
services.specialist_pool_service_client = services.specialist_pool_service_client_v1
services.tensorboard_service_client = services.tensorboard_service_client_v1
services.index_service_client = services.index_service_client_v1
services.index_endpoint_service_client = services.index_endpoint_service_client_v1
services.vizier_service_client = services.vizier_service_client_v1
types.accelerator_type = types.accelerator_type_v1
types.annotation = types.annotation_v1
types.annotation_spec = types.annotation_spec_v1
types.artifact = types.artifact_v1
types.batch_prediction_job = types.batch_prediction_job_v1
types.cached_content = types.cached_content_v1
types.completion_stats = types.completion_stats_v1
types.context = types.context_v1
types.custom_job = types.custom_job_v1
types.data_item = types.data_item_v1
types.data_labeling_job = types.data_labeling_job_v1
types.dataset = types.dataset_v1
types.dataset_service = types.dataset_service_v1
types.deployed_model_ref = types.deployed_model_ref_v1
types.deployment_resource_pool = types.deployment_resource_pool_v1
types.deployment_resource_pool_service = types.deployment_resource_pool_service_v1
types.encryption_spec = types.encryption_spec_v1
types.endpoint = types.endpoint_v1
types.endpoint_service = types.endpoint_service_v1
types.entity_type = types.entity_type_v1
types.env_var = types.env_var_v1
types.event = types.event_v1
types.execution = types.execution_v1
types.explanation = types.explanation_v1
types.explanation_metadata = types.explanation_metadata_v1
types.feature = types.feature_v1
types.feature_group = types.feature_group_v1
# TODO(b/293184410): Temporary code. Switch to v1 once v1 is available.
types.feature_monitor = types.feature_monitor_v1beta1
types.feature_monitor_job = types.feature_monitor_job_v1beta1
types.feature_monitoring_stats = types.feature_monitoring_stats_v1
types.feature_online_store = types.feature_online_store_v1
types.feature_online_store_admin_service = (
types.feature_online_store_admin_service_v1
)
types.feature_registry_service = types.feature_registry_service_v1
types.feature_online_store_service = types.feature_online_store_service_v1
types.feature_selector = types.feature_selector_v1
types.feature_view = types.feature_view_v1
types.feature_view_sync = types.feature_view_sync_v1
types.featurestore = types.featurestore_v1
types.featurestore_online_service = types.featurestore_online_service_v1
types.featurestore_service = types.featurestore_service_v1
types.hyperparameter_tuning_job = types.hyperparameter_tuning_job_v1
types.index = types.index_v1
types.index_endpoint = types.index_endpoint_v1
types.index_service = types.index_service_v1
types.io = types.io_v1
types.job_service = types.job_service_v1
types.job_state = types.job_state_v1
types.lineage_subgraph = types.lineage_subgraph_v1
types.machine_resources = types.machine_resources_v1
types.manual_batch_tuning_parameters = types.manual_batch_tuning_parameters_v1
types.matching_engine_deployed_index_ref = (
types.matching_engine_deployed_index_ref_v1
)
types.matching_engine_index = types.index_v1
types.matching_engine_index_endpoint = types.index_endpoint_v1
types.metadata_service = types.metadata_service_v1
types.metadata_schema = types.metadata_schema_v1
types.metadata_store = types.metadata_store_v1
types.model = types.model_v1
types.model_evaluation = types.model_evaluation_v1
types.model_evaluation_slice = types.model_evaluation_slice_v1
types.model_deployment_monitoring_job = types.model_deployment_monitoring_job_v1
types.model_monitoring = types.model_monitoring_v1
types.model_service = types.model_service_v1
types.service_networking = types.service_networking_v1
types.operation = types.operation_v1
types.pipeline_failure_policy = types.pipeline_failure_policy_v1
types.pipeline_job = types.pipeline_job_v1
types.pipeline_service = types.pipeline_service_v1
types.pipeline_state = types.pipeline_state_v1
types.prediction_service = types.prediction_service_v1
types.publisher_model = types.publisher_model_v1
types.schedule = types.schedule_v1
types.schedule_service = types.schedule_service_v1
types.specialist_pool = types.specialist_pool_v1
types.specialist_pool_service = types.specialist_pool_service_v1
types.study = types.study_v1
types.tensorboard = types.tensorboard_v1
types.tensorboard_service = types.tensorboard_service_v1
types.tensorboard_data = types.tensorboard_data_v1
types.tensorboard_experiment = types.tensorboard_experiment_v1
types.tensorboard_run = types.tensorboard_run_v1
types.tensorboard_service = types.tensorboard_service_v1
types.tensorboard_time_series = types.tensorboard_time_series_v1
types.training_pipeline = types.training_pipeline_v1
types.types = types.types_v1
types.vizier_service = types.vizier_service_v1
__all__ = (
DEFAULT_VERSION,
V1BETA1,
V1,
services,
types,
)

View File

@@ -0,0 +1,264 @@
# -*- coding: utf-8 -*-
# Copyright 2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from google.cloud.aiplatform_v1beta1.services.dataset_service import (
client as dataset_service_client_v1beta1,
)
from google.cloud.aiplatform_v1beta1.services.deployment_resource_pool_service import (
client as deployment_resource_pool_service_client_v1beta1,
)
from google.cloud.aiplatform_v1beta1.services.endpoint_service import (
client as endpoint_service_client_v1beta1,
)
from google.cloud.aiplatform_v1beta1.services.example_store_service import (
client as example_store_service_client_v1beta1,
)
from google.cloud.aiplatform_v1beta1.services.extension_execution_service import (
client as extension_execution_service_client_v1beta1,
)
from google.cloud.aiplatform_v1beta1.services.extension_registry_service import (
client as extension_registry_service_client_v1beta1,
)
from google.cloud.aiplatform_v1beta1.services.feature_online_store_service import (
client as feature_online_store_service_client_v1beta1,
)
from google.cloud.aiplatform_v1beta1.services.feature_online_store_admin_service import (
client as feature_online_store_admin_service_client_v1beta1,
)
from google.cloud.aiplatform_v1beta1.services.feature_registry_service import (
client as feature_registry_service_client_v1beta1,
)
from google.cloud.aiplatform_v1beta1.services.featurestore_online_serving_service import (
client as featurestore_online_serving_service_client_v1beta1,
)
from google.cloud.aiplatform_v1beta1.services.featurestore_service import (
client as featurestore_service_client_v1beta1,
)
from google.cloud.aiplatform_v1beta1.services.gen_ai_cache_service import (
client as gen_ai_cache_service_client_v1beta1,
)
from google.cloud.aiplatform_v1beta1.services.index_service import (
client as index_service_client_v1beta1,
)
from google.cloud.aiplatform_v1beta1.services.index_endpoint_service import (
client as index_endpoint_service_client_v1beta1,
)
from google.cloud.aiplatform_v1beta1.services.job_service import (
client as job_service_client_v1beta1,
)
from google.cloud.aiplatform_v1beta1.services.match_service import (
client as match_service_client_v1beta1,
)
from google.cloud.aiplatform_v1beta1.services.metadata_service import (
client as metadata_service_client_v1beta1,
)
from google.cloud.aiplatform_v1beta1.services.model_garden_service import (
client as model_garden_service_client_v1beta1,
)
from google.cloud.aiplatform_v1beta1.services.model_service import (
client as model_service_client_v1beta1,
)
from google.cloud.aiplatform_v1beta1.services.model_monitoring_service import (
client as model_monitoring_service_client_v1beta1,
)
from google.cloud.aiplatform_v1beta1.services.persistent_resource_service import (
client as persistent_resource_service_client_v1beta1,
)
from google.cloud.aiplatform_v1beta1.services.pipeline_service import (
client as pipeline_service_client_v1beta1,
)
from google.cloud.aiplatform_v1beta1.services.prediction_service import (
client as prediction_service_client_v1beta1,
)
from google.cloud.aiplatform_v1beta1.services.prediction_service import (
async_client as prediction_service_async_client_v1beta1,
)
from google.cloud.aiplatform_v1beta1.services.reasoning_engine_service import (
client as reasoning_engine_service_client_v1beta1,
)
from google.cloud.aiplatform_v1beta1.services.reasoning_engine_execution_service import (
client as reasoning_engine_execution_service_client_v1beta1,
)
from google.cloud.aiplatform_v1beta1.services.schedule_service import (
client as schedule_service_client_v1beta1,
)
from google.cloud.aiplatform_v1beta1.services.specialist_pool_service import (
client as specialist_pool_service_client_v1beta1,
)
from google.cloud.aiplatform_v1beta1.services.tensorboard_service import (
client as tensorboard_service_client_v1beta1,
)
from google.cloud.aiplatform_v1beta1.services.vertex_rag_data_service import (
client as vertex_rag_data_service_client_v1beta1,
)
from google.cloud.aiplatform_v1beta1.services.vertex_rag_data_service import (
async_client as vertex_rag_data_service_async_client_v1beta1,
)
from google.cloud.aiplatform_v1beta1.services.vertex_rag_service import (
client as vertex_rag_service_client_v1beta1,
)
from google.cloud.aiplatform_v1beta1.services.vizier_service import (
client as vizier_service_client_v1beta1,
)
from google.cloud.aiplatform_v1.services.dataset_service import (
client as dataset_service_client_v1,
)
from google.cloud.aiplatform_v1.services.deployment_resource_pool_service import (
client as deployment_resource_pool_service_client_v1,
)
from google.cloud.aiplatform_v1.services.endpoint_service import (
client as endpoint_service_client_v1,
)
from google.cloud.aiplatform_v1.services.feature_online_store_service import (
client as feature_online_store_service_client_v1,
)
from google.cloud.aiplatform_v1.services.feature_online_store_admin_service import (
client as feature_online_store_admin_service_client_v1,
)
from google.cloud.aiplatform_v1.services.feature_registry_service import (
client as feature_registry_service_client_v1,
)
from google.cloud.aiplatform_v1.services.featurestore_online_serving_service import (
client as featurestore_online_serving_service_client_v1,
)
from google.cloud.aiplatform_v1.services.featurestore_service import (
client as featurestore_service_client_v1,
)
from google.cloud.aiplatform_v1.services.gen_ai_cache_service import (
client as gen_ai_cache_service_client_v1,
)
from google.cloud.aiplatform_v1.services.index_service import (
client as index_service_client_v1,
)
from google.cloud.aiplatform_v1.services.index_endpoint_service import (
client as index_endpoint_service_client_v1,
)
from google.cloud.aiplatform_v1.services.job_service import (
client as job_service_client_v1,
)
from google.cloud.aiplatform_v1.services.metadata_service import (
client as metadata_service_client_v1,
)
from google.cloud.aiplatform_v1.services.model_garden_service import (
client as model_garden_service_client_v1,
)
from google.cloud.aiplatform_v1.services.model_service import (
client as model_service_client_v1,
)
from google.cloud.aiplatform_v1.services.persistent_resource_service import (
client as persistent_resource_service_client_v1,
)
from google.cloud.aiplatform_v1.services.pipeline_service import (
client as pipeline_service_client_v1,
)
from google.cloud.aiplatform_v1.services.prediction_service import (
client as prediction_service_client_v1,
)
from google.cloud.aiplatform_v1.services.prediction_service import (
async_client as prediction_service_async_client_v1,
)
from google.cloud.aiplatform_v1.services.reasoning_engine_service import (
client as reasoning_engine_service_client_v1,
)
from google.cloud.aiplatform_v1.services.reasoning_engine_execution_service import (
client as reasoning_engine_execution_service_client_v1,
)
from google.cloud.aiplatform_v1.services.schedule_service import (
client as schedule_service_client_v1,
)
from google.cloud.aiplatform_v1.services.specialist_pool_service import (
client as specialist_pool_service_client_v1,
)
from google.cloud.aiplatform_v1.services.tensorboard_service import (
client as tensorboard_service_client_v1,
)
from google.cloud.aiplatform_v1.services.vizier_service import (
client as vizier_service_client_v1,
)
from google.cloud.aiplatform_v1.services.vertex_rag_service import (
client as vertex_rag_service_client_v1,
)
from google.cloud.aiplatform_v1.services.vertex_rag_data_service import (
client as vertex_rag_data_service_client_v1,
)
from google.cloud.aiplatform_v1.services.vertex_rag_data_service import (
async_client as vertex_rag_data_service_async_client_v1,
)
__all__ = (
# v1
dataset_service_client_v1,
deployment_resource_pool_service_client_v1,
endpoint_service_client_v1,
feature_online_store_service_client_v1,
feature_online_store_admin_service_client_v1,
feature_registry_service_client_v1,
featurestore_online_serving_service_client_v1,
featurestore_service_client_v1,
index_service_client_v1,
index_endpoint_service_client_v1,
job_service_client_v1,
metadata_service_client_v1,
model_garden_service_client_v1,
model_service_client_v1,
persistent_resource_service_client_v1,
pipeline_service_client_v1,
prediction_service_client_v1,
prediction_service_async_client_v1,
reasoning_engine_execution_service_client_v1,
reasoning_engine_service_client_v1,
schedule_service_client_v1,
specialist_pool_service_client_v1,
tensorboard_service_client_v1,
vizier_service_client_v1,
vertex_rag_data_service_async_client_v1,
vertex_rag_data_service_client_v1,
vertex_rag_service_client_v1,
# v1beta1
dataset_service_client_v1beta1,
deployment_resource_pool_service_client_v1beta1,
endpoint_service_client_v1beta1,
example_store_service_client_v1beta1,
feature_online_store_service_client_v1beta1,
feature_online_store_admin_service_client_v1beta1,
feature_registry_service_client_v1beta1,
featurestore_online_serving_service_client_v1beta1,
featurestore_service_client_v1beta1,
index_service_client_v1beta1,
index_endpoint_service_client_v1beta1,
job_service_client_v1beta1,
match_service_client_v1beta1,
model_garden_service_client_v1beta1,
model_monitoring_service_client_v1beta1,
model_service_client_v1beta1,
persistent_resource_service_client_v1beta1,
pipeline_service_client_v1beta1,
prediction_service_client_v1beta1,
prediction_service_async_client_v1beta1,
reasoning_engine_execution_service_client_v1beta1,
reasoning_engine_service_client_v1beta1,
schedule_service_client_v1beta1,
specialist_pool_service_client_v1beta1,
metadata_service_client_v1beta1,
tensorboard_service_client_v1beta1,
vertex_rag_service_client_v1beta1,
vertex_rag_data_service_client_v1beta1,
vertex_rag_data_service_async_client_v1beta1,
vizier_service_client_v1beta1,
)

View File

@@ -0,0 +1,361 @@
# -*- coding: utf-8 -*-
# Copyright 2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from google.cloud.aiplatform_v1beta1.types import (
accelerator_type as accelerator_type_v1beta1,
annotation as annotation_v1beta1,
annotation_spec as annotation_spec_v1beta1,
artifact as artifact_v1beta1,
batch_prediction_job as batch_prediction_job_v1beta1,
cached_content as cached_content_v1beta1,
completion_stats as completion_stats_v1beta1,
context as context_v1beta1,
custom_job as custom_job_v1beta1,
data_item as data_item_v1beta1,
data_labeling_job as data_labeling_job_v1beta1,
dataset as dataset_v1beta1,
dataset_service as dataset_service_v1beta1,
deployed_index_ref as matching_engine_deployed_index_ref_v1beta1,
deployed_model_ref as deployed_model_ref_v1beta1,
deployment_resource_pool as deployment_resource_pool_v1beta1,
deployment_resource_pool_service as deployment_resource_pool_service_v1beta1,
encryption_spec as encryption_spec_v1beta1,
endpoint as endpoint_v1beta1,
endpoint_service as endpoint_service_v1beta1,
entity_type as entity_type_v1beta1,
env_var as env_var_v1beta1,
event as event_v1beta1,
execution as execution_v1beta1,
explanation as explanation_v1beta1,
explanation_metadata as explanation_metadata_v1beta1,
feature as feature_v1beta1,
feature_group as feature_group_v1beta1,
feature_monitor as feature_monitor_v1beta1,
feature_monitor_job as feature_monitor_job_v1beta1,
feature_monitoring_stats as feature_monitoring_stats_v1beta1,
feature_online_store as feature_online_store_v1beta1,
feature_online_store_admin_service as feature_online_store_admin_service_v1beta1,
feature_online_store_service as feature_online_store_service_v1beta1,
feature_registry_service as feature_registry_service_v1beta1,
feature_selector as feature_selector_v1beta1,
feature_view as feature_view_v1beta1,
feature_view_sync as feature_view_sync_v1beta1,
featurestore as featurestore_v1beta1,
featurestore_monitoring as featurestore_monitoring_v1beta1,
featurestore_online_service as featurestore_online_service_v1beta1,
featurestore_service as featurestore_service_v1beta1,
gen_ai_cache_service as gen_ai_cache_service_v1beta1,
index as index_v1beta1,
index_endpoint as index_endpoint_v1beta1,
hyperparameter_tuning_job as hyperparameter_tuning_job_v1beta1,
io as io_v1beta1,
index_service as index_service_v1beta1,
job_service as job_service_v1beta1,
job_state as job_state_v1beta1,
lineage_subgraph as lineage_subgraph_v1beta1,
machine_resources as machine_resources_v1beta1,
manual_batch_tuning_parameters as manual_batch_tuning_parameters_v1beta1,
match_service as match_service_v1beta1,
metadata_schema as metadata_schema_v1beta1,
metadata_service as metadata_service_v1beta1,
metadata_store as metadata_store_v1beta1,
model as model_v1beta1,
model_evaluation as model_evaluation_v1beta1,
model_evaluation_slice as model_evaluation_slice_v1beta1,
model_deployment_monitoring_job as model_deployment_monitoring_job_v1beta1,
model_garden_service as model_garden_service_v1beta1,
model_service as model_service_v1beta1,
model_monitor as model_monitor_v1beta1,
model_monitoring as model_monitoring_v1beta1,
model_monitoring_alert as model_monitoring_alert_v1beta1,
model_monitoring_job as model_monitoring_job_v1beta1,
model_monitoring_service as model_monitoring_service_v1beta1,
model_monitoring_spec as model_monitoring_spec_v1beta1,
model_monitoring_stats as model_monitoring_stats_v1beta1,
operation as operation_v1beta1,
persistent_resource as persistent_resource_v1beta1,
persistent_resource_service as persistent_resource_service_v1beta1,
pipeline_failure_policy as pipeline_failure_policy_v1beta1,
pipeline_job as pipeline_job_v1beta1,
pipeline_service as pipeline_service_v1beta1,
pipeline_state as pipeline_state_v1beta1,
prediction_service as prediction_service_v1beta1,
publisher_model as publisher_model_v1beta1,
reservation_affinity as reservation_affinity_v1beta1,
service_networking as service_networking_v1beta1,
schedule as schedule_v1beta1,
schedule_service as schedule_service_v1beta1,
specialist_pool as specialist_pool_v1beta1,
specialist_pool_service as specialist_pool_service_v1beta1,
study as study_v1beta1,
tensorboard as tensorboard_v1beta1,
tensorboard_data as tensorboard_data_v1beta1,
tensorboard_experiment as tensorboard_experiment_v1beta1,
tensorboard_run as tensorboard_run_v1beta1,
tensorboard_service as tensorboard_service_v1beta1,
tensorboard_time_series as tensorboard_time_series_v1beta1,
training_pipeline as training_pipeline_v1beta1,
types as types_v1beta1,
vizier_service as vizier_service_v1beta1,
)
from google.cloud.aiplatform_v1.types import (
accelerator_type as accelerator_type_v1,
annotation as annotation_v1,
annotation_spec as annotation_spec_v1,
artifact as artifact_v1,
batch_prediction_job as batch_prediction_job_v1,
cached_content as cached_content_v1,
completion_stats as completion_stats_v1,
context as context_v1,
custom_job as custom_job_v1,
data_item as data_item_v1,
data_labeling_job as data_labeling_job_v1,
dataset as dataset_v1,
dataset_service as dataset_service_v1,
deployed_index_ref as matching_engine_deployed_index_ref_v1,
deployed_model_ref as deployed_model_ref_v1,
deployment_resource_pool as deployment_resource_pool_v1,
deployment_resource_pool_service as deployment_resource_pool_service_v1,
encryption_spec as encryption_spec_v1,
endpoint as endpoint_v1,
endpoint_service as endpoint_service_v1,
entity_type as entity_type_v1,
env_var as env_var_v1,
event as event_v1,
execution as execution_v1,
explanation as explanation_v1,
explanation_metadata as explanation_metadata_v1,
feature as feature_v1,
feature_group as feature_group_v1,
feature_monitoring_stats as feature_monitoring_stats_v1,
feature_online_store as feature_online_store_v1,
feature_online_store_admin_service as feature_online_store_admin_service_v1,
feature_online_store_service as feature_online_store_service_v1,
feature_registry_service as feature_registry_service_v1,
feature_selector as feature_selector_v1,
feature_view as feature_view_v1,
feature_view_sync as feature_view_sync_v1,
featurestore as featurestore_v1,
featurestore_online_service as featurestore_online_service_v1,
featurestore_service as featurestore_service_v1,
hyperparameter_tuning_job as hyperparameter_tuning_job_v1,
index as index_v1,
index_endpoint as index_endpoint_v1,
index_service as index_service_v1,
io as io_v1,
job_service as job_service_v1,
job_state as job_state_v1,
lineage_subgraph as lineage_subgraph_v1,
machine_resources as machine_resources_v1,
manual_batch_tuning_parameters as manual_batch_tuning_parameters_v1,
metadata_service as metadata_service_v1,
metadata_schema as metadata_schema_v1,
metadata_store as metadata_store_v1,
model as model_v1,
model_evaluation as model_evaluation_v1,
model_evaluation_slice as model_evaluation_slice_v1,
model_deployment_monitoring_job as model_deployment_monitoring_job_v1,
model_service as model_service_v1,
model_monitoring as model_monitoring_v1,
operation as operation_v1,
persistent_resource as persistent_resource_v1,
persistent_resource_service as persistent_resource_service_v1,
pipeline_failure_policy as pipeline_failure_policy_v1,
pipeline_job as pipeline_job_v1,
pipeline_service as pipeline_service_v1,
pipeline_state as pipeline_state_v1,
prediction_service as prediction_service_v1,
publisher_model as publisher_model_v1,
reservation_affinity as reservation_affinity_v1,
schedule as schedule_v1,
schedule_service as schedule_service_v1,
service_networking as service_networking_v1,
specialist_pool as specialist_pool_v1,
specialist_pool_service as specialist_pool_service_v1,
study as study_v1,
tensorboard as tensorboard_v1,
tensorboard_data as tensorboard_data_v1,
tensorboard_experiment as tensorboard_experiment_v1,
tensorboard_run as tensorboard_run_v1,
tensorboard_service as tensorboard_service_v1,
tensorboard_time_series as tensorboard_time_series_v1,
training_pipeline as training_pipeline_v1,
types as types_v1,
vizier_service as vizier_service_v1,
)
__all__ = (
# v1
accelerator_type_v1,
annotation_v1,
annotation_spec_v1,
artifact_v1,
batch_prediction_job_v1,
completion_stats_v1,
context_v1,
custom_job_v1,
data_item_v1,
data_labeling_job_v1,
dataset_v1,
dataset_service_v1,
deployed_model_ref_v1,
deployment_resource_pool_v1,
deployment_resource_pool_service_v1,
encryption_spec_v1,
endpoint_v1,
endpoint_service_v1,
entity_type_v1,
env_var_v1,
event_v1,
execution_v1,
explanation_v1,
explanation_metadata_v1,
feature_v1,
feature_monitoring_stats_v1,
feature_selector_v1,
featurestore_v1,
featurestore_online_service_v1,
featurestore_service_v1,
hyperparameter_tuning_job_v1,
io_v1,
job_service_v1,
job_state_v1,
lineage_subgraph_v1,
machine_resources_v1,
manual_batch_tuning_parameters_v1,
matching_engine_deployed_index_ref_v1,
index_v1,
index_endpoint_v1,
index_service_v1,
metadata_service_v1,
metadata_schema_v1,
metadata_store_v1,
model_v1,
model_evaluation_v1,
model_evaluation_slice_v1,
model_deployment_monitoring_job_v1,
model_service_v1,
model_monitoring_v1,
operation_v1,
persistent_resource_v1,
persistent_resource_service_v1,
pipeline_failure_policy_v1,
pipeline_job_v1,
pipeline_service_v1,
pipeline_state_v1,
prediction_service_v1,
publisher_model_v1,
reservation_affinity_v1,
schedule_v1,
schedule_service_v1,
specialist_pool_v1,
specialist_pool_service_v1,
tensorboard_v1,
tensorboard_data_v1,
tensorboard_experiment_v1,
tensorboard_run_v1,
tensorboard_service_v1,
tensorboard_time_series_v1,
training_pipeline_v1,
types_v1,
study_v1,
vizier_service_v1,
# v1beta1
accelerator_type_v1beta1,
annotation_v1beta1,
annotation_spec_v1beta1,
artifact_v1beta1,
batch_prediction_job_v1beta1,
completion_stats_v1beta1,
context_v1beta1,
custom_job_v1beta1,
data_item_v1beta1,
data_labeling_job_v1beta1,
dataset_v1beta1,
dataset_service_v1beta1,
deployment_resource_pool_v1beta1,
deployment_resource_pool_service_v1beta1,
deployed_model_ref_v1beta1,
encryption_spec_v1beta1,
endpoint_v1beta1,
endpoint_service_v1beta1,
entity_type_v1beta1,
env_var_v1beta1,
event_v1beta1,
execution_v1beta1,
explanation_v1beta1,
explanation_metadata_v1beta1,
feature_v1beta1,
feature_monitoring_stats_v1beta1,
feature_selector_v1beta1,
featurestore_v1beta1,
featurestore_monitoring_v1beta1,
featurestore_online_service_v1beta1,
featurestore_service_v1beta1,
hyperparameter_tuning_job_v1beta1,
io_v1beta1,
job_service_v1beta1,
job_state_v1beta1,
lineage_subgraph_v1beta1,
machine_resources_v1beta1,
manual_batch_tuning_parameters_v1beta1,
matching_engine_deployed_index_ref_v1beta1,
index_v1beta1,
index_endpoint_v1beta1,
index_service_v1beta1,
match_service_v1beta1,
metadata_service_v1beta1,
metadata_schema_v1beta1,
metadata_store_v1beta1,
model_v1beta1,
model_evaluation_v1beta1,
model_evaluation_slice_v1beta1,
model_deployment_monitoring_job_v1beta1,
model_garden_service_v1beta1,
model_service_v1beta1,
model_monitor_v1beta1,
model_monitoring_v1beta1,
model_monitoring_alert_v1beta1,
model_monitoring_job_v1beta1,
model_monitoring_service_v1beta1,
model_monitoring_spec_v1beta1,
model_monitoring_stats_v1beta1,
operation_v1beta1,
persistent_resource_v1beta1,
persistent_resource_service_v1beta1,
pipeline_failure_policy_v1beta1,
pipeline_job_v1beta1,
pipeline_service_v1beta1,
pipeline_state_v1beta1,
prediction_service_v1beta1,
publisher_model_v1beta1,
reservation_affinity_v1beta1,
schedule_v1beta1,
schedule_service_v1beta1,
specialist_pool_v1beta1,
specialist_pool_service_v1beta1,
study_v1beta1,
tensorboard_v1beta1,
tensorboard_data_v1beta1,
tensorboard_experiment_v1beta1,
tensorboard_run_v1beta1,
tensorboard_service_v1beta1,
tensorboard_time_series_v1beta1,
training_pipeline_v1beta1,
types_v1beta1,
vizier_service_v1beta1,
)

View File

@@ -0,0 +1,18 @@
# Copyright 2021 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from google.cloud.aiplatform.constants import base
from google.cloud.aiplatform.constants import prediction
__all__ = ("base", "prediction")

View File

@@ -0,0 +1,143 @@
# -*- coding: utf-8 -*-
# Copyright 2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from google.cloud.aiplatform import version as aiplatform_version
DEFAULT_REGION = "us-central1"
SUPPORTED_REGIONS = frozenset(
{
"africa-south1",
"asia-east1",
"asia-east2",
"asia-northeast1",
"asia-northeast2",
"asia-northeast3",
"asia-south1",
"asia-southeast1",
"asia-southeast2",
"australia-southeast1",
"australia-southeast2",
"europe-central2",
"europe-north1",
"europe-southwest1",
"europe-west1",
"europe-west2",
"europe-west3",
"europe-west4",
"europe-west6",
"europe-west8",
"europe-west9",
"europe-west12",
"global",
"me-central1",
"me-central2",
"me-west1",
"northamerica-northeast1",
"northamerica-northeast2",
"southamerica-east1",
"southamerica-west1",
"us-central1",
"us-east1",
"us-east4",
"us-east5",
"us-south1",
"us-west1",
"us-west2",
"us-west3",
"us-west4",
}
)
API_BASE_PATH = "aiplatform.googleapis.com"
PREDICTION_API_BASE_PATH = API_BASE_PATH
# Batch Prediction
BATCH_PREDICTION_INPUT_STORAGE_FORMATS = (
"jsonl",
"csv",
"tf-record",
"tf-record-gzip",
"bigquery",
"file-list",
)
BATCH_PREDICTION_OUTPUT_STORAGE_FORMATS = ("jsonl", "csv", "bigquery")
MOBILE_TF_MODEL_TYPES = {
"MOBILE_TF_LOW_LATENCY_1",
"MOBILE_TF_VERSATILE_1",
"MOBILE_TF_HIGH_ACCURACY_1",
}
MODEL_GARDEN_ICN_MODEL_TYPES = {
"EFFICIENTNET",
"MAXVIT",
"VIT",
"COCA",
}
MODEL_GARDEN_IOD_MODEL_TYPES = {
"SPINENET",
"YOLO",
}
# TODO(b/177079208): Use EPCL Enums for validating Model Types
# Defined by gs://google-cloud-aiplatform/schema/trainingjob/definition/automl_image_*
# Format: "prediction_type": set() of model_type's
#
# NOTE: When adding a new prediction_type's, ensure it fits the pattern
# "automl_image_{prediction_type}_*" used by the YAML schemas on GCS
AUTOML_IMAGE_PREDICTION_MODEL_TYPES = {
"classification": {"CLOUD", "CLOUD_1"}
| MOBILE_TF_MODEL_TYPES
| MODEL_GARDEN_ICN_MODEL_TYPES,
"object_detection": {"CLOUD_1", "CLOUD_HIGH_ACCURACY_1", "CLOUD_LOW_LATENCY_1"}
| MOBILE_TF_MODEL_TYPES
| MODEL_GARDEN_IOD_MODEL_TYPES,
}
AUTOML_VIDEO_PREDICTION_MODEL_TYPES = {
"classification": {"CLOUD"} | {"MOBILE_VERSATILE_1"},
"action_recognition": {"CLOUD"} | {"MOBILE_VERSATILE_1"},
"object_tracking": {"CLOUD"}
| {
"MOBILE_VERSATILE_1",
"MOBILE_CORAL_VERSATILE_1",
"MOBILE_CORAL_LOW_LATENCY_1",
"MOBILE_JETSON_VERSATILE_1",
"MOBILE_JETSON_LOW_LATENCY_1",
},
}
# Used in constructing the requests user_agent header for metrics reporting.
USER_AGENT_PRODUCT = "model-builder"
# This field is used to pass the name of the specific SDK method
# that is being used for usage metrics tracking purposes.
# For more details on go/oneplatform-api-analytics
USER_AGENT_SDK_COMMAND = ""
# Needed for Endpoint.raw_predict
DEFAULT_AUTHED_SCOPES = ["https://www.googleapis.com/auth/cloud-platform"]
# Used in CustomJob.from_local_script for experiments integration in training
AIPLATFORM_DEPENDENCY_PATH = (
f"google-cloud-aiplatform=={aiplatform_version.__version__}"
)
AIPLATFORM_AUTOLOG_DEPENDENCY_PATH = (
f"google-cloud-aiplatform[autologging]=={aiplatform_version.__version__}"
)

View File

@@ -0,0 +1,59 @@
# -*- coding: utf-8 -*-
# Copyright 2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import re
from google.cloud.aiplatform.compat.types import (
pipeline_state as gca_pipeline_state,
)
_PIPELINE_COMPLETE_STATES = set(
[
gca_pipeline_state.PipelineState.PIPELINE_STATE_SUCCEEDED,
gca_pipeline_state.PipelineState.PIPELINE_STATE_FAILED,
gca_pipeline_state.PipelineState.PIPELINE_STATE_CANCELLED,
gca_pipeline_state.PipelineState.PIPELINE_STATE_PAUSED,
]
)
_PIPELINE_ERROR_STATES = set([gca_pipeline_state.PipelineState.PIPELINE_STATE_FAILED])
# Pattern for valid names used as a Vertex resource name.
_VALID_NAME_PATTERN = re.compile("^[a-z][-a-z0-9]{0,127}$", re.IGNORECASE)
# Pattern for an Artifact Registry URL.
_VALID_AR_URL = re.compile(r"^https:\/\/([\w-]+)-kfp\.pkg\.dev\/.*", re.IGNORECASE)
# Pattern for any JSON or YAML file over HTTPS.
_VALID_HTTPS_URL = re.compile(r"^https:\/\/([\.\/\w-]+)\/.*(json|yaml|yml)$")
# Fields to include in returned PipelineJob when enable_simple_view=True in PipelineJob.list()
_READ_MASK_FIELDS = [
"name",
"state",
"display_name",
"pipeline_spec.pipeline_info",
"create_time",
"start_time",
"end_time",
"update_time",
"labels",
"template_uri",
"template_metadata.version",
"job_detail.pipeline_run_context",
"job_detail.pipeline_context",
]

View File

@@ -0,0 +1,304 @@
# Copyright 2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import re
from collections import defaultdict
# [region]-docker.pkg.dev/vertex-ai/prediction/[framework]-[accelerator].[version]:latest
CONTAINER_URI_PATTERN = re.compile(
r"(?P<region>[\w]+)\-docker\.pkg\.dev\/vertex\-ai\/prediction\/"
r"(?P<framework>[\w]+)\-(?P<accelerator>[\w]+)\.(?P<version>[\d-]+):latest"
)
CONTAINER_URI_REGEX = (
r"^(us|europe|asia)-docker.pkg.dev/"
r"vertex-ai/prediction/"
r"(tf|sklearn|xgboost|pytorch).+$"
)
SKLEARN = "sklearn"
TF = "tf"
TF2 = "tf2"
XGBOOST = "xgboost"
XGBOOST_CONTAINER_URIS = [
"us-docker.pkg.dev/vertex-ai/prediction/xgboost-cpu.2-1:latest",
"europe-docker.pkg.dev/vertex-ai/prediction/xgboost-cpu.2-1:latest",
"asia-docker.pkg.dev/vertex-ai/prediction/xgboost-cpu.2-1:latest",
"us-docker.pkg.dev/vertex-ai/prediction/xgboost-cpu.2-0:latest",
"europe-docker.pkg.dev/vertex-ai/prediction/xgboost-cpu.2-0:latest",
"asia-docker.pkg.dev/vertex-ai/prediction/xgboost-cpu.2-0:latest",
"us-docker.pkg.dev/vertex-ai/prediction/xgboost-cpu.1-7:latest",
"europe-docker.pkg.dev/vertex-ai/prediction/xgboost-cpu.1-7:latest",
"asia-docker.pkg.dev/vertex-ai/prediction/xgboost-cpu.1-7:latest",
"us-docker.pkg.dev/vertex-ai/prediction/xgboost-cpu.1-6:latest",
"europe-docker.pkg.dev/vertex-ai/prediction/xgboost-cpu.1-6:latest",
"asia-docker.pkg.dev/vertex-ai/prediction/xgboost-cpu.1-6:latest",
"us-docker.pkg.dev/vertex-ai/prediction/xgboost-cpu.1-5:latest",
"europe-docker.pkg.dev/vertex-ai/prediction/xgboost-cpu.1-5:latest",
"asia-docker.pkg.dev/vertex-ai/prediction/xgboost-cpu.1-5:latest",
"us-docker.pkg.dev/vertex-ai/prediction/xgboost-cpu.1-4:latest",
"europe-docker.pkg.dev/vertex-ai/prediction/xgboost-cpu.1-4:latest",
"asia-docker.pkg.dev/vertex-ai/prediction/xgboost-cpu.1-4:latest",
"us-docker.pkg.dev/vertex-ai/prediction/xgboost-cpu.1-3:latest",
"europe-docker.pkg.dev/vertex-ai/prediction/xgboost-cpu.1-3:latest",
"asia-docker.pkg.dev/vertex-ai/prediction/xgboost-cpu.1-3:latest",
"us-docker.pkg.dev/vertex-ai/prediction/xgboost-cpu.1-2:latest",
"europe-docker.pkg.dev/vertex-ai/prediction/xgboost-cpu.1-2:latest",
"asia-docker.pkg.dev/vertex-ai/prediction/xgboost-cpu.1-2:latest",
"us-docker.pkg.dev/vertex-ai/prediction/xgboost-cpu.1-1:latest",
"europe-docker.pkg.dev/vertex-ai/prediction/xgboost-cpu.1-1:latest",
"asia-docker.pkg.dev/vertex-ai/prediction/xgboost-cpu.1-1:latest",
"us-docker.pkg.dev/vertex-ai/prediction/xgboost-cpu.0-90:latest",
"europe-docker.pkg.dev/vertex-ai/prediction/xgboost-cpu.0-90:latest",
"asia-docker.pkg.dev/vertex-ai/prediction/xgboost-cpu.0-90:latest",
"us-docker.pkg.dev/vertex-ai/prediction/xgboost-cpu.0-82:latest",
"europe-docker.pkg.dev/vertex-ai/prediction/xgboost-cpu.0-82:latest",
"asia-docker.pkg.dev/vertex-ai/prediction/xgboost-cpu.0-82:latest",
]
SKLEARN_CONTAINER_URIS = [
"us-docker.pkg.dev/vertex-ai/prediction/sklearn-cpu.1-5:latest",
"europe-docker.pkg.dev/vertex-ai/prediction/sklearn-cpu.1-5:latest",
"asia-docker.pkg.dev/vertex-ai/prediction/sklearn-cpu.1-5:latest",
"us-docker.pkg.dev/vertex-ai/prediction/sklearn-cpu.1-4:latest",
"europe-docker.pkg.dev/vertex-ai/prediction/sklearn-cpu.1-4:latest",
"asia-docker.pkg.dev/vertex-ai/prediction/sklearn-cpu.1-4:latest",
"us-docker.pkg.dev/vertex-ai/prediction/sklearn-cpu.1-3:latest",
"europe-docker.pkg.dev/vertex-ai/prediction/sklearn-cpu.1-3:latest",
"asia-docker.pkg.dev/vertex-ai/prediction/sklearn-cpu.1-3:latest",
"us-docker.pkg.dev/vertex-ai/prediction/sklearn-cpu.1-2:latest",
"europe-docker.pkg.dev/vertex-ai/prediction/sklearn-cpu.1-2:latest",
"asia-docker.pkg.dev/vertex-ai/prediction/sklearn-cpu.1-2:latest",
"us-docker.pkg.dev/vertex-ai/prediction/sklearn-cpu.1-0:latest",
"europe-docker.pkg.dev/vertex-ai/prediction/sklearn-cpu.1-0:latest",
"asia-docker.pkg.dev/vertex-ai/prediction/sklearn-cpu.1-0:latest",
"us-docker.pkg.dev/vertex-ai/prediction/sklearn-cpu.0-24:latest",
"europe-docker.pkg.dev/vertex-ai/prediction/sklearn-cpu.0-24:latest",
"asia-docker.pkg.dev/vertex-ai/prediction/sklearn-cpu.0-24:latest",
"us-docker.pkg.dev/vertex-ai/prediction/sklearn-cpu.0-23:latest",
"europe-docker.pkg.dev/vertex-ai/prediction/sklearn-cpu.0-23:latest",
"asia-docker.pkg.dev/vertex-ai/prediction/sklearn-cpu.0-23:latest",
"us-docker.pkg.dev/vertex-ai/prediction/sklearn-cpu.0-22:latest",
"europe-docker.pkg.dev/vertex-ai/prediction/sklearn-cpu.0-22:latest",
"asia-docker.pkg.dev/vertex-ai/prediction/sklearn-cpu.0-22:latest",
"us-docker.pkg.dev/vertex-ai/prediction/sklearn-cpu.0-20:latest",
"europe-docker.pkg.dev/vertex-ai/prediction/sklearn-cpu.0-20:latest",
"asia-docker.pkg.dev/vertex-ai/prediction/sklearn-cpu.0-20:latest",
]
TF_CONTAINER_URIS = [
"us-docker.pkg.dev/vertex-ai/prediction/tf2-cpu.2-15:latest",
"europe-docker.pkg.dev/vertex-ai/prediction/tf2-cpu.2-15:latest",
"asia-docker.pkg.dev/vertex-ai/prediction/tf2-cpu.2-15:latest",
"us-docker.pkg.dev/vertex-ai/prediction/tf2-gpu.2-15:latest",
"europe-docker.pkg.dev/vertex-ai/prediction/tf2-gpu.2-15:latest",
"asia-docker.pkg.dev/vertex-ai/prediction/tf2-gpu.2-15:latest",
"us-docker.pkg.dev/vertex-ai/prediction/tf2-cpu.2-14:latest",
"europe-docker.pkg.dev/vertex-ai/prediction/tf2-cpu.2-14:latest",
"asia-docker.pkg.dev/vertex-ai/prediction/tf2-cpu.2-14:latest",
"us-docker.pkg.dev/vertex-ai/prediction/tf2-gpu.2-14:latest",
"europe-docker.pkg.dev/vertex-ai/prediction/tf2-gpu.2-14:latest",
"asia-docker.pkg.dev/vertex-ai/prediction/tf2-gpu.2-14:latest",
"us-docker.pkg.dev/vertex-ai/prediction/tf2-cpu.2-13:latest",
"europe-docker.pkg.dev/vertex-ai/prediction/tf2-cpu.2-13:latest",
"asia-docker.pkg.dev/vertex-ai/prediction/tf2-cpu.2-13:latest",
"us-docker.pkg.dev/vertex-ai/prediction/tf2-gpu.2-13:latest",
"europe-docker.pkg.dev/vertex-ai/prediction/tf2-gpu.2-13:latest",
"asia-docker.pkg.dev/vertex-ai/prediction/tf2-gpu.2-13:latest",
"us-docker.pkg.dev/vertex-ai/prediction/tf2-cpu.2-12:latest",
"europe-docker.pkg.dev/vertex-ai/prediction/tf2-cpu.2-12:latest",
"asia-docker.pkg.dev/vertex-ai/prediction/tf2-cpu.2-12:latest",
"us-docker.pkg.dev/vertex-ai/prediction/tf2-gpu.2-12:latest",
"europe-docker.pkg.dev/vertex-ai/prediction/tf2-gpu.2-12:latest",
"asia-docker.pkg.dev/vertex-ai/prediction/tf2-gpu.2-12:latest",
"us-docker.pkg.dev/vertex-ai/prediction/tf2-cpu.2-11:latest",
"europe-docker.pkg.dev/vertex-ai/prediction/tf2-cpu.2-11:latest",
"asia-docker.pkg.dev/vertex-ai/prediction/tf2-cpu.2-11:latest",
"us-docker.pkg.dev/vertex-ai/prediction/tf2-gpu.2-11:latest",
"europe-docker.pkg.dev/vertex-ai/prediction/tf2-gpu.2-11:latest",
"asia-docker.pkg.dev/vertex-ai/prediction/tf2-gpu.2-11:latest",
"us-docker.pkg.dev/vertex-ai/prediction/tf2-cpu.2-10:latest",
"europe-docker.pkg.dev/vertex-ai/prediction/tf2-cpu.2-10:latest",
"asia-docker.pkg.dev/vertex-ai/prediction/tf2-cpu.2-10:latest",
"us-docker.pkg.dev/vertex-ai/prediction/tf2-gpu.2-10:latest",
"europe-docker.pkg.dev/vertex-ai/prediction/tf2-gpu.2-10:latest",
"asia-docker.pkg.dev/vertex-ai/prediction/tf2-gpu.2-10:latest",
"us-docker.pkg.dev/vertex-ai/prediction/tf2-cpu.2-9:latest",
"europe-docker.pkg.dev/vertex-ai/prediction/tf2-cpu.2-9:latest",
"asia-docker.pkg.dev/vertex-ai/prediction/tf2-cpu.2-9:latest",
"us-docker.pkg.dev/vertex-ai/prediction/tf2-gpu.2-9:latest",
"europe-docker.pkg.dev/vertex-ai/prediction/tf2-gpu.2-9:latest",
"asia-docker.pkg.dev/vertex-ai/prediction/tf2-gpu.2-9:latest",
"us-docker.pkg.dev/vertex-ai/prediction/tf2-cpu.2-8:latest",
"europe-docker.pkg.dev/vertex-ai/prediction/tf2-cpu.2-8:latest",
"asia-docker.pkg.dev/vertex-ai/prediction/tf2-cpu.2-8:latest",
"us-docker.pkg.dev/vertex-ai/prediction/tf2-gpu.2-8:latest",
"europe-docker.pkg.dev/vertex-ai/prediction/tf2-gpu.2-8:latest",
"asia-docker.pkg.dev/vertex-ai/prediction/tf2-gpu.2-8:latest",
"us-docker.pkg.dev/vertex-ai/prediction/tf2-cpu.2-7:latest",
"europe-docker.pkg.dev/vertex-ai/prediction/tf2-cpu.2-7:latest",
"asia-docker.pkg.dev/vertex-ai/prediction/tf2-cpu.2-7:latest",
"us-docker.pkg.dev/vertex-ai/prediction/tf2-gpu.2-7:latest",
"europe-docker.pkg.dev/vertex-ai/prediction/tf2-gpu.2-7:latest",
"asia-docker.pkg.dev/vertex-ai/prediction/tf2-gpu.2-7:latest",
"us-docker.pkg.dev/vertex-ai/prediction/tf2-cpu.2-6:latest",
"europe-docker.pkg.dev/vertex-ai/prediction/tf2-cpu.2-6:latest",
"asia-docker.pkg.dev/vertex-ai/prediction/tf2-cpu.2-6:latest",
"us-docker.pkg.dev/vertex-ai/prediction/tf2-gpu.2-6:latest",
"europe-docker.pkg.dev/vertex-ai/prediction/tf2-gpu.2-6:latest",
"asia-docker.pkg.dev/vertex-ai/prediction/tf2-gpu.2-6:latest",
"us-docker.pkg.dev/vertex-ai/prediction/tf2-cpu.2-5:latest",
"europe-docker.pkg.dev/vertex-ai/prediction/tf2-cpu.2-5:latest",
"asia-docker.pkg.dev/vertex-ai/prediction/tf2-cpu.2-5:latest",
"us-docker.pkg.dev/vertex-ai/prediction/tf2-gpu.2-5:latest",
"europe-docker.pkg.dev/vertex-ai/prediction/tf2-gpu.2-5:latest",
"asia-docker.pkg.dev/vertex-ai/prediction/tf2-gpu.2-5:latest",
"us-docker.pkg.dev/vertex-ai/prediction/tf2-cpu.2-4:latest",
"europe-docker.pkg.dev/vertex-ai/prediction/tf2-cpu.2-4:latest",
"asia-docker.pkg.dev/vertex-ai/prediction/tf2-cpu.2-4:latest",
"us-docker.pkg.dev/vertex-ai/prediction/tf2-gpu.2-4:latest",
"europe-docker.pkg.dev/vertex-ai/prediction/tf2-gpu.2-4:latest",
"asia-docker.pkg.dev/vertex-ai/prediction/tf2-gpu.2-4:latest",
"us-docker.pkg.dev/vertex-ai/prediction/tf2-cpu.2-3:latest",
"europe-docker.pkg.dev/vertex-ai/prediction/tf2-cpu.2-3:latest",
"asia-docker.pkg.dev/vertex-ai/prediction/tf2-cpu.2-3:latest",
"us-docker.pkg.dev/vertex-ai/prediction/tf2-gpu.2-3:latest",
"europe-docker.pkg.dev/vertex-ai/prediction/tf2-gpu.2-3:latest",
"asia-docker.pkg.dev/vertex-ai/prediction/tf2-gpu.2-3:latest",
"us-docker.pkg.dev/vertex-ai/prediction/tf2-cpu.2-2:latest",
"europe-docker.pkg.dev/vertex-ai/prediction/tf2-cpu.2-2:latest",
"asia-docker.pkg.dev/vertex-ai/prediction/tf2-cpu.2-2:latest",
"us-docker.pkg.dev/vertex-ai/prediction/tf2-gpu.2-2:latest",
"europe-docker.pkg.dev/vertex-ai/prediction/tf2-gpu.2-2:latest",
"asia-docker.pkg.dev/vertex-ai/prediction/tf2-gpu.2-2:latest",
"us-docker.pkg.dev/vertex-ai/prediction/tf2-cpu.2-1:latest",
"europe-docker.pkg.dev/vertex-ai/prediction/tf2-cpu.2-1:latest",
"asia-docker.pkg.dev/vertex-ai/prediction/tf2-cpu.2-1:latest",
"us-docker.pkg.dev/vertex-ai/prediction/tf2-gpu.2-1:latest",
"europe-docker.pkg.dev/vertex-ai/prediction/tf2-gpu.2-1:latest",
"asia-docker.pkg.dev/vertex-ai/prediction/tf2-gpu.2-1:latest",
"us-docker.pkg.dev/vertex-ai/prediction/tf-cpu.1-15:latest",
"europe-docker.pkg.dev/vertex-ai/prediction/tf-cpu.1-15:latest",
"asia-docker.pkg.dev/vertex-ai/prediction/tf-cpu.1-15:latest",
"us-docker.pkg.dev/vertex-ai/prediction/tf-gpu.1-15:latest",
"europe-docker.pkg.dev/vertex-ai/prediction/tf-gpu.1-15:latest",
"asia-docker.pkg.dev/vertex-ai/prediction/tf-gpu.1-15:latest",
]
PYTORCH_CONTAINER_URIS = [
"us-docker.pkg.dev/vertex-ai/prediction/pytorch-cpu.2-4:latest",
"europe-docker.pkg.dev/vertex-ai/prediction/pytorch-cpu.2-4:latest",
"asia-docker.pkg.dev/vertex-ai/prediction/pytorch-cpu.2-4:latest",
"us-docker.pkg.dev/vertex-ai/prediction/pytorch-gpu.2-4:latest",
"europe-docker.pkg.dev/vertex-ai/prediction/pytorch-gpu.2-4:latest",
"asia-docker.pkg.dev/vertex-ai/prediction/pytorch-gpu.2-4:latest",
"us-docker.pkg.dev/vertex-ai/prediction/pytorch-cpu.2-3:latest",
"europe-docker.pkg.dev/vertex-ai/prediction/pytorch-cpu.2-3:latest",
"asia-docker.pkg.dev/vertex-ai/prediction/pytorch-cpu.2-3:latest",
"us-docker.pkg.dev/vertex-ai/prediction/pytorch-gpu.2-3:latest",
"europe-docker.pkg.dev/vertex-ai/prediction/pytorch-gpu.2-3:latest",
"asia-docker.pkg.dev/vertex-ai/prediction/pytorch-gpu.2-3:latest",
"us-docker.pkg.dev/vertex-ai/prediction/pytorch-cpu.2-2:latest",
"europe-docker.pkg.dev/vertex-ai/prediction/pytorch-cpu.2-2:latest",
"asia-docker.pkg.dev/vertex-ai/prediction/pytorch-cpu.2-2:latest",
"us-docker.pkg.dev/vertex-ai/prediction/pytorch-gpu.2-2:latest",
"europe-docker.pkg.dev/vertex-ai/prediction/pytorch-gpu.2-2:latest",
"asia-docker.pkg.dev/vertex-ai/prediction/pytorch-gpu.2-2:latest",
"us-docker.pkg.dev/vertex-ai/prediction/pytorch-cpu.2-1:latest",
"europe-docker.pkg.dev/vertex-ai/prediction/pytorch-cpu.2-1:latest",
"asia-docker.pkg.dev/vertex-ai/prediction/pytorch-cpu.2-1:latest",
"us-docker.pkg.dev/vertex-ai/prediction/pytorch-gpu.2-1:latest",
"europe-docker.pkg.dev/vertex-ai/prediction/pytorch-gpu.2-1:latest",
"asia-docker.pkg.dev/vertex-ai/prediction/pytorch-gpu.2-1:latest",
"us-docker.pkg.dev/vertex-ai/prediction/pytorch-cpu.2-0:latest",
"europe-docker.pkg.dev/vertex-ai/prediction/pytorch-cpu.2-0:latest",
"asia-docker.pkg.dev/vertex-ai/prediction/pytorch-cpu.2-0:latest",
"us-docker.pkg.dev/vertex-ai/prediction/pytorch-gpu.2-0:latest",
"europe-docker.pkg.dev/vertex-ai/prediction/pytorch-gpu.2-0:latest",
"asia-docker.pkg.dev/vertex-ai/prediction/pytorch-gpu.2-0:latest",
"us-docker.pkg.dev/vertex-ai/prediction/pytorch-cpu.1-13:latest",
"europe-docker.pkg.dev/vertex-ai/prediction/pytorch-cpu.1-13:latest",
"asia-docker.pkg.dev/vertex-ai/prediction/pytorch-cpu.1-13:latest",
"us-docker.pkg.dev/vertex-ai/prediction/pytorch-gpu.1-13:latest",
"europe-docker.pkg.dev/vertex-ai/prediction/pytorch-gpu.1-13:latest",
"asia-docker.pkg.dev/vertex-ai/prediction/pytorch-gpu.1-13:latest",
"us-docker.pkg.dev/vertex-ai/prediction/pytorch-cpu.1-12:latest",
"europe-docker.pkg.dev/vertex-ai/prediction/pytorch-cpu.1-12:latest",
"asia-docker.pkg.dev/vertex-ai/prediction/pytorch-cpu.1-12:latest",
"us-docker.pkg.dev/vertex-ai/prediction/pytorch-gpu.1-12:latest",
"europe-docker.pkg.dev/vertex-ai/prediction/pytorch-gpu.1-12:latest",
"asia-docker.pkg.dev/vertex-ai/prediction/pytorch-gpu.1-12:latest",
"us-docker.pkg.dev/vertex-ai/prediction/pytorch-cpu.1-11:latest",
"europe-docker.pkg.dev/vertex-ai/prediction/pytorch-cpu.1-11:latest",
"asia-docker.pkg.dev/vertex-ai/prediction/pytorch-cpu.1-11:latest",
"us-docker.pkg.dev/vertex-ai/prediction/pytorch-gpu.1-11:latest",
"europe-docker.pkg.dev/vertex-ai/prediction/pytorch-gpu.1-11:latest",
"asia-docker.pkg.dev/vertex-ai/prediction/pytorch-gpu.1-11:latest",
]
SERVING_CONTAINER_URIS = (
SKLEARN_CONTAINER_URIS
+ TF_CONTAINER_URIS
+ XGBOOST_CONTAINER_URIS
+ PYTORCH_CONTAINER_URIS
)
# Map of all first-party prediction containers
d = defaultdict(lambda: defaultdict(lambda: defaultdict(lambda: defaultdict(str))))
for container_uri in SERVING_CONTAINER_URIS:
m = CONTAINER_URI_PATTERN.match(container_uri)
region, framework, accelerator, version = m[1], m[2], m[3], m[4]
version = version.replace("-", ".")
if framework in (TF2, TF): # Store both `tf`, `tf2` as `tensorflow`
framework = "tensorflow"
d[region][framework][accelerator][version] = container_uri
_SERVING_CONTAINER_URI_MAP = d
_SERVING_CONTAINER_DOCUMENTATION_URL = (
"https://cloud.google.com/vertex-ai/docs/predictions/pre-built-containers"
)
# Variables set by Vertex AI. For more details, please refer to
# https://cloud.google.com/vertex-ai/docs/predictions/custom-container-requirements#aip-variables
DEFAULT_AIP_HTTP_PORT = 8080
AIP_HTTP_PORT = "AIP_HTTP_PORT"
AIP_HEALTH_ROUTE = "AIP_HEALTH_ROUTE"
AIP_PREDICT_ROUTE = "AIP_PREDICT_ROUTE"
AIP_STORAGE_URI = "AIP_STORAGE_URI"
# Default values for Prediction local experience.
DEFAULT_LOCAL_PREDICT_ROUTE = "/predict"
DEFAULT_LOCAL_HEALTH_ROUTE = "/health"
DEFAULT_LOCAL_RUN_GPU_CAPABILITIES = [["utility", "compute"]]
DEFAULT_LOCAL_RUN_GPU_COUNT = -1
CUSTOM_PREDICTION_ROUTINES = "custom-prediction-routines"
CUSTOM_PREDICTION_ROUTINES_SERVER_ERROR_HEADER_KEY = "X-AIP-CPR-SYSTEM-ERROR"
# Headers' related constants for the handler usage.
CONTENT_TYPE_HEADER_REGEX = re.compile("^[Cc]ontent-?[Tt]ype$")
ACCEPT_HEADER_REGEX = re.compile("^[Aa]ccept$")
ANY_ACCEPT_TYPE = "*/*"
DEFAULT_ACCEPT_VALUE = "application/json"
# Model filenames.
MODEL_FILENAME_BST = "model.bst"
MODEL_FILENAME_JOBLIB = "model.joblib"
MODEL_FILENAME_PKL = "model.pkl"

View File

@@ -0,0 +1,43 @@
# -*- coding: utf-8 -*-
# Copyright 2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from google.cloud.aiplatform.compat.types import (
schedule as gca_schedule,
)
from google.cloud.aiplatform.constants import pipeline as pipeline_constants
_SCHEDULE_COMPLETE_STATES = set(
[
gca_schedule.Schedule.State.PAUSED,
gca_schedule.Schedule.State.COMPLETED,
]
)
_SCHEDULE_ERROR_STATES = set(
[
gca_schedule.Schedule.State.STATE_UNSPECIFIED,
]
)
# Pattern for valid names used as a Vertex resource name.
_VALID_NAME_PATTERN = pipeline_constants._VALID_NAME_PATTERN
# Pattern for an Artifact Registry URL.
_VALID_AR_URL = pipeline_constants._VALID_AR_URL
# Pattern for any JSON or YAML file over HTTPS.
_VALID_HTTPS_URL = pipeline_constants._VALID_HTTPS_URL

View File

@@ -0,0 +1,35 @@
# -*- coding: utf-8 -*-
# Copyright 2020 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from google.cloud.aiplatform.datasets.dataset import _Dataset
from google.cloud.aiplatform.datasets.column_names_dataset import _ColumnNamesDataset
from google.cloud.aiplatform.datasets.tabular_dataset import TabularDataset
from google.cloud.aiplatform.datasets.time_series_dataset import TimeSeriesDataset
from google.cloud.aiplatform.datasets.image_dataset import ImageDataset
from google.cloud.aiplatform.datasets.text_dataset import TextDataset
from google.cloud.aiplatform.datasets.video_dataset import VideoDataset
__all__ = (
"_Dataset",
"_ColumnNamesDataset",
"TabularDataset",
"TimeSeriesDataset",
"ImageDataset",
"TextDataset",
"VideoDataset",
)

View File

@@ -0,0 +1,240 @@
# -*- coding: utf-8 -*-
# Copyright 2020 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import abc
from typing import Optional, Dict, Sequence, Union
from google.cloud.aiplatform import schema
from google.cloud.aiplatform.compat.types import (
io as gca_io,
dataset as gca_dataset,
)
class Datasource(abc.ABC):
"""An abstract class that sets dataset_metadata."""
@property
@abc.abstractmethod
def dataset_metadata(self):
"""Dataset Metadata."""
pass
class DatasourceImportable(abc.ABC):
"""An abstract class that sets import_data_config."""
@property
@abc.abstractmethod
def import_data_config(self):
"""Import Data Config."""
pass
class TabularDatasource(Datasource):
"""Datasource for creating a tabular dataset for Vertex AI."""
def __init__(
self,
gcs_source: Optional[Union[str, Sequence[str]]] = None,
bq_source: Optional[str] = None,
):
"""Creates a tabular datasource.
Args:
gcs_source (Union[str, Sequence[str]]):
Cloud Storage URI of one or more files. Only CSV files are supported.
The first line of the CSV file is used as the header.
If there are multiple files, the header is the first line of
the lexicographically first file, the other files must either
contain the exact same header or omit the header.
examples:
str: "gs://bucket/file.csv"
Sequence[str]: ["gs://bucket/file1.csv", "gs://bucket/file2.csv"]
bq_source (str):
The URI of a BigQuery table.
example:
"bq://project.dataset.table_name"
Raises:
ValueError: If source configuration is not valid.
"""
dataset_metadata = None
if gcs_source and isinstance(gcs_source, str):
gcs_source = [gcs_source]
if gcs_source and bq_source:
raise ValueError("Only one of gcs_source or bq_source can be set.")
if not any([gcs_source, bq_source]):
raise ValueError("One of gcs_source or bq_source must be set.")
if gcs_source:
dataset_metadata = {"inputConfig": {"gcsSource": {"uri": gcs_source}}}
elif bq_source:
dataset_metadata = {"inputConfig": {"bigquerySource": {"uri": bq_source}}}
self._dataset_metadata = dataset_metadata
@property
def dataset_metadata(self) -> Optional[Dict]:
"""Dataset Metadata."""
return self._dataset_metadata
class NonTabularDatasource(Datasource):
"""Datasource for creating an empty non-tabular dataset for Vertex AI."""
@property
def dataset_metadata(self) -> Optional[Dict]:
return None
class NonTabularDatasourceImportable(NonTabularDatasource, DatasourceImportable):
"""Datasource for creating a non-tabular dataset for Vertex AI and
importing data to the dataset."""
def __init__(
self,
gcs_source: Union[str, Sequence[str]],
import_schema_uri: str,
data_item_labels: Optional[Dict] = None,
):
"""Creates a non-tabular datasource.
Args:
gcs_source (Union[str, Sequence[str]]):
Required. The Google Cloud Storage location for the input content.
Google Cloud Storage URI(-s) to the input file(s).
Examples:
str: "gs://bucket/file.csv"
Sequence[str]: ["gs://bucket/file1.csv", "gs://bucket/file2.csv"]
import_schema_uri (str):
Required. Points to a YAML file stored on Google Cloud
Storage describing the import format. Validation will be
done against the schema. The schema is defined as an
`OpenAPI 3.0.2 Schema
data_item_labels (Dict):
Labels that will be applied to newly imported DataItems. If
an identical DataItem as one being imported already exists
in the Dataset, then these labels will be appended to these
of the already existing one, and if labels with identical
key is imported before, the old label value will be
overwritten. If two DataItems are identical in the same
import data operation, the labels will be combined and if
key collision happens in this case, one of the values will
be picked randomly. Two DataItems are considered identical
if their content bytes are identical (e.g. image bytes or
pdf bytes). These labels will be overridden by Annotation
labels specified inside index file refenced by
``import_schema_uri``,
e.g. jsonl file.
"""
super().__init__()
self._gcs_source = [gcs_source] if isinstance(gcs_source, str) else gcs_source
self._import_schema_uri = import_schema_uri
self._data_item_labels = data_item_labels
@property
def import_data_config(self) -> gca_dataset.ImportDataConfig:
"""Import Data Config."""
return gca_dataset.ImportDataConfig(
gcs_source=gca_io.GcsSource(uris=self._gcs_source),
import_schema_uri=self._import_schema_uri,
data_item_labels=self._data_item_labels,
)
def create_datasource(
metadata_schema_uri: str,
import_schema_uri: Optional[str] = None,
gcs_source: Optional[Union[str, Sequence[str]]] = None,
bq_source: Optional[str] = None,
data_item_labels: Optional[Dict] = None,
) -> Datasource:
"""Creates a datasource
Args:
metadata_schema_uri (str):
Required. Points to a YAML file stored on Google Cloud Storage
describing additional information about the Dataset. The schema
is defined as an OpenAPI 3.0.2 Schema Object. The schema files
that can be used here are found in gs://google-cloud-
aiplatform/schema/dataset/metadata/.
import_schema_uri (str):
Points to a YAML file stored on Google Cloud
Storage describing the import format. Validation will be
done against the schema. The schema is defined as an
`OpenAPI 3.0.2 Schema
gcs_source (Union[str, Sequence[str]]):
The Google Cloud Storage location for the input content.
Google Cloud Storage URI(-s) to the input file(s).
Examples:
str: "gs://bucket/file.csv"
Sequence[str]: ["gs://bucket/file1.csv", "gs://bucket/file2.csv"]
bq_source (str):
BigQuery URI to the input table.
example:
"bq://project.dataset.table_name"
data_item_labels (Dict):
Labels that will be applied to newly imported DataItems. If
an identical DataItem as one being imported already exists
in the Dataset, then these labels will be appended to these
of the already existing one, and if labels with identical
key is imported before, the old label value will be
overwritten. If two DataItems are identical in the same
import data operation, the labels will be combined and if
key collision happens in this case, one of the values will
be picked randomly. Two DataItems are considered identical
if their content bytes are identical (e.g. image bytes or
pdf bytes). These labels will be overridden by Annotation
labels specified inside index file refenced by
``import_schema_uri``,
e.g. jsonl file.
Returns:
datasource (Datasource)
Raises:
ValueError: When below scenarios happen:
- import_schema_uri is identified for creating TabularDatasource
- either import_schema_uri or gcs_source is missing for creating NonTabularDatasourceImportable
"""
if metadata_schema_uri == schema.dataset.metadata.tabular:
if import_schema_uri:
raise ValueError("tabular dataset does not support data import.")
return TabularDatasource(gcs_source, bq_source)
if metadata_schema_uri == schema.dataset.metadata.time_series:
if import_schema_uri:
raise ValueError("time series dataset does not support data import.")
return TabularDatasource(gcs_source, bq_source)
if not import_schema_uri and not gcs_source:
return NonTabularDatasource()
elif import_schema_uri and gcs_source:
return NonTabularDatasourceImportable(
gcs_source, import_schema_uri, data_item_labels
)
else:
raise ValueError(
"nontabular dataset requires both import_schema_uri and gcs_source for data import."
)

View File

@@ -0,0 +1,261 @@
# -*- coding: utf-8 -*-
# Copyright 2021 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import csv
import logging
from typing import List, Optional, Set, TYPE_CHECKING
from google.auth import credentials as auth_credentials
from google.cloud import storage
from google.cloud.aiplatform import utils
from google.cloud.aiplatform import datasets
if TYPE_CHECKING:
from google.cloud import bigquery
class _ColumnNamesDataset(datasets._Dataset):
@property
def column_names(self) -> List[str]:
"""Retrieve the columns for the dataset by extracting it from the Google Cloud Storage or
Google BigQuery source.
Returns:
List[str]
A list of columns names
Raises:
RuntimeError: When no valid source is found.
"""
self._assert_gca_resource_is_available()
metadata = self._gca_resource.metadata
if metadata is None:
raise RuntimeError("No metadata found for dataset")
input_config = metadata.get("inputConfig")
if input_config is None:
raise RuntimeError("No inputConfig found for dataset")
gcs_source = input_config.get("gcsSource")
bq_source = input_config.get("bigquerySource")
if gcs_source:
gcs_source_uris = gcs_source.get("uri")
if gcs_source_uris and len(gcs_source_uris) > 0:
# Lexicographically sort the files
gcs_source_uris.sort()
# Get the first file in sorted list
# TODO(b/193044977): Return as Set instead of List
return list(
self._retrieve_gcs_source_columns(
project=self.project,
gcs_csv_file_path=gcs_source_uris[0],
credentials=self.credentials,
)
)
elif bq_source:
bq_table_uri = bq_source.get("uri")
if bq_table_uri:
# TODO(b/193044977): Return as Set instead of List
return list(
self._retrieve_bq_source_columns(
project=self.project,
bq_table_uri=bq_table_uri,
credentials=self.credentials,
)
)
raise RuntimeError("No valid CSV or BigQuery datasource found.")
@staticmethod
def _retrieve_gcs_source_columns(
project: str,
gcs_csv_file_path: str,
credentials: Optional[auth_credentials.Credentials] = None,
) -> Set[str]:
"""Retrieve the columns from a comma-delimited CSV file stored on Google Cloud Storage
Example Usage:
column_names = _retrieve_gcs_source_columns(
"project_id",
"gs://example-bucket/path/to/csv_file"
)
# column_names = {"column_1", "column_2"}
Args:
project (str):
Required. Project to initiate the Google Cloud Storage client with.
gcs_csv_file_path (str):
Required. A full path to a CSV files stored on Google Cloud Storage.
Must include "gs://" prefix.
credentials (auth_credentials.Credentials):
Credentials to use to with GCS Client.
Returns:
Set[str]
A set of columns names in the CSV file.
Raises:
RuntimeError: When the retrieved CSV file is invalid.
"""
gcs_bucket, gcs_blob = utils.extract_bucket_and_prefix_from_gcs_path(
gcs_csv_file_path
)
client = storage.Client(project=project, credentials=credentials)
bucket = client.bucket(gcs_bucket)
blob = bucket.blob(gcs_blob)
# Incrementally download the CSV file until the header is retrieved
first_new_line_index = -1
start_index = 0
increment = 1000
line = ""
try:
logger = logging.getLogger("google.resumable_media._helpers")
logging_warning_filter = utils.LoggingFilter(logging.INFO)
logger.addFilter(logging_warning_filter)
while first_new_line_index == -1:
line += blob.download_as_bytes(
start=start_index, end=start_index + increment - 1
).decode("utf-8")
first_new_line_index = line.find("\n")
start_index += increment
header_line = line[:first_new_line_index]
# Split to make it an iterable
header_line = header_line.split("\n")[:1]
csv_reader = csv.reader(header_line, delimiter=",")
except (ValueError, RuntimeError) as err:
raise RuntimeError(
"There was a problem extracting the headers from the CSV file at '{}': {}".format(
gcs_csv_file_path, err
)
) from err
finally:
logger.removeFilter(logging_warning_filter)
return set(next(csv_reader))
@staticmethod
def _get_bq_schema_field_names_recursively(
schema_field: "bigquery.SchemaField",
) -> Set[str]:
"""Retrieve the name for a schema field along with ancestor fields.
Nested schema fields are flattened and concatenated with a ".".
Schema fields with child fields are not included, but the children are.
Args:
project (str):
Required. Project to initiate the BigQuery client with.
bq_table_uri (str):
Required. A URI to a BigQuery table.
Can include "bq://" prefix but not required.
credentials (auth_credentials.Credentials):
Credentials to use with BQ Client.
Returns:
Set[str]
A set of columns names in the BigQuery table.
"""
ancestor_names = {
nested_field_name
for field in schema_field.fields
for nested_field_name in _ColumnNamesDataset._get_bq_schema_field_names_recursively(
field
)
}
# Only return "leaf nodes", basically any field that doesn't have children
if len(ancestor_names) == 0:
return {schema_field.name}
else:
return {f"{schema_field.name}.{name}" for name in ancestor_names}
@staticmethod
def _retrieve_bq_source_columns(
project: str,
bq_table_uri: str,
credentials: Optional[auth_credentials.Credentials] = None,
) -> Set[str]:
"""Retrieve the column names from a table on Google BigQuery
Nested schema fields are flattened and concatenated with a ".".
Schema fields with child fields are not included, but the children are.
Example Usage:
column_names = _retrieve_bq_source_columns(
"project_id",
"bq://project_id.dataset.table"
)
# column_names = {"column_1", "column_2", "column_3.nested_field"}
Args:
project (str):
Required. Project to initiate the BigQuery client with.
bq_table_uri (str):
Required. A URI to a BigQuery table.
Can include "bq://" prefix but not required.
credentials (auth_credentials.Credentials):
Credentials to use with BQ Client.
Returns:
Set[str]
A set of column names in the BigQuery table.
"""
# Remove bq:// prefix
prefix = "bq://"
if bq_table_uri.startswith(prefix):
bq_table_uri = bq_table_uri[len(prefix) :]
# The colon-based "project:dataset.table" format is no longer supported:
# Invalid dataset ID "bigquery-public-data:chicago_taxi_trips".
# Dataset IDs must be alphanumeric (plus underscores and dashes) and must be at most 1024 characters long.
# Using dot-based "project.dataset.table" format instead.
bq_table_uri = bq_table_uri.replace(":", ".")
# Loading bigquery lazily to avoid auto-loading it when importing vertexai
from google.cloud import bigquery # pylint: disable=g-import-not-at-top
client = bigquery.Client(project=project, credentials=credentials)
table = client.get_table(bq_table_uri)
schema = table.schema
return {
field_name
for field in schema
for field_name in _ColumnNamesDataset._get_bq_schema_field_names_recursively(
field
)
}

View File

@@ -0,0 +1,927 @@
# -*- coding: utf-8 -*-
# Copyright 2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union
from google.api_core import operation
from google.auth import credentials as auth_credentials
from google.cloud.aiplatform import base
from google.cloud.aiplatform import initializer
from google.cloud.aiplatform import utils
from google.cloud.aiplatform.compat.services import dataset_service_client
from google.cloud.aiplatform.compat.types import (
dataset as gca_dataset,
dataset_service as gca_dataset_service,
encryption_spec as gca_encryption_spec,
io as gca_io,
)
from google.cloud.aiplatform.datasets import _datasources
from google.protobuf import field_mask_pb2
from google.protobuf import json_format
_LOGGER = base.Logger(__name__)
class _Dataset(base.VertexAiResourceNounWithFutureManager):
"""Managed dataset resource for Vertex AI."""
client_class = utils.DatasetClientWithOverride
_resource_noun = "datasets"
_getter_method = "get_dataset"
_list_method = "list_datasets"
_delete_method = "delete_dataset"
_parse_resource_name_method = "parse_dataset_path"
_format_resource_name_method = "dataset_path"
_supported_metadata_schema_uris: Tuple[str] = ()
def __init__(
self,
dataset_name: str,
project: Optional[str] = None,
location: Optional[str] = None,
credentials: Optional[auth_credentials.Credentials] = None,
):
"""Retrieves an existing managed dataset given a dataset name or ID.
Args:
dataset_name (str):
Required. A fully-qualified dataset resource name or dataset ID.
Example: "projects/123/locations/us-central1/datasets/456" or
"456" when project and location are initialized or passed.
project (str):
Optional project to retrieve dataset from. If not set, project
set in aiplatform.init will be used.
location (str):
Optional location to retrieve dataset from. If not set, location
set in aiplatform.init will be used.
credentials (auth_credentials.Credentials):
Custom credentials to use to retrieve this Dataset. Overrides
credentials set in aiplatform.init.
"""
super().__init__(
project=project,
location=location,
credentials=credentials,
resource_name=dataset_name,
)
self._gca_resource = self._get_gca_resource(resource_name=dataset_name)
self._validate_metadata_schema_uri()
@property
def metadata_schema_uri(self) -> str:
"""The metadata schema uri of this dataset resource."""
self._assert_gca_resource_is_available()
return self._gca_resource.metadata_schema_uri
def _validate_metadata_schema_uri(self) -> None:
"""Validate the metadata_schema_uri of retrieved dataset resource.
Raises:
ValueError: If the dataset type of the retrieved dataset resource is
not supported by the class.
"""
if self._supported_metadata_schema_uris and (
self.metadata_schema_uri not in self._supported_metadata_schema_uris
):
raise ValueError(
f"{self.__class__.__name__} class can not be used to retrieve "
f"dataset resource {self.resource_name}, check the dataset type"
)
@classmethod
def create(
cls,
# TODO(b/223262536): Make the display_name parameter optional in the next major release
display_name: str,
metadata_schema_uri: str,
gcs_source: Optional[Union[str, Sequence[str]]] = None,
bq_source: Optional[str] = None,
import_schema_uri: Optional[str] = None,
data_item_labels: Optional[Dict] = None,
project: Optional[str] = None,
location: Optional[str] = None,
credentials: Optional[auth_credentials.Credentials] = None,
request_metadata: Optional[Sequence[Tuple[str, str]]] = (),
labels: Optional[Dict[str, str]] = None,
encryption_spec_key_name: Optional[str] = None,
sync: bool = True,
create_request_timeout: Optional[float] = None,
) -> "_Dataset":
"""Creates a new dataset and optionally imports data into dataset when
source and import_schema_uri are passed.
Args:
display_name (str):
Required. The user-defined name of the Dataset.
The name can be up to 128 characters long and can be consist
of any UTF-8 characters.
metadata_schema_uri (str):
Required. Points to a YAML file stored on Google Cloud Storage
describing additional information about the Dataset. The schema
is defined as an OpenAPI 3.0.2 Schema Object. The schema files
that can be used here are found in gs://google-cloud-
aiplatform/schema/dataset/metadata/.
gcs_source (Union[str, Sequence[str]]):
Google Cloud Storage URI(-s) to the
input file(s). May contain wildcards. For more
information on wildcards, see
https://cloud.google.com/storage/docs/gsutil/addlhelp/WildcardNames.
examples:
str: "gs://bucket/file.csv"
Sequence[str]: ["gs://bucket/file1.csv", "gs://bucket/file2.csv"]
bq_source (str):
BigQuery URI to the input table.
example:
"bq://project.dataset.table_name"
import_schema_uri (str):
Points to a YAML file stored on Google Cloud
Storage describing the import format. Validation will be
done against the schema. The schema is defined as an
`OpenAPI 3.0.2 Schema
Object <https://tinyurl.com/y538mdwt>`__.
data_item_labels (Dict):
Labels that will be applied to newly imported DataItems. If
an identical DataItem as one being imported already exists
in the Dataset, then these labels will be appended to these
of the already existing one, and if labels with identical
key is imported before, the old label value will be
overwritten. If two DataItems are identical in the same
import data operation, the labels will be combined and if
key collision happens in this case, one of the values will
be picked randomly. Two DataItems are considered identical
if their content bytes are identical (e.g. image bytes or
pdf bytes). These labels will be overridden by Annotation
labels specified inside index file referenced by
``import_schema_uri``,
e.g. jsonl file.
This arg is not for specifying the annotation name or the
training target of your data, but for some global labels of
the dataset. E.g.,
'data_item_labels={"aiplatform.googleapis.com/ml_use":"training"}'
specifies that all the uploaded data are used for training.
project (str):
Project to upload this dataset to. Overrides project set in
aiplatform.init.
location (str):
Location to upload this dataset to. Overrides location set in
aiplatform.init.
credentials (auth_credentials.Credentials):
Custom credentials to use to upload this dataset. Overrides
credentials set in aiplatform.init.
request_metadata (Sequence[Tuple[str, str]]):
Strings which should be sent along with the request as metadata.
labels (Dict[str, str]):
Optional. Labels with user-defined metadata to organize your datasets.
Label keys and values can be no longer than 64 characters
(Unicode codepoints), can only contain lowercase letters, numeric
characters, underscores and dashes. International characters are allowed.
No more than 64 user labels can be associated with one Dataset
(System labels are excluded).
See https://goo.gl/xmQnxf for more information and examples of labels.
System reserved label keys are prefixed with "aiplatform.googleapis.com/"
and are immutable.
encryption_spec_key_name (Optional[str]):
Optional. The Cloud KMS resource identifier of the customer
managed encryption key used to protect the dataset. Has the
form:
``projects/my-project/locations/my-region/keyRings/my-kr/cryptoKeys/my-key``.
The key needs to be in the same region as where the compute
resource is created.
If set, this Dataset and all sub-resources of this Dataset will be secured by this key.
Overrides encryption_spec_key_name set in aiplatform.init.
sync (bool):
Whether to execute this method synchronously. If False, this method
will be executed in concurrent Future and any downstream object will
be immediately returned and synced when the Future has completed.
create_request_timeout (float):
Optional. The timeout for the create request in seconds.
Returns:
dataset (Dataset):
Instantiated representation of the managed dataset resource.
"""
if not display_name:
display_name = cls._generate_display_name()
utils.validate_display_name(display_name)
if labels:
utils.validate_labels(labels)
api_client = cls._instantiate_client(location=location, credentials=credentials)
datasource = _datasources.create_datasource(
metadata_schema_uri=metadata_schema_uri,
import_schema_uri=import_schema_uri,
gcs_source=gcs_source,
bq_source=bq_source,
data_item_labels=data_item_labels,
)
return cls._create_and_import(
api_client=api_client,
parent=initializer.global_config.common_location_path(
project=project, location=location
),
display_name=display_name,
metadata_schema_uri=metadata_schema_uri,
datasource=datasource,
project=project or initializer.global_config.project,
location=location or initializer.global_config.location,
credentials=credentials or initializer.global_config.credentials,
request_metadata=request_metadata,
labels=labels,
encryption_spec=initializer.global_config.get_encryption_spec(
encryption_spec_key_name=encryption_spec_key_name
),
sync=sync,
create_request_timeout=create_request_timeout,
)
@classmethod
@base.optional_sync()
def _create_and_import(
cls,
api_client: dataset_service_client.DatasetServiceClient,
parent: str,
display_name: str,
metadata_schema_uri: str,
datasource: _datasources.Datasource,
project: str,
location: str,
credentials: Optional[auth_credentials.Credentials],
request_metadata: Optional[Sequence[Tuple[str, str]]] = (),
labels: Optional[Dict[str, str]] = None,
encryption_spec: Optional[gca_encryption_spec.EncryptionSpec] = None,
sync: bool = True,
create_request_timeout: Optional[float] = None,
import_request_timeout: Optional[float] = None,
) -> "_Dataset":
"""Creates a new dataset and optionally imports data into dataset when
source and import_schema_uri are passed.
Args:
api_client (dataset_service_client.DatasetServiceClient):
An instance of DatasetServiceClient with the correct api_endpoint
already set based on user's preferences.
parent (str):
Required. Also known as common location path, that usually contains the
project and location that the user provided to the upstream method.
Example: "projects/my-prj/locations/us-central1"
display_name (str):
Required. The user-defined name of the Dataset.
The name can be up to 128 characters long and can be consist
of any UTF-8 characters.
metadata_schema_uri (str):
Required. Points to a YAML file stored on Google Cloud Storage
describing additional information about the Dataset. The schema
is defined as an OpenAPI 3.0.2 Schema Object. The schema files
that can be used here are found in gs://google-cloud-
aiplatform/schema/dataset/metadata/.
datasource (_datasources.Datasource):
Required. Datasource for creating a dataset for Vertex AI.
project (str):
Required. Project to upload this model to. Overrides project set in
aiplatform.init.
location (str):
Required. Location to upload this model to. Overrides location set in
aiplatform.init.
credentials (Optional[auth_credentials.Credentials]):
Custom credentials to use to upload this model. Overrides
credentials set in aiplatform.init.
request_metadata (Sequence[Tuple[str, str]]):
Strings which should be sent along with the request as metadata.
labels (Dict[str, str]):
Optional. Labels with user-defined metadata to organize your Tensorboards.
Label keys and values can be no longer than 64 characters
(Unicode codepoints), can only contain lowercase letters, numeric
characters, underscores and dashes. International characters are allowed.
No more than 64 user labels can be associated with one Tensorboard
(System labels are excluded).
See https://goo.gl/xmQnxf for more information and examples of labels.
System reserved label keys are prefixed with "aiplatform.googleapis.com/"
and are immutable.
encryption_spec (Optional[gca_encryption_spec.EncryptionSpec]):
Optional. The Cloud KMS customer managed encryption key used to protect the dataset.
The key needs to be in the same region as where the compute
resource is created.
If set, this Dataset and all sub-resources of this Dataset will be secured by this key.
sync (bool):
Whether to execute this method synchronously. If False, this method
will be executed in concurrent Future and any downstream object will
be immediately returned and synced when the Future has completed.
create_request_timeout (float):
Optional. The timeout for the create request in seconds.
import_request_timeout (float):
Optional. The timeout for the import request in seconds.
Returns:
dataset (Dataset):
Instantiated representation of the managed dataset resource.
"""
create_dataset_lro = cls._create(
api_client=api_client,
parent=parent,
display_name=display_name,
metadata_schema_uri=metadata_schema_uri,
datasource=datasource,
request_metadata=request_metadata,
labels=labels,
encryption_spec=encryption_spec,
create_request_timeout=create_request_timeout,
)
_LOGGER.log_create_with_lro(cls, create_dataset_lro)
created_dataset = create_dataset_lro.result(timeout=None)
_LOGGER.log_create_complete(cls, created_dataset, "ds")
dataset_obj = cls(
dataset_name=created_dataset.name,
project=project,
location=location,
credentials=credentials,
)
# Import if import datasource is DatasourceImportable
if isinstance(datasource, _datasources.DatasourceImportable):
dataset_obj._import_and_wait(
datasource, import_request_timeout=import_request_timeout
)
return dataset_obj
def _import_and_wait(
self,
datasource,
import_request_timeout: Optional[float] = None,
):
_LOGGER.log_action_start_against_resource(
"Importing",
"data",
self,
)
import_lro = self._import(
datasource=datasource, import_request_timeout=import_request_timeout
)
_LOGGER.log_action_started_against_resource_with_lro(
"Import", "data", self.__class__, import_lro
)
import_lro.result(timeout=None)
_LOGGER.log_action_completed_against_resource("data", "imported", self)
@classmethod
def _create(
cls,
api_client: dataset_service_client.DatasetServiceClient,
parent: str,
display_name: str,
metadata_schema_uri: str,
datasource: _datasources.Datasource,
request_metadata: Sequence[Tuple[str, str]] = (),
labels: Optional[Dict[str, str]] = None,
encryption_spec: Optional[gca_encryption_spec.EncryptionSpec] = None,
create_request_timeout: Optional[float] = None,
) -> operation.Operation:
"""Creates a new managed dataset by directly calling API client.
Args:
api_client (dataset_service_client.DatasetServiceClient):
An instance of DatasetServiceClient with the correct api_endpoint
already set based on user's preferences.
parent (str):
Required. Also known as common location path, that usually contains the
project and location that the user provided to the upstream method.
Example: "projects/my-prj/locations/us-central1"
display_name (str):
Required. The user-defined name of the Dataset.
The name can be up to 128 characters long and can be consist
of any UTF-8 characters.
metadata_schema_uri (str):
Required. Points to a YAML file stored on Google Cloud Storage
describing additional information about the Dataset. The schema
is defined as an OpenAPI 3.0.2 Schema Object. The schema files
that can be used here are found in gs://google-cloud-
aiplatform/schema/dataset/metadata/.
datasource (_datasources.Datasource):
Required. Datasource for creating a dataset for Vertex AI.
request_metadata (Sequence[Tuple[str, str]]):
Strings which should be sent along with the create_dataset
request as metadata. Usually to specify special dataset config.
labels (Dict[str, str]):
Optional. Labels with user-defined metadata to organize your Tensorboards.
Label keys and values can be no longer than 64 characters
(Unicode codepoints), can only contain lowercase letters, numeric
characters, underscores and dashes. International characters are allowed.
No more than 64 user labels can be associated with one Tensorboard
(System labels are excluded).
See https://goo.gl/xmQnxf for more information and examples of labels.
System reserved label keys are prefixed with "aiplatform.googleapis.com/"
and are immutable.
encryption_spec (Optional[gca_encryption_spec.EncryptionSpec]):
Optional. The Cloud KMS customer managed encryption key used to protect the dataset.
The key needs to be in the same region as where the compute
resource is created.
If set, this Dataset and all sub-resources of this Dataset will be secured by this key.
create_request_timeout (float):
Optional. The timeout for the create request in seconds.
Returns:
operation (Operation):
An object representing a long-running operation.
"""
gapic_dataset = gca_dataset.Dataset(
display_name=display_name,
metadata_schema_uri=metadata_schema_uri,
metadata=datasource.dataset_metadata,
labels=labels,
encryption_spec=encryption_spec,
)
return api_client.create_dataset(
parent=parent,
dataset=gapic_dataset,
metadata=request_metadata,
timeout=create_request_timeout,
)
def _import(
self,
datasource: _datasources.DatasourceImportable,
import_request_timeout: Optional[float] = None,
) -> operation.Operation:
"""Imports data into managed dataset by directly calling API client.
Args:
datasource (_datasources.DatasourceImportable):
Required. Datasource for importing data to an existing dataset for Vertex AI.
import_request_timeout (float):
Optional. The timeout for the import request in seconds.
Returns:
operation (Operation):
An object representing a long-running operation.
"""
return self.api_client.import_data(
name=self.resource_name,
import_configs=[datasource.import_data_config],
timeout=import_request_timeout,
)
@base.optional_sync(return_input_arg="self")
def import_data(
self,
gcs_source: Union[str, Sequence[str]],
import_schema_uri: str,
data_item_labels: Optional[Dict] = None,
sync: bool = True,
import_request_timeout: Optional[float] = None,
) -> "_Dataset":
"""Upload data to existing managed dataset.
Args:
gcs_source (Union[str, Sequence[str]]):
Required. Google Cloud Storage URI(-s) to the
input file(s). May contain wildcards. For more
information on wildcards, see
https://cloud.google.com/storage/docs/gsutil/addlhelp/WildcardNames.
examples:
str: "gs://bucket/file.csv"
Sequence[str]: ["gs://bucket/file1.csv", "gs://bucket/file2.csv"]
import_schema_uri (str):
Required. Points to a YAML file stored on Google Cloud
Storage describing the import format. Validation will be
done against the schema. The schema is defined as an
`OpenAPI 3.0.2 Schema
Object <https://tinyurl.com/y538mdwt>`__.
data_item_labels (Dict):
Labels that will be applied to newly imported DataItems. If
an identical DataItem as one being imported already exists
in the Dataset, then these labels will be appended to these
of the already existing one, and if labels with identical
key is imported before, the old label value will be
overwritten. If two DataItems are identical in the same
import data operation, the labels will be combined and if
key collision happens in this case, one of the values will
be picked randomly. Two DataItems are considered identical
if their content bytes are identical (e.g. image bytes or
pdf bytes). These labels will be overridden by Annotation
labels specified inside index file referenced by
``import_schema_uri``,
e.g. jsonl file.
This arg is not for specifying the annotation name or the
training target of your data, but for some global labels of
the dataset. E.g.,
'data_item_labels={"aiplatform.googleapis.com/ml_use":"training"}'
specifies that all the uploaded data are used for training.
sync (bool):
Whether to execute this method synchronously. If False, this method
will be executed in concurrent Future and any downstream object will
be immediately returned and synced when the Future has completed.
import_request_timeout (float):
Optional. The timeout for the import request in seconds.
Returns:
dataset (Dataset):
Instantiated representation of the managed dataset resource.
"""
datasource = _datasources.create_datasource(
metadata_schema_uri=self.metadata_schema_uri,
import_schema_uri=import_schema_uri,
gcs_source=gcs_source,
data_item_labels=data_item_labels,
)
self._import_and_wait(
datasource=datasource, import_request_timeout=import_request_timeout
)
return self
def _validate_and_convert_export_split(
self,
split: Union[Dict[str, str], Dict[str, float]],
) -> Union[gca_dataset.ExportFilterSplit, gca_dataset.ExportFractionSplit]:
"""
Validates the split for data export. Valid splits are dicts
encoding the contents of proto messages ExportFilterSplit or
ExportFractionSplit. If the split is valid, this function returns
the corresponding convertered proto message.
split (Union[Dict[str, str], Dict[str, float]]):
The instructions how the export data should be split between the
training, validation and test sets.
"""
if len(split) != 3:
raise ValueError(
"The provided split for data export does not provide enough"
"information. It must have three fields, mapping to training,"
"validation and test splits respectively."
)
if not ("training_filter" in split or "training_fraction" in split):
raise ValueError(
"The provided filter for data export does not provide enough"
"information. It must have three fields, mapping to training,"
"validation and test respectively."
)
if "training_filter" in split:
if (
"validation_filter" in split
and "test_filter" in split
and isinstance(split["training_filter"], str)
and isinstance(split["validation_filter"], str)
and isinstance(split["test_filter"], str)
):
return gca_dataset.ExportFilterSplit(
training_filter=split["training_filter"],
validation_filter=split["validation_filter"],
test_filter=split["test_filter"],
)
else:
raise ValueError(
"The provided ExportFilterSplit does not contain all"
"three required fields: training_filter, "
"validation_filter and test_filter."
)
else:
if (
"validation_fraction" in split
and "test_fraction" in split
and isinstance(split["training_fraction"], float)
and isinstance(split["validation_fraction"], float)
and isinstance(split["test_fraction"], float)
):
return gca_dataset.ExportFractionSplit(
training_fraction=split["training_fraction"],
validation_fraction=split["validation_fraction"],
test_fraction=split["test_fraction"],
)
else:
raise ValueError(
"The provided ExportFractionSplit does not contain all"
"three required fields: training_fraction, "
"validation_fraction and test_fraction."
)
def _get_completed_export_data_operation(
self,
output_dir: str,
export_use: Optional[gca_dataset.ExportDataConfig.ExportUse] = None,
annotation_filter: Optional[str] = None,
saved_query_id: Optional[str] = None,
annotation_schema_uri: Optional[str] = None,
split: Optional[
Union[gca_dataset.ExportFilterSplit, gca_dataset.ExportFractionSplit]
] = None,
) -> gca_dataset_service.ExportDataResponse:
self.wait()
# TODO(b/171311614): Add support for BigQuery export path
export_data_config = gca_dataset.ExportDataConfig(
gcs_destination=gca_io.GcsDestination(output_uri_prefix=output_dir)
)
if export_use is not None:
export_data_config.export_use = export_use
if annotation_filter is not None:
export_data_config.annotation_filter = annotation_filter
if saved_query_id is not None:
export_data_config.saved_query_id = saved_query_id
if annotation_schema_uri is not None:
export_data_config.annotation_schema_uri = annotation_schema_uri
if split is not None:
if isinstance(split, gca_dataset.ExportFilterSplit):
export_data_config.filter_split = split
elif isinstance(split, gca_dataset.ExportFractionSplit):
export_data_config.fraction_split = split
_LOGGER.log_action_start_against_resource("Exporting", "data", self)
export_lro = self.api_client.export_data(
name=self.resource_name, export_config=export_data_config
)
_LOGGER.log_action_started_against_resource_with_lro(
"Export", "data", self.__class__, export_lro
)
export_data_response = export_lro.result()
_LOGGER.log_action_completed_against_resource("data", "export", self)
return export_data_response
# TODO(b/174751568) add optional sync support
def export_data(self, output_dir: str) -> Sequence[str]:
"""Exports data to output dir to GCS.
Args:
output_dir (str):
Required. The Google Cloud Storage location where the output is to
be written to. In the given directory a new directory will be
created with name:
``export-data-<dataset-display-name>-<timestamp-of-export-call>``
where timestamp is in YYYYMMDDHHMMSS format. All export
output will be written into that directory. Inside that
directory, annotations with the same schema will be grouped
into sub directories which are named with the corresponding
annotations' schema title. Inside these sub directories, a
schema.yaml will be created to describe the output format.
If the uri doesn't end with '/', a '/' will be automatically
appended. The directory is created if it doesn't exist.
Returns:
exported_files (Sequence[str]):
All of the files that are exported in this export operation.
"""
return self._get_completed_export_data_operation(output_dir).exported_files
def export_data_for_custom_training(
self,
output_dir: str,
annotation_filter: Optional[str] = None,
saved_query_id: Optional[str] = None,
annotation_schema_uri: Optional[str] = None,
split: Optional[Union[Dict[str, str], Dict[str, float]]] = None,
) -> Dict[str, Any]:
"""Exports data to output dir to GCS for custom training use case.
Example annotation_schema_uri (image classification):
gs://google-cloud-aiplatform/schema/dataset/annotation/image_classification_1.0.0.yaml
Example split (filter split):
{
"training_filter": "labels.aiplatform.googleapis.com/ml_use=training",
"validation_filter": "labels.aiplatform.googleapis.com/ml_use=validation",
"test_filter": "labels.aiplatform.googleapis.com/ml_use=test",
}
Example split (fraction split):
{
"training_fraction": 0.7,
"validation_fraction": 0.2,
"test_fraction": 0.1,
}
Args:
output_dir (str):
Required. The Google Cloud Storage location where the output is to
be written to. In the given directory a new directory will be
created with name:
``export-data-<dataset-display-name>-<timestamp-of-export-call>``
where timestamp is in YYYYMMDDHHMMSS format. All export
output will be written into that directory. Inside that
directory, annotations with the same schema will be grouped
into sub directories which are named with the corresponding
annotations' schema title. Inside these sub directories, a
schema.yaml will be created to describe the output format.
If the uri doesn't end with '/', a '/' will be automatically
appended. The directory is created if it doesn't exist.
annotation_filter (str):
Optional. An expression for filtering what part of the Dataset
is to be exported.
Only Annotations that match this filter will be exported.
The filter syntax is the same as in
[ListAnnotations][DatasetService.ListAnnotations].
saved_query_id (str):
Optional. The ID of a SavedQuery (annotation set) under this
Dataset used for filtering Annotations for training.
Only used for custom training data export use cases.
Only applicable to Datasets that have SavedQueries.
Only Annotations that are associated with this SavedQuery are
used in respectively training. When used in conjunction with
annotations_filter, the Annotations used for training are
filtered by both saved_query_id and annotations_filter.
Only one of saved_query_id and annotation_schema_uri should be
specified as both of them represent the same thing: problem
type.
annotation_schema_uri (str):
Optional. The Cloud Storage URI that points to a YAML file
describing the annotation schema. The schema is defined as an
OpenAPI 3.0.2 Schema Object. The schema files that can be used
here are found in
gs://google-cloud-aiplatform/schema/dataset/annotation/, note
that the chosen schema must be consistent with
metadata_schema_uri of this Dataset.
Only used for custom training data export use cases.
Only applicable if this Dataset that have DataItems and
Annotations.
Only Annotations that both match this schema and belong to
DataItems not ignored by the split method are used in
respectively training, validation or test role, depending on the
role of the DataItem they are on.
When used in conjunction with annotations_filter, the
Annotations used for training are filtered by both
annotations_filter and annotation_schema_uri.
split (Union[Dict[str, str], Dict[str, float]]):
The instructions how the export data should be split between the
training, validation and test sets.
Returns:
export_data_response (Dict):
Response message for DatasetService.ExportData in Dictionary
format.
"""
split = self._validate_and_convert_export_split(split)
return json_format.MessageToDict(
self._get_completed_export_data_operation(
output_dir,
gca_dataset.ExportDataConfig.ExportUse.CUSTOM_CODE_TRAINING,
annotation_filter,
saved_query_id,
annotation_schema_uri,
split,
)._pb
)
def update(
self,
*,
display_name: Optional[str] = None,
labels: Optional[Dict[str, str]] = None,
description: Optional[str] = None,
update_request_timeout: Optional[float] = None,
) -> "_Dataset":
"""Update the dataset.
Updatable fields:
- ``display_name``
- ``description``
- ``labels``
Args:
display_name (str):
Optional. The user-defined name of the Dataset.
The name can be up to 128 characters long and can be consist
of any UTF-8 characters.
labels (Dict[str, str]):
Optional. Labels with user-defined metadata to organize your Tensorboards.
Label keys and values can be no longer than 64 characters
(Unicode codepoints), can only contain lowercase letters, numeric
characters, underscores and dashes. International characters are allowed.
No more than 64 user labels can be associated with one Tensorboard
(System labels are excluded).
See https://goo.gl/xmQnxf for more information and examples of labels.
System reserved label keys are prefixed with "aiplatform.googleapis.com/"
and are immutable.
description (str):
Optional. The description of the Dataset.
update_request_timeout (float):
Optional. The timeout for the update request in seconds.
Returns:
dataset (Dataset):
Updated dataset.
"""
update_mask = field_mask_pb2.FieldMask()
if display_name:
update_mask.paths.append("display_name")
if labels:
update_mask.paths.append("labels")
if description:
update_mask.paths.append("description")
update_dataset = gca_dataset.Dataset(
name=self.resource_name,
display_name=display_name,
description=description,
labels=labels,
)
self._gca_resource = self.api_client.update_dataset(
dataset=update_dataset,
update_mask=update_mask,
timeout=update_request_timeout,
)
return self
@classmethod
def list(
cls,
filter: Optional[str] = None,
order_by: Optional[str] = None,
project: Optional[str] = None,
location: Optional[str] = None,
credentials: Optional[auth_credentials.Credentials] = None,
) -> List[base.VertexAiResourceNoun]:
"""List all instances of this Dataset resource.
Example Usage:
aiplatform.TabularDataset.list(
filter='labels.my_key="my_value"',
order_by='display_name'
)
Args:
filter (str):
Optional. An expression for filtering the results of the request.
For field names both snake_case and camelCase are supported.
order_by (str):
Optional. A comma-separated list of fields to order by, sorted in
ascending order. Use "desc" after a field name for descending.
Supported fields: `display_name`, `create_time`, `update_time`
project (str):
Optional. Project to retrieve list from. If not set, project
set in aiplatform.init will be used.
location (str):
Optional. Location to retrieve list from. If not set, location
set in aiplatform.init will be used.
credentials (auth_credentials.Credentials):
Optional. Custom credentials to use to retrieve list. Overrides
credentials set in aiplatform.init.
Returns:
List[base.VertexAiResourceNoun] - A list of Dataset resource objects
"""
dataset_subclass_filter = (
lambda gapic_obj: gapic_obj.metadata_schema_uri
in cls._supported_metadata_schema_uris
)
return cls._list_with_local_order(
cls_filter=dataset_subclass_filter,
filter=filter,
order_by=order_by,
project=project,
location=location,
credentials=credentials,
)

View File

@@ -0,0 +1,198 @@
# -*- coding: utf-8 -*-
# Copyright 2020 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from typing import Dict, Optional, Sequence, Tuple, Union
from google.auth import credentials as auth_credentials
from google.cloud.aiplatform import datasets
from google.cloud.aiplatform.datasets import _datasources
from google.cloud.aiplatform import initializer
from google.cloud.aiplatform import schema
from google.cloud.aiplatform import utils
class ImageDataset(datasets._Dataset):
"""A managed image dataset resource for Vertex AI.
Use this class to work with a managed image dataset. To create a managed
image dataset, you need a datasource file in CSV format and a schema file in
YAML format. A schema is optional for a custom model. You put the CSV file
and the schema into Cloud Storage buckets.
Use image data for the following objectives:
* Single-label classification. For more information, see
[Prepare image training data for single-label classification](https://cloud.google.com/vertex-ai/docs/image-data/classification/prepare-data#single-label-classification).
* Multi-label classification. For more information, see [Prepare image training data for multi-label classification](https://cloud.google.com/vertex-ai/docs/image-data/classification/prepare-data#multi-label-classification).
* Object detection. For more information, see [Prepare image training data
for object detection](https://cloud.google.com/vertex-ai/docs/image-data/object-detection/prepare-data).
The following code shows you how to create an image dataset by importing data from
a CSV datasource file and a YAML schema file. The schema file you use
depends on whether your image dataset is used for single-label
classification, multi-label classification, or object detection.
```py
my_dataset = aiplatform.ImageDataset.create(
display_name="my-image-dataset",
gcs_source=['gs://path/to/my/image-dataset.csv'],
import_schema_uri=['gs://path/to/my/schema.yaml']
)
```
"""
_supported_metadata_schema_uris: Optional[Tuple[str]] = (
schema.dataset.metadata.image,
)
@classmethod
def create(
cls,
display_name: Optional[str] = None,
gcs_source: Optional[Union[str, Sequence[str]]] = None,
import_schema_uri: Optional[str] = None,
data_item_labels: Optional[Dict] = None,
project: Optional[str] = None,
location: Optional[str] = None,
credentials: Optional[auth_credentials.Credentials] = None,
request_metadata: Optional[Sequence[Tuple[str, str]]] = (),
labels: Optional[Dict[str, str]] = None,
encryption_spec_key_name: Optional[str] = None,
sync: bool = True,
create_request_timeout: Optional[float] = None,
) -> "ImageDataset":
"""Creates a new image dataset.
Optionally imports data into the dataset when a source and
`import_schema_uri` are passed in.
Args:
display_name (str):
Optional. The user-defined name of the dataset. The name must
contain 128 or fewer UTF-8 characters.
gcs_source (Union[str, Sequence[str]]):
Optional. The URI to one or more Google Cloud Storage buckets
that contain your datasets. For example, `str:
"gs://bucket/file.csv"` or `Sequence[str]:
["gs://bucket/file1.csv", "gs://bucket/file2.csv"]`.
import_schema_uri (str):
Optional. A URI for a YAML file stored in Cloud Storage that
describes the import schema used to validate the
dataset. The schema is an
[OpenAPI 3.0.2 Schema](https://tinyurl.com/y538mdwt) object.
data_item_labels (Dict):
Optional. A dictionary of label information. Each dictionary
item contains a label and a label key. Each image in the dataset
includes one dictionary of label information. If a data item is
added or merged into a dataset, and that data item contains an
image that's identical to an image thats already in the
dataset, then the data items are merged. If two identical labels
are detected during the merge, each with a different label key,
then one of the label and label key dictionary items is randomly
chosen to be into the merged data item. Images and documents are
compared using their binary data (bytes), not on their content.
If annotation labels are referenced in a schema specified by the
`import_schema_url` parameter, then the labels in the
`data_item_labels` dictionary are overriden by the annotations.
project (str):
Optional. The name of the Google Cloud project to which this
`ImageDataset` is uploaded. This overrides the project that
was set by `aiplatform.init`.
location (str):
Optional. The Google Cloud region where this dataset is uploaded. This
region overrides the region that was set by `aiplatform.init`.
credentials (auth_credentials.Credentials):
Optional. The credentials that are used to upload the
`ImageDataset`. These credentials override the credentials set
by `aiplatform.init`.
request_metadata (Sequence[Tuple[str, str]]):
Optional. Strings that contain metadata that's sent with the request.
labels (Dict[str, str]):
Optional. Labels with user-defined metadata to organize your
Vertex AI Tensorboards. The maximum length of a key and of a
value is 64 unicode characters. Labels and keys can contain only
lowercase letters, numeric characters, underscores, and dashes.
International characters are allowed. No more than 64 user
labels can be associated with one Tensorboard (system labels are
excluded). For more information and examples of using labels, see
[Using labels to organize Google Cloud Platform resources](https://goo.gl/xmQnxf).
System reserved label keys are prefixed with
`aiplatform.googleapis.com/` and are immutable.
encryption_spec_key_name (Optional[str]):
Optional. The Cloud KMS resource identifier of the customer
managed encryption key that's used to protect the dataset. The
format of the key is
`projects/my-project/locations/my-region/keyRings/my-kr/cryptoKeys/my-key`.
The key needs to be in the same region as where the compute
resource is created.
If `encryption_spec_key_name` is set, this image dataset and
all of its sub-resources are secured by this key.
This `encryption_spec_key_name` overrides the
`encryption_spec_key_name` set by `aiplatform.init`.
sync (bool):
If `true`, the `create` method creates an image dataset
synchronously. If `false`, the `create` method creates an image
dataset asynchronously.
create_request_timeout (float):
Optional. The number of seconds for the timeout of the create
request.
Returns:
image_dataset (ImageDataset):
An instantiated representation of the managed `ImageDataset`
resource.
"""
if not display_name:
display_name = cls._generate_display_name()
utils.validate_display_name(display_name)
if labels:
utils.validate_labels(labels)
api_client = cls._instantiate_client(location=location, credentials=credentials)
metadata_schema_uri = schema.dataset.metadata.image
datasource = _datasources.create_datasource(
metadata_schema_uri=metadata_schema_uri,
import_schema_uri=import_schema_uri,
gcs_source=gcs_source,
data_item_labels=data_item_labels,
)
return cls._create_and_import(
api_client=api_client,
parent=initializer.global_config.common_location_path(
project=project, location=location
),
display_name=display_name,
metadata_schema_uri=metadata_schema_uri,
datasource=datasource,
project=project or initializer.global_config.project,
location=location or initializer.global_config.location,
credentials=credentials or initializer.global_config.credentials,
request_metadata=request_metadata,
labels=labels,
encryption_spec=initializer.global_config.get_encryption_spec(
encryption_spec_key_name=encryption_spec_key_name
),
sync=sync,
create_request_timeout=create_request_timeout,
)

View File

@@ -0,0 +1,318 @@
# -*- coding: utf-8 -*-
# Copyright 2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from typing import Dict, Optional, Sequence, Tuple, Union, TYPE_CHECKING
from google.auth import credentials as auth_credentials
from google.cloud.aiplatform import base
from google.cloud.aiplatform import datasets
from google.cloud.aiplatform.datasets import _datasources
from google.cloud.aiplatform import initializer
from google.cloud.aiplatform import schema
from google.cloud.aiplatform import utils
if TYPE_CHECKING:
from google.cloud import bigquery
_AUTOML_TRAINING_MIN_ROWS = 1000
_LOGGER = base.Logger(__name__)
class TabularDataset(datasets._ColumnNamesDataset):
"""A managed tabular dataset resource for Vertex AI.
Use this class to work with tabular datasets. You can use a CSV file, BigQuery, or a pandas
[`DataFrame`](https://pandas.pydata.org/docs/reference/api/pandas.DataFrame.html)
to create a tabular dataset. For more information about paging through
BigQuery data, see [Read data with BigQuery API using
pagination](https://cloud.google.com/bigquery/docs/paging-results). For more
information about tabular data, see [Tabular
data](https://cloud.google.com/vertex-ai/docs/training-overview#tabular_data).
The following code shows you how to create and import a tabular
dataset with a CSV file.
```py
my_dataset = aiplatform.TabularDataset.create(
display_name="my-dataset", gcs_source=['gs://path/to/my/dataset.csv'])
```
Contrary to unstructured datasets, creating and importing a tabular dataset
can only be done in a single step.
If you create a tabular dataset with a pandas
[`DataFrame`](https://pandas.pydata.org/docs/reference/api/pandas.DataFrame.html),
you need to use a BigQuery table to stage the data for Vertex AI:
```py
my_dataset = aiplatform.TabularDataset.create_from_dataframe(
df_source=my_pandas_dataframe,
staging_path=f"bq://{bq_dataset_id}.table-unique"
)
```
"""
_supported_metadata_schema_uris: Optional[Tuple[str]] = (
schema.dataset.metadata.tabular,
)
@classmethod
def create(
cls,
display_name: Optional[str] = None,
gcs_source: Optional[Union[str, Sequence[str]]] = None,
bq_source: Optional[str] = None,
project: Optional[str] = None,
location: Optional[str] = None,
credentials: Optional[auth_credentials.Credentials] = None,
request_metadata: Optional[Sequence[Tuple[str, str]]] = (),
labels: Optional[Dict[str, str]] = None,
encryption_spec_key_name: Optional[str] = None,
sync: bool = True,
create_request_timeout: Optional[float] = None,
) -> "TabularDataset":
"""Creates a tabular dataset.
Args:
display_name (str):
Optional. The user-defined name of the dataset. The name must
contain 128 or fewer UTF-8 characters.
gcs_source (Union[str, Sequence[str]]):
Optional. The URI to one or more Google Cloud Storage buckets that contain
your datasets. For example, `str: "gs://bucket/file.csv"` or
`Sequence[str]: ["gs://bucket/file1.csv",
"gs://bucket/file2.csv"]`. Either `gcs_source` or `bq_source` must be specified.
bq_source (str):
Optional. The URI to a BigQuery table that's used as an input source. For
example, `bq://project.dataset.table_name`. Either `gcs_source`
or `bq_source` must be specified.
project (str):
Optional. The name of the Google Cloud project to which this
`TabularDataset` is uploaded. This overrides the project that
was set by `aiplatform.init`.
location (str):
Optional. The Google Cloud region where this dataset is uploaded. This
region overrides the region that was set by `aiplatform.init`.
credentials (auth_credentials.Credentials):
Optional. The credentials that are used to upload the `TabularDataset`.
These credentials override the credentials set by
`aiplatform.init`.
request_metadata (Sequence[Tuple[str, str]]):
Optional. Strings that contain metadata that's sent with the request.
labels (Dict[str, str]):
Optional. Labels with user-defined metadata to organize your
Vertex AI Tensorboards. The maximum length of a key and of a
value is 64 unicode characters. Labels and keys can contain only
lowercase letters, numeric characters, underscores, and dashes.
International characters are allowed. No more than 64 user
labels can be associated with one Tensorboard (system labels are
excluded). For more information and examples of using labels, see
[Using labels to organize Google Cloud Platform resources](https://goo.gl/xmQnxf).
System reserved label keys are prefixed with
`aiplatform.googleapis.com/` and are immutable.
encryption_spec_key_name (Optional[str]):
Optional. The Cloud KMS resource identifier of the customer
managed encryption key that's used to protect the dataset. The
format of the key is
`projects/my-project/locations/my-region/keyRings/my-kr/cryptoKeys/my-key`.
The key needs to be in the same region as where the compute
resource is created.
If `encryption_spec_key_name` is set, this `TabularDataset` and
all of its sub-resources are secured by this key.
This `encryption_spec_key_name` overrides the
`encryption_spec_key_name` set by `aiplatform.init`.
sync (bool):
If `true`, the `create` method creates a tabular dataset
synchronously. If `false`, the `create` method creates a tabular
dataset asynchronously.
create_request_timeout (float):
Optional. The number of seconds for the timeout of the create
request.
Returns:
tabular_dataset (TabularDataset):
An instantiated representation of the managed `TabularDataset` resource.
"""
if not display_name:
display_name = cls._generate_display_name()
utils.validate_display_name(display_name)
if labels:
utils.validate_labels(labels)
api_client = cls._instantiate_client(location=location, credentials=credentials)
metadata_schema_uri = schema.dataset.metadata.tabular
datasource = _datasources.create_datasource(
metadata_schema_uri=metadata_schema_uri,
gcs_source=gcs_source,
bq_source=bq_source,
)
return cls._create_and_import(
api_client=api_client,
parent=initializer.global_config.common_location_path(
project=project, location=location
),
display_name=display_name,
metadata_schema_uri=metadata_schema_uri,
datasource=datasource,
project=project or initializer.global_config.project,
location=location or initializer.global_config.location,
credentials=credentials or initializer.global_config.credentials,
request_metadata=request_metadata,
labels=labels,
encryption_spec=initializer.global_config.get_encryption_spec(
encryption_spec_key_name=encryption_spec_key_name
),
sync=sync,
create_request_timeout=create_request_timeout,
)
@classmethod
def create_from_dataframe(
cls,
df_source: "pd.DataFrame", # noqa: F821 - skip check for undefined name 'pd'
staging_path: str,
bq_schema: Optional[Union[str, "bigquery.SchemaField"]] = None,
display_name: Optional[str] = None,
project: Optional[str] = None,
location: Optional[str] = None,
credentials: Optional[auth_credentials.Credentials] = None,
) -> "TabularDataset":
"""Creates a new tabular dataset from a pandas `DataFrame`.
Args:
df_source (pd.DataFrame):
Required. A pandas
[`DataFrame`](https://pandas.pydata.org/docs/reference/api/pandas.DataFrame.html)
containing the source data for ingestion as a `TabularDataset`.
This method uses the data types from the provided `DataFrame`
when the `TabularDataset` is created.
staging_path (str):
Required. The BigQuery table used to stage the data for Vertex
AI. Because Vertex AI maintains a reference to this source to
create the `TabularDataset`, you shouldn't delete this BigQuery
table. For example: `bq://my-project.my-dataset.my-table`.
If the specified BigQuery table doesn't exist, then the table is
created for you. If the provided BigQuery table already exists,
and the schemas of the BigQuery table and your DataFrame match,
then the data in your local `DataFrame` is appended to the table.
The location of the BigQuery table must conform to the
[BigQuery location requirements](https://cloud.google.com/vertex-ai/docs/general/locations#bq-locations).
bq_schema (Optional[Union[str, bigquery.SchemaField]]):
Optional. If not set, BigQuery autodetects the schema using the
column types of your `DataFrame`. If set, BigQuery uses the
schema you provide when the staging table is created. For more
information,
see the BigQuery
[`LoadJobConfig.schema`](https://cloud.google.com/python/docs/reference/bigquery/latest/google.cloud.bigquery.job.LoadJobConfig#google_cloud_bigquery_job_LoadJobConfig_schema)
property.
display_name (str):
Optional. The user-defined name of the `Dataset`. The name must
contain 128 or fewer UTF-8 characters.
project (str):
Optional. The project to upload this dataset to. This overrides
the project set using `aiplatform.init`.
location (str):
Optional. The location to upload this dataset to. This overrides
the location set using `aiplatform.init`.
credentials (auth_credentials.Credentials):
Optional. The custom credentials used to upload this dataset.
This overrides credentials set using `aiplatform.init`.
Returns:
tabular_dataset (TabularDataset):
An instantiated representation of the managed `TabularDataset` resource.
"""
if staging_path.startswith("bq://"):
bq_staging_path = staging_path[len("bq://") :]
else:
raise ValueError(
"Only BigQuery staging paths are supported. Provide a staging path in the format `bq://your-project.your-dataset.your-table`."
)
try:
import pyarrow # noqa: F401 - skip check for 'pyarrow' which is required when using 'google.cloud.bigquery'
except ImportError:
raise ImportError(
"Pyarrow is not installed, and is required to use the BigQuery client."
'Please install the SDK using "pip install google-cloud-aiplatform[datasets]"'
)
import pandas.api.types as pd_types
if any(
[
pd_types.is_datetime64_any_dtype(df_source[column])
for column in df_source.columns
]
):
_LOGGER.info(
"Received datetime-like column in the dataframe. Please note that the column could be interpreted differently in BigQuery depending on which major version you are using. For more information, please reference the BigQuery v3 release notes here: https://github.com/googleapis/python-bigquery/releases/tag/v3.0.0"
)
if len(df_source) < _AUTOML_TRAINING_MIN_ROWS:
_LOGGER.info(
"Your DataFrame has %s rows and AutoML requires %s rows to train on tabular data. You can still train a custom model once your dataset has been uploaded to Vertex, but you will not be able to use AutoML for training."
% (len(df_source), _AUTOML_TRAINING_MIN_ROWS),
)
# Loading bigquery lazily to avoid auto-loading it when importing vertexai
from google.cloud import bigquery # pylint: disable=g-import-not-at-top
bigquery_client = bigquery.Client(
project=project or initializer.global_config.project,
credentials=credentials or initializer.global_config.credentials,
)
try:
parquet_options = bigquery.format_options.ParquetOptions()
parquet_options.enable_list_inference = True
job_config = bigquery.LoadJobConfig(
source_format=bigquery.SourceFormat.PARQUET,
parquet_options=parquet_options,
)
if bq_schema:
job_config.schema = bq_schema
job = bigquery_client.load_table_from_dataframe(
dataframe=df_source, destination=bq_staging_path, job_config=job_config
)
job.result()
finally:
dataset_from_dataframe = cls.create(
display_name=display_name,
bq_source=staging_path,
project=project,
location=location,
credentials=credentials,
)
return dataset_from_dataframe
def import_data(self):
raise NotImplementedError(
f"{self.__class__.__name__} class does not support 'import_data'"
)

View File

@@ -0,0 +1,207 @@
# -*- coding: utf-8 -*-
# Copyright 2021 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from typing import Dict, Optional, Sequence, Tuple, Union
from google.auth import credentials as auth_credentials
from google.cloud.aiplatform import datasets
from google.cloud.aiplatform.datasets import _datasources
from google.cloud.aiplatform import initializer
from google.cloud.aiplatform import schema
from google.cloud.aiplatform import utils
class TextDataset(datasets._Dataset):
"""A managed text dataset resource for Vertex AI.
Use this class to work with a managed text dataset. To create a managed
text dataset, you need a datasource file in CSV format and a schema file in
YAML format. A schema is optional for a custom model. The CSV file and the
schema are accessed in Cloud Storage buckets.
Use text data for the following objectives:
* Classification. For more information, see
[Prepare text training data for classification](https://cloud.google.com/vertex-ai/docs/text-data/classification/prepare-data).
* Entity extraction. For more information, see
[Prepare text training data for entity extraction](https://cloud.google.com/vertex-ai/docs/text-data/entity-extraction/prepare-data).
* Sentiment analysis. For more information, see
[Prepare text training data for sentiment analysis](Prepare text training data for sentiment analysis).
The following code shows you how to create and import a text dataset with
a CSV datasource file and a YAML schema file. The schema file you use
depends on whether your text dataset is used for single-label
classification, multi-label classification, or object detection.
```py
my_dataset = aiplatform.TextDataset.create(
display_name="my-text-dataset",
gcs_source=['gs://path/to/my/text-dataset.csv'],
import_schema_uri=['gs://path/to/my/schema.yaml'],
)
```
"""
_supported_metadata_schema_uris: Optional[Tuple[str]] = (
schema.dataset.metadata.text,
)
@classmethod
def create(
cls,
display_name: Optional[str] = None,
gcs_source: Optional[Union[str, Sequence[str]]] = None,
import_schema_uri: Optional[str] = None,
data_item_labels: Optional[Dict] = None,
project: Optional[str] = None,
location: Optional[str] = None,
credentials: Optional[auth_credentials.Credentials] = None,
request_metadata: Optional[Sequence[Tuple[str, str]]] = (),
labels: Optional[Dict[str, str]] = None,
encryption_spec_key_name: Optional[str] = None,
sync: bool = True,
create_request_timeout: Optional[float] = None,
) -> "TextDataset":
"""Creates a new text dataset.
Optionally imports data into this dataset when a source and
`import_schema_uri` are passed in. The following is an example of how
this method is used:
```py
ds = aiplatform.TextDataset.create(
display_name='my-dataset',
gcs_source='gs://my-bucket/dataset.csv',
import_schema_uri=aiplatform.schema.dataset.ioformat.text.multi_label_classification
)
```
Args:
display_name (str):
Optional. The user-defined name of the dataset. The name must
contain 128 or fewer UTF-8 characters.
gcs_source (Union[str, Sequence[str]]):
Optional. The URI to one or more Google Cloud Storage buckets
that contain your datasets. For example, `str:
"gs://bucket/file.csv"` or `Sequence[str]:
["gs://bucket/file1.csv", "gs://bucket/file2.csv"]`.
import_schema_uri (str):
Optional. A URI for a YAML file stored in Cloud Storage that
describes the import schema used to validate the
dataset. The schema is an
[OpenAPI 3.0.2 Schema](https://tinyurl.com/y538mdwt) object.
data_item_labels (Dict):
Optional. A dictionary of label information. Each dictionary
item contains a label and a label key. Each item in the dataset
includes one dictionary of label information. If a data item is
added or merged into a dataset, and that data item contains an
image that's identical to an image thats already in the
dataset, then the data items are merged. If two identical labels
are detected during the merge, each with a different label key,
then one of the label and label key dictionary items is randomly
chosen to be into the merged data item. Data items are
compared using their binary data (bytes), not on their content.
If annotation labels are referenced in a schema specified by the
`import_schema_url` parameter, then the labels in the
`data_item_labels` dictionary are overriden by the annotations.
project (str):
Optional. The name of the Google Cloud project to which this
`TextDataset` is uploaded. This overrides the project that
was set by `aiplatform.init`.
location (str):
Optional. The Google Cloud region where this dataset is uploaded. This
region overrides the region that was set by `aiplatform.init`.
credentials (auth_credentials.Credentials):
Optional. The credentials that are used to upload the `TextDataset`.
These credentials override the credentials set by
`aiplatform.init`.
request_metadata (Sequence[Tuple[str, str]]):
Optional. Strings that contain metadata that's sent with the request.
labels (Dict[str, str]):
Optional. Labels with user-defined metadata to organize your
Vertex AI Tensorboards. The maximum length of a key and of a
value is 64 unicode characters. Labels and keys can contain only
lowercase letters, numeric characters, underscores, and dashes.
International characters are allowed. No more than 64 user
labels can be associated with one Tensorboard (system labels are
excluded). For more information and examples of using labels, see
[Using labels to organize Google Cloud Platform resources](https://goo.gl/xmQnxf).
System reserved label keys are prefixed with
`aiplatform.googleapis.com/` and are immutable.
encryption_spec_key_name (Optional[str]):
Optional. The Cloud KMS resource identifier of the customer
managed encryption key that's used to protect the dataset. The
format of the key is
`projects/my-project/locations/my-region/keyRings/my-kr/cryptoKeys/my-key`.
The key needs to be in the same region as where the compute
resource is created.
If `encryption_spec_key_name` is set, this `TextDataset` and
all of its sub-resources are secured by this key.
This `encryption_spec_key_name` overrides the
`encryption_spec_key_name` set by `aiplatform.init`.
sync (bool):
If `true`, the `create` method creates a text dataset
synchronously. If `false`, the `create` method creates a text
dataset asynchronously.
create_request_timeout (float):
Optional. The number of seconds for the timeout of the create
request.
Returns:
text_dataset (TextDataset):
An instantiated representation of the managed `TextDataset`
resource.
"""
if not display_name:
display_name = cls._generate_display_name()
utils.validate_display_name(display_name)
if labels:
utils.validate_labels(labels)
api_client = cls._instantiate_client(location=location, credentials=credentials)
metadata_schema_uri = schema.dataset.metadata.text
datasource = _datasources.create_datasource(
metadata_schema_uri=metadata_schema_uri,
import_schema_uri=import_schema_uri,
gcs_source=gcs_source,
data_item_labels=data_item_labels,
)
return cls._create_and_import(
api_client=api_client,
parent=initializer.global_config.common_location_path(
project=project, location=location
),
display_name=display_name,
metadata_schema_uri=metadata_schema_uri,
datasource=datasource,
project=project or initializer.global_config.project,
location=location or initializer.global_config.location,
credentials=credentials or initializer.global_config.credentials,
request_metadata=request_metadata,
labels=labels,
encryption_spec=initializer.global_config.get_encryption_spec(
encryption_spec_key_name=encryption_spec_key_name
),
sync=sync,
create_request_timeout=create_request_timeout,
)

View File

@@ -0,0 +1,186 @@
# -*- coding: utf-8 -*-
# Copyright 2020 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from typing import Dict, Optional, Sequence, Tuple, Union
from google.auth import credentials as auth_credentials
from google.cloud.aiplatform import datasets
from google.cloud.aiplatform.datasets import _datasources
from google.cloud.aiplatform import initializer
from google.cloud.aiplatform import schema
from google.cloud.aiplatform import utils
class TimeSeriesDataset(datasets._ColumnNamesDataset):
"""A managed time series dataset resource for Vertex AI.
Use this class to work with time series datasets. A time series is a dataset
that contains data recorded at different time intervals. The dataset
includes time and at least one variable that's dependent on time. You use a
time series dataset for forecasting predictions. For more information, see
[Forecasting overview](https://cloud.google.com/vertex-ai/docs/tabular-data/forecasting/overview).
You can create a managed time series dataset from CSV files in a Cloud
Storage bucket or from a BigQuery table.
The following code shows you how to create a `TimeSeriesDataset` with a CSV
file that has the time series dataset:
```py
my_dataset = aiplatform.TimeSeriesDataset.create(
display_name="my-dataset",
gcs_source=['gs://path/to/my/dataset.csv'],
)
```
The following code shows you how to create with a `TimeSeriesDataset` with a
BigQuery table file that has the time series dataset:
```py
my_dataset = aiplatform.TimeSeriesDataset.create(
display_name="my-dataset",
bq_source=['bq://path/to/my/bigquerydataset.train'],
)
```
"""
_supported_metadata_schema_uris: Optional[Tuple[str]] = (
schema.dataset.metadata.time_series,
)
@classmethod
def create(
cls,
display_name: Optional[str] = None,
gcs_source: Optional[Union[str, Sequence[str]]] = None,
bq_source: Optional[str] = None,
project: Optional[str] = None,
location: Optional[str] = None,
credentials: Optional[auth_credentials.Credentials] = None,
request_metadata: Optional[Sequence[Tuple[str, str]]] = (),
labels: Optional[Dict[str, str]] = None,
encryption_spec_key_name: Optional[str] = None,
sync: bool = True,
create_request_timeout: Optional[float] = None,
) -> "TimeSeriesDataset":
"""Creates a new time series dataset.
Args:
display_name (str):
Optional. The user-defined name of the dataset. The name must
contain 128 or fewer UTF-8 characters.
gcs_source (Union[str, Sequence[str]]):
The URI to one or more Google Cloud Storage buckets that contain
your datasets. For example, `str: "gs://bucket/file.csv"` or
`Sequence[str]: ["gs://bucket/file1.csv",
"gs://bucket/file2.csv"]`.
bq_source (str):
A BigQuery URI for the input table. For example,
`bq://project.dataset.table_name`.
project (str):
The name of the Google Cloud project to which this
`TimeSeriesDataset` is uploaded. This overrides the project that
was set by `aiplatform.init`.
location (str):
The Google Cloud region where this dataset is uploaded. This
region overrides the region that was set by `aiplatform.init`.
credentials (auth_credentials.Credentials):
The credentials that are used to upload the `TimeSeriesDataset`.
These credentials override the credentials set by
`aiplatform.init`.
request_metadata (Sequence[Tuple[str, str]]):
Strings that contain metadata that's sent with the request.
labels (Dict[str, str]):
Optional. Labels with user-defined metadata to organize your
Vertex AI Tensorboards. The maximum length of a key and of a
value is 64 unicode characters. Labels and keys can contain only
lowercase letters, numeric characters, underscores, and dashes.
International characters are allowed. No more than 64 user
labels can be associated with one Tensorboard (system labels are
excluded). For more information and examples of using labels, see
[Using labels to organize Google Cloud Platform resources](https://goo.gl/xmQnxf).
System reserved label keys are prefixed with
`aiplatform.googleapis.com/` and are immutable.
encryption_spec_key_name (Optional[str]):
Optional. The Cloud KMS resource identifier of the customer
managed encryption key that's used to protect the dataset. The
format of the key is
`projects/my-project/locations/my-region/keyRings/my-kr/cryptoKeys/my-key`.
The key needs to be in the same region as where the compute
resource is created.
If `encryption_spec_key_name` is set, this time series dataset
and all of its sub-resources are secured by this key.
This `encryption_spec_key_name` overrides the
`encryption_spec_key_name` set by `aiplatform.init`.
create_request_timeout (float):
Optional. The number of seconds for the timeout of the create
request.
sync (bool):
If `true`, the `create` method creates a time series dataset
synchronously. If `false`, the `create` method creates a time
series dataset asynchronously.
Returns:
time_series_dataset (TimeSeriesDataset):
An instantiated representation of the managed
`TimeSeriesDataset` resource.
"""
if not display_name:
display_name = cls._generate_display_name()
utils.validate_display_name(display_name)
if labels:
utils.validate_labels(labels)
api_client = cls._instantiate_client(location=location, credentials=credentials)
metadata_schema_uri = schema.dataset.metadata.time_series
datasource = _datasources.create_datasource(
metadata_schema_uri=metadata_schema_uri,
gcs_source=gcs_source,
bq_source=bq_source,
)
return cls._create_and_import(
api_client=api_client,
parent=initializer.global_config.common_location_path(
project=project, location=location
),
display_name=display_name,
metadata_schema_uri=metadata_schema_uri,
datasource=datasource,
project=project or initializer.global_config.project,
location=location or initializer.global_config.location,
credentials=credentials or initializer.global_config.credentials,
request_metadata=request_metadata,
labels=labels,
encryption_spec=initializer.global_config.get_encryption_spec(
encryption_spec_key_name=encryption_spec_key_name
),
sync=sync,
create_request_timeout=create_request_timeout,
)
def import_data(self):
raise NotImplementedError(
f"{self.__class__.__name__} class does not support 'import_data'"
)

View File

@@ -0,0 +1,199 @@
# -*- coding: utf-8 -*-
# Copyright 2020 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from typing import Dict, Optional, Sequence, Tuple, Union
from google.auth import credentials as auth_credentials
from google.cloud.aiplatform import datasets
from google.cloud.aiplatform.datasets import _datasources
from google.cloud.aiplatform import initializer
from google.cloud.aiplatform import schema
from google.cloud.aiplatform import utils
class VideoDataset(datasets._Dataset):
"""A managed video dataset resource for Vertex AI.
Use this class to work with a managed video dataset. To create a video
dataset, you need a datasource in CSV format and a schema in YAML format.
The CSV file and the schema are accessed in Cloud Storage buckets.
Use video data for the following objectives:
Classification. For more information, see Classification schema files.
Action recognition. For more information, see Action recognition schema
files. Object tracking. For more information, see Object tracking schema
files. The following code shows you how to create and import a dataset to
train a video classification model. The schema file you use depends on
whether you use your video dataset for action classification, recognition,
or object tracking.
```py
my_dataset = aiplatform.VideoDataset.create(
gcs_source=['gs://path/to/my/dataset.csv'],
import_schema_uri=['gs://aip.schema.dataset.ioformat.video.classification.yaml']
)
```
"""
_supported_metadata_schema_uris: Optional[Tuple[str]] = (
schema.dataset.metadata.video,
)
@classmethod
def create(
cls,
display_name: Optional[str] = None,
gcs_source: Optional[Union[str, Sequence[str]]] = None,
import_schema_uri: Optional[str] = None,
data_item_labels: Optional[Dict] = None,
project: Optional[str] = None,
location: Optional[str] = None,
credentials: Optional[auth_credentials.Credentials] = None,
request_metadata: Optional[Sequence[Tuple[str, str]]] = (),
labels: Optional[Dict[str, str]] = None,
encryption_spec_key_name: Optional[str] = None,
sync: bool = True,
create_request_timeout: Optional[float] = None,
) -> "VideoDataset":
"""Creates a new video dataset.
Optionally imports data into the dataset when a source and
`import_schema_uri` are passed in. The following is an example of how
this method is used:
```py
my_dataset = aiplatform.VideoDataset.create(
gcs_source=['gs://path/to/my/dataset.csv'],
import_schema_uri=['gs://aip.schema.dataset.ioformat.video.classification.yaml']
)
```
Args:
display_name (str):
Optional. The user-defined name of the dataset. The name must
contain 128 or fewer UTF-8 characters.
gcs_source (Union[str, Sequence[str]]):
The URI to one or more Google Cloud Storage buckets that contain
your datasets. For example, `str: "gs://bucket/file.csv"` or
`Sequence[str]: ["gs://bucket/file1.csv",
"gs://bucket/file2.csv"]`.
import_schema_uri (str):
A URI for a YAML file stored in Cloud Storage that
describes the import schema used to validate the
dataset. The schema is an
[OpenAPI 3.0.2 Schema](https://tinyurl.com/y538mdwt) object.
data_item_labels (Dict):
Optional. A dictionary of label information. Each dictionary
item contains a label and a label key. Each item in the dataset
includes one dictionary of label information. If a data item is
added or merged into a dataset, and that data item contains an
image that's identical to an image thats already in the
dataset, then the data items are merged. If two identical labels
are detected during the merge, each with a different label key,
then one of the label and label key dictionary items is randomly
chosen to be into the merged data item. Dataset items are
compared using their binary data (bytes), not on their content.
If annotation labels are referenced in a schema specified by the
`import_schema_url` parameter, then the labels in the
`data_item_labels` dictionary are overriden by the annotations.
project (str):
The name of the Google Cloud project to which this
`VideoDataset` is uploaded. This overrides the project that
was set by `aiplatform.init`.
location (str):
The Google Cloud region where this dataset is uploaded. This
region overrides the region that was set by `aiplatform.init`.
credentials (auth_credentials.Credentials):
The credentials that are used to upload the `VideoDataset`.
These credentials override the credentials set by
`aiplatform.init`.
request_metadata (Sequence[Tuple[str, str]]):
Strings that contain metadata that's sent with the request.
labels (Dict[str, str]):
Optional. Labels with user-defined metadata to organize your
Vertex AI Tensorboards. The maximum length of a key and of a
value is 64 unicode characters. Labels and keys can contain only
lowercase letters, numeric characters, underscores, and dashes.
International characters are allowed. No more than 64 user
labels can be associated with one Tensorboard (system labels are
excluded). For more information and examples of using labels, see
[Using labels to organize Google Cloud Platform resources](https://goo.gl/xmQnxf).
System reserved label keys are prefixed with
`aiplatform.googleapis.com/` and are immutable.
encryption_spec_key_name (Optional[str]):
Optional. The Cloud KMS resource identifier of the customer
managed encryption key that's used to protect the dataset. The
format of the key is
`projects/my-project/locations/my-region/keyRings/my-kr/cryptoKeys/my-key`.
The key needs to be in the same region as where the compute
resource is created.
If `encryption_spec_key_name` is set, this `VideoDataset` and
all of its sub-resources are secured by this key.
This `encryption_spec_key_name` overrides the
`encryption_spec_key_name` set by `aiplatform.init`.
sync (bool):
If `true`, the `create` method creates a video dataset
synchronously. If `false`, the `create` mdthod creates a video
dataset asynchronously.
create_request_timeout (float):
Optional. The number of seconds for the timeout of the create
request.
Returns:
video_dataset (VideoDataset):
An instantiated representation of the managed
`VideoDataset` resource.
"""
if not display_name:
display_name = cls._generate_display_name()
utils.validate_display_name(display_name)
if labels:
utils.validate_labels(labels)
api_client = cls._instantiate_client(location=location, credentials=credentials)
metadata_schema_uri = schema.dataset.metadata.video
datasource = _datasources.create_datasource(
metadata_schema_uri=metadata_schema_uri,
import_schema_uri=import_schema_uri,
gcs_source=gcs_source,
data_item_labels=data_item_labels,
)
return cls._create_and_import(
api_client=api_client,
parent=initializer.global_config.common_location_path(
project=project, location=location
),
display_name=display_name,
metadata_schema_uri=metadata_schema_uri,
datasource=datasource,
project=project or initializer.global_config.project,
location=location or initializer.global_config.location,
credentials=credentials or initializer.global_config.credentials,
request_metadata=request_metadata,
labels=labels,
encryption_spec=initializer.global_config.get_encryption_spec(
encryption_spec_key_name=encryption_spec_key_name
),
sync=sync,
create_request_timeout=create_request_timeout,
)

View File

@@ -0,0 +1,16 @@
# -*- coding: utf-8 -*-
# Copyright 2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

View File

@@ -0,0 +1,558 @@
# -*- coding: utf-8 -*-
# Copyright 2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import json
import logging
import os
from pathlib import Path
import textwrap
from typing import Dict, List, Optional
from shlex import quote
from google.cloud.aiplatform.docker_utils import local_util
from google.cloud.aiplatform.docker_utils.errors import DockerError
from google.cloud.aiplatform.docker_utils.utils import (
DEFAULT_HOME,
DEFAULT_WORKDIR,
Image,
Package,
)
from google.cloud.aiplatform.utils import path_utils
_logger = logging.getLogger(__name__)
def _generate_copy_command(
from_path: str, to_path: str, comment: Optional[str] = None
) -> str:
"""Returns a Dockerfile entry that copies a file from host to container.
Args:
from_path (str):
Required. The path of the source in host.
to_path (str):
Required. The path to the destination in the container.
comment (str):
Optional. A comment explaining the copy operation.
Returns:
The generated copy command used in Dockerfile.
"""
cmd = "COPY {}".format(json.dumps([from_path, to_path]))
if comment is not None:
formatted_comment = "\n# ".join(comment.split("\n"))
return textwrap.dedent(
"""
# {}
{}
""".format(
formatted_comment,
cmd,
)
)
return cmd
def _prepare_dependency_entries(
setup_path: Optional[str] = None,
requirements_path: Optional[str] = None,
extra_packages: Optional[List[str]] = None,
extra_requirements: Optional[List[str]] = None,
extra_dirs: Optional[List[str]] = None,
force_reinstall: bool = False,
pip_command: str = "pip",
) -> str:
"""Returns the Dockerfile entries required to install dependencies.
Args:
setup_path (str):
Optional. The path that points to a setup.py.
requirements_path (str):
Optional. The path that points to a requirements.txt file.
extra_packages (List[str]):
Optional. The list of user custom dependency packages to install.
extra_requirements (List[str]):
Optional. The list of required dependencies to be installed from remote resource archives.
extra_dirs (List[str]):
Optional. The directories other than the work_dir required.
force_reinstall (bool):
Required. Whether or not force reinstall all packages even if they are already up-to-date.
pip_command (str):
Required. The pip command used for install packages.
Returns:
The dependency installation command used in Dockerfile.
"""
ret = ""
if setup_path is not None:
ret += _generate_copy_command(
setup_path,
"./setup.py",
comment="setup.py file specified, thus copy it to the docker container.",
) + textwrap.dedent(
"""
RUN {} install --no-cache-dir {} .
""".format(
pip_command,
"--force-reinstall" if force_reinstall else "",
)
)
if requirements_path is not None:
ret += textwrap.dedent(
"""
RUN {} install --no-cache-dir {} -r {}
""".format(
pip_command,
"--force-reinstall" if force_reinstall else "",
requirements_path,
)
)
if extra_packages is not None:
for package in extra_packages:
ret += textwrap.dedent(
"""
RUN {} install --no-cache-dir {} {}
""".format(
pip_command,
"--force-reinstall" if force_reinstall else "",
quote(package),
)
)
if extra_requirements is not None:
for requirement in extra_requirements:
ret += textwrap.dedent(
"""
RUN {} install --no-cache-dir {} {}
""".format(
pip_command,
"--force-reinstall" if force_reinstall else "",
quote(requirement),
)
)
if extra_dirs is not None:
for directory in extra_dirs:
ret += "\n{}\n".format(_generate_copy_command(directory, directory))
return ret
def _prepare_entrypoint(package: Package, python_command: str = "python") -> str:
"""Generates dockerfile entry to set the container entrypoint.
Args:
package (Package):
Required. The main application copied to the container.
python_command (str):
Required. The python command used for running python code.
Returns:
A string with Dockerfile directives to set ENTRYPOINT or "".
"""
exec_str = ""
# Needs to use json so that quotes print as double quotes, not single quotes.
if package.python_module is not None:
exec_str = json.dumps([python_command, "-m", package.python_module])
elif package.script is not None:
_, ext = os.path.splitext(package.script)
executable = [python_command] if ext == ".py" else ["/bin/bash"]
exec_str = json.dumps(executable + [package.script])
if not exec_str:
return ""
return "\nENTRYPOINT {}\n".format(exec_str)
def _copy_source_directory() -> str:
"""Returns the Dockerfile entry required to copy the package to the image.
The Docker build context has been changed to host_workdir. We copy all
the files to the working directory of images.
Returns:
The generated package related copy command used in Dockerfile.
"""
copy_code = _generate_copy_command(
".", # Dockefile context location has been changed to host_workdir
".", # Copy all the files to the working directory of images.
comment="Copy the source directory into the docker container.",
)
return "\n{}\n".format(copy_code)
def _prepare_exposed_ports(exposed_ports: Optional[List[int]] = None) -> str:
"""Returns the Dockerfile entries required to expose ports in containers.
Args:
exposed_ports (List[int]):
Optional. The exposed ports that the container listens on at runtime.
Returns:
The generated port expose command used in Dockerfile.
"""
ret = ""
if exposed_ports is None:
return ret
for port in exposed_ports:
ret += "\nEXPOSE {}\n".format(port)
return ret
def _prepare_environment_variables(
environment_variables: Optional[Dict[str, str]] = None
) -> str:
"""Returns the Dockerfile entries required to set environment variables in containers.
Args:
environment_variables (Dict[str, str]):
Optional. The environment variables to be set in the container.
Returns:
The generated environment variable commands used in Dockerfile.
"""
ret = ""
if environment_variables is None:
return ret
for key, value in environment_variables.items():
ret += f"\nENV {key}={value}\n"
return ret
def _get_relative_path_to_workdir(
workdir: str,
path: Optional[str] = None,
value_name: str = "value",
) -> str:
"""Returns the relative path to the workdir.
Args:
workdir (str):
Required. The directory that the retrieved path relative to.
path (str):
Optional. The path to retrieve the relative path to the workdir.
value_name (str):
Required. The variable name specified in the exception message.
Returns:
The relative path to the workdir or None if path is None.
Raises:
ValueError: If the path does not exist or is not relative to the workdir.
"""
if path is None:
return None
if not Path(path).is_file():
raise ValueError(f'The {value_name} "{path}" must exist.')
if not path_utils._is_relative_to(path, workdir):
raise ValueError(f'The {value_name} "{path}" must be in "{workdir}".')
abs_path = Path(path).expanduser().resolve()
abs_workdir = Path(workdir).expanduser().resolve()
return Path(abs_path).relative_to(abs_workdir).as_posix()
def make_dockerfile(
base_image: str,
main_package: Package,
container_workdir: str,
container_home: str,
requirements_path: Optional[str] = None,
setup_path: Optional[str] = None,
extra_requirements: Optional[List[str]] = None,
extra_packages: Optional[List[str]] = None,
extra_dirs: Optional[List[str]] = None,
exposed_ports: Optional[List[int]] = None,
environment_variables: Optional[Dict[str, str]] = None,
pip_command: str = "pip",
python_command: str = "python",
) -> str:
"""Generates a Dockerfile for building an image.
It builds on a specified base image to create a container that:
- installs any dependency specified in a requirements.txt or a setup.py file,
and any specified dependency packages existing locally or found from PyPI
- copies all source needed by the main module, and potentially injects an
entrypoint that, on run, will run that main module
Args:
base_image (str):
Required. The ID or name of the base image to initialize the build stage.
main_package (Package):
Required. The main application to execute.
container_workdir (str):
Required. The working directory in the container.
container_home (str):
Required. The $HOME directory in the container.
requirements_path (str):
Optional. The path to a local requirements.txt file.
setup_path (str):
Optional. The path to a local setup.py file.
extra_requirements (List[str]):
Optional. The list of required dependencies to install from PyPI.
extra_packages (List[str]):
Optional. The list of user custom dependency packages to install.
extra_dirs: (List[str]):
Optional. The directories other than the work_dir required to be in the container.
exposed_ports (List[int]):
Optional. The exposed ports that the container listens on at runtime.
environment_variables (Dict[str, str]):
Optional. The environment variables to be set in the container.
pip_command (str):
Required. The pip command used for install packages.
python_command (str):
Required. The python command used for running python code.
Returns:
A string that represents the content of a Dockerfile.
"""
dockerfile = textwrap.dedent(
"""
FROM {base_image}
# Keeps Python from generating .pyc files in the container
ENV PYTHONDONTWRITEBYTECODE=1
""".format(
base_image=base_image,
)
)
dockerfile += _prepare_exposed_ports(exposed_ports)
dockerfile += _prepare_entrypoint(main_package, python_command=python_command)
dockerfile += textwrap.dedent(
"""
# The directory is created by root. This sets permissions so that any user can
# access the folder.
RUN mkdir -m 777 -p {workdir} {container_home}
WORKDIR {workdir}
ENV HOME={container_home}
""".format(
workdir=quote(container_workdir),
container_home=quote(container_home),
)
)
# Installs extra requirements which do not involve user source code.
dockerfile += _prepare_dependency_entries(
requirements_path=None,
setup_path=None,
extra_requirements=extra_requirements,
extra_packages=None,
extra_dirs=None,
force_reinstall=True,
pip_command=pip_command,
)
dockerfile += _prepare_environment_variables(
environment_variables=environment_variables
)
# Copies user code to the image.
dockerfile += _copy_source_directory()
# Installs packages from requirements_path.
dockerfile += _prepare_dependency_entries(
requirements_path=requirements_path,
setup_path=None,
extra_requirements=None,
extra_packages=None,
extra_dirs=None,
force_reinstall=True,
pip_command=pip_command,
)
# Installs additional packages from user code.
dockerfile += _prepare_dependency_entries(
requirements_path=None,
setup_path=setup_path,
extra_requirements=None,
extra_packages=extra_packages,
extra_dirs=extra_dirs,
force_reinstall=True,
pip_command=pip_command,
)
return dockerfile
def build_image(
base_image: str,
host_workdir: str,
output_image_name: str,
python_module: Optional[str] = None,
requirements_path: Optional[str] = None,
extra_requirements: Optional[List[str]] = None,
setup_path: Optional[str] = None,
extra_packages: Optional[List[str]] = None,
container_workdir: Optional[str] = None,
container_home: Optional[str] = None,
extra_dirs: Optional[List[str]] = None,
exposed_ports: Optional[List[int]] = None,
pip_command: str = "pip",
python_command: str = "python",
no_cache: bool = True,
platform: Optional[str] = None,
**kwargs,
) -> Image:
"""Builds a Docker image.
Generates a Dockerfile and passes it to `docker build` via stdin.
All output from the `docker build` process prints to stdout.
Args:
base_image (str):
Required. The ID or name of the base image to initialize the build stage.
host_workdir (str):
Required. The path indicating where all the required sources locates.
output_image_name (str):
Required. The name of the built image.
python_module (str):
Optional. The executable main script in form of a python module, if applicable.
requirements_path (str):
Optional. The path to a local file including required dependencies to install from PyPI.
extra_requirements (List[str]):
Optional. The list of required dependencies to install from PyPI.
setup_path (str):
Optional. The path to a local setup.py used for installing packages.
extra_packages (List[str]):
Optional. The list of user custom dependency packages to install.
container_workdir (str):
Optional. The working directory in the container.
container_home (str):
Optional. The $HOME directory in the container.
extra_dirs (List[str]):
Optional. The directories other than the work_dir required.
exposed_ports (List[int]):
Optional. The exposed ports that the container listens on at runtime.
pip_command (str):
Required. The pip command used for installing packages.
python_command (str):
Required. The python command used for running python scripts.
no_cache (bool):
Required. Do not use cache when building the image. Using build cache usually
reduces the image building time. See
https://docs.docker.com/develop/develop-images/dockerfile_best-practices/#leverage-build-cache
for more details.
platform (str):
Optional. The target platform for the Docker image build. See
https://docs.docker.com/build/building/multi-platform/#building-multi-platform-images
for more details.
**kwargs:
Other arguments to pass to underlying method that generates the Dockerfile.
Returns:
A Image class that contains info of the built image.
Raises:
DockerError: An error occurred when executing `docker build`
ValueError: If the needed code is not relative to the host workdir.
"""
tag_options = ["-t", output_image_name]
cache_args = ["--no-cache"] if no_cache else []
platform_args = ["--platform", platform] if platform is not None else []
command = (
["docker", "build"]
+ cache_args
+ platform_args
+ tag_options
+ ["--rm", "-f-", host_workdir]
)
requirements_relative_path = _get_relative_path_to_workdir(
host_workdir,
path=requirements_path,
value_name="requirements_path",
)
setup_relative_path = _get_relative_path_to_workdir(
host_workdir,
path=setup_path,
value_name="setup_path",
)
extra_packages_relative_paths = (
None
if extra_packages is None
else [
_get_relative_path_to_workdir(
host_workdir, path=extra_package, value_name="extra_packages"
)
for extra_package in extra_packages
if extra_package is not None
]
)
home_dir = container_home or DEFAULT_HOME
work_dir = container_workdir or DEFAULT_WORKDIR
# The package will be used in Docker, thus norm it to POSIX path format.
main_package = Package(
script=None,
package_path=host_workdir,
python_module=python_module,
)
dockerfile = make_dockerfile(
base_image,
main_package,
work_dir,
home_dir,
requirements_path=requirements_relative_path,
setup_path=setup_relative_path,
extra_requirements=extra_requirements,
extra_packages=extra_packages_relative_paths,
extra_dirs=extra_dirs,
exposed_ports=exposed_ports,
pip_command=pip_command,
python_command=python_command,
**kwargs,
)
joined_command = " ".join(command)
_logger.info("Running command: {}".format(joined_command))
return_code = local_util.execute_command(
command,
input_str=dockerfile,
)
if return_code == 0:
return Image(output_image_name, home_dir, work_dir)
else:
error_msg = textwrap.dedent(
"""
Docker failed with error code {code}.
Command: {cmd}
""".format(
code=return_code, cmd=joined_command
)
)
raise DockerError(error_msg, command, return_code)

View File

@@ -0,0 +1,60 @@
# -*- coding: utf-8 -*-
# Copyright 2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import textwrap
from typing import List, NoReturn
class Error(Exception):
"""A base exception for all user recoverable errors."""
def __init__(self, *args, **kwargs):
"""Initialize an Error."""
self.exit_code = kwargs.get("exit_code", 1)
class DockerError(Error):
"""Exception that passes info on a failed Docker command."""
def __init__(self, message, cmd, exit_code):
super(DockerError, self).__init__(message)
self.message = message
self.cmd = cmd
self.exit_code = exit_code
def raise_docker_error_with_command(command: List[str], return_code: int) -> NoReturn:
"""Raises DockerError with the given command and return code.
Args:
command (List(str)):
Required. The docker command that fails.
return_code (int):
Required. The return code from the command.
Raises:
DockerError which error message populated by the given command and return code.
"""
error_msg = textwrap.dedent(
"""
Docker failed with error code {code}.
Command: {cmd}
""".format(
code=return_code, cmd=" ".join(command)
)
)
raise DockerError(error_msg, command, return_code)

View File

@@ -0,0 +1,62 @@
# -*- coding: utf-8 -*-
# Copyright 2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import io
import logging
import subprocess
from typing import List, Optional
_logger = logging.getLogger(__name__)
def execute_command(
cmd: List[str],
input_str: Optional[str] = None,
) -> int:
"""Executes commands in subprocess.
Executes the supplied command with the supplied standard input string, streams
the output to stdout, and returns the process's return code.
Args:
cmd (List[str]):
Required. The strings to send in as the command.
input_str (str):
Optional. If supplied, it will be passed as stdin to the supplied command.
If None, stdin will get closed immediately.
Returns:
Return code of the process.
"""
with subprocess.Popen(
cmd,
stdin=subprocess.PIPE,
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT,
universal_newlines=False,
bufsize=1,
) as p:
if input_str:
p.stdin.write(input_str.encode("utf-8"))
p.stdin.close()
out = io.TextIOWrapper(p.stdout, newline="", encoding="utf-8", errors="replace")
for line in out:
_logger.info(line)
return p.returncode

View File

@@ -0,0 +1,279 @@
# -*- coding: utf-8 -*-
# Copyright 2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import logging
import os
from pathlib import Path
import re
from typing import Dict, List, Optional, Sequence
try:
import docker
except ImportError:
raise ImportError(
"Docker is not installed and is required to run containers. "
'Please install the SDK using `pip install "google-cloud-aiplatform[prediction]>=1.16.0"`.'
)
from google.cloud.aiplatform.constants import prediction
from google.cloud.aiplatform.docker_utils.utils import DEFAULT_MOUNTED_MODEL_DIRECTORY
from google.cloud.aiplatform.utils import prediction_utils
_logger = logging.getLogger(__name__)
_DEFAULT_CONTAINER_CRED_KEY_PATH = "/tmp/keys/cred_key.json"
_ADC_ENVIRONMENT_VARIABLE = "GOOGLE_APPLICATION_CREDENTIALS"
CONTAINER_RUNNING_STATUS = "running"
def _get_adc_environment_variable() -> Optional[str]:
"""Gets the value of the ADC environment variable.
Returns:
The value of the environment variable or None if unset.
"""
return os.environ.get(_ADC_ENVIRONMENT_VARIABLE)
def _replace_env_var_reference(
target: str,
env_vars: Dict[str, str],
) -> str:
"""Replaces the environment variable reference in the given string.
Variable references $(VAR_NAME) are expanded using the container's environment.
If a variable cannot be resolved, the reference in the input string will be unchanged.
The $(VAR_NAME) syntax can be escaped with a double $$, ie: $$(VAR_NAME). Escaped
references will never be expanded, regardless of whether the variable exists or not.
More info:
https://kubernetes.io/docs/tasks/inject-data-application/define-command-argument-container/#running-a-command-in-a-shell
Args:
target (str):
Required. The string to be replaced with the environment variable reference.
env_vars (Dict[str, str]):
Required. The environment variables used for reference.
Returns:
The updated string.
"""
# Replace env var references with env vars.
for key, value in env_vars.items():
target = re.sub(rf"(?<!\$)\$\({key}\)", str(value), target)
# Replace $$ with $.
target = re.sub(r"\$\$", "$", target)
return target
def run_prediction_container(
serving_container_image_uri: str,
artifact_uri: Optional[str] = None,
serving_container_predict_route: Optional[str] = None,
serving_container_health_route: Optional[str] = None,
serving_container_command: Optional[Sequence[str]] = None,
serving_container_args: Optional[Sequence[str]] = None,
serving_container_environment_variables: Optional[Dict[str, str]] = None,
serving_container_ports: Optional[Sequence[int]] = None,
credential_path: Optional[str] = None,
host_port: Optional[int] = None,
gpu_count: Optional[int] = None,
gpu_device_ids: Optional[List[str]] = None,
gpu_capabilities: Optional[List[List[str]]] = None,
) -> docker.models.containers.Container:
"""Runs a prediction container locally.
Args:
serving_container_image_uri (str):
Required. The URI of the Model serving container.
artifact_uri (str):
Optional. The path to the directory containing the Model artifact and any of its
supporting files. The path is either a GCS uri or the path to a local directory.
If this parameter is set to a GCS uri:
(1) `credential_path` must be specified for local prediction.
(2) The GCS uri will be passed directly to `Predictor.load`.
If this parameter is a local directory:
(1) The directory will be mounted to a default temporary model path.
(2) The mounted path will be passed to `Predictor.load`.
serving_container_predict_route (str):
Optional. An HTTP path to send prediction requests to the container, and
which must be supported by it. If not specified a default HTTP path will
be used by Vertex AI.
serving_container_health_route (str):
Optional. An HTTP path to send health check requests to the container, and which
must be supported by it. If not specified a standard HTTP path will be
used by Vertex AI.
serving_container_command (Sequence[str]):
Optional. The command with which the container is run. Not executed within a
shell. The Docker image's ENTRYPOINT is used if this is not provided.
Variable references $(VAR_NAME) are expanded using the container's
environment. If a variable cannot be resolved, the reference in the
input string will be unchanged. The $(VAR_NAME) syntax can be escaped
with a double $$, ie: $$(VAR_NAME). Escaped references will never be
expanded, regardless of whether the variable exists or not.
serving_container_args: (Sequence[str]):
Optional. The arguments to the command. The Docker image's CMD is used if this is
not provided. Variable references $(VAR_NAME) are expanded using the
container's environment. If a variable cannot be resolved, the reference
in the input string will be unchanged. The $(VAR_NAME) syntax can be
escaped with a double $$, ie: $$(VAR_NAME). Escaped references will
never be expanded, regardless of whether the variable exists or not.
serving_container_environment_variables (Dict[str, str]):
Optional. The environment variables that are to be present in the container.
Should be a dictionary where keys are environment variable names
and values are environment variable values for those names.
serving_container_ports (Sequence[int]):
Optional. Declaration of ports that are exposed by the container. This field is
primarily informational, it gives Vertex AI information about the
network connections the container uses. Listing or not a port here has
no impact on whether the port is actually exposed, any port listening on
the default "0.0.0.0" address inside a container will be accessible from
the network.
credential_path (str):
Optional. The path to the credential key that will be mounted to the container.
If it's unset, the environment variable, GOOGLE_APPLICATION_CREDENTIALS, will
be used if set.
host_port (int):
Optional. The port on the host that the port, AIP_HTTP_PORT, inside the container
will be exposed as. If it's unset, a random host port will be assigned.
gpu_count (int):
Optional. Number of devices to request. Set to -1 to request all available devices.
To use GPU, set either `gpu_count` or `gpu_device_ids`.
gpu_device_ids (List[str]):
Optional. This parameter corresponds to `NVIDIA_VISIBLE_DEVICES` in the NVIDIA
Runtime.
To use GPU, set either `gpu_count` or `gpu_device_ids`.
gpu_capabilities (List[List[str]]):
Optional. This parameter corresponds to `NVIDIA_DRIVER_CAPABILITIES` in the NVIDIA
Runtime. This must be set to use GPU. The outer list acts like an OR, and each
sub-list acts like an AND. The driver will try to satisfy one of the sub-lists.
Available capabilities for the NVIDIA driver can be found in
https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/user-guide.html#driver-capabilities.
Returns:
The container object running in the background.
Raises:
ValueError: If artifact_uri does not exist if artifact_uri is a path to a local directory,
or if credential_path or the file pointed by the environment variable
GOOGLE_APPLICATION_CREDENTIALS does not exist.
docker.errors.ImageNotFound: If the specified image does not exist.
docker.errors.APIError: If the server returns an error.
"""
client = docker.from_env()
envs = {}
if serving_container_environment_variables:
for key, value in serving_container_environment_variables.items():
envs[key] = _replace_env_var_reference(value, envs)
port = prediction_utils.get_prediction_aip_http_port(serving_container_ports)
envs[prediction.AIP_HTTP_PORT] = port
envs[prediction.AIP_HEALTH_ROUTE] = serving_container_health_route
envs[prediction.AIP_PREDICT_ROUTE] = serving_container_predict_route
volumes = []
envs[prediction.AIP_STORAGE_URI] = artifact_uri or ""
if artifact_uri and not artifact_uri.startswith(prediction_utils.GCS_URI_PREFIX):
artifact_uri_on_host = Path(artifact_uri).expanduser().resolve()
if not artifact_uri_on_host.exists():
raise ValueError(
"artifact_uri should be specified as either a GCS uri which starts with "
f"`{prediction_utils.GCS_URI_PREFIX}` or a path to a local directory. "
f'However, "{artifact_uri_on_host}" does not exist.'
)
for mounted_path in artifact_uri_on_host.rglob("*"):
relative_mounted_path = mounted_path.relative_to(artifact_uri_on_host)
volumes += [
f"{mounted_path}:{os.path.join(DEFAULT_MOUNTED_MODEL_DIRECTORY, relative_mounted_path)}"
]
envs[prediction.AIP_STORAGE_URI] = DEFAULT_MOUNTED_MODEL_DIRECTORY
credential_from_adc_env = credential_path is None
credential_path = credential_path or _get_adc_environment_variable()
if credential_path:
credential_path_on_host = Path(credential_path).expanduser().resolve()
if not credential_path_on_host.exists() and credential_from_adc_env:
raise ValueError(
f"The file from the environment variable {_ADC_ENVIRONMENT_VARIABLE} does "
f'not exist: "{credential_path}".'
)
elif not credential_path_on_host.exists() and not credential_from_adc_env:
raise ValueError(f'credential_path does not exist: "{credential_path}".')
credential_mount_path = _DEFAULT_CONTAINER_CRED_KEY_PATH
volumes = volumes + [f"{credential_path_on_host}:{credential_mount_path}"]
envs[_ADC_ENVIRONMENT_VARIABLE] = credential_mount_path
entrypoint = [
_replace_env_var_reference(i, envs) for i in serving_container_command or []
]
command = [
_replace_env_var_reference(i, envs) for i in serving_container_args or []
]
device_requests = None
if gpu_count or gpu_device_ids or gpu_capabilities:
device_requests = [
docker.types.DeviceRequest(
count=gpu_count,
device_ids=gpu_device_ids,
capabilities=gpu_capabilities,
)
]
container = client.containers.run(
serving_container_image_uri,
command=command if len(command) > 0 else None,
entrypoint=entrypoint if len(entrypoint) > 0 else None,
ports={port: host_port},
environment=envs,
volumes=volumes,
device_requests=device_requests,
detach=True,
)
return container
def print_container_logs(
container: docker.models.containers.Container,
start_index: Optional[int] = None,
message: Optional[str] = None,
) -> int:
"""Prints container logs.
Args:
container (docker.models.containers.Container):
Required. The container object to print the logs.
start_index (int):
Optional. The index of log entries to start printing.
message (str):
Optional. The message to be printed before printing the logs.
Returns:
The total number of log entries.
"""
if message is not None:
_logger.info(message)
logs = container.logs().decode("utf-8").strip().split("\n")
start_index = 0 if start_index is None else start_index
for i in range(start_index, len(logs)):
_logger.info(logs[i])
return len(logs)

View File

@@ -0,0 +1,51 @@
# -*- coding: utf-8 -*-
# Copyright 2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import collections
try:
import docker
except ImportError:
raise ImportError(
"Docker is not installed and is required to run containers. "
'Please install the SDK using `pip install "google-cloud-aiplatform[prediction]>=1.16.0"`.'
)
Package = collections.namedtuple("Package", ["script", "package_path", "python_module"])
Image = collections.namedtuple("Image", ["name", "default_home", "default_workdir"])
DEFAULT_HOME = "/home"
DEFAULT_WORKDIR = "/usr/app"
DEFAULT_MOUNTED_MODEL_DIRECTORY = "/tmp_cpr_local_model"
def check_image_exists_locally(image_name: str) -> bool:
"""Checks if an image exists locally.
Args:
image_name (str):
Required. The name of the image.
Returns:
Whether the image exists locally.
"""
client = docker.from_env()
try:
_ = client.images.get(image_name)
return True
except (docker.errors.ImageNotFound, docker.errors.APIError):
return False

View File

@@ -0,0 +1,65 @@
# -*- coding: utf-8 -*-
# Copyright 2021 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from google.cloud.aiplatform.compat.types import (
explanation as explanation_compat,
explanation_metadata as explanation_metadata_compat,
)
ExplanationMetadata = explanation_metadata_compat.ExplanationMetadata
# ExplanationMetadata subclasses
InputMetadata = ExplanationMetadata.InputMetadata
OutputMetadata = ExplanationMetadata.OutputMetadata
# InputMetadata subclasses
Encoding = InputMetadata.Encoding
FeatureValueDomain = InputMetadata.FeatureValueDomain
Visualization = InputMetadata.Visualization
ExplanationParameters = explanation_compat.ExplanationParameters
FeatureNoiseSigma = explanation_compat.FeatureNoiseSigma
ExplanationSpec = explanation_compat.ExplanationSpec
# Classes used by ExplanationParameters
IntegratedGradientsAttribution = explanation_compat.IntegratedGradientsAttribution
SampledShapleyAttribution = explanation_compat.SampledShapleyAttribution
SmoothGradConfig = explanation_compat.SmoothGradConfig
XraiAttribution = explanation_compat.XraiAttribution
Presets = explanation_compat.Presets
Examples = explanation_compat.Examples
__all__ = (
"Encoding",
"ExplanationSpec",
"ExplanationMetadata",
"ExplanationParameters",
"FeatureNoiseSigma",
"FeatureValueDomain",
"InputMetadata",
"IntegratedGradientsAttribution",
"OutputMetadata",
"SampledShapleyAttribution",
"SmoothGradConfig",
"Visualization",
"XraiAttribution",
"Presets",
"Examples",
)

View File

@@ -0,0 +1,482 @@
# -*- coding: utf-8 -*-
# Copyright 2021 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import os
from google.cloud import aiplatform
from typing import Dict, List, Mapping, Optional, Tuple, Union
try:
from lit_nlp.api import dataset as lit_dataset
from lit_nlp.api import dtypes as lit_dtypes
from lit_nlp.api import model as lit_model
from lit_nlp.api import types as lit_types
from lit_nlp import notebook
except ImportError:
raise ImportError(
"LIT is not installed and is required to get Dataset as the return format. "
'Please install the SDK using "pip install google-cloud-aiplatform[lit]"'
)
try:
import tensorflow as tf
except ImportError:
raise ImportError(
"Tensorflow is not installed and is required to load saved model. "
'Please install the SDK using "pip install google-cloud-aiplatform[lit]"'
)
try:
import pandas as pd
except ImportError:
raise ImportError(
"Pandas is not installed and is required to read the dataset. "
'Please install Pandas using "pip install google-cloud-aiplatform[lit]"'
)
class _VertexLitDataset(lit_dataset.Dataset):
"""LIT dataset class for the Vertex LIT integration.
This is used in the create_lit_dataset function.
"""
def __init__(
self,
dataset: pd.DataFrame,
column_types: "OrderedDict[str, lit_types.LitType]", # noqa: F821
):
"""Construct a VertexLitDataset.
Args:
dataset:
Required. A Pandas DataFrame that includes feature column names and data.
column_types:
Required. An OrderedDict of string names matching the columns of the dataset
as the key, and the associated LitType of the column.
"""
self._examples = dataset.to_dict(orient="records")
self._column_types = column_types
def spec(self):
"""Return a spec describing dataset elements."""
return dict(self._column_types)
class _EndpointLitModel(lit_model.Model):
"""LIT model class for the Vertex LIT integration with a model deployed to an endpoint.
This is used in the create_lit_model function.
"""
def __init__(
self,
endpoint: Union[str, aiplatform.Endpoint],
input_types: "OrderedDict[str, lit_types.LitType]", # noqa: F821
output_types: "OrderedDict[str, lit_types.LitType]", # noqa: F821
model_id: Optional[str] = None,
):
"""Construct a VertexLitModel.
Args:
model:
Required. The name of the Endpoint resource. Format:
``projects/{project}/locations/{location}/endpoints/{endpoint}``
input_types:
Required. An OrderedDict of string names matching the features of the model
as the key, and the associated LitType of the feature.
output_types:
Required. An OrderedDict of string names matching the labels of the model
as the key, and the associated LitType of the label.
model_id:
Optional. A string of the specific model in the endpoint to create the
LIT model from. If this is not set, any usable model in the endpoint is
used to create the LIT model.
Raises:
ValueError if the model_id was not found in the endpoint.
"""
if isinstance(endpoint, str):
self._endpoint = aiplatform.Endpoint(endpoint)
else:
self._endpoint = endpoint
self._model_id = model_id
self._input_types = input_types
self._output_types = output_types
# Check if the model with the model ID has explanation enabled
if model_id:
deployed_model = next(
filter(
lambda model: model.id == model_id, self._endpoint.list_models()
),
None,
)
if not deployed_model:
raise ValueError(
"A model with id {model_id} was not found in the endpoint {endpoint}.".format(
model_id=model_id, endpoint=endpoint
)
)
self._explanation_enabled = bool(deployed_model.explanation_spec)
# Check if all models in the endpoint have explanation enabled
else:
self._explanation_enabled = all(
model.explanation_spec for model in self._endpoint.list_models()
)
def predict_minibatch(
self, inputs: List[lit_types.JsonDict]
) -> List[lit_types.JsonDict]:
"""Retun predictions based on a batch of inputs.
Args:
inputs: Requred. a List of instances to predict on based on the input spec.
Returns:
A list of predictions based on the output spec.
"""
instances = []
for input in inputs:
instance = [input[feature] for feature in self._input_types]
instances.append(instance)
if self._explanation_enabled:
prediction_object = self._endpoint.explain(instances)
else:
prediction_object = self._endpoint.predict(instances)
outputs = []
for prediction in prediction_object.predictions:
if isinstance(prediction, Mapping):
outputs.append({key: prediction[key] for key in self._output_types})
else:
outputs.append(
{key: prediction[i] for i, key in enumerate(self._output_types)}
)
if self._explanation_enabled:
for i, explanation in enumerate(prediction_object.explanations):
attributions = explanation.attributions
outputs[i]["feature_attribution"] = lit_dtypes.FeatureSalience(
attributions
)
return outputs
def input_spec(self) -> lit_types.Spec:
"""Return a spec describing model inputs."""
return dict(self._input_types)
def output_spec(self) -> lit_types.Spec:
"""Return a spec describing model outputs."""
output_spec_dict = dict(self._output_types)
if self._explanation_enabled:
output_spec_dict["feature_attribution"] = lit_types.FeatureSalience(
signed=True
)
return output_spec_dict
class _TensorFlowLitModel(lit_model.Model):
"""LIT model class for the Vertex LIT integration with a TensorFlow saved model.
This is used in the create_lit_model function.
"""
def __init__(
self,
model: str,
input_types: "OrderedDict[str, lit_types.LitType]", # noqa: F821
output_types: "OrderedDict[str, lit_types.LitType]", # noqa: F821
attribution_method: str = "sampled_shapley",
):
"""Construct a VertexLitModel.
Args:
model:
Required. A string reference to a local TensorFlow saved model directory.
The model must have at most one input and one output tensor.
input_types:
Required. An OrderedDict of string names matching the features of the model
as the key, and the associated LitType of the feature.
output_types:
Required. An OrderedDict of string names matching the labels of the model
as the key, and the associated LitType of the label.
attribution_method:
Optional. A string to choose what attribution configuration to
set up the explainer with. Valid options are 'sampled_shapley'
or 'integrated_gradients'.
"""
self._load_model(model)
self._input_types = input_types
self._output_types = output_types
self._input_tensor_name = next(iter(self._kwargs_signature))
self._attribution_explainer = None
if os.environ.get("LIT_PROXY_URL"):
self._set_up_attribution_explainer(model, attribution_method)
@property
def attribution_explainer(
self,
) -> Optional["AttributionExplainer"]: # noqa: F821
"""Gets the attribution explainer property if set."""
return self._attribution_explainer
def predict_minibatch(
self, inputs: List[lit_types.JsonDict]
) -> List[lit_types.JsonDict]:
"""Retun predictions based on a batch of inputs.
Args:
inputs: Requred. a List of instances to predict on based on the input spec.
Returns:
A list of predictions based on the output spec.
"""
instances = []
for input in inputs:
instance = [input[feature] for feature in self._input_types]
instances.append(instance)
prediction_input_dict = {
self._input_tensor_name: tf.convert_to_tensor(instances)
}
prediction_dict = self._loaded_model.signatures[
tf.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY
](**prediction_input_dict)
predictions = prediction_dict[next(iter(self._output_signature))].numpy()
outputs = []
for prediction in predictions:
outputs.append(
{
label: value
for label, value in zip(self._output_types.keys(), prediction)
}
)
# Get feature attributions
if self.attribution_explainer:
attributions = self.attribution_explainer.explain(
[{self._input_tensor_name: i} for i in instances]
)
for i, attribution in enumerate(attributions):
outputs[i]["feature_attribution"] = lit_dtypes.FeatureSalience(
attribution.feature_importance()
)
return outputs
def input_spec(self) -> lit_types.Spec:
"""Return a spec describing model inputs."""
return dict(self._input_types)
def output_spec(self) -> lit_types.Spec:
"""Return a spec describing model outputs."""
output_spec_dict = dict(self._output_types)
if self.attribution_explainer:
output_spec_dict["feature_attribution"] = lit_types.FeatureSalience(
signed=True
)
return output_spec_dict
def _load_model(self, model: str):
"""Loads a TensorFlow saved model and populates the input and output signature attributes of the class.
Args:
model: Required. A string reference to a TensorFlow saved model directory.
Raises:
ValueError if the model has more than one input tensor or more than one output tensor.
"""
self._loaded_model = tf.saved_model.load(model)
serving_default = self._loaded_model.signatures[
tf.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY
]
_, self._kwargs_signature = serving_default.structured_input_signature
self._output_signature = serving_default.structured_outputs
if len(self._kwargs_signature) != 1:
raise ValueError("Please use a model with only one input tensor.")
if len(self._output_signature) != 1:
raise ValueError("Please use a model with only one output tensor.")
def _set_up_attribution_explainer(
self, model: str, attribution_method: str = "integrated_gradients"
):
"""Populates the attribution explainer attribute of the class.
Args:
model: Required. A string reference to a TensorFlow saved model directory.
attribution_method:
Optional. A string to choose what attribution configuration to
set up the explainer with. Valid options are 'sampled_shapley'
or 'integrated_gradients'.
"""
try:
import explainable_ai_sdk
from explainable_ai_sdk.metadata.tf.v2 import SavedModelMetadataBuilder
except ImportError:
logging.info(
"Skipping explanations because the Explainable AI SDK is not installed."
'Please install the SDK using "pip install explainable-ai-sdk"'
)
return
builder = SavedModelMetadataBuilder(model)
builder.get_metadata()
builder.set_numeric_metadata(
self._input_tensor_name,
index_feature_mapping=list(self._input_types.keys()),
)
builder.save_metadata(model)
if attribution_method == "integrated_gradients":
explainer_config = explainable_ai_sdk.IntegratedGradientsConfig()
else:
explainer_config = explainable_ai_sdk.SampledShapleyConfig()
self._attribution_explainer = explainable_ai_sdk.load_model_from_local_path(
model, explainer_config
)
self._load_model(model)
def create_lit_dataset(
dataset: pd.DataFrame,
column_types: "OrderedDict[str, lit_types.LitType]", # noqa: F821
) -> lit_dataset.Dataset:
"""Creates a LIT Dataset object.
Args:
dataset:
Required. A Pandas DataFrame that includes feature column names and data.
column_types:
Required. An OrderedDict of string names matching the columns of the dataset
as the key, and the associated LitType of the column.
Returns:
A LIT Dataset object that has the data from the dataset provided.
"""
return _VertexLitDataset(dataset, column_types)
def create_lit_model_from_endpoint(
endpoint: Union[str, aiplatform.Endpoint],
input_types: "OrderedDict[str, lit_types.LitType]", # noqa: F821
output_types: "OrderedDict[str, lit_types.LitType]", # noqa: F821
model_id: Optional[str] = None,
) -> lit_model.Model:
"""Creates a LIT Model object.
Args:
model:
Required. The name of the Endpoint resource or an Endpoint instance.
Endpoint name format: ``projects/{project}/locations/{location}/endpoints/{endpoint}``
input_types:
Required. An OrderedDict of string names matching the features of the model
as the key, and the associated LitType of the feature.
output_types:
Required. An OrderedDict of string names matching the labels of the model
as the key, and the associated LitType of the label.
model_id:
Optional. A string of the specific model in the endpoint to create the
LIT model from. If this is not set, any usable model in the endpoint is
used to create the LIT model.
Returns:
A LIT Model object that has the same functionality as the model provided.
"""
return _EndpointLitModel(endpoint, input_types, output_types, model_id)
def create_lit_model(
model: str,
input_types: "OrderedDict[str, lit_types.LitType]", # noqa: F821
output_types: "OrderedDict[str, lit_types.LitType]", # noqa: F821
attribution_method: str = "sampled_shapley",
) -> lit_model.Model:
"""Creates a LIT Model object.
Args:
model:
Required. A string reference to a local TensorFlow saved model directory.
The model must have at most one input and one output tensor.
input_types:
Required. An OrderedDict of string names matching the features of the model
as the key, and the associated LitType of the feature.
output_types:
Required. An OrderedDict of string names matching the labels of the model
as the key, and the associated LitType of the label.
attribution_method:
Optional. A string to choose what attribution configuration to
set up the explainer with. Valid options are 'sampled_shapley'
or 'integrated_gradients'.
Returns:
A LIT Model object that has the same functionality as the model provided.
"""
return _TensorFlowLitModel(model, input_types, output_types, attribution_method)
def open_lit(
models: Dict[str, lit_model.Model],
datasets: Dict[str, lit_dataset.Dataset],
open_in_new_tab: bool = True,
):
"""Open LIT from the provided models and datasets.
Args:
models:
Required. A list of LIT models to open LIT with.
input_types:
Required. A lit of LIT datasets to open LIT with.
open_in_new_tab:
Optional. A boolean to choose if LIT open in a new tab or not.
Raises:
ImportError if LIT is not installed.
"""
widget = notebook.LitWidget(models, datasets)
widget.render(open_in_new_tab=open_in_new_tab)
def set_up_and_open_lit(
dataset: Union[pd.DataFrame, lit_dataset.Dataset],
column_types: "OrderedDict[str, lit_types.LitType]", # noqa: F821
model: Union[str, lit_model.Model],
input_types: Union[List[str], Dict[str, lit_types.LitType]],
output_types: Union[str, List[str], Dict[str, lit_types.LitType]],
attribution_method: str = "sampled_shapley",
open_in_new_tab: bool = True,
) -> Tuple[lit_dataset.Dataset, lit_model.Model]:
"""Creates a LIT dataset and model and opens LIT.
Args:
dataset:
Required. A Pandas DataFrame that includes feature column names and data.
column_types:
Required. An OrderedDict of string names matching the columns of the dataset
as the key, and the associated LitType of the column.
model:
Required. A string reference to a TensorFlow saved model directory.
The model must have at most one input and one output tensor.
input_types:
Required. An OrderedDict of string names matching the features of the model
as the key, and the associated LitType of the feature.
output_types:
Required. An OrderedDict of string names matching the labels of the model
as the key, and the associated LitType of the label.
attribution_method:
Optional. A string to choose what attribution configuration to
set up the explainer with. Valid options are 'sampled_shapley'
or 'integrated_gradients'.
open_in_new_tab:
Optional. A boolean to choose if LIT open in a new tab or not.
Returns:
A Tuple of the LIT dataset and model created.
Raises:
ImportError if LIT or TensorFlow is not installed.
ValueError if the model doesn't have only 1 input and output tensor.
"""
if not isinstance(dataset, lit_dataset.Dataset):
dataset = create_lit_dataset(dataset, column_types)
if not isinstance(model, lit_model.Model):
model = create_lit_model(
model, input_types, output_types, attribution_method=attribution_method
)
open_lit(
{"model": model},
{"dataset": dataset},
open_in_new_tab=open_in_new_tab,
)
return dataset, model

View File

@@ -0,0 +1,15 @@
# -*- coding: utf-8 -*-
# Copyright 2021 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

View File

@@ -0,0 +1,34 @@
# -*- coding: utf-8 -*-
# Copyright 2021 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Base abstract class for metadata builders."""
import abc
_ABC = abc.ABCMeta("ABC", (object,), {"__slots__": ()})
class MetadataBuilder(_ABC):
"""Abstract base class for metadata builders."""
@abc.abstractmethod
def get_metadata(self):
"""Returns the current metadata as a dictionary."""
@abc.abstractmethod
def get_metadata_protobuf(self):
"""Returns the current metadata as ExplanationMetadata protobuf"""

View File

@@ -0,0 +1,15 @@
# -*- coding: utf-8 -*-
# Copyright 2021 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

View File

@@ -0,0 +1,15 @@
# -*- coding: utf-8 -*-
# Copyright 2021 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

View File

@@ -0,0 +1,171 @@
# -*- coding: utf-8 -*-
# Copyright 2021 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from google.protobuf import json_format
from typing import Any, Dict, List, Optional
from google.cloud.aiplatform.compat.types import explanation_metadata
from google.cloud.aiplatform.explain.metadata import metadata_builder
class SavedModelMetadataBuilder(metadata_builder.MetadataBuilder):
"""Metadata builder class that accepts a TF1 saved model."""
def __init__(
self,
model_path: str,
tags: Optional[List[str]] = None,
signature_name: Optional[str] = None,
outputs_to_explain: Optional[List[str]] = None,
) -> None:
"""Initializes a SavedModelMetadataBuilder object.
Args:
model_path:
Required. Local or GCS path to load the saved model from.
tags:
Optional. Tags to identify the model graph. If None or empty,
TensorFlow's default serving tag will be used.
signature_name:
Optional. Name of the signature to be explained. Inputs and
outputs of this signature will be written in the metadata. If not
provided, the default signature will be used.
outputs_to_explain:
Optional. List of output names to explain. Only single output is
supported for now. Hence, the list should contain one element.
This parameter is required if the model signature (provided via
signature_name) specifies multiple outputs.
Raises:
ValueError: If outputs_to_explain contains more than 1 element or
signature contains multiple outputs.
"""
if outputs_to_explain:
if len(outputs_to_explain) > 1:
raise ValueError(
"Only one output is supported at the moment. "
f"Received: {outputs_to_explain}."
)
self._output_to_explain = next(iter(outputs_to_explain))
try:
import tensorflow.compat.v1 as tf
except ImportError:
raise ImportError(
"Tensorflow is not installed and is required to load saved model. "
'Please install the SDK using "pip install "tensorflow>=1.15,<2.0""'
)
if not signature_name:
signature_name = tf.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY
self._tags = tags or [tf.saved_model.tag_constants.SERVING]
self._graph = tf.Graph()
with self.graph.as_default():
self._session = tf.Session(graph=self.graph)
self._metagraph_def = tf.saved_model.loader.load(
sess=self.session, tags=self._tags, export_dir=model_path
)
if signature_name not in self._metagraph_def.signature_def:
raise ValueError(
f"Serving sigdef key {signature_name} not in the signature def."
)
serving_sigdef = self._metagraph_def.signature_def[signature_name]
if not outputs_to_explain:
if len(serving_sigdef.outputs) > 1:
raise ValueError(
"The signature contains multiple outputs. Specify "
'an output via "outputs_to_explain" parameter.'
)
self._output_to_explain = next(iter(serving_sigdef.outputs.keys()))
self._inputs = _create_input_metadata_from_signature(serving_sigdef.inputs)
self._outputs = _create_output_metadata_from_signature(
serving_sigdef.outputs, self._output_to_explain
)
@property
def graph(self) -> "tf.Graph": # noqa: F821
return self._graph
@property
def session(self) -> "tf.Session": # noqa: F821
return self._session
def get_metadata(self) -> Dict[str, Any]:
"""Returns the current metadata as a dictionary.
Returns:
Json format of the explanation metadata.
"""
return json_format.MessageToDict(self.get_metadata_protobuf()._pb)
def get_metadata_protobuf(self) -> explanation_metadata.ExplanationMetadata:
"""Returns the current metadata as a Protobuf object.
Returns:
ExplanationMetadata object format of the explanation metadata.
"""
return explanation_metadata.ExplanationMetadata(
inputs=self._inputs,
outputs=self._outputs,
)
def _create_input_metadata_from_signature(
signature_inputs: Dict[str, "tf.Tensor"] # noqa: F821
) -> Dict[str, explanation_metadata.ExplanationMetadata.InputMetadata]:
"""Creates InputMetadata from signature inputs.
Args:
signature_inputs:
Required. Inputs of the signature to be explained. If not provided,
the default signature will be used.
Returns:
Inferred input metadata from the model.
"""
input_mds = {}
for key, tensor in signature_inputs.items():
input_mds[key] = explanation_metadata.ExplanationMetadata.InputMetadata(
input_tensor_name=tensor.name
)
return input_mds
def _create_output_metadata_from_signature(
signature_outputs: Dict[str, "tf.Tensor"], # noqa: F821
output_to_explain: Optional[str] = None,
) -> Dict[str, explanation_metadata.ExplanationMetadata.OutputMetadata]:
"""Creates OutputMetadata from signature inputs.
Args:
signature_outputs:
Required. Inputs of the signature to be explained. If not provided,
the default signature will be used.
output_to_explain:
Optional. Output name to explain.
Returns:
Inferred output metadata from the model.
"""
output_mds = {}
for key, tensor in signature_outputs.items():
if not output_to_explain or output_to_explain == key:
output_mds[key] = explanation_metadata.ExplanationMetadata.OutputMetadata(
output_tensor_name=tensor.name
)
return output_mds

Some files were not shown because too many files have changed in this diff Show More