structure saas with tools
This commit is contained in:
File diff suppressed because it is too large
Load Diff
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,63 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
# Copyright 2022 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from google.cloud.aiplatform import explain
|
||||
from google.cloud.aiplatform.compat.types import (
|
||||
endpoint as gca_endpoint_compat,
|
||||
)
|
||||
|
||||
|
||||
def create_and_validate_explanation_spec(
|
||||
explanation_metadata: Optional[explain.ExplanationMetadata] = None,
|
||||
explanation_parameters: Optional[explain.ExplanationParameters] = None,
|
||||
) -> Optional[explain.ExplanationSpec]:
|
||||
"""Validates the parameters needed to create explanation_spec and creates it.
|
||||
|
||||
Args:
|
||||
explanation_metadata (explain.ExplanationMetadata):
|
||||
Optional. Metadata describing the Model's input and output for
|
||||
explanation. `explanation_metadata` is optional while
|
||||
`explanation_parameters` must be specified when used.
|
||||
For more details, see `Ref docs <http://tinyurl.com/1igh60kt>`
|
||||
explanation_parameters (explain.ExplanationParameters):
|
||||
Optional. Parameters to configure explaining for Model's
|
||||
predictions.
|
||||
For more details, see `Ref docs <http://tinyurl.com/1an4zake>`
|
||||
|
||||
Returns:
|
||||
explanation_spec: Specification of Model explanation.
|
||||
|
||||
Raises:
|
||||
ValueError: If `explanation_metadata` is given, but
|
||||
`explanation_parameters` is omitted. `explanation_metadata` is optional
|
||||
while `explanation_parameters` must be specified when used.
|
||||
"""
|
||||
if bool(explanation_metadata) and not bool(explanation_parameters):
|
||||
raise ValueError(
|
||||
"To get model explanation, `explanation_parameters` must be specified."
|
||||
)
|
||||
|
||||
if explanation_parameters:
|
||||
explanation_spec = gca_endpoint_compat.explanation.ExplanationSpec()
|
||||
explanation_spec.parameters = explanation_parameters
|
||||
if explanation_metadata:
|
||||
explanation_spec.metadata = explanation_metadata
|
||||
return explanation_spec
|
||||
|
||||
return None
|
||||
@@ -0,0 +1,284 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
# Copyright 2024 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
import sys
|
||||
import typing
|
||||
import urllib
|
||||
from uuid import uuid4
|
||||
from typing import Optional
|
||||
|
||||
from google.cloud.aiplatform import base
|
||||
|
||||
if typing.TYPE_CHECKING:
|
||||
from google.cloud.aiplatform.metadata import experiment_resources
|
||||
from google.cloud.aiplatform.metadata import experiment_run_resource
|
||||
from google.cloud.aiplatform import model_evaluation
|
||||
from vertexai.preview.tuning import sft
|
||||
|
||||
_LOGGER = base.Logger(__name__)
|
||||
|
||||
|
||||
def _get_ipython_shell_name() -> str:
|
||||
if "IPython" in sys.modules:
|
||||
from IPython import get_ipython
|
||||
|
||||
return get_ipython().__class__.__name__
|
||||
return ""
|
||||
|
||||
|
||||
def is_ipython_available() -> bool:
|
||||
return _get_ipython_shell_name() != ""
|
||||
|
||||
|
||||
def _get_styles() -> None:
|
||||
"""Returns the HTML style markup to support custom buttons."""
|
||||
return """
|
||||
<link rel="stylesheet" href="https://fonts.googleapis.com/icon?family=Material+Icons">
|
||||
<style>
|
||||
.view-vertex-resource,
|
||||
.view-vertex-resource:hover,
|
||||
.view-vertex-resource:visited {
|
||||
position: relative;
|
||||
display: inline-flex;
|
||||
flex-direction: row;
|
||||
height: 32px;
|
||||
padding: 0 12px;
|
||||
margin: 4px 18px;
|
||||
gap: 4px;
|
||||
border-radius: 4px;
|
||||
|
||||
align-items: center;
|
||||
justify-content: center;
|
||||
background-color: rgb(255, 255, 255);
|
||||
color: rgb(51, 103, 214);
|
||||
|
||||
font-family: Roboto,"Helvetica Neue",sans-serif;
|
||||
font-size: 13px;
|
||||
font-weight: 500;
|
||||
text-transform: uppercase;
|
||||
text-decoration: none !important;
|
||||
|
||||
transition: box-shadow 280ms cubic-bezier(0.4, 0, 0.2, 1) 0s;
|
||||
box-shadow: 0px 3px 1px -2px rgba(0,0,0,0.2), 0px 2px 2px 0px rgba(0,0,0,0.14), 0px 1px 5px 0px rgba(0,0,0,0.12);
|
||||
}
|
||||
.view-vertex-resource:active {
|
||||
box-shadow: 0px 5px 5px -3px rgba(0,0,0,0.2),0px 8px 10px 1px rgba(0,0,0,0.14),0px 3px 14px 2px rgba(0,0,0,0.12);
|
||||
}
|
||||
.view-vertex-resource:active .view-vertex-ripple::before {
|
||||
position: absolute;
|
||||
top: 0;
|
||||
bottom: 0;
|
||||
left: 0;
|
||||
right: 0;
|
||||
border-radius: 4px;
|
||||
pointer-events: none;
|
||||
|
||||
content: '';
|
||||
background-color: rgb(51, 103, 214);
|
||||
opacity: 0.12;
|
||||
}
|
||||
.view-vertex-icon {
|
||||
font-size: 18px;
|
||||
}
|
||||
</style>
|
||||
"""
|
||||
|
||||
|
||||
def display_link(text: str, url: str, icon: Optional[str] = "open_in_new") -> None:
|
||||
"""Creates and displays the link to open the Vertex resource
|
||||
|
||||
Args:
|
||||
text: The text displayed on the clickable button.
|
||||
url: The url that the button will lead to.
|
||||
Only cloud console URIs are allowed.
|
||||
icon: The icon name on the button (from material-icons library)
|
||||
|
||||
Returns:
|
||||
Dict of custom properties with keys mapped to column names
|
||||
"""
|
||||
CLOUD_UI_URL = "https://console.cloud.google.com"
|
||||
CLOUD_DOCS_URL = "https://cloud.google.com"
|
||||
if not (url.startswith(CLOUD_UI_URL) or url.startswith(CLOUD_DOCS_URL)):
|
||||
raise ValueError(
|
||||
f"Only urls starting with {CLOUD_UI_URL} or {CLOUD_DOCS_URL} are allowed."
|
||||
)
|
||||
|
||||
button_id = f"view-vertex-resource-{str(uuid4())}"
|
||||
|
||||
# Add the markup for the CSS and link component
|
||||
html = f"""
|
||||
{_get_styles()}
|
||||
<a class="view-vertex-resource" id="{button_id}" href="#view-{button_id}">
|
||||
<span class="material-icons view-vertex-icon">{icon}</span>
|
||||
<span>{text}</span>
|
||||
</a>
|
||||
"""
|
||||
|
||||
# Add the click handler for the link
|
||||
html += f"""
|
||||
<script>
|
||||
(function () {{
|
||||
const link = document.getElementById('{button_id}');
|
||||
link.addEventListener('click', (e) => {{
|
||||
if (window.google?.colab?.openUrl) {{
|
||||
window.google.colab.openUrl('{url}');
|
||||
}} else {{
|
||||
window.open('{url}', '_blank');
|
||||
}}
|
||||
e.stopPropagation();
|
||||
e.preventDefault();
|
||||
}});
|
||||
}})();
|
||||
</script>
|
||||
"""
|
||||
|
||||
from IPython.display import display
|
||||
from IPython.display import HTML
|
||||
|
||||
display(HTML(html))
|
||||
|
||||
|
||||
def display_experiment_button(experiment: "experiment_resources.Experiment") -> None:
|
||||
"""Function to generate a link bound to the Vertex experiment"""
|
||||
if not is_ipython_available():
|
||||
return
|
||||
try:
|
||||
project = experiment._metadata_context.project
|
||||
location = experiment._metadata_context.location
|
||||
experiment_name = experiment._metadata_context.name
|
||||
if experiment_name is None or project is None or location is None:
|
||||
return
|
||||
except AttributeError:
|
||||
_LOGGER.warning("Unable to fetch experiment metadata")
|
||||
return
|
||||
|
||||
uri = (
|
||||
"https://console.cloud.google.com/vertex-ai/experiments/locations/"
|
||||
+ f"{location}/experiments/{experiment_name}/"
|
||||
+ f"runs?project={project}"
|
||||
)
|
||||
display_link("View Experiment", uri, "science")
|
||||
|
||||
|
||||
def display_experiment_run_button(
|
||||
experiment_run: "experiment_run_resource.ExperimentRun",
|
||||
) -> None:
|
||||
"""Function to generate a link bound to the Vertex experiment run"""
|
||||
if not is_ipython_available():
|
||||
return
|
||||
try:
|
||||
project = experiment_run.project
|
||||
location = experiment_run.location
|
||||
experiment_name = experiment_run._experiment._metadata_context.name
|
||||
run_name = experiment_run.name
|
||||
if (
|
||||
run_name is None
|
||||
or experiment_name is None
|
||||
or project is None
|
||||
or location is None
|
||||
):
|
||||
return
|
||||
except AttributeError:
|
||||
_LOGGER.warning("Unable to fetch experiment run metadata")
|
||||
return
|
||||
|
||||
uri = (
|
||||
"https://console.cloud.google.com/vertex-ai/experiments/locations/"
|
||||
+ f"{location}/experiments/{experiment_name}/"
|
||||
+ f"runs/{experiment_name}-{run_name}?project={project}"
|
||||
)
|
||||
display_link("View Experiment Run", uri, "science")
|
||||
|
||||
|
||||
def display_model_evaluation_button(
|
||||
evaluation: "model_evaluation.ModelEvaluation",
|
||||
) -> None:
|
||||
"""Function to generate a link bound to the Vertex model evaluation"""
|
||||
if not is_ipython_available():
|
||||
return
|
||||
|
||||
try:
|
||||
resource_name = evaluation.resource_name
|
||||
fields = evaluation._parse_resource_name(resource_name)
|
||||
project = fields["project"]
|
||||
location = fields["location"]
|
||||
model_id = fields["model"]
|
||||
evaluation_id = fields["evaluation"]
|
||||
except AttributeError:
|
||||
_LOGGER.warning("Unable to parse model evaluation metadata")
|
||||
return
|
||||
|
||||
if "@" in model_id:
|
||||
model_id, version_id = model_id.split("@")
|
||||
else:
|
||||
version_id = "default"
|
||||
|
||||
uri = (
|
||||
"https://console.cloud.google.com/vertex-ai/models/locations/"
|
||||
+ f"{location}/models/{model_id}/versions/{version_id}/evaluations/"
|
||||
+ f"{evaluation_id}?project={project}"
|
||||
)
|
||||
display_link("View Model Evaluation", uri, "lightbulb")
|
||||
|
||||
|
||||
def display_model_tuning_button(tuning_job: "sft.SupervisedTuningJob") -> None:
|
||||
"""Function to generate a link bound to the Vertex model tuning job."""
|
||||
if not is_ipython_available():
|
||||
return
|
||||
|
||||
try:
|
||||
resource_name = tuning_job.resource_name
|
||||
fields = tuning_job._parse_resource_name(resource_name)
|
||||
project = fields["project"]
|
||||
location = fields["location"]
|
||||
tuning_job_id = fields["tuning_job"]
|
||||
except AttributeError:
|
||||
_LOGGER.warning("Unable to parse tuning job metadata")
|
||||
return
|
||||
|
||||
uri = (
|
||||
"https://console.cloud.google.com/vertex-ai/generative/language/"
|
||||
+ f"locations/{location}/tuning/tuningJob/{tuning_job_id}"
|
||||
+ f"?project={project}"
|
||||
)
|
||||
display_link("View Tuning Job", uri, "tune")
|
||||
|
||||
|
||||
def display_browse_prebuilt_metrics_button() -> None:
|
||||
"""Function to generate a link to the Gen AI Evaluation pre-built metrics page."""
|
||||
if not is_ipython_available():
|
||||
return
|
||||
|
||||
uri = (
|
||||
"https://cloud.google.com/vertex-ai/generative-ai/docs/models/metrics-templates"
|
||||
)
|
||||
display_link("Browse pre-built metrics", uri, "list")
|
||||
|
||||
|
||||
def display_gen_ai_evaluation_results_button(
|
||||
gcs_file_path: Optional[str] = None,
|
||||
) -> None:
|
||||
"""Function to generate a link bound to the Gen AI evaluation run."""
|
||||
if not is_ipython_available():
|
||||
return
|
||||
|
||||
uri = "https://cloud.google.com/vertex-ai/generative-ai/docs/models/view-evaluation"
|
||||
if gcs_file_path is not None:
|
||||
gcs_file_path = urllib.parse.quote(gcs_file_path)
|
||||
uri = f"https://console.cloud.google.com/storage/browser/_details/{gcs_file_path};colab_enterprise=gen_ai_evaluation"
|
||||
|
||||
display_link("View evaluation results", uri, "bar_chart")
|
||||
@@ -0,0 +1,28 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
# Copyright 2023 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
|
||||
def _is_autologging_enabled() -> bool:
|
||||
try:
|
||||
import mlflow
|
||||
|
||||
if mlflow.get_tracking_uri() == "vertex-mlflow-plugin://":
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
except ImportError:
|
||||
return False
|
||||
@@ -0,0 +1,102 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
# Copyright 2021 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
from google.cloud.aiplatform import datasets
|
||||
|
||||
|
||||
def get_default_column_transformations(
|
||||
dataset: datasets._ColumnNamesDataset,
|
||||
target_column: str,
|
||||
) -> Tuple[List[Dict[str, Dict[str, str]]], List[str]]:
|
||||
"""Get default column transformations from the column names, while omitting the target column.
|
||||
|
||||
Args:
|
||||
dataset (_ColumnNamesDataset):
|
||||
Required. The dataset
|
||||
target_column (str):
|
||||
Required. The name of the column values of which the Model is to predict.
|
||||
|
||||
Returns:
|
||||
Tuple[List[Dict[str, Dict[str, str]]], List[str]]:
|
||||
The default column transformations and the default column names.
|
||||
"""
|
||||
|
||||
column_names = [
|
||||
column_name
|
||||
for column_name in dataset.column_names
|
||||
if column_name != target_column
|
||||
]
|
||||
column_transformations = [
|
||||
{"auto": {"column_name": column_name}} for column_name in column_names
|
||||
]
|
||||
|
||||
return (column_transformations, column_names)
|
||||
|
||||
|
||||
def validate_and_get_column_transformations(
|
||||
column_specs: Optional[Dict[str, str]] = None,
|
||||
column_transformations: Optional[List[Dict[str, Dict[str, str]]]] = None,
|
||||
) -> Optional[List[Dict[str, Dict[str, str]]]]:
|
||||
"""Validates column specs and transformations, then returns processed transformations.
|
||||
|
||||
Args:
|
||||
column_specs (Dict[str, str]):
|
||||
Optional. Alternative to column_transformations where the keys of the dict
|
||||
are column names and their respective values are one of
|
||||
AutoMLTabularTrainingJob.column_data_types.
|
||||
When creating transformation for BigQuery Struct column, the column
|
||||
should be flattened using "." as the delimiter. Only columns with no child
|
||||
should have a transformation.
|
||||
If an input column has no transformations on it, such a column is
|
||||
ignored by the training, except for the targetColumn, which should have
|
||||
no transformations defined on.
|
||||
Only one of column_transformations or column_specs should be passed.
|
||||
column_transformations (List[Dict[str, Dict[str, str]]]):
|
||||
Optional. Transformations to apply to the input columns (i.e. columns other
|
||||
than the targetColumn). Each transformation may produce multiple
|
||||
result values from the column's value, and all are used for training.
|
||||
When creating transformation for BigQuery Struct column, the column
|
||||
should be flattened using "." as the delimiter. Only columns with no child
|
||||
should have a transformation.
|
||||
If an input column has no transformations on it, such a column is
|
||||
ignored by the training, except for the targetColumn, which should have
|
||||
no transformations defined on.
|
||||
Only one of column_transformations or column_specs should be passed.
|
||||
Consider using column_specs as column_transformations will be deprecated eventually.
|
||||
|
||||
Returns:
|
||||
List[Dict[str, Dict[str, str]]]:
|
||||
The column transformations.
|
||||
|
||||
Raises:
|
||||
ValueError: If both column_transformations and column_specs were provided.
|
||||
"""
|
||||
# user populated transformations
|
||||
if column_transformations is not None and column_specs is not None:
|
||||
raise ValueError(
|
||||
"Both column_transformations and column_specs were passed. Only "
|
||||
"one is allowed."
|
||||
)
|
||||
elif column_specs is not None:
|
||||
return [
|
||||
{transformation: {"column_name": column_name}}
|
||||
for column_name, transformation in column_specs.items()
|
||||
]
|
||||
else:
|
||||
return column_transformations
|
||||
@@ -0,0 +1,36 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
# Copyright 2021 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
from google.cloud.aiplatform import jobs
|
||||
from google.cloud.aiplatform import tensorboard
|
||||
|
||||
|
||||
def custom_job_console_uri(custom_job_resource_name: str) -> str:
|
||||
"""Helper method to create console uri from custom job resource name."""
|
||||
fields = jobs.CustomJob._parse_resource_name(custom_job_resource_name)
|
||||
return f"https://console.cloud.google.com/ai/platform/locations/{fields['location']}/training/{fields['custom_job']}?project={fields['project']}"
|
||||
|
||||
|
||||
def custom_job_tensorboard_console_uri(
|
||||
tensorboard_resource_name: str, custom_job_resource_name: str
|
||||
) -> str:
|
||||
"""Helper method to create console uri to tensorboard from custom job resource."""
|
||||
# projects+40556267596+locations+us-central1+tensorboards+740208820004847616+experiments+2214368039829241856
|
||||
fields = tensorboard.Tensorboard._parse_resource_name(tensorboard_resource_name)
|
||||
experiment_resource_name = f"{tensorboard_resource_name}/experiments/{custom_job_resource_name.split('/')[-1]}"
|
||||
uri_experiment_resource_name = experiment_resource_name.replace("/", "+")
|
||||
return f"https://{fields['location']}.tensorboard.googleusercontent.com/experiment/{uri_experiment_resource_name}"
|
||||
@@ -0,0 +1,13 @@
|
||||
# Copyright 2021 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,72 @@
|
||||
# Copyright 2020 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from __future__ import absolute_import
|
||||
from google.cloud.aiplatform.utils.enhanced_library import value_converter
|
||||
|
||||
from proto.marshal import Marshal
|
||||
from proto.marshal.rules.struct import ValueRule
|
||||
from google.protobuf.struct_pb2 import Value
|
||||
|
||||
|
||||
class ConversionValueRule(ValueRule):
|
||||
def to_python(self, value, *, absent: bool = None):
|
||||
return super().to_python(value, absent=absent)
|
||||
|
||||
def to_proto(self, value):
|
||||
|
||||
# Need to check whether value is an instance
|
||||
# of an enhanced type
|
||||
if callable(getattr(value, "to_value", None)):
|
||||
return value.to_value()
|
||||
else:
|
||||
return super().to_proto(value)
|
||||
|
||||
|
||||
def _add_methods_to_classes_in_package(pkg):
|
||||
classes = dict(
|
||||
[(name, cls) for name, cls in pkg.__dict__.items() if isinstance(cls, type)]
|
||||
)
|
||||
|
||||
for class_name, cls in classes.items():
|
||||
# Add to_value() method to class with docstring
|
||||
setattr(cls, "to_value", value_converter.to_value)
|
||||
cls.to_value.__doc__ = value_converter.to_value.__doc__
|
||||
|
||||
# Add from_value() method to class with docstring
|
||||
setattr(cls, "from_value", _add_from_value_to_class(cls))
|
||||
cls.from_value.__doc__ = value_converter.from_value.__doc__
|
||||
|
||||
# Add from_map() method to class with docstring
|
||||
setattr(cls, "from_map", _add_from_map_to_class(cls))
|
||||
cls.from_map.__doc__ = value_converter.from_map.__doc__
|
||||
|
||||
|
||||
def _add_from_value_to_class(cls):
|
||||
def _from_value(value):
|
||||
return value_converter.from_value(cls, value)
|
||||
|
||||
return _from_value
|
||||
|
||||
|
||||
def _add_from_map_to_class(cls):
|
||||
def _from_map(map_):
|
||||
return value_converter.from_map(cls, map_)
|
||||
|
||||
return _from_map
|
||||
|
||||
|
||||
marshal = Marshal(name="google.cloud.aiplatform.v1beta1")
|
||||
marshal.register(Value, ConversionValueRule(marshal=marshal))
|
||||
marshal = Marshal(name="google.cloud.aiplatform.v1")
|
||||
marshal.register(Value, ConversionValueRule(marshal=marshal))
|
||||
@@ -0,0 +1,60 @@
|
||||
# Copyright 2020 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from __future__ import absolute_import
|
||||
from google.protobuf.struct_pb2 import Value
|
||||
from google.protobuf import json_format
|
||||
from proto.marshal.collections.maps import MapComposite
|
||||
from proto.marshal import Marshal
|
||||
from proto import Message
|
||||
from proto.message import MessageMeta
|
||||
|
||||
|
||||
def to_value(self: Message) -> Value:
|
||||
"""Converts a message type to a :class:`~google.protobuf.struct_pb2.Value` object.
|
||||
|
||||
Args:
|
||||
message: the message to convert
|
||||
|
||||
Returns:
|
||||
the message as a :class:`~google.protobuf.struct_pb2.Value` object
|
||||
"""
|
||||
tmp_dict = json_format.MessageToDict(self._pb)
|
||||
return json_format.ParseDict(tmp_dict, Value())
|
||||
|
||||
|
||||
def from_value(cls: MessageMeta, value: Value) -> Message:
|
||||
"""Creates instance of class from a :class:`~google.protobuf.struct_pb2.Value` object.
|
||||
|
||||
Args:
|
||||
value: a :class:`~google.protobuf.struct_pb2.Value` object
|
||||
|
||||
Returns:
|
||||
Instance of class
|
||||
"""
|
||||
value_dict = json_format.MessageToDict(value)
|
||||
return json_format.ParseDict(value_dict, cls()._pb)
|
||||
|
||||
|
||||
def from_map(cls: MessageMeta, map_: MapComposite) -> Message:
|
||||
"""Creates instance of class from a :class:`~proto.marshal.collections.maps.MapComposite` object.
|
||||
|
||||
Args:
|
||||
map_: a :class:`~proto.marshal.collections.maps.MapComposite` object
|
||||
|
||||
Returns:
|
||||
Instance of class
|
||||
"""
|
||||
marshal = Marshal(name="marshal")
|
||||
pb = marshal.to_proto(Value, map_)
|
||||
return from_value(cls, pb)
|
||||
@@ -0,0 +1,177 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
# Copyright 2021 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
import re
|
||||
from typing import Dict, NamedTuple, Optional
|
||||
|
||||
from google.cloud.aiplatform.compat.services import featurestore_service_client
|
||||
from google.cloud.aiplatform.compat.types import (
|
||||
feature as gca_feature,
|
||||
featurestore_service as gca_featurestore_service,
|
||||
)
|
||||
from google.cloud.aiplatform import utils
|
||||
|
||||
CompatFeaturestoreServiceClient = featurestore_service_client.FeaturestoreServiceClient
|
||||
|
||||
RESOURCE_ID_PATTERN_REGEX = r"[a-z_][a-z0-9_]{0,59}"
|
||||
GCS_SOURCE_TYPE = {"csv", "avro"}
|
||||
GCS_DESTINATION_TYPE = {"csv", "tfrecord"}
|
||||
|
||||
_FEATURE_VALUE_TYPE_UNSPECIFIED = "VALUE_TYPE_UNSPECIFIED"
|
||||
|
||||
FEATURE_STORE_VALUE_TYPE_TO_BQ_DATA_TYPE_MAP = {
|
||||
"BOOL": {"field_type": "BOOL"},
|
||||
"BOOL_ARRAY": {"field_type": "BOOL", "mode": "REPEATED"},
|
||||
"DOUBLE": {"field_type": "FLOAT64"},
|
||||
"DOUBLE_ARRAY": {"field_type": "FLOAT64", "mode": "REPEATED"},
|
||||
"INT64": {"field_type": "INT64"},
|
||||
"INT64_ARRAY": {"field_type": "INT64", "mode": "REPEATED"},
|
||||
"STRING": {"field_type": "STRING"},
|
||||
"STRING_ARRAY": {"field_type": "STRING", "mode": "REPEATED"},
|
||||
"BYTES": {"field_type": "BYTES"},
|
||||
}
|
||||
|
||||
|
||||
def validate_id(resource_id: str) -> None:
|
||||
"""Validates feature store resource ID pattern.
|
||||
|
||||
Args:
|
||||
resource_id (str):
|
||||
Required. Feature Store resource ID.
|
||||
|
||||
Raises:
|
||||
ValueError if resource_id is invalid.
|
||||
"""
|
||||
if not re.compile(r"^" + RESOURCE_ID_PATTERN_REGEX + r"$").match(resource_id):
|
||||
raise ValueError("Resource ID {resource_id} is not a valid resource id.")
|
||||
|
||||
|
||||
def validate_feature_id(feature_id: str) -> None:
|
||||
"""Validates feature ID.
|
||||
|
||||
Args:
|
||||
feature_id (str):
|
||||
Required. Feature resource ID.
|
||||
|
||||
Raises:
|
||||
ValueError if feature_id is invalid.
|
||||
"""
|
||||
match = re.compile(r"^" + RESOURCE_ID_PATTERN_REGEX + r"$").match(feature_id)
|
||||
|
||||
if not match:
|
||||
raise ValueError(
|
||||
f"The value of feature_id may be up to 60 characters, and valid characters are `[a-z0-9_]`. "
|
||||
f"The first character cannot be a number. Instead, get {feature_id}."
|
||||
)
|
||||
|
||||
reserved_words = ["entity_id", "feature_timestamp", "arrival_timestamp"]
|
||||
if feature_id.lower() in reserved_words:
|
||||
raise ValueError(
|
||||
"The feature_id can not be any of the reserved_words: `%s`"
|
||||
% ("`, `".join(reserved_words))
|
||||
)
|
||||
|
||||
|
||||
def validate_value_type(value_type: str) -> None:
|
||||
"""Validates user provided feature value_type string.
|
||||
|
||||
Args:
|
||||
value_type (str):
|
||||
Required. Immutable. Type of Feature value.
|
||||
One of BOOL, BOOL_ARRAY, DOUBLE, DOUBLE_ARRAY, INT64, INT64_ARRAY, STRING, STRING_ARRAY, BYTES.
|
||||
|
||||
Raises:
|
||||
ValueError if value_type is invalid or unspecified.
|
||||
"""
|
||||
if getattr(gca_feature.Feature.ValueType, value_type, None) in (
|
||||
gca_feature.Feature.ValueType.VALUE_TYPE_UNSPECIFIED,
|
||||
None,
|
||||
):
|
||||
raise ValueError(
|
||||
f"Given value_type `{value_type}` invalid or unspecified. "
|
||||
f"Choose one of {gca_feature.Feature.ValueType._member_names_} except `{_FEATURE_VALUE_TYPE_UNSPECIFIED}`"
|
||||
)
|
||||
|
||||
|
||||
class _FeatureConfig(NamedTuple):
|
||||
"""Configuration for feature creation.
|
||||
|
||||
Usage:
|
||||
|
||||
config = _FeatureConfig(
|
||||
feature_id='my_feature_id',
|
||||
value_type='int64',
|
||||
description='my description',
|
||||
labels={'my_key': 'my_value'},
|
||||
)
|
||||
"""
|
||||
|
||||
feature_id: str
|
||||
value_type: str = _FEATURE_VALUE_TYPE_UNSPECIFIED
|
||||
description: Optional[str] = None
|
||||
labels: Optional[Dict[str, str]] = None
|
||||
|
||||
def _get_feature_id(self) -> str:
|
||||
"""Validates and returns the feature_id.
|
||||
|
||||
Returns:
|
||||
str - valid feature ID.
|
||||
|
||||
Raise:
|
||||
ValueError if feature_id is invalid
|
||||
"""
|
||||
|
||||
# Raises ValueError if invalid feature_id
|
||||
validate_feature_id(feature_id=self.feature_id)
|
||||
|
||||
return self.feature_id
|
||||
|
||||
def _get_value_type_enum(self) -> int:
|
||||
"""Validates value_type and returns the enum of the value type.
|
||||
|
||||
Returns:
|
||||
int - valid value type enum.
|
||||
"""
|
||||
|
||||
# Raises ValueError if invalid value_type
|
||||
validate_value_type(value_type=self.value_type)
|
||||
|
||||
value_type_enum = getattr(gca_feature.Feature.ValueType, self.value_type)
|
||||
|
||||
return value_type_enum
|
||||
|
||||
def get_create_feature_request(
|
||||
self,
|
||||
) -> gca_featurestore_service.CreateFeatureRequest:
|
||||
"""Return create feature request."""
|
||||
|
||||
gapic_feature = gca_feature.Feature(
|
||||
value_type=self._get_value_type_enum(),
|
||||
)
|
||||
|
||||
if self.labels:
|
||||
utils.validate_labels(self.labels)
|
||||
gapic_feature.labels = self.labels
|
||||
|
||||
if self.description:
|
||||
gapic_feature.description = self.description
|
||||
|
||||
create_feature_request = gca_featurestore_service.CreateFeatureRequest(
|
||||
feature=gapic_feature, feature_id=self._get_feature_id()
|
||||
)
|
||||
|
||||
return create_feature_request
|
||||
@@ -0,0 +1,398 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
# Copyright 2023 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import datetime
|
||||
import glob
|
||||
import logging
|
||||
import os
|
||||
import pathlib
|
||||
import tempfile
|
||||
from typing import Optional, TYPE_CHECKING
|
||||
|
||||
from google.auth import credentials as auth_credentials
|
||||
from google.cloud import storage
|
||||
|
||||
from google.cloud.aiplatform import initializer
|
||||
from google.cloud.aiplatform.utils import resource_manager_utils
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import pandas
|
||||
|
||||
_logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def upload_to_gcs(
|
||||
source_path: str,
|
||||
destination_uri: str,
|
||||
project: Optional[str] = None,
|
||||
credentials: Optional[auth_credentials.Credentials] = None,
|
||||
):
|
||||
"""Uploads local files to GCS.
|
||||
|
||||
After upload the `destination_uri` will contain the same data as the `source_path`.
|
||||
|
||||
Args:
|
||||
source_path: Required. Path of the local data to copy to GCS.
|
||||
destination_uri: Required. GCS URI where the data should be uploaded.
|
||||
project: Optional. Google Cloud Project that contains the staging bucket.
|
||||
credentials: The custom credentials to use when making API calls.
|
||||
If not provided, default credentials will be used.
|
||||
|
||||
Raises:
|
||||
RuntimeError: When source_path does not exist.
|
||||
GoogleCloudError: When the upload process fails.
|
||||
"""
|
||||
source_path_obj = pathlib.Path(source_path)
|
||||
if not source_path_obj.exists():
|
||||
raise RuntimeError(f"Source path does not exist: {source_path}")
|
||||
|
||||
project = project or initializer.global_config.project
|
||||
credentials = credentials or initializer.global_config.credentials
|
||||
|
||||
storage_client = storage.Client(project=project, credentials=credentials)
|
||||
if source_path_obj.is_dir():
|
||||
source_file_paths = glob.glob(
|
||||
pathname=str(source_path_obj / "**"), recursive=True
|
||||
)
|
||||
for source_file_path in source_file_paths:
|
||||
source_file_path_obj = pathlib.Path(source_file_path)
|
||||
if source_file_path_obj.is_dir():
|
||||
continue
|
||||
source_file_relative_path_obj = source_file_path_obj.relative_to(
|
||||
source_path_obj
|
||||
)
|
||||
source_file_relative_posix_path = source_file_relative_path_obj.as_posix()
|
||||
destination_file_uri = (
|
||||
destination_uri.rstrip("/") + "/" + source_file_relative_posix_path
|
||||
)
|
||||
_logger.debug(f'Uploading "{source_file_path}" to "{destination_file_uri}"')
|
||||
destination_blob = storage.Blob.from_string(
|
||||
destination_file_uri, client=storage_client
|
||||
)
|
||||
destination_blob.upload_from_filename(filename=source_file_path)
|
||||
else:
|
||||
source_file_path = source_path
|
||||
destination_file_uri = destination_uri
|
||||
_logger.debug(f'Uploading "{source_file_path}" to "{destination_file_uri}"')
|
||||
destination_blob = storage.Blob.from_string(
|
||||
destination_file_uri, client=storage_client
|
||||
)
|
||||
destination_blob.upload_from_filename(filename=source_file_path)
|
||||
|
||||
|
||||
def stage_local_data_in_gcs(
|
||||
data_path: str,
|
||||
staging_gcs_dir: Optional[str] = None,
|
||||
project: Optional[str] = None,
|
||||
location: Optional[str] = None,
|
||||
credentials: Optional[auth_credentials.Credentials] = None,
|
||||
) -> str:
|
||||
"""Stages a local data in GCS.
|
||||
|
||||
The file copied to GCS is the name of the local file prepended with an
|
||||
"aiplatform-{timestamp}-" string.
|
||||
|
||||
Args:
|
||||
data_path: Required. Path of the local data to copy to GCS.
|
||||
staging_gcs_dir:
|
||||
Optional. Google Cloud Storage bucket to be used for data staging.
|
||||
project: Optional. Google Cloud Project that contains the staging bucket.
|
||||
location: Optional. Google Cloud location to use for the staging bucket.
|
||||
credentials: The custom credentials to use when making API calls.
|
||||
If not provided, default credentials will be used.
|
||||
|
||||
Returns:
|
||||
Google Cloud Storage URI of the staged data.
|
||||
|
||||
Raises:
|
||||
RuntimeError: When source_path does not exist.
|
||||
GoogleCloudError: When the upload process fails.
|
||||
"""
|
||||
data_path_obj = pathlib.Path(data_path)
|
||||
|
||||
if not data_path_obj.exists():
|
||||
raise RuntimeError(f"Local data does not exist: data_path='{data_path}'")
|
||||
|
||||
staging_gcs_dir = staging_gcs_dir or initializer.global_config.staging_bucket
|
||||
if not staging_gcs_dir:
|
||||
project = project or initializer.global_config.project
|
||||
location = location or initializer.global_config.location
|
||||
credentials = credentials or initializer.global_config.credentials
|
||||
# Creating the bucket if it does not exist.
|
||||
# Currently we only do this when staging_gcs_dir is not specified.
|
||||
# The buckets that we create are regional.
|
||||
# This prevents errors when some service required regional bucket.
|
||||
# E.g. "FailedPrecondition: 400 The Cloud Storage bucket of `gs://...` is in location `us`. It must be in the same regional location as the service location `us-central1`."
|
||||
# We are making the bucket name region-specific since the bucket is regional.
|
||||
staging_bucket_name = project + "-vertex-staging-" + location
|
||||
client = storage.Client(project=project, credentials=credentials)
|
||||
staging_bucket = storage.Bucket(client=client, name=staging_bucket_name)
|
||||
if not staging_bucket.exists():
|
||||
_logger.info(f'Creating staging GCS bucket "{staging_bucket_name}"')
|
||||
staging_bucket = client.create_bucket(
|
||||
bucket_or_name=staging_bucket,
|
||||
project=project,
|
||||
location=location,
|
||||
)
|
||||
staging_gcs_dir = "gs://" + staging_bucket_name
|
||||
|
||||
timestamp = datetime.datetime.now().isoformat(sep="-", timespec="milliseconds")
|
||||
staging_gcs_subdir = (
|
||||
staging_gcs_dir.rstrip("/") + "/vertex_ai_auto_staging/" + timestamp
|
||||
)
|
||||
|
||||
staged_data_uri = staging_gcs_subdir
|
||||
if data_path_obj.is_file():
|
||||
staged_data_uri = staging_gcs_subdir + "/" + data_path_obj.name
|
||||
|
||||
_logger.info(f'Uploading "{data_path}" to "{staged_data_uri}"')
|
||||
upload_to_gcs(
|
||||
source_path=data_path,
|
||||
destination_uri=staged_data_uri,
|
||||
project=project,
|
||||
credentials=credentials,
|
||||
)
|
||||
|
||||
return staged_data_uri
|
||||
|
||||
|
||||
def generate_gcs_directory_for_pipeline_artifacts(
|
||||
project: Optional[str] = None,
|
||||
location: Optional[str] = None,
|
||||
):
|
||||
"""Gets or creates the GCS directory for Vertex Pipelines artifacts.
|
||||
|
||||
Args:
|
||||
project: Optional. Google Cloud Project that contains the staging bucket.
|
||||
location: Optional. Google Cloud location to use for the staging bucket.
|
||||
|
||||
Returns:
|
||||
Google Cloud Storage URI of the staged data.
|
||||
"""
|
||||
project = project or initializer.global_config.project
|
||||
location = location or initializer.global_config.location
|
||||
|
||||
pipelines_bucket_name = project + "-vertex-pipelines-" + location
|
||||
output_artifacts_gcs_dir = "gs://" + pipelines_bucket_name + "/output_artifacts/"
|
||||
return output_artifacts_gcs_dir
|
||||
|
||||
|
||||
def create_gcs_bucket_for_pipeline_artifacts_if_it_does_not_exist(
|
||||
output_artifacts_gcs_dir: Optional[str] = None,
|
||||
service_account: Optional[str] = None,
|
||||
project: Optional[str] = None,
|
||||
location: Optional[str] = None,
|
||||
credentials: Optional[auth_credentials.Credentials] = None,
|
||||
):
|
||||
"""Gets or creates the GCS directory for Vertex Pipelines artifacts.
|
||||
|
||||
Args:
|
||||
output_artifacts_gcs_dir: Optional. The GCS location for the pipeline outputs.
|
||||
It will be generated if not specified.
|
||||
service_account: Optional. Google Cloud service account that will be used
|
||||
to run the pipelines. If this function creates a new bucket it will give
|
||||
permission to the specified service account to access the bucket.
|
||||
If not provided, the Google Cloud Compute Engine service account will be used.
|
||||
project: Optional. Google Cloud Project that contains the staging bucket.
|
||||
location: Optional. Google Cloud location to use for the staging bucket.
|
||||
credentials: The custom credentials to use when making API calls.
|
||||
If not provided, default credentials will be used.
|
||||
|
||||
Returns:
|
||||
Google Cloud Storage URI of the staged data.
|
||||
"""
|
||||
project = project or initializer.global_config.project
|
||||
location = location or initializer.global_config.location
|
||||
service_account = service_account or initializer.global_config.service_account
|
||||
credentials = credentials or initializer.global_config.credentials
|
||||
|
||||
output_artifacts_gcs_dir = (
|
||||
output_artifacts_gcs_dir
|
||||
or generate_gcs_directory_for_pipeline_artifacts(
|
||||
project=project,
|
||||
location=location,
|
||||
)
|
||||
)
|
||||
|
||||
# Creating the bucket if needed
|
||||
storage_client = storage.Client(
|
||||
project=project,
|
||||
credentials=credentials,
|
||||
)
|
||||
|
||||
pipelines_bucket = storage.Bucket.from_string(
|
||||
uri=output_artifacts_gcs_dir,
|
||||
client=storage_client,
|
||||
)
|
||||
|
||||
if not pipelines_bucket.exists():
|
||||
_logger.info(
|
||||
f'Creating GCS bucket for Vertex Pipelines: "{pipelines_bucket.name}"'
|
||||
)
|
||||
pipelines_bucket = storage_client.create_bucket(
|
||||
bucket_or_name=pipelines_bucket,
|
||||
project=project,
|
||||
location=location,
|
||||
)
|
||||
# Giving the service account read and write access to the new bucket
|
||||
# Workaround for error: "Failed to create pipeline job. Error: Service account `NNNNNNNN-compute@developer.gserviceaccount.com`
|
||||
# does not have `[storage.objects.get, storage.objects.create]` IAM permission(s) to the bucket `xxxxxxxx-vertex-pipelines-us-central1`.
|
||||
# Please either copy the files to the Google Cloud Storage bucket owned by your project, or grant the required IAM permission(s) to the service account."
|
||||
if not service_account:
|
||||
# Getting the project number to use in service account
|
||||
project_number = resource_manager_utils.get_project_number(project)
|
||||
service_account = f"{project_number}-compute@developer.gserviceaccount.com"
|
||||
bucket_iam_policy = pipelines_bucket.get_iam_policy()
|
||||
bucket_iam_policy.setdefault("roles/storage.objectCreator", set()).add(
|
||||
f"serviceAccount:{service_account}"
|
||||
)
|
||||
bucket_iam_policy.setdefault("roles/storage.objectViewer", set()).add(
|
||||
f"serviceAccount:{service_account}"
|
||||
)
|
||||
pipelines_bucket.set_iam_policy(bucket_iam_policy)
|
||||
return output_artifacts_gcs_dir
|
||||
|
||||
|
||||
def download_file_from_gcs(
|
||||
source_file_uri: str,
|
||||
destination_file_path: str,
|
||||
project: Optional[str] = None,
|
||||
credentials: Optional[auth_credentials.Credentials] = None,
|
||||
):
|
||||
"""Downloads a GCS file to local path.
|
||||
|
||||
Args:
|
||||
source_file_uri (str):
|
||||
Required. GCS URI of the file to download.
|
||||
destination_file_path (str):
|
||||
Required. local path where the data should be downloaded.
|
||||
project (str):
|
||||
Optional. Google Cloud Project that contains the staging bucket.
|
||||
credentials (auth_credentials.Credentials):
|
||||
Optional. The custom credentials to use when making API calls.
|
||||
If not provided, default credentials will be used.
|
||||
|
||||
Raises:
|
||||
RuntimeError: When destination_path does not exist.
|
||||
GoogleCloudError: When the download process fails.
|
||||
"""
|
||||
project = project or initializer.global_config.project
|
||||
credentials = credentials or initializer.global_config.credentials
|
||||
|
||||
storage_client = storage.Client(project=project, credentials=credentials)
|
||||
source_blob = storage.Blob.from_string(source_file_uri, client=storage_client)
|
||||
|
||||
_logger.debug(f'Downloading "{source_file_uri}" to "{destination_file_path}"')
|
||||
|
||||
source_blob.download_to_filename(filename=destination_file_path)
|
||||
|
||||
|
||||
def download_from_gcs(
|
||||
source_uri: str,
|
||||
destination_path: str,
|
||||
project: Optional[str] = None,
|
||||
credentials: Optional[auth_credentials.Credentials] = None,
|
||||
):
|
||||
"""Downloads GCS files to local path.
|
||||
|
||||
Args:
|
||||
source_uri (str):
|
||||
Required. GCS URI(or prefix) of the file(s) to download.
|
||||
destination_path (str):
|
||||
Required. local path where the data should be downloaded.
|
||||
If provided a file path, then `source_uri` must refer to a file.
|
||||
If provided a directory path, then `source_uri` must refer to a prefix.
|
||||
project (str):
|
||||
Optional. Google Cloud Project that contains the staging bucket.
|
||||
credentials (auth_credentials.Credentials):
|
||||
Optional. The custom credentials to use when making API calls.
|
||||
If not provided, default credentials will be used.
|
||||
|
||||
Raises:
|
||||
GoogleCloudError: When the download process fails.
|
||||
"""
|
||||
project = project or initializer.global_config.project
|
||||
credentials = credentials or initializer.global_config.credentials
|
||||
|
||||
storage_client = storage.Client(project=project, credentials=credentials)
|
||||
|
||||
validate_gcs_path(source_uri)
|
||||
bucket_name, prefix = source_uri.replace("gs://", "").split("/", maxsplit=1)
|
||||
|
||||
blobs = storage_client.list_blobs(bucket_or_name=bucket_name, prefix=prefix)
|
||||
for blob in blobs:
|
||||
# In SDK 2.0 remote training, we'll create some empty files.
|
||||
# These files ends with '/', and we'll skip them.
|
||||
if not blob.name.endswith("/"):
|
||||
rel_path = os.path.relpath(blob.name, prefix)
|
||||
filename = (
|
||||
destination_path
|
||||
if rel_path == "."
|
||||
else os.path.join(destination_path, rel_path)
|
||||
)
|
||||
os.makedirs(os.path.dirname(filename), exist_ok=True)
|
||||
blob.download_to_filename(filename=filename)
|
||||
|
||||
|
||||
def _upload_pandas_df_to_gcs(
|
||||
df: "pandas.DataFrame", upload_gcs_path: str, file_format: str = "jsonl"
|
||||
) -> None:
|
||||
"""Uploads the provided Pandas DataFrame to a GCS bucket.
|
||||
|
||||
Args:
|
||||
df (pandas.DataFrame):
|
||||
Required. The Pandas DataFrame to upload.
|
||||
upload_gcs_path (str):
|
||||
Required. The GCS path to upload the data file.
|
||||
file_format (str):
|
||||
Required. The format to export the DataFrame to. Currently
|
||||
only JSONL is supported.
|
||||
|
||||
Raises:
|
||||
ValueError: When a file format other than JSONL is provided.
|
||||
"""
|
||||
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
local_dataset_path = os.path.join(temp_dir, "dataset.jsonl")
|
||||
|
||||
if file_format == "jsonl":
|
||||
df.to_json(path_or_buf=local_dataset_path, orient="records", lines=True)
|
||||
else:
|
||||
raise ValueError(f"Unsupported file format: {file_format}")
|
||||
|
||||
storage_client = storage.Client(
|
||||
project=initializer.global_config.project,
|
||||
credentials=initializer.global_config.credentials,
|
||||
)
|
||||
storage.Blob.from_string(
|
||||
uri=upload_gcs_path, client=storage_client
|
||||
).upload_from_filename(filename=local_dataset_path)
|
||||
|
||||
|
||||
def validate_gcs_path(gcs_path: str) -> None:
|
||||
"""Validates a GCS path.
|
||||
|
||||
Args:
|
||||
gcs_path (str):
|
||||
Required. A GCS path to validate.
|
||||
Raises:
|
||||
ValueError if gcs_path is invalid.
|
||||
"""
|
||||
if not gcs_path.startswith("gs://"):
|
||||
raise ValueError(
|
||||
f"Invalid GCS path {gcs_path}. Please provide a valid GCS path starting with 'gs://'"
|
||||
)
|
||||
@@ -0,0 +1,41 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
# Copyright 2022 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
def _is_relative_to(path: str, to_path: str) -> bool:
|
||||
"""Returns whether or not this path is relative to another path.
|
||||
|
||||
This function can be replacted by Path.is_relative_to which is availble in Python 3.9+.
|
||||
|
||||
Args:
|
||||
path (str):
|
||||
Required. The path to check whether it is relative to the other path.
|
||||
to_path (str):
|
||||
Required. The path to check whether the other path is relative to it.
|
||||
|
||||
Returns:
|
||||
Whether the path is relative to another path.
|
||||
"""
|
||||
try:
|
||||
Path(path).expanduser().resolve().relative_to(
|
||||
Path(to_path).expanduser().resolve()
|
||||
)
|
||||
return True
|
||||
except ValueError:
|
||||
return False
|
||||
@@ -0,0 +1,293 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2021 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
import copy
|
||||
import json
|
||||
from typing import Any, Dict, Mapping, Optional, Union
|
||||
from google.cloud.aiplatform.compat.types import pipeline_failure_policy
|
||||
import packaging.version
|
||||
|
||||
|
||||
class PipelineRuntimeConfigBuilder(object):
|
||||
"""Pipeline RuntimeConfig builder.
|
||||
|
||||
Constructs a RuntimeConfig spec with pipeline_root and parameter overrides.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
pipeline_root: str,
|
||||
schema_version: str,
|
||||
parameter_types: Mapping[str, str],
|
||||
parameter_values: Optional[Dict[str, Any]] = None,
|
||||
input_artifacts: Optional[Dict[str, str]] = None,
|
||||
failure_policy: Optional[pipeline_failure_policy.PipelineFailurePolicy] = None,
|
||||
default_runtime: Optional[Dict[str, Any]] = None,
|
||||
):
|
||||
"""Creates a PipelineRuntimeConfigBuilder object.
|
||||
|
||||
Args:
|
||||
pipeline_root (str):
|
||||
Required. The root of the pipeline outputs.
|
||||
schema_version (str):
|
||||
Required. Schema version of the IR. This field determines the fields supported in current version of IR.
|
||||
parameter_types (Mapping[str, str]):
|
||||
Required. The mapping from pipeline parameter name to its type.
|
||||
parameter_values (Dict[str, Any]):
|
||||
Optional. The mapping from runtime parameter name to its value.
|
||||
input_artifacts (Dict[str, str]):
|
||||
Optional. The mapping from the runtime parameter name for this artifact to its resource id.
|
||||
failure_policy (pipeline_failure_policy.PipelineFailurePolicy):
|
||||
Optional. Represents the failure policy of a pipeline. Currently, the
|
||||
default of a pipeline is that the pipeline will continue to
|
||||
run until no more tasks can be executed, also known as
|
||||
PIPELINE_FAILURE_POLICY_FAIL_SLOW. However, if a pipeline is
|
||||
set to PIPELINE_FAILURE_POLICY_FAIL_FAST, it will stop
|
||||
scheduling any new tasks when a task has failed. Any
|
||||
scheduled tasks will continue to completion.
|
||||
default_runtime (Dict[str, Any]):
|
||||
Optional. The default runtime config for the pipeline.
|
||||
"""
|
||||
self._pipeline_root = pipeline_root
|
||||
self._schema_version = schema_version
|
||||
self._parameter_types = parameter_types
|
||||
self._parameter_values = copy.deepcopy(parameter_values or {})
|
||||
self._input_artifacts = copy.deepcopy(input_artifacts or {})
|
||||
self._failure_policy = failure_policy
|
||||
self._default_runtime = default_runtime
|
||||
|
||||
@classmethod
|
||||
def from_job_spec_json(
|
||||
cls,
|
||||
job_spec: Mapping[str, Any],
|
||||
) -> "PipelineRuntimeConfigBuilder":
|
||||
"""Creates a PipelineRuntimeConfigBuilder object from PipelineJob json spec.
|
||||
|
||||
Args:
|
||||
job_spec (Mapping[str, Any]):
|
||||
Required. The PipelineJob spec.
|
||||
|
||||
Returns:
|
||||
A PipelineRuntimeConfigBuilder object.
|
||||
"""
|
||||
runtime_config_spec = job_spec["runtimeConfig"]
|
||||
parameter_input_definitions = (
|
||||
job_spec["pipelineSpec"]["root"]
|
||||
.get("inputDefinitions", {})
|
||||
.get("parameters", {})
|
||||
)
|
||||
schema_version = job_spec["pipelineSpec"]["schemaVersion"]
|
||||
|
||||
# 'type' is deprecated in IR and change to 'parameterType'.
|
||||
parameter_types = {
|
||||
k: v.get("parameterType") or v.get("type")
|
||||
for k, v in parameter_input_definitions.items()
|
||||
}
|
||||
|
||||
pipeline_root = runtime_config_spec.get("gcsOutputDirectory")
|
||||
parameter_values = _parse_runtime_parameters(runtime_config_spec)
|
||||
failure_policy = runtime_config_spec.get("failurePolicy")
|
||||
return cls(
|
||||
pipeline_root=pipeline_root,
|
||||
schema_version=schema_version,
|
||||
parameter_types=parameter_types,
|
||||
parameter_values=parameter_values,
|
||||
failure_policy=failure_policy,
|
||||
)
|
||||
|
||||
def update_pipeline_root(self, pipeline_root: Optional[str]) -> None:
|
||||
"""Updates pipeline_root value.
|
||||
|
||||
Args:
|
||||
pipeline_root (str):
|
||||
Optional. The root of the pipeline outputs.
|
||||
"""
|
||||
if pipeline_root:
|
||||
self._pipeline_root = pipeline_root
|
||||
|
||||
def update_runtime_parameters(
|
||||
self, parameter_values: Optional[Mapping[str, Any]] = None
|
||||
) -> None:
|
||||
"""Merges runtime parameter values.
|
||||
|
||||
Args:
|
||||
parameter_values (Mapping[str, Any]):
|
||||
Optional. The mapping from runtime parameter names to its values.
|
||||
"""
|
||||
if parameter_values:
|
||||
parameters = dict(parameter_values)
|
||||
if packaging.version.parse(self._schema_version) <= packaging.version.parse(
|
||||
"2.0.0"
|
||||
):
|
||||
for k, v in parameter_values.items():
|
||||
if isinstance(v, (dict, list, bool)):
|
||||
parameters[k] = json.dumps(v)
|
||||
self._parameter_values.update(parameters)
|
||||
|
||||
def update_input_artifacts(
|
||||
self, input_artifacts: Optional[Mapping[str, str]]
|
||||
) -> None:
|
||||
"""Merges runtime input artifacts.
|
||||
|
||||
Args:
|
||||
input_artifacts (Mapping[str, str]):
|
||||
Optional. The mapping from the runtime parameter name for this artifact to its resource id.
|
||||
"""
|
||||
if input_artifacts:
|
||||
self._input_artifacts.update(input_artifacts)
|
||||
|
||||
def update_failure_policy(self, failure_policy: Optional[str] = None) -> None:
|
||||
"""Merges runtime failure policy.
|
||||
|
||||
Args:
|
||||
failure_policy (str):
|
||||
Optional. The failure policy - "slow" or "fast".
|
||||
|
||||
Raises:
|
||||
ValueError: if failure_policy is not valid.
|
||||
"""
|
||||
if failure_policy:
|
||||
if failure_policy in _FAILURE_POLICY_TO_ENUM_VALUE:
|
||||
self._failure_policy = _FAILURE_POLICY_TO_ENUM_VALUE[failure_policy]
|
||||
else:
|
||||
raise ValueError(
|
||||
f'failure_policy should be either "slow" or "fast", but got: "{failure_policy}".'
|
||||
)
|
||||
|
||||
def update_default_runtime(self, default_runtime: Dict[str, Any]) -> None:
|
||||
"""Merges default runtime.
|
||||
|
||||
Args:
|
||||
default_runtime (Dict[str, Any]):
|
||||
default runtime config for the pipeline.
|
||||
"""
|
||||
if default_runtime:
|
||||
self._default_runtime = default_runtime
|
||||
|
||||
def build(self) -> Dict[str, Any]:
|
||||
"""Build a RuntimeConfig proto.
|
||||
|
||||
Raises:
|
||||
ValueError: if the pipeline root is not specified.
|
||||
"""
|
||||
if not self._pipeline_root:
|
||||
raise ValueError(
|
||||
"Pipeline root must be specified, either during "
|
||||
"compile time, or when calling the service."
|
||||
)
|
||||
if packaging.version.parse(self._schema_version) > packaging.version.parse(
|
||||
"2.0.0"
|
||||
):
|
||||
parameter_values_key = "parameterValues"
|
||||
else:
|
||||
parameter_values_key = "parameters"
|
||||
|
||||
runtime_config = {
|
||||
"gcsOutputDirectory": self._pipeline_root,
|
||||
parameter_values_key: {
|
||||
k: self._get_vertex_value(k, v)
|
||||
for k, v in self._parameter_values.items()
|
||||
if v is not None
|
||||
},
|
||||
"inputArtifacts": {
|
||||
k: {"artifactId": v} for k, v in self._input_artifacts.items()
|
||||
},
|
||||
}
|
||||
if self._default_runtime:
|
||||
# Only v1beta1 supports DefaultRuntime. For other cases, the field is
|
||||
# None and we clear the defaultRuntime field.
|
||||
runtime_config["defaultRuntime"] = self._default_runtime
|
||||
|
||||
if self._failure_policy:
|
||||
runtime_config["failurePolicy"] = self._failure_policy
|
||||
|
||||
return runtime_config
|
||||
|
||||
def _get_vertex_value(
|
||||
self, name: str, value: Union[int, float, str, bool, list, dict]
|
||||
) -> Union[int, float, str, bool, list, dict]:
|
||||
"""Converts primitive values into Vertex pipeline Value proto message.
|
||||
|
||||
Args:
|
||||
name (str):
|
||||
Required. The name of the pipeline parameter.
|
||||
value (Union[int, float, str, bool, list, dict]):
|
||||
Required. The value of the pipeline parameter.
|
||||
|
||||
Returns:
|
||||
A dictionary represents the Vertex pipeline Value proto message.
|
||||
|
||||
Raises:
|
||||
ValueError: if the parameter name is not found in pipeline root
|
||||
inputs, or value is none.
|
||||
"""
|
||||
if value is None:
|
||||
raise ValueError("None values should be filtered out.")
|
||||
|
||||
if name not in self._parameter_types:
|
||||
raise ValueError(
|
||||
"The pipeline parameter {} is not found in the "
|
||||
"pipeline job input definitions.".format(name)
|
||||
)
|
||||
|
||||
if packaging.version.parse(self._schema_version) <= packaging.version.parse(
|
||||
"2.0.0"
|
||||
):
|
||||
result = {}
|
||||
if self._parameter_types[name] == "INT":
|
||||
result["intValue"] = value
|
||||
elif self._parameter_types[name] == "DOUBLE":
|
||||
result["doubleValue"] = value
|
||||
elif self._parameter_types[name] == "STRING":
|
||||
result["stringValue"] = value
|
||||
else:
|
||||
raise TypeError("Got unknown type of value: {}".format(value))
|
||||
return result
|
||||
else:
|
||||
return value
|
||||
|
||||
|
||||
def _parse_runtime_parameters(
|
||||
runtime_config_spec: Mapping[str, Any]
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""Extracts runtime parameters from runtime config json spec.
|
||||
|
||||
Raises:
|
||||
TypeError: if the parameter type is not one of 'INT', 'DOUBLE', 'STRING'.
|
||||
"""
|
||||
# 'parameters' are deprecated in IR and changed to 'parameterValues'.
|
||||
if runtime_config_spec.get("parameterValues") is not None:
|
||||
return runtime_config_spec.get("parameterValues")
|
||||
|
||||
if runtime_config_spec.get("parameters") is not None:
|
||||
result = {}
|
||||
for name, value in runtime_config_spec.get("parameters").items():
|
||||
if "intValue" in value:
|
||||
result[name] = int(value["intValue"])
|
||||
elif "doubleValue" in value:
|
||||
result[name] = float(value["doubleValue"])
|
||||
elif "stringValue" in value:
|
||||
result[name] = value["stringValue"]
|
||||
else:
|
||||
raise TypeError("Got unknown type of value: {}".format(value))
|
||||
return result
|
||||
|
||||
|
||||
_FAILURE_POLICY_TO_ENUM_VALUE = {
|
||||
"slow": pipeline_failure_policy.PipelineFailurePolicy.PIPELINE_FAILURE_POLICY_FAIL_SLOW,
|
||||
"fast": pipeline_failure_policy.PipelineFailurePolicy.PIPELINE_FAILURE_POLICY_FAIL_FAST,
|
||||
None: pipeline_failure_policy.PipelineFailurePolicy.PIPELINE_FAILURE_POLICY_UNSPECIFIED,
|
||||
}
|
||||
@@ -0,0 +1,153 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
# Copyright 2022 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
import shutil
|
||||
import inspect
|
||||
import logging
|
||||
import os
|
||||
from pathlib import Path
|
||||
import re
|
||||
from typing import Any, Optional, Sequence, Tuple, Type
|
||||
|
||||
from google.cloud import storage
|
||||
from google.cloud.aiplatform.constants import prediction
|
||||
from google.cloud.aiplatform.utils import path_utils
|
||||
|
||||
_logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
REGISTRY_REGEX = re.compile(r"^([\w\-]+\-docker\.pkg\.dev|([\w]+\.|)gcr\.io)")
|
||||
GCS_URI_PREFIX = "gs://"
|
||||
|
||||
|
||||
def inspect_source_from_class(
|
||||
custom_class: Type[Any],
|
||||
src_dir: str,
|
||||
) -> Tuple[str, str]:
|
||||
"""Inspects the source file from a custom class and returns its import path.
|
||||
|
||||
Args:
|
||||
custom_class (Type[Any]):
|
||||
Required. The custom class needs to be inspected for the source file.
|
||||
src_dir (str):
|
||||
Required. The path to the local directory including all needed files.
|
||||
The source file of the custom class must be in this directory.
|
||||
|
||||
Returns:
|
||||
(import_from, class_name): the source file path in python import format
|
||||
and the custom class name.
|
||||
|
||||
Raises:
|
||||
ValueError: If the source file of the custom class is not in the source
|
||||
directory.
|
||||
"""
|
||||
src_dir_abs_path = Path(src_dir).expanduser().resolve()
|
||||
|
||||
custom_class_name = custom_class.__name__
|
||||
|
||||
custom_class_path = Path(inspect.getsourcefile(custom_class)).resolve()
|
||||
if not path_utils._is_relative_to(custom_class_path, src_dir_abs_path):
|
||||
raise ValueError(
|
||||
f'The file implementing "{custom_class_name}" must be in "{src_dir}".'
|
||||
)
|
||||
|
||||
custom_class_import_path = custom_class_path.relative_to(src_dir_abs_path)
|
||||
custom_class_import_path = custom_class_import_path.with_name(
|
||||
custom_class_import_path.stem
|
||||
)
|
||||
custom_class_import = custom_class_import_path.as_posix().replace(os.sep, ".")
|
||||
|
||||
return custom_class_import, custom_class_name
|
||||
|
||||
|
||||
def is_registry_uri(image_uri: str) -> bool:
|
||||
"""Checks whether the image uri is in container registry or artifact registry.
|
||||
|
||||
Args:
|
||||
image_uri (str):
|
||||
The image uri to check if it is in container registry or artifact registry.
|
||||
|
||||
Returns:
|
||||
True if the image uri is in container registry or artifact registry.
|
||||
"""
|
||||
return REGISTRY_REGEX.match(image_uri) is not None
|
||||
|
||||
|
||||
def get_prediction_aip_http_port(
|
||||
serving_container_ports: Optional[Sequence[int]] = None,
|
||||
) -> int:
|
||||
"""Gets the used prediction container port from serving container ports.
|
||||
|
||||
If containerSpec.ports is specified during Model or LocalModel creation time, retrieve
|
||||
the first entry in this field. Otherwise use the default value of 8080. The environment
|
||||
variable AIP_HTTP_PORT will be set to this value.
|
||||
See https://cloud.google.com/vertex-ai/docs/predictions/custom-container-requirements
|
||||
for more details.
|
||||
|
||||
Args:
|
||||
serving_container_ports (Sequence[int]):
|
||||
Optional. Declaration of ports that are exposed by the container. This field is
|
||||
primarily informational, it gives Vertex AI information about the
|
||||
network connections the container uses. Listing or not a port here has
|
||||
no impact on whether the port is actually exposed, any port listening on
|
||||
the default "0.0.0.0" address inside a container will be accessible from
|
||||
the network.
|
||||
|
||||
Returns:
|
||||
The first element in the serving_container_ports. If there is no any values in it,
|
||||
return the default http port.
|
||||
"""
|
||||
return (
|
||||
serving_container_ports[0]
|
||||
if serving_container_ports is not None and len(serving_container_ports) > 0
|
||||
else prediction.DEFAULT_AIP_HTTP_PORT
|
||||
)
|
||||
|
||||
|
||||
def download_model_artifacts(artifact_uri: str) -> None:
|
||||
"""Prepares model artifacts in the current working directory.
|
||||
|
||||
If artifact_uri is a GCS uri, the model artifacts will be downloaded to the current
|
||||
working directory.
|
||||
If artifact_uri is a local directory, the model artifacts will be copied to the current
|
||||
working directory.
|
||||
|
||||
Args:
|
||||
artifact_uri (str):
|
||||
Required. The artifact uri that includes model artifacts.
|
||||
"""
|
||||
if artifact_uri.startswith(GCS_URI_PREFIX):
|
||||
matches = re.match(f"{GCS_URI_PREFIX}(.*?)/(.*)", artifact_uri)
|
||||
bucket_name, prefix = matches.groups()
|
||||
|
||||
gcs_client = storage.Client()
|
||||
blobs = gcs_client.list_blobs(bucket_name, prefix=prefix)
|
||||
for blob in blobs:
|
||||
name_without_prefix = blob.name[len(prefix) :]
|
||||
name_without_prefix = (
|
||||
name_without_prefix[1:]
|
||||
if name_without_prefix.startswith("/")
|
||||
else name_without_prefix
|
||||
)
|
||||
file_split = name_without_prefix.split("/")
|
||||
directory = "/".join(file_split[0:-1])
|
||||
Path(directory).mkdir(parents=True, exist_ok=True)
|
||||
if name_without_prefix and not name_without_prefix.endswith("/"):
|
||||
blob.download_to_filename(name_without_prefix)
|
||||
else:
|
||||
# Copy files to the current working directory.
|
||||
shutil.copytree(artifact_uri, ".", dirs_exist_ok=True)
|
||||
@@ -0,0 +1,77 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
# Copyright 2022 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from google.auth import credentials as auth_credentials
|
||||
from google.cloud import resourcemanager
|
||||
|
||||
from google.cloud.aiplatform import initializer
|
||||
|
||||
|
||||
def get_project_id(
|
||||
project_number: str,
|
||||
credentials: Optional[auth_credentials.Credentials] = None,
|
||||
) -> str:
|
||||
"""Gets project ID given the project number
|
||||
|
||||
Args:
|
||||
project_number (str):
|
||||
Required. The automatically generated unique identifier for your GCP project.
|
||||
credentials: The custom credentials to use when making API calls.
|
||||
Optional. If not provided, default credentials will be used.
|
||||
|
||||
Returns:
|
||||
str - The unique string used to differentiate your GCP project from all others in Google Cloud.
|
||||
|
||||
"""
|
||||
|
||||
credentials = credentials or initializer.global_config.credentials
|
||||
|
||||
projects_client = resourcemanager.ProjectsClient(credentials=credentials)
|
||||
|
||||
project = projects_client.get_project(name=f"projects/{project_number}")
|
||||
|
||||
return project.project_id
|
||||
|
||||
|
||||
def get_project_number(
|
||||
project_id: str,
|
||||
credentials: Optional[auth_credentials.Credentials] = None,
|
||||
) -> str:
|
||||
"""Gets project ID given the project number
|
||||
|
||||
Args:
|
||||
project_id (str):
|
||||
Required. Google Cloud project unique ID.
|
||||
credentials: The custom credentials to use when making API calls.
|
||||
Optional. If not provided, default credentials will be used.
|
||||
|
||||
Returns:
|
||||
str - The automatically generated unique numerical identifier for your GCP project.
|
||||
|
||||
"""
|
||||
|
||||
credentials = credentials or initializer.global_config.credentials
|
||||
|
||||
projects_client = resourcemanager.ProjectsClient(credentials=credentials)
|
||||
|
||||
project = projects_client.get_project(name=f"projects/{project_id}")
|
||||
project_number = project.name.split("/", 1)[1]
|
||||
|
||||
return project_number
|
||||
@@ -0,0 +1,35 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2022 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
from google.cloud.aiplatform import base
|
||||
|
||||
|
||||
def make_gcp_resource_rest_url(resource: base.VertexAiResourceNoun) -> str:
|
||||
"""Helper function to format the GCP resource url for google.X metadata schemas.
|
||||
|
||||
Args:
|
||||
resource (base.VertexAiResourceNoun): Required. A Vertex resource instance.
|
||||
Returns:
|
||||
The formatted url of resource.
|
||||
"""
|
||||
try:
|
||||
resource_name = resource.versioned_resource_name
|
||||
except AttributeError:
|
||||
resource_name = resource.resource_name
|
||||
version = resource.api_client._default_version
|
||||
api_uri = resource.api_client.api_endpoint
|
||||
|
||||
return f"https://{api_uri}/{version}/{resource_name}"
|
||||
@@ -0,0 +1,248 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2021 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
|
||||
import functools
|
||||
import os
|
||||
import pathlib
|
||||
import shutil
|
||||
import subprocess
|
||||
import sys
|
||||
import tempfile
|
||||
from typing import Optional, Sequence, Callable
|
||||
|
||||
from google.auth import credentials as auth_credentials
|
||||
from google.cloud.aiplatform import base
|
||||
from google.cloud.aiplatform import utils
|
||||
|
||||
_LOGGER = base.Logger(__name__)
|
||||
|
||||
|
||||
def _get_python_executable() -> str:
|
||||
"""Returns Python executable.
|
||||
|
||||
Returns:
|
||||
Python executable to use for setuptools packaging.
|
||||
Raises:
|
||||
EnvironmentError: If Python executable is not found.
|
||||
"""
|
||||
|
||||
python_executable = sys.executable
|
||||
|
||||
if not python_executable:
|
||||
raise EnvironmentError("Cannot find Python executable for packaging.")
|
||||
return python_executable
|
||||
|
||||
|
||||
class _TrainingScriptPythonPackager:
|
||||
"""Converts a Python script into Python package suitable for aiplatform
|
||||
training.
|
||||
|
||||
Copies the script to specified location.
|
||||
|
||||
Class Attributes:
|
||||
_TRAINER_FOLDER: Constant folder name to build package.
|
||||
_ROOT_MODULE: Constant root name of module.
|
||||
_TEST_MODULE_NAME: Constant name of module that will store script.
|
||||
_SETUP_PY_VERSION: Constant version of this created python package.
|
||||
_SETUP_PY_TEMPLATE: Constant template used to generate setup.py file.
|
||||
_SETUP_PY_SOURCE_DISTRIBUTION_CMD:
|
||||
Constant command to generate the source distribution package.
|
||||
|
||||
Attributes:
|
||||
script_path: local path of script or folder to package
|
||||
requirements: list of Python dependencies to add to package
|
||||
|
||||
Usage:
|
||||
|
||||
packager = TrainingScriptPythonPackager('my_script.py', ['pandas', 'pytorch'])
|
||||
gcs_path = packager.package_and_copy_to_gcs(
|
||||
gcs_staging_dir='my-bucket',
|
||||
project='my-project')
|
||||
module_name = packager.module_name
|
||||
|
||||
The package after installed can be executed as:
|
||||
python -m aiplatform_custom_trainer_script.task
|
||||
"""
|
||||
|
||||
_TRAINER_FOLDER = "trainer"
|
||||
_ROOT_MODULE = "aiplatform_custom_trainer_script"
|
||||
_SETUP_PY_VERSION = "0.1"
|
||||
|
||||
_SETUP_PY_TEMPLATE = """from setuptools import find_packages
|
||||
from setuptools import setup
|
||||
|
||||
setup(
|
||||
name='{name}',
|
||||
version='{version}',
|
||||
packages=find_packages(),
|
||||
install_requires=({requirements}),
|
||||
include_package_data=True,
|
||||
description='My training application.'
|
||||
)"""
|
||||
|
||||
_SETUP_PY_SOURCE_DISTRIBUTION_CMD = "setup.py sdist --formats=gztar"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
script_path: str,
|
||||
task_module_name: str = "task",
|
||||
requirements: Optional[Sequence[str]] = None,
|
||||
):
|
||||
"""Initializes packager.
|
||||
|
||||
Args:
|
||||
script_path (str): Required. Local path to script.
|
||||
requirements (Sequence[str]):
|
||||
List of python packages dependencies of script.
|
||||
"""
|
||||
|
||||
self.script_path = script_path
|
||||
self.task_module_name = task_module_name
|
||||
self.requirements = requirements or []
|
||||
|
||||
@property
|
||||
def module_name(self) -> str:
|
||||
# Module name that can be executed during training. ie. python -m
|
||||
return f"{self._ROOT_MODULE}.{self.task_module_name}"
|
||||
|
||||
def make_package(self, package_directory: str) -> str:
|
||||
"""Converts script into a Python package suitable for python module
|
||||
execution.
|
||||
|
||||
Args:
|
||||
package_directory (str): Directory to build package in.
|
||||
Returns:
|
||||
source_distribution_path (str): Path to built package.
|
||||
Raises:
|
||||
RunTimeError: If package creation fails.
|
||||
"""
|
||||
# The root folder to builder the package in
|
||||
package_path = pathlib.Path(package_directory)
|
||||
|
||||
# Root directory of the package
|
||||
trainer_root_path = package_path / self._TRAINER_FOLDER
|
||||
|
||||
# The root module of the python package
|
||||
trainer_path = trainer_root_path / self._ROOT_MODULE
|
||||
|
||||
# __init__.py path in root module
|
||||
init_path = trainer_path / "__init__.py"
|
||||
|
||||
# The path to setup.py in the package.
|
||||
setup_py_path = trainer_root_path / "setup.py"
|
||||
|
||||
# The path to the generated source distribution.
|
||||
source_distribution_path = (
|
||||
trainer_root_path
|
||||
/ "dist"
|
||||
/ f"{self._ROOT_MODULE}-{self._SETUP_PY_VERSION}.tar.gz"
|
||||
)
|
||||
|
||||
trainer_root_path.mkdir()
|
||||
trainer_path.mkdir()
|
||||
|
||||
# Make empty __init__.py
|
||||
with init_path.open("w"):
|
||||
pass
|
||||
|
||||
# Format the setup.py file.
|
||||
setup_py_output = self._SETUP_PY_TEMPLATE.format(
|
||||
name=self._ROOT_MODULE,
|
||||
requirements=",".join(f'"{r}"' for r in self.requirements),
|
||||
version=self._SETUP_PY_VERSION,
|
||||
)
|
||||
|
||||
# Write setup.py
|
||||
with setup_py_path.open("w") as fp:
|
||||
fp.write(setup_py_output)
|
||||
|
||||
if os.path.isdir(self.script_path):
|
||||
# Remove destination path if it already exists
|
||||
shutil.rmtree(trainer_path)
|
||||
|
||||
# Copy folder recursively
|
||||
shutil.copytree(src=self.script_path, dst=trainer_path)
|
||||
else:
|
||||
# The module that will contain the script
|
||||
script_out_path = trainer_path / f"{self.task_module_name}.py"
|
||||
|
||||
# Copy script as module of python package.
|
||||
shutil.copy(self.script_path, script_out_path)
|
||||
|
||||
# Run setup.py to create the source distribution.
|
||||
setup_cmd = [
|
||||
_get_python_executable()
|
||||
] + self._SETUP_PY_SOURCE_DISTRIBUTION_CMD.split()
|
||||
|
||||
p = subprocess.Popen(
|
||||
args=setup_cmd,
|
||||
cwd=trainer_root_path,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.PIPE,
|
||||
)
|
||||
output, error = p.communicate()
|
||||
|
||||
# Raise informative error if packaging fails.
|
||||
if p.returncode != 0:
|
||||
raise RuntimeError(
|
||||
"Packaging of training script failed with code %d\n%s \n%s"
|
||||
% (p.returncode, output.decode(), error.decode())
|
||||
)
|
||||
|
||||
return str(source_distribution_path)
|
||||
|
||||
def package_and_copy(self, copy_method: Callable[[str], str]) -> str:
|
||||
"""Packages the script and executes copy with given copy_method.
|
||||
|
||||
Args:
|
||||
copy_method Callable[[str], str]
|
||||
Takes a string path, copies to a desired location, and returns the
|
||||
output path location.
|
||||
Returns:
|
||||
output_path str: Location of copied package.
|
||||
"""
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
source_distribution_path = self.make_package(tmpdirname)
|
||||
output_location = copy_method(source_distribution_path)
|
||||
_LOGGER.info("Training script copied to:\n%s." % output_location)
|
||||
return output_location
|
||||
|
||||
def package_and_copy_to_gcs(
|
||||
self,
|
||||
gcs_staging_dir: str,
|
||||
project: str = None,
|
||||
credentials: Optional[auth_credentials.Credentials] = None,
|
||||
) -> str:
|
||||
"""Packages script in Python package and copies package to GCS bucket.
|
||||
|
||||
Args
|
||||
gcs_staging_dir (str): Required. GCS Staging directory.
|
||||
project (str): Required. Project where GCS Staging bucket is located.
|
||||
credentials (auth_credentials.Credentials):
|
||||
Optional credentials used with GCS client.
|
||||
Returns:
|
||||
GCS location of Python package.
|
||||
"""
|
||||
|
||||
copy_method = functools.partial(
|
||||
utils._timestamped_copy_to_gcs,
|
||||
gcs_dir=gcs_staging_dir,
|
||||
project=project,
|
||||
credentials=credentials,
|
||||
)
|
||||
return self.package_and_copy(copy_method=copy_method)
|
||||
@@ -0,0 +1,93 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2021 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
from typing import Sequence, Dict
|
||||
from google.cloud.aiplatform_v1beta1.services.tensorboard_service.client import (
|
||||
TensorboardServiceClient,
|
||||
)
|
||||
|
||||
_SERVING_DOMAIN = "tensorboard.googleusercontent.com"
|
||||
|
||||
|
||||
def _parse_experiment_name(experiment_name: str) -> Dict[str, str]:
|
||||
"""Parses an experiment_name into its component segments.
|
||||
|
||||
Args:
|
||||
experiment_name: Resource name of the TensorboardExperiment. E.g.
|
||||
"projects/123/locations/asia-east1/tensorboards/456/experiments/exp1"
|
||||
|
||||
Returns:
|
||||
Components of the experiment name.
|
||||
|
||||
Raises:
|
||||
ValueError: If the experiment_name is invalid.
|
||||
"""
|
||||
matched = TensorboardServiceClient.parse_tensorboard_experiment_path(
|
||||
experiment_name
|
||||
)
|
||||
if not matched:
|
||||
raise ValueError(f"Invalid experiment name: {experiment_name}.")
|
||||
return matched
|
||||
|
||||
|
||||
def get_experiment_url(experiment_name: str) -> str:
|
||||
"""Get URL for comparing experiments.
|
||||
|
||||
Args:
|
||||
experiment_name: Resource name of the TensorboardExperiment. E.g.
|
||||
"projects/123/locations/asia-east1/tensorboards/456/experiments/exp1"
|
||||
|
||||
Returns:
|
||||
URL for the tensorboard web app.
|
||||
"""
|
||||
location = _parse_experiment_name(experiment_name)["location"]
|
||||
name_for_url = experiment_name.replace("/", "+")
|
||||
return f"https://{location}.{_SERVING_DOMAIN}/experiment/{name_for_url}"
|
||||
|
||||
|
||||
def get_experiments_compare_url(experiment_names: Sequence[str]) -> str:
|
||||
"""Get URL for comparing experiments.
|
||||
|
||||
Args:
|
||||
experiment_names: Resource names of the TensorboardExperiments that needs to
|
||||
be compared.
|
||||
|
||||
Returns:
|
||||
URL for the tensorboard web app.
|
||||
"""
|
||||
if len(experiment_names) < 2:
|
||||
raise ValueError("At least two experiment_names are required.")
|
||||
|
||||
locations = {
|
||||
_parse_experiment_name(experiment_name)["location"]
|
||||
for experiment_name in experiment_names
|
||||
}
|
||||
if len(locations) != 1:
|
||||
raise ValueError(
|
||||
f"Got experiments from different locations: {', '.join(locations)}."
|
||||
)
|
||||
location = locations.pop()
|
||||
|
||||
experiment_url_segments = []
|
||||
for idx, experiment_name in enumerate(experiment_names):
|
||||
name_segments = _parse_experiment_name(experiment_name)
|
||||
experiment_url_segments.append(
|
||||
"{cnt}-{experiment}:{project}+{location}+{tensorboard}+{experiment}".format(
|
||||
cnt=idx + 1, **name_segments
|
||||
)
|
||||
)
|
||||
encoded_names = ",".join(experiment_url_segments)
|
||||
return f"https://{location}.{_SERVING_DOMAIN}/compare/{encoded_names}"
|
||||
@@ -0,0 +1,314 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2020 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
from typing import NamedTuple, Optional, Dict, Union, List, Literal
|
||||
|
||||
from google.cloud.aiplatform import utils
|
||||
from google.cloud.aiplatform.compat.types import (
|
||||
accelerator_type as gca_accelerator_type_compat,
|
||||
)
|
||||
|
||||
# `_SPEC_ORDERS` contains the worker pool spec type and its order in the `_WorkerPoolSpec`.
|
||||
# The `server_spec` supports either reduction server or parameter server, each
|
||||
# with different configuration for its `container_spec`. This mapping will be
|
||||
# used during configuration of `container_spec` for all worker pool specs.
|
||||
_SPEC_ORDERS = {
|
||||
"chief_spec": 0,
|
||||
"worker_spec": 1,
|
||||
"server_spec": 2,
|
||||
"evaluator_spec": 3,
|
||||
}
|
||||
|
||||
|
||||
class _WorkerPoolSpec(NamedTuple):
|
||||
"""Specification container for Worker Pool specs used for distributed training.
|
||||
|
||||
Usage:
|
||||
|
||||
spec = _WorkerPoolSpec(
|
||||
replica_count=10,
|
||||
machine_type='n1-standard-4',
|
||||
accelerator_count=2,
|
||||
accelerator_type='NVIDIA_TESLA_K80',
|
||||
boot_disk_type='pd-ssd',
|
||||
boot_disk_size_gb=100,
|
||||
reservation_affinity_type=reservation_affinity_type,
|
||||
reservation_affinity_key=reservation_affinity_key,
|
||||
reservation_affinity_values=reservation_affinity_values,
|
||||
)
|
||||
|
||||
Note that container and python package specs are not stored with this spec.
|
||||
"""
|
||||
|
||||
replica_count: int = 0
|
||||
machine_type: str = "n1-standard-4"
|
||||
accelerator_count: int = 0
|
||||
accelerator_type: str = "ACCELERATOR_TYPE_UNSPECIFIED"
|
||||
boot_disk_type: str = "pd-ssd"
|
||||
boot_disk_size_gb: int = 100
|
||||
tpu_topology: Optional[str] = None
|
||||
reservation_affinity_type: Optional[
|
||||
Literal["NO_RESERVATION", "ANY_RESERVATION", "SPECIFIC_RESERVATION"]
|
||||
] = None
|
||||
reservation_affinity_key: Optional[str] = None
|
||||
reservation_affinity_values: Optional[List[str]] = None
|
||||
|
||||
def _get_accelerator_type(self) -> Optional[str]:
|
||||
"""Validates accelerator_type and returns the name of the accelerator.
|
||||
|
||||
Returns:
|
||||
None if no accelerator or valid accelerator name.
|
||||
|
||||
Raise:
|
||||
ValueError if accelerator type is invalid.
|
||||
"""
|
||||
|
||||
# Raises ValueError if invalid accelerator_type
|
||||
utils.validate_accelerator_type(self.accelerator_type)
|
||||
|
||||
accelerator_enum = getattr(
|
||||
gca_accelerator_type_compat.AcceleratorType, self.accelerator_type
|
||||
)
|
||||
|
||||
if (
|
||||
accelerator_enum
|
||||
!= gca_accelerator_type_compat.AcceleratorType.ACCELERATOR_TYPE_UNSPECIFIED
|
||||
):
|
||||
return self.accelerator_type
|
||||
|
||||
@property
|
||||
def spec_dict(self) -> Dict[str, Union[int, str, Dict[str, Union[int, str]]]]:
|
||||
"""Return specification as a Dict."""
|
||||
spec = {
|
||||
"machine_spec": {"machine_type": self.machine_type},
|
||||
"replica_count": self.replica_count,
|
||||
"disk_spec": {
|
||||
"boot_disk_type": self.boot_disk_type,
|
||||
"boot_disk_size_gb": self.boot_disk_size_gb,
|
||||
},
|
||||
}
|
||||
|
||||
accelerator_type = self._get_accelerator_type()
|
||||
if accelerator_type and self.accelerator_count:
|
||||
spec["machine_spec"]["accelerator_type"] = accelerator_type
|
||||
spec["machine_spec"]["accelerator_count"] = self.accelerator_count
|
||||
|
||||
if self.tpu_topology:
|
||||
spec["machine_spec"]["tpu_topology"] = self.tpu_topology
|
||||
|
||||
if self.reservation_affinity_type:
|
||||
spec["machine_spec"]["reservation_affinity"] = {
|
||||
"reservation_affinity_type": self.reservation_affinity_type,
|
||||
}
|
||||
if self.reservation_affinity_type == "SPECIFIC_RESERVATION":
|
||||
spec["machine_spec"]["reservation_affinity"][
|
||||
"key"
|
||||
] = self.reservation_affinity_key
|
||||
spec["machine_spec"]["reservation_affinity"][
|
||||
"values"
|
||||
] = self.reservation_affinity_values
|
||||
|
||||
return spec
|
||||
|
||||
@property
|
||||
def is_empty(self) -> bool:
|
||||
"""Returns True is replica_count > 0 False otherwise."""
|
||||
return self.replica_count <= 0
|
||||
|
||||
|
||||
class _DistributedTrainingSpec(NamedTuple):
|
||||
"""Configuration for distributed training worker pool specs.
|
||||
|
||||
Vertex AI Training expects configuration in this order:
|
||||
[
|
||||
chief spec, # can only have one replica
|
||||
worker spec,
|
||||
parameter server spec,
|
||||
evaluator spec
|
||||
]
|
||||
|
||||
Usage:
|
||||
|
||||
dist_training_spec = _DistributedTrainingSpec(
|
||||
chief_spec = _WorkerPoolSpec(
|
||||
replica_count=1,
|
||||
machine_type='n1-standard-4',
|
||||
accelerator_count=2,
|
||||
accelerator_type='NVIDIA_TESLA_K80',
|
||||
boot_disk_type='pd-ssd',
|
||||
boot_disk_size_gb=100,
|
||||
),
|
||||
worker_spec = _WorkerPoolSpec(
|
||||
replica_count=10,
|
||||
machine_type='n1-standard-4',
|
||||
accelerator_count=2,
|
||||
accelerator_type='NVIDIA_TESLA_K80',
|
||||
boot_disk_type='pd-ssd',
|
||||
boot_disk_size_gb=100,
|
||||
),
|
||||
)
|
||||
"""
|
||||
|
||||
chief_spec: _WorkerPoolSpec = _WorkerPoolSpec()
|
||||
worker_spec: _WorkerPoolSpec = _WorkerPoolSpec()
|
||||
server_spec: _WorkerPoolSpec = _WorkerPoolSpec()
|
||||
evaluator_spec: _WorkerPoolSpec = _WorkerPoolSpec()
|
||||
|
||||
@property
|
||||
def pool_specs(
|
||||
self,
|
||||
) -> List[Dict[str, Union[int, str, Dict[str, Union[int, str]]]]]:
|
||||
"""Return each pools spec in correct order for Vertex AI as a list of
|
||||
dicts.
|
||||
|
||||
Also removes specs if they are empty but leaves specs in if there unusual
|
||||
specifications to not break the ordering in Vertex AI Training.
|
||||
ie. 0 chief replica, 10 worker replica, 3 ps replica
|
||||
|
||||
Returns:
|
||||
Order list of worker pool specs suitable for Vertex AI Training.
|
||||
"""
|
||||
if self.chief_spec.replica_count > 1:
|
||||
raise ValueError("Chief spec replica count cannot be greater than 1.")
|
||||
|
||||
spec_order = [
|
||||
self.chief_spec,
|
||||
self.worker_spec,
|
||||
self.server_spec,
|
||||
self.evaluator_spec,
|
||||
]
|
||||
specs = [{} if s.is_empty else s.spec_dict for s in spec_order]
|
||||
for i in reversed(range(len(spec_order))):
|
||||
if spec_order[i].is_empty:
|
||||
specs.pop()
|
||||
else:
|
||||
break
|
||||
return specs
|
||||
|
||||
@classmethod
|
||||
def chief_worker_pool(
|
||||
cls,
|
||||
replica_count: int = 0,
|
||||
machine_type: str = "n1-standard-4",
|
||||
accelerator_count: int = 0,
|
||||
accelerator_type: str = "ACCELERATOR_TYPE_UNSPECIFIED",
|
||||
boot_disk_type: str = "pd-ssd",
|
||||
boot_disk_size_gb: int = 100,
|
||||
reduction_server_replica_count: int = 0,
|
||||
reduction_server_machine_type: str = None,
|
||||
tpu_topology: str = None,
|
||||
reservation_affinity_type: Optional[
|
||||
Literal["NO_RESERVATION", "ANY_RESERVATION", "SPECIFIC_RESERVATION"]
|
||||
] = None,
|
||||
reservation_affinity_key: Optional[str] = None,
|
||||
reservation_affinity_values: Optional[List[str]] = None,
|
||||
) -> "_DistributedTrainingSpec":
|
||||
"""Parametrizes Config to support only chief with worker replicas.
|
||||
|
||||
For replica is assigned to chief and the remainder to workers. All spec have the
|
||||
same machine type, accelerator count, and accelerator type.
|
||||
|
||||
Args:
|
||||
replica_count (int):
|
||||
The number of worker replicas. Assigns 1 chief replica and
|
||||
replica_count - 1 worker replicas.
|
||||
machine_type (str):
|
||||
The type of machine to use for training.
|
||||
accelerator_type (str):
|
||||
Hardware accelerator type. One of ACCELERATOR_TYPE_UNSPECIFIED,
|
||||
NVIDIA_TESLA_K80, NVIDIA_TESLA_P100, NVIDIA_TESLA_V100, NVIDIA_TESLA_P4,
|
||||
NVIDIA_TESLA_T4
|
||||
accelerator_count (int):
|
||||
The number of accelerators to attach to a worker replica.
|
||||
boot_disk_type (str):
|
||||
Type of the boot disk (default is `pd-ssd`).
|
||||
Valid values: `pd-ssd` (Persistent Disk Solid State Drive) or
|
||||
`pd-standard` (Persistent Disk Hard Disk Drive).
|
||||
boot_disk_size_gb (int):
|
||||
Size in GB of the boot disk (default is 100GB).
|
||||
boot disk size must be within the range of [100, 64000].
|
||||
reduction_server_replica_count (int):
|
||||
The number of reduction server replicas, default is 0.
|
||||
reduction_server_machine_type (str):
|
||||
The type of machine to use for reduction server, default is `n1-highcpu-16`.
|
||||
tpu_topology (str):
|
||||
TPU topology for the TPU type. This field is
|
||||
required for the TPU v5 versions. This field is only passed to the
|
||||
chief replica as TPU jobs only allow 1 replica.
|
||||
reservation_affinity_type (str):
|
||||
Optional. The type of reservation affinity. One of:
|
||||
* "NO_RESERVATION" : No reservation is used.
|
||||
* "ANY_RESERVATION" : Any reservation that matches machine spec
|
||||
can be used.
|
||||
* "SPECIFIC_RESERVATION" : A specific reservation must be use
|
||||
used. See reservation_affinity_key and
|
||||
reservation_affinity_values for how to specify the reservation.
|
||||
reservation_affinity_key (str):
|
||||
Optional. Corresponds to the label key of a reservation resource.
|
||||
To target a SPECIFIC_RESERVATION by name, use
|
||||
`compute.googleapis.com/reservation-name` as the key
|
||||
and specify the name of your reservation as its value.
|
||||
reservation_affinity_values (List[str]):
|
||||
Optional. Corresponds to the label values of a reservation resource.
|
||||
This must be the full resource name of the reservation.
|
||||
Format: 'projects/{project_id_or_number}/zones/{zone}/reservations/{reservation_name}'
|
||||
|
||||
Returns:
|
||||
_DistributedTrainingSpec representing one chief and n workers all of
|
||||
same type, optional with reduction server(s). If replica_count <= 0
|
||||
then an empty spec is returned.
|
||||
"""
|
||||
if replica_count <= 0:
|
||||
return cls()
|
||||
|
||||
chief_spec = _WorkerPoolSpec(
|
||||
replica_count=1,
|
||||
machine_type=machine_type,
|
||||
accelerator_count=accelerator_count,
|
||||
accelerator_type=accelerator_type,
|
||||
boot_disk_type=boot_disk_type,
|
||||
boot_disk_size_gb=boot_disk_size_gb,
|
||||
tpu_topology=tpu_topology,
|
||||
reservation_affinity_type=reservation_affinity_type,
|
||||
reservation_affinity_key=reservation_affinity_key,
|
||||
reservation_affinity_values=reservation_affinity_values,
|
||||
)
|
||||
|
||||
worker_spec = _WorkerPoolSpec(
|
||||
replica_count=replica_count - 1,
|
||||
machine_type=machine_type,
|
||||
accelerator_count=accelerator_count,
|
||||
accelerator_type=accelerator_type,
|
||||
boot_disk_type=boot_disk_type,
|
||||
boot_disk_size_gb=boot_disk_size_gb,
|
||||
reservation_affinity_type=reservation_affinity_type,
|
||||
reservation_affinity_key=reservation_affinity_key,
|
||||
reservation_affinity_values=reservation_affinity_values,
|
||||
)
|
||||
|
||||
reduction_server_spec = _WorkerPoolSpec(
|
||||
replica_count=reduction_server_replica_count,
|
||||
machine_type=reduction_server_machine_type,
|
||||
reservation_affinity_type=reservation_affinity_type,
|
||||
reservation_affinity_key=reservation_affinity_key,
|
||||
reservation_affinity_values=reservation_affinity_values,
|
||||
)
|
||||
|
||||
return cls(
|
||||
chief_spec=chief_spec,
|
||||
worker_spec=worker_spec,
|
||||
server_spec=reduction_server_spec,
|
||||
)
|
||||
@@ -0,0 +1,145 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
# Copyright 2020 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from types import ModuleType
|
||||
from typing import Any, Dict, Optional
|
||||
from urllib import request
|
||||
|
||||
from google.auth import credentials as auth_credentials
|
||||
from google.auth import transport
|
||||
from google.cloud import storage
|
||||
from google.cloud.aiplatform.constants import pipeline as pipeline_constants
|
||||
|
||||
# Pattern for an Artifact Registry URL.
|
||||
_VALID_AR_URL = pipeline_constants._VALID_AR_URL
|
||||
|
||||
# Pattern for any JSON or YAML file over HTTPS.
|
||||
_VALID_HTTPS_URL = pipeline_constants._VALID_HTTPS_URL
|
||||
|
||||
|
||||
def load_yaml(
|
||||
path: str,
|
||||
project: Optional[str] = None,
|
||||
credentials: Optional[auth_credentials.Credentials] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""Loads data from a YAML document.
|
||||
|
||||
Args:
|
||||
path (str):
|
||||
Required. The path of the YAML document. It can be a local path, a
|
||||
Google Cloud Storage URI, an Artifact Registry URI, or an HTTPS URI.
|
||||
project (str):
|
||||
Optional. Project to initiate the Storage client with.
|
||||
credentials (auth_credentials.Credentials):
|
||||
Optional. Credentials to use with Storage Client.
|
||||
|
||||
Returns:
|
||||
A Dict object representing the YAML document.
|
||||
"""
|
||||
if path.startswith("gs://"):
|
||||
return _load_yaml_from_gs_uri(path, project, credentials)
|
||||
elif path.startswith("http://") or path.startswith("https://"):
|
||||
if _VALID_AR_URL.match(path):
|
||||
return _load_yaml_from_https_uri(path, credentials)
|
||||
elif _VALID_HTTPS_URL.match(path):
|
||||
return _load_yaml_from_https_uri(path)
|
||||
else:
|
||||
raise ValueError(
|
||||
"Invalid HTTPS URI. If not using Artifact Registry, please "
|
||||
"ensure the URI ends with .json, .yaml, or .yml."
|
||||
)
|
||||
else:
|
||||
return _load_yaml_from_local_file(path)
|
||||
|
||||
|
||||
def _maybe_import_yaml() -> ModuleType:
|
||||
"""Tries to import the PyYAML module."""
|
||||
try:
|
||||
import yaml
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"PyYAML is not installed and is required to parse PipelineJob or "
|
||||
'PipelineSpec files. Please install the SDK using "pip install '
|
||||
'google-cloud-aiplatform[pipelines]"'
|
||||
)
|
||||
return yaml
|
||||
|
||||
|
||||
def _load_yaml_from_gs_uri(
|
||||
uri: str,
|
||||
project: Optional[str] = None,
|
||||
credentials: Optional[auth_credentials.Credentials] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""Loads data from a YAML document referenced by a GCS URI.
|
||||
|
||||
Args:
|
||||
path (str):
|
||||
Required. GCS URI for YAML document.
|
||||
project (str):
|
||||
Optional. Project to initiate the Storage client with.
|
||||
credentials (auth_credentials.Credentials):
|
||||
Optional. Credentials to use with Storage Client.
|
||||
|
||||
Returns:
|
||||
A Dict object representing the YAML document.
|
||||
"""
|
||||
yaml = _maybe_import_yaml()
|
||||
storage_client = storage.Client(project=project, credentials=credentials)
|
||||
blob = storage.Blob.from_string(uri, storage_client)
|
||||
return yaml.safe_load(blob.download_as_bytes())
|
||||
|
||||
|
||||
def _load_yaml_from_local_file(file_path: str) -> Dict[str, Any]:
|
||||
"""Loads data from a YAML local file.
|
||||
|
||||
Args:
|
||||
file_path (str):
|
||||
Required. The local file path of the YAML document.
|
||||
|
||||
Returns:
|
||||
A Dict object representing the YAML document.
|
||||
"""
|
||||
yaml = _maybe_import_yaml()
|
||||
with open(file_path) as f:
|
||||
return yaml.safe_load(f)
|
||||
|
||||
|
||||
def _load_yaml_from_https_uri(
|
||||
uri: str,
|
||||
credentials: Optional[auth_credentials.Credentials] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""Loads data from a YAML document referenced by a Artifact Registry URI.
|
||||
|
||||
Args:
|
||||
uri (str):
|
||||
Required. Artifact Registry URI for YAML document.
|
||||
credentials (auth_credentials.Credentials):
|
||||
Optional. Credentials to use with Artifact Registry.
|
||||
|
||||
Returns:
|
||||
A Dict object representing the YAML document.
|
||||
"""
|
||||
yaml = _maybe_import_yaml()
|
||||
req = request.Request(uri)
|
||||
|
||||
if credentials:
|
||||
if not credentials.valid:
|
||||
credentials.refresh(transport.requests.Request())
|
||||
if credentials.token:
|
||||
req.add_header("Authorization", "Bearer " + credentials.token)
|
||||
response = request.urlopen(req)
|
||||
|
||||
return yaml.safe_load(response.read().decode("utf-8"))
|
||||
Reference in New Issue
Block a user