structure saas with tools
This commit is contained in:
@@ -0,0 +1,16 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
# Copyright 2021 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,900 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
# Copyright 2022 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
import importlib
|
||||
import os
|
||||
import pickle
|
||||
import tempfile
|
||||
from typing import Any, Dict, Optional, Sequence, Union
|
||||
|
||||
from google.auth import credentials as auth_credentials
|
||||
from google.cloud import storage
|
||||
from google.cloud import aiplatform
|
||||
from google.cloud.aiplatform import base
|
||||
from google.cloud.aiplatform import explain
|
||||
from google.cloud.aiplatform import helpers
|
||||
from google.cloud.aiplatform import initializer
|
||||
from google.cloud.aiplatform import models
|
||||
from google.cloud.aiplatform import utils
|
||||
from google.cloud.aiplatform.metadata.schema import utils as schema_utils
|
||||
from google.cloud.aiplatform.metadata.schema.google import (
|
||||
artifact_schema as google_artifact_schema,
|
||||
)
|
||||
from google.cloud.aiplatform.utils import gcs_utils
|
||||
|
||||
|
||||
_LOGGER = base.Logger(__name__)
|
||||
|
||||
_PICKLE_PROTOCOL = 4
|
||||
_MAX_INPUT_EXAMPLE_ROWS = 5
|
||||
|
||||
|
||||
def _save_sklearn_model(
|
||||
model: "sklearn.base.BaseEstimator", # noqa: F821
|
||||
path: str,
|
||||
) -> str:
|
||||
"""Saves a sklearn model.
|
||||
|
||||
Args:
|
||||
model (sklearn.base.BaseEstimator):
|
||||
Required. A sklearn model.
|
||||
path (str):
|
||||
Required. The local path to save the model.
|
||||
|
||||
Returns:
|
||||
A string represents the model class.
|
||||
"""
|
||||
with open(path, "wb") as f:
|
||||
pickle.dump(model, f, protocol=_PICKLE_PROTOCOL)
|
||||
return f"{model.__class__.__module__}.{model.__class__.__name__}"
|
||||
|
||||
|
||||
def _save_xgboost_model(
|
||||
model: Union["xgb.Booster", "xgb.XGBModel"], # noqa: F821
|
||||
path: str,
|
||||
) -> str:
|
||||
"""Saves a xgboost model.
|
||||
|
||||
Args:
|
||||
model (Union[xgb.Booster, xgb.XGBModel]):
|
||||
Requred. A xgboost model.
|
||||
path (str):
|
||||
Required. The local path to save the model.
|
||||
|
||||
Returns:
|
||||
A string represents the model class.
|
||||
"""
|
||||
model.save_model(path)
|
||||
return f"{model.__class__.__module__}.{model.__class__.__name__}"
|
||||
|
||||
|
||||
def _save_tensorflow_model(
|
||||
model: "tf.Module", # noqa: F821
|
||||
path: str,
|
||||
tf_save_model_kwargs: Optional[Dict[str, Any]] = None,
|
||||
) -> str:
|
||||
"""Saves a tensorflow model.
|
||||
|
||||
Args:
|
||||
model (tf.Module):
|
||||
Requred. A tensorflow model.
|
||||
path (str):
|
||||
Required. The local path to save the model.
|
||||
tf_save_model_kwargs (Dict[str, Any]):
|
||||
Optional. A dict of kwargs to pass to the model's save method.
|
||||
If saving a tf module, this will pass to "tf.saved_model.save" method.
|
||||
If saving a keras model, this will pass to "tf.keras.Model.save" method.
|
||||
|
||||
Returns:
|
||||
A string represents the model's base class.
|
||||
"""
|
||||
try:
|
||||
import tensorflow as tf
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"tensorflow is not installed and required for saving models."
|
||||
) from None
|
||||
|
||||
tf_save_model_kwargs = tf_save_model_kwargs or {}
|
||||
if isinstance(model, tf.keras.Model):
|
||||
model.save(path, **tf_save_model_kwargs)
|
||||
return "tensorflow.keras.Model"
|
||||
elif isinstance(model, tf.Module):
|
||||
tf.saved_model.save(model, path, **tf_save_model_kwargs)
|
||||
return "tensorflow.Module"
|
||||
|
||||
|
||||
def _load_sklearn_model(
|
||||
model_file: str,
|
||||
model_artifact: google_artifact_schema.ExperimentModel,
|
||||
) -> "sklearn.base.BaseEstimator": # noqa: F821
|
||||
"""Loads a sklearn model from local path.
|
||||
|
||||
Args:
|
||||
model_file (str):
|
||||
Required. A local model file to load.
|
||||
model_artifact (google_artifact_schema.ExperimentModel):
|
||||
Required. The artifact that saved the model.
|
||||
Returns:
|
||||
The sklearn model instance.
|
||||
|
||||
Raises:
|
||||
ImportError: if sklearn is not installed.
|
||||
"""
|
||||
try:
|
||||
import sklearn
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"sklearn is not installed and is required for loading models."
|
||||
) from None
|
||||
|
||||
if sklearn.__version__ < model_artifact.framework_version:
|
||||
_LOGGER.warning(
|
||||
f"The original model was saved via sklearn {model_artifact.framework_version}. "
|
||||
f"You are using sklearn {sklearn.__version__}."
|
||||
"Attempting to load model..."
|
||||
)
|
||||
with open(model_file, "rb") as f:
|
||||
sk_model = pickle.load(f)
|
||||
|
||||
return sk_model
|
||||
|
||||
|
||||
def _load_xgboost_model(
|
||||
model_file: str,
|
||||
model_artifact: google_artifact_schema.ExperimentModel,
|
||||
) -> Union["xgb.Booster", "xgb.XGBModel"]: # noqa: F821
|
||||
"""Loads a xgboost model from local path.
|
||||
|
||||
Args:
|
||||
model_file (str):
|
||||
Required. A local model file to load.
|
||||
model_artifact (google_artifact_schema.ExperimentModel):
|
||||
Required. The artifact that saved the model.
|
||||
Returns:
|
||||
The xgboost model instance.
|
||||
|
||||
Raises:
|
||||
ImportError: if xgboost is not installed.
|
||||
"""
|
||||
try:
|
||||
import xgboost as xgb
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"xgboost is not installed and is required for loading models."
|
||||
) from None
|
||||
|
||||
if xgb.__version__ < model_artifact.framework_version:
|
||||
_LOGGER.warning(
|
||||
f"The original model was saved via xgboost {model_artifact.framework_version}. "
|
||||
f"You are using xgboost {xgb.__version__}."
|
||||
"Attempting to load model..."
|
||||
)
|
||||
|
||||
module, class_name = model_artifact.model_class.rsplit(".", maxsplit=1)
|
||||
xgb_model = getattr(importlib.import_module(module), class_name)()
|
||||
xgb_model.load_model(model_file)
|
||||
|
||||
return xgb_model
|
||||
|
||||
|
||||
def _load_tensorflow_model(
|
||||
model_file: str,
|
||||
model_artifact: google_artifact_schema.ExperimentModel,
|
||||
) -> "tf.Module": # noqa: F821
|
||||
"""Loads a tensorflow model from path.
|
||||
|
||||
Args:
|
||||
model_file (str):
|
||||
Required. A path to load the model.
|
||||
model_artifact (google_artifact_schema.ExperimentModel):
|
||||
Required. The artifact that saved the model.
|
||||
Returns:
|
||||
The tensorflow model instance.
|
||||
|
||||
Raises:
|
||||
ImportError: if tensorflow is not installed.
|
||||
"""
|
||||
try:
|
||||
import tensorflow as tf
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"tensorflow is not installed and is required for loading models."
|
||||
) from None
|
||||
|
||||
if tf.__version__ < model_artifact.framework_version:
|
||||
_LOGGER.warning(
|
||||
f"The original model was saved via tensorflow {model_artifact.framework_version}. "
|
||||
f"You are using tensorflow {tf.__version__}."
|
||||
"Attempting to load model..."
|
||||
)
|
||||
|
||||
if model_artifact.model_class == "tensorflow.keras.Model":
|
||||
tf_model = tf.keras.models.load_model(model_file)
|
||||
elif model_artifact.model_class == "tensorflow.Module":
|
||||
tf_model = tf.saved_model.load(model_file)
|
||||
else:
|
||||
raise ValueError(f"Unsupported model class: {model_artifact.model_class}")
|
||||
|
||||
return tf_model
|
||||
|
||||
|
||||
def _save_input_example(
|
||||
input_example: Union[list, dict, "pd.DataFrame", "np.ndarray"], # noqa: F821
|
||||
path: str,
|
||||
):
|
||||
"""Saves an input example into a yaml file in the given path.
|
||||
|
||||
Supported example formats: list, dict, np.ndarray, pd.DataFrame.
|
||||
|
||||
Args:
|
||||
input_example (Union[list, dict, np.ndarray, pd.DataFrame]):
|
||||
Required. An input example to save. The value inside a list must be
|
||||
a scalar or list. The value inside a dict must be a scalar, list, or
|
||||
np.ndarray.
|
||||
path (str):
|
||||
Required. The directory that the example is saved to.
|
||||
|
||||
Raises:
|
||||
ImportError: if PyYAML or numpy is not installed.
|
||||
ValueError: if input_example is in a wrong format.
|
||||
"""
|
||||
try:
|
||||
import numpy as np
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"numpy is not installed and is required for saving input examples. "
|
||||
"Please install google-cloud-aiplatform[metadata]."
|
||||
) from None
|
||||
|
||||
try:
|
||||
import yaml
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"PyYAML is not installed and is required for saving input examples."
|
||||
) from None
|
||||
|
||||
example = {}
|
||||
if isinstance(input_example, list):
|
||||
if all(isinstance(x, list) for x in input_example):
|
||||
example = {
|
||||
"type": "list",
|
||||
"data": input_example[:_MAX_INPUT_EXAMPLE_ROWS],
|
||||
}
|
||||
elif all(np.isscalar(x) for x in input_example):
|
||||
example = {
|
||||
"type": "list",
|
||||
"data": input_example,
|
||||
}
|
||||
else:
|
||||
raise ValueError("The value inside a list must be a scalar or list.")
|
||||
|
||||
if isinstance(input_example, dict):
|
||||
if all(isinstance(x, list) for x in input_example.values()):
|
||||
example = {
|
||||
"type": "dict",
|
||||
"data": {
|
||||
k: v[:_MAX_INPUT_EXAMPLE_ROWS] for k, v in input_example.items()
|
||||
},
|
||||
}
|
||||
elif all(isinstance(x, np.ndarray) for x in input_example.values()):
|
||||
example = {
|
||||
"type": "dict",
|
||||
"data": {
|
||||
k: v[:_MAX_INPUT_EXAMPLE_ROWS].tolist()
|
||||
for k, v in input_example.items()
|
||||
},
|
||||
}
|
||||
elif all(np.isscalar(x) for x in input_example.values()):
|
||||
example = {"type": "dict", "data": input_example}
|
||||
else:
|
||||
raise ValueError(
|
||||
"The value inside a dictionary must be a scalar, list, or np.ndarray"
|
||||
)
|
||||
|
||||
if isinstance(input_example, np.ndarray):
|
||||
example = {
|
||||
"type": "numpy.ndarray",
|
||||
"data": input_example[:_MAX_INPUT_EXAMPLE_ROWS].tolist(),
|
||||
}
|
||||
|
||||
try:
|
||||
import pandas as pd
|
||||
|
||||
if isinstance(input_example, pd.DataFrame):
|
||||
example = {
|
||||
"type": "pandas.DataFrame",
|
||||
"data": input_example.head(_MAX_INPUT_EXAMPLE_ROWS).to_dict("list"),
|
||||
}
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
if not example:
|
||||
raise ValueError(
|
||||
(
|
||||
"Input example type not supported. "
|
||||
"Valid example must be a list, dict, np.ndarray, or pd.DataFrame."
|
||||
)
|
||||
)
|
||||
|
||||
example_file = os.path.join(path, "instance.yaml")
|
||||
with open(example_file, "w") as file:
|
||||
yaml.dump(
|
||||
{"input_example": example}, file, default_flow_style=None, sort_keys=False
|
||||
)
|
||||
|
||||
|
||||
_FRAMEWORK_SPECS = {
|
||||
"sklearn": {
|
||||
"save_method": _save_sklearn_model,
|
||||
"load_method": _load_sklearn_model,
|
||||
"model_file": "model.pkl",
|
||||
},
|
||||
"xgboost": {
|
||||
"save_method": _save_xgboost_model,
|
||||
"load_method": _load_xgboost_model,
|
||||
"model_file": "model.bst",
|
||||
},
|
||||
"tensorflow": {
|
||||
"save_method": _save_tensorflow_model,
|
||||
"load_method": _load_tensorflow_model,
|
||||
"model_file": "saved_model",
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def save_model(
|
||||
model: Union[
|
||||
"sklearn.base.BaseEstimator", "xgb.Booster", "tf.Module" # noqa: F821
|
||||
],
|
||||
artifact_id: Optional[str] = None,
|
||||
*,
|
||||
uri: Optional[str] = None,
|
||||
input_example: Union[list, dict, "pd.DataFrame", "np.ndarray"] = None, # noqa: F821
|
||||
tf_save_model_kwargs: Optional[Dict[str, Any]] = None,
|
||||
display_name: Optional[str] = None,
|
||||
metadata_store_id: Optional[str] = "default",
|
||||
project: Optional[str] = None,
|
||||
location: Optional[str] = None,
|
||||
credentials: Optional[auth_credentials.Credentials] = None,
|
||||
) -> google_artifact_schema.ExperimentModel:
|
||||
"""Saves a ML model into a MLMD artifact.
|
||||
|
||||
Supported model frameworks: sklearn, xgboost, tensorflow.
|
||||
|
||||
Example usage:
|
||||
aiplatform.init(project="my-project", location="my-location", staging_bucket="gs://my-bucket")
|
||||
model = LinearRegression()
|
||||
model.fit(X, y)
|
||||
aiplatform.save_model(model, "my-sklearn-model")
|
||||
|
||||
Args:
|
||||
model (Union["sklearn.base.BaseEstimator", "xgb.Booster", "tf.Module"]):
|
||||
Required. A machine learning model.
|
||||
artifact_id (str):
|
||||
Optional. The resource id of the artifact. This id must be globally unique
|
||||
in a metadataStore. It may be up to 63 characters, and valid characters
|
||||
are `[a-z0-9_-]`. The first character cannot be a number or hyphen.
|
||||
uri (str):
|
||||
Optional. A gcs directory to save the model file. If not provided,
|
||||
`gs://default-bucket/timestamp-uuid-frameworkName-model` will be used.
|
||||
If default staging bucket is not set, a new bucket will be created.
|
||||
input_example (Union[list, dict, pd.DataFrame, np.ndarray]):
|
||||
Optional. An example of a valid model input. Will be stored as a yaml file
|
||||
in the gcs uri. Accepts list, dict, pd.DataFrame, and np.ndarray
|
||||
The value inside a list must be a scalar or list. The value inside
|
||||
a dict must be a scalar, list, or np.ndarray.
|
||||
tf_save_model_kwargs (Dict[str, Any]):
|
||||
Optional. A dict of kwargs to pass to the model's save method.
|
||||
If saving a tf module, this will pass to "tf.saved_model.save" method.
|
||||
If saving a keras model, this will pass to "tf.keras.Model.save" method.
|
||||
display_name (str):
|
||||
Optional. The display name of the artifact.
|
||||
metadata_store_id (str):
|
||||
Optional. The <metadata_store_id> portion of the resource name with
|
||||
the format:
|
||||
projects/123/locations/us-central1/metadataStores/<metadata_store_id>/artifacts/<resource_id>
|
||||
If not provided, the MetadataStore's ID will be set to "default".
|
||||
project (str):
|
||||
Optional. Project used to create this Artifact. Overrides project set in
|
||||
aiplatform.init.
|
||||
location (str):
|
||||
Optional. Location used to create this Artifact. Overrides location set in
|
||||
aiplatform.init.
|
||||
credentials (auth_credentials.Credentials):
|
||||
Optional. Custom credentials used to create this Artifact. Overrides
|
||||
credentials set in aiplatform.init.
|
||||
|
||||
Returns:
|
||||
An ExperimentModel instance.
|
||||
|
||||
Raises:
|
||||
ValueError: if model type is not supported.
|
||||
"""
|
||||
framework_name = framework_version = ""
|
||||
try:
|
||||
import sklearn
|
||||
except ImportError:
|
||||
pass
|
||||
else:
|
||||
# An instance of sklearn.base.BaseEstimator might be a sklearn model
|
||||
# or a xgboost/lightgbm model implemented on top of sklearn.
|
||||
if isinstance(
|
||||
model, sklearn.base.BaseEstimator
|
||||
) and model.__class__.__module__.startswith("sklearn"):
|
||||
framework_name = "sklearn"
|
||||
framework_version = sklearn.__version__
|
||||
try:
|
||||
import sklearn.v1_0_2
|
||||
except ImportError:
|
||||
pass
|
||||
else:
|
||||
if isinstance(
|
||||
model, sklearn.v1_0_2.base.BaseEstimator
|
||||
) and model.__class__.__module__.startswith("sklearn"):
|
||||
framework_name = "sklearn"
|
||||
framework_version = sklearn.v1_0_2.__version__
|
||||
|
||||
try:
|
||||
import xgboost as xgb
|
||||
except ImportError:
|
||||
pass
|
||||
else:
|
||||
if isinstance(model, (xgb.Booster, xgb.XGBModel)):
|
||||
framework_name = "xgboost"
|
||||
framework_version = xgb.__version__
|
||||
|
||||
try:
|
||||
import tensorflow as tf
|
||||
except ImportError:
|
||||
pass
|
||||
else:
|
||||
if isinstance(model, tf.Module):
|
||||
framework_name = "tensorflow"
|
||||
framework_version = tf.__version__
|
||||
|
||||
if framework_name not in _FRAMEWORK_SPECS:
|
||||
raise ValueError(
|
||||
f"Model type {model.__class__.__module__}.{model.__class__.__name__} not supported."
|
||||
)
|
||||
|
||||
save_method = _FRAMEWORK_SPECS[framework_name]["save_method"]
|
||||
model_file = _FRAMEWORK_SPECS[framework_name]["model_file"]
|
||||
|
||||
if not uri:
|
||||
staging_bucket = initializer.global_config.staging_bucket
|
||||
# TODO(b/264196887)
|
||||
if not staging_bucket:
|
||||
project = project or initializer.global_config.project
|
||||
location = location or initializer.global_config.location
|
||||
credentials = credentials or initializer.global_config.credentials
|
||||
|
||||
staging_bucket_name = project + "-vertex-staging-" + location
|
||||
client = storage.Client(project=project, credentials=credentials)
|
||||
staging_bucket = storage.Bucket(client=client, name=staging_bucket_name)
|
||||
if not staging_bucket.exists():
|
||||
_LOGGER.info(f'Creating staging bucket "{staging_bucket_name}"')
|
||||
staging_bucket = client.create_bucket(
|
||||
bucket_or_name=staging_bucket,
|
||||
project=project,
|
||||
location=location,
|
||||
)
|
||||
staging_bucket = f"gs://{staging_bucket_name}"
|
||||
|
||||
unique_name = utils.timestamped_unique_name()
|
||||
uri = f"{staging_bucket}/{unique_name}-{framework_name}-model"
|
||||
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
# Tensorflow models can be saved directly to gcs
|
||||
if framework_name == "tensorflow":
|
||||
path = os.path.join(uri, model_file)
|
||||
model_class = save_method(model, path, tf_save_model_kwargs)
|
||||
# Other models will be saved to a temp path and uploaded to gcs
|
||||
else:
|
||||
path = os.path.join(temp_dir, model_file)
|
||||
model_class = save_method(model, path)
|
||||
|
||||
if input_example is not None:
|
||||
_save_input_example(input_example, temp_dir)
|
||||
predict_schemata = schema_utils.PredictSchemata(
|
||||
instance_schema_uri=os.path.join(uri, "instance.yaml")
|
||||
)
|
||||
else:
|
||||
predict_schemata = None
|
||||
gcs_utils.upload_to_gcs(temp_dir, uri)
|
||||
|
||||
model_artifact = google_artifact_schema.ExperimentModel(
|
||||
framework_name=framework_name,
|
||||
framework_version=framework_version,
|
||||
model_file=model_file,
|
||||
model_class=model_class,
|
||||
predict_schemata=predict_schemata,
|
||||
artifact_id=artifact_id,
|
||||
uri=uri,
|
||||
display_name=display_name,
|
||||
)
|
||||
model_artifact.create(
|
||||
metadata_store_id=metadata_store_id,
|
||||
project=project,
|
||||
location=location,
|
||||
credentials=credentials,
|
||||
)
|
||||
|
||||
return model_artifact
|
||||
|
||||
|
||||
def load_model(
|
||||
model: Union[str, google_artifact_schema.ExperimentModel]
|
||||
) -> Union["sklearn.base.BaseEstimator", "xgb.Booster", "tf.Module"]: # noqa: F821
|
||||
"""Retrieves the original ML model from an ExperimentModel resource.
|
||||
|
||||
Args:
|
||||
model (Union[str, google_artifact_schema.ExperimentModel]):
|
||||
Required. The id or ExperimentModel instance for the model.
|
||||
|
||||
Returns:
|
||||
The original ML model.
|
||||
|
||||
Raises:
|
||||
ValueError: if model type is not supported.
|
||||
"""
|
||||
if isinstance(model, str):
|
||||
model = aiplatform.get_experiment_model(model)
|
||||
framework_name = model.framework_name
|
||||
|
||||
if framework_name not in _FRAMEWORK_SPECS:
|
||||
raise ValueError(f"Model type {framework_name} not supported.")
|
||||
|
||||
load_method = _FRAMEWORK_SPECS[framework_name]["load_method"]
|
||||
model_file = _FRAMEWORK_SPECS[framework_name]["model_file"]
|
||||
|
||||
source_file_uri = os.path.join(model.uri, model_file)
|
||||
# Tensorflow models can be loaded directly from gcs
|
||||
if framework_name == "tensorflow":
|
||||
loaded_model = load_method(source_file_uri, model)
|
||||
# Other models need to be downloaded to local path then loaded.
|
||||
else:
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
destination_file_path = os.path.join(temp_dir, model_file)
|
||||
gcs_utils.download_file_from_gcs(source_file_uri, destination_file_path)
|
||||
loaded_model = load_method(destination_file_path, model)
|
||||
|
||||
return loaded_model
|
||||
|
||||
|
||||
# TODO(b/264893283)
|
||||
def register_model(
|
||||
model: Union[str, google_artifact_schema.ExperimentModel],
|
||||
*,
|
||||
model_id: Optional[str] = None,
|
||||
parent_model: Optional[str] = None,
|
||||
use_gpu: bool = False,
|
||||
is_default_version: bool = True,
|
||||
version_aliases: Optional[Sequence[str]] = None,
|
||||
version_description: Optional[str] = None,
|
||||
display_name: Optional[str] = None,
|
||||
description: Optional[str] = None,
|
||||
labels: Optional[Dict[str, str]] = None,
|
||||
serving_container_image_uri: Optional[str] = None,
|
||||
serving_container_predict_route: Optional[str] = None,
|
||||
serving_container_health_route: Optional[str] = None,
|
||||
serving_container_command: Optional[Sequence[str]] = None,
|
||||
serving_container_args: Optional[Sequence[str]] = None,
|
||||
serving_container_environment_variables: Optional[Dict[str, str]] = None,
|
||||
serving_container_ports: Optional[Sequence[int]] = None,
|
||||
instance_schema_uri: Optional[str] = None,
|
||||
parameters_schema_uri: Optional[str] = None,
|
||||
prediction_schema_uri: Optional[str] = None,
|
||||
explanation_metadata: Optional[explain.ExplanationMetadata] = None,
|
||||
explanation_parameters: Optional[explain.ExplanationParameters] = None,
|
||||
project: Optional[str] = None,
|
||||
location: Optional[str] = None,
|
||||
credentials: Optional[auth_credentials.Credentials] = None,
|
||||
encryption_spec_key_name: Optional[str] = None,
|
||||
staging_bucket: Optional[str] = None,
|
||||
sync: Optional[bool] = True,
|
||||
upload_request_timeout: Optional[float] = None,
|
||||
) -> models.Model:
|
||||
"""Register an ExperimentModel to Model Registry and returns a Model representing the registered Model resource.
|
||||
|
||||
Args:
|
||||
model (Union[str, google_artifact_schema.ExperimentModel]):
|
||||
Required. The id or ExperimentModel instance for the model.
|
||||
model_id (str):
|
||||
Optional. The ID to use for the registered Model, which will
|
||||
become the final component of the model resource name.
|
||||
This value may be up to 63 characters, and valid characters
|
||||
are `[a-z0-9_-]`. The first character cannot be a number or hyphen.
|
||||
parent_model (str):
|
||||
Optional. The resource name or model ID of an existing model that the
|
||||
newly-registered model will be a version of.
|
||||
Only set this field when uploading a new version of an existing model.
|
||||
use_gpu (str):
|
||||
Optional. Whether or not to use GPUs for the serving container. Only
|
||||
specify this argument when registering a Tensorflow model and
|
||||
'serving_container_image_uri' is not specified.
|
||||
is_default_version (bool):
|
||||
Optional. When set to True, the newly registered model version will
|
||||
automatically have alias "default" included. Subsequent uses of
|
||||
this model without a version specified will use this "default" version.
|
||||
|
||||
When set to False, the "default" alias will not be moved.
|
||||
Actions targeting the newly-registered model version will need
|
||||
to specifically reference this version by ID or alias.
|
||||
|
||||
New model uploads, i.e. version 1, will always be "default" aliased.
|
||||
version_aliases (Sequence[str]):
|
||||
Optional. User provided version aliases so that a model version
|
||||
can be referenced via alias instead of auto-generated version ID.
|
||||
A default version alias will be created for the first version of the model.
|
||||
|
||||
The format is [a-z][a-zA-Z0-9-]{0,126}[a-z0-9]
|
||||
version_description (str):
|
||||
Optional. The description of the model version being uploaded.
|
||||
display_name (str):
|
||||
Optional. The display name of the Model. The name can be up to 128
|
||||
characters long and can be consist of any UTF-8 characters.
|
||||
description (str):
|
||||
Optional. The description of the model.
|
||||
labels (Dict[str, str]):
|
||||
Optional. The labels with user-defined metadata to
|
||||
organize your Models.
|
||||
Label keys and values can be no longer than 64
|
||||
characters (Unicode codepoints), can only
|
||||
contain lowercase letters, numeric characters,
|
||||
underscores and dashes. International characters
|
||||
are allowed.
|
||||
See https://goo.gl/xmQnxf for more information
|
||||
and examples of labels.
|
||||
serving_container_image_uri (str):
|
||||
Optional. The URI of the Model serving container. A pre-built container
|
||||
<https://cloud.google.com/vertex-ai/docs/predictions/pre-built-containers>
|
||||
is automatically chosen based on the model's framwork. Set this field to
|
||||
override the default pre-built container.
|
||||
serving_container_predict_route (str):
|
||||
Optional. An HTTP path to send prediction requests to the container, and
|
||||
which must be supported by it. If not specified a default HTTP path will
|
||||
be used by Vertex AI.
|
||||
serving_container_health_route (str):
|
||||
Optional. An HTTP path to send health check requests to the container, and which
|
||||
must be supported by it. If not specified a standard HTTP path will be
|
||||
used by Vertex AI.
|
||||
serving_container_command (Sequence[str]):
|
||||
Optional. The command with which the container is run. Not executed within a
|
||||
shell. The Docker image's ENTRYPOINT is used if this is not provided.
|
||||
Variable references $(VAR_NAME) are expanded using the container's
|
||||
environment. If a variable cannot be resolved, the reference in the
|
||||
input string will be unchanged. The $(VAR_NAME) syntax can be escaped
|
||||
with a double $$, ie: $$(VAR_NAME). Escaped references will never be
|
||||
expanded, regardless of whether the variable exists or not.
|
||||
serving_container_args (Sequence[str]):
|
||||
Optional. The arguments to the command. The Docker image's CMD is used if this is
|
||||
not provided. Variable references $(VAR_NAME) are expanded using the
|
||||
container's environment. If a variable cannot be resolved, the reference
|
||||
in the input string will be unchanged. The $(VAR_NAME) syntax can be
|
||||
escaped with a double $$, ie: $$(VAR_NAME). Escaped references will
|
||||
never be expanded, regardless of whether the variable exists or not.
|
||||
serving_container_environment_variables (Dict[str, str]):
|
||||
Optional. The environment variables that are to be present in the container.
|
||||
Should be a dictionary where keys are environment variable names
|
||||
and values are environment variable values for those names.
|
||||
serving_container_ports (Sequence[int]):
|
||||
Optional. Declaration of ports that are exposed by the container. This field is
|
||||
primarily informational, it gives Vertex AI information about the
|
||||
network connections the container uses. Listing or not a port here has
|
||||
no impact on whether the port is actually exposed, any port listening on
|
||||
the default "0.0.0.0" address inside a container will be accessible from
|
||||
the network.
|
||||
instance_schema_uri (str):
|
||||
Optional. Points to a YAML file stored on Google Cloud
|
||||
Storage describing the format of a single instance, which
|
||||
are used in
|
||||
``PredictRequest.instances``,
|
||||
``ExplainRequest.instances``
|
||||
and
|
||||
``BatchPredictionJob.input_config``.
|
||||
The schema is defined as an OpenAPI 3.0.2 `Schema
|
||||
Object <https://tinyurl.com/y538mdwt#schema-object>`__.
|
||||
AutoML Models always have this field populated by AI
|
||||
Platform. Note: The URI given on output will be immutable
|
||||
and probably different, including the URI scheme, than the
|
||||
one given on input. The output URI will point to a location
|
||||
where the user only has a read access.
|
||||
parameters_schema_uri (str):
|
||||
Optional. Points to a YAML file stored on Google Cloud
|
||||
Storage describing the parameters of prediction and
|
||||
explanation via
|
||||
``PredictRequest.parameters``,
|
||||
``ExplainRequest.parameters``
|
||||
and
|
||||
``BatchPredictionJob.model_parameters``.
|
||||
The schema is defined as an OpenAPI 3.0.2 `Schema
|
||||
Object <https://tinyurl.com/y538mdwt#schema-object>`__.
|
||||
AutoML Models always have this field populated by AI
|
||||
Platform, if no parameters are supported it is set to an
|
||||
empty string. Note: The URI given on output will be
|
||||
immutable and probably different, including the URI scheme,
|
||||
than the one given on input. The output URI will point to a
|
||||
location where the user only has a read access.
|
||||
prediction_schema_uri (str):
|
||||
Optional. Points to a YAML file stored on Google Cloud
|
||||
Storage describing the format of a single prediction
|
||||
produced by this Model, which are returned via
|
||||
``PredictResponse.predictions``,
|
||||
``ExplainResponse.explanations``,
|
||||
and
|
||||
``BatchPredictionJob.output_config``.
|
||||
The schema is defined as an OpenAPI 3.0.2 `Schema
|
||||
Object <https://tinyurl.com/y538mdwt#schema-object>`__.
|
||||
AutoML Models always have this field populated by AI
|
||||
Platform. Note: The URI given on output will be immutable
|
||||
and probably different, including the URI scheme, than the
|
||||
one given on input. The output URI will point to a location
|
||||
where the user only has a read access.
|
||||
explanation_metadata (aiplatform.explain.ExplanationMetadata):
|
||||
Optional. Metadata describing the Model's input and output for explanation.
|
||||
`explanation_metadata` is optional while `explanation_parameters` must be
|
||||
specified when used.
|
||||
For more details, see `Ref docs <http://tinyurl.com/1igh60kt>`
|
||||
explanation_parameters (aiplatform.explain.ExplanationParameters):
|
||||
Optional. Parameters to configure explaining for Model's predictions.
|
||||
For more details, see `Ref docs <http://tinyurl.com/1an4zake>`
|
||||
project (str)
|
||||
Project to upload this model to. Overrides project set in
|
||||
aiplatform.init.
|
||||
location (str)
|
||||
Location to upload this model to. Overrides location set in
|
||||
aiplatform.init.
|
||||
credentials (auth_credentials.Credentials)
|
||||
Custom credentials to use to upload this model. Overrides credentials
|
||||
set in aiplatform.init.
|
||||
encryption_spec_key_name (Optional[str]):
|
||||
Optional. The Cloud KMS resource identifier of the customer
|
||||
managed encryption key used to protect the model. Has the
|
||||
form
|
||||
``projects/my-project/locations/my-region/keyRings/my-kr/cryptoKeys/my-key``.
|
||||
The key needs to be in the same region as where the compute
|
||||
resource is created.
|
||||
|
||||
If set, this Model and all sub-resources of this Model will be secured by this key.
|
||||
|
||||
Overrides encryption_spec_key_name set in aiplatform.init.
|
||||
staging_bucket (str):
|
||||
Optional. Bucket to stage local model artifacts. Overrides
|
||||
staging_bucket set in aiplatform.init.
|
||||
sync (bool):
|
||||
Optional. Whether to execute this method synchronously. If False,
|
||||
this method will unblock and it will be executed in a concurrent Future.
|
||||
upload_request_timeout (float):
|
||||
Optional. The timeout for the upload request in seconds.
|
||||
|
||||
Returns:
|
||||
model (aiplatform.Model):
|
||||
Instantiated representation of the registered model resource.
|
||||
|
||||
Raises:
|
||||
ValueError: If the model doesn't have a pre-built container that is
|
||||
suitable for its framework and 'serving_container_image_uri'
|
||||
is not set.
|
||||
"""
|
||||
if isinstance(model, str):
|
||||
model = aiplatform.get_experiment_model(model)
|
||||
|
||||
project = project or model.project
|
||||
location = location or model.location
|
||||
credentials = credentials or model.credentials
|
||||
|
||||
artifact_uri = model.uri
|
||||
framework_name = model.framework_name
|
||||
framework_version = model.framework_version
|
||||
artifact_uri = (
|
||||
f"{model.uri}/saved_model" if framework_name == "tensorflow" else model.uri
|
||||
)
|
||||
|
||||
if not serving_container_image_uri:
|
||||
if framework_name == "tensorflow" and use_gpu:
|
||||
accelerator = "gpu"
|
||||
else:
|
||||
accelerator = "cpu"
|
||||
serving_container_image_uri = helpers._get_closest_match_prebuilt_container_uri(
|
||||
framework=framework_name,
|
||||
framework_version=framework_version,
|
||||
region=location,
|
||||
accelerator=accelerator,
|
||||
)
|
||||
|
||||
if not display_name:
|
||||
display_name = models.Model._generate_display_name(f"{framework_name} model")
|
||||
|
||||
return models.Model.upload(
|
||||
serving_container_image_uri=serving_container_image_uri,
|
||||
artifact_uri=artifact_uri,
|
||||
model_id=model_id,
|
||||
parent_model=parent_model,
|
||||
is_default_version=is_default_version,
|
||||
version_aliases=version_aliases,
|
||||
version_description=version_description,
|
||||
display_name=display_name,
|
||||
description=description,
|
||||
labels=labels,
|
||||
serving_container_predict_route=serving_container_predict_route,
|
||||
serving_container_health_route=serving_container_health_route,
|
||||
serving_container_command=serving_container_command,
|
||||
serving_container_args=serving_container_args,
|
||||
serving_container_environment_variables=serving_container_environment_variables,
|
||||
serving_container_ports=serving_container_ports,
|
||||
instance_schema_uri=instance_schema_uri,
|
||||
parameters_schema_uri=parameters_schema_uri,
|
||||
prediction_schema_uri=prediction_schema_uri,
|
||||
explanation_metadata=explanation_metadata,
|
||||
explanation_parameters=explanation_parameters,
|
||||
project=project,
|
||||
location=location,
|
||||
credentials=credentials,
|
||||
encryption_spec_key_name=encryption_spec_key_name,
|
||||
staging_bucket=staging_bucket,
|
||||
sync=sync,
|
||||
upload_request_timeout=upload_request_timeout,
|
||||
)
|
||||
|
||||
|
||||
def get_experiment_model_info(
|
||||
model: Union[str, google_artifact_schema.ExperimentModel]
|
||||
) -> Dict[str, Any]:
|
||||
"""Get the model's info from an experiment model artifact.
|
||||
|
||||
Args:
|
||||
model (Union[str, google_artifact_schema.ExperimentModel]):
|
||||
Required. The id or ExperimentModel instance for the model.
|
||||
|
||||
Returns:
|
||||
A dict of model's info. This includes model's class name, framework name,
|
||||
framework version, and input example.
|
||||
"""
|
||||
if isinstance(model, str):
|
||||
model = aiplatform.get_experiment_model(model)
|
||||
|
||||
model_info = {
|
||||
"model_class": model.model_class,
|
||||
"framework_name": model.framework_name,
|
||||
"framework_version": model.framework_version,
|
||||
}
|
||||
|
||||
# try to get input example if exists
|
||||
input_example = None
|
||||
source_file = f"{model.uri}/instance.yaml"
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
destination_file = os.path.join(temp_dir, "instance.yaml")
|
||||
try:
|
||||
gcs_utils.download_file_from_gcs(source_file, destination_file)
|
||||
except Exception:
|
||||
pass
|
||||
else:
|
||||
try:
|
||||
import yaml
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"PyYAML is not installed and is required for loading input examples."
|
||||
) from None
|
||||
|
||||
with open(destination_file, "r") as f:
|
||||
input_example = yaml.safe_load(f)["input_example"]
|
||||
|
||||
if input_example:
|
||||
model_info["input_example"] = input_example
|
||||
|
||||
return model_info
|
||||
@@ -0,0 +1,564 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
# Copyright 2023 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
from typing import Optional, Dict, Union
|
||||
|
||||
import proto
|
||||
import threading
|
||||
|
||||
from google.auth import credentials as auth_credentials
|
||||
|
||||
from google.cloud.aiplatform import base
|
||||
from google.cloud.aiplatform import models
|
||||
from google.cloud.aiplatform import utils
|
||||
from google.cloud.aiplatform.compat.types import artifact as gca_artifact
|
||||
from google.cloud.aiplatform.compat.types import (
|
||||
metadata_service as gca_metadata_service,
|
||||
)
|
||||
from google.cloud.aiplatform.constants import base as base_constants
|
||||
from google.cloud.aiplatform.metadata import metadata_store
|
||||
from google.cloud.aiplatform.metadata import resource
|
||||
from google.cloud.aiplatform.metadata import utils as metadata_utils
|
||||
from google.cloud.aiplatform.utils import rest_utils
|
||||
|
||||
|
||||
_LOGGER = base.Logger(__name__)
|
||||
|
||||
|
||||
class Artifact(resource._Resource):
|
||||
"""Metadata Artifact resource for Vertex AI"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
artifact_name: str,
|
||||
*,
|
||||
metadata_store_id: str = "default",
|
||||
project: Optional[str] = None,
|
||||
location: Optional[str] = None,
|
||||
credentials: Optional[auth_credentials.Credentials] = None,
|
||||
):
|
||||
"""Retrieves an existing Metadata Artifact given a resource name or ID.
|
||||
|
||||
Args:
|
||||
artifact_name (str):
|
||||
Required. A fully-qualified resource name or resource ID of the Artifact.
|
||||
Example: "projects/123/locations/us-central1/metadataStores/default/artifacts/my-resource".
|
||||
or "my-resource" when project and location are initialized or passed.
|
||||
metadata_store_id (str):
|
||||
Optional. MetadataStore to retrieve Artifact from. If not set, metadata_store_id is set to "default".
|
||||
If artifact_name is a fully-qualified resource, its metadata_store_id overrides this one.
|
||||
project (str):
|
||||
Optional. Project to retrieve the artifact from. If not set, project
|
||||
set in aiplatform.init will be used.
|
||||
location (str):
|
||||
Optional. Location to retrieve the Artifact from. If not set, location
|
||||
set in aiplatform.init will be used.
|
||||
credentials (auth_credentials.Credentials):
|
||||
Optional. Custom credentials to use to retrieve this Artifact. Overrides
|
||||
credentials set in aiplatform.init.
|
||||
"""
|
||||
|
||||
super().__init__(
|
||||
resource_name=artifact_name,
|
||||
metadata_store_id=metadata_store_id,
|
||||
project=project,
|
||||
location=location,
|
||||
credentials=credentials,
|
||||
)
|
||||
|
||||
_resource_noun = "artifacts"
|
||||
_getter_method = "get_artifact"
|
||||
_delete_method = "delete_artifact"
|
||||
_parse_resource_name_method = "parse_artifact_path"
|
||||
_format_resource_name_method = "artifact_path"
|
||||
_list_method = "list_artifacts"
|
||||
|
||||
@classmethod
|
||||
def _create_resource(
|
||||
cls,
|
||||
client: utils.MetadataClientWithOverride,
|
||||
parent: str,
|
||||
resource_id: str,
|
||||
schema_title: str,
|
||||
uri: Optional[str] = None,
|
||||
display_name: Optional[str] = None,
|
||||
schema_version: Optional[str] = None,
|
||||
description: Optional[str] = None,
|
||||
metadata: Optional[Dict] = None,
|
||||
state: gca_artifact.Artifact.State = gca_artifact.Artifact.State.LIVE,
|
||||
) -> gca_artifact.Artifact:
|
||||
gapic_artifact = gca_artifact.Artifact(
|
||||
uri=uri,
|
||||
schema_title=schema_title,
|
||||
schema_version=schema_version,
|
||||
display_name=display_name,
|
||||
description=description,
|
||||
metadata=metadata if metadata else {},
|
||||
state=state,
|
||||
)
|
||||
return client.create_artifact(
|
||||
parent=parent,
|
||||
artifact=gapic_artifact,
|
||||
artifact_id=resource_id,
|
||||
)
|
||||
|
||||
# TODO() refactor code to move _create to _Resource class.
|
||||
@classmethod
|
||||
def _create(
|
||||
cls,
|
||||
resource_id: str,
|
||||
schema_title: str,
|
||||
uri: Optional[str] = None,
|
||||
display_name: Optional[str] = None,
|
||||
schema_version: Optional[str] = None,
|
||||
description: Optional[str] = None,
|
||||
metadata: Optional[Dict] = None,
|
||||
state: gca_artifact.Artifact.State = gca_artifact.Artifact.State.LIVE,
|
||||
metadata_store_id: Optional[str] = "default",
|
||||
project: Optional[str] = None,
|
||||
location: Optional[str] = None,
|
||||
credentials: Optional[auth_credentials.Credentials] = None,
|
||||
) -> "Artifact":
|
||||
"""Creates a new Metadata resource.
|
||||
|
||||
Args:
|
||||
resource_id (str):
|
||||
Required. The <resource_id> portion of the resource name with
|
||||
the format:
|
||||
projects/123/locations/us-central1/metadataStores/<metadata_store_id>/<resource_noun>/<resource_id>.
|
||||
schema_title (str):
|
||||
Required. schema_title identifies the schema title used by the resource.
|
||||
display_name (str):
|
||||
Optional. The user-defined name of the resource.
|
||||
schema_version (str):
|
||||
Optional. schema_version specifies the version used by the resource.
|
||||
If not set, defaults to use the latest version.
|
||||
description (str):
|
||||
Optional. Describes the purpose of the resource to be created.
|
||||
metadata (Dict):
|
||||
Optional. Contains the metadata information that will be stored in the resource.
|
||||
state (google.cloud.gapic.types.Artifact.State):
|
||||
Optional. The state of this Artifact. This is a
|
||||
property of the Artifact, and does not imply or
|
||||
capture any ongoing process. This property is
|
||||
managed by clients (such as Vertex AI
|
||||
Pipelines), and the system does not prescribe or
|
||||
check the validity of state transitions.
|
||||
metadata_store_id (str):
|
||||
The <metadata_store_id> portion of the resource name with
|
||||
the format:
|
||||
projects/123/locations/us-central1/metadataStores/<metadata_store_id>/<resource_noun>/<resource_id>
|
||||
If not provided, the MetadataStore's ID will be set to "default".
|
||||
project (str):
|
||||
Project used to create this resource. Overrides project set in
|
||||
aiplatform.init.
|
||||
location (str):
|
||||
Location used to create this resource. Overrides location set in
|
||||
aiplatform.init.
|
||||
credentials (auth_credentials.Credentials):
|
||||
Custom credentials used to create this resource. Overrides
|
||||
credentials set in aiplatform.init.
|
||||
|
||||
Returns:
|
||||
resource (_Resource):
|
||||
Instantiated representation of the managed Metadata resource.
|
||||
|
||||
"""
|
||||
appended_user_agent = []
|
||||
if base_constants.USER_AGENT_SDK_COMMAND:
|
||||
appended_user_agent = [
|
||||
f"sdk_command/{base_constants.USER_AGENT_SDK_COMMAND}"
|
||||
]
|
||||
# Reset the value for the USER_AGENT_SDK_COMMAND to avoid counting future unrelated api calls.
|
||||
base_constants.USER_AGENT_SDK_COMMAND = ""
|
||||
|
||||
api_client = cls._instantiate_client(
|
||||
location=location,
|
||||
credentials=credentials,
|
||||
appended_user_agent=appended_user_agent,
|
||||
)
|
||||
|
||||
parent = utils.full_resource_name(
|
||||
resource_name=metadata_store_id,
|
||||
resource_noun=metadata_store._MetadataStore._resource_noun,
|
||||
parse_resource_name_method=metadata_store._MetadataStore._parse_resource_name,
|
||||
format_resource_name_method=metadata_store._MetadataStore._format_resource_name,
|
||||
project=project,
|
||||
location=location,
|
||||
)
|
||||
|
||||
resource = cls._create_resource(
|
||||
client=api_client,
|
||||
parent=parent,
|
||||
resource_id=resource_id,
|
||||
schema_title=schema_title,
|
||||
uri=uri,
|
||||
display_name=display_name,
|
||||
schema_version=schema_version,
|
||||
description=description,
|
||||
metadata=metadata,
|
||||
state=state,
|
||||
)
|
||||
|
||||
self = cls._empty_constructor(
|
||||
project=project, location=location, credentials=credentials
|
||||
)
|
||||
self._gca_resource = resource
|
||||
self._threading_lock = threading.Lock()
|
||||
|
||||
return self
|
||||
|
||||
@classmethod
|
||||
def _update_resource(
|
||||
cls,
|
||||
client: utils.MetadataClientWithOverride,
|
||||
resource: proto.Message,
|
||||
) -> proto.Message:
|
||||
"""Update Artifacts with given input.
|
||||
|
||||
Args:
|
||||
client (utils.MetadataClientWithOverride):
|
||||
Required. client to send require to Metadata Service.
|
||||
resource (proto.Message):
|
||||
Required. The proto.Message which contains the update information for the resource.
|
||||
"""
|
||||
|
||||
return client.update_artifact(artifact=resource)
|
||||
|
||||
@classmethod
|
||||
def _list_resources(
|
||||
cls,
|
||||
client: utils.MetadataClientWithOverride,
|
||||
parent: str,
|
||||
filter: Optional[str] = None, # pylint: disable=redefined-builtin
|
||||
order_by: Optional[str] = None,
|
||||
):
|
||||
"""List artifacts in the parent path that matches the filter.
|
||||
|
||||
Args:
|
||||
client (utils.MetadataClientWithOverride):
|
||||
Required. client to send require to Metadata Service.
|
||||
parent (str):
|
||||
Required. The path where Artifacts are stored.
|
||||
filter (str):
|
||||
Optional. filter string to restrict the list result
|
||||
order_by (str):
|
||||
Optional. How the list of messages is ordered. Specify the
|
||||
values to order by and an ordering operation. The default sorting
|
||||
order is ascending. To specify descending order for a field, users
|
||||
append a " desc" suffix; for example: "foo desc, bar". Subfields
|
||||
are specified with a ``.`` character, such as foo.bar. see
|
||||
https://google.aip.dev/132#ordering for more details.
|
||||
|
||||
Returns:
|
||||
List of artifacts.
|
||||
"""
|
||||
list_request = gca_metadata_service.ListArtifactsRequest(
|
||||
parent=parent,
|
||||
filter=filter,
|
||||
order_by=order_by,
|
||||
)
|
||||
return client.list_artifacts(request=list_request)
|
||||
|
||||
@classmethod
|
||||
def create(
|
||||
cls,
|
||||
schema_title: str,
|
||||
*,
|
||||
resource_id: Optional[str] = None,
|
||||
uri: Optional[str] = None,
|
||||
display_name: Optional[str] = None,
|
||||
schema_version: Optional[str] = None,
|
||||
description: Optional[str] = None,
|
||||
metadata: Optional[Dict] = None,
|
||||
state: gca_artifact.Artifact.State = gca_artifact.Artifact.State.LIVE,
|
||||
metadata_store_id: Optional[str] = "default",
|
||||
project: Optional[str] = None,
|
||||
location: Optional[str] = None,
|
||||
credentials: Optional[auth_credentials.Credentials] = None,
|
||||
) -> "Artifact":
|
||||
"""Creates a new Metadata Artifact.
|
||||
|
||||
Args:
|
||||
schema_title (str):
|
||||
Required. schema_title identifies the schema title used by the Artifact.
|
||||
|
||||
Please reference https://cloud.google.com/vertex-ai/docs/ml-metadata/system-schemas.
|
||||
resource_id (str):
|
||||
Optional. The <resource_id> portion of the Artifact name with
|
||||
the format. This is globally unique in a metadataStore:
|
||||
projects/123/locations/us-central1/metadataStores/<metadata_store_id>/artifacts/<resource_id>.
|
||||
uri (str):
|
||||
Optional. The uniform resource identifier of the artifact file. May be empty if there is no actual
|
||||
artifact file.
|
||||
display_name (str):
|
||||
Optional. The user-defined name of the Artifact.
|
||||
schema_version (str):
|
||||
Optional. schema_version specifies the version used by the Artifact.
|
||||
If not set, defaults to use the latest version.
|
||||
description (str):
|
||||
Optional. Describes the purpose of the Artifact to be created.
|
||||
metadata (Dict):
|
||||
Optional. Contains the metadata information that will be stored in the Artifact.
|
||||
state (google.cloud.gapic.types.Artifact.State):
|
||||
Optional. The state of this Artifact. This is a
|
||||
property of the Artifact, and does not imply or
|
||||
capture any ongoing process. This property is
|
||||
managed by clients (such as Vertex AI
|
||||
Pipelines), and the system does not prescribe or
|
||||
check the validity of state transitions.
|
||||
metadata_store_id (str):
|
||||
Optional. The <metadata_store_id> portion of the resource name with
|
||||
the format:
|
||||
projects/123/locations/us-central1/metadataStores/<metadata_store_id>/artifacts/<resource_id>
|
||||
If not provided, the MetadataStore's ID will be set to "default".
|
||||
project (str):
|
||||
Optional. Project used to create this Artifact. Overrides project set in
|
||||
aiplatform.init.
|
||||
location (str):
|
||||
Optional. Location used to create this Artifact. Overrides location set in
|
||||
aiplatform.init.
|
||||
credentials (auth_credentials.Credentials):
|
||||
Optional. Custom credentials used to create this Artifact. Overrides
|
||||
credentials set in aiplatform.init.
|
||||
|
||||
Returns:
|
||||
Artifact: Instantiated representation of the managed Metadata Artifact.
|
||||
"""
|
||||
# Add User Agent Header for metrics tracking if one is not specified
|
||||
# If one is already specified this call was initiated by a sub class.
|
||||
if not base_constants.USER_AGENT_SDK_COMMAND:
|
||||
base_constants.USER_AGENT_SDK_COMMAND = (
|
||||
"aiplatform.metadata.artifact.Artifact.create"
|
||||
)
|
||||
|
||||
if metadata_store_id == "default":
|
||||
metadata_store._MetadataStore.ensure_default_metadata_store_exists(
|
||||
project=project, location=location, credentials=credentials
|
||||
)
|
||||
|
||||
return cls._create(
|
||||
resource_id=resource_id,
|
||||
schema_title=schema_title,
|
||||
uri=uri,
|
||||
display_name=display_name,
|
||||
schema_version=schema_version,
|
||||
description=description,
|
||||
metadata=metadata,
|
||||
state=state,
|
||||
metadata_store_id=metadata_store_id,
|
||||
project=project,
|
||||
location=location,
|
||||
credentials=credentials,
|
||||
)
|
||||
|
||||
@property
|
||||
def uri(self) -> Optional[str]:
|
||||
"Uri for this Artifact."
|
||||
return self._gca_resource.uri
|
||||
|
||||
@property
|
||||
def state(self) -> Optional[gca_artifact.Artifact.State]:
|
||||
"The State for this Artifact."
|
||||
return self._gca_resource.state
|
||||
|
||||
@classmethod
|
||||
def get_with_uri(
|
||||
cls,
|
||||
uri: str,
|
||||
*,
|
||||
metadata_store_id: Optional[str] = "default",
|
||||
project: Optional[str] = None,
|
||||
location: Optional[str] = None,
|
||||
credentials: Optional[auth_credentials.Credentials] = None,
|
||||
) -> "Artifact":
|
||||
"""Get an Artifact by it's uri.
|
||||
|
||||
If more than one Artifact with this uri is in the metadata store then the Artifact with the latest
|
||||
create_time is returned.
|
||||
|
||||
Args:
|
||||
uri(str):
|
||||
Required. Uri of the Artifact to retrieve.
|
||||
metadata_store_id (str):
|
||||
Optional. MetadataStore to retrieve Artifact from. If not set, metadata_store_id is set to "default".
|
||||
If artifact_name is a fully-qualified resource, its metadata_store_id overrides this one.
|
||||
project (str):
|
||||
Optional. Project to retrieve the artifact from. If not set, project
|
||||
set in aiplatform.init will be used.
|
||||
location (str):
|
||||
Optional. Location to retrieve the Artifact from. If not set, location
|
||||
set in aiplatform.init will be used.
|
||||
credentials (auth_credentials.Credentials):
|
||||
Optional. Custom credentials to use to retrieve this Artifact. Overrides
|
||||
credentials set in aiplatform.init.
|
||||
Returns:
|
||||
Artifact: Artifact with given uri.
|
||||
Raises:
|
||||
ValueError: If no Artifact exists with the provided uri.
|
||||
|
||||
"""
|
||||
|
||||
matched_artifacts = cls.list(
|
||||
filter=f'uri = "{uri}"',
|
||||
metadata_store_id=metadata_store_id,
|
||||
project=project,
|
||||
location=location,
|
||||
credentials=credentials,
|
||||
)
|
||||
|
||||
if not matched_artifacts:
|
||||
raise ValueError(
|
||||
f"No artifact with uri {uri} is in the `{metadata_store_id}` MetadataStore."
|
||||
)
|
||||
|
||||
if len(matched_artifacts) > 1:
|
||||
matched_artifacts.sort(key=lambda a: a.create_time, reverse=True)
|
||||
resource_names = "\n".join(a.resource_name for a in matched_artifacts)
|
||||
_LOGGER.warn(
|
||||
f"Mutiple artifacts with uri {uri} were found: {resource_names}"
|
||||
)
|
||||
_LOGGER.warn(f"Returning {matched_artifacts[0].resource_name}")
|
||||
|
||||
return matched_artifacts[0]
|
||||
|
||||
@property
|
||||
def lineage_console_uri(self) -> str:
|
||||
"""Cloud console uri to view this Artifact Lineage."""
|
||||
metadata_store = self._parse_resource_name(self.resource_name)["metadata_store"]
|
||||
return f"https://console.cloud.google.com/vertex-ai/locations/{self.location}/metadata-stores/{metadata_store}/artifacts/{self.name}?project={self.project}"
|
||||
|
||||
def __repr__(self) -> str:
|
||||
if self._gca_resource:
|
||||
return f"{object.__repr__(self)} \nresource name: {self.resource_name}\nuri: {self.uri}\nschema_title:{self.gca_resource.schema_title}"
|
||||
|
||||
return base.FutureManager.__repr__(self)
|
||||
|
||||
|
||||
class _VertexResourceArtifactResolver:
|
||||
|
||||
# TODO(b/235594717) Add support for managed datasets
|
||||
_resource_to_artifact_type = {models.Model: "google.VertexModel"}
|
||||
|
||||
@classmethod
|
||||
def supports_metadata(cls, resource: base.VertexAiResourceNoun) -> bool:
|
||||
"""Returns True if Vertex resource is supported in Vertex Metadata otherwise False.
|
||||
|
||||
Args:
|
||||
resource (base.VertexAiResourceNoun):
|
||||
Requried. Instance of Vertex AI Resource.
|
||||
Returns:
|
||||
True if Vertex resource is supported in Vertex Metadata otherwise False.
|
||||
"""
|
||||
return type(resource) in cls._resource_to_artifact_type
|
||||
|
||||
@classmethod
|
||||
def validate_resource_supports_metadata(cls, resource: base.VertexAiResourceNoun):
|
||||
"""Validates Vertex resource is supported in Vertex Metadata.
|
||||
|
||||
Args:
|
||||
resource (base.VertexAiResourceNoun):
|
||||
Required. Instance of Vertex AI Resource.
|
||||
Raises:
|
||||
ValueError: If Vertex AI Resource is not support in Vertex Metadata.
|
||||
"""
|
||||
if not cls.supports_metadata(resource):
|
||||
raise ValueError(
|
||||
f"Vertex {type(resource)} is not yet supported in Vertex Metadata."
|
||||
f"Only {list(cls._resource_to_artifact_type.keys())} are supported"
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def resolve_vertex_resource(
|
||||
cls, resource: Union[models.Model]
|
||||
) -> Optional[Artifact]:
|
||||
"""Resolves Vertex Metadata Artifact that represents this Vertex Resource.
|
||||
|
||||
If there are multiple Artifacts in the metadata store that represent the provided resource. The one with the
|
||||
latest create_time is returned.
|
||||
|
||||
Args:
|
||||
resource (base.VertexAiResourceNoun):
|
||||
Required. Instance of Vertex AI Resource.
|
||||
Returns:
|
||||
Artifact: Artifact that represents this Vertex Resource. None if Resource not found in Metadata store.
|
||||
"""
|
||||
cls.validate_resource_supports_metadata(resource)
|
||||
resource.wait()
|
||||
metadata_type = cls._resource_to_artifact_type[type(resource)]
|
||||
uri = rest_utils.make_gcp_resource_rest_url(resource=resource)
|
||||
|
||||
artifacts = Artifact.list(
|
||||
filter=metadata_utils._make_filter_string(
|
||||
schema_title=metadata_type,
|
||||
uri=uri,
|
||||
),
|
||||
project=resource.project,
|
||||
location=resource.location,
|
||||
credentials=resource.credentials,
|
||||
)
|
||||
|
||||
artifacts.sort(key=lambda a: a.create_time, reverse=True)
|
||||
if artifacts:
|
||||
# most recent
|
||||
return artifacts[0]
|
||||
|
||||
@classmethod
|
||||
def create_vertex_resource_artifact(cls, resource: Union[models.Model]) -> Artifact:
|
||||
"""Creates Vertex Metadata Artifact that represents this Vertex Resource.
|
||||
|
||||
Args:
|
||||
resource (base.VertexAiResourceNoun):
|
||||
Required. Instance of Vertex AI Resource.
|
||||
Returns:
|
||||
Artifact: Artifact that represents this Vertex Resource.
|
||||
"""
|
||||
cls.validate_resource_supports_metadata(resource)
|
||||
resource.wait()
|
||||
|
||||
metadata_type = cls._resource_to_artifact_type[type(resource)]
|
||||
uri = rest_utils.make_gcp_resource_rest_url(resource=resource)
|
||||
|
||||
return Artifact.create(
|
||||
schema_title=metadata_type,
|
||||
display_name=getattr(resource.gca_resource, "display_name", None),
|
||||
uri=uri,
|
||||
# Note that support for non-versioned resources requires
|
||||
# change to reference `resource_name` please update if
|
||||
# supporting resource other than Model
|
||||
metadata={"resourceName": resource.versioned_resource_name},
|
||||
project=resource.project,
|
||||
location=resource.location,
|
||||
credentials=resource.credentials,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def resolve_or_create_resource_artifact(
|
||||
cls, resource: Union[models.Model]
|
||||
) -> Artifact:
|
||||
"""Create of gets Vertex Metadata Artifact that represents this Vertex Resource.
|
||||
|
||||
Args:
|
||||
resource (base.VertexAiResourceNoun):
|
||||
Required. Instance of Vertex AI Resource.
|
||||
Returns:
|
||||
Artifact: Artifact that represents this Vertex Resource.
|
||||
"""
|
||||
artifact = cls.resolve_vertex_resource(resource=resource)
|
||||
if artifact:
|
||||
return artifact
|
||||
return cls.create_vertex_resource_artifact(resource=resource)
|
||||
@@ -0,0 +1,87 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
# Copyright 2023 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
"""Constants used by Metadata and Vertex Experiments."""
|
||||
|
||||
from google.cloud.aiplatform.compat.types import artifact
|
||||
|
||||
SYSTEM_RUN = "system.Run"
|
||||
SYSTEM_EXPERIMENT = "system.Experiment"
|
||||
SYSTEM_EXPERIMENT_RUN = "system.ExperimentRun"
|
||||
SYSTEM_PIPELINE = "system.Pipeline"
|
||||
SYSTEM_PIPELINE_RUN = "system.PipelineRun"
|
||||
SYSTEM_METRICS = "system.Metrics"
|
||||
GOOGLE_CLASSIFICATION_METRICS = "google.ClassificationMetrics"
|
||||
GOOGLE_REGRESSION_METRICS = "google.RegressionMetrics"
|
||||
GOOGLE_FORECASTING_METRICS = "google.ForecastingMetrics"
|
||||
GOOGLE_EXPERIMENT_MODEL = "google.ExperimentModel"
|
||||
_EXPERIMENTS_V2_TENSORBOARD_RUN = "google.VertexTensorboardRun"
|
||||
|
||||
_DEFAULT_SCHEMA_VERSION = "0.0.1"
|
||||
|
||||
SCHEMA_VERSIONS = {
|
||||
SYSTEM_RUN: _DEFAULT_SCHEMA_VERSION,
|
||||
SYSTEM_EXPERIMENT: _DEFAULT_SCHEMA_VERSION,
|
||||
SYSTEM_EXPERIMENT_RUN: _DEFAULT_SCHEMA_VERSION,
|
||||
SYSTEM_PIPELINE: _DEFAULT_SCHEMA_VERSION,
|
||||
SYSTEM_METRICS: _DEFAULT_SCHEMA_VERSION,
|
||||
}
|
||||
|
||||
_BACKING_TENSORBOARD_RESOURCE_KEY = "backing_tensorboard_resource"
|
||||
|
||||
_CUSTOM_JOB_KEY = "_custom_jobs"
|
||||
_CUSTOM_JOB_RESOURCE_NAME = "custom_job_resource_name"
|
||||
_CUSTOM_JOB_CONSOLE_URI = "custom_job_console_uri"
|
||||
|
||||
_PARAM_KEY = "_params"
|
||||
_METRIC_KEY = "_metrics"
|
||||
_STATE_KEY = "_state"
|
||||
|
||||
_PARAM_PREFIX = "param"
|
||||
_METRIC_PREFIX = "metric"
|
||||
_TIME_SERIES_METRIC_PREFIX = "time_series_metric"
|
||||
|
||||
# This is currently used to filter in the Console.
|
||||
EXPERIMENT_METADATA = {"experiment_deleted": False}
|
||||
|
||||
PIPELINE_PARAM_PREFIX = "input:"
|
||||
|
||||
TENSORBOARD_CUSTOM_JOB_EXPERIMENT_FIELD = "tensorboard_link"
|
||||
|
||||
GCP_ARTIFACT_RESOURCE_NAME_KEY = "resourceName"
|
||||
|
||||
# constant to mark an Experiment context as originating from the SDK
|
||||
# TODO(b/235593750) Remove this field
|
||||
_VERTEX_EXPERIMENT_TRACKING_LABEL = "vertex_experiment_tracking"
|
||||
|
||||
_TENSORBOARD_RUN_REFERENCE_ARTIFACT = artifact.Artifact(
|
||||
name="google-vertex-tensorboard-run-v0-0-1",
|
||||
schema_title=_EXPERIMENTS_V2_TENSORBOARD_RUN,
|
||||
schema_version="0.0.1",
|
||||
metadata={_VERTEX_EXPERIMENT_TRACKING_LABEL: True},
|
||||
)
|
||||
|
||||
_TB_RUN_ARTIFACT_POST_FIX_ID = "-tb-run"
|
||||
_EXPERIMENT_RUN_MAX_LENGTH = 128 - len(_TB_RUN_ARTIFACT_POST_FIX_ID)
|
||||
|
||||
# Label used to identify TensorboardExperiment as created from Vertex
|
||||
# Experiments
|
||||
_VERTEX_EXPERIMENT_TB_EXPERIMENT_LABEL = {
|
||||
"vertex_tensorboard_experiment_source": "vertex_experiment"
|
||||
}
|
||||
|
||||
ENV_EXPERIMENT_KEY = "AIP_EXPERIMENT_NAME"
|
||||
ENV_EXPERIMENT_RUN_KEY = "AIP_EXPERIMENT_RUN_NAME"
|
||||
@@ -0,0 +1,421 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
# Copyright 2023 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
from typing import Optional, Dict, List, Sequence
|
||||
|
||||
import proto
|
||||
import re
|
||||
import threading
|
||||
|
||||
from google.auth import credentials as auth_credentials
|
||||
|
||||
from google.cloud.aiplatform import base
|
||||
from google.cloud.aiplatform import utils
|
||||
from google.cloud.aiplatform.constants import base as base_constants
|
||||
from google.cloud.aiplatform.metadata import utils as metadata_utils
|
||||
from google.cloud.aiplatform.compat.types import context as gca_context
|
||||
from google.cloud.aiplatform.compat.types import (
|
||||
lineage_subgraph as gca_lineage_subgraph,
|
||||
)
|
||||
from google.cloud.aiplatform.compat.types import (
|
||||
metadata_service as gca_metadata_service,
|
||||
)
|
||||
from google.cloud.aiplatform.metadata import artifact
|
||||
from google.cloud.aiplatform.metadata import execution
|
||||
from google.cloud.aiplatform.metadata import metadata_store
|
||||
from google.cloud.aiplatform.metadata import resource
|
||||
from google.api_core.exceptions import Aborted
|
||||
|
||||
_ETAG_ERROR_MAX_RETRY_COUNT = 5
|
||||
_ETAG_ERROR_REGEX = re.compile(
|
||||
r"Specified Context \`etag\`: \`(\d+)\` does not match server \`etag\`: \`(\d+)\`"
|
||||
)
|
||||
|
||||
|
||||
class Context(resource._Resource):
|
||||
"""Metadata Context resource for Vertex AI"""
|
||||
|
||||
_resource_noun = "contexts"
|
||||
_getter_method = "get_context"
|
||||
_delete_method = "delete_context"
|
||||
_parse_resource_name_method = "parse_context_path"
|
||||
_format_resource_name_method = "context_path"
|
||||
_list_method = "list_contexts"
|
||||
|
||||
@property
|
||||
def parent_contexts(self) -> Sequence[str]:
|
||||
"""The parent context resource names of this context."""
|
||||
return self.gca_resource.parent_contexts
|
||||
|
||||
def add_artifacts_and_executions(
|
||||
self,
|
||||
artifact_resource_names: Optional[Sequence[str]] = None,
|
||||
execution_resource_names: Optional[Sequence[str]] = None,
|
||||
):
|
||||
"""Associate Executions and attribute Artifacts to a given Context.
|
||||
|
||||
Args:
|
||||
artifact_resource_names (Sequence[str]):
|
||||
Optional. The full resource name of Artifacts to attribute to the Context.
|
||||
execution_resource_names (Sequence[str]):
|
||||
Optional. The full resource name of Executions to associate with the Context.
|
||||
"""
|
||||
self.api_client.add_context_artifacts_and_executions(
|
||||
context=self.resource_name,
|
||||
artifacts=artifact_resource_names,
|
||||
executions=execution_resource_names,
|
||||
)
|
||||
|
||||
def get_artifacts(self) -> List[artifact.Artifact]:
|
||||
"""Returns all Artifact attributed to this Context.
|
||||
|
||||
Returns:
|
||||
artifacts(List[Artifacts]): All Artifacts under this context.
|
||||
"""
|
||||
return artifact.Artifact.list(
|
||||
filter=metadata_utils._make_filter_string(in_context=[self.resource_name]),
|
||||
project=self.project,
|
||||
location=self.location,
|
||||
credentials=self.credentials,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def create(
|
||||
cls,
|
||||
schema_title: str,
|
||||
*,
|
||||
resource_id: Optional[str] = None,
|
||||
display_name: Optional[str] = None,
|
||||
schema_version: Optional[str] = None,
|
||||
description: Optional[str] = None,
|
||||
metadata: Optional[Dict] = None,
|
||||
metadata_store_id: Optional[str] = "default",
|
||||
project: Optional[str] = None,
|
||||
location: Optional[str] = None,
|
||||
credentials: Optional[auth_credentials.Credentials] = None,
|
||||
) -> "Context":
|
||||
"""Creates a new Metadata Context.
|
||||
|
||||
Args:
|
||||
schema_title (str):
|
||||
Required. schema_title identifies the schema title used by the Context.
|
||||
Please reference https://cloud.google.com/vertex-ai/docs/ml-metadata/system-schemas.
|
||||
resource_id (str):
|
||||
Optional. The <resource_id> portion of the Context name with
|
||||
the format. This is globally unique in a metadataStore:
|
||||
projects/123/locations/us-central1/metadataStores/<metadata_store_id>/Contexts/<resource_id>.
|
||||
display_name (str):
|
||||
Optional. The user-defined name of the Context.
|
||||
schema_version (str):
|
||||
Optional. schema_version specifies the version used by the Context.
|
||||
If not set, defaults to use the latest version.
|
||||
description (str):
|
||||
Optional. Describes the purpose of the Context to be created.
|
||||
metadata (Dict):
|
||||
Optional. Contains the metadata information that will be stored in the Context.
|
||||
metadata_store_id (str):
|
||||
Optional. The <metadata_store_id> portion of the resource name with
|
||||
the format:
|
||||
projects/123/locations/us-central1/metadataStores/<metadata_store_id>/Contexts/<resource_id>
|
||||
If not provided, the MetadataStore's ID will be set to "default".
|
||||
project (str):
|
||||
Optional. Project used to create this Context. Overrides project set in
|
||||
aiplatform.init.
|
||||
location (str):
|
||||
Optional. Location used to create this Context. Overrides location set in
|
||||
aiplatform.init.
|
||||
credentials (auth_credentials.Credentials):
|
||||
Optional. Custom credentials used to create this Context. Overrides
|
||||
credentials set in aiplatform.init.
|
||||
|
||||
Returns:
|
||||
Context: Instantiated representation of the managed Metadata Context.
|
||||
"""
|
||||
# Add User Agent Header for metrics tracking if one is not specified
|
||||
# If one is already specified this call was initiated by a sub class.
|
||||
if not base_constants.USER_AGENT_SDK_COMMAND:
|
||||
base_constants.USER_AGENT_SDK_COMMAND = (
|
||||
"aiplatform.metadata.context.Context.create"
|
||||
)
|
||||
|
||||
return cls._create(
|
||||
resource_id=resource_id,
|
||||
schema_title=schema_title,
|
||||
display_name=display_name,
|
||||
schema_version=schema_version,
|
||||
description=description,
|
||||
metadata=metadata,
|
||||
metadata_store_id=metadata_store_id,
|
||||
project=project,
|
||||
location=location,
|
||||
credentials=credentials,
|
||||
)
|
||||
|
||||
# TODO() refactor code to move _create to _Resource class.
|
||||
@classmethod
|
||||
def _create(
|
||||
cls,
|
||||
resource_id: str,
|
||||
schema_title: str,
|
||||
display_name: Optional[str] = None,
|
||||
schema_version: Optional[str] = None,
|
||||
description: Optional[str] = None,
|
||||
metadata: Optional[Dict] = None,
|
||||
metadata_store_id: Optional[str] = "default",
|
||||
project: Optional[str] = None,
|
||||
location: Optional[str] = None,
|
||||
credentials: Optional[auth_credentials.Credentials] = None,
|
||||
) -> "Context":
|
||||
"""Creates a new Metadata resource.
|
||||
|
||||
Args:
|
||||
resource_id (str):
|
||||
Required. The <resource_id> portion of the resource name with
|
||||
the format:
|
||||
projects/123/locations/us-central1/metadataStores/<metadata_store_id>/<resource_noun>/<resource_id>.
|
||||
schema_title (str):
|
||||
Required. schema_title identifies the schema title used by the resource.
|
||||
display_name (str):
|
||||
Optional. The user-defined name of the resource.
|
||||
schema_version (str):
|
||||
Optional. schema_version specifies the version used by the resource.
|
||||
If not set, defaults to use the latest version.
|
||||
description (str):
|
||||
Optional. Describes the purpose of the resource to be created.
|
||||
metadata (Dict):
|
||||
Optional. Contains the metadata information that will be stored in the resource.
|
||||
metadata_store_id (str):
|
||||
The <metadata_store_id> portion of the resource name with
|
||||
the format:
|
||||
projects/123/locations/us-central1/metadataStores/<metadata_store_id>/<resource_noun>/<resource_id>
|
||||
If not provided, the MetadataStore's ID will be set to "default".
|
||||
project (str):
|
||||
Project used to create this resource. Overrides project set in
|
||||
aiplatform.init.
|
||||
location (str):
|
||||
Location used to create this resource. Overrides location set in
|
||||
aiplatform.init.
|
||||
credentials (auth_credentials.Credentials):
|
||||
Custom credentials used to create this resource. Overrides
|
||||
credentials set in aiplatform.init.
|
||||
|
||||
Returns:
|
||||
resource (_Resource):
|
||||
Instantiated representation of the managed Metadata resource.
|
||||
|
||||
"""
|
||||
appended_user_agent = []
|
||||
if base_constants.USER_AGENT_SDK_COMMAND:
|
||||
appended_user_agent = [
|
||||
f"sdk_command/{base_constants.USER_AGENT_SDK_COMMAND}"
|
||||
]
|
||||
# Reset the value for the USER_AGENT_SDK_COMMAND to avoid counting future unrelated api calls.
|
||||
base_constants.USER_AGENT_SDK_COMMAND = ""
|
||||
|
||||
api_client = cls._instantiate_client(
|
||||
location=location,
|
||||
credentials=credentials,
|
||||
appended_user_agent=appended_user_agent,
|
||||
)
|
||||
|
||||
parent = utils.full_resource_name(
|
||||
resource_name=metadata_store_id,
|
||||
resource_noun=metadata_store._MetadataStore._resource_noun,
|
||||
parse_resource_name_method=metadata_store._MetadataStore._parse_resource_name,
|
||||
format_resource_name_method=metadata_store._MetadataStore._format_resource_name,
|
||||
project=project,
|
||||
location=location,
|
||||
)
|
||||
|
||||
resource = cls._create_resource(
|
||||
client=api_client,
|
||||
parent=parent,
|
||||
resource_id=resource_id,
|
||||
schema_title=schema_title,
|
||||
display_name=display_name,
|
||||
schema_version=schema_version,
|
||||
description=description,
|
||||
metadata=metadata,
|
||||
)
|
||||
|
||||
self = cls._empty_constructor(
|
||||
project=project, location=location, credentials=credentials
|
||||
)
|
||||
self._gca_resource = resource
|
||||
self._threading_lock = threading.Lock()
|
||||
|
||||
return self
|
||||
|
||||
@classmethod
|
||||
def _create_resource(
|
||||
cls,
|
||||
client: utils.MetadataClientWithOverride,
|
||||
parent: str,
|
||||
resource_id: str,
|
||||
schema_title: str,
|
||||
display_name: Optional[str] = None,
|
||||
schema_version: Optional[str] = None,
|
||||
description: Optional[str] = None,
|
||||
metadata: Optional[Dict] = None,
|
||||
) -> proto.Message:
|
||||
gapic_context = gca_context.Context(
|
||||
schema_title=schema_title,
|
||||
schema_version=schema_version,
|
||||
display_name=display_name,
|
||||
description=description,
|
||||
metadata=metadata if metadata else {},
|
||||
)
|
||||
return client.create_context(
|
||||
parent=parent,
|
||||
context=gapic_context,
|
||||
context_id=resource_id,
|
||||
)
|
||||
|
||||
def update(
|
||||
self,
|
||||
metadata: Optional[Dict] = None,
|
||||
description: Optional[str] = None,
|
||||
credentials: Optional[auth_credentials.Credentials] = None,
|
||||
location: Optional[str] = None,
|
||||
):
|
||||
"""Updates an existing Metadata Context with new metadata.
|
||||
|
||||
This is implemented with retry on etag errors, up to
|
||||
_ETAG_ERROR_MAX_RETRY_COUNT times.
|
||||
Args:
|
||||
metadata (Dict):
|
||||
Optional. metadata contains the updated metadata information.
|
||||
description (str):
|
||||
Optional. Description describes the resource to be updated.
|
||||
credentials (auth_credentials.Credentials):
|
||||
Custom credentials to use to update this resource. Overrides
|
||||
credentials set in aiplatform.init.
|
||||
"""
|
||||
for _ in range(_ETAG_ERROR_MAX_RETRY_COUNT - 1):
|
||||
try:
|
||||
super().update(
|
||||
metadata=metadata,
|
||||
description=description,
|
||||
credentials=credentials,
|
||||
location=location,
|
||||
)
|
||||
return
|
||||
except Aborted as aborted_exception:
|
||||
regex_match = _ETAG_ERROR_REGEX.match(aborted_exception.message)
|
||||
if regex_match:
|
||||
local_etag = regex_match.group(1)
|
||||
server_etag = regex_match.group(2)
|
||||
if local_etag < server_etag:
|
||||
self.sync_resource()
|
||||
continue
|
||||
raise aborted_exception
|
||||
|
||||
# Expose result/exception directly in the last retry.
|
||||
super().update(
|
||||
metadata=metadata,
|
||||
description=description,
|
||||
credentials=credentials,
|
||||
location=location,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _update_resource(
|
||||
cls,
|
||||
client: utils.MetadataClientWithOverride,
|
||||
resource: proto.Message,
|
||||
) -> proto.Message:
|
||||
"""Update Contexts with given input.
|
||||
|
||||
Args:
|
||||
client (utils.MetadataClientWithOverride):
|
||||
Required. client to send require to Metadata Service.
|
||||
resource (proto.Message):
|
||||
Required. The proto.Message which contains the update information for the resource.
|
||||
"""
|
||||
|
||||
return client.update_context(context=resource)
|
||||
|
||||
@classmethod
|
||||
def _list_resources(
|
||||
cls,
|
||||
client: utils.MetadataClientWithOverride,
|
||||
parent: str,
|
||||
filter: Optional[str] = None, # pylint: disable=redefined-builtin
|
||||
order_by: Optional[str] = None,
|
||||
):
|
||||
"""List Contexts in the parent path that matches the filter.
|
||||
|
||||
Args:
|
||||
client (utils.MetadataClientWithOverride):
|
||||
Required. client to send require to Metadata Service.
|
||||
parent (str):
|
||||
Required. The path where Contexts are stored.
|
||||
filter (str):
|
||||
Optional. filter string to restrict the list result
|
||||
order_by (str):
|
||||
Optional. How the list of messages is ordered. Specify the
|
||||
values to order by and an ordering operation. The default sorting
|
||||
order is ascending. To specify descending order for a field, users
|
||||
append a " desc" suffix; for example: "foo desc, bar". Subfields
|
||||
are specified with a ``.`` character, such as foo.bar. see
|
||||
https://google.aip.dev/132#ordering for more details.
|
||||
|
||||
Returns:
|
||||
List of Contexts.
|
||||
"""
|
||||
|
||||
list_request = gca_metadata_service.ListContextsRequest(
|
||||
parent=parent,
|
||||
filter=filter,
|
||||
order_by=order_by,
|
||||
)
|
||||
return client.list_contexts(request=list_request)
|
||||
|
||||
def add_context_children(self, contexts: List["Context"]):
|
||||
"""Adds the provided contexts as children of this context.
|
||||
|
||||
Args:
|
||||
contexts (List[_Context]): Contexts to add as children.
|
||||
"""
|
||||
self.api_client.add_context_children(
|
||||
context=self.resource_name,
|
||||
child_contexts=[c.resource_name for c in contexts],
|
||||
)
|
||||
|
||||
def query_lineage_subgraph(self) -> gca_lineage_subgraph.LineageSubgraph:
|
||||
"""Queries lineage subgraph of this context.
|
||||
|
||||
Returns:
|
||||
lineage subgraph(gca_lineage_subgraph.LineageSubgraph): Lineage subgraph of this Context.
|
||||
"""
|
||||
|
||||
return self.api_client.query_context_lineage_subgraph(
|
||||
context=self.resource_name, retry=base._DEFAULT_RETRY
|
||||
)
|
||||
|
||||
def get_executions(self) -> List[execution.Execution]:
|
||||
"""Returns Executions associated to this context.
|
||||
|
||||
Returns:
|
||||
executions (List[Executions]): Executions associated to this context.
|
||||
"""
|
||||
return execution.Execution.list(
|
||||
filter=metadata_utils._make_filter_string(in_context=[self.resource_name]),
|
||||
project=self.project,
|
||||
location=self.location,
|
||||
credentials=self.credentials,
|
||||
)
|
||||
@@ -0,0 +1,533 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
# Copyright 2022 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
from copy import deepcopy
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
import proto
|
||||
from google.auth import credentials as auth_credentials
|
||||
|
||||
from google.cloud.aiplatform import models
|
||||
from google.cloud.aiplatform import utils
|
||||
from google.cloud.aiplatform.compat.types import event as gca_event
|
||||
from google.cloud.aiplatform.compat.types import execution as gca_execution
|
||||
from google.cloud.aiplatform.compat.types import (
|
||||
metadata_service as gca_metadata_service,
|
||||
)
|
||||
from google.cloud.aiplatform.constants import base as base_constants
|
||||
from google.cloud.aiplatform.metadata import artifact
|
||||
from google.cloud.aiplatform.metadata import metadata_store
|
||||
from google.cloud.aiplatform.metadata import resource
|
||||
|
||||
|
||||
class Execution(resource._Resource):
|
||||
"""Metadata Execution resource for Vertex AI"""
|
||||
|
||||
_resource_noun = "executions"
|
||||
_getter_method = "get_execution"
|
||||
_delete_method = "delete_execution"
|
||||
_parse_resource_name_method = "parse_execution_path"
|
||||
_format_resource_name_method = "execution_path"
|
||||
_list_method = "list_executions"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
execution_name: str,
|
||||
*,
|
||||
metadata_store_id: str = "default",
|
||||
project: Optional[str] = None,
|
||||
location: Optional[str] = None,
|
||||
credentials: Optional[auth_credentials.Credentials] = None,
|
||||
):
|
||||
"""Retrieves an existing Metadata Execution given a resource name or ID.
|
||||
|
||||
Args:
|
||||
execution_name (str):
|
||||
Required. A fully-qualified resource name or resource ID of the Execution.
|
||||
Example: "projects/123/locations/us-central1/metadataStores/default/executions/my-resource".
|
||||
or "my-resource" when project and location are initialized or passed.
|
||||
metadata_store_id (str):
|
||||
Optional. MetadataStore to retrieve Execution from. If not set, metadata_store_id is set to "default".
|
||||
If execution_name is a fully-qualified resource, its metadata_store_id overrides this one.
|
||||
project (str):
|
||||
Optional. Project to retrieve the artifact from. If not set, project
|
||||
set in aiplatform.init will be used.
|
||||
location (str):
|
||||
Optional. Location to retrieve the Execution from. If not set, location
|
||||
set in aiplatform.init will be used.
|
||||
credentials (auth_credentials.Credentials):
|
||||
Optional. Custom credentials to use to retrieve this Execution. Overrides
|
||||
credentials set in aiplatform.init.
|
||||
"""
|
||||
|
||||
super().__init__(
|
||||
resource_name=execution_name,
|
||||
metadata_store_id=metadata_store_id,
|
||||
project=project,
|
||||
location=location,
|
||||
credentials=credentials,
|
||||
)
|
||||
|
||||
@property
|
||||
def state(self) -> gca_execution.Execution.State:
|
||||
"""State of this Execution."""
|
||||
return self._gca_resource.state
|
||||
|
||||
@classmethod
|
||||
def create(
|
||||
cls,
|
||||
schema_title: str,
|
||||
*,
|
||||
state: gca_execution.Execution.State = gca_execution.Execution.State.RUNNING,
|
||||
resource_id: Optional[str] = None,
|
||||
display_name: Optional[str] = None,
|
||||
schema_version: Optional[str] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
description: Optional[str] = None,
|
||||
metadata_store_id: str = "default",
|
||||
project: Optional[str] = None,
|
||||
location: Optional[str] = None,
|
||||
credentials=Optional[auth_credentials.Credentials],
|
||||
) -> "Execution":
|
||||
"""
|
||||
Creates a new Metadata Execution.
|
||||
|
||||
Args:
|
||||
schema_title (str):
|
||||
Required. schema_title identifies the schema title used by the Execution.
|
||||
state (gca_execution.Execution.State.RUNNING):
|
||||
Optional. State of this Execution. Defaults to RUNNING.
|
||||
resource_id (str):
|
||||
Optional. The <resource_id> portion of the Execution name with
|
||||
the format. This is globally unique in a metadataStore:
|
||||
projects/123/locations/us-central1/metadataStores/<metadata_store_id>/executions/<resource_id>.
|
||||
display_name (str):
|
||||
Optional. The user-defined name of the Execution.
|
||||
schema_version (str):
|
||||
Optional. schema_version specifies the version used by the Execution.
|
||||
If not set, defaults to use the latest version.
|
||||
metadata (Dict):
|
||||
Optional. Contains the metadata information that will be stored in the Execution.
|
||||
description (str):
|
||||
Optional. Describes the purpose of the Execution to be created.
|
||||
metadata_store_id (str):
|
||||
Optional. The <metadata_store_id> portion of the resource name with
|
||||
the format:
|
||||
projects/123/locations/us-central1/metadataStores/<metadata_store_id>/artifacts/<resource_id>
|
||||
If not provided, the MetadataStore's ID will be set to "default".
|
||||
project (str):
|
||||
Optional. Project used to create this Execution. Overrides project set in
|
||||
aiplatform.init.
|
||||
location (str):
|
||||
Optional. Location used to create this Execution. Overrides location set in
|
||||
aiplatform.init.
|
||||
credentials (auth_credentials.Credentials):
|
||||
Optional. Custom credentials used to create this Execution. Overrides
|
||||
credentials set in aiplatform.init.
|
||||
|
||||
Returns:
|
||||
Execution: Instantiated representation of the managed Metadata Execution.
|
||||
|
||||
"""
|
||||
# Add User Agent Header for metrics tracking if one is not specified
|
||||
# If one is already specified this call was initiated by a sub class.
|
||||
if not base_constants.USER_AGENT_SDK_COMMAND:
|
||||
base_constants.USER_AGENT_SDK_COMMAND = (
|
||||
"aiplatform.metadata.execution.Execution.create"
|
||||
)
|
||||
|
||||
return cls._create(
|
||||
resource_id=resource_id,
|
||||
schema_title=schema_title,
|
||||
display_name=display_name,
|
||||
schema_version=schema_version,
|
||||
description=description,
|
||||
metadata=metadata,
|
||||
state=state,
|
||||
metadata_store_id=metadata_store_id,
|
||||
project=project,
|
||||
location=location,
|
||||
credentials=credentials,
|
||||
)
|
||||
|
||||
# TODO() refactor code to move _create to _Resource class.
|
||||
@classmethod
|
||||
def _create(
|
||||
cls,
|
||||
schema_title: str,
|
||||
*,
|
||||
state: gca_execution.Execution.State = gca_execution.Execution.State.RUNNING,
|
||||
resource_id: Optional[str] = None,
|
||||
display_name: Optional[str] = None,
|
||||
schema_version: Optional[str] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
description: Optional[str] = None,
|
||||
metadata_store_id: str = "default",
|
||||
project: Optional[str] = None,
|
||||
location: Optional[str] = None,
|
||||
credentials=Optional[auth_credentials.Credentials],
|
||||
) -> "Execution":
|
||||
"""
|
||||
Creates a new Metadata Execution.
|
||||
|
||||
Args:
|
||||
schema_title (str):
|
||||
Required. schema_title identifies the schema title used by the Execution.
|
||||
state (gca_execution.Execution.State.RUNNING):
|
||||
Optional. State of this Execution. Defaults to RUNNING.
|
||||
resource_id (str):
|
||||
Optional. The <resource_id> portion of the Execution name with
|
||||
the format. This is globally unique in a metadataStore:
|
||||
projects/123/locations/us-central1/metadataStores/<metadata_store_id>/executions/<resource_id>.
|
||||
display_name (str):
|
||||
Optional. The user-defined name of the Execution.
|
||||
schema_version (str):
|
||||
Optional. schema_version specifies the version used by the Execution.
|
||||
If not set, defaults to use the latest version.
|
||||
metadata (Dict):
|
||||
Optional. Contains the metadata information that will be stored in the Execution.
|
||||
description (str):
|
||||
Optional. Describes the purpose of the Execution to be created.
|
||||
metadata_store_id (str):
|
||||
Optional. The <metadata_store_id> portion of the resource name with
|
||||
the format:
|
||||
projects/123/locations/us-central1/metadataStores/<metadata_store_id>/artifacts/<resource_id>
|
||||
If not provided, the MetadataStore's ID will be set to "default".
|
||||
project (str):
|
||||
Optional. Project used to create this Execution. Overrides project set in
|
||||
aiplatform.init.
|
||||
location (str):
|
||||
Optional. Location used to create this Execution. Overrides location set in
|
||||
aiplatform.init.
|
||||
credentials (auth_credentials.Credentials):
|
||||
Optional. Custom credentials used to create this Execution. Overrides
|
||||
credentials set in aiplatform.init.
|
||||
|
||||
Returns:
|
||||
Execution: Instantiated representation of the managed Metadata Execution.
|
||||
|
||||
"""
|
||||
appended_user_agent = []
|
||||
if base_constants.USER_AGENT_SDK_COMMAND:
|
||||
appended_user_agent = [
|
||||
f"sdk_command/{base_constants.USER_AGENT_SDK_COMMAND}"
|
||||
]
|
||||
# Reset the value for the USER_AGENT_SDK_COMMAND to avoid counting future unrelated api calls.
|
||||
base_constants.USER_AGENT_SDK_COMMAND = ""
|
||||
|
||||
api_client = cls._instantiate_client(
|
||||
location=location,
|
||||
credentials=credentials,
|
||||
appended_user_agent=appended_user_agent,
|
||||
)
|
||||
|
||||
parent = utils.full_resource_name(
|
||||
resource_name=metadata_store_id,
|
||||
resource_noun=metadata_store._MetadataStore._resource_noun,
|
||||
parse_resource_name_method=metadata_store._MetadataStore._parse_resource_name,
|
||||
format_resource_name_method=metadata_store._MetadataStore._format_resource_name,
|
||||
project=project,
|
||||
location=location,
|
||||
)
|
||||
|
||||
resource = Execution._create_resource(
|
||||
client=api_client,
|
||||
parent=parent,
|
||||
schema_title=schema_title,
|
||||
resource_id=resource_id,
|
||||
metadata=metadata,
|
||||
description=description,
|
||||
display_name=display_name,
|
||||
schema_version=schema_version,
|
||||
state=state,
|
||||
)
|
||||
self = cls._empty_constructor(
|
||||
project=project, location=location, credentials=credentials
|
||||
)
|
||||
self._gca_resource = resource
|
||||
|
||||
return self
|
||||
|
||||
def __enter__(self):
|
||||
if self.state is not gca_execution.Execution.State.RUNNING:
|
||||
self.update(state=gca_execution.Execution.State.RUNNING)
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
state = (
|
||||
gca_execution.Execution.State.FAILED
|
||||
if exc_type
|
||||
else gca_execution.Execution.State.COMPLETE
|
||||
)
|
||||
self.update(state=state)
|
||||
|
||||
def assign_input_artifacts(
|
||||
self, artifacts: List[Union[artifact.Artifact, models.Model]]
|
||||
):
|
||||
"""Assigns Artifacts as inputs to this Executions.
|
||||
|
||||
Args:
|
||||
artifacts (List[Union[artifact.Artifact, models.Model]]):
|
||||
Required. Artifacts to assign as input.
|
||||
"""
|
||||
self._add_artifact(artifacts=artifacts, input=True)
|
||||
|
||||
def assign_output_artifacts(
|
||||
self, artifacts: List[Union[artifact.Artifact, models.Model]]
|
||||
):
|
||||
"""Assigns Artifacts as outputs to this Executions.
|
||||
|
||||
Args:
|
||||
artifacts (List[Union[artifact.Artifact, models.Model]]):
|
||||
Required. Artifacts to assign as input.
|
||||
"""
|
||||
self._add_artifact(artifacts=artifacts, input=False)
|
||||
|
||||
def _add_artifact(
|
||||
self,
|
||||
artifacts: List[Union[artifact.Artifact, models.Model]],
|
||||
input: bool,
|
||||
):
|
||||
"""Connect Artifact to a given Execution.
|
||||
|
||||
Args:
|
||||
artifact_resource_names (List[str]):
|
||||
Required. The full resource name of the Artifact to connect to the Execution through an Event.
|
||||
input (bool)
|
||||
Required. Whether Artifact is an input event to the Execution or not.
|
||||
"""
|
||||
|
||||
artifact_resource_names = []
|
||||
for a in artifacts:
|
||||
if isinstance(a, artifact.Artifact):
|
||||
artifact_resource_names.append(a.resource_name)
|
||||
else:
|
||||
artifact_resource_names.append(
|
||||
artifact._VertexResourceArtifactResolver.resolve_or_create_resource_artifact(
|
||||
a
|
||||
).resource_name
|
||||
)
|
||||
|
||||
events = [
|
||||
gca_event.Event(
|
||||
artifact=artifact_resource_name,
|
||||
type_=gca_event.Event.Type.INPUT
|
||||
if input
|
||||
else gca_event.Event.Type.OUTPUT,
|
||||
)
|
||||
for artifact_resource_name in artifact_resource_names
|
||||
]
|
||||
|
||||
self.api_client.add_execution_events(
|
||||
execution=self.resource_name,
|
||||
events=events,
|
||||
)
|
||||
|
||||
def _get_artifacts(
|
||||
self, event_type: gca_event.Event.Type
|
||||
) -> List[artifact.Artifact]:
|
||||
"""Get Executions input or output Artifacts.
|
||||
|
||||
Args:
|
||||
event_type (gca_event.Event.Type):
|
||||
Required. The Event type, input or output.
|
||||
Returns:
|
||||
List of Artifacts.
|
||||
"""
|
||||
subgraph = self.api_client.query_execution_inputs_and_outputs(
|
||||
execution=self.resource_name
|
||||
)
|
||||
|
||||
artifact_map = {
|
||||
artifact_metadata.name: artifact_metadata
|
||||
for artifact_metadata in subgraph.artifacts
|
||||
}
|
||||
|
||||
gca_artifacts = [
|
||||
artifact_map[event.artifact]
|
||||
for event in subgraph.events
|
||||
if event.type_ == event_type
|
||||
]
|
||||
|
||||
artifacts = []
|
||||
for gca_artifact in gca_artifacts:
|
||||
this_artifact = artifact.Artifact._empty_constructor(
|
||||
project=self.project,
|
||||
location=self.location,
|
||||
credentials=self.credentials,
|
||||
)
|
||||
this_artifact._gca_resource = gca_artifact
|
||||
artifacts.append(this_artifact)
|
||||
|
||||
return artifacts
|
||||
|
||||
def get_input_artifacts(self) -> List[artifact.Artifact]:
|
||||
"""Get the input Artifacts of this Execution.
|
||||
|
||||
Returns:
|
||||
List of input Artifacts.
|
||||
"""
|
||||
return self._get_artifacts(event_type=gca_event.Event.Type.INPUT)
|
||||
|
||||
def get_output_artifacts(self) -> List[artifact.Artifact]:
|
||||
"""Get the output Artifacts of this Execution.
|
||||
|
||||
Returns:
|
||||
List of output Artifacts.
|
||||
"""
|
||||
return self._get_artifacts(event_type=gca_event.Event.Type.OUTPUT)
|
||||
|
||||
@classmethod
|
||||
def _create_resource(
|
||||
cls,
|
||||
client: utils.MetadataClientWithOverride,
|
||||
parent: str,
|
||||
schema_title: str,
|
||||
state: gca_execution.Execution.State = gca_execution.Execution.State.RUNNING,
|
||||
resource_id: Optional[str] = None,
|
||||
display_name: Optional[str] = None,
|
||||
schema_version: Optional[str] = None,
|
||||
description: Optional[str] = None,
|
||||
metadata: Optional[Dict] = None,
|
||||
) -> gca_execution.Execution:
|
||||
"""
|
||||
Creates a new Metadata Execution.
|
||||
|
||||
Args:
|
||||
client (utils.MetadataClientWithOverride):
|
||||
Required. Instantiated Metadata Service Client.
|
||||
parent (str):
|
||||
Required: MetadataStore parent in which to create this Execution.
|
||||
schema_title (str):
|
||||
Required. schema_title identifies the schema title used by the Execution.
|
||||
state (gca_execution.Execution.State):
|
||||
Optional. State of this Execution. Defaults to RUNNING.
|
||||
resource_id (str):
|
||||
Optional. The {execution} portion of the resource name with the
|
||||
format:
|
||||
``projects/{project}/locations/{location}/metadataStores/{metadatastore}/executions/{execution}``
|
||||
If not provided, the Execution's ID will be a UUID generated
|
||||
by the service. Must be 4-128 characters in length. Valid
|
||||
characters are ``/[a-z][0-9]-/``. Must be unique across all
|
||||
Executions in the parent MetadataStore. (Otherwise the
|
||||
request will fail with ALREADY_EXISTS, or PERMISSION_DENIED
|
||||
if the caller can't view the preexisting Execution.)
|
||||
display_name (str):
|
||||
Optional. The user-defined name of the Execution.
|
||||
schema_version (str):
|
||||
Optional. schema_version specifies the version used by the Execution.
|
||||
If not set, defaults to use the latest version.
|
||||
description (str):
|
||||
Optional. Describes the purpose of the Execution to be created.
|
||||
metadata (Dict):
|
||||
Optional. Contains the metadata information that will be stored in the Execution.
|
||||
|
||||
Returns:
|
||||
Execution: Instantiated representation of the managed Metadata Execution.
|
||||
|
||||
"""
|
||||
gapic_execution = gca_execution.Execution(
|
||||
schema_title=schema_title,
|
||||
schema_version=schema_version,
|
||||
display_name=display_name,
|
||||
description=description,
|
||||
metadata=metadata if metadata else {},
|
||||
state=state,
|
||||
)
|
||||
return client.create_execution(
|
||||
parent=parent,
|
||||
execution=gapic_execution,
|
||||
execution_id=resource_id,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _list_resources(
|
||||
cls,
|
||||
client: utils.MetadataClientWithOverride,
|
||||
parent: str,
|
||||
filter: Optional[str] = None, # pylint: disable=redefined-builtin
|
||||
order_by: Optional[str] = None,
|
||||
):
|
||||
"""List Executions in the parent path that matches the filter.
|
||||
|
||||
Args:
|
||||
client (utils.MetadataClientWithOverride):
|
||||
Required. client to send require to Metadata Service.
|
||||
parent (str):
|
||||
Required. The path where Executions are stored.
|
||||
filter (str):
|
||||
Optional. filter string to restrict the list result
|
||||
order_by (str):
|
||||
Optional. How the list of messages is ordered. Specify the
|
||||
values to order by and an ordering operation. The default sorting
|
||||
order is ascending. To specify descending order for a field, users
|
||||
append a " desc" suffix; for example: "foo desc, bar". Subfields
|
||||
are specified with a ``.`` character, such as foo.bar. see
|
||||
https://google.aip.dev/132#ordering for more details.
|
||||
Returns:
|
||||
List of execution.
|
||||
"""
|
||||
|
||||
list_request = gca_metadata_service.ListExecutionsRequest(
|
||||
parent=parent,
|
||||
filter=filter,
|
||||
order_by=order_by,
|
||||
)
|
||||
return client.list_executions(request=list_request)
|
||||
|
||||
@classmethod
|
||||
def _update_resource(
|
||||
cls,
|
||||
client: utils.MetadataClientWithOverride,
|
||||
resource: proto.Message,
|
||||
) -> proto.Message:
|
||||
"""Update Executions with given input.
|
||||
|
||||
Args:
|
||||
client (utils.MetadataClientWithOverride):
|
||||
Required. client to send require to Metadata Service.
|
||||
resource (proto.Message):
|
||||
Required. The proto.Message which contains the update information for the resource.
|
||||
"""
|
||||
|
||||
return client.update_execution(execution=resource)
|
||||
|
||||
def update(
|
||||
self,
|
||||
state: Optional[gca_execution.Execution.State] = None,
|
||||
description: Optional[str] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
):
|
||||
"""Update this Execution.
|
||||
|
||||
Args:
|
||||
state (gca_execution.Execution.State):
|
||||
Optional. State of this Execution.
|
||||
description (str):
|
||||
Optional. Describes the purpose of the Execution to be created.
|
||||
metadata (Dict[str, Any):
|
||||
Optional. Contains the metadata information that will be stored in the Execution.
|
||||
"""
|
||||
|
||||
gca_resource = deepcopy(self._gca_resource)
|
||||
if state:
|
||||
gca_resource.state = state
|
||||
if description:
|
||||
gca_resource.description = description
|
||||
self._nested_update_metadata(gca_resource=gca_resource, metadata=metadata)
|
||||
self._gca_resource = self._update_resource(
|
||||
self.api_client, resource=gca_resource
|
||||
)
|
||||
@@ -0,0 +1,827 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
# Copyright 2023 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
import abc
|
||||
import concurrent.futures
|
||||
from dataclasses import dataclass
|
||||
import logging
|
||||
from typing import Dict, List, NamedTuple, Optional, Tuple, Type, Union
|
||||
|
||||
from google.api_core import exceptions
|
||||
from google.auth import credentials as auth_credentials
|
||||
|
||||
from google.cloud.aiplatform import base
|
||||
from google.cloud.aiplatform.metadata import artifact
|
||||
from google.cloud.aiplatform.metadata import constants
|
||||
from google.cloud.aiplatform.metadata import context
|
||||
from google.cloud.aiplatform.metadata import execution
|
||||
from google.cloud.aiplatform.metadata import metadata
|
||||
from google.cloud.aiplatform.metadata import metadata_store
|
||||
from google.cloud.aiplatform.metadata import resource
|
||||
from google.cloud.aiplatform.metadata import utils as metadata_utils
|
||||
from google.cloud.aiplatform.tensorboard import tensorboard_resource
|
||||
|
||||
_LOGGER = base.Logger(__name__)
|
||||
_HIGH_RUN_COUNT_THRESHOLD = 100 # Used in get_data_frame to make suggestion to user
|
||||
|
||||
|
||||
@dataclass
|
||||
class _ExperimentRow:
|
||||
"""Class for representing a run row in an Experiments Dataframe.
|
||||
|
||||
Attributes:
|
||||
params (Dict[str, Union[float, int, str]]): Optional. The parameters of this run.
|
||||
metrics (Dict[str, Union[float, int, str]]): Optional. The metrics of this run.
|
||||
time_series_metrics (Dict[str, float]): Optional. The latest time series metrics of this run.
|
||||
experiment_run_type (Optional[str]): Optional. The type of this run.
|
||||
name (str): Optional. The name of this run.
|
||||
state (str): Optional. The state of this run.
|
||||
"""
|
||||
|
||||
params: Optional[Dict[str, Union[float, int, str]]] = None
|
||||
metrics: Optional[Dict[str, Union[float, int, str]]] = None
|
||||
time_series_metrics: Optional[Dict[str, float]] = None
|
||||
experiment_run_type: Optional[str] = None
|
||||
name: Optional[str] = None
|
||||
state: Optional[str] = None
|
||||
|
||||
def to_dict(self) -> Dict[str, Union[float, int, str]]:
|
||||
"""Converts this experiment row into a dictionary.
|
||||
|
||||
Returns:
|
||||
Row as a dictionary.
|
||||
"""
|
||||
result = {
|
||||
"run_type": self.experiment_run_type,
|
||||
"run_name": self.name,
|
||||
"state": self.state,
|
||||
}
|
||||
for prefix, field in [
|
||||
(constants._PARAM_PREFIX, self.params),
|
||||
(constants._METRIC_PREFIX, self.metrics),
|
||||
(constants._TIME_SERIES_METRIC_PREFIX, self.time_series_metrics),
|
||||
]:
|
||||
if field:
|
||||
result.update(
|
||||
{f"{prefix}.{key}": value for key, value in field.items()}
|
||||
)
|
||||
return result
|
||||
|
||||
|
||||
class Experiment:
|
||||
"""Represents a Vertex AI Experiment resource."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
experiment_name: str,
|
||||
*,
|
||||
project: Optional[str] = None,
|
||||
location: Optional[str] = None,
|
||||
credentials: Optional[auth_credentials.Credentials] = None,
|
||||
):
|
||||
"""
|
||||
|
||||
```py
|
||||
my_experiment = aiplatform.Experiment('my-experiment')
|
||||
```
|
||||
|
||||
Args:
|
||||
experiment_name (str):
|
||||
Required. The name or resource name of this experiment.
|
||||
|
||||
Resource name is of the format:
|
||||
`projects/123/locations/us-central1/metadataStores/default/contexts/my-experiment`
|
||||
project (str):
|
||||
Optional. Project where this experiment is located. Overrides
|
||||
project set in aiplatform.init.
|
||||
location (str):
|
||||
Optional. Location where this experiment is located. Overrides
|
||||
location set in aiplatform.init.
|
||||
credentials (auth_credentials.Credentials):
|
||||
Optional. Custom credentials used to retrieve this experiment.
|
||||
Overrides credentials set in aiplatform.init.
|
||||
"""
|
||||
|
||||
metadata_args = dict(
|
||||
resource_name=experiment_name,
|
||||
project=project,
|
||||
location=location,
|
||||
credentials=credentials,
|
||||
)
|
||||
|
||||
with _SetLoggerLevel(resource):
|
||||
experiment_context = context.Context(**metadata_args)
|
||||
self._validate_experiment_context(experiment_context)
|
||||
|
||||
self._metadata_context = experiment_context
|
||||
|
||||
@staticmethod
|
||||
def _validate_experiment_context(experiment_context: context.Context):
|
||||
"""Validates this context is an experiment context.
|
||||
|
||||
Args:
|
||||
experiment_context (context._Context): Metadata context.
|
||||
Raises:
|
||||
ValueError: If Metadata context is not an experiment context or a TensorboardExperiment.
|
||||
"""
|
||||
if experiment_context.schema_title != constants.SYSTEM_EXPERIMENT:
|
||||
raise ValueError(
|
||||
f"Experiment name {experiment_context.name} is of type "
|
||||
f"({experiment_context.schema_title}) in this MetadataStore. "
|
||||
f"It must of type {constants.SYSTEM_EXPERIMENT}."
|
||||
)
|
||||
if Experiment._is_tensorboard_experiment(experiment_context):
|
||||
raise ValueError(
|
||||
f"Experiment name {experiment_context.name} is a TensorboardExperiment context "
|
||||
f"and cannot be used as a Vertex AI Experiment."
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _is_tensorboard_experiment(context: context.Context) -> bool:
|
||||
"""Returns True if Experiment is a Tensorboard Experiment created by CustomJob."""
|
||||
return constants.TENSORBOARD_CUSTOM_JOB_EXPERIMENT_FIELD in context.metadata
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
"""The name of this experiment."""
|
||||
return self._metadata_context.name
|
||||
|
||||
@classmethod
|
||||
def create(
|
||||
cls,
|
||||
experiment_name: str,
|
||||
*,
|
||||
description: Optional[str] = None,
|
||||
project: Optional[str] = None,
|
||||
location: Optional[str] = None,
|
||||
credentials: Optional[auth_credentials.Credentials] = None,
|
||||
) -> "Experiment":
|
||||
"""Creates a new experiment in Vertex AI Experiments.
|
||||
|
||||
```py
|
||||
my_experiment = aiplatform.Experiment.create('my-experiment', description='my description')
|
||||
```
|
||||
|
||||
Args:
|
||||
experiment_name (str): Required. The name of this experiment.
|
||||
description (str): Optional. Describes this experiment's purpose.
|
||||
project (str):
|
||||
Optional. Project where this experiment will be created. Overrides project set in
|
||||
aiplatform.init.
|
||||
location (str):
|
||||
Optional. Location where this experiment will be created. Overrides location set in
|
||||
aiplatform.init.
|
||||
credentials (auth_credentials.Credentials):
|
||||
Optional. Custom credentials used to create this experiment. Overrides
|
||||
credentials set in aiplatform.init.
|
||||
Returns:
|
||||
The newly created experiment.
|
||||
"""
|
||||
|
||||
metadata_store._MetadataStore.ensure_default_metadata_store_exists(
|
||||
project=project, location=location, credentials=credentials
|
||||
)
|
||||
|
||||
with _SetLoggerLevel(resource):
|
||||
experiment_context = context.Context._create(
|
||||
resource_id=experiment_name,
|
||||
display_name=experiment_name,
|
||||
description=description,
|
||||
schema_title=constants.SYSTEM_EXPERIMENT,
|
||||
schema_version=metadata._get_experiment_schema_version(),
|
||||
metadata=constants.EXPERIMENT_METADATA,
|
||||
project=project,
|
||||
location=location,
|
||||
credentials=credentials,
|
||||
)
|
||||
|
||||
self = cls.__new__(cls)
|
||||
self._metadata_context = experiment_context
|
||||
|
||||
return self
|
||||
|
||||
@classmethod
|
||||
def get(
|
||||
cls,
|
||||
experiment_name: str,
|
||||
*,
|
||||
project: Optional[str] = None,
|
||||
location: Optional[str] = None,
|
||||
credentials: Optional[auth_credentials.Credentials] = None,
|
||||
) -> Optional["Experiment"]:
|
||||
"""Gets experiment if one exists with this experiment_name in Vertex AI Experiments.
|
||||
|
||||
Args:
|
||||
experiment_name (str):
|
||||
Required. The name of this experiment.
|
||||
project (str):
|
||||
Optional. Project used to retrieve this resource.
|
||||
Overrides project set in aiplatform.init.
|
||||
location (str):
|
||||
Optional. Location used to retrieve this resource.
|
||||
Overrides location set in aiplatform.init.
|
||||
credentials (auth_credentials.Credentials):
|
||||
Optional. Custom credentials used to retrieve this resource.
|
||||
Overrides credentials set in aiplatform.init.
|
||||
|
||||
Returns:
|
||||
Vertex AI experiment or None if no resource was found.
|
||||
"""
|
||||
try:
|
||||
return cls(
|
||||
experiment_name=experiment_name,
|
||||
project=project,
|
||||
location=location,
|
||||
credentials=credentials,
|
||||
)
|
||||
except exceptions.NotFound:
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def get_or_create(
|
||||
cls,
|
||||
experiment_name: str,
|
||||
*,
|
||||
description: Optional[str] = None,
|
||||
project: Optional[str] = None,
|
||||
location: Optional[str] = None,
|
||||
credentials: Optional[auth_credentials.Credentials] = None,
|
||||
) -> "Experiment":
|
||||
"""Gets experiment if one exists with this experiment_name in Vertex AI Experiments.
|
||||
|
||||
Otherwise creates this experiment.
|
||||
|
||||
```py
|
||||
my_experiment = aiplatform.Experiment.get_or_create('my-experiment', description='my description')
|
||||
```
|
||||
|
||||
Args:
|
||||
experiment_name (str): Required. The name of this experiment.
|
||||
description (str): Optional. Describes this experiment's purpose.
|
||||
project (str):
|
||||
Optional. Project where this experiment will be retrieved from or created. Overrides project set in
|
||||
aiplatform.init.
|
||||
location (str):
|
||||
Optional. Location where this experiment will be retrieved from or created. Overrides location set in
|
||||
aiplatform.init.
|
||||
credentials (auth_credentials.Credentials):
|
||||
Optional. Custom credentials used to retrieve or create this experiment. Overrides
|
||||
credentials set in aiplatform.init.
|
||||
Returns:
|
||||
Vertex AI experiment.
|
||||
"""
|
||||
|
||||
metadata_store._MetadataStore.ensure_default_metadata_store_exists(
|
||||
project=project, location=location, credentials=credentials
|
||||
)
|
||||
|
||||
with _SetLoggerLevel(resource):
|
||||
experiment_context = context.Context.get_or_create(
|
||||
resource_id=experiment_name,
|
||||
display_name=experiment_name,
|
||||
description=description,
|
||||
schema_title=constants.SYSTEM_EXPERIMENT,
|
||||
schema_version=metadata._get_experiment_schema_version(),
|
||||
metadata=constants.EXPERIMENT_METADATA,
|
||||
project=project,
|
||||
location=location,
|
||||
credentials=credentials,
|
||||
)
|
||||
|
||||
cls._validate_experiment_context(experiment_context)
|
||||
|
||||
if description and description != experiment_context.description:
|
||||
experiment_context.update(description=description)
|
||||
|
||||
self = cls.__new__(cls)
|
||||
self._metadata_context = experiment_context
|
||||
|
||||
return self
|
||||
|
||||
@classmethod
|
||||
def list(
|
||||
cls,
|
||||
*,
|
||||
project: Optional[str] = None,
|
||||
location: Optional[str] = None,
|
||||
credentials: Optional[auth_credentials.Credentials] = None,
|
||||
) -> List["Experiment"]:
|
||||
"""List all Vertex AI Experiments in the given project.
|
||||
|
||||
```py
|
||||
my_experiments = aiplatform.Experiment.list()
|
||||
```
|
||||
|
||||
Args:
|
||||
project (str):
|
||||
Optional. Project to list these experiments from. Overrides project set in
|
||||
aiplatform.init.
|
||||
location (str):
|
||||
Optional. Location to list these experiments from. Overrides location set in
|
||||
aiplatform.init.
|
||||
credentials (auth_credentials.Credentials):
|
||||
Optional. Custom credentials to list these experiments. Overrides
|
||||
credentials set in aiplatform.init.
|
||||
Returns:
|
||||
List of Vertex AI experiments.
|
||||
"""
|
||||
|
||||
filter_str = metadata_utils._make_filter_string(
|
||||
schema_title=constants.SYSTEM_EXPERIMENT
|
||||
)
|
||||
|
||||
with _SetLoggerLevel(resource):
|
||||
experiment_contexts = context.Context.list(
|
||||
filter=filter_str,
|
||||
project=project,
|
||||
location=location,
|
||||
credentials=credentials,
|
||||
)
|
||||
|
||||
experiments = []
|
||||
for experiment_context in experiment_contexts:
|
||||
# Filters Tensorboard Experiments
|
||||
if not cls._is_tensorboard_experiment(experiment_context):
|
||||
experiment = cls.__new__(cls)
|
||||
experiment._metadata_context = experiment_context
|
||||
experiments.append(experiment)
|
||||
return experiments
|
||||
|
||||
@property
|
||||
def resource_name(self) -> str:
|
||||
"""The Metadata context resource name of this experiment."""
|
||||
return self._metadata_context.resource_name
|
||||
|
||||
@property
|
||||
def backing_tensorboard_resource_name(self) -> Optional[str]:
|
||||
"""The Tensorboard resource associated with this Experiment if there is one."""
|
||||
return self._metadata_context.metadata.get(
|
||||
constants._BACKING_TENSORBOARD_RESOURCE_KEY
|
||||
)
|
||||
|
||||
def delete(self, *, delete_backing_tensorboard_runs: bool = False):
|
||||
"""Deletes this experiment all the experiment runs under this experiment
|
||||
|
||||
Does not delete Pipeline runs, Artifacts, or Executions associated to this experiment
|
||||
or experiment runs in this experiment.
|
||||
|
||||
```py
|
||||
my_experiment = aiplatform.Experiment('my-experiment')
|
||||
my_experiment.delete(delete_backing_tensorboard_runs=True)
|
||||
```
|
||||
|
||||
Args:
|
||||
delete_backing_tensorboard_runs (bool):
|
||||
Optional. If True will also delete the Tensorboard Runs associated to the experiment
|
||||
runs under this experiment that we used to store time series metrics.
|
||||
"""
|
||||
|
||||
experiment_runs = _SUPPORTED_LOGGABLE_RESOURCES[context.Context][
|
||||
constants.SYSTEM_EXPERIMENT_RUN
|
||||
].list(experiment=self)
|
||||
for experiment_run in experiment_runs:
|
||||
experiment_run.delete(
|
||||
delete_backing_tensorboard_run=delete_backing_tensorboard_runs
|
||||
)
|
||||
try:
|
||||
self._metadata_context.delete()
|
||||
except exceptions.NotFound:
|
||||
_LOGGER.warning(
|
||||
f"Experiment {self.name} metadata node not found. Skipping deletion."
|
||||
)
|
||||
|
||||
def get_data_frame(
|
||||
self, *, include_time_series: bool = True
|
||||
) -> "pd.DataFrame": # noqa: F821
|
||||
"""Get parameters, metrics, and time series metrics of all runs in this experiment as Dataframe.
|
||||
|
||||
```py
|
||||
my_experiment = aiplatform.Experiment('my-experiment')
|
||||
df = my_experiment.get_data_frame()
|
||||
```
|
||||
Args:
|
||||
include_time_series (bool):
|
||||
Optional. Whether or not to include time series metrics in df.
|
||||
Default is True. Setting to False will largely improve execution
|
||||
time and reduce quota contributing calls. Recommended when time
|
||||
series metrics are not needed or number of runs in Experiment is
|
||||
large. For time series metrics consider querying a specific run
|
||||
using get_time_series_data_frame.
|
||||
|
||||
Returns:
|
||||
pd.DataFrame: Pandas Dataframe of Experiment Runs.
|
||||
|
||||
Raises:
|
||||
ImportError: If pandas is not installed.
|
||||
"""
|
||||
try:
|
||||
import pandas as pd
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"Pandas is not installed and is required to get dataframe as the return format. "
|
||||
'Please install the SDK using "pip install google-cloud-aiplatform[metadata]"'
|
||||
)
|
||||
|
||||
service_request_args = dict(
|
||||
project=self._metadata_context.project,
|
||||
location=self._metadata_context.location,
|
||||
credentials=self._metadata_context.credentials,
|
||||
)
|
||||
|
||||
filter_str = metadata_utils._make_filter_string(
|
||||
schema_title=sorted(
|
||||
list(_SUPPORTED_LOGGABLE_RESOURCES[context.Context].keys())
|
||||
),
|
||||
parent_contexts=[self._metadata_context.resource_name],
|
||||
)
|
||||
contexts = context.Context.list(filter_str, **service_request_args)
|
||||
|
||||
filter_str = metadata_utils._make_filter_string(
|
||||
schema_title=list(
|
||||
_SUPPORTED_LOGGABLE_RESOURCES[execution.Execution].keys()
|
||||
),
|
||||
in_context=[self._metadata_context.resource_name],
|
||||
)
|
||||
|
||||
executions = execution.Execution.list(filter_str, **service_request_args)
|
||||
|
||||
run_count = max([len(contexts), len(executions)])
|
||||
if include_time_series and run_count > _HIGH_RUN_COUNT_THRESHOLD:
|
||||
_LOGGER.warning(
|
||||
f"Number of runs {run_count} is high. Consider setting "
|
||||
f"include_time_series to False to improve execution performance"
|
||||
)
|
||||
if not include_time_series:
|
||||
_LOGGER.warning(
|
||||
"include_time_series is set to False. Time series metrics will"
|
||||
" not be included in this call even if they exist."
|
||||
)
|
||||
|
||||
rows = []
|
||||
if contexts or executions:
|
||||
with concurrent.futures.ThreadPoolExecutor(
|
||||
max_workers=run_count
|
||||
) as executor:
|
||||
futures = [
|
||||
executor.submit(
|
||||
_SUPPORTED_LOGGABLE_RESOURCES[context.Context][
|
||||
metadata_context.schema_title
|
||||
]._query_experiment_row,
|
||||
metadata_context,
|
||||
experiment=self,
|
||||
include_time_series=include_time_series,
|
||||
)
|
||||
for metadata_context in contexts
|
||||
]
|
||||
|
||||
# backward compatibility
|
||||
futures.extend(
|
||||
executor.submit(
|
||||
_SUPPORTED_LOGGABLE_RESOURCES[execution.Execution][
|
||||
metadata_execution.schema_title
|
||||
]._query_experiment_row,
|
||||
metadata_execution,
|
||||
experiment=self,
|
||||
include_time_series=include_time_series,
|
||||
)
|
||||
for metadata_execution in executions
|
||||
)
|
||||
|
||||
for future in futures:
|
||||
try:
|
||||
row_dict = future.result().to_dict()
|
||||
except Exception as exc:
|
||||
raise ValueError(
|
||||
f"Failed to get experiment row for {self.name}"
|
||||
) from exc
|
||||
else:
|
||||
row_dict.update({"experiment_name": self.name})
|
||||
rows.append(row_dict)
|
||||
|
||||
df = pd.DataFrame(rows)
|
||||
|
||||
column_name_sort_map = {
|
||||
"experiment_name": -1,
|
||||
"run_name": 1,
|
||||
"run_type": 2,
|
||||
"state": 3,
|
||||
}
|
||||
|
||||
def column_sort_key(key: str) -> int:
|
||||
"""Helper method to reorder columns."""
|
||||
order = column_name_sort_map.get(key)
|
||||
if order:
|
||||
return order
|
||||
elif key.startswith("param"):
|
||||
return 5
|
||||
elif key.startswith("metric"):
|
||||
return 6
|
||||
else:
|
||||
return 7
|
||||
|
||||
columns = df.columns
|
||||
columns = sorted(columns, key=column_sort_key)
|
||||
df = df.reindex(columns, axis=1)
|
||||
|
||||
return df
|
||||
|
||||
def _lookup_backing_tensorboard(self) -> Optional[tensorboard_resource.Tensorboard]:
|
||||
"""Returns backing tensorboard if one is set.
|
||||
|
||||
Returns:
|
||||
Tensorboard resource if one exists, otherwise returns None.
|
||||
"""
|
||||
tensorboard_resource_name = self._metadata_context.metadata.get(
|
||||
constants._BACKING_TENSORBOARD_RESOURCE_KEY
|
||||
)
|
||||
|
||||
if not tensorboard_resource_name:
|
||||
with _SetLoggerLevel(resource):
|
||||
self._metadata_context.sync_resource()
|
||||
tensorboard_resource_name = self._metadata_context.metadata.get(
|
||||
constants._BACKING_TENSORBOARD_RESOURCE_KEY
|
||||
)
|
||||
|
||||
if tensorboard_resource_name:
|
||||
try:
|
||||
return tensorboard_resource.Tensorboard(
|
||||
tensorboard_resource_name,
|
||||
credentials=self._metadata_context.credentials,
|
||||
)
|
||||
except exceptions.NotFound:
|
||||
self._metadata_context.update(
|
||||
metadata={constants._BACKING_TENSORBOARD_RESOURCE_KEY: None}
|
||||
)
|
||||
return None
|
||||
|
||||
def get_backing_tensorboard_resource(
|
||||
self,
|
||||
) -> Optional[tensorboard_resource.Tensorboard]:
|
||||
"""Get the backing tensorboard for this experiment if one exists.
|
||||
|
||||
```py
|
||||
my_experiment = aiplatform.Experiment('my-experiment')
|
||||
tb = my_experiment.get_backing_tensorboard_resource()
|
||||
```
|
||||
|
||||
Returns:
|
||||
Backing Tensorboard resource for this experiment if one exists.
|
||||
"""
|
||||
return self._lookup_backing_tensorboard()
|
||||
|
||||
def assign_backing_tensorboard(
|
||||
self, tensorboard: Union[tensorboard_resource.Tensorboard, str]
|
||||
):
|
||||
"""Assigns tensorboard as backing tensorboard to support time series metrics logging.
|
||||
|
||||
```py
|
||||
tb = aiplatform.Tensorboard('tensorboard-resource-id')
|
||||
my_experiment = aiplatform.Experiment('my-experiment')
|
||||
my_experiment.assign_backing_tensorboard(tb)
|
||||
```
|
||||
|
||||
Args:
|
||||
tensorboard (Union[aiplatform.Tensorboard, str]):
|
||||
Required. Tensorboard resource or resource name to associate to this experiment.
|
||||
|
||||
Raises:
|
||||
ValueError: If this experiment already has a previously set backing tensorboard resource.
|
||||
ValueError: If Tensorboard is not in same project and location as this experiment.
|
||||
"""
|
||||
|
||||
backing_tensorboard = self._lookup_backing_tensorboard()
|
||||
if backing_tensorboard:
|
||||
tensorboard_resource_name = (
|
||||
tensorboard
|
||||
if isinstance(tensorboard, str)
|
||||
else tensorboard.resource_name
|
||||
)
|
||||
if tensorboard_resource_name != backing_tensorboard.resource_name:
|
||||
raise ValueError(
|
||||
f"Experiment {self._metadata_context.name} already associated '"
|
||||
f"to tensorboard resource {backing_tensorboard.resource_name}"
|
||||
)
|
||||
|
||||
if isinstance(tensorboard, str):
|
||||
tensorboard = tensorboard_resource.Tensorboard(
|
||||
tensorboard,
|
||||
project=self._metadata_context.project,
|
||||
location=self._metadata_context.location,
|
||||
credentials=self._metadata_context.credentials,
|
||||
)
|
||||
|
||||
if tensorboard.project not in self._metadata_context._project_tuple:
|
||||
raise ValueError(
|
||||
f"Tensorboard is in project {tensorboard.project} but must be in project {self._metadata_context.project}"
|
||||
)
|
||||
if tensorboard.location != self._metadata_context.location:
|
||||
raise ValueError(
|
||||
f"Tensorboard is in location {tensorboard.location} but must be in location {self._metadata_context.location}"
|
||||
)
|
||||
|
||||
self._metadata_context.update(
|
||||
metadata={
|
||||
constants._BACKING_TENSORBOARD_RESOURCE_KEY: tensorboard.resource_name
|
||||
},
|
||||
location=self._metadata_context.location,
|
||||
)
|
||||
|
||||
def _log_experiment_loggable(self, experiment_loggable: "_ExperimentLoggable"):
|
||||
"""Associates a Vertex resource that can be logged to an Experiment as run of this experiment.
|
||||
|
||||
Args:
|
||||
experiment_loggable (_ExperimentLoggable):
|
||||
A Vertex Resource that can be logged to an Experiment directly.
|
||||
"""
|
||||
context = experiment_loggable._get_context()
|
||||
self._metadata_context.add_context_children([context])
|
||||
|
||||
@property
|
||||
def dashboard_url(self) -> Optional[str]:
|
||||
"""Cloud console URL for this resource."""
|
||||
url = f"https://console.cloud.google.com/vertex-ai/experiments/locations/{self._metadata_context.location}/experiments/{self._metadata_context.name}?project={self._metadata_context.project}"
|
||||
return url
|
||||
|
||||
|
||||
class _SetLoggerLevel:
|
||||
"""Helper method to suppress logging."""
|
||||
|
||||
def __init__(self, module):
|
||||
self._module = module
|
||||
|
||||
def __enter__(self):
|
||||
logging.getLogger(self._module.__name__).setLevel(logging.WARNING)
|
||||
|
||||
def __exit__(self, exc_type, exc_value, traceback):
|
||||
logging.getLogger(self._module.__name__).setLevel(logging.INFO)
|
||||
|
||||
|
||||
class _VertexResourceWithMetadata(NamedTuple):
|
||||
"""Represents a resource coupled with it's metadata representation"""
|
||||
|
||||
resource: base.VertexAiResourceNoun
|
||||
metadata: Union[artifact.Artifact, execution.Execution, context.Context]
|
||||
|
||||
|
||||
class _ExperimentLoggableSchema(NamedTuple):
|
||||
"""Used with _ExperimentLoggable to capture Metadata representation information about resoure.
|
||||
|
||||
For example:
|
||||
_ExperimentLoggableSchema(title='system.PipelineRun', type=context._Context)
|
||||
|
||||
Defines the schema and metadata type to lookup PipelineJobs.
|
||||
"""
|
||||
|
||||
title: str
|
||||
type: Union[Type[context.Context], Type[execution.Execution]] = context.Context
|
||||
|
||||
|
||||
class _ExperimentLoggable(abc.ABC):
|
||||
"""Abstract base class to define a Vertex Resource as loggable against an Experiment.
|
||||
|
||||
For example:
|
||||
class PipelineJob(..., experiment_loggable_schemas=
|
||||
(_ExperimentLoggableSchema(title='system.PipelineRun'), )
|
||||
|
||||
"""
|
||||
|
||||
def __init_subclass__(
|
||||
cls, *, experiment_loggable_schemas: Tuple[_ExperimentLoggableSchema], **kwargs
|
||||
):
|
||||
"""Register the metadata_schema for the subclass so Experiment can use it to retrieve the associated types.
|
||||
|
||||
usage:
|
||||
|
||||
class PipelineJob(..., experiment_loggable_schemas=
|
||||
(_ExperimentLoggableSchema(title='system.PipelineRun'), )
|
||||
|
||||
Args:
|
||||
experiment_loggable_schemas:
|
||||
Tuple of the schema_title and type pairs that represent this resource. Note that a single item in the
|
||||
tuple will be most common. Currently only experiment run has multiple representation for backwards
|
||||
compatibility. Almost all schemas should be Contexts and Execution is currently only supported
|
||||
for backwards compatibility of experiment runs.
|
||||
|
||||
"""
|
||||
super().__init_subclass__(**kwargs)
|
||||
|
||||
# register the type when module is loaded
|
||||
for schema in experiment_loggable_schemas:
|
||||
_SUPPORTED_LOGGABLE_RESOURCES[schema.type][schema.title] = cls
|
||||
|
||||
@abc.abstractmethod
|
||||
def _get_context(self) -> context.Context:
|
||||
"""Should return the metadata context that represents this resource.
|
||||
|
||||
The subclass should enforce this context exists.
|
||||
|
||||
Returns:
|
||||
Context that represents this resource.
|
||||
"""
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
@abc.abstractmethod
|
||||
def _query_experiment_row(
|
||||
cls, node: Union[context.Context, execution.Execution]
|
||||
) -> _ExperimentRow:
|
||||
"""Should return parameters and metrics for this resource as a run row.
|
||||
|
||||
Args:
|
||||
node: The metadata node that represents this resource.
|
||||
Returns:
|
||||
A populated run row for this resource.
|
||||
"""
|
||||
pass
|
||||
|
||||
def _validate_experiment(self, experiment: Union[str, Experiment]):
|
||||
"""Validates experiment is accessible. Can be used by subclass to throw before creating the intended resource.
|
||||
|
||||
Args:
|
||||
experiment (Union[str, Experiment]): The experiment that this resource will be associated to.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If service raises any exception when trying to access this experiment.
|
||||
ValueError: If resource project or location do not match experiment project or location.
|
||||
"""
|
||||
|
||||
if isinstance(experiment, str):
|
||||
try:
|
||||
experiment = Experiment.get_or_create(
|
||||
experiment,
|
||||
project=self.project,
|
||||
location=self.location,
|
||||
credentials=self.credentials,
|
||||
)
|
||||
except Exception as e:
|
||||
raise RuntimeError(
|
||||
f"Experiment {experiment} could not be found or created. {self.__class__.__name__} not created"
|
||||
) from e
|
||||
|
||||
if self.project not in experiment._metadata_context._project_tuple:
|
||||
raise ValueError(
|
||||
f"{self.__class__.__name__} project {self.project} does not match experiment "
|
||||
f"{experiment.name} project {experiment.project}"
|
||||
)
|
||||
|
||||
if experiment._metadata_context.location != self.location:
|
||||
raise ValueError(
|
||||
f"{self.__class__.__name__} location {self.location} does not match experiment "
|
||||
f"{experiment.name} location {experiment.location}"
|
||||
)
|
||||
|
||||
def _associate_to_experiment(self, experiment: Union[str, Experiment]):
|
||||
"""Associates this resource to the provided Experiment.
|
||||
|
||||
Args:
|
||||
experiment (Union[str, Experiment]): Required. Experiment name or experiment instance.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If Metadata service cannot associate resource to Experiment.
|
||||
"""
|
||||
experiment_name = experiment if isinstance(experiment, str) else experiment.name
|
||||
_LOGGER.info(
|
||||
"Associating %s to Experiment: %s" % (self.resource_name, experiment_name)
|
||||
)
|
||||
|
||||
try:
|
||||
if isinstance(experiment, str):
|
||||
experiment = Experiment.get_or_create(
|
||||
experiment,
|
||||
project=self.project,
|
||||
location=self.location,
|
||||
credentials=self.credentials,
|
||||
)
|
||||
experiment._log_experiment_loggable(self)
|
||||
except Exception as e:
|
||||
raise RuntimeError(
|
||||
f"{self.resource_name} could not be associated with Experiment {experiment.name}"
|
||||
) from e
|
||||
|
||||
|
||||
# maps context names to their resources classes
|
||||
# used by the Experiment implementation to filter for representations in the metadata store
|
||||
# populated at module import time from class that inherit _ExperimentLoggable
|
||||
# example mapping:
|
||||
# {Metadata Type} -> {schema title} -> {vertex sdk class}
|
||||
# Context -> 'system.PipelineRun' -> aiplatform.PipelineJob
|
||||
# Context -> 'system.ExperimentRun' -> aiplatform.ExperimentRun
|
||||
# Execution -> 'system.Run' -> aiplatform.ExperimentRun
|
||||
_SUPPORTED_LOGGABLE_RESOURCES: Dict[
|
||||
Union[Type[context.Context], Type[execution.Execution]],
|
||||
Dict[str, _ExperimentLoggable],
|
||||
] = {execution.Execution: dict(), context.Context: dict()}
|
||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,297 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
# Copyright 2021 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
import logging
|
||||
from typing import Optional
|
||||
|
||||
from google.api_core import exceptions
|
||||
from google.auth import credentials as auth_credentials
|
||||
|
||||
from google.cloud.aiplatform import base, initializer
|
||||
from google.cloud.aiplatform import compat
|
||||
from google.cloud.aiplatform import utils
|
||||
from google.cloud.aiplatform.compat.types import metadata_store as gca_metadata_store
|
||||
from google.cloud.aiplatform.constants import base as base_constants
|
||||
|
||||
|
||||
class _MetadataStore(base.VertexAiResourceNounWithFutureManager):
|
||||
"""Managed MetadataStore resource for Vertex AI"""
|
||||
|
||||
client_class = utils.MetadataClientWithOverride
|
||||
_is_client_prediction_client = False
|
||||
_resource_noun = "metadataStores"
|
||||
_getter_method = "get_metadata_store"
|
||||
_delete_method = "delete_metadata_store"
|
||||
_parse_resource_name_method = "parse_metadata_store_path"
|
||||
_format_resource_name_method = "metadata_store_path"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
metadata_store_name: Optional[str] = "default",
|
||||
project: Optional[str] = None,
|
||||
location: Optional[str] = None,
|
||||
credentials: Optional[auth_credentials.Credentials] = None,
|
||||
):
|
||||
"""Retrieves an existing MetadataStore given a MetadataStore name or ID.
|
||||
|
||||
Args:
|
||||
metadata_store_name (str):
|
||||
Optional. A fully-qualified MetadataStore resource name or metadataStore ID.
|
||||
Example: "projects/123/locations/us-central1/metadataStores/my-store" or
|
||||
"my-store" when project and location are initialized or passed.
|
||||
If not set, metadata_store_name will be set to "default".
|
||||
project (str):
|
||||
Optional project to retrieve resource from. If not set, project
|
||||
set in aiplatform.init will be used.
|
||||
location (str):
|
||||
Optional location to retrieve resource from. If not set, location
|
||||
set in aiplatform.init will be used.
|
||||
credentials (auth_credentials.Credentials):
|
||||
Custom credentials to use to upload this model. Overrides
|
||||
credentials set in aiplatform.init.
|
||||
|
||||
"""
|
||||
|
||||
super().__init__(
|
||||
project=project,
|
||||
location=location,
|
||||
credentials=credentials,
|
||||
)
|
||||
self._gca_resource = self._get_gca_resource(resource_name=metadata_store_name)
|
||||
|
||||
@classmethod
|
||||
def get_or_create(
|
||||
cls,
|
||||
metadata_store_id: str = "default",
|
||||
project: Optional[str] = None,
|
||||
location: Optional[str] = None,
|
||||
credentials: Optional[auth_credentials.Credentials] = None,
|
||||
encryption_spec_key_name: Optional[str] = None,
|
||||
) -> "_MetadataStore":
|
||||
""" "Retrieves or Creates (if it does not exist) a Metadata Store.
|
||||
|
||||
Args:
|
||||
metadata_store_id (str):
|
||||
The <metadatastore> portion of the resource name with the format:
|
||||
projects/123/locations/us-central1/metadataStores/<metadatastore>
|
||||
If not provided, the MetadataStore's ID will be set to "default" to create a default MetadataStore.
|
||||
project (str):
|
||||
Project used to retrieve or create the metadata store. Overrides project set in
|
||||
aiplatform.init.
|
||||
location (str):
|
||||
Location used to retrieve or create the metadata store. Overrides location set in
|
||||
aiplatform.init.
|
||||
credentials (auth_credentials.Credentials):
|
||||
Custom credentials used to retrieve or create the metadata store. Overrides
|
||||
credentials set in aiplatform.init.
|
||||
encryption_spec_key_name (Optional[str]):
|
||||
Optional. The Cloud KMS resource identifier of the customer
|
||||
managed encryption key used to protect the metadata store. Has the
|
||||
form:
|
||||
``projects/my-project/locations/my-region/keyRings/my-kr/cryptoKeys/my-key``.
|
||||
The key needs to be in the same region as where the compute
|
||||
resource is created.
|
||||
|
||||
If set, this MetadataStore and all sub-resources of this MetadataStore will be secured by this key.
|
||||
|
||||
Overrides encryption_spec_key_name set in aiplatform.init.
|
||||
|
||||
|
||||
Returns:
|
||||
metadata_store (_MetadataStore):
|
||||
Instantiated representation of the managed metadata store resource.
|
||||
|
||||
"""
|
||||
store = cls._get(
|
||||
metadata_store_name=metadata_store_id,
|
||||
project=project,
|
||||
location=location,
|
||||
credentials=credentials,
|
||||
)
|
||||
if not store:
|
||||
store = cls._create(
|
||||
metadata_store_id=metadata_store_id,
|
||||
project=project,
|
||||
location=location,
|
||||
credentials=credentials,
|
||||
encryption_spec_key_name=encryption_spec_key_name,
|
||||
)
|
||||
return store
|
||||
|
||||
@classmethod
|
||||
def _create(
|
||||
cls,
|
||||
metadata_store_id: str = "default",
|
||||
project: Optional[str] = None,
|
||||
location: Optional[str] = None,
|
||||
credentials: Optional[auth_credentials.Credentials] = None,
|
||||
encryption_spec_key_name: Optional[str] = None,
|
||||
) -> "_MetadataStore":
|
||||
"""Creates a new MetadataStore if it does not exist.
|
||||
|
||||
Args:
|
||||
metadata_store_id (str):
|
||||
The <metadatastore> portion of the resource name with
|
||||
the format:
|
||||
projects/123/locations/us-central1/metadataStores/<metadatastore>
|
||||
If not provided, the MetadataStore's ID will be set to "default" to create a default MetadataStore.
|
||||
project (str):
|
||||
Project used to create the metadata store. Overrides project set in
|
||||
aiplatform.init.
|
||||
location (str):
|
||||
Location used to create the metadata store. Overrides location set in
|
||||
aiplatform.init.
|
||||
credentials (auth_credentials.Credentials):
|
||||
Custom credentials used to create the metadata store. Overrides
|
||||
credentials set in aiplatform.init.
|
||||
encryption_spec_key_name (Optional[str]):
|
||||
Optional. The Cloud KMS resource identifier of the customer
|
||||
managed encryption key used to protect the metadata store. Has the
|
||||
form:
|
||||
``projects/my-project/locations/my-region/keyRings/my-kr/cryptoKeys/my-key``.
|
||||
The key needs to be in the same region as where the compute
|
||||
resource is created.
|
||||
|
||||
If set, this MetadataStore and all sub-resources of this MetadataStore will be secured by this key.
|
||||
|
||||
Overrides encryption_spec_key_name set in aiplatform.init.
|
||||
|
||||
|
||||
Returns:
|
||||
metadata_store (_MetadataStore):
|
||||
Instantiated representation of the managed metadata store resource.
|
||||
|
||||
"""
|
||||
appended_user_agent = []
|
||||
if base_constants.USER_AGENT_SDK_COMMAND:
|
||||
appended_user_agent = [
|
||||
f"sdk_command/{base_constants.USER_AGENT_SDK_COMMAND}"
|
||||
]
|
||||
# Reset the value for the USER_AGENT_SDK_COMMAND to avoid counting future unrelated api calls.
|
||||
base_constants.USER_AGENT_SDK_COMMAND = ""
|
||||
|
||||
api_client = cls._instantiate_client(
|
||||
location=location,
|
||||
credentials=credentials,
|
||||
appended_user_agent=appended_user_agent,
|
||||
)
|
||||
|
||||
gapic_metadata_store = gca_metadata_store.MetadataStore(
|
||||
encryption_spec=initializer.global_config.get_encryption_spec(
|
||||
encryption_spec_key_name=encryption_spec_key_name,
|
||||
select_version=compat.DEFAULT_VERSION,
|
||||
)
|
||||
)
|
||||
|
||||
try:
|
||||
api_client.create_metadata_store(
|
||||
parent=initializer.global_config.common_location_path(
|
||||
project=project, location=location
|
||||
),
|
||||
metadata_store=gapic_metadata_store,
|
||||
metadata_store_id=metadata_store_id,
|
||||
).result()
|
||||
except exceptions.AlreadyExists:
|
||||
logging.info(f"MetadataStore '{metadata_store_id}' already exists")
|
||||
|
||||
return cls(
|
||||
metadata_store_name=metadata_store_id,
|
||||
project=project,
|
||||
location=location,
|
||||
credentials=credentials,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _get(
|
||||
cls,
|
||||
metadata_store_name: Optional[str] = "default",
|
||||
project: Optional[str] = None,
|
||||
location: Optional[str] = None,
|
||||
credentials: Optional[auth_credentials.Credentials] = None,
|
||||
) -> Optional["_MetadataStore"]:
|
||||
"""Returns a MetadataStore resource.
|
||||
|
||||
Args:
|
||||
metadata_store_name (str):
|
||||
Optional. A fully-qualified MetadataStore resource name or metadataStore ID.
|
||||
Example: "projects/123/locations/us-central1/metadataStores/my-store" or
|
||||
"my-store" when project and location are initialized or passed.
|
||||
If not set, metadata_store_name will be set to "default".
|
||||
project (str):
|
||||
Optional project to retrieve the metadata store from. If not set, project
|
||||
set in aiplatform.init will be used.
|
||||
location (str):
|
||||
Optional location to retrieve the metadata store from. If not set, location
|
||||
set in aiplatform.init will be used.
|
||||
credentials (auth_credentials.Credentials):
|
||||
Custom credentials to retrieve this metadata store. Overrides
|
||||
credentials set in aiplatform.init.
|
||||
|
||||
Returns:
|
||||
metadata_store (Optional[_MetadataStore]):
|
||||
An optional instantiated representation of the managed Metadata Store resource.
|
||||
"""
|
||||
|
||||
try:
|
||||
return cls(
|
||||
metadata_store_name=metadata_store_name,
|
||||
project=project,
|
||||
location=location,
|
||||
credentials=credentials,
|
||||
)
|
||||
except exceptions.NotFound:
|
||||
logging.info(f"MetadataStore {metadata_store_name} not found.")
|
||||
|
||||
@classmethod
|
||||
def ensure_default_metadata_store_exists(
|
||||
cls,
|
||||
project: Optional[str] = None,
|
||||
location: Optional[str] = None,
|
||||
credentials: Optional[auth_credentials.Credentials] = None,
|
||||
encryption_key_spec_name: Optional[str] = None,
|
||||
):
|
||||
"""Helpers method to ensure the `default` MetadataStore exists in this project and location.
|
||||
|
||||
Args:
|
||||
project (str):
|
||||
Optional. Project to retrieve resource from. If not set, project
|
||||
set in aiplatform.init will be used.
|
||||
location (str):
|
||||
Optional. Location to retrieve resource from. If not set, location
|
||||
set in aiplatform.init will be used.
|
||||
credentials (auth_credentials.Credentials):
|
||||
Optional. Custom credentials to use to upload this model. Overrides
|
||||
credentials set in aiplatform.init.
|
||||
encryption_spec_key_name (str):
|
||||
Optional. The Cloud KMS resource identifier of the customer
|
||||
managed encryption key used to protect the metadata store. Has the
|
||||
form:
|
||||
``projects/my-project/locations/my-region/keyRings/my-kr/cryptoKeys/my-key``.
|
||||
The key needs to be in the same region as where the compute
|
||||
resource is created.
|
||||
|
||||
If set, this MetadataStore and all sub-resources of this MetadataStore will be secured by this key.
|
||||
|
||||
Overrides encryption_spec_key_name set in aiplatform.init.
|
||||
"""
|
||||
|
||||
cls.get_or_create(
|
||||
project=project,
|
||||
location=location,
|
||||
credentials=credentials,
|
||||
encryption_spec_key_name=encryption_key_spec_name,
|
||||
)
|
||||
@@ -0,0 +1,569 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
# Copyright 2023 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
import abc
|
||||
import collections
|
||||
import re
|
||||
import threading
|
||||
from copy import deepcopy
|
||||
from typing import Dict, Optional, Union, Any, List
|
||||
|
||||
import proto
|
||||
from google.api_core import exceptions
|
||||
from google.auth import credentials as auth_credentials
|
||||
|
||||
from google.cloud.aiplatform import base, initializer
|
||||
from google.cloud.aiplatform import metadata
|
||||
from google.cloud.aiplatform import utils
|
||||
from google.cloud.aiplatform.compat.types import artifact as gca_artifact
|
||||
from google.cloud.aiplatform.compat.types import context as gca_context
|
||||
from google.cloud.aiplatform.compat.types import execution as gca_execution
|
||||
|
||||
_LOGGER = base.Logger(__name__)
|
||||
|
||||
|
||||
class _Resource(base.VertexAiResourceNounWithFutureManager, abc.ABC):
|
||||
"""Metadata Resource for Vertex AI"""
|
||||
|
||||
client_class = utils.MetadataClientWithOverride
|
||||
_delete_method = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
resource_name: Optional[str] = None,
|
||||
resource: Optional[
|
||||
Union[gca_context.Context, gca_artifact.Artifact, gca_execution.Execution]
|
||||
] = None,
|
||||
metadata_store_id: str = "default",
|
||||
project: Optional[str] = None,
|
||||
location: Optional[str] = None,
|
||||
credentials: Optional[auth_credentials.Credentials] = None,
|
||||
):
|
||||
"""Retrieves an existing Metadata resource given a resource name or ID.
|
||||
|
||||
Args:
|
||||
resource_name (str):
|
||||
A fully-qualified resource name or ID
|
||||
Example: "projects/123/locations/us-central1/metadataStores/default/<resource_noun>/my-resource".
|
||||
or "my-resource" when project and location are initialized or passed. if ``resource`` is provided, this
|
||||
should not be set.
|
||||
resource (Union[gca_context.Context, gca_artifact.Artifact, gca_execution.Execution]):
|
||||
The proto.Message that contains the full information of the resource. If both set, this field overrides
|
||||
``resource_name`` field.
|
||||
metadata_store_id (str):
|
||||
MetadataStore to retrieve resource from. If not set, metadata_store_id is set to "default".
|
||||
If resource_name is a fully-qualified resource, its metadata_store_id overrides this one.
|
||||
project (str):
|
||||
Optional project to retrieve the resource from. If not set, project
|
||||
set in aiplatform.init will be used.
|
||||
location (str):
|
||||
Optional location to retrieve the resource from. If not set, location
|
||||
set in aiplatform.init will be used.
|
||||
credentials (auth_credentials.Credentials):
|
||||
Custom credentials to use to retrieve this resource. Overrides
|
||||
credentials set in aiplatform.init.
|
||||
"""
|
||||
|
||||
super().__init__(
|
||||
project=project,
|
||||
location=location,
|
||||
credentials=credentials,
|
||||
)
|
||||
|
||||
if resource:
|
||||
self._gca_resource = resource
|
||||
else:
|
||||
full_resource_name = utils.full_resource_name(
|
||||
resource_name=resource_name,
|
||||
resource_noun=self._resource_noun,
|
||||
parse_resource_name_method=self._parse_resource_name,
|
||||
format_resource_name_method=self._format_resource_name,
|
||||
parent_resource_name_fields={
|
||||
metadata.metadata_store._MetadataStore._resource_noun: metadata_store_id
|
||||
},
|
||||
project=self.project,
|
||||
location=self.location,
|
||||
)
|
||||
|
||||
self._gca_resource = getattr(self.api_client, self._getter_method)(
|
||||
name=full_resource_name, retry=base._DEFAULT_RETRY
|
||||
)
|
||||
|
||||
self._threading_lock = threading.Lock()
|
||||
|
||||
@property
|
||||
def metadata(self) -> Dict:
|
||||
return self.to_dict()["metadata"]
|
||||
|
||||
@property
|
||||
def schema_title(self) -> str:
|
||||
return self._gca_resource.schema_title
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return self._gca_resource.description
|
||||
|
||||
@property
|
||||
def display_name(self) -> str:
|
||||
return self._gca_resource.display_name
|
||||
|
||||
@property
|
||||
def schema_version(self) -> str:
|
||||
return self._gca_resource.schema_version
|
||||
|
||||
@classmethod
|
||||
def get_or_create(
|
||||
cls,
|
||||
resource_id: str,
|
||||
schema_title: str,
|
||||
display_name: Optional[str] = None,
|
||||
schema_version: Optional[str] = None,
|
||||
description: Optional[str] = None,
|
||||
metadata: Optional[Dict] = None,
|
||||
metadata_store_id: str = "default",
|
||||
project: Optional[str] = None,
|
||||
location: Optional[str] = None,
|
||||
credentials: Optional[auth_credentials.Credentials] = None,
|
||||
) -> "_Resource":
|
||||
"""Retrieves or Creates (if it does not exist) a Metadata resource.
|
||||
|
||||
Args:
|
||||
resource_id (str):
|
||||
Required. The <resource_id> portion of the resource name with the format:
|
||||
projects/123/locations/us-central1/metadataStores/<metadata_store_id>/<resource_noun>/<resource_id>.
|
||||
schema_title (str):
|
||||
Required. schema_title identifies the schema title used by the resource.
|
||||
display_name (str):
|
||||
Optional. The user-defined name of the resource.
|
||||
schema_version (str):
|
||||
Optional. schema_version specifies the version used by the resource.
|
||||
If not set, defaults to use the latest version.
|
||||
description (str):
|
||||
Optional. Describes the purpose of the resource to be created.
|
||||
metadata (Dict):
|
||||
Optional. Contains the metadata information that will be stored in the resource.
|
||||
metadata_store_id (str):
|
||||
The <metadata_store_id> portion of the resource name with
|
||||
the format:
|
||||
projects/123/locations/us-central1/metadataStores/<metadata_store_id>/<resource_noun>/<resource_id>
|
||||
If not provided, the MetadataStore's ID will be set to "default".
|
||||
project (str):
|
||||
Project used to retrieve or create this resource. Overrides project set in
|
||||
aiplatform.init.
|
||||
location (str):
|
||||
Location used to retrieve or create this resource. Overrides location set in
|
||||
aiplatform.init.
|
||||
credentials (auth_credentials.Credentials):
|
||||
Custom credentials used to retrieve or create this resource. Overrides
|
||||
credentials set in aiplatform.init.
|
||||
|
||||
Returns:
|
||||
resource (_Resource):
|
||||
Instantiated representation of the managed Metadata resource.
|
||||
|
||||
"""
|
||||
|
||||
resource = cls._get(
|
||||
resource_name=resource_id,
|
||||
metadata_store_id=metadata_store_id,
|
||||
project=project,
|
||||
location=location,
|
||||
credentials=credentials,
|
||||
)
|
||||
if not resource:
|
||||
_LOGGER.info(f"Creating Resource {resource_id}")
|
||||
resource = cls._create(
|
||||
resource_id=resource_id,
|
||||
schema_title=schema_title,
|
||||
display_name=display_name,
|
||||
schema_version=schema_version,
|
||||
description=description,
|
||||
metadata=metadata,
|
||||
metadata_store_id=metadata_store_id,
|
||||
project=project,
|
||||
location=location,
|
||||
credentials=credentials,
|
||||
)
|
||||
return resource
|
||||
|
||||
@classmethod
|
||||
def get(
|
||||
cls,
|
||||
resource_id: str,
|
||||
metadata_store_id: str = "default",
|
||||
project: Optional[str] = None,
|
||||
location: Optional[str] = None,
|
||||
credentials: Optional[auth_credentials.Credentials] = None,
|
||||
) -> "_Resource":
|
||||
"""Retrieves a Metadata resource.
|
||||
|
||||
Args:
|
||||
resource_id (str):
|
||||
Required. The <resource_id> portion of the resource name with the format:
|
||||
projects/123/locations/us-central1/metadataStores/<metadata_store_id>/<resource_noun>/<resource_id>.
|
||||
metadata_store_id (str):
|
||||
The <metadata_store_id> portion of the resource name with
|
||||
the format:
|
||||
projects/123/locations/us-central1/metadataStores/<metadata_store_id>/<resource_noun>/<resource_id>
|
||||
If not provided, the MetadataStore's ID will be set to "default".
|
||||
project (str):
|
||||
Project used to retrieve or create this resource. Overrides project set in
|
||||
aiplatform.init.
|
||||
location (str):
|
||||
Location used to retrieve or create this resource. Overrides location set in
|
||||
aiplatform.init.
|
||||
credentials (auth_credentials.Credentials):
|
||||
Custom credentials used to retrieve or create this resource. Overrides
|
||||
credentials set in aiplatform.init.
|
||||
|
||||
Returns:
|
||||
resource (_Resource):
|
||||
Instantiated representation of the managed Metadata resource or None if no resource was found.
|
||||
|
||||
"""
|
||||
resource = cls._get(
|
||||
resource_name=resource_id,
|
||||
metadata_store_id=metadata_store_id,
|
||||
project=project,
|
||||
location=location,
|
||||
credentials=credentials,
|
||||
)
|
||||
return resource
|
||||
|
||||
def sync_resource(self):
|
||||
"""Syncs local resource with the resource in metadata store."""
|
||||
self._gca_resource = getattr(self.api_client, self._getter_method)(
|
||||
name=self.resource_name, retry=base._DEFAULT_RETRY
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _nested_update_metadata(
|
||||
gca_resource: Union[
|
||||
gca_context.Context, gca_execution.Execution, gca_artifact.Artifact
|
||||
],
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
):
|
||||
"""Helper method to update gca_resource in place.
|
||||
|
||||
Performs a one-level deep nested update on the metadata field.
|
||||
|
||||
Args:
|
||||
gca_resource (Union[gca_context.Context, gca_execution.Execution, gca_artifact.Artifact]):
|
||||
Required. Metadata Protobuf resource. This proto's metadata will be
|
||||
updated in place.
|
||||
metadata (Dict[str, Any]):
|
||||
Optional. Metadata dictionary to merge into gca_resource.metadata.
|
||||
"""
|
||||
|
||||
if metadata:
|
||||
if gca_resource.metadata:
|
||||
for key, value in metadata.items():
|
||||
# Note: This only support nested dictionaries one level deep
|
||||
if isinstance(value, collections.abc.Mapping):
|
||||
gca_resource.metadata[key].update(value)
|
||||
else:
|
||||
gca_resource.metadata[key] = value
|
||||
else:
|
||||
gca_resource.metadata = metadata
|
||||
|
||||
def update(
|
||||
self,
|
||||
metadata: Optional[Dict] = None,
|
||||
description: Optional[str] = None,
|
||||
credentials: Optional[auth_credentials.Credentials] = None,
|
||||
location: Optional[str] = None,
|
||||
):
|
||||
"""Updates an existing Metadata resource with new metadata.
|
||||
|
||||
Args:
|
||||
metadata (Dict):
|
||||
Optional. metadata contains the updated metadata information.
|
||||
description (str):
|
||||
Optional. Description describes the resource to be updated.
|
||||
credentials (auth_credentials.Credentials):
|
||||
Custom credentials to use to update this resource. Overrides
|
||||
credentials set in aiplatform.init.
|
||||
"""
|
||||
if not hasattr(self, "_threading_lock"):
|
||||
self._threading_lock = threading.Lock()
|
||||
|
||||
with self._threading_lock:
|
||||
gca_resource = deepcopy(self._gca_resource)
|
||||
if metadata:
|
||||
self._nested_update_metadata(
|
||||
gca_resource=gca_resource, metadata=metadata
|
||||
)
|
||||
if description:
|
||||
gca_resource.description = description
|
||||
|
||||
api_client = self._instantiate_client(
|
||||
credentials=credentials, location=location
|
||||
)
|
||||
# TODO: if etag is not valid sync and retry
|
||||
update_gca_resource = self._update_resource(
|
||||
client=api_client,
|
||||
resource=gca_resource,
|
||||
)
|
||||
self._gca_resource = update_gca_resource
|
||||
|
||||
@classmethod
|
||||
def list(
|
||||
cls,
|
||||
filter: Optional[str] = None, # pylint: disable=redefined-builtin
|
||||
metadata_store_id: str = "default",
|
||||
project: Optional[str] = None,
|
||||
location: Optional[str] = None,
|
||||
credentials: Optional[auth_credentials.Credentials] = None,
|
||||
order_by: Optional[str] = None,
|
||||
) -> List["_Resource"]:
|
||||
"""List resources that match the list filter in target metadataStore.
|
||||
|
||||
Args:
|
||||
filter (str):
|
||||
Optional. A query to filter available resources for
|
||||
matching results.
|
||||
metadata_store_id (str):
|
||||
The <metadata_store_id> portion of the resource name with
|
||||
the format:
|
||||
projects/123/locations/us-central1/metadataStores/<metadata_store_id>/<resource_noun>/<resource_id>
|
||||
If not provided, the MetadataStore's ID will be set to "default".
|
||||
project (str):
|
||||
Project used to create this resource. Overrides project set in
|
||||
aiplatform.init.
|
||||
location (str):
|
||||
Location used to create this resource. Overrides location set in
|
||||
aiplatform.init.
|
||||
credentials (auth_credentials.Credentials):
|
||||
Custom credentials used to create this resource. Overrides
|
||||
credentials set in aiplatform.init.
|
||||
order_by (str):
|
||||
Optional. How the list of messages is ordered.
|
||||
Specify the values to order by and an ordering operation. The
|
||||
default sorting order is ascending. To specify descending order
|
||||
for a field, users append a " desc" suffix; for example: "foo
|
||||
desc, bar". Subfields are specified with a ``.`` character, such
|
||||
as foo.bar. see https://google.aip.dev/132#ordering for more
|
||||
details.
|
||||
|
||||
Returns:
|
||||
resources (sequence[_Resource]):
|
||||
a list of managed Metadata resource.
|
||||
|
||||
"""
|
||||
parent = (
|
||||
initializer.global_config.common_location_path(
|
||||
project=project, location=location
|
||||
)
|
||||
+ f"/metadataStores/{metadata_store_id}"
|
||||
)
|
||||
|
||||
return super().list(
|
||||
filter=filter,
|
||||
project=project,
|
||||
location=location,
|
||||
credentials=credentials,
|
||||
parent=parent,
|
||||
order_by=order_by,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _create(
|
||||
cls,
|
||||
resource_id: str,
|
||||
schema_title: str,
|
||||
display_name: Optional[str] = None,
|
||||
schema_version: Optional[str] = None,
|
||||
description: Optional[str] = None,
|
||||
metadata: Optional[Dict] = None,
|
||||
metadata_store_id: Optional[str] = "default",
|
||||
project: Optional[str] = None,
|
||||
location: Optional[str] = None,
|
||||
credentials: Optional[auth_credentials.Credentials] = None,
|
||||
) -> Optional["_Resource"]:
|
||||
"""Creates a new Metadata resource.
|
||||
|
||||
Args:
|
||||
resource_id (str):
|
||||
Required. The <resource_id> portion of the resource name with
|
||||
the format:
|
||||
projects/123/locations/us-central1/metadataStores/<metadata_store_id>/<resource_noun>/<resource_id>.
|
||||
schema_title (str):
|
||||
Required. schema_title identifies the schema title used by the resource.
|
||||
display_name (str):
|
||||
Optional. The user-defined name of the resource.
|
||||
schema_version (str):
|
||||
Optional. schema_version specifies the version used by the resource.
|
||||
If not set, defaults to use the latest version.
|
||||
description (str):
|
||||
Optional. Describes the purpose of the resource to be created.
|
||||
metadata (Dict):
|
||||
Optional. Contains the metadata information that will be stored in the resource.
|
||||
metadata_store_id (str):
|
||||
The <metadata_store_id> portion of the resource name with
|
||||
the format:
|
||||
projects/123/locations/us-central1/metadataStores/<metadata_store_id>/<resource_noun>/<resource_id>
|
||||
If not provided, the MetadataStore's ID will be set to "default".
|
||||
project (str):
|
||||
Project used to create this resource. Overrides project set in
|
||||
aiplatform.init.
|
||||
location (str):
|
||||
Location used to create this resource. Overrides location set in
|
||||
aiplatform.init.
|
||||
credentials (auth_credentials.Credentials):
|
||||
Custom credentials used to create this resource. Overrides
|
||||
credentials set in aiplatform.init.
|
||||
|
||||
Returns:
|
||||
resource (_Resource):
|
||||
Instantiated representation of the managed Metadata resource.
|
||||
|
||||
"""
|
||||
api_client = cls._instantiate_client(location=location, credentials=credentials)
|
||||
|
||||
parent = (
|
||||
initializer.global_config.common_location_path(
|
||||
project=project, location=location
|
||||
)
|
||||
+ f"/metadataStores/{metadata_store_id}"
|
||||
)
|
||||
|
||||
try:
|
||||
resource = cls._create_resource(
|
||||
client=api_client,
|
||||
parent=parent,
|
||||
resource_id=resource_id,
|
||||
schema_title=schema_title,
|
||||
display_name=display_name,
|
||||
schema_version=schema_version,
|
||||
description=description,
|
||||
metadata=metadata,
|
||||
)
|
||||
except exceptions.AlreadyExists:
|
||||
_LOGGER.info(f"Resource '{resource_id}' already exist")
|
||||
return
|
||||
|
||||
self = cls._empty_constructor(
|
||||
project=project,
|
||||
location=location,
|
||||
credentials=credentials,
|
||||
)
|
||||
|
||||
self._gca_resource = resource
|
||||
self._threading_lock = threading.Lock()
|
||||
return self
|
||||
|
||||
@classmethod
|
||||
def _get(
|
||||
cls,
|
||||
resource_name: str,
|
||||
metadata_store_id: Optional[str] = "default",
|
||||
project: Optional[str] = None,
|
||||
location: Optional[str] = None,
|
||||
credentials: Optional[auth_credentials.Credentials] = None,
|
||||
) -> Optional["_Resource"]:
|
||||
"""Returns a metadata Resource.
|
||||
|
||||
Args:
|
||||
resource_name (str):
|
||||
A fully-qualified resource name or resource ID
|
||||
Example: "projects/123/locations/us-central1/metadataStores/default/<resource_noun>/my-resource".
|
||||
or "my-resource" when project and location are initialized or passed.
|
||||
metadata_store_id (str):
|
||||
The metadata_store_id portion of the resource name with
|
||||
the format:
|
||||
projects/123/locations/us-central1/metadataStores/<metadata_store_id>/<resource_noun>/my-resource
|
||||
If not provided, the MetadataStore's ID will be set to "default".
|
||||
project (str):
|
||||
Project to get this resource from. Overrides project set in
|
||||
aiplatform.init.
|
||||
location (str):
|
||||
Location to get this resource from. Overrides location set in
|
||||
aiplatform.init.
|
||||
credentials (auth_credentials.Credentials):
|
||||
Custom credentials to use to get this resource. Overrides
|
||||
credentials set in aiplatform.init.
|
||||
|
||||
Returns:
|
||||
resource (Optional[_Resource]):
|
||||
An optional instantiated representation of the managed Metadata resource.
|
||||
|
||||
"""
|
||||
|
||||
try:
|
||||
return cls(
|
||||
resource_name,
|
||||
metadata_store_id=metadata_store_id,
|
||||
project=project,
|
||||
location=location,
|
||||
credentials=credentials,
|
||||
)
|
||||
except exceptions.NotFound:
|
||||
_LOGGER.info(f"Resource {resource_name} not found.")
|
||||
|
||||
@classmethod
|
||||
@abc.abstractmethod
|
||||
def _create_resource(
|
||||
cls,
|
||||
client: utils.MetadataClientWithOverride,
|
||||
parent: str,
|
||||
resource_id: str,
|
||||
schema_title: str,
|
||||
display_name: Optional[str] = None,
|
||||
schema_version: Optional[str] = None,
|
||||
description: Optional[str] = None,
|
||||
metadata: Optional[Dict] = None,
|
||||
) -> proto.Message:
|
||||
"""Create resource method."""
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
@abc.abstractmethod
|
||||
def _update_resource(
|
||||
cls,
|
||||
client: utils.MetadataClientWithOverride,
|
||||
resource: proto.Message,
|
||||
) -> proto.Message:
|
||||
"""Update resource method."""
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
def _extract_metadata_store_id(resource_name, resource_noun) -> str:
|
||||
"""Extracts the metadata store id from the resource name.
|
||||
|
||||
Args:
|
||||
resource_name (str):
|
||||
Required. A fully-qualified metadata resource name. For example
|
||||
projects/{project}/locations/{location}/metadataStores/{metadata_store_id}/{resource_noun}/{resource_id}.
|
||||
resource_noun (str):
|
||||
Required. The resource_noun portion of the resource_name
|
||||
Returns:
|
||||
metadata_store_id (str):
|
||||
The metadata store id for the particular resource name.
|
||||
Raises:
|
||||
ValueError: If it does not exist.
|
||||
"""
|
||||
pattern = re.compile(
|
||||
r"^projects\/(?P<project>[\w-]+)\/locations\/(?P<location>[\w-]+)\/metadataStores\/(?P<store>[\w-]+)\/"
|
||||
+ resource_noun
|
||||
+ r"\/(?P<id>[\w-]+)(?P<version>@[\w-]+)?$"
|
||||
)
|
||||
match = pattern.match(resource_name)
|
||||
if not match:
|
||||
raise ValueError(
|
||||
f"failed to extract metadata_store_id from resource {resource_name}"
|
||||
)
|
||||
return match["store"]
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,310 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
# Copyright 2023 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
import abc
|
||||
|
||||
from typing import Any, Optional, Dict, List
|
||||
|
||||
from google.auth import credentials as auth_credentials
|
||||
from google.cloud.aiplatform.compat.types import artifact as gca_artifact
|
||||
from google.cloud.aiplatform.metadata import artifact
|
||||
from google.cloud.aiplatform.constants import base as base_constants
|
||||
from google.cloud.aiplatform.metadata import constants
|
||||
|
||||
|
||||
class BaseArtifactSchema(artifact.Artifact):
|
||||
"""Base class for Metadata Artifact types."""
|
||||
|
||||
@property
|
||||
@classmethod
|
||||
@abc.abstractmethod
|
||||
def schema_title(cls) -> str:
|
||||
"""Identifies the Vertex Metadata schema title used by the resource."""
|
||||
pass
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
artifact_id: Optional[str] = None,
|
||||
uri: Optional[str] = None,
|
||||
display_name: Optional[str] = None,
|
||||
schema_version: Optional[str] = None,
|
||||
description: Optional[str] = None,
|
||||
metadata: Optional[Dict] = None,
|
||||
state: Optional[gca_artifact.Artifact.State] = gca_artifact.Artifact.State.LIVE,
|
||||
):
|
||||
|
||||
"""Initializes the Artifact with the given name, URI and metadata.
|
||||
|
||||
This is the base class for defining various artifact types, which can be
|
||||
passed to google.Artifact to create a corresponding resource.
|
||||
Artifacts carry a `metadata` field, which is a dictionary for storing
|
||||
metadata related to this artifact. Subclasses from ArtifactType can enforce
|
||||
various structure and field requirements for the metadata field.
|
||||
|
||||
Args:
|
||||
artifact_id (str):
|
||||
Optional. The <resource_id> portion of the Artifact name with
|
||||
the following format, this is globally unique in a metadataStore:
|
||||
projects/123/locations/us-central1/metadataStores/<metadata_store_id>/artifacts/<resource_id>.
|
||||
uri (str):
|
||||
Optional. The uniform resource identifier of the artifact file. May be empty if there is no actual
|
||||
artifact file.
|
||||
display_name (str):
|
||||
Optional. The user-defined name of the Artifact.
|
||||
schema_version (str):
|
||||
Optional. schema_version specifies the version used by the Artifact.
|
||||
If not set, defaults to use the latest version.
|
||||
description (str):
|
||||
Optional. Describes the purpose of the Artifact to be created.
|
||||
metadata (Dict):
|
||||
Optional. Contains the metadata information that will be stored in the Artifact.
|
||||
state (google.cloud.gapic.types.Artifact.State):
|
||||
Optional. The state of this Artifact. This is a
|
||||
property of the Artifact, and does not imply or
|
||||
capture any ongoing process. This property is
|
||||
managed by clients (such as Vertex AI
|
||||
Pipelines), and the system does not prescribe or
|
||||
check the validity of state transitions.
|
||||
"""
|
||||
# initialize the exception to resolve the FutureManager exception.
|
||||
self._exception = None
|
||||
# resource_id is not stored in the proto. Create method uses the
|
||||
# resource_id along with project_id and location to construct an
|
||||
# resource_name which is stored in the proto message.
|
||||
self.artifact_id = artifact_id
|
||||
|
||||
# Store all other attributes using the proto structure.
|
||||
self._gca_resource = gca_artifact.Artifact()
|
||||
self._gca_resource.uri = uri
|
||||
self._gca_resource.display_name = display_name
|
||||
self._gca_resource.schema_version = (
|
||||
schema_version or constants._DEFAULT_SCHEMA_VERSION
|
||||
)
|
||||
self._gca_resource.description = description
|
||||
|
||||
# If metadata is None covert to {}
|
||||
metadata = metadata if metadata else {}
|
||||
self._nested_update_metadata(self._gca_resource, metadata)
|
||||
self._gca_resource.state = state
|
||||
|
||||
# TODO() Switch to @singledispatchmethod constructor overload after py>=3.8
|
||||
def _init_with_resource_name(
|
||||
self,
|
||||
*,
|
||||
artifact_name: str,
|
||||
metadata_store_id: str = "default",
|
||||
project: Optional[str] = None,
|
||||
location: Optional[str] = None,
|
||||
credentials: Optional[auth_credentials.Credentials] = None,
|
||||
):
|
||||
|
||||
"""Initializes the Artifact instance using an existing resource.
|
||||
|
||||
Args:
|
||||
artifact_name (str):
|
||||
Artifact name with the following format, this is globally unique in a metadataStore:
|
||||
projects/123/locations/us-central1/metadataStores/<metadata_store_id>/artifacts/<resource_id>.
|
||||
metadata_store_id (str):
|
||||
Optional. MetadataStore to retrieve Artifact from. If not set, metadata_store_id is set to "default".
|
||||
If artifact_name is a fully-qualified resource, its metadata_store_id overrides this one.
|
||||
project (str):
|
||||
Optional. Project to retrieve the artifact from. If not set, project
|
||||
set in aiplatform.init will be used.
|
||||
location (str):
|
||||
Optional. Location to retrieve the Artifact from. If not set, location
|
||||
set in aiplatform.init will be used.
|
||||
credentials (auth_credentials.Credentials):
|
||||
Optional. Custom credentials to use to retrieve this Artifact. Overrides
|
||||
credentials set in aiplatform.init.
|
||||
"""
|
||||
# Add User Agent Header for metrics tracking if one is not specified
|
||||
# If one is already specified this call was initiated by a sub class.
|
||||
if not base_constants.USER_AGENT_SDK_COMMAND:
|
||||
base_constants.USER_AGENT_SDK_COMMAND = "aiplatform.metadata.schema.base_artifact.BaseArtifactSchema._init_with_resource_name"
|
||||
|
||||
super(BaseArtifactSchema, self).__init__(
|
||||
artifact_name=artifact_name,
|
||||
metadata_store_id=metadata_store_id,
|
||||
project=project,
|
||||
location=location,
|
||||
credentials=credentials,
|
||||
)
|
||||
|
||||
def create(
|
||||
self,
|
||||
*,
|
||||
metadata_store_id: Optional[str] = "default",
|
||||
project: Optional[str] = None,
|
||||
location: Optional[str] = None,
|
||||
credentials: Optional[auth_credentials.Credentials] = None,
|
||||
) -> "artifact.Artifact":
|
||||
"""Creates a new Metadata Artifact.
|
||||
|
||||
Args:
|
||||
metadata_store_id (str):
|
||||
Optional. The <metadata_store_id> portion of the resource name with
|
||||
the format:
|
||||
projects/123/locations/us-central1/metadataStores/<metadata_store_id>/artifacts/<resource_id>
|
||||
If not provided, the MetadataStore's ID will be set to "default".
|
||||
project (str):
|
||||
Optional. Project used to create this Artifact. Overrides project set in
|
||||
aiplatform.init.
|
||||
location (str):
|
||||
Optional. Location used to create this Artifact. Overrides location set in
|
||||
aiplatform.init.
|
||||
credentials (auth_credentials.Credentials):
|
||||
Optional. Custom credentials used to create this Artifact. Overrides
|
||||
credentials set in aiplatform.init.
|
||||
Returns:
|
||||
Artifact: Instantiated representation of the managed Metadata Artifact.
|
||||
"""
|
||||
# Add User Agent Header for metrics tracking.
|
||||
base_constants.USER_AGENT_SDK_COMMAND = (
|
||||
"aiplatform.metadata.schema.base_artifact.BaseArtifactSchema.create"
|
||||
)
|
||||
|
||||
# Check if metadata exists to avoid proto read error
|
||||
metadata = None
|
||||
if self._gca_resource.metadata:
|
||||
metadata = self.metadata
|
||||
|
||||
new_artifact_instance = artifact.Artifact.create(
|
||||
resource_id=self.artifact_id,
|
||||
schema_title=self.schema_title,
|
||||
uri=self.uri,
|
||||
display_name=self.display_name,
|
||||
schema_version=self.schema_version,
|
||||
description=self.description,
|
||||
metadata=metadata,
|
||||
state=self.state,
|
||||
metadata_store_id=metadata_store_id,
|
||||
project=project,
|
||||
location=location,
|
||||
credentials=credentials,
|
||||
)
|
||||
|
||||
# Reinstantiate this class using the newly created resource.
|
||||
self._init_with_resource_name(artifact_name=new_artifact_instance.resource_name)
|
||||
return self
|
||||
|
||||
@classmethod
|
||||
def list(
|
||||
cls,
|
||||
filter: Optional[str] = None, # pylint: disable=redefined-builtin
|
||||
metadata_store_id: str = "default",
|
||||
project: Optional[str] = None,
|
||||
location: Optional[str] = None,
|
||||
credentials: Optional[auth_credentials.Credentials] = None,
|
||||
order_by: Optional[str] = None,
|
||||
) -> List["BaseArtifactSchema"]:
|
||||
"""List all the Artifact resources with a particular schema.
|
||||
|
||||
Args:
|
||||
filter (str):
|
||||
Optional. A query to filter available resources for
|
||||
matching results.
|
||||
metadata_store_id (str):
|
||||
The <metadata_store_id> portion of the resource name with
|
||||
the format:
|
||||
projects/123/locations/us-central1/metadataStores/<metadata_store_id>/<resource_noun>/<resource_id>
|
||||
If not provided, the MetadataStore's ID will be set to "default".
|
||||
project (str):
|
||||
Project used to create this resource. Overrides project set in
|
||||
aiplatform.init.
|
||||
location (str):
|
||||
Location used to create this resource. Overrides location set in
|
||||
aiplatform.init.
|
||||
credentials (auth_credentials.Credentials):
|
||||
Custom credentials used to create this resource. Overrides
|
||||
credentials set in aiplatform.init.
|
||||
order_by (str):
|
||||
Optional. How the list of messages is ordered.
|
||||
Specify the values to order by and an ordering operation. The
|
||||
default sorting order is ascending. To specify descending order
|
||||
for a field, users append a " desc" suffix; for example: "foo
|
||||
desc, bar". Subfields are specified with a ``.`` character, such
|
||||
as foo.bar. see https://google.aip.dev/132#ordering for more
|
||||
details.
|
||||
|
||||
Returns:
|
||||
A list of artifact resources with a particular schema.
|
||||
|
||||
"""
|
||||
schema_filter = f'schema_title="{cls.schema_title}"'
|
||||
if filter:
|
||||
filter = f"{filter} AND {schema_filter}"
|
||||
else:
|
||||
filter = schema_filter
|
||||
|
||||
return super().list(
|
||||
filter=filter,
|
||||
metadata_store_id=metadata_store_id,
|
||||
project=project,
|
||||
location=location,
|
||||
credentials=credentials,
|
||||
)
|
||||
|
||||
def sync_resource(self):
|
||||
"""Syncs local resource with the resource in metadata store.
|
||||
|
||||
Raises:
|
||||
RuntimeError: if the artifact resource hasn't been created.
|
||||
"""
|
||||
if self._gca_resource.name:
|
||||
super().sync_resource()
|
||||
else:
|
||||
raise RuntimeError(
|
||||
f"{self.__class__.__name__} resource has not been created."
|
||||
)
|
||||
|
||||
def update(
|
||||
self,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
description: Optional[str] = None,
|
||||
credentials: Optional[auth_credentials.Credentials] = None,
|
||||
):
|
||||
"""Updates an existing Artifact resource with new metadata.
|
||||
|
||||
Args:
|
||||
metadata (Dict):
|
||||
Optional. metadata contains the updated metadata information.
|
||||
description (str):
|
||||
Optional. Description describes the resource to be updated.
|
||||
credentials (auth_credentials.Credentials):
|
||||
Custom credentials to use to update this resource. Overrides
|
||||
credentials set in aiplatform.init.
|
||||
|
||||
Raises:
|
||||
RuntimeError: if the artifact resource hasn't been created.
|
||||
"""
|
||||
if self._gca_resource.name:
|
||||
super().update(
|
||||
metadata=metadata,
|
||||
description=description,
|
||||
credentials=credentials,
|
||||
)
|
||||
else:
|
||||
raise RuntimeError(
|
||||
f"{self.__class__.__name__} resource has not been created."
|
||||
)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
if self._gca_resource.name:
|
||||
return super().__repr__()
|
||||
else:
|
||||
return f"{object.__repr__(self)}\nschema_title: {self.schema_title}"
|
||||
@@ -0,0 +1,285 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
# Copyright 2023 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
import abc
|
||||
|
||||
from typing import Dict, List, Optional, Sequence
|
||||
|
||||
from google.auth import credentials as auth_credentials
|
||||
|
||||
from google.cloud.aiplatform.compat.types import context as gca_context
|
||||
from google.cloud.aiplatform.compat.types import (
|
||||
lineage_subgraph as gca_lineage_subgraph,
|
||||
)
|
||||
from google.cloud.aiplatform.constants import base as base_constants
|
||||
from google.cloud.aiplatform.metadata import constants
|
||||
from google.cloud.aiplatform.metadata import context
|
||||
|
||||
|
||||
class BaseContextSchema(context.Context):
|
||||
"""Base class for Metadata Context schema."""
|
||||
|
||||
@property
|
||||
@classmethod
|
||||
@abc.abstractmethod
|
||||
def schema_title(cls) -> str:
|
||||
"""Identifies the Vertex Metadta schema title used by the resource."""
|
||||
pass
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
context_id: Optional[str] = None,
|
||||
display_name: Optional[str] = None,
|
||||
schema_version: Optional[str] = None,
|
||||
metadata: Optional[Dict] = None,
|
||||
description: Optional[str] = None,
|
||||
):
|
||||
|
||||
"""Initializes the Context with the given name, URI and metadata.
|
||||
|
||||
Args:
|
||||
context_id (str):
|
||||
Optional. The <resource_id> portion of the Context name with
|
||||
the following format, this is globally unique in a metadataStore.
|
||||
projects/123/locations/us-central1/metadataStores/<metadata_store_id>/Contexts/<resource_id>.
|
||||
display_name (str):
|
||||
Optional. The user-defined name of the Context.
|
||||
schema_version (str):
|
||||
Optional. schema_version specifies the version used by the Context.
|
||||
If not set, defaults to use the latest version.
|
||||
metadata (Dict):
|
||||
Optional. Contains the metadata information that will be stored in the Context.
|
||||
description (str):
|
||||
Optional. Describes the purpose of the Context to be created.
|
||||
"""
|
||||
# initialize the exception to resolve the FutureManager exception.
|
||||
self._exception = None
|
||||
# resource_id is not stored in the proto. Create method uses the
|
||||
# resource_id along with project_id and location to construct an
|
||||
# resource_name which is stored in the proto message.
|
||||
self.context_id = context_id
|
||||
|
||||
# Store all other attributes using the proto structure.
|
||||
self._gca_resource = gca_context.Context()
|
||||
self._gca_resource.display_name = display_name
|
||||
self._gca_resource.schema_version = (
|
||||
schema_version or constants._DEFAULT_SCHEMA_VERSION
|
||||
)
|
||||
# If metadata is None covert to {}
|
||||
metadata = metadata if metadata else {}
|
||||
self._nested_update_metadata(self._gca_resource, metadata)
|
||||
self._gca_resource.description = description
|
||||
|
||||
# TODO() Switch to @singledispatchmethod constructor overload after py>=3.8
|
||||
def _init_with_resource_name(
|
||||
self,
|
||||
*,
|
||||
context_name: str,
|
||||
):
|
||||
"""Initializes the Artifact instance using an existing resource.
|
||||
Args:
|
||||
context_name (str):
|
||||
Context name with the following format, this is globally unique in a metadataStore:
|
||||
projects/123/locations/us-central1/metadataStores/<metadata_store_id>/contexts/<resource_id>.
|
||||
"""
|
||||
# Add User Agent Header for metrics tracking if one is not specified
|
||||
# If one is already specified this call was initiated by a sub class.
|
||||
if not base_constants.USER_AGENT_SDK_COMMAND:
|
||||
base_constants.USER_AGENT_SDK_COMMAND = "aiplatform.metadata.schema.base_context.BaseContextSchema._init_with_resource_name"
|
||||
|
||||
super(BaseContextSchema, self).__init__(resource_name=context_name)
|
||||
|
||||
def create(
|
||||
self,
|
||||
*,
|
||||
metadata_store_id: Optional[str] = "default",
|
||||
project: Optional[str] = None,
|
||||
location: Optional[str] = None,
|
||||
credentials: Optional[auth_credentials.Credentials] = None,
|
||||
) -> "context.Context":
|
||||
"""Creates a new Metadata Context.
|
||||
|
||||
Args:
|
||||
metadata_store_id (str):
|
||||
Optional. The <metadata_store_id> portion of the resource name with
|
||||
the format:
|
||||
projects/123/locations/us-central1/metadataStores/<metadata_store_id>/Contexts/<resource_id>
|
||||
If not provided, the MetadataStore's ID will be set to "default".
|
||||
project (str):
|
||||
Optional. Project used to create this Context. Overrides project set in
|
||||
aiplatform.init.
|
||||
location (str):
|
||||
Optional. Location used to create this Context. Overrides location set in
|
||||
aiplatform.init.
|
||||
credentials (auth_credentials.Credentials):
|
||||
Optional. Custom credentials used to create this Context. Overrides
|
||||
credentials set in aiplatform.init.
|
||||
Returns:
|
||||
Context: Instantiated representation of the managed Metadata Context.
|
||||
|
||||
"""
|
||||
# Add User Agent Header for metrics tracking.
|
||||
base_constants.USER_AGENT_SDK_COMMAND = (
|
||||
"aiplatform.metadata.schema.base_context.BaseContextSchema.create"
|
||||
)
|
||||
|
||||
# Check if metadata exists to avoid proto read error
|
||||
metadata = None
|
||||
if self._gca_resource.metadata:
|
||||
metadata = self.metadata
|
||||
|
||||
new_context = context.Context.create(
|
||||
resource_id=self.context_id,
|
||||
schema_title=self.schema_title,
|
||||
display_name=self.display_name,
|
||||
schema_version=self.schema_version,
|
||||
description=self.description,
|
||||
metadata=metadata,
|
||||
metadata_store_id=metadata_store_id,
|
||||
project=project,
|
||||
location=location,
|
||||
credentials=credentials,
|
||||
)
|
||||
|
||||
# Reinstantiate this class using the newly created resource.
|
||||
self._init_with_resource_name(context_name=new_context.resource_name)
|
||||
return self
|
||||
|
||||
@classmethod
|
||||
def list(
|
||||
cls,
|
||||
filter: Optional[str] = None, # pylint: disable=redefined-builtin
|
||||
metadata_store_id: str = "default",
|
||||
project: Optional[str] = None,
|
||||
location: Optional[str] = None,
|
||||
credentials: Optional[auth_credentials.Credentials] = None,
|
||||
order_by: Optional[str] = None,
|
||||
) -> List["BaseContextSchema"]:
|
||||
"""List all the Context resources with a particular schema.
|
||||
|
||||
Args:
|
||||
filter (str):
|
||||
Optional. A query to filter available resources for
|
||||
matching results.
|
||||
metadata_store_id (str):
|
||||
The <metadata_store_id> portion of the resource name with
|
||||
the format:
|
||||
projects/123/locations/us-central1/metadataStores/<metadata_store_id>/<resource_noun>/<resource_id>
|
||||
If not provided, the MetadataStore's ID will be set to "default".
|
||||
project (str):
|
||||
Project used to create this resource. Overrides project set in
|
||||
aiplatform.init.
|
||||
location (str):
|
||||
Location used to create this resource. Overrides location set in
|
||||
aiplatform.init.
|
||||
credentials (auth_credentials.Credentials):
|
||||
Custom credentials used to create this resource. Overrides
|
||||
credentials set in aiplatform.init.
|
||||
order_by (str):
|
||||
Optional. How the list of messages is ordered.
|
||||
Specify the values to order by and an ordering operation. The
|
||||
default sorting order is ascending. To specify descending order
|
||||
for a field, users append a " desc" suffix; for example: "foo
|
||||
desc, bar". Subfields are specified with a ``.`` character, such
|
||||
as foo.bar. see https://google.aip.dev/132#ordering for more
|
||||
details.
|
||||
|
||||
Returns:
|
||||
A list of context resources with a particular schema.
|
||||
|
||||
"""
|
||||
schema_filter = f'schema_title="{cls.schema_title}"'
|
||||
if filter:
|
||||
filter = f"{filter} AND {schema_filter}"
|
||||
else:
|
||||
filter = schema_filter
|
||||
|
||||
return super().list(
|
||||
filter=filter,
|
||||
metadata_store_id=metadata_store_id,
|
||||
project=project,
|
||||
location=location,
|
||||
credentials=credentials,
|
||||
)
|
||||
|
||||
def add_artifacts_and_executions(
|
||||
self,
|
||||
artifact_resource_names: Optional[Sequence[str]] = None,
|
||||
execution_resource_names: Optional[Sequence[str]] = None,
|
||||
):
|
||||
"""Associate Executions and attribute Artifacts to a given Context.
|
||||
|
||||
Args:
|
||||
artifact_resource_names (Sequence[str]):
|
||||
Optional. The full resource name of Artifacts to attribute to
|
||||
the Context.
|
||||
execution_resource_names (Sequence[str]):
|
||||
Optional. The full resource name of Executions to associate with
|
||||
the Context.
|
||||
|
||||
Raises:
|
||||
RuntimeError: if Context resource hasn't been created.
|
||||
"""
|
||||
if self._gca_resource.name:
|
||||
super().add_artifacts_and_executions(
|
||||
artifact_resource_names=artifact_resource_names,
|
||||
execution_resource_names=execution_resource_names,
|
||||
)
|
||||
else:
|
||||
raise RuntimeError(
|
||||
f"{self.__class__.__name__} resource has not been created."
|
||||
)
|
||||
|
||||
def add_context_children(self, contexts: List[context.Context]):
|
||||
"""Adds the provided contexts as children of this context.
|
||||
|
||||
Args:
|
||||
contexts (List[_Context]): Contexts to add as children.
|
||||
|
||||
Raises:
|
||||
RuntimeError: if Context resource hasn't been created.
|
||||
"""
|
||||
if self._gca_resource.name:
|
||||
super().add_context_children(contexts)
|
||||
else:
|
||||
raise RuntimeError(
|
||||
f"{self.__class__.__name__} resource has not been created."
|
||||
)
|
||||
|
||||
def query_lineage_subgraph(self) -> gca_lineage_subgraph.LineageSubgraph:
|
||||
"""Queries lineage subgraph of this context.
|
||||
|
||||
Returns:
|
||||
lineage subgraph(gca_lineage_subgraph.LineageSubgraph):
|
||||
Lineage subgraph of this Context.
|
||||
|
||||
Raises:
|
||||
RuntimeError: if Context resource hasn't been created.
|
||||
"""
|
||||
if self._gca_resource.name:
|
||||
return super().query_lineage_subgraph()
|
||||
else:
|
||||
raise RuntimeError(
|
||||
f"{self.__class__.__name__} resource has not been created."
|
||||
)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
if self._gca_resource.name:
|
||||
return super().__repr__()
|
||||
else:
|
||||
return f"{object.__repr__(self)}\nschema_title: {self.schema_title}"
|
||||
@@ -0,0 +1,418 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
# Copyright 2023 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
import abc
|
||||
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
from google.auth import credentials as auth_credentials
|
||||
|
||||
from google.cloud.aiplatform import models
|
||||
from google.cloud.aiplatform.compat.types import execution as gca_execution
|
||||
from google.cloud.aiplatform.constants import base as base_constants
|
||||
from google.cloud.aiplatform.metadata import artifact
|
||||
from google.cloud.aiplatform.metadata import constants
|
||||
from google.cloud.aiplatform.metadata import execution
|
||||
from google.cloud.aiplatform.metadata import metadata
|
||||
|
||||
|
||||
class BaseExecutionSchema(execution.Execution):
|
||||
"""Base class for Metadata Execution schema."""
|
||||
|
||||
@property
|
||||
@classmethod
|
||||
@abc.abstractmethod
|
||||
def schema_title(cls) -> str:
|
||||
"""Identifies the Vertex Metadta schema title used by the resource."""
|
||||
pass
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
state: Optional[
|
||||
gca_execution.Execution.State
|
||||
] = gca_execution.Execution.State.RUNNING,
|
||||
execution_id: Optional[str] = None,
|
||||
display_name: Optional[str] = None,
|
||||
schema_version: Optional[str] = None,
|
||||
metadata: Optional[Dict] = None,
|
||||
description: Optional[str] = None,
|
||||
):
|
||||
|
||||
"""Initializes the Execution with the given name, URI and metadata.
|
||||
|
||||
Args:
|
||||
state (gca_execution.Execution.State.RUNNING):
|
||||
Optional. State of this Execution. Defaults to RUNNING.
|
||||
execution_id (str):
|
||||
Optional. The <resource_id> portion of the Execution name with
|
||||
the following format, this is globally unique in a metadataStore.
|
||||
projects/123/locations/us-central1/metadataStores/<metadata_store_id>/executions/<resource_id>.
|
||||
display_name (str):
|
||||
Optional. The user-defined name of the Execution.
|
||||
schema_version (str):
|
||||
Optional. schema_version specifies the version used by the Execution.
|
||||
If not set, defaults to use the latest version.
|
||||
metadata (Dict):
|
||||
Optional. Contains the metadata information that will be stored in the Execution.
|
||||
description (str):
|
||||
Optional. Describes the purpose of the Execution to be created.
|
||||
"""
|
||||
# initialize the exception to resolve the FutureManager exception.
|
||||
self._exception = None
|
||||
# resource_id is not stored in the proto. Create method uses the
|
||||
# resource_id along with project_id and location to construct an
|
||||
# resource_name which is stored in the proto message.
|
||||
self.execution_id = execution_id
|
||||
|
||||
# Store all other attributes using the proto structure.
|
||||
self._gca_resource = gca_execution.Execution()
|
||||
self._gca_resource.state = state
|
||||
self._gca_resource.display_name = display_name
|
||||
self._gca_resource.schema_version = (
|
||||
schema_version or constants._DEFAULT_SCHEMA_VERSION
|
||||
)
|
||||
# If metadata is None covert to {}
|
||||
metadata = metadata if metadata else {}
|
||||
self._nested_update_metadata(self._gca_resource, metadata)
|
||||
self._gca_resource.description = description
|
||||
|
||||
# TODO() Switch to @singledispatchmethod constructor overload after py>=3.8
|
||||
def _init_with_resource_name(
|
||||
self,
|
||||
*,
|
||||
execution_name: str,
|
||||
):
|
||||
|
||||
"""Initializes the Execution instance using an existing resource.
|
||||
Args:
|
||||
execution_name (str):
|
||||
The Execution name with the following format, this is globally unique in a metadataStore.
|
||||
projects/123/locations/us-central1/metadataStores/<metadata_store_id>/executions/<resource_id>.
|
||||
"""
|
||||
# Add User Agent Header for metrics tracking if one is not specified
|
||||
# If one is already specified this call was initiated by a sub class.
|
||||
if not base_constants.USER_AGENT_SDK_COMMAND:
|
||||
base_constants.USER_AGENT_SDK_COMMAND = "aiplatform.metadata.schema.base_execution.BaseExecutionSchema._init_with_resource_name"
|
||||
|
||||
super(BaseExecutionSchema, self).__init__(execution_name=execution_name)
|
||||
|
||||
def create(
|
||||
self,
|
||||
*,
|
||||
metadata_store_id: Optional[str] = "default",
|
||||
project: Optional[str] = None,
|
||||
location: Optional[str] = None,
|
||||
credentials: Optional[auth_credentials.Credentials] = None,
|
||||
) -> "execution.Execution":
|
||||
"""Creates a new Metadata Execution.
|
||||
|
||||
Args:
|
||||
metadata_store_id (str):
|
||||
Optional. The <metadata_store_id> portion of the resource name with
|
||||
the format:
|
||||
projects/123/locations/us-central1/metadataStores/<metadata_store_id>/executions/<resource_id>
|
||||
If not provided, the MetadataStore's ID will be set to "default".
|
||||
project (str):
|
||||
Optional. Project used to create this Execution. Overrides project set in
|
||||
aiplatform.init.
|
||||
location (str):
|
||||
Optional. Location used to create this Execution. Overrides location set in
|
||||
aiplatform.init.
|
||||
credentials (auth_credentials.Credentials):
|
||||
Optional. Custom credentials used to create this Execution. Overrides
|
||||
credentials set in aiplatform.init.
|
||||
Returns:
|
||||
Execution: Instantiated representation of the managed Metadata Execution.
|
||||
|
||||
"""
|
||||
# Add User Agent Header for metrics tracking if one is not specified
|
||||
# If one is already specified this call was initiated by a sub class.
|
||||
base_constants.USER_AGENT_SDK_COMMAND = (
|
||||
"aiplatform.metadata.schema.base_execution.BaseExecutionSchema.create"
|
||||
)
|
||||
|
||||
# Check if metadata exists to avoid proto read error
|
||||
metadata = None
|
||||
if self._gca_resource.metadata:
|
||||
metadata = self.metadata
|
||||
|
||||
new_execution_instance = execution.Execution.create(
|
||||
resource_id=self.execution_id,
|
||||
schema_title=self.schema_title,
|
||||
display_name=self.display_name,
|
||||
schema_version=self.schema_version,
|
||||
description=self.description,
|
||||
metadata=metadata,
|
||||
state=self.state,
|
||||
metadata_store_id=metadata_store_id,
|
||||
project=project,
|
||||
location=location,
|
||||
credentials=credentials,
|
||||
)
|
||||
# Reinstantiate this class using the newly created resource.
|
||||
self._init_with_resource_name(
|
||||
execution_name=new_execution_instance.resource_name
|
||||
)
|
||||
return self
|
||||
|
||||
@classmethod
|
||||
def list(
|
||||
cls,
|
||||
filter: Optional[str] = None, # pylint: disable=redefined-builtin
|
||||
metadata_store_id: str = "default",
|
||||
project: Optional[str] = None,
|
||||
location: Optional[str] = None,
|
||||
credentials: Optional[auth_credentials.Credentials] = None,
|
||||
order_by: Optional[str] = None,
|
||||
) -> List["BaseExecutionSchema"]:
|
||||
"""List all the Execution resources with a particular schema.
|
||||
|
||||
Args:
|
||||
filter (str):
|
||||
Optional. A query to filter available resources for
|
||||
matching results.
|
||||
metadata_store_id (str):
|
||||
The <metadata_store_id> portion of the resource name with
|
||||
the format:
|
||||
projects/123/locations/us-central1/metadataStores/<metadata_store_id>/<resource_noun>/<resource_id>
|
||||
If not provided, the MetadataStore's ID will be set to "default".
|
||||
project (str):
|
||||
Project used to create this resource. Overrides project set in
|
||||
aiplatform.init.
|
||||
location (str):
|
||||
Location used to create this resource. Overrides location set in
|
||||
aiplatform.init.
|
||||
credentials (auth_credentials.Credentials):
|
||||
Custom credentials used to create this resource. Overrides
|
||||
credentials set in aiplatform.init.
|
||||
order_by (str):
|
||||
Optional. How the list of messages is ordered.
|
||||
Specify the values to order by and an ordering operation. The
|
||||
default sorting order is ascending. To specify descending order
|
||||
for a field, users append a " desc" suffix; for example: "foo
|
||||
desc, bar". Subfields are specified with a ``.`` character, such
|
||||
as foo.bar. see https://google.aip.dev/132#ordering for more
|
||||
details.
|
||||
|
||||
Returns:
|
||||
A list of execution resources with a particular schema.
|
||||
|
||||
"""
|
||||
schema_filter = f'schema_title="{cls.schema_title}"'
|
||||
if filter:
|
||||
filter = f"{filter} AND {schema_filter}"
|
||||
else:
|
||||
filter = schema_filter
|
||||
|
||||
return super().list(
|
||||
filter=filter,
|
||||
metadata_store_id=metadata_store_id,
|
||||
project=project,
|
||||
location=location,
|
||||
credentials=credentials,
|
||||
)
|
||||
|
||||
def start_execution(
|
||||
self,
|
||||
*,
|
||||
metadata_store_id: Optional[str] = "default",
|
||||
resume: bool = False,
|
||||
project: Optional[str] = None,
|
||||
location: Optional[str] = None,
|
||||
credentials: Optional[auth_credentials.Credentials] = None,
|
||||
) -> "execution.Execution":
|
||||
"""Create and starts a new Metadata Execution or resumes a previously created Execution.
|
||||
|
||||
This method is similar to create_execution with additional support for Experiments.
|
||||
If an Experiment is set prior to running this command, the Experiment will be
|
||||
associtaed with the created execution, otherwise this method behaves the same
|
||||
as create_execution.
|
||||
|
||||
To start a new execution:
|
||||
```
|
||||
instance_of_execution_schema = execution_schema.ContainerExecution(...)
|
||||
with instance_of_execution_schema.start_execution() as exc:
|
||||
exc.assign_input_artifacts([my_artifact])
|
||||
model = aiplatform.Artifact.create(uri='gs://my-uri', schema_title='system.Model')
|
||||
exc.assign_output_artifacts([model])
|
||||
```
|
||||
|
||||
To continue a previously created execution:
|
||||
```
|
||||
with execution_schema.ContainerExecution(resource_id='my-exc', resume=True) as exc:
|
||||
...
|
||||
```
|
||||
Args:
|
||||
metadata_store_id (str):
|
||||
Optional. The <metadata_store_id> portion of the resource name with
|
||||
the format:
|
||||
projects/123/locations/us-central1/metadataStores/<metadata_store_id>/executions/<executions_id>
|
||||
If not provided, the MetadataStore's ID will be set to "default". Currently only the 'default'
|
||||
MetadataStore ID is supported.
|
||||
resume (bool):
|
||||
Resume an existing execution.
|
||||
project (str):
|
||||
Optional. Project used to create this Execution. Overrides project set in
|
||||
aiplatform.init.
|
||||
location (str):
|
||||
Optional. Location used to create this Execution. Overrides location set in
|
||||
aiplatform.init.
|
||||
credentials (auth_credentials.Credentials):
|
||||
Optional. Custom credentials used to create this Execution. Overrides
|
||||
credentials set in aiplatform.init.
|
||||
Returns:
|
||||
Execution: Instantiated representation of the managed Metadata Execution.
|
||||
Raises:
|
||||
ValueError: If metadata_store_id other than 'default' is provided.
|
||||
"""
|
||||
# Add User Agent Header for metrics tracking if one is not specified
|
||||
# If one is already specified this call was initiated by a sub class.
|
||||
|
||||
base_constants.USER_AGENT_SDK_COMMAND = "aiplatform.metadata.schema.base_execution.BaseExecutionSchema.start_execution"
|
||||
|
||||
if metadata_store_id != "default":
|
||||
raise ValueError(
|
||||
f"metadata_store_id {metadata_store_id} is not supported. Only the default MetadataStore ID is supported."
|
||||
)
|
||||
|
||||
new_execution_instance = metadata._ExperimentTracker().start_execution(
|
||||
schema_title=self.schema_title,
|
||||
display_name=self.display_name,
|
||||
resource_id=self.execution_id,
|
||||
metadata=self.metadata,
|
||||
schema_version=self.schema_version,
|
||||
description=self.description,
|
||||
# TODO: Add support for metadata_store_id once it is supported in experiment.
|
||||
resume=resume,
|
||||
project=project,
|
||||
location=location,
|
||||
credentials=credentials,
|
||||
)
|
||||
|
||||
# Reinstantiate this class using the newly created resource.
|
||||
self._init_with_resource_name(
|
||||
execution_name=new_execution_instance.resource_name
|
||||
)
|
||||
return self
|
||||
|
||||
def assign_input_artifacts(
|
||||
self, artifacts: List[Union[artifact.Artifact, models.Model]]
|
||||
):
|
||||
"""Assigns Artifacts as inputs to this Executions.
|
||||
|
||||
Args:
|
||||
artifacts (List[Union[artifact.Artifact, models.Model]]):
|
||||
Required. Artifacts to assign as input.
|
||||
|
||||
Raises:
|
||||
RuntimeError: if Execution resource hasn't been created.
|
||||
"""
|
||||
if self._gca_resource.name:
|
||||
super().assign_input_artifacts(artifacts)
|
||||
else:
|
||||
raise RuntimeError(
|
||||
f"{self.__class__.__name__} resource has not been created."
|
||||
)
|
||||
|
||||
def assign_output_artifacts(
|
||||
self, artifacts: List[Union[artifact.Artifact, models.Model]]
|
||||
):
|
||||
"""Assigns Artifacts as outputs to this Executions.
|
||||
|
||||
Args:
|
||||
artifacts (List[Union[artifact.Artifact, models.Model]]):
|
||||
Required. Artifacts to assign as input.
|
||||
|
||||
Raises:
|
||||
RuntimeError: if Execution resource hasn't been created.
|
||||
"""
|
||||
if self._gca_resource.name:
|
||||
super().assign_output_artifacts(artifacts)
|
||||
else:
|
||||
raise RuntimeError(
|
||||
f"{self.__class__.__name__} resource has not been created."
|
||||
)
|
||||
|
||||
def get_input_artifacts(self) -> List[artifact.Artifact]:
|
||||
"""Get the input Artifacts of this Execution.
|
||||
|
||||
Returns:
|
||||
List of input Artifacts.
|
||||
|
||||
Raises:
|
||||
RuntimeError: if Execution resource hasn't been created.
|
||||
"""
|
||||
if self._gca_resource.name:
|
||||
return super().get_input_artifacts()
|
||||
else:
|
||||
raise RuntimeError(
|
||||
f"{self.__class__.__name__} resource has not been created."
|
||||
)
|
||||
|
||||
def get_output_artifacts(self) -> List[artifact.Artifact]:
|
||||
"""Get the output Artifacts of this Execution.
|
||||
|
||||
Returns:
|
||||
List of output Artifacts.
|
||||
|
||||
Raises:
|
||||
RuntimeError: if Execution resource hasn't been created.
|
||||
"""
|
||||
if self._gca_resource.name:
|
||||
return super().get_output_artifacts()
|
||||
else:
|
||||
raise RuntimeError(
|
||||
f"{self.__class__.__name__} resource has not been created."
|
||||
)
|
||||
|
||||
def update(
|
||||
self,
|
||||
state: Optional[gca_execution.Execution.State] = None,
|
||||
description: Optional[str] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
):
|
||||
"""Update this Execution.
|
||||
|
||||
Args:
|
||||
state (gca_execution.Execution.State):
|
||||
Optional. State of this Execution.
|
||||
description (str):
|
||||
Optional. Describes the purpose of the Execution to be created.
|
||||
metadata (Dict[str, Any):
|
||||
Optional. Contains the metadata information that will be stored
|
||||
in the Execution.
|
||||
|
||||
Raises:
|
||||
RuntimeError: if Execution resource hasn't been created.
|
||||
"""
|
||||
if self._gca_resource.name:
|
||||
super().update(
|
||||
state=state,
|
||||
description=description,
|
||||
metadata=metadata,
|
||||
)
|
||||
else:
|
||||
raise RuntimeError(
|
||||
f"{self.__class__.__name__} resource has not been created."
|
||||
)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
if self._gca_resource.name:
|
||||
return super().__repr__()
|
||||
else:
|
||||
return f"{object.__repr__(self)}\nschema_title: {self.schema_title}"
|
||||
Binary file not shown.
File diff suppressed because it is too large
Load Diff
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,265 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
# Copyright 2022 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
import copy
|
||||
from typing import Optional, Dict
|
||||
|
||||
from google.cloud.aiplatform.compat.types import artifact as gca_artifact
|
||||
from google.cloud.aiplatform.metadata.schema import base_artifact
|
||||
|
||||
|
||||
class Model(base_artifact.BaseArtifactSchema):
|
||||
"""Artifact type for model."""
|
||||
|
||||
schema_title = "system.Model"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
uri: Optional[str] = None,
|
||||
artifact_id: Optional[str] = None,
|
||||
display_name: Optional[str] = None,
|
||||
schema_version: Optional[str] = None,
|
||||
description: Optional[str] = None,
|
||||
metadata: Optional[Dict] = None,
|
||||
state: Optional[gca_artifact.Artifact.State] = gca_artifact.Artifact.State.LIVE,
|
||||
):
|
||||
"""Args:
|
||||
uri (str):
|
||||
Optional. The uniform resource identifier of the artifact file. May be empty if there is no actual
|
||||
artifact file.
|
||||
artifact_id (str):
|
||||
Optional. The <resource_id> portion of the Artifact name with
|
||||
the format. This is globally unique in a metadataStore:
|
||||
projects/123/locations/us-central1/metadataStores/<metadata_store_id>/artifacts/<resource_id>.
|
||||
display_name (str):
|
||||
Optional. The user-defined name of the base.
|
||||
schema_version (str):
|
||||
Optional. schema_version specifies the version used by the base.
|
||||
If not set, defaults to use the latest version.
|
||||
description (str):
|
||||
Optional. Describes the purpose of the Artifact to be created.
|
||||
metadata (Dict):
|
||||
Optional. Contains the metadata information that will be stored in the Artifact.
|
||||
state (google.cloud.gapic.types.Artifact.State):
|
||||
Optional. The state of this Artifact. This is a
|
||||
property of the Artifact, and does not imply or
|
||||
capture any ongoing process. This property is
|
||||
managed by clients (such as Vertex AI
|
||||
Pipelines), and the system does not prescribe or
|
||||
check the validity of state transitions.
|
||||
"""
|
||||
extended_metadata = copy.deepcopy(metadata) if metadata else {}
|
||||
super(Model, self).__init__(
|
||||
uri=uri,
|
||||
artifact_id=artifact_id,
|
||||
display_name=display_name,
|
||||
schema_version=schema_version,
|
||||
description=description,
|
||||
metadata=extended_metadata,
|
||||
state=state,
|
||||
)
|
||||
|
||||
|
||||
class Artifact(base_artifact.BaseArtifactSchema):
|
||||
"""A generic artifact."""
|
||||
|
||||
schema_title = "system.Artifact"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
uri: Optional[str] = None,
|
||||
artifact_id: Optional[str] = None,
|
||||
display_name: Optional[str] = None,
|
||||
schema_version: Optional[str] = None,
|
||||
description: Optional[str] = None,
|
||||
metadata: Optional[Dict] = None,
|
||||
state: Optional[gca_artifact.Artifact.State] = gca_artifact.Artifact.State.LIVE,
|
||||
):
|
||||
"""Args:
|
||||
uri (str):
|
||||
Optional. The uniform resource identifier of the artifact file. May be empty if there is no actual
|
||||
artifact file.
|
||||
artifact_id (str):
|
||||
Optional. The <resource_id> portion of the Artifact name with
|
||||
the format. This is globally unique in a metadataStore:
|
||||
projects/123/locations/us-central1/metadataStores/<metadata_store_id>/artifacts/<resource_id>.
|
||||
display_name (str):
|
||||
Optional. The user-defined name of the base.
|
||||
schema_version (str):
|
||||
Optional. schema_version specifies the version used by the base.
|
||||
If not set, defaults to use the latest version.
|
||||
description (str):
|
||||
Optional. Describes the purpose of the Artifact to be created.
|
||||
metadata (Dict):
|
||||
Optional. Contains the metadata information that will be stored in the Artifact.
|
||||
state (google.cloud.gapic.types.Artifact.State):
|
||||
Optional. The state of this Artifact. This is a
|
||||
property of the Artifact, and does not imply or
|
||||
capture any ongoing process. This property is
|
||||
managed by clients (such as Vertex AI
|
||||
Pipelines), and the system does not prescribe or
|
||||
check the validity of state transitions.
|
||||
"""
|
||||
extended_metadata = copy.deepcopy(metadata) if metadata else {}
|
||||
super(Artifact, self).__init__(
|
||||
uri=uri,
|
||||
artifact_id=artifact_id,
|
||||
display_name=display_name,
|
||||
schema_version=schema_version,
|
||||
description=description,
|
||||
metadata=extended_metadata,
|
||||
state=state,
|
||||
)
|
||||
|
||||
|
||||
class Dataset(base_artifact.BaseArtifactSchema):
|
||||
"""An artifact representing a system Dataset."""
|
||||
|
||||
schema_title = "system.Dataset"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
uri: Optional[str] = None,
|
||||
artifact_id: Optional[str] = None,
|
||||
display_name: Optional[str] = None,
|
||||
schema_version: Optional[str] = None,
|
||||
description: Optional[str] = None,
|
||||
metadata: Optional[Dict] = None,
|
||||
state: Optional[gca_artifact.Artifact.State] = gca_artifact.Artifact.State.LIVE,
|
||||
):
|
||||
"""Args:
|
||||
uri (str):
|
||||
Optional. The uniform resource identifier of the artifact file. May be empty if there is no actual
|
||||
artifact file.
|
||||
artifact_id (str):
|
||||
Optional. The <resource_id> portion of the Artifact name with
|
||||
the format. This is globally unique in a metadataStore:
|
||||
projects/123/locations/us-central1/metadataStores/<metadata_store_id>/artifacts/<resource_id>.
|
||||
display_name (str):
|
||||
Optional. The user-defined name of the base.
|
||||
schema_version (str):
|
||||
Optional. schema_version specifies the version used by the base.
|
||||
If not set, defaults to use the latest version.
|
||||
description (str):
|
||||
Optional. Describes the purpose of the Artifact to be created.
|
||||
metadata (Dict):
|
||||
Optional. Contains the metadata information that will be stored in the Artifact.
|
||||
state (google.cloud.gapic.types.Artifact.State):
|
||||
Optional. The state of this Artifact. This is a
|
||||
property of the Artifact, and does not imply or
|
||||
capture any ongoing process. This property is
|
||||
managed by clients (such as Vertex AI
|
||||
Pipelines), and the system does not prescribe or
|
||||
check the validity of state transitions.
|
||||
"""
|
||||
extended_metadata = copy.deepcopy(metadata) if metadata else {}
|
||||
super(Dataset, self).__init__(
|
||||
uri=uri,
|
||||
artifact_id=artifact_id,
|
||||
display_name=display_name,
|
||||
schema_version=schema_version,
|
||||
description=description,
|
||||
metadata=extended_metadata,
|
||||
state=state,
|
||||
)
|
||||
|
||||
|
||||
class Metrics(base_artifact.BaseArtifactSchema):
|
||||
"""Artifact schema for scalar metrics."""
|
||||
|
||||
schema_title = "system.Metrics"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
accuracy: Optional[float] = None,
|
||||
precision: Optional[float] = None,
|
||||
recall: Optional[float] = None,
|
||||
f1score: Optional[float] = None,
|
||||
mean_absolute_error: Optional[float] = None,
|
||||
mean_squared_error: Optional[float] = None,
|
||||
uri: Optional[str] = None,
|
||||
artifact_id: Optional[str] = None,
|
||||
display_name: Optional[str] = None,
|
||||
schema_version: Optional[str] = None,
|
||||
description: Optional[str] = None,
|
||||
metadata: Optional[Dict] = None,
|
||||
state: Optional[gca_artifact.Artifact.State] = gca_artifact.Artifact.State.LIVE,
|
||||
):
|
||||
"""Args:
|
||||
accuracy (float):
|
||||
Optional.
|
||||
precision (float):
|
||||
Optional.
|
||||
recall (float):
|
||||
Optional.
|
||||
f1score (float):
|
||||
Optional.
|
||||
mean_absolute_error (float):
|
||||
Optional.
|
||||
mean_squared_error (float):
|
||||
Optional.
|
||||
uri (str):
|
||||
Optional. The uniform resource identifier of the artifact file. May be empty if there is no actual
|
||||
artifact file.
|
||||
artifact_id (str):
|
||||
Optional. The <resource_id> portion of the Artifact name with
|
||||
the format. This is globally unique in a metadataStore:
|
||||
projects/123/locations/us-central1/metadataStores/<metadata_store_id>/artifacts/<resource_id>.
|
||||
display_name (str):
|
||||
Optional. The user-defined name of the base.
|
||||
schema_version (str):
|
||||
Optional. schema_version specifies the version used by the base.
|
||||
If not set, defaults to use the latest version.
|
||||
description (str):
|
||||
Optional. Describes the purpose of the Artifact to be created.
|
||||
metadata (Dict):
|
||||
Optional. Contains the metadata information that will be stored in the Artifact.
|
||||
state (google.cloud.gapic.types.Artifact.State):
|
||||
Optional. The state of this Artifact. This is a
|
||||
property of the Artifact, and does not imply or
|
||||
capture any ongoing process. This property is
|
||||
managed by clients (such as Vertex AI
|
||||
Pipelines), and the system does not prescribe or
|
||||
check the validity of state transitions.
|
||||
"""
|
||||
extended_metadata = copy.deepcopy(metadata) if metadata else {}
|
||||
if accuracy:
|
||||
extended_metadata["accuracy"] = accuracy
|
||||
if precision:
|
||||
extended_metadata["precision"] = precision
|
||||
if recall:
|
||||
extended_metadata["recall"] = recall
|
||||
if f1score:
|
||||
extended_metadata["f1score"] = f1score
|
||||
if mean_absolute_error:
|
||||
extended_metadata["mean_absolute_error"] = mean_absolute_error
|
||||
if mean_squared_error:
|
||||
extended_metadata["mean_squared_error"] = mean_squared_error
|
||||
|
||||
super(Metrics, self).__init__(
|
||||
uri=uri,
|
||||
artifact_id=artifact_id,
|
||||
display_name=display_name,
|
||||
schema_version=schema_version,
|
||||
description=description,
|
||||
metadata=extended_metadata,
|
||||
state=state,
|
||||
)
|
||||
@@ -0,0 +1,185 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
# Copyright 2022 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
import copy
|
||||
from typing import Optional, Dict
|
||||
|
||||
from google.cloud.aiplatform.metadata.schema import base_context
|
||||
|
||||
|
||||
class Experiment(base_context.BaseContextSchema):
|
||||
"""Context schema for a Experiment context."""
|
||||
|
||||
schema_title = "system.Experiment"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
context_id: Optional[str] = None,
|
||||
display_name: Optional[str] = None,
|
||||
schema_version: Optional[str] = None,
|
||||
metadata: Optional[Dict] = None,
|
||||
description: Optional[str] = None,
|
||||
):
|
||||
"""Args:
|
||||
context_id (str):
|
||||
Optional. The <resource_id> portion of the context name with
|
||||
the following format, this is globally unique in a metadataStore.
|
||||
projects/123/locations/us-central1/metadataStores/<metadata_store_id>/contexts/<resource_id>.
|
||||
display_name (str):
|
||||
Optional. The user-defined name of the context.
|
||||
schema_version (str):
|
||||
Optional. schema_version specifies the version used by the context.
|
||||
If not set, defaults to use the latest version.
|
||||
metadata (Dict):
|
||||
Optional. Contains the metadata information that will be stored in the context.
|
||||
description (str):
|
||||
Optional. Describes the purpose of the context to be created.
|
||||
"""
|
||||
extended_metadata = copy.deepcopy(metadata) if metadata else {}
|
||||
super(Experiment, self).__init__(
|
||||
context_id=context_id,
|
||||
display_name=display_name,
|
||||
schema_version=schema_version,
|
||||
description=description,
|
||||
metadata=extended_metadata,
|
||||
)
|
||||
|
||||
|
||||
class ExperimentRun(base_context.BaseContextSchema):
|
||||
"""Context schema for a ExperimentRun context."""
|
||||
|
||||
schema_title = "system.ExperimentRun"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
experiment_id: Optional[str] = None,
|
||||
context_id: Optional[str] = None,
|
||||
display_name: Optional[str] = None,
|
||||
schema_version: Optional[str] = None,
|
||||
metadata: Optional[Dict] = None,
|
||||
description: Optional[str] = None,
|
||||
):
|
||||
"""Args:
|
||||
experiment_id (str):
|
||||
Optional. The experiment_id that this experiment_run belongs to.
|
||||
context_id (str):
|
||||
Optional. The <resource_id> portion of the context name with
|
||||
the following format, this is globally unique in a metadataStore.
|
||||
projects/123/locations/us-central1/metadataStores/<metadata_store_id>/contexts/<resource_id>.
|
||||
display_name (str):
|
||||
Optional. The user-defined name of the context.
|
||||
schema_version (str):
|
||||
Optional. schema_version specifies the version used by the context.
|
||||
If not set, defaults to use the latest version.
|
||||
metadata (Dict):
|
||||
Optional. Contains the metadata information that will be stored in the context.
|
||||
description (str):
|
||||
Optional. Describes the purpose of the context to be created.
|
||||
"""
|
||||
extended_metadata = copy.deepcopy(metadata) if metadata else {}
|
||||
extended_metadata["experiment_id"] = experiment_id
|
||||
super(ExperimentRun, self).__init__(
|
||||
context_id=context_id,
|
||||
display_name=display_name,
|
||||
schema_version=schema_version,
|
||||
description=description,
|
||||
metadata=extended_metadata,
|
||||
)
|
||||
|
||||
|
||||
class Pipeline(base_context.BaseContextSchema):
|
||||
"""Context schema for a Pipeline context."""
|
||||
|
||||
schema_title = "system.Pipeline"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
context_id: Optional[str] = None,
|
||||
display_name: Optional[str] = None,
|
||||
schema_version: Optional[str] = None,
|
||||
metadata: Optional[Dict] = None,
|
||||
description: Optional[str] = None,
|
||||
):
|
||||
"""Args:
|
||||
context_id (str):
|
||||
Optional. The <resource_id> portion of the context name with
|
||||
the following format, this is globally unique in a metadataStore.
|
||||
projects/123/locations/us-central1/metadataStores/<metadata_store_id>/contexts/<resource_id>.
|
||||
display_name (str):
|
||||
Optional. The user-defined name of the context.
|
||||
schema_version (str):
|
||||
Optional. schema_version specifies the version used by the context.
|
||||
If not set, defaults to use the latest version.
|
||||
metadata (Dict):
|
||||
Optional. Contains the metadata information that will be stored in the context.
|
||||
description (str):
|
||||
Optional. Describes the purpose of the context to be created.
|
||||
"""
|
||||
extended_metadata = copy.deepcopy(metadata) if metadata else {}
|
||||
super(Pipeline, self).__init__(
|
||||
context_id=context_id,
|
||||
display_name=display_name,
|
||||
schema_version=schema_version,
|
||||
description=description,
|
||||
metadata=extended_metadata,
|
||||
)
|
||||
|
||||
|
||||
class PipelineRun(base_context.BaseContextSchema):
|
||||
"""Context schema for a PipelineRun context."""
|
||||
|
||||
schema_title = "system.PipelineRun"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
pipeline_id: Optional[str] = None,
|
||||
context_id: Optional[str] = None,
|
||||
display_name: Optional[str] = None,
|
||||
schema_version: Optional[str] = None,
|
||||
metadata: Optional[Dict] = None,
|
||||
description: Optional[str] = None,
|
||||
):
|
||||
"""Args:
|
||||
pipeline_id (str):
|
||||
Optional. PipelineJob resource name corresponding to this run.
|
||||
context_id (str):
|
||||
Optional. The <resource_id> portion of the context name with
|
||||
the following format, this is globally unique in a metadataStore.
|
||||
projects/123/locations/us-central1/metadataStores/<metadata_store_id>/contexts/<resource_id>.
|
||||
display_name (str):
|
||||
Optional. The user-defined name of the context.
|
||||
schema_version (str):
|
||||
Optional. schema_version specifies the version used by the context.
|
||||
If not set, defaults to use the latest version.
|
||||
metadata (Dict):
|
||||
Optional. Contains the metadata information that will be stored in the context.
|
||||
description (str):
|
||||
Optional. Describes the purpose of the context to be created.
|
||||
"""
|
||||
extended_metadata = copy.deepcopy(metadata) if metadata else {}
|
||||
extended_metadata["pipeline_id"] = pipeline_id
|
||||
super(PipelineRun, self).__init__(
|
||||
context_id=context_id,
|
||||
display_name=display_name,
|
||||
schema_version=schema_version,
|
||||
description=description,
|
||||
metadata=extended_metadata,
|
||||
)
|
||||
@@ -0,0 +1,157 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
# Copyright 2022 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
import copy
|
||||
from typing import Optional, Dict
|
||||
|
||||
from google.cloud.aiplatform.compat.types import execution as gca_execution
|
||||
from google.cloud.aiplatform.metadata.schema import base_execution
|
||||
|
||||
|
||||
class ContainerExecution(base_execution.BaseExecutionSchema):
|
||||
"""Execution schema for a container execution."""
|
||||
|
||||
schema_title = "system.ContainerExecution"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
state: Optional[
|
||||
gca_execution.Execution.State
|
||||
] = gca_execution.Execution.State.RUNNING,
|
||||
execution_id: Optional[str] = None,
|
||||
display_name: Optional[str] = None,
|
||||
schema_version: Optional[str] = None,
|
||||
metadata: Optional[Dict] = None,
|
||||
description: Optional[str] = None,
|
||||
):
|
||||
"""Args:
|
||||
state (gca_execution.Execution.State.RUNNING):
|
||||
Optional. State of this Execution. Defaults to RUNNING.
|
||||
execution_id (str):
|
||||
Optional. The <resource_id> portion of the Execution name with
|
||||
the following format, this is globally unique in a metadataStore.
|
||||
projects/123/locations/us-central1/metadataStores/<metadata_store_id>/executions/<resource_id>.
|
||||
display_name (str):
|
||||
Optional. The user-defined name of the Execution.
|
||||
schema_version (str):
|
||||
Optional. schema_version specifies the version used by the Execution.
|
||||
If not set, defaults to use the latest version.
|
||||
metadata (Dict):
|
||||
Optional. Contains the metadata information that will be stored in the Execution.
|
||||
description (str):
|
||||
Optional. Describes the purpose of the Execution to be created.
|
||||
"""
|
||||
extended_metadata = copy.deepcopy(metadata) if metadata else {}
|
||||
super(ContainerExecution, self).__init__(
|
||||
execution_id=execution_id,
|
||||
state=state,
|
||||
display_name=display_name,
|
||||
schema_version=schema_version,
|
||||
description=description,
|
||||
metadata=extended_metadata,
|
||||
)
|
||||
|
||||
|
||||
class CustomJobExecution(base_execution.BaseExecutionSchema):
|
||||
"""Execution schema for a custom job execution."""
|
||||
|
||||
schema_title = "system.CustomJobExecution"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
state: Optional[
|
||||
gca_execution.Execution.State
|
||||
] = gca_execution.Execution.State.RUNNING,
|
||||
execution_id: Optional[str] = None,
|
||||
display_name: Optional[str] = None,
|
||||
schema_version: Optional[str] = None,
|
||||
metadata: Optional[Dict] = None,
|
||||
description: Optional[str] = None,
|
||||
):
|
||||
"""Args:
|
||||
state (gca_execution.Execution.State.RUNNING):
|
||||
Optional. State of this Execution. Defaults to RUNNING.
|
||||
execution_id (str):
|
||||
Optional. The <resource_id> portion of the Execution name with
|
||||
the following format, this is globally unique in a metadataStore.
|
||||
projects/123/locations/us-central1/metadataStores/<metadata_store_id>/executions/<resource_id>.
|
||||
display_name (str):
|
||||
Optional. The user-defined name of the Execution.
|
||||
schema_version (str):
|
||||
Optional. schema_version specifies the version used by the Execution.
|
||||
If not set, defaults to use the latest version.
|
||||
metadata (Dict):
|
||||
Optional. Contains the metadata information that will be stored in the Execution.
|
||||
description (str):
|
||||
Optional. Describes the purpose of the Execution to be created.
|
||||
"""
|
||||
extended_metadata = copy.deepcopy(metadata) if metadata else {}
|
||||
super(CustomJobExecution, self).__init__(
|
||||
execution_id=execution_id,
|
||||
state=state,
|
||||
display_name=display_name,
|
||||
schema_version=schema_version,
|
||||
description=description,
|
||||
metadata=extended_metadata,
|
||||
)
|
||||
|
||||
|
||||
class Run(base_execution.BaseExecutionSchema):
|
||||
"""Execution schema for root run execution."""
|
||||
|
||||
schema_title = "system.Run"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
state: Optional[
|
||||
gca_execution.Execution.State
|
||||
] = gca_execution.Execution.State.RUNNING,
|
||||
execution_id: Optional[str] = None,
|
||||
display_name: Optional[str] = None,
|
||||
schema_version: Optional[str] = None,
|
||||
metadata: Optional[Dict] = None,
|
||||
description: Optional[str] = None,
|
||||
):
|
||||
"""Args:
|
||||
state (gca_execution.Execution.State.RUNNING):
|
||||
Optional. State of this Execution. Defaults to RUNNING.
|
||||
execution_id (str):
|
||||
Optional. The <resource_id> portion of the Execution name with
|
||||
the following format, this is globally unique in a metadataStore.
|
||||
projects/123/locations/us-central1/metadataStores/<metadata_store_id>/executions/<resource_id>.
|
||||
display_name (str):
|
||||
Optional. The user-defined name of the Execution.
|
||||
schema_version (str):
|
||||
Optional. schema_version specifies the version used by the Execution.
|
||||
If not set, defaults to use the latest version.
|
||||
metadata (Dict):
|
||||
Optional. Contains the metadata information that will be stored in the Execution.
|
||||
description (str):
|
||||
Optional. Describes the purpose of the Execution to be created.
|
||||
"""
|
||||
extended_metadata = copy.deepcopy(metadata) if metadata else {}
|
||||
super(Run, self).__init__(
|
||||
execution_id=execution_id,
|
||||
state=state,
|
||||
display_name=display_name,
|
||||
schema_version=schema_version,
|
||||
description=description,
|
||||
metadata=extended_metadata,
|
||||
)
|
||||
@@ -0,0 +1,360 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
# Copyright 2022 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
import re
|
||||
|
||||
from typing import Optional, Dict, List
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
||||
@dataclass
|
||||
class PredictSchemata:
|
||||
"""A class holding instance, parameter and prediction schema uris.
|
||||
|
||||
Args:
|
||||
instance_schema_uri (str):
|
||||
Required. Points to a YAML file stored on Google Cloud Storage
|
||||
describing the format of a single instance, which are used in
|
||||
PredictRequest.instances, ExplainRequest.instances and
|
||||
BatchPredictionJob.input_config. The schema is defined as an
|
||||
OpenAPI 3.0.2 `Schema Object.
|
||||
parameters_schema_uri (str):
|
||||
Required. Points to a YAML file stored on Google Cloud Storage
|
||||
describing the parameters of prediction and explanation via
|
||||
PredictRequest.parameters, ExplainRequest.parameters and
|
||||
BatchPredictionJob.model_parameters. The schema is defined as an
|
||||
OpenAPI 3.0.2 `Schema Object.
|
||||
prediction_schema_uri (str):
|
||||
Required. Points to a YAML file stored on Google Cloud Storage
|
||||
describing the format of a single prediction produced by this Model
|
||||
, which are returned via PredictResponse.predictions,
|
||||
ExplainResponse.explanations, and BatchPredictionJob.output_config.
|
||||
The schema is defined as an OpenAPI 3.0.2 `Schema Object.
|
||||
"""
|
||||
|
||||
instance_schema_uri: Optional[str] = None
|
||||
parameters_schema_uri: Optional[str] = None
|
||||
prediction_schema_uri: Optional[str] = None
|
||||
|
||||
def to_dict(self):
|
||||
"""ML metadata schema dictionary representation of this DataClass.
|
||||
|
||||
|
||||
Returns:
|
||||
A dictionary that represents the PredictSchemata class.
|
||||
"""
|
||||
results = {}
|
||||
if self.instance_schema_uri:
|
||||
results["instanceSchemaUri"] = self.instance_schema_uri
|
||||
if self.parameters_schema_uri:
|
||||
results["parametersSchemaUri"] = self.parameters_schema_uri
|
||||
if self.prediction_schema_uri:
|
||||
results["predictionSchemaUri"] = self.prediction_schema_uri
|
||||
|
||||
return results
|
||||
|
||||
|
||||
@dataclass
|
||||
class ContainerSpec:
|
||||
"""Container configuration for the model.
|
||||
|
||||
Args:
|
||||
image_uri (str):
|
||||
Required. URI of the Docker image to be used as the custom
|
||||
container for serving predictions. This URI must identify an image
|
||||
in Artifact Registry or Container Registry.
|
||||
command (Sequence[str]):
|
||||
Optional. Specifies the command that runs when the container
|
||||
starts. This overrides the container's `ENTRYPOINT`.
|
||||
args (Sequence[str]):
|
||||
Optional. Specifies arguments for the command that runs when the
|
||||
container starts. This overrides the container's `CMD`
|
||||
env (Sequence[google.cloud.aiplatform_v1.types.EnvVar]):
|
||||
Optional. List of environment variables to set in the container.
|
||||
After the container starts running, code running in the container
|
||||
can read these environment variables. Additionally, the command
|
||||
and args fields can reference these variables. Later entries in
|
||||
this list can also reference earlier entries. For example, the
|
||||
following example sets the variable ``VAR_2`` to have the value
|
||||
``foo bar``: .. code:: json [ { "name": "VAR_1", "value": "foo" },
|
||||
{ "name": "VAR_2", "value": "$(VAR_1) bar" } ] If you switch the
|
||||
order of the variables in the example, then the expansion does not
|
||||
occur. This field corresponds to the ``env`` field of the
|
||||
Kubernetes Containers `v1 core API.
|
||||
ports (Sequence[google.cloud.aiplatform_v1.types.Port]):
|
||||
Optional. List of ports to expose from the container. Vertex AI
|
||||
sends any prediction requests that it receives to the first port on
|
||||
this list. Vertex AI also sends `liveness and health checks.
|
||||
predict_route (str):
|
||||
Optional. HTTP path on the container to send prediction requests
|
||||
to. Vertex AI forwards requests sent using
|
||||
projects.locations.endpoints.predict to this path on the
|
||||
container's IP address and port. Vertex AI then returns the
|
||||
container's response in the API response. For example, if you set
|
||||
this field to ``/foo``, then when Vertex AI receives a prediction
|
||||
request, it forwards the request body in a POST request to the
|
||||
``/foo`` path on the port of your container specified by the first
|
||||
value of this ``ModelContainerSpec``'s ports field. If you don't
|
||||
specify this field, it defaults to the following value when you
|
||||
deploy this Model to an Endpoint
|
||||
/v1/endpoints/ENDPOINT/deployedModels/DEPLOYED_MODEL:predict
|
||||
The placeholders in this value are replaced as follows:
|
||||
- ENDPOINT: The last segment (following ``endpoints/``)of the
|
||||
Endpoint.name][] field of the Endpoint where this Model has
|
||||
been deployed. (Vertex AI makes this value available to your
|
||||
container code as the ```AIP_ENDPOINT_ID`` environment variable
|
||||
health_route (str):
|
||||
Optional. HTTP path on the container to send health checks to.
|
||||
Vertex AI intermittently sends GET requests to this path on the
|
||||
container's IP address and port to check that the container is
|
||||
healthy. Read more about `health checks
|
||||
display_name (str):
|
||||
"""
|
||||
|
||||
image_uri: str
|
||||
command: Optional[List[str]] = None
|
||||
args: Optional[List[str]] = None
|
||||
env: Optional[List[Dict[str, str]]] = None
|
||||
ports: Optional[List[int]] = None
|
||||
predict_route: Optional[str] = None
|
||||
health_route: Optional[str] = None
|
||||
|
||||
def to_dict(self):
|
||||
"""ML metadata schema dictionary representation of this DataClass.
|
||||
|
||||
|
||||
Returns:
|
||||
A dictionary that represents the ContainerSpec class.
|
||||
"""
|
||||
results = {}
|
||||
results["imageUri"] = self.image_uri
|
||||
if self.command:
|
||||
results["command"] = self.command
|
||||
if self.args:
|
||||
results["args"] = self.args
|
||||
if self.env:
|
||||
results["env"] = self.env
|
||||
if self.ports:
|
||||
results["ports"] = self.ports
|
||||
if self.predict_route:
|
||||
results["predictRoute"] = self.predict_route
|
||||
if self.health_route:
|
||||
results["healthRoute"] = self.health_route
|
||||
|
||||
return results
|
||||
|
||||
|
||||
@dataclass
|
||||
class AnnotationSpec:
|
||||
"""A class that represents the annotation spec of a Confusion Matrix.
|
||||
|
||||
Args:
|
||||
display_name (str):
|
||||
Optional. Display name for a column of a confusion matrix.
|
||||
id (str):
|
||||
Optional. Id for a column of a confusion matrix.
|
||||
"""
|
||||
|
||||
display_name: Optional[str] = None
|
||||
id: Optional[str] = None
|
||||
|
||||
def to_dict(self):
|
||||
"""ML metadata schema dictionary representation of this DataClass.
|
||||
|
||||
|
||||
Returns:
|
||||
A dictionary that represents the AnnotationSpec class.
|
||||
"""
|
||||
results = {}
|
||||
if self.display_name:
|
||||
results["displayName"] = self.display_name
|
||||
if self.id:
|
||||
results["id"] = self.id
|
||||
|
||||
return results
|
||||
|
||||
|
||||
@dataclass
|
||||
class ConfusionMatrix:
|
||||
"""A class that represents a Confusion Matrix.
|
||||
|
||||
Args:
|
||||
matrix (List[List[int]]):
|
||||
Required. A 2D array of integers that represets the values for the confusion matrix.
|
||||
annotation_specs: (List(AnnotationSpec)):
|
||||
Optional. List of column annotation specs which contains display_name (str) and id (str)
|
||||
"""
|
||||
|
||||
matrix: List[List[int]]
|
||||
annotation_specs: Optional[List[AnnotationSpec]] = None
|
||||
|
||||
def to_dict(self):
|
||||
"""ML metadata schema dictionary representation of this DataClass.
|
||||
|
||||
Returns:
|
||||
A dictionary that represents the ConfusionMatrix class.
|
||||
|
||||
Raises:
|
||||
ValueError: if annotation_specs and matrix have different length.
|
||||
"""
|
||||
results = {}
|
||||
if self.annotation_specs:
|
||||
if len(self.annotation_specs) != len(self.matrix):
|
||||
raise ValueError(
|
||||
"Length of annotation_specs and matrix must be the same. "
|
||||
"Got lengths {} and {} respectively.".format(
|
||||
len(self.annotation_specs), len(self.matrix)
|
||||
)
|
||||
)
|
||||
results["annotationSpecs"] = [
|
||||
annotation_spec.to_dict() for annotation_spec in self.annotation_specs
|
||||
]
|
||||
if self.matrix:
|
||||
results["rows"] = self.matrix
|
||||
|
||||
return results
|
||||
|
||||
|
||||
@dataclass
|
||||
class ConfidenceMetric:
|
||||
"""A class that represents a Confidence Metric.
|
||||
Args:
|
||||
confidence_threshold (float):
|
||||
Required. Metrics are computed with an assumption that the Model never returns predictions with a score lower than this value.
|
||||
For binary classification this is the positive class threshold. For multi-class classification this is the confidence threshold.
|
||||
recall (float):
|
||||
Optional. Recall (True Positive Rate) for the given confidence threshold.
|
||||
precision (float):
|
||||
Optional. Precision for the given confidence threshold.
|
||||
f1_score (float):
|
||||
Optional. The harmonic mean of recall and precision.
|
||||
max_predictions (int):
|
||||
Optional. Metrics are computed with an assumption that the Model always returns at most this many predictions (ordered by their score, descendingly).
|
||||
But they all still need to meet the `confidence_threshold`.
|
||||
false_positive_rate (float):
|
||||
Optional. False Positive Rate for the given confidence threshold.
|
||||
accuracy (float):
|
||||
Optional. Accuracy is the fraction of predictions given the correct label. For multiclass this is a micro-average metric.
|
||||
true_positive_count (int):
|
||||
Optional. The number of Model created labels that match a ground truth label.
|
||||
false_positive_count (int):
|
||||
Optional. The number of Model created labels that do not match a ground truth label.
|
||||
false_negative_count (int):
|
||||
Optional. The number of ground truth labels that are not matched by a Model created label.
|
||||
true_negative_count (int):
|
||||
Optional. The number of labels that were not created by the Model, but if they would, they would not match a ground truth label.
|
||||
recall_at_1 (float):
|
||||
Optional. The Recall (True Positive Rate) when only considering the label that has the highest prediction score
|
||||
and not below the confidence threshold for each DataItem.
|
||||
precision_at_1 (float):
|
||||
Optional. The precision when only considering the label that has the highest prediction score
|
||||
and not below the confidence threshold for each DataItem.
|
||||
false_positive_rate_at_1 (float):
|
||||
Optional. The False Positive Rate when only considering the label that has the highest prediction score
|
||||
and not below the confidence threshold for each DataItem.
|
||||
f1_score_at_1 (float):
|
||||
Optional. The harmonic mean of recallAt1 and precisionAt1.
|
||||
confusion_matrix (ConfusionMatrix):
|
||||
Optional. Confusion matrix for the given confidence threshold.
|
||||
"""
|
||||
|
||||
confidence_threshold: float
|
||||
recall: Optional[float] = None
|
||||
precision: Optional[float] = None
|
||||
f1_score: Optional[float] = None
|
||||
max_predictions: Optional[int] = None
|
||||
false_positive_rate: Optional[float] = None
|
||||
accuracy: Optional[float] = None
|
||||
true_positive_count: Optional[int] = None
|
||||
false_positive_count: Optional[int] = None
|
||||
false_negative_count: Optional[int] = None
|
||||
true_negative_count: Optional[int] = None
|
||||
recall_at_1: Optional[float] = None
|
||||
precision_at_1: Optional[float] = None
|
||||
false_positive_rate_at_1: Optional[float] = None
|
||||
f1_score_at_1: Optional[float] = None
|
||||
confusion_matrix: Optional[ConfusionMatrix] = None
|
||||
|
||||
def to_dict(self):
|
||||
"""ML metadata schema dictionary representation of this DataClass.
|
||||
|
||||
|
||||
Returns:
|
||||
A dictionary that represents the ConfidenceMetric class.
|
||||
"""
|
||||
results = {}
|
||||
results["confidenceThreshold"] = self.confidence_threshold
|
||||
if self.recall is not None:
|
||||
results["recall"] = self.recall
|
||||
if self.precision is not None:
|
||||
results["precision"] = self.precision
|
||||
if self.f1_score is not None:
|
||||
results["f1Score"] = self.f1_score
|
||||
if self.max_predictions is not None:
|
||||
results["maxPredictions"] = self.max_predictions
|
||||
if self.false_positive_rate is not None:
|
||||
results["falsePositiveRate"] = self.false_positive_rate
|
||||
if self.accuracy is not None:
|
||||
results["accuracy"] = self.accuracy
|
||||
if self.true_positive_count is not None:
|
||||
results["truePositiveCount"] = self.true_positive_count
|
||||
if self.false_positive_count is not None:
|
||||
results["falsePositiveCount"] = self.false_positive_count
|
||||
if self.false_negative_count is not None:
|
||||
results["falseNegativeCount"] = self.false_negative_count
|
||||
if self.true_negative_count is not None:
|
||||
results["trueNegativeCount"] = self.true_negative_count
|
||||
if self.recall_at_1 is not None:
|
||||
results["recallAt1"] = self.recall_at_1
|
||||
if self.precision_at_1 is not None:
|
||||
results["precisionAt1"] = self.precision_at_1
|
||||
if self.false_positive_rate_at_1 is not None:
|
||||
results["falsePositiveRateAt1"] = self.false_positive_rate_at_1
|
||||
if self.f1_score_at_1 is not None:
|
||||
results["f1ScoreAt1"] = self.f1_score_at_1
|
||||
if self.confusion_matrix:
|
||||
results["confusionMatrix"] = self.confusion_matrix.to_dict()
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def create_uri_from_resource_name(resource_name: str) -> str:
|
||||
"""Construct the service URI for a given resource_name.
|
||||
Args:
|
||||
resource_name (str):
|
||||
The name of the Vertex resource, in one of the forms:
|
||||
projects/{project}/locations/{location}/{resource_type}/{resource_id}
|
||||
projects/{project}/locations/{location}/{resource_type}/{resource_id}@{version}
|
||||
projects/{project}/locations/{location}/metadataStores/{store_id}/{resource_type}/{resource_id}
|
||||
projects/{project}/locations/{location}/metadataStores/{store_id}/{resource_type}/{resource_id}@{version}
|
||||
Returns:
|
||||
The resource URI in the form of:
|
||||
https://{service-endpoint}/v1/{resource_name},
|
||||
where {service-endpoint} is one of the supported service endpoints at
|
||||
https://cloud.google.com/vertex-ai/docs/reference/rest#rest_endpoints
|
||||
Raises:
|
||||
ValueError: If resource_name does not match the specified format.
|
||||
"""
|
||||
# TODO: support nested resource names such as models/123/evaluations/456
|
||||
match_results = re.match(
|
||||
r"^projects\/(?P<project>[\w-]+)\/locations\/(?P<location>[\w-]+)(\/metadataStores\/(?P<store>[\w-]+))?\/[\w-]+\/(?P<id>[\w-]+)(?P<version>@[\w-]+)?$",
|
||||
resource_name,
|
||||
)
|
||||
if not match_results:
|
||||
raise ValueError(f"Invalid resource_name format for {resource_name}.")
|
||||
|
||||
location = match_results["location"]
|
||||
return f"https://{location}-aiplatform.googleapis.com/v1/{resource_name}"
|
||||
@@ -0,0 +1,54 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2022 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
from typing import List, Optional, Union
|
||||
|
||||
|
||||
def _make_filter_string(
|
||||
schema_title: Optional[Union[str, List[str]]] = None,
|
||||
in_context: Optional[List[str]] = None,
|
||||
parent_contexts: Optional[List[str]] = None,
|
||||
uri: Optional[str] = None,
|
||||
) -> str:
|
||||
"""Helper method to format filter strings for Metadata querying.
|
||||
|
||||
No enforcement of correctness.
|
||||
|
||||
Args:
|
||||
schema_title (Union[str, List[str]]): Optional. schema_titles to filter for.
|
||||
in_context (List[str]):
|
||||
Optional. Context resource names that the node should be in. Only for Artifacts/Executions.
|
||||
parent_contexts (List[str]): Optional. Parent contexts the context should be in. Only for Contexts.
|
||||
uri (str): Optional. uri to match for. Only for Artifacts.
|
||||
Returns:
|
||||
String that can be used for Metadata service filtering.
|
||||
"""
|
||||
parts = []
|
||||
if schema_title:
|
||||
if isinstance(schema_title, str):
|
||||
parts.append(f'schema_title="{schema_title}"')
|
||||
else:
|
||||
substring = " OR ".join(f'schema_title="{s}"' for s in schema_title)
|
||||
parts.append(f"({substring})")
|
||||
if in_context:
|
||||
for context in in_context:
|
||||
parts.append(f'in_context("{context}")')
|
||||
if parent_contexts:
|
||||
parent_context_str = ",".join([f'"{c}"' for c in parent_contexts])
|
||||
parts.append(f"parent_contexts:{parent_context_str}")
|
||||
if uri:
|
||||
parts.append(f'uri="{uri}"')
|
||||
return " AND ".join(parts)
|
||||
Reference in New Issue
Block a user