structure saas with tools

This commit is contained in:
Davidson Gomes
2025-04-25 15:30:54 -03:00
commit 1aef473937
16434 changed files with 6584257 additions and 0 deletions

View File

@@ -0,0 +1,68 @@
"""Ray on Vertex AI."""
# -*- coding: utf-8 -*-
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import sys
from google.cloud.aiplatform.vertex_ray.bigquery_datasource import (
_BigQueryDatasource,
)
from google.cloud.aiplatform.vertex_ray.client_builder import (
VertexRayClientBuilder as ClientBuilder,
)
from google.cloud.aiplatform.vertex_ray.cluster_init import (
create_ray_cluster,
delete_ray_cluster,
get_ray_cluster,
list_ray_clusters,
update_ray_cluster,
)
from google.cloud.aiplatform.vertex_ray import data
from google.cloud.aiplatform.vertex_ray.util.resources import (
AutoscalingSpec,
Resources,
NodeImages,
PscIConfig,
)
from google.cloud.aiplatform.vertex_ray.dashboard_sdk import (
get_job_submission_client_cluster_info,
)
if sys.version_info[1] not in (10, 11):
print(
"[Ray on Vertex]: The client environment with Python version 3.10 or 3.11 is required."
)
__all__ = (
"_BigQueryDatasource",
"data",
"ClientBuilder",
"get_job_submission_client_cluster_info",
"create_ray_cluster",
"delete_ray_cluster",
"get_ray_cluster",
"list_ray_clusters",
"update_ray_cluster",
"AutoscalingSpec",
"Resources",
"NodeImages",
"PscIConfig",
)

View File

@@ -0,0 +1,161 @@
# -*- coding: utf-8 -*-
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import logging
import os
import tempfile
import time
import uuid
from typing import Any, Iterable, Optional
import pyarrow.parquet as pq
from google.api_core import client_info
from google.api_core import exceptions
from google.cloud import bigquery
from google.cloud.aiplatform import initializer
import ray
from ray.data._internal.execution.interfaces import TaskContext
from ray.data._internal.remote_fn import cached_remote_fn
from ray.data.block import Block, BlockAccessor
try:
from ray.data.datasource.datasink import Datasink
except ImportError:
# If datasink cannot be imported, Ray >=2.9.3 is not installed
Datasink = None
DEFAULT_MAX_RETRY_CNT = 10
RATE_LIMIT_EXCEEDED_SLEEP_TIME = 11
_BQ_GAPIC_VERSION = bigquery.__version__ + "+vertex_ray"
bq_info = client_info.ClientInfo(
gapic_version=_BQ_GAPIC_VERSION, user_agent=f"ray-on-vertex/{_BQ_GAPIC_VERSION}"
)
# BigQuery write for Ray 2.42.0, 2.33.0, and 2.9.3
if Datasink is None:
_BigQueryDatasink = None
else:
class _BigQueryDatasink(Datasink):
def __init__(
self,
dataset: str,
project_id: Optional[str] = None,
max_retry_cnt: int = DEFAULT_MAX_RETRY_CNT,
overwrite_table: Optional[bool] = True,
) -> None:
self.dataset = dataset
self.project_id = project_id or initializer.global_config.project
self.max_retry_cnt = max_retry_cnt
self.overwrite_table = overwrite_table
def on_write_start(self) -> None:
# Set up datasets to write
client = bigquery.Client(project=self.project_id, client_info=bq_info)
dataset_id = self.dataset.split(".", 1)[0]
try:
client.get_dataset(dataset_id)
except exceptions.NotFound:
client.create_dataset(f"{self.project_id}.{dataset_id}", timeout=30)
print("[Ray on Vertex AI]: Created dataset " + dataset_id)
# Delete table if overwrite_table is True
if self.overwrite_table:
print(
f"[Ray on Vertex AI]: Attempting to delete table {self.dataset}"
+ " if it already exists since kwarg overwrite_table = True."
)
client.delete_table(
f"{self.project_id}.{self.dataset}", not_found_ok=True
)
else:
print(
"[Ray on Vertex AI]: The write will append to table "
+ f"{self.dataset} if it already exists "
+ "since kwarg overwrite_table = False."
)
def write(
self,
blocks: Iterable[Block],
ctx: TaskContext,
) -> Any:
def _write_single_block(
block: Block, project_id: str, dataset: str
) -> None:
block = BlockAccessor.for_block(block).to_arrow()
client = bigquery.Client(project=project_id, client_info=bq_info)
job_config = bigquery.LoadJobConfig(autodetect=True)
job_config.source_format = bigquery.SourceFormat.PARQUET
job_config.write_disposition = bigquery.WriteDisposition.WRITE_APPEND
with tempfile.TemporaryDirectory() as temp_dir:
fp = os.path.join(temp_dir, f"block_{uuid.uuid4()}.parquet")
pq.write_table(block, fp, compression="SNAPPY")
retry_cnt = 0
while retry_cnt <= self.max_retry_cnt:
with open(fp, "rb") as source_file:
job = client.load_table_from_file(
source_file, dataset, job_config=job_config
)
try:
logging.info(job.result())
break
except exceptions.Forbidden as e:
retry_cnt += 1
if retry_cnt > self.max_retry_cnt:
break
print(
"[Ray on Vertex AI]: A block write encountered"
+ f" a rate limit exceeded error {retry_cnt} time(s)."
+ " Sleeping to try again."
)
logging.debug(e)
time.sleep(RATE_LIMIT_EXCEEDED_SLEEP_TIME)
# Raise exception if retry_cnt exceeds max_retry_cnt
if retry_cnt > self.max_retry_cnt:
print(
f"[Ray on Vertex AI]: Maximum ({self.max_retry_cnt}) retry count exceeded."
+ " Ray will attempt to retry the block write via fault tolerance."
+ " For more information, see https://docs.ray.io/en/latest/ray-core/fault_tolerance/tasks.html"
)
raise RuntimeError(
f"[Ray on Vertex AI]: Write failed due to {retry_cnt}"
+ " repeated API rate limit exceeded responses. Consider"
+ " specifiying the max_retry_cnt kwarg with a higher value."
)
_write_single_block = cached_remote_fn(_write_single_block)
# Launch a remote task for each block within this write task
ray.get(
[
_write_single_block.remote(block, self.project_id, self.dataset)
for block in blocks
]
)
return "ok"

View File

@@ -0,0 +1,151 @@
# -*- coding: utf-8 -*-
# Copyright 2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from typing import List, Optional
from google.api_core import client_info
from google.api_core import exceptions
from google.api_core.gapic_v1 import client_info as v1_client_info
from google.cloud import bigquery
from google.cloud import bigquery_storage
from google.cloud.aiplatform import initializer
from google.cloud.bigquery_storage import types
from ray.data.block import Block
from ray.data.block import BlockMetadata
from ray.data.datasource.datasource import Datasource
from ray.data.datasource.datasource import ReadTask
_BQ_GAPIC_VERSION = bigquery.__version__ + "+vertex_ray"
_BQS_GAPIC_VERSION = bigquery_storage.__version__ + "+vertex_ray"
bq_info = client_info.ClientInfo(
gapic_version=_BQ_GAPIC_VERSION, user_agent=f"ray-on-vertex/{_BQ_GAPIC_VERSION}"
)
bqstorage_info = v1_client_info.ClientInfo(
gapic_version=_BQS_GAPIC_VERSION, user_agent=f"ray-on-vertex/{_BQS_GAPIC_VERSION}"
)
DEFAULT_MAX_RETRY_CNT = 10
RATE_LIMIT_EXCEEDED_SLEEP_TIME = 11
class _BigQueryDatasource(Datasource):
def __init__(
self,
project_id: Optional[str] = None,
dataset: Optional[str] = None,
query: Optional[str] = None,
):
self._project_id = project_id or initializer.global_config.project
self._dataset = dataset
self._query = query
if query is not None and dataset is not None:
raise ValueError(
"[Ray on Vertex AI]: Query and dataset kwargs cannot both be provided (must be mutually exclusive)."
)
def get_read_tasks(self, parallelism: int) -> List[ReadTask]:
# Executed by a worker node
def _read_single_partition(stream) -> Block:
client = bigquery_storage.BigQueryReadClient(client_info=bqstorage_info)
reader = client.read_rows(stream.name)
return reader.to_arrow()
if self._query:
query_client = bigquery.Client(
project=self._project_id, client_info=bq_info
)
query_job = query_client.query(self._query)
query_job.result()
destination = str(query_job.destination)
dataset_id = destination.split(".")[-2]
table_id = destination.split(".")[-1]
else:
self._validate_dataset_table_exist(self._project_id, self._dataset)
dataset_id = self._dataset.split(".")[0]
table_id = self._dataset.split(".")[1]
bqs_client = bigquery_storage.BigQueryReadClient(client_info=bqstorage_info)
table = f"projects/{self._project_id}/datasets/{dataset_id}/tables/{table_id}"
if parallelism == -1:
parallelism = None
requested_session = types.ReadSession(
table=table,
data_format=types.DataFormat.ARROW,
)
read_session = bqs_client.create_read_session(
parent=f"projects/{self._project_id}",
read_session=requested_session,
max_stream_count=parallelism,
)
read_tasks = []
print("[Ray on Vertex AI]: Created streams:", len(read_session.streams))
if len(read_session.streams) < parallelism:
print(
"[Ray on Vertex AI]: The number of streams created by the "
+ "BigQuery Storage Read API is less than the requested "
+ "parallelism due to the size of the dataset."
)
for stream in read_session.streams:
# Create a metadata block object to store schema, etc.
metadata = BlockMetadata(
num_rows=None,
size_bytes=None,
schema=None,
input_files=None,
exec_stats=None,
)
# Create a no-arg wrapper read function which returns a block
read_single_partition = lambda stream=stream: [ # noqa: E731
_read_single_partition(stream)
]
# Create the read task and pass the wrapper and metadata in
read_task = ReadTask(read_single_partition, metadata)
read_tasks.append(read_task)
return read_tasks
def estimate_inmemory_data_size(self) -> Optional[int]:
# TODO(b/281891467): Implement this method
return None
def _validate_dataset_table_exist(self, project_id: str, dataset: str) -> None:
client = bigquery.Client(project=project_id, client_info=bq_info)
dataset_id = dataset.split(".")[0]
try:
client.get_dataset(dataset_id)
except exceptions.NotFound:
raise ValueError(
"[Ray on Vertex AI]: Dataset {} is not found. Please ensure that it exists.".format(
dataset_id
)
)
try:
client.get_table(dataset)
except exceptions.NotFound:
raise ValueError(
"[Ray on Vertex AI]: Table {} is not found. Please ensure that it exists.".format(
dataset
)
)

View File

@@ -0,0 +1,201 @@
# -*- coding: utf-8 -*-
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import grpc
import logging
import ray
from typing import Dict
from typing import Optional
from google.cloud import aiplatform
from google.cloud.aiplatform import initializer
from ray import client_builder
from .render import VertexRayTemplate
from .util import _validation_utils
from .util import _gapic_utils
VERTEX_SDK_VERSION = aiplatform.__version__
class _VertexRayClientContext(client_builder.ClientContext):
"""Custom ClientContext."""
def __init__(
self,
persistent_resource_id: str,
ray_head_uris: Dict[str, str],
ray_client_context: client_builder.ClientContext,
) -> None:
dashboard_uri = ray_head_uris.get("RAY_DASHBOARD_URI")
if dashboard_uri is None:
raise ValueError(
"Ray Cluster ",
persistent_resource_id,
" failed to start Head node properly.",
)
if ray.__version__ in ("2.42.0", "2.33.0"):
super().__init__(
dashboard_url=dashboard_uri,
python_version=ray_client_context.python_version,
ray_version=ray_client_context.ray_version,
ray_commit=ray_client_context.ray_commit,
_num_clients=ray_client_context._num_clients,
_context_to_restore=ray_client_context._context_to_restore,
)
elif ray.__version__ == "2.9.3":
super().__init__(
dashboard_url=dashboard_uri,
python_version=ray_client_context.python_version,
ray_version=ray_client_context.ray_version,
ray_commit=ray_client_context.ray_commit,
protocol_version=ray_client_context.protocol_version,
_num_clients=ray_client_context._num_clients,
_context_to_restore=ray_client_context._context_to_restore,
)
else:
raise ImportError(
f"[Ray on Vertex AI]: Unsupported version {ray.__version__}."
+ "Only 2.42.0, 2.33.0, and 2.9.3 are supported."
)
self.persistent_resource_id = persistent_resource_id
self.vertex_sdk_version = str(VERTEX_SDK_VERSION)
self.shell_uri = ray_head_uris.get("RAY_HEAD_NODE_INTERACTIVE_SHELL_URI")
def _context_table_template(self):
shell_uri_row = None
if self.shell_uri is not None:
shell_uri_row = VertexRayTemplate("context_shellurirow.html.j2").render(
shell_uri=self.shell_uri
)
return VertexRayTemplate("context_table.html.j2").render(
python_version=self.python_version,
ray_version=self.ray_version,
vertex_sdk_version=self.vertex_sdk_version,
dashboard_url=self.dashboard_url,
persistent_resource_id=self.persistent_resource_id,
shell_uri_row=shell_uri_row,
)
class VertexRayClientBuilder(client_builder.ClientBuilder):
"""Class to initialize a Ray client with vertex on ray capabilities."""
def __init__(self, address: Optional[str]) -> None:
address = _validation_utils.maybe_reconstruct_resource_name(address)
_validation_utils.valid_resource_name(address)
self._credentials = None
self._metadata = None
self.vertex_address = address
logging.info(
"[Ray on Vertex AI]: Using cluster resource name to access head address with GAPIC API"
)
self.resource_name = address
self.response = _gapic_utils.get_persistent_resource(self.resource_name)
private_address = self.response.resource_runtime.access_uris.get(
"RAY_HEAD_NODE_INTERNAL_IP"
)
public_address = self.response.resource_runtime.access_uris.get(
"RAY_CLIENT_ENDPOINT"
)
service_account = (
self.response.resource_runtime_spec.service_account_spec.service_account
)
if public_address is None:
address = private_address
if service_account:
raise ValueError(
"[Ray on Vertex AI]: Ray Cluster ",
address,
" failed to start Head node properly because custom service"
" account isn't supported in peered VPC network. Use public"
" endpoint instead (createa a cluster without specifying"
" VPC network).",
)
else:
address = public_address
if address is None:
persistent_resource_id = self.resource_name.split("/")[5]
raise ValueError(
"[Ray on Vertex AI]: Ray Cluster ",
persistent_resource_id,
" Head node is not reachable. Please ensure that a valid VPC network has been specified.",
)
logging.debug("[Ray on Vertex AI]: Resolved head node ip: %s", address)
cluster = _gapic_utils.persistent_resource_to_cluster(
persistent_resource=self.response
)
if cluster is None:
raise ValueError(
"[Ray on Vertex AI]: Please delete and recreate the cluster (The cluster is not a Ray cluster or the cluster image is outdated)."
)
local_ray_verion = _validation_utils.get_local_ray_version()
if cluster.ray_version != local_ray_verion:
if cluster.head_node_type.custom_image is None:
install_ray_version = _validation_utils.SUPPORTED_RAY_VERSIONS.get(
cluster.ray_version
)
logging.info(
"[Ray on Vertex]: Local runtime has Ray version %s"
", but the requested cluster runtime has %s. Please "
"ensure that the Ray versions match for client connectivity. You may "
'"pip install --user --force-reinstall ray[default]==%s"'
" and restart runtime before cluster connection."
% (local_ray_verion, cluster.ray_version, install_ray_version)
)
else:
logging.info(
"[Ray on Vertex]: Local runtime has Ray version %s."
"Please ensure that the Ray versions match for client connectivity."
% local_ray_verion
)
super().__init__(address)
def connect(self) -> _VertexRayClientContext:
# Can send any other params to ray cluster here
logging.info("[Ray on Vertex AI]: Connecting...")
public_address = self.response.resource_runtime.access_uris.get(
"RAY_CLIENT_ENDPOINT"
)
private_address = self.response.resource_runtime.access_uris.get(
"RAY_HEAD_NODE_INTERNAL_IP"
)
if public_address and not private_address:
self._credentials = grpc.ssl_channel_credentials()
bearer_token = _validation_utils.get_bearer_token()
self._metadata = [
("authorization", "Bearer {}".format(bearer_token)),
("x-goog-user-project", "{}".format(initializer.global_config.project)),
]
ray_client_context = super().connect()
ray_head_uris = self.response.resource_runtime.access_uris
# Valid resource name (reference public doc for public release):
# "projects/<project_num>/locations/<region>/persistentResources/<pr_id>"
persistent_resource_id = self.resource_name.split("/")[5]
return _VertexRayClientContext(
persistent_resource_id=persistent_resource_id,
ray_head_uris=ray_head_uris,
ray_client_context=ray_client_context,
)

View File

@@ -0,0 +1,575 @@
# -*- coding: utf-8 -*-
# Copyright 2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import copy
import logging
import time
from typing import Dict, List, Optional
import warnings
from google.cloud.aiplatform import initializer
from google.cloud.aiplatform import utils
from google.cloud.aiplatform.utils import resource_manager_utils
from google.cloud.aiplatform_v1beta1.types import persistent_resource_service
from google.cloud.aiplatform_v1beta1.types.machine_resources import NfsMount
from google.cloud.aiplatform_v1beta1.types.persistent_resource import (
PersistentResource,
RayLogsSpec,
RaySpec,
RayMetricSpec,
ResourcePool,
ResourceRuntimeSpec,
ServiceAccountSpec,
)
from google.cloud.aiplatform_v1beta1.types.service_networking import (
PscInterfaceConfig,
)
from google.cloud.aiplatform.vertex_ray.util import (
_gapic_utils,
_validation_utils,
resources,
)
from google.protobuf import field_mask_pb2 # type: ignore
from google.cloud.aiplatform.vertex_ray.util._validation_utils import (
_V2_4_WARNING_MESSAGE,
_V2_9_WARNING_MESSAGE,
)
def create_ray_cluster(
head_node_type: Optional[resources.Resources] = resources.Resources(),
python_version: Optional[str] = "3.10",
ray_version: Optional[str] = "2.42",
network: Optional[str] = None,
service_account: Optional[str] = None,
cluster_name: Optional[str] = None,
worker_node_types: Optional[List[resources.Resources]] = [resources.Resources()],
custom_images: Optional[resources.NodeImages] = None,
enable_metrics_collection: Optional[bool] = True,
enable_logging: Optional[bool] = True,
psc_interface_config: Optional[resources.PscIConfig] = None,
reserved_ip_ranges: Optional[List[str]] = None,
nfs_mounts: Optional[List[resources.NfsMount]] = None,
labels: Optional[Dict[str, str]] = None,
) -> str:
"""Create a ray cluster on the Vertex AI.
Sample usage:
from vertex_ray import Resources
head_node_type = Resources(
machine_type="n1-standard-8",
node_count=1,
accelerator_type="NVIDIA_TESLA_T4",
accelerator_count=1,
custom_image="us-docker.pkg.dev/my-project/ray-cpu-image.2.33:latest", # Optional
)
worker_node_types = [Resources(
machine_type="n1-standard-8",
node_count=2,
accelerator_type="NVIDIA_TESLA_T4",
accelerator_count=1,
custom_image="us-docker.pkg.dev/my-project/ray-gpu-image.2.33:latest", # Optional
)]
cluster_resource_name = vertex_ray.create_ray_cluster(
head_node_type=head_node_type,
network="projects/my-project-number/global/networks/my-vpc-name", # Optional
service_account="my-service-account@my-project-number.iam.gserviceaccount.com", # Optional
cluster_name="my-cluster-name", # Optional
worker_node_types=worker_node_types,
ray_version="2.33",
)
After a ray cluster is set up, you can call
`ray.init(f"vertex_ray://{cluster_resource_name}", runtime_env=...)` without
specifying ray cluster address to connect to the cluster. To shut down the
cluster you can call `ray.delete_ray_cluster()`.
Note: If the active ray cluster has not finished shutting down, you cannot
create a new ray cluster with the same cluster_name.
Args:
head_node_type: The head node resource. Resources.node_count must be 1.
If not set, default value of Resources() class will be used.
python_version: Python version for the ray cluster.
ray_version: Ray version for the ray cluster. Default is 2.42.0.
network: Virtual private cloud (VPC) network. For Ray Client, VPC
peering is required to connect to the Ray Cluster managed in the
Vertex API service. For Ray Job API, VPC network is not required
because Ray Cluster connection can be accessed through dashboard
address.
service_account: Service account to be used for running Ray programs on
the cluster.
cluster_name: This value may be up to 63 characters, and valid
characters are `[a-z0-9_-]`. The first character cannot be a number
or hyphen.
worker_node_types: The list of Resources of the worker nodes. The same
Resources object should not appear multiple times in the list.
custom_images: The NodeImages which specifies head node and worker nodes
images. All the workers will share the same image. If each Resource
has a specific custom image, use `Resources.custom_image` for
head/worker_node_type(s). Note that configuring `Resources.custom_image`
will override `custom_images` here. Allowlist only.
enable_metrics_collection: Enable Ray metrics collection for visualization.
enable_logging: Enable exporting Ray logs to Cloud Logging.
psc_interface_config: PSC-I config.
reserved_ip_ranges: A list of names for the reserved IP ranges under
the VPC network that can be used for this cluster. If set, we will
deploy the cluster within the provided IP ranges. Otherwise, the
cluster is deployed to any IP ranges under the provided VPC network.
Example: ["vertex-ai-ip-range"].
labels:
The labels with user-defined metadata to organize Ray cluster.
Label keys and values can be no longer than 64 characters (Unicode
codepoints), can only contain lowercase letters, numeric characters,
underscores and dashes. International characters are allowed.
See https://goo.gl/xmQnxf for more information and examples of labels.
Returns:
The cluster_resource_name of the initiated Ray cluster on Vertex.
Raise:
ValueError: If the cluster is not created successfully.
RuntimeError: If the ray_version is 2.4.
"""
if network is None:
logging.info(
"[Ray on Vertex]: No VPC network configured. It is required for client connection."
)
if ray_version == "2.4":
raise RuntimeError(_V2_4_WARNING_MESSAGE)
if ray_version == "2.9.3":
warnings.warn(_V2_9_WARNING_MESSAGE, DeprecationWarning, stacklevel=1)
local_ray_verion = _validation_utils.get_local_ray_version()
if ray_version != local_ray_verion:
if custom_images is None and head_node_type.custom_image is None:
install_ray_version = "2.42.0"
logging.info(
"[Ray on Vertex]: Local runtime has Ray version %s"
", but the requested cluster runtime has %s. Please "
"ensure that the Ray versions match for client connectivity. You may "
'"pip install --user --force-reinstall ray[default]==%s"'
" and restart runtime before cluster connection."
% (local_ray_verion, ray_version, install_ray_version)
)
else:
logging.info(
"[Ray on Vertex]: Local runtime has Ray version %s."
"Please ensure that the Ray versions match for client connectivity."
% local_ray_verion
)
if cluster_name is None:
cluster_name = "ray-cluster-" + utils.timestamped_unique_name()
if head_node_type:
if head_node_type.node_count != 1:
raise ValueError(
"[Ray on Vertex AI]: For head_node_type, "
+ "Resources.node_count must be 1."
)
if head_node_type.autoscaling_spec is not None:
raise ValueError(
"[Ray on Vertex AI]: For head_node_type, "
+ "Resources.autoscaling_spec must be None."
)
if (
head_node_type.accelerator_type is None
and head_node_type.accelerator_count > 0
):
raise ValueError(
"[Ray on Vertex]: accelerator_type must be specified when"
+ " accelerator_count is set to a value other than 0."
)
resource_pool_images = {}
# head node
resource_pool_0 = ResourcePool()
resource_pool_0.id = "head-node"
resource_pool_0.replica_count = head_node_type.node_count
resource_pool_0.machine_spec.machine_type = head_node_type.machine_type
resource_pool_0.machine_spec.accelerator_count = head_node_type.accelerator_count
resource_pool_0.machine_spec.accelerator_type = head_node_type.accelerator_type
resource_pool_0.disk_spec.boot_disk_type = head_node_type.boot_disk_type
resource_pool_0.disk_spec.boot_disk_size_gb = head_node_type.boot_disk_size_gb
enable_cuda = True if head_node_type.accelerator_count > 0 else False
if head_node_type.custom_image is not None:
image_uri = head_node_type.custom_image
elif custom_images is None:
image_uri = _validation_utils.get_image_uri(
ray_version, python_version, enable_cuda
)
elif custom_images.head is not None and custom_images.worker is not None:
image_uri = custom_images.head
else:
raise ValueError(
"[Ray on Vertex AI]: custom_images.head and custom_images.worker must be specified when custom_images is set."
)
resource_pool_images[resource_pool_0.id] = image_uri
worker_pools = []
i = 0
if worker_node_types:
for worker_node_type in worker_node_types:
if (
worker_node_type.accelerator_type is None
and worker_node_type.accelerator_count > 0
):
raise ValueError(
"[Ray on Vertex]: accelerator_type must be specified when"
+ " accelerator_count is set to a value other than 0."
)
additional_replica_count = resources._check_machine_spec_identical(
head_node_type, worker_node_type
)
if worker_node_type.autoscaling_spec is None:
# Worker and head share the same MachineSpec, merge them into the
# same ResourcePool
resource_pool_0.replica_count = (
resource_pool_0.replica_count + additional_replica_count
)
else:
if additional_replica_count > 0:
# Autoscaling for single ResourcePool (homogeneous cluster).
resource_pool_0.replica_count = None
resource_pool_0.autoscaling_spec.min_replica_count = (
worker_node_type.autoscaling_spec.min_replica_count
)
resource_pool_0.autoscaling_spec.max_replica_count = (
worker_node_type.autoscaling_spec.max_replica_count
)
if additional_replica_count == 0:
resource_pool = ResourcePool()
resource_pool.id = f"worker-pool{i+1}"
if worker_node_type.autoscaling_spec is None:
resource_pool.replica_count = worker_node_type.node_count
else:
# Autoscaling for worker ResourcePool.
resource_pool.autoscaling_spec.min_replica_count = (
worker_node_type.autoscaling_spec.min_replica_count
)
resource_pool.autoscaling_spec.max_replica_count = (
worker_node_type.autoscaling_spec.max_replica_count
)
resource_pool.machine_spec.machine_type = worker_node_type.machine_type
resource_pool.machine_spec.accelerator_count = (
worker_node_type.accelerator_count
)
resource_pool.machine_spec.accelerator_type = (
worker_node_type.accelerator_type
)
resource_pool.disk_spec.boot_disk_type = worker_node_type.boot_disk_type
resource_pool.disk_spec.boot_disk_size_gb = (
worker_node_type.boot_disk_size_gb
)
worker_pools.append(resource_pool)
enable_cuda = True if worker_node_type.accelerator_count > 0 else False
if worker_node_type.custom_image is not None:
image_uri = worker_node_type.custom_image
elif custom_images is None:
image_uri = _validation_utils.get_image_uri(
ray_version, python_version, enable_cuda
)
else:
image_uri = custom_images.worker
resource_pool_images[resource_pool.id] = image_uri
i += 1
resource_pools = [resource_pool_0] + worker_pools
metrics_collection_disabled = not enable_metrics_collection
ray_metric_spec = RayMetricSpec(disabled=metrics_collection_disabled)
logging_disabled = not enable_logging
ray_logs_spec = RayLogsSpec(disabled=logging_disabled)
ray_spec = RaySpec(
resource_pool_images=resource_pool_images,
ray_metric_spec=ray_metric_spec,
ray_logs_spec=ray_logs_spec,
)
if nfs_mounts:
gapic_nfs_mounts = []
for nfs_mount in nfs_mounts:
gapic_nfs_mounts.append(
NfsMount(
server=nfs_mount.server,
path=nfs_mount.path,
mount_point=nfs_mount.mount_point,
)
)
ray_spec.nfs_mounts = gapic_nfs_mounts
if service_account:
service_account_spec = ServiceAccountSpec(
enable_custom_service_account=True,
service_account=service_account,
)
resource_runtime_spec = ResourceRuntimeSpec(
ray_spec=ray_spec,
service_account_spec=service_account_spec,
)
else:
resource_runtime_spec = ResourceRuntimeSpec(ray_spec=ray_spec)
if psc_interface_config:
gapic_psc_interface_config = PscInterfaceConfig(
network_attachment=psc_interface_config.network_attachment,
)
else:
gapic_psc_interface_config = None
persistent_resource = PersistentResource(
resource_pools=resource_pools,
network=network,
labels=labels,
resource_runtime_spec=resource_runtime_spec,
psc_interface_config=gapic_psc_interface_config,
reserved_ip_ranges=reserved_ip_ranges,
)
location = initializer.global_config.location
project_id = initializer.global_config.project
project_number = resource_manager_utils.get_project_number(project_id)
parent = f"projects/{project_number}/locations/{location}"
request = persistent_resource_service.CreatePersistentResourceRequest(
parent=parent,
persistent_resource=persistent_resource,
persistent_resource_id=cluster_name,
)
client = _gapic_utils.create_persistent_resource_client()
try:
_ = client.create_persistent_resource(request)
except Exception as e:
raise ValueError("Failed in cluster creation due to: ", e) from e
# Get persisent resource
cluster_resource_name = f"{parent}/persistentResources/{cluster_name}"
response = _gapic_utils.get_persistent_resource(
persistent_resource_name=cluster_resource_name,
tolerance=1, # allow 1 retry to avoid get request before creation
)
return response.name
def delete_ray_cluster(cluster_resource_name: str) -> None:
"""Delete Ray Cluster.
Args:
cluster_resource_name: Cluster resource name.
Raises:
FailedPrecondition: If the cluster is deleted already.
"""
client = _gapic_utils.create_persistent_resource_client()
request = persistent_resource_service.DeletePersistentResourceRequest(
name=cluster_resource_name
)
try:
client.delete_persistent_resource(request)
print("[Ray on Vertex AI]: Successfully deleted the cluster.")
except Exception as e:
raise ValueError(
"[Ray on Vertex AI]: Failed in cluster deletion due to: ", e
) from e
def get_ray_cluster(cluster_resource_name: str) -> resources.Cluster:
"""Get Ray Cluster.
Args:
cluster_resource_name: Cluster resource name.
Returns:
A Cluster object.
"""
client = _gapic_utils.create_persistent_resource_client()
request = persistent_resource_service.GetPersistentResourceRequest(
name=cluster_resource_name
)
try:
response = client.get_persistent_resource(request)
except Exception as e:
raise ValueError(
"[Ray on Vertex AI]: Failed in getting the cluster due to: ", e
) from e
cluster = _gapic_utils.persistent_resource_to_cluster(persistent_resource=response)
if cluster:
return cluster
raise ValueError(
"[Ray on Vertex AI]: Please delete and recreate the cluster (The cluster is not a Ray cluster or the cluster image is outdated)."
)
def list_ray_clusters() -> List[resources.Cluster]:
"""List Ray Clusters under the currently authenticated project.
Returns:
List of Cluster objects that exists in the current authorized project.
"""
location = initializer.global_config.location
project_id = initializer.global_config.project
project_number = resource_manager_utils.get_project_number(project_id)
parent = f"projects/{project_number}/locations/{location}"
request = persistent_resource_service.ListPersistentResourcesRequest(
parent=parent,
)
client = _gapic_utils.create_persistent_resource_client()
try:
response = client.list_persistent_resources(request)
except Exception as e:
raise ValueError(
"[Ray on Vertex AI]: Failed in listing the clusters due to: ", e
) from e
ray_clusters = []
for persistent_resource in response:
ray_cluster = _gapic_utils.persistent_resource_to_cluster(
persistent_resource=persistent_resource
)
if ray_cluster:
ray_clusters.append(ray_cluster)
return ray_clusters
def update_ray_cluster(
cluster_resource_name: str, worker_node_types: List[resources.Resources]
) -> str:
"""Update Ray Cluster (currently support resizing node counts for worker nodes).
Sample usage:
my_cluster = vertex_ray.get_ray_cluster(
cluster_resource_name=my_existing_cluster_resource_name,
)
# Declaration to resize all the worker_node_type to node_count=1
new_worker_node_types = []
for worker_node_type in my_cluster.worker_node_types:
worker_node_type.node_count = 1
new_worker_node_types.append(worker_node_type)
# Execution to update new node_count (block until complete)
vertex_ray.update_ray_cluster(
cluster_resource_name=my_cluster.cluster_resource_name,
worker_node_types=new_worker_node_types,
)
Args:
cluster_resource_name:
worker_node_types: The list of Resources of the resized worker nodes.
The same Resources object should not appear multiple times in the list.
Returns:
The cluster_resource_name of the Ray cluster on Vertex.
"""
# worker_node_types should not be duplicated.
for i in range(len(worker_node_types)):
for j in range(len(worker_node_types)):
additional_replica_count = resources._check_machine_spec_identical(
worker_node_types[i], worker_node_types[j]
)
if additional_replica_count > 0 and i != j:
raise ValueError(
"[Ray on Vertex AI]: Worker_node_types have duplicate "
+ f"machine specs: {worker_node_types[i]} "
+ f"and {worker_node_types[j]}"
)
persistent_resource = _gapic_utils.get_persistent_resource(
persistent_resource_name=cluster_resource_name
)
current_persistent_resource = copy.deepcopy(persistent_resource)
current_persistent_resource.resource_pools[0].replica_count = 1
previous_ray_cluster = get_ray_cluster(cluster_resource_name)
head_node_type = previous_ray_cluster.head_node_type
previous_worker_node_types = previous_ray_cluster.worker_node_types
# new worker_node_types and previous_worker_node_types should be the same length.
if len(worker_node_types) != len(previous_worker_node_types):
raise ValueError(
"[Ray on Vertex AI]: Desired number of worker_node_types "
+ "(%i) does not match the number of the "
+ "existing worker_node_type(%i).",
len(worker_node_types),
len(previous_worker_node_types),
)
# Merge worker_node_type and head_node_type if they share
# the same machine spec.
not_merged = 1
for i in range(len(worker_node_types)):
additional_replica_count = resources._check_machine_spec_identical(
head_node_type, worker_node_types[i]
)
if additional_replica_count != 0 or (
additional_replica_count == 0 and worker_node_types[i].node_count == 0
):
# Merge the 1st duplicated worker with head, allow scale down to 0 worker
current_persistent_resource.resource_pools[0].replica_count = (
1 + additional_replica_count
)
# Reset not_merged
not_merged = 0
else:
# No duplication w/ head node, write the 2nd worker node to the 2nd resource pool.
current_persistent_resource.resource_pools[
i + not_merged
].replica_count = worker_node_types[i].node_count
# New worker_node_type.node_count should be >=1 unless the worker_node_type
# and head_node_type are merged due to the same machine specs.
if worker_node_types[i].node_count == 0:
raise ValueError(
"[Ray on Vertex AI]: Worker_node_type "
+ f"({worker_node_types[i]}) must update to >= 1 nodes",
)
request = persistent_resource_service.UpdatePersistentResourceRequest(
persistent_resource=current_persistent_resource,
update_mask=field_mask_pb2.FieldMask(paths=["resource_pools.replica_count"]),
)
client = _gapic_utils.create_persistent_resource_client()
try:
operation_future = client.update_persistent_resource(request)
except Exception as e:
raise ValueError(
"[Ray on Vertex AI]: Failed in updating the cluster due to: ", e
) from e
# block before returning
start_time = time.time()
response = operation_future.result()
duration = (time.time() - start_time) // 60
print(
"[Ray on Vertex AI]: Successfully updated the cluster ({} mininutes elapsed).".format(
duration
)
)
return response.name

View File

@@ -0,0 +1,79 @@
# -*- coding: utf-8 -*-
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""Utility to interact with Ray-on-Vertex dashboard."""
from ray.dashboard.modules import dashboard_sdk as oss_dashboard_sdk
from .util import _gapic_utils
from .util import _validation_utils
def get_job_submission_client_cluster_info(
address: str, *args, **kwargs
) -> oss_dashboard_sdk.ClusterInfo:
"""A vertex_ray implementation of get_job_submission_client_cluster_info().
Implements
https://github.com/ray-project/ray/blob/ray-2.42.0/python/ray/dashboard/modules/dashboard_sdk.py#L84
This will be called in from Ray Job API Python client.
Args:
address: Address without the module prefix `vertex_ray` but otherwise
the same format as passed to ray.init(address="vertex_ray://...").
*args: Remainder of positional args that might be passed down from
the framework.
**kwargs: Remainder of keyword args that might be passed down from
the framework.
Returns:
An instance of ClusterInfo that contains address, cookies and
metadata for SubmissionClient to use.
Raises:
RuntimeError if head_address is None.
"""
if _validation_utils.valid_dashboard_address(address):
dashboard_address = address
else:
address = _validation_utils.maybe_reconstruct_resource_name(address)
_validation_utils.valid_resource_name(address)
resource_name = address
response = _gapic_utils.get_persistent_resource(resource_name)
dashboard_address = response.resource_runtime.access_uris.get(
"RAY_DASHBOARD_URI", None
)
if dashboard_address is None:
raise RuntimeError(
"[Ray on Vertex AI]: Unable to obtain a response from the backend."
)
# If passing the dashboard uri, programmatically get headers
bearer_token = _validation_utils.get_bearer_token()
if kwargs.get("headers", None) is None:
kwargs["headers"] = {
"Content-Type": "application/json",
"Authorization": "Bearer {}".format(bearer_token),
}
return oss_dashboard_sdk.get_job_submission_client_cluster_info(
address=dashboard_address,
_use_tls=True,
*args,
**kwargs,
)

View File

@@ -0,0 +1,192 @@
# -*- coding: utf-8 -*-
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import warnings
import ray.data
from ray.data.dataset import Dataset
from typing import Any, Dict, Optional
from google.cloud.aiplatform.vertex_ray.bigquery_datasource import (
_BigQueryDatasource,
)
try:
from google.cloud.aiplatform.vertex_ray.bigquery_datasink import (
_BigQueryDatasink,
)
except ImportError:
_BigQueryDatasink = None
from google.cloud.aiplatform.vertex_ray.util._validation_utils import (
_V2_4_WARNING_MESSAGE,
_V2_9_WARNING_MESSAGE,
)
def read_bigquery(
project_id: Optional[str] = None,
dataset: Optional[str] = None,
query: Optional[str] = None,
*,
parallelism: int = -1,
ray_remote_args: Dict[str, Any] = None,
concurrency: Optional[int] = None,
override_num_blocks: Optional[int] = None,
) -> Dataset:
"""Create a dataset from BigQuery.
The data to read from is specified via the ``project_id``, ``dataset``
and/or ``query`` parameters.
Args:
project_id: The name of the associated Google Cloud Project that hosts
the dataset to read.
dataset: The name of the dataset hosted in BigQuery in the format of
``dataset_id.table_id``. Both the dataset_id and table_id must exist
otherwise an exception will be raised.
query: The query to execute. The dataset is created from the results of
executing the query if provided. Otherwise, the entire dataset is read.
For query syntax guidelines, see
https://cloud.google.com/bigquery/docs/reference/standard-sql/query-syntax
parallelism: 2.33.0, 2.42.0: This argument is deprecated. Use
``override_num_blocks`` argument. 2.9.3: The requested parallelism of
the read. If -1, it will be automatically chosen based on the available
cluster resources and estimated in-memory data size.
ray_remote_args: kwargs passed to ray.remote in the read tasks.
concurrency: Supported for 2.33.0 and 2.42.0 only: The maximum number of
Ray tasks to run concurrently. Set this to control number of tasks to
run concurrently. This doesn't change the total number of tasks run or
the total number of output blocks. By default, concurrency is
dynamically decided based on the available resources.
override_num_blocks: Supported for 2.33.0 and 2.42.0 only: Override the
number of output blocks from all read tasks. By default, the number of
output blocks is dynamically decided based on input data size and
available resources. You shouldn't manually set this value in most
cases.
Returns:
Dataset producing rows from the results of executing the query
or reading the entire dataset on the specified BigQuery dataset.
"""
datasource = _BigQueryDatasource(
project_id=project_id,
dataset=dataset,
query=query,
)
if ray.__version__ == "2.9.3":
warnings.warn(_V2_9_WARNING_MESSAGE, DeprecationWarning, stacklevel=1)
# Concurrency and override_num_blocks are not supported in 2.9.3
return ray.data.read_datasource(
datasource=datasource,
parallelism=parallelism,
ray_remote_args=ray_remote_args,
)
elif ray.__version__ in ("2.33.0", "2.42.0"):
return ray.data.read_datasource(
datasource=datasource,
parallelism=parallelism,
ray_remote_args=ray_remote_args,
concurrency=concurrency,
override_num_blocks=override_num_blocks,
)
else:
raise ImportError(
f"[Ray on Vertex AI]: Unsupported version {ray.__version__}."
+ "Only 2.42.0, 2.33.0, and 2.9.3 are supported."
)
def write_bigquery(
ds: Dataset,
project_id: Optional[str] = None,
dataset: Optional[str] = None,
max_retry_cnt: int = 10,
ray_remote_args: Dict[str, Any] = None,
overwrite_table: Optional[bool] = True,
concurrency: Optional[int] = None,
) -> Any:
"""Write the dataset to a BigQuery dataset table.
Args:
ds: The dataset to write.
project_id: The name of the associated Google Cloud Project that hosts
the dataset table to write to.
dataset: The name of the dataset table hosted in BigQuery in the format of
``dataset_id.table_id``.
The dataset table is created if it doesn't already exist.
In 2.9.3, the table_id is overwritten if it exists.
max_retry_cnt: The maximum number of retries that an individual block write
is retried due to BigQuery rate limiting errors.
The default number of retries is 10.
ray_remote_args: kwargs passed to ray.remote in the write tasks.
overwrite_table: Not supported in 2.9.3.
2.33.0, 2.42.0: Whether the write will overwrite the table if it already
exists. The default behavior is to overwrite the table.
If false, will append to the table if it exists.
concurrency: Not supported in 2.9.3.
2.33.0, 2.42.0: The maximum number of Ray tasks to run concurrently. Set this
to control number of tasks to run concurrently. This doesn't change the
total number of tasks run or the total number of output blocks. By default,
concurrency is dynamically decided based on the available resources.
"""
if ray.__version__ == "2.4.0":
raise RuntimeError(_V2_4_WARNING_MESSAGE)
elif ray.__version__ in ("2.9.3", "2.33.0", "2.42.0"):
if ray.__version__ == "2.9.3":
warnings.warn(_V2_9_WARNING_MESSAGE, DeprecationWarning, stacklevel=1)
if ray_remote_args is None:
ray_remote_args = {}
# Each write task will launch individual remote tasks to write each block
# To avoid duplicate block writes, the write task should not be retried
if ray_remote_args.get("max_retries", 0) != 0:
print(
"[Ray on Vertex AI]: The max_retries of a BigQuery Write "
"Task should be set to 0 to avoid duplicate writes."
)
else:
ray_remote_args["max_retries"] = 0
if ray.__version__ == "2.9.3":
# Concurrency and overwrite_table are not supported in 2.9.3
datasink = _BigQueryDatasink(
project_id=project_id,
dataset=dataset,
max_retry_cnt=max_retry_cnt,
)
return ds.write_datasink(
datasink=datasink,
ray_remote_args=ray_remote_args,
)
elif ray.__version__ in ("2.33.0", "2.42.0"):
datasink = _BigQueryDatasink(
project_id=project_id,
dataset=dataset,
max_retry_cnt=max_retry_cnt,
overwrite_table=overwrite_table,
)
return ds.write_datasink(
datasink=datasink,
ray_remote_args=ray_remote_args,
concurrency=concurrency,
)
else:
raise ImportError(
f"[Ray on Vertex AI]: Unsupported version {ray.__version__}."
+ "Only 2.42.0, 2.33.0 and 2.9.3 are supported."
)

View File

@@ -0,0 +1,18 @@
"""Ray on Vertex AI Prediction."""
# -*- coding: utf-8 -*-
# Copyright 2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

View File

@@ -0,0 +1,22 @@
"""Ray on Vertex AI Prediction Tensorflow."""
# -*- coding: utf-8 -*-
# Copyright 2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from .register import register_sklearn
__all__ = ("register_sklearn",)

View File

@@ -0,0 +1,173 @@
"""Regsiter Scikit Learn for Ray on Vertex AI."""
# -*- coding: utf-8 -*-
# Copyright 2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import os
import pickle
import warnings
import ray
import ray.cloudpickle as cpickle
import tempfile
from typing import Optional, TYPE_CHECKING
from google.cloud import aiplatform
from google.cloud.aiplatform import initializer
from google.cloud.aiplatform import utils
from google.cloud.aiplatform.utils import gcs_utils
from google.cloud.aiplatform.vertex_ray.predict.util import constants
from google.cloud.aiplatform.vertex_ray.predict.util import (
predict_utils,
)
from google.cloud.aiplatform.vertex_ray.util._validation_utils import (
_V2_4_WARNING_MESSAGE,
_V2_9_WARNING_MESSAGE,
)
try:
from ray.train import sklearn as ray_sklearn
if TYPE_CHECKING:
import sklearn
except ImportError as ie:
if ray.__version__ < "2.42.0":
raise ModuleNotFoundError("Sklearn isn't installed.") from ie
else:
sklearn = None
def register_sklearn(
checkpoint: "ray_sklearn.SklearnCheckpoint",
artifact_uri: Optional[str] = None,
display_name: Optional[str] = None,
**kwargs,
) -> aiplatform.Model:
"""Uploads a Ray Sklearn Checkpoint as Sklearn Model to Model Registry.
Example usage:
from vertex_ray.predict import sklearn
from ray.train.sklearn import SklearnCheckpoint
trainer = SklearnTrainer(estimator=RandomForestClassifier, ...)
result = trainer.fit()
sklearn_checkpoint = SklearnCheckpoint.from_checkpoint(result.checkpoint)
my_model = sklearn.register_sklearn(
checkpoint=sklearn_checkpoint,
artifact_uri="gs://{gcs-bucket-name}/path/to/store"
)
Args:
checkpoint: SklearnCheckpoint instance.
artifact_uri (str):
Optional. The path to the directory where Model Artifacts will be saved. If
not set, will use staging bucket set in aiplatform.init().
display_name (str):
Optional. The display name of the Model. The name can be up to 128
characters long and can be consist of any UTF-8 characters.
**kwargs:
Any kwargs will be passed to aiplatform.Model registration.
Returns:
model (aiplatform.Model):
Instantiated representation of the uploaded model resource.
Raises:
ValueError: Invalid Argument.
RuntimeError: Only Ray version 2.9.3 is supported.
"""
ray_version = ray.__version__
if ray_version != "2.9.3":
raise RuntimeError(
f"Ray version {ray_version} is not supported to upload Sklearn"
" model to Vertex Model Registry yet. Please use Ray 2.9.3."
)
if ray_version == "2.9.3":
warnings.warn(_V2_9_WARNING_MESSAGE, DeprecationWarning, stacklevel=1)
artifact_uri = artifact_uri or initializer.global_config.staging_bucket
predict_utils.validate_artifact_uri(artifact_uri)
display_model_name = (
(f"ray-on-vertex-registered-sklearn-model-{utils.timestamped_unique_name()}")
if display_name is None
else display_name
)
estimator = _get_estimator_from(checkpoint)
model_dir = os.path.join(artifact_uri, display_model_name)
file_path = os.path.join(model_dir, constants._PICKLE_FILE_NAME)
with tempfile.NamedTemporaryFile(suffix=constants._PICKLE_EXTENTION) as temp_file:
pickle.dump(estimator, temp_file)
gcs_utils.upload_to_gcs(temp_file.name, file_path)
return aiplatform.Model.upload_scikit_learn_model_file(
model_file_path=temp_file.name, display_name=display_model_name, **kwargs
)
def _get_estimator_from(
checkpoint: "ray_sklearn.SklearnCheckpoint",
) -> "sklearn.base.BaseEstimator":
"""Converts a SklearnCheckpoint to sklearn estimator.
Args:
checkpoint: SklearnCheckpoint instance.
Returns:
A Sklearn BaseEstimator
Raises:
ValueError: Invalid Argument.
RuntimeError: Model not found.
RuntimeError: Ray version 2.4 is not supported.
RuntimeError: Only Ray version 2.9.3 is supported.
"""
ray_version = ray.__version__
if ray_version == "2.4.0":
raise RuntimeError(_V2_4_WARNING_MESSAGE)
if ray_version != "2.9.3":
raise RuntimeError(
f"Ray version {ray_version} is not supported to convert a Sklearn"
" checkpoint to sklearn estimator on Vertex yet. Please use Ray 2.9.3."
)
try:
return checkpoint.get_model()
except AttributeError:
model_file_name = ray.train.sklearn.SklearnCheckpoint.MODEL_FILENAME
model_path = os.path.join(checkpoint.path, model_file_name)
if os.path.exists(model_path):
with open(model_path, mode="rb") as f:
obj = pickle.load(f)
else:
try:
# Download from GCS to temp and then load_model
with tempfile.TemporaryDirectory() as temp_dir:
gcs_utils.download_from_gcs("gs://" + checkpoint.path, temp_dir)
with open(f"{temp_dir}/{model_file_name}", mode="rb") as f:
obj = cpickle.load(f)
except Exception as e:
raise RuntimeError(
f"{model_file_name} not found in this checkpoint due to: {e}."
)
return obj

View File

@@ -0,0 +1,22 @@
"""Ray on Vertex AI Prediction Tensorflow."""
# -*- coding: utf-8 -*-
# Copyright 2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from .register import register_tensorflow
__all__ = ("register_tensorflow",)

View File

@@ -0,0 +1,161 @@
"""Regsiter Tensorflow for Ray on Vertex AI."""
# -*- coding: utf-8 -*-
# Copyright 2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import os
import logging
import ray
from typing import Callable, Optional, Union, TYPE_CHECKING
import warnings
from google.cloud import aiplatform
from google.cloud.aiplatform import initializer
from google.cloud.aiplatform import utils
from google.cloud.aiplatform.vertex_ray.predict.util import constants
from google.cloud.aiplatform.vertex_ray.predict.util import (
predict_utils,
)
from google.cloud.aiplatform.vertex_ray.util._validation_utils import (
_V2_4_WARNING_MESSAGE,
_V2_9_WARNING_MESSAGE,
)
try:
from ray.train import tensorflow as ray_tensorflow
if TYPE_CHECKING:
import tensorflow as tf
except ModuleNotFoundError as mnfe:
raise ModuleNotFoundError("Tensorflow isn't installed.") from mnfe
def register_tensorflow(
checkpoint: ray_tensorflow.TensorflowCheckpoint,
artifact_uri: Optional[str] = None,
_model: Optional[Union["tf.keras.Model", Callable[[], "tf.keras.Model"]]] = None,
display_name: Optional[str] = None,
tensorflow_version: Optional[str] = None,
**kwargs,
) -> aiplatform.Model:
"""Uploads a Ray Tensorflow Checkpoint as Tensorflow Model to Model Registry.
Example usage:
from vertex_ray.predict import tensorflow
def create_model():
model = tf.keras.Sequential(...)
...
return model
result = trainer.fit()
my_model = tensorflow.register_tensorflow(
checkpoint=result.checkpoint,
_model=create_model,
artifact_uri="gs://{gcs-bucket-name}/path/to/store",
use_gpu=True
)
1. `use_gpu` will be passed to aiplatform.Model.upload_tensorflow_saved_model()
2. The `create_model` provides the model_definition which is required if
you create the TensorflowCheckpoint using `from_model` method.
More here, https://docs.ray.io/en/latest/train/api/doc/ray.train.tensorflow.TensorflowCheckpoint.get_model.html#ray.train.tensorflow.TensorflowCheckpoint.get_model
Args:
checkpoint: TensorflowCheckpoint instance.
artifact_uri (str):
Optional. The path to the directory where Model Artifacts will be saved. If
not set, will use staging bucket set in aiplatform.init().
_model: Tensorflow Model Definition. Refer
https://docs.ray.io/en/latest/train/api/doc/ray.train.tensorflow.TensorflowCheckpoint.get_model.html#ray.train.tensorflow.TensorflowCheckpoint.get_model
display_name (str):
Optional. The display name of the Model. The name can be up to 128
characters long and can be consist of any UTF-8 characters.
tensorflow_version (str):
Optional. The version of the Tensorflow serving container.
Supported versions:
https://cloud.google.com/vertex-ai/docs/predictions/pre-built-containers
If the version is not specified, the latest version is used.
**kwargs:
Any kwargs will be passed to aiplatform.Model registration.
Returns:
model (aiplatform.Model):
Instantiated representation of the uploaded model resource.
Raises:
ValueError: Invalid Argument.
"""
if ray.__version__ == "2.9.3":
warnings.warn(_V2_9_WARNING_MESSAGE, DeprecationWarning, stacklevel=1)
if tensorflow_version is None:
tensorflow_version = constants._TENSORFLOW_VERSION
artifact_uri = artifact_uri or initializer.global_config.staging_bucket
predict_utils.validate_artifact_uri(artifact_uri)
prefix = "ray-on-vertex-registered-tensorflow-model"
display_model_name = (
(f"{prefix}-{utils.timestamped_unique_name()}")
if display_name is None
else display_name
)
tf_model = _get_tensorflow_model_from(checkpoint, model=_model)
model_dir = os.path.join(artifact_uri, prefix)
try:
import tensorflow as tf
tf.saved_model.save(tf_model, model_dir)
except ImportError:
logging.warning("TensorFlow must be installed to save the trained model.")
return aiplatform.Model.upload_tensorflow_saved_model(
saved_model_dir=model_dir,
display_name=display_model_name,
tensorflow_version=tensorflow_version,
**kwargs,
)
def _get_tensorflow_model_from(
checkpoint: ray_tensorflow.TensorflowCheckpoint,
model: Optional[Union["tf.keras.Model", Callable[[], "tf.keras.Model"]]] = None,
) -> "tf.keras.Model":
"""Converts a TensorflowCheckpoint to Tensorflow Model.
Args:
checkpoint: TensorflowCheckpoint instance.
model: Tensorflow Model Defination.
Returns:
A Tensorflow Native Framework Model.
Raises:
ValueError: Invalid Argument.
RuntimeError: Ray version 2.4.0 is not supported.
"""
ray_version = ray.__version__
if ray_version == "2.4.0":
raise RuntimeError(_V2_4_WARNING_MESSAGE)
try:
import tensorflow as tf
try:
return tf.saved_model.load(checkpoint.path)
except OSError:
return tf.saved_model.load("gs://" + checkpoint.path)
except ImportError:
logging.warning("TensorFlow must be installed to load the trained model.")

View File

@@ -0,0 +1,22 @@
"""Ray on Vertex AI Prediction Tensorflow."""
# -*- coding: utf-8 -*-
# Copyright 2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from .register import get_pytorch_model_from
__all__ = ("get_pytorch_model_from",)

View File

@@ -0,0 +1,112 @@
"""Regsiter Torch for Ray on Vertex AI."""
# -*- coding: utf-8 -*-
# Copyright 2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import warnings
import ray
from ray.air._internal.torch_utils import load_torch_model
import tempfile
from google.cloud.aiplatform.vertex_ray.util._validation_utils import (
_V2_4_WARNING_MESSAGE,
_V2_9_WARNING_MESSAGE,
)
from google.cloud.aiplatform.utils import gcs_utils
from typing import Optional
try:
from ray.train import torch as ray_torch
import torch
except ModuleNotFoundError as mnfe:
raise ModuleNotFoundError("Torch isn't installed.") from mnfe
def get_pytorch_model_from(
checkpoint: ray_torch.TorchCheckpoint,
model: Optional[torch.nn.Module] = None,
) -> torch.nn.Module:
"""Converts a TorchCheckpoint to Pytorch Model.
Example:
from vertex_ray.predict import torch
result = TorchTrainer.fit(...)
pytorch_model = torch.get_pytorch_model_from(
checkpoint=result.checkpoint
)
Args:
checkpoint: TorchCheckpoint instance.
model: If the checkpoint contains a model state dict, and not the model
itself, then the state dict will be loaded to this `model`. Otherwise,
the model will be discarded.
Returns:
A Pytorch Native Framework Model.
Raises:
ValueError: Invalid Argument.
ModuleNotFoundError: PyTorch isn't installed.
RuntimeError: Model not found.
RuntimeError: Ray version 2.4 is not supported.
RuntimeError: Only Ray version 2.9.3 is supported.
"""
ray_version = ray.__version__
if ray_version == "2.4.0":
raise RuntimeError(_V2_4_WARNING_MESSAGE)
if ray_version != "2.9.3":
raise RuntimeError(
f"Ray on Vertex does not support Ray version {ray_version} to"
" convert PyTorch model artifacts yet. Please use Ray 2.9.3."
)
if ray_version == "2.9.3":
warnings.warn(_V2_9_WARNING_MESSAGE, DeprecationWarning, stacklevel=1)
try:
return checkpoint.get_model()
except AttributeError:
model_file_name = ray.train.torch.TorchCheckpoint.MODEL_FILENAME
model_path = os.path.join(checkpoint.path, model_file_name)
try:
import torch
except ModuleNotFoundError as mnfe:
raise ModuleNotFoundError("PyTorch isn't installed.") from mnfe
if os.path.exists(model_path):
model_or_state_dict = torch.load(
model_path, map_location="cpu", weights_only=True
)
else:
try:
# Download from GCS to temp and then load_model
with tempfile.TemporaryDirectory() as temp_dir:
gcs_utils.download_from_gcs("gs://" + checkpoint.path, temp_dir)
model_or_state_dict = torch.load(
f"{temp_dir}/{model_file_name}",
map_location="cpu",
weights_only=True,
)
except Exception as e:
raise RuntimeError(
f"{model_file_name} not found in this checkpoint due to: {e}."
)
model = load_torch_model(saved_model=model_or_state_dict, model_definition=model)
return model

View File

@@ -0,0 +1,28 @@
# -*- coding: utf-8 -*-
# Copyright 2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""Constants."""
# Required Names for model files are specified here
# https://cloud.google.com/vertex-ai/docs/training/exporting-model-artifacts#framework-specific_requirements
_PICKLE_FILE_NAME = "model.pkl"
_PICKLE_EXTENTION = ".pkl"
_XGBOOST_VERSION = "1.6"
# TensorFlow 2.13 requires typing_extensions<4.6 and will cause errors in Ray.
# https://github.com/tensorflow/tensorflow/blob/v2.13.0/tensorflow/tools/pip_package/setup.py#L100
# 2.13 is the latest supported version of Vertex prebuilt prediction container.
# Set 2.12 as default here since 2.13 cause errors.
_TENSORFLOW_VERSION = "2.12"

View File

@@ -0,0 +1,24 @@
# -*- coding: utf-8 -*-
# Copyright 2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""Predict Utils.
"""
def validate_artifact_uri(artifact_uri: str) -> None:
if artifact_uri is None or not artifact_uri.startswith("gs://"):
raise ValueError("Argument 'artifact_uri' should start with 'gs://'.")

View File

@@ -0,0 +1,22 @@
"""Ray on Vertex AI Prediction Tensorflow."""
# -*- coding: utf-8 -*-
# Copyright 2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from .register import register_xgboost
__all__ = ("register_xgboost",)

View File

@@ -0,0 +1,189 @@
"""Regsiter XGBoost for Ray on Vertex AI."""
# -*- coding: utf-8 -*-
# Copyright 2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import os
import pickle
import ray
import tempfile
from typing import Optional, TYPE_CHECKING
import warnings
from google.cloud import aiplatform
from google.cloud.aiplatform import initializer
from google.cloud.aiplatform import utils
from google.cloud.aiplatform.utils import gcs_utils
from google.cloud.aiplatform.vertex_ray.predict.util import constants
from google.cloud.aiplatform.vertex_ray.predict.util import (
predict_utils,
)
from google.cloud.aiplatform.vertex_ray.util._validation_utils import (
_V2_4_WARNING_MESSAGE,
_V2_9_WARNING_MESSAGE,
)
try:
from ray.train import xgboost as ray_xgboost
if TYPE_CHECKING:
import xgboost
except ModuleNotFoundError as mnfe:
if ray.__version__ == "2.9.3":
raise ModuleNotFoundError("XGBoost isn't installed.") from mnfe
else:
xgboost = None
def register_xgboost(
checkpoint: "ray_xgboost.XGBoostCheckpoint",
artifact_uri: Optional[str] = None,
display_name: Optional[str] = None,
xgboost_version: Optional[str] = None,
**kwargs,
) -> aiplatform.Model:
"""Uploads a Ray XGBoost Checkpoint as XGBoost Model to Model Registry.
Example usage:
from vertex_ray.predict import xgboost
from ray.train.xgboost import XGBoostCheckpoint
trainer = XGBoostTrainer(...)
result = trainer.fit()
xgboost_checkpoint = XGBoostCheckpoint.from_checkpoint(result.checkpoint)
my_model = xgboost.register_xgboost(
checkpoint=xgboost_checkpoint,
artifact_uri="gs://{gcs-bucket-name}/path/to/store",
display_name="my-ray-on-vertex-xgboost-model",
)
Args:
checkpoint: XGBoostCheckpoint instance.
artifact_uri (str):
The path to the directory where Model Artifacts will be saved. If
not set, will use staging bucket set in aiplatform.init().
display_name (str):
Optional. The display name of the Model. The name can be up to 128
characters long and can be consist of any UTF-8 characters.
xgboost_version (str): Optional. The version of the XGBoost serving container.
Supported versions:
https://cloud.google.com/vertex-ai/docs/predictions/pre-built-containers
If the version is not specified, the version 1.6 is used.
**kwargs:
Any kwargs will be passed to aiplatform.Model registration.
Returns:
model (aiplatform.Model):
Instantiated representation of the uploaded model resource.
Raises:
ValueError: Invalid Argument.
RuntimeError: Only Ray version 2.9.3 is supported.
"""
ray_version = ray.__version__
if ray_version != "2.9.3":
raise RuntimeError(
f"Ray version {ray_version} is not supported to upload XGBoost"
" model to Vertex Model Registry yet. Please use Ray 2.9.3."
)
if ray_version == "2.9.3":
warnings.warn(_V2_9_WARNING_MESSAGE, DeprecationWarning, stacklevel=1)
artifact_uri = artifact_uri or initializer.global_config.staging_bucket
predict_utils.validate_artifact_uri(artifact_uri)
display_model_name = (
(f"ray-on-vertex-registered-xgboost-model-{utils.timestamped_unique_name()}")
if display_name is None
else display_name
)
model = _get_xgboost_model_from(checkpoint)
model_dir = os.path.join(artifact_uri, display_model_name)
file_path = os.path.join(model_dir, constants._PICKLE_FILE_NAME)
if xgboost_version is None:
xgboost_version = constants._XGBOOST_VERSION
with tempfile.NamedTemporaryFile(suffix=constants._PICKLE_EXTENTION) as temp_file:
pickle.dump(model, temp_file)
gcs_utils.upload_to_gcs(temp_file.name, file_path)
return aiplatform.Model.upload_xgboost_model_file(
model_file_path=temp_file.name,
display_name=display_model_name,
xgboost_version=xgboost_version,
**kwargs,
)
def _get_xgboost_model_from(
checkpoint: "ray_xgboost.XGBoostCheckpoint",
) -> "xgboost.Booster":
"""Converts a XGBoostCheckpoint to XGBoost model.
Args:
checkpoint: XGBoostCheckpoint instance.
Returns:
A XGBoost core Booster
Raises:
ValueError: Invalid Argument.
ModuleNotFoundError: XGBoost isn't installed.
RuntimeError: Model not found.
RuntimeError: Ray version 2.4 is not supported.
RuntimeError: Only Ray version 2.9.3 is supported.
"""
ray_version = ray.__version__
if ray_version == "2.4.0":
raise RuntimeError(_V2_4_WARNING_MESSAGE)
if ray_version != "2.9.3":
raise RuntimeError(
f"Ray version {ray_version} is not supported to convert a XGBoost"
" checkpoint to XGBoost model on Vertex yet. Please use Ray 2.9.3."
)
try:
# This works for Ray v2.5
return checkpoint.get_model()
except AttributeError:
# This works for Ray v2.9
model_file_name = ray.train.xgboost.XGBoostCheckpoint.MODEL_FILENAME
model_path = os.path.join(checkpoint.path, model_file_name)
try:
import xgboost
except ModuleNotFoundError as mnfe:
raise ModuleNotFoundError("XGBoost isn't installed.") from mnfe
booster = xgboost.Booster()
if os.path.exists(model_path):
booster.load_model(model_path)
return booster
try:
# Download from GCS to temp and then load_model
with tempfile.TemporaryDirectory() as temp_dir:
gcs_utils.download_from_gcs("gs://" + checkpoint.path, temp_dir)
booster.load_model(f"{temp_dir}/{model_file_name}")
return booster
except Exception as e:
raise RuntimeError(
f"{model_file_name} not found in this checkpoint due to: {e}."
)

View File

@@ -0,0 +1,27 @@
# -*- coding: utf-8 -*-
# Copyright 2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import pathlib
from ray.widgets import Template
class VertexRayTemplate(Template):
"""Class which provides basic HTML templating."""
def __init__(self, file: str):
with open(pathlib.Path(__file__).parent / "templates" / file, "r") as f:
self.template = f.read()

View File

@@ -0,0 +1,4 @@
<tr>
<td style="text-align: left"><b>Interactive Terminal Uri:</b></td>
<td style="text-align: left"><b><a href="https://{{ shell_uri }}" target="_blank">{{ shell_uri }}</a></b></td>
</tr>

View File

@@ -0,0 +1,24 @@
<table class="jp-RenderedHTMLCommon" style="border-collapse: collapse;color: var(--jp-ui-font-color1);font-size: var(--jp-ui-font-size1);">
<tr>
<td style="text-align: left"><b>Python version:</b></td>
<td style="text-align: left"><b>{{ python_version }}</b></td>
</tr>
<tr>
<td style="text-align: left"><b>Ray version:</b></td>
<td style="text-align: left"><b> {{ ray_version }}</b></td>
</tr>
<tr>
<td style="text-align: left"><b>Vertex SDK version:</b></td>
<td style="text-align: left"><b> {{ vertex_sdk_version }}</b></td>
</tr>
<tr>
<td style="text-align: left"><b>Dashboard:</b></td>
<td style="text-align: left"><b><a href="https://{{ dashboard_url }}" target="_blank">{{ dashboard_url }}</a></b></td>
</tr>
{{ shell_uri_row }}
<tr>
<td style="text-align: left"><b>Cluster Name:</b></td>
<td style="text-align: left"><b> {{ persistent_resource_id }}</b></td>
</tr>
</table>

View File

@@ -0,0 +1,288 @@
# -*- coding: utf-8 -*-
# Copyright 2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import datetime
import logging
import time
from typing import Optional
from google.api_core import exceptions
from google.cloud.aiplatform import initializer
from google.cloud.aiplatform.utils import (
PersistentResourceClientWithOverride,
)
from google.cloud.aiplatform.vertex_ray.util import _validation_utils
from google.cloud.aiplatform.vertex_ray.util.resources import (
AutoscalingSpec,
Cluster,
PscIConfig,
Resources,
)
from google.cloud.aiplatform_v1beta1.types.persistent_resource import (
PersistentResource,
)
from google.cloud.aiplatform_v1beta1.types.persistent_resource_service import (
GetPersistentResourceRequest,
)
_PRIVATE_PREVIEW_IMAGE = "-docker.pkg.dev/vertex-ai/training/tf-"
_OFFICIAL_IMAGE = "-docker.pkg.dev/vertex-ai/training/ray-"
def create_persistent_resource_client():
# location is inhereted from the global configuration at aiplatform.init().
return initializer.global_config.create_client(
client_class=PersistentResourceClientWithOverride,
appended_gapic_version="vertex_ray",
).select_version("v1beta1")
def polling_delay(num_attempts: int, time_scale: float) -> datetime.timedelta:
"""Computes a delay to the next attempt to poll the Vertex service.
This does bounded exponential backoff, starting with $time_scale.
If $time_scale == 0, it starts with a small time interval, less than
1 second.
Args:
num_attempts: The number of times have we polled and found that the
desired result was not yet available.
time_scale: The shortest polling interval, in seconds, or zero. Zero is
treated as a small interval, less than 1 second.
Returns:
A recommended delay interval, in seconds.
"""
# The polling schedule is slow initially , and then gets faster until 6
# attempts (after that the sleeping time remains the same).
small_interval = 30.0 # Seconds
interval = max(time_scale, small_interval) * 0.765 ** min(num_attempts, 6)
return datetime.timedelta(seconds=interval)
def get_persistent_resource(
persistent_resource_name: str, tolerance: Optional[int] = 0
):
"""Get persistent resource.
Args:
persistent_resource_name:
"projects/<project_num>/locations/<region>/persistentResources/<pr_id>".
tolerance: number of attemps to get persistent resource.
Returns:
aiplatform_v1.PersistentResource if state is RUNNING.
Raises:
ValueError: Invalid cluster resource name.
RuntimeError: Service returns error.
RuntimeError: Cluster resource state is STOPPING.
RuntimeError: Cluster resource state is ERROR.
"""
client = create_persistent_resource_client()
request = GetPersistentResourceRequest(name=persistent_resource_name)
# TODO(b/277117901): Add test cases for polling and error handling
num_attempts = 0
while True:
try:
response = client.get_persistent_resource(request)
except exceptions.NotFound:
response = None
if num_attempts >= tolerance:
raise ValueError(
"[Ray on Vertex AI]: Invalid cluster_resource_name (404 not found)."
)
if response:
if response.error.message:
logging.error("[Ray on Vertex AI]: %s" % response.error.message)
raise RuntimeError("[Ray on Vertex AI]: Cluster returned an error.")
print("[Ray on Vertex AI]: Cluster State =", response.state)
if response.state == PersistentResource.State.RUNNING:
return response
elif response.state == PersistentResource.State.STOPPING:
raise RuntimeError("[Ray on Vertex AI]: The cluster is stopping.")
elif response.state == PersistentResource.State.ERROR:
raise RuntimeError(
"[Ray on Vertex AI]: The cluster encountered an error."
)
# Polling decay
sleep_time = polling_delay(num_attempts=num_attempts, time_scale=150.0)
num_attempts += 1
print(
"Waiting for cluster provisioning; attempt {}; sleeping for {} seconds".format(
num_attempts, sleep_time
)
)
time.sleep(sleep_time.total_seconds())
def persistent_resource_to_cluster(
persistent_resource: PersistentResource,
) -> Optional[Cluster]:
"""Format a PersistentResource to a dictionary.
Args:
persistent_resource: PersistentResource.
Returns:
Cluster.
"""
dashboard_address = persistent_resource.resource_runtime.access_uris.get(
"RAY_DASHBOARD_URI"
)
cluster = Cluster(
cluster_resource_name=persistent_resource.name,
network=persistent_resource.network,
reserved_ip_ranges=persistent_resource.reserved_ip_ranges,
state=persistent_resource.state.name,
labels=persistent_resource.labels,
dashboard_address=dashboard_address,
)
if not persistent_resource.resource_runtime_spec.ray_spec:
# skip PersistentResource without RaySpec
logging.info(
"[Ray on Vertex AI]: Cluster %s does not have Ray installed."
% persistent_resource.name,
)
return
if persistent_resource.psc_interface_config:
cluster.psc_interface_config = PscIConfig(
network_attachment=persistent_resource.psc_interface_config.network_attachment
)
resource_pools = persistent_resource.resource_pools
head_resource_pool = resource_pools[0]
head_id = head_resource_pool.id
head_image_uri = (
persistent_resource.resource_runtime_spec.ray_spec.resource_pool_images[head_id]
)
if persistent_resource.resource_runtime_spec.service_account_spec.service_account:
cluster.service_account = (
persistent_resource.resource_runtime_spec.service_account_spec.service_account
)
if not head_image_uri:
head_image_uri = persistent_resource.resource_runtime_spec.ray_spec.image_uri
try:
python_version, ray_version = _validation_utils.get_versions_from_image_uri(
head_image_uri
)
except IndexError:
if _PRIVATE_PREVIEW_IMAGE in head_image_uri:
# If using outdated images
logging.info(
"[Ray on Vertex AI]: The image of cluster %s is outdated."
" It is recommended to delete and recreate the cluster to obtain"
" the latest image." % persistent_resource.name
)
return None
else:
# Custom image might also cause IndexError
python_version = None
ray_version = None
cluster.python_version = python_version
cluster.ray_version = ray_version
cluster.ray_metric_enabled = not (
persistent_resource.resource_runtime_spec.ray_spec.ray_metric_spec.disabled
)
cluster.ray_logs_enabled = not (
persistent_resource.resource_runtime_spec.ray_spec.ray_logs_spec.disabled
)
accelerator_type = head_resource_pool.machine_spec.accelerator_type
if accelerator_type.value != 0:
accelerator_type = accelerator_type.name
else:
accelerator_type = None
if _OFFICIAL_IMAGE in head_image_uri:
# Official training image is not custom
head_image_uri = None
head_node_type = Resources(
machine_type=head_resource_pool.machine_spec.machine_type,
accelerator_type=accelerator_type,
accelerator_count=head_resource_pool.machine_spec.accelerator_count,
boot_disk_type=head_resource_pool.disk_spec.boot_disk_type,
boot_disk_size_gb=head_resource_pool.disk_spec.boot_disk_size_gb,
node_count=1,
custom_image=head_image_uri,
)
worker_node_types = []
if head_resource_pool.replica_count > 1:
# head_node_type.node_count must be 1. If the head_resource_pool (the first
# resource pool) has replica_count > 1, the rest replica are worker nodes.
worker_node_count = head_resource_pool.replica_count - 1
worker_node_types.append(
Resources(
machine_type=head_resource_pool.machine_spec.machine_type,
accelerator_type=accelerator_type,
accelerator_count=head_resource_pool.machine_spec.accelerator_count,
boot_disk_type=head_resource_pool.disk_spec.boot_disk_type,
boot_disk_size_gb=head_resource_pool.disk_spec.boot_disk_size_gb,
node_count=worker_node_count,
custom_image=head_image_uri,
)
)
if head_resource_pool.autoscaling_spec:
worker_node_types[0].autoscaling_spec = AutoscalingSpec(
min_replica_count=head_resource_pool.autoscaling_spec.min_replica_count,
max_replica_count=head_resource_pool.autoscaling_spec.max_replica_count,
)
for i in range(len(resource_pools) - 1):
# Convert the second and more resource pools to vertex_ray.Resources,
# and append then to worker_node_types.
accelerator_type = resource_pools[i + 1].machine_spec.accelerator_type
if accelerator_type.value != 0:
accelerator_type = accelerator_type.name
else:
accelerator_type = None
worker_image_uri = (
persistent_resource.resource_runtime_spec.ray_spec.resource_pool_images[
resource_pools[i + 1].id
]
)
if _OFFICIAL_IMAGE in worker_image_uri:
# Official training image is not custom
worker_image_uri = None
resource = Resources(
machine_type=resource_pools[i + 1].machine_spec.machine_type,
accelerator_type=accelerator_type,
accelerator_count=resource_pools[i + 1].machine_spec.accelerator_count,
boot_disk_type=resource_pools[i + 1].disk_spec.boot_disk_type,
boot_disk_size_gb=resource_pools[i + 1].disk_spec.boot_disk_size_gb,
node_count=resource_pools[i + 1].replica_count,
custom_image=worker_image_uri,
)
if resource_pools[i + 1].autoscaling_spec:
resource.autoscaling_spec = AutoscalingSpec(
min_replica_count=resource_pools[
i + 1
].autoscaling_spec.min_replica_count,
max_replica_count=resource_pools[
i + 1
].autoscaling_spec.max_replica_count,
)
worker_node_types.append(resource)
cluster.head_node_type = head_node_type
cluster.worker_node_types = worker_node_types
return cluster

View File

@@ -0,0 +1,167 @@
# -*- coding: utf-8 -*-
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import google.auth
import google.auth.transport.requests
import logging
import ray
import re
from immutabledict import immutabledict
from google.cloud.aiplatform import initializer
from google.cloud.aiplatform.utils import resource_manager_utils
SUPPORTED_RAY_VERSIONS = immutabledict(
{"2.9": "2.9.3", "2.33": "2.33.0", "2.42": "2.42.0"}
)
SUPPORTED_RAY_VERSIONS_FROM_PYTHON_VERSIONS = immutabledict(
{
"3.10": ("2.9", "2.33", "2.42"),
"3.11": ("2.42"),
}
)
_V2_4_WARNING_MESSAGE = (
"After google-cloud-aiplatform>1.53.0, using Ray version = 2.4 will result"
" in an error. Please use Ray version = 2.33.0 or 2.42.0 (default) instead."
)
_V2_9_WARNING_MESSAGE = (
"In March 2025, using Ray version = 2.9 will result in an error. "
"Please use Ray version = 2.33.0 or 2.42.0 (default) instead."
)
# Artifact Repository available regions.
_AVAILABLE_REGIONS = ["us", "europe", "asia"]
# If region is not available, assume using the default region.
_DEFAULT_REGION = "us"
_PERSISTENT_RESOURCE_NAME_PATTERN = "projects/{}/locations/{}/persistentResources/{}"
_VALID_RESOURCE_NAME_REGEX = "[a-z][a-zA-Z0-9._-]{0,127}"
_DASHBOARD_URI_SUFFIX = "aiplatform-training.googleusercontent.com"
def valid_resource_name(resource_name):
"""Check if address is a valid resource name."""
resource_name_split = resource_name.split("/")
if not (
len(resource_name_split) == 6
and resource_name_split[0] == "projects"
and resource_name_split[2] == "locations"
and resource_name_split[4] == "persistentResources"
):
raise ValueError(
"[Ray on Vertex AI]: Address must be in the following "
"format: vertex_ray://projects/<project_num>/locations/<region>/persistentResources/<pr_id> "
"or vertex_ray://<pr_id>."
)
def maybe_reconstruct_resource_name(address) -> str:
"""Reconstruct full persistent resource name if only id was given."""
if re.match("^{}$".format(_VALID_RESOURCE_NAME_REGEX), address):
# Assume only cluster name (persistent resource id) was given.
logging.info(
"[Ray on Vertex AI]: Cluster name was given as address, reconstructing full resource name"
)
return _PERSISTENT_RESOURCE_NAME_PATTERN.format(
resource_manager_utils.get_project_number(
initializer.global_config.project
),
initializer.global_config.location,
address,
)
return address
def get_local_ray_version():
ray_version = ray.__version__.split(".")
if len(ray_version) == 3:
ray_version = ray_version[:2]
return ".".join(ray_version)
def get_image_uri(ray_version, python_version, enable_cuda):
"""Image uri for a given ray version and python version."""
if ray_version not in SUPPORTED_RAY_VERSIONS:
raise ValueError(
"[Ray on Vertex AI]: The supported Ray versions are %s (%s) and %s (%s)."
% (
list(SUPPORTED_RAY_VERSIONS.keys())[0],
list(SUPPORTED_RAY_VERSIONS.values())[0],
list(SUPPORTED_RAY_VERSIONS.keys())[1],
list(SUPPORTED_RAY_VERSIONS.values())[1],
)
)
if python_version not in SUPPORTED_RAY_VERSIONS_FROM_PYTHON_VERSIONS:
raise ValueError(
"[Ray on Vertex AI]: The supported Python versions are 3.10 or 3.11."
)
if ray_version not in SUPPORTED_RAY_VERSIONS_FROM_PYTHON_VERSIONS[python_version]:
raise ValueError(
"[Ray on Vertex AI]: The supported Ray version(s) for Python version %s: %s."
% (
python_version,
SUPPORTED_RAY_VERSIONS_FROM_PYTHON_VERSIONS[python_version],
)
)
location = initializer.global_config.location
region = location.split("-")[0]
if region not in _AVAILABLE_REGIONS:
region = _DEFAULT_REGION
ray_version = ray_version.replace(".", "-")
python_version = python_version.replace(".", "")
if enable_cuda:
return f"{region}-docker.pkg.dev/vertex-ai/training/ray-gpu.{ray_version}.py{python_version}:latest"
else:
return f"{region}-docker.pkg.dev/vertex-ai/training/ray-cpu.{ray_version}.py{python_version}:latest"
def get_versions_from_image_uri(image_uri):
"""Get ray version and python version from image uri."""
logging.info(f"[Ray on Vertex AI]: Getting versions from image uri: {image_uri}")
image_label = image_uri.split("/")[-1].split(":")[0]
py_version = image_label[-3] + "." + image_label[-2:]
ray_version = image_label.split(".")[1].replace("-", ".")
if (
py_version in SUPPORTED_RAY_VERSIONS_FROM_PYTHON_VERSIONS
and ray_version in SUPPORTED_RAY_VERSIONS_FROM_PYTHON_VERSIONS[py_version]
):
return py_version, ray_version
else:
# May not parse custom image and get the versions correctly
return None, None
def valid_dashboard_address(address):
"""Check if address is a valid dashboard uri."""
return address.endswith(_DASHBOARD_URI_SUFFIX)
def get_bearer_token():
"""Get bearer token through Application Default Credentials."""
creds, _ = google.auth.default(
scopes=["https://www.googleapis.com/auth/cloud-platform"]
)
# creds.valid is False, and creds.token is None
# Need to refresh credentials to populate those
auth_req = google.auth.transport.requests.Request()
creds.refresh(auth_req)
return creds.token

View File

@@ -0,0 +1,217 @@
# -*- coding: utf-8 -*-
# Copyright 2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import dataclasses
from typing import Dict, List, Optional
from google.cloud.aiplatform_v1beta1.types import PersistentResource
@dataclasses.dataclass
class AutoscalingSpec:
"""Autoscaling spec for a ray cluster node.
Attributes:
min_replica_count: The minimum number of replicas in the cluster.
max_replica_count: The maximum number of replicas in the cluster.
"""
min_replica_count: int = 1
max_replica_count: int = 2
@dataclasses.dataclass
class Resources:
"""Resources for a ray cluster node.
Attributes:
machine_type: See the list of machine types:
https://cloud.google.com/vertex-ai/docs/training/configure-compute#machine-types
node_count: This argument represents how many nodes to start for the
ray cluster.
accelerator_type: e.g. "NVIDIA_TESLA_P4".
Vertex AI supports the following types of GPU:
https://cloud.google.com/vertex-ai/docs/training/configure-compute#specifying_gpus
accelerator_count: The number of accelerators to attach to the machine.
boot_disk_type: Type of the boot disk (default is "pd-ssd").
Valid values: "pd-ssd" (Persistent Disk Solid State Drive) or
"pd-standard" (Persistent Disk Hard Disk Drive).
boot_disk_size_gb: Size in GB of the boot disk (default is 100GB). Must
be either unspecified or within the range of [100, 64000].
custom_image: Custom image for this resource (e.g.
us-docker.pkg.dev/my-project/ray-gpu.2-9.py310-tf:latest).
autoscaling_spec: Autoscaling spec for this resource.
"""
machine_type: Optional[str] = "n1-standard-16"
node_count: Optional[int] = 1
accelerator_type: Optional[str] = None
accelerator_count: Optional[int] = 0
boot_disk_type: Optional[str] = "pd-ssd"
boot_disk_size_gb: Optional[int] = 100
custom_image: Optional[str] = None
autoscaling_spec: Optional[AutoscalingSpec] = None
@dataclasses.dataclass
class NodeImages:
"""Custom images for a ray cluster.
We currently support Ray v2.9, v2.33, v2.42 and python v3.10.
We also support python v3.11 for Ray v2.42.
The custom images must be extended from the following base images:
"{region}-docker.pkg.dev/vertex-ai/training/ray-cpu.2-9.py310:latest",
"{region}-docker.pkg.dev/vertex-ai/training/ray-gpu.2-9.py310:latest",
"{region}-docker.pkg.dev/vertex-ai/training/ray-cpu.2-33.py310:latest",
"{region}-docker.pkg.dev/vertex-ai/training/ray-gpu.2-33.py310:latest",
"{region}-docker.pkg.dev/vertex-ai/training/ray-cpu.2-42.py310:latest",
"{region}-docker.pkg.dev/vertex-ai/training/ray-gpu.2-42.py310:latest",
"{region}-docker.pkg.dev/vertex-ai/training/ray-cpu.2-42.py311:latest", or
"{region}-docker.pkg.dev/vertex-ai/training/ray-gpu.2-42.py311:latest". In
order to use custom images, need to specify both head and worker images.
Attributes:
head: image for head node (eg. us-docker.pkg.dev/my-project/ray-cpu.2-33.py310-tf:latest).
worker: image for all worker nodes (eg. us-docker.pkg.dev/my-project/ray-gpu.2-33.py310-tf:latest).
"""
head: str = None
worker: str = None
@dataclasses.dataclass
class PscIConfig:
"""PSC-I config.
Attributes:
network_attachment: Optional. The name or full name of the Compute Engine
`network attachment <https://cloud.google.com/vpc/docs/about-network-attachments>`
to attach to the resource. It has a format:
``projects/{project}/regions/{region}/networkAttachments/{networkAttachment}``.
Where {project} is a project number, as in ``12345``, and
{networkAttachment} is a network attachment name. To specify
this field, you must have already [created a network
attachment]
(https://cloud.google.com/vpc/docs/create-manage-network-attachments#create-network-attachments).
This field is only used for resources using PSC-I. Make sure you do not
specify the network here for VPC peering.
"""
network_attachment: str = None
@dataclasses.dataclass
class NfsMount:
"""NFS mount.
Attributes:
server: Required. IP address of the NFS server.
path: Required. Source path exported from NFS server. Has to start
with '/', and combined with the ip address, it indicates the
source mount path in the form of ``server:path``.
mount_point: Required. Destination mount path. The NFS will be mounted
for the user under /mnt/nfs/<mount_point>.
"""
server: str = None
path: str = None
mount_point: str = None
@dataclasses.dataclass
class Cluster:
"""Ray cluster (output only).
Attributes:
cluster_resource_name: It has a format:
"projects/<project_num>/locations/<region>/persistentResources/<pr_id>".
network: Virtual private cloud (VPC) network. It has a format:
"projects/<project_num>/global/networks/<network_name>".
For Ray Client, VPC peering is required to connect to the cluster
managed in the Vertex API service. For Ray Job API, VPC network is
not required because cluster connection can be accessed through
dashboard address.
reserved_ip_ranges: A list of names for the reserved IP ranges under
the VPC network that can be used for this cluster. If set, we will
deploy the cluster within the provided IP ranges. Otherwise, the
cluster is deployed to any IP ranges under the provided VPC network.
Example: ["vertex-ai-ip-range"].
service_account: Service account to be used for running Ray programs on
the cluster.
state: Describes the cluster state (defined in PersistentResource.State).
python_version: Python version for the ray cluster (e.g. "3.10").
ray_version: Ray version for the ray cluster (e.g. "2.33").
head_node_type: The head node resource. Resources.node_count must be 1.
If not set, by default it is a CPU node with machine_type of n1-standard-8.
worker_node_types: The list of Resources of the worker nodes. Should not
duplicate the elements in the list.
dashboard_address: For Ray Job API (JobSubmissionClient), with this
cluster connection doesn't require VPC peering.
labels:
The labels with user-defined metadata to organize Ray cluster.
Label keys and values can be no longer than 64 characters (Unicode
codepoints), can only contain lowercase letters, numeric characters,
underscores and dashes. International characters are allowed.
See https://goo.gl/xmQnxf for more information and examples of labels.
"""
cluster_resource_name: str = None
network: str = None
reserved_ip_ranges: List[str] = None
service_account: str = None
state: PersistentResource.State = None
python_version: str = None
ray_version: str = None
head_node_type: Resources = None
worker_node_types: List[Resources] = None
dashboard_address: str = None
ray_metric_enabled: bool = True
ray_logs_enabled: bool = True
psc_interface_config: PscIConfig = None
labels: Dict[str, str] = None
def _check_machine_spec_identical(
node_type_1: Resources,
node_type_2: Resources,
) -> int:
"""Check if node_type_1 and node_type_2 have the same machine_spec.
If they are identical, return additional_replica_count."""
additional_replica_count = 0
# Check if machine_spec are the same
if (
node_type_1.machine_type == node_type_2.machine_type
and node_type_1.accelerator_type == node_type_2.accelerator_type
and node_type_1.accelerator_count == node_type_2.accelerator_count
):
if node_type_1.boot_disk_type != node_type_2.boot_disk_type:
raise ValueError(
"Worker disk type must match the head node's disk type if"
" sharing the same machine_type, accelerator_type, and"
" accelerator_count"
)
if node_type_1.boot_disk_size_gb != node_type_2.boot_disk_size_gb:
raise ValueError(
"Worker disk size must match the head node's disk size if"
" sharing the same machine_type, accelerator_type, and"
" accelerator_count"
)
additional_replica_count = node_type_2.node_count
return additional_replica_count
return additional_replica_count