structure saas with tools

This commit is contained in:
Davidson Gomes
2025-04-25 15:30:54 -03:00
commit 1aef473937
16434 changed files with 6584257 additions and 0 deletions

View File

@@ -0,0 +1,16 @@
# -*- coding: utf-8 -*-
# Copyright 2021 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

View File

@@ -0,0 +1,900 @@
# -*- coding: utf-8 -*-
# Copyright 2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import importlib
import os
import pickle
import tempfile
from typing import Any, Dict, Optional, Sequence, Union
from google.auth import credentials as auth_credentials
from google.cloud import storage
from google.cloud import aiplatform
from google.cloud.aiplatform import base
from google.cloud.aiplatform import explain
from google.cloud.aiplatform import helpers
from google.cloud.aiplatform import initializer
from google.cloud.aiplatform import models
from google.cloud.aiplatform import utils
from google.cloud.aiplatform.metadata.schema import utils as schema_utils
from google.cloud.aiplatform.metadata.schema.google import (
artifact_schema as google_artifact_schema,
)
from google.cloud.aiplatform.utils import gcs_utils
_LOGGER = base.Logger(__name__)
_PICKLE_PROTOCOL = 4
_MAX_INPUT_EXAMPLE_ROWS = 5
def _save_sklearn_model(
model: "sklearn.base.BaseEstimator", # noqa: F821
path: str,
) -> str:
"""Saves a sklearn model.
Args:
model (sklearn.base.BaseEstimator):
Required. A sklearn model.
path (str):
Required. The local path to save the model.
Returns:
A string represents the model class.
"""
with open(path, "wb") as f:
pickle.dump(model, f, protocol=_PICKLE_PROTOCOL)
return f"{model.__class__.__module__}.{model.__class__.__name__}"
def _save_xgboost_model(
model: Union["xgb.Booster", "xgb.XGBModel"], # noqa: F821
path: str,
) -> str:
"""Saves a xgboost model.
Args:
model (Union[xgb.Booster, xgb.XGBModel]):
Requred. A xgboost model.
path (str):
Required. The local path to save the model.
Returns:
A string represents the model class.
"""
model.save_model(path)
return f"{model.__class__.__module__}.{model.__class__.__name__}"
def _save_tensorflow_model(
model: "tf.Module", # noqa: F821
path: str,
tf_save_model_kwargs: Optional[Dict[str, Any]] = None,
) -> str:
"""Saves a tensorflow model.
Args:
model (tf.Module):
Requred. A tensorflow model.
path (str):
Required. The local path to save the model.
tf_save_model_kwargs (Dict[str, Any]):
Optional. A dict of kwargs to pass to the model's save method.
If saving a tf module, this will pass to "tf.saved_model.save" method.
If saving a keras model, this will pass to "tf.keras.Model.save" method.
Returns:
A string represents the model's base class.
"""
try:
import tensorflow as tf
except ImportError:
raise ImportError(
"tensorflow is not installed and required for saving models."
) from None
tf_save_model_kwargs = tf_save_model_kwargs or {}
if isinstance(model, tf.keras.Model):
model.save(path, **tf_save_model_kwargs)
return "tensorflow.keras.Model"
elif isinstance(model, tf.Module):
tf.saved_model.save(model, path, **tf_save_model_kwargs)
return "tensorflow.Module"
def _load_sklearn_model(
model_file: str,
model_artifact: google_artifact_schema.ExperimentModel,
) -> "sklearn.base.BaseEstimator": # noqa: F821
"""Loads a sklearn model from local path.
Args:
model_file (str):
Required. A local model file to load.
model_artifact (google_artifact_schema.ExperimentModel):
Required. The artifact that saved the model.
Returns:
The sklearn model instance.
Raises:
ImportError: if sklearn is not installed.
"""
try:
import sklearn
except ImportError:
raise ImportError(
"sklearn is not installed and is required for loading models."
) from None
if sklearn.__version__ < model_artifact.framework_version:
_LOGGER.warning(
f"The original model was saved via sklearn {model_artifact.framework_version}. "
f"You are using sklearn {sklearn.__version__}."
"Attempting to load model..."
)
with open(model_file, "rb") as f:
sk_model = pickle.load(f)
return sk_model
def _load_xgboost_model(
model_file: str,
model_artifact: google_artifact_schema.ExperimentModel,
) -> Union["xgb.Booster", "xgb.XGBModel"]: # noqa: F821
"""Loads a xgboost model from local path.
Args:
model_file (str):
Required. A local model file to load.
model_artifact (google_artifact_schema.ExperimentModel):
Required. The artifact that saved the model.
Returns:
The xgboost model instance.
Raises:
ImportError: if xgboost is not installed.
"""
try:
import xgboost as xgb
except ImportError:
raise ImportError(
"xgboost is not installed and is required for loading models."
) from None
if xgb.__version__ < model_artifact.framework_version:
_LOGGER.warning(
f"The original model was saved via xgboost {model_artifact.framework_version}. "
f"You are using xgboost {xgb.__version__}."
"Attempting to load model..."
)
module, class_name = model_artifact.model_class.rsplit(".", maxsplit=1)
xgb_model = getattr(importlib.import_module(module), class_name)()
xgb_model.load_model(model_file)
return xgb_model
def _load_tensorflow_model(
model_file: str,
model_artifact: google_artifact_schema.ExperimentModel,
) -> "tf.Module": # noqa: F821
"""Loads a tensorflow model from path.
Args:
model_file (str):
Required. A path to load the model.
model_artifact (google_artifact_schema.ExperimentModel):
Required. The artifact that saved the model.
Returns:
The tensorflow model instance.
Raises:
ImportError: if tensorflow is not installed.
"""
try:
import tensorflow as tf
except ImportError:
raise ImportError(
"tensorflow is not installed and is required for loading models."
) from None
if tf.__version__ < model_artifact.framework_version:
_LOGGER.warning(
f"The original model was saved via tensorflow {model_artifact.framework_version}. "
f"You are using tensorflow {tf.__version__}."
"Attempting to load model..."
)
if model_artifact.model_class == "tensorflow.keras.Model":
tf_model = tf.keras.models.load_model(model_file)
elif model_artifact.model_class == "tensorflow.Module":
tf_model = tf.saved_model.load(model_file)
else:
raise ValueError(f"Unsupported model class: {model_artifact.model_class}")
return tf_model
def _save_input_example(
input_example: Union[list, dict, "pd.DataFrame", "np.ndarray"], # noqa: F821
path: str,
):
"""Saves an input example into a yaml file in the given path.
Supported example formats: list, dict, np.ndarray, pd.DataFrame.
Args:
input_example (Union[list, dict, np.ndarray, pd.DataFrame]):
Required. An input example to save. The value inside a list must be
a scalar or list. The value inside a dict must be a scalar, list, or
np.ndarray.
path (str):
Required. The directory that the example is saved to.
Raises:
ImportError: if PyYAML or numpy is not installed.
ValueError: if input_example is in a wrong format.
"""
try:
import numpy as np
except ImportError:
raise ImportError(
"numpy is not installed and is required for saving input examples. "
"Please install google-cloud-aiplatform[metadata]."
) from None
try:
import yaml
except ImportError:
raise ImportError(
"PyYAML is not installed and is required for saving input examples."
) from None
example = {}
if isinstance(input_example, list):
if all(isinstance(x, list) for x in input_example):
example = {
"type": "list",
"data": input_example[:_MAX_INPUT_EXAMPLE_ROWS],
}
elif all(np.isscalar(x) for x in input_example):
example = {
"type": "list",
"data": input_example,
}
else:
raise ValueError("The value inside a list must be a scalar or list.")
if isinstance(input_example, dict):
if all(isinstance(x, list) for x in input_example.values()):
example = {
"type": "dict",
"data": {
k: v[:_MAX_INPUT_EXAMPLE_ROWS] for k, v in input_example.items()
},
}
elif all(isinstance(x, np.ndarray) for x in input_example.values()):
example = {
"type": "dict",
"data": {
k: v[:_MAX_INPUT_EXAMPLE_ROWS].tolist()
for k, v in input_example.items()
},
}
elif all(np.isscalar(x) for x in input_example.values()):
example = {"type": "dict", "data": input_example}
else:
raise ValueError(
"The value inside a dictionary must be a scalar, list, or np.ndarray"
)
if isinstance(input_example, np.ndarray):
example = {
"type": "numpy.ndarray",
"data": input_example[:_MAX_INPUT_EXAMPLE_ROWS].tolist(),
}
try:
import pandas as pd
if isinstance(input_example, pd.DataFrame):
example = {
"type": "pandas.DataFrame",
"data": input_example.head(_MAX_INPUT_EXAMPLE_ROWS).to_dict("list"),
}
except ImportError:
pass
if not example:
raise ValueError(
(
"Input example type not supported. "
"Valid example must be a list, dict, np.ndarray, or pd.DataFrame."
)
)
example_file = os.path.join(path, "instance.yaml")
with open(example_file, "w") as file:
yaml.dump(
{"input_example": example}, file, default_flow_style=None, sort_keys=False
)
_FRAMEWORK_SPECS = {
"sklearn": {
"save_method": _save_sklearn_model,
"load_method": _load_sklearn_model,
"model_file": "model.pkl",
},
"xgboost": {
"save_method": _save_xgboost_model,
"load_method": _load_xgboost_model,
"model_file": "model.bst",
},
"tensorflow": {
"save_method": _save_tensorflow_model,
"load_method": _load_tensorflow_model,
"model_file": "saved_model",
},
}
def save_model(
model: Union[
"sklearn.base.BaseEstimator", "xgb.Booster", "tf.Module" # noqa: F821
],
artifact_id: Optional[str] = None,
*,
uri: Optional[str] = None,
input_example: Union[list, dict, "pd.DataFrame", "np.ndarray"] = None, # noqa: F821
tf_save_model_kwargs: Optional[Dict[str, Any]] = None,
display_name: Optional[str] = None,
metadata_store_id: Optional[str] = "default",
project: Optional[str] = None,
location: Optional[str] = None,
credentials: Optional[auth_credentials.Credentials] = None,
) -> google_artifact_schema.ExperimentModel:
"""Saves a ML model into a MLMD artifact.
Supported model frameworks: sklearn, xgboost, tensorflow.
Example usage:
aiplatform.init(project="my-project", location="my-location", staging_bucket="gs://my-bucket")
model = LinearRegression()
model.fit(X, y)
aiplatform.save_model(model, "my-sklearn-model")
Args:
model (Union["sklearn.base.BaseEstimator", "xgb.Booster", "tf.Module"]):
Required. A machine learning model.
artifact_id (str):
Optional. The resource id of the artifact. This id must be globally unique
in a metadataStore. It may be up to 63 characters, and valid characters
are `[a-z0-9_-]`. The first character cannot be a number or hyphen.
uri (str):
Optional. A gcs directory to save the model file. If not provided,
`gs://default-bucket/timestamp-uuid-frameworkName-model` will be used.
If default staging bucket is not set, a new bucket will be created.
input_example (Union[list, dict, pd.DataFrame, np.ndarray]):
Optional. An example of a valid model input. Will be stored as a yaml file
in the gcs uri. Accepts list, dict, pd.DataFrame, and np.ndarray
The value inside a list must be a scalar or list. The value inside
a dict must be a scalar, list, or np.ndarray.
tf_save_model_kwargs (Dict[str, Any]):
Optional. A dict of kwargs to pass to the model's save method.
If saving a tf module, this will pass to "tf.saved_model.save" method.
If saving a keras model, this will pass to "tf.keras.Model.save" method.
display_name (str):
Optional. The display name of the artifact.
metadata_store_id (str):
Optional. The <metadata_store_id> portion of the resource name with
the format:
projects/123/locations/us-central1/metadataStores/<metadata_store_id>/artifacts/<resource_id>
If not provided, the MetadataStore's ID will be set to "default".
project (str):
Optional. Project used to create this Artifact. Overrides project set in
aiplatform.init.
location (str):
Optional. Location used to create this Artifact. Overrides location set in
aiplatform.init.
credentials (auth_credentials.Credentials):
Optional. Custom credentials used to create this Artifact. Overrides
credentials set in aiplatform.init.
Returns:
An ExperimentModel instance.
Raises:
ValueError: if model type is not supported.
"""
framework_name = framework_version = ""
try:
import sklearn
except ImportError:
pass
else:
# An instance of sklearn.base.BaseEstimator might be a sklearn model
# or a xgboost/lightgbm model implemented on top of sklearn.
if isinstance(
model, sklearn.base.BaseEstimator
) and model.__class__.__module__.startswith("sklearn"):
framework_name = "sklearn"
framework_version = sklearn.__version__
try:
import sklearn.v1_0_2
except ImportError:
pass
else:
if isinstance(
model, sklearn.v1_0_2.base.BaseEstimator
) and model.__class__.__module__.startswith("sklearn"):
framework_name = "sklearn"
framework_version = sklearn.v1_0_2.__version__
try:
import xgboost as xgb
except ImportError:
pass
else:
if isinstance(model, (xgb.Booster, xgb.XGBModel)):
framework_name = "xgboost"
framework_version = xgb.__version__
try:
import tensorflow as tf
except ImportError:
pass
else:
if isinstance(model, tf.Module):
framework_name = "tensorflow"
framework_version = tf.__version__
if framework_name not in _FRAMEWORK_SPECS:
raise ValueError(
f"Model type {model.__class__.__module__}.{model.__class__.__name__} not supported."
)
save_method = _FRAMEWORK_SPECS[framework_name]["save_method"]
model_file = _FRAMEWORK_SPECS[framework_name]["model_file"]
if not uri:
staging_bucket = initializer.global_config.staging_bucket
# TODO(b/264196887)
if not staging_bucket:
project = project or initializer.global_config.project
location = location or initializer.global_config.location
credentials = credentials or initializer.global_config.credentials
staging_bucket_name = project + "-vertex-staging-" + location
client = storage.Client(project=project, credentials=credentials)
staging_bucket = storage.Bucket(client=client, name=staging_bucket_name)
if not staging_bucket.exists():
_LOGGER.info(f'Creating staging bucket "{staging_bucket_name}"')
staging_bucket = client.create_bucket(
bucket_or_name=staging_bucket,
project=project,
location=location,
)
staging_bucket = f"gs://{staging_bucket_name}"
unique_name = utils.timestamped_unique_name()
uri = f"{staging_bucket}/{unique_name}-{framework_name}-model"
with tempfile.TemporaryDirectory() as temp_dir:
# Tensorflow models can be saved directly to gcs
if framework_name == "tensorflow":
path = os.path.join(uri, model_file)
model_class = save_method(model, path, tf_save_model_kwargs)
# Other models will be saved to a temp path and uploaded to gcs
else:
path = os.path.join(temp_dir, model_file)
model_class = save_method(model, path)
if input_example is not None:
_save_input_example(input_example, temp_dir)
predict_schemata = schema_utils.PredictSchemata(
instance_schema_uri=os.path.join(uri, "instance.yaml")
)
else:
predict_schemata = None
gcs_utils.upload_to_gcs(temp_dir, uri)
model_artifact = google_artifact_schema.ExperimentModel(
framework_name=framework_name,
framework_version=framework_version,
model_file=model_file,
model_class=model_class,
predict_schemata=predict_schemata,
artifact_id=artifact_id,
uri=uri,
display_name=display_name,
)
model_artifact.create(
metadata_store_id=metadata_store_id,
project=project,
location=location,
credentials=credentials,
)
return model_artifact
def load_model(
model: Union[str, google_artifact_schema.ExperimentModel]
) -> Union["sklearn.base.BaseEstimator", "xgb.Booster", "tf.Module"]: # noqa: F821
"""Retrieves the original ML model from an ExperimentModel resource.
Args:
model (Union[str, google_artifact_schema.ExperimentModel]):
Required. The id or ExperimentModel instance for the model.
Returns:
The original ML model.
Raises:
ValueError: if model type is not supported.
"""
if isinstance(model, str):
model = aiplatform.get_experiment_model(model)
framework_name = model.framework_name
if framework_name not in _FRAMEWORK_SPECS:
raise ValueError(f"Model type {framework_name} not supported.")
load_method = _FRAMEWORK_SPECS[framework_name]["load_method"]
model_file = _FRAMEWORK_SPECS[framework_name]["model_file"]
source_file_uri = os.path.join(model.uri, model_file)
# Tensorflow models can be loaded directly from gcs
if framework_name == "tensorflow":
loaded_model = load_method(source_file_uri, model)
# Other models need to be downloaded to local path then loaded.
else:
with tempfile.TemporaryDirectory() as temp_dir:
destination_file_path = os.path.join(temp_dir, model_file)
gcs_utils.download_file_from_gcs(source_file_uri, destination_file_path)
loaded_model = load_method(destination_file_path, model)
return loaded_model
# TODO(b/264893283)
def register_model(
model: Union[str, google_artifact_schema.ExperimentModel],
*,
model_id: Optional[str] = None,
parent_model: Optional[str] = None,
use_gpu: bool = False,
is_default_version: bool = True,
version_aliases: Optional[Sequence[str]] = None,
version_description: Optional[str] = None,
display_name: Optional[str] = None,
description: Optional[str] = None,
labels: Optional[Dict[str, str]] = None,
serving_container_image_uri: Optional[str] = None,
serving_container_predict_route: Optional[str] = None,
serving_container_health_route: Optional[str] = None,
serving_container_command: Optional[Sequence[str]] = None,
serving_container_args: Optional[Sequence[str]] = None,
serving_container_environment_variables: Optional[Dict[str, str]] = None,
serving_container_ports: Optional[Sequence[int]] = None,
instance_schema_uri: Optional[str] = None,
parameters_schema_uri: Optional[str] = None,
prediction_schema_uri: Optional[str] = None,
explanation_metadata: Optional[explain.ExplanationMetadata] = None,
explanation_parameters: Optional[explain.ExplanationParameters] = None,
project: Optional[str] = None,
location: Optional[str] = None,
credentials: Optional[auth_credentials.Credentials] = None,
encryption_spec_key_name: Optional[str] = None,
staging_bucket: Optional[str] = None,
sync: Optional[bool] = True,
upload_request_timeout: Optional[float] = None,
) -> models.Model:
"""Register an ExperimentModel to Model Registry and returns a Model representing the registered Model resource.
Args:
model (Union[str, google_artifact_schema.ExperimentModel]):
Required. The id or ExperimentModel instance for the model.
model_id (str):
Optional. The ID to use for the registered Model, which will
become the final component of the model resource name.
This value may be up to 63 characters, and valid characters
are `[a-z0-9_-]`. The first character cannot be a number or hyphen.
parent_model (str):
Optional. The resource name or model ID of an existing model that the
newly-registered model will be a version of.
Only set this field when uploading a new version of an existing model.
use_gpu (str):
Optional. Whether or not to use GPUs for the serving container. Only
specify this argument when registering a Tensorflow model and
'serving_container_image_uri' is not specified.
is_default_version (bool):
Optional. When set to True, the newly registered model version will
automatically have alias "default" included. Subsequent uses of
this model without a version specified will use this "default" version.
When set to False, the "default" alias will not be moved.
Actions targeting the newly-registered model version will need
to specifically reference this version by ID or alias.
New model uploads, i.e. version 1, will always be "default" aliased.
version_aliases (Sequence[str]):
Optional. User provided version aliases so that a model version
can be referenced via alias instead of auto-generated version ID.
A default version alias will be created for the first version of the model.
The format is [a-z][a-zA-Z0-9-]{0,126}[a-z0-9]
version_description (str):
Optional. The description of the model version being uploaded.
display_name (str):
Optional. The display name of the Model. The name can be up to 128
characters long and can be consist of any UTF-8 characters.
description (str):
Optional. The description of the model.
labels (Dict[str, str]):
Optional. The labels with user-defined metadata to
organize your Models.
Label keys and values can be no longer than 64
characters (Unicode codepoints), can only
contain lowercase letters, numeric characters,
underscores and dashes. International characters
are allowed.
See https://goo.gl/xmQnxf for more information
and examples of labels.
serving_container_image_uri (str):
Optional. The URI of the Model serving container. A pre-built container
<https://cloud.google.com/vertex-ai/docs/predictions/pre-built-containers>
is automatically chosen based on the model's framwork. Set this field to
override the default pre-built container.
serving_container_predict_route (str):
Optional. An HTTP path to send prediction requests to the container, and
which must be supported by it. If not specified a default HTTP path will
be used by Vertex AI.
serving_container_health_route (str):
Optional. An HTTP path to send health check requests to the container, and which
must be supported by it. If not specified a standard HTTP path will be
used by Vertex AI.
serving_container_command (Sequence[str]):
Optional. The command with which the container is run. Not executed within a
shell. The Docker image's ENTRYPOINT is used if this is not provided.
Variable references $(VAR_NAME) are expanded using the container's
environment. If a variable cannot be resolved, the reference in the
input string will be unchanged. The $(VAR_NAME) syntax can be escaped
with a double $$, ie: $$(VAR_NAME). Escaped references will never be
expanded, regardless of whether the variable exists or not.
serving_container_args (Sequence[str]):
Optional. The arguments to the command. The Docker image's CMD is used if this is
not provided. Variable references $(VAR_NAME) are expanded using the
container's environment. If a variable cannot be resolved, the reference
in the input string will be unchanged. The $(VAR_NAME) syntax can be
escaped with a double $$, ie: $$(VAR_NAME). Escaped references will
never be expanded, regardless of whether the variable exists or not.
serving_container_environment_variables (Dict[str, str]):
Optional. The environment variables that are to be present in the container.
Should be a dictionary where keys are environment variable names
and values are environment variable values for those names.
serving_container_ports (Sequence[int]):
Optional. Declaration of ports that are exposed by the container. This field is
primarily informational, it gives Vertex AI information about the
network connections the container uses. Listing or not a port here has
no impact on whether the port is actually exposed, any port listening on
the default "0.0.0.0" address inside a container will be accessible from
the network.
instance_schema_uri (str):
Optional. Points to a YAML file stored on Google Cloud
Storage describing the format of a single instance, which
are used in
``PredictRequest.instances``,
``ExplainRequest.instances``
and
``BatchPredictionJob.input_config``.
The schema is defined as an OpenAPI 3.0.2 `Schema
Object <https://tinyurl.com/y538mdwt#schema-object>`__.
AutoML Models always have this field populated by AI
Platform. Note: The URI given on output will be immutable
and probably different, including the URI scheme, than the
one given on input. The output URI will point to a location
where the user only has a read access.
parameters_schema_uri (str):
Optional. Points to a YAML file stored on Google Cloud
Storage describing the parameters of prediction and
explanation via
``PredictRequest.parameters``,
``ExplainRequest.parameters``
and
``BatchPredictionJob.model_parameters``.
The schema is defined as an OpenAPI 3.0.2 `Schema
Object <https://tinyurl.com/y538mdwt#schema-object>`__.
AutoML Models always have this field populated by AI
Platform, if no parameters are supported it is set to an
empty string. Note: The URI given on output will be
immutable and probably different, including the URI scheme,
than the one given on input. The output URI will point to a
location where the user only has a read access.
prediction_schema_uri (str):
Optional. Points to a YAML file stored on Google Cloud
Storage describing the format of a single prediction
produced by this Model, which are returned via
``PredictResponse.predictions``,
``ExplainResponse.explanations``,
and
``BatchPredictionJob.output_config``.
The schema is defined as an OpenAPI 3.0.2 `Schema
Object <https://tinyurl.com/y538mdwt#schema-object>`__.
AutoML Models always have this field populated by AI
Platform. Note: The URI given on output will be immutable
and probably different, including the URI scheme, than the
one given on input. The output URI will point to a location
where the user only has a read access.
explanation_metadata (aiplatform.explain.ExplanationMetadata):
Optional. Metadata describing the Model's input and output for explanation.
`explanation_metadata` is optional while `explanation_parameters` must be
specified when used.
For more details, see `Ref docs <http://tinyurl.com/1igh60kt>`
explanation_parameters (aiplatform.explain.ExplanationParameters):
Optional. Parameters to configure explaining for Model's predictions.
For more details, see `Ref docs <http://tinyurl.com/1an4zake>`
project (str)
Project to upload this model to. Overrides project set in
aiplatform.init.
location (str)
Location to upload this model to. Overrides location set in
aiplatform.init.
credentials (auth_credentials.Credentials)
Custom credentials to use to upload this model. Overrides credentials
set in aiplatform.init.
encryption_spec_key_name (Optional[str]):
Optional. The Cloud KMS resource identifier of the customer
managed encryption key used to protect the model. Has the
form
``projects/my-project/locations/my-region/keyRings/my-kr/cryptoKeys/my-key``.
The key needs to be in the same region as where the compute
resource is created.
If set, this Model and all sub-resources of this Model will be secured by this key.
Overrides encryption_spec_key_name set in aiplatform.init.
staging_bucket (str):
Optional. Bucket to stage local model artifacts. Overrides
staging_bucket set in aiplatform.init.
sync (bool):
Optional. Whether to execute this method synchronously. If False,
this method will unblock and it will be executed in a concurrent Future.
upload_request_timeout (float):
Optional. The timeout for the upload request in seconds.
Returns:
model (aiplatform.Model):
Instantiated representation of the registered model resource.
Raises:
ValueError: If the model doesn't have a pre-built container that is
suitable for its framework and 'serving_container_image_uri'
is not set.
"""
if isinstance(model, str):
model = aiplatform.get_experiment_model(model)
project = project or model.project
location = location or model.location
credentials = credentials or model.credentials
artifact_uri = model.uri
framework_name = model.framework_name
framework_version = model.framework_version
artifact_uri = (
f"{model.uri}/saved_model" if framework_name == "tensorflow" else model.uri
)
if not serving_container_image_uri:
if framework_name == "tensorflow" and use_gpu:
accelerator = "gpu"
else:
accelerator = "cpu"
serving_container_image_uri = helpers._get_closest_match_prebuilt_container_uri(
framework=framework_name,
framework_version=framework_version,
region=location,
accelerator=accelerator,
)
if not display_name:
display_name = models.Model._generate_display_name(f"{framework_name} model")
return models.Model.upload(
serving_container_image_uri=serving_container_image_uri,
artifact_uri=artifact_uri,
model_id=model_id,
parent_model=parent_model,
is_default_version=is_default_version,
version_aliases=version_aliases,
version_description=version_description,
display_name=display_name,
description=description,
labels=labels,
serving_container_predict_route=serving_container_predict_route,
serving_container_health_route=serving_container_health_route,
serving_container_command=serving_container_command,
serving_container_args=serving_container_args,
serving_container_environment_variables=serving_container_environment_variables,
serving_container_ports=serving_container_ports,
instance_schema_uri=instance_schema_uri,
parameters_schema_uri=parameters_schema_uri,
prediction_schema_uri=prediction_schema_uri,
explanation_metadata=explanation_metadata,
explanation_parameters=explanation_parameters,
project=project,
location=location,
credentials=credentials,
encryption_spec_key_name=encryption_spec_key_name,
staging_bucket=staging_bucket,
sync=sync,
upload_request_timeout=upload_request_timeout,
)
def get_experiment_model_info(
model: Union[str, google_artifact_schema.ExperimentModel]
) -> Dict[str, Any]:
"""Get the model's info from an experiment model artifact.
Args:
model (Union[str, google_artifact_schema.ExperimentModel]):
Required. The id or ExperimentModel instance for the model.
Returns:
A dict of model's info. This includes model's class name, framework name,
framework version, and input example.
"""
if isinstance(model, str):
model = aiplatform.get_experiment_model(model)
model_info = {
"model_class": model.model_class,
"framework_name": model.framework_name,
"framework_version": model.framework_version,
}
# try to get input example if exists
input_example = None
source_file = f"{model.uri}/instance.yaml"
with tempfile.TemporaryDirectory() as temp_dir:
destination_file = os.path.join(temp_dir, "instance.yaml")
try:
gcs_utils.download_file_from_gcs(source_file, destination_file)
except Exception:
pass
else:
try:
import yaml
except ImportError:
raise ImportError(
"PyYAML is not installed and is required for loading input examples."
) from None
with open(destination_file, "r") as f:
input_example = yaml.safe_load(f)["input_example"]
if input_example:
model_info["input_example"] = input_example
return model_info

View File

@@ -0,0 +1,564 @@
# -*- coding: utf-8 -*-
# Copyright 2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from typing import Optional, Dict, Union
import proto
import threading
from google.auth import credentials as auth_credentials
from google.cloud.aiplatform import base
from google.cloud.aiplatform import models
from google.cloud.aiplatform import utils
from google.cloud.aiplatform.compat.types import artifact as gca_artifact
from google.cloud.aiplatform.compat.types import (
metadata_service as gca_metadata_service,
)
from google.cloud.aiplatform.constants import base as base_constants
from google.cloud.aiplatform.metadata import metadata_store
from google.cloud.aiplatform.metadata import resource
from google.cloud.aiplatform.metadata import utils as metadata_utils
from google.cloud.aiplatform.utils import rest_utils
_LOGGER = base.Logger(__name__)
class Artifact(resource._Resource):
"""Metadata Artifact resource for Vertex AI"""
def __init__(
self,
artifact_name: str,
*,
metadata_store_id: str = "default",
project: Optional[str] = None,
location: Optional[str] = None,
credentials: Optional[auth_credentials.Credentials] = None,
):
"""Retrieves an existing Metadata Artifact given a resource name or ID.
Args:
artifact_name (str):
Required. A fully-qualified resource name or resource ID of the Artifact.
Example: "projects/123/locations/us-central1/metadataStores/default/artifacts/my-resource".
or "my-resource" when project and location are initialized or passed.
metadata_store_id (str):
Optional. MetadataStore to retrieve Artifact from. If not set, metadata_store_id is set to "default".
If artifact_name is a fully-qualified resource, its metadata_store_id overrides this one.
project (str):
Optional. Project to retrieve the artifact from. If not set, project
set in aiplatform.init will be used.
location (str):
Optional. Location to retrieve the Artifact from. If not set, location
set in aiplatform.init will be used.
credentials (auth_credentials.Credentials):
Optional. Custom credentials to use to retrieve this Artifact. Overrides
credentials set in aiplatform.init.
"""
super().__init__(
resource_name=artifact_name,
metadata_store_id=metadata_store_id,
project=project,
location=location,
credentials=credentials,
)
_resource_noun = "artifacts"
_getter_method = "get_artifact"
_delete_method = "delete_artifact"
_parse_resource_name_method = "parse_artifact_path"
_format_resource_name_method = "artifact_path"
_list_method = "list_artifacts"
@classmethod
def _create_resource(
cls,
client: utils.MetadataClientWithOverride,
parent: str,
resource_id: str,
schema_title: str,
uri: Optional[str] = None,
display_name: Optional[str] = None,
schema_version: Optional[str] = None,
description: Optional[str] = None,
metadata: Optional[Dict] = None,
state: gca_artifact.Artifact.State = gca_artifact.Artifact.State.LIVE,
) -> gca_artifact.Artifact:
gapic_artifact = gca_artifact.Artifact(
uri=uri,
schema_title=schema_title,
schema_version=schema_version,
display_name=display_name,
description=description,
metadata=metadata if metadata else {},
state=state,
)
return client.create_artifact(
parent=parent,
artifact=gapic_artifact,
artifact_id=resource_id,
)
# TODO() refactor code to move _create to _Resource class.
@classmethod
def _create(
cls,
resource_id: str,
schema_title: str,
uri: Optional[str] = None,
display_name: Optional[str] = None,
schema_version: Optional[str] = None,
description: Optional[str] = None,
metadata: Optional[Dict] = None,
state: gca_artifact.Artifact.State = gca_artifact.Artifact.State.LIVE,
metadata_store_id: Optional[str] = "default",
project: Optional[str] = None,
location: Optional[str] = None,
credentials: Optional[auth_credentials.Credentials] = None,
) -> "Artifact":
"""Creates a new Metadata resource.
Args:
resource_id (str):
Required. The <resource_id> portion of the resource name with
the format:
projects/123/locations/us-central1/metadataStores/<metadata_store_id>/<resource_noun>/<resource_id>.
schema_title (str):
Required. schema_title identifies the schema title used by the resource.
display_name (str):
Optional. The user-defined name of the resource.
schema_version (str):
Optional. schema_version specifies the version used by the resource.
If not set, defaults to use the latest version.
description (str):
Optional. Describes the purpose of the resource to be created.
metadata (Dict):
Optional. Contains the metadata information that will be stored in the resource.
state (google.cloud.gapic.types.Artifact.State):
Optional. The state of this Artifact. This is a
property of the Artifact, and does not imply or
capture any ongoing process. This property is
managed by clients (such as Vertex AI
Pipelines), and the system does not prescribe or
check the validity of state transitions.
metadata_store_id (str):
The <metadata_store_id> portion of the resource name with
the format:
projects/123/locations/us-central1/metadataStores/<metadata_store_id>/<resource_noun>/<resource_id>
If not provided, the MetadataStore's ID will be set to "default".
project (str):
Project used to create this resource. Overrides project set in
aiplatform.init.
location (str):
Location used to create this resource. Overrides location set in
aiplatform.init.
credentials (auth_credentials.Credentials):
Custom credentials used to create this resource. Overrides
credentials set in aiplatform.init.
Returns:
resource (_Resource):
Instantiated representation of the managed Metadata resource.
"""
appended_user_agent = []
if base_constants.USER_AGENT_SDK_COMMAND:
appended_user_agent = [
f"sdk_command/{base_constants.USER_AGENT_SDK_COMMAND}"
]
# Reset the value for the USER_AGENT_SDK_COMMAND to avoid counting future unrelated api calls.
base_constants.USER_AGENT_SDK_COMMAND = ""
api_client = cls._instantiate_client(
location=location,
credentials=credentials,
appended_user_agent=appended_user_agent,
)
parent = utils.full_resource_name(
resource_name=metadata_store_id,
resource_noun=metadata_store._MetadataStore._resource_noun,
parse_resource_name_method=metadata_store._MetadataStore._parse_resource_name,
format_resource_name_method=metadata_store._MetadataStore._format_resource_name,
project=project,
location=location,
)
resource = cls._create_resource(
client=api_client,
parent=parent,
resource_id=resource_id,
schema_title=schema_title,
uri=uri,
display_name=display_name,
schema_version=schema_version,
description=description,
metadata=metadata,
state=state,
)
self = cls._empty_constructor(
project=project, location=location, credentials=credentials
)
self._gca_resource = resource
self._threading_lock = threading.Lock()
return self
@classmethod
def _update_resource(
cls,
client: utils.MetadataClientWithOverride,
resource: proto.Message,
) -> proto.Message:
"""Update Artifacts with given input.
Args:
client (utils.MetadataClientWithOverride):
Required. client to send require to Metadata Service.
resource (proto.Message):
Required. The proto.Message which contains the update information for the resource.
"""
return client.update_artifact(artifact=resource)
@classmethod
def _list_resources(
cls,
client: utils.MetadataClientWithOverride,
parent: str,
filter: Optional[str] = None, # pylint: disable=redefined-builtin
order_by: Optional[str] = None,
):
"""List artifacts in the parent path that matches the filter.
Args:
client (utils.MetadataClientWithOverride):
Required. client to send require to Metadata Service.
parent (str):
Required. The path where Artifacts are stored.
filter (str):
Optional. filter string to restrict the list result
order_by (str):
Optional. How the list of messages is ordered. Specify the
values to order by and an ordering operation. The default sorting
order is ascending. To specify descending order for a field, users
append a " desc" suffix; for example: "foo desc, bar". Subfields
are specified with a ``.`` character, such as foo.bar. see
https://google.aip.dev/132#ordering for more details.
Returns:
List of artifacts.
"""
list_request = gca_metadata_service.ListArtifactsRequest(
parent=parent,
filter=filter,
order_by=order_by,
)
return client.list_artifacts(request=list_request)
@classmethod
def create(
cls,
schema_title: str,
*,
resource_id: Optional[str] = None,
uri: Optional[str] = None,
display_name: Optional[str] = None,
schema_version: Optional[str] = None,
description: Optional[str] = None,
metadata: Optional[Dict] = None,
state: gca_artifact.Artifact.State = gca_artifact.Artifact.State.LIVE,
metadata_store_id: Optional[str] = "default",
project: Optional[str] = None,
location: Optional[str] = None,
credentials: Optional[auth_credentials.Credentials] = None,
) -> "Artifact":
"""Creates a new Metadata Artifact.
Args:
schema_title (str):
Required. schema_title identifies the schema title used by the Artifact.
Please reference https://cloud.google.com/vertex-ai/docs/ml-metadata/system-schemas.
resource_id (str):
Optional. The <resource_id> portion of the Artifact name with
the format. This is globally unique in a metadataStore:
projects/123/locations/us-central1/metadataStores/<metadata_store_id>/artifacts/<resource_id>.
uri (str):
Optional. The uniform resource identifier of the artifact file. May be empty if there is no actual
artifact file.
display_name (str):
Optional. The user-defined name of the Artifact.
schema_version (str):
Optional. schema_version specifies the version used by the Artifact.
If not set, defaults to use the latest version.
description (str):
Optional. Describes the purpose of the Artifact to be created.
metadata (Dict):
Optional. Contains the metadata information that will be stored in the Artifact.
state (google.cloud.gapic.types.Artifact.State):
Optional. The state of this Artifact. This is a
property of the Artifact, and does not imply or
capture any ongoing process. This property is
managed by clients (such as Vertex AI
Pipelines), and the system does not prescribe or
check the validity of state transitions.
metadata_store_id (str):
Optional. The <metadata_store_id> portion of the resource name with
the format:
projects/123/locations/us-central1/metadataStores/<metadata_store_id>/artifacts/<resource_id>
If not provided, the MetadataStore's ID will be set to "default".
project (str):
Optional. Project used to create this Artifact. Overrides project set in
aiplatform.init.
location (str):
Optional. Location used to create this Artifact. Overrides location set in
aiplatform.init.
credentials (auth_credentials.Credentials):
Optional. Custom credentials used to create this Artifact. Overrides
credentials set in aiplatform.init.
Returns:
Artifact: Instantiated representation of the managed Metadata Artifact.
"""
# Add User Agent Header for metrics tracking if one is not specified
# If one is already specified this call was initiated by a sub class.
if not base_constants.USER_AGENT_SDK_COMMAND:
base_constants.USER_AGENT_SDK_COMMAND = (
"aiplatform.metadata.artifact.Artifact.create"
)
if metadata_store_id == "default":
metadata_store._MetadataStore.ensure_default_metadata_store_exists(
project=project, location=location, credentials=credentials
)
return cls._create(
resource_id=resource_id,
schema_title=schema_title,
uri=uri,
display_name=display_name,
schema_version=schema_version,
description=description,
metadata=metadata,
state=state,
metadata_store_id=metadata_store_id,
project=project,
location=location,
credentials=credentials,
)
@property
def uri(self) -> Optional[str]:
"Uri for this Artifact."
return self._gca_resource.uri
@property
def state(self) -> Optional[gca_artifact.Artifact.State]:
"The State for this Artifact."
return self._gca_resource.state
@classmethod
def get_with_uri(
cls,
uri: str,
*,
metadata_store_id: Optional[str] = "default",
project: Optional[str] = None,
location: Optional[str] = None,
credentials: Optional[auth_credentials.Credentials] = None,
) -> "Artifact":
"""Get an Artifact by it's uri.
If more than one Artifact with this uri is in the metadata store then the Artifact with the latest
create_time is returned.
Args:
uri(str):
Required. Uri of the Artifact to retrieve.
metadata_store_id (str):
Optional. MetadataStore to retrieve Artifact from. If not set, metadata_store_id is set to "default".
If artifact_name is a fully-qualified resource, its metadata_store_id overrides this one.
project (str):
Optional. Project to retrieve the artifact from. If not set, project
set in aiplatform.init will be used.
location (str):
Optional. Location to retrieve the Artifact from. If not set, location
set in aiplatform.init will be used.
credentials (auth_credentials.Credentials):
Optional. Custom credentials to use to retrieve this Artifact. Overrides
credentials set in aiplatform.init.
Returns:
Artifact: Artifact with given uri.
Raises:
ValueError: If no Artifact exists with the provided uri.
"""
matched_artifacts = cls.list(
filter=f'uri = "{uri}"',
metadata_store_id=metadata_store_id,
project=project,
location=location,
credentials=credentials,
)
if not matched_artifacts:
raise ValueError(
f"No artifact with uri {uri} is in the `{metadata_store_id}` MetadataStore."
)
if len(matched_artifacts) > 1:
matched_artifacts.sort(key=lambda a: a.create_time, reverse=True)
resource_names = "\n".join(a.resource_name for a in matched_artifacts)
_LOGGER.warn(
f"Mutiple artifacts with uri {uri} were found: {resource_names}"
)
_LOGGER.warn(f"Returning {matched_artifacts[0].resource_name}")
return matched_artifacts[0]
@property
def lineage_console_uri(self) -> str:
"""Cloud console uri to view this Artifact Lineage."""
metadata_store = self._parse_resource_name(self.resource_name)["metadata_store"]
return f"https://console.cloud.google.com/vertex-ai/locations/{self.location}/metadata-stores/{metadata_store}/artifacts/{self.name}?project={self.project}"
def __repr__(self) -> str:
if self._gca_resource:
return f"{object.__repr__(self)} \nresource name: {self.resource_name}\nuri: {self.uri}\nschema_title:{self.gca_resource.schema_title}"
return base.FutureManager.__repr__(self)
class _VertexResourceArtifactResolver:
# TODO(b/235594717) Add support for managed datasets
_resource_to_artifact_type = {models.Model: "google.VertexModel"}
@classmethod
def supports_metadata(cls, resource: base.VertexAiResourceNoun) -> bool:
"""Returns True if Vertex resource is supported in Vertex Metadata otherwise False.
Args:
resource (base.VertexAiResourceNoun):
Requried. Instance of Vertex AI Resource.
Returns:
True if Vertex resource is supported in Vertex Metadata otherwise False.
"""
return type(resource) in cls._resource_to_artifact_type
@classmethod
def validate_resource_supports_metadata(cls, resource: base.VertexAiResourceNoun):
"""Validates Vertex resource is supported in Vertex Metadata.
Args:
resource (base.VertexAiResourceNoun):
Required. Instance of Vertex AI Resource.
Raises:
ValueError: If Vertex AI Resource is not support in Vertex Metadata.
"""
if not cls.supports_metadata(resource):
raise ValueError(
f"Vertex {type(resource)} is not yet supported in Vertex Metadata."
f"Only {list(cls._resource_to_artifact_type.keys())} are supported"
)
@classmethod
def resolve_vertex_resource(
cls, resource: Union[models.Model]
) -> Optional[Artifact]:
"""Resolves Vertex Metadata Artifact that represents this Vertex Resource.
If there are multiple Artifacts in the metadata store that represent the provided resource. The one with the
latest create_time is returned.
Args:
resource (base.VertexAiResourceNoun):
Required. Instance of Vertex AI Resource.
Returns:
Artifact: Artifact that represents this Vertex Resource. None if Resource not found in Metadata store.
"""
cls.validate_resource_supports_metadata(resource)
resource.wait()
metadata_type = cls._resource_to_artifact_type[type(resource)]
uri = rest_utils.make_gcp_resource_rest_url(resource=resource)
artifacts = Artifact.list(
filter=metadata_utils._make_filter_string(
schema_title=metadata_type,
uri=uri,
),
project=resource.project,
location=resource.location,
credentials=resource.credentials,
)
artifacts.sort(key=lambda a: a.create_time, reverse=True)
if artifacts:
# most recent
return artifacts[0]
@classmethod
def create_vertex_resource_artifact(cls, resource: Union[models.Model]) -> Artifact:
"""Creates Vertex Metadata Artifact that represents this Vertex Resource.
Args:
resource (base.VertexAiResourceNoun):
Required. Instance of Vertex AI Resource.
Returns:
Artifact: Artifact that represents this Vertex Resource.
"""
cls.validate_resource_supports_metadata(resource)
resource.wait()
metadata_type = cls._resource_to_artifact_type[type(resource)]
uri = rest_utils.make_gcp_resource_rest_url(resource=resource)
return Artifact.create(
schema_title=metadata_type,
display_name=getattr(resource.gca_resource, "display_name", None),
uri=uri,
# Note that support for non-versioned resources requires
# change to reference `resource_name` please update if
# supporting resource other than Model
metadata={"resourceName": resource.versioned_resource_name},
project=resource.project,
location=resource.location,
credentials=resource.credentials,
)
@classmethod
def resolve_or_create_resource_artifact(
cls, resource: Union[models.Model]
) -> Artifact:
"""Create of gets Vertex Metadata Artifact that represents this Vertex Resource.
Args:
resource (base.VertexAiResourceNoun):
Required. Instance of Vertex AI Resource.
Returns:
Artifact: Artifact that represents this Vertex Resource.
"""
artifact = cls.resolve_vertex_resource(resource=resource)
if artifact:
return artifact
return cls.create_vertex_resource_artifact(resource=resource)

View File

@@ -0,0 +1,87 @@
# -*- coding: utf-8 -*-
# Copyright 2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""Constants used by Metadata and Vertex Experiments."""
from google.cloud.aiplatform.compat.types import artifact
SYSTEM_RUN = "system.Run"
SYSTEM_EXPERIMENT = "system.Experiment"
SYSTEM_EXPERIMENT_RUN = "system.ExperimentRun"
SYSTEM_PIPELINE = "system.Pipeline"
SYSTEM_PIPELINE_RUN = "system.PipelineRun"
SYSTEM_METRICS = "system.Metrics"
GOOGLE_CLASSIFICATION_METRICS = "google.ClassificationMetrics"
GOOGLE_REGRESSION_METRICS = "google.RegressionMetrics"
GOOGLE_FORECASTING_METRICS = "google.ForecastingMetrics"
GOOGLE_EXPERIMENT_MODEL = "google.ExperimentModel"
_EXPERIMENTS_V2_TENSORBOARD_RUN = "google.VertexTensorboardRun"
_DEFAULT_SCHEMA_VERSION = "0.0.1"
SCHEMA_VERSIONS = {
SYSTEM_RUN: _DEFAULT_SCHEMA_VERSION,
SYSTEM_EXPERIMENT: _DEFAULT_SCHEMA_VERSION,
SYSTEM_EXPERIMENT_RUN: _DEFAULT_SCHEMA_VERSION,
SYSTEM_PIPELINE: _DEFAULT_SCHEMA_VERSION,
SYSTEM_METRICS: _DEFAULT_SCHEMA_VERSION,
}
_BACKING_TENSORBOARD_RESOURCE_KEY = "backing_tensorboard_resource"
_CUSTOM_JOB_KEY = "_custom_jobs"
_CUSTOM_JOB_RESOURCE_NAME = "custom_job_resource_name"
_CUSTOM_JOB_CONSOLE_URI = "custom_job_console_uri"
_PARAM_KEY = "_params"
_METRIC_KEY = "_metrics"
_STATE_KEY = "_state"
_PARAM_PREFIX = "param"
_METRIC_PREFIX = "metric"
_TIME_SERIES_METRIC_PREFIX = "time_series_metric"
# This is currently used to filter in the Console.
EXPERIMENT_METADATA = {"experiment_deleted": False}
PIPELINE_PARAM_PREFIX = "input:"
TENSORBOARD_CUSTOM_JOB_EXPERIMENT_FIELD = "tensorboard_link"
GCP_ARTIFACT_RESOURCE_NAME_KEY = "resourceName"
# constant to mark an Experiment context as originating from the SDK
# TODO(b/235593750) Remove this field
_VERTEX_EXPERIMENT_TRACKING_LABEL = "vertex_experiment_tracking"
_TENSORBOARD_RUN_REFERENCE_ARTIFACT = artifact.Artifact(
name="google-vertex-tensorboard-run-v0-0-1",
schema_title=_EXPERIMENTS_V2_TENSORBOARD_RUN,
schema_version="0.0.1",
metadata={_VERTEX_EXPERIMENT_TRACKING_LABEL: True},
)
_TB_RUN_ARTIFACT_POST_FIX_ID = "-tb-run"
_EXPERIMENT_RUN_MAX_LENGTH = 128 - len(_TB_RUN_ARTIFACT_POST_FIX_ID)
# Label used to identify TensorboardExperiment as created from Vertex
# Experiments
_VERTEX_EXPERIMENT_TB_EXPERIMENT_LABEL = {
"vertex_tensorboard_experiment_source": "vertex_experiment"
}
ENV_EXPERIMENT_KEY = "AIP_EXPERIMENT_NAME"
ENV_EXPERIMENT_RUN_KEY = "AIP_EXPERIMENT_RUN_NAME"

View File

@@ -0,0 +1,421 @@
# -*- coding: utf-8 -*-
# Copyright 2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from typing import Optional, Dict, List, Sequence
import proto
import re
import threading
from google.auth import credentials as auth_credentials
from google.cloud.aiplatform import base
from google.cloud.aiplatform import utils
from google.cloud.aiplatform.constants import base as base_constants
from google.cloud.aiplatform.metadata import utils as metadata_utils
from google.cloud.aiplatform.compat.types import context as gca_context
from google.cloud.aiplatform.compat.types import (
lineage_subgraph as gca_lineage_subgraph,
)
from google.cloud.aiplatform.compat.types import (
metadata_service as gca_metadata_service,
)
from google.cloud.aiplatform.metadata import artifact
from google.cloud.aiplatform.metadata import execution
from google.cloud.aiplatform.metadata import metadata_store
from google.cloud.aiplatform.metadata import resource
from google.api_core.exceptions import Aborted
_ETAG_ERROR_MAX_RETRY_COUNT = 5
_ETAG_ERROR_REGEX = re.compile(
r"Specified Context \`etag\`: \`(\d+)\` does not match server \`etag\`: \`(\d+)\`"
)
class Context(resource._Resource):
"""Metadata Context resource for Vertex AI"""
_resource_noun = "contexts"
_getter_method = "get_context"
_delete_method = "delete_context"
_parse_resource_name_method = "parse_context_path"
_format_resource_name_method = "context_path"
_list_method = "list_contexts"
@property
def parent_contexts(self) -> Sequence[str]:
"""The parent context resource names of this context."""
return self.gca_resource.parent_contexts
def add_artifacts_and_executions(
self,
artifact_resource_names: Optional[Sequence[str]] = None,
execution_resource_names: Optional[Sequence[str]] = None,
):
"""Associate Executions and attribute Artifacts to a given Context.
Args:
artifact_resource_names (Sequence[str]):
Optional. The full resource name of Artifacts to attribute to the Context.
execution_resource_names (Sequence[str]):
Optional. The full resource name of Executions to associate with the Context.
"""
self.api_client.add_context_artifacts_and_executions(
context=self.resource_name,
artifacts=artifact_resource_names,
executions=execution_resource_names,
)
def get_artifacts(self) -> List[artifact.Artifact]:
"""Returns all Artifact attributed to this Context.
Returns:
artifacts(List[Artifacts]): All Artifacts under this context.
"""
return artifact.Artifact.list(
filter=metadata_utils._make_filter_string(in_context=[self.resource_name]),
project=self.project,
location=self.location,
credentials=self.credentials,
)
@classmethod
def create(
cls,
schema_title: str,
*,
resource_id: Optional[str] = None,
display_name: Optional[str] = None,
schema_version: Optional[str] = None,
description: Optional[str] = None,
metadata: Optional[Dict] = None,
metadata_store_id: Optional[str] = "default",
project: Optional[str] = None,
location: Optional[str] = None,
credentials: Optional[auth_credentials.Credentials] = None,
) -> "Context":
"""Creates a new Metadata Context.
Args:
schema_title (str):
Required. schema_title identifies the schema title used by the Context.
Please reference https://cloud.google.com/vertex-ai/docs/ml-metadata/system-schemas.
resource_id (str):
Optional. The <resource_id> portion of the Context name with
the format. This is globally unique in a metadataStore:
projects/123/locations/us-central1/metadataStores/<metadata_store_id>/Contexts/<resource_id>.
display_name (str):
Optional. The user-defined name of the Context.
schema_version (str):
Optional. schema_version specifies the version used by the Context.
If not set, defaults to use the latest version.
description (str):
Optional. Describes the purpose of the Context to be created.
metadata (Dict):
Optional. Contains the metadata information that will be stored in the Context.
metadata_store_id (str):
Optional. The <metadata_store_id> portion of the resource name with
the format:
projects/123/locations/us-central1/metadataStores/<metadata_store_id>/Contexts/<resource_id>
If not provided, the MetadataStore's ID will be set to "default".
project (str):
Optional. Project used to create this Context. Overrides project set in
aiplatform.init.
location (str):
Optional. Location used to create this Context. Overrides location set in
aiplatform.init.
credentials (auth_credentials.Credentials):
Optional. Custom credentials used to create this Context. Overrides
credentials set in aiplatform.init.
Returns:
Context: Instantiated representation of the managed Metadata Context.
"""
# Add User Agent Header for metrics tracking if one is not specified
# If one is already specified this call was initiated by a sub class.
if not base_constants.USER_AGENT_SDK_COMMAND:
base_constants.USER_AGENT_SDK_COMMAND = (
"aiplatform.metadata.context.Context.create"
)
return cls._create(
resource_id=resource_id,
schema_title=schema_title,
display_name=display_name,
schema_version=schema_version,
description=description,
metadata=metadata,
metadata_store_id=metadata_store_id,
project=project,
location=location,
credentials=credentials,
)
# TODO() refactor code to move _create to _Resource class.
@classmethod
def _create(
cls,
resource_id: str,
schema_title: str,
display_name: Optional[str] = None,
schema_version: Optional[str] = None,
description: Optional[str] = None,
metadata: Optional[Dict] = None,
metadata_store_id: Optional[str] = "default",
project: Optional[str] = None,
location: Optional[str] = None,
credentials: Optional[auth_credentials.Credentials] = None,
) -> "Context":
"""Creates a new Metadata resource.
Args:
resource_id (str):
Required. The <resource_id> portion of the resource name with
the format:
projects/123/locations/us-central1/metadataStores/<metadata_store_id>/<resource_noun>/<resource_id>.
schema_title (str):
Required. schema_title identifies the schema title used by the resource.
display_name (str):
Optional. The user-defined name of the resource.
schema_version (str):
Optional. schema_version specifies the version used by the resource.
If not set, defaults to use the latest version.
description (str):
Optional. Describes the purpose of the resource to be created.
metadata (Dict):
Optional. Contains the metadata information that will be stored in the resource.
metadata_store_id (str):
The <metadata_store_id> portion of the resource name with
the format:
projects/123/locations/us-central1/metadataStores/<metadata_store_id>/<resource_noun>/<resource_id>
If not provided, the MetadataStore's ID will be set to "default".
project (str):
Project used to create this resource. Overrides project set in
aiplatform.init.
location (str):
Location used to create this resource. Overrides location set in
aiplatform.init.
credentials (auth_credentials.Credentials):
Custom credentials used to create this resource. Overrides
credentials set in aiplatform.init.
Returns:
resource (_Resource):
Instantiated representation of the managed Metadata resource.
"""
appended_user_agent = []
if base_constants.USER_AGENT_SDK_COMMAND:
appended_user_agent = [
f"sdk_command/{base_constants.USER_AGENT_SDK_COMMAND}"
]
# Reset the value for the USER_AGENT_SDK_COMMAND to avoid counting future unrelated api calls.
base_constants.USER_AGENT_SDK_COMMAND = ""
api_client = cls._instantiate_client(
location=location,
credentials=credentials,
appended_user_agent=appended_user_agent,
)
parent = utils.full_resource_name(
resource_name=metadata_store_id,
resource_noun=metadata_store._MetadataStore._resource_noun,
parse_resource_name_method=metadata_store._MetadataStore._parse_resource_name,
format_resource_name_method=metadata_store._MetadataStore._format_resource_name,
project=project,
location=location,
)
resource = cls._create_resource(
client=api_client,
parent=parent,
resource_id=resource_id,
schema_title=schema_title,
display_name=display_name,
schema_version=schema_version,
description=description,
metadata=metadata,
)
self = cls._empty_constructor(
project=project, location=location, credentials=credentials
)
self._gca_resource = resource
self._threading_lock = threading.Lock()
return self
@classmethod
def _create_resource(
cls,
client: utils.MetadataClientWithOverride,
parent: str,
resource_id: str,
schema_title: str,
display_name: Optional[str] = None,
schema_version: Optional[str] = None,
description: Optional[str] = None,
metadata: Optional[Dict] = None,
) -> proto.Message:
gapic_context = gca_context.Context(
schema_title=schema_title,
schema_version=schema_version,
display_name=display_name,
description=description,
metadata=metadata if metadata else {},
)
return client.create_context(
parent=parent,
context=gapic_context,
context_id=resource_id,
)
def update(
self,
metadata: Optional[Dict] = None,
description: Optional[str] = None,
credentials: Optional[auth_credentials.Credentials] = None,
location: Optional[str] = None,
):
"""Updates an existing Metadata Context with new metadata.
This is implemented with retry on etag errors, up to
_ETAG_ERROR_MAX_RETRY_COUNT times.
Args:
metadata (Dict):
Optional. metadata contains the updated metadata information.
description (str):
Optional. Description describes the resource to be updated.
credentials (auth_credentials.Credentials):
Custom credentials to use to update this resource. Overrides
credentials set in aiplatform.init.
"""
for _ in range(_ETAG_ERROR_MAX_RETRY_COUNT - 1):
try:
super().update(
metadata=metadata,
description=description,
credentials=credentials,
location=location,
)
return
except Aborted as aborted_exception:
regex_match = _ETAG_ERROR_REGEX.match(aborted_exception.message)
if regex_match:
local_etag = regex_match.group(1)
server_etag = regex_match.group(2)
if local_etag < server_etag:
self.sync_resource()
continue
raise aborted_exception
# Expose result/exception directly in the last retry.
super().update(
metadata=metadata,
description=description,
credentials=credentials,
location=location,
)
@classmethod
def _update_resource(
cls,
client: utils.MetadataClientWithOverride,
resource: proto.Message,
) -> proto.Message:
"""Update Contexts with given input.
Args:
client (utils.MetadataClientWithOverride):
Required. client to send require to Metadata Service.
resource (proto.Message):
Required. The proto.Message which contains the update information for the resource.
"""
return client.update_context(context=resource)
@classmethod
def _list_resources(
cls,
client: utils.MetadataClientWithOverride,
parent: str,
filter: Optional[str] = None, # pylint: disable=redefined-builtin
order_by: Optional[str] = None,
):
"""List Contexts in the parent path that matches the filter.
Args:
client (utils.MetadataClientWithOverride):
Required. client to send require to Metadata Service.
parent (str):
Required. The path where Contexts are stored.
filter (str):
Optional. filter string to restrict the list result
order_by (str):
Optional. How the list of messages is ordered. Specify the
values to order by and an ordering operation. The default sorting
order is ascending. To specify descending order for a field, users
append a " desc" suffix; for example: "foo desc, bar". Subfields
are specified with a ``.`` character, such as foo.bar. see
https://google.aip.dev/132#ordering for more details.
Returns:
List of Contexts.
"""
list_request = gca_metadata_service.ListContextsRequest(
parent=parent,
filter=filter,
order_by=order_by,
)
return client.list_contexts(request=list_request)
def add_context_children(self, contexts: List["Context"]):
"""Adds the provided contexts as children of this context.
Args:
contexts (List[_Context]): Contexts to add as children.
"""
self.api_client.add_context_children(
context=self.resource_name,
child_contexts=[c.resource_name for c in contexts],
)
def query_lineage_subgraph(self) -> gca_lineage_subgraph.LineageSubgraph:
"""Queries lineage subgraph of this context.
Returns:
lineage subgraph(gca_lineage_subgraph.LineageSubgraph): Lineage subgraph of this Context.
"""
return self.api_client.query_context_lineage_subgraph(
context=self.resource_name, retry=base._DEFAULT_RETRY
)
def get_executions(self) -> List[execution.Execution]:
"""Returns Executions associated to this context.
Returns:
executions (List[Executions]): Executions associated to this context.
"""
return execution.Execution.list(
filter=metadata_utils._make_filter_string(in_context=[self.resource_name]),
project=self.project,
location=self.location,
credentials=self.credentials,
)

View File

@@ -0,0 +1,533 @@
# -*- coding: utf-8 -*-
# Copyright 2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from copy import deepcopy
from typing import Any, Dict, List, Optional, Union
import proto
from google.auth import credentials as auth_credentials
from google.cloud.aiplatform import models
from google.cloud.aiplatform import utils
from google.cloud.aiplatform.compat.types import event as gca_event
from google.cloud.aiplatform.compat.types import execution as gca_execution
from google.cloud.aiplatform.compat.types import (
metadata_service as gca_metadata_service,
)
from google.cloud.aiplatform.constants import base as base_constants
from google.cloud.aiplatform.metadata import artifact
from google.cloud.aiplatform.metadata import metadata_store
from google.cloud.aiplatform.metadata import resource
class Execution(resource._Resource):
"""Metadata Execution resource for Vertex AI"""
_resource_noun = "executions"
_getter_method = "get_execution"
_delete_method = "delete_execution"
_parse_resource_name_method = "parse_execution_path"
_format_resource_name_method = "execution_path"
_list_method = "list_executions"
def __init__(
self,
execution_name: str,
*,
metadata_store_id: str = "default",
project: Optional[str] = None,
location: Optional[str] = None,
credentials: Optional[auth_credentials.Credentials] = None,
):
"""Retrieves an existing Metadata Execution given a resource name or ID.
Args:
execution_name (str):
Required. A fully-qualified resource name or resource ID of the Execution.
Example: "projects/123/locations/us-central1/metadataStores/default/executions/my-resource".
or "my-resource" when project and location are initialized or passed.
metadata_store_id (str):
Optional. MetadataStore to retrieve Execution from. If not set, metadata_store_id is set to "default".
If execution_name is a fully-qualified resource, its metadata_store_id overrides this one.
project (str):
Optional. Project to retrieve the artifact from. If not set, project
set in aiplatform.init will be used.
location (str):
Optional. Location to retrieve the Execution from. If not set, location
set in aiplatform.init will be used.
credentials (auth_credentials.Credentials):
Optional. Custom credentials to use to retrieve this Execution. Overrides
credentials set in aiplatform.init.
"""
super().__init__(
resource_name=execution_name,
metadata_store_id=metadata_store_id,
project=project,
location=location,
credentials=credentials,
)
@property
def state(self) -> gca_execution.Execution.State:
"""State of this Execution."""
return self._gca_resource.state
@classmethod
def create(
cls,
schema_title: str,
*,
state: gca_execution.Execution.State = gca_execution.Execution.State.RUNNING,
resource_id: Optional[str] = None,
display_name: Optional[str] = None,
schema_version: Optional[str] = None,
metadata: Optional[Dict[str, Any]] = None,
description: Optional[str] = None,
metadata_store_id: str = "default",
project: Optional[str] = None,
location: Optional[str] = None,
credentials=Optional[auth_credentials.Credentials],
) -> "Execution":
"""
Creates a new Metadata Execution.
Args:
schema_title (str):
Required. schema_title identifies the schema title used by the Execution.
state (gca_execution.Execution.State.RUNNING):
Optional. State of this Execution. Defaults to RUNNING.
resource_id (str):
Optional. The <resource_id> portion of the Execution name with
the format. This is globally unique in a metadataStore:
projects/123/locations/us-central1/metadataStores/<metadata_store_id>/executions/<resource_id>.
display_name (str):
Optional. The user-defined name of the Execution.
schema_version (str):
Optional. schema_version specifies the version used by the Execution.
If not set, defaults to use the latest version.
metadata (Dict):
Optional. Contains the metadata information that will be stored in the Execution.
description (str):
Optional. Describes the purpose of the Execution to be created.
metadata_store_id (str):
Optional. The <metadata_store_id> portion of the resource name with
the format:
projects/123/locations/us-central1/metadataStores/<metadata_store_id>/artifacts/<resource_id>
If not provided, the MetadataStore's ID will be set to "default".
project (str):
Optional. Project used to create this Execution. Overrides project set in
aiplatform.init.
location (str):
Optional. Location used to create this Execution. Overrides location set in
aiplatform.init.
credentials (auth_credentials.Credentials):
Optional. Custom credentials used to create this Execution. Overrides
credentials set in aiplatform.init.
Returns:
Execution: Instantiated representation of the managed Metadata Execution.
"""
# Add User Agent Header for metrics tracking if one is not specified
# If one is already specified this call was initiated by a sub class.
if not base_constants.USER_AGENT_SDK_COMMAND:
base_constants.USER_AGENT_SDK_COMMAND = (
"aiplatform.metadata.execution.Execution.create"
)
return cls._create(
resource_id=resource_id,
schema_title=schema_title,
display_name=display_name,
schema_version=schema_version,
description=description,
metadata=metadata,
state=state,
metadata_store_id=metadata_store_id,
project=project,
location=location,
credentials=credentials,
)
# TODO() refactor code to move _create to _Resource class.
@classmethod
def _create(
cls,
schema_title: str,
*,
state: gca_execution.Execution.State = gca_execution.Execution.State.RUNNING,
resource_id: Optional[str] = None,
display_name: Optional[str] = None,
schema_version: Optional[str] = None,
metadata: Optional[Dict[str, Any]] = None,
description: Optional[str] = None,
metadata_store_id: str = "default",
project: Optional[str] = None,
location: Optional[str] = None,
credentials=Optional[auth_credentials.Credentials],
) -> "Execution":
"""
Creates a new Metadata Execution.
Args:
schema_title (str):
Required. schema_title identifies the schema title used by the Execution.
state (gca_execution.Execution.State.RUNNING):
Optional. State of this Execution. Defaults to RUNNING.
resource_id (str):
Optional. The <resource_id> portion of the Execution name with
the format. This is globally unique in a metadataStore:
projects/123/locations/us-central1/metadataStores/<metadata_store_id>/executions/<resource_id>.
display_name (str):
Optional. The user-defined name of the Execution.
schema_version (str):
Optional. schema_version specifies the version used by the Execution.
If not set, defaults to use the latest version.
metadata (Dict):
Optional. Contains the metadata information that will be stored in the Execution.
description (str):
Optional. Describes the purpose of the Execution to be created.
metadata_store_id (str):
Optional. The <metadata_store_id> portion of the resource name with
the format:
projects/123/locations/us-central1/metadataStores/<metadata_store_id>/artifacts/<resource_id>
If not provided, the MetadataStore's ID will be set to "default".
project (str):
Optional. Project used to create this Execution. Overrides project set in
aiplatform.init.
location (str):
Optional. Location used to create this Execution. Overrides location set in
aiplatform.init.
credentials (auth_credentials.Credentials):
Optional. Custom credentials used to create this Execution. Overrides
credentials set in aiplatform.init.
Returns:
Execution: Instantiated representation of the managed Metadata Execution.
"""
appended_user_agent = []
if base_constants.USER_AGENT_SDK_COMMAND:
appended_user_agent = [
f"sdk_command/{base_constants.USER_AGENT_SDK_COMMAND}"
]
# Reset the value for the USER_AGENT_SDK_COMMAND to avoid counting future unrelated api calls.
base_constants.USER_AGENT_SDK_COMMAND = ""
api_client = cls._instantiate_client(
location=location,
credentials=credentials,
appended_user_agent=appended_user_agent,
)
parent = utils.full_resource_name(
resource_name=metadata_store_id,
resource_noun=metadata_store._MetadataStore._resource_noun,
parse_resource_name_method=metadata_store._MetadataStore._parse_resource_name,
format_resource_name_method=metadata_store._MetadataStore._format_resource_name,
project=project,
location=location,
)
resource = Execution._create_resource(
client=api_client,
parent=parent,
schema_title=schema_title,
resource_id=resource_id,
metadata=metadata,
description=description,
display_name=display_name,
schema_version=schema_version,
state=state,
)
self = cls._empty_constructor(
project=project, location=location, credentials=credentials
)
self._gca_resource = resource
return self
def __enter__(self):
if self.state is not gca_execution.Execution.State.RUNNING:
self.update(state=gca_execution.Execution.State.RUNNING)
return self
def __exit__(self, exc_type, exc_val, exc_tb):
state = (
gca_execution.Execution.State.FAILED
if exc_type
else gca_execution.Execution.State.COMPLETE
)
self.update(state=state)
def assign_input_artifacts(
self, artifacts: List[Union[artifact.Artifact, models.Model]]
):
"""Assigns Artifacts as inputs to this Executions.
Args:
artifacts (List[Union[artifact.Artifact, models.Model]]):
Required. Artifacts to assign as input.
"""
self._add_artifact(artifacts=artifacts, input=True)
def assign_output_artifacts(
self, artifacts: List[Union[artifact.Artifact, models.Model]]
):
"""Assigns Artifacts as outputs to this Executions.
Args:
artifacts (List[Union[artifact.Artifact, models.Model]]):
Required. Artifacts to assign as input.
"""
self._add_artifact(artifacts=artifacts, input=False)
def _add_artifact(
self,
artifacts: List[Union[artifact.Artifact, models.Model]],
input: bool,
):
"""Connect Artifact to a given Execution.
Args:
artifact_resource_names (List[str]):
Required. The full resource name of the Artifact to connect to the Execution through an Event.
input (bool)
Required. Whether Artifact is an input event to the Execution or not.
"""
artifact_resource_names = []
for a in artifacts:
if isinstance(a, artifact.Artifact):
artifact_resource_names.append(a.resource_name)
else:
artifact_resource_names.append(
artifact._VertexResourceArtifactResolver.resolve_or_create_resource_artifact(
a
).resource_name
)
events = [
gca_event.Event(
artifact=artifact_resource_name,
type_=gca_event.Event.Type.INPUT
if input
else gca_event.Event.Type.OUTPUT,
)
for artifact_resource_name in artifact_resource_names
]
self.api_client.add_execution_events(
execution=self.resource_name,
events=events,
)
def _get_artifacts(
self, event_type: gca_event.Event.Type
) -> List[artifact.Artifact]:
"""Get Executions input or output Artifacts.
Args:
event_type (gca_event.Event.Type):
Required. The Event type, input or output.
Returns:
List of Artifacts.
"""
subgraph = self.api_client.query_execution_inputs_and_outputs(
execution=self.resource_name
)
artifact_map = {
artifact_metadata.name: artifact_metadata
for artifact_metadata in subgraph.artifacts
}
gca_artifacts = [
artifact_map[event.artifact]
for event in subgraph.events
if event.type_ == event_type
]
artifacts = []
for gca_artifact in gca_artifacts:
this_artifact = artifact.Artifact._empty_constructor(
project=self.project,
location=self.location,
credentials=self.credentials,
)
this_artifact._gca_resource = gca_artifact
artifacts.append(this_artifact)
return artifacts
def get_input_artifacts(self) -> List[artifact.Artifact]:
"""Get the input Artifacts of this Execution.
Returns:
List of input Artifacts.
"""
return self._get_artifacts(event_type=gca_event.Event.Type.INPUT)
def get_output_artifacts(self) -> List[artifact.Artifact]:
"""Get the output Artifacts of this Execution.
Returns:
List of output Artifacts.
"""
return self._get_artifacts(event_type=gca_event.Event.Type.OUTPUT)
@classmethod
def _create_resource(
cls,
client: utils.MetadataClientWithOverride,
parent: str,
schema_title: str,
state: gca_execution.Execution.State = gca_execution.Execution.State.RUNNING,
resource_id: Optional[str] = None,
display_name: Optional[str] = None,
schema_version: Optional[str] = None,
description: Optional[str] = None,
metadata: Optional[Dict] = None,
) -> gca_execution.Execution:
"""
Creates a new Metadata Execution.
Args:
client (utils.MetadataClientWithOverride):
Required. Instantiated Metadata Service Client.
parent (str):
Required: MetadataStore parent in which to create this Execution.
schema_title (str):
Required. schema_title identifies the schema title used by the Execution.
state (gca_execution.Execution.State):
Optional. State of this Execution. Defaults to RUNNING.
resource_id (str):
Optional. The {execution} portion of the resource name with the
format:
``projects/{project}/locations/{location}/metadataStores/{metadatastore}/executions/{execution}``
If not provided, the Execution's ID will be a UUID generated
by the service. Must be 4-128 characters in length. Valid
characters are ``/[a-z][0-9]-/``. Must be unique across all
Executions in the parent MetadataStore. (Otherwise the
request will fail with ALREADY_EXISTS, or PERMISSION_DENIED
if the caller can't view the preexisting Execution.)
display_name (str):
Optional. The user-defined name of the Execution.
schema_version (str):
Optional. schema_version specifies the version used by the Execution.
If not set, defaults to use the latest version.
description (str):
Optional. Describes the purpose of the Execution to be created.
metadata (Dict):
Optional. Contains the metadata information that will be stored in the Execution.
Returns:
Execution: Instantiated representation of the managed Metadata Execution.
"""
gapic_execution = gca_execution.Execution(
schema_title=schema_title,
schema_version=schema_version,
display_name=display_name,
description=description,
metadata=metadata if metadata else {},
state=state,
)
return client.create_execution(
parent=parent,
execution=gapic_execution,
execution_id=resource_id,
)
@classmethod
def _list_resources(
cls,
client: utils.MetadataClientWithOverride,
parent: str,
filter: Optional[str] = None, # pylint: disable=redefined-builtin
order_by: Optional[str] = None,
):
"""List Executions in the parent path that matches the filter.
Args:
client (utils.MetadataClientWithOverride):
Required. client to send require to Metadata Service.
parent (str):
Required. The path where Executions are stored.
filter (str):
Optional. filter string to restrict the list result
order_by (str):
Optional. How the list of messages is ordered. Specify the
values to order by and an ordering operation. The default sorting
order is ascending. To specify descending order for a field, users
append a " desc" suffix; for example: "foo desc, bar". Subfields
are specified with a ``.`` character, such as foo.bar. see
https://google.aip.dev/132#ordering for more details.
Returns:
List of execution.
"""
list_request = gca_metadata_service.ListExecutionsRequest(
parent=parent,
filter=filter,
order_by=order_by,
)
return client.list_executions(request=list_request)
@classmethod
def _update_resource(
cls,
client: utils.MetadataClientWithOverride,
resource: proto.Message,
) -> proto.Message:
"""Update Executions with given input.
Args:
client (utils.MetadataClientWithOverride):
Required. client to send require to Metadata Service.
resource (proto.Message):
Required. The proto.Message which contains the update information for the resource.
"""
return client.update_execution(execution=resource)
def update(
self,
state: Optional[gca_execution.Execution.State] = None,
description: Optional[str] = None,
metadata: Optional[Dict[str, Any]] = None,
):
"""Update this Execution.
Args:
state (gca_execution.Execution.State):
Optional. State of this Execution.
description (str):
Optional. Describes the purpose of the Execution to be created.
metadata (Dict[str, Any):
Optional. Contains the metadata information that will be stored in the Execution.
"""
gca_resource = deepcopy(self._gca_resource)
if state:
gca_resource.state = state
if description:
gca_resource.description = description
self._nested_update_metadata(gca_resource=gca_resource, metadata=metadata)
self._gca_resource = self._update_resource(
self.api_client, resource=gca_resource
)

View File

@@ -0,0 +1,827 @@
# -*- coding: utf-8 -*-
# Copyright 2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import abc
import concurrent.futures
from dataclasses import dataclass
import logging
from typing import Dict, List, NamedTuple, Optional, Tuple, Type, Union
from google.api_core import exceptions
from google.auth import credentials as auth_credentials
from google.cloud.aiplatform import base
from google.cloud.aiplatform.metadata import artifact
from google.cloud.aiplatform.metadata import constants
from google.cloud.aiplatform.metadata import context
from google.cloud.aiplatform.metadata import execution
from google.cloud.aiplatform.metadata import metadata
from google.cloud.aiplatform.metadata import metadata_store
from google.cloud.aiplatform.metadata import resource
from google.cloud.aiplatform.metadata import utils as metadata_utils
from google.cloud.aiplatform.tensorboard import tensorboard_resource
_LOGGER = base.Logger(__name__)
_HIGH_RUN_COUNT_THRESHOLD = 100 # Used in get_data_frame to make suggestion to user
@dataclass
class _ExperimentRow:
"""Class for representing a run row in an Experiments Dataframe.
Attributes:
params (Dict[str, Union[float, int, str]]): Optional. The parameters of this run.
metrics (Dict[str, Union[float, int, str]]): Optional. The metrics of this run.
time_series_metrics (Dict[str, float]): Optional. The latest time series metrics of this run.
experiment_run_type (Optional[str]): Optional. The type of this run.
name (str): Optional. The name of this run.
state (str): Optional. The state of this run.
"""
params: Optional[Dict[str, Union[float, int, str]]] = None
metrics: Optional[Dict[str, Union[float, int, str]]] = None
time_series_metrics: Optional[Dict[str, float]] = None
experiment_run_type: Optional[str] = None
name: Optional[str] = None
state: Optional[str] = None
def to_dict(self) -> Dict[str, Union[float, int, str]]:
"""Converts this experiment row into a dictionary.
Returns:
Row as a dictionary.
"""
result = {
"run_type": self.experiment_run_type,
"run_name": self.name,
"state": self.state,
}
for prefix, field in [
(constants._PARAM_PREFIX, self.params),
(constants._METRIC_PREFIX, self.metrics),
(constants._TIME_SERIES_METRIC_PREFIX, self.time_series_metrics),
]:
if field:
result.update(
{f"{prefix}.{key}": value for key, value in field.items()}
)
return result
class Experiment:
"""Represents a Vertex AI Experiment resource."""
def __init__(
self,
experiment_name: str,
*,
project: Optional[str] = None,
location: Optional[str] = None,
credentials: Optional[auth_credentials.Credentials] = None,
):
"""
```py
my_experiment = aiplatform.Experiment('my-experiment')
```
Args:
experiment_name (str):
Required. The name or resource name of this experiment.
Resource name is of the format:
`projects/123/locations/us-central1/metadataStores/default/contexts/my-experiment`
project (str):
Optional. Project where this experiment is located. Overrides
project set in aiplatform.init.
location (str):
Optional. Location where this experiment is located. Overrides
location set in aiplatform.init.
credentials (auth_credentials.Credentials):
Optional. Custom credentials used to retrieve this experiment.
Overrides credentials set in aiplatform.init.
"""
metadata_args = dict(
resource_name=experiment_name,
project=project,
location=location,
credentials=credentials,
)
with _SetLoggerLevel(resource):
experiment_context = context.Context(**metadata_args)
self._validate_experiment_context(experiment_context)
self._metadata_context = experiment_context
@staticmethod
def _validate_experiment_context(experiment_context: context.Context):
"""Validates this context is an experiment context.
Args:
experiment_context (context._Context): Metadata context.
Raises:
ValueError: If Metadata context is not an experiment context or a TensorboardExperiment.
"""
if experiment_context.schema_title != constants.SYSTEM_EXPERIMENT:
raise ValueError(
f"Experiment name {experiment_context.name} is of type "
f"({experiment_context.schema_title}) in this MetadataStore. "
f"It must of type {constants.SYSTEM_EXPERIMENT}."
)
if Experiment._is_tensorboard_experiment(experiment_context):
raise ValueError(
f"Experiment name {experiment_context.name} is a TensorboardExperiment context "
f"and cannot be used as a Vertex AI Experiment."
)
@staticmethod
def _is_tensorboard_experiment(context: context.Context) -> bool:
"""Returns True if Experiment is a Tensorboard Experiment created by CustomJob."""
return constants.TENSORBOARD_CUSTOM_JOB_EXPERIMENT_FIELD in context.metadata
@property
def name(self) -> str:
"""The name of this experiment."""
return self._metadata_context.name
@classmethod
def create(
cls,
experiment_name: str,
*,
description: Optional[str] = None,
project: Optional[str] = None,
location: Optional[str] = None,
credentials: Optional[auth_credentials.Credentials] = None,
) -> "Experiment":
"""Creates a new experiment in Vertex AI Experiments.
```py
my_experiment = aiplatform.Experiment.create('my-experiment', description='my description')
```
Args:
experiment_name (str): Required. The name of this experiment.
description (str): Optional. Describes this experiment's purpose.
project (str):
Optional. Project where this experiment will be created. Overrides project set in
aiplatform.init.
location (str):
Optional. Location where this experiment will be created. Overrides location set in
aiplatform.init.
credentials (auth_credentials.Credentials):
Optional. Custom credentials used to create this experiment. Overrides
credentials set in aiplatform.init.
Returns:
The newly created experiment.
"""
metadata_store._MetadataStore.ensure_default_metadata_store_exists(
project=project, location=location, credentials=credentials
)
with _SetLoggerLevel(resource):
experiment_context = context.Context._create(
resource_id=experiment_name,
display_name=experiment_name,
description=description,
schema_title=constants.SYSTEM_EXPERIMENT,
schema_version=metadata._get_experiment_schema_version(),
metadata=constants.EXPERIMENT_METADATA,
project=project,
location=location,
credentials=credentials,
)
self = cls.__new__(cls)
self._metadata_context = experiment_context
return self
@classmethod
def get(
cls,
experiment_name: str,
*,
project: Optional[str] = None,
location: Optional[str] = None,
credentials: Optional[auth_credentials.Credentials] = None,
) -> Optional["Experiment"]:
"""Gets experiment if one exists with this experiment_name in Vertex AI Experiments.
Args:
experiment_name (str):
Required. The name of this experiment.
project (str):
Optional. Project used to retrieve this resource.
Overrides project set in aiplatform.init.
location (str):
Optional. Location used to retrieve this resource.
Overrides location set in aiplatform.init.
credentials (auth_credentials.Credentials):
Optional. Custom credentials used to retrieve this resource.
Overrides credentials set in aiplatform.init.
Returns:
Vertex AI experiment or None if no resource was found.
"""
try:
return cls(
experiment_name=experiment_name,
project=project,
location=location,
credentials=credentials,
)
except exceptions.NotFound:
return None
@classmethod
def get_or_create(
cls,
experiment_name: str,
*,
description: Optional[str] = None,
project: Optional[str] = None,
location: Optional[str] = None,
credentials: Optional[auth_credentials.Credentials] = None,
) -> "Experiment":
"""Gets experiment if one exists with this experiment_name in Vertex AI Experiments.
Otherwise creates this experiment.
```py
my_experiment = aiplatform.Experiment.get_or_create('my-experiment', description='my description')
```
Args:
experiment_name (str): Required. The name of this experiment.
description (str): Optional. Describes this experiment's purpose.
project (str):
Optional. Project where this experiment will be retrieved from or created. Overrides project set in
aiplatform.init.
location (str):
Optional. Location where this experiment will be retrieved from or created. Overrides location set in
aiplatform.init.
credentials (auth_credentials.Credentials):
Optional. Custom credentials used to retrieve or create this experiment. Overrides
credentials set in aiplatform.init.
Returns:
Vertex AI experiment.
"""
metadata_store._MetadataStore.ensure_default_metadata_store_exists(
project=project, location=location, credentials=credentials
)
with _SetLoggerLevel(resource):
experiment_context = context.Context.get_or_create(
resource_id=experiment_name,
display_name=experiment_name,
description=description,
schema_title=constants.SYSTEM_EXPERIMENT,
schema_version=metadata._get_experiment_schema_version(),
metadata=constants.EXPERIMENT_METADATA,
project=project,
location=location,
credentials=credentials,
)
cls._validate_experiment_context(experiment_context)
if description and description != experiment_context.description:
experiment_context.update(description=description)
self = cls.__new__(cls)
self._metadata_context = experiment_context
return self
@classmethod
def list(
cls,
*,
project: Optional[str] = None,
location: Optional[str] = None,
credentials: Optional[auth_credentials.Credentials] = None,
) -> List["Experiment"]:
"""List all Vertex AI Experiments in the given project.
```py
my_experiments = aiplatform.Experiment.list()
```
Args:
project (str):
Optional. Project to list these experiments from. Overrides project set in
aiplatform.init.
location (str):
Optional. Location to list these experiments from. Overrides location set in
aiplatform.init.
credentials (auth_credentials.Credentials):
Optional. Custom credentials to list these experiments. Overrides
credentials set in aiplatform.init.
Returns:
List of Vertex AI experiments.
"""
filter_str = metadata_utils._make_filter_string(
schema_title=constants.SYSTEM_EXPERIMENT
)
with _SetLoggerLevel(resource):
experiment_contexts = context.Context.list(
filter=filter_str,
project=project,
location=location,
credentials=credentials,
)
experiments = []
for experiment_context in experiment_contexts:
# Filters Tensorboard Experiments
if not cls._is_tensorboard_experiment(experiment_context):
experiment = cls.__new__(cls)
experiment._metadata_context = experiment_context
experiments.append(experiment)
return experiments
@property
def resource_name(self) -> str:
"""The Metadata context resource name of this experiment."""
return self._metadata_context.resource_name
@property
def backing_tensorboard_resource_name(self) -> Optional[str]:
"""The Tensorboard resource associated with this Experiment if there is one."""
return self._metadata_context.metadata.get(
constants._BACKING_TENSORBOARD_RESOURCE_KEY
)
def delete(self, *, delete_backing_tensorboard_runs: bool = False):
"""Deletes this experiment all the experiment runs under this experiment
Does not delete Pipeline runs, Artifacts, or Executions associated to this experiment
or experiment runs in this experiment.
```py
my_experiment = aiplatform.Experiment('my-experiment')
my_experiment.delete(delete_backing_tensorboard_runs=True)
```
Args:
delete_backing_tensorboard_runs (bool):
Optional. If True will also delete the Tensorboard Runs associated to the experiment
runs under this experiment that we used to store time series metrics.
"""
experiment_runs = _SUPPORTED_LOGGABLE_RESOURCES[context.Context][
constants.SYSTEM_EXPERIMENT_RUN
].list(experiment=self)
for experiment_run in experiment_runs:
experiment_run.delete(
delete_backing_tensorboard_run=delete_backing_tensorboard_runs
)
try:
self._metadata_context.delete()
except exceptions.NotFound:
_LOGGER.warning(
f"Experiment {self.name} metadata node not found. Skipping deletion."
)
def get_data_frame(
self, *, include_time_series: bool = True
) -> "pd.DataFrame": # noqa: F821
"""Get parameters, metrics, and time series metrics of all runs in this experiment as Dataframe.
```py
my_experiment = aiplatform.Experiment('my-experiment')
df = my_experiment.get_data_frame()
```
Args:
include_time_series (bool):
Optional. Whether or not to include time series metrics in df.
Default is True. Setting to False will largely improve execution
time and reduce quota contributing calls. Recommended when time
series metrics are not needed or number of runs in Experiment is
large. For time series metrics consider querying a specific run
using get_time_series_data_frame.
Returns:
pd.DataFrame: Pandas Dataframe of Experiment Runs.
Raises:
ImportError: If pandas is not installed.
"""
try:
import pandas as pd
except ImportError:
raise ImportError(
"Pandas is not installed and is required to get dataframe as the return format. "
'Please install the SDK using "pip install google-cloud-aiplatform[metadata]"'
)
service_request_args = dict(
project=self._metadata_context.project,
location=self._metadata_context.location,
credentials=self._metadata_context.credentials,
)
filter_str = metadata_utils._make_filter_string(
schema_title=sorted(
list(_SUPPORTED_LOGGABLE_RESOURCES[context.Context].keys())
),
parent_contexts=[self._metadata_context.resource_name],
)
contexts = context.Context.list(filter_str, **service_request_args)
filter_str = metadata_utils._make_filter_string(
schema_title=list(
_SUPPORTED_LOGGABLE_RESOURCES[execution.Execution].keys()
),
in_context=[self._metadata_context.resource_name],
)
executions = execution.Execution.list(filter_str, **service_request_args)
run_count = max([len(contexts), len(executions)])
if include_time_series and run_count > _HIGH_RUN_COUNT_THRESHOLD:
_LOGGER.warning(
f"Number of runs {run_count} is high. Consider setting "
f"include_time_series to False to improve execution performance"
)
if not include_time_series:
_LOGGER.warning(
"include_time_series is set to False. Time series metrics will"
" not be included in this call even if they exist."
)
rows = []
if contexts or executions:
with concurrent.futures.ThreadPoolExecutor(
max_workers=run_count
) as executor:
futures = [
executor.submit(
_SUPPORTED_LOGGABLE_RESOURCES[context.Context][
metadata_context.schema_title
]._query_experiment_row,
metadata_context,
experiment=self,
include_time_series=include_time_series,
)
for metadata_context in contexts
]
# backward compatibility
futures.extend(
executor.submit(
_SUPPORTED_LOGGABLE_RESOURCES[execution.Execution][
metadata_execution.schema_title
]._query_experiment_row,
metadata_execution,
experiment=self,
include_time_series=include_time_series,
)
for metadata_execution in executions
)
for future in futures:
try:
row_dict = future.result().to_dict()
except Exception as exc:
raise ValueError(
f"Failed to get experiment row for {self.name}"
) from exc
else:
row_dict.update({"experiment_name": self.name})
rows.append(row_dict)
df = pd.DataFrame(rows)
column_name_sort_map = {
"experiment_name": -1,
"run_name": 1,
"run_type": 2,
"state": 3,
}
def column_sort_key(key: str) -> int:
"""Helper method to reorder columns."""
order = column_name_sort_map.get(key)
if order:
return order
elif key.startswith("param"):
return 5
elif key.startswith("metric"):
return 6
else:
return 7
columns = df.columns
columns = sorted(columns, key=column_sort_key)
df = df.reindex(columns, axis=1)
return df
def _lookup_backing_tensorboard(self) -> Optional[tensorboard_resource.Tensorboard]:
"""Returns backing tensorboard if one is set.
Returns:
Tensorboard resource if one exists, otherwise returns None.
"""
tensorboard_resource_name = self._metadata_context.metadata.get(
constants._BACKING_TENSORBOARD_RESOURCE_KEY
)
if not tensorboard_resource_name:
with _SetLoggerLevel(resource):
self._metadata_context.sync_resource()
tensorboard_resource_name = self._metadata_context.metadata.get(
constants._BACKING_TENSORBOARD_RESOURCE_KEY
)
if tensorboard_resource_name:
try:
return tensorboard_resource.Tensorboard(
tensorboard_resource_name,
credentials=self._metadata_context.credentials,
)
except exceptions.NotFound:
self._metadata_context.update(
metadata={constants._BACKING_TENSORBOARD_RESOURCE_KEY: None}
)
return None
def get_backing_tensorboard_resource(
self,
) -> Optional[tensorboard_resource.Tensorboard]:
"""Get the backing tensorboard for this experiment if one exists.
```py
my_experiment = aiplatform.Experiment('my-experiment')
tb = my_experiment.get_backing_tensorboard_resource()
```
Returns:
Backing Tensorboard resource for this experiment if one exists.
"""
return self._lookup_backing_tensorboard()
def assign_backing_tensorboard(
self, tensorboard: Union[tensorboard_resource.Tensorboard, str]
):
"""Assigns tensorboard as backing tensorboard to support time series metrics logging.
```py
tb = aiplatform.Tensorboard('tensorboard-resource-id')
my_experiment = aiplatform.Experiment('my-experiment')
my_experiment.assign_backing_tensorboard(tb)
```
Args:
tensorboard (Union[aiplatform.Tensorboard, str]):
Required. Tensorboard resource or resource name to associate to this experiment.
Raises:
ValueError: If this experiment already has a previously set backing tensorboard resource.
ValueError: If Tensorboard is not in same project and location as this experiment.
"""
backing_tensorboard = self._lookup_backing_tensorboard()
if backing_tensorboard:
tensorboard_resource_name = (
tensorboard
if isinstance(tensorboard, str)
else tensorboard.resource_name
)
if tensorboard_resource_name != backing_tensorboard.resource_name:
raise ValueError(
f"Experiment {self._metadata_context.name} already associated '"
f"to tensorboard resource {backing_tensorboard.resource_name}"
)
if isinstance(tensorboard, str):
tensorboard = tensorboard_resource.Tensorboard(
tensorboard,
project=self._metadata_context.project,
location=self._metadata_context.location,
credentials=self._metadata_context.credentials,
)
if tensorboard.project not in self._metadata_context._project_tuple:
raise ValueError(
f"Tensorboard is in project {tensorboard.project} but must be in project {self._metadata_context.project}"
)
if tensorboard.location != self._metadata_context.location:
raise ValueError(
f"Tensorboard is in location {tensorboard.location} but must be in location {self._metadata_context.location}"
)
self._metadata_context.update(
metadata={
constants._BACKING_TENSORBOARD_RESOURCE_KEY: tensorboard.resource_name
},
location=self._metadata_context.location,
)
def _log_experiment_loggable(self, experiment_loggable: "_ExperimentLoggable"):
"""Associates a Vertex resource that can be logged to an Experiment as run of this experiment.
Args:
experiment_loggable (_ExperimentLoggable):
A Vertex Resource that can be logged to an Experiment directly.
"""
context = experiment_loggable._get_context()
self._metadata_context.add_context_children([context])
@property
def dashboard_url(self) -> Optional[str]:
"""Cloud console URL for this resource."""
url = f"https://console.cloud.google.com/vertex-ai/experiments/locations/{self._metadata_context.location}/experiments/{self._metadata_context.name}?project={self._metadata_context.project}"
return url
class _SetLoggerLevel:
"""Helper method to suppress logging."""
def __init__(self, module):
self._module = module
def __enter__(self):
logging.getLogger(self._module.__name__).setLevel(logging.WARNING)
def __exit__(self, exc_type, exc_value, traceback):
logging.getLogger(self._module.__name__).setLevel(logging.INFO)
class _VertexResourceWithMetadata(NamedTuple):
"""Represents a resource coupled with it's metadata representation"""
resource: base.VertexAiResourceNoun
metadata: Union[artifact.Artifact, execution.Execution, context.Context]
class _ExperimentLoggableSchema(NamedTuple):
"""Used with _ExperimentLoggable to capture Metadata representation information about resoure.
For example:
_ExperimentLoggableSchema(title='system.PipelineRun', type=context._Context)
Defines the schema and metadata type to lookup PipelineJobs.
"""
title: str
type: Union[Type[context.Context], Type[execution.Execution]] = context.Context
class _ExperimentLoggable(abc.ABC):
"""Abstract base class to define a Vertex Resource as loggable against an Experiment.
For example:
class PipelineJob(..., experiment_loggable_schemas=
(_ExperimentLoggableSchema(title='system.PipelineRun'), )
"""
def __init_subclass__(
cls, *, experiment_loggable_schemas: Tuple[_ExperimentLoggableSchema], **kwargs
):
"""Register the metadata_schema for the subclass so Experiment can use it to retrieve the associated types.
usage:
class PipelineJob(..., experiment_loggable_schemas=
(_ExperimentLoggableSchema(title='system.PipelineRun'), )
Args:
experiment_loggable_schemas:
Tuple of the schema_title and type pairs that represent this resource. Note that a single item in the
tuple will be most common. Currently only experiment run has multiple representation for backwards
compatibility. Almost all schemas should be Contexts and Execution is currently only supported
for backwards compatibility of experiment runs.
"""
super().__init_subclass__(**kwargs)
# register the type when module is loaded
for schema in experiment_loggable_schemas:
_SUPPORTED_LOGGABLE_RESOURCES[schema.type][schema.title] = cls
@abc.abstractmethod
def _get_context(self) -> context.Context:
"""Should return the metadata context that represents this resource.
The subclass should enforce this context exists.
Returns:
Context that represents this resource.
"""
pass
@classmethod
@abc.abstractmethod
def _query_experiment_row(
cls, node: Union[context.Context, execution.Execution]
) -> _ExperimentRow:
"""Should return parameters and metrics for this resource as a run row.
Args:
node: The metadata node that represents this resource.
Returns:
A populated run row for this resource.
"""
pass
def _validate_experiment(self, experiment: Union[str, Experiment]):
"""Validates experiment is accessible. Can be used by subclass to throw before creating the intended resource.
Args:
experiment (Union[str, Experiment]): The experiment that this resource will be associated to.
Raises:
RuntimeError: If service raises any exception when trying to access this experiment.
ValueError: If resource project or location do not match experiment project or location.
"""
if isinstance(experiment, str):
try:
experiment = Experiment.get_or_create(
experiment,
project=self.project,
location=self.location,
credentials=self.credentials,
)
except Exception as e:
raise RuntimeError(
f"Experiment {experiment} could not be found or created. {self.__class__.__name__} not created"
) from e
if self.project not in experiment._metadata_context._project_tuple:
raise ValueError(
f"{self.__class__.__name__} project {self.project} does not match experiment "
f"{experiment.name} project {experiment.project}"
)
if experiment._metadata_context.location != self.location:
raise ValueError(
f"{self.__class__.__name__} location {self.location} does not match experiment "
f"{experiment.name} location {experiment.location}"
)
def _associate_to_experiment(self, experiment: Union[str, Experiment]):
"""Associates this resource to the provided Experiment.
Args:
experiment (Union[str, Experiment]): Required. Experiment name or experiment instance.
Raises:
RuntimeError: If Metadata service cannot associate resource to Experiment.
"""
experiment_name = experiment if isinstance(experiment, str) else experiment.name
_LOGGER.info(
"Associating %s to Experiment: %s" % (self.resource_name, experiment_name)
)
try:
if isinstance(experiment, str):
experiment = Experiment.get_or_create(
experiment,
project=self.project,
location=self.location,
credentials=self.credentials,
)
experiment._log_experiment_loggable(self)
except Exception as e:
raise RuntimeError(
f"{self.resource_name} could not be associated with Experiment {experiment.name}"
) from e
# maps context names to their resources classes
# used by the Experiment implementation to filter for representations in the metadata store
# populated at module import time from class that inherit _ExperimentLoggable
# example mapping:
# {Metadata Type} -> {schema title} -> {vertex sdk class}
# Context -> 'system.PipelineRun' -> aiplatform.PipelineJob
# Context -> 'system.ExperimentRun' -> aiplatform.ExperimentRun
# Execution -> 'system.Run' -> aiplatform.ExperimentRun
_SUPPORTED_LOGGABLE_RESOURCES: Dict[
Union[Type[context.Context], Type[execution.Execution]],
Dict[str, _ExperimentLoggable],
] = {execution.Execution: dict(), context.Context: dict()}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,297 @@
# -*- coding: utf-8 -*-
# Copyright 2021 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import logging
from typing import Optional
from google.api_core import exceptions
from google.auth import credentials as auth_credentials
from google.cloud.aiplatform import base, initializer
from google.cloud.aiplatform import compat
from google.cloud.aiplatform import utils
from google.cloud.aiplatform.compat.types import metadata_store as gca_metadata_store
from google.cloud.aiplatform.constants import base as base_constants
class _MetadataStore(base.VertexAiResourceNounWithFutureManager):
"""Managed MetadataStore resource for Vertex AI"""
client_class = utils.MetadataClientWithOverride
_is_client_prediction_client = False
_resource_noun = "metadataStores"
_getter_method = "get_metadata_store"
_delete_method = "delete_metadata_store"
_parse_resource_name_method = "parse_metadata_store_path"
_format_resource_name_method = "metadata_store_path"
def __init__(
self,
metadata_store_name: Optional[str] = "default",
project: Optional[str] = None,
location: Optional[str] = None,
credentials: Optional[auth_credentials.Credentials] = None,
):
"""Retrieves an existing MetadataStore given a MetadataStore name or ID.
Args:
metadata_store_name (str):
Optional. A fully-qualified MetadataStore resource name or metadataStore ID.
Example: "projects/123/locations/us-central1/metadataStores/my-store" or
"my-store" when project and location are initialized or passed.
If not set, metadata_store_name will be set to "default".
project (str):
Optional project to retrieve resource from. If not set, project
set in aiplatform.init will be used.
location (str):
Optional location to retrieve resource from. If not set, location
set in aiplatform.init will be used.
credentials (auth_credentials.Credentials):
Custom credentials to use to upload this model. Overrides
credentials set in aiplatform.init.
"""
super().__init__(
project=project,
location=location,
credentials=credentials,
)
self._gca_resource = self._get_gca_resource(resource_name=metadata_store_name)
@classmethod
def get_or_create(
cls,
metadata_store_id: str = "default",
project: Optional[str] = None,
location: Optional[str] = None,
credentials: Optional[auth_credentials.Credentials] = None,
encryption_spec_key_name: Optional[str] = None,
) -> "_MetadataStore":
""" "Retrieves or Creates (if it does not exist) a Metadata Store.
Args:
metadata_store_id (str):
The <metadatastore> portion of the resource name with the format:
projects/123/locations/us-central1/metadataStores/<metadatastore>
If not provided, the MetadataStore's ID will be set to "default" to create a default MetadataStore.
project (str):
Project used to retrieve or create the metadata store. Overrides project set in
aiplatform.init.
location (str):
Location used to retrieve or create the metadata store. Overrides location set in
aiplatform.init.
credentials (auth_credentials.Credentials):
Custom credentials used to retrieve or create the metadata store. Overrides
credentials set in aiplatform.init.
encryption_spec_key_name (Optional[str]):
Optional. The Cloud KMS resource identifier of the customer
managed encryption key used to protect the metadata store. Has the
form:
``projects/my-project/locations/my-region/keyRings/my-kr/cryptoKeys/my-key``.
The key needs to be in the same region as where the compute
resource is created.
If set, this MetadataStore and all sub-resources of this MetadataStore will be secured by this key.
Overrides encryption_spec_key_name set in aiplatform.init.
Returns:
metadata_store (_MetadataStore):
Instantiated representation of the managed metadata store resource.
"""
store = cls._get(
metadata_store_name=metadata_store_id,
project=project,
location=location,
credentials=credentials,
)
if not store:
store = cls._create(
metadata_store_id=metadata_store_id,
project=project,
location=location,
credentials=credentials,
encryption_spec_key_name=encryption_spec_key_name,
)
return store
@classmethod
def _create(
cls,
metadata_store_id: str = "default",
project: Optional[str] = None,
location: Optional[str] = None,
credentials: Optional[auth_credentials.Credentials] = None,
encryption_spec_key_name: Optional[str] = None,
) -> "_MetadataStore":
"""Creates a new MetadataStore if it does not exist.
Args:
metadata_store_id (str):
The <metadatastore> portion of the resource name with
the format:
projects/123/locations/us-central1/metadataStores/<metadatastore>
If not provided, the MetadataStore's ID will be set to "default" to create a default MetadataStore.
project (str):
Project used to create the metadata store. Overrides project set in
aiplatform.init.
location (str):
Location used to create the metadata store. Overrides location set in
aiplatform.init.
credentials (auth_credentials.Credentials):
Custom credentials used to create the metadata store. Overrides
credentials set in aiplatform.init.
encryption_spec_key_name (Optional[str]):
Optional. The Cloud KMS resource identifier of the customer
managed encryption key used to protect the metadata store. Has the
form:
``projects/my-project/locations/my-region/keyRings/my-kr/cryptoKeys/my-key``.
The key needs to be in the same region as where the compute
resource is created.
If set, this MetadataStore and all sub-resources of this MetadataStore will be secured by this key.
Overrides encryption_spec_key_name set in aiplatform.init.
Returns:
metadata_store (_MetadataStore):
Instantiated representation of the managed metadata store resource.
"""
appended_user_agent = []
if base_constants.USER_AGENT_SDK_COMMAND:
appended_user_agent = [
f"sdk_command/{base_constants.USER_AGENT_SDK_COMMAND}"
]
# Reset the value for the USER_AGENT_SDK_COMMAND to avoid counting future unrelated api calls.
base_constants.USER_AGENT_SDK_COMMAND = ""
api_client = cls._instantiate_client(
location=location,
credentials=credentials,
appended_user_agent=appended_user_agent,
)
gapic_metadata_store = gca_metadata_store.MetadataStore(
encryption_spec=initializer.global_config.get_encryption_spec(
encryption_spec_key_name=encryption_spec_key_name,
select_version=compat.DEFAULT_VERSION,
)
)
try:
api_client.create_metadata_store(
parent=initializer.global_config.common_location_path(
project=project, location=location
),
metadata_store=gapic_metadata_store,
metadata_store_id=metadata_store_id,
).result()
except exceptions.AlreadyExists:
logging.info(f"MetadataStore '{metadata_store_id}' already exists")
return cls(
metadata_store_name=metadata_store_id,
project=project,
location=location,
credentials=credentials,
)
@classmethod
def _get(
cls,
metadata_store_name: Optional[str] = "default",
project: Optional[str] = None,
location: Optional[str] = None,
credentials: Optional[auth_credentials.Credentials] = None,
) -> Optional["_MetadataStore"]:
"""Returns a MetadataStore resource.
Args:
metadata_store_name (str):
Optional. A fully-qualified MetadataStore resource name or metadataStore ID.
Example: "projects/123/locations/us-central1/metadataStores/my-store" or
"my-store" when project and location are initialized or passed.
If not set, metadata_store_name will be set to "default".
project (str):
Optional project to retrieve the metadata store from. If not set, project
set in aiplatform.init will be used.
location (str):
Optional location to retrieve the metadata store from. If not set, location
set in aiplatform.init will be used.
credentials (auth_credentials.Credentials):
Custom credentials to retrieve this metadata store. Overrides
credentials set in aiplatform.init.
Returns:
metadata_store (Optional[_MetadataStore]):
An optional instantiated representation of the managed Metadata Store resource.
"""
try:
return cls(
metadata_store_name=metadata_store_name,
project=project,
location=location,
credentials=credentials,
)
except exceptions.NotFound:
logging.info(f"MetadataStore {metadata_store_name} not found.")
@classmethod
def ensure_default_metadata_store_exists(
cls,
project: Optional[str] = None,
location: Optional[str] = None,
credentials: Optional[auth_credentials.Credentials] = None,
encryption_key_spec_name: Optional[str] = None,
):
"""Helpers method to ensure the `default` MetadataStore exists in this project and location.
Args:
project (str):
Optional. Project to retrieve resource from. If not set, project
set in aiplatform.init will be used.
location (str):
Optional. Location to retrieve resource from. If not set, location
set in aiplatform.init will be used.
credentials (auth_credentials.Credentials):
Optional. Custom credentials to use to upload this model. Overrides
credentials set in aiplatform.init.
encryption_spec_key_name (str):
Optional. The Cloud KMS resource identifier of the customer
managed encryption key used to protect the metadata store. Has the
form:
``projects/my-project/locations/my-region/keyRings/my-kr/cryptoKeys/my-key``.
The key needs to be in the same region as where the compute
resource is created.
If set, this MetadataStore and all sub-resources of this MetadataStore will be secured by this key.
Overrides encryption_spec_key_name set in aiplatform.init.
"""
cls.get_or_create(
project=project,
location=location,
credentials=credentials,
encryption_spec_key_name=encryption_key_spec_name,
)

View File

@@ -0,0 +1,569 @@
# -*- coding: utf-8 -*-
# Copyright 2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import abc
import collections
import re
import threading
from copy import deepcopy
from typing import Dict, Optional, Union, Any, List
import proto
from google.api_core import exceptions
from google.auth import credentials as auth_credentials
from google.cloud.aiplatform import base, initializer
from google.cloud.aiplatform import metadata
from google.cloud.aiplatform import utils
from google.cloud.aiplatform.compat.types import artifact as gca_artifact
from google.cloud.aiplatform.compat.types import context as gca_context
from google.cloud.aiplatform.compat.types import execution as gca_execution
_LOGGER = base.Logger(__name__)
class _Resource(base.VertexAiResourceNounWithFutureManager, abc.ABC):
"""Metadata Resource for Vertex AI"""
client_class = utils.MetadataClientWithOverride
_delete_method = None
def __init__(
self,
resource_name: Optional[str] = None,
resource: Optional[
Union[gca_context.Context, gca_artifact.Artifact, gca_execution.Execution]
] = None,
metadata_store_id: str = "default",
project: Optional[str] = None,
location: Optional[str] = None,
credentials: Optional[auth_credentials.Credentials] = None,
):
"""Retrieves an existing Metadata resource given a resource name or ID.
Args:
resource_name (str):
A fully-qualified resource name or ID
Example: "projects/123/locations/us-central1/metadataStores/default/<resource_noun>/my-resource".
or "my-resource" when project and location are initialized or passed. if ``resource`` is provided, this
should not be set.
resource (Union[gca_context.Context, gca_artifact.Artifact, gca_execution.Execution]):
The proto.Message that contains the full information of the resource. If both set, this field overrides
``resource_name`` field.
metadata_store_id (str):
MetadataStore to retrieve resource from. If not set, metadata_store_id is set to "default".
If resource_name is a fully-qualified resource, its metadata_store_id overrides this one.
project (str):
Optional project to retrieve the resource from. If not set, project
set in aiplatform.init will be used.
location (str):
Optional location to retrieve the resource from. If not set, location
set in aiplatform.init will be used.
credentials (auth_credentials.Credentials):
Custom credentials to use to retrieve this resource. Overrides
credentials set in aiplatform.init.
"""
super().__init__(
project=project,
location=location,
credentials=credentials,
)
if resource:
self._gca_resource = resource
else:
full_resource_name = utils.full_resource_name(
resource_name=resource_name,
resource_noun=self._resource_noun,
parse_resource_name_method=self._parse_resource_name,
format_resource_name_method=self._format_resource_name,
parent_resource_name_fields={
metadata.metadata_store._MetadataStore._resource_noun: metadata_store_id
},
project=self.project,
location=self.location,
)
self._gca_resource = getattr(self.api_client, self._getter_method)(
name=full_resource_name, retry=base._DEFAULT_RETRY
)
self._threading_lock = threading.Lock()
@property
def metadata(self) -> Dict:
return self.to_dict()["metadata"]
@property
def schema_title(self) -> str:
return self._gca_resource.schema_title
@property
def description(self) -> str:
return self._gca_resource.description
@property
def display_name(self) -> str:
return self._gca_resource.display_name
@property
def schema_version(self) -> str:
return self._gca_resource.schema_version
@classmethod
def get_or_create(
cls,
resource_id: str,
schema_title: str,
display_name: Optional[str] = None,
schema_version: Optional[str] = None,
description: Optional[str] = None,
metadata: Optional[Dict] = None,
metadata_store_id: str = "default",
project: Optional[str] = None,
location: Optional[str] = None,
credentials: Optional[auth_credentials.Credentials] = None,
) -> "_Resource":
"""Retrieves or Creates (if it does not exist) a Metadata resource.
Args:
resource_id (str):
Required. The <resource_id> portion of the resource name with the format:
projects/123/locations/us-central1/metadataStores/<metadata_store_id>/<resource_noun>/<resource_id>.
schema_title (str):
Required. schema_title identifies the schema title used by the resource.
display_name (str):
Optional. The user-defined name of the resource.
schema_version (str):
Optional. schema_version specifies the version used by the resource.
If not set, defaults to use the latest version.
description (str):
Optional. Describes the purpose of the resource to be created.
metadata (Dict):
Optional. Contains the metadata information that will be stored in the resource.
metadata_store_id (str):
The <metadata_store_id> portion of the resource name with
the format:
projects/123/locations/us-central1/metadataStores/<metadata_store_id>/<resource_noun>/<resource_id>
If not provided, the MetadataStore's ID will be set to "default".
project (str):
Project used to retrieve or create this resource. Overrides project set in
aiplatform.init.
location (str):
Location used to retrieve or create this resource. Overrides location set in
aiplatform.init.
credentials (auth_credentials.Credentials):
Custom credentials used to retrieve or create this resource. Overrides
credentials set in aiplatform.init.
Returns:
resource (_Resource):
Instantiated representation of the managed Metadata resource.
"""
resource = cls._get(
resource_name=resource_id,
metadata_store_id=metadata_store_id,
project=project,
location=location,
credentials=credentials,
)
if not resource:
_LOGGER.info(f"Creating Resource {resource_id}")
resource = cls._create(
resource_id=resource_id,
schema_title=schema_title,
display_name=display_name,
schema_version=schema_version,
description=description,
metadata=metadata,
metadata_store_id=metadata_store_id,
project=project,
location=location,
credentials=credentials,
)
return resource
@classmethod
def get(
cls,
resource_id: str,
metadata_store_id: str = "default",
project: Optional[str] = None,
location: Optional[str] = None,
credentials: Optional[auth_credentials.Credentials] = None,
) -> "_Resource":
"""Retrieves a Metadata resource.
Args:
resource_id (str):
Required. The <resource_id> portion of the resource name with the format:
projects/123/locations/us-central1/metadataStores/<metadata_store_id>/<resource_noun>/<resource_id>.
metadata_store_id (str):
The <metadata_store_id> portion of the resource name with
the format:
projects/123/locations/us-central1/metadataStores/<metadata_store_id>/<resource_noun>/<resource_id>
If not provided, the MetadataStore's ID will be set to "default".
project (str):
Project used to retrieve or create this resource. Overrides project set in
aiplatform.init.
location (str):
Location used to retrieve or create this resource. Overrides location set in
aiplatform.init.
credentials (auth_credentials.Credentials):
Custom credentials used to retrieve or create this resource. Overrides
credentials set in aiplatform.init.
Returns:
resource (_Resource):
Instantiated representation of the managed Metadata resource or None if no resource was found.
"""
resource = cls._get(
resource_name=resource_id,
metadata_store_id=metadata_store_id,
project=project,
location=location,
credentials=credentials,
)
return resource
def sync_resource(self):
"""Syncs local resource with the resource in metadata store."""
self._gca_resource = getattr(self.api_client, self._getter_method)(
name=self.resource_name, retry=base._DEFAULT_RETRY
)
@staticmethod
def _nested_update_metadata(
gca_resource: Union[
gca_context.Context, gca_execution.Execution, gca_artifact.Artifact
],
metadata: Optional[Dict[str, Any]] = None,
):
"""Helper method to update gca_resource in place.
Performs a one-level deep nested update on the metadata field.
Args:
gca_resource (Union[gca_context.Context, gca_execution.Execution, gca_artifact.Artifact]):
Required. Metadata Protobuf resource. This proto's metadata will be
updated in place.
metadata (Dict[str, Any]):
Optional. Metadata dictionary to merge into gca_resource.metadata.
"""
if metadata:
if gca_resource.metadata:
for key, value in metadata.items():
# Note: This only support nested dictionaries one level deep
if isinstance(value, collections.abc.Mapping):
gca_resource.metadata[key].update(value)
else:
gca_resource.metadata[key] = value
else:
gca_resource.metadata = metadata
def update(
self,
metadata: Optional[Dict] = None,
description: Optional[str] = None,
credentials: Optional[auth_credentials.Credentials] = None,
location: Optional[str] = None,
):
"""Updates an existing Metadata resource with new metadata.
Args:
metadata (Dict):
Optional. metadata contains the updated metadata information.
description (str):
Optional. Description describes the resource to be updated.
credentials (auth_credentials.Credentials):
Custom credentials to use to update this resource. Overrides
credentials set in aiplatform.init.
"""
if not hasattr(self, "_threading_lock"):
self._threading_lock = threading.Lock()
with self._threading_lock:
gca_resource = deepcopy(self._gca_resource)
if metadata:
self._nested_update_metadata(
gca_resource=gca_resource, metadata=metadata
)
if description:
gca_resource.description = description
api_client = self._instantiate_client(
credentials=credentials, location=location
)
# TODO: if etag is not valid sync and retry
update_gca_resource = self._update_resource(
client=api_client,
resource=gca_resource,
)
self._gca_resource = update_gca_resource
@classmethod
def list(
cls,
filter: Optional[str] = None, # pylint: disable=redefined-builtin
metadata_store_id: str = "default",
project: Optional[str] = None,
location: Optional[str] = None,
credentials: Optional[auth_credentials.Credentials] = None,
order_by: Optional[str] = None,
) -> List["_Resource"]:
"""List resources that match the list filter in target metadataStore.
Args:
filter (str):
Optional. A query to filter available resources for
matching results.
metadata_store_id (str):
The <metadata_store_id> portion of the resource name with
the format:
projects/123/locations/us-central1/metadataStores/<metadata_store_id>/<resource_noun>/<resource_id>
If not provided, the MetadataStore's ID will be set to "default".
project (str):
Project used to create this resource. Overrides project set in
aiplatform.init.
location (str):
Location used to create this resource. Overrides location set in
aiplatform.init.
credentials (auth_credentials.Credentials):
Custom credentials used to create this resource. Overrides
credentials set in aiplatform.init.
order_by (str):
Optional. How the list of messages is ordered.
Specify the values to order by and an ordering operation. The
default sorting order is ascending. To specify descending order
for a field, users append a " desc" suffix; for example: "foo
desc, bar". Subfields are specified with a ``.`` character, such
as foo.bar. see https://google.aip.dev/132#ordering for more
details.
Returns:
resources (sequence[_Resource]):
a list of managed Metadata resource.
"""
parent = (
initializer.global_config.common_location_path(
project=project, location=location
)
+ f"/metadataStores/{metadata_store_id}"
)
return super().list(
filter=filter,
project=project,
location=location,
credentials=credentials,
parent=parent,
order_by=order_by,
)
@classmethod
def _create(
cls,
resource_id: str,
schema_title: str,
display_name: Optional[str] = None,
schema_version: Optional[str] = None,
description: Optional[str] = None,
metadata: Optional[Dict] = None,
metadata_store_id: Optional[str] = "default",
project: Optional[str] = None,
location: Optional[str] = None,
credentials: Optional[auth_credentials.Credentials] = None,
) -> Optional["_Resource"]:
"""Creates a new Metadata resource.
Args:
resource_id (str):
Required. The <resource_id> portion of the resource name with
the format:
projects/123/locations/us-central1/metadataStores/<metadata_store_id>/<resource_noun>/<resource_id>.
schema_title (str):
Required. schema_title identifies the schema title used by the resource.
display_name (str):
Optional. The user-defined name of the resource.
schema_version (str):
Optional. schema_version specifies the version used by the resource.
If not set, defaults to use the latest version.
description (str):
Optional. Describes the purpose of the resource to be created.
metadata (Dict):
Optional. Contains the metadata information that will be stored in the resource.
metadata_store_id (str):
The <metadata_store_id> portion of the resource name with
the format:
projects/123/locations/us-central1/metadataStores/<metadata_store_id>/<resource_noun>/<resource_id>
If not provided, the MetadataStore's ID will be set to "default".
project (str):
Project used to create this resource. Overrides project set in
aiplatform.init.
location (str):
Location used to create this resource. Overrides location set in
aiplatform.init.
credentials (auth_credentials.Credentials):
Custom credentials used to create this resource. Overrides
credentials set in aiplatform.init.
Returns:
resource (_Resource):
Instantiated representation of the managed Metadata resource.
"""
api_client = cls._instantiate_client(location=location, credentials=credentials)
parent = (
initializer.global_config.common_location_path(
project=project, location=location
)
+ f"/metadataStores/{metadata_store_id}"
)
try:
resource = cls._create_resource(
client=api_client,
parent=parent,
resource_id=resource_id,
schema_title=schema_title,
display_name=display_name,
schema_version=schema_version,
description=description,
metadata=metadata,
)
except exceptions.AlreadyExists:
_LOGGER.info(f"Resource '{resource_id}' already exist")
return
self = cls._empty_constructor(
project=project,
location=location,
credentials=credentials,
)
self._gca_resource = resource
self._threading_lock = threading.Lock()
return self
@classmethod
def _get(
cls,
resource_name: str,
metadata_store_id: Optional[str] = "default",
project: Optional[str] = None,
location: Optional[str] = None,
credentials: Optional[auth_credentials.Credentials] = None,
) -> Optional["_Resource"]:
"""Returns a metadata Resource.
Args:
resource_name (str):
A fully-qualified resource name or resource ID
Example: "projects/123/locations/us-central1/metadataStores/default/<resource_noun>/my-resource".
or "my-resource" when project and location are initialized or passed.
metadata_store_id (str):
The metadata_store_id portion of the resource name with
the format:
projects/123/locations/us-central1/metadataStores/<metadata_store_id>/<resource_noun>/my-resource
If not provided, the MetadataStore's ID will be set to "default".
project (str):
Project to get this resource from. Overrides project set in
aiplatform.init.
location (str):
Location to get this resource from. Overrides location set in
aiplatform.init.
credentials (auth_credentials.Credentials):
Custom credentials to use to get this resource. Overrides
credentials set in aiplatform.init.
Returns:
resource (Optional[_Resource]):
An optional instantiated representation of the managed Metadata resource.
"""
try:
return cls(
resource_name,
metadata_store_id=metadata_store_id,
project=project,
location=location,
credentials=credentials,
)
except exceptions.NotFound:
_LOGGER.info(f"Resource {resource_name} not found.")
@classmethod
@abc.abstractmethod
def _create_resource(
cls,
client: utils.MetadataClientWithOverride,
parent: str,
resource_id: str,
schema_title: str,
display_name: Optional[str] = None,
schema_version: Optional[str] = None,
description: Optional[str] = None,
metadata: Optional[Dict] = None,
) -> proto.Message:
"""Create resource method."""
pass
@classmethod
@abc.abstractmethod
def _update_resource(
cls,
client: utils.MetadataClientWithOverride,
resource: proto.Message,
) -> proto.Message:
"""Update resource method."""
pass
@staticmethod
def _extract_metadata_store_id(resource_name, resource_noun) -> str:
"""Extracts the metadata store id from the resource name.
Args:
resource_name (str):
Required. A fully-qualified metadata resource name. For example
projects/{project}/locations/{location}/metadataStores/{metadata_store_id}/{resource_noun}/{resource_id}.
resource_noun (str):
Required. The resource_noun portion of the resource_name
Returns:
metadata_store_id (str):
The metadata store id for the particular resource name.
Raises:
ValueError: If it does not exist.
"""
pattern = re.compile(
r"^projects\/(?P<project>[\w-]+)\/locations\/(?P<location>[\w-]+)\/metadataStores\/(?P<store>[\w-]+)\/"
+ resource_noun
+ r"\/(?P<id>[\w-]+)(?P<version>@[\w-]+)?$"
)
match = pattern.match(resource_name)
if not match:
raise ValueError(
f"failed to extract metadata_store_id from resource {resource_name}"
)
return match["store"]

View File

@@ -0,0 +1,310 @@
# -*- coding: utf-8 -*-
# Copyright 2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import abc
from typing import Any, Optional, Dict, List
from google.auth import credentials as auth_credentials
from google.cloud.aiplatform.compat.types import artifact as gca_artifact
from google.cloud.aiplatform.metadata import artifact
from google.cloud.aiplatform.constants import base as base_constants
from google.cloud.aiplatform.metadata import constants
class BaseArtifactSchema(artifact.Artifact):
"""Base class for Metadata Artifact types."""
@property
@classmethod
@abc.abstractmethod
def schema_title(cls) -> str:
"""Identifies the Vertex Metadata schema title used by the resource."""
pass
def __init__(
self,
*,
artifact_id: Optional[str] = None,
uri: Optional[str] = None,
display_name: Optional[str] = None,
schema_version: Optional[str] = None,
description: Optional[str] = None,
metadata: Optional[Dict] = None,
state: Optional[gca_artifact.Artifact.State] = gca_artifact.Artifact.State.LIVE,
):
"""Initializes the Artifact with the given name, URI and metadata.
This is the base class for defining various artifact types, which can be
passed to google.Artifact to create a corresponding resource.
Artifacts carry a `metadata` field, which is a dictionary for storing
metadata related to this artifact. Subclasses from ArtifactType can enforce
various structure and field requirements for the metadata field.
Args:
artifact_id (str):
Optional. The <resource_id> portion of the Artifact name with
the following format, this is globally unique in a metadataStore:
projects/123/locations/us-central1/metadataStores/<metadata_store_id>/artifacts/<resource_id>.
uri (str):
Optional. The uniform resource identifier of the artifact file. May be empty if there is no actual
artifact file.
display_name (str):
Optional. The user-defined name of the Artifact.
schema_version (str):
Optional. schema_version specifies the version used by the Artifact.
If not set, defaults to use the latest version.
description (str):
Optional. Describes the purpose of the Artifact to be created.
metadata (Dict):
Optional. Contains the metadata information that will be stored in the Artifact.
state (google.cloud.gapic.types.Artifact.State):
Optional. The state of this Artifact. This is a
property of the Artifact, and does not imply or
capture any ongoing process. This property is
managed by clients (such as Vertex AI
Pipelines), and the system does not prescribe or
check the validity of state transitions.
"""
# initialize the exception to resolve the FutureManager exception.
self._exception = None
# resource_id is not stored in the proto. Create method uses the
# resource_id along with project_id and location to construct an
# resource_name which is stored in the proto message.
self.artifact_id = artifact_id
# Store all other attributes using the proto structure.
self._gca_resource = gca_artifact.Artifact()
self._gca_resource.uri = uri
self._gca_resource.display_name = display_name
self._gca_resource.schema_version = (
schema_version or constants._DEFAULT_SCHEMA_VERSION
)
self._gca_resource.description = description
# If metadata is None covert to {}
metadata = metadata if metadata else {}
self._nested_update_metadata(self._gca_resource, metadata)
self._gca_resource.state = state
# TODO() Switch to @singledispatchmethod constructor overload after py>=3.8
def _init_with_resource_name(
self,
*,
artifact_name: str,
metadata_store_id: str = "default",
project: Optional[str] = None,
location: Optional[str] = None,
credentials: Optional[auth_credentials.Credentials] = None,
):
"""Initializes the Artifact instance using an existing resource.
Args:
artifact_name (str):
Artifact name with the following format, this is globally unique in a metadataStore:
projects/123/locations/us-central1/metadataStores/<metadata_store_id>/artifacts/<resource_id>.
metadata_store_id (str):
Optional. MetadataStore to retrieve Artifact from. If not set, metadata_store_id is set to "default".
If artifact_name is a fully-qualified resource, its metadata_store_id overrides this one.
project (str):
Optional. Project to retrieve the artifact from. If not set, project
set in aiplatform.init will be used.
location (str):
Optional. Location to retrieve the Artifact from. If not set, location
set in aiplatform.init will be used.
credentials (auth_credentials.Credentials):
Optional. Custom credentials to use to retrieve this Artifact. Overrides
credentials set in aiplatform.init.
"""
# Add User Agent Header for metrics tracking if one is not specified
# If one is already specified this call was initiated by a sub class.
if not base_constants.USER_AGENT_SDK_COMMAND:
base_constants.USER_AGENT_SDK_COMMAND = "aiplatform.metadata.schema.base_artifact.BaseArtifactSchema._init_with_resource_name"
super(BaseArtifactSchema, self).__init__(
artifact_name=artifact_name,
metadata_store_id=metadata_store_id,
project=project,
location=location,
credentials=credentials,
)
def create(
self,
*,
metadata_store_id: Optional[str] = "default",
project: Optional[str] = None,
location: Optional[str] = None,
credentials: Optional[auth_credentials.Credentials] = None,
) -> "artifact.Artifact":
"""Creates a new Metadata Artifact.
Args:
metadata_store_id (str):
Optional. The <metadata_store_id> portion of the resource name with
the format:
projects/123/locations/us-central1/metadataStores/<metadata_store_id>/artifacts/<resource_id>
If not provided, the MetadataStore's ID will be set to "default".
project (str):
Optional. Project used to create this Artifact. Overrides project set in
aiplatform.init.
location (str):
Optional. Location used to create this Artifact. Overrides location set in
aiplatform.init.
credentials (auth_credentials.Credentials):
Optional. Custom credentials used to create this Artifact. Overrides
credentials set in aiplatform.init.
Returns:
Artifact: Instantiated representation of the managed Metadata Artifact.
"""
# Add User Agent Header for metrics tracking.
base_constants.USER_AGENT_SDK_COMMAND = (
"aiplatform.metadata.schema.base_artifact.BaseArtifactSchema.create"
)
# Check if metadata exists to avoid proto read error
metadata = None
if self._gca_resource.metadata:
metadata = self.metadata
new_artifact_instance = artifact.Artifact.create(
resource_id=self.artifact_id,
schema_title=self.schema_title,
uri=self.uri,
display_name=self.display_name,
schema_version=self.schema_version,
description=self.description,
metadata=metadata,
state=self.state,
metadata_store_id=metadata_store_id,
project=project,
location=location,
credentials=credentials,
)
# Reinstantiate this class using the newly created resource.
self._init_with_resource_name(artifact_name=new_artifact_instance.resource_name)
return self
@classmethod
def list(
cls,
filter: Optional[str] = None, # pylint: disable=redefined-builtin
metadata_store_id: str = "default",
project: Optional[str] = None,
location: Optional[str] = None,
credentials: Optional[auth_credentials.Credentials] = None,
order_by: Optional[str] = None,
) -> List["BaseArtifactSchema"]:
"""List all the Artifact resources with a particular schema.
Args:
filter (str):
Optional. A query to filter available resources for
matching results.
metadata_store_id (str):
The <metadata_store_id> portion of the resource name with
the format:
projects/123/locations/us-central1/metadataStores/<metadata_store_id>/<resource_noun>/<resource_id>
If not provided, the MetadataStore's ID will be set to "default".
project (str):
Project used to create this resource. Overrides project set in
aiplatform.init.
location (str):
Location used to create this resource. Overrides location set in
aiplatform.init.
credentials (auth_credentials.Credentials):
Custom credentials used to create this resource. Overrides
credentials set in aiplatform.init.
order_by (str):
Optional. How the list of messages is ordered.
Specify the values to order by and an ordering operation. The
default sorting order is ascending. To specify descending order
for a field, users append a " desc" suffix; for example: "foo
desc, bar". Subfields are specified with a ``.`` character, such
as foo.bar. see https://google.aip.dev/132#ordering for more
details.
Returns:
A list of artifact resources with a particular schema.
"""
schema_filter = f'schema_title="{cls.schema_title}"'
if filter:
filter = f"{filter} AND {schema_filter}"
else:
filter = schema_filter
return super().list(
filter=filter,
metadata_store_id=metadata_store_id,
project=project,
location=location,
credentials=credentials,
)
def sync_resource(self):
"""Syncs local resource with the resource in metadata store.
Raises:
RuntimeError: if the artifact resource hasn't been created.
"""
if self._gca_resource.name:
super().sync_resource()
else:
raise RuntimeError(
f"{self.__class__.__name__} resource has not been created."
)
def update(
self,
metadata: Optional[Dict[str, Any]] = None,
description: Optional[str] = None,
credentials: Optional[auth_credentials.Credentials] = None,
):
"""Updates an existing Artifact resource with new metadata.
Args:
metadata (Dict):
Optional. metadata contains the updated metadata information.
description (str):
Optional. Description describes the resource to be updated.
credentials (auth_credentials.Credentials):
Custom credentials to use to update this resource. Overrides
credentials set in aiplatform.init.
Raises:
RuntimeError: if the artifact resource hasn't been created.
"""
if self._gca_resource.name:
super().update(
metadata=metadata,
description=description,
credentials=credentials,
)
else:
raise RuntimeError(
f"{self.__class__.__name__} resource has not been created."
)
def __repr__(self) -> str:
if self._gca_resource.name:
return super().__repr__()
else:
return f"{object.__repr__(self)}\nschema_title: {self.schema_title}"

View File

@@ -0,0 +1,285 @@
# -*- coding: utf-8 -*-
# Copyright 2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import abc
from typing import Dict, List, Optional, Sequence
from google.auth import credentials as auth_credentials
from google.cloud.aiplatform.compat.types import context as gca_context
from google.cloud.aiplatform.compat.types import (
lineage_subgraph as gca_lineage_subgraph,
)
from google.cloud.aiplatform.constants import base as base_constants
from google.cloud.aiplatform.metadata import constants
from google.cloud.aiplatform.metadata import context
class BaseContextSchema(context.Context):
"""Base class for Metadata Context schema."""
@property
@classmethod
@abc.abstractmethod
def schema_title(cls) -> str:
"""Identifies the Vertex Metadta schema title used by the resource."""
pass
def __init__(
self,
*,
context_id: Optional[str] = None,
display_name: Optional[str] = None,
schema_version: Optional[str] = None,
metadata: Optional[Dict] = None,
description: Optional[str] = None,
):
"""Initializes the Context with the given name, URI and metadata.
Args:
context_id (str):
Optional. The <resource_id> portion of the Context name with
the following format, this is globally unique in a metadataStore.
projects/123/locations/us-central1/metadataStores/<metadata_store_id>/Contexts/<resource_id>.
display_name (str):
Optional. The user-defined name of the Context.
schema_version (str):
Optional. schema_version specifies the version used by the Context.
If not set, defaults to use the latest version.
metadata (Dict):
Optional. Contains the metadata information that will be stored in the Context.
description (str):
Optional. Describes the purpose of the Context to be created.
"""
# initialize the exception to resolve the FutureManager exception.
self._exception = None
# resource_id is not stored in the proto. Create method uses the
# resource_id along with project_id and location to construct an
# resource_name which is stored in the proto message.
self.context_id = context_id
# Store all other attributes using the proto structure.
self._gca_resource = gca_context.Context()
self._gca_resource.display_name = display_name
self._gca_resource.schema_version = (
schema_version or constants._DEFAULT_SCHEMA_VERSION
)
# If metadata is None covert to {}
metadata = metadata if metadata else {}
self._nested_update_metadata(self._gca_resource, metadata)
self._gca_resource.description = description
# TODO() Switch to @singledispatchmethod constructor overload after py>=3.8
def _init_with_resource_name(
self,
*,
context_name: str,
):
"""Initializes the Artifact instance using an existing resource.
Args:
context_name (str):
Context name with the following format, this is globally unique in a metadataStore:
projects/123/locations/us-central1/metadataStores/<metadata_store_id>/contexts/<resource_id>.
"""
# Add User Agent Header for metrics tracking if one is not specified
# If one is already specified this call was initiated by a sub class.
if not base_constants.USER_AGENT_SDK_COMMAND:
base_constants.USER_AGENT_SDK_COMMAND = "aiplatform.metadata.schema.base_context.BaseContextSchema._init_with_resource_name"
super(BaseContextSchema, self).__init__(resource_name=context_name)
def create(
self,
*,
metadata_store_id: Optional[str] = "default",
project: Optional[str] = None,
location: Optional[str] = None,
credentials: Optional[auth_credentials.Credentials] = None,
) -> "context.Context":
"""Creates a new Metadata Context.
Args:
metadata_store_id (str):
Optional. The <metadata_store_id> portion of the resource name with
the format:
projects/123/locations/us-central1/metadataStores/<metadata_store_id>/Contexts/<resource_id>
If not provided, the MetadataStore's ID will be set to "default".
project (str):
Optional. Project used to create this Context. Overrides project set in
aiplatform.init.
location (str):
Optional. Location used to create this Context. Overrides location set in
aiplatform.init.
credentials (auth_credentials.Credentials):
Optional. Custom credentials used to create this Context. Overrides
credentials set in aiplatform.init.
Returns:
Context: Instantiated representation of the managed Metadata Context.
"""
# Add User Agent Header for metrics tracking.
base_constants.USER_AGENT_SDK_COMMAND = (
"aiplatform.metadata.schema.base_context.BaseContextSchema.create"
)
# Check if metadata exists to avoid proto read error
metadata = None
if self._gca_resource.metadata:
metadata = self.metadata
new_context = context.Context.create(
resource_id=self.context_id,
schema_title=self.schema_title,
display_name=self.display_name,
schema_version=self.schema_version,
description=self.description,
metadata=metadata,
metadata_store_id=metadata_store_id,
project=project,
location=location,
credentials=credentials,
)
# Reinstantiate this class using the newly created resource.
self._init_with_resource_name(context_name=new_context.resource_name)
return self
@classmethod
def list(
cls,
filter: Optional[str] = None, # pylint: disable=redefined-builtin
metadata_store_id: str = "default",
project: Optional[str] = None,
location: Optional[str] = None,
credentials: Optional[auth_credentials.Credentials] = None,
order_by: Optional[str] = None,
) -> List["BaseContextSchema"]:
"""List all the Context resources with a particular schema.
Args:
filter (str):
Optional. A query to filter available resources for
matching results.
metadata_store_id (str):
The <metadata_store_id> portion of the resource name with
the format:
projects/123/locations/us-central1/metadataStores/<metadata_store_id>/<resource_noun>/<resource_id>
If not provided, the MetadataStore's ID will be set to "default".
project (str):
Project used to create this resource. Overrides project set in
aiplatform.init.
location (str):
Location used to create this resource. Overrides location set in
aiplatform.init.
credentials (auth_credentials.Credentials):
Custom credentials used to create this resource. Overrides
credentials set in aiplatform.init.
order_by (str):
Optional. How the list of messages is ordered.
Specify the values to order by and an ordering operation. The
default sorting order is ascending. To specify descending order
for a field, users append a " desc" suffix; for example: "foo
desc, bar". Subfields are specified with a ``.`` character, such
as foo.bar. see https://google.aip.dev/132#ordering for more
details.
Returns:
A list of context resources with a particular schema.
"""
schema_filter = f'schema_title="{cls.schema_title}"'
if filter:
filter = f"{filter} AND {schema_filter}"
else:
filter = schema_filter
return super().list(
filter=filter,
metadata_store_id=metadata_store_id,
project=project,
location=location,
credentials=credentials,
)
def add_artifacts_and_executions(
self,
artifact_resource_names: Optional[Sequence[str]] = None,
execution_resource_names: Optional[Sequence[str]] = None,
):
"""Associate Executions and attribute Artifacts to a given Context.
Args:
artifact_resource_names (Sequence[str]):
Optional. The full resource name of Artifacts to attribute to
the Context.
execution_resource_names (Sequence[str]):
Optional. The full resource name of Executions to associate with
the Context.
Raises:
RuntimeError: if Context resource hasn't been created.
"""
if self._gca_resource.name:
super().add_artifacts_and_executions(
artifact_resource_names=artifact_resource_names,
execution_resource_names=execution_resource_names,
)
else:
raise RuntimeError(
f"{self.__class__.__name__} resource has not been created."
)
def add_context_children(self, contexts: List[context.Context]):
"""Adds the provided contexts as children of this context.
Args:
contexts (List[_Context]): Contexts to add as children.
Raises:
RuntimeError: if Context resource hasn't been created.
"""
if self._gca_resource.name:
super().add_context_children(contexts)
else:
raise RuntimeError(
f"{self.__class__.__name__} resource has not been created."
)
def query_lineage_subgraph(self) -> gca_lineage_subgraph.LineageSubgraph:
"""Queries lineage subgraph of this context.
Returns:
lineage subgraph(gca_lineage_subgraph.LineageSubgraph):
Lineage subgraph of this Context.
Raises:
RuntimeError: if Context resource hasn't been created.
"""
if self._gca_resource.name:
return super().query_lineage_subgraph()
else:
raise RuntimeError(
f"{self.__class__.__name__} resource has not been created."
)
def __repr__(self) -> str:
if self._gca_resource.name:
return super().__repr__()
else:
return f"{object.__repr__(self)}\nschema_title: {self.schema_title}"

View File

@@ -0,0 +1,418 @@
# -*- coding: utf-8 -*-
# Copyright 2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import abc
from typing import Any, Dict, List, Optional, Union
from google.auth import credentials as auth_credentials
from google.cloud.aiplatform import models
from google.cloud.aiplatform.compat.types import execution as gca_execution
from google.cloud.aiplatform.constants import base as base_constants
from google.cloud.aiplatform.metadata import artifact
from google.cloud.aiplatform.metadata import constants
from google.cloud.aiplatform.metadata import execution
from google.cloud.aiplatform.metadata import metadata
class BaseExecutionSchema(execution.Execution):
"""Base class for Metadata Execution schema."""
@property
@classmethod
@abc.abstractmethod
def schema_title(cls) -> str:
"""Identifies the Vertex Metadta schema title used by the resource."""
pass
def __init__(
self,
*,
state: Optional[
gca_execution.Execution.State
] = gca_execution.Execution.State.RUNNING,
execution_id: Optional[str] = None,
display_name: Optional[str] = None,
schema_version: Optional[str] = None,
metadata: Optional[Dict] = None,
description: Optional[str] = None,
):
"""Initializes the Execution with the given name, URI and metadata.
Args:
state (gca_execution.Execution.State.RUNNING):
Optional. State of this Execution. Defaults to RUNNING.
execution_id (str):
Optional. The <resource_id> portion of the Execution name with
the following format, this is globally unique in a metadataStore.
projects/123/locations/us-central1/metadataStores/<metadata_store_id>/executions/<resource_id>.
display_name (str):
Optional. The user-defined name of the Execution.
schema_version (str):
Optional. schema_version specifies the version used by the Execution.
If not set, defaults to use the latest version.
metadata (Dict):
Optional. Contains the metadata information that will be stored in the Execution.
description (str):
Optional. Describes the purpose of the Execution to be created.
"""
# initialize the exception to resolve the FutureManager exception.
self._exception = None
# resource_id is not stored in the proto. Create method uses the
# resource_id along with project_id and location to construct an
# resource_name which is stored in the proto message.
self.execution_id = execution_id
# Store all other attributes using the proto structure.
self._gca_resource = gca_execution.Execution()
self._gca_resource.state = state
self._gca_resource.display_name = display_name
self._gca_resource.schema_version = (
schema_version or constants._DEFAULT_SCHEMA_VERSION
)
# If metadata is None covert to {}
metadata = metadata if metadata else {}
self._nested_update_metadata(self._gca_resource, metadata)
self._gca_resource.description = description
# TODO() Switch to @singledispatchmethod constructor overload after py>=3.8
def _init_with_resource_name(
self,
*,
execution_name: str,
):
"""Initializes the Execution instance using an existing resource.
Args:
execution_name (str):
The Execution name with the following format, this is globally unique in a metadataStore.
projects/123/locations/us-central1/metadataStores/<metadata_store_id>/executions/<resource_id>.
"""
# Add User Agent Header for metrics tracking if one is not specified
# If one is already specified this call was initiated by a sub class.
if not base_constants.USER_AGENT_SDK_COMMAND:
base_constants.USER_AGENT_SDK_COMMAND = "aiplatform.metadata.schema.base_execution.BaseExecutionSchema._init_with_resource_name"
super(BaseExecutionSchema, self).__init__(execution_name=execution_name)
def create(
self,
*,
metadata_store_id: Optional[str] = "default",
project: Optional[str] = None,
location: Optional[str] = None,
credentials: Optional[auth_credentials.Credentials] = None,
) -> "execution.Execution":
"""Creates a new Metadata Execution.
Args:
metadata_store_id (str):
Optional. The <metadata_store_id> portion of the resource name with
the format:
projects/123/locations/us-central1/metadataStores/<metadata_store_id>/executions/<resource_id>
If not provided, the MetadataStore's ID will be set to "default".
project (str):
Optional. Project used to create this Execution. Overrides project set in
aiplatform.init.
location (str):
Optional. Location used to create this Execution. Overrides location set in
aiplatform.init.
credentials (auth_credentials.Credentials):
Optional. Custom credentials used to create this Execution. Overrides
credentials set in aiplatform.init.
Returns:
Execution: Instantiated representation of the managed Metadata Execution.
"""
# Add User Agent Header for metrics tracking if one is not specified
# If one is already specified this call was initiated by a sub class.
base_constants.USER_AGENT_SDK_COMMAND = (
"aiplatform.metadata.schema.base_execution.BaseExecutionSchema.create"
)
# Check if metadata exists to avoid proto read error
metadata = None
if self._gca_resource.metadata:
metadata = self.metadata
new_execution_instance = execution.Execution.create(
resource_id=self.execution_id,
schema_title=self.schema_title,
display_name=self.display_name,
schema_version=self.schema_version,
description=self.description,
metadata=metadata,
state=self.state,
metadata_store_id=metadata_store_id,
project=project,
location=location,
credentials=credentials,
)
# Reinstantiate this class using the newly created resource.
self._init_with_resource_name(
execution_name=new_execution_instance.resource_name
)
return self
@classmethod
def list(
cls,
filter: Optional[str] = None, # pylint: disable=redefined-builtin
metadata_store_id: str = "default",
project: Optional[str] = None,
location: Optional[str] = None,
credentials: Optional[auth_credentials.Credentials] = None,
order_by: Optional[str] = None,
) -> List["BaseExecutionSchema"]:
"""List all the Execution resources with a particular schema.
Args:
filter (str):
Optional. A query to filter available resources for
matching results.
metadata_store_id (str):
The <metadata_store_id> portion of the resource name with
the format:
projects/123/locations/us-central1/metadataStores/<metadata_store_id>/<resource_noun>/<resource_id>
If not provided, the MetadataStore's ID will be set to "default".
project (str):
Project used to create this resource. Overrides project set in
aiplatform.init.
location (str):
Location used to create this resource. Overrides location set in
aiplatform.init.
credentials (auth_credentials.Credentials):
Custom credentials used to create this resource. Overrides
credentials set in aiplatform.init.
order_by (str):
Optional. How the list of messages is ordered.
Specify the values to order by and an ordering operation. The
default sorting order is ascending. To specify descending order
for a field, users append a " desc" suffix; for example: "foo
desc, bar". Subfields are specified with a ``.`` character, such
as foo.bar. see https://google.aip.dev/132#ordering for more
details.
Returns:
A list of execution resources with a particular schema.
"""
schema_filter = f'schema_title="{cls.schema_title}"'
if filter:
filter = f"{filter} AND {schema_filter}"
else:
filter = schema_filter
return super().list(
filter=filter,
metadata_store_id=metadata_store_id,
project=project,
location=location,
credentials=credentials,
)
def start_execution(
self,
*,
metadata_store_id: Optional[str] = "default",
resume: bool = False,
project: Optional[str] = None,
location: Optional[str] = None,
credentials: Optional[auth_credentials.Credentials] = None,
) -> "execution.Execution":
"""Create and starts a new Metadata Execution or resumes a previously created Execution.
This method is similar to create_execution with additional support for Experiments.
If an Experiment is set prior to running this command, the Experiment will be
associtaed with the created execution, otherwise this method behaves the same
as create_execution.
To start a new execution:
```
instance_of_execution_schema = execution_schema.ContainerExecution(...)
with instance_of_execution_schema.start_execution() as exc:
exc.assign_input_artifacts([my_artifact])
model = aiplatform.Artifact.create(uri='gs://my-uri', schema_title='system.Model')
exc.assign_output_artifacts([model])
```
To continue a previously created execution:
```
with execution_schema.ContainerExecution(resource_id='my-exc', resume=True) as exc:
...
```
Args:
metadata_store_id (str):
Optional. The <metadata_store_id> portion of the resource name with
the format:
projects/123/locations/us-central1/metadataStores/<metadata_store_id>/executions/<executions_id>
If not provided, the MetadataStore's ID will be set to "default". Currently only the 'default'
MetadataStore ID is supported.
resume (bool):
Resume an existing execution.
project (str):
Optional. Project used to create this Execution. Overrides project set in
aiplatform.init.
location (str):
Optional. Location used to create this Execution. Overrides location set in
aiplatform.init.
credentials (auth_credentials.Credentials):
Optional. Custom credentials used to create this Execution. Overrides
credentials set in aiplatform.init.
Returns:
Execution: Instantiated representation of the managed Metadata Execution.
Raises:
ValueError: If metadata_store_id other than 'default' is provided.
"""
# Add User Agent Header for metrics tracking if one is not specified
# If one is already specified this call was initiated by a sub class.
base_constants.USER_AGENT_SDK_COMMAND = "aiplatform.metadata.schema.base_execution.BaseExecutionSchema.start_execution"
if metadata_store_id != "default":
raise ValueError(
f"metadata_store_id {metadata_store_id} is not supported. Only the default MetadataStore ID is supported."
)
new_execution_instance = metadata._ExperimentTracker().start_execution(
schema_title=self.schema_title,
display_name=self.display_name,
resource_id=self.execution_id,
metadata=self.metadata,
schema_version=self.schema_version,
description=self.description,
# TODO: Add support for metadata_store_id once it is supported in experiment.
resume=resume,
project=project,
location=location,
credentials=credentials,
)
# Reinstantiate this class using the newly created resource.
self._init_with_resource_name(
execution_name=new_execution_instance.resource_name
)
return self
def assign_input_artifacts(
self, artifacts: List[Union[artifact.Artifact, models.Model]]
):
"""Assigns Artifacts as inputs to this Executions.
Args:
artifacts (List[Union[artifact.Artifact, models.Model]]):
Required. Artifacts to assign as input.
Raises:
RuntimeError: if Execution resource hasn't been created.
"""
if self._gca_resource.name:
super().assign_input_artifacts(artifacts)
else:
raise RuntimeError(
f"{self.__class__.__name__} resource has not been created."
)
def assign_output_artifacts(
self, artifacts: List[Union[artifact.Artifact, models.Model]]
):
"""Assigns Artifacts as outputs to this Executions.
Args:
artifacts (List[Union[artifact.Artifact, models.Model]]):
Required. Artifacts to assign as input.
Raises:
RuntimeError: if Execution resource hasn't been created.
"""
if self._gca_resource.name:
super().assign_output_artifacts(artifacts)
else:
raise RuntimeError(
f"{self.__class__.__name__} resource has not been created."
)
def get_input_artifacts(self) -> List[artifact.Artifact]:
"""Get the input Artifacts of this Execution.
Returns:
List of input Artifacts.
Raises:
RuntimeError: if Execution resource hasn't been created.
"""
if self._gca_resource.name:
return super().get_input_artifacts()
else:
raise RuntimeError(
f"{self.__class__.__name__} resource has not been created."
)
def get_output_artifacts(self) -> List[artifact.Artifact]:
"""Get the output Artifacts of this Execution.
Returns:
List of output Artifacts.
Raises:
RuntimeError: if Execution resource hasn't been created.
"""
if self._gca_resource.name:
return super().get_output_artifacts()
else:
raise RuntimeError(
f"{self.__class__.__name__} resource has not been created."
)
def update(
self,
state: Optional[gca_execution.Execution.State] = None,
description: Optional[str] = None,
metadata: Optional[Dict[str, Any]] = None,
):
"""Update this Execution.
Args:
state (gca_execution.Execution.State):
Optional. State of this Execution.
description (str):
Optional. Describes the purpose of the Execution to be created.
metadata (Dict[str, Any):
Optional. Contains the metadata information that will be stored
in the Execution.
Raises:
RuntimeError: if Execution resource hasn't been created.
"""
if self._gca_resource.name:
super().update(
state=state,
description=description,
metadata=metadata,
)
else:
raise RuntimeError(
f"{self.__class__.__name__} resource has not been created."
)
def __repr__(self) -> str:
if self._gca_resource.name:
return super().__repr__()
else:
return f"{object.__repr__(self)}\nschema_title: {self.schema_title}"

View File

@@ -0,0 +1,265 @@
# -*- coding: utf-8 -*-
# Copyright 2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import copy
from typing import Optional, Dict
from google.cloud.aiplatform.compat.types import artifact as gca_artifact
from google.cloud.aiplatform.metadata.schema import base_artifact
class Model(base_artifact.BaseArtifactSchema):
"""Artifact type for model."""
schema_title = "system.Model"
def __init__(
self,
*,
uri: Optional[str] = None,
artifact_id: Optional[str] = None,
display_name: Optional[str] = None,
schema_version: Optional[str] = None,
description: Optional[str] = None,
metadata: Optional[Dict] = None,
state: Optional[gca_artifact.Artifact.State] = gca_artifact.Artifact.State.LIVE,
):
"""Args:
uri (str):
Optional. The uniform resource identifier of the artifact file. May be empty if there is no actual
artifact file.
artifact_id (str):
Optional. The <resource_id> portion of the Artifact name with
the format. This is globally unique in a metadataStore:
projects/123/locations/us-central1/metadataStores/<metadata_store_id>/artifacts/<resource_id>.
display_name (str):
Optional. The user-defined name of the base.
schema_version (str):
Optional. schema_version specifies the version used by the base.
If not set, defaults to use the latest version.
description (str):
Optional. Describes the purpose of the Artifact to be created.
metadata (Dict):
Optional. Contains the metadata information that will be stored in the Artifact.
state (google.cloud.gapic.types.Artifact.State):
Optional. The state of this Artifact. This is a
property of the Artifact, and does not imply or
capture any ongoing process. This property is
managed by clients (such as Vertex AI
Pipelines), and the system does not prescribe or
check the validity of state transitions.
"""
extended_metadata = copy.deepcopy(metadata) if metadata else {}
super(Model, self).__init__(
uri=uri,
artifact_id=artifact_id,
display_name=display_name,
schema_version=schema_version,
description=description,
metadata=extended_metadata,
state=state,
)
class Artifact(base_artifact.BaseArtifactSchema):
"""A generic artifact."""
schema_title = "system.Artifact"
def __init__(
self,
*,
uri: Optional[str] = None,
artifact_id: Optional[str] = None,
display_name: Optional[str] = None,
schema_version: Optional[str] = None,
description: Optional[str] = None,
metadata: Optional[Dict] = None,
state: Optional[gca_artifact.Artifact.State] = gca_artifact.Artifact.State.LIVE,
):
"""Args:
uri (str):
Optional. The uniform resource identifier of the artifact file. May be empty if there is no actual
artifact file.
artifact_id (str):
Optional. The <resource_id> portion of the Artifact name with
the format. This is globally unique in a metadataStore:
projects/123/locations/us-central1/metadataStores/<metadata_store_id>/artifacts/<resource_id>.
display_name (str):
Optional. The user-defined name of the base.
schema_version (str):
Optional. schema_version specifies the version used by the base.
If not set, defaults to use the latest version.
description (str):
Optional. Describes the purpose of the Artifact to be created.
metadata (Dict):
Optional. Contains the metadata information that will be stored in the Artifact.
state (google.cloud.gapic.types.Artifact.State):
Optional. The state of this Artifact. This is a
property of the Artifact, and does not imply or
capture any ongoing process. This property is
managed by clients (such as Vertex AI
Pipelines), and the system does not prescribe or
check the validity of state transitions.
"""
extended_metadata = copy.deepcopy(metadata) if metadata else {}
super(Artifact, self).__init__(
uri=uri,
artifact_id=artifact_id,
display_name=display_name,
schema_version=schema_version,
description=description,
metadata=extended_metadata,
state=state,
)
class Dataset(base_artifact.BaseArtifactSchema):
"""An artifact representing a system Dataset."""
schema_title = "system.Dataset"
def __init__(
self,
*,
uri: Optional[str] = None,
artifact_id: Optional[str] = None,
display_name: Optional[str] = None,
schema_version: Optional[str] = None,
description: Optional[str] = None,
metadata: Optional[Dict] = None,
state: Optional[gca_artifact.Artifact.State] = gca_artifact.Artifact.State.LIVE,
):
"""Args:
uri (str):
Optional. The uniform resource identifier of the artifact file. May be empty if there is no actual
artifact file.
artifact_id (str):
Optional. The <resource_id> portion of the Artifact name with
the format. This is globally unique in a metadataStore:
projects/123/locations/us-central1/metadataStores/<metadata_store_id>/artifacts/<resource_id>.
display_name (str):
Optional. The user-defined name of the base.
schema_version (str):
Optional. schema_version specifies the version used by the base.
If not set, defaults to use the latest version.
description (str):
Optional. Describes the purpose of the Artifact to be created.
metadata (Dict):
Optional. Contains the metadata information that will be stored in the Artifact.
state (google.cloud.gapic.types.Artifact.State):
Optional. The state of this Artifact. This is a
property of the Artifact, and does not imply or
capture any ongoing process. This property is
managed by clients (such as Vertex AI
Pipelines), and the system does not prescribe or
check the validity of state transitions.
"""
extended_metadata = copy.deepcopy(metadata) if metadata else {}
super(Dataset, self).__init__(
uri=uri,
artifact_id=artifact_id,
display_name=display_name,
schema_version=schema_version,
description=description,
metadata=extended_metadata,
state=state,
)
class Metrics(base_artifact.BaseArtifactSchema):
"""Artifact schema for scalar metrics."""
schema_title = "system.Metrics"
def __init__(
self,
*,
accuracy: Optional[float] = None,
precision: Optional[float] = None,
recall: Optional[float] = None,
f1score: Optional[float] = None,
mean_absolute_error: Optional[float] = None,
mean_squared_error: Optional[float] = None,
uri: Optional[str] = None,
artifact_id: Optional[str] = None,
display_name: Optional[str] = None,
schema_version: Optional[str] = None,
description: Optional[str] = None,
metadata: Optional[Dict] = None,
state: Optional[gca_artifact.Artifact.State] = gca_artifact.Artifact.State.LIVE,
):
"""Args:
accuracy (float):
Optional.
precision (float):
Optional.
recall (float):
Optional.
f1score (float):
Optional.
mean_absolute_error (float):
Optional.
mean_squared_error (float):
Optional.
uri (str):
Optional. The uniform resource identifier of the artifact file. May be empty if there is no actual
artifact file.
artifact_id (str):
Optional. The <resource_id> portion of the Artifact name with
the format. This is globally unique in a metadataStore:
projects/123/locations/us-central1/metadataStores/<metadata_store_id>/artifacts/<resource_id>.
display_name (str):
Optional. The user-defined name of the base.
schema_version (str):
Optional. schema_version specifies the version used by the base.
If not set, defaults to use the latest version.
description (str):
Optional. Describes the purpose of the Artifact to be created.
metadata (Dict):
Optional. Contains the metadata information that will be stored in the Artifact.
state (google.cloud.gapic.types.Artifact.State):
Optional. The state of this Artifact. This is a
property of the Artifact, and does not imply or
capture any ongoing process. This property is
managed by clients (such as Vertex AI
Pipelines), and the system does not prescribe or
check the validity of state transitions.
"""
extended_metadata = copy.deepcopy(metadata) if metadata else {}
if accuracy:
extended_metadata["accuracy"] = accuracy
if precision:
extended_metadata["precision"] = precision
if recall:
extended_metadata["recall"] = recall
if f1score:
extended_metadata["f1score"] = f1score
if mean_absolute_error:
extended_metadata["mean_absolute_error"] = mean_absolute_error
if mean_squared_error:
extended_metadata["mean_squared_error"] = mean_squared_error
super(Metrics, self).__init__(
uri=uri,
artifact_id=artifact_id,
display_name=display_name,
schema_version=schema_version,
description=description,
metadata=extended_metadata,
state=state,
)

View File

@@ -0,0 +1,185 @@
# -*- coding: utf-8 -*-
# Copyright 2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import copy
from typing import Optional, Dict
from google.cloud.aiplatform.metadata.schema import base_context
class Experiment(base_context.BaseContextSchema):
"""Context schema for a Experiment context."""
schema_title = "system.Experiment"
def __init__(
self,
*,
context_id: Optional[str] = None,
display_name: Optional[str] = None,
schema_version: Optional[str] = None,
metadata: Optional[Dict] = None,
description: Optional[str] = None,
):
"""Args:
context_id (str):
Optional. The <resource_id> portion of the context name with
the following format, this is globally unique in a metadataStore.
projects/123/locations/us-central1/metadataStores/<metadata_store_id>/contexts/<resource_id>.
display_name (str):
Optional. The user-defined name of the context.
schema_version (str):
Optional. schema_version specifies the version used by the context.
If not set, defaults to use the latest version.
metadata (Dict):
Optional. Contains the metadata information that will be stored in the context.
description (str):
Optional. Describes the purpose of the context to be created.
"""
extended_metadata = copy.deepcopy(metadata) if metadata else {}
super(Experiment, self).__init__(
context_id=context_id,
display_name=display_name,
schema_version=schema_version,
description=description,
metadata=extended_metadata,
)
class ExperimentRun(base_context.BaseContextSchema):
"""Context schema for a ExperimentRun context."""
schema_title = "system.ExperimentRun"
def __init__(
self,
*,
experiment_id: Optional[str] = None,
context_id: Optional[str] = None,
display_name: Optional[str] = None,
schema_version: Optional[str] = None,
metadata: Optional[Dict] = None,
description: Optional[str] = None,
):
"""Args:
experiment_id (str):
Optional. The experiment_id that this experiment_run belongs to.
context_id (str):
Optional. The <resource_id> portion of the context name with
the following format, this is globally unique in a metadataStore.
projects/123/locations/us-central1/metadataStores/<metadata_store_id>/contexts/<resource_id>.
display_name (str):
Optional. The user-defined name of the context.
schema_version (str):
Optional. schema_version specifies the version used by the context.
If not set, defaults to use the latest version.
metadata (Dict):
Optional. Contains the metadata information that will be stored in the context.
description (str):
Optional. Describes the purpose of the context to be created.
"""
extended_metadata = copy.deepcopy(metadata) if metadata else {}
extended_metadata["experiment_id"] = experiment_id
super(ExperimentRun, self).__init__(
context_id=context_id,
display_name=display_name,
schema_version=schema_version,
description=description,
metadata=extended_metadata,
)
class Pipeline(base_context.BaseContextSchema):
"""Context schema for a Pipeline context."""
schema_title = "system.Pipeline"
def __init__(
self,
*,
context_id: Optional[str] = None,
display_name: Optional[str] = None,
schema_version: Optional[str] = None,
metadata: Optional[Dict] = None,
description: Optional[str] = None,
):
"""Args:
context_id (str):
Optional. The <resource_id> portion of the context name with
the following format, this is globally unique in a metadataStore.
projects/123/locations/us-central1/metadataStores/<metadata_store_id>/contexts/<resource_id>.
display_name (str):
Optional. The user-defined name of the context.
schema_version (str):
Optional. schema_version specifies the version used by the context.
If not set, defaults to use the latest version.
metadata (Dict):
Optional. Contains the metadata information that will be stored in the context.
description (str):
Optional. Describes the purpose of the context to be created.
"""
extended_metadata = copy.deepcopy(metadata) if metadata else {}
super(Pipeline, self).__init__(
context_id=context_id,
display_name=display_name,
schema_version=schema_version,
description=description,
metadata=extended_metadata,
)
class PipelineRun(base_context.BaseContextSchema):
"""Context schema for a PipelineRun context."""
schema_title = "system.PipelineRun"
def __init__(
self,
*,
pipeline_id: Optional[str] = None,
context_id: Optional[str] = None,
display_name: Optional[str] = None,
schema_version: Optional[str] = None,
metadata: Optional[Dict] = None,
description: Optional[str] = None,
):
"""Args:
pipeline_id (str):
Optional. PipelineJob resource name corresponding to this run.
context_id (str):
Optional. The <resource_id> portion of the context name with
the following format, this is globally unique in a metadataStore.
projects/123/locations/us-central1/metadataStores/<metadata_store_id>/contexts/<resource_id>.
display_name (str):
Optional. The user-defined name of the context.
schema_version (str):
Optional. schema_version specifies the version used by the context.
If not set, defaults to use the latest version.
metadata (Dict):
Optional. Contains the metadata information that will be stored in the context.
description (str):
Optional. Describes the purpose of the context to be created.
"""
extended_metadata = copy.deepcopy(metadata) if metadata else {}
extended_metadata["pipeline_id"] = pipeline_id
super(PipelineRun, self).__init__(
context_id=context_id,
display_name=display_name,
schema_version=schema_version,
description=description,
metadata=extended_metadata,
)

View File

@@ -0,0 +1,157 @@
# -*- coding: utf-8 -*-
# Copyright 2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import copy
from typing import Optional, Dict
from google.cloud.aiplatform.compat.types import execution as gca_execution
from google.cloud.aiplatform.metadata.schema import base_execution
class ContainerExecution(base_execution.BaseExecutionSchema):
"""Execution schema for a container execution."""
schema_title = "system.ContainerExecution"
def __init__(
self,
*,
state: Optional[
gca_execution.Execution.State
] = gca_execution.Execution.State.RUNNING,
execution_id: Optional[str] = None,
display_name: Optional[str] = None,
schema_version: Optional[str] = None,
metadata: Optional[Dict] = None,
description: Optional[str] = None,
):
"""Args:
state (gca_execution.Execution.State.RUNNING):
Optional. State of this Execution. Defaults to RUNNING.
execution_id (str):
Optional. The <resource_id> portion of the Execution name with
the following format, this is globally unique in a metadataStore.
projects/123/locations/us-central1/metadataStores/<metadata_store_id>/executions/<resource_id>.
display_name (str):
Optional. The user-defined name of the Execution.
schema_version (str):
Optional. schema_version specifies the version used by the Execution.
If not set, defaults to use the latest version.
metadata (Dict):
Optional. Contains the metadata information that will be stored in the Execution.
description (str):
Optional. Describes the purpose of the Execution to be created.
"""
extended_metadata = copy.deepcopy(metadata) if metadata else {}
super(ContainerExecution, self).__init__(
execution_id=execution_id,
state=state,
display_name=display_name,
schema_version=schema_version,
description=description,
metadata=extended_metadata,
)
class CustomJobExecution(base_execution.BaseExecutionSchema):
"""Execution schema for a custom job execution."""
schema_title = "system.CustomJobExecution"
def __init__(
self,
*,
state: Optional[
gca_execution.Execution.State
] = gca_execution.Execution.State.RUNNING,
execution_id: Optional[str] = None,
display_name: Optional[str] = None,
schema_version: Optional[str] = None,
metadata: Optional[Dict] = None,
description: Optional[str] = None,
):
"""Args:
state (gca_execution.Execution.State.RUNNING):
Optional. State of this Execution. Defaults to RUNNING.
execution_id (str):
Optional. The <resource_id> portion of the Execution name with
the following format, this is globally unique in a metadataStore.
projects/123/locations/us-central1/metadataStores/<metadata_store_id>/executions/<resource_id>.
display_name (str):
Optional. The user-defined name of the Execution.
schema_version (str):
Optional. schema_version specifies the version used by the Execution.
If not set, defaults to use the latest version.
metadata (Dict):
Optional. Contains the metadata information that will be stored in the Execution.
description (str):
Optional. Describes the purpose of the Execution to be created.
"""
extended_metadata = copy.deepcopy(metadata) if metadata else {}
super(CustomJobExecution, self).__init__(
execution_id=execution_id,
state=state,
display_name=display_name,
schema_version=schema_version,
description=description,
metadata=extended_metadata,
)
class Run(base_execution.BaseExecutionSchema):
"""Execution schema for root run execution."""
schema_title = "system.Run"
def __init__(
self,
*,
state: Optional[
gca_execution.Execution.State
] = gca_execution.Execution.State.RUNNING,
execution_id: Optional[str] = None,
display_name: Optional[str] = None,
schema_version: Optional[str] = None,
metadata: Optional[Dict] = None,
description: Optional[str] = None,
):
"""Args:
state (gca_execution.Execution.State.RUNNING):
Optional. State of this Execution. Defaults to RUNNING.
execution_id (str):
Optional. The <resource_id> portion of the Execution name with
the following format, this is globally unique in a metadataStore.
projects/123/locations/us-central1/metadataStores/<metadata_store_id>/executions/<resource_id>.
display_name (str):
Optional. The user-defined name of the Execution.
schema_version (str):
Optional. schema_version specifies the version used by the Execution.
If not set, defaults to use the latest version.
metadata (Dict):
Optional. Contains the metadata information that will be stored in the Execution.
description (str):
Optional. Describes the purpose of the Execution to be created.
"""
extended_metadata = copy.deepcopy(metadata) if metadata else {}
super(Run, self).__init__(
execution_id=execution_id,
state=state,
display_name=display_name,
schema_version=schema_version,
description=description,
metadata=extended_metadata,
)

View File

@@ -0,0 +1,360 @@
# -*- coding: utf-8 -*-
# Copyright 2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import re
from typing import Optional, Dict, List
from dataclasses import dataclass
@dataclass
class PredictSchemata:
"""A class holding instance, parameter and prediction schema uris.
Args:
instance_schema_uri (str):
Required. Points to a YAML file stored on Google Cloud Storage
describing the format of a single instance, which are used in
PredictRequest.instances, ExplainRequest.instances and
BatchPredictionJob.input_config. The schema is defined as an
OpenAPI 3.0.2 `Schema Object.
parameters_schema_uri (str):
Required. Points to a YAML file stored on Google Cloud Storage
describing the parameters of prediction and explanation via
PredictRequest.parameters, ExplainRequest.parameters and
BatchPredictionJob.model_parameters. The schema is defined as an
OpenAPI 3.0.2 `Schema Object.
prediction_schema_uri (str):
Required. Points to a YAML file stored on Google Cloud Storage
describing the format of a single prediction produced by this Model
, which are returned via PredictResponse.predictions,
ExplainResponse.explanations, and BatchPredictionJob.output_config.
The schema is defined as an OpenAPI 3.0.2 `Schema Object.
"""
instance_schema_uri: Optional[str] = None
parameters_schema_uri: Optional[str] = None
prediction_schema_uri: Optional[str] = None
def to_dict(self):
"""ML metadata schema dictionary representation of this DataClass.
Returns:
A dictionary that represents the PredictSchemata class.
"""
results = {}
if self.instance_schema_uri:
results["instanceSchemaUri"] = self.instance_schema_uri
if self.parameters_schema_uri:
results["parametersSchemaUri"] = self.parameters_schema_uri
if self.prediction_schema_uri:
results["predictionSchemaUri"] = self.prediction_schema_uri
return results
@dataclass
class ContainerSpec:
"""Container configuration for the model.
Args:
image_uri (str):
Required. URI of the Docker image to be used as the custom
container for serving predictions. This URI must identify an image
in Artifact Registry or Container Registry.
command (Sequence[str]):
Optional. Specifies the command that runs when the container
starts. This overrides the container's `ENTRYPOINT`.
args (Sequence[str]):
Optional. Specifies arguments for the command that runs when the
container starts. This overrides the container's `CMD`
env (Sequence[google.cloud.aiplatform_v1.types.EnvVar]):
Optional. List of environment variables to set in the container.
After the container starts running, code running in the container
can read these environment variables. Additionally, the command
and args fields can reference these variables. Later entries in
this list can also reference earlier entries. For example, the
following example sets the variable ``VAR_2`` to have the value
``foo bar``: .. code:: json [ { "name": "VAR_1", "value": "foo" },
{ "name": "VAR_2", "value": "$(VAR_1) bar" } ] If you switch the
order of the variables in the example, then the expansion does not
occur. This field corresponds to the ``env`` field of the
Kubernetes Containers `v1 core API.
ports (Sequence[google.cloud.aiplatform_v1.types.Port]):
Optional. List of ports to expose from the container. Vertex AI
sends any prediction requests that it receives to the first port on
this list. Vertex AI also sends `liveness and health checks.
predict_route (str):
Optional. HTTP path on the container to send prediction requests
to. Vertex AI forwards requests sent using
projects.locations.endpoints.predict to this path on the
container's IP address and port. Vertex AI then returns the
container's response in the API response. For example, if you set
this field to ``/foo``, then when Vertex AI receives a prediction
request, it forwards the request body in a POST request to the
``/foo`` path on the port of your container specified by the first
value of this ``ModelContainerSpec``'s ports field. If you don't
specify this field, it defaults to the following value when you
deploy this Model to an Endpoint
/v1/endpoints/ENDPOINT/deployedModels/DEPLOYED_MODEL:predict
The placeholders in this value are replaced as follows:
- ENDPOINT: The last segment (following ``endpoints/``)of the
Endpoint.name][] field of the Endpoint where this Model has
been deployed. (Vertex AI makes this value available to your
container code as the ```AIP_ENDPOINT_ID`` environment variable
health_route (str):
Optional. HTTP path on the container to send health checks to.
Vertex AI intermittently sends GET requests to this path on the
container's IP address and port to check that the container is
healthy. Read more about `health checks
display_name (str):
"""
image_uri: str
command: Optional[List[str]] = None
args: Optional[List[str]] = None
env: Optional[List[Dict[str, str]]] = None
ports: Optional[List[int]] = None
predict_route: Optional[str] = None
health_route: Optional[str] = None
def to_dict(self):
"""ML metadata schema dictionary representation of this DataClass.
Returns:
A dictionary that represents the ContainerSpec class.
"""
results = {}
results["imageUri"] = self.image_uri
if self.command:
results["command"] = self.command
if self.args:
results["args"] = self.args
if self.env:
results["env"] = self.env
if self.ports:
results["ports"] = self.ports
if self.predict_route:
results["predictRoute"] = self.predict_route
if self.health_route:
results["healthRoute"] = self.health_route
return results
@dataclass
class AnnotationSpec:
"""A class that represents the annotation spec of a Confusion Matrix.
Args:
display_name (str):
Optional. Display name for a column of a confusion matrix.
id (str):
Optional. Id for a column of a confusion matrix.
"""
display_name: Optional[str] = None
id: Optional[str] = None
def to_dict(self):
"""ML metadata schema dictionary representation of this DataClass.
Returns:
A dictionary that represents the AnnotationSpec class.
"""
results = {}
if self.display_name:
results["displayName"] = self.display_name
if self.id:
results["id"] = self.id
return results
@dataclass
class ConfusionMatrix:
"""A class that represents a Confusion Matrix.
Args:
matrix (List[List[int]]):
Required. A 2D array of integers that represets the values for the confusion matrix.
annotation_specs: (List(AnnotationSpec)):
Optional. List of column annotation specs which contains display_name (str) and id (str)
"""
matrix: List[List[int]]
annotation_specs: Optional[List[AnnotationSpec]] = None
def to_dict(self):
"""ML metadata schema dictionary representation of this DataClass.
Returns:
A dictionary that represents the ConfusionMatrix class.
Raises:
ValueError: if annotation_specs and matrix have different length.
"""
results = {}
if self.annotation_specs:
if len(self.annotation_specs) != len(self.matrix):
raise ValueError(
"Length of annotation_specs and matrix must be the same. "
"Got lengths {} and {} respectively.".format(
len(self.annotation_specs), len(self.matrix)
)
)
results["annotationSpecs"] = [
annotation_spec.to_dict() for annotation_spec in self.annotation_specs
]
if self.matrix:
results["rows"] = self.matrix
return results
@dataclass
class ConfidenceMetric:
"""A class that represents a Confidence Metric.
Args:
confidence_threshold (float):
Required. Metrics are computed with an assumption that the Model never returns predictions with a score lower than this value.
For binary classification this is the positive class threshold. For multi-class classification this is the confidence threshold.
recall (float):
Optional. Recall (True Positive Rate) for the given confidence threshold.
precision (float):
Optional. Precision for the given confidence threshold.
f1_score (float):
Optional. The harmonic mean of recall and precision.
max_predictions (int):
Optional. Metrics are computed with an assumption that the Model always returns at most this many predictions (ordered by their score, descendingly).
But they all still need to meet the `confidence_threshold`.
false_positive_rate (float):
Optional. False Positive Rate for the given confidence threshold.
accuracy (float):
Optional. Accuracy is the fraction of predictions given the correct label. For multiclass this is a micro-average metric.
true_positive_count (int):
Optional. The number of Model created labels that match a ground truth label.
false_positive_count (int):
Optional. The number of Model created labels that do not match a ground truth label.
false_negative_count (int):
Optional. The number of ground truth labels that are not matched by a Model created label.
true_negative_count (int):
Optional. The number of labels that were not created by the Model, but if they would, they would not match a ground truth label.
recall_at_1 (float):
Optional. The Recall (True Positive Rate) when only considering the label that has the highest prediction score
and not below the confidence threshold for each DataItem.
precision_at_1 (float):
Optional. The precision when only considering the label that has the highest prediction score
and not below the confidence threshold for each DataItem.
false_positive_rate_at_1 (float):
Optional. The False Positive Rate when only considering the label that has the highest prediction score
and not below the confidence threshold for each DataItem.
f1_score_at_1 (float):
Optional. The harmonic mean of recallAt1 and precisionAt1.
confusion_matrix (ConfusionMatrix):
Optional. Confusion matrix for the given confidence threshold.
"""
confidence_threshold: float
recall: Optional[float] = None
precision: Optional[float] = None
f1_score: Optional[float] = None
max_predictions: Optional[int] = None
false_positive_rate: Optional[float] = None
accuracy: Optional[float] = None
true_positive_count: Optional[int] = None
false_positive_count: Optional[int] = None
false_negative_count: Optional[int] = None
true_negative_count: Optional[int] = None
recall_at_1: Optional[float] = None
precision_at_1: Optional[float] = None
false_positive_rate_at_1: Optional[float] = None
f1_score_at_1: Optional[float] = None
confusion_matrix: Optional[ConfusionMatrix] = None
def to_dict(self):
"""ML metadata schema dictionary representation of this DataClass.
Returns:
A dictionary that represents the ConfidenceMetric class.
"""
results = {}
results["confidenceThreshold"] = self.confidence_threshold
if self.recall is not None:
results["recall"] = self.recall
if self.precision is not None:
results["precision"] = self.precision
if self.f1_score is not None:
results["f1Score"] = self.f1_score
if self.max_predictions is not None:
results["maxPredictions"] = self.max_predictions
if self.false_positive_rate is not None:
results["falsePositiveRate"] = self.false_positive_rate
if self.accuracy is not None:
results["accuracy"] = self.accuracy
if self.true_positive_count is not None:
results["truePositiveCount"] = self.true_positive_count
if self.false_positive_count is not None:
results["falsePositiveCount"] = self.false_positive_count
if self.false_negative_count is not None:
results["falseNegativeCount"] = self.false_negative_count
if self.true_negative_count is not None:
results["trueNegativeCount"] = self.true_negative_count
if self.recall_at_1 is not None:
results["recallAt1"] = self.recall_at_1
if self.precision_at_1 is not None:
results["precisionAt1"] = self.precision_at_1
if self.false_positive_rate_at_1 is not None:
results["falsePositiveRateAt1"] = self.false_positive_rate_at_1
if self.f1_score_at_1 is not None:
results["f1ScoreAt1"] = self.f1_score_at_1
if self.confusion_matrix:
results["confusionMatrix"] = self.confusion_matrix.to_dict()
return results
def create_uri_from_resource_name(resource_name: str) -> str:
"""Construct the service URI for a given resource_name.
Args:
resource_name (str):
The name of the Vertex resource, in one of the forms:
projects/{project}/locations/{location}/{resource_type}/{resource_id}
projects/{project}/locations/{location}/{resource_type}/{resource_id}@{version}
projects/{project}/locations/{location}/metadataStores/{store_id}/{resource_type}/{resource_id}
projects/{project}/locations/{location}/metadataStores/{store_id}/{resource_type}/{resource_id}@{version}
Returns:
The resource URI in the form of:
https://{service-endpoint}/v1/{resource_name},
where {service-endpoint} is one of the supported service endpoints at
https://cloud.google.com/vertex-ai/docs/reference/rest#rest_endpoints
Raises:
ValueError: If resource_name does not match the specified format.
"""
# TODO: support nested resource names such as models/123/evaluations/456
match_results = re.match(
r"^projects\/(?P<project>[\w-]+)\/locations\/(?P<location>[\w-]+)(\/metadataStores\/(?P<store>[\w-]+))?\/[\w-]+\/(?P<id>[\w-]+)(?P<version>@[\w-]+)?$",
resource_name,
)
if not match_results:
raise ValueError(f"Invalid resource_name format for {resource_name}.")
location = match_results["location"]
return f"https://{location}-aiplatform.googleapis.com/v1/{resource_name}"

View File

@@ -0,0 +1,54 @@
# -*- coding: utf-8 -*-
# Copyright 2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from typing import List, Optional, Union
def _make_filter_string(
schema_title: Optional[Union[str, List[str]]] = None,
in_context: Optional[List[str]] = None,
parent_contexts: Optional[List[str]] = None,
uri: Optional[str] = None,
) -> str:
"""Helper method to format filter strings for Metadata querying.
No enforcement of correctness.
Args:
schema_title (Union[str, List[str]]): Optional. schema_titles to filter for.
in_context (List[str]):
Optional. Context resource names that the node should be in. Only for Artifacts/Executions.
parent_contexts (List[str]): Optional. Parent contexts the context should be in. Only for Contexts.
uri (str): Optional. uri to match for. Only for Artifacts.
Returns:
String that can be used for Metadata service filtering.
"""
parts = []
if schema_title:
if isinstance(schema_title, str):
parts.append(f'schema_title="{schema_title}"')
else:
substring = " OR ".join(f'schema_title="{s}"' for s in schema_title)
parts.append(f"({substring})")
if in_context:
for context in in_context:
parts.append(f'in_context("{context}")')
if parent_contexts:
parent_context_str = ",".join([f'"{c}"' for c in parent_contexts])
parts.append(f"parent_contexts:{parent_context_str}")
if uri:
parts.append(f'uri="{uri}"')
return " AND ".join(parts)