structure saas with tools

This commit is contained in:
Davidson Gomes
2025-04-25 15:30:54 -03:00
commit 1aef473937
16434 changed files with 6584257 additions and 0 deletions

View File

@@ -0,0 +1,22 @@
"""Ray on Vertex AI Prediction Tensorflow."""
# -*- coding: utf-8 -*-
# Copyright 2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from .register import get_pytorch_model_from
__all__ = ("get_pytorch_model_from",)

View File

@@ -0,0 +1,112 @@
"""Regsiter Torch for Ray on Vertex AI."""
# -*- coding: utf-8 -*-
# Copyright 2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import warnings
import ray
from ray.air._internal.torch_utils import load_torch_model
import tempfile
from google.cloud.aiplatform.vertex_ray.util._validation_utils import (
_V2_4_WARNING_MESSAGE,
_V2_9_WARNING_MESSAGE,
)
from google.cloud.aiplatform.utils import gcs_utils
from typing import Optional
try:
from ray.train import torch as ray_torch
import torch
except ModuleNotFoundError as mnfe:
raise ModuleNotFoundError("Torch isn't installed.") from mnfe
def get_pytorch_model_from(
checkpoint: ray_torch.TorchCheckpoint,
model: Optional[torch.nn.Module] = None,
) -> torch.nn.Module:
"""Converts a TorchCheckpoint to Pytorch Model.
Example:
from vertex_ray.predict import torch
result = TorchTrainer.fit(...)
pytorch_model = torch.get_pytorch_model_from(
checkpoint=result.checkpoint
)
Args:
checkpoint: TorchCheckpoint instance.
model: If the checkpoint contains a model state dict, and not the model
itself, then the state dict will be loaded to this `model`. Otherwise,
the model will be discarded.
Returns:
A Pytorch Native Framework Model.
Raises:
ValueError: Invalid Argument.
ModuleNotFoundError: PyTorch isn't installed.
RuntimeError: Model not found.
RuntimeError: Ray version 2.4 is not supported.
RuntimeError: Only Ray version 2.9.3 is supported.
"""
ray_version = ray.__version__
if ray_version == "2.4.0":
raise RuntimeError(_V2_4_WARNING_MESSAGE)
if ray_version != "2.9.3":
raise RuntimeError(
f"Ray on Vertex does not support Ray version {ray_version} to"
" convert PyTorch model artifacts yet. Please use Ray 2.9.3."
)
if ray_version == "2.9.3":
warnings.warn(_V2_9_WARNING_MESSAGE, DeprecationWarning, stacklevel=1)
try:
return checkpoint.get_model()
except AttributeError:
model_file_name = ray.train.torch.TorchCheckpoint.MODEL_FILENAME
model_path = os.path.join(checkpoint.path, model_file_name)
try:
import torch
except ModuleNotFoundError as mnfe:
raise ModuleNotFoundError("PyTorch isn't installed.") from mnfe
if os.path.exists(model_path):
model_or_state_dict = torch.load(
model_path, map_location="cpu", weights_only=True
)
else:
try:
# Download from GCS to temp and then load_model
with tempfile.TemporaryDirectory() as temp_dir:
gcs_utils.download_from_gcs("gs://" + checkpoint.path, temp_dir)
model_or_state_dict = torch.load(
f"{temp_dir}/{model_file_name}",
map_location="cpu",
weights_only=True,
)
except Exception as e:
raise RuntimeError(
f"{model_file_name} not found in this checkpoint due to: {e}."
)
model = load_torch_model(saved_model=model_or_state_dict, model_definition=model)
return model