structure saas with tools
This commit is contained in:
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,590 @@
|
||||
# Copyright 2014 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Shared helpers for Google Cloud packages.
|
||||
|
||||
This module is not part of the public API surface.
|
||||
"""
|
||||
|
||||
from __future__ import absolute_import
|
||||
|
||||
import calendar
|
||||
import datetime
|
||||
import http.client
|
||||
import os
|
||||
import re
|
||||
from threading import local as Local
|
||||
from typing import Union
|
||||
|
||||
import google.auth
|
||||
import google.auth.transport.requests
|
||||
from google.protobuf import duration_pb2
|
||||
from google.protobuf import timestamp_pb2
|
||||
|
||||
try:
|
||||
import grpc
|
||||
import google.auth.transport.grpc
|
||||
except ImportError: # pragma: NO COVER
|
||||
grpc = None
|
||||
|
||||
# `google.cloud._helpers._NOW` is deprecated
|
||||
_NOW = datetime.datetime.utcnow
|
||||
UTC = datetime.timezone.utc # Singleton instance to be used throughout.
|
||||
_EPOCH = datetime.datetime(1970, 1, 1, tzinfo=datetime.timezone.utc)
|
||||
|
||||
_RFC3339_MICROS = "%Y-%m-%dT%H:%M:%S.%fZ"
|
||||
_RFC3339_NO_FRACTION = "%Y-%m-%dT%H:%M:%S"
|
||||
_TIMEONLY_W_MICROS = "%H:%M:%S.%f"
|
||||
_TIMEONLY_NO_FRACTION = "%H:%M:%S"
|
||||
# datetime.strptime cannot handle nanosecond precision: parse w/ regex
|
||||
_RFC3339_NANOS = re.compile(
|
||||
r"""
|
||||
(?P<no_fraction>
|
||||
\d{4}-\d{2}-\d{2}T\d{2}:\d{2}:\d{2} # YYYY-MM-DDTHH:MM:SS
|
||||
)
|
||||
( # Optional decimal part
|
||||
\. # decimal point
|
||||
(?P<nanos>\d{1,9}) # nanoseconds, maybe truncated
|
||||
)?
|
||||
Z # Zulu
|
||||
""",
|
||||
re.VERBOSE,
|
||||
)
|
||||
# NOTE: Catching this ImportError is a workaround for GAE not supporting the
|
||||
# "pwd" module which is imported lazily when "expanduser" is called.
|
||||
_USER_ROOT: Union[str, None]
|
||||
try:
|
||||
_USER_ROOT = os.path.expanduser("~")
|
||||
except ImportError: # pragma: NO COVER
|
||||
_USER_ROOT = None
|
||||
_GCLOUD_CONFIG_FILE = os.path.join("gcloud", "configurations", "config_default")
|
||||
_GCLOUD_CONFIG_SECTION = "core"
|
||||
_GCLOUD_CONFIG_KEY = "project"
|
||||
|
||||
|
||||
class _LocalStack(Local):
|
||||
"""Manage a thread-local LIFO stack of resources.
|
||||
|
||||
Intended for use in :class:`google.cloud.datastore.batch.Batch.__enter__`,
|
||||
:class:`google.cloud.storage.batch.Batch.__enter__`, etc.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super(_LocalStack, self).__init__()
|
||||
self._stack = []
|
||||
|
||||
def __iter__(self):
|
||||
"""Iterate the stack in LIFO order."""
|
||||
return iter(reversed(self._stack))
|
||||
|
||||
def push(self, resource):
|
||||
"""Push a resource onto our stack."""
|
||||
self._stack.append(resource)
|
||||
|
||||
def pop(self):
|
||||
"""Pop a resource from our stack.
|
||||
|
||||
:rtype: object
|
||||
:returns: the top-most resource, after removing it.
|
||||
:raises IndexError: if the stack is empty.
|
||||
"""
|
||||
return self._stack.pop()
|
||||
|
||||
@property
|
||||
def top(self):
|
||||
"""Get the top-most resource
|
||||
|
||||
:rtype: object
|
||||
:returns: the top-most item, or None if the stack is empty.
|
||||
"""
|
||||
if self._stack:
|
||||
return self._stack[-1]
|
||||
|
||||
|
||||
def _ensure_tuple_or_list(arg_name, tuple_or_list):
|
||||
"""Ensures an input is a tuple or list.
|
||||
|
||||
This effectively reduces the iterable types allowed to a very short
|
||||
allowlist: list and tuple.
|
||||
|
||||
:type arg_name: str
|
||||
:param arg_name: Name of argument to use in error message.
|
||||
|
||||
:type tuple_or_list: sequence of str
|
||||
:param tuple_or_list: Sequence to be verified.
|
||||
|
||||
:rtype: list of str
|
||||
:returns: The ``tuple_or_list`` passed in cast to a ``list``.
|
||||
:raises TypeError: if the ``tuple_or_list`` is not a tuple or list.
|
||||
"""
|
||||
if not isinstance(tuple_or_list, (tuple, list)):
|
||||
raise TypeError(
|
||||
"Expected %s to be a tuple or list. "
|
||||
"Received %r" % (arg_name, tuple_or_list)
|
||||
)
|
||||
return list(tuple_or_list)
|
||||
|
||||
|
||||
def _determine_default_project(project=None):
|
||||
"""Determine default project ID explicitly or implicitly as fall-back.
|
||||
|
||||
See :func:`google.auth.default` for details on how the default project
|
||||
is determined.
|
||||
|
||||
:type project: str
|
||||
:param project: Optional. The project name to use as default.
|
||||
|
||||
:rtype: str or ``NoneType``
|
||||
:returns: Default project if it can be determined.
|
||||
"""
|
||||
if project is None:
|
||||
_, project = google.auth.default()
|
||||
return project
|
||||
|
||||
|
||||
def _millis(when):
|
||||
"""Convert a zone-aware datetime to integer milliseconds.
|
||||
|
||||
:type when: :class:`datetime.datetime`
|
||||
:param when: the datetime to convert
|
||||
|
||||
:rtype: int
|
||||
:returns: milliseconds since epoch for ``when``
|
||||
"""
|
||||
micros = _microseconds_from_datetime(when)
|
||||
return micros // 1000
|
||||
|
||||
|
||||
def _datetime_from_microseconds(value):
|
||||
"""Convert timestamp to datetime, assuming UTC.
|
||||
|
||||
:type value: float
|
||||
:param value: The timestamp to convert
|
||||
|
||||
:rtype: :class:`datetime.datetime`
|
||||
:returns: The datetime object created from the value.
|
||||
"""
|
||||
return _EPOCH + datetime.timedelta(microseconds=value)
|
||||
|
||||
|
||||
def _microseconds_from_datetime(value):
|
||||
"""Convert non-none datetime to microseconds.
|
||||
|
||||
:type value: :class:`datetime.datetime`
|
||||
:param value: The timestamp to convert.
|
||||
|
||||
:rtype: int
|
||||
:returns: The timestamp, in microseconds.
|
||||
"""
|
||||
if not value.tzinfo:
|
||||
value = value.replace(tzinfo=UTC)
|
||||
# Regardless of what timezone is on the value, convert it to UTC.
|
||||
value = value.astimezone(UTC)
|
||||
# Convert the datetime to a microsecond timestamp.
|
||||
return int(calendar.timegm(value.timetuple()) * 1e6) + value.microsecond
|
||||
|
||||
|
||||
def _millis_from_datetime(value):
|
||||
"""Convert non-none datetime to timestamp, assuming UTC.
|
||||
|
||||
:type value: :class:`datetime.datetime`
|
||||
:param value: (Optional) the timestamp
|
||||
|
||||
:rtype: int, or ``NoneType``
|
||||
:returns: the timestamp, in milliseconds, or None
|
||||
"""
|
||||
if value is not None:
|
||||
return _millis(value)
|
||||
|
||||
|
||||
def _date_from_iso8601_date(value):
|
||||
"""Convert a ISO8601 date string to native datetime date
|
||||
|
||||
:type value: str
|
||||
:param value: The date string to convert
|
||||
|
||||
:rtype: :class:`datetime.date`
|
||||
:returns: A datetime date object created from the string
|
||||
|
||||
"""
|
||||
return datetime.datetime.strptime(value, "%Y-%m-%d").date()
|
||||
|
||||
|
||||
def _time_from_iso8601_time_naive(value):
|
||||
"""Convert a zoneless ISO8601 time string to naive datetime time
|
||||
|
||||
:type value: str
|
||||
:param value: The time string to convert
|
||||
|
||||
:rtype: :class:`datetime.time`
|
||||
:returns: A datetime time object created from the string
|
||||
:raises ValueError: if the value does not match a known format.
|
||||
"""
|
||||
if len(value) == 8: # HH:MM:SS
|
||||
fmt = _TIMEONLY_NO_FRACTION
|
||||
elif len(value) == 15: # HH:MM:SS.micros
|
||||
fmt = _TIMEONLY_W_MICROS
|
||||
else:
|
||||
raise ValueError("Unknown time format: {}".format(value))
|
||||
return datetime.datetime.strptime(value, fmt).time()
|
||||
|
||||
|
||||
def _rfc3339_to_datetime(dt_str):
|
||||
"""Convert a microsecond-precision timestamp to a native datetime.
|
||||
|
||||
:type dt_str: str
|
||||
:param dt_str: The string to convert.
|
||||
|
||||
:rtype: :class:`datetime.datetime`
|
||||
:returns: The datetime object created from the string.
|
||||
"""
|
||||
return datetime.datetime.strptime(dt_str, _RFC3339_MICROS).replace(tzinfo=UTC)
|
||||
|
||||
|
||||
def _rfc3339_nanos_to_datetime(dt_str):
|
||||
"""Convert a nanosecond-precision timestamp to a native datetime.
|
||||
|
||||
.. note::
|
||||
|
||||
Python datetimes do not support nanosecond precision; this function
|
||||
therefore truncates such values to microseconds.
|
||||
|
||||
:type dt_str: str
|
||||
:param dt_str: The string to convert.
|
||||
|
||||
:rtype: :class:`datetime.datetime`
|
||||
:returns: The datetime object created from the string.
|
||||
:raises ValueError: If the timestamp does not match the RFC 3339
|
||||
regular expression.
|
||||
"""
|
||||
with_nanos = _RFC3339_NANOS.match(dt_str)
|
||||
if with_nanos is None:
|
||||
raise ValueError(
|
||||
"Timestamp: %r, does not match pattern: %r"
|
||||
% (dt_str, _RFC3339_NANOS.pattern)
|
||||
)
|
||||
bare_seconds = datetime.datetime.strptime(
|
||||
with_nanos.group("no_fraction"), _RFC3339_NO_FRACTION
|
||||
)
|
||||
fraction = with_nanos.group("nanos")
|
||||
if fraction is None:
|
||||
micros = 0
|
||||
else:
|
||||
scale = 9 - len(fraction)
|
||||
nanos = int(fraction) * (10**scale)
|
||||
micros = nanos // 1000
|
||||
return bare_seconds.replace(microsecond=micros, tzinfo=UTC)
|
||||
|
||||
|
||||
def _datetime_to_rfc3339(value, ignore_zone=True):
|
||||
"""Convert a timestamp to a string.
|
||||
|
||||
:type value: :class:`datetime.datetime`
|
||||
:param value: The datetime object to be converted to a string.
|
||||
|
||||
:type ignore_zone: bool
|
||||
:param ignore_zone: If True, then the timezone (if any) of the datetime
|
||||
object is ignored.
|
||||
|
||||
:rtype: str
|
||||
:returns: The string representing the datetime stamp.
|
||||
"""
|
||||
if not ignore_zone and value.tzinfo is not None:
|
||||
# Convert to UTC and remove the time zone info.
|
||||
value = value.replace(tzinfo=None) - value.utcoffset()
|
||||
|
||||
return value.strftime(_RFC3339_MICROS)
|
||||
|
||||
|
||||
def _to_bytes(value, encoding="ascii"):
|
||||
"""Converts a string value to bytes, if necessary.
|
||||
|
||||
:type value: str / bytes or unicode
|
||||
:param value: The string/bytes value to be converted.
|
||||
|
||||
:type encoding: str
|
||||
:param encoding: The encoding to use to convert unicode to bytes. Defaults
|
||||
to "ascii", which will not allow any characters from
|
||||
ordinals larger than 127. Other useful values are
|
||||
"latin-1", which which will only allows byte ordinals
|
||||
(up to 255) and "utf-8", which will encode any unicode
|
||||
that needs to be.
|
||||
|
||||
:rtype: str / bytes
|
||||
:returns: The original value converted to bytes (if unicode) or as passed
|
||||
in if it started out as bytes.
|
||||
:raises TypeError: if the value could not be converted to bytes.
|
||||
"""
|
||||
result = value.encode(encoding) if isinstance(value, str) else value
|
||||
if isinstance(result, bytes):
|
||||
return result
|
||||
else:
|
||||
raise TypeError("%r could not be converted to bytes" % (value,))
|
||||
|
||||
|
||||
def _bytes_to_unicode(value):
|
||||
"""Converts bytes to a unicode value, if necessary.
|
||||
|
||||
:type value: bytes
|
||||
:param value: bytes value to attempt string conversion on.
|
||||
|
||||
:rtype: str
|
||||
:returns: The original value converted to unicode (if bytes) or as passed
|
||||
in if it started out as unicode.
|
||||
|
||||
:raises ValueError: if the value could not be converted to unicode.
|
||||
"""
|
||||
result = value.decode("utf-8") if isinstance(value, bytes) else value
|
||||
if isinstance(result, str):
|
||||
return result
|
||||
else:
|
||||
raise ValueError("%r could not be converted to unicode" % (value,))
|
||||
|
||||
|
||||
def _from_any_pb(pb_type, any_pb):
|
||||
"""Converts an Any protobuf to the specified message type
|
||||
|
||||
Args:
|
||||
pb_type (type): the type of the message that any_pb stores an instance
|
||||
of.
|
||||
any_pb (google.protobuf.any_pb2.Any): the object to be converted.
|
||||
|
||||
Returns:
|
||||
pb_type: An instance of the pb_type message.
|
||||
|
||||
Raises:
|
||||
TypeError: if the message could not be converted.
|
||||
"""
|
||||
msg = pb_type()
|
||||
if not any_pb.Unpack(msg):
|
||||
raise TypeError(
|
||||
"Could not convert {} to {}".format(
|
||||
any_pb.__class__.__name__, pb_type.__name__
|
||||
)
|
||||
)
|
||||
|
||||
return msg
|
||||
|
||||
|
||||
def _pb_timestamp_to_datetime(timestamp_pb):
|
||||
"""Convert a Timestamp protobuf to a datetime object.
|
||||
|
||||
:type timestamp_pb: :class:`google.protobuf.timestamp_pb2.Timestamp`
|
||||
:param timestamp_pb: A Google returned timestamp protobuf.
|
||||
|
||||
:rtype: :class:`datetime.datetime`
|
||||
:returns: A UTC datetime object converted from a protobuf timestamp.
|
||||
"""
|
||||
return _EPOCH + datetime.timedelta(
|
||||
seconds=timestamp_pb.seconds, microseconds=(timestamp_pb.nanos / 1000.0)
|
||||
)
|
||||
|
||||
|
||||
def _pb_timestamp_to_rfc3339(timestamp_pb):
|
||||
"""Convert a Timestamp protobuf to an RFC 3339 string.
|
||||
|
||||
:type timestamp_pb: :class:`google.protobuf.timestamp_pb2.Timestamp`
|
||||
:param timestamp_pb: A Google returned timestamp protobuf.
|
||||
|
||||
:rtype: str
|
||||
:returns: An RFC 3339 formatted timestamp string.
|
||||
"""
|
||||
timestamp = _pb_timestamp_to_datetime(timestamp_pb)
|
||||
return _datetime_to_rfc3339(timestamp)
|
||||
|
||||
|
||||
def _datetime_to_pb_timestamp(when):
|
||||
"""Convert a datetime object to a Timestamp protobuf.
|
||||
|
||||
:type when: :class:`datetime.datetime`
|
||||
:param when: the datetime to convert
|
||||
|
||||
:rtype: :class:`google.protobuf.timestamp_pb2.Timestamp`
|
||||
:returns: A timestamp protobuf corresponding to the object.
|
||||
"""
|
||||
ms_value = _microseconds_from_datetime(when)
|
||||
seconds, micros = divmod(ms_value, 10**6)
|
||||
nanos = micros * 10**3
|
||||
return timestamp_pb2.Timestamp(seconds=seconds, nanos=nanos)
|
||||
|
||||
|
||||
def _timedelta_to_duration_pb(timedelta_val):
|
||||
"""Convert a Python timedelta object to a duration protobuf.
|
||||
|
||||
.. note::
|
||||
|
||||
The Python timedelta has a granularity of microseconds while
|
||||
the protobuf duration type has a duration of nanoseconds.
|
||||
|
||||
:type timedelta_val: :class:`datetime.timedelta`
|
||||
:param timedelta_val: A timedelta object.
|
||||
|
||||
:rtype: :class:`google.protobuf.duration_pb2.Duration`
|
||||
:returns: A duration object equivalent to the time delta.
|
||||
"""
|
||||
duration_pb = duration_pb2.Duration()
|
||||
duration_pb.FromTimedelta(timedelta_val)
|
||||
return duration_pb
|
||||
|
||||
|
||||
def _duration_pb_to_timedelta(duration_pb):
|
||||
"""Convert a duration protobuf to a Python timedelta object.
|
||||
|
||||
.. note::
|
||||
|
||||
The Python timedelta has a granularity of microseconds while
|
||||
the protobuf duration type has a duration of nanoseconds.
|
||||
|
||||
:type duration_pb: :class:`google.protobuf.duration_pb2.Duration`
|
||||
:param duration_pb: A protobuf duration object.
|
||||
|
||||
:rtype: :class:`datetime.timedelta`
|
||||
:returns: The converted timedelta object.
|
||||
"""
|
||||
return datetime.timedelta(
|
||||
seconds=duration_pb.seconds, microseconds=(duration_pb.nanos / 1000.0)
|
||||
)
|
||||
|
||||
|
||||
def _name_from_project_path(path, project, template):
|
||||
"""Validate a URI path and get the leaf object's name.
|
||||
|
||||
:type path: str
|
||||
:param path: URI path containing the name.
|
||||
|
||||
:type project: str
|
||||
:param project: (Optional) The project associated with the request. It is
|
||||
included for validation purposes. If passed as None,
|
||||
disables validation.
|
||||
|
||||
:type template: str
|
||||
:param template: Template regex describing the expected form of the path.
|
||||
The regex must have two named groups, 'project' and
|
||||
'name'.
|
||||
|
||||
:rtype: str
|
||||
:returns: Name parsed from ``path``.
|
||||
:raises ValueError: if the ``path`` is ill-formed or if the project from
|
||||
the ``path`` does not agree with the ``project``
|
||||
passed in.
|
||||
"""
|
||||
if isinstance(template, str):
|
||||
template = re.compile(template)
|
||||
|
||||
match = template.match(path)
|
||||
|
||||
if not match:
|
||||
raise ValueError(
|
||||
'path "%s" did not match expected pattern "%s"' % (path, template.pattern)
|
||||
)
|
||||
|
||||
if project is not None:
|
||||
found_project = match.group("project")
|
||||
if found_project != project:
|
||||
raise ValueError(
|
||||
"Project from client (%s) should agree with "
|
||||
"project from resource(%s)." % (project, found_project)
|
||||
)
|
||||
|
||||
return match.group("name")
|
||||
|
||||
|
||||
def make_secure_channel(credentials, user_agent, host, extra_options=()):
|
||||
"""Makes a secure channel for an RPC service.
|
||||
|
||||
Uses / depends on gRPC.
|
||||
|
||||
:type credentials: :class:`google.auth.credentials.Credentials`
|
||||
:param credentials: The OAuth2 Credentials to use for creating
|
||||
access tokens.
|
||||
|
||||
:type user_agent: str
|
||||
:param user_agent: The user agent to be used with API requests.
|
||||
|
||||
:type host: str
|
||||
:param host: The host for the service.
|
||||
|
||||
:type extra_options: tuple
|
||||
:param extra_options: (Optional) Extra gRPC options used when creating the
|
||||
channel.
|
||||
|
||||
:rtype: :class:`grpc._channel.Channel`
|
||||
:returns: gRPC secure channel with credentials attached.
|
||||
"""
|
||||
target = "%s:%d" % (host, http.client.HTTPS_PORT)
|
||||
http_request = google.auth.transport.requests.Request()
|
||||
|
||||
user_agent_option = ("grpc.primary_user_agent", user_agent)
|
||||
options = (user_agent_option,) + extra_options
|
||||
return google.auth.transport.grpc.secure_authorized_channel(
|
||||
credentials, http_request, target, options=options
|
||||
)
|
||||
|
||||
|
||||
def make_secure_stub(credentials, user_agent, stub_class, host, extra_options=()):
|
||||
"""Makes a secure stub for an RPC service.
|
||||
|
||||
Uses / depends on gRPC.
|
||||
|
||||
:type credentials: :class:`google.auth.credentials.Credentials`
|
||||
:param credentials: The OAuth2 Credentials to use for creating
|
||||
access tokens.
|
||||
|
||||
:type user_agent: str
|
||||
:param user_agent: The user agent to be used with API requests.
|
||||
|
||||
:type stub_class: type
|
||||
:param stub_class: A gRPC stub type for a given service.
|
||||
|
||||
:type host: str
|
||||
:param host: The host for the service.
|
||||
|
||||
:type extra_options: tuple
|
||||
:param extra_options: (Optional) Extra gRPC options passed when creating
|
||||
the channel.
|
||||
|
||||
:rtype: object, instance of ``stub_class``
|
||||
:returns: The stub object used to make gRPC requests to a given API.
|
||||
"""
|
||||
channel = make_secure_channel(
|
||||
credentials, user_agent, host, extra_options=extra_options
|
||||
)
|
||||
return stub_class(channel)
|
||||
|
||||
|
||||
def make_insecure_stub(stub_class, host, port=None):
|
||||
"""Makes an insecure stub for an RPC service.
|
||||
|
||||
Uses / depends on gRPC.
|
||||
|
||||
:type stub_class: type
|
||||
:param stub_class: A gRPC stub type for a given service.
|
||||
|
||||
:type host: str
|
||||
:param host: The host for the service. May also include the port
|
||||
if ``port`` is unspecified.
|
||||
|
||||
:type port: int
|
||||
:param port: (Optional) The port for the service.
|
||||
|
||||
:rtype: object, instance of ``stub_class``
|
||||
:returns: The stub object used to make gRPC requests to a given API.
|
||||
"""
|
||||
if port is None:
|
||||
target = host
|
||||
else:
|
||||
# NOTE: This assumes port != http.client.HTTPS_PORT:
|
||||
target = "%s:%d" % (host, port)
|
||||
channel = grpc.insecure_channel(target)
|
||||
return stub_class(channel)
|
||||
Binary file not shown.
@@ -0,0 +1,2 @@
|
||||
# Marker file for PEP 561.
|
||||
# This package uses inline types.
|
||||
@@ -0,0 +1,499 @@
|
||||
# Copyright 2014 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Shared implementation of connections to API servers."""
|
||||
|
||||
import collections
|
||||
import collections.abc
|
||||
import json
|
||||
import os
|
||||
import platform
|
||||
from typing import Optional
|
||||
from urllib.parse import urlencode
|
||||
import warnings
|
||||
|
||||
from google.api_core.client_info import ClientInfo
|
||||
from google.cloud import exceptions
|
||||
from google.cloud import version
|
||||
|
||||
|
||||
API_BASE_URL = "https://www.googleapis.com"
|
||||
"""The base of the API call URL."""
|
||||
|
||||
DEFAULT_USER_AGENT = "gcloud-python/{0}".format(version.__version__)
|
||||
"""The user agent for google-cloud-python requests."""
|
||||
|
||||
CLIENT_INFO_HEADER = "X-Goog-API-Client"
|
||||
CLIENT_INFO_TEMPLATE = "gl-python/" + platform.python_version() + " gccl/{}"
|
||||
|
||||
_USER_AGENT_ALL_CAPS_DEPRECATED = """\
|
||||
The 'USER_AGENT' class-level attribute is deprecated. Please use
|
||||
'user_agent' instead.
|
||||
"""
|
||||
|
||||
_EXTRA_HEADERS_ALL_CAPS_DEPRECATED = """\
|
||||
The '_EXTRA_HEADERS' class-level attribute is deprecated. Please use
|
||||
'extra_headers' instead.
|
||||
"""
|
||||
|
||||
_DEFAULT_TIMEOUT = 60 # in seconds
|
||||
|
||||
|
||||
class Connection(object):
|
||||
"""A generic connection to Google Cloud Platform.
|
||||
|
||||
:type client: :class:`~google.cloud.client.Client`
|
||||
:param client: The client that owns the current connection.
|
||||
|
||||
:type client_info: :class:`~google.api_core.client_info.ClientInfo`
|
||||
:param client_info: (Optional) instance used to generate user agent.
|
||||
"""
|
||||
|
||||
_user_agent = DEFAULT_USER_AGENT
|
||||
|
||||
def __init__(self, client, client_info=None):
|
||||
self._client = client
|
||||
|
||||
if client_info is None:
|
||||
client_info = ClientInfo()
|
||||
|
||||
self._client_info = client_info
|
||||
self._extra_headers = {}
|
||||
|
||||
@property
|
||||
def USER_AGENT(self):
|
||||
"""Deprecated: get / set user agent sent by connection.
|
||||
|
||||
:rtype: str
|
||||
:returns: user agent
|
||||
"""
|
||||
warnings.warn(_USER_AGENT_ALL_CAPS_DEPRECATED, DeprecationWarning, stacklevel=2)
|
||||
return self.user_agent
|
||||
|
||||
@USER_AGENT.setter
|
||||
def USER_AGENT(self, value):
|
||||
warnings.warn(_USER_AGENT_ALL_CAPS_DEPRECATED, DeprecationWarning, stacklevel=2)
|
||||
self.user_agent = value
|
||||
|
||||
@property
|
||||
def user_agent(self):
|
||||
"""Get / set user agent sent by connection.
|
||||
|
||||
:rtype: str
|
||||
:returns: user agent
|
||||
"""
|
||||
return self._client_info.to_user_agent()
|
||||
|
||||
@user_agent.setter
|
||||
def user_agent(self, value):
|
||||
self._client_info.user_agent = value
|
||||
|
||||
@property
|
||||
def _EXTRA_HEADERS(self):
|
||||
"""Deprecated: get / set extra headers sent by connection.
|
||||
|
||||
:rtype: dict
|
||||
:returns: header keys / values
|
||||
"""
|
||||
warnings.warn(
|
||||
_EXTRA_HEADERS_ALL_CAPS_DEPRECATED, DeprecationWarning, stacklevel=2
|
||||
)
|
||||
return self.extra_headers
|
||||
|
||||
@_EXTRA_HEADERS.setter
|
||||
def _EXTRA_HEADERS(self, value):
|
||||
warnings.warn(
|
||||
_EXTRA_HEADERS_ALL_CAPS_DEPRECATED, DeprecationWarning, stacklevel=2
|
||||
)
|
||||
self.extra_headers = value
|
||||
|
||||
@property
|
||||
def extra_headers(self):
|
||||
"""Get / set extra headers sent by connection.
|
||||
|
||||
:rtype: dict
|
||||
:returns: header keys / values
|
||||
"""
|
||||
return self._extra_headers
|
||||
|
||||
@extra_headers.setter
|
||||
def extra_headers(self, value):
|
||||
self._extra_headers = value
|
||||
|
||||
@property
|
||||
def credentials(self):
|
||||
"""Getter for current credentials.
|
||||
|
||||
:rtype: :class:`google.auth.credentials.Credentials` or
|
||||
:class:`NoneType`
|
||||
:returns: The credentials object associated with this connection.
|
||||
"""
|
||||
return self._client._credentials
|
||||
|
||||
@property
|
||||
def http(self):
|
||||
"""A getter for the HTTP transport used in talking to the API.
|
||||
|
||||
Returns:
|
||||
google.auth.transport.requests.AuthorizedSession:
|
||||
A :class:`requests.Session` instance.
|
||||
"""
|
||||
return self._client._http
|
||||
|
||||
|
||||
class JSONConnection(Connection):
|
||||
"""A connection to a Google JSON-based API.
|
||||
|
||||
These APIs are discovery based. For reference:
|
||||
|
||||
https://developers.google.com/discovery/
|
||||
|
||||
This defines :meth:`api_request` for making a generic JSON
|
||||
API request and API requests are created elsewhere.
|
||||
|
||||
* :attr:`API_BASE_URL`
|
||||
* :attr:`API_VERSION`
|
||||
* :attr:`API_URL_TEMPLATE`
|
||||
|
||||
must be updated by subclasses.
|
||||
"""
|
||||
|
||||
API_BASE_URL: Optional[str] = None
|
||||
"""The base of the API call URL."""
|
||||
|
||||
API_BASE_MTLS_URL: Optional[str] = None
|
||||
"""The base of the API call URL for mutual TLS."""
|
||||
|
||||
ALLOW_AUTO_SWITCH_TO_MTLS_URL = False
|
||||
"""Indicates if auto switch to mTLS url is allowed."""
|
||||
|
||||
API_VERSION: Optional[str] = None
|
||||
"""The version of the API, used in building the API call's URL."""
|
||||
|
||||
API_URL_TEMPLATE: Optional[str] = None
|
||||
"""A template for the URL of a particular API call."""
|
||||
|
||||
def get_api_base_url_for_mtls(self, api_base_url=None):
|
||||
"""Return the api base url for mutual TLS.
|
||||
|
||||
Typically, you shouldn't need to use this method.
|
||||
|
||||
The logic is as follows:
|
||||
|
||||
If `api_base_url` is provided, just return this value; otherwise, the
|
||||
return value depends `GOOGLE_API_USE_MTLS_ENDPOINT` environment variable
|
||||
value.
|
||||
|
||||
If the environment variable value is "always", return `API_BASE_MTLS_URL`.
|
||||
If the environment variable value is "never", return `API_BASE_URL`.
|
||||
Otherwise, if `ALLOW_AUTO_SWITCH_TO_MTLS_URL` is True and the underlying
|
||||
http is mTLS, then return `API_BASE_MTLS_URL`; otherwise return `API_BASE_URL`.
|
||||
|
||||
:type api_base_url: str
|
||||
:param api_base_url: User provided api base url. It takes precedence over
|
||||
`API_BASE_URL` and `API_BASE_MTLS_URL`.
|
||||
|
||||
:rtype: str
|
||||
:returns: The api base url used for mTLS.
|
||||
"""
|
||||
if api_base_url:
|
||||
return api_base_url
|
||||
|
||||
env = os.getenv("GOOGLE_API_USE_MTLS_ENDPOINT", "auto")
|
||||
if env == "always":
|
||||
url_to_use = self.API_BASE_MTLS_URL
|
||||
elif env == "never":
|
||||
url_to_use = self.API_BASE_URL
|
||||
else:
|
||||
if self.ALLOW_AUTO_SWITCH_TO_MTLS_URL:
|
||||
url_to_use = (
|
||||
self.API_BASE_MTLS_URL if self.http.is_mtls else self.API_BASE_URL
|
||||
)
|
||||
else:
|
||||
url_to_use = self.API_BASE_URL
|
||||
return url_to_use
|
||||
|
||||
def build_api_url(
|
||||
self, path, query_params=None, api_base_url=None, api_version=None
|
||||
):
|
||||
"""Construct an API url given a few components, some optional.
|
||||
|
||||
Typically, you shouldn't need to use this method.
|
||||
|
||||
:type path: str
|
||||
:param path: The path to the resource (ie, ``'/b/bucket-name'``).
|
||||
|
||||
:type query_params: dict or list
|
||||
:param query_params: A dictionary of keys and values (or list of
|
||||
key-value pairs) to insert into the query
|
||||
string of the URL.
|
||||
|
||||
:type api_base_url: str
|
||||
:param api_base_url: The base URL for the API endpoint.
|
||||
Typically you won't have to provide this.
|
||||
|
||||
:type api_version: str
|
||||
:param api_version: The version of the API to call.
|
||||
Typically you shouldn't provide this and instead
|
||||
use the default for the library.
|
||||
|
||||
:rtype: str
|
||||
:returns: The URL assembled from the pieces provided.
|
||||
"""
|
||||
url = self.API_URL_TEMPLATE.format(
|
||||
api_base_url=self.get_api_base_url_for_mtls(api_base_url),
|
||||
api_version=(api_version or self.API_VERSION),
|
||||
path=path,
|
||||
)
|
||||
|
||||
query_params = query_params or {}
|
||||
|
||||
if isinstance(query_params, collections.abc.Mapping):
|
||||
query_params = query_params.copy()
|
||||
else:
|
||||
query_params_dict = collections.defaultdict(list)
|
||||
for key, value in query_params:
|
||||
query_params_dict[key].append(value)
|
||||
query_params = query_params_dict
|
||||
|
||||
query_params.setdefault("prettyPrint", "false")
|
||||
|
||||
url += "?" + urlencode(query_params, doseq=True)
|
||||
|
||||
return url
|
||||
|
||||
def _make_request(
|
||||
self,
|
||||
method,
|
||||
url,
|
||||
data=None,
|
||||
content_type=None,
|
||||
headers=None,
|
||||
target_object=None,
|
||||
timeout=_DEFAULT_TIMEOUT,
|
||||
extra_api_info=None,
|
||||
):
|
||||
"""A low level method to send a request to the API.
|
||||
|
||||
Typically, you shouldn't need to use this method.
|
||||
|
||||
:type method: str
|
||||
:param method: The HTTP method to use in the request.
|
||||
|
||||
:type url: str
|
||||
:param url: The URL to send the request to.
|
||||
|
||||
:type data: str
|
||||
:param data: The data to send as the body of the request.
|
||||
|
||||
:type content_type: str
|
||||
:param content_type: The proper MIME type of the data provided.
|
||||
|
||||
:type headers: dict
|
||||
:param headers: (Optional) A dictionary of HTTP headers to send with
|
||||
the request. If passed, will be modified directly
|
||||
here with added headers.
|
||||
|
||||
:type target_object: object
|
||||
:param target_object:
|
||||
(Optional) Argument to be used by library callers. This can allow
|
||||
custom behavior, for example, to defer an HTTP request and complete
|
||||
initialization of the object at a later time.
|
||||
|
||||
:type timeout: float or tuple
|
||||
:param timeout: (optional) The amount of time, in seconds, to wait
|
||||
for the server response.
|
||||
|
||||
Can also be passed as a tuple (connect_timeout, read_timeout).
|
||||
See :meth:`requests.Session.request` documentation for details.
|
||||
|
||||
:type extra_api_info: string
|
||||
:param extra_api_info: (optional) Extra api info to be appended to
|
||||
the X-Goog-API-Client header
|
||||
|
||||
:rtype: :class:`requests.Response`
|
||||
:returns: The HTTP response.
|
||||
"""
|
||||
headers = headers or {}
|
||||
headers.update(self.extra_headers)
|
||||
headers["Accept-Encoding"] = "gzip"
|
||||
|
||||
if content_type:
|
||||
headers["Content-Type"] = content_type
|
||||
|
||||
if extra_api_info:
|
||||
headers[CLIENT_INFO_HEADER] = f"{self.user_agent} {extra_api_info}"
|
||||
else:
|
||||
headers[CLIENT_INFO_HEADER] = self.user_agent
|
||||
headers["User-Agent"] = self.user_agent
|
||||
|
||||
return self._do_request(
|
||||
method, url, headers, data, target_object, timeout=timeout
|
||||
)
|
||||
|
||||
def _do_request(
|
||||
self, method, url, headers, data, target_object, timeout=_DEFAULT_TIMEOUT
|
||||
): # pylint: disable=unused-argument
|
||||
"""Low-level helper: perform the actual API request over HTTP.
|
||||
|
||||
Allows batch context managers to override and defer a request.
|
||||
|
||||
:type method: str
|
||||
:param method: The HTTP method to use in the request.
|
||||
|
||||
:type url: str
|
||||
:param url: The URL to send the request to.
|
||||
|
||||
:type headers: dict
|
||||
:param headers: A dictionary of HTTP headers to send with the request.
|
||||
|
||||
:type data: str
|
||||
:param data: The data to send as the body of the request.
|
||||
|
||||
:type target_object: object
|
||||
:param target_object:
|
||||
(Optional) Unused ``target_object`` here but may be used by a
|
||||
superclass.
|
||||
|
||||
:type timeout: float or tuple
|
||||
:param timeout: (optional) The amount of time, in seconds, to wait
|
||||
for the server response.
|
||||
|
||||
Can also be passed as a tuple (connect_timeout, read_timeout).
|
||||
See :meth:`requests.Session.request` documentation for details.
|
||||
|
||||
:rtype: :class:`requests.Response`
|
||||
:returns: The HTTP response.
|
||||
"""
|
||||
return self.http.request(
|
||||
url=url, method=method, headers=headers, data=data, timeout=timeout
|
||||
)
|
||||
|
||||
def api_request(
|
||||
self,
|
||||
method,
|
||||
path,
|
||||
query_params=None,
|
||||
data=None,
|
||||
content_type=None,
|
||||
headers=None,
|
||||
api_base_url=None,
|
||||
api_version=None,
|
||||
expect_json=True,
|
||||
_target_object=None,
|
||||
timeout=_DEFAULT_TIMEOUT,
|
||||
extra_api_info=None,
|
||||
):
|
||||
"""Make a request over the HTTP transport to the API.
|
||||
|
||||
You shouldn't need to use this method, but if you plan to
|
||||
interact with the API using these primitives, this is the
|
||||
correct one to use.
|
||||
|
||||
:type method: str
|
||||
:param method: The HTTP method name (ie, ``GET``, ``POST``, etc).
|
||||
Required.
|
||||
|
||||
:type path: str
|
||||
:param path: The path to the resource (ie, ``'/b/bucket-name'``).
|
||||
Required.
|
||||
|
||||
:type query_params: dict or list
|
||||
:param query_params: A dictionary of keys and values (or list of
|
||||
key-value pairs) to insert into the query
|
||||
string of the URL.
|
||||
|
||||
:type data: str
|
||||
:param data: The data to send as the body of the request. Default is
|
||||
the empty string.
|
||||
|
||||
:type content_type: str
|
||||
:param content_type: The proper MIME type of the data provided. Default
|
||||
is None.
|
||||
|
||||
:type headers: dict
|
||||
:param headers: extra HTTP headers to be sent with the request.
|
||||
|
||||
:type api_base_url: str
|
||||
:param api_base_url: The base URL for the API endpoint.
|
||||
Typically you won't have to provide this.
|
||||
Default is the standard API base URL.
|
||||
|
||||
:type api_version: str
|
||||
:param api_version: The version of the API to call. Typically
|
||||
you shouldn't provide this and instead use
|
||||
the default for the library. Default is the
|
||||
latest API version supported by
|
||||
google-cloud-python.
|
||||
|
||||
:type expect_json: bool
|
||||
:param expect_json: If True, this method will try to parse the
|
||||
response as JSON and raise an exception if
|
||||
that cannot be done. Default is True.
|
||||
|
||||
:type _target_object: :class:`object`
|
||||
:param _target_object:
|
||||
(Optional) Protected argument to be used by library callers. This
|
||||
can allow custom behavior, for example, to defer an HTTP request
|
||||
and complete initialization of the object at a later time.
|
||||
|
||||
:type timeout: float or tuple
|
||||
:param timeout: (optional) The amount of time, in seconds, to wait
|
||||
for the server response.
|
||||
|
||||
Can also be passed as a tuple (connect_timeout, read_timeout).
|
||||
See :meth:`requests.Session.request` documentation for details.
|
||||
|
||||
:type extra_api_info: string
|
||||
:param extra_api_info: (optional) Extra api info to be appended to
|
||||
the X-Goog-API-Client header
|
||||
|
||||
:raises ~google.cloud.exceptions.GoogleCloudError: if the response code
|
||||
is not 200 OK.
|
||||
:raises ValueError: if the response content type is not JSON.
|
||||
:rtype: dict or str
|
||||
:returns: The API response payload, either as a raw string or
|
||||
a dictionary if the response is valid JSON.
|
||||
"""
|
||||
url = self.build_api_url(
|
||||
path=path,
|
||||
query_params=query_params,
|
||||
api_base_url=api_base_url,
|
||||
api_version=api_version,
|
||||
)
|
||||
|
||||
# Making the executive decision that any dictionary
|
||||
# data will be sent properly as JSON.
|
||||
if data and isinstance(data, dict):
|
||||
data = json.dumps(data)
|
||||
content_type = "application/json"
|
||||
|
||||
response = self._make_request(
|
||||
method=method,
|
||||
url=url,
|
||||
data=data,
|
||||
content_type=content_type,
|
||||
headers=headers,
|
||||
target_object=_target_object,
|
||||
timeout=timeout,
|
||||
extra_api_info=extra_api_info,
|
||||
)
|
||||
|
||||
if not 200 <= response.status_code < 300:
|
||||
raise exceptions.from_http_response(response)
|
||||
|
||||
if expect_json and response.content:
|
||||
return response.json()
|
||||
else:
|
||||
return response.content
|
||||
Binary file not shown.
@@ -0,0 +1,2 @@
|
||||
# Marker file for PEP 561.
|
||||
# This package uses inline types.
|
||||
@@ -0,0 +1,121 @@
|
||||
# Copyright 2014 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Shared testing utilities."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
|
||||
|
||||
class _Monkey(object):
|
||||
"""Context-manager for replacing module names in the scope of a test."""
|
||||
|
||||
def __init__(self, module, **kw):
|
||||
self.module = module
|
||||
if not kw: # pragma: NO COVER
|
||||
raise ValueError("_Monkey was used with nothing to monkey-patch")
|
||||
self.to_restore = {key: getattr(module, key) for key in kw}
|
||||
for key, value in kw.items():
|
||||
setattr(module, key, value)
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
for key, value in self.to_restore.items():
|
||||
setattr(self.module, key, value)
|
||||
|
||||
|
||||
class _NamedTemporaryFile(object):
|
||||
def __init__(self, suffix=""):
|
||||
import os
|
||||
import tempfile
|
||||
|
||||
filehandle, self.name = tempfile.mkstemp(suffix=suffix)
|
||||
os.close(filehandle)
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
import os
|
||||
|
||||
os.remove(self.name)
|
||||
|
||||
|
||||
def _tempdir_maker():
|
||||
import contextlib
|
||||
import shutil
|
||||
import tempfile
|
||||
|
||||
@contextlib.contextmanager
|
||||
def _tempdir_mgr():
|
||||
temp_dir = tempfile.mkdtemp()
|
||||
yield temp_dir
|
||||
shutil.rmtree(temp_dir)
|
||||
|
||||
return _tempdir_mgr
|
||||
|
||||
|
||||
# pylint: disable=invalid-name
|
||||
# Retain _tempdir as a constant for backwards compatibility despite
|
||||
# being an invalid name.
|
||||
_tempdir = _tempdir_maker()
|
||||
del _tempdir_maker
|
||||
# pylint: enable=invalid-name
|
||||
|
||||
|
||||
class _GAXBaseAPI(object):
|
||||
_random_gax_error = False
|
||||
|
||||
def __init__(self, **kw):
|
||||
self.__dict__.update(kw)
|
||||
|
||||
@staticmethod
|
||||
def _make_grpc_error(status_code, trailing=None):
|
||||
from grpc._channel import _RPCState
|
||||
from google.cloud.exceptions import GrpcRendezvous
|
||||
|
||||
details = "Some error details."
|
||||
exc_state = _RPCState((), None, trailing, status_code, details)
|
||||
return GrpcRendezvous(exc_state, None, None, None)
|
||||
|
||||
def _make_grpc_not_found(self):
|
||||
from grpc import StatusCode
|
||||
|
||||
return self._make_grpc_error(StatusCode.NOT_FOUND)
|
||||
|
||||
def _make_grpc_failed_precondition(self):
|
||||
from grpc import StatusCode
|
||||
|
||||
return self._make_grpc_error(StatusCode.FAILED_PRECONDITION)
|
||||
|
||||
def _make_grpc_already_exists(self):
|
||||
from grpc import StatusCode
|
||||
|
||||
return self._make_grpc_error(StatusCode.ALREADY_EXISTS)
|
||||
|
||||
def _make_grpc_deadline_exceeded(self):
|
||||
from grpc import StatusCode
|
||||
|
||||
return self._make_grpc_error(StatusCode.DEADLINE_EXCEEDED)
|
||||
|
||||
|
||||
class _GAXPageIterator(object):
|
||||
def __init__(self, *pages, **kwargs):
|
||||
self._pages = iter(pages)
|
||||
self.page_token = kwargs.get("page_token")
|
||||
|
||||
def __next__(self):
|
||||
"""Iterate to the next page."""
|
||||
return next(self._pages)
|
||||
Binary file not shown.
@@ -0,0 +1,2 @@
|
||||
# Marker file for PEP 561.
|
||||
# This package uses inline types.
|
||||
@@ -0,0 +1,187 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
# Copyright 2022 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
|
||||
from google.cloud.aiplatform import version as aiplatform_version
|
||||
|
||||
__version__ = aiplatform_version.__version__
|
||||
|
||||
|
||||
from google.cloud.aiplatform import initializer
|
||||
|
||||
from google.cloud.aiplatform.datasets import (
|
||||
ImageDataset,
|
||||
TabularDataset,
|
||||
TextDataset,
|
||||
TimeSeriesDataset,
|
||||
VideoDataset,
|
||||
)
|
||||
from google.cloud.aiplatform import explain
|
||||
from google.cloud.aiplatform import gapic
|
||||
from google.cloud.aiplatform import hyperparameter_tuning
|
||||
from google.cloud.aiplatform.featurestore import (
|
||||
EntityType,
|
||||
Feature,
|
||||
Featurestore,
|
||||
)
|
||||
from google.cloud.aiplatform.matching_engine import (
|
||||
MatchingEngineIndex,
|
||||
MatchingEngineIndexEndpoint,
|
||||
)
|
||||
from google.cloud.aiplatform import metadata
|
||||
from google.cloud.aiplatform.tensorboard import uploader_tracker
|
||||
from google.cloud.aiplatform.models import DeploymentResourcePool
|
||||
from google.cloud.aiplatform.models import Endpoint
|
||||
from google.cloud.aiplatform.models import PrivateEndpoint
|
||||
from google.cloud.aiplatform.models import Model
|
||||
from google.cloud.aiplatform.models import ModelRegistry
|
||||
from google.cloud.aiplatform.model_evaluation import ModelEvaluation
|
||||
from google.cloud.aiplatform.jobs import (
|
||||
BatchPredictionJob,
|
||||
CustomJob,
|
||||
HyperparameterTuningJob,
|
||||
ModelDeploymentMonitoringJob,
|
||||
)
|
||||
from google.cloud.aiplatform.pipeline_jobs import PipelineJob
|
||||
from google.cloud.aiplatform.pipeline_job_schedules import (
|
||||
PipelineJobSchedule,
|
||||
)
|
||||
from google.cloud.aiplatform.tensorboard import (
|
||||
Tensorboard,
|
||||
TensorboardExperiment,
|
||||
TensorboardRun,
|
||||
TensorboardTimeSeries,
|
||||
)
|
||||
from google.cloud.aiplatform.training_jobs import (
|
||||
CustomTrainingJob,
|
||||
CustomContainerTrainingJob,
|
||||
CustomPythonPackageTrainingJob,
|
||||
AutoMLTabularTrainingJob,
|
||||
AutoMLForecastingTrainingJob,
|
||||
SequenceToSequencePlusForecastingTrainingJob,
|
||||
TemporalFusionTransformerForecastingTrainingJob,
|
||||
TimeSeriesDenseEncoderForecastingTrainingJob,
|
||||
AutoMLImageTrainingJob,
|
||||
AutoMLTextTrainingJob,
|
||||
AutoMLVideoTrainingJob,
|
||||
)
|
||||
|
||||
from google.cloud.aiplatform import helpers
|
||||
|
||||
"""
|
||||
Usage:
|
||||
from google.cloud import aiplatform
|
||||
|
||||
aiplatform.init(project='my_project')
|
||||
"""
|
||||
init = initializer.global_config.init
|
||||
|
||||
get_pipeline_df = metadata.metadata._LegacyExperimentService.get_pipeline_df
|
||||
|
||||
log_params = metadata.metadata._experiment_tracker.log_params
|
||||
log_metrics = metadata.metadata._experiment_tracker.log_metrics
|
||||
log_classification_metrics = (
|
||||
metadata.metadata._experiment_tracker.log_classification_metrics
|
||||
)
|
||||
log_model = metadata.metadata._experiment_tracker.log_model
|
||||
get_experiment_df = metadata.metadata._experiment_tracker.get_experiment_df
|
||||
start_run = metadata.metadata._experiment_tracker.start_run
|
||||
autolog = metadata.metadata._experiment_tracker.autolog
|
||||
start_execution = metadata.metadata._experiment_tracker.start_execution
|
||||
log = metadata.metadata._experiment_tracker.log
|
||||
log_time_series_metrics = metadata.metadata._experiment_tracker.log_time_series_metrics
|
||||
end_run = metadata.metadata._experiment_tracker.end_run
|
||||
|
||||
upload_tb_log = uploader_tracker._tensorboard_tracker.upload_tb_log
|
||||
start_upload_tb_log = uploader_tracker._tensorboard_tracker.start_upload_tb_log
|
||||
end_upload_tb_log = uploader_tracker._tensorboard_tracker.end_upload_tb_log
|
||||
|
||||
save_model = metadata._models.save_model
|
||||
get_experiment_model = metadata.schema.google.artifact_schema.ExperimentModel.get
|
||||
|
||||
Experiment = metadata.experiment_resources.Experiment
|
||||
ExperimentRun = metadata.experiment_run_resource.ExperimentRun
|
||||
Artifact = metadata.artifact.Artifact
|
||||
Execution = metadata.execution.Execution
|
||||
Context = metadata.context.Context
|
||||
|
||||
|
||||
__all__ = (
|
||||
"end_run",
|
||||
"explain",
|
||||
"gapic",
|
||||
"init",
|
||||
"helpers",
|
||||
"hyperparameter_tuning",
|
||||
"log",
|
||||
"log_params",
|
||||
"log_metrics",
|
||||
"log_classification_metrics",
|
||||
"log_model",
|
||||
"log_time_series_metrics",
|
||||
"get_experiment_df",
|
||||
"get_pipeline_df",
|
||||
"start_run",
|
||||
"start_execution",
|
||||
"save_model",
|
||||
"get_experiment_model",
|
||||
"autolog",
|
||||
"upload_tb_log",
|
||||
"start_upload_tb_log",
|
||||
"end_upload_tb_log",
|
||||
"Artifact",
|
||||
"AutoMLImageTrainingJob",
|
||||
"AutoMLTabularTrainingJob",
|
||||
"AutoMLForecastingTrainingJob",
|
||||
"AutoMLTextTrainingJob",
|
||||
"AutoMLVideoTrainingJob",
|
||||
"BatchPredictionJob",
|
||||
"CustomJob",
|
||||
"CustomTrainingJob",
|
||||
"CustomContainerTrainingJob",
|
||||
"CustomPythonPackageTrainingJob",
|
||||
"DeploymentResourcePool",
|
||||
"Endpoint",
|
||||
"EntityType",
|
||||
"Execution",
|
||||
"Experiment",
|
||||
"ExperimentRun",
|
||||
"Feature",
|
||||
"Featurestore",
|
||||
"MatchingEngineIndex",
|
||||
"MatchingEngineIndexEndpoint",
|
||||
"ImageDataset",
|
||||
"HyperparameterTuningJob",
|
||||
"Model",
|
||||
"ModelRegistry",
|
||||
"ModelEvaluation",
|
||||
"ModelDeploymentMonitoringJob",
|
||||
"PipelineJob",
|
||||
"PipelineJobSchedule",
|
||||
"PrivateEndpoint",
|
||||
"SequenceToSequencePlusForecastingTrainingJob",
|
||||
"TabularDataset",
|
||||
"Tensorboard",
|
||||
"TensorboardExperiment",
|
||||
"TensorboardRun",
|
||||
"TensorboardTimeSeries",
|
||||
"TextDataset",
|
||||
"TemporalFusionTransformerForecastingTrainingJob",
|
||||
"TimeSeriesDataset",
|
||||
"TimeSeriesDenseEncoderForecastingTrainingJob",
|
||||
"VideoDataset",
|
||||
)
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,16 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
# Copyright 2022 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,471 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
# Copyright 2023 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
from collections import defaultdict
|
||||
from typing import Any, Dict, List, NamedTuple, Optional, Union
|
||||
|
||||
from mlflow import entities as mlflow_entities
|
||||
from mlflow.store.tracking import abstract_store
|
||||
from mlflow import exceptions as mlflow_exceptions
|
||||
|
||||
from google.cloud import aiplatform
|
||||
from google.cloud.aiplatform import base
|
||||
from google.cloud.aiplatform import utils
|
||||
from google.cloud.aiplatform.compat.types import execution as execution_v1
|
||||
|
||||
_LOGGER = base.Logger(__name__)
|
||||
|
||||
# MLFlow RunStatus:
|
||||
# https://www.mlflow.org/docs/latest/python_api/mlflow.entities.html#mlflow.entities.RunStatus
|
||||
_MLFLOW_RUN_TO_VERTEX_RUN_STATUS = {
|
||||
mlflow_entities.RunStatus.FINISHED: execution_v1.Execution.State.COMPLETE,
|
||||
mlflow_entities.RunStatus.FAILED: execution_v1.Execution.State.FAILED,
|
||||
mlflow_entities.RunStatus.RUNNING: execution_v1.Execution.State.RUNNING,
|
||||
mlflow_entities.RunStatus.KILLED: execution_v1.Execution.State.CANCELLED,
|
||||
mlflow_entities.RunStatus.SCHEDULED: execution_v1.Execution.State.NEW,
|
||||
}
|
||||
mlflow_to_vertex_run_default = defaultdict(
|
||||
lambda: execution_v1.Execution.State.STATE_UNSPECIFIED
|
||||
)
|
||||
for mlflow_status in _MLFLOW_RUN_TO_VERTEX_RUN_STATUS:
|
||||
mlflow_to_vertex_run_default[mlflow_status] = _MLFLOW_RUN_TO_VERTEX_RUN_STATUS[
|
||||
mlflow_status
|
||||
]
|
||||
|
||||
# Mapping of Vertex run status to MLFlow run status (inverse of _MLFLOW_RUN_TO_VERTEX_RUN_STATUS)
|
||||
_VERTEX_RUN_TO_MLFLOW_RUN_STATUS = {
|
||||
v: k for k, v in _MLFLOW_RUN_TO_VERTEX_RUN_STATUS.items()
|
||||
}
|
||||
vertex_run_to_mflow_default = defaultdict(lambda: mlflow_entities.RunStatus.FAILED)
|
||||
for vertex_status in _VERTEX_RUN_TO_MLFLOW_RUN_STATUS:
|
||||
vertex_run_to_mflow_default[vertex_status] = _VERTEX_RUN_TO_MLFLOW_RUN_STATUS[
|
||||
vertex_status
|
||||
]
|
||||
|
||||
_MLFLOW_TERMINAL_RUN_STATES = [
|
||||
mlflow_entities.RunStatus.FINISHED,
|
||||
mlflow_entities.RunStatus.FAILED,
|
||||
mlflow_entities.RunStatus.KILLED,
|
||||
]
|
||||
|
||||
|
||||
class _RunTracker(NamedTuple):
|
||||
"""Tracks the current Vertex ExperimentRun.
|
||||
|
||||
Stores the current ExperimentRun the plugin is writing to and whether or
|
||||
not this run is autocreated.
|
||||
|
||||
Attributes:
|
||||
autocreate (bool):
|
||||
Whether the Vertex ExperimentRun should be autocreated. If False,
|
||||
the plugin writes to the currently active run created via
|
||||
`aiplatform.start_run()`.
|
||||
experiment_run (aiplatform.ExperimentRun):
|
||||
The currently set ExperimentRun.
|
||||
"""
|
||||
|
||||
autocreate: bool
|
||||
experiment_run: "aiplatform.ExperimentRun"
|
||||
|
||||
|
||||
class _VertexMlflowTracking(abstract_store.AbstractStore):
|
||||
"""Vertex plugin implementation of MLFlow's AbstractStore class."""
|
||||
|
||||
def _to_mlflow_metric(
|
||||
self,
|
||||
vertex_metrics: Dict[str, Union[float, int, str]],
|
||||
) -> Optional[List[mlflow_entities.Metric]]:
|
||||
"""Helper method to convert Vertex metrics to mlflow.entities.Metric type.
|
||||
|
||||
Args:
|
||||
vertex_metrics (Dict[str, Union[float, int, str]]):
|
||||
Required. A dictionary of Vertex metrics returned from
|
||||
ExperimentRun.get_metrics()
|
||||
Returns:
|
||||
List[mlflow_entities.Metric] - A list of metrics converted to MLFlow's
|
||||
Metric type.
|
||||
"""
|
||||
|
||||
mlflow_metrics = []
|
||||
|
||||
if vertex_metrics:
|
||||
for metric_key in vertex_metrics:
|
||||
mlflow_metric = mlflow_entities.Metric(
|
||||
key=metric_key,
|
||||
value=vertex_metrics[metric_key],
|
||||
step=0,
|
||||
timestamp=0,
|
||||
)
|
||||
mlflow_metrics.append(mlflow_metric)
|
||||
else:
|
||||
return None
|
||||
|
||||
return mlflow_metrics
|
||||
|
||||
def _to_mlflow_params(
|
||||
self, vertex_params: Dict[str, Union[float, int, str]]
|
||||
) -> Optional[mlflow_entities.Param]:
|
||||
"""Helper method to convert Vertex params to mlflow.entities.Param type.
|
||||
|
||||
Args:
|
||||
vertex_params (Dict[str, Union[float, int, str]]):
|
||||
Required. A dictionary of Vertex params returned from
|
||||
ExperimentRun.get_params()
|
||||
Returns:
|
||||
List[mlflow_entities.Param] - A list of params converted to MLFlow's
|
||||
Param type.
|
||||
"""
|
||||
|
||||
mlflow_params = []
|
||||
|
||||
if vertex_params:
|
||||
for param_key in vertex_params:
|
||||
mlflow_param = mlflow_entities.Param(
|
||||
key=param_key, value=vertex_params[param_key]
|
||||
)
|
||||
mlflow_params.append(mlflow_param)
|
||||
else:
|
||||
return None
|
||||
|
||||
return mlflow_params
|
||||
|
||||
def _to_mlflow_entity(
|
||||
self,
|
||||
vertex_exp: "aiplatform.Experiment",
|
||||
vertex_run: "aiplatform.ExperimentRun",
|
||||
) -> mlflow_entities.Run:
|
||||
"""Helper method to convert data to required MLFlow type.
|
||||
|
||||
This converts data into MLFlow's mlflow_entities.Run type, which is a
|
||||
required return type for some methods we're overriding in this plugin.
|
||||
|
||||
Args:
|
||||
vertex_exp (aiplatform.Experiment):
|
||||
Required. The current Vertex Experiment.
|
||||
vertex_run (aiplatform.ExperimentRun):
|
||||
Required. The active Vertex ExperimentRun
|
||||
Returns:
|
||||
mlflow_entities.Run - The data from the currently active run
|
||||
converted to MLFLow's mlflow_entities.Run type.
|
||||
|
||||
https://www.mlflow.org/docs/latest/python_api/mlflow.entities.html#mlflow.entities.Run
|
||||
"""
|
||||
|
||||
run_info = mlflow_entities.RunInfo(
|
||||
run_id=f"{vertex_exp.name}-{vertex_run.name}",
|
||||
run_uuid=f"{vertex_exp.name}-{vertex_run.name}",
|
||||
experiment_id=vertex_exp.name,
|
||||
user_id="",
|
||||
status=vertex_run_to_mflow_default[vertex_run.state],
|
||||
start_time=1,
|
||||
end_time=2,
|
||||
lifecycle_stage=mlflow_entities.LifecycleStage.ACTIVE,
|
||||
artifact_uri="file:///tmp/", # The plugin will fail if artifact_uri is not set to a valid filepath string
|
||||
)
|
||||
|
||||
run_data = mlflow_entities.RunData(
|
||||
metrics=self._to_mlflow_metric(vertex_run.get_metrics()),
|
||||
params=self._to_mlflow_params(vertex_run.get_params()),
|
||||
tags={},
|
||||
)
|
||||
|
||||
return mlflow_entities.Run(run_info=run_info, run_data=run_data)
|
||||
|
||||
def __init__(self, store_uri: Optional[str], artifact_uri: Optional[str]) -> None:
|
||||
"""Initializes the Vertex MLFlow plugin.
|
||||
|
||||
This plugin overrides MLFlow's AbstractStore class to write metrics and
|
||||
parameters from model training code to Vertex Experiments. This plugin
|
||||
is private and should not be instantiated outside the Vertex SDK.
|
||||
|
||||
The _run_map instance property is a dict mapping MLFlow run_id to an
|
||||
instance of _RunTracker with data on the corresponding Vertex
|
||||
ExperimentRun.
|
||||
|
||||
For example: {
|
||||
'sklearn-12345': _RunTracker(autocreate=True, experiment_run=aiplatform.ExperimentRun(...))
|
||||
}
|
||||
|
||||
Until autologging and Experiments supports nested runs, _nested_run_tracker
|
||||
is used to ensure the plugin shows a warning log exactly once every time it
|
||||
encounters a model that produces nested runs, like sklearn GridSearchCV and
|
||||
RandomizedSearchCV models. It is a mapping of parent_run_id to the number of
|
||||
child runs for that parent. When exactly 1 child run is found, the warning
|
||||
log is shown.
|
||||
|
||||
Args:
|
||||
store_uri (str):
|
||||
The tracking store uri used by MLFlow to write parameters and
|
||||
metrics for a run. This plugin ignores store_uri since we are
|
||||
writing data to Vertex Experiments. For this plugin, the value
|
||||
of store_uri will always be `vertex-mlflow-plugin://`.
|
||||
artifact_uri (str):
|
||||
The artifact uri used by MLFlow to write artifacts generated by
|
||||
a run. This plugin ignores artifact_uri since it doesn't write
|
||||
any artifacts to Vertex.
|
||||
"""
|
||||
|
||||
self._run_map = {}
|
||||
self._vertex_experiment = None
|
||||
self._nested_run_tracker = {}
|
||||
super(_VertexMlflowTracking, self).__init__()
|
||||
|
||||
@property
|
||||
def run_map(self) -> Dict[str, Any]:
|
||||
return self._run_map
|
||||
|
||||
@property
|
||||
def vertex_experiment(self) -> "aiplatform.Experiment":
|
||||
return self._vertex_experiment
|
||||
|
||||
def create_run(
|
||||
self,
|
||||
experiment_id: str,
|
||||
user_id: str,
|
||||
start_time: str,
|
||||
tags: List[mlflow_entities.RunTag],
|
||||
run_name: str,
|
||||
) -> mlflow_entities.Run:
|
||||
"""Creates a new ExperimentRun in Vertex if no run is active.
|
||||
|
||||
This overrides the behavior of MLFlow's `create_run()` method to check
|
||||
if there is a currently active ExperimentRun. If no ExperimentRun is
|
||||
active, a new Vertex ExperimentRun will be created with the name
|
||||
`<ml-framework>-<timestamp>`. If aiplatform.start_run() has been
|
||||
invoked and there is an active run, no run will be created and the
|
||||
currently active ExperimentRun will be returned as an MLFlow Run
|
||||
entity.
|
||||
|
||||
Args:
|
||||
experiment_id (str):
|
||||
The ID of the currently set MLFlow Experiment. Not used by this
|
||||
plugin.
|
||||
user_id (str):
|
||||
The ID of the MLFlow user. Not used by this plugin.
|
||||
start_time (int):
|
||||
The start time of the run, in milliseconds since the UNIX
|
||||
epoch. Not used by this plugin.
|
||||
tags (List[mlflow_entities.RunTag]):
|
||||
The tags provided by MLFlow. Only the `mlflow.autologging` tag
|
||||
is used by this plugin.
|
||||
run_name (str):
|
||||
The name of the MLFlow run. Not used by this plugin.
|
||||
Returns:
|
||||
mlflow_entities.Run - The created run returned as MLFLow's run
|
||||
type.
|
||||
Raises:
|
||||
RuntimeError:
|
||||
If a second model training call is made to a manually created
|
||||
run created via `aiplatform.start_run()` that has already been
|
||||
used to autolog metrics and parameters in this session.
|
||||
"""
|
||||
|
||||
self._vertex_experiment = (
|
||||
aiplatform.metadata.metadata._experiment_tracker.experiment
|
||||
)
|
||||
|
||||
currently_active_run = (
|
||||
aiplatform.metadata.metadata._experiment_tracker.experiment_run
|
||||
)
|
||||
|
||||
parent_run_id = None
|
||||
|
||||
for tag in tags:
|
||||
if tag.key == "mlflow.parentRunId" and tag.value is not None:
|
||||
parent_run_id = tag.value
|
||||
if parent_run_id in self._nested_run_tracker:
|
||||
self._nested_run_tracker[parent_run_id] += 1
|
||||
else:
|
||||
self._nested_run_tracker[parent_run_id] = 1
|
||||
_LOGGER.warning(
|
||||
f"This model creates nested runs. No additional ExperimentRun resources will be created for nested runs, summary metrics and parameters will be logged to the parent ExperimentRun: {parent_run_id}."
|
||||
)
|
||||
|
||||
if currently_active_run:
|
||||
if (
|
||||
f"{currently_active_run.resource_id}" in self._run_map
|
||||
and not parent_run_id
|
||||
):
|
||||
_LOGGER.warning(
|
||||
"Metrics and parameters have already been logged to this run. Call aiplatform.end_run() to end the current run before training a new model."
|
||||
)
|
||||
raise mlflow_exceptions.MlflowException(
|
||||
"Metrics and parameters have already been logged to this run. Call aiplatform.end_run() to end the current run before training a new model."
|
||||
)
|
||||
elif not parent_run_id:
|
||||
run_tracker = _RunTracker(
|
||||
autocreate=False, experiment_run=currently_active_run
|
||||
)
|
||||
current_run_id = currently_active_run.name
|
||||
|
||||
# nested run case
|
||||
else:
|
||||
raise mlflow_exceptions.MlflowException(
|
||||
f"This model creates nested runs. No additional ExperimentRun resources will be created for nested runs, summary metrics and parameters will be logged to the {parent_run_id}: ExperimentRun."
|
||||
)
|
||||
|
||||
# Create a new run if aiplatform.start_run() hasn't been called
|
||||
else:
|
||||
framework = ""
|
||||
|
||||
for tag in tags:
|
||||
if tag.key == "mlflow.autologging":
|
||||
framework = tag.value
|
||||
|
||||
current_run_id = f"{framework}-{utils.timestamped_unique_name()}"
|
||||
currently_active_run = aiplatform.start_run(run=current_run_id)
|
||||
run_tracker = _RunTracker(
|
||||
autocreate=True, experiment_run=currently_active_run
|
||||
)
|
||||
|
||||
self._run_map[currently_active_run.resource_id] = run_tracker
|
||||
|
||||
return self._to_mlflow_entity(
|
||||
vertex_exp=self._vertex_experiment,
|
||||
vertex_run=run_tracker.experiment_run,
|
||||
)
|
||||
|
||||
def update_run_info(
|
||||
self,
|
||||
run_id: str,
|
||||
run_status: mlflow_entities.RunStatus,
|
||||
end_time: int,
|
||||
run_name: str,
|
||||
) -> mlflow_entities.RunInfo:
|
||||
"""Updates the ExperimentRun status with the status provided by MLFlow.
|
||||
|
||||
Args:
|
||||
run_id (str):
|
||||
The ID of the currently set MLFlow run. This is mapped to the
|
||||
corresponding ExperimentRun in self._run_map.
|
||||
run_status (mlflow_entities.RunStatus):
|
||||
The run status provided by MLFlow MLFlow.
|
||||
end_time (int):
|
||||
The end time of the run. Not used by this plugin.
|
||||
run_name (str):
|
||||
The name of the MLFlow run. Not used by this plugin.
|
||||
Returns:
|
||||
mlflow_entities.RunInfo - Info about the updated run in MLFlow's
|
||||
required RunInfo format.
|
||||
"""
|
||||
|
||||
# The if block below does the following:
|
||||
# - Ends autocreated ExperimentRuns when MLFlow returns a terminal RunStatus.
|
||||
# - For other autocreated runs or runs where MLFlow returns a non-terminal
|
||||
# RunStatus, this updates the ExperimentRun with the corresponding
|
||||
# _MLFLOW_RUN_TO_VERTEX_RUN_STATUS.
|
||||
# - Non-autocreated ExperimentRuns with a terminal status are not ended.
|
||||
|
||||
if (
|
||||
self._run_map[run_id].autocreate
|
||||
and run_status in _MLFLOW_TERMINAL_RUN_STATES
|
||||
and self._run_map[run_id].experiment_run
|
||||
is aiplatform.metadata.metadata._experiment_tracker.experiment_run
|
||||
):
|
||||
aiplatform.metadata.metadata._experiment_tracker.end_run(
|
||||
state=execution_v1.Execution.State.COMPLETE
|
||||
)
|
||||
elif (
|
||||
self._run_map[run_id].autocreate
|
||||
or run_status not in _MLFLOW_TERMINAL_RUN_STATES
|
||||
):
|
||||
self._run_map[run_id].experiment_run.update_state(
|
||||
state=mlflow_to_vertex_run_default[run_status]
|
||||
)
|
||||
|
||||
return mlflow_entities.RunInfo(
|
||||
run_uuid=run_id,
|
||||
run_id=run_id,
|
||||
status=run_status,
|
||||
end_time=end_time,
|
||||
experiment_id=self._vertex_experiment,
|
||||
user_id="",
|
||||
start_time=1,
|
||||
lifecycle_stage=mlflow_entities.LifecycleStage.ACTIVE,
|
||||
artifact_uri="file:///tmp/",
|
||||
)
|
||||
|
||||
def log_batch(
|
||||
self,
|
||||
run_id: str,
|
||||
metrics: List[mlflow_entities.Metric],
|
||||
params: List[mlflow_entities.Param],
|
||||
tags: List[mlflow_entities.RunTag],
|
||||
) -> None:
|
||||
"""The primary logging method used by MLFlow.
|
||||
|
||||
This plugin overrides this method to write the metrics and parameters
|
||||
provided by MLFlow to the active Vertex ExperimentRun.
|
||||
Args:
|
||||
run_id (str):
|
||||
The ID of the MLFlow run to write metrics to. This is mapped to
|
||||
the corresponding ExperimentRun in self._run_map.
|
||||
metrics (List[mlflow_entities.Metric]):
|
||||
A list of MLFlow metrics generated from the current model
|
||||
training run.
|
||||
params (List[mlflow_entities.Param]):
|
||||
A list of MLFlow params generated from the current model
|
||||
training run.
|
||||
tags (List[mlflow_entities.RunTag]):
|
||||
The tags provided by MLFlow. Not used by this plugin.
|
||||
"""
|
||||
|
||||
summary_metrics = {}
|
||||
summary_params = {}
|
||||
time_series_metrics = {}
|
||||
|
||||
# Get the run to write to
|
||||
vertex_run = self._run_map[run_id].experiment_run
|
||||
|
||||
for metric in metrics:
|
||||
if metric.step:
|
||||
if metric.step not in time_series_metrics:
|
||||
time_series_metrics[metric.step] = {metric.key: metric.value}
|
||||
else:
|
||||
time_series_metrics[metric.step][metric.key] = metric.value
|
||||
else:
|
||||
summary_metrics[metric.key] = metric.value
|
||||
|
||||
for param in params:
|
||||
summary_params[param.key] = param.value
|
||||
|
||||
if summary_metrics:
|
||||
vertex_run.log_metrics(metrics=summary_metrics)
|
||||
|
||||
if summary_params:
|
||||
vertex_run.log_params(params=summary_params)
|
||||
|
||||
# TODO(b/261722623): batch these calls
|
||||
if time_series_metrics:
|
||||
for step in time_series_metrics:
|
||||
vertex_run.log_time_series_metrics(time_series_metrics[step], step)
|
||||
|
||||
def get_run(self, run_id: str) -> mlflow_entities.Run:
|
||||
"""Gets the currently active run.
|
||||
|
||||
Args:
|
||||
run_id (str):
|
||||
The ID of the currently set MLFlow run. This is mapped to the
|
||||
corresponding ExperimentRun in self._run_map.
|
||||
Returns:
|
||||
mlflow_entities.Run - The currently active Vertex ExperimentRun,
|
||||
returned as MLFLow's run type.
|
||||
"""
|
||||
return self._to_mlflow_entity(
|
||||
vertex_exp=self._vertex_experiment,
|
||||
vertex_run=self._run_map[run_id].experiment_run,
|
||||
)
|
||||
@@ -0,0 +1,22 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
# Copyright 2022 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
from google.cloud.aiplatform._pipeline_based_service.pipeline_based_service import (
|
||||
_VertexAiPipelineBasedService,
|
||||
)
|
||||
|
||||
__all__ = ("_VertexAiPipelineBasedService",)
|
||||
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,431 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
# Copyright 2022 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
import abc
|
||||
import logging
|
||||
from typing import (
|
||||
Any,
|
||||
Dict,
|
||||
FrozenSet,
|
||||
Optional,
|
||||
List,
|
||||
Tuple,
|
||||
Union,
|
||||
)
|
||||
|
||||
from google.auth import credentials as auth_credentials
|
||||
|
||||
from google.cloud import aiplatform
|
||||
from google.cloud.aiplatform import base
|
||||
from google.cloud.aiplatform import pipeline_jobs
|
||||
from google.cloud.aiplatform import utils
|
||||
from google.cloud.aiplatform.compat.types import (
|
||||
pipeline_state as gca_pipeline_state,
|
||||
)
|
||||
from google.cloud.aiplatform.constants import pipeline as pipeline_constants
|
||||
|
||||
_PIPELINE_COMPLETE_STATES = pipeline_constants._PIPELINE_COMPLETE_STATES
|
||||
|
||||
|
||||
class _VertexAiPipelineBasedService(base.VertexAiStatefulResource):
|
||||
"""Base class for Vertex AI Pipeline based services."""
|
||||
|
||||
client_class = utils.PipelineJobClientWithOverride
|
||||
_resource_noun = "pipelineJob"
|
||||
_delete_method = "delete_pipeline_job"
|
||||
_getter_method = "get_pipeline_job"
|
||||
_list_method = "list_pipeline_jobs"
|
||||
_parse_resource_name_method = "parse_pipeline_job_path"
|
||||
_format_resource_name_method = "pipeline_job_path"
|
||||
|
||||
_valid_done_states = _PIPELINE_COMPLETE_STATES
|
||||
|
||||
@property
|
||||
@classmethod
|
||||
@abc.abstractmethod
|
||||
def _template_ref(cls) -> FrozenSet[Tuple[str, str]]:
|
||||
"""A dictionary of the pipeline template URLs for this service.
|
||||
|
||||
The key is an identifier for that template and the value is the url of
|
||||
that pipeline template.
|
||||
|
||||
For example: {"tabular_classification": "gs://path/to/tabular/pipeline/template.json"}
|
||||
|
||||
"""
|
||||
pass
|
||||
|
||||
@property
|
||||
@classmethod
|
||||
@abc.abstractmethod
|
||||
def _creation_log_message(cls) -> str:
|
||||
"""A log message to use when the Pipeline-based Service is created.
|
||||
|
||||
_VertexAiPipelineBasedService supresses logs from PipelineJob creation
|
||||
to avoid duplication.
|
||||
|
||||
For example: 'Created PipelineJob for your Model Evaluation.'
|
||||
|
||||
"""
|
||||
pass
|
||||
|
||||
@property
|
||||
@classmethod
|
||||
@abc.abstractmethod
|
||||
def _component_identifier(cls) -> str:
|
||||
"""A 'component_type' value unique to this service's pipeline execution metadata.
|
||||
|
||||
This is an identifier used by the _validate_pipeline_template_matches_service method
|
||||
to confirm the pipeline being instantiated belongs to this service. Use something
|
||||
specific to your service's PipelineJob.
|
||||
|
||||
For example: 'fpc-model-evaluation'
|
||||
|
||||
"""
|
||||
pass
|
||||
|
||||
@property
|
||||
@classmethod
|
||||
@abc.abstractmethod
|
||||
def _template_name_identifier(cls) -> Optional[str]:
|
||||
"""An optional name identifier for the pipeline template.
|
||||
|
||||
This will validate on the Pipeline's PipelineSpec.PipelineInfo.name
|
||||
field. Setting this property will lead to an additional validation
|
||||
check on pipeline templates in _does_pipeline_template_match_service.
|
||||
If this property is present, the validation method will check for it
|
||||
after validating on `_component_identifier`.
|
||||
|
||||
"""
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
@abc.abstractmethod
|
||||
def submit(self) -> "_VertexAiPipelineBasedService":
|
||||
"""Subclasses should implement this method to submit the underlying PipelineJob."""
|
||||
pass
|
||||
|
||||
# TODO (b/248582133): Consider updating this to return a list in the future to support multiple outputs
|
||||
@property
|
||||
@abc.abstractmethod
|
||||
def _metadata_output_artifact(self) -> Optional[str]:
|
||||
"""The ML Metadata output artifact resource URI from the completed pipeline run."""
|
||||
pass
|
||||
|
||||
@property
|
||||
def backing_pipeline_job(self) -> "pipeline_jobs.PipelineJob":
|
||||
"""The PipelineJob associated with the resource."""
|
||||
return pipeline_jobs.PipelineJob.get(resource_name=self.resource_name)
|
||||
|
||||
@property
|
||||
def pipeline_console_uri(self) -> Optional[str]:
|
||||
"""The console URI of the PipelineJob created by the service."""
|
||||
if self.backing_pipeline_job:
|
||||
return self.backing_pipeline_job._dashboard_uri()
|
||||
|
||||
@property
|
||||
def state(self) -> Optional[gca_pipeline_state.PipelineState]:
|
||||
"""The state of the Pipeline run associated with the service."""
|
||||
if self.backing_pipeline_job:
|
||||
return self.backing_pipeline_job.state
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def _does_pipeline_template_match_service(
|
||||
cls, pipeline_job: "pipeline_jobs.PipelineJob"
|
||||
) -> bool:
|
||||
"""Checks whether the provided pipeline template matches the service.
|
||||
|
||||
Args:
|
||||
pipeline_job (aiplatform.PipelineJob):
|
||||
Required. The PipelineJob to validate with this Pipeline Based Service.
|
||||
|
||||
Returns:
|
||||
Boolean indicating whether the provided template matches the
|
||||
service it's trying to instantiate.
|
||||
"""
|
||||
|
||||
valid_schema_titles = ["system.Run", "system.DagExecution"]
|
||||
|
||||
# We get the Execution here because we want to allow instantiating
|
||||
# failed pipeline runs that match the service. The component_type is
|
||||
# present in the Execution metadata for both failed and successful
|
||||
# pipeline runs
|
||||
for component in pipeline_job.task_details:
|
||||
if not (
|
||||
"name" in component.execution
|
||||
and component.execution.schema_title in valid_schema_titles
|
||||
):
|
||||
continue
|
||||
|
||||
execution_resource = aiplatform.Execution.get(
|
||||
component.execution.name, credentials=pipeline_job.credentials
|
||||
)
|
||||
|
||||
# First validate on component_type
|
||||
if (
|
||||
"component_type" in execution_resource.metadata
|
||||
and execution_resource.metadata.get("component_type")
|
||||
== cls._component_identifier
|
||||
):
|
||||
# Then validate on _template_name_identifier if provided
|
||||
if cls._template_name_identifier is None or (
|
||||
pipeline_job.pipeline_spec is not None
|
||||
and cls._template_name_identifier
|
||||
== pipeline_job.pipeline_spec["pipelineInfo"]["name"]
|
||||
):
|
||||
return True
|
||||
return False
|
||||
|
||||
# TODO (b/249153354): expose _template_ref in error message when artifact
|
||||
# registry support is added
|
||||
@classmethod
|
||||
def _validate_pipeline_template_matches_service(
|
||||
cls, pipeline_job: "pipeline_jobs.PipelineJob"
|
||||
):
|
||||
"""Validates the provided pipeline matches the template of the Pipeline Based Service.
|
||||
|
||||
Args:
|
||||
pipeline_job (aiplatform.PipelineJob):
|
||||
Required. The PipelineJob to validate with this Pipeline Based Service.
|
||||
|
||||
Raises:
|
||||
ValueError: if the provided pipeline ID doesn't match the pipeline service.
|
||||
"""
|
||||
|
||||
if not cls._does_pipeline_template_match_service(pipeline_job):
|
||||
raise ValueError(
|
||||
f"The provided pipeline template is not compatible with {cls.__name__}"
|
||||
)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
pipeline_job_name: str,
|
||||
project: Optional[str] = None,
|
||||
location: Optional[str] = None,
|
||||
credentials: Optional[auth_credentials.Credentials] = None,
|
||||
):
|
||||
"""Retrieves an existing Pipeline Based Service given the ID of the pipeline execution.
|
||||
|
||||
Args:
|
||||
pipeline_job_name (str):
|
||||
Required. A fully-qualified pipeline job run.
|
||||
Example: "projects/123/locations/us-central1/pipelineJobs/456" or
|
||||
"456" when project and location are initialized or passed.
|
||||
project (str):
|
||||
Optional. Project to retrieve pipeline job from. If not set, project
|
||||
set in aiplatform.init will be used.
|
||||
location (str):
|
||||
Optional. Location to retrieve pipeline job from. If not set, location
|
||||
set in aiplatform.init will be used.
|
||||
credentials (auth_credentials.Credentials):
|
||||
Optional. Custom credentials to use to retrieve this pipeline job. Overrides
|
||||
credentials set in aiplatform.init.
|
||||
Raises:
|
||||
ValueError: if the pipeline template used in this PipelineJob is not
|
||||
consistent with the _template_ref defined on the subclass.
|
||||
"""
|
||||
|
||||
super().__init__(
|
||||
project=project,
|
||||
location=location,
|
||||
credentials=credentials,
|
||||
resource_name=pipeline_job_name,
|
||||
)
|
||||
|
||||
job_resource = pipeline_jobs.PipelineJob.get(
|
||||
resource_name=pipeline_job_name, credentials=credentials
|
||||
)
|
||||
|
||||
self._validate_pipeline_template_matches_service(job_resource)
|
||||
|
||||
self._gca_resource = job_resource._gca_resource
|
||||
|
||||
@classmethod
|
||||
def _create_and_submit_pipeline_job(
|
||||
cls,
|
||||
template_params: Dict[str, Any],
|
||||
template_path: str,
|
||||
pipeline_root: Optional[str] = None,
|
||||
display_name: Optional[str] = None,
|
||||
job_id: Optional[str] = None,
|
||||
service_account: Optional[str] = None,
|
||||
network: Optional[str] = None,
|
||||
encryption_spec_key_name: Optional[str] = None,
|
||||
project: Optional[str] = None,
|
||||
location: Optional[str] = None,
|
||||
credentials: Optional[auth_credentials.Credentials] = None,
|
||||
experiment: Optional[Union[str, "aiplatform.Experiment"]] = None,
|
||||
enable_caching: Optional[bool] = None,
|
||||
) -> "_VertexAiPipelineBasedService":
|
||||
"""Create a new PipelineJob using the provided template and parameters.
|
||||
|
||||
Args:
|
||||
template_params (Dict[str, Any]):
|
||||
Required. The parameters to pass to the given pipeline template.
|
||||
template_path (str):
|
||||
Required. The path of the pipeline template to use for this
|
||||
pipeline run.
|
||||
pipeline_root (str):
|
||||
Optional. The GCS directory to store the pipeline run output.
|
||||
If not set, the bucket set in `aiplatform.init(staging_bucket=...)`
|
||||
will be used.
|
||||
display_name (str):
|
||||
Optional. The user-defined name of the PipelineJob created by
|
||||
this Pipeline Based Service.
|
||||
job_id (str):
|
||||
Optional. The unique ID of the job run.
|
||||
If not specified, pipeline name + timestamp will be used.
|
||||
service_account (str):
|
||||
Specifies the service account for workload run-as account.
|
||||
Users submitting jobs must have act-as permission on this run-as account.
|
||||
network (str):
|
||||
The full name of the Compute Engine network to which the job
|
||||
should be peered. For example, projects/12345/global/networks/myVPC.
|
||||
Private services access must already be configured for the network.
|
||||
If left unspecified, the job is not peered with any network.
|
||||
encryption_spec_key_name (str):
|
||||
Customer managed encryption key resource name.
|
||||
project (str):
|
||||
Optional. The project to run this PipelineJob in. If not set,
|
||||
the project set in aiplatform.init will be used.
|
||||
location (str):
|
||||
Optional. Location to create PipelineJob. If not set,
|
||||
location set in aiplatform.init will be used.
|
||||
credentials (auth_credentials.Credentials):
|
||||
Optional. Custom credentials to use to create the PipelineJob.
|
||||
Overrides credentials set in aiplatform.init.
|
||||
experiment (Union[str, experiments_resource.Experiment]):
|
||||
Optional. The Vertex AI experiment name or instance to associate
|
||||
to the PipelineJob executing this model evaluation job.
|
||||
enable_caching (bool):
|
||||
Optional. Whether to turn on caching for the run.
|
||||
|
||||
If this is not set, defaults to the compile time settings, which
|
||||
are True for all tasks by default, while users may specify
|
||||
different caching options for individual tasks.
|
||||
|
||||
If this is set, the setting applies to all tasks in the pipeline.
|
||||
|
||||
Overrides the compile time settings.
|
||||
Returns:
|
||||
(VertexAiPipelineBasedService):
|
||||
Instantiated representation of a Vertex AI Pipeline based service.
|
||||
"""
|
||||
|
||||
if not display_name:
|
||||
display_name = cls._generate_display_name()
|
||||
|
||||
self = cls._empty_constructor(
|
||||
project=project,
|
||||
location=location,
|
||||
credentials=credentials,
|
||||
)
|
||||
|
||||
service_pipeline_job = pipeline_jobs.PipelineJob(
|
||||
display_name=display_name,
|
||||
template_path=template_path,
|
||||
job_id=job_id,
|
||||
pipeline_root=pipeline_root,
|
||||
parameter_values=template_params,
|
||||
encryption_spec_key_name=encryption_spec_key_name,
|
||||
project=project,
|
||||
location=location,
|
||||
credentials=credentials,
|
||||
enable_caching=enable_caching,
|
||||
)
|
||||
|
||||
# Suppresses logs from PipelineJob
|
||||
# The class implementing _VertexAiPipelineBasedService should define a
|
||||
# custom log message via `_creation_log_message`
|
||||
logging.getLogger("google.cloud.aiplatform.pipeline_jobs").setLevel(
|
||||
logging.WARNING
|
||||
)
|
||||
|
||||
service_pipeline_job.submit(
|
||||
service_account=service_account,
|
||||
network=network,
|
||||
experiment=experiment,
|
||||
)
|
||||
|
||||
logging.getLogger("google.cloud.aiplatform.pipeline_jobs").setLevel(
|
||||
logging.INFO
|
||||
)
|
||||
|
||||
self._gca_resource = service_pipeline_job.gca_resource
|
||||
|
||||
return self
|
||||
|
||||
@classmethod
|
||||
def list(
|
||||
cls,
|
||||
project: Optional[str] = None,
|
||||
location: Optional[str] = None,
|
||||
credentials: Optional[str] = None,
|
||||
) -> List["_VertexAiPipelineBasedService"]:
|
||||
"""Lists all PipelineJob resources associated with this Pipeline Based service.
|
||||
|
||||
Args:
|
||||
project (str):
|
||||
Optional. The project to retrieve the Pipeline Based Services from.
|
||||
If not set, the project set in aiplatform.init will be used.
|
||||
location (str):
|
||||
Optional. Location to retrieve the Pipeline Based Services from.
|
||||
If not set, location set in aiplatform.init will be used.
|
||||
credentials (auth_credentials.Credentials):
|
||||
Optional. Custom credentials to use to retrieve the Pipeline Based
|
||||
Services from. Overrides credentials set in aiplatform.init.
|
||||
Returns:
|
||||
(List[PipelineJob]):
|
||||
A list of PipelineJob resource objects.
|
||||
"""
|
||||
|
||||
filter_str = f"metadata.component_type.string_value={cls._component_identifier}"
|
||||
|
||||
filtered_pipeline_executions = aiplatform.Execution.list(
|
||||
filter=filter_str, credentials=credentials
|
||||
)
|
||||
|
||||
service_pipeline_jobs = []
|
||||
|
||||
for pipeline_execution in filtered_pipeline_executions:
|
||||
if "pipeline_job_resource_name" in pipeline_execution.metadata:
|
||||
# This is wrapped in a try/except for cases when both
|
||||
# `_coponent_identifier` and `_template_name_identifier` are
|
||||
# set. In that case, even though all pipelines returned by the
|
||||
# Execution.list() call will match the `_component_identifier`,
|
||||
# some may not match the `_template_name_identifier`
|
||||
try:
|
||||
service_pipeline_job = cls(
|
||||
pipeline_execution.metadata["pipeline_job_resource_name"],
|
||||
project=project,
|
||||
location=location,
|
||||
credentials=credentials,
|
||||
)
|
||||
service_pipeline_jobs.append(service_pipeline_job)
|
||||
except ValueError:
|
||||
continue
|
||||
|
||||
return service_pipeline_jobs
|
||||
|
||||
def wait(self):
|
||||
"""Wait for the PipelineJob to complete."""
|
||||
pipeline_run = self.backing_pipeline_job
|
||||
|
||||
if pipeline_run._latest_future is None:
|
||||
pipeline_run._block_until_complete()
|
||||
else:
|
||||
pipeline_run.wait()
|
||||
@@ -0,0 +1,79 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
# Copyright 2023 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
import re
|
||||
from typing import Optional
|
||||
|
||||
from google.auth import credentials as auth_credentials
|
||||
from google.cloud.aiplatform import base
|
||||
from google.cloud.aiplatform import utils
|
||||
|
||||
|
||||
class _PublisherModel(base.VertexAiResourceNoun):
|
||||
"""Publisher Model Resource for Vertex AI."""
|
||||
|
||||
client_class = utils.ModelGardenClientWithOverride
|
||||
|
||||
_resource_noun = "publisher_model"
|
||||
_getter_method = "get_publisher_model"
|
||||
_delete_method = None
|
||||
_parse_resource_name_method = "parse_publisher_model_path"
|
||||
_format_resource_name_method = "publisher_model_path"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
resource_name: str,
|
||||
project: Optional[str] = None,
|
||||
location: Optional[str] = None,
|
||||
credentials: Optional[auth_credentials.Credentials] = None,
|
||||
):
|
||||
"""Retrieves an existing PublisherModel resource given a resource name or model garden id.
|
||||
|
||||
Args:
|
||||
resource_name (str):
|
||||
Required. A fully-qualified PublisherModel resource name or
|
||||
model garden id. Format:
|
||||
`publishers/{publisher}/models/{publisher_model}` or
|
||||
`{publisher}/{publisher_model}`.
|
||||
project (str):
|
||||
Optional. Project to retrieve the resource from. If not set,
|
||||
project set in aiplatform.init will be used.
|
||||
location (str):
|
||||
Optional. Location to retrieve the resource from. If not set,
|
||||
location set in aiplatform.init will be used.
|
||||
credentials (auth_credentials.Credentials):
|
||||
Optional. Custom credentials to use to retrieve the resource.
|
||||
Overrides credentials set in aiplatform.init.
|
||||
"""
|
||||
|
||||
super().__init__(project=project, location=location, credentials=credentials)
|
||||
|
||||
if self._parse_resource_name(resource_name):
|
||||
full_resource_name = resource_name
|
||||
else:
|
||||
m = re.match(r"^(?P<publisher>.+?)/(?P<model>.+?)$", resource_name)
|
||||
if m:
|
||||
full_resource_name = self._format_resource_name(**m.groupdict())
|
||||
else:
|
||||
raise ValueError(
|
||||
f"`{resource_name}` is not a valid PublisherModel resource "
|
||||
"name or model garden id."
|
||||
)
|
||||
|
||||
self._gca_resource = getattr(self.api_client, self._getter_method)(
|
||||
name=full_resource_name, retry=base._DEFAULT_RETRY
|
||||
)
|
||||
@@ -0,0 +1,252 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
# Copyright 2023 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
"""Streaming prediction functions."""
|
||||
|
||||
from typing import Any, AsyncIterator, Dict, Iterator, List, Optional, Sequence
|
||||
|
||||
from google.cloud.aiplatform_v1.services import prediction_service
|
||||
from google.cloud.aiplatform_v1.types import (
|
||||
prediction_service as prediction_service_types,
|
||||
)
|
||||
from google.cloud.aiplatform_v1.types import (
|
||||
types as aiplatform_types,
|
||||
)
|
||||
|
||||
|
||||
def value_to_tensor(value: Any) -> aiplatform_types.Tensor:
|
||||
"""Converts a Python value to `Tensor`.
|
||||
|
||||
Args:
|
||||
value: A value to convert
|
||||
|
||||
Returns:
|
||||
A `Tensor` object
|
||||
"""
|
||||
if value is None:
|
||||
return aiplatform_types.Tensor()
|
||||
elif isinstance(value, int):
|
||||
return aiplatform_types.Tensor(int_val=[value])
|
||||
elif isinstance(value, float):
|
||||
return aiplatform_types.Tensor(float_val=[value])
|
||||
elif isinstance(value, bool):
|
||||
return aiplatform_types.Tensor(bool_val=[value])
|
||||
elif isinstance(value, str):
|
||||
return aiplatform_types.Tensor(string_val=[value])
|
||||
elif isinstance(value, bytes):
|
||||
return aiplatform_types.Tensor(bytes_val=[value])
|
||||
elif isinstance(value, list):
|
||||
return aiplatform_types.Tensor(list_val=[value_to_tensor(x) for x in value])
|
||||
elif isinstance(value, dict):
|
||||
return aiplatform_types.Tensor(
|
||||
struct_val={k: value_to_tensor(v) for k, v in value.items()}
|
||||
)
|
||||
raise TypeError(f"Unsupported value type {type(value)}")
|
||||
|
||||
|
||||
def tensor_to_value(tensor_pb: aiplatform_types.Tensor) -> Any:
|
||||
"""Converts `Tensor` to a Python value.
|
||||
|
||||
Args:
|
||||
tensor_pb: A `Tensor` object
|
||||
|
||||
Returns:
|
||||
A corresponding Python object
|
||||
"""
|
||||
list_of_fields = tensor_pb.ListFields()
|
||||
if not list_of_fields:
|
||||
return None
|
||||
descriptor, value = tensor_pb.ListFields()[0]
|
||||
if descriptor.name == "list_val":
|
||||
return [tensor_to_value(x) for x in value]
|
||||
elif descriptor.name == "struct_val":
|
||||
return {k: tensor_to_value(v) for k, v in value.items()}
|
||||
if not isinstance(value, Sequence):
|
||||
raise TypeError(f"Unexpected non-list tensor value {value}")
|
||||
if len(value) == 1:
|
||||
return value[0]
|
||||
else:
|
||||
return value
|
||||
|
||||
|
||||
def predict_stream_of_tensor_lists_from_single_tensor_list(
|
||||
prediction_service_client: prediction_service.PredictionServiceClient,
|
||||
endpoint_name: str,
|
||||
tensor_list: List[aiplatform_types.Tensor],
|
||||
parameters_tensor: Optional[aiplatform_types.Tensor] = None,
|
||||
) -> Iterator[List[aiplatform_types.Tensor]]:
|
||||
"""Predicts a stream of lists of `Tensor` objects from a single list of `Tensor` objects.
|
||||
|
||||
Args:
|
||||
tensor_list: Model input as a list of `Tensor` objects.
|
||||
parameters_tensor: Optional. Prediction parameters in `Tensor` form.
|
||||
prediction_service_client: A PredictionServiceClient object.
|
||||
endpoint_name: Resource name of Endpoint or PublisherModel.
|
||||
|
||||
Yields:
|
||||
A generator of model prediction `Tensor` lists.
|
||||
"""
|
||||
request = prediction_service_types.StreamingPredictRequest(
|
||||
endpoint=endpoint_name,
|
||||
inputs=tensor_list,
|
||||
parameters=parameters_tensor,
|
||||
)
|
||||
for response in prediction_service_client.server_streaming_predict(request=request):
|
||||
yield response.outputs
|
||||
|
||||
|
||||
async def predict_stream_of_tensor_lists_from_single_tensor_list_async(
|
||||
prediction_service_async_client: prediction_service.PredictionServiceAsyncClient,
|
||||
endpoint_name: str,
|
||||
tensor_list: List[aiplatform_types.Tensor],
|
||||
parameters_tensor: Optional[aiplatform_types.Tensor] = None,
|
||||
) -> AsyncIterator[List[aiplatform_types.Tensor]]:
|
||||
"""Asynchronously predicts a stream of lists of `Tensor` objects from a single list of `Tensor` objects.
|
||||
|
||||
Args:
|
||||
tensor_list: Model input as a list of `Tensor` objects.
|
||||
parameters_tensor: Optional. Prediction parameters in `Tensor` form.
|
||||
prediction_service_async_client: A PredictionServiceAsyncClient object.
|
||||
endpoint_name: Resource name of Endpoint or PublisherModel.
|
||||
|
||||
Yields:
|
||||
A generator of model prediction `Tensor` lists.
|
||||
"""
|
||||
request = prediction_service_types.StreamingPredictRequest(
|
||||
endpoint=endpoint_name,
|
||||
inputs=tensor_list,
|
||||
parameters=parameters_tensor,
|
||||
)
|
||||
async for response in await prediction_service_async_client.server_streaming_predict(
|
||||
request=request
|
||||
):
|
||||
yield response.outputs
|
||||
|
||||
|
||||
def predict_stream_of_dict_lists_from_single_dict_list(
|
||||
prediction_service_client: prediction_service.PredictionServiceClient,
|
||||
endpoint_name: str,
|
||||
dict_list: List[Dict[str, Any]],
|
||||
parameters: Optional[Dict[str, Any]] = None,
|
||||
) -> Iterator[List[Dict[str, Any]]]:
|
||||
"""Predicts a stream of lists of dicts from a stream of lists of dicts.
|
||||
|
||||
Args:
|
||||
dict_list: Model input as a list of `dict` objects.
|
||||
parameters: Optional. Prediction parameters `dict` form.
|
||||
prediction_service_client: A PredictionServiceClient object.
|
||||
endpoint_name: Resource name of Endpoint or PublisherModel.
|
||||
|
||||
Yields:
|
||||
A generator of model prediction dict lists.
|
||||
"""
|
||||
tensor_list = [value_to_tensor(d) for d in dict_list]
|
||||
parameters_tensor = value_to_tensor(parameters) if parameters else None
|
||||
for tensor_list in predict_stream_of_tensor_lists_from_single_tensor_list(
|
||||
prediction_service_client=prediction_service_client,
|
||||
endpoint_name=endpoint_name,
|
||||
tensor_list=tensor_list,
|
||||
parameters_tensor=parameters_tensor,
|
||||
):
|
||||
yield [tensor_to_value(tensor._pb) for tensor in tensor_list]
|
||||
|
||||
|
||||
async def predict_stream_of_dict_lists_from_single_dict_list_async(
|
||||
prediction_service_async_client: prediction_service.PredictionServiceAsyncClient,
|
||||
endpoint_name: str,
|
||||
dict_list: List[Dict[str, Any]],
|
||||
parameters: Optional[Dict[str, Any]] = None,
|
||||
) -> AsyncIterator[List[Dict[str, Any]]]:
|
||||
"""Asynchronously predicts a stream of lists of dicts from a stream of lists of dicts.
|
||||
|
||||
Args:
|
||||
dict_list: Model input as a list of `dict` objects.
|
||||
parameters: Optional. Prediction parameters `dict` form.
|
||||
prediction_service_async_client: A PredictionServiceAsyncClient object.
|
||||
endpoint_name: Resource name of Endpoint or PublisherModel.
|
||||
|
||||
Yields:
|
||||
A generator of model prediction dict lists.
|
||||
"""
|
||||
tensor_list = [value_to_tensor(d) for d in dict_list]
|
||||
parameters_tensor = value_to_tensor(parameters) if parameters else None
|
||||
async for tensor_list in predict_stream_of_tensor_lists_from_single_tensor_list_async(
|
||||
prediction_service_async_client=prediction_service_async_client,
|
||||
endpoint_name=endpoint_name,
|
||||
tensor_list=tensor_list,
|
||||
parameters_tensor=parameters_tensor,
|
||||
):
|
||||
yield [tensor_to_value(tensor._pb) for tensor in tensor_list]
|
||||
|
||||
|
||||
def predict_stream_of_dicts_from_single_dict(
|
||||
prediction_service_client: prediction_service.PredictionServiceClient,
|
||||
endpoint_name: str,
|
||||
instance: Dict[str, Any],
|
||||
parameters: Optional[Dict[str, Any]] = None,
|
||||
) -> Iterator[Dict[str, Any]]:
|
||||
"""Predicts a stream of dicts from a single instance dict.
|
||||
|
||||
Args:
|
||||
instance: A single input instance `dict`.
|
||||
parameters: Optional. Prediction parameters `dict`.
|
||||
prediction_service_client: A PredictionServiceClient object.
|
||||
endpoint_name: Resource name of Endpoint or PublisherModel.
|
||||
|
||||
Yields:
|
||||
A generator of model prediction dicts.
|
||||
"""
|
||||
for dict_list in predict_stream_of_dict_lists_from_single_dict_list(
|
||||
prediction_service_client=prediction_service_client,
|
||||
endpoint_name=endpoint_name,
|
||||
dict_list=[instance],
|
||||
parameters=parameters,
|
||||
):
|
||||
if len(dict_list) > 1:
|
||||
raise ValueError(
|
||||
f"Expected to receive a single output, but got {dict_list}"
|
||||
)
|
||||
yield dict_list[0]
|
||||
|
||||
|
||||
async def predict_stream_of_dicts_from_single_dict_async(
|
||||
prediction_service_async_client: prediction_service.PredictionServiceAsyncClient,
|
||||
endpoint_name: str,
|
||||
instance: Dict[str, Any],
|
||||
parameters: Optional[Dict[str, Any]] = None,
|
||||
) -> AsyncIterator[Dict[str, Any]]:
|
||||
"""Asynchronously predicts a stream of dicts from a single instance dict.
|
||||
|
||||
Args:
|
||||
instance: A single input instance `dict`.
|
||||
parameters: Optional. Prediction parameters `dict`.
|
||||
prediction_service_async_client: A PredictionServiceAsyncClient object.
|
||||
endpoint_name: Resource name of Endpoint or PublisherModel.
|
||||
|
||||
Yields:
|
||||
A generator of model prediction dicts.
|
||||
"""
|
||||
async for dict_list in predict_stream_of_dict_lists_from_single_dict_list_async(
|
||||
prediction_service_async_client=prediction_service_async_client,
|
||||
endpoint_name=endpoint_name,
|
||||
dict_list=[instance],
|
||||
parameters=parameters,
|
||||
):
|
||||
if len(dict_list) > 1:
|
||||
raise ValueError(
|
||||
f"Expected to receive a single output, but got {dict_list}"
|
||||
)
|
||||
yield dict_list[0]
|
||||
1525
.venv/lib/python3.10/site-packages/google/cloud/aiplatform/base.py
Normal file
1525
.venv/lib/python3.10/site-packages/google/cloud/aiplatform/base.py
Normal file
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,297 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
# Copyright 2023 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
from google.cloud.aiplatform.compat import services
|
||||
from google.cloud.aiplatform.compat import types
|
||||
|
||||
V1BETA1 = "v1beta1"
|
||||
V1 = "v1"
|
||||
|
||||
DEFAULT_VERSION = V1
|
||||
|
||||
if DEFAULT_VERSION == V1BETA1:
|
||||
|
||||
services.dataset_service_client = services.dataset_service_client_v1beta1
|
||||
services.deployment_resource_pool_service_client = (
|
||||
services.deployment_resource_pool_service_client_v1beta1
|
||||
)
|
||||
services.endpoint_service_client = services.endpoint_service_client_v1beta1
|
||||
services.feature_online_store_admin_service_client = (
|
||||
services.feature_online_store_admin_service_client_v1beta1
|
||||
)
|
||||
services.feature_online_store_service_client = (
|
||||
services.feature_online_store_service_client_v1beta1
|
||||
)
|
||||
services.feature_registry_service_client = (
|
||||
services.feature_registry_service_client_v1beta1
|
||||
)
|
||||
services.featurestore_online_serving_service_client = (
|
||||
services.featurestore_online_serving_service_client_v1beta1
|
||||
)
|
||||
services.featurestore_service_client = services.featurestore_service_client_v1beta1
|
||||
services.gen_ai_cache_service_client = services.gen_ai_cache_service_client_v1beta1
|
||||
services.job_service_client = services.job_service_client_v1beta1
|
||||
services.model_service_client = services.model_service_client_v1beta1
|
||||
services.model_garden_service_client = services.model_garden_service_client_v1beta1
|
||||
services.pipeline_service_client = services.pipeline_service_client_v1beta1
|
||||
services.prediction_service_client = services.prediction_service_client_v1beta1
|
||||
services.prediction_service_async_client = (
|
||||
services.prediction_service_async_client_v1beta1
|
||||
)
|
||||
services.schedule_service_client = services.schedule_service_client_v1beta1
|
||||
services.specialist_pool_service_client = (
|
||||
services.specialist_pool_service_client_v1beta1
|
||||
)
|
||||
services.match_service_client = services.match_service_client_v1beta1
|
||||
services.metadata_service_client = services.metadata_service_client_v1beta1
|
||||
services.tensorboard_service_client = services.tensorboard_service_client_v1beta1
|
||||
services.index_service_client = services.index_service_client_v1beta1
|
||||
services.index_endpoint_service_client = (
|
||||
services.index_endpoint_service_client_v1beta1
|
||||
)
|
||||
services.vizier_service_client = services.vizier_service_client_v1beta1
|
||||
|
||||
types.accelerator_type = types.accelerator_type_v1beta1
|
||||
types.annotation = types.annotation_v1beta1
|
||||
types.annotation_spec = types.annotation_spec_v1beta1
|
||||
types.artifact = types.artifact_v1beta1
|
||||
types.batch_prediction_job = types.batch_prediction_job_v1beta1
|
||||
types.cached_content = types.cached_content_v1beta1
|
||||
types.completion_stats = types.completion_stats_v1beta1
|
||||
types.context = types.context_v1beta1
|
||||
types.custom_job = types.custom_job_v1beta1
|
||||
types.data_item = types.data_item_v1beta1
|
||||
types.data_labeling_job = types.data_labeling_job_v1beta1
|
||||
types.dataset = types.dataset_v1beta1
|
||||
types.dataset_service = types.dataset_service_v1beta1
|
||||
types.deployed_model_ref = types.deployed_model_ref_v1beta1
|
||||
types.deployment_resource_pool = types.deployment_resource_pool_v1beta1
|
||||
types.deployment_resource_pool_service = (
|
||||
types.deployment_resource_pool_service_v1beta1
|
||||
)
|
||||
types.encryption_spec = types.encryption_spec_v1beta1
|
||||
types.endpoint = types.endpoint_v1beta1
|
||||
types.endpoint_service = types.endpoint_service_v1beta1
|
||||
types.entity_type = types.entity_type_v1beta1
|
||||
types.env_var = types.env_var_v1beta1
|
||||
types.event = types.event_v1beta1
|
||||
types.execution = types.execution_v1beta1
|
||||
types.explanation = types.explanation_v1beta1
|
||||
types.explanation_metadata = types.explanation_metadata_v1beta1
|
||||
types.feature = types.feature_v1beta1
|
||||
types.feature_group = types.feature_group_v1beta1
|
||||
types.feature_monitor = types.feature_monitor_v1beta1
|
||||
types.feature_monitor_job = types.feature_monitor_job_v1beta1
|
||||
types.feature_monitoring_stats = types.feature_monitoring_stats_v1beta1
|
||||
types.feature_online_store = types.feature_online_store_v1beta1
|
||||
types.feature_online_store_admin_service = (
|
||||
types.feature_online_store_admin_service_v1beta1
|
||||
)
|
||||
types.feature_registry_service = types.feature_registry_service_v1beta1
|
||||
types.feature_online_store_service = types.feature_online_store_service_v1beta1
|
||||
types.feature_selector = types.feature_selector_v1beta1
|
||||
types.feature_view = types.feature_view_v1beta1
|
||||
types.feature_view_sync = types.feature_view_sync_v1beta1
|
||||
types.featurestore = types.featurestore_v1beta1
|
||||
types.featurestore_monitoring = types.featurestore_monitoring_v1beta1
|
||||
types.featurestore_online_service = types.featurestore_online_service_v1beta1
|
||||
types.featurestore_service = types.featurestore_service_v1beta1
|
||||
types.hyperparameter_tuning_job = types.hyperparameter_tuning_job_v1beta1
|
||||
types.index = types.index_v1beta1
|
||||
types.index_endpoint = types.index_endpoint_v1beta1
|
||||
types.index_service = types.index_service_v1beta1
|
||||
types.io = types.io_v1beta1
|
||||
types.job_service = types.job_service_v1beta1
|
||||
types.job_state = types.job_state_v1beta1
|
||||
types.lineage_subgraph = types.lineage_subgraph_v1beta1
|
||||
types.machine_resources = types.machine_resources_v1beta1
|
||||
types.manual_batch_tuning_parameters = types.manual_batch_tuning_parameters_v1beta1
|
||||
types.matching_engine_deployed_index_ref = (
|
||||
types.matching_engine_deployed_index_ref_v1beta1
|
||||
)
|
||||
types.matching_engine_index = types.index_v1beta1
|
||||
types.matching_engine_index_endpoint = types.index_endpoint_v1beta1
|
||||
types.metadata_service = types.metadata_service_v1beta1
|
||||
types.metadata_schema = types.metadata_schema_v1beta1
|
||||
types.metadata_store = types.metadata_store_v1beta1
|
||||
types.model = types.model_v1beta1
|
||||
types.model_evaluation = types.model_evaluation_v1beta1
|
||||
types.model_evaluation_slice = types.model_evaluation_slice_v1beta1
|
||||
types.model_deployment_monitoring_job = (
|
||||
types.model_deployment_monitoring_job_v1beta1
|
||||
)
|
||||
types.model_garden_service = types.model_garden_service_v1beta1
|
||||
types.model_monitoring = types.model_monitoring_v1beta1
|
||||
types.model_service = types.model_service_v1beta1
|
||||
types.service_networking = types.service_networking_v1beta1
|
||||
types.operation = types.operation_v1beta1
|
||||
types.pipeline_failure_policy = types.pipeline_failure_policy_v1beta1
|
||||
types.pipeline_job = types.pipeline_job_v1beta1
|
||||
types.pipeline_service = types.pipeline_service_v1beta1
|
||||
types.pipeline_state = types.pipeline_state_v1beta1
|
||||
types.prediction_service = types.prediction_service_v1beta1
|
||||
types.publisher_model = types.publisher_model_v1beta1
|
||||
types.schedule = types.schedule_v1beta1
|
||||
types.schedule_service = types.schedule_service_v1beta1
|
||||
types.specialist_pool = types.specialist_pool_v1beta1
|
||||
types.specialist_pool_service = types.specialist_pool_service_v1beta1
|
||||
types.study = types.study_v1beta1
|
||||
types.tensorboard = types.tensorboard_v1beta1
|
||||
types.tensorboard_service = types.tensorboard_service_v1beta1
|
||||
types.tensorboard_data = types.tensorboard_data_v1beta1
|
||||
types.tensorboard_experiment = types.tensorboard_experiment_v1beta1
|
||||
types.tensorboard_run = types.tensorboard_run_v1beta1
|
||||
types.tensorboard_service = types.tensorboard_service_v1beta1
|
||||
types.tensorboard_time_series = types.tensorboard_time_series_v1beta1
|
||||
types.training_pipeline = types.training_pipeline_v1beta1
|
||||
types.types = types.types_v1beta1
|
||||
types.vizier_service = types.vizier_service_v1beta1
|
||||
|
||||
if DEFAULT_VERSION == V1:
|
||||
|
||||
services.dataset_service_client = services.dataset_service_client_v1
|
||||
services.deployment_resource_pool_service_client = (
|
||||
services.deployment_resource_pool_service_client_v1
|
||||
)
|
||||
services.endpoint_service_client = services.endpoint_service_client_v1
|
||||
services.feature_online_store_admin_service_client = (
|
||||
services.feature_online_store_admin_service_client_v1
|
||||
)
|
||||
services.feature_registry_service_client = (
|
||||
services.feature_registry_service_client_v1
|
||||
)
|
||||
services.feature_online_store_service_client = (
|
||||
services.feature_online_store_service_client_v1
|
||||
)
|
||||
services.featurestore_online_serving_service_client = (
|
||||
services.featurestore_online_serving_service_client_v1
|
||||
)
|
||||
services.featurestore_service_client = services.featurestore_service_client_v1
|
||||
services.gen_ai_cache_service_client = services.gen_ai_cache_service_client_v1
|
||||
services.job_service_client = services.job_service_client_v1
|
||||
services.model_garden_service_client = services.model_garden_service_client_v1
|
||||
services.model_service_client = services.model_service_client_v1
|
||||
services.pipeline_service_client = services.pipeline_service_client_v1
|
||||
services.prediction_service_client = services.prediction_service_client_v1
|
||||
services.prediction_service_async_client = (
|
||||
services.prediction_service_async_client_v1
|
||||
)
|
||||
services.schedule_service_client = services.schedule_service_client_v1
|
||||
services.specialist_pool_service_client = services.specialist_pool_service_client_v1
|
||||
services.tensorboard_service_client = services.tensorboard_service_client_v1
|
||||
services.index_service_client = services.index_service_client_v1
|
||||
services.index_endpoint_service_client = services.index_endpoint_service_client_v1
|
||||
services.vizier_service_client = services.vizier_service_client_v1
|
||||
|
||||
types.accelerator_type = types.accelerator_type_v1
|
||||
types.annotation = types.annotation_v1
|
||||
types.annotation_spec = types.annotation_spec_v1
|
||||
types.artifact = types.artifact_v1
|
||||
types.batch_prediction_job = types.batch_prediction_job_v1
|
||||
types.cached_content = types.cached_content_v1
|
||||
types.completion_stats = types.completion_stats_v1
|
||||
types.context = types.context_v1
|
||||
types.custom_job = types.custom_job_v1
|
||||
types.data_item = types.data_item_v1
|
||||
types.data_labeling_job = types.data_labeling_job_v1
|
||||
types.dataset = types.dataset_v1
|
||||
types.dataset_service = types.dataset_service_v1
|
||||
types.deployed_model_ref = types.deployed_model_ref_v1
|
||||
types.deployment_resource_pool = types.deployment_resource_pool_v1
|
||||
types.deployment_resource_pool_service = types.deployment_resource_pool_service_v1
|
||||
types.encryption_spec = types.encryption_spec_v1
|
||||
types.endpoint = types.endpoint_v1
|
||||
types.endpoint_service = types.endpoint_service_v1
|
||||
types.entity_type = types.entity_type_v1
|
||||
types.env_var = types.env_var_v1
|
||||
types.event = types.event_v1
|
||||
types.execution = types.execution_v1
|
||||
types.explanation = types.explanation_v1
|
||||
types.explanation_metadata = types.explanation_metadata_v1
|
||||
types.feature = types.feature_v1
|
||||
types.feature_group = types.feature_group_v1
|
||||
# TODO(b/293184410): Temporary code. Switch to v1 once v1 is available.
|
||||
types.feature_monitor = types.feature_monitor_v1beta1
|
||||
types.feature_monitor_job = types.feature_monitor_job_v1beta1
|
||||
types.feature_monitoring_stats = types.feature_monitoring_stats_v1
|
||||
types.feature_online_store = types.feature_online_store_v1
|
||||
types.feature_online_store_admin_service = (
|
||||
types.feature_online_store_admin_service_v1
|
||||
)
|
||||
types.feature_registry_service = types.feature_registry_service_v1
|
||||
types.feature_online_store_service = types.feature_online_store_service_v1
|
||||
types.feature_selector = types.feature_selector_v1
|
||||
types.feature_view = types.feature_view_v1
|
||||
types.feature_view_sync = types.feature_view_sync_v1
|
||||
types.featurestore = types.featurestore_v1
|
||||
types.featurestore_online_service = types.featurestore_online_service_v1
|
||||
types.featurestore_service = types.featurestore_service_v1
|
||||
types.hyperparameter_tuning_job = types.hyperparameter_tuning_job_v1
|
||||
types.index = types.index_v1
|
||||
types.index_endpoint = types.index_endpoint_v1
|
||||
types.index_service = types.index_service_v1
|
||||
types.io = types.io_v1
|
||||
types.job_service = types.job_service_v1
|
||||
types.job_state = types.job_state_v1
|
||||
types.lineage_subgraph = types.lineage_subgraph_v1
|
||||
types.machine_resources = types.machine_resources_v1
|
||||
types.manual_batch_tuning_parameters = types.manual_batch_tuning_parameters_v1
|
||||
types.matching_engine_deployed_index_ref = (
|
||||
types.matching_engine_deployed_index_ref_v1
|
||||
)
|
||||
types.matching_engine_index = types.index_v1
|
||||
types.matching_engine_index_endpoint = types.index_endpoint_v1
|
||||
types.metadata_service = types.metadata_service_v1
|
||||
types.metadata_schema = types.metadata_schema_v1
|
||||
types.metadata_store = types.metadata_store_v1
|
||||
types.model = types.model_v1
|
||||
types.model_evaluation = types.model_evaluation_v1
|
||||
types.model_evaluation_slice = types.model_evaluation_slice_v1
|
||||
types.model_deployment_monitoring_job = types.model_deployment_monitoring_job_v1
|
||||
types.model_monitoring = types.model_monitoring_v1
|
||||
types.model_service = types.model_service_v1
|
||||
types.service_networking = types.service_networking_v1
|
||||
types.operation = types.operation_v1
|
||||
types.pipeline_failure_policy = types.pipeline_failure_policy_v1
|
||||
types.pipeline_job = types.pipeline_job_v1
|
||||
types.pipeline_service = types.pipeline_service_v1
|
||||
types.pipeline_state = types.pipeline_state_v1
|
||||
types.prediction_service = types.prediction_service_v1
|
||||
types.publisher_model = types.publisher_model_v1
|
||||
types.schedule = types.schedule_v1
|
||||
types.schedule_service = types.schedule_service_v1
|
||||
types.specialist_pool = types.specialist_pool_v1
|
||||
types.specialist_pool_service = types.specialist_pool_service_v1
|
||||
types.study = types.study_v1
|
||||
types.tensorboard = types.tensorboard_v1
|
||||
types.tensorboard_service = types.tensorboard_service_v1
|
||||
types.tensorboard_data = types.tensorboard_data_v1
|
||||
types.tensorboard_experiment = types.tensorboard_experiment_v1
|
||||
types.tensorboard_run = types.tensorboard_run_v1
|
||||
types.tensorboard_service = types.tensorboard_service_v1
|
||||
types.tensorboard_time_series = types.tensorboard_time_series_v1
|
||||
types.training_pipeline = types.training_pipeline_v1
|
||||
types.types = types.types_v1
|
||||
types.vizier_service = types.vizier_service_v1
|
||||
|
||||
__all__ = (
|
||||
DEFAULT_VERSION,
|
||||
V1BETA1,
|
||||
V1,
|
||||
services,
|
||||
types,
|
||||
)
|
||||
Binary file not shown.
@@ -0,0 +1,264 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
# Copyright 2023 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
from google.cloud.aiplatform_v1beta1.services.dataset_service import (
|
||||
client as dataset_service_client_v1beta1,
|
||||
)
|
||||
from google.cloud.aiplatform_v1beta1.services.deployment_resource_pool_service import (
|
||||
client as deployment_resource_pool_service_client_v1beta1,
|
||||
)
|
||||
from google.cloud.aiplatform_v1beta1.services.endpoint_service import (
|
||||
client as endpoint_service_client_v1beta1,
|
||||
)
|
||||
from google.cloud.aiplatform_v1beta1.services.example_store_service import (
|
||||
client as example_store_service_client_v1beta1,
|
||||
)
|
||||
from google.cloud.aiplatform_v1beta1.services.extension_execution_service import (
|
||||
client as extension_execution_service_client_v1beta1,
|
||||
)
|
||||
from google.cloud.aiplatform_v1beta1.services.extension_registry_service import (
|
||||
client as extension_registry_service_client_v1beta1,
|
||||
)
|
||||
from google.cloud.aiplatform_v1beta1.services.feature_online_store_service import (
|
||||
client as feature_online_store_service_client_v1beta1,
|
||||
)
|
||||
from google.cloud.aiplatform_v1beta1.services.feature_online_store_admin_service import (
|
||||
client as feature_online_store_admin_service_client_v1beta1,
|
||||
)
|
||||
from google.cloud.aiplatform_v1beta1.services.feature_registry_service import (
|
||||
client as feature_registry_service_client_v1beta1,
|
||||
)
|
||||
from google.cloud.aiplatform_v1beta1.services.featurestore_online_serving_service import (
|
||||
client as featurestore_online_serving_service_client_v1beta1,
|
||||
)
|
||||
from google.cloud.aiplatform_v1beta1.services.featurestore_service import (
|
||||
client as featurestore_service_client_v1beta1,
|
||||
)
|
||||
from google.cloud.aiplatform_v1beta1.services.gen_ai_cache_service import (
|
||||
client as gen_ai_cache_service_client_v1beta1,
|
||||
)
|
||||
from google.cloud.aiplatform_v1beta1.services.index_service import (
|
||||
client as index_service_client_v1beta1,
|
||||
)
|
||||
from google.cloud.aiplatform_v1beta1.services.index_endpoint_service import (
|
||||
client as index_endpoint_service_client_v1beta1,
|
||||
)
|
||||
from google.cloud.aiplatform_v1beta1.services.job_service import (
|
||||
client as job_service_client_v1beta1,
|
||||
)
|
||||
from google.cloud.aiplatform_v1beta1.services.match_service import (
|
||||
client as match_service_client_v1beta1,
|
||||
)
|
||||
from google.cloud.aiplatform_v1beta1.services.metadata_service import (
|
||||
client as metadata_service_client_v1beta1,
|
||||
)
|
||||
from google.cloud.aiplatform_v1beta1.services.model_garden_service import (
|
||||
client as model_garden_service_client_v1beta1,
|
||||
)
|
||||
from google.cloud.aiplatform_v1beta1.services.model_service import (
|
||||
client as model_service_client_v1beta1,
|
||||
)
|
||||
from google.cloud.aiplatform_v1beta1.services.model_monitoring_service import (
|
||||
client as model_monitoring_service_client_v1beta1,
|
||||
)
|
||||
from google.cloud.aiplatform_v1beta1.services.persistent_resource_service import (
|
||||
client as persistent_resource_service_client_v1beta1,
|
||||
)
|
||||
from google.cloud.aiplatform_v1beta1.services.pipeline_service import (
|
||||
client as pipeline_service_client_v1beta1,
|
||||
)
|
||||
from google.cloud.aiplatform_v1beta1.services.prediction_service import (
|
||||
client as prediction_service_client_v1beta1,
|
||||
)
|
||||
from google.cloud.aiplatform_v1beta1.services.prediction_service import (
|
||||
async_client as prediction_service_async_client_v1beta1,
|
||||
)
|
||||
from google.cloud.aiplatform_v1beta1.services.reasoning_engine_service import (
|
||||
client as reasoning_engine_service_client_v1beta1,
|
||||
)
|
||||
from google.cloud.aiplatform_v1beta1.services.reasoning_engine_execution_service import (
|
||||
client as reasoning_engine_execution_service_client_v1beta1,
|
||||
)
|
||||
from google.cloud.aiplatform_v1beta1.services.schedule_service import (
|
||||
client as schedule_service_client_v1beta1,
|
||||
)
|
||||
from google.cloud.aiplatform_v1beta1.services.specialist_pool_service import (
|
||||
client as specialist_pool_service_client_v1beta1,
|
||||
)
|
||||
from google.cloud.aiplatform_v1beta1.services.tensorboard_service import (
|
||||
client as tensorboard_service_client_v1beta1,
|
||||
)
|
||||
from google.cloud.aiplatform_v1beta1.services.vertex_rag_data_service import (
|
||||
client as vertex_rag_data_service_client_v1beta1,
|
||||
)
|
||||
from google.cloud.aiplatform_v1beta1.services.vertex_rag_data_service import (
|
||||
async_client as vertex_rag_data_service_async_client_v1beta1,
|
||||
)
|
||||
from google.cloud.aiplatform_v1beta1.services.vertex_rag_service import (
|
||||
client as vertex_rag_service_client_v1beta1,
|
||||
)
|
||||
from google.cloud.aiplatform_v1beta1.services.vizier_service import (
|
||||
client as vizier_service_client_v1beta1,
|
||||
)
|
||||
|
||||
|
||||
from google.cloud.aiplatform_v1.services.dataset_service import (
|
||||
client as dataset_service_client_v1,
|
||||
)
|
||||
from google.cloud.aiplatform_v1.services.deployment_resource_pool_service import (
|
||||
client as deployment_resource_pool_service_client_v1,
|
||||
)
|
||||
from google.cloud.aiplatform_v1.services.endpoint_service import (
|
||||
client as endpoint_service_client_v1,
|
||||
)
|
||||
from google.cloud.aiplatform_v1.services.feature_online_store_service import (
|
||||
client as feature_online_store_service_client_v1,
|
||||
)
|
||||
from google.cloud.aiplatform_v1.services.feature_online_store_admin_service import (
|
||||
client as feature_online_store_admin_service_client_v1,
|
||||
)
|
||||
from google.cloud.aiplatform_v1.services.feature_registry_service import (
|
||||
client as feature_registry_service_client_v1,
|
||||
)
|
||||
from google.cloud.aiplatform_v1.services.featurestore_online_serving_service import (
|
||||
client as featurestore_online_serving_service_client_v1,
|
||||
)
|
||||
from google.cloud.aiplatform_v1.services.featurestore_service import (
|
||||
client as featurestore_service_client_v1,
|
||||
)
|
||||
from google.cloud.aiplatform_v1.services.gen_ai_cache_service import (
|
||||
client as gen_ai_cache_service_client_v1,
|
||||
)
|
||||
from google.cloud.aiplatform_v1.services.index_service import (
|
||||
client as index_service_client_v1,
|
||||
)
|
||||
from google.cloud.aiplatform_v1.services.index_endpoint_service import (
|
||||
client as index_endpoint_service_client_v1,
|
||||
)
|
||||
from google.cloud.aiplatform_v1.services.job_service import (
|
||||
client as job_service_client_v1,
|
||||
)
|
||||
from google.cloud.aiplatform_v1.services.metadata_service import (
|
||||
client as metadata_service_client_v1,
|
||||
)
|
||||
from google.cloud.aiplatform_v1.services.model_garden_service import (
|
||||
client as model_garden_service_client_v1,
|
||||
)
|
||||
from google.cloud.aiplatform_v1.services.model_service import (
|
||||
client as model_service_client_v1,
|
||||
)
|
||||
from google.cloud.aiplatform_v1.services.persistent_resource_service import (
|
||||
client as persistent_resource_service_client_v1,
|
||||
)
|
||||
from google.cloud.aiplatform_v1.services.pipeline_service import (
|
||||
client as pipeline_service_client_v1,
|
||||
)
|
||||
from google.cloud.aiplatform_v1.services.prediction_service import (
|
||||
client as prediction_service_client_v1,
|
||||
)
|
||||
from google.cloud.aiplatform_v1.services.prediction_service import (
|
||||
async_client as prediction_service_async_client_v1,
|
||||
)
|
||||
from google.cloud.aiplatform_v1.services.reasoning_engine_service import (
|
||||
client as reasoning_engine_service_client_v1,
|
||||
)
|
||||
from google.cloud.aiplatform_v1.services.reasoning_engine_execution_service import (
|
||||
client as reasoning_engine_execution_service_client_v1,
|
||||
)
|
||||
from google.cloud.aiplatform_v1.services.schedule_service import (
|
||||
client as schedule_service_client_v1,
|
||||
)
|
||||
from google.cloud.aiplatform_v1.services.specialist_pool_service import (
|
||||
client as specialist_pool_service_client_v1,
|
||||
)
|
||||
from google.cloud.aiplatform_v1.services.tensorboard_service import (
|
||||
client as tensorboard_service_client_v1,
|
||||
)
|
||||
from google.cloud.aiplatform_v1.services.vizier_service import (
|
||||
client as vizier_service_client_v1,
|
||||
)
|
||||
from google.cloud.aiplatform_v1.services.vertex_rag_service import (
|
||||
client as vertex_rag_service_client_v1,
|
||||
)
|
||||
from google.cloud.aiplatform_v1.services.vertex_rag_data_service import (
|
||||
client as vertex_rag_data_service_client_v1,
|
||||
)
|
||||
from google.cloud.aiplatform_v1.services.vertex_rag_data_service import (
|
||||
async_client as vertex_rag_data_service_async_client_v1,
|
||||
)
|
||||
|
||||
__all__ = (
|
||||
# v1
|
||||
dataset_service_client_v1,
|
||||
deployment_resource_pool_service_client_v1,
|
||||
endpoint_service_client_v1,
|
||||
feature_online_store_service_client_v1,
|
||||
feature_online_store_admin_service_client_v1,
|
||||
feature_registry_service_client_v1,
|
||||
featurestore_online_serving_service_client_v1,
|
||||
featurestore_service_client_v1,
|
||||
index_service_client_v1,
|
||||
index_endpoint_service_client_v1,
|
||||
job_service_client_v1,
|
||||
metadata_service_client_v1,
|
||||
model_garden_service_client_v1,
|
||||
model_service_client_v1,
|
||||
persistent_resource_service_client_v1,
|
||||
pipeline_service_client_v1,
|
||||
prediction_service_client_v1,
|
||||
prediction_service_async_client_v1,
|
||||
reasoning_engine_execution_service_client_v1,
|
||||
reasoning_engine_service_client_v1,
|
||||
schedule_service_client_v1,
|
||||
specialist_pool_service_client_v1,
|
||||
tensorboard_service_client_v1,
|
||||
vizier_service_client_v1,
|
||||
vertex_rag_data_service_async_client_v1,
|
||||
vertex_rag_data_service_client_v1,
|
||||
vertex_rag_service_client_v1,
|
||||
# v1beta1
|
||||
dataset_service_client_v1beta1,
|
||||
deployment_resource_pool_service_client_v1beta1,
|
||||
endpoint_service_client_v1beta1,
|
||||
example_store_service_client_v1beta1,
|
||||
feature_online_store_service_client_v1beta1,
|
||||
feature_online_store_admin_service_client_v1beta1,
|
||||
feature_registry_service_client_v1beta1,
|
||||
featurestore_online_serving_service_client_v1beta1,
|
||||
featurestore_service_client_v1beta1,
|
||||
index_service_client_v1beta1,
|
||||
index_endpoint_service_client_v1beta1,
|
||||
job_service_client_v1beta1,
|
||||
match_service_client_v1beta1,
|
||||
model_garden_service_client_v1beta1,
|
||||
model_monitoring_service_client_v1beta1,
|
||||
model_service_client_v1beta1,
|
||||
persistent_resource_service_client_v1beta1,
|
||||
pipeline_service_client_v1beta1,
|
||||
prediction_service_client_v1beta1,
|
||||
prediction_service_async_client_v1beta1,
|
||||
reasoning_engine_execution_service_client_v1beta1,
|
||||
reasoning_engine_service_client_v1beta1,
|
||||
schedule_service_client_v1beta1,
|
||||
specialist_pool_service_client_v1beta1,
|
||||
metadata_service_client_v1beta1,
|
||||
tensorboard_service_client_v1beta1,
|
||||
vertex_rag_service_client_v1beta1,
|
||||
vertex_rag_data_service_client_v1beta1,
|
||||
vertex_rag_data_service_async_client_v1beta1,
|
||||
vizier_service_client_v1beta1,
|
||||
)
|
||||
Binary file not shown.
@@ -0,0 +1,361 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
# Copyright 2023 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
from google.cloud.aiplatform_v1beta1.types import (
|
||||
accelerator_type as accelerator_type_v1beta1,
|
||||
annotation as annotation_v1beta1,
|
||||
annotation_spec as annotation_spec_v1beta1,
|
||||
artifact as artifact_v1beta1,
|
||||
batch_prediction_job as batch_prediction_job_v1beta1,
|
||||
cached_content as cached_content_v1beta1,
|
||||
completion_stats as completion_stats_v1beta1,
|
||||
context as context_v1beta1,
|
||||
custom_job as custom_job_v1beta1,
|
||||
data_item as data_item_v1beta1,
|
||||
data_labeling_job as data_labeling_job_v1beta1,
|
||||
dataset as dataset_v1beta1,
|
||||
dataset_service as dataset_service_v1beta1,
|
||||
deployed_index_ref as matching_engine_deployed_index_ref_v1beta1,
|
||||
deployed_model_ref as deployed_model_ref_v1beta1,
|
||||
deployment_resource_pool as deployment_resource_pool_v1beta1,
|
||||
deployment_resource_pool_service as deployment_resource_pool_service_v1beta1,
|
||||
encryption_spec as encryption_spec_v1beta1,
|
||||
endpoint as endpoint_v1beta1,
|
||||
endpoint_service as endpoint_service_v1beta1,
|
||||
entity_type as entity_type_v1beta1,
|
||||
env_var as env_var_v1beta1,
|
||||
event as event_v1beta1,
|
||||
execution as execution_v1beta1,
|
||||
explanation as explanation_v1beta1,
|
||||
explanation_metadata as explanation_metadata_v1beta1,
|
||||
feature as feature_v1beta1,
|
||||
feature_group as feature_group_v1beta1,
|
||||
feature_monitor as feature_monitor_v1beta1,
|
||||
feature_monitor_job as feature_monitor_job_v1beta1,
|
||||
feature_monitoring_stats as feature_monitoring_stats_v1beta1,
|
||||
feature_online_store as feature_online_store_v1beta1,
|
||||
feature_online_store_admin_service as feature_online_store_admin_service_v1beta1,
|
||||
feature_online_store_service as feature_online_store_service_v1beta1,
|
||||
feature_registry_service as feature_registry_service_v1beta1,
|
||||
feature_selector as feature_selector_v1beta1,
|
||||
feature_view as feature_view_v1beta1,
|
||||
feature_view_sync as feature_view_sync_v1beta1,
|
||||
featurestore as featurestore_v1beta1,
|
||||
featurestore_monitoring as featurestore_monitoring_v1beta1,
|
||||
featurestore_online_service as featurestore_online_service_v1beta1,
|
||||
featurestore_service as featurestore_service_v1beta1,
|
||||
gen_ai_cache_service as gen_ai_cache_service_v1beta1,
|
||||
index as index_v1beta1,
|
||||
index_endpoint as index_endpoint_v1beta1,
|
||||
hyperparameter_tuning_job as hyperparameter_tuning_job_v1beta1,
|
||||
io as io_v1beta1,
|
||||
index_service as index_service_v1beta1,
|
||||
job_service as job_service_v1beta1,
|
||||
job_state as job_state_v1beta1,
|
||||
lineage_subgraph as lineage_subgraph_v1beta1,
|
||||
machine_resources as machine_resources_v1beta1,
|
||||
manual_batch_tuning_parameters as manual_batch_tuning_parameters_v1beta1,
|
||||
match_service as match_service_v1beta1,
|
||||
metadata_schema as metadata_schema_v1beta1,
|
||||
metadata_service as metadata_service_v1beta1,
|
||||
metadata_store as metadata_store_v1beta1,
|
||||
model as model_v1beta1,
|
||||
model_evaluation as model_evaluation_v1beta1,
|
||||
model_evaluation_slice as model_evaluation_slice_v1beta1,
|
||||
model_deployment_monitoring_job as model_deployment_monitoring_job_v1beta1,
|
||||
model_garden_service as model_garden_service_v1beta1,
|
||||
model_service as model_service_v1beta1,
|
||||
model_monitor as model_monitor_v1beta1,
|
||||
model_monitoring as model_monitoring_v1beta1,
|
||||
model_monitoring_alert as model_monitoring_alert_v1beta1,
|
||||
model_monitoring_job as model_monitoring_job_v1beta1,
|
||||
model_monitoring_service as model_monitoring_service_v1beta1,
|
||||
model_monitoring_spec as model_monitoring_spec_v1beta1,
|
||||
model_monitoring_stats as model_monitoring_stats_v1beta1,
|
||||
operation as operation_v1beta1,
|
||||
persistent_resource as persistent_resource_v1beta1,
|
||||
persistent_resource_service as persistent_resource_service_v1beta1,
|
||||
pipeline_failure_policy as pipeline_failure_policy_v1beta1,
|
||||
pipeline_job as pipeline_job_v1beta1,
|
||||
pipeline_service as pipeline_service_v1beta1,
|
||||
pipeline_state as pipeline_state_v1beta1,
|
||||
prediction_service as prediction_service_v1beta1,
|
||||
publisher_model as publisher_model_v1beta1,
|
||||
reservation_affinity as reservation_affinity_v1beta1,
|
||||
service_networking as service_networking_v1beta1,
|
||||
schedule as schedule_v1beta1,
|
||||
schedule_service as schedule_service_v1beta1,
|
||||
specialist_pool as specialist_pool_v1beta1,
|
||||
specialist_pool_service as specialist_pool_service_v1beta1,
|
||||
study as study_v1beta1,
|
||||
tensorboard as tensorboard_v1beta1,
|
||||
tensorboard_data as tensorboard_data_v1beta1,
|
||||
tensorboard_experiment as tensorboard_experiment_v1beta1,
|
||||
tensorboard_run as tensorboard_run_v1beta1,
|
||||
tensorboard_service as tensorboard_service_v1beta1,
|
||||
tensorboard_time_series as tensorboard_time_series_v1beta1,
|
||||
training_pipeline as training_pipeline_v1beta1,
|
||||
types as types_v1beta1,
|
||||
vizier_service as vizier_service_v1beta1,
|
||||
)
|
||||
from google.cloud.aiplatform_v1.types import (
|
||||
accelerator_type as accelerator_type_v1,
|
||||
annotation as annotation_v1,
|
||||
annotation_spec as annotation_spec_v1,
|
||||
artifact as artifact_v1,
|
||||
batch_prediction_job as batch_prediction_job_v1,
|
||||
cached_content as cached_content_v1,
|
||||
completion_stats as completion_stats_v1,
|
||||
context as context_v1,
|
||||
custom_job as custom_job_v1,
|
||||
data_item as data_item_v1,
|
||||
data_labeling_job as data_labeling_job_v1,
|
||||
dataset as dataset_v1,
|
||||
dataset_service as dataset_service_v1,
|
||||
deployed_index_ref as matching_engine_deployed_index_ref_v1,
|
||||
deployed_model_ref as deployed_model_ref_v1,
|
||||
deployment_resource_pool as deployment_resource_pool_v1,
|
||||
deployment_resource_pool_service as deployment_resource_pool_service_v1,
|
||||
encryption_spec as encryption_spec_v1,
|
||||
endpoint as endpoint_v1,
|
||||
endpoint_service as endpoint_service_v1,
|
||||
entity_type as entity_type_v1,
|
||||
env_var as env_var_v1,
|
||||
event as event_v1,
|
||||
execution as execution_v1,
|
||||
explanation as explanation_v1,
|
||||
explanation_metadata as explanation_metadata_v1,
|
||||
feature as feature_v1,
|
||||
feature_group as feature_group_v1,
|
||||
feature_monitoring_stats as feature_monitoring_stats_v1,
|
||||
feature_online_store as feature_online_store_v1,
|
||||
feature_online_store_admin_service as feature_online_store_admin_service_v1,
|
||||
feature_online_store_service as feature_online_store_service_v1,
|
||||
feature_registry_service as feature_registry_service_v1,
|
||||
feature_selector as feature_selector_v1,
|
||||
feature_view as feature_view_v1,
|
||||
feature_view_sync as feature_view_sync_v1,
|
||||
featurestore as featurestore_v1,
|
||||
featurestore_online_service as featurestore_online_service_v1,
|
||||
featurestore_service as featurestore_service_v1,
|
||||
hyperparameter_tuning_job as hyperparameter_tuning_job_v1,
|
||||
index as index_v1,
|
||||
index_endpoint as index_endpoint_v1,
|
||||
index_service as index_service_v1,
|
||||
io as io_v1,
|
||||
job_service as job_service_v1,
|
||||
job_state as job_state_v1,
|
||||
lineage_subgraph as lineage_subgraph_v1,
|
||||
machine_resources as machine_resources_v1,
|
||||
manual_batch_tuning_parameters as manual_batch_tuning_parameters_v1,
|
||||
metadata_service as metadata_service_v1,
|
||||
metadata_schema as metadata_schema_v1,
|
||||
metadata_store as metadata_store_v1,
|
||||
model as model_v1,
|
||||
model_evaluation as model_evaluation_v1,
|
||||
model_evaluation_slice as model_evaluation_slice_v1,
|
||||
model_deployment_monitoring_job as model_deployment_monitoring_job_v1,
|
||||
model_service as model_service_v1,
|
||||
model_monitoring as model_monitoring_v1,
|
||||
operation as operation_v1,
|
||||
persistent_resource as persistent_resource_v1,
|
||||
persistent_resource_service as persistent_resource_service_v1,
|
||||
pipeline_failure_policy as pipeline_failure_policy_v1,
|
||||
pipeline_job as pipeline_job_v1,
|
||||
pipeline_service as pipeline_service_v1,
|
||||
pipeline_state as pipeline_state_v1,
|
||||
prediction_service as prediction_service_v1,
|
||||
publisher_model as publisher_model_v1,
|
||||
reservation_affinity as reservation_affinity_v1,
|
||||
schedule as schedule_v1,
|
||||
schedule_service as schedule_service_v1,
|
||||
service_networking as service_networking_v1,
|
||||
specialist_pool as specialist_pool_v1,
|
||||
specialist_pool_service as specialist_pool_service_v1,
|
||||
study as study_v1,
|
||||
tensorboard as tensorboard_v1,
|
||||
tensorboard_data as tensorboard_data_v1,
|
||||
tensorboard_experiment as tensorboard_experiment_v1,
|
||||
tensorboard_run as tensorboard_run_v1,
|
||||
tensorboard_service as tensorboard_service_v1,
|
||||
tensorboard_time_series as tensorboard_time_series_v1,
|
||||
training_pipeline as training_pipeline_v1,
|
||||
types as types_v1,
|
||||
vizier_service as vizier_service_v1,
|
||||
)
|
||||
|
||||
__all__ = (
|
||||
# v1
|
||||
accelerator_type_v1,
|
||||
annotation_v1,
|
||||
annotation_spec_v1,
|
||||
artifact_v1,
|
||||
batch_prediction_job_v1,
|
||||
completion_stats_v1,
|
||||
context_v1,
|
||||
custom_job_v1,
|
||||
data_item_v1,
|
||||
data_labeling_job_v1,
|
||||
dataset_v1,
|
||||
dataset_service_v1,
|
||||
deployed_model_ref_v1,
|
||||
deployment_resource_pool_v1,
|
||||
deployment_resource_pool_service_v1,
|
||||
encryption_spec_v1,
|
||||
endpoint_v1,
|
||||
endpoint_service_v1,
|
||||
entity_type_v1,
|
||||
env_var_v1,
|
||||
event_v1,
|
||||
execution_v1,
|
||||
explanation_v1,
|
||||
explanation_metadata_v1,
|
||||
feature_v1,
|
||||
feature_monitoring_stats_v1,
|
||||
feature_selector_v1,
|
||||
featurestore_v1,
|
||||
featurestore_online_service_v1,
|
||||
featurestore_service_v1,
|
||||
hyperparameter_tuning_job_v1,
|
||||
io_v1,
|
||||
job_service_v1,
|
||||
job_state_v1,
|
||||
lineage_subgraph_v1,
|
||||
machine_resources_v1,
|
||||
manual_batch_tuning_parameters_v1,
|
||||
matching_engine_deployed_index_ref_v1,
|
||||
index_v1,
|
||||
index_endpoint_v1,
|
||||
index_service_v1,
|
||||
metadata_service_v1,
|
||||
metadata_schema_v1,
|
||||
metadata_store_v1,
|
||||
model_v1,
|
||||
model_evaluation_v1,
|
||||
model_evaluation_slice_v1,
|
||||
model_deployment_monitoring_job_v1,
|
||||
model_service_v1,
|
||||
model_monitoring_v1,
|
||||
operation_v1,
|
||||
persistent_resource_v1,
|
||||
persistent_resource_service_v1,
|
||||
pipeline_failure_policy_v1,
|
||||
pipeline_job_v1,
|
||||
pipeline_service_v1,
|
||||
pipeline_state_v1,
|
||||
prediction_service_v1,
|
||||
publisher_model_v1,
|
||||
reservation_affinity_v1,
|
||||
schedule_v1,
|
||||
schedule_service_v1,
|
||||
specialist_pool_v1,
|
||||
specialist_pool_service_v1,
|
||||
tensorboard_v1,
|
||||
tensorboard_data_v1,
|
||||
tensorboard_experiment_v1,
|
||||
tensorboard_run_v1,
|
||||
tensorboard_service_v1,
|
||||
tensorboard_time_series_v1,
|
||||
training_pipeline_v1,
|
||||
types_v1,
|
||||
study_v1,
|
||||
vizier_service_v1,
|
||||
# v1beta1
|
||||
accelerator_type_v1beta1,
|
||||
annotation_v1beta1,
|
||||
annotation_spec_v1beta1,
|
||||
artifact_v1beta1,
|
||||
batch_prediction_job_v1beta1,
|
||||
completion_stats_v1beta1,
|
||||
context_v1beta1,
|
||||
custom_job_v1beta1,
|
||||
data_item_v1beta1,
|
||||
data_labeling_job_v1beta1,
|
||||
dataset_v1beta1,
|
||||
dataset_service_v1beta1,
|
||||
deployment_resource_pool_v1beta1,
|
||||
deployment_resource_pool_service_v1beta1,
|
||||
deployed_model_ref_v1beta1,
|
||||
encryption_spec_v1beta1,
|
||||
endpoint_v1beta1,
|
||||
endpoint_service_v1beta1,
|
||||
entity_type_v1beta1,
|
||||
env_var_v1beta1,
|
||||
event_v1beta1,
|
||||
execution_v1beta1,
|
||||
explanation_v1beta1,
|
||||
explanation_metadata_v1beta1,
|
||||
feature_v1beta1,
|
||||
feature_monitoring_stats_v1beta1,
|
||||
feature_selector_v1beta1,
|
||||
featurestore_v1beta1,
|
||||
featurestore_monitoring_v1beta1,
|
||||
featurestore_online_service_v1beta1,
|
||||
featurestore_service_v1beta1,
|
||||
hyperparameter_tuning_job_v1beta1,
|
||||
io_v1beta1,
|
||||
job_service_v1beta1,
|
||||
job_state_v1beta1,
|
||||
lineage_subgraph_v1beta1,
|
||||
machine_resources_v1beta1,
|
||||
manual_batch_tuning_parameters_v1beta1,
|
||||
matching_engine_deployed_index_ref_v1beta1,
|
||||
index_v1beta1,
|
||||
index_endpoint_v1beta1,
|
||||
index_service_v1beta1,
|
||||
match_service_v1beta1,
|
||||
metadata_service_v1beta1,
|
||||
metadata_schema_v1beta1,
|
||||
metadata_store_v1beta1,
|
||||
model_v1beta1,
|
||||
model_evaluation_v1beta1,
|
||||
model_evaluation_slice_v1beta1,
|
||||
model_deployment_monitoring_job_v1beta1,
|
||||
model_garden_service_v1beta1,
|
||||
model_service_v1beta1,
|
||||
model_monitor_v1beta1,
|
||||
model_monitoring_v1beta1,
|
||||
model_monitoring_alert_v1beta1,
|
||||
model_monitoring_job_v1beta1,
|
||||
model_monitoring_service_v1beta1,
|
||||
model_monitoring_spec_v1beta1,
|
||||
model_monitoring_stats_v1beta1,
|
||||
operation_v1beta1,
|
||||
persistent_resource_v1beta1,
|
||||
persistent_resource_service_v1beta1,
|
||||
pipeline_failure_policy_v1beta1,
|
||||
pipeline_job_v1beta1,
|
||||
pipeline_service_v1beta1,
|
||||
pipeline_state_v1beta1,
|
||||
prediction_service_v1beta1,
|
||||
publisher_model_v1beta1,
|
||||
reservation_affinity_v1beta1,
|
||||
schedule_v1beta1,
|
||||
schedule_service_v1beta1,
|
||||
specialist_pool_v1beta1,
|
||||
specialist_pool_service_v1beta1,
|
||||
study_v1beta1,
|
||||
tensorboard_v1beta1,
|
||||
tensorboard_data_v1beta1,
|
||||
tensorboard_experiment_v1beta1,
|
||||
tensorboard_run_v1beta1,
|
||||
tensorboard_service_v1beta1,
|
||||
tensorboard_time_series_v1beta1,
|
||||
training_pipeline_v1beta1,
|
||||
types_v1beta1,
|
||||
vizier_service_v1beta1,
|
||||
)
|
||||
Binary file not shown.
@@ -0,0 +1,18 @@
|
||||
# Copyright 2021 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from google.cloud.aiplatform.constants import base
|
||||
from google.cloud.aiplatform.constants import prediction
|
||||
|
||||
__all__ = ("base", "prediction")
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,143 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
# Copyright 2022 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
from google.cloud.aiplatform import version as aiplatform_version
|
||||
|
||||
|
||||
DEFAULT_REGION = "us-central1"
|
||||
SUPPORTED_REGIONS = frozenset(
|
||||
{
|
||||
"africa-south1",
|
||||
"asia-east1",
|
||||
"asia-east2",
|
||||
"asia-northeast1",
|
||||
"asia-northeast2",
|
||||
"asia-northeast3",
|
||||
"asia-south1",
|
||||
"asia-southeast1",
|
||||
"asia-southeast2",
|
||||
"australia-southeast1",
|
||||
"australia-southeast2",
|
||||
"europe-central2",
|
||||
"europe-north1",
|
||||
"europe-southwest1",
|
||||
"europe-west1",
|
||||
"europe-west2",
|
||||
"europe-west3",
|
||||
"europe-west4",
|
||||
"europe-west6",
|
||||
"europe-west8",
|
||||
"europe-west9",
|
||||
"europe-west12",
|
||||
"global",
|
||||
"me-central1",
|
||||
"me-central2",
|
||||
"me-west1",
|
||||
"northamerica-northeast1",
|
||||
"northamerica-northeast2",
|
||||
"southamerica-east1",
|
||||
"southamerica-west1",
|
||||
"us-central1",
|
||||
"us-east1",
|
||||
"us-east4",
|
||||
"us-east5",
|
||||
"us-south1",
|
||||
"us-west1",
|
||||
"us-west2",
|
||||
"us-west3",
|
||||
"us-west4",
|
||||
}
|
||||
)
|
||||
|
||||
API_BASE_PATH = "aiplatform.googleapis.com"
|
||||
PREDICTION_API_BASE_PATH = API_BASE_PATH
|
||||
|
||||
# Batch Prediction
|
||||
BATCH_PREDICTION_INPUT_STORAGE_FORMATS = (
|
||||
"jsonl",
|
||||
"csv",
|
||||
"tf-record",
|
||||
"tf-record-gzip",
|
||||
"bigquery",
|
||||
"file-list",
|
||||
)
|
||||
BATCH_PREDICTION_OUTPUT_STORAGE_FORMATS = ("jsonl", "csv", "bigquery")
|
||||
|
||||
MOBILE_TF_MODEL_TYPES = {
|
||||
"MOBILE_TF_LOW_LATENCY_1",
|
||||
"MOBILE_TF_VERSATILE_1",
|
||||
"MOBILE_TF_HIGH_ACCURACY_1",
|
||||
}
|
||||
|
||||
MODEL_GARDEN_ICN_MODEL_TYPES = {
|
||||
"EFFICIENTNET",
|
||||
"MAXVIT",
|
||||
"VIT",
|
||||
"COCA",
|
||||
}
|
||||
|
||||
MODEL_GARDEN_IOD_MODEL_TYPES = {
|
||||
"SPINENET",
|
||||
"YOLO",
|
||||
}
|
||||
|
||||
# TODO(b/177079208): Use EPCL Enums for validating Model Types
|
||||
# Defined by gs://google-cloud-aiplatform/schema/trainingjob/definition/automl_image_*
|
||||
# Format: "prediction_type": set() of model_type's
|
||||
#
|
||||
# NOTE: When adding a new prediction_type's, ensure it fits the pattern
|
||||
# "automl_image_{prediction_type}_*" used by the YAML schemas on GCS
|
||||
AUTOML_IMAGE_PREDICTION_MODEL_TYPES = {
|
||||
"classification": {"CLOUD", "CLOUD_1"}
|
||||
| MOBILE_TF_MODEL_TYPES
|
||||
| MODEL_GARDEN_ICN_MODEL_TYPES,
|
||||
"object_detection": {"CLOUD_1", "CLOUD_HIGH_ACCURACY_1", "CLOUD_LOW_LATENCY_1"}
|
||||
| MOBILE_TF_MODEL_TYPES
|
||||
| MODEL_GARDEN_IOD_MODEL_TYPES,
|
||||
}
|
||||
|
||||
AUTOML_VIDEO_PREDICTION_MODEL_TYPES = {
|
||||
"classification": {"CLOUD"} | {"MOBILE_VERSATILE_1"},
|
||||
"action_recognition": {"CLOUD"} | {"MOBILE_VERSATILE_1"},
|
||||
"object_tracking": {"CLOUD"}
|
||||
| {
|
||||
"MOBILE_VERSATILE_1",
|
||||
"MOBILE_CORAL_VERSATILE_1",
|
||||
"MOBILE_CORAL_LOW_LATENCY_1",
|
||||
"MOBILE_JETSON_VERSATILE_1",
|
||||
"MOBILE_JETSON_LOW_LATENCY_1",
|
||||
},
|
||||
}
|
||||
|
||||
# Used in constructing the requests user_agent header for metrics reporting.
|
||||
USER_AGENT_PRODUCT = "model-builder"
|
||||
# This field is used to pass the name of the specific SDK method
|
||||
# that is being used for usage metrics tracking purposes.
|
||||
# For more details on go/oneplatform-api-analytics
|
||||
USER_AGENT_SDK_COMMAND = ""
|
||||
|
||||
# Needed for Endpoint.raw_predict
|
||||
DEFAULT_AUTHED_SCOPES = ["https://www.googleapis.com/auth/cloud-platform"]
|
||||
|
||||
# Used in CustomJob.from_local_script for experiments integration in training
|
||||
AIPLATFORM_DEPENDENCY_PATH = (
|
||||
f"google-cloud-aiplatform=={aiplatform_version.__version__}"
|
||||
)
|
||||
|
||||
AIPLATFORM_AUTOLOG_DEPENDENCY_PATH = (
|
||||
f"google-cloud-aiplatform[autologging]=={aiplatform_version.__version__}"
|
||||
)
|
||||
@@ -0,0 +1,59 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
# Copyright 2022 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
import re
|
||||
|
||||
from google.cloud.aiplatform.compat.types import (
|
||||
pipeline_state as gca_pipeline_state,
|
||||
)
|
||||
|
||||
_PIPELINE_COMPLETE_STATES = set(
|
||||
[
|
||||
gca_pipeline_state.PipelineState.PIPELINE_STATE_SUCCEEDED,
|
||||
gca_pipeline_state.PipelineState.PIPELINE_STATE_FAILED,
|
||||
gca_pipeline_state.PipelineState.PIPELINE_STATE_CANCELLED,
|
||||
gca_pipeline_state.PipelineState.PIPELINE_STATE_PAUSED,
|
||||
]
|
||||
)
|
||||
|
||||
_PIPELINE_ERROR_STATES = set([gca_pipeline_state.PipelineState.PIPELINE_STATE_FAILED])
|
||||
|
||||
# Pattern for valid names used as a Vertex resource name.
|
||||
_VALID_NAME_PATTERN = re.compile("^[a-z][-a-z0-9]{0,127}$", re.IGNORECASE)
|
||||
|
||||
# Pattern for an Artifact Registry URL.
|
||||
_VALID_AR_URL = re.compile(r"^https:\/\/([\w-]+)-kfp\.pkg\.dev\/.*", re.IGNORECASE)
|
||||
|
||||
# Pattern for any JSON or YAML file over HTTPS.
|
||||
_VALID_HTTPS_URL = re.compile(r"^https:\/\/([\.\/\w-]+)\/.*(json|yaml|yml)$")
|
||||
|
||||
# Fields to include in returned PipelineJob when enable_simple_view=True in PipelineJob.list()
|
||||
_READ_MASK_FIELDS = [
|
||||
"name",
|
||||
"state",
|
||||
"display_name",
|
||||
"pipeline_spec.pipeline_info",
|
||||
"create_time",
|
||||
"start_time",
|
||||
"end_time",
|
||||
"update_time",
|
||||
"labels",
|
||||
"template_uri",
|
||||
"template_metadata.version",
|
||||
"job_detail.pipeline_run_context",
|
||||
"job_detail.pipeline_context",
|
||||
]
|
||||
@@ -0,0 +1,304 @@
|
||||
# Copyright 2022 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import re
|
||||
|
||||
from collections import defaultdict
|
||||
|
||||
# [region]-docker.pkg.dev/vertex-ai/prediction/[framework]-[accelerator].[version]:latest
|
||||
CONTAINER_URI_PATTERN = re.compile(
|
||||
r"(?P<region>[\w]+)\-docker\.pkg\.dev\/vertex\-ai\/prediction\/"
|
||||
r"(?P<framework>[\w]+)\-(?P<accelerator>[\w]+)\.(?P<version>[\d-]+):latest"
|
||||
)
|
||||
|
||||
CONTAINER_URI_REGEX = (
|
||||
r"^(us|europe|asia)-docker.pkg.dev/"
|
||||
r"vertex-ai/prediction/"
|
||||
r"(tf|sklearn|xgboost|pytorch).+$"
|
||||
)
|
||||
|
||||
SKLEARN = "sklearn"
|
||||
TF = "tf"
|
||||
TF2 = "tf2"
|
||||
XGBOOST = "xgboost"
|
||||
|
||||
XGBOOST_CONTAINER_URIS = [
|
||||
"us-docker.pkg.dev/vertex-ai/prediction/xgboost-cpu.2-1:latest",
|
||||
"europe-docker.pkg.dev/vertex-ai/prediction/xgboost-cpu.2-1:latest",
|
||||
"asia-docker.pkg.dev/vertex-ai/prediction/xgboost-cpu.2-1:latest",
|
||||
"us-docker.pkg.dev/vertex-ai/prediction/xgboost-cpu.2-0:latest",
|
||||
"europe-docker.pkg.dev/vertex-ai/prediction/xgboost-cpu.2-0:latest",
|
||||
"asia-docker.pkg.dev/vertex-ai/prediction/xgboost-cpu.2-0:latest",
|
||||
"us-docker.pkg.dev/vertex-ai/prediction/xgboost-cpu.1-7:latest",
|
||||
"europe-docker.pkg.dev/vertex-ai/prediction/xgboost-cpu.1-7:latest",
|
||||
"asia-docker.pkg.dev/vertex-ai/prediction/xgboost-cpu.1-7:latest",
|
||||
"us-docker.pkg.dev/vertex-ai/prediction/xgboost-cpu.1-6:latest",
|
||||
"europe-docker.pkg.dev/vertex-ai/prediction/xgboost-cpu.1-6:latest",
|
||||
"asia-docker.pkg.dev/vertex-ai/prediction/xgboost-cpu.1-6:latest",
|
||||
"us-docker.pkg.dev/vertex-ai/prediction/xgboost-cpu.1-5:latest",
|
||||
"europe-docker.pkg.dev/vertex-ai/prediction/xgboost-cpu.1-5:latest",
|
||||
"asia-docker.pkg.dev/vertex-ai/prediction/xgboost-cpu.1-5:latest",
|
||||
"us-docker.pkg.dev/vertex-ai/prediction/xgboost-cpu.1-4:latest",
|
||||
"europe-docker.pkg.dev/vertex-ai/prediction/xgboost-cpu.1-4:latest",
|
||||
"asia-docker.pkg.dev/vertex-ai/prediction/xgboost-cpu.1-4:latest",
|
||||
"us-docker.pkg.dev/vertex-ai/prediction/xgboost-cpu.1-3:latest",
|
||||
"europe-docker.pkg.dev/vertex-ai/prediction/xgboost-cpu.1-3:latest",
|
||||
"asia-docker.pkg.dev/vertex-ai/prediction/xgboost-cpu.1-3:latest",
|
||||
"us-docker.pkg.dev/vertex-ai/prediction/xgboost-cpu.1-2:latest",
|
||||
"europe-docker.pkg.dev/vertex-ai/prediction/xgboost-cpu.1-2:latest",
|
||||
"asia-docker.pkg.dev/vertex-ai/prediction/xgboost-cpu.1-2:latest",
|
||||
"us-docker.pkg.dev/vertex-ai/prediction/xgboost-cpu.1-1:latest",
|
||||
"europe-docker.pkg.dev/vertex-ai/prediction/xgboost-cpu.1-1:latest",
|
||||
"asia-docker.pkg.dev/vertex-ai/prediction/xgboost-cpu.1-1:latest",
|
||||
"us-docker.pkg.dev/vertex-ai/prediction/xgboost-cpu.0-90:latest",
|
||||
"europe-docker.pkg.dev/vertex-ai/prediction/xgboost-cpu.0-90:latest",
|
||||
"asia-docker.pkg.dev/vertex-ai/prediction/xgboost-cpu.0-90:latest",
|
||||
"us-docker.pkg.dev/vertex-ai/prediction/xgboost-cpu.0-82:latest",
|
||||
"europe-docker.pkg.dev/vertex-ai/prediction/xgboost-cpu.0-82:latest",
|
||||
"asia-docker.pkg.dev/vertex-ai/prediction/xgboost-cpu.0-82:latest",
|
||||
]
|
||||
|
||||
SKLEARN_CONTAINER_URIS = [
|
||||
"us-docker.pkg.dev/vertex-ai/prediction/sklearn-cpu.1-5:latest",
|
||||
"europe-docker.pkg.dev/vertex-ai/prediction/sklearn-cpu.1-5:latest",
|
||||
"asia-docker.pkg.dev/vertex-ai/prediction/sklearn-cpu.1-5:latest",
|
||||
"us-docker.pkg.dev/vertex-ai/prediction/sklearn-cpu.1-4:latest",
|
||||
"europe-docker.pkg.dev/vertex-ai/prediction/sklearn-cpu.1-4:latest",
|
||||
"asia-docker.pkg.dev/vertex-ai/prediction/sklearn-cpu.1-4:latest",
|
||||
"us-docker.pkg.dev/vertex-ai/prediction/sklearn-cpu.1-3:latest",
|
||||
"europe-docker.pkg.dev/vertex-ai/prediction/sklearn-cpu.1-3:latest",
|
||||
"asia-docker.pkg.dev/vertex-ai/prediction/sklearn-cpu.1-3:latest",
|
||||
"us-docker.pkg.dev/vertex-ai/prediction/sklearn-cpu.1-2:latest",
|
||||
"europe-docker.pkg.dev/vertex-ai/prediction/sklearn-cpu.1-2:latest",
|
||||
"asia-docker.pkg.dev/vertex-ai/prediction/sklearn-cpu.1-2:latest",
|
||||
"us-docker.pkg.dev/vertex-ai/prediction/sklearn-cpu.1-0:latest",
|
||||
"europe-docker.pkg.dev/vertex-ai/prediction/sklearn-cpu.1-0:latest",
|
||||
"asia-docker.pkg.dev/vertex-ai/prediction/sklearn-cpu.1-0:latest",
|
||||
"us-docker.pkg.dev/vertex-ai/prediction/sklearn-cpu.0-24:latest",
|
||||
"europe-docker.pkg.dev/vertex-ai/prediction/sklearn-cpu.0-24:latest",
|
||||
"asia-docker.pkg.dev/vertex-ai/prediction/sklearn-cpu.0-24:latest",
|
||||
"us-docker.pkg.dev/vertex-ai/prediction/sklearn-cpu.0-23:latest",
|
||||
"europe-docker.pkg.dev/vertex-ai/prediction/sklearn-cpu.0-23:latest",
|
||||
"asia-docker.pkg.dev/vertex-ai/prediction/sklearn-cpu.0-23:latest",
|
||||
"us-docker.pkg.dev/vertex-ai/prediction/sklearn-cpu.0-22:latest",
|
||||
"europe-docker.pkg.dev/vertex-ai/prediction/sklearn-cpu.0-22:latest",
|
||||
"asia-docker.pkg.dev/vertex-ai/prediction/sklearn-cpu.0-22:latest",
|
||||
"us-docker.pkg.dev/vertex-ai/prediction/sklearn-cpu.0-20:latest",
|
||||
"europe-docker.pkg.dev/vertex-ai/prediction/sklearn-cpu.0-20:latest",
|
||||
"asia-docker.pkg.dev/vertex-ai/prediction/sklearn-cpu.0-20:latest",
|
||||
]
|
||||
|
||||
TF_CONTAINER_URIS = [
|
||||
"us-docker.pkg.dev/vertex-ai/prediction/tf2-cpu.2-15:latest",
|
||||
"europe-docker.pkg.dev/vertex-ai/prediction/tf2-cpu.2-15:latest",
|
||||
"asia-docker.pkg.dev/vertex-ai/prediction/tf2-cpu.2-15:latest",
|
||||
"us-docker.pkg.dev/vertex-ai/prediction/tf2-gpu.2-15:latest",
|
||||
"europe-docker.pkg.dev/vertex-ai/prediction/tf2-gpu.2-15:latest",
|
||||
"asia-docker.pkg.dev/vertex-ai/prediction/tf2-gpu.2-15:latest",
|
||||
"us-docker.pkg.dev/vertex-ai/prediction/tf2-cpu.2-14:latest",
|
||||
"europe-docker.pkg.dev/vertex-ai/prediction/tf2-cpu.2-14:latest",
|
||||
"asia-docker.pkg.dev/vertex-ai/prediction/tf2-cpu.2-14:latest",
|
||||
"us-docker.pkg.dev/vertex-ai/prediction/tf2-gpu.2-14:latest",
|
||||
"europe-docker.pkg.dev/vertex-ai/prediction/tf2-gpu.2-14:latest",
|
||||
"asia-docker.pkg.dev/vertex-ai/prediction/tf2-gpu.2-14:latest",
|
||||
"us-docker.pkg.dev/vertex-ai/prediction/tf2-cpu.2-13:latest",
|
||||
"europe-docker.pkg.dev/vertex-ai/prediction/tf2-cpu.2-13:latest",
|
||||
"asia-docker.pkg.dev/vertex-ai/prediction/tf2-cpu.2-13:latest",
|
||||
"us-docker.pkg.dev/vertex-ai/prediction/tf2-gpu.2-13:latest",
|
||||
"europe-docker.pkg.dev/vertex-ai/prediction/tf2-gpu.2-13:latest",
|
||||
"asia-docker.pkg.dev/vertex-ai/prediction/tf2-gpu.2-13:latest",
|
||||
"us-docker.pkg.dev/vertex-ai/prediction/tf2-cpu.2-12:latest",
|
||||
"europe-docker.pkg.dev/vertex-ai/prediction/tf2-cpu.2-12:latest",
|
||||
"asia-docker.pkg.dev/vertex-ai/prediction/tf2-cpu.2-12:latest",
|
||||
"us-docker.pkg.dev/vertex-ai/prediction/tf2-gpu.2-12:latest",
|
||||
"europe-docker.pkg.dev/vertex-ai/prediction/tf2-gpu.2-12:latest",
|
||||
"asia-docker.pkg.dev/vertex-ai/prediction/tf2-gpu.2-12:latest",
|
||||
"us-docker.pkg.dev/vertex-ai/prediction/tf2-cpu.2-11:latest",
|
||||
"europe-docker.pkg.dev/vertex-ai/prediction/tf2-cpu.2-11:latest",
|
||||
"asia-docker.pkg.dev/vertex-ai/prediction/tf2-cpu.2-11:latest",
|
||||
"us-docker.pkg.dev/vertex-ai/prediction/tf2-gpu.2-11:latest",
|
||||
"europe-docker.pkg.dev/vertex-ai/prediction/tf2-gpu.2-11:latest",
|
||||
"asia-docker.pkg.dev/vertex-ai/prediction/tf2-gpu.2-11:latest",
|
||||
"us-docker.pkg.dev/vertex-ai/prediction/tf2-cpu.2-10:latest",
|
||||
"europe-docker.pkg.dev/vertex-ai/prediction/tf2-cpu.2-10:latest",
|
||||
"asia-docker.pkg.dev/vertex-ai/prediction/tf2-cpu.2-10:latest",
|
||||
"us-docker.pkg.dev/vertex-ai/prediction/tf2-gpu.2-10:latest",
|
||||
"europe-docker.pkg.dev/vertex-ai/prediction/tf2-gpu.2-10:latest",
|
||||
"asia-docker.pkg.dev/vertex-ai/prediction/tf2-gpu.2-10:latest",
|
||||
"us-docker.pkg.dev/vertex-ai/prediction/tf2-cpu.2-9:latest",
|
||||
"europe-docker.pkg.dev/vertex-ai/prediction/tf2-cpu.2-9:latest",
|
||||
"asia-docker.pkg.dev/vertex-ai/prediction/tf2-cpu.2-9:latest",
|
||||
"us-docker.pkg.dev/vertex-ai/prediction/tf2-gpu.2-9:latest",
|
||||
"europe-docker.pkg.dev/vertex-ai/prediction/tf2-gpu.2-9:latest",
|
||||
"asia-docker.pkg.dev/vertex-ai/prediction/tf2-gpu.2-9:latest",
|
||||
"us-docker.pkg.dev/vertex-ai/prediction/tf2-cpu.2-8:latest",
|
||||
"europe-docker.pkg.dev/vertex-ai/prediction/tf2-cpu.2-8:latest",
|
||||
"asia-docker.pkg.dev/vertex-ai/prediction/tf2-cpu.2-8:latest",
|
||||
"us-docker.pkg.dev/vertex-ai/prediction/tf2-gpu.2-8:latest",
|
||||
"europe-docker.pkg.dev/vertex-ai/prediction/tf2-gpu.2-8:latest",
|
||||
"asia-docker.pkg.dev/vertex-ai/prediction/tf2-gpu.2-8:latest",
|
||||
"us-docker.pkg.dev/vertex-ai/prediction/tf2-cpu.2-7:latest",
|
||||
"europe-docker.pkg.dev/vertex-ai/prediction/tf2-cpu.2-7:latest",
|
||||
"asia-docker.pkg.dev/vertex-ai/prediction/tf2-cpu.2-7:latest",
|
||||
"us-docker.pkg.dev/vertex-ai/prediction/tf2-gpu.2-7:latest",
|
||||
"europe-docker.pkg.dev/vertex-ai/prediction/tf2-gpu.2-7:latest",
|
||||
"asia-docker.pkg.dev/vertex-ai/prediction/tf2-gpu.2-7:latest",
|
||||
"us-docker.pkg.dev/vertex-ai/prediction/tf2-cpu.2-6:latest",
|
||||
"europe-docker.pkg.dev/vertex-ai/prediction/tf2-cpu.2-6:latest",
|
||||
"asia-docker.pkg.dev/vertex-ai/prediction/tf2-cpu.2-6:latest",
|
||||
"us-docker.pkg.dev/vertex-ai/prediction/tf2-gpu.2-6:latest",
|
||||
"europe-docker.pkg.dev/vertex-ai/prediction/tf2-gpu.2-6:latest",
|
||||
"asia-docker.pkg.dev/vertex-ai/prediction/tf2-gpu.2-6:latest",
|
||||
"us-docker.pkg.dev/vertex-ai/prediction/tf2-cpu.2-5:latest",
|
||||
"europe-docker.pkg.dev/vertex-ai/prediction/tf2-cpu.2-5:latest",
|
||||
"asia-docker.pkg.dev/vertex-ai/prediction/tf2-cpu.2-5:latest",
|
||||
"us-docker.pkg.dev/vertex-ai/prediction/tf2-gpu.2-5:latest",
|
||||
"europe-docker.pkg.dev/vertex-ai/prediction/tf2-gpu.2-5:latest",
|
||||
"asia-docker.pkg.dev/vertex-ai/prediction/tf2-gpu.2-5:latest",
|
||||
"us-docker.pkg.dev/vertex-ai/prediction/tf2-cpu.2-4:latest",
|
||||
"europe-docker.pkg.dev/vertex-ai/prediction/tf2-cpu.2-4:latest",
|
||||
"asia-docker.pkg.dev/vertex-ai/prediction/tf2-cpu.2-4:latest",
|
||||
"us-docker.pkg.dev/vertex-ai/prediction/tf2-gpu.2-4:latest",
|
||||
"europe-docker.pkg.dev/vertex-ai/prediction/tf2-gpu.2-4:latest",
|
||||
"asia-docker.pkg.dev/vertex-ai/prediction/tf2-gpu.2-4:latest",
|
||||
"us-docker.pkg.dev/vertex-ai/prediction/tf2-cpu.2-3:latest",
|
||||
"europe-docker.pkg.dev/vertex-ai/prediction/tf2-cpu.2-3:latest",
|
||||
"asia-docker.pkg.dev/vertex-ai/prediction/tf2-cpu.2-3:latest",
|
||||
"us-docker.pkg.dev/vertex-ai/prediction/tf2-gpu.2-3:latest",
|
||||
"europe-docker.pkg.dev/vertex-ai/prediction/tf2-gpu.2-3:latest",
|
||||
"asia-docker.pkg.dev/vertex-ai/prediction/tf2-gpu.2-3:latest",
|
||||
"us-docker.pkg.dev/vertex-ai/prediction/tf2-cpu.2-2:latest",
|
||||
"europe-docker.pkg.dev/vertex-ai/prediction/tf2-cpu.2-2:latest",
|
||||
"asia-docker.pkg.dev/vertex-ai/prediction/tf2-cpu.2-2:latest",
|
||||
"us-docker.pkg.dev/vertex-ai/prediction/tf2-gpu.2-2:latest",
|
||||
"europe-docker.pkg.dev/vertex-ai/prediction/tf2-gpu.2-2:latest",
|
||||
"asia-docker.pkg.dev/vertex-ai/prediction/tf2-gpu.2-2:latest",
|
||||
"us-docker.pkg.dev/vertex-ai/prediction/tf2-cpu.2-1:latest",
|
||||
"europe-docker.pkg.dev/vertex-ai/prediction/tf2-cpu.2-1:latest",
|
||||
"asia-docker.pkg.dev/vertex-ai/prediction/tf2-cpu.2-1:latest",
|
||||
"us-docker.pkg.dev/vertex-ai/prediction/tf2-gpu.2-1:latest",
|
||||
"europe-docker.pkg.dev/vertex-ai/prediction/tf2-gpu.2-1:latest",
|
||||
"asia-docker.pkg.dev/vertex-ai/prediction/tf2-gpu.2-1:latest",
|
||||
"us-docker.pkg.dev/vertex-ai/prediction/tf-cpu.1-15:latest",
|
||||
"europe-docker.pkg.dev/vertex-ai/prediction/tf-cpu.1-15:latest",
|
||||
"asia-docker.pkg.dev/vertex-ai/prediction/tf-cpu.1-15:latest",
|
||||
"us-docker.pkg.dev/vertex-ai/prediction/tf-gpu.1-15:latest",
|
||||
"europe-docker.pkg.dev/vertex-ai/prediction/tf-gpu.1-15:latest",
|
||||
"asia-docker.pkg.dev/vertex-ai/prediction/tf-gpu.1-15:latest",
|
||||
]
|
||||
|
||||
PYTORCH_CONTAINER_URIS = [
|
||||
"us-docker.pkg.dev/vertex-ai/prediction/pytorch-cpu.2-4:latest",
|
||||
"europe-docker.pkg.dev/vertex-ai/prediction/pytorch-cpu.2-4:latest",
|
||||
"asia-docker.pkg.dev/vertex-ai/prediction/pytorch-cpu.2-4:latest",
|
||||
"us-docker.pkg.dev/vertex-ai/prediction/pytorch-gpu.2-4:latest",
|
||||
"europe-docker.pkg.dev/vertex-ai/prediction/pytorch-gpu.2-4:latest",
|
||||
"asia-docker.pkg.dev/vertex-ai/prediction/pytorch-gpu.2-4:latest",
|
||||
"us-docker.pkg.dev/vertex-ai/prediction/pytorch-cpu.2-3:latest",
|
||||
"europe-docker.pkg.dev/vertex-ai/prediction/pytorch-cpu.2-3:latest",
|
||||
"asia-docker.pkg.dev/vertex-ai/prediction/pytorch-cpu.2-3:latest",
|
||||
"us-docker.pkg.dev/vertex-ai/prediction/pytorch-gpu.2-3:latest",
|
||||
"europe-docker.pkg.dev/vertex-ai/prediction/pytorch-gpu.2-3:latest",
|
||||
"asia-docker.pkg.dev/vertex-ai/prediction/pytorch-gpu.2-3:latest",
|
||||
"us-docker.pkg.dev/vertex-ai/prediction/pytorch-cpu.2-2:latest",
|
||||
"europe-docker.pkg.dev/vertex-ai/prediction/pytorch-cpu.2-2:latest",
|
||||
"asia-docker.pkg.dev/vertex-ai/prediction/pytorch-cpu.2-2:latest",
|
||||
"us-docker.pkg.dev/vertex-ai/prediction/pytorch-gpu.2-2:latest",
|
||||
"europe-docker.pkg.dev/vertex-ai/prediction/pytorch-gpu.2-2:latest",
|
||||
"asia-docker.pkg.dev/vertex-ai/prediction/pytorch-gpu.2-2:latest",
|
||||
"us-docker.pkg.dev/vertex-ai/prediction/pytorch-cpu.2-1:latest",
|
||||
"europe-docker.pkg.dev/vertex-ai/prediction/pytorch-cpu.2-1:latest",
|
||||
"asia-docker.pkg.dev/vertex-ai/prediction/pytorch-cpu.2-1:latest",
|
||||
"us-docker.pkg.dev/vertex-ai/prediction/pytorch-gpu.2-1:latest",
|
||||
"europe-docker.pkg.dev/vertex-ai/prediction/pytorch-gpu.2-1:latest",
|
||||
"asia-docker.pkg.dev/vertex-ai/prediction/pytorch-gpu.2-1:latest",
|
||||
"us-docker.pkg.dev/vertex-ai/prediction/pytorch-cpu.2-0:latest",
|
||||
"europe-docker.pkg.dev/vertex-ai/prediction/pytorch-cpu.2-0:latest",
|
||||
"asia-docker.pkg.dev/vertex-ai/prediction/pytorch-cpu.2-0:latest",
|
||||
"us-docker.pkg.dev/vertex-ai/prediction/pytorch-gpu.2-0:latest",
|
||||
"europe-docker.pkg.dev/vertex-ai/prediction/pytorch-gpu.2-0:latest",
|
||||
"asia-docker.pkg.dev/vertex-ai/prediction/pytorch-gpu.2-0:latest",
|
||||
"us-docker.pkg.dev/vertex-ai/prediction/pytorch-cpu.1-13:latest",
|
||||
"europe-docker.pkg.dev/vertex-ai/prediction/pytorch-cpu.1-13:latest",
|
||||
"asia-docker.pkg.dev/vertex-ai/prediction/pytorch-cpu.1-13:latest",
|
||||
"us-docker.pkg.dev/vertex-ai/prediction/pytorch-gpu.1-13:latest",
|
||||
"europe-docker.pkg.dev/vertex-ai/prediction/pytorch-gpu.1-13:latest",
|
||||
"asia-docker.pkg.dev/vertex-ai/prediction/pytorch-gpu.1-13:latest",
|
||||
"us-docker.pkg.dev/vertex-ai/prediction/pytorch-cpu.1-12:latest",
|
||||
"europe-docker.pkg.dev/vertex-ai/prediction/pytorch-cpu.1-12:latest",
|
||||
"asia-docker.pkg.dev/vertex-ai/prediction/pytorch-cpu.1-12:latest",
|
||||
"us-docker.pkg.dev/vertex-ai/prediction/pytorch-gpu.1-12:latest",
|
||||
"europe-docker.pkg.dev/vertex-ai/prediction/pytorch-gpu.1-12:latest",
|
||||
"asia-docker.pkg.dev/vertex-ai/prediction/pytorch-gpu.1-12:latest",
|
||||
"us-docker.pkg.dev/vertex-ai/prediction/pytorch-cpu.1-11:latest",
|
||||
"europe-docker.pkg.dev/vertex-ai/prediction/pytorch-cpu.1-11:latest",
|
||||
"asia-docker.pkg.dev/vertex-ai/prediction/pytorch-cpu.1-11:latest",
|
||||
"us-docker.pkg.dev/vertex-ai/prediction/pytorch-gpu.1-11:latest",
|
||||
"europe-docker.pkg.dev/vertex-ai/prediction/pytorch-gpu.1-11:latest",
|
||||
"asia-docker.pkg.dev/vertex-ai/prediction/pytorch-gpu.1-11:latest",
|
||||
]
|
||||
|
||||
SERVING_CONTAINER_URIS = (
|
||||
SKLEARN_CONTAINER_URIS
|
||||
+ TF_CONTAINER_URIS
|
||||
+ XGBOOST_CONTAINER_URIS
|
||||
+ PYTORCH_CONTAINER_URIS
|
||||
)
|
||||
|
||||
# Map of all first-party prediction containers
|
||||
d = defaultdict(lambda: defaultdict(lambda: defaultdict(lambda: defaultdict(str))))
|
||||
|
||||
for container_uri in SERVING_CONTAINER_URIS:
|
||||
m = CONTAINER_URI_PATTERN.match(container_uri)
|
||||
region, framework, accelerator, version = m[1], m[2], m[3], m[4]
|
||||
version = version.replace("-", ".")
|
||||
|
||||
if framework in (TF2, TF): # Store both `tf`, `tf2` as `tensorflow`
|
||||
framework = "tensorflow"
|
||||
|
||||
d[region][framework][accelerator][version] = container_uri
|
||||
|
||||
_SERVING_CONTAINER_URI_MAP = d
|
||||
|
||||
_SERVING_CONTAINER_DOCUMENTATION_URL = (
|
||||
"https://cloud.google.com/vertex-ai/docs/predictions/pre-built-containers"
|
||||
)
|
||||
|
||||
# Variables set by Vertex AI. For more details, please refer to
|
||||
# https://cloud.google.com/vertex-ai/docs/predictions/custom-container-requirements#aip-variables
|
||||
DEFAULT_AIP_HTTP_PORT = 8080
|
||||
AIP_HTTP_PORT = "AIP_HTTP_PORT"
|
||||
AIP_HEALTH_ROUTE = "AIP_HEALTH_ROUTE"
|
||||
AIP_PREDICT_ROUTE = "AIP_PREDICT_ROUTE"
|
||||
AIP_STORAGE_URI = "AIP_STORAGE_URI"
|
||||
|
||||
# Default values for Prediction local experience.
|
||||
DEFAULT_LOCAL_PREDICT_ROUTE = "/predict"
|
||||
DEFAULT_LOCAL_HEALTH_ROUTE = "/health"
|
||||
DEFAULT_LOCAL_RUN_GPU_CAPABILITIES = [["utility", "compute"]]
|
||||
DEFAULT_LOCAL_RUN_GPU_COUNT = -1
|
||||
|
||||
CUSTOM_PREDICTION_ROUTINES = "custom-prediction-routines"
|
||||
CUSTOM_PREDICTION_ROUTINES_SERVER_ERROR_HEADER_KEY = "X-AIP-CPR-SYSTEM-ERROR"
|
||||
|
||||
# Headers' related constants for the handler usage.
|
||||
CONTENT_TYPE_HEADER_REGEX = re.compile("^[Cc]ontent-?[Tt]ype$")
|
||||
ACCEPT_HEADER_REGEX = re.compile("^[Aa]ccept$")
|
||||
ANY_ACCEPT_TYPE = "*/*"
|
||||
DEFAULT_ACCEPT_VALUE = "application/json"
|
||||
|
||||
# Model filenames.
|
||||
MODEL_FILENAME_BST = "model.bst"
|
||||
MODEL_FILENAME_JOBLIB = "model.joblib"
|
||||
MODEL_FILENAME_PKL = "model.pkl"
|
||||
@@ -0,0 +1,43 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
# Copyright 2023 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
from google.cloud.aiplatform.compat.types import (
|
||||
schedule as gca_schedule,
|
||||
)
|
||||
from google.cloud.aiplatform.constants import pipeline as pipeline_constants
|
||||
|
||||
_SCHEDULE_COMPLETE_STATES = set(
|
||||
[
|
||||
gca_schedule.Schedule.State.PAUSED,
|
||||
gca_schedule.Schedule.State.COMPLETED,
|
||||
]
|
||||
)
|
||||
|
||||
_SCHEDULE_ERROR_STATES = set(
|
||||
[
|
||||
gca_schedule.Schedule.State.STATE_UNSPECIFIED,
|
||||
]
|
||||
)
|
||||
|
||||
# Pattern for valid names used as a Vertex resource name.
|
||||
_VALID_NAME_PATTERN = pipeline_constants._VALID_NAME_PATTERN
|
||||
|
||||
# Pattern for an Artifact Registry URL.
|
||||
_VALID_AR_URL = pipeline_constants._VALID_AR_URL
|
||||
|
||||
# Pattern for any JSON or YAML file over HTTPS.
|
||||
_VALID_HTTPS_URL = pipeline_constants._VALID_HTTPS_URL
|
||||
@@ -0,0 +1,35 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
# Copyright 2020 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
from google.cloud.aiplatform.datasets.dataset import _Dataset
|
||||
from google.cloud.aiplatform.datasets.column_names_dataset import _ColumnNamesDataset
|
||||
from google.cloud.aiplatform.datasets.tabular_dataset import TabularDataset
|
||||
from google.cloud.aiplatform.datasets.time_series_dataset import TimeSeriesDataset
|
||||
from google.cloud.aiplatform.datasets.image_dataset import ImageDataset
|
||||
from google.cloud.aiplatform.datasets.text_dataset import TextDataset
|
||||
from google.cloud.aiplatform.datasets.video_dataset import VideoDataset
|
||||
|
||||
|
||||
__all__ = (
|
||||
"_Dataset",
|
||||
"_ColumnNamesDataset",
|
||||
"TabularDataset",
|
||||
"TimeSeriesDataset",
|
||||
"ImageDataset",
|
||||
"TextDataset",
|
||||
"VideoDataset",
|
||||
)
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,240 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
# Copyright 2020 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
import abc
|
||||
from typing import Optional, Dict, Sequence, Union
|
||||
from google.cloud.aiplatform import schema
|
||||
|
||||
from google.cloud.aiplatform.compat.types import (
|
||||
io as gca_io,
|
||||
dataset as gca_dataset,
|
||||
)
|
||||
|
||||
|
||||
class Datasource(abc.ABC):
|
||||
"""An abstract class that sets dataset_metadata."""
|
||||
|
||||
@property
|
||||
@abc.abstractmethod
|
||||
def dataset_metadata(self):
|
||||
"""Dataset Metadata."""
|
||||
pass
|
||||
|
||||
|
||||
class DatasourceImportable(abc.ABC):
|
||||
"""An abstract class that sets import_data_config."""
|
||||
|
||||
@property
|
||||
@abc.abstractmethod
|
||||
def import_data_config(self):
|
||||
"""Import Data Config."""
|
||||
pass
|
||||
|
||||
|
||||
class TabularDatasource(Datasource):
|
||||
"""Datasource for creating a tabular dataset for Vertex AI."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
gcs_source: Optional[Union[str, Sequence[str]]] = None,
|
||||
bq_source: Optional[str] = None,
|
||||
):
|
||||
"""Creates a tabular datasource.
|
||||
|
||||
Args:
|
||||
gcs_source (Union[str, Sequence[str]]):
|
||||
Cloud Storage URI of one or more files. Only CSV files are supported.
|
||||
The first line of the CSV file is used as the header.
|
||||
If there are multiple files, the header is the first line of
|
||||
the lexicographically first file, the other files must either
|
||||
contain the exact same header or omit the header.
|
||||
examples:
|
||||
str: "gs://bucket/file.csv"
|
||||
Sequence[str]: ["gs://bucket/file1.csv", "gs://bucket/file2.csv"]
|
||||
bq_source (str):
|
||||
The URI of a BigQuery table.
|
||||
example:
|
||||
"bq://project.dataset.table_name"
|
||||
|
||||
Raises:
|
||||
ValueError: If source configuration is not valid.
|
||||
"""
|
||||
|
||||
dataset_metadata = None
|
||||
|
||||
if gcs_source and isinstance(gcs_source, str):
|
||||
gcs_source = [gcs_source]
|
||||
|
||||
if gcs_source and bq_source:
|
||||
raise ValueError("Only one of gcs_source or bq_source can be set.")
|
||||
|
||||
if not any([gcs_source, bq_source]):
|
||||
raise ValueError("One of gcs_source or bq_source must be set.")
|
||||
|
||||
if gcs_source:
|
||||
dataset_metadata = {"inputConfig": {"gcsSource": {"uri": gcs_source}}}
|
||||
elif bq_source:
|
||||
dataset_metadata = {"inputConfig": {"bigquerySource": {"uri": bq_source}}}
|
||||
|
||||
self._dataset_metadata = dataset_metadata
|
||||
|
||||
@property
|
||||
def dataset_metadata(self) -> Optional[Dict]:
|
||||
"""Dataset Metadata."""
|
||||
return self._dataset_metadata
|
||||
|
||||
|
||||
class NonTabularDatasource(Datasource):
|
||||
"""Datasource for creating an empty non-tabular dataset for Vertex AI."""
|
||||
|
||||
@property
|
||||
def dataset_metadata(self) -> Optional[Dict]:
|
||||
return None
|
||||
|
||||
|
||||
class NonTabularDatasourceImportable(NonTabularDatasource, DatasourceImportable):
|
||||
"""Datasource for creating a non-tabular dataset for Vertex AI and
|
||||
importing data to the dataset."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
gcs_source: Union[str, Sequence[str]],
|
||||
import_schema_uri: str,
|
||||
data_item_labels: Optional[Dict] = None,
|
||||
):
|
||||
"""Creates a non-tabular datasource.
|
||||
|
||||
Args:
|
||||
gcs_source (Union[str, Sequence[str]]):
|
||||
Required. The Google Cloud Storage location for the input content.
|
||||
Google Cloud Storage URI(-s) to the input file(s).
|
||||
|
||||
Examples:
|
||||
str: "gs://bucket/file.csv"
|
||||
Sequence[str]: ["gs://bucket/file1.csv", "gs://bucket/file2.csv"]
|
||||
import_schema_uri (str):
|
||||
Required. Points to a YAML file stored on Google Cloud
|
||||
Storage describing the import format. Validation will be
|
||||
done against the schema. The schema is defined as an
|
||||
`OpenAPI 3.0.2 Schema
|
||||
data_item_labels (Dict):
|
||||
Labels that will be applied to newly imported DataItems. If
|
||||
an identical DataItem as one being imported already exists
|
||||
in the Dataset, then these labels will be appended to these
|
||||
of the already existing one, and if labels with identical
|
||||
key is imported before, the old label value will be
|
||||
overwritten. If two DataItems are identical in the same
|
||||
import data operation, the labels will be combined and if
|
||||
key collision happens in this case, one of the values will
|
||||
be picked randomly. Two DataItems are considered identical
|
||||
if their content bytes are identical (e.g. image bytes or
|
||||
pdf bytes). These labels will be overridden by Annotation
|
||||
labels specified inside index file refenced by
|
||||
``import_schema_uri``,
|
||||
e.g. jsonl file.
|
||||
"""
|
||||
super().__init__()
|
||||
self._gcs_source = [gcs_source] if isinstance(gcs_source, str) else gcs_source
|
||||
self._import_schema_uri = import_schema_uri
|
||||
self._data_item_labels = data_item_labels
|
||||
|
||||
@property
|
||||
def import_data_config(self) -> gca_dataset.ImportDataConfig:
|
||||
"""Import Data Config."""
|
||||
return gca_dataset.ImportDataConfig(
|
||||
gcs_source=gca_io.GcsSource(uris=self._gcs_source),
|
||||
import_schema_uri=self._import_schema_uri,
|
||||
data_item_labels=self._data_item_labels,
|
||||
)
|
||||
|
||||
|
||||
def create_datasource(
|
||||
metadata_schema_uri: str,
|
||||
import_schema_uri: Optional[str] = None,
|
||||
gcs_source: Optional[Union[str, Sequence[str]]] = None,
|
||||
bq_source: Optional[str] = None,
|
||||
data_item_labels: Optional[Dict] = None,
|
||||
) -> Datasource:
|
||||
"""Creates a datasource
|
||||
Args:
|
||||
metadata_schema_uri (str):
|
||||
Required. Points to a YAML file stored on Google Cloud Storage
|
||||
describing additional information about the Dataset. The schema
|
||||
is defined as an OpenAPI 3.0.2 Schema Object. The schema files
|
||||
that can be used here are found in gs://google-cloud-
|
||||
aiplatform/schema/dataset/metadata/.
|
||||
import_schema_uri (str):
|
||||
Points to a YAML file stored on Google Cloud
|
||||
Storage describing the import format. Validation will be
|
||||
done against the schema. The schema is defined as an
|
||||
`OpenAPI 3.0.2 Schema
|
||||
gcs_source (Union[str, Sequence[str]]):
|
||||
The Google Cloud Storage location for the input content.
|
||||
Google Cloud Storage URI(-s) to the input file(s).
|
||||
|
||||
Examples:
|
||||
str: "gs://bucket/file.csv"
|
||||
Sequence[str]: ["gs://bucket/file1.csv", "gs://bucket/file2.csv"]
|
||||
bq_source (str):
|
||||
BigQuery URI to the input table.
|
||||
example:
|
||||
"bq://project.dataset.table_name"
|
||||
data_item_labels (Dict):
|
||||
Labels that will be applied to newly imported DataItems. If
|
||||
an identical DataItem as one being imported already exists
|
||||
in the Dataset, then these labels will be appended to these
|
||||
of the already existing one, and if labels with identical
|
||||
key is imported before, the old label value will be
|
||||
overwritten. If two DataItems are identical in the same
|
||||
import data operation, the labels will be combined and if
|
||||
key collision happens in this case, one of the values will
|
||||
be picked randomly. Two DataItems are considered identical
|
||||
if their content bytes are identical (e.g. image bytes or
|
||||
pdf bytes). These labels will be overridden by Annotation
|
||||
labels specified inside index file refenced by
|
||||
``import_schema_uri``,
|
||||
e.g. jsonl file.
|
||||
|
||||
Returns:
|
||||
datasource (Datasource)
|
||||
|
||||
Raises:
|
||||
ValueError: When below scenarios happen:
|
||||
- import_schema_uri is identified for creating TabularDatasource
|
||||
- either import_schema_uri or gcs_source is missing for creating NonTabularDatasourceImportable
|
||||
"""
|
||||
|
||||
if metadata_schema_uri == schema.dataset.metadata.tabular:
|
||||
if import_schema_uri:
|
||||
raise ValueError("tabular dataset does not support data import.")
|
||||
return TabularDatasource(gcs_source, bq_source)
|
||||
|
||||
if metadata_schema_uri == schema.dataset.metadata.time_series:
|
||||
if import_schema_uri:
|
||||
raise ValueError("time series dataset does not support data import.")
|
||||
return TabularDatasource(gcs_source, bq_source)
|
||||
|
||||
if not import_schema_uri and not gcs_source:
|
||||
return NonTabularDatasource()
|
||||
elif import_schema_uri and gcs_source:
|
||||
return NonTabularDatasourceImportable(
|
||||
gcs_source, import_schema_uri, data_item_labels
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
"nontabular dataset requires both import_schema_uri and gcs_source for data import."
|
||||
)
|
||||
@@ -0,0 +1,261 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
# Copyright 2021 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
|
||||
import csv
|
||||
import logging
|
||||
from typing import List, Optional, Set, TYPE_CHECKING
|
||||
from google.auth import credentials as auth_credentials
|
||||
|
||||
from google.cloud import storage
|
||||
|
||||
from google.cloud.aiplatform import utils
|
||||
from google.cloud.aiplatform import datasets
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from google.cloud import bigquery
|
||||
|
||||
|
||||
class _ColumnNamesDataset(datasets._Dataset):
|
||||
@property
|
||||
def column_names(self) -> List[str]:
|
||||
"""Retrieve the columns for the dataset by extracting it from the Google Cloud Storage or
|
||||
Google BigQuery source.
|
||||
|
||||
Returns:
|
||||
List[str]
|
||||
A list of columns names
|
||||
|
||||
Raises:
|
||||
RuntimeError: When no valid source is found.
|
||||
"""
|
||||
|
||||
self._assert_gca_resource_is_available()
|
||||
|
||||
metadata = self._gca_resource.metadata
|
||||
|
||||
if metadata is None:
|
||||
raise RuntimeError("No metadata found for dataset")
|
||||
|
||||
input_config = metadata.get("inputConfig")
|
||||
|
||||
if input_config is None:
|
||||
raise RuntimeError("No inputConfig found for dataset")
|
||||
|
||||
gcs_source = input_config.get("gcsSource")
|
||||
bq_source = input_config.get("bigquerySource")
|
||||
|
||||
if gcs_source:
|
||||
gcs_source_uris = gcs_source.get("uri")
|
||||
|
||||
if gcs_source_uris and len(gcs_source_uris) > 0:
|
||||
# Lexicographically sort the files
|
||||
gcs_source_uris.sort()
|
||||
|
||||
# Get the first file in sorted list
|
||||
# TODO(b/193044977): Return as Set instead of List
|
||||
return list(
|
||||
self._retrieve_gcs_source_columns(
|
||||
project=self.project,
|
||||
gcs_csv_file_path=gcs_source_uris[0],
|
||||
credentials=self.credentials,
|
||||
)
|
||||
)
|
||||
elif bq_source:
|
||||
bq_table_uri = bq_source.get("uri")
|
||||
if bq_table_uri:
|
||||
# TODO(b/193044977): Return as Set instead of List
|
||||
return list(
|
||||
self._retrieve_bq_source_columns(
|
||||
project=self.project,
|
||||
bq_table_uri=bq_table_uri,
|
||||
credentials=self.credentials,
|
||||
)
|
||||
)
|
||||
|
||||
raise RuntimeError("No valid CSV or BigQuery datasource found.")
|
||||
|
||||
@staticmethod
|
||||
def _retrieve_gcs_source_columns(
|
||||
project: str,
|
||||
gcs_csv_file_path: str,
|
||||
credentials: Optional[auth_credentials.Credentials] = None,
|
||||
) -> Set[str]:
|
||||
"""Retrieve the columns from a comma-delimited CSV file stored on Google Cloud Storage
|
||||
|
||||
Example Usage:
|
||||
|
||||
column_names = _retrieve_gcs_source_columns(
|
||||
"project_id",
|
||||
"gs://example-bucket/path/to/csv_file"
|
||||
)
|
||||
|
||||
# column_names = {"column_1", "column_2"}
|
||||
|
||||
Args:
|
||||
project (str):
|
||||
Required. Project to initiate the Google Cloud Storage client with.
|
||||
gcs_csv_file_path (str):
|
||||
Required. A full path to a CSV files stored on Google Cloud Storage.
|
||||
Must include "gs://" prefix.
|
||||
credentials (auth_credentials.Credentials):
|
||||
Credentials to use to with GCS Client.
|
||||
Returns:
|
||||
Set[str]
|
||||
A set of columns names in the CSV file.
|
||||
|
||||
Raises:
|
||||
RuntimeError: When the retrieved CSV file is invalid.
|
||||
"""
|
||||
|
||||
gcs_bucket, gcs_blob = utils.extract_bucket_and_prefix_from_gcs_path(
|
||||
gcs_csv_file_path
|
||||
)
|
||||
client = storage.Client(project=project, credentials=credentials)
|
||||
bucket = client.bucket(gcs_bucket)
|
||||
blob = bucket.blob(gcs_blob)
|
||||
|
||||
# Incrementally download the CSV file until the header is retrieved
|
||||
first_new_line_index = -1
|
||||
start_index = 0
|
||||
increment = 1000
|
||||
line = ""
|
||||
|
||||
try:
|
||||
logger = logging.getLogger("google.resumable_media._helpers")
|
||||
logging_warning_filter = utils.LoggingFilter(logging.INFO)
|
||||
logger.addFilter(logging_warning_filter)
|
||||
|
||||
while first_new_line_index == -1:
|
||||
line += blob.download_as_bytes(
|
||||
start=start_index, end=start_index + increment - 1
|
||||
).decode("utf-8")
|
||||
|
||||
first_new_line_index = line.find("\n")
|
||||
start_index += increment
|
||||
|
||||
header_line = line[:first_new_line_index]
|
||||
|
||||
# Split to make it an iterable
|
||||
header_line = header_line.split("\n")[:1]
|
||||
|
||||
csv_reader = csv.reader(header_line, delimiter=",")
|
||||
except (ValueError, RuntimeError) as err:
|
||||
raise RuntimeError(
|
||||
"There was a problem extracting the headers from the CSV file at '{}': {}".format(
|
||||
gcs_csv_file_path, err
|
||||
)
|
||||
) from err
|
||||
finally:
|
||||
logger.removeFilter(logging_warning_filter)
|
||||
|
||||
return set(next(csv_reader))
|
||||
|
||||
@staticmethod
|
||||
def _get_bq_schema_field_names_recursively(
|
||||
schema_field: "bigquery.SchemaField",
|
||||
) -> Set[str]:
|
||||
"""Retrieve the name for a schema field along with ancestor fields.
|
||||
Nested schema fields are flattened and concatenated with a ".".
|
||||
Schema fields with child fields are not included, but the children are.
|
||||
|
||||
Args:
|
||||
project (str):
|
||||
Required. Project to initiate the BigQuery client with.
|
||||
bq_table_uri (str):
|
||||
Required. A URI to a BigQuery table.
|
||||
Can include "bq://" prefix but not required.
|
||||
credentials (auth_credentials.Credentials):
|
||||
Credentials to use with BQ Client.
|
||||
|
||||
Returns:
|
||||
Set[str]
|
||||
A set of columns names in the BigQuery table.
|
||||
"""
|
||||
|
||||
ancestor_names = {
|
||||
nested_field_name
|
||||
for field in schema_field.fields
|
||||
for nested_field_name in _ColumnNamesDataset._get_bq_schema_field_names_recursively(
|
||||
field
|
||||
)
|
||||
}
|
||||
|
||||
# Only return "leaf nodes", basically any field that doesn't have children
|
||||
if len(ancestor_names) == 0:
|
||||
return {schema_field.name}
|
||||
else:
|
||||
return {f"{schema_field.name}.{name}" for name in ancestor_names}
|
||||
|
||||
@staticmethod
|
||||
def _retrieve_bq_source_columns(
|
||||
project: str,
|
||||
bq_table_uri: str,
|
||||
credentials: Optional[auth_credentials.Credentials] = None,
|
||||
) -> Set[str]:
|
||||
"""Retrieve the column names from a table on Google BigQuery
|
||||
Nested schema fields are flattened and concatenated with a ".".
|
||||
Schema fields with child fields are not included, but the children are.
|
||||
|
||||
Example Usage:
|
||||
|
||||
column_names = _retrieve_bq_source_columns(
|
||||
"project_id",
|
||||
"bq://project_id.dataset.table"
|
||||
)
|
||||
|
||||
# column_names = {"column_1", "column_2", "column_3.nested_field"}
|
||||
|
||||
Args:
|
||||
project (str):
|
||||
Required. Project to initiate the BigQuery client with.
|
||||
bq_table_uri (str):
|
||||
Required. A URI to a BigQuery table.
|
||||
Can include "bq://" prefix but not required.
|
||||
credentials (auth_credentials.Credentials):
|
||||
Credentials to use with BQ Client.
|
||||
|
||||
Returns:
|
||||
Set[str]
|
||||
A set of column names in the BigQuery table.
|
||||
"""
|
||||
|
||||
# Remove bq:// prefix
|
||||
prefix = "bq://"
|
||||
if bq_table_uri.startswith(prefix):
|
||||
bq_table_uri = bq_table_uri[len(prefix) :]
|
||||
|
||||
# The colon-based "project:dataset.table" format is no longer supported:
|
||||
# Invalid dataset ID "bigquery-public-data:chicago_taxi_trips".
|
||||
# Dataset IDs must be alphanumeric (plus underscores and dashes) and must be at most 1024 characters long.
|
||||
# Using dot-based "project.dataset.table" format instead.
|
||||
bq_table_uri = bq_table_uri.replace(":", ".")
|
||||
|
||||
# Loading bigquery lazily to avoid auto-loading it when importing vertexai
|
||||
from google.cloud import bigquery # pylint: disable=g-import-not-at-top
|
||||
|
||||
client = bigquery.Client(project=project, credentials=credentials)
|
||||
table = client.get_table(bq_table_uri)
|
||||
schema = table.schema
|
||||
|
||||
return {
|
||||
field_name
|
||||
for field in schema
|
||||
for field_name in _ColumnNamesDataset._get_bq_schema_field_names_recursively(
|
||||
field
|
||||
)
|
||||
}
|
||||
@@ -0,0 +1,927 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
# Copyright 2022 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union
|
||||
|
||||
from google.api_core import operation
|
||||
from google.auth import credentials as auth_credentials
|
||||
|
||||
from google.cloud.aiplatform import base
|
||||
from google.cloud.aiplatform import initializer
|
||||
from google.cloud.aiplatform import utils
|
||||
|
||||
from google.cloud.aiplatform.compat.services import dataset_service_client
|
||||
from google.cloud.aiplatform.compat.types import (
|
||||
dataset as gca_dataset,
|
||||
dataset_service as gca_dataset_service,
|
||||
encryption_spec as gca_encryption_spec,
|
||||
io as gca_io,
|
||||
)
|
||||
from google.cloud.aiplatform.datasets import _datasources
|
||||
from google.protobuf import field_mask_pb2
|
||||
from google.protobuf import json_format
|
||||
|
||||
_LOGGER = base.Logger(__name__)
|
||||
|
||||
|
||||
class _Dataset(base.VertexAiResourceNounWithFutureManager):
|
||||
"""Managed dataset resource for Vertex AI."""
|
||||
|
||||
client_class = utils.DatasetClientWithOverride
|
||||
_resource_noun = "datasets"
|
||||
_getter_method = "get_dataset"
|
||||
_list_method = "list_datasets"
|
||||
_delete_method = "delete_dataset"
|
||||
_parse_resource_name_method = "parse_dataset_path"
|
||||
_format_resource_name_method = "dataset_path"
|
||||
|
||||
_supported_metadata_schema_uris: Tuple[str] = ()
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dataset_name: str,
|
||||
project: Optional[str] = None,
|
||||
location: Optional[str] = None,
|
||||
credentials: Optional[auth_credentials.Credentials] = None,
|
||||
):
|
||||
"""Retrieves an existing managed dataset given a dataset name or ID.
|
||||
|
||||
Args:
|
||||
dataset_name (str):
|
||||
Required. A fully-qualified dataset resource name or dataset ID.
|
||||
Example: "projects/123/locations/us-central1/datasets/456" or
|
||||
"456" when project and location are initialized or passed.
|
||||
project (str):
|
||||
Optional project to retrieve dataset from. If not set, project
|
||||
set in aiplatform.init will be used.
|
||||
location (str):
|
||||
Optional location to retrieve dataset from. If not set, location
|
||||
set in aiplatform.init will be used.
|
||||
credentials (auth_credentials.Credentials):
|
||||
Custom credentials to use to retrieve this Dataset. Overrides
|
||||
credentials set in aiplatform.init.
|
||||
"""
|
||||
|
||||
super().__init__(
|
||||
project=project,
|
||||
location=location,
|
||||
credentials=credentials,
|
||||
resource_name=dataset_name,
|
||||
)
|
||||
self._gca_resource = self._get_gca_resource(resource_name=dataset_name)
|
||||
self._validate_metadata_schema_uri()
|
||||
|
||||
@property
|
||||
def metadata_schema_uri(self) -> str:
|
||||
"""The metadata schema uri of this dataset resource."""
|
||||
self._assert_gca_resource_is_available()
|
||||
return self._gca_resource.metadata_schema_uri
|
||||
|
||||
def _validate_metadata_schema_uri(self) -> None:
|
||||
"""Validate the metadata_schema_uri of retrieved dataset resource.
|
||||
|
||||
Raises:
|
||||
ValueError: If the dataset type of the retrieved dataset resource is
|
||||
not supported by the class.
|
||||
"""
|
||||
if self._supported_metadata_schema_uris and (
|
||||
self.metadata_schema_uri not in self._supported_metadata_schema_uris
|
||||
):
|
||||
raise ValueError(
|
||||
f"{self.__class__.__name__} class can not be used to retrieve "
|
||||
f"dataset resource {self.resource_name}, check the dataset type"
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def create(
|
||||
cls,
|
||||
# TODO(b/223262536): Make the display_name parameter optional in the next major release
|
||||
display_name: str,
|
||||
metadata_schema_uri: str,
|
||||
gcs_source: Optional[Union[str, Sequence[str]]] = None,
|
||||
bq_source: Optional[str] = None,
|
||||
import_schema_uri: Optional[str] = None,
|
||||
data_item_labels: Optional[Dict] = None,
|
||||
project: Optional[str] = None,
|
||||
location: Optional[str] = None,
|
||||
credentials: Optional[auth_credentials.Credentials] = None,
|
||||
request_metadata: Optional[Sequence[Tuple[str, str]]] = (),
|
||||
labels: Optional[Dict[str, str]] = None,
|
||||
encryption_spec_key_name: Optional[str] = None,
|
||||
sync: bool = True,
|
||||
create_request_timeout: Optional[float] = None,
|
||||
) -> "_Dataset":
|
||||
"""Creates a new dataset and optionally imports data into dataset when
|
||||
source and import_schema_uri are passed.
|
||||
|
||||
Args:
|
||||
display_name (str):
|
||||
Required. The user-defined name of the Dataset.
|
||||
The name can be up to 128 characters long and can be consist
|
||||
of any UTF-8 characters.
|
||||
metadata_schema_uri (str):
|
||||
Required. Points to a YAML file stored on Google Cloud Storage
|
||||
describing additional information about the Dataset. The schema
|
||||
is defined as an OpenAPI 3.0.2 Schema Object. The schema files
|
||||
that can be used here are found in gs://google-cloud-
|
||||
aiplatform/schema/dataset/metadata/.
|
||||
gcs_source (Union[str, Sequence[str]]):
|
||||
Google Cloud Storage URI(-s) to the
|
||||
input file(s). May contain wildcards. For more
|
||||
information on wildcards, see
|
||||
https://cloud.google.com/storage/docs/gsutil/addlhelp/WildcardNames.
|
||||
examples:
|
||||
str: "gs://bucket/file.csv"
|
||||
Sequence[str]: ["gs://bucket/file1.csv", "gs://bucket/file2.csv"]
|
||||
bq_source (str):
|
||||
BigQuery URI to the input table.
|
||||
example:
|
||||
"bq://project.dataset.table_name"
|
||||
import_schema_uri (str):
|
||||
Points to a YAML file stored on Google Cloud
|
||||
Storage describing the import format. Validation will be
|
||||
done against the schema. The schema is defined as an
|
||||
`OpenAPI 3.0.2 Schema
|
||||
Object <https://tinyurl.com/y538mdwt>`__.
|
||||
data_item_labels (Dict):
|
||||
Labels that will be applied to newly imported DataItems. If
|
||||
an identical DataItem as one being imported already exists
|
||||
in the Dataset, then these labels will be appended to these
|
||||
of the already existing one, and if labels with identical
|
||||
key is imported before, the old label value will be
|
||||
overwritten. If two DataItems are identical in the same
|
||||
import data operation, the labels will be combined and if
|
||||
key collision happens in this case, one of the values will
|
||||
be picked randomly. Two DataItems are considered identical
|
||||
if their content bytes are identical (e.g. image bytes or
|
||||
pdf bytes). These labels will be overridden by Annotation
|
||||
labels specified inside index file referenced by
|
||||
``import_schema_uri``,
|
||||
e.g. jsonl file.
|
||||
This arg is not for specifying the annotation name or the
|
||||
training target of your data, but for some global labels of
|
||||
the dataset. E.g.,
|
||||
'data_item_labels={"aiplatform.googleapis.com/ml_use":"training"}'
|
||||
specifies that all the uploaded data are used for training.
|
||||
project (str):
|
||||
Project to upload this dataset to. Overrides project set in
|
||||
aiplatform.init.
|
||||
location (str):
|
||||
Location to upload this dataset to. Overrides location set in
|
||||
aiplatform.init.
|
||||
credentials (auth_credentials.Credentials):
|
||||
Custom credentials to use to upload this dataset. Overrides
|
||||
credentials set in aiplatform.init.
|
||||
request_metadata (Sequence[Tuple[str, str]]):
|
||||
Strings which should be sent along with the request as metadata.
|
||||
labels (Dict[str, str]):
|
||||
Optional. Labels with user-defined metadata to organize your datasets.
|
||||
Label keys and values can be no longer than 64 characters
|
||||
(Unicode codepoints), can only contain lowercase letters, numeric
|
||||
characters, underscores and dashes. International characters are allowed.
|
||||
No more than 64 user labels can be associated with one Dataset
|
||||
(System labels are excluded).
|
||||
See https://goo.gl/xmQnxf for more information and examples of labels.
|
||||
System reserved label keys are prefixed with "aiplatform.googleapis.com/"
|
||||
and are immutable.
|
||||
encryption_spec_key_name (Optional[str]):
|
||||
Optional. The Cloud KMS resource identifier of the customer
|
||||
managed encryption key used to protect the dataset. Has the
|
||||
form:
|
||||
``projects/my-project/locations/my-region/keyRings/my-kr/cryptoKeys/my-key``.
|
||||
The key needs to be in the same region as where the compute
|
||||
resource is created.
|
||||
|
||||
If set, this Dataset and all sub-resources of this Dataset will be secured by this key.
|
||||
|
||||
Overrides encryption_spec_key_name set in aiplatform.init.
|
||||
sync (bool):
|
||||
Whether to execute this method synchronously. If False, this method
|
||||
will be executed in concurrent Future and any downstream object will
|
||||
be immediately returned and synced when the Future has completed.
|
||||
create_request_timeout (float):
|
||||
Optional. The timeout for the create request in seconds.
|
||||
|
||||
Returns:
|
||||
dataset (Dataset):
|
||||
Instantiated representation of the managed dataset resource.
|
||||
"""
|
||||
if not display_name:
|
||||
display_name = cls._generate_display_name()
|
||||
utils.validate_display_name(display_name)
|
||||
if labels:
|
||||
utils.validate_labels(labels)
|
||||
|
||||
api_client = cls._instantiate_client(location=location, credentials=credentials)
|
||||
|
||||
datasource = _datasources.create_datasource(
|
||||
metadata_schema_uri=metadata_schema_uri,
|
||||
import_schema_uri=import_schema_uri,
|
||||
gcs_source=gcs_source,
|
||||
bq_source=bq_source,
|
||||
data_item_labels=data_item_labels,
|
||||
)
|
||||
|
||||
return cls._create_and_import(
|
||||
api_client=api_client,
|
||||
parent=initializer.global_config.common_location_path(
|
||||
project=project, location=location
|
||||
),
|
||||
display_name=display_name,
|
||||
metadata_schema_uri=metadata_schema_uri,
|
||||
datasource=datasource,
|
||||
project=project or initializer.global_config.project,
|
||||
location=location or initializer.global_config.location,
|
||||
credentials=credentials or initializer.global_config.credentials,
|
||||
request_metadata=request_metadata,
|
||||
labels=labels,
|
||||
encryption_spec=initializer.global_config.get_encryption_spec(
|
||||
encryption_spec_key_name=encryption_spec_key_name
|
||||
),
|
||||
sync=sync,
|
||||
create_request_timeout=create_request_timeout,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@base.optional_sync()
|
||||
def _create_and_import(
|
||||
cls,
|
||||
api_client: dataset_service_client.DatasetServiceClient,
|
||||
parent: str,
|
||||
display_name: str,
|
||||
metadata_schema_uri: str,
|
||||
datasource: _datasources.Datasource,
|
||||
project: str,
|
||||
location: str,
|
||||
credentials: Optional[auth_credentials.Credentials],
|
||||
request_metadata: Optional[Sequence[Tuple[str, str]]] = (),
|
||||
labels: Optional[Dict[str, str]] = None,
|
||||
encryption_spec: Optional[gca_encryption_spec.EncryptionSpec] = None,
|
||||
sync: bool = True,
|
||||
create_request_timeout: Optional[float] = None,
|
||||
import_request_timeout: Optional[float] = None,
|
||||
) -> "_Dataset":
|
||||
"""Creates a new dataset and optionally imports data into dataset when
|
||||
source and import_schema_uri are passed.
|
||||
|
||||
Args:
|
||||
api_client (dataset_service_client.DatasetServiceClient):
|
||||
An instance of DatasetServiceClient with the correct api_endpoint
|
||||
already set based on user's preferences.
|
||||
parent (str):
|
||||
Required. Also known as common location path, that usually contains the
|
||||
project and location that the user provided to the upstream method.
|
||||
Example: "projects/my-prj/locations/us-central1"
|
||||
display_name (str):
|
||||
Required. The user-defined name of the Dataset.
|
||||
The name can be up to 128 characters long and can be consist
|
||||
of any UTF-8 characters.
|
||||
metadata_schema_uri (str):
|
||||
Required. Points to a YAML file stored on Google Cloud Storage
|
||||
describing additional information about the Dataset. The schema
|
||||
is defined as an OpenAPI 3.0.2 Schema Object. The schema files
|
||||
that can be used here are found in gs://google-cloud-
|
||||
aiplatform/schema/dataset/metadata/.
|
||||
datasource (_datasources.Datasource):
|
||||
Required. Datasource for creating a dataset for Vertex AI.
|
||||
project (str):
|
||||
Required. Project to upload this model to. Overrides project set in
|
||||
aiplatform.init.
|
||||
location (str):
|
||||
Required. Location to upload this model to. Overrides location set in
|
||||
aiplatform.init.
|
||||
credentials (Optional[auth_credentials.Credentials]):
|
||||
Custom credentials to use to upload this model. Overrides
|
||||
credentials set in aiplatform.init.
|
||||
request_metadata (Sequence[Tuple[str, str]]):
|
||||
Strings which should be sent along with the request as metadata.
|
||||
labels (Dict[str, str]):
|
||||
Optional. Labels with user-defined metadata to organize your Tensorboards.
|
||||
Label keys and values can be no longer than 64 characters
|
||||
(Unicode codepoints), can only contain lowercase letters, numeric
|
||||
characters, underscores and dashes. International characters are allowed.
|
||||
No more than 64 user labels can be associated with one Tensorboard
|
||||
(System labels are excluded).
|
||||
See https://goo.gl/xmQnxf for more information and examples of labels.
|
||||
System reserved label keys are prefixed with "aiplatform.googleapis.com/"
|
||||
and are immutable.
|
||||
encryption_spec (Optional[gca_encryption_spec.EncryptionSpec]):
|
||||
Optional. The Cloud KMS customer managed encryption key used to protect the dataset.
|
||||
The key needs to be in the same region as where the compute
|
||||
resource is created.
|
||||
|
||||
If set, this Dataset and all sub-resources of this Dataset will be secured by this key.
|
||||
sync (bool):
|
||||
Whether to execute this method synchronously. If False, this method
|
||||
will be executed in concurrent Future and any downstream object will
|
||||
be immediately returned and synced when the Future has completed.
|
||||
create_request_timeout (float):
|
||||
Optional. The timeout for the create request in seconds.
|
||||
import_request_timeout (float):
|
||||
Optional. The timeout for the import request in seconds.
|
||||
|
||||
Returns:
|
||||
dataset (Dataset):
|
||||
Instantiated representation of the managed dataset resource.
|
||||
"""
|
||||
|
||||
create_dataset_lro = cls._create(
|
||||
api_client=api_client,
|
||||
parent=parent,
|
||||
display_name=display_name,
|
||||
metadata_schema_uri=metadata_schema_uri,
|
||||
datasource=datasource,
|
||||
request_metadata=request_metadata,
|
||||
labels=labels,
|
||||
encryption_spec=encryption_spec,
|
||||
create_request_timeout=create_request_timeout,
|
||||
)
|
||||
|
||||
_LOGGER.log_create_with_lro(cls, create_dataset_lro)
|
||||
|
||||
created_dataset = create_dataset_lro.result(timeout=None)
|
||||
|
||||
_LOGGER.log_create_complete(cls, created_dataset, "ds")
|
||||
|
||||
dataset_obj = cls(
|
||||
dataset_name=created_dataset.name,
|
||||
project=project,
|
||||
location=location,
|
||||
credentials=credentials,
|
||||
)
|
||||
|
||||
# Import if import datasource is DatasourceImportable
|
||||
if isinstance(datasource, _datasources.DatasourceImportable):
|
||||
dataset_obj._import_and_wait(
|
||||
datasource, import_request_timeout=import_request_timeout
|
||||
)
|
||||
|
||||
return dataset_obj
|
||||
|
||||
def _import_and_wait(
|
||||
self,
|
||||
datasource,
|
||||
import_request_timeout: Optional[float] = None,
|
||||
):
|
||||
_LOGGER.log_action_start_against_resource(
|
||||
"Importing",
|
||||
"data",
|
||||
self,
|
||||
)
|
||||
|
||||
import_lro = self._import(
|
||||
datasource=datasource, import_request_timeout=import_request_timeout
|
||||
)
|
||||
|
||||
_LOGGER.log_action_started_against_resource_with_lro(
|
||||
"Import", "data", self.__class__, import_lro
|
||||
)
|
||||
|
||||
import_lro.result(timeout=None)
|
||||
|
||||
_LOGGER.log_action_completed_against_resource("data", "imported", self)
|
||||
|
||||
@classmethod
|
||||
def _create(
|
||||
cls,
|
||||
api_client: dataset_service_client.DatasetServiceClient,
|
||||
parent: str,
|
||||
display_name: str,
|
||||
metadata_schema_uri: str,
|
||||
datasource: _datasources.Datasource,
|
||||
request_metadata: Sequence[Tuple[str, str]] = (),
|
||||
labels: Optional[Dict[str, str]] = None,
|
||||
encryption_spec: Optional[gca_encryption_spec.EncryptionSpec] = None,
|
||||
create_request_timeout: Optional[float] = None,
|
||||
) -> operation.Operation:
|
||||
"""Creates a new managed dataset by directly calling API client.
|
||||
|
||||
Args:
|
||||
api_client (dataset_service_client.DatasetServiceClient):
|
||||
An instance of DatasetServiceClient with the correct api_endpoint
|
||||
already set based on user's preferences.
|
||||
parent (str):
|
||||
Required. Also known as common location path, that usually contains the
|
||||
project and location that the user provided to the upstream method.
|
||||
Example: "projects/my-prj/locations/us-central1"
|
||||
display_name (str):
|
||||
Required. The user-defined name of the Dataset.
|
||||
The name can be up to 128 characters long and can be consist
|
||||
of any UTF-8 characters.
|
||||
metadata_schema_uri (str):
|
||||
Required. Points to a YAML file stored on Google Cloud Storage
|
||||
describing additional information about the Dataset. The schema
|
||||
is defined as an OpenAPI 3.0.2 Schema Object. The schema files
|
||||
that can be used here are found in gs://google-cloud-
|
||||
aiplatform/schema/dataset/metadata/.
|
||||
datasource (_datasources.Datasource):
|
||||
Required. Datasource for creating a dataset for Vertex AI.
|
||||
request_metadata (Sequence[Tuple[str, str]]):
|
||||
Strings which should be sent along with the create_dataset
|
||||
request as metadata. Usually to specify special dataset config.
|
||||
labels (Dict[str, str]):
|
||||
Optional. Labels with user-defined metadata to organize your Tensorboards.
|
||||
Label keys and values can be no longer than 64 characters
|
||||
(Unicode codepoints), can only contain lowercase letters, numeric
|
||||
characters, underscores and dashes. International characters are allowed.
|
||||
No more than 64 user labels can be associated with one Tensorboard
|
||||
(System labels are excluded).
|
||||
See https://goo.gl/xmQnxf for more information and examples of labels.
|
||||
System reserved label keys are prefixed with "aiplatform.googleapis.com/"
|
||||
and are immutable.
|
||||
encryption_spec (Optional[gca_encryption_spec.EncryptionSpec]):
|
||||
Optional. The Cloud KMS customer managed encryption key used to protect the dataset.
|
||||
The key needs to be in the same region as where the compute
|
||||
resource is created.
|
||||
|
||||
If set, this Dataset and all sub-resources of this Dataset will be secured by this key.
|
||||
create_request_timeout (float):
|
||||
Optional. The timeout for the create request in seconds.
|
||||
Returns:
|
||||
operation (Operation):
|
||||
An object representing a long-running operation.
|
||||
"""
|
||||
|
||||
gapic_dataset = gca_dataset.Dataset(
|
||||
display_name=display_name,
|
||||
metadata_schema_uri=metadata_schema_uri,
|
||||
metadata=datasource.dataset_metadata,
|
||||
labels=labels,
|
||||
encryption_spec=encryption_spec,
|
||||
)
|
||||
|
||||
return api_client.create_dataset(
|
||||
parent=parent,
|
||||
dataset=gapic_dataset,
|
||||
metadata=request_metadata,
|
||||
timeout=create_request_timeout,
|
||||
)
|
||||
|
||||
def _import(
|
||||
self,
|
||||
datasource: _datasources.DatasourceImportable,
|
||||
import_request_timeout: Optional[float] = None,
|
||||
) -> operation.Operation:
|
||||
"""Imports data into managed dataset by directly calling API client.
|
||||
|
||||
Args:
|
||||
datasource (_datasources.DatasourceImportable):
|
||||
Required. Datasource for importing data to an existing dataset for Vertex AI.
|
||||
import_request_timeout (float):
|
||||
Optional. The timeout for the import request in seconds.
|
||||
|
||||
Returns:
|
||||
operation (Operation):
|
||||
An object representing a long-running operation.
|
||||
"""
|
||||
return self.api_client.import_data(
|
||||
name=self.resource_name,
|
||||
import_configs=[datasource.import_data_config],
|
||||
timeout=import_request_timeout,
|
||||
)
|
||||
|
||||
@base.optional_sync(return_input_arg="self")
|
||||
def import_data(
|
||||
self,
|
||||
gcs_source: Union[str, Sequence[str]],
|
||||
import_schema_uri: str,
|
||||
data_item_labels: Optional[Dict] = None,
|
||||
sync: bool = True,
|
||||
import_request_timeout: Optional[float] = None,
|
||||
) -> "_Dataset":
|
||||
"""Upload data to existing managed dataset.
|
||||
|
||||
Args:
|
||||
gcs_source (Union[str, Sequence[str]]):
|
||||
Required. Google Cloud Storage URI(-s) to the
|
||||
input file(s). May contain wildcards. For more
|
||||
information on wildcards, see
|
||||
https://cloud.google.com/storage/docs/gsutil/addlhelp/WildcardNames.
|
||||
examples:
|
||||
str: "gs://bucket/file.csv"
|
||||
Sequence[str]: ["gs://bucket/file1.csv", "gs://bucket/file2.csv"]
|
||||
import_schema_uri (str):
|
||||
Required. Points to a YAML file stored on Google Cloud
|
||||
Storage describing the import format. Validation will be
|
||||
done against the schema. The schema is defined as an
|
||||
`OpenAPI 3.0.2 Schema
|
||||
Object <https://tinyurl.com/y538mdwt>`__.
|
||||
data_item_labels (Dict):
|
||||
Labels that will be applied to newly imported DataItems. If
|
||||
an identical DataItem as one being imported already exists
|
||||
in the Dataset, then these labels will be appended to these
|
||||
of the already existing one, and if labels with identical
|
||||
key is imported before, the old label value will be
|
||||
overwritten. If two DataItems are identical in the same
|
||||
import data operation, the labels will be combined and if
|
||||
key collision happens in this case, one of the values will
|
||||
be picked randomly. Two DataItems are considered identical
|
||||
if their content bytes are identical (e.g. image bytes or
|
||||
pdf bytes). These labels will be overridden by Annotation
|
||||
labels specified inside index file referenced by
|
||||
``import_schema_uri``,
|
||||
e.g. jsonl file.
|
||||
This arg is not for specifying the annotation name or the
|
||||
training target of your data, but for some global labels of
|
||||
the dataset. E.g.,
|
||||
'data_item_labels={"aiplatform.googleapis.com/ml_use":"training"}'
|
||||
specifies that all the uploaded data are used for training.
|
||||
sync (bool):
|
||||
Whether to execute this method synchronously. If False, this method
|
||||
will be executed in concurrent Future and any downstream object will
|
||||
be immediately returned and synced when the Future has completed.
|
||||
import_request_timeout (float):
|
||||
Optional. The timeout for the import request in seconds.
|
||||
|
||||
Returns:
|
||||
dataset (Dataset):
|
||||
Instantiated representation of the managed dataset resource.
|
||||
"""
|
||||
datasource = _datasources.create_datasource(
|
||||
metadata_schema_uri=self.metadata_schema_uri,
|
||||
import_schema_uri=import_schema_uri,
|
||||
gcs_source=gcs_source,
|
||||
data_item_labels=data_item_labels,
|
||||
)
|
||||
|
||||
self._import_and_wait(
|
||||
datasource=datasource, import_request_timeout=import_request_timeout
|
||||
)
|
||||
return self
|
||||
|
||||
def _validate_and_convert_export_split(
|
||||
self,
|
||||
split: Union[Dict[str, str], Dict[str, float]],
|
||||
) -> Union[gca_dataset.ExportFilterSplit, gca_dataset.ExportFractionSplit]:
|
||||
"""
|
||||
Validates the split for data export. Valid splits are dicts
|
||||
encoding the contents of proto messages ExportFilterSplit or
|
||||
ExportFractionSplit. If the split is valid, this function returns
|
||||
the corresponding convertered proto message.
|
||||
|
||||
split (Union[Dict[str, str], Dict[str, float]]):
|
||||
The instructions how the export data should be split between the
|
||||
training, validation and test sets.
|
||||
"""
|
||||
if len(split) != 3:
|
||||
raise ValueError(
|
||||
"The provided split for data export does not provide enough"
|
||||
"information. It must have three fields, mapping to training,"
|
||||
"validation and test splits respectively."
|
||||
)
|
||||
|
||||
if not ("training_filter" in split or "training_fraction" in split):
|
||||
raise ValueError(
|
||||
"The provided filter for data export does not provide enough"
|
||||
"information. It must have three fields, mapping to training,"
|
||||
"validation and test respectively."
|
||||
)
|
||||
|
||||
if "training_filter" in split:
|
||||
if (
|
||||
"validation_filter" in split
|
||||
and "test_filter" in split
|
||||
and isinstance(split["training_filter"], str)
|
||||
and isinstance(split["validation_filter"], str)
|
||||
and isinstance(split["test_filter"], str)
|
||||
):
|
||||
return gca_dataset.ExportFilterSplit(
|
||||
training_filter=split["training_filter"],
|
||||
validation_filter=split["validation_filter"],
|
||||
test_filter=split["test_filter"],
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
"The provided ExportFilterSplit does not contain all"
|
||||
"three required fields: training_filter, "
|
||||
"validation_filter and test_filter."
|
||||
)
|
||||
else:
|
||||
if (
|
||||
"validation_fraction" in split
|
||||
and "test_fraction" in split
|
||||
and isinstance(split["training_fraction"], float)
|
||||
and isinstance(split["validation_fraction"], float)
|
||||
and isinstance(split["test_fraction"], float)
|
||||
):
|
||||
return gca_dataset.ExportFractionSplit(
|
||||
training_fraction=split["training_fraction"],
|
||||
validation_fraction=split["validation_fraction"],
|
||||
test_fraction=split["test_fraction"],
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
"The provided ExportFractionSplit does not contain all"
|
||||
"three required fields: training_fraction, "
|
||||
"validation_fraction and test_fraction."
|
||||
)
|
||||
|
||||
def _get_completed_export_data_operation(
|
||||
self,
|
||||
output_dir: str,
|
||||
export_use: Optional[gca_dataset.ExportDataConfig.ExportUse] = None,
|
||||
annotation_filter: Optional[str] = None,
|
||||
saved_query_id: Optional[str] = None,
|
||||
annotation_schema_uri: Optional[str] = None,
|
||||
split: Optional[
|
||||
Union[gca_dataset.ExportFilterSplit, gca_dataset.ExportFractionSplit]
|
||||
] = None,
|
||||
) -> gca_dataset_service.ExportDataResponse:
|
||||
self.wait()
|
||||
|
||||
# TODO(b/171311614): Add support for BigQuery export path
|
||||
export_data_config = gca_dataset.ExportDataConfig(
|
||||
gcs_destination=gca_io.GcsDestination(output_uri_prefix=output_dir)
|
||||
)
|
||||
if export_use is not None:
|
||||
export_data_config.export_use = export_use
|
||||
if annotation_filter is not None:
|
||||
export_data_config.annotation_filter = annotation_filter
|
||||
if saved_query_id is not None:
|
||||
export_data_config.saved_query_id = saved_query_id
|
||||
if annotation_schema_uri is not None:
|
||||
export_data_config.annotation_schema_uri = annotation_schema_uri
|
||||
if split is not None:
|
||||
if isinstance(split, gca_dataset.ExportFilterSplit):
|
||||
export_data_config.filter_split = split
|
||||
elif isinstance(split, gca_dataset.ExportFractionSplit):
|
||||
export_data_config.fraction_split = split
|
||||
|
||||
_LOGGER.log_action_start_against_resource("Exporting", "data", self)
|
||||
|
||||
export_lro = self.api_client.export_data(
|
||||
name=self.resource_name, export_config=export_data_config
|
||||
)
|
||||
|
||||
_LOGGER.log_action_started_against_resource_with_lro(
|
||||
"Export", "data", self.__class__, export_lro
|
||||
)
|
||||
|
||||
export_data_response = export_lro.result()
|
||||
|
||||
_LOGGER.log_action_completed_against_resource("data", "export", self)
|
||||
|
||||
return export_data_response
|
||||
|
||||
# TODO(b/174751568) add optional sync support
|
||||
def export_data(self, output_dir: str) -> Sequence[str]:
|
||||
"""Exports data to output dir to GCS.
|
||||
|
||||
Args:
|
||||
output_dir (str):
|
||||
Required. The Google Cloud Storage location where the output is to
|
||||
be written to. In the given directory a new directory will be
|
||||
created with name:
|
||||
``export-data-<dataset-display-name>-<timestamp-of-export-call>``
|
||||
where timestamp is in YYYYMMDDHHMMSS format. All export
|
||||
output will be written into that directory. Inside that
|
||||
directory, annotations with the same schema will be grouped
|
||||
into sub directories which are named with the corresponding
|
||||
annotations' schema title. Inside these sub directories, a
|
||||
schema.yaml will be created to describe the output format.
|
||||
|
||||
If the uri doesn't end with '/', a '/' will be automatically
|
||||
appended. The directory is created if it doesn't exist.
|
||||
|
||||
Returns:
|
||||
exported_files (Sequence[str]):
|
||||
All of the files that are exported in this export operation.
|
||||
"""
|
||||
return self._get_completed_export_data_operation(output_dir).exported_files
|
||||
|
||||
def export_data_for_custom_training(
|
||||
self,
|
||||
output_dir: str,
|
||||
annotation_filter: Optional[str] = None,
|
||||
saved_query_id: Optional[str] = None,
|
||||
annotation_schema_uri: Optional[str] = None,
|
||||
split: Optional[Union[Dict[str, str], Dict[str, float]]] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""Exports data to output dir to GCS for custom training use case.
|
||||
|
||||
Example annotation_schema_uri (image classification):
|
||||
gs://google-cloud-aiplatform/schema/dataset/annotation/image_classification_1.0.0.yaml
|
||||
|
||||
Example split (filter split):
|
||||
{
|
||||
"training_filter": "labels.aiplatform.googleapis.com/ml_use=training",
|
||||
"validation_filter": "labels.aiplatform.googleapis.com/ml_use=validation",
|
||||
"test_filter": "labels.aiplatform.googleapis.com/ml_use=test",
|
||||
}
|
||||
Example split (fraction split):
|
||||
{
|
||||
"training_fraction": 0.7,
|
||||
"validation_fraction": 0.2,
|
||||
"test_fraction": 0.1,
|
||||
}
|
||||
|
||||
Args:
|
||||
output_dir (str):
|
||||
Required. The Google Cloud Storage location where the output is to
|
||||
be written to. In the given directory a new directory will be
|
||||
created with name:
|
||||
``export-data-<dataset-display-name>-<timestamp-of-export-call>``
|
||||
where timestamp is in YYYYMMDDHHMMSS format. All export
|
||||
output will be written into that directory. Inside that
|
||||
directory, annotations with the same schema will be grouped
|
||||
into sub directories which are named with the corresponding
|
||||
annotations' schema title. Inside these sub directories, a
|
||||
schema.yaml will be created to describe the output format.
|
||||
|
||||
If the uri doesn't end with '/', a '/' will be automatically
|
||||
appended. The directory is created if it doesn't exist.
|
||||
annotation_filter (str):
|
||||
Optional. An expression for filtering what part of the Dataset
|
||||
is to be exported.
|
||||
Only Annotations that match this filter will be exported.
|
||||
The filter syntax is the same as in
|
||||
[ListAnnotations][DatasetService.ListAnnotations].
|
||||
saved_query_id (str):
|
||||
Optional. The ID of a SavedQuery (annotation set) under this
|
||||
Dataset used for filtering Annotations for training.
|
||||
|
||||
Only used for custom training data export use cases.
|
||||
Only applicable to Datasets that have SavedQueries.
|
||||
|
||||
Only Annotations that are associated with this SavedQuery are
|
||||
used in respectively training. When used in conjunction with
|
||||
annotations_filter, the Annotations used for training are
|
||||
filtered by both saved_query_id and annotations_filter.
|
||||
|
||||
Only one of saved_query_id and annotation_schema_uri should be
|
||||
specified as both of them represent the same thing: problem
|
||||
type.
|
||||
annotation_schema_uri (str):
|
||||
Optional. The Cloud Storage URI that points to a YAML file
|
||||
describing the annotation schema. The schema is defined as an
|
||||
OpenAPI 3.0.2 Schema Object. The schema files that can be used
|
||||
here are found in
|
||||
gs://google-cloud-aiplatform/schema/dataset/annotation/, note
|
||||
that the chosen schema must be consistent with
|
||||
metadata_schema_uri of this Dataset.
|
||||
|
||||
Only used for custom training data export use cases.
|
||||
Only applicable if this Dataset that have DataItems and
|
||||
Annotations.
|
||||
|
||||
Only Annotations that both match this schema and belong to
|
||||
DataItems not ignored by the split method are used in
|
||||
respectively training, validation or test role, depending on the
|
||||
role of the DataItem they are on.
|
||||
|
||||
When used in conjunction with annotations_filter, the
|
||||
Annotations used for training are filtered by both
|
||||
annotations_filter and annotation_schema_uri.
|
||||
split (Union[Dict[str, str], Dict[str, float]]):
|
||||
The instructions how the export data should be split between the
|
||||
training, validation and test sets.
|
||||
|
||||
Returns:
|
||||
export_data_response (Dict):
|
||||
Response message for DatasetService.ExportData in Dictionary
|
||||
format.
|
||||
"""
|
||||
split = self._validate_and_convert_export_split(split)
|
||||
|
||||
return json_format.MessageToDict(
|
||||
self._get_completed_export_data_operation(
|
||||
output_dir,
|
||||
gca_dataset.ExportDataConfig.ExportUse.CUSTOM_CODE_TRAINING,
|
||||
annotation_filter,
|
||||
saved_query_id,
|
||||
annotation_schema_uri,
|
||||
split,
|
||||
)._pb
|
||||
)
|
||||
|
||||
def update(
|
||||
self,
|
||||
*,
|
||||
display_name: Optional[str] = None,
|
||||
labels: Optional[Dict[str, str]] = None,
|
||||
description: Optional[str] = None,
|
||||
update_request_timeout: Optional[float] = None,
|
||||
) -> "_Dataset":
|
||||
"""Update the dataset.
|
||||
Updatable fields:
|
||||
- ``display_name``
|
||||
- ``description``
|
||||
- ``labels``
|
||||
|
||||
Args:
|
||||
display_name (str):
|
||||
Optional. The user-defined name of the Dataset.
|
||||
The name can be up to 128 characters long and can be consist
|
||||
of any UTF-8 characters.
|
||||
labels (Dict[str, str]):
|
||||
Optional. Labels with user-defined metadata to organize your Tensorboards.
|
||||
Label keys and values can be no longer than 64 characters
|
||||
(Unicode codepoints), can only contain lowercase letters, numeric
|
||||
characters, underscores and dashes. International characters are allowed.
|
||||
No more than 64 user labels can be associated with one Tensorboard
|
||||
(System labels are excluded).
|
||||
See https://goo.gl/xmQnxf for more information and examples of labels.
|
||||
System reserved label keys are prefixed with "aiplatform.googleapis.com/"
|
||||
and are immutable.
|
||||
description (str):
|
||||
Optional. The description of the Dataset.
|
||||
update_request_timeout (float):
|
||||
Optional. The timeout for the update request in seconds.
|
||||
|
||||
Returns:
|
||||
dataset (Dataset):
|
||||
Updated dataset.
|
||||
"""
|
||||
|
||||
update_mask = field_mask_pb2.FieldMask()
|
||||
if display_name:
|
||||
update_mask.paths.append("display_name")
|
||||
|
||||
if labels:
|
||||
update_mask.paths.append("labels")
|
||||
|
||||
if description:
|
||||
update_mask.paths.append("description")
|
||||
|
||||
update_dataset = gca_dataset.Dataset(
|
||||
name=self.resource_name,
|
||||
display_name=display_name,
|
||||
description=description,
|
||||
labels=labels,
|
||||
)
|
||||
|
||||
self._gca_resource = self.api_client.update_dataset(
|
||||
dataset=update_dataset,
|
||||
update_mask=update_mask,
|
||||
timeout=update_request_timeout,
|
||||
)
|
||||
|
||||
return self
|
||||
|
||||
@classmethod
|
||||
def list(
|
||||
cls,
|
||||
filter: Optional[str] = None,
|
||||
order_by: Optional[str] = None,
|
||||
project: Optional[str] = None,
|
||||
location: Optional[str] = None,
|
||||
credentials: Optional[auth_credentials.Credentials] = None,
|
||||
) -> List[base.VertexAiResourceNoun]:
|
||||
"""List all instances of this Dataset resource.
|
||||
|
||||
Example Usage:
|
||||
|
||||
aiplatform.TabularDataset.list(
|
||||
filter='labels.my_key="my_value"',
|
||||
order_by='display_name'
|
||||
)
|
||||
|
||||
Args:
|
||||
filter (str):
|
||||
Optional. An expression for filtering the results of the request.
|
||||
For field names both snake_case and camelCase are supported.
|
||||
order_by (str):
|
||||
Optional. A comma-separated list of fields to order by, sorted in
|
||||
ascending order. Use "desc" after a field name for descending.
|
||||
Supported fields: `display_name`, `create_time`, `update_time`
|
||||
project (str):
|
||||
Optional. Project to retrieve list from. If not set, project
|
||||
set in aiplatform.init will be used.
|
||||
location (str):
|
||||
Optional. Location to retrieve list from. If not set, location
|
||||
set in aiplatform.init will be used.
|
||||
credentials (auth_credentials.Credentials):
|
||||
Optional. Custom credentials to use to retrieve list. Overrides
|
||||
credentials set in aiplatform.init.
|
||||
|
||||
Returns:
|
||||
List[base.VertexAiResourceNoun] - A list of Dataset resource objects
|
||||
"""
|
||||
|
||||
dataset_subclass_filter = (
|
||||
lambda gapic_obj: gapic_obj.metadata_schema_uri
|
||||
in cls._supported_metadata_schema_uris
|
||||
)
|
||||
|
||||
return cls._list_with_local_order(
|
||||
cls_filter=dataset_subclass_filter,
|
||||
filter=filter,
|
||||
order_by=order_by,
|
||||
project=project,
|
||||
location=location,
|
||||
credentials=credentials,
|
||||
)
|
||||
@@ -0,0 +1,198 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
# Copyright 2020 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
from typing import Dict, Optional, Sequence, Tuple, Union
|
||||
|
||||
from google.auth import credentials as auth_credentials
|
||||
|
||||
from google.cloud.aiplatform import datasets
|
||||
from google.cloud.aiplatform.datasets import _datasources
|
||||
from google.cloud.aiplatform import initializer
|
||||
from google.cloud.aiplatform import schema
|
||||
from google.cloud.aiplatform import utils
|
||||
|
||||
|
||||
class ImageDataset(datasets._Dataset):
|
||||
"""A managed image dataset resource for Vertex AI.
|
||||
|
||||
Use this class to work with a managed image dataset. To create a managed
|
||||
image dataset, you need a datasource file in CSV format and a schema file in
|
||||
YAML format. A schema is optional for a custom model. You put the CSV file
|
||||
and the schema into Cloud Storage buckets.
|
||||
|
||||
Use image data for the following objectives:
|
||||
|
||||
* Single-label classification. For more information, see
|
||||
[Prepare image training data for single-label classification](https://cloud.google.com/vertex-ai/docs/image-data/classification/prepare-data#single-label-classification).
|
||||
* Multi-label classification. For more information, see [Prepare image training data for multi-label classification](https://cloud.google.com/vertex-ai/docs/image-data/classification/prepare-data#multi-label-classification).
|
||||
* Object detection. For more information, see [Prepare image training data
|
||||
for object detection](https://cloud.google.com/vertex-ai/docs/image-data/object-detection/prepare-data).
|
||||
|
||||
The following code shows you how to create an image dataset by importing data from
|
||||
a CSV datasource file and a YAML schema file. The schema file you use
|
||||
depends on whether your image dataset is used for single-label
|
||||
classification, multi-label classification, or object detection.
|
||||
|
||||
```py
|
||||
my_dataset = aiplatform.ImageDataset.create(
|
||||
display_name="my-image-dataset",
|
||||
gcs_source=['gs://path/to/my/image-dataset.csv'],
|
||||
import_schema_uri=['gs://path/to/my/schema.yaml']
|
||||
)
|
||||
```
|
||||
"""
|
||||
|
||||
_supported_metadata_schema_uris: Optional[Tuple[str]] = (
|
||||
schema.dataset.metadata.image,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def create(
|
||||
cls,
|
||||
display_name: Optional[str] = None,
|
||||
gcs_source: Optional[Union[str, Sequence[str]]] = None,
|
||||
import_schema_uri: Optional[str] = None,
|
||||
data_item_labels: Optional[Dict] = None,
|
||||
project: Optional[str] = None,
|
||||
location: Optional[str] = None,
|
||||
credentials: Optional[auth_credentials.Credentials] = None,
|
||||
request_metadata: Optional[Sequence[Tuple[str, str]]] = (),
|
||||
labels: Optional[Dict[str, str]] = None,
|
||||
encryption_spec_key_name: Optional[str] = None,
|
||||
sync: bool = True,
|
||||
create_request_timeout: Optional[float] = None,
|
||||
) -> "ImageDataset":
|
||||
"""Creates a new image dataset.
|
||||
|
||||
Optionally imports data into the dataset when a source and
|
||||
`import_schema_uri` are passed in.
|
||||
|
||||
Args:
|
||||
display_name (str):
|
||||
Optional. The user-defined name of the dataset. The name must
|
||||
contain 128 or fewer UTF-8 characters.
|
||||
gcs_source (Union[str, Sequence[str]]):
|
||||
Optional. The URI to one or more Google Cloud Storage buckets
|
||||
that contain your datasets. For example, `str:
|
||||
"gs://bucket/file.csv"` or `Sequence[str]:
|
||||
["gs://bucket/file1.csv", "gs://bucket/file2.csv"]`.
|
||||
import_schema_uri (str):
|
||||
Optional. A URI for a YAML file stored in Cloud Storage that
|
||||
describes the import schema used to validate the
|
||||
dataset. The schema is an
|
||||
[OpenAPI 3.0.2 Schema](https://tinyurl.com/y538mdwt) object.
|
||||
data_item_labels (Dict):
|
||||
Optional. A dictionary of label information. Each dictionary
|
||||
item contains a label and a label key. Each image in the dataset
|
||||
includes one dictionary of label information. If a data item is
|
||||
added or merged into a dataset, and that data item contains an
|
||||
image that's identical to an image that’s already in the
|
||||
dataset, then the data items are merged. If two identical labels
|
||||
are detected during the merge, each with a different label key,
|
||||
then one of the label and label key dictionary items is randomly
|
||||
chosen to be into the merged data item. Images and documents are
|
||||
compared using their binary data (bytes), not on their content.
|
||||
If annotation labels are referenced in a schema specified by the
|
||||
`import_schema_url` parameter, then the labels in the
|
||||
`data_item_labels` dictionary are overriden by the annotations.
|
||||
project (str):
|
||||
Optional. The name of the Google Cloud project to which this
|
||||
`ImageDataset` is uploaded. This overrides the project that
|
||||
was set by `aiplatform.init`.
|
||||
location (str):
|
||||
Optional. The Google Cloud region where this dataset is uploaded. This
|
||||
region overrides the region that was set by `aiplatform.init`.
|
||||
credentials (auth_credentials.Credentials):
|
||||
Optional. The credentials that are used to upload the
|
||||
`ImageDataset`. These credentials override the credentials set
|
||||
by `aiplatform.init`.
|
||||
request_metadata (Sequence[Tuple[str, str]]):
|
||||
Optional. Strings that contain metadata that's sent with the request.
|
||||
labels (Dict[str, str]):
|
||||
Optional. Labels with user-defined metadata to organize your
|
||||
Vertex AI Tensorboards. The maximum length of a key and of a
|
||||
value is 64 unicode characters. Labels and keys can contain only
|
||||
lowercase letters, numeric characters, underscores, and dashes.
|
||||
International characters are allowed. No more than 64 user
|
||||
labels can be associated with one Tensorboard (system labels are
|
||||
excluded). For more information and examples of using labels, see
|
||||
[Using labels to organize Google Cloud Platform resources](https://goo.gl/xmQnxf).
|
||||
System reserved label keys are prefixed with
|
||||
`aiplatform.googleapis.com/` and are immutable.
|
||||
encryption_spec_key_name (Optional[str]):
|
||||
Optional. The Cloud KMS resource identifier of the customer
|
||||
managed encryption key that's used to protect the dataset. The
|
||||
format of the key is
|
||||
`projects/my-project/locations/my-region/keyRings/my-kr/cryptoKeys/my-key`.
|
||||
The key needs to be in the same region as where the compute
|
||||
resource is created.
|
||||
|
||||
If `encryption_spec_key_name` is set, this image dataset and
|
||||
all of its sub-resources are secured by this key.
|
||||
|
||||
This `encryption_spec_key_name` overrides the
|
||||
`encryption_spec_key_name` set by `aiplatform.init`.
|
||||
sync (bool):
|
||||
If `true`, the `create` method creates an image dataset
|
||||
synchronously. If `false`, the `create` method creates an image
|
||||
dataset asynchronously.
|
||||
create_request_timeout (float):
|
||||
Optional. The number of seconds for the timeout of the create
|
||||
request.
|
||||
|
||||
Returns:
|
||||
image_dataset (ImageDataset):
|
||||
An instantiated representation of the managed `ImageDataset`
|
||||
resource.
|
||||
"""
|
||||
if not display_name:
|
||||
display_name = cls._generate_display_name()
|
||||
|
||||
utils.validate_display_name(display_name)
|
||||
if labels:
|
||||
utils.validate_labels(labels)
|
||||
|
||||
api_client = cls._instantiate_client(location=location, credentials=credentials)
|
||||
|
||||
metadata_schema_uri = schema.dataset.metadata.image
|
||||
|
||||
datasource = _datasources.create_datasource(
|
||||
metadata_schema_uri=metadata_schema_uri,
|
||||
import_schema_uri=import_schema_uri,
|
||||
gcs_source=gcs_source,
|
||||
data_item_labels=data_item_labels,
|
||||
)
|
||||
|
||||
return cls._create_and_import(
|
||||
api_client=api_client,
|
||||
parent=initializer.global_config.common_location_path(
|
||||
project=project, location=location
|
||||
),
|
||||
display_name=display_name,
|
||||
metadata_schema_uri=metadata_schema_uri,
|
||||
datasource=datasource,
|
||||
project=project or initializer.global_config.project,
|
||||
location=location or initializer.global_config.location,
|
||||
credentials=credentials or initializer.global_config.credentials,
|
||||
request_metadata=request_metadata,
|
||||
labels=labels,
|
||||
encryption_spec=initializer.global_config.get_encryption_spec(
|
||||
encryption_spec_key_name=encryption_spec_key_name
|
||||
),
|
||||
sync=sync,
|
||||
create_request_timeout=create_request_timeout,
|
||||
)
|
||||
@@ -0,0 +1,318 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
# Copyright 2022 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
from typing import Dict, Optional, Sequence, Tuple, Union, TYPE_CHECKING
|
||||
|
||||
from google.auth import credentials as auth_credentials
|
||||
|
||||
from google.cloud.aiplatform import base
|
||||
from google.cloud.aiplatform import datasets
|
||||
from google.cloud.aiplatform.datasets import _datasources
|
||||
from google.cloud.aiplatform import initializer
|
||||
from google.cloud.aiplatform import schema
|
||||
from google.cloud.aiplatform import utils
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from google.cloud import bigquery
|
||||
|
||||
_AUTOML_TRAINING_MIN_ROWS = 1000
|
||||
|
||||
_LOGGER = base.Logger(__name__)
|
||||
|
||||
|
||||
class TabularDataset(datasets._ColumnNamesDataset):
|
||||
"""A managed tabular dataset resource for Vertex AI.
|
||||
|
||||
Use this class to work with tabular datasets. You can use a CSV file, BigQuery, or a pandas
|
||||
[`DataFrame`](https://pandas.pydata.org/docs/reference/api/pandas.DataFrame.html)
|
||||
to create a tabular dataset. For more information about paging through
|
||||
BigQuery data, see [Read data with BigQuery API using
|
||||
pagination](https://cloud.google.com/bigquery/docs/paging-results). For more
|
||||
information about tabular data, see [Tabular
|
||||
data](https://cloud.google.com/vertex-ai/docs/training-overview#tabular_data).
|
||||
|
||||
The following code shows you how to create and import a tabular
|
||||
dataset with a CSV file.
|
||||
|
||||
```py
|
||||
my_dataset = aiplatform.TabularDataset.create(
|
||||
display_name="my-dataset", gcs_source=['gs://path/to/my/dataset.csv'])
|
||||
```
|
||||
Contrary to unstructured datasets, creating and importing a tabular dataset
|
||||
can only be done in a single step.
|
||||
|
||||
If you create a tabular dataset with a pandas
|
||||
[`DataFrame`](https://pandas.pydata.org/docs/reference/api/pandas.DataFrame.html),
|
||||
you need to use a BigQuery table to stage the data for Vertex AI:
|
||||
|
||||
```py
|
||||
my_dataset = aiplatform.TabularDataset.create_from_dataframe(
|
||||
df_source=my_pandas_dataframe,
|
||||
staging_path=f"bq://{bq_dataset_id}.table-unique"
|
||||
)
|
||||
```
|
||||
|
||||
"""
|
||||
|
||||
_supported_metadata_schema_uris: Optional[Tuple[str]] = (
|
||||
schema.dataset.metadata.tabular,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def create(
|
||||
cls,
|
||||
display_name: Optional[str] = None,
|
||||
gcs_source: Optional[Union[str, Sequence[str]]] = None,
|
||||
bq_source: Optional[str] = None,
|
||||
project: Optional[str] = None,
|
||||
location: Optional[str] = None,
|
||||
credentials: Optional[auth_credentials.Credentials] = None,
|
||||
request_metadata: Optional[Sequence[Tuple[str, str]]] = (),
|
||||
labels: Optional[Dict[str, str]] = None,
|
||||
encryption_spec_key_name: Optional[str] = None,
|
||||
sync: bool = True,
|
||||
create_request_timeout: Optional[float] = None,
|
||||
) -> "TabularDataset":
|
||||
"""Creates a tabular dataset.
|
||||
|
||||
Args:
|
||||
display_name (str):
|
||||
Optional. The user-defined name of the dataset. The name must
|
||||
contain 128 or fewer UTF-8 characters.
|
||||
gcs_source (Union[str, Sequence[str]]):
|
||||
Optional. The URI to one or more Google Cloud Storage buckets that contain
|
||||
your datasets. For example, `str: "gs://bucket/file.csv"` or
|
||||
`Sequence[str]: ["gs://bucket/file1.csv",
|
||||
"gs://bucket/file2.csv"]`. Either `gcs_source` or `bq_source` must be specified.
|
||||
bq_source (str):
|
||||
Optional. The URI to a BigQuery table that's used as an input source. For
|
||||
example, `bq://project.dataset.table_name`. Either `gcs_source`
|
||||
or `bq_source` must be specified.
|
||||
project (str):
|
||||
Optional. The name of the Google Cloud project to which this
|
||||
`TabularDataset` is uploaded. This overrides the project that
|
||||
was set by `aiplatform.init`.
|
||||
location (str):
|
||||
Optional. The Google Cloud region where this dataset is uploaded. This
|
||||
region overrides the region that was set by `aiplatform.init`.
|
||||
credentials (auth_credentials.Credentials):
|
||||
Optional. The credentials that are used to upload the `TabularDataset`.
|
||||
These credentials override the credentials set by
|
||||
`aiplatform.init`.
|
||||
request_metadata (Sequence[Tuple[str, str]]):
|
||||
Optional. Strings that contain metadata that's sent with the request.
|
||||
labels (Dict[str, str]):
|
||||
Optional. Labels with user-defined metadata to organize your
|
||||
Vertex AI Tensorboards. The maximum length of a key and of a
|
||||
value is 64 unicode characters. Labels and keys can contain only
|
||||
lowercase letters, numeric characters, underscores, and dashes.
|
||||
International characters are allowed. No more than 64 user
|
||||
labels can be associated with one Tensorboard (system labels are
|
||||
excluded). For more information and examples of using labels, see
|
||||
[Using labels to organize Google Cloud Platform resources](https://goo.gl/xmQnxf).
|
||||
System reserved label keys are prefixed with
|
||||
`aiplatform.googleapis.com/` and are immutable.
|
||||
encryption_spec_key_name (Optional[str]):
|
||||
Optional. The Cloud KMS resource identifier of the customer
|
||||
managed encryption key that's used to protect the dataset. The
|
||||
format of the key is
|
||||
`projects/my-project/locations/my-region/keyRings/my-kr/cryptoKeys/my-key`.
|
||||
The key needs to be in the same region as where the compute
|
||||
resource is created.
|
||||
|
||||
If `encryption_spec_key_name` is set, this `TabularDataset` and
|
||||
all of its sub-resources are secured by this key.
|
||||
|
||||
This `encryption_spec_key_name` overrides the
|
||||
`encryption_spec_key_name` set by `aiplatform.init`.
|
||||
sync (bool):
|
||||
If `true`, the `create` method creates a tabular dataset
|
||||
synchronously. If `false`, the `create` method creates a tabular
|
||||
dataset asynchronously.
|
||||
create_request_timeout (float):
|
||||
Optional. The number of seconds for the timeout of the create
|
||||
request.
|
||||
|
||||
Returns:
|
||||
tabular_dataset (TabularDataset):
|
||||
An instantiated representation of the managed `TabularDataset` resource.
|
||||
"""
|
||||
if not display_name:
|
||||
display_name = cls._generate_display_name()
|
||||
utils.validate_display_name(display_name)
|
||||
if labels:
|
||||
utils.validate_labels(labels)
|
||||
|
||||
api_client = cls._instantiate_client(location=location, credentials=credentials)
|
||||
|
||||
metadata_schema_uri = schema.dataset.metadata.tabular
|
||||
|
||||
datasource = _datasources.create_datasource(
|
||||
metadata_schema_uri=metadata_schema_uri,
|
||||
gcs_source=gcs_source,
|
||||
bq_source=bq_source,
|
||||
)
|
||||
|
||||
return cls._create_and_import(
|
||||
api_client=api_client,
|
||||
parent=initializer.global_config.common_location_path(
|
||||
project=project, location=location
|
||||
),
|
||||
display_name=display_name,
|
||||
metadata_schema_uri=metadata_schema_uri,
|
||||
datasource=datasource,
|
||||
project=project or initializer.global_config.project,
|
||||
location=location or initializer.global_config.location,
|
||||
credentials=credentials or initializer.global_config.credentials,
|
||||
request_metadata=request_metadata,
|
||||
labels=labels,
|
||||
encryption_spec=initializer.global_config.get_encryption_spec(
|
||||
encryption_spec_key_name=encryption_spec_key_name
|
||||
),
|
||||
sync=sync,
|
||||
create_request_timeout=create_request_timeout,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def create_from_dataframe(
|
||||
cls,
|
||||
df_source: "pd.DataFrame", # noqa: F821 - skip check for undefined name 'pd'
|
||||
staging_path: str,
|
||||
bq_schema: Optional[Union[str, "bigquery.SchemaField"]] = None,
|
||||
display_name: Optional[str] = None,
|
||||
project: Optional[str] = None,
|
||||
location: Optional[str] = None,
|
||||
credentials: Optional[auth_credentials.Credentials] = None,
|
||||
) -> "TabularDataset":
|
||||
"""Creates a new tabular dataset from a pandas `DataFrame`.
|
||||
|
||||
Args:
|
||||
df_source (pd.DataFrame):
|
||||
Required. A pandas
|
||||
[`DataFrame`](https://pandas.pydata.org/docs/reference/api/pandas.DataFrame.html)
|
||||
containing the source data for ingestion as a `TabularDataset`.
|
||||
This method uses the data types from the provided `DataFrame`
|
||||
when the `TabularDataset` is created.
|
||||
staging_path (str):
|
||||
Required. The BigQuery table used to stage the data for Vertex
|
||||
AI. Because Vertex AI maintains a reference to this source to
|
||||
create the `TabularDataset`, you shouldn't delete this BigQuery
|
||||
table. For example: `bq://my-project.my-dataset.my-table`.
|
||||
If the specified BigQuery table doesn't exist, then the table is
|
||||
created for you. If the provided BigQuery table already exists,
|
||||
and the schemas of the BigQuery table and your DataFrame match,
|
||||
then the data in your local `DataFrame` is appended to the table.
|
||||
The location of the BigQuery table must conform to the
|
||||
[BigQuery location requirements](https://cloud.google.com/vertex-ai/docs/general/locations#bq-locations).
|
||||
bq_schema (Optional[Union[str, bigquery.SchemaField]]):
|
||||
Optional. If not set, BigQuery autodetects the schema using the
|
||||
column types of your `DataFrame`. If set, BigQuery uses the
|
||||
schema you provide when the staging table is created. For more
|
||||
information,
|
||||
see the BigQuery
|
||||
[`LoadJobConfig.schema`](https://cloud.google.com/python/docs/reference/bigquery/latest/google.cloud.bigquery.job.LoadJobConfig#google_cloud_bigquery_job_LoadJobConfig_schema)
|
||||
property.
|
||||
display_name (str):
|
||||
Optional. The user-defined name of the `Dataset`. The name must
|
||||
contain 128 or fewer UTF-8 characters.
|
||||
project (str):
|
||||
Optional. The project to upload this dataset to. This overrides
|
||||
the project set using `aiplatform.init`.
|
||||
location (str):
|
||||
Optional. The location to upload this dataset to. This overrides
|
||||
the location set using `aiplatform.init`.
|
||||
credentials (auth_credentials.Credentials):
|
||||
Optional. The custom credentials used to upload this dataset.
|
||||
This overrides credentials set using `aiplatform.init`.
|
||||
Returns:
|
||||
tabular_dataset (TabularDataset):
|
||||
An instantiated representation of the managed `TabularDataset` resource.
|
||||
"""
|
||||
|
||||
if staging_path.startswith("bq://"):
|
||||
bq_staging_path = staging_path[len("bq://") :]
|
||||
else:
|
||||
raise ValueError(
|
||||
"Only BigQuery staging paths are supported. Provide a staging path in the format `bq://your-project.your-dataset.your-table`."
|
||||
)
|
||||
|
||||
try:
|
||||
import pyarrow # noqa: F401 - skip check for 'pyarrow' which is required when using 'google.cloud.bigquery'
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"Pyarrow is not installed, and is required to use the BigQuery client."
|
||||
'Please install the SDK using "pip install google-cloud-aiplatform[datasets]"'
|
||||
)
|
||||
import pandas.api.types as pd_types
|
||||
|
||||
if any(
|
||||
[
|
||||
pd_types.is_datetime64_any_dtype(df_source[column])
|
||||
for column in df_source.columns
|
||||
]
|
||||
):
|
||||
_LOGGER.info(
|
||||
"Received datetime-like column in the dataframe. Please note that the column could be interpreted differently in BigQuery depending on which major version you are using. For more information, please reference the BigQuery v3 release notes here: https://github.com/googleapis/python-bigquery/releases/tag/v3.0.0"
|
||||
)
|
||||
|
||||
if len(df_source) < _AUTOML_TRAINING_MIN_ROWS:
|
||||
_LOGGER.info(
|
||||
"Your DataFrame has %s rows and AutoML requires %s rows to train on tabular data. You can still train a custom model once your dataset has been uploaded to Vertex, but you will not be able to use AutoML for training."
|
||||
% (len(df_source), _AUTOML_TRAINING_MIN_ROWS),
|
||||
)
|
||||
|
||||
# Loading bigquery lazily to avoid auto-loading it when importing vertexai
|
||||
from google.cloud import bigquery # pylint: disable=g-import-not-at-top
|
||||
|
||||
bigquery_client = bigquery.Client(
|
||||
project=project or initializer.global_config.project,
|
||||
credentials=credentials or initializer.global_config.credentials,
|
||||
)
|
||||
|
||||
try:
|
||||
parquet_options = bigquery.format_options.ParquetOptions()
|
||||
parquet_options.enable_list_inference = True
|
||||
|
||||
job_config = bigquery.LoadJobConfig(
|
||||
source_format=bigquery.SourceFormat.PARQUET,
|
||||
parquet_options=parquet_options,
|
||||
)
|
||||
|
||||
if bq_schema:
|
||||
job_config.schema = bq_schema
|
||||
|
||||
job = bigquery_client.load_table_from_dataframe(
|
||||
dataframe=df_source, destination=bq_staging_path, job_config=job_config
|
||||
)
|
||||
|
||||
job.result()
|
||||
|
||||
finally:
|
||||
dataset_from_dataframe = cls.create(
|
||||
display_name=display_name,
|
||||
bq_source=staging_path,
|
||||
project=project,
|
||||
location=location,
|
||||
credentials=credentials,
|
||||
)
|
||||
|
||||
return dataset_from_dataframe
|
||||
|
||||
def import_data(self):
|
||||
raise NotImplementedError(
|
||||
f"{self.__class__.__name__} class does not support 'import_data'"
|
||||
)
|
||||
@@ -0,0 +1,207 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
# Copyright 2021 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
from typing import Dict, Optional, Sequence, Tuple, Union
|
||||
|
||||
from google.auth import credentials as auth_credentials
|
||||
|
||||
from google.cloud.aiplatform import datasets
|
||||
from google.cloud.aiplatform.datasets import _datasources
|
||||
from google.cloud.aiplatform import initializer
|
||||
from google.cloud.aiplatform import schema
|
||||
from google.cloud.aiplatform import utils
|
||||
|
||||
|
||||
class TextDataset(datasets._Dataset):
|
||||
"""A managed text dataset resource for Vertex AI.
|
||||
|
||||
Use this class to work with a managed text dataset. To create a managed
|
||||
text dataset, you need a datasource file in CSV format and a schema file in
|
||||
YAML format. A schema is optional for a custom model. The CSV file and the
|
||||
schema are accessed in Cloud Storage buckets.
|
||||
|
||||
Use text data for the following objectives:
|
||||
|
||||
* Classification. For more information, see
|
||||
[Prepare text training data for classification](https://cloud.google.com/vertex-ai/docs/text-data/classification/prepare-data).
|
||||
* Entity extraction. For more information, see
|
||||
[Prepare text training data for entity extraction](https://cloud.google.com/vertex-ai/docs/text-data/entity-extraction/prepare-data).
|
||||
* Sentiment analysis. For more information, see
|
||||
[Prepare text training data for sentiment analysis](Prepare text training data for sentiment analysis).
|
||||
|
||||
The following code shows you how to create and import a text dataset with
|
||||
a CSV datasource file and a YAML schema file. The schema file you use
|
||||
depends on whether your text dataset is used for single-label
|
||||
classification, multi-label classification, or object detection.
|
||||
|
||||
```py
|
||||
my_dataset = aiplatform.TextDataset.create(
|
||||
display_name="my-text-dataset",
|
||||
gcs_source=['gs://path/to/my/text-dataset.csv'],
|
||||
import_schema_uri=['gs://path/to/my/schema.yaml'],
|
||||
)
|
||||
```
|
||||
"""
|
||||
|
||||
_supported_metadata_schema_uris: Optional[Tuple[str]] = (
|
||||
schema.dataset.metadata.text,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def create(
|
||||
cls,
|
||||
display_name: Optional[str] = None,
|
||||
gcs_source: Optional[Union[str, Sequence[str]]] = None,
|
||||
import_schema_uri: Optional[str] = None,
|
||||
data_item_labels: Optional[Dict] = None,
|
||||
project: Optional[str] = None,
|
||||
location: Optional[str] = None,
|
||||
credentials: Optional[auth_credentials.Credentials] = None,
|
||||
request_metadata: Optional[Sequence[Tuple[str, str]]] = (),
|
||||
labels: Optional[Dict[str, str]] = None,
|
||||
encryption_spec_key_name: Optional[str] = None,
|
||||
sync: bool = True,
|
||||
create_request_timeout: Optional[float] = None,
|
||||
) -> "TextDataset":
|
||||
"""Creates a new text dataset.
|
||||
|
||||
Optionally imports data into this dataset when a source and
|
||||
`import_schema_uri` are passed in. The following is an example of how
|
||||
this method is used:
|
||||
|
||||
```py
|
||||
ds = aiplatform.TextDataset.create(
|
||||
display_name='my-dataset',
|
||||
gcs_source='gs://my-bucket/dataset.csv',
|
||||
import_schema_uri=aiplatform.schema.dataset.ioformat.text.multi_label_classification
|
||||
)
|
||||
```
|
||||
|
||||
Args:
|
||||
display_name (str):
|
||||
Optional. The user-defined name of the dataset. The name must
|
||||
contain 128 or fewer UTF-8 characters.
|
||||
gcs_source (Union[str, Sequence[str]]):
|
||||
Optional. The URI to one or more Google Cloud Storage buckets
|
||||
that contain your datasets. For example, `str:
|
||||
"gs://bucket/file.csv"` or `Sequence[str]:
|
||||
["gs://bucket/file1.csv", "gs://bucket/file2.csv"]`.
|
||||
import_schema_uri (str):
|
||||
Optional. A URI for a YAML file stored in Cloud Storage that
|
||||
describes the import schema used to validate the
|
||||
dataset. The schema is an
|
||||
[OpenAPI 3.0.2 Schema](https://tinyurl.com/y538mdwt) object.
|
||||
data_item_labels (Dict):
|
||||
Optional. A dictionary of label information. Each dictionary
|
||||
item contains a label and a label key. Each item in the dataset
|
||||
includes one dictionary of label information. If a data item is
|
||||
added or merged into a dataset, and that data item contains an
|
||||
image that's identical to an image that’s already in the
|
||||
dataset, then the data items are merged. If two identical labels
|
||||
are detected during the merge, each with a different label key,
|
||||
then one of the label and label key dictionary items is randomly
|
||||
chosen to be into the merged data item. Data items are
|
||||
compared using their binary data (bytes), not on their content.
|
||||
If annotation labels are referenced in a schema specified by the
|
||||
`import_schema_url` parameter, then the labels in the
|
||||
`data_item_labels` dictionary are overriden by the annotations.
|
||||
project (str):
|
||||
Optional. The name of the Google Cloud project to which this
|
||||
`TextDataset` is uploaded. This overrides the project that
|
||||
was set by `aiplatform.init`.
|
||||
location (str):
|
||||
Optional. The Google Cloud region where this dataset is uploaded. This
|
||||
region overrides the region that was set by `aiplatform.init`.
|
||||
credentials (auth_credentials.Credentials):
|
||||
Optional. The credentials that are used to upload the `TextDataset`.
|
||||
These credentials override the credentials set by
|
||||
`aiplatform.init`.
|
||||
request_metadata (Sequence[Tuple[str, str]]):
|
||||
Optional. Strings that contain metadata that's sent with the request.
|
||||
labels (Dict[str, str]):
|
||||
Optional. Labels with user-defined metadata to organize your
|
||||
Vertex AI Tensorboards. The maximum length of a key and of a
|
||||
value is 64 unicode characters. Labels and keys can contain only
|
||||
lowercase letters, numeric characters, underscores, and dashes.
|
||||
International characters are allowed. No more than 64 user
|
||||
labels can be associated with one Tensorboard (system labels are
|
||||
excluded). For more information and examples of using labels, see
|
||||
[Using labels to organize Google Cloud Platform resources](https://goo.gl/xmQnxf).
|
||||
System reserved label keys are prefixed with
|
||||
`aiplatform.googleapis.com/` and are immutable.
|
||||
encryption_spec_key_name (Optional[str]):
|
||||
Optional. The Cloud KMS resource identifier of the customer
|
||||
managed encryption key that's used to protect the dataset. The
|
||||
format of the key is
|
||||
`projects/my-project/locations/my-region/keyRings/my-kr/cryptoKeys/my-key`.
|
||||
The key needs to be in the same region as where the compute
|
||||
resource is created.
|
||||
|
||||
If `encryption_spec_key_name` is set, this `TextDataset` and
|
||||
all of its sub-resources are secured by this key.
|
||||
|
||||
This `encryption_spec_key_name` overrides the
|
||||
`encryption_spec_key_name` set by `aiplatform.init`.
|
||||
sync (bool):
|
||||
If `true`, the `create` method creates a text dataset
|
||||
synchronously. If `false`, the `create` method creates a text
|
||||
dataset asynchronously.
|
||||
create_request_timeout (float):
|
||||
Optional. The number of seconds for the timeout of the create
|
||||
request.
|
||||
|
||||
Returns:
|
||||
text_dataset (TextDataset):
|
||||
An instantiated representation of the managed `TextDataset`
|
||||
resource.
|
||||
"""
|
||||
if not display_name:
|
||||
display_name = cls._generate_display_name()
|
||||
utils.validate_display_name(display_name)
|
||||
if labels:
|
||||
utils.validate_labels(labels)
|
||||
|
||||
api_client = cls._instantiate_client(location=location, credentials=credentials)
|
||||
|
||||
metadata_schema_uri = schema.dataset.metadata.text
|
||||
|
||||
datasource = _datasources.create_datasource(
|
||||
metadata_schema_uri=metadata_schema_uri,
|
||||
import_schema_uri=import_schema_uri,
|
||||
gcs_source=gcs_source,
|
||||
data_item_labels=data_item_labels,
|
||||
)
|
||||
|
||||
return cls._create_and_import(
|
||||
api_client=api_client,
|
||||
parent=initializer.global_config.common_location_path(
|
||||
project=project, location=location
|
||||
),
|
||||
display_name=display_name,
|
||||
metadata_schema_uri=metadata_schema_uri,
|
||||
datasource=datasource,
|
||||
project=project or initializer.global_config.project,
|
||||
location=location or initializer.global_config.location,
|
||||
credentials=credentials or initializer.global_config.credentials,
|
||||
request_metadata=request_metadata,
|
||||
labels=labels,
|
||||
encryption_spec=initializer.global_config.get_encryption_spec(
|
||||
encryption_spec_key_name=encryption_spec_key_name
|
||||
),
|
||||
sync=sync,
|
||||
create_request_timeout=create_request_timeout,
|
||||
)
|
||||
@@ -0,0 +1,186 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
# Copyright 2020 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
from typing import Dict, Optional, Sequence, Tuple, Union
|
||||
|
||||
from google.auth import credentials as auth_credentials
|
||||
|
||||
from google.cloud.aiplatform import datasets
|
||||
from google.cloud.aiplatform.datasets import _datasources
|
||||
from google.cloud.aiplatform import initializer
|
||||
from google.cloud.aiplatform import schema
|
||||
from google.cloud.aiplatform import utils
|
||||
|
||||
|
||||
class TimeSeriesDataset(datasets._ColumnNamesDataset):
|
||||
"""A managed time series dataset resource for Vertex AI.
|
||||
|
||||
Use this class to work with time series datasets. A time series is a dataset
|
||||
that contains data recorded at different time intervals. The dataset
|
||||
includes time and at least one variable that's dependent on time. You use a
|
||||
time series dataset for forecasting predictions. For more information, see
|
||||
[Forecasting overview](https://cloud.google.com/vertex-ai/docs/tabular-data/forecasting/overview).
|
||||
|
||||
You can create a managed time series dataset from CSV files in a Cloud
|
||||
Storage bucket or from a BigQuery table.
|
||||
|
||||
The following code shows you how to create a `TimeSeriesDataset` with a CSV
|
||||
file that has the time series dataset:
|
||||
|
||||
```py
|
||||
my_dataset = aiplatform.TimeSeriesDataset.create(
|
||||
display_name="my-dataset",
|
||||
gcs_source=['gs://path/to/my/dataset.csv'],
|
||||
)
|
||||
```
|
||||
|
||||
The following code shows you how to create with a `TimeSeriesDataset` with a
|
||||
BigQuery table file that has the time series dataset:
|
||||
|
||||
```py
|
||||
my_dataset = aiplatform.TimeSeriesDataset.create(
|
||||
display_name="my-dataset",
|
||||
bq_source=['bq://path/to/my/bigquerydataset.train'],
|
||||
)
|
||||
```
|
||||
|
||||
"""
|
||||
|
||||
_supported_metadata_schema_uris: Optional[Tuple[str]] = (
|
||||
schema.dataset.metadata.time_series,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def create(
|
||||
cls,
|
||||
display_name: Optional[str] = None,
|
||||
gcs_source: Optional[Union[str, Sequence[str]]] = None,
|
||||
bq_source: Optional[str] = None,
|
||||
project: Optional[str] = None,
|
||||
location: Optional[str] = None,
|
||||
credentials: Optional[auth_credentials.Credentials] = None,
|
||||
request_metadata: Optional[Sequence[Tuple[str, str]]] = (),
|
||||
labels: Optional[Dict[str, str]] = None,
|
||||
encryption_spec_key_name: Optional[str] = None,
|
||||
sync: bool = True,
|
||||
create_request_timeout: Optional[float] = None,
|
||||
) -> "TimeSeriesDataset":
|
||||
"""Creates a new time series dataset.
|
||||
|
||||
Args:
|
||||
display_name (str):
|
||||
Optional. The user-defined name of the dataset. The name must
|
||||
contain 128 or fewer UTF-8 characters.
|
||||
gcs_source (Union[str, Sequence[str]]):
|
||||
The URI to one or more Google Cloud Storage buckets that contain
|
||||
your datasets. For example, `str: "gs://bucket/file.csv"` or
|
||||
`Sequence[str]: ["gs://bucket/file1.csv",
|
||||
"gs://bucket/file2.csv"]`.
|
||||
bq_source (str):
|
||||
A BigQuery URI for the input table. For example,
|
||||
`bq://project.dataset.table_name`.
|
||||
project (str):
|
||||
The name of the Google Cloud project to which this
|
||||
`TimeSeriesDataset` is uploaded. This overrides the project that
|
||||
was set by `aiplatform.init`.
|
||||
location (str):
|
||||
The Google Cloud region where this dataset is uploaded. This
|
||||
region overrides the region that was set by `aiplatform.init`.
|
||||
credentials (auth_credentials.Credentials):
|
||||
The credentials that are used to upload the `TimeSeriesDataset`.
|
||||
These credentials override the credentials set by
|
||||
`aiplatform.init`.
|
||||
request_metadata (Sequence[Tuple[str, str]]):
|
||||
Strings that contain metadata that's sent with the request.
|
||||
labels (Dict[str, str]):
|
||||
Optional. Labels with user-defined metadata to organize your
|
||||
Vertex AI Tensorboards. The maximum length of a key and of a
|
||||
value is 64 unicode characters. Labels and keys can contain only
|
||||
lowercase letters, numeric characters, underscores, and dashes.
|
||||
International characters are allowed. No more than 64 user
|
||||
labels can be associated with one Tensorboard (system labels are
|
||||
excluded). For more information and examples of using labels, see
|
||||
[Using labels to organize Google Cloud Platform resources](https://goo.gl/xmQnxf).
|
||||
System reserved label keys are prefixed with
|
||||
`aiplatform.googleapis.com/` and are immutable.
|
||||
encryption_spec_key_name (Optional[str]):
|
||||
Optional. The Cloud KMS resource identifier of the customer
|
||||
managed encryption key that's used to protect the dataset. The
|
||||
format of the key is
|
||||
`projects/my-project/locations/my-region/keyRings/my-kr/cryptoKeys/my-key`.
|
||||
The key needs to be in the same region as where the compute
|
||||
resource is created.
|
||||
|
||||
If `encryption_spec_key_name` is set, this time series dataset
|
||||
and all of its sub-resources are secured by this key.
|
||||
|
||||
This `encryption_spec_key_name` overrides the
|
||||
`encryption_spec_key_name` set by `aiplatform.init`.
|
||||
create_request_timeout (float):
|
||||
Optional. The number of seconds for the timeout of the create
|
||||
request.
|
||||
sync (bool):
|
||||
If `true`, the `create` method creates a time series dataset
|
||||
synchronously. If `false`, the `create` method creates a time
|
||||
series dataset asynchronously.
|
||||
|
||||
Returns:
|
||||
time_series_dataset (TimeSeriesDataset):
|
||||
An instantiated representation of the managed
|
||||
`TimeSeriesDataset` resource.
|
||||
|
||||
"""
|
||||
if not display_name:
|
||||
display_name = cls._generate_display_name()
|
||||
utils.validate_display_name(display_name)
|
||||
if labels:
|
||||
utils.validate_labels(labels)
|
||||
|
||||
api_client = cls._instantiate_client(location=location, credentials=credentials)
|
||||
|
||||
metadata_schema_uri = schema.dataset.metadata.time_series
|
||||
|
||||
datasource = _datasources.create_datasource(
|
||||
metadata_schema_uri=metadata_schema_uri,
|
||||
gcs_source=gcs_source,
|
||||
bq_source=bq_source,
|
||||
)
|
||||
|
||||
return cls._create_and_import(
|
||||
api_client=api_client,
|
||||
parent=initializer.global_config.common_location_path(
|
||||
project=project, location=location
|
||||
),
|
||||
display_name=display_name,
|
||||
metadata_schema_uri=metadata_schema_uri,
|
||||
datasource=datasource,
|
||||
project=project or initializer.global_config.project,
|
||||
location=location or initializer.global_config.location,
|
||||
credentials=credentials or initializer.global_config.credentials,
|
||||
request_metadata=request_metadata,
|
||||
labels=labels,
|
||||
encryption_spec=initializer.global_config.get_encryption_spec(
|
||||
encryption_spec_key_name=encryption_spec_key_name
|
||||
),
|
||||
sync=sync,
|
||||
create_request_timeout=create_request_timeout,
|
||||
)
|
||||
|
||||
def import_data(self):
|
||||
raise NotImplementedError(
|
||||
f"{self.__class__.__name__} class does not support 'import_data'"
|
||||
)
|
||||
@@ -0,0 +1,199 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
# Copyright 2020 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
from typing import Dict, Optional, Sequence, Tuple, Union
|
||||
|
||||
from google.auth import credentials as auth_credentials
|
||||
|
||||
from google.cloud.aiplatform import datasets
|
||||
from google.cloud.aiplatform.datasets import _datasources
|
||||
from google.cloud.aiplatform import initializer
|
||||
from google.cloud.aiplatform import schema
|
||||
from google.cloud.aiplatform import utils
|
||||
|
||||
|
||||
class VideoDataset(datasets._Dataset):
|
||||
"""A managed video dataset resource for Vertex AI.
|
||||
|
||||
Use this class to work with a managed video dataset. To create a video
|
||||
dataset, you need a datasource in CSV format and a schema in YAML format.
|
||||
The CSV file and the schema are accessed in Cloud Storage buckets.
|
||||
|
||||
Use video data for the following objectives:
|
||||
|
||||
Classification. For more information, see Classification schema files.
|
||||
Action recognition. For more information, see Action recognition schema
|
||||
files. Object tracking. For more information, see Object tracking schema
|
||||
files. The following code shows you how to create and import a dataset to
|
||||
train a video classification model. The schema file you use depends on
|
||||
whether you use your video dataset for action classification, recognition,
|
||||
or object tracking.
|
||||
|
||||
```py
|
||||
my_dataset = aiplatform.VideoDataset.create(
|
||||
gcs_source=['gs://path/to/my/dataset.csv'],
|
||||
import_schema_uri=['gs://aip.schema.dataset.ioformat.video.classification.yaml']
|
||||
)
|
||||
```
|
||||
"""
|
||||
|
||||
_supported_metadata_schema_uris: Optional[Tuple[str]] = (
|
||||
schema.dataset.metadata.video,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def create(
|
||||
cls,
|
||||
display_name: Optional[str] = None,
|
||||
gcs_source: Optional[Union[str, Sequence[str]]] = None,
|
||||
import_schema_uri: Optional[str] = None,
|
||||
data_item_labels: Optional[Dict] = None,
|
||||
project: Optional[str] = None,
|
||||
location: Optional[str] = None,
|
||||
credentials: Optional[auth_credentials.Credentials] = None,
|
||||
request_metadata: Optional[Sequence[Tuple[str, str]]] = (),
|
||||
labels: Optional[Dict[str, str]] = None,
|
||||
encryption_spec_key_name: Optional[str] = None,
|
||||
sync: bool = True,
|
||||
create_request_timeout: Optional[float] = None,
|
||||
) -> "VideoDataset":
|
||||
"""Creates a new video dataset.
|
||||
|
||||
Optionally imports data into the dataset when a source and
|
||||
`import_schema_uri` are passed in. The following is an example of how
|
||||
this method is used:
|
||||
|
||||
```py
|
||||
my_dataset = aiplatform.VideoDataset.create(
|
||||
gcs_source=['gs://path/to/my/dataset.csv'],
|
||||
import_schema_uri=['gs://aip.schema.dataset.ioformat.video.classification.yaml']
|
||||
)
|
||||
```
|
||||
|
||||
Args:
|
||||
display_name (str):
|
||||
Optional. The user-defined name of the dataset. The name must
|
||||
contain 128 or fewer UTF-8 characters.
|
||||
gcs_source (Union[str, Sequence[str]]):
|
||||
The URI to one or more Google Cloud Storage buckets that contain
|
||||
your datasets. For example, `str: "gs://bucket/file.csv"` or
|
||||
`Sequence[str]: ["gs://bucket/file1.csv",
|
||||
"gs://bucket/file2.csv"]`.
|
||||
import_schema_uri (str):
|
||||
A URI for a YAML file stored in Cloud Storage that
|
||||
describes the import schema used to validate the
|
||||
dataset. The schema is an
|
||||
[OpenAPI 3.0.2 Schema](https://tinyurl.com/y538mdwt) object.
|
||||
data_item_labels (Dict):
|
||||
Optional. A dictionary of label information. Each dictionary
|
||||
item contains a label and a label key. Each item in the dataset
|
||||
includes one dictionary of label information. If a data item is
|
||||
added or merged into a dataset, and that data item contains an
|
||||
image that's identical to an image that’s already in the
|
||||
dataset, then the data items are merged. If two identical labels
|
||||
are detected during the merge, each with a different label key,
|
||||
then one of the label and label key dictionary items is randomly
|
||||
chosen to be into the merged data item. Dataset items are
|
||||
compared using their binary data (bytes), not on their content.
|
||||
If annotation labels are referenced in a schema specified by the
|
||||
`import_schema_url` parameter, then the labels in the
|
||||
`data_item_labels` dictionary are overriden by the annotations.
|
||||
project (str):
|
||||
The name of the Google Cloud project to which this
|
||||
`VideoDataset` is uploaded. This overrides the project that
|
||||
was set by `aiplatform.init`.
|
||||
location (str):
|
||||
The Google Cloud region where this dataset is uploaded. This
|
||||
region overrides the region that was set by `aiplatform.init`.
|
||||
credentials (auth_credentials.Credentials):
|
||||
The credentials that are used to upload the `VideoDataset`.
|
||||
These credentials override the credentials set by
|
||||
`aiplatform.init`.
|
||||
request_metadata (Sequence[Tuple[str, str]]):
|
||||
Strings that contain metadata that's sent with the request.
|
||||
labels (Dict[str, str]):
|
||||
Optional. Labels with user-defined metadata to organize your
|
||||
Vertex AI Tensorboards. The maximum length of a key and of a
|
||||
value is 64 unicode characters. Labels and keys can contain only
|
||||
lowercase letters, numeric characters, underscores, and dashes.
|
||||
International characters are allowed. No more than 64 user
|
||||
labels can be associated with one Tensorboard (system labels are
|
||||
excluded). For more information and examples of using labels, see
|
||||
[Using labels to organize Google Cloud Platform resources](https://goo.gl/xmQnxf).
|
||||
System reserved label keys are prefixed with
|
||||
`aiplatform.googleapis.com/` and are immutable.
|
||||
encryption_spec_key_name (Optional[str]):
|
||||
Optional. The Cloud KMS resource identifier of the customer
|
||||
managed encryption key that's used to protect the dataset. The
|
||||
format of the key is
|
||||
`projects/my-project/locations/my-region/keyRings/my-kr/cryptoKeys/my-key`.
|
||||
The key needs to be in the same region as where the compute
|
||||
resource is created.
|
||||
|
||||
If `encryption_spec_key_name` is set, this `VideoDataset` and
|
||||
all of its sub-resources are secured by this key.
|
||||
|
||||
This `encryption_spec_key_name` overrides the
|
||||
`encryption_spec_key_name` set by `aiplatform.init`.
|
||||
sync (bool):
|
||||
If `true`, the `create` method creates a video dataset
|
||||
synchronously. If `false`, the `create` mdthod creates a video
|
||||
dataset asynchronously.
|
||||
create_request_timeout (float):
|
||||
Optional. The number of seconds for the timeout of the create
|
||||
request.
|
||||
Returns:
|
||||
video_dataset (VideoDataset):
|
||||
An instantiated representation of the managed
|
||||
`VideoDataset` resource.
|
||||
"""
|
||||
if not display_name:
|
||||
display_name = cls._generate_display_name()
|
||||
utils.validate_display_name(display_name)
|
||||
if labels:
|
||||
utils.validate_labels(labels)
|
||||
|
||||
api_client = cls._instantiate_client(location=location, credentials=credentials)
|
||||
|
||||
metadata_schema_uri = schema.dataset.metadata.video
|
||||
|
||||
datasource = _datasources.create_datasource(
|
||||
metadata_schema_uri=metadata_schema_uri,
|
||||
import_schema_uri=import_schema_uri,
|
||||
gcs_source=gcs_source,
|
||||
data_item_labels=data_item_labels,
|
||||
)
|
||||
|
||||
return cls._create_and_import(
|
||||
api_client=api_client,
|
||||
parent=initializer.global_config.common_location_path(
|
||||
project=project, location=location
|
||||
),
|
||||
display_name=display_name,
|
||||
metadata_schema_uri=metadata_schema_uri,
|
||||
datasource=datasource,
|
||||
project=project or initializer.global_config.project,
|
||||
location=location or initializer.global_config.location,
|
||||
credentials=credentials or initializer.global_config.credentials,
|
||||
request_metadata=request_metadata,
|
||||
labels=labels,
|
||||
encryption_spec=initializer.global_config.get_encryption_spec(
|
||||
encryption_spec_key_name=encryption_spec_key_name
|
||||
),
|
||||
sync=sync,
|
||||
create_request_timeout=create_request_timeout,
|
||||
)
|
||||
@@ -0,0 +1,16 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
# Copyright 2022 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,558 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
# Copyright 2022 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
from pathlib import Path
|
||||
import textwrap
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
from shlex import quote
|
||||
|
||||
from google.cloud.aiplatform.docker_utils import local_util
|
||||
from google.cloud.aiplatform.docker_utils.errors import DockerError
|
||||
from google.cloud.aiplatform.docker_utils.utils import (
|
||||
DEFAULT_HOME,
|
||||
DEFAULT_WORKDIR,
|
||||
Image,
|
||||
Package,
|
||||
)
|
||||
from google.cloud.aiplatform.utils import path_utils
|
||||
|
||||
_logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _generate_copy_command(
|
||||
from_path: str, to_path: str, comment: Optional[str] = None
|
||||
) -> str:
|
||||
"""Returns a Dockerfile entry that copies a file from host to container.
|
||||
|
||||
Args:
|
||||
from_path (str):
|
||||
Required. The path of the source in host.
|
||||
to_path (str):
|
||||
Required. The path to the destination in the container.
|
||||
comment (str):
|
||||
Optional. A comment explaining the copy operation.
|
||||
|
||||
Returns:
|
||||
The generated copy command used in Dockerfile.
|
||||
"""
|
||||
cmd = "COPY {}".format(json.dumps([from_path, to_path]))
|
||||
|
||||
if comment is not None:
|
||||
formatted_comment = "\n# ".join(comment.split("\n"))
|
||||
return textwrap.dedent(
|
||||
"""
|
||||
# {}
|
||||
{}
|
||||
""".format(
|
||||
formatted_comment,
|
||||
cmd,
|
||||
)
|
||||
)
|
||||
|
||||
return cmd
|
||||
|
||||
|
||||
def _prepare_dependency_entries(
|
||||
setup_path: Optional[str] = None,
|
||||
requirements_path: Optional[str] = None,
|
||||
extra_packages: Optional[List[str]] = None,
|
||||
extra_requirements: Optional[List[str]] = None,
|
||||
extra_dirs: Optional[List[str]] = None,
|
||||
force_reinstall: bool = False,
|
||||
pip_command: str = "pip",
|
||||
) -> str:
|
||||
"""Returns the Dockerfile entries required to install dependencies.
|
||||
|
||||
Args:
|
||||
setup_path (str):
|
||||
Optional. The path that points to a setup.py.
|
||||
requirements_path (str):
|
||||
Optional. The path that points to a requirements.txt file.
|
||||
extra_packages (List[str]):
|
||||
Optional. The list of user custom dependency packages to install.
|
||||
extra_requirements (List[str]):
|
||||
Optional. The list of required dependencies to be installed from remote resource archives.
|
||||
extra_dirs (List[str]):
|
||||
Optional. The directories other than the work_dir required.
|
||||
force_reinstall (bool):
|
||||
Required. Whether or not force reinstall all packages even if they are already up-to-date.
|
||||
pip_command (str):
|
||||
Required. The pip command used for install packages.
|
||||
|
||||
Returns:
|
||||
The dependency installation command used in Dockerfile.
|
||||
"""
|
||||
ret = ""
|
||||
|
||||
if setup_path is not None:
|
||||
ret += _generate_copy_command(
|
||||
setup_path,
|
||||
"./setup.py",
|
||||
comment="setup.py file specified, thus copy it to the docker container.",
|
||||
) + textwrap.dedent(
|
||||
"""
|
||||
RUN {} install --no-cache-dir {} .
|
||||
""".format(
|
||||
pip_command,
|
||||
"--force-reinstall" if force_reinstall else "",
|
||||
)
|
||||
)
|
||||
|
||||
if requirements_path is not None:
|
||||
ret += textwrap.dedent(
|
||||
"""
|
||||
RUN {} install --no-cache-dir {} -r {}
|
||||
""".format(
|
||||
pip_command,
|
||||
"--force-reinstall" if force_reinstall else "",
|
||||
requirements_path,
|
||||
)
|
||||
)
|
||||
|
||||
if extra_packages is not None:
|
||||
for package in extra_packages:
|
||||
ret += textwrap.dedent(
|
||||
"""
|
||||
RUN {} install --no-cache-dir {} {}
|
||||
""".format(
|
||||
pip_command,
|
||||
"--force-reinstall" if force_reinstall else "",
|
||||
quote(package),
|
||||
)
|
||||
)
|
||||
|
||||
if extra_requirements is not None:
|
||||
for requirement in extra_requirements:
|
||||
ret += textwrap.dedent(
|
||||
"""
|
||||
RUN {} install --no-cache-dir {} {}
|
||||
""".format(
|
||||
pip_command,
|
||||
"--force-reinstall" if force_reinstall else "",
|
||||
quote(requirement),
|
||||
)
|
||||
)
|
||||
|
||||
if extra_dirs is not None:
|
||||
for directory in extra_dirs:
|
||||
ret += "\n{}\n".format(_generate_copy_command(directory, directory))
|
||||
|
||||
return ret
|
||||
|
||||
|
||||
def _prepare_entrypoint(package: Package, python_command: str = "python") -> str:
|
||||
"""Generates dockerfile entry to set the container entrypoint.
|
||||
|
||||
Args:
|
||||
package (Package):
|
||||
Required. The main application copied to the container.
|
||||
python_command (str):
|
||||
Required. The python command used for running python code.
|
||||
|
||||
Returns:
|
||||
A string with Dockerfile directives to set ENTRYPOINT or "".
|
||||
"""
|
||||
exec_str = ""
|
||||
# Needs to use json so that quotes print as double quotes, not single quotes.
|
||||
if package.python_module is not None:
|
||||
exec_str = json.dumps([python_command, "-m", package.python_module])
|
||||
elif package.script is not None:
|
||||
_, ext = os.path.splitext(package.script)
|
||||
executable = [python_command] if ext == ".py" else ["/bin/bash"]
|
||||
exec_str = json.dumps(executable + [package.script])
|
||||
|
||||
if not exec_str:
|
||||
return ""
|
||||
return "\nENTRYPOINT {}\n".format(exec_str)
|
||||
|
||||
|
||||
def _copy_source_directory() -> str:
|
||||
"""Returns the Dockerfile entry required to copy the package to the image.
|
||||
|
||||
The Docker build context has been changed to host_workdir. We copy all
|
||||
the files to the working directory of images.
|
||||
|
||||
Returns:
|
||||
The generated package related copy command used in Dockerfile.
|
||||
"""
|
||||
copy_code = _generate_copy_command(
|
||||
".", # Dockefile context location has been changed to host_workdir
|
||||
".", # Copy all the files to the working directory of images.
|
||||
comment="Copy the source directory into the docker container.",
|
||||
)
|
||||
|
||||
return "\n{}\n".format(copy_code)
|
||||
|
||||
|
||||
def _prepare_exposed_ports(exposed_ports: Optional[List[int]] = None) -> str:
|
||||
"""Returns the Dockerfile entries required to expose ports in containers.
|
||||
|
||||
Args:
|
||||
exposed_ports (List[int]):
|
||||
Optional. The exposed ports that the container listens on at runtime.
|
||||
|
||||
Returns:
|
||||
The generated port expose command used in Dockerfile.
|
||||
"""
|
||||
ret = ""
|
||||
|
||||
if exposed_ports is None:
|
||||
return ret
|
||||
|
||||
for port in exposed_ports:
|
||||
ret += "\nEXPOSE {}\n".format(port)
|
||||
return ret
|
||||
|
||||
|
||||
def _prepare_environment_variables(
|
||||
environment_variables: Optional[Dict[str, str]] = None
|
||||
) -> str:
|
||||
"""Returns the Dockerfile entries required to set environment variables in containers.
|
||||
|
||||
Args:
|
||||
environment_variables (Dict[str, str]):
|
||||
Optional. The environment variables to be set in the container.
|
||||
|
||||
Returns:
|
||||
The generated environment variable commands used in Dockerfile.
|
||||
"""
|
||||
ret = ""
|
||||
|
||||
if environment_variables is None:
|
||||
return ret
|
||||
|
||||
for key, value in environment_variables.items():
|
||||
ret += f"\nENV {key}={value}\n"
|
||||
return ret
|
||||
|
||||
|
||||
def _get_relative_path_to_workdir(
|
||||
workdir: str,
|
||||
path: Optional[str] = None,
|
||||
value_name: str = "value",
|
||||
) -> str:
|
||||
"""Returns the relative path to the workdir.
|
||||
|
||||
Args:
|
||||
workdir (str):
|
||||
Required. The directory that the retrieved path relative to.
|
||||
path (str):
|
||||
Optional. The path to retrieve the relative path to the workdir.
|
||||
value_name (str):
|
||||
Required. The variable name specified in the exception message.
|
||||
|
||||
Returns:
|
||||
The relative path to the workdir or None if path is None.
|
||||
|
||||
Raises:
|
||||
ValueError: If the path does not exist or is not relative to the workdir.
|
||||
"""
|
||||
if path is None:
|
||||
return None
|
||||
|
||||
if not Path(path).is_file():
|
||||
raise ValueError(f'The {value_name} "{path}" must exist.')
|
||||
if not path_utils._is_relative_to(path, workdir):
|
||||
raise ValueError(f'The {value_name} "{path}" must be in "{workdir}".')
|
||||
abs_path = Path(path).expanduser().resolve()
|
||||
abs_workdir = Path(workdir).expanduser().resolve()
|
||||
return Path(abs_path).relative_to(abs_workdir).as_posix()
|
||||
|
||||
|
||||
def make_dockerfile(
|
||||
base_image: str,
|
||||
main_package: Package,
|
||||
container_workdir: str,
|
||||
container_home: str,
|
||||
requirements_path: Optional[str] = None,
|
||||
setup_path: Optional[str] = None,
|
||||
extra_requirements: Optional[List[str]] = None,
|
||||
extra_packages: Optional[List[str]] = None,
|
||||
extra_dirs: Optional[List[str]] = None,
|
||||
exposed_ports: Optional[List[int]] = None,
|
||||
environment_variables: Optional[Dict[str, str]] = None,
|
||||
pip_command: str = "pip",
|
||||
python_command: str = "python",
|
||||
) -> str:
|
||||
"""Generates a Dockerfile for building an image.
|
||||
|
||||
It builds on a specified base image to create a container that:
|
||||
- installs any dependency specified in a requirements.txt or a setup.py file,
|
||||
and any specified dependency packages existing locally or found from PyPI
|
||||
- copies all source needed by the main module, and potentially injects an
|
||||
entrypoint that, on run, will run that main module
|
||||
|
||||
Args:
|
||||
base_image (str):
|
||||
Required. The ID or name of the base image to initialize the build stage.
|
||||
main_package (Package):
|
||||
Required. The main application to execute.
|
||||
container_workdir (str):
|
||||
Required. The working directory in the container.
|
||||
container_home (str):
|
||||
Required. The $HOME directory in the container.
|
||||
requirements_path (str):
|
||||
Optional. The path to a local requirements.txt file.
|
||||
setup_path (str):
|
||||
Optional. The path to a local setup.py file.
|
||||
extra_requirements (List[str]):
|
||||
Optional. The list of required dependencies to install from PyPI.
|
||||
extra_packages (List[str]):
|
||||
Optional. The list of user custom dependency packages to install.
|
||||
extra_dirs: (List[str]):
|
||||
Optional. The directories other than the work_dir required to be in the container.
|
||||
exposed_ports (List[int]):
|
||||
Optional. The exposed ports that the container listens on at runtime.
|
||||
environment_variables (Dict[str, str]):
|
||||
Optional. The environment variables to be set in the container.
|
||||
pip_command (str):
|
||||
Required. The pip command used for install packages.
|
||||
python_command (str):
|
||||
Required. The python command used for running python code.
|
||||
|
||||
Returns:
|
||||
A string that represents the content of a Dockerfile.
|
||||
"""
|
||||
dockerfile = textwrap.dedent(
|
||||
"""
|
||||
FROM {base_image}
|
||||
|
||||
# Keeps Python from generating .pyc files in the container
|
||||
ENV PYTHONDONTWRITEBYTECODE=1
|
||||
""".format(
|
||||
base_image=base_image,
|
||||
)
|
||||
)
|
||||
|
||||
dockerfile += _prepare_exposed_ports(exposed_ports)
|
||||
|
||||
dockerfile += _prepare_entrypoint(main_package, python_command=python_command)
|
||||
|
||||
dockerfile += textwrap.dedent(
|
||||
"""
|
||||
# The directory is created by root. This sets permissions so that any user can
|
||||
# access the folder.
|
||||
RUN mkdir -m 777 -p {workdir} {container_home}
|
||||
WORKDIR {workdir}
|
||||
ENV HOME={container_home}
|
||||
""".format(
|
||||
workdir=quote(container_workdir),
|
||||
container_home=quote(container_home),
|
||||
)
|
||||
)
|
||||
|
||||
# Installs extra requirements which do not involve user source code.
|
||||
dockerfile += _prepare_dependency_entries(
|
||||
requirements_path=None,
|
||||
setup_path=None,
|
||||
extra_requirements=extra_requirements,
|
||||
extra_packages=None,
|
||||
extra_dirs=None,
|
||||
force_reinstall=True,
|
||||
pip_command=pip_command,
|
||||
)
|
||||
|
||||
dockerfile += _prepare_environment_variables(
|
||||
environment_variables=environment_variables
|
||||
)
|
||||
|
||||
# Copies user code to the image.
|
||||
dockerfile += _copy_source_directory()
|
||||
|
||||
# Installs packages from requirements_path.
|
||||
dockerfile += _prepare_dependency_entries(
|
||||
requirements_path=requirements_path,
|
||||
setup_path=None,
|
||||
extra_requirements=None,
|
||||
extra_packages=None,
|
||||
extra_dirs=None,
|
||||
force_reinstall=True,
|
||||
pip_command=pip_command,
|
||||
)
|
||||
|
||||
# Installs additional packages from user code.
|
||||
dockerfile += _prepare_dependency_entries(
|
||||
requirements_path=None,
|
||||
setup_path=setup_path,
|
||||
extra_requirements=None,
|
||||
extra_packages=extra_packages,
|
||||
extra_dirs=extra_dirs,
|
||||
force_reinstall=True,
|
||||
pip_command=pip_command,
|
||||
)
|
||||
|
||||
return dockerfile
|
||||
|
||||
|
||||
def build_image(
|
||||
base_image: str,
|
||||
host_workdir: str,
|
||||
output_image_name: str,
|
||||
python_module: Optional[str] = None,
|
||||
requirements_path: Optional[str] = None,
|
||||
extra_requirements: Optional[List[str]] = None,
|
||||
setup_path: Optional[str] = None,
|
||||
extra_packages: Optional[List[str]] = None,
|
||||
container_workdir: Optional[str] = None,
|
||||
container_home: Optional[str] = None,
|
||||
extra_dirs: Optional[List[str]] = None,
|
||||
exposed_ports: Optional[List[int]] = None,
|
||||
pip_command: str = "pip",
|
||||
python_command: str = "python",
|
||||
no_cache: bool = True,
|
||||
platform: Optional[str] = None,
|
||||
**kwargs,
|
||||
) -> Image:
|
||||
"""Builds a Docker image.
|
||||
|
||||
Generates a Dockerfile and passes it to `docker build` via stdin.
|
||||
All output from the `docker build` process prints to stdout.
|
||||
|
||||
Args:
|
||||
base_image (str):
|
||||
Required. The ID or name of the base image to initialize the build stage.
|
||||
host_workdir (str):
|
||||
Required. The path indicating where all the required sources locates.
|
||||
output_image_name (str):
|
||||
Required. The name of the built image.
|
||||
python_module (str):
|
||||
Optional. The executable main script in form of a python module, if applicable.
|
||||
requirements_path (str):
|
||||
Optional. The path to a local file including required dependencies to install from PyPI.
|
||||
extra_requirements (List[str]):
|
||||
Optional. The list of required dependencies to install from PyPI.
|
||||
setup_path (str):
|
||||
Optional. The path to a local setup.py used for installing packages.
|
||||
extra_packages (List[str]):
|
||||
Optional. The list of user custom dependency packages to install.
|
||||
container_workdir (str):
|
||||
Optional. The working directory in the container.
|
||||
container_home (str):
|
||||
Optional. The $HOME directory in the container.
|
||||
extra_dirs (List[str]):
|
||||
Optional. The directories other than the work_dir required.
|
||||
exposed_ports (List[int]):
|
||||
Optional. The exposed ports that the container listens on at runtime.
|
||||
pip_command (str):
|
||||
Required. The pip command used for installing packages.
|
||||
python_command (str):
|
||||
Required. The python command used for running python scripts.
|
||||
no_cache (bool):
|
||||
Required. Do not use cache when building the image. Using build cache usually
|
||||
reduces the image building time. See
|
||||
https://docs.docker.com/develop/develop-images/dockerfile_best-practices/#leverage-build-cache
|
||||
for more details.
|
||||
platform (str):
|
||||
Optional. The target platform for the Docker image build. See
|
||||
https://docs.docker.com/build/building/multi-platform/#building-multi-platform-images
|
||||
for more details.
|
||||
**kwargs:
|
||||
Other arguments to pass to underlying method that generates the Dockerfile.
|
||||
|
||||
Returns:
|
||||
A Image class that contains info of the built image.
|
||||
|
||||
Raises:
|
||||
DockerError: An error occurred when executing `docker build`
|
||||
ValueError: If the needed code is not relative to the host workdir.
|
||||
"""
|
||||
|
||||
tag_options = ["-t", output_image_name]
|
||||
cache_args = ["--no-cache"] if no_cache else []
|
||||
platform_args = ["--platform", platform] if platform is not None else []
|
||||
|
||||
command = (
|
||||
["docker", "build"]
|
||||
+ cache_args
|
||||
+ platform_args
|
||||
+ tag_options
|
||||
+ ["--rm", "-f-", host_workdir]
|
||||
)
|
||||
|
||||
requirements_relative_path = _get_relative_path_to_workdir(
|
||||
host_workdir,
|
||||
path=requirements_path,
|
||||
value_name="requirements_path",
|
||||
)
|
||||
|
||||
setup_relative_path = _get_relative_path_to_workdir(
|
||||
host_workdir,
|
||||
path=setup_path,
|
||||
value_name="setup_path",
|
||||
)
|
||||
|
||||
extra_packages_relative_paths = (
|
||||
None
|
||||
if extra_packages is None
|
||||
else [
|
||||
_get_relative_path_to_workdir(
|
||||
host_workdir, path=extra_package, value_name="extra_packages"
|
||||
)
|
||||
for extra_package in extra_packages
|
||||
if extra_package is not None
|
||||
]
|
||||
)
|
||||
|
||||
home_dir = container_home or DEFAULT_HOME
|
||||
work_dir = container_workdir or DEFAULT_WORKDIR
|
||||
|
||||
# The package will be used in Docker, thus norm it to POSIX path format.
|
||||
main_package = Package(
|
||||
script=None,
|
||||
package_path=host_workdir,
|
||||
python_module=python_module,
|
||||
)
|
||||
|
||||
dockerfile = make_dockerfile(
|
||||
base_image,
|
||||
main_package,
|
||||
work_dir,
|
||||
home_dir,
|
||||
requirements_path=requirements_relative_path,
|
||||
setup_path=setup_relative_path,
|
||||
extra_requirements=extra_requirements,
|
||||
extra_packages=extra_packages_relative_paths,
|
||||
extra_dirs=extra_dirs,
|
||||
exposed_ports=exposed_ports,
|
||||
pip_command=pip_command,
|
||||
python_command=python_command,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
joined_command = " ".join(command)
|
||||
_logger.info("Running command: {}".format(joined_command))
|
||||
|
||||
return_code = local_util.execute_command(
|
||||
command,
|
||||
input_str=dockerfile,
|
||||
)
|
||||
if return_code == 0:
|
||||
return Image(output_image_name, home_dir, work_dir)
|
||||
else:
|
||||
error_msg = textwrap.dedent(
|
||||
"""
|
||||
Docker failed with error code {code}.
|
||||
Command: {cmd}
|
||||
""".format(
|
||||
code=return_code, cmd=joined_command
|
||||
)
|
||||
)
|
||||
raise DockerError(error_msg, command, return_code)
|
||||
@@ -0,0 +1,60 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
# Copyright 2022 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
import textwrap
|
||||
from typing import List, NoReturn
|
||||
|
||||
|
||||
class Error(Exception):
|
||||
"""A base exception for all user recoverable errors."""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
"""Initialize an Error."""
|
||||
self.exit_code = kwargs.get("exit_code", 1)
|
||||
|
||||
|
||||
class DockerError(Error):
|
||||
"""Exception that passes info on a failed Docker command."""
|
||||
|
||||
def __init__(self, message, cmd, exit_code):
|
||||
super(DockerError, self).__init__(message)
|
||||
self.message = message
|
||||
self.cmd = cmd
|
||||
self.exit_code = exit_code
|
||||
|
||||
|
||||
def raise_docker_error_with_command(command: List[str], return_code: int) -> NoReturn:
|
||||
"""Raises DockerError with the given command and return code.
|
||||
|
||||
Args:
|
||||
command (List(str)):
|
||||
Required. The docker command that fails.
|
||||
return_code (int):
|
||||
Required. The return code from the command.
|
||||
|
||||
Raises:
|
||||
DockerError which error message populated by the given command and return code.
|
||||
"""
|
||||
error_msg = textwrap.dedent(
|
||||
"""
|
||||
Docker failed with error code {code}.
|
||||
Command: {cmd}
|
||||
""".format(
|
||||
code=return_code, cmd=" ".join(command)
|
||||
)
|
||||
)
|
||||
raise DockerError(error_msg, command, return_code)
|
||||
@@ -0,0 +1,62 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
# Copyright 2022 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
import io
|
||||
import logging
|
||||
import subprocess
|
||||
from typing import List, Optional
|
||||
|
||||
_logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def execute_command(
|
||||
cmd: List[str],
|
||||
input_str: Optional[str] = None,
|
||||
) -> int:
|
||||
"""Executes commands in subprocess.
|
||||
|
||||
Executes the supplied command with the supplied standard input string, streams
|
||||
the output to stdout, and returns the process's return code.
|
||||
|
||||
Args:
|
||||
cmd (List[str]):
|
||||
Required. The strings to send in as the command.
|
||||
input_str (str):
|
||||
Optional. If supplied, it will be passed as stdin to the supplied command.
|
||||
If None, stdin will get closed immediately.
|
||||
|
||||
Returns:
|
||||
Return code of the process.
|
||||
"""
|
||||
with subprocess.Popen(
|
||||
cmd,
|
||||
stdin=subprocess.PIPE,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.STDOUT,
|
||||
universal_newlines=False,
|
||||
bufsize=1,
|
||||
) as p:
|
||||
if input_str:
|
||||
p.stdin.write(input_str.encode("utf-8"))
|
||||
p.stdin.close()
|
||||
|
||||
out = io.TextIOWrapper(p.stdout, newline="", encoding="utf-8", errors="replace")
|
||||
|
||||
for line in out:
|
||||
_logger.info(line)
|
||||
|
||||
return p.returncode
|
||||
@@ -0,0 +1,279 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
# Copyright 2022 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
import logging
|
||||
import os
|
||||
from pathlib import Path
|
||||
import re
|
||||
from typing import Dict, List, Optional, Sequence
|
||||
|
||||
try:
|
||||
import docker
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"Docker is not installed and is required to run containers. "
|
||||
'Please install the SDK using `pip install "google-cloud-aiplatform[prediction]>=1.16.0"`.'
|
||||
)
|
||||
|
||||
from google.cloud.aiplatform.constants import prediction
|
||||
from google.cloud.aiplatform.docker_utils.utils import DEFAULT_MOUNTED_MODEL_DIRECTORY
|
||||
from google.cloud.aiplatform.utils import prediction_utils
|
||||
|
||||
_logger = logging.getLogger(__name__)
|
||||
|
||||
_DEFAULT_CONTAINER_CRED_KEY_PATH = "/tmp/keys/cred_key.json"
|
||||
_ADC_ENVIRONMENT_VARIABLE = "GOOGLE_APPLICATION_CREDENTIALS"
|
||||
|
||||
CONTAINER_RUNNING_STATUS = "running"
|
||||
|
||||
|
||||
def _get_adc_environment_variable() -> Optional[str]:
|
||||
"""Gets the value of the ADC environment variable.
|
||||
|
||||
Returns:
|
||||
The value of the environment variable or None if unset.
|
||||
"""
|
||||
return os.environ.get(_ADC_ENVIRONMENT_VARIABLE)
|
||||
|
||||
|
||||
def _replace_env_var_reference(
|
||||
target: str,
|
||||
env_vars: Dict[str, str],
|
||||
) -> str:
|
||||
"""Replaces the environment variable reference in the given string.
|
||||
|
||||
Variable references $(VAR_NAME) are expanded using the container's environment.
|
||||
If a variable cannot be resolved, the reference in the input string will be unchanged.
|
||||
The $(VAR_NAME) syntax can be escaped with a double $$, ie: $$(VAR_NAME). Escaped
|
||||
references will never be expanded, regardless of whether the variable exists or not.
|
||||
More info:
|
||||
https://kubernetes.io/docs/tasks/inject-data-application/define-command-argument-container/#running-a-command-in-a-shell
|
||||
|
||||
Args:
|
||||
target (str):
|
||||
Required. The string to be replaced with the environment variable reference.
|
||||
env_vars (Dict[str, str]):
|
||||
Required. The environment variables used for reference.
|
||||
|
||||
Returns:
|
||||
The updated string.
|
||||
"""
|
||||
# Replace env var references with env vars.
|
||||
for key, value in env_vars.items():
|
||||
target = re.sub(rf"(?<!\$)\$\({key}\)", str(value), target)
|
||||
# Replace $$ with $.
|
||||
target = re.sub(r"\$\$", "$", target)
|
||||
return target
|
||||
|
||||
|
||||
def run_prediction_container(
|
||||
serving_container_image_uri: str,
|
||||
artifact_uri: Optional[str] = None,
|
||||
serving_container_predict_route: Optional[str] = None,
|
||||
serving_container_health_route: Optional[str] = None,
|
||||
serving_container_command: Optional[Sequence[str]] = None,
|
||||
serving_container_args: Optional[Sequence[str]] = None,
|
||||
serving_container_environment_variables: Optional[Dict[str, str]] = None,
|
||||
serving_container_ports: Optional[Sequence[int]] = None,
|
||||
credential_path: Optional[str] = None,
|
||||
host_port: Optional[int] = None,
|
||||
gpu_count: Optional[int] = None,
|
||||
gpu_device_ids: Optional[List[str]] = None,
|
||||
gpu_capabilities: Optional[List[List[str]]] = None,
|
||||
) -> docker.models.containers.Container:
|
||||
"""Runs a prediction container locally.
|
||||
|
||||
Args:
|
||||
serving_container_image_uri (str):
|
||||
Required. The URI of the Model serving container.
|
||||
artifact_uri (str):
|
||||
Optional. The path to the directory containing the Model artifact and any of its
|
||||
supporting files. The path is either a GCS uri or the path to a local directory.
|
||||
If this parameter is set to a GCS uri:
|
||||
(1) `credential_path` must be specified for local prediction.
|
||||
(2) The GCS uri will be passed directly to `Predictor.load`.
|
||||
If this parameter is a local directory:
|
||||
(1) The directory will be mounted to a default temporary model path.
|
||||
(2) The mounted path will be passed to `Predictor.load`.
|
||||
serving_container_predict_route (str):
|
||||
Optional. An HTTP path to send prediction requests to the container, and
|
||||
which must be supported by it. If not specified a default HTTP path will
|
||||
be used by Vertex AI.
|
||||
serving_container_health_route (str):
|
||||
Optional. An HTTP path to send health check requests to the container, and which
|
||||
must be supported by it. If not specified a standard HTTP path will be
|
||||
used by Vertex AI.
|
||||
serving_container_command (Sequence[str]):
|
||||
Optional. The command with which the container is run. Not executed within a
|
||||
shell. The Docker image's ENTRYPOINT is used if this is not provided.
|
||||
Variable references $(VAR_NAME) are expanded using the container's
|
||||
environment. If a variable cannot be resolved, the reference in the
|
||||
input string will be unchanged. The $(VAR_NAME) syntax can be escaped
|
||||
with a double $$, ie: $$(VAR_NAME). Escaped references will never be
|
||||
expanded, regardless of whether the variable exists or not.
|
||||
serving_container_args: (Sequence[str]):
|
||||
Optional. The arguments to the command. The Docker image's CMD is used if this is
|
||||
not provided. Variable references $(VAR_NAME) are expanded using the
|
||||
container's environment. If a variable cannot be resolved, the reference
|
||||
in the input string will be unchanged. The $(VAR_NAME) syntax can be
|
||||
escaped with a double $$, ie: $$(VAR_NAME). Escaped references will
|
||||
never be expanded, regardless of whether the variable exists or not.
|
||||
serving_container_environment_variables (Dict[str, str]):
|
||||
Optional. The environment variables that are to be present in the container.
|
||||
Should be a dictionary where keys are environment variable names
|
||||
and values are environment variable values for those names.
|
||||
serving_container_ports (Sequence[int]):
|
||||
Optional. Declaration of ports that are exposed by the container. This field is
|
||||
primarily informational, it gives Vertex AI information about the
|
||||
network connections the container uses. Listing or not a port here has
|
||||
no impact on whether the port is actually exposed, any port listening on
|
||||
the default "0.0.0.0" address inside a container will be accessible from
|
||||
the network.
|
||||
credential_path (str):
|
||||
Optional. The path to the credential key that will be mounted to the container.
|
||||
If it's unset, the environment variable, GOOGLE_APPLICATION_CREDENTIALS, will
|
||||
be used if set.
|
||||
host_port (int):
|
||||
Optional. The port on the host that the port, AIP_HTTP_PORT, inside the container
|
||||
will be exposed as. If it's unset, a random host port will be assigned.
|
||||
gpu_count (int):
|
||||
Optional. Number of devices to request. Set to -1 to request all available devices.
|
||||
To use GPU, set either `gpu_count` or `gpu_device_ids`.
|
||||
gpu_device_ids (List[str]):
|
||||
Optional. This parameter corresponds to `NVIDIA_VISIBLE_DEVICES` in the NVIDIA
|
||||
Runtime.
|
||||
To use GPU, set either `gpu_count` or `gpu_device_ids`.
|
||||
gpu_capabilities (List[List[str]]):
|
||||
Optional. This parameter corresponds to `NVIDIA_DRIVER_CAPABILITIES` in the NVIDIA
|
||||
Runtime. This must be set to use GPU. The outer list acts like an OR, and each
|
||||
sub-list acts like an AND. The driver will try to satisfy one of the sub-lists.
|
||||
Available capabilities for the NVIDIA driver can be found in
|
||||
https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/user-guide.html#driver-capabilities.
|
||||
|
||||
Returns:
|
||||
The container object running in the background.
|
||||
|
||||
Raises:
|
||||
ValueError: If artifact_uri does not exist if artifact_uri is a path to a local directory,
|
||||
or if credential_path or the file pointed by the environment variable
|
||||
GOOGLE_APPLICATION_CREDENTIALS does not exist.
|
||||
docker.errors.ImageNotFound: If the specified image does not exist.
|
||||
docker.errors.APIError: If the server returns an error.
|
||||
"""
|
||||
client = docker.from_env()
|
||||
|
||||
envs = {}
|
||||
if serving_container_environment_variables:
|
||||
for key, value in serving_container_environment_variables.items():
|
||||
envs[key] = _replace_env_var_reference(value, envs)
|
||||
|
||||
port = prediction_utils.get_prediction_aip_http_port(serving_container_ports)
|
||||
envs[prediction.AIP_HTTP_PORT] = port
|
||||
|
||||
envs[prediction.AIP_HEALTH_ROUTE] = serving_container_health_route
|
||||
envs[prediction.AIP_PREDICT_ROUTE] = serving_container_predict_route
|
||||
|
||||
volumes = []
|
||||
envs[prediction.AIP_STORAGE_URI] = artifact_uri or ""
|
||||
if artifact_uri and not artifact_uri.startswith(prediction_utils.GCS_URI_PREFIX):
|
||||
artifact_uri_on_host = Path(artifact_uri).expanduser().resolve()
|
||||
if not artifact_uri_on_host.exists():
|
||||
raise ValueError(
|
||||
"artifact_uri should be specified as either a GCS uri which starts with "
|
||||
f"`{prediction_utils.GCS_URI_PREFIX}` or a path to a local directory. "
|
||||
f'However, "{artifact_uri_on_host}" does not exist.'
|
||||
)
|
||||
|
||||
for mounted_path in artifact_uri_on_host.rglob("*"):
|
||||
relative_mounted_path = mounted_path.relative_to(artifact_uri_on_host)
|
||||
volumes += [
|
||||
f"{mounted_path}:{os.path.join(DEFAULT_MOUNTED_MODEL_DIRECTORY, relative_mounted_path)}"
|
||||
]
|
||||
envs[prediction.AIP_STORAGE_URI] = DEFAULT_MOUNTED_MODEL_DIRECTORY
|
||||
|
||||
credential_from_adc_env = credential_path is None
|
||||
credential_path = credential_path or _get_adc_environment_variable()
|
||||
if credential_path:
|
||||
credential_path_on_host = Path(credential_path).expanduser().resolve()
|
||||
if not credential_path_on_host.exists() and credential_from_adc_env:
|
||||
raise ValueError(
|
||||
f"The file from the environment variable {_ADC_ENVIRONMENT_VARIABLE} does "
|
||||
f'not exist: "{credential_path}".'
|
||||
)
|
||||
elif not credential_path_on_host.exists() and not credential_from_adc_env:
|
||||
raise ValueError(f'credential_path does not exist: "{credential_path}".')
|
||||
credential_mount_path = _DEFAULT_CONTAINER_CRED_KEY_PATH
|
||||
volumes = volumes + [f"{credential_path_on_host}:{credential_mount_path}"]
|
||||
envs[_ADC_ENVIRONMENT_VARIABLE] = credential_mount_path
|
||||
|
||||
entrypoint = [
|
||||
_replace_env_var_reference(i, envs) for i in serving_container_command or []
|
||||
]
|
||||
command = [
|
||||
_replace_env_var_reference(i, envs) for i in serving_container_args or []
|
||||
]
|
||||
|
||||
device_requests = None
|
||||
if gpu_count or gpu_device_ids or gpu_capabilities:
|
||||
device_requests = [
|
||||
docker.types.DeviceRequest(
|
||||
count=gpu_count,
|
||||
device_ids=gpu_device_ids,
|
||||
capabilities=gpu_capabilities,
|
||||
)
|
||||
]
|
||||
|
||||
container = client.containers.run(
|
||||
serving_container_image_uri,
|
||||
command=command if len(command) > 0 else None,
|
||||
entrypoint=entrypoint if len(entrypoint) > 0 else None,
|
||||
ports={port: host_port},
|
||||
environment=envs,
|
||||
volumes=volumes,
|
||||
device_requests=device_requests,
|
||||
detach=True,
|
||||
)
|
||||
|
||||
return container
|
||||
|
||||
|
||||
def print_container_logs(
|
||||
container: docker.models.containers.Container,
|
||||
start_index: Optional[int] = None,
|
||||
message: Optional[str] = None,
|
||||
) -> int:
|
||||
"""Prints container logs.
|
||||
|
||||
Args:
|
||||
container (docker.models.containers.Container):
|
||||
Required. The container object to print the logs.
|
||||
start_index (int):
|
||||
Optional. The index of log entries to start printing.
|
||||
message (str):
|
||||
Optional. The message to be printed before printing the logs.
|
||||
|
||||
Returns:
|
||||
The total number of log entries.
|
||||
"""
|
||||
if message is not None:
|
||||
_logger.info(message)
|
||||
|
||||
logs = container.logs().decode("utf-8").strip().split("\n")
|
||||
start_index = 0 if start_index is None else start_index
|
||||
for i in range(start_index, len(logs)):
|
||||
_logger.info(logs[i])
|
||||
return len(logs)
|
||||
@@ -0,0 +1,51 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
# Copyright 2022 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
import collections
|
||||
|
||||
try:
|
||||
import docker
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"Docker is not installed and is required to run containers. "
|
||||
'Please install the SDK using `pip install "google-cloud-aiplatform[prediction]>=1.16.0"`.'
|
||||
)
|
||||
|
||||
|
||||
Package = collections.namedtuple("Package", ["script", "package_path", "python_module"])
|
||||
Image = collections.namedtuple("Image", ["name", "default_home", "default_workdir"])
|
||||
DEFAULT_HOME = "/home"
|
||||
DEFAULT_WORKDIR = "/usr/app"
|
||||
DEFAULT_MOUNTED_MODEL_DIRECTORY = "/tmp_cpr_local_model"
|
||||
|
||||
|
||||
def check_image_exists_locally(image_name: str) -> bool:
|
||||
"""Checks if an image exists locally.
|
||||
|
||||
Args:
|
||||
image_name (str):
|
||||
Required. The name of the image.
|
||||
|
||||
Returns:
|
||||
Whether the image exists locally.
|
||||
"""
|
||||
client = docker.from_env()
|
||||
try:
|
||||
_ = client.images.get(image_name)
|
||||
return True
|
||||
except (docker.errors.ImageNotFound, docker.errors.APIError):
|
||||
return False
|
||||
@@ -0,0 +1,65 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
# Copyright 2021 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
from google.cloud.aiplatform.compat.types import (
|
||||
explanation as explanation_compat,
|
||||
explanation_metadata as explanation_metadata_compat,
|
||||
)
|
||||
|
||||
ExplanationMetadata = explanation_metadata_compat.ExplanationMetadata
|
||||
|
||||
# ExplanationMetadata subclasses
|
||||
InputMetadata = ExplanationMetadata.InputMetadata
|
||||
OutputMetadata = ExplanationMetadata.OutputMetadata
|
||||
|
||||
# InputMetadata subclasses
|
||||
Encoding = InputMetadata.Encoding
|
||||
FeatureValueDomain = InputMetadata.FeatureValueDomain
|
||||
Visualization = InputMetadata.Visualization
|
||||
|
||||
|
||||
ExplanationParameters = explanation_compat.ExplanationParameters
|
||||
FeatureNoiseSigma = explanation_compat.FeatureNoiseSigma
|
||||
|
||||
ExplanationSpec = explanation_compat.ExplanationSpec
|
||||
|
||||
# Classes used by ExplanationParameters
|
||||
IntegratedGradientsAttribution = explanation_compat.IntegratedGradientsAttribution
|
||||
SampledShapleyAttribution = explanation_compat.SampledShapleyAttribution
|
||||
SmoothGradConfig = explanation_compat.SmoothGradConfig
|
||||
XraiAttribution = explanation_compat.XraiAttribution
|
||||
Presets = explanation_compat.Presets
|
||||
Examples = explanation_compat.Examples
|
||||
|
||||
|
||||
__all__ = (
|
||||
"Encoding",
|
||||
"ExplanationSpec",
|
||||
"ExplanationMetadata",
|
||||
"ExplanationParameters",
|
||||
"FeatureNoiseSigma",
|
||||
"FeatureValueDomain",
|
||||
"InputMetadata",
|
||||
"IntegratedGradientsAttribution",
|
||||
"OutputMetadata",
|
||||
"SampledShapleyAttribution",
|
||||
"SmoothGradConfig",
|
||||
"Visualization",
|
||||
"XraiAttribution",
|
||||
"Presets",
|
||||
"Examples",
|
||||
)
|
||||
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,482 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
# Copyright 2021 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
import os
|
||||
|
||||
from google.cloud import aiplatform
|
||||
from typing import Dict, List, Mapping, Optional, Tuple, Union
|
||||
|
||||
try:
|
||||
from lit_nlp.api import dataset as lit_dataset
|
||||
from lit_nlp.api import dtypes as lit_dtypes
|
||||
from lit_nlp.api import model as lit_model
|
||||
from lit_nlp.api import types as lit_types
|
||||
from lit_nlp import notebook
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"LIT is not installed and is required to get Dataset as the return format. "
|
||||
'Please install the SDK using "pip install google-cloud-aiplatform[lit]"'
|
||||
)
|
||||
|
||||
try:
|
||||
import tensorflow as tf
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"Tensorflow is not installed and is required to load saved model. "
|
||||
'Please install the SDK using "pip install google-cloud-aiplatform[lit]"'
|
||||
)
|
||||
|
||||
try:
|
||||
import pandas as pd
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"Pandas is not installed and is required to read the dataset. "
|
||||
'Please install Pandas using "pip install google-cloud-aiplatform[lit]"'
|
||||
)
|
||||
|
||||
|
||||
class _VertexLitDataset(lit_dataset.Dataset):
|
||||
"""LIT dataset class for the Vertex LIT integration.
|
||||
|
||||
This is used in the create_lit_dataset function.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dataset: pd.DataFrame,
|
||||
column_types: "OrderedDict[str, lit_types.LitType]", # noqa: F821
|
||||
):
|
||||
"""Construct a VertexLitDataset.
|
||||
Args:
|
||||
dataset:
|
||||
Required. A Pandas DataFrame that includes feature column names and data.
|
||||
column_types:
|
||||
Required. An OrderedDict of string names matching the columns of the dataset
|
||||
as the key, and the associated LitType of the column.
|
||||
"""
|
||||
self._examples = dataset.to_dict(orient="records")
|
||||
self._column_types = column_types
|
||||
|
||||
def spec(self):
|
||||
"""Return a spec describing dataset elements."""
|
||||
return dict(self._column_types)
|
||||
|
||||
|
||||
class _EndpointLitModel(lit_model.Model):
|
||||
"""LIT model class for the Vertex LIT integration with a model deployed to an endpoint.
|
||||
|
||||
This is used in the create_lit_model function.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
endpoint: Union[str, aiplatform.Endpoint],
|
||||
input_types: "OrderedDict[str, lit_types.LitType]", # noqa: F821
|
||||
output_types: "OrderedDict[str, lit_types.LitType]", # noqa: F821
|
||||
model_id: Optional[str] = None,
|
||||
):
|
||||
"""Construct a VertexLitModel.
|
||||
Args:
|
||||
model:
|
||||
Required. The name of the Endpoint resource. Format:
|
||||
``projects/{project}/locations/{location}/endpoints/{endpoint}``
|
||||
input_types:
|
||||
Required. An OrderedDict of string names matching the features of the model
|
||||
as the key, and the associated LitType of the feature.
|
||||
output_types:
|
||||
Required. An OrderedDict of string names matching the labels of the model
|
||||
as the key, and the associated LitType of the label.
|
||||
model_id:
|
||||
Optional. A string of the specific model in the endpoint to create the
|
||||
LIT model from. If this is not set, any usable model in the endpoint is
|
||||
used to create the LIT model.
|
||||
Raises:
|
||||
ValueError if the model_id was not found in the endpoint.
|
||||
"""
|
||||
if isinstance(endpoint, str):
|
||||
self._endpoint = aiplatform.Endpoint(endpoint)
|
||||
else:
|
||||
self._endpoint = endpoint
|
||||
self._model_id = model_id
|
||||
self._input_types = input_types
|
||||
self._output_types = output_types
|
||||
# Check if the model with the model ID has explanation enabled
|
||||
if model_id:
|
||||
deployed_model = next(
|
||||
filter(
|
||||
lambda model: model.id == model_id, self._endpoint.list_models()
|
||||
),
|
||||
None,
|
||||
)
|
||||
if not deployed_model:
|
||||
raise ValueError(
|
||||
"A model with id {model_id} was not found in the endpoint {endpoint}.".format(
|
||||
model_id=model_id, endpoint=endpoint
|
||||
)
|
||||
)
|
||||
self._explanation_enabled = bool(deployed_model.explanation_spec)
|
||||
# Check if all models in the endpoint have explanation enabled
|
||||
else:
|
||||
self._explanation_enabled = all(
|
||||
model.explanation_spec for model in self._endpoint.list_models()
|
||||
)
|
||||
|
||||
def predict_minibatch(
|
||||
self, inputs: List[lit_types.JsonDict]
|
||||
) -> List[lit_types.JsonDict]:
|
||||
"""Retun predictions based on a batch of inputs.
|
||||
Args:
|
||||
inputs: Requred. a List of instances to predict on based on the input spec.
|
||||
Returns:
|
||||
A list of predictions based on the output spec.
|
||||
"""
|
||||
instances = []
|
||||
for input in inputs:
|
||||
instance = [input[feature] for feature in self._input_types]
|
||||
instances.append(instance)
|
||||
if self._explanation_enabled:
|
||||
prediction_object = self._endpoint.explain(instances)
|
||||
else:
|
||||
prediction_object = self._endpoint.predict(instances)
|
||||
outputs = []
|
||||
for prediction in prediction_object.predictions:
|
||||
if isinstance(prediction, Mapping):
|
||||
outputs.append({key: prediction[key] for key in self._output_types})
|
||||
else:
|
||||
outputs.append(
|
||||
{key: prediction[i] for i, key in enumerate(self._output_types)}
|
||||
)
|
||||
if self._explanation_enabled:
|
||||
for i, explanation in enumerate(prediction_object.explanations):
|
||||
attributions = explanation.attributions
|
||||
outputs[i]["feature_attribution"] = lit_dtypes.FeatureSalience(
|
||||
attributions
|
||||
)
|
||||
return outputs
|
||||
|
||||
def input_spec(self) -> lit_types.Spec:
|
||||
"""Return a spec describing model inputs."""
|
||||
return dict(self._input_types)
|
||||
|
||||
def output_spec(self) -> lit_types.Spec:
|
||||
"""Return a spec describing model outputs."""
|
||||
output_spec_dict = dict(self._output_types)
|
||||
if self._explanation_enabled:
|
||||
output_spec_dict["feature_attribution"] = lit_types.FeatureSalience(
|
||||
signed=True
|
||||
)
|
||||
return output_spec_dict
|
||||
|
||||
|
||||
class _TensorFlowLitModel(lit_model.Model):
|
||||
"""LIT model class for the Vertex LIT integration with a TensorFlow saved model.
|
||||
|
||||
This is used in the create_lit_model function.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: str,
|
||||
input_types: "OrderedDict[str, lit_types.LitType]", # noqa: F821
|
||||
output_types: "OrderedDict[str, lit_types.LitType]", # noqa: F821
|
||||
attribution_method: str = "sampled_shapley",
|
||||
):
|
||||
"""Construct a VertexLitModel.
|
||||
Args:
|
||||
model:
|
||||
Required. A string reference to a local TensorFlow saved model directory.
|
||||
The model must have at most one input and one output tensor.
|
||||
input_types:
|
||||
Required. An OrderedDict of string names matching the features of the model
|
||||
as the key, and the associated LitType of the feature.
|
||||
output_types:
|
||||
Required. An OrderedDict of string names matching the labels of the model
|
||||
as the key, and the associated LitType of the label.
|
||||
attribution_method:
|
||||
Optional. A string to choose what attribution configuration to
|
||||
set up the explainer with. Valid options are 'sampled_shapley'
|
||||
or 'integrated_gradients'.
|
||||
"""
|
||||
self._load_model(model)
|
||||
self._input_types = input_types
|
||||
self._output_types = output_types
|
||||
self._input_tensor_name = next(iter(self._kwargs_signature))
|
||||
self._attribution_explainer = None
|
||||
if os.environ.get("LIT_PROXY_URL"):
|
||||
self._set_up_attribution_explainer(model, attribution_method)
|
||||
|
||||
@property
|
||||
def attribution_explainer(
|
||||
self,
|
||||
) -> Optional["AttributionExplainer"]: # noqa: F821
|
||||
"""Gets the attribution explainer property if set."""
|
||||
return self._attribution_explainer
|
||||
|
||||
def predict_minibatch(
|
||||
self, inputs: List[lit_types.JsonDict]
|
||||
) -> List[lit_types.JsonDict]:
|
||||
"""Retun predictions based on a batch of inputs.
|
||||
Args:
|
||||
inputs: Requred. a List of instances to predict on based on the input spec.
|
||||
Returns:
|
||||
A list of predictions based on the output spec.
|
||||
"""
|
||||
instances = []
|
||||
for input in inputs:
|
||||
instance = [input[feature] for feature in self._input_types]
|
||||
instances.append(instance)
|
||||
prediction_input_dict = {
|
||||
self._input_tensor_name: tf.convert_to_tensor(instances)
|
||||
}
|
||||
prediction_dict = self._loaded_model.signatures[
|
||||
tf.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY
|
||||
](**prediction_input_dict)
|
||||
predictions = prediction_dict[next(iter(self._output_signature))].numpy()
|
||||
outputs = []
|
||||
for prediction in predictions:
|
||||
outputs.append(
|
||||
{
|
||||
label: value
|
||||
for label, value in zip(self._output_types.keys(), prediction)
|
||||
}
|
||||
)
|
||||
# Get feature attributions
|
||||
if self.attribution_explainer:
|
||||
attributions = self.attribution_explainer.explain(
|
||||
[{self._input_tensor_name: i} for i in instances]
|
||||
)
|
||||
for i, attribution in enumerate(attributions):
|
||||
outputs[i]["feature_attribution"] = lit_dtypes.FeatureSalience(
|
||||
attribution.feature_importance()
|
||||
)
|
||||
return outputs
|
||||
|
||||
def input_spec(self) -> lit_types.Spec:
|
||||
"""Return a spec describing model inputs."""
|
||||
return dict(self._input_types)
|
||||
|
||||
def output_spec(self) -> lit_types.Spec:
|
||||
"""Return a spec describing model outputs."""
|
||||
output_spec_dict = dict(self._output_types)
|
||||
if self.attribution_explainer:
|
||||
output_spec_dict["feature_attribution"] = lit_types.FeatureSalience(
|
||||
signed=True
|
||||
)
|
||||
return output_spec_dict
|
||||
|
||||
def _load_model(self, model: str):
|
||||
"""Loads a TensorFlow saved model and populates the input and output signature attributes of the class.
|
||||
Args:
|
||||
model: Required. A string reference to a TensorFlow saved model directory.
|
||||
Raises:
|
||||
ValueError if the model has more than one input tensor or more than one output tensor.
|
||||
"""
|
||||
self._loaded_model = tf.saved_model.load(model)
|
||||
serving_default = self._loaded_model.signatures[
|
||||
tf.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY
|
||||
]
|
||||
_, self._kwargs_signature = serving_default.structured_input_signature
|
||||
self._output_signature = serving_default.structured_outputs
|
||||
|
||||
if len(self._kwargs_signature) != 1:
|
||||
raise ValueError("Please use a model with only one input tensor.")
|
||||
|
||||
if len(self._output_signature) != 1:
|
||||
raise ValueError("Please use a model with only one output tensor.")
|
||||
|
||||
def _set_up_attribution_explainer(
|
||||
self, model: str, attribution_method: str = "integrated_gradients"
|
||||
):
|
||||
"""Populates the attribution explainer attribute of the class.
|
||||
Args:
|
||||
model: Required. A string reference to a TensorFlow saved model directory.
|
||||
attribution_method:
|
||||
Optional. A string to choose what attribution configuration to
|
||||
set up the explainer with. Valid options are 'sampled_shapley'
|
||||
or 'integrated_gradients'.
|
||||
"""
|
||||
try:
|
||||
import explainable_ai_sdk
|
||||
from explainable_ai_sdk.metadata.tf.v2 import SavedModelMetadataBuilder
|
||||
except ImportError:
|
||||
logging.info(
|
||||
"Skipping explanations because the Explainable AI SDK is not installed."
|
||||
'Please install the SDK using "pip install explainable-ai-sdk"'
|
||||
)
|
||||
return
|
||||
|
||||
builder = SavedModelMetadataBuilder(model)
|
||||
builder.get_metadata()
|
||||
builder.set_numeric_metadata(
|
||||
self._input_tensor_name,
|
||||
index_feature_mapping=list(self._input_types.keys()),
|
||||
)
|
||||
builder.save_metadata(model)
|
||||
if attribution_method == "integrated_gradients":
|
||||
explainer_config = explainable_ai_sdk.IntegratedGradientsConfig()
|
||||
else:
|
||||
explainer_config = explainable_ai_sdk.SampledShapleyConfig()
|
||||
|
||||
self._attribution_explainer = explainable_ai_sdk.load_model_from_local_path(
|
||||
model, explainer_config
|
||||
)
|
||||
self._load_model(model)
|
||||
|
||||
|
||||
def create_lit_dataset(
|
||||
dataset: pd.DataFrame,
|
||||
column_types: "OrderedDict[str, lit_types.LitType]", # noqa: F821
|
||||
) -> lit_dataset.Dataset:
|
||||
"""Creates a LIT Dataset object.
|
||||
Args:
|
||||
dataset:
|
||||
Required. A Pandas DataFrame that includes feature column names and data.
|
||||
column_types:
|
||||
Required. An OrderedDict of string names matching the columns of the dataset
|
||||
as the key, and the associated LitType of the column.
|
||||
Returns:
|
||||
A LIT Dataset object that has the data from the dataset provided.
|
||||
"""
|
||||
return _VertexLitDataset(dataset, column_types)
|
||||
|
||||
|
||||
def create_lit_model_from_endpoint(
|
||||
endpoint: Union[str, aiplatform.Endpoint],
|
||||
input_types: "OrderedDict[str, lit_types.LitType]", # noqa: F821
|
||||
output_types: "OrderedDict[str, lit_types.LitType]", # noqa: F821
|
||||
model_id: Optional[str] = None,
|
||||
) -> lit_model.Model:
|
||||
"""Creates a LIT Model object.
|
||||
Args:
|
||||
model:
|
||||
Required. The name of the Endpoint resource or an Endpoint instance.
|
||||
Endpoint name format: ``projects/{project}/locations/{location}/endpoints/{endpoint}``
|
||||
input_types:
|
||||
Required. An OrderedDict of string names matching the features of the model
|
||||
as the key, and the associated LitType of the feature.
|
||||
output_types:
|
||||
Required. An OrderedDict of string names matching the labels of the model
|
||||
as the key, and the associated LitType of the label.
|
||||
model_id:
|
||||
Optional. A string of the specific model in the endpoint to create the
|
||||
LIT model from. If this is not set, any usable model in the endpoint is
|
||||
used to create the LIT model.
|
||||
Returns:
|
||||
A LIT Model object that has the same functionality as the model provided.
|
||||
"""
|
||||
return _EndpointLitModel(endpoint, input_types, output_types, model_id)
|
||||
|
||||
|
||||
def create_lit_model(
|
||||
model: str,
|
||||
input_types: "OrderedDict[str, lit_types.LitType]", # noqa: F821
|
||||
output_types: "OrderedDict[str, lit_types.LitType]", # noqa: F821
|
||||
attribution_method: str = "sampled_shapley",
|
||||
) -> lit_model.Model:
|
||||
"""Creates a LIT Model object.
|
||||
Args:
|
||||
model:
|
||||
Required. A string reference to a local TensorFlow saved model directory.
|
||||
The model must have at most one input and one output tensor.
|
||||
input_types:
|
||||
Required. An OrderedDict of string names matching the features of the model
|
||||
as the key, and the associated LitType of the feature.
|
||||
output_types:
|
||||
Required. An OrderedDict of string names matching the labels of the model
|
||||
as the key, and the associated LitType of the label.
|
||||
attribution_method:
|
||||
Optional. A string to choose what attribution configuration to
|
||||
set up the explainer with. Valid options are 'sampled_shapley'
|
||||
or 'integrated_gradients'.
|
||||
Returns:
|
||||
A LIT Model object that has the same functionality as the model provided.
|
||||
"""
|
||||
return _TensorFlowLitModel(model, input_types, output_types, attribution_method)
|
||||
|
||||
|
||||
def open_lit(
|
||||
models: Dict[str, lit_model.Model],
|
||||
datasets: Dict[str, lit_dataset.Dataset],
|
||||
open_in_new_tab: bool = True,
|
||||
):
|
||||
"""Open LIT from the provided models and datasets.
|
||||
Args:
|
||||
models:
|
||||
Required. A list of LIT models to open LIT with.
|
||||
input_types:
|
||||
Required. A lit of LIT datasets to open LIT with.
|
||||
open_in_new_tab:
|
||||
Optional. A boolean to choose if LIT open in a new tab or not.
|
||||
Raises:
|
||||
ImportError if LIT is not installed.
|
||||
"""
|
||||
widget = notebook.LitWidget(models, datasets)
|
||||
widget.render(open_in_new_tab=open_in_new_tab)
|
||||
|
||||
|
||||
def set_up_and_open_lit(
|
||||
dataset: Union[pd.DataFrame, lit_dataset.Dataset],
|
||||
column_types: "OrderedDict[str, lit_types.LitType]", # noqa: F821
|
||||
model: Union[str, lit_model.Model],
|
||||
input_types: Union[List[str], Dict[str, lit_types.LitType]],
|
||||
output_types: Union[str, List[str], Dict[str, lit_types.LitType]],
|
||||
attribution_method: str = "sampled_shapley",
|
||||
open_in_new_tab: bool = True,
|
||||
) -> Tuple[lit_dataset.Dataset, lit_model.Model]:
|
||||
"""Creates a LIT dataset and model and opens LIT.
|
||||
Args:
|
||||
dataset:
|
||||
Required. A Pandas DataFrame that includes feature column names and data.
|
||||
column_types:
|
||||
Required. An OrderedDict of string names matching the columns of the dataset
|
||||
as the key, and the associated LitType of the column.
|
||||
model:
|
||||
Required. A string reference to a TensorFlow saved model directory.
|
||||
The model must have at most one input and one output tensor.
|
||||
input_types:
|
||||
Required. An OrderedDict of string names matching the features of the model
|
||||
as the key, and the associated LitType of the feature.
|
||||
output_types:
|
||||
Required. An OrderedDict of string names matching the labels of the model
|
||||
as the key, and the associated LitType of the label.
|
||||
attribution_method:
|
||||
Optional. A string to choose what attribution configuration to
|
||||
set up the explainer with. Valid options are 'sampled_shapley'
|
||||
or 'integrated_gradients'.
|
||||
open_in_new_tab:
|
||||
Optional. A boolean to choose if LIT open in a new tab or not.
|
||||
Returns:
|
||||
A Tuple of the LIT dataset and model created.
|
||||
Raises:
|
||||
ImportError if LIT or TensorFlow is not installed.
|
||||
ValueError if the model doesn't have only 1 input and output tensor.
|
||||
"""
|
||||
if not isinstance(dataset, lit_dataset.Dataset):
|
||||
dataset = create_lit_dataset(dataset, column_types)
|
||||
|
||||
if not isinstance(model, lit_model.Model):
|
||||
model = create_lit_model(
|
||||
model, input_types, output_types, attribution_method=attribution_method
|
||||
)
|
||||
|
||||
open_lit(
|
||||
{"model": model},
|
||||
{"dataset": dataset},
|
||||
open_in_new_tab=open_in_new_tab,
|
||||
)
|
||||
|
||||
return dataset, model
|
||||
@@ -0,0 +1,15 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
# Copyright 2021 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,34 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
# Copyright 2021 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
"""Base abstract class for metadata builders."""
|
||||
|
||||
import abc
|
||||
|
||||
_ABC = abc.ABCMeta("ABC", (object,), {"__slots__": ()})
|
||||
|
||||
|
||||
class MetadataBuilder(_ABC):
|
||||
"""Abstract base class for metadata builders."""
|
||||
|
||||
@abc.abstractmethod
|
||||
def get_metadata(self):
|
||||
"""Returns the current metadata as a dictionary."""
|
||||
|
||||
@abc.abstractmethod
|
||||
def get_metadata_protobuf(self):
|
||||
"""Returns the current metadata as ExplanationMetadata protobuf"""
|
||||
@@ -0,0 +1,15 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
# Copyright 2021 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
Binary file not shown.
@@ -0,0 +1,15 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
# Copyright 2021 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,171 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
# Copyright 2021 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from google.protobuf import json_format
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from google.cloud.aiplatform.compat.types import explanation_metadata
|
||||
from google.cloud.aiplatform.explain.metadata import metadata_builder
|
||||
|
||||
|
||||
class SavedModelMetadataBuilder(metadata_builder.MetadataBuilder):
|
||||
"""Metadata builder class that accepts a TF1 saved model."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_path: str,
|
||||
tags: Optional[List[str]] = None,
|
||||
signature_name: Optional[str] = None,
|
||||
outputs_to_explain: Optional[List[str]] = None,
|
||||
) -> None:
|
||||
"""Initializes a SavedModelMetadataBuilder object.
|
||||
|
||||
Args:
|
||||
model_path:
|
||||
Required. Local or GCS path to load the saved model from.
|
||||
tags:
|
||||
Optional. Tags to identify the model graph. If None or empty,
|
||||
TensorFlow's default serving tag will be used.
|
||||
signature_name:
|
||||
Optional. Name of the signature to be explained. Inputs and
|
||||
outputs of this signature will be written in the metadata. If not
|
||||
provided, the default signature will be used.
|
||||
outputs_to_explain:
|
||||
Optional. List of output names to explain. Only single output is
|
||||
supported for now. Hence, the list should contain one element.
|
||||
This parameter is required if the model signature (provided via
|
||||
signature_name) specifies multiple outputs.
|
||||
|
||||
Raises:
|
||||
ValueError: If outputs_to_explain contains more than 1 element or
|
||||
signature contains multiple outputs.
|
||||
"""
|
||||
if outputs_to_explain:
|
||||
if len(outputs_to_explain) > 1:
|
||||
raise ValueError(
|
||||
"Only one output is supported at the moment. "
|
||||
f"Received: {outputs_to_explain}."
|
||||
)
|
||||
self._output_to_explain = next(iter(outputs_to_explain))
|
||||
|
||||
try:
|
||||
import tensorflow.compat.v1 as tf
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"Tensorflow is not installed and is required to load saved model. "
|
||||
'Please install the SDK using "pip install "tensorflow>=1.15,<2.0""'
|
||||
)
|
||||
|
||||
if not signature_name:
|
||||
signature_name = tf.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY
|
||||
self._tags = tags or [tf.saved_model.tag_constants.SERVING]
|
||||
self._graph = tf.Graph()
|
||||
|
||||
with self.graph.as_default():
|
||||
self._session = tf.Session(graph=self.graph)
|
||||
self._metagraph_def = tf.saved_model.loader.load(
|
||||
sess=self.session, tags=self._tags, export_dir=model_path
|
||||
)
|
||||
if signature_name not in self._metagraph_def.signature_def:
|
||||
raise ValueError(
|
||||
f"Serving sigdef key {signature_name} not in the signature def."
|
||||
)
|
||||
serving_sigdef = self._metagraph_def.signature_def[signature_name]
|
||||
if not outputs_to_explain:
|
||||
if len(serving_sigdef.outputs) > 1:
|
||||
raise ValueError(
|
||||
"The signature contains multiple outputs. Specify "
|
||||
'an output via "outputs_to_explain" parameter.'
|
||||
)
|
||||
self._output_to_explain = next(iter(serving_sigdef.outputs.keys()))
|
||||
|
||||
self._inputs = _create_input_metadata_from_signature(serving_sigdef.inputs)
|
||||
self._outputs = _create_output_metadata_from_signature(
|
||||
serving_sigdef.outputs, self._output_to_explain
|
||||
)
|
||||
|
||||
@property
|
||||
def graph(self) -> "tf.Graph": # noqa: F821
|
||||
return self._graph
|
||||
|
||||
@property
|
||||
def session(self) -> "tf.Session": # noqa: F821
|
||||
return self._session
|
||||
|
||||
def get_metadata(self) -> Dict[str, Any]:
|
||||
"""Returns the current metadata as a dictionary.
|
||||
|
||||
Returns:
|
||||
Json format of the explanation metadata.
|
||||
"""
|
||||
return json_format.MessageToDict(self.get_metadata_protobuf()._pb)
|
||||
|
||||
def get_metadata_protobuf(self) -> explanation_metadata.ExplanationMetadata:
|
||||
"""Returns the current metadata as a Protobuf object.
|
||||
|
||||
Returns:
|
||||
ExplanationMetadata object format of the explanation metadata.
|
||||
"""
|
||||
return explanation_metadata.ExplanationMetadata(
|
||||
inputs=self._inputs,
|
||||
outputs=self._outputs,
|
||||
)
|
||||
|
||||
|
||||
def _create_input_metadata_from_signature(
|
||||
signature_inputs: Dict[str, "tf.Tensor"] # noqa: F821
|
||||
) -> Dict[str, explanation_metadata.ExplanationMetadata.InputMetadata]:
|
||||
"""Creates InputMetadata from signature inputs.
|
||||
|
||||
Args:
|
||||
signature_inputs:
|
||||
Required. Inputs of the signature to be explained. If not provided,
|
||||
the default signature will be used.
|
||||
|
||||
Returns:
|
||||
Inferred input metadata from the model.
|
||||
"""
|
||||
input_mds = {}
|
||||
for key, tensor in signature_inputs.items():
|
||||
input_mds[key] = explanation_metadata.ExplanationMetadata.InputMetadata(
|
||||
input_tensor_name=tensor.name
|
||||
)
|
||||
return input_mds
|
||||
|
||||
|
||||
def _create_output_metadata_from_signature(
|
||||
signature_outputs: Dict[str, "tf.Tensor"], # noqa: F821
|
||||
output_to_explain: Optional[str] = None,
|
||||
) -> Dict[str, explanation_metadata.ExplanationMetadata.OutputMetadata]:
|
||||
"""Creates OutputMetadata from signature inputs.
|
||||
|
||||
Args:
|
||||
signature_outputs:
|
||||
Required. Inputs of the signature to be explained. If not provided,
|
||||
the default signature will be used.
|
||||
output_to_explain:
|
||||
Optional. Output name to explain.
|
||||
|
||||
Returns:
|
||||
Inferred output metadata from the model.
|
||||
"""
|
||||
output_mds = {}
|
||||
for key, tensor in signature_outputs.items():
|
||||
if not output_to_explain or output_to_explain == key:
|
||||
output_mds[key] = explanation_metadata.ExplanationMetadata.OutputMetadata(
|
||||
output_tensor_name=tensor.name
|
||||
)
|
||||
return output_mds
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user