Files
evo-ai/.venv/lib/python3.10/site-packages/vertex_ray/client_builder.py
2025-04-25 15:30:54 -03:00

202 lines
8.2 KiB
Python

# -*- coding: utf-8 -*-
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import grpc
import logging
import ray
from typing import Dict
from typing import Optional
from google.cloud import aiplatform
from google.cloud.aiplatform import initializer
from ray import client_builder
from .render import VertexRayTemplate
from .util import _validation_utils
from .util import _gapic_utils
VERTEX_SDK_VERSION = aiplatform.__version__
class _VertexRayClientContext(client_builder.ClientContext):
"""Custom ClientContext."""
def __init__(
self,
persistent_resource_id: str,
ray_head_uris: Dict[str, str],
ray_client_context: client_builder.ClientContext,
) -> None:
dashboard_uri = ray_head_uris.get("RAY_DASHBOARD_URI")
if dashboard_uri is None:
raise ValueError(
"Ray Cluster ",
persistent_resource_id,
" failed to start Head node properly.",
)
if ray.__version__ in ("2.42.0", "2.33.0"):
super().__init__(
dashboard_url=dashboard_uri,
python_version=ray_client_context.python_version,
ray_version=ray_client_context.ray_version,
ray_commit=ray_client_context.ray_commit,
_num_clients=ray_client_context._num_clients,
_context_to_restore=ray_client_context._context_to_restore,
)
elif ray.__version__ == "2.9.3":
super().__init__(
dashboard_url=dashboard_uri,
python_version=ray_client_context.python_version,
ray_version=ray_client_context.ray_version,
ray_commit=ray_client_context.ray_commit,
protocol_version=ray_client_context.protocol_version,
_num_clients=ray_client_context._num_clients,
_context_to_restore=ray_client_context._context_to_restore,
)
else:
raise ImportError(
f"[Ray on Vertex AI]: Unsupported version {ray.__version__}."
+ "Only 2.42.0, 2.33.0, and 2.9.3 are supported."
)
self.persistent_resource_id = persistent_resource_id
self.vertex_sdk_version = str(VERTEX_SDK_VERSION)
self.shell_uri = ray_head_uris.get("RAY_HEAD_NODE_INTERACTIVE_SHELL_URI")
def _context_table_template(self):
shell_uri_row = None
if self.shell_uri is not None:
shell_uri_row = VertexRayTemplate("context_shellurirow.html.j2").render(
shell_uri=self.shell_uri
)
return VertexRayTemplate("context_table.html.j2").render(
python_version=self.python_version,
ray_version=self.ray_version,
vertex_sdk_version=self.vertex_sdk_version,
dashboard_url=self.dashboard_url,
persistent_resource_id=self.persistent_resource_id,
shell_uri_row=shell_uri_row,
)
class VertexRayClientBuilder(client_builder.ClientBuilder):
"""Class to initialize a Ray client with vertex on ray capabilities."""
def __init__(self, address: Optional[str]) -> None:
address = _validation_utils.maybe_reconstruct_resource_name(address)
_validation_utils.valid_resource_name(address)
self._credentials = None
self._metadata = None
self.vertex_address = address
logging.info(
"[Ray on Vertex AI]: Using cluster resource name to access head address with GAPIC API"
)
self.resource_name = address
self.response = _gapic_utils.get_persistent_resource(self.resource_name)
private_address = self.response.resource_runtime.access_uris.get(
"RAY_HEAD_NODE_INTERNAL_IP"
)
public_address = self.response.resource_runtime.access_uris.get(
"RAY_CLIENT_ENDPOINT"
)
service_account = (
self.response.resource_runtime_spec.service_account_spec.service_account
)
if public_address is None:
address = private_address
if service_account:
raise ValueError(
"[Ray on Vertex AI]: Ray Cluster ",
address,
" failed to start Head node properly because custom service"
" account isn't supported in peered VPC network. Use public"
" endpoint instead (createa a cluster without specifying"
" VPC network).",
)
else:
address = public_address
if address is None:
persistent_resource_id = self.resource_name.split("/")[5]
raise ValueError(
"[Ray on Vertex AI]: Ray Cluster ",
persistent_resource_id,
" Head node is not reachable. Please ensure that a valid VPC network has been specified.",
)
logging.debug("[Ray on Vertex AI]: Resolved head node ip: %s", address)
cluster = _gapic_utils.persistent_resource_to_cluster(
persistent_resource=self.response
)
if cluster is None:
raise ValueError(
"[Ray on Vertex AI]: Please delete and recreate the cluster (The cluster is not a Ray cluster or the cluster image is outdated)."
)
local_ray_verion = _validation_utils.get_local_ray_version()
if cluster.ray_version != local_ray_verion:
if cluster.head_node_type.custom_image is None:
install_ray_version = _validation_utils.SUPPORTED_RAY_VERSIONS.get(
cluster.ray_version
)
logging.info(
"[Ray on Vertex]: Local runtime has Ray version %s"
", but the requested cluster runtime has %s. Please "
"ensure that the Ray versions match for client connectivity. You may "
'"pip install --user --force-reinstall ray[default]==%s"'
" and restart runtime before cluster connection."
% (local_ray_verion, cluster.ray_version, install_ray_version)
)
else:
logging.info(
"[Ray on Vertex]: Local runtime has Ray version %s."
"Please ensure that the Ray versions match for client connectivity."
% local_ray_verion
)
super().__init__(address)
def connect(self) -> _VertexRayClientContext:
# Can send any other params to ray cluster here
logging.info("[Ray on Vertex AI]: Connecting...")
public_address = self.response.resource_runtime.access_uris.get(
"RAY_CLIENT_ENDPOINT"
)
private_address = self.response.resource_runtime.access_uris.get(
"RAY_HEAD_NODE_INTERNAL_IP"
)
if public_address and not private_address:
self._credentials = grpc.ssl_channel_credentials()
bearer_token = _validation_utils.get_bearer_token()
self._metadata = [
("authorization", "Bearer {}".format(bearer_token)),
("x-goog-user-project", "{}".format(initializer.global_config.project)),
]
ray_client_context = super().connect()
ray_head_uris = self.response.resource_runtime.access_uris
# Valid resource name (reference public doc for public release):
# "projects/<project_num>/locations/<region>/persistentResources/<pr_id>"
persistent_resource_id = self.resource_name.split("/")[5]
return _VertexRayClientContext(
persistent_resource_id=persistent_resource_id,
ray_head_uris=ray_head_uris,
ray_client_context=ray_client_context,
)