190 lines
6.4 KiB
Python
190 lines
6.4 KiB
Python
"""Regsiter XGBoost for Ray on Vertex AI."""
|
|
|
|
# -*- coding: utf-8 -*-
|
|
|
|
# Copyright 2023 Google LLC
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
#
|
|
|
|
import os
|
|
import pickle
|
|
import ray
|
|
import tempfile
|
|
from typing import Optional, TYPE_CHECKING
|
|
import warnings
|
|
from google.cloud import aiplatform
|
|
from google.cloud.aiplatform import initializer
|
|
from google.cloud.aiplatform import utils
|
|
from google.cloud.aiplatform.utils import gcs_utils
|
|
from google.cloud.aiplatform.vertex_ray.predict.util import constants
|
|
from google.cloud.aiplatform.vertex_ray.predict.util import (
|
|
predict_utils,
|
|
)
|
|
from google.cloud.aiplatform.vertex_ray.util._validation_utils import (
|
|
_V2_4_WARNING_MESSAGE,
|
|
_V2_9_WARNING_MESSAGE,
|
|
)
|
|
|
|
|
|
try:
|
|
from ray.train import xgboost as ray_xgboost
|
|
|
|
if TYPE_CHECKING:
|
|
import xgboost
|
|
|
|
except ModuleNotFoundError as mnfe:
|
|
if ray.__version__ == "2.9.3":
|
|
raise ModuleNotFoundError("XGBoost isn't installed.") from mnfe
|
|
else:
|
|
xgboost = None
|
|
|
|
|
|
def register_xgboost(
|
|
checkpoint: "ray_xgboost.XGBoostCheckpoint",
|
|
artifact_uri: Optional[str] = None,
|
|
display_name: Optional[str] = None,
|
|
xgboost_version: Optional[str] = None,
|
|
**kwargs,
|
|
) -> aiplatform.Model:
|
|
"""Uploads a Ray XGBoost Checkpoint as XGBoost Model to Model Registry.
|
|
|
|
Example usage:
|
|
from vertex_ray.predict import xgboost
|
|
from ray.train.xgboost import XGBoostCheckpoint
|
|
|
|
trainer = XGBoostTrainer(...)
|
|
result = trainer.fit()
|
|
xgboost_checkpoint = XGBoostCheckpoint.from_checkpoint(result.checkpoint)
|
|
|
|
my_model = xgboost.register_xgboost(
|
|
checkpoint=xgboost_checkpoint,
|
|
artifact_uri="gs://{gcs-bucket-name}/path/to/store",
|
|
display_name="my-ray-on-vertex-xgboost-model",
|
|
)
|
|
|
|
|
|
Args:
|
|
checkpoint: XGBoostCheckpoint instance.
|
|
artifact_uri (str):
|
|
The path to the directory where Model Artifacts will be saved. If
|
|
not set, will use staging bucket set in aiplatform.init().
|
|
display_name (str):
|
|
Optional. The display name of the Model. The name can be up to 128
|
|
characters long and can be consist of any UTF-8 characters.
|
|
xgboost_version (str): Optional. The version of the XGBoost serving container.
|
|
Supported versions:
|
|
https://cloud.google.com/vertex-ai/docs/predictions/pre-built-containers
|
|
If the version is not specified, the version 1.6 is used.
|
|
**kwargs:
|
|
Any kwargs will be passed to aiplatform.Model registration.
|
|
|
|
Returns:
|
|
model (aiplatform.Model):
|
|
Instantiated representation of the uploaded model resource.
|
|
|
|
Raises:
|
|
ValueError: Invalid Argument.
|
|
RuntimeError: Only Ray version 2.9.3 is supported.
|
|
"""
|
|
ray_version = ray.__version__
|
|
if ray_version != "2.9.3":
|
|
raise RuntimeError(
|
|
f"Ray version {ray_version} is not supported to upload XGBoost"
|
|
" model to Vertex Model Registry yet. Please use Ray 2.9.3."
|
|
)
|
|
if ray_version == "2.9.3":
|
|
warnings.warn(_V2_9_WARNING_MESSAGE, DeprecationWarning, stacklevel=1)
|
|
artifact_uri = artifact_uri or initializer.global_config.staging_bucket
|
|
predict_utils.validate_artifact_uri(artifact_uri)
|
|
display_model_name = (
|
|
(f"ray-on-vertex-registered-xgboost-model-{utils.timestamped_unique_name()}")
|
|
if display_name is None
|
|
else display_name
|
|
)
|
|
model = _get_xgboost_model_from(checkpoint)
|
|
|
|
model_dir = os.path.join(artifact_uri, display_model_name)
|
|
file_path = os.path.join(model_dir, constants._PICKLE_FILE_NAME)
|
|
if xgboost_version is None:
|
|
xgboost_version = constants._XGBOOST_VERSION
|
|
|
|
with tempfile.NamedTemporaryFile(suffix=constants._PICKLE_EXTENTION) as temp_file:
|
|
pickle.dump(model, temp_file)
|
|
gcs_utils.upload_to_gcs(temp_file.name, file_path)
|
|
return aiplatform.Model.upload_xgboost_model_file(
|
|
model_file_path=temp_file.name,
|
|
display_name=display_model_name,
|
|
xgboost_version=xgboost_version,
|
|
**kwargs,
|
|
)
|
|
|
|
|
|
def _get_xgboost_model_from(
|
|
checkpoint: "ray_xgboost.XGBoostCheckpoint",
|
|
) -> "xgboost.Booster":
|
|
"""Converts a XGBoostCheckpoint to XGBoost model.
|
|
|
|
Args:
|
|
checkpoint: XGBoostCheckpoint instance.
|
|
|
|
Returns:
|
|
A XGBoost core Booster
|
|
|
|
Raises:
|
|
ValueError: Invalid Argument.
|
|
ModuleNotFoundError: XGBoost isn't installed.
|
|
RuntimeError: Model not found.
|
|
RuntimeError: Ray version 2.4 is not supported.
|
|
RuntimeError: Only Ray version 2.9.3 is supported.
|
|
"""
|
|
ray_version = ray.__version__
|
|
if ray_version == "2.4.0":
|
|
raise RuntimeError(_V2_4_WARNING_MESSAGE)
|
|
if ray_version != "2.9.3":
|
|
raise RuntimeError(
|
|
f"Ray version {ray_version} is not supported to convert a XGBoost"
|
|
" checkpoint to XGBoost model on Vertex yet. Please use Ray 2.9.3."
|
|
)
|
|
|
|
try:
|
|
# This works for Ray v2.5
|
|
return checkpoint.get_model()
|
|
except AttributeError:
|
|
# This works for Ray v2.9
|
|
model_file_name = ray.train.xgboost.XGBoostCheckpoint.MODEL_FILENAME
|
|
|
|
model_path = os.path.join(checkpoint.path, model_file_name)
|
|
|
|
try:
|
|
import xgboost
|
|
|
|
except ModuleNotFoundError as mnfe:
|
|
raise ModuleNotFoundError("XGBoost isn't installed.") from mnfe
|
|
|
|
booster = xgboost.Booster()
|
|
if os.path.exists(model_path):
|
|
booster.load_model(model_path)
|
|
return booster
|
|
|
|
try:
|
|
# Download from GCS to temp and then load_model
|
|
with tempfile.TemporaryDirectory() as temp_dir:
|
|
gcs_utils.download_from_gcs("gs://" + checkpoint.path, temp_dir)
|
|
booster.load_model(f"{temp_dir}/{model_file_name}")
|
|
return booster
|
|
except Exception as e:
|
|
raise RuntimeError(
|
|
f"{model_file_name} not found in this checkpoint due to: {e}."
|
|
)
|