structure saas with tools
This commit is contained in:
201
.venv/lib/python3.10/site-packages/vertex_ray/client_builder.py
Normal file
201
.venv/lib/python3.10/site-packages/vertex_ray/client_builder.py
Normal file
@@ -0,0 +1,201 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
import grpc
|
||||
import logging
|
||||
import ray
|
||||
|
||||
from typing import Dict
|
||||
from typing import Optional
|
||||
from google.cloud import aiplatform
|
||||
from google.cloud.aiplatform import initializer
|
||||
from ray import client_builder
|
||||
from .render import VertexRayTemplate
|
||||
from .util import _validation_utils
|
||||
from .util import _gapic_utils
|
||||
|
||||
|
||||
VERTEX_SDK_VERSION = aiplatform.__version__
|
||||
|
||||
|
||||
class _VertexRayClientContext(client_builder.ClientContext):
|
||||
"""Custom ClientContext."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
persistent_resource_id: str,
|
||||
ray_head_uris: Dict[str, str],
|
||||
ray_client_context: client_builder.ClientContext,
|
||||
) -> None:
|
||||
dashboard_uri = ray_head_uris.get("RAY_DASHBOARD_URI")
|
||||
if dashboard_uri is None:
|
||||
raise ValueError(
|
||||
"Ray Cluster ",
|
||||
persistent_resource_id,
|
||||
" failed to start Head node properly.",
|
||||
)
|
||||
if ray.__version__ in ("2.42.0", "2.33.0"):
|
||||
super().__init__(
|
||||
dashboard_url=dashboard_uri,
|
||||
python_version=ray_client_context.python_version,
|
||||
ray_version=ray_client_context.ray_version,
|
||||
ray_commit=ray_client_context.ray_commit,
|
||||
_num_clients=ray_client_context._num_clients,
|
||||
_context_to_restore=ray_client_context._context_to_restore,
|
||||
)
|
||||
elif ray.__version__ == "2.9.3":
|
||||
super().__init__(
|
||||
dashboard_url=dashboard_uri,
|
||||
python_version=ray_client_context.python_version,
|
||||
ray_version=ray_client_context.ray_version,
|
||||
ray_commit=ray_client_context.ray_commit,
|
||||
protocol_version=ray_client_context.protocol_version,
|
||||
_num_clients=ray_client_context._num_clients,
|
||||
_context_to_restore=ray_client_context._context_to_restore,
|
||||
)
|
||||
else:
|
||||
raise ImportError(
|
||||
f"[Ray on Vertex AI]: Unsupported version {ray.__version__}."
|
||||
+ "Only 2.42.0, 2.33.0, and 2.9.3 are supported."
|
||||
)
|
||||
self.persistent_resource_id = persistent_resource_id
|
||||
self.vertex_sdk_version = str(VERTEX_SDK_VERSION)
|
||||
self.shell_uri = ray_head_uris.get("RAY_HEAD_NODE_INTERACTIVE_SHELL_URI")
|
||||
|
||||
def _context_table_template(self):
|
||||
shell_uri_row = None
|
||||
if self.shell_uri is not None:
|
||||
shell_uri_row = VertexRayTemplate("context_shellurirow.html.j2").render(
|
||||
shell_uri=self.shell_uri
|
||||
)
|
||||
|
||||
return VertexRayTemplate("context_table.html.j2").render(
|
||||
python_version=self.python_version,
|
||||
ray_version=self.ray_version,
|
||||
vertex_sdk_version=self.vertex_sdk_version,
|
||||
dashboard_url=self.dashboard_url,
|
||||
persistent_resource_id=self.persistent_resource_id,
|
||||
shell_uri_row=shell_uri_row,
|
||||
)
|
||||
|
||||
|
||||
class VertexRayClientBuilder(client_builder.ClientBuilder):
|
||||
"""Class to initialize a Ray client with vertex on ray capabilities."""
|
||||
|
||||
def __init__(self, address: Optional[str]) -> None:
|
||||
address = _validation_utils.maybe_reconstruct_resource_name(address)
|
||||
_validation_utils.valid_resource_name(address)
|
||||
self._credentials = None
|
||||
self._metadata = None
|
||||
self.vertex_address = address
|
||||
logging.info(
|
||||
"[Ray on Vertex AI]: Using cluster resource name to access head address with GAPIC API"
|
||||
)
|
||||
|
||||
self.resource_name = address
|
||||
|
||||
self.response = _gapic_utils.get_persistent_resource(self.resource_name)
|
||||
private_address = self.response.resource_runtime.access_uris.get(
|
||||
"RAY_HEAD_NODE_INTERNAL_IP"
|
||||
)
|
||||
public_address = self.response.resource_runtime.access_uris.get(
|
||||
"RAY_CLIENT_ENDPOINT"
|
||||
)
|
||||
service_account = (
|
||||
self.response.resource_runtime_spec.service_account_spec.service_account
|
||||
)
|
||||
|
||||
if public_address is None:
|
||||
address = private_address
|
||||
if service_account:
|
||||
raise ValueError(
|
||||
"[Ray on Vertex AI]: Ray Cluster ",
|
||||
address,
|
||||
" failed to start Head node properly because custom service"
|
||||
" account isn't supported in peered VPC network. Use public"
|
||||
" endpoint instead (createa a cluster without specifying"
|
||||
" VPC network).",
|
||||
)
|
||||
else:
|
||||
address = public_address
|
||||
|
||||
if address is None:
|
||||
persistent_resource_id = self.resource_name.split("/")[5]
|
||||
raise ValueError(
|
||||
"[Ray on Vertex AI]: Ray Cluster ",
|
||||
persistent_resource_id,
|
||||
" Head node is not reachable. Please ensure that a valid VPC network has been specified.",
|
||||
)
|
||||
|
||||
logging.debug("[Ray on Vertex AI]: Resolved head node ip: %s", address)
|
||||
cluster = _gapic_utils.persistent_resource_to_cluster(
|
||||
persistent_resource=self.response
|
||||
)
|
||||
if cluster is None:
|
||||
raise ValueError(
|
||||
"[Ray on Vertex AI]: Please delete and recreate the cluster (The cluster is not a Ray cluster or the cluster image is outdated)."
|
||||
)
|
||||
local_ray_verion = _validation_utils.get_local_ray_version()
|
||||
if cluster.ray_version != local_ray_verion:
|
||||
if cluster.head_node_type.custom_image is None:
|
||||
install_ray_version = _validation_utils.SUPPORTED_RAY_VERSIONS.get(
|
||||
cluster.ray_version
|
||||
)
|
||||
logging.info(
|
||||
"[Ray on Vertex]: Local runtime has Ray version %s"
|
||||
", but the requested cluster runtime has %s. Please "
|
||||
"ensure that the Ray versions match for client connectivity. You may "
|
||||
'"pip install --user --force-reinstall ray[default]==%s"'
|
||||
" and restart runtime before cluster connection."
|
||||
% (local_ray_verion, cluster.ray_version, install_ray_version)
|
||||
)
|
||||
else:
|
||||
logging.info(
|
||||
"[Ray on Vertex]: Local runtime has Ray version %s."
|
||||
"Please ensure that the Ray versions match for client connectivity."
|
||||
% local_ray_verion
|
||||
)
|
||||
super().__init__(address)
|
||||
|
||||
def connect(self) -> _VertexRayClientContext:
|
||||
# Can send any other params to ray cluster here
|
||||
logging.info("[Ray on Vertex AI]: Connecting...")
|
||||
|
||||
public_address = self.response.resource_runtime.access_uris.get(
|
||||
"RAY_CLIENT_ENDPOINT"
|
||||
)
|
||||
private_address = self.response.resource_runtime.access_uris.get(
|
||||
"RAY_HEAD_NODE_INTERNAL_IP"
|
||||
)
|
||||
if public_address and not private_address:
|
||||
self._credentials = grpc.ssl_channel_credentials()
|
||||
bearer_token = _validation_utils.get_bearer_token()
|
||||
self._metadata = [
|
||||
("authorization", "Bearer {}".format(bearer_token)),
|
||||
("x-goog-user-project", "{}".format(initializer.global_config.project)),
|
||||
]
|
||||
ray_client_context = super().connect()
|
||||
ray_head_uris = self.response.resource_runtime.access_uris
|
||||
|
||||
# Valid resource name (reference public doc for public release):
|
||||
# "projects/<project_num>/locations/<region>/persistentResources/<pr_id>"
|
||||
persistent_resource_id = self.resource_name.split("/")[5]
|
||||
|
||||
return _VertexRayClientContext(
|
||||
persistent_resource_id=persistent_resource_id,
|
||||
ray_head_uris=ray_head_uris,
|
||||
ray_client_context=ray_client_context,
|
||||
)
|
||||
Reference in New Issue
Block a user