structure saas with tools

This commit is contained in:
Davidson Gomes
2025-04-25 15:30:54 -03:00
commit 1aef473937
16434 changed files with 6584257 additions and 0 deletions

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,63 @@
# -*- coding: utf-8 -*-
# Copyright 2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from typing import Optional
from google.cloud.aiplatform import explain
from google.cloud.aiplatform.compat.types import (
endpoint as gca_endpoint_compat,
)
def create_and_validate_explanation_spec(
explanation_metadata: Optional[explain.ExplanationMetadata] = None,
explanation_parameters: Optional[explain.ExplanationParameters] = None,
) -> Optional[explain.ExplanationSpec]:
"""Validates the parameters needed to create explanation_spec and creates it.
Args:
explanation_metadata (explain.ExplanationMetadata):
Optional. Metadata describing the Model's input and output for
explanation. `explanation_metadata` is optional while
`explanation_parameters` must be specified when used.
For more details, see `Ref docs <http://tinyurl.com/1igh60kt>`
explanation_parameters (explain.ExplanationParameters):
Optional. Parameters to configure explaining for Model's
predictions.
For more details, see `Ref docs <http://tinyurl.com/1an4zake>`
Returns:
explanation_spec: Specification of Model explanation.
Raises:
ValueError: If `explanation_metadata` is given, but
`explanation_parameters` is omitted. `explanation_metadata` is optional
while `explanation_parameters` must be specified when used.
"""
if bool(explanation_metadata) and not bool(explanation_parameters):
raise ValueError(
"To get model explanation, `explanation_parameters` must be specified."
)
if explanation_parameters:
explanation_spec = gca_endpoint_compat.explanation.ExplanationSpec()
explanation_spec.parameters = explanation_parameters
if explanation_metadata:
explanation_spec.metadata = explanation_metadata
return explanation_spec
return None

View File

@@ -0,0 +1,284 @@
# -*- coding: utf-8 -*-
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import sys
import typing
import urllib
from uuid import uuid4
from typing import Optional
from google.cloud.aiplatform import base
if typing.TYPE_CHECKING:
from google.cloud.aiplatform.metadata import experiment_resources
from google.cloud.aiplatform.metadata import experiment_run_resource
from google.cloud.aiplatform import model_evaluation
from vertexai.preview.tuning import sft
_LOGGER = base.Logger(__name__)
def _get_ipython_shell_name() -> str:
if "IPython" in sys.modules:
from IPython import get_ipython
return get_ipython().__class__.__name__
return ""
def is_ipython_available() -> bool:
return _get_ipython_shell_name() != ""
def _get_styles() -> None:
"""Returns the HTML style markup to support custom buttons."""
return """
<link rel="stylesheet" href="https://fonts.googleapis.com/icon?family=Material+Icons">
<style>
.view-vertex-resource,
.view-vertex-resource:hover,
.view-vertex-resource:visited {
position: relative;
display: inline-flex;
flex-direction: row;
height: 32px;
padding: 0 12px;
margin: 4px 18px;
gap: 4px;
border-radius: 4px;
align-items: center;
justify-content: center;
background-color: rgb(255, 255, 255);
color: rgb(51, 103, 214);
font-family: Roboto,"Helvetica Neue",sans-serif;
font-size: 13px;
font-weight: 500;
text-transform: uppercase;
text-decoration: none !important;
transition: box-shadow 280ms cubic-bezier(0.4, 0, 0.2, 1) 0s;
box-shadow: 0px 3px 1px -2px rgba(0,0,0,0.2), 0px 2px 2px 0px rgba(0,0,0,0.14), 0px 1px 5px 0px rgba(0,0,0,0.12);
}
.view-vertex-resource:active {
box-shadow: 0px 5px 5px -3px rgba(0,0,0,0.2),0px 8px 10px 1px rgba(0,0,0,0.14),0px 3px 14px 2px rgba(0,0,0,0.12);
}
.view-vertex-resource:active .view-vertex-ripple::before {
position: absolute;
top: 0;
bottom: 0;
left: 0;
right: 0;
border-radius: 4px;
pointer-events: none;
content: '';
background-color: rgb(51, 103, 214);
opacity: 0.12;
}
.view-vertex-icon {
font-size: 18px;
}
</style>
"""
def display_link(text: str, url: str, icon: Optional[str] = "open_in_new") -> None:
"""Creates and displays the link to open the Vertex resource
Args:
text: The text displayed on the clickable button.
url: The url that the button will lead to.
Only cloud console URIs are allowed.
icon: The icon name on the button (from material-icons library)
Returns:
Dict of custom properties with keys mapped to column names
"""
CLOUD_UI_URL = "https://console.cloud.google.com"
CLOUD_DOCS_URL = "https://cloud.google.com"
if not (url.startswith(CLOUD_UI_URL) or url.startswith(CLOUD_DOCS_URL)):
raise ValueError(
f"Only urls starting with {CLOUD_UI_URL} or {CLOUD_DOCS_URL} are allowed."
)
button_id = f"view-vertex-resource-{str(uuid4())}"
# Add the markup for the CSS and link component
html = f"""
{_get_styles()}
<a class="view-vertex-resource" id="{button_id}" href="#view-{button_id}">
<span class="material-icons view-vertex-icon">{icon}</span>
<span>{text}</span>
</a>
"""
# Add the click handler for the link
html += f"""
<script>
(function () {{
const link = document.getElementById('{button_id}');
link.addEventListener('click', (e) => {{
if (window.google?.colab?.openUrl) {{
window.google.colab.openUrl('{url}');
}} else {{
window.open('{url}', '_blank');
}}
e.stopPropagation();
e.preventDefault();
}});
}})();
</script>
"""
from IPython.display import display
from IPython.display import HTML
display(HTML(html))
def display_experiment_button(experiment: "experiment_resources.Experiment") -> None:
"""Function to generate a link bound to the Vertex experiment"""
if not is_ipython_available():
return
try:
project = experiment._metadata_context.project
location = experiment._metadata_context.location
experiment_name = experiment._metadata_context.name
if experiment_name is None or project is None or location is None:
return
except AttributeError:
_LOGGER.warning("Unable to fetch experiment metadata")
return
uri = (
"https://console.cloud.google.com/vertex-ai/experiments/locations/"
+ f"{location}/experiments/{experiment_name}/"
+ f"runs?project={project}"
)
display_link("View Experiment", uri, "science")
def display_experiment_run_button(
experiment_run: "experiment_run_resource.ExperimentRun",
) -> None:
"""Function to generate a link bound to the Vertex experiment run"""
if not is_ipython_available():
return
try:
project = experiment_run.project
location = experiment_run.location
experiment_name = experiment_run._experiment._metadata_context.name
run_name = experiment_run.name
if (
run_name is None
or experiment_name is None
or project is None
or location is None
):
return
except AttributeError:
_LOGGER.warning("Unable to fetch experiment run metadata")
return
uri = (
"https://console.cloud.google.com/vertex-ai/experiments/locations/"
+ f"{location}/experiments/{experiment_name}/"
+ f"runs/{experiment_name}-{run_name}?project={project}"
)
display_link("View Experiment Run", uri, "science")
def display_model_evaluation_button(
evaluation: "model_evaluation.ModelEvaluation",
) -> None:
"""Function to generate a link bound to the Vertex model evaluation"""
if not is_ipython_available():
return
try:
resource_name = evaluation.resource_name
fields = evaluation._parse_resource_name(resource_name)
project = fields["project"]
location = fields["location"]
model_id = fields["model"]
evaluation_id = fields["evaluation"]
except AttributeError:
_LOGGER.warning("Unable to parse model evaluation metadata")
return
if "@" in model_id:
model_id, version_id = model_id.split("@")
else:
version_id = "default"
uri = (
"https://console.cloud.google.com/vertex-ai/models/locations/"
+ f"{location}/models/{model_id}/versions/{version_id}/evaluations/"
+ f"{evaluation_id}?project={project}"
)
display_link("View Model Evaluation", uri, "lightbulb")
def display_model_tuning_button(tuning_job: "sft.SupervisedTuningJob") -> None:
"""Function to generate a link bound to the Vertex model tuning job."""
if not is_ipython_available():
return
try:
resource_name = tuning_job.resource_name
fields = tuning_job._parse_resource_name(resource_name)
project = fields["project"]
location = fields["location"]
tuning_job_id = fields["tuning_job"]
except AttributeError:
_LOGGER.warning("Unable to parse tuning job metadata")
return
uri = (
"https://console.cloud.google.com/vertex-ai/generative/language/"
+ f"locations/{location}/tuning/tuningJob/{tuning_job_id}"
+ f"?project={project}"
)
display_link("View Tuning Job", uri, "tune")
def display_browse_prebuilt_metrics_button() -> None:
"""Function to generate a link to the Gen AI Evaluation pre-built metrics page."""
if not is_ipython_available():
return
uri = (
"https://cloud.google.com/vertex-ai/generative-ai/docs/models/metrics-templates"
)
display_link("Browse pre-built metrics", uri, "list")
def display_gen_ai_evaluation_results_button(
gcs_file_path: Optional[str] = None,
) -> None:
"""Function to generate a link bound to the Gen AI evaluation run."""
if not is_ipython_available():
return
uri = "https://cloud.google.com/vertex-ai/generative-ai/docs/models/view-evaluation"
if gcs_file_path is not None:
gcs_file_path = urllib.parse.quote(gcs_file_path)
uri = f"https://console.cloud.google.com/storage/browser/_details/{gcs_file_path};colab_enterprise=gen_ai_evaluation"
display_link("View evaluation results", uri, "bar_chart")

View File

@@ -0,0 +1,28 @@
# -*- coding: utf-8 -*-
# Copyright 2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
def _is_autologging_enabled() -> bool:
try:
import mlflow
if mlflow.get_tracking_uri() == "vertex-mlflow-plugin://":
return True
else:
return False
except ImportError:
return False

View File

@@ -0,0 +1,102 @@
# -*- coding: utf-8 -*-
# Copyright 2021 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from typing import Dict, List, Optional, Tuple
from google.cloud.aiplatform import datasets
def get_default_column_transformations(
dataset: datasets._ColumnNamesDataset,
target_column: str,
) -> Tuple[List[Dict[str, Dict[str, str]]], List[str]]:
"""Get default column transformations from the column names, while omitting the target column.
Args:
dataset (_ColumnNamesDataset):
Required. The dataset
target_column (str):
Required. The name of the column values of which the Model is to predict.
Returns:
Tuple[List[Dict[str, Dict[str, str]]], List[str]]:
The default column transformations and the default column names.
"""
column_names = [
column_name
for column_name in dataset.column_names
if column_name != target_column
]
column_transformations = [
{"auto": {"column_name": column_name}} for column_name in column_names
]
return (column_transformations, column_names)
def validate_and_get_column_transformations(
column_specs: Optional[Dict[str, str]] = None,
column_transformations: Optional[List[Dict[str, Dict[str, str]]]] = None,
) -> Optional[List[Dict[str, Dict[str, str]]]]:
"""Validates column specs and transformations, then returns processed transformations.
Args:
column_specs (Dict[str, str]):
Optional. Alternative to column_transformations where the keys of the dict
are column names and their respective values are one of
AutoMLTabularTrainingJob.column_data_types.
When creating transformation for BigQuery Struct column, the column
should be flattened using "." as the delimiter. Only columns with no child
should have a transformation.
If an input column has no transformations on it, such a column is
ignored by the training, except for the targetColumn, which should have
no transformations defined on.
Only one of column_transformations or column_specs should be passed.
column_transformations (List[Dict[str, Dict[str, str]]]):
Optional. Transformations to apply to the input columns (i.e. columns other
than the targetColumn). Each transformation may produce multiple
result values from the column's value, and all are used for training.
When creating transformation for BigQuery Struct column, the column
should be flattened using "." as the delimiter. Only columns with no child
should have a transformation.
If an input column has no transformations on it, such a column is
ignored by the training, except for the targetColumn, which should have
no transformations defined on.
Only one of column_transformations or column_specs should be passed.
Consider using column_specs as column_transformations will be deprecated eventually.
Returns:
List[Dict[str, Dict[str, str]]]:
The column transformations.
Raises:
ValueError: If both column_transformations and column_specs were provided.
"""
# user populated transformations
if column_transformations is not None and column_specs is not None:
raise ValueError(
"Both column_transformations and column_specs were passed. Only "
"one is allowed."
)
elif column_specs is not None:
return [
{transformation: {"column_name": column_name}}
for column_name, transformation in column_specs.items()
]
else:
return column_transformations

View File

@@ -0,0 +1,36 @@
# -*- coding: utf-8 -*-
# Copyright 2021 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from google.cloud.aiplatform import jobs
from google.cloud.aiplatform import tensorboard
def custom_job_console_uri(custom_job_resource_name: str) -> str:
"""Helper method to create console uri from custom job resource name."""
fields = jobs.CustomJob._parse_resource_name(custom_job_resource_name)
return f"https://console.cloud.google.com/ai/platform/locations/{fields['location']}/training/{fields['custom_job']}?project={fields['project']}"
def custom_job_tensorboard_console_uri(
tensorboard_resource_name: str, custom_job_resource_name: str
) -> str:
"""Helper method to create console uri to tensorboard from custom job resource."""
# projects+40556267596+locations+us-central1+tensorboards+740208820004847616+experiments+2214368039829241856
fields = tensorboard.Tensorboard._parse_resource_name(tensorboard_resource_name)
experiment_resource_name = f"{tensorboard_resource_name}/experiments/{custom_job_resource_name.split('/')[-1]}"
uri_experiment_resource_name = experiment_resource_name.replace("/", "+")
return f"https://{fields['location']}.tensorboard.googleusercontent.com/experiment/{uri_experiment_resource_name}"

View File

@@ -0,0 +1,13 @@
# Copyright 2021 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

View File

@@ -0,0 +1,72 @@
# Copyright 2020 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from google.cloud.aiplatform.utils.enhanced_library import value_converter
from proto.marshal import Marshal
from proto.marshal.rules.struct import ValueRule
from google.protobuf.struct_pb2 import Value
class ConversionValueRule(ValueRule):
def to_python(self, value, *, absent: bool = None):
return super().to_python(value, absent=absent)
def to_proto(self, value):
# Need to check whether value is an instance
# of an enhanced type
if callable(getattr(value, "to_value", None)):
return value.to_value()
else:
return super().to_proto(value)
def _add_methods_to_classes_in_package(pkg):
classes = dict(
[(name, cls) for name, cls in pkg.__dict__.items() if isinstance(cls, type)]
)
for class_name, cls in classes.items():
# Add to_value() method to class with docstring
setattr(cls, "to_value", value_converter.to_value)
cls.to_value.__doc__ = value_converter.to_value.__doc__
# Add from_value() method to class with docstring
setattr(cls, "from_value", _add_from_value_to_class(cls))
cls.from_value.__doc__ = value_converter.from_value.__doc__
# Add from_map() method to class with docstring
setattr(cls, "from_map", _add_from_map_to_class(cls))
cls.from_map.__doc__ = value_converter.from_map.__doc__
def _add_from_value_to_class(cls):
def _from_value(value):
return value_converter.from_value(cls, value)
return _from_value
def _add_from_map_to_class(cls):
def _from_map(map_):
return value_converter.from_map(cls, map_)
return _from_map
marshal = Marshal(name="google.cloud.aiplatform.v1beta1")
marshal.register(Value, ConversionValueRule(marshal=marshal))
marshal = Marshal(name="google.cloud.aiplatform.v1")
marshal.register(Value, ConversionValueRule(marshal=marshal))

View File

@@ -0,0 +1,60 @@
# Copyright 2020 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from google.protobuf.struct_pb2 import Value
from google.protobuf import json_format
from proto.marshal.collections.maps import MapComposite
from proto.marshal import Marshal
from proto import Message
from proto.message import MessageMeta
def to_value(self: Message) -> Value:
"""Converts a message type to a :class:`~google.protobuf.struct_pb2.Value` object.
Args:
message: the message to convert
Returns:
the message as a :class:`~google.protobuf.struct_pb2.Value` object
"""
tmp_dict = json_format.MessageToDict(self._pb)
return json_format.ParseDict(tmp_dict, Value())
def from_value(cls: MessageMeta, value: Value) -> Message:
"""Creates instance of class from a :class:`~google.protobuf.struct_pb2.Value` object.
Args:
value: a :class:`~google.protobuf.struct_pb2.Value` object
Returns:
Instance of class
"""
value_dict = json_format.MessageToDict(value)
return json_format.ParseDict(value_dict, cls()._pb)
def from_map(cls: MessageMeta, map_: MapComposite) -> Message:
"""Creates instance of class from a :class:`~proto.marshal.collections.maps.MapComposite` object.
Args:
map_: a :class:`~proto.marshal.collections.maps.MapComposite` object
Returns:
Instance of class
"""
marshal = Marshal(name="marshal")
pb = marshal.to_proto(Value, map_)
return from_value(cls, pb)

View File

@@ -0,0 +1,177 @@
# -*- coding: utf-8 -*-
# Copyright 2021 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import re
from typing import Dict, NamedTuple, Optional
from google.cloud.aiplatform.compat.services import featurestore_service_client
from google.cloud.aiplatform.compat.types import (
feature as gca_feature,
featurestore_service as gca_featurestore_service,
)
from google.cloud.aiplatform import utils
CompatFeaturestoreServiceClient = featurestore_service_client.FeaturestoreServiceClient
RESOURCE_ID_PATTERN_REGEX = r"[a-z_][a-z0-9_]{0,59}"
GCS_SOURCE_TYPE = {"csv", "avro"}
GCS_DESTINATION_TYPE = {"csv", "tfrecord"}
_FEATURE_VALUE_TYPE_UNSPECIFIED = "VALUE_TYPE_UNSPECIFIED"
FEATURE_STORE_VALUE_TYPE_TO_BQ_DATA_TYPE_MAP = {
"BOOL": {"field_type": "BOOL"},
"BOOL_ARRAY": {"field_type": "BOOL", "mode": "REPEATED"},
"DOUBLE": {"field_type": "FLOAT64"},
"DOUBLE_ARRAY": {"field_type": "FLOAT64", "mode": "REPEATED"},
"INT64": {"field_type": "INT64"},
"INT64_ARRAY": {"field_type": "INT64", "mode": "REPEATED"},
"STRING": {"field_type": "STRING"},
"STRING_ARRAY": {"field_type": "STRING", "mode": "REPEATED"},
"BYTES": {"field_type": "BYTES"},
}
def validate_id(resource_id: str) -> None:
"""Validates feature store resource ID pattern.
Args:
resource_id (str):
Required. Feature Store resource ID.
Raises:
ValueError if resource_id is invalid.
"""
if not re.compile(r"^" + RESOURCE_ID_PATTERN_REGEX + r"$").match(resource_id):
raise ValueError("Resource ID {resource_id} is not a valid resource id.")
def validate_feature_id(feature_id: str) -> None:
"""Validates feature ID.
Args:
feature_id (str):
Required. Feature resource ID.
Raises:
ValueError if feature_id is invalid.
"""
match = re.compile(r"^" + RESOURCE_ID_PATTERN_REGEX + r"$").match(feature_id)
if not match:
raise ValueError(
f"The value of feature_id may be up to 60 characters, and valid characters are `[a-z0-9_]`. "
f"The first character cannot be a number. Instead, get {feature_id}."
)
reserved_words = ["entity_id", "feature_timestamp", "arrival_timestamp"]
if feature_id.lower() in reserved_words:
raise ValueError(
"The feature_id can not be any of the reserved_words: `%s`"
% ("`, `".join(reserved_words))
)
def validate_value_type(value_type: str) -> None:
"""Validates user provided feature value_type string.
Args:
value_type (str):
Required. Immutable. Type of Feature value.
One of BOOL, BOOL_ARRAY, DOUBLE, DOUBLE_ARRAY, INT64, INT64_ARRAY, STRING, STRING_ARRAY, BYTES.
Raises:
ValueError if value_type is invalid or unspecified.
"""
if getattr(gca_feature.Feature.ValueType, value_type, None) in (
gca_feature.Feature.ValueType.VALUE_TYPE_UNSPECIFIED,
None,
):
raise ValueError(
f"Given value_type `{value_type}` invalid or unspecified. "
f"Choose one of {gca_feature.Feature.ValueType._member_names_} except `{_FEATURE_VALUE_TYPE_UNSPECIFIED}`"
)
class _FeatureConfig(NamedTuple):
"""Configuration for feature creation.
Usage:
config = _FeatureConfig(
feature_id='my_feature_id',
value_type='int64',
description='my description',
labels={'my_key': 'my_value'},
)
"""
feature_id: str
value_type: str = _FEATURE_VALUE_TYPE_UNSPECIFIED
description: Optional[str] = None
labels: Optional[Dict[str, str]] = None
def _get_feature_id(self) -> str:
"""Validates and returns the feature_id.
Returns:
str - valid feature ID.
Raise:
ValueError if feature_id is invalid
"""
# Raises ValueError if invalid feature_id
validate_feature_id(feature_id=self.feature_id)
return self.feature_id
def _get_value_type_enum(self) -> int:
"""Validates value_type and returns the enum of the value type.
Returns:
int - valid value type enum.
"""
# Raises ValueError if invalid value_type
validate_value_type(value_type=self.value_type)
value_type_enum = getattr(gca_feature.Feature.ValueType, self.value_type)
return value_type_enum
def get_create_feature_request(
self,
) -> gca_featurestore_service.CreateFeatureRequest:
"""Return create feature request."""
gapic_feature = gca_feature.Feature(
value_type=self._get_value_type_enum(),
)
if self.labels:
utils.validate_labels(self.labels)
gapic_feature.labels = self.labels
if self.description:
gapic_feature.description = self.description
create_feature_request = gca_featurestore_service.CreateFeatureRequest(
feature=gapic_feature, feature_id=self._get_feature_id()
)
return create_feature_request

View File

@@ -0,0 +1,398 @@
# -*- coding: utf-8 -*-
# Copyright 2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import datetime
import glob
import logging
import os
import pathlib
import tempfile
from typing import Optional, TYPE_CHECKING
from google.auth import credentials as auth_credentials
from google.cloud import storage
from google.cloud.aiplatform import initializer
from google.cloud.aiplatform.utils import resource_manager_utils
if TYPE_CHECKING:
import pandas
_logger = logging.getLogger(__name__)
def upload_to_gcs(
source_path: str,
destination_uri: str,
project: Optional[str] = None,
credentials: Optional[auth_credentials.Credentials] = None,
):
"""Uploads local files to GCS.
After upload the `destination_uri` will contain the same data as the `source_path`.
Args:
source_path: Required. Path of the local data to copy to GCS.
destination_uri: Required. GCS URI where the data should be uploaded.
project: Optional. Google Cloud Project that contains the staging bucket.
credentials: The custom credentials to use when making API calls.
If not provided, default credentials will be used.
Raises:
RuntimeError: When source_path does not exist.
GoogleCloudError: When the upload process fails.
"""
source_path_obj = pathlib.Path(source_path)
if not source_path_obj.exists():
raise RuntimeError(f"Source path does not exist: {source_path}")
project = project or initializer.global_config.project
credentials = credentials or initializer.global_config.credentials
storage_client = storage.Client(project=project, credentials=credentials)
if source_path_obj.is_dir():
source_file_paths = glob.glob(
pathname=str(source_path_obj / "**"), recursive=True
)
for source_file_path in source_file_paths:
source_file_path_obj = pathlib.Path(source_file_path)
if source_file_path_obj.is_dir():
continue
source_file_relative_path_obj = source_file_path_obj.relative_to(
source_path_obj
)
source_file_relative_posix_path = source_file_relative_path_obj.as_posix()
destination_file_uri = (
destination_uri.rstrip("/") + "/" + source_file_relative_posix_path
)
_logger.debug(f'Uploading "{source_file_path}" to "{destination_file_uri}"')
destination_blob = storage.Blob.from_string(
destination_file_uri, client=storage_client
)
destination_blob.upload_from_filename(filename=source_file_path)
else:
source_file_path = source_path
destination_file_uri = destination_uri
_logger.debug(f'Uploading "{source_file_path}" to "{destination_file_uri}"')
destination_blob = storage.Blob.from_string(
destination_file_uri, client=storage_client
)
destination_blob.upload_from_filename(filename=source_file_path)
def stage_local_data_in_gcs(
data_path: str,
staging_gcs_dir: Optional[str] = None,
project: Optional[str] = None,
location: Optional[str] = None,
credentials: Optional[auth_credentials.Credentials] = None,
) -> str:
"""Stages a local data in GCS.
The file copied to GCS is the name of the local file prepended with an
"aiplatform-{timestamp}-" string.
Args:
data_path: Required. Path of the local data to copy to GCS.
staging_gcs_dir:
Optional. Google Cloud Storage bucket to be used for data staging.
project: Optional. Google Cloud Project that contains the staging bucket.
location: Optional. Google Cloud location to use for the staging bucket.
credentials: The custom credentials to use when making API calls.
If not provided, default credentials will be used.
Returns:
Google Cloud Storage URI of the staged data.
Raises:
RuntimeError: When source_path does not exist.
GoogleCloudError: When the upload process fails.
"""
data_path_obj = pathlib.Path(data_path)
if not data_path_obj.exists():
raise RuntimeError(f"Local data does not exist: data_path='{data_path}'")
staging_gcs_dir = staging_gcs_dir or initializer.global_config.staging_bucket
if not staging_gcs_dir:
project = project or initializer.global_config.project
location = location or initializer.global_config.location
credentials = credentials or initializer.global_config.credentials
# Creating the bucket if it does not exist.
# Currently we only do this when staging_gcs_dir is not specified.
# The buckets that we create are regional.
# This prevents errors when some service required regional bucket.
# E.g. "FailedPrecondition: 400 The Cloud Storage bucket of `gs://...` is in location `us`. It must be in the same regional location as the service location `us-central1`."
# We are making the bucket name region-specific since the bucket is regional.
staging_bucket_name = project + "-vertex-staging-" + location
client = storage.Client(project=project, credentials=credentials)
staging_bucket = storage.Bucket(client=client, name=staging_bucket_name)
if not staging_bucket.exists():
_logger.info(f'Creating staging GCS bucket "{staging_bucket_name}"')
staging_bucket = client.create_bucket(
bucket_or_name=staging_bucket,
project=project,
location=location,
)
staging_gcs_dir = "gs://" + staging_bucket_name
timestamp = datetime.datetime.now().isoformat(sep="-", timespec="milliseconds")
staging_gcs_subdir = (
staging_gcs_dir.rstrip("/") + "/vertex_ai_auto_staging/" + timestamp
)
staged_data_uri = staging_gcs_subdir
if data_path_obj.is_file():
staged_data_uri = staging_gcs_subdir + "/" + data_path_obj.name
_logger.info(f'Uploading "{data_path}" to "{staged_data_uri}"')
upload_to_gcs(
source_path=data_path,
destination_uri=staged_data_uri,
project=project,
credentials=credentials,
)
return staged_data_uri
def generate_gcs_directory_for_pipeline_artifacts(
project: Optional[str] = None,
location: Optional[str] = None,
):
"""Gets or creates the GCS directory for Vertex Pipelines artifacts.
Args:
project: Optional. Google Cloud Project that contains the staging bucket.
location: Optional. Google Cloud location to use for the staging bucket.
Returns:
Google Cloud Storage URI of the staged data.
"""
project = project or initializer.global_config.project
location = location or initializer.global_config.location
pipelines_bucket_name = project + "-vertex-pipelines-" + location
output_artifacts_gcs_dir = "gs://" + pipelines_bucket_name + "/output_artifacts/"
return output_artifacts_gcs_dir
def create_gcs_bucket_for_pipeline_artifacts_if_it_does_not_exist(
output_artifacts_gcs_dir: Optional[str] = None,
service_account: Optional[str] = None,
project: Optional[str] = None,
location: Optional[str] = None,
credentials: Optional[auth_credentials.Credentials] = None,
):
"""Gets or creates the GCS directory for Vertex Pipelines artifacts.
Args:
output_artifacts_gcs_dir: Optional. The GCS location for the pipeline outputs.
It will be generated if not specified.
service_account: Optional. Google Cloud service account that will be used
to run the pipelines. If this function creates a new bucket it will give
permission to the specified service account to access the bucket.
If not provided, the Google Cloud Compute Engine service account will be used.
project: Optional. Google Cloud Project that contains the staging bucket.
location: Optional. Google Cloud location to use for the staging bucket.
credentials: The custom credentials to use when making API calls.
If not provided, default credentials will be used.
Returns:
Google Cloud Storage URI of the staged data.
"""
project = project or initializer.global_config.project
location = location or initializer.global_config.location
service_account = service_account or initializer.global_config.service_account
credentials = credentials or initializer.global_config.credentials
output_artifacts_gcs_dir = (
output_artifacts_gcs_dir
or generate_gcs_directory_for_pipeline_artifacts(
project=project,
location=location,
)
)
# Creating the bucket if needed
storage_client = storage.Client(
project=project,
credentials=credentials,
)
pipelines_bucket = storage.Bucket.from_string(
uri=output_artifacts_gcs_dir,
client=storage_client,
)
if not pipelines_bucket.exists():
_logger.info(
f'Creating GCS bucket for Vertex Pipelines: "{pipelines_bucket.name}"'
)
pipelines_bucket = storage_client.create_bucket(
bucket_or_name=pipelines_bucket,
project=project,
location=location,
)
# Giving the service account read and write access to the new bucket
# Workaround for error: "Failed to create pipeline job. Error: Service account `NNNNNNNN-compute@developer.gserviceaccount.com`
# does not have `[storage.objects.get, storage.objects.create]` IAM permission(s) to the bucket `xxxxxxxx-vertex-pipelines-us-central1`.
# Please either copy the files to the Google Cloud Storage bucket owned by your project, or grant the required IAM permission(s) to the service account."
if not service_account:
# Getting the project number to use in service account
project_number = resource_manager_utils.get_project_number(project)
service_account = f"{project_number}-compute@developer.gserviceaccount.com"
bucket_iam_policy = pipelines_bucket.get_iam_policy()
bucket_iam_policy.setdefault("roles/storage.objectCreator", set()).add(
f"serviceAccount:{service_account}"
)
bucket_iam_policy.setdefault("roles/storage.objectViewer", set()).add(
f"serviceAccount:{service_account}"
)
pipelines_bucket.set_iam_policy(bucket_iam_policy)
return output_artifacts_gcs_dir
def download_file_from_gcs(
source_file_uri: str,
destination_file_path: str,
project: Optional[str] = None,
credentials: Optional[auth_credentials.Credentials] = None,
):
"""Downloads a GCS file to local path.
Args:
source_file_uri (str):
Required. GCS URI of the file to download.
destination_file_path (str):
Required. local path where the data should be downloaded.
project (str):
Optional. Google Cloud Project that contains the staging bucket.
credentials (auth_credentials.Credentials):
Optional. The custom credentials to use when making API calls.
If not provided, default credentials will be used.
Raises:
RuntimeError: When destination_path does not exist.
GoogleCloudError: When the download process fails.
"""
project = project or initializer.global_config.project
credentials = credentials or initializer.global_config.credentials
storage_client = storage.Client(project=project, credentials=credentials)
source_blob = storage.Blob.from_string(source_file_uri, client=storage_client)
_logger.debug(f'Downloading "{source_file_uri}" to "{destination_file_path}"')
source_blob.download_to_filename(filename=destination_file_path)
def download_from_gcs(
source_uri: str,
destination_path: str,
project: Optional[str] = None,
credentials: Optional[auth_credentials.Credentials] = None,
):
"""Downloads GCS files to local path.
Args:
source_uri (str):
Required. GCS URI(or prefix) of the file(s) to download.
destination_path (str):
Required. local path where the data should be downloaded.
If provided a file path, then `source_uri` must refer to a file.
If provided a directory path, then `source_uri` must refer to a prefix.
project (str):
Optional. Google Cloud Project that contains the staging bucket.
credentials (auth_credentials.Credentials):
Optional. The custom credentials to use when making API calls.
If not provided, default credentials will be used.
Raises:
GoogleCloudError: When the download process fails.
"""
project = project or initializer.global_config.project
credentials = credentials or initializer.global_config.credentials
storage_client = storage.Client(project=project, credentials=credentials)
validate_gcs_path(source_uri)
bucket_name, prefix = source_uri.replace("gs://", "").split("/", maxsplit=1)
blobs = storage_client.list_blobs(bucket_or_name=bucket_name, prefix=prefix)
for blob in blobs:
# In SDK 2.0 remote training, we'll create some empty files.
# These files ends with '/', and we'll skip them.
if not blob.name.endswith("/"):
rel_path = os.path.relpath(blob.name, prefix)
filename = (
destination_path
if rel_path == "."
else os.path.join(destination_path, rel_path)
)
os.makedirs(os.path.dirname(filename), exist_ok=True)
blob.download_to_filename(filename=filename)
def _upload_pandas_df_to_gcs(
df: "pandas.DataFrame", upload_gcs_path: str, file_format: str = "jsonl"
) -> None:
"""Uploads the provided Pandas DataFrame to a GCS bucket.
Args:
df (pandas.DataFrame):
Required. The Pandas DataFrame to upload.
upload_gcs_path (str):
Required. The GCS path to upload the data file.
file_format (str):
Required. The format to export the DataFrame to. Currently
only JSONL is supported.
Raises:
ValueError: When a file format other than JSONL is provided.
"""
with tempfile.TemporaryDirectory() as temp_dir:
local_dataset_path = os.path.join(temp_dir, "dataset.jsonl")
if file_format == "jsonl":
df.to_json(path_or_buf=local_dataset_path, orient="records", lines=True)
else:
raise ValueError(f"Unsupported file format: {file_format}")
storage_client = storage.Client(
project=initializer.global_config.project,
credentials=initializer.global_config.credentials,
)
storage.Blob.from_string(
uri=upload_gcs_path, client=storage_client
).upload_from_filename(filename=local_dataset_path)
def validate_gcs_path(gcs_path: str) -> None:
"""Validates a GCS path.
Args:
gcs_path (str):
Required. A GCS path to validate.
Raises:
ValueError if gcs_path is invalid.
"""
if not gcs_path.startswith("gs://"):
raise ValueError(
f"Invalid GCS path {gcs_path}. Please provide a valid GCS path starting with 'gs://'"
)

View File

@@ -0,0 +1,41 @@
# -*- coding: utf-8 -*-
# Copyright 2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from pathlib import Path
def _is_relative_to(path: str, to_path: str) -> bool:
"""Returns whether or not this path is relative to another path.
This function can be replacted by Path.is_relative_to which is availble in Python 3.9+.
Args:
path (str):
Required. The path to check whether it is relative to the other path.
to_path (str):
Required. The path to check whether the other path is relative to it.
Returns:
Whether the path is relative to another path.
"""
try:
Path(path).expanduser().resolve().relative_to(
Path(to_path).expanduser().resolve()
)
return True
except ValueError:
return False

View File

@@ -0,0 +1,293 @@
# -*- coding: utf-8 -*-
# Copyright 2021 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import copy
import json
from typing import Any, Dict, Mapping, Optional, Union
from google.cloud.aiplatform.compat.types import pipeline_failure_policy
import packaging.version
class PipelineRuntimeConfigBuilder(object):
"""Pipeline RuntimeConfig builder.
Constructs a RuntimeConfig spec with pipeline_root and parameter overrides.
"""
def __init__(
self,
pipeline_root: str,
schema_version: str,
parameter_types: Mapping[str, str],
parameter_values: Optional[Dict[str, Any]] = None,
input_artifacts: Optional[Dict[str, str]] = None,
failure_policy: Optional[pipeline_failure_policy.PipelineFailurePolicy] = None,
default_runtime: Optional[Dict[str, Any]] = None,
):
"""Creates a PipelineRuntimeConfigBuilder object.
Args:
pipeline_root (str):
Required. The root of the pipeline outputs.
schema_version (str):
Required. Schema version of the IR. This field determines the fields supported in current version of IR.
parameter_types (Mapping[str, str]):
Required. The mapping from pipeline parameter name to its type.
parameter_values (Dict[str, Any]):
Optional. The mapping from runtime parameter name to its value.
input_artifacts (Dict[str, str]):
Optional. The mapping from the runtime parameter name for this artifact to its resource id.
failure_policy (pipeline_failure_policy.PipelineFailurePolicy):
Optional. Represents the failure policy of a pipeline. Currently, the
default of a pipeline is that the pipeline will continue to
run until no more tasks can be executed, also known as
PIPELINE_FAILURE_POLICY_FAIL_SLOW. However, if a pipeline is
set to PIPELINE_FAILURE_POLICY_FAIL_FAST, it will stop
scheduling any new tasks when a task has failed. Any
scheduled tasks will continue to completion.
default_runtime (Dict[str, Any]):
Optional. The default runtime config for the pipeline.
"""
self._pipeline_root = pipeline_root
self._schema_version = schema_version
self._parameter_types = parameter_types
self._parameter_values = copy.deepcopy(parameter_values or {})
self._input_artifacts = copy.deepcopy(input_artifacts or {})
self._failure_policy = failure_policy
self._default_runtime = default_runtime
@classmethod
def from_job_spec_json(
cls,
job_spec: Mapping[str, Any],
) -> "PipelineRuntimeConfigBuilder":
"""Creates a PipelineRuntimeConfigBuilder object from PipelineJob json spec.
Args:
job_spec (Mapping[str, Any]):
Required. The PipelineJob spec.
Returns:
A PipelineRuntimeConfigBuilder object.
"""
runtime_config_spec = job_spec["runtimeConfig"]
parameter_input_definitions = (
job_spec["pipelineSpec"]["root"]
.get("inputDefinitions", {})
.get("parameters", {})
)
schema_version = job_spec["pipelineSpec"]["schemaVersion"]
# 'type' is deprecated in IR and change to 'parameterType'.
parameter_types = {
k: v.get("parameterType") or v.get("type")
for k, v in parameter_input_definitions.items()
}
pipeline_root = runtime_config_spec.get("gcsOutputDirectory")
parameter_values = _parse_runtime_parameters(runtime_config_spec)
failure_policy = runtime_config_spec.get("failurePolicy")
return cls(
pipeline_root=pipeline_root,
schema_version=schema_version,
parameter_types=parameter_types,
parameter_values=parameter_values,
failure_policy=failure_policy,
)
def update_pipeline_root(self, pipeline_root: Optional[str]) -> None:
"""Updates pipeline_root value.
Args:
pipeline_root (str):
Optional. The root of the pipeline outputs.
"""
if pipeline_root:
self._pipeline_root = pipeline_root
def update_runtime_parameters(
self, parameter_values: Optional[Mapping[str, Any]] = None
) -> None:
"""Merges runtime parameter values.
Args:
parameter_values (Mapping[str, Any]):
Optional. The mapping from runtime parameter names to its values.
"""
if parameter_values:
parameters = dict(parameter_values)
if packaging.version.parse(self._schema_version) <= packaging.version.parse(
"2.0.0"
):
for k, v in parameter_values.items():
if isinstance(v, (dict, list, bool)):
parameters[k] = json.dumps(v)
self._parameter_values.update(parameters)
def update_input_artifacts(
self, input_artifacts: Optional[Mapping[str, str]]
) -> None:
"""Merges runtime input artifacts.
Args:
input_artifacts (Mapping[str, str]):
Optional. The mapping from the runtime parameter name for this artifact to its resource id.
"""
if input_artifacts:
self._input_artifacts.update(input_artifacts)
def update_failure_policy(self, failure_policy: Optional[str] = None) -> None:
"""Merges runtime failure policy.
Args:
failure_policy (str):
Optional. The failure policy - "slow" or "fast".
Raises:
ValueError: if failure_policy is not valid.
"""
if failure_policy:
if failure_policy in _FAILURE_POLICY_TO_ENUM_VALUE:
self._failure_policy = _FAILURE_POLICY_TO_ENUM_VALUE[failure_policy]
else:
raise ValueError(
f'failure_policy should be either "slow" or "fast", but got: "{failure_policy}".'
)
def update_default_runtime(self, default_runtime: Dict[str, Any]) -> None:
"""Merges default runtime.
Args:
default_runtime (Dict[str, Any]):
default runtime config for the pipeline.
"""
if default_runtime:
self._default_runtime = default_runtime
def build(self) -> Dict[str, Any]:
"""Build a RuntimeConfig proto.
Raises:
ValueError: if the pipeline root is not specified.
"""
if not self._pipeline_root:
raise ValueError(
"Pipeline root must be specified, either during "
"compile time, or when calling the service."
)
if packaging.version.parse(self._schema_version) > packaging.version.parse(
"2.0.0"
):
parameter_values_key = "parameterValues"
else:
parameter_values_key = "parameters"
runtime_config = {
"gcsOutputDirectory": self._pipeline_root,
parameter_values_key: {
k: self._get_vertex_value(k, v)
for k, v in self._parameter_values.items()
if v is not None
},
"inputArtifacts": {
k: {"artifactId": v} for k, v in self._input_artifacts.items()
},
}
if self._default_runtime:
# Only v1beta1 supports DefaultRuntime. For other cases, the field is
# None and we clear the defaultRuntime field.
runtime_config["defaultRuntime"] = self._default_runtime
if self._failure_policy:
runtime_config["failurePolicy"] = self._failure_policy
return runtime_config
def _get_vertex_value(
self, name: str, value: Union[int, float, str, bool, list, dict]
) -> Union[int, float, str, bool, list, dict]:
"""Converts primitive values into Vertex pipeline Value proto message.
Args:
name (str):
Required. The name of the pipeline parameter.
value (Union[int, float, str, bool, list, dict]):
Required. The value of the pipeline parameter.
Returns:
A dictionary represents the Vertex pipeline Value proto message.
Raises:
ValueError: if the parameter name is not found in pipeline root
inputs, or value is none.
"""
if value is None:
raise ValueError("None values should be filtered out.")
if name not in self._parameter_types:
raise ValueError(
"The pipeline parameter {} is not found in the "
"pipeline job input definitions.".format(name)
)
if packaging.version.parse(self._schema_version) <= packaging.version.parse(
"2.0.0"
):
result = {}
if self._parameter_types[name] == "INT":
result["intValue"] = value
elif self._parameter_types[name] == "DOUBLE":
result["doubleValue"] = value
elif self._parameter_types[name] == "STRING":
result["stringValue"] = value
else:
raise TypeError("Got unknown type of value: {}".format(value))
return result
else:
return value
def _parse_runtime_parameters(
runtime_config_spec: Mapping[str, Any]
) -> Optional[Dict[str, Any]]:
"""Extracts runtime parameters from runtime config json spec.
Raises:
TypeError: if the parameter type is not one of 'INT', 'DOUBLE', 'STRING'.
"""
# 'parameters' are deprecated in IR and changed to 'parameterValues'.
if runtime_config_spec.get("parameterValues") is not None:
return runtime_config_spec.get("parameterValues")
if runtime_config_spec.get("parameters") is not None:
result = {}
for name, value in runtime_config_spec.get("parameters").items():
if "intValue" in value:
result[name] = int(value["intValue"])
elif "doubleValue" in value:
result[name] = float(value["doubleValue"])
elif "stringValue" in value:
result[name] = value["stringValue"]
else:
raise TypeError("Got unknown type of value: {}".format(value))
return result
_FAILURE_POLICY_TO_ENUM_VALUE = {
"slow": pipeline_failure_policy.PipelineFailurePolicy.PIPELINE_FAILURE_POLICY_FAIL_SLOW,
"fast": pipeline_failure_policy.PipelineFailurePolicy.PIPELINE_FAILURE_POLICY_FAIL_FAST,
None: pipeline_failure_policy.PipelineFailurePolicy.PIPELINE_FAILURE_POLICY_UNSPECIFIED,
}

View File

@@ -0,0 +1,153 @@
# -*- coding: utf-8 -*-
# Copyright 2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import shutil
import inspect
import logging
import os
from pathlib import Path
import re
from typing import Any, Optional, Sequence, Tuple, Type
from google.cloud import storage
from google.cloud.aiplatform.constants import prediction
from google.cloud.aiplatform.utils import path_utils
_logger = logging.getLogger(__name__)
REGISTRY_REGEX = re.compile(r"^([\w\-]+\-docker\.pkg\.dev|([\w]+\.|)gcr\.io)")
GCS_URI_PREFIX = "gs://"
def inspect_source_from_class(
custom_class: Type[Any],
src_dir: str,
) -> Tuple[str, str]:
"""Inspects the source file from a custom class and returns its import path.
Args:
custom_class (Type[Any]):
Required. The custom class needs to be inspected for the source file.
src_dir (str):
Required. The path to the local directory including all needed files.
The source file of the custom class must be in this directory.
Returns:
(import_from, class_name): the source file path in python import format
and the custom class name.
Raises:
ValueError: If the source file of the custom class is not in the source
directory.
"""
src_dir_abs_path = Path(src_dir).expanduser().resolve()
custom_class_name = custom_class.__name__
custom_class_path = Path(inspect.getsourcefile(custom_class)).resolve()
if not path_utils._is_relative_to(custom_class_path, src_dir_abs_path):
raise ValueError(
f'The file implementing "{custom_class_name}" must be in "{src_dir}".'
)
custom_class_import_path = custom_class_path.relative_to(src_dir_abs_path)
custom_class_import_path = custom_class_import_path.with_name(
custom_class_import_path.stem
)
custom_class_import = custom_class_import_path.as_posix().replace(os.sep, ".")
return custom_class_import, custom_class_name
def is_registry_uri(image_uri: str) -> bool:
"""Checks whether the image uri is in container registry or artifact registry.
Args:
image_uri (str):
The image uri to check if it is in container registry or artifact registry.
Returns:
True if the image uri is in container registry or artifact registry.
"""
return REGISTRY_REGEX.match(image_uri) is not None
def get_prediction_aip_http_port(
serving_container_ports: Optional[Sequence[int]] = None,
) -> int:
"""Gets the used prediction container port from serving container ports.
If containerSpec.ports is specified during Model or LocalModel creation time, retrieve
the first entry in this field. Otherwise use the default value of 8080. The environment
variable AIP_HTTP_PORT will be set to this value.
See https://cloud.google.com/vertex-ai/docs/predictions/custom-container-requirements
for more details.
Args:
serving_container_ports (Sequence[int]):
Optional. Declaration of ports that are exposed by the container. This field is
primarily informational, it gives Vertex AI information about the
network connections the container uses. Listing or not a port here has
no impact on whether the port is actually exposed, any port listening on
the default "0.0.0.0" address inside a container will be accessible from
the network.
Returns:
The first element in the serving_container_ports. If there is no any values in it,
return the default http port.
"""
return (
serving_container_ports[0]
if serving_container_ports is not None and len(serving_container_ports) > 0
else prediction.DEFAULT_AIP_HTTP_PORT
)
def download_model_artifacts(artifact_uri: str) -> None:
"""Prepares model artifacts in the current working directory.
If artifact_uri is a GCS uri, the model artifacts will be downloaded to the current
working directory.
If artifact_uri is a local directory, the model artifacts will be copied to the current
working directory.
Args:
artifact_uri (str):
Required. The artifact uri that includes model artifacts.
"""
if artifact_uri.startswith(GCS_URI_PREFIX):
matches = re.match(f"{GCS_URI_PREFIX}(.*?)/(.*)", artifact_uri)
bucket_name, prefix = matches.groups()
gcs_client = storage.Client()
blobs = gcs_client.list_blobs(bucket_name, prefix=prefix)
for blob in blobs:
name_without_prefix = blob.name[len(prefix) :]
name_without_prefix = (
name_without_prefix[1:]
if name_without_prefix.startswith("/")
else name_without_prefix
)
file_split = name_without_prefix.split("/")
directory = "/".join(file_split[0:-1])
Path(directory).mkdir(parents=True, exist_ok=True)
if name_without_prefix and not name_without_prefix.endswith("/"):
blob.download_to_filename(name_without_prefix)
else:
# Copy files to the current working directory.
shutil.copytree(artifact_uri, ".", dirs_exist_ok=True)

View File

@@ -0,0 +1,77 @@
# -*- coding: utf-8 -*-
# Copyright 2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from typing import Optional
from google.auth import credentials as auth_credentials
from google.cloud import resourcemanager
from google.cloud.aiplatform import initializer
def get_project_id(
project_number: str,
credentials: Optional[auth_credentials.Credentials] = None,
) -> str:
"""Gets project ID given the project number
Args:
project_number (str):
Required. The automatically generated unique identifier for your GCP project.
credentials: The custom credentials to use when making API calls.
Optional. If not provided, default credentials will be used.
Returns:
str - The unique string used to differentiate your GCP project from all others in Google Cloud.
"""
credentials = credentials or initializer.global_config.credentials
projects_client = resourcemanager.ProjectsClient(credentials=credentials)
project = projects_client.get_project(name=f"projects/{project_number}")
return project.project_id
def get_project_number(
project_id: str,
credentials: Optional[auth_credentials.Credentials] = None,
) -> str:
"""Gets project ID given the project number
Args:
project_id (str):
Required. Google Cloud project unique ID.
credentials: The custom credentials to use when making API calls.
Optional. If not provided, default credentials will be used.
Returns:
str - The automatically generated unique numerical identifier for your GCP project.
"""
credentials = credentials or initializer.global_config.credentials
projects_client = resourcemanager.ProjectsClient(credentials=credentials)
project = projects_client.get_project(name=f"projects/{project_id}")
project_number = project.name.split("/", 1)[1]
return project_number

View File

@@ -0,0 +1,35 @@
# -*- coding: utf-8 -*-
# Copyright 2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from google.cloud.aiplatform import base
def make_gcp_resource_rest_url(resource: base.VertexAiResourceNoun) -> str:
"""Helper function to format the GCP resource url for google.X metadata schemas.
Args:
resource (base.VertexAiResourceNoun): Required. A Vertex resource instance.
Returns:
The formatted url of resource.
"""
try:
resource_name = resource.versioned_resource_name
except AttributeError:
resource_name = resource.resource_name
version = resource.api_client._default_version
api_uri = resource.api_client.api_endpoint
return f"https://{api_uri}/{version}/{resource_name}"

View File

@@ -0,0 +1,248 @@
# -*- coding: utf-8 -*-
# Copyright 2021 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import functools
import os
import pathlib
import shutil
import subprocess
import sys
import tempfile
from typing import Optional, Sequence, Callable
from google.auth import credentials as auth_credentials
from google.cloud.aiplatform import base
from google.cloud.aiplatform import utils
_LOGGER = base.Logger(__name__)
def _get_python_executable() -> str:
"""Returns Python executable.
Returns:
Python executable to use for setuptools packaging.
Raises:
EnvironmentError: If Python executable is not found.
"""
python_executable = sys.executable
if not python_executable:
raise EnvironmentError("Cannot find Python executable for packaging.")
return python_executable
class _TrainingScriptPythonPackager:
"""Converts a Python script into Python package suitable for aiplatform
training.
Copies the script to specified location.
Class Attributes:
_TRAINER_FOLDER: Constant folder name to build package.
_ROOT_MODULE: Constant root name of module.
_TEST_MODULE_NAME: Constant name of module that will store script.
_SETUP_PY_VERSION: Constant version of this created python package.
_SETUP_PY_TEMPLATE: Constant template used to generate setup.py file.
_SETUP_PY_SOURCE_DISTRIBUTION_CMD:
Constant command to generate the source distribution package.
Attributes:
script_path: local path of script or folder to package
requirements: list of Python dependencies to add to package
Usage:
packager = TrainingScriptPythonPackager('my_script.py', ['pandas', 'pytorch'])
gcs_path = packager.package_and_copy_to_gcs(
gcs_staging_dir='my-bucket',
project='my-project')
module_name = packager.module_name
The package after installed can be executed as:
python -m aiplatform_custom_trainer_script.task
"""
_TRAINER_FOLDER = "trainer"
_ROOT_MODULE = "aiplatform_custom_trainer_script"
_SETUP_PY_VERSION = "0.1"
_SETUP_PY_TEMPLATE = """from setuptools import find_packages
from setuptools import setup
setup(
name='{name}',
version='{version}',
packages=find_packages(),
install_requires=({requirements}),
include_package_data=True,
description='My training application.'
)"""
_SETUP_PY_SOURCE_DISTRIBUTION_CMD = "setup.py sdist --formats=gztar"
def __init__(
self,
script_path: str,
task_module_name: str = "task",
requirements: Optional[Sequence[str]] = None,
):
"""Initializes packager.
Args:
script_path (str): Required. Local path to script.
requirements (Sequence[str]):
List of python packages dependencies of script.
"""
self.script_path = script_path
self.task_module_name = task_module_name
self.requirements = requirements or []
@property
def module_name(self) -> str:
# Module name that can be executed during training. ie. python -m
return f"{self._ROOT_MODULE}.{self.task_module_name}"
def make_package(self, package_directory: str) -> str:
"""Converts script into a Python package suitable for python module
execution.
Args:
package_directory (str): Directory to build package in.
Returns:
source_distribution_path (str): Path to built package.
Raises:
RunTimeError: If package creation fails.
"""
# The root folder to builder the package in
package_path = pathlib.Path(package_directory)
# Root directory of the package
trainer_root_path = package_path / self._TRAINER_FOLDER
# The root module of the python package
trainer_path = trainer_root_path / self._ROOT_MODULE
# __init__.py path in root module
init_path = trainer_path / "__init__.py"
# The path to setup.py in the package.
setup_py_path = trainer_root_path / "setup.py"
# The path to the generated source distribution.
source_distribution_path = (
trainer_root_path
/ "dist"
/ f"{self._ROOT_MODULE}-{self._SETUP_PY_VERSION}.tar.gz"
)
trainer_root_path.mkdir()
trainer_path.mkdir()
# Make empty __init__.py
with init_path.open("w"):
pass
# Format the setup.py file.
setup_py_output = self._SETUP_PY_TEMPLATE.format(
name=self._ROOT_MODULE,
requirements=",".join(f'"{r}"' for r in self.requirements),
version=self._SETUP_PY_VERSION,
)
# Write setup.py
with setup_py_path.open("w") as fp:
fp.write(setup_py_output)
if os.path.isdir(self.script_path):
# Remove destination path if it already exists
shutil.rmtree(trainer_path)
# Copy folder recursively
shutil.copytree(src=self.script_path, dst=trainer_path)
else:
# The module that will contain the script
script_out_path = trainer_path / f"{self.task_module_name}.py"
# Copy script as module of python package.
shutil.copy(self.script_path, script_out_path)
# Run setup.py to create the source distribution.
setup_cmd = [
_get_python_executable()
] + self._SETUP_PY_SOURCE_DISTRIBUTION_CMD.split()
p = subprocess.Popen(
args=setup_cmd,
cwd=trainer_root_path,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
)
output, error = p.communicate()
# Raise informative error if packaging fails.
if p.returncode != 0:
raise RuntimeError(
"Packaging of training script failed with code %d\n%s \n%s"
% (p.returncode, output.decode(), error.decode())
)
return str(source_distribution_path)
def package_and_copy(self, copy_method: Callable[[str], str]) -> str:
"""Packages the script and executes copy with given copy_method.
Args:
copy_method Callable[[str], str]
Takes a string path, copies to a desired location, and returns the
output path location.
Returns:
output_path str: Location of copied package.
"""
with tempfile.TemporaryDirectory() as tmpdirname:
source_distribution_path = self.make_package(tmpdirname)
output_location = copy_method(source_distribution_path)
_LOGGER.info("Training script copied to:\n%s." % output_location)
return output_location
def package_and_copy_to_gcs(
self,
gcs_staging_dir: str,
project: str = None,
credentials: Optional[auth_credentials.Credentials] = None,
) -> str:
"""Packages script in Python package and copies package to GCS bucket.
Args
gcs_staging_dir (str): Required. GCS Staging directory.
project (str): Required. Project where GCS Staging bucket is located.
credentials (auth_credentials.Credentials):
Optional credentials used with GCS client.
Returns:
GCS location of Python package.
"""
copy_method = functools.partial(
utils._timestamped_copy_to_gcs,
gcs_dir=gcs_staging_dir,
project=project,
credentials=credentials,
)
return self.package_and_copy(copy_method=copy_method)

View File

@@ -0,0 +1,93 @@
# -*- coding: utf-8 -*-
# Copyright 2021 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from typing import Sequence, Dict
from google.cloud.aiplatform_v1beta1.services.tensorboard_service.client import (
TensorboardServiceClient,
)
_SERVING_DOMAIN = "tensorboard.googleusercontent.com"
def _parse_experiment_name(experiment_name: str) -> Dict[str, str]:
"""Parses an experiment_name into its component segments.
Args:
experiment_name: Resource name of the TensorboardExperiment. E.g.
"projects/123/locations/asia-east1/tensorboards/456/experiments/exp1"
Returns:
Components of the experiment name.
Raises:
ValueError: If the experiment_name is invalid.
"""
matched = TensorboardServiceClient.parse_tensorboard_experiment_path(
experiment_name
)
if not matched:
raise ValueError(f"Invalid experiment name: {experiment_name}.")
return matched
def get_experiment_url(experiment_name: str) -> str:
"""Get URL for comparing experiments.
Args:
experiment_name: Resource name of the TensorboardExperiment. E.g.
"projects/123/locations/asia-east1/tensorboards/456/experiments/exp1"
Returns:
URL for the tensorboard web app.
"""
location = _parse_experiment_name(experiment_name)["location"]
name_for_url = experiment_name.replace("/", "+")
return f"https://{location}.{_SERVING_DOMAIN}/experiment/{name_for_url}"
def get_experiments_compare_url(experiment_names: Sequence[str]) -> str:
"""Get URL for comparing experiments.
Args:
experiment_names: Resource names of the TensorboardExperiments that needs to
be compared.
Returns:
URL for the tensorboard web app.
"""
if len(experiment_names) < 2:
raise ValueError("At least two experiment_names are required.")
locations = {
_parse_experiment_name(experiment_name)["location"]
for experiment_name in experiment_names
}
if len(locations) != 1:
raise ValueError(
f"Got experiments from different locations: {', '.join(locations)}."
)
location = locations.pop()
experiment_url_segments = []
for idx, experiment_name in enumerate(experiment_names):
name_segments = _parse_experiment_name(experiment_name)
experiment_url_segments.append(
"{cnt}-{experiment}:{project}+{location}+{tensorboard}+{experiment}".format(
cnt=idx + 1, **name_segments
)
)
encoded_names = ",".join(experiment_url_segments)
return f"https://{location}.{_SERVING_DOMAIN}/compare/{encoded_names}"

View File

@@ -0,0 +1,314 @@
# -*- coding: utf-8 -*-
# Copyright 2020 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from typing import NamedTuple, Optional, Dict, Union, List, Literal
from google.cloud.aiplatform import utils
from google.cloud.aiplatform.compat.types import (
accelerator_type as gca_accelerator_type_compat,
)
# `_SPEC_ORDERS` contains the worker pool spec type and its order in the `_WorkerPoolSpec`.
# The `server_spec` supports either reduction server or parameter server, each
# with different configuration for its `container_spec`. This mapping will be
# used during configuration of `container_spec` for all worker pool specs.
_SPEC_ORDERS = {
"chief_spec": 0,
"worker_spec": 1,
"server_spec": 2,
"evaluator_spec": 3,
}
class _WorkerPoolSpec(NamedTuple):
"""Specification container for Worker Pool specs used for distributed training.
Usage:
spec = _WorkerPoolSpec(
replica_count=10,
machine_type='n1-standard-4',
accelerator_count=2,
accelerator_type='NVIDIA_TESLA_K80',
boot_disk_type='pd-ssd',
boot_disk_size_gb=100,
reservation_affinity_type=reservation_affinity_type,
reservation_affinity_key=reservation_affinity_key,
reservation_affinity_values=reservation_affinity_values,
)
Note that container and python package specs are not stored with this spec.
"""
replica_count: int = 0
machine_type: str = "n1-standard-4"
accelerator_count: int = 0
accelerator_type: str = "ACCELERATOR_TYPE_UNSPECIFIED"
boot_disk_type: str = "pd-ssd"
boot_disk_size_gb: int = 100
tpu_topology: Optional[str] = None
reservation_affinity_type: Optional[
Literal["NO_RESERVATION", "ANY_RESERVATION", "SPECIFIC_RESERVATION"]
] = None
reservation_affinity_key: Optional[str] = None
reservation_affinity_values: Optional[List[str]] = None
def _get_accelerator_type(self) -> Optional[str]:
"""Validates accelerator_type and returns the name of the accelerator.
Returns:
None if no accelerator or valid accelerator name.
Raise:
ValueError if accelerator type is invalid.
"""
# Raises ValueError if invalid accelerator_type
utils.validate_accelerator_type(self.accelerator_type)
accelerator_enum = getattr(
gca_accelerator_type_compat.AcceleratorType, self.accelerator_type
)
if (
accelerator_enum
!= gca_accelerator_type_compat.AcceleratorType.ACCELERATOR_TYPE_UNSPECIFIED
):
return self.accelerator_type
@property
def spec_dict(self) -> Dict[str, Union[int, str, Dict[str, Union[int, str]]]]:
"""Return specification as a Dict."""
spec = {
"machine_spec": {"machine_type": self.machine_type},
"replica_count": self.replica_count,
"disk_spec": {
"boot_disk_type": self.boot_disk_type,
"boot_disk_size_gb": self.boot_disk_size_gb,
},
}
accelerator_type = self._get_accelerator_type()
if accelerator_type and self.accelerator_count:
spec["machine_spec"]["accelerator_type"] = accelerator_type
spec["machine_spec"]["accelerator_count"] = self.accelerator_count
if self.tpu_topology:
spec["machine_spec"]["tpu_topology"] = self.tpu_topology
if self.reservation_affinity_type:
spec["machine_spec"]["reservation_affinity"] = {
"reservation_affinity_type": self.reservation_affinity_type,
}
if self.reservation_affinity_type == "SPECIFIC_RESERVATION":
spec["machine_spec"]["reservation_affinity"][
"key"
] = self.reservation_affinity_key
spec["machine_spec"]["reservation_affinity"][
"values"
] = self.reservation_affinity_values
return spec
@property
def is_empty(self) -> bool:
"""Returns True is replica_count > 0 False otherwise."""
return self.replica_count <= 0
class _DistributedTrainingSpec(NamedTuple):
"""Configuration for distributed training worker pool specs.
Vertex AI Training expects configuration in this order:
[
chief spec, # can only have one replica
worker spec,
parameter server spec,
evaluator spec
]
Usage:
dist_training_spec = _DistributedTrainingSpec(
chief_spec = _WorkerPoolSpec(
replica_count=1,
machine_type='n1-standard-4',
accelerator_count=2,
accelerator_type='NVIDIA_TESLA_K80',
boot_disk_type='pd-ssd',
boot_disk_size_gb=100,
),
worker_spec = _WorkerPoolSpec(
replica_count=10,
machine_type='n1-standard-4',
accelerator_count=2,
accelerator_type='NVIDIA_TESLA_K80',
boot_disk_type='pd-ssd',
boot_disk_size_gb=100,
),
)
"""
chief_spec: _WorkerPoolSpec = _WorkerPoolSpec()
worker_spec: _WorkerPoolSpec = _WorkerPoolSpec()
server_spec: _WorkerPoolSpec = _WorkerPoolSpec()
evaluator_spec: _WorkerPoolSpec = _WorkerPoolSpec()
@property
def pool_specs(
self,
) -> List[Dict[str, Union[int, str, Dict[str, Union[int, str]]]]]:
"""Return each pools spec in correct order for Vertex AI as a list of
dicts.
Also removes specs if they are empty but leaves specs in if there unusual
specifications to not break the ordering in Vertex AI Training.
ie. 0 chief replica, 10 worker replica, 3 ps replica
Returns:
Order list of worker pool specs suitable for Vertex AI Training.
"""
if self.chief_spec.replica_count > 1:
raise ValueError("Chief spec replica count cannot be greater than 1.")
spec_order = [
self.chief_spec,
self.worker_spec,
self.server_spec,
self.evaluator_spec,
]
specs = [{} if s.is_empty else s.spec_dict for s in spec_order]
for i in reversed(range(len(spec_order))):
if spec_order[i].is_empty:
specs.pop()
else:
break
return specs
@classmethod
def chief_worker_pool(
cls,
replica_count: int = 0,
machine_type: str = "n1-standard-4",
accelerator_count: int = 0,
accelerator_type: str = "ACCELERATOR_TYPE_UNSPECIFIED",
boot_disk_type: str = "pd-ssd",
boot_disk_size_gb: int = 100,
reduction_server_replica_count: int = 0,
reduction_server_machine_type: str = None,
tpu_topology: str = None,
reservation_affinity_type: Optional[
Literal["NO_RESERVATION", "ANY_RESERVATION", "SPECIFIC_RESERVATION"]
] = None,
reservation_affinity_key: Optional[str] = None,
reservation_affinity_values: Optional[List[str]] = None,
) -> "_DistributedTrainingSpec":
"""Parametrizes Config to support only chief with worker replicas.
For replica is assigned to chief and the remainder to workers. All spec have the
same machine type, accelerator count, and accelerator type.
Args:
replica_count (int):
The number of worker replicas. Assigns 1 chief replica and
replica_count - 1 worker replicas.
machine_type (str):
The type of machine to use for training.
accelerator_type (str):
Hardware accelerator type. One of ACCELERATOR_TYPE_UNSPECIFIED,
NVIDIA_TESLA_K80, NVIDIA_TESLA_P100, NVIDIA_TESLA_V100, NVIDIA_TESLA_P4,
NVIDIA_TESLA_T4
accelerator_count (int):
The number of accelerators to attach to a worker replica.
boot_disk_type (str):
Type of the boot disk (default is `pd-ssd`).
Valid values: `pd-ssd` (Persistent Disk Solid State Drive) or
`pd-standard` (Persistent Disk Hard Disk Drive).
boot_disk_size_gb (int):
Size in GB of the boot disk (default is 100GB).
boot disk size must be within the range of [100, 64000].
reduction_server_replica_count (int):
The number of reduction server replicas, default is 0.
reduction_server_machine_type (str):
The type of machine to use for reduction server, default is `n1-highcpu-16`.
tpu_topology (str):
TPU topology for the TPU type. This field is
required for the TPU v5 versions. This field is only passed to the
chief replica as TPU jobs only allow 1 replica.
reservation_affinity_type (str):
Optional. The type of reservation affinity. One of:
* "NO_RESERVATION" : No reservation is used.
* "ANY_RESERVATION" : Any reservation that matches machine spec
can be used.
* "SPECIFIC_RESERVATION" : A specific reservation must be use
used. See reservation_affinity_key and
reservation_affinity_values for how to specify the reservation.
reservation_affinity_key (str):
Optional. Corresponds to the label key of a reservation resource.
To target a SPECIFIC_RESERVATION by name, use
`compute.googleapis.com/reservation-name` as the key
and specify the name of your reservation as its value.
reservation_affinity_values (List[str]):
Optional. Corresponds to the label values of a reservation resource.
This must be the full resource name of the reservation.
Format: 'projects/{project_id_or_number}/zones/{zone}/reservations/{reservation_name}'
Returns:
_DistributedTrainingSpec representing one chief and n workers all of
same type, optional with reduction server(s). If replica_count <= 0
then an empty spec is returned.
"""
if replica_count <= 0:
return cls()
chief_spec = _WorkerPoolSpec(
replica_count=1,
machine_type=machine_type,
accelerator_count=accelerator_count,
accelerator_type=accelerator_type,
boot_disk_type=boot_disk_type,
boot_disk_size_gb=boot_disk_size_gb,
tpu_topology=tpu_topology,
reservation_affinity_type=reservation_affinity_type,
reservation_affinity_key=reservation_affinity_key,
reservation_affinity_values=reservation_affinity_values,
)
worker_spec = _WorkerPoolSpec(
replica_count=replica_count - 1,
machine_type=machine_type,
accelerator_count=accelerator_count,
accelerator_type=accelerator_type,
boot_disk_type=boot_disk_type,
boot_disk_size_gb=boot_disk_size_gb,
reservation_affinity_type=reservation_affinity_type,
reservation_affinity_key=reservation_affinity_key,
reservation_affinity_values=reservation_affinity_values,
)
reduction_server_spec = _WorkerPoolSpec(
replica_count=reduction_server_replica_count,
machine_type=reduction_server_machine_type,
reservation_affinity_type=reservation_affinity_type,
reservation_affinity_key=reservation_affinity_key,
reservation_affinity_values=reservation_affinity_values,
)
return cls(
chief_spec=chief_spec,
worker_spec=worker_spec,
server_spec=reduction_server_spec,
)

View File

@@ -0,0 +1,145 @@
# -*- coding: utf-8 -*-
# Copyright 2020 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from types import ModuleType
from typing import Any, Dict, Optional
from urllib import request
from google.auth import credentials as auth_credentials
from google.auth import transport
from google.cloud import storage
from google.cloud.aiplatform.constants import pipeline as pipeline_constants
# Pattern for an Artifact Registry URL.
_VALID_AR_URL = pipeline_constants._VALID_AR_URL
# Pattern for any JSON or YAML file over HTTPS.
_VALID_HTTPS_URL = pipeline_constants._VALID_HTTPS_URL
def load_yaml(
path: str,
project: Optional[str] = None,
credentials: Optional[auth_credentials.Credentials] = None,
) -> Dict[str, Any]:
"""Loads data from a YAML document.
Args:
path (str):
Required. The path of the YAML document. It can be a local path, a
Google Cloud Storage URI, an Artifact Registry URI, or an HTTPS URI.
project (str):
Optional. Project to initiate the Storage client with.
credentials (auth_credentials.Credentials):
Optional. Credentials to use with Storage Client.
Returns:
A Dict object representing the YAML document.
"""
if path.startswith("gs://"):
return _load_yaml_from_gs_uri(path, project, credentials)
elif path.startswith("http://") or path.startswith("https://"):
if _VALID_AR_URL.match(path):
return _load_yaml_from_https_uri(path, credentials)
elif _VALID_HTTPS_URL.match(path):
return _load_yaml_from_https_uri(path)
else:
raise ValueError(
"Invalid HTTPS URI. If not using Artifact Registry, please "
"ensure the URI ends with .json, .yaml, or .yml."
)
else:
return _load_yaml_from_local_file(path)
def _maybe_import_yaml() -> ModuleType:
"""Tries to import the PyYAML module."""
try:
import yaml
except ImportError:
raise ImportError(
"PyYAML is not installed and is required to parse PipelineJob or "
'PipelineSpec files. Please install the SDK using "pip install '
'google-cloud-aiplatform[pipelines]"'
)
return yaml
def _load_yaml_from_gs_uri(
uri: str,
project: Optional[str] = None,
credentials: Optional[auth_credentials.Credentials] = None,
) -> Dict[str, Any]:
"""Loads data from a YAML document referenced by a GCS URI.
Args:
path (str):
Required. GCS URI for YAML document.
project (str):
Optional. Project to initiate the Storage client with.
credentials (auth_credentials.Credentials):
Optional. Credentials to use with Storage Client.
Returns:
A Dict object representing the YAML document.
"""
yaml = _maybe_import_yaml()
storage_client = storage.Client(project=project, credentials=credentials)
blob = storage.Blob.from_string(uri, storage_client)
return yaml.safe_load(blob.download_as_bytes())
def _load_yaml_from_local_file(file_path: str) -> Dict[str, Any]:
"""Loads data from a YAML local file.
Args:
file_path (str):
Required. The local file path of the YAML document.
Returns:
A Dict object representing the YAML document.
"""
yaml = _maybe_import_yaml()
with open(file_path) as f:
return yaml.safe_load(f)
def _load_yaml_from_https_uri(
uri: str,
credentials: Optional[auth_credentials.Credentials] = None,
) -> Dict[str, Any]:
"""Loads data from a YAML document referenced by a Artifact Registry URI.
Args:
uri (str):
Required. Artifact Registry URI for YAML document.
credentials (auth_credentials.Credentials):
Optional. Credentials to use with Artifact Registry.
Returns:
A Dict object representing the YAML document.
"""
yaml = _maybe_import_yaml()
req = request.Request(uri)
if credentials:
if not credentials.valid:
credentials.refresh(transport.requests.Request())
if credentials.token:
req.add_header("Authorization", "Bearer " + credentials.token)
response = request.urlopen(req)
return yaml.safe_load(response.read().decode("utf-8"))