structure saas with tools
This commit is contained in:
@@ -0,0 +1,22 @@
|
||||
"""Ray on Vertex AI Prediction Tensorflow."""
|
||||
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
# Copyright 2023 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
from .register import get_pytorch_model_from
|
||||
|
||||
__all__ = ("get_pytorch_model_from",)
|
||||
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,112 @@
|
||||
"""Regsiter Torch for Ray on Vertex AI."""
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
# Copyright 2023 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
import warnings
|
||||
import ray
|
||||
from ray.air._internal.torch_utils import load_torch_model
|
||||
import tempfile
|
||||
from google.cloud.aiplatform.vertex_ray.util._validation_utils import (
|
||||
_V2_4_WARNING_MESSAGE,
|
||||
_V2_9_WARNING_MESSAGE,
|
||||
)
|
||||
from google.cloud.aiplatform.utils import gcs_utils
|
||||
from typing import Optional
|
||||
|
||||
|
||||
try:
|
||||
from ray.train import torch as ray_torch
|
||||
import torch
|
||||
except ModuleNotFoundError as mnfe:
|
||||
raise ModuleNotFoundError("Torch isn't installed.") from mnfe
|
||||
|
||||
|
||||
def get_pytorch_model_from(
|
||||
checkpoint: ray_torch.TorchCheckpoint,
|
||||
model: Optional[torch.nn.Module] = None,
|
||||
) -> torch.nn.Module:
|
||||
"""Converts a TorchCheckpoint to Pytorch Model.
|
||||
|
||||
Example:
|
||||
from vertex_ray.predict import torch
|
||||
result = TorchTrainer.fit(...)
|
||||
|
||||
pytorch_model = torch.get_pytorch_model_from(
|
||||
checkpoint=result.checkpoint
|
||||
)
|
||||
|
||||
Args:
|
||||
checkpoint: TorchCheckpoint instance.
|
||||
model: If the checkpoint contains a model state dict, and not the model
|
||||
itself, then the state dict will be loaded to this `model`. Otherwise,
|
||||
the model will be discarded.
|
||||
|
||||
Returns:
|
||||
A Pytorch Native Framework Model.
|
||||
|
||||
Raises:
|
||||
ValueError: Invalid Argument.
|
||||
ModuleNotFoundError: PyTorch isn't installed.
|
||||
RuntimeError: Model not found.
|
||||
RuntimeError: Ray version 2.4 is not supported.
|
||||
RuntimeError: Only Ray version 2.9.3 is supported.
|
||||
"""
|
||||
ray_version = ray.__version__
|
||||
if ray_version == "2.4.0":
|
||||
raise RuntimeError(_V2_4_WARNING_MESSAGE)
|
||||
if ray_version != "2.9.3":
|
||||
raise RuntimeError(
|
||||
f"Ray on Vertex does not support Ray version {ray_version} to"
|
||||
" convert PyTorch model artifacts yet. Please use Ray 2.9.3."
|
||||
)
|
||||
if ray_version == "2.9.3":
|
||||
warnings.warn(_V2_9_WARNING_MESSAGE, DeprecationWarning, stacklevel=1)
|
||||
|
||||
try:
|
||||
return checkpoint.get_model()
|
||||
except AttributeError:
|
||||
model_file_name = ray.train.torch.TorchCheckpoint.MODEL_FILENAME
|
||||
|
||||
model_path = os.path.join(checkpoint.path, model_file_name)
|
||||
|
||||
try:
|
||||
import torch
|
||||
|
||||
except ModuleNotFoundError as mnfe:
|
||||
raise ModuleNotFoundError("PyTorch isn't installed.") from mnfe
|
||||
|
||||
if os.path.exists(model_path):
|
||||
model_or_state_dict = torch.load(
|
||||
model_path, map_location="cpu", weights_only=True
|
||||
)
|
||||
else:
|
||||
try:
|
||||
# Download from GCS to temp and then load_model
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
gcs_utils.download_from_gcs("gs://" + checkpoint.path, temp_dir)
|
||||
model_or_state_dict = torch.load(
|
||||
f"{temp_dir}/{model_file_name}",
|
||||
map_location="cpu",
|
||||
weights_only=True,
|
||||
)
|
||||
except Exception as e:
|
||||
raise RuntimeError(
|
||||
f"{model_file_name} not found in this checkpoint due to: {e}."
|
||||
)
|
||||
|
||||
model = load_torch_model(saved_model=model_or_state_dict, model_definition=model)
|
||||
return model
|
||||
Reference in New Issue
Block a user