# -*- coding: utf-8 -*- # Copyright 2024 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # import logging import os import tempfile import time import uuid from typing import Any, Iterable, Optional import pyarrow.parquet as pq from google.api_core import client_info from google.api_core import exceptions from google.cloud import bigquery from google.cloud.aiplatform import initializer import ray from ray.data._internal.execution.interfaces import TaskContext from ray.data._internal.remote_fn import cached_remote_fn from ray.data.block import Block, BlockAccessor try: from ray.data.datasource.datasink import Datasink except ImportError: # If datasink cannot be imported, Ray >=2.9.3 is not installed Datasink = None DEFAULT_MAX_RETRY_CNT = 10 RATE_LIMIT_EXCEEDED_SLEEP_TIME = 11 _BQ_GAPIC_VERSION = bigquery.__version__ + "+vertex_ray" bq_info = client_info.ClientInfo( gapic_version=_BQ_GAPIC_VERSION, user_agent=f"ray-on-vertex/{_BQ_GAPIC_VERSION}" ) # BigQuery write for Ray 2.42.0, 2.33.0, and 2.9.3 if Datasink is None: _BigQueryDatasink = None else: class _BigQueryDatasink(Datasink): def __init__( self, dataset: str, project_id: Optional[str] = None, max_retry_cnt: int = DEFAULT_MAX_RETRY_CNT, overwrite_table: Optional[bool] = True, ) -> None: self.dataset = dataset self.project_id = project_id or initializer.global_config.project self.max_retry_cnt = max_retry_cnt self.overwrite_table = overwrite_table def on_write_start(self) -> None: # Set up datasets to write client = bigquery.Client(project=self.project_id, client_info=bq_info) dataset_id = self.dataset.split(".", 1)[0] try: client.get_dataset(dataset_id) except exceptions.NotFound: client.create_dataset(f"{self.project_id}.{dataset_id}", timeout=30) print("[Ray on Vertex AI]: Created dataset " + dataset_id) # Delete table if overwrite_table is True if self.overwrite_table: print( f"[Ray on Vertex AI]: Attempting to delete table {self.dataset}" + " if it already exists since kwarg overwrite_table = True." ) client.delete_table( f"{self.project_id}.{self.dataset}", not_found_ok=True ) else: print( "[Ray on Vertex AI]: The write will append to table " + f"{self.dataset} if it already exists " + "since kwarg overwrite_table = False." ) def write( self, blocks: Iterable[Block], ctx: TaskContext, ) -> Any: def _write_single_block( block: Block, project_id: str, dataset: str ) -> None: block = BlockAccessor.for_block(block).to_arrow() client = bigquery.Client(project=project_id, client_info=bq_info) job_config = bigquery.LoadJobConfig(autodetect=True) job_config.source_format = bigquery.SourceFormat.PARQUET job_config.write_disposition = bigquery.WriteDisposition.WRITE_APPEND with tempfile.TemporaryDirectory() as temp_dir: fp = os.path.join(temp_dir, f"block_{uuid.uuid4()}.parquet") pq.write_table(block, fp, compression="SNAPPY") retry_cnt = 0 while retry_cnt <= self.max_retry_cnt: with open(fp, "rb") as source_file: job = client.load_table_from_file( source_file, dataset, job_config=job_config ) try: logging.info(job.result()) break except exceptions.Forbidden as e: retry_cnt += 1 if retry_cnt > self.max_retry_cnt: break print( "[Ray on Vertex AI]: A block write encountered" + f" a rate limit exceeded error {retry_cnt} time(s)." + " Sleeping to try again." ) logging.debug(e) time.sleep(RATE_LIMIT_EXCEEDED_SLEEP_TIME) # Raise exception if retry_cnt exceeds max_retry_cnt if retry_cnt > self.max_retry_cnt: print( f"[Ray on Vertex AI]: Maximum ({self.max_retry_cnt}) retry count exceeded." + " Ray will attempt to retry the block write via fault tolerance." + " For more information, see https://docs.ray.io/en/latest/ray-core/fault_tolerance/tasks.html" ) raise RuntimeError( f"[Ray on Vertex AI]: Write failed due to {retry_cnt}" + " repeated API rate limit exceeded responses. Consider" + " specifiying the max_retry_cnt kwarg with a higher value." ) _write_single_block = cached_remote_fn(_write_single_block) # Launch a remote task for each block within this write task ray.get( [ _write_single_block.remote(block, self.project_id, self.dataset) for block in blocks ] ) return "ok"