structure saas with tools
This commit is contained in:
@@ -0,0 +1,46 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
# Copyright 2022 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
from google.cloud.aiplatform.prediction.handler import (
|
||||
Handler,
|
||||
PredictionHandler,
|
||||
)
|
||||
from google.cloud.aiplatform.prediction.local_endpoint import LocalEndpoint
|
||||
from google.cloud.aiplatform.prediction.local_model import (
|
||||
DEFAULT_HEALTH_ROUTE,
|
||||
DEFAULT_HTTP_PORT,
|
||||
DEFAULT_PREDICT_ROUTE,
|
||||
LocalModel,
|
||||
)
|
||||
from google.cloud.aiplatform.prediction.predictor import Predictor
|
||||
from google.cloud.aiplatform.prediction.serializer import (
|
||||
DefaultSerializer,
|
||||
Serializer,
|
||||
)
|
||||
|
||||
__all__ = (
|
||||
"DEFAULT_HEALTH_ROUTE",
|
||||
"DEFAULT_HTTP_PORT",
|
||||
"DEFAULT_PREDICT_ROUTE",
|
||||
"DefaultSerializer",
|
||||
"Handler",
|
||||
"LocalEndpoint",
|
||||
"LocalModel",
|
||||
"PredictionHandler",
|
||||
"Predictor",
|
||||
"Serializer",
|
||||
)
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,138 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
# Copyright 2022 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
import logging
|
||||
from typing import Optional, Type
|
||||
import traceback
|
||||
|
||||
try:
|
||||
from fastapi import HTTPException
|
||||
from fastapi import Request
|
||||
from fastapi import Response
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"FastAPI is not installed and is required to build model servers. "
|
||||
'Please install the SDK using `pip install "google-cloud-aiplatform[prediction]>=1.16.0"`.'
|
||||
)
|
||||
|
||||
from google.cloud.aiplatform.prediction import handler_utils
|
||||
from google.cloud.aiplatform.prediction.predictor import Predictor
|
||||
from google.cloud.aiplatform.prediction.serializer import DefaultSerializer
|
||||
|
||||
|
||||
class Handler(ABC):
|
||||
"""Interface for Handler class to handle prediction requests."""
|
||||
|
||||
@abstractmethod
|
||||
def __init__(
|
||||
self,
|
||||
artifacts_uri: str,
|
||||
predictor: Optional[Type[Predictor]] = None,
|
||||
):
|
||||
"""Initializes a Handler instance.
|
||||
|
||||
Args:
|
||||
artifacts_uri (str):
|
||||
Required. The value of the environment variable AIP_STORAGE_URI.
|
||||
predictor (Type[Predictor]):
|
||||
Optional. The Predictor class this handler uses to initiate predictor
|
||||
instance if given.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def handle(self, request: Request) -> Response:
|
||||
"""Handles a prediction request.
|
||||
|
||||
Args:
|
||||
request (Request):
|
||||
The request sent to the application.
|
||||
|
||||
Returns:
|
||||
The response of the prediction request.
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class PredictionHandler(Handler):
|
||||
"""Default prediction handler for the prediction requests sent to the application."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
artifacts_uri: str,
|
||||
predictor: Optional[Type[Predictor]] = None,
|
||||
):
|
||||
"""Initializes a Handler instance.
|
||||
|
||||
Args:
|
||||
artifacts_uri (str):
|
||||
Required. The value of the environment variable AIP_STORAGE_URI.
|
||||
predictor (Type[Predictor]):
|
||||
Optional. The Predictor class this handler uses to initiate predictor
|
||||
instance if given.
|
||||
|
||||
Raises:
|
||||
ValueError: If predictor is None.
|
||||
"""
|
||||
if predictor is None:
|
||||
raise ValueError(
|
||||
"PredictionHandler must have a predictor class passed to the init function."
|
||||
)
|
||||
|
||||
self._predictor = predictor()
|
||||
self._predictor.load(artifacts_uri)
|
||||
|
||||
async def handle(self, request: Request) -> Response:
|
||||
"""Handles a prediction request.
|
||||
|
||||
Args:
|
||||
request (Request):
|
||||
Required. The prediction request sent to the application.
|
||||
|
||||
Returns:
|
||||
The response of the prediction request.
|
||||
|
||||
Raises:
|
||||
HTTPException: If any exception is thrown from predictor object.
|
||||
"""
|
||||
request_body = await request.body()
|
||||
content_type = handler_utils.get_content_type_from_headers(request.headers)
|
||||
prediction_input = DefaultSerializer.deserialize(request_body, content_type)
|
||||
|
||||
try:
|
||||
prediction_results = self._predictor.postprocess(
|
||||
self._predictor.predict(self._predictor.preprocess(prediction_input))
|
||||
)
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as exception:
|
||||
error_message = (
|
||||
"The following exception has occurred: {}. Arguments: {}.".format(
|
||||
type(exception).__name__, exception.args
|
||||
)
|
||||
)
|
||||
logging.info(
|
||||
"{}\\nTraceback: {}".format(error_message, traceback.format_exc())
|
||||
)
|
||||
|
||||
# Converts all other exceptions to HTTPException.
|
||||
raise HTTPException(status_code=500, detail=error_message)
|
||||
|
||||
accept = handler_utils.get_accept_from_headers(request.headers)
|
||||
data = DefaultSerializer.serialize(prediction_results, accept)
|
||||
return Response(content=data, media_type=accept)
|
||||
@@ -0,0 +1,119 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
# Copyright 2022 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
from typing import Dict, Optional
|
||||
|
||||
try:
|
||||
import starlette
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"Starlette is not installed and is required to build model servers. "
|
||||
'Please install the SDK using `pip install "google-cloud-aiplatform[prediction]>=1.16.0"`.'
|
||||
)
|
||||
|
||||
from google.cloud.aiplatform.constants import prediction
|
||||
|
||||
|
||||
def _remove_parameter(value: Optional[str]) -> Optional[str]:
|
||||
"""Removes the parameter part from the header value.
|
||||
|
||||
Referring to https://www.w3.org/Protocols/rfc2616/rfc2616-sec3.html#sec3.7.
|
||||
|
||||
Args:
|
||||
value (str):
|
||||
Optional. The original full header value.
|
||||
|
||||
Returns:
|
||||
The value without the parameter or None.
|
||||
"""
|
||||
if value is None:
|
||||
return None
|
||||
|
||||
return value.split(";")[0]
|
||||
|
||||
|
||||
def get_content_type_from_headers(
|
||||
headers: Optional[starlette.datastructures.Headers],
|
||||
) -> Optional[str]:
|
||||
"""Gets content type from headers.
|
||||
|
||||
Args:
|
||||
headers (starlette.datastructures.Headers):
|
||||
Optional. The headers that the content type is retrived from.
|
||||
|
||||
Returns:
|
||||
The content type or None.
|
||||
"""
|
||||
if headers is not None:
|
||||
for key, value in headers.items():
|
||||
if prediction.CONTENT_TYPE_HEADER_REGEX.match(key):
|
||||
return _remove_parameter(value)
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def get_accept_from_headers(
|
||||
headers: Optional[starlette.datastructures.Headers],
|
||||
) -> str:
|
||||
"""Gets accept from headers.
|
||||
|
||||
Default to "application/json" if it is unset.
|
||||
|
||||
Args:
|
||||
headers (starlette.datastructures.Headers):
|
||||
Optional. The headers that the accept is retrived from.
|
||||
|
||||
Returns:
|
||||
The accept.
|
||||
"""
|
||||
if headers is not None:
|
||||
for key, value in headers.items():
|
||||
if prediction.ACCEPT_HEADER_REGEX.match(key):
|
||||
return value
|
||||
|
||||
return prediction.DEFAULT_ACCEPT_VALUE
|
||||
|
||||
|
||||
def parse_accept_header(accept_header: Optional[str]) -> Dict[str, float]:
|
||||
"""Parses the accept header with quality factors.
|
||||
|
||||
Referring to https://www.w3.org/Protocols/rfc2616/rfc2616-sec14.html.
|
||||
|
||||
The default quality factor is 1.
|
||||
|
||||
Args:
|
||||
accept_header (str):
|
||||
Optional. The accept header.
|
||||
|
||||
Returns:
|
||||
A dictionary with media types pointing to the quality factors.
|
||||
"""
|
||||
if not accept_header:
|
||||
return {}
|
||||
|
||||
all_accepts = accept_header.split(",")
|
||||
results = {}
|
||||
|
||||
for media_type in all_accepts:
|
||||
if media_type.split(";")[0] == media_type:
|
||||
# no q => q = 1
|
||||
results[media_type.strip()] = 1.0
|
||||
else:
|
||||
q = media_type.split(";")[1].split("=")[1]
|
||||
results[media_type.split(";")[0].strip()] = float(q)
|
||||
|
||||
return results
|
||||
@@ -0,0 +1,469 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
# Copyright 2022 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
import logging
|
||||
from pathlib import Path
|
||||
import requests
|
||||
import time
|
||||
from typing import Any, Dict, List, Optional, Sequence
|
||||
|
||||
from google.auth.exceptions import GoogleAuthError
|
||||
|
||||
from google.cloud.aiplatform import initializer
|
||||
from google.cloud.aiplatform.constants import prediction
|
||||
from google.cloud.aiplatform.docker_utils import run
|
||||
from google.cloud.aiplatform.docker_utils.errors import DockerError
|
||||
from google.cloud.aiplatform.utils import prediction_utils
|
||||
|
||||
_logger = logging.getLogger(__name__)
|
||||
|
||||
_DEFAULT_CONTAINER_READY_TIMEOUT = 300
|
||||
_DEFAULT_CONTAINER_READY_CHECK_INTERVAL = 1
|
||||
|
||||
_GCLOUD_PROJECT_ENV = "GOOGLE_CLOUD_PROJECT"
|
||||
|
||||
|
||||
class LocalEndpoint:
|
||||
"""Class that represents a local endpoint."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
serving_container_image_uri: str,
|
||||
artifact_uri: Optional[str] = None,
|
||||
serving_container_predict_route: Optional[str] = None,
|
||||
serving_container_health_route: Optional[str] = None,
|
||||
serving_container_command: Optional[Sequence[str]] = None,
|
||||
serving_container_args: Optional[Sequence[str]] = None,
|
||||
serving_container_environment_variables: Optional[Dict[str, str]] = None,
|
||||
serving_container_ports: Optional[Sequence[int]] = None,
|
||||
credential_path: Optional[str] = None,
|
||||
host_port: Optional[str] = None,
|
||||
gpu_count: Optional[int] = None,
|
||||
gpu_device_ids: Optional[List[str]] = None,
|
||||
gpu_capabilities: Optional[List[List[str]]] = None,
|
||||
container_ready_timeout: Optional[int] = None,
|
||||
container_ready_check_interval: Optional[int] = None,
|
||||
):
|
||||
"""Creates a local endpoint instance.
|
||||
|
||||
Args:
|
||||
serving_container_image_uri (str):
|
||||
Required. The URI of the Model serving container.
|
||||
artifact_uri (str):
|
||||
Optional. The path to the directory containing the Model artifact and any of its
|
||||
supporting files. The path is either a GCS uri or the path to a local directory.
|
||||
If this parameter is set to a GCS uri:
|
||||
(1) ``credential_path`` must be specified for local prediction.
|
||||
(2) The GCS uri will be passed directly to ``Predictor.load``.
|
||||
If this parameter is a local directory:
|
||||
(1) The directory will be mounted to a default temporary model path.
|
||||
(2) The mounted path will be passed to ``Predictor.load``.
|
||||
serving_container_predict_route (str):
|
||||
Optional. An HTTP path to send prediction requests to the container, and
|
||||
which must be supported by it. If not specified a default HTTP path will
|
||||
be used by Vertex AI.
|
||||
serving_container_health_route (str):
|
||||
Optional. An HTTP path to send health check requests to the container, and which
|
||||
must be supported by it. If not specified a standard HTTP path will be
|
||||
used by Vertex AI.
|
||||
serving_container_command (Sequence[str]):
|
||||
Optional. The command with which the container is run. Not executed within a
|
||||
shell. The Docker image's ENTRYPOINT is used if this is not provided.
|
||||
Variable references $(VAR_NAME) are expanded using the container's
|
||||
environment. If a variable cannot be resolved, the reference in the
|
||||
input string will be unchanged. The $(VAR_NAME) syntax can be escaped
|
||||
with a double $$, ie: $$(VAR_NAME). Escaped references will never be
|
||||
expanded, regardless of whether the variable exists or not.
|
||||
serving_container_args: (Sequence[str]):
|
||||
Optional. The arguments to the command. The Docker image's CMD is used if this is
|
||||
not provided. Variable references $(VAR_NAME) are expanded using the
|
||||
container's environment. If a variable cannot be resolved, the reference
|
||||
in the input string will be unchanged. The $(VAR_NAME) syntax can be
|
||||
escaped with a double $$, ie: $$(VAR_NAME). Escaped references will
|
||||
never be expanded, regardless of whether the variable exists or not.
|
||||
serving_container_environment_variables (Dict[str, str]):
|
||||
Optional. The environment variables that are to be present in the container.
|
||||
Should be a dictionary where keys are environment variable names
|
||||
and values are environment variable values for those names.
|
||||
serving_container_ports (Sequence[int]):
|
||||
Optional. Declaration of ports that are exposed by the container. This field is
|
||||
primarily informational, it gives Vertex AI information about the
|
||||
network connections the container uses. Listing or not a port here has
|
||||
no impact on whether the port is actually exposed, any port listening on
|
||||
the default "0.0.0.0" address inside a container will be accessible from
|
||||
the network.
|
||||
credential_path (str):
|
||||
Optional. The path to the credential key that will be mounted to the container.
|
||||
If it's unset, the environment variable, ``GOOGLE_APPLICATION_CREDENTIALS``, will
|
||||
be used if set.
|
||||
host_port (str):
|
||||
Optional. The port on the host that the port, ``AIP_HTTP_PORT``, inside the container
|
||||
will be exposed as. If it's unset, a random host port will be assigned.
|
||||
gpu_count (int):
|
||||
Optional. Number of devices to request. Set to -1 to request all available devices.
|
||||
To use GPU, set either ``gpu_count`` or ``gpu_device_ids``.
|
||||
The default value is -1 if ``gpu_capabilities`` is set but both ``gpu_count`` and
|
||||
``gpu_device_ids`` are not set.
|
||||
gpu_device_ids (List[str]):
|
||||
Optional. This parameter corresponds to ``NVIDIA_VISIBLE_DEVICES`` in the NVIDIA
|
||||
Runtime.
|
||||
To use GPU, set either ``gpu_count`` or ``gpu_device_ids``.
|
||||
gpu_capabilities (List[List[str]]):
|
||||
Optional. This parameter corresponds to ``NVIDIA_DRIVER_CAPABILITIES`` in the NVIDIA
|
||||
Runtime. The outer list acts like an OR, and each sub-list acts like an AND. The
|
||||
driver will try to satisfy one of the sub-lists.
|
||||
Available capabilities for the NVIDIA driver can be found in
|
||||
https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/user-guide.html#driver-capabilities.
|
||||
The default value is ``[["utility", "compute"]]`` if ``gpu_count`` or ``gpu_device_ids`` is
|
||||
set.
|
||||
container_ready_timeout (int):
|
||||
Optional. The timeout in second used for starting the container or succeeding the
|
||||
first health check.
|
||||
container_ready_check_interval (int):
|
||||
Optional. The time interval in second to check if the container is ready or the
|
||||
first health check succeeds.
|
||||
|
||||
Raises:
|
||||
ValueError: If both ``gpu_count`` and ``gpu_device_ids`` are set.
|
||||
"""
|
||||
self.container = None
|
||||
self.container_is_running = False
|
||||
self.log_start_index = 0
|
||||
self.serving_container_image_uri = serving_container_image_uri
|
||||
self.artifact_uri = artifact_uri
|
||||
self.serving_container_predict_route = (
|
||||
serving_container_predict_route or prediction.DEFAULT_LOCAL_PREDICT_ROUTE
|
||||
)
|
||||
self.serving_container_health_route = (
|
||||
serving_container_health_route or prediction.DEFAULT_LOCAL_HEALTH_ROUTE
|
||||
)
|
||||
self.serving_container_command = serving_container_command
|
||||
self.serving_container_args = serving_container_args
|
||||
self.serving_container_environment_variables = (
|
||||
serving_container_environment_variables
|
||||
)
|
||||
self.serving_container_ports = serving_container_ports
|
||||
self.container_port = prediction_utils.get_prediction_aip_http_port(
|
||||
serving_container_ports
|
||||
)
|
||||
|
||||
self.credential_path = credential_path
|
||||
self.host_port = host_port
|
||||
# assigned_host_port will be updated according to the running container
|
||||
# if host_port is None.
|
||||
self.assigned_host_port = host_port
|
||||
|
||||
self.gpu_count = gpu_count
|
||||
self.gpu_device_ids = gpu_device_ids
|
||||
self.gpu_capabilities = gpu_capabilities
|
||||
|
||||
if self.gpu_count and self.gpu_device_ids:
|
||||
raise ValueError(
|
||||
"At most one gpu_count or gpu_device_ids can be set but both are set."
|
||||
)
|
||||
if (self.gpu_count or self.gpu_device_ids) and self.gpu_capabilities is None:
|
||||
self.gpu_capabilities = prediction.DEFAULT_LOCAL_RUN_GPU_CAPABILITIES
|
||||
if self.gpu_capabilities and not self.gpu_count and not self.gpu_device_ids:
|
||||
self.gpu_count = prediction.DEFAULT_LOCAL_RUN_GPU_COUNT
|
||||
|
||||
self.container_ready_timeout = (
|
||||
container_ready_timeout or _DEFAULT_CONTAINER_READY_TIMEOUT
|
||||
)
|
||||
self.container_ready_check_interval = (
|
||||
container_ready_check_interval or _DEFAULT_CONTAINER_READY_CHECK_INTERVAL
|
||||
)
|
||||
|
||||
def __enter__(self):
|
||||
"""Enters the runtime context related to this object."""
|
||||
try:
|
||||
self.serve()
|
||||
except Exception as exception:
|
||||
_logger.error(f"Exception during entering a context: {exception}.")
|
||||
raise
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_value, exc_traceback):
|
||||
"""Exits the runtime context related to this object.
|
||||
|
||||
Args:
|
||||
exc_type:
|
||||
Optional. Class of the exception.
|
||||
exc_value:
|
||||
Optional. Type of the exception.
|
||||
exc_traceback:
|
||||
Optional. Traceback that has the information of the exception.
|
||||
"""
|
||||
self.stop()
|
||||
|
||||
def __del__(self):
|
||||
"""Stops the container when the instance is about to be destroyed."""
|
||||
self.stop()
|
||||
|
||||
def serve(self):
|
||||
"""Starts running the container and serves the traffic locally.
|
||||
|
||||
An environment variable, ``GOOGLE_CLOUD_PROJECT``, will be set to the project in the global config.
|
||||
This is required if the credentials file does not have project specified and used to
|
||||
recognize the project by the Cloud Storage client.
|
||||
|
||||
Raises:
|
||||
DockerError: If the container is not ready or health checks do not succeed after the
|
||||
timeout.
|
||||
"""
|
||||
if self.container and self.container_is_running:
|
||||
_logger.warning(
|
||||
"The local endpoint has started serving traffic. "
|
||||
"No need to call `serve()` again."
|
||||
)
|
||||
return
|
||||
|
||||
try:
|
||||
try:
|
||||
project_id = initializer.global_config.project
|
||||
_logger.info(
|
||||
f"Got the project id from the global config: {project_id}."
|
||||
)
|
||||
except (GoogleAuthError, ValueError):
|
||||
project_id = None
|
||||
|
||||
envs = (
|
||||
dict(self.serving_container_environment_variables)
|
||||
if self.serving_container_environment_variables is not None
|
||||
else {}
|
||||
)
|
||||
if project_id is not None:
|
||||
envs[_GCLOUD_PROJECT_ENV] = project_id
|
||||
|
||||
self.container = run.run_prediction_container(
|
||||
self.serving_container_image_uri,
|
||||
artifact_uri=self.artifact_uri,
|
||||
serving_container_predict_route=self.serving_container_predict_route,
|
||||
serving_container_health_route=self.serving_container_health_route,
|
||||
serving_container_command=self.serving_container_command,
|
||||
serving_container_args=self.serving_container_args,
|
||||
serving_container_environment_variables=envs,
|
||||
serving_container_ports=self.serving_container_ports,
|
||||
credential_path=self.credential_path,
|
||||
host_port=self.host_port,
|
||||
gpu_count=self.gpu_count,
|
||||
gpu_device_ids=self.gpu_device_ids,
|
||||
gpu_capabilities=self.gpu_capabilities,
|
||||
)
|
||||
|
||||
# Retrieves the assigned host port.
|
||||
self._wait_until_container_runs()
|
||||
if self.host_port is None:
|
||||
self.container.reload()
|
||||
self.assigned_host_port = self.container.ports[
|
||||
f"{self.container_port}/tcp"
|
||||
][0]["HostPort"]
|
||||
self.container_is_running = True
|
||||
# Waits until the model server starts.
|
||||
self._wait_until_health_check_succeeds()
|
||||
except Exception as exception:
|
||||
_logger.error(f"Exception during starting serving: {exception}.")
|
||||
self._stop_container_if_exists()
|
||||
self.container_is_running = False
|
||||
raise
|
||||
|
||||
def stop(self) -> None:
|
||||
"""Explicitly stops the container."""
|
||||
self._stop_container_if_exists()
|
||||
self.container_is_running = False
|
||||
|
||||
def _wait_until_container_runs(self) -> None:
|
||||
"""Waits until the container is in running status or timeout.
|
||||
|
||||
Raises:
|
||||
DockerError: If timeout.
|
||||
"""
|
||||
elapsed_time = 0
|
||||
while (
|
||||
self.get_container_status() != run.CONTAINER_RUNNING_STATUS
|
||||
and elapsed_time < self.container_ready_timeout
|
||||
):
|
||||
time.sleep(self.container_ready_check_interval)
|
||||
elapsed_time += self.container_ready_check_interval
|
||||
|
||||
if elapsed_time >= self.container_ready_timeout:
|
||||
raise DockerError("The container never starts running.", "", 1)
|
||||
|
||||
def _wait_until_health_check_succeeds(self):
|
||||
"""Waits until a health check succeeds or timeout.
|
||||
|
||||
Raises:
|
||||
DockerError: If container exits or timeout.
|
||||
"""
|
||||
elapsed_time = 0
|
||||
try:
|
||||
response = self.run_health_check(verbose=False)
|
||||
except requests.exceptions.RequestException:
|
||||
response = None
|
||||
|
||||
while elapsed_time < self.container_ready_timeout and (
|
||||
response is None or response.status_code != 200
|
||||
):
|
||||
time.sleep(self.container_ready_check_interval)
|
||||
elapsed_time += self.container_ready_check_interval
|
||||
try:
|
||||
response = self.run_health_check(verbose=False)
|
||||
except requests.exceptions.RequestException:
|
||||
response = None
|
||||
|
||||
if self.get_container_status() != run.CONTAINER_RUNNING_STATUS:
|
||||
self.print_container_logs(
|
||||
show_all=True,
|
||||
message="Container already exited, all container logs:",
|
||||
)
|
||||
raise DockerError(
|
||||
"Container exited before the first health check succeeded.", "", 1
|
||||
)
|
||||
|
||||
if elapsed_time >= self.container_ready_timeout:
|
||||
self.print_container_logs(
|
||||
show_all=True,
|
||||
message="Health check never succeeds, all container logs:",
|
||||
)
|
||||
raise DockerError("The health check never succeeded.", "", 1)
|
||||
|
||||
def _stop_container_if_exists(self):
|
||||
"""Stops the container if the container exists."""
|
||||
if self.container is not None:
|
||||
self.container.stop()
|
||||
|
||||
def predict(
|
||||
self,
|
||||
request: Optional[Any] = None,
|
||||
request_file: Optional[str] = None,
|
||||
headers: Optional[Dict] = None,
|
||||
verbose: bool = True,
|
||||
) -> requests.models.Response:
|
||||
"""Executes a prediction.
|
||||
|
||||
Args:
|
||||
request (Any):
|
||||
Optional. The request sent to the container.
|
||||
request_file (str):
|
||||
Optional. The path to a request file sent to the container.
|
||||
headers (Dict):
|
||||
Optional. The headers in the prediction request.
|
||||
verbose (bool):
|
||||
Required. Whether or not print logs if any.
|
||||
|
||||
Returns:
|
||||
The prediction response.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If the local endpoint has been stopped.
|
||||
ValueError: If both ``request`` and ``request_file`` are specified, both
|
||||
``request`` and ``request_file`` are not provided, or ``request_file``
|
||||
is specified but does not exist.
|
||||
requests.exception.RequestException: If the request fails with an exception.
|
||||
"""
|
||||
if self.container_is_running is False:
|
||||
raise RuntimeError(
|
||||
"The local endpoint is not serving traffic. Please call `serve()`."
|
||||
)
|
||||
|
||||
if request is not None and request_file is not None:
|
||||
raise ValueError(
|
||||
"request and request_file can not be specified at the same time."
|
||||
)
|
||||
if request is None and request_file is None:
|
||||
raise ValueError("One of request and request_file needs to be specified.")
|
||||
|
||||
try:
|
||||
url = f"http://localhost:{self.assigned_host_port}{self.serving_container_predict_route}"
|
||||
if request is not None:
|
||||
response = requests.post(url, data=request, headers=headers)
|
||||
elif request_file is not None:
|
||||
if not Path(request_file).expanduser().resolve().exists():
|
||||
raise ValueError(f"request_file does not exist: {request_file}.")
|
||||
with open(request_file) as data:
|
||||
response = requests.post(url, data=data, headers=headers)
|
||||
return response
|
||||
except requests.exceptions.RequestException as exception:
|
||||
if verbose:
|
||||
_logger.warning(f"Exception during prediction: {exception}")
|
||||
raise
|
||||
|
||||
def run_health_check(self, verbose: bool = True) -> requests.models.Response:
|
||||
"""Runs a health check.
|
||||
|
||||
Args:
|
||||
verbose (bool):
|
||||
Required. Whether or not print logs if any.
|
||||
|
||||
Returns:
|
||||
The health check response.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If the local endpoint has been stopped.
|
||||
requests.exception.RequestException: If the request fails with an exception.
|
||||
"""
|
||||
if self.container_is_running is False:
|
||||
raise RuntimeError(
|
||||
"The local endpoint is not serving traffic. Please call `serve()`."
|
||||
)
|
||||
|
||||
try:
|
||||
url = f"http://localhost:{self.assigned_host_port}{self.serving_container_health_route}"
|
||||
response = requests.get(url)
|
||||
return response
|
||||
except requests.exceptions.RequestException as exception:
|
||||
if verbose:
|
||||
_logger.warning(f"Exception during health check: {exception}")
|
||||
raise
|
||||
|
||||
def print_container_logs(
|
||||
self, show_all: bool = False, message: Optional[str] = None
|
||||
) -> None:
|
||||
"""Prints container logs.
|
||||
|
||||
Args:
|
||||
show_all (bool):
|
||||
Required. If True, prints all logs since the container starts.
|
||||
message (str):
|
||||
Optional. The message to be printed before printing the logs.
|
||||
"""
|
||||
start_index = None if show_all else self.log_start_index
|
||||
self.log_start_index = run.print_container_logs(
|
||||
self.container, start_index=start_index, message=message
|
||||
)
|
||||
|
||||
def print_container_logs_if_container_is_not_running(
|
||||
self, show_all: bool = False, message: Optional[str] = None
|
||||
) -> None:
|
||||
"""Prints container logs if the container is not in "running" status.
|
||||
|
||||
Args:
|
||||
show_all (bool):
|
||||
Required. If True, prints all logs since the container starts.
|
||||
message (str):
|
||||
Optional. The message to be printed before printing the logs.
|
||||
"""
|
||||
if self.get_container_status() != run.CONTAINER_RUNNING_STATUS:
|
||||
self.print_container_logs(show_all=show_all, message=message)
|
||||
|
||||
def get_container_status(self) -> str:
|
||||
"""Gets the container status.
|
||||
|
||||
Returns:
|
||||
The container status. One of restarting, running, paused, exited.
|
||||
"""
|
||||
self.container.reload()
|
||||
return self.container.status
|
||||
@@ -0,0 +1,635 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
# Copyright 2022 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
from copy import copy
|
||||
from typing import Dict, List, Optional, Sequence, Type
|
||||
|
||||
from google.cloud import aiplatform
|
||||
from google.cloud.aiplatform import helpers
|
||||
|
||||
from google.cloud.aiplatform.compat.types import (
|
||||
model as gca_model_compat,
|
||||
env_var as gca_env_var_compat,
|
||||
)
|
||||
|
||||
from google.cloud.aiplatform.docker_utils import build
|
||||
from google.cloud.aiplatform.docker_utils import errors
|
||||
from google.cloud.aiplatform.docker_utils import local_util
|
||||
from google.cloud.aiplatform.docker_utils import utils
|
||||
from google.cloud.aiplatform.prediction import LocalEndpoint
|
||||
from google.cloud.aiplatform.prediction.handler import Handler
|
||||
from google.cloud.aiplatform.prediction.handler import PredictionHandler
|
||||
from google.cloud.aiplatform.prediction.predictor import Predictor
|
||||
from google.cloud.aiplatform.utils import prediction_utils
|
||||
|
||||
from google.protobuf import duration_pb2
|
||||
|
||||
DEFAULT_PREDICT_ROUTE = "/predict"
|
||||
DEFAULT_HEALTH_ROUTE = "/health"
|
||||
DEFAULT_HTTP_PORT = 8080
|
||||
_DEFAULT_SDK_REQUIREMENTS = ["google-cloud-aiplatform[prediction]>=1.27.0"]
|
||||
_DEFAULT_HANDLER_MODULE = "google.cloud.aiplatform.prediction.handler"
|
||||
_DEFAULT_HANDLER_CLASS = "PredictionHandler"
|
||||
_DEFAULT_PYTHON_MODULE = "google.cloud.aiplatform.prediction.model_server"
|
||||
|
||||
|
||||
class LocalModel:
|
||||
"""Class that represents a local model."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
serving_container_spec: Optional[aiplatform.gapic.ModelContainerSpec] = None,
|
||||
serving_container_image_uri: Optional[str] = None,
|
||||
serving_container_predict_route: Optional[str] = None,
|
||||
serving_container_health_route: Optional[str] = None,
|
||||
serving_container_command: Optional[Sequence[str]] = None,
|
||||
serving_container_args: Optional[Sequence[str]] = None,
|
||||
serving_container_environment_variables: Optional[Dict[str, str]] = None,
|
||||
serving_container_ports: Optional[Sequence[int]] = None,
|
||||
serving_container_grpc_ports: Optional[Sequence[int]] = None,
|
||||
serving_container_deployment_timeout: Optional[int] = None,
|
||||
serving_container_shared_memory_size_mb: Optional[int] = None,
|
||||
serving_container_startup_probe_exec: Optional[Sequence[str]] = None,
|
||||
serving_container_startup_probe_period_seconds: Optional[int] = None,
|
||||
serving_container_startup_probe_timeout_seconds: Optional[int] = None,
|
||||
serving_container_health_probe_exec: Optional[Sequence[str]] = None,
|
||||
serving_container_health_probe_period_seconds: Optional[int] = None,
|
||||
serving_container_health_probe_timeout_seconds: Optional[int] = None,
|
||||
):
|
||||
"""Creates a local model instance.
|
||||
|
||||
Args:
|
||||
serving_container_spec (aiplatform.gapic.ModelContainerSpec):
|
||||
Optional. The container spec of the LocalModel instance.
|
||||
serving_container_image_uri (str):
|
||||
Optional. The URI of the Model serving container.
|
||||
serving_container_predict_route (str):
|
||||
Optional. An HTTP path to send prediction requests to the container, and
|
||||
which must be supported by it. If not specified a default HTTP path will
|
||||
be used by Vertex AI.
|
||||
serving_container_health_route (str):
|
||||
Optional. An HTTP path to send health check requests to the container, and which
|
||||
must be supported by it. If not specified a standard HTTP path will be
|
||||
used by Vertex AI.
|
||||
serving_container_command (Sequence[str]):
|
||||
Optional. The command with which the container is run. Not executed within a
|
||||
shell. The Docker image's ENTRYPOINT is used if this is not provided.
|
||||
Variable references $(VAR_NAME) are expanded using the container's
|
||||
environment. If a variable cannot be resolved, the reference in the
|
||||
input string will be unchanged. The $(VAR_NAME) syntax can be escaped
|
||||
with a double $$, ie: $$(VAR_NAME). Escaped references will never be
|
||||
expanded, regardless of whether the variable exists or not.
|
||||
serving_container_args: (Sequence[str]):
|
||||
Optional. The arguments to the command. The Docker image's CMD is used if this is
|
||||
not provided. Variable references $(VAR_NAME) are expanded using the
|
||||
container's environment. If a variable cannot be resolved, the reference
|
||||
in the input string will be unchanged. The $(VAR_NAME) syntax can be
|
||||
escaped with a double $$, ie: $$(VAR_NAME). Escaped references will
|
||||
never be expanded, regardless of whether the variable exists or not.
|
||||
serving_container_environment_variables (Dict[str, str]):
|
||||
Optional. The environment variables that are to be present in the container.
|
||||
Should be a dictionary where keys are environment variable names
|
||||
and values are environment variable values for those names.
|
||||
serving_container_ports (Sequence[int]):
|
||||
Optional. Declaration of ports that are exposed by the container. This field is
|
||||
primarily informational, it gives Vertex AI information about the
|
||||
network connections the container uses. Listing or not a port here has
|
||||
no impact on whether the port is actually exposed, any port listening on
|
||||
the default "0.0.0.0" address inside a container will be accessible from
|
||||
the network.
|
||||
serving_container_grpc_ports: Optional[Sequence[int]]=None,
|
||||
Declaration of ports that are exposed by the container. Vertex AI sends gRPC
|
||||
prediction requests that it receives to the first port on this list. Vertex
|
||||
AI also sends liveness and health checks to this port.
|
||||
If you do not specify this field, gRPC requests to the container will be
|
||||
disabled.
|
||||
Vertex AI does not use ports other than the first one listed. This field
|
||||
corresponds to the `ports` field of the Kubernetes Containers v1 core API.
|
||||
serving_container_deployment_timeout (int):
|
||||
Optional. Deployment timeout in seconds.
|
||||
serving_container_shared_memory_size_mb (int):
|
||||
Optional. The amount of the VM memory to reserve as the shared
|
||||
memory for the model in megabytes.
|
||||
serving_container_startup_probe_exec (Sequence[str]):
|
||||
Optional. Exec specifies the action to take. Used by startup
|
||||
probe. An example of this argument would be
|
||||
["cat", "/tmp/healthy"]
|
||||
serving_container_startup_probe_period_seconds (int):
|
||||
Optional. How often (in seconds) to perform the startup probe.
|
||||
Default to 10 seconds. Minimum value is 1.
|
||||
serving_container_startup_probe_timeout_seconds (int):
|
||||
Optional. Number of seconds after which the startup probe times
|
||||
out. Defaults to 1 second. Minimum value is 1.
|
||||
serving_container_health_probe_exec (Sequence[str]):
|
||||
Optional. Exec specifies the action to take. Used by health
|
||||
probe. An example of this argument would be
|
||||
["cat", "/tmp/healthy"]
|
||||
serving_container_health_probe_period_seconds (int):
|
||||
Optional. How often (in seconds) to perform the health probe.
|
||||
Default to 10 seconds. Minimum value is 1.
|
||||
serving_container_health_probe_timeout_seconds (int):
|
||||
Optional. Number of seconds after which the health probe times
|
||||
out. Defaults to 1 second. Minimum value is 1.
|
||||
|
||||
Raises:
|
||||
ValueError: If ``serving_container_spec`` is specified but ``serving_container_spec.image_uri``
|
||||
is ``None``. Also if ``serving_container_spec`` is None but ``serving_container_image_uri`` is
|
||||
``None``.
|
||||
"""
|
||||
if serving_container_spec:
|
||||
if not serving_container_spec.image_uri:
|
||||
raise ValueError(
|
||||
"Image uri is required for the serving container spec to initialize a LocalModel instance."
|
||||
)
|
||||
|
||||
self.serving_container_spec = serving_container_spec
|
||||
else:
|
||||
if not serving_container_image_uri:
|
||||
raise ValueError(
|
||||
"Serving container image uri is required to initialize a LocalModel instance."
|
||||
)
|
||||
|
||||
env = None
|
||||
ports = None
|
||||
grpc_ports = None
|
||||
deployment_timeout = (
|
||||
duration_pb2.Duration(seconds=serving_container_deployment_timeout)
|
||||
if serving_container_deployment_timeout
|
||||
else None
|
||||
)
|
||||
startup_probe = None
|
||||
health_probe = None
|
||||
|
||||
if serving_container_environment_variables:
|
||||
env = [
|
||||
gca_env_var_compat.EnvVar(name=str(key), value=str(value))
|
||||
for key, value in serving_container_environment_variables.items()
|
||||
]
|
||||
if serving_container_ports:
|
||||
ports = [
|
||||
gca_model_compat.Port(container_port=port)
|
||||
for port in serving_container_ports
|
||||
]
|
||||
if serving_container_grpc_ports:
|
||||
grpc_ports = [
|
||||
gca_model_compat.Port(container_port=port)
|
||||
for port in serving_container_grpc_ports
|
||||
]
|
||||
if (
|
||||
serving_container_startup_probe_exec
|
||||
or serving_container_startup_probe_period_seconds
|
||||
or serving_container_startup_probe_timeout_seconds
|
||||
):
|
||||
startup_probe_exec = None
|
||||
if serving_container_startup_probe_exec:
|
||||
startup_probe_exec = gca_model_compat.Probe.ExecAction(
|
||||
command=serving_container_startup_probe_exec
|
||||
)
|
||||
startup_probe = gca_model_compat.Probe(
|
||||
exec=startup_probe_exec,
|
||||
period_seconds=serving_container_startup_probe_period_seconds,
|
||||
timeout_seconds=serving_container_startup_probe_timeout_seconds,
|
||||
)
|
||||
if (
|
||||
serving_container_health_probe_exec
|
||||
or serving_container_health_probe_period_seconds
|
||||
or serving_container_health_probe_timeout_seconds
|
||||
):
|
||||
health_probe_exec = None
|
||||
if serving_container_health_probe_exec:
|
||||
health_probe_exec = gca_model_compat.Probe.ExecAction(
|
||||
command=serving_container_health_probe_exec
|
||||
)
|
||||
health_probe = gca_model_compat.Probe(
|
||||
exec=health_probe_exec,
|
||||
period_seconds=serving_container_health_probe_period_seconds,
|
||||
timeout_seconds=serving_container_health_probe_timeout_seconds,
|
||||
)
|
||||
|
||||
self.serving_container_spec = gca_model_compat.ModelContainerSpec(
|
||||
image_uri=serving_container_image_uri,
|
||||
command=serving_container_command,
|
||||
args=serving_container_args,
|
||||
env=env,
|
||||
ports=ports,
|
||||
grpc_ports=grpc_ports,
|
||||
predict_route=serving_container_predict_route,
|
||||
health_route=serving_container_health_route,
|
||||
deployment_timeout=deployment_timeout,
|
||||
shared_memory_size_mb=serving_container_shared_memory_size_mb,
|
||||
startup_probe=startup_probe,
|
||||
health_probe=health_probe,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def build_cpr_model(
|
||||
cls,
|
||||
src_dir: str,
|
||||
output_image_uri: str,
|
||||
predictor: Optional[Type[Predictor]] = None,
|
||||
handler: Type[Handler] = PredictionHandler,
|
||||
base_image: str = "python:3.10",
|
||||
requirements_path: Optional[str] = None,
|
||||
extra_packages: Optional[List[str]] = None,
|
||||
no_cache: bool = False,
|
||||
platform: Optional[str] = None,
|
||||
) -> "LocalModel":
|
||||
"""Builds a local model from a custom predictor.
|
||||
|
||||
This method builds a docker image to include user-provided predictor, and handler.
|
||||
|
||||
Sample ``src_dir`` contents (e.g. ``./user_src_dir``):
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
user_src_dir/
|
||||
|-- predictor.py
|
||||
|-- requirements.txt
|
||||
|-- user_code/
|
||||
| |-- utils.py
|
||||
| |-- custom_package.tar.gz
|
||||
| |-- ...
|
||||
|-- ...
|
||||
|
||||
To build a custom container:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
local_model = LocalModel.build_cpr_model(
|
||||
"./user_src_dir",
|
||||
"us-docker.pkg.dev/$PROJECT/$REPOSITORY/$IMAGE_NAME$",
|
||||
predictor=$CUSTOM_PREDICTOR_CLASS,
|
||||
requirements_path="./user_src_dir/requirements.txt",
|
||||
extra_packages=["./user_src_dir/user_code/custom_package.tar.gz"],
|
||||
platform="linux/amd64", # i.e., if you're building on a non-x86 machine
|
||||
)
|
||||
|
||||
In the built image, user provided files will be copied as follows:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
container_workdir/
|
||||
|-- predictor.py
|
||||
|-- requirements.txt
|
||||
|-- user_code/
|
||||
| |-- utils.py
|
||||
| |-- custom_package.tar.gz
|
||||
| |-- ...
|
||||
|-- ...
|
||||
|
||||
To exclude files and directories from being copied into the built container images, create a
|
||||
``.dockerignore`` file in the ``src_dir``. See
|
||||
https://docs.docker.com/engine/reference/builder/#dockerignore-file for more details about
|
||||
usage.
|
||||
|
||||
In order to save and restore class instances transparently with Pickle, the class definition
|
||||
must be importable and live in the same module as when the object was stored. If you want to
|
||||
use Pickle, you must save your objects right under the ``src_dir`` you provide.
|
||||
|
||||
The created CPR images default the number of model server workers to the number of cores.
|
||||
Depending on the characteristics of your model, you may need to adjust the number of workers.
|
||||
You can set the number of workers with the following environment variables:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
VERTEX_CPR_WEB_CONCURRENCY:
|
||||
The number of the workers. This will overwrite the number calculated by the other
|
||||
variables, min(VERTEX_CPR_WORKERS_PER_CORE * number_of_cores, VERTEX_CPR_MAX_WORKERS).
|
||||
VERTEX_CPR_WORKERS_PER_CORE:
|
||||
The number of the workers per core. The default is 1.
|
||||
VERTEX_CPR_MAX_WORKERS:
|
||||
The maximum number of workers can be used given the value of VERTEX_CPR_WORKERS_PER_CORE
|
||||
and the number of cores.
|
||||
|
||||
If you hit the error showing "model server container out of memory" when you deploy models
|
||||
to endpoints, you should decrease the number of workers.
|
||||
|
||||
Args:
|
||||
src_dir (str):
|
||||
Required. The path to the local directory including all needed files such as
|
||||
predictor. The whole directory will be copied to the image.
|
||||
output_image_uri (str):
|
||||
Required. The image uri of the built image.
|
||||
predictor (Type[Predictor]):
|
||||
Optional. The custom predictor class consumed by handler to do prediction.
|
||||
handler (Type[Handler]):
|
||||
Required. The handler class to handle requests in the model server.
|
||||
base_image (str):
|
||||
Required. The base image used to build the custom images. The base image must
|
||||
have python and pip installed where the two commands ``python`` and ``pip`` must be
|
||||
available.
|
||||
requirements_path (str):
|
||||
Optional. The path to the local requirements.txt file. This file will be copied
|
||||
to the image and the needed packages listed in it will be installed.
|
||||
extra_packages (List[str]):
|
||||
Optional. The list of user custom dependency packages to install.
|
||||
no_cache (bool):
|
||||
Required. Do not use cache when building the image. Using build cache usually
|
||||
reduces the image building time. See
|
||||
https://docs.docker.com/develop/develop-images/dockerfile_best-practices/#leverage-build-cache
|
||||
for more details.
|
||||
platform (str):
|
||||
Optional. The target platform for the Docker image build. See
|
||||
https://docs.docker.com/build/building/multi-platform/#building-multi-platform-images
|
||||
for more details.
|
||||
|
||||
Returns:
|
||||
local model: Instantiated representation of the local model.
|
||||
|
||||
Raises:
|
||||
ValueError: If handler is ``None`` or if handler is ``PredictionHandler`` but predictor is ``None``.
|
||||
"""
|
||||
handler_module = _DEFAULT_HANDLER_MODULE
|
||||
handler_class = _DEFAULT_HANDLER_CLASS
|
||||
if handler is None:
|
||||
raise ValueError("A handler must be provided but handler is None.")
|
||||
elif handler == PredictionHandler:
|
||||
if predictor is None:
|
||||
raise ValueError(
|
||||
"PredictionHandler must have a predictor class but predictor is None."
|
||||
)
|
||||
else:
|
||||
handler_module, handler_class = prediction_utils.inspect_source_from_class(
|
||||
handler, src_dir
|
||||
)
|
||||
environment_variables = {
|
||||
"HANDLER_MODULE": handler_module,
|
||||
"HANDLER_CLASS": handler_class,
|
||||
}
|
||||
|
||||
predictor_module = None
|
||||
predictor_class = None
|
||||
if predictor is not None:
|
||||
(
|
||||
predictor_module,
|
||||
predictor_class,
|
||||
) = prediction_utils.inspect_source_from_class(predictor, src_dir)
|
||||
environment_variables["PREDICTOR_MODULE"] = predictor_module
|
||||
environment_variables["PREDICTOR_CLASS"] = predictor_class
|
||||
|
||||
is_prebuilt_prediction_image = helpers.is_prebuilt_prediction_container_uri(
|
||||
base_image
|
||||
)
|
||||
_ = build.build_image(
|
||||
base_image,
|
||||
src_dir,
|
||||
output_image_uri,
|
||||
python_module=_DEFAULT_PYTHON_MODULE,
|
||||
requirements_path=requirements_path,
|
||||
extra_requirements=_DEFAULT_SDK_REQUIREMENTS,
|
||||
extra_packages=extra_packages,
|
||||
exposed_ports=[DEFAULT_HTTP_PORT],
|
||||
environment_variables=environment_variables,
|
||||
pip_command="pip3" if is_prebuilt_prediction_image else "pip",
|
||||
python_command="python3" if is_prebuilt_prediction_image else "python",
|
||||
no_cache=no_cache,
|
||||
platform=platform,
|
||||
)
|
||||
|
||||
container_spec = gca_model_compat.ModelContainerSpec(
|
||||
image_uri=output_image_uri,
|
||||
predict_route=DEFAULT_PREDICT_ROUTE,
|
||||
health_route=DEFAULT_HEALTH_ROUTE,
|
||||
)
|
||||
|
||||
return cls(serving_container_spec=container_spec)
|
||||
|
||||
def deploy_to_local_endpoint(
|
||||
self,
|
||||
artifact_uri: Optional[str] = None,
|
||||
credential_path: Optional[str] = None,
|
||||
host_port: Optional[str] = None,
|
||||
gpu_count: Optional[int] = None,
|
||||
gpu_device_ids: Optional[List[str]] = None,
|
||||
gpu_capabilities: Optional[List[List[str]]] = None,
|
||||
container_ready_timeout: Optional[int] = None,
|
||||
container_ready_check_interval: Optional[int] = None,
|
||||
) -> LocalEndpoint:
|
||||
"""Deploys the local model instance to a local endpoint.
|
||||
|
||||
An environment variable, ``GOOGLE_CLOUD_PROJECT``, will be set to the project in the global config.
|
||||
This is required if the credentials file does not have project specified and used to
|
||||
recognize the project by the Cloud Storage client.
|
||||
|
||||
Example 1:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
with local_model.deploy_to_local_endpoint(
|
||||
artifact_uri="gs://path/to/your/model",
|
||||
credential_path="local/path/to/your/credentials",
|
||||
) as local_endpoint:
|
||||
health_check_response = local_endpoint.run_health_check()
|
||||
print(health_check_response, health_check_response.content)
|
||||
|
||||
predict_response = local_endpoint.predict(
|
||||
request='{"instances": [[1, 2, 3, 4]]}',
|
||||
headers={"header-key": "header-value"},
|
||||
)
|
||||
print(predict_response, predict_response.content)
|
||||
|
||||
local_endpoint.print_container_logs()
|
||||
|
||||
Example 2:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
local_endpoint = local_model.deploy_to_local_endpoint(
|
||||
artifact_uri="gs://path/to/your/model",
|
||||
credential_path="local/path/to/your/credentials",
|
||||
)
|
||||
local_endpoint.serve()
|
||||
|
||||
health_check_response = local_endpoint.run_health_check()
|
||||
print(health_check_response, health_check_response.content)
|
||||
|
||||
predict_response = local_endpoint.predict(
|
||||
request='{"instances": [[1, 2, 3, 4]]}',
|
||||
headers={"header-key": "header-value"},
|
||||
)
|
||||
print(predict_response, predict_response.content)
|
||||
|
||||
local_endpoint.print_container_logs()
|
||||
local_endpoint.stop()
|
||||
|
||||
Args:
|
||||
artifact_uri (str):
|
||||
Optional. The path to the directory containing the Model artifact and any of its
|
||||
supporting files. The path is either a GCS uri or the path to a local directory.
|
||||
If this parameter is set to a GCS uri:
|
||||
(1) ``credential_path`` must be specified for local prediction.
|
||||
(2) The GCS uri will be passed directly to ``Predictor.load``.
|
||||
If this parameter is a local directory:
|
||||
(1) The directory will be mounted to a default temporary model path.
|
||||
(2) The mounted path will be passed to ``Predictor.load``.
|
||||
credential_path (str):
|
||||
Optional. The path to the credential key that will be mounted to the container.
|
||||
If it's unset, the environment variable, ``GOOGLE_APPLICATION_CREDENTIALS``, will
|
||||
be used if set.
|
||||
host_port (str):
|
||||
Optional. The port on the host that the port, ``AIP_HTTP_PORT``, inside the container
|
||||
will be exposed as. If it's unset, a random host port will be assigned.
|
||||
gpu_count (int):
|
||||
Optional. Number of devices to request. Set to -1 to request all available devices.
|
||||
To use GPU, set either ``gpu_count`` or ``gpu_device_ids``.
|
||||
The default value is -1 if ``gpu_capabilities`` is set but both ``gpu_count`` and
|
||||
``gpu_device_ids`` are not set.
|
||||
gpu_device_ids (List[str]):
|
||||
Optional. This parameter corresponds to ``NVIDIA_VISIBLE_DEVICES`` in the NVIDIA
|
||||
Runtime.
|
||||
To use GPU, set either ``gpu_count`` or ``gpu_device_ids``.
|
||||
gpu_capabilities (List[List[str]]):
|
||||
Optional. This parameter corresponds to ``NVIDIA_DRIVER_CAPABILITIES`` in the NVIDIA
|
||||
Runtime. The outer list acts like an OR, and each sub-list acts like an AND. The
|
||||
driver will try to satisfy one of the sub-lists.
|
||||
Available capabilities for the NVIDIA driver can be found in
|
||||
https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/user-guide.html#driver-capabilities.
|
||||
The default value is ``[["utility", "compute"]]`` if ``gpu_count`` or ``gpu_device_ids`` is
|
||||
set.
|
||||
container_ready_timeout (int):
|
||||
Optional. The timeout in second used for starting the container or succeeding the
|
||||
first health check.
|
||||
container_ready_check_interval (int):
|
||||
Optional. The time interval in second to check if the container is ready or the
|
||||
first health check succeeds.
|
||||
|
||||
Returns:
|
||||
A the local endpoint object.
|
||||
"""
|
||||
envs = {env.name: env.value for env in self.serving_container_spec.env}
|
||||
ports = [port.container_port for port in self.serving_container_spec.ports]
|
||||
|
||||
return LocalEndpoint(
|
||||
serving_container_image_uri=self.serving_container_spec.image_uri,
|
||||
artifact_uri=artifact_uri,
|
||||
serving_container_predict_route=self.serving_container_spec.predict_route,
|
||||
serving_container_health_route=self.serving_container_spec.health_route,
|
||||
serving_container_command=self.serving_container_spec.command,
|
||||
serving_container_args=self.serving_container_spec.args,
|
||||
serving_container_environment_variables=envs,
|
||||
serving_container_ports=ports,
|
||||
credential_path=credential_path,
|
||||
host_port=host_port,
|
||||
gpu_count=gpu_count,
|
||||
gpu_device_ids=gpu_device_ids,
|
||||
gpu_capabilities=gpu_capabilities,
|
||||
container_ready_timeout=container_ready_timeout,
|
||||
container_ready_check_interval=container_ready_check_interval,
|
||||
)
|
||||
|
||||
def get_serving_container_spec(self) -> aiplatform.gapic.ModelContainerSpec:
|
||||
"""Returns the container spec for the image.
|
||||
|
||||
Returns:
|
||||
The serving container spec of this local model instance.
|
||||
"""
|
||||
return self.serving_container_spec
|
||||
|
||||
def copy_image(self, dst_image_uri: str) -> "LocalModel":
|
||||
"""Copies the image to another image uri.
|
||||
|
||||
Args:
|
||||
dst_image_uri (str):
|
||||
The destination image uri to copy the image to.
|
||||
|
||||
Returns:
|
||||
local model: Instantiated representation of the local model with the copied
|
||||
image.
|
||||
|
||||
Raises:
|
||||
DockerError: If the command fails.
|
||||
"""
|
||||
self.pull_image_if_not_exists()
|
||||
|
||||
command = [
|
||||
"docker",
|
||||
"tag",
|
||||
f"{self.serving_container_spec.image_uri}",
|
||||
f"{dst_image_uri}",
|
||||
]
|
||||
return_code = local_util.execute_command(command)
|
||||
if return_code != 0:
|
||||
errors.raise_docker_error_with_command(command, return_code)
|
||||
|
||||
new_container_spec = copy(self.serving_container_spec)
|
||||
new_container_spec.image_uri = dst_image_uri
|
||||
|
||||
return LocalModel(new_container_spec)
|
||||
|
||||
def push_image(self) -> None:
|
||||
"""Pushes the image to a registry.
|
||||
|
||||
If you hit permission errors while calling this function, please refer to
|
||||
https://cloud.google.com/artifact-registry/docs/docker/authentication to set
|
||||
up the authentication.
|
||||
|
||||
For Artifact Registry, the repository must be created before you are able to
|
||||
push images to it. Otherwise, you will hit the error, "Repository {REPOSITORY} not found".
|
||||
To create Artifact Registry repositories, use UI or call the following gcloud command.
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
gcloud artifacts repositories create {REPOSITORY} \
|
||||
--project {PROJECT} \
|
||||
--location {REGION} \
|
||||
--repository-format docker
|
||||
|
||||
See https://cloud.google.com/artifact-registry/docs/manage-repos#create for more details.
|
||||
|
||||
If you hit a "Permission artifactregistry.repositories.uploadArtifacts denied" error,
|
||||
set up authentication for Docker.
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
gcloud auth configure-docker {REPOSITORY}
|
||||
|
||||
See https://cloud.google.com/artifact-registry/docs/docker/authentication for mode details.
|
||||
|
||||
Raises:
|
||||
ValueError: If the image uri is not a container registry or artifact registry
|
||||
uri.
|
||||
DockerError: If the command fails.
|
||||
"""
|
||||
if (
|
||||
prediction_utils.is_registry_uri(self.serving_container_spec.image_uri)
|
||||
is False
|
||||
):
|
||||
raise ValueError(
|
||||
"The image uri must be a container registry or artifact registry "
|
||||
f"uri but it is: {self.serving_container_spec.image_uri}."
|
||||
)
|
||||
|
||||
command = ["docker", "push", f"{self.serving_container_spec.image_uri}"]
|
||||
return_code = local_util.execute_command(command)
|
||||
if return_code != 0:
|
||||
errors.raise_docker_error_with_command(command, return_code)
|
||||
|
||||
def pull_image_if_not_exists(self):
|
||||
"""Pulls the image if the image does not exist locally.
|
||||
|
||||
Raises:
|
||||
DockerError: If the command fails.
|
||||
"""
|
||||
if not utils.check_image_exists_locally(self.serving_container_spec.image_uri):
|
||||
command = [
|
||||
"docker",
|
||||
"pull",
|
||||
f"{self.serving_container_spec.image_uri}",
|
||||
]
|
||||
return_code = local_util.execute_command(command)
|
||||
if return_code != 0:
|
||||
errors.raise_docker_error_with_command(command, return_code)
|
||||
@@ -0,0 +1,211 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
# Copyright 2022 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
import importlib
|
||||
import logging
|
||||
import multiprocessing
|
||||
import os
|
||||
import sys
|
||||
import traceback
|
||||
|
||||
try:
|
||||
from fastapi import FastAPI
|
||||
from fastapi import HTTPException
|
||||
from fastapi import Request
|
||||
from fastapi import Response
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"FastAPI is not installed and is required to run model servers. "
|
||||
'Please install the SDK using `pip install "google-cloud-aiplatform[prediction]>=1.16.0"`.'
|
||||
)
|
||||
|
||||
try:
|
||||
import uvicorn
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"Uvicorn is not installed and is required to run fastapi applications. "
|
||||
'Please install the SDK using `pip install "google-cloud-aiplatform[prediction]>=1.16.0"`.'
|
||||
)
|
||||
|
||||
from google.cloud.aiplatform.constants import prediction
|
||||
from google.cloud.aiplatform import version
|
||||
|
||||
|
||||
class CprModelServer:
|
||||
"""Model server to do custom prediction routines."""
|
||||
|
||||
def __init__(self):
|
||||
"""Initializes a fastapi application and sets the configs.
|
||||
|
||||
Raises:
|
||||
ValueError: If either HANDLER_MODULE or HANDLER_CLASS is not set in the
|
||||
environment variables. Or if any of AIP_HTTP_PORT, AIP_HEALTH_ROUTE,
|
||||
and AIP_PREDICT_ROUTE is not set in the environment variables.
|
||||
"""
|
||||
self._init_logging()
|
||||
|
||||
if "HANDLER_MODULE" not in os.environ or "HANDLER_CLASS" not in os.environ:
|
||||
raise ValueError(
|
||||
"Both of the environment variables, HANDLER_MODULE and HANDLER_CLASS "
|
||||
"need to be specified."
|
||||
)
|
||||
handler_module = importlib.import_module(os.environ.get("HANDLER_MODULE"))
|
||||
handler_class = getattr(handler_module, os.environ.get("HANDLER_CLASS"))
|
||||
self.is_default_handler = (
|
||||
handler_module == "google.cloud.aiplatform.prediction.handler"
|
||||
)
|
||||
|
||||
predictor_class = None
|
||||
if "PREDICTOR_MODULE" in os.environ:
|
||||
predictor_module = importlib.import_module(
|
||||
os.environ.get("PREDICTOR_MODULE")
|
||||
)
|
||||
predictor_class = getattr(
|
||||
predictor_module, os.environ.get("PREDICTOR_CLASS")
|
||||
)
|
||||
|
||||
self.handler = handler_class(
|
||||
os.environ.get("AIP_STORAGE_URI"), predictor=predictor_class
|
||||
)
|
||||
|
||||
if "AIP_HTTP_PORT" not in os.environ:
|
||||
raise ValueError(
|
||||
"The environment variable AIP_HTTP_PORT needs to be specified."
|
||||
)
|
||||
if (
|
||||
"AIP_HEALTH_ROUTE" not in os.environ
|
||||
or "AIP_PREDICT_ROUTE" not in os.environ
|
||||
):
|
||||
raise ValueError(
|
||||
"Both of the environment variables AIP_HEALTH_ROUTE and "
|
||||
"AIP_PREDICT_ROUTE need to be specified."
|
||||
)
|
||||
self.http_port = int(os.environ.get("AIP_HTTP_PORT"))
|
||||
self.health_route = os.environ.get("AIP_HEALTH_ROUTE")
|
||||
self.predict_route = os.environ.get("AIP_PREDICT_ROUTE")
|
||||
|
||||
self.app = FastAPI()
|
||||
self.app.add_api_route(
|
||||
path=self.health_route,
|
||||
endpoint=self.health,
|
||||
methods=["GET"],
|
||||
)
|
||||
self.app.add_api_route(
|
||||
path=self.predict_route,
|
||||
endpoint=self.predict,
|
||||
methods=["POST"],
|
||||
)
|
||||
|
||||
async def __call__(self, scope, receive, send):
|
||||
await self.app(scope, receive, send)
|
||||
|
||||
def _init_logging(self):
|
||||
"""Initializes the logging config."""
|
||||
logging.basicConfig(
|
||||
format="%(asctime)s: %(message)s",
|
||||
datefmt="%m/%d/%Y %I:%M:%S %p",
|
||||
level=logging.INFO,
|
||||
stream=sys.stdout,
|
||||
)
|
||||
|
||||
def health(self):
|
||||
"""Executes a health check."""
|
||||
return {}
|
||||
|
||||
async def predict(self, request: Request) -> Response:
|
||||
"""Executes a prediction.
|
||||
|
||||
Args:
|
||||
request (Request):
|
||||
Required. The prediction request.
|
||||
|
||||
Returns:
|
||||
The response containing prediction results.
|
||||
|
||||
Raises:
|
||||
HTTPException: If the handle function of the handler raises any exceptions.
|
||||
"""
|
||||
try:
|
||||
return await self.handler.handle(request)
|
||||
except HTTPException:
|
||||
# Raises exception if it's a HTTPException.
|
||||
raise
|
||||
except Exception as exception:
|
||||
error_message = "An exception {} occurred. Arguments: {}.".format(
|
||||
type(exception).__name__, exception.args
|
||||
)
|
||||
logging.info(
|
||||
"{}\\nTraceback: {}".format(error_message, traceback.format_exc())
|
||||
)
|
||||
|
||||
# Converts all other exceptions to HTTPException.
|
||||
if self.is_default_handler:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=error_message,
|
||||
headers={
|
||||
prediction.CUSTOM_PREDICTION_ROUTINES_SERVER_ERROR_HEADER_KEY: version.__version__
|
||||
},
|
||||
)
|
||||
raise HTTPException(status_code=500, detail=error_message)
|
||||
|
||||
|
||||
def set_number_of_workers_from_env() -> None:
|
||||
"""Sets the number of model server workers used by Uvicorn in the environment variable.
|
||||
|
||||
The number of model server workers will be set as WEB_CONCURRENCY in the environment
|
||||
variables.
|
||||
The default number of model server workers is the number of cores.
|
||||
The following environment variables will adjust the number of workers:
|
||||
VERTEX_CPR_WEB_CONCURRENCY:
|
||||
The number of the workers. This will overwrite the number calculated by the other
|
||||
variables, min(VERTEX_CPR_WORKERS_PER_CORE * number_of_cores, VERTEX_CPR_MAX_WORKERS).
|
||||
VERTEX_CPR_WORKERS_PER_CORE:
|
||||
The number of the workers per core. The default is 1.
|
||||
VERTEX_CPR_MAX_WORKERS:
|
||||
The maximum number of workers can be used given the value of VERTEX_CPR_WORKERS_PER_CORE
|
||||
and the number of cores.
|
||||
"""
|
||||
workers_per_core_str = os.getenv("VERTEX_CPR_WORKERS_PER_CORE", "1")
|
||||
max_workers_str = os.getenv("VERTEX_CPR_MAX_WORKERS")
|
||||
use_max_workers = None
|
||||
if max_workers_str:
|
||||
use_max_workers = int(max_workers_str)
|
||||
web_concurrency_str = os.getenv("VERTEX_CPR_WEB_CONCURRENCY")
|
||||
|
||||
if not web_concurrency_str:
|
||||
cores = multiprocessing.cpu_count()
|
||||
workers_per_core = float(workers_per_core_str)
|
||||
default_web_concurrency = workers_per_core * cores
|
||||
web_concurrency = max(int(default_web_concurrency), 2)
|
||||
if use_max_workers:
|
||||
web_concurrency = min(web_concurrency, use_max_workers)
|
||||
web_concurrency_str = str(web_concurrency)
|
||||
os.environ["WEB_CONCURRENCY"] = web_concurrency_str
|
||||
logging.warning(
|
||||
f'Set the number of model server workers to {os.environ["WEB_CONCURRENCY"]}.'
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
set_number_of_workers_from_env()
|
||||
uvicorn.run(
|
||||
"google.cloud.aiplatform.prediction.model_server:CprModelServer",
|
||||
host="0.0.0.0",
|
||||
port=int(os.environ.get("AIP_HTTP_PORT")),
|
||||
factory=True,
|
||||
)
|
||||
@@ -0,0 +1,87 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
# Copyright 2022 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any
|
||||
|
||||
|
||||
class Predictor(ABC):
|
||||
"""Interface of the Predictor class for Custom Prediction Routines.
|
||||
|
||||
The Predictor is responsible for the ML logic for processing a prediction request.
|
||||
Specifically, the Predictor must define:
|
||||
(1) How to load all model artifacts used during prediction into memory.
|
||||
(2) The logic that should be executed at predict time.
|
||||
|
||||
When using the default ``PredictionHandler``, the ``Predictor`` will be invoked as
|
||||
follows:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
predictor.postprocess(predictor.predict(predictor.preprocess(prediction_input)))
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
return
|
||||
|
||||
@abstractmethod
|
||||
def load(self, artifacts_uri: str) -> None:
|
||||
"""Loads the model artifact.
|
||||
|
||||
Args:
|
||||
artifacts_uri (str):
|
||||
Required. The value of the environment variable AIP_STORAGE_URI.
|
||||
"""
|
||||
pass
|
||||
|
||||
def preprocess(self, prediction_input: Any) -> Any:
|
||||
"""Preprocesses the prediction input before doing the prediction.
|
||||
|
||||
Args:
|
||||
prediction_input (Any):
|
||||
Required. The prediction input that needs to be preprocessed.
|
||||
|
||||
Returns:
|
||||
The preprocessed prediction input.
|
||||
"""
|
||||
return prediction_input
|
||||
|
||||
@abstractmethod
|
||||
def predict(self, instances: Any) -> Any:
|
||||
"""Performs prediction.
|
||||
|
||||
Args:
|
||||
instances (Any):
|
||||
Required. The instance(s) used for performing prediction.
|
||||
|
||||
Returns:
|
||||
Prediction results.
|
||||
"""
|
||||
pass
|
||||
|
||||
def postprocess(self, prediction_results: Any) -> Any:
|
||||
"""Postprocesses the prediction results.
|
||||
|
||||
Args:
|
||||
prediction_results (Any):
|
||||
Required. The prediction results.
|
||||
|
||||
Returns:
|
||||
The postprocessed prediction results.
|
||||
"""
|
||||
return prediction_results
|
||||
@@ -0,0 +1,144 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
# Copyright 2022 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
import json
|
||||
from typing import Any, Optional
|
||||
|
||||
try:
|
||||
from fastapi import HTTPException
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"FastAPI is not installed and is required to build model servers. "
|
||||
'Please install the SDK using `pip install "google-cloud-aiplatform[prediction]>=1.16.0"`.'
|
||||
)
|
||||
|
||||
from google.cloud.aiplatform.constants import prediction as prediction_constants
|
||||
from google.cloud.aiplatform.prediction import handler_utils
|
||||
|
||||
|
||||
APPLICATION_JSON = "application/json"
|
||||
|
||||
|
||||
class Serializer(ABC):
|
||||
"""Interface to implement serialization and deserialization for prediction."""
|
||||
|
||||
@staticmethod
|
||||
@abstractmethod
|
||||
def deserialize(data: Any, content_type: Optional[str]) -> Any:
|
||||
"""Deserializes the request data. Invoked before predict.
|
||||
|
||||
Args:
|
||||
data (Any):
|
||||
Required. The request data sent to the application.
|
||||
content_type (str):
|
||||
Optional. The specified content type of the request.
|
||||
"""
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
@abstractmethod
|
||||
def serialize(prediction: Any, accept: Optional[str]) -> Any:
|
||||
"""Serializes the prediction results. Invoked after predict.
|
||||
|
||||
Args:
|
||||
prediction (Any):
|
||||
Required. The generated prediction to be sent back to clients.
|
||||
accept (str):
|
||||
Optional. The specified content type of the response.
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class DefaultSerializer(Serializer):
|
||||
"""Default serializer for serialization and deserialization for prediction."""
|
||||
|
||||
@staticmethod
|
||||
def deserialize(data: Any, content_type: Optional[str]) -> Any:
|
||||
"""Deserializes the request data. Invoked before predict.
|
||||
|
||||
Args:
|
||||
data (Any):
|
||||
Required. The request data sent to the application.
|
||||
content_type (str):
|
||||
Optional. The specified content type of the request.
|
||||
|
||||
Raises:
|
||||
HTTPException: If Json deserialization failed or the specified content type is not
|
||||
supported.
|
||||
"""
|
||||
if content_type == APPLICATION_JSON:
|
||||
try:
|
||||
return json.loads(data)
|
||||
except json.JSONDecodeError:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=(
|
||||
f"JSON deserialization failed for the request data: {data}.\n"
|
||||
'To specify a different type, please set the "content-type" header '
|
||||
"in the request.\nCurrently supported content-type in DefaultSerializer: "
|
||||
f'"{APPLICATION_JSON}".'
|
||||
),
|
||||
)
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=(
|
||||
f"Unsupported content type of the request: {content_type}.\n"
|
||||
f'Currently supported content-type in DefaultSerializer: "{APPLICATION_JSON}".'
|
||||
),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def serialize(prediction: Any, accept: Optional[str]) -> Any:
|
||||
"""Serializes the prediction results. Invoked after predict.
|
||||
|
||||
Args:
|
||||
prediction (Any):
|
||||
Required. The generated prediction to be sent back to clients.
|
||||
accept (str):
|
||||
Optional. The specified content type of the response.
|
||||
|
||||
Raises:
|
||||
HTTPException: If Json serialization failed or the specified accept is not supported.
|
||||
"""
|
||||
accept_dict = handler_utils.parse_accept_header(accept)
|
||||
|
||||
if (
|
||||
APPLICATION_JSON in accept_dict
|
||||
or prediction_constants.ANY_ACCEPT_TYPE in accept_dict
|
||||
):
|
||||
try:
|
||||
return json.dumps(prediction)
|
||||
except TypeError:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=(
|
||||
f"JSON serialization failed for the prediction result: {prediction}.\n"
|
||||
'To specify a different type, please set the "accept" header '
|
||||
"in the request.\nCurrently supported accept in DefaultSerializer: "
|
||||
f'"{APPLICATION_JSON}".'
|
||||
),
|
||||
)
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=(
|
||||
f"Unsupported accept of the response: {accept}.\n"
|
||||
f'Currently supported accept in DefaultSerializer: "{APPLICATION_JSON}".'
|
||||
),
|
||||
)
|
||||
@@ -0,0 +1,20 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
# Copyright 2022 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
from google.cloud.aiplatform.prediction.sklearn.predictor import SklearnPredictor
|
||||
|
||||
__all__ = ("SklearnPredictor",)
|
||||
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,90 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
# Copyright 2022 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
import joblib
|
||||
import numpy as np
|
||||
import os
|
||||
import pickle
|
||||
|
||||
from google.cloud.aiplatform.constants import prediction
|
||||
from google.cloud.aiplatform.utils import prediction_utils
|
||||
from google.cloud.aiplatform.prediction.predictor import Predictor
|
||||
|
||||
|
||||
class SklearnPredictor(Predictor):
|
||||
"""Default Predictor implementation for Sklearn models."""
|
||||
|
||||
def __init__(self):
|
||||
return
|
||||
|
||||
def load(self, artifacts_uri: str) -> None:
|
||||
"""Loads the model artifact.
|
||||
|
||||
Args:
|
||||
artifacts_uri (str):
|
||||
Required. The value of the environment variable AIP_STORAGE_URI.
|
||||
|
||||
Raises:
|
||||
ValueError: If there's no required model files provided in the artifacts
|
||||
uri.
|
||||
"""
|
||||
prediction_utils.download_model_artifacts(artifacts_uri)
|
||||
if os.path.exists(prediction.MODEL_FILENAME_JOBLIB):
|
||||
self._model = joblib.load(prediction.MODEL_FILENAME_JOBLIB)
|
||||
elif os.path.exists(prediction.MODEL_FILENAME_PKL):
|
||||
self._model = pickle.load(open(prediction.MODEL_FILENAME_PKL, "rb"))
|
||||
else:
|
||||
valid_filenames = [
|
||||
prediction.MODEL_FILENAME_JOBLIB,
|
||||
prediction.MODEL_FILENAME_PKL,
|
||||
]
|
||||
raise ValueError(
|
||||
f"One of the following model files must be provided: {valid_filenames}."
|
||||
)
|
||||
|
||||
def preprocess(self, prediction_input: dict) -> np.ndarray:
|
||||
"""Converts the request body to a numpy array before prediction.
|
||||
Args:
|
||||
prediction_input (dict):
|
||||
Required. The prediction input that needs to be preprocessed.
|
||||
Returns:
|
||||
The preprocessed prediction input.
|
||||
"""
|
||||
instances = prediction_input["instances"]
|
||||
return np.asarray(instances)
|
||||
|
||||
def predict(self, instances: np.ndarray) -> np.ndarray:
|
||||
"""Performs prediction.
|
||||
|
||||
Args:
|
||||
instances (np.ndarray):
|
||||
Required. The instance(s) used for performing prediction.
|
||||
|
||||
Returns:
|
||||
Prediction results.
|
||||
"""
|
||||
return self._model.predict(instances)
|
||||
|
||||
def postprocess(self, prediction_results: np.ndarray) -> dict:
|
||||
"""Converts numpy array to a dict.
|
||||
Args:
|
||||
prediction_results (np.ndarray):
|
||||
Required. The prediction results.
|
||||
Returns:
|
||||
The postprocessed prediction results.
|
||||
"""
|
||||
return {"predictions": prediction_results.tolist()}
|
||||
@@ -0,0 +1,20 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
# Copyright 2022 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
from google.cloud.aiplatform.prediction.xgboost.predictor import XgboostPredictor
|
||||
|
||||
__all__ = ("XgboostPredictor",)
|
||||
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,105 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
# Copyright 2022 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
import joblib
|
||||
import logging
|
||||
import os
|
||||
import pickle
|
||||
|
||||
import numpy as np
|
||||
import xgboost as xgb
|
||||
|
||||
from google.cloud.aiplatform.constants import prediction
|
||||
from google.cloud.aiplatform.utils import prediction_utils
|
||||
from google.cloud.aiplatform.prediction.predictor import Predictor
|
||||
|
||||
|
||||
class XgboostPredictor(Predictor):
|
||||
"""Default Predictor implementation for Xgboost models."""
|
||||
|
||||
def __init__(self):
|
||||
return
|
||||
|
||||
def load(self, artifacts_uri: str) -> None:
|
||||
"""Loads the model artifact.
|
||||
|
||||
Args:
|
||||
artifacts_uri (str):
|
||||
Required. The value of the environment variable AIP_STORAGE_URI.
|
||||
|
||||
Raises:
|
||||
ValueError: If there's no required model files provided in the artifacts
|
||||
uri.
|
||||
"""
|
||||
prediction_utils.download_model_artifacts(artifacts_uri)
|
||||
if os.path.exists(prediction.MODEL_FILENAME_BST):
|
||||
booster = xgb.Booster(model_file=prediction.MODEL_FILENAME_BST)
|
||||
elif os.path.exists(prediction.MODEL_FILENAME_JOBLIB):
|
||||
try:
|
||||
booster = joblib.load(prediction.MODEL_FILENAME_JOBLIB)
|
||||
except KeyError:
|
||||
logging.info(
|
||||
"Loading model using joblib failed. "
|
||||
"Loading model using xgboost.Booster instead."
|
||||
)
|
||||
booster = xgb.Booster()
|
||||
booster.load_model(prediction.MODEL_FILENAME_JOBLIB)
|
||||
elif os.path.exists(prediction.MODEL_FILENAME_PKL):
|
||||
booster = pickle.load(open(prediction.MODEL_FILENAME_PKL, "rb"))
|
||||
else:
|
||||
valid_filenames = [
|
||||
prediction.MODEL_FILENAME_BST,
|
||||
prediction.MODEL_FILENAME_JOBLIB,
|
||||
prediction.MODEL_FILENAME_PKL,
|
||||
]
|
||||
raise ValueError(
|
||||
f"One of the following model files must be provided: {valid_filenames}."
|
||||
)
|
||||
self._booster = booster
|
||||
|
||||
def preprocess(self, prediction_input: dict) -> xgb.DMatrix:
|
||||
"""Converts the request body to a Data Matrix before prediction.
|
||||
Args:
|
||||
prediction_input (dict):
|
||||
Required. The prediction input that needs to be preprocessed.
|
||||
Returns:
|
||||
The preprocessed prediction input.
|
||||
"""
|
||||
instances = prediction_input["instances"]
|
||||
return xgb.DMatrix(instances)
|
||||
|
||||
def predict(self, instances: xgb.DMatrix) -> np.ndarray:
|
||||
"""Performs prediction.
|
||||
|
||||
Args:
|
||||
instances (xgb.DMatrix):
|
||||
Required. The instance(s) used for performing prediction.
|
||||
|
||||
Returns:
|
||||
Prediction results.
|
||||
"""
|
||||
return self._booster.predict(instances)
|
||||
|
||||
def postprocess(self, prediction_results: np.ndarray) -> dict:
|
||||
"""Converts numpy array to a dict.
|
||||
Args:
|
||||
prediction_results (np.ndarray):
|
||||
Required. The prediction results.
|
||||
Returns:
|
||||
The postprocessed prediction results.
|
||||
"""
|
||||
return {"predictions": prediction_results.tolist()}
|
||||
Reference in New Issue
Block a user