structure saas with tools
This commit is contained in:
68
.venv/lib/python3.10/site-packages/vertex_ray/__init__.py
Normal file
68
.venv/lib/python3.10/site-packages/vertex_ray/__init__.py
Normal file
@@ -0,0 +1,68 @@
|
||||
"""Ray on Vertex AI."""
|
||||
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
import sys
|
||||
|
||||
from google.cloud.aiplatform.vertex_ray.bigquery_datasource import (
|
||||
_BigQueryDatasource,
|
||||
)
|
||||
from google.cloud.aiplatform.vertex_ray.client_builder import (
|
||||
VertexRayClientBuilder as ClientBuilder,
|
||||
)
|
||||
|
||||
from google.cloud.aiplatform.vertex_ray.cluster_init import (
|
||||
create_ray_cluster,
|
||||
delete_ray_cluster,
|
||||
get_ray_cluster,
|
||||
list_ray_clusters,
|
||||
update_ray_cluster,
|
||||
)
|
||||
|
||||
from google.cloud.aiplatform.vertex_ray import data
|
||||
|
||||
from google.cloud.aiplatform.vertex_ray.util.resources import (
|
||||
AutoscalingSpec,
|
||||
Resources,
|
||||
NodeImages,
|
||||
PscIConfig,
|
||||
)
|
||||
|
||||
from google.cloud.aiplatform.vertex_ray.dashboard_sdk import (
|
||||
get_job_submission_client_cluster_info,
|
||||
)
|
||||
|
||||
if sys.version_info[1] not in (10, 11):
|
||||
print(
|
||||
"[Ray on Vertex]: The client environment with Python version 3.10 or 3.11 is required."
|
||||
)
|
||||
|
||||
__all__ = (
|
||||
"_BigQueryDatasource",
|
||||
"data",
|
||||
"ClientBuilder",
|
||||
"get_job_submission_client_cluster_info",
|
||||
"create_ray_cluster",
|
||||
"delete_ray_cluster",
|
||||
"get_ray_cluster",
|
||||
"list_ray_clusters",
|
||||
"update_ray_cluster",
|
||||
"AutoscalingSpec",
|
||||
"Resources",
|
||||
"NodeImages",
|
||||
"PscIConfig",
|
||||
)
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,161 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
# Copyright 2024 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
|
||||
import logging
|
||||
import os
|
||||
import tempfile
|
||||
import time
|
||||
import uuid
|
||||
from typing import Any, Iterable, Optional
|
||||
|
||||
import pyarrow.parquet as pq
|
||||
|
||||
from google.api_core import client_info
|
||||
from google.api_core import exceptions
|
||||
from google.cloud import bigquery
|
||||
from google.cloud.aiplatform import initializer
|
||||
|
||||
import ray
|
||||
from ray.data._internal.execution.interfaces import TaskContext
|
||||
from ray.data._internal.remote_fn import cached_remote_fn
|
||||
from ray.data.block import Block, BlockAccessor
|
||||
|
||||
try:
|
||||
from ray.data.datasource.datasink import Datasink
|
||||
except ImportError:
|
||||
# If datasink cannot be imported, Ray >=2.9.3 is not installed
|
||||
Datasink = None
|
||||
|
||||
|
||||
DEFAULT_MAX_RETRY_CNT = 10
|
||||
RATE_LIMIT_EXCEEDED_SLEEP_TIME = 11
|
||||
|
||||
_BQ_GAPIC_VERSION = bigquery.__version__ + "+vertex_ray"
|
||||
bq_info = client_info.ClientInfo(
|
||||
gapic_version=_BQ_GAPIC_VERSION, user_agent=f"ray-on-vertex/{_BQ_GAPIC_VERSION}"
|
||||
)
|
||||
|
||||
|
||||
# BigQuery write for Ray 2.42.0, 2.33.0, and 2.9.3
|
||||
if Datasink is None:
|
||||
_BigQueryDatasink = None
|
||||
else:
|
||||
|
||||
class _BigQueryDatasink(Datasink):
|
||||
def __init__(
|
||||
self,
|
||||
dataset: str,
|
||||
project_id: Optional[str] = None,
|
||||
max_retry_cnt: int = DEFAULT_MAX_RETRY_CNT,
|
||||
overwrite_table: Optional[bool] = True,
|
||||
) -> None:
|
||||
self.dataset = dataset
|
||||
self.project_id = project_id or initializer.global_config.project
|
||||
self.max_retry_cnt = max_retry_cnt
|
||||
self.overwrite_table = overwrite_table
|
||||
|
||||
def on_write_start(self) -> None:
|
||||
# Set up datasets to write
|
||||
client = bigquery.Client(project=self.project_id, client_info=bq_info)
|
||||
dataset_id = self.dataset.split(".", 1)[0]
|
||||
try:
|
||||
client.get_dataset(dataset_id)
|
||||
except exceptions.NotFound:
|
||||
client.create_dataset(f"{self.project_id}.{dataset_id}", timeout=30)
|
||||
print("[Ray on Vertex AI]: Created dataset " + dataset_id)
|
||||
|
||||
# Delete table if overwrite_table is True
|
||||
if self.overwrite_table:
|
||||
print(
|
||||
f"[Ray on Vertex AI]: Attempting to delete table {self.dataset}"
|
||||
+ " if it already exists since kwarg overwrite_table = True."
|
||||
)
|
||||
client.delete_table(
|
||||
f"{self.project_id}.{self.dataset}", not_found_ok=True
|
||||
)
|
||||
else:
|
||||
print(
|
||||
"[Ray on Vertex AI]: The write will append to table "
|
||||
+ f"{self.dataset} if it already exists "
|
||||
+ "since kwarg overwrite_table = False."
|
||||
)
|
||||
|
||||
def write(
|
||||
self,
|
||||
blocks: Iterable[Block],
|
||||
ctx: TaskContext,
|
||||
) -> Any:
|
||||
def _write_single_block(
|
||||
block: Block, project_id: str, dataset: str
|
||||
) -> None:
|
||||
block = BlockAccessor.for_block(block).to_arrow()
|
||||
|
||||
client = bigquery.Client(project=project_id, client_info=bq_info)
|
||||
job_config = bigquery.LoadJobConfig(autodetect=True)
|
||||
job_config.source_format = bigquery.SourceFormat.PARQUET
|
||||
job_config.write_disposition = bigquery.WriteDisposition.WRITE_APPEND
|
||||
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
fp = os.path.join(temp_dir, f"block_{uuid.uuid4()}.parquet")
|
||||
pq.write_table(block, fp, compression="SNAPPY")
|
||||
|
||||
retry_cnt = 0
|
||||
while retry_cnt <= self.max_retry_cnt:
|
||||
with open(fp, "rb") as source_file:
|
||||
job = client.load_table_from_file(
|
||||
source_file, dataset, job_config=job_config
|
||||
)
|
||||
try:
|
||||
logging.info(job.result())
|
||||
break
|
||||
except exceptions.Forbidden as e:
|
||||
retry_cnt += 1
|
||||
if retry_cnt > self.max_retry_cnt:
|
||||
break
|
||||
print(
|
||||
"[Ray on Vertex AI]: A block write encountered"
|
||||
+ f" a rate limit exceeded error {retry_cnt} time(s)."
|
||||
+ " Sleeping to try again."
|
||||
)
|
||||
logging.debug(e)
|
||||
time.sleep(RATE_LIMIT_EXCEEDED_SLEEP_TIME)
|
||||
|
||||
# Raise exception if retry_cnt exceeds max_retry_cnt
|
||||
if retry_cnt > self.max_retry_cnt:
|
||||
print(
|
||||
f"[Ray on Vertex AI]: Maximum ({self.max_retry_cnt}) retry count exceeded."
|
||||
+ " Ray will attempt to retry the block write via fault tolerance."
|
||||
+ " For more information, see https://docs.ray.io/en/latest/ray-core/fault_tolerance/tasks.html"
|
||||
)
|
||||
raise RuntimeError(
|
||||
f"[Ray on Vertex AI]: Write failed due to {retry_cnt}"
|
||||
+ " repeated API rate limit exceeded responses. Consider"
|
||||
+ " specifiying the max_retry_cnt kwarg with a higher value."
|
||||
)
|
||||
|
||||
_write_single_block = cached_remote_fn(_write_single_block)
|
||||
|
||||
# Launch a remote task for each block within this write task
|
||||
ray.get(
|
||||
[
|
||||
_write_single_block.remote(block, self.project_id, self.dataset)
|
||||
for block in blocks
|
||||
]
|
||||
)
|
||||
|
||||
return "ok"
|
||||
@@ -0,0 +1,151 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
# Copyright 2022 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
from typing import List, Optional
|
||||
|
||||
from google.api_core import client_info
|
||||
from google.api_core import exceptions
|
||||
from google.api_core.gapic_v1 import client_info as v1_client_info
|
||||
from google.cloud import bigquery
|
||||
from google.cloud import bigquery_storage
|
||||
from google.cloud.aiplatform import initializer
|
||||
from google.cloud.bigquery_storage import types
|
||||
from ray.data.block import Block
|
||||
from ray.data.block import BlockMetadata
|
||||
from ray.data.datasource.datasource import Datasource
|
||||
from ray.data.datasource.datasource import ReadTask
|
||||
|
||||
|
||||
_BQ_GAPIC_VERSION = bigquery.__version__ + "+vertex_ray"
|
||||
_BQS_GAPIC_VERSION = bigquery_storage.__version__ + "+vertex_ray"
|
||||
bq_info = client_info.ClientInfo(
|
||||
gapic_version=_BQ_GAPIC_VERSION, user_agent=f"ray-on-vertex/{_BQ_GAPIC_VERSION}"
|
||||
)
|
||||
bqstorage_info = v1_client_info.ClientInfo(
|
||||
gapic_version=_BQS_GAPIC_VERSION, user_agent=f"ray-on-vertex/{_BQS_GAPIC_VERSION}"
|
||||
)
|
||||
|
||||
DEFAULT_MAX_RETRY_CNT = 10
|
||||
RATE_LIMIT_EXCEEDED_SLEEP_TIME = 11
|
||||
|
||||
|
||||
class _BigQueryDatasource(Datasource):
|
||||
def __init__(
|
||||
self,
|
||||
project_id: Optional[str] = None,
|
||||
dataset: Optional[str] = None,
|
||||
query: Optional[str] = None,
|
||||
):
|
||||
self._project_id = project_id or initializer.global_config.project
|
||||
self._dataset = dataset
|
||||
self._query = query
|
||||
|
||||
if query is not None and dataset is not None:
|
||||
raise ValueError(
|
||||
"[Ray on Vertex AI]: Query and dataset kwargs cannot both be provided (must be mutually exclusive)."
|
||||
)
|
||||
|
||||
def get_read_tasks(self, parallelism: int) -> List[ReadTask]:
|
||||
# Executed by a worker node
|
||||
def _read_single_partition(stream) -> Block:
|
||||
client = bigquery_storage.BigQueryReadClient(client_info=bqstorage_info)
|
||||
reader = client.read_rows(stream.name)
|
||||
return reader.to_arrow()
|
||||
|
||||
if self._query:
|
||||
query_client = bigquery.Client(
|
||||
project=self._project_id, client_info=bq_info
|
||||
)
|
||||
query_job = query_client.query(self._query)
|
||||
query_job.result()
|
||||
destination = str(query_job.destination)
|
||||
dataset_id = destination.split(".")[-2]
|
||||
table_id = destination.split(".")[-1]
|
||||
else:
|
||||
self._validate_dataset_table_exist(self._project_id, self._dataset)
|
||||
dataset_id = self._dataset.split(".")[0]
|
||||
table_id = self._dataset.split(".")[1]
|
||||
|
||||
bqs_client = bigquery_storage.BigQueryReadClient(client_info=bqstorage_info)
|
||||
table = f"projects/{self._project_id}/datasets/{dataset_id}/tables/{table_id}"
|
||||
|
||||
if parallelism == -1:
|
||||
parallelism = None
|
||||
requested_session = types.ReadSession(
|
||||
table=table,
|
||||
data_format=types.DataFormat.ARROW,
|
||||
)
|
||||
read_session = bqs_client.create_read_session(
|
||||
parent=f"projects/{self._project_id}",
|
||||
read_session=requested_session,
|
||||
max_stream_count=parallelism,
|
||||
)
|
||||
|
||||
read_tasks = []
|
||||
print("[Ray on Vertex AI]: Created streams:", len(read_session.streams))
|
||||
if len(read_session.streams) < parallelism:
|
||||
print(
|
||||
"[Ray on Vertex AI]: The number of streams created by the "
|
||||
+ "BigQuery Storage Read API is less than the requested "
|
||||
+ "parallelism due to the size of the dataset."
|
||||
)
|
||||
|
||||
for stream in read_session.streams:
|
||||
# Create a metadata block object to store schema, etc.
|
||||
metadata = BlockMetadata(
|
||||
num_rows=None,
|
||||
size_bytes=None,
|
||||
schema=None,
|
||||
input_files=None,
|
||||
exec_stats=None,
|
||||
)
|
||||
|
||||
# Create a no-arg wrapper read function which returns a block
|
||||
read_single_partition = lambda stream=stream: [ # noqa: E731
|
||||
_read_single_partition(stream)
|
||||
]
|
||||
|
||||
# Create the read task and pass the wrapper and metadata in
|
||||
read_task = ReadTask(read_single_partition, metadata)
|
||||
read_tasks.append(read_task)
|
||||
|
||||
return read_tasks
|
||||
|
||||
def estimate_inmemory_data_size(self) -> Optional[int]:
|
||||
# TODO(b/281891467): Implement this method
|
||||
return None
|
||||
|
||||
def _validate_dataset_table_exist(self, project_id: str, dataset: str) -> None:
|
||||
client = bigquery.Client(project=project_id, client_info=bq_info)
|
||||
dataset_id = dataset.split(".")[0]
|
||||
try:
|
||||
client.get_dataset(dataset_id)
|
||||
except exceptions.NotFound:
|
||||
raise ValueError(
|
||||
"[Ray on Vertex AI]: Dataset {} is not found. Please ensure that it exists.".format(
|
||||
dataset_id
|
||||
)
|
||||
)
|
||||
|
||||
try:
|
||||
client.get_table(dataset)
|
||||
except exceptions.NotFound:
|
||||
raise ValueError(
|
||||
"[Ray on Vertex AI]: Table {} is not found. Please ensure that it exists.".format(
|
||||
dataset
|
||||
)
|
||||
)
|
||||
201
.venv/lib/python3.10/site-packages/vertex_ray/client_builder.py
Normal file
201
.venv/lib/python3.10/site-packages/vertex_ray/client_builder.py
Normal file
@@ -0,0 +1,201 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
import grpc
|
||||
import logging
|
||||
import ray
|
||||
|
||||
from typing import Dict
|
||||
from typing import Optional
|
||||
from google.cloud import aiplatform
|
||||
from google.cloud.aiplatform import initializer
|
||||
from ray import client_builder
|
||||
from .render import VertexRayTemplate
|
||||
from .util import _validation_utils
|
||||
from .util import _gapic_utils
|
||||
|
||||
|
||||
VERTEX_SDK_VERSION = aiplatform.__version__
|
||||
|
||||
|
||||
class _VertexRayClientContext(client_builder.ClientContext):
|
||||
"""Custom ClientContext."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
persistent_resource_id: str,
|
||||
ray_head_uris: Dict[str, str],
|
||||
ray_client_context: client_builder.ClientContext,
|
||||
) -> None:
|
||||
dashboard_uri = ray_head_uris.get("RAY_DASHBOARD_URI")
|
||||
if dashboard_uri is None:
|
||||
raise ValueError(
|
||||
"Ray Cluster ",
|
||||
persistent_resource_id,
|
||||
" failed to start Head node properly.",
|
||||
)
|
||||
if ray.__version__ in ("2.42.0", "2.33.0"):
|
||||
super().__init__(
|
||||
dashboard_url=dashboard_uri,
|
||||
python_version=ray_client_context.python_version,
|
||||
ray_version=ray_client_context.ray_version,
|
||||
ray_commit=ray_client_context.ray_commit,
|
||||
_num_clients=ray_client_context._num_clients,
|
||||
_context_to_restore=ray_client_context._context_to_restore,
|
||||
)
|
||||
elif ray.__version__ == "2.9.3":
|
||||
super().__init__(
|
||||
dashboard_url=dashboard_uri,
|
||||
python_version=ray_client_context.python_version,
|
||||
ray_version=ray_client_context.ray_version,
|
||||
ray_commit=ray_client_context.ray_commit,
|
||||
protocol_version=ray_client_context.protocol_version,
|
||||
_num_clients=ray_client_context._num_clients,
|
||||
_context_to_restore=ray_client_context._context_to_restore,
|
||||
)
|
||||
else:
|
||||
raise ImportError(
|
||||
f"[Ray on Vertex AI]: Unsupported version {ray.__version__}."
|
||||
+ "Only 2.42.0, 2.33.0, and 2.9.3 are supported."
|
||||
)
|
||||
self.persistent_resource_id = persistent_resource_id
|
||||
self.vertex_sdk_version = str(VERTEX_SDK_VERSION)
|
||||
self.shell_uri = ray_head_uris.get("RAY_HEAD_NODE_INTERACTIVE_SHELL_URI")
|
||||
|
||||
def _context_table_template(self):
|
||||
shell_uri_row = None
|
||||
if self.shell_uri is not None:
|
||||
shell_uri_row = VertexRayTemplate("context_shellurirow.html.j2").render(
|
||||
shell_uri=self.shell_uri
|
||||
)
|
||||
|
||||
return VertexRayTemplate("context_table.html.j2").render(
|
||||
python_version=self.python_version,
|
||||
ray_version=self.ray_version,
|
||||
vertex_sdk_version=self.vertex_sdk_version,
|
||||
dashboard_url=self.dashboard_url,
|
||||
persistent_resource_id=self.persistent_resource_id,
|
||||
shell_uri_row=shell_uri_row,
|
||||
)
|
||||
|
||||
|
||||
class VertexRayClientBuilder(client_builder.ClientBuilder):
|
||||
"""Class to initialize a Ray client with vertex on ray capabilities."""
|
||||
|
||||
def __init__(self, address: Optional[str]) -> None:
|
||||
address = _validation_utils.maybe_reconstruct_resource_name(address)
|
||||
_validation_utils.valid_resource_name(address)
|
||||
self._credentials = None
|
||||
self._metadata = None
|
||||
self.vertex_address = address
|
||||
logging.info(
|
||||
"[Ray on Vertex AI]: Using cluster resource name to access head address with GAPIC API"
|
||||
)
|
||||
|
||||
self.resource_name = address
|
||||
|
||||
self.response = _gapic_utils.get_persistent_resource(self.resource_name)
|
||||
private_address = self.response.resource_runtime.access_uris.get(
|
||||
"RAY_HEAD_NODE_INTERNAL_IP"
|
||||
)
|
||||
public_address = self.response.resource_runtime.access_uris.get(
|
||||
"RAY_CLIENT_ENDPOINT"
|
||||
)
|
||||
service_account = (
|
||||
self.response.resource_runtime_spec.service_account_spec.service_account
|
||||
)
|
||||
|
||||
if public_address is None:
|
||||
address = private_address
|
||||
if service_account:
|
||||
raise ValueError(
|
||||
"[Ray on Vertex AI]: Ray Cluster ",
|
||||
address,
|
||||
" failed to start Head node properly because custom service"
|
||||
" account isn't supported in peered VPC network. Use public"
|
||||
" endpoint instead (createa a cluster without specifying"
|
||||
" VPC network).",
|
||||
)
|
||||
else:
|
||||
address = public_address
|
||||
|
||||
if address is None:
|
||||
persistent_resource_id = self.resource_name.split("/")[5]
|
||||
raise ValueError(
|
||||
"[Ray on Vertex AI]: Ray Cluster ",
|
||||
persistent_resource_id,
|
||||
" Head node is not reachable. Please ensure that a valid VPC network has been specified.",
|
||||
)
|
||||
|
||||
logging.debug("[Ray on Vertex AI]: Resolved head node ip: %s", address)
|
||||
cluster = _gapic_utils.persistent_resource_to_cluster(
|
||||
persistent_resource=self.response
|
||||
)
|
||||
if cluster is None:
|
||||
raise ValueError(
|
||||
"[Ray on Vertex AI]: Please delete and recreate the cluster (The cluster is not a Ray cluster or the cluster image is outdated)."
|
||||
)
|
||||
local_ray_verion = _validation_utils.get_local_ray_version()
|
||||
if cluster.ray_version != local_ray_verion:
|
||||
if cluster.head_node_type.custom_image is None:
|
||||
install_ray_version = _validation_utils.SUPPORTED_RAY_VERSIONS.get(
|
||||
cluster.ray_version
|
||||
)
|
||||
logging.info(
|
||||
"[Ray on Vertex]: Local runtime has Ray version %s"
|
||||
", but the requested cluster runtime has %s. Please "
|
||||
"ensure that the Ray versions match for client connectivity. You may "
|
||||
'"pip install --user --force-reinstall ray[default]==%s"'
|
||||
" and restart runtime before cluster connection."
|
||||
% (local_ray_verion, cluster.ray_version, install_ray_version)
|
||||
)
|
||||
else:
|
||||
logging.info(
|
||||
"[Ray on Vertex]: Local runtime has Ray version %s."
|
||||
"Please ensure that the Ray versions match for client connectivity."
|
||||
% local_ray_verion
|
||||
)
|
||||
super().__init__(address)
|
||||
|
||||
def connect(self) -> _VertexRayClientContext:
|
||||
# Can send any other params to ray cluster here
|
||||
logging.info("[Ray on Vertex AI]: Connecting...")
|
||||
|
||||
public_address = self.response.resource_runtime.access_uris.get(
|
||||
"RAY_CLIENT_ENDPOINT"
|
||||
)
|
||||
private_address = self.response.resource_runtime.access_uris.get(
|
||||
"RAY_HEAD_NODE_INTERNAL_IP"
|
||||
)
|
||||
if public_address and not private_address:
|
||||
self._credentials = grpc.ssl_channel_credentials()
|
||||
bearer_token = _validation_utils.get_bearer_token()
|
||||
self._metadata = [
|
||||
("authorization", "Bearer {}".format(bearer_token)),
|
||||
("x-goog-user-project", "{}".format(initializer.global_config.project)),
|
||||
]
|
||||
ray_client_context = super().connect()
|
||||
ray_head_uris = self.response.resource_runtime.access_uris
|
||||
|
||||
# Valid resource name (reference public doc for public release):
|
||||
# "projects/<project_num>/locations/<region>/persistentResources/<pr_id>"
|
||||
persistent_resource_id = self.resource_name.split("/")[5]
|
||||
|
||||
return _VertexRayClientContext(
|
||||
persistent_resource_id=persistent_resource_id,
|
||||
ray_head_uris=ray_head_uris,
|
||||
ray_client_context=ray_client_context,
|
||||
)
|
||||
575
.venv/lib/python3.10/site-packages/vertex_ray/cluster_init.py
Normal file
575
.venv/lib/python3.10/site-packages/vertex_ray/cluster_init.py
Normal file
@@ -0,0 +1,575 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
# Copyright 2023 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
import copy
|
||||
import logging
|
||||
import time
|
||||
from typing import Dict, List, Optional
|
||||
import warnings
|
||||
|
||||
from google.cloud.aiplatform import initializer
|
||||
from google.cloud.aiplatform import utils
|
||||
from google.cloud.aiplatform.utils import resource_manager_utils
|
||||
from google.cloud.aiplatform_v1beta1.types import persistent_resource_service
|
||||
from google.cloud.aiplatform_v1beta1.types.machine_resources import NfsMount
|
||||
from google.cloud.aiplatform_v1beta1.types.persistent_resource import (
|
||||
PersistentResource,
|
||||
RayLogsSpec,
|
||||
RaySpec,
|
||||
RayMetricSpec,
|
||||
ResourcePool,
|
||||
ResourceRuntimeSpec,
|
||||
ServiceAccountSpec,
|
||||
)
|
||||
from google.cloud.aiplatform_v1beta1.types.service_networking import (
|
||||
PscInterfaceConfig,
|
||||
)
|
||||
from google.cloud.aiplatform.vertex_ray.util import (
|
||||
_gapic_utils,
|
||||
_validation_utils,
|
||||
resources,
|
||||
)
|
||||
|
||||
from google.protobuf import field_mask_pb2 # type: ignore
|
||||
from google.cloud.aiplatform.vertex_ray.util._validation_utils import (
|
||||
_V2_4_WARNING_MESSAGE,
|
||||
_V2_9_WARNING_MESSAGE,
|
||||
)
|
||||
|
||||
|
||||
def create_ray_cluster(
|
||||
head_node_type: Optional[resources.Resources] = resources.Resources(),
|
||||
python_version: Optional[str] = "3.10",
|
||||
ray_version: Optional[str] = "2.42",
|
||||
network: Optional[str] = None,
|
||||
service_account: Optional[str] = None,
|
||||
cluster_name: Optional[str] = None,
|
||||
worker_node_types: Optional[List[resources.Resources]] = [resources.Resources()],
|
||||
custom_images: Optional[resources.NodeImages] = None,
|
||||
enable_metrics_collection: Optional[bool] = True,
|
||||
enable_logging: Optional[bool] = True,
|
||||
psc_interface_config: Optional[resources.PscIConfig] = None,
|
||||
reserved_ip_ranges: Optional[List[str]] = None,
|
||||
nfs_mounts: Optional[List[resources.NfsMount]] = None,
|
||||
labels: Optional[Dict[str, str]] = None,
|
||||
) -> str:
|
||||
"""Create a ray cluster on the Vertex AI.
|
||||
|
||||
Sample usage:
|
||||
|
||||
from vertex_ray import Resources
|
||||
|
||||
head_node_type = Resources(
|
||||
machine_type="n1-standard-8",
|
||||
node_count=1,
|
||||
accelerator_type="NVIDIA_TESLA_T4",
|
||||
accelerator_count=1,
|
||||
custom_image="us-docker.pkg.dev/my-project/ray-cpu-image.2.33:latest", # Optional
|
||||
)
|
||||
|
||||
worker_node_types = [Resources(
|
||||
machine_type="n1-standard-8",
|
||||
node_count=2,
|
||||
accelerator_type="NVIDIA_TESLA_T4",
|
||||
accelerator_count=1,
|
||||
custom_image="us-docker.pkg.dev/my-project/ray-gpu-image.2.33:latest", # Optional
|
||||
)]
|
||||
|
||||
cluster_resource_name = vertex_ray.create_ray_cluster(
|
||||
head_node_type=head_node_type,
|
||||
network="projects/my-project-number/global/networks/my-vpc-name", # Optional
|
||||
service_account="my-service-account@my-project-number.iam.gserviceaccount.com", # Optional
|
||||
cluster_name="my-cluster-name", # Optional
|
||||
worker_node_types=worker_node_types,
|
||||
ray_version="2.33",
|
||||
)
|
||||
|
||||
After a ray cluster is set up, you can call
|
||||
`ray.init(f"vertex_ray://{cluster_resource_name}", runtime_env=...)` without
|
||||
specifying ray cluster address to connect to the cluster. To shut down the
|
||||
cluster you can call `ray.delete_ray_cluster()`.
|
||||
Note: If the active ray cluster has not finished shutting down, you cannot
|
||||
create a new ray cluster with the same cluster_name.
|
||||
|
||||
Args:
|
||||
head_node_type: The head node resource. Resources.node_count must be 1.
|
||||
If not set, default value of Resources() class will be used.
|
||||
python_version: Python version for the ray cluster.
|
||||
ray_version: Ray version for the ray cluster. Default is 2.42.0.
|
||||
network: Virtual private cloud (VPC) network. For Ray Client, VPC
|
||||
peering is required to connect to the Ray Cluster managed in the
|
||||
Vertex API service. For Ray Job API, VPC network is not required
|
||||
because Ray Cluster connection can be accessed through dashboard
|
||||
address.
|
||||
service_account: Service account to be used for running Ray programs on
|
||||
the cluster.
|
||||
cluster_name: This value may be up to 63 characters, and valid
|
||||
characters are `[a-z0-9_-]`. The first character cannot be a number
|
||||
or hyphen.
|
||||
worker_node_types: The list of Resources of the worker nodes. The same
|
||||
Resources object should not appear multiple times in the list.
|
||||
custom_images: The NodeImages which specifies head node and worker nodes
|
||||
images. All the workers will share the same image. If each Resource
|
||||
has a specific custom image, use `Resources.custom_image` for
|
||||
head/worker_node_type(s). Note that configuring `Resources.custom_image`
|
||||
will override `custom_images` here. Allowlist only.
|
||||
enable_metrics_collection: Enable Ray metrics collection for visualization.
|
||||
enable_logging: Enable exporting Ray logs to Cloud Logging.
|
||||
psc_interface_config: PSC-I config.
|
||||
reserved_ip_ranges: A list of names for the reserved IP ranges under
|
||||
the VPC network that can be used for this cluster. If set, we will
|
||||
deploy the cluster within the provided IP ranges. Otherwise, the
|
||||
cluster is deployed to any IP ranges under the provided VPC network.
|
||||
Example: ["vertex-ai-ip-range"].
|
||||
labels:
|
||||
The labels with user-defined metadata to organize Ray cluster.
|
||||
|
||||
Label keys and values can be no longer than 64 characters (Unicode
|
||||
codepoints), can only contain lowercase letters, numeric characters,
|
||||
underscores and dashes. International characters are allowed.
|
||||
|
||||
See https://goo.gl/xmQnxf for more information and examples of labels.
|
||||
|
||||
Returns:
|
||||
The cluster_resource_name of the initiated Ray cluster on Vertex.
|
||||
Raise:
|
||||
ValueError: If the cluster is not created successfully.
|
||||
RuntimeError: If the ray_version is 2.4.
|
||||
"""
|
||||
|
||||
if network is None:
|
||||
logging.info(
|
||||
"[Ray on Vertex]: No VPC network configured. It is required for client connection."
|
||||
)
|
||||
if ray_version == "2.4":
|
||||
raise RuntimeError(_V2_4_WARNING_MESSAGE)
|
||||
if ray_version == "2.9.3":
|
||||
warnings.warn(_V2_9_WARNING_MESSAGE, DeprecationWarning, stacklevel=1)
|
||||
local_ray_verion = _validation_utils.get_local_ray_version()
|
||||
if ray_version != local_ray_verion:
|
||||
if custom_images is None and head_node_type.custom_image is None:
|
||||
install_ray_version = "2.42.0"
|
||||
logging.info(
|
||||
"[Ray on Vertex]: Local runtime has Ray version %s"
|
||||
", but the requested cluster runtime has %s. Please "
|
||||
"ensure that the Ray versions match for client connectivity. You may "
|
||||
'"pip install --user --force-reinstall ray[default]==%s"'
|
||||
" and restart runtime before cluster connection."
|
||||
% (local_ray_verion, ray_version, install_ray_version)
|
||||
)
|
||||
else:
|
||||
logging.info(
|
||||
"[Ray on Vertex]: Local runtime has Ray version %s."
|
||||
"Please ensure that the Ray versions match for client connectivity."
|
||||
% local_ray_verion
|
||||
)
|
||||
|
||||
if cluster_name is None:
|
||||
cluster_name = "ray-cluster-" + utils.timestamped_unique_name()
|
||||
|
||||
if head_node_type:
|
||||
if head_node_type.node_count != 1:
|
||||
raise ValueError(
|
||||
"[Ray on Vertex AI]: For head_node_type, "
|
||||
+ "Resources.node_count must be 1."
|
||||
)
|
||||
if head_node_type.autoscaling_spec is not None:
|
||||
raise ValueError(
|
||||
"[Ray on Vertex AI]: For head_node_type, "
|
||||
+ "Resources.autoscaling_spec must be None."
|
||||
)
|
||||
if (
|
||||
head_node_type.accelerator_type is None
|
||||
and head_node_type.accelerator_count > 0
|
||||
):
|
||||
raise ValueError(
|
||||
"[Ray on Vertex]: accelerator_type must be specified when"
|
||||
+ " accelerator_count is set to a value other than 0."
|
||||
)
|
||||
|
||||
resource_pool_images = {}
|
||||
|
||||
# head node
|
||||
resource_pool_0 = ResourcePool()
|
||||
resource_pool_0.id = "head-node"
|
||||
resource_pool_0.replica_count = head_node_type.node_count
|
||||
resource_pool_0.machine_spec.machine_type = head_node_type.machine_type
|
||||
resource_pool_0.machine_spec.accelerator_count = head_node_type.accelerator_count
|
||||
resource_pool_0.machine_spec.accelerator_type = head_node_type.accelerator_type
|
||||
resource_pool_0.disk_spec.boot_disk_type = head_node_type.boot_disk_type
|
||||
resource_pool_0.disk_spec.boot_disk_size_gb = head_node_type.boot_disk_size_gb
|
||||
|
||||
enable_cuda = True if head_node_type.accelerator_count > 0 else False
|
||||
if head_node_type.custom_image is not None:
|
||||
image_uri = head_node_type.custom_image
|
||||
elif custom_images is None:
|
||||
image_uri = _validation_utils.get_image_uri(
|
||||
ray_version, python_version, enable_cuda
|
||||
)
|
||||
elif custom_images.head is not None and custom_images.worker is not None:
|
||||
image_uri = custom_images.head
|
||||
else:
|
||||
raise ValueError(
|
||||
"[Ray on Vertex AI]: custom_images.head and custom_images.worker must be specified when custom_images is set."
|
||||
)
|
||||
|
||||
resource_pool_images[resource_pool_0.id] = image_uri
|
||||
|
||||
worker_pools = []
|
||||
i = 0
|
||||
if worker_node_types:
|
||||
for worker_node_type in worker_node_types:
|
||||
if (
|
||||
worker_node_type.accelerator_type is None
|
||||
and worker_node_type.accelerator_count > 0
|
||||
):
|
||||
raise ValueError(
|
||||
"[Ray on Vertex]: accelerator_type must be specified when"
|
||||
+ " accelerator_count is set to a value other than 0."
|
||||
)
|
||||
additional_replica_count = resources._check_machine_spec_identical(
|
||||
head_node_type, worker_node_type
|
||||
)
|
||||
if worker_node_type.autoscaling_spec is None:
|
||||
# Worker and head share the same MachineSpec, merge them into the
|
||||
# same ResourcePool
|
||||
resource_pool_0.replica_count = (
|
||||
resource_pool_0.replica_count + additional_replica_count
|
||||
)
|
||||
else:
|
||||
if additional_replica_count > 0:
|
||||
# Autoscaling for single ResourcePool (homogeneous cluster).
|
||||
resource_pool_0.replica_count = None
|
||||
resource_pool_0.autoscaling_spec.min_replica_count = (
|
||||
worker_node_type.autoscaling_spec.min_replica_count
|
||||
)
|
||||
resource_pool_0.autoscaling_spec.max_replica_count = (
|
||||
worker_node_type.autoscaling_spec.max_replica_count
|
||||
)
|
||||
if additional_replica_count == 0:
|
||||
resource_pool = ResourcePool()
|
||||
resource_pool.id = f"worker-pool{i+1}"
|
||||
if worker_node_type.autoscaling_spec is None:
|
||||
resource_pool.replica_count = worker_node_type.node_count
|
||||
else:
|
||||
# Autoscaling for worker ResourcePool.
|
||||
resource_pool.autoscaling_spec.min_replica_count = (
|
||||
worker_node_type.autoscaling_spec.min_replica_count
|
||||
)
|
||||
resource_pool.autoscaling_spec.max_replica_count = (
|
||||
worker_node_type.autoscaling_spec.max_replica_count
|
||||
)
|
||||
resource_pool.machine_spec.machine_type = worker_node_type.machine_type
|
||||
resource_pool.machine_spec.accelerator_count = (
|
||||
worker_node_type.accelerator_count
|
||||
)
|
||||
resource_pool.machine_spec.accelerator_type = (
|
||||
worker_node_type.accelerator_type
|
||||
)
|
||||
resource_pool.disk_spec.boot_disk_type = worker_node_type.boot_disk_type
|
||||
resource_pool.disk_spec.boot_disk_size_gb = (
|
||||
worker_node_type.boot_disk_size_gb
|
||||
)
|
||||
worker_pools.append(resource_pool)
|
||||
enable_cuda = True if worker_node_type.accelerator_count > 0 else False
|
||||
|
||||
if worker_node_type.custom_image is not None:
|
||||
image_uri = worker_node_type.custom_image
|
||||
elif custom_images is None:
|
||||
image_uri = _validation_utils.get_image_uri(
|
||||
ray_version, python_version, enable_cuda
|
||||
)
|
||||
else:
|
||||
image_uri = custom_images.worker
|
||||
|
||||
resource_pool_images[resource_pool.id] = image_uri
|
||||
|
||||
i += 1
|
||||
|
||||
resource_pools = [resource_pool_0] + worker_pools
|
||||
|
||||
metrics_collection_disabled = not enable_metrics_collection
|
||||
ray_metric_spec = RayMetricSpec(disabled=metrics_collection_disabled)
|
||||
|
||||
logging_disabled = not enable_logging
|
||||
ray_logs_spec = RayLogsSpec(disabled=logging_disabled)
|
||||
|
||||
ray_spec = RaySpec(
|
||||
resource_pool_images=resource_pool_images,
|
||||
ray_metric_spec=ray_metric_spec,
|
||||
ray_logs_spec=ray_logs_spec,
|
||||
)
|
||||
if nfs_mounts:
|
||||
gapic_nfs_mounts = []
|
||||
for nfs_mount in nfs_mounts:
|
||||
gapic_nfs_mounts.append(
|
||||
NfsMount(
|
||||
server=nfs_mount.server,
|
||||
path=nfs_mount.path,
|
||||
mount_point=nfs_mount.mount_point,
|
||||
)
|
||||
)
|
||||
ray_spec.nfs_mounts = gapic_nfs_mounts
|
||||
if service_account:
|
||||
service_account_spec = ServiceAccountSpec(
|
||||
enable_custom_service_account=True,
|
||||
service_account=service_account,
|
||||
)
|
||||
resource_runtime_spec = ResourceRuntimeSpec(
|
||||
ray_spec=ray_spec,
|
||||
service_account_spec=service_account_spec,
|
||||
)
|
||||
else:
|
||||
resource_runtime_spec = ResourceRuntimeSpec(ray_spec=ray_spec)
|
||||
if psc_interface_config:
|
||||
gapic_psc_interface_config = PscInterfaceConfig(
|
||||
network_attachment=psc_interface_config.network_attachment,
|
||||
)
|
||||
else:
|
||||
gapic_psc_interface_config = None
|
||||
|
||||
persistent_resource = PersistentResource(
|
||||
resource_pools=resource_pools,
|
||||
network=network,
|
||||
labels=labels,
|
||||
resource_runtime_spec=resource_runtime_spec,
|
||||
psc_interface_config=gapic_psc_interface_config,
|
||||
reserved_ip_ranges=reserved_ip_ranges,
|
||||
)
|
||||
|
||||
location = initializer.global_config.location
|
||||
project_id = initializer.global_config.project
|
||||
project_number = resource_manager_utils.get_project_number(project_id)
|
||||
|
||||
parent = f"projects/{project_number}/locations/{location}"
|
||||
request = persistent_resource_service.CreatePersistentResourceRequest(
|
||||
parent=parent,
|
||||
persistent_resource=persistent_resource,
|
||||
persistent_resource_id=cluster_name,
|
||||
)
|
||||
|
||||
client = _gapic_utils.create_persistent_resource_client()
|
||||
try:
|
||||
_ = client.create_persistent_resource(request)
|
||||
except Exception as e:
|
||||
raise ValueError("Failed in cluster creation due to: ", e) from e
|
||||
|
||||
# Get persisent resource
|
||||
cluster_resource_name = f"{parent}/persistentResources/{cluster_name}"
|
||||
response = _gapic_utils.get_persistent_resource(
|
||||
persistent_resource_name=cluster_resource_name,
|
||||
tolerance=1, # allow 1 retry to avoid get request before creation
|
||||
)
|
||||
return response.name
|
||||
|
||||
|
||||
def delete_ray_cluster(cluster_resource_name: str) -> None:
|
||||
"""Delete Ray Cluster.
|
||||
|
||||
Args:
|
||||
cluster_resource_name: Cluster resource name.
|
||||
Raises:
|
||||
FailedPrecondition: If the cluster is deleted already.
|
||||
"""
|
||||
client = _gapic_utils.create_persistent_resource_client()
|
||||
request = persistent_resource_service.DeletePersistentResourceRequest(
|
||||
name=cluster_resource_name
|
||||
)
|
||||
|
||||
try:
|
||||
client.delete_persistent_resource(request)
|
||||
print("[Ray on Vertex AI]: Successfully deleted the cluster.")
|
||||
except Exception as e:
|
||||
raise ValueError(
|
||||
"[Ray on Vertex AI]: Failed in cluster deletion due to: ", e
|
||||
) from e
|
||||
|
||||
|
||||
def get_ray_cluster(cluster_resource_name: str) -> resources.Cluster:
|
||||
"""Get Ray Cluster.
|
||||
|
||||
Args:
|
||||
cluster_resource_name: Cluster resource name.
|
||||
Returns:
|
||||
A Cluster object.
|
||||
"""
|
||||
client = _gapic_utils.create_persistent_resource_client()
|
||||
request = persistent_resource_service.GetPersistentResourceRequest(
|
||||
name=cluster_resource_name
|
||||
)
|
||||
try:
|
||||
response = client.get_persistent_resource(request)
|
||||
except Exception as e:
|
||||
raise ValueError(
|
||||
"[Ray on Vertex AI]: Failed in getting the cluster due to: ", e
|
||||
) from e
|
||||
|
||||
cluster = _gapic_utils.persistent_resource_to_cluster(persistent_resource=response)
|
||||
if cluster:
|
||||
return cluster
|
||||
raise ValueError(
|
||||
"[Ray on Vertex AI]: Please delete and recreate the cluster (The cluster is not a Ray cluster or the cluster image is outdated)."
|
||||
)
|
||||
|
||||
|
||||
def list_ray_clusters() -> List[resources.Cluster]:
|
||||
"""List Ray Clusters under the currently authenticated project.
|
||||
|
||||
Returns:
|
||||
List of Cluster objects that exists in the current authorized project.
|
||||
"""
|
||||
location = initializer.global_config.location
|
||||
project_id = initializer.global_config.project
|
||||
project_number = resource_manager_utils.get_project_number(project_id)
|
||||
parent = f"projects/{project_number}/locations/{location}"
|
||||
request = persistent_resource_service.ListPersistentResourcesRequest(
|
||||
parent=parent,
|
||||
)
|
||||
client = _gapic_utils.create_persistent_resource_client()
|
||||
try:
|
||||
response = client.list_persistent_resources(request)
|
||||
except Exception as e:
|
||||
raise ValueError(
|
||||
"[Ray on Vertex AI]: Failed in listing the clusters due to: ", e
|
||||
) from e
|
||||
|
||||
ray_clusters = []
|
||||
for persistent_resource in response:
|
||||
ray_cluster = _gapic_utils.persistent_resource_to_cluster(
|
||||
persistent_resource=persistent_resource
|
||||
)
|
||||
if ray_cluster:
|
||||
ray_clusters.append(ray_cluster)
|
||||
|
||||
return ray_clusters
|
||||
|
||||
|
||||
def update_ray_cluster(
|
||||
cluster_resource_name: str, worker_node_types: List[resources.Resources]
|
||||
) -> str:
|
||||
"""Update Ray Cluster (currently support resizing node counts for worker nodes).
|
||||
|
||||
Sample usage:
|
||||
|
||||
my_cluster = vertex_ray.get_ray_cluster(
|
||||
cluster_resource_name=my_existing_cluster_resource_name,
|
||||
)
|
||||
|
||||
# Declaration to resize all the worker_node_type to node_count=1
|
||||
new_worker_node_types = []
|
||||
for worker_node_type in my_cluster.worker_node_types:
|
||||
worker_node_type.node_count = 1
|
||||
new_worker_node_types.append(worker_node_type)
|
||||
|
||||
# Execution to update new node_count (block until complete)
|
||||
vertex_ray.update_ray_cluster(
|
||||
cluster_resource_name=my_cluster.cluster_resource_name,
|
||||
worker_node_types=new_worker_node_types,
|
||||
)
|
||||
|
||||
Args:
|
||||
cluster_resource_name:
|
||||
worker_node_types: The list of Resources of the resized worker nodes.
|
||||
The same Resources object should not appear multiple times in the list.
|
||||
Returns:
|
||||
The cluster_resource_name of the Ray cluster on Vertex.
|
||||
"""
|
||||
# worker_node_types should not be duplicated.
|
||||
for i in range(len(worker_node_types)):
|
||||
for j in range(len(worker_node_types)):
|
||||
additional_replica_count = resources._check_machine_spec_identical(
|
||||
worker_node_types[i], worker_node_types[j]
|
||||
)
|
||||
if additional_replica_count > 0 and i != j:
|
||||
raise ValueError(
|
||||
"[Ray on Vertex AI]: Worker_node_types have duplicate "
|
||||
+ f"machine specs: {worker_node_types[i]} "
|
||||
+ f"and {worker_node_types[j]}"
|
||||
)
|
||||
|
||||
persistent_resource = _gapic_utils.get_persistent_resource(
|
||||
persistent_resource_name=cluster_resource_name
|
||||
)
|
||||
|
||||
current_persistent_resource = copy.deepcopy(persistent_resource)
|
||||
current_persistent_resource.resource_pools[0].replica_count = 1
|
||||
|
||||
previous_ray_cluster = get_ray_cluster(cluster_resource_name)
|
||||
head_node_type = previous_ray_cluster.head_node_type
|
||||
previous_worker_node_types = previous_ray_cluster.worker_node_types
|
||||
|
||||
# new worker_node_types and previous_worker_node_types should be the same length.
|
||||
if len(worker_node_types) != len(previous_worker_node_types):
|
||||
raise ValueError(
|
||||
"[Ray on Vertex AI]: Desired number of worker_node_types "
|
||||
+ "(%i) does not match the number of the "
|
||||
+ "existing worker_node_type(%i).",
|
||||
len(worker_node_types),
|
||||
len(previous_worker_node_types),
|
||||
)
|
||||
|
||||
# Merge worker_node_type and head_node_type if they share
|
||||
# the same machine spec.
|
||||
not_merged = 1
|
||||
for i in range(len(worker_node_types)):
|
||||
additional_replica_count = resources._check_machine_spec_identical(
|
||||
head_node_type, worker_node_types[i]
|
||||
)
|
||||
if additional_replica_count != 0 or (
|
||||
additional_replica_count == 0 and worker_node_types[i].node_count == 0
|
||||
):
|
||||
# Merge the 1st duplicated worker with head, allow scale down to 0 worker
|
||||
current_persistent_resource.resource_pools[0].replica_count = (
|
||||
1 + additional_replica_count
|
||||
)
|
||||
# Reset not_merged
|
||||
not_merged = 0
|
||||
else:
|
||||
# No duplication w/ head node, write the 2nd worker node to the 2nd resource pool.
|
||||
current_persistent_resource.resource_pools[
|
||||
i + not_merged
|
||||
].replica_count = worker_node_types[i].node_count
|
||||
# New worker_node_type.node_count should be >=1 unless the worker_node_type
|
||||
# and head_node_type are merged due to the same machine specs.
|
||||
if worker_node_types[i].node_count == 0:
|
||||
raise ValueError(
|
||||
"[Ray on Vertex AI]: Worker_node_type "
|
||||
+ f"({worker_node_types[i]}) must update to >= 1 nodes",
|
||||
)
|
||||
|
||||
request = persistent_resource_service.UpdatePersistentResourceRequest(
|
||||
persistent_resource=current_persistent_resource,
|
||||
update_mask=field_mask_pb2.FieldMask(paths=["resource_pools.replica_count"]),
|
||||
)
|
||||
client = _gapic_utils.create_persistent_resource_client()
|
||||
try:
|
||||
operation_future = client.update_persistent_resource(request)
|
||||
except Exception as e:
|
||||
raise ValueError(
|
||||
"[Ray on Vertex AI]: Failed in updating the cluster due to: ", e
|
||||
) from e
|
||||
|
||||
# block before returning
|
||||
start_time = time.time()
|
||||
response = operation_future.result()
|
||||
duration = (time.time() - start_time) // 60
|
||||
print(
|
||||
"[Ray on Vertex AI]: Successfully updated the cluster ({} mininutes elapsed).".format(
|
||||
duration
|
||||
)
|
||||
)
|
||||
return response.name
|
||||
@@ -0,0 +1,79 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
"""Utility to interact with Ray-on-Vertex dashboard."""
|
||||
|
||||
from ray.dashboard.modules import dashboard_sdk as oss_dashboard_sdk
|
||||
|
||||
from .util import _gapic_utils
|
||||
from .util import _validation_utils
|
||||
|
||||
|
||||
def get_job_submission_client_cluster_info(
|
||||
address: str, *args, **kwargs
|
||||
) -> oss_dashboard_sdk.ClusterInfo:
|
||||
"""A vertex_ray implementation of get_job_submission_client_cluster_info().
|
||||
|
||||
Implements
|
||||
https://github.com/ray-project/ray/blob/ray-2.42.0/python/ray/dashboard/modules/dashboard_sdk.py#L84
|
||||
This will be called in from Ray Job API Python client.
|
||||
|
||||
Args:
|
||||
address: Address without the module prefix `vertex_ray` but otherwise
|
||||
the same format as passed to ray.init(address="vertex_ray://...").
|
||||
*args: Remainder of positional args that might be passed down from
|
||||
the framework.
|
||||
**kwargs: Remainder of keyword args that might be passed down from
|
||||
the framework.
|
||||
|
||||
Returns:
|
||||
An instance of ClusterInfo that contains address, cookies and
|
||||
metadata for SubmissionClient to use.
|
||||
|
||||
Raises:
|
||||
RuntimeError if head_address is None.
|
||||
"""
|
||||
if _validation_utils.valid_dashboard_address(address):
|
||||
dashboard_address = address
|
||||
else:
|
||||
address = _validation_utils.maybe_reconstruct_resource_name(address)
|
||||
_validation_utils.valid_resource_name(address)
|
||||
|
||||
resource_name = address
|
||||
response = _gapic_utils.get_persistent_resource(resource_name)
|
||||
|
||||
dashboard_address = response.resource_runtime.access_uris.get(
|
||||
"RAY_DASHBOARD_URI", None
|
||||
)
|
||||
|
||||
if dashboard_address is None:
|
||||
raise RuntimeError(
|
||||
"[Ray on Vertex AI]: Unable to obtain a response from the backend."
|
||||
)
|
||||
# If passing the dashboard uri, programmatically get headers
|
||||
bearer_token = _validation_utils.get_bearer_token()
|
||||
if kwargs.get("headers", None) is None:
|
||||
kwargs["headers"] = {
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": "Bearer {}".format(bearer_token),
|
||||
}
|
||||
return oss_dashboard_sdk.get_job_submission_client_cluster_info(
|
||||
address=dashboard_address,
|
||||
_use_tls=True,
|
||||
*args,
|
||||
**kwargs,
|
||||
)
|
||||
192
.venv/lib/python3.10/site-packages/vertex_ray/data.py
Normal file
192
.venv/lib/python3.10/site-packages/vertex_ray/data.py
Normal file
@@ -0,0 +1,192 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
# Copyright 2024 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
import warnings
|
||||
import ray.data
|
||||
from ray.data.dataset import Dataset
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from google.cloud.aiplatform.vertex_ray.bigquery_datasource import (
|
||||
_BigQueryDatasource,
|
||||
)
|
||||
|
||||
try:
|
||||
from google.cloud.aiplatform.vertex_ray.bigquery_datasink import (
|
||||
_BigQueryDatasink,
|
||||
)
|
||||
except ImportError:
|
||||
_BigQueryDatasink = None
|
||||
|
||||
from google.cloud.aiplatform.vertex_ray.util._validation_utils import (
|
||||
_V2_4_WARNING_MESSAGE,
|
||||
_V2_9_WARNING_MESSAGE,
|
||||
)
|
||||
|
||||
|
||||
def read_bigquery(
|
||||
project_id: Optional[str] = None,
|
||||
dataset: Optional[str] = None,
|
||||
query: Optional[str] = None,
|
||||
*,
|
||||
parallelism: int = -1,
|
||||
ray_remote_args: Dict[str, Any] = None,
|
||||
concurrency: Optional[int] = None,
|
||||
override_num_blocks: Optional[int] = None,
|
||||
) -> Dataset:
|
||||
"""Create a dataset from BigQuery.
|
||||
|
||||
The data to read from is specified via the ``project_id``, ``dataset``
|
||||
and/or ``query`` parameters.
|
||||
|
||||
Args:
|
||||
project_id: The name of the associated Google Cloud Project that hosts
|
||||
the dataset to read.
|
||||
dataset: The name of the dataset hosted in BigQuery in the format of
|
||||
``dataset_id.table_id``. Both the dataset_id and table_id must exist
|
||||
otherwise an exception will be raised.
|
||||
query: The query to execute. The dataset is created from the results of
|
||||
executing the query if provided. Otherwise, the entire dataset is read.
|
||||
For query syntax guidelines, see
|
||||
https://cloud.google.com/bigquery/docs/reference/standard-sql/query-syntax
|
||||
parallelism: 2.33.0, 2.42.0: This argument is deprecated. Use
|
||||
``override_num_blocks`` argument. 2.9.3: The requested parallelism of
|
||||
the read. If -1, it will be automatically chosen based on the available
|
||||
cluster resources and estimated in-memory data size.
|
||||
ray_remote_args: kwargs passed to ray.remote in the read tasks.
|
||||
concurrency: Supported for 2.33.0 and 2.42.0 only: The maximum number of
|
||||
Ray tasks to run concurrently. Set this to control number of tasks to
|
||||
run concurrently. This doesn't change the total number of tasks run or
|
||||
the total number of output blocks. By default, concurrency is
|
||||
dynamically decided based on the available resources.
|
||||
override_num_blocks: Supported for 2.33.0 and 2.42.0 only: Override the
|
||||
number of output blocks from all read tasks. By default, the number of
|
||||
output blocks is dynamically decided based on input data size and
|
||||
available resources. You shouldn't manually set this value in most
|
||||
cases.
|
||||
|
||||
Returns:
|
||||
Dataset producing rows from the results of executing the query
|
||||
or reading the entire dataset on the specified BigQuery dataset.
|
||||
"""
|
||||
datasource = _BigQueryDatasource(
|
||||
project_id=project_id,
|
||||
dataset=dataset,
|
||||
query=query,
|
||||
)
|
||||
|
||||
if ray.__version__ == "2.9.3":
|
||||
warnings.warn(_V2_9_WARNING_MESSAGE, DeprecationWarning, stacklevel=1)
|
||||
# Concurrency and override_num_blocks are not supported in 2.9.3
|
||||
return ray.data.read_datasource(
|
||||
datasource=datasource,
|
||||
parallelism=parallelism,
|
||||
ray_remote_args=ray_remote_args,
|
||||
)
|
||||
elif ray.__version__ in ("2.33.0", "2.42.0"):
|
||||
return ray.data.read_datasource(
|
||||
datasource=datasource,
|
||||
parallelism=parallelism,
|
||||
ray_remote_args=ray_remote_args,
|
||||
concurrency=concurrency,
|
||||
override_num_blocks=override_num_blocks,
|
||||
)
|
||||
else:
|
||||
raise ImportError(
|
||||
f"[Ray on Vertex AI]: Unsupported version {ray.__version__}."
|
||||
+ "Only 2.42.0, 2.33.0, and 2.9.3 are supported."
|
||||
)
|
||||
|
||||
|
||||
def write_bigquery(
|
||||
ds: Dataset,
|
||||
project_id: Optional[str] = None,
|
||||
dataset: Optional[str] = None,
|
||||
max_retry_cnt: int = 10,
|
||||
ray_remote_args: Dict[str, Any] = None,
|
||||
overwrite_table: Optional[bool] = True,
|
||||
concurrency: Optional[int] = None,
|
||||
) -> Any:
|
||||
"""Write the dataset to a BigQuery dataset table.
|
||||
|
||||
Args:
|
||||
ds: The dataset to write.
|
||||
project_id: The name of the associated Google Cloud Project that hosts
|
||||
the dataset table to write to.
|
||||
dataset: The name of the dataset table hosted in BigQuery in the format of
|
||||
``dataset_id.table_id``.
|
||||
The dataset table is created if it doesn't already exist.
|
||||
In 2.9.3, the table_id is overwritten if it exists.
|
||||
max_retry_cnt: The maximum number of retries that an individual block write
|
||||
is retried due to BigQuery rate limiting errors.
|
||||
The default number of retries is 10.
|
||||
ray_remote_args: kwargs passed to ray.remote in the write tasks.
|
||||
overwrite_table: Not supported in 2.9.3.
|
||||
2.33.0, 2.42.0: Whether the write will overwrite the table if it already
|
||||
exists. The default behavior is to overwrite the table.
|
||||
If false, will append to the table if it exists.
|
||||
concurrency: Not supported in 2.9.3.
|
||||
2.33.0, 2.42.0: The maximum number of Ray tasks to run concurrently. Set this
|
||||
to control number of tasks to run concurrently. This doesn't change the
|
||||
total number of tasks run or the total number of output blocks. By default,
|
||||
concurrency is dynamically decided based on the available resources.
|
||||
"""
|
||||
if ray.__version__ == "2.4.0":
|
||||
raise RuntimeError(_V2_4_WARNING_MESSAGE)
|
||||
|
||||
elif ray.__version__ in ("2.9.3", "2.33.0", "2.42.0"):
|
||||
if ray.__version__ == "2.9.3":
|
||||
warnings.warn(_V2_9_WARNING_MESSAGE, DeprecationWarning, stacklevel=1)
|
||||
if ray_remote_args is None:
|
||||
ray_remote_args = {}
|
||||
|
||||
# Each write task will launch individual remote tasks to write each block
|
||||
# To avoid duplicate block writes, the write task should not be retried
|
||||
if ray_remote_args.get("max_retries", 0) != 0:
|
||||
print(
|
||||
"[Ray on Vertex AI]: The max_retries of a BigQuery Write "
|
||||
"Task should be set to 0 to avoid duplicate writes."
|
||||
)
|
||||
else:
|
||||
ray_remote_args["max_retries"] = 0
|
||||
|
||||
if ray.__version__ == "2.9.3":
|
||||
# Concurrency and overwrite_table are not supported in 2.9.3
|
||||
datasink = _BigQueryDatasink(
|
||||
project_id=project_id,
|
||||
dataset=dataset,
|
||||
max_retry_cnt=max_retry_cnt,
|
||||
)
|
||||
return ds.write_datasink(
|
||||
datasink=datasink,
|
||||
ray_remote_args=ray_remote_args,
|
||||
)
|
||||
elif ray.__version__ in ("2.33.0", "2.42.0"):
|
||||
datasink = _BigQueryDatasink(
|
||||
project_id=project_id,
|
||||
dataset=dataset,
|
||||
max_retry_cnt=max_retry_cnt,
|
||||
overwrite_table=overwrite_table,
|
||||
)
|
||||
return ds.write_datasink(
|
||||
datasink=datasink,
|
||||
ray_remote_args=ray_remote_args,
|
||||
concurrency=concurrency,
|
||||
)
|
||||
else:
|
||||
raise ImportError(
|
||||
f"[Ray on Vertex AI]: Unsupported version {ray.__version__}."
|
||||
+ "Only 2.42.0, 2.33.0 and 2.9.3 are supported."
|
||||
)
|
||||
@@ -0,0 +1,18 @@
|
||||
"""Ray on Vertex AI Prediction."""
|
||||
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
# Copyright 2023 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
Binary file not shown.
@@ -0,0 +1,22 @@
|
||||
"""Ray on Vertex AI Prediction Tensorflow."""
|
||||
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
# Copyright 2023 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
from .register import register_sklearn
|
||||
|
||||
__all__ = ("register_sklearn",)
|
||||
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,173 @@
|
||||
"""Regsiter Scikit Learn for Ray on Vertex AI."""
|
||||
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
# Copyright 2023 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
import os
|
||||
import pickle
|
||||
import warnings
|
||||
import ray
|
||||
import ray.cloudpickle as cpickle
|
||||
import tempfile
|
||||
from typing import Optional, TYPE_CHECKING
|
||||
|
||||
from google.cloud import aiplatform
|
||||
from google.cloud.aiplatform import initializer
|
||||
from google.cloud.aiplatform import utils
|
||||
from google.cloud.aiplatform.utils import gcs_utils
|
||||
from google.cloud.aiplatform.vertex_ray.predict.util import constants
|
||||
from google.cloud.aiplatform.vertex_ray.predict.util import (
|
||||
predict_utils,
|
||||
)
|
||||
from google.cloud.aiplatform.vertex_ray.util._validation_utils import (
|
||||
_V2_4_WARNING_MESSAGE,
|
||||
_V2_9_WARNING_MESSAGE,
|
||||
)
|
||||
|
||||
|
||||
try:
|
||||
from ray.train import sklearn as ray_sklearn
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import sklearn
|
||||
|
||||
except ImportError as ie:
|
||||
if ray.__version__ < "2.42.0":
|
||||
raise ModuleNotFoundError("Sklearn isn't installed.") from ie
|
||||
else:
|
||||
sklearn = None
|
||||
|
||||
|
||||
def register_sklearn(
|
||||
checkpoint: "ray_sklearn.SklearnCheckpoint",
|
||||
artifact_uri: Optional[str] = None,
|
||||
display_name: Optional[str] = None,
|
||||
**kwargs,
|
||||
) -> aiplatform.Model:
|
||||
"""Uploads a Ray Sklearn Checkpoint as Sklearn Model to Model Registry.
|
||||
|
||||
Example usage:
|
||||
from vertex_ray.predict import sklearn
|
||||
from ray.train.sklearn import SklearnCheckpoint
|
||||
|
||||
trainer = SklearnTrainer(estimator=RandomForestClassifier, ...)
|
||||
result = trainer.fit()
|
||||
sklearn_checkpoint = SklearnCheckpoint.from_checkpoint(result.checkpoint)
|
||||
|
||||
my_model = sklearn.register_sklearn(
|
||||
checkpoint=sklearn_checkpoint,
|
||||
artifact_uri="gs://{gcs-bucket-name}/path/to/store"
|
||||
)
|
||||
|
||||
|
||||
Args:
|
||||
checkpoint: SklearnCheckpoint instance.
|
||||
artifact_uri (str):
|
||||
Optional. The path to the directory where Model Artifacts will be saved. If
|
||||
not set, will use staging bucket set in aiplatform.init().
|
||||
display_name (str):
|
||||
Optional. The display name of the Model. The name can be up to 128
|
||||
characters long and can be consist of any UTF-8 characters.
|
||||
**kwargs:
|
||||
Any kwargs will be passed to aiplatform.Model registration.
|
||||
|
||||
Returns:
|
||||
model (aiplatform.Model):
|
||||
Instantiated representation of the uploaded model resource.
|
||||
|
||||
Raises:
|
||||
ValueError: Invalid Argument.
|
||||
RuntimeError: Only Ray version 2.9.3 is supported.
|
||||
"""
|
||||
ray_version = ray.__version__
|
||||
if ray_version != "2.9.3":
|
||||
raise RuntimeError(
|
||||
f"Ray version {ray_version} is not supported to upload Sklearn"
|
||||
" model to Vertex Model Registry yet. Please use Ray 2.9.3."
|
||||
)
|
||||
if ray_version == "2.9.3":
|
||||
warnings.warn(_V2_9_WARNING_MESSAGE, DeprecationWarning, stacklevel=1)
|
||||
|
||||
artifact_uri = artifact_uri or initializer.global_config.staging_bucket
|
||||
predict_utils.validate_artifact_uri(artifact_uri)
|
||||
display_model_name = (
|
||||
(f"ray-on-vertex-registered-sklearn-model-{utils.timestamped_unique_name()}")
|
||||
if display_name is None
|
||||
else display_name
|
||||
)
|
||||
estimator = _get_estimator_from(checkpoint)
|
||||
|
||||
model_dir = os.path.join(artifact_uri, display_model_name)
|
||||
file_path = os.path.join(model_dir, constants._PICKLE_FILE_NAME)
|
||||
|
||||
with tempfile.NamedTemporaryFile(suffix=constants._PICKLE_EXTENTION) as temp_file:
|
||||
pickle.dump(estimator, temp_file)
|
||||
gcs_utils.upload_to_gcs(temp_file.name, file_path)
|
||||
return aiplatform.Model.upload_scikit_learn_model_file(
|
||||
model_file_path=temp_file.name, display_name=display_model_name, **kwargs
|
||||
)
|
||||
|
||||
|
||||
def _get_estimator_from(
|
||||
checkpoint: "ray_sklearn.SklearnCheckpoint",
|
||||
) -> "sklearn.base.BaseEstimator":
|
||||
"""Converts a SklearnCheckpoint to sklearn estimator.
|
||||
|
||||
Args:
|
||||
checkpoint: SklearnCheckpoint instance.
|
||||
|
||||
Returns:
|
||||
A Sklearn BaseEstimator
|
||||
|
||||
Raises:
|
||||
ValueError: Invalid Argument.
|
||||
RuntimeError: Model not found.
|
||||
RuntimeError: Ray version 2.4 is not supported.
|
||||
RuntimeError: Only Ray version 2.9.3 is supported.
|
||||
"""
|
||||
|
||||
ray_version = ray.__version__
|
||||
if ray_version == "2.4.0":
|
||||
raise RuntimeError(_V2_4_WARNING_MESSAGE)
|
||||
if ray_version != "2.9.3":
|
||||
raise RuntimeError(
|
||||
f"Ray version {ray_version} is not supported to convert a Sklearn"
|
||||
" checkpoint to sklearn estimator on Vertex yet. Please use Ray 2.9.3."
|
||||
)
|
||||
|
||||
try:
|
||||
return checkpoint.get_model()
|
||||
except AttributeError:
|
||||
model_file_name = ray.train.sklearn.SklearnCheckpoint.MODEL_FILENAME
|
||||
|
||||
model_path = os.path.join(checkpoint.path, model_file_name)
|
||||
|
||||
if os.path.exists(model_path):
|
||||
with open(model_path, mode="rb") as f:
|
||||
obj = pickle.load(f)
|
||||
else:
|
||||
try:
|
||||
# Download from GCS to temp and then load_model
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
gcs_utils.download_from_gcs("gs://" + checkpoint.path, temp_dir)
|
||||
with open(f"{temp_dir}/{model_file_name}", mode="rb") as f:
|
||||
obj = cpickle.load(f)
|
||||
except Exception as e:
|
||||
raise RuntimeError(
|
||||
f"{model_file_name} not found in this checkpoint due to: {e}."
|
||||
)
|
||||
return obj
|
||||
@@ -0,0 +1,22 @@
|
||||
"""Ray on Vertex AI Prediction Tensorflow."""
|
||||
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
# Copyright 2023 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
from .register import register_tensorflow
|
||||
|
||||
__all__ = ("register_tensorflow",)
|
||||
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,161 @@
|
||||
"""Regsiter Tensorflow for Ray on Vertex AI."""
|
||||
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
# Copyright 2023 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
import os
|
||||
import logging
|
||||
import ray
|
||||
from typing import Callable, Optional, Union, TYPE_CHECKING
|
||||
import warnings
|
||||
from google.cloud import aiplatform
|
||||
from google.cloud.aiplatform import initializer
|
||||
from google.cloud.aiplatform import utils
|
||||
from google.cloud.aiplatform.vertex_ray.predict.util import constants
|
||||
from google.cloud.aiplatform.vertex_ray.predict.util import (
|
||||
predict_utils,
|
||||
)
|
||||
from google.cloud.aiplatform.vertex_ray.util._validation_utils import (
|
||||
_V2_4_WARNING_MESSAGE,
|
||||
_V2_9_WARNING_MESSAGE,
|
||||
)
|
||||
|
||||
|
||||
try:
|
||||
from ray.train import tensorflow as ray_tensorflow
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import tensorflow as tf
|
||||
|
||||
except ModuleNotFoundError as mnfe:
|
||||
raise ModuleNotFoundError("Tensorflow isn't installed.") from mnfe
|
||||
|
||||
|
||||
def register_tensorflow(
|
||||
checkpoint: ray_tensorflow.TensorflowCheckpoint,
|
||||
artifact_uri: Optional[str] = None,
|
||||
_model: Optional[Union["tf.keras.Model", Callable[[], "tf.keras.Model"]]] = None,
|
||||
display_name: Optional[str] = None,
|
||||
tensorflow_version: Optional[str] = None,
|
||||
**kwargs,
|
||||
) -> aiplatform.Model:
|
||||
"""Uploads a Ray Tensorflow Checkpoint as Tensorflow Model to Model Registry.
|
||||
|
||||
Example usage:
|
||||
from vertex_ray.predict import tensorflow
|
||||
|
||||
def create_model():
|
||||
model = tf.keras.Sequential(...)
|
||||
...
|
||||
return model
|
||||
|
||||
result = trainer.fit()
|
||||
my_model = tensorflow.register_tensorflow(
|
||||
checkpoint=result.checkpoint,
|
||||
_model=create_model,
|
||||
artifact_uri="gs://{gcs-bucket-name}/path/to/store",
|
||||
use_gpu=True
|
||||
)
|
||||
|
||||
1. `use_gpu` will be passed to aiplatform.Model.upload_tensorflow_saved_model()
|
||||
2. The `create_model` provides the model_definition which is required if
|
||||
you create the TensorflowCheckpoint using `from_model` method.
|
||||
More here, https://docs.ray.io/en/latest/train/api/doc/ray.train.tensorflow.TensorflowCheckpoint.get_model.html#ray.train.tensorflow.TensorflowCheckpoint.get_model
|
||||
|
||||
Args:
|
||||
checkpoint: TensorflowCheckpoint instance.
|
||||
artifact_uri (str):
|
||||
Optional. The path to the directory where Model Artifacts will be saved. If
|
||||
not set, will use staging bucket set in aiplatform.init().
|
||||
_model: Tensorflow Model Definition. Refer
|
||||
https://docs.ray.io/en/latest/train/api/doc/ray.train.tensorflow.TensorflowCheckpoint.get_model.html#ray.train.tensorflow.TensorflowCheckpoint.get_model
|
||||
display_name (str):
|
||||
Optional. The display name of the Model. The name can be up to 128
|
||||
characters long and can be consist of any UTF-8 characters.
|
||||
tensorflow_version (str):
|
||||
Optional. The version of the Tensorflow serving container.
|
||||
Supported versions:
|
||||
https://cloud.google.com/vertex-ai/docs/predictions/pre-built-containers
|
||||
If the version is not specified, the latest version is used.
|
||||
**kwargs:
|
||||
Any kwargs will be passed to aiplatform.Model registration.
|
||||
|
||||
Returns:
|
||||
model (aiplatform.Model):
|
||||
Instantiated representation of the uploaded model resource.
|
||||
|
||||
Raises:
|
||||
ValueError: Invalid Argument.
|
||||
"""
|
||||
if ray.__version__ == "2.9.3":
|
||||
warnings.warn(_V2_9_WARNING_MESSAGE, DeprecationWarning, stacklevel=1)
|
||||
if tensorflow_version is None:
|
||||
tensorflow_version = constants._TENSORFLOW_VERSION
|
||||
artifact_uri = artifact_uri or initializer.global_config.staging_bucket
|
||||
predict_utils.validate_artifact_uri(artifact_uri)
|
||||
prefix = "ray-on-vertex-registered-tensorflow-model"
|
||||
display_model_name = (
|
||||
(f"{prefix}-{utils.timestamped_unique_name()}")
|
||||
if display_name is None
|
||||
else display_name
|
||||
)
|
||||
tf_model = _get_tensorflow_model_from(checkpoint, model=_model)
|
||||
model_dir = os.path.join(artifact_uri, prefix)
|
||||
try:
|
||||
import tensorflow as tf
|
||||
|
||||
tf.saved_model.save(tf_model, model_dir)
|
||||
except ImportError:
|
||||
logging.warning("TensorFlow must be installed to save the trained model.")
|
||||
return aiplatform.Model.upload_tensorflow_saved_model(
|
||||
saved_model_dir=model_dir,
|
||||
display_name=display_model_name,
|
||||
tensorflow_version=tensorflow_version,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
def _get_tensorflow_model_from(
|
||||
checkpoint: ray_tensorflow.TensorflowCheckpoint,
|
||||
model: Optional[Union["tf.keras.Model", Callable[[], "tf.keras.Model"]]] = None,
|
||||
) -> "tf.keras.Model":
|
||||
"""Converts a TensorflowCheckpoint to Tensorflow Model.
|
||||
|
||||
Args:
|
||||
checkpoint: TensorflowCheckpoint instance.
|
||||
model: Tensorflow Model Defination.
|
||||
|
||||
Returns:
|
||||
A Tensorflow Native Framework Model.
|
||||
|
||||
Raises:
|
||||
ValueError: Invalid Argument.
|
||||
RuntimeError: Ray version 2.4.0 is not supported.
|
||||
"""
|
||||
ray_version = ray.__version__
|
||||
if ray_version == "2.4.0":
|
||||
raise RuntimeError(_V2_4_WARNING_MESSAGE)
|
||||
|
||||
try:
|
||||
import tensorflow as tf
|
||||
|
||||
try:
|
||||
return tf.saved_model.load(checkpoint.path)
|
||||
except OSError:
|
||||
return tf.saved_model.load("gs://" + checkpoint.path)
|
||||
|
||||
except ImportError:
|
||||
logging.warning("TensorFlow must be installed to load the trained model.")
|
||||
@@ -0,0 +1,22 @@
|
||||
"""Ray on Vertex AI Prediction Tensorflow."""
|
||||
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
# Copyright 2023 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
from .register import get_pytorch_model_from
|
||||
|
||||
__all__ = ("get_pytorch_model_from",)
|
||||
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,112 @@
|
||||
"""Regsiter Torch for Ray on Vertex AI."""
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
# Copyright 2023 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
import warnings
|
||||
import ray
|
||||
from ray.air._internal.torch_utils import load_torch_model
|
||||
import tempfile
|
||||
from google.cloud.aiplatform.vertex_ray.util._validation_utils import (
|
||||
_V2_4_WARNING_MESSAGE,
|
||||
_V2_9_WARNING_MESSAGE,
|
||||
)
|
||||
from google.cloud.aiplatform.utils import gcs_utils
|
||||
from typing import Optional
|
||||
|
||||
|
||||
try:
|
||||
from ray.train import torch as ray_torch
|
||||
import torch
|
||||
except ModuleNotFoundError as mnfe:
|
||||
raise ModuleNotFoundError("Torch isn't installed.") from mnfe
|
||||
|
||||
|
||||
def get_pytorch_model_from(
|
||||
checkpoint: ray_torch.TorchCheckpoint,
|
||||
model: Optional[torch.nn.Module] = None,
|
||||
) -> torch.nn.Module:
|
||||
"""Converts a TorchCheckpoint to Pytorch Model.
|
||||
|
||||
Example:
|
||||
from vertex_ray.predict import torch
|
||||
result = TorchTrainer.fit(...)
|
||||
|
||||
pytorch_model = torch.get_pytorch_model_from(
|
||||
checkpoint=result.checkpoint
|
||||
)
|
||||
|
||||
Args:
|
||||
checkpoint: TorchCheckpoint instance.
|
||||
model: If the checkpoint contains a model state dict, and not the model
|
||||
itself, then the state dict will be loaded to this `model`. Otherwise,
|
||||
the model will be discarded.
|
||||
|
||||
Returns:
|
||||
A Pytorch Native Framework Model.
|
||||
|
||||
Raises:
|
||||
ValueError: Invalid Argument.
|
||||
ModuleNotFoundError: PyTorch isn't installed.
|
||||
RuntimeError: Model not found.
|
||||
RuntimeError: Ray version 2.4 is not supported.
|
||||
RuntimeError: Only Ray version 2.9.3 is supported.
|
||||
"""
|
||||
ray_version = ray.__version__
|
||||
if ray_version == "2.4.0":
|
||||
raise RuntimeError(_V2_4_WARNING_MESSAGE)
|
||||
if ray_version != "2.9.3":
|
||||
raise RuntimeError(
|
||||
f"Ray on Vertex does not support Ray version {ray_version} to"
|
||||
" convert PyTorch model artifacts yet. Please use Ray 2.9.3."
|
||||
)
|
||||
if ray_version == "2.9.3":
|
||||
warnings.warn(_V2_9_WARNING_MESSAGE, DeprecationWarning, stacklevel=1)
|
||||
|
||||
try:
|
||||
return checkpoint.get_model()
|
||||
except AttributeError:
|
||||
model_file_name = ray.train.torch.TorchCheckpoint.MODEL_FILENAME
|
||||
|
||||
model_path = os.path.join(checkpoint.path, model_file_name)
|
||||
|
||||
try:
|
||||
import torch
|
||||
|
||||
except ModuleNotFoundError as mnfe:
|
||||
raise ModuleNotFoundError("PyTorch isn't installed.") from mnfe
|
||||
|
||||
if os.path.exists(model_path):
|
||||
model_or_state_dict = torch.load(
|
||||
model_path, map_location="cpu", weights_only=True
|
||||
)
|
||||
else:
|
||||
try:
|
||||
# Download from GCS to temp and then load_model
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
gcs_utils.download_from_gcs("gs://" + checkpoint.path, temp_dir)
|
||||
model_or_state_dict = torch.load(
|
||||
f"{temp_dir}/{model_file_name}",
|
||||
map_location="cpu",
|
||||
weights_only=True,
|
||||
)
|
||||
except Exception as e:
|
||||
raise RuntimeError(
|
||||
f"{model_file_name} not found in this checkpoint due to: {e}."
|
||||
)
|
||||
|
||||
model = load_torch_model(saved_model=model_or_state_dict, model_definition=model)
|
||||
return model
|
||||
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,28 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2023 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
"""Constants."""
|
||||
|
||||
# Required Names for model files are specified here
|
||||
# https://cloud.google.com/vertex-ai/docs/training/exporting-model-artifacts#framework-specific_requirements
|
||||
_PICKLE_FILE_NAME = "model.pkl"
|
||||
_PICKLE_EXTENTION = ".pkl"
|
||||
|
||||
_XGBOOST_VERSION = "1.6"
|
||||
# TensorFlow 2.13 requires typing_extensions<4.6 and will cause errors in Ray.
|
||||
# https://github.com/tensorflow/tensorflow/blob/v2.13.0/tensorflow/tools/pip_package/setup.py#L100
|
||||
# 2.13 is the latest supported version of Vertex prebuilt prediction container.
|
||||
# Set 2.12 as default here since 2.13 cause errors.
|
||||
_TENSORFLOW_VERSION = "2.12"
|
||||
@@ -0,0 +1,24 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
# Copyright 2023 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
"""Predict Utils.
|
||||
"""
|
||||
|
||||
|
||||
def validate_artifact_uri(artifact_uri: str) -> None:
|
||||
if artifact_uri is None or not artifact_uri.startswith("gs://"):
|
||||
raise ValueError("Argument 'artifact_uri' should start with 'gs://'.")
|
||||
@@ -0,0 +1,22 @@
|
||||
"""Ray on Vertex AI Prediction Tensorflow."""
|
||||
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
# Copyright 2023 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
from .register import register_xgboost
|
||||
|
||||
__all__ = ("register_xgboost",)
|
||||
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,189 @@
|
||||
"""Regsiter XGBoost for Ray on Vertex AI."""
|
||||
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
# Copyright 2023 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
import os
|
||||
import pickle
|
||||
import ray
|
||||
import tempfile
|
||||
from typing import Optional, TYPE_CHECKING
|
||||
import warnings
|
||||
from google.cloud import aiplatform
|
||||
from google.cloud.aiplatform import initializer
|
||||
from google.cloud.aiplatform import utils
|
||||
from google.cloud.aiplatform.utils import gcs_utils
|
||||
from google.cloud.aiplatform.vertex_ray.predict.util import constants
|
||||
from google.cloud.aiplatform.vertex_ray.predict.util import (
|
||||
predict_utils,
|
||||
)
|
||||
from google.cloud.aiplatform.vertex_ray.util._validation_utils import (
|
||||
_V2_4_WARNING_MESSAGE,
|
||||
_V2_9_WARNING_MESSAGE,
|
||||
)
|
||||
|
||||
|
||||
try:
|
||||
from ray.train import xgboost as ray_xgboost
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import xgboost
|
||||
|
||||
except ModuleNotFoundError as mnfe:
|
||||
if ray.__version__ == "2.9.3":
|
||||
raise ModuleNotFoundError("XGBoost isn't installed.") from mnfe
|
||||
else:
|
||||
xgboost = None
|
||||
|
||||
|
||||
def register_xgboost(
|
||||
checkpoint: "ray_xgboost.XGBoostCheckpoint",
|
||||
artifact_uri: Optional[str] = None,
|
||||
display_name: Optional[str] = None,
|
||||
xgboost_version: Optional[str] = None,
|
||||
**kwargs,
|
||||
) -> aiplatform.Model:
|
||||
"""Uploads a Ray XGBoost Checkpoint as XGBoost Model to Model Registry.
|
||||
|
||||
Example usage:
|
||||
from vertex_ray.predict import xgboost
|
||||
from ray.train.xgboost import XGBoostCheckpoint
|
||||
|
||||
trainer = XGBoostTrainer(...)
|
||||
result = trainer.fit()
|
||||
xgboost_checkpoint = XGBoostCheckpoint.from_checkpoint(result.checkpoint)
|
||||
|
||||
my_model = xgboost.register_xgboost(
|
||||
checkpoint=xgboost_checkpoint,
|
||||
artifact_uri="gs://{gcs-bucket-name}/path/to/store",
|
||||
display_name="my-ray-on-vertex-xgboost-model",
|
||||
)
|
||||
|
||||
|
||||
Args:
|
||||
checkpoint: XGBoostCheckpoint instance.
|
||||
artifact_uri (str):
|
||||
The path to the directory where Model Artifacts will be saved. If
|
||||
not set, will use staging bucket set in aiplatform.init().
|
||||
display_name (str):
|
||||
Optional. The display name of the Model. The name can be up to 128
|
||||
characters long and can be consist of any UTF-8 characters.
|
||||
xgboost_version (str): Optional. The version of the XGBoost serving container.
|
||||
Supported versions:
|
||||
https://cloud.google.com/vertex-ai/docs/predictions/pre-built-containers
|
||||
If the version is not specified, the version 1.6 is used.
|
||||
**kwargs:
|
||||
Any kwargs will be passed to aiplatform.Model registration.
|
||||
|
||||
Returns:
|
||||
model (aiplatform.Model):
|
||||
Instantiated representation of the uploaded model resource.
|
||||
|
||||
Raises:
|
||||
ValueError: Invalid Argument.
|
||||
RuntimeError: Only Ray version 2.9.3 is supported.
|
||||
"""
|
||||
ray_version = ray.__version__
|
||||
if ray_version != "2.9.3":
|
||||
raise RuntimeError(
|
||||
f"Ray version {ray_version} is not supported to upload XGBoost"
|
||||
" model to Vertex Model Registry yet. Please use Ray 2.9.3."
|
||||
)
|
||||
if ray_version == "2.9.3":
|
||||
warnings.warn(_V2_9_WARNING_MESSAGE, DeprecationWarning, stacklevel=1)
|
||||
artifact_uri = artifact_uri or initializer.global_config.staging_bucket
|
||||
predict_utils.validate_artifact_uri(artifact_uri)
|
||||
display_model_name = (
|
||||
(f"ray-on-vertex-registered-xgboost-model-{utils.timestamped_unique_name()}")
|
||||
if display_name is None
|
||||
else display_name
|
||||
)
|
||||
model = _get_xgboost_model_from(checkpoint)
|
||||
|
||||
model_dir = os.path.join(artifact_uri, display_model_name)
|
||||
file_path = os.path.join(model_dir, constants._PICKLE_FILE_NAME)
|
||||
if xgboost_version is None:
|
||||
xgboost_version = constants._XGBOOST_VERSION
|
||||
|
||||
with tempfile.NamedTemporaryFile(suffix=constants._PICKLE_EXTENTION) as temp_file:
|
||||
pickle.dump(model, temp_file)
|
||||
gcs_utils.upload_to_gcs(temp_file.name, file_path)
|
||||
return aiplatform.Model.upload_xgboost_model_file(
|
||||
model_file_path=temp_file.name,
|
||||
display_name=display_model_name,
|
||||
xgboost_version=xgboost_version,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
def _get_xgboost_model_from(
|
||||
checkpoint: "ray_xgboost.XGBoostCheckpoint",
|
||||
) -> "xgboost.Booster":
|
||||
"""Converts a XGBoostCheckpoint to XGBoost model.
|
||||
|
||||
Args:
|
||||
checkpoint: XGBoostCheckpoint instance.
|
||||
|
||||
Returns:
|
||||
A XGBoost core Booster
|
||||
|
||||
Raises:
|
||||
ValueError: Invalid Argument.
|
||||
ModuleNotFoundError: XGBoost isn't installed.
|
||||
RuntimeError: Model not found.
|
||||
RuntimeError: Ray version 2.4 is not supported.
|
||||
RuntimeError: Only Ray version 2.9.3 is supported.
|
||||
"""
|
||||
ray_version = ray.__version__
|
||||
if ray_version == "2.4.0":
|
||||
raise RuntimeError(_V2_4_WARNING_MESSAGE)
|
||||
if ray_version != "2.9.3":
|
||||
raise RuntimeError(
|
||||
f"Ray version {ray_version} is not supported to convert a XGBoost"
|
||||
" checkpoint to XGBoost model on Vertex yet. Please use Ray 2.9.3."
|
||||
)
|
||||
|
||||
try:
|
||||
# This works for Ray v2.5
|
||||
return checkpoint.get_model()
|
||||
except AttributeError:
|
||||
# This works for Ray v2.9
|
||||
model_file_name = ray.train.xgboost.XGBoostCheckpoint.MODEL_FILENAME
|
||||
|
||||
model_path = os.path.join(checkpoint.path, model_file_name)
|
||||
|
||||
try:
|
||||
import xgboost
|
||||
|
||||
except ModuleNotFoundError as mnfe:
|
||||
raise ModuleNotFoundError("XGBoost isn't installed.") from mnfe
|
||||
|
||||
booster = xgboost.Booster()
|
||||
if os.path.exists(model_path):
|
||||
booster.load_model(model_path)
|
||||
return booster
|
||||
|
||||
try:
|
||||
# Download from GCS to temp and then load_model
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
gcs_utils.download_from_gcs("gs://" + checkpoint.path, temp_dir)
|
||||
booster.load_model(f"{temp_dir}/{model_file_name}")
|
||||
return booster
|
||||
except Exception as e:
|
||||
raise RuntimeError(
|
||||
f"{model_file_name} not found in this checkpoint due to: {e}."
|
||||
)
|
||||
27
.venv/lib/python3.10/site-packages/vertex_ray/render.py
Normal file
27
.venv/lib/python3.10/site-packages/vertex_ray/render.py
Normal file
@@ -0,0 +1,27 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
# Copyright 2023 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
import pathlib
|
||||
from ray.widgets import Template
|
||||
|
||||
|
||||
class VertexRayTemplate(Template):
|
||||
"""Class which provides basic HTML templating."""
|
||||
|
||||
def __init__(self, file: str):
|
||||
with open(pathlib.Path(__file__).parent / "templates" / file, "r") as f:
|
||||
self.template = f.read()
|
||||
@@ -0,0 +1,4 @@
|
||||
<tr>
|
||||
<td style="text-align: left"><b>Interactive Terminal Uri:</b></td>
|
||||
<td style="text-align: left"><b><a href="https://{{ shell_uri }}" target="_blank">{{ shell_uri }}</a></b></td>
|
||||
</tr>
|
||||
@@ -0,0 +1,24 @@
|
||||
|
||||
<table class="jp-RenderedHTMLCommon" style="border-collapse: collapse;color: var(--jp-ui-font-color1);font-size: var(--jp-ui-font-size1);">
|
||||
<tr>
|
||||
<td style="text-align: left"><b>Python version:</b></td>
|
||||
<td style="text-align: left"><b>{{ python_version }}</b></td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td style="text-align: left"><b>Ray version:</b></td>
|
||||
<td style="text-align: left"><b> {{ ray_version }}</b></td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td style="text-align: left"><b>Vertex SDK version:</b></td>
|
||||
<td style="text-align: left"><b> {{ vertex_sdk_version }}</b></td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td style="text-align: left"><b>Dashboard:</b></td>
|
||||
<td style="text-align: left"><b><a href="https://{{ dashboard_url }}" target="_blank">{{ dashboard_url }}</a></b></td>
|
||||
</tr>
|
||||
{{ shell_uri_row }}
|
||||
<tr>
|
||||
<td style="text-align: left"><b>Cluster Name:</b></td>
|
||||
<td style="text-align: left"><b> {{ persistent_resource_id }}</b></td>
|
||||
</tr>
|
||||
</table>
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,288 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
# Copyright 2023 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
import datetime
|
||||
import logging
|
||||
import time
|
||||
from typing import Optional
|
||||
|
||||
from google.api_core import exceptions
|
||||
from google.cloud.aiplatform import initializer
|
||||
from google.cloud.aiplatform.utils import (
|
||||
PersistentResourceClientWithOverride,
|
||||
)
|
||||
from google.cloud.aiplatform.vertex_ray.util import _validation_utils
|
||||
from google.cloud.aiplatform.vertex_ray.util.resources import (
|
||||
AutoscalingSpec,
|
||||
Cluster,
|
||||
PscIConfig,
|
||||
Resources,
|
||||
)
|
||||
from google.cloud.aiplatform_v1beta1.types.persistent_resource import (
|
||||
PersistentResource,
|
||||
)
|
||||
from google.cloud.aiplatform_v1beta1.types.persistent_resource_service import (
|
||||
GetPersistentResourceRequest,
|
||||
)
|
||||
|
||||
|
||||
_PRIVATE_PREVIEW_IMAGE = "-docker.pkg.dev/vertex-ai/training/tf-"
|
||||
_OFFICIAL_IMAGE = "-docker.pkg.dev/vertex-ai/training/ray-"
|
||||
|
||||
|
||||
def create_persistent_resource_client():
|
||||
# location is inhereted from the global configuration at aiplatform.init().
|
||||
return initializer.global_config.create_client(
|
||||
client_class=PersistentResourceClientWithOverride,
|
||||
appended_gapic_version="vertex_ray",
|
||||
).select_version("v1beta1")
|
||||
|
||||
|
||||
def polling_delay(num_attempts: int, time_scale: float) -> datetime.timedelta:
|
||||
"""Computes a delay to the next attempt to poll the Vertex service.
|
||||
|
||||
This does bounded exponential backoff, starting with $time_scale.
|
||||
If $time_scale == 0, it starts with a small time interval, less than
|
||||
1 second.
|
||||
|
||||
Args:
|
||||
num_attempts: The number of times have we polled and found that the
|
||||
desired result was not yet available.
|
||||
time_scale: The shortest polling interval, in seconds, or zero. Zero is
|
||||
treated as a small interval, less than 1 second.
|
||||
|
||||
Returns:
|
||||
A recommended delay interval, in seconds.
|
||||
"""
|
||||
# The polling schedule is slow initially , and then gets faster until 6
|
||||
# attempts (after that the sleeping time remains the same).
|
||||
small_interval = 30.0 # Seconds
|
||||
interval = max(time_scale, small_interval) * 0.765 ** min(num_attempts, 6)
|
||||
return datetime.timedelta(seconds=interval)
|
||||
|
||||
|
||||
def get_persistent_resource(
|
||||
persistent_resource_name: str, tolerance: Optional[int] = 0
|
||||
):
|
||||
"""Get persistent resource.
|
||||
|
||||
Args:
|
||||
persistent_resource_name:
|
||||
"projects/<project_num>/locations/<region>/persistentResources/<pr_id>".
|
||||
tolerance: number of attemps to get persistent resource.
|
||||
|
||||
Returns:
|
||||
aiplatform_v1.PersistentResource if state is RUNNING.
|
||||
|
||||
Raises:
|
||||
ValueError: Invalid cluster resource name.
|
||||
RuntimeError: Service returns error.
|
||||
RuntimeError: Cluster resource state is STOPPING.
|
||||
RuntimeError: Cluster resource state is ERROR.
|
||||
"""
|
||||
|
||||
client = create_persistent_resource_client()
|
||||
request = GetPersistentResourceRequest(name=persistent_resource_name)
|
||||
|
||||
# TODO(b/277117901): Add test cases for polling and error handling
|
||||
num_attempts = 0
|
||||
while True:
|
||||
try:
|
||||
response = client.get_persistent_resource(request)
|
||||
except exceptions.NotFound:
|
||||
response = None
|
||||
if num_attempts >= tolerance:
|
||||
raise ValueError(
|
||||
"[Ray on Vertex AI]: Invalid cluster_resource_name (404 not found)."
|
||||
)
|
||||
if response:
|
||||
if response.error.message:
|
||||
logging.error("[Ray on Vertex AI]: %s" % response.error.message)
|
||||
raise RuntimeError("[Ray on Vertex AI]: Cluster returned an error.")
|
||||
|
||||
print("[Ray on Vertex AI]: Cluster State =", response.state)
|
||||
if response.state == PersistentResource.State.RUNNING:
|
||||
return response
|
||||
elif response.state == PersistentResource.State.STOPPING:
|
||||
raise RuntimeError("[Ray on Vertex AI]: The cluster is stopping.")
|
||||
elif response.state == PersistentResource.State.ERROR:
|
||||
raise RuntimeError(
|
||||
"[Ray on Vertex AI]: The cluster encountered an error."
|
||||
)
|
||||
# Polling decay
|
||||
sleep_time = polling_delay(num_attempts=num_attempts, time_scale=150.0)
|
||||
num_attempts += 1
|
||||
print(
|
||||
"Waiting for cluster provisioning; attempt {}; sleeping for {} seconds".format(
|
||||
num_attempts, sleep_time
|
||||
)
|
||||
)
|
||||
time.sleep(sleep_time.total_seconds())
|
||||
|
||||
|
||||
def persistent_resource_to_cluster(
|
||||
persistent_resource: PersistentResource,
|
||||
) -> Optional[Cluster]:
|
||||
"""Format a PersistentResource to a dictionary.
|
||||
|
||||
Args:
|
||||
persistent_resource: PersistentResource.
|
||||
Returns:
|
||||
Cluster.
|
||||
"""
|
||||
dashboard_address = persistent_resource.resource_runtime.access_uris.get(
|
||||
"RAY_DASHBOARD_URI"
|
||||
)
|
||||
cluster = Cluster(
|
||||
cluster_resource_name=persistent_resource.name,
|
||||
network=persistent_resource.network,
|
||||
reserved_ip_ranges=persistent_resource.reserved_ip_ranges,
|
||||
state=persistent_resource.state.name,
|
||||
labels=persistent_resource.labels,
|
||||
dashboard_address=dashboard_address,
|
||||
)
|
||||
if not persistent_resource.resource_runtime_spec.ray_spec:
|
||||
# skip PersistentResource without RaySpec
|
||||
logging.info(
|
||||
"[Ray on Vertex AI]: Cluster %s does not have Ray installed."
|
||||
% persistent_resource.name,
|
||||
)
|
||||
return
|
||||
if persistent_resource.psc_interface_config:
|
||||
cluster.psc_interface_config = PscIConfig(
|
||||
network_attachment=persistent_resource.psc_interface_config.network_attachment
|
||||
)
|
||||
resource_pools = persistent_resource.resource_pools
|
||||
|
||||
head_resource_pool = resource_pools[0]
|
||||
head_id = head_resource_pool.id
|
||||
head_image_uri = (
|
||||
persistent_resource.resource_runtime_spec.ray_spec.resource_pool_images[head_id]
|
||||
)
|
||||
if persistent_resource.resource_runtime_spec.service_account_spec.service_account:
|
||||
cluster.service_account = (
|
||||
persistent_resource.resource_runtime_spec.service_account_spec.service_account
|
||||
)
|
||||
if not head_image_uri:
|
||||
head_image_uri = persistent_resource.resource_runtime_spec.ray_spec.image_uri
|
||||
|
||||
try:
|
||||
python_version, ray_version = _validation_utils.get_versions_from_image_uri(
|
||||
head_image_uri
|
||||
)
|
||||
except IndexError:
|
||||
if _PRIVATE_PREVIEW_IMAGE in head_image_uri:
|
||||
# If using outdated images
|
||||
logging.info(
|
||||
"[Ray on Vertex AI]: The image of cluster %s is outdated."
|
||||
" It is recommended to delete and recreate the cluster to obtain"
|
||||
" the latest image." % persistent_resource.name
|
||||
)
|
||||
return None
|
||||
else:
|
||||
# Custom image might also cause IndexError
|
||||
python_version = None
|
||||
ray_version = None
|
||||
cluster.python_version = python_version
|
||||
cluster.ray_version = ray_version
|
||||
cluster.ray_metric_enabled = not (
|
||||
persistent_resource.resource_runtime_spec.ray_spec.ray_metric_spec.disabled
|
||||
)
|
||||
cluster.ray_logs_enabled = not (
|
||||
persistent_resource.resource_runtime_spec.ray_spec.ray_logs_spec.disabled
|
||||
)
|
||||
|
||||
accelerator_type = head_resource_pool.machine_spec.accelerator_type
|
||||
if accelerator_type.value != 0:
|
||||
accelerator_type = accelerator_type.name
|
||||
else:
|
||||
accelerator_type = None
|
||||
if _OFFICIAL_IMAGE in head_image_uri:
|
||||
# Official training image is not custom
|
||||
head_image_uri = None
|
||||
head_node_type = Resources(
|
||||
machine_type=head_resource_pool.machine_spec.machine_type,
|
||||
accelerator_type=accelerator_type,
|
||||
accelerator_count=head_resource_pool.machine_spec.accelerator_count,
|
||||
boot_disk_type=head_resource_pool.disk_spec.boot_disk_type,
|
||||
boot_disk_size_gb=head_resource_pool.disk_spec.boot_disk_size_gb,
|
||||
node_count=1,
|
||||
custom_image=head_image_uri,
|
||||
)
|
||||
worker_node_types = []
|
||||
if head_resource_pool.replica_count > 1:
|
||||
# head_node_type.node_count must be 1. If the head_resource_pool (the first
|
||||
# resource pool) has replica_count > 1, the rest replica are worker nodes.
|
||||
worker_node_count = head_resource_pool.replica_count - 1
|
||||
worker_node_types.append(
|
||||
Resources(
|
||||
machine_type=head_resource_pool.machine_spec.machine_type,
|
||||
accelerator_type=accelerator_type,
|
||||
accelerator_count=head_resource_pool.machine_spec.accelerator_count,
|
||||
boot_disk_type=head_resource_pool.disk_spec.boot_disk_type,
|
||||
boot_disk_size_gb=head_resource_pool.disk_spec.boot_disk_size_gb,
|
||||
node_count=worker_node_count,
|
||||
custom_image=head_image_uri,
|
||||
)
|
||||
)
|
||||
if head_resource_pool.autoscaling_spec:
|
||||
worker_node_types[0].autoscaling_spec = AutoscalingSpec(
|
||||
min_replica_count=head_resource_pool.autoscaling_spec.min_replica_count,
|
||||
max_replica_count=head_resource_pool.autoscaling_spec.max_replica_count,
|
||||
)
|
||||
for i in range(len(resource_pools) - 1):
|
||||
# Convert the second and more resource pools to vertex_ray.Resources,
|
||||
# and append then to worker_node_types.
|
||||
accelerator_type = resource_pools[i + 1].machine_spec.accelerator_type
|
||||
if accelerator_type.value != 0:
|
||||
accelerator_type = accelerator_type.name
|
||||
else:
|
||||
accelerator_type = None
|
||||
worker_image_uri = (
|
||||
persistent_resource.resource_runtime_spec.ray_spec.resource_pool_images[
|
||||
resource_pools[i + 1].id
|
||||
]
|
||||
)
|
||||
if _OFFICIAL_IMAGE in worker_image_uri:
|
||||
# Official training image is not custom
|
||||
worker_image_uri = None
|
||||
|
||||
resource = Resources(
|
||||
machine_type=resource_pools[i + 1].machine_spec.machine_type,
|
||||
accelerator_type=accelerator_type,
|
||||
accelerator_count=resource_pools[i + 1].machine_spec.accelerator_count,
|
||||
boot_disk_type=resource_pools[i + 1].disk_spec.boot_disk_type,
|
||||
boot_disk_size_gb=resource_pools[i + 1].disk_spec.boot_disk_size_gb,
|
||||
node_count=resource_pools[i + 1].replica_count,
|
||||
custom_image=worker_image_uri,
|
||||
)
|
||||
if resource_pools[i + 1].autoscaling_spec:
|
||||
resource.autoscaling_spec = AutoscalingSpec(
|
||||
min_replica_count=resource_pools[
|
||||
i + 1
|
||||
].autoscaling_spec.min_replica_count,
|
||||
max_replica_count=resource_pools[
|
||||
i + 1
|
||||
].autoscaling_spec.max_replica_count,
|
||||
)
|
||||
|
||||
worker_node_types.append(resource)
|
||||
|
||||
cluster.head_node_type = head_node_type
|
||||
cluster.worker_node_types = worker_node_types
|
||||
|
||||
return cluster
|
||||
@@ -0,0 +1,167 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
import google.auth
|
||||
import google.auth.transport.requests
|
||||
import logging
|
||||
import ray
|
||||
import re
|
||||
from immutabledict import immutabledict
|
||||
|
||||
from google.cloud.aiplatform import initializer
|
||||
from google.cloud.aiplatform.utils import resource_manager_utils
|
||||
|
||||
SUPPORTED_RAY_VERSIONS = immutabledict(
|
||||
{"2.9": "2.9.3", "2.33": "2.33.0", "2.42": "2.42.0"}
|
||||
)
|
||||
SUPPORTED_RAY_VERSIONS_FROM_PYTHON_VERSIONS = immutabledict(
|
||||
{
|
||||
"3.10": ("2.9", "2.33", "2.42"),
|
||||
"3.11": ("2.42"),
|
||||
}
|
||||
)
|
||||
_V2_4_WARNING_MESSAGE = (
|
||||
"After google-cloud-aiplatform>1.53.0, using Ray version = 2.4 will result"
|
||||
" in an error. Please use Ray version = 2.33.0 or 2.42.0 (default) instead."
|
||||
)
|
||||
_V2_9_WARNING_MESSAGE = (
|
||||
"In March 2025, using Ray version = 2.9 will result in an error. "
|
||||
"Please use Ray version = 2.33.0 or 2.42.0 (default) instead."
|
||||
)
|
||||
|
||||
|
||||
# Artifact Repository available regions.
|
||||
_AVAILABLE_REGIONS = ["us", "europe", "asia"]
|
||||
# If region is not available, assume using the default region.
|
||||
_DEFAULT_REGION = "us"
|
||||
|
||||
_PERSISTENT_RESOURCE_NAME_PATTERN = "projects/{}/locations/{}/persistentResources/{}"
|
||||
_VALID_RESOURCE_NAME_REGEX = "[a-z][a-zA-Z0-9._-]{0,127}"
|
||||
_DASHBOARD_URI_SUFFIX = "aiplatform-training.googleusercontent.com"
|
||||
|
||||
|
||||
def valid_resource_name(resource_name):
|
||||
"""Check if address is a valid resource name."""
|
||||
resource_name_split = resource_name.split("/")
|
||||
if not (
|
||||
len(resource_name_split) == 6
|
||||
and resource_name_split[0] == "projects"
|
||||
and resource_name_split[2] == "locations"
|
||||
and resource_name_split[4] == "persistentResources"
|
||||
):
|
||||
raise ValueError(
|
||||
"[Ray on Vertex AI]: Address must be in the following "
|
||||
"format: vertex_ray://projects/<project_num>/locations/<region>/persistentResources/<pr_id> "
|
||||
"or vertex_ray://<pr_id>."
|
||||
)
|
||||
|
||||
|
||||
def maybe_reconstruct_resource_name(address) -> str:
|
||||
"""Reconstruct full persistent resource name if only id was given."""
|
||||
if re.match("^{}$".format(_VALID_RESOURCE_NAME_REGEX), address):
|
||||
# Assume only cluster name (persistent resource id) was given.
|
||||
logging.info(
|
||||
"[Ray on Vertex AI]: Cluster name was given as address, reconstructing full resource name"
|
||||
)
|
||||
return _PERSISTENT_RESOURCE_NAME_PATTERN.format(
|
||||
resource_manager_utils.get_project_number(
|
||||
initializer.global_config.project
|
||||
),
|
||||
initializer.global_config.location,
|
||||
address,
|
||||
)
|
||||
|
||||
return address
|
||||
|
||||
|
||||
def get_local_ray_version():
|
||||
ray_version = ray.__version__.split(".")
|
||||
if len(ray_version) == 3:
|
||||
ray_version = ray_version[:2]
|
||||
return ".".join(ray_version)
|
||||
|
||||
|
||||
def get_image_uri(ray_version, python_version, enable_cuda):
|
||||
"""Image uri for a given ray version and python version."""
|
||||
if ray_version not in SUPPORTED_RAY_VERSIONS:
|
||||
raise ValueError(
|
||||
"[Ray on Vertex AI]: The supported Ray versions are %s (%s) and %s (%s)."
|
||||
% (
|
||||
list(SUPPORTED_RAY_VERSIONS.keys())[0],
|
||||
list(SUPPORTED_RAY_VERSIONS.values())[0],
|
||||
list(SUPPORTED_RAY_VERSIONS.keys())[1],
|
||||
list(SUPPORTED_RAY_VERSIONS.values())[1],
|
||||
)
|
||||
)
|
||||
if python_version not in SUPPORTED_RAY_VERSIONS_FROM_PYTHON_VERSIONS:
|
||||
raise ValueError(
|
||||
"[Ray on Vertex AI]: The supported Python versions are 3.10 or 3.11."
|
||||
)
|
||||
|
||||
if ray_version not in SUPPORTED_RAY_VERSIONS_FROM_PYTHON_VERSIONS[python_version]:
|
||||
raise ValueError(
|
||||
"[Ray on Vertex AI]: The supported Ray version(s) for Python version %s: %s."
|
||||
% (
|
||||
python_version,
|
||||
SUPPORTED_RAY_VERSIONS_FROM_PYTHON_VERSIONS[python_version],
|
||||
)
|
||||
)
|
||||
|
||||
location = initializer.global_config.location
|
||||
region = location.split("-")[0]
|
||||
if region not in _AVAILABLE_REGIONS:
|
||||
region = _DEFAULT_REGION
|
||||
ray_version = ray_version.replace(".", "-")
|
||||
python_version = python_version.replace(".", "")
|
||||
if enable_cuda:
|
||||
return f"{region}-docker.pkg.dev/vertex-ai/training/ray-gpu.{ray_version}.py{python_version}:latest"
|
||||
else:
|
||||
return f"{region}-docker.pkg.dev/vertex-ai/training/ray-cpu.{ray_version}.py{python_version}:latest"
|
||||
|
||||
|
||||
def get_versions_from_image_uri(image_uri):
|
||||
"""Get ray version and python version from image uri."""
|
||||
logging.info(f"[Ray on Vertex AI]: Getting versions from image uri: {image_uri}")
|
||||
image_label = image_uri.split("/")[-1].split(":")[0]
|
||||
py_version = image_label[-3] + "." + image_label[-2:]
|
||||
ray_version = image_label.split(".")[1].replace("-", ".")
|
||||
if (
|
||||
py_version in SUPPORTED_RAY_VERSIONS_FROM_PYTHON_VERSIONS
|
||||
and ray_version in SUPPORTED_RAY_VERSIONS_FROM_PYTHON_VERSIONS[py_version]
|
||||
):
|
||||
return py_version, ray_version
|
||||
else:
|
||||
# May not parse custom image and get the versions correctly
|
||||
return None, None
|
||||
|
||||
|
||||
def valid_dashboard_address(address):
|
||||
"""Check if address is a valid dashboard uri."""
|
||||
return address.endswith(_DASHBOARD_URI_SUFFIX)
|
||||
|
||||
|
||||
def get_bearer_token():
|
||||
"""Get bearer token through Application Default Credentials."""
|
||||
creds, _ = google.auth.default(
|
||||
scopes=["https://www.googleapis.com/auth/cloud-platform"]
|
||||
)
|
||||
|
||||
# creds.valid is False, and creds.token is None
|
||||
# Need to refresh credentials to populate those
|
||||
auth_req = google.auth.transport.requests.Request()
|
||||
creds.refresh(auth_req)
|
||||
return creds.token
|
||||
217
.venv/lib/python3.10/site-packages/vertex_ray/util/resources.py
Normal file
217
.venv/lib/python3.10/site-packages/vertex_ray/util/resources.py
Normal file
@@ -0,0 +1,217 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
# Copyright 2023 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
import dataclasses
|
||||
from typing import Dict, List, Optional
|
||||
from google.cloud.aiplatform_v1beta1.types import PersistentResource
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class AutoscalingSpec:
|
||||
"""Autoscaling spec for a ray cluster node.
|
||||
|
||||
Attributes:
|
||||
min_replica_count: The minimum number of replicas in the cluster.
|
||||
max_replica_count: The maximum number of replicas in the cluster.
|
||||
"""
|
||||
|
||||
min_replica_count: int = 1
|
||||
max_replica_count: int = 2
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class Resources:
|
||||
"""Resources for a ray cluster node.
|
||||
|
||||
Attributes:
|
||||
machine_type: See the list of machine types:
|
||||
https://cloud.google.com/vertex-ai/docs/training/configure-compute#machine-types
|
||||
node_count: This argument represents how many nodes to start for the
|
||||
ray cluster.
|
||||
accelerator_type: e.g. "NVIDIA_TESLA_P4".
|
||||
Vertex AI supports the following types of GPU:
|
||||
https://cloud.google.com/vertex-ai/docs/training/configure-compute#specifying_gpus
|
||||
accelerator_count: The number of accelerators to attach to the machine.
|
||||
boot_disk_type: Type of the boot disk (default is "pd-ssd").
|
||||
Valid values: "pd-ssd" (Persistent Disk Solid State Drive) or
|
||||
"pd-standard" (Persistent Disk Hard Disk Drive).
|
||||
boot_disk_size_gb: Size in GB of the boot disk (default is 100GB). Must
|
||||
be either unspecified or within the range of [100, 64000].
|
||||
custom_image: Custom image for this resource (e.g.
|
||||
us-docker.pkg.dev/my-project/ray-gpu.2-9.py310-tf:latest).
|
||||
autoscaling_spec: Autoscaling spec for this resource.
|
||||
"""
|
||||
|
||||
machine_type: Optional[str] = "n1-standard-16"
|
||||
node_count: Optional[int] = 1
|
||||
accelerator_type: Optional[str] = None
|
||||
accelerator_count: Optional[int] = 0
|
||||
boot_disk_type: Optional[str] = "pd-ssd"
|
||||
boot_disk_size_gb: Optional[int] = 100
|
||||
custom_image: Optional[str] = None
|
||||
autoscaling_spec: Optional[AutoscalingSpec] = None
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class NodeImages:
|
||||
"""Custom images for a ray cluster.
|
||||
|
||||
We currently support Ray v2.9, v2.33, v2.42 and python v3.10.
|
||||
We also support python v3.11 for Ray v2.42.
|
||||
The custom images must be extended from the following base images:
|
||||
"{region}-docker.pkg.dev/vertex-ai/training/ray-cpu.2-9.py310:latest",
|
||||
"{region}-docker.pkg.dev/vertex-ai/training/ray-gpu.2-9.py310:latest",
|
||||
"{region}-docker.pkg.dev/vertex-ai/training/ray-cpu.2-33.py310:latest",
|
||||
"{region}-docker.pkg.dev/vertex-ai/training/ray-gpu.2-33.py310:latest",
|
||||
"{region}-docker.pkg.dev/vertex-ai/training/ray-cpu.2-42.py310:latest",
|
||||
"{region}-docker.pkg.dev/vertex-ai/training/ray-gpu.2-42.py310:latest",
|
||||
"{region}-docker.pkg.dev/vertex-ai/training/ray-cpu.2-42.py311:latest", or
|
||||
"{region}-docker.pkg.dev/vertex-ai/training/ray-gpu.2-42.py311:latest". In
|
||||
order to use custom images, need to specify both head and worker images.
|
||||
|
||||
Attributes:
|
||||
head: image for head node (eg. us-docker.pkg.dev/my-project/ray-cpu.2-33.py310-tf:latest).
|
||||
worker: image for all worker nodes (eg. us-docker.pkg.dev/my-project/ray-gpu.2-33.py310-tf:latest).
|
||||
"""
|
||||
|
||||
head: str = None
|
||||
worker: str = None
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class PscIConfig:
|
||||
"""PSC-I config.
|
||||
|
||||
Attributes:
|
||||
network_attachment: Optional. The name or full name of the Compute Engine
|
||||
`network attachment <https://cloud.google.com/vpc/docs/about-network-attachments>`
|
||||
to attach to the resource. It has a format:
|
||||
``projects/{project}/regions/{region}/networkAttachments/{networkAttachment}``.
|
||||
Where {project} is a project number, as in ``12345``, and
|
||||
{networkAttachment} is a network attachment name. To specify
|
||||
this field, you must have already [created a network
|
||||
attachment]
|
||||
(https://cloud.google.com/vpc/docs/create-manage-network-attachments#create-network-attachments).
|
||||
This field is only used for resources using PSC-I. Make sure you do not
|
||||
specify the network here for VPC peering.
|
||||
"""
|
||||
|
||||
network_attachment: str = None
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class NfsMount:
|
||||
"""NFS mount.
|
||||
|
||||
Attributes:
|
||||
server: Required. IP address of the NFS server.
|
||||
path: Required. Source path exported from NFS server. Has to start
|
||||
with '/', and combined with the ip address, it indicates the
|
||||
source mount path in the form of ``server:path``.
|
||||
mount_point: Required. Destination mount path. The NFS will be mounted
|
||||
for the user under /mnt/nfs/<mount_point>.
|
||||
"""
|
||||
|
||||
server: str = None
|
||||
path: str = None
|
||||
mount_point: str = None
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class Cluster:
|
||||
"""Ray cluster (output only).
|
||||
|
||||
Attributes:
|
||||
cluster_resource_name: It has a format:
|
||||
"projects/<project_num>/locations/<region>/persistentResources/<pr_id>".
|
||||
network: Virtual private cloud (VPC) network. It has a format:
|
||||
"projects/<project_num>/global/networks/<network_name>".
|
||||
For Ray Client, VPC peering is required to connect to the cluster
|
||||
managed in the Vertex API service. For Ray Job API, VPC network is
|
||||
not required because cluster connection can be accessed through
|
||||
dashboard address.
|
||||
reserved_ip_ranges: A list of names for the reserved IP ranges under
|
||||
the VPC network that can be used for this cluster. If set, we will
|
||||
deploy the cluster within the provided IP ranges. Otherwise, the
|
||||
cluster is deployed to any IP ranges under the provided VPC network.
|
||||
Example: ["vertex-ai-ip-range"].
|
||||
service_account: Service account to be used for running Ray programs on
|
||||
the cluster.
|
||||
state: Describes the cluster state (defined in PersistentResource.State).
|
||||
python_version: Python version for the ray cluster (e.g. "3.10").
|
||||
ray_version: Ray version for the ray cluster (e.g. "2.33").
|
||||
head_node_type: The head node resource. Resources.node_count must be 1.
|
||||
If not set, by default it is a CPU node with machine_type of n1-standard-8.
|
||||
worker_node_types: The list of Resources of the worker nodes. Should not
|
||||
duplicate the elements in the list.
|
||||
dashboard_address: For Ray Job API (JobSubmissionClient), with this
|
||||
cluster connection doesn't require VPC peering.
|
||||
labels:
|
||||
The labels with user-defined metadata to organize Ray cluster.
|
||||
|
||||
Label keys and values can be no longer than 64 characters (Unicode
|
||||
codepoints), can only contain lowercase letters, numeric characters,
|
||||
underscores and dashes. International characters are allowed.
|
||||
|
||||
See https://goo.gl/xmQnxf for more information and examples of labels.
|
||||
"""
|
||||
|
||||
cluster_resource_name: str = None
|
||||
network: str = None
|
||||
reserved_ip_ranges: List[str] = None
|
||||
service_account: str = None
|
||||
state: PersistentResource.State = None
|
||||
python_version: str = None
|
||||
ray_version: str = None
|
||||
head_node_type: Resources = None
|
||||
worker_node_types: List[Resources] = None
|
||||
dashboard_address: str = None
|
||||
ray_metric_enabled: bool = True
|
||||
ray_logs_enabled: bool = True
|
||||
psc_interface_config: PscIConfig = None
|
||||
labels: Dict[str, str] = None
|
||||
|
||||
|
||||
def _check_machine_spec_identical(
|
||||
node_type_1: Resources,
|
||||
node_type_2: Resources,
|
||||
) -> int:
|
||||
"""Check if node_type_1 and node_type_2 have the same machine_spec.
|
||||
If they are identical, return additional_replica_count."""
|
||||
additional_replica_count = 0
|
||||
|
||||
# Check if machine_spec are the same
|
||||
if (
|
||||
node_type_1.machine_type == node_type_2.machine_type
|
||||
and node_type_1.accelerator_type == node_type_2.accelerator_type
|
||||
and node_type_1.accelerator_count == node_type_2.accelerator_count
|
||||
):
|
||||
if node_type_1.boot_disk_type != node_type_2.boot_disk_type:
|
||||
raise ValueError(
|
||||
"Worker disk type must match the head node's disk type if"
|
||||
" sharing the same machine_type, accelerator_type, and"
|
||||
" accelerator_count"
|
||||
)
|
||||
if node_type_1.boot_disk_size_gb != node_type_2.boot_disk_size_gb:
|
||||
raise ValueError(
|
||||
"Worker disk size must match the head node's disk size if"
|
||||
" sharing the same machine_type, accelerator_type, and"
|
||||
" accelerator_count"
|
||||
)
|
||||
additional_replica_count = node_type_2.node_count
|
||||
return additional_replica_count
|
||||
|
||||
return additional_replica_count
|
||||
Reference in New Issue
Block a user